hi-server/internal/logic/auth/deviceLoginLogic.go
Rust 5f1a546bbe fix: resolve trial subscription cache issue on new user registration
When new users register with trial subscription enabled, the subscription
link fails to connect in Clash clients. This is caused by missing cache
invalidation after transaction commit.

Changes:
- Add cache clearing after successful trial subscription creation
- Clear user subscription cache, subscription details cache, and server cache
- Modify activeTrial functions to return subscription object for cache clearing
- Apply fix to all registration methods: email, phone, device, and OAuth

This ensures subscription links work immediately after registration without
requiring manual subscription reset.
2026-01-22 23:57:15 +07:00

332 lines
10 KiB
Go

package auth
import (
"context"
"fmt"
"time"
"github.com/perfect-panel/server/internal/config"
"github.com/perfect-panel/server/internal/model/log"
"github.com/perfect-panel/server/internal/model/user"
"github.com/perfect-panel/server/internal/svc"
"github.com/perfect-panel/server/internal/types"
"github.com/perfect-panel/server/pkg/jwt"
"github.com/perfect-panel/server/pkg/logger"
"github.com/perfect-panel/server/pkg/tool"
"github.com/perfect-panel/server/pkg/uuidx"
"github.com/perfect-panel/server/pkg/xerr"
"github.com/pkg/errors"
"gorm.io/gorm"
)
type DeviceLoginLogic struct {
logger.Logger
ctx context.Context
svcCtx *svc.ServiceContext
}
// Device Login
func NewDeviceLoginLogic(ctx context.Context, svcCtx *svc.ServiceContext) *DeviceLoginLogic {
return &DeviceLoginLogic{
Logger: logger.WithContext(ctx),
ctx: ctx,
svcCtx: svcCtx,
}
}
func (l *DeviceLoginLogic) DeviceLogin(req *types.DeviceLoginRequest) (resp *types.LoginResponse, err error) {
if !l.svcCtx.Config.Device.Enable {
return nil, xerr.NewErrMsg("Device login is disabled")
}
loginStatus := false
var userInfo *user.User
// Record login status
defer func() {
if userInfo != nil && userInfo.Id != 0 {
loginLog := log.Login{
Method: "device",
LoginIP: req.IP,
UserAgent: req.UserAgent,
Success: loginStatus,
Timestamp: time.Now().UnixMilli(),
}
content, _ := loginLog.Marshal()
if err := l.svcCtx.LogModel.Insert(l.ctx, &log.SystemLog{
Type: log.TypeLogin.Uint8(),
Date: time.Now().Format("2006-01-02"),
ObjectID: userInfo.Id,
Content: string(content),
}); err != nil {
l.Errorw("failed to insert login log",
logger.Field("user_id", userInfo.Id),
logger.Field("ip", req.IP),
logger.Field("error", err.Error()),
)
}
}
}()
// Check if device exists by identifier
deviceInfo, err := l.svcCtx.UserModel.FindOneDeviceByIdentifier(l.ctx, req.Identifier)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
if !registerIpLimit(l.svcCtx, l.ctx, req.IP, "device", req.Identifier) {
return nil, errors.Wrapf(xerr.NewErrCode(xerr.RegisterIPLimit), "register ip limit: %v", req.IP)
}
// Device not found, create new user and device
userInfo, err = l.registerUserAndDevice(req)
if err != nil {
return nil, err
}
} else {
l.Errorw("query device failed",
logger.Field("identifier", req.Identifier),
logger.Field("error", err.Error()),
)
return nil, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "query device failed: %v", err.Error())
}
} else {
// Device found, get user info
userInfo, err = l.svcCtx.UserModel.FindOne(l.ctx, deviceInfo.UserId)
if err != nil {
l.Errorw("query user failed",
logger.Field("user_id", deviceInfo.UserId),
logger.Field("error", err.Error()),
)
return nil, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "query user failed: %v", err.Error())
}
}
// Generate session id
sessionId := uuidx.NewUUID().String()
// Generate token
token, err := jwt.NewJwtToken(
l.svcCtx.Config.JwtAuth.AccessSecret,
time.Now().Unix(),
l.svcCtx.Config.JwtAuth.AccessExpire,
jwt.WithOption("UserId", userInfo.Id),
jwt.WithOption("SessionId", sessionId),
jwt.WithOption("LoginType", "device"),
)
if err != nil {
l.Errorw("token generate error",
logger.Field("user_id", userInfo.Id),
logger.Field("error", err.Error()),
)
return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "token generate error: %v", err.Error())
}
// Store session id in redis
sessionIdCacheKey := fmt.Sprintf("%v:%v", config.SessionIdKey, sessionId)
if err = l.svcCtx.Redis.Set(l.ctx, sessionIdCacheKey, userInfo.Id, time.Duration(l.svcCtx.Config.JwtAuth.AccessExpire)*time.Second).Err(); err != nil {
l.Errorw("set session id error",
logger.Field("user_id", userInfo.Id),
logger.Field("error", err.Error()),
)
return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "set session id error: %v", err.Error())
}
// Store device id in redis
deviceCacheKey := fmt.Sprintf("%v:%v", config.DeviceCacheKeyKey, req.Identifier)
if err = l.svcCtx.Redis.Set(l.ctx, deviceCacheKey, sessionId, time.Duration(l.svcCtx.Config.JwtAuth.AccessExpire)*time.Second).Err(); err != nil {
l.Errorw("set device id error",
logger.Field("user_id", userInfo.Id),
logger.Field("error", err.Error()),
)
return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "set device id error: %v", err.Error())
}
loginStatus = true
return &types.LoginResponse{
Token: token,
}, nil
}
func (l *DeviceLoginLogic) registerUserAndDevice(req *types.DeviceLoginRequest) (*user.User, error) {
l.Infow("device not found, creating new user and device",
logger.Field("identifier", req.Identifier),
logger.Field("ip", req.IP),
)
var userInfo *user.User
var trialSubscribe *user.Subscribe
err := l.svcCtx.UserModel.Transaction(l.ctx, func(db *gorm.DB) error {
// Create new user
userInfo = &user.User{
Salt: "default",
OnlyFirstPurchase: &l.svcCtx.Config.Invite.OnlyFirstPurchase,
}
if err := db.Create(userInfo).Error; err != nil {
l.Errorw("failed to create user",
logger.Field("error", err.Error()),
)
return errors.Wrapf(xerr.NewErrCode(xerr.DatabaseInsertError), "create user failed: %v", err)
}
// Update refer code
userInfo.ReferCode = uuidx.UserInviteCode(userInfo.Id)
if err := db.Model(&user.User{}).Where("id = ?", userInfo.Id).Update("refer_code", userInfo.ReferCode).Error; err != nil {
l.Errorw("failed to update refer code",
logger.Field("user_id", userInfo.Id),
logger.Field("error", err.Error()),
)
return errors.Wrapf(xerr.NewErrCode(xerr.DatabaseUpdateError), "update refer code failed: %v", err)
}
// Create device auth method
authMethod := &user.AuthMethods{
UserId: userInfo.Id,
AuthType: "device",
AuthIdentifier: req.Identifier,
Verified: true,
}
if err := db.Create(authMethod).Error; err != nil {
l.Errorw("failed to create device auth method",
logger.Field("user_id", userInfo.Id),
logger.Field("identifier", req.Identifier),
logger.Field("error", err.Error()),
)
return errors.Wrapf(xerr.NewErrCode(xerr.DatabaseInsertError), "create device auth method failed: %v", err)
}
// Insert device record
deviceInfo := &user.Device{
Ip: req.IP,
UserId: userInfo.Id,
UserAgent: req.UserAgent,
Identifier: req.Identifier,
ShortCode: req.ShortCode,
Enabled: true,
Online: false,
}
if err := db.Create(deviceInfo).Error; err != nil {
l.Errorw("failed to insert device",
logger.Field("user_id", userInfo.Id),
logger.Field("identifier", req.Identifier),
logger.Field("error", err.Error()),
)
return errors.Wrapf(xerr.NewErrCode(xerr.DatabaseInsertError), "insert device failed: %v", err)
}
// Activate trial if enabled
if l.svcCtx.Config.Register.EnableTrial {
var trialErr error
trialSubscribe, trialErr = l.activeTrial(userInfo.Id, db)
if trialErr != nil {
return trialErr
}
}
return nil
})
if err != nil {
l.Errorw("device registration failed",
logger.Field("identifier", req.Identifier),
logger.Field("error", err.Error()),
)
return nil, err
}
// Clear cache after transaction success
if l.svcCtx.Config.Register.EnableTrial && trialSubscribe != nil {
// Clear user subscription cache
if err = l.svcCtx.UserModel.ClearSubscribeCache(l.ctx, trialSubscribe); err != nil {
l.Errorw("ClearSubscribeCache failed", logger.Field("error", err.Error()), logger.Field("userSubscribeId", trialSubscribe.Id))
// Don't return error, just log it
}
// Clear subscription cache
if err = l.svcCtx.SubscribeModel.ClearCache(l.ctx, trialSubscribe.SubscribeId); err != nil {
l.Errorw("ClearSubscribeCache failed", logger.Field("error", err.Error()), logger.Field("subscribeId", trialSubscribe.SubscribeId))
// Don't return error, just log it
}
// Clear all server cache
if err = l.svcCtx.NodeModel.ClearServerAllCache(l.ctx); err != nil {
l.Errorf("ClearServerAllCache error: %v", err.Error())
// Don't return error, just log it
}
}
l.Infow("device registration completed successfully",
logger.Field("user_id", userInfo.Id),
logger.Field("identifier", req.Identifier),
logger.Field("refer_code", userInfo.ReferCode),
)
// Register log
registerLog := log.Register{
AuthMethod: "device",
Identifier: req.Identifier,
RegisterIP: req.IP,
UserAgent: req.UserAgent,
Timestamp: time.Now().UnixMilli(),
}
content, _ := registerLog.Marshal()
if err := l.svcCtx.LogModel.Insert(l.ctx, &log.SystemLog{
Type: log.TypeRegister.Uint8(),
Date: time.Now().Format("2006-01-02"),
ObjectID: userInfo.Id,
Content: string(content),
}); err != nil {
l.Errorw("failed to insert register log",
logger.Field("user_id", userInfo.Id),
logger.Field("ip", req.IP),
logger.Field("error", err.Error()),
)
}
return userInfo, nil
}
func (l *DeviceLoginLogic) activeTrial(userId int64, db *gorm.DB) (*user.Subscribe, error) {
sub, err := l.svcCtx.SubscribeModel.FindOne(l.ctx, l.svcCtx.Config.Register.TrialSubscribe)
if err != nil {
l.Errorw("failed to find trial subscription template",
logger.Field("user_id", userId),
logger.Field("trial_subscribe_id", l.svcCtx.Config.Register.TrialSubscribe),
logger.Field("error", err.Error()),
)
return nil, err
}
startTime := time.Now()
expireTime := tool.AddTime(l.svcCtx.Config.Register.TrialTimeUnit, l.svcCtx.Config.Register.TrialTime, startTime)
subscribeToken := uuidx.SubscribeToken(fmt.Sprintf("Trial-%v", userId))
subscribeUUID := uuidx.NewUUID().String()
userSub := &user.Subscribe{
UserId: userId,
OrderId: 0,
SubscribeId: sub.Id,
StartTime: startTime,
ExpireTime: expireTime,
Traffic: sub.Traffic,
Download: 0,
Upload: 0,
Token: subscribeToken,
UUID: subscribeUUID,
Status: 1,
}
if err := db.Create(userSub).Error; err != nil {
l.Errorw("failed to insert trial subscription",
logger.Field("user_id", userId),
logger.Field("error", err.Error()),
)
return nil, err
}
l.Infow("trial subscription activated successfully",
logger.Field("user_id", userId),
logger.Field("subscribe_id", sub.Id),
logger.Field("expire_time", expireTime),
logger.Field("traffic", sub.Traffic),
)
return userSub, nil
}