diff --git a/internal/logic/auth/bindDeviceLogic.go b/internal/logic/auth/bindDeviceLogic.go index 40110a3..391c080 100644 --- a/internal/logic/auth/bindDeviceLogic.go +++ b/internal/logic/auth/bindDeviceLogic.go @@ -107,8 +107,9 @@ func (l *BindDeviceLogic) createDeviceForUser(identifier, ip, userAgent string, // Create device record deviceInfo := &user.Device{ - Ip: ip, - UserId: userId, + Ip: ip, + UserId: userId, + UserAgent: userAgent, Identifier: identifier, Enabled: true, @@ -145,11 +146,57 @@ func (l *BindDeviceLogic) createDeviceForUser(identifier, ip, userAgent string, func (l *BindDeviceLogic) rebindDeviceToNewUser(deviceInfo *user.Device, ip, userAgent string, newUserId int64) error { oldUserId := deviceInfo.UserId + var users []*user.User + err := l.svcCtx.DB.Where("id in (?)", []int64{oldUserId, newUserId}).Find(&users).Error + if err != nil { + l.Errorw("failed to query users for rebinding", + logger.Field("old_user_id", oldUserId), + logger.Field("new_user_id", newUserId), + logger.Field("error", err.Error()), + ) + return errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "query users failed: %v", err) + } + err = l.svcCtx.UserModel.Transaction(l.ctx, func(tx *gorm.DB) error { + //检查旧设备是否存在认证方式 + var authMethod user.AuthMethods + err := tx.Where("auth_type = ? AND auth_identifier = ?", "device", deviceInfo.Identifier).Find(&authMethod).Error + if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { + l.Errorw("failed to query device auth method", + logger.Field("identifier", deviceInfo.Identifier), + logger.Field("error", err.Error()), + ) + return errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "query device auth method failed: %v", err) + } - err := l.svcCtx.UserModel.Transaction(l.ctx, func(db *gorm.DB) error { - // Check if old user has other auth methods besides device + //未找到设备认证方式信息,创建新的设备认证方式 + if err != nil { + authMethod = user.AuthMethods{ + UserId: newUserId, + AuthType: "device", + AuthIdentifier: deviceInfo.Identifier, + Verified: true, + } + logger.Infof("create auth method: %v", authMethod) + if err := tx.Create(&authMethod).Error; err != nil { + l.Errorw("failed to create device auth method", logger.Field("new_user_id", newUserId), + logger.Field("error", err.Error())) + return errors.Wrapf(xerr.NewErrCode(xerr.DatabaseInsertError), "create device auth method failed: %v", err) + } + } else { + //更新设备认证方式的用户ID为新用户ID + authMethod.UserId = newUserId + if err := tx.Save(&authMethod).Error; err != nil { + l.Errorw("failed to update device auth method", + logger.Field("identifier", deviceInfo.Identifier), + logger.Field("error", err.Error()), + ) + return errors.Wrapf(xerr.NewErrCode(xerr.DatabaseUpdateError), "update device auth method failed: %v", err) + } + } + + //检查旧用户是否还有其他认证方式 var count int64 - if err := db.Where("user_id = ? and auth_identifier != ?", oldUserId, deviceInfo.Identifier).Count(&count).Error; err != nil { + if err := tx.Model(&user.AuthMethods{}).Where("user_id = ?", oldUserId).Count(&count).Error; err != nil { l.Errorw("failed to query auth methods for old user", logger.Field("old_user_id", oldUserId), logger.Field("error", err.Error()), @@ -157,50 +204,33 @@ func (l *BindDeviceLogic) rebindDeviceToNewUser(deviceInfo *user.Device, ip, use return errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "query auth methods failed: %v", err) } - // Only disable old user if they have no other auth methods - if count == 0 { - if err := db.Model(&user.User{}).Where("id = ?", oldUserId).Delete(&user.User{}).Error; err != nil { + //如果没有其他认证方式,禁用旧用户账号 + if count < 1 { + if err := tx.Model(&user.User{}).Where("id = ?", oldUserId).Delete(&user.User{}).Error; err != nil { l.Errorw("failed to disable old user", logger.Field("old_user_id", oldUserId), logger.Field("error", err.Error()), ) return errors.Wrapf(xerr.NewErrCode(xerr.DatabaseUpdateError), "disable old user failed: %v", err) } - l.Infow("disabled old user (no other auth methods)", - logger.Field("old_user_id", oldUserId), - ) - } else { - l.Infow("old user has other auth methods, not disabling", - logger.Field("old_user_id", oldUserId), - logger.Field("non_device_auth_count", count), - ) } - // Update device auth method to new user - if err := db.Model(&user.AuthMethods{}). - Where("auth_type = ? AND auth_identifier = ?", "device", deviceInfo.Identifier). - Update("user_id", newUserId).Error; err != nil { - l.Errorw("failed to update device auth method", - logger.Field("identifier", deviceInfo.Identifier), - logger.Field("error", err.Error()), - ) - return errors.Wrapf(xerr.NewErrCode(xerr.DatabaseUpdateError), "update device auth method failed: %v", err) - } + l.Infow("disabled old user (no other auth methods)", + logger.Field("old_user_id", oldUserId), + ) - // Update device record + // 更新设备绑定的用户id deviceInfo.UserId = newUserId deviceInfo.Ip = ip deviceInfo.UserAgent = userAgent deviceInfo.Enabled = true - - if err := db.Save(deviceInfo).Error; err != nil { + if err := tx.Save(deviceInfo).Error; err != nil { l.Errorw("failed to update device", logger.Field("identifier", deviceInfo.Identifier), logger.Field("error", err.Error()), ) return errors.Wrapf(xerr.NewErrCode(xerr.DatabaseUpdateError), "update device failed: %v", err) } - return nil }) @@ -214,6 +244,15 @@ func (l *BindDeviceLogic) rebindDeviceToNewUser(deviceInfo *user.Device, ip, use return err } + err = l.svcCtx.UserModel.ClearUserCache(l.ctx, users...) + if err != nil { + l.Errorw("failed to clear user cache after rebinding", + logger.Field("old_user_id", oldUserId), + logger.Field("new_user_id", newUserId), + logger.Field("error", err.Error()), + ) + } + l.Infow("device rebound successfully", logger.Field("identifier", deviceInfo.Identifier), logger.Field("old_user_id", oldUserId),