fix(order): prevent duplicate subscriptions and repair invite gifts
All checks were successful
Build docker and publish / build (20.15.1) (push) Successful in 5m6s

This commit is contained in:
shanshanzhong 2026-04-24 21:16:21 -07:00
parent ae62ecc6b3
commit 9db4762904
17 changed files with 3013 additions and 52 deletions

View File

@ -2,6 +2,7 @@ package auth
import (
"context"
"net/mail"
"strings"
"github.com/perfect-panel/server/internal/config"
@ -15,11 +16,10 @@ func IsEmailDomainWhitelisted(email, whitelistCSV string) bool {
if whitelistCSV == "" {
return false
}
parts := strings.SplitN(email, "@", 2)
if len(parts) != 2 {
_, domain, ok := parseStrictEmail(email)
if !ok {
return false
}
domain := strings.ToLower(strings.TrimSpace(parts[1]))
for _, d := range strings.Split(whitelistCSV, ",") {
if strings.ToLower(strings.TrimSpace(d)) == domain {
return true
@ -32,10 +32,16 @@ func ShouldGrantTrialForEmail(register config.RegisterConfig, email string) bool
if !register.EnableTrial {
return false
}
if !IsValidTrialEmail(email) {
return false
}
// 无论白名单是否启用,泛域名邮箱(含 + 别名或 Gmail 点号)始终拒绝赠送
if IsDisposableAlias(email) {
return false
}
if isConfusableGmailDomain(emailDomain(email)) {
return false
}
if !register.EnableTrialEmailWhitelist {
return true
}
@ -52,11 +58,10 @@ func ShouldGrantTrialForEmail(register config.RegisterConfig, email string) bool
// For Gmail-like domains, local part containing "." or "+" is rejected.
// For all other domains, only "+" alias is rejected.
func IsDisposableAlias(email string) bool {
parts := strings.SplitN(strings.ToLower(strings.TrimSpace(email)), "@", 2)
if len(parts) != 2 {
local, domain, ok := parseStrictEmail(email)
if !ok {
return false
}
local, domain := parts[0], parts[1]
// All domains: reject + alias
if strings.ContainsRune(local, '+') {
@ -74,11 +79,10 @@ func IsDisposableAlias(email string) bool {
// Removes dots from local part for Gmail-like providers (gmail.com, googlemail.com).
func NormalizeEmail(email string) string {
email = strings.ToLower(strings.TrimSpace(email))
parts := strings.SplitN(email, "@", 2)
if len(parts) != 2 {
local, domain, ok := parseStrictEmail(email)
if !ok {
return email
}
local, domain := parts[0], parts[1]
// Strip + alias
if idx := strings.IndexByte(local, '+'); idx != -1 {
@ -101,6 +105,51 @@ func isGmailLikeDomain(domain string) bool {
return false
}
func IsValidTrialEmail(email string) bool {
local, domain, ok := parseStrictEmail(email)
if !ok {
return false
}
return local != "" && domain != ""
}
func parseStrictEmail(email string) (local, domain string, ok bool) {
email = strings.ToLower(strings.TrimSpace(email))
if email == "" || strings.ContainsAny(email, " \t\r\n") {
return "", "", false
}
addr, err := mail.ParseAddress(email)
if err != nil || addr.Address != email || addr.Name != "" {
return "", "", false
}
parts := strings.Split(addr.Address, "@")
if len(parts) != 2 {
return "", "", false
}
local = strings.TrimSpace(parts[0])
domain = strings.Trim(strings.TrimSpace(parts[1]), ".")
if local == "" || domain == "" || strings.Contains(domain, "..") || !strings.Contains(domain, ".") {
return "", "", false
}
return local, domain, true
}
func emailDomain(email string) string {
_, domain, ok := parseStrictEmail(email)
if !ok {
return ""
}
return domain
}
func isConfusableGmailDomain(domain string) bool {
switch strings.ToLower(strings.TrimSpace(domain)) {
case "gmaial.com", "gmial.com", "gmai.com", "gamil.com", "gmal.com", "gmail.co", "gmail.con":
return true
}
return false
}
// NormalizedEmailHasTrial returns true if any user with the same normalized email
// already holds a trial subscription. Only performs the cross-user DB check when
// normalization actually changes the email (i.e., dots removed or + alias stripped).

View File

@ -0,0 +1,188 @@
package auth
import (
"testing"
"github.com/perfect-panel/server/internal/config"
)
func TestNormalizeEmail(t *testing.T) {
tests := []struct {
input string
want string
}{
// Gmail dot trick
{"a.v.x.xx@gmail.com", "avxxx@gmail.com"},
{"john.doe@gmail.com", "johndoe@gmail.com"},
{"a.b.c.d.e@gmail.com", "abcde@gmail.com"},
// Gmail + alias
{"user+tag@gmail.com", "user@gmail.com"},
{"a.b+tag@gmail.com", "ab@gmail.com"},
// Googlemail
{"a.b@googlemail.com", "ab@googlemail.com"},
// Non-Gmail: dots preserved
{"john.doe@outlook.com", "john.doe@outlook.com"},
{"john.doe@qq.com", "john.doe@qq.com"},
// + alias stripped for all providers
{"user+spam@outlook.com", "user@outlook.com"},
{"user+spam@qq.com", "user@qq.com"},
// Case insensitive
{"User@Gmail.COM", "user@gmail.com"},
{"A.B@Gmail.com", "ab@gmail.com"},
// No change for normal non-gmail email
{"abc@163.com", "abc@163.com"},
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
got := NormalizeEmail(tt.input)
if got != tt.want {
t.Errorf("NormalizeEmail(%q) = %q, want %q", tt.input, got, tt.want)
}
})
}
}
func TestNormalizeEmail_NoChangeSkipsCheck(t *testing.T) {
// These emails should NOT trigger cross-user check (normalized == original)
noChangeCases := []string{
"abc@163.com",
"john.doe@outlook.com",
"user@qq.com",
}
for _, email := range noChangeCases {
normalized := NormalizeEmail(email)
lower := email
if normalized == lower {
// correct: no normalization change, NormalizedEmailHasTrial would return false early
}
}
}
func TestShouldGrantTrialForEmail(t *testing.T) {
// 模拟线上配置白名单开启gmail.com 也在名单里
rcWithGmail := config.RegisterConfig{
EnableTrial: true,
EnableTrialEmailWhitelist: true,
TrialEmailDomainWhitelist: "hifastapp.com,hifastvpn.com,126.com,139.com,163.com,gmail.com",
}
// 白名单关闭
rcNoWhitelist := config.RegisterConfig{
EnableTrial: true,
EnableTrialEmailWhitelist: false,
}
tests := []struct {
name string
rc config.RegisterConfig
email string
want bool
reason string
}{
{
name: "gmail dot trick - blocked even if gmail.com in whitelist",
rc: rcWithGmail,
email: "s.m.s.n.fsmbt.d.ndny@gmail.com",
want: false,
reason: "gmail 泛域名(含点号)应拒绝",
},
{
name: "gmail plus alias - blocked",
rc: rcWithGmail,
email: "user+tag@gmail.com",
want: false,
reason: "gmail +别名应拒绝",
},
{
name: "clean gmail - allowed",
rc: rcWithGmail,
email: "normaluser@gmail.com",
want: true,
reason: "干净的 gmail 应放行",
},
{
name: "163 with dot - allowed (non-gmail dot is ok)",
rc: rcWithGmail,
email: "s.m.s.n@163.com",
want: true,
reason: "非 gmail 域点号不拦截",
},
{
name: "163 plus alias - blocked",
rc: rcWithGmail,
email: "user+spam@163.com",
want: false,
reason: "所有域名的 +别名都拦截",
},
{
name: "gmail typo squatting domain - blocked even if accidentally whitelisted",
rc: config.RegisterConfig{EnableTrial: true, EnableTrialEmailWhitelist: true, TrialEmailDomainWhitelist: "gmail.com,gmaial.com"},
email: "1.2.3.4xxx@gmaial.com",
want: false,
reason: "易混淆 Gmail 域名不应发放试用",
},
{
name: "invalid empty local - blocked",
rc: rcWithGmail,
email: "@gmail.com",
want: false,
reason: "邮箱 local 为空应拒绝",
},
{
name: "subdomain spoof - blocked",
rc: rcWithGmail,
email: "user@fake.gmail.com",
want: false,
reason: "白名单必须精确匹配域名,不匹配子域",
},
{
name: "whitelist disabled - gmail dot trick still blocked",
rc: rcNoWhitelist,
email: "s.m.s.n.fsmbt.d.ndny@gmail.com",
want: false,
reason: "白名单未启用,但泛域名仍应拒绝",
},
{
name: "trial disabled - always blocked",
rc: config.RegisterConfig{EnableTrial: false},
email: "user@163.com",
want: false,
reason: "试用未开启",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := ShouldGrantTrialForEmail(tt.rc, tt.email)
if got != tt.want {
t.Errorf("ShouldGrantTrialForEmail(%q) = %v, want %v | reason: %s",
tt.email, got, tt.want, tt.reason)
}
})
}
}
func TestIsEmailDomainWhitelisted(t *testing.T) {
whitelist := "gmail.com,edu.cn,outlook.com"
tests := []struct {
email string
want bool
}{
{"user@gmail.com", true},
{"user@edu.cn", true},
{"User@Gmail.COM", true},
{"user@yahoo.com", false},
{"user@fake.gmail.com", false}, // subdomain not matched
{"user@", false},
{"notanemail", false},
{"@gmail.com", false},
}
for _, tt := range tests {
t.Run(tt.email, func(t *testing.T) {
got := IsEmailDomainWhitelisted(tt.email, whitelist)
if got != tt.want {
t.Errorf("IsEmailDomainWhitelisted(%q) = %v, want %v", tt.email, got, tt.want)
}
})
}
}

View File

@ -76,7 +76,10 @@ func CountScopedSubscribePurchaseOrders(
var count int64
query := db.WithContext(ctx).
Model(&modelOrder.Order{}).
Where("user_id IN ? AND subscribe_id = ? AND type IN ? AND amount > 0", scopeUserIDs, subscribeID, []int64{1, 2})
Where("user_id IN ? AND type IN ?", scopeUserIDs, []int64{1, 2})
if subscribeID > 0 {
query = query.Where("subscribe_id = ?", subscribeID)
}
if len(statuses) > 0 {
query = query.Where("status IN ?", statuses)
}

View File

@ -52,12 +52,8 @@ func ResolvePurchaseRoute(
return decision, nil
}
if requestedSubscribeID != anchorSub.SubscribeId {
return nil, ErrSingleModePlanMismatch
}
decision.Route = PurchaseRoutePurchaseToRenewal
decision.ResolvedSubscribeID = anchorSub.SubscribeId
decision.ResolvedSubscribeID = requestedSubscribeID
decision.Anchor = anchorSub
return decision, nil
}

View File

@ -0,0 +1,36 @@
package common
import (
"context"
"testing"
"time"
"github.com/perfect-panel/server/internal/model/user"
"github.com/stretchr/testify/require"
)
func TestResolvePurchaseRoute_AllowsPlanChangeForExistingSubscription(t *testing.T) {
anchor := &user.Subscribe{
Id: 10,
UserId: 20,
OrderId: 30,
SubscribeId: 1,
Token: "existing-token",
ExpireTime: time.Now().Add(time.Hour),
}
decision, err := ResolvePurchaseRoute(
context.Background(),
true,
anchor.UserId,
2,
func(context.Context, int64) (*user.Subscribe, error) {
return anchor, nil
},
)
require.NoError(t, err)
require.Equal(t, PurchaseRoutePurchaseToRenewal, decision.Route)
require.Equal(t, int64(2), decision.ResolvedSubscribeID)
require.Equal(t, anchor, decision.Anchor)
}

View File

@ -12,6 +12,9 @@ func getDiscount(discounts []types.SubscribeDiscount, inputMonths int64, isNewUs
if d.Quantity != inputMonths || d.Discount <= 0 || d.Discount >= 100 {
continue
}
if d.NewUserOnly && !isNewUser {
continue
}
if isNewUser {
// lowest discount value = biggest saving
if best < 0 || d.Discount < best {
@ -50,4 +53,3 @@ func isNewUserOnlyForQuantity(discounts []types.SubscribeDiscount, inputQuantity
}
return hasNewUserOnly && !hasFallback
}

View File

@ -0,0 +1,16 @@
package order
import (
"testing"
"github.com/perfect-panel/server/internal/types"
"github.com/stretchr/testify/require"
)
func TestGetDiscount_SkipsNewUserOnlyTierForExistingUser(t *testing.T) {
discount := getDiscount([]types.SubscribeDiscount{
{Quantity: 1, Discount: 90, NewUserOnly: true},
}, 1, false)
require.Equal(t, float64(1), discount)
}

View File

@ -47,7 +47,7 @@ func resolveNewUserDiscountEligibility(
ctx,
db,
eligibility.ScopeUserIDs,
subscribeID,
0,
[]int64{2, 5},
"",
)

View File

@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"math"
"strings"
"time"
"github.com/google/uuid"
@ -111,18 +112,21 @@ func (l *PurchaseLogic) Purchase(req *types.PurchaseOrderRequest) (resp *types.P
}
}
// 非单订阅模式下,若用户已有同套餐订阅(含过期),提前路由为续费,防止创建重复订阅
// 全局单订阅口径:若用户已有任意付费订阅(含过期),提前路由为续费/换套餐,
// 防止不同套餐购买创建第二条订阅。
if !l.svcCtx.Config.Subscribe.SingleModel && orderType == 1 {
var existSub user.Subscribe
if e := l.svcCtx.DB.WithContext(l.ctx).
Model(&user.Subscribe{}).
Where("user_id = ? AND subscribe_id = ?", entitlement.EffectiveUserID, targetSubscribeID).
Where("user_id = ? AND token != '' AND (order_id > 0 OR token LIKE 'iap:%')", entitlement.EffectiveUserID).
Order("expire_time DESC").
Order("updated_at DESC").
Order("id DESC").
First(&existSub).Error; e == nil && existSub.Id > 0 && existSub.Token != "" {
orderType = 2
parentOrderID = existSub.OrderId
subscribeToken = existSub.Token
l.Infow("[Purchase] non-single mode purchase routed to renewal (existing subscription found)",
l.Infow("[Purchase] purchase routed to renewal/change plan (existing subscription found)",
logger.Field("existing_subscribe_id", existSub.Id),
logger.Field("existing_status", existSub.Status),
logger.Field("user_id", u.Id),
@ -311,13 +315,13 @@ func (l *PurchaseLogic) Purchase(req *types.PurchaseOrderRequest) (resp *types.P
// check subscribe plan quota limit inside transaction to prevent race condition
if orderInfo.Type == 1 && sub.Quota > 0 {
var currentUserSub []user.Subscribe
if e := db.Model(&user.Subscribe{}).Where("user_id = ?", u.Id).Find(&currentUserSub).Error; e != nil {
if e := db.Model(&user.Subscribe{}).Where("user_id = ?", entitlement.EffectiveUserID).Find(&currentUserSub).Error; e != nil {
l.Errorw("[Purchase] Database query error", logger.Field("error", e.Error()), logger.Field("user_id", u.Id))
return e
}
var count int64
for _, v := range currentUserSub {
if v.SubscribeId == targetSubscribeID {
if v.OrderId > 0 || strings.HasPrefix(v.Token, "iap:") {
count++
}
}

View File

@ -0,0 +1,766 @@
package order
import (
"context"
"testing"
"time"
"github.com/alicebob/miniredis/v2"
"github.com/hibiken/asynq"
modelOrder "github.com/perfect-panel/server/internal/model/order"
"github.com/perfect-panel/server/internal/model/payment"
subModel "github.com/perfect-panel/server/internal/model/subscribe"
"github.com/perfect-panel/server/internal/model/user"
"github.com/perfect-panel/server/internal/svc"
"github.com/perfect-panel/server/internal/types"
"github.com/perfect-panel/server/pkg/constant"
"github.com/perfect-panel/server/pkg/xerr"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
// setupNewUserOnlyDB 创建带必要表的 SQLite 内存数据库
func setupNewUserOnlyDB(t *testing.T) *gorm.DB {
t.Helper()
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
})
require.NoError(t, err, "failed to open in-memory SQLite")
db.Exec("PRAGMA foreign_keys = OFF")
sqls := []string{
`CREATE TABLE IF NOT EXISTS "subscribe" (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name VARCHAR(255) NOT NULL DEFAULT '',
language VARCHAR(255) NOT NULL DEFAULT '',
description TEXT,
unit_price INTEGER NOT NULL DEFAULT 0,
unit_time VARCHAR(255) NOT NULL DEFAULT '',
discount TEXT,
replacement INTEGER NOT NULL DEFAULT 0,
inventory INTEGER NOT NULL DEFAULT -1,
traffic INTEGER NOT NULL DEFAULT 0,
speed_limit INTEGER NOT NULL DEFAULT 0,
device_limit INTEGER NOT NULL DEFAULT 0,
quota INTEGER NOT NULL DEFAULT 0,
new_user_only TINYINT DEFAULT 0,
nodes VARCHAR(255),
node_tags VARCHAR(255),
show TINYINT NOT NULL DEFAULT 0,
sell TINYINT NOT NULL DEFAULT 1,
sort INTEGER NOT NULL DEFAULT 0,
deduction_ratio INTEGER DEFAULT 0,
allow_deduction TINYINT DEFAULT 1,
reset_cycle INTEGER DEFAULT 0,
renewal_reset TINYINT DEFAULT 0,
show_original_price TINYINT NOT NULL DEFAULT 1,
created_at DATETIME,
updated_at DATETIME
)`,
`CREATE TABLE IF NOT EXISTS "order" (
id INTEGER PRIMARY KEY AUTOINCREMENT,
parent_id INTEGER DEFAULT NULL,
user_id INTEGER NOT NULL DEFAULT 0,
subscription_user_id INTEGER NOT NULL DEFAULT 0,
order_no VARCHAR(255) NOT NULL DEFAULT '' UNIQUE,
type TINYINT NOT NULL DEFAULT 1,
quantity INTEGER NOT NULL DEFAULT 1,
price INTEGER NOT NULL DEFAULT 0,
amount INTEGER NOT NULL DEFAULT 0,
gift_amount INTEGER NOT NULL DEFAULT 0,
discount INTEGER NOT NULL DEFAULT 0,
coupon VARCHAR(255) DEFAULT NULL,
coupon_discount INTEGER NOT NULL DEFAULT 0,
commission INTEGER NOT NULL DEFAULT 0,
payment_id INTEGER NOT NULL DEFAULT 0,
method VARCHAR(255) NOT NULL DEFAULT '',
fee_amount INTEGER NOT NULL DEFAULT 0,
trade_no VARCHAR(255) DEFAULT NULL,
app_account_token VARCHAR(255) DEFAULT NULL,
status TINYINT NOT NULL DEFAULT 1,
subscribe_id INTEGER NOT NULL DEFAULT 0,
subscribe_token VARCHAR(255) DEFAULT NULL,
is_new TINYINT NOT NULL DEFAULT 0,
created_at DATETIME,
updated_at DATETIME,
deleted_at DATETIME DEFAULT NULL
)`,
`CREATE TABLE IF NOT EXISTS "user" (
id INTEGER PRIMARY KEY AUTOINCREMENT,
password VARCHAR(100) NOT NULL DEFAULT '',
algo VARCHAR(20) DEFAULT 'default',
salt VARCHAR(20) DEFAULT NULL,
avatar TEXT,
balance INTEGER DEFAULT 0,
refer_code VARCHAR(20) DEFAULT '',
referer_id INTEGER DEFAULT 0,
commission INTEGER DEFAULT 0,
referral_percentage INTEGER DEFAULT 0,
only_first_purchase TINYINT DEFAULT 1,
gift_amount INTEGER DEFAULT 0,
enable TINYINT DEFAULT 1,
is_admin TINYINT DEFAULT 0,
enable_balance_notify TINYINT DEFAULT 0,
enable_login_notify TINYINT DEFAULT 0,
enable_subscribe_notify TINYINT DEFAULT 0,
enable_trade_notify TINYINT DEFAULT 0,
rules TEXT,
member_status VARCHAR(20) DEFAULT '',
created_at DATETIME,
updated_at DATETIME,
deleted_at DATETIME DEFAULT NULL
)`,
`CREATE TABLE IF NOT EXISTS "payment" (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name VARCHAR(100) NOT NULL DEFAULT '',
platform VARCHAR(100) NOT NULL DEFAULT '',
icon VARCHAR(255) DEFAULT '',
domain VARCHAR(255) DEFAULT '',
config TEXT NOT NULL DEFAULT '{}',
description TEXT,
fee_mode TINYINT NOT NULL DEFAULT 0,
fee_percent INTEGER DEFAULT 0,
fee_amount INTEGER DEFAULT 0,
enable TINYINT NOT NULL DEFAULT 1,
token VARCHAR(255) NOT NULL DEFAULT '' UNIQUE
)`,
`CREATE TABLE IF NOT EXISTS "user_subscribe" (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL DEFAULT 0,
order_id INTEGER NOT NULL DEFAULT 0,
subscribe_id INTEGER NOT NULL DEFAULT 0,
start_time DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
expire_time DATETIME DEFAULT NULL,
finished_at DATETIME DEFAULT NULL,
traffic INTEGER DEFAULT 0,
download INTEGER DEFAULT 0,
upload INTEGER DEFAULT 0,
token VARCHAR(255) DEFAULT '' UNIQUE,
uuid VARCHAR(255) DEFAULT '' UNIQUE,
status TINYINT DEFAULT 0,
note VARCHAR(500) DEFAULT '',
created_at DATETIME,
updated_at DATETIME
)`,
`CREATE TABLE IF NOT EXISTS "user_device" (
id INTEGER PRIMARY KEY AUTOINCREMENT,
ip VARCHAR(255) NOT NULL DEFAULT '',
user_id INTEGER NOT NULL DEFAULT 0,
user_agent TEXT,
identifier VARCHAR(255) NOT NULL DEFAULT '' UNIQUE,
short_code VARCHAR(255) NOT NULL DEFAULT '',
online TINYINT NOT NULL DEFAULT 0,
enabled TINYINT NOT NULL DEFAULT 1,
created_at DATETIME,
updated_at DATETIME
)`,
`CREATE TABLE IF NOT EXISTS "user_family" (
id INTEGER PRIMARY KEY AUTOINCREMENT,
owner_user_id INTEGER NOT NULL DEFAULT 0,
max_members INTEGER NOT NULL DEFAULT 2,
status TINYINT DEFAULT 0,
created_at DATETIME,
updated_at DATETIME,
deleted_at DATETIME DEFAULT NULL
)`,
`CREATE TABLE IF NOT EXISTS "user_family_member" (
id INTEGER PRIMARY KEY AUTOINCREMENT,
family_id INTEGER NOT NULL DEFAULT 0,
user_id INTEGER NOT NULL DEFAULT 0,
role TINYINT DEFAULT 0,
status TINYINT DEFAULT 0,
join_source VARCHAR(32) NOT NULL DEFAULT '',
joined_at DATETIME,
left_at DATETIME DEFAULT NULL,
created_at DATETIME,
updated_at DATETIME,
deleted_at DATETIME DEFAULT NULL
)`,
}
for _, sql := range sqls {
require.NoError(t, db.Exec(sql).Error)
}
return db
}
// setupNewUserOnlyRedis 启动 miniredis返回 redis.Client 和 miniredis 句柄
func setupNewUserOnlyRedis(t *testing.T) (*redis.Client, *miniredis.Miniredis) {
t.Helper()
mr, err := miniredis.Run()
require.NoError(t, err)
t.Cleanup(mr.Close)
rds := redis.NewClient(&redis.Options{Addr: mr.Addr()})
return rds, mr
}
// buildNewUserOnlySvcCtx 组装最小 ServiceContext含 asynq Queue 使用 miniredis
func buildNewUserOnlySvcCtx(db *gorm.DB, rds *redis.Client, mr *miniredis.Miniredis) *svc.ServiceContext {
queue := asynq.NewClient(asynq.RedisClientOpt{Addr: mr.Addr()})
return &svc.ServiceContext{
DB: db,
Redis: rds,
UserModel: user.NewModel(db, rds),
OrderModel: modelOrder.NewModel(db, rds),
SubscribeModel: subModel.NewModel(db, rds),
PaymentModel: payment.NewModel(db, rds),
Queue: queue,
}
}
// insertTestSubscribe 直接用 SQL 插入 subscribe 行(绕过 GORM hook 的 MySQL 方言)
// new_user_only=true 时同时写入 discount JSON使代码里的 discount 检查生效
func insertTestSubscribe(t *testing.T, db *gorm.DB, id int64, newUserOnly bool) {
t.Helper()
nuOnly := 0
discount := ""
if newUserOnly {
nuOnly = 1
// discount JSON 包含一个 new_user_only=true 的 tier匹配 quantity=1
discount = `[{"quantity":1,"discount":90,"new_user_only":true}]`
}
err := db.Exec(`INSERT INTO "subscribe"
(id, name, unit_price, inventory, sell, sort, new_user_only, discount, created_at, updated_at)
VALUES (?, 'Test Plan', 1000, -1, 1, ?, ?, ?, datetime('now'), datetime('now'))`,
id, id, nuOnly, discount).Error
require.NoError(t, err)
}
// insertTestPayment 插入支付方式行
func insertTestPayment(t *testing.T, db *gorm.DB, id int64) {
t.Helper()
err := db.Exec(`INSERT INTO "payment"
(id, name, platform, config, enable, fee_mode, token)
VALUES (?, 'Balance', 'balance', '{}', 1, 0, ?)`,
id, "test-token").Error
require.NoError(t, err)
}
// insertTestUser 插入用户行createdAt 可控
func insertTestUser(t *testing.T, db *gorm.DB, id int64, createdAt time.Time) *user.User {
t.Helper()
err := db.Exec(`INSERT INTO "user"
(id, password, balance, gift_amount, enable, created_at, updated_at)
VALUES (?, '', 0, 0, 1, ?, datetime('now'))`,
id, createdAt.UTC().Format("2006-01-02 15:04:05")).Error
require.NoError(t, err)
return &user.User{
Id: id,
GiftAmount: 0,
CreatedAt: createdAt,
}
}
func insertTestDevice(t *testing.T, db *gorm.DB, userID int64, identifier string, createdAt time.Time) {
t.Helper()
err := db.Exec(`INSERT INTO "user_device"
(user_id, ip, user_agent, identifier, short_code, online, enabled, created_at, updated_at)
VALUES (?, '127.0.0.1', 'test-agent', ?, '', 0, 1, ?, datetime('now'))`,
userID,
identifier,
createdAt.UTC().Format("2006-01-02 15:04:05"),
).Error
require.NoError(t, err)
}
func insertTestFamily(t *testing.T, db *gorm.DB, familyID, ownerUserID int64) {
t.Helper()
err := db.Exec(`INSERT INTO "user_family"
(id, owner_user_id, max_members, status, created_at, updated_at)
VALUES (?, ?, 3, 1, datetime('now'), datetime('now'))`,
familyID,
ownerUserID,
).Error
require.NoError(t, err)
}
func insertTestFamilyMember(t *testing.T, db *gorm.DB, familyID, userID int64, role, status uint8, joinSource string) {
t.Helper()
err := db.Exec(`INSERT INTO "user_family_member"
(family_id, user_id, role, status, join_source, joined_at, created_at, updated_at)
VALUES (?, ?, ?, ?, ?, datetime('now'), datetime('now'), datetime('now'))`,
familyID,
userID,
role,
status,
joinSource,
).Error
require.NoError(t, err)
}
// insertTestOrder 插入一条历史订单status=2 表示已支付)
func insertTestOrder(t *testing.T, db *gorm.DB, userID, subscribeID int64, status uint8) {
t.Helper()
err := db.Exec(`INSERT INTO "order"
(user_id, order_no, type, status, subscribe_id, created_at, updated_at)
VALUES (?, ?, 1, ?, ?, datetime('now'), datetime('now'))`,
userID, "existing-order-no", status, subscribeID).Error
require.NoError(t, err)
}
func insertScopedTestOrder(t *testing.T, db *gorm.DB, orderNo string, userID, subscribeID int64, status uint8) {
t.Helper()
err := db.Exec(`INSERT INTO "order"
(user_id, order_no, type, status, subscribe_id, created_at, updated_at)
VALUES (?, ?, 1, ?, ?, datetime('now'), datetime('now'))`,
userID, orderNo, status, subscribeID).Error
require.NoError(t, err)
}
// buildPurchaseCtx 把 user 放入 context模拟中间件行为
func buildPurchaseCtx(u *user.User) context.Context {
return context.WithValue(context.Background(), constant.CtxKeyUser, u)
}
// TestPurchase_NewUserOnly_UserTooOld 验证new_user_only=true用户注册超过 24h → 返回 SubscribeNewUserOnly
func TestPurchase_NewUserOnly_UserTooOld(t *testing.T) {
db := setupNewUserOnlyDB(t)
rds, mr := setupNewUserOnlyRedis(t)
svcCtx := buildNewUserOnlySvcCtx(db, rds, mr)
const subID = int64(1)
const payID = int64(1)
insertTestSubscribe(t, db, subID, true) // new_user_only = true
insertTestPayment(t, db, payID)
// 用户注册 48 小时前 → 超出 24h 限制
u := insertTestUser(t, db, 100, time.Now().Add(-48*time.Hour))
ctx := buildPurchaseCtx(u)
logic := NewPurchaseLogic(ctx, svcCtx)
_, err := logic.Purchase(&types.PurchaseOrderRequest{
SubscribeId: subID,
Payment: payID,
Quantity: 1,
})
require.Error(t, err)
var errCode *xerr.CodeError
require.ErrorAs(t, err, &errCode)
assert.Equal(t, xerr.SubscribeNewUserOnly, errCode.GetErrCode(),
"注册超过24h应返回 SubscribeNewUserOnly 错误码")
// 验证订单未被创建
var count int64
db.Model(&modelOrder.Order{}).Where("user_id = ?", u.Id).Count(&count)
assert.Equal(t, int64(0), count, "用户注册超时,订单不应被创建")
}
// TestPurchase_NewUserOnly_AlreadyPurchased 验证new_user_only=true用户是新用户但已购买过
// → 允许下单(不拦截),但不享受新人折扣
func TestPurchase_NewUserOnly_AlreadyPurchased(t *testing.T) {
db := setupNewUserOnlyDB(t)
rds, mr := setupNewUserOnlyRedis(t)
svcCtx := buildNewUserOnlySvcCtx(db, rds, mr)
const subID = int64(2)
const payID = int64(1)
insertTestSubscribe(t, db, subID, true)
insertTestPayment(t, db, payID)
// 用户刚注册2h前→ 满足时间条件
u := insertTestUser(t, db, 200, time.Now().Add(-2*time.Hour))
// 但已有一条 status=2 的历史订单(已支付)
insertTestOrder(t, db, u.Id, subID, 2)
ctx := buildPurchaseCtx(u)
logic := NewPurchaseLogic(ctx, svcCtx)
resp, err := logic.Purchase(&types.PurchaseOrderRequest{
SubscribeId: subID,
Payment: payID,
Quantity: 1,
})
// 不应被拦截,允许下单
require.NoError(t, err, "24h内已购用户应允许继续下单不应返回错误")
require.NotNil(t, resp)
assert.NotEmpty(t, resp.OrderNo)
// 历史订单 +1新增了一条
var count int64
db.Model(&modelOrder.Order{}).Where("user_id = ?", u.Id).Count(&count)
assert.Equal(t, int64(2), count, "应新增一条订单")
// 新订单无折扣Amount=Price=1000
var newOrder modelOrder.Order
require.NoError(t, db.Where("order_no = ?", resp.OrderNo).First(&newOrder).Error)
assert.Equal(t, int64(1000), newOrder.Amount, "已购用户不享受新人折扣Amount 应等于 Price")
assert.Equal(t, int64(0), newOrder.Discount, "Discount 应为 0")
}
// TestPurchase_NewUserOnly_Success 验证new_user_only=true新用户首次购买 → 成功创建订单
func TestPurchase_NewUserOnly_Success(t *testing.T) {
db := setupNewUserOnlyDB(t)
rds, mr := setupNewUserOnlyRedis(t)
svcCtx := buildNewUserOnlySvcCtx(db, rds, mr)
const subID = int64(3)
const payID = int64(1)
insertTestSubscribe(t, db, subID, true)
insertTestPayment(t, db, payID)
// 用户 1 小时前注册(新用户),且没有历史订单
u := insertTestUser(t, db, 300, time.Now().Add(-1*time.Hour))
ctx := buildPurchaseCtx(u)
logic := NewPurchaseLogic(ctx, svcCtx)
resp, err := logic.Purchase(&types.PurchaseOrderRequest{
SubscribeId: subID,
Payment: payID,
Quantity: 1,
})
require.NoError(t, err)
require.NotNil(t, resp)
assert.NotEmpty(t, resp.OrderNo, "新用户首次购买应成功,返回订单号")
// 验证订单已写入数据库
var o modelOrder.Order
err = db.Where("order_no = ?", resp.OrderNo).First(&o).Error
require.NoError(t, err)
assert.Equal(t, u.Id, o.UserId)
assert.Equal(t, subID, o.SubscribeId)
}
// TestPurchase_NewUserOnly_Disabled 验证new_user_only=false 时,老用户也能正常购买
func TestPurchase_NewUserOnly_Disabled(t *testing.T) {
db := setupNewUserOnlyDB(t)
rds, mr := setupNewUserOnlyRedis(t)
svcCtx := buildNewUserOnlySvcCtx(db, rds, mr)
const subID = int64(4)
const payID = int64(1)
insertTestSubscribe(t, db, subID, false) // new_user_only = false
insertTestPayment(t, db, payID)
// 注册 30 天的老用户
u := insertTestUser(t, db, 400, time.Now().Add(-30*24*time.Hour))
ctx := buildPurchaseCtx(u)
logic := NewPurchaseLogic(ctx, svcCtx)
resp, err := logic.Purchase(&types.PurchaseOrderRequest{
SubscribeId: subID,
Payment: payID,
Quantity: 1,
})
require.NoError(t, err)
require.NotNil(t, resp)
assert.NotEmpty(t, resp.OrderNo, "new_user_only=false时老用户应能正常购买")
}
// TestPurchase_SingleMode_PendingOldOrderCancelled 验证:单订阅模式下,已有 pending 订单时
// 第二次下单应关闭旧单并创建新单(而非复用旧单)
func TestPurchase_SingleMode_PendingOldOrderCancelled(t *testing.T) {
db := setupNewUserOnlyDB(t)
rds, mr := setupNewUserOnlyRedis(t)
svcCtx := buildNewUserOnlySvcCtx(db, rds, mr)
svcCtx.Config.Subscribe.SingleModel = true
const subID = int64(5)
const payID = int64(1)
insertTestSubscribe(t, db, subID, false)
insertTestPayment(t, db, payID)
u := insertTestUser(t, db, 500, time.Now().Add(-1*time.Hour))
ctx := buildPurchaseCtx(u)
// 第一次下单pending
logic := NewPurchaseLogic(ctx, svcCtx)
resp1, err := logic.Purchase(&types.PurchaseOrderRequest{
SubscribeId: subID,
Payment: payID,
Quantity: 1,
})
require.NoError(t, err)
require.NotNil(t, resp1)
firstOrderNo := resp1.OrderNo
assert.NotEmpty(t, firstOrderNo)
// 确认第一单 pending
var firstOrder modelOrder.Order
require.NoError(t, db.Where("order_no = ?", firstOrderNo).First(&firstOrder).Error)
assert.Equal(t, uint8(1), firstOrder.Status, "第一单应为 pending")
// 第二次下单(不同 quantity
logic2 := NewPurchaseLogic(ctx, svcCtx)
resp2, err := logic2.Purchase(&types.PurchaseOrderRequest{
SubscribeId: subID,
Payment: payID,
Quantity: 3,
})
require.NoError(t, err)
require.NotNil(t, resp2)
secondOrderNo := resp2.OrderNo
// 新单与旧单不同
assert.NotEqual(t, firstOrderNo, secondOrderNo, "第二次下单应创建新订单,不复用旧单")
// 旧单应被关闭status=3
var closedOrder modelOrder.Order
require.NoError(t, db.Where("order_no = ?", firstOrderNo).First(&closedOrder).Error)
assert.Equal(t, uint8(3), closedOrder.Status, "旧 pending 单应被关闭")
// 新单的 quantity 应为 3
var newOrder modelOrder.Order
require.NoError(t, db.Where("order_no = ?", secondOrderNo).First(&newOrder).Error)
assert.Equal(t, int64(3), newOrder.Quantity, "新单 quantity 应为 3")
assert.Equal(t, uint8(1), newOrder.Status, "新单应为 pending 状态")
}
// TestPurchase_SingleMode_NoPendingOrder 验证:单订阅模式下,没有旧 pending 单时正常创建
func TestPurchase_SingleMode_NoPendingOrder(t *testing.T) {
db := setupNewUserOnlyDB(t)
rds, mr := setupNewUserOnlyRedis(t)
svcCtx := buildNewUserOnlySvcCtx(db, rds, mr)
svcCtx.Config.Subscribe.SingleModel = true
const subID = int64(6)
const payID = int64(1)
insertTestSubscribe(t, db, subID, false)
insertTestPayment(t, db, payID)
u := insertTestUser(t, db, 600, time.Now().Add(-1*time.Hour))
ctx := buildPurchaseCtx(u)
logic := NewPurchaseLogic(ctx, svcCtx)
resp, err := logic.Purchase(&types.PurchaseOrderRequest{
SubscribeId: subID,
Payment: payID,
Quantity: 2,
})
require.NoError(t, err)
require.NotNil(t, resp)
assert.NotEmpty(t, resp.OrderNo, "无旧 pending 单时应正常创建新单")
var o modelOrder.Order
require.NoError(t, db.Where("order_no = ?", resp.OrderNo).First(&o).Error)
assert.Equal(t, int64(2), o.Quantity)
assert.Equal(t, uint8(1), o.Status)
}
// TestPurchase_NewUserOnly_AlreadyPurchased_NoBlock 验证new_user_only=true 套餐,
// 24小时内但已购买过 → 允许下单但不享受新人折扣Discount=0Amount=Price
func TestPurchase_NewUserOnly_AlreadyPurchased_NoBlock(t *testing.T) {
db := setupNewUserOnlyDB(t)
rds, mr := setupNewUserOnlyRedis(t)
svcCtx := buildNewUserOnlySvcCtx(db, rds, mr)
const subID = int64(7)
const payID = int64(1)
// 套餐unit_price=1000discount=[{quantity:1,discount:80,new_user_only:true}]
insertTestSubscribe(t, db, subID, true)
insertTestPayment(t, db, payID)
// 用户 1 小时前注册新用户但已有一条成功订单status=2
u := insertTestUser(t, db, 700, time.Now().Add(-1*time.Hour))
insertTestOrder(t, db, u.Id, subID, 2)
ctx := buildPurchaseCtx(u)
logic := NewPurchaseLogic(ctx, svcCtx)
resp, err := logic.Purchase(&types.PurchaseOrderRequest{
SubscribeId: subID,
Payment: payID,
Quantity: 1,
})
// 不应被拦截
require.NoError(t, err, "24h内已购用户不应被拦截应允许下单")
require.NotNil(t, resp)
assert.NotEmpty(t, resp.OrderNo)
// 验证订单金额无折扣Amount=Price=1000Discount=0
var newOrder modelOrder.Order
require.NoError(t, db.Where("order_no = ?", resp.OrderNo).First(&newOrder).Error)
assert.Equal(t, int64(1000), newOrder.Price, "Price 应为原价 1000")
assert.Equal(t, int64(1000), newOrder.Amount, "已购用户不享受新人折扣Amount 应等于 Price")
assert.Equal(t, int64(0), newOrder.Discount, "Discount 应为 0")
}
// TestPurchase_NewUserOnly_FirstPurchase_HasDiscount 验证new_user_only=true 套餐,
// 24小时内首次购买 → 允许下单且享受新人折扣
func TestPurchase_NewUserOnly_FirstPurchase_HasDiscount(t *testing.T) {
db := setupNewUserOnlyDB(t)
rds, mr := setupNewUserOnlyRedis(t)
svcCtx := buildNewUserOnlySvcCtx(db, rds, mr)
const subID = int64(8)
const payID = int64(1)
// 套餐unit_price=1000discount=[{quantity:1,discount:80,new_user_only:true}]8折
insertTestSubscribe(t, db, subID, true)
insertTestPayment(t, db, payID)
// 用户 1 小时前注册,无历史订单
u := insertTestUser(t, db, 800, time.Now().Add(-1*time.Hour))
ctx := buildPurchaseCtx(u)
logic := NewPurchaseLogic(ctx, svcCtx)
resp, err := logic.Purchase(&types.PurchaseOrderRequest{
SubscribeId: subID,
Payment: payID,
Quantity: 1,
})
require.NoError(t, err)
require.NotNil(t, resp)
var newOrder modelOrder.Order
require.NoError(t, db.Where("order_no = ?", resp.OrderNo).First(&newOrder).Error)
assert.Equal(t, int64(1000), newOrder.Price, "Price 应为原价 1000")
assert.Equal(t, int64(900), newOrder.Amount, "首次购买应享受9折Amount=900")
assert.Equal(t, int64(100), newOrder.Discount, "折扣金额应为 100")
}
func TestPurchase_NewUserOnly_BindEmailScopeUsesEarliestDeviceTime(t *testing.T) {
db := setupNewUserOnlyDB(t)
rds, mr := setupNewUserOnlyRedis(t)
svcCtx := buildNewUserOnlySvcCtx(db, rds, mr)
const (
subID = int64(9)
payID = int64(1)
ownerUserID = int64(901)
memberUserID = int64(902)
familyID = int64(99)
)
insertTestSubscribe(t, db, subID, true)
insertTestPayment(t, db, payID)
owner := insertTestUser(t, db, ownerUserID, time.Now().Add(-1*time.Hour))
insertTestUser(t, db, memberUserID, time.Now().Add(-72*time.Hour))
insertTestDevice(t, db, memberUserID, "device-eligibility-old", time.Now().Add(-72*time.Hour))
insertTestFamily(t, db, familyID, ownerUserID)
insertTestFamilyMember(t, db, familyID, ownerUserID, user.FamilyRoleOwner, user.FamilyMemberActive, "owner_init")
insertTestFamilyMember(t, db, familyID, memberUserID, user.FamilyRoleMember, user.FamilyMemberActive, "bind_email_with_verification")
logic := NewPurchaseLogic(buildPurchaseCtx(owner), svcCtx)
_, err := logic.Purchase(&types.PurchaseOrderRequest{
SubscribeId: subID,
Payment: payID,
Quantity: 1,
})
require.Error(t, err)
var errCode *xerr.CodeError
require.ErrorAs(t, err, &errCode)
assert.Equal(t, xerr.SubscribeNewUserOnly, errCode.GetErrCode())
}
func TestPurchase_NewUserOnly_BindEmailScopeSharesHistory(t *testing.T) {
db := setupNewUserOnlyDB(t)
rds, mr := setupNewUserOnlyRedis(t)
svcCtx := buildNewUserOnlySvcCtx(db, rds, mr)
const (
subID = int64(10)
payID = int64(1)
ownerUserID = int64(1001)
memberUserID = int64(1002)
familyID = int64(109)
)
insertTestSubscribe(t, db, subID, true)
insertTestPayment(t, db, payID)
owner := insertTestUser(t, db, ownerUserID, time.Now().Add(-1*time.Hour))
insertTestUser(t, db, memberUserID, time.Now().Add(-2*time.Hour))
insertTestDevice(t, db, memberUserID, "device-eligibility-shared", time.Now().Add(-2*time.Hour))
insertTestFamily(t, db, familyID, ownerUserID)
insertTestFamilyMember(t, db, familyID, ownerUserID, user.FamilyRoleOwner, user.FamilyMemberActive, "owner_init")
insertTestFamilyMember(t, db, familyID, memberUserID, user.FamilyRoleMember, user.FamilyMemberActive, "bind_email_with_verification")
insertScopedTestOrder(t, db, "existing-scope-order", memberUserID, subID, 2)
logic := NewPurchaseLogic(buildPurchaseCtx(owner), svcCtx)
resp, err := logic.Purchase(&types.PurchaseOrderRequest{
SubscribeId: subID,
Payment: payID,
Quantity: 1,
})
require.NoError(t, err)
require.NotNil(t, resp)
var newOrder modelOrder.Order
require.NoError(t, db.Where("order_no = ?", resp.OrderNo).First(&newOrder).Error)
assert.Equal(t, int64(1000), newOrder.Amount)
assert.Equal(t, int64(0), newOrder.Discount)
}
func TestPreCreateOrder_NewUserOnly_BindEmailScopeUsesEarliestDeviceTime(t *testing.T) {
db := setupNewUserOnlyDB(t)
rds, mr := setupNewUserOnlyRedis(t)
svcCtx := buildNewUserOnlySvcCtx(db, rds, mr)
const (
subID = int64(11)
ownerUserID = int64(1101)
memberUserID = int64(1102)
familyID = int64(119)
)
insertTestSubscribe(t, db, subID, true)
owner := insertTestUser(t, db, ownerUserID, time.Now().Add(-1*time.Hour))
insertTestUser(t, db, memberUserID, time.Now().Add(-96*time.Hour))
insertTestDevice(t, db, memberUserID, "device-precreate-old", time.Now().Add(-96*time.Hour))
insertTestFamily(t, db, familyID, ownerUserID)
insertTestFamilyMember(t, db, familyID, ownerUserID, user.FamilyRoleOwner, user.FamilyMemberActive, "owner_init")
insertTestFamilyMember(t, db, familyID, memberUserID, user.FamilyRoleMember, user.FamilyMemberActive, "bind_email_with_verification")
logic := NewPreCreateOrderLogic(buildPurchaseCtx(owner), svcCtx)
_, err := logic.PreCreateOrder(&types.PurchaseOrderRequest{
SubscribeId: subID,
Quantity: 1,
})
require.Error(t, err)
var errCode *xerr.CodeError
require.ErrorAs(t, err, &errCode)
assert.Equal(t, xerr.SubscribeNewUserOnly, errCode.GetErrCode())
}
func TestPreCreateOrder_NewUserOnly_OrdinaryFamilyMemberDoesNotAffectEligibility(t *testing.T) {
db := setupNewUserOnlyDB(t)
rds, mr := setupNewUserOnlyRedis(t)
svcCtx := buildNewUserOnlySvcCtx(db, rds, mr)
const (
subID = int64(12)
ownerUserID = int64(1201)
memberUserID = int64(1202)
familyID = int64(129)
)
insertTestSubscribe(t, db, subID, true)
owner := insertTestUser(t, db, ownerUserID, time.Now().Add(-1*time.Hour))
insertTestUser(t, db, memberUserID, time.Now().Add(-96*time.Hour))
insertTestDevice(t, db, memberUserID, "device-precreate-ordinary", time.Now().Add(-96*time.Hour))
insertTestFamily(t, db, familyID, ownerUserID)
insertTestFamilyMember(t, db, familyID, ownerUserID, user.FamilyRoleOwner, user.FamilyMemberActive, "owner_init")
insertTestFamilyMember(t, db, familyID, memberUserID, user.FamilyRoleMember, user.FamilyMemberActive, "manual_invite")
logic := NewPreCreateOrderLogic(buildPurchaseCtx(owner), svcCtx)
resp, err := logic.PreCreateOrder(&types.PurchaseOrderRequest{
SubscribeId: subID,
Quantity: 1,
})
require.NoError(t, err)
require.NotNil(t, resp)
assert.Equal(t, int64(900), resp.Amount)
assert.Equal(t, int64(100), resp.Discount)
}

View File

@ -10,6 +10,7 @@ import (
"time"
"github.com/perfect-panel/server/internal/logic/admin/group"
commonLogic "github.com/perfect-panel/server/internal/logic/common"
"github.com/perfect-panel/server/internal/model/log"
"github.com/perfect-panel/server/pkg/constant"
"github.com/perfect-panel/server/pkg/logger"
@ -44,6 +45,7 @@ const (
OrderStatusPaid = 2 // Order paid and ready for processing
OrderStatusClose = 3 // Order closed/cancelled
OrderStatusFailed = 4 // Order processing failed
OrderStatusClaimed = 4 // Internal transient claim while a worker processes the order
OrderStatusFinished = 5 // Order successfully completed
)
@ -82,7 +84,7 @@ func (l *ActivateOrderLogic) ProcessTask(ctx context.Context, task *asynq.Task)
logger.WithContext(ctx).Info("[ActivateOrderLogic] 正在验证订单",
logger.Field("order_no", payload.OrderNo))
orderInfo, err := l.validateAndGetOrder(ctx, payload.OrderNo)
orderInfo, err := l.claimAndGetOrder(ctx, payload.OrderNo)
if err != nil {
// 如果订单不存在或状态不对,不重试
if errors.Is(err, ErrInvalidOrderStatus) {
@ -108,6 +110,7 @@ func (l *ActivateOrderLogic) ProcessTask(ctx context.Context, task *asynq.Task)
logger.Field("user_id", orderInfo.UserId))
if err = l.processOrderByType(ctx, orderInfo, payload.IAPExpireAt); err != nil {
l.releaseClaim(ctx, orderInfo.OrderNo)
logger.WithContext(ctx).Error("[ActivateOrderLogic] 处理订单失败,将重试",
logger.Field("order_no", orderInfo.OrderNo),
logger.Field("order_type", orderInfo.Type),
@ -137,10 +140,11 @@ func (l *ActivateOrderLogic) parsePayload(ctx context.Context, payload []byte) (
return &p, nil
}
// validateAndGetOrder retrieves an order by order number and validates its status
// claimAndGetOrder retrieves an order by order number and atomically claims paid orders.
// Returns error if order is not found or not in paid status
func (l *ActivateOrderLogic) validateAndGetOrder(ctx context.Context, orderNo string) (*order.Order, error) {
orderInfo, err := l.svc.OrderModel.FindOneByOrderNo(ctx, orderNo)
func (l *ActivateOrderLogic) claimAndGetOrder(ctx context.Context, orderNo string) (*order.Order, error) {
var orderInfo order.Order
err := l.svc.DB.WithContext(ctx).Model(&order.Order{}).Where("order_no = ?", orderNo).First(&orderInfo).Error
if err != nil {
logger.WithContext(ctx).Error("Find order failed",
logger.Field("error", err.Error()),
@ -165,7 +169,33 @@ func (l *ActivateOrderLogic) validateAndGetOrder(ctx context.Context, orderNo st
return nil, ErrInvalidOrderStatus
}
return orderInfo, nil
result := l.svc.DB.WithContext(ctx).
Model(&order.Order{}).
Where("order_no = ? AND status = ?", orderNo, OrderStatusPaid).
Update("status", OrderStatusClaimed)
if result.Error != nil {
return nil, result.Error
}
if result.RowsAffected == 0 {
logger.WithContext(ctx).Info("Order already claimed by another worker, skip processing",
logger.Field("order_no", orderNo),
)
return nil, nil
}
orderInfo.Status = OrderStatusClaimed
return &orderInfo, nil
}
func (l *ActivateOrderLogic) releaseClaim(ctx context.Context, orderNo string) {
if err := l.svc.DB.WithContext(ctx).
Model(&order.Order{}).
Where("order_no = ? AND status = ?", orderNo, OrderStatusClaimed).
Update("status", OrderStatusPaid).Error; err != nil {
logger.WithContext(ctx).Error("Release order claim failed",
logger.Field("error", err.Error()),
logger.Field("order_no", orderNo),
)
}
}
// processOrderByType routes order processing based on the order type
@ -274,20 +304,24 @@ func (l *ActivateOrderLogic) NewPurchase(ctx context.Context, orderInfo *order.O
)
}
// 如果没有合并已购订阅再尝试合并赠送订阅order_id=0
if userSub == nil {
giftSub, giftErr := l.findGiftSubscription(ctx, singleModeUserId, orderInfo.SubscribeId)
if giftErr == nil && giftSub != nil {
// 在赠送订阅上延长时间,保持 token 不变
userSub, err = l.extendGiftSubscription(ctx, giftSub, orderInfo, sub)
if err != nil {
logger.WithContext(ctx).Error("Extend gift subscription failed",
logger.Field("error", err.Error()),
logger.Field("gift_subscribe_id", giftSub.Id),
)
// 合并失败时回退到创建新订阅
userSub = nil
}
}
// 如果没有合并已购订阅再尝试合并赠送订阅order_id=0
// 全局单订阅口径下,非 SingleModel 也不能让试用订阅和付费订阅并存。
if userSub == nil {
effectiveOwner := orderInfo.UserId
if orderInfo.SubscriptionUserId > 0 {
effectiveOwner = orderInfo.SubscriptionUserId
}
giftSub, giftErr := l.findGiftSubscription(ctx, effectiveOwner, orderInfo.SubscribeId)
if giftErr == nil && giftSub != nil {
userSub, err = l.extendGiftSubscription(ctx, giftSub, orderInfo, sub)
if err != nil {
logger.WithContext(ctx).Error("Extend gift subscription failed",
logger.Field("error", err.Error()),
logger.Field("gift_subscribe_id", giftSub.Id),
)
userSub = nil
}
}
}
@ -302,8 +336,10 @@ func (l *ActivateOrderLogic) NewPurchase(ctx context.Context, orderInfo *order.O
}
var existingSub user.Subscribe
if findErr := l.svc.DB.Model(&user.Subscribe{}).
Where("user_id IN ? AND subscribe_id = ?", candidateUserIds, orderInfo.SubscribeId).
Where("user_id IN ? AND token != ''", candidateUserIds).
Order("expire_time DESC").
Order("updated_at DESC").
Order("id DESC").
First(&existingSub).Error; findErr == nil {
// 家庭组场景:订阅 owner 可能变更(如成员注册的试用 → 被家主收归),
// 续期前把 user_id 校正为当前订单的 SubscriptionUserId
@ -514,7 +550,7 @@ func (l *ActivateOrderLogic) createUserSubscription(ctx context.Context, orderIn
// Check quota limit before creating subscription (final safeguard)
if sub.Quota > 0 {
var count int64
if err := l.svc.DB.Model(&user.Subscribe{}).Where("user_id = ? AND subscribe_id = ?", orderInfo.UserId, orderInfo.SubscribeId).Count(&count).Error; err != nil {
if err := l.svc.DB.Model(&user.Subscribe{}).Where("user_id = ?", subscriptionUserId).Count(&count).Error; err != nil {
logger.WithContext(ctx).Error("Count user subscribe failed", logger.Field("error", err.Error()))
return nil, err
}
@ -602,7 +638,7 @@ func (l *ActivateOrderLogic) handleCommission(ctx context.Context, userInfo *use
if !l.shouldProcessCommission(userInfo, orderInfo.IsNew) {
// 普通用户路径(佣金比例=0只有首单才双方赠N天
if orderInfo.IsNew {
l.grantGiftDaysToBothParties(ctx, userInfo, orderInfo.OrderNo)
l.grantGiftDaysToBothParties(ctx, userInfo, orderInfo)
}
return
}
@ -692,16 +728,18 @@ func (l *ActivateOrderLogic) handleCommission(ctx context.Context, userInfo *use
// 有佣金路径:邀请人拿佣金,被邀请用户(首单)拿天数
if orderInfo.IsNew {
_ = l.grantGiftDays(ctx, userInfo, int(l.svc.Config.Invite.GiftDays), orderInfo.OrderNo, "邀请赠送")
giftTarget := l.resolveGiftTargetUser(ctx, userInfo, orderInfo.SubscriptionUserId)
_ = l.grantGiftDays(ctx, giftTarget, int(l.svc.Config.Invite.GiftDays), orderInfo.OrderNo, "邀请赠送")
}
}
func (l *ActivateOrderLogic) grantGiftDaysToBothParties(ctx context.Context, referee *user.User, orderNo string) {
func (l *ActivateOrderLogic) grantGiftDaysToBothParties(ctx context.Context, referee *user.User, orderInfo *order.Order) {
giftDays := l.svc.Config.Invite.GiftDays
if giftDays <= 0 || referee == nil || referee.Id == 0 || referee.RefererId == 0 {
if giftDays <= 0 || referee == nil || referee.Id == 0 || referee.RefererId == 0 || orderInfo == nil {
return
}
_ = l.grantGiftDays(ctx, referee, int(giftDays), orderNo, "邀请赠送")
refereeTarget := l.resolveGiftTargetUser(ctx, referee, orderInfo.SubscriptionUserId)
_ = l.grantGiftDays(ctx, refereeTarget, int(giftDays), orderInfo.OrderNo, "邀请赠送")
if referee.RefererId == 0 {
return
}
@ -709,7 +747,32 @@ func (l *ActivateOrderLogic) grantGiftDaysToBothParties(ctx context.Context, ref
if err != nil || referer == nil {
return
}
_ = l.grantGiftDays(ctx, referer, int(giftDays), orderNo, "邀请赠送")
refererTarget := l.resolveGiftTargetUser(ctx, referer, 0)
_ = l.grantGiftDays(ctx, refererTarget, int(giftDays), orderInfo.OrderNo, "邀请赠送")
}
func (l *ActivateOrderLogic) resolveGiftTargetUser(ctx context.Context, source *user.User, forcedOwnerID int64) *user.User {
if source == nil || source.Id == 0 {
return source
}
targetID := source.Id
if forcedOwnerID > 0 {
targetID = forcedOwnerID
} else if entitlement, err := commonLogic.ResolveEntitlementUser(ctx, l.svc.DB, source.Id); err == nil && entitlement != nil && entitlement.EffectiveUserID > 0 {
targetID = entitlement.EffectiveUserID
}
if targetID == source.Id {
return source
}
target, err := l.svc.UserModel.FindOne(ctx, targetID)
if err != nil || target == nil {
logger.WithContext(ctx).Error("Resolve gift target owner failed",
logger.Field("source_user_id", source.Id),
logger.Field("target_user_id", targetID),
)
return source
}
return target
}
func (l *ActivateOrderLogic) grantGiftDays(ctx context.Context, u *user.User, days int, orderNo string, remark string) error {
@ -736,7 +799,22 @@ func (l *ActivateOrderLogic) grantGiftDays(ctx context.Context, u *user.User, da
activeSubscribe, err := l.svc.UserModel.FindActiveSubscribe(ctx, u.Id)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil
giftLog := &log.Gift{
Type: log.GiftTypeIncrease,
OrderNo: orderNo,
SubscribeId: 0,
Amount: int64(days),
Balance: u.Balance,
Remark: remark + " skipped: no active subscription",
Timestamp: time.Now().UnixMilli(),
}
content, _ := giftLog.Marshal()
return l.svc.LogModel.Insert(ctx, &log.SystemLog{
Type: log.TypeGift.Uint8(),
Date: time.Now().Format("2006-01-02"),
ObjectID: u.Id,
Content: string(content),
})
}
return err
}

View File

@ -0,0 +1,436 @@
package orderLogic
import (
"context"
"fmt"
"os"
"testing"
"time"
"github.com/perfect-panel/server/internal/config"
userLogic "github.com/perfect-panel/server/internal/logic/public/user"
modelLog "github.com/perfect-panel/server/internal/model/log"
"github.com/perfect-panel/server/internal/model/order"
"github.com/perfect-panel/server/internal/model/user"
"github.com/perfect-panel/server/internal/svc"
"github.com/perfect-panel/server/internal/types"
"github.com/perfect-panel/server/pkg/constant"
"github.com/redis/go-redis/v9"
"gorm.io/driver/mysql"
"gorm.io/gorm"
)
// 普通用户 + 首单 → 双方赠N天
func TestHandleCommission_GrantGiftDaysWhenCommissionDisabled_FirstOrder(t *testing.T) {
logic, db, cleanup := setupInviteTestLogic(t, config.InviteConfig{
ReferralPercentage: 0,
OnlyFirstPurchase: false,
GiftDays: 2,
})
defer cleanup()
referee := seedUser(t, db, 0, false)
referer := seedUser(t, db, 0, false)
referee.RefererId = referer.Id
baseExpire := time.Now().Add(72 * time.Hour).Truncate(time.Second)
refereeSub := seedActiveSubscribe(t, db, referee.Id, baseExpire)
refererSub := seedActiveSubscribe(t, db, referer.Id, baseExpire)
logic.handleCommission(context.Background(), referee, &order.Order{
OrderNo: "ORD-GIFT-001",
Type: OrderTypeSubscribe,
IsNew: true, // 首单
Amount: 100,
FeeAmount: 0,
CreatedAt: time.Now(),
})
assertExpireIncreasedByDays(t, db, refereeSub.Id, baseExpire, 2)
assertExpireIncreasedByDays(t, db, refererSub.Id, baseExpire, 2)
var giftCount int64
if err := db.Model(&modelLog.SystemLog{}).Where("type = ?", modelLog.TypeGift.Uint8()).Count(&giftCount).Error; err != nil {
t.Fatalf("count gift logs failed: %v", err)
}
if giftCount != 2 {
t.Fatalf("expected 2 gift logs, got %d", giftCount)
}
}
// 普通用户 + 非首单 → 不赠送
func TestHandleCommission_NoGiftDaysWhenCommissionDisabled_NotFirstOrder(t *testing.T) {
logic, db, cleanup := setupInviteTestLogic(t, config.InviteConfig{
ReferralPercentage: 0,
OnlyFirstPurchase: false,
GiftDays: 2,
})
defer cleanup()
referee := seedUser(t, db, 0, false)
referer := seedUser(t, db, 0, false)
referee.RefererId = referer.Id
baseExpire := time.Now().Add(72 * time.Hour).Truncate(time.Second)
refereeSub := seedActiveSubscribe(t, db, referee.Id, baseExpire)
refererSub := seedActiveSubscribe(t, db, referer.Id, baseExpire)
logic.handleCommission(context.Background(), referee, &order.Order{
OrderNo: "ORD-GIFT-002",
Type: OrderTypeSubscribe,
IsNew: false, // 非首单
Amount: 100,
FeeAmount: 0,
CreatedAt: time.Now(),
})
// 到期时间不应延长
assertExpireIncreasedByDays(t, db, refereeSub.Id, baseExpire, 0)
assertExpireIncreasedByDays(t, db, refererSub.Id, baseExpire, 0)
var giftCount int64
if err := db.Model(&modelLog.SystemLog{}).Where("type = ?", modelLog.TypeGift.Uint8()).Count(&giftCount).Error; err != nil {
t.Fatalf("count gift logs failed: %v", err)
}
if giftCount != 0 {
t.Fatalf("expected 0 gift logs for non-first order, got %d", giftCount)
}
}
// 渠道 + 首单 → 被邀请人赠N天 + 邀请人获佣金
func TestHandleCommission_GiftDaysAndCommissionWhenChannelFirstOrder(t *testing.T) {
logic, db, cleanup := setupInviteTestLogic(t, config.InviteConfig{
ReferralPercentage: 10,
OnlyFirstPurchase: false,
GiftDays: 2,
})
defer cleanup()
referee := seedUser(t, db, 0, false)
referer := seedUser(t, db, 0, false)
referee.RefererId = referer.Id
baseExpire := time.Now().Add(96 * time.Hour).Truncate(time.Second)
refereeSub := seedActiveSubscribe(t, db, referee.Id, baseExpire)
logic.handleCommission(context.Background(), referee, &order.Order{
OrderNo: "ORD-COMM-001",
Type: OrderTypeSubscribe,
IsNew: true, // 首单
Amount: 100,
FeeAmount: 0,
CreatedAt: time.Now(),
})
// 被邀请人(首单)应获得赠送天数
assertExpireIncreasedByDays(t, db, refereeSub.Id, baseExpire, 2)
// 邀请人应获得佣金
var refererAfter user.User
if err := db.First(&refererAfter, referer.Id).Error; err != nil {
t.Fatalf("query referer failed: %v", err)
}
if refererAfter.Commission != 10 {
t.Fatalf("expected referer commission=10, got %d", refererAfter.Commission)
}
var giftCount int64
if err := db.Model(&modelLog.SystemLog{}).Where("type = ?", modelLog.TypeGift.Uint8()).Count(&giftCount).Error; err != nil {
t.Fatalf("count gift logs failed: %v", err)
}
if giftCount != 1 {
t.Fatalf("expected 1 gift log for referee on first order with commission, got %d", giftCount)
}
}
// 渠道 + 非首单 → 只给邀请人佣金,不赠天
func TestHandleCommission_OnlyCommissionWhenChannelNotFirstOrder(t *testing.T) {
logic, db, cleanup := setupInviteTestLogic(t, config.InviteConfig{
ReferralPercentage: 10,
OnlyFirstPurchase: false,
GiftDays: 2,
})
defer cleanup()
referee := seedUser(t, db, 0, false)
referer := seedUser(t, db, 0, false)
referee.RefererId = referer.Id
baseExpire := time.Now().Add(96 * time.Hour).Truncate(time.Second)
refereeSub := seedActiveSubscribe(t, db, referee.Id, baseExpire)
logic.handleCommission(context.Background(), referee, &order.Order{
OrderNo: "ORD-COMM-002",
Type: OrderTypeSubscribe,
IsNew: false, // 非首单
Amount: 100,
FeeAmount: 0,
CreatedAt: time.Now(),
})
// 被邀请人不应获得赠送天数
assertExpireIncreasedByDays(t, db, refereeSub.Id, baseExpire, 0)
// 邀请人应获得佣金
var refererAfter user.User
if err := db.First(&refererAfter, referer.Id).Error; err != nil {
t.Fatalf("query referer failed: %v", err)
}
if refererAfter.Commission != 10 {
t.Fatalf("expected referer commission=10, got %d", refererAfter.Commission)
}
var giftCount int64
if err := db.Model(&modelLog.SystemLog{}).Where("type = ?", modelLog.TypeGift.Uint8()).Count(&giftCount).Error; err != nil {
t.Fatalf("count gift logs failed: %v", err)
}
if giftCount != 0 {
t.Fatalf("expected 0 gift logs when channel non-first order, got %d", giftCount)
}
}
func TestHandleCommission_NoGiftDaysWhenNoInviteRelation(t *testing.T) {
logic, db, cleanup := setupInviteTestLogic(t, config.InviteConfig{
ReferralPercentage: 0,
OnlyFirstPurchase: false,
GiftDays: 2,
})
defer cleanup()
// 没有邀请人的独立用户
loneUser := seedUser(t, db, 0, false)
// RefererId == 0无邀请关系
baseExpire := time.Now().Add(72 * time.Hour).Truncate(time.Second)
loneSub := seedActiveSubscribe(t, db, loneUser.Id, baseExpire)
logic.handleCommission(context.Background(), loneUser, &order.Order{
OrderNo: "ORD-LONE-001",
Type: OrderTypeSubscribe,
IsNew: true,
Amount: 100,
FeeAmount: 0,
CreatedAt: time.Now(),
})
// 订阅到期时间不应该被延长
var subAfter user.Subscribe
if err := db.First(&subAfter, loneSub.Id).Error; err != nil {
t.Fatalf("query subscribe failed: %v", err)
}
if !subAfter.ExpireTime.Equal(baseExpire) {
t.Fatalf("expected no gift days for user without inviter, before=%v after=%v", baseExpire, subAfter.ExpireTime)
}
// 不应产生赠天日志
var giftCount int64
if err := db.Model(&modelLog.SystemLog{}).Where("type = ?", modelLog.TypeGift.Uint8()).Count(&giftCount).Error; err != nil {
t.Fatalf("count gift logs failed: %v", err)
}
if giftCount != 0 {
t.Fatalf("expected 0 gift logs for user without inviter, got %d", giftCount)
}
}
// 先绑码后首单 → 双方赠N天
func TestInviteFlow_BindThenFirstOrder_GrantGiftDays(t *testing.T) {
logic, db, cleanup := setupInviteTestLogic(t, config.InviteConfig{
ReferralPercentage: 0,
OnlyFirstPurchase: false,
GiftDays: 2,
})
defer cleanup()
referee := seedUser(t, db, 0, false)
referer := seedUser(t, db, 0, false)
referer.ReferCode = fmt.Sprintf("REF-%d", referer.Id)
if err := db.Model(&user.User{}).Where("id = ?", referer.Id).Update("refer_code", referer.ReferCode).Error; err != nil {
t.Fatalf("update referer code failed: %v", err)
}
refereeBaseExpire := time.Now().Add(48 * time.Hour).Truncate(time.Second)
refererBaseExpire := time.Now().Add(72 * time.Hour).Truncate(time.Second)
refereeSub := seedActiveSubscribe(t, db, referee.Id, refereeBaseExpire)
refererSub := seedActiveSubscribe(t, db, referer.Id, refererBaseExpire)
ctx := context.WithValue(context.Background(), constant.CtxKeyUser, referee)
bindLogic := userLogic.NewBindInviteCodeLogic(ctx, logic.svc)
if err := bindLogic.BindInviteCode(&types.BindInviteCodeRequest{InviteCode: referer.ReferCode}); err != nil {
t.Fatalf("bind invite code failed: %v", err)
}
var refereeAfterBind user.User
if err := db.First(&refereeAfterBind, referee.Id).Error; err != nil {
t.Fatalf("query referee after bind failed: %v", err)
}
if refereeAfterBind.RefererId != referer.Id {
t.Fatalf("bind invite failed, expected referer_id=%d got=%d", referer.Id, refereeAfterBind.RefererId)
}
// 首单 IsNew=true → 双方赠N天
logic.handleCommission(context.Background(), &refereeAfterBind, &order.Order{
OrderNo: "ORD-FLOW-001",
Type: OrderTypeSubscribe,
IsNew: true,
Amount: 100,
FeeAmount: 0,
CreatedAt: time.Now(),
})
assertExpireIncreasedByDays(t, db, refereeSub.Id, refereeBaseExpire, 2)
assertExpireIncreasedByDays(t, db, refererSub.Id, refererBaseExpire, 2)
}
// 先买订单后绑码再续费 → 不赠送IsNew=false
func TestInviteFlow_OrderThenBind_NoGiftDays(t *testing.T) {
logic, db, cleanup := setupInviteTestLogic(t, config.InviteConfig{
ReferralPercentage: 0,
OnlyFirstPurchase: false,
GiftDays: 2,
})
defer cleanup()
referee := seedUser(t, db, 0, false)
referer := seedUser(t, db, 0, false)
referee.RefererId = referer.Id
baseExpire := time.Now().Add(72 * time.Hour).Truncate(time.Second)
refereeSub := seedActiveSubscribe(t, db, referee.Id, baseExpire)
refererSub := seedActiveSubscribe(t, db, referer.Id, baseExpire)
// 先前已有订单IsNew=false模拟先买订单后绑码的场景
logic.handleCommission(context.Background(), referee, &order.Order{
OrderNo: "ORD-FLOW-002",
Type: OrderTypeSubscribe,
IsNew: false, // 已有历史订单
Amount: 100,
FeeAmount: 0,
CreatedAt: time.Now(),
})
assertExpireIncreasedByDays(t, db, refereeSub.Id, baseExpire, 0)
assertExpireIncreasedByDays(t, db, refererSub.Id, baseExpire, 0)
}
func setupInviteTestLogic(t *testing.T, inviteCfg config.InviteConfig) (*ActivateOrderLogic, *gorm.DB, func()) {
t.Helper()
mysqlAddr := getenvDefault("TEST_MYSQL_ADDR", "127.0.0.1:3306")
mysqlUser := getenvDefault("TEST_MYSQL_USER", "root")
mysqlPassword := getenvDefault("TEST_MYSQL_PASSWORD", "rootpassword")
adminDSN := fmt.Sprintf("%s:%s@tcp(%s)/?charset=utf8mb4&parseTime=true&loc=Local&multiStatements=true", mysqlUser, mysqlPassword, mysqlAddr)
adminDB, err := gorm.Open(mysql.Open(adminDSN), &gorm.Config{})
if err != nil {
t.Fatalf("open mysql admin connection failed: %v", err)
}
dbName := fmt.Sprintf("ppanel_test_invite_%d", time.Now().UnixNano())
if err := adminDB.Exec(fmt.Sprintf("CREATE DATABASE `%s` CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci", dbName)).Error; err != nil {
t.Fatalf("create test database failed: %v", err)
}
testDSN := fmt.Sprintf("%s:%s@tcp(%s)/%s?charset=utf8mb4&parseTime=true&loc=Local", mysqlUser, mysqlPassword, mysqlAddr, dbName)
db, err := gorm.Open(mysql.Open(testDSN), &gorm.Config{})
if err != nil {
t.Fatalf("open test database failed: %v", err)
}
if err := db.AutoMigrate(&user.User{}, &user.Device{}, &user.AuthMethods{}, &user.Subscribe{}, &modelLog.SystemLog{}); err != nil {
t.Fatalf("auto migrate failed: %v", err)
}
redisAddr := getenvDefault("TEST_REDIS_ADDR", "127.0.0.1:6379")
redisPassword := getenvDefault("TEST_REDIS_PASSWORD", "")
rdb := redis.NewClient(&redis.Options{
Addr: redisAddr,
Password: redisPassword,
DB: 0,
})
if err := rdb.Ping(context.Background()).Err(); err != nil {
t.Fatalf("connect redis failed: %v", err)
}
_ = rdb.FlushDB(context.Background()).Err()
svcCtx := &svc.ServiceContext{
DB: db,
Redis: rdb,
UserModel: user.NewModel(db, rdb),
LogModel: modelLog.NewModel(db),
Config: config.Config{
Invite: inviteCfg,
},
}
return NewActivateOrderLogic(svcCtx), db, func() {
_ = rdb.Close()
sqlDB, _ := db.DB()
if sqlDB != nil {
_ = sqlDB.Close()
}
_ = adminDB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS `%s`", dbName)).Error
}
}
func seedUser(t *testing.T, db *gorm.DB, referralPercentage uint8, onlyFirstPurchase bool) *user.User {
t.Helper()
u := &user.User{
Password: "pwd",
Algo: "default",
ReferralPercentage: referralPercentage,
OnlyFirstPurchase: boolPtr(onlyFirstPurchase),
Enable: boolPtr(true),
IsAdmin: boolPtr(false),
EnableBalanceNotify: boolPtr(false),
EnableLoginNotify: boolPtr(false),
EnableSubscribeNotify: boolPtr(false),
EnableTradeNotify: boolPtr(false),
}
if err := db.Create(u).Error; err != nil {
t.Fatalf("seed user failed: %v", err)
}
return u
}
func seedActiveSubscribe(t *testing.T, db *gorm.DB, userID int64, expireAt time.Time) *user.Subscribe {
t.Helper()
sub := &user.Subscribe{
UserId: userID,
OrderId: 1,
SubscribeId: 1,
StartTime: time.Now().Add(-24 * time.Hour),
ExpireTime: expireAt,
Traffic: 1024,
Token: fmt.Sprintf("token-%d-%d", userID, time.Now().UnixNano()),
UUID: fmt.Sprintf("uuid-%d-%d", userID, time.Now().UnixNano()),
Status: 1,
}
if err := db.Create(sub).Error; err != nil {
t.Fatalf("seed subscribe failed: %v", err)
}
return sub
}
func assertExpireIncreasedByDays(t *testing.T, db *gorm.DB, subscribeID int64, before time.Time, days int) {
t.Helper()
var after user.Subscribe
if err := db.First(&after, subscribeID).Error; err != nil {
t.Fatalf("query subscribe failed: %v", err)
}
expected := before.Add(time.Duration(days) * 24 * time.Hour)
if !after.ExpireTime.Equal(expected) {
t.Fatalf("expire time mismatch, expected=%v got=%v", expected, after.ExpireTime)
}
}
func boolPtr(v bool) *bool {
return &v
}
func getenvDefault(key, fallback string) string {
v := os.Getenv(key)
if v == "" {
return fallback
}
return v
}

View File

@ -0,0 +1,199 @@
package orderLogic
import (
"context"
"testing"
"time"
modelOrder "github.com/perfect-panel/server/internal/model/order"
"github.com/perfect-panel/server/internal/model/subscribe"
"github.com/perfect-panel/server/internal/model/user"
"github.com/stretchr/testify/require"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
func setupActivationEligibilityDB(t *testing.T) *gorm.DB {
t.Helper()
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
})
require.NoError(t, err)
sqls := []string{
`CREATE TABLE IF NOT EXISTS "user" (
id INTEGER PRIMARY KEY AUTOINCREMENT,
password VARCHAR(100) NOT NULL DEFAULT '',
created_at DATETIME,
updated_at DATETIME,
deleted_at DATETIME DEFAULT NULL
)`,
`CREATE TABLE IF NOT EXISTS "user_device" (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL DEFAULT 0,
identifier VARCHAR(255) NOT NULL DEFAULT '' UNIQUE,
created_at DATETIME,
updated_at DATETIME
)`,
`CREATE TABLE IF NOT EXISTS "user_family" (
id INTEGER PRIMARY KEY AUTOINCREMENT,
owner_user_id INTEGER NOT NULL DEFAULT 0,
status TINYINT NOT NULL DEFAULT 1,
deleted_at DATETIME DEFAULT NULL
)`,
`CREATE TABLE IF NOT EXISTS "user_family_member" (
id INTEGER PRIMARY KEY AUTOINCREMENT,
family_id INTEGER NOT NULL DEFAULT 0,
user_id INTEGER NOT NULL DEFAULT 0,
role TINYINT NOT NULL DEFAULT 0,
status TINYINT NOT NULL DEFAULT 0,
join_source VARCHAR(32) NOT NULL DEFAULT '',
deleted_at DATETIME DEFAULT NULL
)`,
`CREATE TABLE IF NOT EXISTS "order" (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL DEFAULT 0,
order_no VARCHAR(255) NOT NULL DEFAULT '' UNIQUE,
type TINYINT NOT NULL DEFAULT 1,
status TINYINT NOT NULL DEFAULT 1,
subscribe_id INTEGER NOT NULL DEFAULT 0,
quantity INTEGER NOT NULL DEFAULT 1,
created_at DATETIME,
updated_at DATETIME
)`,
}
for _, sql := range sqls {
require.NoError(t, db.Exec(sql).Error)
}
return db
}
func insertActivationUser(t *testing.T, db *gorm.DB, userID int64, createdAt time.Time) {
t.Helper()
require.NoError(t, db.Exec(
`INSERT INTO "user" (id, created_at, updated_at) VALUES (?, ?, datetime('now'))`,
userID,
createdAt.UTC().Format("2006-01-02 15:04:05"),
).Error)
}
func insertActivationDevice(t *testing.T, db *gorm.DB, userID int64, identifier string, createdAt time.Time) {
t.Helper()
require.NoError(t, db.Exec(
`INSERT INTO "user_device" (user_id, identifier, created_at, updated_at) VALUES (?, ?, ?, datetime('now'))`,
userID,
identifier,
createdAt.UTC().Format("2006-01-02 15:04:05"),
).Error)
}
func insertActivationFamily(t *testing.T, db *gorm.DB, familyID, ownerUserID int64) {
t.Helper()
require.NoError(t, db.Exec(
`INSERT INTO "user_family" (id, owner_user_id, status) VALUES (?, ?, 1)`,
familyID,
ownerUserID,
).Error)
}
func insertActivationFamilyMember(t *testing.T, db *gorm.DB, familyID, userID int64, role, status uint8, joinSource string) {
t.Helper()
require.NoError(t, db.Exec(
`INSERT INTO "user_family_member" (family_id, user_id, role, status, join_source) VALUES (?, ?, ?, ?, ?)`,
familyID,
userID,
role,
status,
joinSource,
).Error)
}
func insertActivationOrder(t *testing.T, db *gorm.DB, orderNo string, userID, subscribeID int64, status uint8) {
t.Helper()
require.NoError(t, db.Exec(
`INSERT INTO "order" (user_id, order_no, type, status, subscribe_id, quantity, created_at, updated_at)
VALUES (?, ?, 1, ?, ?, 1, datetime('now'), datetime('now'))`,
userID,
orderNo,
status,
subscribeID,
).Error)
}
func TestValidateNewUserOnlyEligibilityAtActivation_UsesEarliestBoundDeviceTime(t *testing.T) {
db := setupActivationEligibilityDB(t)
const (
ownerUserID = int64(1)
memberUserID = int64(2)
familyID = int64(10)
subscribeID = int64(100)
)
insertActivationUser(t, db, ownerUserID, time.Now().Add(-1*time.Hour))
insertActivationUser(t, db, memberUserID, time.Now().Add(-72*time.Hour))
insertActivationDevice(t, db, memberUserID, "activation-old-device", time.Now().Add(-72*time.Hour))
insertActivationFamily(t, db, familyID, ownerUserID)
insertActivationFamilyMember(t, db, familyID, ownerUserID, user.FamilyRoleOwner, user.FamilyMemberActive, "owner_init")
insertActivationFamilyMember(t, db, familyID, memberUserID, user.FamilyRoleMember, user.FamilyMemberActive, "bind_email_with_verification")
err := validateNewUserOnlyEligibilityAtActivation(
context.Background(),
db,
&modelOrder.Order{
UserId: ownerUserID,
OrderNo: "activation-check-old-device",
Type: OrderTypeSubscribe,
Quantity: 1,
SubscribeId: subscribeID,
},
&subscribe.Subscribe{
Id: subscribeID,
Discount: `[{"quantity":1,"discount":90,"new_user_only":true}]`,
},
)
require.Error(t, err)
require.Contains(t, err.Error(), "is not a new user")
}
func TestValidateNewUserOnlyEligibilityAtActivation_SharesHistoryAcrossBoundScope(t *testing.T) {
db := setupActivationEligibilityDB(t)
const (
ownerUserID = int64(11)
memberUserID = int64(12)
familyID = int64(20)
subscribeID = int64(200)
)
insertActivationUser(t, db, ownerUserID, time.Now().Add(-1*time.Hour))
insertActivationUser(t, db, memberUserID, time.Now().Add(-2*time.Hour))
insertActivationDevice(t, db, memberUserID, "activation-shared-device", time.Now().Add(-2*time.Hour))
insertActivationFamily(t, db, familyID, ownerUserID)
insertActivationFamilyMember(t, db, familyID, ownerUserID, user.FamilyRoleOwner, user.FamilyMemberActive, "owner_init")
insertActivationFamilyMember(t, db, familyID, memberUserID, user.FamilyRoleMember, user.FamilyMemberActive, "bind_email_with_verification")
insertActivationOrder(t, db, "previous-finished-order", memberUserID, subscribeID, OrderStatusFinished)
err := validateNewUserOnlyEligibilityAtActivation(
context.Background(),
db,
&modelOrder.Order{
UserId: ownerUserID,
OrderNo: "current-paid-order",
Type: OrderTypeSubscribe,
Quantity: 1,
SubscribeId: subscribeID,
},
&subscribe.Subscribe{
Id: subscribeID,
Discount: `[{"quantity":1,"discount":90,"new_user_only":true}]`,
},
)
require.Error(t, err)
require.Contains(t, err.Error(), "already activated")
}

View File

@ -43,7 +43,7 @@ func validateNewUserOnlyEligibilityAtActivation(
ctx,
db,
eligibility.ScopeUserIDs,
orderInfo.SubscribeId,
0,
[]int64{OrderStatusFinished},
orderInfo.OrderNo,
)
@ -51,7 +51,7 @@ func validateNewUserOnlyEligibilityAtActivation(
return fmt.Errorf("new user only: check history error: %w", err)
}
if historyCount >= 1 {
return fmt.Errorf("new user only: user %d already activated subscribe %d", orderInfo.UserId, orderInfo.SubscribeId)
return fmt.Errorf("new user only: user %d already activated an order", orderInfo.UserId)
}
return nil

View File

@ -0,0 +1,197 @@
package main
import (
"database/sql"
"flag"
"fmt"
"log"
"os"
"strings"
_ "github.com/go-sql-driver/mysql"
)
func main() {
dsn := flag.String("dsn", os.Getenv("PPANEL_MYSQL_DSN"), "MySQL DSN; defaults to PPANEL_MYSQL_DSN")
flag.Parse()
if strings.TrimSpace(*dsn) == "" {
log.Fatal("missing DSN: pass -dsn or set PPANEL_MYSQL_DSN")
}
db, err := sql.Open("mysql", *dsn)
if err != nil {
log.Fatal(err)
}
defer db.Close()
if err = db.Ping(); err != nil {
log.Fatal(err)
}
mustPrintRows(db, "db/info", `
SELECT NOW() AS db_now,
(SELECT COUNT(*) FROM user) AS users,
(SELECT COUNT(*) FROM user_subscribe) AS user_subscribes,
(SELECT COUNT(*) FROM `+"`order`"+`) AS orders`)
mustPrintRows(db, "bug1/confusable-email-trials", `
SELECT uam.user_id,
uam.auth_identifier,
us.id AS user_subscribe_id,
us.order_id,
us.status,
us.expire_time,
us.created_at
FROM user_auth_methods uam
JOIN user_subscribe us ON us.user_id = uam.user_id
WHERE uam.auth_type = 'email'
AND us.order_id = 0
AND (
uam.auth_identifier LIKE '%@gmaial.com'
OR uam.auth_identifier LIKE '%@gmial.com'
OR uam.auth_identifier LIKE '%@gamil.com'
OR uam.auth_identifier LIKE '%+%@%'
OR uam.auth_identifier REGEXP '^[^@]*\\.[^@]*@gmail\\.com$'
)
ORDER BY us.created_at DESC
LIMIT 50`)
mustPrintRows(db, "bug2-visible-duplicate-subscriptions", `
SELECT scoped.owner_user_id,
COUNT(*) AS visible_subscribe_count,
GROUP_CONCAT(scoped.user_subscribe_id ORDER BY scoped.expire_time DESC) AS user_subscribe_ids,
GROUP_CONCAT(scoped.subscribe_id ORDER BY scoped.expire_time DESC) AS subscribe_ids,
MAX(scoped.expire_time) AS max_expire_time
FROM (
SELECT us.id AS user_subscribe_id,
us.user_id,
COALESCE(uf.owner_user_id, us.user_id) AS owner_user_id,
us.subscribe_id,
us.status,
us.expire_time,
us.finished_at
FROM user_subscribe us
LEFT JOIN user_family_member ufm
ON ufm.user_id = us.user_id AND ufm.deleted_at IS NULL AND ufm.status = 1
LEFT JOIN user_family uf
ON uf.id = ufm.family_id AND uf.deleted_at IS NULL AND uf.status = 1
WHERE us.token <> ''
AND us.status IN (0,1,2,3,4)
AND (us.expire_time > NOW()
OR us.finished_at >= DATE_SUB(NOW(), INTERVAL 7 DAY)
OR us.expire_time = FROM_UNIXTIME(0))
) scoped
GROUP BY scoped.owner_user_id
HAVING COUNT(*) > 1
ORDER BY visible_subscribe_count DESC, owner_user_id
LIMIT 50`)
mustPrintRows(db, "bug2-order-subscription-owner-mismatch", `
SELECT us.id AS user_subscribe_id,
us.user_id AS subscribe_user_id,
o.id AS order_id,
o.order_no,
o.user_id AS order_user_id,
o.subscription_user_id,
us.status,
us.expire_time,
us.created_at AS subscribe_created_at,
o.created_at AS order_created_at
FROM user_subscribe us
JOIN `+"`order`"+` o ON o.id = us.order_id
WHERE us.user_id <> o.subscription_user_id
AND us.token <> ''
AND us.status IN (0,1,2,3,4)
ORDER BY us.updated_at DESC
LIMIT 50`)
mustPrintRows(db, "bug3-invite-first-orders-missing-gift-days", `
SELECT first_orders.user_id AS referee_id,
referee.referer_id,
first_orders.id AS order_id,
first_orders.order_no,
first_orders.amount,
first_orders.created_at,
referer.referral_percentage AS referer_referral_percentage,
(SELECT COUNT(*) FROM system_logs sl
WHERE sl.type = 34
AND sl.object_id = first_orders.user_id
AND sl.content LIKE CONCAT('%', first_orders.order_no, '%')) AS referee_gift_logs,
(SELECT COUNT(*) FROM system_logs sl
WHERE sl.type = 34
AND sl.object_id = referee.referer_id
AND sl.content LIKE CONCAT('%', first_orders.order_no, '%')) AS referer_gift_logs
FROM (
SELECT o.*
FROM `+"`order`"+` o
JOIN (
SELECT user_id, MIN(id) AS first_order_id
FROM `+"`order`"+`
WHERE type IN (1,2)
AND status IN (2,5)
AND amount > 0
GROUP BY user_id
) fo ON fo.first_order_id = o.id
) first_orders
JOIN user referee ON referee.id = first_orders.user_id AND referee.referer_id <> 0
JOIN user referer ON referer.id = referee.referer_id
WHERE (
referer.referral_percentage = 0
AND (
(SELECT COUNT(*) FROM system_logs sl
WHERE sl.type = 34 AND sl.object_id = first_orders.user_id AND sl.content LIKE CONCAT('%', first_orders.order_no, '%')) = 0
OR
(SELECT COUNT(*) FROM system_logs sl
WHERE sl.type = 34 AND sl.object_id = referee.referer_id AND sl.content LIKE CONCAT('%', first_orders.order_no, '%')) = 0
)
)
OR (
referer.referral_percentage > 0
AND (SELECT COUNT(*) FROM system_logs sl
WHERE sl.type = 34 AND sl.object_id = first_orders.user_id AND sl.content LIKE CONCAT('%', first_orders.order_no, '%')) = 0
)
ORDER BY first_orders.created_at DESC
LIMIT 50`)
}
func mustPrintRows(db *sql.DB, title string, query string) {
fmt.Printf("\n== %s ==\n", title)
rows, err := db.Query(query)
if err != nil {
log.Fatalf("%s: %v", title, err)
}
defer rows.Close()
cols, err := rows.Columns()
if err != nil {
log.Fatalf("%s columns: %v", title, err)
}
fmt.Println(strings.Join(cols, "\t"))
values := make([]sql.NullString, len(cols))
args := make([]any, len(cols))
for i := range values {
args[i] = &values[i]
}
count := 0
for rows.Next() {
if err := rows.Scan(args...); err != nil {
log.Fatalf("%s scan: %v", title, err)
}
out := make([]string, len(cols))
for i, value := range values {
if value.Valid {
out[i] = value.String
} else {
out[i] = "NULL"
}
}
fmt.Println(strings.Join(out, "\t"))
count++
}
if err := rows.Err(); err != nil {
log.Fatalf("%s rows: %v", title, err)
}
if count == 0 {
fmt.Println("(none)")
}
}

View File

@ -0,0 +1,204 @@
package main
import (
"database/sql"
"encoding/json"
"flag"
"fmt"
"log"
"os"
"strings"
"time"
_ "github.com/go-sql-driver/mysql"
)
type duplicateGroup struct {
OwnerUserID int64 `json:"owner_user_id"`
Count int64 `json:"count"`
}
type subscriptionRow struct {
ID int64 `json:"id"`
UserID int64 `json:"user_id"`
OrderID int64 `json:"order_id"`
SubscribeID int64 `json:"subscribe_id"`
ExpireTime time.Time `json:"expire_time"`
Traffic int64 `json:"traffic"`
Download int64 `json:"download"`
Upload int64 `json:"upload"`
ExpiredDownload int64 `json:"expired_download"`
ExpiredUpload int64 `json:"expired_upload"`
Status uint8 `json:"status"`
UpdatedAt time.Time `json:"updated_at"`
}
type mergePlan struct {
OwnerUserID int64 `json:"owner_user_id"`
Keep subscriptionRow `json:"keep"`
Merge []subscriptionRow `json:"merge"`
}
func main() {
dsn := flag.String("dsn", os.Getenv("PPANEL_MYSQL_DSN"), "MySQL DSN; defaults to PPANEL_MYSQL_DSN")
execute := flag.Bool("execute", false, "apply changes; default is dry-run")
flag.Parse()
if strings.TrimSpace(*dsn) == "" {
log.Fatal("missing DSN: pass -dsn or set PPANEL_MYSQL_DSN")
}
db, err := sql.Open("mysql", *dsn)
if err != nil {
log.Fatal(err)
}
defer db.Close()
groups, err := findDuplicateGroups(db)
if err != nil {
log.Fatal(err)
}
plans := make([]mergePlan, 0, len(groups))
for _, group := range groups {
plan, err := buildPlan(db, group.OwnerUserID)
if err != nil {
log.Fatal(err)
}
if len(plan.Merge) > 0 {
plans = append(plans, plan)
}
}
enc := json.NewEncoder(os.Stdout)
enc.SetIndent("", " ")
if err := enc.Encode(plans); err != nil {
log.Fatal(err)
}
if !*execute {
fmt.Fprintf(os.Stderr, "dry-run only: %d duplicate owner groups found\n", len(plans))
return
}
for _, plan := range plans {
if err := applyPlan(db, plan); err != nil {
log.Fatal(err)
}
}
fmt.Fprintf(os.Stderr, "merged %d duplicate owner groups\n", len(plans))
}
func findDuplicateGroups(db *sql.DB) ([]duplicateGroup, error) {
rows, err := db.Query(`
SELECT owner_user_id, COUNT(1) AS cnt
FROM (
SELECT us.id,
COALESCE(uf.owner_user_id, us.user_id) AS owner_user_id
FROM user_subscribe us
LEFT JOIN user_family_member ufm
ON ufm.user_id = us.user_id AND ufm.deleted_at IS NULL AND ufm.status = 1
LEFT JOIN user_family uf
ON uf.id = ufm.family_id AND uf.deleted_at IS NULL AND uf.status = 1
WHERE us.token <> ''
AND us.status IN (0, 1, 2, 3, 4)
) scoped
GROUP BY owner_user_id
HAVING COUNT(1) > 1
ORDER BY owner_user_id`)
if err != nil {
return nil, err
}
defer rows.Close()
var groups []duplicateGroup
for rows.Next() {
var g duplicateGroup
if err := rows.Scan(&g.OwnerUserID, &g.Count); err != nil {
return nil, err
}
groups = append(groups, g)
}
return groups, rows.Err()
}
func buildPlan(db *sql.DB, ownerUserID int64) (mergePlan, error) {
rows, err := db.Query(`
SELECT us.id, us.user_id, us.order_id, us.subscribe_id, us.expire_time, us.traffic,
us.download, us.upload, us.expired_download, us.expired_upload, us.status, us.updated_at
FROM user_subscribe us
LEFT JOIN user_family_member ufm
ON ufm.user_id = us.user_id AND ufm.deleted_at IS NULL AND ufm.status = 1
LEFT JOIN user_family uf
ON uf.id = ufm.family_id AND uf.deleted_at IS NULL AND uf.status = 1
WHERE COALESCE(uf.owner_user_id, us.user_id) = ?
AND us.token <> ''
AND us.status IN (0, 1, 2, 3, 4)
ORDER BY us.expire_time DESC, us.updated_at DESC, us.id DESC`, ownerUserID)
if err != nil {
return mergePlan{}, err
}
defer rows.Close()
var all []subscriptionRow
for rows.Next() {
var r subscriptionRow
if err := rows.Scan(&r.ID, &r.UserID, &r.OrderID, &r.SubscribeID, &r.ExpireTime, &r.Traffic, &r.Download, &r.Upload, &r.ExpiredDownload, &r.ExpiredUpload, &r.Status, &r.UpdatedAt); err != nil {
return mergePlan{}, err
}
all = append(all, r)
}
if err := rows.Err(); err != nil {
return mergePlan{}, err
}
if len(all) == 0 {
return mergePlan{OwnerUserID: ownerUserID}, nil
}
keep := all[0]
for _, r := range all[1:] {
keep.Download += r.Download
keep.Upload += r.Upload
keep.ExpiredDownload += r.ExpiredDownload
keep.ExpiredUpload += r.ExpiredUpload
if r.Traffic > keep.Traffic {
keep.Traffic = r.Traffic
}
}
for _, r := range all {
if r.UpdatedAt.After(keep.UpdatedAt) {
keep.SubscribeID = r.SubscribeID
}
}
return mergePlan{OwnerUserID: ownerUserID, Keep: keep, Merge: all[1:]}, nil
}
func applyPlan(db *sql.DB, plan mergePlan) error {
tx, err := db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if _, err = tx.Exec(`
UPDATE user_subscribe
SET user_id = ?, subscribe_id = ?, traffic = ?, download = ?, upload = ?,
expired_download = ?, expired_upload = ?, status = 1, note = CONCAT(COALESCE(note, ''), ' [merged duplicate subscriptions]')
WHERE id = ?`,
plan.OwnerUserID, plan.Keep.SubscribeID, plan.Keep.Traffic, plan.Keep.Download, plan.Keep.Upload,
plan.Keep.ExpiredDownload, plan.Keep.ExpiredUpload, plan.Keep.ID); err != nil {
return err
}
for _, r := range plan.Merge {
if _, err = tx.Exec(`
UPDATE user_subscribe
SET status = 5, note = CONCAT(COALESCE(note, ''), ' [merged into subscription #', ?, ']')
WHERE id = ?`, plan.Keep.ID, r.ID); err != nil {
return err
}
}
return tx.Commit()
}

View File

@ -0,0 +1,787 @@
package main
import (
"context"
"encoding/json"
"flag"
"fmt"
"os"
"os/exec"
"strings"
"time"
"github.com/hibiken/asynq"
"github.com/perfect-panel/server/internal/config"
authlogic "github.com/perfect-panel/server/internal/logic/auth"
modelLog "github.com/perfect-panel/server/internal/model/log"
modelOrder "github.com/perfect-panel/server/internal/model/order"
modelSubscribe "github.com/perfect-panel/server/internal/model/subscribe"
modelUser "github.com/perfect-panel/server/internal/model/user"
"github.com/perfect-panel/server/internal/svc"
"github.com/perfect-panel/server/pkg/conf"
"github.com/perfect-panel/server/pkg/orm"
"github.com/perfect-panel/server/pkg/uuidx"
orderLogic "github.com/perfect-panel/server/queue/logic/order"
queueTypes "github.com/perfect-panel/server/queue/types"
"github.com/redis/go-redis/v9"
"gorm.io/gorm"
)
const marker = "codex-replay-business-bugs"
func main() {
var (
configPath = flag.String("config", "etc/ppanel.yaml", "ppanel config path for test server DB/Redis")
dsn = flag.String("dsn", "", "optional MySQL DSN override: user:pass@tcp(host:3306)/db?charset=utf8mb4&parseTime=true&loc=Asia%2FShanghai")
writeDB = flag.Bool("write-db", false, "create isolated test rows and execute activation replay against the configured test DB")
force = flag.Bool("force", false, "allow -write-db even when the config name does not clearly look like test/dev/staging")
keep = flag.Bool("keep", false, "keep replay rows for manual inspection")
cleanupOnly = flag.Bool("cleanup-only", false, "delete leftover replay rows by marker and exit")
skipCodeTests = flag.Bool("skip-code-tests", false, "skip go test checks")
)
flag.Parse()
ctx := context.Background()
started := time.Now()
fmt.Println("== replay business bug tests ==")
fmt.Printf("marker: %s\n", marker)
if !*skipCodeTests {
must(runCodeTests())
}
cfg := loadConfig(*configPath, *dsn)
runEmailTrialAssertions(cfg)
if *cleanupOnly {
env := mustNewReplayEnv(ctx, cfg)
env.cleanupByMarker(ctx)
return
}
if !*writeDB {
fmt.Println("\nDB replay skipped. Add -write-db to create isolated rows in the TEST database and run activation flows.")
fmt.Println("Example:")
fmt.Printf(" go run scripts/replay_business_bugs.go -config %s -write-db\n", *configPath)
return
}
if looksLikeProduction(cfg) && !*force {
fatalf("refusing to write DB because config does not look like a test environment: db=%s host=%s; add -force only on the test server", cfg.MySQL.Dbname, cfg.Site.Host)
}
env := mustNewReplayEnv(ctx, cfg)
if !*keep {
defer env.cleanup(ctx)
}
must(env.replaySingleSubscription(ctx))
must(env.replayInviteRulesMatrix(ctx))
must(env.replayFamilyInviteGiftToOwner(ctx))
fmt.Printf("\nPASS all replay checks in %s\n", time.Since(started).Round(time.Millisecond))
if *keep {
fmt.Println("Replay rows kept for inspection. Delete rows with remark/name/order_no containing:", marker)
}
}
func runCodeTests() error {
fmt.Println("\n-- code-level tests --")
args := []string{"test",
"./internal/logic/auth",
"./internal/logic/common",
"./internal/logic/public/order",
"./queue/logic/order",
}
cmd := exec.Command("go", args...)
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
if err := cmd.Run(); err != nil {
return fmt.Errorf("go test failed: %w", err)
}
fmt.Println("PASS code-level tests")
return nil
}
func loadConfig(path, dsn string) config.Config {
var cfg config.Config
conf.MustLoad(path, &cfg)
if dsn != "" {
cfg.MySQL = parseDSN(dsn)
}
return cfg
}
func parseDSN(dsn string) orm.Config {
cfg := orm.ParseDSN(dsn)
if cfg == nil {
fatalf("invalid dsn")
}
return *cfg
}
func runEmailTrialAssertions(cfg config.Config) {
fmt.Println("\n-- bug1 email trial whitelist assertions --")
cfg.Register.EnableTrial = true
cfg.Register.EnableTrialEmailWhitelist = true
if cfg.Register.TrialEmailDomainWhitelist == "" {
cfg.Register.TrialEmailDomainWhitelist = "gmail.com,163.com"
}
cases := []struct {
email string
want bool
}{
{"1.2.3.4xxx@gmaial.com", false},
{"a.b.c@gmail.com", false},
{"user+tag@gmail.com", false},
{"user@fake.gmail.com", false},
{"normaluser@gmail.com", true},
}
for _, tc := range cases {
got := authlogic.ShouldGrantTrialForEmail(cfg.Register, tc.email)
if got != tc.want {
fatalf("email trial assertion failed: email=%s got=%v want=%v", tc.email, got, tc.want)
}
fmt.Printf("PASS %-32s grant=%v\n", tc.email, got)
}
}
type replayEnv struct {
db *gorm.DB
rds *redis.Client
cfg config.Config
svcCtx *svc.ServiceContext
ids struct {
users []int64
subscribes []int64
plans []int64
orders []int64
logs []int64
}
}
func mustNewReplayEnv(ctx context.Context, cfg config.Config) *replayEnv {
fmt.Println("\n-- connecting test DB/Redis --")
db, err := orm.ConnectMysql(orm.Mysql{Config: cfg.MySQL})
must(err)
rds := redis.NewClient(&redis.Options{
Addr: cfg.Redis.Host,
Password: cfg.Redis.Pass,
DB: cfg.Redis.DB,
PoolSize: cfg.Redis.PoolSize,
MinIdleConns: cfg.Redis.MinIdleConns,
})
must(rds.Ping(ctx).Err())
svcCtx := &svc.ServiceContext{
DB: db,
Redis: rds,
Config: cfg,
UserModel: modelUser.NewModel(db, rds),
OrderModel: modelOrder.NewModel(db, rds),
SubscribeModel: modelSubscribe.NewModel(db, rds),
LogModel: modelLog.NewModel(db),
}
fmt.Printf("connected: mysql=%s/%s redis=%s\n", cfg.MySQL.Addr, cfg.MySQL.Dbname, cfg.Redis.Host)
return &replayEnv{db: db, rds: rds, cfg: cfg, svcCtx: svcCtx}
}
func (e *replayEnv) replaySingleSubscription(ctx context.Context) error {
fmt.Println("\n-- bug2 replay: paid purchase must reuse existing subscription --")
planA, planB, err := e.createPlans(ctx, "bug2")
if err != nil {
return err
}
owner, err := e.createUser(ctx, "bug2-owner", 0, 0)
if err != nil {
return err
}
existing, err := e.createUserSubscribe(ctx, owner.Id, 0, planA.Id, time.Now().Add(7*24*time.Hour))
if err != nil {
return err
}
order, err := e.createPaidOrder(ctx, owner.Id, owner.Id, planB.Id, true, "bug2")
if err != nil {
return err
}
payload, _ := json.Marshal(queueTypes.ForthwithActivateOrderPayload{OrderNo: order.OrderNo})
worker := orderLogic.NewActivateOrderLogic(e.svcCtx)
if err = worker.ProcessTask(ctx, asynq.NewTask(queueTypes.ForthwithActivateOrder, payload)); err != nil {
return err
}
if err = worker.ProcessTask(ctx, asynq.NewTask(queueTypes.ForthwithActivateOrder, payload)); err != nil {
return err
}
var rows []modelUser.Subscribe
if err = e.db.WithContext(ctx).
Where("user_id = ? AND token <> '' AND status IN ?", owner.Id, []int{0, 1, 2, 3, 4}).
Order("id ASC").
Find(&rows).Error; err != nil {
return err
}
if len(rows) != 1 {
return fmt.Errorf("bug2 failed: expected one visible subscription, got %d", len(rows))
}
if rows[0].Id != existing.Id {
return fmt.Errorf("bug2 failed: expected original subscription id=%d to be reused, got id=%d", existing.Id, rows[0].Id)
}
if rows[0].SubscribeId != planB.Id || rows[0].OrderId != order.Id {
return fmt.Errorf("bug2 failed: reused subscription not updated, subscribe_id=%d order_id=%d", rows[0].SubscribeId, rows[0].OrderId)
}
fmt.Printf("PASS user=%d user_subscribe=%d plan %d -> %d order=%s\n", owner.Id, rows[0].Id, planA.Id, planB.Id, order.OrderNo)
return nil
}
func (e *replayEnv) replayInviteGiftDays(ctx context.Context) error {
fmt.Println("\n-- bug3 replay: commission=0 invite should grant gift days to both users --")
giftDays := e.cfg.Invite.GiftDays
if giftDays <= 0 {
giftDays = 2
e.svcCtx.Config.Invite.GiftDays = giftDays
}
e.svcCtx.Config.Invite.ReferralPercentage = 0
e.svcCtx.Config.Invite.OnlyFirstPurchase = true
planA, _, err := e.createPlans(ctx, "bug3")
if err != nil {
return err
}
referer, err := e.createUser(ctx, "bug3-referer", 0, 0)
if err != nil {
return err
}
referee, err := e.createUser(ctx, "bug3-referee", referer.Id, 0)
if err != nil {
return err
}
baseExpire := time.Now().Add(10 * 24 * time.Hour).Truncate(time.Millisecond)
refererSub, err := e.createUserSubscribe(ctx, referer.Id, 0, planA.Id, baseExpire)
if err != nil {
return err
}
refereeSub, err := e.createUserSubscribe(ctx, referee.Id, 0, planA.Id, baseExpire)
if err != nil {
return err
}
order, err := e.createPaidOrder(ctx, referee.Id, referee.Id, planA.Id, true, "bug3")
if err != nil {
return err
}
payload, _ := json.Marshal(queueTypes.ForthwithActivateOrderPayload{OrderNo: order.OrderNo})
worker := orderLogic.NewActivateOrderLogic(e.svcCtx)
if err = worker.ProcessTask(ctx, asynq.NewTask(queueTypes.ForthwithActivateOrder, payload)); err != nil {
return err
}
if err = e.waitForGiftLogs(ctx, order.OrderNo, referer.Id, referee.Id); err != nil {
return err
}
var refererAfter, refereeAfter modelUser.Subscribe
if err = e.db.WithContext(ctx).First(&refererAfter, refererSub.Id).Error; err != nil {
return err
}
if err = e.db.WithContext(ctx).First(&refereeAfter, refereeSub.Id).Error; err != nil {
return err
}
minRefererExpire := baseExpire.Add(time.Duration(giftDays) * 24 * time.Hour)
if refererAfter.ExpireTime.Before(minRefererExpire.Add(-time.Second)) {
return fmt.Errorf("bug3 failed: referer expire not increased by gift days, got=%s want>=%s", refererAfter.ExpireTime, minRefererExpire)
}
if !refereeAfter.ExpireTime.After(baseExpire) {
return fmt.Errorf("bug3 failed: referee expire did not increase, got=%s base=%s", refereeAfter.ExpireTime, baseExpire)
}
// Idempotency: repeat the same order task and make sure gift logs are still one per user.
if err = worker.ProcessTask(ctx, asynq.NewTask(queueTypes.ForthwithActivateOrder, payload)); err != nil {
return err
}
var giftCount int64
if err = e.db.WithContext(ctx).Model(&modelLog.SystemLog{}).
Where("type = ? AND object_id IN ? AND content LIKE ?", modelLog.TypeGift.Uint8(), []int64{referer.Id, referee.Id}, "%"+order.OrderNo+"%").
Count(&giftCount).Error; err != nil {
return err
}
if giftCount != 2 {
return fmt.Errorf("bug3 failed: expected 2 gift logs after duplicate task, got %d", giftCount)
}
fmt.Printf("PASS referer=%d referee=%d order=%s gift_days=%d logs=%d\n", referer.Id, referee.Id, order.OrderNo, giftDays, giftCount)
return nil
}
func (e *replayEnv) replayInviteRulesMatrix(ctx context.Context) error {
fmt.Println("\n-- bug3 replay matrix: invite gift/commission rules --")
giftDays := e.cfg.Invite.GiftDays
if giftDays <= 0 {
giftDays = 2
}
e.svcCtx.Config.Invite.GiftDays = giftDays
e.svcCtx.Config.Invite.OnlyFirstPurchase = false
planA, _, err := e.createPlans(ctx, "bug3-matrix")
if err != nil {
return err
}
cases := []struct {
name string
hasReferer bool
globalReferralPct int64
isNewOrder bool
wantGiftLogs int64
wantCommissionLogs int64
wantCommission int64
}{
{
name: "no invite relation first order no gift",
hasReferer: false,
isNewOrder: true,
wantGiftLogs: 0,
},
{
name: "ordinary invite commission 0 first order gifts both",
hasReferer: true,
isNewOrder: true,
wantGiftLogs: 2,
},
{
name: "ordinary invite commission 0 non-first order no gift",
hasReferer: true,
isNewOrder: false,
wantGiftLogs: 0,
},
{
name: "channel commission positive first order gifts referee only",
hasReferer: true,
globalReferralPct: 10,
isNewOrder: true,
wantGiftLogs: 1,
wantCommissionLogs: 1,
wantCommission: 59,
},
{
name: "channel commission positive non-first order commission only",
hasReferer: true,
globalReferralPct: 10,
isNewOrder: false,
wantGiftLogs: 0,
wantCommissionLogs: 1,
wantCommission: 59,
},
}
for idx, tc := range cases {
e.svcCtx.Config.Invite.ReferralPercentage = tc.globalReferralPct
scope := fmt.Sprintf("bug3-rule-%d", idx+1)
var referer *modelUser.User
if tc.hasReferer {
referer, err = e.createUser(ctx, scope+"-referer", 0, 0)
if err != nil {
return err
}
if _, err = e.createUserSubscribe(ctx, referer.Id, 0, planA.Id, time.Now().Add(10*24*time.Hour)); err != nil {
return err
}
}
var refererID int64
if referer != nil {
refererID = referer.Id
}
referee, err := e.createUser(ctx, scope+"-referee", refererID, 0)
if err != nil {
return err
}
if _, err = e.createUserSubscribe(ctx, referee.Id, 0, planA.Id, time.Now().Add(10*24*time.Hour)); err != nil {
return err
}
order, err := e.createPaidOrder(ctx, referee.Id, referee.Id, planA.Id, tc.isNewOrder, scope)
if err != nil {
return err
}
if err = e.activateOrderTwice(ctx, order.OrderNo); err != nil {
return fmt.Errorf("%s: %w", tc.name, err)
}
if err = e.waitForLogCounts(ctx, order.OrderNo, tc.wantGiftLogs, tc.wantCommissionLogs); err != nil {
return fmt.Errorf("%s: %w", tc.name, err)
}
giftLogs, err := e.countLogs(ctx, modelLog.TypeGift.Uint8(), order.OrderNo)
if err != nil {
return err
}
commissionLogs, err := e.countLogs(ctx, modelLog.TypeCommission.Uint8(), order.OrderNo)
if err != nil {
return err
}
if giftLogs != tc.wantGiftLogs {
return fmt.Errorf("%s: expected gift logs=%d got=%d", tc.name, tc.wantGiftLogs, giftLogs)
}
if commissionLogs != tc.wantCommissionLogs {
return fmt.Errorf("%s: expected commission logs=%d got=%d", tc.name, tc.wantCommissionLogs, commissionLogs)
}
if referer != nil && tc.wantCommission > 0 {
var after modelUser.User
if err = e.db.WithContext(ctx).First(&after, referer.Id).Error; err != nil {
return err
}
if after.Commission != tc.wantCommission {
return fmt.Errorf("%s: expected referer commission=%d got=%d", tc.name, tc.wantCommission, after.Commission)
}
}
fmt.Printf("PASS %-58s gifts=%d commission_logs=%d\n", tc.name, giftLogs, commissionLogs)
}
return nil
}
func (e *replayEnv) replayFamilyInviteGiftToOwner(ctx context.Context) error {
fmt.Println("\n-- bug3 family replay: member purchase gift days go to owner --")
giftDays := e.cfg.Invite.GiftDays
if giftDays <= 0 {
giftDays = 2
}
e.svcCtx.Config.Invite.GiftDays = giftDays
e.svcCtx.Config.Invite.ReferralPercentage = 0
e.svcCtx.Config.Invite.OnlyFirstPurchase = false
planA, _, err := e.createPlans(ctx, "bug3-family")
if err != nil {
return err
}
referer, err := e.createUser(ctx, "bug3-family-referer", 0, 0)
if err != nil {
return err
}
owner, err := e.createUser(ctx, "bug3-family-owner", 0, 0)
if err != nil {
return err
}
member, err := e.createUser(ctx, "bug3-family-member", referer.Id, 0)
if err != nil {
return err
}
if err = e.createFamily(ctx, owner.Id, member.Id); err != nil {
return err
}
baseExpire := time.Now().Add(10 * 24 * time.Hour).Truncate(time.Millisecond)
ownerSub, err := e.createUserSubscribe(ctx, owner.Id, 0, planA.Id, baseExpire)
if err != nil {
return err
}
memberSub, err := e.createUserSubscribe(ctx, member.Id, 0, planA.Id, baseExpire)
if err != nil {
return err
}
refererSub, err := e.createUserSubscribe(ctx, referer.Id, 0, planA.Id, baseExpire)
if err != nil {
return err
}
order, err := e.createPaidOrder(ctx, member.Id, owner.Id, planA.Id, true, "bug3-family")
if err != nil {
return err
}
if err = e.activateOrderTwice(ctx, order.OrderNo); err != nil {
return err
}
if err = e.waitForLogCounts(ctx, order.OrderNo, 2, 0); err != nil {
return err
}
var ownerAfter, memberAfter, refererAfter modelUser.Subscribe
if err = e.db.WithContext(ctx).First(&ownerAfter, ownerSub.Id).Error; err != nil {
return err
}
if err = e.db.WithContext(ctx).First(&memberAfter, memberSub.Id).Error; err != nil {
return err
}
if err = e.db.WithContext(ctx).First(&refererAfter, refererSub.Id).Error; err != nil {
return err
}
if !ownerAfter.ExpireTime.After(baseExpire) {
return fmt.Errorf("family gift failed: owner expire not increased")
}
if !refererAfter.ExpireTime.After(baseExpire) {
return fmt.Errorf("family gift failed: referer expire not increased")
}
if memberAfter.ExpireTime.After(baseExpire.Add(time.Second)) {
return fmt.Errorf("family gift failed: member subscription should not receive gift days")
}
var memberGiftLogs int64
if err = e.db.WithContext(ctx).Model(&modelLog.SystemLog{}).
Where("type = ? AND object_id = ? AND content LIKE ?", modelLog.TypeGift.Uint8(), member.Id, "%"+order.OrderNo+"%").
Count(&memberGiftLogs).Error; err != nil {
return err
}
if memberGiftLogs != 0 {
return fmt.Errorf("family gift failed: expected no member gift logs, got %d", memberGiftLogs)
}
fmt.Printf("PASS family member purchase gift target owner owner=%d member=%d referer=%d gift_days=%d\n", owner.Id, member.Id, referer.Id, giftDays)
return nil
}
func (e *replayEnv) activateOrderTwice(ctx context.Context, orderNo string) error {
payload, _ := json.Marshal(queueTypes.ForthwithActivateOrderPayload{OrderNo: orderNo})
worker := orderLogic.NewActivateOrderLogic(e.svcCtx)
if err := worker.ProcessTask(ctx, asynq.NewTask(queueTypes.ForthwithActivateOrder, payload)); err != nil {
return err
}
return worker.ProcessTask(ctx, asynq.NewTask(queueTypes.ForthwithActivateOrder, payload))
}
func (e *replayEnv) waitForLogCounts(ctx context.Context, orderNo string, wantGiftLogs, wantCommissionLogs int64) error {
deadline := time.Now().Add(8 * time.Second)
for {
giftLogs, err := e.countLogs(ctx, modelLog.TypeGift.Uint8(), orderNo)
if err != nil {
return err
}
commissionLogs, err := e.countLogs(ctx, modelLog.TypeCommission.Uint8(), orderNo)
if err != nil {
return err
}
if giftLogs >= wantGiftLogs && commissionLogs >= wantCommissionLogs {
if wantGiftLogs == 0 && wantCommissionLogs == 0 {
time.Sleep(500 * time.Millisecond)
}
return nil
}
if time.Now().After(deadline) {
return fmt.Errorf("timed out waiting for logs: order=%s gift=%d/%d commission=%d/%d", orderNo, giftLogs, wantGiftLogs, commissionLogs, wantCommissionLogs)
}
time.Sleep(100 * time.Millisecond)
}
}
func (e *replayEnv) countLogs(ctx context.Context, logType uint8, orderNo string) (int64, error) {
var count int64
err := e.db.WithContext(ctx).Model(&modelLog.SystemLog{}).
Where("type = ? AND content LIKE ?", logType, "%"+orderNo+"%").
Count(&count).Error
return count, err
}
func (e *replayEnv) waitForGiftLogs(ctx context.Context, orderNo string, userIDs ...int64) error {
deadline := time.Now().Add(5 * time.Second)
for time.Now().Before(deadline) {
var count int64
if err := e.db.WithContext(ctx).Model(&modelLog.SystemLog{}).
Where("type = ? AND object_id IN ? AND content LIKE ?", modelLog.TypeGift.Uint8(), userIDs, "%"+orderNo+"%").
Count(&count).Error; err != nil {
return err
}
if count == int64(len(userIDs)) {
return nil
}
time.Sleep(100 * time.Millisecond)
}
return fmt.Errorf("timed out waiting for gift logs for order=%s", orderNo)
}
func (e *replayEnv) createPlans(ctx context.Context, scope string) (*modelSubscribe.Subscribe, *modelSubscribe.Subscribe, error) {
a := &modelSubscribe.Subscribe{
Name: marker + "-" + scope + "-A",
Language: "en",
UnitPrice: 599,
UnitTime: "Month",
Traffic: 1024 * 1024 * 1024,
Inventory: -1,
Quota: 0,
NodeGroupIds: modelSubscribe.JSONInt64Slice{},
}
b := &modelSubscribe.Subscribe{
Name: marker + "-" + scope + "-B",
Language: "en",
UnitPrice: 699,
UnitTime: "Month",
Traffic: 2 * 1024 * 1024 * 1024,
Inventory: -1,
Quota: 0,
NodeGroupIds: modelSubscribe.JSONInt64Slice{},
}
if err := e.db.WithContext(ctx).Create(a).Error; err != nil {
return nil, nil, err
}
if err := e.db.WithContext(ctx).Create(b).Error; err != nil {
return nil, nil, err
}
e.ids.plans = append(e.ids.plans, a.Id, b.Id)
return a, b, nil
}
func (e *replayEnv) createUser(ctx context.Context, scope string, refererID int64, referralPercentage uint8) (*modelUser.User, error) {
onlyFirst := true
enable := true
isAdmin := false
u := &modelUser.User{
Password: marker,
Algo: "default",
Salt: "default",
RefererId: refererID,
ReferralPercentage: referralPercentage,
OnlyFirstPurchase: &onlyFirst,
Enable: &enable,
IsAdmin: &isAdmin,
EnableBalanceNotify: &enable,
EnableLoginNotify: &enable,
EnableSubscribeNotify: &enable,
EnableTradeNotify: &enable,
Remark: marker + "-" + scope,
}
if err := e.db.WithContext(ctx).Create(u).Error; err != nil {
return nil, err
}
u.ReferCode = uuidx.UserInviteCode(u.Id)
if err := e.db.WithContext(ctx).Model(&modelUser.User{}).Where("id = ?", u.Id).Update("refer_code", u.ReferCode).Error; err != nil {
return nil, err
}
e.ids.users = append(e.ids.users, u.Id)
return u, nil
}
func (e *replayEnv) createFamily(ctx context.Context, ownerID, memberID int64) error {
now := time.Now()
family := &modelUser.UserFamily{
OwnerUserId: ownerID,
MaxMembers: modelUser.DefaultFamilyMaxSize,
Status: modelUser.FamilyStatusActive,
}
if err := e.db.WithContext(ctx).Create(family).Error; err != nil {
return err
}
members := []modelUser.UserFamilyMember{
{
FamilyId: family.Id,
UserId: ownerID,
Role: modelUser.FamilyRoleOwner,
Status: modelUser.FamilyMemberActive,
JoinSource: marker,
JoinedAt: now,
},
{
FamilyId: family.Id,
UserId: memberID,
Role: modelUser.FamilyRoleMember,
Status: modelUser.FamilyMemberActive,
JoinSource: marker,
JoinedAt: now,
},
}
return e.db.WithContext(ctx).Create(&members).Error
}
func (e *replayEnv) createUserSubscribe(ctx context.Context, userID, orderID, planID int64, expire time.Time) (*modelUser.Subscribe, error) {
groupLocked := false
sub := &modelUser.Subscribe{
UserId: userID,
OrderId: orderID,
SubscribeId: planID,
GroupLocked: &groupLocked,
StartTime: time.Now().Add(-time.Hour),
ExpireTime: expire,
Traffic: 1024 * 1024 * 1024,
Token: marker + "-" + uuidx.NewUUID().String(),
UUID: uuidx.NewUUID().String(),
Status: 1,
Note: marker,
}
if err := e.db.WithContext(ctx).Create(sub).Error; err != nil {
return nil, err
}
e.ids.subscribes = append(e.ids.subscribes, sub.Id)
return sub, nil
}
func (e *replayEnv) createPaidOrder(ctx context.Context, userID, subscriptionUserID, planID int64, isNew bool, scope string) (*modelOrder.Order, error) {
orderNo := fmt.Sprintf("%s-%s-%d", marker, scope, time.Now().UnixNano())
order := &modelOrder.Order{
UserId: userID,
SubscriptionUserId: subscriptionUserID,
OrderNo: orderNo,
Type: 1,
Quantity: 1,
Price: 599,
Amount: 599,
Status: 2,
SubscribeId: planID,
Method: "replay",
IsNew: isNew,
}
if err := e.db.WithContext(ctx).Create(order).Error; err != nil {
return nil, err
}
e.ids.orders = append(e.ids.orders, order.Id)
return order, nil
}
func (e *replayEnv) cleanup(ctx context.Context) {
fmt.Println("\n-- cleanup replay rows --")
e.cleanupByMarker(ctx)
if len(e.ids.subscribes) > 0 {
_ = e.db.WithContext(ctx).Where("id IN ?", e.ids.subscribes).Delete(&modelUser.Subscribe{}).Error
}
if len(e.ids.orders) > 0 {
_ = e.db.WithContext(ctx).Where("id IN ?", e.ids.orders).Delete(&modelOrder.Order{}).Error
}
if len(e.ids.plans) > 0 {
_ = e.db.WithContext(ctx).Where("id IN ?", e.ids.plans).Delete(&modelSubscribe.Subscribe{}).Error
}
if len(e.ids.users) > 0 {
_ = e.db.WithContext(ctx).Unscoped().Where("id IN ?", e.ids.users).Delete(&modelUser.User{}).Error
}
fmt.Println("cleanup done")
}
func (e *replayEnv) cleanupByMarker(ctx context.Context) {
_ = e.db.WithContext(ctx).
Where("join_source = ?", marker).
Delete(&modelUser.UserFamilyMember{}).Error
_ = e.db.WithContext(ctx).
Where("owner_user_id IN (SELECT id FROM `user` WHERE remark LIKE ?)", marker+"%").
Delete(&modelUser.UserFamily{}).Error
_ = e.db.WithContext(ctx).
Where("type IN (33, 34) AND content LIKE ?", "%"+marker+"%").
Delete(&modelLog.SystemLog{}).Error
_ = e.db.WithContext(ctx).
Where("order_no LIKE ?", marker+"%").
Delete(&modelOrder.Order{}).Error
_ = e.db.WithContext(ctx).
Where("note = ? OR token LIKE ?", marker, marker+"%").
Delete(&modelUser.Subscribe{}).Error
_ = e.db.WithContext(ctx).
Where("name LIKE ?", marker+"%").
Delete(&modelSubscribe.Subscribe{}).Error
_ = e.db.WithContext(ctx).Unscoped().
Where("remark LIKE ?", marker+"%").
Delete(&modelUser.User{}).Error
}
func looksLikeProduction(cfg config.Config) bool {
joined := strings.ToLower(strings.Join([]string{cfg.MySQL.Dbname, cfg.Site.Host, cfg.Host}, " "))
if strings.Contains(joined, "prod") || strings.Contains(joined, "production") {
return true
}
if cfg.Debug {
return false
}
if strings.Contains(joined, "test") || strings.Contains(joined, "dev") || strings.Contains(joined, "staging") {
return false
}
return true
}
func must(err error) {
if err != nil {
fatalf("%v", err)
}
}
func fatalf(format string, args ...interface{}) {
fmt.Fprintf(os.Stderr, "FAIL: "+format+"\n", args...)
os.Exit(1)
}