feat(auth): 添加设备绑定数量限制检查
All checks were successful
Build docker and publish / build (20.15.1) (push) Successful in 7m26s

在设备绑定逻辑中添加对设备绑定数量的限制检查,当超过限制时返回特定错误码
同时在用户注册、登录等流程中处理设备绑定数量超限的错误情况
This commit is contained in:
shanshanzhong 2025-11-27 23:58:10 -08:00
parent 2442831cd7
commit 6afd6eb307
7 changed files with 90 additions and 58 deletions

View File

@ -88,7 +88,15 @@ func (l *BindDeviceLogic) createDeviceForUser(identifier, ip, userAgent string,
logger.Field("user_id", userId),
)
err := l.svcCtx.UserModel.Transaction(l.ctx, func(db *gorm.DB) error {
// enforce device bind limit before creating
if limit := l.svcCtx.SessionLimit(); limit > 0 {
if _, count, err := l.svcCtx.UserModel.QueryDeviceList(l.ctx, userId); err == nil {
if count >= limit {
return xerr.NewErrCodeMsg(xerr.DeviceBindLimitExceeded, "账户绑定设备数已达上限,请移除其他设备后再登录,您也可以再注册一个新账户使用,点击帮助中心查看更多详情。")
}
}
}
err := l.svcCtx.UserModel.Transaction(l.ctx, func(db *gorm.DB) error {
// Create device auth method
authMethod := &user.AuthMethods{
UserId: userId,
@ -123,8 +131,8 @@ func (l *BindDeviceLogic) createDeviceForUser(identifier, ip, userAgent string,
return errors.Wrapf(xerr.NewErrCode(xerr.DatabaseInsertError), "create device failed: %v", err)
}
return nil
})
return nil
})
if err != nil {
l.Errorw("device creation failed",
@ -144,9 +152,17 @@ func (l *BindDeviceLogic) createDeviceForUser(identifier, ip, userAgent string,
}
func (l *BindDeviceLogic) rebindDeviceToNewUser(deviceInfo *user.Device, ip, userAgent string, newUserId int64) error {
oldUserId := deviceInfo.UserId
oldUserId := deviceInfo.UserId
err := l.svcCtx.UserModel.Transaction(l.ctx, func(db *gorm.DB) error {
// enforce device bind limit before rebind
if limit := l.svcCtx.SessionLimit(); limit > 0 {
if _, count, err := l.svcCtx.UserModel.QueryDeviceList(l.ctx, newUserId); err == nil {
if count >= limit {
return xerr.NewErrCodeMsg(xerr.DeviceBindLimitExceeded, "账户绑定设备数已达上限,请移除其他设备后再登录,您也可以再注册一个新账户使用,点击帮助中心查看更多详情。")
}
}
}
err := l.svcCtx.UserModel.Transaction(l.ctx, func(db *gorm.DB) error {
// Check if old user has other auth methods besides device
var authMethods []user.AuthMethods
if err := db.Where("user_id = ?", oldUserId).Find(&authMethods).Error; err != nil {
@ -211,8 +227,8 @@ func (l *BindDeviceLogic) rebindDeviceToNewUser(deviceInfo *user.Device, ip, use
return errors.Wrapf(xerr.NewErrCode(xerr.DatabaseUpdateError), "update device failed: %v", err)
}
return nil
})
return nil
})
if err != nil {
l.Errorw("device rebinding failed",

View File

@ -110,17 +110,20 @@ func (l *ResetPasswordLogic) ResetPassword(req *types.ResetPasswordRequest) (res
}
// Bind device to user if identifier is provided
if req.Identifier != "" {
bindLogic := NewBindDeviceLogic(l.ctx, l.svcCtx)
if err := bindLogic.BindDeviceToUser(req.Identifier, req.IP, req.UserAgent, userInfo.Id); err != nil {
l.Errorw("failed to bind device to user",
logger.Field("user_id", userInfo.Id),
logger.Field("identifier", req.Identifier),
logger.Field("error", err.Error()),
)
// Don't fail register if device binding fails, just log the error
}
}
if req.Identifier != "" {
bindLogic := NewBindDeviceLogic(l.ctx, l.svcCtx)
if err := bindLogic.BindDeviceToUser(req.Identifier, req.IP, req.UserAgent, userInfo.Id); err != nil {
var ce *xerr.CodeError
if errors.As(err, &ce) && ce.GetErrCode() == xerr.DeviceBindLimitExceeded {
return nil, ce
}
l.Errorw("failed to bind device to user",
logger.Field("user_id", userInfo.Id),
logger.Field("identifier", req.Identifier),
logger.Field("error", err.Error()),
)
}
}
if l.ctx.Value(constant.LoginType) != nil {
req.LoginType = l.ctx.Value(constant.LoginType).(string)
}

View File

@ -128,12 +128,15 @@ func (l *TelephoneLoginLogic) TelephoneLogin(req *types.TelephoneLoginRequest, r
if req.Identifier != "" {
bindLogic := NewBindDeviceLogic(l.ctx, l.svcCtx)
if err := bindLogic.BindDeviceToUser(req.Identifier, req.IP, req.UserAgent, userInfo.Id); err != nil {
var ce *xerr.CodeError
if errors.As(err, &ce) && ce.GetErrCode() == xerr.DeviceBindLimitExceeded {
return nil, ce
}
l.Errorw("failed to bind device to user",
logger.Field("user_id", userInfo.Id),
logger.Field("identifier", req.Identifier),
logger.Field("error", err.Error()),
)
// Don't fail login if device binding fails, just log the error
}
}
@ -156,16 +159,16 @@ func (l *TelephoneLoginLogic) TelephoneLogin(req *types.TelephoneLoginRequest, r
l.Logger.Error("[UserLogin] token generate error", logger.Field("error", err.Error()))
return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "token generate error: %v", err.Error())
}
if err = l.svcCtx.EnforceUserSessionLimit(l.ctx, userInfo.Id, sessionId, l.svcCtx.SessionLimit()); err != nil {
return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "enforce session limit error: %v", err.Error())
}
if err = l.svcCtx.EnforceUserSessionLimit(l.ctx, userInfo.Id, sessionId, l.svcCtx.SessionLimit()); err != nil {
return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "enforce session limit error: %v", err.Error())
}
sessionIdCacheKey := fmt.Sprintf("%v:%v", config.SessionIdKey, sessionId)
if err = l.svcCtx.Redis.Set(l.ctx, sessionIdCacheKey, userInfo.Id, time.Duration(l.svcCtx.Config.JwtAuth.AccessExpire)*time.Second).Err(); err != nil {
return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "set session id error: %v", err.Error())
}
loginStatus = true
return &types.LoginResponse{
Token: token,
Limit: l.svcCtx.SessionLimit(),
}, nil
return &types.LoginResponse{
Token: token,
Limit: l.svcCtx.SessionLimit(),
}, nil
}

View File

@ -141,17 +141,20 @@ func (l *TelephoneUserRegisterLogic) TelephoneUserRegister(req *types.TelephoneR
})
// Bind device to user if identifier is provided
if req.Identifier != "" {
bindLogic := NewBindDeviceLogic(l.ctx, l.svcCtx)
if err := bindLogic.BindDeviceToUser(req.Identifier, req.IP, req.UserAgent, userInfo.Id); err != nil {
l.Errorw("failed to bind device to user",
logger.Field("user_id", userInfo.Id),
logger.Field("identifier", req.Identifier),
logger.Field("error", err.Error()),
)
// Don't fail register if device binding fails, just log the error
}
}
if req.Identifier != "" {
bindLogic := NewBindDeviceLogic(l.ctx, l.svcCtx)
if err := bindLogic.BindDeviceToUser(req.Identifier, req.IP, req.UserAgent, userInfo.Id); err != nil {
var ce *xerr.CodeError
if errors.As(err, &ce) && ce.GetErrCode() == xerr.DeviceBindLimitExceeded {
return nil, ce
}
l.Errorw("failed to bind device to user",
logger.Field("user_id", userInfo.Id),
logger.Field("identifier", req.Identifier),
logger.Field("error", err.Error()),
)
}
}
if l.ctx.Value(constant.LoginType) != nil {
req.LoginType = l.ctx.Value(constant.LoginType).(string)
}

View File

@ -85,12 +85,15 @@ func (l *UserLoginLogic) UserLogin(req *types.UserLoginRequest) (resp *types.Log
if req.Identifier != "" {
bindLogic := NewBindDeviceLogic(l.ctx, l.svcCtx)
if err := bindLogic.BindDeviceToUser(req.Identifier, req.IP, req.UserAgent, userInfo.Id); err != nil {
var ce *xerr.CodeError
if errors.As(err, &ce) && ce.GetErrCode() == xerr.DeviceBindLimitExceeded {
return nil, ce
}
l.Errorw("failed to bind device to user",
logger.Field("user_id", userInfo.Id),
logger.Field("identifier", req.Identifier),
logger.Field("error", err.Error()),
)
// Don't fail login if device binding fails, just log the error
}
}
if l.ctx.Value(constant.LoginType) != nil {
@ -111,16 +114,16 @@ func (l *UserLoginLogic) UserLogin(req *types.UserLoginRequest) (resp *types.Log
l.Logger.Error("[UserLogin] token generate error", logger.Field("error", err.Error()))
return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "token generate error: %v", err.Error())
}
if err = l.svcCtx.EnforceUserSessionLimit(l.ctx, userInfo.Id, sessionId, l.svcCtx.SessionLimit()); err != nil {
return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "enforce session limit error: %v", err.Error())
}
if err = l.svcCtx.EnforceUserSessionLimit(l.ctx, userInfo.Id, sessionId, l.svcCtx.SessionLimit()); err != nil {
return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "enforce session limit error: %v", err.Error())
}
sessionIdCacheKey := fmt.Sprintf("%v:%v", config.SessionIdKey, sessionId)
if err = l.svcCtx.Redis.Set(l.ctx, sessionIdCacheKey, userInfo.Id, time.Duration(l.svcCtx.Config.JwtAuth.AccessExpire)*time.Second).Err(); err != nil {
return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "set session id error: %v", err.Error())
}
loginStatus = true
return &types.LoginResponse{
Token: token,
Limit: l.svcCtx.SessionLimit(),
}, nil
return &types.LoginResponse{
Token: token,
Limit: l.svcCtx.SessionLimit(),
}, nil
}

