From a01570b59d589466cadf5e5e2474bf68ca432bc0 Mon Sep 17 00:00:00 2001 From: shanshanzhong Date: Wed, 4 Mar 2026 06:33:14 -0800 Subject: [PATCH] fix gitea workflow path and runner label --- internal/config/config.go | 23 +- .../handler/auth/checkCodeLegacyHandler.go | 50 ++++ .../auth/checkCodeLegacyHandler_test.go | 167 +++++++++++ .../common/checkverificationcodehandler.go | 11 +- .../checkverificationcodehandler_test.go | 143 ++++++++++ internal/handler/routes.go | 3 + .../common/checkverificationcodelogic.go | 50 +--- .../common/checkverificationcodelogic_test.go | 259 ++++++++++++++++++ internal/logic/common/sendEmailCodeLogic.go | 2 + internal/logic/common/verifyCodeChecker.go | 227 +++++++++++++++ internal/middleware/apiVersionMiddleware.go | 25 ++ internal/middleware/corsMiddleware.go | 2 +- internal/middleware/loggerMiddleware.go | 3 + internal/model/auth/auth.go | 34 ++- internal/server.go | 2 +- internal/types/types.go | 14 + pkg/apiversion/version.go | 83 ++++++ pkg/apiversion/version_test.go | 55 ++++ pkg/constant/context.go | 22 +- queue/logic/email/sendEmailLogic.go | 41 ++- queue/logic/email/sendEmailLogic_test.go | 28 ++ queue/types/email.go | 1 + 22 files changed, 1153 insertions(+), 92 deletions(-) create mode 100644 internal/handler/auth/checkCodeLegacyHandler.go create mode 100644 internal/handler/auth/checkCodeLegacyHandler_test.go create mode 100644 internal/handler/common/checkverificationcodehandler_test.go create mode 100644 internal/logic/common/checkverificationcodelogic_test.go create mode 100644 internal/logic/common/verifyCodeChecker.go create mode 100644 internal/middleware/apiVersionMiddleware.go create mode 100644 pkg/apiversion/version.go create mode 100644 pkg/apiversion/version_test.go create mode 100644 queue/logic/email/sendEmailLogic_test.go diff --git a/internal/config/config.go b/internal/config/config.go index a1de61e..eb7cc06 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -91,17 +91,18 @@ type RegisterConfig struct { } type EmailConfig struct { - Enable bool `yaml:"Enable" default:"true"` - Platform string `yaml:"platform"` - PlatformConfig string `yaml:"platform_config"` - EnableVerify bool `yaml:"enable_verify"` - EnableNotify bool `yaml:"enable_notify"` - EnableDomainSuffix bool `yaml:"enable_domain_suffix"` - DomainSuffixList string `yaml:"domain_suffix_list"` - VerifyEmailTemplate string `yaml:"verify_email_template"` - ExpirationEmailTemplate string `yaml:"expiration_email_template"` - MaintenanceEmailTemplate string `yaml:"maintenance_email_template"` - TrafficExceedEmailTemplate string `yaml:"traffic_exceed_email_template"` + Enable bool `yaml:"Enable" default:"true"` + Platform string `yaml:"platform"` + PlatformConfig string `yaml:"platform_config"` + EnableVerify bool `yaml:"enable_verify"` + EnableNotify bool `yaml:"enable_notify"` + EnableDomainSuffix bool `yaml:"enable_domain_suffix"` + DomainSuffixList string `yaml:"domain_suffix_list"` + VerifyEmailTemplate string `yaml:"verify_email_template"` + VerifyEmailTemplates map[string]string `yaml:"verify_email_templates"` + ExpirationEmailTemplate string `yaml:"expiration_email_template"` + MaintenanceEmailTemplate string `yaml:"maintenance_email_template"` + TrafficExceedEmailTemplate string `yaml:"traffic_exceed_email_template"` } type MobileConfig struct { diff --git a/internal/handler/auth/checkCodeLegacyHandler.go b/internal/handler/auth/checkCodeLegacyHandler.go new file mode 100644 index 0000000..87db15c --- /dev/null +++ b/internal/handler/auth/checkCodeLegacyHandler.go @@ -0,0 +1,50 @@ +package auth + +import ( + "github.com/gin-gonic/gin" + commonLogic "github.com/perfect-panel/server/internal/logic/common" + "github.com/perfect-panel/server/internal/svc" + "github.com/perfect-panel/server/internal/types" + "github.com/perfect-panel/server/pkg/constant" + "github.com/perfect-panel/server/pkg/result" +) + +// Check legacy verification code +func CheckCodeLegacyHandler(svcCtx *svc.ServiceContext) func(c *gin.Context) { + return func(c *gin.Context) { + var req types.LegacyCheckVerificationCodeRequest + _ = c.ShouldBind(&req) + validateErr := svcCtx.Validate(&req) + if validateErr != nil { + result.ParamErrorResult(c, validateErr) + return + } + + normalizedReq, legacyType3Mapped, err := commonLogic.NormalizeLegacyCheckVerificationCodeRequest(&req) + if err != nil { + result.ParamErrorResult(c, err) + return + } + + l := commonLogic.NewCheckVerificationCodeLogic(c.Request.Context(), svcCtx) + useLatest := false + if value, ok := c.Request.Context().Value(constant.CtxKeyAPIVersionUseLatest).(bool); ok { + useLatest = value + } + + resp, err := l.CheckVerificationCodeWithBehavior(normalizedReq, commonLogic.VerifyCodeCheckBehavior{ + Source: "legacy", + Consume: useLatest, + LegacyType3Mapped: legacyType3Mapped, + AllowSceneFallback: constant.ParseVerifyType(normalizedReq.Type) != constant.DeleteAccount, + }) + + legacyResp := &types.LegacyCheckVerificationCodeResponse{} + if resp != nil { + legacyResp.Status = resp.Status + legacyResp.Exist = resp.Status + } + + result.HttpResult(c, legacyResp, err) + } +} diff --git a/internal/handler/auth/checkCodeLegacyHandler_test.go b/internal/handler/auth/checkCodeLegacyHandler_test.go new file mode 100644 index 0000000..c6c0b0c --- /dev/null +++ b/internal/handler/auth/checkCodeLegacyHandler_test.go @@ -0,0 +1,167 @@ +package auth + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/alicebob/miniredis/v2" + "github.com/gin-gonic/gin" + "github.com/perfect-panel/server/internal/config" + "github.com/perfect-panel/server/internal/middleware" + "github.com/perfect-panel/server/internal/svc" + "github.com/perfect-panel/server/pkg/constant" + "github.com/redis/go-redis/v9" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type legacyCheckCodeResponse struct { + Code uint32 `json:"code"` + Data struct { + Status bool `json:"status"` + Exist bool `json:"exist"` + } `json:"data"` +} + +func newLegacyCheckCodeTestRouter(svcCtx *svc.ServiceContext) *gin.Engine { + gin.SetMode(gin.TestMode) + router := gin.New() + router.Use(middleware.ApiVersionMiddleware(svcCtx)) + router.POST("/v1/auth/check-code", CheckCodeLegacyHandler(svcCtx)) + return router +} + +func newLegacyCheckCodeTestSvcCtx(t *testing.T) (*svc.ServiceContext, *redis.Client) { + t.Helper() + + miniRedis := miniredis.RunT(t) + redisClient := redis.NewClient(&redis.Options{Addr: miniRedis.Addr()}) + t.Cleanup(func() { + redisClient.Close() + miniRedis.Close() + }) + + svcCtx := &svc.ServiceContext{ + Redis: redisClient, + Config: config.Config{ + VerifyCode: config.VerifyCode{ + VerifyCodeExpireTime: 900, + }, + }, + } + return svcCtx, redisClient +} + +func seedLegacyVerifyCode(t *testing.T, redisClient *redis.Client, scene string, email string, code string) string { + t.Helper() + + cacheKey := fmt.Sprintf("%s:%s:%s", config.AuthCodeCacheKey, scene, email) + payload := map[string]interface{}{ + "code": code, + "lastAt": time.Now().Unix(), + } + payloadRaw, err := json.Marshal(payload) + require.NoError(t, err) + err = redisClient.Set(context.Background(), cacheKey, payloadRaw, time.Minute*15).Err() + require.NoError(t, err) + return cacheKey +} + +func callLegacyCheckCode(t *testing.T, router *gin.Engine, apiHeader string, body string) legacyCheckCodeResponse { + t.Helper() + + reqBody := bytes.NewBufferString(body) + req := httptest.NewRequest(http.MethodPost, "/v1/auth/check-code", reqBody) + req.Header.Set("Content-Type", "application/json") + if apiHeader != "" { + req.Header.Set("api-header", apiHeader) + } + recorder := httptest.NewRecorder() + router.ServeHTTP(recorder, req) + require.Equal(t, http.StatusOK, recorder.Code) + + var resp legacyCheckCodeResponse + err := json.Unmarshal(recorder.Body.Bytes(), &resp) + require.NoError(t, err) + return resp +} + +func TestCheckCodeLegacyHandler_NoHeaderNotConsumed(t *testing.T) { + svcCtx, redisClient := newLegacyCheckCodeTestSvcCtx(t) + router := newLegacyCheckCodeTestRouter(svcCtx) + + email := "legacy@example.com" + code := "123456" + cacheKey := seedLegacyVerifyCode(t, redisClient, constant.Security.String(), email, code) + + resp := callLegacyCheckCode(t, router, "", `{"email":"legacy@example.com","code":"123456","type":3}`) + assert.Equal(t, uint32(200), resp.Code) + assert.True(t, resp.Data.Status) + assert.True(t, resp.Data.Exist) + + exists, err := redisClient.Exists(context.Background(), cacheKey).Result() + require.NoError(t, err) + assert.Equal(t, int64(1), exists) +} + +func TestCheckCodeLegacyHandler_GreaterVersionConsumed(t *testing.T) { + svcCtx, redisClient := newLegacyCheckCodeTestSvcCtx(t) + router := newLegacyCheckCodeTestRouter(svcCtx) + + email := "latest@example.com" + code := "999888" + cacheKey := seedLegacyVerifyCode(t, redisClient, constant.Security.String(), email, code) + + resp := callLegacyCheckCode(t, router, "1.0.1", `{"email":"latest@example.com","code":"999888","type":3}`) + assert.Equal(t, uint32(200), resp.Code) + assert.True(t, resp.Data.Status) + + exists, err := redisClient.Exists(context.Background(), cacheKey).Result() + require.NoError(t, err) + assert.Equal(t, int64(0), exists) + + resp = callLegacyCheckCode(t, router, "1.0.1", `{"email":"latest@example.com","code":"999888","type":3}`) + assert.Equal(t, uint32(200), resp.Code) + assert.False(t, resp.Data.Status) + assert.False(t, resp.Data.Exist) +} + +func TestCheckCodeLegacyHandler_EqualThresholdNotConsumed(t *testing.T) { + svcCtx, redisClient := newLegacyCheckCodeTestSvcCtx(t) + router := newLegacyCheckCodeTestRouter(svcCtx) + + email := "equal@example.com" + code := "112233" + cacheKey := seedLegacyVerifyCode(t, redisClient, constant.Security.String(), email, code) + + resp := callLegacyCheckCode(t, router, "1.0.0", `{"email":"equal@example.com","code":"112233","type":3}`) + assert.Equal(t, uint32(200), resp.Code) + assert.True(t, resp.Data.Status) + + exists, err := redisClient.Exists(context.Background(), cacheKey).Result() + require.NoError(t, err) + assert.Equal(t, int64(1), exists) +} + +func TestCheckCodeLegacyHandler_InvalidVersionNotConsumed(t *testing.T) { + svcCtx, redisClient := newLegacyCheckCodeTestSvcCtx(t) + router := newLegacyCheckCodeTestRouter(svcCtx) + + email := "invalid@example.com" + code := "445566" + cacheKey := seedLegacyVerifyCode(t, redisClient, constant.Security.String(), email, code) + + resp := callLegacyCheckCode(t, router, "abc", `{"email":"invalid@example.com","code":"445566","type":3}`) + assert.Equal(t, uint32(200), resp.Code) + assert.True(t, resp.Data.Status) + + exists, err := redisClient.Exists(context.Background(), cacheKey).Result() + require.NoError(t, err) + assert.Equal(t, int64(1), exists) +} diff --git a/internal/handler/common/checkverificationcodehandler.go b/internal/handler/common/checkverificationcodehandler.go index 70c743f..1d2ffee 100644 --- a/internal/handler/common/checkverificationcodehandler.go +++ b/internal/handler/common/checkverificationcodehandler.go @@ -5,6 +5,7 @@ import ( "github.com/perfect-panel/server/internal/logic/common" "github.com/perfect-panel/server/internal/svc" "github.com/perfect-panel/server/internal/types" + "github.com/perfect-panel/server/pkg/constant" "github.com/perfect-panel/server/pkg/result" ) @@ -20,7 +21,15 @@ func CheckVerificationCodeHandler(svcCtx *svc.ServiceContext) func(c *gin.Contex } l := common.NewCheckVerificationCodeLogic(c.Request.Context(), svcCtx) - resp, err := l.CheckVerificationCode(&req) + useLatest := false + if value, ok := c.Request.Context().Value(constant.CtxKeyAPIVersionUseLatest).(bool); ok { + useLatest = value + } + + resp, err := l.CheckVerificationCodeWithBehavior(&req, common.VerifyCodeCheckBehavior{ + Source: "canonical", + Consume: useLatest, + }) result.HttpResult(c, resp, err) } } diff --git a/internal/handler/common/checkverificationcodehandler_test.go b/internal/handler/common/checkverificationcodehandler_test.go new file mode 100644 index 0000000..721e0b5 --- /dev/null +++ b/internal/handler/common/checkverificationcodehandler_test.go @@ -0,0 +1,143 @@ +package common + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/alicebob/miniredis/v2" + "github.com/gin-gonic/gin" + "github.com/perfect-panel/server/internal/config" + "github.com/perfect-panel/server/internal/middleware" + "github.com/perfect-panel/server/internal/svc" + "github.com/perfect-panel/server/pkg/authmethod" + "github.com/perfect-panel/server/pkg/constant" + "github.com/redis/go-redis/v9" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type canonicalCheckCodeResponse struct { + Code uint32 `json:"code"` + Data struct { + Status bool `json:"status"` + Exist bool `json:"exist"` + } `json:"data"` +} + +func newCanonicalCheckCodeTestSvcCtx(t *testing.T) (*svc.ServiceContext, *redis.Client) { + t.Helper() + + miniRedis := miniredis.RunT(t) + redisClient := redis.NewClient(&redis.Options{Addr: miniRedis.Addr()}) + t.Cleanup(func() { + redisClient.Close() + miniRedis.Close() + }) + + svcCtx := &svc.ServiceContext{ + Redis: redisClient, + Config: config.Config{ + VerifyCode: config.VerifyCode{ + VerifyCodeExpireTime: 900, + }, + }, + } + return svcCtx, redisClient +} + +func newCanonicalCheckCodeTestRouter(svcCtx *svc.ServiceContext) *gin.Engine { + gin.SetMode(gin.TestMode) + router := gin.New() + router.Use(middleware.ApiVersionMiddleware(svcCtx)) + router.POST("/v1/common/check_verification_code", CheckVerificationCodeHandler(svcCtx)) + return router +} + +func seedCanonicalVerifyCode(t *testing.T, redisClient *redis.Client, scene string, account string, code string) string { + t.Helper() + + cacheKey := fmt.Sprintf("%s:%s:%s", config.AuthCodeCacheKey, scene, account) + payload := map[string]interface{}{ + "code": code, + "lastAt": time.Now().Unix(), + } + payloadRaw, err := json.Marshal(payload) + require.NoError(t, err) + err = redisClient.Set(context.Background(), cacheKey, payloadRaw, time.Minute*15).Err() + require.NoError(t, err) + return cacheKey +} + +func callCanonicalCheckCode(t *testing.T, router *gin.Engine, apiHeader string, body string) canonicalCheckCodeResponse { + t.Helper() + + reqBody := bytes.NewBufferString(body) + req := httptest.NewRequest(http.MethodPost, "/v1/common/check_verification_code", reqBody) + req.Header.Set("Content-Type", "application/json") + if apiHeader != "" { + req.Header.Set("api-header", apiHeader) + } + recorder := httptest.NewRecorder() + router.ServeHTTP(recorder, req) + require.Equal(t, http.StatusOK, recorder.Code) + + var resp canonicalCheckCodeResponse + err := json.Unmarshal(recorder.Body.Bytes(), &resp) + require.NoError(t, err) + return resp +} + +func TestCheckVerificationCodeHandler_ApiHeaderGate(t *testing.T) { + tests := []struct { + name string + apiHeader string + expectConsume bool + }{ + {name: "no header", apiHeader: "", expectConsume: false}, + {name: "invalid header", apiHeader: "invalid", expectConsume: false}, + {name: "equal threshold", apiHeader: "1.0.0", expectConsume: false}, + {name: "greater threshold", apiHeader: "1.0.1", expectConsume: true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + svcCtx, redisClient := newCanonicalCheckCodeTestSvcCtx(t) + router := newCanonicalCheckCodeTestRouter(svcCtx) + + account := "header-gate@example.com" + code := "123123" + cacheKey := seedCanonicalVerifyCode(t, redisClient, constant.Register.String(), account, code) + body := fmt.Sprintf(`{"method":"%s","account":"%s","code":"%s","type":%d}`, + authmethod.Email, + account, + code, + constant.Register, + ) + + resp := callCanonicalCheckCode(t, router, tt.apiHeader, body) + assert.Equal(t, uint32(200), resp.Code) + assert.True(t, resp.Data.Status) + + exists, err := redisClient.Exists(context.Background(), cacheKey).Result() + require.NoError(t, err) + if tt.expectConsume { + assert.Equal(t, int64(0), exists) + } else { + assert.Equal(t, int64(1), exists) + } + + resp = callCanonicalCheckCode(t, router, tt.apiHeader, body) + if tt.expectConsume { + assert.False(t, resp.Data.Status) + } else { + assert.True(t, resp.Data.Status) + } + }) + } +} diff --git a/internal/handler/routes.go b/internal/handler/routes.go index 1317e33..3d18628 100644 --- a/internal/handler/routes.go +++ b/internal/handler/routes.go @@ -635,6 +635,9 @@ func RegisterHandlers(router *gin.Engine, serverCtx *svc.ServiceContext) { // Check user telephone is exist authGroupRouter.GET("/check/telephone", auth.CheckUserTelephoneHandler(serverCtx)) + // Check legacy verification code + authGroupRouter.POST("/check-code", auth.CheckCodeLegacyHandler(serverCtx)) + // User login authGroupRouter.POST("/login", auth.UserLoginHandler(serverCtx)) diff --git a/internal/logic/common/checkverificationcodelogic.go b/internal/logic/common/checkverificationcodelogic.go index e0fb6cb..1ed2f7c 100644 --- a/internal/logic/common/checkverificationcodelogic.go +++ b/internal/logic/common/checkverificationcodelogic.go @@ -2,19 +2,9 @@ package common import ( "context" - "encoding/json" - "fmt" - "strings" - - "github.com/perfect-panel/server/internal/config" "github.com/perfect-panel/server/internal/svc" "github.com/perfect-panel/server/internal/types" - "github.com/perfect-panel/server/pkg/authmethod" - "github.com/perfect-panel/server/pkg/constant" "github.com/perfect-panel/server/pkg/logger" - "github.com/perfect-panel/server/pkg/phone" - "github.com/perfect-panel/server/pkg/xerr" - "github.com/pkg/errors" ) type CheckVerificationCodeLogic struct { @@ -33,40 +23,8 @@ func NewCheckVerificationCodeLogic(ctx context.Context, svcCtx *svc.ServiceConte } func (l *CheckVerificationCodeLogic) CheckVerificationCode(req *types.CheckVerificationCodeRequest) (resp *types.CheckVerificationCodeRespone, err error) { - resp = &types.CheckVerificationCodeRespone{} - req.Account = strings.ToLower(strings.TrimSpace(req.Account)) - if req.Method == authmethod.Email { - cacheKey := fmt.Sprintf("%s:%s:%s", config.AuthCodeCacheKey, constant.ParseVerifyType(req.Type), req.Account) - value, err := l.svcCtx.Redis.Get(l.ctx, cacheKey).Result() - if err != nil { - return resp, nil - } - var payload CacheKeyPayload - if err := json.Unmarshal([]byte(value), &payload); err != nil { - return resp, nil - } - if payload.Code != req.Code { - return resp, nil - } - resp.Status = true - } - if req.Method == authmethod.Mobile { - if !phone.CheckPhone(req.Account) { - return nil, errors.Wrapf(xerr.NewErrCode(xerr.TelephoneError), "Invalid phone number") - } - cacheKey := fmt.Sprintf("%s:%s:+%s", config.AuthCodeTelephoneCacheKey, constant.ParseVerifyType(req.Type), req.Account) - value, err := l.svcCtx.Redis.Get(l.ctx, cacheKey).Result() - if err != nil { - return resp, nil - } - var payload CacheKeyPayload - if err := json.Unmarshal([]byte(value), &payload); err != nil { - return resp, nil - } - if payload.Code != req.Code { - return resp, nil - } - resp.Status = true - } - return resp, nil + return l.CheckVerificationCodeWithBehavior(req, VerifyCodeCheckBehavior{ + Source: "canonical", + Consume: true, + }) } diff --git a/internal/logic/common/checkverificationcodelogic_test.go b/internal/logic/common/checkverificationcodelogic_test.go new file mode 100644 index 0000000..6f064a6 --- /dev/null +++ b/internal/logic/common/checkverificationcodelogic_test.go @@ -0,0 +1,259 @@ +package common + +import ( + "context" + "encoding/json" + "fmt" + "testing" + "time" + + "github.com/alicebob/miniredis/v2" + "github.com/perfect-panel/server/internal/config" + "github.com/perfect-panel/server/internal/svc" + "github.com/perfect-panel/server/internal/types" + "github.com/perfect-panel/server/pkg/apiversion" + "github.com/perfect-panel/server/pkg/authmethod" + "github.com/perfect-panel/server/pkg/constant" + "github.com/redis/go-redis/v9" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestCheckVerificationCodeCanonicalConsume(t *testing.T) { + miniRedis := miniredis.RunT(t) + redisClient := redis.NewClient(&redis.Options{Addr: miniRedis.Addr()}) + t.Cleanup(func() { + redisClient.Close() + miniRedis.Close() + }) + + svcCtx := &svc.ServiceContext{ + Redis: redisClient, + Config: config.Config{ + VerifyCode: config.VerifyCode{ + VerifyCodeExpireTime: 900, + }, + }, + } + + email := "user@example.com" + code := "123456" + scene := constant.Register.String() + cacheKey := fmt.Sprintf("%s:%s:%s", config.AuthCodeCacheKey, scene, email) + setEmailCodePayload(t, redisClient, cacheKey, code, time.Now().Unix()) + + logic := NewCheckVerificationCodeLogic(context.Background(), svcCtx) + req := &types.CheckVerificationCodeRequest{ + Method: authmethod.Email, + Account: email, + Code: code, + Type: uint8(constant.Register), + } + + resp, err := logic.CheckVerificationCode(req) + require.NoError(t, err) + require.NotNil(t, resp) + assert.True(t, resp.Status) + assert.True(t, resp.Exist) + + exists, err := redisClient.Exists(context.Background(), cacheKey).Result() + require.NoError(t, err) + assert.Equal(t, int64(0), exists) + + resp, err = logic.CheckVerificationCode(req) + require.NoError(t, err) + require.NotNil(t, resp) + assert.False(t, resp.Status) + assert.False(t, resp.Exist) +} + +func TestCheckVerificationCodeLegacyNoConsumeAndType3Mapping(t *testing.T) { + miniRedis := miniredis.RunT(t) + redisClient := redis.NewClient(&redis.Options{Addr: miniRedis.Addr()}) + t.Cleanup(func() { + redisClient.Close() + miniRedis.Close() + }) + + svcCtx := &svc.ServiceContext{ + Redis: redisClient, + Config: config.Config{ + VerifyCode: config.VerifyCode{ + VerifyCodeExpireTime: 900, + }, + }, + } + + email := "legacy@example.com" + code := "654321" + scene := constant.Security.String() + cacheKey := fmt.Sprintf("%s:%s:%s", config.AuthCodeCacheKey, scene, email) + setEmailCodePayload(t, redisClient, cacheKey, code, time.Now().Unix()) + + legacyReq := &types.LegacyCheckVerificationCodeRequest{ + Email: email, + Code: code, + Type: 3, + } + + normalizedReq, type3Mapped, err := NormalizeLegacyCheckVerificationCodeRequest(legacyReq) + require.NoError(t, err) + assert.True(t, type3Mapped) + assert.Equal(t, uint8(constant.Security), normalizedReq.Type) + assert.Equal(t, authmethod.Email, normalizedReq.Method) + assert.Equal(t, email, normalizedReq.Account) + + logic := NewCheckVerificationCodeLogic(context.Background(), svcCtx) + legacyBehavior := VerifyCodeCheckBehavior{ + Source: "legacy", + Consume: false, + LegacyType3Mapped: true, + AllowSceneFallback: true, + } + + resp, err := logic.CheckVerificationCodeWithBehavior(normalizedReq, legacyBehavior) + require.NoError(t, err) + require.NotNil(t, resp) + assert.True(t, resp.Status) + assert.True(t, resp.Exist) + + exists, err := redisClient.Exists(context.Background(), cacheKey).Result() + require.NoError(t, err) + assert.Equal(t, int64(1), exists) + + resp, err = logic.CheckVerificationCodeWithBehavior(normalizedReq, legacyBehavior) + require.NoError(t, err) + assert.True(t, resp.Status) + + resp, err = logic.CheckVerificationCode(normalizedReq) + require.NoError(t, err) + assert.True(t, resp.Status) + + exists, err = redisClient.Exists(context.Background(), cacheKey).Result() + require.NoError(t, err) + assert.Equal(t, int64(0), exists) +} + +func TestCheckVerificationCodeLegacySceneFallback(t *testing.T) { + miniRedis := miniredis.RunT(t) + redisClient := redis.NewClient(&redis.Options{Addr: miniRedis.Addr()}) + t.Cleanup(func() { + redisClient.Close() + miniRedis.Close() + }) + + svcCtx := &svc.ServiceContext{ + Redis: redisClient, + Config: config.Config{ + VerifyCode: config.VerifyCode{ + VerifyCodeExpireTime: 900, + }, + }, + } + + email := "fallback@example.com" + code := "778899" + cacheKey := fmt.Sprintf("%s:%s:%s", config.AuthCodeCacheKey, constant.Register.String(), email) + setEmailCodePayload(t, redisClient, cacheKey, code, time.Now().Unix()) + + logic := NewCheckVerificationCodeLogic(context.Background(), svcCtx) + req := &types.CheckVerificationCodeRequest{ + Method: authmethod.Email, + Account: email, + Code: code, + Type: uint8(constant.Security), + } + + resp, err := logic.CheckVerificationCodeWithBehavior(req, VerifyCodeCheckBehavior{ + Source: "legacy", + Consume: false, + AllowSceneFallback: true, + }) + require.NoError(t, err) + require.NotNil(t, resp) + assert.True(t, resp.Status) + + resp, err = logic.CheckVerificationCodeWithBehavior(req, VerifyCodeCheckBehavior{ + Source: "legacy", + Consume: false, + AllowSceneFallback: false, + }) + require.NoError(t, err) + require.NotNil(t, resp) + assert.False(t, resp.Status) +} + +func setEmailCodePayload(t *testing.T, redisClient *redis.Client, cacheKey string, code string, lastAt int64) { + t.Helper() + + payload := CacheKeyPayload{ + Code: code, + LastAt: lastAt, + } + value, err := json.Marshal(payload) + require.NoError(t, err) + err = redisClient.Set(context.Background(), cacheKey, value, time.Minute*15).Err() + require.NoError(t, err) +} + +func TestCheckVerificationCodeWithApiHeaderGate(t *testing.T) { + tests := []struct { + name string + header string + expectConsume bool + }{ + {name: "missing header", header: "", expectConsume: false}, + {name: "invalid header", header: "invalid", expectConsume: false}, + {name: "equal threshold", header: "1.0.0", expectConsume: false}, + {name: "greater threshold", header: "1.0.1", expectConsume: true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + miniRedis := miniredis.RunT(t) + redisClient := redis.NewClient(&redis.Options{Addr: miniRedis.Addr()}) + t.Cleanup(func() { + redisClient.Close() + miniRedis.Close() + }) + + svcCtx := &svc.ServiceContext{ + Redis: redisClient, + Config: config.Config{ + VerifyCode: config.VerifyCode{ + VerifyCodeExpireTime: 900, + }, + }, + } + + email := "gate@example.com" + code := "101010" + cacheKey := fmt.Sprintf("%s:%s:%s", config.AuthCodeCacheKey, constant.Register.String(), email) + setEmailCodePayload(t, redisClient, cacheKey, code, time.Now().Unix()) + + logic := NewCheckVerificationCodeLogic(context.Background(), svcCtx) + req := &types.CheckVerificationCodeRequest{ + Method: authmethod.Email, + Account: email, + Code: code, + Type: uint8(constant.Register), + } + + resp, err := logic.CheckVerificationCodeWithBehavior(req, VerifyCodeCheckBehavior{ + Source: "canonical", + Consume: apiversion.UseLatest(tt.header, apiversion.DefaultThreshold), + }) + require.NoError(t, err) + require.NotNil(t, resp) + assert.True(t, resp.Status) + + exists, err := redisClient.Exists(context.Background(), cacheKey).Result() + require.NoError(t, err) + if tt.expectConsume { + assert.Equal(t, int64(0), exists) + } else { + assert.Equal(t, int64(1), exists) + } + }) + } +} diff --git a/internal/logic/common/sendEmailCodeLogic.go b/internal/logic/common/sendEmailCodeLogic.go index 516d47f..0170b99 100644 --- a/internal/logic/common/sendEmailCodeLogic.go +++ b/internal/logic/common/sendEmailCodeLogic.go @@ -88,7 +88,9 @@ func (l *SendEmailCodeLogic) SendEmailCode(req *types.SendCodeRequest) (resp *ty var taskPayload queue.SendEmailPayload // Generate verification code code := random.Key(6, 0) + scene := constant.ParseVerifyType(req.Type).String() taskPayload.Type = queue.EmailTypeVerify + taskPayload.Scene = scene taskPayload.Email = req.Email taskPayload.Subject = "Verification code" diff --git a/internal/logic/common/verifyCodeChecker.go b/internal/logic/common/verifyCodeChecker.go new file mode 100644 index 0000000..4cfaee9 --- /dev/null +++ b/internal/logic/common/verifyCodeChecker.go @@ -0,0 +1,227 @@ +package common + +import ( + "context" + "encoding/json" + "fmt" + "strings" + "time" + + "github.com/perfect-panel/server/internal/config" + "github.com/perfect-panel/server/internal/types" + "github.com/perfect-panel/server/pkg/authmethod" + "github.com/perfect-panel/server/pkg/constant" + "github.com/perfect-panel/server/pkg/logger" + "github.com/perfect-panel/server/pkg/phone" + "github.com/perfect-panel/server/pkg/xerr" + "github.com/pkg/errors" + "github.com/redis/go-redis/v9" +) + +var consumeVerifyCodeScript = redis.NewScript(` +local current = redis.call("GET", KEYS[1]) +if not current then + return 0 +end +if current == ARGV[1] then + redis.call("DEL", KEYS[1]) + return 1 +end +return -1 +`) + +type VerifyCodeCheckBehavior struct { + Source string + Consume bool + LegacyType3Mapped bool + AllowSceneFallback bool +} + +func NormalizeLegacyCheckVerificationCodeRequest(req *types.LegacyCheckVerificationCodeRequest) (*types.CheckVerificationCodeRequest, bool, error) { + if req == nil { + return nil, false, errors.Wrapf(xerr.NewErrCode(xerr.InvalidParams), "empty request") + } + + method := strings.ToLower(strings.TrimSpace(req.Method)) + account := strings.TrimSpace(req.Account) + email := strings.ToLower(strings.TrimSpace(req.Email)) + if account == "" { + account = email + } + if method == "" { + method = authmethod.Email + } + + mappedType := req.Type + legacyType3Mapped := false + if mappedType == 3 { + mappedType = uint8(constant.Security) + legacyType3Mapped = true + } + + normalizedReq := &types.CheckVerificationCodeRequest{ + Method: method, + Account: account, + Code: strings.TrimSpace(req.Code), + Type: mappedType, + } + + normalizedReq, err := normalizeCheckVerificationCodeRequest(normalizedReq) + if err != nil { + return nil, false, err + } + return normalizedReq, legacyType3Mapped, nil +} + +func (l *CheckVerificationCodeLogic) CheckVerificationCodeWithBehavior(req *types.CheckVerificationCodeRequest, behavior VerifyCodeCheckBehavior) (*types.CheckVerificationCodeRespone, error) { + resp := &types.CheckVerificationCodeRespone{} + + normalizedReq, err := normalizeCheckVerificationCodeRequest(req) + if err != nil { + return nil, err + } + + source := strings.TrimSpace(behavior.Source) + if source == "" { + source = "canonical" + } + + verifyType := constant.ParseVerifyType(normalizedReq.Type) + scenes := resolveVerifyScenes(verifyType, behavior.AllowSceneFallback) + if len(scenes) == 0 { + l.Infow("[CheckVerificationCode] unsupported verify type", + logger.Field("verify_check_source", source), + logger.Field("verify_consume", behavior.Consume), + logger.Field("legacy_type3_mapped", behavior.LegacyType3Mapped), + logger.Field("legacy_scene_fallback_hit", false), + logger.Field("type", normalizedReq.Type), + logger.Field("method", normalizedReq.Method), + ) + return resp, nil + } + + expireTime := l.svcCtx.Config.VerifyCode.VerifyCodeExpireTime + if expireTime <= 0 { + expireTime = 900 + } + + for idx, scene := range scenes { + cacheKey := buildVerifyCodeCacheKey(normalizedReq.Method, scene, normalizedReq.Account) + value, redisErr := l.svcCtx.Redis.Get(l.ctx, cacheKey).Result() + if redisErr != nil || value == "" { + continue + } + + var payload CacheKeyPayload + if err = json.Unmarshal([]byte(value), &payload); err != nil { + continue + } + if payload.Code != normalizedReq.Code { + continue + } + if time.Now().Unix()-payload.LastAt > expireTime { + continue + } + + if behavior.Consume { + consumed, consumeErr := consumeVerificationCodeAtomically(l.ctx, l.svcCtx.Redis, cacheKey, value) + if consumeErr != nil { + return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "consume verification code failed") + } + if !consumed { + continue + } + } + + fallbackHit := idx > 0 + resp.Status = true + resp.Exist = true + + l.Infow("[CheckVerificationCode] verify success", + logger.Field("verify_check_source", source), + logger.Field("verify_consume", behavior.Consume), + logger.Field("legacy_type3_mapped", behavior.LegacyType3Mapped), + logger.Field("legacy_scene_fallback_hit", fallbackHit), + logger.Field("scene", scene), + logger.Field("method", normalizedReq.Method), + ) + return resp, nil + } + + l.Infow("[CheckVerificationCode] verify failed", + logger.Field("verify_check_source", source), + logger.Field("verify_consume", behavior.Consume), + logger.Field("legacy_type3_mapped", behavior.LegacyType3Mapped), + logger.Field("legacy_scene_fallback_hit", false), + logger.Field("type", normalizedReq.Type), + logger.Field("method", normalizedReq.Method), + ) + + return resp, nil +} + +func normalizeCheckVerificationCodeRequest(req *types.CheckVerificationCodeRequest) (*types.CheckVerificationCodeRequest, error) { + if req == nil { + return nil, errors.Wrapf(xerr.NewErrCode(xerr.InvalidParams), "empty request") + } + + method := strings.ToLower(strings.TrimSpace(req.Method)) + account := strings.TrimSpace(req.Account) + code := strings.TrimSpace(req.Code) + + switch method { + case authmethod.Email: + account = strings.ToLower(account) + case authmethod.Mobile: + if !phone.CheckPhone(account) { + return nil, errors.Wrapf(xerr.NewErrCode(xerr.TelephoneError), "Invalid phone number") + } + default: + return nil, errors.Wrapf(xerr.NewErrCode(xerr.InvalidParams), "invalid method") + } + + if account == "" { + return nil, errors.Wrapf(xerr.NewErrCode(xerr.InvalidParams), "account is required") + } + + return &types.CheckVerificationCodeRequest{ + Method: method, + Account: account, + Code: code, + Type: req.Type, + }, nil +} + +func resolveVerifyScenes(verifyType constant.VerifyType, allowFallback bool) []string { + switch verifyType { + case constant.Register: + if allowFallback { + return []string{constant.Register.String(), constant.Security.String()} + } + return []string{constant.Register.String()} + case constant.Security: + if allowFallback { + return []string{constant.Security.String(), constant.Register.String()} + } + return []string{constant.Security.String()} + case constant.DeleteAccount: + return []string{constant.DeleteAccount.String()} + default: + return nil + } +} + +func buildVerifyCodeCacheKey(method string, scene string, account string) string { + if method == authmethod.Mobile { + return fmt.Sprintf("%s:%s:+%s", config.AuthCodeTelephoneCacheKey, scene, account) + } + return fmt.Sprintf("%s:%s:%s", config.AuthCodeCacheKey, scene, account) +} + +func consumeVerificationCodeAtomically(ctx context.Context, redisClient *redis.Client, cacheKey string, expectedValue string) (bool, error) { + result, err := consumeVerifyCodeScript.Run(ctx, redisClient, []string{cacheKey}, expectedValue).Int() + if err != nil { + return false, err + } + return result == 1, nil +} diff --git a/internal/middleware/apiVersionMiddleware.go b/internal/middleware/apiVersionMiddleware.go new file mode 100644 index 0000000..845a54d --- /dev/null +++ b/internal/middleware/apiVersionMiddleware.go @@ -0,0 +1,25 @@ +package middleware + +import ( + "context" + "strings" + + "github.com/gin-gonic/gin" + "github.com/perfect-panel/server/internal/svc" + "github.com/perfect-panel/server/pkg/apiversion" + "github.com/perfect-panel/server/pkg/constant" +) + +func ApiVersionMiddleware(_ *svc.ServiceContext) func(c *gin.Context) { + return func(c *gin.Context) { + rawVersion := strings.TrimSpace(c.GetHeader("api-header")) + useLatest := apiversion.UseLatest(rawVersion, apiversion.DefaultThreshold) + + ctx := context.WithValue(c.Request.Context(), constant.CtxKeyAPIVersionUseLatest, useLatest) + ctx = context.WithValue(ctx, constant.CtxKeyAPIHeaderRaw, rawVersion) + c.Request = c.Request.WithContext(ctx) + + c.Set("api_header", rawVersion) + c.Next() + } +} diff --git a/internal/middleware/corsMiddleware.go b/internal/middleware/corsMiddleware.go index e6c94df..51c92b0 100644 --- a/internal/middleware/corsMiddleware.go +++ b/internal/middleware/corsMiddleware.go @@ -15,7 +15,7 @@ func CorsMiddleware(c *gin.Context) { } // c.Writer.Header().Set("Access-Control-Allow-Origin", c.Request.Host) c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE, UPDATE") - c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Origin, X-CSRF-Token, Authorization, AccessToken, Token, Range") + c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Origin, X-CSRF-Token, Authorization, AccessToken, Token, Range, api-header") c.Writer.Header().Set("Access-Control-Expose-Headers", "Content-Length, Access-Control-Allow-Origin, Access-Control-Allow-Headers") c.Writer.Header().Set("Access-Control-Allow-Credentials", "true") c.Writer.Header().Set("Access-Control-Max-Age", "172800") diff --git a/internal/middleware/loggerMiddleware.go b/internal/middleware/loggerMiddleware.go index a5b5639..c598daf 100644 --- a/internal/middleware/loggerMiddleware.go +++ b/internal/middleware/loggerMiddleware.go @@ -70,6 +70,9 @@ func LoggerMiddleware(svc *svc.ServiceContext) func(c *gin.Context) { Value: c.Request.UserAgent(), }, } + if apiHeader, ok := c.Get("api_header"); ok { + logs = append(logs, logger.Field("api_header", apiHeader)) + } if c.Errors.Last() != nil { var e *xerr.CodeError var errMessage string diff --git a/internal/model/auth/auth.go b/internal/model/auth/auth.go index ffeab4c..be40d92 100644 --- a/internal/model/auth/auth.go +++ b/internal/model/auth/auth.go @@ -113,16 +113,17 @@ func (l *TelegramAuthConfig) Unmarshal(data string) error { } type EmailAuthConfig struct { - Platform string `json:"platform"` - PlatformConfig interface{} `json:"platform_config"` - EnableVerify bool `json:"enable_verify"` - EnableNotify bool `json:"enable_notify"` - EnableDomainSuffix bool `json:"enable_domain_suffix"` - DomainSuffixList string `json:"domain_suffix_list"` - VerifyEmailTemplate string `json:"verify_email_template"` - ExpirationEmailTemplate string `json:"expiration_email_template"` - MaintenanceEmailTemplate string `json:"maintenance_email_template"` - TrafficExceedEmailTemplate string `json:"traffic_exceed_email_template"` + Platform string `json:"platform"` + PlatformConfig interface{} `json:"platform_config"` + EnableVerify bool `json:"enable_verify"` + EnableNotify bool `json:"enable_notify"` + EnableDomainSuffix bool `json:"enable_domain_suffix"` + DomainSuffixList string `json:"domain_suffix_list"` + VerifyEmailTemplate string `json:"verify_email_template"` + VerifyEmailTemplates map[string]string `json:"verify_email_templates"` + ExpirationEmailTemplate string `json:"expiration_email_template"` + MaintenanceEmailTemplate string `json:"maintenance_email_template"` + TrafficExceedEmailTemplate string `json:"traffic_exceed_email_template"` } func (l *EmailAuthConfig) Marshal() string { @@ -138,6 +139,9 @@ func (l *EmailAuthConfig) Marshal() string { if l.VerifyEmailTemplate == "" { l.VerifyEmailTemplate = email.DefaultEmailVerifyTemplate } + if l.VerifyEmailTemplates == nil { + l.VerifyEmailTemplates = map[string]string{} + } bytes, err := json.Marshal(l) if err != nil { config := &EmailAuthConfig{ @@ -148,6 +152,7 @@ func (l *EmailAuthConfig) Marshal() string { EnableDomainSuffix: false, DomainSuffixList: "", VerifyEmailTemplate: email.DefaultEmailVerifyTemplate, + VerifyEmailTemplates: map[string]string{}, ExpirationEmailTemplate: email.DefaultExpirationEmailTemplate, MaintenanceEmailTemplate: email.DefaultMaintenanceEmailTemplate, TrafficExceedEmailTemplate: email.DefaultTrafficExceedEmailTemplate, @@ -169,11 +174,20 @@ func (l *EmailAuthConfig) Unmarshal(data string) { EnableDomainSuffix: false, DomainSuffixList: "", VerifyEmailTemplate: email.DefaultEmailVerifyTemplate, + VerifyEmailTemplates: map[string]string{}, ExpirationEmailTemplate: email.DefaultExpirationEmailTemplate, MaintenanceEmailTemplate: email.DefaultMaintenanceEmailTemplate, TrafficExceedEmailTemplate: email.DefaultTrafficExceedEmailTemplate, } _ = json.Unmarshal([]byte(config.Marshal()), &l) + return + } + + if l.VerifyEmailTemplate == "" { + l.VerifyEmailTemplate = email.DefaultEmailVerifyTemplate + } + if l.VerifyEmailTemplates == nil { + l.VerifyEmailTemplates = map[string]string{} } } diff --git a/internal/server.go b/internal/server.go index 78d6422..98db20a 100644 --- a/internal/server.go +++ b/internal/server.go @@ -50,7 +50,7 @@ func initServer(svc *svc.ServiceContext) *gin.Engine { } r.Use(sessions.Sessions("ppanel", sessionStore)) // use cors middleware - r.Use(middleware.TraceMiddleware(svc), middleware.LoggerMiddleware(svc), middleware.CorsMiddleware, gin.Recovery()) + r.Use(middleware.TraceMiddleware(svc), middleware.ApiVersionMiddleware(svc), middleware.LoggerMiddleware(svc), middleware.CorsMiddleware, gin.Recovery()) // register handlers handler.RegisterHandlers(r, svc) diff --git a/internal/types/types.go b/internal/types/types.go index af15577..0bfff14 100644 --- a/internal/types/types.go +++ b/internal/types/types.go @@ -218,6 +218,20 @@ type CheckVerificationCodeRequest struct { type CheckVerificationCodeRespone struct { Status bool `json:"status"` + Exist bool `json:"exist,omitempty"` +} + +type LegacyCheckVerificationCodeRequest struct { + Method string `json:"method" form:"method"` + Account string `json:"account" form:"account"` + Email string `json:"email" form:"email"` + Code string `json:"code" form:"code" validate:"required"` + Type uint8 `json:"type" form:"type" validate:"required"` +} + +type LegacyCheckVerificationCodeResponse struct { + Status bool `json:"status"` + Exist bool `json:"exist"` } type CheckoutOrderRequest struct { diff --git a/pkg/apiversion/version.go b/pkg/apiversion/version.go new file mode 100644 index 0000000..1e0cbd6 --- /dev/null +++ b/pkg/apiversion/version.go @@ -0,0 +1,83 @@ +package apiversion + +import ( + "regexp" + "strconv" + "strings" +) + +const DefaultThreshold = "1.0.0" + +type Version struct { + Major int + Minor int + Patch int +} + +var versionPattern = regexp.MustCompile(`^v?(\d+)\.(\d+)\.(\d+)$`) + +func Parse(header string) (Version, bool) { + normalized := strings.TrimSpace(header) + if normalized == "" { + return Version{}, false + } + + matches := versionPattern.FindStringSubmatch(normalized) + if len(matches) != 4 { + return Version{}, false + } + + major, err := strconv.Atoi(matches[1]) + if err != nil { + return Version{}, false + } + minor, err := strconv.Atoi(matches[2]) + if err != nil { + return Version{}, false + } + patch, err := strconv.Atoi(matches[3]) + if err != nil { + return Version{}, false + } + + return Version{Major: major, Minor: minor, Patch: patch}, true +} + +func UseLatest(header string, threshold string) bool { + currentVersion, ok := Parse(header) + if !ok { + return false + } + + thresholdVersion, ok := Parse(strings.TrimSpace(threshold)) + if !ok { + thresholdVersion, _ = Parse(DefaultThreshold) + } + + return compare(currentVersion, thresholdVersion) > 0 +} + +func compare(left Version, right Version) int { + if left.Major != right.Major { + if left.Major > right.Major { + return 1 + } + return -1 + } + + if left.Minor != right.Minor { + if left.Minor > right.Minor { + return 1 + } + return -1 + } + + if left.Patch != right.Patch { + if left.Patch > right.Patch { + return 1 + } + return -1 + } + + return 0 +} diff --git a/pkg/apiversion/version_test.go b/pkg/apiversion/version_test.go new file mode 100644 index 0000000..41e5585 --- /dev/null +++ b/pkg/apiversion/version_test.go @@ -0,0 +1,55 @@ +package apiversion + +import "testing" + +func TestParse(t *testing.T) { + tests := []struct { + name string + raw string + valid bool + version Version + }{ + {name: "empty", raw: "", valid: false}, + {name: "invalid text", raw: "abc", valid: false}, + {name: "missing patch", raw: "1.0", valid: false}, + {name: "exact", raw: "1.0.0", valid: true, version: Version{Major: 1, Minor: 0, Patch: 0}}, + {name: "with prefix", raw: "v1.2.3", valid: true, version: Version{Major: 1, Minor: 2, Patch: 3}}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + version, ok := Parse(tt.raw) + if ok != tt.valid { + t.Fatalf("expected valid=%v, got %v", tt.valid, ok) + } + if tt.valid && version != tt.version { + t.Fatalf("expected version=%+v, got %+v", tt.version, version) + } + }) + } +} + +func TestUseLatest(t *testing.T) { + tests := []struct { + name string + header string + threshold string + expect bool + }{ + {name: "missing header", header: "", threshold: "1.0.0", expect: false}, + {name: "invalid header", header: "invalid", threshold: "1.0.0", expect: false}, + {name: "equal threshold", header: "1.0.0", threshold: "1.0.0", expect: false}, + {name: "greater threshold", header: "1.0.1", threshold: "1.0.0", expect: true}, + {name: "greater with v prefix", header: "v1.2.3", threshold: "1.0.0", expect: true}, + {name: "less than threshold", header: "0.9.9", threshold: "1.0.0", expect: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := UseLatest(tt.header, tt.threshold) + if result != tt.expect { + t.Fatalf("expected %v, got %v", tt.expect, result) + } + }) + } +} diff --git a/pkg/constant/context.go b/pkg/constant/context.go index 1c368d4..ea20e2a 100644 --- a/pkg/constant/context.go +++ b/pkg/constant/context.go @@ -3,14 +3,16 @@ package constant type CtxKey string const ( - CtxKeyUser CtxKey = "user" - CtxKeySessionID CtxKey = "sessionId" - CtxKeyRequestHost CtxKey = "requestHost" - CtxKeyPlatform CtxKey = "platform" - CtxKeyPayment CtxKey = "payment" - CtxLoginType CtxKey = "loginType" - LoginType CtxKey = "loginType" - CtxKeyIdentifier CtxKey = "identifier" - CtxKeyDeviceID CtxKey = "deviceId" - CtxKeyIncludeExpired CtxKey = "includeExpired" + CtxKeyUser CtxKey = "user" + CtxKeySessionID CtxKey = "sessionId" + CtxKeyRequestHost CtxKey = "requestHost" + CtxKeyPlatform CtxKey = "platform" + CtxKeyPayment CtxKey = "payment" + CtxLoginType CtxKey = "loginType" + LoginType CtxKey = "loginType" + CtxKeyIdentifier CtxKey = "identifier" + CtxKeyDeviceID CtxKey = "deviceId" + CtxKeyIncludeExpired CtxKey = "includeExpired" + CtxKeyAPIVersionUseLatest CtxKey = "apiVersionUseLatest" + CtxKeyAPIHeaderRaw CtxKey = "apiHeaderRaw" ) diff --git a/queue/logic/email/sendEmailLogic.go b/queue/logic/email/sendEmailLogic.go index 37c6e16..6779942 100644 --- a/queue/logic/email/sendEmailLogic.go +++ b/queue/logic/email/sendEmailLogic.go @@ -13,6 +13,7 @@ import ( "github.com/hibiken/asynq" "github.com/perfect-panel/server/internal/model/log" "github.com/perfect-panel/server/internal/svc" + "github.com/perfect-panel/server/pkg/constant" "github.com/perfect-panel/server/pkg/email" "github.com/perfect-panel/server/queue/types" ) @@ -49,9 +50,6 @@ func (l *SendEmailLogic) ProcessTask(ctx context.Context, task *asynq.Task) erro var content string switch payload.Type { case types.EmailTypeVerify: - tplStr := l.svcCtx.Config.Email.VerifyEmailTemplate - - // Use int for better template compatibility if t, ok := payload.Content["Type"].(float64); ok { payload.Content["Type"] = int(t) } else if t, ok := payload.Content["Type"].(int); ok { @@ -59,18 +57,14 @@ func (l *SendEmailLogic) ProcessTask(ctx context.Context, task *asynq.Task) erro } typeVal, _ := payload.Content["Type"].(int) + scene := resolveVerifyScene(payload.Scene, typeVal) + tplStr := selectVerifyTemplate(l.svcCtx.Config.Email.VerifyEmailTemplates, l.svcCtx.Config.Email.VerifyEmailTemplate, scene) - // Smart Fallback: If template is empty OR (Type is 4 but template doesn't support it), use default - // We check for "Type 4" or "Type eq 4" string in the template as a heuristic - needDefault := tplStr == "" - if !needDefault && typeVal == 4 && + if tplStr == l.svcCtx.Config.Email.VerifyEmailTemplate && + scene == constant.DeleteAccount.String() && !strings.Contains(tplStr, "Type 4") && !strings.Contains(tplStr, "Type eq 4") { - logger.WithContext(ctx).Infow("[SendEmailLogic] Configured template might not support DeleteAccount (Type 4), forcing default template") - needDefault = true - } - - if needDefault { + logger.WithContext(ctx).Infow("[SendEmailLogic] configured legacy verify template may not support DeleteAccount, fallback to default template") tplStr = email.DefaultEmailVerifyTemplate } @@ -189,3 +183,26 @@ func (l *SendEmailLogic) ProcessTask(ctx context.Context, task *asynq.Task) erro } return nil } + +func resolveVerifyScene(scene string, typeVal int) string { + scene = strings.ToLower(strings.TrimSpace(scene)) + if scene != "" { + return scene + } + return constant.ParseVerifyType(uint8(typeVal)).String() +} + +func selectVerifyTemplate(sceneTemplates map[string]string, legacyTemplate string, scene string) string { + if sceneTemplates != nil { + if tpl := strings.TrimSpace(sceneTemplates[scene]); tpl != "" { + return tpl + } + if tpl := strings.TrimSpace(sceneTemplates["default"]); tpl != "" { + return tpl + } + } + if strings.TrimSpace(legacyTemplate) != "" { + return legacyTemplate + } + return email.DefaultEmailVerifyTemplate +} diff --git a/queue/logic/email/sendEmailLogic_test.go b/queue/logic/email/sendEmailLogic_test.go new file mode 100644 index 0000000..20c8323 --- /dev/null +++ b/queue/logic/email/sendEmailLogic_test.go @@ -0,0 +1,28 @@ +package emailLogic + +import ( + "testing" + + "github.com/perfect-panel/server/pkg/constant" + "github.com/perfect-panel/server/pkg/email" + "github.com/stretchr/testify/assert" +) + +func TestSelectVerifyTemplate(t *testing.T) { + sceneTemplates := map[string]string{ + "default": "DEFAULT_TEMPLATE", + "register": "REGISTER_TEMPLATE", + "delete_account": "DELETE_TEMPLATE", + } + + assert.Equal(t, "REGISTER_TEMPLATE", selectVerifyTemplate(sceneTemplates, "LEGACY_TEMPLATE", "register")) + assert.Equal(t, "DEFAULT_TEMPLATE", selectVerifyTemplate(sceneTemplates, "LEGACY_TEMPLATE", "security")) + assert.Equal(t, "LEGACY_TEMPLATE", selectVerifyTemplate(nil, "LEGACY_TEMPLATE", "security")) + assert.Equal(t, email.DefaultEmailVerifyTemplate, selectVerifyTemplate(nil, "", "security")) +} + +func TestResolveVerifyScene(t *testing.T) { + assert.Equal(t, "register", resolveVerifyScene("register", 0)) + assert.Equal(t, constant.DeleteAccount.String(), resolveVerifyScene("", int(constant.DeleteAccount))) + assert.Equal(t, "unknown", resolveVerifyScene("", 99)) +} diff --git a/queue/types/email.go b/queue/types/email.go index 4fee979..abeebca 100644 --- a/queue/types/email.go +++ b/queue/types/email.go @@ -16,6 +16,7 @@ const ( type ( SendEmailPayload struct { Type string `json:"type"` + Scene string `json:"scene,omitempty"` Email string `json:"to"` Subject string `json:"subject"` Content map[string]interface{} `json:"content"`