feat(auth): 添加设备绑定数量限制检查
All checks were successful
Build docker and publish / build (20.15.1) (push) Successful in 7m26s
All checks were successful
Build docker and publish / build (20.15.1) (push) Successful in 7m26s
在设备绑定逻辑中添加对设备绑定数量的限制检查,当超过限制时返回特定错误码 同时在用户注册、登录等流程中处理设备绑定数量超限的错误情况
This commit is contained in:
parent
2442831cd7
commit
6afd6eb307
@ -88,7 +88,15 @@ func (l *BindDeviceLogic) createDeviceForUser(identifier, ip, userAgent string,
|
|||||||
logger.Field("user_id", userId),
|
logger.Field("user_id", userId),
|
||||||
)
|
)
|
||||||
|
|
||||||
err := l.svcCtx.UserModel.Transaction(l.ctx, func(db *gorm.DB) error {
|
// enforce device bind limit before creating
|
||||||
|
if limit := l.svcCtx.SessionLimit(); limit > 0 {
|
||||||
|
if _, count, err := l.svcCtx.UserModel.QueryDeviceList(l.ctx, userId); err == nil {
|
||||||
|
if count >= limit {
|
||||||
|
return xerr.NewErrCodeMsg(xerr.DeviceBindLimitExceeded, "账户绑定设备数已达上限,请移除其他设备后再登录,您也可以再注册一个新账户使用,点击帮助中心查看更多详情。")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
err := l.svcCtx.UserModel.Transaction(l.ctx, func(db *gorm.DB) error {
|
||||||
// Create device auth method
|
// Create device auth method
|
||||||
authMethod := &user.AuthMethods{
|
authMethod := &user.AuthMethods{
|
||||||
UserId: userId,
|
UserId: userId,
|
||||||
@ -123,8 +131,8 @@ func (l *BindDeviceLogic) createDeviceForUser(identifier, ip, userAgent string,
|
|||||||
return errors.Wrapf(xerr.NewErrCode(xerr.DatabaseInsertError), "create device failed: %v", err)
|
return errors.Wrapf(xerr.NewErrCode(xerr.DatabaseInsertError), "create device failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.Errorw("device creation failed",
|
l.Errorw("device creation failed",
|
||||||
@ -144,9 +152,17 @@ func (l *BindDeviceLogic) createDeviceForUser(identifier, ip, userAgent string,
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (l *BindDeviceLogic) rebindDeviceToNewUser(deviceInfo *user.Device, ip, userAgent string, newUserId int64) error {
|
func (l *BindDeviceLogic) rebindDeviceToNewUser(deviceInfo *user.Device, ip, userAgent string, newUserId int64) error {
|
||||||
oldUserId := deviceInfo.UserId
|
oldUserId := deviceInfo.UserId
|
||||||
|
|
||||||
err := l.svcCtx.UserModel.Transaction(l.ctx, func(db *gorm.DB) error {
|
// enforce device bind limit before rebind
|
||||||
|
if limit := l.svcCtx.SessionLimit(); limit > 0 {
|
||||||
|
if _, count, err := l.svcCtx.UserModel.QueryDeviceList(l.ctx, newUserId); err == nil {
|
||||||
|
if count >= limit {
|
||||||
|
return xerr.NewErrCodeMsg(xerr.DeviceBindLimitExceeded, "账户绑定设备数已达上限,请移除其他设备后再登录,您也可以再注册一个新账户使用,点击帮助中心查看更多详情。")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
err := l.svcCtx.UserModel.Transaction(l.ctx, func(db *gorm.DB) error {
|
||||||
// Check if old user has other auth methods besides device
|
// Check if old user has other auth methods besides device
|
||||||
var authMethods []user.AuthMethods
|
var authMethods []user.AuthMethods
|
||||||
if err := db.Where("user_id = ?", oldUserId).Find(&authMethods).Error; err != nil {
|
if err := db.Where("user_id = ?", oldUserId).Find(&authMethods).Error; err != nil {
|
||||||
@ -211,8 +227,8 @@ func (l *BindDeviceLogic) rebindDeviceToNewUser(deviceInfo *user.Device, ip, use
|
|||||||
return errors.Wrapf(xerr.NewErrCode(xerr.DatabaseUpdateError), "update device failed: %v", err)
|
return errors.Wrapf(xerr.NewErrCode(xerr.DatabaseUpdateError), "update device failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.Errorw("device rebinding failed",
|
l.Errorw("device rebinding failed",
|
||||||
|
|||||||
@ -110,17 +110,20 @@ func (l *ResetPasswordLogic) ResetPassword(req *types.ResetPasswordRequest) (res
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Bind device to user if identifier is provided
|
// Bind device to user if identifier is provided
|
||||||
if req.Identifier != "" {
|
if req.Identifier != "" {
|
||||||
bindLogic := NewBindDeviceLogic(l.ctx, l.svcCtx)
|
bindLogic := NewBindDeviceLogic(l.ctx, l.svcCtx)
|
||||||
if err := bindLogic.BindDeviceToUser(req.Identifier, req.IP, req.UserAgent, userInfo.Id); err != nil {
|
if err := bindLogic.BindDeviceToUser(req.Identifier, req.IP, req.UserAgent, userInfo.Id); err != nil {
|
||||||
l.Errorw("failed to bind device to user",
|
var ce *xerr.CodeError
|
||||||
logger.Field("user_id", userInfo.Id),
|
if errors.As(err, &ce) && ce.GetErrCode() == xerr.DeviceBindLimitExceeded {
|
||||||
logger.Field("identifier", req.Identifier),
|
return nil, ce
|
||||||
logger.Field("error", err.Error()),
|
}
|
||||||
)
|
l.Errorw("failed to bind device to user",
|
||||||
// Don't fail register if device binding fails, just log the error
|
logger.Field("user_id", userInfo.Id),
|
||||||
}
|
logger.Field("identifier", req.Identifier),
|
||||||
}
|
logger.Field("error", err.Error()),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
if l.ctx.Value(constant.LoginType) != nil {
|
if l.ctx.Value(constant.LoginType) != nil {
|
||||||
req.LoginType = l.ctx.Value(constant.LoginType).(string)
|
req.LoginType = l.ctx.Value(constant.LoginType).(string)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -128,12 +128,15 @@ func (l *TelephoneLoginLogic) TelephoneLogin(req *types.TelephoneLoginRequest, r
|
|||||||
if req.Identifier != "" {
|
if req.Identifier != "" {
|
||||||
bindLogic := NewBindDeviceLogic(l.ctx, l.svcCtx)
|
bindLogic := NewBindDeviceLogic(l.ctx, l.svcCtx)
|
||||||
if err := bindLogic.BindDeviceToUser(req.Identifier, req.IP, req.UserAgent, userInfo.Id); err != nil {
|
if err := bindLogic.BindDeviceToUser(req.Identifier, req.IP, req.UserAgent, userInfo.Id); err != nil {
|
||||||
|
var ce *xerr.CodeError
|
||||||
|
if errors.As(err, &ce) && ce.GetErrCode() == xerr.DeviceBindLimitExceeded {
|
||||||
|
return nil, ce
|
||||||
|
}
|
||||||
l.Errorw("failed to bind device to user",
|
l.Errorw("failed to bind device to user",
|
||||||
logger.Field("user_id", userInfo.Id),
|
logger.Field("user_id", userInfo.Id),
|
||||||
logger.Field("identifier", req.Identifier),
|
logger.Field("identifier", req.Identifier),
|
||||||
logger.Field("error", err.Error()),
|
logger.Field("error", err.Error()),
|
||||||
)
|
)
|
||||||
// Don't fail login if device binding fails, just log the error
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -156,16 +159,16 @@ func (l *TelephoneLoginLogic) TelephoneLogin(req *types.TelephoneLoginRequest, r
|
|||||||
l.Logger.Error("[UserLogin] token generate error", logger.Field("error", err.Error()))
|
l.Logger.Error("[UserLogin] token generate error", logger.Field("error", err.Error()))
|
||||||
return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "token generate error: %v", err.Error())
|
return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "token generate error: %v", err.Error())
|
||||||
}
|
}
|
||||||
if err = l.svcCtx.EnforceUserSessionLimit(l.ctx, userInfo.Id, sessionId, l.svcCtx.SessionLimit()); err != nil {
|
if err = l.svcCtx.EnforceUserSessionLimit(l.ctx, userInfo.Id, sessionId, l.svcCtx.SessionLimit()); err != nil {
|
||||||
return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "enforce session limit error: %v", err.Error())
|
return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "enforce session limit error: %v", err.Error())
|
||||||
}
|
}
|
||||||
sessionIdCacheKey := fmt.Sprintf("%v:%v", config.SessionIdKey, sessionId)
|
sessionIdCacheKey := fmt.Sprintf("%v:%v", config.SessionIdKey, sessionId)
|
||||||
if err = l.svcCtx.Redis.Set(l.ctx, sessionIdCacheKey, userInfo.Id, time.Duration(l.svcCtx.Config.JwtAuth.AccessExpire)*time.Second).Err(); err != nil {
|
if err = l.svcCtx.Redis.Set(l.ctx, sessionIdCacheKey, userInfo.Id, time.Duration(l.svcCtx.Config.JwtAuth.AccessExpire)*time.Second).Err(); err != nil {
|
||||||
return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "set session id error: %v", err.Error())
|
return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "set session id error: %v", err.Error())
|
||||||
}
|
}
|
||||||
loginStatus = true
|
loginStatus = true
|
||||||
return &types.LoginResponse{
|
return &types.LoginResponse{
|
||||||
Token: token,
|
Token: token,
|
||||||
Limit: l.svcCtx.SessionLimit(),
|
Limit: l.svcCtx.SessionLimit(),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@ -141,17 +141,20 @@ func (l *TelephoneUserRegisterLogic) TelephoneUserRegister(req *types.TelephoneR
|
|||||||
})
|
})
|
||||||
|
|
||||||
// Bind device to user if identifier is provided
|
// Bind device to user if identifier is provided
|
||||||
if req.Identifier != "" {
|
if req.Identifier != "" {
|
||||||
bindLogic := NewBindDeviceLogic(l.ctx, l.svcCtx)
|
bindLogic := NewBindDeviceLogic(l.ctx, l.svcCtx)
|
||||||
if err := bindLogic.BindDeviceToUser(req.Identifier, req.IP, req.UserAgent, userInfo.Id); err != nil {
|
if err := bindLogic.BindDeviceToUser(req.Identifier, req.IP, req.UserAgent, userInfo.Id); err != nil {
|
||||||
l.Errorw("failed to bind device to user",
|
var ce *xerr.CodeError
|
||||||
logger.Field("user_id", userInfo.Id),
|
if errors.As(err, &ce) && ce.GetErrCode() == xerr.DeviceBindLimitExceeded {
|
||||||
logger.Field("identifier", req.Identifier),
|
return nil, ce
|
||||||
logger.Field("error", err.Error()),
|
}
|
||||||
)
|
l.Errorw("failed to bind device to user",
|
||||||
// Don't fail register if device binding fails, just log the error
|
logger.Field("user_id", userInfo.Id),
|
||||||
}
|
logger.Field("identifier", req.Identifier),
|
||||||
}
|
logger.Field("error", err.Error()),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
if l.ctx.Value(constant.LoginType) != nil {
|
if l.ctx.Value(constant.LoginType) != nil {
|
||||||
req.LoginType = l.ctx.Value(constant.LoginType).(string)
|
req.LoginType = l.ctx.Value(constant.LoginType).(string)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -85,12 +85,15 @@ func (l *UserLoginLogic) UserLogin(req *types.UserLoginRequest) (resp *types.Log
|
|||||||
if req.Identifier != "" {
|
if req.Identifier != "" {
|
||||||
bindLogic := NewBindDeviceLogic(l.ctx, l.svcCtx)
|
bindLogic := NewBindDeviceLogic(l.ctx, l.svcCtx)
|
||||||
if err := bindLogic.BindDeviceToUser(req.Identifier, req.IP, req.UserAgent, userInfo.Id); err != nil {
|
if err := bindLogic.BindDeviceToUser(req.Identifier, req.IP, req.UserAgent, userInfo.Id); err != nil {
|
||||||
|
var ce *xerr.CodeError
|
||||||
|
if errors.As(err, &ce) && ce.GetErrCode() == xerr.DeviceBindLimitExceeded {
|
||||||
|
return nil, ce
|
||||||
|
}
|
||||||
l.Errorw("failed to bind device to user",
|
l.Errorw("failed to bind device to user",
|
||||||
logger.Field("user_id", userInfo.Id),
|
logger.Field("user_id", userInfo.Id),
|
||||||
logger.Field("identifier", req.Identifier),
|
logger.Field("identifier", req.Identifier),
|
||||||
logger.Field("error", err.Error()),
|
logger.Field("error", err.Error()),
|
||||||
)
|
)
|
||||||
// Don't fail login if device binding fails, just log the error
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if l.ctx.Value(constant.LoginType) != nil {
|
if l.ctx.Value(constant.LoginType) != nil {
|
||||||
@ -111,16 +114,16 @@ func (l *UserLoginLogic) UserLogin(req *types.UserLoginRequest) (resp *types.Log
|
|||||||
l.Logger.Error("[UserLogin] token generate error", logger.Field("error", err.Error()))
|
l.Logger.Error("[UserLogin] token generate error", logger.Field("error", err.Error()))
|
||||||
return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "token generate error: %v", err.Error())
|
return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "token generate error: %v", err.Error())
|
||||||
}
|
}
|
||||||
if err = l.svcCtx.EnforceUserSessionLimit(l.ctx, userInfo.Id, sessionId, l.svcCtx.SessionLimit()); err != nil {
|
if err = l.svcCtx.EnforceUserSessionLimit(l.ctx, userInfo.Id, sessionId, l.svcCtx.SessionLimit()); err != nil {
|
||||||
return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "enforce session limit error: %v", err.Error())
|
return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "enforce session limit error: %v", err.Error())
|
||||||
}
|
}
|
||||||
sessionIdCacheKey := fmt.Sprintf("%v:%v", config.SessionIdKey, sessionId)
|
sessionIdCacheKey := fmt.Sprintf("%v:%v", config.SessionIdKey, sessionId)
|
||||||
if err = l.svcCtx.Redis.Set(l.ctx, sessionIdCacheKey, userInfo.Id, time.Duration(l.svcCtx.Config.JwtAuth.AccessExpire)*time.Second).Err(); err != nil {
|
if err = l.svcCtx.Redis.Set(l.ctx, sessionIdCacheKey, userInfo.Id, time.Duration(l.svcCtx.Config.JwtAuth.AccessExpire)*time.Second).Err(); err != nil {
|
||||||
return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "set session id error: %v", err.Error())
|
return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "set session id error: %v", err.Error())
|
||||||
}
|
}
|
||||||
loginStatus = true
|
loginStatus = true
|
||||||
return &types.LoginResponse{
|
return &types.LoginResponse{
|
||||||
Token: token,
|
Token: token,
|
||||||
Limit: l.svcCtx.SessionLimit(),
|
Limit: l.svcCtx.SessionLimit(),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@ -127,17 +127,20 @@ func (l *UserRegisterLogic) UserRegister(req *types.UserRegisterRequest) (resp *
|
|||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
// Bind device to user if identifier is provided
|
// Bind device to user if identifier is provided
|
||||||
if req.Identifier != "" {
|
if req.Identifier != "" {
|
||||||
bindLogic := NewBindDeviceLogic(l.ctx, l.svcCtx)
|
bindLogic := NewBindDeviceLogic(l.ctx, l.svcCtx)
|
||||||
if err := bindLogic.BindDeviceToUser(req.Identifier, req.IP, req.UserAgent, userInfo.Id); err != nil {
|
if err := bindLogic.BindDeviceToUser(req.Identifier, req.IP, req.UserAgent, userInfo.Id); err != nil {
|
||||||
l.Errorw("failed to bind device to user",
|
var ce *xerr.CodeError
|
||||||
logger.Field("user_id", userInfo.Id),
|
if errors.As(err, &ce) && ce.GetErrCode() == xerr.DeviceBindLimitExceeded {
|
||||||
logger.Field("identifier", req.Identifier),
|
return nil, ce
|
||||||
logger.Field("error", err.Error()),
|
}
|
||||||
)
|
l.Errorw("failed to bind device to user",
|
||||||
// Don't fail register if device binding fails, just log the error
|
logger.Field("user_id", userInfo.Id),
|
||||||
}
|
logger.Field("identifier", req.Identifier),
|
||||||
}
|
logger.Field("error", err.Error()),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
if l.ctx.Value(constant.LoginType) != nil {
|
if l.ctx.Value(constant.LoginType) != nil {
|
||||||
req.LoginType = l.ctx.Value(constant.LoginType).(string)
|
req.LoginType = l.ctx.Value(constant.LoginType).(string)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -114,8 +114,9 @@ const (
|
|||||||
TelephoneError uint32 = 90014
|
TelephoneError uint32 = 90014
|
||||||
)
|
)
|
||||||
const (
|
const (
|
||||||
DeviceNotExist uint32 = 90017
|
DeviceNotExist uint32 = 90017
|
||||||
UseridNotMatch uint32 = 90018
|
UseridNotMatch uint32 = 90018
|
||||||
|
DeviceBindLimitExceeded uint32 = 90019
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user