From 3b3ed7b3c15a11ae70593a4c3c52e07689a60088 Mon Sep 17 00:00:00 2001 From: shanshanzhong Date: Wed, 29 Apr 2026 23:00:18 -0700 Subject: [PATCH] test(auth): add HTTP device no-trial check Co-Authored-By: claude-flow --- scripts/test_device_login_no_trial.go | 215 ++++++++++++++++++++++++++ 1 file changed, 215 insertions(+) create mode 100644 scripts/test_device_login_no_trial.go diff --git a/scripts/test_device_login_no_trial.go b/scripts/test_device_login_no_trial.go new file mode 100644 index 0000000..c75b31a --- /dev/null +++ b/scripts/test_device_login_no_trial.go @@ -0,0 +1,215 @@ +//go:build ignore + +package main + +import ( + "bytes" + "context" + "encoding/json" + "flag" + "fmt" + "io" + "net/http" + "os" + "strings" + "time" + + "github.com/perfect-panel/server/internal/config" + modelLog "github.com/perfect-panel/server/internal/model/log" + modelUser "github.com/perfect-panel/server/internal/model/user" + "github.com/perfect-panel/server/pkg/conf" + "github.com/perfect-panel/server/pkg/orm" + "github.com/perfect-panel/server/pkg/tool" + "gorm.io/gorm" +) + +func main() { + var ( + configPath = flag.String("config", "etc/ppanel.yaml", "config file path") + dsn = flag.String("dsn", "", "optional MySQL DSN override") + baseURL = flag.String("base-url", "", "server base URL, for example http://154.12.35.103") + identifier = flag.String("identifier", "", "optional device identifier") + ip = flag.String("ip", "198.18.77.77", "X-Forwarded-For test IP") + userAgent = flag.String("user-agent", "CodexDeviceNoTrialHTTP/1.0", "device user agent") + cleanup = flag.Bool("cleanup", false, "delete test user/device/log rows after verification") + ) + flag.Parse() + + if strings.TrimSpace(*baseURL) == "" { + fail("base-url is required") + } + if *identifier == "" { + *identifier = fmt.Sprintf("codex-device-no-trial-%d", time.Now().UnixNano()) + } + + ctx := context.Background() + cfg := loadConfig(*configPath, *dsn) + db, err := orm.ConnectMysql(orm.Mysql{Config: cfg.MySQL}) + must(err) + + fmt.Println("== HTTP device login no-trial test ==") + fmt.Printf("base_url=%s\n", strings.TrimRight(*baseURL, "/")) + fmt.Printf("mysql=%s/%s\n", cfg.MySQL.Addr, cfg.MySQL.Dbname) + fmt.Printf("identifier=%s ip=%s user_agent=%s\n", *identifier, *ip, *userAgent) + + if err = ensureIdentifierUnused(ctx, db, *identifier); err != nil { + fail("%v", err) + } + + status, body, err := postDeviceLogin(ctx, *baseURL, *identifier, *ip, *userAgent) + if err != nil { + fail("device login request failed: %v", err) + } + fmt.Printf("http_status=%d body=%s\n", status, truncate(body, 500)) + if status < 200 || status >= 300 { + fail("device login returned non-2xx status: %d", status) + } + + device, err := findDevice(ctx, db, *identifier) + if err != nil { + fail("created device not found in DB: %v", err) + } + fmt.Printf("device: id=%d sn=%s user_id=%d created_at=%s\n", + device.Id, + tool.DeviceIdToHash(device.Id), + device.UserId, + device.CreatedAt.Format(time.RFC3339), + ) + + var subs []modelUser.Subscribe + if err = db.WithContext(ctx). + Where("user_id = ?", device.UserId). + Order("id ASC"). + Find(&subs).Error; err != nil { + fail("query user_subscribe failed: %v", err) + } + if len(subs) == 0 { + fmt.Printf("user_subscribe rows: 0\n") + } + for i := range subs { + sub := &subs[i] + fmt.Printf("subscribe: id=%d order_id=%d subscribe_id=%d status=%d start=%s expire=%s\n", + sub.Id, + sub.OrderId, + sub.SubscribeId, + sub.Status, + sub.StartTime.Format(time.RFC3339), + sub.ExpireTime.Format(time.RFC3339), + ) + fail("FAIL: HTTP device login created subscription user_subscribe_id=%d user_id=%d", sub.Id, device.UserId) + } + + fmt.Printf("PASS: HTTP device login created no subscription for user_id=%d\n", device.UserId) + + if *cleanup { + if err = cleanupTestRows(ctx, db, device.UserId); err != nil { + fail("cleanup failed: %v", err) + } + fmt.Printf("cleanup: hard-deleted test rows for user_id=%d\n", device.UserId) + } +} + +func postDeviceLogin(ctx context.Context, baseURL, identifier, ip, userAgent string) (int, string, error) { + payload := map[string]string{ + "identifier": identifier, + "user_agent": userAgent, + } + data, err := json.Marshal(payload) + if err != nil { + return 0, "", err + } + url := strings.TrimRight(baseURL, "/") + "/v1/auth/login/device" + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(data)) + if err != nil { + return 0, "", err + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("User-Agent", userAgent) + req.Header.Set("X-Forwarded-For", ip) + req.Header.Set("X-Real-IP", ip) + + client := &http.Client{Timeout: 20 * time.Second} + resp, err := client.Do(req) + if err != nil { + return 0, "", err + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return resp.StatusCode, "", err + } + return resp.StatusCode, string(body), nil +} + +func loadConfig(path, dsn string) config.Config { + var cfg config.Config + conf.MustLoad(path, &cfg) + if dsn != "" { + parsed := orm.ParseDSN(dsn) + if parsed == nil { + fail("invalid dsn") + } + cfg.MySQL = *parsed + } + return cfg +} + +func ensureIdentifierUnused(ctx context.Context, db *gorm.DB, identifier string) error { + var count int64 + if err := db.WithContext(ctx). + Model(&modelUser.Device{}). + Where("identifier = ?", identifier). + Count(&count).Error; err != nil { + return err + } + if count > 0 { + return fmt.Errorf("identifier already exists: %s", identifier) + } + return nil +} + +func findDevice(ctx context.Context, db *gorm.DB, identifier string) (*modelUser.Device, error) { + var device modelUser.Device + err := db.WithContext(ctx). + Where("identifier = ?", identifier). + First(&device).Error + return &device, err +} + +func cleanupTestRows(ctx context.Context, db *gorm.DB, userID int64) error { + return db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + if err := tx.Where("object_id = ?", userID).Delete(&modelLog.SystemLog{}).Error; err != nil { + return err + } + if err := tx.Where("user_id = ?", userID).Delete(&modelUser.Subscribe{}).Error; err != nil { + return err + } + if err := tx.Where("user_id = ?", userID).Delete(&modelUser.AuthMethods{}).Error; err != nil { + return err + } + if err := tx.Where("user_id = ?", userID).Delete(&modelUser.Device{}).Error; err != nil { + return err + } + return tx.Unscoped().Where("id = ?", userID).Delete(&modelUser.User{}).Error + }) +} + +func truncate(s string, n int) string { + s = strings.TrimSpace(s) + if len(s) <= n { + return s + } + return s[:n] + "..." +} + +func must(err error) { + if err != nil { + fail("%v", err) + } +} + +func fail(format string, args ...interface{}) { + fmt.Fprintf(os.Stderr, format+"\n", args...) + os.Exit(1) +}