fix(subscribe): fix user subscription node retrieval logic to support directly assigned nodes

This commit is contained in:
EUForest 2026-03-10 18:29:19 +08:00
parent 884310d951
commit 17163486f6
2 changed files with 405 additions and 252 deletions

View File

@ -70,10 +70,12 @@ func (l *PreviewUserNodesLogic) PreviewUserNodes(req *types.PreviewUserNodesRequ
Id int64 Id int64
NodeGroupId int64 NodeGroupId int64
NodeGroupIds string // JSON string NodeGroupIds string // JSON string
Nodes string // JSON string - 直接分配的节点ID
NodeTags string // 节点标签
} }
var subscribeInfos []SubscribeInfo var subscribeInfos []SubscribeInfo
err = l.svcCtx.DB.Table("subscribe"). err = l.svcCtx.DB.Table("subscribe").
Select("id, node_group_id, node_group_ids"). Select("id, node_group_id, node_group_ids, nodes, node_tags").
Where("id IN ?", subscribeIds). Where("id IN ?", subscribeIds).
Find(&subscribeInfos).Error Find(&subscribeInfos).Error
if err != nil { if err != nil {
@ -124,6 +126,28 @@ func (l *PreviewUserNodesLogic) PreviewUserNodes(req *types.PreviewUserNodesRequ
logger.Infof("[PreviewUserNodes] collected node_group_ids with priority: %v", allNodeGroupIds) logger.Infof("[PreviewUserNodes] collected node_group_ids with priority: %v", allNodeGroupIds)
// 3. 收集所有订阅中直接分配的节点ID
var allDirectNodeIds []int64
for _, subInfo := range subscribeInfos {
if subInfo.Nodes != "" && subInfo.Nodes != "null" {
// nodes 是逗号分隔的字符串,如 "1,2,3"
nodeIdStrs := strings.Split(subInfo.Nodes, ",")
for _, idStr := range nodeIdStrs {
idStr = strings.TrimSpace(idStr)
if idStr != "" {
var nodeId int64
if _, err := fmt.Sscanf(idStr, "%d", &nodeId); err == nil {
allDirectNodeIds = append(allDirectNodeIds, nodeId)
}
}
}
logger.Debugf("[PreviewUserNodes] subscribe_id=%d has direct nodes: %s", subInfo.Id, subInfo.Nodes)
}
}
// 去重
allDirectNodeIds = removeDuplicateInt64(allDirectNodeIds)
logger.Infof("[PreviewUserNodes] collected direct node_ids: %v", allDirectNodeIds)
// 4. 判断分组功能是否启用 // 4. 判断分组功能是否启用
var groupEnabled string var groupEnabled string
l.svcCtx.DB.Table("system"). l.svcCtx.DB.Table("system").
@ -141,8 +165,8 @@ func (l *PreviewUserNodesLogic) PreviewUserNodes(req *types.PreviewUserNodesRequ
// === 启用分组功能:通过用户订阅的 node_group_id 查询节点 === // === 启用分组功能:通过用户订阅的 node_group_id 查询节点 ===
logger.Infof("[PreviewUserNodes] using group-based node filtering") logger.Infof("[PreviewUserNodes] using group-based node filtering")
if len(allNodeGroupIds) == 0 { if len(allNodeGroupIds) == 0 && len(allDirectNodeIds) == 0 {
logger.Infof("[PreviewUserNodes] no node groups found in user subscribes") logger.Infof("[PreviewUserNodes] no node groups and no direct nodes found in user subscribes")
resp = &types.PreviewUserNodesResponse{ resp = &types.PreviewUserNodesResponse{
UserId: req.UserId, UserId: req.UserId,
NodeGroups: []types.NodeGroupItem{}, NodeGroups: []types.NodeGroupItem{},
@ -150,7 +174,8 @@ func (l *PreviewUserNodesLogic) PreviewUserNodes(req *types.PreviewUserNodesRequ
return resp, nil return resp, nil
} }
// 5. 查询所有启用的节点 // 5. 查询所有启用的节点(只有当有节点组时才查询)
if len(allNodeGroupIds) > 0 {
var dbNodes []node.Node var dbNodes []node.Node
err = l.svcCtx.DB.Table("nodes"). err = l.svcCtx.DB.Table("nodes").
Where("enabled = ?", true). Where("enabled = ?", true).
@ -180,37 +205,17 @@ func (l *PreviewUserNodesLogic) PreviewUserNodes(req *types.PreviewUserNodesRequ
} }
logger.Infof("[PreviewUserNodes] found %v nodes using group filter", len(filteredNodes)) logger.Infof("[PreviewUserNodes] found %v nodes using group filter", len(filteredNodes))
}
} else { } else {
// === 未启用分组功能:通过订阅的 node_tags 查询节点 === // === 未启用分组功能:通过订阅的 node_tags 查询节点 ===
logger.Infof("[PreviewUserNodes] using tag-based node filtering") logger.Infof("[PreviewUserNodes] using tag-based node filtering")
// 5. 获取所有订阅的 subscribeId 列表 // 从已查询的 subscribeInfos 中获取 node_tags
subscribeIds := make([]int64, len(userSubscribes))
for i, us := range userSubscribes {
subscribeIds[i] = us.SubscribeId
}
// 6. 查询这些订阅的 node_tags
type SubscribeNodeTags struct {
Id int64
NodeTags string
}
var subscribeNodeTagsList []SubscribeNodeTags
err = l.svcCtx.DB.Table("subscribe").
Where("id IN ?", subscribeIds).
Select("id, node_tags").
Find(&subscribeNodeTagsList).Error
if err != nil {
logger.Errorf("[PreviewUserNodes] failed to get subscribe node tags: %v", err)
return nil, err
}
// 7. 合并所有标签
var allTags []string var allTags []string
for _, snt := range subscribeNodeTagsList { for _, subInfo := range subscribeInfos {
if snt.NodeTags != "" { if subInfo.NodeTags != "" {
tags := strings.Split(snt.NodeTags, ",") tags := strings.Split(subInfo.NodeTags, ",")
allTags = append(allTags, tags...) allTags = append(allTags, tags...)
} }
} }
@ -221,8 +226,8 @@ func (l *PreviewUserNodesLogic) PreviewUserNodes(req *types.PreviewUserNodesRequ
logger.Infof("[PreviewUserNodes] merged tags from subscribes: %v", allTags) logger.Infof("[PreviewUserNodes] merged tags from subscribes: %v", allTags)
if len(allTags) == 0 { if len(allTags) == 0 && len(allDirectNodeIds) == 0 {
logger.Infof("[PreviewUserNodes] no tags found in subscribes") logger.Infof("[PreviewUserNodes] no tags and no direct nodes found in subscribes")
resp = &types.PreviewUserNodesResponse{ resp = &types.PreviewUserNodesResponse{
UserId: req.UserId, UserId: req.UserId,
NodeGroups: []types.NodeGroupItem{}, NodeGroups: []types.NodeGroupItem{},
@ -230,7 +235,8 @@ func (l *PreviewUserNodesLogic) PreviewUserNodes(req *types.PreviewUserNodesRequ
return resp, nil return resp, nil
} }
// 8. 查询所有启用的节点 // 8. 查询所有启用的节点(只有当有 tags 时才查询)
if len(allTags) > 0 {
var dbNodes []node.Node var dbNodes []node.Node
err = l.svcCtx.DB.Table("nodes"). err = l.svcCtx.DB.Table("nodes").
Where("enabled = ?", true). Where("enabled = ?", true).
@ -257,8 +263,14 @@ func (l *PreviewUserNodesLogic) PreviewUserNodes(req *types.PreviewUserNodesRequ
logger.Infof("[PreviewUserNodes] found %v nodes using tag filter", len(filteredNodes)) logger.Infof("[PreviewUserNodes] found %v nodes using tag filter", len(filteredNodes))
} }
}
// 10. 转换为 types.Node 并按节点组分组 // 10. 根据是否启用分组功能,选择不同的分组方式
nodeGroupItems := make([]types.NodeGroupItem, 0)
if isGroupEnabled {
// === 启用分组:按节点组分组 ===
// 转换为 types.Node 并按节点组分组
type NodeWithGroup struct { type NodeWithGroup struct {
Node node.Node Node node.Node
NodeGroupIds []int64 NodeGroupIds []int64
@ -268,11 +280,11 @@ func (l *PreviewUserNodesLogic) PreviewUserNodes(req *types.PreviewUserNodesRequ
for _, n := range filteredNodes { for _, n := range filteredNodes {
nodesWithGroup = append(nodesWithGroup, NodeWithGroup{ nodesWithGroup = append(nodesWithGroup, NodeWithGroup{
Node: n, Node: n,
NodeGroupIds: []int64(n.NodeGroupIds), NodeGroupIds: n.NodeGroupIds,
}) })
} }
// 11. 按节点组分组节点 // 按节点组分组节点
type NodeGroupMap struct { type NodeGroupMap struct {
Id int64 Id int64
Nodes []types.Node Nodes []types.Node
@ -285,8 +297,7 @@ func (l *PreviewUserNodesLogic) PreviewUserNodes(req *types.PreviewUserNodesRequ
allGroupIds := make([]int64, 0) allGroupIds := make([]int64, 0)
for _, ng := range nodesWithGroup { for _, ng := range nodesWithGroup {
if len(ng.NodeGroupIds) > 0 { if len(ng.NodeGroupIds) > 0 {
// 如果节点属于节点组,按第一个节点组分组(或者可以按所有节点组) // 如果节点属于节点组,按第一个节点组分组
// 这里使用节点的第一个节点组
firstGroupId := ng.NodeGroupIds[0] firstGroupId := ng.NodeGroupIds[0]
if _, exists := groupMap[firstGroupId]; !exists { if _, exists := groupMap[firstGroupId]; !exists {
groupMap[firstGroupId] = &NodeGroupMap{ groupMap[firstGroupId] = &NodeGroupMap{
@ -349,9 +360,9 @@ func (l *PreviewUserNodesLogic) PreviewUserNodes(req *types.PreviewUserNodesRequ
} }
} }
// 12. 查询节点组信息并构建响应 // 查询节点组信息并构建响应
nodeGroupInfoMap := make(map[int64]string) nodeGroupInfoMap := make(map[int64]string)
validGroupIds := make([]int64, 0) // 存储在数据库中实际存在的节点组ID validGroupIds := make([]int64, 0)
if len(allGroupIds) > 0 { if len(allGroupIds) > 0 {
type NodeGroupInfo struct { type NodeGroupInfo struct {
@ -377,7 +388,7 @@ func (l *PreviewUserNodesLogic) PreviewUserNodes(req *types.PreviewUserNodesRequ
logger.Debugf("[PreviewUserNodes] node_group[%d] = %s", ngInfo.Id, ngInfo.Name) logger.Debugf("[PreviewUserNodes] node_group[%d] = %s", ngInfo.Id, ngInfo.Name)
} }
// 记录无效的节点组ID节点有这个ID但数据库中不存在 // 记录无效的节点组ID
for _, requestedId := range allGroupIds { for _, requestedId := range allGroupIds {
found := false found := false
for _, validId := range validGroupIds { for _, validId := range validGroupIds {
@ -392,9 +403,8 @@ func (l *PreviewUserNodesLogic) PreviewUserNodes(req *types.PreviewUserNodesRequ
} }
} }
// 13. 构建响应根据有效节点组ID重新分组节点 // 构建响应根据有效节点组ID重新分组节点
nodeGroupItems := make([]types.NodeGroupItem, 0) publicNodes := make([]types.Node, 0)
publicNodes := make([]types.Node, 0) // 公共节点(包括无效节点组和无节点组的节点)
// 遍历所有分组,重新分类节点 // 遍历所有分组,重新分类节点
for groupId, gm := range groupMap { for groupId, gm := range groupMap {
@ -404,7 +414,7 @@ func (l *PreviewUserNodesLogic) PreviewUserNodes(req *types.PreviewUserNodesRequ
continue continue
} }
// 检查这个节点组ID是否有效(在数据库中存在) // 检查这个节点组ID是否有效
isValid := false isValid := false
for _, validId := range validGroupIds { for _, validId := range validGroupIds {
if groupId == validId { if groupId == validId {
@ -432,7 +442,7 @@ func (l *PreviewUserNodesLogic) PreviewUserNodes(req *types.PreviewUserNodesRequ
} }
} }
// 最后添加公共节点组(如果有) // 添加公共节点组(如果有)
if len(publicNodes) > 0 { if len(publicNodes) > 0 {
nodeGroupItems = append(nodeGroupItems, types.NodeGroupItem{ nodeGroupItems = append(nodeGroupItems, types.NodeGroupItem{
Id: 0, Id: 0,
@ -442,6 +452,104 @@ func (l *PreviewUserNodesLogic) PreviewUserNodes(req *types.PreviewUserNodesRequ
logger.Infof("[PreviewUserNodes] adding public group: nodes=%d", len(publicNodes)) logger.Infof("[PreviewUserNodes] adding public group: nodes=%d", len(publicNodes))
} }
} else {
// === 未启用分组:按 tag 分组 ===
// 按 tag 分组节点
tagGroupMap := make(map[string][]types.Node)
for _, n := range filteredNodes {
tags := []string{}
if n.Tags != "" {
tags = strings.Split(n.Tags, ",")
}
// 转换节点
node := types.Node{
Id: n.Id,
Name: n.Name,
Tags: tags,
Port: n.Port,
Address: n.Address,
ServerId: n.ServerId,
Protocol: n.Protocol,
Enabled: n.Enabled,
Sort: n.Sort,
NodeGroupIds: []int64(n.NodeGroupIds),
CreatedAt: n.CreatedAt.Unix(),
UpdatedAt: n.UpdatedAt.Unix(),
}
// 将节点添加到每个匹配的 tag 分组中
if len(tags) > 0 {
for _, tag := range tags {
tag = strings.TrimSpace(tag)
if tag != "" {
tagGroupMap[tag] = append(tagGroupMap[tag], node)
}
}
} else {
// 没有 tag 的节点放入特殊分组
tagGroupMap[""] = append(tagGroupMap[""], node)
}
}
// 构建响应:按 tag 分组
for tag, nodes := range tagGroupMap {
nodeGroupItems = append(nodeGroupItems, types.NodeGroupItem{
Id: 0, // tag 分组使用 ID 0
Name: tag,
Nodes: nodes,
})
logger.Infof("[PreviewUserNodes] adding tag group: tag=%s, nodes=%d", tag, len(nodes))
}
}
// 添加套餐节点组(直接分配的节点)
if len(allDirectNodeIds) > 0 {
// 查询直接分配的节点详情
var directNodes []node.Node
err = l.svcCtx.DB.Table("nodes").
Where("id IN ? AND enabled = ?", allDirectNodeIds, true).
Find(&directNodes).Error
if err != nil {
logger.Errorf("[PreviewUserNodes] failed to get direct nodes: %v", err)
return nil, err
}
if len(directNodes) > 0 {
// 转换为 types.Node
directNodeItems := make([]types.Node, 0, len(directNodes))
for _, n := range directNodes {
tags := []string{}
if n.Tags != "" {
tags = strings.Split(n.Tags, ",")
}
directNodeItems = append(directNodeItems, types.Node{
Id: n.Id,
Name: n.Name,
Tags: tags,
Port: n.Port,
Address: n.Address,
ServerId: n.ServerId,
Protocol: n.Protocol,
Enabled: n.Enabled,
Sort: n.Sort,
NodeGroupIds: []int64(n.NodeGroupIds),
CreatedAt: n.CreatedAt.Unix(),
UpdatedAt: n.UpdatedAt.Unix(),
})
}
// 添加套餐节点组使用特殊ID -1Name 为空字符串,前端根据 ID -1 进行国际化)
nodeGroupItems = append(nodeGroupItems, types.NodeGroupItem{
Id: -1,
Name: "", // 空字符串,前端根据 ID -1 识别并国际化
Nodes: directNodeItems,
})
logger.Infof("[PreviewUserNodes] adding subscription nodes group: nodes=%d", len(directNodeItems))
}
}
// 14. 返回结果 // 14. 返回结果
resp = &types.PreviewUserNodesResponse{ resp = &types.PreviewUserNodesResponse{
UserId: req.UserId, UserId: req.UserId,

View File

@ -177,19 +177,27 @@ func (l *QueryUserSubscribeNodeListLogic) getNodesByGroup(userSub *user.Subscrib
// 按优先级获取 node_group_iduser_subscribe.node_group_id > subscribe.node_group_id > subscribe.node_group_ids[0] // 按优先级获取 node_group_iduser_subscribe.node_group_id > subscribe.node_group_id > subscribe.node_group_ids[0]
nodeGroupId := int64(0) nodeGroupId := int64(0)
source := "" source := ""
var directNodeIds []int64
// 优先级1: user_subscribe.node_group_id // 优先级1: user_subscribe.node_group_id
if userSub.NodeGroupId != 0 { if userSub.NodeGroupId != 0 {
nodeGroupId = userSub.NodeGroupId nodeGroupId = userSub.NodeGroupId
source = "user_subscribe.node_group_id" source = "user_subscribe.node_group_id"
} else { }
// 优先级2 & 3: 从 subscribe 表获取
// 获取 subscribe 详情(用于获取 node_group_id 和直接分配的节点)
subDetails, err := l.svcCtx.SubscribeModel.FindOne(l.ctx, userSub.SubscribeId) subDetails, err := l.svcCtx.SubscribeModel.FindOne(l.ctx, userSub.SubscribeId)
if err != nil { if err != nil {
l.Errorw("[GetNodesByGroup] find subscribe details error", logger.Field("error", err.Error())) l.Errorw("[GetNodesByGroup] find subscribe details error", logger.Field("error", err.Error()))
return nil, err return nil, err
} }
// 获取直接分配的节点ID
directNodeIds = tool.StringToInt64Slice(subDetails.Nodes)
l.Debugf("[GetNodesByGroup] direct nodes: %v", directNodeIds)
// 如果 user_subscribe 没有 node_group_id从 subscribe 获取
if nodeGroupId == 0 {
// 优先级2: subscribe.node_group_id // 优先级2: subscribe.node_group_id
if subDetails.NodeGroupId != 0 { if subDetails.NodeGroupId != 0 {
nodeGroupId = subDetails.NodeGroupId nodeGroupId = subDetails.NodeGroupId
@ -201,20 +209,13 @@ func (l *QueryUserSubscribeNodeListLogic) getNodesByGroup(userSub *user.Subscrib
} }
} }
// 如果所有优先级都没有获取到,返回空节点列表
if nodeGroupId == 0 {
l.Debugw("[GetNodesByGroup] no node_group_id found in any priority, returning no nodes")
return []*node.Node{}, nil
}
l.Debugf("[GetNodesByGroup] Using %s: %v", source, nodeGroupId) l.Debugf("[GetNodesByGroup] Using %s: %v", source, nodeGroupId)
// Filter nodes by node_group_id // 查询所有启用的节点
enable := true enable := true
_, nodes, err := l.svcCtx.NodeModel.FilterNodeList(l.ctx, &node.FilterNodeParams{ _, allNodes, err := l.svcCtx.NodeModel.FilterNodeList(l.ctx, &node.FilterNodeParams{
Page: 0, Page: 0,
Size: 1000, Size: 10000,
NodeGroupIds: []int64{nodeGroupId}, // Filter by node_group_ids
Enabled: &enable, Enabled: &enable,
}) })
if err != nil { if err != nil {
@ -222,8 +223,46 @@ func (l *QueryUserSubscribeNodeListLogic) getNodesByGroup(userSub *user.Subscrib
return nil, err return nil, err
} }
l.Debugf("[GetNodesByGroup] Found %d nodes for node_group_id=%d", len(nodes), nodeGroupId) // 过滤节点
return nodes, nil var resultNodes []*node.Node
nodeIdMap := make(map[int64]bool)
for _, n := range allNodes {
// 1. 公共节点node_group_ids 为空),所有人可见
if len(n.NodeGroupIds) == 0 {
if !nodeIdMap[n.Id] {
resultNodes = append(resultNodes, n)
nodeIdMap[n.Id] = true
}
continue
}
// 2. 如果有节点组,检查节点是否属于该节点组
if nodeGroupId != 0 {
for _, gid := range n.NodeGroupIds {
if gid == nodeGroupId {
if !nodeIdMap[n.Id] {
resultNodes = append(resultNodes, n)
nodeIdMap[n.Id] = true
}
break
}
}
}
}
// 3. 添加直接分配的节点
if len(directNodeIds) > 0 {
for _, n := range allNodes {
if tool.Contains(directNodeIds, n.Id) && !nodeIdMap[n.Id] {
resultNodes = append(resultNodes, n)
nodeIdMap[n.Id] = true
}
}
}
l.Debugf("[GetNodesByGroup] Found %d nodes (group=%d, direct=%d)", len(resultNodes), nodeGroupId, len(directNodeIds))
return resultNodes, nil
} }
// getNodesByTag gets nodes based on subscribe node_ids and tags // getNodesByTag gets nodes based on subscribe node_ids and tags
@ -236,7 +275,13 @@ func (l *QueryUserSubscribeNodeListLogic) getNodesByTag(userSub *user.Subscribe)
nodeIds := tool.StringToInt64Slice(subDetails.Nodes) nodeIds := tool.StringToInt64Slice(subDetails.Nodes)
tags := strings.Split(subDetails.NodeTags, ",") tags := strings.Split(subDetails.NodeTags, ",")
newTags := make([]string, 0)
for _, t := range tags {
if t != "" {
newTags = append(newTags, t)
}
}
tags = newTags
l.Debugf("[Generate Subscribe]nodes: %v, NodeTags: %v", nodeIds, tags) l.Debugf("[Generate Subscribe]nodes: %v, NodeTags: %v", nodeIds, tags)
enable := true enable := true