hi-server/internal/logic/admin/group/updateNodeGroupLogic.go
EUForest 06a2425474 feat(subscribe): add traffic limit rules and user traffic stats
- Add subscribe traffic_limit schema and migration\n- Support traffic_limit in admin create/update and list/details\n- Apply traffic_limit when building server user list speed limits\n- Add public user traffic stats API
2026-03-14 12:41:52 +08:00

186 lines
5.0 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package group
import (
"context"
"errors"
"time"
"github.com/perfect-panel/server/internal/model/group"
"github.com/perfect-panel/server/internal/model/subscribe"
"github.com/perfect-panel/server/internal/svc"
"github.com/perfect-panel/server/internal/types"
"github.com/perfect-panel/server/pkg/logger"
"gorm.io/gorm"
)
type UpdateNodeGroupLogic struct {
logger.Logger
ctx context.Context
svcCtx *svc.ServiceContext
}
func NewUpdateNodeGroupLogic(ctx context.Context, svcCtx *svc.ServiceContext) *UpdateNodeGroupLogic {
return &UpdateNodeGroupLogic{
Logger: logger.WithContext(ctx),
ctx: ctx,
svcCtx: svcCtx,
}
}
func (l *UpdateNodeGroupLogic) UpdateNodeGroup(req *types.UpdateNodeGroupRequest) error {
// 检查节点组是否存在
var nodeGroup group.NodeGroup
if err := l.svcCtx.DB.Where("id = ?", req.Id).First(&nodeGroup).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return errors.New("node group not found")
}
logger.Errorf("failed to find node group: %v", err)
return err
}
// 验证:系统中只能有一个过期节点组
if req.IsExpiredGroup != nil && *req.IsExpiredGroup {
var count int64
err := l.svcCtx.DB.Model(&group.NodeGroup{}).
Where("is_expired_group = ? AND id != ?", true, req.Id).
Count(&count).Error
if err != nil {
logger.Errorf("failed to check expired group count: %v", err)
return err
}
if count > 0 {
return errors.New("system already has an expired node group, cannot create multiple")
}
// 验证:被订阅商品设置为默认节点组的不能设置为过期节点组
var subscribeCount int64
err = l.svcCtx.DB.Model(&subscribe.Subscribe{}).
Where("node_group_id = ?", req.Id).
Count(&subscribeCount).Error
if err != nil {
logger.Errorf("failed to check subscribe usage: %v", err)
return err
}
if subscribeCount > 0 {
return errors.New("this node group is used as default node group in subscription products, cannot set as expired group")
}
}
// 构建更新数据
updates := map[string]interface{}{
"updated_at": time.Now(),
}
if req.Name != "" {
updates["name"] = req.Name
}
if req.Description != "" {
updates["description"] = req.Description
}
if req.Sort != 0 {
updates["sort"] = req.Sort
}
if req.ForCalculation != nil {
updates["for_calculation"] = *req.ForCalculation
}
if req.IsExpiredGroup != nil {
updates["is_expired_group"] = *req.IsExpiredGroup
// 过期节点组不参与分组计算
if *req.IsExpiredGroup {
updates["for_calculation"] = false
}
}
if req.ExpiredDaysLimit != nil {
updates["expired_days_limit"] = *req.ExpiredDaysLimit
}
if req.MaxTrafficGBExpired != nil {
updates["max_traffic_gb_expired"] = *req.MaxTrafficGBExpired
}
if req.SpeedLimit != nil {
updates["speed_limit"] = *req.SpeedLimit
}
// 获取新的流量区间值
newMinTraffic := nodeGroup.MinTrafficGB
newMaxTraffic := nodeGroup.MaxTrafficGB
if req.MinTrafficGB != nil {
newMinTraffic = req.MinTrafficGB
updates["min_traffic_gb"] = *req.MinTrafficGB
}
if req.MaxTrafficGB != nil {
newMaxTraffic = req.MaxTrafficGB
updates["max_traffic_gb"] = *req.MaxTrafficGB
}
// 校验流量区间
if err := l.validateTrafficRange(int(req.Id), newMinTraffic, newMaxTraffic); err != nil {
return err
}
// 执行更新
if err := l.svcCtx.DB.Model(&nodeGroup).Updates(updates).Error; err != nil {
logger.Errorf("failed to update node group: %v", err)
return err
}
logger.Infof("updated node group: id=%d", req.Id)
return nil
}
// validateTrafficRange 校验流量区间:不能重叠、不能留空档、最小值不能大于最大值
func (l *UpdateNodeGroupLogic) validateTrafficRange(currentNodeGroupId int, newMin, newMax *int64) error {
// 处理指针值
minVal := int64(0)
maxVal := int64(0)
if newMin != nil {
minVal = *newMin
}
if newMax != nil {
maxVal = *newMax
}
// 检查最小值是否大于最大值
if minVal > maxVal {
return errors.New("minimum traffic cannot exceed maximum traffic")
}
// 如果两个值都为0表示不参与流量分组不需要校验
if minVal == 0 && maxVal == 0 {
return nil
}
// 查询所有其他设置了流量区间的节点组
var otherGroups []group.NodeGroup
if err := l.svcCtx.DB.
Where("id != ?", currentNodeGroupId).
Where("(min_traffic_gb > 0 OR max_traffic_gb > 0)").
Find(&otherGroups).Error; err != nil {
logger.Errorf("failed to query other node groups: %v", err)
return err
}
// 检查是否有重叠
for _, other := range otherGroups {
otherMin := int64(0)
otherMax := int64(0)
if other.MinTrafficGB != nil {
otherMin = *other.MinTrafficGB
}
if other.MaxTrafficGB != nil {
otherMax = *other.MaxTrafficGB
}
// 如果对方也没设置区间,跳过
if otherMin == 0 && otherMax == 0 {
continue
}
// 检查是否有重叠: 如果两个区间相交,就是重叠
// 不重叠的条件是: newMax <= otherMin OR newMin >= otherMax
if !(maxVal <= otherMin || minVal >= otherMax) {
return errors.New("traffic range overlaps with another node group")
}
}
return nil
}