diff --git a/internal/logic/auth/trialEmailWhitelist.go b/internal/logic/auth/trialEmailWhitelist.go index 2021b2e..019f995 100644 --- a/internal/logic/auth/trialEmailWhitelist.go +++ b/internal/logic/auth/trialEmailWhitelist.go @@ -2,6 +2,7 @@ package auth import ( "context" + "net/mail" "strings" "github.com/perfect-panel/server/internal/config" @@ -15,11 +16,10 @@ func IsEmailDomainWhitelisted(email, whitelistCSV string) bool { if whitelistCSV == "" { return false } - parts := strings.SplitN(email, "@", 2) - if len(parts) != 2 { + _, domain, ok := parseStrictEmail(email) + if !ok { return false } - domain := strings.ToLower(strings.TrimSpace(parts[1])) for _, d := range strings.Split(whitelistCSV, ",") { if strings.ToLower(strings.TrimSpace(d)) == domain { return true @@ -32,10 +32,16 @@ func ShouldGrantTrialForEmail(register config.RegisterConfig, email string) bool if !register.EnableTrial { return false } + if !IsValidTrialEmail(email) { + return false + } // 无论白名单是否启用,泛域名邮箱(含 + 别名或 Gmail 点号)始终拒绝赠送 if IsDisposableAlias(email) { return false } + if isConfusableGmailDomain(emailDomain(email)) { + return false + } if !register.EnableTrialEmailWhitelist { return true } @@ -52,11 +58,10 @@ func ShouldGrantTrialForEmail(register config.RegisterConfig, email string) bool // For Gmail-like domains, local part containing "." or "+" is rejected. // For all other domains, only "+" alias is rejected. func IsDisposableAlias(email string) bool { - parts := strings.SplitN(strings.ToLower(strings.TrimSpace(email)), "@", 2) - if len(parts) != 2 { + local, domain, ok := parseStrictEmail(email) + if !ok { return false } - local, domain := parts[0], parts[1] // All domains: reject + alias if strings.ContainsRune(local, '+') { @@ -74,11 +79,10 @@ func IsDisposableAlias(email string) bool { // Removes dots from local part for Gmail-like providers (gmail.com, googlemail.com). func NormalizeEmail(email string) string { email = strings.ToLower(strings.TrimSpace(email)) - parts := strings.SplitN(email, "@", 2) - if len(parts) != 2 { + local, domain, ok := parseStrictEmail(email) + if !ok { return email } - local, domain := parts[0], parts[1] // Strip + alias if idx := strings.IndexByte(local, '+'); idx != -1 { @@ -101,6 +105,51 @@ func isGmailLikeDomain(domain string) bool { return false } +func IsValidTrialEmail(email string) bool { + local, domain, ok := parseStrictEmail(email) + if !ok { + return false + } + return local != "" && domain != "" +} + +func parseStrictEmail(email string) (local, domain string, ok bool) { + email = strings.ToLower(strings.TrimSpace(email)) + if email == "" || strings.ContainsAny(email, " \t\r\n") { + return "", "", false + } + addr, err := mail.ParseAddress(email) + if err != nil || addr.Address != email || addr.Name != "" { + return "", "", false + } + parts := strings.Split(addr.Address, "@") + if len(parts) != 2 { + return "", "", false + } + local = strings.TrimSpace(parts[0]) + domain = strings.Trim(strings.TrimSpace(parts[1]), ".") + if local == "" || domain == "" || strings.Contains(domain, "..") || !strings.Contains(domain, ".") { + return "", "", false + } + return local, domain, true +} + +func emailDomain(email string) string { + _, domain, ok := parseStrictEmail(email) + if !ok { + return "" + } + return domain +} + +func isConfusableGmailDomain(domain string) bool { + switch strings.ToLower(strings.TrimSpace(domain)) { + case "gmaial.com", "gmial.com", "gmai.com", "gamil.com", "gmal.com", "gmail.co", "gmail.con": + return true + } + return false +} + // NormalizedEmailHasTrial returns true if any user with the same normalized email // already holds a trial subscription. Only performs the cross-user DB check when // normalization actually changes the email (i.e., dots removed or + alias stripped). diff --git a/internal/logic/auth/trialEmailWhitelist_test.go b/internal/logic/auth/trialEmailWhitelist_test.go new file mode 100644 index 0000000..b5844b3 --- /dev/null +++ b/internal/logic/auth/trialEmailWhitelist_test.go @@ -0,0 +1,188 @@ +package auth + +import ( + "testing" + + "github.com/perfect-panel/server/internal/config" +) + +func TestNormalizeEmail(t *testing.T) { + tests := []struct { + input string + want string + }{ + // Gmail dot trick + {"a.v.x.xx@gmail.com", "avxxx@gmail.com"}, + {"john.doe@gmail.com", "johndoe@gmail.com"}, + {"a.b.c.d.e@gmail.com", "abcde@gmail.com"}, + // Gmail + alias + {"user+tag@gmail.com", "user@gmail.com"}, + {"a.b+tag@gmail.com", "ab@gmail.com"}, + // Googlemail + {"a.b@googlemail.com", "ab@googlemail.com"}, + // Non-Gmail: dots preserved + {"john.doe@outlook.com", "john.doe@outlook.com"}, + {"john.doe@qq.com", "john.doe@qq.com"}, + // + alias stripped for all providers + {"user+spam@outlook.com", "user@outlook.com"}, + {"user+spam@qq.com", "user@qq.com"}, + // Case insensitive + {"User@Gmail.COM", "user@gmail.com"}, + {"A.B@Gmail.com", "ab@gmail.com"}, + // No change for normal non-gmail email + {"abc@163.com", "abc@163.com"}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + got := NormalizeEmail(tt.input) + if got != tt.want { + t.Errorf("NormalizeEmail(%q) = %q, want %q", tt.input, got, tt.want) + } + }) + } +} + +func TestNormalizeEmail_NoChangeSkipsCheck(t *testing.T) { + // These emails should NOT trigger cross-user check (normalized == original) + noChangeCases := []string{ + "abc@163.com", + "john.doe@outlook.com", + "user@qq.com", + } + for _, email := range noChangeCases { + normalized := NormalizeEmail(email) + lower := email + if normalized == lower { + // correct: no normalization change, NormalizedEmailHasTrial would return false early + } + } +} + +func TestShouldGrantTrialForEmail(t *testing.T) { + // 模拟线上配置:白名单开启,gmail.com 也在名单里 + rcWithGmail := config.RegisterConfig{ + EnableTrial: true, + EnableTrialEmailWhitelist: true, + TrialEmailDomainWhitelist: "hifastapp.com,hifastvpn.com,126.com,139.com,163.com,gmail.com", + } + // 白名单关闭 + rcNoWhitelist := config.RegisterConfig{ + EnableTrial: true, + EnableTrialEmailWhitelist: false, + } + + tests := []struct { + name string + rc config.RegisterConfig + email string + want bool + reason string + }{ + { + name: "gmail dot trick - blocked even if gmail.com in whitelist", + rc: rcWithGmail, + email: "s.m.s.n.fsmbt.d.ndny@gmail.com", + want: false, + reason: "gmail 泛域名(含点号)应拒绝", + }, + { + name: "gmail plus alias - blocked", + rc: rcWithGmail, + email: "user+tag@gmail.com", + want: false, + reason: "gmail +别名应拒绝", + }, + { + name: "clean gmail - allowed", + rc: rcWithGmail, + email: "normaluser@gmail.com", + want: true, + reason: "干净的 gmail 应放行", + }, + { + name: "163 with dot - allowed (non-gmail dot is ok)", + rc: rcWithGmail, + email: "s.m.s.n@163.com", + want: true, + reason: "非 gmail 域点号不拦截", + }, + { + name: "163 plus alias - blocked", + rc: rcWithGmail, + email: "user+spam@163.com", + want: false, + reason: "所有域名的 +别名都拦截", + }, + { + name: "gmail typo squatting domain - blocked even if accidentally whitelisted", + rc: config.RegisterConfig{EnableTrial: true, EnableTrialEmailWhitelist: true, TrialEmailDomainWhitelist: "gmail.com,gmaial.com"}, + email: "1.2.3.4xxx@gmaial.com", + want: false, + reason: "易混淆 Gmail 域名不应发放试用", + }, + { + name: "invalid empty local - blocked", + rc: rcWithGmail, + email: "@gmail.com", + want: false, + reason: "邮箱 local 为空应拒绝", + }, + { + name: "subdomain spoof - blocked", + rc: rcWithGmail, + email: "user@fake.gmail.com", + want: false, + reason: "白名单必须精确匹配域名,不匹配子域", + }, + { + name: "whitelist disabled - gmail dot trick still blocked", + rc: rcNoWhitelist, + email: "s.m.s.n.fsmbt.d.ndny@gmail.com", + want: false, + reason: "白名单未启用,但泛域名仍应拒绝", + }, + { + name: "trial disabled - always blocked", + rc: config.RegisterConfig{EnableTrial: false}, + email: "user@163.com", + want: false, + reason: "试用未开启", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := ShouldGrantTrialForEmail(tt.rc, tt.email) + if got != tt.want { + t.Errorf("ShouldGrantTrialForEmail(%q) = %v, want %v | reason: %s", + tt.email, got, tt.want, tt.reason) + } + }) + } +} + +func TestIsEmailDomainWhitelisted(t *testing.T) { + whitelist := "gmail.com,edu.cn,outlook.com" + tests := []struct { + email string + want bool + }{ + {"user@gmail.com", true}, + {"user@edu.cn", true}, + {"User@Gmail.COM", true}, + {"user@yahoo.com", false}, + {"user@fake.gmail.com", false}, // subdomain not matched + {"user@", false}, + {"notanemail", false}, + {"@gmail.com", false}, + } + for _, tt := range tests { + t.Run(tt.email, func(t *testing.T) { + got := IsEmailDomainWhitelisted(tt.email, whitelist) + if got != tt.want { + t.Errorf("IsEmailDomainWhitelisted(%q) = %v, want %v", tt.email, got, tt.want) + } + }) + } +} diff --git a/internal/logic/common/newUserEligibility.go b/internal/logic/common/newUserEligibility.go index 73a8cd7..f2f2235 100644 --- a/internal/logic/common/newUserEligibility.go +++ b/internal/logic/common/newUserEligibility.go @@ -76,7 +76,10 @@ func CountScopedSubscribePurchaseOrders( var count int64 query := db.WithContext(ctx). Model(&modelOrder.Order{}). - Where("user_id IN ? AND subscribe_id = ? AND type IN ? AND amount > 0", scopeUserIDs, subscribeID, []int64{1, 2}) + Where("user_id IN ? AND type IN ?", scopeUserIDs, []int64{1, 2}) + if subscribeID > 0 { + query = query.Where("subscribe_id = ?", subscribeID) + } if len(statuses) > 0 { query = query.Where("status IN ?", statuses) } diff --git a/internal/logic/common/subscribeModeRoute.go b/internal/logic/common/subscribeModeRoute.go index 54dff89..82384b1 100644 --- a/internal/logic/common/subscribeModeRoute.go +++ b/internal/logic/common/subscribeModeRoute.go @@ -52,12 +52,8 @@ func ResolvePurchaseRoute( return decision, nil } - if requestedSubscribeID != anchorSub.SubscribeId { - return nil, ErrSingleModePlanMismatch - } - decision.Route = PurchaseRoutePurchaseToRenewal - decision.ResolvedSubscribeID = anchorSub.SubscribeId + decision.ResolvedSubscribeID = requestedSubscribeID decision.Anchor = anchorSub return decision, nil } diff --git a/internal/logic/common/subscribeModeRoute_test.go b/internal/logic/common/subscribeModeRoute_test.go new file mode 100644 index 0000000..333d256 --- /dev/null +++ b/internal/logic/common/subscribeModeRoute_test.go @@ -0,0 +1,36 @@ +package common + +import ( + "context" + "testing" + "time" + + "github.com/perfect-panel/server/internal/model/user" + "github.com/stretchr/testify/require" +) + +func TestResolvePurchaseRoute_AllowsPlanChangeForExistingSubscription(t *testing.T) { + anchor := &user.Subscribe{ + Id: 10, + UserId: 20, + OrderId: 30, + SubscribeId: 1, + Token: "existing-token", + ExpireTime: time.Now().Add(time.Hour), + } + + decision, err := ResolvePurchaseRoute( + context.Background(), + true, + anchor.UserId, + 2, + func(context.Context, int64) (*user.Subscribe, error) { + return anchor, nil + }, + ) + + require.NoError(t, err) + require.Equal(t, PurchaseRoutePurchaseToRenewal, decision.Route) + require.Equal(t, int64(2), decision.ResolvedSubscribeID) + require.Equal(t, anchor, decision.Anchor) +} diff --git a/internal/logic/public/order/getDiscount.go b/internal/logic/public/order/getDiscount.go index 87dc722..1c0b3ac 100644 --- a/internal/logic/public/order/getDiscount.go +++ b/internal/logic/public/order/getDiscount.go @@ -12,6 +12,9 @@ func getDiscount(discounts []types.SubscribeDiscount, inputMonths int64, isNewUs if d.Quantity != inputMonths || d.Discount <= 0 || d.Discount >= 100 { continue } + if d.NewUserOnly && !isNewUser { + continue + } if isNewUser { // lowest discount value = biggest saving if best < 0 || d.Discount < best { @@ -50,4 +53,3 @@ func isNewUserOnlyForQuantity(discounts []types.SubscribeDiscount, inputQuantity } return hasNewUserOnly && !hasFallback } - diff --git a/internal/logic/public/order/getDiscount_test.go b/internal/logic/public/order/getDiscount_test.go new file mode 100644 index 0000000..b348531 --- /dev/null +++ b/internal/logic/public/order/getDiscount_test.go @@ -0,0 +1,16 @@ +package order + +import ( + "testing" + + "github.com/perfect-panel/server/internal/types" + "github.com/stretchr/testify/require" +) + +func TestGetDiscount_SkipsNewUserOnlyTierForExistingUser(t *testing.T) { + discount := getDiscount([]types.SubscribeDiscount{ + {Quantity: 1, Discount: 90, NewUserOnly: true}, + }, 1, false) + + require.Equal(t, float64(1), discount) +} diff --git a/internal/logic/public/order/newUserDiscountEligibility.go b/internal/logic/public/order/newUserDiscountEligibility.go index 2d08da2..92c23be 100644 --- a/internal/logic/public/order/newUserDiscountEligibility.go +++ b/internal/logic/public/order/newUserDiscountEligibility.go @@ -47,7 +47,7 @@ func resolveNewUserDiscountEligibility( ctx, db, eligibility.ScopeUserIDs, - subscribeID, + 0, []int64{2, 5}, "", ) diff --git a/internal/logic/public/order/purchaseLogic.go b/internal/logic/public/order/purchaseLogic.go index a220e18..b41604d 100644 --- a/internal/logic/public/order/purchaseLogic.go +++ b/internal/logic/public/order/purchaseLogic.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "math" + "strings" "time" "github.com/google/uuid" @@ -111,18 +112,21 @@ func (l *PurchaseLogic) Purchase(req *types.PurchaseOrderRequest) (resp *types.P } } - // 非单订阅模式下,若用户已有同套餐订阅(含过期),提前路由为续费,防止创建重复订阅 + // 全局单订阅口径:若用户已有任意付费订阅(含过期),提前路由为续费/换套餐, + // 防止不同套餐购买创建第二条订阅。 if !l.svcCtx.Config.Subscribe.SingleModel && orderType == 1 { var existSub user.Subscribe if e := l.svcCtx.DB.WithContext(l.ctx). Model(&user.Subscribe{}). - Where("user_id = ? AND subscribe_id = ?", entitlement.EffectiveUserID, targetSubscribeID). + Where("user_id = ? AND token != '' AND (order_id > 0 OR token LIKE 'iap:%')", entitlement.EffectiveUserID). Order("expire_time DESC"). + Order("updated_at DESC"). + Order("id DESC"). First(&existSub).Error; e == nil && existSub.Id > 0 && existSub.Token != "" { orderType = 2 parentOrderID = existSub.OrderId subscribeToken = existSub.Token - l.Infow("[Purchase] non-single mode purchase routed to renewal (existing subscription found)", + l.Infow("[Purchase] purchase routed to renewal/change plan (existing subscription found)", logger.Field("existing_subscribe_id", existSub.Id), logger.Field("existing_status", existSub.Status), logger.Field("user_id", u.Id), @@ -311,13 +315,13 @@ func (l *PurchaseLogic) Purchase(req *types.PurchaseOrderRequest) (resp *types.P // check subscribe plan quota limit inside transaction to prevent race condition if orderInfo.Type == 1 && sub.Quota > 0 { var currentUserSub []user.Subscribe - if e := db.Model(&user.Subscribe{}).Where("user_id = ?", u.Id).Find(¤tUserSub).Error; e != nil { + if e := db.Model(&user.Subscribe{}).Where("user_id = ?", entitlement.EffectiveUserID).Find(¤tUserSub).Error; e != nil { l.Errorw("[Purchase] Database query error", logger.Field("error", e.Error()), logger.Field("user_id", u.Id)) return e } var count int64 for _, v := range currentUserSub { - if v.SubscribeId == targetSubscribeID { + if v.OrderId > 0 || strings.HasPrefix(v.Token, "iap:") { count++ } } diff --git a/internal/logic/public/order/purchaseNewUserOnly_test.go b/internal/logic/public/order/purchaseNewUserOnly_test.go new file mode 100644 index 0000000..b976eea --- /dev/null +++ b/internal/logic/public/order/purchaseNewUserOnly_test.go @@ -0,0 +1,766 @@ +package order + +import ( + "context" + "testing" + "time" + + "github.com/alicebob/miniredis/v2" + "github.com/hibiken/asynq" + modelOrder "github.com/perfect-panel/server/internal/model/order" + "github.com/perfect-panel/server/internal/model/payment" + subModel "github.com/perfect-panel/server/internal/model/subscribe" + "github.com/perfect-panel/server/internal/model/user" + "github.com/perfect-panel/server/internal/svc" + "github.com/perfect-panel/server/internal/types" + "github.com/perfect-panel/server/pkg/constant" + "github.com/perfect-panel/server/pkg/xerr" + "github.com/redis/go-redis/v9" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gorm.io/driver/sqlite" + "gorm.io/gorm" + "gorm.io/gorm/logger" +) + +// setupNewUserOnlyDB 创建带必要表的 SQLite 内存数据库 +func setupNewUserOnlyDB(t *testing.T) *gorm.DB { + t.Helper() + db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{ + Logger: logger.Default.LogMode(logger.Silent), + }) + require.NoError(t, err, "failed to open in-memory SQLite") + db.Exec("PRAGMA foreign_keys = OFF") + + sqls := []string{ + `CREATE TABLE IF NOT EXISTS "subscribe" ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name VARCHAR(255) NOT NULL DEFAULT '', + language VARCHAR(255) NOT NULL DEFAULT '', + description TEXT, + unit_price INTEGER NOT NULL DEFAULT 0, + unit_time VARCHAR(255) NOT NULL DEFAULT '', + discount TEXT, + replacement INTEGER NOT NULL DEFAULT 0, + inventory INTEGER NOT NULL DEFAULT -1, + traffic INTEGER NOT NULL DEFAULT 0, + speed_limit INTEGER NOT NULL DEFAULT 0, + device_limit INTEGER NOT NULL DEFAULT 0, + quota INTEGER NOT NULL DEFAULT 0, + new_user_only TINYINT DEFAULT 0, + nodes VARCHAR(255), + node_tags VARCHAR(255), + show TINYINT NOT NULL DEFAULT 0, + sell TINYINT NOT NULL DEFAULT 1, + sort INTEGER NOT NULL DEFAULT 0, + deduction_ratio INTEGER DEFAULT 0, + allow_deduction TINYINT DEFAULT 1, + reset_cycle INTEGER DEFAULT 0, + renewal_reset TINYINT DEFAULT 0, + show_original_price TINYINT NOT NULL DEFAULT 1, + created_at DATETIME, + updated_at DATETIME + )`, + `CREATE TABLE IF NOT EXISTS "order" ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + parent_id INTEGER DEFAULT NULL, + user_id INTEGER NOT NULL DEFAULT 0, + subscription_user_id INTEGER NOT NULL DEFAULT 0, + order_no VARCHAR(255) NOT NULL DEFAULT '' UNIQUE, + type TINYINT NOT NULL DEFAULT 1, + quantity INTEGER NOT NULL DEFAULT 1, + price INTEGER NOT NULL DEFAULT 0, + amount INTEGER NOT NULL DEFAULT 0, + gift_amount INTEGER NOT NULL DEFAULT 0, + discount INTEGER NOT NULL DEFAULT 0, + coupon VARCHAR(255) DEFAULT NULL, + coupon_discount INTEGER NOT NULL DEFAULT 0, + commission INTEGER NOT NULL DEFAULT 0, + payment_id INTEGER NOT NULL DEFAULT 0, + method VARCHAR(255) NOT NULL DEFAULT '', + fee_amount INTEGER NOT NULL DEFAULT 0, + trade_no VARCHAR(255) DEFAULT NULL, + app_account_token VARCHAR(255) DEFAULT NULL, + status TINYINT NOT NULL DEFAULT 1, + subscribe_id INTEGER NOT NULL DEFAULT 0, + subscribe_token VARCHAR(255) DEFAULT NULL, + is_new TINYINT NOT NULL DEFAULT 0, + created_at DATETIME, + updated_at DATETIME, + deleted_at DATETIME DEFAULT NULL + )`, + `CREATE TABLE IF NOT EXISTS "user" ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + password VARCHAR(100) NOT NULL DEFAULT '', + algo VARCHAR(20) DEFAULT 'default', + salt VARCHAR(20) DEFAULT NULL, + avatar TEXT, + balance INTEGER DEFAULT 0, + refer_code VARCHAR(20) DEFAULT '', + referer_id INTEGER DEFAULT 0, + commission INTEGER DEFAULT 0, + referral_percentage INTEGER DEFAULT 0, + only_first_purchase TINYINT DEFAULT 1, + gift_amount INTEGER DEFAULT 0, + enable TINYINT DEFAULT 1, + is_admin TINYINT DEFAULT 0, + enable_balance_notify TINYINT DEFAULT 0, + enable_login_notify TINYINT DEFAULT 0, + enable_subscribe_notify TINYINT DEFAULT 0, + enable_trade_notify TINYINT DEFAULT 0, + rules TEXT, + member_status VARCHAR(20) DEFAULT '', + created_at DATETIME, + updated_at DATETIME, + deleted_at DATETIME DEFAULT NULL + )`, + `CREATE TABLE IF NOT EXISTS "payment" ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name VARCHAR(100) NOT NULL DEFAULT '', + platform VARCHAR(100) NOT NULL DEFAULT '', + icon VARCHAR(255) DEFAULT '', + domain VARCHAR(255) DEFAULT '', + config TEXT NOT NULL DEFAULT '{}', + description TEXT, + fee_mode TINYINT NOT NULL DEFAULT 0, + fee_percent INTEGER DEFAULT 0, + fee_amount INTEGER DEFAULT 0, + enable TINYINT NOT NULL DEFAULT 1, + token VARCHAR(255) NOT NULL DEFAULT '' UNIQUE + )`, + `CREATE TABLE IF NOT EXISTS "user_subscribe" ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_id INTEGER NOT NULL DEFAULT 0, + order_id INTEGER NOT NULL DEFAULT 0, + subscribe_id INTEGER NOT NULL DEFAULT 0, + start_time DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + expire_time DATETIME DEFAULT NULL, + finished_at DATETIME DEFAULT NULL, + traffic INTEGER DEFAULT 0, + download INTEGER DEFAULT 0, + upload INTEGER DEFAULT 0, + token VARCHAR(255) DEFAULT '' UNIQUE, + uuid VARCHAR(255) DEFAULT '' UNIQUE, + status TINYINT DEFAULT 0, + note VARCHAR(500) DEFAULT '', + created_at DATETIME, + updated_at DATETIME + )`, + `CREATE TABLE IF NOT EXISTS "user_device" ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + ip VARCHAR(255) NOT NULL DEFAULT '', + user_id INTEGER NOT NULL DEFAULT 0, + user_agent TEXT, + identifier VARCHAR(255) NOT NULL DEFAULT '' UNIQUE, + short_code VARCHAR(255) NOT NULL DEFAULT '', + online TINYINT NOT NULL DEFAULT 0, + enabled TINYINT NOT NULL DEFAULT 1, + created_at DATETIME, + updated_at DATETIME + )`, + `CREATE TABLE IF NOT EXISTS "user_family" ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + owner_user_id INTEGER NOT NULL DEFAULT 0, + max_members INTEGER NOT NULL DEFAULT 2, + status TINYINT DEFAULT 0, + created_at DATETIME, + updated_at DATETIME, + deleted_at DATETIME DEFAULT NULL + )`, + `CREATE TABLE IF NOT EXISTS "user_family_member" ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + family_id INTEGER NOT NULL DEFAULT 0, + user_id INTEGER NOT NULL DEFAULT 0, + role TINYINT DEFAULT 0, + status TINYINT DEFAULT 0, + join_source VARCHAR(32) NOT NULL DEFAULT '', + joined_at DATETIME, + left_at DATETIME DEFAULT NULL, + created_at DATETIME, + updated_at DATETIME, + deleted_at DATETIME DEFAULT NULL + )`, + } + for _, sql := range sqls { + require.NoError(t, db.Exec(sql).Error) + } + return db +} + +// setupNewUserOnlyRedis 启动 miniredis,返回 redis.Client 和 miniredis 句柄 +func setupNewUserOnlyRedis(t *testing.T) (*redis.Client, *miniredis.Miniredis) { + t.Helper() + mr, err := miniredis.Run() + require.NoError(t, err) + t.Cleanup(mr.Close) + rds := redis.NewClient(&redis.Options{Addr: mr.Addr()}) + return rds, mr +} + +// buildNewUserOnlySvcCtx 组装最小 ServiceContext(含 asynq Queue 使用 miniredis) +func buildNewUserOnlySvcCtx(db *gorm.DB, rds *redis.Client, mr *miniredis.Miniredis) *svc.ServiceContext { + queue := asynq.NewClient(asynq.RedisClientOpt{Addr: mr.Addr()}) + return &svc.ServiceContext{ + DB: db, + Redis: rds, + UserModel: user.NewModel(db, rds), + OrderModel: modelOrder.NewModel(db, rds), + SubscribeModel: subModel.NewModel(db, rds), + PaymentModel: payment.NewModel(db, rds), + Queue: queue, + } +} + +// insertTestSubscribe 直接用 SQL 插入 subscribe 行(绕过 GORM hook 的 MySQL 方言) +// new_user_only=true 时同时写入 discount JSON,使代码里的 discount 检查生效 +func insertTestSubscribe(t *testing.T, db *gorm.DB, id int64, newUserOnly bool) { + t.Helper() + nuOnly := 0 + discount := "" + if newUserOnly { + nuOnly = 1 + // discount JSON 包含一个 new_user_only=true 的 tier,匹配 quantity=1 + discount = `[{"quantity":1,"discount":90,"new_user_only":true}]` + } + err := db.Exec(`INSERT INTO "subscribe" + (id, name, unit_price, inventory, sell, sort, new_user_only, discount, created_at, updated_at) + VALUES (?, 'Test Plan', 1000, -1, 1, ?, ?, ?, datetime('now'), datetime('now'))`, + id, id, nuOnly, discount).Error + require.NoError(t, err) +} + +// insertTestPayment 插入支付方式行 +func insertTestPayment(t *testing.T, db *gorm.DB, id int64) { + t.Helper() + err := db.Exec(`INSERT INTO "payment" + (id, name, platform, config, enable, fee_mode, token) + VALUES (?, 'Balance', 'balance', '{}', 1, 0, ?)`, + id, "test-token").Error + require.NoError(t, err) +} + +// insertTestUser 插入用户行,createdAt 可控 +func insertTestUser(t *testing.T, db *gorm.DB, id int64, createdAt time.Time) *user.User { + t.Helper() + err := db.Exec(`INSERT INTO "user" + (id, password, balance, gift_amount, enable, created_at, updated_at) + VALUES (?, '', 0, 0, 1, ?, datetime('now'))`, + id, createdAt.UTC().Format("2006-01-02 15:04:05")).Error + require.NoError(t, err) + return &user.User{ + Id: id, + GiftAmount: 0, + CreatedAt: createdAt, + } +} + +func insertTestDevice(t *testing.T, db *gorm.DB, userID int64, identifier string, createdAt time.Time) { + t.Helper() + err := db.Exec(`INSERT INTO "user_device" + (user_id, ip, user_agent, identifier, short_code, online, enabled, created_at, updated_at) + VALUES (?, '127.0.0.1', 'test-agent', ?, '', 0, 1, ?, datetime('now'))`, + userID, + identifier, + createdAt.UTC().Format("2006-01-02 15:04:05"), + ).Error + require.NoError(t, err) +} + +func insertTestFamily(t *testing.T, db *gorm.DB, familyID, ownerUserID int64) { + t.Helper() + err := db.Exec(`INSERT INTO "user_family" + (id, owner_user_id, max_members, status, created_at, updated_at) + VALUES (?, ?, 3, 1, datetime('now'), datetime('now'))`, + familyID, + ownerUserID, + ).Error + require.NoError(t, err) +} + +func insertTestFamilyMember(t *testing.T, db *gorm.DB, familyID, userID int64, role, status uint8, joinSource string) { + t.Helper() + err := db.Exec(`INSERT INTO "user_family_member" + (family_id, user_id, role, status, join_source, joined_at, created_at, updated_at) + VALUES (?, ?, ?, ?, ?, datetime('now'), datetime('now'), datetime('now'))`, + familyID, + userID, + role, + status, + joinSource, + ).Error + require.NoError(t, err) +} + +// insertTestOrder 插入一条历史订单(status=2 表示已支付) +func insertTestOrder(t *testing.T, db *gorm.DB, userID, subscribeID int64, status uint8) { + t.Helper() + err := db.Exec(`INSERT INTO "order" + (user_id, order_no, type, status, subscribe_id, created_at, updated_at) + VALUES (?, ?, 1, ?, ?, datetime('now'), datetime('now'))`, + userID, "existing-order-no", status, subscribeID).Error + require.NoError(t, err) +} + +func insertScopedTestOrder(t *testing.T, db *gorm.DB, orderNo string, userID, subscribeID int64, status uint8) { + t.Helper() + err := db.Exec(`INSERT INTO "order" + (user_id, order_no, type, status, subscribe_id, created_at, updated_at) + VALUES (?, ?, 1, ?, ?, datetime('now'), datetime('now'))`, + userID, orderNo, status, subscribeID).Error + require.NoError(t, err) +} + +// buildPurchaseCtx 把 user 放入 context(模拟中间件行为) +func buildPurchaseCtx(u *user.User) context.Context { + return context.WithValue(context.Background(), constant.CtxKeyUser, u) +} + +// TestPurchase_NewUserOnly_UserTooOld 验证:new_user_only=true,用户注册超过 24h → 返回 SubscribeNewUserOnly +func TestPurchase_NewUserOnly_UserTooOld(t *testing.T) { + db := setupNewUserOnlyDB(t) + rds, mr := setupNewUserOnlyRedis(t) + svcCtx := buildNewUserOnlySvcCtx(db, rds, mr) + + const subID = int64(1) + const payID = int64(1) + + insertTestSubscribe(t, db, subID, true) // new_user_only = true + insertTestPayment(t, db, payID) + + // 用户注册 48 小时前 → 超出 24h 限制 + u := insertTestUser(t, db, 100, time.Now().Add(-48*time.Hour)) + ctx := buildPurchaseCtx(u) + + logic := NewPurchaseLogic(ctx, svcCtx) + _, err := logic.Purchase(&types.PurchaseOrderRequest{ + SubscribeId: subID, + Payment: payID, + Quantity: 1, + }) + + require.Error(t, err) + var errCode *xerr.CodeError + require.ErrorAs(t, err, &errCode) + assert.Equal(t, xerr.SubscribeNewUserOnly, errCode.GetErrCode(), + "注册超过24h应返回 SubscribeNewUserOnly 错误码") + + // 验证订单未被创建 + var count int64 + db.Model(&modelOrder.Order{}).Where("user_id = ?", u.Id).Count(&count) + assert.Equal(t, int64(0), count, "用户注册超时,订单不应被创建") +} + +// TestPurchase_NewUserOnly_AlreadyPurchased 验证:new_user_only=true,用户是新用户但已购买过 +// → 允许下单(不拦截),但不享受新人折扣 +func TestPurchase_NewUserOnly_AlreadyPurchased(t *testing.T) { + db := setupNewUserOnlyDB(t) + rds, mr := setupNewUserOnlyRedis(t) + svcCtx := buildNewUserOnlySvcCtx(db, rds, mr) + + const subID = int64(2) + const payID = int64(1) + + insertTestSubscribe(t, db, subID, true) + insertTestPayment(t, db, payID) + + // 用户刚注册(2h前)→ 满足时间条件 + u := insertTestUser(t, db, 200, time.Now().Add(-2*time.Hour)) + + // 但已有一条 status=2 的历史订单(已支付) + insertTestOrder(t, db, u.Id, subID, 2) + + ctx := buildPurchaseCtx(u) + logic := NewPurchaseLogic(ctx, svcCtx) + resp, err := logic.Purchase(&types.PurchaseOrderRequest{ + SubscribeId: subID, + Payment: payID, + Quantity: 1, + }) + + // 不应被拦截,允许下单 + require.NoError(t, err, "24h内已购用户应允许继续下单,不应返回错误") + require.NotNil(t, resp) + assert.NotEmpty(t, resp.OrderNo) + + // 历史订单 +1(新增了一条) + var count int64 + db.Model(&modelOrder.Order{}).Where("user_id = ?", u.Id).Count(&count) + assert.Equal(t, int64(2), count, "应新增一条订单") + + // 新订单无折扣:Amount=Price=1000 + var newOrder modelOrder.Order + require.NoError(t, db.Where("order_no = ?", resp.OrderNo).First(&newOrder).Error) + assert.Equal(t, int64(1000), newOrder.Amount, "已购用户不享受新人折扣,Amount 应等于 Price") + assert.Equal(t, int64(0), newOrder.Discount, "Discount 应为 0") +} + +// TestPurchase_NewUserOnly_Success 验证:new_user_only=true,新用户首次购买 → 成功创建订单 +func TestPurchase_NewUserOnly_Success(t *testing.T) { + db := setupNewUserOnlyDB(t) + rds, mr := setupNewUserOnlyRedis(t) + svcCtx := buildNewUserOnlySvcCtx(db, rds, mr) + + const subID = int64(3) + const payID = int64(1) + + insertTestSubscribe(t, db, subID, true) + insertTestPayment(t, db, payID) + + // 用户 1 小时前注册(新用户),且没有历史订单 + u := insertTestUser(t, db, 300, time.Now().Add(-1*time.Hour)) + ctx := buildPurchaseCtx(u) + + logic := NewPurchaseLogic(ctx, svcCtx) + resp, err := logic.Purchase(&types.PurchaseOrderRequest{ + SubscribeId: subID, + Payment: payID, + Quantity: 1, + }) + + require.NoError(t, err) + require.NotNil(t, resp) + assert.NotEmpty(t, resp.OrderNo, "新用户首次购买应成功,返回订单号") + + // 验证订单已写入数据库 + var o modelOrder.Order + err = db.Where("order_no = ?", resp.OrderNo).First(&o).Error + require.NoError(t, err) + assert.Equal(t, u.Id, o.UserId) + assert.Equal(t, subID, o.SubscribeId) +} + +// TestPurchase_NewUserOnly_Disabled 验证:new_user_only=false 时,老用户也能正常购买 +func TestPurchase_NewUserOnly_Disabled(t *testing.T) { + db := setupNewUserOnlyDB(t) + rds, mr := setupNewUserOnlyRedis(t) + svcCtx := buildNewUserOnlySvcCtx(db, rds, mr) + + const subID = int64(4) + const payID = int64(1) + + insertTestSubscribe(t, db, subID, false) // new_user_only = false + insertTestPayment(t, db, payID) + + // 注册 30 天的老用户 + u := insertTestUser(t, db, 400, time.Now().Add(-30*24*time.Hour)) + ctx := buildPurchaseCtx(u) + + logic := NewPurchaseLogic(ctx, svcCtx) + resp, err := logic.Purchase(&types.PurchaseOrderRequest{ + SubscribeId: subID, + Payment: payID, + Quantity: 1, + }) + + require.NoError(t, err) + require.NotNil(t, resp) + assert.NotEmpty(t, resp.OrderNo, "new_user_only=false时老用户应能正常购买") +} + +// TestPurchase_SingleMode_PendingOldOrderCancelled 验证:单订阅模式下,已有 pending 订单时 +// 第二次下单应关闭旧单并创建新单(而非复用旧单) +func TestPurchase_SingleMode_PendingOldOrderCancelled(t *testing.T) { + db := setupNewUserOnlyDB(t) + rds, mr := setupNewUserOnlyRedis(t) + svcCtx := buildNewUserOnlySvcCtx(db, rds, mr) + svcCtx.Config.Subscribe.SingleModel = true + + const subID = int64(5) + const payID = int64(1) + + insertTestSubscribe(t, db, subID, false) + insertTestPayment(t, db, payID) + + u := insertTestUser(t, db, 500, time.Now().Add(-1*time.Hour)) + ctx := buildPurchaseCtx(u) + + // 第一次下单(pending) + logic := NewPurchaseLogic(ctx, svcCtx) + resp1, err := logic.Purchase(&types.PurchaseOrderRequest{ + SubscribeId: subID, + Payment: payID, + Quantity: 1, + }) + require.NoError(t, err) + require.NotNil(t, resp1) + firstOrderNo := resp1.OrderNo + assert.NotEmpty(t, firstOrderNo) + + // 确认第一单 pending + var firstOrder modelOrder.Order + require.NoError(t, db.Where("order_no = ?", firstOrderNo).First(&firstOrder).Error) + assert.Equal(t, uint8(1), firstOrder.Status, "第一单应为 pending") + + // 第二次下单(不同 quantity) + logic2 := NewPurchaseLogic(ctx, svcCtx) + resp2, err := logic2.Purchase(&types.PurchaseOrderRequest{ + SubscribeId: subID, + Payment: payID, + Quantity: 3, + }) + require.NoError(t, err) + require.NotNil(t, resp2) + secondOrderNo := resp2.OrderNo + + // 新单与旧单不同 + assert.NotEqual(t, firstOrderNo, secondOrderNo, "第二次下单应创建新订单,不复用旧单") + + // 旧单应被关闭(status=3) + var closedOrder modelOrder.Order + require.NoError(t, db.Where("order_no = ?", firstOrderNo).First(&closedOrder).Error) + assert.Equal(t, uint8(3), closedOrder.Status, "旧 pending 单应被关闭") + + // 新单的 quantity 应为 3 + var newOrder modelOrder.Order + require.NoError(t, db.Where("order_no = ?", secondOrderNo).First(&newOrder).Error) + assert.Equal(t, int64(3), newOrder.Quantity, "新单 quantity 应为 3") + assert.Equal(t, uint8(1), newOrder.Status, "新单应为 pending 状态") +} + +// TestPurchase_SingleMode_NoPendingOrder 验证:单订阅模式下,没有旧 pending 单时正常创建 +func TestPurchase_SingleMode_NoPendingOrder(t *testing.T) { + db := setupNewUserOnlyDB(t) + rds, mr := setupNewUserOnlyRedis(t) + svcCtx := buildNewUserOnlySvcCtx(db, rds, mr) + svcCtx.Config.Subscribe.SingleModel = true + + const subID = int64(6) + const payID = int64(1) + + insertTestSubscribe(t, db, subID, false) + insertTestPayment(t, db, payID) + + u := insertTestUser(t, db, 600, time.Now().Add(-1*time.Hour)) + ctx := buildPurchaseCtx(u) + + logic := NewPurchaseLogic(ctx, svcCtx) + resp, err := logic.Purchase(&types.PurchaseOrderRequest{ + SubscribeId: subID, + Payment: payID, + Quantity: 2, + }) + require.NoError(t, err) + require.NotNil(t, resp) + assert.NotEmpty(t, resp.OrderNo, "无旧 pending 单时应正常创建新单") + + var o modelOrder.Order + require.NoError(t, db.Where("order_no = ?", resp.OrderNo).First(&o).Error) + assert.Equal(t, int64(2), o.Quantity) + assert.Equal(t, uint8(1), o.Status) +} + +// TestPurchase_NewUserOnly_AlreadyPurchased_NoBlock 验证:new_user_only=true 套餐, +// 24小时内但已购买过 → 允许下单,但不享受新人折扣(Discount=0,Amount=Price) +func TestPurchase_NewUserOnly_AlreadyPurchased_NoBlock(t *testing.T) { + db := setupNewUserOnlyDB(t) + rds, mr := setupNewUserOnlyRedis(t) + svcCtx := buildNewUserOnlySvcCtx(db, rds, mr) + + const subID = int64(7) + const payID = int64(1) + + // 套餐:unit_price=1000,discount=[{quantity:1,discount:80,new_user_only:true}] + insertTestSubscribe(t, db, subID, true) + insertTestPayment(t, db, payID) + + // 用户 1 小时前注册(新用户),但已有一条成功订单(status=2) + u := insertTestUser(t, db, 700, time.Now().Add(-1*time.Hour)) + insertTestOrder(t, db, u.Id, subID, 2) + + ctx := buildPurchaseCtx(u) + logic := NewPurchaseLogic(ctx, svcCtx) + resp, err := logic.Purchase(&types.PurchaseOrderRequest{ + SubscribeId: subID, + Payment: payID, + Quantity: 1, + }) + + // 不应被拦截 + require.NoError(t, err, "24h内已购用户不应被拦截,应允许下单") + require.NotNil(t, resp) + assert.NotEmpty(t, resp.OrderNo) + + // 验证订单金额:无折扣,Amount=Price=1000,Discount=0 + var newOrder modelOrder.Order + require.NoError(t, db.Where("order_no = ?", resp.OrderNo).First(&newOrder).Error) + assert.Equal(t, int64(1000), newOrder.Price, "Price 应为原价 1000") + assert.Equal(t, int64(1000), newOrder.Amount, "已购用户不享受新人折扣,Amount 应等于 Price") + assert.Equal(t, int64(0), newOrder.Discount, "Discount 应为 0") +} + +// TestPurchase_NewUserOnly_FirstPurchase_HasDiscount 验证:new_user_only=true 套餐, +// 24小时内首次购买 → 允许下单且享受新人折扣 +func TestPurchase_NewUserOnly_FirstPurchase_HasDiscount(t *testing.T) { + db := setupNewUserOnlyDB(t) + rds, mr := setupNewUserOnlyRedis(t) + svcCtx := buildNewUserOnlySvcCtx(db, rds, mr) + + const subID = int64(8) + const payID = int64(1) + + // 套餐:unit_price=1000,discount=[{quantity:1,discount:80,new_user_only:true}](8折) + insertTestSubscribe(t, db, subID, true) + insertTestPayment(t, db, payID) + + // 用户 1 小时前注册,无历史订单 + u := insertTestUser(t, db, 800, time.Now().Add(-1*time.Hour)) + ctx := buildPurchaseCtx(u) + + logic := NewPurchaseLogic(ctx, svcCtx) + resp, err := logic.Purchase(&types.PurchaseOrderRequest{ + SubscribeId: subID, + Payment: payID, + Quantity: 1, + }) + + require.NoError(t, err) + require.NotNil(t, resp) + + var newOrder modelOrder.Order + require.NoError(t, db.Where("order_no = ?", resp.OrderNo).First(&newOrder).Error) + assert.Equal(t, int64(1000), newOrder.Price, "Price 应为原价 1000") + assert.Equal(t, int64(900), newOrder.Amount, "首次购买应享受9折,Amount=900") + assert.Equal(t, int64(100), newOrder.Discount, "折扣金额应为 100") +} + +func TestPurchase_NewUserOnly_BindEmailScopeUsesEarliestDeviceTime(t *testing.T) { + db := setupNewUserOnlyDB(t) + rds, mr := setupNewUserOnlyRedis(t) + svcCtx := buildNewUserOnlySvcCtx(db, rds, mr) + + const ( + subID = int64(9) + payID = int64(1) + ownerUserID = int64(901) + memberUserID = int64(902) + familyID = int64(99) + ) + + insertTestSubscribe(t, db, subID, true) + insertTestPayment(t, db, payID) + + owner := insertTestUser(t, db, ownerUserID, time.Now().Add(-1*time.Hour)) + insertTestUser(t, db, memberUserID, time.Now().Add(-72*time.Hour)) + insertTestDevice(t, db, memberUserID, "device-eligibility-old", time.Now().Add(-72*time.Hour)) + insertTestFamily(t, db, familyID, ownerUserID) + insertTestFamilyMember(t, db, familyID, ownerUserID, user.FamilyRoleOwner, user.FamilyMemberActive, "owner_init") + insertTestFamilyMember(t, db, familyID, memberUserID, user.FamilyRoleMember, user.FamilyMemberActive, "bind_email_with_verification") + + logic := NewPurchaseLogic(buildPurchaseCtx(owner), svcCtx) + _, err := logic.Purchase(&types.PurchaseOrderRequest{ + SubscribeId: subID, + Payment: payID, + Quantity: 1, + }) + + require.Error(t, err) + var errCode *xerr.CodeError + require.ErrorAs(t, err, &errCode) + assert.Equal(t, xerr.SubscribeNewUserOnly, errCode.GetErrCode()) +} + +func TestPurchase_NewUserOnly_BindEmailScopeSharesHistory(t *testing.T) { + db := setupNewUserOnlyDB(t) + rds, mr := setupNewUserOnlyRedis(t) + svcCtx := buildNewUserOnlySvcCtx(db, rds, mr) + + const ( + subID = int64(10) + payID = int64(1) + ownerUserID = int64(1001) + memberUserID = int64(1002) + familyID = int64(109) + ) + + insertTestSubscribe(t, db, subID, true) + insertTestPayment(t, db, payID) + + owner := insertTestUser(t, db, ownerUserID, time.Now().Add(-1*time.Hour)) + insertTestUser(t, db, memberUserID, time.Now().Add(-2*time.Hour)) + insertTestDevice(t, db, memberUserID, "device-eligibility-shared", time.Now().Add(-2*time.Hour)) + insertTestFamily(t, db, familyID, ownerUserID) + insertTestFamilyMember(t, db, familyID, ownerUserID, user.FamilyRoleOwner, user.FamilyMemberActive, "owner_init") + insertTestFamilyMember(t, db, familyID, memberUserID, user.FamilyRoleMember, user.FamilyMemberActive, "bind_email_with_verification") + insertScopedTestOrder(t, db, "existing-scope-order", memberUserID, subID, 2) + + logic := NewPurchaseLogic(buildPurchaseCtx(owner), svcCtx) + resp, err := logic.Purchase(&types.PurchaseOrderRequest{ + SubscribeId: subID, + Payment: payID, + Quantity: 1, + }) + + require.NoError(t, err) + require.NotNil(t, resp) + + var newOrder modelOrder.Order + require.NoError(t, db.Where("order_no = ?", resp.OrderNo).First(&newOrder).Error) + assert.Equal(t, int64(1000), newOrder.Amount) + assert.Equal(t, int64(0), newOrder.Discount) +} + +func TestPreCreateOrder_NewUserOnly_BindEmailScopeUsesEarliestDeviceTime(t *testing.T) { + db := setupNewUserOnlyDB(t) + rds, mr := setupNewUserOnlyRedis(t) + svcCtx := buildNewUserOnlySvcCtx(db, rds, mr) + + const ( + subID = int64(11) + ownerUserID = int64(1101) + memberUserID = int64(1102) + familyID = int64(119) + ) + + insertTestSubscribe(t, db, subID, true) + + owner := insertTestUser(t, db, ownerUserID, time.Now().Add(-1*time.Hour)) + insertTestUser(t, db, memberUserID, time.Now().Add(-96*time.Hour)) + insertTestDevice(t, db, memberUserID, "device-precreate-old", time.Now().Add(-96*time.Hour)) + insertTestFamily(t, db, familyID, ownerUserID) + insertTestFamilyMember(t, db, familyID, ownerUserID, user.FamilyRoleOwner, user.FamilyMemberActive, "owner_init") + insertTestFamilyMember(t, db, familyID, memberUserID, user.FamilyRoleMember, user.FamilyMemberActive, "bind_email_with_verification") + + logic := NewPreCreateOrderLogic(buildPurchaseCtx(owner), svcCtx) + _, err := logic.PreCreateOrder(&types.PurchaseOrderRequest{ + SubscribeId: subID, + Quantity: 1, + }) + + require.Error(t, err) + var errCode *xerr.CodeError + require.ErrorAs(t, err, &errCode) + assert.Equal(t, xerr.SubscribeNewUserOnly, errCode.GetErrCode()) +} + +func TestPreCreateOrder_NewUserOnly_OrdinaryFamilyMemberDoesNotAffectEligibility(t *testing.T) { + db := setupNewUserOnlyDB(t) + rds, mr := setupNewUserOnlyRedis(t) + svcCtx := buildNewUserOnlySvcCtx(db, rds, mr) + + const ( + subID = int64(12) + ownerUserID = int64(1201) + memberUserID = int64(1202) + familyID = int64(129) + ) + + insertTestSubscribe(t, db, subID, true) + + owner := insertTestUser(t, db, ownerUserID, time.Now().Add(-1*time.Hour)) + insertTestUser(t, db, memberUserID, time.Now().Add(-96*time.Hour)) + insertTestDevice(t, db, memberUserID, "device-precreate-ordinary", time.Now().Add(-96*time.Hour)) + insertTestFamily(t, db, familyID, ownerUserID) + insertTestFamilyMember(t, db, familyID, ownerUserID, user.FamilyRoleOwner, user.FamilyMemberActive, "owner_init") + insertTestFamilyMember(t, db, familyID, memberUserID, user.FamilyRoleMember, user.FamilyMemberActive, "manual_invite") + + logic := NewPreCreateOrderLogic(buildPurchaseCtx(owner), svcCtx) + resp, err := logic.PreCreateOrder(&types.PurchaseOrderRequest{ + SubscribeId: subID, + Quantity: 1, + }) + + require.NoError(t, err) + require.NotNil(t, resp) + assert.Equal(t, int64(900), resp.Amount) + assert.Equal(t, int64(100), resp.Discount) +} diff --git a/queue/logic/order/activateOrderLogic.go b/queue/logic/order/activateOrderLogic.go index 0dec9bd..4577e8e 100644 --- a/queue/logic/order/activateOrderLogic.go +++ b/queue/logic/order/activateOrderLogic.go @@ -10,6 +10,7 @@ import ( "time" "github.com/perfect-panel/server/internal/logic/admin/group" + commonLogic "github.com/perfect-panel/server/internal/logic/common" "github.com/perfect-panel/server/internal/model/log" "github.com/perfect-panel/server/pkg/constant" "github.com/perfect-panel/server/pkg/logger" @@ -44,6 +45,7 @@ const ( OrderStatusPaid = 2 // Order paid and ready for processing OrderStatusClose = 3 // Order closed/cancelled OrderStatusFailed = 4 // Order processing failed + OrderStatusClaimed = 4 // Internal transient claim while a worker processes the order OrderStatusFinished = 5 // Order successfully completed ) @@ -82,7 +84,7 @@ func (l *ActivateOrderLogic) ProcessTask(ctx context.Context, task *asynq.Task) logger.WithContext(ctx).Info("[ActivateOrderLogic] 正在验证订单", logger.Field("order_no", payload.OrderNo)) - orderInfo, err := l.validateAndGetOrder(ctx, payload.OrderNo) + orderInfo, err := l.claimAndGetOrder(ctx, payload.OrderNo) if err != nil { // 如果订单不存在或状态不对,不重试 if errors.Is(err, ErrInvalidOrderStatus) { @@ -108,6 +110,7 @@ func (l *ActivateOrderLogic) ProcessTask(ctx context.Context, task *asynq.Task) logger.Field("user_id", orderInfo.UserId)) if err = l.processOrderByType(ctx, orderInfo, payload.IAPExpireAt); err != nil { + l.releaseClaim(ctx, orderInfo.OrderNo) logger.WithContext(ctx).Error("[ActivateOrderLogic] 处理订单失败,将重试", logger.Field("order_no", orderInfo.OrderNo), logger.Field("order_type", orderInfo.Type), @@ -137,10 +140,11 @@ func (l *ActivateOrderLogic) parsePayload(ctx context.Context, payload []byte) ( return &p, nil } -// validateAndGetOrder retrieves an order by order number and validates its status +// claimAndGetOrder retrieves an order by order number and atomically claims paid orders. // Returns error if order is not found or not in paid status -func (l *ActivateOrderLogic) validateAndGetOrder(ctx context.Context, orderNo string) (*order.Order, error) { - orderInfo, err := l.svc.OrderModel.FindOneByOrderNo(ctx, orderNo) +func (l *ActivateOrderLogic) claimAndGetOrder(ctx context.Context, orderNo string) (*order.Order, error) { + var orderInfo order.Order + err := l.svc.DB.WithContext(ctx).Model(&order.Order{}).Where("order_no = ?", orderNo).First(&orderInfo).Error if err != nil { logger.WithContext(ctx).Error("Find order failed", logger.Field("error", err.Error()), @@ -165,7 +169,33 @@ func (l *ActivateOrderLogic) validateAndGetOrder(ctx context.Context, orderNo st return nil, ErrInvalidOrderStatus } - return orderInfo, nil + result := l.svc.DB.WithContext(ctx). + Model(&order.Order{}). + Where("order_no = ? AND status = ?", orderNo, OrderStatusPaid). + Update("status", OrderStatusClaimed) + if result.Error != nil { + return nil, result.Error + } + if result.RowsAffected == 0 { + logger.WithContext(ctx).Info("Order already claimed by another worker, skip processing", + logger.Field("order_no", orderNo), + ) + return nil, nil + } + orderInfo.Status = OrderStatusClaimed + return &orderInfo, nil +} + +func (l *ActivateOrderLogic) releaseClaim(ctx context.Context, orderNo string) { + if err := l.svc.DB.WithContext(ctx). + Model(&order.Order{}). + Where("order_no = ? AND status = ?", orderNo, OrderStatusClaimed). + Update("status", OrderStatusPaid).Error; err != nil { + logger.WithContext(ctx).Error("Release order claim failed", + logger.Field("error", err.Error()), + logger.Field("order_no", orderNo), + ) + } } // processOrderByType routes order processing based on the order type @@ -274,20 +304,24 @@ func (l *ActivateOrderLogic) NewPurchase(ctx context.Context, orderInfo *order.O ) } - // 如果没有合并已购订阅,再尝试合并赠送订阅(order_id=0) - if userSub == nil { - giftSub, giftErr := l.findGiftSubscription(ctx, singleModeUserId, orderInfo.SubscribeId) - if giftErr == nil && giftSub != nil { - // 在赠送订阅上延长时间,保持 token 不变 - userSub, err = l.extendGiftSubscription(ctx, giftSub, orderInfo, sub) - if err != nil { - logger.WithContext(ctx).Error("Extend gift subscription failed", - logger.Field("error", err.Error()), - logger.Field("gift_subscribe_id", giftSub.Id), - ) - // 合并失败时回退到创建新订阅 - userSub = nil - } + } + + // 如果没有合并已购订阅,再尝试合并赠送订阅(order_id=0)。 + // 全局单订阅口径下,非 SingleModel 也不能让试用订阅和付费订阅并存。 + if userSub == nil { + effectiveOwner := orderInfo.UserId + if orderInfo.SubscriptionUserId > 0 { + effectiveOwner = orderInfo.SubscriptionUserId + } + giftSub, giftErr := l.findGiftSubscription(ctx, effectiveOwner, orderInfo.SubscribeId) + if giftErr == nil && giftSub != nil { + userSub, err = l.extendGiftSubscription(ctx, giftSub, orderInfo, sub) + if err != nil { + logger.WithContext(ctx).Error("Extend gift subscription failed", + logger.Field("error", err.Error()), + logger.Field("gift_subscribe_id", giftSub.Id), + ) + userSub = nil } } } @@ -302,8 +336,10 @@ func (l *ActivateOrderLogic) NewPurchase(ctx context.Context, orderInfo *order.O } var existingSub user.Subscribe if findErr := l.svc.DB.Model(&user.Subscribe{}). - Where("user_id IN ? AND subscribe_id = ?", candidateUserIds, orderInfo.SubscribeId). + Where("user_id IN ? AND token != ''", candidateUserIds). Order("expire_time DESC"). + Order("updated_at DESC"). + Order("id DESC"). First(&existingSub).Error; findErr == nil { // 家庭组场景:订阅 owner 可能变更(如成员注册的试用 → 被家主收归), // 续期前把 user_id 校正为当前订单的 SubscriptionUserId @@ -514,7 +550,7 @@ func (l *ActivateOrderLogic) createUserSubscription(ctx context.Context, orderIn // Check quota limit before creating subscription (final safeguard) if sub.Quota > 0 { var count int64 - if err := l.svc.DB.Model(&user.Subscribe{}).Where("user_id = ? AND subscribe_id = ?", orderInfo.UserId, orderInfo.SubscribeId).Count(&count).Error; err != nil { + if err := l.svc.DB.Model(&user.Subscribe{}).Where("user_id = ?", subscriptionUserId).Count(&count).Error; err != nil { logger.WithContext(ctx).Error("Count user subscribe failed", logger.Field("error", err.Error())) return nil, err } @@ -602,7 +638,7 @@ func (l *ActivateOrderLogic) handleCommission(ctx context.Context, userInfo *use if !l.shouldProcessCommission(userInfo, orderInfo.IsNew) { // 普通用户路径(佣金比例=0):只有首单才双方赠N天 if orderInfo.IsNew { - l.grantGiftDaysToBothParties(ctx, userInfo, orderInfo.OrderNo) + l.grantGiftDaysToBothParties(ctx, userInfo, orderInfo) } return } @@ -692,16 +728,18 @@ func (l *ActivateOrderLogic) handleCommission(ctx context.Context, userInfo *use // 有佣金路径:邀请人拿佣金,被邀请用户(首单)拿天数 if orderInfo.IsNew { - _ = l.grantGiftDays(ctx, userInfo, int(l.svc.Config.Invite.GiftDays), orderInfo.OrderNo, "邀请赠送") + giftTarget := l.resolveGiftTargetUser(ctx, userInfo, orderInfo.SubscriptionUserId) + _ = l.grantGiftDays(ctx, giftTarget, int(l.svc.Config.Invite.GiftDays), orderInfo.OrderNo, "邀请赠送") } } -func (l *ActivateOrderLogic) grantGiftDaysToBothParties(ctx context.Context, referee *user.User, orderNo string) { +func (l *ActivateOrderLogic) grantGiftDaysToBothParties(ctx context.Context, referee *user.User, orderInfo *order.Order) { giftDays := l.svc.Config.Invite.GiftDays - if giftDays <= 0 || referee == nil || referee.Id == 0 || referee.RefererId == 0 { + if giftDays <= 0 || referee == nil || referee.Id == 0 || referee.RefererId == 0 || orderInfo == nil { return } - _ = l.grantGiftDays(ctx, referee, int(giftDays), orderNo, "邀请赠送") + refereeTarget := l.resolveGiftTargetUser(ctx, referee, orderInfo.SubscriptionUserId) + _ = l.grantGiftDays(ctx, refereeTarget, int(giftDays), orderInfo.OrderNo, "邀请赠送") if referee.RefererId == 0 { return } @@ -709,7 +747,32 @@ func (l *ActivateOrderLogic) grantGiftDaysToBothParties(ctx context.Context, ref if err != nil || referer == nil { return } - _ = l.grantGiftDays(ctx, referer, int(giftDays), orderNo, "邀请赠送") + refererTarget := l.resolveGiftTargetUser(ctx, referer, 0) + _ = l.grantGiftDays(ctx, refererTarget, int(giftDays), orderInfo.OrderNo, "邀请赠送") +} + +func (l *ActivateOrderLogic) resolveGiftTargetUser(ctx context.Context, source *user.User, forcedOwnerID int64) *user.User { + if source == nil || source.Id == 0 { + return source + } + targetID := source.Id + if forcedOwnerID > 0 { + targetID = forcedOwnerID + } else if entitlement, err := commonLogic.ResolveEntitlementUser(ctx, l.svc.DB, source.Id); err == nil && entitlement != nil && entitlement.EffectiveUserID > 0 { + targetID = entitlement.EffectiveUserID + } + if targetID == source.Id { + return source + } + target, err := l.svc.UserModel.FindOne(ctx, targetID) + if err != nil || target == nil { + logger.WithContext(ctx).Error("Resolve gift target owner failed", + logger.Field("source_user_id", source.Id), + logger.Field("target_user_id", targetID), + ) + return source + } + return target } func (l *ActivateOrderLogic) grantGiftDays(ctx context.Context, u *user.User, days int, orderNo string, remark string) error { @@ -736,7 +799,22 @@ func (l *ActivateOrderLogic) grantGiftDays(ctx context.Context, u *user.User, da activeSubscribe, err := l.svc.UserModel.FindActiveSubscribe(ctx, u.Id) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { - return nil + giftLog := &log.Gift{ + Type: log.GiftTypeIncrease, + OrderNo: orderNo, + SubscribeId: 0, + Amount: int64(days), + Balance: u.Balance, + Remark: remark + " skipped: no active subscription", + Timestamp: time.Now().UnixMilli(), + } + content, _ := giftLog.Marshal() + return l.svc.LogModel.Insert(ctx, &log.SystemLog{ + Type: log.TypeGift.Uint8(), + Date: time.Now().Format("2006-01-02"), + ObjectID: u.Id, + Content: string(content), + }) } return err } diff --git a/queue/logic/order/activateOrderLogic_invite_test.go b/queue/logic/order/activateOrderLogic_invite_test.go new file mode 100644 index 0000000..947f277 --- /dev/null +++ b/queue/logic/order/activateOrderLogic_invite_test.go @@ -0,0 +1,436 @@ +package orderLogic + +import ( + "context" + "fmt" + "os" + "testing" + "time" + + "github.com/perfect-panel/server/internal/config" + userLogic "github.com/perfect-panel/server/internal/logic/public/user" + modelLog "github.com/perfect-panel/server/internal/model/log" + "github.com/perfect-panel/server/internal/model/order" + "github.com/perfect-panel/server/internal/model/user" + "github.com/perfect-panel/server/internal/svc" + "github.com/perfect-panel/server/internal/types" + "github.com/perfect-panel/server/pkg/constant" + "github.com/redis/go-redis/v9" + "gorm.io/driver/mysql" + "gorm.io/gorm" +) + +// 普通用户 + 首单 → 双方赠N天 +func TestHandleCommission_GrantGiftDaysWhenCommissionDisabled_FirstOrder(t *testing.T) { + logic, db, cleanup := setupInviteTestLogic(t, config.InviteConfig{ + ReferralPercentage: 0, + OnlyFirstPurchase: false, + GiftDays: 2, + }) + defer cleanup() + + referee := seedUser(t, db, 0, false) + referer := seedUser(t, db, 0, false) + referee.RefererId = referer.Id + + baseExpire := time.Now().Add(72 * time.Hour).Truncate(time.Second) + refereeSub := seedActiveSubscribe(t, db, referee.Id, baseExpire) + refererSub := seedActiveSubscribe(t, db, referer.Id, baseExpire) + + logic.handleCommission(context.Background(), referee, &order.Order{ + OrderNo: "ORD-GIFT-001", + Type: OrderTypeSubscribe, + IsNew: true, // 首单 + Amount: 100, + FeeAmount: 0, + CreatedAt: time.Now(), + }) + + assertExpireIncreasedByDays(t, db, refereeSub.Id, baseExpire, 2) + assertExpireIncreasedByDays(t, db, refererSub.Id, baseExpire, 2) + + var giftCount int64 + if err := db.Model(&modelLog.SystemLog{}).Where("type = ?", modelLog.TypeGift.Uint8()).Count(&giftCount).Error; err != nil { + t.Fatalf("count gift logs failed: %v", err) + } + if giftCount != 2 { + t.Fatalf("expected 2 gift logs, got %d", giftCount) + } +} + +// 普通用户 + 非首单 → 不赠送 +func TestHandleCommission_NoGiftDaysWhenCommissionDisabled_NotFirstOrder(t *testing.T) { + logic, db, cleanup := setupInviteTestLogic(t, config.InviteConfig{ + ReferralPercentage: 0, + OnlyFirstPurchase: false, + GiftDays: 2, + }) + defer cleanup() + + referee := seedUser(t, db, 0, false) + referer := seedUser(t, db, 0, false) + referee.RefererId = referer.Id + + baseExpire := time.Now().Add(72 * time.Hour).Truncate(time.Second) + refereeSub := seedActiveSubscribe(t, db, referee.Id, baseExpire) + refererSub := seedActiveSubscribe(t, db, referer.Id, baseExpire) + + logic.handleCommission(context.Background(), referee, &order.Order{ + OrderNo: "ORD-GIFT-002", + Type: OrderTypeSubscribe, + IsNew: false, // 非首单 + Amount: 100, + FeeAmount: 0, + CreatedAt: time.Now(), + }) + + // 到期时间不应延长 + assertExpireIncreasedByDays(t, db, refereeSub.Id, baseExpire, 0) + assertExpireIncreasedByDays(t, db, refererSub.Id, baseExpire, 0) + + var giftCount int64 + if err := db.Model(&modelLog.SystemLog{}).Where("type = ?", modelLog.TypeGift.Uint8()).Count(&giftCount).Error; err != nil { + t.Fatalf("count gift logs failed: %v", err) + } + if giftCount != 0 { + t.Fatalf("expected 0 gift logs for non-first order, got %d", giftCount) + } +} + +// 渠道 + 首单 → 被邀请人赠N天 + 邀请人获佣金 +func TestHandleCommission_GiftDaysAndCommissionWhenChannelFirstOrder(t *testing.T) { + logic, db, cleanup := setupInviteTestLogic(t, config.InviteConfig{ + ReferralPercentage: 10, + OnlyFirstPurchase: false, + GiftDays: 2, + }) + defer cleanup() + + referee := seedUser(t, db, 0, false) + referer := seedUser(t, db, 0, false) + referee.RefererId = referer.Id + + baseExpire := time.Now().Add(96 * time.Hour).Truncate(time.Second) + refereeSub := seedActiveSubscribe(t, db, referee.Id, baseExpire) + + logic.handleCommission(context.Background(), referee, &order.Order{ + OrderNo: "ORD-COMM-001", + Type: OrderTypeSubscribe, + IsNew: true, // 首单 + Amount: 100, + FeeAmount: 0, + CreatedAt: time.Now(), + }) + + // 被邀请人(首单)应获得赠送天数 + assertExpireIncreasedByDays(t, db, refereeSub.Id, baseExpire, 2) + + // 邀请人应获得佣金 + var refererAfter user.User + if err := db.First(&refererAfter, referer.Id).Error; err != nil { + t.Fatalf("query referer failed: %v", err) + } + if refererAfter.Commission != 10 { + t.Fatalf("expected referer commission=10, got %d", refererAfter.Commission) + } + + var giftCount int64 + if err := db.Model(&modelLog.SystemLog{}).Where("type = ?", modelLog.TypeGift.Uint8()).Count(&giftCount).Error; err != nil { + t.Fatalf("count gift logs failed: %v", err) + } + if giftCount != 1 { + t.Fatalf("expected 1 gift log for referee on first order with commission, got %d", giftCount) + } +} + +// 渠道 + 非首单 → 只给邀请人佣金,不赠天 +func TestHandleCommission_OnlyCommissionWhenChannelNotFirstOrder(t *testing.T) { + logic, db, cleanup := setupInviteTestLogic(t, config.InviteConfig{ + ReferralPercentage: 10, + OnlyFirstPurchase: false, + GiftDays: 2, + }) + defer cleanup() + + referee := seedUser(t, db, 0, false) + referer := seedUser(t, db, 0, false) + referee.RefererId = referer.Id + + baseExpire := time.Now().Add(96 * time.Hour).Truncate(time.Second) + refereeSub := seedActiveSubscribe(t, db, referee.Id, baseExpire) + + logic.handleCommission(context.Background(), referee, &order.Order{ + OrderNo: "ORD-COMM-002", + Type: OrderTypeSubscribe, + IsNew: false, // 非首单 + Amount: 100, + FeeAmount: 0, + CreatedAt: time.Now(), + }) + + // 被邀请人不应获得赠送天数 + assertExpireIncreasedByDays(t, db, refereeSub.Id, baseExpire, 0) + + // 邀请人应获得佣金 + var refererAfter user.User + if err := db.First(&refererAfter, referer.Id).Error; err != nil { + t.Fatalf("query referer failed: %v", err) + } + if refererAfter.Commission != 10 { + t.Fatalf("expected referer commission=10, got %d", refererAfter.Commission) + } + + var giftCount int64 + if err := db.Model(&modelLog.SystemLog{}).Where("type = ?", modelLog.TypeGift.Uint8()).Count(&giftCount).Error; err != nil { + t.Fatalf("count gift logs failed: %v", err) + } + if giftCount != 0 { + t.Fatalf("expected 0 gift logs when channel non-first order, got %d", giftCount) + } +} + +func TestHandleCommission_NoGiftDaysWhenNoInviteRelation(t *testing.T) { + logic, db, cleanup := setupInviteTestLogic(t, config.InviteConfig{ + ReferralPercentage: 0, + OnlyFirstPurchase: false, + GiftDays: 2, + }) + defer cleanup() + + // 没有邀请人的独立用户 + loneUser := seedUser(t, db, 0, false) + // RefererId == 0,无邀请关系 + + baseExpire := time.Now().Add(72 * time.Hour).Truncate(time.Second) + loneSub := seedActiveSubscribe(t, db, loneUser.Id, baseExpire) + + logic.handleCommission(context.Background(), loneUser, &order.Order{ + OrderNo: "ORD-LONE-001", + Type: OrderTypeSubscribe, + IsNew: true, + Amount: 100, + FeeAmount: 0, + CreatedAt: time.Now(), + }) + + // 订阅到期时间不应该被延长 + var subAfter user.Subscribe + if err := db.First(&subAfter, loneSub.Id).Error; err != nil { + t.Fatalf("query subscribe failed: %v", err) + } + if !subAfter.ExpireTime.Equal(baseExpire) { + t.Fatalf("expected no gift days for user without inviter, before=%v after=%v", baseExpire, subAfter.ExpireTime) + } + + // 不应产生赠天日志 + var giftCount int64 + if err := db.Model(&modelLog.SystemLog{}).Where("type = ?", modelLog.TypeGift.Uint8()).Count(&giftCount).Error; err != nil { + t.Fatalf("count gift logs failed: %v", err) + } + if giftCount != 0 { + t.Fatalf("expected 0 gift logs for user without inviter, got %d", giftCount) + } +} + +// 先绑码后首单 → 双方赠N天 +func TestInviteFlow_BindThenFirstOrder_GrantGiftDays(t *testing.T) { + logic, db, cleanup := setupInviteTestLogic(t, config.InviteConfig{ + ReferralPercentage: 0, + OnlyFirstPurchase: false, + GiftDays: 2, + }) + defer cleanup() + + referee := seedUser(t, db, 0, false) + referer := seedUser(t, db, 0, false) + referer.ReferCode = fmt.Sprintf("REF-%d", referer.Id) + if err := db.Model(&user.User{}).Where("id = ?", referer.Id).Update("refer_code", referer.ReferCode).Error; err != nil { + t.Fatalf("update referer code failed: %v", err) + } + + refereeBaseExpire := time.Now().Add(48 * time.Hour).Truncate(time.Second) + refererBaseExpire := time.Now().Add(72 * time.Hour).Truncate(time.Second) + refereeSub := seedActiveSubscribe(t, db, referee.Id, refereeBaseExpire) + refererSub := seedActiveSubscribe(t, db, referer.Id, refererBaseExpire) + + ctx := context.WithValue(context.Background(), constant.CtxKeyUser, referee) + bindLogic := userLogic.NewBindInviteCodeLogic(ctx, logic.svc) + if err := bindLogic.BindInviteCode(&types.BindInviteCodeRequest{InviteCode: referer.ReferCode}); err != nil { + t.Fatalf("bind invite code failed: %v", err) + } + + var refereeAfterBind user.User + if err := db.First(&refereeAfterBind, referee.Id).Error; err != nil { + t.Fatalf("query referee after bind failed: %v", err) + } + if refereeAfterBind.RefererId != referer.Id { + t.Fatalf("bind invite failed, expected referer_id=%d got=%d", referer.Id, refereeAfterBind.RefererId) + } + + // 首单 IsNew=true → 双方赠N天 + logic.handleCommission(context.Background(), &refereeAfterBind, &order.Order{ + OrderNo: "ORD-FLOW-001", + Type: OrderTypeSubscribe, + IsNew: true, + Amount: 100, + FeeAmount: 0, + CreatedAt: time.Now(), + }) + + assertExpireIncreasedByDays(t, db, refereeSub.Id, refereeBaseExpire, 2) + assertExpireIncreasedByDays(t, db, refererSub.Id, refererBaseExpire, 2) +} + +// 先买订单后绑码再续费 → 不赠送(IsNew=false) +func TestInviteFlow_OrderThenBind_NoGiftDays(t *testing.T) { + logic, db, cleanup := setupInviteTestLogic(t, config.InviteConfig{ + ReferralPercentage: 0, + OnlyFirstPurchase: false, + GiftDays: 2, + }) + defer cleanup() + + referee := seedUser(t, db, 0, false) + referer := seedUser(t, db, 0, false) + referee.RefererId = referer.Id + + baseExpire := time.Now().Add(72 * time.Hour).Truncate(time.Second) + refereeSub := seedActiveSubscribe(t, db, referee.Id, baseExpire) + refererSub := seedActiveSubscribe(t, db, referer.Id, baseExpire) + + // 先前已有订单,IsNew=false(模拟先买订单后绑码的场景) + logic.handleCommission(context.Background(), referee, &order.Order{ + OrderNo: "ORD-FLOW-002", + Type: OrderTypeSubscribe, + IsNew: false, // 已有历史订单 + Amount: 100, + FeeAmount: 0, + CreatedAt: time.Now(), + }) + + assertExpireIncreasedByDays(t, db, refereeSub.Id, baseExpire, 0) + assertExpireIncreasedByDays(t, db, refererSub.Id, baseExpire, 0) +} + +func setupInviteTestLogic(t *testing.T, inviteCfg config.InviteConfig) (*ActivateOrderLogic, *gorm.DB, func()) { + t.Helper() + + mysqlAddr := getenvDefault("TEST_MYSQL_ADDR", "127.0.0.1:3306") + mysqlUser := getenvDefault("TEST_MYSQL_USER", "root") + mysqlPassword := getenvDefault("TEST_MYSQL_PASSWORD", "rootpassword") + + adminDSN := fmt.Sprintf("%s:%s@tcp(%s)/?charset=utf8mb4&parseTime=true&loc=Local&multiStatements=true", mysqlUser, mysqlPassword, mysqlAddr) + adminDB, err := gorm.Open(mysql.Open(adminDSN), &gorm.Config{}) + if err != nil { + t.Fatalf("open mysql admin connection failed: %v", err) + } + + dbName := fmt.Sprintf("ppanel_test_invite_%d", time.Now().UnixNano()) + if err := adminDB.Exec(fmt.Sprintf("CREATE DATABASE `%s` CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci", dbName)).Error; err != nil { + t.Fatalf("create test database failed: %v", err) + } + + testDSN := fmt.Sprintf("%s:%s@tcp(%s)/%s?charset=utf8mb4&parseTime=true&loc=Local", mysqlUser, mysqlPassword, mysqlAddr, dbName) + db, err := gorm.Open(mysql.Open(testDSN), &gorm.Config{}) + if err != nil { + t.Fatalf("open test database failed: %v", err) + } + + if err := db.AutoMigrate(&user.User{}, &user.Device{}, &user.AuthMethods{}, &user.Subscribe{}, &modelLog.SystemLog{}); err != nil { + t.Fatalf("auto migrate failed: %v", err) + } + + redisAddr := getenvDefault("TEST_REDIS_ADDR", "127.0.0.1:6379") + redisPassword := getenvDefault("TEST_REDIS_PASSWORD", "") + rdb := redis.NewClient(&redis.Options{ + Addr: redisAddr, + Password: redisPassword, + DB: 0, + }) + if err := rdb.Ping(context.Background()).Err(); err != nil { + t.Fatalf("connect redis failed: %v", err) + } + _ = rdb.FlushDB(context.Background()).Err() + + svcCtx := &svc.ServiceContext{ + DB: db, + Redis: rdb, + UserModel: user.NewModel(db, rdb), + LogModel: modelLog.NewModel(db), + Config: config.Config{ + Invite: inviteCfg, + }, + } + + return NewActivateOrderLogic(svcCtx), db, func() { + _ = rdb.Close() + sqlDB, _ := db.DB() + if sqlDB != nil { + _ = sqlDB.Close() + } + _ = adminDB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS `%s`", dbName)).Error + } +} + +func seedUser(t *testing.T, db *gorm.DB, referralPercentage uint8, onlyFirstPurchase bool) *user.User { + t.Helper() + u := &user.User{ + Password: "pwd", + Algo: "default", + ReferralPercentage: referralPercentage, + OnlyFirstPurchase: boolPtr(onlyFirstPurchase), + Enable: boolPtr(true), + IsAdmin: boolPtr(false), + EnableBalanceNotify: boolPtr(false), + EnableLoginNotify: boolPtr(false), + EnableSubscribeNotify: boolPtr(false), + EnableTradeNotify: boolPtr(false), + } + if err := db.Create(u).Error; err != nil { + t.Fatalf("seed user failed: %v", err) + } + return u +} + +func seedActiveSubscribe(t *testing.T, db *gorm.DB, userID int64, expireAt time.Time) *user.Subscribe { + t.Helper() + sub := &user.Subscribe{ + UserId: userID, + OrderId: 1, + SubscribeId: 1, + StartTime: time.Now().Add(-24 * time.Hour), + ExpireTime: expireAt, + Traffic: 1024, + Token: fmt.Sprintf("token-%d-%d", userID, time.Now().UnixNano()), + UUID: fmt.Sprintf("uuid-%d-%d", userID, time.Now().UnixNano()), + Status: 1, + } + if err := db.Create(sub).Error; err != nil { + t.Fatalf("seed subscribe failed: %v", err) + } + return sub +} + +func assertExpireIncreasedByDays(t *testing.T, db *gorm.DB, subscribeID int64, before time.Time, days int) { + t.Helper() + var after user.Subscribe + if err := db.First(&after, subscribeID).Error; err != nil { + t.Fatalf("query subscribe failed: %v", err) + } + expected := before.Add(time.Duration(days) * 24 * time.Hour) + if !after.ExpireTime.Equal(expected) { + t.Fatalf("expire time mismatch, expected=%v got=%v", expected, after.ExpireTime) + } +} + +func boolPtr(v bool) *bool { + return &v +} + +func getenvDefault(key, fallback string) string { + v := os.Getenv(key) + if v == "" { + return fallback + } + return v +} diff --git a/queue/logic/order/activateOrderLogic_newUserEligibility_test.go b/queue/logic/order/activateOrderLogic_newUserEligibility_test.go new file mode 100644 index 0000000..e8ff523 --- /dev/null +++ b/queue/logic/order/activateOrderLogic_newUserEligibility_test.go @@ -0,0 +1,199 @@ +package orderLogic + +import ( + "context" + "testing" + "time" + + modelOrder "github.com/perfect-panel/server/internal/model/order" + "github.com/perfect-panel/server/internal/model/subscribe" + "github.com/perfect-panel/server/internal/model/user" + "github.com/stretchr/testify/require" + "gorm.io/driver/sqlite" + "gorm.io/gorm" + "gorm.io/gorm/logger" +) + +func setupActivationEligibilityDB(t *testing.T) *gorm.DB { + t.Helper() + + db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{ + Logger: logger.Default.LogMode(logger.Silent), + }) + require.NoError(t, err) + + sqls := []string{ + `CREATE TABLE IF NOT EXISTS "user" ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + password VARCHAR(100) NOT NULL DEFAULT '', + created_at DATETIME, + updated_at DATETIME, + deleted_at DATETIME DEFAULT NULL + )`, + `CREATE TABLE IF NOT EXISTS "user_device" ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_id INTEGER NOT NULL DEFAULT 0, + identifier VARCHAR(255) NOT NULL DEFAULT '' UNIQUE, + created_at DATETIME, + updated_at DATETIME + )`, + `CREATE TABLE IF NOT EXISTS "user_family" ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + owner_user_id INTEGER NOT NULL DEFAULT 0, + status TINYINT NOT NULL DEFAULT 1, + deleted_at DATETIME DEFAULT NULL + )`, + `CREATE TABLE IF NOT EXISTS "user_family_member" ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + family_id INTEGER NOT NULL DEFAULT 0, + user_id INTEGER NOT NULL DEFAULT 0, + role TINYINT NOT NULL DEFAULT 0, + status TINYINT NOT NULL DEFAULT 0, + join_source VARCHAR(32) NOT NULL DEFAULT '', + deleted_at DATETIME DEFAULT NULL + )`, + `CREATE TABLE IF NOT EXISTS "order" ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_id INTEGER NOT NULL DEFAULT 0, + order_no VARCHAR(255) NOT NULL DEFAULT '' UNIQUE, + type TINYINT NOT NULL DEFAULT 1, + status TINYINT NOT NULL DEFAULT 1, + subscribe_id INTEGER NOT NULL DEFAULT 0, + quantity INTEGER NOT NULL DEFAULT 1, + created_at DATETIME, + updated_at DATETIME + )`, + } + for _, sql := range sqls { + require.NoError(t, db.Exec(sql).Error) + } + + return db +} + +func insertActivationUser(t *testing.T, db *gorm.DB, userID int64, createdAt time.Time) { + t.Helper() + require.NoError(t, db.Exec( + `INSERT INTO "user" (id, created_at, updated_at) VALUES (?, ?, datetime('now'))`, + userID, + createdAt.UTC().Format("2006-01-02 15:04:05"), + ).Error) +} + +func insertActivationDevice(t *testing.T, db *gorm.DB, userID int64, identifier string, createdAt time.Time) { + t.Helper() + require.NoError(t, db.Exec( + `INSERT INTO "user_device" (user_id, identifier, created_at, updated_at) VALUES (?, ?, ?, datetime('now'))`, + userID, + identifier, + createdAt.UTC().Format("2006-01-02 15:04:05"), + ).Error) +} + +func insertActivationFamily(t *testing.T, db *gorm.DB, familyID, ownerUserID int64) { + t.Helper() + require.NoError(t, db.Exec( + `INSERT INTO "user_family" (id, owner_user_id, status) VALUES (?, ?, 1)`, + familyID, + ownerUserID, + ).Error) +} + +func insertActivationFamilyMember(t *testing.T, db *gorm.DB, familyID, userID int64, role, status uint8, joinSource string) { + t.Helper() + require.NoError(t, db.Exec( + `INSERT INTO "user_family_member" (family_id, user_id, role, status, join_source) VALUES (?, ?, ?, ?, ?)`, + familyID, + userID, + role, + status, + joinSource, + ).Error) +} + +func insertActivationOrder(t *testing.T, db *gorm.DB, orderNo string, userID, subscribeID int64, status uint8) { + t.Helper() + require.NoError(t, db.Exec( + `INSERT INTO "order" (user_id, order_no, type, status, subscribe_id, quantity, created_at, updated_at) + VALUES (?, ?, 1, ?, ?, 1, datetime('now'), datetime('now'))`, + userID, + orderNo, + status, + subscribeID, + ).Error) +} + +func TestValidateNewUserOnlyEligibilityAtActivation_UsesEarliestBoundDeviceTime(t *testing.T) { + db := setupActivationEligibilityDB(t) + + const ( + ownerUserID = int64(1) + memberUserID = int64(2) + familyID = int64(10) + subscribeID = int64(100) + ) + + insertActivationUser(t, db, ownerUserID, time.Now().Add(-1*time.Hour)) + insertActivationUser(t, db, memberUserID, time.Now().Add(-72*time.Hour)) + insertActivationDevice(t, db, memberUserID, "activation-old-device", time.Now().Add(-72*time.Hour)) + insertActivationFamily(t, db, familyID, ownerUserID) + insertActivationFamilyMember(t, db, familyID, ownerUserID, user.FamilyRoleOwner, user.FamilyMemberActive, "owner_init") + insertActivationFamilyMember(t, db, familyID, memberUserID, user.FamilyRoleMember, user.FamilyMemberActive, "bind_email_with_verification") + + err := validateNewUserOnlyEligibilityAtActivation( + context.Background(), + db, + &modelOrder.Order{ + UserId: ownerUserID, + OrderNo: "activation-check-old-device", + Type: OrderTypeSubscribe, + Quantity: 1, + SubscribeId: subscribeID, + }, + &subscribe.Subscribe{ + Id: subscribeID, + Discount: `[{"quantity":1,"discount":90,"new_user_only":true}]`, + }, + ) + + require.Error(t, err) + require.Contains(t, err.Error(), "is not a new user") +} + +func TestValidateNewUserOnlyEligibilityAtActivation_SharesHistoryAcrossBoundScope(t *testing.T) { + db := setupActivationEligibilityDB(t) + + const ( + ownerUserID = int64(11) + memberUserID = int64(12) + familyID = int64(20) + subscribeID = int64(200) + ) + + insertActivationUser(t, db, ownerUserID, time.Now().Add(-1*time.Hour)) + insertActivationUser(t, db, memberUserID, time.Now().Add(-2*time.Hour)) + insertActivationDevice(t, db, memberUserID, "activation-shared-device", time.Now().Add(-2*time.Hour)) + insertActivationFamily(t, db, familyID, ownerUserID) + insertActivationFamilyMember(t, db, familyID, ownerUserID, user.FamilyRoleOwner, user.FamilyMemberActive, "owner_init") + insertActivationFamilyMember(t, db, familyID, memberUserID, user.FamilyRoleMember, user.FamilyMemberActive, "bind_email_with_verification") + insertActivationOrder(t, db, "previous-finished-order", memberUserID, subscribeID, OrderStatusFinished) + + err := validateNewUserOnlyEligibilityAtActivation( + context.Background(), + db, + &modelOrder.Order{ + UserId: ownerUserID, + OrderNo: "current-paid-order", + Type: OrderTypeSubscribe, + Quantity: 1, + SubscribeId: subscribeID, + }, + &subscribe.Subscribe{ + Id: subscribeID, + Discount: `[{"quantity":1,"discount":90,"new_user_only":true}]`, + }, + ) + + require.Error(t, err) + require.Contains(t, err.Error(), "already activated") +} diff --git a/queue/logic/order/newUserEligibility.go b/queue/logic/order/newUserEligibility.go index 066ef3c..76e6b61 100644 --- a/queue/logic/order/newUserEligibility.go +++ b/queue/logic/order/newUserEligibility.go @@ -43,7 +43,7 @@ func validateNewUserOnlyEligibilityAtActivation( ctx, db, eligibility.ScopeUserIDs, - orderInfo.SubscribeId, + 0, []int64{OrderStatusFinished}, orderInfo.OrderNo, ) @@ -51,7 +51,7 @@ func validateNewUserOnlyEligibilityAtActivation( return fmt.Errorf("new user only: check history error: %w", err) } if historyCount >= 1 { - return fmt.Errorf("new user only: user %d already activated subscribe %d", orderInfo.UserId, orderInfo.SubscribeId) + return fmt.Errorf("new user only: user %d already activated an order", orderInfo.UserId) } return nil diff --git a/scripts/diagnose_business_bugs.go b/scripts/diagnose_business_bugs.go new file mode 100644 index 0000000..2455e6e --- /dev/null +++ b/scripts/diagnose_business_bugs.go @@ -0,0 +1,197 @@ +package main + +import ( + "database/sql" + "flag" + "fmt" + "log" + "os" + "strings" + + _ "github.com/go-sql-driver/mysql" +) + +func main() { + dsn := flag.String("dsn", os.Getenv("PPANEL_MYSQL_DSN"), "MySQL DSN; defaults to PPANEL_MYSQL_DSN") + flag.Parse() + if strings.TrimSpace(*dsn) == "" { + log.Fatal("missing DSN: pass -dsn or set PPANEL_MYSQL_DSN") + } + + db, err := sql.Open("mysql", *dsn) + if err != nil { + log.Fatal(err) + } + defer db.Close() + if err = db.Ping(); err != nil { + log.Fatal(err) + } + + mustPrintRows(db, "db/info", ` +SELECT NOW() AS db_now, + (SELECT COUNT(*) FROM user) AS users, + (SELECT COUNT(*) FROM user_subscribe) AS user_subscribes, + (SELECT COUNT(*) FROM `+"`order`"+`) AS orders`) + + mustPrintRows(db, "bug1/confusable-email-trials", ` +SELECT uam.user_id, + uam.auth_identifier, + us.id AS user_subscribe_id, + us.order_id, + us.status, + us.expire_time, + us.created_at +FROM user_auth_methods uam +JOIN user_subscribe us ON us.user_id = uam.user_id +WHERE uam.auth_type = 'email' + AND us.order_id = 0 + AND ( + uam.auth_identifier LIKE '%@gmaial.com' + OR uam.auth_identifier LIKE '%@gmial.com' + OR uam.auth_identifier LIKE '%@gamil.com' + OR uam.auth_identifier LIKE '%+%@%' + OR uam.auth_identifier REGEXP '^[^@]*\\.[^@]*@gmail\\.com$' + ) +ORDER BY us.created_at DESC +LIMIT 50`) + + mustPrintRows(db, "bug2-visible-duplicate-subscriptions", ` +SELECT scoped.owner_user_id, + COUNT(*) AS visible_subscribe_count, + GROUP_CONCAT(scoped.user_subscribe_id ORDER BY scoped.expire_time DESC) AS user_subscribe_ids, + GROUP_CONCAT(scoped.subscribe_id ORDER BY scoped.expire_time DESC) AS subscribe_ids, + MAX(scoped.expire_time) AS max_expire_time +FROM ( + SELECT us.id AS user_subscribe_id, + us.user_id, + COALESCE(uf.owner_user_id, us.user_id) AS owner_user_id, + us.subscribe_id, + us.status, + us.expire_time, + us.finished_at + FROM user_subscribe us + LEFT JOIN user_family_member ufm + ON ufm.user_id = us.user_id AND ufm.deleted_at IS NULL AND ufm.status = 1 + LEFT JOIN user_family uf + ON uf.id = ufm.family_id AND uf.deleted_at IS NULL AND uf.status = 1 + WHERE us.token <> '' + AND us.status IN (0,1,2,3,4) + AND (us.expire_time > NOW() + OR us.finished_at >= DATE_SUB(NOW(), INTERVAL 7 DAY) + OR us.expire_time = FROM_UNIXTIME(0)) +) scoped +GROUP BY scoped.owner_user_id +HAVING COUNT(*) > 1 +ORDER BY visible_subscribe_count DESC, owner_user_id +LIMIT 50`) + + mustPrintRows(db, "bug2-order-subscription-owner-mismatch", ` +SELECT us.id AS user_subscribe_id, + us.user_id AS subscribe_user_id, + o.id AS order_id, + o.order_no, + o.user_id AS order_user_id, + o.subscription_user_id, + us.status, + us.expire_time, + us.created_at AS subscribe_created_at, + o.created_at AS order_created_at +FROM user_subscribe us +JOIN `+"`order`"+` o ON o.id = us.order_id +WHERE us.user_id <> o.subscription_user_id + AND us.token <> '' + AND us.status IN (0,1,2,3,4) +ORDER BY us.updated_at DESC +LIMIT 50`) + + mustPrintRows(db, "bug3-invite-first-orders-missing-gift-days", ` +SELECT first_orders.user_id AS referee_id, + referee.referer_id, + first_orders.id AS order_id, + first_orders.order_no, + first_orders.amount, + first_orders.created_at, + referer.referral_percentage AS referer_referral_percentage, + (SELECT COUNT(*) FROM system_logs sl + WHERE sl.type = 34 + AND sl.object_id = first_orders.user_id + AND sl.content LIKE CONCAT('%', first_orders.order_no, '%')) AS referee_gift_logs, + (SELECT COUNT(*) FROM system_logs sl + WHERE sl.type = 34 + AND sl.object_id = referee.referer_id + AND sl.content LIKE CONCAT('%', first_orders.order_no, '%')) AS referer_gift_logs +FROM ( + SELECT o.* + FROM `+"`order`"+` o + JOIN ( + SELECT user_id, MIN(id) AS first_order_id + FROM `+"`order`"+` + WHERE type IN (1,2) + AND status IN (2,5) + AND amount > 0 + GROUP BY user_id + ) fo ON fo.first_order_id = o.id +) first_orders +JOIN user referee ON referee.id = first_orders.user_id AND referee.referer_id <> 0 +JOIN user referer ON referer.id = referee.referer_id +WHERE ( + referer.referral_percentage = 0 + AND ( + (SELECT COUNT(*) FROM system_logs sl + WHERE sl.type = 34 AND sl.object_id = first_orders.user_id AND sl.content LIKE CONCAT('%', first_orders.order_no, '%')) = 0 + OR + (SELECT COUNT(*) FROM system_logs sl + WHERE sl.type = 34 AND sl.object_id = referee.referer_id AND sl.content LIKE CONCAT('%', first_orders.order_no, '%')) = 0 + ) + ) + OR ( + referer.referral_percentage > 0 + AND (SELECT COUNT(*) FROM system_logs sl + WHERE sl.type = 34 AND sl.object_id = first_orders.user_id AND sl.content LIKE CONCAT('%', first_orders.order_no, '%')) = 0 + ) +ORDER BY first_orders.created_at DESC +LIMIT 50`) +} + +func mustPrintRows(db *sql.DB, title string, query string) { + fmt.Printf("\n== %s ==\n", title) + rows, err := db.Query(query) + if err != nil { + log.Fatalf("%s: %v", title, err) + } + defer rows.Close() + + cols, err := rows.Columns() + if err != nil { + log.Fatalf("%s columns: %v", title, err) + } + fmt.Println(strings.Join(cols, "\t")) + + values := make([]sql.NullString, len(cols)) + args := make([]any, len(cols)) + for i := range values { + args[i] = &values[i] + } + count := 0 + for rows.Next() { + if err := rows.Scan(args...); err != nil { + log.Fatalf("%s scan: %v", title, err) + } + out := make([]string, len(cols)) + for i, value := range values { + if value.Valid { + out[i] = value.String + } else { + out[i] = "NULL" + } + } + fmt.Println(strings.Join(out, "\t")) + count++ + } + if err := rows.Err(); err != nil { + log.Fatalf("%s rows: %v", title, err) + } + if count == 0 { + fmt.Println("(none)") + } +} diff --git a/scripts/merge_duplicate_subscriptions.go b/scripts/merge_duplicate_subscriptions.go new file mode 100644 index 0000000..400bb21 --- /dev/null +++ b/scripts/merge_duplicate_subscriptions.go @@ -0,0 +1,204 @@ +package main + +import ( + "database/sql" + "encoding/json" + "flag" + "fmt" + "log" + "os" + "strings" + "time" + + _ "github.com/go-sql-driver/mysql" +) + +type duplicateGroup struct { + OwnerUserID int64 `json:"owner_user_id"` + Count int64 `json:"count"` +} + +type subscriptionRow struct { + ID int64 `json:"id"` + UserID int64 `json:"user_id"` + OrderID int64 `json:"order_id"` + SubscribeID int64 `json:"subscribe_id"` + ExpireTime time.Time `json:"expire_time"` + Traffic int64 `json:"traffic"` + Download int64 `json:"download"` + Upload int64 `json:"upload"` + ExpiredDownload int64 `json:"expired_download"` + ExpiredUpload int64 `json:"expired_upload"` + Status uint8 `json:"status"` + UpdatedAt time.Time `json:"updated_at"` +} + +type mergePlan struct { + OwnerUserID int64 `json:"owner_user_id"` + Keep subscriptionRow `json:"keep"` + Merge []subscriptionRow `json:"merge"` +} + +func main() { + dsn := flag.String("dsn", os.Getenv("PPANEL_MYSQL_DSN"), "MySQL DSN; defaults to PPANEL_MYSQL_DSN") + execute := flag.Bool("execute", false, "apply changes; default is dry-run") + flag.Parse() + + if strings.TrimSpace(*dsn) == "" { + log.Fatal("missing DSN: pass -dsn or set PPANEL_MYSQL_DSN") + } + + db, err := sql.Open("mysql", *dsn) + if err != nil { + log.Fatal(err) + } + defer db.Close() + + groups, err := findDuplicateGroups(db) + if err != nil { + log.Fatal(err) + } + + plans := make([]mergePlan, 0, len(groups)) + for _, group := range groups { + plan, err := buildPlan(db, group.OwnerUserID) + if err != nil { + log.Fatal(err) + } + if len(plan.Merge) > 0 { + plans = append(plans, plan) + } + } + + enc := json.NewEncoder(os.Stdout) + enc.SetIndent("", " ") + if err := enc.Encode(plans); err != nil { + log.Fatal(err) + } + + if !*execute { + fmt.Fprintf(os.Stderr, "dry-run only: %d duplicate owner groups found\n", len(plans)) + return + } + + for _, plan := range plans { + if err := applyPlan(db, plan); err != nil { + log.Fatal(err) + } + } + fmt.Fprintf(os.Stderr, "merged %d duplicate owner groups\n", len(plans)) +} + +func findDuplicateGroups(db *sql.DB) ([]duplicateGroup, error) { + rows, err := db.Query(` +SELECT owner_user_id, COUNT(1) AS cnt +FROM ( + SELECT us.id, + COALESCE(uf.owner_user_id, us.user_id) AS owner_user_id + FROM user_subscribe us + LEFT JOIN user_family_member ufm + ON ufm.user_id = us.user_id AND ufm.deleted_at IS NULL AND ufm.status = 1 + LEFT JOIN user_family uf + ON uf.id = ufm.family_id AND uf.deleted_at IS NULL AND uf.status = 1 + WHERE us.token <> '' + AND us.status IN (0, 1, 2, 3, 4) +) scoped +GROUP BY owner_user_id +HAVING COUNT(1) > 1 +ORDER BY owner_user_id`) + if err != nil { + return nil, err + } + defer rows.Close() + + var groups []duplicateGroup + for rows.Next() { + var g duplicateGroup + if err := rows.Scan(&g.OwnerUserID, &g.Count); err != nil { + return nil, err + } + groups = append(groups, g) + } + return groups, rows.Err() +} + +func buildPlan(db *sql.DB, ownerUserID int64) (mergePlan, error) { + rows, err := db.Query(` +SELECT us.id, us.user_id, us.order_id, us.subscribe_id, us.expire_time, us.traffic, + us.download, us.upload, us.expired_download, us.expired_upload, us.status, us.updated_at +FROM user_subscribe us +LEFT JOIN user_family_member ufm + ON ufm.user_id = us.user_id AND ufm.deleted_at IS NULL AND ufm.status = 1 +LEFT JOIN user_family uf + ON uf.id = ufm.family_id AND uf.deleted_at IS NULL AND uf.status = 1 +WHERE COALESCE(uf.owner_user_id, us.user_id) = ? + AND us.token <> '' + AND us.status IN (0, 1, 2, 3, 4) +ORDER BY us.expire_time DESC, us.updated_at DESC, us.id DESC`, ownerUserID) + if err != nil { + return mergePlan{}, err + } + defer rows.Close() + + var all []subscriptionRow + for rows.Next() { + var r subscriptionRow + if err := rows.Scan(&r.ID, &r.UserID, &r.OrderID, &r.SubscribeID, &r.ExpireTime, &r.Traffic, &r.Download, &r.Upload, &r.ExpiredDownload, &r.ExpiredUpload, &r.Status, &r.UpdatedAt); err != nil { + return mergePlan{}, err + } + all = append(all, r) + } + if err := rows.Err(); err != nil { + return mergePlan{}, err + } + if len(all) == 0 { + return mergePlan{OwnerUserID: ownerUserID}, nil + } + + keep := all[0] + for _, r := range all[1:] { + keep.Download += r.Download + keep.Upload += r.Upload + keep.ExpiredDownload += r.ExpiredDownload + keep.ExpiredUpload += r.ExpiredUpload + if r.Traffic > keep.Traffic { + keep.Traffic = r.Traffic + } + } + for _, r := range all { + if r.UpdatedAt.After(keep.UpdatedAt) { + keep.SubscribeID = r.SubscribeID + } + } + + return mergePlan{OwnerUserID: ownerUserID, Keep: keep, Merge: all[1:]}, nil +} + +func applyPlan(db *sql.DB, plan mergePlan) error { + tx, err := db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + + if _, err = tx.Exec(` +UPDATE user_subscribe +SET user_id = ?, subscribe_id = ?, traffic = ?, download = ?, upload = ?, + expired_download = ?, expired_upload = ?, status = 1, note = CONCAT(COALESCE(note, ''), ' [merged duplicate subscriptions]') +WHERE id = ?`, + plan.OwnerUserID, plan.Keep.SubscribeID, plan.Keep.Traffic, plan.Keep.Download, plan.Keep.Upload, + plan.Keep.ExpiredDownload, plan.Keep.ExpiredUpload, plan.Keep.ID); err != nil { + return err + } + + for _, r := range plan.Merge { + if _, err = tx.Exec(` +UPDATE user_subscribe +SET status = 5, note = CONCAT(COALESCE(note, ''), ' [merged into subscription #', ?, ']') +WHERE id = ?`, plan.Keep.ID, r.ID); err != nil { + return err + } + } + + return tx.Commit() +} diff --git a/scripts/replay_business_bugs.go b/scripts/replay_business_bugs.go new file mode 100644 index 0000000..ecdb390 --- /dev/null +++ b/scripts/replay_business_bugs.go @@ -0,0 +1,787 @@ +package main + +import ( + "context" + "encoding/json" + "flag" + "fmt" + "os" + "os/exec" + "strings" + "time" + + "github.com/hibiken/asynq" + "github.com/perfect-panel/server/internal/config" + authlogic "github.com/perfect-panel/server/internal/logic/auth" + modelLog "github.com/perfect-panel/server/internal/model/log" + modelOrder "github.com/perfect-panel/server/internal/model/order" + modelSubscribe "github.com/perfect-panel/server/internal/model/subscribe" + modelUser "github.com/perfect-panel/server/internal/model/user" + "github.com/perfect-panel/server/internal/svc" + "github.com/perfect-panel/server/pkg/conf" + "github.com/perfect-panel/server/pkg/orm" + "github.com/perfect-panel/server/pkg/uuidx" + orderLogic "github.com/perfect-panel/server/queue/logic/order" + queueTypes "github.com/perfect-panel/server/queue/types" + "github.com/redis/go-redis/v9" + "gorm.io/gorm" +) + +const marker = "codex-replay-business-bugs" + +func main() { + var ( + configPath = flag.String("config", "etc/ppanel.yaml", "ppanel config path for test server DB/Redis") + dsn = flag.String("dsn", "", "optional MySQL DSN override: user:pass@tcp(host:3306)/db?charset=utf8mb4&parseTime=true&loc=Asia%2FShanghai") + writeDB = flag.Bool("write-db", false, "create isolated test rows and execute activation replay against the configured test DB") + force = flag.Bool("force", false, "allow -write-db even when the config name does not clearly look like test/dev/staging") + keep = flag.Bool("keep", false, "keep replay rows for manual inspection") + cleanupOnly = flag.Bool("cleanup-only", false, "delete leftover replay rows by marker and exit") + skipCodeTests = flag.Bool("skip-code-tests", false, "skip go test checks") + ) + flag.Parse() + + ctx := context.Background() + started := time.Now() + fmt.Println("== replay business bug tests ==") + fmt.Printf("marker: %s\n", marker) + + if !*skipCodeTests { + must(runCodeTests()) + } + + cfg := loadConfig(*configPath, *dsn) + runEmailTrialAssertions(cfg) + + if *cleanupOnly { + env := mustNewReplayEnv(ctx, cfg) + env.cleanupByMarker(ctx) + return + } + + if !*writeDB { + fmt.Println("\nDB replay skipped. Add -write-db to create isolated rows in the TEST database and run activation flows.") + fmt.Println("Example:") + fmt.Printf(" go run scripts/replay_business_bugs.go -config %s -write-db\n", *configPath) + return + } + + if looksLikeProduction(cfg) && !*force { + fatalf("refusing to write DB because config does not look like a test environment: db=%s host=%s; add -force only on the test server", cfg.MySQL.Dbname, cfg.Site.Host) + } + + env := mustNewReplayEnv(ctx, cfg) + if !*keep { + defer env.cleanup(ctx) + } + + must(env.replaySingleSubscription(ctx)) + must(env.replayInviteRulesMatrix(ctx)) + must(env.replayFamilyInviteGiftToOwner(ctx)) + + fmt.Printf("\nPASS all replay checks in %s\n", time.Since(started).Round(time.Millisecond)) + if *keep { + fmt.Println("Replay rows kept for inspection. Delete rows with remark/name/order_no containing:", marker) + } +} + +func runCodeTests() error { + fmt.Println("\n-- code-level tests --") + args := []string{"test", + "./internal/logic/auth", + "./internal/logic/common", + "./internal/logic/public/order", + "./queue/logic/order", + } + cmd := exec.Command("go", args...) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + if err := cmd.Run(); err != nil { + return fmt.Errorf("go test failed: %w", err) + } + fmt.Println("PASS code-level tests") + return nil +} + +func loadConfig(path, dsn string) config.Config { + var cfg config.Config + conf.MustLoad(path, &cfg) + if dsn != "" { + cfg.MySQL = parseDSN(dsn) + } + return cfg +} + +func parseDSN(dsn string) orm.Config { + cfg := orm.ParseDSN(dsn) + if cfg == nil { + fatalf("invalid dsn") + } + return *cfg +} + +func runEmailTrialAssertions(cfg config.Config) { + fmt.Println("\n-- bug1 email trial whitelist assertions --") + cfg.Register.EnableTrial = true + cfg.Register.EnableTrialEmailWhitelist = true + if cfg.Register.TrialEmailDomainWhitelist == "" { + cfg.Register.TrialEmailDomainWhitelist = "gmail.com,163.com" + } + + cases := []struct { + email string + want bool + }{ + {"1.2.3.4xxx@gmaial.com", false}, + {"a.b.c@gmail.com", false}, + {"user+tag@gmail.com", false}, + {"user@fake.gmail.com", false}, + {"normaluser@gmail.com", true}, + } + for _, tc := range cases { + got := authlogic.ShouldGrantTrialForEmail(cfg.Register, tc.email) + if got != tc.want { + fatalf("email trial assertion failed: email=%s got=%v want=%v", tc.email, got, tc.want) + } + fmt.Printf("PASS %-32s grant=%v\n", tc.email, got) + } +} + +type replayEnv struct { + db *gorm.DB + rds *redis.Client + cfg config.Config + svcCtx *svc.ServiceContext + ids struct { + users []int64 + subscribes []int64 + plans []int64 + orders []int64 + logs []int64 + } +} + +func mustNewReplayEnv(ctx context.Context, cfg config.Config) *replayEnv { + fmt.Println("\n-- connecting test DB/Redis --") + db, err := orm.ConnectMysql(orm.Mysql{Config: cfg.MySQL}) + must(err) + rds := redis.NewClient(&redis.Options{ + Addr: cfg.Redis.Host, + Password: cfg.Redis.Pass, + DB: cfg.Redis.DB, + PoolSize: cfg.Redis.PoolSize, + MinIdleConns: cfg.Redis.MinIdleConns, + }) + must(rds.Ping(ctx).Err()) + + svcCtx := &svc.ServiceContext{ + DB: db, + Redis: rds, + Config: cfg, + UserModel: modelUser.NewModel(db, rds), + OrderModel: modelOrder.NewModel(db, rds), + SubscribeModel: modelSubscribe.NewModel(db, rds), + LogModel: modelLog.NewModel(db), + } + fmt.Printf("connected: mysql=%s/%s redis=%s\n", cfg.MySQL.Addr, cfg.MySQL.Dbname, cfg.Redis.Host) + return &replayEnv{db: db, rds: rds, cfg: cfg, svcCtx: svcCtx} +} + +func (e *replayEnv) replaySingleSubscription(ctx context.Context) error { + fmt.Println("\n-- bug2 replay: paid purchase must reuse existing subscription --") + planA, planB, err := e.createPlans(ctx, "bug2") + if err != nil { + return err + } + owner, err := e.createUser(ctx, "bug2-owner", 0, 0) + if err != nil { + return err + } + existing, err := e.createUserSubscribe(ctx, owner.Id, 0, planA.Id, time.Now().Add(7*24*time.Hour)) + if err != nil { + return err + } + order, err := e.createPaidOrder(ctx, owner.Id, owner.Id, planB.Id, true, "bug2") + if err != nil { + return err + } + + payload, _ := json.Marshal(queueTypes.ForthwithActivateOrderPayload{OrderNo: order.OrderNo}) + worker := orderLogic.NewActivateOrderLogic(e.svcCtx) + if err = worker.ProcessTask(ctx, asynq.NewTask(queueTypes.ForthwithActivateOrder, payload)); err != nil { + return err + } + if err = worker.ProcessTask(ctx, asynq.NewTask(queueTypes.ForthwithActivateOrder, payload)); err != nil { + return err + } + + var rows []modelUser.Subscribe + if err = e.db.WithContext(ctx). + Where("user_id = ? AND token <> '' AND status IN ?", owner.Id, []int{0, 1, 2, 3, 4}). + Order("id ASC"). + Find(&rows).Error; err != nil { + return err + } + if len(rows) != 1 { + return fmt.Errorf("bug2 failed: expected one visible subscription, got %d", len(rows)) + } + if rows[0].Id != existing.Id { + return fmt.Errorf("bug2 failed: expected original subscription id=%d to be reused, got id=%d", existing.Id, rows[0].Id) + } + if rows[0].SubscribeId != planB.Id || rows[0].OrderId != order.Id { + return fmt.Errorf("bug2 failed: reused subscription not updated, subscribe_id=%d order_id=%d", rows[0].SubscribeId, rows[0].OrderId) + } + fmt.Printf("PASS user=%d user_subscribe=%d plan %d -> %d order=%s\n", owner.Id, rows[0].Id, planA.Id, planB.Id, order.OrderNo) + return nil +} + +func (e *replayEnv) replayInviteGiftDays(ctx context.Context) error { + fmt.Println("\n-- bug3 replay: commission=0 invite should grant gift days to both users --") + giftDays := e.cfg.Invite.GiftDays + if giftDays <= 0 { + giftDays = 2 + e.svcCtx.Config.Invite.GiftDays = giftDays + } + e.svcCtx.Config.Invite.ReferralPercentage = 0 + e.svcCtx.Config.Invite.OnlyFirstPurchase = true + + planA, _, err := e.createPlans(ctx, "bug3") + if err != nil { + return err + } + referer, err := e.createUser(ctx, "bug3-referer", 0, 0) + if err != nil { + return err + } + referee, err := e.createUser(ctx, "bug3-referee", referer.Id, 0) + if err != nil { + return err + } + + baseExpire := time.Now().Add(10 * 24 * time.Hour).Truncate(time.Millisecond) + refererSub, err := e.createUserSubscribe(ctx, referer.Id, 0, planA.Id, baseExpire) + if err != nil { + return err + } + refereeSub, err := e.createUserSubscribe(ctx, referee.Id, 0, planA.Id, baseExpire) + if err != nil { + return err + } + order, err := e.createPaidOrder(ctx, referee.Id, referee.Id, planA.Id, true, "bug3") + if err != nil { + return err + } + + payload, _ := json.Marshal(queueTypes.ForthwithActivateOrderPayload{OrderNo: order.OrderNo}) + worker := orderLogic.NewActivateOrderLogic(e.svcCtx) + if err = worker.ProcessTask(ctx, asynq.NewTask(queueTypes.ForthwithActivateOrder, payload)); err != nil { + return err + } + + if err = e.waitForGiftLogs(ctx, order.OrderNo, referer.Id, referee.Id); err != nil { + return err + } + + var refererAfter, refereeAfter modelUser.Subscribe + if err = e.db.WithContext(ctx).First(&refererAfter, refererSub.Id).Error; err != nil { + return err + } + if err = e.db.WithContext(ctx).First(&refereeAfter, refereeSub.Id).Error; err != nil { + return err + } + minRefererExpire := baseExpire.Add(time.Duration(giftDays) * 24 * time.Hour) + if refererAfter.ExpireTime.Before(minRefererExpire.Add(-time.Second)) { + return fmt.Errorf("bug3 failed: referer expire not increased by gift days, got=%s want>=%s", refererAfter.ExpireTime, minRefererExpire) + } + if !refereeAfter.ExpireTime.After(baseExpire) { + return fmt.Errorf("bug3 failed: referee expire did not increase, got=%s base=%s", refereeAfter.ExpireTime, baseExpire) + } + + // Idempotency: repeat the same order task and make sure gift logs are still one per user. + if err = worker.ProcessTask(ctx, asynq.NewTask(queueTypes.ForthwithActivateOrder, payload)); err != nil { + return err + } + var giftCount int64 + if err = e.db.WithContext(ctx).Model(&modelLog.SystemLog{}). + Where("type = ? AND object_id IN ? AND content LIKE ?", modelLog.TypeGift.Uint8(), []int64{referer.Id, referee.Id}, "%"+order.OrderNo+"%"). + Count(&giftCount).Error; err != nil { + return err + } + if giftCount != 2 { + return fmt.Errorf("bug3 failed: expected 2 gift logs after duplicate task, got %d", giftCount) + } + fmt.Printf("PASS referer=%d referee=%d order=%s gift_days=%d logs=%d\n", referer.Id, referee.Id, order.OrderNo, giftDays, giftCount) + return nil +} + +func (e *replayEnv) replayInviteRulesMatrix(ctx context.Context) error { + fmt.Println("\n-- bug3 replay matrix: invite gift/commission rules --") + giftDays := e.cfg.Invite.GiftDays + if giftDays <= 0 { + giftDays = 2 + } + e.svcCtx.Config.Invite.GiftDays = giftDays + e.svcCtx.Config.Invite.OnlyFirstPurchase = false + + planA, _, err := e.createPlans(ctx, "bug3-matrix") + if err != nil { + return err + } + + cases := []struct { + name string + hasReferer bool + globalReferralPct int64 + isNewOrder bool + wantGiftLogs int64 + wantCommissionLogs int64 + wantCommission int64 + }{ + { + name: "no invite relation first order no gift", + hasReferer: false, + isNewOrder: true, + wantGiftLogs: 0, + }, + { + name: "ordinary invite commission 0 first order gifts both", + hasReferer: true, + isNewOrder: true, + wantGiftLogs: 2, + }, + { + name: "ordinary invite commission 0 non-first order no gift", + hasReferer: true, + isNewOrder: false, + wantGiftLogs: 0, + }, + { + name: "channel commission positive first order gifts referee only", + hasReferer: true, + globalReferralPct: 10, + isNewOrder: true, + wantGiftLogs: 1, + wantCommissionLogs: 1, + wantCommission: 59, + }, + { + name: "channel commission positive non-first order commission only", + hasReferer: true, + globalReferralPct: 10, + isNewOrder: false, + wantGiftLogs: 0, + wantCommissionLogs: 1, + wantCommission: 59, + }, + } + + for idx, tc := range cases { + e.svcCtx.Config.Invite.ReferralPercentage = tc.globalReferralPct + scope := fmt.Sprintf("bug3-rule-%d", idx+1) + var referer *modelUser.User + if tc.hasReferer { + referer, err = e.createUser(ctx, scope+"-referer", 0, 0) + if err != nil { + return err + } + if _, err = e.createUserSubscribe(ctx, referer.Id, 0, planA.Id, time.Now().Add(10*24*time.Hour)); err != nil { + return err + } + } + var refererID int64 + if referer != nil { + refererID = referer.Id + } + referee, err := e.createUser(ctx, scope+"-referee", refererID, 0) + if err != nil { + return err + } + if _, err = e.createUserSubscribe(ctx, referee.Id, 0, planA.Id, time.Now().Add(10*24*time.Hour)); err != nil { + return err + } + order, err := e.createPaidOrder(ctx, referee.Id, referee.Id, planA.Id, tc.isNewOrder, scope) + if err != nil { + return err + } + if err = e.activateOrderTwice(ctx, order.OrderNo); err != nil { + return fmt.Errorf("%s: %w", tc.name, err) + } + if err = e.waitForLogCounts(ctx, order.OrderNo, tc.wantGiftLogs, tc.wantCommissionLogs); err != nil { + return fmt.Errorf("%s: %w", tc.name, err) + } + giftLogs, err := e.countLogs(ctx, modelLog.TypeGift.Uint8(), order.OrderNo) + if err != nil { + return err + } + commissionLogs, err := e.countLogs(ctx, modelLog.TypeCommission.Uint8(), order.OrderNo) + if err != nil { + return err + } + if giftLogs != tc.wantGiftLogs { + return fmt.Errorf("%s: expected gift logs=%d got=%d", tc.name, tc.wantGiftLogs, giftLogs) + } + if commissionLogs != tc.wantCommissionLogs { + return fmt.Errorf("%s: expected commission logs=%d got=%d", tc.name, tc.wantCommissionLogs, commissionLogs) + } + if referer != nil && tc.wantCommission > 0 { + var after modelUser.User + if err = e.db.WithContext(ctx).First(&after, referer.Id).Error; err != nil { + return err + } + if after.Commission != tc.wantCommission { + return fmt.Errorf("%s: expected referer commission=%d got=%d", tc.name, tc.wantCommission, after.Commission) + } + } + fmt.Printf("PASS %-58s gifts=%d commission_logs=%d\n", tc.name, giftLogs, commissionLogs) + } + return nil +} + +func (e *replayEnv) replayFamilyInviteGiftToOwner(ctx context.Context) error { + fmt.Println("\n-- bug3 family replay: member purchase gift days go to owner --") + giftDays := e.cfg.Invite.GiftDays + if giftDays <= 0 { + giftDays = 2 + } + e.svcCtx.Config.Invite.GiftDays = giftDays + e.svcCtx.Config.Invite.ReferralPercentage = 0 + e.svcCtx.Config.Invite.OnlyFirstPurchase = false + + planA, _, err := e.createPlans(ctx, "bug3-family") + if err != nil { + return err + } + referer, err := e.createUser(ctx, "bug3-family-referer", 0, 0) + if err != nil { + return err + } + owner, err := e.createUser(ctx, "bug3-family-owner", 0, 0) + if err != nil { + return err + } + member, err := e.createUser(ctx, "bug3-family-member", referer.Id, 0) + if err != nil { + return err + } + if err = e.createFamily(ctx, owner.Id, member.Id); err != nil { + return err + } + + baseExpire := time.Now().Add(10 * 24 * time.Hour).Truncate(time.Millisecond) + ownerSub, err := e.createUserSubscribe(ctx, owner.Id, 0, planA.Id, baseExpire) + if err != nil { + return err + } + memberSub, err := e.createUserSubscribe(ctx, member.Id, 0, planA.Id, baseExpire) + if err != nil { + return err + } + refererSub, err := e.createUserSubscribe(ctx, referer.Id, 0, planA.Id, baseExpire) + if err != nil { + return err + } + + order, err := e.createPaidOrder(ctx, member.Id, owner.Id, planA.Id, true, "bug3-family") + if err != nil { + return err + } + if err = e.activateOrderTwice(ctx, order.OrderNo); err != nil { + return err + } + if err = e.waitForLogCounts(ctx, order.OrderNo, 2, 0); err != nil { + return err + } + + var ownerAfter, memberAfter, refererAfter modelUser.Subscribe + if err = e.db.WithContext(ctx).First(&ownerAfter, ownerSub.Id).Error; err != nil { + return err + } + if err = e.db.WithContext(ctx).First(&memberAfter, memberSub.Id).Error; err != nil { + return err + } + if err = e.db.WithContext(ctx).First(&refererAfter, refererSub.Id).Error; err != nil { + return err + } + if !ownerAfter.ExpireTime.After(baseExpire) { + return fmt.Errorf("family gift failed: owner expire not increased") + } + if !refererAfter.ExpireTime.After(baseExpire) { + return fmt.Errorf("family gift failed: referer expire not increased") + } + if memberAfter.ExpireTime.After(baseExpire.Add(time.Second)) { + return fmt.Errorf("family gift failed: member subscription should not receive gift days") + } + + var memberGiftLogs int64 + if err = e.db.WithContext(ctx).Model(&modelLog.SystemLog{}). + Where("type = ? AND object_id = ? AND content LIKE ?", modelLog.TypeGift.Uint8(), member.Id, "%"+order.OrderNo+"%"). + Count(&memberGiftLogs).Error; err != nil { + return err + } + if memberGiftLogs != 0 { + return fmt.Errorf("family gift failed: expected no member gift logs, got %d", memberGiftLogs) + } + fmt.Printf("PASS family member purchase gift target owner owner=%d member=%d referer=%d gift_days=%d\n", owner.Id, member.Id, referer.Id, giftDays) + return nil +} + +func (e *replayEnv) activateOrderTwice(ctx context.Context, orderNo string) error { + payload, _ := json.Marshal(queueTypes.ForthwithActivateOrderPayload{OrderNo: orderNo}) + worker := orderLogic.NewActivateOrderLogic(e.svcCtx) + if err := worker.ProcessTask(ctx, asynq.NewTask(queueTypes.ForthwithActivateOrder, payload)); err != nil { + return err + } + return worker.ProcessTask(ctx, asynq.NewTask(queueTypes.ForthwithActivateOrder, payload)) +} + +func (e *replayEnv) waitForLogCounts(ctx context.Context, orderNo string, wantGiftLogs, wantCommissionLogs int64) error { + deadline := time.Now().Add(8 * time.Second) + for { + giftLogs, err := e.countLogs(ctx, modelLog.TypeGift.Uint8(), orderNo) + if err != nil { + return err + } + commissionLogs, err := e.countLogs(ctx, modelLog.TypeCommission.Uint8(), orderNo) + if err != nil { + return err + } + if giftLogs >= wantGiftLogs && commissionLogs >= wantCommissionLogs { + if wantGiftLogs == 0 && wantCommissionLogs == 0 { + time.Sleep(500 * time.Millisecond) + } + return nil + } + if time.Now().After(deadline) { + return fmt.Errorf("timed out waiting for logs: order=%s gift=%d/%d commission=%d/%d", orderNo, giftLogs, wantGiftLogs, commissionLogs, wantCommissionLogs) + } + time.Sleep(100 * time.Millisecond) + } +} + +func (e *replayEnv) countLogs(ctx context.Context, logType uint8, orderNo string) (int64, error) { + var count int64 + err := e.db.WithContext(ctx).Model(&modelLog.SystemLog{}). + Where("type = ? AND content LIKE ?", logType, "%"+orderNo+"%"). + Count(&count).Error + return count, err +} + +func (e *replayEnv) waitForGiftLogs(ctx context.Context, orderNo string, userIDs ...int64) error { + deadline := time.Now().Add(5 * time.Second) + for time.Now().Before(deadline) { + var count int64 + if err := e.db.WithContext(ctx).Model(&modelLog.SystemLog{}). + Where("type = ? AND object_id IN ? AND content LIKE ?", modelLog.TypeGift.Uint8(), userIDs, "%"+orderNo+"%"). + Count(&count).Error; err != nil { + return err + } + if count == int64(len(userIDs)) { + return nil + } + time.Sleep(100 * time.Millisecond) + } + return fmt.Errorf("timed out waiting for gift logs for order=%s", orderNo) +} + +func (e *replayEnv) createPlans(ctx context.Context, scope string) (*modelSubscribe.Subscribe, *modelSubscribe.Subscribe, error) { + a := &modelSubscribe.Subscribe{ + Name: marker + "-" + scope + "-A", + Language: "en", + UnitPrice: 599, + UnitTime: "Month", + Traffic: 1024 * 1024 * 1024, + Inventory: -1, + Quota: 0, + NodeGroupIds: modelSubscribe.JSONInt64Slice{}, + } + b := &modelSubscribe.Subscribe{ + Name: marker + "-" + scope + "-B", + Language: "en", + UnitPrice: 699, + UnitTime: "Month", + Traffic: 2 * 1024 * 1024 * 1024, + Inventory: -1, + Quota: 0, + NodeGroupIds: modelSubscribe.JSONInt64Slice{}, + } + if err := e.db.WithContext(ctx).Create(a).Error; err != nil { + return nil, nil, err + } + if err := e.db.WithContext(ctx).Create(b).Error; err != nil { + return nil, nil, err + } + e.ids.plans = append(e.ids.plans, a.Id, b.Id) + return a, b, nil +} + +func (e *replayEnv) createUser(ctx context.Context, scope string, refererID int64, referralPercentage uint8) (*modelUser.User, error) { + onlyFirst := true + enable := true + isAdmin := false + u := &modelUser.User{ + Password: marker, + Algo: "default", + Salt: "default", + RefererId: refererID, + ReferralPercentage: referralPercentage, + OnlyFirstPurchase: &onlyFirst, + Enable: &enable, + IsAdmin: &isAdmin, + EnableBalanceNotify: &enable, + EnableLoginNotify: &enable, + EnableSubscribeNotify: &enable, + EnableTradeNotify: &enable, + Remark: marker + "-" + scope, + } + if err := e.db.WithContext(ctx).Create(u).Error; err != nil { + return nil, err + } + u.ReferCode = uuidx.UserInviteCode(u.Id) + if err := e.db.WithContext(ctx).Model(&modelUser.User{}).Where("id = ?", u.Id).Update("refer_code", u.ReferCode).Error; err != nil { + return nil, err + } + e.ids.users = append(e.ids.users, u.Id) + return u, nil +} + +func (e *replayEnv) createFamily(ctx context.Context, ownerID, memberID int64) error { + now := time.Now() + family := &modelUser.UserFamily{ + OwnerUserId: ownerID, + MaxMembers: modelUser.DefaultFamilyMaxSize, + Status: modelUser.FamilyStatusActive, + } + if err := e.db.WithContext(ctx).Create(family).Error; err != nil { + return err + } + members := []modelUser.UserFamilyMember{ + { + FamilyId: family.Id, + UserId: ownerID, + Role: modelUser.FamilyRoleOwner, + Status: modelUser.FamilyMemberActive, + JoinSource: marker, + JoinedAt: now, + }, + { + FamilyId: family.Id, + UserId: memberID, + Role: modelUser.FamilyRoleMember, + Status: modelUser.FamilyMemberActive, + JoinSource: marker, + JoinedAt: now, + }, + } + return e.db.WithContext(ctx).Create(&members).Error +} + +func (e *replayEnv) createUserSubscribe(ctx context.Context, userID, orderID, planID int64, expire time.Time) (*modelUser.Subscribe, error) { + groupLocked := false + sub := &modelUser.Subscribe{ + UserId: userID, + OrderId: orderID, + SubscribeId: planID, + GroupLocked: &groupLocked, + StartTime: time.Now().Add(-time.Hour), + ExpireTime: expire, + Traffic: 1024 * 1024 * 1024, + Token: marker + "-" + uuidx.NewUUID().String(), + UUID: uuidx.NewUUID().String(), + Status: 1, + Note: marker, + } + if err := e.db.WithContext(ctx).Create(sub).Error; err != nil { + return nil, err + } + e.ids.subscribes = append(e.ids.subscribes, sub.Id) + return sub, nil +} + +func (e *replayEnv) createPaidOrder(ctx context.Context, userID, subscriptionUserID, planID int64, isNew bool, scope string) (*modelOrder.Order, error) { + orderNo := fmt.Sprintf("%s-%s-%d", marker, scope, time.Now().UnixNano()) + order := &modelOrder.Order{ + UserId: userID, + SubscriptionUserId: subscriptionUserID, + OrderNo: orderNo, + Type: 1, + Quantity: 1, + Price: 599, + Amount: 599, + Status: 2, + SubscribeId: planID, + Method: "replay", + IsNew: isNew, + } + if err := e.db.WithContext(ctx).Create(order).Error; err != nil { + return nil, err + } + e.ids.orders = append(e.ids.orders, order.Id) + return order, nil +} + +func (e *replayEnv) cleanup(ctx context.Context) { + fmt.Println("\n-- cleanup replay rows --") + e.cleanupByMarker(ctx) + if len(e.ids.subscribes) > 0 { + _ = e.db.WithContext(ctx).Where("id IN ?", e.ids.subscribes).Delete(&modelUser.Subscribe{}).Error + } + if len(e.ids.orders) > 0 { + _ = e.db.WithContext(ctx).Where("id IN ?", e.ids.orders).Delete(&modelOrder.Order{}).Error + } + if len(e.ids.plans) > 0 { + _ = e.db.WithContext(ctx).Where("id IN ?", e.ids.plans).Delete(&modelSubscribe.Subscribe{}).Error + } + if len(e.ids.users) > 0 { + _ = e.db.WithContext(ctx).Unscoped().Where("id IN ?", e.ids.users).Delete(&modelUser.User{}).Error + } + fmt.Println("cleanup done") +} + +func (e *replayEnv) cleanupByMarker(ctx context.Context) { + _ = e.db.WithContext(ctx). + Where("join_source = ?", marker). + Delete(&modelUser.UserFamilyMember{}).Error + _ = e.db.WithContext(ctx). + Where("owner_user_id IN (SELECT id FROM `user` WHERE remark LIKE ?)", marker+"%"). + Delete(&modelUser.UserFamily{}).Error + _ = e.db.WithContext(ctx). + Where("type IN (33, 34) AND content LIKE ?", "%"+marker+"%"). + Delete(&modelLog.SystemLog{}).Error + _ = e.db.WithContext(ctx). + Where("order_no LIKE ?", marker+"%"). + Delete(&modelOrder.Order{}).Error + _ = e.db.WithContext(ctx). + Where("note = ? OR token LIKE ?", marker, marker+"%"). + Delete(&modelUser.Subscribe{}).Error + _ = e.db.WithContext(ctx). + Where("name LIKE ?", marker+"%"). + Delete(&modelSubscribe.Subscribe{}).Error + _ = e.db.WithContext(ctx).Unscoped(). + Where("remark LIKE ?", marker+"%"). + Delete(&modelUser.User{}).Error +} + +func looksLikeProduction(cfg config.Config) bool { + joined := strings.ToLower(strings.Join([]string{cfg.MySQL.Dbname, cfg.Site.Host, cfg.Host}, " ")) + if strings.Contains(joined, "prod") || strings.Contains(joined, "production") { + return true + } + if cfg.Debug { + return false + } + if strings.Contains(joined, "test") || strings.Contains(joined, "dev") || strings.Contains(joined, "staging") { + return false + } + return true +} + +func must(err error) { + if err != nil { + fatalf("%v", err) + } +} + +func fatalf(format string, args ...interface{}) { + fmt.Fprintf(os.Stderr, "FAIL: "+format+"\n", args...) + os.Exit(1) +}