hi-server/internal/logic/subscribe/subscribeLogic.go
EUForest 39310d5b9a Features:
- Node group CRUD operations with traffic-based filtering
  - Three grouping modes: average distribution, subscription-based, and traffic-based
  - Automatic and manual group recalculation with history tracking
  - Group assignment preview before applying changes
  - User subscription group locking to prevent automatic reassignment
  - Subscribe-to-group mapping configuration
  - Group calculation history and detailed reports
  - System configuration for group management (enabled/mode/auto_create)

  Database:
  - Add node_group table for group definitions
  - Add group_history and group_history_detail tables for tracking
  - Add node_group_ids (JSON) to nodes and subscribe tables
  - Add node_group_id and group_locked fields to user_subscribe table
  - Add migration files for schema changes
2026-03-08 23:22:38 +08:00

425 lines
13 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package subscribe
import (
"fmt"
"net/url"
"strings"
"time"
"github.com/perfect-panel/server/adapter"
"github.com/perfect-panel/server/internal/model/client"
"github.com/perfect-panel/server/internal/model/log"
"github.com/perfect-panel/server/internal/model/node"
"github.com/perfect-panel/server/internal/report"
"github.com/perfect-panel/server/internal/model/user"
"github.com/gin-gonic/gin"
"github.com/perfect-panel/server/internal/svc"
"github.com/perfect-panel/server/internal/types"
"github.com/perfect-panel/server/pkg/logger"
"github.com/perfect-panel/server/pkg/tool"
"github.com/perfect-panel/server/pkg/xerr"
"github.com/pkg/errors"
)
//goland:noinspection GoNameStartsWithPackageName
type SubscribeLogic struct {
ctx *gin.Context
svc *svc.ServiceContext
logger.Logger
}
func NewSubscribeLogic(ctx *gin.Context, svc *svc.ServiceContext) *SubscribeLogic {
return &SubscribeLogic{
ctx: ctx,
svc: svc,
Logger: logger.WithContext(ctx.Request.Context()),
}
}
func (l *SubscribeLogic) Handler(req *types.SubscribeRequest) (resp *types.SubscribeResponse, err error) {
// query client list
clients, err := l.svc.ClientModel.List(l.ctx.Request.Context())
if err != nil {
l.Errorw("[SubscribeLogic] Query client list failed", logger.Field("error", err.Error()))
return nil, err
}
userAgent := strings.ToLower(l.ctx.Request.UserAgent())
var targetApp, defaultApp *client.SubscribeApplication
for _, item := range clients {
u := strings.ToLower(item.UserAgent)
if item.IsDefault {
defaultApp = item
}
if strings.Contains(userAgent, u) {
// Special handling for Stash
if strings.Contains(userAgent, "stash") && !strings.Contains(u, "stash") {
continue
}
targetApp = item
break
}
}
if targetApp == nil {
l.Debugf("[SubscribeLogic] No matching client found", logger.Field("userAgent", userAgent))
if defaultApp == nil {
return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "No matching client found for user agent: %s", userAgent)
}
targetApp = defaultApp
}
// Find user subscribe by token
userSubscribe, err := l.getUserSubscribe(req.Token)
if err != nil {
l.Errorw("[SubscribeLogic] Get user subscribe failed", logger.Field("error", err.Error()), logger.Field("token", req.Token))
return nil, err
}
var subscribeStatus = false
defer func() {
l.logSubscribeActivity(subscribeStatus, userSubscribe, req)
}()
// find subscribe info
subscribeInfo, err := l.svc.SubscribeModel.FindOne(l.ctx.Request.Context(), userSubscribe.SubscribeId)
if err != nil {
l.Errorw("[SubscribeLogic] Find subscribe info failed", logger.Field("error", err.Error()), logger.Field("subscribeId", userSubscribe.SubscribeId))
return nil, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "Find subscribe info failed: %v", err.Error())
}
// Find server list by user subscribe
servers, err := l.getServers(userSubscribe)
if err != nil {
return nil, err
}
a := adapter.NewAdapter(
targetApp.SubscribeTemplate,
adapter.WithServers(servers),
adapter.WithSiteName(l.svc.Config.Site.SiteName),
adapter.WithSubscribeName(subscribeInfo.Name),
adapter.WithOutputFormat(targetApp.OutputFormat),
adapter.WithUserInfo(adapter.User{
Password: userSubscribe.UUID,
ExpiredAt: userSubscribe.ExpireTime,
Download: userSubscribe.Download,
Upload: userSubscribe.Upload,
Traffic: userSubscribe.Traffic,
SubscribeURL: l.getSubscribeV2URL(),
}),
adapter.WithParams(req.Params),
)
logger.Debugf("[SubscribeLogic] Building client config for user %d with URI %s", userSubscribe.UserId, l.getSubscribeV2URL())
// Get client config
adapterClient, err := a.Client()
if err != nil {
l.Errorw("[SubscribeLogic] Client error", logger.Field("error", err.Error()))
return nil, errors.Wrapf(xerr.NewErrCode(500), "Client error: %v", err.Error())
}
bytes, err := adapterClient.Build()
if err != nil {
l.Errorw("[SubscribeLogic] Build client config failed", logger.Field("error", err.Error()))
return nil, errors.Wrapf(xerr.NewErrCode(500), "Build client config failed: %v", err.Error())
}
var formats = []string{"json", "yaml", "conf"}
for _, format := range formats {
if format == strings.ToLower(targetApp.OutputFormat) {
l.ctx.Header("content-disposition", fmt.Sprintf("attachment;filename*=UTF-8''%s.%s", url.QueryEscape(l.svc.Config.Site.SiteName), format))
l.ctx.Header("Content-Type", "application/octet-stream; charset=UTF-8")
}
}
resp = &types.SubscribeResponse{
Config: bytes,
Header: fmt.Sprintf(
"upload=%d;download=%d;total=%d;expire=%d",
userSubscribe.Upload, userSubscribe.Download, userSubscribe.Traffic, userSubscribe.ExpireTime.Unix(),
),
}
subscribeStatus = true
return
}
func (l *SubscribeLogic) getSubscribeV2URL() string {
uri := l.ctx.Request.RequestURI
// is gateway mode, add /sub prefix
if report.IsGatewayMode() {
uri = "/sub" + uri
}
// use custom domain if configured
if l.svc.Config.Subscribe.SubscribeDomain != "" {
domains := strings.Split(l.svc.Config.Subscribe.SubscribeDomain, "\n")
return fmt.Sprintf("https://%s%s", domains[0], uri)
}
// use current request host
return fmt.Sprintf("https://%s%s", l.ctx.Request.Host, uri)
}
func (l *SubscribeLogic) getUserSubscribe(token string) (*user.Subscribe, error) {
userSub, err := l.svc.UserModel.FindOneSubscribeByToken(l.ctx.Request.Context(), token)
if err != nil {
l.Infow("[Generate Subscribe]find subscribe error: %v", logger.Field("error", err.Error()), logger.Field("token", token))
return nil, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "find subscribe error: %v", err.Error())
}
// Ignore expiration check
//if userSub.Status > 1 {
// l.Infow("[Generate Subscribe]subscribe is not available", logger.Field("status", int(userSub.Status)), logger.Field("token", token))
// return nil, errors.Wrapf(xerr.NewErrCode(xerr.SubscribeNotAvailable), "subscribe is not available")
//}
return userSub, nil
}
func (l *SubscribeLogic) logSubscribeActivity(subscribeStatus bool, userSub *user.Subscribe, req *types.SubscribeRequest) {
if !subscribeStatus {
return
}
subscribeLog := log.Subscribe{
Token: req.Token,
UserAgent: req.UA,
ClientIP: l.ctx.ClientIP(),
UserSubscribeId: userSub.Id,
}
content, _ := subscribeLog.Marshal()
err := l.svc.LogModel.Insert(l.ctx.Request.Context(), &log.SystemLog{
Type: log.TypeSubscribe.Uint8(),
ObjectID: userSub.UserId, // log user id
Date: time.Now().Format(time.DateOnly),
Content: string(content),
})
if err != nil {
l.Errorw("[Generate Subscribe]insert subscribe log error: %v", logger.Field("error", err.Error()))
}
}
func (l *SubscribeLogic) getServers(userSub *user.Subscribe) ([]*node.Node, error) {
if l.isSubscriptionExpired(userSub) {
return l.createExpiredServers(), nil
}
subDetails, err := l.svc.SubscribeModel.FindOne(l.ctx.Request.Context(), userSub.SubscribeId)
if err != nil {
l.Errorw("[Generate Subscribe]find subscribe details error: %v", logger.Field("error", err.Error()))
return nil, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "find subscribe details error: %v", err.Error())
}
// 判断是否使用分组模式
isGroupMode := l.isGroupEnabled()
if isGroupMode {
// === 分组模式:使用 node_group_id 获取节点 ===
// 按优先级获取 node_group_iduser_subscribe.node_group_id > subscribe.node_group_id > subscribe.node_group_ids[0]
nodeGroupId := int64(0)
source := ""
// 优先级1: user_subscribe.node_group_id
if userSub.NodeGroupId != 0 {
nodeGroupId = userSub.NodeGroupId
source = "user_subscribe.node_group_id"
} else {
// 优先级2 & 3: 从 subscribe 表获取
if subDetails.NodeGroupId != 0 {
nodeGroupId = subDetails.NodeGroupId
source = "subscribe.node_group_id"
} else if len(subDetails.NodeGroupIds) > 0 {
// 优先级3: subscribe.node_group_ids[0]
nodeGroupId = subDetails.NodeGroupIds[0]
source = "subscribe.node_group_ids[0]"
}
}
l.Debugf("[Generate Subscribe]group mode, using %s: %v", source, nodeGroupId)
// 根据 node_group_id 获取节点
enable := true
// 1. 获取分组节点
var groupNodes []*node.Node
if nodeGroupId > 0 {
params := &node.FilterNodeParams{
Page: 0,
Size: 1000,
NodeGroupIds: []int64{nodeGroupId},
Enabled: &enable,
Preload: true,
}
_, groupNodes, err = l.svc.NodeModel.FilterNodeList(l.ctx.Request.Context(), params)
if err != nil {
l.Errorw("[Generate Subscribe]filter nodes by group error", logger.Field("error", err.Error()))
return nil, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "filter nodes by group error: %v", err.Error())
}
l.Debugf("[Generate Subscribe]found %d nodes for node_group_id=%d", len(groupNodes), nodeGroupId)
}
// 2. 获取公共节点NodeGroupIds 为空的节点)
_, allNodes, err := l.svc.NodeModel.FilterNodeList(l.ctx.Request.Context(), &node.FilterNodeParams{
Page: 0,
Size: 1000,
Enabled: &enable,
Preload: true,
})
if err != nil {
l.Errorw("[Generate Subscribe]filter all nodes error", logger.Field("error", err.Error()))
return nil, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "filter all nodes error: %v", err.Error())
}
// 过滤出公共节点
var publicNodes []*node.Node
for _, n := range allNodes {
if len(n.NodeGroupIds) == 0 {
publicNodes = append(publicNodes, n)
}
}
l.Debugf("[Generate Subscribe]found %d public nodes (node_group_ids is empty)", len(publicNodes))
// 3. 合并分组节点和公共节点
nodesMap := make(map[int64]*node.Node)
for _, n := range groupNodes {
nodesMap[n.Id] = n
}
for _, n := range publicNodes {
if _, exists := nodesMap[n.Id]; !exists {
nodesMap[n.Id] = n
}
}
// 转换为切片
var result []*node.Node
for _, n := range nodesMap {
result = append(result, n)
}
l.Debugf("[Generate Subscribe]total nodes (group + public): %d (group: %d, public: %d)", len(result), len(groupNodes), len(publicNodes))
// 查询节点组信息,获取节点组名称(仅当用户有分组时)
if nodeGroupId > 0 {
type NodeGroupInfo struct {
Id int64
Name string
}
var nodeGroupInfo NodeGroupInfo
err = l.svc.DB.Table("node_group").Select("id, name").Where("id = ?", nodeGroupId).First(&nodeGroupInfo).Error
if err != nil {
l.Infow("[Generate Subscribe]node group not found", logger.Field("nodeGroupId", nodeGroupId), logger.Field("error", err.Error()))
}
// 如果节点组信息存在,为没有 tag 的分组节点设置节点组名称为 tag
if nodeGroupInfo.Id != 0 && nodeGroupInfo.Name != "" {
for _, n := range result {
// 只为分组节点设置 tag公共节点不设置
if n.Tags == "" && len(n.NodeGroupIds) > 0 {
n.Tags = nodeGroupInfo.Name
l.Debugf("[Generate Subscribe]set node_group name as tag for node %d: %s", n.Id, nodeGroupInfo.Name)
}
}
}
}
return result, nil
}
// === 标签模式:使用 node_ids 和 tags 获取节点 ===
nodeIds := tool.StringToInt64Slice(subDetails.Nodes)
tags := tool.RemoveStringElement(strings.Split(subDetails.NodeTags, ","), "")
l.Debugf("[Generate Subscribe]tag mode, nodes: %v, NodeTags: %v", len(nodeIds), len(tags))
if len(nodeIds) == 0 && len(tags) == 0 {
logger.Infow("[Generate Subscribe]no subscribe nodes configured")
return []*node.Node{}, nil
}
enable := true
var nodes []*node.Node
_, nodes, err = l.svc.NodeModel.FilterNodeList(l.ctx.Request.Context(), &node.FilterNodeParams{
Page: 1,
Size: 1000,
NodeId: nodeIds,
Tag: tool.RemoveDuplicateElements(tags...),
Preload: true,
Enabled: &enable,
})
if err != nil {
l.Errorw("[Generate Subscribe]find server details error: %v", logger.Field("error", err.Error()))
return nil, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "find server details error: %v", err.Error())
}
l.Debugf("[Generate Subscribe]found %d nodes in tag mode", len(nodes))
return nodes, nil
}
func (l *SubscribeLogic) isSubscriptionExpired(userSub *user.Subscribe) bool {
return userSub.ExpireTime.Unix() < time.Now().Unix() && userSub.ExpireTime.Unix() != 0
}
func (l *SubscribeLogic) createExpiredServers() []*node.Node {
enable := true
host := l.getFirstHostLine()
return []*node.Node{
{
Name: "Subscribe Expired",
Tags: "",
Port: 18080,
Address: "127.0.0.1",
Server: &node.Server{
Id: 1,
Name: "Subscribe Expired",
Protocols: "[{\"type\":\"shadowsocks\",\"cipher\":\"aes-256-gcm\",\"port\":1}]",
},
Protocol: "shadowsocks",
Enabled: &enable,
},
{
Name: host,
Tags: "",
Port: 18080,
Address: "127.0.0.1",
Server: &node.Server{
Id: 1,
Name: "Subscribe Expired",
Protocols: "[{\"type\":\"shadowsocks\",\"cipher\":\"aes-256-gcm\",\"port\":1}]",
},
Protocol: "shadowsocks",
Enabled: &enable,
},
}
}
func (l *SubscribeLogic) getFirstHostLine() string {
host := l.svc.Config.Host
lines := strings.Split(host, "\n")
if len(lines) > 0 {
return lines[0]
}
return host
}
// isGroupEnabled 判断分组功能是否启用
func (l *SubscribeLogic) isGroupEnabled() bool {
var value string
err := l.svc.DB.Table("system").
Where("`category` = ? AND `key` = ?", "group", "enabled").
Select("value").
Scan(&value).Error
if err != nil {
l.Debugf("[SubscribeLogic]check group enabled failed: %v", err)
return false
}
return value == "true" || value == "1"
}