feat(auth): 实现用户会话数限制功能
All checks were successful
Build docker and publish / build (20.15.1) (push) Successful in 7m32s

添加用户会话数限制功能,当超过最大会话数时自动移除最旧的会话
- 在config中添加UserSessionsKeyPrefix常量
- 在JwtAuth配置中新增MaxSessionsPerUser字段
- 在ServiceContext中实现EnforceUserSessionLimit方法
- 在所有登录逻辑中调用会话限制检查
This commit is contained in:
shanshanzhong 2025-11-26 17:52:12 -08:00
parent 4ad384b01a
commit 1d5d361ae8
8 changed files with 167 additions and 29 deletions

View File

@ -0,0 +1,92 @@
## 修复目标
* 解决首次设备登录时在 `internal/logic/auth/deviceLoginLogic.go:99``deviceInfo` 赋值导致的空指针崩溃,确保接口稳定返回。
## 根因定位
* 设备不存在分支仅创建用户与设备记录,但未为局部变量 `deviceInfo` 赋值;随后在 `internal/logic/auth/deviceLoginLogic.go:99-100` 使用 `deviceInfo` 导致 `nil` 解引用。
* 参考位置:
* 赋值处:`internal/logic/auth/deviceLoginLogic.go:99-101`
* 设备存在分支赋值:`internal/logic/auth/deviceLoginLogic.go:88-95`
* 设备不存在分支未赋值:`internal/logic/auth/deviceLoginLogic.go:74-79`
* `UpdateDevice` 需要有效设备 `Id``internal/model/user/device.go:58-69`
## 修改方案
1. 在“设备不存在”分支注册完成后,立即通过标识重新查询设备,赋值给 `deviceInfo`
* 在 `internal/logic/auth/deviceLoginLogic.go``if errors.Is(err, gorm.ErrRecordNotFound)` 分支中,`userInfo, err = l.registerUserAndDevice(req)` 之后追加:
* `deviceInfo, err = l.svcCtx.UserModel.FindOneDeviceByIdentifier(l.ctx, req.Identifier)`
* 如果查询失败则返回数据库查询错误(与现有风格一致)。
2. 在更新设备 UA 前增加空指针保护,并不再忽略更新错误:
* 将 `internal/logic/auth/deviceLoginLogic.go:99-101` 改为:
* 检查 `deviceInfo != nil`
* `deviceInfo.UserAgent = req.UserAgent`
* `if err := l.svcCtx.UserModel.UpdateDevice(l.ctx, deviceInfo); err != nil {` 记录错误并返回包装后的错误 `xerr.DatabaseUpdateError`
3. 可选优化(减少二次查询):
* 将 `registerUserAndDevice(req)` 的返回值改为 `(*user.User, *user.Device, error)`,在注册时直接返回新建设备对象;调用点随之调整。若选择此方案,仍需在更新前做空指针保护。
## 代码示例方案1最小改动
```go
// internal/logic/auth/deviceLoginLogic.go
// 设备不存在分支注册后追加一次设备查询
userInfo, err = l.registerUserAndDevice(req)
if err != nil {
return nil, err
}
deviceInfo, err = l.svcCtx.UserModel.FindOneDeviceByIdentifier(l.ctx, req.Identifier)
if err != nil {
l.Errorw("query device after register failed",
logger.Field("identifier", req.Identifier),
logger.Field("error", err.Error()),
)
return nil, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "query device after register failed: %v", err.Error())
}
// 更新 UA不忽略更新错误
if deviceInfo != nil {
deviceInfo.UserAgent = req.UserAgent
if err := l.svcCtx.UserModel.UpdateDevice(l.ctx, deviceInfo); err != nil {
l.Errorw("update device failed",
logger.Field("user_id", userInfo.Id),
logger.Field("identifier", req.Identifier),
logger.Field("error", err.Error()),
)
return nil, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseUpdateError), "update device failed: %v", err.Error())
}
}
```
## 测试用例与验证
* 用例1首次设备标识登录设备不存在应成功返回 Token日志包含注册与登录记录无 500。
* 用例2已存在设备标识登录设备存在应正常更新 UA 并返回 Token。
* 用例3模拟数据库异常时应返回一致的业务错误码不产生 `panic`
## 风险与回滚
* 改动限定在登录逻辑,属最小范围;若出现异常,回滚为当前版本即可。
* 不改变数据结构与外部接口行为,兼容现有客户端。
## 后续优化(可选)
* 统一 `UpdateDevice` 错误处理路径,避免 `_ = ...` 静默失败。
* 为“首次设备登录”场景补充集成测试,保证不再回归。

