fix(purchase): correct gift amount deduction logic and enhance payment processing comments

This commit is contained in:
Chang lue Tsen 2025-07-07 14:26:53 -04:00
parent 76816ca8ea
commit 9691257bad
7 changed files with 1206 additions and 141 deletions

View File

@ -132,7 +132,7 @@ func (l *PurchaseLogic) Purchase(req *types.PurchaseOrderRequest) (resp *types.P
if u.GiftAmount >= amount {
deductionAmount = amount
amount = 0
u.GiftAmount -= amount
u.GiftAmount -= deductionAmount
} else {
deductionAmount = u.GiftAmount
amount -= u.GiftAmount

View File

@ -28,13 +28,16 @@ import (
"github.com/pkg/errors"
)
// PurchaseCheckoutLogic handles the checkout process for various payment methods
// including EPay, Stripe, Alipay F2F, and balance payments
type PurchaseCheckoutLogic struct {
logger.Logger
ctx context.Context
svcCtx *svc.ServiceContext
}
// NewPurchaseCheckoutLogic Purchase Checkout
// NewPurchaseCheckoutLogic creates a new instance of PurchaseCheckoutLogic
// for handling purchase checkout operations across different payment platforms
func NewPurchaseCheckoutLogic(ctx context.Context, svcCtx *svc.ServiceContext) *PurchaseCheckoutLogic {
return &PurchaseCheckoutLogic{
Logger: logger.WithContext(ctx),
@ -43,88 +46,104 @@ func NewPurchaseCheckoutLogic(ctx context.Context, svcCtx *svc.ServiceContext) *
}
}
// PurchaseCheckout processes the checkout for an order using the specified payment method
// It validates the order, retrieves payment configuration, and routes to the appropriate payment handler
func (l *PurchaseCheckoutLogic) PurchaseCheckout(req *types.CheckoutOrderRequest) (resp *types.CheckoutOrderResponse, err error) {
// Find order
// Validate and retrieve order information
orderInfo, err := l.svcCtx.OrderModel.FindOneByOrderNo(l.ctx, req.OrderNo)
if err != nil {
l.Logger.Error("[PurchaseCheckout] Find order failed", logger.Field("error", err.Error()), logger.Field("orderNo", req.OrderNo))
return nil, errors.Wrapf(xerr.NewErrCode(xerr.OrderNotExist), "order not exist: %v", req.OrderNo)
}
// Verify order is in pending payment status (status = 1)
if orderInfo.Status != 1 {
l.Logger.Error("[PurchaseCheckout] Order status error", logger.Field("status", orderInfo.Status))
return nil, errors.Wrapf(xerr.NewErrCode(xerr.OrderStatusError), "order status error: %v", orderInfo.Status)
}
// find payment method
// Retrieve payment method configuration
paymentConfig, err := l.svcCtx.PaymentModel.FindOne(l.ctx, orderInfo.PaymentId)
if err != nil {
l.Logger.Error("[Purchase] Database query error", logger.Field("error", err.Error()), logger.Field("payment", orderInfo.Method))
l.Logger.Error("[PurchaseCheckout] Database query error", logger.Field("error", err.Error()), logger.Field("payment", orderInfo.Method))
return nil, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "find payment method error: %v", err.Error())
}
// Route to appropriate payment handler based on payment platform
switch paymentPlatform.ParsePlatform(orderInfo.Method) {
case paymentPlatform.EPay:
// Process EPay payment - generates payment URL for redirect
url, err := l.epayPayment(paymentConfig, orderInfo, req.ReturnUrl)
if err != nil {
return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "epayPayment error: %v", err.Error())
}
resp = &types.CheckoutOrderResponse{
CheckoutUrl: url,
Type: "url",
Type: "url", // Client should redirect to URL
}
case paymentPlatform.Stripe:
// Process Stripe payment - creates payment sheet for client-side processing
stripePayment, err := l.stripePayment(paymentConfig.Config, orderInfo, "")
if err != nil {
return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "stripePayment error: %v", err.Error())
}
resp = &types.CheckoutOrderResponse{
Type: "stripe",
Type: "stripe", // Client should use Stripe SDK
Stripe: stripePayment,
}
case paymentPlatform.AlipayF2F:
// Process Alipay Face-to-Face payment - generates QR code
url, err := l.alipayF2fPayment(paymentConfig, orderInfo)
if err != nil {
l.Errorw("[CheckoutOrderLogic] alipayF2fPayment error", logger.Field("error", err.Error()))
l.Errorw("[PurchaseCheckout] alipayF2fPayment error", logger.Field("error", err.Error()))
return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "alipayF2fPayment error: %v", err.Error())
}
resp = &types.CheckoutOrderResponse{
Type: "qr",
Type: "qr", // Client should display QR code
CheckoutUrl: url,
}
case paymentPlatform.Balance:
// Process balance payment - validate user and process payment immediately
if orderInfo.UserId == 0 {
l.Errorw("[CheckoutOrderLogic] user not found", logger.Field("userId", orderInfo.UserId))
l.Errorw("[PurchaseCheckout] user not found", logger.Field("userId", orderInfo.UserId))
return nil, errors.Wrapf(xerr.NewErrCode(xerr.UserNotExist), "user not found")
}
// find user
// Retrieve user information for balance validation
userInfo, err := l.svcCtx.UserModel.FindOne(l.ctx, orderInfo.UserId)
if err != nil {
l.Errorw("[CheckoutOrderLogic] FindOne User error", logger.Field("error", err.Error()))
l.Errorw("[PurchaseCheckout] FindOne User error", logger.Field("error", err.Error()))
return nil, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "FindOne error: %s", err.Error())
}
// balance
// Process balance payment with gift amount priority logic
if err = l.balancePayment(userInfo, orderInfo); err != nil {
return nil, err
}
resp = &types.CheckoutOrderResponse{
Type: "balance",
Type: "balance", // Payment completed immediately
}
default:
l.Errorw("[CheckoutOrderLogic] payment method not found", logger.Field("method", orderInfo.Method))
l.Errorw("[PurchaseCheckout] payment method not found", logger.Field("method", orderInfo.Method))
return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "payment method not found")
}
return
}
// alipay f2f payment
// alipayF2fPayment processes Alipay Face-to-Face payment by generating a QR code
// It handles currency conversion and creates a pre-payment trade for QR code scanning
func (l *PurchaseCheckoutLogic) alipayF2fPayment(pay *payment.Payment, info *order.Order) (string, error) {
// Parse Alipay F2F configuration from payment settings
f2FConfig := payment.AlipayF2FConfig{}
if err := json.Unmarshal([]byte(pay.Config), &f2FConfig); err != nil {
l.Errorw("[PurchaseCheckoutLogic] Unmarshal error", logger.Field("error", err.Error()))
l.Errorw("[PurchaseCheckout] Unmarshal Alipay config error", logger.Field("error", err.Error()))
return "", errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "Unmarshal error: %s", err.Error())
}
// Build notification URL for payment status callbacks
notifyUrl := ""
if pay.Domain != "" {
notifyUrl = pay.Domain + "/v1/notify/" + pay.Platform + "/" + pay.Token
@ -135,6 +154,8 @@ func (l *PurchaseCheckoutLogic) alipayF2fPayment(pay *payment.Payment, info *ord
}
notifyUrl = "https://" + host + "/v1/notify/" + pay.Platform + "/" + pay.Token
}
// Initialize Alipay client with configuration
client := alipay.NewClient(alipay.Config{
AppId: f2FConfig.AppId,
PrivateKey: f2FConfig.PrivateKey,
@ -142,46 +163,53 @@ func (l *PurchaseCheckoutLogic) alipayF2fPayment(pay *payment.Payment, info *ord
InvoiceName: f2FConfig.InvoiceName,
NotifyURL: notifyUrl,
})
// Calculate the amount with exchange rate
// Convert order amount to CNY using current exchange rate
amount, err := l.queryExchangeRate("CNY", info.Amount)
if err != nil {
l.Errorw("[CheckoutOrderLogic] queryExchangeRate error", logger.Field("error", err.Error()))
l.Errorw("[PurchaseCheckout] queryExchangeRate error", logger.Field("error", err.Error()))
return "", errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "queryExchangeRate error: %s", err.Error())
}
convertAmount := int64(amount * 100)
// create payment
convertAmount := int64(amount * 100) // Convert to cents for API
// Create pre-payment trade and generate QR code
QRCode, err := client.PreCreateTrade(l.ctx, alipay.Order{
OrderNo: info.OrderNo,
Amount: convertAmount,
})
if err != nil {
l.Errorw("[CheckoutOrderLogic] PreCreateTrade error", logger.Field("error", err.Error()))
l.Errorw("[PurchaseCheckout] PreCreateTrade error", logger.Field("error", err.Error()))
return "", errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "PreCreateTrade error: %s", err.Error())
}
return QRCode, nil
}
// Stripe Payment
// stripePayment processes Stripe payment by creating a payment sheet
// It supports various payment methods including WeChat Pay and Alipay through Stripe
func (l *PurchaseCheckoutLogic) stripePayment(config string, info *order.Order, identifier string) (*types.StripePayment, error) {
// stripe WeChat pay or stripe alipay
// Parse Stripe configuration from payment settings
stripeConfig := payment.StripeConfig{}
if err := json.Unmarshal([]byte(config), &stripeConfig); err != nil {
l.Errorw("[CheckoutOrderLogic] Unmarshal error", logger.Field("error", err.Error()))
l.Errorw("[PurchaseCheckout] Unmarshal Stripe config error", logger.Field("error", err.Error()))
return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "Unmarshal error: %s", err.Error())
}
// Initialize Stripe client with API credentials
client := stripe.NewClient(stripe.Config{
SecretKey: stripeConfig.SecretKey,
PublicKey: stripeConfig.PublicKey,
WebhookSecret: stripeConfig.WebhookSecret,
})
// Calculate the amount with exchange rate
// Convert order amount to CNY using current exchange rate
amount, err := l.queryExchangeRate("CNY", info.Amount)
if err != nil {
l.Errorw("[CheckoutOrderLogic] queryExchangeRate error", logger.Field("error", err.Error()))
l.Errorw("[PurchaseCheckout] queryExchangeRate error", logger.Field("error", err.Error()))
return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "queryExchangeRate error: %s", err.Error())
}
convertAmount := int64(amount * 100)
// create payment
convertAmount := int64(amount * 100) // Convert to cents for Stripe API
// Create Stripe payment sheet for client-side processing
result, err := client.CreatePaymentSheet(&stripe.Order{
OrderNo: info.OrderNo,
Subscribe: strconv.FormatInt(info.SubscribeId, 10),
@ -193,37 +221,47 @@ func (l *PurchaseCheckoutLogic) stripePayment(config string, info *order.Order,
Email: identifier,
})
if err != nil {
l.Errorw("[CheckoutOrderLogic] CreatePaymentSheet error", logger.Field("error", err.Error()))
l.Errorw("[PurchaseCheckout] CreatePaymentSheet error", logger.Field("error", err.Error()))
return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "CreatePaymentSheet error: %s", err.Error())
}
tradeNo := result.TradeNo
// Prepare response data for client-side Stripe integration
stripePayment := &types.StripePayment{
PublishableKey: stripeConfig.PublicKey,
ClientSecret: result.ClientSecret,
Method: stripeConfig.Payment,
}
// save payment
info.TradeNo = tradeNo
// Save Stripe trade number to order for tracking
info.TradeNo = result.TradeNo
err = l.svcCtx.OrderModel.Update(l.ctx, info)
if err != nil {
l.Errorw("[CheckoutOrderLogic] Update error", logger.Field("error", err.Error()))
l.Errorw("[PurchaseCheckout] Update order error", logger.Field("error", err.Error()))
return nil, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "Update error: %s", err.Error())
}
return stripePayment, nil
}
// epayPayment processes EPay payment by generating a payment URL for redirect
// It handles currency conversion and creates a payment URL for external payment processing
func (l *PurchaseCheckoutLogic) epayPayment(config *payment.Payment, info *order.Order, returnUrl string) (string, error) {
// Parse EPay configuration from payment settings
epayConfig := payment.EPayConfig{}
if err := json.Unmarshal([]byte(config.Config), &epayConfig); err != nil {
l.Errorw("[CheckoutOrderLogic] Unmarshal error", logger.Field("error", err.Error()))
l.Errorw("[PurchaseCheckout] Unmarshal EPay config error", logger.Field("error", err.Error()))
return "", errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "Unmarshal error: %s", err.Error())
}
// Initialize EPay client with merchant credentials
client := epay.NewClient(epayConfig.Pid, epayConfig.Url, epayConfig.Key)
// Calculate the amount with exchange rate
// Convert order amount to CNY using current exchange rate
amount, err := l.queryExchangeRate("CNY", info.Amount)
if err != nil {
return "", err
}
// Build notification URL for payment status callbacks
notifyUrl := ""
if config.Domain != "" {
notifyUrl = config.Domain + "/v1/notify/" + config.Platform + "/" + config.Token
@ -234,7 +272,8 @@ func (l *PurchaseCheckoutLogic) epayPayment(config *payment.Payment, info *order
}
notifyUrl = "https://" + host + "/v1/notify/" + config.Platform + "/" + config.Token
}
// create payment
// Create payment URL for user redirection
url := client.CreatePayUrl(epay.Order{
Name: l.svcCtx.Config.Site.SiteName,
Amount: amount,
@ -246,26 +285,34 @@ func (l *PurchaseCheckoutLogic) epayPayment(config *payment.Payment, info *order
return url, nil
}
// Query exchange rate
// queryExchangeRate converts the order amount from system currency to target currency
// It retrieves the current exchange rate and performs currency conversion if needed
func (l *PurchaseCheckoutLogic) queryExchangeRate(to string, src int64) (amount float64, err error) {
// Convert cents to decimal amount
amount = float64(src) / float64(100)
// query system currency
// Retrieve system currency configuration
currency, err := l.svcCtx.SystemModel.GetCurrencyConfig(l.ctx)
if err != nil {
l.Errorw("[CheckoutOrderLogic] GetCurrencyConfig error", logger.Field("error", err.Error()))
l.Errorw("[PurchaseCheckout] GetCurrencyConfig error", logger.Field("error", err.Error()))
return 0, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "GetCurrencyConfig error: %s", err.Error())
}
// Parse currency configuration
configs := struct {
CurrencyUnit string
CurrencySymbol string
AccessKey string
}{}
tool.SystemConfigSliceReflectToStruct(currency, &configs)
// Skip conversion if no exchange rate API key configured
if configs.AccessKey == "" {
return amount, nil
}
// Convert currency if system currency differs from target currency
if configs.CurrencyUnit != to {
// query exchange rate
result, err := exchangeRate.GetExchangeRete(configs.CurrencyUnit, to, configs.AccessKey, 1)
if err != nil {
return 0, err
@ -275,40 +322,76 @@ func (l *PurchaseCheckoutLogic) queryExchangeRate(to string, src int64) (amount
return amount, nil
}
// Balance payment
// balancePayment processes balance payment with gift amount priority logic
// It prioritizes using gift amount first, then regular balance, and creates proper audit logs
func (l *PurchaseCheckoutLogic) balancePayment(u *user.User, o *order.Order) error {
if o.Amount == 0 {
// No payment required for zero-amount orders
return nil
}
var userInfo user.User
err := l.svcCtx.UserModel.Transaction(l.ctx, func(db *gorm.DB) error {
// Retrieve latest user information with row-level locking
err := db.Model(&user.User{}).Where("id = ?", u.Id).First(&userInfo).Error
if err != nil {
return err
}
if o.GiftAmount != 0 {
if userInfo.GiftAmount < o.GiftAmount {
return errors.Wrapf(xerr.NewErrCode(xerr.InsufficientBalance), "Insufficient gift balance")
}
// deduct gift amount
userInfo.GiftAmount -= o.GiftAmount
// Check if user has sufficient total balance (regular + gift)
totalAvailable := userInfo.Balance + userInfo.GiftAmount
if totalAvailable < o.Amount {
return errors.Wrapf(xerr.NewErrCode(xerr.InsufficientBalance),
"Insufficient balance: required %d, available %d", o.Amount, totalAvailable)
}
if userInfo.Balance < o.Amount {
return errors.Wrapf(xerr.NewErrCode(xerr.InsufficientBalance), "Insufficient balance")
}
// deduct balance
userInfo.Balance -= o.Amount
// Calculate payment distribution: prioritize gift amount first
var giftUsed, balanceUsed int64
remainingAmount := o.Amount
userInfo.GiftAmount -= o.GiftAmount
if userInfo.GiftAmount >= remainingAmount {
// Gift amount covers the entire payment
giftUsed = remainingAmount
balanceUsed = 0
} else {
// Use all available gift amount, then regular balance
giftUsed = userInfo.GiftAmount
balanceUsed = remainingAmount - giftUsed
}
// Update user balances
userInfo.GiftAmount -= giftUsed
userInfo.Balance -= balanceUsed
// Save updated user information
err = l.svcCtx.UserModel.Update(l.ctx, &userInfo)
if err != nil {
return err
}
// create balance log
balanceLog := &user.BalanceLog{
Id: 0,
// Create gift amount log if gift amount was used
if giftUsed > 0 {
giftLog := &user.GiftAmountLog{
UserId: u.Id,
Amount: o.Amount,
Type: 3,
UserSubscribeId: 0, // Will be updated when subscription is created
OrderNo: o.OrderNo,
Type: 2, // Type 2 represents gift amount decrease/usage
Amount: giftUsed,
Balance: userInfo.GiftAmount,
Remark: "Purchase payment",
}
err = db.Create(giftLog).Error
if err != nil {
return err
}
}
// Create balance log if regular balance was used
if balanceUsed > 0 {
balanceLog := &user.BalanceLog{
UserId: u.Id,
Amount: balanceUsed,
Type: 3, // Type 3 represents payment deduction
OrderId: o.Id,
Balance: userInfo.Balance,
}
@ -316,30 +399,46 @@ func (l *PurchaseCheckoutLogic) balancePayment(u *user.User, o *order.Order) err
if err != nil {
return err
}
return db.Model(&order.Order{}).Where("id = ?", o.Id).Updates(map[string]interface{}{
"status": 2, // 2 means paid
}).Error
})
}
// Store gift amount used in order for potential refund tracking
o.GiftAmount = giftUsed
err = l.svcCtx.OrderModel.Update(l.ctx, o)
if err != nil {
l.Errorw("[CheckoutOrderLogic] Transaction error", logger.Field("error", err.Error()), logger.Field("orderNo", o.OrderNo))
return err
}
// create activity order task
// Mark order as paid (status = 2)
return l.svcCtx.OrderModel.UpdateOrderStatus(l.ctx, o.OrderNo, 2, db)
})
if err != nil {
l.Errorw("[PurchaseCheckout] Balance payment transaction error",
logger.Field("error", err.Error()),
logger.Field("orderNo", o.OrderNo),
logger.Field("userId", u.Id))
return err
}
// Enqueue order activation task for immediate processing
payload := queueType.ForthwithActivateOrderPayload{
OrderNo: o.OrderNo,
}
bytes, err := json.Marshal(payload)
if err != nil {
l.Errorw("[CheckoutOrderLogic] Marshal error", logger.Field("error", err.Error()))
l.Errorw("[PurchaseCheckout] Marshal activation payload error", logger.Field("error", err.Error()))
return err
}
task := asynq.NewTask(queueType.ForthwithActivateOrder, bytes)
_, err = l.svcCtx.Queue.EnqueueContext(l.ctx, task)
if err != nil {
l.Errorw("[CheckoutOrderLogic] Enqueue error", logger.Field("error", err.Error()))
l.Errorw("[PurchaseCheckout] Enqueue activation task error", logger.Field("error", err.Error()))
return err
}
l.Logger.Info("[CheckoutOrderLogic] Enqueue success", logger.Field("orderNo", o.OrderNo))
l.Logger.Info("[PurchaseCheckout] Balance payment completed successfully",
logger.Field("orderNo", o.OrderNo),
logger.Field("userId", u.Id))
return nil
}

View File

@ -93,7 +93,6 @@ func (l *PurchaseLogic) Purchase(req *types.PortalPurchaseRequest) (resp *types.
}
// Calculate the handling fee
amount -= couponAmount
var deductionAmount int64
// find payment method
paymentConfig, err := l.svcCtx.PaymentModel.FindOne(l.ctx, req.Payment)
if err != nil {
@ -118,7 +117,7 @@ func (l *PurchaseLogic) Purchase(req *types.PortalPurchaseRequest) (resp *types.
Price: price,
Amount: amount,
Discount: discountAmount,
GiftAmount: deductionAmount,
GiftAmount: 0,
Coupon: req.Coupon,
CouponDiscount: couponAmount,
PaymentId: req.Payment,

View File

@ -38,6 +38,7 @@ func CalculateRemainingAmount(ctx context.Context, svcCtx *svc.ServiceContext, u
orderQuantity := orderDetails.Quantity
// Calculate Order Amount
orderAmount := orderDetails.Amount + orderDetails.GiftAmount
if len(orderDetails.SubOrders) > 0 {
for _, subOrder := range orderDetails.SubOrders {
if subOrder.Status == 2 || subOrder.Status == 5 {
@ -47,7 +48,7 @@ func CalculateRemainingAmount(ctx context.Context, svcCtx *svc.ServiceContext, u
}
}
// Calculate Remaining Amount
remainingAmount := deduction.CalculateRemainingAmount(
remainingAmount, err := deduction.CalculateRemainingAmount(
deduction.Subscribe{
StartTime: userSubscribe.StartTime,
ExpireTime: userSubscribe.ExpireTime,
@ -64,5 +65,8 @@ func CalculateRemainingAmount(ctx context.Context, svcCtx *svc.ServiceContext, u
Quantity: orderQuantity,
},
)
if err != nil {
return 0, errors.Wrapf(xerr.NewErrCode(500), "CalculateRemainingAmount failed, userSubscribeId: %d, err: %v", userSubscribeId, err)
}
return remainingAmount, nil
}

View File

@ -21,7 +21,7 @@ type UnsubscribeLogic struct {
svcCtx *svc.ServiceContext
}
// NewUnsubscribeLogic Unsubscribe
// NewUnsubscribeLogic creates a new instance of UnsubscribeLogic for handling subscription cancellation
func NewUnsubscribeLogic(ctx context.Context, svcCtx *svc.ServiceContext) *UnsubscribeLogic {
return &UnsubscribeLogic{
Logger: logger.WithContext(ctx),
@ -30,39 +30,90 @@ func NewUnsubscribeLogic(ctx context.Context, svcCtx *svc.ServiceContext) *Unsub
}
}
// Unsubscribe handles the subscription cancellation process with proper refund distribution
// It prioritizes refunding to gift amount for balance-paid orders, then to regular balance
func (l *UnsubscribeLogic) Unsubscribe(req *types.UnsubscribeRequest) error {
u, ok := l.ctx.Value(constant.CtxKeyUser).(*user.User)
if !ok {
logger.Error("current user is not found in context")
return errors.Wrapf(xerr.NewErrCode(xerr.InvalidAccess), "Invalid Access")
}
// Calculate the remaining amount to refund based on unused subscription time/traffic
remainingAmount, err := CalculateRemainingAmount(l.ctx, l.svcCtx, req.Id)
if err != nil {
return err
}
// update user subscribe
// Process unsubscription in a database transaction to ensure data consistency
err = l.svcCtx.UserModel.Transaction(l.ctx, func(db *gorm.DB) error {
// Find and update subscription status to cancelled (status = 4)
var userSub user.Subscribe
if err := db.Model(&user.Subscribe{}).Where("id = ?", req.Id).First(&userSub).Error; err != nil {
if err = db.Model(&user.Subscribe{}).Where("id = ?", req.Id).First(&userSub).Error; err != nil {
return err
}
userSub.Status = 4
if err := l.svcCtx.UserModel.UpdateSubscribe(l.ctx, &userSub); err != nil {
userSub.Status = 4 // Set status to cancelled
if err = l.svcCtx.UserModel.UpdateSubscribe(l.ctx, &userSub); err != nil {
return err
}
balance := remainingAmount + u.Balance
// insert deduction log
// Query the original order information to determine refund strategy
orderInfo, err := l.svcCtx.OrderModel.FindOne(l.ctx, userSub.OrderId)
if err != nil {
return err
}
// Calculate refund distribution based on payment method and gift amount priority
var balance, gift int64
if orderInfo.Method == "balance" {
// For balance-paid orders, prioritize refunding to gift amount first
if orderInfo.GiftAmount >= remainingAmount {
// Gift amount covers the entire refund - refund all to gift balance
gift = remainingAmount
balance = u.Balance // Regular balance remains unchanged
} else {
// Gift amount insufficient - refund to gift first, remainder to regular balance
gift = orderInfo.GiftAmount
balance = u.Balance + (remainingAmount - orderInfo.GiftAmount)
}
} else {
// For non-balance payment orders, refund entirely to regular balance
balance = remainingAmount + u.Balance
gift = 0
}
// Create balance log entry only if there's an actual regular balance refund
balanceRefundAmount := balance - u.Balance
if balanceRefundAmount > 0 {
balanceLog := user.BalanceLog{
UserId: userSub.UserId,
OrderId: userSub.OrderId,
Amount: remainingAmount,
Type: 4,
Amount: balanceRefundAmount,
Type: 4, // Type 4 represents refund transaction
Balance: balance,
}
if err := db.Model(&user.BalanceLog{}).Create(&balanceLog).Error; err != nil {
return err
}
// update user balance
}
// Create gift amount log entry if there's a gift balance refund
if gift > 0 {
giftLog := user.GiftAmountLog{
UserId: userSub.UserId,
UserSubscribeId: userSub.Id,
OrderNo: orderInfo.OrderNo,
Type: 1, // Type 1 represents gift amount increase
Amount: gift,
Balance: u.GiftAmount + gift,
Remark: "Unsubscribe refund",
}
if err := db.Model(&user.GiftAmountLog{}).Create(&giftLog).Error; err != nil {
return err
}
// Update user's gift amount
u.GiftAmount += gift
}
// Update user's regular balance and save changes to database
u.Balance = balance
return l.svcCtx.UserModel.Update(l.ctx, u)
})

View File

@ -1,90 +1,302 @@
// Package deduction provides functionality for calculating remaining amounts
// in subscription billing systems, supporting various time units and traffic-based calculations.
package deduction
import (
"errors"
"fmt"
"math"
"time"
"github.com/perfect-panel/server/pkg/tool"
)
const (
UnitTimeNoLimit = "NoLimit"
UnitTimeYear = "Year"
UnitTimeMonth = "Month"
UnitTimeDay = "Day"
UintTimeHour = "Hour"
UintTimeMinute = "Minute"
// Time unit constants for subscription billing
UnitTimeNoLimit = "NoLimit" // Unlimited time subscription
UnitTimeYear = "Year" // Annual subscription
UnitTimeMonth = "Month" // Monthly subscription
UnitTimeDay = "Day" // Daily subscription
UnitTimeHour = "Hour" // Hourly subscription
UnitTimeMinute = "Minute" // Per-minute subscription
ResetCycleNone = 0
ResetCycle1st = 1
ResetCycleMonthly = 2
ResetCycleYear = 3
// Reset cycle constants for traffic resets
ResetCycleNone = 0 // No reset cycle
ResetCycle1st = 1 // Reset on 1st of each month
ResetCycleMonthly = 2 // Reset monthly based on start date
ResetCycleYear = 3 // Reset yearly based on start date
// Safety limits for overflow protection
maxInt64 = math.MaxInt64
minInt64 = math.MinInt64
)
// Error definitions for validation and calculation failures
var (
ErrInvalidQuantity = errors.New("order quantity cannot be zero or negative")
ErrInvalidAmount = errors.New("order amount cannot be negative")
ErrInvalidTraffic = errors.New("traffic values cannot be negative")
ErrInvalidTimeRange = errors.New("expire time must be after start time")
ErrInvalidUnitTime = errors.New("invalid unit time")
ErrInvalidDeductionRatio = errors.New("deduction ratio must be between 0 and 100")
ErrOverflow = errors.New("calculation overflow")
)
// Subscribe represents a subscription with time and traffic limits
type Subscribe struct {
StartTime time.Time
ExpireTime time.Time
Traffic int64
Download int64
Upload int64
UnitTime string
UnitPrice int64
ResetCycle int64
DeductionRatio int64
StartTime time.Time // Subscription start time
ExpireTime time.Time // Subscription expiration time
Traffic int64 // Total traffic allowance in bytes
Download int64 // Downloaded traffic in bytes
Upload int64 // Uploaded traffic in bytes
UnitTime string // Time unit for billing (Year, Month, Day, etc.)
UnitPrice int64 // Price per unit time
ResetCycle int64 // Traffic reset cycle
DeductionRatio int64 // Deduction ratio for weighted calculations (0-100)
}
// Order represents a purchase order for subscription calculation
type Order struct {
Amount int64
Quantity int64
Amount int64 // Total order amount
Quantity int64 // Order quantity
}
func CalculateRemainingAmount(sub Subscribe, order Order) int64 {
if sub.UnitTime == UnitTimeNoLimit && sub.ResetCycle != 0 {
return 0
// Validate checks if the Subscribe struct contains valid data
func (s *Subscribe) Validate() error {
if s.Traffic < 0 || s.Download < 0 || s.Upload < 0 {
return ErrInvalidTraffic
}
// 实际单价
sub.UnitPrice = order.Amount / order.Quantity
now := time.Now()
if s.Download+s.Upload > s.Traffic {
return fmt.Errorf("download + upload (%d) cannot exceed total traffic (%d)", s.Download+s.Upload, s.Traffic)
}
if !s.ExpireTime.After(s.StartTime) {
return ErrInvalidTimeRange
}
if s.DeductionRatio < 0 || s.DeductionRatio > 100 {
return ErrInvalidDeductionRatio
}
validUnitTimes := []string{UnitTimeNoLimit, UnitTimeYear, UnitTimeMonth, UnitTimeDay, UnitTimeHour, UnitTimeMinute}
valid := false
for _, ut := range validUnitTimes {
if s.UnitTime == ut {
valid = true
break
}
}
if !valid {
return ErrInvalidUnitTime
}
return nil
}
// Validate checks if the Order struct contains valid data
func (o *Order) Validate() error {
if o.Quantity <= 0 {
return ErrInvalidQuantity
}
if o.Amount < 0 {
return ErrInvalidAmount
}
return nil
}
// safeMultiply performs multiplication with overflow protection
func safeMultiply(a, b int64) (int64, error) {
if a == 0 || b == 0 {
return 0, nil
}
if a > 0 && b > 0 {
if a > maxInt64/b {
return 0, ErrOverflow
}
} else if a < 0 && b < 0 {
if a < maxInt64/b {
return 0, ErrOverflow
}
} else {
if (a > 0 && b < minInt64/a) || (a < 0 && b > minInt64/a) {
return 0, ErrOverflow
}
}
return a * b, nil
}
// safeAdd performs addition with overflow protection
func safeAdd(a, b int64) (int64, error) {
if (b > 0 && a > maxInt64-b) || (b < 0 && a < minInt64-b) {
return 0, ErrOverflow
}
return a + b, nil
}
// safeDivide performs division with zero-division protection
func safeDivide(a, b int64) (int64, error) {
if b == 0 {
return 0, errors.New("division by zero")
}
return a / b, nil
}
// CalculateRemainingAmount calculates the remaining refund amount for a subscription
// based on unused time and traffic. Returns the amount and any calculation errors.
func CalculateRemainingAmount(sub Subscribe, order Order) (int64, error) {
if err := sub.Validate(); err != nil {
return 0, fmt.Errorf("invalid subscription: %w", err)
}
if err := order.Validate(); err != nil {
return 0, fmt.Errorf("invalid order: %w", err)
}
if sub.UnitTime == UnitTimeNoLimit && sub.ResetCycle != 0 {
return 0, nil
}
unitPrice, err := safeDivide(order.Amount, order.Quantity)
if err != nil {
return 0, fmt.Errorf("failed to calculate unit price: %w", err)
}
sub.UnitPrice = unitPrice
loc, err := time.LoadLocation(sub.StartTime.Location().String())
if err != nil {
loc = time.UTC
}
now := time.Now().In(loc)
switch sub.UnitTime {
case UnitTimeNoLimit:
usedTraffic := sub.Traffic - sub.Download - sub.Upload
unitPrice := float64(order.Amount) / float64(sub.Traffic)
return int64(float64(usedTraffic) * unitPrice)
return calculateNoLimitAmount(sub, order)
case UnitTimeYear:
remainingYears := tool.YearDiff(now, sub.ExpireTime)
remainingUnitTimeAmount := calculateRemainingUnitTimeAmount(sub)
return int64(remainingYears)*sub.UnitPrice + remainingUnitTimeAmount
remainingUnitTimeAmount, err := calculateRemainingUnitTimeAmount(sub)
if err != nil {
return 0, err
}
yearAmount, err := safeMultiply(int64(remainingYears), sub.UnitPrice)
if err != nil {
return 0, fmt.Errorf("year calculation overflow: %w", err)
}
total, err := safeAdd(yearAmount, remainingUnitTimeAmount)
if err != nil {
return 0, fmt.Errorf("total calculation overflow: %w", err)
}
return total, nil
case UnitTimeMonth:
remainingMonths := tool.MonthDiff(now, sub.ExpireTime)
remainingUnitTimeAmount := calculateRemainingUnitTimeAmount(sub)
return int64(remainingMonths)*sub.UnitPrice + remainingUnitTimeAmount
case UnitTimeDay:
remainingDays := tool.DayDiff(now, sub.ExpireTime)
remainingUnitTimeAmount := calculateRemainingUnitTimeAmount(sub)
return remainingDays*sub.UnitPrice + remainingUnitTimeAmount
remainingUnitTimeAmount, err := calculateRemainingUnitTimeAmount(sub)
if err != nil {
return 0, err
}
return 0
monthAmount, err := safeMultiply(int64(remainingMonths), sub.UnitPrice)
if err != nil {
return 0, fmt.Errorf("month calculation overflow: %w", err)
}
total, err := safeAdd(monthAmount, remainingUnitTimeAmount)
if err != nil {
return 0, fmt.Errorf("total calculation overflow: %w", err)
}
return total, nil
case UnitTimeDay:
remainingDays := tool.DayDiff(now, sub.ExpireTime)
remainingUnitTimeAmount, err := calculateRemainingUnitTimeAmount(sub)
if err != nil {
return 0, err
}
dayAmount, err := safeMultiply(remainingDays, sub.UnitPrice)
if err != nil {
return 0, fmt.Errorf("day calculation overflow: %w", err)
}
total, err := safeAdd(dayAmount, remainingUnitTimeAmount)
if err != nil {
return 0, fmt.Errorf("total calculation overflow: %w", err)
}
return total, nil
}
return 0, nil
}
func calculateRemainingUnitTimeAmount(sub Subscribe) int64 {
// calculateNoLimitAmount calculates refund amount for unlimited time subscriptions
// based on unused traffic only
func calculateNoLimitAmount(sub Subscribe, order Order) (int64, error) {
if sub.Traffic == 0 {
return 0, nil
}
usedTraffic := sub.Traffic - sub.Download - sub.Upload
if usedTraffic < 0 {
usedTraffic = 0
}
unitPrice := float64(order.Amount) / float64(sub.Traffic)
result := float64(usedTraffic) * unitPrice
if result > float64(maxInt64) || result < float64(minInt64) {
return 0, ErrOverflow
}
return int64(result), nil
}
// calculateRemainingUnitTimeAmount calculates the remaining amount based on
// both time and traffic usage, applying deduction ratios when specified
func calculateRemainingUnitTimeAmount(sub Subscribe) (int64, error) {
now := time.Now()
trafficWeight, timeWeight := calculateWeights(sub.DeductionRatio)
remainingDays, totalDays := getRemainingAndTotalDays(sub, now)
remainingTraffic := sub.Traffic - sub.Download - sub.Upload
remainingTimeAmount := calculateProportionalAmount(sub.UnitPrice, remainingDays, totalDays)
remainingTrafficAmount := calculateProportionalAmount(sub.UnitPrice, remainingTraffic, sub.Traffic)
if sub.Traffic == 0 {
return remainingTimeAmount
if totalDays == 0 {
return 0, nil
}
remainingTraffic := sub.Traffic - sub.Download - sub.Upload
if remainingTraffic < 0 {
remainingTraffic = 0
}
remainingTimeAmount, err := calculateProportionalAmount(sub.UnitPrice, remainingDays, totalDays)
if err != nil {
return 0, fmt.Errorf("time amount calculation failed: %w", err)
}
if sub.Traffic == 0 {
return remainingTimeAmount, nil
}
remainingTrafficAmount, err := calculateProportionalAmount(sub.UnitPrice, remainingTraffic, sub.Traffic)
if err != nil {
return 0, fmt.Errorf("traffic amount calculation failed: %w", err)
}
if sub.DeductionRatio != 0 {
return calculateWeightedAmount(sub.UnitPrice, remainingTraffic, sub.Traffic, remainingDays, totalDays, trafficWeight, timeWeight)
}
return min(remainingTimeAmount, remainingTrafficAmount)
return min(remainingTimeAmount, remainingTrafficAmount), nil
}
// calculateWeights converts deduction ratio to traffic and time weights
// for weighted calculations
func calculateWeights(deductionRatio int64) (float64, float64) {
if deductionRatio == 0 {
return 0, 0
@ -94,20 +306,32 @@ func calculateWeights(deductionRatio int64) (float64, float64) {
return trafficWeight, timeWeight
}
// getRemainingAndTotalDays calculates remaining and total days based on
// the subscription's reset cycle configuration
func getRemainingAndTotalDays(sub Subscribe, now time.Time) (int64, int64) {
switch sub.ResetCycle {
case ResetCycleNone:
remaining := sub.ExpireTime.Sub(now).Hours() / 24
total := sub.ExpireTime.Sub(sub.StartTime).Hours() / 24
if remaining < 0 {
remaining = 0
}
if total < 0 {
total = 0
}
return int64(remaining), int64(total)
case ResetCycle1st:
return tool.DaysToNextMonth(now), tool.GetLastDayOfMonth(now)
case ResetCycleMonthly:
// -1 to include the current day
return tool.DaysToMonthDay(now, sub.StartTime.Day()) - 1, tool.DaysToMonthDay(now, sub.StartTime.Day())
remaining := tool.DaysToMonthDay(now, sub.StartTime.Day()) - 1
total := tool.DaysToMonthDay(now, sub.StartTime.Day())
if remaining < 0 {
remaining = 0
}
return remaining, total
case ResetCycleYear:
return tool.DaysToYearDay(now, int(sub.StartTime.Month()), sub.StartTime.Day()),
tool.GetYearDays(now, int(sub.StartTime.Month()), sub.StartTime.Day())
@ -115,13 +339,36 @@ func getRemainingAndTotalDays(sub Subscribe, now time.Time) (int64, int64) {
return 0, 0
}
func calculateWeightedAmount(unitPrice, remainingTraffic, totalTraffic, remainingDays, totalDays int64, trafficWeight, timeWeight float64) int64 {
// calculateWeightedAmount applies weighted calculation combining both time and traffic
// remaining ratios based on the specified weights
func calculateWeightedAmount(unitPrice, remainingTraffic, totalTraffic, remainingDays, totalDays int64, trafficWeight, timeWeight float64) (int64, error) {
if totalDays == 0 || totalTraffic == 0 {
return 0, nil
}
remainingTimeRatio := float64(remainingDays) / float64(totalDays)
remainingTrafficRatio := float64(remainingTraffic) / float64(totalTraffic)
weightedRemainingRatio := (timeWeight * remainingTimeRatio) + (trafficWeight * remainingTrafficRatio)
return int64(float64(unitPrice) * weightedRemainingRatio)
result := float64(unitPrice) * weightedRemainingRatio
if result > float64(maxInt64) || result < float64(minInt64) {
return 0, ErrOverflow
}
return int64(result), nil
}
func calculateProportionalAmount(unitPrice, remaining, total int64) int64 {
return int64(float64(unitPrice) * (float64(remaining) / float64(total)))
// calculateProportionalAmount calculates proportional amount based on
// remaining vs total ratio with overflow protection
func calculateProportionalAmount(unitPrice, remaining, total int64) (int64, error) {
if total == 0 {
return 0, nil
}
result := float64(unitPrice) * (float64(remaining) / float64(total))
if result > float64(maxInt64) || result < float64(minInt64) {
return 0, ErrOverflow
}
return int64(result), nil
}

View File

@ -0,0 +1,665 @@
package deduction
import (
"math"
"testing"
"time"
)
func TestSubscribe_Validate(t *testing.T) {
tests := []struct {
name string
sub Subscribe
wantErr bool
errType error
}{
{
name: "valid subscription",
sub: Subscribe{
StartTime: time.Now(),
ExpireTime: time.Now().Add(24 * time.Hour),
Traffic: 1000,
Download: 100,
Upload: 200,
UnitTime: UnitTimeMonth,
DeductionRatio: 50,
},
wantErr: false,
},
{
name: "negative traffic",
sub: Subscribe{
StartTime: time.Now(),
ExpireTime: time.Now().Add(24 * time.Hour),
Traffic: -1000,
Download: 100,
Upload: 200,
UnitTime: UnitTimeMonth,
DeductionRatio: 50,
},
wantErr: true,
errType: ErrInvalidTraffic,
},
{
name: "negative download",
sub: Subscribe{
StartTime: time.Now(),
ExpireTime: time.Now().Add(24 * time.Hour),
Traffic: 1000,
Download: -100,
Upload: 200,
UnitTime: UnitTimeMonth,
DeductionRatio: 50,
},
wantErr: true,
errType: ErrInvalidTraffic,
},
{
name: "download + upload exceeds traffic",
sub: Subscribe{
StartTime: time.Now(),
ExpireTime: time.Now().Add(24 * time.Hour),
Traffic: 1000,
Download: 600,
Upload: 500,
UnitTime: UnitTimeMonth,
DeductionRatio: 50,
},
wantErr: true,
},
{
name: "expire time before start time",
sub: Subscribe{
StartTime: time.Now(),
ExpireTime: time.Now().Add(-24 * time.Hour),
Traffic: 1000,
Download: 100,
Upload: 200,
UnitTime: UnitTimeMonth,
DeductionRatio: 50,
},
wantErr: true,
errType: ErrInvalidTimeRange,
},
{
name: "invalid deduction ratio - negative",
sub: Subscribe{
StartTime: time.Now(),
ExpireTime: time.Now().Add(24 * time.Hour),
Traffic: 1000,
Download: 100,
Upload: 200,
UnitTime: UnitTimeMonth,
DeductionRatio: -10,
},
wantErr: true,
errType: ErrInvalidDeductionRatio,
},
{
name: "invalid deduction ratio - over 100",
sub: Subscribe{
StartTime: time.Now(),
ExpireTime: time.Now().Add(24 * time.Hour),
Traffic: 1000,
Download: 100,
Upload: 200,
UnitTime: UnitTimeMonth,
DeductionRatio: 150,
},
wantErr: true,
errType: ErrInvalidDeductionRatio,
},
{
name: "invalid unit time",
sub: Subscribe{
StartTime: time.Now(),
ExpireTime: time.Now().Add(24 * time.Hour),
Traffic: 1000,
Download: 100,
Upload: 200,
UnitTime: "InvalidUnit",
DeductionRatio: 50,
},
wantErr: true,
errType: ErrInvalidUnitTime,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.sub.Validate()
if (err != nil) != tt.wantErr {
t.Errorf("Subscribe.Validate() error = %v, wantErr %v", err, tt.wantErr)
return
}
if tt.errType != nil && err != tt.errType {
t.Errorf("Subscribe.Validate() error = %v, want %v", err, tt.errType)
}
})
}
}
func TestOrder_Validate(t *testing.T) {
tests := []struct {
name string
order Order
wantErr bool
errType error
}{
{
name: "valid order",
order: Order{Amount: 1000, Quantity: 2},
wantErr: false,
},
{
name: "zero quantity",
order: Order{Amount: 1000, Quantity: 0},
wantErr: true,
errType: ErrInvalidQuantity,
},
{
name: "negative quantity",
order: Order{Amount: 1000, Quantity: -1},
wantErr: true,
errType: ErrInvalidQuantity,
},
{
name: "negative amount",
order: Order{Amount: -1000, Quantity: 2},
wantErr: true,
errType: ErrInvalidAmount,
},
{
name: "zero amount is valid",
order: Order{Amount: 0, Quantity: 1},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.order.Validate()
if (err != nil) != tt.wantErr {
t.Errorf("Order.Validate() error = %v, wantErr %v", err, tt.wantErr)
return
}
if tt.errType != nil && err != tt.errType {
t.Errorf("Order.Validate() error = %v, want %v", err, tt.errType)
}
})
}
}
func TestSafeMultiply(t *testing.T) {
tests := []struct {
name string
a, b int64
want int64
wantErr bool
}{
{
name: "normal multiplication",
a: 10,
b: 20,
want: 200,
wantErr: false,
},
{
name: "zero multiplication",
a: 10,
b: 0,
want: 0,
wantErr: false,
},
{
name: "negative multiplication",
a: -10,
b: 20,
want: -200,
wantErr: false,
},
{
name: "overflow case",
a: math.MaxInt64,
b: 2,
want: 0,
wantErr: true,
},
{
name: "large numbers no overflow",
a: 1000000,
b: 1000000,
want: 1000000000000,
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := safeMultiply(tt.a, tt.b)
if (err != nil) != tt.wantErr {
t.Errorf("safeMultiply() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("safeMultiply() = %v, want %v", got, tt.want)
}
})
}
}
func TestSafeAdd(t *testing.T) {
tests := []struct {
name string
a, b int64
want int64
wantErr bool
}{
{
name: "normal addition",
a: 10,
b: 20,
want: 30,
wantErr: false,
},
{
name: "negative addition",
a: -10,
b: 5,
want: -5,
wantErr: false,
},
{
name: "overflow case",
a: math.MaxInt64,
b: 1,
want: 0,
wantErr: true,
},
{
name: "underflow case",
a: math.MinInt64,
b: -1,
want: 0,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := safeAdd(tt.a, tt.b)
if (err != nil) != tt.wantErr {
t.Errorf("safeAdd() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("safeAdd() = %v, want %v", got, tt.want)
}
})
}
}
func TestSafeDivide(t *testing.T) {
tests := []struct {
name string
a, b int64
want int64
wantErr bool
}{
{
name: "normal division",
a: 20,
b: 10,
want: 2,
wantErr: false,
},
{
name: "division by zero",
a: 20,
b: 0,
want: 0,
wantErr: true,
},
{
name: "negative division",
a: -20,
b: 10,
want: -2,
wantErr: false,
},
{
name: "zero dividend",
a: 0,
b: 10,
want: 0,
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := safeDivide(tt.a, tt.b)
if (err != nil) != tt.wantErr {
t.Errorf("safeDivide() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("safeDivide() = %v, want %v", got, tt.want)
}
})
}
}
func TestCalculateWeights(t *testing.T) {
tests := []struct {
name string
deductionRatio int64
wantTrafficWeight float64
wantTimeWeight float64
}{
{
name: "zero ratio",
deductionRatio: 0,
wantTrafficWeight: 0,
wantTimeWeight: 0,
},
{
name: "50% ratio",
deductionRatio: 50,
wantTrafficWeight: 0.5,
wantTimeWeight: 0.5,
},
{
name: "75% ratio",
deductionRatio: 75,
wantTrafficWeight: 0.75,
wantTimeWeight: 0.25,
},
{
name: "100% ratio",
deductionRatio: 100,
wantTrafficWeight: 1.0,
wantTimeWeight: 0.0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotTrafficWeight, gotTimeWeight := calculateWeights(tt.deductionRatio)
if gotTrafficWeight != tt.wantTrafficWeight {
t.Errorf("calculateWeights() trafficWeight = %v, want %v", gotTrafficWeight, tt.wantTrafficWeight)
}
if gotTimeWeight != tt.wantTimeWeight {
t.Errorf("calculateWeights() timeWeight = %v, want %v", gotTimeWeight, tt.wantTimeWeight)
}
})
}
}
func TestCalculateProportionalAmount(t *testing.T) {
tests := []struct {
name string
unitPrice int64
remaining int64
total int64
want int64
wantErr bool
}{
{
name: "normal calculation",
unitPrice: 100,
remaining: 50,
total: 100,
want: 50,
wantErr: false,
},
{
name: "zero total",
unitPrice: 100,
remaining: 50,
total: 0,
want: 0,
wantErr: false,
},
{
name: "zero remaining",
unitPrice: 100,
remaining: 0,
total: 100,
want: 0,
wantErr: false,
},
{
name: "quarter remaining",
unitPrice: 200,
remaining: 25,
total: 100,
want: 50,
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := calculateProportionalAmount(tt.unitPrice, tt.remaining, tt.total)
if (err != nil) != tt.wantErr {
t.Errorf("calculateProportionalAmount() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("calculateProportionalAmount() = %v, want %v", got, tt.want)
}
})
}
}
func TestCalculateNoLimitAmount(t *testing.T) {
tests := []struct {
name string
sub Subscribe
order Order
want int64
wantErr bool
}{
{
name: "normal no limit calculation",
sub: Subscribe{
Traffic: 1000,
Download: 300,
Upload: 200,
},
order: Order{
Amount: 1000,
},
want: 500, // (1000 - 300 - 200) / 1000 * 1000 = 500
wantErr: false,
},
{
name: "zero traffic",
sub: Subscribe{
Traffic: 0,
Download: 0,
Upload: 0,
},
order: Order{
Amount: 1000,
},
want: 0,
wantErr: false,
},
{
name: "overused traffic",
sub: Subscribe{
Traffic: 1000,
Download: 600,
Upload: 500,
},
order: Order{
Amount: 1000,
},
want: 0, // usedTraffic would be negative, clamped to 0
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := calculateNoLimitAmount(tt.sub, tt.order)
if (err != nil) != tt.wantErr {
t.Errorf("calculateNoLimitAmount() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("calculateNoLimitAmount() = %v, want %v", got, tt.want)
}
})
}
}
func TestCalculateRemainingAmount(t *testing.T) {
now := time.Now()
tests := []struct {
name string
sub Subscribe
order Order
wantErr bool
}{
{
name: "valid no limit subscription",
sub: Subscribe{
StartTime: now.Add(-24 * time.Hour),
ExpireTime: now.Add(24 * time.Hour),
Traffic: 1000,
Download: 300,
Upload: 200,
UnitTime: UnitTimeNoLimit,
ResetCycle: ResetCycleNone,
DeductionRatio: 0,
},
order: Order{
Amount: 1000,
Quantity: 1,
},
wantErr: false,
},
{
name: "invalid subscription",
sub: Subscribe{
StartTime: now,
ExpireTime: now.Add(-24 * time.Hour), // Invalid: expire before start
Traffic: 1000,
Download: 300,
Upload: 200,
UnitTime: UnitTimeMonth,
DeductionRatio: 0,
},
order: Order{
Amount: 1000,
Quantity: 1,
},
wantErr: true,
},
{
name: "invalid order",
sub: Subscribe{
StartTime: now.Add(-24 * time.Hour),
ExpireTime: now.Add(24 * time.Hour),
Traffic: 1000,
Download: 300,
Upload: 200,
UnitTime: UnitTimeMonth,
DeductionRatio: 0,
},
order: Order{
Amount: 1000,
Quantity: 0, // Invalid: zero quantity
},
wantErr: true,
},
{
name: "no limit with reset cycle",
sub: Subscribe{
StartTime: now.Add(-24 * time.Hour),
ExpireTime: now.Add(24 * time.Hour),
Traffic: 1000,
Download: 300,
Upload: 200,
UnitTime: UnitTimeNoLimit,
ResetCycle: ResetCycleMonthly, // Should return 0
DeductionRatio: 0,
},
order: Order{
Amount: 1000,
Quantity: 1,
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := CalculateRemainingAmount(tt.sub, tt.order)
if (err != nil) != tt.wantErr {
t.Errorf("CalculateRemainingAmount() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestCalculateRemainingAmount_NoLimitWithResetCycle(t *testing.T) {
now := time.Now()
sub := Subscribe{
StartTime: now.Add(-24 * time.Hour),
ExpireTime: now.Add(24 * time.Hour),
Traffic: 1000,
Download: 300,
Upload: 200,
UnitTime: UnitTimeNoLimit,
ResetCycle: ResetCycleMonthly,
DeductionRatio: 0,
}
order := Order{
Amount: 1000,
Quantity: 1,
}
got, err := CalculateRemainingAmount(sub, order)
if err != nil {
t.Errorf("CalculateRemainingAmount() error = %v", err)
return
}
if got != 0 {
t.Errorf("CalculateRemainingAmount() = %v, want 0", got)
}
}
// Benchmark tests
func BenchmarkCalculateRemainingAmount(b *testing.B) {
now := time.Now()
sub := Subscribe{
StartTime: now.Add(-24 * time.Hour),
ExpireTime: now.Add(24 * time.Hour),
Traffic: 1000,
Download: 300,
Upload: 200,
UnitTime: UnitTimeMonth,
ResetCycle: ResetCycleNone,
DeductionRatio: 50,
}
order := Order{
Amount: 1000,
Quantity: 1,
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = CalculateRemainingAmount(sub, order)
}
}
func BenchmarkSafeMultiply(b *testing.B) {
for i := 0; i < b.N; i++ {
_, _ = safeMultiply(12345, 67890)
}
}