package user import ( "context" "encoding/json" "fmt" "net/http" "net/http/httptest" "reflect" "strings" "sync" "testing" "time" "unsafe" "github.com/alicebob/miniredis/v2" "github.com/gorilla/websocket" "github.com/perfect-panel/server/internal/config" "github.com/perfect-panel/server/internal/model/user" "github.com/perfect-panel/server/internal/svc" "github.com/perfect-panel/server/internal/types" "github.com/perfect-panel/server/pkg/constant" "github.com/perfect-panel/server/pkg/device" "github.com/redis/go-redis/v9" "github.com/stretchr/testify/assert" "gorm.io/gorm" ) type MockEmailModel struct { MockUserModel } func (m *MockEmailModel) FindUserAuthMethods(ctx context.Context, userId int64) ([]*user.AuthMethods, error) { return []*user.AuthMethods{ {UserId: userId, AuthType: "device", AuthIdentifier: "device-1", Verified: true}, }, nil } func (m *MockEmailModel) FindUserAuthMethodByOpenID(ctx context.Context, method, openID string) (*user.AuthMethods, error) { if openID == "test@example.com" { // 返回已存在的用户(不同的UserId) return &user.AuthMethods{Id: 99, UserId: 2, AuthType: "email", AuthIdentifier: openID}, nil } return nil, gorm.ErrRecordNotFound } func (m *MockEmailModel) QueryDeviceList(ctx context.Context, userId int64) ([]*user.Device, int64, error) { // 模拟当前用户(User 1)持有设备 device-1 if userId == 1 { return []*user.Device{ {Id: 10, UserId: 1, Identifier: "device-1", Enabled: true}, }, 1, nil } return nil, 0, nil } func (m *MockEmailModel) UpdateDevice(ctx context.Context, data *user.Device, tx ...*gorm.DB) error { return nil } // 模拟 Transaction 失败,以便在 KickDevice 后停止 func (m *MockEmailModel) Transaction(ctx context.Context, fn func(db *gorm.DB) error) error { return fmt.Errorf("stop testing here") } func (m *MockEmailModel) FindOne(ctx context.Context, id int64) (*user.User, error) { return &user.User{Id: id}, nil } func TestBindEmailWithVerification_KickDevice(t *testing.T) { // 1. Redis Mock mr, err := miniredis.Run() assert.NoError(t, err) defer mr.Close() rdb := redis.NewClient(&redis.Options{Addr: mr.Addr()}) // 准备验证码数据 email := "test@example.com" code := "123456" payload := map[string]interface{}{ "code": code, "lastAt": time.Now().Unix(), } bytes, _ := json.Marshal(payload) cacheKey := fmt.Sprintf("%s:%s:%s", config.AuthCodeCacheKey, constant.Register.String(), email) rdb.Set(context.Background(), cacheKey, string(bytes), time.Minute) // 2. DeviceManager Mock // 启动 WebSocket 服务器以获取真实连接 var serverConn *websocket.Conn connDone := make(chan struct{}) s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { upgrader := websocket.Upgrader{} c, _ := upgrader.Upgrade(w, r, nil) serverConn = c close(connDone) // 保持连接直到测试结束 (read loop) for { if _, _, err := c.ReadMessage(); err != nil { break } } })) defer s.Close() // 客户端连接 wsURL := "ws" + strings.TrimPrefix(s.URL, "http") clientConn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) assert.NoError(t, err) defer clientConn.Close() <-connDone // 等待服务端获取连接 dm := device.NewDeviceManager(10, 10) // 注入设备 (UserId=1, DeviceId="device-1") dev := &device.Device{ Session: "session-1", DeviceID: "device-1", Conn: serverConn, } // 使用反射注入 v := reflect.ValueOf(dm).Elem() f := v.FieldByName("userDevices") // 直接获取指针 userDevicesMap := (*sync.Map)(unsafe.Pointer(f.UnsafeAddr())) userDevicesMap.Store(int64(1), []*device.Device{dev}) // 3. User Mock mockModel := &MockEmailModel{} // 初始化内部 map,虽然这里只用到 override 的方法 mockModel.users = make(map[int64]*user.User) svcCtx := &svc.ServiceContext{ UserModel: mockModel, Redis: rdb, DeviceManager: dm, Config: config.Config{ VerifyCode: config.VerifyCode{ExpireTime: 900}, // Correct type JwtAuth: config.JwtAuth{MaxSessionsPerUser: 10}, }, } // 4. Run Logic currentUser := &user.User{Id: 1} // 当前用户 ctx := context.WithValue(context.Background(), constant.CtxKeyUser, currentUser) l := NewBindEmailWithVerificationLogic(ctx, svcCtx) req := &types.BindEmailWithVerificationRequest{ Email: email, Code: code, } // 执行 _, err = l.BindEmailWithVerification(req) // 我们预期这里会返回错误 ("stop testing here") assert.Error(t, err) assert.Contains(t, err.Error(), "stop testing here") // 5. Verify // 验证设备是否被移除 (KickDevice 会从 userDevices 中移除被踢出的设备) val, ok := userDevicesMap.Load(int64(1)) if ok { // 如果 key 还在,检查列表是否为空 devices := val.([]*device.Device) assert.Empty(t, devices, "设备列表应为空 (KickDevice 应该移除设备)") } else { // key 不存在,说明已移除,符合预期 } }