Some checks failed
Build docker and publish / build (20.15.1) (push) Has been cancelled
220 lines
7.5 KiB
Go
220 lines
7.5 KiB
Go
package main
|
||
|
||
import (
|
||
"context"
|
||
"fmt"
|
||
"log"
|
||
"time"
|
||
|
||
"github.com/redis/go-redis/v9"
|
||
)
|
||
|
||
/*
|
||
* 设备复用 Session 测试工具
|
||
*
|
||
* 这个测试工具用于验证设备复用 session 的逻辑是否正确
|
||
* 模拟场景:
|
||
* 1. 设备A第一次登录 - 创建新 session
|
||
* 2. 设备A再次登录 - 应该复用旧 session
|
||
* 3. 设备A的session过期 - 应该创建新 session
|
||
*/
|
||
|
||
const (
|
||
SessionIdKey = "auth:session_id"
|
||
DeviceCacheKeyKey = "auth:device"
|
||
UserSessionsKeyPrefix = "auth:user_sessions:"
|
||
)
|
||
|
||
func main() {
|
||
// 连接 Redis
|
||
rds := redis.NewClient(&redis.Options{
|
||
Addr: "localhost:6379", // 修改为你的 Redis 地址
|
||
Password: "", // 修改为你的 Redis 密码
|
||
DB: 0,
|
||
})
|
||
|
||
ctx := context.Background()
|
||
|
||
// 检查 Redis 连接
|
||
if err := rds.Ping(ctx).Err(); err != nil {
|
||
log.Fatalf("❌ 连接 Redis 失败: %v", err)
|
||
}
|
||
fmt.Println("✅ Redis 连接成功")
|
||
|
||
// 测试参数
|
||
testDeviceID := "test-device-12345"
|
||
testUserID := int64(9999)
|
||
sessionExpire := 10 * time.Second // 测试用,设置较短的过期时间
|
||
|
||
fmt.Println("\n========== 开始测试 ==========")
|
||
|
||
// 清理测试数据
|
||
cleanup(ctx, rds, testDeviceID, testUserID)
|
||
|
||
// 测试1: 第一次登录 - 应该创建新 session
|
||
fmt.Println("\n📋 测试1: 第一次登录")
|
||
sessionId1, isReuse1 := simulateLogin(ctx, rds, testDeviceID, testUserID, sessionExpire)
|
||
if isReuse1 {
|
||
fmt.Println("❌ 测试1失败: 第一次登录不应该复用 session")
|
||
} else {
|
||
fmt.Printf("✅ 测试1通过: 创建了新 session: %s\n", sessionId1)
|
||
}
|
||
|
||
// 检查 session 数量
|
||
count1 := getSessionCount(ctx, rds, testUserID)
|
||
fmt.Printf(" 当前 session 数量: %d\n", count1)
|
||
|
||
// 测试2: 再次登录(session 有效)- 应该复用 session
|
||
fmt.Println("\n📋 测试2: 再次登录(session 有效)")
|
||
sessionId2, isReuse2 := simulateLogin(ctx, rds, testDeviceID, testUserID, sessionExpire)
|
||
if !isReuse2 {
|
||
fmt.Println("❌ 测试2失败: 应该复用旧 session")
|
||
} else if sessionId1 != sessionId2 {
|
||
fmt.Printf("❌ 测试2失败: sessionId 不一致 (%s vs %s)\n", sessionId1, sessionId2)
|
||
} else {
|
||
fmt.Printf("✅ 测试2通过: 复用了旧 session: %s\n", sessionId2)
|
||
}
|
||
|
||
// 检查 session 数量 - 应该仍然是1
|
||
count2 := getSessionCount(ctx, rds, testUserID)
|
||
fmt.Printf(" 当前 session 数量: %d (预期: 1)\n", count2)
|
||
if count2 != 1 {
|
||
fmt.Println("❌ session 数量不正确!")
|
||
}
|
||
|
||
// 测试3: 模拟多设备登录
|
||
fmt.Println("\n📋 测试3: 多设备登录")
|
||
testDeviceID2 := "test-device-67890"
|
||
sessionId3, isReuse3 := simulateLogin(ctx, rds, testDeviceID2, testUserID, sessionExpire)
|
||
if isReuse3 {
|
||
fmt.Println("❌ 测试3失败: 新设备不应该复用 session")
|
||
} else {
|
||
fmt.Printf("✅ 测试3通过: 设备B创建了新 session: %s\n", sessionId3)
|
||
}
|
||
|
||
// 检查 session 数量 - 应该是2
|
||
count3 := getSessionCount(ctx, rds, testUserID)
|
||
fmt.Printf(" 当前 session 数量: %d (预期: 2)\n", count3)
|
||
|
||
// 测试4: 设备A再次登录 - 仍然应该复用
|
||
fmt.Println("\n📋 测试4: 设备A再次登录")
|
||
sessionId4, isReuse4 := simulateLogin(ctx, rds, testDeviceID, testUserID, sessionExpire)
|
||
if !isReuse4 {
|
||
fmt.Println("❌ 测试4失败: 应该复用设备A的旧 session")
|
||
} else if sessionId1 != sessionId4 {
|
||
fmt.Printf("❌ 测试4失败: sessionId 不一致 (%s vs %s)\n", sessionId1, sessionId4)
|
||
} else {
|
||
fmt.Printf("✅ 测试4通过: 设备A复用了旧 session: %s\n", sessionId4)
|
||
}
|
||
|
||
// 检查 session 数量 - 仍然应该是2
|
||
count4 := getSessionCount(ctx, rds, testUserID)
|
||
fmt.Printf(" 当前 session 数量: %d (预期: 2)\n", count4)
|
||
|
||
// 测试5: 等待 session 过期后再登录
|
||
fmt.Println("\n📋 测试5: 等待 session 过期后再登录")
|
||
fmt.Printf(" 等待 %v ...\n", sessionExpire+time.Second)
|
||
time.Sleep(sessionExpire + time.Second)
|
||
|
||
sessionId5, isReuse5 := simulateLogin(ctx, rds, testDeviceID, testUserID, sessionExpire)
|
||
if isReuse5 {
|
||
fmt.Println("❌ 测试5失败: session 过期后不应该复用")
|
||
} else {
|
||
fmt.Printf("✅ 测试5通过: 创建了新 session: %s\n", sessionId5)
|
||
}
|
||
|
||
// 测试6: 设备转移场景(关键安全测试)
|
||
fmt.Println("\n📋 测试6: 设备转移场景(用户A的设备被用户B使用)")
|
||
testDeviceID3 := "test-device-transfer"
|
||
testUserA := int64(1001)
|
||
testUserB := int64(1002)
|
||
|
||
// 用户A用设备登录
|
||
cleanup(ctx, rds, testDeviceID3, testUserA)
|
||
cleanup(ctx, rds, testDeviceID3, testUserB)
|
||
sessionA, _ := simulateLogin(ctx, rds, testDeviceID3, testUserA, sessionExpire)
|
||
fmt.Printf(" 用户A登录,session: %s\n", sessionA)
|
||
|
||
// 用户B用同一设备登录(设备转移场景)
|
||
sessionB, isReuseB := simulateLogin(ctx, rds, testDeviceID3, testUserB, sessionExpire)
|
||
if isReuseB {
|
||
fmt.Println("❌ 测试6失败: 用户B不应该复用用户A的session!(安全漏洞)")
|
||
} else {
|
||
fmt.Printf("✅ 测试6通过: 用户B创建了新 session: %s\n", sessionB)
|
||
}
|
||
|
||
// 验证用户A和B的session不同
|
||
if sessionA == sessionB {
|
||
fmt.Println("❌ 安全问题: 两个用户使用了相同的session!")
|
||
} else {
|
||
fmt.Println("✅ 安全验证通过: 两个用户使用不同的session")
|
||
}
|
||
cleanup(ctx, rds, testDeviceID, testUserID)
|
||
cleanup(ctx, rds, testDeviceID2, testUserID)
|
||
|
||
fmt.Println("\n========== 测试完成 ==========")
|
||
}
|
||
|
||
// simulateLogin 模拟登录逻辑
|
||
// 返回: sessionId, isReuse (是否复用了旧 session)
|
||
func simulateLogin(ctx context.Context, rds *redis.Client, deviceID string, userID int64, expire time.Duration) (string, bool) {
|
||
var sessionId string
|
||
var reuseSession bool
|
||
|
||
deviceCacheKey := fmt.Sprintf("%v:%v", DeviceCacheKeyKey, deviceID)
|
||
|
||
// 检查设备是否有旧的有效 session
|
||
if oldSid, getErr := rds.Get(ctx, deviceCacheKey).Result(); getErr == nil && oldSid != "" {
|
||
// 检查旧 session 是否仍然有效 AND 属于当前用户
|
||
oldSessionKey := fmt.Sprintf("%v:%v", SessionIdKey, oldSid)
|
||
if uidStr, existErr := rds.Get(ctx, oldSessionKey).Result(); existErr == nil && uidStr != "" {
|
||
// 验证 session 属于当前用户 (防止设备转移后复用其他用户的session)
|
||
if uidStr == fmt.Sprintf("%d", userID) {
|
||
sessionId = oldSid
|
||
reuseSession = true
|
||
}
|
||
}
|
||
}
|
||
|
||
if !reuseSession {
|
||
// 生成新的 sessionId
|
||
sessionId = fmt.Sprintf("session-%d-%d", userID, time.Now().UnixNano())
|
||
|
||
// 添加到用户的 session 集合
|
||
sessionsKey := fmt.Sprintf("%s%v", UserSessionsKeyPrefix, userID)
|
||
rds.ZAdd(ctx, sessionsKey, redis.Z{Score: float64(time.Now().Unix()), Member: sessionId})
|
||
rds.Expire(ctx, sessionsKey, expire)
|
||
}
|
||
|
||
// 存储/刷新 session
|
||
sessionIdCacheKey := fmt.Sprintf("%v:%v", SessionIdKey, sessionId)
|
||
rds.Set(ctx, sessionIdCacheKey, userID, expire)
|
||
|
||
// 存储/刷新设备到session的映射
|
||
rds.Set(ctx, deviceCacheKey, sessionId, expire)
|
||
|
||
return sessionId, reuseSession
|
||
}
|
||
|
||
// getSessionCount 获取用户的 session 数量
|
||
func getSessionCount(ctx context.Context, rds *redis.Client, userID int64) int64 {
|
||
sessionsKey := fmt.Sprintf("%s%v", UserSessionsKeyPrefix, userID)
|
||
count, _ := rds.ZCard(ctx, sessionsKey).Result()
|
||
return count
|
||
}
|
||
|
||
// cleanup 清理测试数据
|
||
func cleanup(ctx context.Context, rds *redis.Client, deviceID string, userID int64) {
|
||
deviceCacheKey := fmt.Sprintf("%v:%v", DeviceCacheKeyKey, deviceID)
|
||
sessionsKey := fmt.Sprintf("%s%v", UserSessionsKeyPrefix, userID)
|
||
|
||
// 获取设备的 sessionId
|
||
if sid, err := rds.Get(ctx, deviceCacheKey).Result(); err == nil {
|
||
sessionIdCacheKey := fmt.Sprintf("%v:%v", SessionIdKey, sid)
|
||
rds.Del(ctx, sessionIdCacheKey)
|
||
}
|
||
|
||
rds.Del(ctx, deviceCacheKey)
|
||
rds.Del(ctx, sessionsKey)
|
||
}
|