From c5b05071878de4c77d944f34ad6fcc614c4fc08f Mon Sep 17 00:00:00 2001 From: Chang lue Tsen Date: Wed, 30 Apr 2025 15:38:01 +0900 Subject: [PATCH] fix(subscribe): enhance server tag handling in subscription logic --- internal/logic/subscribe/subscribeLogic.go | 21 +++++++++++---------- pkg/adapter/adapter.go | 17 ++++++++++++++--- pkg/adapter/uilts.go | 12 ++++++++++++ 3 files changed, 37 insertions(+), 13 deletions(-) diff --git a/internal/logic/subscribe/subscribeLogic.go b/internal/logic/subscribe/subscribeLogic.go index e723f99..829de79 100644 --- a/internal/logic/subscribe/subscribeLogic.go +++ b/internal/logic/subscribe/subscribeLogic.go @@ -177,27 +177,28 @@ func (l *SubscribeLogic) getRules() ([]*server.RuleGroup, error) { func (l *SubscribeLogic) buildClientConfig(req *types.SubscribeRequest, userSub *user.Subscribe, servers []*server.Server, rules []*server.RuleGroup) ([]byte, string, error) { tags := make(map[string][]*server.Server) - groups, err := l.svc.ServerModel.QueryAllGroup(l.ctx) + serverTags, err := l.svc.ServerModel.FindServerTags(l.ctx) if err != nil { - l.Errorw("[Generate Subscribe]find group error: %v", logger.Field("error", err.Error())) - return nil, "", errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "find group error: %v", err.Error()) + l.Errorw("[Generate Subscribe]find server tags error: %v", logger.Field("error", err.Error())) + return nil, "", errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "find server tags error: %v", err.Error()) } - for _, group := range groups { - total, servers, err := l.svc.ServerModel.FindServerListByFilter(l.ctx, &server.ServerFilter{ - Tags: []string{group.Name}, - }) + // Deduplicate tags + serverTags = tool.RemoveDuplicateElements(serverTags...) + for _, tag := range serverTags { + s, err := l.svc.ServerModel.FindServersByTag(l.ctx.Request.Context(), tag) if err != nil { + l.Errorw("[Generate Subscribe]find servers by tag error: %v", logger.Field("error", err.Error())) continue } - if total > 0 { - tags[group.Name] = servers + if len(s) > 0 { + tags[tag] = s } } proxyManager := adapter.NewAdapter(&adapter.Config{ Nodes: servers, Rules: rules, - Tags: make(map[string][]*server.Server), + Tags: tags, }) clientType := l.getClientType(req) var resp []byte diff --git a/pkg/adapter/adapter.go b/pkg/adapter/adapter.go index 8feb1fa..daf60ff 100644 --- a/pkg/adapter/adapter.go +++ b/pkg/adapter/adapter.go @@ -12,25 +12,36 @@ import ( "github.com/perfect-panel/server/pkg/adapter/surfboard" ) +type Config struct { + Nodes []*server.Server + Rules []*server.RuleGroup + Tags map[string][]*server.Server +} + type Adapter struct { proxy.Adapter } -func NewAdapter(nodes []*server.Server, rules []*server.RuleGroup) *Adapter { +func NewAdapter(cfg *Config) *Adapter { // 转换服务器列表 - proxies := adapterProxies(nodes) + proxies := adapterProxies(cfg.Nodes) // 生成代理组 proxyGroup, region := generateProxyGroup(proxies) + // 转换规则组 - g, r := adapterRules(rules) + g, r := adapterRules(cfg.Rules) + // 加入兜底节点 for i, group := range g { if len(group.Proxies) == 0 { g[i].Proxies = append([]string{"DIRECT"}, region...) } } + // 合并代理组 proxyGroup = RemoveEmptyGroup(append(proxyGroup, g...)) + // 处理标签 + proxyGroup = adapterTags(cfg.Tags, proxyGroup) return &Adapter{ Adapter: proxy.Adapter{ Proxies: proxies, diff --git a/pkg/adapter/uilts.go b/pkg/adapter/uilts.go index f6dad92..a9858ed 100644 --- a/pkg/adapter/uilts.go +++ b/pkg/adapter/uilts.go @@ -109,6 +109,18 @@ func adapterRules(groups []*server.RuleGroup) (proxyGroup []proxy.Group, rules [ return } +func adapterTags(tags map[string][]*server.Server, group []proxy.Group) (proxyGroup []proxy.Group) { + for tag, servers := range tags { + proxies := adapterProxies(servers) + if len(proxies) != 0 { + for _, p := range proxies { + group = addProxyToGroup(p.Name, tag, group) + } + } + } + return group +} + func generateProxyGroup(servers []proxy.Proxy) (proxyGroup []proxy.Group, region []string) { // 设置手动选择分组 proxyGroup = append(proxyGroup, []proxy.Group{