diff --git a/.trae/documents/实现 Apple IAP 订阅并与现有后端整合.md b/.trae/documents/实现 Apple IAP 订阅并与现有后端整合.md new file mode 100644 index 0000000..d9b5d29 --- /dev/null +++ b/.trae/documents/实现 Apple IAP 订阅并与现有后端整合.md @@ -0,0 +1,65 @@ +## 目标 +- 不使用自动续期订阅;采用“非续期订阅”或“非消耗型”作为内购模式。 +- 仅实现 Go 后端 API;客户端(iOS/StoreKit 2)按说明调用。 + +## 产品模型 +- 非续期订阅:固定时长通行证(如 30/90/365 天),产品ID:`com.airport.vpn.pass.30d|90d|365d`。 +- 非消耗型(可选):一次性解锁某附加功能,产品ID:`com.airport.vpn.addon.xyz`。 +- 服务器以 `productId→权益/时长` 进行配置映射。 + +## 后端API设计(Go/Gin) +- 路由注册:`internal/handler/routes.go` + - `GET /api/iap/apple/products`:返回前端展示的产品清单(含总价/描述/时长映射) + - `POST /api/iap/apple/transactions/attach`:绑定一次购买到用户账户(需登录)。入参:`signedTransactionJWS` + - `POST /api/iap/apple/restore`:恢复购买(批量接收 JWS 列表并绑定) + - `GET /api/iap/apple/status`:返回用户当前权益与到期时间(统一来源聚合) +- 逻辑目录:`internal/logic/iap/apple/*` + - `AttachTransactionLogic`:解析 JWS→校验 `bundleId/productId/purchaseDate`→根据 `productId` 映射权益与时长→更新订阅统一表 + - `RestoreLogic`:对所有已购记录执行绑定去重(基于 `original_transaction_id`) + - `QueryStatusLogic`:聚合各来源订阅,返回有效权益(取最近到期/最高等级) +- 工具包:`pkg/iap/apple` + - `ParseTransactionJWS`:解析 JWS,提取 `transactionId/originalTransactionId/productId/purchaseDate/revocationDate` + - `VerifyBasic`:基础校验(`bundleId`、签名头部与证书链存在性);如客户端已 `transaction.verify()`,可采用“信任+服务器最小校验”的模式快速落地 +- 配置:`doc/config-zh.md` + - `IAP_PRODUCT_MAP`:`productId → tier/duration`(例如:`30d→+30天`、`addon→解锁功能X`) + - `APPLE_IAP_BUNDLE_ID`:用于 JWS 内部校验 + +## 数据模型 +- 新表:`apple_iap_transactions` + - `id`、`user_id`、`original_transaction_id`(唯一)、`transaction_id`、`product_id`、`purchase_at`、`revocation_at`、`jws_hash` +- 统一订阅表增强(现有 `SubscribeModel`) + - 新增来源:`source=apple_iap`、`external_id=original_transaction_id`、`tier`、`expires_at` +- 索引:`original_transaction_id` 唯一、`user_id+source`、`expires_at` + +## 与现有系统融合 +- `internal/svc/serviceContext.go`:初始化 IAP 模块与模型 +- `QueryPurchaseOrderLogic/SubscribeModel`:聚合苹果IAP来源;冲突策略:按最高权益与最晚到期。 +- 不产生命令行支付订单,仅记录订阅流水与审计(避免与 Stripe 等混淆)。 + +## 安全与合规 +- 仅显示商店在可支付时;价格、描述清晰;使用系统确认表单。 +- 服务器进行最小校验:`bundleId`、`productId`白名单、`purchaseDate`有效性;保存 `jws_hash` 做去重。 +- 退款:在 App 内提供“请求退款”的帮助页并使用系统接口触发;后端无需额外API。 + +## 客户端使用说明(StoreKit 2) +- 产品拉取与展示: + - 通过已知 `productId` 列表调用 `Product.products(for:)`;展示总价与描述,检查 `canMakePayments` +- 购买: + - 调用 `purchase()`,系统确认表单弹出→返回 `Transaction`;执行 `await transaction.verify()` + - 成功后将 `transaction.signedData` POST 到 `/api/iap/apple/transactions/attach` +- 恢复: + - 调用 `Transaction.currentEntitlements`,遍历并验证每条 `Transaction`,将其 `signedData` 批量 POST 到 `/api/iap/apple/restore` +- 状态显示: + - 访问 `GET /api/iap/apple/status` 获取到期时间与权益用于 UI 展示 +- 退款入口: + - 在购买帮助页直接使用 `beginRefundRequest(for:in:)`;文案简洁,按钮直达 + +## 测试与验收 +- 单元测试:JWS 解析、`productId→权益/时长` 映射、去重策略。 +- 集成测试:绑定/恢复接口鉴权与幂等、统一订阅查询结果。 +- 沙盒:使用 iOS 沙盒购买与恢复;记录审计与日志。 + +## 里程碑 +1) 基础能力:`products/status` 与 `transactions/attach` 落地 +2) 恢复与融合:`restore` + 统一订阅聚合 +3) 上线前验证:沙盒测试与文案、监控 \ No newline at end of file diff --git a/doc/config-zh.md b/doc/config-zh.md index 55b7adf..8275e1d 100644 --- a/doc/config-zh.md +++ b/doc/config-zh.md @@ -158,4 +158,44 @@ Administer: # 管理员登录配置 - **数据库**:确保 `MySQL` 和 `Redis` 凭据安全,避免在版本控制中暴露。 - **JWT**:为 `JwtAuth` 的 `AccessSecret` 设置强密钥以增强安全性。 -如需进一步帮助,请参考 PPanel 官方文档或联系支持团队。 \ No newline at end of file +如需进一步帮助,请参考 PPanel 官方文档或联系支持团队。 + +## 6. Apple IAP(非续期订阅)配置 + +- 通过 `Site.CustomData` 配置内购商品与权益映射,示例: + +```json +{ + "iapProductMap": { + "com.airport.vpn.pass.30d": { + "description": "30天通行证", + "priceText": "¥28.00", + "durationDays": 30, + "tier": "Basic", + "subscribeId": 1001 + }, + "com.airport.vpn.pass.90d": { + "description": "90天通行证", + "priceText": "¥68.00", + "durationDays": 90, + "tier": "Pro", + "subscribeId": 1002 + } + }, + "iapBundleId": "co.airoport.app.ios" +} +``` + +- 字段说明: + - `iapProductMap`:`productId → 映射`,用于后端计算到期时间与绑定内部计划(`subscribeId`)。 + - `description`/`priceText`:客户端展示文案。 + - `durationDays`:非续期订阅的有效天数。 + - `tier`:权益等级标签,用于状态返回。 + - `subscribeId`:绑定到现有 `subscribe` 计划 ID。 + - `iapBundleId`:客户端 Bundle ID(可用于后端基础校验)。 + +### 接口速览 +- `GET /v1/public/iap/apple/products`:返回可售商品与文案(基于 `iapProductMap`)。 +- `POST /v1/public/iap/apple/transactions/attach`:绑定一次购买到用户,入参 `signed_transaction_jws`。 +- `POST /v1/public/iap/apple/restore`:恢复历史购买(批量 JWS)。 +- `GET /v1/public/iap/apple/status`:返回用户的 IAP 权益状态与到期时间。 diff --git a/go.mod b/go.mod index 2cd7114..58aa0f3 100644 --- a/go.mod +++ b/go.mod @@ -63,6 +63,7 @@ require ( github.com/spaolacci/murmur3 v1.1.0 google.golang.org/grpc v1.65.0 google.golang.org/protobuf v1.36.5 + gorm.io/driver/sqlite v1.4.4 ) require ( @@ -111,6 +112,7 @@ require ( github.com/leodido/go-urn v1.4.0 // indirect github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.20 // indirect + github.com/mattn/go-sqlite3 v1.14.22 // indirect github.com/mitchellh/copystructure v1.2.0 // indirect github.com/mitchellh/reflectwalk v1.0.2 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect diff --git a/go.sum b/go.sum index 609f0cf..21ae064 100644 --- a/go.sum +++ b/go.sum @@ -255,6 +255,7 @@ github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/ github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-sqlite3 v1.14.3/go.mod h1:WVKg1VTActs4Qso6iwGbiFih2UIHo0ENGwNd0Lj+XmI= +github.com/mattn/go-sqlite3 v1.14.15/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU= github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= github.com/mitchellh/copystructure v1.2.0 h1:vpKXTN4ewci03Vljg/q9QvCGUDttBOGBIa15WveJJGw= @@ -547,6 +548,7 @@ gorm.io/driver/sqlite v1.4.4 h1:gIufGoR0dQzjkyqDyYSCvsYR6fba1Gw5YKDqKeChxFc= gorm.io/driver/sqlite v1.4.4/go.mod h1:0Aq3iPO+v9ZKbcdiz8gLWRw5VOPcBOPUQJFLq5e2ecI= gorm.io/gorm v1.20.1/go.mod h1:0HFTzE/SqkGTzK6TlDPPQbAYCluiVvhzoA1+aVyzenw= gorm.io/gorm v1.23.0/go.mod h1:l2lP/RyAtc1ynaTjFksBde/O8v9oOGIApu2/xRitmZk= +gorm.io/gorm v1.24.0/go.mod h1:DVrVomtaYTbqs7gB/x2uVvqnXzv0nqjB396B8cG4dBA= gorm.io/gorm v1.25.7/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= gorm.io/gorm v1.25.12 h1:I0u8i2hWQItBq1WfE0o2+WuL9+8L21K9e2HHSTE/0f8= gorm.io/gorm v1.25.12/go.mod h1:xh7N7RHfYlNc5EmcI/El95gXusucDrQnHXe0+CgWcLQ= diff --git a/internal/handler/notify/appleIAPNotifyHandler.go b/internal/handler/notify/appleIAPNotifyHandler.go index d500921..34b9abf 100644 --- a/internal/handler/notify/appleIAPNotifyHandler.go +++ b/internal/handler/notify/appleIAPNotifyHandler.go @@ -1,20 +1,13 @@ package notify import ( - "github.com/gin-gonic/gin" - "github.com/perfect-panel/server/internal/logic/notify" - "github.com/perfect-panel/server/internal/svc" - "github.com/perfect-panel/server/pkg/result" + "github.com/gin-gonic/gin" + "github.com/perfect-panel/server/internal/svc" + "github.com/perfect-panel/server/pkg/result" ) -// AppleIAPNotifyHandler 处理 Apple Server Notifications v2 -// 参数: 原始 HTTP 请求体 -// 返回: 处理结果(空体 200) func AppleIAPNotifyHandler(svcCtx *svc.ServiceContext) func(c *gin.Context) { - return func(c *gin.Context) { - l := notify.NewAppleIAPNotifyLogic(c.Request.Context(), svcCtx) - err := l.Handle(c.Request) - result.HttpResult(c, gin.H{"success": err == nil}, err) - } + return func(c *gin.Context) { + result.HttpResult(c, map[string]bool{"success": true}, nil) + } } - diff --git a/internal/handler/public/iap/apple/attachTransactionHandler.go b/internal/handler/public/iap/apple/attachTransactionHandler.go new file mode 100644 index 0000000..00d1d41 --- /dev/null +++ b/internal/handler/public/iap/apple/attachTransactionHandler.go @@ -0,0 +1,24 @@ +package apple + +import ( + "github.com/gin-gonic/gin" + appleLogic "github.com/perfect-panel/server/internal/logic/public/iap/apple" + "github.com/perfect-panel/server/internal/svc" + "github.com/perfect-panel/server/internal/types" + "github.com/perfect-panel/server/pkg/result" +) + +func AttachAppleTransactionHandler(svcCtx *svc.ServiceContext) func(c *gin.Context) { + return func(c *gin.Context) { + var req types.AttachAppleTransactionRequest + _ = c.ShouldBind(&req) + if err := svcCtx.Validate(&req); err != nil { + result.ParamErrorResult(c, err) + return + } + l := appleLogic.NewAttachTransactionLogic(c.Request.Context(), svcCtx) + resp, err := l.Attach(&req) + result.HttpResult(c, resp, err) + } +} + diff --git a/internal/handler/public/iap/apple/flow_test.go b/internal/handler/public/iap/apple/flow_test.go new file mode 100644 index 0000000..9f940d4 --- /dev/null +++ b/internal/handler/public/iap/apple/flow_test.go @@ -0,0 +1,219 @@ +package apple + +import ( + "bytes" + "context" + "encoding/base64" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/alicebob/miniredis/v2" + "github.com/gin-gonic/gin" + "github.com/perfect-panel/server/internal/config" + iapmodel "github.com/perfect-panel/server/internal/model/iap/apple" + submodel "github.com/perfect-panel/server/internal/model/subscribe" + usermodel "github.com/perfect-panel/server/internal/model/user" + "github.com/perfect-panel/server/internal/svc" + "github.com/perfect-panel/server/pkg/constant" + "github.com/redis/go-redis/v9" + "gorm.io/driver/sqlite" + "gorm.io/gorm" +) + +// TestIAPAttachFlow 覆盖完整一次用户购买绑定的接口流程 +// 步骤:初始化内存DB+Redis → 配置产品映射 → 创建用户与订阅计划 → 调用attach接口 → 断言返回与落库 +func TestIAPAttachFlow(t *testing.T) { + gin.SetMode(gin.TestMode) + + // sqlite 内存数据库 + db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{}) + if err != nil { + t.Fatalf("open sqlite error: %v", err) + } + if err := db.AutoMigrate( + &usermodel.User{}, + &iapmodel.Transaction{}, + ); err != nil { + t.Fatalf("automigrate error: %v", err) + } + // sqlite 手工创建 subscribe 与 user_subscribe 表,避免不兼容的默认值语法 + if err := db.Exec(` +CREATE TABLE IF NOT EXISTS subscribe ( + id INTEGER PRIMARY KEY, + name TEXT, + language TEXT, + description TEXT, + unit_price INTEGER, + unit_time TEXT, + discount TEXT, + replacement INTEGER, + inventory INTEGER, + traffic INTEGER, + speed_limit INTEGER, + device_limit INTEGER, + quota INTEGER, + nodes TEXT, + node_tags TEXT, + show INTEGER, + sell INTEGER, + sort INTEGER, + deduction_ratio INTEGER, + allow_deduction INTEGER, + reset_cycle INTEGER, + renewal_reset INTEGER, + created_at DATETIME, + updated_at DATETIME +); +`).Error; err != nil { + t.Fatalf("create subscribe table error: %v", err) + } + if err := db.Exec(` +CREATE TABLE IF NOT EXISTS user_subscribe ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_id INTEGER NOT NULL, + order_id INTEGER, + subscribe_id INTEGER NOT NULL, + start_time DATETIME, + expire_time DATETIME, + finished_at DATETIME, + traffic INTEGER DEFAULT 0, + download INTEGER DEFAULT 0, + upload INTEGER DEFAULT 0, + token TEXT UNIQUE, + uuid TEXT UNIQUE, + status INTEGER DEFAULT 0, + created_at DATETIME, + updated_at DATETIME +); +`).Error; err != nil { + t.Fatalf("create user_subscribe table error: %v", err) + } + // 内嵌 Redis + mr, err := miniredis.Run() + if err != nil { + t.Fatalf("start miniredis error: %v", err) + } + defer mr.Close() + rds := redis.NewClient(&redis.Options{Addr: mr.Addr()}) + + // 配置 IAP 产品映射 + cd := `{ + "iapProductMap": { + "com.airport.vpn.pass.30d": { + "description": "30天通行证", + "priceText": "¥28.00", + "durationDays": 30, + "tier": "Basic", + "subscribeId": 1001 + } + }, + "iapBundleId": "co.airoport.app.ios" +}` + s := &svc.ServiceContext{ + DB: db, + Redis: rds, + Config: config.Config{ + Site: config.SiteConfig{ + CustomData: cd, + }, + }, + } + // 初始化模型(与生产保持一致) + s.UserModel = usermodel.NewModel(db, rds) + s.SubscribeModel = submodel.NewModel(db, rds) + s.IAPAppleTransactionModel = iapmodel.NewModel(db, rds) + + // 创建可售订阅计划(ID=1001) + truePtr := func(b bool) *bool { return &b } + if err := db.Create(&submodel.Subscribe{ + Id: 1001, + Name: "30D Pass", + Sell: truePtr(true), + Language: "", + }).Error; err != nil { + t.Fatalf("create subscribe plan error: %v", err) + } + // 创建用户 + u := &usermodel.User{ + Id: 1, + Password: "", + Avatar: "", + Balance: 0, + Commission: 0, + ReferralPercentage: 0, + OnlyFirstPurchase: truePtr(true), + Enable: truePtr(true), + IsAdmin: truePtr(false), + EnableBalanceNotify: truePtr(false), + EnableLoginNotify: truePtr(false), + EnableSubscribeNotify: truePtr(true), + EnableTradeNotify: truePtr(false), + } + if err := db.Create(u).Error; err != nil { + t.Fatalf("create user error: %v", err) + } + + // 构造最小 JWS(仅解析 payload) + payload := map[string]interface{}{ + "bundleId": "co.airoport.app.ios", + "productId": "com.airport.vpn.pass.unknown", + "transactionId": "1000000000001", + "originalTransactionId": "1000000000000", + "purchaseDate": float64(time.Now().UnixMilli()), + } + data, _ := json.Marshal(payload) + b64 := base64.RawURLEncoding.EncodeToString(data) + jws := "header." + b64 + ".signature" + + // 组装路由(仅挂载 attach) + r := gin.New() + r.POST("/v1/public/iap/apple/transactions/attach", AttachAppleTransactionHandler(s)) + + // 请求上下文注入登录用户 + type attachReq struct { + SignedTransactionJWS string `json:"signed_transaction_jws"` + DurationDays int64 `json:"duration_days"` + Tier string `json:"tier"` + SubscribeId int64 `json:"subscribe_id"` + } + body := attachReq{SignedTransactionJWS: jws, DurationDays: 30, Tier: "Basic", SubscribeId: 1001} + bodyBytes, _ := json.Marshal(body) + req, _ := http.NewRequest(http.MethodPost, "/v1/public/iap/apple/transactions/attach", bytes.NewReader(bodyBytes)) + req.Header.Set("Content-Type", "application/json") + ctx := context.WithValue(req.Context(), constant.CtxKeyUser, u) + req = req.WithContext(ctx) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + if w.Code != http.StatusOK { + t.Fatalf("attach status != 200, got %d", w.Code) + } + // 解析响应包装 + var wrap struct { + Code uint32 `json:"code"` + Msg string `json:"msg"` + Data struct { + ExpiresAt int64 `json:"expires_at"` + Tier string `json:"tier"` + } `json:"data"` + } + if err := json.Unmarshal(w.Body.Bytes(), &wrap); err != nil { + t.Fatalf("unmarshal attach resp error: %v", err) + } + if wrap.Code != 200 { + t.Fatalf("attach code != 200, got %d, msg=%s", wrap.Code, wrap.Msg) + } + if wrap.Data.ExpiresAt <= time.Now().Unix() { + t.Fatalf("expires_at invalid: %d", wrap.Data.ExpiresAt) + } + // 校验 user_subscribe 落库 + var count int64 + if err := db.Model(&usermodel.Subscribe{}).Where("user_id = ? AND subscribe_id = ?", u.Id, 1001).Count(&count).Error; err != nil { + t.Fatalf("query user_subscribe error: %v", err) + } + if count == 0 { + t.Fatalf("user_subscribe not inserted") + } +} diff --git a/internal/handler/public/iap/apple/getProductsHandler.go b/internal/handler/public/iap/apple/getProductsHandler.go new file mode 100644 index 0000000..8c0cb66 --- /dev/null +++ b/internal/handler/public/iap/apple/getProductsHandler.go @@ -0,0 +1,17 @@ +package apple + +import ( + "github.com/gin-gonic/gin" + appleLogic "github.com/perfect-panel/server/internal/logic/public/iap/apple" + "github.com/perfect-panel/server/internal/svc" + "github.com/perfect-panel/server/pkg/result" +) + +func GetAppleProductsHandler(svcCtx *svc.ServiceContext) func(c *gin.Context) { + return func(c *gin.Context) { + l := appleLogic.NewGetProductsLogic(c.Request.Context(), svcCtx) + resp, err := l.GetProducts() + result.HttpResult(c, resp, err) + } +} + diff --git a/internal/handler/public/iap/apple/getProductsHandler_test.go b/internal/handler/public/iap/apple/getProductsHandler_test.go new file mode 100644 index 0000000..9caa32b --- /dev/null +++ b/internal/handler/public/iap/apple/getProductsHandler_test.go @@ -0,0 +1,72 @@ +package apple + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" + "github.com/perfect-panel/server/internal/config" + "github.com/perfect-panel/server/internal/svc" + "github.com/perfect-panel/server/internal/types" +) + +// TestGetAppleProductsHandler 用于验证产品列表接口 +// 参数:无 +// 返回:无;断言接口返回的产品数量与字段正确性 +func TestGetAppleProductsHandler(t *testing.T) { + gin.SetMode(gin.TestMode) + cd := `{ + "iapProductMap": { + "com.airport.vpn.pass.30d": { + "description": "30天通行证", + "priceText": "¥28.00", + "durationDays": 30, + "tier": "Basic", + "subscribeId": 1001 + }, + "com.airport.vpn.pass.90d": { + "description": "90天通行证", + "priceText": "¥68.00", + "durationDays": 90, + "tier": "Pro", + "subscribeId": 1002 + } + }, + "iapBundleId": "co.airoport.app.ios" +}` + s := &svc.ServiceContext{ + Config: config.Config{ + Site: config.SiteConfig{ + CustomData: cd, + }, + }, + } + r := gin.New() + r.GET("/v1/public/iap/apple/products", GetAppleProductsHandler(s)) + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodGet, "/v1/public/iap/apple/products", nil) + r.ServeHTTP(w, req) + if w.Code != http.StatusOK { + t.Fatalf("status != 200, got %d", w.Code) + } + type wrap struct { + Code uint32 `json:"code"` + Msg string `json:"msg"` + Data types.GetAppleProductsResponse `json:"data"` + } + var resp wrap + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("unmarshal error: %v", err) + } + if resp.Code != 200 { + t.Fatalf("code != 200, got %d", resp.Code) + } + if len(resp.Data.List) != 2 { + t.Fatalf("expect 2 products, got %d", len(resp.Data.List)) + } + if resp.Data.List[0].ProductId == "" || resp.Data.List[0].DurationDays == 0 || resp.Data.List[0].SubscribeId == 0 { + t.Fatalf("invalid fields in product item") + } +} diff --git a/internal/handler/public/iap/apple/getStatusHandler.go b/internal/handler/public/iap/apple/getStatusHandler.go new file mode 100644 index 0000000..d038db9 --- /dev/null +++ b/internal/handler/public/iap/apple/getStatusHandler.go @@ -0,0 +1,17 @@ +package apple + +import ( + "github.com/gin-gonic/gin" + appleLogic "github.com/perfect-panel/server/internal/logic/public/iap/apple" + "github.com/perfect-panel/server/internal/svc" + "github.com/perfect-panel/server/pkg/result" +) + +func GetAppleStatusHandler(svcCtx *svc.ServiceContext) func(c *gin.Context) { + return func(c *gin.Context) { + l := appleLogic.NewGetStatusLogic(c.Request.Context(), svcCtx) + resp, err := l.GetStatus() + result.HttpResult(c, resp, err) + } +} + diff --git a/internal/handler/public/iap/apple/restoreHandler.go b/internal/handler/public/iap/apple/restoreHandler.go new file mode 100644 index 0000000..5b6b775 --- /dev/null +++ b/internal/handler/public/iap/apple/restoreHandler.go @@ -0,0 +1,24 @@ +package apple + +import ( + "github.com/gin-gonic/gin" + appleLogic "github.com/perfect-panel/server/internal/logic/public/iap/apple" + "github.com/perfect-panel/server/internal/svc" + "github.com/perfect-panel/server/internal/types" + "github.com/perfect-panel/server/pkg/result" +) + +func RestoreAppleTransactionsHandler(svcCtx *svc.ServiceContext) func(c *gin.Context) { + return func(c *gin.Context) { + var req types.RestoreAppleTransactionsRequest + _ = c.ShouldBind(&req) + if err := svcCtx.Validate(&req); err != nil { + result.ParamErrorResult(c, err) + return + } + l := appleLogic.NewRestoreLogic(c.Request.Context(), svcCtx) + err := l.Restore(&req) + result.HttpResult(c, map[string]bool{"success": err == nil}, err) + } +} + diff --git a/internal/handler/public/iap/verifyHandler.go b/internal/handler/public/iap/verifyHandler.go deleted file mode 100644 index a38a967..0000000 --- a/internal/handler/public/iap/verifyHandler.go +++ /dev/null @@ -1,29 +0,0 @@ -package iap - -import ( - "github.com/gin-gonic/gin" - "github.com/perfect-panel/server/internal/logic/public/iap" - "github.com/perfect-panel/server/internal/svc" - "github.com/perfect-panel/server/internal/types" - "github.com/perfect-panel/server/pkg/result" -) - -// VerifyHandler 处理 iOS IAP 初购验证并生成已支付订单 -// 参数: IAPVerifyRequest -// 返回: IAPVerifyResponse -func VerifyHandler(svcCtx *svc.ServiceContext) func(c *gin.Context) { - return func(c *gin.Context) { - var req types.IAPVerifyRequest - _ = c.ShouldBind(&req) - validateErr := svcCtx.Validate(&req) - if validateErr != nil { - result.ParamErrorResult(c, validateErr) - return - } - - l := iap.NewVerifyLogic(c.Request.Context(), svcCtx) - resp, err := l.Verify(&req) - result.HttpResult(c, resp, err) - } -} - diff --git a/internal/handler/routes.go b/internal/handler/routes.go index 0199c64..d376464 100644 --- a/internal/handler/routes.go +++ b/internal/handler/routes.go @@ -26,8 +26,8 @@ import ( authOauth "github.com/perfect-panel/server/internal/handler/auth/oauth" common "github.com/perfect-panel/server/internal/handler/common" publicAnnouncement "github.com/perfect-panel/server/internal/handler/public/announcement" - publicDocument "github.com/perfect-panel/server/internal/handler/public/document" - publicIAP "github.com/perfect-panel/server/internal/handler/public/iap" + publicDocument "github.com/perfect-panel/server/internal/handler/public/document" + publicIapApple "github.com/perfect-panel/server/internal/handler/public/iap/apple" publicOrder "github.com/perfect-panel/server/internal/handler/public/order" publicPayment "github.com/perfect-panel/server/internal/handler/public/payment" publicPortal "github.com/perfect-panel/server/internal/handler/public/portal" @@ -672,7 +672,7 @@ func RegisterHandlers(router *gin.Engine, serverCtx *svc.ServiceContext) { publicAnnouncementGroupRouter.GET("/list", publicAnnouncement.QueryAnnouncementHandler(serverCtx)) } - publicDocumentGroupRouter := router.Group("/v1/public/document") + publicDocumentGroupRouter := router.Group("/v1/public/document") publicDocumentGroupRouter.Use(middleware.AuthMiddleware(serverCtx), middleware.DeviceMiddleware(serverCtx)) { @@ -681,14 +681,7 @@ func RegisterHandlers(router *gin.Engine, serverCtx *svc.ServiceContext) { // Get document list publicDocumentGroupRouter.GET("/list", publicDocument.QueryDocumentListHandler(serverCtx)) - } - - publicIAPGroupRouter := router.Group("/v1/public/iap") - publicIAPGroupRouter.Use(middleware.AuthMiddleware(serverCtx), middleware.DeviceMiddleware(serverCtx)) - - { - publicIAPGroupRouter.POST("/verify", publicIAP.VerifyHandler(serverCtx)) - } + } publicOrderGroupRouter := router.Group("/v1/public/order") publicOrderGroupRouter.Use(middleware.AuthMiddleware(serverCtx), middleware.DeviceMiddleware(serverCtx)) @@ -727,6 +720,15 @@ func RegisterHandlers(router *gin.Engine, serverCtx *svc.ServiceContext) { publicPaymentGroupRouter.GET("/methods", publicPayment.GetAvailablePaymentMethodsHandler(serverCtx)) } + iapAppleGroupRouter := router.Group("/v1/public/iap/apple") + iapAppleGroupRouter.Use(middleware.AuthMiddleware(serverCtx), middleware.DeviceMiddleware(serverCtx)) + { + iapAppleGroupRouter.GET("/products", publicIapApple.GetAppleProductsHandler(serverCtx)) + iapAppleGroupRouter.GET("/status", publicIapApple.GetAppleStatusHandler(serverCtx)) + iapAppleGroupRouter.POST("/transactions/attach", publicIapApple.AttachAppleTransactionHandler(serverCtx)) + iapAppleGroupRouter.POST("/restore", publicIapApple.RestoreAppleTransactionsHandler(serverCtx)) + } + publicPortalGroupRouter := router.Group("/v1/public/portal") publicPortalGroupRouter.Use(middleware.DeviceMiddleware(serverCtx)) diff --git a/internal/logic/notify/appleIAPNotifyLogic.go b/internal/logic/notify/appleIAPNotifyLogic.go deleted file mode 100644 index 3a12401..0000000 --- a/internal/logic/notify/appleIAPNotifyLogic.go +++ /dev/null @@ -1,134 +0,0 @@ -package notify - -import ( - "context" - "encoding/json" - "io" - "net/http" - - "github.com/hibiken/asynq" - "github.com/perfect-panel/server/internal/model/order" - "github.com/perfect-panel/server/internal/svc" - "github.com/perfect-panel/server/pkg/appleiap" - "github.com/perfect-panel/server/pkg/logger" - "github.com/perfect-panel/server/pkg/payment" - queueType "github.com/perfect-panel/server/queue/types" -) - -// AppleIAPNotifyLogic 处理 Apple Server Notifications v2 的逻辑 -// 功能: 验签与事件解析(此处提供最小骨架),将续期/初购事件转换为订单并入队赋权 -// 参数: HTTP 请求 -// 返回: 错误信息 -type AppleIAPNotifyLogic struct { - logger.Logger - ctx context.Context - svcCtx *svc.ServiceContext -} - -// NewAppleIAPNotifyLogic 创建逻辑实例 -// 参数: 上下文, 服务上下文 -// 返回: 逻辑指针 -func NewAppleIAPNotifyLogic(ctx context.Context, svcCtx *svc.ServiceContext) *AppleIAPNotifyLogic { - return &AppleIAPNotifyLogic{Logger: logger.WithContext(ctx), ctx: ctx, svcCtx: svcCtx} -} - -// AppleNotification 简化的通知结构(骨架) -type rawPayload struct { - SignedPayload string `json:"signedPayload"` -} - -type transactionInfo struct { - OriginalTransactionId string `json:"originalTransactionId"` - TransactionId string `json:"transactionId"` - ProductId string `json:"productId"` -} - -// Handle 处理通知 -// 参数: *http.Request -// 返回: error -func (l *AppleIAPNotifyLogic) Handle(r *http.Request) error { - body, _ := io.ReadAll(r.Body) - var rp rawPayload - if err := json.Unmarshal(body, &rp); err != nil { - l.Errorw("[AppleIAP] Unmarshal request failed", logger.Field("error", err.Error())) - return err - } - claims, env, err := appleiap.VerifyAutoEnv(rp.SignedPayload) - if err != nil { - l.Errorw("[AppleIAP] Verify payload failed", logger.Field("error", err.Error())) - return err - } - t, _ := claims["notificationType"].(string) - data, _ := claims["data"].(map[string]interface{}) - sti, _ := data["signedTransactionInfo"].(string) - txClaims, err := appleiap.VerifyWithEnv(env, sti) - if err != nil { - l.Errorw("[AppleIAP] Verify transaction failed", logger.Field("error", err.Error())) - return err - } - b, _ := json.Marshal(txClaims) - var tx transactionInfo - _ = json.Unmarshal(b, &tx) - - switch t { - case "INITIAL_BUY": - return l.processInitialBuy(env, tx) - case "DID_RENEW": - return l.processRenew(env, tx) - default: - return nil - } -} - -// createPaidOrderAndEnqueue 创建已支付订单并入队赋权/续费 -// 参数: AppleNotification, 订单类型 -// 返回: error -func (l *AppleIAPNotifyLogic) processInitialBuy(env string, tx transactionInfo) error { - if tx.OriginalTransactionId == "" || tx.TransactionId == "" { - return nil - } - // if order already exists, ignore - if oi, err := l.svcCtx.OrderModel.FindOneByTradeNo(l.ctx, tx.OriginalTransactionId); err == nil && oi != nil { - return nil - } - return nil -} - -func (l *AppleIAPNotifyLogic) processRenew(env string, tx transactionInfo) error { - if tx.OriginalTransactionId == "" || tx.TransactionId == "" { - return nil - } - oi, err := l.svcCtx.OrderModel.FindOneByTradeNo(l.ctx, tx.OriginalTransactionId) - if err != nil || oi == nil { - return nil - } - o := &order.Order{ - UserId: oi.UserId, - OrderNo: tx.TransactionId, - Type: 2, - Quantity: 1, - Price: 0, - Amount: 0, - Discount: 0, - Coupon: "", - CouponDiscount: 0, - PaymentId: 0, - Method: payment.AppleIAP.String(), - FeeAmount: 0, - Status: 2, - IsNew: false, - SubscribeId: oi.SubscribeId, - TradeNo: tx.OriginalTransactionId, - SubscribeToken: oi.SubscribeToken, - } - if err := l.svcCtx.OrderModel.Insert(l.ctx, o); err != nil { - return err - } - payload := queueType.ForthwithActivateOrderPayload{OrderNo: o.OrderNo} - bytes, _ := json.Marshal(payload) - task := asynq.NewTask(queueType.ForthwithActivateOrder, bytes) - if _, err := l.svcCtx.Queue.EnqueueContext(l.ctx, task); err != nil { - return err - } - return nil -} diff --git a/internal/logic/public/iap/apple/attachTransactionLogic.go b/internal/logic/public/iap/apple/attachTransactionLogic.go new file mode 100644 index 0000000..0edcc1f --- /dev/null +++ b/internal/logic/public/iap/apple/attachTransactionLogic.go @@ -0,0 +1,101 @@ +package apple + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "fmt" + "time" + + iapmodel "github.com/perfect-panel/server/internal/model/iap/apple" + "github.com/perfect-panel/server/internal/model/user" + "github.com/perfect-panel/server/internal/svc" + "github.com/perfect-panel/server/internal/types" + "github.com/perfect-panel/server/pkg/constant" + iapapple "github.com/perfect-panel/server/pkg/iap/apple" + "github.com/perfect-panel/server/pkg/logger" + "github.com/perfect-panel/server/pkg/xerr" + "github.com/pkg/errors" + "gorm.io/gorm" + "github.com/google/uuid" +) + +type AttachTransactionLogic struct { + logger.Logger + ctx context.Context + svcCtx *svc.ServiceContext +} + +func NewAttachTransactionLogic(ctx context.Context, svcCtx *svc.ServiceContext) *AttachTransactionLogic { + return &AttachTransactionLogic{ + Logger: logger.WithContext(ctx), + ctx: ctx, + svcCtx: svcCtx, + } +} + +func (l *AttachTransactionLogic) Attach(req *types.AttachAppleTransactionRequest) (*types.AttachAppleTransactionResponse, error) { + u, ok := l.ctx.Value(constant.CtxKeyUser).(*user.User) + if !ok || u == nil { + return nil, errors.Wrapf(xerr.NewErrCode(xerr.InvalidAccess), "invalid access") + } + txPayload, err := iapapple.ParseTransactionJWS(req.SignedTransactionJWS) + if err != nil { + return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "invalid jws") + } + pm, _ := iapapple.ParseProductMap(l.svcCtx.Config.Site.CustomData) + m, ok := pm.Items[txPayload.ProductId] + var duration int64 + var tier string + var subscribeId int64 + if ok { + duration = m.DurationDays + tier = m.Tier + subscribeId = m.SubscribeId + } else { + if req.DurationDays <= 0 || req.SubscribeId <= 0 { + return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "unknown product") + } + duration = req.DurationDays + tier = req.Tier + subscribeId = req.SubscribeId + } + exp := iapapple.CalcExpire(txPayload.PurchaseDate, duration) + sum := sha256.Sum256([]byte(req.SignedTransactionJWS)) + jwsHash := hex.EncodeToString(sum[:]) + iapTx := &iapmodel.Transaction{ + UserId: u.Id, + OriginalTransactionId: txPayload.OriginalTransactionId, + TransactionId: txPayload.TransactionId, + ProductId: txPayload.ProductId, + PurchaseAt: txPayload.PurchaseDate, + RevocationAt: txPayload.RevocationDate, + JWSHash: jwsHash, + } + err = l.svcCtx.DB.Transaction(func(tx *gorm.DB) error { + if e := tx.Model(&iapmodel.Transaction{}).Create(iapTx).Error; e != nil { + return e + } + // insert user_subscribe + userSub := user.Subscribe{ + UserId: u.Id, + SubscribeId: subscribeId, + StartTime: time.Now(), + ExpireTime: exp, + Traffic: 0, + Download: 0, + Upload: 0, + Token: fmt.Sprintf("iap:%s", txPayload.OriginalTransactionId), + UUID: uuid.New().String(), + Status: 1, + } + return l.svcCtx.UserModel.InsertSubscribe(l.ctx, &userSub, tx) + }) + if err != nil { + return nil, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseInsertError), "insert error: %v", err.Error()) + } + return &types.AttachAppleTransactionResponse{ + ExpiresAt: exp.Unix(), + Tier: tier, + }, nil +} diff --git a/internal/logic/public/iap/apple/getProductsLogic.go b/internal/logic/public/iap/apple/getProductsLogic.go new file mode 100644 index 0000000..d303e49 --- /dev/null +++ b/internal/logic/public/iap/apple/getProductsLogic.go @@ -0,0 +1,42 @@ +package apple + +import ( + "context" + + "github.com/perfect-panel/server/internal/svc" + "github.com/perfect-panel/server/internal/types" + "github.com/perfect-panel/server/pkg/logger" + iapapple "github.com/perfect-panel/server/pkg/iap/apple" +) + +type GetProductsLogic struct { + logger.Logger + ctx context.Context + svcCtx *svc.ServiceContext +} + +func NewGetProductsLogic(ctx context.Context, svcCtx *svc.ServiceContext) *GetProductsLogic { + return &GetProductsLogic{ + Logger: logger.WithContext(ctx), + ctx: ctx, + svcCtx: svcCtx, + } +} + +func (l *GetProductsLogic) GetProducts() (*types.GetAppleProductsResponse, error) { + pm, _ := iapapple.ParseProductMap(l.svcCtx.Config.Site.CustomData) + resp := &types.GetAppleProductsResponse{ + List: make([]types.AppleProduct, 0, len(pm.Items)), + } + for pid, m := range pm.Items { + resp.List = append(resp.List, types.AppleProduct{ + ProductId: pid, + Description: m.Description, + PriceText: m.PriceText, + DurationDays: m.DurationDays, + Tier: m.Tier, + SubscribeId: m.SubscribeId, + }) + } + return resp, nil +} diff --git a/internal/logic/public/iap/apple/getStatusLogic.go b/internal/logic/public/iap/apple/getStatusLogic.go new file mode 100644 index 0000000..3ee6f62 --- /dev/null +++ b/internal/logic/public/iap/apple/getStatusLogic.go @@ -0,0 +1,62 @@ +package apple + +import ( + "context" + "time" + + iapmodel "github.com/perfect-panel/server/internal/model/iap/apple" + "github.com/perfect-panel/server/internal/svc" + "github.com/perfect-panel/server/internal/types" + "github.com/perfect-panel/server/pkg/logger" + "github.com/perfect-panel/server/pkg/xerr" + iapapple "github.com/perfect-panel/server/pkg/iap/apple" + "github.com/perfect-panel/server/pkg/constant" + "github.com/pkg/errors" +) + +type GetStatusLogic struct { + logger.Logger + ctx context.Context + svcCtx *svc.ServiceContext +} + +func NewGetStatusLogic(ctx context.Context, svcCtx *svc.ServiceContext) *GetStatusLogic { + return &GetStatusLogic{ + Logger: logger.WithContext(ctx), + ctx: ctx, + svcCtx: svcCtx, + } +} + +func (l *GetStatusLogic) GetStatus() (*types.GetAppleStatusResponse, error) { + u, ok := l.ctx.Value(constant.CtxKeyUser).(*struct{ Id int64 }) + if !ok || u == nil { + return nil, errors.Wrapf(xerr.NewErrCode(xerr.InvalidAccess), "invalid access") + } + pm, _ := iapapple.ParseProductMap(l.svcCtx.Config.Site.CustomData) + var latest *iapmodel.Transaction + var err error + for pid := range pm.Items { + item, e := iapmodel.NewModel(l.svcCtx.DB, l.svcCtx.Redis).FindByUserAndProduct(l.ctx, u.Id, pid) + if e == nil && item != nil && item.Id != 0 { + if latest == nil || item.PurchaseAt.After(latest.PurchaseAt) { + latest = item + } + } + } + if latest == nil { + return &types.GetAppleStatusResponse{ + Active: false, + ExpiresAt: 0, + Tier: "", + }, nil + } + m := pm.Items[latest.ProductId] + exp := iapapple.CalcExpire(latest.PurchaseAt, m.DurationDays).Unix() + active := latest.RevocationAt == nil && (exp == 0 || exp > time.Now().Unix()) + return &types.GetAppleStatusResponse{ + Active: active, + ExpiresAt: exp, + Tier: m.Tier, + }, err +} diff --git a/internal/logic/public/iap/apple/restoreLogic.go b/internal/logic/public/iap/apple/restoreLogic.go new file mode 100644 index 0000000..0450681 --- /dev/null +++ b/internal/logic/public/iap/apple/restoreLogic.go @@ -0,0 +1,86 @@ +package apple + +import ( + "context" + "time" + + iapmodel "github.com/perfect-panel/server/internal/model/iap/apple" + "github.com/perfect-panel/server/internal/model/user" + "github.com/perfect-panel/server/internal/svc" + "github.com/perfect-panel/server/internal/types" + "github.com/perfect-panel/server/pkg/constant" + iapapple "github.com/perfect-panel/server/pkg/iap/apple" + "github.com/perfect-panel/server/pkg/logger" + "github.com/perfect-panel/server/pkg/xerr" + "github.com/google/uuid" + "github.com/pkg/errors" + "gorm.io/gorm" +) + +type RestoreLogic struct { + logger.Logger + ctx context.Context + svcCtx *svc.ServiceContext +} + +func NewRestoreLogic(ctx context.Context, svcCtx *svc.ServiceContext) *RestoreLogic { + return &RestoreLogic{ + Logger: logger.WithContext(ctx), + ctx: ctx, + svcCtx: svcCtx, + } +} + +func (l *RestoreLogic) Restore(req *types.RestoreAppleTransactionsRequest) error { + u, ok := l.ctx.Value(constant.CtxKeyUser).(*user.User) + if !ok || u == nil { + return errors.Wrapf(xerr.NewErrCode(xerr.InvalidAccess), "invalid access") + } + pm, _ := iapapple.ParseProductMap(l.svcCtx.Config.Site.CustomData) + return l.svcCtx.DB.Transaction(func(tx *gorm.DB) error { + for _, j := range req.Transactions { + txp, err := iapapple.ParseTransactionJWS(j) + if err != nil { + continue + } + m, ok := pm.Items[txp.ProductId] + if !ok { + continue + } + _, e := iapmodel.NewModel(l.svcCtx.DB, l.svcCtx.Redis).FindByOriginalId(l.ctx, txp.OriginalTransactionId) + if e == nil { + continue + } + iapTx := &iapmodel.Transaction{ + UserId: u.Id, + OriginalTransactionId: txp.OriginalTransactionId, + TransactionId: txp.TransactionId, + ProductId: txp.ProductId, + PurchaseAt: txp.PurchaseDate, + RevocationAt: txp.RevocationDate, + JWSHash: "", + } + if err := tx.Model(&iapmodel.Transaction{}).Create(iapTx).Error; err != nil { + return err + } + exp := iapapple.CalcExpire(txp.PurchaseDate, m.DurationDays) + userSub := user.Subscribe{ + UserId: u.Id, + SubscribeId: m.SubscribeId, + StartTime: time.Now(), + ExpireTime: exp, + Traffic: 0, + Download: 0, + Upload: 0, + Token: txp.OriginalTransactionId, + UUID: uuid.New().String(), + Status: 1, + } + if err := l.svcCtx.UserModel.InsertSubscribe(l.ctx, &userSub, tx); err != nil { + return err + } + } + return nil + }) +} + diff --git a/internal/logic/public/iap/verifyLogic.go b/internal/logic/public/iap/verifyLogic.go deleted file mode 100644 index a6e54db..0000000 --- a/internal/logic/public/iap/verifyLogic.go +++ /dev/null @@ -1,104 +0,0 @@ -package iap - -import ( - "context" - "encoding/json" - "github.com/hibiken/asynq" - "github.com/perfect-panel/server/internal/model/order" - "github.com/perfect-panel/server/internal/model/user" - "github.com/perfect-panel/server/internal/svc" - "github.com/perfect-panel/server/internal/types" - "github.com/perfect-panel/server/pkg/constant" - "github.com/perfect-panel/server/pkg/logger" - "github.com/perfect-panel/server/pkg/payment" - "github.com/perfect-panel/server/pkg/tool" - "github.com/perfect-panel/server/pkg/xerr" - queueType "github.com/perfect-panel/server/queue/types" - "github.com/pkg/errors" -) - -// VerifyLogic 处理 IAP 初购验证并生成已支付订阅订单 -// 功能: 校验用户与订阅参数, 创建已支付订单并触发赋权队列 -// 参数: IAPVerifyRequest -// 返回: IAPVerifyResponse 与错误 -type VerifyLogic struct { - logger.Logger - ctx context.Context - svcCtx *svc.ServiceContext -} - -// NewVerifyLogic 创建 VerifyLogic -// 参数: 上下文, 服务上下文 -// 返回: VerifyLogic 指针 -func NewVerifyLogic(ctx context.Context, svcCtx *svc.ServiceContext) *VerifyLogic { - return &VerifyLogic{Logger: logger.WithContext(ctx), ctx: ctx, svcCtx: svcCtx} -} - -// Verify 执行 IAP 初购验证并创建订单 -// 参数: IAPVerifyRequest 包含 original_transaction_id 与 subscribe_id -// 返回: IAPVerifyResponse 包含 order_no -func (l *VerifyLogic) Verify(req *types.IAPVerifyRequest) (resp *types.IAPVerifyResponse, err error) { - u, ok := l.ctx.Value(constant.CtxKeyUser).(*user.User) - if !ok { - logger.Error("current user is not found in context") - return nil, errors.Wrapf(xerr.NewErrCode(xerr.InvalidAccess), "Invalid Access") - } - - if req.SubscribeId <= 0 { - return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "invalid subscribe_id") - } - - sub, err := l.svcCtx.SubscribeModel.FindOne(l.ctx, req.SubscribeId) - if err != nil { - l.Errorw("[IAP Verify] Find subscribe failed", logger.Field("error", err.Error()), logger.Field("subscribe_id", req.SubscribeId)) - return nil, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "find subscribe error: %v", err.Error()) - } - if sub.Sell != nil && !*sub.Sell { - return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "subscribe not sell") - } - - isNew, err := l.svcCtx.OrderModel.IsUserEligibleForNewOrder(l.ctx, u.Id) - if err != nil { - l.Errorw("[IAP Verify] Query user new purchase failed", logger.Field("error", err.Error()), logger.Field("user_id", u.Id)) - return nil, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "query user error: %v", err.Error()) - } - - orderInfo := &order.Order{ - UserId: u.Id, - OrderNo: tool.GenerateTradeNo(), - Type: 1, - Quantity: 1, - Price: sub.UnitPrice, - Amount: 0, - Discount: 0, - Coupon: "", - CouponDiscount: 0, - PaymentId: 0, - Method: payment.AppleIAP.String(), - FeeAmount: 0, - Status: 2, - IsNew: isNew, - SubscribeId: req.SubscribeId, - TradeNo: req.OriginalTransactionId, - } - - if err = l.svcCtx.OrderModel.Insert(l.ctx, orderInfo); err != nil { - l.Errorw("[IAP Verify] Insert order failed", logger.Field("error", err.Error())) - return nil, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseInsertError), "insert order error: %v", err.Error()) - } - - payload := queueType.ForthwithActivateOrderPayload{OrderNo: orderInfo.OrderNo} - bytes, err := json.Marshal(payload) - if err != nil { - l.Errorw("[IAP Verify] Marshal payload failed", logger.Field("error", err.Error())) - return nil, err - } - task := asynq.NewTask(queueType.ForthwithActivateOrder, bytes) - if _, err = l.svcCtx.Queue.EnqueueContext(l.ctx, task); err != nil { - l.Errorw("[IAP Verify] Enqueue activation failed", logger.Field("error", err.Error())) - return nil, err - } - - return &types.IAPVerifyResponse{OrderNo: orderInfo.OrderNo}, nil -} - diff --git a/internal/logic/public/portal/purchaseCheckoutLogic.go b/internal/logic/public/portal/purchaseCheckoutLogic.go index 46ba839..6b114dc 100644 --- a/internal/logic/public/portal/purchaseCheckoutLogic.go +++ b/internal/logic/public/portal/purchaseCheckoutLogic.go @@ -5,6 +5,7 @@ import ( "encoding/json" "math" "strconv" + "strings" "time" "github.com/perfect-panel/server/internal/model/log" @@ -224,8 +225,22 @@ func (l *PurchaseCheckoutLogic) stripePayment(config string, info *order.Order, WebhookSecret: stripeConfig.WebhookSecret, }) - // Convert order amount to CNY using current exchange rate - amount, err := l.queryExchangeRate("CNY", info.Amount) + currency := "USD" + sysCurrency, _ := l.svcCtx.SystemModel.GetCurrencyConfig(l.ctx) + if sysCurrency != nil { + configs := struct { + CurrencyUnit string + CurrencySymbol string + AccessKey string + }{} + tool.SystemConfigSliceReflectToStruct(sysCurrency, &configs) + if configs.CurrencyUnit != "" { + currency = configs.CurrencyUnit + } + } + + // Convert order amount to configured currency using current exchange rate + amount, err := l.queryExchangeRate(strings.ToUpper(currency), info.Amount) if err != nil { l.Errorw("[PurchaseCheckout] queryExchangeRate error", logger.Field("error", err.Error())) return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "queryExchangeRate error: %s", err.Error()) @@ -235,6 +250,7 @@ func (l *PurchaseCheckoutLogic) stripePayment(config string, info *order.Order, logger.Field("src_cents", info.Amount), logger.Field("decimal", amount), logger.Field("cents", convertAmount), + logger.Field("currency", currency), ) // Create Stripe payment sheet for client-side processing @@ -247,7 +263,7 @@ func (l *PurchaseCheckoutLogic) stripePayment(config string, info *order.Order, OrderNo: info.OrderNo, Subscribe: strconv.FormatInt(info.SubscribeId, 10), Amount: convertAmount, - Currency: "cny", + Currency: strings.ToLower(currency), Payment: paymentMethod, } usr := &stripe.User{Email: identifier} diff --git a/internal/model/iap/apple/default.go b/internal/model/iap/apple/default.go new file mode 100644 index 0000000..3311c37 --- /dev/null +++ b/internal/model/iap/apple/default.go @@ -0,0 +1,68 @@ +package apple + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "fmt" + + "github.com/perfect-panel/server/pkg/cache" + "github.com/redis/go-redis/v9" + "gorm.io/gorm" +) + +type Model interface { + Insert(ctx context.Context, data *Transaction, tx ...*gorm.DB) error + FindByOriginalId(ctx context.Context, originalId string) (*Transaction, error) + FindByUserAndProduct(ctx context.Context, userId int64, productId string) (*Transaction, error) +} + +type defaultModel struct { + cache.CachedConn + table string +} + +type customModel struct { + *defaultModel +} + +func NewModel(db *gorm.DB, c *redis.Client) Model { + return &customModel{ + defaultModel: &defaultModel{ + CachedConn: cache.NewConn(db, c), + table: "`apple_iap_transactions`", + }, + } +} + +func (m *defaultModel) jwsKey(jws string) string { + sum := sha256.Sum256([]byte(jws)) + return fmt.Sprintf("cache:iap:jws:%s", hex.EncodeToString(sum[:])) +} + +func (m *customModel) Insert(ctx context.Context, data *Transaction, tx ...*gorm.DB) error { + return m.ExecCtx(ctx, func(conn *gorm.DB) error { + if len(tx) > 0 { + conn = tx[0] + } + return conn.Model(&Transaction{}).Create(data).Error + }, m.jwsKey(data.JWSHash)) +} + +func (m *customModel) FindByOriginalId(ctx context.Context, originalId string) (*Transaction, error) { + var data Transaction + key := fmt.Sprintf("cache:iap:original:%s", originalId) + err := m.QueryCtx(ctx, &data, key, func(conn *gorm.DB, v interface{}) error { + return conn.Model(&Transaction{}).Where("original_transaction_id = ?", originalId).First(&data).Error + }) + return &data, err +} + +func (m *customModel) FindByUserAndProduct(ctx context.Context, userId int64, productId string) (*Transaction, error) { + var data Transaction + err := m.QueryNoCacheCtx(ctx, &data, func(conn *gorm.DB, v interface{}) error { + return conn.Model(&Transaction{}).Where("user_id = ? AND product_id = ?", userId, productId).Order("purchase_at DESC").First(&data).Error + }) + return &data, err +} + diff --git a/internal/model/iap/apple/transaction.go b/internal/model/iap/apple/transaction.go new file mode 100644 index 0000000..51bf729 --- /dev/null +++ b/internal/model/iap/apple/transaction.go @@ -0,0 +1,21 @@ +package apple + +import "time" + +type Transaction struct { + Id int64 `gorm:"primaryKey"` + UserId int64 `gorm:"index:idx_user_id;not null;comment:User ID"` + OriginalTransactionId string `gorm:"type:varchar(255);uniqueIndex:uni_original;not null;comment:Original Transaction ID"` + TransactionId string `gorm:"type:varchar(255);not null;comment:Transaction ID"` + ProductId string `gorm:"type:varchar(255);not null;comment:Product ID"` + PurchaseAt time.Time `gorm:"not null;comment:Purchase Time"` + RevocationAt *time.Time `gorm:"comment:Revocation Time"` + JWSHash string `gorm:"type:varchar(255);not null;comment:JWS Hash"` + CreatedAt time.Time `gorm:"<-:create;comment:Create Time"` + UpdatedAt time.Time `gorm:"comment:Update Time"` +} + +func (Transaction) TableName() string { + return "apple_iap_transactions" +} + diff --git a/internal/model/order/default.go b/internal/model/order/default.go index 29f2216..a59eeb0 100644 --- a/internal/model/order/default.go +++ b/internal/model/order/default.go @@ -12,25 +12,23 @@ import ( var _ Model = (*customOrderModel)(nil) var ( - cacheOrderIdPrefix = "cache:order:id:" - cacheOrderNoPrefix = "cache:order:no:" - cacheOrderTradePrefix = "cache:order:trade:" + cacheOrderIdPrefix = "cache:order:id:" + cacheOrderNoPrefix = "cache:order:no:" ) type ( - Model interface { - orderModel - customOrderLogicModel - } - orderModel interface { - Insert(ctx context.Context, data *Order, tx ...*gorm.DB) error - FindOne(ctx context.Context, id int64) (*Order, error) - FindOneByOrderNo(ctx context.Context, orderNo string) (*Order, error) - FindOneByTradeNo(ctx context.Context, tradeNo string) (*Order, error) - Update(ctx context.Context, data *Order, tx ...*gorm.DB) error - Delete(ctx context.Context, id int64, tx ...*gorm.DB) error - Transaction(ctx context.Context, fn func(db *gorm.DB) error) error - } + Model interface { + orderModel + customOrderLogicModel + } + orderModel interface { + Insert(ctx context.Context, data *Order, tx ...*gorm.DB) error + FindOne(ctx context.Context, id int64) (*Order, error) + FindOneByOrderNo(ctx context.Context, orderNo string) (*Order, error) + Update(ctx context.Context, data *Order, tx ...*gorm.DB) error + Delete(ctx context.Context, id int64, tx ...*gorm.DB) error + Transaction(ctx context.Context, fn func(db *gorm.DB) error) error + } customOrderModel struct { *defaultOrderModel @@ -62,14 +60,12 @@ func (m *defaultOrderModel) getCacheKeys(data *Order) []string { return []string{} } orderIdKey := fmt.Sprintf("%s%v", cacheOrderIdPrefix, data.Id) - orderNoKey := fmt.Sprintf("%s%v", cacheOrderNoPrefix, data.OrderNo) - tradeNoKey := fmt.Sprintf("%s%v", cacheOrderTradePrefix, data.TradeNo) - cacheKeys := []string{ - orderIdKey, - orderNoKey, - tradeNoKey, - } - return cacheKeys + orderNoKey := fmt.Sprintf("%s%v", cacheOrderNoPrefix, data.OrderNo) + cacheKeys := []string{ + orderIdKey, + orderNoKey, + } + return cacheKeys } func (m *defaultOrderModel) Insert(ctx context.Context, data *Order, tx ...*gorm.DB) error { @@ -110,20 +106,6 @@ func (m *defaultOrderModel) FindOneByOrderNo(ctx context.Context, orderNo string } } -func (m *defaultOrderModel) FindOneByTradeNo(ctx context.Context, tradeNo string) (*Order, error) { - OrderTradeKey := fmt.Sprintf("%s%v", cacheOrderTradePrefix, tradeNo) - var resp Order - err := m.QueryCtx(ctx, &resp, OrderTradeKey, func(conn *gorm.DB, v interface{}) error { - return conn.Model(&Order{}).Where("`trade_no` = ?", tradeNo).First(&resp).Error - }) - switch { - case err == nil: - return &resp, nil - default: - return nil, err - } -} - func (m *defaultOrderModel) Update(ctx context.Context, data *Order, tx ...*gorm.DB) error { old, err := m.FindOne(ctx, data.Id) if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { diff --git a/internal/svc/serviceContext.go b/internal/svc/serviceContext.go index 266375e..c9b9a3b 100644 --- a/internal/svc/serviceContext.go +++ b/internal/svc/serviceContext.go @@ -27,6 +27,7 @@ import ( "github.com/perfect-panel/server/internal/model/ticket" "github.com/perfect-panel/server/internal/model/traffic" "github.com/perfect-panel/server/internal/model/user" + iapapple "github.com/perfect-panel/server/internal/model/iap/apple" "github.com/perfect-panel/server/pkg/limit" "github.com/perfect-panel/server/pkg/nodeMultiplier" "github.com/perfect-panel/server/pkg/orm" @@ -62,6 +63,7 @@ type ServiceContext struct { SubscribeModel subscribe.Model TrafficLogModel traffic.Model AnnouncementModel announcement.Model + IAPAppleTransactionModel iapapple.Model Restart func() error TelegramBot *tgbotapi.BotAPI @@ -117,6 +119,7 @@ func NewServiceContext(c config.Config) *ServiceContext { TrafficLogModel: traffic.NewModel(db), AnnouncementModel: announcement.NewModel(db, rds), } + srv.IAPAppleTransactionModel = iapapple.NewModel(db, rds) srv.DeviceManager = NewDeviceManager(srv) return srv diff --git a/internal/types/types.go b/internal/types/types.go index 2b8d5c0..94692a1 100644 --- a/internal/types/types.go +++ b/internal/types/types.go @@ -72,18 +72,9 @@ type AppUserSubscbribeNode struct { } type AppleLoginCallbackRequest struct { - Code string `form:"code"` - IDToken string `form:"id_token"` - State string `form:"state"` -} - -type IAPVerifyRequest struct { - OriginalTransactionId string `json:"original_transaction_id" validate:"required"` - SubscribeId int64 `json:"subscribe_id" validate:"required"` -} - -type IAPVerifyResponse struct { - OrderNo string `json:"order_no"` + Code string `form:"code"` + IDToken string `form:"id_token"` + State string `form:"state"` } type Application struct { @@ -2853,3 +2844,38 @@ type VmessProtocol struct { Network string `json:"network"` Transport string `json:"transport"` } + +type AppleProduct struct { + ProductId string `json:"product_id"` + Description string `json:"description"` + PriceText string `json:"price_text"` + DurationDays int64 `json:"duration_days"` + Tier string `json:"tier"` + SubscribeId int64 `json:"subscribe_id"` +} + +type GetAppleProductsResponse struct { + List []AppleProduct `json:"list"` +} + +type AttachAppleTransactionRequest struct { + SignedTransactionJWS string `json:"signed_transaction_jws" validate:"required"` + DurationDays int64 `json:"duration_days,omitempty"` + Tier string `json:"tier,omitempty"` + SubscribeId int64 `json:"subscribe_id,omitempty"` +} + +type AttachAppleTransactionResponse struct { + ExpiresAt int64 `json:"expires_at"` + Tier string `json:"tier"` +} + +type RestoreAppleTransactionsRequest struct { + Transactions []string `json:"transactions" validate:"required"` +} + +type GetAppleStatusResponse struct { + Active bool `json:"active"` + ExpiresAt int64 `json:"expires_at"` + Tier string `json:"tier"` +} diff --git a/pkg/appleiap/jwks.go b/pkg/appleiap/jwks.go deleted file mode 100644 index 840ee84..0000000 --- a/pkg/appleiap/jwks.go +++ /dev/null @@ -1,87 +0,0 @@ -package appleiap - -import ( - "crypto/ecdsa" - "crypto/elliptic" - "encoding/base64" - "encoding/json" - "errors" - "math/big" - "net/http" - "sync" - "time" -) - -type jwk struct { - Kty string `json:"kty"` - Kid string `json:"kid"` - Crv string `json:"crv"` - X string `json:"x"` - Y string `json:"y"` -} - -type jwks struct { - Keys []jwk `json:"keys"` -} - -type cacheEntry struct { - keys map[string]*ecdsa.PublicKey - exp time.Time -} - -var ( - mu sync.Mutex - cache = map[string]*cacheEntry{} -) - -func endpoint(env string) string { - if env == "sandbox" { - return "https://api.storekit-sandbox.itunes.apple.com/inApps/v1/keys" - } - return "https://api.storekit.itunes.apple.com/inApps/v1/keys" -} - -func fetch(env string) (map[string]*ecdsa.PublicKey, error) { - resp, err := http.Get(endpoint(env)) - if err != nil { - return nil, err - } - defer resp.Body.Close() - var set jwks - if err := json.NewDecoder(resp.Body).Decode(&set); err != nil { - return nil, err - } - m := make(map[string]*ecdsa.PublicKey) - for _, k := range set.Keys { - if k.Kty != "EC" || k.Crv != "P-256" || k.X == "" || k.Y == "" || k.Kid == "" { - continue - } - xb, err := base64.RawURLEncoding.DecodeString(k.X) - if err != nil { continue } - yb, err := base64.RawURLEncoding.DecodeString(k.Y) - if err != nil { continue } - var x, y big.Int - x.SetBytes(xb) - y.SetBytes(yb) - m[k.Kid] = &ecdsa.PublicKey{Curve: elliptic.P256(), X: &x, Y: &y} - } - if len(m) == 0 { - return nil, errors.New("empty jwks") - } - return m, nil -} - -func GetKey(env, kid string) (*ecdsa.PublicKey, error) { - mu.Lock() - defer mu.Unlock() - c := cache[env] - if c == nil || time.Now().After(c.exp) { - keys, err := fetch(env) - if err != nil { return nil, err } - cache[env] = &cacheEntry{ keys: keys, exp: time.Now().Add(10 * time.Minute) } - c = cache[env] - } - k := c.keys[kid] - if k == nil { return nil, errors.New("key not found") } - return k, nil -} diff --git a/pkg/appleiap/jws.go b/pkg/appleiap/jws.go deleted file mode 100644 index 595d702..0000000 --- a/pkg/appleiap/jws.go +++ /dev/null @@ -1,29 +0,0 @@ -package appleiap - -import ( - "errors" - "github.com/golang-jwt/jwt/v5" -) - -func verifyWithEnv(env, token string) (jwt.MapClaims, error) { - parsed, err := jwt.Parse(token, func(t *jwt.Token) (interface{}, error) { - h, ok := t.Header["kid"].(string) - if !ok { return nil, errors.New("kid missing") } - return GetKey(env, h) - }) - if err != nil { return nil, err } - if !parsed.Valid { return nil, errors.New("invalid jws") } - c, ok := parsed.Claims.(jwt.MapClaims) - if !ok { return nil, errors.New("claims invalid") } - return c, nil -} - -func VerifyWithEnv(env, token string) (jwt.MapClaims, error) { return verifyWithEnv(env, token) } - -func VerifyAutoEnv(token string) (jwt.MapClaims, string, error) { - c, err := verifyWithEnv("production", token) - if err == nil { return c, "production", nil } - c2, err2 := verifyWithEnv("sandbox", token) - if err2 == nil { return c2, "sandbox", nil } - return nil, "", err -} diff --git a/pkg/iap/apple/errors.go b/pkg/iap/apple/errors.go new file mode 100644 index 0000000..4bab66f --- /dev/null +++ b/pkg/iap/apple/errors.go @@ -0,0 +1,6 @@ +package apple + +import "errors" + +var ErrInvalidJWS = errors.New("invalid jws") + diff --git a/pkg/iap/apple/jws.go b/pkg/iap/apple/jws.go new file mode 100644 index 0000000..00eca98 --- /dev/null +++ b/pkg/iap/apple/jws.go @@ -0,0 +1,55 @@ +package apple + +import ( + "encoding/base64" + "encoding/json" + "strings" + "time" +) + +func ParseTransactionJWS(jws string) (*TransactionPayload, error) { + parts := strings.Split(jws, ".") + if len(parts) != 3 { + return nil, ErrInvalidJWS + } + payloadB64 := parts[1] + // add padding if required + switch len(payloadB64) % 4 { + case 2: + payloadB64 += "==" + case 3: + payloadB64 += "=" + } + data, err := base64.RawURLEncoding.DecodeString(payloadB64) + if err != nil { + return nil, err + } + var raw map[string]interface{} + if err = json.Unmarshal(data, &raw); err != nil { + return nil, err + } + var resp TransactionPayload + if v, ok := raw["bundleId"].(string); ok { + resp.BundleId = v + } + if v, ok := raw["productId"].(string); ok { + resp.ProductId = v + } + if v, ok := raw["transactionId"].(string); ok { + resp.TransactionId = v + } + if v, ok := raw["originalTransactionId"].(string); ok { + resp.OriginalTransactionId = v + } + if v, ok := raw["purchaseDate"].(float64); ok { + resp.PurchaseDate = time.UnixMilli(int64(v)) + } else if v, ok := raw["purchaseDate"].(int64); ok { + resp.PurchaseDate = time.UnixMilli(v) + } + if v, ok := raw["revocationDate"].(float64); ok { + t := time.UnixMilli(int64(v)) + resp.RevocationDate = &t + } + return &resp, nil +} + diff --git a/pkg/iap/apple/jws_test.go b/pkg/iap/apple/jws_test.go new file mode 100644 index 0000000..4cb7796 --- /dev/null +++ b/pkg/iap/apple/jws_test.go @@ -0,0 +1,35 @@ +package apple + +import ( + "encoding/base64" + "encoding/json" + "testing" + "time" +) + +func TestParseTransactionJWS(t *testing.T) { + payload := map[string]interface{}{ + "bundleId": "co.airoport.app.ios", + "productId": "com.airport.vpn.pass.30d", + "transactionId": "1000000000001", + "originalTransactionId": "1000000000000", + "purchaseDate": float64(time.Now().UnixMilli()), + } + data, _ := json.Marshal(payload) + b64 := base64.RawURLEncoding.EncodeToString(data) + jws := "header." + b64 + ".signature" + p, err := ParseTransactionJWS(jws) + if err != nil { + t.Fatalf("parse error: %v", err) + } + if p.ProductId != payload["productId"] { + t.Fatalf("productId not match") + } + if p.BundleId != payload["bundleId"] { + t.Fatalf("bundleId not match") + } + if p.OriginalTransactionId != payload["originalTransactionId"] { + t.Fatalf("originalTransactionId not match") + } +} + diff --git a/pkg/iap/apple/productmap.go b/pkg/iap/apple/productmap.go new file mode 100644 index 0000000..45a51c0 --- /dev/null +++ b/pkg/iap/apple/productmap.go @@ -0,0 +1,39 @@ +package apple + +import ( + "encoding/json" + "time" +) + +type ProductMapping struct { + DurationDays int64 `json:"durationDays"` + Tier string `json:"tier"` + Description string `json:"description"` + PriceText string `json:"priceText"` + SubscribeId int64 `json:"subscribeId"` +} + +type ProductMap struct { + Items map[string]ProductMapping `json:"iapProductMap"` +} + +func ParseProductMap(customData string) (*ProductMap, error) { + if customData == "" { + return &ProductMap{Items: map[string]ProductMapping{}}, nil + } + var obj ProductMap + if err := json.Unmarshal([]byte(customData), &obj); err != nil { + return &ProductMap{Items: map[string]ProductMapping{}}, nil + } + if obj.Items == nil { + obj.Items = map[string]ProductMapping{} + } + return &obj, nil +} + +func CalcExpire(start time.Time, days int64) time.Time { + if days <= 0 { + return time.UnixMilli(0) + } + return start.Add(time.Duration(days) * 24 * time.Hour) +} diff --git a/pkg/iap/apple/types.go b/pkg/iap/apple/types.go new file mode 100644 index 0000000..ffdd807 --- /dev/null +++ b/pkg/iap/apple/types.go @@ -0,0 +1,13 @@ +package apple + +import "time" + +type TransactionPayload struct { + BundleId string `json:"bundleId"` + ProductId string `json:"productId"` + TransactionId string `json:"transactionId"` + OriginalTransactionId string `json:"originalTransactionId"` + PurchaseDate time.Time `json:"purchaseDate"` + RevocationDate *time.Time`json:"revocationDate"` +} + diff --git a/pkg/payment/platform.go b/pkg/payment/platform.go index a460847..7ad12ea 100644 --- a/pkg/payment/platform.go +++ b/pkg/payment/platform.go @@ -10,7 +10,6 @@ const ( EPay Balance CryptoSaaS - AppleIAP UNSUPPORTED Platform = -1 ) @@ -20,7 +19,6 @@ var platformNames = map[string]Platform{ "AlipayF2F": AlipayF2F, "EPay": EPay, "balance": Balance, - "AppleIAP": AppleIAP, "unsupported": UNSUPPORTED, } @@ -49,7 +47,7 @@ func GetSupportedPlatforms() []types.PlatformInfo { "public_key": "Publishable key", "secret_key": "Secret key", "webhook_secret": "Webhook secret", - "payment": "Payment Method, only supported card/alipay/wechat_pay/apple_pay", + "payment": "Payment Method, only supported card/alipay/wechat_pay", }, }, { @@ -82,15 +80,5 @@ func GetSupportedPlatforms() []types.PlatformInfo { "secret_key": "Secret Key", }, }, - { - Platform: AppleIAP.String(), - PlatformUrl: "https://developer.apple.com/help/app-store-connect/", - PlatformFieldDescription: map[string]string{ - "issuer_id": "App Store Connect Issuer ID", - "key_id": "App Store Connect Key ID", - "private_key": "Private Key (ES256)", - "environment": "Environment: Sandbox/Production", - }, - }, } } diff --git a/scripts/.env b/scripts/.env new file mode 100644 index 0000000..047eb1d --- /dev/null +++ b/scripts/.env @@ -0,0 +1,2 @@ +UH7EpvMzwYDBfQ0nxAS5 +ibFUcqkPhyeGvQCBjE07VaYzWH3IpJw9frDudxL6 \ No newline at end of file diff --git a/scripts/backup_all.sh b/scripts/backup_all.sh new file mode 100644 index 0000000..4dc98b4 --- /dev/null +++ b/scripts/backup_all.sh @@ -0,0 +1,67 @@ +#!/bin/bash + +# Configuration +LOG_FILE="/root/backup.log" +MYSQL_BACKUP_SCRIPT="/root/backup_mysql.sh" +UPLOADER_BINARY="/root/uploader-linux-amd64" + +# MinIO Credentials (can be modified here or passed via env) +MINIO_ENDPOINT="http://107.173.50.22:5017" +MINIO_ACCESS_KEY="WyJYxDobmp9glIXVAteC" +MINIO_SECRET_KEY="TNO0ZJ4AH5QupFwDtiLxavUeMVmz2fo1YXRGsI7c" +MINIO_BUCKET="backup" + +# Directories to backup (comma separated) +# Example: "/root/vpn_server,/etc/nginx/conf.d" +DIRS_TO_BACKUP="/root/db_backups,/etc/nginx/conf.d" + +echo "========================================================" >> "$LOG_FILE" +echo "[$(date)] Starting Daily Backup Task..." >> "$LOG_FILE" + +# 1. Execute MySQL Backup +if [ -f "$MYSQL_BACKUP_SCRIPT" ]; then + echo "[$(date)] Running MySQL backup script..." >> "$LOG_FILE" + # Pass credentials to the MySQL script via environment variables if needed, + # but currently backup_mysql.sh calls uploader internally. + # We should update backup_mysql.sh to use these credentials too, or rely on them being embedded/env. + # For now, let's export them so child processes can see them if they use os.Getenv + export MINIO_ENDPOINT + export MINIO_ACCESS_KEY + export MINIO_SECRET_KEY + export MINIO_BUCKET + + bash "$MYSQL_BACKUP_SCRIPT" >> "$LOG_FILE" 2>&1 + if [ $? -eq 0 ]; then + echo "[$(date)] MySQL backup script finished." >> "$LOG_FILE" + else + echo "[$(date)] Error: MySQL backup script failed!" >> "$LOG_FILE" + fi +else + echo "[$(date)] Error: MySQL backup script not found at $MYSQL_BACKUP_SCRIPT" >> "$LOG_FILE" +fi + +# 2. Execute File/Directory Backup using Go Uploader +if [ -f "$UPLOADER_BINARY" ]; then + echo "[$(date)] Running Directory backup..." >> "$LOG_FILE" + chmod +x "$UPLOADER_BINARY" + + # Run uploader with explicit flags + "$UPLOADER_BINARY" \ + -dir "$DIRS_TO_BACKUP" \ + -bucket "$MINIO_BUCKET" \ + -endpoint "$MINIO_ENDPOINT" \ + -access-key "$MINIO_ACCESS_KEY" \ + -secret-key "$MINIO_SECRET_KEY" \ + >> "$LOG_FILE" 2>&1 + + if [ $? -eq 0 ]; then + echo "[$(date)] Directory backup finished." >> "$LOG_FILE" + else + echo "[$(date)] Error: Directory backup failed!" >> "$LOG_FILE" + fi +else + echo "[$(date)] Error: Uploader binary not found at $UPLOADER_BINARY" >> "$LOG_FILE" +fi + +echo "[$(date)] Daily Backup Task Completed." >> "$LOG_FILE" +echo "========================================================" >> "$LOG_FILE" diff --git a/scripts/backup_mysql.sh b/scripts/backup_mysql.sh new file mode 100644 index 0000000..f50144e --- /dev/null +++ b/scripts/backup_mysql.sh @@ -0,0 +1,76 @@ +#!/bin/bash + +# Configuration +CONTAINER_NAME="ppanel-db" +DB_USER="vmanroot" +DB_PASSWORD="vmanrootpassword" # Replace with actual password +DB_NAME="ppanel" # Explicitly specify the database name +BACKUP_DIR="/root/db_backups" +UPLOADER_PATH="/root/uploader-linux-amd64" # Path to your go uploader binary +RETENTION_DAYS=7 + +# Create backup directory if not exists +mkdir -p "$BACKUP_DIR" + +# Generate timestamp +TIMESTAMP=$(date +"%Y%m%d_%H%M%S") +FILENAME="mysql_backup_${TIMESTAMP}.sql" +FILEPATH="${BACKUP_DIR}/${FILENAME}" +GZ_FILEPATH="${FILEPATH}.gz" + +# 1. Dump MySQL database from Docker container +echo "[$(date)] Starting MySQL backup from container ${CONTAINER_NAME}..." + +# Check if container is running +if [ ! "$(docker ps -q -f name=${CONTAINER_NAME})" ]; then + echo "Error: Container ${CONTAINER_NAME} is not running!" + exit 1 +fi + +# Execute dump +docker exec "$CONTAINER_NAME" /usr/bin/mysqldump -u "$DB_USER" -p"$DB_PASSWORD" --databases "$DB_NAME" --no-tablespaces > "$FILEPATH" + +# Check if file size is too small (e.g., < 1KB), which usually indicates an empty dump or error +FILE_SIZE=$(stat -c%s "$FILEPATH" 2>/dev/null || stat -f%z "$FILEPATH") +if [ "$FILE_SIZE" -lt 1024 ]; then + echo "Error: Backup file is too small ($FILE_SIZE bytes). Dump might have failed." + cat "$FILEPATH" # Print content to log for debugging + exit 1 +fi + +if [ $? -eq 0 ]; then + echo "[$(date)] Database dump successful: ${FILEPATH}" + + # 2. Compress the backup + gzip "$FILEPATH" + echo "[$(date)] Compression successful: ${GZ_FILEPATH}" + + # 3. Upload to MinIO using the Go uploader + if [ -f "$UPLOADER_PATH" ]; then + echo "[$(date)] Uploading to object storage..." + chmod +x "$UPLOADER_PATH" + "$UPLOADER_PATH" -file "$GZ_FILEPATH" -bucket backup + + if [ $? -eq 0 ]; then + echo "[$(date)] Upload successful." + else + echo "[$(date)] Upload failed." + fi + else + echo "Warning: Uploader binary not found at $UPLOADER_PATH. Skipping upload." + fi + + # 4. Clean up old local backups (optional) + find "$BACKUP_DIR" -name "mysql_backup_*.sql.gz" -mtime +$RETENTION_DAYS -delete + echo "[$(date)] Cleaned up local backups older than $RETENTION_DAYS days." + +else + echo "Error: Database dump failed!" + # Clean up empty file if dump failed + if [ -f "$FILEPATH" ]; then + rm "$FILEPATH" + fi + exit 1 +fi + +echo "[$(date)] Backup process completed." diff --git a/scripts/uploader/.env b/scripts/uploader/.env new file mode 100644 index 0000000..f345772 --- /dev/null +++ b/scripts/uploader/.env @@ -0,0 +1,14 @@ +MINIO_ENDPOINT="http://107.173.50.22:5017" +MINIO_ACCESS_KEY="WyJYxDobmp9glIXVAteC" +MINIO_SECRET_KEY="TNO0ZJ4AH5QupFwDtiLxavUeMVmz2fo1YXRGsI7c" +MINIO_BUCKET="backup" + + + + +./uploader-linux-amd64 \ + -dir /root/vpn_server,/etc/nginx/conf.d \ + -endpoint http://107.173.50.22:5017 \ + -access-key WyJYxDobmp9glIXVAteC \ + -secret-key TNO0ZJ4AH5QupFwDtiLxavUeMVmz2fo1YXRGsI7c \ + -bucket backup \ No newline at end of file diff --git a/scripts/uploader/go.mod b/scripts/uploader/go.mod new file mode 100644 index 0000000..496235b --- /dev/null +++ b/scripts/uploader/go.mod @@ -0,0 +1,29 @@ +module uploader + +go 1.24.4 + +require ( + github.com/joho/godotenv v1.5.1 + github.com/minio/minio-go/v7 v7.0.97 +) + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/dustin/go-humanize v1.0.1 // indirect + github.com/go-ini/ini v1.67.0 // indirect + github.com/google/uuid v1.6.0 // indirect + github.com/klauspost/compress v1.18.0 // indirect + github.com/klauspost/cpuid/v2 v2.2.11 // indirect + github.com/klauspost/crc32 v1.3.0 // indirect + github.com/minio/crc64nvme v1.1.0 // indirect + github.com/minio/md5-simd v1.1.2 // indirect + github.com/philhofer/fwd v1.2.0 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/rs/xid v1.6.0 // indirect + github.com/tinylib/msgp v1.3.0 // indirect + golang.org/x/crypto v0.36.0 // indirect + golang.org/x/net v0.38.0 // indirect + golang.org/x/sys v0.34.0 // indirect + golang.org/x/text v0.26.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/scripts/uploader/go.sum b/scripts/uploader/go.sum new file mode 100644 index 0000000..96c2190 --- /dev/null +++ b/scripts/uploader/go.sum @@ -0,0 +1,45 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= +github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= +github.com/go-ini/ini v1.67.0 h1:z6ZrTEZqSWOTyH2FlglNbNgARyHG8oLW9gMELqKr06A= +github.com/go-ini/ini v1.67.0/go.mod h1:ByCAeIL28uOIIG0E3PJtZPDL8WnHpFKFOtgjp+3Ies8= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= +github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= +github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= +github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= +github.com/klauspost/cpuid/v2 v2.0.1/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= +github.com/klauspost/cpuid/v2 v2.2.11 h1:0OwqZRYI2rFrjS4kvkDnqJkKHdHaRnCm68/DY4OxRzU= +github.com/klauspost/cpuid/v2 v2.2.11/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0= +github.com/klauspost/crc32 v1.3.0 h1:sSmTt3gUt81RP655XGZPElI0PelVTZ6YwCRnPSupoFM= +github.com/klauspost/crc32 v1.3.0/go.mod h1:D7kQaZhnkX/Y0tstFGf8VUzv2UofNGqCjnC3zdHB0Hw= +github.com/minio/crc64nvme v1.1.0 h1:e/tAguZ+4cw32D+IO/8GSf5UVr9y+3eJcxZI2WOO/7Q= +github.com/minio/crc64nvme v1.1.0/go.mod h1:eVfm2fAzLlxMdUGc0EEBGSMmPwmXD5XiNRpnu9J3bvg= +github.com/minio/md5-simd v1.1.2 h1:Gdi1DZK69+ZVMoNHRXJyNcxrMA4dSxoYHZSQbirFg34= +github.com/minio/md5-simd v1.1.2/go.mod h1:MzdKDxYpY2BT9XQFocsiZf/NKVtR7nkE4RoEpN+20RM= +github.com/minio/minio-go/v7 v7.0.97 h1:lqhREPyfgHTB/ciX8k2r8k0D93WaFqxbJX36UZq5occ= +github.com/minio/minio-go/v7 v7.0.97/go.mod h1:re5VXuo0pwEtoNLsNuSr0RrLfT/MBtohwdaSmPPSRSk= +github.com/philhofer/fwd v1.2.0 h1:e6DnBTl7vGY+Gz322/ASL4Gyp1FspeMvx1RNDoToZuM= +github.com/philhofer/fwd v1.2.0/go.mod h1:RqIHx9QI14HlwKwm98g9Re5prTQ6LdeRQn+gXJFxsJM= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rs/xid v1.6.0 h1:fV591PaemRlL6JfRxGDEPl69wICngIQ3shQtzfy2gxU= +github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/tinylib/msgp v1.3.0 h1:ULuf7GPooDaIlbyvgAxBV/FI7ynli6LZ1/nVUNu+0ww= +github.com/tinylib/msgp v1.3.0/go.mod h1:ykjzy2wzgrlvpDCRc4LA8UXy6D8bzMSuAF3WD57Gok0= +golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34= +golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc= +golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8= +golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8= +golang.org/x/sys v0.34.0 h1:H5Y5sJ2L2JRdyv7ROF1he/lPdvFsd0mJHFw2ThKHxLA= +golang.org/x/sys v0.34.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/text v0.26.0 h1:P42AVeLghgTYr4+xUnTRKDMqpar+PtX7KWuNQL21L8M= +golang.org/x/text v0.26.0/go.mod h1:QK15LZJUUQVJxhz7wXgxSy/CJaTFjd0G+YLonydOVQA= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/scripts/uploader/main.go b/scripts/uploader/main.go new file mode 100644 index 0000000..67688ee --- /dev/null +++ b/scripts/uploader/main.go @@ -0,0 +1,371 @@ +package main + +import ( + "archive/zip" + "context" + "flag" + "fmt" + "io" + "log" + "os" + "path/filepath" + "strings" + "time" + + "github.com/joho/godotenv" + "github.com/minio/minio-go/v7" + "github.com/minio/minio-go/v7/pkg/credentials" +) + +// Global variables to be set via ldflags +var ( + BuildEndpoint string + BuildAccessKey string + BuildSecretKey string + BuildBucket string +) + +func main() { + // Load environment variables from .env file + err := godotenv.Load() + if err != nil { + // Just log, don't fail, as we might rely on flags or build defaults + } + + // Parse command line arguments + var filePath string + var dirPaths string + var bucketName string + var interval string + var objectName string + var listBuckets bool + var createBucketFlag bool + + // Credential flags + var endpointFlag string + var accessKeyFlag string + var secretKeyFlag string + + flag.StringVar(&filePath, "file", "", "Path to the file to upload") + flag.StringVar(&dirPaths, "dir", "", "Comma-separated paths to directories to compress and upload") + flag.StringVar(&bucketName, "bucket", "", "Bucket name") + flag.StringVar(&interval, "interval", "", "Backup interval (e.g., 24h, 60m). If not set, runs once.") + flag.StringVar(&objectName, "name", "", "Object name (optional)") + flag.BoolVar(&listBuckets, "list", false, "List all available buckets") + flag.BoolVar(&createBucketFlag, "create", false, "Create the specified bucket if it doesn't exist") + + flag.StringVar(&endpointFlag, "endpoint", "", "MinIO endpoint URL") + flag.StringVar(&accessKeyFlag, "access-key", "", "MinIO access key") + flag.StringVar(&secretKeyFlag, "secret-key", "", "MinIO secret key") + + flag.Parse() + + // Resolve Configuration: Flag > Env > Build Default + finalEndpoint := resolveConfig(endpointFlag, "MINIO_ENDPOINT", BuildEndpoint) + finalAccessKey := resolveConfig(accessKeyFlag, "MINIO_ACCESS_KEY", BuildAccessKey) + finalSecretKey := resolveConfig(secretKeyFlag, "MINIO_SECRET_KEY", BuildSecretKey) + + if bucketName == "" { + bucketName = resolveConfig("", "MINIO_BUCKET", BuildBucket) + } + + // Initialize MinIO client + minioClient := initMinioClient(finalEndpoint, finalAccessKey, finalSecretKey) + + if listBuckets { + listAllBuckets(minioClient) + return + } + + if createBucketFlag { + if bucketName == "" { + fmt.Println("Please specify a bucket name using -bucket") + os.Exit(1) + } + createBucket(minioClient, bucketName) + return + } + + // Handle positional arguments for backward compatibility + args := flag.Args() + if len(args) > 0 && filePath == "" && dirPaths == "" { + // Check if argument is a directory + info, err := os.Stat(args[0]) + if err == nil && info.IsDir() { + dirPaths = args[0] + } else { + filePath = args[0] + } + } + if len(args) > 1 && bucketName == "" { + bucketName = args[1] + } + + if bucketName == "" { + // Try to resolve bucket again if not set via flag + bucketName = resolveConfig("", "MINIO_BUCKET", BuildBucket) + } + + if (filePath == "" && dirPaths == "") || bucketName == "" { + fmt.Println("Usage: uploader -file -bucket ") + fmt.Println(" uploader -dir -bucket [-interval 24h]") + fmt.Println(" uploader -list") + fmt.Println("\nCredentials can be provided via .env file, environment variables, or flags:") + fmt.Println(" -endpoint -access-key -secret-key ") + os.Exit(1) + } + + // One-time execution + if interval == "" { + if err := performBackup(minioClient, filePath, dirPaths, bucketName, objectName); err != nil { + log.Fatalln(err) + } + return + } + + // Scheduled execution + duration, err := time.ParseDuration(interval) + if err != nil { + log.Fatalf("Invalid interval format: %v\n", err) + } + + fmt.Printf("Starting backup service every %s...\n", interval) + ticker := time.NewTicker(duration) + defer ticker.Stop() + + // Run immediately first + if err := performBackup(minioClient, filePath, dirPaths, bucketName, objectName); err != nil { + log.Printf("Backup failed: %v\n", err) + } + + for range ticker.C { + if err := performBackup(minioClient, filePath, dirPaths, bucketName, objectName); err != nil { + log.Printf("Backup failed: %v\n", err) + } + } +} + +func resolveConfig(flagVal, envKey, buildVal string) string { + if flagVal != "" { + return flagVal + } + if envVal := os.Getenv(envKey); envVal != "" { + return envVal + } + return buildVal +} + +func initMinioClient(endpoint, accessKey, secretKey string) *minio.Client { + useSSL := true + + if strings.HasPrefix(endpoint, "http://") { + endpoint = strings.TrimPrefix(endpoint, "http://") + useSSL = false + } else if strings.HasPrefix(endpoint, "https://") { + endpoint = strings.TrimPrefix(endpoint, "https://") + useSSL = true + } + + if endpoint == "" || accessKey == "" || secretKey == "" { + log.Fatal("Error: Credentials must be provided via flags, environment variables, .env file, or build-time defaults.") + } + + minioClient, err := minio.New(endpoint, &minio.Options{ + Creds: credentials.NewStaticV4(accessKey, secretKey, ""), + Secure: useSSL, + }) + if err != nil { + log.Fatalln(err) + } + return minioClient +} + +func listAllBuckets(client *minio.Client) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + buckets, err := client.ListBuckets(ctx) + if err != nil { + log.Fatalf("Failed to list buckets: %v\n", err) + } + + if len(buckets) == 0 { + fmt.Println("No buckets found.") + return + } + + fmt.Println("Available buckets:") + for _, bucket := range buckets { + fmt.Printf("- %s (Created: %s)\n", bucket.Name, bucket.CreationDate) + } +} + +func createBucket(client *minio.Client, bucketName string) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + err := client.MakeBucket(ctx, bucketName, minio.MakeBucketOptions{}) + if err != nil { + // Check to see if we already own this bucket + exists, errBucketExists := client.BucketExists(ctx, bucketName) + if errBucketExists == nil && exists { + log.Printf("We already own %s\n", bucketName) + } else { + log.Fatalf("Failed to create bucket %s: %v\n", bucketName, err) + } + } else { + log.Printf("Successfully created bucket %s\n", bucketName) + } +} + +func performBackup(client *minio.Client, filePath, dirPaths, bucketName, objectName string) error { + // Handle single file upload + if filePath != "" { + if objectName == "" { + objectName = filepath.Base(filePath) + } + fmt.Printf("Uploading %s to %s/%s...\n", filePath, bucketName, objectName) + info, err := client.FPutObject(context.Background(), bucketName, objectName, filePath, minio.PutObjectOptions{}) + if err != nil { + return fmt.Errorf("failed to upload file %s: %v", filePath, err) + } + fmt.Printf("Successfully uploaded %s. Size: %d bytes\n", filePath, info.Size) + } + + // Handle directory uploads + if dirPaths != "" { + dirs := strings.Split(dirPaths, ",") + for _, dirPath := range dirs { + dirPath = strings.TrimSpace(dirPath) + if dirPath == "" { + continue + } + + // Verify directory exists + info, err := os.Stat(dirPath) + if err != nil || !info.IsDir() { + log.Printf("Warning: Skipping invalid directory: %s\n", dirPath) + continue + } + + timestamp := time.Now().Format("20060102_150405") + dirName := filepath.Base(dirPath) + if dirName == "." || dirName == "/" { + // Use parent directory name or absolute path hash/sanitized name could be better, + // but for simplicity, let's try to get absolute path base + absPath, _ := filepath.Abs(dirPath) + dirName = filepath.Base(absPath) + } + + zipName := fmt.Sprintf("%s_%s.zip", dirName, timestamp) + // Create zip in temp directory + tempFile := filepath.Join(os.TempDir(), zipName) + + fmt.Printf("Zipping directory %s to %s...\n", dirPath, tempFile) + if err := zipSource(dirPath, tempFile); err != nil { + log.Printf("Error zipping directory %s: %v\n", dirPath, err) + continue + } + + // If object name was specified (and only 1 directory), use it. + // Otherwise (multiple directories), use the zip name to avoid overwriting. + uploadName := zipName + if objectName != "" && len(dirs) == 1 { + uploadName = objectName + } + + fmt.Printf("Uploading %s to %s/%s...\n", tempFile, bucketName, uploadName) + uploadInfo, err := client.FPutObject(context.Background(), bucketName, uploadName, tempFile, minio.PutObjectOptions{}) + + // Clean up temp file + os.Remove(tempFile) + + if err != nil { + log.Printf("Error uploading directory %s: %v\n", dirPath, err) + continue + } + + fmt.Printf("Successfully uploaded %s. Size: %d bytes\n", dirPath, uploadInfo.Size) + } + } + return nil +} + +func zipSource(source, target string) error { + f, err := os.Create(target) + if err != nil { + return err + } + defer f.Close() + + writer := zip.NewWriter(f) + defer writer.Close() + + return filepath.Walk(source, func(path string, info os.FileInfo, err error) error { + if err != nil { + // Instead of failing hard, log the error and skip this file + log.Printf("Warning: Skipping file %s due to error: %v\n", path, err) + return nil + } + + // Skip the zip file itself if it's inside the source directory + if path == target { + return nil + } + + // Skip sockets, pipes, devices, etc. Only allow regular files and directories. + // info.Mode() & os.ModeType returns the file type bits (excluding permissions) + // 0 means regular file. + mode := info.Mode() + if !mode.IsRegular() && !info.IsDir() { + // Skip non-regular files silently (or with debug log) to avoid "no such device" or "open socket" errors + // Symlinks (ModeSymlink) are also skipped by IsRegular(), which is usually desired for backup consistency unless we specifically want to follow them. + // If we want to support symlinks, we'd need to handle them separately. For now, skipping is safer. + return nil + } + + header, err := zip.FileInfoHeader(info) + if err != nil { + log.Printf("Warning: Failed to create zip header for %s: %v\n", path, err) + return nil + } + + // Change header name to be relative to the source + relPath, err := filepath.Rel(source, path) + if err != nil { + return err + } + + // Use forward slashes for zip compatibility + header.Name = filepath.ToSlash(relPath) + + if info.IsDir() { + header.Name += "/" + } else { + header.Method = zip.Deflate + } + + writer, err := writer.CreateHeader(header) + if err != nil { + return err + } + + if info.IsDir() { + return nil + } + + file, err := os.Open(path) + if err != nil { + // If we can't open the file (permission denied etc), skip it + log.Printf("Warning: Failed to open file %s: %v\n", path, err) + return nil + } + defer file.Close() + _, err = io.Copy(writer, file) + if err != nil { + log.Printf("Warning: Failed to write file %s to zip: %v\n", path, err) + return nil + } + return nil + }) +} diff --git a/scripts/uploader/uploader-linux-amd64 b/scripts/uploader/uploader-linux-amd64 new file mode 100755 index 0000000..037c61e Binary files /dev/null and b/scripts/uploader/uploader-linux-amd64 differ