server/internal/logic/subscribe/subscribeLogic.go
shanshanzhong a52c7142ee feat: 添加在线设备统计功能并优化订阅相关逻辑
- 在DeviceManager中添加GetOnlineDeviceCount方法用于获取在线设备数
- 在统计接口中增加在线设备数返回
- 优化订阅查询逻辑,增加服务组关联节点数量计算
- 添加AnyTLS协议支持及相关URI生成功能
- 重构邀请佣金计算逻辑,支持首购/年付/非首购不同比例
- 修复用户基本信息更新中IsAdmin和Enable字段类型不匹配问题
- 更新数据库迁移脚本和配置文件中邀请相关配置项
2025-08-12 07:46:45 -07:00

382 lines
12 KiB
Go

package subscribe
import (
"fmt"
"net/url"
"strings"
"time"
"github.com/perfect-panel/server/pkg/adapter"
"github.com/perfect-panel/server/pkg/adapter/shadowrocket"
"github.com/perfect-panel/server/pkg/adapter/surfboard"
"github.com/perfect-panel/server/pkg/adapter/surge"
"github.com/perfect-panel/server/internal/model/server"
"github.com/perfect-panel/server/internal/model/user"
"github.com/gin-gonic/gin"
"github.com/perfect-panel/server/internal/svc"
"github.com/perfect-panel/server/internal/types"
"github.com/perfect-panel/server/pkg/logger"
"github.com/perfect-panel/server/pkg/tool"
"github.com/perfect-panel/server/pkg/xerr"
"github.com/pkg/errors"
)
//goland:noinspection GoNameStartsWithPackageName
type SubscribeLogic struct {
ctx *gin.Context
svc *svc.ServiceContext
logger.Logger
}
func NewSubscribeLogic(ctx *gin.Context, svc *svc.ServiceContext) *SubscribeLogic {
return &SubscribeLogic{
ctx: ctx,
svc: svc,
Logger: logger.WithContext(ctx.Request.Context()),
}
}
func (l *SubscribeLogic) Generate(req *types.SubscribeRequest) (*types.SubscribeResponse, error) {
userSub, err := l.getUserSubscribe(req.Token)
if err != nil {
return nil, err
}
var subscribeStatus = false
defer func() {
l.logSubscribeActivity(subscribeStatus, userSub, req)
}()
servers, err := l.getServers(userSub)
if err != nil {
return nil, err
}
rules, err := l.getRules()
if err != nil {
return nil, err
}
resp, headerInfo, err := l.buildClientConfig(req, userSub, servers, rules)
if err != nil {
return nil, err
}
subscribeStatus = true
return &types.SubscribeResponse{
Config: resp,
Header: headerInfo,
}, nil
}
func (l *SubscribeLogic) getUserSubscribe(token string) (*user.Subscribe, error) {
userSub, err := l.svc.UserModel.FindOneSubscribeByToken(l.ctx.Request.Context(), token)
if err != nil {
l.Infow("[Generate Subscribe]find subscribe error: %v", logger.Field("error", err.Error()), logger.Field("token", token))
return nil, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "find subscribe error: %v", err.Error())
}
if userSub.Status > 1 {
l.Infow("[Generate Subscribe]subscribe is not available", logger.Field("status", int(userSub.Status)), logger.Field("token", token))
return nil, errors.Wrapf(xerr.NewErrCode(xerr.SubscribeNotAvailable), "subscribe is not available")
}
return userSub, nil
}
func (l *SubscribeLogic) logSubscribeActivity(subscribeStatus bool, userSub *user.Subscribe, req *types.SubscribeRequest) {
if !subscribeStatus {
return
}
err := l.svc.UserModel.InsertSubscribeLog(l.ctx.Request.Context(), &user.SubscribeLog{
UserId: userSub.UserId,
UserSubscribeId: userSub.Id,
Token: req.Token,
IP: l.ctx.ClientIP(),
UserAgent: l.ctx.Request.UserAgent(),
})
if err != nil {
l.Errorw("[Generate Subscribe]insert subscribe log error: %v", logger.Field("error", err.Error()))
}
}
func (l *SubscribeLogic) getServers(userSub *user.Subscribe) ([]*server.Server, error) {
if l.isSubscriptionExpired(userSub) {
return l.createExpiredServers(), nil
}
subDetails, err := l.svc.SubscribeModel.FindOne(l.ctx.Request.Context(), userSub.SubscribeId)
if err != nil {
l.Errorw("[Generate Subscribe]find subscribe details error: %v", logger.Field("error", err.Error()))
return nil, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "find subscribe details error: %v", err.Error())
}
serverIds := tool.StringToInt64Slice(subDetails.Server)
groupIds := tool.StringToInt64Slice(subDetails.ServerGroup)
// 🔍 订阅ID 2的详细调试
if userSub.SubscribeId == 2 {
l.Infof("🔍 [DEBUG Subscribe 2] === 开始调试订阅ID 2 ===")
l.Infof("🔍 [DEBUG Subscribe 2] Subscribe详情: %+v", subDetails)
l.Infof("🔍 [DEBUG Subscribe 2] Server字段: %s", subDetails.Server)
l.Infof("🔍 [DEBUG Subscribe 2] ServerGroup字段: %s", subDetails.ServerGroup)
l.Infof("🔍 [DEBUG Subscribe 2] 解析后的serverIds: %v", serverIds)
l.Infof("🔍 [DEBUG Subscribe 2] 解析后的groupIds: %v", groupIds)
}
l.Debugf("[Generate Subscribe]serverIds: %v, groupIds: %v", serverIds, groupIds)
// 查询所有服务器用于调试
allServers, _ := l.svc.ServerModel.FindAllServer(l.ctx.Request.Context())
if userSub.SubscribeId == 2 {
l.Infof("🔍 [DEBUG Subscribe 2] 数据库中所有服务器:")
for _, srv := range allServers {
l.Infof("🔍 [DEBUG Subscribe 2] ID:%d Name:%s Protocol:%s Enable:%v GroupID:%d",
srv.Id, srv.Name, srv.Protocol, *srv.Enable, srv.GroupId)
}
}
servers, err := l.svc.ServerModel.FindServerDetailByGroupIdsAndIds(l.ctx.Request.Context(), groupIds, serverIds)
if userSub.SubscribeId == 2 {
l.Infof("🔍 [DEBUG Subscribe 2] 查询结果服务器数量: %d", len(servers))
for i, srv := range servers {
l.Infof("🔍 [DEBUG Subscribe 2] 结果服务器 %d: ID=%d Name=%s Protocol=%s Enable=%v",
i+1, srv.Id, srv.Name, srv.Protocol, *srv.Enable)
}
// 检查AnyTLS服务器
anytlsServers := []*server.Server{}
for _, srv := range servers {
if srv.Protocol == "anytls" {
anytlsServers = append(anytlsServers, srv)
}
}
l.Infof("🔍 [DEBUG Subscribe 2] AnyTLS服务器数量: %d", len(anytlsServers))
}
l.Debugf("[Query Subscribe]found servers: %v", len(servers))
if err != nil {
l.Errorw("[Generate Subscribe]find server details error: %v", logger.Field("error", err.Error()))
return nil, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "find server details error: %v", err.Error())
}
logger.Debugf("[Generate Subscribe]found servers: %v", len(servers))
return servers, nil
}
func (l *SubscribeLogic) isSubscriptionExpired(userSub *user.Subscribe) bool {
return userSub.ExpireTime.Unix() < time.Now().Unix() && userSub.ExpireTime.Unix() != 0
}
func (l *SubscribeLogic) createExpiredServers() []*server.Server {
enable := true
host := l.getFirstHostLine()
return []*server.Server{
{
Name: "Subscribe Expired",
ServerAddr: "127.0.0.1",
RelayMode: "none",
Protocol: "shadowsocks",
Config: "{\"method\":\"aes-256-gcm\",\"port\":1}",
Enable: &enable,
Sort: 0,
},
{
Name: host,
ServerAddr: "127.0.0.1",
RelayMode: "none",
Protocol: "shadowsocks",
Config: "{\"method\":\"aes-256-gcm\",\"port\":1}",
Enable: &enable,
Sort: 0,
},
}
}
func (l *SubscribeLogic) getFirstHostLine() string {
host := l.svc.Config.Host
lines := strings.Split(host, "\n")
if len(lines) > 0 {
return lines[0]
}
return host
}
func (l *SubscribeLogic) getRules() ([]*server.RuleGroup, error) {
rules, err := l.svc.ServerModel.QueryAllRuleGroup(l.ctx)
if err != nil {
l.Errorw("[Generate Subscribe]find rule group error: %v", logger.Field("error", err.Error()))
return nil, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "find rule group error: %v", err.Error())
}
return rules, nil
}
func (l *SubscribeLogic) buildClientConfig(req *types.SubscribeRequest, userSub *user.Subscribe, servers []*server.Server, rules []*server.RuleGroup) ([]byte, string, error) {
tags := make(map[string][]*server.Server)
serverTags, err := l.svc.ServerModel.FindServerTags(l.ctx)
if err != nil {
l.Errorw("[Generate Subscribe]find server tags error: %v", logger.Field("error", err.Error()))
return nil, "", errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "find server tags error: %v", err.Error())
}
// Deduplicate tags
serverTags = tool.RemoveDuplicateElements(serverTags...)
for _, tag := range serverTags {
s, err := l.svc.ServerModel.FindServersByTag(l.ctx.Request.Context(), tag)
if err != nil {
l.Errorw("[Generate Subscribe]find servers by tag error: %v", logger.Field("error", err.Error()))
continue
}
if len(s) > 0 {
tags[tag] = s
}
}
proxyManager := adapter.NewAdapter(&adapter.Config{
Nodes: servers,
Rules: rules,
Tags: tags,
})
clientType := l.getClientType(req)
var resp []byte
l.Logger.Info(fmt.Sprintf("[Generate Subscribe] %s", clientType), logger.Field("ua", req.UA), logger.Field("flag", req.Flag))
switch clientType {
case "clash":
resp, err = proxyManager.BuildClash(userSub.UUID)
if err != nil {
l.Errorw("[Generate Subscribe] build clash error", logger.Field("error", err.Error()))
return nil, "", errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "build clash error: %v", err.Error())
}
l.setClashHeaders()
case "sing-box":
resp, err = proxyManager.BuildSingbox(userSub.UUID)
if err != nil {
l.Errorw("[Generate Subscribe] build sing-box error", logger.Field("error", err.Error()))
return nil, "", errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "build sing-box error: %v", err.Error())
}
case "quantumult":
resp = []byte(proxyManager.BuildQuantumultX(userSub.UUID))
case "shadowrocket":
resp = proxyManager.BuildShadowrocket(userSub.UUID, shadowrocket.UserInfo{
Upload: userSub.Upload,
Download: userSub.Download,
TotalTraffic: userSub.Traffic,
ExpiredDate: userSub.ExpireTime,
})
case "loon":
resp = proxyManager.BuildLoon(userSub.UUID)
l.setLoonHeaders()
case "surfboard":
subsURL := l.getSubscribeURL(userSub.Token, "surfboard")
resp = proxyManager.BuildSurfboard(l.svc.Config.Site.SiteName, surfboard.UserInfo{
Upload: userSub.Upload,
Download: userSub.Download,
TotalTraffic: userSub.Traffic,
ExpiredDate: userSub.ExpireTime,
UUID: userSub.UUID,
SubscribeURL: subsURL,
})
l.setSurfboardHeaders()
case "v2rayn":
resp = proxyManager.BuildV2rayN(userSub.UUID)
case "surge":
subsURL := l.getSubscribeURL(userSub.Token, "surge")
resp = proxyManager.BuildSurge(l.svc.Config.Site.SiteName, surge.UserInfo{
UUID: userSub.UUID,
Upload: userSub.Upload,
Download: userSub.Download,
TotalTraffic: userSub.Traffic,
ExpiredDate: userSub.ExpireTime,
SubscribeURL: subsURL,
})
l.setSurgeHeaders()
default:
resp = proxyManager.BuildGeneral(userSub.UUID)
}
headerInfo := fmt.Sprintf("upload=%d;download=%d;total=%d;expire=%d",
userSub.Upload, userSub.Download, userSub.Traffic, userSub.ExpireTime.Unix())
return resp, headerInfo, nil
}
func (l *SubscribeLogic) setClashHeaders() {
l.ctx.Header("content-disposition", fmt.Sprintf("attachment;filename*=UTF-8''%s", url.QueryEscape(l.svc.Config.Site.SiteName)))
l.ctx.Header("Profile-Update-Interval", "24")
l.ctx.Header("Content-Type", "application/octet-stream; charset=UTF-8")
}
func (l *SubscribeLogic) setSurfboardHeaders() {
l.ctx.Header("content-disposition", fmt.Sprintf("attachment;filename*=UTF-8''%s.conf", url.QueryEscape(l.svc.Config.Site.SiteName)))
l.ctx.Header("Content-Type", "application/octet-stream; charset=UTF-8")
}
func (l *SubscribeLogic) setSurgeHeaders() {
l.ctx.Header("content-disposition", fmt.Sprintf("attachment;filename*=UTF-8''%s.conf", url.QueryEscape(l.svc.Config.Site.SiteName)))
l.ctx.Header("Content-Type", "application/octet-stream; charset=UTF-8")
}
func (l *SubscribeLogic) setLoonHeaders() {
l.ctx.Header("content-disposition", fmt.Sprintf("attachment;filename*=UTF-8''%s.conf", url.QueryEscape(l.svc.Config.Site.SiteName)))
l.ctx.Header("Content-Type", "application/octet-stream; charset=UTF-8")
}
func (l *SubscribeLogic) getSubscribeURL(token, flag string) string {
if l.svc.Config.Subscribe.PanDomain {
return fmt.Sprintf("https://%s", l.ctx.Request.Host)
}
if l.svc.Config.Subscribe.SubscribeDomain != "" {
domains := strings.Split(l.svc.Config.Subscribe.SubscribeDomain, "\n")
return fmt.Sprintf("https://%s%s?token=%s&flag=%s", domains[0], l.svc.Config.Subscribe.SubscribePath, token, flag)
}
return fmt.Sprintf("https://%s%s?token=%s&flag=surfboard", l.ctx.Request.Host, l.svc.Config.Subscribe.SubscribePath, token)
}
func (l *SubscribeLogic) getClientType(req *types.SubscribeRequest) string {
clientTypeMap := map[string]string{
"clash": "clash",
"meta": "clash",
"sing-box": "sing-box",
"hiddify": "sing-box",
"surge": "surge",
"quantumult": "quantumult",
"shadowrocket": "shadowrocket",
"loon": "loon",
"surfboard": "surfboard",
"v2rayn": "v2rayn",
}
findClient := func(s string) string {
s = strings.ToLower(strings.TrimSpace(s))
if s == "" {
return ""
}
for key, clientType := range clientTypeMap {
if strings.Contains(s, key) {
return clientType
}
}
return ""
}
// 优先检查Flag参数
if typ := findClient(req.Flag); typ != "" {
return typ
}
// 其次检查UA参数
return findClient(req.UA)
}