175 lines
4.9 KiB
Go
175 lines
4.9 KiB
Go
package user
|
||
|
||
import (
|
||
"context"
|
||
"encoding/json"
|
||
"fmt"
|
||
"net/http"
|
||
"net/http/httptest"
|
||
"reflect"
|
||
"strings"
|
||
"sync"
|
||
"testing"
|
||
"time"
|
||
"unsafe"
|
||
|
||
"github.com/alicebob/miniredis/v2"
|
||
"github.com/gorilla/websocket"
|
||
"github.com/perfect-panel/server/internal/config"
|
||
"github.com/perfect-panel/server/internal/model/user"
|
||
"github.com/perfect-panel/server/internal/svc"
|
||
"github.com/perfect-panel/server/internal/types"
|
||
"github.com/perfect-panel/server/pkg/constant"
|
||
"github.com/perfect-panel/server/pkg/device"
|
||
"github.com/redis/go-redis/v9"
|
||
"github.com/stretchr/testify/assert"
|
||
"gorm.io/gorm"
|
||
)
|
||
|
||
type MockEmailModel struct {
|
||
MockUserModel
|
||
}
|
||
|
||
func (m *MockEmailModel) FindUserAuthMethods(ctx context.Context, userId int64) ([]*user.AuthMethods, error) {
|
||
return []*user.AuthMethods{
|
||
{UserId: userId, AuthType: "device", AuthIdentifier: "device-1", Verified: true},
|
||
}, nil
|
||
}
|
||
|
||
func (m *MockEmailModel) FindUserAuthMethodByOpenID(ctx context.Context, method, openID string) (*user.AuthMethods, error) {
|
||
if openID == "test@example.com" {
|
||
// 返回已存在的用户(不同的UserId)
|
||
return &user.AuthMethods{Id: 99, UserId: 2, AuthType: "email", AuthIdentifier: openID}, nil
|
||
}
|
||
return nil, gorm.ErrRecordNotFound
|
||
}
|
||
|
||
func (m *MockEmailModel) QueryDeviceList(ctx context.Context, userId int64) ([]*user.Device, int64, error) {
|
||
// 模拟当前用户(User 1)持有设备 device-1
|
||
if userId == 1 {
|
||
return []*user.Device{
|
||
{Id: 10, UserId: 1, Identifier: "device-1", Enabled: true},
|
||
}, 1, nil
|
||
}
|
||
return nil, 0, nil
|
||
}
|
||
|
||
func (m *MockEmailModel) UpdateDevice(ctx context.Context, data *user.Device, tx ...*gorm.DB) error {
|
||
return nil
|
||
}
|
||
|
||
// 模拟 Transaction 失败,以便在 KickDevice 后停止
|
||
func (m *MockEmailModel) Transaction(ctx context.Context, fn func(db *gorm.DB) error) error {
|
||
return fmt.Errorf("stop testing here")
|
||
}
|
||
|
||
func (m *MockEmailModel) FindOne(ctx context.Context, id int64) (*user.User, error) {
|
||
return &user.User{Id: id}, nil
|
||
}
|
||
|
||
func TestBindEmailWithVerification_KickDevice(t *testing.T) {
|
||
// 1. Redis Mock
|
||
mr, err := miniredis.Run()
|
||
assert.NoError(t, err)
|
||
defer mr.Close()
|
||
|
||
rdb := redis.NewClient(&redis.Options{Addr: mr.Addr()})
|
||
|
||
// 准备验证码数据
|
||
email := "test@example.com"
|
||
code := "123456"
|
||
payload := map[string]interface{}{
|
||
"code": code,
|
||
"lastAt": time.Now().Unix(),
|
||
}
|
||
bytes, _ := json.Marshal(payload)
|
||
cacheKey := fmt.Sprintf("%s:%s:%s", config.AuthCodeCacheKey, constant.Register.String(), email)
|
||
rdb.Set(context.Background(), cacheKey, string(bytes), time.Minute)
|
||
|
||
// 2. DeviceManager Mock
|
||
// 启动 WebSocket 服务器以获取真实连接
|
||
var serverConn *websocket.Conn
|
||
connDone := make(chan struct{})
|
||
|
||
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
upgrader := websocket.Upgrader{}
|
||
c, _ := upgrader.Upgrade(w, r, nil)
|
||
serverConn = c
|
||
close(connDone)
|
||
// 保持连接直到测试结束 (read loop)
|
||
for {
|
||
if _, _, err := c.ReadMessage(); err != nil {
|
||
break
|
||
}
|
||
}
|
||
}))
|
||
defer s.Close()
|
||
|
||
// 客户端连接
|
||
wsURL := "ws" + strings.TrimPrefix(s.URL, "http")
|
||
clientConn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
|
||
assert.NoError(t, err)
|
||
defer clientConn.Close()
|
||
|
||
<-connDone // 等待服务端获取连接
|
||
|
||
dm := device.NewDeviceManager(10, 10)
|
||
|
||
// 注入设备 (UserId=1, DeviceId="device-1")
|
||
dev := &device.Device{
|
||
Session: "session-1",
|
||
DeviceID: "device-1",
|
||
Conn: serverConn,
|
||
}
|
||
|
||
// 使用反射注入
|
||
v := reflect.ValueOf(dm).Elem()
|
||
f := v.FieldByName("userDevices")
|
||
// 直接获取指针
|
||
userDevicesMap := (*sync.Map)(unsafe.Pointer(f.UnsafeAddr()))
|
||
userDevicesMap.Store(int64(1), []*device.Device{dev})
|
||
|
||
// 3. User Mock
|
||
mockModel := &MockEmailModel{}
|
||
// 初始化内部 map,虽然这里只用到 override 的方法
|
||
mockModel.users = make(map[int64]*user.User)
|
||
|
||
svcCtx := &svc.ServiceContext{
|
||
UserModel: mockModel,
|
||
Redis: rdb,
|
||
DeviceManager: dm,
|
||
Config: config.Config{
|
||
VerifyCode: config.VerifyCode{ExpireTime: 900}, // Correct type
|
||
JwtAuth: config.JwtAuth{MaxSessionsPerUser: 10},
|
||
},
|
||
}
|
||
|
||
// 4. Run Logic
|
||
currentUser := &user.User{Id: 1} // 当前用户
|
||
ctx := context.WithValue(context.Background(), constant.CtxKeyUser, currentUser)
|
||
l := NewBindEmailWithVerificationLogic(ctx, svcCtx)
|
||
|
||
req := &types.BindEmailWithVerificationRequest{
|
||
Email: email,
|
||
Code: code,
|
||
}
|
||
|
||
// 执行
|
||
_, err = l.BindEmailWithVerification(req)
|
||
// 我们预期这里会返回错误 ("stop testing here")
|
||
assert.Error(t, err)
|
||
assert.Contains(t, err.Error(), "stop testing here")
|
||
|
||
// 5. Verify
|
||
// 验证设备是否被移除 (KickDevice 会从 userDevices 中移除被踢出的设备)
|
||
val, ok := userDevicesMap.Load(int64(1))
|
||
|
||
if ok {
|
||
// 如果 key 还在,检查列表是否为空
|
||
devices := val.([]*device.Device)
|
||
assert.Empty(t, devices, "设备列表应为空 (KickDevice 应该移除设备)")
|
||
} else {
|
||
// key 不存在,说明已移除,符合预期
|
||
}
|
||
}
|