diff --git a/internal/logic/public/iap/apple/attachTransactionByIdLogic.go b/internal/logic/public/iap/apple/attachTransactionByIdLogic.go index d245620..5d3ca00 100644 --- a/internal/logic/public/iap/apple/attachTransactionByIdLogic.go +++ b/internal/logic/public/iap/apple/attachTransactionByIdLogic.go @@ -2,7 +2,10 @@ package apple import ( "context" + "strings" + "github.com/perfect-panel/server/internal/model/payment" + "github.com/perfect-panel/server/internal/model/user" "github.com/perfect-panel/server/internal/svc" "github.com/perfect-panel/server/internal/types" "github.com/perfect-panel/server/pkg/constant" @@ -27,20 +30,34 @@ func NewAttachTransactionByIdLogic(ctx context.Context, svcCtx *svc.ServiceConte } func (l *AttachTransactionByIdLogic) AttachById(req *types.AttachAppleTransactionByIdRequest) (*types.AttachAppleTransactionResponse, error) { - _, ok := l.ctx.Value(constant.CtxKeyUser).(*struct{ Id int64 }) - if !ok { + l.Infow("attach by transaction id start", logger.Field("orderNo", req.OrderNo), logger.Field("transactionId", req.TransactionId)) + u, ok := l.ctx.Value(constant.CtxKeyUser).(*user.User) + if !ok || u == nil { + l.Errorw("attach by id invalid access") return nil, errors.Wrapf(xerr.NewErrCode(xerr.InvalidAccess), "invalid access") } ord, err := l.svcCtx.OrderModel.FindOneByOrderNo(l.ctx, req.OrderNo) if err != nil { + l.Errorw("attach by id order not exist", logger.Field("orderNo", req.OrderNo)) return nil, errors.Wrapf(xerr.NewErrCode(xerr.OrderNotExist), "order not exist") } pay, err := l.svcCtx.PaymentModel.FindOne(l.ctx, ord.PaymentId) if err != nil { + l.Errorw("attach by id payment not found", logger.Field("paymentId", ord.PaymentId)) return nil, errors.Wrapf(xerr.NewErrCode(xerr.PaymentMethodNotFound), "payment not found") } + hasKey := false + if pay.Config != "" && (strings.Contains(pay.Config, "-----BEGIN PRIVATE KEY-----") || strings.Contains(pay.Config, "BEGIN PRIVATE KEY")) { + hasKey = true + } + l.Infow("attach by id payment config meta", logger.Field("paymentId", pay.Id), logger.Field("platform", pay.Platform), logger.Field("config_len", len(pay.Config)), logger.Field("has_private_key", hasKey)) + if pay.Config == "" { + l.Errorw("attach by id iap config empty", logger.Field("paymentId", pay.Id)) + return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "iap config is empty") + } var cfg payment.AppleIAPConfig if err := cfg.Unmarshal([]byte(pay.Config)); err != nil { + l.Errorw("attach by id iap config error", logger.Field("error", err.Error()), logger.Field("paymentId", pay.Id), logger.Field("platform", pay.Platform), logger.Field("config_len", len(pay.Config)), logger.Field("has_private_key", hasKey)) return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "iap config error") } apiCfg := iapapple.ServerAPIConfig{ @@ -53,6 +70,7 @@ func (l *AttachTransactionByIdLogic) AttachById(req *types.AttachAppleTransactio apiCfg.Sandbox = *req.Sandbox } if apiCfg.KeyID == "" || apiCfg.IssuerID == "" || apiCfg.PrivateKey == "" { + l.Errorw("attach by id credential missing") return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "apple server api credential missing") } jws, err := iapapple.GetTransactionInfo(apiCfg, req.TransactionId) @@ -62,11 +80,17 @@ func (l *AttachTransactionByIdLogic) AttachById(req *types.AttachAppleTransactio } // reuse existing attach logic with JWS attach := NewAttachTransactionLogic(l.ctx, l.svcCtx) - return attach.Attach(&types.AttachAppleTransactionRequest{ + resp, e := attach.Attach(&types.AttachAppleTransactionRequest{ SignedTransactionJWS: jws, SubscribeId: 0, DurationDays: 0, Tier: "", OrderNo: req.OrderNo, }) + if e != nil { + l.Errorw("attach by id commit error", logger.Field("error", e.Error())) + return nil, e + } + l.Infow("attach by transaction id ok", logger.Field("orderNo", req.OrderNo), logger.Field("transactionId", req.TransactionId), logger.Field("expiresAt", resp.ExpiresAt)) + return resp, nil } diff --git a/internal/middleware/authMiddleware.go b/internal/middleware/authMiddleware.go index fbf3758..e44206b 100644 --- a/internal/middleware/authMiddleware.go +++ b/internal/middleware/authMiddleware.go @@ -27,7 +27,7 @@ func AuthMiddleware(svc *svc.ServiceContext) func(c *gin.Context) { // get token from header token := c.GetHeader("Authorization") if token == "" { - logger.WithContext(c.Request.Context()).Debug("[AuthMiddleware] Token Empty") + logger.WithContext(c.Request.Context()).Errorw("[AuthMiddleware] token empty", logger.Field("path", c.Request.URL.Path)) result.HttpResult(c, nil, errors.Wrapf(xerr.NewErrCode(xerr.ErrorTokenEmpty), "Token Empty")) c.Abort() return @@ -35,7 +35,7 @@ func AuthMiddleware(svc *svc.ServiceContext) func(c *gin.Context) { // parse token claims, err := jwt.ParseJwtToken(token, jwtConfig.AccessSecret) if err != nil { - logger.WithContext(c.Request.Context()).Debug("[AuthMiddleware] ParseJwtToken", logger.Field("error", err.Error()), logger.Field("token", token)) + logger.WithContext(c.Request.Context()).Errorw("[AuthMiddleware] parse token failed", logger.Field("error", err.Error())) result.HttpResult(c, nil, errors.Wrapf(xerr.NewErrCode(xerr.ErrorTokenExpire), "Token Invalid")) c.Abort() return @@ -53,7 +53,7 @@ func AuthMiddleware(svc *svc.ServiceContext) func(c *gin.Context) { sessionIdCacheKey := fmt.Sprintf("%v:%v", config.SessionIdKey, sessionId) value, err := svc.Redis.Get(c, sessionIdCacheKey).Result() if err != nil { - logger.WithContext(c.Request.Context()).Debug("[AuthMiddleware] Redis Get", logger.Field("error", err.Error()), logger.Field("sessionId", sessionId)) + logger.WithContext(c.Request.Context()).Errorw("[AuthMiddleware] redis get failed", logger.Field("error", err.Error()), logger.Field("sessionId", sessionId)) result.HttpResult(c, nil, errors.Wrapf(xerr.NewErrCode(xerr.InvalidAccess), "Invalid Access")) c.Abort() return @@ -61,7 +61,7 @@ func AuthMiddleware(svc *svc.ServiceContext) func(c *gin.Context) { //verify user id if value != fmt.Sprintf("%v", userId) { - logger.WithContext(c.Request.Context()).Debug("[AuthMiddleware] Invalid Access", logger.Field("userId", userId), logger.Field("sessionId", sessionId)) + logger.WithContext(c.Request.Context()).Errorw("[AuthMiddleware] user mismatch", logger.Field("userId", userId), logger.Field("sessionId", sessionId), logger.Field("value", value)) result.HttpResult(c, nil, errors.Wrapf(xerr.NewErrCode(xerr.InvalidAccess), "Invalid Access")) c.Abort() return @@ -69,7 +69,7 @@ func AuthMiddleware(svc *svc.ServiceContext) func(c *gin.Context) { userInfo, err := svc.UserModel.FindOne(c, userId) if err != nil { - logger.WithContext(c.Request.Context()).Debug("[AuthMiddleware] UserModel FindOne", logger.Field("error", err.Error()), logger.Field("userId", userId)) + logger.WithContext(c.Request.Context()).Errorw("[AuthMiddleware] user find failed", logger.Field("error", err.Error()), logger.Field("userId", userId)) result.HttpResult(c, nil, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "Database Query Error")) c.Abort() return @@ -77,11 +77,12 @@ func AuthMiddleware(svc *svc.ServiceContext) func(c *gin.Context) { // admin verify paths := strings.Split(c.Request.URL.Path, "/") if tool.StringSliceContains(paths, "admin") && !*userInfo.IsAdmin { - logger.WithContext(c.Request.Context()).Debug("[AuthMiddleware] Not Admin User", logger.Field("userId", userId), logger.Field("sessionId", sessionId)) + logger.WithContext(c.Request.Context()).Errorw("[AuthMiddleware] not admin", logger.Field("userId", userId), logger.Field("sessionId", sessionId)) result.HttpResult(c, nil, errors.Wrapf(xerr.NewErrCode(xerr.InvalidAccess), "Invalid Access")) c.Abort() return } + logger.WithContext(c.Request.Context()).Infow("[AuthMiddleware] auth ok", logger.Field("userId", userId), logger.Field("loginType", loginType), logger.Field("path", c.Request.URL.Path)) ctx = context.WithValue(ctx, constant.LoginType, loginType) ctx = context.WithValue(ctx, constant.CtxKeyUser, userInfo) ctx = context.WithValue(ctx, constant.CtxKeySessionID, sessionId) diff --git a/internal/middleware/deviceMiddleware.go b/internal/middleware/deviceMiddleware.go index b66ccb0..19e2b3c 100644 --- a/internal/middleware/deviceMiddleware.go +++ b/internal/middleware/deviceMiddleware.go @@ -14,6 +14,7 @@ import ( "github.com/perfect-panel/server/internal/svc" pkgaes "github.com/perfect-panel/server/pkg/aes" "github.com/perfect-panel/server/pkg/constant" + "github.com/perfect-panel/server/pkg/logger" "github.com/perfect-panel/server/pkg/result" "github.com/perfect-panel/server/pkg/xerr" "github.com/pkg/errors" @@ -30,11 +31,13 @@ func DeviceMiddleware(srvCtx *svc.ServiceContext) func(c *gin.Context) { return func(c *gin.Context) { if !srvCtx.Config.Device.Enable { + logger.WithContext(c.Request.Context()).Infow("[DeviceMiddleware] disabled") c.Next() return } if srvCtx.Config.Device.SecuritySecret == "" { + logger.WithContext(c.Request.Context()).Errorw("[DeviceMiddleware] secret empty") result.HttpResult(c, nil, errors.Wrapf(xerr.NewErrCode(xerr.SecretIsEmpty), "Secret is empty")) c.Abort() return @@ -48,12 +51,14 @@ func DeviceMiddleware(srvCtx *svc.ServiceContext) func(c *gin.Context) { loginType, ok := ctx.Value(constant.LoginType).(string) if !ok || loginType != "device" { + logger.WithContext(c.Request.Context()).Infow("[DeviceMiddleware] skip encryption", logger.Field("loginType", loginType)) c.Next() return } rw := NewResponseWriter(c, srvCtx) if !rw.Decrypt() { + logger.WithContext(c.Request.Context()).Errorw("[DeviceMiddleware] decrypt failed", logger.Field("path", c.Request.URL.Path)) result.HttpResult(c, nil, errors.Wrapf(xerr.NewErrCode(xerr.InvalidCiphertext), "Invalid ciphertext")) c.Abort() return @@ -125,6 +130,7 @@ func (rw *ResponseWriter) Decrypt() bool { params := map[string]interface{}{} err = json.Unmarshal([]byte(decrypt), ¶ms) if err == nil { + logger.WithContext(rw.c.Request.Context()).Infow("[DeviceMiddleware] query decrypt ok", logger.Field("path", rw.c.Request.URL.Path)) for k, v := range params { query.Set(k, fmt.Sprintf("%v", v)) } @@ -151,23 +157,28 @@ func (rw *ResponseWriter) Decrypt() bool { data := params["data"] nonce := params["time"] if err != nil || data == nil { + logger.WithContext(rw.c.Request.Context()).Errorw("[DeviceMiddleware] body parse failed", logger.Field("error", err)) return false } str, ok := data.(string) if !ok { + logger.WithContext(rw.c.Request.Context()).Errorw("[DeviceMiddleware] body data type invalid") return false } iv, ok := nonce.(string) if !ok { + logger.WithContext(rw.c.Request.Context()).Errorw("[DeviceMiddleware] body time type invalid") return false } decrypt, err := pkgaes.Decrypt(str, rw.encryptionKey, iv) if err != nil { + logger.WithContext(rw.c.Request.Context()).Errorw("[DeviceMiddleware] body decrypt error", logger.Field("error", err.Error())) return false } rw.c.Request.Body = io.NopCloser(bytes.NewBuffer([]byte(decrypt))) + logger.WithContext(rw.c.Request.Context()).Infow("[DeviceMiddleware] body decrypt ok", logger.Field("path", rw.c.Request.URL.Path)) return true } diff --git a/pkg/iap/apple/jws.go b/pkg/iap/apple/jws.go index 7f93a9e..18aaf41 100644 --- a/pkg/iap/apple/jws.go +++ b/pkg/iap/apple/jws.go @@ -6,6 +6,7 @@ import ( "crypto/x509" "encoding/base64" "encoding/json" + "math/big" "strings" "time" "unicode" @@ -21,20 +22,26 @@ func cleanB64(s string) string { }, trimmed) } +func decodeB64URL(s string) ([]byte, error) { + s = cleanB64(s) + if b, err := base64.RawURLEncoding.DecodeString(s); err == nil { + return b, nil + } + switch len(s) % 4 { + case 2: + s += "==" + case 3: + s += "=" + } + return base64.URLEncoding.DecodeString(s) +} + func ParseTransactionJWS(jws string) (*TransactionPayload, error) { parts := strings.Split(strings.TrimSpace(jws), ".") if len(parts) != 3 { return nil, ErrInvalidJWS } - payloadB64 := cleanB64(parts[1]) - // add padding if required - switch len(payloadB64) % 4 { - case 2: - payloadB64 += "==" - case 3: - payloadB64 += "=" - } - data, err := base64.RawURLEncoding.DecodeString(payloadB64) + data, err := decodeB64URL(parts[1]) if err != nil { return nil, err } @@ -78,15 +85,8 @@ func VerifyTransactionJWS(jws string) (*TransactionPayload, error) { if len(parts) != 3 { return nil, ErrInvalidJWS } - hdrB64 := cleanB64(parts[0]) - switch len(hdrB64) % 4 { - case 2: - hdrB64 += "==" - case 3: - hdrB64 += "=" - } var hdr jwsHeader - hdrBytes, err := base64.RawURLEncoding.DecodeString(hdrB64) + hdrBytes, err := decodeB64URL(parts[0]) if err != nil { return nil, err } @@ -109,14 +109,22 @@ func VerifyTransactionJWS(jws string) (*TransactionPayload, error) { return nil, ErrInvalidJWS } signingInput := cleanB64(parts[0]) + "." + cleanB64(parts[1]) - sig := cleanB64(parts[2]) - sigBytes, err := base64.RawURLEncoding.DecodeString(sig) + sigBytes, err := decodeB64URL(parts[2]) if err != nil { return nil, err } d := sha256.Sum256([]byte(signingInput)) - if !ecdsa.VerifyASN1(pub, d[:], sigBytes) { - return nil, ErrInvalidJWS + // Try ASN.1 signature first + if ecdsa.VerifyASN1(pub, d[:], sigBytes) { + return ParseTransactionJWS(jws) } - return ParseTransactionJWS(jws) + // Fallback: raw R||S (JWS ES256 uses raw signature) + if len(sigBytes) == 64 { + r := new(big.Int).SetBytes(sigBytes[:32]) + s := new(big.Int).SetBytes(sigBytes[32:]) + if ecdsa.Verify(pub, d[:], r, s) { + return ParseTransactionJWS(jws) + } + } + return nil, ErrInvalidJWS }