diff --git a/internal/logic/auth/deviceLoginLogic.go b/internal/logic/auth/deviceLoginLogic.go index 9218713..2d71fd5 100644 --- a/internal/logic/auth/deviceLoginLogic.go +++ b/internal/logic/auth/deviceLoginLogic.go @@ -108,6 +108,7 @@ func (l *DeviceLoginLogic) DeviceLogin(req *types.DeviceLoginRequest) (resp *typ l.svcCtx.Config.JwtAuth.AccessExpire, jwt.WithOption("UserId", userInfo.Id), jwt.WithOption("SessionId", sessionId), + jwt.WithOption("CtxLoginType", "device"), jwt.WithOption("LoginType", "device"), ) if err != nil { diff --git a/internal/logic/auth/emailLoginLogic.go b/internal/logic/auth/emailLoginLogic.go index 8312cc2..1cc594f 100644 --- a/internal/logic/auth/emailLoginLogic.go +++ b/internal/logic/auth/emailLoginLogic.go @@ -200,6 +200,7 @@ func (l *EmailLoginLogic) EmailLogin(req *types.EmailLoginRequest) (resp *types. l.svcCtx.Config.JwtAuth.AccessExpire, jwt.WithOption("UserId", userInfo.Id), jwt.WithOption("SessionId", sessionId), + jwt.WithOption("CtxLoginType", req.LoginType), jwt.WithOption("LoginType", req.LoginType), ) if err != nil { diff --git a/internal/logic/auth/userRegisterLogic.go b/internal/logic/auth/userRegisterLogic.go index d960f6a..80cf27b 100644 --- a/internal/logic/auth/userRegisterLogic.go +++ b/internal/logic/auth/userRegisterLogic.go @@ -188,6 +188,7 @@ func (l *UserRegisterLogic) UserRegister(req *types.UserRegisterRequest) (resp * l.svcCtx.Config.JwtAuth.AccessExpire, jwt.WithOption("UserId", userInfo.Id), jwt.WithOption("SessionId", sessionId), + jwt.WithOption("CtxLoginType", req.LoginType), jwt.WithOption("LoginType", req.LoginType), ) if err != nil { diff --git a/internal/middleware/authMiddleware.go b/internal/middleware/authMiddleware.go index 9007ae9..6310fe6 100644 --- a/internal/middleware/authMiddleware.go +++ b/internal/middleware/authMiddleware.go @@ -41,10 +41,7 @@ func AuthMiddleware(svc *svc.ServiceContext) func(c *gin.Context) { return } - loginType := "" - if claims["CtxLoginType"] != nil { - loginType = claims["CtxLoginType"].(string) - } + loginType := parseLoginType(claims) if claims["identifier"] != nil { ctx = context.WithValue(ctx, constant.CtxKeyIdentifier, claims["identifier"].(string)) } @@ -93,3 +90,17 @@ func AuthMiddleware(svc *svc.ServiceContext) func(c *gin.Context) { c.Next() } } + +func parseLoginType(claims map[string]interface{}) string { + if raw, exists := claims["CtxLoginType"]; exists { + if loginType, ok := raw.(string); ok && loginType != "" { + return loginType + } + } + if raw, exists := claims["LoginType"]; exists { + if loginType, ok := raw.(string); ok && loginType != "" { + return loginType + } + } + return "" +} diff --git a/internal/middleware/authMiddleware_test.go b/internal/middleware/authMiddleware_test.go new file mode 100644 index 0000000..ca6d5cf --- /dev/null +++ b/internal/middleware/authMiddleware_test.go @@ -0,0 +1,46 @@ +package middleware + +import "testing" + +func TestParseLoginType(t *testing.T) { + tests := []struct { + name string + claims map[string]interface{} + want string + }{ + { + name: "prefer CtxLoginType when both exist", + claims: map[string]interface{}{"CtxLoginType": "device", "LoginType": "email"}, + want: "device", + }, + { + name: "fallback to legacy LoginType", + claims: map[string]interface{}{"LoginType": "device"}, + want: "device", + }, + { + name: "ignore non-string values", + claims: map[string]interface{}{"CtxLoginType": 123, "LoginType": true}, + want: "", + }, + { + name: "empty values return empty", + claims: map[string]interface{}{"CtxLoginType": "", "LoginType": ""}, + want: "", + }, + { + name: "missing values return empty", + claims: map[string]interface{}{}, + want: "", + }, + } + + for _, testCase := range tests { + t.Run(testCase.name, func(t *testing.T) { + got := parseLoginType(testCase.claims) + if got != testCase.want { + t.Fatalf("parseLoginType() = %q, want %q", got, testCase.want) + } + }) + } +} diff --git a/internal/middleware/deviceMiddleware.go b/internal/middleware/deviceMiddleware.go index db5c18d..7b7c385 100644 --- a/internal/middleware/deviceMiddleware.go +++ b/internal/middleware/deviceMiddleware.go @@ -24,17 +24,28 @@ import ( const ( noWritten = -1 defaultStatus = http.StatusOK + + ctxDeviceDecryptStatusKey = "device_decrypt_status" + ctxDeviceDecryptReasonKey = "device_decrypt_reason" + ctxEncryptedQueryKey = "encrypted_query" + ctxDecryptedQueryKey = "decrypted_query" + ctxEncryptedBodyKey = "encrypted_request_body" + ctxDecryptedBodyKey = "decrypted_request_body" ) func DeviceMiddleware(srvCtx *svc.ServiceContext) func(c *gin.Context) { return func(c *gin.Context) { if !srvCtx.Config.Device.Enable { + c.Set(ctxDeviceDecryptStatusKey, "skipped") + c.Set(ctxDeviceDecryptReasonKey, "device_encryption_disabled") c.Next() return } if srvCtx.Config.Device.SecuritySecret == "" { + c.Set(ctxDeviceDecryptStatusKey, "failed") + c.Set(ctxDeviceDecryptReasonKey, "device_secret_empty") result.HttpResult(c, nil, errors.Wrapf(xerr.NewErrCode(xerr.SecretIsEmpty), "Secret is empty")) c.Abort() return @@ -48,12 +59,22 @@ func DeviceMiddleware(srvCtx *svc.ServiceContext) func(c *gin.Context) { loginType, ok := ctx.Value(constant.CtxLoginType).(string) if !ok || loginType != "device" { + c.Set(ctxDeviceDecryptStatusKey, "skipped") + if ok { + c.Set(ctxDeviceDecryptReasonKey, fmt.Sprintf("login_type_%s_not_device", loginType)) + } else { + c.Set(ctxDeviceDecryptReasonKey, "login_type_not_found") + } c.Next() return } rw := NewResponseWriter(c, srvCtx) if !rw.Decrypt() { + c.Set(ctxDeviceDecryptStatusKey, "failed") + if _, exists := c.Get(ctxDeviceDecryptReasonKey); !exists { + c.Set(ctxDeviceDecryptReasonKey, "decrypt_failed") + } result.HttpResult(c, nil, errors.Wrapf(xerr.NewErrCode(xerr.InvalidCiphertext), "Invalid ciphertext")) c.Abort() return @@ -112,13 +133,21 @@ func (rw *ResponseWriter) Encrypt() { func (rw *ResponseWriter) Decrypt() bool { if !rw.encryption { + rw.c.Set(ctxDeviceDecryptStatusKey, "skipped") + rw.c.Set(ctxDeviceDecryptReasonKey, "response_encryption_disabled") return true } //判断url链接中是否存在data和iv数据,存在就进行解密并设置回去 query := rw.c.Request.URL.Query() + originalRawQuery := rw.c.Request.URL.RawQuery dataStr := query.Get("data") timeStr := query.Get("time") + hasEncryptedQuery := dataStr != "" && timeStr != "" + queryDecrypted := false + if hasEncryptedQuery { + rw.c.Set(ctxEncryptedQueryKey, originalRawQuery) + } if dataStr != "" && timeStr != "" { decrypt, err := pkgaes.Decrypt(dataStr, rw.encryptionKey, timeStr) if err == nil { @@ -132,42 +161,73 @@ func (rw *ResponseWriter) Decrypt() bool { query.Del("time") rw.c.Request.RequestURI = fmt.Sprintf("%s?%s", rw.c.Request.RequestURI[:strings.Index(rw.c.Request.RequestURI, "?")], query.Encode()) rw.c.Request.URL.RawQuery = query.Encode() + rw.c.Set(ctxDecryptedQueryKey, query.Encode()) + queryDecrypted = true } + } else { + rw.c.Set(ctxDeviceDecryptReasonKey, fmt.Sprintf("query_decrypt_failed:%v", err)) } } //判断body是否存在数据,存在就尝试解密,并设置回去 body, err := io.ReadAll(rw.c.Request.Body) if err != nil { + if queryDecrypted { + rw.c.Set(ctxDeviceDecryptStatusKey, "success") + } else { + rw.c.Set(ctxDeviceDecryptStatusKey, "skipped") + } + rw.c.Set(ctxDeviceDecryptReasonKey, fmt.Sprintf("read_body_failed:%v", err)) return true } if len(body) == 0 { + if queryDecrypted { + rw.c.Set(ctxDeviceDecryptStatusKey, "success") + } else { + rw.c.Set(ctxDeviceDecryptStatusKey, "skipped") + if hasEncryptedQuery { + rw.c.Set(ctxDeviceDecryptReasonKey, "query_decrypt_failed") + } else { + rw.c.Set(ctxDeviceDecryptReasonKey, "empty_body") + } + } return true } + rw.c.Set(ctxEncryptedBodyKey, string(body)) params := map[string]interface{}{} err = json.Unmarshal(body, ¶ms) data := params["data"] nonce := params["time"] if err != nil || data == nil { + if err != nil { + rw.c.Set(ctxDeviceDecryptReasonKey, fmt.Sprintf("body_unmarshal_failed:%v", err)) + } else { + rw.c.Set(ctxDeviceDecryptReasonKey, "body_data_field_missing") + } return false } str, ok := data.(string) if !ok { + rw.c.Set(ctxDeviceDecryptReasonKey, "body_data_not_string") return false } iv, ok := nonce.(string) if !ok { + rw.c.Set(ctxDeviceDecryptReasonKey, "body_time_not_string") return false } decrypt, err := pkgaes.Decrypt(str, rw.encryptionKey, iv) if err != nil { + rw.c.Set(ctxDeviceDecryptReasonKey, fmt.Sprintf("body_decrypt_failed:%v", err)) return false } rw.c.Request.Body = io.NopCloser(bytes.NewBuffer([]byte(decrypt))) + rw.c.Set(ctxDecryptedBodyKey, decrypt) + rw.c.Set(ctxDeviceDecryptStatusKey, "success") return true } diff --git a/internal/middleware/loggerMiddleware.go b/internal/middleware/loggerMiddleware.go index 7bb2def..a5b5639 100644 --- a/internal/middleware/loggerMiddleware.go +++ b/internal/middleware/loggerMiddleware.go @@ -27,6 +27,7 @@ func (r responseBodyWriter) Write(b []byte) (int, error) { func LoggerMiddleware(svc *svc.ServiceContext) func(c *gin.Context) { return func(c *gin.Context) { + sensitiveFields := []string{"password", "old_password", "new_password"} // get response body w := &responseBodyWriter{body: &bytes.Buffer{}, ResponseWriter: c.Writer} c.Writer = w @@ -79,9 +80,26 @@ func LoggerMiddleware(svc *svc.ServiceContext) func(c *gin.Context) { } logs = append(logs, logger.Field("error", errMessage)) } + if status, ok := c.Get(ctxDeviceDecryptStatusKey); ok { + logs = append(logs, logger.Field("device_decrypt_status", status)) + } + if reason, ok := c.Get(ctxDeviceDecryptReasonKey); ok { + logs = append(logs, logger.Field("device_decrypt_reason", reason)) + } + if encryptedQuery, ok := c.Get(ctxEncryptedQueryKey); ok { + logs = append(logs, logger.Field("encrypted_query", encryptedQuery)) + } + if decryptedQuery, ok := c.Get(ctxDecryptedQueryKey); ok { + logs = append(logs, logger.Field("decrypted_query", decryptedQuery)) + } if c.Request.Method == "POST" || c.Request.Method == "PUT" || c.Request.Method == "DELETE" { // request content - logs = append(logs, logger.Field("request_body", string(maskSensitiveFields(requestBody, []string{"password", "old_password", "new_password"})))) + logs = append(logs, logger.Field("request_body", string(maskSensitiveFields(requestBody, sensitiveFields)))) + if decryptedBody, ok := c.Get(ctxDecryptedBodyKey); ok { + if bodyText, isString := decryptedBody.(string); isString { + logs = append(logs, logger.Field("decrypted_request_body", string(maskSensitiveFields([]byte(bodyText), sensitiveFields)))) + } + } // response content logs = append(logs, logger.Field("response_body", w.body.String())) }