fix: resolve trial subscription cache issue on new user registration

When new users register with trial subscription enabled, the subscription
link fails to connect in Clash clients. This is caused by missing cache
invalidation after transaction commit.

Changes:
- Add cache clearing after successful trial subscription creation
- Clear user subscription cache, subscription details cache, and server cache
- Modify activeTrial functions to return subscription object for cache clearing
- Apply fix to all registration methods: email, phone, device, and OAuth

This ensures subscription links work immediately after registration without
requiring manual subscription reset.
This commit is contained in:
Rust 2026-01-22 23:43:49 +07:00
parent 5f55b1242e
commit 5f1a546bbe
4 changed files with 121 additions and 29 deletions

View File

@ -152,6 +152,7 @@ func (l *DeviceLoginLogic) registerUserAndDevice(req *types.DeviceLoginRequest)
) )
var userInfo *user.User var userInfo *user.User
var trialSubscribe *user.Subscribe
err := l.svcCtx.UserModel.Transaction(l.ctx, func(db *gorm.DB) error { err := l.svcCtx.UserModel.Transaction(l.ctx, func(db *gorm.DB) error {
// Create new user // Create new user
userInfo = &user.User{ userInfo = &user.User{
@ -212,8 +213,10 @@ func (l *DeviceLoginLogic) registerUserAndDevice(req *types.DeviceLoginRequest)
// Activate trial if enabled // Activate trial if enabled
if l.svcCtx.Config.Register.EnableTrial { if l.svcCtx.Config.Register.EnableTrial {
if err := l.activeTrial(userInfo.Id, db); err != nil { var trialErr error
return err trialSubscribe, trialErr = l.activeTrial(userInfo.Id, db)
if trialErr != nil {
return trialErr
} }
} }
@ -228,6 +231,25 @@ func (l *DeviceLoginLogic) registerUserAndDevice(req *types.DeviceLoginRequest)
return nil, err return nil, err
} }
// Clear cache after transaction success
if l.svcCtx.Config.Register.EnableTrial && trialSubscribe != nil {
// Clear user subscription cache
if err = l.svcCtx.UserModel.ClearSubscribeCache(l.ctx, trialSubscribe); err != nil {
l.Errorw("ClearSubscribeCache failed", logger.Field("error", err.Error()), logger.Field("userSubscribeId", trialSubscribe.Id))
// Don't return error, just log it
}
// Clear subscription cache
if err = l.svcCtx.SubscribeModel.ClearCache(l.ctx, trialSubscribe.SubscribeId); err != nil {
l.Errorw("ClearSubscribeCache failed", logger.Field("error", err.Error()), logger.Field("subscribeId", trialSubscribe.SubscribeId))
// Don't return error, just log it
}
// Clear all server cache
if err = l.svcCtx.NodeModel.ClearServerAllCache(l.ctx); err != nil {
l.Errorf("ClearServerAllCache error: %v", err.Error())
// Don't return error, just log it
}
}
l.Infow("device registration completed successfully", l.Infow("device registration completed successfully",
logger.Field("user_id", userInfo.Id), logger.Field("user_id", userInfo.Id),
logger.Field("identifier", req.Identifier), logger.Field("identifier", req.Identifier),
@ -260,7 +282,7 @@ func (l *DeviceLoginLogic) registerUserAndDevice(req *types.DeviceLoginRequest)
return userInfo, nil return userInfo, nil
} }
func (l *DeviceLoginLogic) activeTrial(userId int64, db *gorm.DB) error { func (l *DeviceLoginLogic) activeTrial(userId int64, db *gorm.DB) (*user.Subscribe, error) {
sub, err := l.svcCtx.SubscribeModel.FindOne(l.ctx, l.svcCtx.Config.Register.TrialSubscribe) sub, err := l.svcCtx.SubscribeModel.FindOne(l.ctx, l.svcCtx.Config.Register.TrialSubscribe)
if err != nil { if err != nil {
l.Errorw("failed to find trial subscription template", l.Errorw("failed to find trial subscription template",
@ -268,7 +290,7 @@ func (l *DeviceLoginLogic) activeTrial(userId int64, db *gorm.DB) error {
logger.Field("trial_subscribe_id", l.svcCtx.Config.Register.TrialSubscribe), logger.Field("trial_subscribe_id", l.svcCtx.Config.Register.TrialSubscribe),
logger.Field("error", err.Error()), logger.Field("error", err.Error()),
) )
return err return nil, err
} }
startTime := time.Now() startTime := time.Now()
@ -295,7 +317,7 @@ func (l *DeviceLoginLogic) activeTrial(userId int64, db *gorm.DB) error {
logger.Field("user_id", userId), logger.Field("user_id", userId),
logger.Field("error", err.Error()), logger.Field("error", err.Error()),
) )
return err return nil, err
} }
l.Infow("trial subscription activated successfully", l.Infow("trial subscription activated successfully",
@ -305,8 +327,5 @@ func (l *DeviceLoginLogic) activeTrial(userId int64, db *gorm.DB) error {
logger.Field("traffic", sub.Traffic), logger.Field("traffic", sub.Traffic),
) )
if clearErr := l.svcCtx.NodeModel.ClearServerAllCache(l.ctx); clearErr != nil { return userSub, nil
l.Errorf("ClearServerAllCache error: %v", clearErr.Error())
}
return nil
} }

View File

@ -341,6 +341,7 @@ func (l *OAuthLoginGetTokenLogic) register(email, avatar, method, openid, reques
} }
var userInfo *user.User var userInfo *user.User
var trialSubscribe *user.Subscribe
err := l.svcCtx.UserModel.Transaction(l.ctx, func(db *gorm.DB) error { err := l.svcCtx.UserModel.Transaction(l.ctx, func(db *gorm.DB) error {
if email != "" { if email != "" {
l.Debugw("checking if email already exists", l.Debugw("checking if email already exists",
@ -397,8 +398,10 @@ func (l *OAuthLoginGetTokenLogic) register(email, avatar, method, openid, reques
logger.Field("request_id", requestID), logger.Field("request_id", requestID),
logger.Field("user_id", userInfo.Id), logger.Field("user_id", userInfo.Id),
) )
if err := l.activeTrial(userInfo.Id, requestID); err != nil { var trialErr error
return err trialSubscribe, trialErr = l.activeTrial(userInfo.Id, requestID)
if trialErr != nil {
return trialErr
} }
} }
@ -415,6 +418,25 @@ func (l *OAuthLoginGetTokenLogic) register(email, avatar, method, openid, reques
return userInfo, err return userInfo, err
} }
// Clear cache after transaction success
if l.svcCtx.Config.Register.EnableTrial && trialSubscribe != nil {
// Clear user subscription cache
if err = l.svcCtx.UserModel.ClearSubscribeCache(l.ctx, trialSubscribe); err != nil {
l.Errorw("ClearSubscribeCache failed", logger.Field("error", err.Error()), logger.Field("userSubscribeId", trialSubscribe.Id))
// Don't return error, just log it
}
// Clear subscription cache
if err = l.svcCtx.SubscribeModel.ClearCache(l.ctx, trialSubscribe.SubscribeId); err != nil {
l.Errorw("ClearSubscribeCache failed", logger.Field("error", err.Error()), logger.Field("subscribeId", trialSubscribe.SubscribeId))
// Don't return error, just log it
}
// Clear all server cache
if err = l.svcCtx.NodeModel.ClearServerAllCache(l.ctx); err != nil {
l.Errorf("ClearServerAllCache error: %v", err.Error())
// Don't return error, just log it
}
}
l.Infow("user registration completed successfully", l.Infow("user registration completed successfully",
logger.Field("request_id", requestID), logger.Field("request_id", requestID),
logger.Field("user_id", userInfo.Id), logger.Field("user_id", userInfo.Id),
@ -793,7 +815,7 @@ func (l *OAuthLoginGetTokenLogic) findOrRegisterUser(authType, openID, email, av
return userInfo, nil return userInfo, nil
} }
func (l *OAuthLoginGetTokenLogic) activeTrial(uid int64, requestID string) error { func (l *OAuthLoginGetTokenLogic) activeTrial(uid int64, requestID string) (*user.Subscribe, error) {
l.Debugw("fetching trial subscription template", l.Debugw("fetching trial subscription template",
logger.Field("request_id", requestID), logger.Field("request_id", requestID),
logger.Field("user_id", uid), logger.Field("user_id", uid),
@ -808,7 +830,7 @@ func (l *OAuthLoginGetTokenLogic) activeTrial(uid int64, requestID string) error
logger.Field("trial_subscribe_id", l.svcCtx.Config.Register.TrialSubscribe), logger.Field("trial_subscribe_id", l.svcCtx.Config.Register.TrialSubscribe),
logger.Field("error", err.Error()), logger.Field("error", err.Error()),
) )
return err return nil, err
} }
startTime := time.Now() startTime := time.Now()
@ -848,7 +870,7 @@ func (l *OAuthLoginGetTokenLogic) activeTrial(uid int64, requestID string) error
logger.Field("user_id", uid), logger.Field("user_id", uid),
logger.Field("error", err.Error()), logger.Field("error", err.Error()),
) )
return err return nil, err
} }
l.Infow("trial subscription activated successfully", l.Infow("trial subscription activated successfully",
@ -858,5 +880,5 @@ func (l *OAuthLoginGetTokenLogic) activeTrial(uid int64, requestID string) error
logger.Field("expire_time", expireTime), logger.Field("expire_time", expireTime),
logger.Field("traffic", sub.Traffic), logger.Field("traffic", sub.Traffic),
) )
return nil return userSub, nil
} }

View File

@ -45,6 +45,7 @@ func NewTelephoneUserRegisterLogic(ctx context.Context, svcCtx *svc.ServiceConte
func (l *TelephoneUserRegisterLogic) TelephoneUserRegister(req *types.TelephoneRegisterRequest) (resp *types.LoginResponse, err error) { func (l *TelephoneUserRegisterLogic) TelephoneUserRegister(req *types.TelephoneRegisterRequest) (resp *types.LoginResponse, err error) {
c := l.svcCtx.Config.Register c := l.svcCtx.Config.Register
var trialSubscribe *user.Subscribe
// Check if the registration is stopped // Check if the registration is stopped
if c.StopRegister { if c.StopRegister {
return nil, errors.Wrapf(xerr.NewErrCode(xerr.StopRegister), "stop register") return nil, errors.Wrapf(xerr.NewErrCode(xerr.StopRegister), "stop register")
@ -135,12 +136,36 @@ func (l *TelephoneUserRegisterLogic) TelephoneUserRegister(req *types.TelephoneR
} }
if l.svcCtx.Config.Register.EnableTrial { if l.svcCtx.Config.Register.EnableTrial {
// Active trial // Active trial
if err = l.activeTrial(userInfo.Id); err != nil { var trialErr error
return err trialSubscribe, trialErr = l.activeTrial(userInfo.Id)
if trialErr != nil {
return trialErr
} }
} }
return nil return nil
}) })
if err != nil {
return nil, err
}
// Clear cache after transaction success
if l.svcCtx.Config.Register.EnableTrial && trialSubscribe != nil {
// Clear user subscription cache
if err = l.svcCtx.UserModel.ClearSubscribeCache(l.ctx, trialSubscribe); err != nil {
l.Errorw("ClearSubscribeCache failed", logger.Field("error", err.Error()), logger.Field("userSubscribeId", trialSubscribe.Id))
// Don't return error, just log it
}
// Clear subscription cache
if err = l.svcCtx.SubscribeModel.ClearCache(l.ctx, trialSubscribe.SubscribeId); err != nil {
l.Errorw("ClearSubscribeCache failed", logger.Field("error", err.Error()), logger.Field("subscribeId", trialSubscribe.SubscribeId))
// Don't return error, just log it
}
// Clear all server cache
if err = l.svcCtx.NodeModel.ClearServerAllCache(l.ctx); err != nil {
l.Errorf("ClearServerAllCache error: %v", err.Error())
// Don't return error, just log it
}
}
// Bind device to user if identifier is provided // Bind device to user if identifier is provided
if req.Identifier != "" { if req.Identifier != "" {
@ -229,10 +254,10 @@ func (l *TelephoneUserRegisterLogic) TelephoneUserRegister(req *types.TelephoneR
}, nil }, nil
} }
func (l *TelephoneUserRegisterLogic) activeTrial(uid int64) error { func (l *TelephoneUserRegisterLogic) activeTrial(uid int64) (*user.Subscribe, error) {
sub, err := l.svcCtx.SubscribeModel.FindOne(l.ctx, l.svcCtx.Config.Register.TrialSubscribe) sub, err := l.svcCtx.SubscribeModel.FindOne(l.ctx, l.svcCtx.Config.Register.TrialSubscribe)
if err != nil { if err != nil {
return err return nil, err
} }
userSub := &user.Subscribe{ userSub := &user.Subscribe{
Id: 0, Id: 0,
@ -250,10 +275,8 @@ func (l *TelephoneUserRegisterLogic) activeTrial(uid int64) error {
} }
err = l.svcCtx.UserModel.InsertSubscribe(l.ctx, userSub) err = l.svcCtx.UserModel.InsertSubscribe(l.ctx, userSub)
if err != nil { if err != nil {
return err return nil, err
} }
if clearErr := l.svcCtx.NodeModel.ClearServerAllCache(l.ctx); clearErr != nil { return userSub, nil
l.Errorf("ClearServerAllCache error: %v", clearErr.Error())
}
return err
} }

View File

@ -42,6 +42,7 @@ func (l *UserRegisterLogic) UserRegister(req *types.UserRegisterRequest) (resp *
c := l.svcCtx.Config.Register c := l.svcCtx.Config.Register
email := l.svcCtx.Config.Email email := l.svcCtx.Config.Email
var referer *user.User var referer *user.User
var trialSubscribe *user.Subscribe
// Check if the registration is stopped // Check if the registration is stopped
if c.StopRegister { if c.StopRegister {
return nil, errors.Wrapf(xerr.NewErrCode(xerr.StopRegister), "stop register") return nil, errors.Wrapf(xerr.NewErrCode(xerr.StopRegister), "stop register")
@ -127,12 +128,36 @@ func (l *UserRegisterLogic) UserRegister(req *types.UserRegisterRequest) (resp *
if l.svcCtx.Config.Register.EnableTrial { if l.svcCtx.Config.Register.EnableTrial {
// Active trial // Active trial
if err = l.activeTrial(userInfo.Id); err != nil { var trialErr error
return err trialSubscribe, trialErr = l.activeTrial(userInfo.Id)
if trialErr != nil {
return trialErr
} }
} }
return nil return nil
}) })
if err != nil {
return nil, err
}
// Clear cache after transaction success
if l.svcCtx.Config.Register.EnableTrial && trialSubscribe != nil {
// Clear user subscription cache
if err = l.svcCtx.UserModel.ClearSubscribeCache(l.ctx, trialSubscribe); err != nil {
l.Errorw("ClearSubscribeCache failed", logger.Field("error", err.Error()), logger.Field("userSubscribeId", trialSubscribe.Id))
// Don't return error, just log it
}
// Clear subscription cache
if err = l.svcCtx.SubscribeModel.ClearCache(l.ctx, trialSubscribe.SubscribeId); err != nil {
l.Errorw("ClearSubscribeCache failed", logger.Field("error", err.Error()), logger.Field("subscribeId", trialSubscribe.SubscribeId))
// Don't return error, just log it
}
// Clear all server cache
if err = l.svcCtx.NodeModel.ClearServerAllCache(l.ctx); err != nil {
l.Errorf("ClearServerAllCache error: %v", err.Error())
// Don't return error, just log it
}
}
// Bind device to user if identifier is provided // Bind device to user if identifier is provided
if req.Identifier != "" { if req.Identifier != "" {
bindLogic := NewBindDeviceLogic(l.ctx, l.svcCtx) bindLogic := NewBindDeviceLogic(l.ctx, l.svcCtx)
@ -220,10 +245,10 @@ func (l *UserRegisterLogic) UserRegister(req *types.UserRegisterRequest) (resp *
}, nil }, nil
} }
func (l *UserRegisterLogic) activeTrial(uid int64) error { func (l *UserRegisterLogic) activeTrial(uid int64) (*user.Subscribe, error) {
sub, err := l.svcCtx.SubscribeModel.FindOne(l.ctx, l.svcCtx.Config.Register.TrialSubscribe) sub, err := l.svcCtx.SubscribeModel.FindOne(l.ctx, l.svcCtx.Config.Register.TrialSubscribe)
if err != nil { if err != nil {
return err return nil, err
} }
userSub := &user.Subscribe{ userSub := &user.Subscribe{
UserId: uid, UserId: uid,
@ -238,5 +263,8 @@ func (l *UserRegisterLogic) activeTrial(uid int64) error {
UUID: uuidx.NewUUID().String(), UUID: uuidx.NewUUID().String(),
Status: 1, Status: 1,
} }
return l.svcCtx.UserModel.InsertSubscribe(l.ctx, userSub) if err = l.svcCtx.UserModel.InsertSubscribe(l.ctx, userSub); err != nil {
return nil, err
}
return userSub, nil
} }