Compare commits

..

28 Commits
master ... main

Author SHA1 Message Date
769622f087 x
Some checks failed
Build docker and publish / build (20.15.1) (push) Has been cancelled
2026-04-29 23:30:38 -07:00
91935e3109 Revert "test(auth): add HTTP device no-trial check"
Some checks failed
Build docker and publish / build (20.15.1) (push) Has been cancelled
This reverts commit 3b3ed7b3c15a11ae70593a4c3c52e07689a60088.
2026-04-29 23:22:31 -07:00
3b3ed7b3c1 test(auth): add HTTP device no-trial check
All checks were successful
Build docker and publish / build (20.15.1) (push) Successful in 5m10s
Co-Authored-By: claude-flow <ruv@ruv.net>
2026-04-29 23:00:18 -07:00
b52e01eaa2 fix(auth): grant trial only on email bind
All checks were successful
Build docker and publish / build (20.15.1) (push) Successful in 5m17s
Co-Authored-By: claude-flow <ruv@ruv.net>
2026-04-29 22:36:17 -07:00
32e3dc3c73 fix(order): cover invite gifts and inactive renewals
All checks were successful
Build docker and publish / build (20.15.1) (push) Successful in 5m36s
2026-04-29 21:52:28 -07:00
6b64e8c461 test(auth): add device trial registration script
All checks were successful
Build docker and publish / build (20.15.1) (push) Successful in 5m6s
2026-04-29 21:05:52 -07:00
47696b9e68 fix(order): reconcile subscriptions and grant device trials
Some checks failed
Build docker and publish / build (20.15.1) (push) Has been cancelled
2026-04-29 21:00:46 -07:00
79427c9f4c 0430
All checks were successful
Build docker and publish / build (20.15.1) (push) Successful in 5m47s
2026-04-29 12:49:45 -07:00
bcefb274ab perf(server): cache speed limit calculations
All checks were successful
Build docker and publish / build (20.15.1) (push) Successful in 5m37s
2026-04-29 01:37:59 -07:00
3ae85f68ea 0428
All checks were successful
Build docker and publish / build (20.15.1) (push) Successful in 5m25s
2026-04-28 17:44:28 -07:00
ac57272018 x
All checks were successful
Build docker and publish / build (20.15.1) (push) Successful in 6m26s
2026-04-28 06:19:10 -07:00
68c7b0a8ec chore(deploy): add replication deployment assets
All checks were successful
Build docker and publish / build (20.15.1) (push) Successful in 5m52s
2026-04-28 05:22:48 -07:00
0ec0e2b9d2 fix(order): align invite gift ownership
Some checks failed
Build docker and publish / build (20.15.1) (push) Has been cancelled
2026-04-28 05:19:57 -07:00
ab38cd4943 x
Some checks failed
Build docker and publish / build (20.15.1) (push) Failing after 4m44s
2026-04-26 21:12:22 -07:00
5b49aa8242 fix(auth): disable trial grants on public email flows
All checks were successful
Build docker and publish / build (20.15.1) (push) Successful in 5m29s
2026-04-25 01:11:27 -07:00
9db4762904 fix(order): prevent duplicate subscriptions and repair invite gifts
All checks were successful
Build docker and publish / build (20.15.1) (push) Successful in 5m6s
2026-04-24 21:16:21 -07:00
ae62ecc6b3 fix: 加入家庭组时无条件丢弃成员订阅,防止重复订阅
Some checks failed
Build docker and publish / build (20.15.1) (push) Failing after 5m17s
加入家庭组前若成员已购买订阅,原逻辑将订阅转移给 owner,
导致 owner 同时持有自身订阅与成员转入订阅,违反单订阅模式。

修改 transferMemberSubscribesToOwner:
- 移除转移逻辑,改为无条件删除成员所有订阅
- 成员加入后通过 owner 的订阅使用服务
- 后续购买以 entitlement.EffectiveUserID(owner)为目标,不受影响
2026-04-22 09:24:00 -07:00
4b73cd4d3c fix: 泛域名邮箱(+别名/Gmail点号)拦截提前,不受白名单开关影响
All checks were successful
Build docker and publish / build (20.15.1) (push) Successful in 5m38s
2026-04-21 09:44:00 -07:00
2c9833df58 fix: 有返佣路径首单漏发被邀请用户赠天
All checks were successful
Build docker and publish / build (20.15.1) (push) Successful in 5m37s
邀请人有返佣比例时,handleCommission 走佣金路径,
之前完全未调用 grantGiftDays,导致设备首单付费后
被邀请用户拿不到 N 天赠送。

修复:佣金处理完成后,若 IsNew(首单),
额外给被邀请用户调用 grantGiftDays(邀请人不重复赠天,
已通过佣金受益)。

Co-Authored-By: claude-flow <ruv@ruv.net>
2026-04-21 01:57:22 -07:00
23a7a292ef fix: Gmail 泛域名邮箱(含点号/+别名)直接拒绝赠送试用
All checks were successful
Build docker and publish / build (20.15.1) (push) Successful in 5m23s
Co-Authored-By: claude-flow <ruv@ruv.net>
2026-04-21 01:09:34 -07:00
f1bfc78d66 fix: 统一日期统计查询方式,使用 DATE_FORMAT 替代 time.Time 边界
All checks were successful
Build docker and publish / build (20.15.1) (push) Successful in 4m58s
QueryDateOrders 和 QueryDateUserCounts 改用 DATE_FORMAT 字符串比较,
与 QueryDailyOrdersList 的 GROUP BY 逻辑一致,避免 go-sql-driver 时区转换导致金额不一致。

Co-Authored-By: claude-flow <ruv@ruv.net>
2026-04-20 22:36:16 -07:00
9912df9ac6 fix: 修复时区问题 - FixedZone 兜底 + Dockerfile 复制完整 zoneinfo
All checks were successful
Build docker and publish / build (20.15.1) (push) Successful in 4m56s
1. ppanel.go: LoadLocation 失败时用 FixedZone("CST", +8h) 兜底
2. Dockerfile: 复制完整 /usr/share/zoneinfo 目录,确保 go-sql-driver 也能加载 Asia/Shanghai

Co-Authored-By: claude-flow <ruv@ruv.net>
2026-04-20 21:53:48 -07:00
bafb13cf06 fix: 修复 scratch 容器中 time.Local 默认 UTC 导致收入统计时间窗口偏移 8 小时
Some checks failed
Build docker and publish / build (20.15.1) (push) Has been cancelled
Co-Authored-By: claude-flow <ruv@ruv.net>
2026-04-20 21:46:55 -07:00
9a8ae8b6fd fix: 修复非单订阅模式下过期用户重复购买产生双订阅的问题
All checks were successful
Build docker and publish / build (20.15.1) (push) Successful in 4m45s
1. purchaseLogic: 非单订阅模式下购买前查询已有订阅,路由为续费(type=2)
2. activateOrderLogic: 续费激活时触发节点分组重算,确保过期续费后权限生效

Co-Authored-By: claude-flow <ruv@ruv.net>
2026-04-20 20:58:01 -07:00
c8258dc93b feat: 设备登录新增 base_payload 字段,前端传入后存储到 user_device 表
All checks were successful
Build docker and publish / build (20.15.1) (push) Successful in 4m48s
2026-04-20 20:08:20 -07:00
c0d839deb9 fix: 修复仪表盘时区统计偏移、重复订阅、新增map_apple字段
All checks were successful
Build docker and publish / build (20.15.1) (push) Successful in 5m23s
- fix(order/model): QueryDateOrders/QueryDailyOrdersList 使用 time.Date 替代 Truncate 修复 UTC+8 时区偏移
- fix(user/model): QueryResisterUserTotalByDate 同样修复时区截断
- fix(traffic/model): QueryServerTrafficByDay 同样修复时区截断
- fix(activateOrder): 兜底查询防止过期用户重购产生重复订阅
- feat(api): SubscribeDiscount 新增 map_apple 字段
2026-04-20 02:34:23 -07:00
800f9c8460 x
Some checks failed
Build docker and publish / build (20.15.1) (push) Failing after 5m7s
2026-04-12 18:44:37 -07:00
954b19c332 feat: 邮箱规范化(NormalizeEmail)与域名白名单检查(IsEmailDomainWhitelisted)
Some checks failed
Build docker and publish / build (20.15.1) (push) Has been cancelled
2026-04-12 18:43:47 -07:00
60 changed files with 7793 additions and 234 deletions

View File

@ -21,7 +21,7 @@ env:
SSH_PASSWORD: ${{ github.ref_name == 'main' && vars.SSH_PASSWORD || vars.DEV_SSH_PASSWORD }} SSH_PASSWORD: ${{ github.ref_name == 'main' && vars.SSH_PASSWORD || vars.DEV_SSH_PASSWORD }}
# TG通知 # TG通知
TG_BOT_TOKEN: 8114337882:AAHkEx03HSu7RxN4IHBJJEnsK9aPPzNLIk0 TG_BOT_TOKEN: 8114337882:AAHkEx03HSu7RxN4IHBJJEnsK9aPPzNLIk0
TG_CHAT_ID: "-4940243803" TG_CHAT_ID: "-49402438031"
# Go构建变量 # Go构建变量
SERVICE: vpn SERVICE: vpn
SERVICE_STYLE: vpn SERVICE_STYLE: vpn
@ -49,12 +49,12 @@ jobs:
if [ "${{ github.ref_name }}" = "main" ]; then if [ "${{ github.ref_name }}" = "main" ]; then
echo "DOCKER_TAG_SUFFIX=latest" >> $GITHUB_ENV echo "DOCKER_TAG_SUFFIX=latest" >> $GITHUB_ENV
echo "CONTAINER_NAME=ppanel-server" >> $GITHUB_ENV echo "CONTAINER_NAME=ppanel-server" >> $GITHUB_ENV
echo "DEPLOY_PATH=/root/bindbox" >> $GITHUB_ENV echo "DEPLOY_PATH=/root/hifast" >> $GITHUB_ENV
echo "为 main 分支设置生产环境变量" echo "为 main 分支设置生产环境变量"
elif [ "${{ github.ref_name }}" = "internal" ]; then elif [ "${{ github.ref_name }}" = "internal" ]; then
echo "DOCKER_TAG_SUFFIX=internal" >> $GITHUB_ENV echo "DOCKER_TAG_SUFFIX=internal" >> $GITHUB_ENV
echo "CONTAINER_NAME=ppanel-server-internal" >> $GITHUB_ENV echo "CONTAINER_NAME=ppanel-server-internal" >> $GITHUB_ENV
echo "DEPLOY_PATH=/root/bindbox" >> $GITHUB_ENV echo "DEPLOY_PATH=/root/hifast" >> $GITHUB_ENV
echo "为 internal 分支设置开发环境变量" echo "为 internal 分支设置开发环境变量"
else else
echo "DOCKER_TAG_SUFFIX=${{ github.ref_name }}" >> $GITHUB_ENV echo "DOCKER_TAG_SUFFIX=${{ github.ref_name }}" >> $GITHUB_ENV

View File

@ -28,7 +28,7 @@ FROM scratch
# Copy CA certificates and timezone data # Copy CA certificates and timezone data
COPY --from=builder /etc/ssl/certs/ca-certificates.crt /etc/ssl/certs/ COPY --from=builder /etc/ssl/certs/ca-certificates.crt /etc/ssl/certs/
COPY --from=builder /usr/share/zoneinfo/Asia/Shanghai /usr/share/zoneinfo/Asia/Shanghai COPY --from=builder /usr/share/zoneinfo /usr/share/zoneinfo
ENV TZ=Asia/Shanghai ENV TZ=Asia/Shanghai

View File

@ -154,6 +154,7 @@ type (
UserAgent string `json:"user_agent" validate:"required"` UserAgent string `json:"user_agent" validate:"required"`
CfToken string `json:"cf_token,optional"` CfToken string `json:"cf_token,optional"`
ShortCode string `json:"short_code,optional"` ShortCode string `json:"short_code,optional"`
BasePayload string `json:"base_payload,optional"`
} }
GenerateCaptchaResponse { GenerateCaptchaResponse {
Id string `json:"id"` Id string `json:"id"`

View File

@ -227,6 +227,7 @@ type (
SubscribeDiscount { SubscribeDiscount {
Quantity int64 `json:"quantity"` Quantity int64 `json:"quantity"`
Discount float64 `json:"discount"` Discount float64 `json:"discount"`
MapApple string `json:"map_apple"`
} }
TrafficLimit { TrafficLimit {
StatType string `json:"stat_type"` StatType string `json:"stat_type"`

View File

@ -0,0 +1,152 @@
# MySQL 8.0 master/replica compose for two separate servers.
#
# Master server:
# COMPOSE_PROFILES=master docker compose -f config/docker-compose.mysql-replication.yml up -d
#
# Replica server:
# MASTER_HOST=<master_public_or_private_ip> COMPOSE_PROFILES=replica docker compose -f config/docker-compose.mysql-replication.yml up -d
#
# Required env on both servers:
# MYSQL_ROOT_PASSWORD=<strong-root-password>
# MYSQL_REPLICATION_PASSWORD=<strong-replication-password>
#
# Optional env:
# MYSQL_DATABASE=ppanel
# MYSQL_REPLICATION_USER=repl
# MYSQL_MASTER_PORT=3306
# MYSQL_REPLICA_PORT=3306
# MYSQL_SERVER_ID=1 # master default
# MYSQL_REPLICA_ID=2 # replica default
#
# If the master already has data, import a GTID-aware dump into the replica
# before starting replication. Fresh empty deployments can start master first,
# then replica, then point the application at the master.
services:
mysql-master:
image: mysql:8.0
container_name: ppanel-mysql-master
profiles:
- master
restart: always
ports:
- "${MYSQL_MASTER_PORT:-3306}:3306"
environment:
MYSQL_ROOT_PASSWORD: "${MYSQL_ROOT_PASSWORD:?please set MYSQL_ROOT_PASSWORD}"
MYSQL_DATABASE: "${MYSQL_DATABASE:-ppanel}"
MYSQL_REPLICATION_USER: "${MYSQL_REPLICATION_USER:-repl}"
MYSQL_REPLICATION_PASSWORD: "${MYSQL_REPLICATION_PASSWORD:?please set MYSQL_REPLICATION_PASSWORD}"
TZ: Asia/Shanghai
command:
- --default-authentication-plugin=mysql_native_password
- --server-id=${MYSQL_SERVER_ID:-1}
- --log-bin=mysql-bin
- --binlog-format=ROW
- --gtid-mode=ON
- --enforce-gtid-consistency=ON
- --log-replica-updates=ON
- --binlog-expire-logs-seconds=604800
- --max_connections=1000
- --character-set-server=utf8mb4
- --collation-server=utf8mb4_unicode_ci
volumes:
- mysql_master_data:/var/lib/mysql
configs:
- source: mysql_master_init
target: /docker-entrypoint-initdb.d/01-create-replication-user.sh
mode: 0755
healthcheck:
test: ["CMD-SHELL", "mysqladmin ping -h 127.0.0.1 -uroot -p$${MYSQL_ROOT_PASSWORD}"]
interval: 10s
timeout: 5s
retries: 10
logging:
driver: json-file
options:
max-size: 10m
max-file: "3"
mysql-replica:
image: mysql:8.0
container_name: ppanel-mysql-replica
profiles:
- replica
restart: always
ports:
- "${MYSQL_REPLICA_PORT:-3306}:3306"
environment:
MYSQL_ROOT_PASSWORD: "${MYSQL_ROOT_PASSWORD:?please set MYSQL_ROOT_PASSWORD}"
MYSQL_DATABASE: "${MYSQL_DATABASE:-ppanel}"
TZ: Asia/Shanghai
command:
- --default-authentication-plugin=mysql_native_password
- --server-id=${MYSQL_REPLICA_ID:-2}
- --relay-log=mysql-relay-bin
- --read-only=ON
- --super-read-only=ON
- --gtid-mode=ON
- --enforce-gtid-consistency=ON
- --log-replica-updates=ON
- --binlog-format=ROW
- --max_connections=1000
- --character-set-server=utf8mb4
- --collation-server=utf8mb4_unicode_ci
volumes:
- mysql_replica_data:/var/lib/mysql
healthcheck:
test: ["CMD-SHELL", "mysqladmin ping -h 127.0.0.1 -uroot -p$${MYSQL_ROOT_PASSWORD}"]
interval: 10s
timeout: 5s
retries: 10
logging:
driver: json-file
options:
max-size: 10m
max-file: "3"
mysql-replica-init:
image: mysql:8.0
container_name: ppanel-mysql-replica-init
profiles:
- replica
restart: "no"
depends_on:
mysql-replica:
condition: service_healthy
environment:
MYSQL_ROOT_PASSWORD: "${MYSQL_ROOT_PASSWORD:?please set MYSQL_ROOT_PASSWORD}"
MYSQL_REPLICATION_USER: "${MYSQL_REPLICATION_USER:-repl}"
MYSQL_REPLICATION_PASSWORD: "${MYSQL_REPLICATION_PASSWORD:?please set MYSQL_REPLICATION_PASSWORD}"
MASTER_HOST: "${MASTER_HOST:?please set MASTER_HOST to the master server ip or hostname}"
MASTER_PORT: "${MASTER_PORT:-3306}"
entrypoint:
- /bin/sh
- -ec
- |
mysql -hmysql-replica -uroot -p"$${MYSQL_ROOT_PASSWORD}" <<SQL
STOP REPLICA;
CHANGE REPLICATION SOURCE TO
SOURCE_HOST='$${MASTER_HOST}',
SOURCE_PORT=$${MASTER_PORT},
SOURCE_USER='$${MYSQL_REPLICATION_USER}',
SOURCE_PASSWORD='$${MYSQL_REPLICATION_PASSWORD}',
SOURCE_AUTO_POSITION=1,
GET_SOURCE_PUBLIC_KEY=1;
START REPLICA;
SQL
configs:
mysql_master_init:
content: |
#!/bin/sh
set -eu
mysql -uroot -p"$${MYSQL_ROOT_PASSWORD}" <<SQL
CREATE USER IF NOT EXISTS '$${MYSQL_REPLICATION_USER}'@'%' IDENTIFIED WITH mysql_native_password BY '$${MYSQL_REPLICATION_PASSWORD}';
GRANT REPLICATION SLAVE, REPLICATION CLIENT ON *.* TO '$${MYSQL_REPLICATION_USER}'@'%';
FLUSH PRIVILEGES;
SQL
volumes:
mysql_master_data:
mysql_replica_data:

BIN
debug_device_login Executable file

Binary file not shown.

862
doc/invite-purchase-flow.md Normal file
View File

@ -0,0 +1,862 @@
# 邀请赠送与购买订阅逻辑说明
本文档说明当前代码中的购买订阅、订单激活、邀请佣金、邀请赠送时间、家庭组归属逻辑。重点覆盖每个主要分支,方便排查“重复订单/重复订阅/邀请未赠时/赠时落点错误”等问题。
涉及核心文件:
- `internal/logic/public/order/purchaseLogic.go`
- `queue/logic/order/activateOrderLogic.go`
- `internal/logic/common/familyEntitlement.go`
- `internal/model/user/model.go`
## 1. 核心概念
### 1.1 订单状态
| 状态 | 含义 |
| --- | --- |
| `1` | pending已创建未支付 |
| `2` | paid已支付待激活 |
| `3` | close已关闭 |
| `4` | failed/claimed代码里同时用于失败和 worker 临时领取 |
| `5` | finished激活完成 |
### 1.2 订单类型
| 类型 | 含义 |
| --- | --- |
| `1` | 新购套餐 |
| `2` | 续费/换套餐 |
| `3` | 重置流量 |
| `4` | 余额充值 |
| `5` | 兑换码激活 |
### 1.3 用户 ID 与订阅归属
订单有两个重要用户字段:
| 字段 | 含义 |
| --- | --- |
| `user_id` | 发起订单/付款的用户 |
| `subscription_user_id` | 订阅权益归属用户;`0` 表示同 `user_id` |
家庭组规则:
- 普通用户下单:`subscription_user_id = user_id`
- 家庭组成员下单:`subscription_user_id = 家主用户 ID`
- 家庭组主账号下单:`subscription_user_id = 家主用户 ID`
当前代码使用 `ResolveEntitlementUser` 判断家庭归属:
- 只有有效家庭组、有效成员关系、角色为 member 时,权益归家主。
- 家主本人不会被改写到别人名下。
## 2. 购买订阅下单逻辑
入口:`Purchase(req *types.PurchaseOrderRequest)`
这里只是创建订单和安排关闭任务,不直接发放订阅。真正发放订阅在订单支付后由队列激活处理。
### 2.1 登录用户检查
分支:
- 上下文没有当前用户:返回 `InvalidAccess`
- 当前用户存在:继续。
### 2.2 解析订阅权益归属
调用 `ResolveEntitlementUser`
| 场景 | `effective_user_id` | 结果 |
| --- | --- | --- |
| 普通用户 | 本人 ID | 订阅归本人 |
| 家庭组主账号 | 本人 ID | 订阅归本人 |
| 家庭组成员 | 家主 ID | 订阅归家主 |
后续查询已有订阅、quota、创建订单里的 `subscription_user_id` 都会使用这个归属结果。
### 2.3 数量校验
分支:
- `quantity <= 0`:自动改成 `1`
- `quantity > MaxQuantity`:返回参数错误。
- 合法:继续。
### 2.4 单订阅模式路由
先默认:
```text
order_type = 1
target_subscribe_id = req.subscribe_id
parent_order_id = 0
subscribe_token = ""
```
如果开启 `Subscribe.SingleModel`
| 查询结果 | 行为 |
| --- | --- |
| 找到已有 anchor 订阅 | 下单路由为续费:`order_type=2`,保留新请求套餐 ID设置 parent/order token |
| 没找到已有订阅 | 保持新购:`order_type=1` |
| 查询异常 | 返回数据库错误 |
说明:即使是换套餐,只要单订阅模式已有订阅,也走续费语义,后续激活会更新套餐 ID 和流量配置。
### 2.5 非 SingleModel 的全局单订阅兜底
如果未开启 `SingleModel`,且当前还是新购 `order_type=1`
- 查询 `effective_user_id` 名下已有付费订阅:
```sql
user_id = effective_user_id
AND token != ''
AND (order_id > 0 OR token LIKE 'iap:%')
```
| 查询结果 | 行为 |
| --- | --- |
| 找到已有订阅 | 路由为续费:`order_type=2`,用已有订阅 token |
| 没找到 | 仍然新购 |
目的:避免同一个权益归属用户购买不同套餐后出现多条订阅权益。
### 2.6 pending 订单处理
当前只有一个分支会主动关闭旧 pending 单:
```text
SingleModel = true
AND order_type = 1
AND 存在同 user_id + subscribe_id + status=1 的订单
```
行为:
- 关闭旧 pending 订单。
- 继续创建新订单。
注意:
- 如果订单已被路由为 `order_type=2`,这里不会关闭旧 pending。
- 非 `SingleModel` 下也不会走这段 pending 关闭逻辑。
### 2.7 套餐校验
分支:
| 条件 | 行为 |
| --- | --- |
| 套餐不存在 | 返回数据库错误 |
| `sell=false` | 返回套餐不可售 |
| 新购且库存为 `0` | 返回库存不足 |
| 续费/换套餐 | 不检查库存为 0 的拦截分支 |
### 2.8 新用户优惠与新用户限定
调用 `resolveNewUserDiscountEligibility`
| 分支 | 行为 |
| --- | --- |
| 解析失败 | 返回错误 |
| 有折扣且符合条件 | 按折扣计算金额 |
| 有折扣但不符合条件 | 按原价 |
| 新用户限定且不是新用户窗口 | 新购事务内再次校验,不通过则失败 |
### 2.9 优惠券逻辑
如果 `req.coupon` 为空:跳过。
如果不为空:
| 校验 | 不通过行为 |
| --- | --- |
| 优惠券存在 | 返回 `CouponNotExist` |
| 总使用次数未超限 | 返回使用次数不足 |
| 用户使用次数未超限 | 返回用户次数不足 |
| 套餐适用 | 返回不适用 |
通过后计算 `coupon_discount`,从订单金额中扣除。
### 2.10 支付手续费与礼品余额抵扣
流程:
1. 找支付方式。
2. 如果金额大于 0计算手续费并加到订单金额。
3. 如果用户 `gift_amount > 0`,继续抵扣订单金额。
4. 抵扣金额记录到订单 `gift_amount`
事务内如果有礼品余额抵扣:
- 扣减用户 `gift_amount`
- 写 `system_logs` 的 gift reduce 日志。
### 2.11 `is_new` 首单标记
创建订单前调用:
```sql
SELECT COUNT(*)
FROM `order`
WHERE user_id = 当前付款用户
AND status IN (2, 5)
```
| 结果 | `is_new` |
| --- | --- |
| count = 0 | `true` |
| count > 0 | `false` |
注意:
- 判断口径是付款用户 `user_id`,不是 `subscription_user_id`
- pending/closed 订单不影响 `is_new`
- 续费订单也可能是 `is_new=true`,例如单订阅模式下首次购买被路由为续费。
### 2.12 事务内创建订单
事务内执行:
1. 新购且套餐有 quota 时,再查一次 `effective_user_id` 名下订阅数量防并发。
2. 新购且新用户限定时,再查一次新用户资格。
3. 如有礼品余额抵扣,扣减余额并写日志。
4. 新购且库存不是 `-1` 时扣库存。
5. 插入订单。
订单核心字段:
| 字段 | 值 |
| --- | --- |
| `user_id` | 当前付款用户 |
| `subscription_user_id` | 权益归属用户 |
| `type` | 新购 `1` 或续费 `2` |
| `subscribe_id` | 本次购买的套餐 ID |
| `subscribe_token` | 续费时已有订阅 token |
| `is_new` | 首单标记 |
| `status` | `1` pending |
### 2.13 延迟关单任务
订单创建成功后,发送 `DeferCloseOrder` 任务:
- 延迟时间15 分钟。
- 作用:未支付订单自动关闭。
## 3. 订单支付后的激活逻辑
入口:`ActivateOrderLogic.ProcessTask`
### 3.1 任务解析和订单领取
分支:
| 分支 | 行为 |
| --- | --- |
| payload 解析失败 | 记录错误,不重试 |
| 订单不存在 | 返回错误,允许重试 |
| 订单已 finished | 幂等跳过 |
| 订单不是 paid | 跳过 |
| paid 订单 | 原子更新为 claimed 状态后处理 |
### 3.2 按订单类型分发
| 订单类型 | 处理函数 |
| --- | --- |
| 新购 `1` | `NewPurchase` |
| 续费 `2` | `Renewal` |
| 重置流量 `3` | `ResetTraffic` |
| 充值 `4` | `Recharge` |
| 兑换码 `5` | `RedemptionActivate` |
处理成功后:
1. 执行订阅合并兜底 `reconcilePostOrderSubscriptions`
2. 更新优惠券使用次数。
3. 更新订单为 `finished`
如果处理失败:
- 把订单从 claimed 释放回 paid。
- 返回错误给队列重试。
## 4. 新购订单激活与订阅发放
入口:`NewPurchase`
### 4.1 获取用户
分支:
| 订单 user_id | 行为 |
| --- | --- |
| 不为 0 | 查询已有用户 |
| 为 0 | 从 Redis 临时订单创建游客用户 |
游客订单创建用户时:
- 创建用户和 auth method。
- 生成 refer code。
- 把订单 `user_id` 更新为新用户 ID。
- 如果临时订单有邀请码,绑定 `referer_id`
### 4.2 新用户限定激活时复查
如果套餐是新用户限定,激活时再次校验:
- 不符合则激活失败,订单回到 paid 等待重试/处理。
- 符合继续。
### 4.3 SingleModel 下复用 anchor 订阅
如果 `Subscribe.SingleModel=true`
| 分支 | 行为 |
| --- | --- |
| 找到 anchor 订阅 | 更新订单 parent_id用续费逻辑延长/换套餐 |
| 找不到 | 继续后续分支 |
| 查询异常 | 记录错误,继续后续分支 |
家庭组场景下查 anchor 的用户 ID
- 优先 `subscription_user_id`
- 没有则用 `user_id`
### 4.4 复用赠送订阅
如果还没有可复用订阅:
- 查找权益归属用户名下 `order_id=0` 的赠送订阅。
- 找到后将它升级为付费订阅:
- `order_id` 改为当前订单 ID。
- 延长到期时间。
- 状态改为 active。
- 如果套餐变更,更新套餐 ID、流量额度并清空已用流量。
### 4.5 兜底复用已有订阅
如果仍未复用到订阅:
- 候选用户 ID
- `order.user_id`
- 如果 `subscription_user_id` 存在且不同,也加入候选。
- 查找这些用户名下 `token != ''` 的订阅。
找到后:
- 如果订阅 owner 不是当前 `subscription_user_id`,先把 `user_id` 修正为权益归属用户。
- 用续费逻辑延长/换套餐。
目的:家庭组绑定前后 owner 变化时,也尽量复用旧记录,避免创建重复订阅。
### 4.6 创建新订阅
如果以上都没有复用成功,才创建新 `user_subscribe`
| 字段 | 值 |
| --- | --- |
| `user_id` | `subscription_user_id`,没有则 `order.user_id` |
| `order_id` | 当前订单 ID |
| `subscribe_id` | 当前订单套餐 ID |
| `start_time` | 当前时间 |
| `expire_time` | 按套餐时间单位和数量计算 |
| `traffic` | 套餐流量 |
| `token` | 基于订单号生成 |
| `uuid` | 新 UUID |
| `status` | `1` active |
创建前如果套餐有 quota会再按订阅 owner 统计数量。
### 4.7 新购激活后的异步逻辑
订阅发放后:
1. 后台触发用户分组重算。
2. 后台异步处理邀请佣金和赠送时间。
3. 清套餐缓存。
注意:邀请逻辑在 goroutine 中执行,不阻塞订单激活。
## 5. 续费/换套餐激活逻辑
入口:`Renewal`
### 5.1 获取用户和订阅
- 查询订单 `user_id` 对应用户。
- 通过 `subscribe_token` 查订阅。
- 查询订单 `subscribe_id` 对应套餐。
### 5.2 Apple IAP 与普通续费分支
| 分支 | 行为 |
| --- | --- |
| `iap_expire_at > 0` | 使用 IAP 到期时间兜底,但仍按累计加时语义 |
| 普通续费 | `updateSubscriptionForRenewal` |
### 5.3 普通续费/换套餐规则
`updateSubscriptionForRenewal`
- 如果当前订阅已过期,先把基准时间改为现在。
- 如果套餐 ID 变化:
- 更新订阅套餐 ID。
- 更新流量额度。
- 清空已用流量。
- 如果套餐没变:
- 如果套餐设置 renewal reset或今天是重置日则清空已用流量。
- 清理 `finished_at`
- `order_id` 改为当前订单 ID。
- 按套餐时间单位和数量延长到期时间。
- 状态改为 active。
- 清空过期流量字段。
### 5.4 续费后的邀请逻辑
续费成功后也会调用 `handleCommission`
- 是否发佣金由邀请配置和 `order.is_new` 决定。
- 是否赠时同样由邀请配置和 `order.is_new` 决定。
注意:如果 `OnlyFirstPurchase=true`,非首单续费通常不会发佣金,也不会赠首单时间。
## 6. 邀请关系绑定逻辑
### 6.1 注册/登录时的邀请码
新用户注册、游客订单创建用户时,如果带邀请码:
- 根据邀请码查邀请人。
- 设置新用户 `referer_id = 邀请人 ID`
### 6.2 用户后绑邀请码
入口:`BindInviteCode`
分支:
| 分支 | 行为 |
| --- | --- |
| 当前用户不存在 | 返回无权限 |
| 当前用户已有 `referer_id` | 返回已绑定 |
| 邀请码不存在 | 返回邀请码错误 |
| 邀请码属于自己 | 返回不允许绑定自己 |
| 通过 | 更新当前用户 `referer_id` |
注意:
- `referer_id` 始终记录实际邀请码所有者。
- 邀请人是家庭成员时,`referer_id` 仍然是该成员 ID不自动改为家主 ID。
## 7. 邀请佣金与赠送时间逻辑
入口:`handleCommission(userInfo, orderInfo)`
这里的 `userInfo` 是订单付款用户,也就是被邀请人。
### 7.1 总入口分支
先调用 `shouldProcessCommission(userInfo, orderInfo.IsNew)`
| 结果 | 行为 |
| --- | --- |
| `false` | 不发佣金;如果 `is_new=true`,走双方赠时 |
| `true` | 发佣金;如果 `is_new=true`,被邀请人赠时 |
### 7.2 什么时候发佣金
`shouldProcessCommission` 规则:
| 条件 | 结果 |
| --- | --- |
| 被邀请人为空 | 不发 |
| 被邀请人 `referer_id=0` | 不发 |
| 查不到邀请人 | 不发 |
| 邀请人自定义 `referral_percentage > 0`,且只首购但不是首单 | 不发 |
| 邀请人自定义 `referral_percentage > 0`,且通过首购限制 | 发佣金 |
| 邀请人无自定义比例,系统 `ReferralPercentage=0` | 不发 |
| 系统 `OnlyFirstPurchase=true` 且不是首单 | 不发 |
| 系统有比例且通过首购限制 | 发佣金 |
### 7.3 发佣金路径
如果 `shouldProcessCommission=true`
1. 查询邀请人,也就是 `userInfo.referer_id` 对应用户。
2. 佣金比例:
- 邀请人自定义比例优先。
- 否则用系统配置 `Invite.ReferralPercentage`
3. 佣金金额:
```text
(order.amount - order.fee_amount) * referral_percentage / 100
```
4. 事务内幂等检查:
- 如果已有同订单佣金日志,则跳过。
- 否则增加邀请人的 `commission`
- 写 `system_logs type=33` 佣金日志。
5. 更新邀请人缓存。
6. 如果 `order.is_new=true`
- 给被邀请人赠送订阅时间。
当前保持不变的行为:
- 邀请人是家庭成员时,佣金仍然给实际邀请人成员本人。
- 佣金不归并到家主。
- 有佣金路径下,邀请人不额外赠送订阅时间。
### 7.4 不发佣金路径
如果 `shouldProcessCommission=false`
| `order.is_new` | 行为 |
| --- | --- |
| `true` | 被邀请人和邀请人双方赠送订阅时间 |
| `false` | 不赠送时间 |
双方赠时具体为:
1. 被邀请人赠时:
- 如果被邀请人是家庭成员,加到被邀请人家主套餐。
- 否则加到被邀请人本人套餐。
2. 邀请人赠时:
- 如果邀请人是家庭成员,加到邀请人家主套餐。
- 否则加到邀请人本人套餐。
## 8. 赠送时间目标解析
入口:`resolveGiftTargetUser(source, forcedOwnerID)`
### 8.1 强制 owner 分支
如果 `forcedOwnerID > 0`
- 赠送目标直接使用 `forcedOwnerID`
- 典型场景:订单里已有 `subscription_user_id`
- 这保证了家庭成员购买时,被邀请人的赠时落到家主。
### 8.2 自动家庭组解析分支
如果没有强制 owner
- 调用 `ResolveEntitlementUser(source.Id)`
- 如果 source 是有效家庭成员,目标改为家主。
- 否则目标为本人。
典型场景:
- 无佣金路径下,邀请人也赠时。
- 邀请人如果是家庭成员,赠时会加到邀请人家主套餐。
### 8.3 目标用户查询失败
如果解析出来的目标用户查不到:
- 记录错误日志。
- 回退为 source 本人。
## 9. 赠送时间落库逻辑
入口:`grantGiftDays(u, days, orderNo, remark)`
### 9.1 空值和配置分支
| 条件 | 行为 |
| --- | --- |
| 目标用户为空 | 直接返回,不写日志 |
| `days <= 0` | 直接返回,不写日志 |
### 9.2 幂等检查
按下面条件查 gift 日志:
```sql
type = 34
AND object_id = 目标用户 ID
AND content LIKE '%订单号%'
```
| 结果 | 行为 |
| --- | --- |
| 已存在 | 跳过,不重复赠时 |
| 不存在 | 继续 |
### 9.3 查目标用户活跃订阅
调用 `FindActiveSubscribe`
当前活跃口径:
```sql
user_id = 目标用户 ID
AND status IN (0, 1)
AND (
expire_time > NOW()
OR expire_time = FROM_UNIXTIME(0)
)
```
说明:
- `expire_time > NOW()` 是普通未过期订阅。
- `expire_time = FROM_UNIXTIME(0)` 是永久/不限时订阅。
### 9.4 没有活跃订阅
如果查不到活跃订阅:
- 不创建新订阅。
- 写一条 `system_logs type=34` 日志。
- 日志 remark 为:
```text
邀请赠送 skipped: no active subscription
```
这表示邀请赠时触发过,但目标用户当时没有可加时的套餐。
### 9.5 找到普通活跃订阅
如果目标订阅不是永久订阅:
- `expire_time += days * 24h`
- 更新订阅。
- 写 `system_logs type=34` gift increase 日志。
### 9.6 找到永久订阅
如果目标订阅 `expire_time = FROM_UNIXTIME(0)`
- 不改变 `expire_time`,因为永久订阅没有可延长的到期时间。
- 仍写 `system_logs type=34` gift increase 日志,表示赠送逻辑已识别并处理。
### 9.7 赠时失败日志
发佣金路径和无佣金路径都会检查 `grantGiftDays` 返回错误。
如果出错,会写应用日志:
```text
Grant invite gift days failed
```
附带字段:
- `stage`
- `target_user_id`
- `order_no`
- `error`
## 10. 家庭组下的完整分支示例
### 10.1 被邀请人是普通用户,邀请人普通用户,有佣金
条件:
- 被邀请人 `referer_id != 0`
- 系统或邀请人佣金比例大于 0
- `order.is_new=true`
结果:
- 佣金给邀请人本人。
- 被邀请人本人套餐加赠送时间。
- 邀请人不加赠送时间。
### 10.2 被邀请人是家庭成员,邀请人普通用户,有佣金
结果:
- 佣金给邀请人本人。
- 被邀请人的赠送时间加到被邀请人家主套餐。
- 被邀请人成员本人不单独加订阅时间。
### 10.3 被邀请人普通用户,邀请人是家庭成员,有佣金
结果:
- 佣金给邀请人成员本人。
- 被邀请人本人套餐加赠送时间。
- 邀请人不加赠送时间。
- 邀请人家主不拿佣金,也不因该佣金路径加赠时。
### 10.4 被邀请人是家庭成员,邀请人也是家庭成员,有佣金
结果:
- 佣金给邀请人成员本人。
- 被邀请人的赠送时间加到被邀请人家主套餐。
- 邀请人不加赠送时间。
- 邀请人家主不拿佣金。
### 10.5 无佣金路径,被邀请人普通用户,邀请人普通用户
触发条件示例:
- `ReferralPercentage=0`
- 或因首购限制导致不发佣金
- 且 `order.is_new=true`
结果:
- 被邀请人本人套餐加赠送时间。
- 邀请人本人套餐加赠送时间。
### 10.6 无佣金路径,被邀请人是家庭成员
结果:
- 被邀请人的赠送时间加到被邀请人家主套餐。
- 邀请人的赠时按邀请人自己的家庭归属解析。
### 10.7 无佣金路径,邀请人是家庭成员
结果:
- 被邀请人的赠时按被邀请人的家庭归属解析。
- 邀请人的赠送时间加到邀请人家主套餐。
- 邀请人成员本人不单独加订阅时间。
### 10.8 被邀请人没有活跃订阅
结果:
- 不创建新订阅。
- 写 skipped gift 日志。
- 后续即使用户后来有订阅,也不会自动补赠,除非另行补偿。
### 10.9 被邀请人或目标家主是永久订阅
结果:
- 识别为活跃订阅。
- 不改变到期时间。
- 写 gift increase 日志。
## 11. 排查 SQL
### 11.1 查邀请配置
```sql
SELECT `key`, `value`, `updated_at`
FROM system
WHERE category = 'invite'
AND `key` IN ('GiftDays', 'OnlyFirstPurchase', 'ReferralPercentage');
```
### 11.2 查某邀请人的被邀请用户
```sql
SELECT id, referer_id, created_at
FROM `user`
WHERE referer_id = 23944
ORDER BY id DESC
LIMIT 100;
```
### 11.3 查被邀请人的订单和首单标记
```sql
SELECT u.id AS invited_user_id,
o.id AS order_id,
o.order_no,
o.type,
o.status,
o.amount,
o.is_new,
o.subscribe_id,
o.subscription_user_id,
o.created_at
FROM `user` u
LEFT JOIN `order` o
ON o.user_id = u.id
AND o.type IN (1, 2)
WHERE u.referer_id = 23944
ORDER BY u.id DESC, o.id ASC
LIMIT 200;
```
### 11.4 查某订单佣金和赠时日志
```sql
SELECT id, type, object_id, content, created_at
FROM system_logs
WHERE content LIKE '%202604281812556044982351822%'
ORDER BY id DESC;
```
### 11.5 查某用户订阅
```sql
SELECT id, user_id, order_id, subscribe_id, status,
expire_time, finished_at, token, created_at, updated_at
FROM user_subscribe
WHERE user_id = 24425
ORDER BY id DESC;
```
### 11.6 查首单但没有赠时日志的被邀请人
```sql
SELECT first_orders.user_id AS invited_user_id,
first_orders.order_no,
first_orders.is_new,
first_orders.status,
first_orders.subscription_user_id,
first_orders.created_at,
(
SELECT COUNT(*)
FROM system_logs sl
WHERE sl.type = 34
AND sl.content LIKE CONCAT('%', first_orders.order_no, '%')
) AS gift_log_count,
(
SELECT COUNT(*)
FROM system_logs sl
WHERE sl.type = 33
AND sl.content LIKE CONCAT('%', first_orders.order_no, '%')
) AS commission_log_count
FROM (
SELECT o.*
FROM `order` o
JOIN (
SELECT user_id, MIN(id) AS first_order_id
FROM `order`
WHERE type IN (1, 2)
AND status IN (2, 5)
GROUP BY user_id
) fo ON fo.first_order_id = o.id
) first_orders
JOIN `user` u ON u.id = first_orders.user_id
WHERE u.referer_id = 23944
ORDER BY first_orders.created_at DESC
LIMIT 100;
```
## 12. 部署注意事项
邀请配置存在两层状态:
1. Redis 缓存:`system:invite_config`
2. 服务进程内存:`svc.Config.Invite`
如果直接修改数据库或 Redis已经运行的 `ppanel-server` 进程不会自动刷新内存配置。订单激活和赠时发生在服务进程/队列 worker 内,所以修改邀请配置或部署赠时逻辑后,需要重启服务。
推荐步骤:
```bash
docker exec ppanel-redis redis-cli DEL system:invite_config system:global_config
docker restart ppanel-server
```
确认启动时间:
```bash
docker inspect --format '{{.Name}} {{.State.StartedAt}} {{.Config.Image}}' ppanel-server
docker ps --filter name=ppanel-server
```

View File

@ -103,6 +103,7 @@ services:
- redis-server - redis-server
- --tcp-backlog 65535 - --tcp-backlog 65535
- --maxmemory-policy allkeys-lru - --maxmemory-policy allkeys-lru
- --requirepass hifast67yj
volumes: volumes:
- redis_data:/data - redis_data:/data
ulimits: ulimits:
@ -201,7 +202,7 @@ services:
container_name: ppanel-grafana container_name: ppanel-grafana
restart: always restart: always
ports: ports:
- "127.0.0.1:3333:3000" # 仅本机可访问,需 SSH 隧道或 Nginx 反代 - "3333:3000" # 仅本机可访问,需 SSH 隧道或 Nginx 反代
environment: environment:
- GF_SECURITY_ADMIN_PASSWORD=${GRAFANA_PASSWORD:?请在 .env 文件中设置 GRAFANA_PASSWORD} - GF_SECURITY_ADMIN_PASSWORD=${GRAFANA_PASSWORD:?请在 .env 文件中设置 GRAFANA_PASSWORD}
- GF_USERS_ALLOW_SIGN_UP=false - GF_USERS_ALLOW_SIGN_UP=false

View File

@ -0,0 +1 @@
ALTER TABLE `user_device` DROP COLUMN `base_payload`;

View File

@ -0,0 +1 @@
ALTER TABLE `user_device` ADD COLUMN `base_payload` TEXT DEFAULT NULL COMMENT 'Base Payload' AFTER `short_code`;

View File

@ -96,6 +96,17 @@ func (l *DeviceLoginLogic) DeviceLogin(req *types.DeviceLoginRequest) (resp *typ
return nil, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "query user failed: %v", err.Error()) return nil, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "query user failed: %v", err.Error())
} }
// Update base_payload if provided
if req.BasePayload != "" && req.BasePayload != deviceInfo.BasePayload {
deviceInfo.BasePayload = req.BasePayload
if updateErr := l.svcCtx.UserModel.UpdateDevice(l.ctx, deviceInfo); updateErr != nil {
l.Errorw("update device base_payload failed",
logger.Field("device_id", deviceInfo.Id),
logger.Field("error", updateErr.Error()),
)
}
}
// 注销后 device auth_method 被删除,重新登录时需要补回 // 注销后 device auth_method 被删除,重新登录时需要补回
hasDeviceAuth := false hasDeviceAuth := false
for _, am := range userInfo.AuthMethods { for _, am := range userInfo.AuthMethods {
@ -225,6 +236,7 @@ func (l *DeviceLoginLogic) registerUserAndDevice(req *types.DeviceLoginRequest)
UserAgent: req.UserAgent, UserAgent: req.UserAgent,
Identifier: req.Identifier, Identifier: req.Identifier,
ShortCode: req.ShortCode, ShortCode: req.ShortCode,
BasePayload: req.BasePayload,
Enabled: true, Enabled: true,
Online: false, Online: false,
} }

View File

@ -126,7 +126,9 @@ func (l *EmailLoginLogic) EmailLogin(req *types.EmailLoginRequest) (resp *types.
return err return err
} }
rc := l.svcCtx.Config.Register rc := l.svcCtx.Config.Register
if ShouldGrantTrialForEmail(rc, req.Email) { if ShouldAutoGrantTrialOnPublicEmailFlows(rc) &&
ShouldGrantTrialForEmail(rc, req.Email) &&
!NormalizedEmailHasTrial(l.ctx, l.svcCtx.DB, req.Email, rc.TrialSubscribe) {
if err = l.activeTrial(userInfo.Id); err != nil { if err = l.activeTrial(userInfo.Id); err != nil {
return err return err
} }

View File

@ -396,7 +396,10 @@ func (l *OAuthLoginGetTokenLogic) register(email, avatar, method, openid, reques
} }
rc := l.svcCtx.Config.Register rc := l.svcCtx.Config.Register
shouldActivateTrial := email != "" && authlogic.ShouldGrantTrialForEmail(rc, email) shouldActivateTrial := email != "" &&
authlogic.ShouldAutoGrantTrialOnPublicEmailFlows(rc) &&
authlogic.ShouldGrantTrialForEmail(rc, email) &&
!authlogic.NormalizedEmailHasTrial(l.ctx, l.svcCtx.DB, email, rc.TrialSubscribe)
if shouldActivateTrial { if shouldActivateTrial {
l.Debugw("activating trial subscription", l.Debugw("activating trial subscription",

View File

@ -1,9 +1,13 @@
package auth package auth
import ( import (
"context"
"net/mail"
"strings" "strings"
"github.com/perfect-panel/server/internal/config" "github.com/perfect-panel/server/internal/config"
usermodel "github.com/perfect-panel/server/internal/model/user"
"gorm.io/gorm"
) )
// IsEmailDomainWhitelisted checks if the email's domain is in the comma-separated whitelist. // IsEmailDomainWhitelisted checks if the email's domain is in the comma-separated whitelist.
@ -12,11 +16,10 @@ func IsEmailDomainWhitelisted(email, whitelistCSV string) bool {
if whitelistCSV == "" { if whitelistCSV == "" {
return false return false
} }
parts := strings.SplitN(email, "@", 2) _, domain, ok := parseStrictEmail(email)
if len(parts) != 2 { if !ok {
return false return false
} }
domain := strings.ToLower(strings.TrimSpace(parts[1]))
for _, d := range strings.Split(whitelistCSV, ",") { for _, d := range strings.Split(whitelistCSV, ",") {
if strings.ToLower(strings.TrimSpace(d)) == domain { if strings.ToLower(strings.TrimSpace(d)) == domain {
return true return true
@ -29,11 +32,177 @@ func ShouldGrantTrialForEmail(register config.RegisterConfig, email string) bool
if !register.EnableTrial { if !register.EnableTrial {
return false return false
} }
if !IsValidTrialEmail(email) {
return false
}
// 无论白名单是否启用,泛域名邮箱(含 + 别名或 Gmail 点号)始终拒绝赠送
if IsDisposableAlias(email) {
return false
}
if isConfusableGmailDomain(emailDomain(email)) {
return false
}
if !register.EnableTrialEmailWhitelist { if !register.EnableTrialEmailWhitelist {
return true return true
} }
if register.TrialEmailDomainWhitelist == "" { if register.TrialEmailDomainWhitelist == "" {
return false return false
} }
return IsEmailDomainWhitelisted(email, register.TrialEmailDomainWhitelist) if !IsEmailDomainWhitelisted(email, register.TrialEmailDomainWhitelist) {
return false
}
return true
}
// IsTrialConfigReady verifies that trial auto-grant has all required config.
func IsTrialConfigReady(register config.RegisterConfig) bool {
return register.EnableTrial &&
register.TrialSubscribe > 0 &&
register.TrialTime > 0 &&
strings.TrimSpace(register.TrialTimeUnit) != ""
}
// ShouldAutoGrantTrialOnPublicEmailFlows defines whether browser/email-originated
// flows may auto-create a trial subscription. Email-specific abuse protection
// is still handled by ShouldGrantTrialForEmail and NormalizedEmailHasTrial.
func ShouldAutoGrantTrialOnPublicEmailFlows(register config.RegisterConfig) bool {
return IsTrialConfigReady(register)
}
// IsDisposableAlias detects Gmail dot trick and + alias abuse.
// For Gmail-like domains, local part containing "." or "+" is rejected.
// For all other domains, only "+" alias is rejected.
func IsDisposableAlias(email string) bool {
local, domain, ok := parseStrictEmail(email)
if !ok {
return false
}
// All domains: reject + alias
if strings.ContainsRune(local, '+') {
return true
}
// Gmail-like domains: reject dots in local part
if isGmailLikeDomain(domain) && strings.ContainsRune(local, '.') {
return true
}
return false
}
// NormalizeEmail returns a canonical form of the email for trial deduplication.
// Strips "+" aliases universally (user+tag@any.com → user@any.com).
// Removes dots from local part for Gmail-like providers (gmail.com, googlemail.com).
func NormalizeEmail(email string) string {
email = strings.ToLower(strings.TrimSpace(email))
local, domain, ok := parseStrictEmail(email)
if !ok {
return email
}
// Strip + alias
if idx := strings.IndexByte(local, '+'); idx != -1 {
local = local[:idx]
}
// Remove dots for Gmail-like providers that ignore dots in local part
if isGmailLikeDomain(domain) {
local = strings.ReplaceAll(local, ".", "")
}
return local + "@" + domain
}
func isGmailLikeDomain(domain string) bool {
switch domain {
case "gmail.com", "googlemail.com":
return true
}
return false
}
func IsValidTrialEmail(email string) bool {
local, domain, ok := parseStrictEmail(email)
if !ok {
return false
}
return local != "" && domain != ""
}
func parseStrictEmail(email string) (local, domain string, ok bool) {
email = strings.ToLower(strings.TrimSpace(email))
if email == "" || strings.ContainsAny(email, " \t\r\n") {
return "", "", false
}
addr, err := mail.ParseAddress(email)
if err != nil || addr.Address != email || addr.Name != "" {
return "", "", false
}
parts := strings.Split(addr.Address, "@")
if len(parts) != 2 {
return "", "", false
}
local = strings.TrimSpace(parts[0])
domain = strings.Trim(strings.TrimSpace(parts[1]), ".")
if local == "" || domain == "" || strings.Contains(domain, "..") || !strings.Contains(domain, ".") {
return "", "", false
}
return local, domain, true
}
func emailDomain(email string) string {
_, domain, ok := parseStrictEmail(email)
if !ok {
return ""
}
return domain
}
func isConfusableGmailDomain(domain string) bool {
switch strings.ToLower(strings.TrimSpace(domain)) {
case "gmaial.com", "gmial.com", "gmai.com", "gamil.com", "gmal.com", "gmail.co", "gmail.con":
return true
}
return false
}
// NormalizedEmailHasTrial returns true if any user with the same normalized email
// already holds a trial subscription. Only performs the cross-user DB check when
// normalization actually changes the email (i.e., dots removed or + alias stripped).
func NormalizedEmailHasTrial(ctx context.Context, db *gorm.DB, email string, trialSubscribeId int64) bool {
normalized := NormalizeEmail(email)
if normalized == strings.ToLower(strings.TrimSpace(email)) {
return false // normalization changed nothing, skip cross-user check
}
parts := strings.SplitN(normalized, "@", 2)
if len(parts) != 2 {
return false
}
domain := parts[1]
var authMethods []usermodel.AuthMethods
if err := db.WithContext(ctx).
Model(&usermodel.AuthMethods{}).
Select("user_id, auth_identifier").
Where("auth_type = ? AND auth_identifier LIKE ?", "email", "%@"+domain).
Find(&authMethods).Error; err != nil {
return false
}
for _, am := range authMethods {
if NormalizeEmail(am.AuthIdentifier) != normalized {
continue
}
var count int64
if err := db.WithContext(ctx).
Model(&usermodel.Subscribe{}).
Where("user_id = ? AND subscribe_id = ?", am.UserId, trialSubscribeId).
Count(&count).Error; err != nil {
continue
}
if count > 0 {
return true
}
}
return false
} }

View File

@ -0,0 +1,198 @@
package auth
import (
"testing"
"github.com/perfect-panel/server/internal/config"
"github.com/stretchr/testify/assert"
)
func TestNormalizeEmail(t *testing.T) {
tests := []struct {
input string
want string
}{
// Gmail dot trick
{"a.v.x.xx@gmail.com", "avxxx@gmail.com"},
{"john.doe@gmail.com", "johndoe@gmail.com"},
{"a.b.c.d.e@gmail.com", "abcde@gmail.com"},
// Gmail + alias
{"user+tag@gmail.com", "user@gmail.com"},
{"a.b+tag@gmail.com", "ab@gmail.com"},
// Googlemail
{"a.b@googlemail.com", "ab@googlemail.com"},
// Non-Gmail: dots preserved
{"john.doe@outlook.com", "john.doe@outlook.com"},
{"john.doe@qq.com", "john.doe@qq.com"},
// + alias stripped for all providers
{"user+spam@outlook.com", "user@outlook.com"},
{"user+spam@qq.com", "user@qq.com"},
// Case insensitive
{"User@Gmail.COM", "user@gmail.com"},
{"A.B@Gmail.com", "ab@gmail.com"},
// No change for normal non-gmail email
{"abc@163.com", "abc@163.com"},
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
got := NormalizeEmail(tt.input)
if got != tt.want {
t.Errorf("NormalizeEmail(%q) = %q, want %q", tt.input, got, tt.want)
}
})
}
}
func TestNormalizeEmail_NoChangeSkipsCheck(t *testing.T) {
// These emails should NOT trigger cross-user check (normalized == original)
noChangeCases := []string{
"abc@163.com",
"john.doe@outlook.com",
"user@qq.com",
}
for _, email := range noChangeCases {
normalized := NormalizeEmail(email)
lower := email
if normalized == lower {
// correct: no normalization change, NormalizedEmailHasTrial would return false early
}
}
}
func TestShouldGrantTrialForEmail(t *testing.T) {
// 模拟线上配置白名单开启gmail.com 也在名单里
rcWithGmail := config.RegisterConfig{
EnableTrial: true,
EnableTrialEmailWhitelist: true,
TrialEmailDomainWhitelist: "hifastapp.com,hifastvpn.com,126.com,139.com,163.com,gmail.com",
}
// 白名单关闭
rcNoWhitelist := config.RegisterConfig{
EnableTrial: true,
EnableTrialEmailWhitelist: false,
}
tests := []struct {
name string
rc config.RegisterConfig
email string
want bool
reason string
}{
{
name: "gmail dot trick - blocked even if gmail.com in whitelist",
rc: rcWithGmail,
email: "s.m.s.n.fsmbt.d.ndny@gmail.com",
want: false,
reason: "gmail 泛域名(含点号)应拒绝",
},
{
name: "gmail plus alias - blocked",
rc: rcWithGmail,
email: "user+tag@gmail.com",
want: false,
reason: "gmail +别名应拒绝",
},
{
name: "clean gmail - allowed",
rc: rcWithGmail,
email: "normaluser@gmail.com",
want: true,
reason: "干净的 gmail 应放行",
},
{
name: "163 with dot - allowed (non-gmail dot is ok)",
rc: rcWithGmail,
email: "s.m.s.n@163.com",
want: true,
reason: "非 gmail 域点号不拦截",
},
{
name: "163 plus alias - blocked",
rc: rcWithGmail,
email: "user+spam@163.com",
want: false,
reason: "所有域名的 +别名都拦截",
},
{
name: "gmail typo squatting domain - blocked even if accidentally whitelisted",
rc: config.RegisterConfig{EnableTrial: true, EnableTrialEmailWhitelist: true, TrialEmailDomainWhitelist: "gmail.com,gmaial.com"},
email: "1.2.3.4xxx@gmaial.com",
want: false,
reason: "易混淆 Gmail 域名不应发放试用",
},
{
name: "invalid empty local - blocked",
rc: rcWithGmail,
email: "@gmail.com",
want: false,
reason: "邮箱 local 为空应拒绝",
},
{
name: "subdomain spoof - blocked",
rc: rcWithGmail,
email: "user@fake.gmail.com",
want: false,
reason: "白名单必须精确匹配域名,不匹配子域",
},
{
name: "whitelist disabled - gmail dot trick still blocked",
rc: rcNoWhitelist,
email: "s.m.s.n.fsmbt.d.ndny@gmail.com",
want: false,
reason: "白名单未启用,但泛域名仍应拒绝",
},
{
name: "trial disabled - always blocked",
rc: config.RegisterConfig{EnableTrial: false},
email: "user@163.com",
want: false,
reason: "试用未开启",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := ShouldGrantTrialForEmail(tt.rc, tt.email)
if got != tt.want {
t.Errorf("ShouldGrantTrialForEmail(%q) = %v, want %v | reason: %s",
tt.email, got, tt.want, tt.reason)
}
})
}
}
func TestShouldAutoGrantTrialOnPublicEmailFlows(t *testing.T) {
assert.False(t, ShouldAutoGrantTrialOnPublicEmailFlows(config.RegisterConfig{}))
assert.False(t, ShouldAutoGrantTrialOnPublicEmailFlows(config.RegisterConfig{
EnableTrial: true,
EnableTrialEmailWhitelist: true,
TrialEmailDomainWhitelist: "gmail.com,example.com",
}))
}
func TestIsEmailDomainWhitelisted(t *testing.T) {
whitelist := "gmail.com,edu.cn,outlook.com"
tests := []struct {
email string
want bool
}{
{"user@gmail.com", true},
{"user@edu.cn", true},
{"User@Gmail.COM", true},
{"user@yahoo.com", false},
{"user@fake.gmail.com", false}, // subdomain not matched
{"user@", false},
{"notanemail", false},
{"@gmail.com", false},
}
for _, tt := range tests {
t.Run(tt.email, func(t *testing.T) {
got := IsEmailDomainWhitelisted(tt.email, whitelist)
if got != tt.want {
t.Errorf("IsEmailDomainWhitelisted(%q) = %v, want %v", tt.email, got, tt.want)
}
})
}
}

View File

@ -148,7 +148,9 @@ func (l *UserRegisterLogic) UserRegister(req *types.UserRegisterRequest) (resp *
// Activate trial subscription after transaction success (moved outside transaction to reduce lock time) // Activate trial subscription after transaction success (moved outside transaction to reduce lock time)
rc := l.svcCtx.Config.Register rc := l.svcCtx.Config.Register
if ShouldGrantTrialForEmail(rc, req.Email) { if ShouldAutoGrantTrialOnPublicEmailFlows(rc) &&
ShouldGrantTrialForEmail(rc, req.Email) &&
!NormalizedEmailHasTrial(l.ctx, l.svcCtx.DB, req.Email, rc.TrialSubscribe) {
trialSubscribe, err = l.activeTrial(userInfo.Id) trialSubscribe, err = l.activeTrial(userInfo.Id)
if err != nil { if err != nil {
l.Errorw("Failed to activate trial subscription", logger.Field("error", err.Error())) l.Errorw("Failed to activate trial subscription", logger.Field("error", err.Error()))

View File

@ -76,7 +76,10 @@ func CountScopedSubscribePurchaseOrders(
var count int64 var count int64
query := db.WithContext(ctx). query := db.WithContext(ctx).
Model(&modelOrder.Order{}). Model(&modelOrder.Order{}).
Where("user_id IN ? AND subscribe_id = ? AND type IN ? AND amount > 0", scopeUserIDs, subscribeID, []int64{1, 2}) Where("user_id IN ? AND type IN ?", scopeUserIDs, []int64{1, 2})
if subscribeID > 0 {
query = query.Where("subscribe_id = ?", subscribeID)
}
if len(statuses) > 0 { if len(statuses) > 0 {
query = query.Where("status IN ?", statuses) query = query.Where("status IN ?", statuses)
} }

View File

@ -52,12 +52,8 @@ func ResolvePurchaseRoute(
return decision, nil return decision, nil
} }
if requestedSubscribeID != anchorSub.SubscribeId {
return nil, ErrSingleModePlanMismatch
}
decision.Route = PurchaseRoutePurchaseToRenewal decision.Route = PurchaseRoutePurchaseToRenewal
decision.ResolvedSubscribeID = anchorSub.SubscribeId decision.ResolvedSubscribeID = requestedSubscribeID
decision.Anchor = anchorSub decision.Anchor = anchorSub
return decision, nil return decision, nil
} }

View File

@ -0,0 +1,36 @@
package common
import (
"context"
"testing"
"time"
"github.com/perfect-panel/server/internal/model/user"
"github.com/stretchr/testify/require"
)
func TestResolvePurchaseRoute_AllowsPlanChangeForExistingSubscription(t *testing.T) {
anchor := &user.Subscribe{
Id: 10,
UserId: 20,
OrderId: 30,
SubscribeId: 1,
Token: "existing-token",
ExpireTime: time.Now().Add(time.Hour),
}
decision, err := ResolvePurchaseRoute(
context.Background(),
true,
anchor.UserId,
2,
func(context.Context, int64) (*user.Subscribe, error) {
return anchor, nil
},
)
require.NoError(t, err)
require.Equal(t, PurchaseRoutePurchaseToRenewal, decision.Route)
require.Equal(t, int64(2), decision.ResolvedSubscribeID)
require.Equal(t, anchor, decision.Anchor)
}

View File

@ -0,0 +1,108 @@
package common
import (
"strings"
ordermodel "github.com/perfect-panel/server/internal/model/order"
usermodel "github.com/perfect-panel/server/internal/model/user"
"github.com/perfect-panel/server/pkg/logger"
)
const (
SubscriptionTraceType = "subscription_flow"
SubscriptionTraceFlowOrder = "order_subscription"
SubscriptionTraceFlowEmailBind = "email_bind_subscription"
)
func SubscriptionTraceFields(flow string, stage string, fields ...logger.LogField) []logger.LogField {
base := []logger.LogField{
logger.Field("trace_type", SubscriptionTraceType),
logger.Field("flow", flow),
logger.Field("stage", stage),
}
return append(base, fields...)
}
func SubscriptionTraceInfo(log logger.Logger, flow string, stage string, msg string, fields ...logger.LogField) {
log.Infow(msg, SubscriptionTraceFields(flow, stage, fields...)...)
}
func SubscriptionTraceError(log logger.Logger, flow string, stage string, msg string, fields ...logger.LogField) {
log.Errorw(msg, SubscriptionTraceFields(flow, stage, fields...)...)
}
func OrderTraceFields(orderInfo *ordermodel.Order) []logger.LogField {
if orderInfo == nil {
return nil
}
effectiveUserID := orderInfo.UserId
if orderInfo.SubscriptionUserId > 0 {
effectiveUserID = orderInfo.SubscriptionUserId
}
fields := []logger.LogField{
logger.Field("order_id", orderInfo.Id),
logger.Field("order_no", orderInfo.OrderNo),
logger.Field("order_type", orderInfo.Type),
logger.Field("order_status", orderInfo.Status),
logger.Field("user_id", orderInfo.UserId),
logger.Field("subscription_user_id", orderInfo.SubscriptionUserId),
logger.Field("effective_user_id", effectiveUserID),
logger.Field("order_subscribe_id", orderInfo.SubscribeId),
logger.Field("payment_id", orderInfo.PaymentId),
logger.Field("payment_method", orderInfo.Method),
logger.Field("parent_order_id", orderInfo.ParentId),
logger.Field("quantity", orderInfo.Quantity),
logger.Field("is_new_order", orderInfo.IsNew),
}
if tail := SensitiveTail(orderInfo.SubscribeToken); tail != "" {
fields = append(fields, logger.Field("subscribe_token_tail", tail))
}
if tail := SensitiveTail(orderInfo.TradeNo); tail != "" {
fields = append(fields, logger.Field("trade_no_tail", tail))
}
if tail := SensitiveTail(orderInfo.AppAccountToken); tail != "" {
fields = append(fields, logger.Field("app_account_token_tail", tail))
}
return fields
}
func UserSubscribeTraceFields(userSub *usermodel.Subscribe) []logger.LogField {
if userSub == nil {
return nil
}
fields := []logger.LogField{
logger.Field("user_subscribe_id", userSub.Id),
logger.Field("subscribe_owner_user_id", userSub.UserId),
logger.Field("user_subscribe_plan_id", userSub.SubscribeId),
logger.Field("subscribe_order_id", userSub.OrderId),
logger.Field("subscribe_status", userSub.Status),
logger.Field("expire_time", userSub.ExpireTime),
}
if tail := SensitiveTail(userSub.Token); tail != "" {
fields = append(fields, logger.Field("subscribe_token_tail", tail))
}
if tail := SensitiveTail(userSub.UUID); tail != "" {
fields = append(fields, logger.Field("subscribe_uuid_tail", tail))
}
return fields
}
func SensitiveTail(value string) string {
value = strings.TrimSpace(value)
if value == "" {
return ""
}
if len(value) <= 8 {
return value
}
return value[len(value)-8:]
}

View File

@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"net/http" "net/http"
commonLogic "github.com/perfect-panel/server/internal/logic/common"
"github.com/perfect-panel/server/pkg/constant" "github.com/perfect-panel/server/pkg/constant"
"github.com/perfect-panel/server/pkg/xerr" "github.com/perfect-panel/server/pkg/xerr"
@ -56,6 +57,12 @@ func (l *AlipayNotifyLogic) AlipayNotify(r *http.Request) error {
l.Logger.Error("[AlipayNotify] Decode notification failed", logger.Field("error", err.Error())) l.Logger.Error("[AlipayNotify] Decode notification failed", logger.Field("error", err.Error()))
return err return err
} }
commonLogic.SubscriptionTraceInfo(l.Logger, commonLogic.SubscriptionTraceFlowOrder, "payment_notify_received",
"[SubscriptionFlow] alipay notify received",
logger.Field("order_no", notify.OrderNo),
logger.Field("payment_platform", data.Platform),
logger.Field("notify_status", string(notify.Status)),
)
if notify.Status == alipay.Success { if notify.Status == alipay.Success {
orderInfo, err := l.svcCtx.OrderModel.FindOneByOrderNo(l.ctx, notify.OrderNo) orderInfo, err := l.svcCtx.OrderModel.FindOneByOrderNo(l.ctx, notify.OrderNo)
if err != nil { if err != nil {
@ -73,6 +80,12 @@ func (l *AlipayNotifyLogic) AlipayNotify(r *http.Request) error {
l.Logger.Error("[AlipayNotify] Update order status failed", logger.Field("error", err.Error()), logger.Field("orderNo", notify.OrderNo)) l.Logger.Error("[AlipayNotify] Update order status failed", logger.Field("error", err.Error()), logger.Field("orderNo", notify.OrderNo))
return err return err
} }
commonLogic.SubscriptionTraceInfo(l.Logger, commonLogic.SubscriptionTraceFlowOrder, "payment_settled",
"[SubscriptionFlow] alipay notify marked order as paid",
append(commonLogic.OrderTraceFields(orderInfo),
logger.Field("payment_platform", data.Platform),
)...,
)
l.Logger.Info("[AlipayNotify] Notify status success", logger.Field("orderNo", notify.OrderNo)) l.Logger.Info("[AlipayNotify] Notify status success", logger.Field("orderNo", notify.OrderNo))
payload := types.ForthwithActivateOrderPayload{ payload := types.ForthwithActivateOrderPayload{
OrderNo: notify.OrderNo, OrderNo: notify.OrderNo,
@ -88,6 +101,13 @@ func (l *AlipayNotifyLogic) AlipayNotify(r *http.Request) error {
l.Logger.Error("[AlipayNotify] Enqueue task failed", logger.Field("error", err.Error())) l.Logger.Error("[AlipayNotify] Enqueue task failed", logger.Field("error", err.Error()))
return err return err
} }
commonLogic.SubscriptionTraceInfo(l.Logger, commonLogic.SubscriptionTraceFlowOrder, "activation_task_enqueued",
"[SubscriptionFlow] activation task enqueued from alipay notify",
append(commonLogic.OrderTraceFields(orderInfo),
logger.Field("payment_platform", data.Platform),
logger.Field("queue_task_id", taskInfo.ID),
)...,
)
l.Logger.Info("[AlipayNotify] Enqueue task success", logger.Field("taskInfo", taskInfo)) l.Logger.Info("[AlipayNotify] Enqueue task success", logger.Field("taskInfo", taskInfo))
} else { } else {
l.Logger.Error("[AlipayNotify] Notify status failed", logger.Field("status", string(notify.Status))) l.Logger.Error("[AlipayNotify] Notify status failed", logger.Field("status", string(notify.Status)))

View File

@ -6,6 +6,7 @@ import (
"strconv" "strconv"
"strings" "strings"
commonLogic "github.com/perfect-panel/server/internal/logic/common"
iapmodel "github.com/perfect-panel/server/internal/model/iap/apple" iapmodel "github.com/perfect-panel/server/internal/model/iap/apple"
"github.com/perfect-panel/server/internal/model/subscribe" "github.com/perfect-panel/server/internal/model/subscribe"
"github.com/perfect-panel/server/internal/model/user" "github.com/perfect-panel/server/internal/model/user"
@ -57,6 +58,13 @@ func (l *AppleIAPNotifyLogic) Handle(signedPayload string) error {
} }
// 验签通过,记录通知类型与关键交易标识 // 验签通过,记录通知类型与关键交易标识
l.Infow("iap notify verified", logger.Field("type", ntype), logger.Field("productId", txPayload.ProductId), logger.Field("originalTransactionId", txPayload.OriginalTransactionId)) l.Infow("iap notify verified", logger.Field("type", ntype), logger.Field("productId", txPayload.ProductId), logger.Field("originalTransactionId", txPayload.OriginalTransactionId))
commonLogic.SubscriptionTraceInfo(l.Logger, commonLogic.SubscriptionTraceFlowOrder, "payment_notify_received",
"[SubscriptionFlow] apple iap server notification received",
logger.Field("notify_type", ntype),
logger.Field("product_id", txPayload.ProductId),
logger.Field("original_transaction_tail", commonLogic.SensitiveTail(txPayload.OriginalTransactionId)),
logger.Field("transaction_id_tail", commonLogic.SensitiveTail(txPayload.TransactionId)),
)
return l.svcCtx.DB.Transaction(func(db *gorm.DB) error { return l.svcCtx.DB.Transaction(func(db *gorm.DB) error {
var existing *iapmodel.Transaction var existing *iapmodel.Transaction
existing, _ = iapmodel.NewModel(l.svcCtx.DB, l.svcCtx.Redis).FindByOriginalId(l.ctx, txPayload.OriginalTransactionId) existing, _ = iapmodel.NewModel(l.svcCtx.DB, l.svcCtx.Redis).FindByOriginalId(l.ctx, txPayload.OriginalTransactionId)
@ -201,6 +209,13 @@ func (l *AppleIAPNotifyLogic) Handle(signedPayload string) error {
return err return err
} }
l.Infow("iap notify fallback updated subscribe", logger.Field("userSubscribeId", candidate.Id), logger.Field("status", candidate.Status)) l.Infow("iap notify fallback updated subscribe", logger.Field("userSubscribeId", candidate.Id), logger.Field("status", candidate.Status))
commonLogic.SubscriptionTraceInfo(l.Logger, commonLogic.SubscriptionTraceFlowOrder, "subscription_updated_from_notify",
"[SubscriptionFlow] apple iap notify updated fallback subscription candidate",
append(commonLogic.UserSubscribeTraceFields(candidate),
logger.Field("notify_type", ntype),
logger.Field("product_id", txPayload.ProductId),
)...,
)
break break
} }
} }
@ -226,6 +241,13 @@ func (l *AppleIAPNotifyLogic) Handle(signedPayload string) error {
} }
// 更新成功,输出订阅状态 // 更新成功,输出订阅状态
l.Infow("iap notify updated subscribe", logger.Field("userSubscribeId", sub.Id), logger.Field("status", sub.Status)) l.Infow("iap notify updated subscribe", logger.Field("userSubscribeId", sub.Id), logger.Field("status", sub.Status))
commonLogic.SubscriptionTraceInfo(l.Logger, commonLogic.SubscriptionTraceFlowOrder, "subscription_updated_from_notify",
"[SubscriptionFlow] apple iap notify updated subscription",
append(commonLogic.UserSubscribeTraceFields(sub),
logger.Field("notify_type", ntype),
logger.Field("product_id", txPayload.ProductId),
)...,
)
} }
return nil return nil
}) })

View File

@ -4,6 +4,7 @@ import (
"encoding/json" "encoding/json"
"net/url" "net/url"
commonLogic "github.com/perfect-panel/server/internal/logic/common"
"github.com/perfect-panel/server/pkg/constant" "github.com/perfect-panel/server/pkg/constant"
"github.com/perfect-panel/server/pkg/xerr" "github.com/perfect-panel/server/pkg/xerr"
@ -44,12 +45,18 @@ func (l *EPayNotifyLogic) EPayNotify(req *types.EPayNotifyRequest) error {
l.Logger.Error("[EPayNotify] Payment not found in context") l.Logger.Error("[EPayNotify] Payment not found in context")
return errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "payment config not found") return errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "payment config not found")
} }
l.Infof("[EPayNotify] Payment config: %+v", data)
orderInfo, err := l.svcCtx.OrderModel.FindOneByOrderNo(l.ctx, req.OutTradeNo) orderInfo, err := l.svcCtx.OrderModel.FindOneByOrderNo(l.ctx, req.OutTradeNo)
if err != nil { if err != nil {
l.Logger.Error("[EPayNotify] Find order failed", logger.Field("error", err.Error()), logger.Field("orderNo", req.OutTradeNo)) l.Logger.Error("[EPayNotify] Find order failed", logger.Field("error", err.Error()), logger.Field("orderNo", req.OutTradeNo))
return errors.Wrapf(xerr.NewErrCode(xerr.OrderNotExist), "order not exist: %v", req.OutTradeNo) return errors.Wrapf(xerr.NewErrCode(xerr.OrderNotExist), "order not exist: %v", req.OutTradeNo)
} }
commonLogic.SubscriptionTraceInfo(l.Logger, commonLogic.SubscriptionTraceFlowOrder, "payment_notify_received",
"[SubscriptionFlow] epay notify received",
append(commonLogic.OrderTraceFields(orderInfo),
logger.Field("payment_platform", data.Platform),
logger.Field("trade_status", req.TradeStatus),
)...,
)
var config payment.EPayConfig var config payment.EPayConfig
if err := json.Unmarshal([]byte(data.Config), &config); err != nil { if err := json.Unmarshal([]byte(data.Config), &config); err != nil {
@ -75,6 +82,12 @@ func (l *EPayNotifyLogic) EPayNotify(req *types.EPayNotifyRequest) error {
l.Logger.Error("[EPayNotify] Update order status failed", logger.Field("error", err.Error()), logger.Field("orderNo", req.OutTradeNo)) l.Logger.Error("[EPayNotify] Update order status failed", logger.Field("error", err.Error()), logger.Field("orderNo", req.OutTradeNo))
return err return err
} }
commonLogic.SubscriptionTraceInfo(l.Logger, commonLogic.SubscriptionTraceFlowOrder, "payment_settled",
"[SubscriptionFlow] epay notify marked order as paid",
append(commonLogic.OrderTraceFields(orderInfo),
logger.Field("payment_platform", data.Platform),
)...,
)
// Create activate order task // Create activate order task
payload := queueType.ForthwithActivateOrderPayload{ payload := queueType.ForthwithActivateOrderPayload{
OrderNo: req.OutTradeNo, OrderNo: req.OutTradeNo,
@ -90,6 +103,13 @@ func (l *EPayNotifyLogic) EPayNotify(req *types.EPayNotifyRequest) error {
l.Logger.Error("[EPayNotify] Enqueue task failed", logger.Field("error", err.Error())) l.Logger.Error("[EPayNotify] Enqueue task failed", logger.Field("error", err.Error()))
return err return err
} }
commonLogic.SubscriptionTraceInfo(l.Logger, commonLogic.SubscriptionTraceFlowOrder, "activation_task_enqueued",
"[SubscriptionFlow] activation task enqueued from epay notify",
append(commonLogic.OrderTraceFields(orderInfo),
logger.Field("payment_platform", data.Platform),
logger.Field("queue_task_id", taskInfo.ID),
)...,
)
l.Logger.Info("[EPayNotify] Enqueue task success", logger.Field("taskInfo", taskInfo)) l.Logger.Info("[EPayNotify] Enqueue task success", logger.Field("taskInfo", taskInfo))
return nil return nil
} }

View File

@ -6,6 +6,7 @@ import (
"io" "io"
"net/http" "net/http"
commonLogic "github.com/perfect-panel/server/internal/logic/common"
"github.com/perfect-panel/server/pkg/constant" "github.com/perfect-panel/server/pkg/constant"
"github.com/perfect-panel/server/pkg/xerr" "github.com/perfect-panel/server/pkg/xerr"
@ -67,6 +68,13 @@ func (l *StripeNotifyLogic) StripeNotify(r *http.Request, w http.ResponseWriter)
l.Logger.Error("[StripeNotify] Find order failed", logger.Field("error", err.Error()), logger.Field("orderNo", notify.OrderNo)) l.Logger.Error("[StripeNotify] Find order failed", logger.Field("error", err.Error()), logger.Field("orderNo", notify.OrderNo))
return errors.Wrapf(xerr.NewErrCode(xerr.OrderNotExist), "order not exist: %v", notify.OrderNo) return errors.Wrapf(xerr.NewErrCode(xerr.OrderNotExist), "order not exist: %v", notify.OrderNo)
} }
commonLogic.SubscriptionTraceInfo(l.Logger, commonLogic.SubscriptionTraceFlowOrder, "payment_notify_received",
"[SubscriptionFlow] stripe notify received",
append(commonLogic.OrderTraceFields(orderInfo),
logger.Field("payment_platform", stripeConfig.Platform),
logger.Field("stripe_event_type", notify.EventType),
)...,
)
if notify.EventType == "payment_intent.succeeded" { if notify.EventType == "payment_intent.succeeded" {
if orderInfo.Status == 5 { if orderInfo.Status == 5 {
return nil return nil
@ -76,6 +84,13 @@ func (l *StripeNotifyLogic) StripeNotify(r *http.Request, w http.ResponseWriter)
if err != nil { if err != nil {
return err return err
} }
commonLogic.SubscriptionTraceInfo(l.Logger, commonLogic.SubscriptionTraceFlowOrder, "payment_settled",
"[SubscriptionFlow] stripe notify marked order as paid",
append(commonLogic.OrderTraceFields(orderInfo),
logger.Field("payment_platform", stripeConfig.Platform),
logger.Field("stripe_event_type", notify.EventType),
)...,
)
// create ActivateOrder task // create ActivateOrder task
payload := types.ForthwithActivateOrderPayload{ payload := types.ForthwithActivateOrderPayload{
OrderNo: notify.OrderNo, OrderNo: notify.OrderNo,
@ -86,11 +101,19 @@ func (l *StripeNotifyLogic) StripeNotify(r *http.Request, w http.ResponseWriter)
return err return err
} }
task := asynq.NewTask(types.ForthwithActivateOrder, bytes, asynq.MaxRetry(5)) task := asynq.NewTask(types.ForthwithActivateOrder, bytes, asynq.MaxRetry(5))
_, err = l.svcCtx.Queue.Enqueue(task) taskInfo, err := l.svcCtx.Queue.Enqueue(task)
if err != nil { if err != nil {
l.Errorw("[StripeNotify] Enqueue error", logger.Field("errors", err.Error())) l.Errorw("[StripeNotify] Enqueue error", logger.Field("errors", err.Error()))
return err return err
} }
commonLogic.SubscriptionTraceInfo(l.Logger, commonLogic.SubscriptionTraceFlowOrder, "activation_task_enqueued",
"[SubscriptionFlow] activation task enqueued from stripe notify",
append(commonLogic.OrderTraceFields(orderInfo),
logger.Field("payment_platform", stripeConfig.Platform),
logger.Field("stripe_event_type", notify.EventType),
logger.Field("queue_task_id", taskInfo.ID),
)...,
)
l.Infow("[StripeNotify] success", logger.Field("orderNo", notify.OrderNo)) l.Infow("[StripeNotify] success", logger.Field("orderNo", notify.OrderNo))
} }
return nil return nil

View File

@ -82,6 +82,13 @@ func (l *AttachTransactionLogic) Attach(req *types.AttachAppleTransactionRequest
l.Errorw("订单与当前用户不匹配", logger.Field("orderNo", req.OrderNo), logger.Field("orderUserId", orderInfo.UserId), logger.Field("userId", u.Id)) l.Errorw("订单与当前用户不匹配", logger.Field("orderNo", req.OrderNo), logger.Field("orderUserId", orderInfo.UserId), logger.Field("userId", u.Id))
return nil, errors.Wrapf(xerr.NewErrCode(xerr.InvalidAccess), "order owner mismatch") return nil, errors.Wrapf(xerr.NewErrCode(xerr.InvalidAccess), "order owner mismatch")
} }
commonLogic.SubscriptionTraceInfo(l.Logger, commonLogic.SubscriptionTraceFlowOrder, "iap_attach_start",
"[SubscriptionFlow] apple iap attach flow started",
append(commonLogic.OrderTraceFields(orderInfo),
logger.Field("request_user_id", u.Id),
logger.Field("effective_user_id", entitlement.EffectiveUserID),
)...,
)
isNewPurchaseOrder := orderInfo.Type == orderTypeSubscribe isNewPurchaseOrder := orderInfo.Type == orderTypeSubscribe
if isNewPurchaseOrder { if isNewPurchaseOrder {
l.Infow("首购订单将只由订单激活流程创建订阅", logger.Field("orderNo", req.OrderNo), logger.Field("orderType", orderInfo.Type)) l.Infow("首购订单将只由订单激活流程创建订阅", logger.Field("orderNo", req.OrderNo), logger.Field("orderType", orderInfo.Type))
@ -93,6 +100,14 @@ func (l *AttachTransactionLogic) Attach(req *types.AttachAppleTransactionRequest
return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "invalid jws") return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "invalid jws")
} }
l.Infow("JWS 验签成功", logger.Field("productId", txPayload.ProductId), logger.Field("originalTransactionId", txPayload.OriginalTransactionId), logger.Field("purchaseAt", txPayload.PurchaseDate)) l.Infow("JWS 验签成功", logger.Field("productId", txPayload.ProductId), logger.Field("originalTransactionId", txPayload.OriginalTransactionId), logger.Field("purchaseAt", txPayload.PurchaseDate))
commonLogic.SubscriptionTraceInfo(l.Logger, commonLogic.SubscriptionTraceFlowOrder, "iap_attach_verified",
"[SubscriptionFlow] apple iap transaction verified",
append(commonLogic.OrderTraceFields(orderInfo),
logger.Field("product_id", txPayload.ProductId),
logger.Field("original_transaction_tail", commonLogic.SensitiveTail(txPayload.OriginalTransactionId)),
logger.Field("transaction_id_tail", commonLogic.SensitiveTail(txPayload.TransactionId)),
)...,
)
tradeNoCandidates := l.getAppleTradeNoCandidates(txPayload) tradeNoCandidates := l.getAppleTradeNoCandidates(txPayload)
existingOrderNo, validateErr := l.validateOrderTradeNoBinding(orderInfo, tradeNoCandidates) existingOrderNo, validateErr := l.validateOrderTradeNoBinding(orderInfo, tradeNoCandidates)
if validateErr != nil { if validateErr != nil {
@ -390,6 +405,12 @@ func (l *AttachTransactionLogic) Attach(req *types.AttachAppleTransactionRequest
return e return e
} }
l.Infow("写入用户订阅成功", logger.Field("userId", u.Id), logger.Field("subscribeId", subscribeId), logger.Field("expireUnix", exp.Unix())) l.Infow("写入用户订阅成功", logger.Field("userId", u.Id), logger.Field("subscribeId", subscribeId), logger.Field("expireUnix", exp.Unix()))
commonLogic.SubscriptionTraceInfo(l.Logger, commonLogic.SubscriptionTraceFlowOrder, "subscription_created",
"[SubscriptionFlow] apple iap attach created a subscription placeholder before queue activation",
append(commonLogic.OrderTraceFields(orderInfo),
commonLogic.UserSubscribeTraceFields(&userSub)...,
)...,
)
} }
} else { } else {
l.Infow("首购订单跳过 attach 阶段订阅写入", logger.Field("orderNo", orderInfo.OrderNo), logger.Field("orderType", orderInfo.Type)) l.Infow("首购订单跳过 attach 阶段订阅写入", logger.Field("orderNo", orderInfo.OrderNo), logger.Field("orderType", orderInfo.Type))
@ -453,6 +474,12 @@ func (l *AttachTransactionLogic) syncOrderStatusAndEnqueue(orderInfo *ordermodel
} }
orderInfo.Status = orderStatusPaid orderInfo.Status = orderStatusPaid
l.Infow("更新订单状态成功", logger.Field("orderNo", orderInfo.OrderNo), logger.Field("status", orderStatusPaid)) l.Infow("更新订单状态成功", logger.Field("orderNo", orderInfo.OrderNo), logger.Field("status", orderStatusPaid))
commonLogic.SubscriptionTraceInfo(l.Logger, commonLogic.SubscriptionTraceFlowOrder, "payment_settled",
"[SubscriptionFlow] apple iap attach marked order as paid",
append(commonLogic.OrderTraceFields(orderInfo),
logger.Field("iap_expire_at", iapExpireAt),
)...,
)
} }
// enqueue activation regardless (idempotent handler downstream) // enqueue activation regardless (idempotent handler downstream)
payload := queueType.ForthwithActivateOrderPayload{OrderNo: orderInfo.OrderNo, IAPExpireAt: iapExpireAt} payload := queueType.ForthwithActivateOrderPayload{OrderNo: orderInfo.OrderNo, IAPExpireAt: iapExpireAt}
@ -463,6 +490,12 @@ func (l *AttachTransactionLogic) syncOrderStatusAndEnqueue(orderInfo *ordermodel
l.Errorw("enqueue activate task error", logger.Field("error", err.Error())) l.Errorw("enqueue activate task error", logger.Field("error", err.Error()))
} else { } else {
l.Infow("已加入订单激活队列", logger.Field("orderNo", orderInfo.OrderNo)) l.Infow("已加入订单激活队列", logger.Field("orderNo", orderInfo.OrderNo))
commonLogic.SubscriptionTraceInfo(l.Logger, commonLogic.SubscriptionTraceFlowOrder, "activation_task_enqueued",
"[SubscriptionFlow] apple iap attach enqueued activation task",
append(commonLogic.OrderTraceFields(orderInfo),
logger.Field("iap_expire_at", iapExpireAt),
)...,
)
} }
return nil return nil
} }

View File

@ -12,6 +12,9 @@ func getDiscount(discounts []types.SubscribeDiscount, inputMonths int64, isNewUs
if d.Quantity != inputMonths || d.Discount <= 0 || d.Discount >= 100 { if d.Quantity != inputMonths || d.Discount <= 0 || d.Discount >= 100 {
continue continue
} }
if d.NewUserOnly && !isNewUser {
continue
}
if isNewUser { if isNewUser {
// lowest discount value = biggest saving // lowest discount value = biggest saving
if best < 0 || d.Discount < best { if best < 0 || d.Discount < best {
@ -50,4 +53,3 @@ func isNewUserOnlyForQuantity(discounts []types.SubscribeDiscount, inputQuantity
} }
return hasNewUserOnly && !hasFallback return hasNewUserOnly && !hasFallback
} }

View File

@ -0,0 +1,16 @@
package order
import (
"testing"
"github.com/perfect-panel/server/internal/types"
"github.com/stretchr/testify/require"
)
func TestGetDiscount_SkipsNewUserOnlyTierForExistingUser(t *testing.T) {
discount := getDiscount([]types.SubscribeDiscount{
{Quantity: 1, Discount: 90, NewUserOnly: true},
}, 1, false)
require.Equal(t, float64(1), discount)
}

View File

@ -47,7 +47,7 @@ func resolveNewUserDiscountEligibility(
ctx, ctx,
db, db,
eligibility.ScopeUserIDs, eligibility.ScopeUserIDs,
subscribeID, 0,
[]int64{2, 5}, []int64{2, 5},
"", "",
) )

View File

@ -4,6 +4,7 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"math" "math"
"strings"
"time" "time"
"github.com/google/uuid" "github.com/google/uuid"
@ -62,6 +63,17 @@ func (l *PurchaseLogic) Purchase(req *types.PurchaseOrderRequest) (resp *types.P
return nil, entErr return nil, entErr
} }
commonLogic.SubscriptionTraceInfo(l.Logger, commonLogic.SubscriptionTraceFlowOrder, "order_create_start",
"[SubscriptionFlow] purchase order creation started",
logger.Field("order_kind", "purchase"),
logger.Field("user_id", u.Id),
logger.Field("effective_user_id", entitlement.EffectiveUserID),
logger.Field("requested_subscribe_id", req.SubscribeId),
logger.Field("quantity", req.Quantity),
logger.Field("payment_id", req.Payment),
logger.Field("coupon", req.Coupon),
)
if req.Quantity <= 0 { if req.Quantity <= 0 {
l.Debugf("[Purchase] Quantity is less than or equal to 0, setting to 1") l.Debugf("[Purchase] Quantity is less than or equal to 0, setting to 1")
req.Quantity = 1 req.Quantity = 1
@ -101,12 +113,42 @@ func (l *PurchaseLogic) Purchase(req *types.PurchaseOrderRequest) (resp *types.P
parentOrderID = decision.Anchor.OrderId parentOrderID = decision.Anchor.OrderId
subscribeToken = decision.Anchor.Token subscribeToken = decision.Anchor.Token
anchorUserSubscribeID = decision.Anchor.Id anchorUserSubscribeID = decision.Anchor.Id
l.Infow("[Purchase] single mode purchase routed to renewal", commonLogic.SubscriptionTraceInfo(l.Logger, commonLogic.SubscriptionTraceFlowOrder, "order_route_selected",
logger.Field("mode", "single"), "[SubscriptionFlow] purchase routed to renewal before order creation",
logger.Field("route_mode", "single"),
logger.Field("route", "purchase_to_renewal"), logger.Field("route", "purchase_to_renewal"),
logger.Field("anchor_user_subscribe_id", decision.Anchor.Id), logger.Field("anchor_user_subscribe_id", decision.Anchor.Id),
logger.Field("order_no", "pending"),
logger.Field("user_id", u.Id), logger.Field("user_id", u.Id),
logger.Field("effective_user_id", entitlement.EffectiveUserID),
logger.Field("requested_subscribe_id", req.SubscribeId),
logger.Field("resolved_subscribe_id", targetSubscribeID),
)
}
}
// 全局单订阅口径:若用户已有任意付费订阅(含过期),提前路由为续费/换套餐,
// 防止不同套餐购买创建第二条订阅。
if !l.svcCtx.Config.Subscribe.SingleModel && orderType == 1 {
var existSub user.Subscribe
if e := l.svcCtx.DB.WithContext(l.ctx).
Model(&user.Subscribe{}).
Where("user_id = ? AND token != '' AND (order_id > 0 OR token LIKE 'iap:%')", entitlement.EffectiveUserID).
Order("expire_time DESC").
Order("updated_at DESC").
Order("id DESC").
First(&existSub).Error; e == nil && existSub.Id > 0 && existSub.Token != "" {
orderType = 2
parentOrderID = existSub.OrderId
subscribeToken = existSub.Token
commonLogic.SubscriptionTraceInfo(l.Logger, commonLogic.SubscriptionTraceFlowOrder, "order_route_selected",
"[SubscriptionFlow] purchase routed to renewal because an existing subscription was found",
logger.Field("route_mode", "global_single_subscription"),
logger.Field("route", "purchase_to_existing_subscription"),
logger.Field("existing_subscribe_id", existSub.Id),
logger.Field("existing_status", existSub.Status),
logger.Field("user_id", u.Id),
logger.Field("effective_user_id", entitlement.EffectiveUserID),
logger.Field("resolved_subscribe_id", targetSubscribeID),
) )
} }
} }
@ -277,13 +319,13 @@ func (l *PurchaseLogic) Purchase(req *types.PurchaseOrderRequest) (resp *types.P
AppAccountToken: uuid.New().String(), AppAccountToken: uuid.New().String(),
} }
if isSingleModeRenewal { if isSingleModeRenewal {
l.Infow("[Purchase] single mode purchase order created as renewal", commonLogic.SubscriptionTraceInfo(l.Logger, commonLogic.SubscriptionTraceFlowOrder, "order_created",
logger.Field("mode", "single"), "[SubscriptionFlow] purchase order persisted as renewal",
append(commonLogic.OrderTraceFields(orderInfo),
logger.Field("route_mode", "single"),
logger.Field("route", "purchase_to_renewal"), logger.Field("route", "purchase_to_renewal"),
logger.Field("anchor_user_subscribe_id", anchorUserSubscribeID), logger.Field("anchor_user_subscribe_id", anchorUserSubscribeID),
logger.Field("order_no", orderInfo.OrderNo), )...,
logger.Field("parent_id", orderInfo.ParentId),
logger.Field("user_id", u.Id),
) )
} }
// Database transaction // Database transaction
@ -291,13 +333,13 @@ func (l *PurchaseLogic) Purchase(req *types.PurchaseOrderRequest) (resp *types.P
// check subscribe plan quota limit inside transaction to prevent race condition // check subscribe plan quota limit inside transaction to prevent race condition
if orderInfo.Type == 1 && sub.Quota > 0 { if orderInfo.Type == 1 && sub.Quota > 0 {
var currentUserSub []user.Subscribe var currentUserSub []user.Subscribe
if e := db.Model(&user.Subscribe{}).Where("user_id = ?", u.Id).Find(&currentUserSub).Error; e != nil { if e := db.Model(&user.Subscribe{}).Where("user_id = ?", entitlement.EffectiveUserID).Find(&currentUserSub).Error; e != nil {
l.Errorw("[Purchase] Database query error", logger.Field("error", e.Error()), logger.Field("user_id", u.Id)) l.Errorw("[Purchase] Database query error", logger.Field("error", e.Error()), logger.Field("user_id", u.Id))
return e return e
} }
var count int64 var count int64
for _, v := range currentUserSub { for _, v := range currentUserSub {
if v.SubscribeId == targetSubscribeID { if v.OrderId > 0 || strings.HasPrefix(v.Token, "iap:") {
count++ count++
} }
} }
@ -380,6 +422,16 @@ func (l *PurchaseLogic) Purchase(req *types.PurchaseOrderRequest) (resp *types.P
} }
return nil, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseInsertError), "insert order error: %v", err.Error()) return nil, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseInsertError), "insert order error: %v", err.Error())
} }
if !isSingleModeRenewal {
commonLogic.SubscriptionTraceInfo(l.Logger, commonLogic.SubscriptionTraceFlowOrder, "order_created",
"[SubscriptionFlow] purchase order persisted",
append(commonLogic.OrderTraceFields(orderInfo),
logger.Field("route_mode", "standard"),
logger.Field("resolved_subscribe_id", targetSubscribeID),
)...,
)
}
// Deferred task // Deferred task
payload := queue.DeferCloseOrderPayload{ payload := queue.DeferCloseOrderPayload{
OrderNo: orderInfo.OrderNo, OrderNo: orderInfo.OrderNo,

View File

@ -0,0 +1,766 @@
package order
import (
"context"
"testing"
"time"
"github.com/alicebob/miniredis/v2"
"github.com/hibiken/asynq"
modelOrder "github.com/perfect-panel/server/internal/model/order"
"github.com/perfect-panel/server/internal/model/payment"
subModel "github.com/perfect-panel/server/internal/model/subscribe"
"github.com/perfect-panel/server/internal/model/user"
"github.com/perfect-panel/server/internal/svc"
"github.com/perfect-panel/server/internal/types"
"github.com/perfect-panel/server/pkg/constant"
"github.com/perfect-panel/server/pkg/xerr"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
// setupNewUserOnlyDB 创建带必要表的 SQLite 内存数据库
func setupNewUserOnlyDB(t *testing.T) *gorm.DB {
t.Helper()
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
})
require.NoError(t, err, "failed to open in-memory SQLite")
db.Exec("PRAGMA foreign_keys = OFF")
sqls := []string{
`CREATE TABLE IF NOT EXISTS "subscribe" (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name VARCHAR(255) NOT NULL DEFAULT '',
language VARCHAR(255) NOT NULL DEFAULT '',
description TEXT,
unit_price INTEGER NOT NULL DEFAULT 0,
unit_time VARCHAR(255) NOT NULL DEFAULT '',
discount TEXT,
replacement INTEGER NOT NULL DEFAULT 0,
inventory INTEGER NOT NULL DEFAULT -1,
traffic INTEGER NOT NULL DEFAULT 0,
speed_limit INTEGER NOT NULL DEFAULT 0,
device_limit INTEGER NOT NULL DEFAULT 0,
quota INTEGER NOT NULL DEFAULT 0,
new_user_only TINYINT DEFAULT 0,
nodes VARCHAR(255),
node_tags VARCHAR(255),
show TINYINT NOT NULL DEFAULT 0,
sell TINYINT NOT NULL DEFAULT 1,
sort INTEGER NOT NULL DEFAULT 0,
deduction_ratio INTEGER DEFAULT 0,
allow_deduction TINYINT DEFAULT 1,
reset_cycle INTEGER DEFAULT 0,
renewal_reset TINYINT DEFAULT 0,
show_original_price TINYINT NOT NULL DEFAULT 1,
created_at DATETIME,
updated_at DATETIME
)`,
`CREATE TABLE IF NOT EXISTS "order" (
id INTEGER PRIMARY KEY AUTOINCREMENT,
parent_id INTEGER DEFAULT NULL,
user_id INTEGER NOT NULL DEFAULT 0,
subscription_user_id INTEGER NOT NULL DEFAULT 0,
order_no VARCHAR(255) NOT NULL DEFAULT '' UNIQUE,
type TINYINT NOT NULL DEFAULT 1,
quantity INTEGER NOT NULL DEFAULT 1,
price INTEGER NOT NULL DEFAULT 0,
amount INTEGER NOT NULL DEFAULT 0,
gift_amount INTEGER NOT NULL DEFAULT 0,
discount INTEGER NOT NULL DEFAULT 0,
coupon VARCHAR(255) DEFAULT NULL,
coupon_discount INTEGER NOT NULL DEFAULT 0,
commission INTEGER NOT NULL DEFAULT 0,
payment_id INTEGER NOT NULL DEFAULT 0,
method VARCHAR(255) NOT NULL DEFAULT '',
fee_amount INTEGER NOT NULL DEFAULT 0,
trade_no VARCHAR(255) DEFAULT NULL,
app_account_token VARCHAR(255) DEFAULT NULL,
status TINYINT NOT NULL DEFAULT 1,
subscribe_id INTEGER NOT NULL DEFAULT 0,
subscribe_token VARCHAR(255) DEFAULT NULL,
is_new TINYINT NOT NULL DEFAULT 0,
created_at DATETIME,
updated_at DATETIME,
deleted_at DATETIME DEFAULT NULL
)`,
`CREATE TABLE IF NOT EXISTS "user" (
id INTEGER PRIMARY KEY AUTOINCREMENT,
password VARCHAR(100) NOT NULL DEFAULT '',
algo VARCHAR(20) DEFAULT 'default',
salt VARCHAR(20) DEFAULT NULL,
avatar TEXT,
balance INTEGER DEFAULT 0,
refer_code VARCHAR(20) DEFAULT '',
referer_id INTEGER DEFAULT 0,
commission INTEGER DEFAULT 0,
referral_percentage INTEGER DEFAULT 0,
only_first_purchase TINYINT DEFAULT 1,
gift_amount INTEGER DEFAULT 0,
enable TINYINT DEFAULT 1,
is_admin TINYINT DEFAULT 0,
enable_balance_notify TINYINT DEFAULT 0,
enable_login_notify TINYINT DEFAULT 0,
enable_subscribe_notify TINYINT DEFAULT 0,
enable_trade_notify TINYINT DEFAULT 0,
rules TEXT,
member_status VARCHAR(20) DEFAULT '',
created_at DATETIME,
updated_at DATETIME,
deleted_at DATETIME DEFAULT NULL
)`,
`CREATE TABLE IF NOT EXISTS "payment" (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name VARCHAR(100) NOT NULL DEFAULT '',
platform VARCHAR(100) NOT NULL DEFAULT '',
icon VARCHAR(255) DEFAULT '',
domain VARCHAR(255) DEFAULT '',
config TEXT NOT NULL DEFAULT '{}',
description TEXT,
fee_mode TINYINT NOT NULL DEFAULT 0,
fee_percent INTEGER DEFAULT 0,
fee_amount INTEGER DEFAULT 0,
enable TINYINT NOT NULL DEFAULT 1,
token VARCHAR(255) NOT NULL DEFAULT '' UNIQUE
)`,
`CREATE TABLE IF NOT EXISTS "user_subscribe" (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL DEFAULT 0,
order_id INTEGER NOT NULL DEFAULT 0,
subscribe_id INTEGER NOT NULL DEFAULT 0,
start_time DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
expire_time DATETIME DEFAULT NULL,
finished_at DATETIME DEFAULT NULL,
traffic INTEGER DEFAULT 0,
download INTEGER DEFAULT 0,
upload INTEGER DEFAULT 0,
token VARCHAR(255) DEFAULT '' UNIQUE,
uuid VARCHAR(255) DEFAULT '' UNIQUE,
status TINYINT DEFAULT 0,
note VARCHAR(500) DEFAULT '',
created_at DATETIME,
updated_at DATETIME
)`,
`CREATE TABLE IF NOT EXISTS "user_device" (
id INTEGER PRIMARY KEY AUTOINCREMENT,
ip VARCHAR(255) NOT NULL DEFAULT '',
user_id INTEGER NOT NULL DEFAULT 0,
user_agent TEXT,
identifier VARCHAR(255) NOT NULL DEFAULT '' UNIQUE,
short_code VARCHAR(255) NOT NULL DEFAULT '',
online TINYINT NOT NULL DEFAULT 0,
enabled TINYINT NOT NULL DEFAULT 1,
created_at DATETIME,
updated_at DATETIME
)`,
`CREATE TABLE IF NOT EXISTS "user_family" (
id INTEGER PRIMARY KEY AUTOINCREMENT,
owner_user_id INTEGER NOT NULL DEFAULT 0,
max_members INTEGER NOT NULL DEFAULT 2,
status TINYINT DEFAULT 0,
created_at DATETIME,
updated_at DATETIME,
deleted_at DATETIME DEFAULT NULL
)`,
`CREATE TABLE IF NOT EXISTS "user_family_member" (
id INTEGER PRIMARY KEY AUTOINCREMENT,
family_id INTEGER NOT NULL DEFAULT 0,
user_id INTEGER NOT NULL DEFAULT 0,
role TINYINT DEFAULT 0,
status TINYINT DEFAULT 0,
join_source VARCHAR(32) NOT NULL DEFAULT '',
joined_at DATETIME,
left_at DATETIME DEFAULT NULL,
created_at DATETIME,
updated_at DATETIME,
deleted_at DATETIME DEFAULT NULL
)`,
}
for _, sql := range sqls {
require.NoError(t, db.Exec(sql).Error)
}
return db
}
// setupNewUserOnlyRedis 启动 miniredis返回 redis.Client 和 miniredis 句柄
func setupNewUserOnlyRedis(t *testing.T) (*redis.Client, *miniredis.Miniredis) {
t.Helper()
mr, err := miniredis.Run()
require.NoError(t, err)
t.Cleanup(mr.Close)
rds := redis.NewClient(&redis.Options{Addr: mr.Addr()})
return rds, mr
}
// buildNewUserOnlySvcCtx 组装最小 ServiceContext含 asynq Queue 使用 miniredis
func buildNewUserOnlySvcCtx(db *gorm.DB, rds *redis.Client, mr *miniredis.Miniredis) *svc.ServiceContext {
queue := asynq.NewClient(asynq.RedisClientOpt{Addr: mr.Addr()})
return &svc.ServiceContext{
DB: db,
Redis: rds,
UserModel: user.NewModel(db, rds),
OrderModel: modelOrder.NewModel(db, rds),
SubscribeModel: subModel.NewModel(db, rds),
PaymentModel: payment.NewModel(db, rds),
Queue: queue,
}
}
// insertTestSubscribe 直接用 SQL 插入 subscribe 行(绕过 GORM hook 的 MySQL 方言)
// new_user_only=true 时同时写入 discount JSON使代码里的 discount 检查生效
func insertTestSubscribe(t *testing.T, db *gorm.DB, id int64, newUserOnly bool) {
t.Helper()
nuOnly := 0
discount := ""
if newUserOnly {
nuOnly = 1
// discount JSON 包含一个 new_user_only=true 的 tier匹配 quantity=1
discount = `[{"quantity":1,"discount":90,"new_user_only":true}]`
}
err := db.Exec(`INSERT INTO "subscribe"
(id, name, unit_price, inventory, sell, sort, new_user_only, discount, created_at, updated_at)
VALUES (?, 'Test Plan', 1000, -1, 1, ?, ?, ?, datetime('now'), datetime('now'))`,
id, id, nuOnly, discount).Error
require.NoError(t, err)
}
// insertTestPayment 插入支付方式行
func insertTestPayment(t *testing.T, db *gorm.DB, id int64) {
t.Helper()
err := db.Exec(`INSERT INTO "payment"
(id, name, platform, config, enable, fee_mode, token)
VALUES (?, 'Balance', 'balance', '{}', 1, 0, ?)`,
id, "test-token").Error
require.NoError(t, err)
}
// insertTestUser 插入用户行createdAt 可控
func insertTestUser(t *testing.T, db *gorm.DB, id int64, createdAt time.Time) *user.User {
t.Helper()
err := db.Exec(`INSERT INTO "user"
(id, password, balance, gift_amount, enable, created_at, updated_at)
VALUES (?, '', 0, 0, 1, ?, datetime('now'))`,
id, createdAt.UTC().Format("2006-01-02 15:04:05")).Error
require.NoError(t, err)
return &user.User{
Id: id,
GiftAmount: 0,
CreatedAt: createdAt,
}
}
func insertTestDevice(t *testing.T, db *gorm.DB, userID int64, identifier string, createdAt time.Time) {
t.Helper()
err := db.Exec(`INSERT INTO "user_device"
(user_id, ip, user_agent, identifier, short_code, online, enabled, created_at, updated_at)
VALUES (?, '127.0.0.1', 'test-agent', ?, '', 0, 1, ?, datetime('now'))`,
userID,
identifier,
createdAt.UTC().Format("2006-01-02 15:04:05"),
).Error
require.NoError(t, err)
}
func insertTestFamily(t *testing.T, db *gorm.DB, familyID, ownerUserID int64) {
t.Helper()
err := db.Exec(`INSERT INTO "user_family"
(id, owner_user_id, max_members, status, created_at, updated_at)
VALUES (?, ?, 3, 1, datetime('now'), datetime('now'))`,
familyID,
ownerUserID,
).Error
require.NoError(t, err)
}
func insertTestFamilyMember(t *testing.T, db *gorm.DB, familyID, userID int64, role, status uint8, joinSource string) {
t.Helper()
err := db.Exec(`INSERT INTO "user_family_member"
(family_id, user_id, role, status, join_source, joined_at, created_at, updated_at)
VALUES (?, ?, ?, ?, ?, datetime('now'), datetime('now'), datetime('now'))`,
familyID,
userID,
role,
status,
joinSource,
).Error
require.NoError(t, err)
}
// insertTestOrder 插入一条历史订单status=2 表示已支付)
func insertTestOrder(t *testing.T, db *gorm.DB, userID, subscribeID int64, status uint8) {
t.Helper()
err := db.Exec(`INSERT INTO "order"
(user_id, order_no, type, status, subscribe_id, created_at, updated_at)
VALUES (?, ?, 1, ?, ?, datetime('now'), datetime('now'))`,
userID, "existing-order-no", status, subscribeID).Error
require.NoError(t, err)
}
func insertScopedTestOrder(t *testing.T, db *gorm.DB, orderNo string, userID, subscribeID int64, status uint8) {
t.Helper()
err := db.Exec(`INSERT INTO "order"
(user_id, order_no, type, status, subscribe_id, created_at, updated_at)
VALUES (?, ?, 1, ?, ?, datetime('now'), datetime('now'))`,
userID, orderNo, status, subscribeID).Error
require.NoError(t, err)
}
// buildPurchaseCtx 把 user 放入 context模拟中间件行为
func buildPurchaseCtx(u *user.User) context.Context {
return context.WithValue(context.Background(), constant.CtxKeyUser, u)
}
// TestPurchase_NewUserOnly_UserTooOld 验证new_user_only=true用户注册超过 24h → 返回 SubscribeNewUserOnly
func TestPurchase_NewUserOnly_UserTooOld(t *testing.T) {
db := setupNewUserOnlyDB(t)
rds, mr := setupNewUserOnlyRedis(t)
svcCtx := buildNewUserOnlySvcCtx(db, rds, mr)
const subID = int64(1)
const payID = int64(1)
insertTestSubscribe(t, db, subID, true) // new_user_only = true
insertTestPayment(t, db, payID)
// 用户注册 48 小时前 → 超出 24h 限制
u := insertTestUser(t, db, 100, time.Now().Add(-48*time.Hour))
ctx := buildPurchaseCtx(u)
logic := NewPurchaseLogic(ctx, svcCtx)
_, err := logic.Purchase(&types.PurchaseOrderRequest{
SubscribeId: subID,
Payment: payID,
Quantity: 1,
})
require.Error(t, err)
var errCode *xerr.CodeError
require.ErrorAs(t, err, &errCode)
assert.Equal(t, xerr.SubscribeNewUserOnly, errCode.GetErrCode(),
"注册超过24h应返回 SubscribeNewUserOnly 错误码")
// 验证订单未被创建
var count int64
db.Model(&modelOrder.Order{}).Where("user_id = ?", u.Id).Count(&count)
assert.Equal(t, int64(0), count, "用户注册超时,订单不应被创建")
}
// TestPurchase_NewUserOnly_AlreadyPurchased 验证new_user_only=true用户是新用户但已购买过
// → 允许下单(不拦截),但不享受新人折扣
func TestPurchase_NewUserOnly_AlreadyPurchased(t *testing.T) {
db := setupNewUserOnlyDB(t)
rds, mr := setupNewUserOnlyRedis(t)
svcCtx := buildNewUserOnlySvcCtx(db, rds, mr)
const subID = int64(2)
const payID = int64(1)
insertTestSubscribe(t, db, subID, true)
insertTestPayment(t, db, payID)
// 用户刚注册2h前→ 满足时间条件
u := insertTestUser(t, db, 200, time.Now().Add(-2*time.Hour))
// 但已有一条 status=2 的历史订单(已支付)
insertTestOrder(t, db, u.Id, subID, 2)
ctx := buildPurchaseCtx(u)
logic := NewPurchaseLogic(ctx, svcCtx)
resp, err := logic.Purchase(&types.PurchaseOrderRequest{
SubscribeId: subID,
Payment: payID,
Quantity: 1,
})
// 不应被拦截,允许下单
require.NoError(t, err, "24h内已购用户应允许继续下单不应返回错误")
require.NotNil(t, resp)
assert.NotEmpty(t, resp.OrderNo)
// 历史订单 +1新增了一条
var count int64
db.Model(&modelOrder.Order{}).Where("user_id = ?", u.Id).Count(&count)
assert.Equal(t, int64(2), count, "应新增一条订单")
// 新订单无折扣Amount=Price=1000
var newOrder modelOrder.Order
require.NoError(t, db.Where("order_no = ?", resp.OrderNo).First(&newOrder).Error)
assert.Equal(t, int64(1000), newOrder.Amount, "已购用户不享受新人折扣Amount 应等于 Price")
assert.Equal(t, int64(0), newOrder.Discount, "Discount 应为 0")
}
// TestPurchase_NewUserOnly_Success 验证new_user_only=true新用户首次购买 → 成功创建订单
func TestPurchase_NewUserOnly_Success(t *testing.T) {
db := setupNewUserOnlyDB(t)
rds, mr := setupNewUserOnlyRedis(t)
svcCtx := buildNewUserOnlySvcCtx(db, rds, mr)
const subID = int64(3)
const payID = int64(1)
insertTestSubscribe(t, db, subID, true)
insertTestPayment(t, db, payID)
// 用户 1 小时前注册(新用户),且没有历史订单
u := insertTestUser(t, db, 300, time.Now().Add(-1*time.Hour))
ctx := buildPurchaseCtx(u)
logic := NewPurchaseLogic(ctx, svcCtx)
resp, err := logic.Purchase(&types.PurchaseOrderRequest{
SubscribeId: subID,
Payment: payID,
Quantity: 1,
})
require.NoError(t, err)
require.NotNil(t, resp)
assert.NotEmpty(t, resp.OrderNo, "新用户首次购买应成功,返回订单号")
// 验证订单已写入数据库
var o modelOrder.Order
err = db.Where("order_no = ?", resp.OrderNo).First(&o).Error
require.NoError(t, err)
assert.Equal(t, u.Id, o.UserId)
assert.Equal(t, subID, o.SubscribeId)
}
// TestPurchase_NewUserOnly_Disabled 验证new_user_only=false 时,老用户也能正常购买
func TestPurchase_NewUserOnly_Disabled(t *testing.T) {
db := setupNewUserOnlyDB(t)
rds, mr := setupNewUserOnlyRedis(t)
svcCtx := buildNewUserOnlySvcCtx(db, rds, mr)
const subID = int64(4)
const payID = int64(1)
insertTestSubscribe(t, db, subID, false) // new_user_only = false
insertTestPayment(t, db, payID)
// 注册 30 天的老用户
u := insertTestUser(t, db, 400, time.Now().Add(-30*24*time.Hour))
ctx := buildPurchaseCtx(u)
logic := NewPurchaseLogic(ctx, svcCtx)
resp, err := logic.Purchase(&types.PurchaseOrderRequest{
SubscribeId: subID,
Payment: payID,
Quantity: 1,
})
require.NoError(t, err)
require.NotNil(t, resp)
assert.NotEmpty(t, resp.OrderNo, "new_user_only=false时老用户应能正常购买")
}
// TestPurchase_SingleMode_PendingOldOrderCancelled 验证:单订阅模式下,已有 pending 订单时
// 第二次下单应关闭旧单并创建新单(而非复用旧单)
func TestPurchase_SingleMode_PendingOldOrderCancelled(t *testing.T) {
db := setupNewUserOnlyDB(t)
rds, mr := setupNewUserOnlyRedis(t)
svcCtx := buildNewUserOnlySvcCtx(db, rds, mr)
svcCtx.Config.Subscribe.SingleModel = true
const subID = int64(5)
const payID = int64(1)
insertTestSubscribe(t, db, subID, false)
insertTestPayment(t, db, payID)
u := insertTestUser(t, db, 500, time.Now().Add(-1*time.Hour))
ctx := buildPurchaseCtx(u)
// 第一次下单pending
logic := NewPurchaseLogic(ctx, svcCtx)
resp1, err := logic.Purchase(&types.PurchaseOrderRequest{
SubscribeId: subID,
Payment: payID,
Quantity: 1,
})
require.NoError(t, err)
require.NotNil(t, resp1)
firstOrderNo := resp1.OrderNo
assert.NotEmpty(t, firstOrderNo)
// 确认第一单 pending
var firstOrder modelOrder.Order
require.NoError(t, db.Where("order_no = ?", firstOrderNo).First(&firstOrder).Error)
assert.Equal(t, uint8(1), firstOrder.Status, "第一单应为 pending")
// 第二次下单(不同 quantity
logic2 := NewPurchaseLogic(ctx, svcCtx)
resp2, err := logic2.Purchase(&types.PurchaseOrderRequest{
SubscribeId: subID,
Payment: payID,
Quantity: 3,
})
require.NoError(t, err)
require.NotNil(t, resp2)
secondOrderNo := resp2.OrderNo
// 新单与旧单不同
assert.NotEqual(t, firstOrderNo, secondOrderNo, "第二次下单应创建新订单,不复用旧单")
// 旧单应被关闭status=3
var closedOrder modelOrder.Order
require.NoError(t, db.Where("order_no = ?", firstOrderNo).First(&closedOrder).Error)
assert.Equal(t, uint8(3), closedOrder.Status, "旧 pending 单应被关闭")
// 新单的 quantity 应为 3
var newOrder modelOrder.Order
require.NoError(t, db.Where("order_no = ?", secondOrderNo).First(&newOrder).Error)
assert.Equal(t, int64(3), newOrder.Quantity, "新单 quantity 应为 3")
assert.Equal(t, uint8(1), newOrder.Status, "新单应为 pending 状态")
}
// TestPurchase_SingleMode_NoPendingOrder 验证:单订阅模式下,没有旧 pending 单时正常创建
func TestPurchase_SingleMode_NoPendingOrder(t *testing.T) {
db := setupNewUserOnlyDB(t)
rds, mr := setupNewUserOnlyRedis(t)
svcCtx := buildNewUserOnlySvcCtx(db, rds, mr)
svcCtx.Config.Subscribe.SingleModel = true
const subID = int64(6)
const payID = int64(1)
insertTestSubscribe(t, db, subID, false)
insertTestPayment(t, db, payID)
u := insertTestUser(t, db, 600, time.Now().Add(-1*time.Hour))
ctx := buildPurchaseCtx(u)
logic := NewPurchaseLogic(ctx, svcCtx)
resp, err := logic.Purchase(&types.PurchaseOrderRequest{
SubscribeId: subID,
Payment: payID,
Quantity: 2,
})
require.NoError(t, err)
require.NotNil(t, resp)
assert.NotEmpty(t, resp.OrderNo, "无旧 pending 单时应正常创建新单")
var o modelOrder.Order
require.NoError(t, db.Where("order_no = ?", resp.OrderNo).First(&o).Error)
assert.Equal(t, int64(2), o.Quantity)
assert.Equal(t, uint8(1), o.Status)
}
// TestPurchase_NewUserOnly_AlreadyPurchased_NoBlock 验证new_user_only=true 套餐,
// 24小时内但已购买过 → 允许下单但不享受新人折扣Discount=0Amount=Price
func TestPurchase_NewUserOnly_AlreadyPurchased_NoBlock(t *testing.T) {
db := setupNewUserOnlyDB(t)
rds, mr := setupNewUserOnlyRedis(t)
svcCtx := buildNewUserOnlySvcCtx(db, rds, mr)
const subID = int64(7)
const payID = int64(1)
// 套餐unit_price=1000discount=[{quantity:1,discount:80,new_user_only:true}]
insertTestSubscribe(t, db, subID, true)
insertTestPayment(t, db, payID)
// 用户 1 小时前注册新用户但已有一条成功订单status=2
u := insertTestUser(t, db, 700, time.Now().Add(-1*time.Hour))
insertTestOrder(t, db, u.Id, subID, 2)
ctx := buildPurchaseCtx(u)
logic := NewPurchaseLogic(ctx, svcCtx)
resp, err := logic.Purchase(&types.PurchaseOrderRequest{
SubscribeId: subID,
Payment: payID,
Quantity: 1,
})
// 不应被拦截
require.NoError(t, err, "24h内已购用户不应被拦截应允许下单")
require.NotNil(t, resp)
assert.NotEmpty(t, resp.OrderNo)
// 验证订单金额无折扣Amount=Price=1000Discount=0
var newOrder modelOrder.Order
require.NoError(t, db.Where("order_no = ?", resp.OrderNo).First(&newOrder).Error)
assert.Equal(t, int64(1000), newOrder.Price, "Price 应为原价 1000")
assert.Equal(t, int64(1000), newOrder.Amount, "已购用户不享受新人折扣Amount 应等于 Price")
assert.Equal(t, int64(0), newOrder.Discount, "Discount 应为 0")
}
// TestPurchase_NewUserOnly_FirstPurchase_HasDiscount 验证new_user_only=true 套餐,
// 24小时内首次购买 → 允许下单且享受新人折扣
func TestPurchase_NewUserOnly_FirstPurchase_HasDiscount(t *testing.T) {
db := setupNewUserOnlyDB(t)
rds, mr := setupNewUserOnlyRedis(t)
svcCtx := buildNewUserOnlySvcCtx(db, rds, mr)
const subID = int64(8)
const payID = int64(1)
// 套餐unit_price=1000discount=[{quantity:1,discount:80,new_user_only:true}]8折
insertTestSubscribe(t, db, subID, true)
insertTestPayment(t, db, payID)
// 用户 1 小时前注册,无历史订单
u := insertTestUser(t, db, 800, time.Now().Add(-1*time.Hour))
ctx := buildPurchaseCtx(u)
logic := NewPurchaseLogic(ctx, svcCtx)
resp, err := logic.Purchase(&types.PurchaseOrderRequest{
SubscribeId: subID,
Payment: payID,
Quantity: 1,
})
require.NoError(t, err)
require.NotNil(t, resp)
var newOrder modelOrder.Order
require.NoError(t, db.Where("order_no = ?", resp.OrderNo).First(&newOrder).Error)
assert.Equal(t, int64(1000), newOrder.Price, "Price 应为原价 1000")
assert.Equal(t, int64(900), newOrder.Amount, "首次购买应享受9折Amount=900")
assert.Equal(t, int64(100), newOrder.Discount, "折扣金额应为 100")
}
func TestPurchase_NewUserOnly_BindEmailScopeUsesEarliestDeviceTime(t *testing.T) {
db := setupNewUserOnlyDB(t)
rds, mr := setupNewUserOnlyRedis(t)
svcCtx := buildNewUserOnlySvcCtx(db, rds, mr)
const (
subID = int64(9)
payID = int64(1)
ownerUserID = int64(901)
memberUserID = int64(902)
familyID = int64(99)
)
insertTestSubscribe(t, db, subID, true)
insertTestPayment(t, db, payID)
owner := insertTestUser(t, db, ownerUserID, time.Now().Add(-1*time.Hour))
insertTestUser(t, db, memberUserID, time.Now().Add(-72*time.Hour))
insertTestDevice(t, db, memberUserID, "device-eligibility-old", time.Now().Add(-72*time.Hour))
insertTestFamily(t, db, familyID, ownerUserID)
insertTestFamilyMember(t, db, familyID, ownerUserID, user.FamilyRoleOwner, user.FamilyMemberActive, "owner_init")
insertTestFamilyMember(t, db, familyID, memberUserID, user.FamilyRoleMember, user.FamilyMemberActive, "bind_email_with_verification")
logic := NewPurchaseLogic(buildPurchaseCtx(owner), svcCtx)
_, err := logic.Purchase(&types.PurchaseOrderRequest{
SubscribeId: subID,
Payment: payID,
Quantity: 1,
})
require.Error(t, err)
var errCode *xerr.CodeError
require.ErrorAs(t, err, &errCode)
assert.Equal(t, xerr.SubscribeNewUserOnly, errCode.GetErrCode())
}
func TestPurchase_NewUserOnly_BindEmailScopeSharesHistory(t *testing.T) {
db := setupNewUserOnlyDB(t)
rds, mr := setupNewUserOnlyRedis(t)
svcCtx := buildNewUserOnlySvcCtx(db, rds, mr)
const (
subID = int64(10)
payID = int64(1)
ownerUserID = int64(1001)
memberUserID = int64(1002)
familyID = int64(109)
)
insertTestSubscribe(t, db, subID, true)
insertTestPayment(t, db, payID)
owner := insertTestUser(t, db, ownerUserID, time.Now().Add(-1*time.Hour))
insertTestUser(t, db, memberUserID, time.Now().Add(-2*time.Hour))
insertTestDevice(t, db, memberUserID, "device-eligibility-shared", time.Now().Add(-2*time.Hour))
insertTestFamily(t, db, familyID, ownerUserID)
insertTestFamilyMember(t, db, familyID, ownerUserID, user.FamilyRoleOwner, user.FamilyMemberActive, "owner_init")
insertTestFamilyMember(t, db, familyID, memberUserID, user.FamilyRoleMember, user.FamilyMemberActive, "bind_email_with_verification")
insertScopedTestOrder(t, db, "existing-scope-order", memberUserID, subID, 2)
logic := NewPurchaseLogic(buildPurchaseCtx(owner), svcCtx)
resp, err := logic.Purchase(&types.PurchaseOrderRequest{
SubscribeId: subID,
Payment: payID,
Quantity: 1,
})
require.NoError(t, err)
require.NotNil(t, resp)
var newOrder modelOrder.Order
require.NoError(t, db.Where("order_no = ?", resp.OrderNo).First(&newOrder).Error)
assert.Equal(t, int64(1000), newOrder.Amount)
assert.Equal(t, int64(0), newOrder.Discount)
}
func TestPreCreateOrder_NewUserOnly_BindEmailScopeUsesEarliestDeviceTime(t *testing.T) {
db := setupNewUserOnlyDB(t)
rds, mr := setupNewUserOnlyRedis(t)
svcCtx := buildNewUserOnlySvcCtx(db, rds, mr)
const (
subID = int64(11)
ownerUserID = int64(1101)
memberUserID = int64(1102)
familyID = int64(119)
)
insertTestSubscribe(t, db, subID, true)
owner := insertTestUser(t, db, ownerUserID, time.Now().Add(-1*time.Hour))
insertTestUser(t, db, memberUserID, time.Now().Add(-96*time.Hour))
insertTestDevice(t, db, memberUserID, "device-precreate-old", time.Now().Add(-96*time.Hour))
insertTestFamily(t, db, familyID, ownerUserID)
insertTestFamilyMember(t, db, familyID, ownerUserID, user.FamilyRoleOwner, user.FamilyMemberActive, "owner_init")
insertTestFamilyMember(t, db, familyID, memberUserID, user.FamilyRoleMember, user.FamilyMemberActive, "bind_email_with_verification")
logic := NewPreCreateOrderLogic(buildPurchaseCtx(owner), svcCtx)
_, err := logic.PreCreateOrder(&types.PurchaseOrderRequest{
SubscribeId: subID,
Quantity: 1,
})
require.Error(t, err)
var errCode *xerr.CodeError
require.ErrorAs(t, err, &errCode)
assert.Equal(t, xerr.SubscribeNewUserOnly, errCode.GetErrCode())
}
func TestPreCreateOrder_NewUserOnly_OrdinaryFamilyMemberDoesNotAffectEligibility(t *testing.T) {
db := setupNewUserOnlyDB(t)
rds, mr := setupNewUserOnlyRedis(t)
svcCtx := buildNewUserOnlySvcCtx(db, rds, mr)
const (
subID = int64(12)
ownerUserID = int64(1201)
memberUserID = int64(1202)
familyID = int64(129)
)
insertTestSubscribe(t, db, subID, true)
owner := insertTestUser(t, db, ownerUserID, time.Now().Add(-1*time.Hour))
insertTestUser(t, db, memberUserID, time.Now().Add(-96*time.Hour))
insertTestDevice(t, db, memberUserID, "device-precreate-ordinary", time.Now().Add(-96*time.Hour))
insertTestFamily(t, db, familyID, ownerUserID)
insertTestFamilyMember(t, db, familyID, ownerUserID, user.FamilyRoleOwner, user.FamilyMemberActive, "owner_init")
insertTestFamilyMember(t, db, familyID, memberUserID, user.FamilyRoleMember, user.FamilyMemberActive, "manual_invite")
logic := NewPreCreateOrderLogic(buildPurchaseCtx(owner), svcCtx)
resp, err := logic.PreCreateOrder(&types.PurchaseOrderRequest{
SubscribeId: subID,
Quantity: 1,
})
require.NoError(t, err)
require.NotNil(t, resp)
assert.Equal(t, int64(900), resp.Amount)
assert.Equal(t, int64(100), resp.Discount)
}

View File

@ -54,6 +54,17 @@ func (l *RenewalLogic) Renewal(req *types.RenewalOrderRequest) (resp *types.Rene
if entErr != nil { if entErr != nil {
return nil, entErr return nil, entErr
} }
commonLogic.SubscriptionTraceInfo(l.Logger, commonLogic.SubscriptionTraceFlowOrder, "order_create_start",
"[SubscriptionFlow] renewal order creation started",
logger.Field("order_kind", "renewal"),
logger.Field("user_id", u.Id),
logger.Field("effective_user_id", entitlement.EffectiveUserID),
logger.Field("requested_user_subscribe_id", req.UserSubscribeID),
logger.Field("quantity", req.Quantity),
logger.Field("payment_id", req.Payment),
logger.Field("coupon", req.Coupon),
)
if req.Quantity <= 0 { if req.Quantity <= 0 {
l.Debugf("[Renewal] Quantity is less than or equal to 0, setting to 1") l.Debugf("[Renewal] Quantity is less than or equal to 0, setting to 1")
req.Quantity = 1 req.Quantity = 1
@ -235,6 +246,14 @@ func (l *RenewalLogic) Renewal(req *types.RenewalOrderRequest) (resp *types.Rene
l.Errorw("[Renewal] Database insert error", logger.Field("error", err.Error()), logger.Field("order", orderInfo)) l.Errorw("[Renewal] Database insert error", logger.Field("error", err.Error()), logger.Field("order", orderInfo))
return nil, errors.Wrapf(err, "insert order error: %v", err.Error()) return nil, errors.Wrapf(err, "insert order error: %v", err.Error())
} }
commonLogic.SubscriptionTraceInfo(l.Logger, commonLogic.SubscriptionTraceFlowOrder, "order_created",
"[SubscriptionFlow] renewal order persisted",
append(commonLogic.OrderTraceFields(&orderInfo),
logger.Field("requested_user_subscribe_id", req.UserSubscribeID),
logger.Field("resolved_user_subscribe_id", userSubscribe.Id),
)...,
)
// Deferred task // Deferred task
payload := queue.DeferCloseOrderPayload{ payload := queue.DeferCloseOrderPayload{
OrderNo: orderInfo.OrderNo, OrderNo: orderInfo.OrderNo,

View File

@ -9,6 +9,7 @@ import (
"strings" "strings"
"time" "time"
commonLogic "github.com/perfect-panel/server/internal/logic/common"
"github.com/perfect-panel/server/internal/model/log" "github.com/perfect-panel/server/internal/model/log"
"github.com/perfect-panel/server/internal/report" "github.com/perfect-panel/server/internal/report"
"github.com/perfect-panel/server/pkg/constant" "github.com/perfect-panel/server/pkg/constant"
@ -75,6 +76,14 @@ func (l *PurchaseCheckoutLogic) PurchaseCheckout(req *types.CheckoutOrderRequest
l.Logger.Error("[PurchaseCheckout] Database query error", logger.Field("error", err.Error()), logger.Field("payment", orderInfo.Method)) l.Logger.Error("[PurchaseCheckout] Database query error", logger.Field("error", err.Error()), logger.Field("payment", orderInfo.Method))
return nil, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "find payment method error: %v", err.Error()) return nil, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "find payment method error: %v", err.Error())
} }
commonLogic.SubscriptionTraceInfo(l.Logger, commonLogic.SubscriptionTraceFlowOrder, "checkout_start",
"[SubscriptionFlow] checkout started",
append(commonLogic.OrderTraceFields(orderInfo),
logger.Field("payment_platform", paymentConfig.Platform),
logger.Field("has_return_url", req.ReturnUrl != ""),
)...,
)
// Route to appropriate payment handler based on payment platform // Route to appropriate payment handler based on payment platform
switch paymentPlatform.ParsePlatform(orderInfo.Method) { switch paymentPlatform.ParsePlatform(orderInfo.Method) {
case paymentPlatform.AppleIAP: case paymentPlatform.AppleIAP:
@ -83,6 +92,14 @@ func (l *PurchaseCheckoutLogic) PurchaseCheckout(req *types.CheckoutOrderRequest
Type: "apple_iap", Type: "apple_iap",
ProductIds: []string{productId}, ProductIds: []string{productId},
} }
commonLogic.SubscriptionTraceInfo(l.Logger, commonLogic.SubscriptionTraceFlowOrder, "checkout_response_ready",
"[SubscriptionFlow] checkout response prepared",
append(commonLogic.OrderTraceFields(orderInfo),
logger.Field("payment_platform", paymentConfig.Platform),
logger.Field("checkout_type", resp.Type),
logger.Field("product_ids", resp.ProductIds),
)...,
)
return resp, nil return resp, nil
case paymentPlatform.EPay: case paymentPlatform.EPay:
// Process EPay payment - generates payment URL for redirect // Process EPay payment - generates payment URL for redirect
@ -157,6 +174,16 @@ func (l *PurchaseCheckoutLogic) PurchaseCheckout(req *types.CheckoutOrderRequest
l.Errorw("[PurchaseCheckout] payment method not found", logger.Field("method", orderInfo.Method)) l.Errorw("[PurchaseCheckout] payment method not found", logger.Field("method", orderInfo.Method))
return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "payment method not found") return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "payment method not found")
} }
if resp != nil {
commonLogic.SubscriptionTraceInfo(l.Logger, commonLogic.SubscriptionTraceFlowOrder, "checkout_response_ready",
"[SubscriptionFlow] checkout response prepared",
append(commonLogic.OrderTraceFields(orderInfo),
logger.Field("payment_platform", paymentConfig.Platform),
logger.Field("checkout_type", resp.Type),
)...,
)
}
return return
} }
@ -503,6 +530,9 @@ func (l *PurchaseCheckoutLogic) queryExchangeRate(to string, src int64) (amount
func (l *PurchaseCheckoutLogic) balancePayment(u *user.User, o *order.Order) error { func (l *PurchaseCheckoutLogic) balancePayment(u *user.User, o *order.Order) error {
var userInfo user.User var userInfo user.User
var err error var err error
var giftUsed int64
var balanceUsed int64
paymentPath := "balance"
if o.Amount == 0 { if o.Amount == 0 {
// No payment required for zero-amount orders // No payment required for zero-amount orders
l.Logger.Info( l.Logger.Info(
@ -518,6 +548,13 @@ func (l *PurchaseCheckoutLogic) balancePayment(u *user.User, o *order.Order) err
logger.Field("userId", u.Id)) logger.Field("userId", u.Id))
return errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "Update order status error: %s", err.Error()) return errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "Update order status error: %s", err.Error())
} }
commonLogic.SubscriptionTraceInfo(l.Logger, commonLogic.SubscriptionTraceFlowOrder, "payment_settled",
"[SubscriptionFlow] order marked paid without external payment",
append(commonLogic.OrderTraceFields(o),
logger.Field("payment_path", "zero_amount"),
)...,
)
paymentPath = "zero_amount"
goto activation goto activation
} }
@ -536,7 +573,6 @@ func (l *PurchaseCheckoutLogic) balancePayment(u *user.User, o *order.Order) err
} }
// Calculate payment distribution: prioritize gift amount first // Calculate payment distribution: prioritize gift amount first
var giftUsed, balanceUsed int64
remainingAmount := o.Amount remainingAmount := o.Amount
if userInfo.GiftAmount >= remainingAmount { if userInfo.GiftAmount >= remainingAmount {
@ -621,6 +657,15 @@ func (l *PurchaseCheckoutLogic) balancePayment(u *user.User, o *order.Order) err
return err return err
} }
commonLogic.SubscriptionTraceInfo(l.Logger, commonLogic.SubscriptionTraceFlowOrder, "payment_settled",
"[SubscriptionFlow] balance payment settled and order marked paid",
append(commonLogic.OrderTraceFields(o),
logger.Field("payment_path", "balance"),
logger.Field("gift_used", giftUsed),
logger.Field("balance_used", balanceUsed),
)...,
)
activation: activation:
// Enqueue order activation task for immediate processing // Enqueue order activation task for immediate processing
payload := queueType.ForthwithActivateOrderPayload{ payload := queueType.ForthwithActivateOrderPayload{
@ -639,6 +684,13 @@ activation:
return err return err
} }
commonLogic.SubscriptionTraceInfo(l.Logger, commonLogic.SubscriptionTraceFlowOrder, "activation_task_enqueued",
"[SubscriptionFlow] activation task enqueued after checkout payment",
append(commonLogic.OrderTraceFields(o),
logger.Field("payment_path", paymentPath),
)...,
)
l.Logger.Info("[PurchaseCheckout] Balance payment completed successfully", l.Logger.Info("[PurchaseCheckout] Balance payment completed successfully",
logger.Field("orderNo", o.OrderNo), logger.Field("orderNo", o.OrderNo),
logger.Field("userId", u.Id)) logger.Field("userId", u.Id))

View File

@ -8,6 +8,7 @@ import (
"time" "time"
"github.com/perfect-panel/server/internal/config" "github.com/perfect-panel/server/internal/config"
commonLogic "github.com/perfect-panel/server/internal/logic/common"
"github.com/perfect-panel/server/internal/model/user" "github.com/perfect-panel/server/internal/model/user"
"github.com/perfect-panel/server/internal/svc" "github.com/perfect-panel/server/internal/svc"
"github.com/perfect-panel/server/internal/types" "github.com/perfect-panel/server/internal/types"
@ -43,6 +44,12 @@ func (l *BindEmailWithVerificationLogic) BindEmailWithVerification(req *types.Bi
return nil, errors.Wrapf(xerr.NewErrCode(xerr.InvalidAccess), "Invalid Access") return nil, errors.Wrapf(xerr.NewErrCode(xerr.InvalidAccess), "Invalid Access")
} }
commonLogic.SubscriptionTraceInfo(l.Logger, commonLogic.SubscriptionTraceFlowEmailBind, "bind_start",
"[SubscriptionFlow] email bind with verification started",
logger.Field("device_user_id", u.Id),
logger.Field("email", req.Email),
)
type payload struct { type payload struct {
Code string `json:"code"` Code string `json:"code"`
LastAt int64 `json:"lastAt"` LastAt int64 `json:"lastAt"`
@ -69,6 +76,12 @@ func (l *BindEmailWithVerificationLogic) BindEmailWithVerification(req *types.Bi
return nil, errors.Wrapf(xerr.NewErrCode(xerr.VerifyCodeError), "code error or expired") return nil, errors.Wrapf(xerr.NewErrCode(xerr.VerifyCodeError), "code error or expired")
} }
commonLogic.SubscriptionTraceInfo(l.Logger, commonLogic.SubscriptionTraceFlowEmailBind, "bind_code_verified",
"[SubscriptionFlow] email verification code accepted",
logger.Field("device_user_id", u.Id),
logger.Field("email", req.Email),
)
familyHelper := newFamilyBindingHelper(l.ctx, l.svcCtx) familyHelper := newFamilyBindingHelper(l.ctx, l.svcCtx)
currentEmailMethod, err := familyHelper.getUserEmailMethod(u.Id) currentEmailMethod, err := familyHelper.getUserEmailMethod(u.Id)
if err != nil { if err != nil {
@ -115,6 +128,13 @@ func (l *BindEmailWithVerificationLogic) BindEmailWithVerification(req *types.Bi
return nil, txErr return nil, txErr
} }
commonLogic.SubscriptionTraceInfo(l.Logger, commonLogic.SubscriptionTraceFlowEmailBind, "email_owner_created",
"[SubscriptionFlow] new email owner account created for bind flow",
logger.Field("device_user_id", u.Id),
logger.Field("owner_user_id", emailUser.Id),
logger.Field("email", req.Email),
)
// Join family: email user as owner, device user as member // Join family: email user as owner, device user as member
if err = familyHelper.validateJoinFamily(emailUser.Id, u.Id); err != nil { if err = familyHelper.validateJoinFamily(emailUser.Id, u.Id); err != nil {
return nil, err return nil, err
@ -123,11 +143,32 @@ func (l *BindEmailWithVerificationLogic) BindEmailWithVerification(req *types.Bi
if err != nil { if err != nil {
return nil, err return nil, err
} }
commonLogic.SubscriptionTraceInfo(l.Logger, commonLogic.SubscriptionTraceFlowEmailBind, "family_joined",
"[SubscriptionFlow] device user joined email owner family",
logger.Field("device_user_id", u.Id),
logger.Field("owner_user_id", emailUser.Id),
logger.Field("family_id", joinResult.FamilyId),
logger.Field("email", req.Email),
)
token, err := l.refreshBindSessionToken(u.Id) token, err := l.refreshBindSessionToken(u.Id)
if err != nil { if err != nil {
return nil, err return nil, err
} }
commonLogic.SubscriptionTraceInfo(l.Logger, commonLogic.SubscriptionTraceFlowEmailBind, "trial_grant_requested",
"[SubscriptionFlow] evaluating trial grant after email bind",
logger.Field("device_user_id", u.Id),
logger.Field("owner_user_id", emailUser.Id),
logger.Field("family_id", joinResult.FamilyId),
logger.Field("email", req.Email),
)
tryGrantTrialOnEmailBind(l.ctx, l.svcCtx, l.Logger, emailUser.Id, req.Email) tryGrantTrialOnEmailBind(l.ctx, l.svcCtx, l.Logger, emailUser.Id, req.Email)
commonLogic.SubscriptionTraceInfo(l.Logger, commonLogic.SubscriptionTraceFlowEmailBind, "bind_complete",
"[SubscriptionFlow] email bind with verification completed",
logger.Field("device_user_id", u.Id),
logger.Field("owner_user_id", emailUser.Id),
logger.Field("family_id", joinResult.FamilyId),
logger.Field("email", req.Email),
)
return &types.BindEmailWithVerificationResponse{ return &types.BindEmailWithVerificationResponse{
Success: true, Success: true,
Message: "email user created and joined family", Message: "email user created and joined family",
@ -146,16 +187,44 @@ func (l *BindEmailWithVerificationLogic) BindEmailWithVerification(req *types.Bi
if err = familyHelper.validateJoinFamily(existingMethod.UserId, u.Id); err != nil { if err = familyHelper.validateJoinFamily(existingMethod.UserId, u.Id); err != nil {
return nil, err return nil, err
} }
commonLogic.SubscriptionTraceInfo(l.Logger, commonLogic.SubscriptionTraceFlowEmailBind, "email_owner_resolved",
"[SubscriptionFlow] existing email owner resolved for bind flow",
logger.Field("device_user_id", u.Id),
logger.Field("owner_user_id", existingMethod.UserId),
logger.Field("email", req.Email),
)
joinResult, err := familyHelper.joinFamily(existingMethod.UserId, u.Id, "bind_email_with_verification") joinResult, err := familyHelper.joinFamily(existingMethod.UserId, u.Id, "bind_email_with_verification")
if err != nil { if err != nil {
return nil, err return nil, err
} }
commonLogic.SubscriptionTraceInfo(l.Logger, commonLogic.SubscriptionTraceFlowEmailBind, "family_joined",
"[SubscriptionFlow] device user joined existing email owner family",
logger.Field("device_user_id", u.Id),
logger.Field("owner_user_id", existingMethod.UserId),
logger.Field("family_id", joinResult.FamilyId),
logger.Field("email", req.Email),
)
token, err := l.refreshBindSessionToken(u.Id) token, err := l.refreshBindSessionToken(u.Id)
if err != nil { if err != nil {
return nil, err return nil, err
} }
commonLogic.SubscriptionTraceInfo(l.Logger, commonLogic.SubscriptionTraceFlowEmailBind, "trial_grant_requested",
"[SubscriptionFlow] evaluating trial grant after existing email owner bind",
logger.Field("device_user_id", u.Id),
logger.Field("owner_user_id", existingMethod.UserId),
logger.Field("family_id", joinResult.FamilyId),
logger.Field("email", req.Email),
)
tryGrantTrialOnEmailBind(l.ctx, l.svcCtx, l.Logger, existingMethod.UserId, req.Email) tryGrantTrialOnEmailBind(l.ctx, l.svcCtx, l.Logger, existingMethod.UserId, req.Email)
commonLogic.SubscriptionTraceInfo(l.Logger, commonLogic.SubscriptionTraceFlowEmailBind, "bind_complete",
"[SubscriptionFlow] email bind with verification completed",
logger.Field("device_user_id", u.Id),
logger.Field("owner_user_id", existingMethod.UserId),
logger.Field("family_id", joinResult.FamilyId),
logger.Field("email", req.Email),
)
return &types.BindEmailWithVerificationResponse{ return &types.BindEmailWithVerificationResponse{
Success: true, Success: true,

View File

@ -5,6 +5,7 @@ import (
"time" "time"
"github.com/perfect-panel/server/internal/logic/auth" "github.com/perfect-panel/server/internal/logic/auth"
commonLogic "github.com/perfect-panel/server/internal/logic/common"
"github.com/perfect-panel/server/internal/model/user" "github.com/perfect-panel/server/internal/model/user"
"github.com/perfect-panel/server/internal/svc" "github.com/perfect-panel/server/internal/svc"
"github.com/perfect-panel/server/pkg/logger" "github.com/perfect-panel/server/pkg/logger"
@ -14,11 +15,28 @@ import (
func tryGrantTrialOnEmailBind(ctx context.Context, svcCtx *svc.ServiceContext, log logger.Logger, ownerUserId int64, email string) { func tryGrantTrialOnEmailBind(ctx context.Context, svcCtx *svc.ServiceContext, log logger.Logger, ownerUserId int64, email string) {
rc := svcCtx.Config.Register rc := svcCtx.Config.Register
if !auth.ShouldGrantTrialForEmail(rc, email) { commonLogic.SubscriptionTraceInfo(log, commonLogic.SubscriptionTraceFlowEmailBind, "trial_grant_evaluating",
if rc.EnableTrial && rc.EnableTrialEmailWhitelist { "[SubscriptionFlow] evaluating email bind trial grant",
log.Infow("email domain not in trial whitelist, skip", logger.Field("owner_user_id", ownerUserId),
logger.Field("email", email),
logger.Field("trial_subscribe_id", rc.TrialSubscribe),
)
if !auth.ShouldAutoGrantTrialOnPublicEmailFlows(rc) {
commonLogic.SubscriptionTraceInfo(log, commonLogic.SubscriptionTraceFlowEmailBind, "trial_grant_skipped",
"[SubscriptionFlow] auto trial on public email flow disabled",
logger.Field("email", email), logger.Field("email", email),
logger.Field("owner_user_id", ownerUserId), logger.Field("owner_user_id", ownerUserId),
logger.Field("skip_reason", "public_email_trial_disabled"),
)
return
}
if !auth.ShouldGrantTrialForEmail(rc, email) {
if rc.EnableTrial && rc.EnableTrialEmailWhitelist {
commonLogic.SubscriptionTraceInfo(log, commonLogic.SubscriptionTraceFlowEmailBind, "trial_grant_skipped",
"[SubscriptionFlow] email domain not in trial whitelist",
logger.Field("email", email),
logger.Field("owner_user_id", ownerUserId),
logger.Field("skip_reason", "trial_whitelist_rejected"),
) )
} }
return return
@ -29,19 +47,45 @@ func tryGrantTrialOnEmailBind(ctx context.Context, svcCtx *svc.ServiceContext, l
Model(&user.Subscribe{}). Model(&user.Subscribe{}).
Where("user_id = ? AND subscribe_id = ?", ownerUserId, rc.TrialSubscribe). Where("user_id = ? AND subscribe_id = ?", ownerUserId, rc.TrialSubscribe).
Count(&count).Error; err != nil { Count(&count).Error; err != nil {
log.Errorw("failed to check existing trial", logger.Field("error", err.Error())) commonLogic.SubscriptionTraceError(log, commonLogic.SubscriptionTraceFlowEmailBind, "trial_grant_error",
"[SubscriptionFlow] failed to query existing trial subscription",
logger.Field("owner_user_id", ownerUserId),
logger.Field("email", email),
logger.Field("error", err.Error()),
)
return return
} }
if count > 0 { if count > 0 {
log.Infow("trial already granted, skip", commonLogic.SubscriptionTraceInfo(log, commonLogic.SubscriptionTraceFlowEmailBind, "trial_grant_skipped",
"[SubscriptionFlow] trial already exists for owner",
logger.Field("owner_user_id", ownerUserId), logger.Field("owner_user_id", ownerUserId),
logger.Field("email", email),
logger.Field("skip_reason", "trial_already_exists"),
)
return
}
// Cross-user check: prevent the same real inbox (via dot trick / + alias) from
// getting multiple trials across different accounts.
if auth.NormalizedEmailHasTrial(ctx, svcCtx.DB, email, rc.TrialSubscribe) {
commonLogic.SubscriptionTraceInfo(log, commonLogic.SubscriptionTraceFlowEmailBind, "trial_grant_skipped",
"[SubscriptionFlow] normalized email already received a trial elsewhere",
logger.Field("email", email),
logger.Field("owner_user_id", ownerUserId),
logger.Field("skip_reason", "normalized_email_has_trial"),
) )
return return
} }
sub, err := svcCtx.SubscribeModel.FindOne(ctx, rc.TrialSubscribe) sub, err := svcCtx.SubscribeModel.FindOne(ctx, rc.TrialSubscribe)
if err != nil { if err != nil {
log.Errorw("failed to find trial subscribe template", logger.Field("error", err.Error())) commonLogic.SubscriptionTraceError(log, commonLogic.SubscriptionTraceFlowEmailBind, "trial_grant_error",
"[SubscriptionFlow] failed to load trial subscription template",
logger.Field("owner_user_id", ownerUserId),
logger.Field("email", email),
logger.Field("trial_subscribe_id", rc.TrialSubscribe),
logger.Field("error", err.Error()),
)
return return
} }
@ -59,9 +103,13 @@ func tryGrantTrialOnEmailBind(ctx context.Context, svcCtx *svc.ServiceContext, l
Status: 1, Status: 1,
} }
if err = svcCtx.UserModel.InsertSubscribe(ctx, userSub); err != nil { if err = svcCtx.UserModel.InsertSubscribe(ctx, userSub); err != nil {
log.Errorw("failed to insert trial subscribe", commonLogic.SubscriptionTraceError(log, commonLogic.SubscriptionTraceFlowEmailBind, "trial_grant_error",
"[SubscriptionFlow] failed to create trial subscription for email bind",
append(commonLogic.UserSubscribeTraceFields(userSub),
logger.Field("error", err.Error()), logger.Field("error", err.Error()),
logger.Field("owner_user_id", ownerUserId), logger.Field("owner_user_id", ownerUserId),
logger.Field("email", email),
)...,
) )
return return
} }
@ -72,9 +120,12 @@ func tryGrantTrialOnEmailBind(ctx context.Context, svcCtx *svc.ServiceContext, l
} }
} }
log.Infow("trial granted on email bind", commonLogic.SubscriptionTraceInfo(log, commonLogic.SubscriptionTraceFlowEmailBind, "trial_grant_succeeded",
"[SubscriptionFlow] trial subscription granted after email bind",
append(commonLogic.UserSubscribeTraceFields(userSub),
logger.Field("owner_user_id", ownerUserId), logger.Field("owner_user_id", ownerUserId),
logger.Field("email", email), logger.Field("email", email),
logger.Field("subscribe_id", sub.Id), logger.Field("trial_subscribe_id", sub.Id),
)...,
) )
} }

View File

@ -209,10 +209,9 @@ func transferMemberSubscribesToOwner(tx *gorm.DB, memberUserId, ownerUserId int6
if len(subscribes) == 0 { if len(subscribes) == 0 {
return nil, nil return nil, nil
} }
if err := tx.Model(&user.Subscribe{}). // 加入家庭组时,无条件丢弃成员的所有订阅(软删除)
Where("user_id = ?", memberUserId). if err := tx.Where("user_id = ?", memberUserId).Delete(&user.Subscribe{}).Error; err != nil {
Update("user_id", ownerUserId).Error; err != nil { return nil, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseUpdateError), "discard member subscribes failed")
return nil, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseUpdateError), "transfer member subscribes to owner failed")
} }
return subscribes, nil return subscribes, nil
} }

View File

@ -8,6 +8,7 @@ import (
"time" "time"
"github.com/perfect-panel/server/internal/config" "github.com/perfect-panel/server/internal/config"
commonLogic "github.com/perfect-panel/server/internal/logic/common"
"github.com/perfect-panel/server/internal/model/user" "github.com/perfect-panel/server/internal/model/user"
"github.com/perfect-panel/server/internal/svc" "github.com/perfect-panel/server/internal/svc"
"github.com/perfect-panel/server/internal/types" "github.com/perfect-panel/server/internal/types"
@ -39,6 +40,10 @@ type CacheKeyPayload struct {
func (l *VerifyEmailLogic) VerifyEmail(req *types.VerifyEmailRequest) error { func (l *VerifyEmailLogic) VerifyEmail(req *types.VerifyEmailRequest) error {
req.Email = strings.ToLower(strings.TrimSpace(req.Email)) req.Email = strings.ToLower(strings.TrimSpace(req.Email))
commonLogic.SubscriptionTraceInfo(l.Logger, commonLogic.SubscriptionTraceFlowEmailBind, "verify_email_start",
"[SubscriptionFlow] email verification started",
logger.Field("email", req.Email),
)
cacheKey := fmt.Sprintf("%s:%s:%s", config.AuthCodeCacheKey, constant.Security, req.Email) cacheKey := fmt.Sprintf("%s:%s:%s", config.AuthCodeCacheKey, constant.Security, req.Email)
value, err := l.svcCtx.Redis.Get(l.ctx, cacheKey).Result() value, err := l.svcCtx.Redis.Get(l.ctx, cacheKey).Result()
if err != nil { if err != nil {
@ -59,6 +64,10 @@ func (l *VerifyEmailLogic) VerifyEmail(req *types.VerifyEmailRequest) error {
return errors.Wrapf(xerr.NewErrCode(xerr.VerifyCodeError), "code expired") return errors.Wrapf(xerr.NewErrCode(xerr.VerifyCodeError), "code expired")
} }
l.svcCtx.Redis.Del(l.ctx, cacheKey) l.svcCtx.Redis.Del(l.ctx, cacheKey)
commonLogic.SubscriptionTraceInfo(l.Logger, commonLogic.SubscriptionTraceFlowEmailBind, "verify_email_code_verified",
"[SubscriptionFlow] email verification code accepted",
logger.Field("email", req.Email),
)
u, ok := l.ctx.Value(constant.CtxKeyUser).(*user.User) u, ok := l.ctx.Value(constant.CtxKeyUser).(*user.User)
if !ok { if !ok {
@ -77,6 +86,12 @@ func (l *VerifyEmailLogic) VerifyEmail(req *types.VerifyEmailRequest) error {
if err != nil { if err != nil {
return errors.Wrapf(xerr.NewErrCode(xerr.DatabaseUpdateError), "UpdateUserAuthMethods error") return errors.Wrapf(xerr.NewErrCode(xerr.DatabaseUpdateError), "UpdateUserAuthMethods error")
} }
commonLogic.SubscriptionTraceInfo(l.Logger, commonLogic.SubscriptionTraceFlowEmailBind, "verify_email_completed",
"[SubscriptionFlow] email verification completed and trial evaluation will run",
logger.Field("user_id", u.Id),
logger.Field("owner_user_id", method.UserId),
logger.Field("email", req.Email),
)
tryGrantTrialOnEmailBind(l.ctx, l.svcCtx, l.Logger, method.UserId, req.Email) tryGrantTrialOnEmailBind(l.ctx, l.svcCtx, l.Logger, method.UserId, req.Email)
return nil return nil
} }

View File

@ -194,7 +194,7 @@ func (l *GetServerUserListLogic) GetServerUserList(req *types.GetServerUserListR
val, _ := json.Marshal(resp) val, _ := json.Marshal(resp)
etag := tool.GenerateETag(val) etag := tool.GenerateETag(val)
l.ctx.Header("ETag", etag) l.ctx.Header("ETag", etag)
err = l.svcCtx.Redis.Set(l.ctx, cacheKey, string(val), -1).Err() err = l.svcCtx.Redis.Set(l.ctx, cacheKey, string(val), l.serverUserListCacheTTL()).Err()
if err != nil { if err != nil {
l.Errorw("[ServerUserListCacheKey] redis set error", logger.Field("error", err.Error())) l.Errorw("[ServerUserListCacheKey] redis set error", logger.Field("error", err.Error()))
} }
@ -205,6 +205,18 @@ func (l *GetServerUserListLogic) GetServerUserList(req *types.GetServerUserListR
return resp, nil return resp, nil
} }
func (l *GetServerUserListLogic) serverUserListCacheTTL() time.Duration {
pullInterval := l.svcCtx.Config.Node.NodePullInterval
if pullInterval <= 0 {
pullInterval = 60
}
ttl := time.Duration(pullInterval*2) * time.Second
if ttl < time.Minute {
return time.Minute
}
return ttl
}
func (l *GetServerUserListLogic) shouldIncludeServerUser(userSub *user.Subscribe, serverNodeGroupIds []int64) bool { func (l *GetServerUserListLogic) shouldIncludeServerUser(userSub *user.Subscribe, serverNodeGroupIds []int64) bool {
if userSub == nil { if userSub == nil {
return false return false
@ -295,6 +307,15 @@ func (l *GetServerUserListLogic) canUseExpiredNodeGroup(userSub *user.Subscribe,
// calculateEffectiveSpeedLimit 计算用户的实际限速值(考虑按量限速规则) // calculateEffectiveSpeedLimit 计算用户的实际限速值(考虑按量限速规则)
func (l *GetServerUserListLogic) calculateEffectiveSpeedLimit(sub *subscribe.Subscribe, userSub *user.Subscribe) int64 { func (l *GetServerUserListLogic) calculateEffectiveSpeedLimit(sub *subscribe.Subscribe, userSub *user.Subscribe) int64 {
result := speedlimit.Calculate(l.ctx.Request.Context(), l.svcCtx.DB, userSub.UserId, userSub.Id, sub.SpeedLimit, sub.TrafficLimit) result := speedlimit.CalculateWithCache(
l.ctx.Request.Context(),
l.svcCtx.Redis,
l.svcCtx.DB,
userSub.UserId,
userSub.Id,
sub.SpeedLimit,
sub.TrafficLimit,
30*time.Second,
)
return result.EffectiveSpeed return result.EffectiveSpeed
} }

View File

@ -166,12 +166,11 @@ func (m *customOrderModel) QueryMonthlyOrders(ctx context.Context, date time.Tim
// QueryDateOrders Query orders by date // QueryDateOrders Query orders by date
func (m *customOrderModel) QueryDateOrders(ctx context.Context, date time.Time) (OrdersTotal, error) { func (m *customOrderModel) QueryDateOrders(ctx context.Context, date time.Time) (OrdersTotal, error) {
start := date.Truncate(24 * time.Hour) dateStr := date.Format("2006-01-02")
end := start.Add(24 * time.Hour).Add(-time.Nanosecond)
var result OrdersTotal var result OrdersTotal
err := m.QueryNoCacheCtx(ctx, &result, func(conn *gorm.DB, v interface{}) error { err := m.QueryNoCacheCtx(ctx, &result, func(conn *gorm.DB, v interface{}) error {
return conn.Model(&Order{}). return conn.Model(&Order{}).
Where("status IN ? AND created_at BETWEEN ? AND ? AND method != ?", []int64{2, 5}, start, end, "balance"). Where("status IN ? AND DATE_FORMAT(created_at, '%Y-%m-%d') = ? AND method != ?", []int64{2, 5}, dateStr, "balance").
Select( Select(
"SUM(amount) as amount_total, " + "SUM(amount) as amount_total, " +
"SUM(CASE WHEN is_new = 1 THEN amount ELSE 0 END) as new_order_amount, " + "SUM(CASE WHEN is_new = 1 THEN amount ELSE 0 END) as new_order_amount, " +
@ -222,10 +221,7 @@ func (m *customOrderModel) QueryMonthlyUserCounts(ctx context.Context, date time
return counts.NewUsers, counts.RenewalUsers, err return counts.NewUsers, counts.RenewalUsers, err
} }
func (m *customOrderModel) QueryDateUserCounts(ctx context.Context, date time.Time) (int64, int64, error) { func (m *customOrderModel) QueryDateUserCounts(ctx context.Context, date time.Time) (int64, int64, error) {
// 当天 00:00:00 dateStr := date.Format("2006-01-02")
start := time.Date(date.Year(), date.Month(), date.Day(), 0, 0, 0, 0, date.Location())
// 下一天 00:00:00
nextDay := start.Add(24 * time.Hour)
var counts UserCounts var counts UserCounts
@ -235,8 +231,8 @@ func (m *customOrderModel) QueryDateUserCounts(ctx context.Context, date time.Ti
COUNT(DISTINCT CASE WHEN is_new = 1 THEN user_id END) AS new_users, COUNT(DISTINCT CASE WHEN is_new = 1 THEN user_id END) AS new_users,
COUNT(DISTINCT CASE WHEN is_new = 0 THEN user_id END) AS renewal_users COUNT(DISTINCT CASE WHEN is_new = 0 THEN user_id END) AS renewal_users
`). `).
Where("status IN ? AND created_at >= ? AND created_at < ? AND method != ?", Where("status IN ? AND DATE_FORMAT(created_at, '%Y-%m-%d') = ? AND method != ?",
[]int64{2, 5}, start, nextDay, "balance"). []int64{2, 5}, dateStr, "balance").
Scan(&counts).Error Scan(&counts).Error
}) })
@ -276,7 +272,7 @@ func (m *customOrderModel) QueryDailyOrdersList(ctx context.Context, date time.T
// 当月 1 号 00:00:00 // 当月 1 号 00:00:00
firstDay := time.Date(date.Year(), date.Month(), 1, 0, 0, 0, 0, date.Location()) firstDay := time.Date(date.Year(), date.Month(), 1, 0, 0, 0, 0, date.Location())
// 第二天 00:00:00 // 第二天 00:00:00
nextDay := date.AddDate(0, 0, 1).Truncate(24 * time.Hour) nextDay := time.Date(date.Year(), date.Month(), date.Day()+1, 0, 0, 0, 0, date.Location())
return conn.Model(&Order{}). return conn.Model(&Order{}).
Select(` Select(`

View File

@ -27,8 +27,8 @@ func NewModel(conn *gorm.DB) Model {
func (m *customTrafficModel) QueryServerTrafficByDay(ctx context.Context, serverId int64, date time.Time) (*TotalTraffic, error) { func (m *customTrafficModel) QueryServerTrafficByDay(ctx context.Context, serverId int64, date time.Time) (*TotalTraffic, error) {
var data TotalTraffic var data TotalTraffic
start := date.Truncate(24 * time.Hour) start := time.Date(date.Year(), date.Month(), date.Day(), 0, 0, 0, 0, date.Location())
end := start.Add(24 * time.Hour).Add(-time.Nanosecond) end := start.AddDate(0, 0, 1).Add(-time.Nanosecond)
err := m.Conn.WithContext(ctx).Model(&TrafficLog{}). err := m.Conn.WithContext(ctx).Model(&TrafficLog{}).
Select("sum(download) as download, sum(upload) as upload"). Select("sum(download) as download, sum(upload) as upload").
Where("server_id = ? AND timestamp BETWEEN ? AND ?", serverId, start, end). Where("server_id = ? AND timestamp BETWEEN ? AND ?", serverId, start, end).

View File

@ -314,8 +314,8 @@ func (m *customUserModel) UpdateUserSubscribeWithTraffic(ctx context.Context, id
func (m *customUserModel) QueryResisterUserTotalByDate(ctx context.Context, date time.Time) (int64, error) { func (m *customUserModel) QueryResisterUserTotalByDate(ctx context.Context, date time.Time) (int64, error) {
var total int64 var total int64
start := date.Truncate(24 * time.Hour) start := time.Date(date.Year(), date.Month(), date.Day(), 0, 0, 0, 0, date.Location())
end := start.Add(24 * time.Hour).Add(-time.Second) end := start.AddDate(0, 0, 1).Add(-time.Second)
err := m.QueryNoCacheCtx(ctx, &total, func(conn *gorm.DB, v interface{}) error { err := m.QueryNoCacheCtx(ctx, &total, func(conn *gorm.DB, v interface{}) error {
return conn.Model(&User{}).Where("created_at > ? and created_at < ?", start, end).Count(&total).Error return conn.Model(&User{}).Where("created_at > ? and created_at < ?", start, end).Count(&total).Error
}) })
@ -447,7 +447,8 @@ func (m *customUserModel) QueryMonthlyUserStatisticsList(ctx context.Context, da
func (m *customUserModel) FindActiveSubscribe(ctx context.Context, userId int64) (*Subscribe, error) { func (m *customUserModel) FindActiveSubscribe(ctx context.Context, userId int64) (*Subscribe, error) {
var subscribe Subscribe var subscribe Subscribe
err := m.QueryNoCacheCtx(ctx, &subscribe, func(conn *gorm.DB, v interface{}) error { err := m.QueryNoCacheCtx(ctx, &subscribe, func(conn *gorm.DB, v interface{}) error {
return conn.Where("user_id = ? AND status IN (0, 1) AND expire_time > ?", userId, time.Now()). now := time.Now()
return conn.Where("user_id = ? AND status IN (0, 1) AND (expire_time > ? OR expire_time = ?)", userId, now, time.UnixMilli(0)).
Order("expire_time DESC"). Order("expire_time DESC").
First(v).Error First(v).Error
}) })

View File

@ -136,6 +136,7 @@ type Device struct {
UserAgent string `gorm:"default:null;comment:UserAgent."` UserAgent string `gorm:"default:null;comment:UserAgent."`
Identifier string `gorm:"type:varchar(255);unique;index:idx_identifier;default:'';comment:Device Identifier"` Identifier string `gorm:"type:varchar(255);unique;index:idx_identifier;default:'';comment:Device Identifier"`
ShortCode string `gorm:"type:varchar(255);default:'';comment:Short Code"` ShortCode string `gorm:"type:varchar(255);default:'';comment:Short Code"`
BasePayload string `gorm:"type:text;default:null;comment:Base Payload"`
Online bool `gorm:"default:false;not null;comment:Online"` Online bool `gorm:"default:false;not null;comment:Online"`
Enabled bool `gorm:"default:true;not null;comment:Enabled"` Enabled bool `gorm:"default:true;not null;comment:Enabled"`
CreatedAt time.Time `gorm:"<-:create;comment:Creation Time"` CreatedAt time.Time `gorm:"<-:create;comment:Creation Time"`

View File

@ -36,9 +36,13 @@ func NewService(svc *svc.ServiceContext) *Service {
} }
func initServer(svc *svc.ServiceContext) *gin.Engine { func initServer(svc *svc.ServiceContext) *gin.Engine {
// start init system config // start init system config
initStart := time.Now()
logger.Info("system initialization start")
initialize.StartInitSystemConfig(svc) initialize.StartInitSystemConfig(svc)
logger.Infow("system initialization complete",
logger.Field("duration", time.Since(initStart).String()),
)
// init gin server // init gin server
r := gin.Default() r := gin.Default()
r.RemoteIPHeaders = []string{"X-Original-Forwarded-For", "X-Forwarded-For", "X-Real-IP"} r.RemoteIPHeaders = []string{"X-Original-Forwarded-For", "X-Forwarded-For", "X-Real-IP"}

View File

@ -642,6 +642,7 @@ type DeviceLoginRequest struct {
UserAgent string `json:"user_agent" validate:"required"` UserAgent string `json:"user_agent" validate:"required"`
CfToken string `json:"cf_token,optional"` CfToken string `json:"cf_token,optional"`
ShortCode string `json:"short_code,optional"` ShortCode string `json:"short_code,optional"`
BasePayload string `json:"base_payload,optional"`
} }
type DissolveFamilyRequest struct { type DissolveFamilyRequest struct {
@ -2655,6 +2656,7 @@ type SubscribeDiscount struct {
Quantity int64 `json:"quantity"` Quantity int64 `json:"quantity"`
Discount float64 `json:"discount"` Discount float64 `json:"discount"`
NewUserOnly bool `json:"new_user_only"` NewUserOnly bool `json:"new_user_only"`
MapApple string `json:"map_apple"`
} }
type SubscribeGroup struct { type SubscribeGroup struct {

View File

@ -9,6 +9,7 @@ import (
) )
type GormLogger struct { type GormLogger struct {
SlowThreshold time.Duration
} }
const TAG = "[GORM]" const TAG = "[GORM]"
@ -27,24 +28,25 @@ func (l *GormLogger) LogMode(logger.LogLevel) logger.Interface {
default: default:
sysLevel = "unknown" sysLevel = "unknown"
} }
Infof("%s System Log Level is %s", TAG, sysLevel) Debugf("%s System Log Level is %s", TAG, sysLevel)
return l return l
} }
func (l *GormLogger) Info(ctx context.Context, str string, args ...interface{}) { func (l *GormLogger) Info(ctx context.Context, str string, args ...interface{}) {
WithContext(ctx).WithCallerSkip(6).Infof("%s Info: %s", TAG, str, args) WithContext(ctx).WithCallerSkip(6).Debugf("%s Info: %s", TAG, fmt.Sprintf(str, args...))
} }
func (l *GormLogger) Warn(ctx context.Context, str string, args ...interface{}) { func (l *GormLogger) Warn(ctx context.Context, str string, args ...interface{}) {
WithContext(ctx).WithCallerSkip(6).Infof("%s Warn: %s", TAG, str, args) WithContext(ctx).WithCallerSkip(6).Debugf("%s Warn: %s", TAG, fmt.Sprintf(str, args...))
} }
func (l *GormLogger) Error(ctx context.Context, str string, args ...interface{}) { func (l *GormLogger) Error(ctx context.Context, str string, args ...interface{}) {
WithContext(ctx).WithCallerSkip(6).Errorf("%s Error: %s", TAG, str, args) WithContext(ctx).WithCallerSkip(6).Errorf("%s Error: %s", TAG, fmt.Sprintf(str, args...))
} }
func (l *GormLogger) Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) { func (l *GormLogger) Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) {
sql, rowsAffected := fc() sql, rowsAffected := fc()
duration := time.Since(begin)
fields := []LogField{ fields := []LogField{
{ {
Key: "sql", Key: "sql",
@ -60,8 +62,16 @@ func (l *GormLogger) Trace(ctx context.Context, begin time.Time, fc func() (sql
Key: "error", Key: "error",
Value: err.Error(), Value: err.Error(),
}) })
WithContext(ctx).WithCallerSkip(6).WithDuration(time.Since(begin)).Errorw(TAG, fields...) WithContext(ctx).WithCallerSkip(6).WithDuration(duration).Errorw(TAG, fields...)
} else { return
WithContext(ctx).WithCallerSkip(6).WithDuration(time.Since(begin)).Infow(fmt.Sprintf("%s SQL Executed", TAG), fields...) }
if l.SlowThreshold > 0 && duration >= l.SlowThreshold {
WithContext(ctx).WithCallerSkip(6).WithDuration(duration).Sloww(fmt.Sprintf("%s SQL Slow", TAG), fields...)
return
}
if shallLog(DebugLevel) {
WithContext(ctx).WithCallerSkip(6).WithDuration(duration).Debugw(fmt.Sprintf("%s SQL Executed", TAG), fields...)
} }
} }

View File

@ -46,7 +46,9 @@ func ConnectMysql(m Mysql) (*gorm.DB, error) {
DSN: m.Dsn(), DSN: m.Dsn(),
} }
db, err := gorm.Open(mysql.New(mysqlCfg), &gorm.Config{ db, err := gorm.Open(mysql.New(mysqlCfg), &gorm.Config{
Logger: new(logger.GormLogger), Logger: &logger.GormLogger{
SlowThreshold: m.GetSlowThreshold(),
},
NamingStrategy: schema.NamingStrategy{ NamingStrategy: schema.NamingStrategy{
SingularTable: true, SingularTable: true,
}, },

View File

@ -2,10 +2,13 @@ package speedlimit
import ( import (
"context" "context"
"crypto/sha256"
"encoding/hex"
"encoding/json" "encoding/json"
"fmt" "fmt"
"time" "time"
"github.com/redis/go-redis/v9"
"gorm.io/gorm" "gorm.io/gorm"
) )
@ -28,6 +31,37 @@ type ThrottleResult struct {
ThrottleEnd int64 `json:"throttle_end"` // Window end Unix timestamp (seconds), 0 if not throttled ThrottleEnd int64 `json:"throttle_end"` // Window end Unix timestamp (seconds), 0 if not throttled
} }
// CalculateWithCache computes the effective speed limit with a short Redis cache.
// It is intended for hot read paths such as node user-list pulls where many nodes
// can ask for the same subscription limits in a short period.
func CalculateWithCache(ctx context.Context, cache *redis.Client, db *gorm.DB, userId, subscribeId, baseSpeedLimit int64, trafficLimitJSON string, ttl time.Duration) *ThrottleResult {
if cache == nil || ttl <= 0 || trafficLimitJSON == "" {
return Calculate(ctx, db, userId, subscribeId, baseSpeedLimit, trafficLimitJSON)
}
key := cacheKey(userId, subscribeId, baseSpeedLimit, trafficLimitJSON)
if cached, err := cache.Get(ctx, key).Result(); err == nil && cached != "" {
var result ThrottleResult
if err := json.Unmarshal([]byte(cached), &result); err == nil {
return &result
}
}
result := Calculate(ctx, db, userId, subscribeId, baseSpeedLimit, trafficLimitJSON)
if payload, err := json.Marshal(result); err == nil {
_ = cache.Set(ctx, key, string(payload), ttl).Err()
}
return result
}
// ClearCache removes a cached speed-limit calculation for a user subscription.
func ClearCache(ctx context.Context, cache *redis.Client, userId, subscribeId, baseSpeedLimit int64, trafficLimitJSON string) error {
if cache == nil || trafficLimitJSON == "" {
return nil
}
return cache.Del(ctx, cacheKey(userId, subscribeId, baseSpeedLimit, trafficLimitJSON)).Err()
}
// Calculate computes the effective speed limit for a user subscription, // Calculate computes the effective speed limit for a user subscription,
// considering traffic-based throttling rules. // considering traffic-based throttling rules.
func Calculate(ctx context.Context, db *gorm.DB, userId, subscribeId, baseSpeedLimit int64, trafficLimitJSON string) *ThrottleResult { func Calculate(ctx context.Context, db *gorm.DB, userId, subscribeId, baseSpeedLimit int64, trafficLimitJSON string) *ThrottleResult {
@ -107,3 +141,8 @@ func Calculate(ctx context.Context, db *gorm.DB, userId, subscribeId, baseSpeedL
return result return result
} }
func cacheKey(userId, subscribeId, baseSpeedLimit int64, trafficLimitJSON string) string {
sum := sha256.Sum256([]byte(trafficLimitJSON))
return fmt.Sprintf("speedlimit:%d:%d:%d:%s", userId, subscribeId, baseSpeedLimit, hex.EncodeToString(sum[:8]))
}

View File

@ -1,6 +1,21 @@
package main package main
import "github.com/perfect-panel/server/cmd" import (
"time"
"github.com/perfect-panel/server/cmd"
)
func init() {
// Ensure time.Local matches DSN loc=Asia/Shanghai.
// In scratch Docker images, LoadLocation may fail due to missing zoneinfo,
// so fall back to FixedZone as a guaranteed alternative.
loc, err := time.LoadLocation("Asia/Shanghai")
if err != nil {
loc = time.FixedZone("CST", 8*3600)
}
time.Local = loc
}
func main() { func main() {
cmd.Execute() cmd.Execute()

View File

@ -10,6 +10,7 @@ import (
"time" "time"
"github.com/perfect-panel/server/internal/logic/admin/group" "github.com/perfect-panel/server/internal/logic/admin/group"
commonLogic "github.com/perfect-panel/server/internal/logic/common"
"github.com/perfect-panel/server/internal/model/log" "github.com/perfect-panel/server/internal/model/log"
"github.com/perfect-panel/server/pkg/constant" "github.com/perfect-panel/server/pkg/constant"
"github.com/perfect-panel/server/pkg/logger" "github.com/perfect-panel/server/pkg/logger"
@ -27,6 +28,7 @@ import (
"github.com/perfect-panel/server/pkg/uuidx" "github.com/perfect-panel/server/pkg/uuidx"
queueTypes "github.com/perfect-panel/server/queue/types" queueTypes "github.com/perfect-panel/server/queue/types"
"gorm.io/gorm" "gorm.io/gorm"
"gorm.io/gorm/clause"
) )
// Order type constants define the different types of orders that can be processed // Order type constants define the different types of orders that can be processed
@ -44,6 +46,7 @@ const (
OrderStatusPaid = 2 // Order paid and ready for processing OrderStatusPaid = 2 // Order paid and ready for processing
OrderStatusClose = 3 // Order closed/cancelled OrderStatusClose = 3 // Order closed/cancelled
OrderStatusFailed = 4 // Order processing failed OrderStatusFailed = 4 // Order processing failed
OrderStatusClaimed = 4 // Internal transient claim while a worker processes the order
OrderStatusFinished = 5 // Order successfully completed OrderStatusFinished = 5 // Order successfully completed
) )
@ -69,8 +72,10 @@ func NewActivateOrderLogic(svc *svc.ServiceContext) *ActivateOrderLogic {
// It handles the complete workflow of activating a paid order including validation, // It handles the complete workflow of activating a paid order including validation,
// processing based on order type, and finalization. // processing based on order type, and finalization.
func (l *ActivateOrderLogic) ProcessTask(ctx context.Context, task *asynq.Task) error { func (l *ActivateOrderLogic) ProcessTask(ctx context.Context, task *asynq.Task) error {
logger.WithContext(ctx).Info("[ActivateOrderLogic] 开始处理订单激活任务", commonLogic.SubscriptionTraceInfo(logger.WithContext(ctx), commonLogic.SubscriptionTraceFlowOrder, "activation_task_received",
logger.Field("payload", string(task.Payload()))) "[SubscriptionFlow] activation task received",
logger.Field("payload", string(task.Payload())),
)
payload, err := l.parsePayload(ctx, task.Payload()) payload, err := l.parsePayload(ctx, task.Payload())
if err != nil { if err != nil {
@ -79,10 +84,13 @@ func (l *ActivateOrderLogic) ProcessTask(ctx context.Context, task *asynq.Task)
return nil // payload 解析失败不重试,因为重试也会失败 return nil // payload 解析失败不重试,因为重试也会失败
} }
logger.WithContext(ctx).Info("[ActivateOrderLogic] 正在验证订单", commonLogic.SubscriptionTraceInfo(logger.WithContext(ctx), commonLogic.SubscriptionTraceFlowOrder, "activation_order_lookup",
logger.Field("order_no", payload.OrderNo)) "[SubscriptionFlow] activation task is loading order",
logger.Field("order_no", payload.OrderNo),
logger.Field("iap_expire_at", payload.IAPExpireAt),
)
orderInfo, err := l.validateAndGetOrder(ctx, payload.OrderNo) orderInfo, err := l.claimAndGetOrder(ctx, payload.OrderNo)
if err != nil { if err != nil {
// 如果订单不存在或状态不对,不重试 // 如果订单不存在或状态不对,不重试
if errors.Is(err, ErrInvalidOrderStatus) { if errors.Is(err, ErrInvalidOrderStatus) {
@ -102,12 +110,13 @@ func (l *ActivateOrderLogic) ProcessTask(ctx context.Context, task *asynq.Task)
return nil return nil
} }
logger.WithContext(ctx).Info("[ActivateOrderLogic] 订单验证通过,开始处理", commonLogic.SubscriptionTraceInfo(logger.WithContext(ctx), commonLogic.SubscriptionTraceFlowOrder, "activation_order_claimed",
logger.Field("order_no", orderInfo.OrderNo), "[SubscriptionFlow] activation worker claimed paid order",
logger.Field("order_type", orderInfo.Type), commonLogic.OrderTraceFields(orderInfo)...,
logger.Field("user_id", orderInfo.UserId)) )
if err = l.processOrderByType(ctx, orderInfo, payload.IAPExpireAt); err != nil { if err = l.processOrderByType(ctx, orderInfo, payload.IAPExpireAt); err != nil {
l.releaseClaim(ctx, orderInfo.OrderNo)
logger.WithContext(ctx).Error("[ActivateOrderLogic] 处理订单失败,将重试", logger.WithContext(ctx).Error("[ActivateOrderLogic] 处理订单失败,将重试",
logger.Field("order_no", orderInfo.OrderNo), logger.Field("order_no", orderInfo.OrderNo),
logger.Field("order_type", orderInfo.Type), logger.Field("order_type", orderInfo.Type),
@ -115,12 +124,21 @@ func (l *ActivateOrderLogic) ProcessTask(ctx context.Context, task *asynq.Task)
return err // 返回 err 允许 asynq 重试 return err // 返回 err 允许 asynq 重试
} }
l.finalizeCouponAndOrder(ctx, orderInfo) if err = l.reconcilePostOrderSubscriptions(ctx, orderInfo); err != nil {
l.releaseClaim(ctx, orderInfo.OrderNo)
logger.WithContext(ctx).Info("[ActivateOrderLogic] 订单激活成功", logger.WithContext(ctx).Error("[ActivateOrderLogic] 订单订阅兜底合并失败,将重试",
logger.Field("order_no", orderInfo.OrderNo), logger.Field("order_no", orderInfo.OrderNo),
logger.Field("order_type", orderInfo.Type), logger.Field("order_type", orderInfo.Type),
logger.Field("user_id", orderInfo.UserId)) logger.Field("error", err.Error()))
return err
}
l.finalizeCouponAndOrder(ctx, orderInfo)
commonLogic.SubscriptionTraceInfo(logger.WithContext(ctx), commonLogic.SubscriptionTraceFlowOrder, "activation_finished",
"[SubscriptionFlow] order activation completed",
commonLogic.OrderTraceFields(orderInfo)...,
)
return nil return nil
} }
@ -137,10 +155,11 @@ func (l *ActivateOrderLogic) parsePayload(ctx context.Context, payload []byte) (
return &p, nil return &p, nil
} }
// validateAndGetOrder retrieves an order by order number and validates its status // claimAndGetOrder retrieves an order by order number and atomically claims paid orders.
// Returns error if order is not found or not in paid status // Returns error if order is not found or not in paid status
func (l *ActivateOrderLogic) validateAndGetOrder(ctx context.Context, orderNo string) (*order.Order, error) { func (l *ActivateOrderLogic) claimAndGetOrder(ctx context.Context, orderNo string) (*order.Order, error) {
orderInfo, err := l.svc.OrderModel.FindOneByOrderNo(ctx, orderNo) var orderInfo order.Order
err := l.svc.DB.WithContext(ctx).Model(&order.Order{}).Where("order_no = ?", orderNo).First(&orderInfo).Error
if err != nil { if err != nil {
logger.WithContext(ctx).Error("Find order failed", logger.WithContext(ctx).Error("Find order failed",
logger.Field("error", err.Error()), logger.Field("error", err.Error()),
@ -165,7 +184,33 @@ func (l *ActivateOrderLogic) validateAndGetOrder(ctx context.Context, orderNo st
return nil, ErrInvalidOrderStatus return nil, ErrInvalidOrderStatus
} }
return orderInfo, nil result := l.svc.DB.WithContext(ctx).
Model(&order.Order{}).
Where("order_no = ? AND status = ?", orderNo, OrderStatusPaid).
Update("status", OrderStatusClaimed)
if result.Error != nil {
return nil, result.Error
}
if result.RowsAffected == 0 {
logger.WithContext(ctx).Info("Order already claimed by another worker, skip processing",
logger.Field("order_no", orderNo),
)
return nil, nil
}
orderInfo.Status = OrderStatusClaimed
return &orderInfo, nil
}
func (l *ActivateOrderLogic) releaseClaim(ctx context.Context, orderNo string) {
if err := l.svc.DB.WithContext(ctx).
Model(&order.Order{}).
Where("order_no = ? AND status = ?", orderNo, OrderStatusClaimed).
Update("status", OrderStatusPaid).Error; err != nil {
logger.WithContext(ctx).Error("Release order claim failed",
logger.Field("error", err.Error()),
logger.Field("order_no", orderNo),
)
}
} }
// processOrderByType routes order processing based on the order type // processOrderByType routes order processing based on the order type
@ -187,6 +232,324 @@ func (l *ActivateOrderLogic) processOrderByType(ctx context.Context, orderInfo *
} }
} }
func (l *ActivateOrderLogic) reconcilePostOrderSubscriptions(ctx context.Context, orderInfo *order.Order) error {
if !shouldReconcilePostOrderSubscriptions(orderInfo) {
return nil
}
effectiveUserID := orderInfo.UserId
if orderInfo.SubscriptionUserId > 0 {
effectiveUserID = orderInfo.SubscriptionUserId
}
if effectiveUserID == 0 || orderInfo.Id == 0 {
return nil
}
var (
survivor user.Subscribe
survivorBefore user.Subscribe
losers []user.Subscribe
mergedIDs []int64
subscribeIDsToClear = make(map[int64]struct{})
missingSurvivor bool
ownerMismatchSkipped bool
identitySourceID int64
)
err := l.svc.DB.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}).
Model(&user.Subscribe{}).
Where("order_id = ?", orderInfo.Id).
First(&survivor).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
missingSurvivor = true
return nil
}
return err
}
survivorBefore = survivor
var ownerSubs []user.Subscribe
if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}).
Model(&user.Subscribe{}).
Where("user_id = ?", effectiveUserID).
Order("id ASC").
Find(&ownerSubs).Error; err != nil {
return err
}
if survivor.UserId != effectiveUserID {
if len(ownerSubs) == 0 {
ownerMismatchSkipped = true
return nil
}
if err := tx.Model(&user.Subscribe{}).
Where("id = ?", survivor.Id).
Update("user_id", effectiveUserID).Error; err != nil {
return err
}
survivor.UserId = effectiveUserID
if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}).
Model(&user.Subscribe{}).
Where("user_id = ?", effectiveUserID).
Order("id ASC").
Find(&ownerSubs).Error; err != nil {
return err
}
}
if len(ownerSubs) <= 1 {
return nil
}
now := time.Now()
accumulatedExpire := now
for i := range ownerSubs {
item := ownerSubs[i]
if (item.Id == survivor.Id || orderMergeRemainingTimeStatus(item.Status)) && item.ExpireTime.After(now) {
accumulatedExpire = accumulatedExpire.Add(item.ExpireTime.Sub(now))
}
if item.Id != survivor.Id {
losers = append(losers, item)
mergedIDs = append(mergedIDs, item.Id)
}
if item.SubscribeId > 0 {
subscribeIDsToClear[item.SubscribeId] = struct{}{}
}
}
if len(losers) == 0 {
return nil
}
if survivor.SubscribeId > 0 {
subscribeIDsToClear[survivor.SubscribeId] = struct{}{}
}
identitySource := pickSubscriptionIdentitySource(losers)
if identitySource != nil {
identitySourceID = identitySource.Id
}
updateFields := map[string]interface{}{
"status": 1,
"finished_at": nil,
}
if accumulatedExpire.After(survivor.ExpireTime) {
survivor.ExpireTime = accumulatedExpire
updateFields["expire_time"] = accumulatedExpire
}
if identitySource != nil {
if identitySource.Token != "" {
survivor.Token = identitySource.Token
updateFields["token"] = identitySource.Token
}
if identitySource.UUID != "" {
survivor.UUID = identitySource.UUID
updateFields["uuid"] = identitySource.UUID
}
}
loserIDs := make([]int64, 0, len(losers))
for i := range losers {
loserIDs = append(loserIDs, losers[i].Id)
}
if len(loserIDs) == 0 {
return nil
}
// user_subscribe 当前没有 deleted_at 字段,这里沿用项目现有删除语义清理 loser 记录。
if err := tx.Where("id IN ?", loserIDs).Delete(&user.Subscribe{}).Error; err != nil {
return err
}
if err := tx.Model(&user.Subscribe{}).
Where("id = ?", survivor.Id).
Updates(updateFields).Error; err != nil {
return err
}
survivor.Status = 1
survivor.FinishedAt = nil
return nil
})
if err != nil {
return err
}
if missingSurvivor {
commonLogic.SubscriptionTraceInfo(logger.WithContext(ctx), commonLogic.SubscriptionTraceFlowOrder, "post_order_reconcile_skipped",
"[SubscriptionFlow] post-order reconcile skipped because survivor subscription was not found",
append(commonLogic.OrderTraceFields(orderInfo),
logger.Field("reason", "post_order_reconcile"),
logger.Field("effective_user_id", effectiveUserID),
)...,
)
return nil
}
if ownerMismatchSkipped {
commonLogic.SubscriptionTraceInfo(logger.WithContext(ctx), commonLogic.SubscriptionTraceFlowOrder, "post_order_reconcile_skipped",
"[SubscriptionFlow] post-order reconcile skipped because survivor owner mismatch had no duplicates",
append(commonLogic.OrderTraceFields(orderInfo),
logger.Field("reason", "post_order_reconcile"),
logger.Field("effective_user_id", effectiveUserID),
logger.Field("survivor_subscribe_id", survivor.Id),
logger.Field("survivor_user_id", survivorBefore.UserId),
)...,
)
return nil
}
if len(losers) == 0 {
return nil
}
l.clearPostOrderReconcileCache(ctx, &survivorBefore, &survivor, losers, subscribeIDsToClear)
commonLogic.SubscriptionTraceInfo(logger.WithContext(ctx), commonLogic.SubscriptionTraceFlowOrder, "post_order_reconciled",
"[SubscriptionFlow] post-order reconcile merged duplicate subscriptions",
append(commonLogic.OrderTraceFields(orderInfo),
logger.Field("reason", "post_order_reconcile"),
logger.Field("effective_user_id", effectiveUserID),
logger.Field("survivor_subscribe_id", survivor.Id),
logger.Field("identity_source_subscribe_id", identitySourceID),
logger.Field("merged_subscribe_ids", mergedIDs),
logger.Field("merged_count", len(mergedIDs)),
)...,
)
return nil
}
func shouldReconcilePostOrderSubscriptions(orderInfo *order.Order) bool {
if orderInfo == nil {
return false
}
switch orderInfo.Type {
case OrderTypeSubscribe, OrderTypeRenewal, OrderTypeRedemption:
return true
default:
return false
}
}
func orderMergeRemainingTimeStatus(status uint8) bool {
switch status {
case 0, 1, 2:
return true
default:
return false
}
}
func subscriptionRenewalBaseTime(now time.Time, userSub *user.Subscribe) time.Time {
if userSub != nil && orderMergeRemainingTimeStatus(userSub.Status) && userSub.ExpireTime.After(now) {
return userSub.ExpireTime
}
return now
}
func pickSubscriptionIdentitySource(candidates []user.Subscribe) *user.Subscribe {
if len(candidates) == 0 {
return nil
}
best := &candidates[0]
for i := 1; i < len(candidates); i++ {
candidate := &candidates[i]
if subscriptionIdentityPriority(candidate, best) {
best = candidate
}
}
return best
}
func subscriptionIdentityPriority(candidate *user.Subscribe, current *user.Subscribe) bool {
if candidate == nil {
return false
}
if current == nil {
return true
}
candidateUsable := candidate.Token != "" || candidate.UUID != ""
currentUsable := current.Token != "" || current.UUID != ""
if candidateUsable != currentUsable {
return candidateUsable
}
if candidate.ExpireTime.After(current.ExpireTime) {
return true
}
if current.ExpireTime.After(candidate.ExpireTime) {
return false
}
if candidate.UpdatedAt.After(current.UpdatedAt) {
return true
}
if current.UpdatedAt.After(candidate.UpdatedAt) {
return false
}
return candidate.Id > current.Id
}
func (l *ActivateOrderLogic) clearPostOrderReconcileCache(
ctx context.Context,
survivorBefore *user.Subscribe,
survivorAfter *user.Subscribe,
losers []user.Subscribe,
subscribeIDs map[int64]struct{},
) {
cacheModels := make([]*user.Subscribe, 0, len(losers)+2)
if survivorBefore != nil {
cacheModels = append(cacheModels, survivorBefore)
}
if survivorAfter != nil {
cacheModels = append(cacheModels, survivorAfter)
}
for i := range losers {
loser := losers[i]
cacheModels = append(cacheModels, &loser)
}
if len(cacheModels) > 0 {
if err := l.svc.UserModel.ClearSubscribeCache(ctx, cacheModels...); err != nil {
logger.WithContext(ctx).Error("Post-order reconcile clear subscribe cache failed",
logger.Field("reason", "post_order_reconcile"),
logger.Field("error", err.Error()),
)
}
}
if l.svc.SubscribeModel != nil {
for subscribeID := range subscribeIDs {
if err := l.svc.SubscribeModel.ClearCache(ctx, subscribeID); err != nil {
logger.WithContext(ctx).Error("Post-order reconcile clear plan cache failed",
logger.Field("reason", "post_order_reconcile"),
logger.Field("subscribe_id", subscribeID),
logger.Field("error", err.Error()),
)
}
}
}
if l.svc.NodeModel != nil {
if err := l.svc.NodeModel.ClearServerAllCache(ctx); err != nil {
logger.WithContext(ctx).Error("Post-order reconcile clear node cache failed",
logger.Field("reason", "post_order_reconcile"),
logger.Field("error", err.Error()),
)
}
}
}
// finalizeCouponAndOrder handles post-processing tasks including coupon updates // finalizeCouponAndOrder handles post-processing tasks including coupon updates
// and order status finalization // and order status finalization
func (l *ActivateOrderLogic) finalizeCouponAndOrder(ctx context.Context, orderInfo *order.Order) { func (l *ActivateOrderLogic) finalizeCouponAndOrder(ctx context.Context, orderInfo *order.Order) {
@ -208,6 +571,10 @@ func (l *ActivateOrderLogic) finalizeCouponAndOrder(ctx context.Context, orderIn
) )
} }
orderInfo.Status = OrderStatusFinished orderInfo.Status = OrderStatusFinished
commonLogic.SubscriptionTraceInfo(logger.WithContext(ctx), commonLogic.SubscriptionTraceFlowOrder, "order_status_finished",
"[SubscriptionFlow] order status updated to finished",
commonLogic.OrderTraceFields(orderInfo)...,
)
} }
// NewPurchase handles new subscription purchase including user creation, // NewPurchase handles new subscription purchase including user creation,
@ -218,6 +585,13 @@ func (l *ActivateOrderLogic) NewPurchase(ctx context.Context, orderInfo *order.O
return err return err
} }
commonLogic.SubscriptionTraceInfo(logger.WithContext(ctx), commonLogic.SubscriptionTraceFlowOrder, "activation_user_resolved",
"[SubscriptionFlow] activation resolved subscription recipient user",
append(commonLogic.OrderTraceFields(orderInfo),
logger.Field("resolved_user_id", userInfo.Id),
)...,
)
sub, err := l.getSubscribeInfo(ctx, orderInfo.SubscribeId) sub, err := l.getSubscribeInfo(ctx, orderInfo.SubscribeId)
if err != nil { if err != nil {
return err return err
@ -258,12 +632,14 @@ func (l *ActivateOrderLogic) NewPurchase(ctx context.Context, orderInfo *order.O
) )
} else { } else {
userSub = anchorSub userSub = anchorSub
logger.WithContext(ctx).Infow("Single mode purchase routed to renewal in activation", commonLogic.SubscriptionTraceInfo(logger.WithContext(ctx), commonLogic.SubscriptionTraceFlowOrder, "subscription_reused",
logger.Field("mode", "single"), "[SubscriptionFlow] activation reused single-mode anchor subscription",
logger.Field("route", "purchase_to_renewal"), append(commonLogic.OrderTraceFields(orderInfo),
append(commonLogic.UserSubscribeTraceFields(anchorSub),
logger.Field("reuse_reason", "single_mode_purchase_to_renewal"),
logger.Field("plan_changed", anchorSub.SubscribeId != orderInfo.SubscribeId), logger.Field("plan_changed", anchorSub.SubscribeId != orderInfo.SubscribeId),
logger.Field("anchor_user_subscribe_id", anchorSub.Id), )...,
logger.Field("order_no", orderInfo.OrderNo), )...,
) )
} }
case errors.Is(anchorErr, gorm.ErrRecordNotFound): case errors.Is(anchorErr, gorm.ErrRecordNotFound):
@ -274,25 +650,76 @@ func (l *ActivateOrderLogic) NewPurchase(ctx context.Context, orderInfo *order.O
) )
} }
// 如果没有合并已购订阅再尝试合并赠送订阅order_id=0 }
// 如果没有合并已购订阅再尝试合并赠送订阅order_id=0
// 全局单订阅口径下,非 SingleModel 也不能让试用订阅和付费订阅并存。
if userSub == nil { if userSub == nil {
giftSub, giftErr := l.findGiftSubscription(ctx, singleModeUserId, orderInfo.SubscribeId) effectiveOwner := orderInfo.UserId
if orderInfo.SubscriptionUserId > 0 {
effectiveOwner = orderInfo.SubscriptionUserId
}
giftSub, giftErr := l.findGiftSubscription(ctx, effectiveOwner, orderInfo.SubscribeId)
if giftErr == nil && giftSub != nil { if giftErr == nil && giftSub != nil {
// 在赠送订阅上延长时间,保持 token 不变
userSub, err = l.extendGiftSubscription(ctx, giftSub, orderInfo, sub) userSub, err = l.extendGiftSubscription(ctx, giftSub, orderInfo, sub)
if err != nil { if err != nil {
logger.WithContext(ctx).Error("Extend gift subscription failed", logger.WithContext(ctx).Error("Extend gift subscription failed",
logger.Field("error", err.Error()), logger.Field("error", err.Error()),
logger.Field("gift_subscribe_id", giftSub.Id), logger.Field("gift_subscribe_id", giftSub.Id),
) )
// 合并失败时回退到创建新订阅
userSub = nil userSub = nil
} }
} }
} }
// 兜底:创建新订阅前,查找用户是否已有同套餐的订阅记录(含过期/赠送),
// 有则复用旧记录续期,避免出现重复订阅。
// 需要同时检查 UserId 和 SubscriptionUserId因为家庭组绑定前后 owner 可能不同。
if userSub == nil {
candidateUserIds := []int64{orderInfo.UserId}
if orderInfo.SubscriptionUserId > 0 && orderInfo.SubscriptionUserId != orderInfo.UserId {
candidateUserIds = append(candidateUserIds, orderInfo.SubscriptionUserId)
}
var existingSub user.Subscribe
if findErr := l.svc.DB.Model(&user.Subscribe{}).
Where("user_id IN ? AND token != ''", candidateUserIds).
Order("expire_time DESC").
Order("updated_at DESC").
Order("id DESC").
First(&existingSub).Error; findErr == nil {
// 家庭组场景:订阅 owner 可能变更(如成员注册的试用 → 被家主收归),
// 续期前把 user_id 校正为当前订单的 SubscriptionUserId
effectiveOwner := orderInfo.UserId
if orderInfo.SubscriptionUserId > 0 {
effectiveOwner = orderInfo.SubscriptionUserId
}
if existingSub.UserId != effectiveOwner {
existingSub.UserId = effectiveOwner
}
// 找到已有记录,走续期逻辑
if renewErr := l.updateSubscriptionForRenewal(ctx, &existingSub, sub, orderInfo); renewErr != nil {
logger.WithContext(ctx).Error("Fallback renew existing subscription failed, will create new",
logger.Field("error", renewErr.Error()),
logger.Field("existing_subscribe_id", existingSub.Id),
logger.Field("order_no", orderInfo.OrderNo),
)
} else {
userSub = &existingSub
commonLogic.SubscriptionTraceInfo(logger.WithContext(ctx), commonLogic.SubscriptionTraceFlowOrder, "subscription_reused",
"[SubscriptionFlow] activation renewed an existing subscription instead of creating a duplicate",
append(commonLogic.OrderTraceFields(orderInfo),
append(commonLogic.UserSubscribeTraceFields(&existingSub),
logger.Field("reuse_reason", "fallback_existing_subscription"),
logger.Field("candidate_user_ids", candidateUserIds),
logger.Field("owner_corrected_to", effectiveOwner),
)...,
)...,
)
}
}
} }
// 如果没有合并赠送订阅,则正常创建新订阅 // 如果仍然没有可复用的订阅,才创建新订阅
if userSub == nil { if userSub == nil {
userSub, err = l.createUserSubscription(ctx, orderInfo, sub) userSub, err = l.createUserSubscription(ctx, orderInfo, sub)
if err != nil { if err != nil {
@ -309,7 +736,12 @@ func (l *ActivateOrderLogic) NewPurchase(ctx context.Context, orderInfo *order.O
// Clear cache // Clear cache
l.clearServerCache(ctx, sub) l.clearServerCache(ctx, sub)
logger.WithContext(ctx).Info("Insert user subscribe success") commonLogic.SubscriptionTraceInfo(logger.WithContext(ctx), commonLogic.SubscriptionTraceFlowOrder, "subscription_issued",
"[SubscriptionFlow] activation finished issuing subscription entitlement",
append(commonLogic.OrderTraceFields(orderInfo),
commonLogic.UserSubscribeTraceFields(userSub)...,
)...,
)
return nil return nil
} }
@ -380,6 +812,14 @@ func (l *ActivateOrderLogic) createGuestUser(ctx context.Context, orderInfo *ord
logger.Field("identifier", tempOrder.Identifier), logger.Field("identifier", tempOrder.Identifier),
logger.Field("auth_type", tempOrder.AuthType), logger.Field("auth_type", tempOrder.AuthType),
) )
commonLogic.SubscriptionTraceInfo(logger.WithContext(ctx), commonLogic.SubscriptionTraceFlowOrder, "guest_user_created",
"[SubscriptionFlow] guest user created during order activation",
append(commonLogic.OrderTraceFields(orderInfo),
logger.Field("created_user_id", userInfo.Id),
logger.Field("identifier", tempOrder.Identifier),
logger.Field("auth_type", tempOrder.AuthType),
)...,
)
return userInfo, nil return userInfo, nil
} }
@ -473,7 +913,7 @@ func (l *ActivateOrderLogic) createUserSubscription(ctx context.Context, orderIn
// Check quota limit before creating subscription (final safeguard) // Check quota limit before creating subscription (final safeguard)
if sub.Quota > 0 { if sub.Quota > 0 {
var count int64 var count int64
if err := l.svc.DB.Model(&user.Subscribe{}).Where("user_id = ? AND subscribe_id = ?", orderInfo.UserId, orderInfo.SubscribeId).Count(&count).Error; err != nil { if err := l.svc.DB.Model(&user.Subscribe{}).Where("user_id = ?", subscriptionUserId).Count(&count).Error; err != nil {
logger.WithContext(ctx).Error("Count user subscribe failed", logger.Field("error", err.Error())) logger.WithContext(ctx).Error("Count user subscribe failed", logger.Field("error", err.Error()))
return nil, err return nil, err
} }
@ -493,6 +933,13 @@ func (l *ActivateOrderLogic) createUserSubscription(ctx context.Context, orderIn
return nil, err return nil, err
} }
commonLogic.SubscriptionTraceInfo(logger.WithContext(ctx), commonLogic.SubscriptionTraceFlowOrder, "subscription_created",
"[SubscriptionFlow] new user subscription record created",
append(commonLogic.OrderTraceFields(orderInfo),
commonLogic.UserSubscribeTraceFields(userSub)...,
)...,
)
return userSub, nil return userSub, nil
} }
@ -500,13 +947,13 @@ func (l *ActivateOrderLogic) patchOrderParentID(ctx context.Context, orderID int
return l.svc.DB.WithContext(ctx).Model(&order.Order{}).Where("id = ? AND (parent_id = 0 OR parent_id IS NULL)", orderID).Update("parent_id", parentID).Error return l.svc.DB.WithContext(ctx).Model(&order.Order{}).Where("id = ? AND (parent_id = 0 OR parent_id IS NULL)", orderID).Update("parent_id", parentID).Error
} }
// findGiftSubscription 查找用户指定套餐的赠送订阅order_id=0包括已过期的 // findGiftSubscription 查找用户的赠送订阅order_id=0包括已过期的
// 返回找到的赠送订阅记录,如果没有则返回 nil // 单订阅模式下,用户若以不同套餐首次购买,需要将赠送订阅合并为付费订阅,
func (l *ActivateOrderLogic) findGiftSubscription(ctx context.Context, userId int64, subscribeId int64) (*user.Subscribe, error) { // 因此不再过滤 subscribe_id避免套餐不同时绕过合并路径创建重复订阅。
// 直接查询数据库,查找 order_id=0赠送且同套餐的订阅不限制过期状态 func (l *ActivateOrderLogic) findGiftSubscription(ctx context.Context, userId int64, _ int64) (*user.Subscribe, error) {
var giftSub user.Subscribe var giftSub user.Subscribe
err := l.svc.DB.Model(&user.Subscribe{}). err := l.svc.DB.Model(&user.Subscribe{}).
Where("user_id = ? AND order_id = 0 AND subscribe_id = ?", userId, subscribeId). Where("user_id = ? AND order_id = 0", userId).
Order("created_at DESC"). Order("created_at DESC").
First(&giftSub).Error First(&giftSub).Error
if err != nil { if err != nil {
@ -515,23 +962,25 @@ func (l *ActivateOrderLogic) findGiftSubscription(ctx context.Context, userId in
return &giftSub, nil return &giftSub, nil
} }
// extendGiftSubscription 在现有赠送订阅上延长到期时间,保持 token 不变 // extendGiftSubscription 在现有赠送订阅上延长到期时间,保持 token/UUID 不变
// 将购买的天数叠加到赠送订阅的到期时间上,并更新 order_id 为新订单 ID // 若购买套餐与赠送套餐不同,同步更新套餐 ID 和流量配额并重置已用量(套餐变更语义)。
func (l *ActivateOrderLogic) extendGiftSubscription(ctx context.Context, giftSub *user.Subscribe, orderInfo *order.Order, sub *subscribe.Subscribe) (*user.Subscribe, error) { func (l *ActivateOrderLogic) extendGiftSubscription(ctx context.Context, giftSub *user.Subscribe, orderInfo *order.Order, sub *subscribe.Subscribe) (*user.Subscribe, error) {
now := time.Now() now := time.Now()
// 计算基准时间:取赠送订阅到期时间和当前时间的较大值 baseTime := subscriptionRenewalBaseTime(now, giftSub)
baseTime := giftSub.ExpireTime
if baseTime.Before(now) {
baseTime = now
}
// 在基准时间上增加购买的天数
newExpireTime := tool.AddTime(sub.UnitTime, orderInfo.Quantity, baseTime) newExpireTime := tool.AddTime(sub.UnitTime, orderInfo.Quantity, baseTime)
// 更新赠送订阅的信息
giftSub.OrderId = orderInfo.Id giftSub.OrderId = orderInfo.Id
giftSub.ExpireTime = newExpireTime giftSub.ExpireTime = newExpireTime
giftSub.Status = 1 giftSub.Status = 1
// 套餐变更:更新套餐 ID 和流量配额,重置已用流量(与 updateSubscriptionForRenewal 逻辑一致)
if giftSub.SubscribeId != orderInfo.SubscribeId {
giftSub.SubscribeId = orderInfo.SubscribeId
giftSub.Traffic = sub.Traffic
giftSub.Download = 0
giftSub.Upload = 0
}
if err := l.svc.UserModel.UpdateSubscribe(ctx, giftSub); err != nil { if err := l.svc.UserModel.UpdateSubscribe(ctx, giftSub); err != nil {
logger.WithContext(ctx).Error("Update gift subscription failed", logger.WithContext(ctx).Error("Update gift subscription failed",
logger.Field("error", err.Error()), logger.Field("error", err.Error()),
@ -540,11 +989,15 @@ func (l *ActivateOrderLogic) extendGiftSubscription(ctx context.Context, giftSub
return nil, err return nil, err
} }
logger.WithContext(ctx).Info("Extended gift subscription successfully", commonLogic.SubscriptionTraceInfo(logger.WithContext(ctx), commonLogic.SubscriptionTraceFlowOrder, "subscription_reused",
logger.Field("subscribe_id", giftSub.Id), "[SubscriptionFlow] paid order extended an existing gift subscription",
append(commonLogic.OrderTraceFields(orderInfo),
append(commonLogic.UserSubscribeTraceFields(giftSub),
logger.Field("reuse_reason", "gift_subscription_promoted"),
logger.Field("old_expire_time", baseTime), logger.Field("old_expire_time", baseTime),
logger.Field("new_expire_time", newExpireTime), logger.Field("new_expire_time", newExpireTime),
logger.Field("order_id", orderInfo.Id), )...,
)...,
) )
return giftSub, nil return giftSub, nil
@ -556,7 +1009,7 @@ func (l *ActivateOrderLogic) handleCommission(ctx context.Context, userInfo *use
if !l.shouldProcessCommission(userInfo, orderInfo.IsNew) { if !l.shouldProcessCommission(userInfo, orderInfo.IsNew) {
// 普通用户路径(佣金比例=0只有首单才双方赠N天 // 普通用户路径(佣金比例=0只有首单才双方赠N天
if orderInfo.IsNew { if orderInfo.IsNew {
l.grantGiftDaysToBothParties(ctx, userInfo, orderInfo.OrderNo) l.grantGiftDaysToBothParties(ctx, userInfo, orderInfo)
} }
return return
} }
@ -643,14 +1096,25 @@ func (l *ActivateOrderLogic) handleCommission(ctx context.Context, userInfo *use
logger.Field("user_id", referer.Id), logger.Field("user_id", referer.Id),
) )
} }
// 有佣金路径:邀请人拿佣金,被邀请用户(首单)拿天数
if orderInfo.IsNew {
giftTarget := l.resolveGiftTargetUser(ctx, userInfo, orderInfo.SubscriptionUserId)
if giftErr := l.grantGiftDays(ctx, giftTarget, int(l.svc.Config.Invite.GiftDays), orderInfo.OrderNo, "邀请赠送"); giftErr != nil {
l.logGiftDaysError(ctx, giftErr, giftTarget, orderInfo, "commission_path_referee")
}
}
} }
func (l *ActivateOrderLogic) grantGiftDaysToBothParties(ctx context.Context, referee *user.User, orderNo string) { func (l *ActivateOrderLogic) grantGiftDaysToBothParties(ctx context.Context, referee *user.User, orderInfo *order.Order) {
giftDays := l.svc.Config.Invite.GiftDays giftDays := l.svc.Config.Invite.GiftDays
if giftDays <= 0 || referee == nil || referee.Id == 0 || referee.RefererId == 0 { if giftDays <= 0 || referee == nil || referee.Id == 0 || referee.RefererId == 0 || orderInfo == nil {
return return
} }
_ = l.grantGiftDays(ctx, referee, int(giftDays), orderNo, "邀请赠送") refereeTarget := l.resolveGiftTargetUser(ctx, referee, orderInfo.SubscriptionUserId)
if err := l.grantGiftDays(ctx, refereeTarget, int(giftDays), orderInfo.OrderNo, "邀请赠送"); err != nil {
l.logGiftDaysError(ctx, err, refereeTarget, orderInfo, "no_commission_referee")
}
if referee.RefererId == 0 { if referee.RefererId == 0 {
return return
} }
@ -658,7 +1122,54 @@ func (l *ActivateOrderLogic) grantGiftDaysToBothParties(ctx context.Context, ref
if err != nil || referer == nil { if err != nil || referer == nil {
return return
} }
_ = l.grantGiftDays(ctx, referer, int(giftDays), orderNo, "邀请赠送") refererTarget := l.resolveGiftTargetUser(ctx, referer, 0)
if err = l.grantGiftDays(ctx, refererTarget, int(giftDays), orderInfo.OrderNo, "邀请赠送"); err != nil {
l.logGiftDaysError(ctx, err, refererTarget, orderInfo, "no_commission_referer")
}
}
func (l *ActivateOrderLogic) resolveGiftTargetUser(ctx context.Context, source *user.User, forcedOwnerID int64) *user.User {
if source == nil || source.Id == 0 {
return source
}
targetID := source.Id
if forcedOwnerID > 0 {
targetID = forcedOwnerID
} else if entitlement, err := commonLogic.ResolveEntitlementUser(ctx, l.svc.DB, source.Id); err == nil && entitlement != nil && entitlement.EffectiveUserID > 0 {
targetID = entitlement.EffectiveUserID
}
if targetID == source.Id {
return source
}
target, err := l.svc.UserModel.FindOne(ctx, targetID)
if err != nil || target == nil {
logger.WithContext(ctx).Error("Resolve gift target owner failed",
logger.Field("source_user_id", source.Id),
logger.Field("target_user_id", targetID),
)
return source
}
return target
}
func (l *ActivateOrderLogic) logGiftDaysError(ctx context.Context, err error, target *user.User, orderInfo *order.Order, stage string) {
if err == nil {
return
}
var targetID int64
if target != nil {
targetID = target.Id
}
var orderNo string
if orderInfo != nil {
orderNo = orderInfo.OrderNo
}
logger.WithContext(ctx).Error("Grant invite gift days failed",
logger.Field("error", err.Error()),
logger.Field("stage", stage),
logger.Field("target_user_id", targetID),
logger.Field("order_no", orderNo),
)
} }
func (l *ActivateOrderLogic) grantGiftDays(ctx context.Context, u *user.User, days int, orderNo string, remark string) error { func (l *ActivateOrderLogic) grantGiftDays(ctx context.Context, u *user.User, days int, orderNo string, remark string) error {
@ -685,11 +1196,28 @@ func (l *ActivateOrderLogic) grantGiftDays(ctx context.Context, u *user.User, da
activeSubscribe, err := l.svc.UserModel.FindActiveSubscribe(ctx, u.Id) activeSubscribe, err := l.svc.UserModel.FindActiveSubscribe(ctx, u.Id)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
return nil giftLog := &log.Gift{
Type: log.GiftTypeIncrease,
OrderNo: orderNo,
SubscribeId: 0,
Amount: int64(days),
Balance: u.Balance,
Remark: remark + " skipped: no active subscription",
Timestamp: time.Now().UnixMilli(),
}
content, _ := giftLog.Marshal()
return l.svc.LogModel.Insert(ctx, &log.SystemLog{
Type: log.TypeGift.Uint8(),
Date: time.Now().Format("2006-01-02"),
ObjectID: u.Id,
Content: string(content),
})
} }
return err return err
} }
if !activeSubscribe.ExpireTime.Equal(time.UnixMilli(0)) {
activeSubscribe.ExpireTime = activeSubscribe.ExpireTime.Add(time.Duration(days) * 24 * time.Hour) activeSubscribe.ExpireTime = activeSubscribe.ExpireTime.Add(time.Duration(days) * 24 * time.Hour)
}
err = l.svc.UserModel.UpdateSubscribe(ctx, activeSubscribe) err = l.svc.UserModel.UpdateSubscribe(ctx, activeSubscribe)
if err != nil { if err != nil {
return err return err
@ -849,6 +1377,9 @@ func (l *ActivateOrderLogic) Renewal(ctx context.Context, orderInfo *order.Order
} }
} }
// Trigger user group recalculation (needed when renewing an expired subscription)
l.triggerUserGroupRecalculation(ctx, userInfo.Id)
// Clear user subscription cache // Clear user subscription cache
err = l.svc.UserModel.ClearSubscribeCache(ctx, userSub) err = l.svc.UserModel.ClearSubscribeCache(ctx, userSub)
if err != nil { if err != nil {
@ -865,6 +1396,15 @@ func (l *ActivateOrderLogic) Renewal(ctx context.Context, orderInfo *order.Order
// Handle commission // Handle commission
go l.handleCommission(context.Background(), userInfo, orderInfo) go l.handleCommission(context.Background(), userInfo, orderInfo)
commonLogic.SubscriptionTraceInfo(logger.WithContext(ctx), commonLogic.SubscriptionTraceFlowOrder, "subscription_renewed",
"[SubscriptionFlow] renewal order updated existing subscription",
append(commonLogic.OrderTraceFields(orderInfo),
append(commonLogic.UserSubscribeTraceFields(userSub),
logger.Field("iap_expire_at", iapExpireAt),
)...,
)...,
)
return nil return nil
} }
@ -881,10 +1421,7 @@ func (l *ActivateOrderLogic) getUserSubscription(ctx context.Context, token stri
// updateSubscriptionWithIAPExpire 用于 Apple IAP 续费:按累计加时语义更新到期时间。 // updateSubscriptionWithIAPExpire 用于 Apple IAP 续费:按累计加时语义更新到期时间。
func (l *ActivateOrderLogic) updateSubscriptionWithIAPExpire(ctx context.Context, userSub *user.Subscribe, sub *subscribe.Subscribe, orderInfo *order.Order, iapExpireAt int64) error { func (l *ActivateOrderLogic) updateSubscriptionWithIAPExpire(ctx context.Context, userSub *user.Subscribe, sub *subscribe.Subscribe, orderInfo *order.Order, iapExpireAt int64) error {
now := time.Now() now := time.Now()
baseTime := userSub.ExpireTime baseTime := subscriptionRenewalBaseTime(now, userSub)
if baseTime.Before(now) {
baseTime = now
}
newExpire := tool.AddTime(sub.UnitTime, orderInfo.Quantity, baseTime) newExpire := tool.AddTime(sub.UnitTime, orderInfo.Quantity, baseTime)
if iapExpireAt > 0 { if iapExpireAt > 0 {
appleExpire := time.Unix(iapExpireAt, 0) appleExpire := time.Unix(iapExpireAt, 0)
@ -904,6 +1441,7 @@ func (l *ActivateOrderLogic) updateSubscriptionWithIAPExpire(ctx context.Context
userSub.FinishedAt = nil userSub.FinishedAt = nil
} }
userSub.OrderId = orderInfo.Id
userSub.ExpireTime = newExpire userSub.ExpireTime = newExpire
userSub.Status = 1 userSub.Status = 1
@ -918,11 +1456,9 @@ func (l *ActivateOrderLogic) updateSubscriptionWithIAPExpire(ctx context.Context
// expiration time extension and traffic reset if configured // expiration time extension and traffic reset if configured
func (l *ActivateOrderLogic) updateSubscriptionForRenewal(ctx context.Context, userSub *user.Subscribe, sub *subscribe.Subscribe, orderInfo *order.Order) error { func (l *ActivateOrderLogic) updateSubscriptionForRenewal(ctx context.Context, userSub *user.Subscribe, sub *subscribe.Subscribe, orderInfo *order.Order) error {
now := time.Now() now := time.Now()
if userSub.ExpireTime.Before(now) { baseTime := subscriptionRenewalBaseTime(now, userSub)
userSub.ExpireTime = now today := now.Day()
} resetDay := baseTime.Day()
today := time.Now().Day()
resetDay := userSub.ExpireTime.Day()
// 套餐变更更新套餐ID和流量配额并重置已用流量 // 套餐变更更新套餐ID和流量配额并重置已用流量
if userSub.SubscribeId != orderInfo.SubscribeId { if userSub.SubscribeId != orderInfo.SubscribeId {
@ -949,7 +1485,7 @@ func (l *ActivateOrderLogic) updateSubscriptionForRenewal(ctx context.Context, u
} }
userSub.OrderId = orderInfo.Id userSub.OrderId = orderInfo.Id
userSub.ExpireTime = tool.AddTime(sub.UnitTime, orderInfo.Quantity, userSub.ExpireTime) userSub.ExpireTime = tool.AddTime(sub.UnitTime, orderInfo.Quantity, baseTime)
userSub.Status = 1 userSub.Status = 1
// 续费时重置过期流量字段 // 续费时重置过期流量字段
userSub.ExpiredDownload = 0 userSub.ExpiredDownload = 0

View File

@ -0,0 +1,806 @@
package orderLogic
import (
"context"
"fmt"
"os"
"strings"
"testing"
"time"
"github.com/perfect-panel/server/internal/config"
userLogic "github.com/perfect-panel/server/internal/logic/public/user"
modelLog "github.com/perfect-panel/server/internal/model/log"
"github.com/perfect-panel/server/internal/model/order"
"github.com/perfect-panel/server/internal/model/user"
"github.com/perfect-panel/server/internal/svc"
"github.com/perfect-panel/server/internal/types"
"github.com/perfect-panel/server/pkg/constant"
"github.com/redis/go-redis/v9"
"gorm.io/driver/mysql"
"gorm.io/gorm"
)
// 普通用户 + 首单 → 双方赠N天
func TestHandleCommission_GrantGiftDaysWhenCommissionDisabled_FirstOrder(t *testing.T) {
logic, db, cleanup := setupInviteTestLogic(t, config.InviteConfig{
ReferralPercentage: 0,
OnlyFirstPurchase: false,
GiftDays: 2,
})
defer cleanup()
referee := seedUser(t, db, 0, false)
referer := seedUser(t, db, 0, false)
referee.RefererId = referer.Id
baseExpire := time.Now().Add(72 * time.Hour).Truncate(time.Second)
refereeSub := seedActiveSubscribe(t, db, referee.Id, baseExpire)
refererSub := seedActiveSubscribe(t, db, referer.Id, baseExpire)
logic.handleCommission(context.Background(), referee, &order.Order{
OrderNo: "ORD-GIFT-001",
Type: OrderTypeSubscribe,
IsNew: true, // 首单
Amount: 100,
FeeAmount: 0,
CreatedAt: time.Now(),
})
assertExpireIncreasedByDays(t, db, refereeSub.Id, baseExpire, 2)
assertExpireIncreasedByDays(t, db, refererSub.Id, baseExpire, 2)
var giftCount int64
if err := db.Model(&modelLog.SystemLog{}).Where("type = ?", modelLog.TypeGift.Uint8()).Count(&giftCount).Error; err != nil {
t.Fatalf("count gift logs failed: %v", err)
}
if giftCount != 2 {
t.Fatalf("expected 2 gift logs, got %d", giftCount)
}
}
// 普通用户 + 非首单 → 不赠送
func TestHandleCommission_NoGiftDaysWhenCommissionDisabled_NotFirstOrder(t *testing.T) {
logic, db, cleanup := setupInviteTestLogic(t, config.InviteConfig{
ReferralPercentage: 0,
OnlyFirstPurchase: false,
GiftDays: 2,
})
defer cleanup()
referee := seedUser(t, db, 0, false)
referer := seedUser(t, db, 0, false)
referee.RefererId = referer.Id
baseExpire := time.Now().Add(72 * time.Hour).Truncate(time.Second)
refereeSub := seedActiveSubscribe(t, db, referee.Id, baseExpire)
refererSub := seedActiveSubscribe(t, db, referer.Id, baseExpire)
logic.handleCommission(context.Background(), referee, &order.Order{
OrderNo: "ORD-GIFT-002",
Type: OrderTypeSubscribe,
IsNew: false, // 非首单
Amount: 100,
FeeAmount: 0,
CreatedAt: time.Now(),
})
// 到期时间不应延长
assertExpireIncreasedByDays(t, db, refereeSub.Id, baseExpire, 0)
assertExpireIncreasedByDays(t, db, refererSub.Id, baseExpire, 0)
var giftCount int64
if err := db.Model(&modelLog.SystemLog{}).Where("type = ?", modelLog.TypeGift.Uint8()).Count(&giftCount).Error; err != nil {
t.Fatalf("count gift logs failed: %v", err)
}
if giftCount != 0 {
t.Fatalf("expected 0 gift logs for non-first order, got %d", giftCount)
}
}
// 渠道 + 首单 → 被邀请人赠N天 + 邀请人获佣金
func TestHandleCommission_GiftDaysAndCommissionWhenChannelFirstOrder(t *testing.T) {
logic, db, cleanup := setupInviteTestLogic(t, config.InviteConfig{
ReferralPercentage: 10,
OnlyFirstPurchase: false,
GiftDays: 2,
})
defer cleanup()
referee := seedUser(t, db, 0, false)
referer := seedUser(t, db, 0, false)
referee.RefererId = referer.Id
baseExpire := time.Now().Add(96 * time.Hour).Truncate(time.Second)
refereeSub := seedActiveSubscribe(t, db, referee.Id, baseExpire)
logic.handleCommission(context.Background(), referee, &order.Order{
OrderNo: "ORD-COMM-001",
Type: OrderTypeSubscribe,
IsNew: true, // 首单
Amount: 100,
FeeAmount: 0,
CreatedAt: time.Now(),
})
// 被邀请人(首单)应获得赠送天数
assertExpireIncreasedByDays(t, db, refereeSub.Id, baseExpire, 2)
// 邀请人应获得佣金
var refererAfter user.User
if err := db.First(&refererAfter, referer.Id).Error; err != nil {
t.Fatalf("query referer failed: %v", err)
}
if refererAfter.Commission != 10 {
t.Fatalf("expected referer commission=10, got %d", refererAfter.Commission)
}
var giftCount int64
if err := db.Model(&modelLog.SystemLog{}).Where("type = ?", modelLog.TypeGift.Uint8()).Count(&giftCount).Error; err != nil {
t.Fatalf("count gift logs failed: %v", err)
}
if giftCount != 1 {
t.Fatalf("expected 1 gift log for referee on first order with commission, got %d", giftCount)
}
}
// 渠道 + 非首单 → 只给邀请人佣金,不赠天
func TestHandleCommission_OnlyCommissionWhenChannelNotFirstOrder(t *testing.T) {
logic, db, cleanup := setupInviteTestLogic(t, config.InviteConfig{
ReferralPercentage: 10,
OnlyFirstPurchase: false,
GiftDays: 2,
})
defer cleanup()
referee := seedUser(t, db, 0, false)
referer := seedUser(t, db, 0, false)
referee.RefererId = referer.Id
baseExpire := time.Now().Add(96 * time.Hour).Truncate(time.Second)
refereeSub := seedActiveSubscribe(t, db, referee.Id, baseExpire)
logic.handleCommission(context.Background(), referee, &order.Order{
OrderNo: "ORD-COMM-002",
Type: OrderTypeSubscribe,
IsNew: false, // 非首单
Amount: 100,
FeeAmount: 0,
CreatedAt: time.Now(),
})
// 被邀请人不应获得赠送天数
assertExpireIncreasedByDays(t, db, refereeSub.Id, baseExpire, 0)
// 邀请人应获得佣金
var refererAfter user.User
if err := db.First(&refererAfter, referer.Id).Error; err != nil {
t.Fatalf("query referer failed: %v", err)
}
if refererAfter.Commission != 10 {
t.Fatalf("expected referer commission=10, got %d", refererAfter.Commission)
}
var giftCount int64
if err := db.Model(&modelLog.SystemLog{}).Where("type = ?", modelLog.TypeGift.Uint8()).Count(&giftCount).Error; err != nil {
t.Fatalf("count gift logs failed: %v", err)
}
if giftCount != 0 {
t.Fatalf("expected 0 gift logs when channel non-first order, got %d", giftCount)
}
}
func TestHandleCommission_NoGiftDaysWhenNoInviteRelation(t *testing.T) {
logic, db, cleanup := setupInviteTestLogic(t, config.InviteConfig{
ReferralPercentage: 0,
OnlyFirstPurchase: false,
GiftDays: 2,
})
defer cleanup()
// 没有邀请人的独立用户
loneUser := seedUser(t, db, 0, false)
// RefererId == 0无邀请关系
baseExpire := time.Now().Add(72 * time.Hour).Truncate(time.Second)
loneSub := seedActiveSubscribe(t, db, loneUser.Id, baseExpire)
logic.handleCommission(context.Background(), loneUser, &order.Order{
OrderNo: "ORD-LONE-001",
Type: OrderTypeSubscribe,
IsNew: true,
Amount: 100,
FeeAmount: 0,
CreatedAt: time.Now(),
})
// 订阅到期时间不应该被延长
var subAfter user.Subscribe
if err := db.First(&subAfter, loneSub.Id).Error; err != nil {
t.Fatalf("query subscribe failed: %v", err)
}
if !subAfter.ExpireTime.Equal(baseExpire) {
t.Fatalf("expected no gift days for user without inviter, before=%v after=%v", baseExpire, subAfter.ExpireTime)
}
// 不应产生赠天日志
var giftCount int64
if err := db.Model(&modelLog.SystemLog{}).Where("type = ?", modelLog.TypeGift.Uint8()).Count(&giftCount).Error; err != nil {
t.Fatalf("count gift logs failed: %v", err)
}
if giftCount != 0 {
t.Fatalf("expected 0 gift logs for user without inviter, got %d", giftCount)
}
}
// 先绑码后首单 → 双方赠N天
func TestInviteFlow_BindThenFirstOrder_GrantGiftDays(t *testing.T) {
logic, db, cleanup := setupInviteTestLogic(t, config.InviteConfig{
ReferralPercentage: 0,
OnlyFirstPurchase: false,
GiftDays: 2,
})
defer cleanup()
referee := seedUser(t, db, 0, false)
referer := seedUser(t, db, 0, false)
referer.ReferCode = fmt.Sprintf("REF-%d", referer.Id)
if err := db.Model(&user.User{}).Where("id = ?", referer.Id).Update("refer_code", referer.ReferCode).Error; err != nil {
t.Fatalf("update referer code failed: %v", err)
}
refereeBaseExpire := time.Now().Add(48 * time.Hour).Truncate(time.Second)
refererBaseExpire := time.Now().Add(72 * time.Hour).Truncate(time.Second)
refereeSub := seedActiveSubscribe(t, db, referee.Id, refereeBaseExpire)
refererSub := seedActiveSubscribe(t, db, referer.Id, refererBaseExpire)
ctx := context.WithValue(context.Background(), constant.CtxKeyUser, referee)
bindLogic := userLogic.NewBindInviteCodeLogic(ctx, logic.svc)
if err := bindLogic.BindInviteCode(&types.BindInviteCodeRequest{InviteCode: referer.ReferCode}); err != nil {
t.Fatalf("bind invite code failed: %v", err)
}
var refereeAfterBind user.User
if err := db.First(&refereeAfterBind, referee.Id).Error; err != nil {
t.Fatalf("query referee after bind failed: %v", err)
}
if refereeAfterBind.RefererId != referer.Id {
t.Fatalf("bind invite failed, expected referer_id=%d got=%d", referer.Id, refereeAfterBind.RefererId)
}
// 首单 IsNew=true → 双方赠N天
logic.handleCommission(context.Background(), &refereeAfterBind, &order.Order{
OrderNo: "ORD-FLOW-001",
Type: OrderTypeSubscribe,
IsNew: true,
Amount: 100,
FeeAmount: 0,
CreatedAt: time.Now(),
})
assertExpireIncreasedByDays(t, db, refereeSub.Id, refereeBaseExpire, 2)
assertExpireIncreasedByDays(t, db, refererSub.Id, refererBaseExpire, 2)
}
// 先买订单后绑码再续费 → 不赠送IsNew=false
func TestInviteFlow_OrderThenBind_NoGiftDays(t *testing.T) {
logic, db, cleanup := setupInviteTestLogic(t, config.InviteConfig{
ReferralPercentage: 0,
OnlyFirstPurchase: false,
GiftDays: 2,
})
defer cleanup()
referee := seedUser(t, db, 0, false)
referer := seedUser(t, db, 0, false)
referee.RefererId = referer.Id
baseExpire := time.Now().Add(72 * time.Hour).Truncate(time.Second)
refereeSub := seedActiveSubscribe(t, db, referee.Id, baseExpire)
refererSub := seedActiveSubscribe(t, db, referer.Id, baseExpire)
// 先前已有订单IsNew=false模拟先买订单后绑码的场景
logic.handleCommission(context.Background(), referee, &order.Order{
OrderNo: "ORD-FLOW-002",
Type: OrderTypeSubscribe,
IsNew: false, // 已有历史订单
Amount: 100,
FeeAmount: 0,
CreatedAt: time.Now(),
})
assertExpireIncreasedByDays(t, db, refereeSub.Id, baseExpire, 0)
assertExpireIncreasedByDays(t, db, refererSub.Id, baseExpire, 0)
}
func TestHandleCommission_GiftDaysToRefereeFamilyOwnerWhenChannelFirstOrder(t *testing.T) {
logic, db, cleanup := setupInviteTestLogic(t, config.InviteConfig{
ReferralPercentage: 10,
OnlyFirstPurchase: false,
GiftDays: 2,
})
defer cleanup()
refereeOwner := seedUser(t, db, 0, false)
referee := seedUser(t, db, 0, false)
referer := seedUser(t, db, 0, false)
seedFamily(t, db, refereeOwner.Id, referee.Id)
referee.RefererId = referer.Id
ownerBaseExpire := time.Now().Add(96 * time.Hour).Truncate(time.Second)
ownerSub := seedActiveSubscribe(t, db, refereeOwner.Id, ownerBaseExpire)
logic.handleCommission(context.Background(), referee, &order.Order{
OrderNo: "ORD-FAMILY-REFEREE-001",
Type: OrderTypeSubscribe,
IsNew: true,
Amount: 100,
FeeAmount: 0,
SubscriptionUserId: refereeOwner.Id,
CreatedAt: time.Now(),
})
assertExpireIncreasedByDays(t, db, ownerSub.Id, ownerBaseExpire, 2)
assertUserCommission(t, db, referer.Id, 10)
}
func TestHandleCommission_GiftDaysToRefererFamilyOwnerWhenCommissionDisabled(t *testing.T) {
logic, db, cleanup := setupInviteTestLogic(t, config.InviteConfig{
ReferralPercentage: 0,
OnlyFirstPurchase: false,
GiftDays: 2,
})
defer cleanup()
referee := seedUser(t, db, 0, false)
refererOwner := seedUser(t, db, 0, false)
referer := seedUser(t, db, 0, false)
seedFamily(t, db, refererOwner.Id, referer.Id)
referee.RefererId = referer.Id
refereeBaseExpire := time.Now().Add(72 * time.Hour).Truncate(time.Second)
refererOwnerBaseExpire := time.Now().Add(96 * time.Hour).Truncate(time.Second)
refereeSub := seedActiveSubscribe(t, db, referee.Id, refereeBaseExpire)
refererOwnerSub := seedActiveSubscribe(t, db, refererOwner.Id, refererOwnerBaseExpire)
logic.handleCommission(context.Background(), referee, &order.Order{
OrderNo: "ORD-FAMILY-REFERER-001",
Type: OrderTypeSubscribe,
IsNew: true,
Amount: 100,
FeeAmount: 0,
CreatedAt: time.Now(),
})
assertExpireIncreasedByDays(t, db, refereeSub.Id, refereeBaseExpire, 2)
assertExpireIncreasedByDays(t, db, refererOwnerSub.Id, refererOwnerBaseExpire, 2)
}
func TestHandleCommission_RefererFamilyMemberCommissionBehaviorUnchanged(t *testing.T) {
logic, db, cleanup := setupInviteTestLogic(t, config.InviteConfig{
ReferralPercentage: 10,
OnlyFirstPurchase: false,
GiftDays: 2,
})
defer cleanup()
referee := seedUser(t, db, 0, false)
refererOwner := seedUser(t, db, 0, false)
referer := seedUser(t, db, 0, false)
seedFamily(t, db, refererOwner.Id, referer.Id)
referee.RefererId = referer.Id
refereeBaseExpire := time.Now().Add(96 * time.Hour).Truncate(time.Second)
refereeSub := seedActiveSubscribe(t, db, referee.Id, refereeBaseExpire)
logic.handleCommission(context.Background(), referee, &order.Order{
OrderNo: "ORD-FAMILY-COMMISSION-001",
Type: OrderTypeSubscribe,
IsNew: true,
Amount: 100,
FeeAmount: 0,
CreatedAt: time.Now(),
})
assertExpireIncreasedByDays(t, db, refereeSub.Id, refereeBaseExpire, 2)
assertUserCommission(t, db, referer.Id, 10)
assertUserCommission(t, db, refererOwner.Id, 0)
}
func TestHandleCommission_GiftDaysRecognizesUnlimitedSubscription(t *testing.T) {
logic, db, cleanup := setupInviteTestLogic(t, config.InviteConfig{
ReferralPercentage: 10,
OnlyFirstPurchase: false,
GiftDays: 2,
})
defer cleanup()
referee := seedUser(t, db, 0, false)
referer := seedUser(t, db, 0, false)
referee.RefererId = referer.Id
unlimitedExpire := time.UnixMilli(0)
refereeSub := seedActiveSubscribe(t, db, referee.Id, unlimitedExpire)
logic.handleCommission(context.Background(), referee, &order.Order{
OrderNo: "ORD-UNLIMITED-001",
Type: OrderTypeSubscribe,
IsNew: true,
Amount: 100,
FeeAmount: 0,
CreatedAt: time.Now(),
})
assertExpireIncreasedByDays(t, db, refereeSub.Id, unlimitedExpire, 0)
var giftCount int64
if err := db.Model(&modelLog.SystemLog{}).Where("type = ? AND object_id = ?", modelLog.TypeGift.Uint8(), referee.Id).Count(&giftCount).Error; err != nil {
t.Fatalf("count gift logs failed: %v", err)
}
if giftCount != 1 {
t.Fatalf("expected 1 gift log for unlimited subscription, got %d", giftCount)
}
}
func TestHandleCommission_IdempotentForRepeatedActivation(t *testing.T) {
logic, db, cleanup := setupInviteTestLogic(t, config.InviteConfig{
ReferralPercentage: 10,
OnlyFirstPurchase: false,
GiftDays: 2,
})
defer cleanup()
referee := seedUser(t, db, 0, false)
referer := seedUser(t, db, 0, false)
referee.RefererId = referer.Id
baseExpire := time.Now().Add(96 * time.Hour).Truncate(time.Second)
refereeSub := seedActiveSubscribe(t, db, referee.Id, baseExpire)
orderInfo := &order.Order{
OrderNo: "ORD-IDEMPOTENT-001",
Type: OrderTypeSubscribe,
IsNew: true,
Amount: 100,
FeeAmount: 0,
CreatedAt: time.Now(),
}
logic.handleCommission(context.Background(), referee, orderInfo)
logic.handleCommission(context.Background(), referee, orderInfo)
assertExpireIncreasedByDays(t, db, refereeSub.Id, baseExpire, 2)
assertUserCommission(t, db, referer.Id, 10)
assertLogCountForOrder(t, db, modelLog.TypeCommission.Uint8(), orderInfo.OrderNo, 1)
assertLogCountForOrder(t, db, modelLog.TypeGift.Uint8(), orderInfo.OrderNo, 1)
}
func TestHandleCommission_NoActiveSubscriptionWritesSkippedGiftLog(t *testing.T) {
logic, db, cleanup := setupInviteTestLogic(t, config.InviteConfig{
ReferralPercentage: 10,
OnlyFirstPurchase: false,
GiftDays: 2,
})
defer cleanup()
referee := seedUser(t, db, 0, false)
referer := seedUser(t, db, 0, false)
referee.RefererId = referer.Id
orderInfo := &order.Order{
OrderNo: "ORD-SKIPPED-001",
Type: OrderTypeSubscribe,
IsNew: true,
Amount: 100,
FeeAmount: 0,
CreatedAt: time.Now(),
}
logic.handleCommission(context.Background(), referee, orderInfo)
assertUserCommission(t, db, referer.Id, 10)
assertLogCountForOrder(t, db, modelLog.TypeCommission.Uint8(), orderInfo.OrderNo, 1)
assertLogCountForOrder(t, db, modelLog.TypeGift.Uint8(), orderInfo.OrderNo, 1)
assertGiftLogRemarkContains(t, db, orderInfo.OrderNo, "skipped: no active subscription")
}
func TestHandleCommission_GiftDaysZeroDoesNotWriteGiftLog(t *testing.T) {
logic, db, cleanup := setupInviteTestLogic(t, config.InviteConfig{
ReferralPercentage: 10,
OnlyFirstPurchase: false,
GiftDays: 0,
})
defer cleanup()
referee := seedUser(t, db, 0, false)
referer := seedUser(t, db, 0, false)
referee.RefererId = referer.Id
baseExpire := time.Now().Add(96 * time.Hour).Truncate(time.Second)
refereeSub := seedActiveSubscribe(t, db, referee.Id, baseExpire)
orderInfo := &order.Order{
OrderNo: "ORD-GIFT-ZERO-001",
Type: OrderTypeSubscribe,
IsNew: true,
Amount: 100,
FeeAmount: 0,
CreatedAt: time.Now(),
}
logic.handleCommission(context.Background(), referee, orderInfo)
assertExpireIncreasedByDays(t, db, refereeSub.Id, baseExpire, 0)
assertUserCommission(t, db, referer.Id, 10)
assertLogCountForOrder(t, db, modelLog.TypeCommission.Uint8(), orderInfo.OrderNo, 1)
assertLogCountForOrder(t, db, modelLog.TypeGift.Uint8(), orderInfo.OrderNo, 0)
}
func TestHandleCommission_GlobalOnlyFirstPurchaseBlocksCommissionAndGiftForNonFirstOrder(t *testing.T) {
logic, db, cleanup := setupInviteTestLogic(t, config.InviteConfig{
ReferralPercentage: 10,
OnlyFirstPurchase: true,
GiftDays: 2,
})
defer cleanup()
referee := seedUser(t, db, 0, false)
referer := seedUser(t, db, 0, false)
referee.RefererId = referer.Id
baseExpire := time.Now().Add(96 * time.Hour).Truncate(time.Second)
refereeSub := seedActiveSubscribe(t, db, referee.Id, baseExpire)
orderInfo := &order.Order{
OrderNo: "ORD-NONFIRST-GLOBAL-001",
Type: OrderTypeRenewal,
IsNew: false,
Amount: 100,
FeeAmount: 0,
CreatedAt: time.Now(),
}
logic.handleCommission(context.Background(), referee, orderInfo)
assertExpireIncreasedByDays(t, db, refereeSub.Id, baseExpire, 0)
assertUserCommission(t, db, referer.Id, 0)
assertLogCountForOrder(t, db, modelLog.TypeCommission.Uint8(), orderInfo.OrderNo, 0)
assertLogCountForOrder(t, db, modelLog.TypeGift.Uint8(), orderInfo.OrderNo, 0)
}
func TestHandleCommission_BothFamilySidesUseCorrectGiftOwners(t *testing.T) {
logic, db, cleanup := setupInviteTestLogic(t, config.InviteConfig{
ReferralPercentage: 0,
OnlyFirstPurchase: false,
GiftDays: 2,
})
defer cleanup()
refereeOwner := seedUser(t, db, 0, false)
referee := seedUser(t, db, 0, false)
refererOwner := seedUser(t, db, 0, false)
referer := seedUser(t, db, 0, false)
seedFamily(t, db, refereeOwner.Id, referee.Id)
seedFamily(t, db, refererOwner.Id, referer.Id)
referee.RefererId = referer.Id
refereeOwnerBaseExpire := time.Now().Add(72 * time.Hour).Truncate(time.Second)
refererOwnerBaseExpire := time.Now().Add(96 * time.Hour).Truncate(time.Second)
refereeOwnerSub := seedActiveSubscribe(t, db, refereeOwner.Id, refereeOwnerBaseExpire)
refererOwnerSub := seedActiveSubscribe(t, db, refererOwner.Id, refererOwnerBaseExpire)
orderInfo := &order.Order{
OrderNo: "ORD-BOTH-FAMILIES-001",
Type: OrderTypeSubscribe,
IsNew: true,
Amount: 100,
FeeAmount: 0,
SubscriptionUserId: refereeOwner.Id,
CreatedAt: time.Now(),
}
logic.handleCommission(context.Background(), referee, orderInfo)
assertExpireIncreasedByDays(t, db, refereeOwnerSub.Id, refereeOwnerBaseExpire, 2)
assertExpireIncreasedByDays(t, db, refererOwnerSub.Id, refererOwnerBaseExpire, 2)
assertNoSubscribeForUser(t, db, referee.Id)
assertNoSubscribeForUser(t, db, referer.Id)
assertLogCountForOrder(t, db, modelLog.TypeGift.Uint8(), orderInfo.OrderNo, 2)
}
func setupInviteTestLogic(t *testing.T, inviteCfg config.InviteConfig) (*ActivateOrderLogic, *gorm.DB, func()) {
t.Helper()
mysqlAddr := getenvDefault("TEST_MYSQL_ADDR", "127.0.0.1:3306")
mysqlUser := getenvDefault("TEST_MYSQL_USER", "root")
mysqlPassword := getenvDefault("TEST_MYSQL_PASSWORD", "rootpassword")
adminDSN := fmt.Sprintf("%s:%s@tcp(%s)/?charset=utf8mb4&parseTime=true&loc=Local&multiStatements=true", mysqlUser, mysqlPassword, mysqlAddr)
adminDB, err := gorm.Open(mysql.Open(adminDSN), &gorm.Config{})
if err != nil {
t.Fatalf("open mysql admin connection failed: %v", err)
}
dbName := fmt.Sprintf("ppanel_test_invite_%d", time.Now().UnixNano())
if err := adminDB.Exec(fmt.Sprintf("CREATE DATABASE `%s` CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci", dbName)).Error; err != nil {
t.Fatalf("create test database failed: %v", err)
}
testDSN := fmt.Sprintf("%s:%s@tcp(%s)/%s?charset=utf8mb4&parseTime=true&loc=Local", mysqlUser, mysqlPassword, mysqlAddr, dbName)
db, err := gorm.Open(mysql.Open(testDSN), &gorm.Config{})
if err != nil {
t.Fatalf("open test database failed: %v", err)
}
if err := db.AutoMigrate(&user.User{}, &user.Device{}, &user.AuthMethods{}, &user.Subscribe{}, &user.UserFamily{}, &user.UserFamilyMember{}, &modelLog.SystemLog{}); err != nil {
t.Fatalf("auto migrate failed: %v", err)
}
redisAddr := getenvDefault("TEST_REDIS_ADDR", "127.0.0.1:6379")
redisPassword := getenvDefault("TEST_REDIS_PASSWORD", "")
rdb := redis.NewClient(&redis.Options{
Addr: redisAddr,
Password: redisPassword,
DB: 0,
})
if err := rdb.Ping(context.Background()).Err(); err != nil {
t.Fatalf("connect redis failed: %v", err)
}
_ = rdb.FlushDB(context.Background()).Err()
svcCtx := &svc.ServiceContext{
DB: db,
Redis: rdb,
UserModel: user.NewModel(db, rdb),
LogModel: modelLog.NewModel(db),
Config: config.Config{
Invite: inviteCfg,
},
}
return NewActivateOrderLogic(svcCtx), db, func() {
_ = rdb.Close()
sqlDB, _ := db.DB()
if sqlDB != nil {
_ = sqlDB.Close()
}
_ = adminDB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS `%s`", dbName)).Error
}
}
func seedUser(t *testing.T, db *gorm.DB, referralPercentage uint8, onlyFirstPurchase bool) *user.User {
t.Helper()
u := &user.User{
Password: "pwd",
Algo: "default",
ReferralPercentage: referralPercentage,
OnlyFirstPurchase: boolPtr(onlyFirstPurchase),
Enable: boolPtr(true),
IsAdmin: boolPtr(false),
EnableBalanceNotify: boolPtr(false),
EnableLoginNotify: boolPtr(false),
EnableSubscribeNotify: boolPtr(false),
EnableTradeNotify: boolPtr(false),
}
if err := db.Create(u).Error; err != nil {
t.Fatalf("seed user failed: %v", err)
}
return u
}
func seedFamily(t *testing.T, db *gorm.DB, ownerID int64, memberID int64) {
t.Helper()
family := &user.UserFamily{
OwnerUserId: ownerID,
MaxMembers: 3,
Status: user.FamilyStatusActive,
}
if err := db.Create(family).Error; err != nil {
t.Fatalf("seed family failed: %v", err)
}
now := time.Now()
members := []user.UserFamilyMember{
{
FamilyId: family.Id,
UserId: ownerID,
Role: user.FamilyRoleOwner,
Status: user.FamilyMemberActive,
JoinSource: "test",
JoinedAt: now,
},
{
FamilyId: family.Id,
UserId: memberID,
Role: user.FamilyRoleMember,
Status: user.FamilyMemberActive,
JoinSource: "test",
JoinedAt: now,
},
}
if err := db.Create(&members).Error; err != nil {
t.Fatalf("seed family members failed: %v", err)
}
}
func seedActiveSubscribe(t *testing.T, db *gorm.DB, userID int64, expireAt time.Time) *user.Subscribe {
t.Helper()
sub := &user.Subscribe{
UserId: userID,
OrderId: 1,
SubscribeId: 1,
StartTime: time.Now().Add(-24 * time.Hour),
ExpireTime: expireAt,
Traffic: 1024,
Token: fmt.Sprintf("token-%d-%d", userID, time.Now().UnixNano()),
UUID: fmt.Sprintf("uuid-%d-%d", userID, time.Now().UnixNano()),
Status: 1,
}
if err := db.Create(sub).Error; err != nil {
t.Fatalf("seed subscribe failed: %v", err)
}
return sub
}
func assertExpireIncreasedByDays(t *testing.T, db *gorm.DB, subscribeID int64, before time.Time, days int) {
t.Helper()
var after user.Subscribe
if err := db.First(&after, subscribeID).Error; err != nil {
t.Fatalf("query subscribe failed: %v", err)
}
expected := before.Add(time.Duration(days) * 24 * time.Hour)
if !after.ExpireTime.Equal(expected) {
t.Fatalf("expire time mismatch, expected=%v got=%v", expected, after.ExpireTime)
}
}
func assertUserCommission(t *testing.T, db *gorm.DB, userID int64, expected int64) {
t.Helper()
var u user.User
if err := db.First(&u, userID).Error; err != nil {
t.Fatalf("query user failed: %v", err)
}
if u.Commission != expected {
t.Fatalf("expected user %d commission=%d, got %d", userID, expected, u.Commission)
}
}
func assertLogCountForOrder(t *testing.T, db *gorm.DB, logType uint8, orderNo string, expected int64) {
t.Helper()
var count int64
if err := db.Model(&modelLog.SystemLog{}).
Where("type = ? AND content LIKE ?", logType, "%"+orderNo+"%").
Count(&count).Error; err != nil {
t.Fatalf("count logs failed: %v", err)
}
if count != expected {
t.Fatalf("expected log type %d count=%d for order %s, got %d", logType, expected, orderNo, count)
}
}
func assertGiftLogRemarkContains(t *testing.T, db *gorm.DB, orderNo string, want string) {
t.Helper()
var row modelLog.SystemLog
if err := db.Model(&modelLog.SystemLog{}).
Where("type = ? AND content LIKE ?", modelLog.TypeGift.Uint8(), "%"+orderNo+"%").
First(&row).Error; err != nil {
t.Fatalf("query gift log failed: %v", err)
}
if !strings.Contains(row.Content, want) {
t.Fatalf("expected gift log content to contain %q, got %s", want, row.Content)
}
}
func assertNoSubscribeForUser(t *testing.T, db *gorm.DB, userID int64) {
t.Helper()
var count int64
if err := db.Model(&user.Subscribe{}).Where("user_id = ?", userID).Count(&count).Error; err != nil {
t.Fatalf("count subscribes failed: %v", err)
}
if count != 0 {
t.Fatalf("expected user %d to have no direct subscriptions, got %d", userID, count)
}
}
func boolPtr(v bool) *bool {
return &v
}
func getenvDefault(key, fallback string) string {
v := os.Getenv(key)
if v == "" {
return fallback
}
return v
}

View File

@ -0,0 +1,199 @@
package orderLogic
import (
"context"
"testing"
"time"
modelOrder "github.com/perfect-panel/server/internal/model/order"
"github.com/perfect-panel/server/internal/model/subscribe"
"github.com/perfect-panel/server/internal/model/user"
"github.com/stretchr/testify/require"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
func setupActivationEligibilityDB(t *testing.T) *gorm.DB {
t.Helper()
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
})
require.NoError(t, err)
sqls := []string{
`CREATE TABLE IF NOT EXISTS "user" (
id INTEGER PRIMARY KEY AUTOINCREMENT,
password VARCHAR(100) NOT NULL DEFAULT '',
created_at DATETIME,
updated_at DATETIME,
deleted_at DATETIME DEFAULT NULL
)`,
`CREATE TABLE IF NOT EXISTS "user_device" (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL DEFAULT 0,
identifier VARCHAR(255) NOT NULL DEFAULT '' UNIQUE,
created_at DATETIME,
updated_at DATETIME
)`,
`CREATE TABLE IF NOT EXISTS "user_family" (
id INTEGER PRIMARY KEY AUTOINCREMENT,
owner_user_id INTEGER NOT NULL DEFAULT 0,
status TINYINT NOT NULL DEFAULT 1,
deleted_at DATETIME DEFAULT NULL
)`,
`CREATE TABLE IF NOT EXISTS "user_family_member" (
id INTEGER PRIMARY KEY AUTOINCREMENT,
family_id INTEGER NOT NULL DEFAULT 0,
user_id INTEGER NOT NULL DEFAULT 0,
role TINYINT NOT NULL DEFAULT 0,
status TINYINT NOT NULL DEFAULT 0,
join_source VARCHAR(32) NOT NULL DEFAULT '',
deleted_at DATETIME DEFAULT NULL
)`,
`CREATE TABLE IF NOT EXISTS "order" (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL DEFAULT 0,
order_no VARCHAR(255) NOT NULL DEFAULT '' UNIQUE,
type TINYINT NOT NULL DEFAULT 1,
status TINYINT NOT NULL DEFAULT 1,
subscribe_id INTEGER NOT NULL DEFAULT 0,
quantity INTEGER NOT NULL DEFAULT 1,
created_at DATETIME,
updated_at DATETIME
)`,
}
for _, sql := range sqls {
require.NoError(t, db.Exec(sql).Error)
}
return db
}
func insertActivationUser(t *testing.T, db *gorm.DB, userID int64, createdAt time.Time) {
t.Helper()
require.NoError(t, db.Exec(
`INSERT INTO "user" (id, created_at, updated_at) VALUES (?, ?, datetime('now'))`,
userID,
createdAt.UTC().Format("2006-01-02 15:04:05"),
).Error)
}
func insertActivationDevice(t *testing.T, db *gorm.DB, userID int64, identifier string, createdAt time.Time) {
t.Helper()
require.NoError(t, db.Exec(
`INSERT INTO "user_device" (user_id, identifier, created_at, updated_at) VALUES (?, ?, ?, datetime('now'))`,
userID,
identifier,
createdAt.UTC().Format("2006-01-02 15:04:05"),
).Error)
}
func insertActivationFamily(t *testing.T, db *gorm.DB, familyID, ownerUserID int64) {
t.Helper()
require.NoError(t, db.Exec(
`INSERT INTO "user_family" (id, owner_user_id, status) VALUES (?, ?, 1)`,
familyID,
ownerUserID,
).Error)
}
func insertActivationFamilyMember(t *testing.T, db *gorm.DB, familyID, userID int64, role, status uint8, joinSource string) {
t.Helper()
require.NoError(t, db.Exec(
`INSERT INTO "user_family_member" (family_id, user_id, role, status, join_source) VALUES (?, ?, ?, ?, ?)`,
familyID,
userID,
role,
status,
joinSource,
).Error)
}
func insertActivationOrder(t *testing.T, db *gorm.DB, orderNo string, userID, subscribeID int64, status uint8) {
t.Helper()
require.NoError(t, db.Exec(
`INSERT INTO "order" (user_id, order_no, type, status, subscribe_id, quantity, created_at, updated_at)
VALUES (?, ?, 1, ?, ?, 1, datetime('now'), datetime('now'))`,
userID,
orderNo,
status,
subscribeID,
).Error)
}
func TestValidateNewUserOnlyEligibilityAtActivation_UsesEarliestBoundDeviceTime(t *testing.T) {
db := setupActivationEligibilityDB(t)
const (
ownerUserID = int64(1)
memberUserID = int64(2)
familyID = int64(10)
subscribeID = int64(100)
)
insertActivationUser(t, db, ownerUserID, time.Now().Add(-1*time.Hour))
insertActivationUser(t, db, memberUserID, time.Now().Add(-72*time.Hour))
insertActivationDevice(t, db, memberUserID, "activation-old-device", time.Now().Add(-72*time.Hour))
insertActivationFamily(t, db, familyID, ownerUserID)
insertActivationFamilyMember(t, db, familyID, ownerUserID, user.FamilyRoleOwner, user.FamilyMemberActive, "owner_init")
insertActivationFamilyMember(t, db, familyID, memberUserID, user.FamilyRoleMember, user.FamilyMemberActive, "bind_email_with_verification")
err := validateNewUserOnlyEligibilityAtActivation(
context.Background(),
db,
&modelOrder.Order{
UserId: ownerUserID,
OrderNo: "activation-check-old-device",
Type: OrderTypeSubscribe,
Quantity: 1,
SubscribeId: subscribeID,
},
&subscribe.Subscribe{
Id: subscribeID,
Discount: `[{"quantity":1,"discount":90,"new_user_only":true}]`,
},
)
require.Error(t, err)
require.Contains(t, err.Error(), "is not a new user")
}
func TestValidateNewUserOnlyEligibilityAtActivation_SharesHistoryAcrossBoundScope(t *testing.T) {
db := setupActivationEligibilityDB(t)
const (
ownerUserID = int64(11)
memberUserID = int64(12)
familyID = int64(20)
subscribeID = int64(200)
)
insertActivationUser(t, db, ownerUserID, time.Now().Add(-1*time.Hour))
insertActivationUser(t, db, memberUserID, time.Now().Add(-2*time.Hour))
insertActivationDevice(t, db, memberUserID, "activation-shared-device", time.Now().Add(-2*time.Hour))
insertActivationFamily(t, db, familyID, ownerUserID)
insertActivationFamilyMember(t, db, familyID, ownerUserID, user.FamilyRoleOwner, user.FamilyMemberActive, "owner_init")
insertActivationFamilyMember(t, db, familyID, memberUserID, user.FamilyRoleMember, user.FamilyMemberActive, "bind_email_with_verification")
insertActivationOrder(t, db, "previous-finished-order", memberUserID, subscribeID, OrderStatusFinished)
err := validateNewUserOnlyEligibilityAtActivation(
context.Background(),
db,
&modelOrder.Order{
UserId: ownerUserID,
OrderNo: "current-paid-order",
Type: OrderTypeSubscribe,
Quantity: 1,
SubscribeId: subscribeID,
},
&subscribe.Subscribe{
Id: subscribeID,
Discount: `[{"quantity":1,"discount":90,"new_user_only":true}]`,
},
)
require.Error(t, err)
require.Contains(t, err.Error(), "already activated")
}

View File

@ -43,7 +43,7 @@ func validateNewUserOnlyEligibilityAtActivation(
ctx, ctx,
db, db,
eligibility.ScopeUserIDs, eligibility.ScopeUserIDs,
orderInfo.SubscribeId, 0,
[]int64{OrderStatusFinished}, []int64{OrderStatusFinished},
orderInfo.OrderNo, orderInfo.OrderNo,
) )
@ -51,7 +51,7 @@ func validateNewUserOnlyEligibilityAtActivation(
return fmt.Errorf("new user only: check history error: %w", err) return fmt.Errorf("new user only: check history error: %w", err)
} }
if historyCount >= 1 { if historyCount >= 1 {
return fmt.Errorf("new user only: user %d already activated subscribe %d", orderInfo.UserId, orderInfo.SubscribeId) return fmt.Errorf("new user only: user %d already activated an order", orderInfo.UserId)
} }
return nil return nil

View File

@ -131,9 +131,15 @@ func (l *TrafficStatisticsLogic) ProcessTask(ctx context.Context, task *asynq.Ta
// 写完流量后检查是否触发按量限速,若触发则清除节点缓存使限速立即生效 // 写完流量后检查是否触发按量限速,若触发则清除节点缓存使限速立即生效
if planSub, planErr := l.svc.SubscribeModel.FindOne(ctx, sub.SubscribeId); planErr == nil && if planSub, planErr := l.svc.SubscribeModel.FindOne(ctx, sub.SubscribeId); planErr == nil &&
(planSub.SpeedLimit > 0 || planSub.TrafficLimit != "") { planSub.TrafficLimit != "" {
throttle := speedlimit.Calculate(ctx, l.svc.DB, sub.UserId, sub.Id, planSub.SpeedLimit, planSub.TrafficLimit) throttle := speedlimit.Calculate(ctx, l.svc.DB, sub.UserId, sub.Id, planSub.SpeedLimit, planSub.TrafficLimit)
if throttle.IsThrottled { if throttle.IsThrottled {
if delErr := speedlimit.ClearCache(ctx, l.svc.Redis, sub.UserId, sub.Id, planSub.SpeedLimit, planSub.TrafficLimit); delErr != nil {
logger.WithContext(ctx).Error("[TrafficStatistics] Clear speed limit cache failed",
logger.Field("subscribeId", sub.Id),
logger.Field("error", delErr.Error()),
)
}
cacheKey := fmt.Sprintf("%s%d", node.ServerUserListCacheKey, payload.ServerId) cacheKey := fmt.Sprintf("%s%d", node.ServerUserListCacheKey, payload.ServerId)
if delErr := l.svc.Redis.Del(ctx, cacheKey).Err(); delErr != nil { if delErr := l.svc.Redis.Del(ctx, cacheKey).Err(); delErr != nil {
logger.WithContext(ctx).Error("[TrafficStatistics] Clear server user cache failed", logger.WithContext(ctx).Error("[TrafficStatistics] Clear server user cache failed",

View File

@ -0,0 +1,185 @@
//go:build ignore
package main
import (
"bytes"
"crypto/md5"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"encoding/json"
"flag"
"fmt"
"io"
"net/http"
"time"
"github.com/forgoer/openssl"
)
// ===== AES 加解密(与 pkg/aes/aes.go 一致)=====
func generateKey(key string) []byte {
hash := sha256.Sum256([]byte(key))
return hash[:32]
}
func generateIv(iv, key string) []byte {
h := md5.New()
h.Write([]byte(iv))
return generateKey(hex.EncodeToString(h.Sum(nil)) + key)
}
func aesEncrypt(plainText []byte, keyStr string) (string, string, error) {
nonce := fmt.Sprintf("%x", time.Now().UnixNano())
key := generateKey(keyStr)
iv := generateIv(nonce, keyStr)
dst, err := openssl.AesCBCEncrypt(plainText, key, iv, openssl.PKCS7_PADDING)
if err != nil {
return "", "", err
}
return base64.StdEncoding.EncodeToString(dst), nonce, nil
}
func aesDecrypt(cipherText, keyStr, ivStr string) (string, error) {
decode, err := base64.StdEncoding.DecodeString(cipherText)
if err != nil {
return "", err
}
key := generateKey(keyStr)
iv := generateIv(ivStr, keyStr)
dst, err := openssl.AesCBCDecrypt(decode, key, iv, openssl.PKCS7_PADDING)
return string(dst), err
}
// ===== 主逻辑 =====
func main() {
deviceID := flag.String("id", "", "设备 ID (identifier)")
secret := flag.String("secret", "", "security_secret (device.security_secret)")
host := flag.String("host", "https://api.hifast.biz", "API 地址")
flag.Parse()
if *deviceID == "" || *secret == "" {
fmt.Println("用法: go run scripts/debug_device_login.go -id <设备ID> -secret <security_secret>")
return
}
// 1. 构造登录请求体
loginBody := map[string]interface{}{
"identifier": *deviceID,
"user_agent": "DebugScript/1.0",
}
loginJSON, _ := json.Marshal(loginBody)
// 2. AES 加密请求体
encData, nonce, err := aesEncrypt(loginJSON, *secret)
if err != nil {
fmt.Printf("❌ 加密失败: %v\n", err)
return
}
encBody := map[string]interface{}{
"data": encData,
"time": nonce,
}
encBodyJSON, _ := json.Marshal(encBody)
fmt.Printf("📤 登录请求体(加密): %s\n\n", encBodyJSON)
// 3. 发起设备登录请求
loginURL := *host + "/v1/auth/login/device"
req, _ := http.NewRequest("POST", loginURL, bytes.NewReader(encBodyJSON))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Login-Type", "device")
client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Do(req)
if err != nil {
fmt.Printf("❌ 登录请求失败: %v\n", err)
return
}
defer resp.Body.Close()
respBody, _ := io.ReadAll(resp.Body)
fmt.Printf("📥 登录响应(原始): %s\n\n", respBody)
// 4. 解密响应
var respMap map[string]interface{}
if err := json.Unmarshal(respBody, &respMap); err != nil {
fmt.Printf("❌ 解析响应 JSON 失败: %v\n", err)
return
}
var token string
if dataField, ok := respMap["data"]; ok {
switch d := dataField.(type) {
case map[string]interface{}:
// 加密响应
encResp, _ := d["data"].(string)
ivResp, _ := d["time"].(string)
if encResp != "" && ivResp != "" {
decrypted, err := aesDecrypt(encResp, *secret, ivResp)
if err != nil {
fmt.Printf("❌ 解密响应失败: %v\n", err)
return
}
fmt.Printf("📥 登录响应(解密): %s\n\n", decrypted)
var loginData map[string]interface{}
if err := json.Unmarshal([]byte(decrypted), &loginData); err == nil {
token, _ = loginData["token"].(string)
}
}
case string:
// 未加密直接是 token 字符串
token = d
}
}
if token == "" {
fmt.Println("❌ 未获取到 token登录失败")
return
}
fmt.Printf("✅ Token: %s\n\n", token)
// 5. 查询订阅
subURL := *host + "/v1/public/user/subscribe"
subReq, _ := http.NewRequest("GET", subURL, nil)
subReq.Header.Set("Authorization", "Bearer "+token)
subReq.Header.Set("Login-Type", "device")
subReq.Header.Set("X-App-Id", "debug")
subResp, err := client.Do(subReq)
if err != nil {
fmt.Printf("❌ 查询订阅失败: %v\n", err)
return
}
defer subResp.Body.Close()
subBody, _ := io.ReadAll(subResp.Body)
fmt.Printf("📥 订阅响应(原始): %s\n\n", subBody)
// 6. 解密订阅响应
var subRespMap map[string]interface{}
if err := json.Unmarshal(subBody, &subRespMap); err == nil {
if dataField, ok := subRespMap["data"]; ok {
if d, ok := dataField.(map[string]interface{}); ok {
encResp, _ := d["data"].(string)
ivResp, _ := d["time"].(string)
if encResp != "" && ivResp != "" {
decrypted, err := aesDecrypt(encResp, *secret, ivResp)
if err != nil {
fmt.Printf("❌ 解密订阅响应失败: %v\n", err)
return
}
// 格式化输出
var pretty interface{}
json.Unmarshal([]byte(decrypted), &pretty)
out, _ := json.MarshalIndent(pretty, "", " ")
fmt.Printf("📋 订阅信息(解密):\n%s\n", out)
}
}
}
}
}

View File

@ -0,0 +1,197 @@
package main
import (
"database/sql"
"flag"
"fmt"
"log"
"os"
"strings"
_ "github.com/go-sql-driver/mysql"
)
func main() {
dsn := flag.String("dsn", os.Getenv("PPANEL_MYSQL_DSN"), "MySQL DSN; defaults to PPANEL_MYSQL_DSN")
flag.Parse()
if strings.TrimSpace(*dsn) == "" {
log.Fatal("missing DSN: pass -dsn or set PPANEL_MYSQL_DSN")
}
db, err := sql.Open("mysql", *dsn)
if err != nil {
log.Fatal(err)
}
defer db.Close()
if err = db.Ping(); err != nil {
log.Fatal(err)
}
mustPrintRows(db, "db/info", `
SELECT NOW() AS db_now,
(SELECT COUNT(*) FROM user) AS users,
(SELECT COUNT(*) FROM user_subscribe) AS user_subscribes,
(SELECT COUNT(*) FROM `+"`order`"+`) AS orders`)
mustPrintRows(db, "bug1/confusable-email-trials", `
SELECT uam.user_id,
uam.auth_identifier,
us.id AS user_subscribe_id,
us.order_id,
us.status,
us.expire_time,
us.created_at
FROM user_auth_methods uam
JOIN user_subscribe us ON us.user_id = uam.user_id
WHERE uam.auth_type = 'email'
AND us.order_id = 0
AND (
uam.auth_identifier LIKE '%@gmaial.com'
OR uam.auth_identifier LIKE '%@gmial.com'
OR uam.auth_identifier LIKE '%@gamil.com'
OR uam.auth_identifier LIKE '%+%@%'
OR uam.auth_identifier REGEXP '^[^@]*\\.[^@]*@gmail\\.com$'
)
ORDER BY us.created_at DESC
LIMIT 50`)
mustPrintRows(db, "bug2-visible-duplicate-subscriptions", `
SELECT scoped.owner_user_id,
COUNT(*) AS visible_subscribe_count,
GROUP_CONCAT(scoped.user_subscribe_id ORDER BY scoped.expire_time DESC) AS user_subscribe_ids,
GROUP_CONCAT(scoped.subscribe_id ORDER BY scoped.expire_time DESC) AS subscribe_ids,
MAX(scoped.expire_time) AS max_expire_time
FROM (
SELECT us.id AS user_subscribe_id,
us.user_id,
COALESCE(uf.owner_user_id, us.user_id) AS owner_user_id,
us.subscribe_id,
us.status,
us.expire_time,
us.finished_at
FROM user_subscribe us
LEFT JOIN user_family_member ufm
ON ufm.user_id = us.user_id AND ufm.deleted_at IS NULL AND ufm.status = 1
LEFT JOIN user_family uf
ON uf.id = ufm.family_id AND uf.deleted_at IS NULL AND uf.status = 1
WHERE us.token <> ''
AND us.status IN (0,1,2,3,4)
AND (us.expire_time > NOW()
OR us.finished_at >= DATE_SUB(NOW(), INTERVAL 7 DAY)
OR us.expire_time = FROM_UNIXTIME(0))
) scoped
GROUP BY scoped.owner_user_id
HAVING COUNT(*) > 1
ORDER BY visible_subscribe_count DESC, owner_user_id
LIMIT 50`)
mustPrintRows(db, "bug2-order-subscription-owner-mismatch", `
SELECT us.id AS user_subscribe_id,
us.user_id AS subscribe_user_id,
o.id AS order_id,
o.order_no,
o.user_id AS order_user_id,
o.subscription_user_id,
us.status,
us.expire_time,
us.created_at AS subscribe_created_at,
o.created_at AS order_created_at
FROM user_subscribe us
JOIN `+"`order`"+` o ON o.id = us.order_id
WHERE us.user_id <> o.subscription_user_id
AND us.token <> ''
AND us.status IN (0,1,2,3,4)
ORDER BY us.updated_at DESC
LIMIT 50`)
mustPrintRows(db, "bug3-invite-first-orders-missing-gift-days", `
SELECT first_orders.user_id AS referee_id,
referee.referer_id,
first_orders.id AS order_id,
first_orders.order_no,
first_orders.amount,
first_orders.created_at,
referer.referral_percentage AS referer_referral_percentage,
(SELECT COUNT(*) FROM system_logs sl
WHERE sl.type = 34
AND sl.object_id = first_orders.user_id
AND sl.content LIKE CONCAT('%', first_orders.order_no, '%')) AS referee_gift_logs,
(SELECT COUNT(*) FROM system_logs sl
WHERE sl.type = 34
AND sl.object_id = referee.referer_id
AND sl.content LIKE CONCAT('%', first_orders.order_no, '%')) AS referer_gift_logs
FROM (
SELECT o.*
FROM `+"`order`"+` o
JOIN (
SELECT user_id, MIN(id) AS first_order_id
FROM `+"`order`"+`
WHERE type IN (1,2)
AND status IN (2,5)
AND amount > 0
GROUP BY user_id
) fo ON fo.first_order_id = o.id
) first_orders
JOIN user referee ON referee.id = first_orders.user_id AND referee.referer_id <> 0
JOIN user referer ON referer.id = referee.referer_id
WHERE (
referer.referral_percentage = 0
AND (
(SELECT COUNT(*) FROM system_logs sl
WHERE sl.type = 34 AND sl.object_id = first_orders.user_id AND sl.content LIKE CONCAT('%', first_orders.order_no, '%')) = 0
OR
(SELECT COUNT(*) FROM system_logs sl
WHERE sl.type = 34 AND sl.object_id = referee.referer_id AND sl.content LIKE CONCAT('%', first_orders.order_no, '%')) = 0
)
)
OR (
referer.referral_percentage > 0
AND (SELECT COUNT(*) FROM system_logs sl
WHERE sl.type = 34 AND sl.object_id = first_orders.user_id AND sl.content LIKE CONCAT('%', first_orders.order_no, '%')) = 0
)
ORDER BY first_orders.created_at DESC
LIMIT 50`)
}
func mustPrintRows(db *sql.DB, title string, query string) {
fmt.Printf("\n== %s ==\n", title)
rows, err := db.Query(query)
if err != nil {
log.Fatalf("%s: %v", title, err)
}
defer rows.Close()
cols, err := rows.Columns()
if err != nil {
log.Fatalf("%s columns: %v", title, err)
}
fmt.Println(strings.Join(cols, "\t"))
values := make([]sql.NullString, len(cols))
args := make([]any, len(cols))
for i := range values {
args[i] = &values[i]
}
count := 0
for rows.Next() {
if err := rows.Scan(args...); err != nil {
log.Fatalf("%s scan: %v", title, err)
}
out := make([]string, len(cols))
for i, value := range values {
if value.Valid {
out[i] = value.String
} else {
out[i] = "NULL"
}
}
fmt.Println(strings.Join(out, "\t"))
count++
}
if err := rows.Err(); err != nil {
log.Fatalf("%s rows: %v", title, err)
}
if count == 0 {
fmt.Println("(none)")
}
}

View File

@ -0,0 +1,204 @@
package main
import (
"database/sql"
"encoding/json"
"flag"
"fmt"
"log"
"os"
"strings"
"time"
_ "github.com/go-sql-driver/mysql"
)
type duplicateGroup struct {
OwnerUserID int64 `json:"owner_user_id"`
Count int64 `json:"count"`
}
type subscriptionRow struct {
ID int64 `json:"id"`
UserID int64 `json:"user_id"`
OrderID int64 `json:"order_id"`
SubscribeID int64 `json:"subscribe_id"`
ExpireTime time.Time `json:"expire_time"`
Traffic int64 `json:"traffic"`
Download int64 `json:"download"`
Upload int64 `json:"upload"`
ExpiredDownload int64 `json:"expired_download"`
ExpiredUpload int64 `json:"expired_upload"`
Status uint8 `json:"status"`
UpdatedAt time.Time `json:"updated_at"`
}
type mergePlan struct {
OwnerUserID int64 `json:"owner_user_id"`
Keep subscriptionRow `json:"keep"`
Merge []subscriptionRow `json:"merge"`
}
func main() {
dsn := flag.String("dsn", os.Getenv("PPANEL_MYSQL_DSN"), "MySQL DSN; defaults to PPANEL_MYSQL_DSN")
execute := flag.Bool("execute", false, "apply changes; default is dry-run")
flag.Parse()
if strings.TrimSpace(*dsn) == "" {
log.Fatal("missing DSN: pass -dsn or set PPANEL_MYSQL_DSN")
}
db, err := sql.Open("mysql", *dsn)
if err != nil {
log.Fatal(err)
}
defer db.Close()
groups, err := findDuplicateGroups(db)
if err != nil {
log.Fatal(err)
}
plans := make([]mergePlan, 0, len(groups))
for _, group := range groups {
plan, err := buildPlan(db, group.OwnerUserID)
if err != nil {
log.Fatal(err)
}
if len(plan.Merge) > 0 {
plans = append(plans, plan)
}
}
enc := json.NewEncoder(os.Stdout)
enc.SetIndent("", " ")
if err := enc.Encode(plans); err != nil {
log.Fatal(err)
}
if !*execute {
fmt.Fprintf(os.Stderr, "dry-run only: %d duplicate owner groups found\n", len(plans))
return
}
for _, plan := range plans {
if err := applyPlan(db, plan); err != nil {
log.Fatal(err)
}
}
fmt.Fprintf(os.Stderr, "merged %d duplicate owner groups\n", len(plans))
}
func findDuplicateGroups(db *sql.DB) ([]duplicateGroup, error) {
rows, err := db.Query(`
SELECT owner_user_id, COUNT(1) AS cnt
FROM (
SELECT us.id,
COALESCE(uf.owner_user_id, us.user_id) AS owner_user_id
FROM user_subscribe us
LEFT JOIN user_family_member ufm
ON ufm.user_id = us.user_id AND ufm.deleted_at IS NULL AND ufm.status = 1
LEFT JOIN user_family uf
ON uf.id = ufm.family_id AND uf.deleted_at IS NULL AND uf.status = 1
WHERE us.token <> ''
AND us.status IN (0, 1, 2, 3, 4)
) scoped
GROUP BY owner_user_id
HAVING COUNT(1) > 1
ORDER BY owner_user_id`)
if err != nil {
return nil, err
}
defer rows.Close()
var groups []duplicateGroup
for rows.Next() {
var g duplicateGroup
if err := rows.Scan(&g.OwnerUserID, &g.Count); err != nil {
return nil, err
}
groups = append(groups, g)
}
return groups, rows.Err()
}
func buildPlan(db *sql.DB, ownerUserID int64) (mergePlan, error) {
rows, err := db.Query(`
SELECT us.id, us.user_id, us.order_id, us.subscribe_id, us.expire_time, us.traffic,
us.download, us.upload, us.expired_download, us.expired_upload, us.status, us.updated_at
FROM user_subscribe us
LEFT JOIN user_family_member ufm
ON ufm.user_id = us.user_id AND ufm.deleted_at IS NULL AND ufm.status = 1
LEFT JOIN user_family uf
ON uf.id = ufm.family_id AND uf.deleted_at IS NULL AND uf.status = 1
WHERE COALESCE(uf.owner_user_id, us.user_id) = ?
AND us.token <> ''
AND us.status IN (0, 1, 2, 3, 4)
ORDER BY us.expire_time DESC, us.updated_at DESC, us.id DESC`, ownerUserID)
if err != nil {
return mergePlan{}, err
}
defer rows.Close()
var all []subscriptionRow
for rows.Next() {
var r subscriptionRow
if err := rows.Scan(&r.ID, &r.UserID, &r.OrderID, &r.SubscribeID, &r.ExpireTime, &r.Traffic, &r.Download, &r.Upload, &r.ExpiredDownload, &r.ExpiredUpload, &r.Status, &r.UpdatedAt); err != nil {
return mergePlan{}, err
}
all = append(all, r)
}
if err := rows.Err(); err != nil {
return mergePlan{}, err
}
if len(all) == 0 {
return mergePlan{OwnerUserID: ownerUserID}, nil
}
keep := all[0]
for _, r := range all[1:] {
keep.Download += r.Download
keep.Upload += r.Upload
keep.ExpiredDownload += r.ExpiredDownload
keep.ExpiredUpload += r.ExpiredUpload
if r.Traffic > keep.Traffic {
keep.Traffic = r.Traffic
}
}
for _, r := range all {
if r.UpdatedAt.After(keep.UpdatedAt) {
keep.SubscribeID = r.SubscribeID
}
}
return mergePlan{OwnerUserID: ownerUserID, Keep: keep, Merge: all[1:]}, nil
}
func applyPlan(db *sql.DB, plan mergePlan) error {
tx, err := db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if _, err = tx.Exec(`
UPDATE user_subscribe
SET user_id = ?, subscribe_id = ?, traffic = ?, download = ?, upload = ?,
expired_download = ?, expired_upload = ?, status = 1, note = CONCAT(COALESCE(note, ''), ' [merged duplicate subscriptions]')
WHERE id = ?`,
plan.OwnerUserID, plan.Keep.SubscribeID, plan.Keep.Traffic, plan.Keep.Download, plan.Keep.Upload,
plan.Keep.ExpiredDownload, plan.Keep.ExpiredUpload, plan.Keep.ID); err != nil {
return err
}
for _, r := range plan.Merge {
if _, err = tx.Exec(`
UPDATE user_subscribe
SET status = 5, note = CONCAT(COALESCE(note, ''), ' [merged into subscription #', ?, ']')
WHERE id = ?`, plan.Keep.ID, r.ID); err != nil {
return err
}
}
return tx.Commit()
}

View File

@ -0,0 +1,787 @@
package main
import (
"context"
"encoding/json"
"flag"
"fmt"
"os"
"os/exec"
"strings"
"time"
"github.com/hibiken/asynq"
"github.com/perfect-panel/server/internal/config"
authlogic "github.com/perfect-panel/server/internal/logic/auth"
modelLog "github.com/perfect-panel/server/internal/model/log"
modelOrder "github.com/perfect-panel/server/internal/model/order"
modelSubscribe "github.com/perfect-panel/server/internal/model/subscribe"
modelUser "github.com/perfect-panel/server/internal/model/user"
"github.com/perfect-panel/server/internal/svc"
"github.com/perfect-panel/server/pkg/conf"
"github.com/perfect-panel/server/pkg/orm"
"github.com/perfect-panel/server/pkg/uuidx"
orderLogic "github.com/perfect-panel/server/queue/logic/order"
queueTypes "github.com/perfect-panel/server/queue/types"
"github.com/redis/go-redis/v9"
"gorm.io/gorm"
)
const marker = "codex-replay-business-bugs"
func main() {
var (
configPath = flag.String("config", "etc/ppanel.yaml", "ppanel config path for test server DB/Redis")
dsn = flag.String("dsn", "", "optional MySQL DSN override: user:pass@tcp(host:3306)/db?charset=utf8mb4&parseTime=true&loc=Asia%2FShanghai")
writeDB = flag.Bool("write-db", false, "create isolated test rows and execute activation replay against the configured test DB")
force = flag.Bool("force", false, "allow -write-db even when the config name does not clearly look like test/dev/staging")
keep = flag.Bool("keep", false, "keep replay rows for manual inspection")
cleanupOnly = flag.Bool("cleanup-only", false, "delete leftover replay rows by marker and exit")
skipCodeTests = flag.Bool("skip-code-tests", false, "skip go test checks")
)
flag.Parse()
ctx := context.Background()
started := time.Now()
fmt.Println("== replay business bug tests ==")
fmt.Printf("marker: %s\n", marker)
if !*skipCodeTests {
must(runCodeTests())
}
cfg := loadConfig(*configPath, *dsn)
runEmailTrialAssertions(cfg)
if *cleanupOnly {
env := mustNewReplayEnv(ctx, cfg)
env.cleanupByMarker(ctx)
return
}
if !*writeDB {
fmt.Println("\nDB replay skipped. Add -write-db to create isolated rows in the TEST database and run activation flows.")
fmt.Println("Example:")
fmt.Printf(" go run scripts/replay_business_bugs.go -config %s -write-db\n", *configPath)
return
}
if looksLikeProduction(cfg) && !*force {
fatalf("refusing to write DB because config does not look like a test environment: db=%s host=%s; add -force only on the test server", cfg.MySQL.Dbname, cfg.Site.Host)
}
env := mustNewReplayEnv(ctx, cfg)
if !*keep {
defer env.cleanup(ctx)
}
must(env.replaySingleSubscription(ctx))
must(env.replayInviteRulesMatrix(ctx))
must(env.replayFamilyInviteGiftToOwner(ctx))
fmt.Printf("\nPASS all replay checks in %s\n", time.Since(started).Round(time.Millisecond))
if *keep {
fmt.Println("Replay rows kept for inspection. Delete rows with remark/name/order_no containing:", marker)
}
}
func runCodeTests() error {
fmt.Println("\n-- code-level tests --")
args := []string{"test",
"./internal/logic/auth",
"./internal/logic/common",
"./internal/logic/public/order",
"./queue/logic/order",
}
cmd := exec.Command("go", args...)
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
if err := cmd.Run(); err != nil {
return fmt.Errorf("go test failed: %w", err)
}
fmt.Println("PASS code-level tests")
return nil
}
func loadConfig(path, dsn string) config.Config {
var cfg config.Config
conf.MustLoad(path, &cfg)
if dsn != "" {
cfg.MySQL = parseDSN(dsn)
}
return cfg
}
func parseDSN(dsn string) orm.Config {
cfg := orm.ParseDSN(dsn)
if cfg == nil {
fatalf("invalid dsn")
}
return *cfg
}
func runEmailTrialAssertions(cfg config.Config) {
fmt.Println("\n-- bug1 email trial whitelist assertions --")
cfg.Register.EnableTrial = true
cfg.Register.EnableTrialEmailWhitelist = true
if cfg.Register.TrialEmailDomainWhitelist == "" {
cfg.Register.TrialEmailDomainWhitelist = "gmail.com,163.com"
}
cases := []struct {
email string
want bool
}{
{"1.2.3.4xxx@gmaial.com", false},
{"a.b.c@gmail.com", false},
{"user+tag@gmail.com", false},
{"user@fake.gmail.com", false},
{"normaluser@gmail.com", true},
}
for _, tc := range cases {
got := authlogic.ShouldGrantTrialForEmail(cfg.Register, tc.email)
if got != tc.want {
fatalf("email trial assertion failed: email=%s got=%v want=%v", tc.email, got, tc.want)
}
fmt.Printf("PASS %-32s grant=%v\n", tc.email, got)
}
}
type replayEnv struct {
db *gorm.DB
rds *redis.Client
cfg config.Config
svcCtx *svc.ServiceContext
ids struct {
users []int64
subscribes []int64
plans []int64
orders []int64
logs []int64
}
}
func mustNewReplayEnv(ctx context.Context, cfg config.Config) *replayEnv {
fmt.Println("\n-- connecting test DB/Redis --")
db, err := orm.ConnectMysql(orm.Mysql{Config: cfg.MySQL})
must(err)
rds := redis.NewClient(&redis.Options{
Addr: cfg.Redis.Host,
Password: cfg.Redis.Pass,
DB: cfg.Redis.DB,
PoolSize: cfg.Redis.PoolSize,
MinIdleConns: cfg.Redis.MinIdleConns,
})
must(rds.Ping(ctx).Err())
svcCtx := &svc.ServiceContext{
DB: db,
Redis: rds,
Config: cfg,
UserModel: modelUser.NewModel(db, rds),
OrderModel: modelOrder.NewModel(db, rds),
SubscribeModel: modelSubscribe.NewModel(db, rds),
LogModel: modelLog.NewModel(db),
}
fmt.Printf("connected: mysql=%s/%s redis=%s\n", cfg.MySQL.Addr, cfg.MySQL.Dbname, cfg.Redis.Host)
return &replayEnv{db: db, rds: rds, cfg: cfg, svcCtx: svcCtx}
}
func (e *replayEnv) replaySingleSubscription(ctx context.Context) error {
fmt.Println("\n-- bug2 replay: paid purchase must reuse existing subscription --")
planA, planB, err := e.createPlans(ctx, "bug2")
if err != nil {
return err
}
owner, err := e.createUser(ctx, "bug2-owner", 0, 0)
if err != nil {
return err
}
existing, err := e.createUserSubscribe(ctx, owner.Id, 0, planA.Id, time.Now().Add(7*24*time.Hour))
if err != nil {
return err
}
order, err := e.createPaidOrder(ctx, owner.Id, owner.Id, planB.Id, true, "bug2")
if err != nil {
return err
}
payload, _ := json.Marshal(queueTypes.ForthwithActivateOrderPayload{OrderNo: order.OrderNo})
worker := orderLogic.NewActivateOrderLogic(e.svcCtx)
if err = worker.ProcessTask(ctx, asynq.NewTask(queueTypes.ForthwithActivateOrder, payload)); err != nil {
return err
}
if err = worker.ProcessTask(ctx, asynq.NewTask(queueTypes.ForthwithActivateOrder, payload)); err != nil {
return err
}
var rows []modelUser.Subscribe
if err = e.db.WithContext(ctx).
Where("user_id = ? AND token <> '' AND status IN ?", owner.Id, []int{0, 1, 2, 3, 4}).
Order("id ASC").
Find(&rows).Error; err != nil {
return err
}
if len(rows) != 1 {
return fmt.Errorf("bug2 failed: expected one visible subscription, got %d", len(rows))
}
if rows[0].Id != existing.Id {
return fmt.Errorf("bug2 failed: expected original subscription id=%d to be reused, got id=%d", existing.Id, rows[0].Id)
}
if rows[0].SubscribeId != planB.Id || rows[0].OrderId != order.Id {
return fmt.Errorf("bug2 failed: reused subscription not updated, subscribe_id=%d order_id=%d", rows[0].SubscribeId, rows[0].OrderId)
}
fmt.Printf("PASS user=%d user_subscribe=%d plan %d -> %d order=%s\n", owner.Id, rows[0].Id, planA.Id, planB.Id, order.OrderNo)
return nil
}
func (e *replayEnv) replayInviteGiftDays(ctx context.Context) error {
fmt.Println("\n-- bug3 replay: commission=0 invite should grant gift days to both users --")
giftDays := e.cfg.Invite.GiftDays
if giftDays <= 0 {
giftDays = 2
e.svcCtx.Config.Invite.GiftDays = giftDays
}
e.svcCtx.Config.Invite.ReferralPercentage = 0
e.svcCtx.Config.Invite.OnlyFirstPurchase = true
planA, _, err := e.createPlans(ctx, "bug3")
if err != nil {
return err
}
referer, err := e.createUser(ctx, "bug3-referer", 0, 0)
if err != nil {
return err
}
referee, err := e.createUser(ctx, "bug3-referee", referer.Id, 0)
if err != nil {
return err
}
baseExpire := time.Now().Add(10 * 24 * time.Hour).Truncate(time.Millisecond)
refererSub, err := e.createUserSubscribe(ctx, referer.Id, 0, planA.Id, baseExpire)
if err != nil {
return err
}
refereeSub, err := e.createUserSubscribe(ctx, referee.Id, 0, planA.Id, baseExpire)
if err != nil {
return err
}
order, err := e.createPaidOrder(ctx, referee.Id, referee.Id, planA.Id, true, "bug3")
if err != nil {
return err
}
payload, _ := json.Marshal(queueTypes.ForthwithActivateOrderPayload{OrderNo: order.OrderNo})
worker := orderLogic.NewActivateOrderLogic(e.svcCtx)
if err = worker.ProcessTask(ctx, asynq.NewTask(queueTypes.ForthwithActivateOrder, payload)); err != nil {
return err
}
if err = e.waitForGiftLogs(ctx, order.OrderNo, referer.Id, referee.Id); err != nil {
return err
}
var refererAfter, refereeAfter modelUser.Subscribe
if err = e.db.WithContext(ctx).First(&refererAfter, refererSub.Id).Error; err != nil {
return err
}
if err = e.db.WithContext(ctx).First(&refereeAfter, refereeSub.Id).Error; err != nil {
return err
}
minRefererExpire := baseExpire.Add(time.Duration(giftDays) * 24 * time.Hour)
if refererAfter.ExpireTime.Before(minRefererExpire.Add(-time.Second)) {
return fmt.Errorf("bug3 failed: referer expire not increased by gift days, got=%s want>=%s", refererAfter.ExpireTime, minRefererExpire)
}
if !refereeAfter.ExpireTime.After(baseExpire) {
return fmt.Errorf("bug3 failed: referee expire did not increase, got=%s base=%s", refereeAfter.ExpireTime, baseExpire)
}
// Idempotency: repeat the same order task and make sure gift logs are still one per user.
if err = worker.ProcessTask(ctx, asynq.NewTask(queueTypes.ForthwithActivateOrder, payload)); err != nil {
return err
}
var giftCount int64
if err = e.db.WithContext(ctx).Model(&modelLog.SystemLog{}).
Where("type = ? AND object_id IN ? AND content LIKE ?", modelLog.TypeGift.Uint8(), []int64{referer.Id, referee.Id}, "%"+order.OrderNo+"%").
Count(&giftCount).Error; err != nil {
return err
}
if giftCount != 2 {
return fmt.Errorf("bug3 failed: expected 2 gift logs after duplicate task, got %d", giftCount)
}
fmt.Printf("PASS referer=%d referee=%d order=%s gift_days=%d logs=%d\n", referer.Id, referee.Id, order.OrderNo, giftDays, giftCount)
return nil
}
func (e *replayEnv) replayInviteRulesMatrix(ctx context.Context) error {
fmt.Println("\n-- bug3 replay matrix: invite gift/commission rules --")
giftDays := e.cfg.Invite.GiftDays
if giftDays <= 0 {
giftDays = 2
}
e.svcCtx.Config.Invite.GiftDays = giftDays
e.svcCtx.Config.Invite.OnlyFirstPurchase = false
planA, _, err := e.createPlans(ctx, "bug3-matrix")
if err != nil {
return err
}
cases := []struct {
name string
hasReferer bool
globalReferralPct int64
isNewOrder bool
wantGiftLogs int64
wantCommissionLogs int64
wantCommission int64
}{
{
name: "no invite relation first order no gift",
hasReferer: false,
isNewOrder: true,
wantGiftLogs: 0,
},
{
name: "ordinary invite commission 0 first order gifts both",
hasReferer: true,
isNewOrder: true,
wantGiftLogs: 2,
},
{
name: "ordinary invite commission 0 non-first order no gift",
hasReferer: true,
isNewOrder: false,
wantGiftLogs: 0,
},
{
name: "channel commission positive first order gifts referee only",
hasReferer: true,
globalReferralPct: 10,
isNewOrder: true,
wantGiftLogs: 1,
wantCommissionLogs: 1,
wantCommission: 59,
},
{
name: "channel commission positive non-first order commission only",
hasReferer: true,
globalReferralPct: 10,
isNewOrder: false,
wantGiftLogs: 0,
wantCommissionLogs: 1,
wantCommission: 59,
},
}
for idx, tc := range cases {
e.svcCtx.Config.Invite.ReferralPercentage = tc.globalReferralPct
scope := fmt.Sprintf("bug3-rule-%d", idx+1)
var referer *modelUser.User
if tc.hasReferer {
referer, err = e.createUser(ctx, scope+"-referer", 0, 0)
if err != nil {
return err
}
if _, err = e.createUserSubscribe(ctx, referer.Id, 0, planA.Id, time.Now().Add(10*24*time.Hour)); err != nil {
return err
}
}
var refererID int64
if referer != nil {
refererID = referer.Id
}
referee, err := e.createUser(ctx, scope+"-referee", refererID, 0)
if err != nil {
return err
}
if _, err = e.createUserSubscribe(ctx, referee.Id, 0, planA.Id, time.Now().Add(10*24*time.Hour)); err != nil {
return err
}
order, err := e.createPaidOrder(ctx, referee.Id, referee.Id, planA.Id, tc.isNewOrder, scope)
if err != nil {
return err
}
if err = e.activateOrderTwice(ctx, order.OrderNo); err != nil {
return fmt.Errorf("%s: %w", tc.name, err)
}
if err = e.waitForLogCounts(ctx, order.OrderNo, tc.wantGiftLogs, tc.wantCommissionLogs); err != nil {
return fmt.Errorf("%s: %w", tc.name, err)
}
giftLogs, err := e.countLogs(ctx, modelLog.TypeGift.Uint8(), order.OrderNo)
if err != nil {
return err
}
commissionLogs, err := e.countLogs(ctx, modelLog.TypeCommission.Uint8(), order.OrderNo)
if err != nil {
return err
}
if giftLogs != tc.wantGiftLogs {
return fmt.Errorf("%s: expected gift logs=%d got=%d", tc.name, tc.wantGiftLogs, giftLogs)
}
if commissionLogs != tc.wantCommissionLogs {
return fmt.Errorf("%s: expected commission logs=%d got=%d", tc.name, tc.wantCommissionLogs, commissionLogs)
}
if referer != nil && tc.wantCommission > 0 {
var after modelUser.User
if err = e.db.WithContext(ctx).First(&after, referer.Id).Error; err != nil {
return err
}
if after.Commission != tc.wantCommission {
return fmt.Errorf("%s: expected referer commission=%d got=%d", tc.name, tc.wantCommission, after.Commission)
}
}
fmt.Printf("PASS %-58s gifts=%d commission_logs=%d\n", tc.name, giftLogs, commissionLogs)
}
return nil
}
func (e *replayEnv) replayFamilyInviteGiftToOwner(ctx context.Context) error {
fmt.Println("\n-- bug3 family replay: member purchase gift days go to owner --")
giftDays := e.cfg.Invite.GiftDays
if giftDays <= 0 {
giftDays = 2
}
e.svcCtx.Config.Invite.GiftDays = giftDays
e.svcCtx.Config.Invite.ReferralPercentage = 0
e.svcCtx.Config.Invite.OnlyFirstPurchase = false
planA, _, err := e.createPlans(ctx, "bug3-family")
if err != nil {
return err
}
referer, err := e.createUser(ctx, "bug3-family-referer", 0, 0)
if err != nil {
return err
}
owner, err := e.createUser(ctx, "bug3-family-owner", 0, 0)
if err != nil {
return err
}
member, err := e.createUser(ctx, "bug3-family-member", referer.Id, 0)
if err != nil {
return err
}
if err = e.createFamily(ctx, owner.Id, member.Id); err != nil {
return err
}
baseExpire := time.Now().Add(10 * 24 * time.Hour).Truncate(time.Millisecond)
ownerSub, err := e.createUserSubscribe(ctx, owner.Id, 0, planA.Id, baseExpire)
if err != nil {
return err
}
memberSub, err := e.createUserSubscribe(ctx, member.Id, 0, planA.Id, baseExpire)
if err != nil {
return err
}
refererSub, err := e.createUserSubscribe(ctx, referer.Id, 0, planA.Id, baseExpire)
if err != nil {
return err
}
order, err := e.createPaidOrder(ctx, member.Id, owner.Id, planA.Id, true, "bug3-family")
if err != nil {
return err
}
if err = e.activateOrderTwice(ctx, order.OrderNo); err != nil {
return err
}
if err = e.waitForLogCounts(ctx, order.OrderNo, 2, 0); err != nil {
return err
}
var ownerAfter, memberAfter, refererAfter modelUser.Subscribe
if err = e.db.WithContext(ctx).First(&ownerAfter, ownerSub.Id).Error; err != nil {
return err
}
if err = e.db.WithContext(ctx).First(&memberAfter, memberSub.Id).Error; err != nil {
return err
}
if err = e.db.WithContext(ctx).First(&refererAfter, refererSub.Id).Error; err != nil {
return err
}
if !ownerAfter.ExpireTime.After(baseExpire) {
return fmt.Errorf("family gift failed: owner expire not increased")
}
if !refererAfter.ExpireTime.After(baseExpire) {
return fmt.Errorf("family gift failed: referer expire not increased")
}
if memberAfter.ExpireTime.After(baseExpire.Add(time.Second)) {
return fmt.Errorf("family gift failed: member subscription should not receive gift days")
}
var memberGiftLogs int64
if err = e.db.WithContext(ctx).Model(&modelLog.SystemLog{}).
Where("type = ? AND object_id = ? AND content LIKE ?", modelLog.TypeGift.Uint8(), member.Id, "%"+order.OrderNo+"%").
Count(&memberGiftLogs).Error; err != nil {
return err
}
if memberGiftLogs != 0 {
return fmt.Errorf("family gift failed: expected no member gift logs, got %d", memberGiftLogs)
}
fmt.Printf("PASS family member purchase gift target owner owner=%d member=%d referer=%d gift_days=%d\n", owner.Id, member.Id, referer.Id, giftDays)
return nil
}
func (e *replayEnv) activateOrderTwice(ctx context.Context, orderNo string) error {
payload, _ := json.Marshal(queueTypes.ForthwithActivateOrderPayload{OrderNo: orderNo})
worker := orderLogic.NewActivateOrderLogic(e.svcCtx)
if err := worker.ProcessTask(ctx, asynq.NewTask(queueTypes.ForthwithActivateOrder, payload)); err != nil {
return err
}
return worker.ProcessTask(ctx, asynq.NewTask(queueTypes.ForthwithActivateOrder, payload))
}
func (e *replayEnv) waitForLogCounts(ctx context.Context, orderNo string, wantGiftLogs, wantCommissionLogs int64) error {
deadline := time.Now().Add(8 * time.Second)
for {
giftLogs, err := e.countLogs(ctx, modelLog.TypeGift.Uint8(), orderNo)
if err != nil {
return err
}
commissionLogs, err := e.countLogs(ctx, modelLog.TypeCommission.Uint8(), orderNo)
if err != nil {
return err
}
if giftLogs >= wantGiftLogs && commissionLogs >= wantCommissionLogs {
if wantGiftLogs == 0 && wantCommissionLogs == 0 {
time.Sleep(500 * time.Millisecond)
}
return nil
}
if time.Now().After(deadline) {
return fmt.Errorf("timed out waiting for logs: order=%s gift=%d/%d commission=%d/%d", orderNo, giftLogs, wantGiftLogs, commissionLogs, wantCommissionLogs)
}
time.Sleep(100 * time.Millisecond)
}
}
func (e *replayEnv) countLogs(ctx context.Context, logType uint8, orderNo string) (int64, error) {
var count int64
err := e.db.WithContext(ctx).Model(&modelLog.SystemLog{}).
Where("type = ? AND content LIKE ?", logType, "%"+orderNo+"%").
Count(&count).Error
return count, err
}
func (e *replayEnv) waitForGiftLogs(ctx context.Context, orderNo string, userIDs ...int64) error {
deadline := time.Now().Add(5 * time.Second)
for time.Now().Before(deadline) {
var count int64
if err := e.db.WithContext(ctx).Model(&modelLog.SystemLog{}).
Where("type = ? AND object_id IN ? AND content LIKE ?", modelLog.TypeGift.Uint8(), userIDs, "%"+orderNo+"%").
Count(&count).Error; err != nil {
return err
}
if count == int64(len(userIDs)) {
return nil
}
time.Sleep(100 * time.Millisecond)
}
return fmt.Errorf("timed out waiting for gift logs for order=%s", orderNo)
}
func (e *replayEnv) createPlans(ctx context.Context, scope string) (*modelSubscribe.Subscribe, *modelSubscribe.Subscribe, error) {
a := &modelSubscribe.Subscribe{
Name: marker + "-" + scope + "-A",
Language: "en",
UnitPrice: 599,
UnitTime: "Month",
Traffic: 1024 * 1024 * 1024,
Inventory: -1,
Quota: 0,
NodeGroupIds: modelSubscribe.JSONInt64Slice{},
}
b := &modelSubscribe.Subscribe{
Name: marker + "-" + scope + "-B",
Language: "en",
UnitPrice: 699,
UnitTime: "Month",
Traffic: 2 * 1024 * 1024 * 1024,
Inventory: -1,
Quota: 0,
NodeGroupIds: modelSubscribe.JSONInt64Slice{},
}
if err := e.db.WithContext(ctx).Create(a).Error; err != nil {
return nil, nil, err
}
if err := e.db.WithContext(ctx).Create(b).Error; err != nil {
return nil, nil, err
}
e.ids.plans = append(e.ids.plans, a.Id, b.Id)
return a, b, nil
}
func (e *replayEnv) createUser(ctx context.Context, scope string, refererID int64, referralPercentage uint8) (*modelUser.User, error) {
onlyFirst := true
enable := true
isAdmin := false
u := &modelUser.User{
Password: marker,
Algo: "default",
Salt: "default",
RefererId: refererID,
ReferralPercentage: referralPercentage,
OnlyFirstPurchase: &onlyFirst,
Enable: &enable,
IsAdmin: &isAdmin,
EnableBalanceNotify: &enable,
EnableLoginNotify: &enable,
EnableSubscribeNotify: &enable,
EnableTradeNotify: &enable,
Remark: marker + "-" + scope,
}
if err := e.db.WithContext(ctx).Create(u).Error; err != nil {
return nil, err
}
u.ReferCode = uuidx.UserInviteCode(u.Id)
if err := e.db.WithContext(ctx).Model(&modelUser.User{}).Where("id = ?", u.Id).Update("refer_code", u.ReferCode).Error; err != nil {
return nil, err
}
e.ids.users = append(e.ids.users, u.Id)
return u, nil
}
func (e *replayEnv) createFamily(ctx context.Context, ownerID, memberID int64) error {
now := time.Now()
family := &modelUser.UserFamily{
OwnerUserId: ownerID,
MaxMembers: modelUser.DefaultFamilyMaxSize,
Status: modelUser.FamilyStatusActive,
}
if err := e.db.WithContext(ctx).Create(family).Error; err != nil {
return err
}
members := []modelUser.UserFamilyMember{
{
FamilyId: family.Id,
UserId: ownerID,
Role: modelUser.FamilyRoleOwner,
Status: modelUser.FamilyMemberActive,
JoinSource: marker,
JoinedAt: now,
},
{
FamilyId: family.Id,
UserId: memberID,
Role: modelUser.FamilyRoleMember,
Status: modelUser.FamilyMemberActive,
JoinSource: marker,
JoinedAt: now,
},
}
return e.db.WithContext(ctx).Create(&members).Error
}
func (e *replayEnv) createUserSubscribe(ctx context.Context, userID, orderID, planID int64, expire time.Time) (*modelUser.Subscribe, error) {
groupLocked := false
sub := &modelUser.Subscribe{
UserId: userID,
OrderId: orderID,
SubscribeId: planID,
GroupLocked: &groupLocked,
StartTime: time.Now().Add(-time.Hour),
ExpireTime: expire,
Traffic: 1024 * 1024 * 1024,
Token: marker + "-" + uuidx.NewUUID().String(),
UUID: uuidx.NewUUID().String(),
Status: 1,
Note: marker,
}
if err := e.db.WithContext(ctx).Create(sub).Error; err != nil {
return nil, err
}
e.ids.subscribes = append(e.ids.subscribes, sub.Id)
return sub, nil
}
func (e *replayEnv) createPaidOrder(ctx context.Context, userID, subscriptionUserID, planID int64, isNew bool, scope string) (*modelOrder.Order, error) {
orderNo := fmt.Sprintf("%s-%s-%d", marker, scope, time.Now().UnixNano())
order := &modelOrder.Order{
UserId: userID,
SubscriptionUserId: subscriptionUserID,
OrderNo: orderNo,
Type: 1,
Quantity: 1,
Price: 599,
Amount: 599,
Status: 2,
SubscribeId: planID,
Method: "replay",
IsNew: isNew,
}
if err := e.db.WithContext(ctx).Create(order).Error; err != nil {
return nil, err
}
e.ids.orders = append(e.ids.orders, order.Id)
return order, nil
}
func (e *replayEnv) cleanup(ctx context.Context) {
fmt.Println("\n-- cleanup replay rows --")
e.cleanupByMarker(ctx)
if len(e.ids.subscribes) > 0 {
_ = e.db.WithContext(ctx).Where("id IN ?", e.ids.subscribes).Delete(&modelUser.Subscribe{}).Error
}
if len(e.ids.orders) > 0 {
_ = e.db.WithContext(ctx).Where("id IN ?", e.ids.orders).Delete(&modelOrder.Order{}).Error
}
if len(e.ids.plans) > 0 {
_ = e.db.WithContext(ctx).Where("id IN ?", e.ids.plans).Delete(&modelSubscribe.Subscribe{}).Error
}
if len(e.ids.users) > 0 {
_ = e.db.WithContext(ctx).Unscoped().Where("id IN ?", e.ids.users).Delete(&modelUser.User{}).Error
}
fmt.Println("cleanup done")
}
func (e *replayEnv) cleanupByMarker(ctx context.Context) {
_ = e.db.WithContext(ctx).
Where("join_source = ?", marker).
Delete(&modelUser.UserFamilyMember{}).Error
_ = e.db.WithContext(ctx).
Where("owner_user_id IN (SELECT id FROM `user` WHERE remark LIKE ?)", marker+"%").
Delete(&modelUser.UserFamily{}).Error
_ = e.db.WithContext(ctx).
Where("type IN (33, 34) AND content LIKE ?", "%"+marker+"%").
Delete(&modelLog.SystemLog{}).Error
_ = e.db.WithContext(ctx).
Where("order_no LIKE ?", marker+"%").
Delete(&modelOrder.Order{}).Error
_ = e.db.WithContext(ctx).
Where("note = ? OR token LIKE ?", marker, marker+"%").
Delete(&modelUser.Subscribe{}).Error
_ = e.db.WithContext(ctx).
Where("name LIKE ?", marker+"%").
Delete(&modelSubscribe.Subscribe{}).Error
_ = e.db.WithContext(ctx).Unscoped().
Where("remark LIKE ?", marker+"%").
Delete(&modelUser.User{}).Error
}
func looksLikeProduction(cfg config.Config) bool {
joined := strings.ToLower(strings.Join([]string{cfg.MySQL.Dbname, cfg.Site.Host, cfg.Host}, " "))
if strings.Contains(joined, "prod") || strings.Contains(joined, "production") {
return true
}
if cfg.Debug {
return false
}
if strings.Contains(joined, "test") || strings.Contains(joined, "dev") || strings.Contains(joined, "staging") {
return false
}
return true
}
func must(err error) {
if err != nil {
fatalf("%v", err)
}
}
func fatalf(format string, args ...interface{}) {
fmt.Fprintf(os.Stderr, "FAIL: "+format+"\n", args...)
os.Exit(1)
}

View File

@ -0,0 +1,238 @@
//go:build ignore
package main
import (
"context"
"flag"
"fmt"
"os"
"strings"
"time"
"github.com/perfect-panel/server/initialize"
"github.com/perfect-panel/server/internal/config"
authlogic "github.com/perfect-panel/server/internal/logic/auth"
modelAuth "github.com/perfect-panel/server/internal/model/auth"
modelLog "github.com/perfect-panel/server/internal/model/log"
modelNode "github.com/perfect-panel/server/internal/model/node"
modelSubscribe "github.com/perfect-panel/server/internal/model/subscribe"
modelSystem "github.com/perfect-panel/server/internal/model/system"
modelUser "github.com/perfect-panel/server/internal/model/user"
"github.com/perfect-panel/server/internal/svc"
"github.com/perfect-panel/server/internal/types"
"github.com/perfect-panel/server/pkg/conf"
"github.com/perfect-panel/server/pkg/orm"
"github.com/perfect-panel/server/pkg/tool"
"github.com/redis/go-redis/v9"
"gorm.io/gorm"
)
func main() {
var (
configPath = flag.String("config", "etc/ppanel.yaml", "config file path on the test server")
dsn = flag.String("dsn", "", "optional MySQL DSN override")
identifier = flag.String("identifier", "", "optional device identifier; defaults to a unique test identifier")
ip = flag.String("ip", "", "optional request IP; defaults to a reserved test IP")
userAgent = flag.String("user-agent", "CodexDeviceTrialTest/1.0", "device user agent")
write = flag.Bool("write", false, "actually create a test device user by running DeviceLogin")
cleanup = flag.Bool("cleanup", false, "delete the test user/device/subscription/log rows after verification")
)
flag.Parse()
if !*write {
fmt.Println("Refusing to write DB without -write.")
fmt.Println("Example:")
fmt.Printf(" go run scripts/test_device_trial_registration.go -config %s -write\n", *configPath)
os.Exit(2)
}
ctx := context.Background()
cfg := loadConfig(*configPath, *dsn)
env := mustNewDeviceTrialEnv(ctx, cfg)
defer env.close()
initialize.Device(env.svcCtx)
initialize.Register(env.svcCtx)
if *identifier == "" {
*identifier = fmt.Sprintf("codex-device-trial-%d", time.Now().UnixNano())
}
if *ip == "" {
now := time.Now().UnixNano()
*ip = fmt.Sprintf("198.18.%d.%d", now%200+1, now/200%200+1)
}
fmt.Println("== device registration no-trial test ==")
fmt.Printf("mysql: %s/%s\n", env.cfg.MySQL.Addr, env.cfg.MySQL.Dbname)
fmt.Printf("redis: %s db=%d\n", env.cfg.Redis.Host, env.cfg.Redis.DB)
fmt.Printf("device.enable=%v\n", env.svcCtx.Config.Device.Enable)
fmt.Printf("register.enable_trial=%v trial_subscribe=%d trial_time=%d trial_time_unit=%s\n",
env.svcCtx.Config.Register.EnableTrial,
env.svcCtx.Config.Register.TrialSubscribe,
env.svcCtx.Config.Register.TrialTime,
env.svcCtx.Config.Register.TrialTimeUnit,
)
fmt.Printf("identifier=%s ip=%s user_agent=%s\n", *identifier, *ip, *userAgent)
if err := ensureIdentifierUnused(ctx, env.db, *identifier); err != nil {
fail(err)
}
logic := authlogic.NewDeviceLoginLogic(ctx, env.svcCtx)
resp, err := logic.DeviceLogin(&types.DeviceLoginRequest{
Identifier: *identifier,
IP: *ip,
UserAgent: *userAgent,
})
if err != nil {
fail(fmt.Errorf("DeviceLogin failed: %w", err))
}
if resp == nil || strings.TrimSpace(resp.Token) == "" {
fail(fmt.Errorf("DeviceLogin returned empty token"))
}
fmt.Printf("login token: ok len=%d\n", len(resp.Token))
device, err := env.svcCtx.UserModel.FindOneDeviceByIdentifier(ctx, *identifier)
if err != nil {
fail(fmt.Errorf("query created device failed: %w", err))
}
fmt.Printf("device: id=%d sn=%s user_id=%d created_at=%s\n",
device.Id,
tool.DeviceIdToHash(device.Id),
device.UserId,
device.CreatedAt.Format(time.RFC3339),
)
var subs []modelUser.Subscribe
if err = env.db.WithContext(ctx).
Where("user_id = ?", device.UserId).
Order("id ASC").
Find(&subs).Error; err != nil {
fail(fmt.Errorf("query user_subscribe failed: %w", err))
}
for i := range subs {
sub := &subs[i]
fmt.Printf("subscribe: id=%d order_id=%d subscribe_id=%d status=%d start=%s expire=%s token_empty=%v\n",
sub.Id,
sub.OrderId,
sub.SubscribeId,
sub.Status,
sub.StartTime.Format(time.RFC3339),
sub.ExpireTime.Format(time.RFC3339),
sub.Token == "",
)
if sub.OrderId == 0 &&
sub.SubscribeId == env.svcCtx.Config.Register.TrialSubscribe &&
(sub.Status == 0 || sub.Status == 1) &&
sub.ExpireTime.After(time.Now()) {
fail(fmt.Errorf("FAIL: device registration unexpectedly granted trial user_subscribe_id=%d user_id=%d", sub.Id, device.UserId))
}
}
fmt.Printf("PASS: device registration created no active trial subscription for user_id=%d\n", device.UserId)
if *cleanup {
if err = cleanupTestRows(ctx, env.db, device.UserId); err != nil {
fail(fmt.Errorf("cleanup failed: %w", err))
}
fmt.Printf("cleanup: deleted test rows for user_id=%d\n", device.UserId)
}
}
type deviceTrialEnv struct {
db *gorm.DB
rds *redis.Client
cfg config.Config
svcCtx *svc.ServiceContext
}
func mustNewDeviceTrialEnv(ctx context.Context, cfg config.Config) *deviceTrialEnv {
db, err := orm.ConnectMysql(orm.Mysql{Config: cfg.MySQL})
must(err)
rds := redis.NewClient(&redis.Options{
Addr: cfg.Redis.Host,
Password: cfg.Redis.Pass,
DB: cfg.Redis.DB,
PoolSize: cfg.Redis.PoolSize,
MinIdleConns: cfg.Redis.MinIdleConns,
})
must(rds.Ping(ctx).Err())
svcCtx := &svc.ServiceContext{
DB: db,
Redis: rds,
Config: cfg,
AuthModel: modelAuth.NewModel(db, rds),
LogModel: modelLog.NewModel(db),
NodeModel: modelNode.NewModel(db, rds),
SystemModel: modelSystem.NewModel(db, rds),
UserModel: modelUser.NewModel(db, rds),
SubscribeModel: modelSubscribe.NewModel(db, rds),
}
return &deviceTrialEnv{db: db, rds: rds, cfg: cfg, svcCtx: svcCtx}
}
func (e *deviceTrialEnv) close() {
if e == nil || e.rds == nil {
return
}
_ = e.rds.Close()
}
func loadConfig(path, dsn string) config.Config {
var cfg config.Config
conf.MustLoad(path, &cfg)
if dsn != "" {
parsed := orm.ParseDSN(dsn)
if parsed == nil {
fail(fmt.Errorf("invalid dsn"))
}
cfg.MySQL = *parsed
}
return cfg
}
func ensureIdentifierUnused(ctx context.Context, db *gorm.DB, identifier string) error {
var count int64
if err := db.WithContext(ctx).
Model(&modelUser.Device{}).
Where("identifier = ?", identifier).
Count(&count).Error; err != nil {
return err
}
if count > 0 {
return fmt.Errorf("identifier already exists: %s", identifier)
}
return nil
}
func cleanupTestRows(ctx context.Context, db *gorm.DB, userID int64) error {
return db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
if err := tx.Where("object_id = ?", userID).Delete(&modelLog.SystemLog{}).Error; err != nil {
return err
}
if err := tx.Where("user_id = ?", userID).Delete(&modelUser.Subscribe{}).Error; err != nil {
return err
}
if err := tx.Where("user_id = ?", userID).Delete(&modelUser.AuthMethods{}).Error; err != nil {
return err
}
if err := tx.Where("user_id = ?", userID).Delete(&modelUser.Device{}).Error; err != nil {
return err
}
return tx.Where("id = ?", userID).Delete(&modelUser.User{}).Error
})
}
func must(err error) {
if err != nil {
fail(err)
}
}
func fail(err error) {
fmt.Fprintln(os.Stderr, err)
os.Exit(1)
}

View File

@ -0,0 +1,485 @@
package main
import (
"context"
"database/sql"
"encoding/json"
"flag"
"fmt"
"os"
"strings"
"time"
_ "github.com/go-sql-driver/mysql"
)
const inviteGiftMarker = "codex-test-invite-gift-days"
type giftLog struct {
Type uint16 `json:"type"`
OrderNo string `json:"order_no"`
SubscribeId int64 `json:"subscribe_id"`
Amount int64 `json:"amount"`
Balance int64 `json:"balance"`
Remark string `json:"remark,omitempty"`
Timestamp int64 `json:"timestamp"`
}
type commissionLog struct {
Type uint16 `json:"type"`
Amount int64 `json:"amount"`
OrderNo string `json:"order_no"`
Timestamp int64 `json:"timestamp"`
}
type userSubscribe struct {
ID int64
UserID int64
ExpireTime time.Time
}
func main() {
var (
dsn = flag.String("dsn", "", "MySQL DSN, for example root:pass@tcp(host:3306)/ppanel?charset=utf8mb4&parseTime=true&loc=Asia%2FShanghai")
writeDB = flag.Bool("write-db", false, "create isolated rows, simulate invite gifts, and clean them up")
keep = flag.Bool("keep", false, "keep rows for manual inspection")
cleanupOnly = flag.Bool("cleanup-only", false, "delete leftover rows created by this script and exit")
giftDays = flag.Int("gift-days", 3, "days to add to both invite users")
commission = flag.Int64("commission-percent", 10, "commission percent for commission-path simulation")
)
flag.Parse()
if *dsn == "" {
exitf("-dsn is required")
}
ctx := context.Background()
db, err := sql.Open("mysql", *dsn)
mustNoErr(err)
defer db.Close()
db.SetMaxIdleConns(1)
db.SetMaxOpenConns(1)
mustNoErr(db.PingContext(ctx))
if *cleanupOnly {
mustNoErr(cleanup(ctx, db))
fmt.Println("cleanup done")
return
}
if !*writeDB {
fmt.Println("dry run only. Add -write-db to create isolated invite rows in the TEST database.")
return
}
if *giftDays <= 0 {
exitf("-gift-days must be positive")
}
mustNoErr(cleanup(ctx, db))
if !*keep {
defer func() {
if err := cleanup(context.Background(), db); err != nil {
fmt.Fprintf(os.Stderr, "cleanup failed: %v\n", err)
}
}()
}
planID := mustCreatePlan(ctx, db)
runSelfInviteScenario(ctx, db, planID, *giftDays)
runFamilyInviteScenario(ctx, db, planID, *giftDays)
runCommissionScenario(ctx, db, planID, *giftDays, *commission)
if *keep {
fmt.Println("rows kept; cleanup with -cleanup-only. inviteGiftMarker:", inviteGiftMarker)
}
}
func runSelfInviteScenario(ctx context.Context, db *sql.DB, planID int64, giftDays int) {
refererID := mustCreateUser(ctx, db, "self-referer", 0)
refereeID := mustCreateUser(ctx, db, "self-referee", refererID)
baseExpire := time.Now().Add(10 * 24 * time.Hour).Truncate(time.Second)
refererSubID := mustCreateUserSubscribe(ctx, db, refererID, planID, baseExpire)
refereeSubID := mustCreateUserSubscribe(ctx, db, refereeID, planID, baseExpire)
orderNo := fmt.Sprintf("%s-self-order-%d", inviteGiftMarker, time.Now().UnixNano())
mustNoErr(simulateInviteGiftBoth(ctx, db, orderNo, refererID, refereeID, 0, giftDays))
mustNoErr(simulateInviteGiftBoth(ctx, db, orderNo, refererID, refereeID, 0, giftDays))
assertExpire(ctx, db, "referer", refererSubID, baseExpire, giftDays)
assertExpire(ctx, db, "referee", refereeSubID, baseExpire, giftDays)
logs := mustGiftLogCount(ctx, db, orderNo)
if logs != 2 {
exitf("gift log count mismatch after duplicate simulation: got=%d want=2", logs)
}
fmt.Printf("PASS self invite: referer=%d referee=%d order=%s gift_days=%d logs=%d\n", refererID, refereeID, orderNo, giftDays, logs)
}
func runFamilyInviteScenario(ctx context.Context, db *sql.DB, planID int64, giftDays int) {
refererOwnerID := mustCreateUser(ctx, db, "family-referer-owner", 0)
refererMemberID := mustCreateUser(ctx, db, "family-referer-member", 0)
refereeOwnerID := mustCreateUser(ctx, db, "family-referee-owner", 0)
refereeMemberID := mustCreateUser(ctx, db, "family-referee-member", refererMemberID)
mustCreateFamily(ctx, db, refererOwnerID, refererMemberID)
mustCreateFamily(ctx, db, refereeOwnerID, refereeMemberID)
baseExpire := time.Now().Add(10 * 24 * time.Hour).Truncate(time.Second)
refererOwnerSubID := mustCreateUserSubscribe(ctx, db, refererOwnerID, planID, baseExpire)
refereeOwnerSubID := mustCreateUserSubscribe(ctx, db, refereeOwnerID, planID, baseExpire)
refererMemberSubID := mustCreateUserSubscribe(ctx, db, refererMemberID, planID, baseExpire)
refereeMemberSubID := mustCreateUserSubscribe(ctx, db, refereeMemberID, planID, baseExpire)
orderNo := fmt.Sprintf("%s-family-order-%d", inviteGiftMarker, time.Now().UnixNano())
mustNoErr(simulateInviteGiftBoth(ctx, db, orderNo, refererMemberID, refereeMemberID, refereeOwnerID, giftDays))
mustNoErr(simulateInviteGiftBoth(ctx, db, orderNo, refererMemberID, refereeMemberID, refereeOwnerID, giftDays))
assertExpire(ctx, db, "referer owner", refererOwnerSubID, baseExpire, giftDays)
assertExpire(ctx, db, "referee owner", refereeOwnerSubID, baseExpire, giftDays)
assertExpire(ctx, db, "referer member", refererMemberSubID, baseExpire, 0)
assertExpire(ctx, db, "referee member", refereeMemberSubID, baseExpire, 0)
logs := mustGiftLogCount(ctx, db, orderNo)
if logs != 2 {
exitf("family gift log count mismatch after duplicate simulation: got=%d want=2", logs)
}
fmt.Printf("PASS family invite: referer_member=%d->owner=%d referee_member=%d->owner=%d order=%s gift_days=%d logs=%d\n",
refererMemberID, refererOwnerID, refereeMemberID, refereeOwnerID, orderNo, giftDays, logs)
}
func runCommissionScenario(ctx context.Context, db *sql.DB, planID int64, giftDays int, commissionPercent int64) {
if commissionPercent <= 0 {
fmt.Println("SKIP commission invite: commission-percent <= 0")
return
}
const amount int64 = 599
refererID := mustCreateUser(ctx, db, "commission-referer", 0)
refereeID := mustCreateUser(ctx, db, "commission-referee", refererID)
baseExpire := time.Now().Add(10 * 24 * time.Hour).Truncate(time.Second)
refererSubID := mustCreateUserSubscribe(ctx, db, refererID, planID, baseExpire)
refereeSubID := mustCreateUserSubscribe(ctx, db, refereeID, planID, baseExpire)
orderNo := fmt.Sprintf("%s-commission-first-order-%d", inviteGiftMarker, time.Now().UnixNano())
mustNoErr(simulateInviteCommission(ctx, db, orderNo, refererID, refereeID, 0, giftDays, amount, commissionPercent, true))
mustNoErr(simulateInviteCommission(ctx, db, orderNo, refererID, refereeID, 0, giftDays, amount, commissionPercent, true))
wantCommission := amount * commissionPercent / 100
assertExpire(ctx, db, "commission referer", refererSubID, baseExpire, 0)
assertExpire(ctx, db, "commission referee", refereeSubID, baseExpire, giftDays)
assertCommission(ctx, db, refererID, wantCommission)
assertLogCount(ctx, db, "commission first gift", 34, orderNo, 1)
assertLogCount(ctx, db, "commission first commission", 33, orderNo, 1)
nonFirstRefererID := mustCreateUser(ctx, db, "commission-nonfirst-referer", 0)
nonFirstRefereeID := mustCreateUser(ctx, db, "commission-nonfirst-referee", nonFirstRefererID)
nonFirstRefererSubID := mustCreateUserSubscribe(ctx, db, nonFirstRefererID, planID, baseExpire)
nonFirstRefereeSubID := mustCreateUserSubscribe(ctx, db, nonFirstRefereeID, planID, baseExpire)
nonFirstOrderNo := fmt.Sprintf("%s-commission-nonfirst-order-%d", inviteGiftMarker, time.Now().UnixNano())
mustNoErr(simulateInviteCommission(ctx, db, nonFirstOrderNo, nonFirstRefererID, nonFirstRefereeID, 0, giftDays, amount, commissionPercent, false))
mustNoErr(simulateInviteCommission(ctx, db, nonFirstOrderNo, nonFirstRefererID, nonFirstRefereeID, 0, giftDays, amount, commissionPercent, false))
assertExpire(ctx, db, "commission non-first referer", nonFirstRefererSubID, baseExpire, 0)
assertExpire(ctx, db, "commission non-first referee", nonFirstRefereeSubID, baseExpire, 0)
assertCommission(ctx, db, nonFirstRefererID, wantCommission)
assertLogCount(ctx, db, "commission non-first gift", 34, nonFirstOrderNo, 0)
assertLogCount(ctx, db, "commission non-first commission", 33, nonFirstOrderNo, 1)
fmt.Printf("PASS commission invite: percent=%d first_order_commission=%d non_first_commission=%d\n",
commissionPercent, wantCommission, wantCommission)
}
func assertExpire(ctx context.Context, db *sql.DB, label string, subID int64, before time.Time, addedDays int) {
got := mustExpire(ctx, db, subID)
want := before.Add(time.Duration(addedDays) * 24 * time.Hour)
if !got.Equal(want) {
exitf("%s expire mismatch: got=%s want=%s", label, got, want)
}
fmt.Printf("PASS %s subscribe=%d expire %s -> %s\n", label, subID, before.Format(time.RFC3339), got.Format(time.RFC3339))
}
func simulateInviteGiftBoth(ctx context.Context, db *sql.DB, orderNo string, refererID, refereeID, forcedRefereeOwnerID int64, days int) error {
refereeTargetID, err := resolveGiftTargetUser(ctx, db, refereeID, forcedRefereeOwnerID)
if err != nil {
return fmt.Errorf("resolve referee gift target: %w", err)
}
refererTargetID, err := resolveGiftTargetUser(ctx, db, refererID, 0)
if err != nil {
return fmt.Errorf("resolve referer gift target: %w", err)
}
if err := grantGiftDays(ctx, db, refereeTargetID, orderNo, days); err != nil {
return fmt.Errorf("grant referee gift: %w", err)
}
if err := grantGiftDays(ctx, db, refererTargetID, orderNo, days); err != nil {
return fmt.Errorf("grant referer gift: %w", err)
}
return nil
}
func simulateInviteCommission(ctx context.Context, db *sql.DB, orderNo string, refererID, refereeID, forcedRefereeOwnerID int64, days int, amount int64, commissionPercent int64, isFirstOrder bool) error {
if err := grantCommission(ctx, db, refererID, orderNo, amount, commissionPercent); err != nil {
return fmt.Errorf("grant commission: %w", err)
}
if isFirstOrder {
refereeTargetID, err := resolveGiftTargetUser(ctx, db, refereeID, forcedRefereeOwnerID)
if err != nil {
return fmt.Errorf("resolve referee gift target: %w", err)
}
if err := grantGiftDays(ctx, db, refereeTargetID, orderNo, days); err != nil {
return fmt.Errorf("grant commission-path referee gift: %w", err)
}
}
return nil
}
func resolveGiftTargetUser(ctx context.Context, db *sql.DB, userID int64, forcedOwnerID int64) (int64, error) {
if forcedOwnerID > 0 {
return forcedOwnerID, nil
}
var ownerID int64
err := db.QueryRowContext(ctx, `
SELECT uf.owner_user_id
FROM user_family_member ufm
JOIN user_family uf ON uf.id = ufm.family_id AND uf.deleted_at IS NULL
WHERE ufm.user_id = ?
AND ufm.deleted_at IS NULL
AND ufm.status = 1
AND ufm.role = 2
AND uf.status = 1
ORDER BY ufm.role
LIMIT 1`, userID).Scan(&ownerID)
if err == sql.ErrNoRows {
return userID, nil
}
if err != nil {
return 0, err
}
if ownerID > 0 && ownerID != userID {
return ownerID, nil
}
return userID, nil
}
func grantCommission(ctx context.Context, db *sql.DB, refererID int64, orderNo string, amount int64, commissionPercent int64) error {
var existing int64
err := db.QueryRowContext(ctx,
"SELECT COUNT(*) FROM system_logs WHERE type = 33 AND object_id = ? AND content LIKE ?",
refererID, "%\""+orderNo+"\"%",
).Scan(&existing)
if err != nil {
return err
}
if existing > 0 {
return nil
}
commissionAmount := amount * commissionPercent / 100
if _, err = db.ExecContext(ctx,
"UPDATE `user` SET commission = commission + ?, updated_at = ? WHERE id = ?",
commissionAmount, time.Now(), refererID,
); err != nil {
return err
}
content, err := json.Marshal(commissionLog{
Type: 331,
Amount: commissionAmount,
OrderNo: orderNo,
Timestamp: time.Now().UnixMilli(),
})
if err != nil {
return err
}
_, err = db.ExecContext(ctx,
"INSERT INTO system_logs (`type`, object_id, content, created_at, `date`) VALUES (33, ?, ?, ?, ?)",
refererID, string(content), time.Now(), time.Now().Format("2006-01-02"),
)
return err
}
func grantGiftDays(ctx context.Context, db *sql.DB, userID int64, orderNo string, days int) error {
var existing int64
err := db.QueryRowContext(ctx,
"SELECT COUNT(*) FROM system_logs WHERE type = 34 AND object_id = ? AND content LIKE ?",
userID, "%\""+orderNo+"\"%",
).Scan(&existing)
if err != nil {
return err
}
if existing > 0 {
return nil
}
sub, err := findActiveSubscribe(ctx, db, userID)
if err != nil {
return err
}
nextExpire := sub.ExpireTime
if !sub.ExpireTime.Equal(time.UnixMilli(0)) {
nextExpire = sub.ExpireTime.Add(time.Duration(days) * 24 * time.Hour)
if _, err = db.ExecContext(ctx,
"UPDATE user_subscribe SET expire_time = ?, updated_at = ? WHERE id = ?",
nextExpire, time.Now(), sub.ID,
); err != nil {
return err
}
}
content, err := json.Marshal(giftLog{
Type: 341,
OrderNo: orderNo,
SubscribeId: sub.ID,
Amount: int64(days),
Balance: 0,
Remark: "邀请赠送",
Timestamp: time.Now().UnixMilli(),
})
if err != nil {
return err
}
_, err = db.ExecContext(ctx,
"INSERT INTO system_logs (`type`, object_id, content, created_at, `date`) VALUES (34, ?, ?, ?, ?)",
userID, string(content), time.Now(), time.Now().Format("2006-01-02"),
)
return err
}
func findActiveSubscribe(ctx context.Context, db *sql.DB, userID int64) (*userSubscribe, error) {
var row userSubscribe
err := db.QueryRowContext(ctx, `
SELECT id, user_id, expire_time
FROM user_subscribe
WHERE user_id = ?
AND status IN (0, 1)
AND (expire_time > ? OR expire_time = '1970-01-01 08:00:00')
ORDER BY expire_time DESC, id DESC
LIMIT 1`, userID, time.Now()).Scan(&row.ID, &row.UserID, &row.ExpireTime)
if err != nil {
return nil, err
}
return &row, nil
}
func mustCreatePlan(ctx context.Context, db *sql.DB) int64 {
var sort int64
mustNoErr(db.QueryRowContext(ctx, "SELECT COALESCE(MAX(sort), 0) + 1 FROM subscribe").Scan(&sort))
res, err := db.ExecContext(ctx, `
INSERT INTO subscribe
(name, language, description, unit_price, unit_time, discount, replacement, inventory, traffic, speed_limit, device_limit, quota, new_user_only, nodes, node_tags, node_group_ids, node_group_id, traffic_limit, `+"`show`"+`, sell, sort, deduction_ratio, allow_deduction, reset_cycle, renewal_reset, show_original_price, created_at, updated_at)
VALUES (?, 'en', '', 599, 'Month', '', 0, -1, 1073741824, 0, 0, 0, false, '', '', '[]', 0, '', false, false, ?, 0, true, 0, false, true, ?, ?)`,
inviteGiftMarker+"-plan", sort, time.Now(), time.Now())
mustNoErr(err)
id, err := res.LastInsertId()
mustNoErr(err)
return id
}
func mustCreateUser(ctx context.Context, db *sql.DB, role string, refererID int64) int64 {
res, err := db.ExecContext(ctx, `
INSERT INTO `+"`user`"+`
(password, algo, avatar, balance, refer_code, referer_id, commission, referral_percentage, only_first_purchase, gift_amount, enable, is_admin, enable_balance_notify, enable_login_notify, enable_subscribe_notify, enable_trade_notify, rules, member_status, remark, created_at, updated_at, salt)
VALUES (?, 'default', '', 0, '', ?, 0, 0, true, 0, true, false, true, true, true, true, '', '', ?, ?, ?, 'default')`,
inviteGiftMarker, refererID, inviteGiftMarker+"-"+role, time.Now(), time.Now())
mustNoErr(err)
id, err := res.LastInsertId()
mustNoErr(err)
_, err = db.ExecContext(ctx, "UPDATE `user` SET refer_code = ?, updated_at = ? WHERE id = ?", fmt.Sprintf("codex%d", id), time.Now(), id)
mustNoErr(err)
return id
}
func mustCreateFamily(ctx context.Context, db *sql.DB, ownerID, memberID int64) int64 {
res, err := db.ExecContext(ctx, `
INSERT INTO user_family
(owner_user_id, max_members, status, created_at, updated_at)
VALUES (?, 3, 1, ?, ?)`, ownerID, time.Now(), time.Now())
mustNoErr(err)
familyID, err := res.LastInsertId()
mustNoErr(err)
now := time.Now()
_, err = db.ExecContext(ctx, `
INSERT INTO user_family_member
(family_id, user_id, role, status, join_source, joined_at, created_at, updated_at)
VALUES
(?, ?, 1, 1, ?, ?, ?, ?),
(?, ?, 2, 1, ?, ?, ?, ?)`,
familyID, ownerID, inviteGiftMarker, now, now, now,
familyID, memberID, inviteGiftMarker, now, now, now)
mustNoErr(err)
return familyID
}
func mustCreateUserSubscribe(ctx context.Context, db *sql.DB, userID, planID int64, expire time.Time) int64 {
token := fmt.Sprintf("%s-token-%d-%d", inviteGiftMarker, userID, time.Now().UnixNano())
uuid := fmt.Sprintf("%08d-0000-4000-8000-%012d", userID, time.Now().UnixNano()%1_000_000_000_000)
res, err := db.ExecContext(ctx, `
INSERT INTO user_subscribe
(user_id, order_id, subscribe_id, node_group_id, group_locked, traffic, download, upload, expired_download, expired_upload, token, uuid, status, note, created_at, updated_at, start_time, expire_time)
VALUES (?, 0, ?, 0, false, 1073741824, 0, 0, 0, 0, ?, ?, 1, ?, ?, ?, ?, ?)`,
userID, planID, token, uuid, inviteGiftMarker, time.Now(), time.Now(), time.Now().Add(-time.Hour), expire)
mustNoErr(err)
id, err := res.LastInsertId()
mustNoErr(err)
return id
}
func mustExpire(ctx context.Context, db *sql.DB, subID int64) time.Time {
var expire time.Time
mustNoErr(db.QueryRowContext(ctx, "SELECT expire_time FROM user_subscribe WHERE id = ?", subID).Scan(&expire))
return expire
}
func mustGiftLogCount(ctx context.Context, db *sql.DB, orderNo string) int64 {
var count int64
mustNoErr(db.QueryRowContext(ctx, "SELECT COUNT(*) FROM system_logs WHERE type = 34 AND content LIKE ?", "%"+orderNo+"%").Scan(&count))
return count
}
func assertCommission(ctx context.Context, db *sql.DB, userID int64, want int64) {
var got int64
mustNoErr(db.QueryRowContext(ctx, "SELECT commission FROM `user` WHERE id = ?", userID).Scan(&got))
if got != want {
exitf("commission mismatch: user=%d got=%d want=%d", userID, got, want)
}
fmt.Printf("PASS commission user=%d amount=%d\n", userID, got)
}
func assertLogCount(ctx context.Context, db *sql.DB, label string, logType uint8, orderNo string, want int64) {
var got int64
mustNoErr(db.QueryRowContext(ctx, "SELECT COUNT(*) FROM system_logs WHERE type = ? AND content LIKE ?", logType, "%"+orderNo+"%").Scan(&got))
if got != want {
exitf("%s log count mismatch: got=%d want=%d", label, got, want)
}
fmt.Printf("PASS %s logs=%d\n", label, got)
}
func cleanup(ctx context.Context, db *sql.DB) error {
stmts := []string{
"DELETE FROM user_family_member WHERE join_source = '" + inviteGiftMarker + "'",
"DELETE FROM user_family WHERE owner_user_id IN (SELECT id FROM `user` WHERE remark LIKE '" + inviteGiftMarker + "%')",
"DELETE FROM system_logs WHERE type IN (33, 34) AND content LIKE '%" + inviteGiftMarker + "%'",
"DELETE FROM user_subscribe WHERE note = '" + inviteGiftMarker + "' OR token LIKE '" + inviteGiftMarker + "%'",
"DELETE FROM subscribe WHERE name LIKE '" + inviteGiftMarker + "%'",
"DELETE FROM `user` WHERE remark LIKE '" + inviteGiftMarker + "%'",
}
for _, stmt := range stmts {
if _, err := db.ExecContext(ctx, stmt); err != nil {
return fmt.Errorf("%s: %w", stmt, err)
}
}
return nil
}
func mustNoErr(err error) {
if err != nil {
exitf("%v", err)
}
}
func exitf(format string, args ...interface{}) {
msg := fmt.Sprintf(format, args...)
fmt.Fprintln(os.Stderr, "FAIL:", strings.TrimSpace(msg))
os.Exit(1)
}

File diff suppressed because it is too large Load Diff

View File

@ -16,6 +16,7 @@
| 2026-03-12 | 分析并确认 Unknown column 错误 | [x] 已完成 | 确认为 `user_device` 缺少 `short_code` 字段,已提供 SQL | | 2026-03-12 | 分析并确认 Unknown column 错误 | [x] 已完成 | 确认为 `user_device` 缺少 `short_code` 字段,已提供 SQL |
| 2026-03-12 | 提供 SSL 证书替换指令 | [x] 已完成 | 已提供备份与替换证书的组合指令 | | 2026-03-12 | 提供 SSL 证书替换指令 | [x] 已完成 | 已提供备份与替换证书的组合指令 |
| 2026-03-17 | 合并 internal 到 internal/main | [x] 已完成 | 已查验均为fast-forward受限网络/权限需手动push完成合并 | | 2026-03-17 | 合并 internal 到 internal/main | [x] 已完成 | 已查验均为fast-forward受限网络/权限需手动push完成合并 |
| 2026-04-14 | 排查支付成功但订阅未下发问题 | [x] 已完成 | 已提供 Docker 相关的日志排查与数据库核对命令 |
certbot certonly --manual --preferred-challenges dns -d airoport.win -d "*.airoport.win" -d hifastapp.com certbot certonly --manual --preferred-challenges dns -d airoport.win -d "*.airoport.win" -d hifastapp.com