package common import ( "context" "encoding/json" "fmt" "strings" "time" "github.com/perfect-panel/server/internal/config" "github.com/perfect-panel/server/internal/types" "github.com/perfect-panel/server/pkg/authmethod" "github.com/perfect-panel/server/pkg/constant" "github.com/perfect-panel/server/pkg/logger" "github.com/perfect-panel/server/pkg/phone" "github.com/perfect-panel/server/pkg/xerr" "github.com/pkg/errors" "github.com/redis/go-redis/v9" ) var consumeVerifyCodeScript = redis.NewScript(` local current = redis.call("GET", KEYS[1]) if not current then return 0 end if current == ARGV[1] then redis.call("DEL", KEYS[1]) return 1 end return -1 `) type VerifyCodeCheckBehavior struct { Source string Consume bool LegacyType3Mapped bool AllowSceneFallback bool } func NormalizeLegacyCheckVerificationCodeRequest(req *types.LegacyCheckVerificationCodeRequest) (*types.CheckVerificationCodeRequest, bool, error) { if req == nil { return nil, false, errors.Wrapf(xerr.NewErrCode(xerr.InvalidParams), "empty request") } method := strings.ToLower(strings.TrimSpace(req.Method)) account := strings.TrimSpace(req.Account) email := strings.ToLower(strings.TrimSpace(req.Email)) if account == "" { account = email } if method == "" { method = authmethod.Email } mappedType := req.Type legacyType3Mapped := false if mappedType == 3 { mappedType = uint8(constant.Security) legacyType3Mapped = true } normalizedReq := &types.CheckVerificationCodeRequest{ Method: method, Account: account, Code: strings.TrimSpace(req.Code), Type: mappedType, } normalizedReq, err := normalizeCheckVerificationCodeRequest(normalizedReq) if err != nil { return nil, false, err } return normalizedReq, legacyType3Mapped, nil } func (l *CheckVerificationCodeLogic) CheckVerificationCodeWithBehavior(req *types.CheckVerificationCodeRequest, behavior VerifyCodeCheckBehavior) (*types.CheckVerificationCodeRespone, error) { resp := &types.CheckVerificationCodeRespone{} normalizedReq, err := normalizeCheckVerificationCodeRequest(req) if err != nil { return nil, err } source := strings.TrimSpace(behavior.Source) if source == "" { source = "canonical" } verifyType := constant.ParseVerifyType(normalizedReq.Type) scenes := resolveVerifyScenes(verifyType, behavior.AllowSceneFallback) if len(scenes) == 0 { l.Infow("[CheckVerificationCode] unsupported verify type", logger.Field("verify_check_source", source), logger.Field("verify_consume", behavior.Consume), logger.Field("legacy_type3_mapped", behavior.LegacyType3Mapped), logger.Field("legacy_scene_fallback_hit", false), logger.Field("type", normalizedReq.Type), logger.Field("method", normalizedReq.Method), ) return resp, nil } expireTime := l.svcCtx.Config.VerifyCode.VerifyCodeExpireTime if expireTime <= 0 { expireTime = 900 } for idx, scene := range scenes { cacheKey := buildVerifyCodeCacheKey(normalizedReq.Method, scene, normalizedReq.Account) value, redisErr := l.svcCtx.Redis.Get(l.ctx, cacheKey).Result() if redisErr != nil || value == "" { continue } var payload CacheKeyPayload if err = json.Unmarshal([]byte(value), &payload); err != nil { continue } if payload.Code != normalizedReq.Code { continue } if time.Now().Unix()-payload.LastAt > expireTime { continue } if behavior.Consume { consumed, consumeErr := consumeVerificationCodeAtomically(l.ctx, l.svcCtx.Redis, cacheKey, value) if consumeErr != nil { return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "consume verification code failed") } if !consumed { continue } } fallbackHit := idx > 0 resp.Status = true resp.Exist = true l.Infow("[CheckVerificationCode] verify success", logger.Field("verify_check_source", source), logger.Field("verify_consume", behavior.Consume), logger.Field("legacy_type3_mapped", behavior.LegacyType3Mapped), logger.Field("legacy_scene_fallback_hit", fallbackHit), logger.Field("scene", scene), logger.Field("method", normalizedReq.Method), ) return resp, nil } l.Infow("[CheckVerificationCode] verify failed", logger.Field("verify_check_source", source), logger.Field("verify_consume", behavior.Consume), logger.Field("legacy_type3_mapped", behavior.LegacyType3Mapped), logger.Field("legacy_scene_fallback_hit", false), logger.Field("type", normalizedReq.Type), logger.Field("method", normalizedReq.Method), ) return resp, nil } func normalizeCheckVerificationCodeRequest(req *types.CheckVerificationCodeRequest) (*types.CheckVerificationCodeRequest, error) { if req == nil { return nil, errors.Wrapf(xerr.NewErrCode(xerr.InvalidParams), "empty request") } method := strings.ToLower(strings.TrimSpace(req.Method)) account := strings.TrimSpace(req.Account) code := strings.TrimSpace(req.Code) switch method { case authmethod.Email: account = strings.ToLower(account) case authmethod.Mobile: if !phone.CheckPhone(account) { return nil, errors.Wrapf(xerr.NewErrCode(xerr.TelephoneError), "Invalid phone number") } default: return nil, errors.Wrapf(xerr.NewErrCode(xerr.InvalidParams), "invalid method") } if account == "" { return nil, errors.Wrapf(xerr.NewErrCode(xerr.InvalidParams), "account is required") } return &types.CheckVerificationCodeRequest{ Method: method, Account: account, Code: code, Type: req.Type, }, nil } func resolveVerifyScenes(verifyType constant.VerifyType, allowFallback bool) []string { switch verifyType { case constant.Register: if allowFallback { return []string{constant.Register.String(), constant.Security.String()} } return []string{constant.Register.String()} case constant.Security: if allowFallback { return []string{constant.Security.String(), constant.Register.String()} } return []string{constant.Security.String()} case constant.DeleteAccount: if allowFallback { return []string{constant.DeleteAccount.String(), constant.Security.String()} } return []string{constant.DeleteAccount.String()} default: return nil } } func buildVerifyCodeCacheKey(method string, scene string, account string) string { if method == authmethod.Mobile { return fmt.Sprintf("%s:%s:+%s", config.AuthCodeTelephoneCacheKey, scene, account) } return fmt.Sprintf("%s:%s:%s", config.AuthCodeCacheKey, scene, account) } func consumeVerificationCodeAtomically(ctx context.Context, redisClient *redis.Client, cacheKey string, expectedValue string) (bool, error) { result, err := consumeVerifyCodeScript.Run(ctx, redisClient, []string{cacheKey}, expectedValue).Int() if err != nil { return false, err } return result == 1, nil }