diff --git a/internal/logic/server/serverPushUserTrafficLogic.go b/internal/logic/server/serverPushUserTrafficLogic.go index d94e827..c6ab4e6 100644 --- a/internal/logic/server/serverPushUserTrafficLogic.go +++ b/internal/logic/server/serverPushUserTrafficLogic.go @@ -41,6 +41,7 @@ func (l *ServerPushUserTrafficLogic) ServerPushUserTraffic(req *types.ServerPush // Create traffic task var request task.TrafficStatistics request.ServerId = serverInfo.Id + request.Protocol = req.Protocol tool.DeepCopy(&request.Logs, req.Traffic) // Push traffic task diff --git a/queue/logic/traffic/trafficStatisticsLogic.go b/queue/logic/traffic/trafficStatisticsLogic.go index 5be16b2..ed98cd1 100644 --- a/queue/logic/traffic/trafficStatisticsLogic.go +++ b/queue/logic/traffic/trafficStatisticsLogic.go @@ -3,8 +3,10 @@ package traffic import ( "context" "encoding/json" + "strings" "time" + "github.com/perfect-panel/server/internal/model/node" "github.com/perfect-panel/server/pkg/logger" "github.com/hibiken/asynq" @@ -46,29 +48,38 @@ func (l *TrafficStatisticsLogic) ProcessTask(ctx context.Context, task *asynq.Ta ) return nil } - var serverRatio float32 = 1.0 - if serverInfo.Ratio > 0 { - serverRatio = serverInfo.Ratio + // query protocol ratio + // default ratio is 1.0 + + protocols, err := serverInfo.UnmarshalProtocols() + if err != nil { + logger.Errorf("[TrafficStatistics] Unmarshal protocols failed: %s", err.Error()) + return nil + } + var protocol *node.Protocol + + var ratio float32 = 1.0 + + for _, p := range protocols { + if strings.ToLower(p.Type) == strings.ToLower(payload.Protocol) { + protocol = &p + break + } + } + + if protocol == nil { + logger.WithContext(ctx).Error("[TrafficStatistics] Protocol not found: %s", payload.Protocol) + return nil + } + + // use protocol ratio if it's greater than 0 + if protocol.Ratio > 0 { + ratio = float32(protocol.Ratio) } now := time.Now() realTimeMultiplier := l.svc.NodeMultiplierManager.GetMultiplier(now) for _, log := range payload.Logs { - if log.Upload == 0 && log.Download == 0 { - continue - } - // update user subscribe with log - d := int64(float32(log.Download) * serverRatio * realTimeMultiplier) - u := int64(float32(log.Upload) * serverRatio * realTimeMultiplier) - if err := l.svc.UserModel.UpdateUserSubscribeWithTraffic(ctx, log.SID, d, u); err != nil { - logger.WithContext(ctx).Error("[TrafficStatistics] Update user subscribe with log failed", - logger.Field("sid", log.SID), - logger.Field("download", float32(log.Download)*serverRatio), - logger.Field("upload", float32(log.Upload)*serverRatio), - logger.Field("error", err.Error()), - ) - continue - } // query user Subscribe Info sub, err := l.svc.UserModel.FindOneSubscribe(ctx, log.SID) if err != nil { @@ -79,8 +90,25 @@ func (l *TrafficStatisticsLogic) ProcessTask(ctx context.Context, task *asynq.Ta continue } + if log.Download+log.Upload <= l.svc.Config.Node.TrafficReportThreshold { + // no traffic, skip + continue + } + // update user subscribe with log + d := int64(float32(log.Download) * ratio * realTimeMultiplier) + u := int64(float32(log.Upload) * ratio * realTimeMultiplier) + if err := l.svc.UserModel.UpdateUserSubscribeWithTraffic(ctx, sub.Id, d, u); err != nil { + logger.WithContext(ctx).Error("[TrafficStatistics] Update user subscribe with log failed", + logger.Field("sid", log.SID), + logger.Field("download", float32(log.Download)*ratio), + logger.Field("upload", float32(log.Upload)*ratio), + logger.Field("error", err.Error()), + ) + continue + } + // create log log - if err := l.svc.TrafficLogModel.Insert(ctx, &traffic.TrafficLog{ + if err = l.svc.TrafficLogModel.Insert(ctx, &traffic.TrafficLog{ ServerId: payload.ServerId, SubscribeId: log.SID, UserId: sub.UserId, @@ -90,8 +118,8 @@ func (l *TrafficStatisticsLogic) ProcessTask(ctx context.Context, task *asynq.Ta }); err != nil { logger.WithContext(ctx).Error("[TrafficStatistics] Create log log failed", logger.Field("uid", log.SID), - logger.Field("download", float32(log.Download)*serverRatio), - logger.Field("upload", float32(log.Upload)*serverRatio), + logger.Field("download", float32(log.Download)*ratio), + logger.Field("upload", float32(log.Upload)*ratio), logger.Field("error", err.Error()), ) } diff --git a/queue/types/server.go b/queue/types/server.go index 75ec89c..3202ed1 100644 --- a/queue/types/server.go +++ b/queue/types/server.go @@ -10,6 +10,7 @@ type UserTraffic struct { type TrafficStatistics struct { ServerId int64 `json:"server_id"` + Protocol string `json:"protocol"` Logs []UserTraffic `json:"logs"` }