This commit is contained in:
shanshanzhong 2026-02-08 18:49:14 -08:00
parent 709d657906
commit 28ada42ae5
14 changed files with 726 additions and 77 deletions

View File

@ -22,7 +22,7 @@ func main() {
ctx := context.Background()
// 调用 GetPlatformDownloads 获取当月数据+ 环比
platformDownloads, err := client.GetPlatformDownloads(ctx)
platformDownloads, err := client.GetPlatformDownloads(ctx, "")
if err != nil {
fmt.Printf("❌ 获取失败: %v\n", err)
return

View File

@ -0,0 +1,54 @@
# ALIGNMENT: 检查设备登录 IP 逻辑
## 原始需求
用户询问:
> "检查一下 设备登录的时候 有检查IP 么"
> "我需要知道 有没有IP 目前数据库都没有IP"
## 项目现状分析 (Context Analysis)
1. **数据模型 (`user.Device`)**:
- 存在 `Ip` 字段 (`varchar(255)`).
- 定义在 `/internal/model/user/user.go`
2. **业务逻辑 (`DeviceLoginLogic.go`)**:
- 在代码逻辑中,确实尝试获取 `req.IP` 并保存到数据库。
- 创建新设备时:`deviceInfo.Ip = req.IP`
- 记录登录日志时:`LoginIP: req.IP`
3. **关键问题点 (Root Cause)**:
- 输入参数定义在 `internal/types/types.go`:
```go
type DeviceLoginRequest struct {
// ...
IP string `header:"X-Original-Forwarded-For"`
// ...
}
```
- **当前仅支持 `X-Original-Forwarded-For` 请求头**
- 如果请求经过 Nginx、Cloudflare 等代理,但没有专门配置传递这个特定的 Header或者使用的是标准的 `X-Forwarded-For` / `X-Real-IP`,后端获取到的 `req.IP` 将为空字符串。
- 这就是导致“数据库都没有IP”的直接原因。
## 疑问澄清 (Questions)
我们需要确认修复方案:
1. **是否只需支持标准 Header?**
- 建议改为优先获取 `X-Forwarded-For`,其次 `X-Real-IP`,最后才是 `X-Original-Forwarded-For` 或直连 IP。
- go-zero 框架通常可以通过配置或中间件处理 IP或者我们在 struct tag 中调整。但 struct tag `header` 只能由 go-zero 的 rest 绑定一个特定的 key。
2. **是否需要记录 IP 归属地?**
- 目前逻辑只记录 IP 字符串,不解析归属地。需求中没提,暂时不作为重点,但可以确认一下。
## 建议方案
修改 `DeviceLoginRequest` 的定义可能不够灵活Header key 是固定的)。
更好的方式是:
1. **移除 Struct Tag 绑定**(或者保留作为备选)。
2. **在 Logic 中显式获取 IP**
- 从 `l.ctx` (Context) 中获取 `http.Request` (如果 go-zero 支持)。
- 或者在 Middleware 中解析真实 IP 并放入 Context。
- 或者简单点,修改 Struct Tag 为最常用的 `X-Forwarded-For` (如果确定环境是这样配置的)。
**最快修复**:
`internal/types/types.go` 中的 `X-Original-Forwarded-For` 改为 `X-Forwarded-For` (或者根据实际网关配置修改)。
但通常建议使用工具函数解析多种 Header。
## 下一步 (Next Step)
请确认是否要我修改代码以支持标准的 IP 获取方式(如 `X-Forwarded-For`

View File

@ -0,0 +1,36 @@
# DESIGN: Device Login IP Fix
## 目标
修复设备登录时无法获取真实 IP (`req.IP` 为空) 的问题,导致数据库未存储 IP。
## 现状
- `internal/types/types.go` 定义了 `DeviceLoginRequest`,其中 `IP` 字段绑定的是 `X-Original-Forwarded-For`
- 实际环境中Nginx/Cloudflare等通常使用 `X-Forwarded-For`
## 方案选择
由于项目使用 `go-zero` 并且存在 `.api` 文件,**最佳实践**是修改 `.api` 文件并重新生成代码。
但考虑到我无法运行 `goctl` (或者环境可能不一致),如果不重新生成而直接改 `types.go`,虽然能即时生效,但下次生成会被覆盖。
**然而**,鉴于我之前的操作已经直接修改过 `types.go` (Invite Sales Time Filter),且项目看似允许直接修改(或用户负责生成),我将**优先修改 `.api` 文件** 以保持源头正确,同时**手动同步修改 `types.go`** 以确保立即生效。
## 变更范围
### 1. API 定义 (`apis/auth/auth.api`)
- 修改 `DeviceLoginRequest` struct。
- 将 `header: X-Original-Forwarded-For` 改为 `header: X-Forwarded-For` (这是最通用的标准)。
### 2. 生成文件 (`internal/types/types.go`)
- 手动同步修改 `DeviceLoginRequest` 中的 Tag。
- 变为: `IP string header:"X-Forwarded-For"`
### 3. (可选增强) 业务逻辑 (`internal/logic/auth/deviceLoginLogic.go`)
- 由于 go-zero 的绑定机制比较“死”,如果 Tag 没取到值就是空的。Logic 层拿到空字符串也没办法再去 Context 捞(除非 Context 里存了 request
- 暂时只做 Tag 修改,因为这是最根本原因。
## 验证
- 检查代码变更。
- (无法直接测试 IP 获取,依赖用户部署验证)。
## 任务拆分
1. 修改 `apis/auth/auth.api`
2. 修改 `internal/types/types.go`

View File

@ -30,6 +30,7 @@ type Config struct {
Invite InviteConfig `yaml:"Invite"`
Kutt KuttConfig `yaml:"Kutt"`
OpenInstall OpenInstallConfig `yaml:"OpenInstall"`
Loki LokiConfig `yaml:"Loki"`
Telegram Telegram `yaml:"Telegram"`
Log Log `yaml:"Log"`
Trace trace.Config `yaml:"Trace"`
@ -227,6 +228,12 @@ type OpenInstallConfig struct {
ApiKey string `yaml:"ApiKey" default:""` // OpenInstall 数据接口 ApiKey
}
// LokiConfig Loki 日志查询配置
type LokiConfig struct {
Enable bool `yaml:"Enable" default:"false"` // 是否启用 Loki 查询
URL string `yaml:"URL" default:"http://localhost:3100"` // Loki 服务地址
}
type Telegram struct {
Enable bool `yaml:"Enable" default:"false"`
BotID int64 `yaml:"BotID" default:""`

View File

@ -40,7 +40,6 @@ import (
)
func RegisterHandlers(router *gin.Engine, serverCtx *svc.ServiceContext) {
router.Use(middleware.TraceMiddleware(serverCtx))
adminAdsGroupRouter := router.Group("/v1/admin/ads")
adminAdsGroupRouter.Use(middleware.AuthMiddleware(serverCtx))

View File

@ -9,6 +9,7 @@ import (
"github.com/perfect-panel/server/internal/types"
"github.com/perfect-panel/server/pkg/constant"
"github.com/perfect-panel/server/pkg/logger"
"github.com/perfect-panel/server/pkg/loki"
"github.com/perfect-panel/server/pkg/openinstall"
"github.com/perfect-panel/server/pkg/xerr"
"github.com/pkg/errors"
@ -20,6 +21,7 @@ type GetAgentDownloadsLogic struct {
svcCtx *svc.ServiceContext
}
// NewGetAgentDownloadsLogic 创建 GetAgentDownloadsLogic 实例
func NewGetAgentDownloadsLogic(ctx context.Context, svcCtx *svc.ServiceContext) *GetAgentDownloadsLogic {
return &GetAgentDownloadsLogic{
Logger: logger.WithContext(ctx),
@ -28,6 +30,8 @@ func NewGetAgentDownloadsLogic(ctx context.Context, svcCtx *svc.ServiceContext)
}
}
// GetAgentDownloads 获取用户代理下载统计数据
// 结合 OpenInstall (iOS/Android) 和 Loki (Windows/Mac) 数据源
func (l *GetAgentDownloadsLogic) GetAgentDownloads(req *types.GetAgentDownloadsRequest) (resp *types.GetAgentDownloadsResponse, err error) {
// 1. 从 context 获取用户信息
u, ok := l.ctx.Value(constant.CtxKeyUser).(*user.User)
@ -36,72 +40,68 @@ func (l *GetAgentDownloadsLogic) GetAgentDownloads(req *types.GetAgentDownloadsR
return nil, errors.Wrapf(xerr.NewErrCode(xerr.InvalidAccess), "Invalid Access")
}
// 2. 检查 OpenInstall 是否启用
cfg := l.svcCtx.Config.OpenInstall
if !cfg.Enable {
l.Infow("[GetAgentDownloads] OpenInstall is disabled, returning zero stats")
return &types.GetAgentDownloadsResponse{
Total: 0,
Platforms: &types.PlatformDownloads{
IOS: 0,
Android: 0,
Windows: 0,
Mac: 0,
},
}, nil
// 初始化响应数据
var iosCount, androidCount, windowsCount, macCount int64
var comparisonRate *string
// 2. 从 OpenInstall 获取 iOS/Android 数据
openInstallCfg := l.svcCtx.Config.OpenInstall
if openInstallCfg.Enable && openInstallCfg.ApiKey != "" {
client := openinstall.NewClient(openInstallCfg.ApiKey)
platformDownloads, err := client.GetPlatformDownloads(l.ctx, u.ReferCode)
if err != nil {
l.Errorw("Failed to fetch OpenInstall platform downloads", logger.Field("error", err), logger.Field("user_id", u.Id))
// 不返回错误,继续处理其他数据源
} else {
iosCount = platformDownloads.IOS
androidCount = platformDownloads.Android
// OpenInstall 的 Windows/Mac 数据可能为空,后面用 Loki 补充
// 计算环比
if platformDownloads.Comparison != nil {
percent := platformDownloads.Comparison.ChangePercent
var formatted string
if percent >= 0 {
formatted = fmt.Sprintf("+%.1f%%", percent)
} else {
formatted = fmt.Sprintf("%.1f%%", percent)
}
comparisonRate = &formatted
}
}
}
// 3. 检查 ApiKey 是否配置
if cfg.ApiKey == "" {
l.Errorw("[GetAgentDownloads] OpenInstall ApiKey not configured")
return &types.GetAgentDownloadsResponse{
Total: 0,
Platforms: &types.PlatformDownloads{
IOS: 0,
Android: 0,
Windows: 0,
Mac: 0,
},
}, nil
// 3. 从 Loki 获取 Windows/Mac 数据(基于用户邀请码)
lokiCfg := l.svcCtx.Config.Loki
if lokiCfg.Enable && lokiCfg.URL != "" && u.ReferCode != "" {
lokiClient := loki.NewClient(lokiCfg.URL)
lokiStats, err := lokiClient.GetInviteCodeStats(l.ctx, u.ReferCode, 30)
if err != nil {
l.Errorw("Failed to fetch Loki stats", logger.Field("error", err), logger.Field("user_id", u.Id), logger.Field("refer_code", u.ReferCode))
// 不返回错误,继续使用已有数据
} else {
// 使用 Loki 的 Windows/Mac 数据
windowsCount = lokiStats.WindowsClicks
macCount = lokiStats.MacClicks
l.Infow("Fetched Loki stats successfully",
logger.Field("user_id", u.Id),
logger.Field("refer_code", u.ReferCode),
logger.Field("windows", windowsCount),
logger.Field("mac", macCount))
}
}
// 4. 调用 OpenInstall API 获取各端下载量
client := openinstall.NewClient(cfg.ApiKey)
platformDownloads, err := client.GetPlatformDownloads(l.ctx)
if err != nil {
l.Errorw("Failed to fetch OpenInstall platform downloads", logger.Field("error", err), logger.Field("user_id", u.Id))
// 返回空数据而不是错误,避免影响前端显示
return &types.GetAgentDownloadsResponse{
Total: 0,
Platforms: &types.PlatformDownloads{
IOS: 0,
Android: 0,
Windows: 0,
Mac: 0,
},
}, nil
}
// 4. 计算总量
total := iosCount + androidCount + windowsCount + macCount
// 5. 构造响应
var comparisonRate *string
if platformDownloads.Comparison != nil {
percent := platformDownloads.Comparison.ChangePercent
var formatted string
if percent >= 0 {
formatted = fmt.Sprintf("+%.1f%%", percent)
} else {
formatted = fmt.Sprintf("%.1f%%", percent)
}
comparisonRate = &formatted
}
return &types.GetAgentDownloadsResponse{
Total: platformDownloads.Total,
Total: total,
Platforms: &types.PlatformDownloads{
IOS: platformDownloads.IOS,
Android: platformDownloads.Android,
Windows: platformDownloads.Windows,
Mac: platformDownloads.Mac,
IOS: iosCount,
Android: androidCount,
Windows: windowsCount,
Mac: macCount,
},
ComparisonRate: comparisonRate,
}, nil

View File

@ -40,11 +40,19 @@ func (l *GetInviteSalesLogic) GetInviteSales(req *types.GetInviteSalesRequest) (
// 2. Count total sales
var totalSales int64
err = l.svcCtx.DB.WithContext(l.ctx).
db := l.svcCtx.DB.WithContext(l.ctx).
Table("`order` o").
Joins("JOIN user u ON o.user_id = u.id").
Where("u.referer_id = ? AND o.status = ?", userId, 5).
Count(&totalSales).Error
Where("u.referer_id = ? AND o.status = ?", userId, 5)
if req.StartTime > 0 {
db = db.Where("o.updated_at >= FROM_UNIXTIME(?)", req.StartTime)
}
if req.EndTime > 0 {
db = db.Where("o.updated_at <= FROM_UNIXTIME(?)", req.EndTime)
}
err = db.Count(&totalSales).Error
if err != nil {
l.Errorw("[GetInviteSales] count sales failed",
logger.Field("error", err.Error()),
@ -75,13 +83,21 @@ func (l *GetInviteSalesLogic) GetInviteSales(req *types.GetInviteSalesRequest) (
}
var orderData []OrderWithUser
err = l.svcCtx.DB.WithContext(l.ctx).
query := l.svcCtx.DB.WithContext(l.ctx).
Table("`order` o").
Select("o.amount, CAST(UNIX_TIMESTAMP(o.updated_at) * 1000 AS SIGNED) as updated_at, u.id as user_id, s.name as product_name, o.quantity").
Joins("JOIN user u ON o.user_id = u.id").
Joins("LEFT JOIN subscribe s ON o.subscribe_id = s.id").
Where("u.referer_id = ? AND o.status = ?", userId, 5). // status 5: Finished
Order("o.updated_at DESC").
Where("u.referer_id = ? AND o.status = ?", userId, 5) // status 5: Finished
if req.StartTime > 0 {
query = query.Where("o.updated_at >= FROM_UNIXTIME(?)", req.StartTime)
}
if req.EndTime > 0 {
query = query.Where("o.updated_at <= FROM_UNIXTIME(?)", req.EndTime)
}
err = query.Order("o.updated_at DESC").
Limit(req.Size).
Offset(offset).
Scan(&orderData).Error

View File

@ -0,0 +1,168 @@
package user
import (
"context"
"fmt"
"os"
"testing"
"time"
"github.com/alicebob/miniredis/v2"
"github.com/perfect-panel/server/internal/config"
"github.com/perfect-panel/server/internal/model/order"
"github.com/perfect-panel/server/internal/model/subscribe"
"github.com/perfect-panel/server/internal/model/user"
"github.com/perfect-panel/server/internal/svc"
"github.com/perfect-panel/server/internal/types"
"github.com/perfect-panel/server/pkg/constant"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/assert"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
)
// setupTestSvcCtx 初始化测试上下文
func setupTestSvcCtx(t *testing.T) (*svc.ServiceContext, *gorm.DB) {
// 1. Setup Miniredis
mr, err := miniredis.Run()
assert.NoError(t, err)
rdb := redis.NewClient(&redis.Options{
Addr: mr.Addr(),
})
// 2. Setup GORM with SQLite
dbName := fmt.Sprintf("test_sales_%d.db", time.Now().UnixNano())
db, err := gorm.Open(sqlite.Open(dbName), &gorm.Config{})
assert.NoError(t, err)
t.Cleanup(func() {
os.Remove(dbName)
mr.Close()
})
// Migrate tables
_ = db.Migrator().CreateTable(&user.User{})
_ = db.Migrator().CreateTable(&subscribe.Subscribe{}) // Plan definition
_ = db.Migrator().CreateTable(&order.Order{})
// 3. Create ServiceContext
svcCtx := &svc.ServiceContext{
Redis: rdb,
DB: db,
Config: config.Config{},
UserModel: user.NewModel(db, rdb),
}
return svcCtx, db
}
func TestGetInviteSales_TimeFilter(t *testing.T) {
svcCtx, db := setupTestSvcCtx(t)
// 1. Prepare Data
// Referrer User (Current User)
referrer := &user.User{
Id: 100,
// Email removed (not in struct)
ReferCode: "REF100",
}
db.Create(referrer)
// Invited User
invitedUser := &user.User{
Id: 200,
// Email removed
RefererId: referrer.Id, // Linked to referrer
}
db.Create(invitedUser)
// Subscribe (Plan)
sub := &subscribe.Subscribe{
Id: 1,
Name: "Standard Plan",
}
db.Create(sub)
// Orders
// Order 1: Inside Range (2023-10-15)
timeIn := time.Date(2023, 10, 15, 12, 0, 0, 0, time.UTC)
db.Create(&order.Order{
UserId: invitedUser.Id,
OrderNo: "ORD001",
Status: 5, // Finished
Amount: 1000,
Quantity: 30,
SubscribeId: sub.Id,
UpdatedAt: timeIn,
})
// Order 2: Before Range (2023-09-15)
timeBefore := time.Date(2023, 9, 15, 12, 0, 0, 0, time.UTC)
db.Create(&order.Order{
UserId: invitedUser.Id,
OrderNo: "ORD002",
Status: 5, // Finished
Amount: 1000,
Quantity: 30,
SubscribeId: sub.Id,
UpdatedAt: timeBefore,
})
// Order 3: After Range (2023-11-15)
timeAfter := time.Date(2023, 11, 15, 12, 0, 0, 0, time.UTC)
db.Create(&order.Order{
UserId: invitedUser.Id,
OrderNo: "ORD003",
Status: 5, // Finished
Amount: 1000,
Quantity: 30,
SubscribeId: sub.Id,
UpdatedAt: timeAfter,
})
// Order 4: Wrong Status (2023-10-16) - Should be ignored
db.Create(&order.Order{
UserId: invitedUser.Id,
OrderNo: "ORD004",
Status: 1, // Pending
Amount: 1000,
Quantity: 30,
SubscribeId: sub.Id,
UpdatedAt: timeIn.Add(24 * time.Hour),
})
// 2. Execute Logic
// Context with current user
ctx := context.WithValue(context.Background(), constant.CtxKeyUser, referrer)
l := NewGetInviteSalesLogic(ctx, svcCtx)
// Filter for October 2023
startTime := time.Date(2023, 10, 1, 0, 0, 0, 0, time.UTC).Unix()
endTime := time.Date(2023, 10, 31, 23, 59, 59, 0, time.UTC).Unix()
req := &types.GetInviteSalesRequest{
Page: 1,
Size: 10,
StartTime: startTime, // 2023-10-01
EndTime: endTime, // 2023-10-31
}
resp, err := l.GetInviteSales(req)
assert.NoError(t, err)
// 3. Verify Results
// Should match exactly 1 order (ORD001)
assert.Equal(t, int64(1), resp.Total, "Should return exactly 1 order matching time range and status")
if assert.NotEmpty(t, resp.List) {
assert.Equal(t, 1, len(resp.List))
// Log result for debug
t.Logf("Found Sale: Amount=%.2f, Time=%d", resp.List[0].Amount, resp.List[0].UpdatedAt)
// Verify timestamp is roughly correct (millisecond precision in logic)
expectedMs := timeIn.Unix() * 1000
assert.Equal(t, expectedMs, resp.List[0].UpdatedAt)
} else {
t.Error("Returned list is empty")
}
}

View File

@ -6,8 +6,10 @@ import (
"fmt"
"io"
"net/http"
"regexp"
"strings"
model "github.com/perfect-panel/server/internal/model/user"
"github.com/perfect-panel/server/pkg/constant"
"github.com/gin-gonic/gin"
@ -31,6 +33,21 @@ func (w bodyLogWriter) Write(b []byte) (int, error) {
return w.ResponseWriter.Write(b)
}
// inviteCodeRegex matches invite code patterns in URLs like:
// /v1/common/client/download/file/Hi快VPN-mac-1.0.0-ic-uuSo11uy.dmg
// Matches: ic-XXXXX or ic_XXXXX before file extension
var inviteCodeRegex = regexp.MustCompile(`[-_]ic[-_]([a-zA-Z0-9]+)\.[a-zA-Z0-9]+$`)
// extractInviteCode extracts invite code from URL path
// Returns empty string if no invite code found
func extractInviteCode(path string) string {
matches := inviteCodeRegex.FindStringSubmatch(path)
if len(matches) >= 2 {
return matches[1]
}
return ""
}
// statusByWriter returns a span status code and message for an HTTP status code
// value returned by a server. Status codes in the 400-499 range are not
// returned as errors.
@ -48,7 +65,7 @@ func requestAttributes(req *http.Request) []attribute.KeyValue {
protoN := strings.SplitN(req.Proto, "/", 2)
remoteAddrN := strings.SplitN(req.RemoteAddr, ":", 2)
return []attribute.KeyValue{
attrs := []attribute.KeyValue{
semconv.HTTPRequestMethodKey.String(req.Method),
semconv.HTTPUserAgentKey.String(req.UserAgent()),
semconv.HTTPRequestContentLengthKey.Int64(req.ContentLength),
@ -65,6 +82,66 @@ func requestAttributes(req *http.Request) []attribute.KeyValue {
semconv.ClientAddressKey.String(remoteAddrN[0]),
semconv.ClientPortKey.String(remoteAddrN[1]),
}
// Extract invite code from URL path (e.g., /v1/common/client/download/file/Hi快VPN-mac-1.0.0-ic-uuSo11uy.dmg)
if inviteCode := extractInviteCode(req.URL.Path); inviteCode != "" {
attrs = append(attrs, attribute.String("affiliate.invite_code", inviteCode))
attrs = append(attrs, attribute.String("affiliate.source", "download_link"))
}
// Also check query parameter for invite code (e.g., ?ic=uuSo11uy)
if ic := req.URL.Query().Get("ic"); ic != "" {
attrs = append(attrs, attribute.String("affiliate.invite_code", ic))
attrs = append(attrs, attribute.String("affiliate.source", "query_param"))
}
return attrs
}
// userAttributes extracts user information from context and returns span attributes
func userAttributes(ctx context.Context) []attribute.KeyValue {
var attrs []attribute.KeyValue
// Get user info from context (set by authMiddleware)
if userInfo := ctx.Value(constant.CtxKeyUser); userInfo != nil {
if user, ok := userInfo.(*model.User); ok {
var email string
for _, method := range user.AuthMethods {
if method.AuthType == "email" {
email = method.AuthIdentifier
break
}
}
attrs = append(attrs,
attribute.Int64("user.id", user.Id),
attribute.String("user.email", email),
attribute.Bool("user.is_admin", *user.IsAdmin),
)
}
}
// Get session ID from context
if sessionID := ctx.Value(constant.CtxKeySessionID); sessionID != nil {
if sid, ok := sessionID.(string); ok {
attrs = append(attrs, attribute.String("user.session_id", sid))
}
}
// Get device ID from context
if deviceID := ctx.Value(constant.CtxKeyDeviceID); deviceID != nil {
if did, ok := deviceID.(int64); ok {
attrs = append(attrs, attribute.Int64("user.device_id", did))
}
}
// Get login type from context
if loginType := ctx.Value(constant.LoginType); loginType != nil {
if lt, ok := loginType.(string); ok {
attrs = append(attrs, attribute.String("user.login_type", lt))
}
}
return attrs
}
func TraceMiddleware(_ *svc.ServiceContext) func(ctx *gin.Context) {
@ -99,6 +176,9 @@ func TraceMiddleware(_ *svc.ServiceContext) func(ctx *gin.Context) {
semconv.HTTPRouteKey.String(c.FullPath()),
)
// Add user attributes from context (set by authMiddleware)
span.SetAttributes(userAttributes(ctx)...)
// Record Request Body (limit to 1MB)
if len(reqBody) > 0 {
limit := 1048576

View File

@ -948,8 +948,10 @@ type GetUserInviteStatsResponse struct {
}
type GetInviteSalesRequest struct {
Page int `form:"page" validate:"required"`
Size int `form:"size" validate:"required"`
Page int `form:"page" validate:"required"`
Size int `form:"size" validate:"required"`
StartTime int64 `form:"start_time,optional"`
EndTime int64 `form:"end_time,optional"`
}
type GetInviteSalesResponse struct {

159
pkg/loki/loki.go Normal file
View File

@ -0,0 +1,159 @@
package loki
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"regexp"
"strings"
"time"
)
// Client Loki 客户端
type Client struct {
url string
httpClient *http.Client
}
// NewClient 创建新的 Loki 客户端
// url: Loki 服务地址,例如 http://154.12.35.103:3100
func NewClient(url string) *Client {
return &Client{
url: url,
httpClient: &http.Client{
Timeout: 30 * time.Second,
},
}
}
// InviteCodeStats 邀请码统计数据
type InviteCodeStats struct {
MacClicks int64 `json:"mac_clicks"` // Mac 下载点击数
WindowsClicks int64 `json:"windows_clicks"` // Windows 下载点击数
LastMonthMac int64 `json:"last_month_mac"` // 上月 Mac 下载数
LastMonthWindows int64 `json:"last_month_windows"` // 上月 Windows 下载数
}
// LokiQueryResponse Loki 查询响应结构
type LokiQueryResponse struct {
Status string `json:"status"`
Data struct {
ResultType string `json:"resultType"`
Result []struct {
Stream map[string]string `json:"stream"`
Values [][]string `json:"values"` // [[timestamp, log_line], ...]
} `json:"result"`
} `json:"data"`
}
// GetInviteCodeStats 获取指定邀请码的下载统计
// inviteCode: 邀请码
// days: 统计天数默认30天
func (c *Client) GetInviteCodeStats(ctx context.Context, inviteCode string, days int) (*InviteCodeStats, error) {
if days <= 0 {
days = 30
}
now := time.Now().UTC()
startTime := now.Add(-time.Duration(days) * 24 * time.Hour)
// 上月时间范围
lastMonthEnd := startTime
lastMonthStart := startTime.Add(-time.Duration(days) * 24 * time.Hour)
// 查询本月数据
thisMonthStats, err := c.queryPeriodStats(ctx, inviteCode, startTime, now)
if err != nil {
return nil, fmt.Errorf("查询本月数据失败: %w", err)
}
// 查询上月数据
lastMonthStats, err := c.queryPeriodStats(ctx, inviteCode, lastMonthStart, lastMonthEnd)
if err != nil {
return nil, fmt.Errorf("查询上月数据失败: %w", err)
}
return &InviteCodeStats{
MacClicks: thisMonthStats.MacClicks,
WindowsClicks: thisMonthStats.WindowsClicks,
LastMonthMac: lastMonthStats.MacClicks,
LastMonthWindows: lastMonthStats.WindowsClicks,
}, nil
}
// queryPeriodStats 查询指定时间范围的统计数据
func (c *Client) queryPeriodStats(ctx context.Context, inviteCode string, startTime, endTime time.Time) (*InviteCodeStats, error) {
// 构建 Loki 查询
query := fmt.Sprintf(`{job="nginx_access", invite_code="%s"}`, inviteCode)
apiURL := fmt.Sprintf("%s/loki/api/v1/query_range", c.url)
params := url.Values{}
params.Add("query", query)
params.Add("start", startTime.Format(time.RFC3339))
params.Add("end", endTime.Format(time.RFC3339))
params.Add("limit", "5000")
fullURL := fmt.Sprintf("%s?%s", apiURL, params.Encode())
req, err := http.NewRequestWithContext(ctx, http.MethodGet, fullURL, nil)
if err != nil {
return nil, fmt.Errorf("创建请求失败: %w", err)
}
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("发送请求失败: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("Loki 返回错误状态码 %d: %s", resp.StatusCode, string(body))
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("读取响应失败: %w", err)
}
var lokiResp LokiQueryResponse
if err := json.Unmarshal(body, &lokiResp); err != nil {
return nil, fmt.Errorf("解析响应失败: %w", err)
}
// 解析日志行统计 Mac 和 Windows 下载
stats := &InviteCodeStats{}
// Nginx combined log format regex
// 格式: IP - - [time] "METHOD URI VERSION" STATUS BYTES "REFERER" "UA"
logPattern := regexp.MustCompile(`"[A-Z]+ ([^ ]+) `)
for _, result := range lokiResp.Data.Result {
for _, value := range result.Values {
if len(value) < 2 {
continue
}
logLine := value[1]
// 提取 URI
matches := logPattern.FindStringSubmatch(logLine)
if len(matches) < 2 {
continue
}
uri := strings.ToLower(matches[1])
// 统计平台下载
if strings.Contains(uri, "mac") {
stats.MacClicks++
} else if strings.Contains(uri, "windows") {
stats.WindowsClicks++
}
}
}
return stats, nil
}

View File

@ -0,0 +1,68 @@
package openinstall
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
)
// TestChannelParameter 验证 OpenInstall 客户端是否正确传递了 channel 参数
func TestChannelParameter(t *testing.T) {
// 1. 启动 Mock Server
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// 验证请求路径
if r.URL.Path == "/data/sum/growth" {
// 验证 Query 参数
query := r.URL.Query()
channel := query.Get("channel")
// 核心验证点channel 参数必须等于即使的 inviteCode
if channel == "TEST_INVITE_CODE_123" {
w.WriteHeader(http.StatusOK)
// 返回假数据
w.Write([]byte(`{
"code": 0,
"body": [
{"key": "ios", "value": 100},
{"key": "android", "value": 200}
]
}`))
return
}
// 如果 channel 不匹配,返回错误
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte(`{"code": 400, "error": "channel mismatch"}`))
return
}
w.WriteHeader(http.StatusNotFound)
}))
defer mockServer.Close()
// 2. 临时修改 apiBaseURL 指向 Mock Server
originalBaseURL := apiBaseURL
apiBaseURL = mockServer.URL
defer func() { apiBaseURL = originalBaseURL }()
// 3. 初始化客户端
client := NewClient("test-api-key")
// 4. 调用接口 (传入测试用的邀请码)
ctx := context.Background()
stats, err := client.GetPlatformDownloads(ctx, "TEST_INVITE_CODE_123")
// 5. 验证结果
assert.NoError(t, err)
assert.NotNil(t, stats)
// 验证数据正确解析 (iOS=100, Android=200, Total=300)
assert.Equal(t, int64(100), stats.IOS, "iOS count should match mock data")
assert.Equal(t, int64(200), stats.Android, "Android count should match mock data")
assert.Equal(t, int64(300), stats.Total, "Total count should match sum of mock data")
t.Logf("Success! Channel parameter 'TEST_INVITE_CODE_123' was correctly sent to server.")
}

View File

@ -0,0 +1,57 @@
package openinstall
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
)
func TestClient_GetPlatformDownloads_WithChannel(t *testing.T) {
// Mock Server
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Verify URL parameters
assert.Equal(t, "GET", r.Method)
assert.Equal(t, "/data/sum/growth", r.URL.Path)
assert.Equal(t, "test-api-key", r.URL.Query().Get("apiKey"))
assert.Equal(t, "test-channel", r.URL.Query().Get("channel")) // Verify channel is passed
assert.Equal(t, "total", r.URL.Query().Get("sumBy"))
assert.Equal(t, "0", r.URL.Query().Get("excludeDuplication"))
// Return mock response
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{
"code": 0,
"body": [
{"key": "ios", "value": 10},
{"key": "android", "value": 20}
]
}`))
}))
defer server.Close()
// Redirect base URL to mock server (This requires modifying the constant in real code,
// but for this test script we can just verify the logic or make the URL configurable.
// Since apiBaseURL is a constant, we cannot change it.
// However, this test demonstrates the logic we implemented.
// For actual running, we might need to inject the URL or make it a variable.)
// NOTE: Since apiBaseURL is constant in standard Go we can't patch it easily without unsafe or changing code.
// But `getDeviceDistribution` constructs the URL using `apiBaseURL`.
// For the sake of this example, we assume we can test the parameter construction logic
// or we would need to refactor `apiBaseURL` to be a field in `Client`.
// Since I cannot change the constant easily to point to localhost in the compiled package
// without refactoring, I will provide a test that *would* work if we refactored,
// OR I can make the test just run against the real API but that requires a key.
// Plan B: Create a test that instantiates the client and checks the URL construction if we extracted that method,
// but we didn't.
// Let's refactor Client to allow base URL injection for testing?
// Or just provide a shell script for the user to run against real env provided they have keys.
// The user asked for a "Test Script", commonly meaning a shell script to run the API.
t.Log("This is a structural test example. To fully unit test HTTP requests with constants, refactoring is recommended.")
}

View File

@ -10,7 +10,7 @@ import (
"time"
)
const (
var (
// OpenInstall 数据接口基础 URL
apiBaseURL = "https://data.openinstall.com"
)
@ -81,7 +81,7 @@ type DistributionData struct {
}
// GetPlatformDownloads 获取各端下载量统计(当月数据 + 环比)
func (c *Client) GetPlatformDownloads(ctx context.Context) (*PlatformDownloads, error) {
func (c *Client) GetPlatformDownloads(ctx context.Context, channel string) (*PlatformDownloads, error) {
now := time.Now()
// 当月数据本月1号到今天
@ -93,13 +93,13 @@ func (c *Client) GetPlatformDownloads(ctx context.Context) (*PlatformDownloads,
endOfLastMonth := startOfMonth.AddDate(0, 0, -1)
// 获取当月各平台数据
currentMonthData, err := c.getPlatformData(ctx, startOfMonth, endOfMonth)
currentMonthData, err := c.getPlatformData(ctx, startOfMonth, endOfMonth, channel)
if err != nil {
return nil, fmt.Errorf("failed to get current month data: %w", err)
}
// 获取上月各平台数据
lastMonthData, err := c.getPlatformData(ctx, startOfLastMonth, endOfLastMonth)
lastMonthData, err := c.getPlatformData(ctx, startOfLastMonth, endOfLastMonth, channel)
if err != nil {
return nil, fmt.Errorf("failed to get last month data: %w", err)
}
@ -130,11 +130,11 @@ func (c *Client) GetPlatformDownloads(ctx context.Context) (*PlatformDownloads,
}
// getPlatformData 获取指定时间范围内各平台的数据
func (c *Client) getPlatformData(ctx context.Context, startDate, endDate time.Time) (*PlatformDownloads, error) {
func (c *Client) getPlatformData(ctx context.Context, startDate, endDate time.Time, channel string) (*PlatformDownloads, error) {
result := &PlatformDownloads{}
// 获取 iOS 数据
iosData, err := c.getDeviceDistribution(ctx, startDate, endDate, "ios", "total")
iosData, err := c.getDeviceDistribution(ctx, startDate, endDate, "ios", "total", channel)
if err != nil {
return nil, fmt.Errorf("failed to get iOS data: %w", err)
}
@ -143,7 +143,7 @@ func (c *Client) getPlatformData(ctx context.Context, startDate, endDate time.Ti
}
// 获取 Android 数据
androidData, err := c.getDeviceDistribution(ctx, startDate, endDate, "android", "total")
androidData, err := c.getDeviceDistribution(ctx, startDate, endDate, "android", "total", channel)
if err != nil {
return nil, fmt.Errorf("failed to get Android data: %w", err)
}
@ -159,7 +159,7 @@ func (c *Client) getPlatformData(ctx context.Context, startDate, endDate time.Ti
}
// getDeviceDistribution 获取设备分布数据
func (c *Client) getDeviceDistribution(ctx context.Context, startDate, endDate time.Time, platform, sumBy string) ([]DistributionData, error) {
func (c *Client) getDeviceDistribution(ctx context.Context, startDate, endDate time.Time, platform, sumBy, channel string) ([]DistributionData, error) {
apiURL := fmt.Sprintf("%s/data/sum/growth", apiBaseURL)
params := url.Values{}
@ -168,6 +168,9 @@ func (c *Client) getDeviceDistribution(ctx context.Context, startDate, endDate t
params.Add("endDate", endDate.Format("2006-01-02"))
params.Add("platform", platform)
params.Add("sumBy", sumBy)
if channel != "" {
params.Add("channelCode", channel)
}
params.Add("excludeDuplication", "0")
fullURL := fmt.Sprintf("%s?%s", apiURL, params.Encode())