feat(auth): 实现用户会话数限制功能
All checks were successful
Build docker and publish / build (20.15.1) (push) Successful in 7m32s
All checks were successful
Build docker and publish / build (20.15.1) (push) Successful in 7m32s
添加用户会话数限制功能,当超过最大会话数时自动移除最旧的会话 - 在config中添加UserSessionsKeyPrefix常量 - 在JwtAuth配置中新增MaxSessionsPerUser字段 - 在ServiceContext中实现EnforceUserSessionLimit方法 - 在所有登录逻辑中调用会话限制检查
This commit is contained in:
parent
4ad384b01a
commit
1d5d361ae8
92
.trae/documents/实现按用户Token并发登录上限N并超限逐出.md
Normal file
92
.trae/documents/实现按用户Token并发登录上限N并超限逐出.md
Normal file
@ -0,0 +1,92 @@
|
|||||||
|
## 修复目标
|
||||||
|
|
||||||
|
* 解决首次设备登录时在 `internal/logic/auth/deviceLoginLogic.go:99` 对 `deviceInfo` 赋值导致的空指针崩溃,确保接口稳定返回。
|
||||||
|
|
||||||
|
## 根因定位
|
||||||
|
|
||||||
|
* 设备不存在分支仅创建用户与设备记录,但未为局部变量 `deviceInfo` 赋值;随后在 `internal/logic/auth/deviceLoginLogic.go:99-100` 使用 `deviceInfo` 导致 `nil` 解引用。
|
||||||
|
|
||||||
|
* 参考位置:
|
||||||
|
|
||||||
|
* 赋值处:`internal/logic/auth/deviceLoginLogic.go:99-101`
|
||||||
|
|
||||||
|
* 设备存在分支赋值:`internal/logic/auth/deviceLoginLogic.go:88-95`
|
||||||
|
|
||||||
|
* 设备不存在分支未赋值:`internal/logic/auth/deviceLoginLogic.go:74-79`
|
||||||
|
|
||||||
|
* `UpdateDevice` 需要有效设备 `Id`:`internal/model/user/device.go:58-69`
|
||||||
|
|
||||||
|
## 修改方案
|
||||||
|
|
||||||
|
1. 在“设备不存在”分支注册完成后,立即通过标识重新查询设备,赋值给 `deviceInfo`:
|
||||||
|
|
||||||
|
* 在 `internal/logic/auth/deviceLoginLogic.go` 的 `if errors.Is(err, gorm.ErrRecordNotFound)` 分支中,`userInfo, err = l.registerUserAndDevice(req)` 之后追加:
|
||||||
|
|
||||||
|
* `deviceInfo, err = l.svcCtx.UserModel.FindOneDeviceByIdentifier(l.ctx, req.Identifier)`
|
||||||
|
|
||||||
|
* 如果查询失败则返回数据库查询错误(与现有风格一致)。
|
||||||
|
2. 在更新设备 UA 前增加空指针保护,并不再忽略更新错误:
|
||||||
|
|
||||||
|
* 将 `internal/logic/auth/deviceLoginLogic.go:99-101` 改为:
|
||||||
|
|
||||||
|
* 检查 `deviceInfo != nil`
|
||||||
|
|
||||||
|
* `deviceInfo.UserAgent = req.UserAgent`
|
||||||
|
|
||||||
|
* `if err := l.svcCtx.UserModel.UpdateDevice(l.ctx, deviceInfo); err != nil {` 记录错误并返回包装后的错误 `xerr.DatabaseUpdateError`。
|
||||||
|
3. 可选优化(减少二次查询):
|
||||||
|
|
||||||
|
* 将 `registerUserAndDevice(req)` 的返回值改为 `(*user.User, *user.Device, error)`,在注册时直接返回新建设备对象;调用点随之调整。若选择此方案,仍需在更新前做空指针保护。
|
||||||
|
|
||||||
|
## 代码示例(方案1,最小改动)
|
||||||
|
|
||||||
|
```go
|
||||||
|
// internal/logic/auth/deviceLoginLogic.go
|
||||||
|
// 设备不存在分支注册后追加一次设备查询
|
||||||
|
userInfo, err = l.registerUserAndDevice(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
deviceInfo, err = l.svcCtx.UserModel.FindOneDeviceByIdentifier(l.ctx, req.Identifier)
|
||||||
|
if err != nil {
|
||||||
|
l.Errorw("query device after register failed",
|
||||||
|
logger.Field("identifier", req.Identifier),
|
||||||
|
logger.Field("error", err.Error()),
|
||||||
|
)
|
||||||
|
return nil, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "query device after register failed: %v", err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
// 更新 UA,不忽略更新错误
|
||||||
|
if deviceInfo != nil {
|
||||||
|
deviceInfo.UserAgent = req.UserAgent
|
||||||
|
if err := l.svcCtx.UserModel.UpdateDevice(l.ctx, deviceInfo); err != nil {
|
||||||
|
l.Errorw("update device failed",
|
||||||
|
logger.Field("user_id", userInfo.Id),
|
||||||
|
logger.Field("identifier", req.Identifier),
|
||||||
|
logger.Field("error", err.Error()),
|
||||||
|
)
|
||||||
|
return nil, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseUpdateError), "update device failed: %v", err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## 测试用例与验证
|
||||||
|
|
||||||
|
* 用例1:首次设备标识登录(设备不存在)应成功返回 Token,日志包含注册与登录记录,无 500。
|
||||||
|
|
||||||
|
* 用例2:已存在设备标识登录(设备存在)应正常更新 UA 并返回 Token。
|
||||||
|
|
||||||
|
* 用例3:模拟数据库异常时应返回一致的业务错误码,不产生 `panic`。
|
||||||
|
|
||||||
|
## 风险与回滚
|
||||||
|
|
||||||
|
* 改动限定在登录逻辑,属最小范围;若出现异常,回滚为当前版本即可。
|
||||||
|
|
||||||
|
* 不改变数据结构与外部接口行为,兼容现有客户端。
|
||||||
|
|
||||||
|
## 后续优化(可选)
|
||||||
|
|
||||||
|
* 统一 `UpdateDevice` 错误处理路径,避免 `_ = ...` 静默失败。
|
||||||
|
|
||||||
|
* 为“首次设备登录”场景补充集成测试,保证不再回归。
|
||||||
|
|
||||||
@ -42,6 +42,9 @@ const SessionIdKey = "auth:session_id"
|
|||||||
// DeviceCacheKeyKey cache session key
|
// DeviceCacheKeyKey cache session key
|
||||||
const DeviceCacheKeyKey = "auth:device_identifier"
|
const DeviceCacheKeyKey = "auth:device_identifier"
|
||||||
|
|
||||||
|
// UserSessionsKeyPrefix per-user sessions zset key prefix
|
||||||
|
const UserSessionsKeyPrefix = "auth:user_sessions:"
|
||||||
|
|
||||||
// GlobalConfigKey Global Config Key
|
// GlobalConfigKey Global Config Key
|
||||||
const GlobalConfigKey = "system:global_config"
|
const GlobalConfigKey = "system:global_config"
|
||||||
|
|
||||||
|
|||||||
@ -44,6 +44,7 @@ type RedisConfig struct {
|
|||||||
type JwtAuth struct {
|
type JwtAuth struct {
|
||||||
AccessSecret string `yaml:"AccessSecret"`
|
AccessSecret string `yaml:"AccessSecret"`
|
||||||
AccessExpire int64 `yaml:"AccessExpire" default:"604800"`
|
AccessExpire int64 `yaml:"AccessExpire" default:"604800"`
|
||||||
|
MaxSessionsPerUser int64 `yaml:"MaxSessionsPerUser" default:"3"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type Verify struct {
|
type Verify struct {
|
||||||
|
|||||||
@ -140,6 +140,9 @@ func (l *DeviceLoginLogic) DeviceLogin(req *types.DeviceLoginRequest) (resp *typ
|
|||||||
return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "token generate error: %v", err.Error())
|
return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "token generate error: %v", err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err = l.svcCtx.EnforceUserSessionLimit(l.ctx, userInfo.Id, sessionId, l.svcCtx.Config.JwtAuth.MaxSessionsPerUser); err != nil {
|
||||||
|
return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "enforce session limit error: %v", err.Error())
|
||||||
|
}
|
||||||
// Store session id in redis
|
// Store session id in redis
|
||||||
sessionIdCacheKey := fmt.Sprintf("%v:%v", config.SessionIdKey, sessionId)
|
sessionIdCacheKey := fmt.Sprintf("%v:%v", config.SessionIdKey, sessionId)
|
||||||
if err = l.svcCtx.Redis.Set(l.ctx, sessionIdCacheKey, userInfo.Id, time.Duration(l.svcCtx.Config.JwtAuth.AccessExpire)*time.Second).Err(); err != nil {
|
if err = l.svcCtx.Redis.Set(l.ctx, sessionIdCacheKey, userInfo.Id, time.Duration(l.svcCtx.Config.JwtAuth.AccessExpire)*time.Second).Err(); err != nil {
|
||||||
|
|||||||
@ -587,6 +587,10 @@ func (l *OAuthLoginGetTokenLogic) generateToken(userInfo *user.User, requestID s
|
|||||||
return "", errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "token generate error: %v", err)
|
return "", errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "token generate error: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err = l.svcCtx.EnforceUserSessionLimit(l.ctx, userInfo.Id, sessionId, l.svcCtx.Config.JwtAuth.MaxSessionsPerUser); err != nil {
|
||||||
|
return "", errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "enforce session limit error: %v", err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
sessionIdCacheKey := fmt.Sprintf("%v:%v", config.SessionIdKey, sessionId)
|
sessionIdCacheKey := fmt.Sprintf("%v:%v", config.SessionIdKey, sessionId)
|
||||||
if err = l.svcCtx.Redis.Set(l.ctx, sessionIdCacheKey, userInfo.Id, time.Duration(l.svcCtx.Config.JwtAuth.AccessExpire)*time.Second).Err(); err != nil {
|
if err = l.svcCtx.Redis.Set(l.ctx, sessionIdCacheKey, userInfo.Id, time.Duration(l.svcCtx.Config.JwtAuth.AccessExpire)*time.Second).Err(); err != nil {
|
||||||
l.Errorw("failed to cache session id",
|
l.Errorw("failed to cache session id",
|
||||||
|
|||||||
@ -156,6 +156,9 @@ func (l *TelephoneLoginLogic) TelephoneLogin(req *types.TelephoneLoginRequest, r
|
|||||||
l.Logger.Error("[UserLogin] token generate error", logger.Field("error", err.Error()))
|
l.Logger.Error("[UserLogin] token generate error", logger.Field("error", err.Error()))
|
||||||
return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "token generate error: %v", err.Error())
|
return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "token generate error: %v", err.Error())
|
||||||
}
|
}
|
||||||
|
if err = l.svcCtx.EnforceUserSessionLimit(l.ctx, userInfo.Id, sessionId, l.svcCtx.Config.JwtAuth.MaxSessionsPerUser); err != nil {
|
||||||
|
return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "enforce session limit error: %v", err.Error())
|
||||||
|
}
|
||||||
sessionIdCacheKey := fmt.Sprintf("%v:%v", config.SessionIdKey, sessionId)
|
sessionIdCacheKey := fmt.Sprintf("%v:%v", config.SessionIdKey, sessionId)
|
||||||
if err = l.svcCtx.Redis.Set(l.ctx, sessionIdCacheKey, userInfo.Id, time.Duration(l.svcCtx.Config.JwtAuth.AccessExpire)*time.Second).Err(); err != nil {
|
if err = l.svcCtx.Redis.Set(l.ctx, sessionIdCacheKey, userInfo.Id, time.Duration(l.svcCtx.Config.JwtAuth.AccessExpire)*time.Second).Err(); err != nil {
|
||||||
return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "set session id error: %v", err.Error())
|
return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "set session id error: %v", err.Error())
|
||||||
|
|||||||
@ -111,6 +111,9 @@ func (l *UserLoginLogic) UserLogin(req *types.UserLoginRequest) (resp *types.Log
|
|||||||
l.Logger.Error("[UserLogin] token generate error", logger.Field("error", err.Error()))
|
l.Logger.Error("[UserLogin] token generate error", logger.Field("error", err.Error()))
|
||||||
return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "token generate error: %v", err.Error())
|
return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "token generate error: %v", err.Error())
|
||||||
}
|
}
|
||||||
|
if err = l.svcCtx.EnforceUserSessionLimit(l.ctx, userInfo.Id, sessionId, l.svcCtx.Config.JwtAuth.MaxSessionsPerUser); err != nil {
|
||||||
|
return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "enforce session limit error: %v", err.Error())
|
||||||
|
}
|
||||||
sessionIdCacheKey := fmt.Sprintf("%v:%v", config.SessionIdKey, sessionId)
|
sessionIdCacheKey := fmt.Sprintf("%v:%v", config.SessionIdKey, sessionId)
|
||||||
if err = l.svcCtx.Redis.Set(l.ctx, sessionIdCacheKey, userInfo.Id, time.Duration(l.svcCtx.Config.JwtAuth.AccessExpire)*time.Second).Err(); err != nil {
|
if err = l.svcCtx.Redis.Set(l.ctx, sessionIdCacheKey, userInfo.Id, time.Duration(l.svcCtx.Config.JwtAuth.AccessExpire)*time.Second).Err(); err != nil {
|
||||||
return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "set session id error: %v", err.Error())
|
return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "set session id error: %v", err.Error())
|
||||||
|
|||||||
@ -2,6 +2,8 @@ package svc
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/perfect-panel/server/internal/model/client"
|
"github.com/perfect-panel/server/internal/model/client"
|
||||||
"github.com/perfect-panel/server/internal/model/node"
|
"github.com/perfect-panel/server/internal/model/node"
|
||||||
@ -112,3 +114,30 @@ func NewServiceContext(c config.Config) *ServiceContext {
|
|||||||
return srv
|
return srv
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (srv *ServiceContext) EnforceUserSessionLimit(ctx context.Context, userId int64, newSessionId string, max int64) error {
|
||||||
|
if max <= 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
sessionsKey := fmt.Sprintf("%s%v", config.UserSessionsKeyPrefix, userId)
|
||||||
|
now := time.Now().Unix()
|
||||||
|
if err := srv.Redis.ZAdd(ctx, sessionsKey, redis.Z{Score: float64(now), Member: newSessionId}).Err(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
count, err := srv.Redis.ZCard(ctx, sessionsKey).Result()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if count > max {
|
||||||
|
popped, err := srv.Redis.ZPopMin(ctx, sessionsKey, count-max).Result()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
for _, z := range popped {
|
||||||
|
sid := fmt.Sprintf("%v", z.Member)
|
||||||
|
_ = srv.Redis.Del(ctx, fmt.Sprintf("%v:%v", config.SessionIdKey, sid)).Err()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ = srv.Redis.Expire(ctx, sessionsKey, time.Duration(srv.Config.JwtAuth.AccessExpire)*time.Second).Err()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user