server/queue/logic/task/quotaLogic.go

451 lines
13 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package task
import (
"context"
"encoding/json"
"fmt"
"strconv"
"time"
"github.com/hibiken/asynq"
"github.com/perfect-panel/server/internal/model/log"
"github.com/perfect-panel/server/internal/model/order"
"github.com/perfect-panel/server/internal/model/subscribe"
"github.com/perfect-panel/server/internal/model/task"
"github.com/perfect-panel/server/internal/model/user"
"github.com/perfect-panel/server/internal/svc"
"github.com/perfect-panel/server/pkg/logger"
"github.com/perfect-panel/server/pkg/tool"
"gorm.io/gorm"
)
const (
UnitTimeNoLimit = "NoLimit" // Unlimited time subscription
UnitTimeYear = "Year" // Annual subscription
UnitTimeMonth = "Month" // Monthly subscription
UnitTimeDay = "Day" // Daily subscription
UnitTimeHour = "Hour" // Hourly subscription
UnitTimeMinute = "Minute" // Per-minute subscription
)
type QuotaTaskLogic struct {
svcCtx *svc.ServiceContext
}
type ErrorInfo struct {
UserSubscribeId int64 `json:"user_subscribe_id"`
Error string `json:"error"`
}
func NewQuotaTaskLogic(svcCtx *svc.ServiceContext) *QuotaTaskLogic {
return &QuotaTaskLogic{
svcCtx: svcCtx,
}
}
func (l *QuotaTaskLogic) ProcessTask(ctx context.Context, t *asynq.Task) error {
taskID, err := l.parseTaskID(ctx, t.Payload())
if err != nil {
return err
}
taskInfo, err := l.getTaskInfo(ctx, taskID)
if err != nil {
return err
}
if taskInfo.Status != 0 {
logger.WithContext(ctx).Info("[QuotaTaskLogic.ProcessTask] task already processed",
logger.Field("taskID", taskID),
logger.Field("status", taskInfo.Status),
)
return nil
}
scope, content, err := l.parseTaskData(ctx, taskInfo)
if err != nil {
return err
}
subscribes, err := l.getSubscribes(ctx, scope.Objects)
if err != nil {
return err
}
if err = l.processSubscribes(ctx, subscribes, content, taskInfo); err != nil {
return err
}
// 清理用户缓存(仅在有赠送金时清理)
if content.GiftValue != 0 {
var userIds []int64
for _, sub := range subscribes {
userIds = append(userIds, sub.UserId)
}
userIds = tool.RemoveDuplicateElements(userIds...)
var users []*user.User
if err = l.svcCtx.DB.WithContext(ctx).Model(&user.User{}).Where("id IN ?", userIds).Find(&users).Error; err != nil {
logger.WithContext(ctx).Error("[QuotaTaskLogic.ProcessTask] find users error",
logger.Field("error", err.Error()),
logger.Field("userIDs", userIds))
}
err = l.svcCtx.UserModel.ClearUserCache(ctx, users...)
if err != nil {
logger.WithContext(ctx).Error("[QuotaTaskLogic.ProcessTask] clear user cache error",
logger.Field("error", err.Error()),
logger.Field("userIDs", userIds))
}
}
// 清理用户订阅缓存
err = l.svcCtx.UserModel.ClearSubscribeCache(ctx, subscribes...)
if err != nil {
logger.WithContext(ctx).Error("[QuotaTaskLogic.ProcessTask] clear subscribe cache error",
logger.Field("error", err.Error()))
}
return nil
}
func (l *QuotaTaskLogic) parseTaskID(ctx context.Context, payload []byte) (int64, error) {
if len(payload) == 0 {
logger.WithContext(ctx).Error("[QuotaTaskLogic.parseTaskID] empty payload")
return 0, asynq.SkipRetry
}
taskID, err := strconv.ParseInt(string(payload), 10, 64)
if err != nil {
logger.WithContext(ctx).Error("[QuotaTaskLogic.parseTaskID] invalid task ID",
logger.Field("error", err.Error()),
logger.Field("payload", string(payload)),
)
return 0, asynq.SkipRetry
}
return taskID, nil
}
func (l *QuotaTaskLogic) getTaskInfo(ctx context.Context, taskID int64) (*task.Task, error) {
var taskInfo *task.Task
if err := l.svcCtx.DB.WithContext(ctx).Model(&task.Task{}).Where("id = ?", taskID).First(&taskInfo).Error; err != nil {
logger.WithContext(ctx).Error("[QuotaTaskLogic.getTaskInfo] find task error",
logger.Field("error", err.Error()),
logger.Field("taskID", taskID),
)
return nil, asynq.SkipRetry
}
return taskInfo, nil
}
func (l *QuotaTaskLogic) parseTaskData(ctx context.Context, taskInfo *task.Task) (task.QuotaScope, task.QuotaContent, error) {
var scope task.QuotaScope
if err := scope.Unmarshal([]byte(taskInfo.Scope)); err != nil {
logger.WithContext(ctx).Error("[QuotaTaskLogic.parseTaskData] unmarshal scope error",
logger.Field("error", err.Error()),
)
return scope, task.QuotaContent{}, asynq.SkipRetry
}
var content task.QuotaContent
if err := content.Unmarshal([]byte(taskInfo.Content)); err != nil {
logger.WithContext(ctx).Error("[QuotaTaskLogic.parseTaskData] unmarshal content error",
logger.Field("error", err.Error()),
)
return scope, content, asynq.SkipRetry
}
return scope, content, nil
}
func (l *QuotaTaskLogic) getSubscribes(ctx context.Context, subscriberIDs []int64) ([]*user.Subscribe, error) {
var subscribes []*user.Subscribe
if err := l.svcCtx.DB.WithContext(ctx).Model(&user.Subscribe{}).Where("id IN ?", subscriberIDs).Find(&subscribes).Error; err != nil {
logger.WithContext(ctx).Error("[QuotaTaskLogic.getSubscribes] find subscribes error",
logger.Field("error", err.Error()),
logger.Field("subscribers", subscriberIDs),
)
return nil, asynq.SkipRetry
}
return subscribes, nil
}
func (l *QuotaTaskLogic) processSubscribes(ctx context.Context, subscribes []*user.Subscribe, content task.QuotaContent, taskInfo *task.Task) error {
tx := l.svcCtx.DB.WithContext(ctx).Begin()
defer func() {
if r := recover(); r != nil {
tx.Rollback()
logger.WithContext(ctx).Error("[QuotaTaskLogic.processSubscribes] transaction panic",
logger.Field("panic", r),
)
}
}()
var errors []ErrorInfo
now := time.Now()
for _, sub := range subscribes {
if err := l.processSubscription(tx, sub, content, now, &errors); err != nil {
tx.Rollback()
return err
}
}
// 根据错误情况决定任务状态
status := int8(2) // Completed
if len(errors) > 0 {
logger.WithContext(ctx).Error("[QuotaTaskLogic.processSubscribes] some subscriptions failed",
logger.Field("total", len(subscribes)),
logger.Field("failed", len(errors)),
)
// 如果所有订阅都失败,标记为失败状态
if len(errors) == len(subscribes) {
status = 3 // Failed
}
errs, err := json.Marshal(errors)
if err != nil {
logger.WithContext(ctx).Error("[QuotaTaskLogic.processSubscribes] marshal errors failed",
logger.Field("error", err.Error()),
)
tx.Rollback()
return err
}
taskInfo.Errors = string(errs)
}
taskInfo.Current = uint64(len(subscribes))
taskInfo.Status = status
err := tx.Where("id = ?", taskInfo.Id).Save(taskInfo).Error
if err != nil {
logger.WithContext(ctx).Error("[QuotaTaskLogic.processSubscribes] update task status error",
logger.Field("error", err.Error()),
logger.Field("taskID", taskInfo.Id),
)
tx.Rollback()
return err
}
if err = tx.Commit().Error; err != nil {
logger.WithContext(ctx).Error("[QuotaTaskLogic.processSubscribes] commit transaction error",
logger.Field("error", err.Error()),
)
return err
}
return nil
}
func (l *QuotaTaskLogic) processSubscription(tx *gorm.DB, sub *user.Subscribe, content task.QuotaContent, now time.Time, errors *[]ErrorInfo) error {
// 验证订阅数据
if sub == nil {
*errors = append(*errors, ErrorInfo{
UserSubscribeId: 0,
Error: "subscription is nil",
})
return nil
}
updated := false
// 处理时间延长 - 修复逻辑只要Days不为0就处理不管ExpireTime是否为0
if content.Days != 0 {
if sub.ExpireTime.Unix() == 0 || sub.ExpireTime.Before(now) {
// 如果没有过期时间或已过期,从现在开始计算
sub.ExpireTime = now.AddDate(0, 0, int(content.Days))
} else {
// 在原有过期时间基础上延长
sub.ExpireTime = sub.ExpireTime.AddDate(0, 0, int(content.Days))
}
// 如果订阅延长到未来时间,设置为激活状态
if sub.ExpireTime.After(now) && sub.Status != 1 {
sub.Status = 1 // Active
}
updated = true
}
// 处理流量重置
if content.ResetTraffic {
sub.Download = 0
sub.Upload = 0
updated = true
if err := l.createResetTrafficLog(tx, sub.Id, sub.UserId, now); err != nil {
// 记录错误但不阻断整个任务,日志失败不影响主流程
*errors = append(*errors, ErrorInfo{
UserSubscribeId: sub.Id,
Error: "create reset traffic log error: " + err.Error(),
})
}
}
// 处理赠送金
if content.GiftValue != 0 {
if err := l.processGift(tx, sub, content, now, errors); err != nil {
return err
}
}
// 只有在有更新时才保存订阅信息
if updated {
if err := tx.Where("id = ?", sub.Id).Save(sub).Error; err != nil {
*errors = append(*errors, ErrorInfo{
UserSubscribeId: sub.Id,
Error: "update subscription error: " + err.Error(),
})
return nil
}
}
return nil
}
func (l *QuotaTaskLogic) processGift(tx *gorm.DB, sub *user.Subscribe, content task.QuotaContent, now time.Time, errors *[]ErrorInfo) error {
// 验证赠送类型
if content.GiftType != 1 && content.GiftType != 2 {
*errors = append(*errors, ErrorInfo{
UserSubscribeId: sub.Id,
Error: fmt.Sprintf("invalid gift type: %d", content.GiftType),
})
return nil
}
var userInfo user.User
if err := tx.Model(&user.User{}).Where("id = ?", sub.UserId).First(&userInfo).Error; err != nil {
*errors = append(*errors, ErrorInfo{
UserSubscribeId: sub.Id,
Error: "find user error: " + err.Error(),
})
return nil
}
var giftAmount int64
switch content.GiftType {
case 1:
giftAmount = int64(content.GiftValue)
case 2:
orderAmount, err := l.calculateOrderAmount(tx, sub, now)
if err != nil {
*errors = append(*errors, ErrorInfo{
UserSubscribeId: sub.Id,
Error: err.Error(),
})
return nil
}
if orderAmount > 0 {
giftAmount = int64(float64(orderAmount) * (float64(content.GiftValue) / 100))
}
}
if giftAmount > 0 {
userInfo.GiftAmount += giftAmount
// 使用Update而不是Save更精确地更新单个字段
if err := tx.Model(&user.User{}).Where("id = ?", sub.UserId).Update("gift_amount", userInfo.GiftAmount).Error; err != nil {
*errors = append(*errors, ErrorInfo{
UserSubscribeId: sub.Id,
Error: "update user gift amount error: " + err.Error(),
})
return nil
}
if err := l.createGiftLog(tx, sub.Id, userInfo.Id, giftAmount, userInfo.GiftAmount, now); err != nil {
*errors = append(*errors, ErrorInfo{
UserSubscribeId: sub.Id,
Error: "create gift log error: " + err.Error(),
})
// 回滚用户金额更新
userInfo.GiftAmount -= giftAmount
tx.Model(&user.User{}).Where("id = ?", sub.UserId).Update("gift_amount", userInfo.GiftAmount)
return nil
}
}
return nil
}
func (l *QuotaTaskLogic) getStartTime(sub *user.Subscribe, now time.Time) time.Time {
if sub.StartTime.Unix() == 0 {
return now
}
return sub.StartTime
}
func (l *QuotaTaskLogic) calculateOrderAmount(tx *gorm.DB, sub *user.Subscribe, now time.Time) (int64, error) {
if sub.OrderId != 0 {
var orderInfo *order.Order
if err := tx.Model(&order.Order{}).Where("id = ?", sub.OrderId).First(&orderInfo).Error; err != nil {
return 0, fmt.Errorf("find order error: %v", err)
}
return orderInfo.Amount + orderInfo.GiftAmount, nil
}
var subInfo *subscribe.Subscribe
if err := tx.Model(&subscribe.Subscribe{}).Where("id = ?", sub.SubscribeId).First(&subInfo).Error; err != nil {
return 0, fmt.Errorf("find subscribe error: %v", err)
}
startTime := l.getStartTime(sub, now)
if sub.ExpireTime.Before(startTime) {
return subInfo.UnitPrice, nil
}
switch subInfo.UnitTime {
case UnitTimeNoLimit:
return subInfo.UnitPrice, nil
case UnitTimeYear:
days := tool.DayDiff(startTime, sub.ExpireTime)
return subInfo.UnitPrice / 365 * days, nil
case UnitTimeMonth:
days := tool.DayDiff(startTime, sub.ExpireTime)
return subInfo.UnitPrice / 30 * days, nil
case UnitTimeDay:
days := tool.DayDiff(startTime, sub.ExpireTime)
return subInfo.UnitPrice * days, nil
case UnitTimeHour:
hours := int(tool.HourDiff(startTime, sub.ExpireTime))
return subInfo.UnitPrice * int64(hours), nil
case UnitTimeMinute:
minutes := tool.HourDiff(startTime, sub.ExpireTime) * 60
return subInfo.UnitPrice * minutes, nil
default:
return subInfo.UnitPrice, nil
}
}
func (l *QuotaTaskLogic) createGiftLog(tx *gorm.DB, subscribeId, userId, amount, balance int64, now time.Time) error {
giftLog := &log.Gift{
Type: log.GiftTypeIncrease,
OrderNo: "",
SubscribeId: subscribeId,
Amount: amount,
Balance: balance,
Remark: "Quota task gift",
Timestamp: now.UnixMilli(),
}
logString, err := giftLog.Marshal()
if err != nil {
return fmt.Errorf("marshal gift log error: %v", err)
}
return tx.Model(&log.SystemLog{}).Create(&log.SystemLog{
Type: log.TypeGift.Uint8(),
Content: string(logString),
ObjectID: userId,
Date: now.Format(time.DateOnly),
}).Error
}
func (l *QuotaTaskLogic) createResetTrafficLog(tx *gorm.DB, subscribeId, userId int64, now time.Time) error {
trafficLog := &log.ResetSubscribe{
Type: log.ResetSubscribeTypeQuota,
UserId: userId,
OrderNo: "",
Timestamp: now.UnixMilli(),
}
logString, err := trafficLog.Marshal()
if err != nil {
return fmt.Errorf("marshal traffic log error: %v", err)
}
return tx.Model(&log.SystemLog{}).Create(&log.SystemLog{
Type: log.TypeResetSubscribe.Uint8(),
Content: string(logString),
ObjectID: subscribeId,
Date: now.Format(time.DateOnly),
}).Error
}