feat(apple): 添加通过transaction_id附加苹果交易功能
All checks were successful
Build docker and publish / build (20.15.1) (push) Successful in 6m41s

新增通过transaction_id附加苹果交易的功能,包括:
1. 添加AttachAppleTransactionByIdRequest类型和对应路由
2. 实现AppleIAPConfig配置模型
3. 添加ServerAPI获取交易信息的实现
4. 优化JWS解析逻辑,增加cleanB64函数处理空格
5. 完善苹果通知处理逻辑的日志和注释
This commit is contained in:
shanshanzhong 2025-12-15 22:35:33 -08:00
parent 15fb9a1da5
commit 3c6dd5058b
13 changed files with 381 additions and 92 deletions

View File

@ -1,9 +1,10 @@
package notify package notify
import ( import (
"github.com/gin-gonic/gin"
"io"
"encoding/json" "encoding/json"
"io"
"github.com/gin-gonic/gin"
"github.com/perfect-panel/server/internal/logic/notify" "github.com/perfect-panel/server/internal/logic/notify"
"github.com/perfect-panel/server/internal/svc" "github.com/perfect-panel/server/internal/svc"
"github.com/perfect-panel/server/pkg/result" "github.com/perfect-panel/server/pkg/result"

View File

@ -0,0 +1,23 @@
package apple
import (
"github.com/gin-gonic/gin"
appleLogic "github.com/perfect-panel/server/internal/logic/public/iap/apple"
"github.com/perfect-panel/server/internal/svc"
"github.com/perfect-panel/server/internal/types"
"github.com/perfect-panel/server/pkg/result"
)
func AttachAppleTransactionByIdHandler(svcCtx *svc.ServiceContext) func(c *gin.Context) {
return func(c *gin.Context) {
var req types.AttachAppleTransactionByIdRequest
_ = c.ShouldBind(&req)
if err := svcCtx.Validate(&req); err != nil {
result.ParamErrorResult(c, err)
return
}
l := appleLogic.NewAttachTransactionByIdLogic(c.Request.Context(), svcCtx)
resp, err := l.AttachById(&req)
result.HttpResult(c, resp, err)
}
}

View File

@ -725,6 +725,7 @@ func RegisterHandlers(router *gin.Engine, serverCtx *svc.ServiceContext) {
{ {
iapAppleGroupRouter.GET("/status", publicIapApple.GetAppleStatusHandler(serverCtx)) iapAppleGroupRouter.GET("/status", publicIapApple.GetAppleStatusHandler(serverCtx))
iapAppleGroupRouter.POST("/transactions/attach", publicIapApple.AttachAppleTransactionHandler(serverCtx)) iapAppleGroupRouter.POST("/transactions/attach", publicIapApple.AttachAppleTransactionHandler(serverCtx))
iapAppleGroupRouter.POST("/transactions/attach_by_id", publicIapApple.AttachAppleTransactionByIdHandler(serverCtx))
iapAppleGroupRouter.POST("/restore", publicIapApple.RestoreAppleTransactionsHandler(serverCtx)) iapAppleGroupRouter.POST("/restore", publicIapApple.RestoreAppleTransactionsHandler(serverCtx))
} }

View File

@ -131,6 +131,8 @@ func parsePaymentPlatformConfig(ctx context.Context, platform payment.Platform,
return handleConfig("Epay", &paymentModel.EPayConfig{}) return handleConfig("Epay", &paymentModel.EPayConfig{})
case payment.CryptoSaaS: case payment.CryptoSaaS:
return handleConfig("CryptoSaaS", &paymentModel.CryptoSaaSConfig{}) return handleConfig("CryptoSaaS", &paymentModel.CryptoSaaSConfig{})
case payment.AppleIAP:
return handleConfig("AppleIAP", &paymentModel.AppleIAPConfig{})
default: default:
return "" return ""
} }

View File

@ -1,18 +1,18 @@
package user package user
import ( import (
"context" "context"
"fmt" "fmt"
"time" "time"
"github.com/perfect-panel/server/pkg/xerr" "github.com/perfect-panel/server/pkg/xerr"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/perfect-panel/server/internal/svc" "github.com/perfect-panel/server/internal/config"
"github.com/perfect-panel/server/internal/types" "github.com/perfect-panel/server/internal/svc"
"github.com/perfect-panel/server/internal/config" "github.com/perfect-panel/server/internal/types"
"github.com/perfect-panel/server/pkg/logger" "github.com/perfect-panel/server/pkg/logger"
"gorm.io/gorm" "gorm.io/gorm"
) )
type DeleteUserDeviceLogic struct { type DeleteUserDeviceLogic struct {
@ -31,32 +31,32 @@ func NewDeleteUserDeviceLogic(ctx context.Context, svcCtx *svc.ServiceContext) *
} }
func (l *DeleteUserDeviceLogic) DeleteUserDevice(req *types.DeleteUserDeivceRequest) error { func (l *DeleteUserDeviceLogic) DeleteUserDevice(req *types.DeleteUserDeivceRequest) error {
device, findErr := l.svcCtx.UserModel.FindOneDevice(l.ctx, req.Id) device, findErr := l.svcCtx.UserModel.FindOneDevice(l.ctx, req.Id)
if findErr != nil { if findErr != nil {
if errors.Is(findErr, gorm.ErrRecordNotFound) { if errors.Is(findErr, gorm.ErrRecordNotFound) {
return nil return nil
} }
return errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "get Device error: %v", findErr.Error()) return errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "get Device error: %v", findErr.Error())
} }
// 尝试踢下线在线设备 // 尝试踢下线在线设备
l.svcCtx.DeviceManager.KickDevice(device.UserId, device.Identifier) l.svcCtx.DeviceManager.KickDevice(device.UserId, device.Identifier)
// 清理与设备相关的缓存会话 // 清理与设备相关的缓存会话
ctx, cancel := context.WithTimeout(l.ctx, 2*time.Second) ctx, cancel := context.WithTimeout(l.ctx, 2*time.Second)
defer cancel() defer cancel()
deviceCacheKey := fmt.Sprintf("%v:%v", config.DeviceCacheKeyKey, device.Identifier) deviceCacheKey := fmt.Sprintf("%v:%v", config.DeviceCacheKeyKey, device.Identifier)
if sessionId, rerr := l.svcCtx.Redis.Get(ctx, deviceCacheKey).Result(); rerr == nil && sessionId != "" { if sessionId, rerr := l.svcCtx.Redis.Get(ctx, deviceCacheKey).Result(); rerr == nil && sessionId != "" {
_ = l.svcCtx.Redis.Del(ctx, deviceCacheKey).Err() _ = l.svcCtx.Redis.Del(ctx, deviceCacheKey).Err()
sessionIdCacheKey := fmt.Sprintf("%v:%v", config.SessionIdKey, sessionId) sessionIdCacheKey := fmt.Sprintf("%v:%v", config.SessionIdKey, sessionId)
_ = l.svcCtx.Redis.Del(ctx, sessionIdCacheKey).Err() _ = l.svcCtx.Redis.Del(ctx, sessionIdCacheKey).Err()
sessionsKey := fmt.Sprintf("%s%v", config.UserSessionsKeyPrefix, device.UserId) sessionsKey := fmt.Sprintf("%s%v", config.UserSessionsKeyPrefix, device.UserId)
_ = l.svcCtx.Redis.ZRem(ctx, sessionsKey, sessionId).Err() _ = l.svcCtx.Redis.ZRem(ctx, sessionsKey, sessionId).Err()
} }
// 最后删除数据库记录 // 最后删除数据库记录
if err := l.svcCtx.UserModel.DeleteDevice(l.ctx, req.Id); err != nil { if err := l.svcCtx.UserModel.DeleteDevice(l.ctx, req.Id); err != nil {
return errors.Wrapf(xerr.NewErrCode(xerr.DatabaseDeletedError), "delete user error: %v", err.Error()) return errors.Wrapf(xerr.NewErrCode(xerr.DatabaseDeletedError), "delete user error: %v", err.Error())
} }
return nil return nil
} }

