fix(purchase): correct gift amount deduction logic and enhance payment processing comments
This commit is contained in:
parent
76816ca8ea
commit
9691257bad
@ -132,7 +132,7 @@ func (l *PurchaseLogic) Purchase(req *types.PurchaseOrderRequest) (resp *types.P
|
||||
if u.GiftAmount >= amount {
|
||||
deductionAmount = amount
|
||||
amount = 0
|
||||
u.GiftAmount -= amount
|
||||
u.GiftAmount -= deductionAmount
|
||||
} else {
|
||||
deductionAmount = u.GiftAmount
|
||||
amount -= u.GiftAmount
|
||||
|
||||
@ -28,13 +28,16 @@ import (
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// PurchaseCheckoutLogic handles the checkout process for various payment methods
|
||||
// including EPay, Stripe, Alipay F2F, and balance payments
|
||||
type PurchaseCheckoutLogic struct {
|
||||
logger.Logger
|
||||
ctx context.Context
|
||||
svcCtx *svc.ServiceContext
|
||||
}
|
||||
|
||||
// NewPurchaseCheckoutLogic Purchase Checkout
|
||||
// NewPurchaseCheckoutLogic creates a new instance of PurchaseCheckoutLogic
|
||||
// for handling purchase checkout operations across different payment platforms
|
||||
func NewPurchaseCheckoutLogic(ctx context.Context, svcCtx *svc.ServiceContext) *PurchaseCheckoutLogic {
|
||||
return &PurchaseCheckoutLogic{
|
||||
Logger: logger.WithContext(ctx),
|
||||
@ -43,88 +46,104 @@ func NewPurchaseCheckoutLogic(ctx context.Context, svcCtx *svc.ServiceContext) *
|
||||
}
|
||||
}
|
||||
|
||||
// PurchaseCheckout processes the checkout for an order using the specified payment method
|
||||
// It validates the order, retrieves payment configuration, and routes to the appropriate payment handler
|
||||
func (l *PurchaseCheckoutLogic) PurchaseCheckout(req *types.CheckoutOrderRequest) (resp *types.CheckoutOrderResponse, err error) {
|
||||
// Find order
|
||||
// Validate and retrieve order information
|
||||
orderInfo, err := l.svcCtx.OrderModel.FindOneByOrderNo(l.ctx, req.OrderNo)
|
||||
if err != nil {
|
||||
l.Logger.Error("[PurchaseCheckout] Find order failed", logger.Field("error", err.Error()), logger.Field("orderNo", req.OrderNo))
|
||||
return nil, errors.Wrapf(xerr.NewErrCode(xerr.OrderNotExist), "order not exist: %v", req.OrderNo)
|
||||
}
|
||||
|
||||
// Verify order is in pending payment status (status = 1)
|
||||
if orderInfo.Status != 1 {
|
||||
l.Logger.Error("[PurchaseCheckout] Order status error", logger.Field("status", orderInfo.Status))
|
||||
return nil, errors.Wrapf(xerr.NewErrCode(xerr.OrderStatusError), "order status error: %v", orderInfo.Status)
|
||||
}
|
||||
|
||||
// find payment method
|
||||
// Retrieve payment method configuration
|
||||
paymentConfig, err := l.svcCtx.PaymentModel.FindOne(l.ctx, orderInfo.PaymentId)
|
||||
if err != nil {
|
||||
l.Logger.Error("[Purchase] Database query error", logger.Field("error", err.Error()), logger.Field("payment", orderInfo.Method))
|
||||
l.Logger.Error("[PurchaseCheckout] Database query error", logger.Field("error", err.Error()), logger.Field("payment", orderInfo.Method))
|
||||
return nil, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "find payment method error: %v", err.Error())
|
||||
}
|
||||
// Route to appropriate payment handler based on payment platform
|
||||
switch paymentPlatform.ParsePlatform(orderInfo.Method) {
|
||||
case paymentPlatform.EPay:
|
||||
// Process EPay payment - generates payment URL for redirect
|
||||
url, err := l.epayPayment(paymentConfig, orderInfo, req.ReturnUrl)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "epayPayment error: %v", err.Error())
|
||||
}
|
||||
resp = &types.CheckoutOrderResponse{
|
||||
CheckoutUrl: url,
|
||||
Type: "url",
|
||||
Type: "url", // Client should redirect to URL
|
||||
}
|
||||
|
||||
case paymentPlatform.Stripe:
|
||||
// Process Stripe payment - creates payment sheet for client-side processing
|
||||
stripePayment, err := l.stripePayment(paymentConfig.Config, orderInfo, "")
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "stripePayment error: %v", err.Error())
|
||||
}
|
||||
resp = &types.CheckoutOrderResponse{
|
||||
Type: "stripe",
|
||||
Type: "stripe", // Client should use Stripe SDK
|
||||
Stripe: stripePayment,
|
||||
}
|
||||
|
||||
case paymentPlatform.AlipayF2F:
|
||||
// Process Alipay Face-to-Face payment - generates QR code
|
||||
url, err := l.alipayF2fPayment(paymentConfig, orderInfo)
|
||||
if err != nil {
|
||||
l.Errorw("[CheckoutOrderLogic] alipayF2fPayment error", logger.Field("error", err.Error()))
|
||||
l.Errorw("[PurchaseCheckout] alipayF2fPayment error", logger.Field("error", err.Error()))
|
||||
return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "alipayF2fPayment error: %v", err.Error())
|
||||
}
|
||||
resp = &types.CheckoutOrderResponse{
|
||||
Type: "qr",
|
||||
Type: "qr", // Client should display QR code
|
||||
CheckoutUrl: url,
|
||||
}
|
||||
|
||||
case paymentPlatform.Balance:
|
||||
// Process balance payment - validate user and process payment immediately
|
||||
if orderInfo.UserId == 0 {
|
||||
l.Errorw("[CheckoutOrderLogic] user not found", logger.Field("userId", orderInfo.UserId))
|
||||
l.Errorw("[PurchaseCheckout] user not found", logger.Field("userId", orderInfo.UserId))
|
||||
return nil, errors.Wrapf(xerr.NewErrCode(xerr.UserNotExist), "user not found")
|
||||
}
|
||||
// find user
|
||||
|
||||
// Retrieve user information for balance validation
|
||||
userInfo, err := l.svcCtx.UserModel.FindOne(l.ctx, orderInfo.UserId)
|
||||
if err != nil {
|
||||
l.Errorw("[CheckoutOrderLogic] FindOne User error", logger.Field("error", err.Error()))
|
||||
l.Errorw("[PurchaseCheckout] FindOne User error", logger.Field("error", err.Error()))
|
||||
return nil, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "FindOne error: %s", err.Error())
|
||||
}
|
||||
|
||||
// balance
|
||||
// Process balance payment with gift amount priority logic
|
||||
if err = l.balancePayment(userInfo, orderInfo); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp = &types.CheckoutOrderResponse{
|
||||
Type: "balance",
|
||||
Type: "balance", // Payment completed immediately
|
||||
}
|
||||
|
||||
default:
|
||||
l.Errorw("[CheckoutOrderLogic] payment method not found", logger.Field("method", orderInfo.Method))
|
||||
l.Errorw("[PurchaseCheckout] payment method not found", logger.Field("method", orderInfo.Method))
|
||||
return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "payment method not found")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// alipay f2f payment
|
||||
// alipayF2fPayment processes Alipay Face-to-Face payment by generating a QR code
|
||||
// It handles currency conversion and creates a pre-payment trade for QR code scanning
|
||||
func (l *PurchaseCheckoutLogic) alipayF2fPayment(pay *payment.Payment, info *order.Order) (string, error) {
|
||||
// Parse Alipay F2F configuration from payment settings
|
||||
f2FConfig := payment.AlipayF2FConfig{}
|
||||
if err := json.Unmarshal([]byte(pay.Config), &f2FConfig); err != nil {
|
||||
l.Errorw("[PurchaseCheckoutLogic] Unmarshal error", logger.Field("error", err.Error()))
|
||||
l.Errorw("[PurchaseCheckout] Unmarshal Alipay config error", logger.Field("error", err.Error()))
|
||||
return "", errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "Unmarshal error: %s", err.Error())
|
||||
}
|
||||
|
||||
// Build notification URL for payment status callbacks
|
||||
notifyUrl := ""
|
||||
if pay.Domain != "" {
|
||||
notifyUrl = pay.Domain + "/v1/notify/" + pay.Platform + "/" + pay.Token
|
||||
@ -135,6 +154,8 @@ func (l *PurchaseCheckoutLogic) alipayF2fPayment(pay *payment.Payment, info *ord
|
||||
}
|
||||
notifyUrl = "https://" + host + "/v1/notify/" + pay.Platform + "/" + pay.Token
|
||||
}
|
||||
|
||||
// Initialize Alipay client with configuration
|
||||
client := alipay.NewClient(alipay.Config{
|
||||
AppId: f2FConfig.AppId,
|
||||
PrivateKey: f2FConfig.PrivateKey,
|
||||
@ -142,46 +163,53 @@ func (l *PurchaseCheckoutLogic) alipayF2fPayment(pay *payment.Payment, info *ord
|
||||
InvoiceName: f2FConfig.InvoiceName,
|
||||
NotifyURL: notifyUrl,
|
||||
})
|
||||
// Calculate the amount with exchange rate
|
||||
|
||||
// Convert order amount to CNY using current exchange rate
|
||||
amount, err := l.queryExchangeRate("CNY", info.Amount)
|
||||
if err != nil {
|
||||
l.Errorw("[CheckoutOrderLogic] queryExchangeRate error", logger.Field("error", err.Error()))
|
||||
l.Errorw("[PurchaseCheckout] queryExchangeRate error", logger.Field("error", err.Error()))
|
||||
return "", errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "queryExchangeRate error: %s", err.Error())
|
||||
}
|
||||
convertAmount := int64(amount * 100)
|
||||
// create payment
|
||||
convertAmount := int64(amount * 100) // Convert to cents for API
|
||||
|
||||
// Create pre-payment trade and generate QR code
|
||||
QRCode, err := client.PreCreateTrade(l.ctx, alipay.Order{
|
||||
OrderNo: info.OrderNo,
|
||||
Amount: convertAmount,
|
||||
})
|
||||
if err != nil {
|
||||
l.Errorw("[CheckoutOrderLogic] PreCreateTrade error", logger.Field("error", err.Error()))
|
||||
l.Errorw("[PurchaseCheckout] PreCreateTrade error", logger.Field("error", err.Error()))
|
||||
return "", errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "PreCreateTrade error: %s", err.Error())
|
||||
}
|
||||
return QRCode, nil
|
||||
}
|
||||
|
||||
// Stripe Payment
|
||||
// stripePayment processes Stripe payment by creating a payment sheet
|
||||
// It supports various payment methods including WeChat Pay and Alipay through Stripe
|
||||
func (l *PurchaseCheckoutLogic) stripePayment(config string, info *order.Order, identifier string) (*types.StripePayment, error) {
|
||||
// stripe WeChat pay or stripe alipay
|
||||
// Parse Stripe configuration from payment settings
|
||||
stripeConfig := payment.StripeConfig{}
|
||||
if err := json.Unmarshal([]byte(config), &stripeConfig); err != nil {
|
||||
l.Errorw("[CheckoutOrderLogic] Unmarshal error", logger.Field("error", err.Error()))
|
||||
l.Errorw("[PurchaseCheckout] Unmarshal Stripe config error", logger.Field("error", err.Error()))
|
||||
return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "Unmarshal error: %s", err.Error())
|
||||
}
|
||||
|
||||
// Initialize Stripe client with API credentials
|
||||
client := stripe.NewClient(stripe.Config{
|
||||
SecretKey: stripeConfig.SecretKey,
|
||||
PublicKey: stripeConfig.PublicKey,
|
||||
WebhookSecret: stripeConfig.WebhookSecret,
|
||||
})
|
||||
// Calculate the amount with exchange rate
|
||||
|
||||
// Convert order amount to CNY using current exchange rate
|
||||
amount, err := l.queryExchangeRate("CNY", info.Amount)
|
||||
if err != nil {
|
||||
l.Errorw("[CheckoutOrderLogic] queryExchangeRate error", logger.Field("error", err.Error()))
|
||||
l.Errorw("[PurchaseCheckout] queryExchangeRate error", logger.Field("error", err.Error()))
|
||||
return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "queryExchangeRate error: %s", err.Error())
|
||||
}
|
||||
convertAmount := int64(amount * 100)
|
||||
// create payment
|
||||
convertAmount := int64(amount * 100) // Convert to cents for Stripe API
|
||||
|
||||
// Create Stripe payment sheet for client-side processing
|
||||
result, err := client.CreatePaymentSheet(&stripe.Order{
|
||||
OrderNo: info.OrderNo,
|
||||
Subscribe: strconv.FormatInt(info.SubscribeId, 10),
|
||||
@ -193,37 +221,47 @@ func (l *PurchaseCheckoutLogic) stripePayment(config string, info *order.Order,
|
||||
Email: identifier,
|
||||
})
|
||||
if err != nil {
|
||||
l.Errorw("[CheckoutOrderLogic] CreatePaymentSheet error", logger.Field("error", err.Error()))
|
||||
l.Errorw("[PurchaseCheckout] CreatePaymentSheet error", logger.Field("error", err.Error()))
|
||||
return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "CreatePaymentSheet error: %s", err.Error())
|
||||
}
|
||||
tradeNo := result.TradeNo
|
||||
|
||||
// Prepare response data for client-side Stripe integration
|
||||
stripePayment := &types.StripePayment{
|
||||
PublishableKey: stripeConfig.PublicKey,
|
||||
ClientSecret: result.ClientSecret,
|
||||
Method: stripeConfig.Payment,
|
||||
}
|
||||
// save payment
|
||||
info.TradeNo = tradeNo
|
||||
|
||||
// Save Stripe trade number to order for tracking
|
||||
info.TradeNo = result.TradeNo
|
||||
err = l.svcCtx.OrderModel.Update(l.ctx, info)
|
||||
if err != nil {
|
||||
l.Errorw("[CheckoutOrderLogic] Update error", logger.Field("error", err.Error()))
|
||||
l.Errorw("[PurchaseCheckout] Update order error", logger.Field("error", err.Error()))
|
||||
return nil, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "Update error: %s", err.Error())
|
||||
}
|
||||
return stripePayment, nil
|
||||
}
|
||||
|
||||
// epayPayment processes EPay payment by generating a payment URL for redirect
|
||||
// It handles currency conversion and creates a payment URL for external payment processing
|
||||
func (l *PurchaseCheckoutLogic) epayPayment(config *payment.Payment, info *order.Order, returnUrl string) (string, error) {
|
||||
// Parse EPay configuration from payment settings
|
||||
epayConfig := payment.EPayConfig{}
|
||||
if err := json.Unmarshal([]byte(config.Config), &epayConfig); err != nil {
|
||||
l.Errorw("[CheckoutOrderLogic] Unmarshal error", logger.Field("error", err.Error()))
|
||||
l.Errorw("[PurchaseCheckout] Unmarshal EPay config error", logger.Field("error", err.Error()))
|
||||
return "", errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "Unmarshal error: %s", err.Error())
|
||||
}
|
||||
|
||||
// Initialize EPay client with merchant credentials
|
||||
client := epay.NewClient(epayConfig.Pid, epayConfig.Url, epayConfig.Key)
|
||||
// Calculate the amount with exchange rate
|
||||
|
||||
// Convert order amount to CNY using current exchange rate
|
||||
amount, err := l.queryExchangeRate("CNY", info.Amount)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Build notification URL for payment status callbacks
|
||||
notifyUrl := ""
|
||||
if config.Domain != "" {
|
||||
notifyUrl = config.Domain + "/v1/notify/" + config.Platform + "/" + config.Token
|
||||
@ -234,7 +272,8 @@ func (l *PurchaseCheckoutLogic) epayPayment(config *payment.Payment, info *order
|
||||
}
|
||||
notifyUrl = "https://" + host + "/v1/notify/" + config.Platform + "/" + config.Token
|
||||
}
|
||||
// create payment
|
||||
|
||||
// Create payment URL for user redirection
|
||||
url := client.CreatePayUrl(epay.Order{
|
||||
Name: l.svcCtx.Config.Site.SiteName,
|
||||
Amount: amount,
|
||||
@ -246,26 +285,34 @@ func (l *PurchaseCheckoutLogic) epayPayment(config *payment.Payment, info *order
|
||||
return url, nil
|
||||
}
|
||||
|
||||
// Query exchange rate
|
||||
// queryExchangeRate converts the order amount from system currency to target currency
|
||||
// It retrieves the current exchange rate and performs currency conversion if needed
|
||||
func (l *PurchaseCheckoutLogic) queryExchangeRate(to string, src int64) (amount float64, err error) {
|
||||
// Convert cents to decimal amount
|
||||
amount = float64(src) / float64(100)
|
||||
// query system currency
|
||||
|
||||
// Retrieve system currency configuration
|
||||
currency, err := l.svcCtx.SystemModel.GetCurrencyConfig(l.ctx)
|
||||
if err != nil {
|
||||
l.Errorw("[CheckoutOrderLogic] GetCurrencyConfig error", logger.Field("error", err.Error()))
|
||||
l.Errorw("[PurchaseCheckout] GetCurrencyConfig error", logger.Field("error", err.Error()))
|
||||
return 0, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "GetCurrencyConfig error: %s", err.Error())
|
||||
}
|
||||
|
||||
// Parse currency configuration
|
||||
configs := struct {
|
||||
CurrencyUnit string
|
||||
CurrencySymbol string
|
||||
AccessKey string
|
||||
}{}
|
||||
tool.SystemConfigSliceReflectToStruct(currency, &configs)
|
||||
|
||||
// Skip conversion if no exchange rate API key configured
|
||||
if configs.AccessKey == "" {
|
||||
return amount, nil
|
||||
}
|
||||
|
||||
// Convert currency if system currency differs from target currency
|
||||
if configs.CurrencyUnit != to {
|
||||
// query exchange rate
|
||||
result, err := exchangeRate.GetExchangeRete(configs.CurrencyUnit, to, configs.AccessKey, 1)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
@ -275,71 +322,123 @@ func (l *PurchaseCheckoutLogic) queryExchangeRate(to string, src int64) (amount
|
||||
return amount, nil
|
||||
}
|
||||
|
||||
// Balance payment
|
||||
// balancePayment processes balance payment with gift amount priority logic
|
||||
// It prioritizes using gift amount first, then regular balance, and creates proper audit logs
|
||||
func (l *PurchaseCheckoutLogic) balancePayment(u *user.User, o *order.Order) error {
|
||||
if o.Amount == 0 {
|
||||
// No payment required for zero-amount orders
|
||||
return nil
|
||||
}
|
||||
|
||||
var userInfo user.User
|
||||
err := l.svcCtx.UserModel.Transaction(l.ctx, func(db *gorm.DB) error {
|
||||
// Retrieve latest user information with row-level locking
|
||||
err := db.Model(&user.User{}).Where("id = ?", u.Id).First(&userInfo).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if o.GiftAmount != 0 {
|
||||
if userInfo.GiftAmount < o.GiftAmount {
|
||||
return errors.Wrapf(xerr.NewErrCode(xerr.InsufficientBalance), "Insufficient gift balance")
|
||||
}
|
||||
// deduct gift amount
|
||||
userInfo.GiftAmount -= o.GiftAmount
|
||||
// Check if user has sufficient total balance (regular + gift)
|
||||
totalAvailable := userInfo.Balance + userInfo.GiftAmount
|
||||
if totalAvailable < o.Amount {
|
||||
return errors.Wrapf(xerr.NewErrCode(xerr.InsufficientBalance),
|
||||
"Insufficient balance: required %d, available %d", o.Amount, totalAvailable)
|
||||
}
|
||||
|
||||
if userInfo.Balance < o.Amount {
|
||||
return errors.Wrapf(xerr.NewErrCode(xerr.InsufficientBalance), "Insufficient balance")
|
||||
}
|
||||
// deduct balance
|
||||
userInfo.Balance -= o.Amount
|
||||
// Calculate payment distribution: prioritize gift amount first
|
||||
var giftUsed, balanceUsed int64
|
||||
remainingAmount := o.Amount
|
||||
|
||||
userInfo.GiftAmount -= o.GiftAmount
|
||||
if userInfo.GiftAmount >= remainingAmount {
|
||||
// Gift amount covers the entire payment
|
||||
giftUsed = remainingAmount
|
||||
balanceUsed = 0
|
||||
} else {
|
||||
// Use all available gift amount, then regular balance
|
||||
giftUsed = userInfo.GiftAmount
|
||||
balanceUsed = remainingAmount - giftUsed
|
||||
}
|
||||
|
||||
// Update user balances
|
||||
userInfo.GiftAmount -= giftUsed
|
||||
userInfo.Balance -= balanceUsed
|
||||
|
||||
// Save updated user information
|
||||
err = l.svcCtx.UserModel.Update(l.ctx, &userInfo)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// create balance log
|
||||
balanceLog := &user.BalanceLog{
|
||||
Id: 0,
|
||||
UserId: u.Id,
|
||||
Amount: o.Amount,
|
||||
Type: 3,
|
||||
OrderId: o.Id,
|
||||
Balance: userInfo.Balance,
|
||||
|
||||
// Create gift amount log if gift amount was used
|
||||
if giftUsed > 0 {
|
||||
giftLog := &user.GiftAmountLog{
|
||||
UserId: u.Id,
|
||||
UserSubscribeId: 0, // Will be updated when subscription is created
|
||||
OrderNo: o.OrderNo,
|
||||
Type: 2, // Type 2 represents gift amount decrease/usage
|
||||
Amount: giftUsed,
|
||||
Balance: userInfo.GiftAmount,
|
||||
Remark: "Purchase payment",
|
||||
}
|
||||
err = db.Create(giftLog).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
err = db.Create(balanceLog).Error
|
||||
|
||||
// Create balance log if regular balance was used
|
||||
if balanceUsed > 0 {
|
||||
balanceLog := &user.BalanceLog{
|
||||
UserId: u.Id,
|
||||
Amount: balanceUsed,
|
||||
Type: 3, // Type 3 represents payment deduction
|
||||
OrderId: o.Id,
|
||||
Balance: userInfo.Balance,
|
||||
}
|
||||
err = db.Create(balanceLog).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Store gift amount used in order for potential refund tracking
|
||||
o.GiftAmount = giftUsed
|
||||
err = l.svcCtx.OrderModel.Update(l.ctx, o)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return db.Model(&order.Order{}).Where("id = ?", o.Id).Updates(map[string]interface{}{
|
||||
"status": 2, // 2 means paid
|
||||
}).Error
|
||||
|
||||
// Mark order as paid (status = 2)
|
||||
return l.svcCtx.OrderModel.UpdateOrderStatus(l.ctx, o.OrderNo, 2, db)
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
l.Errorw("[CheckoutOrderLogic] Transaction error", logger.Field("error", err.Error()), logger.Field("orderNo", o.OrderNo))
|
||||
l.Errorw("[PurchaseCheckout] Balance payment transaction error",
|
||||
logger.Field("error", err.Error()),
|
||||
logger.Field("orderNo", o.OrderNo),
|
||||
logger.Field("userId", u.Id))
|
||||
return err
|
||||
}
|
||||
// create activity order task
|
||||
|
||||
// Enqueue order activation task for immediate processing
|
||||
payload := queueType.ForthwithActivateOrderPayload{
|
||||
OrderNo: o.OrderNo,
|
||||
}
|
||||
bytes, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
l.Errorw("[CheckoutOrderLogic] Marshal error", logger.Field("error", err.Error()))
|
||||
l.Errorw("[PurchaseCheckout] Marshal activation payload error", logger.Field("error", err.Error()))
|
||||
return err
|
||||
}
|
||||
|
||||
task := asynq.NewTask(queueType.ForthwithActivateOrder, bytes)
|
||||
_, err = l.svcCtx.Queue.EnqueueContext(l.ctx, task)
|
||||
if err != nil {
|
||||
l.Errorw("[CheckoutOrderLogic] Enqueue error", logger.Field("error", err.Error()))
|
||||
l.Errorw("[PurchaseCheckout] Enqueue activation task error", logger.Field("error", err.Error()))
|
||||
return err
|
||||
}
|
||||
l.Logger.Info("[CheckoutOrderLogic] Enqueue success", logger.Field("orderNo", o.OrderNo))
|
||||
|
||||
l.Logger.Info("[PurchaseCheckout] Balance payment completed successfully",
|
||||
logger.Field("orderNo", o.OrderNo),
|
||||
logger.Field("userId", u.Id))
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -93,7 +93,6 @@ func (l *PurchaseLogic) Purchase(req *types.PortalPurchaseRequest) (resp *types.
|
||||
}
|
||||
// Calculate the handling fee
|
||||
amount -= couponAmount
|
||||
var deductionAmount int64
|
||||
// find payment method
|
||||
paymentConfig, err := l.svcCtx.PaymentModel.FindOne(l.ctx, req.Payment)
|
||||
if err != nil {
|
||||
@ -118,7 +117,7 @@ func (l *PurchaseLogic) Purchase(req *types.PortalPurchaseRequest) (resp *types.
|
||||
Price: price,
|
||||
Amount: amount,
|
||||
Discount: discountAmount,
|
||||
GiftAmount: deductionAmount,
|
||||
GiftAmount: 0,
|
||||
Coupon: req.Coupon,
|
||||
CouponDiscount: couponAmount,
|
||||
PaymentId: req.Payment,
|
||||
|
||||
@ -38,6 +38,7 @@ func CalculateRemainingAmount(ctx context.Context, svcCtx *svc.ServiceContext, u
|
||||
orderQuantity := orderDetails.Quantity
|
||||
// Calculate Order Amount
|
||||
orderAmount := orderDetails.Amount + orderDetails.GiftAmount
|
||||
|
||||
if len(orderDetails.SubOrders) > 0 {
|
||||
for _, subOrder := range orderDetails.SubOrders {
|
||||
if subOrder.Status == 2 || subOrder.Status == 5 {
|
||||
@ -47,7 +48,7 @@ func CalculateRemainingAmount(ctx context.Context, svcCtx *svc.ServiceContext, u
|
||||
}
|
||||
}
|
||||
// Calculate Remaining Amount
|
||||
remainingAmount := deduction.CalculateRemainingAmount(
|
||||
remainingAmount, err := deduction.CalculateRemainingAmount(
|
||||
deduction.Subscribe{
|
||||
StartTime: userSubscribe.StartTime,
|
||||
ExpireTime: userSubscribe.ExpireTime,
|
||||
@ -64,5 +65,8 @@ func CalculateRemainingAmount(ctx context.Context, svcCtx *svc.ServiceContext, u
|
||||
Quantity: orderQuantity,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return 0, errors.Wrapf(xerr.NewErrCode(500), "CalculateRemainingAmount failed, userSubscribeId: %d, err: %v", userSubscribeId, err)
|
||||
}
|
||||
return remainingAmount, nil
|
||||
}
|
||||
|
||||
@ -21,7 +21,7 @@ type UnsubscribeLogic struct {
|
||||
svcCtx *svc.ServiceContext
|
||||
}
|
||||
|
||||
// NewUnsubscribeLogic Unsubscribe
|
||||
// NewUnsubscribeLogic creates a new instance of UnsubscribeLogic for handling subscription cancellation
|
||||
func NewUnsubscribeLogic(ctx context.Context, svcCtx *svc.ServiceContext) *UnsubscribeLogic {
|
||||
return &UnsubscribeLogic{
|
||||
Logger: logger.WithContext(ctx),
|
||||
@ -30,39 +30,90 @@ func NewUnsubscribeLogic(ctx context.Context, svcCtx *svc.ServiceContext) *Unsub
|
||||
}
|
||||
}
|
||||
|
||||
// Unsubscribe handles the subscription cancellation process with proper refund distribution
|
||||
// It prioritizes refunding to gift amount for balance-paid orders, then to regular balance
|
||||
func (l *UnsubscribeLogic) Unsubscribe(req *types.UnsubscribeRequest) error {
|
||||
u, ok := l.ctx.Value(constant.CtxKeyUser).(*user.User)
|
||||
if !ok {
|
||||
logger.Error("current user is not found in context")
|
||||
return errors.Wrapf(xerr.NewErrCode(xerr.InvalidAccess), "Invalid Access")
|
||||
}
|
||||
// Calculate the remaining amount to refund based on unused subscription time/traffic
|
||||
remainingAmount, err := CalculateRemainingAmount(l.ctx, l.svcCtx, req.Id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// update user subscribe
|
||||
|
||||
// Process unsubscription in a database transaction to ensure data consistency
|
||||
err = l.svcCtx.UserModel.Transaction(l.ctx, func(db *gorm.DB) error {
|
||||
// Find and update subscription status to cancelled (status = 4)
|
||||
var userSub user.Subscribe
|
||||
if err := db.Model(&user.Subscribe{}).Where("id = ?", req.Id).First(&userSub).Error; err != nil {
|
||||
if err = db.Model(&user.Subscribe{}).Where("id = ?", req.Id).First(&userSub).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
userSub.Status = 4
|
||||
if err := l.svcCtx.UserModel.UpdateSubscribe(l.ctx, &userSub); err != nil {
|
||||
userSub.Status = 4 // Set status to cancelled
|
||||
if err = l.svcCtx.UserModel.UpdateSubscribe(l.ctx, &userSub); err != nil {
|
||||
return err
|
||||
}
|
||||
balance := remainingAmount + u.Balance
|
||||
// insert deduction log
|
||||
balanceLog := user.BalanceLog{
|
||||
UserId: userSub.UserId,
|
||||
OrderId: userSub.OrderId,
|
||||
Amount: remainingAmount,
|
||||
Type: 4,
|
||||
Balance: balance,
|
||||
}
|
||||
if err := db.Model(&user.BalanceLog{}).Create(&balanceLog).Error; err != nil {
|
||||
|
||||
// Query the original order information to determine refund strategy
|
||||
orderInfo, err := l.svcCtx.OrderModel.FindOne(l.ctx, userSub.OrderId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// update user balance
|
||||
// Calculate refund distribution based on payment method and gift amount priority
|
||||
var balance, gift int64
|
||||
if orderInfo.Method == "balance" {
|
||||
// For balance-paid orders, prioritize refunding to gift amount first
|
||||
if orderInfo.GiftAmount >= remainingAmount {
|
||||
// Gift amount covers the entire refund - refund all to gift balance
|
||||
gift = remainingAmount
|
||||
balance = u.Balance // Regular balance remains unchanged
|
||||
} else {
|
||||
// Gift amount insufficient - refund to gift first, remainder to regular balance
|
||||
gift = orderInfo.GiftAmount
|
||||
balance = u.Balance + (remainingAmount - orderInfo.GiftAmount)
|
||||
}
|
||||
} else {
|
||||
// For non-balance payment orders, refund entirely to regular balance
|
||||
balance = remainingAmount + u.Balance
|
||||
gift = 0
|
||||
}
|
||||
|
||||
// Create balance log entry only if there's an actual regular balance refund
|
||||
balanceRefundAmount := balance - u.Balance
|
||||
if balanceRefundAmount > 0 {
|
||||
balanceLog := user.BalanceLog{
|
||||
UserId: userSub.UserId,
|
||||
OrderId: userSub.OrderId,
|
||||
Amount: balanceRefundAmount,
|
||||
Type: 4, // Type 4 represents refund transaction
|
||||
Balance: balance,
|
||||
}
|
||||
if err := db.Model(&user.BalanceLog{}).Create(&balanceLog).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Create gift amount log entry if there's a gift balance refund
|
||||
if gift > 0 {
|
||||
giftLog := user.GiftAmountLog{
|
||||
UserId: userSub.UserId,
|
||||
UserSubscribeId: userSub.Id,
|
||||
OrderNo: orderInfo.OrderNo,
|
||||
Type: 1, // Type 1 represents gift amount increase
|
||||
Amount: gift,
|
||||
Balance: u.GiftAmount + gift,
|
||||
Remark: "Unsubscribe refund",
|
||||
}
|
||||
if err := db.Model(&user.GiftAmountLog{}).Create(&giftLog).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
// Update user's gift amount
|
||||
u.GiftAmount += gift
|
||||
}
|
||||
|
||||
// Update user's regular balance and save changes to database
|
||||
u.Balance = balance
|
||||
return l.svcCtx.UserModel.Update(l.ctx, u)
|
||||
})
|
||||
|
||||
@ -1,90 +1,302 @@
|
||||
// Package deduction provides functionality for calculating remaining amounts
|
||||
// in subscription billing systems, supporting various time units and traffic-based calculations.
|
||||
package deduction
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"time"
|
||||
|
||||
"github.com/perfect-panel/server/pkg/tool"
|
||||
)
|
||||
|
||||
const (
|
||||
UnitTimeNoLimit = "NoLimit"
|
||||
UnitTimeYear = "Year"
|
||||
UnitTimeMonth = "Month"
|
||||
UnitTimeDay = "Day"
|
||||
UintTimeHour = "Hour"
|
||||
UintTimeMinute = "Minute"
|
||||
// Time unit constants for subscription billing
|
||||
UnitTimeNoLimit = "NoLimit" // Unlimited time subscription
|
||||
UnitTimeYear = "Year" // Annual subscription
|
||||
UnitTimeMonth = "Month" // Monthly subscription
|
||||
UnitTimeDay = "Day" // Daily subscription
|
||||
UnitTimeHour = "Hour" // Hourly subscription
|
||||
UnitTimeMinute = "Minute" // Per-minute subscription
|
||||
|
||||
ResetCycleNone = 0
|
||||
ResetCycle1st = 1
|
||||
ResetCycleMonthly = 2
|
||||
ResetCycleYear = 3
|
||||
// Reset cycle constants for traffic resets
|
||||
ResetCycleNone = 0 // No reset cycle
|
||||
ResetCycle1st = 1 // Reset on 1st of each month
|
||||
ResetCycleMonthly = 2 // Reset monthly based on start date
|
||||
ResetCycleYear = 3 // Reset yearly based on start date
|
||||
|
||||
// Safety limits for overflow protection
|
||||
maxInt64 = math.MaxInt64
|
||||
minInt64 = math.MinInt64
|
||||
)
|
||||
|
||||
// Error definitions for validation and calculation failures
|
||||
var (
|
||||
ErrInvalidQuantity = errors.New("order quantity cannot be zero or negative")
|
||||
ErrInvalidAmount = errors.New("order amount cannot be negative")
|
||||
ErrInvalidTraffic = errors.New("traffic values cannot be negative")
|
||||
ErrInvalidTimeRange = errors.New("expire time must be after start time")
|
||||
ErrInvalidUnitTime = errors.New("invalid unit time")
|
||||
ErrInvalidDeductionRatio = errors.New("deduction ratio must be between 0 and 100")
|
||||
ErrOverflow = errors.New("calculation overflow")
|
||||
)
|
||||
|
||||
// Subscribe represents a subscription with time and traffic limits
|
||||
type Subscribe struct {
|
||||
StartTime time.Time
|
||||
ExpireTime time.Time
|
||||
Traffic int64
|
||||
Download int64
|
||||
Upload int64
|
||||
UnitTime string
|
||||
UnitPrice int64
|
||||
ResetCycle int64
|
||||
DeductionRatio int64
|
||||
StartTime time.Time // Subscription start time
|
||||
ExpireTime time.Time // Subscription expiration time
|
||||
Traffic int64 // Total traffic allowance in bytes
|
||||
Download int64 // Downloaded traffic in bytes
|
||||
Upload int64 // Uploaded traffic in bytes
|
||||
UnitTime string // Time unit for billing (Year, Month, Day, etc.)
|
||||
UnitPrice int64 // Price per unit time
|
||||
ResetCycle int64 // Traffic reset cycle
|
||||
DeductionRatio int64 // Deduction ratio for weighted calculations (0-100)
|
||||
}
|
||||
|
||||
// Order represents a purchase order for subscription calculation
|
||||
type Order struct {
|
||||
Amount int64
|
||||
Quantity int64
|
||||
Amount int64 // Total order amount
|
||||
Quantity int64 // Order quantity
|
||||
}
|
||||
|
||||
func CalculateRemainingAmount(sub Subscribe, order Order) int64 {
|
||||
if sub.UnitTime == UnitTimeNoLimit && sub.ResetCycle != 0 {
|
||||
return 0
|
||||
// Validate checks if the Subscribe struct contains valid data
|
||||
func (s *Subscribe) Validate() error {
|
||||
if s.Traffic < 0 || s.Download < 0 || s.Upload < 0 {
|
||||
return ErrInvalidTraffic
|
||||
}
|
||||
// 实际单价
|
||||
sub.UnitPrice = order.Amount / order.Quantity
|
||||
now := time.Now()
|
||||
|
||||
if s.Download+s.Upload > s.Traffic {
|
||||
return fmt.Errorf("download + upload (%d) cannot exceed total traffic (%d)", s.Download+s.Upload, s.Traffic)
|
||||
}
|
||||
|
||||
if !s.ExpireTime.After(s.StartTime) {
|
||||
return ErrInvalidTimeRange
|
||||
}
|
||||
|
||||
if s.DeductionRatio < 0 || s.DeductionRatio > 100 {
|
||||
return ErrInvalidDeductionRatio
|
||||
}
|
||||
|
||||
validUnitTimes := []string{UnitTimeNoLimit, UnitTimeYear, UnitTimeMonth, UnitTimeDay, UnitTimeHour, UnitTimeMinute}
|
||||
valid := false
|
||||
for _, ut := range validUnitTimes {
|
||||
if s.UnitTime == ut {
|
||||
valid = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !valid {
|
||||
return ErrInvalidUnitTime
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Validate checks if the Order struct contains valid data
|
||||
func (o *Order) Validate() error {
|
||||
if o.Quantity <= 0 {
|
||||
return ErrInvalidQuantity
|
||||
}
|
||||
if o.Amount < 0 {
|
||||
return ErrInvalidAmount
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// safeMultiply performs multiplication with overflow protection
|
||||
func safeMultiply(a, b int64) (int64, error) {
|
||||
if a == 0 || b == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
if a > 0 && b > 0 {
|
||||
if a > maxInt64/b {
|
||||
return 0, ErrOverflow
|
||||
}
|
||||
} else if a < 0 && b < 0 {
|
||||
if a < maxInt64/b {
|
||||
return 0, ErrOverflow
|
||||
}
|
||||
} else {
|
||||
if (a > 0 && b < minInt64/a) || (a < 0 && b > minInt64/a) {
|
||||
return 0, ErrOverflow
|
||||
}
|
||||
}
|
||||
|
||||
return a * b, nil
|
||||
}
|
||||
|
||||
// safeAdd performs addition with overflow protection
|
||||
func safeAdd(a, b int64) (int64, error) {
|
||||
if (b > 0 && a > maxInt64-b) || (b < 0 && a < minInt64-b) {
|
||||
return 0, ErrOverflow
|
||||
}
|
||||
return a + b, nil
|
||||
}
|
||||
|
||||
// safeDivide performs division with zero-division protection
|
||||
func safeDivide(a, b int64) (int64, error) {
|
||||
if b == 0 {
|
||||
return 0, errors.New("division by zero")
|
||||
}
|
||||
return a / b, nil
|
||||
}
|
||||
|
||||
// CalculateRemainingAmount calculates the remaining refund amount for a subscription
|
||||
// based on unused time and traffic. Returns the amount and any calculation errors.
|
||||
func CalculateRemainingAmount(sub Subscribe, order Order) (int64, error) {
|
||||
if err := sub.Validate(); err != nil {
|
||||
return 0, fmt.Errorf("invalid subscription: %w", err)
|
||||
}
|
||||
|
||||
if err := order.Validate(); err != nil {
|
||||
return 0, fmt.Errorf("invalid order: %w", err)
|
||||
}
|
||||
|
||||
if sub.UnitTime == UnitTimeNoLimit && sub.ResetCycle != 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
unitPrice, err := safeDivide(order.Amount, order.Quantity)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to calculate unit price: %w", err)
|
||||
}
|
||||
sub.UnitPrice = unitPrice
|
||||
|
||||
loc, err := time.LoadLocation(sub.StartTime.Location().String())
|
||||
if err != nil {
|
||||
loc = time.UTC
|
||||
}
|
||||
now := time.Now().In(loc)
|
||||
|
||||
switch sub.UnitTime {
|
||||
case UnitTimeNoLimit:
|
||||
usedTraffic := sub.Traffic - sub.Download - sub.Upload
|
||||
unitPrice := float64(order.Amount) / float64(sub.Traffic)
|
||||
return int64(float64(usedTraffic) * unitPrice)
|
||||
return calculateNoLimitAmount(sub, order)
|
||||
|
||||
case UnitTimeYear:
|
||||
remainingYears := tool.YearDiff(now, sub.ExpireTime)
|
||||
remainingUnitTimeAmount := calculateRemainingUnitTimeAmount(sub)
|
||||
return int64(remainingYears)*sub.UnitPrice + remainingUnitTimeAmount
|
||||
remainingUnitTimeAmount, err := calculateRemainingUnitTimeAmount(sub)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
yearAmount, err := safeMultiply(int64(remainingYears), sub.UnitPrice)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("year calculation overflow: %w", err)
|
||||
}
|
||||
|
||||
total, err := safeAdd(yearAmount, remainingUnitTimeAmount)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("total calculation overflow: %w", err)
|
||||
}
|
||||
|
||||
return total, nil
|
||||
|
||||
case UnitTimeMonth:
|
||||
remainingMonths := tool.MonthDiff(now, sub.ExpireTime)
|
||||
remainingUnitTimeAmount := calculateRemainingUnitTimeAmount(sub)
|
||||
return int64(remainingMonths)*sub.UnitPrice + remainingUnitTimeAmount
|
||||
remainingUnitTimeAmount, err := calculateRemainingUnitTimeAmount(sub)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
monthAmount, err := safeMultiply(int64(remainingMonths), sub.UnitPrice)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("month calculation overflow: %w", err)
|
||||
}
|
||||
|
||||
total, err := safeAdd(monthAmount, remainingUnitTimeAmount)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("total calculation overflow: %w", err)
|
||||
}
|
||||
|
||||
return total, nil
|
||||
|
||||
case UnitTimeDay:
|
||||
remainingDays := tool.DayDiff(now, sub.ExpireTime)
|
||||
remainingUnitTimeAmount := calculateRemainingUnitTimeAmount(sub)
|
||||
return remainingDays*sub.UnitPrice + remainingUnitTimeAmount
|
||||
remainingUnitTimeAmount, err := calculateRemainingUnitTimeAmount(sub)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
dayAmount, err := safeMultiply(remainingDays, sub.UnitPrice)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("day calculation overflow: %w", err)
|
||||
}
|
||||
|
||||
total, err := safeAdd(dayAmount, remainingUnitTimeAmount)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("total calculation overflow: %w", err)
|
||||
}
|
||||
|
||||
return total, nil
|
||||
}
|
||||
|
||||
return 0
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func calculateRemainingUnitTimeAmount(sub Subscribe) int64 {
|
||||
// calculateNoLimitAmount calculates refund amount for unlimited time subscriptions
|
||||
// based on unused traffic only
|
||||
func calculateNoLimitAmount(sub Subscribe, order Order) (int64, error) {
|
||||
if sub.Traffic == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
usedTraffic := sub.Traffic - sub.Download - sub.Upload
|
||||
if usedTraffic < 0 {
|
||||
usedTraffic = 0
|
||||
}
|
||||
|
||||
unitPrice := float64(order.Amount) / float64(sub.Traffic)
|
||||
result := float64(usedTraffic) * unitPrice
|
||||
|
||||
if result > float64(maxInt64) || result < float64(minInt64) {
|
||||
return 0, ErrOverflow
|
||||
}
|
||||
|
||||
return int64(result), nil
|
||||
}
|
||||
|
||||
// calculateRemainingUnitTimeAmount calculates the remaining amount based on
|
||||
// both time and traffic usage, applying deduction ratios when specified
|
||||
func calculateRemainingUnitTimeAmount(sub Subscribe) (int64, error) {
|
||||
now := time.Now()
|
||||
trafficWeight, timeWeight := calculateWeights(sub.DeductionRatio)
|
||||
remainingDays, totalDays := getRemainingAndTotalDays(sub, now)
|
||||
remainingTraffic := sub.Traffic - sub.Download - sub.Upload
|
||||
remainingTimeAmount := calculateProportionalAmount(sub.UnitPrice, remainingDays, totalDays)
|
||||
remainingTrafficAmount := calculateProportionalAmount(sub.UnitPrice, remainingTraffic, sub.Traffic)
|
||||
if sub.Traffic == 0 {
|
||||
return remainingTimeAmount
|
||||
|
||||
if totalDays == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
remainingTraffic := sub.Traffic - sub.Download - sub.Upload
|
||||
if remainingTraffic < 0 {
|
||||
remainingTraffic = 0
|
||||
}
|
||||
|
||||
remainingTimeAmount, err := calculateProportionalAmount(sub.UnitPrice, remainingDays, totalDays)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("time amount calculation failed: %w", err)
|
||||
}
|
||||
|
||||
if sub.Traffic == 0 {
|
||||
return remainingTimeAmount, nil
|
||||
}
|
||||
|
||||
remainingTrafficAmount, err := calculateProportionalAmount(sub.UnitPrice, remainingTraffic, sub.Traffic)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("traffic amount calculation failed: %w", err)
|
||||
}
|
||||
|
||||
if sub.DeductionRatio != 0 {
|
||||
return calculateWeightedAmount(sub.UnitPrice, remainingTraffic, sub.Traffic, remainingDays, totalDays, trafficWeight, timeWeight)
|
||||
}
|
||||
|
||||
return min(remainingTimeAmount, remainingTrafficAmount)
|
||||
return min(remainingTimeAmount, remainingTrafficAmount), nil
|
||||
}
|
||||
|
||||
// calculateWeights converts deduction ratio to traffic and time weights
|
||||
// for weighted calculations
|
||||
func calculateWeights(deductionRatio int64) (float64, float64) {
|
||||
if deductionRatio == 0 {
|
||||
return 0, 0
|
||||
@ -94,20 +306,32 @@ func calculateWeights(deductionRatio int64) (float64, float64) {
|
||||
return trafficWeight, timeWeight
|
||||
}
|
||||
|
||||
// getRemainingAndTotalDays calculates remaining and total days based on
|
||||
// the subscription's reset cycle configuration
|
||||
func getRemainingAndTotalDays(sub Subscribe, now time.Time) (int64, int64) {
|
||||
switch sub.ResetCycle {
|
||||
case ResetCycleNone:
|
||||
|
||||
remaining := sub.ExpireTime.Sub(now).Hours() / 24
|
||||
total := sub.ExpireTime.Sub(sub.StartTime).Hours() / 24
|
||||
if remaining < 0 {
|
||||
remaining = 0
|
||||
}
|
||||
if total < 0 {
|
||||
total = 0
|
||||
}
|
||||
return int64(remaining), int64(total)
|
||||
|
||||
case ResetCycle1st:
|
||||
return tool.DaysToNextMonth(now), tool.GetLastDayOfMonth(now)
|
||||
|
||||
case ResetCycleMonthly:
|
||||
// -1 to include the current day
|
||||
return tool.DaysToMonthDay(now, sub.StartTime.Day()) - 1, tool.DaysToMonthDay(now, sub.StartTime.Day())
|
||||
remaining := tool.DaysToMonthDay(now, sub.StartTime.Day()) - 1
|
||||
total := tool.DaysToMonthDay(now, sub.StartTime.Day())
|
||||
if remaining < 0 {
|
||||
remaining = 0
|
||||
}
|
||||
return remaining, total
|
||||
|
||||
case ResetCycleYear:
|
||||
return tool.DaysToYearDay(now, int(sub.StartTime.Month()), sub.StartTime.Day()),
|
||||
tool.GetYearDays(now, int(sub.StartTime.Month()), sub.StartTime.Day())
|
||||
@ -115,13 +339,36 @@ func getRemainingAndTotalDays(sub Subscribe, now time.Time) (int64, int64) {
|
||||
return 0, 0
|
||||
}
|
||||
|
||||
func calculateWeightedAmount(unitPrice, remainingTraffic, totalTraffic, remainingDays, totalDays int64, trafficWeight, timeWeight float64) int64 {
|
||||
// calculateWeightedAmount applies weighted calculation combining both time and traffic
|
||||
// remaining ratios based on the specified weights
|
||||
func calculateWeightedAmount(unitPrice, remainingTraffic, totalTraffic, remainingDays, totalDays int64, trafficWeight, timeWeight float64) (int64, error) {
|
||||
if totalDays == 0 || totalTraffic == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
remainingTimeRatio := float64(remainingDays) / float64(totalDays)
|
||||
remainingTrafficRatio := float64(remainingTraffic) / float64(totalTraffic)
|
||||
weightedRemainingRatio := (timeWeight * remainingTimeRatio) + (trafficWeight * remainingTrafficRatio)
|
||||
return int64(float64(unitPrice) * weightedRemainingRatio)
|
||||
|
||||
result := float64(unitPrice) * weightedRemainingRatio
|
||||
if result > float64(maxInt64) || result < float64(minInt64) {
|
||||
return 0, ErrOverflow
|
||||
}
|
||||
|
||||
return int64(result), nil
|
||||
}
|
||||
|
||||
func calculateProportionalAmount(unitPrice, remaining, total int64) int64 {
|
||||
return int64(float64(unitPrice) * (float64(remaining) / float64(total)))
|
||||
// calculateProportionalAmount calculates proportional amount based on
|
||||
// remaining vs total ratio with overflow protection
|
||||
func calculateProportionalAmount(unitPrice, remaining, total int64) (int64, error) {
|
||||
if total == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
result := float64(unitPrice) * (float64(remaining) / float64(total))
|
||||
if result > float64(maxInt64) || result < float64(minInt64) {
|
||||
return 0, ErrOverflow
|
||||
}
|
||||
|
||||
return int64(result), nil
|
||||
}
|
||||
|
||||
665
pkg/deduction/deduction_test.go
Normal file
665
pkg/deduction/deduction_test.go
Normal file
@ -0,0 +1,665 @@
|
||||
package deduction
|
||||
|
||||
import (
|
||||
"math"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestSubscribe_Validate(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
sub Subscribe
|
||||
wantErr bool
|
||||
errType error
|
||||
}{
|
||||
{
|
||||
name: "valid subscription",
|
||||
sub: Subscribe{
|
||||
StartTime: time.Now(),
|
||||
ExpireTime: time.Now().Add(24 * time.Hour),
|
||||
Traffic: 1000,
|
||||
Download: 100,
|
||||
Upload: 200,
|
||||
UnitTime: UnitTimeMonth,
|
||||
DeductionRatio: 50,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "negative traffic",
|
||||
sub: Subscribe{
|
||||
StartTime: time.Now(),
|
||||
ExpireTime: time.Now().Add(24 * time.Hour),
|
||||
Traffic: -1000,
|
||||
Download: 100,
|
||||
Upload: 200,
|
||||
UnitTime: UnitTimeMonth,
|
||||
DeductionRatio: 50,
|
||||
},
|
||||
wantErr: true,
|
||||
errType: ErrInvalidTraffic,
|
||||
},
|
||||
{
|
||||
name: "negative download",
|
||||
sub: Subscribe{
|
||||
StartTime: time.Now(),
|
||||
ExpireTime: time.Now().Add(24 * time.Hour),
|
||||
Traffic: 1000,
|
||||
Download: -100,
|
||||
Upload: 200,
|
||||
UnitTime: UnitTimeMonth,
|
||||
DeductionRatio: 50,
|
||||
},
|
||||
wantErr: true,
|
||||
errType: ErrInvalidTraffic,
|
||||
},
|
||||
{
|
||||
name: "download + upload exceeds traffic",
|
||||
sub: Subscribe{
|
||||
StartTime: time.Now(),
|
||||
ExpireTime: time.Now().Add(24 * time.Hour),
|
||||
Traffic: 1000,
|
||||
Download: 600,
|
||||
Upload: 500,
|
||||
UnitTime: UnitTimeMonth,
|
||||
DeductionRatio: 50,
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "expire time before start time",
|
||||
sub: Subscribe{
|
||||
StartTime: time.Now(),
|
||||
ExpireTime: time.Now().Add(-24 * time.Hour),
|
||||
Traffic: 1000,
|
||||
Download: 100,
|
||||
Upload: 200,
|
||||
UnitTime: UnitTimeMonth,
|
||||
DeductionRatio: 50,
|
||||
},
|
||||
wantErr: true,
|
||||
errType: ErrInvalidTimeRange,
|
||||
},
|
||||
{
|
||||
name: "invalid deduction ratio - negative",
|
||||
sub: Subscribe{
|
||||
StartTime: time.Now(),
|
||||
ExpireTime: time.Now().Add(24 * time.Hour),
|
||||
Traffic: 1000,
|
||||
Download: 100,
|
||||
Upload: 200,
|
||||
UnitTime: UnitTimeMonth,
|
||||
DeductionRatio: -10,
|
||||
},
|
||||
wantErr: true,
|
||||
errType: ErrInvalidDeductionRatio,
|
||||
},
|
||||
{
|
||||
name: "invalid deduction ratio - over 100",
|
||||
sub: Subscribe{
|
||||
StartTime: time.Now(),
|
||||
ExpireTime: time.Now().Add(24 * time.Hour),
|
||||
Traffic: 1000,
|
||||
Download: 100,
|
||||
Upload: 200,
|
||||
UnitTime: UnitTimeMonth,
|
||||
DeductionRatio: 150,
|
||||
},
|
||||
wantErr: true,
|
||||
errType: ErrInvalidDeductionRatio,
|
||||
},
|
||||
{
|
||||
name: "invalid unit time",
|
||||
sub: Subscribe{
|
||||
StartTime: time.Now(),
|
||||
ExpireTime: time.Now().Add(24 * time.Hour),
|
||||
Traffic: 1000,
|
||||
Download: 100,
|
||||
Upload: 200,
|
||||
UnitTime: "InvalidUnit",
|
||||
DeductionRatio: 50,
|
||||
},
|
||||
wantErr: true,
|
||||
errType: ErrInvalidUnitTime,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := tt.sub.Validate()
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Subscribe.Validate() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if tt.errType != nil && err != tt.errType {
|
||||
t.Errorf("Subscribe.Validate() error = %v, want %v", err, tt.errType)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOrder_Validate(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
order Order
|
||||
wantErr bool
|
||||
errType error
|
||||
}{
|
||||
{
|
||||
name: "valid order",
|
||||
order: Order{Amount: 1000, Quantity: 2},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "zero quantity",
|
||||
order: Order{Amount: 1000, Quantity: 0},
|
||||
wantErr: true,
|
||||
errType: ErrInvalidQuantity,
|
||||
},
|
||||
{
|
||||
name: "negative quantity",
|
||||
order: Order{Amount: 1000, Quantity: -1},
|
||||
wantErr: true,
|
||||
errType: ErrInvalidQuantity,
|
||||
},
|
||||
{
|
||||
name: "negative amount",
|
||||
order: Order{Amount: -1000, Quantity: 2},
|
||||
wantErr: true,
|
||||
errType: ErrInvalidAmount,
|
||||
},
|
||||
{
|
||||
name: "zero amount is valid",
|
||||
order: Order{Amount: 0, Quantity: 1},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := tt.order.Validate()
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Order.Validate() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if tt.errType != nil && err != tt.errType {
|
||||
t.Errorf("Order.Validate() error = %v, want %v", err, tt.errType)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSafeMultiply(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
a, b int64
|
||||
want int64
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "normal multiplication",
|
||||
a: 10,
|
||||
b: 20,
|
||||
want: 200,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "zero multiplication",
|
||||
a: 10,
|
||||
b: 0,
|
||||
want: 0,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "negative multiplication",
|
||||
a: -10,
|
||||
b: 20,
|
||||
want: -200,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "overflow case",
|
||||
a: math.MaxInt64,
|
||||
b: 2,
|
||||
want: 0,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "large numbers no overflow",
|
||||
a: 1000000,
|
||||
b: 1000000,
|
||||
want: 1000000000000,
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := safeMultiply(tt.a, tt.b)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("safeMultiply() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if got != tt.want {
|
||||
t.Errorf("safeMultiply() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSafeAdd(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
a, b int64
|
||||
want int64
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "normal addition",
|
||||
a: 10,
|
||||
b: 20,
|
||||
want: 30,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "negative addition",
|
||||
a: -10,
|
||||
b: 5,
|
||||
want: -5,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "overflow case",
|
||||
a: math.MaxInt64,
|
||||
b: 1,
|
||||
want: 0,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "underflow case",
|
||||
a: math.MinInt64,
|
||||
b: -1,
|
||||
want: 0,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := safeAdd(tt.a, tt.b)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("safeAdd() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if got != tt.want {
|
||||
t.Errorf("safeAdd() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSafeDivide(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
a, b int64
|
||||
want int64
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "normal division",
|
||||
a: 20,
|
||||
b: 10,
|
||||
want: 2,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "division by zero",
|
||||
a: 20,
|
||||
b: 0,
|
||||
want: 0,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "negative division",
|
||||
a: -20,
|
||||
b: 10,
|
||||
want: -2,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "zero dividend",
|
||||
a: 0,
|
||||
b: 10,
|
||||
want: 0,
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := safeDivide(tt.a, tt.b)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("safeDivide() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if got != tt.want {
|
||||
t.Errorf("safeDivide() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalculateWeights(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
deductionRatio int64
|
||||
wantTrafficWeight float64
|
||||
wantTimeWeight float64
|
||||
}{
|
||||
{
|
||||
name: "zero ratio",
|
||||
deductionRatio: 0,
|
||||
wantTrafficWeight: 0,
|
||||
wantTimeWeight: 0,
|
||||
},
|
||||
{
|
||||
name: "50% ratio",
|
||||
deductionRatio: 50,
|
||||
wantTrafficWeight: 0.5,
|
||||
wantTimeWeight: 0.5,
|
||||
},
|
||||
{
|
||||
name: "75% ratio",
|
||||
deductionRatio: 75,
|
||||
wantTrafficWeight: 0.75,
|
||||
wantTimeWeight: 0.25,
|
||||
},
|
||||
{
|
||||
name: "100% ratio",
|
||||
deductionRatio: 100,
|
||||
wantTrafficWeight: 1.0,
|
||||
wantTimeWeight: 0.0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotTrafficWeight, gotTimeWeight := calculateWeights(tt.deductionRatio)
|
||||
if gotTrafficWeight != tt.wantTrafficWeight {
|
||||
t.Errorf("calculateWeights() trafficWeight = %v, want %v", gotTrafficWeight, tt.wantTrafficWeight)
|
||||
}
|
||||
if gotTimeWeight != tt.wantTimeWeight {
|
||||
t.Errorf("calculateWeights() timeWeight = %v, want %v", gotTimeWeight, tt.wantTimeWeight)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalculateProportionalAmount(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
unitPrice int64
|
||||
remaining int64
|
||||
total int64
|
||||
want int64
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "normal calculation",
|
||||
unitPrice: 100,
|
||||
remaining: 50,
|
||||
total: 100,
|
||||
want: 50,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "zero total",
|
||||
unitPrice: 100,
|
||||
remaining: 50,
|
||||
total: 0,
|
||||
want: 0,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "zero remaining",
|
||||
unitPrice: 100,
|
||||
remaining: 0,
|
||||
total: 100,
|
||||
want: 0,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "quarter remaining",
|
||||
unitPrice: 200,
|
||||
remaining: 25,
|
||||
total: 100,
|
||||
want: 50,
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := calculateProportionalAmount(tt.unitPrice, tt.remaining, tt.total)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("calculateProportionalAmount() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if got != tt.want {
|
||||
t.Errorf("calculateProportionalAmount() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalculateNoLimitAmount(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
sub Subscribe
|
||||
order Order
|
||||
want int64
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "normal no limit calculation",
|
||||
sub: Subscribe{
|
||||
Traffic: 1000,
|
||||
Download: 300,
|
||||
Upload: 200,
|
||||
},
|
||||
order: Order{
|
||||
Amount: 1000,
|
||||
},
|
||||
want: 500, // (1000 - 300 - 200) / 1000 * 1000 = 500
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "zero traffic",
|
||||
sub: Subscribe{
|
||||
Traffic: 0,
|
||||
Download: 0,
|
||||
Upload: 0,
|
||||
},
|
||||
order: Order{
|
||||
Amount: 1000,
|
||||
},
|
||||
want: 0,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "overused traffic",
|
||||
sub: Subscribe{
|
||||
Traffic: 1000,
|
||||
Download: 600,
|
||||
Upload: 500,
|
||||
},
|
||||
order: Order{
|
||||
Amount: 1000,
|
||||
},
|
||||
want: 0, // usedTraffic would be negative, clamped to 0
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := calculateNoLimitAmount(tt.sub, tt.order)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("calculateNoLimitAmount() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if got != tt.want {
|
||||
t.Errorf("calculateNoLimitAmount() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalculateRemainingAmount(t *testing.T) {
|
||||
now := time.Now()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
sub Subscribe
|
||||
order Order
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid no limit subscription",
|
||||
sub: Subscribe{
|
||||
StartTime: now.Add(-24 * time.Hour),
|
||||
ExpireTime: now.Add(24 * time.Hour),
|
||||
Traffic: 1000,
|
||||
Download: 300,
|
||||
Upload: 200,
|
||||
UnitTime: UnitTimeNoLimit,
|
||||
ResetCycle: ResetCycleNone,
|
||||
DeductionRatio: 0,
|
||||
},
|
||||
order: Order{
|
||||
Amount: 1000,
|
||||
Quantity: 1,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid subscription",
|
||||
sub: Subscribe{
|
||||
StartTime: now,
|
||||
ExpireTime: now.Add(-24 * time.Hour), // Invalid: expire before start
|
||||
Traffic: 1000,
|
||||
Download: 300,
|
||||
Upload: 200,
|
||||
UnitTime: UnitTimeMonth,
|
||||
DeductionRatio: 0,
|
||||
},
|
||||
order: Order{
|
||||
Amount: 1000,
|
||||
Quantity: 1,
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid order",
|
||||
sub: Subscribe{
|
||||
StartTime: now.Add(-24 * time.Hour),
|
||||
ExpireTime: now.Add(24 * time.Hour),
|
||||
Traffic: 1000,
|
||||
Download: 300,
|
||||
Upload: 200,
|
||||
UnitTime: UnitTimeMonth,
|
||||
DeductionRatio: 0,
|
||||
},
|
||||
order: Order{
|
||||
Amount: 1000,
|
||||
Quantity: 0, // Invalid: zero quantity
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "no limit with reset cycle",
|
||||
sub: Subscribe{
|
||||
StartTime: now.Add(-24 * time.Hour),
|
||||
ExpireTime: now.Add(24 * time.Hour),
|
||||
Traffic: 1000,
|
||||
Download: 300,
|
||||
Upload: 200,
|
||||
UnitTime: UnitTimeNoLimit,
|
||||
ResetCycle: ResetCycleMonthly, // Should return 0
|
||||
DeductionRatio: 0,
|
||||
},
|
||||
order: Order{
|
||||
Amount: 1000,
|
||||
Quantity: 1,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := CalculateRemainingAmount(tt.sub, tt.order)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("CalculateRemainingAmount() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalculateRemainingAmount_NoLimitWithResetCycle(t *testing.T) {
|
||||
now := time.Now()
|
||||
sub := Subscribe{
|
||||
StartTime: now.Add(-24 * time.Hour),
|
||||
ExpireTime: now.Add(24 * time.Hour),
|
||||
Traffic: 1000,
|
||||
Download: 300,
|
||||
Upload: 200,
|
||||
UnitTime: UnitTimeNoLimit,
|
||||
ResetCycle: ResetCycleMonthly,
|
||||
DeductionRatio: 0,
|
||||
}
|
||||
order := Order{
|
||||
Amount: 1000,
|
||||
Quantity: 1,
|
||||
}
|
||||
|
||||
got, err := CalculateRemainingAmount(sub, order)
|
||||
if err != nil {
|
||||
t.Errorf("CalculateRemainingAmount() error = %v", err)
|
||||
return
|
||||
}
|
||||
if got != 0 {
|
||||
t.Errorf("CalculateRemainingAmount() = %v, want 0", got)
|
||||
}
|
||||
}
|
||||
|
||||
// Benchmark tests
|
||||
func BenchmarkCalculateRemainingAmount(b *testing.B) {
|
||||
now := time.Now()
|
||||
sub := Subscribe{
|
||||
StartTime: now.Add(-24 * time.Hour),
|
||||
ExpireTime: now.Add(24 * time.Hour),
|
||||
Traffic: 1000,
|
||||
Download: 300,
|
||||
Upload: 200,
|
||||
UnitTime: UnitTimeMonth,
|
||||
ResetCycle: ResetCycleNone,
|
||||
DeductionRatio: 50,
|
||||
}
|
||||
order := Order{
|
||||
Amount: 1000,
|
||||
Quantity: 1,
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = CalculateRemainingAmount(sub, order)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkSafeMultiply(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = safeMultiply(12345, 67890)
|
||||
}
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user