From 46e6a9784d01036fcfc16aec2f513d80148225ef Mon Sep 17 00:00:00 2001 From: EUForest Date: Sun, 12 Oct 2025 16:23:29 +0800 Subject: [PATCH] add: User transmission interface encryption --- apis/auth/auth.api | 3 + apis/common.api | 1 + apis/public/announcement.api | 2 +- apis/public/document.api | 2 +- apis/public/order.api | 2 +- apis/public/payment.api | 2 +- apis/public/portal.api | 1 + apis/public/subscribe.api | 2 +- apis/public/ticket.api | 2 +- apis/public/user.api | 2 +- initialize/device.go | 26 ++ initialize/init.go | 1 + internal/config/config.go | 10 + internal/handler/routes.go | 17 +- .../authMethod/updateAuthMethodConfigLogic.go | 3 + internal/logic/auth/deviceLoginLogic.go | 1 + internal/logic/auth/resetPasswordLogic.go | 18 ++ internal/logic/auth/telephoneLoginLogic.go | 3 + .../logic/auth/telephoneResetPasswordLogic.go | 17 ++ .../logic/auth/telephoneUserRegisterLogic.go | 4 + internal/logic/auth/userLoginLogic.go | 3 + internal/logic/auth/userRegisterLogic.go | 4 +- internal/middleware/authMiddleware.go | 3 + internal/middleware/deviceMiddleware.go | 286 ++++++++++++++++++ internal/types/types.go | 14 +- pkg/constant/context.go | 1 + pkg/tool/encryption_test.go | 2 +- 27 files changed, 410 insertions(+), 22 deletions(-) create mode 100644 initialize/device.go create mode 100644 internal/middleware/deviceMiddleware.go diff --git a/apis/auth/auth.api b/apis/auth/auth.api index 9ee085e..68fe700 100644 --- a/apis/auth/auth.api +++ b/apis/auth/auth.api @@ -39,6 +39,7 @@ type ( } // User login response ResetPasswordRequest { + Identifier string `json:"identifier"` Email string `json:"email" validate:"required"` Password string `json:"password" validate:"required"` Code string `json:"code,optional"` @@ -94,6 +95,7 @@ type ( } // User login response TelephoneResetPasswordRequest { + Identifier string `json:"identifier"` Telephone string `json:"telephone" validate:"required"` TelephoneAreaCode string `json:"telephone_area_code" validate:"required"` Password string `json:"password" validate:"required"` @@ -122,6 +124,7 @@ type ( @server ( prefix: v1/auth group: auth + middleware: DeviceMiddleware ) service ppanel { @doc "User login" diff --git a/apis/common.api b/apis/common.api index 6617a17..d246099 100644 --- a/apis/common.api +++ b/apis/common.api @@ -92,6 +92,7 @@ type ( @server ( prefix: v1/common group: common + middleware: DeviceMiddleware ) service ppanel { @doc "Get global config" diff --git a/apis/public/announcement.api b/apis/public/announcement.api index 5afd09b..7122e4e 100644 --- a/apis/public/announcement.api +++ b/apis/public/announcement.api @@ -13,7 +13,7 @@ import "../types.api" @server ( prefix: v1/public/announcement group: public/announcement - middleware: AuthMiddleware + middleware: AuthMiddleware,DeviceMiddleware ) service ppanel { @doc "Query announcement" diff --git a/apis/public/document.api b/apis/public/document.api index 4a5e6f9..660bea6 100644 --- a/apis/public/document.api +++ b/apis/public/document.api @@ -13,7 +13,7 @@ import "../types.api" @server ( prefix: v1/public/document group: public/document - middleware: AuthMiddleware + middleware: AuthMiddleware,DeviceMiddleware ) service ppanel { @doc "Get document list" diff --git a/apis/public/order.api b/apis/public/order.api index 0db556f..4e83b0f 100644 --- a/apis/public/order.api +++ b/apis/public/order.api @@ -13,7 +13,7 @@ import "../types.api" @server ( prefix: v1/public/order group: public/order - middleware: AuthMiddleware + middleware: AuthMiddleware,DeviceMiddleware ) service ppanel { @doc "Pre create order" diff --git a/apis/public/payment.api b/apis/public/payment.api index 4876abd..a4893ab 100644 --- a/apis/public/payment.api +++ b/apis/public/payment.api @@ -13,7 +13,7 @@ import "../types.api" @server ( prefix: v1/public/payment group: public/payment - middleware: AuthMiddleware + middleware: AuthMiddleware,DeviceMiddleware ) service ppanel { @doc "Get available payment methods" diff --git a/apis/public/portal.api b/apis/public/portal.api index 33d9948..aba8e25 100644 --- a/apis/public/portal.api +++ b/apis/public/portal.api @@ -70,6 +70,7 @@ type ( @server ( prefix: v1/public/portal group: public/portal + middleware: DeviceMiddleware ) service ppanel { @doc "Get available payment methods" diff --git a/apis/public/subscribe.api b/apis/public/subscribe.api index c5eaa63..13024f8 100644 --- a/apis/public/subscribe.api +++ b/apis/public/subscribe.api @@ -19,7 +19,7 @@ type ( @server ( prefix: v1/public/subscribe group: public/subscribe - middleware: AuthMiddleware + middleware: AuthMiddleware,DeviceMiddleware ) service ppanel { @doc "Get subscribe list" diff --git a/apis/public/ticket.api b/apis/public/ticket.api index 0f39304..69bff62 100644 --- a/apis/public/ticket.api +++ b/apis/public/ticket.api @@ -43,7 +43,7 @@ type ( @server ( prefix: v1/public/ticket group: public/ticket - middleware: AuthMiddleware + middleware: AuthMiddleware,DeviceMiddleware ) service ppanel { @doc "Get ticket list" diff --git a/apis/public/user.api b/apis/public/user.api index 3236bc5..2547e5b 100644 --- a/apis/public/user.api +++ b/apis/public/user.api @@ -102,7 +102,7 @@ type ( @server ( prefix: v1/public/user group: public/user - middleware: AuthMiddleware + middleware: AuthMiddleware,DeviceMiddleware ) service ppanel { @doc "Query User Info" diff --git a/initialize/device.go b/initialize/device.go new file mode 100644 index 0000000..1b8c527 --- /dev/null +++ b/initialize/device.go @@ -0,0 +1,26 @@ +package initialize + +import ( + "context" + + "github.com/perfect-panel/server/pkg/logger" + + "github.com/perfect-panel/server/internal/config" + "github.com/perfect-panel/server/internal/model/auth" + "github.com/perfect-panel/server/internal/svc" + "github.com/perfect-panel/server/pkg/tool" +) + +func Device(ctx *svc.ServiceContext) { + logger.Debug("device config initialization") + method, err := ctx.AuthModel.FindOneByMethod(context.Background(), "device") + if err != nil { + panic(err) + } + var cfg config.DeviceConfig + var deviceConfig auth.DeviceConfig + deviceConfig.Unmarshal(method.Config) + tool.DeepCopy(&cfg, deviceConfig) + cfg.Enable = *method.Enabled + ctx.Config.Device = cfg +} diff --git a/initialize/init.go b/initialize/init.go index 02ce905..8023ce5 100644 --- a/initialize/init.go +++ b/initialize/init.go @@ -9,6 +9,7 @@ func StartInitSystemConfig(svc *svc.ServiceContext) { Site(svc) Node(svc) Email(svc) + Device(svc) Invite(svc) Verify(svc) Subscribe(svc) diff --git a/internal/config/config.go b/internal/config/config.go index 4d2dde5..59ece74 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -2,6 +2,7 @@ package config import ( "encoding/json" + "github.com/perfect-panel/server/pkg/logger" "github.com/perfect-panel/server/pkg/orm" ) @@ -20,6 +21,7 @@ type Config struct { Node NodeConfig `yaml:"Node"` Mobile MobileConfig `yaml:"Mobile"` Email EmailConfig `yaml:"Email"` + Device DeviceConfig `yaml:"device"` Verify Verify `yaml:"Verify"` VerifyCode VerifyCode `yaml:"VerifyCode"` Register RegisterConfig `yaml:"Register"` @@ -95,6 +97,14 @@ type MobileConfig struct { Whitelist []string `yaml:"whitelist"` } +type DeviceConfig struct { + Enable bool `yaml:"enable" default:"true"` + ShowAds bool `yaml:"show_ads"` + EnableSecurity bool `yaml:"enable_security"` + OnlyRealDevice bool `yaml:"only_real_device"` + SecuritySecret string `yaml:"security_secret"` +} + type SiteConfig struct { Host string `yaml:"Host" default:""` SiteName string `yaml:"SiteName" default:""` diff --git a/internal/handler/routes.go b/internal/handler/routes.go index f386b79..7109585 100644 --- a/internal/handler/routes.go +++ b/internal/handler/routes.go @@ -578,6 +578,7 @@ func RegisterHandlers(router *gin.Engine, serverCtx *svc.ServiceContext) { } authGroupRouter := router.Group("/v1/auth") + authGroupRouter.Use(middleware.DeviceMiddleware(serverCtx)) { // Check user is exist @@ -622,6 +623,7 @@ func RegisterHandlers(router *gin.Engine, serverCtx *svc.ServiceContext) { } commonGroupRouter := router.Group("/v1/common") + commonGroupRouter.Use(middleware.DeviceMiddleware(serverCtx)) { // Get Ads @@ -653,7 +655,7 @@ func RegisterHandlers(router *gin.Engine, serverCtx *svc.ServiceContext) { } publicAnnouncementGroupRouter := router.Group("/v1/public/announcement") - publicAnnouncementGroupRouter.Use(middleware.AuthMiddleware(serverCtx)) + publicAnnouncementGroupRouter.Use(middleware.AuthMiddleware(serverCtx), middleware.DeviceMiddleware(serverCtx)) { // Query announcement @@ -661,7 +663,7 @@ func RegisterHandlers(router *gin.Engine, serverCtx *svc.ServiceContext) { } publicDocumentGroupRouter := router.Group("/v1/public/document") - publicDocumentGroupRouter.Use(middleware.AuthMiddleware(serverCtx)) + publicDocumentGroupRouter.Use(middleware.AuthMiddleware(serverCtx), middleware.DeviceMiddleware(serverCtx)) { // Get document detail @@ -672,7 +674,7 @@ func RegisterHandlers(router *gin.Engine, serverCtx *svc.ServiceContext) { } publicOrderGroupRouter := router.Group("/v1/public/order") - publicOrderGroupRouter.Use(middleware.AuthMiddleware(serverCtx)) + publicOrderGroupRouter.Use(middleware.AuthMiddleware(serverCtx), middleware.DeviceMiddleware(serverCtx)) { // Close order @@ -701,7 +703,7 @@ func RegisterHandlers(router *gin.Engine, serverCtx *svc.ServiceContext) { } publicPaymentGroupRouter := router.Group("/v1/public/payment") - publicPaymentGroupRouter.Use(middleware.AuthMiddleware(serverCtx)) + publicPaymentGroupRouter.Use(middleware.AuthMiddleware(serverCtx), middleware.DeviceMiddleware(serverCtx)) { // Get available payment methods @@ -709,6 +711,7 @@ func RegisterHandlers(router *gin.Engine, serverCtx *svc.ServiceContext) { } publicPortalGroupRouter := router.Group("/v1/public/portal") + publicPortalGroupRouter.Use(middleware.DeviceMiddleware(serverCtx)) { // Purchase Checkout @@ -731,7 +734,7 @@ func RegisterHandlers(router *gin.Engine, serverCtx *svc.ServiceContext) { } publicSubscribeGroupRouter := router.Group("/v1/public/subscribe") - publicSubscribeGroupRouter.Use(middleware.AuthMiddleware(serverCtx)) + publicSubscribeGroupRouter.Use(middleware.AuthMiddleware(serverCtx), middleware.DeviceMiddleware(serverCtx)) { // Get subscribe list @@ -739,7 +742,7 @@ func RegisterHandlers(router *gin.Engine, serverCtx *svc.ServiceContext) { } publicTicketGroupRouter := router.Group("/v1/public/ticket") - publicTicketGroupRouter.Use(middleware.AuthMiddleware(serverCtx)) + publicTicketGroupRouter.Use(middleware.AuthMiddleware(serverCtx), middleware.DeviceMiddleware(serverCtx)) { // Update ticket status @@ -759,7 +762,7 @@ func RegisterHandlers(router *gin.Engine, serverCtx *svc.ServiceContext) { } publicUserGroupRouter := router.Group("/v1/public/user") - publicUserGroupRouter.Use(middleware.AuthMiddleware(serverCtx)) + publicUserGroupRouter.Use(middleware.AuthMiddleware(serverCtx), middleware.DeviceMiddleware(serverCtx)) { // Query User Affiliate Count diff --git a/internal/logic/admin/authMethod/updateAuthMethodConfigLogic.go b/internal/logic/admin/authMethod/updateAuthMethodConfigLogic.go index c20a45f..d61e38f 100644 --- a/internal/logic/admin/authMethod/updateAuthMethodConfigLogic.go +++ b/internal/logic/admin/authMethod/updateAuthMethodConfigLogic.go @@ -92,6 +92,9 @@ func (l *UpdateAuthMethodConfigLogic) UpdateGlobal(method string) { if method == "mobile" { initialize.Mobile(l.svcCtx) } + if method == "device" { + initialize.Device(l.svcCtx) + } } func validatePlatformConfig(platform string, cfg map[string]interface{}) (interface{}, error) { diff --git a/internal/logic/auth/deviceLoginLogic.go b/internal/logic/auth/deviceLoginLogic.go index 1901e78..8f7807a 100644 --- a/internal/logic/auth/deviceLoginLogic.go +++ b/internal/logic/auth/deviceLoginLogic.go @@ -107,6 +107,7 @@ func (l *DeviceLoginLogic) DeviceLogin(req *types.DeviceLoginRequest) (resp *typ l.svcCtx.Config.JwtAuth.AccessExpire, jwt.WithOption("UserId", userInfo.Id), jwt.WithOption("SessionId", sessionId), + jwt.WithOption("LoginType", "device"), ) if err != nil { l.Errorw("token generate error", diff --git a/internal/logic/auth/resetPasswordLogic.go b/internal/logic/auth/resetPasswordLogic.go index d0d3f2f..40a1ba4 100644 --- a/internal/logic/auth/resetPasswordLogic.go +++ b/internal/logic/auth/resetPasswordLogic.go @@ -107,6 +107,23 @@ func (l *ResetPasswordLogic) ResetPassword(req *types.ResetPasswordRequest) (res if err := l.svcCtx.UserModel.Update(l.ctx, userInfo); err != nil { return nil, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseUpdateError), "update user info failed: %v", err.Error()) } + + loginType := "pc" + + // Bind device to user if identifier is provided + if req.Identifier != "" { + bindLogic := NewBindDeviceLogic(l.ctx, l.svcCtx) + if err := bindLogic.BindDeviceToUser(req.Identifier, req.IP, req.UserAgent, userInfo.Id); err != nil { + l.Errorw("failed to bind device to user", + logger.Field("user_id", userInfo.Id), + logger.Field("identifier", req.Identifier), + logger.Field("error", err.Error()), + ) + // Don't fail register if device binding fails, just log the error + } + loginType = "mobile" + } + // Generate session id sessionId := uuidx.NewUUID().String() // Generate token @@ -116,6 +133,7 @@ func (l *ResetPasswordLogic) ResetPassword(req *types.ResetPasswordRequest) (res l.svcCtx.Config.JwtAuth.AccessExpire, jwt.WithOption("UserId", userInfo.Id), jwt.WithOption("SessionId", sessionId), + jwt.WithOption("LoginType", loginType), ) if err != nil { l.Logger.Error("[UserLogin] token generate error", logger.Field("error", err.Error())) diff --git a/internal/logic/auth/telephoneLoginLogic.go b/internal/logic/auth/telephoneLoginLogic.go index 4a0fc48..4b630b9 100644 --- a/internal/logic/auth/telephoneLoginLogic.go +++ b/internal/logic/auth/telephoneLoginLogic.go @@ -124,6 +124,7 @@ func (l *TelephoneLoginLogic) TelephoneLogin(req *types.TelephoneLoginRequest, r l.svcCtx.Redis.Del(l.ctx, cacheKey) } + loginType := "pc" // Bind device to user if identifier is provided if req.Identifier != "" { bindLogic := NewBindDeviceLogic(l.ctx, l.svcCtx) @@ -135,6 +136,7 @@ func (l *TelephoneLoginLogic) TelephoneLogin(req *types.TelephoneLoginRequest, r ) // Don't fail login if device binding fails, just log the error } + loginType = "device" } // Generate session id @@ -146,6 +148,7 @@ func (l *TelephoneLoginLogic) TelephoneLogin(req *types.TelephoneLoginRequest, r l.svcCtx.Config.JwtAuth.AccessExpire, jwt.WithOption("UserId", userInfo.Id), jwt.WithOption("SessionId", sessionId), + jwt.WithOption("LoginType", loginType), ) if err != nil { l.Logger.Error("[UserLogin] token generate error", logger.Field("error", err.Error())) diff --git a/internal/logic/auth/telephoneResetPasswordLogic.go b/internal/logic/auth/telephoneResetPasswordLogic.go index 18891b0..7f5d928 100644 --- a/internal/logic/auth/telephoneResetPasswordLogic.go +++ b/internal/logic/auth/telephoneResetPasswordLogic.go @@ -83,6 +83,22 @@ func (l *TelephoneResetPasswordLogic) TelephoneResetPassword(req *types.Telephon return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "update user password failed: %v", err.Error()) } + loginType := "pc" + + // Bind device to user if identifier is provided + if req.Identifier != "" { + bindLogic := NewBindDeviceLogic(l.ctx, l.svcCtx) + if err := bindLogic.BindDeviceToUser(req.Identifier, req.IP, req.UserAgent, userInfo.Id); err != nil { + l.Errorw("failed to bind device to user", + logger.Field("user_id", userInfo.Id), + logger.Field("identifier", req.Identifier), + logger.Field("error", err.Error()), + ) + // Don't fail register if device binding fails, just log the error + } + loginType = "mobile" + } + // Generate session id sessionId := uuidx.NewUUID().String() // Generate token @@ -92,6 +108,7 @@ func (l *TelephoneResetPasswordLogic) TelephoneResetPassword(req *types.Telephon l.svcCtx.Config.JwtAuth.AccessExpire, jwt.WithOption("UserId", userInfo.Id), jwt.WithOption("SessionId", sessionId), + jwt.WithOption("LoginType", loginType), ) if err != nil { l.Errorw("[UserLogin] token generate error", logger.Field("error", err.Error())) diff --git a/internal/logic/auth/telephoneUserRegisterLogic.go b/internal/logic/auth/telephoneUserRegisterLogic.go index c28552e..1952990 100644 --- a/internal/logic/auth/telephoneUserRegisterLogic.go +++ b/internal/logic/auth/telephoneUserRegisterLogic.go @@ -139,6 +139,8 @@ func (l *TelephoneUserRegisterLogic) TelephoneUserRegister(req *types.TelephoneR return nil }) + loginType := "pc" + // Bind device to user if identifier is provided if req.Identifier != "" { bindLogic := NewBindDeviceLogic(l.ctx, l.svcCtx) @@ -150,6 +152,7 @@ func (l *TelephoneUserRegisterLogic) TelephoneUserRegister(req *types.TelephoneR ) // Don't fail register if device binding fails, just log the error } + loginType = "mobile" } // Generate session id @@ -161,6 +164,7 @@ func (l *TelephoneUserRegisterLogic) TelephoneUserRegister(req *types.TelephoneR l.svcCtx.Config.JwtAuth.AccessExpire, jwt.WithOption("UserId", userInfo.Id), jwt.WithOption("SessionId", sessionId), + jwt.WithOption("LoginType", loginType), ) if err != nil { l.Logger.Error("[UserLogin] token generate error", logger.Field("error", err.Error())) diff --git a/internal/logic/auth/userLoginLogic.go b/internal/logic/auth/userLoginLogic.go index f5e0f9d..6c53f36 100644 --- a/internal/logic/auth/userLoginLogic.go +++ b/internal/logic/auth/userLoginLogic.go @@ -79,6 +79,7 @@ func (l *UserLoginLogic) UserLogin(req *types.UserLoginRequest) (resp *types.Log if !tool.VerifyPassWord(req.Password, userInfo.Password) { return nil, errors.Wrapf(xerr.NewErrCode(xerr.UserPasswordError), "user password") } + loginType := "pc" // Bind device to user if identifier is provided if req.Identifier != "" { @@ -91,6 +92,7 @@ func (l *UserLoginLogic) UserLogin(req *types.UserLoginRequest) (resp *types.Log ) // Don't fail login if device binding fails, just log the error } + loginType = "device" } // Generate session id @@ -102,6 +104,7 @@ func (l *UserLoginLogic) UserLogin(req *types.UserLoginRequest) (resp *types.Log l.svcCtx.Config.JwtAuth.AccessExpire, jwt.WithOption("UserId", userInfo.Id), jwt.WithOption("SessionId", sessionId), + jwt.WithOption("LoginType", loginType), ) if err != nil { l.Logger.Error("[UserLogin] token generate error", logger.Field("error", err.Error())) diff --git a/internal/logic/auth/userRegisterLogic.go b/internal/logic/auth/userRegisterLogic.go index 8147362..34b2646 100644 --- a/internal/logic/auth/userRegisterLogic.go +++ b/internal/logic/auth/userRegisterLogic.go @@ -125,7 +125,7 @@ func (l *UserRegisterLogic) UserRegister(req *types.UserRegisterRequest) (resp * } return nil }) - + loginType := "pc" // Bind device to user if identifier is provided if req.Identifier != "" { bindLogic := NewBindDeviceLogic(l.ctx, l.svcCtx) @@ -137,6 +137,7 @@ func (l *UserRegisterLogic) UserRegister(req *types.UserRegisterRequest) (resp * ) // Don't fail register if device binding fails, just log the error } + loginType = "device" } // Generate session id @@ -148,6 +149,7 @@ func (l *UserRegisterLogic) UserRegister(req *types.UserRegisterRequest) (resp * l.svcCtx.Config.JwtAuth.AccessExpire, jwt.WithOption("UserId", userInfo.Id), jwt.WithOption("SessionId", sessionId), + jwt.WithOption("LoginType", loginType), ) if err != nil { l.Logger.Error("[UserLogin] token generate error", logger.Field("error", err.Error())) diff --git a/internal/middleware/authMiddleware.go b/internal/middleware/authMiddleware.go index a76d0ee..7c86865 100644 --- a/internal/middleware/authMiddleware.go +++ b/internal/middleware/authMiddleware.go @@ -40,6 +40,8 @@ func AuthMiddleware(svc *svc.ServiceContext) func(c *gin.Context) { c.Abort() return } + + loginType := claims["LoginType"].(string) // get user id from token userId := int64(claims["UserId"].(float64)) // get session id from token @@ -77,6 +79,7 @@ func AuthMiddleware(svc *svc.ServiceContext) func(c *gin.Context) { c.Abort() return } + ctx = context.WithValue(ctx, constant.LoginType, loginType) ctx = context.WithValue(ctx, constant.CtxKeyUser, userInfo) ctx = context.WithValue(ctx, constant.CtxKeySessionID, sessionId) c.Request = c.Request.WithContext(ctx) diff --git a/internal/middleware/deviceMiddleware.go b/internal/middleware/deviceMiddleware.go new file mode 100644 index 0000000..a2d6805 --- /dev/null +++ b/internal/middleware/deviceMiddleware.go @@ -0,0 +1,286 @@ +package middleware + +import ( + "bufio" + "bytes" + "encoding/json" + "fmt" + "io" + "net" + "net/http" + "strings" + + "github.com/perfect-panel/server/internal/svc" + pkgaes "github.com/perfect-panel/server/pkg/aes" + "github.com/perfect-panel/server/pkg/constant" + "github.com/perfect-panel/server/pkg/result" + "github.com/perfect-panel/server/pkg/xerr" + "github.com/pkg/errors" + + "github.com/gin-gonic/gin" +) + +const ( + noWritten = -1 + defaultStatus = http.StatusOK +) + +func DeviceMiddleware(srvCtx *svc.ServiceContext) func(c *gin.Context) { + return func(c *gin.Context) { + loginType := c.GetString(string(constant.LoginType)) + if loginType == "" { + loginType = c.GetHeader("Login-Type") + } + + if loginType != "device" { + c.Next() + return + } + + if !srvCtx.Config.Device.Enable || srvCtx.Config.Device.SecuritySecret == "" { + c.Next() + return + } + rw := NewResponseWriter(c, srvCtx) + if !rw.Decrypt() { + result.HttpResult(c, nil, errors.Wrapf(xerr.NewErrCode(xerr.InvalidCiphertext), "Invalid ciphertext")) + c.Abort() + return + } + c.Writer = rw + c.Next() + rw.FlushAbort() + } +} + +func NewResponseWriter(c *gin.Context, srvCtx *svc.ServiceContext) (rw *ResponseWriter) { + rw = &ResponseWriter{ + c: c, + body: new(bytes.Buffer), + ResponseWriter: c.Writer, + } + rw.encryptionKey = srvCtx.Config.Device.SecuritySecret + rw.encryptionMethod = "AES" + rw.encryption = true + return rw +} + +func (rw *ResponseWriter) Encrypt() { + if !rw.encryption { + return + } + buf := rw.body.Bytes() + params := map[string]interface{}{} + err := json.Unmarshal(buf, ¶ms) + if err != nil { + return + } + data := params["data"] + if data != nil { + var jsonData []byte + str, ok := data.(string) + if ok { + jsonData = []byte(str) + } else { + jsonData, _ = json.Marshal(data) + } + encrypt, iv, err := pkgaes.Encrypt(jsonData, rw.encryptionKey) + if err != nil { + return + } + params["data"] = map[string]interface{}{ + "data": encrypt, + "time": iv, + } + + } + marshal, _ := json.Marshal(params) + rw.body.Reset() + rw.body.Write(marshal) +} + +func (rw *ResponseWriter) Decrypt() bool { + if !rw.encryption { + return true + } + + //判断url链接中是否存在data和iv数据,存在就进行解密并设置回去 + query := rw.c.Request.URL.Query() + dataStr := query.Get("data") + timeStr := query.Get("time") + if dataStr != "" && timeStr != "" { + decrypt, err := pkgaes.Decrypt(dataStr, rw.encryptionKey, timeStr) + if err == nil { + params := map[string]interface{}{} + err = json.Unmarshal([]byte(decrypt), ¶ms) + if err == nil { + for k, v := range params { + query.Set(k, fmt.Sprintf("%v", v)) + } + query.Del("data") + query.Del("time") + rw.c.Request.RequestURI = fmt.Sprintf("%s?%s", rw.c.Request.RequestURI[:strings.Index(rw.c.Request.RequestURI, "?")], query.Encode()) + rw.c.Request.URL.RawQuery = query.Encode() + } + } + } else { + return false + } + + //判断body是否存在数据,存在就尝试解密,并设置回去 + body, err := io.ReadAll(rw.c.Request.Body) + if err != nil { + return true + } + + if len(body) == 0 { + return true + } + + params := map[string]interface{}{} + err = json.Unmarshal(body, ¶ms) + data := params["data"] + nonce := params["time"] + if err != nil || data == nil { + return false + } + + str, ok := data.(string) + if !ok { + return false + } + iv, ok := nonce.(string) + if !ok { + return false + } + + decrypt, err := pkgaes.Decrypt(str, rw.encryptionKey, iv) + if err != nil { + return false + } + rw.c.Request.Body = io.NopCloser(bytes.NewBuffer([]byte(decrypt))) + return true +} + +func (rw *ResponseWriter) FlushAbort() { + defer rw.c.Abort() + responseBody := rw.body.String() + fmt.Println("Original Response Body:", responseBody) + rw.flush = true + if rw.encryption { + rw.Encrypt() + } + _, err := rw.Write(rw.body.Bytes()) + if err != nil { + return + } +} + +type ResponseWriter struct { + http.ResponseWriter + size int + status int + flush bool + body *bytes.Buffer + c *gin.Context + encryption bool + encryptionKey string + encryptionMethod string +} + +func (rw *ResponseWriter) Unwrap() http.ResponseWriter { + return rw.ResponseWriter +} + +//nolint:unused +func (rw *ResponseWriter) reset(writer http.ResponseWriter) { + rw.ResponseWriter = writer + rw.size = noWritten + rw.status = defaultStatus +} + +func (rw *ResponseWriter) WriteHeader(code int) { + if code > 0 && rw.status != code { + if rw.Written() { + return + } + rw.status = code + } +} + +func (rw *ResponseWriter) WriteHeaderNow() { + if !rw.Written() { + rw.size = 0 + rw.ResponseWriter.WriteHeader(rw.status) + } +} + +func (rw *ResponseWriter) Write(data []byte) (n int, err error) { + if rw.flush { + rw.WriteHeaderNow() + n, err = rw.ResponseWriter.Write(data) + rw.size += n + } else { + rw.body.Write(data) + } + return +} + +func (rw *ResponseWriter) WriteString(s string) (n int, err error) { + if rw.flush { + rw.WriteHeaderNow() + n, err = rw.ResponseWriter.Write([]byte(s)) + rw.size += n + } else { + rw.body.Write([]byte(s)) + } + return +} + +func (rw *ResponseWriter) Status() int { + return rw.status +} + +func (rw *ResponseWriter) Size() int { + return rw.size +} + +func (rw *ResponseWriter) Written() bool { + return rw.size != noWritten +} + +// Hijack implements the http.Hijacker interface. +func (rw *ResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + if rw.size < 0 { + rw.size = 0 + } + return rw.ResponseWriter.(http.Hijacker).Hijack() +} + +// CloseNotify implements the http.CloseNotifier interface. +func (rw *ResponseWriter) CloseNotify() <-chan bool { + // 通过 r.Context().Done() 来监听请求的取消 + done := rw.c.Request.Context().Done() + closed := make(chan bool) + + // 当上下文被取消时,通过 closed channel 发送通知 + go func() { + <-done + closed <- true + }() + + return closed +} + +// Flush implements the http.Flusher interface. +func (rw *ResponseWriter) Flush() { + rw.WriteHeaderNow() + rw.ResponseWriter.(http.Flusher).Flush() +} + +func (rw *ResponseWriter) Pusher() (pusher http.Pusher) { + if pusher, ok := rw.ResponseWriter.(http.Pusher); ok { + return pusher + } + return nil +} diff --git a/internal/types/types.go b/internal/types/types.go index 0bdb30d..14e3120 100644 --- a/internal/types/types.go +++ b/internal/types/types.go @@ -1752,12 +1752,13 @@ type RenewalOrderResponse struct { } type ResetPasswordRequest struct { - Email string `json:"email" validate:"required"` - Password string `json:"password" validate:"required"` - Code string `json:"code,optional"` - IP string `header:"X-Original-Forwarded-For"` - UserAgent string `header:"User-Agent"` - CfToken string `json:"cf_token,optional"` + Identifier string `json:"identifier"` + Email string `json:"email" validate:"required"` + Password string `json:"password" validate:"required"` + Code string `json:"code,optional"` + IP string `header:"X-Original-Forwarded-For"` + UserAgent string `header:"User-Agent"` + CfToken string `json:"cf_token,optional"` } type ResetSortRequest struct { @@ -2128,6 +2129,7 @@ type TelephoneRegisterRequest struct { } type TelephoneResetPasswordRequest struct { + Identifier string `json:"identifier"` Telephone string `json:"telephone" validate:"required"` TelephoneAreaCode string `json:"telephone_area_code" validate:"required"` Password string `json:"password" validate:"required"` diff --git a/pkg/constant/context.go b/pkg/constant/context.go index 4023cd1..45c7f86 100644 --- a/pkg/constant/context.go +++ b/pkg/constant/context.go @@ -8,4 +8,5 @@ const ( CtxKeyRequestHost CtxKey = "requestHost" CtxKeyPlatform CtxKey = "platform" CtxKeyPayment CtxKey = "payment" + LoginType CtxKey = "loginType" ) diff --git a/pkg/tool/encryption_test.go b/pkg/tool/encryption_test.go index 45e0a18..8841072 100644 --- a/pkg/tool/encryption_test.go +++ b/pkg/tool/encryption_test.go @@ -3,5 +3,5 @@ package tool import "testing" func TestEncodePassWord(t *testing.T) { - t.Logf("EncodePassWord: %v", EncodePassWord("")) + t.Logf("EncodePassWord: %v", EncodePassWord("password")) }