View File

@ -88,26 +88,26 @@ func (l *BindDeviceLogic) createDeviceForUser(identifier, ip, userAgent string,
logger.Field("user_id", userId), logger.Field("user_id", userId),
) )
// enforce device bind limit before creating // enforce device bind limit before creating
if limit := l.svcCtx.SessionLimit(); limit > 0 { if limit := l.svcCtx.SessionLimit(); limit > 0 {
if _, count, err := l.svcCtx.UserModel.QueryDeviceList(l.ctx, userId); err == nil { if _, count, err := l.svcCtx.UserModel.QueryDeviceList(l.ctx, userId); err == nil {
if count >= limit { if count >= limit {
l.Infow("device bind blocked by limit", l.Infow("device bind blocked by limit",
logger.Field("user_id", userId), logger.Field("user_id", userId),
logger.Field("identifier", identifier), logger.Field("identifier", identifier),
logger.Field("count", count), logger.Field("count", count),
logger.Field("limit", limit)) logger.Field("limit", limit))
return xerr.NewErrCodeMsg(xerr.DeviceBindLimitExceeded, "账户绑定设备数已达上限,请移除其他设备后再登录,您也可以再注册一个新账户使用,点击帮助中心查看更多详情。") return xerr.NewErrCodeMsg(xerr.DeviceBindLimitExceeded, "账户绑定设备数已达上限,请移除其他设备后再登录,您也可以再注册一个新账户使用,点击帮助中心查看更多详情。")
} else { } else {
l.Infow("device bind limit check", l.Infow("device bind limit check",
logger.Field("user_id", userId), logger.Field("user_id", userId),
logger.Field("identifier", identifier), logger.Field("identifier", identifier),
logger.Field("count", count), logger.Field("count", count),
logger.Field("limit", limit)) logger.Field("limit", limit))
} }
} }
} }
err := l.svcCtx.UserModel.Transaction(l.ctx, func(db *gorm.DB) error { err := l.svcCtx.UserModel.Transaction(l.ctx, func(db *gorm.DB) error {
// Create device auth method // Create device auth method
authMethod := &user.AuthMethods{ authMethod := &user.AuthMethods{
UserId: userId, UserId: userId,
@ -142,8 +142,8 @@ func (l *BindDeviceLogic) createDeviceForUser(identifier, ip, userAgent string,
return errors.Wrapf(xerr.NewErrCode(xerr.DatabaseInsertError), "create device failed: %v", err) return errors.Wrapf(xerr.NewErrCode(xerr.DatabaseInsertError), "create device failed: %v", err)
} }
return nil return nil
}) })
if err != nil { if err != nil {
l.Errorw("device creation failed", l.Errorw("device creation failed",
@ -163,28 +163,28 @@ func (l *BindDeviceLogic) createDeviceForUser(identifier, ip, userAgent string,
} }
func (l *BindDeviceLogic) rebindDeviceToNewUser(deviceInfo *user.Device, ip, userAgent string, newUserId int64) error { func (l *BindDeviceLogic) rebindDeviceToNewUser(deviceInfo *user.Device, ip, userAgent string, newUserId int64) error {
oldUserId := deviceInfo.UserId oldUserId := deviceInfo.UserId
// enforce device bind limit before rebind // enforce device bind limit before rebind
if limit := l.svcCtx.SessionLimit(); limit > 0 { if limit := l.svcCtx.SessionLimit(); limit > 0 {
if _, count, err := l.svcCtx.UserModel.QueryDeviceList(l.ctx, newUserId); err == nil { if _, count, err := l.svcCtx.UserModel.QueryDeviceList(l.ctx, newUserId); err == nil {
if count >= limit { if count >= limit {
l.Infow("device rebind blocked by limit", l.Infow("device rebind blocked by limit",
logger.Field("new_user_id", newUserId), logger.Field("new_user_id", newUserId),
logger.Field("identifier", deviceInfo.Identifier), logger.Field("identifier", deviceInfo.Identifier),
logger.Field("count", count), logger.Field("count", count),
logger.Field("limit", limit)) logger.Field("limit", limit))
return xerr.NewErrCodeMsg(xerr.DeviceBindLimitExceeded, "账户绑定设备数已达上限,请移除其他设备后再登录,您也可以再注册一个新账户使用,点击帮助中心查看更多详情。") return xerr.NewErrCodeMsg(xerr.DeviceBindLimitExceeded, "账户绑定设备数已达上限,请移除其他设备后再登录,您也可以再注册一个新账户使用,点击帮助中心查看更多详情。")
} else { } else {
l.Infow("device rebind limit check", l.Infow("device rebind limit check",
logger.Field("new_user_id", newUserId), logger.Field("new_user_id", newUserId),
logger.Field("identifier", deviceInfo.Identifier), logger.Field("identifier", deviceInfo.Identifier),
logger.Field("count", count), logger.Field("count", count),
logger.Field("limit", limit)) logger.Field("limit", limit))
} }
} }
} }
err := l.svcCtx.UserModel.Transaction(l.ctx, func(db *gorm.DB) error { err := l.svcCtx.UserModel.Transaction(l.ctx, func(db *gorm.DB) error {
// Check if old user has other auth methods besides device // Check if old user has other auth methods besides device
var authMethods []user.AuthMethods var authMethods []user.AuthMethods
if err := db.Where("user_id = ?", oldUserId).Find(&authMethods).Error; err != nil { if err := db.Where("user_id = ?", oldUserId).Find(&authMethods).Error; err != nil {
@ -249,8 +249,8 @@ func (l *BindDeviceLogic) rebindDeviceToNewUser(deviceInfo *user.Device, ip, use
return errors.Wrapf(xerr.NewErrCode(xerr.DatabaseUpdateError), "update device failed: %v", err) return errors.Wrapf(xerr.NewErrCode(xerr.DatabaseUpdateError), "update device failed: %v", err)
} }
return nil return nil
}) })
if err != nil { if err != nil {
l.Errorw("device rebinding failed", l.Errorw("device rebinding failed",

View File

@ -10,12 +10,20 @@ import (
"gorm.io/gorm" "gorm.io/gorm"
) )
// AppleIAPNotifyLogic 用于处理 App Store Server Notifications V2 的苹果内购通知
// 负责JWS 验签、事务记录写入/撤销更新、订阅生命周期同步(续期/撤销等)
type AppleIAPNotifyLogic struct { type AppleIAPNotifyLogic struct {
logger.Logger logger.Logger
ctx context.Context ctx context.Context
svcCtx *svc.ServiceContext svcCtx *svc.ServiceContext
} }
// NewAppleIAPNotifyLogic 创建通知处理逻辑实例
// 参数:
// - ctx: 请求上下文
// - svcCtx: 服务上下文,包含 DB/Redis/配置 等
// 返回:
// - *AppleIAPNotifyLogic: 通知处理逻辑对象
func NewAppleIAPNotifyLogic(ctx context.Context, svcCtx *svc.ServiceContext) *AppleIAPNotifyLogic { func NewAppleIAPNotifyLogic(ctx context.Context, svcCtx *svc.ServiceContext) *AppleIAPNotifyLogic {
return &AppleIAPNotifyLogic{ return &AppleIAPNotifyLogic{
Logger: logger.WithContext(ctx), Logger: logger.WithContext(ctx),
@ -24,15 +32,30 @@ func NewAppleIAPNotifyLogic(ctx context.Context, svcCtx *svc.ServiceContext) *Ap
} }
} }
// Handle 处理苹果内购通知
// 流程:
// 1. 验签通知信封,解析得到交易 JWS 并再次验签;
// 2. 写入或更新事务记录(幂等按 OriginalTransactionId
// 3. 依据产品映射更新订阅到期时间或撤销状态;
// 4. 全流程关键节点输出详细中文日志,便于定位问题。
// 参数:
// - signedPayload: 通知信封的 JWS包含 data.signedTransactionInfo
// 返回:
// - error: 处理失败错误,成功返回 nil
func (l *AppleIAPNotifyLogic) Handle(signedPayload string) error { func (l *AppleIAPNotifyLogic) Handle(signedPayload string) error {
txPayload, _, err := iapapple.VerifyNotificationSignedPayload(signedPayload) txPayload, ntype, err := iapapple.VerifyNotificationSignedPayload(signedPayload)
if err != nil { if err != nil {
// 验签失败,记录错误以便排查(通常为 JWS 格式/证书链问题)
l.Errorw("iap notify verify failed", logger.Field("error", err.Error()))
return err return err
} }
// 验签通过,记录通知类型与关键交易标识
l.Infow("iap notify verified", logger.Field("type", ntype), logger.Field("productId", txPayload.ProductId), logger.Field("originalTransactionId", txPayload.OriginalTransactionId))
return l.svcCtx.DB.Transaction(func(db *gorm.DB) error { return l.svcCtx.DB.Transaction(func(db *gorm.DB) error {
var existing *iapmodel.Transaction var existing *iapmodel.Transaction
existing, _ = iapmodel.NewModel(l.svcCtx.DB, l.svcCtx.Redis).FindByOriginalId(l.ctx, txPayload.OriginalTransactionId) existing, _ = iapmodel.NewModel(l.svcCtx.DB, l.svcCtx.Redis).FindByOriginalId(l.ctx, txPayload.OriginalTransactionId)
if existing == nil || existing.Id == 0 { if existing == nil || existing.Id == 0 {
// 首次出现该事务,写入记录
rec := &iapmodel.Transaction{ rec := &iapmodel.Transaction{
UserId: 0, UserId: 0,
OriginalTransactionId: txPayload.OriginalTransactionId, OriginalTransactionId: txPayload.OriginalTransactionId,
@ -43,35 +66,50 @@ func (l *AppleIAPNotifyLogic) Handle(signedPayload string) error {
JWSHash: "", JWSHash: "",
} }
if e := db.Model(&iapmodel.Transaction{}).Create(rec).Error; e != nil { if e := db.Model(&iapmodel.Transaction{}).Create(rec).Error; e != nil {
// 事务写入失败(唯一约束/字段问题),输出详细日志
l.Errorw("iap notify insert transaction error", logger.Field("error", e.Error()), logger.Field("productId", txPayload.ProductId), logger.Field("originalTransactionId", txPayload.OriginalTransactionId))
return e return e
} }
} else { } else {
if txPayload.RevocationDate != nil { if txPayload.RevocationDate != nil {
// 撤销场景:更新 revocation_at
if e := db.Model(&iapmodel.Transaction{}). if e := db.Model(&iapmodel.Transaction{}).
Where("original_transaction_id = ?", txPayload.OriginalTransactionId). Where("original_transaction_id = ?", txPayload.OriginalTransactionId).
Update("revocation_at", txPayload.RevocationDate).Error; e != nil { Update("revocation_at", txPayload.RevocationDate).Error; e != nil {
// 撤销更新失败,记录日志
l.Errorw("iap notify update revocation error", logger.Field("error", e.Error()), logger.Field("originalTransactionId", txPayload.OriginalTransactionId))
return e return e
} }
} }
} }
pm, _ := iapapple.ParseProductMap(l.svcCtx.Config.Site.CustomData) pm, _ := iapapple.ParseProductMap(l.svcCtx.Config.Site.CustomData)
m := pm.Items[txPayload.ProductId] m := pm.Items[txPayload.ProductId]
// 若产品映射缺失,记录警告日志(不影响事务入库)
if m.DurationDays == 0 {
l.Errorw("iap notify product mapping missing", logger.Field("productId", txPayload.ProductId))
}
token := "iap:" + txPayload.OriginalTransactionId token := "iap:" + txPayload.OriginalTransactionId
sub, e := l.svcCtx.UserModel.FindOneSubscribeByToken(l.ctx, token) sub, e := l.svcCtx.UserModel.FindOneSubscribeByToken(l.ctx, token)
if e == nil && sub != nil && sub.Id != 0 { if e == nil && sub != nil && sub.Id != 0 {
if txPayload.RevocationDate != nil { if txPayload.RevocationDate != nil {
// 撤销:订阅置为过期并记录完成时间
sub.Status = 3 sub.Status = 3
t := *txPayload.RevocationDate t := *txPayload.RevocationDate
sub.FinishedAt = &t sub.FinishedAt = &t
sub.ExpireTime = t sub.ExpireTime = t
} else if m.DurationDays > 0 { } else if m.DurationDays > 0 {
// 正常:根据映射天数续期
exp := iapapple.CalcExpire(txPayload.PurchaseDate, m.DurationDays) exp := iapapple.CalcExpire(txPayload.PurchaseDate, m.DurationDays)
sub.ExpireTime = exp sub.ExpireTime = exp
sub.Status = 1 sub.Status = 1
} }
if e := l.svcCtx.UserModel.UpdateSubscribe(l.ctx, sub, db); e != nil { if e := l.svcCtx.UserModel.UpdateSubscribe(l.ctx, sub, db); e != nil {
// 订阅更新失败,记录日志
l.Errorw("iap notify update subscribe error", logger.Field("error", e.Error()), logger.Field("userSubscribeId", sub.Id))
return e return e
} }
// 更新成功,输出订阅状态
l.Infow("iap notify updated subscribe", logger.Field("userSubscribeId", sub.Id), logger.Field("status", sub.Status))
} }
return nil return nil
}) })

View File

@ -0,0 +1,72 @@
package apple
import (
"context"
"github.com/perfect-panel/server/internal/model/payment"
"github.com/perfect-panel/server/internal/svc"
"github.com/perfect-panel/server/internal/types"
"github.com/perfect-panel/server/pkg/constant"
iapapple "github.com/perfect-panel/server/pkg/iap/apple"
"github.com/perfect-panel/server/pkg/logger"
"github.com/perfect-panel/server/pkg/xerr"
"github.com/pkg/errors"
)
type AttachTransactionByIdLogic struct {
logger.Logger
ctx context.Context
svcCtx *svc.ServiceContext
}
func NewAttachTransactionByIdLogic(ctx context.Context, svcCtx *svc.ServiceContext) *AttachTransactionByIdLogic {
return &AttachTransactionByIdLogic{
Logger: logger.WithContext(ctx),
ctx: ctx,
svcCtx: svcCtx,
}
}
func (l *AttachTransactionByIdLogic) AttachById(req *types.AttachAppleTransactionByIdRequest) (*types.AttachAppleTransactionResponse, error) {
_, ok := l.ctx.Value(constant.CtxKeyUser).(*struct{ Id int64 })
if !ok {
return nil, errors.Wrapf(xerr.NewErrCode(xerr.InvalidAccess), "invalid access")
}
ord, err := l.svcCtx.OrderModel.FindOneByOrderNo(l.ctx, req.OrderNo)
if err != nil {
return nil, errors.Wrapf(xerr.NewErrCode(xerr.OrderNotExist), "order not exist")
}
pay, err := l.svcCtx.PaymentModel.FindOne(l.ctx, ord.PaymentId)
if err != nil {
return nil, errors.Wrapf(xerr.NewErrCode(xerr.PaymentMethodNotFound), "payment not found")
}
var cfg payment.AppleIAPConfig
if err := cfg.Unmarshal([]byte(pay.Config)); err != nil {
return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "iap config error")
}
apiCfg := iapapple.ServerAPIConfig{
KeyID: cfg.KeyID,
IssuerID: cfg.IssuerID,
PrivateKey: cfg.PrivateKey,
Sandbox: cfg.Sandbox,
}
if req.Sandbox != nil {
apiCfg.Sandbox = *req.Sandbox
}
if apiCfg.KeyID == "" || apiCfg.IssuerID == "" || apiCfg.PrivateKey == "" {
return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "apple server api credential missing")
}
jws, err := iapapple.GetTransactionInfo(apiCfg, req.TransactionId)
if err != nil {
l.Errorw("fetch transaction info error", logger.Field("error", err.Error()))
return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "fetch transaction info error")
}
// reuse existing attach logic with JWS
attach := NewAttachTransactionLogic(l.ctx, l.svcCtx)
return attach.Attach(&types.AttachAppleTransactionRequest{
SignedTransactionJWS: jws,
SubscribeId: 0,
DurationDays: 0,
Tier: "",
OrderNo: req.OrderNo,
})
}

View File

@ -127,3 +127,26 @@ func (l *CryptoSaaSConfig) Unmarshal(data []byte) error {
aux := (*Alias)(l) aux := (*Alias)(l)
return json.Unmarshal(data, &aux) return json.Unmarshal(data, &aux)
} }
type AppleIAPConfig struct {
ProductIds []string `json:"product_ids"`
KeyID string `json:"key_id"`
IssuerID string `json:"issuer_id"`
PrivateKey string `json:"private_key"`
Sandbox bool `json:"sandbox"`
}
func (l *AppleIAPConfig) Marshal() ([]byte, error) {
type Alias AppleIAPConfig
return json.Marshal(&struct {
*Alias
}{
Alias: (*Alias)(l),
})
}
func (l *AppleIAPConfig) Unmarshal(data []byte) error {
type Alias AppleIAPConfig
aux := (*Alias)(l)
return json.Unmarshal(data, &aux)
}

View File

@ -2873,6 +2873,11 @@ type AttachAppleTransactionResponse struct {
Tier string `json:"tier"` Tier string `json:"tier"`
} }
type AttachAppleTransactionByIdRequest struct {
TransactionId string `json:"transaction_id" validate:"required"`
OrderNo string `json:"order_no" validate:"required"`
Sandbox *bool `json:"sandbox,omitempty"`
}
type RestoreAppleTransactionsRequest struct { type RestoreAppleTransactionsRequest struct {
Transactions []string `json:"transactions" validate:"required"` Transactions []string `json:"transactions" validate:"required"`
} }

View File

@ -8,14 +8,25 @@ import (
"encoding/json" "encoding/json"
"strings" "strings"
"time" "time"
"unicode"
) )
func cleanB64(s string) string {
trimmed := strings.TrimSpace(s)
return strings.Map(func(r rune) rune {
if unicode.IsSpace(r) {
return -1
}
return r
}, trimmed)
}
func ParseTransactionJWS(jws string) (*TransactionPayload, error) { func ParseTransactionJWS(jws string) (*TransactionPayload, error) {
parts := strings.Split(jws, ".") parts := strings.Split(strings.TrimSpace(jws), ".")
if len(parts) != 3 { if len(parts) != 3 {
return nil, ErrInvalidJWS return nil, ErrInvalidJWS
} }
payloadB64 := parts[1] payloadB64 := cleanB64(parts[1])
// add padding if required // add padding if required
switch len(payloadB64) % 4 { switch len(payloadB64) % 4 {
case 2: case 2:
@ -63,11 +74,11 @@ type jwsHeader struct {
} }
func VerifyTransactionJWS(jws string) (*TransactionPayload, error) { func VerifyTransactionJWS(jws string) (*TransactionPayload, error) {
parts := strings.Split(jws, ".") parts := strings.Split(strings.TrimSpace(jws), ".")
if len(parts) != 3 { if len(parts) != 3 {
return nil, ErrInvalidJWS return nil, ErrInvalidJWS
} }
hdrB64 := parts[0] hdrB64 := cleanB64(parts[0])
switch len(hdrB64) % 4 { switch len(hdrB64) % 4 {
case 2: case 2:
hdrB64 += "==" hdrB64 += "=="
@ -97,13 +108,14 @@ func VerifyTransactionJWS(jws string) (*TransactionPayload, error) {
if !ok { if !ok {
return nil, ErrInvalidJWS return nil, ErrInvalidJWS
} }
signingInput := parts[0] + "." + parts[1] signingInput := cleanB64(parts[0]) + "." + cleanB64(parts[1])
sig, err := base64.RawURLEncoding.DecodeString(parts[2]) sig := cleanB64(parts[2])
sigBytes, err := base64.RawURLEncoding.DecodeString(sig)
if err != nil { if err != nil {
return nil, err return nil, err
} }
d := sha256.Sum256([]byte(signingInput)) d := sha256.Sum256([]byte(signingInput))
if !ecdsa.VerifyASN1(pub, d[:], sig) { if !ecdsa.VerifyASN1(pub, d[:], sigBytes) {
return nil, ErrInvalidJWS return nil, ErrInvalidJWS
} }
return ParseTransactionJWS(jws) return ParseTransactionJWS(jws)

View File

@ -12,11 +12,11 @@ type NotificationEnvelope struct {
} }
func ParseNotificationSignedPayload(jws string) (*NotificationEnvelope, error) { func ParseNotificationSignedPayload(jws string) (*NotificationEnvelope, error) {
parts := strings.Split(jws, ".") parts := strings.Split(strings.TrimSpace(jws), ".")
if len(parts) != 3 { if len(parts) != 3 {
return nil, ErrInvalidJWS return nil, ErrInvalidJWS
} }
payloadB64 := parts[1] payloadB64 := cleanB64(parts[1])
switch len(payloadB64) % 4 { switch len(payloadB64) % 4 {
case 2: case 2:
payloadB64 += "==" payloadB64 += "=="

112
pkg/iap/apple/serverapi.go Normal file
View File

@ -0,0 +1,112 @@
package apple
import (
"bytes"
"crypto/ecdsa"
"crypto/rand"
"crypto/sha256"
"crypto/x509"
"encoding/base64"
"encoding/json"
"encoding/pem"
"fmt"
"net/http"
"time"
)
type ServerAPIConfig struct {
KeyID string
IssuerID string
PrivateKey string
Sandbox bool
}
func buildAPIToken(cfg ServerAPIConfig) (string, error) {
header := map[string]string{
"alg": "ES256",
"kid": cfg.KeyID,
"typ": "JWT",
}
now := time.Now().Unix()
payload := map[string]interface{}{
"iss": cfg.IssuerID,
"iat": now,
"exp": now + 1800,
"aud": "appstoreconnect-v1",
}
hb, _ := json.Marshal(header)
pb, _ := json.Marshal(payload)
enc := func(b []byte) string {
return base64.RawURLEncoding.EncodeToString(b)
}
unsigned := fmt.Sprintf("%s.%s", enc(hb), enc(pb))
block, _ := pem.Decode([]byte(cfg.PrivateKey))
if block == nil {
return "", fmt.Errorf("invalid private key")
}
keyAny, err := x509.ParsePKCS8PrivateKey(block.Bytes)
if err != nil {
return "", err
}
priv, ok := keyAny.(*ecdsa.PrivateKey)
if !ok {
return "", fmt.Errorf("private key is not ECDSA")
}
hash := unsigned // ES256 signs SHA-256 of input; jwt libs do hashing, we implement manually
digest := sha256Sum([]byte(hash))
sig, err := ecdsa.SignASN1(rand.Reader, priv, digest)
if err != nil {
return "", err
}
return unsigned + "." + base64.RawURLEncoding.EncodeToString(sig), nil
}
func sha256Sum(b []byte) []byte {
h := sha256.New()
h.Write(b)
return h.Sum(nil)
}
func GetTransactionInfo(cfg ServerAPIConfig, transactionId string) (string, error) {
token, err := buildAPIToken(cfg)
if err != nil {
return "", err
}
try := func(host string) (string, int, string, error) {
url := fmt.Sprintf("%s/inApps/v1/transactions/%s", host, transactionId)
req, _ := http.NewRequest("GET", url, nil)
req.Header.Set("Authorization", "Bearer "+token)
resp, err := http.DefaultClient.Do(req)
if err != nil {
return "", 0, "", err
}
defer resp.Body.Close()
buf := new(bytes.Buffer)
_, _ = buf.ReadFrom(resp.Body)
if resp.StatusCode != 200 {
return "", resp.StatusCode, buf.String(), fmt.Errorf("apple api error: %d", resp.StatusCode)
}
var body struct {
SignedTransactionInfo string `json:"signedTransactionInfo"`
}
if err := json.Unmarshal(buf.Bytes(), &body); err != nil {
return "", resp.StatusCode, buf.String(), err
}
return body.SignedTransactionInfo, resp.StatusCode, buf.String(), nil
}
primary := "https://api.storekit.itunes.apple.com"
secondary := "https://api.storekit-sandbox.itunes.apple.com"
if cfg.Sandbox {
primary, secondary = secondary, primary
}
jws, code, body, err := try(primary)
if err == nil && jws != "" {
return jws, nil
}
// Fallback to the other environment if primary failed (common when env mismatches)
jws2, code2, body2, err2 := try(secondary)
if err2 == nil && jws2 != "" {
return jws2, nil
}
return "", fmt.Errorf("apple api error, primary[%d:%s], secondary[%d:%s]", code, body, code2, body2)
}