View File

@ -42,6 +42,9 @@ const SessionIdKey = "auth:session_id"
// DeviceCacheKeyKey cache session key // DeviceCacheKeyKey cache session key
const DeviceCacheKeyKey = "auth:device_identifier" const DeviceCacheKeyKey = "auth:device_identifier"
// UserSessionsKeyPrefix per-user sessions zset key prefix
const UserSessionsKeyPrefix = "auth:user_sessions:"
// GlobalConfigKey Global Config Key // GlobalConfigKey Global Config Key
const GlobalConfigKey = "system:global_config" const GlobalConfigKey = "system:global_config"

View File

@ -44,6 +44,7 @@ type RedisConfig struct {
type JwtAuth struct { type JwtAuth struct {
AccessSecret string `yaml:"AccessSecret"` AccessSecret string `yaml:"AccessSecret"`
AccessExpire int64 `yaml:"AccessExpire" default:"604800"` AccessExpire int64 `yaml:"AccessExpire" default:"604800"`
MaxSessionsPerUser int64 `yaml:"MaxSessionsPerUser" default:"3"`
} }
type Verify struct { type Verify struct {

View File

@ -140,6 +140,9 @@ func (l *DeviceLoginLogic) DeviceLogin(req *types.DeviceLoginRequest) (resp *typ
return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "token generate error: %v", err.Error()) return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "token generate error: %v", err.Error())
} }
if err = l.svcCtx.EnforceUserSessionLimit(l.ctx, userInfo.Id, sessionId, l.svcCtx.Config.JwtAuth.MaxSessionsPerUser); err != nil {
return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "enforce session limit error: %v", err.Error())
}
// Store session id in redis // Store session id in redis
sessionIdCacheKey := fmt.Sprintf("%v:%v", config.SessionIdKey, sessionId) sessionIdCacheKey := fmt.Sprintf("%v:%v", config.SessionIdKey, sessionId)
if err = l.svcCtx.Redis.Set(l.ctx, sessionIdCacheKey, userInfo.Id, time.Duration(l.svcCtx.Config.JwtAuth.AccessExpire)*time.Second).Err(); err != nil { if err = l.svcCtx.Redis.Set(l.ctx, sessionIdCacheKey, userInfo.Id, time.Duration(l.svcCtx.Config.JwtAuth.AccessExpire)*time.Second).Err(); err != nil {

View File

@ -587,6 +587,10 @@ func (l *OAuthLoginGetTokenLogic) generateToken(userInfo *user.User, requestID s
return "", errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "token generate error: %v", err) return "", errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "token generate error: %v", err)
} }
if err = l.svcCtx.EnforceUserSessionLimit(l.ctx, userInfo.Id, sessionId, l.svcCtx.Config.JwtAuth.MaxSessionsPerUser); err != nil {
return "", errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "enforce session limit error: %v", err.Error())
}
sessionIdCacheKey := fmt.Sprintf("%v:%v", config.SessionIdKey, sessionId) sessionIdCacheKey := fmt.Sprintf("%v:%v", config.SessionIdKey, sessionId)
if err = l.svcCtx.Redis.Set(l.ctx, sessionIdCacheKey, userInfo.Id, time.Duration(l.svcCtx.Config.JwtAuth.AccessExpire)*time.Second).Err(); err != nil { if err = l.svcCtx.Redis.Set(l.ctx, sessionIdCacheKey, userInfo.Id, time.Duration(l.svcCtx.Config.JwtAuth.AccessExpire)*time.Second).Err(); err != nil {
l.Errorw("failed to cache session id", l.Errorw("failed to cache session id",

View File

@ -156,6 +156,9 @@ func (l *TelephoneLoginLogic) TelephoneLogin(req *types.TelephoneLoginRequest, r
l.Logger.Error("[UserLogin] token generate error", logger.Field("error", err.Error())) l.Logger.Error("[UserLogin] token generate error", logger.Field("error", err.Error()))
return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "token generate error: %v", err.Error()) return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "token generate error: %v", err.Error())
} }
if err = l.svcCtx.EnforceUserSessionLimit(l.ctx, userInfo.Id, sessionId, l.svcCtx.Config.JwtAuth.MaxSessionsPerUser); err != nil {
return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "enforce session limit error: %v", err.Error())
}
sessionIdCacheKey := fmt.Sprintf("%v:%v", config.SessionIdKey, sessionId) sessionIdCacheKey := fmt.Sprintf("%v:%v", config.SessionIdKey, sessionId)
if err = l.svcCtx.Redis.Set(l.ctx, sessionIdCacheKey, userInfo.Id, time.Duration(l.svcCtx.Config.JwtAuth.AccessExpire)*time.Second).Err(); err != nil { if err = l.svcCtx.Redis.Set(l.ctx, sessionIdCacheKey, userInfo.Id, time.Duration(l.svcCtx.Config.JwtAuth.AccessExpire)*time.Second).Err(); err != nil {
return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "set session id error: %v", err.Error()) return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "set session id error: %v", err.Error())

