Compare commits

...

3 Commits

Author SHA1 Message Date
3b3ed7b3c1 test(auth): add HTTP device no-trial check
All checks were successful
Build docker and publish / build (20.15.1) (push) Successful in 5m10s
Co-Authored-By: claude-flow <ruv@ruv.net>
2026-04-29 23:00:18 -07:00
b52e01eaa2 fix(auth): grant trial only on email bind
All checks were successful
Build docker and publish / build (20.15.1) (push) Successful in 5m17s
Co-Authored-By: claude-flow <ruv@ruv.net>
2026-04-29 22:36:17 -07:00
32e3dc3c73 fix(order): cover invite gifts and inactive renewals
All checks were successful
Build docker and publish / build (20.15.1) (push) Successful in 5m36s
2026-04-29 21:52:28 -07:00
5 changed files with 716 additions and 136 deletions

View File

@ -12,7 +12,6 @@ import (
"github.com/perfect-panel/server/internal/types"
"github.com/perfect-panel/server/pkg/jwt"
"github.com/perfect-panel/server/pkg/logger"
"github.com/perfect-panel/server/pkg/tool"
"github.com/perfect-panel/server/pkg/uuidx"
"github.com/perfect-panel/server/pkg/xerr"
"github.com/pkg/errors"
@ -136,8 +135,6 @@ func (l *DeviceLoginLogic) DeviceLogin(req *types.DeviceLoginRequest) (resp *typ
}
}
l.tryGrantTrialForDeviceLogin(userInfo, req.Identifier)
// Generate session id
sessionId := uuidx.NewUUID().String()
@ -294,108 +291,3 @@ func (l *DeviceLoginLogic) registerUserAndDevice(req *types.DeviceLoginRequest)
return userInfo, nil
}
func (l *DeviceLoginLogic) tryGrantTrialForDeviceLogin(userInfo *user.User, identifier string) {
if userInfo == nil || userInfo.Id == 0 {
return
}
if !IsTrialConfigReady(l.svcCtx.Config.Register) {
l.Debugw("skip device trial grant because trial config is not ready",
logger.Field("user_id", userInfo.Id),
logger.Field("identifier", identifier),
logger.Field("enable_trial", l.svcCtx.Config.Register.EnableTrial),
logger.Field("trial_subscribe_id", l.svcCtx.Config.Register.TrialSubscribe),
logger.Field("trial_time", l.svcCtx.Config.Register.TrialTime),
logger.Field("trial_time_unit", l.svcCtx.Config.Register.TrialTimeUnit),
)
return
}
if userInfo.CreatedAt.IsZero() || time.Since(userInfo.CreatedAt) > 24*time.Hour {
l.Debugw("skip device trial grant because user is outside trial backfill window",
logger.Field("user_id", userInfo.Id),
logger.Field("identifier", identifier),
logger.Field("user_created_at", userInfo.CreatedAt),
)
return
}
var count int64
if err := l.svcCtx.DB.WithContext(l.ctx).
Model(&user.Subscribe{}).
Where("user_id = ?", userInfo.Id).
Count(&count).Error; err != nil {
l.Errorw("failed to query existing subscriptions before device trial grant",
logger.Field("user_id", userInfo.Id),
logger.Field("identifier", identifier),
logger.Field("error", err.Error()),
)
return
}
if count > 0 {
l.Debugw("skip device trial grant because user already has subscriptions",
logger.Field("user_id", userInfo.Id),
logger.Field("identifier", identifier),
logger.Field("subscription_count", count),
)
return
}
trialSubscribe, err := l.activeTrial(userInfo.Id)
if err != nil {
l.Errorw("failed to activate trial subscription for device login",
logger.Field("user_id", userInfo.Id),
logger.Field("identifier", identifier),
logger.Field("trial_subscribe_id", l.svcCtx.Config.Register.TrialSubscribe),
logger.Field("error", err.Error()),
)
return
}
if clearErr := l.svcCtx.UserModel.ClearSubscribeCache(l.ctx, trialSubscribe); clearErr != nil {
l.Errorw("ClearSubscribeCache failed",
logger.Field("error", clearErr.Error()),
logger.Field("userSubscribeId", trialSubscribe.Id),
)
}
if clearErr := l.svcCtx.SubscribeModel.ClearCache(l.ctx, trialSubscribe.SubscribeId); clearErr != nil {
l.Errorw("ClearSubscribeCache failed",
logger.Field("error", clearErr.Error()),
logger.Field("subscribeId", trialSubscribe.SubscribeId),
)
}
if clearErr := l.svcCtx.NodeModel.ClearServerAllCache(l.ctx); clearErr != nil {
l.Errorf("ClearServerAllCache error: %v", clearErr.Error())
}
l.Infow("device trial subscription granted",
logger.Field("user_id", userInfo.Id),
logger.Field("identifier", identifier),
logger.Field("user_subscribe_id", trialSubscribe.Id),
logger.Field("trial_subscribe_id", trialSubscribe.SubscribeId),
logger.Field("expire_time", trialSubscribe.ExpireTime),
)
}
func (l *DeviceLoginLogic) activeTrial(uid int64) (*user.Subscribe, error) {
sub, err := l.svcCtx.SubscribeModel.FindOne(l.ctx, l.svcCtx.Config.Register.TrialSubscribe)
if err != nil {
return nil, err
}
startTime := time.Now()
userSub := &user.Subscribe{
UserId: uid,
OrderId: 0,
SubscribeId: sub.Id,
StartTime: startTime,
ExpireTime: tool.AddTime(l.svcCtx.Config.Register.TrialTimeUnit, l.svcCtx.Config.Register.TrialTime, startTime),
Traffic: sub.Traffic,
Download: 0,
Upload: 0,
Token: uuidx.NewUUID().String(),
UUID: uuidx.NewUUID().String(),
Status: 1,
}
if err = l.svcCtx.UserModel.InsertSubscribe(l.ctx, userSub); err != nil {
return nil, err
}
return userSub, nil
}

