hi-server/internal/logic/public/subscribe/queryUserSubscribeNodeListLogic.go
EUForest 06a2425474 feat(subscribe): add traffic limit rules and user traffic stats
- Add subscribe traffic_limit schema and migration\n- Support traffic_limit in admin create/update and list/details\n- Apply traffic_limit when building server user list speed limits\n- Add public user traffic stats API
2026-03-14 12:41:52 +08:00

433 lines
14 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package subscribe
import (
"context"
"strings"
"time"
"github.com/perfect-panel/server/internal/model/group"
"github.com/perfect-panel/server/internal/model/node"
"github.com/perfect-panel/server/internal/model/user"
"github.com/perfect-panel/server/internal/svc"
"github.com/perfect-panel/server/internal/types"
"github.com/perfect-panel/server/pkg/constant"
"github.com/perfect-panel/server/pkg/logger"
"github.com/perfect-panel/server/pkg/tool"
"github.com/perfect-panel/server/pkg/xerr"
"github.com/pkg/errors"
)
type QueryUserSubscribeNodeListLogic struct {
logger.Logger
ctx context.Context
svcCtx *svc.ServiceContext
}
// Get user subscribe node info
func NewQueryUserSubscribeNodeListLogic(ctx context.Context, svcCtx *svc.ServiceContext) *QueryUserSubscribeNodeListLogic {
return &QueryUserSubscribeNodeListLogic{
Logger: logger.WithContext(ctx),
ctx: ctx,
svcCtx: svcCtx,
}
}
func (l *QueryUserSubscribeNodeListLogic) QueryUserSubscribeNodeList() (resp *types.QueryUserSubscribeNodeListResponse, err error) {
u, ok := l.ctx.Value(constant.CtxKeyUser).(*user.User)
if !ok {
logger.Error("current user is not found in context")
return nil, errors.Wrapf(xerr.NewErrCode(xerr.InvalidAccess), "Invalid Access")
}
userSubscribes, err := l.svcCtx.UserModel.QueryUserSubscribe(l.ctx, u.Id, 0, 1, 2, 3)
if err != nil {
logger.Errorw("failed to query user subscribe", logger.Field("error", err.Error()), logger.Field("user_id", u.Id))
return nil, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "DB_ERROR")
}
resp = &types.QueryUserSubscribeNodeListResponse{}
for _, us := range userSubscribes {
userSubscribe, err := l.getUserSubscribe(us.Token)
if err != nil {
l.Errorw("[SubscribeLogic] Get user subscribe failed", logger.Field("error", err.Error()), logger.Field("token", userSubscribe.Token))
return nil, err
}
nodes, err := l.getServers(userSubscribe)
if err != nil {
return nil, err
}
userSubscribeInfo := types.UserSubscribeInfo{
Id: userSubscribe.Id,
Nodes: nodes,
Traffic: userSubscribe.Traffic,
Upload: userSubscribe.Upload,
Download: userSubscribe.Download,
Token: userSubscribe.Token,
UserId: userSubscribe.UserId,
OrderId: userSubscribe.OrderId,
SubscribeId: userSubscribe.SubscribeId,
StartTime: userSubscribe.StartTime.Unix(),
ExpireTime: userSubscribe.ExpireTime.Unix(),
Status: userSubscribe.Status,
CreatedAt: userSubscribe.CreatedAt.Unix(),
UpdatedAt: userSubscribe.UpdatedAt.Unix(),
}
if userSubscribe.FinishedAt != nil {
userSubscribeInfo.FinishedAt = userSubscribe.FinishedAt.Unix()
}
if l.svcCtx.Config.Register.EnableTrial && l.svcCtx.Config.Register.TrialSubscribe == userSubscribe.SubscribeId {
userSubscribeInfo.IsTryOut = true
}
resp.List = append(resp.List, userSubscribeInfo)
}
return
}
func (l *QueryUserSubscribeNodeListLogic) getServers(userSub *user.Subscribe) (userSubscribeNodes []*types.UserSubscribeNodeInfo, err error) {
userSubscribeNodes = make([]*types.UserSubscribeNodeInfo, 0)
if l.isSubscriptionExpired(userSub) {
return l.createExpiredServers(userSub), nil
}
// Check if group management is enabled
var groupEnabled string
err = l.svcCtx.DB.Table("system").
Where("`category` = ? AND `key` = ?", "group", "enabled").
Select("value").Scan(&groupEnabled).Error
if err != nil {
l.Debugw("[GetServers] Failed to check group enabled", logger.Field("error", err.Error()))
// Continue with tag-based filtering
}
isGroupEnabled := (groupEnabled == "true" || groupEnabled == "1")
var nodes []*node.Node
if isGroupEnabled {
// Group mode: use group_ids to filter nodes
nodes, err = l.getNodesByGroup(userSub)
if err != nil {
l.Errorw("[GetServers] Failed to get nodes by group", logger.Field("error", err.Error()))
return nil, err
}
} else {
// Tag mode: use node_ids and tags to filter nodes
nodes, err = l.getNodesByTag(userSub)
if err != nil {
l.Errorw("[GetServers] Failed to get nodes by tag", logger.Field("error", err.Error()))
return nil, err
}
}
// Process nodes and create response
if len(nodes) > 0 {
var serverMapIds = make(map[int64]*node.Server)
for _, n := range nodes {
serverMapIds[n.ServerId] = nil
}
var serverIds []int64
for k := range serverMapIds {
serverIds = append(serverIds, k)
}
servers, err := l.svcCtx.NodeModel.QueryServerList(l.ctx, serverIds)
if err != nil {
l.Errorw("[Generate Subscribe]find server details error: %v", logger.Field("error", err.Error()))
return nil, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "find server details error: %v", err.Error())
}
for _, s := range servers {
serverMapIds[s.Id] = s
}
for _, n := range nodes {
server := serverMapIds[n.ServerId]
if server == nil {
continue
}
userSubscribeNode := &types.UserSubscribeNodeInfo{
Id: n.Id,
Name: n.Name,
Uuid: userSub.UUID,
Protocol: n.Protocol,
Protocols: server.Protocols,
Port: n.Port,
Address: n.Address,
Tags: strings.Split(n.Tags, ","),
Country: server.Country,
City: server.City,
Latitude: server.Latitude,
Longitude: server.Longitude,
LongitudeCenter: server.LongitudeCenter,
LatitudeCenter: server.LatitudeCenter,
CreatedAt: n.CreatedAt.Unix(),
}
userSubscribeNodes = append(userSubscribeNodes, userSubscribeNode)
}
}
l.Debugf("[Query Subscribe]found servers: %v", len(nodes))
return userSubscribeNodes, nil
}
// getNodesByGroup gets nodes based on user subscription node_group_id with priority fallback
func (l *QueryUserSubscribeNodeListLogic) getNodesByGroup(userSub *user.Subscribe) ([]*node.Node, error) {
// 按优先级获取 node_group_iduser_subscribe.node_group_id > subscribe.node_group_id > subscribe.node_group_ids[0]
nodeGroupId := int64(0)
source := ""
var directNodeIds []int64
// 优先级1: user_subscribe.node_group_id
if userSub.NodeGroupId != 0 {
nodeGroupId = userSub.NodeGroupId
source = "user_subscribe.node_group_id"
}
// 获取 subscribe 详情(用于获取 node_group_id 和直接分配的节点)
subDetails, err := l.svcCtx.SubscribeModel.FindOne(l.ctx, userSub.SubscribeId)
if err != nil {
l.Errorw("[GetNodesByGroup] find subscribe details error", logger.Field("error", err.Error()))
return nil, err
}
// 获取直接分配的节点ID
directNodeIds = tool.StringToInt64Slice(subDetails.Nodes)
l.Debugf("[GetNodesByGroup] direct nodes: %v", directNodeIds)
// 如果 user_subscribe 没有 node_group_id从 subscribe 获取
if nodeGroupId == 0 {
// 优先级2: subscribe.node_group_id
if subDetails.NodeGroupId != 0 {
nodeGroupId = subDetails.NodeGroupId
source = "subscribe.node_group_id"
} else if len(subDetails.NodeGroupIds) > 0 {
// 优先级3: subscribe.node_group_ids[0]
nodeGroupId = subDetails.NodeGroupIds[0]
source = "subscribe.node_group_ids[0]"
}
}
l.Debugf("[GetNodesByGroup] Using %s: %v", source, nodeGroupId)
// 查询所有启用的节点
enable := true
_, allNodes, err := l.svcCtx.NodeModel.FilterNodeList(l.ctx, &node.FilterNodeParams{
Page: 0,
Size: 10000,
Enabled: &enable,
})
if err != nil {
l.Errorw("[GetNodesByGroup] FilterNodeList error", logger.Field("error", err.Error()))
return nil, err
}
// 过滤节点
var resultNodes []*node.Node
nodeIdMap := make(map[int64]bool)
for _, n := range allNodes {
// 1. 公共节点node_group_ids 为空),所有人可见
if len(n.NodeGroupIds) == 0 {
if !nodeIdMap[n.Id] {
resultNodes = append(resultNodes, n)
nodeIdMap[n.Id] = true
}
continue
}
// 2. 如果有节点组,检查节点是否属于该节点组
if nodeGroupId != 0 {
for _, gid := range n.NodeGroupIds {
if gid == nodeGroupId {
if !nodeIdMap[n.Id] {
resultNodes = append(resultNodes, n)
nodeIdMap[n.Id] = true
}
break
}
}
}
}
// 3. 添加直接分配的节点
if len(directNodeIds) > 0 {
for _, n := range allNodes {
if tool.Contains(directNodeIds, n.Id) && !nodeIdMap[n.Id] {
resultNodes = append(resultNodes, n)
nodeIdMap[n.Id] = true
}
}
}
l.Debugf("[GetNodesByGroup] Found %d nodes (group=%d, direct=%d)", len(resultNodes), nodeGroupId, len(directNodeIds))
return resultNodes, nil
}
// getNodesByTag gets nodes based on subscribe node_ids and tags
func (l *QueryUserSubscribeNodeListLogic) getNodesByTag(userSub *user.Subscribe) ([]*node.Node, error) {
subDetails, err := l.svcCtx.SubscribeModel.FindOne(l.ctx, userSub.SubscribeId)
if err != nil {
l.Errorw("[Generate Subscribe]find subscribe details error: %v", logger.Field("error", err.Error()))
return nil, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "find subscribe details error: %v", err.Error())
}
nodeIds := tool.StringToInt64Slice(subDetails.Nodes)
tags := strings.Split(subDetails.NodeTags, ",")
newTags := make([]string, 0)
for _, t := range tags {
if t != "" {
newTags = append(newTags, t)
}
}
tags = newTags
l.Debugf("[Generate Subscribe]nodes: %v, NodeTags: %v", nodeIds, tags)
enable := true
_, nodes, err := l.svcCtx.NodeModel.FilterNodeList(l.ctx, &node.FilterNodeParams{
Page: 0,
Size: 1000,
NodeId: nodeIds,
Tag: tags,
Enabled: &enable, // Only get enabled nodes
})
return nodes, err
}
// getAllNodes returns all enabled nodes
func (l *QueryUserSubscribeNodeListLogic) getAllNodes() ([]*node.Node, error) {
enable := true
_, nodes, err := l.svcCtx.NodeModel.FilterNodeList(l.ctx, &node.FilterNodeParams{
Page: 0,
Size: 1000,
Enabled: &enable,
})
return nodes, err
}
func (l *QueryUserSubscribeNodeListLogic) isSubscriptionExpired(userSub *user.Subscribe) bool {
return userSub.ExpireTime.Unix() < time.Now().Unix() && userSub.ExpireTime.Unix() != 0
}
func (l *QueryUserSubscribeNodeListLogic) createExpiredServers(userSub *user.Subscribe) []*types.UserSubscribeNodeInfo {
// 1. 查询过期节点组
var expiredGroup group.NodeGroup
err := l.svcCtx.DB.Where("is_expired_group = ?", true).First(&expiredGroup).Error
if err != nil {
l.Debugw("no expired node group configured", logger.Field("error", err))
return nil
}
// 2. 检查用户是否在过期天数限制内
expiredDays := int(time.Since(userSub.ExpireTime).Hours() / 24)
if expiredDays > expiredGroup.ExpiredDaysLimit {
l.Debugf("user subscription expired %d days, exceeds limit %d days", expiredDays, expiredGroup.ExpiredDaysLimit)
return nil
}
// 3. 检查用户已使用流量是否超过限制(仅使用过期期间的流量)
if expiredGroup.MaxTrafficGBExpired != nil && *expiredGroup.MaxTrafficGBExpired > 0 {
usedTrafficGB := (userSub.ExpiredDownload + userSub.ExpiredUpload) / (1024 * 1024 * 1024)
if usedTrafficGB >= *expiredGroup.MaxTrafficGBExpired {
l.Debugf("user expired traffic %d GB, exceeds expired group limit %d GB", usedTrafficGB, *expiredGroup.MaxTrafficGBExpired)
return nil
}
}
// 4. 查询过期节点组的节点
enable := true
_, nodes, err := l.svcCtx.NodeModel.FilterNodeList(l.ctx, &node.FilterNodeParams{
Page: 0,
Size: 1000,
NodeGroupIds: []int64{expiredGroup.Id},
Enabled: &enable,
})
if err != nil {
l.Errorw("failed to query expired group nodes", logger.Field("error", err))
return nil
}
if len(nodes) == 0 {
l.Debug("no nodes found in expired group")
return nil
}
// 5. 查询服务器信息
var serverMapIds = make(map[int64]*node.Server)
for _, n := range nodes {
serverMapIds[n.ServerId] = nil
}
var serverIds []int64
for k := range serverMapIds {
serverIds = append(serverIds, k)
}
servers, err := l.svcCtx.NodeModel.QueryServerList(l.ctx, serverIds)
if err != nil {
l.Errorw("failed to query servers", logger.Field("error", err))
return nil
}
for _, s := range servers {
serverMapIds[s.Id] = s
}
// 6. 构建节点列表
userSubscribeNodes := make([]*types.UserSubscribeNodeInfo, 0, len(nodes))
for _, n := range nodes {
server := serverMapIds[n.ServerId]
if server == nil {
continue
}
userSubscribeNode := &types.UserSubscribeNodeInfo{
Id: n.Id,
Name: n.Name,
Uuid: userSub.UUID,
Protocol: n.Protocol,
Protocols: server.Protocols,
Port: n.Port,
Address: n.Address,
Tags: strings.Split(n.Tags, ","),
Country: server.Country,
City: server.City,
Latitude: server.Latitude,
Longitude: server.Longitude,
LongitudeCenter: server.LongitudeCenter,
LatitudeCenter: server.LatitudeCenter,
CreatedAt: n.CreatedAt.Unix(),
}
userSubscribeNodes = append(userSubscribeNodes, userSubscribeNode)
}
l.Infof("returned %d nodes from expired group for user %d (expired %d days)", len(userSubscribeNodes), userSub.UserId, expiredDays)
return userSubscribeNodes
}
func (l *QueryUserSubscribeNodeListLogic) getFirstHostLine() string {
host := l.svcCtx.Config.Host
lines := strings.Split(host, "\n")
if len(lines) > 0 {
return lines[0]
}
return host
}
func (l *QueryUserSubscribeNodeListLogic) getUserSubscribe(token string) (*user.Subscribe, error) {
userSub, err := l.svcCtx.UserModel.FindOneSubscribeByToken(l.ctx, token)
if err != nil {
l.Infow("[Generate Subscribe]find subscribe error: %v", logger.Field("error", err.Error()), logger.Field("token", token))
return nil, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "find subscribe error: %v", err.Error())
}
// Ignore expiration check
//if userSub.Status > 1 {
// l.Infow("[Generate Subscribe]subscribe is not available", logger.Field("status", int(userSub.Status)), logger.Field("token", token))
// return nil, errors.Wrapf(xerr.NewErrCode(xerr.SubscribeNotAvailable), "subscribe is not available")
//}
return userSub, nil
}