461 lines
14 KiB
Go
461 lines
14 KiB
Go
package subscribe
|
||
|
||
import (
|
||
"context"
|
||
"strings"
|
||
"time"
|
||
|
||
commonLogic "github.com/perfect-panel/server/internal/logic/common"
|
||
"github.com/perfect-panel/server/internal/model/node"
|
||
"github.com/perfect-panel/server/internal/model/user"
|
||
"github.com/perfect-panel/server/internal/svc"
|
||
"github.com/perfect-panel/server/internal/types"
|
||
"github.com/perfect-panel/server/pkg/constant"
|
||
"github.com/perfect-panel/server/pkg/logger"
|
||
"github.com/perfect-panel/server/pkg/tool"
|
||
"github.com/perfect-panel/server/pkg/xerr"
|
||
"github.com/pkg/errors"
|
||
)
|
||
|
||
type QueryUserSubscribeNodeListLogic struct {
|
||
logger.Logger
|
||
ctx context.Context
|
||
svcCtx *svc.ServiceContext
|
||
}
|
||
|
||
// Get user subscribe node info
|
||
func NewQueryUserSubscribeNodeListLogic(ctx context.Context, svcCtx *svc.ServiceContext) *QueryUserSubscribeNodeListLogic {
|
||
return &QueryUserSubscribeNodeListLogic{
|
||
Logger: logger.WithContext(ctx),
|
||
ctx: ctx,
|
||
svcCtx: svcCtx,
|
||
}
|
||
}
|
||
|
||
func (l *QueryUserSubscribeNodeListLogic) QueryUserSubscribeNodeList() (resp *types.QueryUserSubscribeNodeListResponse, err error) {
|
||
u, ok := l.ctx.Value(constant.CtxKeyUser).(*user.User)
|
||
if !ok {
|
||
logger.Error("current user is not found in context")
|
||
return nil, errors.Wrapf(xerr.NewErrCode(xerr.InvalidAccess), "Invalid Access")
|
||
}
|
||
|
||
entitlement, err := commonLogic.ResolveEntitlementUser(l.ctx, l.svcCtx.DB, u.Id)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
userSubscribes, err := l.svcCtx.UserModel.QueryUserSubscribe(l.ctx, entitlement.EffectiveUserID, 0, 1, 2, 3)
|
||
if err != nil {
|
||
logger.Errorw("failed to query user subscribe", logger.Field("error", err.Error()), logger.Field("user_id", u.Id))
|
||
return nil, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "DB_ERROR")
|
||
}
|
||
|
||
resp = &types.QueryUserSubscribeNodeListResponse{}
|
||
for _, us := range userSubscribes {
|
||
userSubscribe, err := l.getUserSubscribe(us.Token)
|
||
if err != nil {
|
||
l.Errorw("[SubscribeLogic] Get user subscribe failed", logger.Field("error", err.Error()), logger.Field("token", userSubscribe.Token))
|
||
return nil, err
|
||
}
|
||
nodes, err := l.getServers(userSubscribe)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
userSubscribeInfo := types.UserSubscribeInfo{
|
||
Id: userSubscribe.Id,
|
||
Nodes: nodes,
|
||
Traffic: userSubscribe.Traffic,
|
||
Upload: userSubscribe.Upload,
|
||
Download: userSubscribe.Download,
|
||
Token: userSubscribe.Token,
|
||
UserId: userSubscribe.UserId,
|
||
OrderId: userSubscribe.OrderId,
|
||
SubscribeId: userSubscribe.SubscribeId,
|
||
StartTime: userSubscribe.StartTime.Unix(),
|
||
ExpireTime: userSubscribe.ExpireTime.Unix(),
|
||
Status: userSubscribe.Status,
|
||
CreatedAt: userSubscribe.CreatedAt.Unix(),
|
||
UpdatedAt: userSubscribe.UpdatedAt.Unix(),
|
||
}
|
||
|
||
if userSubscribe.FinishedAt != nil {
|
||
userSubscribeInfo.FinishedAt = userSubscribe.FinishedAt.Unix()
|
||
}
|
||
|
||
if l.svcCtx.Config.Register.EnableTrial && l.svcCtx.Config.Register.TrialSubscribe == userSubscribe.SubscribeId {
|
||
userSubscribeInfo.IsTryOut = true
|
||
}
|
||
fillUserSubscribeInfoEntitlementFields(&userSubscribeInfo, entitlement)
|
||
resp.List = append(resp.List, userSubscribeInfo)
|
||
}
|
||
|
||
return
|
||
}
|
||
|
||
func fillUserSubscribeInfoEntitlementFields(sub *types.UserSubscribeInfo, entitlement *commonLogic.EntitlementContext) {
|
||
if sub == nil || entitlement == nil {
|
||
return
|
||
}
|
||
sub.EntitlementSource = entitlement.Source
|
||
sub.EntitlementOwnerUserId = entitlement.OwnerUserID
|
||
sub.ReadOnly = entitlement.ReadOnly
|
||
}
|
||
|
||
func (l *QueryUserSubscribeNodeListLogic) getServers(userSub *user.Subscribe) (userSubscribeNodes []*types.UserSubscribeNodeInfo, err error) {
|
||
userSubscribeNodes = make([]*types.UserSubscribeNodeInfo, 0)
|
||
if l.isSubscriptionExpired(userSub) {
|
||
return l.createExpiredServers(userSub), nil
|
||
}
|
||
|
||
// Check if group management is enabled
|
||
var groupEnabled string
|
||
err = l.svcCtx.DB.Table("system").
|
||
Where("`category` = ? AND `key` = ?", "group", "enabled").
|
||
Select("value").Scan(&groupEnabled).Error
|
||
|
||
if err != nil {
|
||
l.Debugw("[GetServers] Failed to check group enabled", logger.Field("error", err.Error()))
|
||
// Continue with tag-based filtering
|
||
}
|
||
nodeIds := tool.StringToInt64Slice(subDetails.Nodes)
|
||
tags := normalizeSubscribeNodeTags(subDetails.NodeTags)
|
||
|
||
isGroupEnabled := (groupEnabled == "true" || groupEnabled == "1")
|
||
|
||
enable := true
|
||
|
||
_, nodes, err := l.svcCtx.NodeModel.FilterNodeList(l.ctx, &node.FilterNodeParams{
|
||
Page: 1,
|
||
Size: 1000,
|
||
NodeId: nodeIds,
|
||
Tag: tags,
|
||
Enabled: &enable, // Only get enabled nodes
|
||
})
|
||
|
||
// Process nodes and create response
|
||
if len(nodes) > 0 {
|
||
var serverMapIds = make(map[int64]*node.Server)
|
||
for _, n := range nodes {
|
||
serverMapIds[n.ServerId] = nil
|
||
}
|
||
var serverIds []int64
|
||
for k := range serverMapIds {
|
||
serverIds = append(serverIds, k)
|
||
}
|
||
|
||
servers, err := l.svcCtx.NodeModel.QueryServerList(l.ctx, serverIds)
|
||
if err != nil {
|
||
l.Errorw("[Generate Subscribe]find server details error: %v", logger.Field("error", err.Error()))
|
||
return nil, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "find server details error: %v", err.Error())
|
||
}
|
||
|
||
for _, s := range servers {
|
||
serverMapIds[s.Id] = s
|
||
}
|
||
|
||
for _, n := range nodes {
|
||
server := serverMapIds[n.ServerId]
|
||
if server == nil {
|
||
continue
|
||
}
|
||
userSubscribeNode := &types.UserSubscribeNodeInfo{
|
||
Id: n.Id,
|
||
Name: n.Name,
|
||
Uuid: userSub.UUID,
|
||
Protocol: n.Protocol,
|
||
Protocols: server.Protocols,
|
||
Port: n.Port,
|
||
Address: n.Address,
|
||
Tags: strings.Split(n.Tags, ","),
|
||
Country: server.Country,
|
||
City: server.City,
|
||
Latitude: server.Latitude,
|
||
Longitude: server.Longitude,
|
||
LongitudeCenter: server.LongitudeCenter,
|
||
LatitudeCenter: server.LatitudeCenter,
|
||
CreatedAt: n.CreatedAt.Unix(),
|
||
}
|
||
userSubscribeNodes = append(userSubscribeNodes, userSubscribeNode)
|
||
}
|
||
}
|
||
|
||
l.Debugf("[Query Subscribe]found servers: %v", len(nodes))
|
||
return userSubscribeNodes, nil
|
||
}
|
||
|
||
// getNodesByGroup gets nodes based on user subscription node_group_id with priority fallback
|
||
func (l *QueryUserSubscribeNodeListLogic) getNodesByGroup(userSub *user.Subscribe) ([]*node.Node, error) {
|
||
// 按优先级获取 node_group_id:user_subscribe.node_group_id > subscribe.node_group_id > subscribe.node_group_ids[0]
|
||
nodeGroupId := int64(0)
|
||
source := ""
|
||
var directNodeIds []int64
|
||
|
||
// 优先级1: user_subscribe.node_group_id
|
||
if userSub.NodeGroupId != 0 {
|
||
nodeGroupId = userSub.NodeGroupId
|
||
source = "user_subscribe.node_group_id"
|
||
}
|
||
|
||
// 获取 subscribe 详情(用于获取 node_group_id 和直接分配的节点)
|
||
subDetails, err := l.svcCtx.SubscribeModel.FindOne(l.ctx, userSub.SubscribeId)
|
||
if err != nil {
|
||
l.Errorw("[GetNodesByGroup] find subscribe details error", logger.Field("error", err.Error()))
|
||
return nil, err
|
||
}
|
||
|
||
// 获取直接分配的节点ID
|
||
directNodeIds = tool.StringToInt64Slice(subDetails.Nodes)
|
||
l.Debugf("[GetNodesByGroup] direct nodes: %v", directNodeIds)
|
||
|
||
// 如果 user_subscribe 没有 node_group_id,从 subscribe 获取
|
||
if nodeGroupId == 0 {
|
||
// 优先级2: subscribe.node_group_id
|
||
if subDetails.NodeGroupId != 0 {
|
||
nodeGroupId = subDetails.NodeGroupId
|
||
source = "subscribe.node_group_id"
|
||
} else if len(subDetails.NodeGroupIds) > 0 {
|
||
// 优先级3: subscribe.node_group_ids[0]
|
||
nodeGroupId = subDetails.NodeGroupIds[0]
|
||
source = "subscribe.node_group_ids[0]"
|
||
}
|
||
}
|
||
|
||
l.Debugf("[GetNodesByGroup] Using %s: %v", source, nodeGroupId)
|
||
|
||
// 查询所有启用的节点
|
||
enable := true
|
||
_, allNodes, err := l.svcCtx.NodeModel.FilterNodeList(l.ctx, &node.FilterNodeParams{
|
||
Page: 0,
|
||
Size: 10000,
|
||
Enabled: &enable,
|
||
})
|
||
if err != nil {
|
||
l.Errorw("[GetNodesByGroup] FilterNodeList error", logger.Field("error", err.Error()))
|
||
return nil, err
|
||
}
|
||
|
||
// 过滤节点
|
||
var resultNodes []*node.Node
|
||
nodeIdMap := make(map[int64]bool)
|
||
|
||
for _, n := range allNodes {
|
||
// 1. 公共节点(node_group_ids 为空),所有人可见
|
||
if len(n.NodeGroupIds) == 0 {
|
||
if !nodeIdMap[n.Id] {
|
||
resultNodes = append(resultNodes, n)
|
||
nodeIdMap[n.Id] = true
|
||
}
|
||
continue
|
||
}
|
||
|
||
// 2. 如果有节点组,检查节点是否属于该节点组
|
||
if nodeGroupId != 0 {
|
||
for _, gid := range n.NodeGroupIds {
|
||
if gid == nodeGroupId {
|
||
if !nodeIdMap[n.Id] {
|
||
resultNodes = append(resultNodes, n)
|
||
nodeIdMap[n.Id] = true
|
||
}
|
||
break
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// 3. 添加直接分配的节点
|
||
if len(directNodeIds) > 0 {
|
||
for _, n := range allNodes {
|
||
if tool.Contains(directNodeIds, n.Id) && !nodeIdMap[n.Id] {
|
||
resultNodes = append(resultNodes, n)
|
||
nodeIdMap[n.Id] = true
|
||
}
|
||
}
|
||
}
|
||
|
||
l.Debugf("[GetNodesByGroup] Found %d nodes (group=%d, direct=%d)", len(resultNodes), nodeGroupId, len(directNodeIds))
|
||
return resultNodes, nil
|
||
}
|
||
|
||
// getNodesByTag gets nodes based on subscribe node_ids and tags
|
||
func (l *QueryUserSubscribeNodeListLogic) getNodesByTag(userSub *user.Subscribe) ([]*node.Node, error) {
|
||
subDetails, err := l.svcCtx.SubscribeModel.FindOne(l.ctx, userSub.SubscribeId)
|
||
if err != nil {
|
||
l.Errorw("[Generate Subscribe]find subscribe details error: %v", logger.Field("error", err.Error()))
|
||
return nil, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "find subscribe details error: %v", err.Error())
|
||
}
|
||
|
||
nodeIds := tool.StringToInt64Slice(subDetails.Nodes)
|
||
tags := strings.Split(subDetails.NodeTags, ",")
|
||
newTags := make([]string, 0)
|
||
for _, t := range tags {
|
||
if t != "" {
|
||
newTags = append(newTags, t)
|
||
}
|
||
}
|
||
tags = newTags
|
||
l.Debugf("[Generate Subscribe]nodes: %v, NodeTags: %v", nodeIds, tags)
|
||
|
||
enable := true
|
||
_, nodes, err := l.svcCtx.NodeModel.FilterNodeList(l.ctx, &node.FilterNodeParams{
|
||
Page: 0,
|
||
Size: 1000,
|
||
NodeId: nodeIds,
|
||
Tag: tags,
|
||
Enabled: &enable, // Only get enabled nodes
|
||
})
|
||
|
||
return nodes, err
|
||
}
|
||
|
||
// getAllNodes returns all enabled nodes
|
||
func (l *QueryUserSubscribeNodeListLogic) getAllNodes() ([]*node.Node, error) {
|
||
enable := true
|
||
_, nodes, err := l.svcCtx.NodeModel.FilterNodeList(l.ctx, &node.FilterNodeParams{
|
||
Page: 0,
|
||
Size: 1000,
|
||
Enabled: &enable,
|
||
})
|
||
|
||
return nodes, err
|
||
}
|
||
|
||
func (l *QueryUserSubscribeNodeListLogic) isSubscriptionExpired(userSub *user.Subscribe) bool {
|
||
return userSub.ExpireTime.Unix() < time.Now().Unix() && userSub.ExpireTime.Unix() != 0
|
||
}
|
||
|
||
func (l *QueryUserSubscribeNodeListLogic) createExpiredServers(userSub *user.Subscribe) []*types.UserSubscribeNodeInfo {
|
||
// 1. 查询过期节点组
|
||
var expiredGroup group.NodeGroup
|
||
err := l.svcCtx.DB.Where("is_expired_group = ?", true).First(&expiredGroup).Error
|
||
if err != nil {
|
||
l.Debugw("no expired node group configured", logger.Field("error", err))
|
||
return nil
|
||
}
|
||
|
||
// 2. 检查用户是否在过期天数限制内
|
||
expiredDays := int(time.Since(userSub.ExpireTime).Hours() / 24)
|
||
if expiredDays > expiredGroup.ExpiredDaysLimit {
|
||
l.Debugf("user subscription expired %d days, exceeds limit %d days", expiredDays, expiredGroup.ExpiredDaysLimit)
|
||
return nil
|
||
}
|
||
|
||
// 3. 检查用户已使用流量是否超过限制(仅使用过期期间的流量)
|
||
if expiredGroup.MaxTrafficGBExpired != nil && *expiredGroup.MaxTrafficGBExpired > 0 {
|
||
usedTrafficGB := (userSub.ExpiredDownload + userSub.ExpiredUpload) / (1024 * 1024 * 1024)
|
||
if usedTrafficGB >= *expiredGroup.MaxTrafficGBExpired {
|
||
l.Debugf("user expired traffic %d GB, exceeds expired group limit %d GB", usedTrafficGB, *expiredGroup.MaxTrafficGBExpired)
|
||
return nil
|
||
}
|
||
}
|
||
|
||
// 4. 查询过期节点组的节点
|
||
enable := true
|
||
_, nodes, err := l.svcCtx.NodeModel.FilterNodeList(l.ctx, &node.FilterNodeParams{
|
||
Page: 0,
|
||
Size: 1000,
|
||
NodeGroupIds: []int64{expiredGroup.Id},
|
||
Enabled: &enable,
|
||
})
|
||
if err != nil {
|
||
l.Errorw("failed to query expired group nodes", logger.Field("error", err))
|
||
return nil
|
||
}
|
||
|
||
if len(nodes) == 0 {
|
||
l.Debug("no nodes found in expired group")
|
||
return nil
|
||
}
|
||
|
||
// 5. 查询服务器信息
|
||
var serverMapIds = make(map[int64]*node.Server)
|
||
for _, n := range nodes {
|
||
serverMapIds[n.ServerId] = nil
|
||
}
|
||
var serverIds []int64
|
||
for k := range serverMapIds {
|
||
serverIds = append(serverIds, k)
|
||
}
|
||
|
||
servers, err := l.svcCtx.NodeModel.QueryServerList(l.ctx, serverIds)
|
||
if err != nil {
|
||
l.Errorw("failed to query servers", logger.Field("error", err))
|
||
return nil
|
||
}
|
||
|
||
for _, s := range servers {
|
||
serverMapIds[s.Id] = s
|
||
}
|
||
|
||
// 6. 构建节点列表
|
||
userSubscribeNodes := make([]*types.UserSubscribeNodeInfo, 0, len(nodes))
|
||
for _, n := range nodes {
|
||
server := serverMapIds[n.ServerId]
|
||
if server == nil {
|
||
continue
|
||
}
|
||
userSubscribeNode := &types.UserSubscribeNodeInfo{
|
||
Id: n.Id,
|
||
Name: n.Name,
|
||
Uuid: userSub.UUID,
|
||
Protocol: n.Protocol,
|
||
Protocols: server.Protocols,
|
||
Port: n.Port,
|
||
Address: n.Address,
|
||
Tags: strings.Split(n.Tags, ","),
|
||
Country: server.Country,
|
||
City: server.City,
|
||
Latitude: server.Latitude,
|
||
Longitude: server.Longitude,
|
||
LongitudeCenter: server.LongitudeCenter,
|
||
LatitudeCenter: server.LatitudeCenter,
|
||
CreatedAt: n.CreatedAt.Unix(),
|
||
}
|
||
userSubscribeNodes = append(userSubscribeNodes, userSubscribeNode)
|
||
}
|
||
|
||
l.Infof("returned %d nodes from expired group for user %d (expired %d days)", len(userSubscribeNodes), userSub.UserId, expiredDays)
|
||
return userSubscribeNodes
|
||
}
|
||
|
||
func (l *QueryUserSubscribeNodeListLogic) getFirstHostLine() string {
|
||
host := l.svcCtx.Config.Host
|
||
lines := strings.Split(host, "\n")
|
||
if len(lines) > 0 {
|
||
return lines[0]
|
||
}
|
||
return host
|
||
}
|
||
func (l *QueryUserSubscribeNodeListLogic) getUserSubscribe(token string) (*user.Subscribe, error) {
|
||
userSub, err := l.svcCtx.UserModel.FindOneSubscribeByToken(l.ctx, token)
|
||
if err != nil {
|
||
l.Infow("[Generate Subscribe]find subscribe error: %v", logger.Field("error", err.Error()), logger.Field("token", token))
|
||
return nil, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "find subscribe error: %v", err.Error())
|
||
}
|
||
|
||
// Ignore expiration check
|
||
//if userSub.Status > 1 {
|
||
// l.Infow("[Generate Subscribe]subscribe is not available", logger.Field("status", int(userSub.Status)), logger.Field("token", token))
|
||
// return nil, errors.Wrapf(xerr.NewErrCode(xerr.SubscribeNotAvailable), "subscribe is not available")
|
||
//}
|
||
|
||
return userSub, nil
|
||
}
|
||
|
||
func normalizeSubscribeNodeTags(raw string) []string {
|
||
if raw == "" {
|
||
return nil
|
||
}
|
||
|
||
parts := strings.Split(raw, ",")
|
||
cleaned := make([]string, 0, len(parts))
|
||
for _, tag := range parts {
|
||
trimmed := strings.TrimSpace(tag)
|
||
if trimmed == "" {
|
||
continue
|
||
}
|
||
cleaned = append(cleaned, trimmed)
|
||
}
|
||
|
||
return tool.RemoveDuplicateElements(cleaned...)
|
||
}
|