208
This commit is contained in:
parent
709d657906
commit
28ada42ae5
@ -22,7 +22,7 @@ func main() {
|
|||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
// 调用 GetPlatformDownloads 获取当月数据+ 环比
|
// 调用 GetPlatformDownloads 获取当月数据+ 环比
|
||||||
platformDownloads, err := client.GetPlatformDownloads(ctx)
|
platformDownloads, err := client.GetPlatformDownloads(ctx, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Printf("❌ 获取失败: %v\n", err)
|
fmt.Printf("❌ 获取失败: %v\n", err)
|
||||||
return
|
return
|
||||||
|
|||||||
@ -0,0 +1,54 @@
|
|||||||
|
# ALIGNMENT: 检查设备登录 IP 逻辑
|
||||||
|
|
||||||
|
## 原始需求
|
||||||
|
用户询问:
|
||||||
|
> "检查一下 设备登录的时候 有检查IP 么"
|
||||||
|
> "我需要知道 有没有IP 目前数据库都没有IP"
|
||||||
|
|
||||||
|
## 项目现状分析 (Context Analysis)
|
||||||
|
1. **数据模型 (`user.Device`)**:
|
||||||
|
- 存在 `Ip` 字段 (`varchar(255)`).
|
||||||
|
- 定义在 `/internal/model/user/user.go`。
|
||||||
|
|
||||||
|
2. **业务逻辑 (`DeviceLoginLogic.go`)**:
|
||||||
|
- 在代码逻辑中,确实尝试获取 `req.IP` 并保存到数据库。
|
||||||
|
- 创建新设备时:`deviceInfo.Ip = req.IP`。
|
||||||
|
- 记录登录日志时:`LoginIP: req.IP`。
|
||||||
|
|
||||||
|
3. **关键问题点 (Root Cause)**:
|
||||||
|
- 输入参数定义在 `internal/types/types.go`:
|
||||||
|
```go
|
||||||
|
type DeviceLoginRequest struct {
|
||||||
|
// ...
|
||||||
|
IP string `header:"X-Original-Forwarded-For"`
|
||||||
|
// ...
|
||||||
|
}
|
||||||
|
```
|
||||||
|
- **当前仅支持 `X-Original-Forwarded-For` 请求头**。
|
||||||
|
- 如果请求经过 Nginx、Cloudflare 等代理,但没有专门配置传递这个特定的 Header,或者使用的是标准的 `X-Forwarded-For` / `X-Real-IP`,后端获取到的 `req.IP` 将为空字符串。
|
||||||
|
- 这就是导致“数据库都没有IP”的直接原因。
|
||||||
|
|
||||||
|
## 疑问澄清 (Questions)
|
||||||
|
我们需要确认修复方案:
|
||||||
|
|
||||||
|
1. **是否只需支持标准 Header?**
|
||||||
|
- 建议改为优先获取 `X-Forwarded-For`,其次 `X-Real-IP`,最后才是 `X-Original-Forwarded-For` 或直连 IP。
|
||||||
|
- go-zero 框架通常可以通过配置或中间件处理 IP,或者我们在 struct tag 中调整。但 struct tag `header` 只能由 go-zero 的 rest 绑定一个特定的 key。
|
||||||
|
2. **是否需要记录 IP 归属地?**
|
||||||
|
- 目前逻辑只记录 IP 字符串,不解析归属地。需求中没提,暂时不作为重点,但可以确认一下。
|
||||||
|
|
||||||
|
## 建议方案
|
||||||
|
修改 `DeviceLoginRequest` 的定义可能不够灵活(Header key 是固定的)。
|
||||||
|
更好的方式是:
|
||||||
|
1. **移除 Struct Tag 绑定**(或者保留作为备选)。
|
||||||
|
2. **在 Logic 中显式获取 IP**:
|
||||||
|
- 从 `l.ctx` (Context) 中获取 `http.Request` (如果 go-zero 支持)。
|
||||||
|
- 或者在 Middleware 中解析真实 IP 并放入 Context。
|
||||||
|
- 或者简单点,修改 Struct Tag 为最常用的 `X-Forwarded-For` (如果确定环境是这样配置的)。
|
||||||
|
|
||||||
|
**最快修复**:
|
||||||
|
将 `internal/types/types.go` 中的 `X-Original-Forwarded-For` 改为 `X-Forwarded-For` (或者根据实际网关配置修改)。
|
||||||
|
但通常建议使用工具函数解析多种 Header。
|
||||||
|
|
||||||
|
## 下一步 (Next Step)
|
||||||
|
请确认是否要我修改代码以支持标准的 IP 获取方式(如 `X-Forwarded-For`)?
|
||||||
36
docs/Check_Device_Login_IP/DESIGN_Check_Device_Login_IP.md
Normal file
36
docs/Check_Device_Login_IP/DESIGN_Check_Device_Login_IP.md
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
# DESIGN: Device Login IP Fix
|
||||||
|
|
||||||
|
## 目标
|
||||||
|
修复设备登录时无法获取真实 IP (`req.IP` 为空) 的问题,导致数据库未存储 IP。
|
||||||
|
|
||||||
|
## 现状
|
||||||
|
- `internal/types/types.go` 定义了 `DeviceLoginRequest`,其中 `IP` 字段绑定的是 `X-Original-Forwarded-For`。
|
||||||
|
- 实际环境中(Nginx/Cloudflare等)通常使用 `X-Forwarded-For`。
|
||||||
|
|
||||||
|
## 方案选择
|
||||||
|
由于项目使用 `go-zero` 并且存在 `.api` 文件,**最佳实践**是修改 `.api` 文件并重新生成代码。
|
||||||
|
但考虑到我无法运行 `goctl` (或者环境可能不一致),如果不重新生成而直接改 `types.go`,虽然能即时生效,但下次生成会被覆盖。
|
||||||
|
|
||||||
|
**然而**,鉴于我之前的操作已经直接修改过 `types.go` (Invite Sales Time Filter),且项目看似允许直接修改(或用户负责生成),我将**优先修改 `.api` 文件** 以保持源头正确,同时**手动同步修改 `types.go`** 以确保立即生效。
|
||||||
|
|
||||||
|
## 变更范围
|
||||||
|
|
||||||
|
### 1. API 定义 (`apis/auth/auth.api`)
|
||||||
|
- 修改 `DeviceLoginRequest` struct。
|
||||||
|
- 将 `header: X-Original-Forwarded-For` 改为 `header: X-Forwarded-For` (这是最通用的标准)。
|
||||||
|
|
||||||
|
### 2. 生成文件 (`internal/types/types.go`)
|
||||||
|
- 手动同步修改 `DeviceLoginRequest` 中的 Tag。
|
||||||
|
- 变为: `IP string header:"X-Forwarded-For"`
|
||||||
|
|
||||||
|
### 3. (可选增强) 业务逻辑 (`internal/logic/auth/deviceLoginLogic.go`)
|
||||||
|
- 由于 go-zero 的绑定机制比较“死”,如果 Tag 没取到值,就是空的。Logic 层拿到空字符串也没办法再去 Context 捞(除非 Context 里存了 request)。
|
||||||
|
- 暂时只做 Tag 修改,因为这是最根本原因。
|
||||||
|
|
||||||
|
## 验证
|
||||||
|
- 检查代码变更。
|
||||||
|
- (无法直接测试 IP 获取,依赖用户部署验证)。
|
||||||
|
|
||||||
|
## 任务拆分
|
||||||
|
1. 修改 `apis/auth/auth.api`
|
||||||
|
2. 修改 `internal/types/types.go`
|
||||||
@ -30,6 +30,7 @@ type Config struct {
|
|||||||
Invite InviteConfig `yaml:"Invite"`
|
Invite InviteConfig `yaml:"Invite"`
|
||||||
Kutt KuttConfig `yaml:"Kutt"`
|
Kutt KuttConfig `yaml:"Kutt"`
|
||||||
OpenInstall OpenInstallConfig `yaml:"OpenInstall"`
|
OpenInstall OpenInstallConfig `yaml:"OpenInstall"`
|
||||||
|
Loki LokiConfig `yaml:"Loki"`
|
||||||
Telegram Telegram `yaml:"Telegram"`
|
Telegram Telegram `yaml:"Telegram"`
|
||||||
Log Log `yaml:"Log"`
|
Log Log `yaml:"Log"`
|
||||||
Trace trace.Config `yaml:"Trace"`
|
Trace trace.Config `yaml:"Trace"`
|
||||||
@ -227,6 +228,12 @@ type OpenInstallConfig struct {
|
|||||||
ApiKey string `yaml:"ApiKey" default:""` // OpenInstall 数据接口 ApiKey
|
ApiKey string `yaml:"ApiKey" default:""` // OpenInstall 数据接口 ApiKey
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// LokiConfig Loki 日志查询配置
|
||||||
|
type LokiConfig struct {
|
||||||
|
Enable bool `yaml:"Enable" default:"false"` // 是否启用 Loki 查询
|
||||||
|
URL string `yaml:"URL" default:"http://localhost:3100"` // Loki 服务地址
|
||||||
|
}
|
||||||
|
|
||||||
type Telegram struct {
|
type Telegram struct {
|
||||||
Enable bool `yaml:"Enable" default:"false"`
|
Enable bool `yaml:"Enable" default:"false"`
|
||||||
BotID int64 `yaml:"BotID" default:""`
|
BotID int64 `yaml:"BotID" default:""`
|
||||||
|
|||||||
@ -40,7 +40,6 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func RegisterHandlers(router *gin.Engine, serverCtx *svc.ServiceContext) {
|
func RegisterHandlers(router *gin.Engine, serverCtx *svc.ServiceContext) {
|
||||||
router.Use(middleware.TraceMiddleware(serverCtx))
|
|
||||||
|
|
||||||
adminAdsGroupRouter := router.Group("/v1/admin/ads")
|
adminAdsGroupRouter := router.Group("/v1/admin/ads")
|
||||||
adminAdsGroupRouter.Use(middleware.AuthMiddleware(serverCtx))
|
adminAdsGroupRouter.Use(middleware.AuthMiddleware(serverCtx))
|
||||||
|
|||||||
@ -9,6 +9,7 @@ import (
|
|||||||
"github.com/perfect-panel/server/internal/types"
|
"github.com/perfect-panel/server/internal/types"
|
||||||
"github.com/perfect-panel/server/pkg/constant"
|
"github.com/perfect-panel/server/pkg/constant"
|
||||||
"github.com/perfect-panel/server/pkg/logger"
|
"github.com/perfect-panel/server/pkg/logger"
|
||||||
|
"github.com/perfect-panel/server/pkg/loki"
|
||||||
"github.com/perfect-panel/server/pkg/openinstall"
|
"github.com/perfect-panel/server/pkg/openinstall"
|
||||||
"github.com/perfect-panel/server/pkg/xerr"
|
"github.com/perfect-panel/server/pkg/xerr"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
@ -20,6 +21,7 @@ type GetAgentDownloadsLogic struct {
|
|||||||
svcCtx *svc.ServiceContext
|
svcCtx *svc.ServiceContext
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NewGetAgentDownloadsLogic 创建 GetAgentDownloadsLogic 实例
|
||||||
func NewGetAgentDownloadsLogic(ctx context.Context, svcCtx *svc.ServiceContext) *GetAgentDownloadsLogic {
|
func NewGetAgentDownloadsLogic(ctx context.Context, svcCtx *svc.ServiceContext) *GetAgentDownloadsLogic {
|
||||||
return &GetAgentDownloadsLogic{
|
return &GetAgentDownloadsLogic{
|
||||||
Logger: logger.WithContext(ctx),
|
Logger: logger.WithContext(ctx),
|
||||||
@ -28,6 +30,8 @@ func NewGetAgentDownloadsLogic(ctx context.Context, svcCtx *svc.ServiceContext)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetAgentDownloads 获取用户代理下载统计数据
|
||||||
|
// 结合 OpenInstall (iOS/Android) 和 Loki (Windows/Mac) 数据源
|
||||||
func (l *GetAgentDownloadsLogic) GetAgentDownloads(req *types.GetAgentDownloadsRequest) (resp *types.GetAgentDownloadsResponse, err error) {
|
func (l *GetAgentDownloadsLogic) GetAgentDownloads(req *types.GetAgentDownloadsRequest) (resp *types.GetAgentDownloadsResponse, err error) {
|
||||||
// 1. 从 context 获取用户信息
|
// 1. 从 context 获取用户信息
|
||||||
u, ok := l.ctx.Value(constant.CtxKeyUser).(*user.User)
|
u, ok := l.ctx.Value(constant.CtxKeyUser).(*user.User)
|
||||||
@ -36,72 +40,68 @@ func (l *GetAgentDownloadsLogic) GetAgentDownloads(req *types.GetAgentDownloadsR
|
|||||||
return nil, errors.Wrapf(xerr.NewErrCode(xerr.InvalidAccess), "Invalid Access")
|
return nil, errors.Wrapf(xerr.NewErrCode(xerr.InvalidAccess), "Invalid Access")
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2. 检查 OpenInstall 是否启用
|
// 初始化响应数据
|
||||||
cfg := l.svcCtx.Config.OpenInstall
|
var iosCount, androidCount, windowsCount, macCount int64
|
||||||
if !cfg.Enable {
|
var comparisonRate *string
|
||||||
l.Infow("[GetAgentDownloads] OpenInstall is disabled, returning zero stats")
|
|
||||||
return &types.GetAgentDownloadsResponse{
|
// 2. 从 OpenInstall 获取 iOS/Android 数据
|
||||||
Total: 0,
|
openInstallCfg := l.svcCtx.Config.OpenInstall
|
||||||
Platforms: &types.PlatformDownloads{
|
if openInstallCfg.Enable && openInstallCfg.ApiKey != "" {
|
||||||
IOS: 0,
|
client := openinstall.NewClient(openInstallCfg.ApiKey)
|
||||||
Android: 0,
|
platformDownloads, err := client.GetPlatformDownloads(l.ctx, u.ReferCode)
|
||||||
Windows: 0,
|
if err != nil {
|
||||||
Mac: 0,
|
l.Errorw("Failed to fetch OpenInstall platform downloads", logger.Field("error", err), logger.Field("user_id", u.Id))
|
||||||
},
|
// 不返回错误,继续处理其他数据源
|
||||||
}, nil
|
} else {
|
||||||
|
iosCount = platformDownloads.IOS
|
||||||
|
androidCount = platformDownloads.Android
|
||||||
|
// OpenInstall 的 Windows/Mac 数据可能为空,后面用 Loki 补充
|
||||||
|
|
||||||
|
// 计算环比
|
||||||
|
if platformDownloads.Comparison != nil {
|
||||||
|
percent := platformDownloads.Comparison.ChangePercent
|
||||||
|
var formatted string
|
||||||
|
if percent >= 0 {
|
||||||
|
formatted = fmt.Sprintf("+%.1f%%", percent)
|
||||||
|
} else {
|
||||||
|
formatted = fmt.Sprintf("%.1f%%", percent)
|
||||||
|
}
|
||||||
|
comparisonRate = &formatted
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 3. 检查 ApiKey 是否配置
|
// 3. 从 Loki 获取 Windows/Mac 数据(基于用户邀请码)
|
||||||
if cfg.ApiKey == "" {
|
lokiCfg := l.svcCtx.Config.Loki
|
||||||
l.Errorw("[GetAgentDownloads] OpenInstall ApiKey not configured")
|
if lokiCfg.Enable && lokiCfg.URL != "" && u.ReferCode != "" {
|
||||||
return &types.GetAgentDownloadsResponse{
|
lokiClient := loki.NewClient(lokiCfg.URL)
|
||||||
Total: 0,
|
lokiStats, err := lokiClient.GetInviteCodeStats(l.ctx, u.ReferCode, 30)
|
||||||
Platforms: &types.PlatformDownloads{
|
if err != nil {
|
||||||
IOS: 0,
|
l.Errorw("Failed to fetch Loki stats", logger.Field("error", err), logger.Field("user_id", u.Id), logger.Field("refer_code", u.ReferCode))
|
||||||
Android: 0,
|
// 不返回错误,继续使用已有数据
|
||||||
Windows: 0,
|
} else {
|
||||||
Mac: 0,
|
// 使用 Loki 的 Windows/Mac 数据
|
||||||
},
|
windowsCount = lokiStats.WindowsClicks
|
||||||
}, nil
|
macCount = lokiStats.MacClicks
|
||||||
|
l.Infow("Fetched Loki stats successfully",
|
||||||
|
logger.Field("user_id", u.Id),
|
||||||
|
logger.Field("refer_code", u.ReferCode),
|
||||||
|
logger.Field("windows", windowsCount),
|
||||||
|
logger.Field("mac", macCount))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 4. 调用 OpenInstall API 获取各端下载量
|
// 4. 计算总量
|
||||||
client := openinstall.NewClient(cfg.ApiKey)
|
total := iosCount + androidCount + windowsCount + macCount
|
||||||
platformDownloads, err := client.GetPlatformDownloads(l.ctx)
|
|
||||||
if err != nil {
|
|
||||||
l.Errorw("Failed to fetch OpenInstall platform downloads", logger.Field("error", err), logger.Field("user_id", u.Id))
|
|
||||||
// 返回空数据而不是错误,避免影响前端显示
|
|
||||||
return &types.GetAgentDownloadsResponse{
|
|
||||||
Total: 0,
|
|
||||||
Platforms: &types.PlatformDownloads{
|
|
||||||
IOS: 0,
|
|
||||||
Android: 0,
|
|
||||||
Windows: 0,
|
|
||||||
Mac: 0,
|
|
||||||
},
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// 5. 构造响应
|
// 5. 构造响应
|
||||||
var comparisonRate *string
|
|
||||||
if platformDownloads.Comparison != nil {
|
|
||||||
percent := platformDownloads.Comparison.ChangePercent
|
|
||||||
var formatted string
|
|
||||||
if percent >= 0 {
|
|
||||||
formatted = fmt.Sprintf("+%.1f%%", percent)
|
|
||||||
} else {
|
|
||||||
formatted = fmt.Sprintf("%.1f%%", percent)
|
|
||||||
}
|
|
||||||
comparisonRate = &formatted
|
|
||||||
}
|
|
||||||
|
|
||||||
return &types.GetAgentDownloadsResponse{
|
return &types.GetAgentDownloadsResponse{
|
||||||
Total: platformDownloads.Total,
|
Total: total,
|
||||||
Platforms: &types.PlatformDownloads{
|
Platforms: &types.PlatformDownloads{
|
||||||
IOS: platformDownloads.IOS,
|
IOS: iosCount,
|
||||||
Android: platformDownloads.Android,
|
Android: androidCount,
|
||||||
Windows: platformDownloads.Windows,
|
Windows: windowsCount,
|
||||||
Mac: platformDownloads.Mac,
|
Mac: macCount,
|
||||||
},
|
},
|
||||||
ComparisonRate: comparisonRate,
|
ComparisonRate: comparisonRate,
|
||||||
}, nil
|
}, nil
|
||||||
|
|||||||
@ -40,11 +40,19 @@ func (l *GetInviteSalesLogic) GetInviteSales(req *types.GetInviteSalesRequest) (
|
|||||||
|
|
||||||
// 2. Count total sales
|
// 2. Count total sales
|
||||||
var totalSales int64
|
var totalSales int64
|
||||||
err = l.svcCtx.DB.WithContext(l.ctx).
|
db := l.svcCtx.DB.WithContext(l.ctx).
|
||||||
Table("`order` o").
|
Table("`order` o").
|
||||||
Joins("JOIN user u ON o.user_id = u.id").
|
Joins("JOIN user u ON o.user_id = u.id").
|
||||||
Where("u.referer_id = ? AND o.status = ?", userId, 5).
|
Where("u.referer_id = ? AND o.status = ?", userId, 5)
|
||||||
Count(&totalSales).Error
|
|
||||||
|
if req.StartTime > 0 {
|
||||||
|
db = db.Where("o.updated_at >= FROM_UNIXTIME(?)", req.StartTime)
|
||||||
|
}
|
||||||
|
if req.EndTime > 0 {
|
||||||
|
db = db.Where("o.updated_at <= FROM_UNIXTIME(?)", req.EndTime)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = db.Count(&totalSales).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.Errorw("[GetInviteSales] count sales failed",
|
l.Errorw("[GetInviteSales] count sales failed",
|
||||||
logger.Field("error", err.Error()),
|
logger.Field("error", err.Error()),
|
||||||
@ -75,13 +83,21 @@ func (l *GetInviteSalesLogic) GetInviteSales(req *types.GetInviteSalesRequest) (
|
|||||||
}
|
}
|
||||||
|
|
||||||
var orderData []OrderWithUser
|
var orderData []OrderWithUser
|
||||||
err = l.svcCtx.DB.WithContext(l.ctx).
|
query := l.svcCtx.DB.WithContext(l.ctx).
|
||||||
Table("`order` o").
|
Table("`order` o").
|
||||||
Select("o.amount, CAST(UNIX_TIMESTAMP(o.updated_at) * 1000 AS SIGNED) as updated_at, u.id as user_id, s.name as product_name, o.quantity").
|
Select("o.amount, CAST(UNIX_TIMESTAMP(o.updated_at) * 1000 AS SIGNED) as updated_at, u.id as user_id, s.name as product_name, o.quantity").
|
||||||
Joins("JOIN user u ON o.user_id = u.id").
|
Joins("JOIN user u ON o.user_id = u.id").
|
||||||
Joins("LEFT JOIN subscribe s ON o.subscribe_id = s.id").
|
Joins("LEFT JOIN subscribe s ON o.subscribe_id = s.id").
|
||||||
Where("u.referer_id = ? AND o.status = ?", userId, 5). // status 5: Finished
|
Where("u.referer_id = ? AND o.status = ?", userId, 5) // status 5: Finished
|
||||||
Order("o.updated_at DESC").
|
|
||||||
|
if req.StartTime > 0 {
|
||||||
|
query = query.Where("o.updated_at >= FROM_UNIXTIME(?)", req.StartTime)
|
||||||
|
}
|
||||||
|
if req.EndTime > 0 {
|
||||||
|
query = query.Where("o.updated_at <= FROM_UNIXTIME(?)", req.EndTime)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = query.Order("o.updated_at DESC").
|
||||||
Limit(req.Size).
|
Limit(req.Size).
|
||||||
Offset(offset).
|
Offset(offset).
|
||||||
Scan(&orderData).Error
|
Scan(&orderData).Error
|
||||||
|
|||||||
168
internal/logic/public/user/getInviteSalesLogic_test.go
Normal file
168
internal/logic/public/user/getInviteSalesLogic_test.go
Normal file
@ -0,0 +1,168 @@
|
|||||||
|
package user
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/alicebob/miniredis/v2"
|
||||||
|
"github.com/perfect-panel/server/internal/config"
|
||||||
|
"github.com/perfect-panel/server/internal/model/order"
|
||||||
|
"github.com/perfect-panel/server/internal/model/subscribe"
|
||||||
|
"github.com/perfect-panel/server/internal/model/user"
|
||||||
|
"github.com/perfect-panel/server/internal/svc"
|
||||||
|
"github.com/perfect-panel/server/internal/types"
|
||||||
|
"github.com/perfect-panel/server/pkg/constant"
|
||||||
|
"github.com/redis/go-redis/v9"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"gorm.io/driver/sqlite"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
// setupTestSvcCtx 初始化测试上下文
|
||||||
|
func setupTestSvcCtx(t *testing.T) (*svc.ServiceContext, *gorm.DB) {
|
||||||
|
// 1. Setup Miniredis
|
||||||
|
mr, err := miniredis.Run()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
rdb := redis.NewClient(&redis.Options{
|
||||||
|
Addr: mr.Addr(),
|
||||||
|
})
|
||||||
|
|
||||||
|
// 2. Setup GORM with SQLite
|
||||||
|
dbName := fmt.Sprintf("test_sales_%d.db", time.Now().UnixNano())
|
||||||
|
db, err := gorm.Open(sqlite.Open(dbName), &gorm.Config{})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
t.Cleanup(func() {
|
||||||
|
os.Remove(dbName)
|
||||||
|
mr.Close()
|
||||||
|
})
|
||||||
|
|
||||||
|
// Migrate tables
|
||||||
|
_ = db.Migrator().CreateTable(&user.User{})
|
||||||
|
_ = db.Migrator().CreateTable(&subscribe.Subscribe{}) // Plan definition
|
||||||
|
_ = db.Migrator().CreateTable(&order.Order{})
|
||||||
|
|
||||||
|
// 3. Create ServiceContext
|
||||||
|
svcCtx := &svc.ServiceContext{
|
||||||
|
Redis: rdb,
|
||||||
|
DB: db,
|
||||||
|
Config: config.Config{},
|
||||||
|
UserModel: user.NewModel(db, rdb),
|
||||||
|
}
|
||||||
|
|
||||||
|
return svcCtx, db
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetInviteSales_TimeFilter(t *testing.T) {
|
||||||
|
svcCtx, db := setupTestSvcCtx(t)
|
||||||
|
|
||||||
|
// 1. Prepare Data
|
||||||
|
// Referrer User (Current User)
|
||||||
|
referrer := &user.User{
|
||||||
|
Id: 100,
|
||||||
|
// Email removed (not in struct)
|
||||||
|
ReferCode: "REF100",
|
||||||
|
}
|
||||||
|
db.Create(referrer)
|
||||||
|
|
||||||
|
// Invited User
|
||||||
|
invitedUser := &user.User{
|
||||||
|
Id: 200,
|
||||||
|
// Email removed
|
||||||
|
RefererId: referrer.Id, // Linked to referrer
|
||||||
|
}
|
||||||
|
db.Create(invitedUser)
|
||||||
|
|
||||||
|
// Subscribe (Plan)
|
||||||
|
sub := &subscribe.Subscribe{
|
||||||
|
Id: 1,
|
||||||
|
Name: "Standard Plan",
|
||||||
|
}
|
||||||
|
db.Create(sub)
|
||||||
|
|
||||||
|
// Orders
|
||||||
|
// Order 1: Inside Range (2023-10-15)
|
||||||
|
timeIn := time.Date(2023, 10, 15, 12, 0, 0, 0, time.UTC)
|
||||||
|
db.Create(&order.Order{
|
||||||
|
UserId: invitedUser.Id,
|
||||||
|
OrderNo: "ORD001",
|
||||||
|
Status: 5, // Finished
|
||||||
|
Amount: 1000,
|
||||||
|
Quantity: 30,
|
||||||
|
SubscribeId: sub.Id,
|
||||||
|
UpdatedAt: timeIn,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Order 2: Before Range (2023-09-15)
|
||||||
|
timeBefore := time.Date(2023, 9, 15, 12, 0, 0, 0, time.UTC)
|
||||||
|
db.Create(&order.Order{
|
||||||
|
UserId: invitedUser.Id,
|
||||||
|
OrderNo: "ORD002",
|
||||||
|
Status: 5, // Finished
|
||||||
|
Amount: 1000,
|
||||||
|
Quantity: 30,
|
||||||
|
SubscribeId: sub.Id,
|
||||||
|
UpdatedAt: timeBefore,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Order 3: After Range (2023-11-15)
|
||||||
|
timeAfter := time.Date(2023, 11, 15, 12, 0, 0, 0, time.UTC)
|
||||||
|
db.Create(&order.Order{
|
||||||
|
UserId: invitedUser.Id,
|
||||||
|
OrderNo: "ORD003",
|
||||||
|
Status: 5, // Finished
|
||||||
|
Amount: 1000,
|
||||||
|
Quantity: 30,
|
||||||
|
SubscribeId: sub.Id,
|
||||||
|
UpdatedAt: timeAfter,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Order 4: Wrong Status (2023-10-16) - Should be ignored
|
||||||
|
db.Create(&order.Order{
|
||||||
|
UserId: invitedUser.Id,
|
||||||
|
OrderNo: "ORD004",
|
||||||
|
Status: 1, // Pending
|
||||||
|
Amount: 1000,
|
||||||
|
Quantity: 30,
|
||||||
|
SubscribeId: sub.Id,
|
||||||
|
UpdatedAt: timeIn.Add(24 * time.Hour),
|
||||||
|
})
|
||||||
|
|
||||||
|
// 2. Execute Logic
|
||||||
|
// Context with current user
|
||||||
|
ctx := context.WithValue(context.Background(), constant.CtxKeyUser, referrer)
|
||||||
|
l := NewGetInviteSalesLogic(ctx, svcCtx)
|
||||||
|
|
||||||
|
// Filter for October 2023
|
||||||
|
startTime := time.Date(2023, 10, 1, 0, 0, 0, 0, time.UTC).Unix()
|
||||||
|
endTime := time.Date(2023, 10, 31, 23, 59, 59, 0, time.UTC).Unix()
|
||||||
|
|
||||||
|
req := &types.GetInviteSalesRequest{
|
||||||
|
Page: 1,
|
||||||
|
Size: 10,
|
||||||
|
StartTime: startTime, // 2023-10-01
|
||||||
|
EndTime: endTime, // 2023-10-31
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := l.GetInviteSales(req)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// 3. Verify Results
|
||||||
|
// Should match exactly 1 order (ORD001)
|
||||||
|
assert.Equal(t, int64(1), resp.Total, "Should return exactly 1 order matching time range and status")
|
||||||
|
if assert.NotEmpty(t, resp.List) {
|
||||||
|
assert.Equal(t, 1, len(resp.List))
|
||||||
|
// Log result for debug
|
||||||
|
t.Logf("Found Sale: Amount=%.2f, Time=%d", resp.List[0].Amount, resp.List[0].UpdatedAt)
|
||||||
|
|
||||||
|
// Verify timestamp is roughly correct (millisecond precision in logic)
|
||||||
|
expectedMs := timeIn.Unix() * 1000
|
||||||
|
assert.Equal(t, expectedMs, resp.List[0].UpdatedAt)
|
||||||
|
} else {
|
||||||
|
t.Error("Returned list is empty")
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -6,8 +6,10 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
model "github.com/perfect-panel/server/internal/model/user"
|
||||||
"github.com/perfect-panel/server/pkg/constant"
|
"github.com/perfect-panel/server/pkg/constant"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@ -31,6 +33,21 @@ func (w bodyLogWriter) Write(b []byte) (int, error) {
|
|||||||
return w.ResponseWriter.Write(b)
|
return w.ResponseWriter.Write(b)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// inviteCodeRegex matches invite code patterns in URLs like:
|
||||||
|
// /v1/common/client/download/file/Hi快VPN-mac-1.0.0-ic-uuSo11uy.dmg
|
||||||
|
// Matches: ic-XXXXX or ic_XXXXX before file extension
|
||||||
|
var inviteCodeRegex = regexp.MustCompile(`[-_]ic[-_]([a-zA-Z0-9]+)\.[a-zA-Z0-9]+$`)
|
||||||
|
|
||||||
|
// extractInviteCode extracts invite code from URL path
|
||||||
|
// Returns empty string if no invite code found
|
||||||
|
func extractInviteCode(path string) string {
|
||||||
|
matches := inviteCodeRegex.FindStringSubmatch(path)
|
||||||
|
if len(matches) >= 2 {
|
||||||
|
return matches[1]
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
// statusByWriter returns a span status code and message for an HTTP status code
|
// statusByWriter returns a span status code and message for an HTTP status code
|
||||||
// value returned by a server. Status codes in the 400-499 range are not
|
// value returned by a server. Status codes in the 400-499 range are not
|
||||||
// returned as errors.
|
// returned as errors.
|
||||||
@ -48,7 +65,7 @@ func requestAttributes(req *http.Request) []attribute.KeyValue {
|
|||||||
protoN := strings.SplitN(req.Proto, "/", 2)
|
protoN := strings.SplitN(req.Proto, "/", 2)
|
||||||
remoteAddrN := strings.SplitN(req.RemoteAddr, ":", 2)
|
remoteAddrN := strings.SplitN(req.RemoteAddr, ":", 2)
|
||||||
|
|
||||||
return []attribute.KeyValue{
|
attrs := []attribute.KeyValue{
|
||||||
semconv.HTTPRequestMethodKey.String(req.Method),
|
semconv.HTTPRequestMethodKey.String(req.Method),
|
||||||
semconv.HTTPUserAgentKey.String(req.UserAgent()),
|
semconv.HTTPUserAgentKey.String(req.UserAgent()),
|
||||||
semconv.HTTPRequestContentLengthKey.Int64(req.ContentLength),
|
semconv.HTTPRequestContentLengthKey.Int64(req.ContentLength),
|
||||||
@ -65,6 +82,66 @@ func requestAttributes(req *http.Request) []attribute.KeyValue {
|
|||||||
semconv.ClientAddressKey.String(remoteAddrN[0]),
|
semconv.ClientAddressKey.String(remoteAddrN[0]),
|
||||||
semconv.ClientPortKey.String(remoteAddrN[1]),
|
semconv.ClientPortKey.String(remoteAddrN[1]),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Extract invite code from URL path (e.g., /v1/common/client/download/file/Hi快VPN-mac-1.0.0-ic-uuSo11uy.dmg)
|
||||||
|
if inviteCode := extractInviteCode(req.URL.Path); inviteCode != "" {
|
||||||
|
attrs = append(attrs, attribute.String("affiliate.invite_code", inviteCode))
|
||||||
|
attrs = append(attrs, attribute.String("affiliate.source", "download_link"))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Also check query parameter for invite code (e.g., ?ic=uuSo11uy)
|
||||||
|
if ic := req.URL.Query().Get("ic"); ic != "" {
|
||||||
|
attrs = append(attrs, attribute.String("affiliate.invite_code", ic))
|
||||||
|
attrs = append(attrs, attribute.String("affiliate.source", "query_param"))
|
||||||
|
}
|
||||||
|
|
||||||
|
return attrs
|
||||||
|
}
|
||||||
|
|
||||||
|
// userAttributes extracts user information from context and returns span attributes
|
||||||
|
func userAttributes(ctx context.Context) []attribute.KeyValue {
|
||||||
|
var attrs []attribute.KeyValue
|
||||||
|
|
||||||
|
// Get user info from context (set by authMiddleware)
|
||||||
|
if userInfo := ctx.Value(constant.CtxKeyUser); userInfo != nil {
|
||||||
|
if user, ok := userInfo.(*model.User); ok {
|
||||||
|
var email string
|
||||||
|
for _, method := range user.AuthMethods {
|
||||||
|
if method.AuthType == "email" {
|
||||||
|
email = method.AuthIdentifier
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attrs = append(attrs,
|
||||||
|
attribute.Int64("user.id", user.Id),
|
||||||
|
attribute.String("user.email", email),
|
||||||
|
attribute.Bool("user.is_admin", *user.IsAdmin),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get session ID from context
|
||||||
|
if sessionID := ctx.Value(constant.CtxKeySessionID); sessionID != nil {
|
||||||
|
if sid, ok := sessionID.(string); ok {
|
||||||
|
attrs = append(attrs, attribute.String("user.session_id", sid))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get device ID from context
|
||||||
|
if deviceID := ctx.Value(constant.CtxKeyDeviceID); deviceID != nil {
|
||||||
|
if did, ok := deviceID.(int64); ok {
|
||||||
|
attrs = append(attrs, attribute.Int64("user.device_id", did))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get login type from context
|
||||||
|
if loginType := ctx.Value(constant.LoginType); loginType != nil {
|
||||||
|
if lt, ok := loginType.(string); ok {
|
||||||
|
attrs = append(attrs, attribute.String("user.login_type", lt))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return attrs
|
||||||
}
|
}
|
||||||
|
|
||||||
func TraceMiddleware(_ *svc.ServiceContext) func(ctx *gin.Context) {
|
func TraceMiddleware(_ *svc.ServiceContext) func(ctx *gin.Context) {
|
||||||
@ -99,6 +176,9 @@ func TraceMiddleware(_ *svc.ServiceContext) func(ctx *gin.Context) {
|
|||||||
semconv.HTTPRouteKey.String(c.FullPath()),
|
semconv.HTTPRouteKey.String(c.FullPath()),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Add user attributes from context (set by authMiddleware)
|
||||||
|
span.SetAttributes(userAttributes(ctx)...)
|
||||||
|
|
||||||
// Record Request Body (limit to 1MB)
|
// Record Request Body (limit to 1MB)
|
||||||
if len(reqBody) > 0 {
|
if len(reqBody) > 0 {
|
||||||
limit := 1048576
|
limit := 1048576
|
||||||
|
|||||||
@ -948,8 +948,10 @@ type GetUserInviteStatsResponse struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type GetInviteSalesRequest struct {
|
type GetInviteSalesRequest struct {
|
||||||
Page int `form:"page" validate:"required"`
|
Page int `form:"page" validate:"required"`
|
||||||
Size int `form:"size" validate:"required"`
|
Size int `form:"size" validate:"required"`
|
||||||
|
StartTime int64 `form:"start_time,optional"`
|
||||||
|
EndTime int64 `form:"end_time,optional"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type GetInviteSalesResponse struct {
|
type GetInviteSalesResponse struct {
|
||||||
|
|||||||
159
pkg/loki/loki.go
Normal file
159
pkg/loki/loki.go
Normal file
@ -0,0 +1,159 @@
|
|||||||
|
package loki
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Client Loki 客户端
|
||||||
|
type Client struct {
|
||||||
|
url string
|
||||||
|
httpClient *http.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewClient 创建新的 Loki 客户端
|
||||||
|
// url: Loki 服务地址,例如 http://154.12.35.103:3100
|
||||||
|
func NewClient(url string) *Client {
|
||||||
|
return &Client{
|
||||||
|
url: url,
|
||||||
|
httpClient: &http.Client{
|
||||||
|
Timeout: 30 * time.Second,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// InviteCodeStats 邀请码统计数据
|
||||||
|
type InviteCodeStats struct {
|
||||||
|
MacClicks int64 `json:"mac_clicks"` // Mac 下载点击数
|
||||||
|
WindowsClicks int64 `json:"windows_clicks"` // Windows 下载点击数
|
||||||
|
LastMonthMac int64 `json:"last_month_mac"` // 上月 Mac 下载数
|
||||||
|
LastMonthWindows int64 `json:"last_month_windows"` // 上月 Windows 下载数
|
||||||
|
}
|
||||||
|
|
||||||
|
// LokiQueryResponse Loki 查询响应结构
|
||||||
|
type LokiQueryResponse struct {
|
||||||
|
Status string `json:"status"`
|
||||||
|
Data struct {
|
||||||
|
ResultType string `json:"resultType"`
|
||||||
|
Result []struct {
|
||||||
|
Stream map[string]string `json:"stream"`
|
||||||
|
Values [][]string `json:"values"` // [[timestamp, log_line], ...]
|
||||||
|
} `json:"result"`
|
||||||
|
} `json:"data"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetInviteCodeStats 获取指定邀请码的下载统计
|
||||||
|
// inviteCode: 邀请码
|
||||||
|
// days: 统计天数(默认30天)
|
||||||
|
func (c *Client) GetInviteCodeStats(ctx context.Context, inviteCode string, days int) (*InviteCodeStats, error) {
|
||||||
|
if days <= 0 {
|
||||||
|
days = 30
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now().UTC()
|
||||||
|
startTime := now.Add(-time.Duration(days) * 24 * time.Hour)
|
||||||
|
|
||||||
|
// 上月时间范围
|
||||||
|
lastMonthEnd := startTime
|
||||||
|
lastMonthStart := startTime.Add(-time.Duration(days) * 24 * time.Hour)
|
||||||
|
|
||||||
|
// 查询本月数据
|
||||||
|
thisMonthStats, err := c.queryPeriodStats(ctx, inviteCode, startTime, now)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("查询本月数据失败: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 查询上月数据
|
||||||
|
lastMonthStats, err := c.queryPeriodStats(ctx, inviteCode, lastMonthStart, lastMonthEnd)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("查询上月数据失败: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &InviteCodeStats{
|
||||||
|
MacClicks: thisMonthStats.MacClicks,
|
||||||
|
WindowsClicks: thisMonthStats.WindowsClicks,
|
||||||
|
LastMonthMac: lastMonthStats.MacClicks,
|
||||||
|
LastMonthWindows: lastMonthStats.WindowsClicks,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// queryPeriodStats 查询指定时间范围的统计数据
|
||||||
|
func (c *Client) queryPeriodStats(ctx context.Context, inviteCode string, startTime, endTime time.Time) (*InviteCodeStats, error) {
|
||||||
|
// 构建 Loki 查询
|
||||||
|
query := fmt.Sprintf(`{job="nginx_access", invite_code="%s"}`, inviteCode)
|
||||||
|
|
||||||
|
apiURL := fmt.Sprintf("%s/loki/api/v1/query_range", c.url)
|
||||||
|
|
||||||
|
params := url.Values{}
|
||||||
|
params.Add("query", query)
|
||||||
|
params.Add("start", startTime.Format(time.RFC3339))
|
||||||
|
params.Add("end", endTime.Format(time.RFC3339))
|
||||||
|
params.Add("limit", "5000")
|
||||||
|
|
||||||
|
fullURL := fmt.Sprintf("%s?%s", apiURL, params.Encode())
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, fullURL, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("创建请求失败: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := c.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("发送请求失败: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
return nil, fmt.Errorf("Loki 返回错误状态码 %d: %s", resp.StatusCode, string(body))
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("读取响应失败: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var lokiResp LokiQueryResponse
|
||||||
|
if err := json.Unmarshal(body, &lokiResp); err != nil {
|
||||||
|
return nil, fmt.Errorf("解析响应失败: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 解析日志行统计 Mac 和 Windows 下载
|
||||||
|
stats := &InviteCodeStats{}
|
||||||
|
|
||||||
|
// Nginx combined log format regex
|
||||||
|
// 格式: IP - - [time] "METHOD URI VERSION" STATUS BYTES "REFERER" "UA"
|
||||||
|
logPattern := regexp.MustCompile(`"[A-Z]+ ([^ ]+) `)
|
||||||
|
|
||||||
|
for _, result := range lokiResp.Data.Result {
|
||||||
|
for _, value := range result.Values {
|
||||||
|
if len(value) < 2 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
logLine := value[1]
|
||||||
|
|
||||||
|
// 提取 URI
|
||||||
|
matches := logPattern.FindStringSubmatch(logLine)
|
||||||
|
if len(matches) < 2 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
uri := strings.ToLower(matches[1])
|
||||||
|
|
||||||
|
// 统计平台下载
|
||||||
|
if strings.Contains(uri, "mac") {
|
||||||
|
stats.MacClicks++
|
||||||
|
} else if strings.Contains(uri, "windows") {
|
||||||
|
stats.WindowsClicks++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return stats, nil
|
||||||
|
}
|
||||||
68
pkg/openinstall/channel_test.go
Normal file
68
pkg/openinstall/channel_test.go
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
package openinstall
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestChannelParameter 验证 OpenInstall 客户端是否正确传递了 channel 参数
|
||||||
|
func TestChannelParameter(t *testing.T) {
|
||||||
|
// 1. 启动 Mock Server
|
||||||
|
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// 验证请求路径
|
||||||
|
if r.URL.Path == "/data/sum/growth" {
|
||||||
|
// 验证 Query 参数
|
||||||
|
query := r.URL.Query()
|
||||||
|
channel := query.Get("channel")
|
||||||
|
|
||||||
|
// 核心验证点:channel 参数必须等于即使的 inviteCode
|
||||||
|
if channel == "TEST_INVITE_CODE_123" {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
// 返回假数据
|
||||||
|
w.Write([]byte(`{
|
||||||
|
"code": 0,
|
||||||
|
"body": [
|
||||||
|
{"key": "ios", "value": 100},
|
||||||
|
{"key": "android", "value": 200}
|
||||||
|
]
|
||||||
|
}`))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果 channel 不匹配,返回错误
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
w.Write([]byte(`{"code": 400, "error": "channel mismatch"}`))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
}))
|
||||||
|
defer mockServer.Close()
|
||||||
|
|
||||||
|
// 2. 临时修改 apiBaseURL 指向 Mock Server
|
||||||
|
originalBaseURL := apiBaseURL
|
||||||
|
apiBaseURL = mockServer.URL
|
||||||
|
defer func() { apiBaseURL = originalBaseURL }()
|
||||||
|
|
||||||
|
// 3. 初始化客户端
|
||||||
|
client := NewClient("test-api-key")
|
||||||
|
|
||||||
|
// 4. 调用接口 (传入测试用的邀请码)
|
||||||
|
ctx := context.Background()
|
||||||
|
stats, err := client.GetPlatformDownloads(ctx, "TEST_INVITE_CODE_123")
|
||||||
|
|
||||||
|
// 5. 验证结果
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotNil(t, stats)
|
||||||
|
|
||||||
|
// 验证数据正确解析 (iOS=100, Android=200, Total=300)
|
||||||
|
assert.Equal(t, int64(100), stats.IOS, "iOS count should match mock data")
|
||||||
|
assert.Equal(t, int64(200), stats.Android, "Android count should match mock data")
|
||||||
|
assert.Equal(t, int64(300), stats.Total, "Total count should match sum of mock data")
|
||||||
|
|
||||||
|
t.Logf("Success! Channel parameter 'TEST_INVITE_CODE_123' was correctly sent to server.")
|
||||||
|
}
|
||||||
57
pkg/openinstall/client_test.go
Normal file
57
pkg/openinstall/client_test.go
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
package openinstall
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestClient_GetPlatformDownloads_WithChannel(t *testing.T) {
|
||||||
|
// Mock Server
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Verify URL parameters
|
||||||
|
assert.Equal(t, "GET", r.Method)
|
||||||
|
assert.Equal(t, "/data/sum/growth", r.URL.Path)
|
||||||
|
assert.Equal(t, "test-api-key", r.URL.Query().Get("apiKey"))
|
||||||
|
assert.Equal(t, "test-channel", r.URL.Query().Get("channel")) // Verify channel is passed
|
||||||
|
assert.Equal(t, "total", r.URL.Query().Get("sumBy"))
|
||||||
|
assert.Equal(t, "0", r.URL.Query().Get("excludeDuplication"))
|
||||||
|
|
||||||
|
// Return mock response
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.Write([]byte(`{
|
||||||
|
"code": 0,
|
||||||
|
"body": [
|
||||||
|
{"key": "ios", "value": 10},
|
||||||
|
{"key": "android", "value": 20}
|
||||||
|
]
|
||||||
|
}`))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
// Redirect base URL to mock server (This requires modifying the constant in real code,
|
||||||
|
// but for this test script we can just verify the logic or make the URL configurable.
|
||||||
|
// Since apiBaseURL is a constant, we cannot change it.
|
||||||
|
// However, this test demonstrates the logic we implemented.
|
||||||
|
// For actual running, we might need to inject the URL or make it a variable.)
|
||||||
|
|
||||||
|
// NOTE: Since apiBaseURL is constant in standard Go we can't patch it easily without unsafe or changing code.
|
||||||
|
// But `getDeviceDistribution` constructs the URL using `apiBaseURL`.
|
||||||
|
// For the sake of this example, we assume we can test the parameter construction logic
|
||||||
|
// or we would need to refactor `apiBaseURL` to be a field in `Client`.
|
||||||
|
|
||||||
|
// Since I cannot change the constant easily to point to localhost in the compiled package
|
||||||
|
// without refactoring, I will provide a test that *would* work if we refactored,
|
||||||
|
// OR I can make the test just run against the real API but that requires a key.
|
||||||
|
|
||||||
|
// Plan B: Create a test that instantiates the client and checks the URL construction if we extracted that method,
|
||||||
|
// but we didn't.
|
||||||
|
|
||||||
|
// Let's refactor Client to allow base URL injection for testing?
|
||||||
|
// Or just provide a shell script for the user to run against real env provided they have keys.
|
||||||
|
// The user asked for a "Test Script", commonly meaning a shell script to run the API.
|
||||||
|
|
||||||
|
t.Log("This is a structural test example. To fully unit test HTTP requests with constants, refactoring is recommended.")
|
||||||
|
}
|
||||||
@ -10,7 +10,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
var (
|
||||||
// OpenInstall 数据接口基础 URL
|
// OpenInstall 数据接口基础 URL
|
||||||
apiBaseURL = "https://data.openinstall.com"
|
apiBaseURL = "https://data.openinstall.com"
|
||||||
)
|
)
|
||||||
@ -81,7 +81,7 @@ type DistributionData struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetPlatformDownloads 获取各端下载量统计(当月数据 + 环比)
|
// GetPlatformDownloads 获取各端下载量统计(当月数据 + 环比)
|
||||||
func (c *Client) GetPlatformDownloads(ctx context.Context) (*PlatformDownloads, error) {
|
func (c *Client) GetPlatformDownloads(ctx context.Context, channel string) (*PlatformDownloads, error) {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
|
|
||||||
// 当月数据:本月1号到今天
|
// 当月数据:本月1号到今天
|
||||||
@ -93,13 +93,13 @@ func (c *Client) GetPlatformDownloads(ctx context.Context) (*PlatformDownloads,
|
|||||||
endOfLastMonth := startOfMonth.AddDate(0, 0, -1)
|
endOfLastMonth := startOfMonth.AddDate(0, 0, -1)
|
||||||
|
|
||||||
// 获取当月各平台数据
|
// 获取当月各平台数据
|
||||||
currentMonthData, err := c.getPlatformData(ctx, startOfMonth, endOfMonth)
|
currentMonthData, err := c.getPlatformData(ctx, startOfMonth, endOfMonth, channel)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to get current month data: %w", err)
|
return nil, fmt.Errorf("failed to get current month data: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 获取上月各平台数据
|
// 获取上月各平台数据
|
||||||
lastMonthData, err := c.getPlatformData(ctx, startOfLastMonth, endOfLastMonth)
|
lastMonthData, err := c.getPlatformData(ctx, startOfLastMonth, endOfLastMonth, channel)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to get last month data: %w", err)
|
return nil, fmt.Errorf("failed to get last month data: %w", err)
|
||||||
}
|
}
|
||||||
@ -130,11 +130,11 @@ func (c *Client) GetPlatformDownloads(ctx context.Context) (*PlatformDownloads,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// getPlatformData 获取指定时间范围内各平台的数据
|
// getPlatformData 获取指定时间范围内各平台的数据
|
||||||
func (c *Client) getPlatformData(ctx context.Context, startDate, endDate time.Time) (*PlatformDownloads, error) {
|
func (c *Client) getPlatformData(ctx context.Context, startDate, endDate time.Time, channel string) (*PlatformDownloads, error) {
|
||||||
result := &PlatformDownloads{}
|
result := &PlatformDownloads{}
|
||||||
|
|
||||||
// 获取 iOS 数据
|
// 获取 iOS 数据
|
||||||
iosData, err := c.getDeviceDistribution(ctx, startDate, endDate, "ios", "total")
|
iosData, err := c.getDeviceDistribution(ctx, startDate, endDate, "ios", "total", channel)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to get iOS data: %w", err)
|
return nil, fmt.Errorf("failed to get iOS data: %w", err)
|
||||||
}
|
}
|
||||||
@ -143,7 +143,7 @@ func (c *Client) getPlatformData(ctx context.Context, startDate, endDate time.Ti
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 获取 Android 数据
|
// 获取 Android 数据
|
||||||
androidData, err := c.getDeviceDistribution(ctx, startDate, endDate, "android", "total")
|
androidData, err := c.getDeviceDistribution(ctx, startDate, endDate, "android", "total", channel)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to get Android data: %w", err)
|
return nil, fmt.Errorf("failed to get Android data: %w", err)
|
||||||
}
|
}
|
||||||
@ -159,7 +159,7 @@ func (c *Client) getPlatformData(ctx context.Context, startDate, endDate time.Ti
|
|||||||
}
|
}
|
||||||
|
|
||||||
// getDeviceDistribution 获取设备分布数据
|
// getDeviceDistribution 获取设备分布数据
|
||||||
func (c *Client) getDeviceDistribution(ctx context.Context, startDate, endDate time.Time, platform, sumBy string) ([]DistributionData, error) {
|
func (c *Client) getDeviceDistribution(ctx context.Context, startDate, endDate time.Time, platform, sumBy, channel string) ([]DistributionData, error) {
|
||||||
apiURL := fmt.Sprintf("%s/data/sum/growth", apiBaseURL)
|
apiURL := fmt.Sprintf("%s/data/sum/growth", apiBaseURL)
|
||||||
|
|
||||||
params := url.Values{}
|
params := url.Values{}
|
||||||
@ -168,6 +168,9 @@ func (c *Client) getDeviceDistribution(ctx context.Context, startDate, endDate t
|
|||||||
params.Add("endDate", endDate.Format("2006-01-02"))
|
params.Add("endDate", endDate.Format("2006-01-02"))
|
||||||
params.Add("platform", platform)
|
params.Add("platform", platform)
|
||||||
params.Add("sumBy", sumBy)
|
params.Add("sumBy", sumBy)
|
||||||
|
if channel != "" {
|
||||||
|
params.Add("channelCode", channel)
|
||||||
|
}
|
||||||
params.Add("excludeDuplication", "0")
|
params.Add("excludeDuplication", "0")
|
||||||
|
|
||||||
fullURL := fmt.Sprintf("%s?%s", apiURL, params.Encode())
|
fullURL := fmt.Sprintf("%s?%s", apiURL, params.Encode())
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user