package group import ( "context" "errors" "time" "github.com/perfect-panel/server/internal/model/group" "github.com/perfect-panel/server/internal/svc" "github.com/perfect-panel/server/internal/types" "github.com/perfect-panel/server/pkg/logger" "gorm.io/gorm" ) type UpdateNodeGroupLogic struct { logger.Logger ctx context.Context svcCtx *svc.ServiceContext } func NewUpdateNodeGroupLogic(ctx context.Context, svcCtx *svc.ServiceContext) *UpdateNodeGroupLogic { return &UpdateNodeGroupLogic{ Logger: logger.WithContext(ctx), ctx: ctx, svcCtx: svcCtx, } } func (l *UpdateNodeGroupLogic) UpdateNodeGroup(req *types.UpdateNodeGroupRequest) error { // 检查节点组是否存在 var nodeGroup group.NodeGroup if err := l.svcCtx.DB.Where("id = ?", req.Id).First(&nodeGroup).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return errors.New("node group not found") } logger.Errorf("failed to find node group: %v", err) return err } // 构建更新数据 updates := map[string]interface{}{ "updated_at": time.Now(), } if req.Name != "" { updates["name"] = req.Name } if req.Description != "" { updates["description"] = req.Description } if req.Sort != 0 { updates["sort"] = req.Sort } if req.ForCalculation != nil { updates["for_calculation"] = *req.ForCalculation } // 获取新的流量区间值 newMinTraffic := nodeGroup.MinTrafficGB newMaxTraffic := nodeGroup.MaxTrafficGB if req.MinTrafficGB != nil { newMinTraffic = req.MinTrafficGB updates["min_traffic_gb"] = *req.MinTrafficGB } if req.MaxTrafficGB != nil { newMaxTraffic = req.MaxTrafficGB updates["max_traffic_gb"] = *req.MaxTrafficGB } // 校验流量区间 if err := l.validateTrafficRange(int(req.Id), newMinTraffic, newMaxTraffic); err != nil { return err } // 执行更新 if err := l.svcCtx.DB.Model(&nodeGroup).Updates(updates).Error; err != nil { logger.Errorf("failed to update node group: %v", err) return err } logger.Infof("updated node group: id=%d", req.Id) return nil } // validateTrafficRange 校验流量区间:不能重叠、不能留空档、最小值不能大于最大值 func (l *UpdateNodeGroupLogic) validateTrafficRange(currentNodeGroupId int, newMin, newMax *int64) error { // 处理指针值 minVal := int64(0) maxVal := int64(0) if newMin != nil { minVal = *newMin } if newMax != nil { maxVal = *newMax } // 检查最小值是否大于最大值 if minVal > maxVal { return errors.New("minimum traffic cannot exceed maximum traffic") } // 如果两个值都为0,表示不参与流量分组,不需要校验 if minVal == 0 && maxVal == 0 { return nil } // 查询所有其他设置了流量区间的节点组 var otherGroups []group.NodeGroup if err := l.svcCtx.DB. Where("id != ?", currentNodeGroupId). Where("(min_traffic_gb > 0 OR max_traffic_gb > 0)"). Find(&otherGroups).Error; err != nil { logger.Errorf("failed to query other node groups: %v", err) return err } // 检查是否有重叠 for _, other := range otherGroups { otherMin := int64(0) otherMax := int64(0) if other.MinTrafficGB != nil { otherMin = *other.MinTrafficGB } if other.MaxTrafficGB != nil { otherMax = *other.MaxTrafficGB } // 如果对方也没设置区间,跳过 if otherMin == 0 && otherMax == 0 { continue } // 检查是否有重叠: 如果两个区间相交,就是重叠 // 不重叠的条件是: newMax <= otherMin OR newMin >= otherMax if !(maxVal <= otherMin || minVal >= otherMax) { return errors.New("traffic range overlaps with another node group") } } return nil }