diff --git a/.trae/documents/实现按用户Token并发登录上限N并超限逐出.md b/.trae/documents/实现按用户Token并发登录上限N并超限逐出.md new file mode 100644 index 0000000..ff4ec8f --- /dev/null +++ b/.trae/documents/实现按用户Token并发登录上限N并超限逐出.md @@ -0,0 +1,92 @@ +## 修复目标 + +* 解决首次设备登录时在 `internal/logic/auth/deviceLoginLogic.go:99` 对 `deviceInfo` 赋值导致的空指针崩溃,确保接口稳定返回。 + +## 根因定位 + +* 设备不存在分支仅创建用户与设备记录,但未为局部变量 `deviceInfo` 赋值;随后在 `internal/logic/auth/deviceLoginLogic.go:99-100` 使用 `deviceInfo` 导致 `nil` 解引用。 + +* 参考位置: + + * 赋值处:`internal/logic/auth/deviceLoginLogic.go:99-101` + + * 设备存在分支赋值:`internal/logic/auth/deviceLoginLogic.go:88-95` + + * 设备不存在分支未赋值:`internal/logic/auth/deviceLoginLogic.go:74-79` + + * `UpdateDevice` 需要有效设备 `Id`:`internal/model/user/device.go:58-69` + +## 修改方案 + +1. 在“设备不存在”分支注册完成后,立即通过标识重新查询设备,赋值给 `deviceInfo`: + + * 在 `internal/logic/auth/deviceLoginLogic.go` 的 `if errors.Is(err, gorm.ErrRecordNotFound)` 分支中,`userInfo, err = l.registerUserAndDevice(req)` 之后追加: + + * `deviceInfo, err = l.svcCtx.UserModel.FindOneDeviceByIdentifier(l.ctx, req.Identifier)` + + * 如果查询失败则返回数据库查询错误(与现有风格一致)。 +2. 在更新设备 UA 前增加空指针保护,并不再忽略更新错误: + + * 将 `internal/logic/auth/deviceLoginLogic.go:99-101` 改为: + + * 检查 `deviceInfo != nil` + + * `deviceInfo.UserAgent = req.UserAgent` + + * `if err := l.svcCtx.UserModel.UpdateDevice(l.ctx, deviceInfo); err != nil {` 记录错误并返回包装后的错误 `xerr.DatabaseUpdateError`。 +3. 可选优化(减少二次查询): + + * 将 `registerUserAndDevice(req)` 的返回值改为 `(*user.User, *user.Device, error)`,在注册时直接返回新建设备对象;调用点随之调整。若选择此方案,仍需在更新前做空指针保护。 + +## 代码示例(方案1,最小改动) + +```go +// internal/logic/auth/deviceLoginLogic.go +// 设备不存在分支注册后追加一次设备查询 +userInfo, err = l.registerUserAndDevice(req) +if err != nil { + return nil, err +} +deviceInfo, err = l.svcCtx.UserModel.FindOneDeviceByIdentifier(l.ctx, req.Identifier) +if err != nil { + l.Errorw("query device after register failed", + logger.Field("identifier", req.Identifier), + logger.Field("error", err.Error()), + ) + return nil, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "query device after register failed: %v", err.Error()) +} + +// 更新 UA,不忽略更新错误 +if deviceInfo != nil { + deviceInfo.UserAgent = req.UserAgent + if err := l.svcCtx.UserModel.UpdateDevice(l.ctx, deviceInfo); err != nil { + l.Errorw("update device failed", + logger.Field("user_id", userInfo.Id), + logger.Field("identifier", req.Identifier), + logger.Field("error", err.Error()), + ) + return nil, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseUpdateError), "update device failed: %v", err.Error()) + } +} +``` + +## 测试用例与验证 + +* 用例1:首次设备标识登录(设备不存在)应成功返回 Token,日志包含注册与登录记录,无 500。 + +* 用例2:已存在设备标识登录(设备存在)应正常更新 UA 并返回 Token。 + +* 用例3:模拟数据库异常时应返回一致的业务错误码,不产生 `panic`。 + +## 风险与回滚 + +* 改动限定在登录逻辑,属最小范围;若出现异常,回滚为当前版本即可。 + +* 不改变数据结构与外部接口行为,兼容现有客户端。 + +## 后续优化(可选) + +* 统一 `UpdateDevice` 错误处理路径,避免 `_ = ...` 静默失败。 + +* 为“首次设备登录”场景补充集成测试,保证不再回归。 + diff --git a/internal/config/cacheKey.go b/internal/config/cacheKey.go index b2b290b..a57cb04 100644 --- a/internal/config/cacheKey.go +++ b/internal/config/cacheKey.go @@ -42,6 +42,9 @@ const SessionIdKey = "auth:session_id" // DeviceCacheKeyKey cache session key const DeviceCacheKeyKey = "auth:device_identifier" +// UserSessionsKeyPrefix per-user sessions zset key prefix +const UserSessionsKeyPrefix = "auth:user_sessions:" + // GlobalConfigKey Global Config Key const GlobalConfigKey = "system:global_config" diff --git a/internal/config/config.go b/internal/config/config.go index e47b56c..ac34546 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -42,8 +42,9 @@ type RedisConfig struct { } type JwtAuth struct { - AccessSecret string `yaml:"AccessSecret"` - AccessExpire int64 `yaml:"AccessExpire" default:"604800"` + AccessSecret string `yaml:"AccessSecret"` + AccessExpire int64 `yaml:"AccessExpire" default:"604800"` + MaxSessionsPerUser int64 `yaml:"MaxSessionsPerUser" default:"3"` } type Verify struct { diff --git a/internal/logic/auth/deviceLoginLogic.go b/internal/logic/auth/deviceLoginLogic.go index 3ea802d..52cdce5 100644 --- a/internal/logic/auth/deviceLoginLogic.go +++ b/internal/logic/auth/deviceLoginLogic.go @@ -140,6 +140,9 @@ func (l *DeviceLoginLogic) DeviceLogin(req *types.DeviceLoginRequest) (resp *typ return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "token generate error: %v", err.Error()) } + if err = l.svcCtx.EnforceUserSessionLimit(l.ctx, userInfo.Id, sessionId, l.svcCtx.Config.JwtAuth.MaxSessionsPerUser); err != nil { + return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "enforce session limit error: %v", err.Error()) + } // Store session id in redis sessionIdCacheKey := fmt.Sprintf("%v:%v", config.SessionIdKey, sessionId) if err = l.svcCtx.Redis.Set(l.ctx, sessionIdCacheKey, userInfo.Id, time.Duration(l.svcCtx.Config.JwtAuth.AccessExpire)*time.Second).Err(); err != nil { diff --git a/internal/logic/auth/oauth/oAuthLoginGetTokenLogic.go b/internal/logic/auth/oauth/oAuthLoginGetTokenLogic.go index 4e12d2f..8bebe11 100644 --- a/internal/logic/auth/oauth/oAuthLoginGetTokenLogic.go +++ b/internal/logic/auth/oauth/oAuthLoginGetTokenLogic.go @@ -587,6 +587,10 @@ func (l *OAuthLoginGetTokenLogic) generateToken(userInfo *user.User, requestID s return "", errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "token generate error: %v", err) } + if err = l.svcCtx.EnforceUserSessionLimit(l.ctx, userInfo.Id, sessionId, l.svcCtx.Config.JwtAuth.MaxSessionsPerUser); err != nil { + return "", errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "enforce session limit error: %v", err.Error()) + } + sessionIdCacheKey := fmt.Sprintf("%v:%v", config.SessionIdKey, sessionId) if err = l.svcCtx.Redis.Set(l.ctx, sessionIdCacheKey, userInfo.Id, time.Duration(l.svcCtx.Config.JwtAuth.AccessExpire)*time.Second).Err(); err != nil { l.Errorw("failed to cache session id", diff --git a/internal/logic/auth/telephoneLoginLogic.go b/internal/logic/auth/telephoneLoginLogic.go index 8a54ff5..ca2b62b 100644 --- a/internal/logic/auth/telephoneLoginLogic.go +++ b/internal/logic/auth/telephoneLoginLogic.go @@ -156,6 +156,9 @@ func (l *TelephoneLoginLogic) TelephoneLogin(req *types.TelephoneLoginRequest, r l.Logger.Error("[UserLogin] token generate error", logger.Field("error", err.Error())) return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "token generate error: %v", err.Error()) } + if err = l.svcCtx.EnforceUserSessionLimit(l.ctx, userInfo.Id, sessionId, l.svcCtx.Config.JwtAuth.MaxSessionsPerUser); err != nil { + return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "enforce session limit error: %v", err.Error()) + } sessionIdCacheKey := fmt.Sprintf("%v:%v", config.SessionIdKey, sessionId) if err = l.svcCtx.Redis.Set(l.ctx, sessionIdCacheKey, userInfo.Id, time.Duration(l.svcCtx.Config.JwtAuth.AccessExpire)*time.Second).Err(); err != nil { return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "set session id error: %v", err.Error()) diff --git a/internal/logic/auth/userLoginLogic.go b/internal/logic/auth/userLoginLogic.go index deecaae..36fad03 100644 --- a/internal/logic/auth/userLoginLogic.go +++ b/internal/logic/auth/userLoginLogic.go @@ -111,6 +111,9 @@ func (l *UserLoginLogic) UserLogin(req *types.UserLoginRequest) (resp *types.Log l.Logger.Error("[UserLogin] token generate error", logger.Field("error", err.Error())) return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "token generate error: %v", err.Error()) } + if err = l.svcCtx.EnforceUserSessionLimit(l.ctx, userInfo.Id, sessionId, l.svcCtx.Config.JwtAuth.MaxSessionsPerUser); err != nil { + return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "enforce session limit error: %v", err.Error()) + } sessionIdCacheKey := fmt.Sprintf("%v:%v", config.SessionIdKey, sessionId) if err = l.svcCtx.Redis.Set(l.ctx, sessionIdCacheKey, userInfo.Id, time.Duration(l.svcCtx.Config.JwtAuth.AccessExpire)*time.Second).Err(); err != nil { return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "set session id error: %v", err.Error()) diff --git a/internal/svc/serviceContext.go b/internal/svc/serviceContext.go index be05079..1a925ba 100644 --- a/internal/svc/serviceContext.go +++ b/internal/svc/serviceContext.go @@ -1,34 +1,36 @@ package svc import ( - "context" + "context" + "fmt" + "time" - "github.com/perfect-panel/server/internal/model/client" - "github.com/perfect-panel/server/internal/model/node" - "github.com/perfect-panel/server/pkg/device" + "github.com/perfect-panel/server/internal/model/client" + "github.com/perfect-panel/server/internal/model/node" + "github.com/perfect-panel/server/pkg/device" - "github.com/perfect-panel/server/internal/config" - "github.com/perfect-panel/server/internal/model/ads" - "github.com/perfect-panel/server/internal/model/announcement" - "github.com/perfect-panel/server/internal/model/auth" - "github.com/perfect-panel/server/internal/model/coupon" - "github.com/perfect-panel/server/internal/model/document" - "github.com/perfect-panel/server/internal/model/log" - "github.com/perfect-panel/server/internal/model/order" - "github.com/perfect-panel/server/internal/model/payment" - "github.com/perfect-panel/server/internal/model/subscribe" - "github.com/perfect-panel/server/internal/model/system" - "github.com/perfect-panel/server/internal/model/ticket" - "github.com/perfect-panel/server/internal/model/traffic" - "github.com/perfect-panel/server/internal/model/user" - "github.com/perfect-panel/server/pkg/limit" - "github.com/perfect-panel/server/pkg/nodeMultiplier" - "github.com/perfect-panel/server/pkg/orm" + "github.com/perfect-panel/server/internal/config" + "github.com/perfect-panel/server/internal/model/ads" + "github.com/perfect-panel/server/internal/model/announcement" + "github.com/perfect-panel/server/internal/model/auth" + "github.com/perfect-panel/server/internal/model/coupon" + "github.com/perfect-panel/server/internal/model/document" + "github.com/perfect-panel/server/internal/model/log" + "github.com/perfect-panel/server/internal/model/order" + "github.com/perfect-panel/server/internal/model/payment" + "github.com/perfect-panel/server/internal/model/subscribe" + "github.com/perfect-panel/server/internal/model/system" + "github.com/perfect-panel/server/internal/model/ticket" + "github.com/perfect-panel/server/internal/model/traffic" + "github.com/perfect-panel/server/internal/model/user" + "github.com/perfect-panel/server/pkg/limit" + "github.com/perfect-panel/server/pkg/nodeMultiplier" + "github.com/perfect-panel/server/pkg/orm" - tgbotapi "github.com/go-telegram-bot-api/telegram-bot-api/v5" - "github.com/hibiken/asynq" - "github.com/redis/go-redis/v9" - "gorm.io/gorm" + tgbotapi "github.com/go-telegram-bot-api/telegram-bot-api/v5" + "github.com/hibiken/asynq" + "github.com/redis/go-redis/v9" + "gorm.io/gorm" ) type ServiceContext struct { @@ -108,7 +110,34 @@ func NewServiceContext(c config.Config) *ServiceContext { TrafficLogModel: traffic.NewModel(db), AnnouncementModel: announcement.NewModel(db, rds), } - srv.DeviceManager = NewDeviceManager(srv) - return srv + srv.DeviceManager = NewDeviceManager(srv) + return srv } + +func (srv *ServiceContext) EnforceUserSessionLimit(ctx context.Context, userId int64, newSessionId string, max int64) error { + if max <= 0 { + return nil + } + sessionsKey := fmt.Sprintf("%s%v", config.UserSessionsKeyPrefix, userId) + now := time.Now().Unix() + if err := srv.Redis.ZAdd(ctx, sessionsKey, redis.Z{Score: float64(now), Member: newSessionId}).Err(); err != nil { + return err + } + count, err := srv.Redis.ZCard(ctx, sessionsKey).Result() + if err != nil { + return err + } + if count > max { + popped, err := srv.Redis.ZPopMin(ctx, sessionsKey, count-max).Result() + if err != nil { + return err + } + for _, z := range popped { + sid := fmt.Sprintf("%v", z.Member) + _ = srv.Redis.Del(ctx, fmt.Sprintf("%v:%v", config.SessionIdKey, sid)).Err() + } + } + _ = srv.Redis.Expire(ctx, sessionsKey, time.Duration(srv.Config.JwtAuth.AccessExpire)*time.Second).Err() + return nil +}