fix(iap): 修复JWS验证逻辑,支持原始R||S签名格式
All checks were successful
Build docker and publish / build (20.15.1) (push) Successful in 6m38s

fix(middleware): 增加设备中间件的日志记录
fix(auth): 优化认证中间件的错误日志记录
feat(iap): 添加苹果交易附加逻辑的详细日志
This commit is contained in:
shanshanzhong 2025-12-15 23:44:55 -08:00
parent 3c6dd5058b
commit e11ed2338d
4 changed files with 75 additions and 31 deletions

View File

@ -2,7 +2,10 @@ package apple
import (
"context"
"strings"
"github.com/perfect-panel/server/internal/model/payment"
"github.com/perfect-panel/server/internal/model/user"
"github.com/perfect-panel/server/internal/svc"
"github.com/perfect-panel/server/internal/types"
"github.com/perfect-panel/server/pkg/constant"
@ -27,20 +30,34 @@ func NewAttachTransactionByIdLogic(ctx context.Context, svcCtx *svc.ServiceConte
}
func (l *AttachTransactionByIdLogic) AttachById(req *types.AttachAppleTransactionByIdRequest) (*types.AttachAppleTransactionResponse, error) {
_, ok := l.ctx.Value(constant.CtxKeyUser).(*struct{ Id int64 })
if !ok {
l.Infow("attach by transaction id start", logger.Field("orderNo", req.OrderNo), logger.Field("transactionId", req.TransactionId))
u, ok := l.ctx.Value(constant.CtxKeyUser).(*user.User)
if !ok || u == nil {
l.Errorw("attach by id invalid access")
return nil, errors.Wrapf(xerr.NewErrCode(xerr.InvalidAccess), "invalid access")
}
ord, err := l.svcCtx.OrderModel.FindOneByOrderNo(l.ctx, req.OrderNo)
if err != nil {
l.Errorw("attach by id order not exist", logger.Field("orderNo", req.OrderNo))
return nil, errors.Wrapf(xerr.NewErrCode(xerr.OrderNotExist), "order not exist")
}
pay, err := l.svcCtx.PaymentModel.FindOne(l.ctx, ord.PaymentId)
if err != nil {
l.Errorw("attach by id payment not found", logger.Field("paymentId", ord.PaymentId))
return nil, errors.Wrapf(xerr.NewErrCode(xerr.PaymentMethodNotFound), "payment not found")
}
hasKey := false
if pay.Config != "" && (strings.Contains(pay.Config, "-----BEGIN PRIVATE KEY-----") || strings.Contains(pay.Config, "BEGIN PRIVATE KEY")) {
hasKey = true
}
l.Infow("attach by id payment config meta", logger.Field("paymentId", pay.Id), logger.Field("platform", pay.Platform), logger.Field("config_len", len(pay.Config)), logger.Field("has_private_key", hasKey))
if pay.Config == "" {
l.Errorw("attach by id iap config empty", logger.Field("paymentId", pay.Id))
return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "iap config is empty")
}
var cfg payment.AppleIAPConfig
if err := cfg.Unmarshal([]byte(pay.Config)); err != nil {
l.Errorw("attach by id iap config error", logger.Field("error", err.Error()), logger.Field("paymentId", pay.Id), logger.Field("platform", pay.Platform), logger.Field("config_len", len(pay.Config)), logger.Field("has_private_key", hasKey))
return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "iap config error")
}
apiCfg := iapapple.ServerAPIConfig{
@ -53,6 +70,7 @@ func (l *AttachTransactionByIdLogic) AttachById(req *types.AttachAppleTransactio
apiCfg.Sandbox = *req.Sandbox
}
if apiCfg.KeyID == "" || apiCfg.IssuerID == "" || apiCfg.PrivateKey == "" {
l.Errorw("attach by id credential missing")
return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "apple server api credential missing")
}
jws, err := iapapple.GetTransactionInfo(apiCfg, req.TransactionId)
@ -62,11 +80,17 @@ func (l *AttachTransactionByIdLogic) AttachById(req *types.AttachAppleTransactio
}
// reuse existing attach logic with JWS
attach := NewAttachTransactionLogic(l.ctx, l.svcCtx)
return attach.Attach(&types.AttachAppleTransactionRequest{
resp, e := attach.Attach(&types.AttachAppleTransactionRequest{
SignedTransactionJWS: jws,
SubscribeId: 0,
DurationDays: 0,
Tier: "",
OrderNo: req.OrderNo,
})
if e != nil {
l.Errorw("attach by id commit error", logger.Field("error", e.Error()))
return nil, e
}
l.Infow("attach by transaction id ok", logger.Field("orderNo", req.OrderNo), logger.Field("transactionId", req.TransactionId), logger.Field("expiresAt", resp.ExpiresAt))
return resp, nil
}

View File

@ -27,7 +27,7 @@ func AuthMiddleware(svc *svc.ServiceContext) func(c *gin.Context) {
// get token from header
token := c.GetHeader("Authorization")
if token == "" {
logger.WithContext(c.Request.Context()).Debug("[AuthMiddleware] Token Empty")
logger.WithContext(c.Request.Context()).Errorw("[AuthMiddleware] token empty", logger.Field("path", c.Request.URL.Path))
result.HttpResult(c, nil, errors.Wrapf(xerr.NewErrCode(xerr.ErrorTokenEmpty), "Token Empty"))
c.Abort()
return
@ -35,7 +35,7 @@ func AuthMiddleware(svc *svc.ServiceContext) func(c *gin.Context) {
// parse token
claims, err := jwt.ParseJwtToken(token, jwtConfig.AccessSecret)
if err != nil {
logger.WithContext(c.Request.Context()).Debug("[AuthMiddleware] ParseJwtToken", logger.Field("error", err.Error()), logger.Field("token", token))
logger.WithContext(c.Request.Context()).Errorw("[AuthMiddleware] parse token failed", logger.Field("error", err.Error()))
result.HttpResult(c, nil, errors.Wrapf(xerr.NewErrCode(xerr.ErrorTokenExpire), "Token Invalid"))
c.Abort()
return
@ -53,7 +53,7 @@ func AuthMiddleware(svc *svc.ServiceContext) func(c *gin.Context) {
sessionIdCacheKey := fmt.Sprintf("%v:%v", config.SessionIdKey, sessionId)
value, err := svc.Redis.Get(c, sessionIdCacheKey).Result()
if err != nil {
logger.WithContext(c.Request.Context()).Debug("[AuthMiddleware] Redis Get", logger.Field("error", err.Error()), logger.Field("sessionId", sessionId))
logger.WithContext(c.Request.Context()).Errorw("[AuthMiddleware] redis get failed", logger.Field("error", err.Error()), logger.Field("sessionId", sessionId))
result.HttpResult(c, nil, errors.Wrapf(xerr.NewErrCode(xerr.InvalidAccess), "Invalid Access"))
c.Abort()
return
@ -61,7 +61,7 @@ func AuthMiddleware(svc *svc.ServiceContext) func(c *gin.Context) {
//verify user id
if value != fmt.Sprintf("%v", userId) {
logger.WithContext(c.Request.Context()).Debug("[AuthMiddleware] Invalid Access", logger.Field("userId", userId), logger.Field("sessionId", sessionId))
logger.WithContext(c.Request.Context()).Errorw("[AuthMiddleware] user mismatch", logger.Field("userId", userId), logger.Field("sessionId", sessionId), logger.Field("value", value))
result.HttpResult(c, nil, errors.Wrapf(xerr.NewErrCode(xerr.InvalidAccess), "Invalid Access"))
c.Abort()
return
@ -69,7 +69,7 @@ func AuthMiddleware(svc *svc.ServiceContext) func(c *gin.Context) {
userInfo, err := svc.UserModel.FindOne(c, userId)
if err != nil {
logger.WithContext(c.Request.Context()).Debug("[AuthMiddleware] UserModel FindOne", logger.Field("error", err.Error()), logger.Field("userId", userId))
logger.WithContext(c.Request.Context()).Errorw("[AuthMiddleware] user find failed", logger.Field("error", err.Error()), logger.Field("userId", userId))
result.HttpResult(c, nil, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "Database Query Error"))
c.Abort()
return
@ -77,11 +77,12 @@ func AuthMiddleware(svc *svc.ServiceContext) func(c *gin.Context) {
// admin verify
paths := strings.Split(c.Request.URL.Path, "/")
if tool.StringSliceContains(paths, "admin") && !*userInfo.IsAdmin {
logger.WithContext(c.Request.Context()).Debug("[AuthMiddleware] Not Admin User", logger.Field("userId", userId), logger.Field("sessionId", sessionId))
logger.WithContext(c.Request.Context()).Errorw("[AuthMiddleware] not admin", logger.Field("userId", userId), logger.Field("sessionId", sessionId))
result.HttpResult(c, nil, errors.Wrapf(xerr.NewErrCode(xerr.InvalidAccess), "Invalid Access"))
c.Abort()
return
}
logger.WithContext(c.Request.Context()).Infow("[AuthMiddleware] auth ok", logger.Field("userId", userId), logger.Field("loginType", loginType), logger.Field("path", c.Request.URL.Path))
ctx = context.WithValue(ctx, constant.LoginType, loginType)
ctx = context.WithValue(ctx, constant.CtxKeyUser, userInfo)
ctx = context.WithValue(ctx, constant.CtxKeySessionID, sessionId)

View File

@ -14,6 +14,7 @@ import (
"github.com/perfect-panel/server/internal/svc"
pkgaes "github.com/perfect-panel/server/pkg/aes"
"github.com/perfect-panel/server/pkg/constant"
"github.com/perfect-panel/server/pkg/logger"
"github.com/perfect-panel/server/pkg/result"
"github.com/perfect-panel/server/pkg/xerr"
"github.com/pkg/errors"
@ -30,11 +31,13 @@ func DeviceMiddleware(srvCtx *svc.ServiceContext) func(c *gin.Context) {
return func(c *gin.Context) {
if !srvCtx.Config.Device.Enable {
logger.WithContext(c.Request.Context()).Infow("[DeviceMiddleware] disabled")
c.Next()
return
}
if srvCtx.Config.Device.SecuritySecret == "" {
logger.WithContext(c.Request.Context()).Errorw("[DeviceMiddleware] secret empty")
result.HttpResult(c, nil, errors.Wrapf(xerr.NewErrCode(xerr.SecretIsEmpty), "Secret is empty"))
c.Abort()
return
@ -48,12 +51,14 @@ func DeviceMiddleware(srvCtx *svc.ServiceContext) func(c *gin.Context) {
loginType, ok := ctx.Value(constant.LoginType).(string)
if !ok || loginType != "device" {
logger.WithContext(c.Request.Context()).Infow("[DeviceMiddleware] skip encryption", logger.Field("loginType", loginType))
c.Next()
return
}
rw := NewResponseWriter(c, srvCtx)
if !rw.Decrypt() {
logger.WithContext(c.Request.Context()).Errorw("[DeviceMiddleware] decrypt failed", logger.Field("path", c.Request.URL.Path))
result.HttpResult(c, nil, errors.Wrapf(xerr.NewErrCode(xerr.InvalidCiphertext), "Invalid ciphertext"))
c.Abort()
return
@ -125,6 +130,7 @@ func (rw *ResponseWriter) Decrypt() bool {
params := map[string]interface{}{}
err = json.Unmarshal([]byte(decrypt), &params)
if err == nil {
logger.WithContext(rw.c.Request.Context()).Infow("[DeviceMiddleware] query decrypt ok", logger.Field("path", rw.c.Request.URL.Path))
for k, v := range params {
query.Set(k, fmt.Sprintf("%v", v))
}
@ -151,23 +157,28 @@ func (rw *ResponseWriter) Decrypt() bool {
data := params["data"]
nonce := params["time"]
if err != nil || data == nil {
logger.WithContext(rw.c.Request.Context()).Errorw("[DeviceMiddleware] body parse failed", logger.Field("error", err))
return false
}
str, ok := data.(string)
if !ok {
logger.WithContext(rw.c.Request.Context()).Errorw("[DeviceMiddleware] body data type invalid")
return false
}
iv, ok := nonce.(string)
if !ok {
logger.WithContext(rw.c.Request.Context()).Errorw("[DeviceMiddleware] body time type invalid")
return false
}
decrypt, err := pkgaes.Decrypt(str, rw.encryptionKey, iv)
if err != nil {
logger.WithContext(rw.c.Request.Context()).Errorw("[DeviceMiddleware] body decrypt error", logger.Field("error", err.Error()))
return false
}
rw.c.Request.Body = io.NopCloser(bytes.NewBuffer([]byte(decrypt)))
logger.WithContext(rw.c.Request.Context()).Infow("[DeviceMiddleware] body decrypt ok", logger.Field("path", rw.c.Request.URL.Path))
return true
}

View File

@ -6,6 +6,7 @@ import (
"crypto/x509"
"encoding/base64"
"encoding/json"
"math/big"
"strings"
"time"
"unicode"
@ -21,20 +22,26 @@ func cleanB64(s string) string {
}, trimmed)
}
func decodeB64URL(s string) ([]byte, error) {
s = cleanB64(s)
if b, err := base64.RawURLEncoding.DecodeString(s); err == nil {
return b, nil
}
switch len(s) % 4 {
case 2:
s += "=="
case 3:
s += "="
}
return base64.URLEncoding.DecodeString(s)
}
func ParseTransactionJWS(jws string) (*TransactionPayload, error) {
parts := strings.Split(strings.TrimSpace(jws), ".")
if len(parts) != 3 {
return nil, ErrInvalidJWS
}
payloadB64 := cleanB64(parts[1])
// add padding if required
switch len(payloadB64) % 4 {
case 2:
payloadB64 += "=="
case 3:
payloadB64 += "="
}
data, err := base64.RawURLEncoding.DecodeString(payloadB64)
data, err := decodeB64URL(parts[1])
if err != nil {
return nil, err
}
@ -78,15 +85,8 @@ func VerifyTransactionJWS(jws string) (*TransactionPayload, error) {
if len(parts) != 3 {
return nil, ErrInvalidJWS
}
hdrB64 := cleanB64(parts[0])
switch len(hdrB64) % 4 {
case 2:
hdrB64 += "=="
case 3:
hdrB64 += "="
}
var hdr jwsHeader
hdrBytes, err := base64.RawURLEncoding.DecodeString(hdrB64)
hdrBytes, err := decodeB64URL(parts[0])
if err != nil {
return nil, err
}
@ -109,14 +109,22 @@ func VerifyTransactionJWS(jws string) (*TransactionPayload, error) {
return nil, ErrInvalidJWS
}
signingInput := cleanB64(parts[0]) + "." + cleanB64(parts[1])
sig := cleanB64(parts[2])
sigBytes, err := base64.RawURLEncoding.DecodeString(sig)
sigBytes, err := decodeB64URL(parts[2])
if err != nil {
return nil, err
}
d := sha256.Sum256([]byte(signingInput))
if !ecdsa.VerifyASN1(pub, d[:], sigBytes) {
return nil, ErrInvalidJWS
// Try ASN.1 signature first
if ecdsa.VerifyASN1(pub, d[:], sigBytes) {
return ParseTransactionJWS(jws)
}
return ParseTransactionJWS(jws)
// Fallback: raw R||S (JWS ES256 uses raw signature)
if len(sigBytes) == 64 {
r := new(big.Int).SetBytes(sigBytes[:32])
s := new(big.Int).SetBytes(sigBytes[32:])
if ecdsa.Verify(pub, d[:], r, s) {
return ParseTransactionJWS(jws)
}
}
return nil, ErrInvalidJWS
}