ario_server/internal/logic/admin/server/updateNodeLogic.go
missish e0a4bb028b fix(admin/server): 修复节点更新时Tags字段处理逻辑
修改DeepCopy调用以忽略空值,并完善Tags字段的处理逻辑。当Tags为空数组时清空数据库字段,避免保留旧值。
2025-05-31 11:22:02 +08:00

117 lines
3.6 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package server
import (
"context"
"encoding/json"
"strings"
"github.com/perfect-panel/server/pkg/device"
"github.com/hibiken/asynq"
"github.com/perfect-panel/server/internal/svc"
"github.com/perfect-panel/server/internal/types"
"github.com/perfect-panel/server/pkg/logger"
"github.com/perfect-panel/server/pkg/tool"
"github.com/perfect-panel/server/pkg/xerr"
queue "github.com/perfect-panel/server/queue/types"
"github.com/pkg/errors"
)
type UpdateNodeLogic struct {
logger.Logger
ctx context.Context
svcCtx *svc.ServiceContext
}
func NewUpdateNodeLogic(ctx context.Context, svcCtx *svc.ServiceContext) *UpdateNodeLogic {
return &UpdateNodeLogic{
Logger: logger.WithContext(ctx),
ctx: ctx,
svcCtx: svcCtx,
}
}
func (l *UpdateNodeLogic) UpdateNode(req *types.UpdateNodeRequest) error {
// Check server exist
nodeInfo, err := l.svcCtx.ServerModel.FindOne(l.ctx, req.Id)
if err != nil {
return errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "find server error: %v", err)
}
tool.DeepCopy(nodeInfo, req, tool.CopyWithIgnoreEmpty(false))
config, err := json.Marshal(req.Config)
if err != nil {
return err
}
nodeInfo.Config = string(config)
nodeRelay, err := json.Marshal(req.RelayNode)
if err != nil {
l.Errorw("[UpdateNode] Marshal RelayNode Error: ", logger.Field("error", err.Error()))
return err
}
// 处理Tags字段
switch {
case len(req.Tags) > 0:
// 有Tags进行连接
nodeInfo.Tags = strings.Join(req.Tags, ",")
default:
// 空数组清空Tags
nodeInfo.Tags = ""
}
nodeInfo.City = req.City
nodeInfo.Country = req.Country
nodeInfo.RelayNode = string(nodeRelay)
if req.Protocol == "vless" {
var cfg types.Vless
if err := json.Unmarshal(config, &cfg); err != nil {
return errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "json.Unmarshal error: %v", err.Error())
}
if cfg.Security == "reality" && cfg.SecurityConfig.RealityPublicKey == "" {
public, private, err := tool.Curve25519Genkey(false, "")
if err != nil {
return errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "generate curve25519 key error")
}
cfg.SecurityConfig.RealityPublicKey = public
cfg.SecurityConfig.RealityPrivateKey = private
cfg.SecurityConfig.RealityShortId = tool.GenerateShortID(private)
}
if cfg.SecurityConfig.RealityServerAddr == "" {
cfg.SecurityConfig.RealityServerAddr = cfg.SecurityConfig.SNI
}
if cfg.SecurityConfig.RealityServerPort == 0 {
cfg.SecurityConfig.RealityServerPort = 443
}
config, _ = json.Marshal(cfg)
nodeInfo.Config = string(config)
}
err = l.svcCtx.ServerModel.Update(l.ctx, nodeInfo)
if err != nil {
l.Errorw("[UpdateNode] Update Database Error: ", logger.Field("error", err.Error()))
return errors.Wrapf(xerr.NewErrCode(xerr.DatabaseInsertError), "create server error: %v", err)
}
// Marshal the task payload
payload, err := json.Marshal(queue.GetNodeCountry{
Protocol: nodeInfo.Protocol,
ServerAddr: nodeInfo.ServerAddr,
})
if err != nil {
l.Errorw("[GetNodeCountry]: Marshal Error", logger.Field("error", err.Error()))
return errors.Wrap(xerr.NewErrCode(xerr.ERROR), "Failed to marshal task payload")
}
// Create a queue task
task := asynq.NewTask(queue.ForthwithGetCountry, payload)
// Enqueue the task
taskInfo, err := l.svcCtx.Queue.Enqueue(task)
if err != nil {
l.Errorw("[GetNodeCountry]: Enqueue Error", logger.Field("error", err.Error()), logger.Field("payload", string(payload)))
return errors.Wrap(xerr.NewErrCode(xerr.ERROR), "Failed to enqueue task")
}
l.Infow("[GetNodeCountry]: Enqueue Success", logger.Field("taskID", taskInfo.ID), logger.Field("payload", string(payload)))
l.svcCtx.DeviceManager.Broadcast(device.SubscribeUpdate)
return nil
}