diff --git a/internal/logic/public/order/purchaseLogic.go b/internal/logic/public/order/purchaseLogic.go index 5030e96..98263d2 100644 --- a/internal/logic/public/order/purchaseLogic.go +++ b/internal/logic/public/order/purchaseLogic.go @@ -132,7 +132,7 @@ func (l *PurchaseLogic) Purchase(req *types.PurchaseOrderRequest) (resp *types.P if u.GiftAmount >= amount { deductionAmount = amount amount = 0 - u.GiftAmount -= amount + u.GiftAmount -= deductionAmount } else { deductionAmount = u.GiftAmount amount -= u.GiftAmount diff --git a/internal/logic/public/portal/purchaseCheckoutLogic.go b/internal/logic/public/portal/purchaseCheckoutLogic.go index e7e293a..7ce1f9e 100644 --- a/internal/logic/public/portal/purchaseCheckoutLogic.go +++ b/internal/logic/public/portal/purchaseCheckoutLogic.go @@ -28,13 +28,16 @@ import ( "github.com/pkg/errors" ) +// PurchaseCheckoutLogic handles the checkout process for various payment methods +// including EPay, Stripe, Alipay F2F, and balance payments type PurchaseCheckoutLogic struct { logger.Logger ctx context.Context svcCtx *svc.ServiceContext } -// NewPurchaseCheckoutLogic Purchase Checkout +// NewPurchaseCheckoutLogic creates a new instance of PurchaseCheckoutLogic +// for handling purchase checkout operations across different payment platforms func NewPurchaseCheckoutLogic(ctx context.Context, svcCtx *svc.ServiceContext) *PurchaseCheckoutLogic { return &PurchaseCheckoutLogic{ Logger: logger.WithContext(ctx), @@ -43,88 +46,104 @@ func NewPurchaseCheckoutLogic(ctx context.Context, svcCtx *svc.ServiceContext) * } } +// PurchaseCheckout processes the checkout for an order using the specified payment method +// It validates the order, retrieves payment configuration, and routes to the appropriate payment handler func (l *PurchaseCheckoutLogic) PurchaseCheckout(req *types.CheckoutOrderRequest) (resp *types.CheckoutOrderResponse, err error) { - // Find order + // Validate and retrieve order information orderInfo, err := l.svcCtx.OrderModel.FindOneByOrderNo(l.ctx, req.OrderNo) if err != nil { l.Logger.Error("[PurchaseCheckout] Find order failed", logger.Field("error", err.Error()), logger.Field("orderNo", req.OrderNo)) return nil, errors.Wrapf(xerr.NewErrCode(xerr.OrderNotExist), "order not exist: %v", req.OrderNo) } + // Verify order is in pending payment status (status = 1) if orderInfo.Status != 1 { l.Logger.Error("[PurchaseCheckout] Order status error", logger.Field("status", orderInfo.Status)) return nil, errors.Wrapf(xerr.NewErrCode(xerr.OrderStatusError), "order status error: %v", orderInfo.Status) } - // find payment method + // Retrieve payment method configuration paymentConfig, err := l.svcCtx.PaymentModel.FindOne(l.ctx, orderInfo.PaymentId) if err != nil { - l.Logger.Error("[Purchase] Database query error", logger.Field("error", err.Error()), logger.Field("payment", orderInfo.Method)) + l.Logger.Error("[PurchaseCheckout] Database query error", logger.Field("error", err.Error()), logger.Field("payment", orderInfo.Method)) return nil, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "find payment method error: %v", err.Error()) } + // Route to appropriate payment handler based on payment platform switch paymentPlatform.ParsePlatform(orderInfo.Method) { case paymentPlatform.EPay: + // Process EPay payment - generates payment URL for redirect url, err := l.epayPayment(paymentConfig, orderInfo, req.ReturnUrl) if err != nil { return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "epayPayment error: %v", err.Error()) } resp = &types.CheckoutOrderResponse{ CheckoutUrl: url, - Type: "url", + Type: "url", // Client should redirect to URL } + case paymentPlatform.Stripe: + // Process Stripe payment - creates payment sheet for client-side processing stripePayment, err := l.stripePayment(paymentConfig.Config, orderInfo, "") if err != nil { return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "stripePayment error: %v", err.Error()) } resp = &types.CheckoutOrderResponse{ - Type: "stripe", + Type: "stripe", // Client should use Stripe SDK Stripe: stripePayment, } + case paymentPlatform.AlipayF2F: + // Process Alipay Face-to-Face payment - generates QR code url, err := l.alipayF2fPayment(paymentConfig, orderInfo) if err != nil { - l.Errorw("[CheckoutOrderLogic] alipayF2fPayment error", logger.Field("error", err.Error())) + l.Errorw("[PurchaseCheckout] alipayF2fPayment error", logger.Field("error", err.Error())) return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "alipayF2fPayment error: %v", err.Error()) } resp = &types.CheckoutOrderResponse{ - Type: "qr", + Type: "qr", // Client should display QR code CheckoutUrl: url, } + case paymentPlatform.Balance: + // Process balance payment - validate user and process payment immediately if orderInfo.UserId == 0 { - l.Errorw("[CheckoutOrderLogic] user not found", logger.Field("userId", orderInfo.UserId)) + l.Errorw("[PurchaseCheckout] user not found", logger.Field("userId", orderInfo.UserId)) return nil, errors.Wrapf(xerr.NewErrCode(xerr.UserNotExist), "user not found") } - // find user + + // Retrieve user information for balance validation userInfo, err := l.svcCtx.UserModel.FindOne(l.ctx, orderInfo.UserId) if err != nil { - l.Errorw("[CheckoutOrderLogic] FindOne User error", logger.Field("error", err.Error())) + l.Errorw("[PurchaseCheckout] FindOne User error", logger.Field("error", err.Error())) return nil, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "FindOne error: %s", err.Error()) } - // balance + // Process balance payment with gift amount priority logic if err = l.balancePayment(userInfo, orderInfo); err != nil { return nil, err } resp = &types.CheckoutOrderResponse{ - Type: "balance", + Type: "balance", // Payment completed immediately } default: - l.Errorw("[CheckoutOrderLogic] payment method not found", logger.Field("method", orderInfo.Method)) + l.Errorw("[PurchaseCheckout] payment method not found", logger.Field("method", orderInfo.Method)) return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "payment method not found") } return } -// alipay f2f payment +// alipayF2fPayment processes Alipay Face-to-Face payment by generating a QR code +// It handles currency conversion and creates a pre-payment trade for QR code scanning func (l *PurchaseCheckoutLogic) alipayF2fPayment(pay *payment.Payment, info *order.Order) (string, error) { + // Parse Alipay F2F configuration from payment settings f2FConfig := payment.AlipayF2FConfig{} if err := json.Unmarshal([]byte(pay.Config), &f2FConfig); err != nil { - l.Errorw("[PurchaseCheckoutLogic] Unmarshal error", logger.Field("error", err.Error())) + l.Errorw("[PurchaseCheckout] Unmarshal Alipay config error", logger.Field("error", err.Error())) return "", errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "Unmarshal error: %s", err.Error()) } + + // Build notification URL for payment status callbacks notifyUrl := "" if pay.Domain != "" { notifyUrl = pay.Domain + "/v1/notify/" + pay.Platform + "/" + pay.Token @@ -135,6 +154,8 @@ func (l *PurchaseCheckoutLogic) alipayF2fPayment(pay *payment.Payment, info *ord } notifyUrl = "https://" + host + "/v1/notify/" + pay.Platform + "/" + pay.Token } + + // Initialize Alipay client with configuration client := alipay.NewClient(alipay.Config{ AppId: f2FConfig.AppId, PrivateKey: f2FConfig.PrivateKey, @@ -142,46 +163,53 @@ func (l *PurchaseCheckoutLogic) alipayF2fPayment(pay *payment.Payment, info *ord InvoiceName: f2FConfig.InvoiceName, NotifyURL: notifyUrl, }) - // Calculate the amount with exchange rate + + // Convert order amount to CNY using current exchange rate amount, err := l.queryExchangeRate("CNY", info.Amount) if err != nil { - l.Errorw("[CheckoutOrderLogic] queryExchangeRate error", logger.Field("error", err.Error())) + l.Errorw("[PurchaseCheckout] queryExchangeRate error", logger.Field("error", err.Error())) return "", errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "queryExchangeRate error: %s", err.Error()) } - convertAmount := int64(amount * 100) - // create payment + convertAmount := int64(amount * 100) // Convert to cents for API + + // Create pre-payment trade and generate QR code QRCode, err := client.PreCreateTrade(l.ctx, alipay.Order{ OrderNo: info.OrderNo, Amount: convertAmount, }) if err != nil { - l.Errorw("[CheckoutOrderLogic] PreCreateTrade error", logger.Field("error", err.Error())) + l.Errorw("[PurchaseCheckout] PreCreateTrade error", logger.Field("error", err.Error())) return "", errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "PreCreateTrade error: %s", err.Error()) } return QRCode, nil } -// Stripe Payment +// stripePayment processes Stripe payment by creating a payment sheet +// It supports various payment methods including WeChat Pay and Alipay through Stripe func (l *PurchaseCheckoutLogic) stripePayment(config string, info *order.Order, identifier string) (*types.StripePayment, error) { - // stripe WeChat pay or stripe alipay + // Parse Stripe configuration from payment settings stripeConfig := payment.StripeConfig{} if err := json.Unmarshal([]byte(config), &stripeConfig); err != nil { - l.Errorw("[CheckoutOrderLogic] Unmarshal error", logger.Field("error", err.Error())) + l.Errorw("[PurchaseCheckout] Unmarshal Stripe config error", logger.Field("error", err.Error())) return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "Unmarshal error: %s", err.Error()) } + + // Initialize Stripe client with API credentials client := stripe.NewClient(stripe.Config{ SecretKey: stripeConfig.SecretKey, PublicKey: stripeConfig.PublicKey, WebhookSecret: stripeConfig.WebhookSecret, }) - // Calculate the amount with exchange rate + + // Convert order amount to CNY using current exchange rate amount, err := l.queryExchangeRate("CNY", info.Amount) if err != nil { - l.Errorw("[CheckoutOrderLogic] queryExchangeRate error", logger.Field("error", err.Error())) + l.Errorw("[PurchaseCheckout] queryExchangeRate error", logger.Field("error", err.Error())) return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "queryExchangeRate error: %s", err.Error()) } - convertAmount := int64(amount * 100) - // create payment + convertAmount := int64(amount * 100) // Convert to cents for Stripe API + + // Create Stripe payment sheet for client-side processing result, err := client.CreatePaymentSheet(&stripe.Order{ OrderNo: info.OrderNo, Subscribe: strconv.FormatInt(info.SubscribeId, 10), @@ -193,37 +221,47 @@ func (l *PurchaseCheckoutLogic) stripePayment(config string, info *order.Order, Email: identifier, }) if err != nil { - l.Errorw("[CheckoutOrderLogic] CreatePaymentSheet error", logger.Field("error", err.Error())) + l.Errorw("[PurchaseCheckout] CreatePaymentSheet error", logger.Field("error", err.Error())) return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "CreatePaymentSheet error: %s", err.Error()) } - tradeNo := result.TradeNo + + // Prepare response data for client-side Stripe integration stripePayment := &types.StripePayment{ PublishableKey: stripeConfig.PublicKey, ClientSecret: result.ClientSecret, Method: stripeConfig.Payment, } - // save payment - info.TradeNo = tradeNo + + // Save Stripe trade number to order for tracking + info.TradeNo = result.TradeNo err = l.svcCtx.OrderModel.Update(l.ctx, info) if err != nil { - l.Errorw("[CheckoutOrderLogic] Update error", logger.Field("error", err.Error())) + l.Errorw("[PurchaseCheckout] Update order error", logger.Field("error", err.Error())) return nil, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "Update error: %s", err.Error()) } return stripePayment, nil } +// epayPayment processes EPay payment by generating a payment URL for redirect +// It handles currency conversion and creates a payment URL for external payment processing func (l *PurchaseCheckoutLogic) epayPayment(config *payment.Payment, info *order.Order, returnUrl string) (string, error) { + // Parse EPay configuration from payment settings epayConfig := payment.EPayConfig{} if err := json.Unmarshal([]byte(config.Config), &epayConfig); err != nil { - l.Errorw("[CheckoutOrderLogic] Unmarshal error", logger.Field("error", err.Error())) + l.Errorw("[PurchaseCheckout] Unmarshal EPay config error", logger.Field("error", err.Error())) return "", errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "Unmarshal error: %s", err.Error()) } + + // Initialize EPay client with merchant credentials client := epay.NewClient(epayConfig.Pid, epayConfig.Url, epayConfig.Key) - // Calculate the amount with exchange rate + + // Convert order amount to CNY using current exchange rate amount, err := l.queryExchangeRate("CNY", info.Amount) if err != nil { return "", err } + + // Build notification URL for payment status callbacks notifyUrl := "" if config.Domain != "" { notifyUrl = config.Domain + "/v1/notify/" + config.Platform + "/" + config.Token @@ -234,7 +272,8 @@ func (l *PurchaseCheckoutLogic) epayPayment(config *payment.Payment, info *order } notifyUrl = "https://" + host + "/v1/notify/" + config.Platform + "/" + config.Token } - // create payment + + // Create payment URL for user redirection url := client.CreatePayUrl(epay.Order{ Name: l.svcCtx.Config.Site.SiteName, Amount: amount, @@ -246,26 +285,34 @@ func (l *PurchaseCheckoutLogic) epayPayment(config *payment.Payment, info *order return url, nil } -// Query exchange rate +// queryExchangeRate converts the order amount from system currency to target currency +// It retrieves the current exchange rate and performs currency conversion if needed func (l *PurchaseCheckoutLogic) queryExchangeRate(to string, src int64) (amount float64, err error) { + // Convert cents to decimal amount amount = float64(src) / float64(100) - // query system currency + + // Retrieve system currency configuration currency, err := l.svcCtx.SystemModel.GetCurrencyConfig(l.ctx) if err != nil { - l.Errorw("[CheckoutOrderLogic] GetCurrencyConfig error", logger.Field("error", err.Error())) + l.Errorw("[PurchaseCheckout] GetCurrencyConfig error", logger.Field("error", err.Error())) return 0, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "GetCurrencyConfig error: %s", err.Error()) } + + // Parse currency configuration configs := struct { CurrencyUnit string CurrencySymbol string AccessKey string }{} tool.SystemConfigSliceReflectToStruct(currency, &configs) + + // Skip conversion if no exchange rate API key configured if configs.AccessKey == "" { return amount, nil } + + // Convert currency if system currency differs from target currency if configs.CurrencyUnit != to { - // query exchange rate result, err := exchangeRate.GetExchangeRete(configs.CurrencyUnit, to, configs.AccessKey, 1) if err != nil { return 0, err @@ -275,71 +322,123 @@ func (l *PurchaseCheckoutLogic) queryExchangeRate(to string, src int64) (amount return amount, nil } -// Balance payment +// balancePayment processes balance payment with gift amount priority logic +// It prioritizes using gift amount first, then regular balance, and creates proper audit logs func (l *PurchaseCheckoutLogic) balancePayment(u *user.User, o *order.Order) error { + if o.Amount == 0 { + // No payment required for zero-amount orders + return nil + } + var userInfo user.User err := l.svcCtx.UserModel.Transaction(l.ctx, func(db *gorm.DB) error { + // Retrieve latest user information with row-level locking err := db.Model(&user.User{}).Where("id = ?", u.Id).First(&userInfo).Error if err != nil { return err } - if o.GiftAmount != 0 { - if userInfo.GiftAmount < o.GiftAmount { - return errors.Wrapf(xerr.NewErrCode(xerr.InsufficientBalance), "Insufficient gift balance") - } - // deduct gift amount - userInfo.GiftAmount -= o.GiftAmount + // Check if user has sufficient total balance (regular + gift) + totalAvailable := userInfo.Balance + userInfo.GiftAmount + if totalAvailable < o.Amount { + return errors.Wrapf(xerr.NewErrCode(xerr.InsufficientBalance), + "Insufficient balance: required %d, available %d", o.Amount, totalAvailable) } - if userInfo.Balance < o.Amount { - return errors.Wrapf(xerr.NewErrCode(xerr.InsufficientBalance), "Insufficient balance") - } - // deduct balance - userInfo.Balance -= o.Amount + // Calculate payment distribution: prioritize gift amount first + var giftUsed, balanceUsed int64 + remainingAmount := o.Amount - userInfo.GiftAmount -= o.GiftAmount + if userInfo.GiftAmount >= remainingAmount { + // Gift amount covers the entire payment + giftUsed = remainingAmount + balanceUsed = 0 + } else { + // Use all available gift amount, then regular balance + giftUsed = userInfo.GiftAmount + balanceUsed = remainingAmount - giftUsed + } + + // Update user balances + userInfo.GiftAmount -= giftUsed + userInfo.Balance -= balanceUsed + + // Save updated user information err = l.svcCtx.UserModel.Update(l.ctx, &userInfo) if err != nil { return err } - // create balance log - balanceLog := &user.BalanceLog{ - Id: 0, - UserId: u.Id, - Amount: o.Amount, - Type: 3, - OrderId: o.Id, - Balance: userInfo.Balance, + + // Create gift amount log if gift amount was used + if giftUsed > 0 { + giftLog := &user.GiftAmountLog{ + UserId: u.Id, + UserSubscribeId: 0, // Will be updated when subscription is created + OrderNo: o.OrderNo, + Type: 2, // Type 2 represents gift amount decrease/usage + Amount: giftUsed, + Balance: userInfo.GiftAmount, + Remark: "Purchase payment", + } + err = db.Create(giftLog).Error + if err != nil { + return err + } } - err = db.Create(balanceLog).Error + + // Create balance log if regular balance was used + if balanceUsed > 0 { + balanceLog := &user.BalanceLog{ + UserId: u.Id, + Amount: balanceUsed, + Type: 3, // Type 3 represents payment deduction + OrderId: o.Id, + Balance: userInfo.Balance, + } + err = db.Create(balanceLog).Error + if err != nil { + return err + } + } + + // Store gift amount used in order for potential refund tracking + o.GiftAmount = giftUsed + err = l.svcCtx.OrderModel.Update(l.ctx, o) if err != nil { return err } - return db.Model(&order.Order{}).Where("id = ?", o.Id).Updates(map[string]interface{}{ - "status": 2, // 2 means paid - }).Error + + // Mark order as paid (status = 2) + return l.svcCtx.OrderModel.UpdateOrderStatus(l.ctx, o.OrderNo, 2, db) }) + if err != nil { - l.Errorw("[CheckoutOrderLogic] Transaction error", logger.Field("error", err.Error()), logger.Field("orderNo", o.OrderNo)) + l.Errorw("[PurchaseCheckout] Balance payment transaction error", + logger.Field("error", err.Error()), + logger.Field("orderNo", o.OrderNo), + logger.Field("userId", u.Id)) return err } - // create activity order task + + // Enqueue order activation task for immediate processing payload := queueType.ForthwithActivateOrderPayload{ OrderNo: o.OrderNo, } bytes, err := json.Marshal(payload) if err != nil { - l.Errorw("[CheckoutOrderLogic] Marshal error", logger.Field("error", err.Error())) + l.Errorw("[PurchaseCheckout] Marshal activation payload error", logger.Field("error", err.Error())) return err } task := asynq.NewTask(queueType.ForthwithActivateOrder, bytes) _, err = l.svcCtx.Queue.EnqueueContext(l.ctx, task) if err != nil { - l.Errorw("[CheckoutOrderLogic] Enqueue error", logger.Field("error", err.Error())) + l.Errorw("[PurchaseCheckout] Enqueue activation task error", logger.Field("error", err.Error())) return err } - l.Logger.Info("[CheckoutOrderLogic] Enqueue success", logger.Field("orderNo", o.OrderNo)) + + l.Logger.Info("[PurchaseCheckout] Balance payment completed successfully", + logger.Field("orderNo", o.OrderNo), + logger.Field("userId", u.Id)) return nil } diff --git a/internal/logic/public/portal/purchaseLogic.go b/internal/logic/public/portal/purchaseLogic.go index 52fac66..5bc8786 100644 --- a/internal/logic/public/portal/purchaseLogic.go +++ b/internal/logic/public/portal/purchaseLogic.go @@ -93,7 +93,6 @@ func (l *PurchaseLogic) Purchase(req *types.PortalPurchaseRequest) (resp *types. } // Calculate the handling fee amount -= couponAmount - var deductionAmount int64 // find payment method paymentConfig, err := l.svcCtx.PaymentModel.FindOne(l.ctx, req.Payment) if err != nil { @@ -118,7 +117,7 @@ func (l *PurchaseLogic) Purchase(req *types.PortalPurchaseRequest) (resp *types. Price: price, Amount: amount, Discount: discountAmount, - GiftAmount: deductionAmount, + GiftAmount: 0, Coupon: req.Coupon, CouponDiscount: couponAmount, PaymentId: req.Payment, diff --git a/internal/logic/public/user/calculateRemainingAmount.go b/internal/logic/public/user/calculateRemainingAmount.go index 7e79016..d601c2e 100644 --- a/internal/logic/public/user/calculateRemainingAmount.go +++ b/internal/logic/public/user/calculateRemainingAmount.go @@ -38,6 +38,7 @@ func CalculateRemainingAmount(ctx context.Context, svcCtx *svc.ServiceContext, u orderQuantity := orderDetails.Quantity // Calculate Order Amount orderAmount := orderDetails.Amount + orderDetails.GiftAmount + if len(orderDetails.SubOrders) > 0 { for _, subOrder := range orderDetails.SubOrders { if subOrder.Status == 2 || subOrder.Status == 5 { @@ -47,7 +48,7 @@ func CalculateRemainingAmount(ctx context.Context, svcCtx *svc.ServiceContext, u } } // Calculate Remaining Amount - remainingAmount := deduction.CalculateRemainingAmount( + remainingAmount, err := deduction.CalculateRemainingAmount( deduction.Subscribe{ StartTime: userSubscribe.StartTime, ExpireTime: userSubscribe.ExpireTime, @@ -64,5 +65,8 @@ func CalculateRemainingAmount(ctx context.Context, svcCtx *svc.ServiceContext, u Quantity: orderQuantity, }, ) + if err != nil { + return 0, errors.Wrapf(xerr.NewErrCode(500), "CalculateRemainingAmount failed, userSubscribeId: %d, err: %v", userSubscribeId, err) + } return remainingAmount, nil } diff --git a/internal/logic/public/user/unsubscribeLogic.go b/internal/logic/public/user/unsubscribeLogic.go index e23eabc..e2c38b1 100644 --- a/internal/logic/public/user/unsubscribeLogic.go +++ b/internal/logic/public/user/unsubscribeLogic.go @@ -21,7 +21,7 @@ type UnsubscribeLogic struct { svcCtx *svc.ServiceContext } -// NewUnsubscribeLogic Unsubscribe +// NewUnsubscribeLogic creates a new instance of UnsubscribeLogic for handling subscription cancellation func NewUnsubscribeLogic(ctx context.Context, svcCtx *svc.ServiceContext) *UnsubscribeLogic { return &UnsubscribeLogic{ Logger: logger.WithContext(ctx), @@ -30,39 +30,90 @@ func NewUnsubscribeLogic(ctx context.Context, svcCtx *svc.ServiceContext) *Unsub } } +// Unsubscribe handles the subscription cancellation process with proper refund distribution +// It prioritizes refunding to gift amount for balance-paid orders, then to regular balance func (l *UnsubscribeLogic) Unsubscribe(req *types.UnsubscribeRequest) error { u, ok := l.ctx.Value(constant.CtxKeyUser).(*user.User) if !ok { logger.Error("current user is not found in context") return errors.Wrapf(xerr.NewErrCode(xerr.InvalidAccess), "Invalid Access") } + // Calculate the remaining amount to refund based on unused subscription time/traffic remainingAmount, err := CalculateRemainingAmount(l.ctx, l.svcCtx, req.Id) if err != nil { return err } - // update user subscribe + + // Process unsubscription in a database transaction to ensure data consistency err = l.svcCtx.UserModel.Transaction(l.ctx, func(db *gorm.DB) error { + // Find and update subscription status to cancelled (status = 4) var userSub user.Subscribe - if err := db.Model(&user.Subscribe{}).Where("id = ?", req.Id).First(&userSub).Error; err != nil { + if err = db.Model(&user.Subscribe{}).Where("id = ?", req.Id).First(&userSub).Error; err != nil { return err } - userSub.Status = 4 - if err := l.svcCtx.UserModel.UpdateSubscribe(l.ctx, &userSub); err != nil { + userSub.Status = 4 // Set status to cancelled + if err = l.svcCtx.UserModel.UpdateSubscribe(l.ctx, &userSub); err != nil { return err } - balance := remainingAmount + u.Balance - // insert deduction log - balanceLog := user.BalanceLog{ - UserId: userSub.UserId, - OrderId: userSub.OrderId, - Amount: remainingAmount, - Type: 4, - Balance: balance, - } - if err := db.Model(&user.BalanceLog{}).Create(&balanceLog).Error; err != nil { + + // Query the original order information to determine refund strategy + orderInfo, err := l.svcCtx.OrderModel.FindOne(l.ctx, userSub.OrderId) + if err != nil { return err } - // update user balance + // Calculate refund distribution based on payment method and gift amount priority + var balance, gift int64 + if orderInfo.Method == "balance" { + // For balance-paid orders, prioritize refunding to gift amount first + if orderInfo.GiftAmount >= remainingAmount { + // Gift amount covers the entire refund - refund all to gift balance + gift = remainingAmount + balance = u.Balance // Regular balance remains unchanged + } else { + // Gift amount insufficient - refund to gift first, remainder to regular balance + gift = orderInfo.GiftAmount + balance = u.Balance + (remainingAmount - orderInfo.GiftAmount) + } + } else { + // For non-balance payment orders, refund entirely to regular balance + balance = remainingAmount + u.Balance + gift = 0 + } + + // Create balance log entry only if there's an actual regular balance refund + balanceRefundAmount := balance - u.Balance + if balanceRefundAmount > 0 { + balanceLog := user.BalanceLog{ + UserId: userSub.UserId, + OrderId: userSub.OrderId, + Amount: balanceRefundAmount, + Type: 4, // Type 4 represents refund transaction + Balance: balance, + } + if err := db.Model(&user.BalanceLog{}).Create(&balanceLog).Error; err != nil { + return err + } + } + + // Create gift amount log entry if there's a gift balance refund + if gift > 0 { + giftLog := user.GiftAmountLog{ + UserId: userSub.UserId, + UserSubscribeId: userSub.Id, + OrderNo: orderInfo.OrderNo, + Type: 1, // Type 1 represents gift amount increase + Amount: gift, + Balance: u.GiftAmount + gift, + Remark: "Unsubscribe refund", + } + if err := db.Model(&user.GiftAmountLog{}).Create(&giftLog).Error; err != nil { + return err + } + // Update user's gift amount + u.GiftAmount += gift + } + + // Update user's regular balance and save changes to database u.Balance = balance return l.svcCtx.UserModel.Update(l.ctx, u) }) diff --git a/pkg/deduction/deduction.go b/pkg/deduction/deduction.go index aa23074..46f32c8 100644 --- a/pkg/deduction/deduction.go +++ b/pkg/deduction/deduction.go @@ -1,90 +1,302 @@ +// Package deduction provides functionality for calculating remaining amounts +// in subscription billing systems, supporting various time units and traffic-based calculations. package deduction import ( + "errors" + "fmt" + "math" "time" "github.com/perfect-panel/server/pkg/tool" ) const ( - UnitTimeNoLimit = "NoLimit" - UnitTimeYear = "Year" - UnitTimeMonth = "Month" - UnitTimeDay = "Day" - UintTimeHour = "Hour" - UintTimeMinute = "Minute" + // Time unit constants for subscription billing + UnitTimeNoLimit = "NoLimit" // Unlimited time subscription + UnitTimeYear = "Year" // Annual subscription + UnitTimeMonth = "Month" // Monthly subscription + UnitTimeDay = "Day" // Daily subscription + UnitTimeHour = "Hour" // Hourly subscription + UnitTimeMinute = "Minute" // Per-minute subscription - ResetCycleNone = 0 - ResetCycle1st = 1 - ResetCycleMonthly = 2 - ResetCycleYear = 3 + // Reset cycle constants for traffic resets + ResetCycleNone = 0 // No reset cycle + ResetCycle1st = 1 // Reset on 1st of each month + ResetCycleMonthly = 2 // Reset monthly based on start date + ResetCycleYear = 3 // Reset yearly based on start date + + // Safety limits for overflow protection + maxInt64 = math.MaxInt64 + minInt64 = math.MinInt64 ) +// Error definitions for validation and calculation failures +var ( + ErrInvalidQuantity = errors.New("order quantity cannot be zero or negative") + ErrInvalidAmount = errors.New("order amount cannot be negative") + ErrInvalidTraffic = errors.New("traffic values cannot be negative") + ErrInvalidTimeRange = errors.New("expire time must be after start time") + ErrInvalidUnitTime = errors.New("invalid unit time") + ErrInvalidDeductionRatio = errors.New("deduction ratio must be between 0 and 100") + ErrOverflow = errors.New("calculation overflow") +) + +// Subscribe represents a subscription with time and traffic limits type Subscribe struct { - StartTime time.Time - ExpireTime time.Time - Traffic int64 - Download int64 - Upload int64 - UnitTime string - UnitPrice int64 - ResetCycle int64 - DeductionRatio int64 + StartTime time.Time // Subscription start time + ExpireTime time.Time // Subscription expiration time + Traffic int64 // Total traffic allowance in bytes + Download int64 // Downloaded traffic in bytes + Upload int64 // Uploaded traffic in bytes + UnitTime string // Time unit for billing (Year, Month, Day, etc.) + UnitPrice int64 // Price per unit time + ResetCycle int64 // Traffic reset cycle + DeductionRatio int64 // Deduction ratio for weighted calculations (0-100) } +// Order represents a purchase order for subscription calculation type Order struct { - Amount int64 - Quantity int64 + Amount int64 // Total order amount + Quantity int64 // Order quantity } -func CalculateRemainingAmount(sub Subscribe, order Order) int64 { - if sub.UnitTime == UnitTimeNoLimit && sub.ResetCycle != 0 { - return 0 +// Validate checks if the Subscribe struct contains valid data +func (s *Subscribe) Validate() error { + if s.Traffic < 0 || s.Download < 0 || s.Upload < 0 { + return ErrInvalidTraffic } - // 实际单价 - sub.UnitPrice = order.Amount / order.Quantity - now := time.Now() + + if s.Download+s.Upload > s.Traffic { + return fmt.Errorf("download + upload (%d) cannot exceed total traffic (%d)", s.Download+s.Upload, s.Traffic) + } + + if !s.ExpireTime.After(s.StartTime) { + return ErrInvalidTimeRange + } + + if s.DeductionRatio < 0 || s.DeductionRatio > 100 { + return ErrInvalidDeductionRatio + } + + validUnitTimes := []string{UnitTimeNoLimit, UnitTimeYear, UnitTimeMonth, UnitTimeDay, UnitTimeHour, UnitTimeMinute} + valid := false + for _, ut := range validUnitTimes { + if s.UnitTime == ut { + valid = true + break + } + } + if !valid { + return ErrInvalidUnitTime + } + + return nil +} + +// Validate checks if the Order struct contains valid data +func (o *Order) Validate() error { + if o.Quantity <= 0 { + return ErrInvalidQuantity + } + if o.Amount < 0 { + return ErrInvalidAmount + } + return nil +} + +// safeMultiply performs multiplication with overflow protection +func safeMultiply(a, b int64) (int64, error) { + if a == 0 || b == 0 { + return 0, nil + } + + if a > 0 && b > 0 { + if a > maxInt64/b { + return 0, ErrOverflow + } + } else if a < 0 && b < 0 { + if a < maxInt64/b { + return 0, ErrOverflow + } + } else { + if (a > 0 && b < minInt64/a) || (a < 0 && b > minInt64/a) { + return 0, ErrOverflow + } + } + + return a * b, nil +} + +// safeAdd performs addition with overflow protection +func safeAdd(a, b int64) (int64, error) { + if (b > 0 && a > maxInt64-b) || (b < 0 && a < minInt64-b) { + return 0, ErrOverflow + } + return a + b, nil +} + +// safeDivide performs division with zero-division protection +func safeDivide(a, b int64) (int64, error) { + if b == 0 { + return 0, errors.New("division by zero") + } + return a / b, nil +} + +// CalculateRemainingAmount calculates the remaining refund amount for a subscription +// based on unused time and traffic. Returns the amount and any calculation errors. +func CalculateRemainingAmount(sub Subscribe, order Order) (int64, error) { + if err := sub.Validate(); err != nil { + return 0, fmt.Errorf("invalid subscription: %w", err) + } + + if err := order.Validate(); err != nil { + return 0, fmt.Errorf("invalid order: %w", err) + } + + if sub.UnitTime == UnitTimeNoLimit && sub.ResetCycle != 0 { + return 0, nil + } + + unitPrice, err := safeDivide(order.Amount, order.Quantity) + if err != nil { + return 0, fmt.Errorf("failed to calculate unit price: %w", err) + } + sub.UnitPrice = unitPrice + + loc, err := time.LoadLocation(sub.StartTime.Location().String()) + if err != nil { + loc = time.UTC + } + now := time.Now().In(loc) + switch sub.UnitTime { case UnitTimeNoLimit: - usedTraffic := sub.Traffic - sub.Download - sub.Upload - unitPrice := float64(order.Amount) / float64(sub.Traffic) - return int64(float64(usedTraffic) * unitPrice) + return calculateNoLimitAmount(sub, order) case UnitTimeYear: remainingYears := tool.YearDiff(now, sub.ExpireTime) - remainingUnitTimeAmount := calculateRemainingUnitTimeAmount(sub) - return int64(remainingYears)*sub.UnitPrice + remainingUnitTimeAmount + remainingUnitTimeAmount, err := calculateRemainingUnitTimeAmount(sub) + if err != nil { + return 0, err + } + + yearAmount, err := safeMultiply(int64(remainingYears), sub.UnitPrice) + if err != nil { + return 0, fmt.Errorf("year calculation overflow: %w", err) + } + + total, err := safeAdd(yearAmount, remainingUnitTimeAmount) + if err != nil { + return 0, fmt.Errorf("total calculation overflow: %w", err) + } + + return total, nil case UnitTimeMonth: remainingMonths := tool.MonthDiff(now, sub.ExpireTime) - remainingUnitTimeAmount := calculateRemainingUnitTimeAmount(sub) - return int64(remainingMonths)*sub.UnitPrice + remainingUnitTimeAmount + remainingUnitTimeAmount, err := calculateRemainingUnitTimeAmount(sub) + if err != nil { + return 0, err + } + + monthAmount, err := safeMultiply(int64(remainingMonths), sub.UnitPrice) + if err != nil { + return 0, fmt.Errorf("month calculation overflow: %w", err) + } + + total, err := safeAdd(monthAmount, remainingUnitTimeAmount) + if err != nil { + return 0, fmt.Errorf("total calculation overflow: %w", err) + } + + return total, nil + case UnitTimeDay: remainingDays := tool.DayDiff(now, sub.ExpireTime) - remainingUnitTimeAmount := calculateRemainingUnitTimeAmount(sub) - return remainingDays*sub.UnitPrice + remainingUnitTimeAmount + remainingUnitTimeAmount, err := calculateRemainingUnitTimeAmount(sub) + if err != nil { + return 0, err + } + + dayAmount, err := safeMultiply(remainingDays, sub.UnitPrice) + if err != nil { + return 0, fmt.Errorf("day calculation overflow: %w", err) + } + + total, err := safeAdd(dayAmount, remainingUnitTimeAmount) + if err != nil { + return 0, fmt.Errorf("total calculation overflow: %w", err) + } + + return total, nil } - return 0 + return 0, nil } -func calculateRemainingUnitTimeAmount(sub Subscribe) int64 { +// calculateNoLimitAmount calculates refund amount for unlimited time subscriptions +// based on unused traffic only +func calculateNoLimitAmount(sub Subscribe, order Order) (int64, error) { + if sub.Traffic == 0 { + return 0, nil + } + + usedTraffic := sub.Traffic - sub.Download - sub.Upload + if usedTraffic < 0 { + usedTraffic = 0 + } + + unitPrice := float64(order.Amount) / float64(sub.Traffic) + result := float64(usedTraffic) * unitPrice + + if result > float64(maxInt64) || result < float64(minInt64) { + return 0, ErrOverflow + } + + return int64(result), nil +} + +// calculateRemainingUnitTimeAmount calculates the remaining amount based on +// both time and traffic usage, applying deduction ratios when specified +func calculateRemainingUnitTimeAmount(sub Subscribe) (int64, error) { now := time.Now() trafficWeight, timeWeight := calculateWeights(sub.DeductionRatio) remainingDays, totalDays := getRemainingAndTotalDays(sub, now) - remainingTraffic := sub.Traffic - sub.Download - sub.Upload - remainingTimeAmount := calculateProportionalAmount(sub.UnitPrice, remainingDays, totalDays) - remainingTrafficAmount := calculateProportionalAmount(sub.UnitPrice, remainingTraffic, sub.Traffic) - if sub.Traffic == 0 { - return remainingTimeAmount + + if totalDays == 0 { + return 0, nil } + + remainingTraffic := sub.Traffic - sub.Download - sub.Upload + if remainingTraffic < 0 { + remainingTraffic = 0 + } + + remainingTimeAmount, err := calculateProportionalAmount(sub.UnitPrice, remainingDays, totalDays) + if err != nil { + return 0, fmt.Errorf("time amount calculation failed: %w", err) + } + + if sub.Traffic == 0 { + return remainingTimeAmount, nil + } + + remainingTrafficAmount, err := calculateProportionalAmount(sub.UnitPrice, remainingTraffic, sub.Traffic) + if err != nil { + return 0, fmt.Errorf("traffic amount calculation failed: %w", err) + } + if sub.DeductionRatio != 0 { return calculateWeightedAmount(sub.UnitPrice, remainingTraffic, sub.Traffic, remainingDays, totalDays, trafficWeight, timeWeight) } - return min(remainingTimeAmount, remainingTrafficAmount) + return min(remainingTimeAmount, remainingTrafficAmount), nil } +// calculateWeights converts deduction ratio to traffic and time weights +// for weighted calculations func calculateWeights(deductionRatio int64) (float64, float64) { if deductionRatio == 0 { return 0, 0 @@ -94,20 +306,32 @@ func calculateWeights(deductionRatio int64) (float64, float64) { return trafficWeight, timeWeight } +// getRemainingAndTotalDays calculates remaining and total days based on +// the subscription's reset cycle configuration func getRemainingAndTotalDays(sub Subscribe, now time.Time) (int64, int64) { switch sub.ResetCycle { case ResetCycleNone: - remaining := sub.ExpireTime.Sub(now).Hours() / 24 total := sub.ExpireTime.Sub(sub.StartTime).Hours() / 24 + if remaining < 0 { + remaining = 0 + } + if total < 0 { + total = 0 + } return int64(remaining), int64(total) case ResetCycle1st: return tool.DaysToNextMonth(now), tool.GetLastDayOfMonth(now) case ResetCycleMonthly: - // -1 to include the current day - return tool.DaysToMonthDay(now, sub.StartTime.Day()) - 1, tool.DaysToMonthDay(now, sub.StartTime.Day()) + remaining := tool.DaysToMonthDay(now, sub.StartTime.Day()) - 1 + total := tool.DaysToMonthDay(now, sub.StartTime.Day()) + if remaining < 0 { + remaining = 0 + } + return remaining, total + case ResetCycleYear: return tool.DaysToYearDay(now, int(sub.StartTime.Month()), sub.StartTime.Day()), tool.GetYearDays(now, int(sub.StartTime.Month()), sub.StartTime.Day()) @@ -115,13 +339,36 @@ func getRemainingAndTotalDays(sub Subscribe, now time.Time) (int64, int64) { return 0, 0 } -func calculateWeightedAmount(unitPrice, remainingTraffic, totalTraffic, remainingDays, totalDays int64, trafficWeight, timeWeight float64) int64 { +// calculateWeightedAmount applies weighted calculation combining both time and traffic +// remaining ratios based on the specified weights +func calculateWeightedAmount(unitPrice, remainingTraffic, totalTraffic, remainingDays, totalDays int64, trafficWeight, timeWeight float64) (int64, error) { + if totalDays == 0 || totalTraffic == 0 { + return 0, nil + } + remainingTimeRatio := float64(remainingDays) / float64(totalDays) remainingTrafficRatio := float64(remainingTraffic) / float64(totalTraffic) weightedRemainingRatio := (timeWeight * remainingTimeRatio) + (trafficWeight * remainingTrafficRatio) - return int64(float64(unitPrice) * weightedRemainingRatio) + + result := float64(unitPrice) * weightedRemainingRatio + if result > float64(maxInt64) || result < float64(minInt64) { + return 0, ErrOverflow + } + + return int64(result), nil } -func calculateProportionalAmount(unitPrice, remaining, total int64) int64 { - return int64(float64(unitPrice) * (float64(remaining) / float64(total))) +// calculateProportionalAmount calculates proportional amount based on +// remaining vs total ratio with overflow protection +func calculateProportionalAmount(unitPrice, remaining, total int64) (int64, error) { + if total == 0 { + return 0, nil + } + + result := float64(unitPrice) * (float64(remaining) / float64(total)) + if result > float64(maxInt64) || result < float64(minInt64) { + return 0, ErrOverflow + } + + return int64(result), nil } diff --git a/pkg/deduction/deduction_test.go b/pkg/deduction/deduction_test.go new file mode 100644 index 0000000..0e96555 --- /dev/null +++ b/pkg/deduction/deduction_test.go @@ -0,0 +1,665 @@ +package deduction + +import ( + "math" + "testing" + "time" +) + +func TestSubscribe_Validate(t *testing.T) { + tests := []struct { + name string + sub Subscribe + wantErr bool + errType error + }{ + { + name: "valid subscription", + sub: Subscribe{ + StartTime: time.Now(), + ExpireTime: time.Now().Add(24 * time.Hour), + Traffic: 1000, + Download: 100, + Upload: 200, + UnitTime: UnitTimeMonth, + DeductionRatio: 50, + }, + wantErr: false, + }, + { + name: "negative traffic", + sub: Subscribe{ + StartTime: time.Now(), + ExpireTime: time.Now().Add(24 * time.Hour), + Traffic: -1000, + Download: 100, + Upload: 200, + UnitTime: UnitTimeMonth, + DeductionRatio: 50, + }, + wantErr: true, + errType: ErrInvalidTraffic, + }, + { + name: "negative download", + sub: Subscribe{ + StartTime: time.Now(), + ExpireTime: time.Now().Add(24 * time.Hour), + Traffic: 1000, + Download: -100, + Upload: 200, + UnitTime: UnitTimeMonth, + DeductionRatio: 50, + }, + wantErr: true, + errType: ErrInvalidTraffic, + }, + { + name: "download + upload exceeds traffic", + sub: Subscribe{ + StartTime: time.Now(), + ExpireTime: time.Now().Add(24 * time.Hour), + Traffic: 1000, + Download: 600, + Upload: 500, + UnitTime: UnitTimeMonth, + DeductionRatio: 50, + }, + wantErr: true, + }, + { + name: "expire time before start time", + sub: Subscribe{ + StartTime: time.Now(), + ExpireTime: time.Now().Add(-24 * time.Hour), + Traffic: 1000, + Download: 100, + Upload: 200, + UnitTime: UnitTimeMonth, + DeductionRatio: 50, + }, + wantErr: true, + errType: ErrInvalidTimeRange, + }, + { + name: "invalid deduction ratio - negative", + sub: Subscribe{ + StartTime: time.Now(), + ExpireTime: time.Now().Add(24 * time.Hour), + Traffic: 1000, + Download: 100, + Upload: 200, + UnitTime: UnitTimeMonth, + DeductionRatio: -10, + }, + wantErr: true, + errType: ErrInvalidDeductionRatio, + }, + { + name: "invalid deduction ratio - over 100", + sub: Subscribe{ + StartTime: time.Now(), + ExpireTime: time.Now().Add(24 * time.Hour), + Traffic: 1000, + Download: 100, + Upload: 200, + UnitTime: UnitTimeMonth, + DeductionRatio: 150, + }, + wantErr: true, + errType: ErrInvalidDeductionRatio, + }, + { + name: "invalid unit time", + sub: Subscribe{ + StartTime: time.Now(), + ExpireTime: time.Now().Add(24 * time.Hour), + Traffic: 1000, + Download: 100, + Upload: 200, + UnitTime: "InvalidUnit", + DeductionRatio: 50, + }, + wantErr: true, + errType: ErrInvalidUnitTime, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.sub.Validate() + if (err != nil) != tt.wantErr { + t.Errorf("Subscribe.Validate() error = %v, wantErr %v", err, tt.wantErr) + return + } + if tt.errType != nil && err != tt.errType { + t.Errorf("Subscribe.Validate() error = %v, want %v", err, tt.errType) + } + }) + } +} + +func TestOrder_Validate(t *testing.T) { + tests := []struct { + name string + order Order + wantErr bool + errType error + }{ + { + name: "valid order", + order: Order{Amount: 1000, Quantity: 2}, + wantErr: false, + }, + { + name: "zero quantity", + order: Order{Amount: 1000, Quantity: 0}, + wantErr: true, + errType: ErrInvalidQuantity, + }, + { + name: "negative quantity", + order: Order{Amount: 1000, Quantity: -1}, + wantErr: true, + errType: ErrInvalidQuantity, + }, + { + name: "negative amount", + order: Order{Amount: -1000, Quantity: 2}, + wantErr: true, + errType: ErrInvalidAmount, + }, + { + name: "zero amount is valid", + order: Order{Amount: 0, Quantity: 1}, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.order.Validate() + if (err != nil) != tt.wantErr { + t.Errorf("Order.Validate() error = %v, wantErr %v", err, tt.wantErr) + return + } + if tt.errType != nil && err != tt.errType { + t.Errorf("Order.Validate() error = %v, want %v", err, tt.errType) + } + }) + } +} + +func TestSafeMultiply(t *testing.T) { + tests := []struct { + name string + a, b int64 + want int64 + wantErr bool + }{ + { + name: "normal multiplication", + a: 10, + b: 20, + want: 200, + wantErr: false, + }, + { + name: "zero multiplication", + a: 10, + b: 0, + want: 0, + wantErr: false, + }, + { + name: "negative multiplication", + a: -10, + b: 20, + want: -200, + wantErr: false, + }, + { + name: "overflow case", + a: math.MaxInt64, + b: 2, + want: 0, + wantErr: true, + }, + { + name: "large numbers no overflow", + a: 1000000, + b: 1000000, + want: 1000000000000, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := safeMultiply(tt.a, tt.b) + if (err != nil) != tt.wantErr { + t.Errorf("safeMultiply() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("safeMultiply() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestSafeAdd(t *testing.T) { + tests := []struct { + name string + a, b int64 + want int64 + wantErr bool + }{ + { + name: "normal addition", + a: 10, + b: 20, + want: 30, + wantErr: false, + }, + { + name: "negative addition", + a: -10, + b: 5, + want: -5, + wantErr: false, + }, + { + name: "overflow case", + a: math.MaxInt64, + b: 1, + want: 0, + wantErr: true, + }, + { + name: "underflow case", + a: math.MinInt64, + b: -1, + want: 0, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := safeAdd(tt.a, tt.b) + if (err != nil) != tt.wantErr { + t.Errorf("safeAdd() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("safeAdd() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestSafeDivide(t *testing.T) { + tests := []struct { + name string + a, b int64 + want int64 + wantErr bool + }{ + { + name: "normal division", + a: 20, + b: 10, + want: 2, + wantErr: false, + }, + { + name: "division by zero", + a: 20, + b: 0, + want: 0, + wantErr: true, + }, + { + name: "negative division", + a: -20, + b: 10, + want: -2, + wantErr: false, + }, + { + name: "zero dividend", + a: 0, + b: 10, + want: 0, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := safeDivide(tt.a, tt.b) + if (err != nil) != tt.wantErr { + t.Errorf("safeDivide() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("safeDivide() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestCalculateWeights(t *testing.T) { + tests := []struct { + name string + deductionRatio int64 + wantTrafficWeight float64 + wantTimeWeight float64 + }{ + { + name: "zero ratio", + deductionRatio: 0, + wantTrafficWeight: 0, + wantTimeWeight: 0, + }, + { + name: "50% ratio", + deductionRatio: 50, + wantTrafficWeight: 0.5, + wantTimeWeight: 0.5, + }, + { + name: "75% ratio", + deductionRatio: 75, + wantTrafficWeight: 0.75, + wantTimeWeight: 0.25, + }, + { + name: "100% ratio", + deductionRatio: 100, + wantTrafficWeight: 1.0, + wantTimeWeight: 0.0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotTrafficWeight, gotTimeWeight := calculateWeights(tt.deductionRatio) + if gotTrafficWeight != tt.wantTrafficWeight { + t.Errorf("calculateWeights() trafficWeight = %v, want %v", gotTrafficWeight, tt.wantTrafficWeight) + } + if gotTimeWeight != tt.wantTimeWeight { + t.Errorf("calculateWeights() timeWeight = %v, want %v", gotTimeWeight, tt.wantTimeWeight) + } + }) + } +} + +func TestCalculateProportionalAmount(t *testing.T) { + tests := []struct { + name string + unitPrice int64 + remaining int64 + total int64 + want int64 + wantErr bool + }{ + { + name: "normal calculation", + unitPrice: 100, + remaining: 50, + total: 100, + want: 50, + wantErr: false, + }, + { + name: "zero total", + unitPrice: 100, + remaining: 50, + total: 0, + want: 0, + wantErr: false, + }, + { + name: "zero remaining", + unitPrice: 100, + remaining: 0, + total: 100, + want: 0, + wantErr: false, + }, + { + name: "quarter remaining", + unitPrice: 200, + remaining: 25, + total: 100, + want: 50, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := calculateProportionalAmount(tt.unitPrice, tt.remaining, tt.total) + if (err != nil) != tt.wantErr { + t.Errorf("calculateProportionalAmount() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("calculateProportionalAmount() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestCalculateNoLimitAmount(t *testing.T) { + tests := []struct { + name string + sub Subscribe + order Order + want int64 + wantErr bool + }{ + { + name: "normal no limit calculation", + sub: Subscribe{ + Traffic: 1000, + Download: 300, + Upload: 200, + }, + order: Order{ + Amount: 1000, + }, + want: 500, // (1000 - 300 - 200) / 1000 * 1000 = 500 + wantErr: false, + }, + { + name: "zero traffic", + sub: Subscribe{ + Traffic: 0, + Download: 0, + Upload: 0, + }, + order: Order{ + Amount: 1000, + }, + want: 0, + wantErr: false, + }, + { + name: "overused traffic", + sub: Subscribe{ + Traffic: 1000, + Download: 600, + Upload: 500, + }, + order: Order{ + Amount: 1000, + }, + want: 0, // usedTraffic would be negative, clamped to 0 + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := calculateNoLimitAmount(tt.sub, tt.order) + if (err != nil) != tt.wantErr { + t.Errorf("calculateNoLimitAmount() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("calculateNoLimitAmount() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestCalculateRemainingAmount(t *testing.T) { + now := time.Now() + + tests := []struct { + name string + sub Subscribe + order Order + wantErr bool + }{ + { + name: "valid no limit subscription", + sub: Subscribe{ + StartTime: now.Add(-24 * time.Hour), + ExpireTime: now.Add(24 * time.Hour), + Traffic: 1000, + Download: 300, + Upload: 200, + UnitTime: UnitTimeNoLimit, + ResetCycle: ResetCycleNone, + DeductionRatio: 0, + }, + order: Order{ + Amount: 1000, + Quantity: 1, + }, + wantErr: false, + }, + { + name: "invalid subscription", + sub: Subscribe{ + StartTime: now, + ExpireTime: now.Add(-24 * time.Hour), // Invalid: expire before start + Traffic: 1000, + Download: 300, + Upload: 200, + UnitTime: UnitTimeMonth, + DeductionRatio: 0, + }, + order: Order{ + Amount: 1000, + Quantity: 1, + }, + wantErr: true, + }, + { + name: "invalid order", + sub: Subscribe{ + StartTime: now.Add(-24 * time.Hour), + ExpireTime: now.Add(24 * time.Hour), + Traffic: 1000, + Download: 300, + Upload: 200, + UnitTime: UnitTimeMonth, + DeductionRatio: 0, + }, + order: Order{ + Amount: 1000, + Quantity: 0, // Invalid: zero quantity + }, + wantErr: true, + }, + { + name: "no limit with reset cycle", + sub: Subscribe{ + StartTime: now.Add(-24 * time.Hour), + ExpireTime: now.Add(24 * time.Hour), + Traffic: 1000, + Download: 300, + Upload: 200, + UnitTime: UnitTimeNoLimit, + ResetCycle: ResetCycleMonthly, // Should return 0 + DeductionRatio: 0, + }, + order: Order{ + Amount: 1000, + Quantity: 1, + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := CalculateRemainingAmount(tt.sub, tt.order) + if (err != nil) != tt.wantErr { + t.Errorf("CalculateRemainingAmount() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestCalculateRemainingAmount_NoLimitWithResetCycle(t *testing.T) { + now := time.Now() + sub := Subscribe{ + StartTime: now.Add(-24 * time.Hour), + ExpireTime: now.Add(24 * time.Hour), + Traffic: 1000, + Download: 300, + Upload: 200, + UnitTime: UnitTimeNoLimit, + ResetCycle: ResetCycleMonthly, + DeductionRatio: 0, + } + order := Order{ + Amount: 1000, + Quantity: 1, + } + + got, err := CalculateRemainingAmount(sub, order) + if err != nil { + t.Errorf("CalculateRemainingAmount() error = %v", err) + return + } + if got != 0 { + t.Errorf("CalculateRemainingAmount() = %v, want 0", got) + } +} + +// Benchmark tests +func BenchmarkCalculateRemainingAmount(b *testing.B) { + now := time.Now() + sub := Subscribe{ + StartTime: now.Add(-24 * time.Hour), + ExpireTime: now.Add(24 * time.Hour), + Traffic: 1000, + Download: 300, + Upload: 200, + UnitTime: UnitTimeMonth, + ResetCycle: ResetCycleNone, + DeductionRatio: 50, + } + order := Order{ + Amount: 1000, + Quantity: 1, + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = CalculateRemainingAmount(sub, order) + } +} + +func BenchmarkSafeMultiply(b *testing.B) { + for i := 0; i < b.N; i++ { + _, _ = safeMultiply(12345, 67890) + } +}