hi-server/internal/logic/public/user/bindEmailLogic_test.go

175 lines
4.9 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package user
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"reflect"
"strings"
"sync"
"testing"
"time"
"unsafe"
"github.com/alicebob/miniredis/v2"
"github.com/gorilla/websocket"
"github.com/perfect-panel/server/internal/config"
"github.com/perfect-panel/server/internal/model/user"
"github.com/perfect-panel/server/internal/svc"
"github.com/perfect-panel/server/internal/types"
"github.com/perfect-panel/server/pkg/constant"
"github.com/perfect-panel/server/pkg/device"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/assert"
"gorm.io/gorm"
)
type MockEmailModel struct {
MockUserModel
}
func (m *MockEmailModel) FindUserAuthMethods(ctx context.Context, userId int64) ([]*user.AuthMethods, error) {
return []*user.AuthMethods{
{UserId: userId, AuthType: "device", AuthIdentifier: "device-1", Verified: true},
}, nil
}
func (m *MockEmailModel) FindUserAuthMethodByOpenID(ctx context.Context, method, openID string) (*user.AuthMethods, error) {
if openID == "test@example.com" {
// 返回已存在的用户不同的UserId
return &user.AuthMethods{Id: 99, UserId: 2, AuthType: "email", AuthIdentifier: openID}, nil
}
return nil, gorm.ErrRecordNotFound
}
func (m *MockEmailModel) QueryDeviceList(ctx context.Context, userId int64) ([]*user.Device, int64, error) {
// 模拟当前用户(User 1)持有设备 device-1
if userId == 1 {
return []*user.Device{
{Id: 10, UserId: 1, Identifier: "device-1", Enabled: true},
}, 1, nil
}
return nil, 0, nil
}
func (m *MockEmailModel) UpdateDevice(ctx context.Context, data *user.Device, tx ...*gorm.DB) error {
return nil
}
// 模拟 Transaction 失败,以便在 KickDevice 后停止
func (m *MockEmailModel) Transaction(ctx context.Context, fn func(db *gorm.DB) error) error {
return fmt.Errorf("stop testing here")
}
func (m *MockEmailModel) FindOne(ctx context.Context, id int64) (*user.User, error) {
return &user.User{Id: id}, nil
}
func TestBindEmailWithVerification_KickDevice(t *testing.T) {
// 1. Redis Mock
mr, err := miniredis.Run()
assert.NoError(t, err)
defer mr.Close()
rdb := redis.NewClient(&redis.Options{Addr: mr.Addr()})
// 准备验证码数据
email := "test@example.com"
code := "123456"
payload := map[string]interface{}{
"code": code,
"lastAt": time.Now().Unix(),
}
bytes, _ := json.Marshal(payload)
cacheKey := fmt.Sprintf("%s:%s:%s", config.AuthCodeCacheKey, constant.Register.String(), email)
rdb.Set(context.Background(), cacheKey, string(bytes), time.Minute)
// 2. DeviceManager Mock
// 启动 WebSocket 服务器以获取真实连接
var serverConn *websocket.Conn
connDone := make(chan struct{})
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
upgrader := websocket.Upgrader{}
c, _ := upgrader.Upgrade(w, r, nil)
serverConn = c
close(connDone)
// 保持连接直到测试结束 (read loop)
for {
if _, _, err := c.ReadMessage(); err != nil {
break
}
}
}))
defer s.Close()
// 客户端连接
wsURL := "ws" + strings.TrimPrefix(s.URL, "http")
clientConn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
assert.NoError(t, err)
defer clientConn.Close()
<-connDone // 等待服务端获取连接
dm := device.NewDeviceManager(10, 10)
// 注入设备 (UserId=1, DeviceId="device-1")
dev := &device.Device{
Session: "session-1",
DeviceID: "device-1",
Conn: serverConn,
}
// 使用反射注入
v := reflect.ValueOf(dm).Elem()
f := v.FieldByName("userDevices")
// 直接获取指针
userDevicesMap := (*sync.Map)(unsafe.Pointer(f.UnsafeAddr()))
userDevicesMap.Store(int64(1), []*device.Device{dev})
// 3. User Mock
mockModel := &MockEmailModel{}
// 初始化内部 map虽然这里只用到 override 的方法
mockModel.users = make(map[int64]*user.User)
svcCtx := &svc.ServiceContext{
UserModel: mockModel,
Redis: rdb,
DeviceManager: dm,
Config: config.Config{
VerifyCode: config.VerifyCode{ExpireTime: 900}, // Correct type
JwtAuth: config.JwtAuth{MaxSessionsPerUser: 10},
},
}
// 4. Run Logic
currentUser := &user.User{Id: 1} // 当前用户
ctx := context.WithValue(context.Background(), constant.CtxKeyUser, currentUser)
l := NewBindEmailWithVerificationLogic(ctx, svcCtx)
req := &types.BindEmailWithVerificationRequest{
Email: email,
Code: code,
}
// 执行
_, err = l.BindEmailWithVerification(req)
// 我们预期这里会返回错误 ("stop testing here")
assert.Error(t, err)
assert.Contains(t, err.Error(), "stop testing here")
// 5. Verify
// 验证设备是否被移除 (KickDevice 会从 userDevices 中移除被踢出的设备)
val, ok := userDevicesMap.Load(int64(1))
if ok {
// 如果 key 还在,检查列表是否为空
devices := val.([]*device.Device)
assert.Empty(t, devices, "设备列表应为空 (KickDevice 应该移除设备)")
} else {
// key 不存在,说明已移除,符合预期
}
}