View File

@ -447,6 +447,13 @@ func orderMergeRemainingTimeStatus(status uint8) bool {
}
}
func subscriptionRenewalBaseTime(now time.Time, userSub *user.Subscribe) time.Time {
if userSub != nil && orderMergeRemainingTimeStatus(userSub.Status) && userSub.ExpireTime.After(now) {
return userSub.ExpireTime
}
return now
}
func pickSubscriptionIdentitySource(candidates []user.Subscribe) *user.Subscribe {
if len(candidates) == 0 {
return nil
@ -959,10 +966,7 @@ func (l *ActivateOrderLogic) findGiftSubscription(ctx context.Context, userId in
// 若购买套餐与赠送套餐不同,同步更新套餐 ID 和流量配额并重置已用量(套餐变更语义)。
func (l *ActivateOrderLogic) extendGiftSubscription(ctx context.Context, giftSub *user.Subscribe, orderInfo *order.Order, sub *subscribe.Subscribe) (*user.Subscribe, error) {
now := time.Now()
baseTime := giftSub.ExpireTime
if baseTime.Before(now) {
baseTime = now
}
baseTime := subscriptionRenewalBaseTime(now, giftSub)
newExpireTime := tool.AddTime(sub.UnitTime, orderInfo.Quantity, baseTime)
giftSub.OrderId = orderInfo.Id
@ -1417,10 +1421,7 @@ func (l *ActivateOrderLogic) getUserSubscription(ctx context.Context, token stri
// updateSubscriptionWithIAPExpire 用于 Apple IAP 续费:按累计加时语义更新到期时间。
func (l *ActivateOrderLogic) updateSubscriptionWithIAPExpire(ctx context.Context, userSub *user.Subscribe, sub *subscribe.Subscribe, orderInfo *order.Order, iapExpireAt int64) error {
now := time.Now()
baseTime := userSub.ExpireTime
if baseTime.Before(now) {
baseTime = now
}
baseTime := subscriptionRenewalBaseTime(now, userSub)
newExpire := tool.AddTime(sub.UnitTime, orderInfo.Quantity, baseTime)
if iapExpireAt > 0 {
appleExpire := time.Unix(iapExpireAt, 0)
@ -1455,11 +1456,9 @@ func (l *ActivateOrderLogic) updateSubscriptionWithIAPExpire(ctx context.Context
// expiration time extension and traffic reset if configured
func (l *ActivateOrderLogic) updateSubscriptionForRenewal(ctx context.Context, userSub *user.Subscribe, sub *subscribe.Subscribe, orderInfo *order.Order) error {
now := time.Now()
if userSub.ExpireTime.Before(now) {
userSub.ExpireTime = now
}
today := time.Now().Day()
resetDay := userSub.ExpireTime.Day()
baseTime := subscriptionRenewalBaseTime(now, userSub)
today := now.Day()
resetDay := baseTime.Day()
// 套餐变更更新套餐ID和流量配额并重置已用流量
if userSub.SubscribeId != orderInfo.SubscribeId {
@ -1486,7 +1485,7 @@ func (l *ActivateOrderLogic) updateSubscriptionForRenewal(ctx context.Context, u
}
userSub.OrderId = orderInfo.Id
userSub.ExpireTime = tool.AddTime(sub.UnitTime, orderInfo.Quantity, userSub.ExpireTime)
userSub.ExpireTime = tool.AddTime(sub.UnitTime, orderInfo.Quantity, baseTime)
userSub.Status = 1
// 续费时重置过期流量字段
userSub.ExpiredDownload = 0

View File

@ -0,0 +1,215 @@
//go:build ignore
package main
import (
"bytes"
"context"
"encoding/json"
"flag"
"fmt"
"io"
"net/http"
"os"
"strings"
"time"
"github.com/perfect-panel/server/internal/config"
modelLog "github.com/perfect-panel/server/internal/model/log"
modelUser "github.com/perfect-panel/server/internal/model/user"
"github.com/perfect-panel/server/pkg/conf"
"github.com/perfect-panel/server/pkg/orm"
"github.com/perfect-panel/server/pkg/tool"
"gorm.io/gorm"
)
func main() {
var (
configPath = flag.String("config", "etc/ppanel.yaml", "config file path")
dsn = flag.String("dsn", "", "optional MySQL DSN override")
baseURL = flag.String("base-url", "", "server base URL, for example http://154.12.35.103")
identifier = flag.String("identifier", "", "optional device identifier")
ip = flag.String("ip", "198.18.77.77", "X-Forwarded-For test IP")
userAgent = flag.String("user-agent", "CodexDeviceNoTrialHTTP/1.0", "device user agent")
cleanup = flag.Bool("cleanup", false, "delete test user/device/log rows after verification")
)
flag.Parse()
if strings.TrimSpace(*baseURL) == "" {
fail("base-url is required")
}
if *identifier == "" {
*identifier = fmt.Sprintf("codex-device-no-trial-%d", time.Now().UnixNano())
}
ctx := context.Background()
cfg := loadConfig(*configPath, *dsn)
db, err := orm.ConnectMysql(orm.Mysql{Config: cfg.MySQL})
must(err)
fmt.Println("== HTTP device login no-trial test ==")
fmt.Printf("base_url=%s\n", strings.TrimRight(*baseURL, "/"))
fmt.Printf("mysql=%s/%s\n", cfg.MySQL.Addr, cfg.MySQL.Dbname)
fmt.Printf("identifier=%s ip=%s user_agent=%s\n", *identifier, *ip, *userAgent)
if err = ensureIdentifierUnused(ctx, db, *identifier); err != nil {
fail("%v", err)
}
status, body, err := postDeviceLogin(ctx, *baseURL, *identifier, *ip, *userAgent)
if err != nil {
fail("device login request failed: %v", err)
}
fmt.Printf("http_status=%d body=%s\n", status, truncate(body, 500))
if status < 200 || status >= 300 {
fail("device login returned non-2xx status: %d", status)
}
device, err := findDevice(ctx, db, *identifier)
if err != nil {
fail("created device not found in DB: %v", err)
}
fmt.Printf("device: id=%d sn=%s user_id=%d created_at=%s\n",
device.Id,
tool.DeviceIdToHash(device.Id),
device.UserId,
device.CreatedAt.Format(time.RFC3339),
)
var subs []modelUser.Subscribe
if err = db.WithContext(ctx).
Where("user_id = ?", device.UserId).
Order("id ASC").
Find(&subs).Error; err != nil {
fail("query user_subscribe failed: %v", err)
}
if len(subs) == 0 {
fmt.Printf("user_subscribe rows: 0\n")
}
for i := range subs {
sub := &subs[i]
fmt.Printf("subscribe: id=%d order_id=%d subscribe_id=%d status=%d start=%s expire=%s\n",
sub.Id,
sub.OrderId,
sub.SubscribeId,
sub.Status,
sub.StartTime.Format(time.RFC3339),
sub.ExpireTime.Format(time.RFC3339),
)
fail("FAIL: HTTP device login created subscription user_subscribe_id=%d user_id=%d", sub.Id, device.UserId)
}
fmt.Printf("PASS: HTTP device login created no subscription for user_id=%d\n", device.UserId)
if *cleanup {
if err = cleanupTestRows(ctx, db, device.UserId); err != nil {
fail("cleanup failed: %v", err)
}
fmt.Printf("cleanup: hard-deleted test rows for user_id=%d\n", device.UserId)
}
}
func postDeviceLogin(ctx context.Context, baseURL, identifier, ip, userAgent string) (int, string, error) {
payload := map[string]string{
"identifier": identifier,
"user_agent": userAgent,
}
data, err := json.Marshal(payload)
if err != nil {
return 0, "", err
}
url := strings.TrimRight(baseURL, "/") + "/v1/auth/login/device"
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(data))
if err != nil {
return 0, "", err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", userAgent)
req.Header.Set("X-Forwarded-For", ip)
req.Header.Set("X-Real-IP", ip)
client := &http.Client{Timeout: 20 * time.Second}
resp, err := client.Do(req)
if err != nil {
return 0, "", err
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return resp.StatusCode, "", err
}
return resp.StatusCode, string(body), nil
}
func loadConfig(path, dsn string) config.Config {
var cfg config.Config
conf.MustLoad(path, &cfg)
if dsn != "" {
parsed := orm.ParseDSN(dsn)
if parsed == nil {
fail("invalid dsn")
}
cfg.MySQL = *parsed
}
return cfg
}
func ensureIdentifierUnused(ctx context.Context, db *gorm.DB, identifier string) error {
var count int64
if err := db.WithContext(ctx).
Model(&modelUser.Device{}).
Where("identifier = ?", identifier).
Count(&count).Error; err != nil {
return err
}
if count > 0 {
return fmt.Errorf("identifier already exists: %s", identifier)
}
return nil
}
func findDevice(ctx context.Context, db *gorm.DB, identifier string) (*modelUser.Device, error) {
var device modelUser.Device
err := db.WithContext(ctx).
Where("identifier = ?", identifier).
First(&device).Error
return &device, err
}
func cleanupTestRows(ctx context.Context, db *gorm.DB, userID int64) error {
return db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
if err := tx.Where("object_id = ?", userID).Delete(&modelLog.SystemLog{}).Error; err != nil {
return err
}
if err := tx.Where("user_id = ?", userID).Delete(&modelUser.Subscribe{}).Error; err != nil {
return err
}
if err := tx.Where("user_id = ?", userID).Delete(&modelUser.AuthMethods{}).Error; err != nil {
return err
}
if err := tx.Where("user_id = ?", userID).Delete(&modelUser.Device{}).Error; err != nil {
return err
}
return tx.Unscoped().Where("id = ?", userID).Delete(&modelUser.User{}).Error
})
}
func truncate(s string, n int) string {
s = strings.TrimSpace(s)
if len(s) <= n {
return s
}
return s[:n] + "...<truncated>"
}
func must(err error) {
if err != nil {
fail("%v", err)
}
}
func fail(format string, args ...interface{}) {
fmt.Fprintf(os.Stderr, format+"\n", args...)
os.Exit(1)
}

View File

@ -63,7 +63,7 @@ func main() {
*ip = fmt.Sprintf("198.18.%d.%d", now%200+1, now/200%200+1)
}
fmt.Println("== device trial registration test ==")
fmt.Println("== device registration no-trial test ==")
fmt.Printf("mysql: %s/%s\n", env.cfg.MySQL.Addr, env.cfg.MySQL.Dbname)
fmt.Printf("redis: %s db=%d\n", env.cfg.Redis.Host, env.cfg.Redis.DB)
fmt.Printf("device.enable=%v\n", env.svcCtx.Config.Device.Enable)
@ -111,11 +111,6 @@ func main() {
Find(&subs).Error; err != nil {
fail(fmt.Errorf("query user_subscribe failed: %w", err))
}
if len(subs) == 0 {
fail(fmt.Errorf("FAIL: no user_subscribe rows created for user_id=%d", device.UserId))
}
var trial *modelUser.Subscribe
for i := range subs {
sub := &subs[i]
fmt.Printf("subscribe: id=%d order_id=%d subscribe_id=%d status=%d start=%s expire=%s token_empty=%v\n",
@ -131,17 +126,11 @@ func main() {
sub.SubscribeId == env.svcCtx.Config.Register.TrialSubscribe &&
(sub.Status == 0 || sub.Status == 1) &&
sub.ExpireTime.After(time.Now()) {
trial = sub
fail(fmt.Errorf("FAIL: device registration unexpectedly granted trial user_subscribe_id=%d user_id=%d", sub.Id, device.UserId))
}
}
if trial == nil {
fail(fmt.Errorf("FAIL: trial subscription was not granted for user_id=%d", device.UserId))
}
fmt.Printf("PASS: trial granted user_subscribe_id=%d expire_time=%s\n",
trial.Id,
trial.ExpireTime.Format(time.RFC3339),
)
fmt.Printf("PASS: device registration created no active trial subscription for user_id=%d\n", device.UserId)
if *cleanup {
if err = cleanupTestRows(ctx, env.db, device.UserId); err != nil {

View File

@ -0,0 +1,485 @@
package main
import (
"context"
"database/sql"
"encoding/json"
"flag"
"fmt"
"os"
"strings"
"time"
_ "github.com/go-sql-driver/mysql"
)
const inviteGiftMarker = "codex-test-invite-gift-days"
type giftLog struct {
Type uint16 `json:"type"`
OrderNo string `json:"order_no"`
SubscribeId int64 `json:"subscribe_id"`
Amount int64 `json:"amount"`
Balance int64 `json:"balance"`
Remark string `json:"remark,omitempty"`
Timestamp int64 `json:"timestamp"`
}
type commissionLog struct {
Type uint16 `json:"type"`
Amount int64 `json:"amount"`
OrderNo string `json:"order_no"`
Timestamp int64 `json:"timestamp"`
}
type userSubscribe struct {
ID int64
UserID int64
ExpireTime time.Time
}
func main() {
var (
dsn = flag.String("dsn", "", "MySQL DSN, for example root:pass@tcp(host:3306)/ppanel?charset=utf8mb4&parseTime=true&loc=Asia%2FShanghai")
writeDB = flag.Bool("write-db", false, "create isolated rows, simulate invite gifts, and clean them up")
keep = flag.Bool("keep", false, "keep rows for manual inspection")
cleanupOnly = flag.Bool("cleanup-only", false, "delete leftover rows created by this script and exit")
giftDays = flag.Int("gift-days", 3, "days to add to both invite users")
commission = flag.Int64("commission-percent", 10, "commission percent for commission-path simulation")
)
flag.Parse()
if *dsn == "" {
exitf("-dsn is required")
}
ctx := context.Background()
db, err := sql.Open("mysql", *dsn)
mustNoErr(err)
defer db.Close()
db.SetMaxIdleConns(1)
db.SetMaxOpenConns(1)
mustNoErr(db.PingContext(ctx))
if *cleanupOnly {
mustNoErr(cleanup(ctx, db))
fmt.Println("cleanup done")
return
}
if !*writeDB {
fmt.Println("dry run only. Add -write-db to create isolated invite rows in the TEST database.")
return
}
if *giftDays <= 0 {
exitf("-gift-days must be positive")
}
mustNoErr(cleanup(ctx, db))
if !*keep {
defer func() {
if err := cleanup(context.Background(), db); err != nil {
fmt.Fprintf(os.Stderr, "cleanup failed: %v\n", err)
}
}()
}
planID := mustCreatePlan(ctx, db)
runSelfInviteScenario(ctx, db, planID, *giftDays)
runFamilyInviteScenario(ctx, db, planID, *giftDays)
runCommissionScenario(ctx, db, planID, *giftDays, *commission)
if *keep {
fmt.Println("rows kept; cleanup with -cleanup-only. inviteGiftMarker:", inviteGiftMarker)
}
}
func runSelfInviteScenario(ctx context.Context, db *sql.DB, planID int64, giftDays int) {
refererID := mustCreateUser(ctx, db, "self-referer", 0)
refereeID := mustCreateUser(ctx, db, "self-referee", refererID)
baseExpire := time.Now().Add(10 * 24 * time.Hour).Truncate(time.Second)
refererSubID := mustCreateUserSubscribe(ctx, db, refererID, planID, baseExpire)
refereeSubID := mustCreateUserSubscribe(ctx, db, refereeID, planID, baseExpire)
orderNo := fmt.Sprintf("%s-self-order-%d", inviteGiftMarker, time.Now().UnixNano())
mustNoErr(simulateInviteGiftBoth(ctx, db, orderNo, refererID, refereeID, 0, giftDays))
mustNoErr(simulateInviteGiftBoth(ctx, db, orderNo, refererID, refereeID, 0, giftDays))
assertExpire(ctx, db, "referer", refererSubID, baseExpire, giftDays)
assertExpire(ctx, db, "referee", refereeSubID, baseExpire, giftDays)
logs := mustGiftLogCount(ctx, db, orderNo)
if logs != 2 {
exitf("gift log count mismatch after duplicate simulation: got=%d want=2", logs)
}
fmt.Printf("PASS self invite: referer=%d referee=%d order=%s gift_days=%d logs=%d\n", refererID, refereeID, orderNo, giftDays, logs)
}
func runFamilyInviteScenario(ctx context.Context, db *sql.DB, planID int64, giftDays int) {
refererOwnerID := mustCreateUser(ctx, db, "family-referer-owner", 0)
refererMemberID := mustCreateUser(ctx, db, "family-referer-member", 0)
refereeOwnerID := mustCreateUser(ctx, db, "family-referee-owner", 0)
refereeMemberID := mustCreateUser(ctx, db, "family-referee-member", refererMemberID)
mustCreateFamily(ctx, db, refererOwnerID, refererMemberID)
mustCreateFamily(ctx, db, refereeOwnerID, refereeMemberID)
baseExpire := time.Now().Add(10 * 24 * time.Hour).Truncate(time.Second)
refererOwnerSubID := mustCreateUserSubscribe(ctx, db, refererOwnerID, planID, baseExpire)
refereeOwnerSubID := mustCreateUserSubscribe(ctx, db, refereeOwnerID, planID, baseExpire)
refererMemberSubID := mustCreateUserSubscribe(ctx, db, refererMemberID, planID, baseExpire)
refereeMemberSubID := mustCreateUserSubscribe(ctx, db, refereeMemberID, planID, baseExpire)
orderNo := fmt.Sprintf("%s-family-order-%d", inviteGiftMarker, time.Now().UnixNano())
mustNoErr(simulateInviteGiftBoth(ctx, db, orderNo, refererMemberID, refereeMemberID, refereeOwnerID, giftDays))
mustNoErr(simulateInviteGiftBoth(ctx, db, orderNo, refererMemberID, refereeMemberID, refereeOwnerID, giftDays))
assertExpire(ctx, db, "referer owner", refererOwnerSubID, baseExpire, giftDays)
assertExpire(ctx, db, "referee owner", refereeOwnerSubID, baseExpire, giftDays)
assertExpire(ctx, db, "referer member", refererMemberSubID, baseExpire, 0)
assertExpire(ctx, db, "referee member", refereeMemberSubID, baseExpire, 0)
logs := mustGiftLogCount(ctx, db, orderNo)
if logs != 2 {
exitf("family gift log count mismatch after duplicate simulation: got=%d want=2", logs)
}
fmt.Printf("PASS family invite: referer_member=%d->owner=%d referee_member=%d->owner=%d order=%s gift_days=%d logs=%d\n",
refererMemberID, refererOwnerID, refereeMemberID, refereeOwnerID, orderNo, giftDays, logs)
}
func runCommissionScenario(ctx context.Context, db *sql.DB, planID int64, giftDays int, commissionPercent int64) {
if commissionPercent <= 0 {
fmt.Println("SKIP commission invite: commission-percent <= 0")
return
}
const amount int64 = 599
refererID := mustCreateUser(ctx, db, "commission-referer", 0)
refereeID := mustCreateUser(ctx, db, "commission-referee", refererID)
baseExpire := time.Now().Add(10 * 24 * time.Hour).Truncate(time.Second)
refererSubID := mustCreateUserSubscribe(ctx, db, refererID, planID, baseExpire)
refereeSubID := mustCreateUserSubscribe(ctx, db, refereeID, planID, baseExpire)
orderNo := fmt.Sprintf("%s-commission-first-order-%d", inviteGiftMarker, time.Now().UnixNano())
mustNoErr(simulateInviteCommission(ctx, db, orderNo, refererID, refereeID, 0, giftDays, amount, commissionPercent, true))
mustNoErr(simulateInviteCommission(ctx, db, orderNo, refererID, refereeID, 0, giftDays, amount, commissionPercent, true))
wantCommission := amount * commissionPercent / 100
assertExpire(ctx, db, "commission referer", refererSubID, baseExpire, 0)
assertExpire(ctx, db, "commission referee", refereeSubID, baseExpire, giftDays)
assertCommission(ctx, db, refererID, wantCommission)
assertLogCount(ctx, db, "commission first gift", 34, orderNo, 1)
assertLogCount(ctx, db, "commission first commission", 33, orderNo, 1)
nonFirstRefererID := mustCreateUser(ctx, db, "commission-nonfirst-referer", 0)
nonFirstRefereeID := mustCreateUser(ctx, db, "commission-nonfirst-referee", nonFirstRefererID)
nonFirstRefererSubID := mustCreateUserSubscribe(ctx, db, nonFirstRefererID, planID, baseExpire)
nonFirstRefereeSubID := mustCreateUserSubscribe(ctx, db, nonFirstRefereeID, planID, baseExpire)
nonFirstOrderNo := fmt.Sprintf("%s-commission-nonfirst-order-%d", inviteGiftMarker, time.Now().UnixNano())
mustNoErr(simulateInviteCommission(ctx, db, nonFirstOrderNo, nonFirstRefererID, nonFirstRefereeID, 0, giftDays, amount, commissionPercent, false))
mustNoErr(simulateInviteCommission(ctx, db, nonFirstOrderNo, nonFirstRefererID, nonFirstRefereeID, 0, giftDays, amount, commissionPercent, false))
assertExpire(ctx, db, "commission non-first referer", nonFirstRefererSubID, baseExpire, 0)
assertExpire(ctx, db, "commission non-first referee", nonFirstRefereeSubID, baseExpire, 0)
assertCommission(ctx, db, nonFirstRefererID, wantCommission)
assertLogCount(ctx, db, "commission non-first gift", 34, nonFirstOrderNo, 0)
assertLogCount(ctx, db, "commission non-first commission", 33, nonFirstOrderNo, 1)
fmt.Printf("PASS commission invite: percent=%d first_order_commission=%d non_first_commission=%d\n",
commissionPercent, wantCommission, wantCommission)
}
func assertExpire(ctx context.Context, db *sql.DB, label string, subID int64, before time.Time, addedDays int) {
got := mustExpire(ctx, db, subID)
want := before.Add(time.Duration(addedDays) * 24 * time.Hour)
if !got.Equal(want) {
exitf("%s expire mismatch: got=%s want=%s", label, got, want)
}
fmt.Printf("PASS %s subscribe=%d expire %s -> %s\n", label, subID, before.Format(time.RFC3339), got.Format(time.RFC3339))
}
func simulateInviteGiftBoth(ctx context.Context, db *sql.DB, orderNo string, refererID, refereeID, forcedRefereeOwnerID int64, days int) error {
refereeTargetID, err := resolveGiftTargetUser(ctx, db, refereeID, forcedRefereeOwnerID)
if err != nil {
return fmt.Errorf("resolve referee gift target: %w", err)
}
refererTargetID, err := resolveGiftTargetUser(ctx, db, refererID, 0)
if err != nil {
return fmt.Errorf("resolve referer gift target: %w", err)
}
if err := grantGiftDays(ctx, db, refereeTargetID, orderNo, days); err != nil {
return fmt.Errorf("grant referee gift: %w", err)
}
if err := grantGiftDays(ctx, db, refererTargetID, orderNo, days); err != nil {
return fmt.Errorf("grant referer gift: %w", err)
}
return nil
}
func simulateInviteCommission(ctx context.Context, db *sql.DB, orderNo string, refererID, refereeID, forcedRefereeOwnerID int64, days int, amount int64, commissionPercent int64, isFirstOrder bool) error {
if err := grantCommission(ctx, db, refererID, orderNo, amount, commissionPercent); err != nil {
return fmt.Errorf("grant commission: %w", err)
}
if isFirstOrder {
refereeTargetID, err := resolveGiftTargetUser(ctx, db, refereeID, forcedRefereeOwnerID)
if err != nil {
return fmt.Errorf("resolve referee gift target: %w", err)
}
if err := grantGiftDays(ctx, db, refereeTargetID, orderNo, days); err != nil {
return fmt.Errorf("grant commission-path referee gift: %w", err)
}
}
return nil
}
func resolveGiftTargetUser(ctx context.Context, db *sql.DB, userID int64, forcedOwnerID int64) (int64, error) {
if forcedOwnerID > 0 {
return forcedOwnerID, nil
}
var ownerID int64
err := db.QueryRowContext(ctx, `
SELECT uf.owner_user_id
FROM user_family_member ufm
JOIN user_family uf ON uf.id = ufm.family_id AND uf.deleted_at IS NULL
WHERE ufm.user_id = ?
AND ufm.deleted_at IS NULL
AND ufm.status = 1
AND ufm.role = 2
AND uf.status = 1
ORDER BY ufm.role
LIMIT 1`, userID).Scan(&ownerID)
if err == sql.ErrNoRows {
return userID, nil
}
if err != nil {
return 0, err
}
if ownerID > 0 && ownerID != userID {
return ownerID, nil
}
return userID, nil
}
func grantCommission(ctx context.Context, db *sql.DB, refererID int64, orderNo string, amount int64, commissionPercent int64) error {
var existing int64
err := db.QueryRowContext(ctx,
"SELECT COUNT(*) FROM system_logs WHERE type = 33 AND object_id = ? AND content LIKE ?",
refererID, "%\""+orderNo+"\"%",
).Scan(&existing)
if err != nil {
return err
}
if existing > 0 {
return nil
}
commissionAmount := amount * commissionPercent / 100
if _, err = db.ExecContext(ctx,
"UPDATE `user` SET commission = commission + ?, updated_at = ? WHERE id = ?",
commissionAmount, time.Now(), refererID,
); err != nil {
return err
}
content, err := json.Marshal(commissionLog{
Type: 331,
Amount: commissionAmount,
OrderNo: orderNo,
Timestamp: time.Now().UnixMilli(),
})
if err != nil {
return err
}
_, err = db.ExecContext(ctx,
"INSERT INTO system_logs (`type`, object_id, content, created_at, `date`) VALUES (33, ?, ?, ?, ?)",
refererID, string(content), time.Now(), time.Now().Format("2006-01-02"),
)
return err
}
func grantGiftDays(ctx context.Context, db *sql.DB, userID int64, orderNo string, days int) error {
var existing int64
err := db.QueryRowContext(ctx,
"SELECT COUNT(*) FROM system_logs WHERE type = 34 AND object_id = ? AND content LIKE ?",
userID, "%\""+orderNo+"\"%",
).Scan(&existing)
if err != nil {
return err
}
if existing > 0 {
return nil
}
sub, err := findActiveSubscribe(ctx, db, userID)
if err != nil {
return err
}
nextExpire := sub.ExpireTime
if !sub.ExpireTime.Equal(time.UnixMilli(0)) {
nextExpire = sub.ExpireTime.Add(time.Duration(days) * 24 * time.Hour)
if _, err = db.ExecContext(ctx,
"UPDATE user_subscribe SET expire_time = ?, updated_at = ? WHERE id = ?",
nextExpire, time.Now(), sub.ID,
); err != nil {
return err
}
}
content, err := json.Marshal(giftLog{
Type: 341,
OrderNo: orderNo,
SubscribeId: sub.ID,
Amount: int64(days),
Balance: 0,
Remark: "邀请赠送",
Timestamp: time.Now().UnixMilli(),
})
if err != nil {
return err
}
_, err = db.ExecContext(ctx,
"INSERT INTO system_logs (`type`, object_id, content, created_at, `date`) VALUES (34, ?, ?, ?, ?)",
userID, string(content), time.Now(), time.Now().Format("2006-01-02"),
)
return err
}
func findActiveSubscribe(ctx context.Context, db *sql.DB, userID int64) (*userSubscribe, error) {
var row userSubscribe
err := db.QueryRowContext(ctx, `
SELECT id, user_id, expire_time
FROM user_subscribe
WHERE user_id = ?
AND status IN (0, 1)
AND (expire_time > ? OR expire_time = '1970-01-01 08:00:00')
ORDER BY expire_time DESC, id DESC
LIMIT 1`, userID, time.Now()).Scan(&row.ID, &row.UserID, &row.ExpireTime)
if err != nil {
return nil, err
}
return &row, nil
}
func mustCreatePlan(ctx context.Context, db *sql.DB) int64 {
var sort int64
mustNoErr(db.QueryRowContext(ctx, "SELECT COALESCE(MAX(sort), 0) + 1 FROM subscribe").Scan(&sort))
res, err := db.ExecContext(ctx, `
INSERT INTO subscribe
(name, language, description, unit_price, unit_time, discount, replacement, inventory, traffic, speed_limit, device_limit, quota, new_user_only, nodes, node_tags, node_group_ids, node_group_id, traffic_limit, `+"`show`"+`, sell, sort, deduction_ratio, allow_deduction, reset_cycle, renewal_reset, show_original_price, created_at, updated_at)
VALUES (?, 'en', '', 599, 'Month', '', 0, -1, 1073741824, 0, 0, 0, false, '', '', '[]', 0, '', false, false, ?, 0, true, 0, false, true, ?, ?)`,
inviteGiftMarker+"-plan", sort, time.Now(), time.Now())
mustNoErr(err)
id, err := res.LastInsertId()
mustNoErr(err)
return id
}
func mustCreateUser(ctx context.Context, db *sql.DB, role string, refererID int64) int64 {
res, err := db.ExecContext(ctx, `
INSERT INTO `+"`user`"+`
(password, algo, avatar, balance, refer_code, referer_id, commission, referral_percentage, only_first_purchase, gift_amount, enable, is_admin, enable_balance_notify, enable_login_notify, enable_subscribe_notify, enable_trade_notify, rules, member_status, remark, created_at, updated_at, salt)
VALUES (?, 'default', '', 0, '', ?, 0, 0, true, 0, true, false, true, true, true, true, '', '', ?, ?, ?, 'default')`,
inviteGiftMarker, refererID, inviteGiftMarker+"-"+role, time.Now(), time.Now())
mustNoErr(err)
id, err := res.LastInsertId()
mustNoErr(err)
_, err = db.ExecContext(ctx, "UPDATE `user` SET refer_code = ?, updated_at = ? WHERE id = ?", fmt.Sprintf("codex%d", id), time.Now(), id)
mustNoErr(err)
return id
}
func mustCreateFamily(ctx context.Context, db *sql.DB, ownerID, memberID int64) int64 {
res, err := db.ExecContext(ctx, `
INSERT INTO user_family
(owner_user_id, max_members, status, created_at, updated_at)
VALUES (?, 3, 1, ?, ?)`, ownerID, time.Now(), time.Now())
mustNoErr(err)
familyID, err := res.LastInsertId()
mustNoErr(err)
now := time.Now()
_, err = db.ExecContext(ctx, `
INSERT INTO user_family_member
(family_id, user_id, role, status, join_source, joined_at, created_at, updated_at)
VALUES
(?, ?, 1, 1, ?, ?, ?, ?),
(?, ?, 2, 1, ?, ?, ?, ?)`,
familyID, ownerID, inviteGiftMarker, now, now, now,
familyID, memberID, inviteGiftMarker, now, now, now)
mustNoErr(err)
return familyID
}
func mustCreateUserSubscribe(ctx context.Context, db *sql.DB, userID, planID int64, expire time.Time) int64 {
token := fmt.Sprintf("%s-token-%d-%d", inviteGiftMarker, userID, time.Now().UnixNano())
uuid := fmt.Sprintf("%08d-0000-4000-8000-%012d", userID, time.Now().UnixNano()%1_000_000_000_000)
res, err := db.ExecContext(ctx, `
INSERT INTO user_subscribe
(user_id, order_id, subscribe_id, node_group_id, group_locked, traffic, download, upload, expired_download, expired_upload, token, uuid, status, note, created_at, updated_at, start_time, expire_time)
VALUES (?, 0, ?, 0, false, 1073741824, 0, 0, 0, 0, ?, ?, 1, ?, ?, ?, ?, ?)`,
userID, planID, token, uuid, inviteGiftMarker, time.Now(), time.Now(), time.Now().Add(-time.Hour), expire)
mustNoErr(err)
id, err := res.LastInsertId()
mustNoErr(err)
return id
}
func mustExpire(ctx context.Context, db *sql.DB, subID int64) time.Time {
var expire time.Time
mustNoErr(db.QueryRowContext(ctx, "SELECT expire_time FROM user_subscribe WHERE id = ?", subID).Scan(&expire))
return expire
}
func mustGiftLogCount(ctx context.Context, db *sql.DB, orderNo string) int64 {
var count int64
mustNoErr(db.QueryRowContext(ctx, "SELECT COUNT(*) FROM system_logs WHERE type = 34 AND content LIKE ?", "%"+orderNo+"%").Scan(&count))
return count
}
func assertCommission(ctx context.Context, db *sql.DB, userID int64, want int64) {
var got int64
mustNoErr(db.QueryRowContext(ctx, "SELECT commission FROM `user` WHERE id = ?", userID).Scan(&got))
if got != want {
exitf("commission mismatch: user=%d got=%d want=%d", userID, got, want)
}
fmt.Printf("PASS commission user=%d amount=%d\n", userID, got)
}
func assertLogCount(ctx context.Context, db *sql.DB, label string, logType uint8, orderNo string, want int64) {
var got int64
mustNoErr(db.QueryRowContext(ctx, "SELECT COUNT(*) FROM system_logs WHERE type = ? AND content LIKE ?", logType, "%"+orderNo+"%").Scan(&got))
if got != want {
exitf("%s log count mismatch: got=%d want=%d", label, got, want)
}
fmt.Printf("PASS %s logs=%d\n", label, got)
}
func cleanup(ctx context.Context, db *sql.DB) error {
stmts := []string{
"DELETE FROM user_family_member WHERE join_source = '" + inviteGiftMarker + "'",
"DELETE FROM user_family WHERE owner_user_id IN (SELECT id FROM `user` WHERE remark LIKE '" + inviteGiftMarker + "%')",
"DELETE FROM system_logs WHERE type IN (33, 34) AND content LIKE '%" + inviteGiftMarker + "%'",
"DELETE FROM user_subscribe WHERE note = '" + inviteGiftMarker + "' OR token LIKE '" + inviteGiftMarker + "%'",
"DELETE FROM subscribe WHERE name LIKE '" + inviteGiftMarker + "%'",
"DELETE FROM `user` WHERE remark LIKE '" + inviteGiftMarker + "%'",
}
for _, stmt := range stmts {
if _, err := db.ExecContext(ctx, stmt); err != nil {
return fmt.Errorf("%s: %w", stmt, err)
}
}
return nil
}
func mustNoErr(err error) {
if err != nil {
exitf("%v", err)
}
}
func exitf(format string, args ...interface{}) {
msg := fmt.Sprintf(format, args...)
fmt.Fprintln(os.Stderr, "FAIL:", strings.TrimSpace(msg))
os.Exit(1)
}