diff --git a/internal/logic/admin/payment/createPaymentMethodLogic.go b/internal/logic/admin/payment/createPaymentMethodLogic.go index 11441c5..9124086 100644 --- a/internal/logic/admin/payment/createPaymentMethodLogic.go +++ b/internal/logic/admin/payment/createPaymentMethodLogic.go @@ -36,11 +36,13 @@ func NewCreatePaymentMethodLogic(ctx context.Context, svcCtx *svc.ServiceContext } func (l *CreatePaymentMethodLogic) CreatePaymentMethod(req *types.CreatePaymentMethodRequest) (resp *types.PaymentConfig, err error) { - if payment.ParsePlatform(req.Platform) == payment.UNSUPPORTED { + platformType := payment.ParsePlatform(req.Platform) + if platformType == payment.UNSUPPORTED { l.Errorw("unsupported payment platform", logger.Field("mark", req.Platform)) return nil, errors.Wrapf(xerr.NewErrCodeMsg(400, "UNSUPPORTED_PAYMENT_PLATFORM"), "unsupported payment platform: %s", req.Platform) } - config := parsePaymentPlatformConfig(l.ctx, payment.ParsePlatform(req.Platform), req.Config) + req.Platform = platformType.String() + config := parsePaymentPlatformConfig(l.ctx, platformType, req.Config) var paymentMethod = &paymentModel.Payment{ Name: req.Name, Platform: req.Platform, @@ -55,7 +57,7 @@ func (l *CreatePaymentMethodLogic) CreatePaymentMethod(req *types.CreatePaymentM Token: random.KeyNew(8, 1), } err = l.svcCtx.PaymentModel.Transaction(l.ctx, func(tx *gorm.DB) error { - if req.Platform == "Stripe" { + if platformType == payment.Stripe { var cfg paymentModel.StripeConfig if err = cfg.Unmarshal([]byte(paymentMethod.Config)); err != nil { l.Errorf("[CreatePaymentMethod] unmarshal stripe config error: %s", err.Error()) diff --git a/internal/logic/admin/payment/updatePaymentMethodLogic.go b/internal/logic/admin/payment/updatePaymentMethodLogic.go index 7c2dda2..a2a4b1c 100644 --- a/internal/logic/admin/payment/updatePaymentMethodLogic.go +++ b/internal/logic/admin/payment/updatePaymentMethodLogic.go @@ -29,7 +29,8 @@ func NewUpdatePaymentMethodLogic(ctx context.Context, svcCtx *svc.ServiceContext } func (l *UpdatePaymentMethodLogic) UpdatePaymentMethod(req *types.UpdatePaymentMethodRequest) (resp *types.PaymentConfig, err error) { - if payment.ParsePlatform(req.Platform) == payment.UNSUPPORTED { + platformType := payment.ParsePlatform(req.Platform) + if platformType == payment.UNSUPPORTED { l.Errorw("unsupported payment platform", logger.Field("mark", req.Platform)) return nil, errors.Wrapf(xerr.NewErrCodeMsg(400, "UNSUPPORTED_PAYMENT_PLATFORM"), "unsupported payment platform: %s", req.Platform) } @@ -38,7 +39,13 @@ func (l *UpdatePaymentMethodLogic) UpdatePaymentMethod(req *types.UpdatePaymentM l.Errorw("find payment method error", logger.Field("id", req.Id), logger.Field("error", err.Error())) return nil, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "find payment method error: %s", err.Error()) } - config := parsePaymentPlatformConfig(l.ctx, payment.ParsePlatform(req.Platform), req.Config) + existingPlatformType := payment.ParsePlatform(method.Platform) + if existingPlatformType != payment.UNSUPPORTED && existingPlatformType != platformType { + l.Errorw("payment platform mismatch", logger.Field("id", req.Id), logger.Field("current", method.Platform), logger.Field("request", req.Platform)) + return nil, errors.Wrapf(xerr.NewErrCodeMsg(xerr.InvalidParams, "payment platform mismatch"), "payment platform mismatch: %s -> %s", method.Platform, req.Platform) + } + req.Platform = platformType.String() + config := parsePaymentPlatformConfig(l.ctx, platformType, req.Config) tool.DeepCopy(method, req, tool.CopyWithIgnoreEmpty(false)) method.Config = config if err := l.svcCtx.PaymentModel.Update(l.ctx, method); err != nil { diff --git a/internal/logic/public/order/paymentMethod.go b/internal/logic/public/order/paymentMethod.go new file mode 100644 index 0000000..7315342 --- /dev/null +++ b/internal/logic/public/order/paymentMethod.go @@ -0,0 +1,11 @@ +package order + +import paymentPlatform "github.com/perfect-panel/server/pkg/payment" + +func canonicalOrderMethod(method string) string { + platform := paymentPlatform.ParsePlatform(method) + if platform == paymentPlatform.UNSUPPORTED { + return method + } + return platform.String() +} diff --git a/internal/logic/public/order/purchaseLogic.go b/internal/logic/public/order/purchaseLogic.go index 4c1c3c3..404eb96 100644 --- a/internal/logic/public/order/purchaseLogic.go +++ b/internal/logic/public/order/purchaseLogic.go @@ -224,7 +224,7 @@ func (l *PurchaseLogic) Purchase(req *types.PurchaseOrderRequest) (resp *types.P Coupon: req.Coupon, CouponDiscount: coupon, PaymentId: payment.Id, - Method: payment.Platform, + Method: canonicalOrderMethod(payment.Platform), FeeAmount: feeAmount, Status: 1, IsNew: isNew, diff --git a/internal/logic/public/order/rechargeLogic.go b/internal/logic/public/order/rechargeLogic.go index a27053c..c84edfd 100644 --- a/internal/logic/public/order/rechargeLogic.go +++ b/internal/logic/public/order/rechargeLogic.go @@ -88,7 +88,7 @@ func (l *RechargeLogic) Recharge(req *types.RechargeOrderRequest) (resp *types.R Amount: totalAmount, FeeAmount: feeAmount, PaymentId: payment.Id, - Method: payment.Platform, + Method: canonicalOrderMethod(payment.Platform), Status: 1, IsNew: isNew, } diff --git a/internal/logic/public/order/renewalLogic.go b/internal/logic/public/order/renewalLogic.go index 18766c7..1898652 100644 --- a/internal/logic/public/order/renewalLogic.go +++ b/internal/logic/public/order/renewalLogic.go @@ -175,7 +175,7 @@ func (l *RenewalLogic) Renewal(req *types.RenewalOrderRequest) (resp *types.Rene Coupon: req.Coupon, CouponDiscount: coupon, PaymentId: payment.Id, - Method: payment.Platform, + Method: canonicalOrderMethod(payment.Platform), FeeAmount: feeAmount, Status: 1, SubscribeId: userSubscribe.SubscribeId, diff --git a/internal/logic/public/order/resetTrafficLogic.go b/internal/logic/public/order/resetTrafficLogic.go index 1fc9b57..a3a0669 100644 --- a/internal/logic/public/order/resetTrafficLogic.go +++ b/internal/logic/public/order/resetTrafficLogic.go @@ -90,7 +90,7 @@ func (l *ResetTrafficLogic) ResetTraffic(req *types.ResetTrafficOrderRequest) (r GiftAmount: deductionAmount, FeeAmount: feeAmount, PaymentId: payment.Id, - Method: payment.Platform, + Method: canonicalOrderMethod(payment.Platform), Status: 1, SubscribeId: userSubscribe.SubscribeId, SubscribeToken: userSubscribe.Token, diff --git a/internal/middleware/deviceMiddleware.go b/internal/middleware/deviceMiddleware.go index 7b7c385..f36e62b 100644 --- a/internal/middleware/deviceMiddleware.go +++ b/internal/middleware/deviceMiddleware.go @@ -89,6 +89,8 @@ func NewResponseWriter(c *gin.Context, srvCtx *svc.ServiceContext) (rw *Response rw = &ResponseWriter{ c: c, body: new(bytes.Buffer), + size: noWritten, + status: defaultStatus, ResponseWriter: c.Writer, } rw.encryptionKey = srvCtx.Config.Device.SecuritySecret diff --git a/pkg/payment/platform.go b/pkg/payment/platform.go index ff092ef..5e18088 100644 --- a/pkg/payment/platform.go +++ b/pkg/payment/platform.go @@ -1,6 +1,10 @@ package payment -import "github.com/perfect-panel/server/internal/types" +import ( + "strings" + + "github.com/perfect-panel/server/internal/types" +) type Platform int @@ -14,32 +18,74 @@ const ( UNSUPPORTED Platform = -1 ) -var platformNames = map[string]Platform{ - "CryptoSaaS": CryptoSaaS, - "Stripe": Stripe, - "AlipayF2F": AlipayF2F, - "EPay": EPay, - "AppleIAP": AppleIAP, +const ( + platformNameStripe = "Stripe" + platformNameAlipayF2F = "AlipayF2F" + platformNameEPay = "EPay" + platformNameBalance = "balance" + platformNameCryptoSaaS = "CryptoSaaS" + platformNameAppleIAP = "AppleIAP" + platformNameUnsupported = "unsupported" +) + +var platformAliasToType = map[string]Platform{ + "stripe": Stripe, + "alipayf2f": AlipayF2F, + "alipay_f2f": AlipayF2F, + "epay": EPay, "balance": Balance, - "unsupported": UNSUPPORTED, + "cryptosaas": CryptoSaaS, + "crypto_saas": CryptoSaaS, + "appleiap": AppleIAP, + "apple_iap": AppleIAP, } func (p Platform) String() string { - for k, v := range platformNames { - if v == p { - return k - } + switch p { + case Stripe: + return platformNameStripe + case AlipayF2F: + return platformNameAlipayF2F + case EPay: + return platformNameEPay + case Balance: + return platformNameBalance + case CryptoSaaS: + return platformNameCryptoSaaS + case AppleIAP: + return platformNameAppleIAP + default: + return platformNameUnsupported } - return "unsupported" } func ParsePlatform(s string) Platform { - if p, ok := platformNames[s]; ok { + normalized := normalizePlatformAlias(s) + if p, ok := platformAliasToType[normalized]; ok { + return p + } + compact := strings.ReplaceAll(normalized, "_", "") + if p, ok := platformAliasToType[compact]; ok { return p } return UNSUPPORTED } +func CanonicalPlatformName(input string) (string, bool) { + platform := ParsePlatform(input) + if platform == UNSUPPORTED { + return "", false + } + return platform.String(), true +} + +func normalizePlatformAlias(input string) string { + normalized := strings.ToLower(strings.TrimSpace(input)) + normalized = strings.ReplaceAll(normalized, "-", "_") + normalized = strings.ReplaceAll(normalized, " ", "") + return normalized +} + func GetSupportedPlatforms() []types.PlatformInfo { return []types.PlatformInfo{ { diff --git a/pkg/payment/platform_test.go b/pkg/payment/platform_test.go new file mode 100644 index 0000000..95ba93a --- /dev/null +++ b/pkg/payment/platform_test.go @@ -0,0 +1,69 @@ +package payment + +import "testing" + +func TestParsePlatform(t *testing.T) { + testCases := []struct { + name string + input string + expected Platform + }{ + {name: "exact AppleIAP", input: "AppleIAP", expected: AppleIAP}, + {name: "snake apple_iap", input: "apple_iap", expected: AppleIAP}, + {name: "kebab apple-iap", input: "apple-iap", expected: AppleIAP}, + {name: "compact appleiap", input: "appleiap", expected: AppleIAP}, + {name: "trimmed value", input: " apple_iap ", expected: AppleIAP}, + {name: "legacy exact CryptoSaaS", input: "CryptoSaaS", expected: CryptoSaaS}, + {name: "snake crypto_saas", input: "crypto_saas", expected: CryptoSaaS}, + {name: "unsupported", input: "unknown_gateway", expected: UNSUPPORTED}, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + got := ParsePlatform(testCase.input) + if got != testCase.expected { + t.Fatalf("ParsePlatform(%q) = %v, expected %v", testCase.input, got, testCase.expected) + } + }) + } +} + +func TestPlatformStringIsCanonical(t *testing.T) { + testCases := []struct { + name string + input Platform + expected string + }{ + {name: "stripe", input: Stripe, expected: "Stripe"}, + {name: "alipay", input: AlipayF2F, expected: "AlipayF2F"}, + {name: "epay", input: EPay, expected: "EPay"}, + {name: "balance", input: Balance, expected: "balance"}, + {name: "crypto", input: CryptoSaaS, expected: "CryptoSaaS"}, + {name: "apple", input: AppleIAP, expected: "AppleIAP"}, + {name: "unsupported", input: UNSUPPORTED, expected: "unsupported"}, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + got := testCase.input.String() + if got != testCase.expected { + t.Fatalf("Platform.String() = %q, expected %q", got, testCase.expected) + } + }) + } +} + +func TestCanonicalPlatformName(t *testing.T) { + canonical, ok := CanonicalPlatformName("apple_iap") + if !ok { + t.Fatalf("expected apple_iap to be supported") + } + if canonical != "AppleIAP" { + t.Fatalf("canonical name mismatch: got %q", canonical) + } + + _, ok = CanonicalPlatformName("not_exists") + if ok { + t.Fatalf("expected unsupported platform to return ok=false") + } +}