fix(purchase): correct gift amount deduction logic and enhance payment processing comments
This commit is contained in:
parent
76816ca8ea
commit
9691257bad
@ -132,7 +132,7 @@ func (l *PurchaseLogic) Purchase(req *types.PurchaseOrderRequest) (resp *types.P
|
|||||||
if u.GiftAmount >= amount {
|
if u.GiftAmount >= amount {
|
||||||
deductionAmount = amount
|
deductionAmount = amount
|
||||||
amount = 0
|
amount = 0
|
||||||
u.GiftAmount -= amount
|
u.GiftAmount -= deductionAmount
|
||||||
} else {
|
} else {
|
||||||
deductionAmount = u.GiftAmount
|
deductionAmount = u.GiftAmount
|
||||||
amount -= u.GiftAmount
|
amount -= u.GiftAmount
|
||||||
|
|||||||
@ -28,13 +28,16 @@ import (
|
|||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// PurchaseCheckoutLogic handles the checkout process for various payment methods
|
||||||
|
// including EPay, Stripe, Alipay F2F, and balance payments
|
||||||
type PurchaseCheckoutLogic struct {
|
type PurchaseCheckoutLogic struct {
|
||||||
logger.Logger
|
logger.Logger
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
svcCtx *svc.ServiceContext
|
svcCtx *svc.ServiceContext
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewPurchaseCheckoutLogic Purchase Checkout
|
// NewPurchaseCheckoutLogic creates a new instance of PurchaseCheckoutLogic
|
||||||
|
// for handling purchase checkout operations across different payment platforms
|
||||||
func NewPurchaseCheckoutLogic(ctx context.Context, svcCtx *svc.ServiceContext) *PurchaseCheckoutLogic {
|
func NewPurchaseCheckoutLogic(ctx context.Context, svcCtx *svc.ServiceContext) *PurchaseCheckoutLogic {
|
||||||
return &PurchaseCheckoutLogic{
|
return &PurchaseCheckoutLogic{
|
||||||
Logger: logger.WithContext(ctx),
|
Logger: logger.WithContext(ctx),
|
||||||
@ -43,88 +46,104 @@ func NewPurchaseCheckoutLogic(ctx context.Context, svcCtx *svc.ServiceContext) *
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PurchaseCheckout processes the checkout for an order using the specified payment method
|
||||||
|
// It validates the order, retrieves payment configuration, and routes to the appropriate payment handler
|
||||||
func (l *PurchaseCheckoutLogic) PurchaseCheckout(req *types.CheckoutOrderRequest) (resp *types.CheckoutOrderResponse, err error) {
|
func (l *PurchaseCheckoutLogic) PurchaseCheckout(req *types.CheckoutOrderRequest) (resp *types.CheckoutOrderResponse, err error) {
|
||||||
// Find order
|
// Validate and retrieve order information
|
||||||
orderInfo, err := l.svcCtx.OrderModel.FindOneByOrderNo(l.ctx, req.OrderNo)
|
orderInfo, err := l.svcCtx.OrderModel.FindOneByOrderNo(l.ctx, req.OrderNo)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.Logger.Error("[PurchaseCheckout] Find order failed", logger.Field("error", err.Error()), logger.Field("orderNo", req.OrderNo))
|
l.Logger.Error("[PurchaseCheckout] Find order failed", logger.Field("error", err.Error()), logger.Field("orderNo", req.OrderNo))
|
||||||
return nil, errors.Wrapf(xerr.NewErrCode(xerr.OrderNotExist), "order not exist: %v", req.OrderNo)
|
return nil, errors.Wrapf(xerr.NewErrCode(xerr.OrderNotExist), "order not exist: %v", req.OrderNo)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Verify order is in pending payment status (status = 1)
|
||||||
if orderInfo.Status != 1 {
|
if orderInfo.Status != 1 {
|
||||||
l.Logger.Error("[PurchaseCheckout] Order status error", logger.Field("status", orderInfo.Status))
|
l.Logger.Error("[PurchaseCheckout] Order status error", logger.Field("status", orderInfo.Status))
|
||||||
return nil, errors.Wrapf(xerr.NewErrCode(xerr.OrderStatusError), "order status error: %v", orderInfo.Status)
|
return nil, errors.Wrapf(xerr.NewErrCode(xerr.OrderStatusError), "order status error: %v", orderInfo.Status)
|
||||||
}
|
}
|
||||||
|
|
||||||
// find payment method
|
// Retrieve payment method configuration
|
||||||
paymentConfig, err := l.svcCtx.PaymentModel.FindOne(l.ctx, orderInfo.PaymentId)
|
paymentConfig, err := l.svcCtx.PaymentModel.FindOne(l.ctx, orderInfo.PaymentId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.Logger.Error("[Purchase] Database query error", logger.Field("error", err.Error()), logger.Field("payment", orderInfo.Method))
|
l.Logger.Error("[PurchaseCheckout] Database query error", logger.Field("error", err.Error()), logger.Field("payment", orderInfo.Method))
|
||||||
return nil, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "find payment method error: %v", err.Error())
|
return nil, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "find payment method error: %v", err.Error())
|
||||||
}
|
}
|
||||||
|
// Route to appropriate payment handler based on payment platform
|
||||||
switch paymentPlatform.ParsePlatform(orderInfo.Method) {
|
switch paymentPlatform.ParsePlatform(orderInfo.Method) {
|
||||||
case paymentPlatform.EPay:
|
case paymentPlatform.EPay:
|
||||||
|
// Process EPay payment - generates payment URL for redirect
|
||||||
url, err := l.epayPayment(paymentConfig, orderInfo, req.ReturnUrl)
|
url, err := l.epayPayment(paymentConfig, orderInfo, req.ReturnUrl)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "epayPayment error: %v", err.Error())
|
return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "epayPayment error: %v", err.Error())
|
||||||
}
|
}
|
||||||
resp = &types.CheckoutOrderResponse{
|
resp = &types.CheckoutOrderResponse{
|
||||||
CheckoutUrl: url,
|
CheckoutUrl: url,
|
||||||
Type: "url",
|
Type: "url", // Client should redirect to URL
|
||||||
}
|
}
|
||||||
|
|
||||||
case paymentPlatform.Stripe:
|
case paymentPlatform.Stripe:
|
||||||
|
// Process Stripe payment - creates payment sheet for client-side processing
|
||||||
stripePayment, err := l.stripePayment(paymentConfig.Config, orderInfo, "")
|
stripePayment, err := l.stripePayment(paymentConfig.Config, orderInfo, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "stripePayment error: %v", err.Error())
|
return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "stripePayment error: %v", err.Error())
|
||||||
}
|
}
|
||||||
resp = &types.CheckoutOrderResponse{
|
resp = &types.CheckoutOrderResponse{
|
||||||
Type: "stripe",
|
Type: "stripe", // Client should use Stripe SDK
|
||||||
Stripe: stripePayment,
|
Stripe: stripePayment,
|
||||||
}
|
}
|
||||||
|
|
||||||
case paymentPlatform.AlipayF2F:
|
case paymentPlatform.AlipayF2F:
|
||||||
|
// Process Alipay Face-to-Face payment - generates QR code
|
||||||
url, err := l.alipayF2fPayment(paymentConfig, orderInfo)
|
url, err := l.alipayF2fPayment(paymentConfig, orderInfo)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.Errorw("[CheckoutOrderLogic] alipayF2fPayment error", logger.Field("error", err.Error()))
|
l.Errorw("[PurchaseCheckout] alipayF2fPayment error", logger.Field("error", err.Error()))
|
||||||
return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "alipayF2fPayment error: %v", err.Error())
|
return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "alipayF2fPayment error: %v", err.Error())
|
||||||
}
|
}
|
||||||
resp = &types.CheckoutOrderResponse{
|
resp = &types.CheckoutOrderResponse{
|
||||||
Type: "qr",
|
Type: "qr", // Client should display QR code
|
||||||
CheckoutUrl: url,
|
CheckoutUrl: url,
|
||||||
}
|
}
|
||||||
|
|
||||||
case paymentPlatform.Balance:
|
case paymentPlatform.Balance:
|
||||||
|
// Process balance payment - validate user and process payment immediately
|
||||||
if orderInfo.UserId == 0 {
|
if orderInfo.UserId == 0 {
|
||||||
l.Errorw("[CheckoutOrderLogic] user not found", logger.Field("userId", orderInfo.UserId))
|
l.Errorw("[PurchaseCheckout] user not found", logger.Field("userId", orderInfo.UserId))
|
||||||
return nil, errors.Wrapf(xerr.NewErrCode(xerr.UserNotExist), "user not found")
|
return nil, errors.Wrapf(xerr.NewErrCode(xerr.UserNotExist), "user not found")
|
||||||
}
|
}
|
||||||
// find user
|
|
||||||
|
// Retrieve user information for balance validation
|
||||||
userInfo, err := l.svcCtx.UserModel.FindOne(l.ctx, orderInfo.UserId)
|
userInfo, err := l.svcCtx.UserModel.FindOne(l.ctx, orderInfo.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.Errorw("[CheckoutOrderLogic] FindOne User error", logger.Field("error", err.Error()))
|
l.Errorw("[PurchaseCheckout] FindOne User error", logger.Field("error", err.Error()))
|
||||||
return nil, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "FindOne error: %s", err.Error())
|
return nil, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "FindOne error: %s", err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
// balance
|
// Process balance payment with gift amount priority logic
|
||||||
if err = l.balancePayment(userInfo, orderInfo); err != nil {
|
if err = l.balancePayment(userInfo, orderInfo); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
resp = &types.CheckoutOrderResponse{
|
resp = &types.CheckoutOrderResponse{
|
||||||
Type: "balance",
|
Type: "balance", // Payment completed immediately
|
||||||
}
|
}
|
||||||
|
|
||||||
default:
|
default:
|
||||||
l.Errorw("[CheckoutOrderLogic] payment method not found", logger.Field("method", orderInfo.Method))
|
l.Errorw("[PurchaseCheckout] payment method not found", logger.Field("method", orderInfo.Method))
|
||||||
return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "payment method not found")
|
return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "payment method not found")
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// alipay f2f payment
|
// alipayF2fPayment processes Alipay Face-to-Face payment by generating a QR code
|
||||||
|
// It handles currency conversion and creates a pre-payment trade for QR code scanning
|
||||||
func (l *PurchaseCheckoutLogic) alipayF2fPayment(pay *payment.Payment, info *order.Order) (string, error) {
|
func (l *PurchaseCheckoutLogic) alipayF2fPayment(pay *payment.Payment, info *order.Order) (string, error) {
|
||||||
|
// Parse Alipay F2F configuration from payment settings
|
||||||
f2FConfig := payment.AlipayF2FConfig{}
|
f2FConfig := payment.AlipayF2FConfig{}
|
||||||
if err := json.Unmarshal([]byte(pay.Config), &f2FConfig); err != nil {
|
if err := json.Unmarshal([]byte(pay.Config), &f2FConfig); err != nil {
|
||||||
l.Errorw("[PurchaseCheckoutLogic] Unmarshal error", logger.Field("error", err.Error()))
|
l.Errorw("[PurchaseCheckout] Unmarshal Alipay config error", logger.Field("error", err.Error()))
|
||||||
return "", errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "Unmarshal error: %s", err.Error())
|
return "", errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "Unmarshal error: %s", err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Build notification URL for payment status callbacks
|
||||||
notifyUrl := ""
|
notifyUrl := ""
|
||||||
if pay.Domain != "" {
|
if pay.Domain != "" {
|
||||||
notifyUrl = pay.Domain + "/v1/notify/" + pay.Platform + "/" + pay.Token
|
notifyUrl = pay.Domain + "/v1/notify/" + pay.Platform + "/" + pay.Token
|
||||||
@ -135,6 +154,8 @@ func (l *PurchaseCheckoutLogic) alipayF2fPayment(pay *payment.Payment, info *ord
|
|||||||
}
|
}
|
||||||
notifyUrl = "https://" + host + "/v1/notify/" + pay.Platform + "/" + pay.Token
|
notifyUrl = "https://" + host + "/v1/notify/" + pay.Platform + "/" + pay.Token
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Initialize Alipay client with configuration
|
||||||
client := alipay.NewClient(alipay.Config{
|
client := alipay.NewClient(alipay.Config{
|
||||||
AppId: f2FConfig.AppId,
|
AppId: f2FConfig.AppId,
|
||||||
PrivateKey: f2FConfig.PrivateKey,
|
PrivateKey: f2FConfig.PrivateKey,
|
||||||
@ -142,46 +163,53 @@ func (l *PurchaseCheckoutLogic) alipayF2fPayment(pay *payment.Payment, info *ord
|
|||||||
InvoiceName: f2FConfig.InvoiceName,
|
InvoiceName: f2FConfig.InvoiceName,
|
||||||
NotifyURL: notifyUrl,
|
NotifyURL: notifyUrl,
|
||||||
})
|
})
|
||||||
// Calculate the amount with exchange rate
|
|
||||||
|
// Convert order amount to CNY using current exchange rate
|
||||||
amount, err := l.queryExchangeRate("CNY", info.Amount)
|
amount, err := l.queryExchangeRate("CNY", info.Amount)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.Errorw("[CheckoutOrderLogic] queryExchangeRate error", logger.Field("error", err.Error()))
|
l.Errorw("[PurchaseCheckout] queryExchangeRate error", logger.Field("error", err.Error()))
|
||||||
return "", errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "queryExchangeRate error: %s", err.Error())
|
return "", errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "queryExchangeRate error: %s", err.Error())
|
||||||
}
|
}
|
||||||
convertAmount := int64(amount * 100)
|
convertAmount := int64(amount * 100) // Convert to cents for API
|
||||||
// create payment
|
|
||||||
|
// Create pre-payment trade and generate QR code
|
||||||
QRCode, err := client.PreCreateTrade(l.ctx, alipay.Order{
|
QRCode, err := client.PreCreateTrade(l.ctx, alipay.Order{
|
||||||
OrderNo: info.OrderNo,
|
OrderNo: info.OrderNo,
|
||||||
Amount: convertAmount,
|
Amount: convertAmount,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.Errorw("[CheckoutOrderLogic] PreCreateTrade error", logger.Field("error", err.Error()))
|
l.Errorw("[PurchaseCheckout] PreCreateTrade error", logger.Field("error", err.Error()))
|
||||||
return "", errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "PreCreateTrade error: %s", err.Error())
|
return "", errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "PreCreateTrade error: %s", err.Error())
|
||||||
}
|
}
|
||||||
return QRCode, nil
|
return QRCode, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stripe Payment
|
// stripePayment processes Stripe payment by creating a payment sheet
|
||||||
|
// It supports various payment methods including WeChat Pay and Alipay through Stripe
|
||||||
func (l *PurchaseCheckoutLogic) stripePayment(config string, info *order.Order, identifier string) (*types.StripePayment, error) {
|
func (l *PurchaseCheckoutLogic) stripePayment(config string, info *order.Order, identifier string) (*types.StripePayment, error) {
|
||||||
// stripe WeChat pay or stripe alipay
|
// Parse Stripe configuration from payment settings
|
||||||
stripeConfig := payment.StripeConfig{}
|
stripeConfig := payment.StripeConfig{}
|
||||||
if err := json.Unmarshal([]byte(config), &stripeConfig); err != nil {
|
if err := json.Unmarshal([]byte(config), &stripeConfig); err != nil {
|
||||||
l.Errorw("[CheckoutOrderLogic] Unmarshal error", logger.Field("error", err.Error()))
|
l.Errorw("[PurchaseCheckout] Unmarshal Stripe config error", logger.Field("error", err.Error()))
|
||||||
return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "Unmarshal error: %s", err.Error())
|
return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "Unmarshal error: %s", err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Initialize Stripe client with API credentials
|
||||||
client := stripe.NewClient(stripe.Config{
|
client := stripe.NewClient(stripe.Config{
|
||||||
SecretKey: stripeConfig.SecretKey,
|
SecretKey: stripeConfig.SecretKey,
|
||||||
PublicKey: stripeConfig.PublicKey,
|
PublicKey: stripeConfig.PublicKey,
|
||||||
WebhookSecret: stripeConfig.WebhookSecret,
|
WebhookSecret: stripeConfig.WebhookSecret,
|
||||||
})
|
})
|
||||||
// Calculate the amount with exchange rate
|
|
||||||
|
// Convert order amount to CNY using current exchange rate
|
||||||
amount, err := l.queryExchangeRate("CNY", info.Amount)
|
amount, err := l.queryExchangeRate("CNY", info.Amount)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.Errorw("[CheckoutOrderLogic] queryExchangeRate error", logger.Field("error", err.Error()))
|
l.Errorw("[PurchaseCheckout] queryExchangeRate error", logger.Field("error", err.Error()))
|
||||||
return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "queryExchangeRate error: %s", err.Error())
|
return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "queryExchangeRate error: %s", err.Error())
|
||||||
}
|
}
|
||||||
convertAmount := int64(amount * 100)
|
convertAmount := int64(amount * 100) // Convert to cents for Stripe API
|
||||||
// create payment
|
|
||||||
|
// Create Stripe payment sheet for client-side processing
|
||||||
result, err := client.CreatePaymentSheet(&stripe.Order{
|
result, err := client.CreatePaymentSheet(&stripe.Order{
|
||||||
OrderNo: info.OrderNo,
|
OrderNo: info.OrderNo,
|
||||||
Subscribe: strconv.FormatInt(info.SubscribeId, 10),
|
Subscribe: strconv.FormatInt(info.SubscribeId, 10),
|
||||||
@ -193,37 +221,47 @@ func (l *PurchaseCheckoutLogic) stripePayment(config string, info *order.Order,
|
|||||||
Email: identifier,
|
Email: identifier,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.Errorw("[CheckoutOrderLogic] CreatePaymentSheet error", logger.Field("error", err.Error()))
|
l.Errorw("[PurchaseCheckout] CreatePaymentSheet error", logger.Field("error", err.Error()))
|
||||||
return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "CreatePaymentSheet error: %s", err.Error())
|
return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "CreatePaymentSheet error: %s", err.Error())
|
||||||
}
|
}
|
||||||
tradeNo := result.TradeNo
|
|
||||||
|
// Prepare response data for client-side Stripe integration
|
||||||
stripePayment := &types.StripePayment{
|
stripePayment := &types.StripePayment{
|
||||||
PublishableKey: stripeConfig.PublicKey,
|
PublishableKey: stripeConfig.PublicKey,
|
||||||
ClientSecret: result.ClientSecret,
|
ClientSecret: result.ClientSecret,
|
||||||
Method: stripeConfig.Payment,
|
Method: stripeConfig.Payment,
|
||||||
}
|
}
|
||||||
// save payment
|
|
||||||
info.TradeNo = tradeNo
|
// Save Stripe trade number to order for tracking
|
||||||
|
info.TradeNo = result.TradeNo
|
||||||
err = l.svcCtx.OrderModel.Update(l.ctx, info)
|
err = l.svcCtx.OrderModel.Update(l.ctx, info)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.Errorw("[CheckoutOrderLogic] Update error", logger.Field("error", err.Error()))
|
l.Errorw("[PurchaseCheckout] Update order error", logger.Field("error", err.Error()))
|
||||||
return nil, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "Update error: %s", err.Error())
|
return nil, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "Update error: %s", err.Error())
|
||||||
}
|
}
|
||||||
return stripePayment, nil
|
return stripePayment, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// epayPayment processes EPay payment by generating a payment URL for redirect
|
||||||
|
// It handles currency conversion and creates a payment URL for external payment processing
|
||||||
func (l *PurchaseCheckoutLogic) epayPayment(config *payment.Payment, info *order.Order, returnUrl string) (string, error) {
|
func (l *PurchaseCheckoutLogic) epayPayment(config *payment.Payment, info *order.Order, returnUrl string) (string, error) {
|
||||||
|
// Parse EPay configuration from payment settings
|
||||||
epayConfig := payment.EPayConfig{}
|
epayConfig := payment.EPayConfig{}
|
||||||
if err := json.Unmarshal([]byte(config.Config), &epayConfig); err != nil {
|
if err := json.Unmarshal([]byte(config.Config), &epayConfig); err != nil {
|
||||||
l.Errorw("[CheckoutOrderLogic] Unmarshal error", logger.Field("error", err.Error()))
|
l.Errorw("[PurchaseCheckout] Unmarshal EPay config error", logger.Field("error", err.Error()))
|
||||||
return "", errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "Unmarshal error: %s", err.Error())
|
return "", errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "Unmarshal error: %s", err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Initialize EPay client with merchant credentials
|
||||||
client := epay.NewClient(epayConfig.Pid, epayConfig.Url, epayConfig.Key)
|
client := epay.NewClient(epayConfig.Pid, epayConfig.Url, epayConfig.Key)
|
||||||
// Calculate the amount with exchange rate
|
|
||||||
|
// Convert order amount to CNY using current exchange rate
|
||||||
amount, err := l.queryExchangeRate("CNY", info.Amount)
|
amount, err := l.queryExchangeRate("CNY", info.Amount)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Build notification URL for payment status callbacks
|
||||||
notifyUrl := ""
|
notifyUrl := ""
|
||||||
if config.Domain != "" {
|
if config.Domain != "" {
|
||||||
notifyUrl = config.Domain + "/v1/notify/" + config.Platform + "/" + config.Token
|
notifyUrl = config.Domain + "/v1/notify/" + config.Platform + "/" + config.Token
|
||||||
@ -234,7 +272,8 @@ func (l *PurchaseCheckoutLogic) epayPayment(config *payment.Payment, info *order
|
|||||||
}
|
}
|
||||||
notifyUrl = "https://" + host + "/v1/notify/" + config.Platform + "/" + config.Token
|
notifyUrl = "https://" + host + "/v1/notify/" + config.Platform + "/" + config.Token
|
||||||
}
|
}
|
||||||
// create payment
|
|
||||||
|
// Create payment URL for user redirection
|
||||||
url := client.CreatePayUrl(epay.Order{
|
url := client.CreatePayUrl(epay.Order{
|
||||||
Name: l.svcCtx.Config.Site.SiteName,
|
Name: l.svcCtx.Config.Site.SiteName,
|
||||||
Amount: amount,
|
Amount: amount,
|
||||||
@ -246,26 +285,34 @@ func (l *PurchaseCheckoutLogic) epayPayment(config *payment.Payment, info *order
|
|||||||
return url, nil
|
return url, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Query exchange rate
|
// queryExchangeRate converts the order amount from system currency to target currency
|
||||||
|
// It retrieves the current exchange rate and performs currency conversion if needed
|
||||||
func (l *PurchaseCheckoutLogic) queryExchangeRate(to string, src int64) (amount float64, err error) {
|
func (l *PurchaseCheckoutLogic) queryExchangeRate(to string, src int64) (amount float64, err error) {
|
||||||
|
// Convert cents to decimal amount
|
||||||
amount = float64(src) / float64(100)
|
amount = float64(src) / float64(100)
|
||||||
// query system currency
|
|
||||||
|
// Retrieve system currency configuration
|
||||||
currency, err := l.svcCtx.SystemModel.GetCurrencyConfig(l.ctx)
|
currency, err := l.svcCtx.SystemModel.GetCurrencyConfig(l.ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.Errorw("[CheckoutOrderLogic] GetCurrencyConfig error", logger.Field("error", err.Error()))
|
l.Errorw("[PurchaseCheckout] GetCurrencyConfig error", logger.Field("error", err.Error()))
|
||||||
return 0, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "GetCurrencyConfig error: %s", err.Error())
|
return 0, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "GetCurrencyConfig error: %s", err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Parse currency configuration
|
||||||
configs := struct {
|
configs := struct {
|
||||||
CurrencyUnit string
|
CurrencyUnit string
|
||||||
CurrencySymbol string
|
CurrencySymbol string
|
||||||
AccessKey string
|
AccessKey string
|
||||||
}{}
|
}{}
|
||||||
tool.SystemConfigSliceReflectToStruct(currency, &configs)
|
tool.SystemConfigSliceReflectToStruct(currency, &configs)
|
||||||
|
|
||||||
|
// Skip conversion if no exchange rate API key configured
|
||||||
if configs.AccessKey == "" {
|
if configs.AccessKey == "" {
|
||||||
return amount, nil
|
return amount, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Convert currency if system currency differs from target currency
|
||||||
if configs.CurrencyUnit != to {
|
if configs.CurrencyUnit != to {
|
||||||
// query exchange rate
|
|
||||||
result, err := exchangeRate.GetExchangeRete(configs.CurrencyUnit, to, configs.AccessKey, 1)
|
result, err := exchangeRate.GetExchangeRete(configs.CurrencyUnit, to, configs.AccessKey, 1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
@ -275,71 +322,123 @@ func (l *PurchaseCheckoutLogic) queryExchangeRate(to string, src int64) (amount
|
|||||||
return amount, nil
|
return amount, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Balance payment
|
// balancePayment processes balance payment with gift amount priority logic
|
||||||
|
// It prioritizes using gift amount first, then regular balance, and creates proper audit logs
|
||||||
func (l *PurchaseCheckoutLogic) balancePayment(u *user.User, o *order.Order) error {
|
func (l *PurchaseCheckoutLogic) balancePayment(u *user.User, o *order.Order) error {
|
||||||
|
if o.Amount == 0 {
|
||||||
|
// No payment required for zero-amount orders
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
var userInfo user.User
|
var userInfo user.User
|
||||||
err := l.svcCtx.UserModel.Transaction(l.ctx, func(db *gorm.DB) error {
|
err := l.svcCtx.UserModel.Transaction(l.ctx, func(db *gorm.DB) error {
|
||||||
|
// Retrieve latest user information with row-level locking
|
||||||
err := db.Model(&user.User{}).Where("id = ?", u.Id).First(&userInfo).Error
|
err := db.Model(&user.User{}).Where("id = ?", u.Id).First(&userInfo).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if o.GiftAmount != 0 {
|
// Check if user has sufficient total balance (regular + gift)
|
||||||
if userInfo.GiftAmount < o.GiftAmount {
|
totalAvailable := userInfo.Balance + userInfo.GiftAmount
|
||||||
return errors.Wrapf(xerr.NewErrCode(xerr.InsufficientBalance), "Insufficient gift balance")
|
if totalAvailable < o.Amount {
|
||||||
}
|
return errors.Wrapf(xerr.NewErrCode(xerr.InsufficientBalance),
|
||||||
// deduct gift amount
|
"Insufficient balance: required %d, available %d", o.Amount, totalAvailable)
|
||||||
userInfo.GiftAmount -= o.GiftAmount
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if userInfo.Balance < o.Amount {
|
// Calculate payment distribution: prioritize gift amount first
|
||||||
return errors.Wrapf(xerr.NewErrCode(xerr.InsufficientBalance), "Insufficient balance")
|
var giftUsed, balanceUsed int64
|
||||||
}
|
remainingAmount := o.Amount
|
||||||
// deduct balance
|
|
||||||
userInfo.Balance -= o.Amount
|
|
||||||
|
|
||||||
userInfo.GiftAmount -= o.GiftAmount
|
if userInfo.GiftAmount >= remainingAmount {
|
||||||
|
// Gift amount covers the entire payment
|
||||||
|
giftUsed = remainingAmount
|
||||||
|
balanceUsed = 0
|
||||||
|
} else {
|
||||||
|
// Use all available gift amount, then regular balance
|
||||||
|
giftUsed = userInfo.GiftAmount
|
||||||
|
balanceUsed = remainingAmount - giftUsed
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update user balances
|
||||||
|
userInfo.GiftAmount -= giftUsed
|
||||||
|
userInfo.Balance -= balanceUsed
|
||||||
|
|
||||||
|
// Save updated user information
|
||||||
err = l.svcCtx.UserModel.Update(l.ctx, &userInfo)
|
err = l.svcCtx.UserModel.Update(l.ctx, &userInfo)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
// create balance log
|
|
||||||
balanceLog := &user.BalanceLog{
|
// Create gift amount log if gift amount was used
|
||||||
Id: 0,
|
if giftUsed > 0 {
|
||||||
UserId: u.Id,
|
giftLog := &user.GiftAmountLog{
|
||||||
Amount: o.Amount,
|
UserId: u.Id,
|
||||||
Type: 3,
|
UserSubscribeId: 0, // Will be updated when subscription is created
|
||||||
OrderId: o.Id,
|
OrderNo: o.OrderNo,
|
||||||
Balance: userInfo.Balance,
|
Type: 2, // Type 2 represents gift amount decrease/usage
|
||||||
|
Amount: giftUsed,
|
||||||
|
Balance: userInfo.GiftAmount,
|
||||||
|
Remark: "Purchase payment",
|
||||||
|
}
|
||||||
|
err = db.Create(giftLog).Error
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
err = db.Create(balanceLog).Error
|
|
||||||
|
// Create balance log if regular balance was used
|
||||||
|
if balanceUsed > 0 {
|
||||||
|
balanceLog := &user.BalanceLog{
|
||||||
|
UserId: u.Id,
|
||||||
|
Amount: balanceUsed,
|
||||||
|
Type: 3, // Type 3 represents payment deduction
|
||||||
|
OrderId: o.Id,
|
||||||
|
Balance: userInfo.Balance,
|
||||||
|
}
|
||||||
|
err = db.Create(balanceLog).Error
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store gift amount used in order for potential refund tracking
|
||||||
|
o.GiftAmount = giftUsed
|
||||||
|
err = l.svcCtx.OrderModel.Update(l.ctx, o)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return db.Model(&order.Order{}).Where("id = ?", o.Id).Updates(map[string]interface{}{
|
|
||||||
"status": 2, // 2 means paid
|
// Mark order as paid (status = 2)
|
||||||
}).Error
|
return l.svcCtx.OrderModel.UpdateOrderStatus(l.ctx, o.OrderNo, 2, db)
|
||||||
})
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.Errorw("[CheckoutOrderLogic] Transaction error", logger.Field("error", err.Error()), logger.Field("orderNo", o.OrderNo))
|
l.Errorw("[PurchaseCheckout] Balance payment transaction error",
|
||||||
|
logger.Field("error", err.Error()),
|
||||||
|
logger.Field("orderNo", o.OrderNo),
|
||||||
|
logger.Field("userId", u.Id))
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
// create activity order task
|
|
||||||
|
// Enqueue order activation task for immediate processing
|
||||||
payload := queueType.ForthwithActivateOrderPayload{
|
payload := queueType.ForthwithActivateOrderPayload{
|
||||||
OrderNo: o.OrderNo,
|
OrderNo: o.OrderNo,
|
||||||
}
|
}
|
||||||
bytes, err := json.Marshal(payload)
|
bytes, err := json.Marshal(payload)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.Errorw("[CheckoutOrderLogic] Marshal error", logger.Field("error", err.Error()))
|
l.Errorw("[PurchaseCheckout] Marshal activation payload error", logger.Field("error", err.Error()))
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
task := asynq.NewTask(queueType.ForthwithActivateOrder, bytes)
|
task := asynq.NewTask(queueType.ForthwithActivateOrder, bytes)
|
||||||
_, err = l.svcCtx.Queue.EnqueueContext(l.ctx, task)
|
_, err = l.svcCtx.Queue.EnqueueContext(l.ctx, task)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.Errorw("[CheckoutOrderLogic] Enqueue error", logger.Field("error", err.Error()))
|
l.Errorw("[PurchaseCheckout] Enqueue activation task error", logger.Field("error", err.Error()))
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
l.Logger.Info("[CheckoutOrderLogic] Enqueue success", logger.Field("orderNo", o.OrderNo))
|
|
||||||
|
l.Logger.Info("[PurchaseCheckout] Balance payment completed successfully",
|
||||||
|
logger.Field("orderNo", o.OrderNo),
|
||||||
|
logger.Field("userId", u.Id))
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@ -93,7 +93,6 @@ func (l *PurchaseLogic) Purchase(req *types.PortalPurchaseRequest) (resp *types.
|
|||||||
}
|
}
|
||||||
// Calculate the handling fee
|
// Calculate the handling fee
|
||||||
amount -= couponAmount
|
amount -= couponAmount
|
||||||
var deductionAmount int64
|
|
||||||
// find payment method
|
// find payment method
|
||||||
paymentConfig, err := l.svcCtx.PaymentModel.FindOne(l.ctx, req.Payment)
|
paymentConfig, err := l.svcCtx.PaymentModel.FindOne(l.ctx, req.Payment)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -118,7 +117,7 @@ func (l *PurchaseLogic) Purchase(req *types.PortalPurchaseRequest) (resp *types.
|
|||||||
Price: price,
|
Price: price,
|
||||||
Amount: amount,
|
Amount: amount,
|
||||||
Discount: discountAmount,
|
Discount: discountAmount,
|
||||||
GiftAmount: deductionAmount,
|
GiftAmount: 0,
|
||||||
Coupon: req.Coupon,
|
Coupon: req.Coupon,
|
||||||
CouponDiscount: couponAmount,
|
CouponDiscount: couponAmount,
|
||||||
PaymentId: req.Payment,
|
PaymentId: req.Payment,
|
||||||
|
|||||||
@ -38,6 +38,7 @@ func CalculateRemainingAmount(ctx context.Context, svcCtx *svc.ServiceContext, u
|
|||||||
orderQuantity := orderDetails.Quantity
|
orderQuantity := orderDetails.Quantity
|
||||||
// Calculate Order Amount
|
// Calculate Order Amount
|
||||||
orderAmount := orderDetails.Amount + orderDetails.GiftAmount
|
orderAmount := orderDetails.Amount + orderDetails.GiftAmount
|
||||||
|
|
||||||
if len(orderDetails.SubOrders) > 0 {
|
if len(orderDetails.SubOrders) > 0 {
|
||||||
for _, subOrder := range orderDetails.SubOrders {
|
for _, subOrder := range orderDetails.SubOrders {
|
||||||
if subOrder.Status == 2 || subOrder.Status == 5 {
|
if subOrder.Status == 2 || subOrder.Status == 5 {
|
||||||
@ -47,7 +48,7 @@ func CalculateRemainingAmount(ctx context.Context, svcCtx *svc.ServiceContext, u
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Calculate Remaining Amount
|
// Calculate Remaining Amount
|
||||||
remainingAmount := deduction.CalculateRemainingAmount(
|
remainingAmount, err := deduction.CalculateRemainingAmount(
|
||||||
deduction.Subscribe{
|
deduction.Subscribe{
|
||||||
StartTime: userSubscribe.StartTime,
|
StartTime: userSubscribe.StartTime,
|
||||||
ExpireTime: userSubscribe.ExpireTime,
|
ExpireTime: userSubscribe.ExpireTime,
|
||||||
@ -64,5 +65,8 @@ func CalculateRemainingAmount(ctx context.Context, svcCtx *svc.ServiceContext, u
|
|||||||
Quantity: orderQuantity,
|
Quantity: orderQuantity,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
if err != nil {
|
||||||
|
return 0, errors.Wrapf(xerr.NewErrCode(500), "CalculateRemainingAmount failed, userSubscribeId: %d, err: %v", userSubscribeId, err)
|
||||||
|
}
|
||||||
return remainingAmount, nil
|
return remainingAmount, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@ -21,7 +21,7 @@ type UnsubscribeLogic struct {
|
|||||||
svcCtx *svc.ServiceContext
|
svcCtx *svc.ServiceContext
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewUnsubscribeLogic Unsubscribe
|
// NewUnsubscribeLogic creates a new instance of UnsubscribeLogic for handling subscription cancellation
|
||||||
func NewUnsubscribeLogic(ctx context.Context, svcCtx *svc.ServiceContext) *UnsubscribeLogic {
|
func NewUnsubscribeLogic(ctx context.Context, svcCtx *svc.ServiceContext) *UnsubscribeLogic {
|
||||||
return &UnsubscribeLogic{
|
return &UnsubscribeLogic{
|
||||||
Logger: logger.WithContext(ctx),
|
Logger: logger.WithContext(ctx),
|
||||||
@ -30,39 +30,90 @@ func NewUnsubscribeLogic(ctx context.Context, svcCtx *svc.ServiceContext) *Unsub
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Unsubscribe handles the subscription cancellation process with proper refund distribution
|
||||||
|
// It prioritizes refunding to gift amount for balance-paid orders, then to regular balance
|
||||||
func (l *UnsubscribeLogic) Unsubscribe(req *types.UnsubscribeRequest) error {
|
func (l *UnsubscribeLogic) Unsubscribe(req *types.UnsubscribeRequest) error {
|
||||||
u, ok := l.ctx.Value(constant.CtxKeyUser).(*user.User)
|
u, ok := l.ctx.Value(constant.CtxKeyUser).(*user.User)
|
||||||
if !ok {
|
if !ok {
|
||||||
logger.Error("current user is not found in context")
|
logger.Error("current user is not found in context")
|
||||||
return errors.Wrapf(xerr.NewErrCode(xerr.InvalidAccess), "Invalid Access")
|
return errors.Wrapf(xerr.NewErrCode(xerr.InvalidAccess), "Invalid Access")
|
||||||
}
|
}
|
||||||
|
// Calculate the remaining amount to refund based on unused subscription time/traffic
|
||||||
remainingAmount, err := CalculateRemainingAmount(l.ctx, l.svcCtx, req.Id)
|
remainingAmount, err := CalculateRemainingAmount(l.ctx, l.svcCtx, req.Id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
// update user subscribe
|
|
||||||
|
// Process unsubscription in a database transaction to ensure data consistency
|
||||||
err = l.svcCtx.UserModel.Transaction(l.ctx, func(db *gorm.DB) error {
|
err = l.svcCtx.UserModel.Transaction(l.ctx, func(db *gorm.DB) error {
|
||||||
|
// Find and update subscription status to cancelled (status = 4)
|
||||||
var userSub user.Subscribe
|
var userSub user.Subscribe
|
||||||
if err := db.Model(&user.Subscribe{}).Where("id = ?", req.Id).First(&userSub).Error; err != nil {
|
if err = db.Model(&user.Subscribe{}).Where("id = ?", req.Id).First(&userSub).Error; err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
userSub.Status = 4
|
userSub.Status = 4 // Set status to cancelled
|
||||||
if err := l.svcCtx.UserModel.UpdateSubscribe(l.ctx, &userSub); err != nil {
|
if err = l.svcCtx.UserModel.UpdateSubscribe(l.ctx, &userSub); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
balance := remainingAmount + u.Balance
|
|
||||||
// insert deduction log
|
// Query the original order information to determine refund strategy
|
||||||
balanceLog := user.BalanceLog{
|
orderInfo, err := l.svcCtx.OrderModel.FindOne(l.ctx, userSub.OrderId)
|
||||||
UserId: userSub.UserId,
|
if err != nil {
|
||||||
OrderId: userSub.OrderId,
|
|
||||||
Amount: remainingAmount,
|
|
||||||
Type: 4,
|
|
||||||
Balance: balance,
|
|
||||||
}
|
|
||||||
if err := db.Model(&user.BalanceLog{}).Create(&balanceLog).Error; err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
// update user balance
|
// Calculate refund distribution based on payment method and gift amount priority
|
||||||
|
var balance, gift int64
|
||||||
|
if orderInfo.Method == "balance" {
|
||||||
|
// For balance-paid orders, prioritize refunding to gift amount first
|
||||||
|
if orderInfo.GiftAmount >= remainingAmount {
|
||||||
|
// Gift amount covers the entire refund - refund all to gift balance
|
||||||
|
gift = remainingAmount
|
||||||
|
balance = u.Balance // Regular balance remains unchanged
|
||||||
|
} else {
|
||||||
|
// Gift amount insufficient - refund to gift first, remainder to regular balance
|
||||||
|
gift = orderInfo.GiftAmount
|
||||||
|
balance = u.Balance + (remainingAmount - orderInfo.GiftAmount)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// For non-balance payment orders, refund entirely to regular balance
|
||||||
|
balance = remainingAmount + u.Balance
|
||||||
|
gift = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create balance log entry only if there's an actual regular balance refund
|
||||||
|
balanceRefundAmount := balance - u.Balance
|
||||||
|
if balanceRefundAmount > 0 {
|
||||||
|
balanceLog := user.BalanceLog{
|
||||||
|
UserId: userSub.UserId,
|
||||||
|
OrderId: userSub.OrderId,
|
||||||
|
Amount: balanceRefundAmount,
|
||||||
|
Type: 4, // Type 4 represents refund transaction
|
||||||
|
Balance: balance,
|
||||||
|
}
|
||||||
|
if err := db.Model(&user.BalanceLog{}).Create(&balanceLog).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create gift amount log entry if there's a gift balance refund
|
||||||
|
if gift > 0 {
|
||||||
|
giftLog := user.GiftAmountLog{
|
||||||
|
UserId: userSub.UserId,
|
||||||
|
UserSubscribeId: userSub.Id,
|
||||||
|
OrderNo: orderInfo.OrderNo,
|
||||||
|
Type: 1, // Type 1 represents gift amount increase
|
||||||
|
Amount: gift,
|
||||||
|
Balance: u.GiftAmount + gift,
|
||||||
|
Remark: "Unsubscribe refund",
|
||||||
|
}
|
||||||
|
if err := db.Model(&user.GiftAmountLog{}).Create(&giftLog).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
// Update user's gift amount
|
||||||
|
u.GiftAmount += gift
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update user's regular balance and save changes to database
|
||||||
u.Balance = balance
|
u.Balance = balance
|
||||||
return l.svcCtx.UserModel.Update(l.ctx, u)
|
return l.svcCtx.UserModel.Update(l.ctx, u)
|
||||||
})
|
})
|
||||||
|
|||||||
@ -1,90 +1,302 @@
|
|||||||
|
// Package deduction provides functionality for calculating remaining amounts
|
||||||
|
// in subscription billing systems, supporting various time units and traffic-based calculations.
|
||||||
package deduction
|
package deduction
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"math"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/perfect-panel/server/pkg/tool"
|
"github.com/perfect-panel/server/pkg/tool"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
UnitTimeNoLimit = "NoLimit"
|
// Time unit constants for subscription billing
|
||||||
UnitTimeYear = "Year"
|
UnitTimeNoLimit = "NoLimit" // Unlimited time subscription
|
||||||
UnitTimeMonth = "Month"
|
UnitTimeYear = "Year" // Annual subscription
|
||||||
UnitTimeDay = "Day"
|
UnitTimeMonth = "Month" // Monthly subscription
|
||||||
UintTimeHour = "Hour"
|
UnitTimeDay = "Day" // Daily subscription
|
||||||
UintTimeMinute = "Minute"
|
UnitTimeHour = "Hour" // Hourly subscription
|
||||||
|
UnitTimeMinute = "Minute" // Per-minute subscription
|
||||||
|
|
||||||
ResetCycleNone = 0
|
// Reset cycle constants for traffic resets
|
||||||
ResetCycle1st = 1
|
ResetCycleNone = 0 // No reset cycle
|
||||||
ResetCycleMonthly = 2
|
ResetCycle1st = 1 // Reset on 1st of each month
|
||||||
ResetCycleYear = 3
|
ResetCycleMonthly = 2 // Reset monthly based on start date
|
||||||
|
ResetCycleYear = 3 // Reset yearly based on start date
|
||||||
|
|
||||||
|
// Safety limits for overflow protection
|
||||||
|
maxInt64 = math.MaxInt64
|
||||||
|
minInt64 = math.MinInt64
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Error definitions for validation and calculation failures
|
||||||
|
var (
|
||||||
|
ErrInvalidQuantity = errors.New("order quantity cannot be zero or negative")
|
||||||
|
ErrInvalidAmount = errors.New("order amount cannot be negative")
|
||||||
|
ErrInvalidTraffic = errors.New("traffic values cannot be negative")
|
||||||
|
ErrInvalidTimeRange = errors.New("expire time must be after start time")
|
||||||
|
ErrInvalidUnitTime = errors.New("invalid unit time")
|
||||||
|
ErrInvalidDeductionRatio = errors.New("deduction ratio must be between 0 and 100")
|
||||||
|
ErrOverflow = errors.New("calculation overflow")
|
||||||
|
)
|
||||||
|
|
||||||
|
// Subscribe represents a subscription with time and traffic limits
|
||||||
type Subscribe struct {
|
type Subscribe struct {
|
||||||
StartTime time.Time
|
StartTime time.Time // Subscription start time
|
||||||
ExpireTime time.Time
|
ExpireTime time.Time // Subscription expiration time
|
||||||
Traffic int64
|
Traffic int64 // Total traffic allowance in bytes
|
||||||
Download int64
|
Download int64 // Downloaded traffic in bytes
|
||||||
Upload int64
|
Upload int64 // Uploaded traffic in bytes
|
||||||
UnitTime string
|
UnitTime string // Time unit for billing (Year, Month, Day, etc.)
|
||||||
UnitPrice int64
|
UnitPrice int64 // Price per unit time
|
||||||
ResetCycle int64
|
ResetCycle int64 // Traffic reset cycle
|
||||||
DeductionRatio int64
|
DeductionRatio int64 // Deduction ratio for weighted calculations (0-100)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Order represents a purchase order for subscription calculation
|
||||||
type Order struct {
|
type Order struct {
|
||||||
Amount int64
|
Amount int64 // Total order amount
|
||||||
Quantity int64
|
Quantity int64 // Order quantity
|
||||||
}
|
}
|
||||||
|
|
||||||
func CalculateRemainingAmount(sub Subscribe, order Order) int64 {
|
// Validate checks if the Subscribe struct contains valid data
|
||||||
if sub.UnitTime == UnitTimeNoLimit && sub.ResetCycle != 0 {
|
func (s *Subscribe) Validate() error {
|
||||||
return 0
|
if s.Traffic < 0 || s.Download < 0 || s.Upload < 0 {
|
||||||
|
return ErrInvalidTraffic
|
||||||
}
|
}
|
||||||
// 实际单价
|
|
||||||
sub.UnitPrice = order.Amount / order.Quantity
|
if s.Download+s.Upload > s.Traffic {
|
||||||
now := time.Now()
|
return fmt.Errorf("download + upload (%d) cannot exceed total traffic (%d)", s.Download+s.Upload, s.Traffic)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !s.ExpireTime.After(s.StartTime) {
|
||||||
|
return ErrInvalidTimeRange
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.DeductionRatio < 0 || s.DeductionRatio > 100 {
|
||||||
|
return ErrInvalidDeductionRatio
|
||||||
|
}
|
||||||
|
|
||||||
|
validUnitTimes := []string{UnitTimeNoLimit, UnitTimeYear, UnitTimeMonth, UnitTimeDay, UnitTimeHour, UnitTimeMinute}
|
||||||
|
valid := false
|
||||||
|
for _, ut := range validUnitTimes {
|
||||||
|
if s.UnitTime == ut {
|
||||||
|
valid = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !valid {
|
||||||
|
return ErrInvalidUnitTime
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate checks if the Order struct contains valid data
|
||||||
|
func (o *Order) Validate() error {
|
||||||
|
if o.Quantity <= 0 {
|
||||||
|
return ErrInvalidQuantity
|
||||||
|
}
|
||||||
|
if o.Amount < 0 {
|
||||||
|
return ErrInvalidAmount
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// safeMultiply performs multiplication with overflow protection
|
||||||
|
func safeMultiply(a, b int64) (int64, error) {
|
||||||
|
if a == 0 || b == 0 {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if a > 0 && b > 0 {
|
||||||
|
if a > maxInt64/b {
|
||||||
|
return 0, ErrOverflow
|
||||||
|
}
|
||||||
|
} else if a < 0 && b < 0 {
|
||||||
|
if a < maxInt64/b {
|
||||||
|
return 0, ErrOverflow
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (a > 0 && b < minInt64/a) || (a < 0 && b > minInt64/a) {
|
||||||
|
return 0, ErrOverflow
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return a * b, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// safeAdd performs addition with overflow protection
|
||||||
|
func safeAdd(a, b int64) (int64, error) {
|
||||||
|
if (b > 0 && a > maxInt64-b) || (b < 0 && a < minInt64-b) {
|
||||||
|
return 0, ErrOverflow
|
||||||
|
}
|
||||||
|
return a + b, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// safeDivide performs division with zero-division protection
|
||||||
|
func safeDivide(a, b int64) (int64, error) {
|
||||||
|
if b == 0 {
|
||||||
|
return 0, errors.New("division by zero")
|
||||||
|
}
|
||||||
|
return a / b, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CalculateRemainingAmount calculates the remaining refund amount for a subscription
|
||||||
|
// based on unused time and traffic. Returns the amount and any calculation errors.
|
||||||
|
func CalculateRemainingAmount(sub Subscribe, order Order) (int64, error) {
|
||||||
|
if err := sub.Validate(); err != nil {
|
||||||
|
return 0, fmt.Errorf("invalid subscription: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := order.Validate(); err != nil {
|
||||||
|
return 0, fmt.Errorf("invalid order: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if sub.UnitTime == UnitTimeNoLimit && sub.ResetCycle != 0 {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
unitPrice, err := safeDivide(order.Amount, order.Quantity)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("failed to calculate unit price: %w", err)
|
||||||
|
}
|
||||||
|
sub.UnitPrice = unitPrice
|
||||||
|
|
||||||
|
loc, err := time.LoadLocation(sub.StartTime.Location().String())
|
||||||
|
if err != nil {
|
||||||
|
loc = time.UTC
|
||||||
|
}
|
||||||
|
now := time.Now().In(loc)
|
||||||
|
|
||||||
switch sub.UnitTime {
|
switch sub.UnitTime {
|
||||||
case UnitTimeNoLimit:
|
case UnitTimeNoLimit:
|
||||||
usedTraffic := sub.Traffic - sub.Download - sub.Upload
|
return calculateNoLimitAmount(sub, order)
|
||||||
unitPrice := float64(order.Amount) / float64(sub.Traffic)
|
|
||||||
return int64(float64(usedTraffic) * unitPrice)
|
|
||||||
|
|
||||||
case UnitTimeYear:
|
case UnitTimeYear:
|
||||||
remainingYears := tool.YearDiff(now, sub.ExpireTime)
|
remainingYears := tool.YearDiff(now, sub.ExpireTime)
|
||||||
remainingUnitTimeAmount := calculateRemainingUnitTimeAmount(sub)
|
remainingUnitTimeAmount, err := calculateRemainingUnitTimeAmount(sub)
|
||||||
return int64(remainingYears)*sub.UnitPrice + remainingUnitTimeAmount
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
yearAmount, err := safeMultiply(int64(remainingYears), sub.UnitPrice)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("year calculation overflow: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
total, err := safeAdd(yearAmount, remainingUnitTimeAmount)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("total calculation overflow: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return total, nil
|
||||||
|
|
||||||
case UnitTimeMonth:
|
case UnitTimeMonth:
|
||||||
remainingMonths := tool.MonthDiff(now, sub.ExpireTime)
|
remainingMonths := tool.MonthDiff(now, sub.ExpireTime)
|
||||||
remainingUnitTimeAmount := calculateRemainingUnitTimeAmount(sub)
|
remainingUnitTimeAmount, err := calculateRemainingUnitTimeAmount(sub)
|
||||||
return int64(remainingMonths)*sub.UnitPrice + remainingUnitTimeAmount
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
monthAmount, err := safeMultiply(int64(remainingMonths), sub.UnitPrice)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("month calculation overflow: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
total, err := safeAdd(monthAmount, remainingUnitTimeAmount)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("total calculation overflow: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return total, nil
|
||||||
|
|
||||||
case UnitTimeDay:
|
case UnitTimeDay:
|
||||||
remainingDays := tool.DayDiff(now, sub.ExpireTime)
|
remainingDays := tool.DayDiff(now, sub.ExpireTime)
|
||||||
remainingUnitTimeAmount := calculateRemainingUnitTimeAmount(sub)
|
remainingUnitTimeAmount, err := calculateRemainingUnitTimeAmount(sub)
|
||||||
return remainingDays*sub.UnitPrice + remainingUnitTimeAmount
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
dayAmount, err := safeMultiply(remainingDays, sub.UnitPrice)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("day calculation overflow: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
total, err := safeAdd(dayAmount, remainingUnitTimeAmount)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("total calculation overflow: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return total, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return 0
|
return 0, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func calculateRemainingUnitTimeAmount(sub Subscribe) int64 {
|
// calculateNoLimitAmount calculates refund amount for unlimited time subscriptions
|
||||||
|
// based on unused traffic only
|
||||||
|
func calculateNoLimitAmount(sub Subscribe, order Order) (int64, error) {
|
||||||
|
if sub.Traffic == 0 {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
usedTraffic := sub.Traffic - sub.Download - sub.Upload
|
||||||
|
if usedTraffic < 0 {
|
||||||
|
usedTraffic = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
unitPrice := float64(order.Amount) / float64(sub.Traffic)
|
||||||
|
result := float64(usedTraffic) * unitPrice
|
||||||
|
|
||||||
|
if result > float64(maxInt64) || result < float64(minInt64) {
|
||||||
|
return 0, ErrOverflow
|
||||||
|
}
|
||||||
|
|
||||||
|
return int64(result), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// calculateRemainingUnitTimeAmount calculates the remaining amount based on
|
||||||
|
// both time and traffic usage, applying deduction ratios when specified
|
||||||
|
func calculateRemainingUnitTimeAmount(sub Subscribe) (int64, error) {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
trafficWeight, timeWeight := calculateWeights(sub.DeductionRatio)
|
trafficWeight, timeWeight := calculateWeights(sub.DeductionRatio)
|
||||||
remainingDays, totalDays := getRemainingAndTotalDays(sub, now)
|
remainingDays, totalDays := getRemainingAndTotalDays(sub, now)
|
||||||
remainingTraffic := sub.Traffic - sub.Download - sub.Upload
|
|
||||||
remainingTimeAmount := calculateProportionalAmount(sub.UnitPrice, remainingDays, totalDays)
|
if totalDays == 0 {
|
||||||
remainingTrafficAmount := calculateProportionalAmount(sub.UnitPrice, remainingTraffic, sub.Traffic)
|
return 0, nil
|
||||||
if sub.Traffic == 0 {
|
|
||||||
return remainingTimeAmount
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
remainingTraffic := sub.Traffic - sub.Download - sub.Upload
|
||||||
|
if remainingTraffic < 0 {
|
||||||
|
remainingTraffic = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
remainingTimeAmount, err := calculateProportionalAmount(sub.UnitPrice, remainingDays, totalDays)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("time amount calculation failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if sub.Traffic == 0 {
|
||||||
|
return remainingTimeAmount, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
remainingTrafficAmount, err := calculateProportionalAmount(sub.UnitPrice, remainingTraffic, sub.Traffic)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("traffic amount calculation failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
if sub.DeductionRatio != 0 {
|
if sub.DeductionRatio != 0 {
|
||||||
return calculateWeightedAmount(sub.UnitPrice, remainingTraffic, sub.Traffic, remainingDays, totalDays, trafficWeight, timeWeight)
|
return calculateWeightedAmount(sub.UnitPrice, remainingTraffic, sub.Traffic, remainingDays, totalDays, trafficWeight, timeWeight)
|
||||||
}
|
}
|
||||||
|
|
||||||
return min(remainingTimeAmount, remainingTrafficAmount)
|
return min(remainingTimeAmount, remainingTrafficAmount), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// calculateWeights converts deduction ratio to traffic and time weights
|
||||||
|
// for weighted calculations
|
||||||
func calculateWeights(deductionRatio int64) (float64, float64) {
|
func calculateWeights(deductionRatio int64) (float64, float64) {
|
||||||
if deductionRatio == 0 {
|
if deductionRatio == 0 {
|
||||||
return 0, 0
|
return 0, 0
|
||||||
@ -94,20 +306,32 @@ func calculateWeights(deductionRatio int64) (float64, float64) {
|
|||||||
return trafficWeight, timeWeight
|
return trafficWeight, timeWeight
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// getRemainingAndTotalDays calculates remaining and total days based on
|
||||||
|
// the subscription's reset cycle configuration
|
||||||
func getRemainingAndTotalDays(sub Subscribe, now time.Time) (int64, int64) {
|
func getRemainingAndTotalDays(sub Subscribe, now time.Time) (int64, int64) {
|
||||||
switch sub.ResetCycle {
|
switch sub.ResetCycle {
|
||||||
case ResetCycleNone:
|
case ResetCycleNone:
|
||||||
|
|
||||||
remaining := sub.ExpireTime.Sub(now).Hours() / 24
|
remaining := sub.ExpireTime.Sub(now).Hours() / 24
|
||||||
total := sub.ExpireTime.Sub(sub.StartTime).Hours() / 24
|
total := sub.ExpireTime.Sub(sub.StartTime).Hours() / 24
|
||||||
|
if remaining < 0 {
|
||||||
|
remaining = 0
|
||||||
|
}
|
||||||
|
if total < 0 {
|
||||||
|
total = 0
|
||||||
|
}
|
||||||
return int64(remaining), int64(total)
|
return int64(remaining), int64(total)
|
||||||
|
|
||||||
case ResetCycle1st:
|
case ResetCycle1st:
|
||||||
return tool.DaysToNextMonth(now), tool.GetLastDayOfMonth(now)
|
return tool.DaysToNextMonth(now), tool.GetLastDayOfMonth(now)
|
||||||
|
|
||||||
case ResetCycleMonthly:
|
case ResetCycleMonthly:
|
||||||
// -1 to include the current day
|
remaining := tool.DaysToMonthDay(now, sub.StartTime.Day()) - 1
|
||||||
return tool.DaysToMonthDay(now, sub.StartTime.Day()) - 1, tool.DaysToMonthDay(now, sub.StartTime.Day())
|
total := tool.DaysToMonthDay(now, sub.StartTime.Day())
|
||||||
|
if remaining < 0 {
|
||||||
|
remaining = 0
|
||||||
|
}
|
||||||
|
return remaining, total
|
||||||
|
|
||||||
case ResetCycleYear:
|
case ResetCycleYear:
|
||||||
return tool.DaysToYearDay(now, int(sub.StartTime.Month()), sub.StartTime.Day()),
|
return tool.DaysToYearDay(now, int(sub.StartTime.Month()), sub.StartTime.Day()),
|
||||||
tool.GetYearDays(now, int(sub.StartTime.Month()), sub.StartTime.Day())
|
tool.GetYearDays(now, int(sub.StartTime.Month()), sub.StartTime.Day())
|
||||||
@ -115,13 +339,36 @@ func getRemainingAndTotalDays(sub Subscribe, now time.Time) (int64, int64) {
|
|||||||
return 0, 0
|
return 0, 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func calculateWeightedAmount(unitPrice, remainingTraffic, totalTraffic, remainingDays, totalDays int64, trafficWeight, timeWeight float64) int64 {
|
// calculateWeightedAmount applies weighted calculation combining both time and traffic
|
||||||
|
// remaining ratios based on the specified weights
|
||||||
|
func calculateWeightedAmount(unitPrice, remainingTraffic, totalTraffic, remainingDays, totalDays int64, trafficWeight, timeWeight float64) (int64, error) {
|
||||||
|
if totalDays == 0 || totalTraffic == 0 {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
remainingTimeRatio := float64(remainingDays) / float64(totalDays)
|
remainingTimeRatio := float64(remainingDays) / float64(totalDays)
|
||||||
remainingTrafficRatio := float64(remainingTraffic) / float64(totalTraffic)
|
remainingTrafficRatio := float64(remainingTraffic) / float64(totalTraffic)
|
||||||
weightedRemainingRatio := (timeWeight * remainingTimeRatio) + (trafficWeight * remainingTrafficRatio)
|
weightedRemainingRatio := (timeWeight * remainingTimeRatio) + (trafficWeight * remainingTrafficRatio)
|
||||||
return int64(float64(unitPrice) * weightedRemainingRatio)
|
|
||||||
|
result := float64(unitPrice) * weightedRemainingRatio
|
||||||
|
if result > float64(maxInt64) || result < float64(minInt64) {
|
||||||
|
return 0, ErrOverflow
|
||||||
|
}
|
||||||
|
|
||||||
|
return int64(result), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func calculateProportionalAmount(unitPrice, remaining, total int64) int64 {
|
// calculateProportionalAmount calculates proportional amount based on
|
||||||
return int64(float64(unitPrice) * (float64(remaining) / float64(total)))
|
// remaining vs total ratio with overflow protection
|
||||||
|
func calculateProportionalAmount(unitPrice, remaining, total int64) (int64, error) {
|
||||||
|
if total == 0 {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
result := float64(unitPrice) * (float64(remaining) / float64(total))
|
||||||
|
if result > float64(maxInt64) || result < float64(minInt64) {
|
||||||
|
return 0, ErrOverflow
|
||||||
|
}
|
||||||
|
|
||||||
|
return int64(result), nil
|
||||||
}
|
}
|
||||||
|
|||||||
665
pkg/deduction/deduction_test.go
Normal file
665
pkg/deduction/deduction_test.go
Normal file
@ -0,0 +1,665 @@
|
|||||||
|
package deduction
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSubscribe_Validate(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
sub Subscribe
|
||||||
|
wantErr bool
|
||||||
|
errType error
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "valid subscription",
|
||||||
|
sub: Subscribe{
|
||||||
|
StartTime: time.Now(),
|
||||||
|
ExpireTime: time.Now().Add(24 * time.Hour),
|
||||||
|
Traffic: 1000,
|
||||||
|
Download: 100,
|
||||||
|
Upload: 200,
|
||||||
|
UnitTime: UnitTimeMonth,
|
||||||
|
DeductionRatio: 50,
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "negative traffic",
|
||||||
|
sub: Subscribe{
|
||||||
|
StartTime: time.Now(),
|
||||||
|
ExpireTime: time.Now().Add(24 * time.Hour),
|
||||||
|
Traffic: -1000,
|
||||||
|
Download: 100,
|
||||||
|
Upload: 200,
|
||||||
|
UnitTime: UnitTimeMonth,
|
||||||
|
DeductionRatio: 50,
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
errType: ErrInvalidTraffic,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "negative download",
|
||||||
|
sub: Subscribe{
|
||||||
|
StartTime: time.Now(),
|
||||||
|
ExpireTime: time.Now().Add(24 * time.Hour),
|
||||||
|
Traffic: 1000,
|
||||||
|
Download: -100,
|
||||||
|
Upload: 200,
|
||||||
|
UnitTime: UnitTimeMonth,
|
||||||
|
DeductionRatio: 50,
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
errType: ErrInvalidTraffic,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "download + upload exceeds traffic",
|
||||||
|
sub: Subscribe{
|
||||||
|
StartTime: time.Now(),
|
||||||
|
ExpireTime: time.Now().Add(24 * time.Hour),
|
||||||
|
Traffic: 1000,
|
||||||
|
Download: 600,
|
||||||
|
Upload: 500,
|
||||||
|
UnitTime: UnitTimeMonth,
|
||||||
|
DeductionRatio: 50,
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "expire time before start time",
|
||||||
|
sub: Subscribe{
|
||||||
|
StartTime: time.Now(),
|
||||||
|
ExpireTime: time.Now().Add(-24 * time.Hour),
|
||||||
|
Traffic: 1000,
|
||||||
|
Download: 100,
|
||||||
|
Upload: 200,
|
||||||
|
UnitTime: UnitTimeMonth,
|
||||||
|
DeductionRatio: 50,
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
errType: ErrInvalidTimeRange,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid deduction ratio - negative",
|
||||||
|
sub: Subscribe{
|
||||||
|
StartTime: time.Now(),
|
||||||
|
ExpireTime: time.Now().Add(24 * time.Hour),
|
||||||
|
Traffic: 1000,
|
||||||
|
Download: 100,
|
||||||
|
Upload: 200,
|
||||||
|
UnitTime: UnitTimeMonth,
|
||||||
|
DeductionRatio: -10,
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
errType: ErrInvalidDeductionRatio,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid deduction ratio - over 100",
|
||||||
|
sub: Subscribe{
|
||||||
|
StartTime: time.Now(),
|
||||||
|
ExpireTime: time.Now().Add(24 * time.Hour),
|
||||||
|
Traffic: 1000,
|
||||||
|
Download: 100,
|
||||||
|
Upload: 200,
|
||||||
|
UnitTime: UnitTimeMonth,
|
||||||
|
DeductionRatio: 150,
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
errType: ErrInvalidDeductionRatio,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid unit time",
|
||||||
|
sub: Subscribe{
|
||||||
|
StartTime: time.Now(),
|
||||||
|
ExpireTime: time.Now().Add(24 * time.Hour),
|
||||||
|
Traffic: 1000,
|
||||||
|
Download: 100,
|
||||||
|
Upload: 200,
|
||||||
|
UnitTime: "InvalidUnit",
|
||||||
|
DeductionRatio: 50,
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
errType: ErrInvalidUnitTime,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
err := tt.sub.Validate()
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("Subscribe.Validate() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if tt.errType != nil && err != tt.errType {
|
||||||
|
t.Errorf("Subscribe.Validate() error = %v, want %v", err, tt.errType)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOrder_Validate(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
order Order
|
||||||
|
wantErr bool
|
||||||
|
errType error
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "valid order",
|
||||||
|
order: Order{Amount: 1000, Quantity: 2},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "zero quantity",
|
||||||
|
order: Order{Amount: 1000, Quantity: 0},
|
||||||
|
wantErr: true,
|
||||||
|
errType: ErrInvalidQuantity,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "negative quantity",
|
||||||
|
order: Order{Amount: 1000, Quantity: -1},
|
||||||
|
wantErr: true,
|
||||||
|
errType: ErrInvalidQuantity,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "negative amount",
|
||||||
|
order: Order{Amount: -1000, Quantity: 2},
|
||||||
|
wantErr: true,
|
||||||
|
errType: ErrInvalidAmount,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "zero amount is valid",
|
||||||
|
order: Order{Amount: 0, Quantity: 1},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
err := tt.order.Validate()
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("Order.Validate() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if tt.errType != nil && err != tt.errType {
|
||||||
|
t.Errorf("Order.Validate() error = %v, want %v", err, tt.errType)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSafeMultiply(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
a, b int64
|
||||||
|
want int64
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "normal multiplication",
|
||||||
|
a: 10,
|
||||||
|
b: 20,
|
||||||
|
want: 200,
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "zero multiplication",
|
||||||
|
a: 10,
|
||||||
|
b: 0,
|
||||||
|
want: 0,
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "negative multiplication",
|
||||||
|
a: -10,
|
||||||
|
b: 20,
|
||||||
|
want: -200,
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "overflow case",
|
||||||
|
a: math.MaxInt64,
|
||||||
|
b: 2,
|
||||||
|
want: 0,
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "large numbers no overflow",
|
||||||
|
a: 1000000,
|
||||||
|
b: 1000000,
|
||||||
|
want: 1000000000000,
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got, err := safeMultiply(tt.a, tt.b)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("safeMultiply() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("safeMultiply() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSafeAdd(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
a, b int64
|
||||||
|
want int64
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "normal addition",
|
||||||
|
a: 10,
|
||||||
|
b: 20,
|
||||||
|
want: 30,
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "negative addition",
|
||||||
|
a: -10,
|
||||||
|
b: 5,
|
||||||
|
want: -5,
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "overflow case",
|
||||||
|
a: math.MaxInt64,
|
||||||
|
b: 1,
|
||||||
|
want: 0,
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "underflow case",
|
||||||
|
a: math.MinInt64,
|
||||||
|
b: -1,
|
||||||
|
want: 0,
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got, err := safeAdd(tt.a, tt.b)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("safeAdd() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("safeAdd() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSafeDivide(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
a, b int64
|
||||||
|
want int64
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "normal division",
|
||||||
|
a: 20,
|
||||||
|
b: 10,
|
||||||
|
want: 2,
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "division by zero",
|
||||||
|
a: 20,
|
||||||
|
b: 0,
|
||||||
|
want: 0,
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "negative division",
|
||||||
|
a: -20,
|
||||||
|
b: 10,
|
||||||
|
want: -2,
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "zero dividend",
|
||||||
|
a: 0,
|
||||||
|
b: 10,
|
||||||
|
want: 0,
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got, err := safeDivide(tt.a, tt.b)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("safeDivide() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("safeDivide() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCalculateWeights(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
deductionRatio int64
|
||||||
|
wantTrafficWeight float64
|
||||||
|
wantTimeWeight float64
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "zero ratio",
|
||||||
|
deductionRatio: 0,
|
||||||
|
wantTrafficWeight: 0,
|
||||||
|
wantTimeWeight: 0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "50% ratio",
|
||||||
|
deductionRatio: 50,
|
||||||
|
wantTrafficWeight: 0.5,
|
||||||
|
wantTimeWeight: 0.5,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "75% ratio",
|
||||||
|
deductionRatio: 75,
|
||||||
|
wantTrafficWeight: 0.75,
|
||||||
|
wantTimeWeight: 0.25,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "100% ratio",
|
||||||
|
deductionRatio: 100,
|
||||||
|
wantTrafficWeight: 1.0,
|
||||||
|
wantTimeWeight: 0.0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
gotTrafficWeight, gotTimeWeight := calculateWeights(tt.deductionRatio)
|
||||||
|
if gotTrafficWeight != tt.wantTrafficWeight {
|
||||||
|
t.Errorf("calculateWeights() trafficWeight = %v, want %v", gotTrafficWeight, tt.wantTrafficWeight)
|
||||||
|
}
|
||||||
|
if gotTimeWeight != tt.wantTimeWeight {
|
||||||
|
t.Errorf("calculateWeights() timeWeight = %v, want %v", gotTimeWeight, tt.wantTimeWeight)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCalculateProportionalAmount(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
unitPrice int64
|
||||||
|
remaining int64
|
||||||
|
total int64
|
||||||
|
want int64
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "normal calculation",
|
||||||
|
unitPrice: 100,
|
||||||
|
remaining: 50,
|
||||||
|
total: 100,
|
||||||
|
want: 50,
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "zero total",
|
||||||
|
unitPrice: 100,
|
||||||
|
remaining: 50,
|
||||||
|
total: 0,
|
||||||
|
want: 0,
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "zero remaining",
|
||||||
|
unitPrice: 100,
|
||||||
|
remaining: 0,
|
||||||
|
total: 100,
|
||||||
|
want: 0,
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "quarter remaining",
|
||||||
|
unitPrice: 200,
|
||||||
|
remaining: 25,
|
||||||
|
total: 100,
|
||||||
|
want: 50,
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got, err := calculateProportionalAmount(tt.unitPrice, tt.remaining, tt.total)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("calculateProportionalAmount() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("calculateProportionalAmount() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCalculateNoLimitAmount(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
sub Subscribe
|
||||||
|
order Order
|
||||||
|
want int64
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "normal no limit calculation",
|
||||||
|
sub: Subscribe{
|
||||||
|
Traffic: 1000,
|
||||||
|
Download: 300,
|
||||||
|
Upload: 200,
|
||||||
|
},
|
||||||
|
order: Order{
|
||||||
|
Amount: 1000,
|
||||||
|
},
|
||||||
|
want: 500, // (1000 - 300 - 200) / 1000 * 1000 = 500
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "zero traffic",
|
||||||
|
sub: Subscribe{
|
||||||
|
Traffic: 0,
|
||||||
|
Download: 0,
|
||||||
|
Upload: 0,
|
||||||
|
},
|
||||||
|
order: Order{
|
||||||
|
Amount: 1000,
|
||||||
|
},
|
||||||
|
want: 0,
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "overused traffic",
|
||||||
|
sub: Subscribe{
|
||||||
|
Traffic: 1000,
|
||||||
|
Download: 600,
|
||||||
|
Upload: 500,
|
||||||
|
},
|
||||||
|
order: Order{
|
||||||
|
Amount: 1000,
|
||||||
|
},
|
||||||
|
want: 0, // usedTraffic would be negative, clamped to 0
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got, err := calculateNoLimitAmount(tt.sub, tt.order)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("calculateNoLimitAmount() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("calculateNoLimitAmount() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCalculateRemainingAmount(t *testing.T) {
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
sub Subscribe
|
||||||
|
order Order
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "valid no limit subscription",
|
||||||
|
sub: Subscribe{
|
||||||
|
StartTime: now.Add(-24 * time.Hour),
|
||||||
|
ExpireTime: now.Add(24 * time.Hour),
|
||||||
|
Traffic: 1000,
|
||||||
|
Download: 300,
|
||||||
|
Upload: 200,
|
||||||
|
UnitTime: UnitTimeNoLimit,
|
||||||
|
ResetCycle: ResetCycleNone,
|
||||||
|
DeductionRatio: 0,
|
||||||
|
},
|
||||||
|
order: Order{
|
||||||
|
Amount: 1000,
|
||||||
|
Quantity: 1,
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid subscription",
|
||||||
|
sub: Subscribe{
|
||||||
|
StartTime: now,
|
||||||
|
ExpireTime: now.Add(-24 * time.Hour), // Invalid: expire before start
|
||||||
|
Traffic: 1000,
|
||||||
|
Download: 300,
|
||||||
|
Upload: 200,
|
||||||
|
UnitTime: UnitTimeMonth,
|
||||||
|
DeductionRatio: 0,
|
||||||
|
},
|
||||||
|
order: Order{
|
||||||
|
Amount: 1000,
|
||||||
|
Quantity: 1,
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid order",
|
||||||
|
sub: Subscribe{
|
||||||
|
StartTime: now.Add(-24 * time.Hour),
|
||||||
|
ExpireTime: now.Add(24 * time.Hour),
|
||||||
|
Traffic: 1000,
|
||||||
|
Download: 300,
|
||||||
|
Upload: 200,
|
||||||
|
UnitTime: UnitTimeMonth,
|
||||||
|
DeductionRatio: 0,
|
||||||
|
},
|
||||||
|
order: Order{
|
||||||
|
Amount: 1000,
|
||||||
|
Quantity: 0, // Invalid: zero quantity
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no limit with reset cycle",
|
||||||
|
sub: Subscribe{
|
||||||
|
StartTime: now.Add(-24 * time.Hour),
|
||||||
|
ExpireTime: now.Add(24 * time.Hour),
|
||||||
|
Traffic: 1000,
|
||||||
|
Download: 300,
|
||||||
|
Upload: 200,
|
||||||
|
UnitTime: UnitTimeNoLimit,
|
||||||
|
ResetCycle: ResetCycleMonthly, // Should return 0
|
||||||
|
DeductionRatio: 0,
|
||||||
|
},
|
||||||
|
order: Order{
|
||||||
|
Amount: 1000,
|
||||||
|
Quantity: 1,
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
_, err := CalculateRemainingAmount(tt.sub, tt.order)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("CalculateRemainingAmount() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCalculateRemainingAmount_NoLimitWithResetCycle(t *testing.T) {
|
||||||
|
now := time.Now()
|
||||||
|
sub := Subscribe{
|
||||||
|
StartTime: now.Add(-24 * time.Hour),
|
||||||
|
ExpireTime: now.Add(24 * time.Hour),
|
||||||
|
Traffic: 1000,
|
||||||
|
Download: 300,
|
||||||
|
Upload: 200,
|
||||||
|
UnitTime: UnitTimeNoLimit,
|
||||||
|
ResetCycle: ResetCycleMonthly,
|
||||||
|
DeductionRatio: 0,
|
||||||
|
}
|
||||||
|
order := Order{
|
||||||
|
Amount: 1000,
|
||||||
|
Quantity: 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
got, err := CalculateRemainingAmount(sub, order)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("CalculateRemainingAmount() error = %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if got != 0 {
|
||||||
|
t.Errorf("CalculateRemainingAmount() = %v, want 0", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Benchmark tests
|
||||||
|
func BenchmarkCalculateRemainingAmount(b *testing.B) {
|
||||||
|
now := time.Now()
|
||||||
|
sub := Subscribe{
|
||||||
|
StartTime: now.Add(-24 * time.Hour),
|
||||||
|
ExpireTime: now.Add(24 * time.Hour),
|
||||||
|
Traffic: 1000,
|
||||||
|
Download: 300,
|
||||||
|
Upload: 200,
|
||||||
|
UnitTime: UnitTimeMonth,
|
||||||
|
ResetCycle: ResetCycleNone,
|
||||||
|
DeductionRatio: 50,
|
||||||
|
}
|
||||||
|
order := Order{
|
||||||
|
Amount: 1000,
|
||||||
|
Quantity: 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_, _ = CalculateRemainingAmount(sub, order)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkSafeMultiply(b *testing.B) {
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_, _ = safeMultiply(12345, 67890)
|
||||||
|
}
|
||||||
|
}
|
||||||
Loading…
x
Reference in New Issue
Block a user