View File

@ -111,6 +111,9 @@ func (l *UserLoginLogic) UserLogin(req *types.UserLoginRequest) (resp *types.Log
l.Logger.Error("[UserLogin] token generate error", logger.Field("error", err.Error())) l.Logger.Error("[UserLogin] token generate error", logger.Field("error", err.Error()))
return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "token generate error: %v", err.Error()) return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "token generate error: %v", err.Error())
} }
if err = l.svcCtx.EnforceUserSessionLimit(l.ctx, userInfo.Id, sessionId, l.svcCtx.Config.JwtAuth.MaxSessionsPerUser); err != nil {
return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "enforce session limit error: %v", err.Error())
}
sessionIdCacheKey := fmt.Sprintf("%v:%v", config.SessionIdKey, sessionId) sessionIdCacheKey := fmt.Sprintf("%v:%v", config.SessionIdKey, sessionId)
if err = l.svcCtx.Redis.Set(l.ctx, sessionIdCacheKey, userInfo.Id, time.Duration(l.svcCtx.Config.JwtAuth.AccessExpire)*time.Second).Err(); err != nil { if err = l.svcCtx.Redis.Set(l.ctx, sessionIdCacheKey, userInfo.Id, time.Duration(l.svcCtx.Config.JwtAuth.AccessExpire)*time.Second).Err(); err != nil {
return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "set session id error: %v", err.Error()) return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "set session id error: %v", err.Error())

View File

@ -2,6 +2,8 @@ package svc
import ( import (
"context" "context"
"fmt"
"time"
"github.com/perfect-panel/server/internal/model/client" "github.com/perfect-panel/server/internal/model/client"
"github.com/perfect-panel/server/internal/model/node" "github.com/perfect-panel/server/internal/model/node"
@ -112,3 +114,30 @@ func NewServiceContext(c config.Config) *ServiceContext {
return srv return srv
} }
func (srv *ServiceContext) EnforceUserSessionLimit(ctx context.Context, userId int64, newSessionId string, max int64) error {
if max <= 0 {
return nil
}
sessionsKey := fmt.Sprintf("%s%v", config.UserSessionsKeyPrefix, userId)
now := time.Now().Unix()
if err := srv.Redis.ZAdd(ctx, sessionsKey, redis.Z{Score: float64(now), Member: newSessionId}).Err(); err != nil {
return err
}
count, err := srv.Redis.ZCard(ctx, sessionsKey).Result()
if err != nil {
return err
}
if count > max {
popped, err := srv.Redis.ZPopMin(ctx, sessionsKey, count-max).Result()
if err != nil {
return err
}
for _, z := range popped {
sid := fmt.Sprintf("%v", z.Member)
_ = srv.Redis.Del(ctx, fmt.Sprintf("%v:%v", config.SessionIdKey, sid)).Err()
}
}
_ = srv.Redis.Expire(ctx, sessionsKey, time.Duration(srv.Config.JwtAuth.AccessExpire)*time.Second).Err()
return nil
}