hi-server/internal/logic/auth/telephoneUserRegisterLogic.go
Rust 5f1a546bbe fix: resolve trial subscription cache issue on new user registration
When new users register with trial subscription enabled, the subscription
link fails to connect in Clash clients. This is caused by missing cache
invalidation after transaction commit.

Changes:
- Add cache clearing after successful trial subscription creation
- Clear user subscription cache, subscription details cache, and server cache
- Modify activeTrial functions to return subscription object for cache clearing
- Apply fix to all registration methods: email, phone, device, and OAuth

This ensures subscription links work immediately after registration without
requiring manual subscription reset.
2026-01-22 23:57:15 +07:00

283 lines
9.5 KiB
Go

package auth
import (
"context"
"encoding/json"
"fmt"
"time"
"github.com/perfect-panel/server/internal/model/log"
"github.com/perfect-panel/server/pkg/constant"
"github.com/perfect-panel/server/internal/config"
"github.com/perfect-panel/server/internal/model/user"
"github.com/perfect-panel/server/internal/svc"
"github.com/perfect-panel/server/internal/types"
"github.com/perfect-panel/server/pkg/jwt"
"github.com/perfect-panel/server/pkg/logger"
"github.com/perfect-panel/server/pkg/phone"
"github.com/perfect-panel/server/pkg/tool"
"github.com/perfect-panel/server/pkg/uuidx"
"github.com/perfect-panel/server/pkg/xerr"
"github.com/pkg/errors"
"gorm.io/gorm"
)
type CacheKeyPayload struct {
Code string `json:"code"`
LastAt int64 `json:"lastAt"`
}
type TelephoneUserRegisterLogic struct {
logger.Logger
ctx context.Context
svcCtx *svc.ServiceContext
}
// NewTelephoneUserRegisterLogic User Telephone register
func NewTelephoneUserRegisterLogic(ctx context.Context, svcCtx *svc.ServiceContext) *TelephoneUserRegisterLogic {
return &TelephoneUserRegisterLogic{
Logger: logger.WithContext(ctx),
ctx: ctx,
svcCtx: svcCtx,
}
}
func (l *TelephoneUserRegisterLogic) TelephoneUserRegister(req *types.TelephoneRegisterRequest) (resp *types.LoginResponse, err error) {
c := l.svcCtx.Config.Register
var trialSubscribe *user.Subscribe
// Check if the registration is stopped
if c.StopRegister {
return nil, errors.Wrapf(xerr.NewErrCode(xerr.StopRegister), "stop register")
}
if !phone.Check(req.TelephoneAreaCode, req.Telephone) {
return nil, errors.Wrapf(xerr.NewErrCode(xerr.TelephoneError), "telephone number error")
}
if !l.svcCtx.Config.Mobile.Enable {
return nil, errors.Wrapf(xerr.NewErrCode(xerr.SmsNotEnabled), "sms login is not enabled")
}
phoneNumber, err := phone.FormatToE164(req.TelephoneAreaCode, req.Telephone)
if err != nil {
return nil, errors.Wrapf(xerr.NewErrCode(xerr.TelephoneError), "Invalid phone number")
}
// if the email verification is enabled, the verification code is required
cacheKey := fmt.Sprintf("%s:%s:%s", config.AuthCodeTelephoneCacheKey, constant.ParseVerifyType(uint8(constant.Register)), phoneNumber)
value, err := l.svcCtx.Redis.Get(l.ctx, cacheKey).Result()
if err != nil {
l.Errorw("Redis Error", logger.Field("error", err.Error()), logger.Field("cacheKey", cacheKey))
return nil, errors.Wrapf(xerr.NewErrCode(xerr.VerifyCodeError), "code error")
}
var payload CacheKeyPayload
err = json.Unmarshal([]byte(value), &payload)
if err != nil {
l.Errorw("Unmarshal Error", logger.Field("error", err.Error()), logger.Field("value", value))
return nil, errors.Wrapf(xerr.NewErrCode(xerr.VerifyCodeError), "code error")
}
if payload.Code != req.Code {
return nil, errors.Wrapf(xerr.NewErrCode(xerr.VerifyCodeError), "code error")
}
l.svcCtx.Redis.Del(l.ctx, cacheKey)
// Check if the user exists
_, err = l.svcCtx.UserModel.FindUserAuthMethodByOpenID(l.ctx, "mobile", phoneNumber)
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
l.Errorw("FindOneByTelephone Error", logger.Field("error", err))
return nil, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "query user info failed: %v", err.Error())
}
if !errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.Wrapf(xerr.NewErrCode(xerr.UserExist), "telephone already exists")
}
var referer *user.User
if req.Invite == "" {
if l.svcCtx.Config.Invite.ForcedInvite {
return nil, errors.Wrapf(xerr.NewErrCode(xerr.InviteCodeError), "invite code is required")
}
} else {
// Check if the invite code is valid
referer, err = l.svcCtx.UserModel.FindOneByReferCode(l.ctx, req.Invite)
if err != nil {
l.Errorw("FindOneByReferCode Error", logger.Field("error", err))
return nil, errors.Wrapf(xerr.NewErrCode(xerr.InviteCodeError), "invite code is invalid")
}
}
if !registerIpLimit(l.svcCtx, l.ctx, req.IP, "mobile", phoneNumber) {
return nil, errors.Wrapf(xerr.NewErrCode(xerr.RegisterIPLimit), "register ip limit: %v", req.IP)
}
// Generate password
pwd := tool.EncodePassWord(req.Password)
userInfo := &user.User{
Password: pwd,
Algo: "default",
OnlyFirstPurchase: &l.svcCtx.Config.Invite.OnlyFirstPurchase,
AuthMethods: []user.AuthMethods{
{
AuthType: "mobile",
AuthIdentifier: phoneNumber,
Verified: true,
},
},
}
if referer != nil {
userInfo.RefererId = referer.Id
}
err = l.svcCtx.UserModel.Transaction(l.ctx, func(db *gorm.DB) error {
// Save user information
if err := db.Create(userInfo).Error; err != nil {
return err
}
// Generate ReferCode
userInfo.ReferCode = uuidx.UserInviteCode(userInfo.Id)
// Update ReferCode
if err := db.Model(&user.User{}).Where("id = ?", userInfo.Id).Update("refer_code", userInfo.ReferCode).Error; err != nil {
return err
}
if l.svcCtx.Config.Register.EnableTrial {
// Active trial
var trialErr error
trialSubscribe, trialErr = l.activeTrial(userInfo.Id)
if trialErr != nil {
return trialErr
}
}
return nil
})
if err != nil {
return nil, err
}
// Clear cache after transaction success
if l.svcCtx.Config.Register.EnableTrial && trialSubscribe != nil {
// Clear user subscription cache
if err = l.svcCtx.UserModel.ClearSubscribeCache(l.ctx, trialSubscribe); err != nil {
l.Errorw("ClearSubscribeCache failed", logger.Field("error", err.Error()), logger.Field("userSubscribeId", trialSubscribe.Id))
// Don't return error, just log it
}
// Clear subscription cache
if err = l.svcCtx.SubscribeModel.ClearCache(l.ctx, trialSubscribe.SubscribeId); err != nil {
l.Errorw("ClearSubscribeCache failed", logger.Field("error", err.Error()), logger.Field("subscribeId", trialSubscribe.SubscribeId))
// Don't return error, just log it
}
// Clear all server cache
if err = l.svcCtx.NodeModel.ClearServerAllCache(l.ctx); err != nil {
l.Errorf("ClearServerAllCache error: %v", err.Error())
// Don't return error, just log it
}
}
// Bind device to user if identifier is provided
if req.Identifier != "" {
bindLogic := NewBindDeviceLogic(l.ctx, l.svcCtx)
if err := bindLogic.BindDeviceToUser(req.Identifier, req.IP, req.UserAgent, userInfo.Id); err != nil {
l.Errorw("failed to bind device to user",
logger.Field("user_id", userInfo.Id),
logger.Field("identifier", req.Identifier),
logger.Field("error", err.Error()),
)
// Don't fail register if device binding fails, just log the error
}
}
if l.ctx.Value(constant.CtxLoginType) != nil {
req.LoginType = l.ctx.Value(constant.CtxLoginType).(string)
}
// Generate session id
sessionId := uuidx.NewUUID().String()
// Generate token
token, err := jwt.NewJwtToken(
l.svcCtx.Config.JwtAuth.AccessSecret,
time.Now().Unix(),
l.svcCtx.Config.JwtAuth.AccessExpire,
jwt.WithOption("UserId", userInfo.Id),
jwt.WithOption("SessionId", sessionId),
jwt.WithOption("identifier", req.Identifier),
jwt.WithOption("CtxLoginType", req.LoginType),
)
if err != nil {
l.Logger.Error("[UserLogin] token generate error", logger.Field("error", err.Error()))
return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "token generate error: %v", err.Error())
}
sessionIdCacheKey := fmt.Sprintf("%v:%v", config.SessionIdKey, sessionId)
if err = l.svcCtx.Redis.Set(l.ctx, sessionIdCacheKey, userInfo.Id, time.Duration(l.svcCtx.Config.JwtAuth.AccessExpire)*time.Second).Err(); err != nil {
return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "set session id error: %v", err.Error())
}
defer func() {
if token != "" && userInfo.Id != 0 {
loginLog := log.Login{
Method: "mobile",
LoginIP: req.IP,
UserAgent: req.UserAgent,
Success: token != "",
Timestamp: time.Now().UnixMilli(),
}
content, _ := loginLog.Marshal()
if err := l.svcCtx.LogModel.Insert(l.ctx, &log.SystemLog{
Id: 0,
Type: log.TypeLogin.Uint8(),
Date: time.Now().Format("2006-01-02"),
ObjectID: userInfo.Id,
Content: string(content),
}); err != nil {
l.Errorw("failed to insert login log",
logger.Field("user_id", userInfo.Id),
logger.Field("ip", req.IP),
logger.Field("error", err.Error()),
)
}
// Register log
registerLog := log.Register{
AuthMethod: "mobile",
Identifier: phoneNumber,
RegisterIP: req.IP,
UserAgent: req.UserAgent,
Timestamp: time.Now().UnixMilli(),
}
content, _ = registerLog.Marshal()
if err := l.svcCtx.LogModel.Insert(l.ctx, &log.SystemLog{
Type: log.TypeRegister.Uint8(),
ObjectID: userInfo.Id,
Date: time.Now().Format("2006-01-02"),
Content: string(content),
}); err != nil {
l.Errorw("failed to insert login log",
logger.Field("user_id", userInfo.Id),
logger.Field("ip", req.IP),
logger.Field("error", err.Error()))
}
}
}()
return &types.LoginResponse{
Token: token,
}, nil
}
func (l *TelephoneUserRegisterLogic) activeTrial(uid int64) (*user.Subscribe, error) {
sub, err := l.svcCtx.SubscribeModel.FindOne(l.ctx, l.svcCtx.Config.Register.TrialSubscribe)
if err != nil {
return nil, err
}
userSub := &user.Subscribe{
Id: 0,
UserId: uid,
OrderId: 0,
SubscribeId: sub.Id,
StartTime: time.Now(),
ExpireTime: tool.AddTime(l.svcCtx.Config.Register.TrialTimeUnit, l.svcCtx.Config.Register.TrialTime, time.Now()),
Traffic: sub.Traffic,
Download: 0,
Upload: 0,
Token: uuidx.SubscribeToken(fmt.Sprintf("Trial-%v", uid)),
UUID: uuidx.NewUUID().String(),
Status: 1,
}
err = l.svcCtx.UserModel.InsertSubscribe(l.ctx, userSub)
if err != nil {
return nil, err
}
return userSub, nil
}