View File

@ -127,17 +127,20 @@ func (l *UserRegisterLogic) UserRegister(req *types.UserRegisterRequest) (resp *
return nil
})
// Bind device to user if identifier is provided
if req.Identifier != "" {
bindLogic := NewBindDeviceLogic(l.ctx, l.svcCtx)
if err := bindLogic.BindDeviceToUser(req.Identifier, req.IP, req.UserAgent, userInfo.Id); err != nil {
l.Errorw("failed to bind device to user",
logger.Field("user_id", userInfo.Id),
logger.Field("identifier", req.Identifier),
logger.Field("error", err.Error()),
)
// Don't fail register if device binding fails, just log the error
}
}
if req.Identifier != "" {
bindLogic := NewBindDeviceLogic(l.ctx, l.svcCtx)
if err := bindLogic.BindDeviceToUser(req.Identifier, req.IP, req.UserAgent, userInfo.Id); err != nil {
var ce *xerr.CodeError
if errors.As(err, &ce) && ce.GetErrCode() == xerr.DeviceBindLimitExceeded {
return nil, ce
}
l.Errorw("failed to bind device to user",
logger.Field("user_id", userInfo.Id),
logger.Field("identifier", req.Identifier),
logger.Field("error", err.Error()),
)
}
}
if l.ctx.Value(constant.LoginType) != nil {
req.LoginType = l.ctx.Value(constant.LoginType).(string)
}

View File

@ -114,8 +114,9 @@ const (
TelephoneError uint32 = 90014
)
const (
DeviceNotExist uint32 = 90017
UseridNotMatch uint32 = 90018
DeviceNotExist uint32 = 90017
UseridNotMatch uint32 = 90018
DeviceBindLimitExceeded uint32 = 90019
)
const (