diff --git a/internal/logic/admin/server/createRuleGroupLogic.go b/internal/logic/admin/server/createRuleGroupLogic.go index 9fb5125..619fbe7 100644 --- a/internal/logic/admin/server/createRuleGroupLogic.go +++ b/internal/logic/admin/server/createRuleGroupLogic.go @@ -53,8 +53,7 @@ func (l *CreateRuleGroupLogic) CreateRuleGroup(req *types.CreateRuleGroupRequest if err != nil { return err } - - err = l.svcCtx.ServerModel.InsertRuleGroup(l.ctx, &server.RuleGroup{ + info := &server.RuleGroup{ Name: req.Name, Icon: req.Icon, Type: req.Type, @@ -62,10 +61,18 @@ func (l *CreateRuleGroupLogic) CreateRuleGroup(req *types.CreateRuleGroupRequest Rules: strings.Join(rs, "\n"), Default: req.Default, Enable: req.Enable, - }) + } + err = l.svcCtx.ServerModel.InsertRuleGroup(l.ctx, info) if err != nil { l.Errorw("[CreateRuleGroup] Insert Database Error: ", logger.Field("error", err.Error())) return errors.Wrapf(xerr.NewErrCode(xerr.DatabaseInsertError), "create server rule group error: %v", err) } + if req.Default { + if err = l.svcCtx.ServerModel.SetDefaultRuleGroup(l.ctx, info.Id); err != nil { + l.Errorw("[CreateRuleGroup] Set Default Rule Group Error: ", logger.Field("error", err.Error())) + return errors.Wrapf(xerr.NewErrCode(xerr.DatabaseUpdateError), "set default rule group error: %v", err) + } + } + return nil } diff --git a/internal/logic/admin/server/updateRuleGroupLogic.go b/internal/logic/admin/server/updateRuleGroupLogic.go index 9d36ce8..500ba02 100644 --- a/internal/logic/admin/server/updateRuleGroupLogic.go +++ b/internal/logic/admin/server/updateRuleGroupLogic.go @@ -48,5 +48,11 @@ func (l *UpdateRuleGroupLogic) UpdateRuleGroup(req *types.UpdateRuleGroupRequest if err != nil { return errors.Wrapf(xerr.NewErrCode(xerr.DatabaseUpdateError), err.Error()) } + if req.Default { + if err = l.svcCtx.ServerModel.SetDefaultRuleGroup(l.ctx, req.Id); err != nil { + l.Errorf("SetDefaultRuleGroup error: %v", err.Error()) + return errors.Wrapf(xerr.NewErrCode(xerr.DatabaseUpdateError), err.Error()) + } + } return nil } diff --git a/internal/model/server/model.go b/internal/model/server/model.go index 1af139a..4b51c9c 100644 --- a/internal/model/server/model.go +++ b/internal/model/server/model.go @@ -32,6 +32,8 @@ type customServerLogicModel interface { QueryAllRuleGroup(ctx context.Context) ([]*RuleGroup, error) FindServersByTag(ctx context.Context, tag string) ([]*Server, error) FindServerTags(ctx context.Context) ([]string, error) + + SetDefaultRuleGroup(ctx context.Context, id int64) error } var ( @@ -275,3 +277,16 @@ func (m *customServerModel) FindServersByTag(ctx context.Context, tag string) ([ }) return data, err } + +// SetDefaultRuleGroup sets the default rule group. + +func (m *customServerModel) SetDefaultRuleGroup(ctx context.Context, id int64) error { + return m.ExecCtx(ctx, func(conn *gorm.DB) error { + // Reset all groups to not default + if err := conn.Model(&RuleGroup{}).Update("default", false).Error; err != nil { + return err + } + // Set the specified group as default + return conn.Model(&RuleGroup{}).Where("id = ?", id).Update("default", true).Error + }, cacheServerRuleGroupAllKeys, fmt.Sprintf("cache:serverRuleGroup:%v", id)) +}