From e215ffcae9a664c15e94cdec1d4f7d1772d67c76 Mon Sep 17 00:00:00 2001 From: EUForest Date: Fri, 6 Mar 2026 13:25:01 +0800 Subject: [PATCH 01/18] fix(subscribe): invalidate user subscription cache when plan is updated When administrators update subscription plan configurations (traffic limits, nodes, speed limits, etc.), existing subscribers were not seeing the updated settings immediately. This was caused by stale cache entries that were not being invalidated. The issue occurred because: - User subscription queries cache the entire result including preloaded plan details - Plan update/delete operations only cleared the plan's own cache keys - User subscription cache keys (cache:user:subscribe:user:{userId}) remained stale This fix ensures that when a subscription plan is updated or deleted, all associated user subscription caches are properly invalidated by: - Querying all active users subscribed to the plan - Building cache keys for each affected user - Clearing both plan and user subscription caches atomically Users will now immediately see updated plan configurations without waiting for cache expiration. --- internal/model/subscribe/default.go | 48 +++++++++++++++++++++++++++-- 1 file changed, 46 insertions(+), 2 deletions(-) diff --git a/internal/model/subscribe/default.go b/internal/model/subscribe/default.go index 29e748c..c35d161 100644 --- a/internal/model/subscribe/default.go +++ b/internal/model/subscribe/default.go @@ -119,13 +119,35 @@ func (m *defaultSubscribeModel) Update(ctx context.Context, data *Subscribe, tx if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { return err } + + // 获取所有使用该套餐的用户订阅缓存 key + var userIds []int64 + err = m.QueryNoCacheCtx(ctx, &userIds, func(conn *gorm.DB, v interface{}) error { + return conn.Table("user_subscribe"). + Where("subscribe_id = ? AND status IN (0, 1)", data.Id). + Distinct("user_id"). + Pluck("user_id", &userIds).Error + }) + if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { + return err + } + + // 构建用户订阅缓存 key 列表 + userSubscribeCacheKeys := make([]string, 0, len(userIds)) + for _, userId := range userIds { + userSubscribeCacheKeys = append(userSubscribeCacheKeys, fmt.Sprintf("cache:user:subscribe:user:%d", userId)) + } + + // 合并套餐缓存 key 和用户订阅缓存 key + allCacheKeys := append(m.getCacheKeys(old), userSubscribeCacheKeys...) + err = m.ExecCtx(ctx, func(conn *gorm.DB) error { db := conn if len(tx) > 0 { db = tx[0] } return db.Save(data).Error - }, m.getCacheKeys(old)...) + }, allCacheKeys...) return err } @@ -137,13 +159,35 @@ func (m *defaultSubscribeModel) Delete(ctx context.Context, id int64, tx ...*gor } return err } + + // 获取所有使用该套餐的用户订阅缓存 key + var userIds []int64 + err = m.QueryNoCacheCtx(ctx, &userIds, func(conn *gorm.DB, v interface{}) error { + return conn.Table("user_subscribe"). + Where("subscribe_id = ? AND status IN (0, 1)", id). + Distinct("user_id"). + Pluck("user_id", &userIds).Error + }) + if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { + return err + } + + // 构建用户订阅缓存 key 列表 + userSubscribeCacheKeys := make([]string, 0, len(userIds)) + for _, userId := range userIds { + userSubscribeCacheKeys = append(userSubscribeCacheKeys, fmt.Sprintf("cache:user:subscribe:user:%d", userId)) + } + + // 合并套餐缓存 key 和用户订阅缓存 key + allCacheKeys := append(m.getCacheKeys(data), userSubscribeCacheKeys...) + err = m.ExecCtx(ctx, func(conn *gorm.DB) error { db := conn if len(tx) > 0 { db = tx[0] } return db.Delete(&Subscribe{}, id).Error - }, m.getCacheKeys(data)...) + }, allCacheKeys...) return err } From 39310d5b9a77e263bbc02445d3cc3c2a934f2732 Mon Sep 17 00:00:00 2001 From: EUForest Date: Sun, 8 Mar 2026 23:22:38 +0800 Subject: [PATCH 02/18] Features: - Node group CRUD operations with traffic-based filtering - Three grouping modes: average distribution, subscription-based, and traffic-based - Automatic and manual group recalculation with history tracking - Group assignment preview before applying changes - User subscription group locking to prevent automatic reassignment - Subscribe-to-group mapping configuration - Group calculation history and detailed reports - System configuration for group management (enabled/mode/auto_create) Database: - Add node_group table for group definitions - Add group_history and group_history_detail tables for tracking - Add node_group_ids (JSON) to nodes and subscribe tables - Add node_group_id and group_locked fields to user_subscribe table - Add migration files for schema changes --- apis/admin/group.api | 207 +++++ apis/admin/redemption.api | 5 +- apis/admin/server.api | 63 +- apis/admin/subscribe.api | 14 +- apis/admin/user.api | 2 + apis/public/redemption.api | 1 + apis/public/subscribe.api | 87 +- apis/public/user.api | 86 +- apis/types.api | 40 +- .../database/02131_add_groups.down.sql | 28 + .../migrate/database/02131_add_groups.up.sql | 130 +++ .../admin/group/createNodeGroupHandler.go | 26 + .../admin/group/deleteNodeGroupHandler.go | 29 + .../admin/group/exportGroupResultHandler.go | 36 + .../admin/group/getGroupConfigHandler.go | 26 + .../group/getGroupHistoryDetailHandler.go | 26 + .../admin/group/getGroupHistoryHandler.go | 26 + .../admin/group/getNodeGroupListHandler.go | 26 + .../group/getRecalculationStatusHandler.go | 18 + .../group/getSubscribeGroupMappingHandler.go | 26 + .../admin/group/previewUserNodesHandler.go | 26 + .../admin/group/recalculateGroupHandler.go | 26 + .../handler/admin/group/resetGroupsHandler.go | 17 + .../admin/group/updateGroupConfigHandler.go | 26 + .../admin/group/updateNodeGroupHandler.go | 33 + internal/handler/routes.go | 48 ++ .../logic/admin/group/createNodeGroupLogic.go | 46 + .../logic/admin/group/deleteNodeGroupLogic.go | 61 ++ .../admin/group/exportGroupResultLogic.go | 128 +++ .../logic/admin/group/getGroupConfigLogic.go | 125 +++ .../admin/group/getGroupHistoryDetailLogic.go | 109 +++ .../logic/admin/group/getGroupHistoryLogic.go | 87 ++ .../admin/group/getNodeGroupListLogic.go | 89 ++ .../group/getRecalculationStatusLogic.go | 57 ++ .../group/getSubscribeGroupMappingLogic.go | 71 ++ .../admin/group/previewUserNodesLogic.go | 466 ++++++++++ .../admin/group/recalculateGroupLogic.go | 814 ++++++++++++++++++ .../logic/admin/group/resetGroupsLogic.go | 82 ++ .../admin/group/updateGroupConfigLogic.go | 188 ++++ .../logic/admin/group/updateNodeGroupLogic.go | 140 +++ .../logic/admin/server/createNodeLogic.go | 15 +- .../logic/admin/server/filterNodeListLogic.go | 36 +- .../logic/admin/server/updateNodeLogic.go | 1 + .../admin/subscribe/createSubscribeLogic.go | 2 + .../admin/subscribe/getSubscribeListLogic.go | 20 +- .../admin/subscribe/updateSubscribeLogic.go | 2 + .../admin/user/createUserSubscribeLogic.go | 57 ++ .../admin/user/updateUserBasicInfoLogic.go | 26 +- .../admin/user/updateUserSubscribeLogic.go | 2 + internal/logic/auth/registerLimitLogic.go | 4 + internal/logic/auth/userRegisterLogic.go | 71 +- .../queryUserSubscribeNodeListLogic.go | 137 ++- .../logic/server/getServerUserListLogic.go | 72 +- internal/logic/subscribe/subscribeLogic.go | 144 +++- internal/model/group/history.go | 54 ++ internal/model/group/model.go | 14 + internal/model/group/node_group.go | 30 + internal/model/node/model.go | 31 +- internal/model/node/node.go | 72 +- internal/model/subscribe/model.go | 77 ++ internal/model/subscribe/subscribe.go | 99 ++- internal/model/user/model.go | 1 + internal/model/user/user.go | 49 ++ internal/types/types.go | 228 ++++- pkg/turnstile/service.go | 4 +- ppanel.api | 1 + queue/handler/routes.go | 4 + queue/logic/group/recalculateGroupLogic.go | 87 ++ queue/logic/order/activateOrderLogic.go | 74 +- .../subscription/checkSubscriptionLogic.go | 2 +- queue/types/scheduler.go | 1 + scheduler/scheduler.go | 6 + 72 files changed, 4682 insertions(+), 282 deletions(-) create mode 100644 apis/admin/group.api create mode 100644 initialize/migrate/database/02131_add_groups.down.sql create mode 100644 initialize/migrate/database/02131_add_groups.up.sql create mode 100644 internal/handler/admin/group/createNodeGroupHandler.go create mode 100644 internal/handler/admin/group/deleteNodeGroupHandler.go create mode 100644 internal/handler/admin/group/exportGroupResultHandler.go create mode 100644 internal/handler/admin/group/getGroupConfigHandler.go create mode 100644 internal/handler/admin/group/getGroupHistoryDetailHandler.go create mode 100644 internal/handler/admin/group/getGroupHistoryHandler.go create mode 100644 internal/handler/admin/group/getNodeGroupListHandler.go create mode 100644 internal/handler/admin/group/getRecalculationStatusHandler.go create mode 100644 internal/handler/admin/group/getSubscribeGroupMappingHandler.go create mode 100644 internal/handler/admin/group/previewUserNodesHandler.go create mode 100644 internal/handler/admin/group/recalculateGroupHandler.go create mode 100644 internal/handler/admin/group/resetGroupsHandler.go create mode 100644 internal/handler/admin/group/updateGroupConfigHandler.go create mode 100644 internal/handler/admin/group/updateNodeGroupHandler.go create mode 100644 internal/logic/admin/group/createNodeGroupLogic.go create mode 100644 internal/logic/admin/group/deleteNodeGroupLogic.go create mode 100644 internal/logic/admin/group/exportGroupResultLogic.go create mode 100644 internal/logic/admin/group/getGroupConfigLogic.go create mode 100644 internal/logic/admin/group/getGroupHistoryDetailLogic.go create mode 100644 internal/logic/admin/group/getGroupHistoryLogic.go create mode 100644 internal/logic/admin/group/getNodeGroupListLogic.go create mode 100644 internal/logic/admin/group/getRecalculationStatusLogic.go create mode 100644 internal/logic/admin/group/getSubscribeGroupMappingLogic.go create mode 100644 internal/logic/admin/group/previewUserNodesLogic.go create mode 100644 internal/logic/admin/group/recalculateGroupLogic.go create mode 100644 internal/logic/admin/group/resetGroupsLogic.go create mode 100644 internal/logic/admin/group/updateGroupConfigLogic.go create mode 100644 internal/logic/admin/group/updateNodeGroupLogic.go create mode 100644 internal/model/group/history.go create mode 100644 internal/model/group/model.go create mode 100644 internal/model/group/node_group.go create mode 100644 queue/logic/group/recalculateGroupLogic.go diff --git a/apis/admin/group.api b/apis/admin/group.api new file mode 100644 index 0000000..e229fc4 --- /dev/null +++ b/apis/admin/group.api @@ -0,0 +1,207 @@ +syntax = "v1" + +info ( + title: "Group API" + desc: "API for user group and node group management" + author: "Tension" + email: "tension@ppanel.com" + version: "0.0.1" +) + +import ( + "../types.api" + "./server.api" +) + +type ( + // ===== 节点组管理 ===== + // GetNodeGroupListRequest + GetNodeGroupListRequest { + Page int `form:"page"` + Size int `form:"size"` + GroupId string `form:"group_id,omitempty"` + } + // GetNodeGroupListResponse + GetNodeGroupListResponse { + Total int64 `json:"total"` + List []NodeGroup `json:"list"` + } + // CreateNodeGroupRequest + CreateNodeGroupRequest { + Name string `json:"name" validate:"required"` + Description string `json:"description"` + Sort int `json:"sort"` + ForCalculation *bool `json:"for_calculation"` + MinTrafficGB *int64 `json:"min_traffic_gb,omitempty"` + MaxTrafficGB *int64 `json:"max_traffic_gb,omitempty"` + } + // UpdateNodeGroupRequest + UpdateNodeGroupRequest { + Id int64 `json:"id" validate:"required"` + Name string `json:"name"` + Description string `json:"description"` + Sort int `json:"sort"` + ForCalculation *bool `json:"for_calculation"` + MinTrafficGB *int64 `json:"min_traffic_gb,omitempty"` + MaxTrafficGB *int64 `json:"max_traffic_gb,omitempty"` + } + // DeleteNodeGroupRequest + DeleteNodeGroupRequest { + Id int64 `json:"id" validate:"required"` + } + // ===== 分组配置管理 ===== + // GetGroupConfigRequest + GetGroupConfigRequest { + Keys []string `form:"keys,omitempty"` + } + // GetGroupConfigResponse + GetGroupConfigResponse { + Enabled bool `json:"enabled"` + Mode string `json:"mode"` + Config map[string]interface{} `json:"config"` + State RecalculationState `json:"state"` + } + // UpdateGroupConfigRequest + UpdateGroupConfigRequest { + Enabled bool `json:"enabled"` + Mode string `json:"mode"` + Config map[string]interface{} `json:"config"` + } + // RecalculationState + RecalculationState { + State string `json:"state"` + Progress int `json:"progress"` + Total int `json:"total"` + } + // ===== 分组操作 ===== + // RecalculateGroupRequest + RecalculateGroupRequest { + Mode string `json:"mode" validate:"required"` + TriggerType string `json:"trigger_type"` // "manual" or "scheduled" + } + // GetGroupHistoryRequest + GetGroupHistoryRequest { + Page int `form:"page"` + Size int `form:"size"` + GroupMode string `form:"group_mode,omitempty"` + TriggerType string `form:"trigger_type,omitempty"` + } + // GetGroupHistoryResponse + GetGroupHistoryResponse { + Total int64 `json:"total"` + List []GroupHistory `json:"list"` + } + // GetGroupHistoryDetailRequest + GetGroupHistoryDetailRequest { + Id int64 `form:"id" validate:"required"` + } + // GetGroupHistoryDetailResponse + GetGroupHistoryDetailResponse { + GroupHistoryDetail + } + // PreviewUserNodesRequest + PreviewUserNodesRequest { + UserId int64 `form:"user_id" validate:"required"` + } + // PreviewUserNodesResponse + PreviewUserNodesResponse { + UserId int64 `json:"user_id"` + NodeGroups []NodeGroupItem `json:"node_groups"` + } + // NodeGroupItem + NodeGroupItem { + Id int64 `json:"id"` + Name string `json:"name"` + Nodes []Node `json:"nodes"` + } + // ExportGroupResultRequest + ExportGroupResultRequest { + HistoryId *int64 `form:"history_id,omitempty"` + } + // ===== Reset Groups ===== + // ResetGroupsRequest + ResetGroupsRequest { + Confirm bool `json:"confirm" validate:"required"` + } + // ===== 套餐分组映射 ===== + // SubscribeGroupMappingItem + SubscribeGroupMappingItem { + SubscribeName string `json:"subscribe_name"` + NodeGroupName string `json:"node_group_name"` + } + // GetSubscribeGroupMappingRequest + GetSubscribeGroupMappingRequest {} + // GetSubscribeGroupMappingResponse + GetSubscribeGroupMappingResponse { + List []SubscribeGroupMappingItem `json:"list"` + } +) + +@server ( + prefix: v1/admin/group + group: admin/group + jwt: JwtAuth + middleware: AuthMiddleware +) +service ppanel { + // ===== 节点组管理 ===== + @doc "Get node group list" + @handler GetNodeGroupList + get /node/list (GetNodeGroupListRequest) returns (GetNodeGroupListResponse) + + @doc "Create node group" + @handler CreateNodeGroup + post /node (CreateNodeGroupRequest) + + @doc "Update node group" + @handler UpdateNodeGroup + put /node (UpdateNodeGroupRequest) + + @doc "Delete node group" + @handler DeleteNodeGroup + delete /node (DeleteNodeGroupRequest) + + // ===== 分组配置管理 ===== + @doc "Get group config" + @handler GetGroupConfig + get /config (GetGroupConfigRequest) returns (GetGroupConfigResponse) + + @doc "Update group config" + @handler UpdateGroupConfig + put /config (UpdateGroupConfigRequest) + + // ===== 分组操作 ===== + @doc "Recalculate group" + @handler RecalculateGroup + post /recalculate (RecalculateGroupRequest) + + @doc "Get recalculation status" + @handler GetRecalculationStatus + get /recalculation/status returns (RecalculationState) + + @doc "Get group history" + @handler GetGroupHistory + get /history (GetGroupHistoryRequest) returns (GetGroupHistoryResponse) + + @doc "Export group result" + @handler ExportGroupResult + get /export (ExportGroupResultRequest) + + // Routes with query parameters + @doc "Get group history detail" + @handler GetGroupHistoryDetail + get /history/detail (GetGroupHistoryDetailRequest) returns (GetGroupHistoryDetailResponse) + + @doc "Preview user nodes" + @handler PreviewUserNodes + get /preview (PreviewUserNodesRequest) returns (PreviewUserNodesResponse) + + @doc "Reset all groups" + @handler ResetGroups + post /reset (ResetGroupsRequest) + + @doc "Get subscribe group mapping" + @handler GetSubscribeGroupMapping + get /subscribe/mapping (GetSubscribeGroupMappingRequest) returns (GetSubscribeGroupMappingResponse) +} + diff --git a/apis/admin/redemption.api b/apis/admin/redemption.api index ca5632c..ff2240b 100644 --- a/apis/admin/redemption.api +++ b/apis/admin/redemption.api @@ -27,8 +27,8 @@ type ( Status int64 `json:"status,omitempty" validate:"omitempty,oneof=0 1"` } ToggleRedemptionCodeStatusRequest { - Id int64 `json:"id" validate:"required"` - Status int64 `json:"status" validate:"oneof=0 1"` + Id int64 `json:"id" validate:"required"` + Status int64 `json:"status" validate:"oneof=0 1"` } DeleteRedemptionCodeRequest { Id int64 `json:"id" validate:"required"` @@ -93,3 +93,4 @@ service ppanel { @handler GetRedemptionRecordList get /record/list (GetRedemptionRecordListRequest) returns (GetRedemptionRecordListResponse) } + diff --git a/apis/admin/server.api b/apis/admin/server.api index 87d1f7e..5f5479d 100644 --- a/apis/admin/server.api +++ b/apis/admin/server.api @@ -80,36 +80,40 @@ type ( Protocols []Protocol `json:"protocols"` } Node { - Id int64 `json:"id"` - Name string `json:"name"` - Tags []string `json:"tags"` - Port uint16 `json:"port"` - Address string `json:"address"` - ServerId int64 `json:"server_id"` - Protocol string `json:"protocol"` - Enabled *bool `json:"enabled"` - Sort int `json:"sort,omitempty"` - CreatedAt int64 `json:"created_at"` - UpdatedAt int64 `json:"updated_at"` + Id int64 `json:"id"` + Name string `json:"name"` + Tags []string `json:"tags"` + Port uint16 `json:"port"` + Address string `json:"address"` + ServerId int64 `json:"server_id"` + Protocol string `json:"protocol"` + Enabled *bool `json:"enabled"` + Sort int `json:"sort,omitempty"` + NodeGroupId int64 `json:"node_group_id,omitempty"` + NodeGroupIds []int64 `json:"node_group_ids,omitempty"` + CreatedAt int64 `json:"created_at"` + UpdatedAt int64 `json:"updated_at"` } CreateNodeRequest { - Name string `json:"name"` - Tags []string `json:"tags,omitempty"` - Port uint16 `json:"port"` - Address string `json:"address"` - ServerId int64 `json:"server_id"` - Protocol string `json:"protocol"` - Enabled *bool `json:"enabled"` + Name string `json:"name"` + Tags []string `json:"tags,omitempty"` + Port uint16 `json:"port"` + Address string `json:"address"` + ServerId int64 `json:"server_id"` + Protocol string `json:"protocol"` + Enabled *bool `json:"enabled"` + NodeGroupIds []int64 `json:"node_group_ids,omitempty"` } UpdateNodeRequest { - Id int64 `json:"id"` - Name string `json:"name"` - Tags []string `json:"tags,omitempty"` - Port uint16 `json:"port"` - Address string `json:"address"` - ServerId int64 `json:"server_id"` - Protocol string `json:"protocol"` - Enabled *bool `json:"enabled"` + Id int64 `json:"id"` + Name string `json:"name"` + Tags []string `json:"tags,omitempty"` + Port uint16 `json:"port"` + Address string `json:"address"` + ServerId int64 `json:"server_id"` + Protocol string `json:"protocol"` + Enabled *bool `json:"enabled"` + NodeGroupIds []int64 `json:"node_group_ids,omitempty"` } ToggleNodeStatusRequest { Id int64 `json:"id"` @@ -119,9 +123,10 @@ type ( Id int64 `json:"id"` } FilterNodeListRequest { - Page int `form:"page"` - Size int `form:"size"` - Search string `form:"search,omitempty"` + Page int `form:"page"` + Size int `form:"size"` + Search string `form:"search,omitempty"` + NodeGroupId *int64 `form:"node_group_id,omitempty"` } FilterNodeListResponse { Total int64 `json:"total"` diff --git a/apis/admin/subscribe.api b/apis/admin/subscribe.api index a832b3a..881f021 100644 --- a/apis/admin/subscribe.api +++ b/apis/admin/subscribe.api @@ -48,6 +48,8 @@ type ( Quota int64 `json:"quota"` Nodes []int64 `json:"nodes"` NodeTags []string `json:"node_tags"` + NodeGroupIds []int64 `json:"node_group_ids,omitempty"` + NodeGroupId int64 `json:"node_group_id"` Show *bool `json:"show"` Sell *bool `json:"sell"` DeductionRatio int64 `json:"deduction_ratio"` @@ -55,6 +57,7 @@ type ( ResetCycle int64 `json:"reset_cycle"` RenewalReset *bool `json:"renewal_reset"` ShowOriginalPrice bool `json:"show_original_price"` + AutoCreateGroup bool `json:"auto_create_group"` } UpdateSubscribeRequest { Id int64 `json:"id" validate:"required"` @@ -72,6 +75,8 @@ type ( Quota int64 `json:"quota"` Nodes []int64 `json:"nodes"` NodeTags []string `json:"node_tags"` + NodeGroupIds []int64 `json:"node_group_ids,omitempty"` + NodeGroupId int64 `json:"node_group_id"` Show *bool `json:"show"` Sell *bool `json:"sell"` Sort int64 `json:"sort"` @@ -85,10 +90,11 @@ type ( Sort []SortItem `json:"sort"` } GetSubscribeListRequest { - Page int64 `form:"page" validate:"required"` - Size int64 `form:"size" validate:"required"` - Language string `form:"language,omitempty"` - Search string `form:"search,omitempty"` + Page int64 `form:"page" validate:"required"` + Size int64 `form:"size" validate:"required"` + Language string `form:"language,omitempty"` + Search string `form:"search,omitempty"` + NodeGroupId int64 `form:"node_group_id,omitempty"` } SubscribeItem { Subscribe diff --git a/apis/admin/user.api b/apis/admin/user.api index 669ec90..bbc583d 100644 --- a/apis/admin/user.api +++ b/apis/admin/user.api @@ -78,6 +78,8 @@ type ( OrderId int64 `json:"order_id"` SubscribeId int64 `json:"subscribe_id"` Subscribe Subscribe `json:"subscribe"` + NodeGroupId int64 `json:"node_group_id"` + GroupLocked bool `json:"group_locked"` StartTime int64 `json:"start_time"` ExpireTime int64 `json:"expire_time"` ResetTime int64 `json:"reset_time"` diff --git a/apis/public/redemption.api b/apis/public/redemption.api index 8bbb3ae..787fa35 100644 --- a/apis/public/redemption.api +++ b/apis/public/redemption.api @@ -30,3 +30,4 @@ service ppanel { @handler RedeemCode post / (RedeemCodeRequest) returns (RedeemCodeResponse) } + diff --git a/apis/public/subscribe.api b/apis/public/subscribe.api index 4c0d2aa..65c2f92 100644 --- a/apis/public/subscribe.api +++ b/apis/public/subscribe.api @@ -14,48 +14,45 @@ type ( QuerySubscribeListRequest { Language string `form:"language"` } - - QueryUserSubscribeNodeListResponse { - List []UserSubscribeInfo `json:"list"` - } - - UserSubscribeInfo { - Id int64 `json:"id"` - UserId int64 `json:"user_id"` - OrderId int64 `json:"order_id"` - SubscribeId int64 `json:"subscribe_id"` - StartTime int64 `json:"start_time"` - ExpireTime int64 `json:"expire_time"` - FinishedAt int64 `json:"finished_at"` - ResetTime int64 `json:"reset_time"` - Traffic int64 `json:"traffic"` - Download int64 `json:"download"` - Upload int64 `json:"upload"` - Token string `json:"token"` - Status uint8 `json:"status"` - CreatedAt int64 `json:"created_at"` - UpdatedAt int64 `json:"updated_at"` - IsTryOut bool `json:"is_try_out"` - Nodes []*UserSubscribeNodeInfo `json:"nodes"` - } - - UserSubscribeNodeInfo{ - Id int64 `json:"id"` - Name string `json:"name"` - Uuid string `json:"uuid"` - Protocol string `json:"protocol"` - Protocols string `json:"protocols"` - Port uint16 `json:"port"` - Address string `json:"address"` - Tags []string `json:"tags"` - Country string `json:"country"` - City string `json:"city"` - Longitude string `json:"longitude"` - Latitude string `json:"latitude"` - LatitudeCenter string `json:"latitude_center"` - LongitudeCenter string `json:"longitude_center"` - CreatedAt int64 `json:"created_at"` - } + QueryUserSubscribeNodeListResponse { + List []UserSubscribeInfo `json:"list"` + } + UserSubscribeInfo { + Id int64 `json:"id"` + UserId int64 `json:"user_id"` + OrderId int64 `json:"order_id"` + SubscribeId int64 `json:"subscribe_id"` + StartTime int64 `json:"start_time"` + ExpireTime int64 `json:"expire_time"` + FinishedAt int64 `json:"finished_at"` + ResetTime int64 `json:"reset_time"` + Traffic int64 `json:"traffic"` + Download int64 `json:"download"` + Upload int64 `json:"upload"` + Token string `json:"token"` + Status uint8 `json:"status"` + CreatedAt int64 `json:"created_at"` + UpdatedAt int64 `json:"updated_at"` + IsTryOut bool `json:"is_try_out"` + Nodes []*UserSubscribeNodeInfo `json:"nodes"` + } + UserSubscribeNodeInfo { + Id int64 `json:"id"` + Name string `json:"name"` + Uuid string `json:"uuid"` + Protocol string `json:"protocol"` + Protocols string `json:"protocols"` + Port uint16 `json:"port"` + Address string `json:"address"` + Tags []string `json:"tags"` + Country string `json:"country"` + City string `json:"city"` + Longitude string `json:"longitude"` + Latitude string `json:"latitude"` + LatitudeCenter string `json:"latitude_center"` + LongitudeCenter string `json:"longitude_center"` + CreatedAt int64 `json:"created_at"` + } ) @server ( @@ -68,8 +65,8 @@ service ppanel { @handler QuerySubscribeList get /list (QuerySubscribeListRequest) returns (QuerySubscribeListResponse) - @doc "Get user subscribe node info" - @handler QueryUserSubscribeNodeList - get /node/list returns (QueryUserSubscribeNodeListResponse) + @doc "Get user subscribe node info" + @handler QueryUserSubscribeNodeList + get /node/list returns (QueryUserSubscribeNodeListResponse) } diff --git a/apis/public/user.api b/apis/public/user.api index 820a160..a6eb50f 100644 --- a/apis/public/user.api +++ b/apis/public/user.api @@ -66,7 +66,6 @@ type ( UnbindOAuthRequest { Method string `json:"method"` } - GetLoginLogRequest { Page int `form:"page"` Size int `form:"size"` @@ -95,21 +94,17 @@ type ( Email string `json:"email" validate:"required"` Code string `json:"code" validate:"required"` } - - GetDeviceListResponse { - List []UserDevice `json:"list"` - Total int64 `json:"total"` - } - - UnbindDeviceRequest { - Id int64 `json:"id" validate:"required"` - } - + GetDeviceListResponse { + List []UserDevice `json:"list"` + Total int64 `json:"total"` + } + UnbindDeviceRequest { + Id int64 `json:"id" validate:"required"` + } UpdateUserSubscribeNoteRequest { UserSubscribeId int64 `json:"user_subscribe_id" validate:"required"` Note string `json:"note" validate:"max=500"` } - UpdateUserRulesRequest { Rules []string `json:"rules" validate:"required"` } @@ -135,23 +130,20 @@ type ( List []WithdrawalLog `json:"list"` Total int64 `json:"total"` } - - - GetDeviceOnlineStatsResponse { - WeeklyStats []WeeklyStat `json:"weekly_stats"` - ConnectionRecords ConnectionRecords `json:"connection_records"` - } - - WeeklyStat { - Day int `json:"day"` - DayName string `json:"day_name"` - Hours float64 `json:"hours"` - } - ConnectionRecords { - CurrentContinuousDays int64 `json:"current_continuous_days"` - HistoryContinuousDays int64 `json:"history_continuous_days"` - LongestSingleConnection int64 `json:"longest_single_connection"` - } + GetDeviceOnlineStatsResponse { + WeeklyStats []WeeklyStat `json:"weekly_stats"` + ConnectionRecords ConnectionRecords `json:"connection_records"` + } + WeeklyStat { + Day int `json:"day"` + DayName string `json:"day_name"` + Hours float64 `json:"hours"` + } + ConnectionRecords { + CurrentContinuousDays int64 `json:"current_continuous_days"` + HistoryContinuousDays int64 `json:"history_continuous_days"` + LongestSingleConnection int64 `json:"longest_single_connection"` + } ) @server ( @@ -248,9 +240,9 @@ service ppanel { @handler UpdateBindEmail put /bind_email (UpdateBindEmailRequest) - @doc "Get Device List" - @handler GetDeviceList - get /devices returns (GetDeviceListResponse) + @doc "Get Device List" + @handler GetDeviceList + get /devices returns (GetDeviceListResponse) @doc "Unbind Device" @handler UnbindDevice @@ -272,23 +264,23 @@ service ppanel { @handler QueryWithdrawalLog get /withdrawal_log (QueryWithdrawalLogListRequest) returns (QueryWithdrawalLogListResponse) - @doc "Device Online Statistics" - @handler DeviceOnlineStatistics - get /device_online_statistics returns (GetDeviceOnlineStatsResponse) - - @doc "Delete Current User Account" - @handler DeleteCurrentUserAccount - delete /current_user_account + @doc "Device Online Statistics" + @handler DeviceOnlineStatistics + get /device_online_statistics returns (GetDeviceOnlineStatsResponse) + @doc "Delete Current User Account" + @handler DeleteCurrentUserAccount + delete /current_user_account } -@server( - prefix: v1/public/user - group: public/user/ws - middleware: AuthMiddleware + +@server ( + prefix: v1/public/user + group: public/user/ws + middleware: AuthMiddleware ) - service ppanel { - @doc "Webosocket Device Connect" - @handler DeviceWsConnect - get /device_ws_connect + @doc "Webosocket Device Connect" + @handler DeviceWsConnect + get /device_ws_connect } + diff --git a/apis/types.api b/apis/types.api index 1fc1725..e5ffbd1 100644 --- a/apis/types.api +++ b/apis/types.api @@ -206,7 +206,7 @@ type ( CurrencySymbol string `json:"currency_symbol"` } SubscribeDiscount { - Quantity int64 `json:"quantity"` + Quantity int64 `json:"quantity"` Discount float64 `json:"discount"` } Subscribe { @@ -225,6 +225,8 @@ type ( Quota int64 `json:"quota"` Nodes []int64 `json:"nodes"` NodeTags []string `json:"node_tags"` + NodeGroupIds []int64 `json:"node_group_ids,omitempty"` + NodeGroupId int64 `json:"node_group_id"` Show bool `json:"show"` Sell bool `json:"sell"` Sort int64 `json:"sort"` @@ -486,6 +488,8 @@ type ( OrderId int64 `json:"order_id"` SubscribeId int64 `json:"subscribe_id"` Subscribe Subscribe `json:"subscribe"` + NodeGroupId int64 `json:"node_group_id"` + GroupLocked bool `json:"group_locked"` StartTime int64 `json:"start_time"` ExpireTime int64 `json:"expire_time"` FinishedAt int64 `json:"finished_at"` @@ -697,7 +701,6 @@ type ( List []SubscribeGroup `json:"list"` Total int64 `json:"total"` } - GetUserSubscribeTrafficLogsRequest { Page int `form:"page"` Size int `form:"size"` @@ -874,5 +877,38 @@ type ( ResetUserSubscribeTokenRequest { UserSubscribeId int64 `json:"user_subscribe_id"` } + // ===== 分组功能类型定义 ===== + // NodeGroup 节点组 + NodeGroup { + Id int64 `json:"id"` + Name string `json:"name"` + Description string `json:"description"` + Sort int `json:"sort"` + ForCalculation bool `json:"for_calculation"` + MinTrafficGB int64 `json:"min_traffic_gb,omitempty"` + MaxTrafficGB int64 `json:"max_traffic_gb,omitempty"` + NodeCount int64 `json:"node_count,omitempty"` + CreatedAt int64 `json:"created_at"` + UpdatedAt int64 `json:"updated_at"` + } + // GroupHistory 分组历史记录 + GroupHistory { + Id int64 `json:"id"` + GroupMode string `json:"group_mode"` + TriggerType string `json:"trigger_type"` + TotalUsers int `json:"total_users"` + SuccessCount int `json:"success_count"` + FailedCount int `json:"failed_count"` + StartTime *int64 `json:"start_time,omitempty"` + EndTime *int64 `json:"end_time,omitempty"` + Operator string `json:"operator,omitempty"` + ErrorLog string `json:"error_log,omitempty"` + CreatedAt int64 `json:"created_at"` + } + // GroupHistoryDetail 分组历史详情 + GroupHistoryDetail { + GroupHistory + ConfigSnapshot map[string]interface{} `json:"config_snapshot,omitempty"` + } ) diff --git a/initialize/migrate/database/02131_add_groups.down.sql b/initialize/migrate/database/02131_add_groups.down.sql new file mode 100644 index 0000000..6765acc --- /dev/null +++ b/initialize/migrate/database/02131_add_groups.down.sql @@ -0,0 +1,28 @@ +-- Purpose: Rollback node group management tables +-- Author: Tension +-- Date: 2025-02-23 +-- Updated: 2025-03-06 + +-- ===== Remove system configuration entries ===== +DELETE FROM `system` WHERE `category` = 'group' AND `key` IN ('enabled', 'mode', 'auto_create_group'); + +-- ===== Remove columns and indexes from subscribe table ===== +ALTER TABLE `subscribe` DROP INDEX IF EXISTS `idx_node_group_id`; +ALTER TABLE `subscribe` DROP COLUMN IF EXISTS `node_group_id`; +ALTER TABLE `subscribe` DROP COLUMN IF EXISTS `node_group_ids`; + +-- ===== Remove columns and indexes from user_subscribe table ===== +ALTER TABLE `user_subscribe` DROP INDEX IF EXISTS `idx_node_group_id`; +ALTER TABLE `user_subscribe` DROP COLUMN IF EXISTS `node_group_id`; + +-- ===== Remove columns and indexes from nodes table ===== +ALTER TABLE `nodes` DROP COLUMN IF EXISTS `node_group_ids`; + +-- ===== Drop group_history_detail table ===== +DROP TABLE IF EXISTS `group_history_detail`; + +-- ===== Drop group_history table ===== +DROP TABLE IF EXISTS `group_history`; + +-- ===== Drop node_group table ===== +DROP TABLE IF EXISTS `node_group`; diff --git a/initialize/migrate/database/02131_add_groups.up.sql b/initialize/migrate/database/02131_add_groups.up.sql new file mode 100644 index 0000000..a4150ce --- /dev/null +++ b/initialize/migrate/database/02131_add_groups.up.sql @@ -0,0 +1,130 @@ +-- Purpose: Add node group management tables with multi-group support +-- Author: Tension +-- Date: 2025-02-23 +-- Updated: 2025-03-06 + +-- ===== Create node_group table ===== +DROP TABLE IF EXISTS `node_group`; +CREATE TABLE IF NOT EXISTS `node_group` ( + `id` bigint NOT NULL AUTO_INCREMENT, + `name` varchar(255) CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci NOT NULL DEFAULT '' COMMENT 'Name', + `description` varchar(500) CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci DEFAULT NULL COMMENT 'Group Description', + `sort` int NOT NULL DEFAULT '0' COMMENT 'Sort Order', + `for_calculation` tinyint(1) NOT NULL DEFAULT 1 COMMENT 'For Grouping Calculation: 0=false, 1=true', + `min_traffic_gb` bigint DEFAULT 0 COMMENT 'Minimum Traffic (GB) for this node group', + `max_traffic_gb` bigint DEFAULT 0 COMMENT 'Maximum Traffic (GB) for this node group', + `created_at` datetime(3) DEFAULT NULL COMMENT 'Create Time', + `updated_at` datetime(3) DEFAULT NULL COMMENT 'Update Time', + PRIMARY KEY (`id`), + KEY `idx_sort` (`sort`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_general_ci COMMENT='Node Groups'; + +-- ===== Create group_history table ===== +DROP TABLE IF EXISTS `group_history`; +CREATE TABLE IF NOT EXISTS `group_history` ( + `id` bigint NOT NULL AUTO_INCREMENT, + `group_mode` varchar(50) CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci NOT NULL DEFAULT '' COMMENT 'Group Mode: average/subscribe/traffic', + `trigger_type` varchar(50) CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci NOT NULL DEFAULT '' COMMENT 'Trigger Type: manual/auto/schedule', + `state` varchar(50) CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci NOT NULL DEFAULT '' COMMENT 'State: pending/running/completed/failed', + `total_users` int NOT NULL DEFAULT '0' COMMENT 'Total Users', + `success_count` int NOT NULL DEFAULT '0' COMMENT 'Success Count', + `failed_count` int NOT NULL DEFAULT '0' COMMENT 'Failed Count', + `start_time` datetime(3) DEFAULT NULL COMMENT 'Start Time', + `end_time` datetime(3) DEFAULT NULL COMMENT 'End Time', + `operator` varchar(100) CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci DEFAULT NULL COMMENT 'Operator', + `error_message` text CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci COMMENT 'Error Message', + `created_at` datetime(3) DEFAULT NULL COMMENT 'Create Time', + PRIMARY KEY (`id`), + KEY `idx_group_mode` (`group_mode`), + KEY `idx_trigger_type` (`trigger_type`), + KEY `idx_state` (`state`), + KEY `idx_created_at` (`created_at`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_general_ci COMMENT='Group Calculation History'; + +-- ===== Create group_history_detail table ===== +-- Note: user_group_id column removed, using user_data JSON field instead +DROP TABLE IF EXISTS `group_history_detail`; +CREATE TABLE IF NOT EXISTS `group_history_detail` ( + `id` bigint NOT NULL AUTO_INCREMENT, + `history_id` bigint NOT NULL COMMENT 'History ID', + `node_group_id` bigint NOT NULL COMMENT 'Node Group ID', + `user_count` int NOT NULL DEFAULT '0' COMMENT 'User Count', + `node_count` int NOT NULL DEFAULT '0' COMMENT 'Node Count', + `user_data` TEXT COMMENT 'User data JSON (id and email/phone)', + `created_at` datetime(3) DEFAULT NULL COMMENT 'Create Time', + PRIMARY KEY (`id`), + KEY `idx_history_id` (`history_id`), + KEY `idx_node_group_id` (`node_group_id`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_general_ci COMMENT='Group History Details'; + +-- ===== Add columns to nodes table ===== +SET @column_exists = (SELECT COUNT(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = 'nodes' AND COLUMN_NAME = 'node_group_ids'); +SET @sql = IF(@column_exists = 0, + 'ALTER TABLE `nodes` ADD COLUMN `node_group_ids` JSON COMMENT ''Node Group IDs (JSON array, multiple groups)''', + 'SELECT ''Column node_group_ids already exists'''); +PREPARE stmt FROM @sql; +EXECUTE stmt; +DEALLOCATE PREPARE stmt; + +-- ===== Add node_group_id column to user_subscribe table ===== +SET @column_exists = (SELECT COUNT(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = 'user_subscribe' AND COLUMN_NAME = 'node_group_id'); +SET @sql = IF(@column_exists = 0, + 'ALTER TABLE `user_subscribe` ADD COLUMN `node_group_id` bigint NOT NULL DEFAULT 0 COMMENT ''Node Group ID (single ID)''', + 'SELECT ''Column node_group_id already exists'''); +PREPARE stmt FROM @sql; +EXECUTE stmt; +DEALLOCATE PREPARE stmt; + +-- ===== Add index for user_subscribe.node_group_id ===== +SET @index_exists = (SELECT COUNT(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = 'user_subscribe' AND INDEX_NAME = 'idx_node_group_id'); +SET @sql = IF(@index_exists = 0, + 'ALTER TABLE `user_subscribe` ADD INDEX `idx_node_group_id` (`node_group_id`)', + 'SELECT ''Index idx_node_group_id already exists'''); +PREPARE stmt FROM @sql; +EXECUTE stmt; +DEALLOCATE PREPARE stmt; + +-- ===== Add group_locked column to user_subscribe table ===== +SET @column_exists = (SELECT COUNT(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = 'user_subscribe' AND COLUMN_NAME = 'group_locked'); +SET @sql = IF(@column_exists = 0, + 'ALTER TABLE `user_subscribe` ADD COLUMN `group_locked` tinyint(1) NOT NULL DEFAULT 0 COMMENT ''Group Locked''', + 'SELECT ''Column group_locked already exists in user_subscribe table'''); +PREPARE stmt FROM @sql; +EXECUTE stmt; +DEALLOCATE PREPARE stmt; + +-- ===== Add columns to subscribe table ===== +SET @column_exists = (SELECT COUNT(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = 'subscribe' AND COLUMN_NAME = 'node_group_ids'); +SET @sql = IF(@column_exists = 0, + 'ALTER TABLE `subscribe` ADD COLUMN `node_group_ids` JSON COMMENT ''Node Group IDs (JSON array, multiple groups)''', + 'SELECT ''Column node_group_ids already exists'''); +PREPARE stmt FROM @sql; +EXECUTE stmt; +DEALLOCATE PREPARE stmt; + +-- ===== Add default node_group_id column to subscribe table ===== +SET @column_exists = (SELECT COUNT(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = 'subscribe' AND COLUMN_NAME = 'node_group_id'); +SET @sql = IF(@column_exists = 0, + 'ALTER TABLE `subscribe` ADD COLUMN `node_group_id` bigint NOT NULL DEFAULT 0 COMMENT ''Default Node Group ID (single ID)''', + 'SELECT ''Column node_group_id already exists in subscribe table'''); +PREPARE stmt FROM @sql; +EXECUTE stmt; +DEALLOCATE PREPARE stmt; + +-- ===== Add index for subscribe.node_group_id ===== +SET @index_exists = (SELECT COUNT(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = 'subscribe' AND INDEX_NAME = 'idx_node_group_id'); +SET @sql = IF(@index_exists = 0, + 'ALTER TABLE `subscribe` ADD INDEX `idx_node_group_id` (`node_group_id`)', + 'SELECT ''Index idx_node_group_id already exists in subscribe table'''); +PREPARE stmt FROM @sql; +EXECUTE stmt; +DEALLOCATE PREPARE stmt; + +-- ===== Insert system configuration entries ===== +INSERT INTO `system` (`category`, `key`, `value`, `desc`) VALUES + ('group', 'enabled', 'false', 'Group Management Enabled'), + ('group', 'mode', 'average', 'Group Mode: average/subscribe/traffic'), + ('group', 'auto_create_group', 'false', 'Auto-create user group when creating subscribe product') +ON DUPLICATE KEY UPDATE + `value` = VALUES(`value`), + `desc` = VALUES(`desc`); diff --git a/internal/handler/admin/group/createNodeGroupHandler.go b/internal/handler/admin/group/createNodeGroupHandler.go new file mode 100644 index 0000000..eaba8cf --- /dev/null +++ b/internal/handler/admin/group/createNodeGroupHandler.go @@ -0,0 +1,26 @@ +package group + +import ( + "github.com/gin-gonic/gin" + "github.com/perfect-panel/server/internal/logic/admin/group" + "github.com/perfect-panel/server/internal/svc" + "github.com/perfect-panel/server/internal/types" + "github.com/perfect-panel/server/pkg/result" +) + +// Create node group +func CreateNodeGroupHandler(svcCtx *svc.ServiceContext) func(c *gin.Context) { + return func(c *gin.Context) { + var req types.CreateNodeGroupRequest + _ = c.ShouldBind(&req) + validateErr := svcCtx.Validate(&req) + if validateErr != nil { + result.ParamErrorResult(c, validateErr) + return + } + + l := group.NewCreateNodeGroupLogic(c.Request.Context(), svcCtx) + err := l.CreateNodeGroup(&req) + result.HttpResult(c, nil, err) + } +} diff --git a/internal/handler/admin/group/deleteNodeGroupHandler.go b/internal/handler/admin/group/deleteNodeGroupHandler.go new file mode 100644 index 0000000..b93120d --- /dev/null +++ b/internal/handler/admin/group/deleteNodeGroupHandler.go @@ -0,0 +1,29 @@ +package group + +import ( + "github.com/gin-gonic/gin" + "github.com/perfect-panel/server/internal/logic/admin/group" + "github.com/perfect-panel/server/internal/svc" + "github.com/perfect-panel/server/internal/types" + "github.com/perfect-panel/server/pkg/result" +) + +// Delete node group +func DeleteNodeGroupHandler(svcCtx *svc.ServiceContext) func(c *gin.Context) { + return func(c *gin.Context) { + var req types.DeleteNodeGroupRequest + if err := c.ShouldBind(&req); err != nil { + result.ParamErrorResult(c, err) + return + } + validateErr := svcCtx.Validate(&req) + if validateErr != nil { + result.ParamErrorResult(c, validateErr) + return + } + + l := group.NewDeleteNodeGroupLogic(c.Request.Context(), svcCtx) + err := l.DeleteNodeGroup(&req) + result.HttpResult(c, nil, err) + } +} diff --git a/internal/handler/admin/group/exportGroupResultHandler.go b/internal/handler/admin/group/exportGroupResultHandler.go new file mode 100644 index 0000000..69b065f --- /dev/null +++ b/internal/handler/admin/group/exportGroupResultHandler.go @@ -0,0 +1,36 @@ +package group + +import ( + "net/http" + + "github.com/gin-gonic/gin" + "github.com/perfect-panel/server/internal/logic/admin/group" + "github.com/perfect-panel/server/internal/svc" + "github.com/perfect-panel/server/internal/types" + "github.com/perfect-panel/server/pkg/result" +) + +// Export group result +func ExportGroupResultHandler(svcCtx *svc.ServiceContext) func(c *gin.Context) { + return func(c *gin.Context) { + var req types.ExportGroupResultRequest + _ = c.ShouldBind(&req) + validateErr := svcCtx.Validate(&req) + if validateErr != nil { + result.ParamErrorResult(c, validateErr) + return + } + + l := group.NewExportGroupResultLogic(c.Request.Context(), svcCtx) + data, filename, err := l.ExportGroupResult(&req) + if err != nil { + result.HttpResult(c, nil, err) + return + } + + // 设置响应头 + c.Header("Content-Type", "text/csv") + c.Header("Content-Disposition", "attachment; filename="+filename) + c.Data(http.StatusOK, "text/csv", data) + } +} diff --git a/internal/handler/admin/group/getGroupConfigHandler.go b/internal/handler/admin/group/getGroupConfigHandler.go new file mode 100644 index 0000000..ef24311 --- /dev/null +++ b/internal/handler/admin/group/getGroupConfigHandler.go @@ -0,0 +1,26 @@ +package group + +import ( + "github.com/gin-gonic/gin" + "github.com/perfect-panel/server/internal/logic/admin/group" + "github.com/perfect-panel/server/internal/svc" + "github.com/perfect-panel/server/internal/types" + "github.com/perfect-panel/server/pkg/result" +) + +// Get group config +func GetGroupConfigHandler(svcCtx *svc.ServiceContext) func(c *gin.Context) { + return func(c *gin.Context) { + var req types.GetGroupConfigRequest + _ = c.ShouldBind(&req) + validateErr := svcCtx.Validate(&req) + if validateErr != nil { + result.ParamErrorResult(c, validateErr) + return + } + + l := group.NewGetGroupConfigLogic(c.Request.Context(), svcCtx) + resp, err := l.GetGroupConfig(&req) + result.HttpResult(c, resp, err) + } +} diff --git a/internal/handler/admin/group/getGroupHistoryDetailHandler.go b/internal/handler/admin/group/getGroupHistoryDetailHandler.go new file mode 100644 index 0000000..fa58f3c --- /dev/null +++ b/internal/handler/admin/group/getGroupHistoryDetailHandler.go @@ -0,0 +1,26 @@ +package group + +import ( + "github.com/gin-gonic/gin" + "github.com/perfect-panel/server/internal/logic/admin/group" + "github.com/perfect-panel/server/internal/svc" + "github.com/perfect-panel/server/internal/types" + "github.com/perfect-panel/server/pkg/result" +) + +// Get group history detail +func GetGroupHistoryDetailHandler(svcCtx *svc.ServiceContext) func(c *gin.Context) { + return func(c *gin.Context) { + var req types.GetGroupHistoryDetailRequest + _ = c.ShouldBind(&req) + validateErr := svcCtx.Validate(&req) + if validateErr != nil { + result.ParamErrorResult(c, validateErr) + return + } + + l := group.NewGetGroupHistoryDetailLogic(c.Request.Context(), svcCtx) + resp, err := l.GetGroupHistoryDetail(&req) + result.HttpResult(c, resp, err) + } +} diff --git a/internal/handler/admin/group/getGroupHistoryHandler.go b/internal/handler/admin/group/getGroupHistoryHandler.go new file mode 100644 index 0000000..b6b5490 --- /dev/null +++ b/internal/handler/admin/group/getGroupHistoryHandler.go @@ -0,0 +1,26 @@ +package group + +import ( + "github.com/gin-gonic/gin" + "github.com/perfect-panel/server/internal/logic/admin/group" + "github.com/perfect-panel/server/internal/svc" + "github.com/perfect-panel/server/internal/types" + "github.com/perfect-panel/server/pkg/result" +) + +// Get group history +func GetGroupHistoryHandler(svcCtx *svc.ServiceContext) func(c *gin.Context) { + return func(c *gin.Context) { + var req types.GetGroupHistoryRequest + _ = c.ShouldBind(&req) + validateErr := svcCtx.Validate(&req) + if validateErr != nil { + result.ParamErrorResult(c, validateErr) + return + } + + l := group.NewGetGroupHistoryLogic(c.Request.Context(), svcCtx) + resp, err := l.GetGroupHistory(&req) + result.HttpResult(c, resp, err) + } +} diff --git a/internal/handler/admin/group/getNodeGroupListHandler.go b/internal/handler/admin/group/getNodeGroupListHandler.go new file mode 100644 index 0000000..501138f --- /dev/null +++ b/internal/handler/admin/group/getNodeGroupListHandler.go @@ -0,0 +1,26 @@ +package group + +import ( + "github.com/gin-gonic/gin" + "github.com/perfect-panel/server/internal/logic/admin/group" + "github.com/perfect-panel/server/internal/svc" + "github.com/perfect-panel/server/internal/types" + "github.com/perfect-panel/server/pkg/result" +) + +// Get node group list +func GetNodeGroupListHandler(svcCtx *svc.ServiceContext) func(c *gin.Context) { + return func(c *gin.Context) { + var req types.GetNodeGroupListRequest + _ = c.ShouldBind(&req) + validateErr := svcCtx.Validate(&req) + if validateErr != nil { + result.ParamErrorResult(c, validateErr) + return + } + + l := group.NewGetNodeGroupListLogic(c.Request.Context(), svcCtx) + resp, err := l.GetNodeGroupList(&req) + result.HttpResult(c, resp, err) + } +} diff --git a/internal/handler/admin/group/getRecalculationStatusHandler.go b/internal/handler/admin/group/getRecalculationStatusHandler.go new file mode 100644 index 0000000..e9b76b8 --- /dev/null +++ b/internal/handler/admin/group/getRecalculationStatusHandler.go @@ -0,0 +1,18 @@ +package group + +import ( + "github.com/gin-gonic/gin" + "github.com/perfect-panel/server/internal/logic/admin/group" + "github.com/perfect-panel/server/internal/svc" + "github.com/perfect-panel/server/pkg/result" +) + +// Get recalculation status +func GetRecalculationStatusHandler(svcCtx *svc.ServiceContext) func(c *gin.Context) { + return func(c *gin.Context) { + + l := group.NewGetRecalculationStatusLogic(c.Request.Context(), svcCtx) + resp, err := l.GetRecalculationStatus() + result.HttpResult(c, resp, err) + } +} diff --git a/internal/handler/admin/group/getSubscribeGroupMappingHandler.go b/internal/handler/admin/group/getSubscribeGroupMappingHandler.go new file mode 100644 index 0000000..4da798c --- /dev/null +++ b/internal/handler/admin/group/getSubscribeGroupMappingHandler.go @@ -0,0 +1,26 @@ +package group + +import ( + "github.com/gin-gonic/gin" + "github.com/perfect-panel/server/internal/logic/admin/group" + "github.com/perfect-panel/server/internal/svc" + "github.com/perfect-panel/server/internal/types" + "github.com/perfect-panel/server/pkg/result" +) + +// Get subscribe group mapping +func GetSubscribeGroupMappingHandler(svcCtx *svc.ServiceContext) func(c *gin.Context) { + return func(c *gin.Context) { + var req types.GetSubscribeGroupMappingRequest + _ = c.ShouldBind(&req) + validateErr := svcCtx.Validate(&req) + if validateErr != nil { + result.ParamErrorResult(c, validateErr) + return + } + + l := group.NewGetSubscribeGroupMappingLogic(c.Request.Context(), svcCtx) + resp, err := l.GetSubscribeGroupMapping(&req) + result.HttpResult(c, resp, err) + } +} diff --git a/internal/handler/admin/group/previewUserNodesHandler.go b/internal/handler/admin/group/previewUserNodesHandler.go new file mode 100644 index 0000000..da3560d --- /dev/null +++ b/internal/handler/admin/group/previewUserNodesHandler.go @@ -0,0 +1,26 @@ +package group + +import ( + "github.com/gin-gonic/gin" + "github.com/perfect-panel/server/internal/logic/admin/group" + "github.com/perfect-panel/server/internal/svc" + "github.com/perfect-panel/server/internal/types" + "github.com/perfect-panel/server/pkg/result" +) + +// Preview user nodes +func PreviewUserNodesHandler(svcCtx *svc.ServiceContext) func(c *gin.Context) { + return func(c *gin.Context) { + var req types.PreviewUserNodesRequest + _ = c.ShouldBind(&req) + validateErr := svcCtx.Validate(&req) + if validateErr != nil { + result.ParamErrorResult(c, validateErr) + return + } + + l := group.NewPreviewUserNodesLogic(c.Request.Context(), svcCtx) + resp, err := l.PreviewUserNodes(&req) + result.HttpResult(c, resp, err) + } +} diff --git a/internal/handler/admin/group/recalculateGroupHandler.go b/internal/handler/admin/group/recalculateGroupHandler.go new file mode 100644 index 0000000..848363d --- /dev/null +++ b/internal/handler/admin/group/recalculateGroupHandler.go @@ -0,0 +1,26 @@ +package group + +import ( + "github.com/gin-gonic/gin" + "github.com/perfect-panel/server/internal/logic/admin/group" + "github.com/perfect-panel/server/internal/svc" + "github.com/perfect-panel/server/internal/types" + "github.com/perfect-panel/server/pkg/result" +) + +// Recalculate group +func RecalculateGroupHandler(svcCtx *svc.ServiceContext) func(c *gin.Context) { + return func(c *gin.Context) { + var req types.RecalculateGroupRequest + _ = c.ShouldBind(&req) + validateErr := svcCtx.Validate(&req) + if validateErr != nil { + result.ParamErrorResult(c, validateErr) + return + } + + l := group.NewRecalculateGroupLogic(c.Request.Context(), svcCtx) + err := l.RecalculateGroup(&req) + result.HttpResult(c, nil, err) + } +} diff --git a/internal/handler/admin/group/resetGroupsHandler.go b/internal/handler/admin/group/resetGroupsHandler.go new file mode 100644 index 0000000..e0af912 --- /dev/null +++ b/internal/handler/admin/group/resetGroupsHandler.go @@ -0,0 +1,17 @@ +package group + +import ( + "github.com/gin-gonic/gin" + "github.com/perfect-panel/server/internal/logic/admin/group" + "github.com/perfect-panel/server/internal/svc" + "github.com/perfect-panel/server/pkg/result" +) + +// Reset all groups +func ResetGroupsHandler(svcCtx *svc.ServiceContext) func(c *gin.Context) { + return func(c *gin.Context) { + l := group.NewResetGroupsLogic(c.Request.Context(), svcCtx) + err := l.ResetGroups() + result.HttpResult(c, nil, err) + } +} diff --git a/internal/handler/admin/group/updateGroupConfigHandler.go b/internal/handler/admin/group/updateGroupConfigHandler.go new file mode 100644 index 0000000..6f2ea1c --- /dev/null +++ b/internal/handler/admin/group/updateGroupConfigHandler.go @@ -0,0 +1,26 @@ +package group + +import ( + "github.com/gin-gonic/gin" + "github.com/perfect-panel/server/internal/logic/admin/group" + "github.com/perfect-panel/server/internal/svc" + "github.com/perfect-panel/server/internal/types" + "github.com/perfect-panel/server/pkg/result" +) + +// Update group config +func UpdateGroupConfigHandler(svcCtx *svc.ServiceContext) func(c *gin.Context) { + return func(c *gin.Context) { + var req types.UpdateGroupConfigRequest + _ = c.ShouldBind(&req) + validateErr := svcCtx.Validate(&req) + if validateErr != nil { + result.ParamErrorResult(c, validateErr) + return + } + + l := group.NewUpdateGroupConfigLogic(c.Request.Context(), svcCtx) + err := l.UpdateGroupConfig(&req) + result.HttpResult(c, nil, err) + } +} diff --git a/internal/handler/admin/group/updateNodeGroupHandler.go b/internal/handler/admin/group/updateNodeGroupHandler.go new file mode 100644 index 0000000..e9f9058 --- /dev/null +++ b/internal/handler/admin/group/updateNodeGroupHandler.go @@ -0,0 +1,33 @@ +package group + +import ( + "github.com/gin-gonic/gin" + "github.com/perfect-panel/server/internal/logic/admin/group" + "github.com/perfect-panel/server/internal/svc" + "github.com/perfect-panel/server/internal/types" + "github.com/perfect-panel/server/pkg/result" +) + +// Update node group +func UpdateNodeGroupHandler(svcCtx *svc.ServiceContext) func(c *gin.Context) { + return func(c *gin.Context) { + var req types.UpdateNodeGroupRequest + if err := c.ShouldBindUri(&req); err != nil { + result.ParamErrorResult(c, err) + return + } + if err := c.ShouldBind(&req); err != nil { + result.ParamErrorResult(c, err) + return + } + validateErr := svcCtx.Validate(&req) + if validateErr != nil { + result.ParamErrorResult(c, validateErr) + return + } + + l := group.NewUpdateNodeGroupLogic(c.Request.Context(), svcCtx) + err := l.UpdateNodeGroup(&req) + result.HttpResult(c, nil, err) + } +} diff --git a/internal/handler/routes.go b/internal/handler/routes.go index ff26e9f..d811058 100644 --- a/internal/handler/routes.go +++ b/internal/handler/routes.go @@ -12,6 +12,7 @@ import ( adminConsole "github.com/perfect-panel/server/internal/handler/admin/console" adminCoupon "github.com/perfect-panel/server/internal/handler/admin/coupon" adminDocument "github.com/perfect-panel/server/internal/handler/admin/document" + adminGroup "github.com/perfect-panel/server/internal/handler/admin/group" adminLog "github.com/perfect-panel/server/internal/handler/admin/log" adminMarketing "github.com/perfect-panel/server/internal/handler/admin/marketing" adminOrder "github.com/perfect-panel/server/internal/handler/admin/order" @@ -188,6 +189,53 @@ func RegisterHandlers(router *gin.Engine, serverCtx *svc.ServiceContext) { adminDocumentGroupRouter.GET("/list", adminDocument.GetDocumentListHandler(serverCtx)) } + adminGroupGroupRouter := router.Group("/v1/admin/group") + adminGroupGroupRouter.Use(middleware.AuthMiddleware(serverCtx)) + + { + // Get group config + adminGroupGroupRouter.GET("/config", adminGroup.GetGroupConfigHandler(serverCtx)) + + // Update group config + adminGroupGroupRouter.PUT("/config", adminGroup.UpdateGroupConfigHandler(serverCtx)) + + // Export group result + adminGroupGroupRouter.GET("/export", adminGroup.ExportGroupResultHandler(serverCtx)) + + // Get group history + adminGroupGroupRouter.GET("/history", adminGroup.GetGroupHistoryHandler(serverCtx)) + + // Get group history detail + adminGroupGroupRouter.GET("/history/detail", adminGroup.GetGroupHistoryDetailHandler(serverCtx)) + + // Create node group + adminGroupGroupRouter.POST("/node", adminGroup.CreateNodeGroupHandler(serverCtx)) + + // Update node group + adminGroupGroupRouter.PUT("/node", adminGroup.UpdateNodeGroupHandler(serverCtx)) + + // Delete node group + adminGroupGroupRouter.DELETE("/node", adminGroup.DeleteNodeGroupHandler(serverCtx)) + + // Get node group list + adminGroupGroupRouter.GET("/node/list", adminGroup.GetNodeGroupListHandler(serverCtx)) + + // Preview user nodes + adminGroupGroupRouter.GET("/preview", adminGroup.PreviewUserNodesHandler(serverCtx)) + + // Recalculate group + adminGroupGroupRouter.POST("/recalculate", adminGroup.RecalculateGroupHandler(serverCtx)) + + // Get recalculation status + adminGroupGroupRouter.GET("/recalculation/status", adminGroup.GetRecalculationStatusHandler(serverCtx)) + + // Reset all groups + adminGroupGroupRouter.POST("/reset", adminGroup.ResetGroupsHandler(serverCtx)) + + // Get subscribe group mapping + adminGroupGroupRouter.GET("/subscribe/mapping", adminGroup.GetSubscribeGroupMappingHandler(serverCtx)) + } + adminLogGroupRouter := router.Group("/v1/admin/log") adminLogGroupRouter.Use(middleware.AuthMiddleware(serverCtx)) diff --git a/internal/logic/admin/group/createNodeGroupLogic.go b/internal/logic/admin/group/createNodeGroupLogic.go new file mode 100644 index 0000000..2d361d6 --- /dev/null +++ b/internal/logic/admin/group/createNodeGroupLogic.go @@ -0,0 +1,46 @@ +package group + +import ( + "context" + "time" + + "github.com/perfect-panel/server/internal/model/group" + "github.com/perfect-panel/server/internal/svc" + "github.com/perfect-panel/server/internal/types" + "github.com/perfect-panel/server/pkg/logger" +) + +type CreateNodeGroupLogic struct { + logger.Logger + ctx context.Context + svcCtx *svc.ServiceContext +} + +func NewCreateNodeGroupLogic(ctx context.Context, svcCtx *svc.ServiceContext) *CreateNodeGroupLogic { + return &CreateNodeGroupLogic{ + Logger: logger.WithContext(ctx), + ctx: ctx, + svcCtx: svcCtx, + } +} + +func (l *CreateNodeGroupLogic) CreateNodeGroup(req *types.CreateNodeGroupRequest) error { + // 创建节点组 + nodeGroup := &group.NodeGroup{ + Name: req.Name, + Description: req.Description, + Sort: req.Sort, + ForCalculation: req.ForCalculation, + MinTrafficGB: req.MinTrafficGB, + MaxTrafficGB: req.MaxTrafficGB, + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } + if err := l.svcCtx.DB.Create(nodeGroup).Error; err != nil { + logger.Errorf("failed to create node group: %v", err) + return err + } + + logger.Infof("created node group: node_group_id=%d", nodeGroup.Id) + return nil +} diff --git a/internal/logic/admin/group/deleteNodeGroupLogic.go b/internal/logic/admin/group/deleteNodeGroupLogic.go new file mode 100644 index 0000000..947dc49 --- /dev/null +++ b/internal/logic/admin/group/deleteNodeGroupLogic.go @@ -0,0 +1,61 @@ +package group + +import ( + "context" + "errors" + "fmt" + + "github.com/perfect-panel/server/internal/model/group" + "github.com/perfect-panel/server/internal/svc" + "github.com/perfect-panel/server/internal/types" + "github.com/perfect-panel/server/pkg/logger" + "gorm.io/gorm" +) + +type DeleteNodeGroupLogic struct { + logger.Logger + ctx context.Context + svcCtx *svc.ServiceContext +} + +func NewDeleteNodeGroupLogic(ctx context.Context, svcCtx *svc.ServiceContext) *DeleteNodeGroupLogic { + return &DeleteNodeGroupLogic{ + Logger: logger.WithContext(ctx), + ctx: ctx, + svcCtx: svcCtx, + } +} + +func (l *DeleteNodeGroupLogic) DeleteNodeGroup(req *types.DeleteNodeGroupRequest) error { + // 查询节点组信息 + var nodeGroup group.NodeGroup + if err := l.svcCtx.DB.Where("id = ?", req.Id).First(&nodeGroup).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return errors.New("node group not found") + } + logger.Errorf("failed to find node group: %v", err) + return err + } + + // 检查是否有关联节点 + var nodeCount int64 + if err := l.svcCtx.DB.Table("nodes").Where("node_group_id = ?", nodeGroup.Id).Count(&nodeCount).Error; err != nil { + logger.Errorf("failed to count nodes in group: %v", err) + return err + } + if nodeCount > 0 { + return fmt.Errorf("cannot delete group with %d associated nodes, please migrate nodes first", nodeCount) + } + + // 使用 GORM Transaction 删除节点组 + return l.svcCtx.DB.Transaction(func(tx *gorm.DB) error { + // 删除节点组 + if err := tx.Where("id = ?", req.Id).Delete(&group.NodeGroup{}).Error; err != nil { + logger.Errorf("failed to delete node group: %v", err) + return err // 自动回滚 + } + + logger.Infof("deleted node group: id=%d", nodeGroup.Id) + return nil // 自动提交 + }) +} diff --git a/internal/logic/admin/group/exportGroupResultLogic.go b/internal/logic/admin/group/exportGroupResultLogic.go new file mode 100644 index 0000000..ef2183f --- /dev/null +++ b/internal/logic/admin/group/exportGroupResultLogic.go @@ -0,0 +1,128 @@ +package group + +import ( + "bytes" + "context" + "encoding/csv" + "fmt" + + "github.com/perfect-panel/server/internal/model/group" + "github.com/perfect-panel/server/internal/svc" + "github.com/perfect-panel/server/internal/types" + "github.com/perfect-panel/server/pkg/logger" +) + +type ExportGroupResultLogic struct { + logger.Logger + ctx context.Context + svcCtx *svc.ServiceContext +} + +func NewExportGroupResultLogic(ctx context.Context, svcCtx *svc.ServiceContext) *ExportGroupResultLogic { + return &ExportGroupResultLogic{ + Logger: logger.WithContext(ctx), + ctx: ctx, + svcCtx: svcCtx, + } +} + +// ExportGroupResult 导出分组结果为 CSV +// 返回:CSV 数据(字节切片)、文件名、错误 +func (l *ExportGroupResultLogic) ExportGroupResult(req *types.ExportGroupResultRequest) ([]byte, string, error) { + var records [][]string + + // CSV 表头 + records = append(records, []string{"用户ID", "节点组ID", "节点组名称"}) + + if req.HistoryId != nil { + // 导出指定历史的详细结果 + // 1. 查询分组历史详情 + var details []group.GroupHistoryDetail + if err := l.svcCtx.DB.Where("history_id = ?", *req.HistoryId).Find(&details).Error; err != nil { + logger.Errorf("failed to get group history details: %v", err) + return nil, "", err + } + + // 2. 为每个组生成记录 + for _, detail := range details { + // 从 UserData JSON 解析用户信息 + type UserInfo struct { + Id int64 `json:"id"` + Email string `json:"email"` + } + var users []UserInfo + if err := l.svcCtx.DB.Raw("SELECT * FROM JSON_ARRAY(?)", detail.UserData).Scan(&users).Error; err != nil { + // 如果解析失败,尝试用标准 JSON 解析 + logger.Errorf("failed to parse user data: %v", err) + continue + } + + // 查询节点组名称 + var nodeGroup group.NodeGroup + l.svcCtx.DB.Where("id = ?", detail.NodeGroupId).First(&nodeGroup) + + // 为每个用户生成记录 + for _, user := range users { + records = append(records, []string{ + fmt.Sprintf("%d", user.Id), + fmt.Sprintf("%d", nodeGroup.Id), + nodeGroup.Name, + }) + } + } + } else { + // 导出当前所有用户的分组情况 + type UserNodeGroupInfo struct { + Id int64 `json:"id"` + NodeGroupId int64 `json:"node_group_id"` + } + var userSubscribes []UserNodeGroupInfo + if err := l.svcCtx.DB.Table("user_subscribe"). + Select("DISTINCT user_id as id, node_group_id"). + Where("node_group_id > ?", 0). + Find(&userSubscribes).Error; err != nil { + logger.Errorf("failed to get users: %v", err) + return nil, "", err + } + + // 为每个用户生成记录 + for _, us := range userSubscribes { + // 查询节点组信息 + var nodeGroup group.NodeGroup + if err := l.svcCtx.DB.Where("id = ?", us.NodeGroupId).First(&nodeGroup).Error; err != nil { + logger.Errorf("failed to find node group: %v", err) + // 跳过该用户 + continue + } + + records = append(records, []string{ + fmt.Sprintf("%d", us.Id), + fmt.Sprintf("%d", nodeGroup.Id), + nodeGroup.Name, + }) + } + } + + // 生成 CSV 数据 + var buf bytes.Buffer + writer := csv.NewWriter(&buf) + writer.WriteAll(records) + writer.Flush() + + if err := writer.Error(); err != nil { + logger.Errorf("failed to write csv: %v", err) + return nil, "", err + } + + // 添加 UTF-8 BOM + bom := []byte{0xEF, 0xBB, 0xBF} + csvData := buf.Bytes() + result := make([]byte, 0, len(bom)+len(csvData)) + result = append(result, bom...) + result = append(result, csvData...) + + // 生成文件名 + filename := fmt.Sprintf("group_result_%d.csv", req.HistoryId) + + return result, filename, nil +} diff --git a/internal/logic/admin/group/getGroupConfigLogic.go b/internal/logic/admin/group/getGroupConfigLogic.go new file mode 100644 index 0000000..2aedb10 --- /dev/null +++ b/internal/logic/admin/group/getGroupConfigLogic.go @@ -0,0 +1,125 @@ +package group + +import ( + "context" + "encoding/json" + + "github.com/perfect-panel/server/internal/model/group" + "github.com/perfect-panel/server/internal/model/system" + "github.com/perfect-panel/server/internal/svc" + "github.com/perfect-panel/server/internal/types" + "github.com/perfect-panel/server/pkg/logger" + "github.com/pkg/errors" + "gorm.io/gorm" +) + +type GetGroupConfigLogic struct { + logger.Logger + ctx context.Context + svcCtx *svc.ServiceContext +} + +// Get group config +func NewGetGroupConfigLogic(ctx context.Context, svcCtx *svc.ServiceContext) *GetGroupConfigLogic { + return &GetGroupConfigLogic{ + Logger: logger.WithContext(ctx), + ctx: ctx, + svcCtx: svcCtx, + } +} + +func (l *GetGroupConfigLogic) GetGroupConfig(req *types.GetGroupConfigRequest) (resp *types.GetGroupConfigResponse, err error) { + // 读取基础配置 + var enabledConfig system.System + var modeConfig system.System + var averageConfig system.System + var subscribeConfig system.System + var trafficConfig system.System + + // 从 system_config 表读取配置 + if err := l.svcCtx.DB.Where("`category` = 'group' and `key` = ?", "enabled").First(&enabledConfig).Error; err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { + l.Errorw("failed to get group enabled config", logger.Field("error", err.Error())) + return nil, err + } + + if err := l.svcCtx.DB.Where("`category` = 'group' and `key` = ?", "mode").First(&modeConfig).Error; err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { + l.Errorw("failed to get group mode config", logger.Field("error", err.Error())) + return nil, err + } + + // 读取 JSON 配置 + config := make(map[string]interface{}) + + if err := l.svcCtx.DB.Where("`category` = 'group' and `key` = ?", "average_config").First(&averageConfig).Error; err == nil { + var averageCfg map[string]interface{} + if err := json.Unmarshal([]byte(averageConfig.Value), &averageCfg); err == nil { + config["average_config"] = averageCfg + } + } + + if err := l.svcCtx.DB.Where("`category` = 'group' and `key` = ?", "subscribe_config").First(&subscribeConfig).Error; err == nil { + var subscribeCfg map[string]interface{} + if err := json.Unmarshal([]byte(subscribeConfig.Value), &subscribeCfg); err == nil { + config["subscribe_config"] = subscribeCfg + } + } + + if err := l.svcCtx.DB.Where("`category` = 'group' and `key` = ?", "traffic_config").First(&trafficConfig).Error; err == nil { + var trafficCfg map[string]interface{} + if err := json.Unmarshal([]byte(trafficConfig.Value), &trafficCfg); err == nil { + config["traffic_config"] = trafficCfg + } + } + + // 解析基础配置 + enabled := enabledConfig.Value == "true" + mode := modeConfig.Value + if mode == "" { + mode = "average" // 默认模式 + } + + // 获取重算状态 + state, err := l.getRecalculationState() + if err != nil { + l.Errorw("failed to get recalculation state", logger.Field("error", err.Error())) + // 继续执行,不影响配置获取 + state = &types.RecalculationState{ + State: "idle", + Progress: 0, + Total: 0, + } + } + + resp = &types.GetGroupConfigResponse{ + Enabled: enabled, + Mode: mode, + Config: config, + State: *state, + } + + return resp, nil +} + +// getRecalculationState 获取重算状态 +func (l *GetGroupConfigLogic) getRecalculationState() (*types.RecalculationState, error) { + var history group.GroupHistory + err := l.svcCtx.DB.Order("id desc").First(&history).Error + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return &types.RecalculationState{ + State: "idle", + Progress: 0, + Total: 0, + }, nil + } + return nil, err + } + + state := &types.RecalculationState{ + State: history.State, + Progress: history.TotalUsers, + Total: history.TotalUsers, + } + + return state, nil +} diff --git a/internal/logic/admin/group/getGroupHistoryDetailLogic.go b/internal/logic/admin/group/getGroupHistoryDetailLogic.go new file mode 100644 index 0000000..d868d55 --- /dev/null +++ b/internal/logic/admin/group/getGroupHistoryDetailLogic.go @@ -0,0 +1,109 @@ +package group + +import ( + "context" + "encoding/json" + "errors" + + "github.com/perfect-panel/server/internal/model/group" + "github.com/perfect-panel/server/internal/svc" + "github.com/perfect-panel/server/internal/types" + "github.com/perfect-panel/server/pkg/logger" + "gorm.io/gorm" +) + +type GetGroupHistoryDetailLogic struct { + logger.Logger + ctx context.Context + svcCtx *svc.ServiceContext +} + +func NewGetGroupHistoryDetailLogic(ctx context.Context, svcCtx *svc.ServiceContext) *GetGroupHistoryDetailLogic { + return &GetGroupHistoryDetailLogic{ + Logger: logger.WithContext(ctx), + ctx: ctx, + svcCtx: svcCtx, + } +} + +func (l *GetGroupHistoryDetailLogic) GetGroupHistoryDetail(req *types.GetGroupHistoryDetailRequest) (resp *types.GetGroupHistoryDetailResponse, err error) { + // 查询分组历史记录 + var history group.GroupHistory + if err := l.svcCtx.DB.Where("id = ?", req.Id).First(&history).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, errors.New("group history not found") + } + logger.Errorf("failed to find group history: %v", err) + return nil, err + } + + // 查询分组历史详情 + var details []group.GroupHistoryDetail + if err := l.svcCtx.DB.Where("history_id = ?", req.Id).Find(&details).Error; err != nil { + logger.Errorf("failed to find group history details: %v", err) + return nil, err + } + + // 转换时间格式 + var startTime, endTime *int64 + if history.StartTime != nil { + t := history.StartTime.Unix() + startTime = &t + } + if history.EndTime != nil { + t := history.EndTime.Unix() + endTime = &t + } + + // 构建 GroupHistoryDetail + historyDetail := types.GroupHistoryDetail{ + GroupHistory: types.GroupHistory{ + Id: history.Id, + GroupMode: history.GroupMode, + TriggerType: history.TriggerType, + TotalUsers: history.TotalUsers, + SuccessCount: history.SuccessCount, + FailedCount: history.FailedCount, + StartTime: startTime, + EndTime: endTime, + ErrorLog: history.ErrorMessage, + CreatedAt: history.CreatedAt.Unix(), + }, + } + + // 如果有详情记录,构建 ConfigSnapshot + if len(details) > 0 { + configSnapshot := make(map[string]interface{}) + configSnapshot["group_details"] = details + + // 获取配置快照(从 system_config 读取) + var configValue string + if history.GroupMode == "average" { + l.svcCtx.DB.Table("system_config"). + Where("`key` = ?", "group.average_config"). + Select("value"). + Scan(&configValue) + } else if history.GroupMode == "traffic" { + l.svcCtx.DB.Table("system_config"). + Where("`key` = ?", "group.traffic_config"). + Select("value"). + Scan(&configValue) + } + + // 解析 JSON 配置 + if configValue != "" { + var config map[string]interface{} + if err := json.Unmarshal([]byte(configValue), &config); err == nil { + configSnapshot["config"] = config + } + } + + historyDetail.ConfigSnapshot = configSnapshot + } + + resp = &types.GetGroupHistoryDetailResponse{ + GroupHistoryDetail: historyDetail, + } + + return resp, nil +} diff --git a/internal/logic/admin/group/getGroupHistoryLogic.go b/internal/logic/admin/group/getGroupHistoryLogic.go new file mode 100644 index 0000000..6eee9c3 --- /dev/null +++ b/internal/logic/admin/group/getGroupHistoryLogic.go @@ -0,0 +1,87 @@ +package group + +import ( + "context" + + "github.com/perfect-panel/server/internal/model/group" + "github.com/perfect-panel/server/internal/svc" + "github.com/perfect-panel/server/internal/types" + "github.com/perfect-panel/server/pkg/logger" +) + +type GetGroupHistoryLogic struct { + logger.Logger + ctx context.Context + svcCtx *svc.ServiceContext +} + +func NewGetGroupHistoryLogic(ctx context.Context, svcCtx *svc.ServiceContext) *GetGroupHistoryLogic { + return &GetGroupHistoryLogic{ + Logger: logger.WithContext(ctx), + ctx: ctx, + svcCtx: svcCtx, + } +} + +func (l *GetGroupHistoryLogic) GetGroupHistory(req *types.GetGroupHistoryRequest) (resp *types.GetGroupHistoryResponse, err error) { + var histories []group.GroupHistory + var total int64 + + // 构建查询 + query := l.svcCtx.DB.Model(&group.GroupHistory{}) + + // 添加过滤条件 + if req.GroupMode != "" { + query = query.Where("group_mode = ?", req.GroupMode) + } + if req.TriggerType != "" { + query = query.Where("trigger_type = ?", req.TriggerType) + } + + // 获取总数 + if err := query.Count(&total).Error; err != nil { + logger.Errorf("failed to count group histories: %v", err) + return nil, err + } + + // 分页查询 + offset := (req.Page - 1) * req.Size + if err := query.Order("id DESC").Offset(offset).Limit(req.Size).Find(&histories).Error; err != nil { + logger.Errorf("failed to find group histories: %v", err) + return nil, err + } + + // 转换为响应格式 + var list []types.GroupHistory + for _, h := range histories { + var startTime, endTime *int64 + if h.StartTime != nil { + t := h.StartTime.Unix() + startTime = &t + } + if h.EndTime != nil { + t := h.EndTime.Unix() + endTime = &t + } + + list = append(list, types.GroupHistory{ + Id: h.Id, + GroupMode: h.GroupMode, + TriggerType: h.TriggerType, + TotalUsers: h.TotalUsers, + SuccessCount: h.SuccessCount, + FailedCount: h.FailedCount, + StartTime: startTime, + EndTime: endTime, + ErrorLog: h.ErrorMessage, + CreatedAt: h.CreatedAt.Unix(), + }) + } + + resp = &types.GetGroupHistoryResponse{ + Total: total, + List: list, + } + + return resp, nil +} diff --git a/internal/logic/admin/group/getNodeGroupListLogic.go b/internal/logic/admin/group/getNodeGroupListLogic.go new file mode 100644 index 0000000..abd1a6a --- /dev/null +++ b/internal/logic/admin/group/getNodeGroupListLogic.go @@ -0,0 +1,89 @@ +package group + +import ( + "context" + + "github.com/perfect-panel/server/internal/model/group" + "github.com/perfect-panel/server/internal/svc" + "github.com/perfect-panel/server/internal/types" + "github.com/perfect-panel/server/pkg/logger" +) + +type GetNodeGroupListLogic struct { + logger.Logger + ctx context.Context + svcCtx *svc.ServiceContext +} + +func NewGetNodeGroupListLogic(ctx context.Context, svcCtx *svc.ServiceContext) *GetNodeGroupListLogic { + return &GetNodeGroupListLogic{ + Logger: logger.WithContext(ctx), + ctx: ctx, + svcCtx: svcCtx, + } +} + +func (l *GetNodeGroupListLogic) GetNodeGroupList(req *types.GetNodeGroupListRequest) (resp *types.GetNodeGroupListResponse, err error) { + var nodeGroups []group.NodeGroup + var total int64 + + // 构建查询 + query := l.svcCtx.DB.Model(&group.NodeGroup{}) + + // 获取总数 + if err := query.Count(&total).Error; err != nil { + logger.Errorf("failed to count node groups: %v", err) + return nil, err + } + + // 分页查询 + offset := (req.Page - 1) * req.Size + if err := query.Order("sort ASC").Offset(offset).Limit(req.Size).Find(&nodeGroups).Error; err != nil { + logger.Errorf("failed to find node groups: %v", err) + return nil, err + } + + // 转换为响应格式 + var list []types.NodeGroup + for _, ng := range nodeGroups { + // 统计该组的节点数 + var nodeCount int64 + l.svcCtx.DB.Table("nodes").Where("node_group_id = ?", ng.Id).Count(&nodeCount) + + // 处理指针类型的字段 + var forCalculation bool + if ng.ForCalculation != nil { + forCalculation = *ng.ForCalculation + } else { + forCalculation = true // 默认值 + } + + var minTrafficGB, maxTrafficGB int64 + if ng.MinTrafficGB != nil { + minTrafficGB = *ng.MinTrafficGB + } + if ng.MaxTrafficGB != nil { + maxTrafficGB = *ng.MaxTrafficGB + } + + list = append(list, types.NodeGroup{ + Id: ng.Id, + Name: ng.Name, + Description: ng.Description, + Sort: ng.Sort, + ForCalculation: forCalculation, + MinTrafficGB: minTrafficGB, + MaxTrafficGB: maxTrafficGB, + NodeCount: nodeCount, + CreatedAt: ng.CreatedAt.Unix(), + UpdatedAt: ng.UpdatedAt.Unix(), + }) + } + + resp = &types.GetNodeGroupListResponse{ + Total: total, + List: list, + } + + return resp, nil +} diff --git a/internal/logic/admin/group/getRecalculationStatusLogic.go b/internal/logic/admin/group/getRecalculationStatusLogic.go new file mode 100644 index 0000000..9a04f80 --- /dev/null +++ b/internal/logic/admin/group/getRecalculationStatusLogic.go @@ -0,0 +1,57 @@ +package group + +import ( + "context" + + "github.com/perfect-panel/server/internal/model/group" + "github.com/perfect-panel/server/internal/svc" + "github.com/perfect-panel/server/internal/types" + "github.com/perfect-panel/server/pkg/logger" + "github.com/pkg/errors" + "gorm.io/gorm" +) + +type GetRecalculationStatusLogic struct { + logger.Logger + ctx context.Context + svcCtx *svc.ServiceContext +} + +// Get recalculation status +func NewGetRecalculationStatusLogic(ctx context.Context, svcCtx *svc.ServiceContext) *GetRecalculationStatusLogic { + return &GetRecalculationStatusLogic{ + Logger: logger.WithContext(ctx), + ctx: ctx, + svcCtx: svcCtx, + } +} + +func (l *GetRecalculationStatusLogic) GetRecalculationStatus() (resp *types.RecalculationState, err error) { + // 返回最近的一条 GroupHistory 记录 + var history group.GroupHistory + err = l.svcCtx.DB.Order("id desc").First(&history).Error + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + // 如果没有历史记录,返回空闲状态 + resp = &types.RecalculationState{ + State: "idle", + Progress: 0, + Total: 0, + } + return resp, nil + } + l.Errorw("failed to get group history", logger.Field("error", err.Error())) + return nil, err + } + + // 转换为 RecalculationState 格式 + // Progress = 已处理的用户数(成功+失败),Total = 总用户数 + processedUsers := history.SuccessCount + history.FailedCount + resp = &types.RecalculationState{ + State: history.State, + Progress: processedUsers, + Total: history.TotalUsers, + } + + return resp, nil +} diff --git a/internal/logic/admin/group/getSubscribeGroupMappingLogic.go b/internal/logic/admin/group/getSubscribeGroupMappingLogic.go new file mode 100644 index 0000000..fb3ed90 --- /dev/null +++ b/internal/logic/admin/group/getSubscribeGroupMappingLogic.go @@ -0,0 +1,71 @@ +package group + +import ( + "context" + + "github.com/perfect-panel/server/internal/model/group" + "github.com/perfect-panel/server/internal/model/subscribe" + "github.com/perfect-panel/server/internal/svc" + "github.com/perfect-panel/server/internal/types" + "github.com/perfect-panel/server/pkg/logger" +) + +type GetSubscribeGroupMappingLogic struct { + logger.Logger + ctx context.Context + svcCtx *svc.ServiceContext +} + +// Get subscribe group mapping +func NewGetSubscribeGroupMappingLogic(ctx context.Context, svcCtx *svc.ServiceContext) *GetSubscribeGroupMappingLogic { + return &GetSubscribeGroupMappingLogic{ + Logger: logger.WithContext(ctx), + ctx: ctx, + svcCtx: svcCtx, + } +} + +func (l *GetSubscribeGroupMappingLogic) GetSubscribeGroupMapping(req *types.GetSubscribeGroupMappingRequest) (resp *types.GetSubscribeGroupMappingResponse, err error) { + // 1. 查询所有订阅套餐 + var subscribes []subscribe.Subscribe + if err := l.svcCtx.DB.Table("subscribe").Find(&subscribes).Error; err != nil { + l.Errorw("[GetSubscribeGroupMapping] failed to query subscribes", logger.Field("error", err.Error())) + return nil, err + } + + // 2. 查询所有节点组 + var nodeGroups []group.NodeGroup + if err := l.svcCtx.DB.Table("node_group").Find(&nodeGroups).Error; err != nil { + l.Errorw("[GetSubscribeGroupMapping] failed to query node groups", logger.Field("error", err.Error())) + return nil, err + } + + // 创建 node_group_id -> node_group_name 的映射 + nodeGroupMap := make(map[int64]string) + for _, ng := range nodeGroups { + nodeGroupMap[ng.Id] = ng.Name + } + + // 3. 构建映射结果:套餐 -> 默认节点组(一对一) + var mappingList []types.SubscribeGroupMappingItem + + for _, sub := range subscribes { + // 获取套餐的默认节点组(node_group_ids 数组的第一个) + nodeGroupName := "" + if len(sub.NodeGroupIds) > 0 { + defaultNodeGroupId := sub.NodeGroupIds[0] + nodeGroupName = nodeGroupMap[defaultNodeGroupId] + } + + mappingList = append(mappingList, types.SubscribeGroupMappingItem{ + SubscribeName: sub.Name, + NodeGroupName: nodeGroupName, + }) + } + + resp = &types.GetSubscribeGroupMappingResponse{ + List: mappingList, + } + + return resp, nil +} diff --git a/internal/logic/admin/group/previewUserNodesLogic.go b/internal/logic/admin/group/previewUserNodesLogic.go new file mode 100644 index 0000000..ba91f4e --- /dev/null +++ b/internal/logic/admin/group/previewUserNodesLogic.go @@ -0,0 +1,466 @@ +package group + +import ( + "context" + "encoding/json" + "fmt" + "strings" + + "github.com/perfect-panel/server/internal/model/node" + "github.com/perfect-panel/server/internal/svc" + "github.com/perfect-panel/server/internal/types" + "github.com/perfect-panel/server/pkg/logger" + "github.com/perfect-panel/server/pkg/tool" +) + +type PreviewUserNodesLogic struct { + logger.Logger + ctx context.Context + svcCtx *svc.ServiceContext +} + +func NewPreviewUserNodesLogic(ctx context.Context, svcCtx *svc.ServiceContext) *PreviewUserNodesLogic { + return &PreviewUserNodesLogic{ + Logger: logger.WithContext(ctx), + ctx: ctx, + svcCtx: svcCtx, + } +} + +func (l *PreviewUserNodesLogic) PreviewUserNodes(req *types.PreviewUserNodesRequest) (resp *types.PreviewUserNodesResponse, err error) { + logger.Infof("[PreviewUserNodes] userId: %v", req.UserId) + + // 1. 查询用户的所有有效订阅(只查询可用状态:0-Pending, 1-Active) + type UserSubscribe struct { + Id int64 + UserId int64 + SubscribeId int64 + NodeGroupId int64 // 用户订阅的 node_group_id(单个ID) + } + var userSubscribes []UserSubscribe + err = l.svcCtx.DB.Table("user_subscribe"). + Select("id, user_id, subscribe_id, node_group_id"). + Where("user_id = ? AND status IN ?", req.UserId, []int8{0, 1}). + Find(&userSubscribes).Error + if err != nil { + logger.Errorf("[PreviewUserNodes] failed to get user subscribes: %v", err) + return nil, err + } + + if len(userSubscribes) == 0 { + logger.Infof("[PreviewUserNodes] no user subscribes found") + resp = &types.PreviewUserNodesResponse{ + UserId: req.UserId, + NodeGroups: []types.NodeGroupItem{}, + } + return resp, nil + } + + logger.Infof("[PreviewUserNodes] found %v user subscribes", len(userSubscribes)) + + // 2. 按优先级获取 node_group_id:user_subscribe.node_group_id > subscribe.node_group_id > subscribe.node_group_ids[0] + // 收集所有订阅ID以便批量查询 + subscribeIds := make([]int64, len(userSubscribes)) + for i, us := range userSubscribes { + subscribeIds[i] = us.SubscribeId + } + + // 批量查询订阅信息 + type SubscribeInfo struct { + Id int64 + NodeGroupId int64 + NodeGroupIds string // JSON string + } + var subscribeInfos []SubscribeInfo + err = l.svcCtx.DB.Table("subscribe"). + Select("id, node_group_id, node_group_ids"). + Where("id IN ?", subscribeIds). + Find(&subscribeInfos).Error + if err != nil { + logger.Errorf("[PreviewUserNodes] failed to get subscribe infos: %v", err) + return nil, err + } + + // 创建 subscribe_id -> SubscribeInfo 的映射 + subInfoMap := make(map[int64]SubscribeInfo) + for _, si := range subscribeInfos { + subInfoMap[si.Id] = si + } + + // 按优先级获取每个用户订阅的 node_group_id + var allNodeGroupIds []int64 + for _, us := range userSubscribes { + nodeGroupId := int64(0) + + // 优先级1: user_subscribe.node_group_id + if us.NodeGroupId != 0 { + nodeGroupId = us.NodeGroupId + logger.Debugf("[PreviewUserNodes] user_subscribe_id=%d using node_group_id=%d", us.Id, nodeGroupId) + } else { + // 优先级2: subscribe.node_group_id + subInfo, ok := subInfoMap[us.SubscribeId] + if ok { + if subInfo.NodeGroupId != 0 { + nodeGroupId = subInfo.NodeGroupId + logger.Debugf("[PreviewUserNodes] user_subscribe_id=%d using subscribe.node_group_id=%d", us.Id, nodeGroupId) + } else if subInfo.NodeGroupIds != "" && subInfo.NodeGroupIds != "null" && subInfo.NodeGroupIds != "[]" { + // 优先级3: subscribe.node_group_ids[0] + var nodeGroupIds []int64 + if err := json.Unmarshal([]byte(subInfo.NodeGroupIds), &nodeGroupIds); err == nil && len(nodeGroupIds) > 0 { + nodeGroupId = nodeGroupIds[0] + logger.Debugf("[PreviewUserNodes] user_subscribe_id=%d using subscribe.node_group_ids[0]=%d", us.Id, nodeGroupId) + } + } + } + } + + if nodeGroupId != 0 { + allNodeGroupIds = append(allNodeGroupIds, nodeGroupId) + } + } + + // 去重 + allNodeGroupIds = removeDuplicateInt64(allNodeGroupIds) + + logger.Infof("[PreviewUserNodes] collected node_group_ids with priority: %v", allNodeGroupIds) + + // 4. 判断分组功能是否启用 + var groupEnabled string + l.svcCtx.DB.Table("system"). + Where("`category` = ? AND `key` = ?", "group", "enabled"). + Select("value"). + Scan(&groupEnabled) + + logger.Infof("[PreviewUserNodes] groupEnabled: %v", groupEnabled) + + isGroupEnabled := groupEnabled == "true" || groupEnabled == "1" + + var filteredNodes []node.Node + + if isGroupEnabled { + // === 启用分组功能:通过用户订阅的 node_group_id 查询节点 === + logger.Infof("[PreviewUserNodes] using group-based node filtering") + + if len(allNodeGroupIds) == 0 { + logger.Infof("[PreviewUserNodes] no node groups found in user subscribes") + resp = &types.PreviewUserNodesResponse{ + UserId: req.UserId, + NodeGroups: []types.NodeGroupItem{}, + } + return resp, nil + } + + // 5. 查询所有启用的节点 + var dbNodes []node.Node + err = l.svcCtx.DB.Table("nodes"). + Where("enabled = ?", true). + Find(&dbNodes).Error + if err != nil { + logger.Errorf("[PreviewUserNodes] failed to get nodes: %v", err) + return nil, err + } + + // 6. 过滤出包含至少一个匹配节点组的节点 + // node_group_ids 为空 = 公共节点,所有人可见 + // node_group_ids 与订阅的 node_group_id 匹配 = 该节点可见 + for _, n := range dbNodes { + // 公共节点(node_group_ids 为空),所有人可见 + if len(n.NodeGroupIds) == 0 { + filteredNodes = append(filteredNodes, n) + continue + } + + // 检查节点的 node_group_ids 是否与订阅的 node_group_id 有交集 + for _, nodeGroupId := range n.NodeGroupIds { + if tool.Contains(allNodeGroupIds, nodeGroupId) { + filteredNodes = append(filteredNodes, n) + break + } + } + } + + logger.Infof("[PreviewUserNodes] found %v nodes using group filter", len(filteredNodes)) + + } else { + // === 未启用分组功能:通过订阅的 node_tags 查询节点 === + logger.Infof("[PreviewUserNodes] using tag-based node filtering") + + // 5. 获取所有订阅的 subscribeId 列表 + subscribeIds := make([]int64, len(userSubscribes)) + for i, us := range userSubscribes { + subscribeIds[i] = us.SubscribeId + } + + // 6. 查询这些订阅的 node_tags + type SubscribeNodeTags struct { + Id int64 + NodeTags string + } + var subscribeNodeTagsList []SubscribeNodeTags + err = l.svcCtx.DB.Table("subscribe"). + Where("id IN ?", subscribeIds). + Select("id, node_tags"). + Find(&subscribeNodeTagsList).Error + if err != nil { + logger.Errorf("[PreviewUserNodes] failed to get subscribe node tags: %v", err) + return nil, err + } + + // 7. 合并所有标签 + var allTags []string + for _, snt := range subscribeNodeTagsList { + if snt.NodeTags != "" { + tags := strings.Split(snt.NodeTags, ",") + allTags = append(allTags, tags...) + } + } + // 去重 + allTags = tool.RemoveDuplicateElements(allTags...) + // 去除空字符串 + allTags = tool.RemoveStringElement(allTags, "") + + logger.Infof("[PreviewUserNodes] merged tags from subscribes: %v", allTags) + + if len(allTags) == 0 { + logger.Infof("[PreviewUserNodes] no tags found in subscribes") + resp = &types.PreviewUserNodesResponse{ + UserId: req.UserId, + NodeGroups: []types.NodeGroupItem{}, + } + return resp, nil + } + + // 8. 查询所有启用的节点 + var dbNodes []node.Node + err = l.svcCtx.DB.Table("nodes"). + Where("enabled = ?", true). + Find(&dbNodes).Error + if err != nil { + logger.Errorf("[PreviewUserNodes] failed to get nodes: %v", err) + return nil, err + } + + // 9. 过滤出包含至少一个匹配标签的节点 + for _, n := range dbNodes { + if n.Tags == "" { + continue + } + nodeTags := strings.Split(n.Tags, ",") + // 检查是否有交集 + for _, tag := range nodeTags { + if tag != "" && tool.Contains(allTags, tag) { + filteredNodes = append(filteredNodes, n) + break + } + } + } + + logger.Infof("[PreviewUserNodes] found %v nodes using tag filter", len(filteredNodes)) + } + + // 10. 转换为 types.Node 并按节点组分组 + type NodeWithGroup struct { + Node node.Node + NodeGroupIds []int64 + } + + nodesWithGroup := make([]NodeWithGroup, 0, len(filteredNodes)) + for _, n := range filteredNodes { + nodesWithGroup = append(nodesWithGroup, NodeWithGroup{ + Node: n, + NodeGroupIds: []int64(n.NodeGroupIds), + }) + } + + // 11. 按节点组分组节点 + type NodeGroupMap struct { + Id int64 + Nodes []types.Node + } + + // 创建节点组映射:group_id -> nodes + groupMap := make(map[int64]*NodeGroupMap) + + // 获取所有涉及的节点组ID + allGroupIds := make([]int64, 0) + for _, ng := range nodesWithGroup { + if len(ng.NodeGroupIds) > 0 { + // 如果节点属于节点组,按第一个节点组分组(或者可以按所有节点组) + // 这里使用节点的第一个节点组 + firstGroupId := ng.NodeGroupIds[0] + if _, exists := groupMap[firstGroupId]; !exists { + groupMap[firstGroupId] = &NodeGroupMap{ + Id: firstGroupId, + Nodes: []types.Node{}, + } + allGroupIds = append(allGroupIds, firstGroupId) + } + + // 转换节点 + tags := []string{} + if ng.Node.Tags != "" { + tags = strings.Split(ng.Node.Tags, ",") + } + node := types.Node{ + Id: ng.Node.Id, + Name: ng.Node.Name, + Tags: tags, + Port: ng.Node.Port, + Address: ng.Node.Address, + ServerId: ng.Node.ServerId, + Protocol: ng.Node.Protocol, + Enabled: ng.Node.Enabled, + Sort: ng.Node.Sort, + NodeGroupIds: []int64(ng.Node.NodeGroupIds), + CreatedAt: ng.Node.CreatedAt.Unix(), + UpdatedAt: ng.Node.UpdatedAt.Unix(), + } + + groupMap[firstGroupId].Nodes = append(groupMap[firstGroupId].Nodes, node) + } else { + // 没有节点组的节点,使用 group_id = 0 作为"无节点组"分组 + if _, exists := groupMap[0]; !exists { + groupMap[0] = &NodeGroupMap{ + Id: 0, + Nodes: []types.Node{}, + } + } + + tags := []string{} + if ng.Node.Tags != "" { + tags = strings.Split(ng.Node.Tags, ",") + } + node := types.Node{ + Id: ng.Node.Id, + Name: ng.Node.Name, + Tags: tags, + Port: ng.Node.Port, + Address: ng.Node.Address, + ServerId: ng.Node.ServerId, + Protocol: ng.Node.Protocol, + Enabled: ng.Node.Enabled, + Sort: ng.Node.Sort, + NodeGroupIds: []int64(ng.Node.NodeGroupIds), + CreatedAt: ng.Node.CreatedAt.Unix(), + UpdatedAt: ng.Node.UpdatedAt.Unix(), + } + + groupMap[0].Nodes = append(groupMap[0].Nodes, node) + } + } + + // 12. 查询节点组信息并构建响应 + nodeGroupInfoMap := make(map[int64]string) + validGroupIds := make([]int64, 0) // 存储在数据库中实际存在的节点组ID + + if len(allGroupIds) > 0 { + type NodeGroupInfo struct { + Id int64 + Name string + } + var nodeGroupInfos []NodeGroupInfo + err = l.svcCtx.DB.Table("node_group"). + Select("id, name"). + Where("id IN ?", allGroupIds). + Find(&nodeGroupInfos).Error + if err != nil { + logger.Errorf("[PreviewUserNodes] failed to get node group infos: %v", err) + return nil, err + } + + logger.Infof("[PreviewUserNodes] found %v node group infos from %v requested", len(nodeGroupInfos), len(allGroupIds)) + + // 创建节点组信息映射和有效节点组ID列表 + for _, ngInfo := range nodeGroupInfos { + nodeGroupInfoMap[ngInfo.Id] = ngInfo.Name + validGroupIds = append(validGroupIds, ngInfo.Id) + logger.Debugf("[PreviewUserNodes] node_group[%d] = %s", ngInfo.Id, ngInfo.Name) + } + + // 记录无效的节点组ID(节点有这个ID但数据库中不存在) + for _, requestedId := range allGroupIds { + found := false + for _, validId := range validGroupIds { + if requestedId == validId { + found = true + break + } + } + if !found { + logger.Infof("[PreviewUserNodes] node_group_id %d not found in database, treating as public nodes", requestedId) + } + } + } + + // 13. 构建响应:根据有效节点组ID重新分组节点 + nodeGroupItems := make([]types.NodeGroupItem, 0) + publicNodes := make([]types.Node, 0) // 公共节点(包括无效节点组和无节点组的节点) + + // 遍历所有分组,重新分类节点 + for groupId, gm := range groupMap { + if groupId == 0 { + // 本来就是无节点组的节点 + publicNodes = append(publicNodes, gm.Nodes...) + continue + } + + // 检查这个节点组ID是否有效(在数据库中存在) + isValid := false + for _, validId := range validGroupIds { + if groupId == validId { + isValid = true + break + } + } + + if isValid { + // 节点组有效,添加到对应的分组 + groupName := nodeGroupInfoMap[groupId] + if groupName == "" { + groupName = fmt.Sprintf("Group %d", groupId) + } + nodeGroupItems = append(nodeGroupItems, types.NodeGroupItem{ + Id: groupId, + Name: groupName, + Nodes: gm.Nodes, + }) + logger.Infof("[PreviewUserNodes] adding node group: id=%d, name=%s, nodes=%d", groupId, groupName, len(gm.Nodes)) + } else { + // 节点组无效,节点归入公共节点组 + logger.Infof("[PreviewUserNodes] node_group_id %d invalid, moving %d nodes to public group", groupId, len(gm.Nodes)) + publicNodes = append(publicNodes, gm.Nodes...) + } + } + + // 最后添加公共节点组(如果有) + if len(publicNodes) > 0 { + nodeGroupItems = append(nodeGroupItems, types.NodeGroupItem{ + Id: 0, + Name: "", + Nodes: publicNodes, + }) + logger.Infof("[PreviewUserNodes] adding public group: nodes=%d", len(publicNodes)) + } + + // 14. 返回结果 + resp = &types.PreviewUserNodesResponse{ + UserId: req.UserId, + NodeGroups: nodeGroupItems, + } + + logger.Infof("[PreviewUserNodes] returning %v node groups for user %v", len(resp.NodeGroups), req.UserId) + return resp, nil +} + +// removeDuplicateInt64 去重 []int64 +func removeDuplicateInt64(slice []int64) []int64 { + keys := make(map[int64]bool) + var list []int64 + for _, entry := range slice { + if !keys[entry] { + keys[entry] = true + list = append(list, entry) + } + } + return list +} diff --git a/internal/logic/admin/group/recalculateGroupLogic.go b/internal/logic/admin/group/recalculateGroupLogic.go new file mode 100644 index 0000000..d43557c --- /dev/null +++ b/internal/logic/admin/group/recalculateGroupLogic.go @@ -0,0 +1,814 @@ +package group + +import ( + "context" + "encoding/json" + "time" + + "github.com/perfect-panel/server/internal/model/group" + "github.com/perfect-panel/server/internal/svc" + "github.com/perfect-panel/server/internal/types" + "github.com/perfect-panel/server/pkg/logger" + "github.com/pkg/errors" + "gorm.io/gorm" +) + +type RecalculateGroupLogic struct { + logger.Logger + ctx context.Context + svcCtx *svc.ServiceContext +} + +// Recalculate group +func NewRecalculateGroupLogic(ctx context.Context, svcCtx *svc.ServiceContext) *RecalculateGroupLogic { + return &RecalculateGroupLogic{ + Logger: logger.WithContext(ctx), + ctx: ctx, + svcCtx: svcCtx, + } +} + +func (l *RecalculateGroupLogic) RecalculateGroup(req *types.RecalculateGroupRequest) error { + // 验证 mode 参数 + if req.Mode != "average" && req.Mode != "subscribe" && req.Mode != "traffic" { + return errors.New("invalid mode, must be one of: average, subscribe, traffic") + } + + // 创建 GroupHistory 记录(state=pending) + triggerType := req.TriggerType + if triggerType == "" { + triggerType = "manual" // 默认为手动触发 + } + + history := &group.GroupHistory{ + GroupMode: req.Mode, + TriggerType: triggerType, + TotalUsers: 0, + SuccessCount: 0, + FailedCount: 0, + } + now := time.Now() + history.StartTime = &now + + // 使用 GORM Transaction 执行分组重算 + err := l.svcCtx.DB.Transaction(func(tx *gorm.DB) error { + // 创建历史记录 + if err := tx.Create(history).Error; err != nil { + l.Errorw("failed to create group history", logger.Field("error", err.Error())) + return err + } + + // 更新状态为 running + if err := tx.Model(history).Update("state", "running").Error; err != nil { + l.Errorw("failed to update history state to running", logger.Field("error", err.Error())) + return err + } + + // 根据 mode 执行不同的分组算法 + var affectedCount int + var err error + + switch req.Mode { + case "average": + affectedCount, err = l.executeAverageGrouping(tx, history.Id) + if err != nil { + l.Errorw("failed to execute average grouping", logger.Field("error", err.Error())) + return err + } + case "subscribe": + affectedCount, err = l.executeSubscribeGrouping(tx, history.Id) + if err != nil { + l.Errorw("failed to execute subscribe grouping", logger.Field("error", err.Error())) + return err + } + case "traffic": + affectedCount, err = l.executeTrafficGrouping(tx, history.Id) + if err != nil { + l.Errorw("failed to execute traffic grouping", logger.Field("error", err.Error())) + return err + } + } + + // 更新 GroupHistory 记录(state=completed, 统计成功/失败数) + endTime := time.Now() + updates := map[string]interface{}{ + "state": "completed", + "total_users": affectedCount, + "success_count": affectedCount, // 暂时假设所有都成功 + "failed_count": 0, + "end_time": endTime, + } + + if err := tx.Model(history).Updates(updates).Error; err != nil { + l.Errorw("failed to update history state to completed", logger.Field("error", err.Error())) + return err + } + + l.Infof("group recalculation completed: mode=%s, affected_users=%d", req.Mode, affectedCount) + return nil + }) + + if err != nil { + // 如果失败,更新历史记录状态为 failed + updateErr := l.svcCtx.DB.Model(history).Updates(map[string]interface{}{ + "state": "failed", + "error_message": err.Error(), + "end_time": time.Now(), + }).Error + if updateErr != nil { + l.Errorw("failed to update history state to failed", logger.Field("error", updateErr.Error())) + } + return err + } + + return nil +} + +// getUserEmail 查询用户的邮箱 +func (l *RecalculateGroupLogic) getUserEmail(tx *gorm.DB, userId int64) string { + type UserAuthMethod struct { + AuthIdentifier string `json:"auth_identifier"` + } + + var authMethod UserAuthMethod + if err := tx.Table("user_auth_methods"). + Select("auth_identifier"). + Where("user_id = ? AND (auth_type = ? OR auth_type = ?)", userId, "email", "6"). + First(&authMethod).Error; err != nil { + return "" + } + + return authMethod.AuthIdentifier +} + +// executeAverageGrouping 实现平均分组算法(随机分配节点组到用户订阅) +// 新逻辑:获取所有有效用户订阅,从订阅的节点组ID中随机选择一个,设置到用户订阅的 node_group_id 字段 +func (l *RecalculateGroupLogic) executeAverageGrouping(tx *gorm.DB, historyId int64) (int, error) { + // 1. 查询所有有效且未锁定的用户订阅(status IN (0, 1)) + type UserSubscribeInfo struct { + Id int64 `json:"id"` + UserId int64 `json:"user_id"` + SubscribeId int64 `json:"subscribe_id"` + } + + var userSubscribes []UserSubscribeInfo + if err := tx.Table("user_subscribe"). + Select("id, user_id, subscribe_id"). + Where("group_locked = ? AND status IN (0, 1)", 0). // 只查询未锁定且有效的用户订阅 + Scan(&userSubscribes).Error; err != nil { + return 0, err + } + + if len(userSubscribes) == 0 { + l.Infof("average grouping: no valid and unlocked user subscribes found") + return 0, nil + } + + l.Infof("average grouping: found %d valid and unlocked user subscribes", len(userSubscribes)) + + // 1.5 查询所有参与计算的节点组ID + var calculationNodeGroups []group.NodeGroup + if err := tx.Table("node_group"). + Select("id"). + Where("for_calculation = ?", true). + Scan(&calculationNodeGroups).Error; err != nil { + l.Errorw("failed to query calculation node groups", logger.Field("error", err.Error())) + return 0, err + } + + // 创建参与计算的节点组ID集合(用于快速查找) + calculationNodeGroupIds := make(map[int64]bool) + for _, ng := range calculationNodeGroups { + calculationNodeGroupIds[ng.Id] = true + } + + l.Infof("average grouping: found %d node groups with for_calculation=true", len(calculationNodeGroupIds)) + + // 2. 批量查询订阅的节点组ID信息 + subscribeIds := make([]int64, len(userSubscribes)) + for i, us := range userSubscribes { + subscribeIds[i] = us.SubscribeId + } + + type SubscribeInfo struct { + Id int64 `json:"id"` + NodeGroupIds string `json:"node_group_ids"` // JSON string + } + var subscribeInfos []SubscribeInfo + if err := tx.Table("subscribe"). + Select("id, node_group_ids"). + Where("id IN ?", subscribeIds). + Find(&subscribeInfos).Error; err != nil { + l.Errorw("failed to query subscribe infos", logger.Field("error", err.Error())) + return 0, err + } + + // 创建 subscribe_id -> SubscribeInfo 的映射 + subInfoMap := make(map[int64]SubscribeInfo) + for _, si := range subscribeInfos { + subInfoMap[si.Id] = si + } + + // 用于存储统计信息(按节点组ID统计用户数) + groupUsersMap := make(map[int64][]struct { + Id int64 `json:"id"` + Email string `json:"email"` + }) + nodeGroupUserCount := make(map[int64]int) // node_group_id -> user_count + nodeGroupNodeCount := make(map[int64]int) // node_group_id -> node_count + + // 3. 遍历所有用户订阅,按序平均分配节点组 + affectedCount := 0 + failedCount := 0 + + // 为每个订阅维护一个分配索引,用于按序循环分配 + subscribeAllocationIndex := make(map[int64]int) // subscribe_id -> current_index + + for _, us := range userSubscribes { + subInfo, ok := subInfoMap[us.SubscribeId] + if !ok { + l.Infow("subscribe not found", + logger.Field("user_subscribe_id", us.Id), + logger.Field("subscribe_id", us.SubscribeId)) + failedCount++ + continue + } + + // 解析订阅的节点组ID列表,并过滤出参与计算的节点组 + var nodeGroupIds []int64 + if subInfo.NodeGroupIds != "" && subInfo.NodeGroupIds != "[]" { + var allNodeGroupIds []int64 + if err := json.Unmarshal([]byte(subInfo.NodeGroupIds), &allNodeGroupIds); err != nil { + l.Errorw("failed to parse node_group_ids", + logger.Field("subscribe_id", subInfo.Id), + logger.Field("node_group_ids", subInfo.NodeGroupIds), + logger.Field("error", err.Error())) + failedCount++ + continue + } + + // 只保留参与计算的节点组 + for _, ngId := range allNodeGroupIds { + if calculationNodeGroupIds[ngId] { + nodeGroupIds = append(nodeGroupIds, ngId) + } + } + + if len(nodeGroupIds) == 0 && len(allNodeGroupIds) > 0 { + l.Debugw("all node_group_ids are not for calculation, setting to 0", + logger.Field("subscribe_id", subInfo.Id), + logger.Field("total_node_groups", len(allNodeGroupIds))) + } + } + + // 如果没有节点组ID,跳过 + if len(nodeGroupIds) == 0 { + l.Debugf("no valid node_group_ids for subscribe_id=%d, setting to 0", subInfo.Id) + if err := tx.Table("user_subscribe"). + Where("id = ?", us.Id). + Update("node_group_id", 0).Error; err != nil { + l.Errorw("failed to update user_subscribe node_group_id", + logger.Field("user_subscribe_id", us.Id), + logger.Field("error", err.Error())) + failedCount++ + continue + } + } + + // 按序选择节点组ID(循环轮询分配) + selectedNodeGroupId := int64(0) + if len(nodeGroupIds) > 0 { + // 获取当前订阅的分配索引 + currentIndex := subscribeAllocationIndex[us.SubscribeId] + // 选择当前索引对应的节点组 + selectedNodeGroupId = nodeGroupIds[currentIndex] + // 更新索引,循环使用(轮询) + subscribeAllocationIndex[us.SubscribeId] = (currentIndex + 1) % len(nodeGroupIds) + + l.Debugf("assigning user_subscribe_id=%d (subscribe_id=%d) to node_group_id=%d (index=%d, total_options=%d, mode=sequential)", + us.Id, us.SubscribeId, selectedNodeGroupId, currentIndex, len(nodeGroupIds)) + } + + // 更新 user_subscribe 的 node_group_id 字段(单个ID) + if err := tx.Table("user_subscribe"). + Where("id = ?", us.Id). + Update("node_group_id", selectedNodeGroupId).Error; err != nil { + l.Errorw("failed to update user_subscribe node_group_id", + logger.Field("user_subscribe_id", us.Id), + logger.Field("error", err.Error())) + failedCount++ + continue + } + + // 只统计有节点组的用户 + if selectedNodeGroupId > 0 { + // 查询用户邮箱,用于保存到历史记录 + email := l.getUserEmail(tx, us.UserId) + groupUsersMap[selectedNodeGroupId] = append(groupUsersMap[selectedNodeGroupId], struct { + Id int64 `json:"id"` + Email string `json:"email"` + }{ + Id: us.UserId, + Email: email, + }) + nodeGroupUserCount[selectedNodeGroupId]++ + } + + affectedCount++ + } + + l.Infof("average grouping completed: affected=%d, failed=%d", affectedCount, failedCount) + + // 4. 创建分组历史详情记录(按节点组ID统计) + for nodeGroupId, users := range groupUsersMap { + userCount := len(users) + if userCount == 0 { + continue + } + + // 统计该节点组的节点数 + var nodeCount int64 = 0 + if nodeGroupId > 0 { + if err := tx.Table("nodes"). + Where("JSON_CONTAINS(node_group_ids, ?)", nodeGroupId). + Count(&nodeCount).Error; err != nil { + l.Errorw("failed to count nodes", + logger.Field("node_group_id", nodeGroupId), + logger.Field("error", err.Error())) + } + } + nodeGroupNodeCount[nodeGroupId] = int(nodeCount) + + // 序列化用户信息为 JSON + userDataJSON := "[]" + if jsonData, err := json.Marshal(users); err == nil { + userDataJSON = string(jsonData) + } else { + l.Errorw("failed to marshal user data", + logger.Field("node_group_id", nodeGroupId), + logger.Field("error", err.Error())) + } + + // 创建历史详情(使用 node_group_id 作为分组标识) + detail := &group.GroupHistoryDetail{ + HistoryId: historyId, + NodeGroupId: nodeGroupId, + UserCount: userCount, + NodeCount: int(nodeCount), + UserData: userDataJSON, + } + + if err := tx.Create(detail).Error; err != nil { + l.Errorw("failed to create group history detail", + logger.Field("node_group_id", nodeGroupId), + logger.Field("error", err.Error())) + } + + l.Infof("Average Group (node_group_id=%d): users=%d, nodes=%d", + nodeGroupId, userCount, nodeCount) + } + + return affectedCount, nil +} + +// executeSubscribeGrouping 实现基于订阅套餐的分组算法 +// 逻辑:查询有效订阅 → 获取订阅的 node_group_ids → 取第一个 node_group_id(如果有) → 更新 user_subscribe.node_group_id +// 订阅过期的用户 → 设置 node_group_id 为 0 +func (l *RecalculateGroupLogic) executeSubscribeGrouping(tx *gorm.DB, historyId int64) (int, error) { + // 1. 查询所有有效且未锁定的用户订阅(status IN (0, 1), group_locked = 0) + type UserSubscribeInfo struct { + Id int64 `json:"id"` + UserId int64 `json:"user_id"` + SubscribeId int64 `json:"subscribe_id"` + } + + var userSubscribes []UserSubscribeInfo + if err := tx.Table("user_subscribe"). + Select("id, user_id, subscribe_id"). + Where("group_locked = ? AND status IN (0, 1)", 0). + Scan(&userSubscribes).Error; err != nil { + l.Errorw("failed to query user subscribes", logger.Field("error", err.Error())) + return 0, err + } + + if len(userSubscribes) == 0 { + l.Infof("subscribe grouping: no valid and unlocked user subscribes found") + return 0, nil + } + + l.Infof("subscribe grouping: found %d valid and unlocked user subscribes", len(userSubscribes)) + + // 1.5 查询所有参与计算的节点组ID + var calculationNodeGroups []group.NodeGroup + if err := tx.Table("node_group"). + Select("id"). + Where("for_calculation = ?", true). + Scan(&calculationNodeGroups).Error; err != nil { + l.Errorw("failed to query calculation node groups", logger.Field("error", err.Error())) + return 0, err + } + + // 创建参与计算的节点组ID集合(用于快速查找) + calculationNodeGroupIds := make(map[int64]bool) + for _, ng := range calculationNodeGroups { + calculationNodeGroupIds[ng.Id] = true + } + + l.Infof("subscribe grouping: found %d node groups with for_calculation=true", len(calculationNodeGroupIds)) + + // 2. 批量查询订阅的节点组ID信息 + subscribeIds := make([]int64, len(userSubscribes)) + for i, us := range userSubscribes { + subscribeIds[i] = us.SubscribeId + } + + type SubscribeInfo struct { + Id int64 `json:"id"` + NodeGroupIds string `json:"node_group_ids"` // JSON string + } + var subscribeInfos []SubscribeInfo + if err := tx.Table("subscribe"). + Select("id, node_group_ids"). + Where("id IN ?", subscribeIds). + Find(&subscribeInfos).Error; err != nil { + l.Errorw("failed to query subscribe infos", logger.Field("error", err.Error())) + return 0, err + } + + // 创建 subscribe_id -> SubscribeInfo 的映射 + subInfoMap := make(map[int64]SubscribeInfo) + for _, si := range subscribeInfos { + subInfoMap[si.Id] = si + } + + // 用于存储统计信息(按节点组ID统计用户数) + type UserInfo struct { + Id int64 `json:"id"` + Email string `json:"email"` + } + groupUsersMap := make(map[int64][]UserInfo) + nodeGroupUserCount := make(map[int64]int) // node_group_id -> user_count + nodeGroupNodeCount := make(map[int64]int) // node_group_id -> node_count + + // 3. 遍历所有用户订阅,取第一个节点组ID + affectedCount := 0 + failedCount := 0 + + for _, us := range userSubscribes { + subInfo, ok := subInfoMap[us.SubscribeId] + if !ok { + l.Infow("subscribe not found", + logger.Field("user_subscribe_id", us.Id), + logger.Field("subscribe_id", us.SubscribeId)) + failedCount++ + continue + } + + // 解析订阅的节点组ID列表,并过滤出参与计算的节点组 + var nodeGroupIds []int64 + if subInfo.NodeGroupIds != "" && subInfo.NodeGroupIds != "[]" { + var allNodeGroupIds []int64 + if err := json.Unmarshal([]byte(subInfo.NodeGroupIds), &allNodeGroupIds); err != nil { + l.Errorw("failed to parse node_group_ids", + logger.Field("subscribe_id", subInfo.Id), + logger.Field("node_group_ids", subInfo.NodeGroupIds), + logger.Field("error", err.Error())) + failedCount++ + continue + } + + // 只保留参与计算的节点组 + for _, ngId := range allNodeGroupIds { + if calculationNodeGroupIds[ngId] { + nodeGroupIds = append(nodeGroupIds, ngId) + } + } + + if len(nodeGroupIds) == 0 && len(allNodeGroupIds) > 0 { + l.Debugw("all node_group_ids are not for calculation, setting to 0", + logger.Field("subscribe_id", subInfo.Id), + logger.Field("total_node_groups", len(allNodeGroupIds))) + } + } + + // 取第一个参与计算的节点组ID(如果有),否则设置为 0 + selectedNodeGroupId := int64(0) + if len(nodeGroupIds) > 0 { + selectedNodeGroupId = nodeGroupIds[0] + } + + l.Debugf("assigning user_subscribe_id=%d (subscribe_id=%d) to node_group_id=%d (total_options=%d, selected_first)", + us.Id, us.SubscribeId, selectedNodeGroupId, len(nodeGroupIds)) + + // 更新 user_subscribe 的 node_group_id 字段 + if err := tx.Table("user_subscribe"). + Where("id = ?", us.Id). + Update("node_group_id", selectedNodeGroupId).Error; err != nil { + l.Errorw("failed to update user_subscribe node_group_id", + logger.Field("user_subscribe_id", us.Id), + logger.Field("error", err.Error())) + failedCount++ + continue + } + + // 只统计有节点组的用户 + if selectedNodeGroupId > 0 { + // 查询用户邮箱,用于保存到历史记录 + email := l.getUserEmail(tx, us.UserId) + groupUsersMap[selectedNodeGroupId] = append(groupUsersMap[selectedNodeGroupId], UserInfo{ + Id: us.UserId, + Email: email, + }) + nodeGroupUserCount[selectedNodeGroupId]++ + } + + affectedCount++ + } + + l.Infof("subscribe grouping completed: affected=%d, failed=%d", affectedCount, failedCount) + + // 4. 处理订阅过期/失效的用户,设置 node_group_id 为 0 + // 查询所有没有有效订阅且未锁定的用户订阅记录 + var expiredUserSubscribes []struct { + Id int64 `json:"id"` + UserId int64 `json:"user_id"` + } + + if err := tx.Raw(` + SELECT us.id, us.user_id + FROM user_subscribe as us + WHERE us.group_locked = 0 + AND us.status NOT IN (0, 1) + `).Scan(&expiredUserSubscribes).Error; err != nil { + l.Errorw("failed to query expired user subscribes", logger.Field("error", err.Error())) + // 继续处理,不因为过期用户查询失败而影响 + } else { + l.Infof("found %d expired user subscribes for subscribe-based grouping, will set node_group_id to 0", len(expiredUserSubscribes)) + + expiredAffectedCount := 0 + for _, eu := range expiredUserSubscribes { + // 更新 user_subscribe 表的 node_group_id 字段到 0 + if err := tx.Table("user_subscribe"). + Where("id = ?", eu.Id). + Update("node_group_id", 0).Error; err != nil { + l.Errorw("failed to update expired user subscribe node_group_id", + logger.Field("user_subscribe_id", eu.Id), + logger.Field("error", err.Error())) + continue + } + + expiredAffectedCount++ + } + + l.Infof("expired user subscribes grouping completed: affected=%d", expiredAffectedCount) + } + + // 5. 创建分组历史详情记录(按节点组ID统计) + for nodeGroupId, users := range groupUsersMap { + userCount := len(users) + if userCount == 0 { + continue + } + + // 统计该节点组的节点数 + var nodeCount int64 = 0 + if nodeGroupId > 0 { + if err := tx.Table("nodes"). + Where("JSON_CONTAINS(node_group_ids, ?)", nodeGroupId). + Count(&nodeCount).Error; err != nil { + l.Errorw("failed to count nodes", + logger.Field("node_group_id", nodeGroupId), + logger.Field("error", err.Error())) + } + } + nodeGroupNodeCount[nodeGroupId] = int(nodeCount) + + // 序列化用户信息为 JSON + userDataJSON := "[]" + if jsonData, err := json.Marshal(users); err == nil { + userDataJSON = string(jsonData) + } else { + l.Errorw("failed to marshal user data", + logger.Field("node_group_id", nodeGroupId), + logger.Field("error", err.Error())) + } + + // 创建历史详情 + detail := &group.GroupHistoryDetail{ + HistoryId: historyId, + NodeGroupId: nodeGroupId, + UserCount: userCount, + NodeCount: int(nodeCount), + UserData: userDataJSON, + } + + if err := tx.Create(detail).Error; err != nil { + l.Errorw("failed to create group history detail", + logger.Field("node_group_id", nodeGroupId), + logger.Field("error", err.Error())) + } + + l.Infof("Subscribe Group (node_group_id=%d): users=%d, nodes=%d", + nodeGroupId, userCount, nodeCount) + } + + return affectedCount, nil +} + +// executeTrafficGrouping 实现基于流量的分组算法 +// 逻辑:根据配置的流量范围,将用户分配到对应的用户组 +func (l *RecalculateGroupLogic) executeTrafficGrouping(tx *gorm.DB, historyId int64) (int, error) { + // 用于存储每个节点组的用户信息(id 和 email) + type UserInfo struct { + Id int64 `json:"id"` + Email string `json:"email"` + } + groupUsersMap := make(map[int64][]UserInfo) // node_group_id -> []UserInfo + + // 1. 获取所有设置了流量区间的节点组 + var nodeGroups []group.NodeGroup + if err := tx.Where("for_calculation = ?", true). + Where("(min_traffic_gb > 0 OR max_traffic_gb > 0)"). + Find(&nodeGroups).Error; err != nil { + l.Errorw("failed to query node groups", logger.Field("error", err.Error())) + return 0, err + } + + if len(nodeGroups) == 0 { + l.Infow("no node groups with traffic ranges configured") + return 0, nil + } + + l.Infow("executeTrafficGrouping loaded node groups", + logger.Field("node_groups_count", len(nodeGroups))) + + // 2. 查询所有有效且未锁定的用户订阅及其已用流量 + type UserSubscribeInfo struct { + Id int64 + UserId int64 + Upload int64 + Download int64 + UsedTraffic int64 // 已用流量 = upload + download (bytes) + } + + var userSubscribes []UserSubscribeInfo + if err := tx.Table("user_subscribe"). + Select("id, user_id, upload, download, (upload + download) as used_traffic"). + Where("group_locked = ? AND status IN (0, 1)", 0). // 只查询有效且未锁定的用户订阅 + Scan(&userSubscribes).Error; err != nil { + l.Errorw("failed to query user subscribes", logger.Field("error", err.Error())) + return 0, err + } + + if len(userSubscribes) == 0 { + l.Infow("no valid and unlocked user subscribes found") + return 0, nil + } + + l.Infow("found user subscribes for traffic-based grouping", logger.Field("count", len(userSubscribes))) + + // 3. 根据流量范围分配节点组ID到用户订阅 + affectedCount := 0 + groupUserCount := make(map[int64]int) // node_group_id -> user_count + + for _, us := range userSubscribes { + // 将字节转换为 GB + usedTrafficGB := float64(us.UsedTraffic) / (1024 * 1024 * 1024) + + // 查找匹配的流量范围(使用左闭右开区间 [Min, Max)) + var targetNodeGroupId int64 = 0 + for _, ng := range nodeGroups { + if ng.MinTrafficGB == nil || ng.MaxTrafficGB == nil { + continue + } + minTraffic := float64(*ng.MinTrafficGB) + maxTraffic := float64(*ng.MaxTrafficGB) + + // 检查是否在区间内 [min, max) + if usedTrafficGB >= minTraffic && usedTrafficGB < maxTraffic { + targetNodeGroupId = ng.Id + break + } + } + + // 如果没有匹配到任何范围,targetNodeGroupId 保持为 0(不分配节点组) + + // 更新 user_subscribe 的 node_group_id 字段 + if err := tx.Table("user_subscribe"). + Where("id = ?", us.Id). + Update("node_group_id", targetNodeGroupId).Error; err != nil { + l.Errorw("failed to update user subscribe node_group_id", + logger.Field("user_subscribe_id", us.Id), + logger.Field("target_node_group_id", targetNodeGroupId), + logger.Field("error", err.Error())) + continue + } + + // 只有分配了节点组的用户才记录到历史 + if targetNodeGroupId > 0 { + // 查询用户邮箱,用于保存到历史记录 + email := l.getUserEmail(tx, us.UserId) + userInfo := UserInfo{ + Id: us.UserId, + Email: email, + } + groupUsersMap[targetNodeGroupId] = append(groupUsersMap[targetNodeGroupId], userInfo) + groupUserCount[targetNodeGroupId]++ + + l.Debugf("assigned user subscribe %d (traffic: %.2fGB) to node group %d", + us.Id, usedTrafficGB, targetNodeGroupId) + } else { + l.Debugf("user subscribe %d (traffic: %.2fGB) not assigned to any node group", + us.Id, usedTrafficGB) + } + + affectedCount++ + } + + l.Infof("traffic-based grouping completed: affected_subscribes=%d", affectedCount) + + // 4. 创建分组历史详情记录(只统计有用户的节点组) + nodeGroupCount := make(map[int64]int) // node_group_id -> node_count + for _, ng := range nodeGroups { + nodeGroupCount[ng.Id] = 1 // 每个节点组计为1 + } + + for nodeGroupId, userCount := range groupUserCount { + userDataJSON, err := json.Marshal(groupUsersMap[nodeGroupId]) + if err != nil { + l.Errorw("failed to marshal user data", + logger.Field("node_group_id", nodeGroupId), + logger.Field("error", err.Error())) + continue + } + + detail := group.GroupHistoryDetail{ + HistoryId: historyId, + NodeGroupId: nodeGroupId, + UserCount: userCount, + NodeCount: nodeGroupCount[nodeGroupId], + UserData: string(userDataJSON), + } + if err := tx.Create(&detail).Error; err != nil { + l.Errorw("failed to create group history detail", + logger.Field("history_id", historyId), + logger.Field("node_group_id", nodeGroupId), + logger.Field("error", err.Error())) + } + } + + return affectedCount, nil +} + +// containsIgnoreCase checks if a string contains another substring (case-insensitive) +func containsIgnoreCase(s, substr string) bool { + if len(substr) == 0 { + return true + } + if len(s) < len(substr) { + return false + } + + // Simple case-insensitive contains check + sLower := toLower(s) + substrLower := toLower(substr) + + return contains(sLower, substrLower) +} + +// toLower converts a string to lowercase +func toLower(s string) string { + result := make([]rune, len(s)) + for i, r := range s { + if r >= 'A' && r <= 'Z' { + result[i] = r + ('a' - 'A') + } else { + result[i] = r + } + } + return string(result) +} + +// contains checks if a string contains another substring (case-sensitive) +func contains(s, substr string) bool { + return len(s) >= len(substr) && indexOf(s, substr) >= 0 +} + +// indexOf returns the index of the first occurrence of substr in s, or -1 if not found +func indexOf(s, substr string) int { + n := len(substr) + if n == 0 { + return 0 + } + if n > len(s) { + return -1 + } + + // Simple string search + for i := 0; i <= len(s)-n; i++ { + if s[i:i+n] == substr { + return i + } + } + return -1 +} diff --git a/internal/logic/admin/group/resetGroupsLogic.go b/internal/logic/admin/group/resetGroupsLogic.go new file mode 100644 index 0000000..eaaa098 --- /dev/null +++ b/internal/logic/admin/group/resetGroupsLogic.go @@ -0,0 +1,82 @@ +package group + +import ( + "context" + + "github.com/perfect-panel/server/internal/model/group" + "github.com/perfect-panel/server/internal/model/node" + "github.com/perfect-panel/server/internal/model/subscribe" + "github.com/perfect-panel/server/internal/model/system" + "github.com/perfect-panel/server/internal/svc" + "github.com/perfect-panel/server/pkg/logger" +) + +type ResetGroupsLogic struct { + logger.Logger + ctx context.Context + svcCtx *svc.ServiceContext +} + +// NewResetGroupsLogic Reset all groups (delete all node groups and reset related data) +func NewResetGroupsLogic(ctx context.Context, svcCtx *svc.ServiceContext) *ResetGroupsLogic { + return &ResetGroupsLogic{ + Logger: logger.WithContext(ctx), + ctx: ctx, + svcCtx: svcCtx, + } +} + +func (l *ResetGroupsLogic) ResetGroups() error { + // 1. Delete all node groups + err := l.svcCtx.DB.Where("1 = 1").Delete(&group.NodeGroup{}).Error + if err != nil { + l.Errorw("Failed to delete all node groups", logger.Field("error", err.Error())) + return err + } + l.Infow("Successfully deleted all node groups") + + // 2. Clear node_group_ids for all subscribes (products) + err = l.svcCtx.DB.Model(&subscribe.Subscribe{}).Where("1 = 1").Update("node_group_ids", "[]").Error + if err != nil { + l.Errorw("Failed to clear subscribes' node_group_ids", logger.Field("error", err.Error())) + return err + } + l.Infow("Successfully cleared all subscribes' node_group_ids") + + // 3. Clear node_group_ids for all nodes + err = l.svcCtx.DB.Model(&node.Node{}).Where("1 = 1").Update("node_group_ids", "[]").Error + if err != nil { + l.Errorw("Failed to clear nodes' node_group_ids", logger.Field("error", err.Error())) + return err + } + l.Infow("Successfully cleared all nodes' node_group_ids") + + // 4. Clear group history + err = l.svcCtx.DB.Where("1 = 1").Delete(&group.GroupHistory{}).Error + if err != nil { + l.Errorw("Failed to clear group history", logger.Field("error", err.Error())) + // Non-critical error, continue anyway + } else { + l.Infow("Successfully cleared group history") + } + + // 7. Clear group history details + err = l.svcCtx.DB.Where("1 = 1").Delete(&group.GroupHistoryDetail{}).Error + if err != nil { + l.Errorw("Failed to clear group history details", logger.Field("error", err.Error())) + // Non-critical error, continue anyway + } else { + l.Infow("Successfully cleared group history details") + } + + // 5. Delete all group config settings + err = l.svcCtx.DB.Where("`category` = ?", "group").Delete(&system.System{}).Error + if err != nil { + l.Errorw("Failed to delete group config", logger.Field("error", err.Error())) + return err + } + l.Infow("Successfully deleted all group config settings") + + l.Infow("Group reset completed successfully") + return nil +} diff --git a/internal/logic/admin/group/updateGroupConfigLogic.go b/internal/logic/admin/group/updateGroupConfigLogic.go new file mode 100644 index 0000000..0980373 --- /dev/null +++ b/internal/logic/admin/group/updateGroupConfigLogic.go @@ -0,0 +1,188 @@ +package group + +import ( + "context" + "encoding/json" + + "github.com/perfect-panel/server/internal/model/system" + "github.com/perfect-panel/server/internal/svc" + "github.com/perfect-panel/server/internal/types" + "github.com/perfect-panel/server/pkg/logger" + "github.com/pkg/errors" + "gorm.io/gorm" +) + +type UpdateGroupConfigLogic struct { + logger.Logger + ctx context.Context + svcCtx *svc.ServiceContext +} + +// Update group config +func NewUpdateGroupConfigLogic(ctx context.Context, svcCtx *svc.ServiceContext) *UpdateGroupConfigLogic { + return &UpdateGroupConfigLogic{ + Logger: logger.WithContext(ctx), + ctx: ctx, + svcCtx: svcCtx, + } +} + +func (l *UpdateGroupConfigLogic) UpdateGroupConfig(req *types.UpdateGroupConfigRequest) error { + // 验证 mode 是否为合法值 + if req.Mode != "" { + if req.Mode != "average" && req.Mode != "subscribe" && req.Mode != "traffic" { + return errors.New("invalid mode, must be one of: average, subscribe, traffic") + } + } + + // 使用 GORM Transaction 更新配置 + err := l.svcCtx.DB.Transaction(func(tx *gorm.DB) error { + // 更新 enabled 配置(使用 Upsert 逻辑) + enabledValue := "false" + if req.Enabled { + enabledValue = "true" + } + result := tx.Model(&system.System{}). + Where("`category` = 'group' and `key` = ?", "enabled"). + Update("value", enabledValue) + if result.Error != nil { + l.Errorw("failed to update group enabled config", logger.Field("error", result.Error.Error())) + return result.Error + } + // 如果没有更新任何行,说明记录不存在,需要插入 + if result.RowsAffected == 0 { + if err := tx.Create(&system.System{ + Category: "group", + Key: "enabled", + Value: enabledValue, + Desc: "Group Feature Enabled", + }).Error; err != nil { + l.Errorw("failed to create group enabled config", logger.Field("error", err.Error())) + return err + } + } + + // 更新 mode 配置(使用 Upsert 逻辑) + if req.Mode != "" { + result := tx.Model(&system.System{}). + Where("`category` = 'group' and `key` = ?", "mode"). + Update("value", req.Mode) + if result.Error != nil { + l.Errorw("failed to update group mode config", logger.Field("error", result.Error.Error())) + return result.Error + } + // 如果没有更新任何行,说明记录不存在,需要插入 + if result.RowsAffected == 0 { + if err := tx.Create(&system.System{ + Category: "group", + Key: "mode", + Value: req.Mode, + Desc: "Group Mode", + }).Error; err != nil { + l.Errorw("failed to create group mode config", logger.Field("error", err.Error())) + return err + } + } + } + + // 更新 JSON 配置 + if req.Config != nil { + // 更新 average_config + if averageConfig, ok := req.Config["average_config"]; ok { + jsonBytes, err := json.Marshal(averageConfig) + if err != nil { + l.Errorw("failed to marshal average_config", logger.Field("error", err.Error())) + return errors.Wrap(err, "failed to marshal average_config") + } + // 使用 Upsert 逻辑:先尝试 UPDATE,如果不存在则 INSERT + result := tx.Model(&system.System{}). + Where("`category` = 'group' and `key` = ?", "average_config"). + Update("value", string(jsonBytes)) + if result.Error != nil { + l.Errorw("failed to update group average_config", logger.Field("error", result.Error.Error())) + return result.Error + } + // 如果没有更新任何行,说明记录不存在,需要插入 + if result.RowsAffected == 0 { + if err := tx.Create(&system.System{ + Category: "group", + Key: "average_config", + Value: string(jsonBytes), + Desc: "Average Group Config", + }).Error; err != nil { + l.Errorw("failed to create group average_config", logger.Field("error", err.Error())) + return err + } + } + } + + // 更新 subscribe_config + if subscribeConfig, ok := req.Config["subscribe_config"]; ok { + jsonBytes, err := json.Marshal(subscribeConfig) + if err != nil { + l.Errorw("failed to marshal subscribe_config", logger.Field("error", err.Error())) + return errors.Wrap(err, "failed to marshal subscribe_config") + } + // 使用 Upsert 逻辑:先尝试 UPDATE,如果不存在则 INSERT + result := tx.Model(&system.System{}). + Where("`category` = 'group' and `key` = ?", "subscribe_config"). + Update("value", string(jsonBytes)) + if result.Error != nil { + l.Errorw("failed to update group subscribe_config", logger.Field("error", result.Error.Error())) + return result.Error + } + // 如果没有更新任何行,说明记录不存在,需要插入 + if result.RowsAffected == 0 { + if err := tx.Create(&system.System{ + Category: "group", + Key: "subscribe_config", + Value: string(jsonBytes), + Desc: "Subscribe Group Config", + }).Error; err != nil { + l.Errorw("failed to create group subscribe_config", logger.Field("error", err.Error())) + return err + } + } + } + + // 更新 traffic_config + if trafficConfig, ok := req.Config["traffic_config"]; ok { + jsonBytes, err := json.Marshal(trafficConfig) + if err != nil { + l.Errorw("failed to marshal traffic_config", logger.Field("error", err.Error())) + return errors.Wrap(err, "failed to marshal traffic_config") + } + // 使用 Upsert 逻辑:先尝试 UPDATE,如果不存在则 INSERT + result := tx.Model(&system.System{}). + Where("`category` = 'group' and `key` = ?", "traffic_config"). + Update("value", string(jsonBytes)) + if result.Error != nil { + l.Errorw("failed to update group traffic_config", logger.Field("error", result.Error.Error())) + return result.Error + } + // 如果没有更新任何行,说明记录不存在,需要插入 + if result.RowsAffected == 0 { + if err := tx.Create(&system.System{ + Category: "group", + Key: "traffic_config", + Value: string(jsonBytes), + Desc: "Traffic Group Config", + }).Error; err != nil { + l.Errorw("failed to create group traffic_config", logger.Field("error", err.Error())) + return err + } + } + } + } + + return nil + }) + + if err != nil { + l.Errorw("failed to update group config", logger.Field("error", err.Error())) + return err + } + + l.Infof("group config updated successfully: enabled=%v, mode=%s", req.Enabled, req.Mode) + return nil +} diff --git a/internal/logic/admin/group/updateNodeGroupLogic.go b/internal/logic/admin/group/updateNodeGroupLogic.go new file mode 100644 index 0000000..eb299d5 --- /dev/null +++ b/internal/logic/admin/group/updateNodeGroupLogic.go @@ -0,0 +1,140 @@ +package group + +import ( + "context" + "errors" + "time" + + "github.com/perfect-panel/server/internal/model/group" + "github.com/perfect-panel/server/internal/svc" + "github.com/perfect-panel/server/internal/types" + "github.com/perfect-panel/server/pkg/logger" + "gorm.io/gorm" +) + +type UpdateNodeGroupLogic struct { + logger.Logger + ctx context.Context + svcCtx *svc.ServiceContext +} + +func NewUpdateNodeGroupLogic(ctx context.Context, svcCtx *svc.ServiceContext) *UpdateNodeGroupLogic { + return &UpdateNodeGroupLogic{ + Logger: logger.WithContext(ctx), + ctx: ctx, + svcCtx: svcCtx, + } +} + +func (l *UpdateNodeGroupLogic) UpdateNodeGroup(req *types.UpdateNodeGroupRequest) error { + // 检查节点组是否存在 + var nodeGroup group.NodeGroup + if err := l.svcCtx.DB.Where("id = ?", req.Id).First(&nodeGroup).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return errors.New("node group not found") + } + logger.Errorf("failed to find node group: %v", err) + return err + } + + // 构建更新数据 + updates := map[string]interface{}{ + "updated_at": time.Now(), + } + if req.Name != "" { + updates["name"] = req.Name + } + if req.Description != "" { + updates["description"] = req.Description + } + if req.Sort != 0 { + updates["sort"] = req.Sort + } + if req.ForCalculation != nil { + updates["for_calculation"] = *req.ForCalculation + } + + // 获取新的流量区间值 + newMinTraffic := nodeGroup.MinTrafficGB + newMaxTraffic := nodeGroup.MaxTrafficGB + if req.MinTrafficGB != nil { + newMinTraffic = req.MinTrafficGB + updates["min_traffic_gb"] = *req.MinTrafficGB + } + if req.MaxTrafficGB != nil { + newMaxTraffic = req.MaxTrafficGB + updates["max_traffic_gb"] = *req.MaxTrafficGB + } + + // 校验流量区间 + if err := l.validateTrafficRange(int(req.Id), newMinTraffic, newMaxTraffic); err != nil { + return err + } + + // 执行更新 + if err := l.svcCtx.DB.Model(&nodeGroup).Updates(updates).Error; err != nil { + logger.Errorf("failed to update node group: %v", err) + return err + } + + logger.Infof("updated node group: id=%d", req.Id) + return nil +} + +// validateTrafficRange 校验流量区间:不能重叠、不能留空档、最小值不能大于最大值 +func (l *UpdateNodeGroupLogic) validateTrafficRange(currentNodeGroupId int, newMin, newMax *int64) error { + // 处理指针值 + minVal := int64(0) + maxVal := int64(0) + if newMin != nil { + minVal = *newMin + } + if newMax != nil { + maxVal = *newMax + } + + // 检查最小值是否大于最大值 + if minVal > maxVal { + return errors.New("minimum traffic cannot exceed maximum traffic") + } + + // 如果两个值都为0,表示不参与流量分组,不需要校验 + if minVal == 0 && maxVal == 0 { + return nil + } + + // 查询所有其他设置了流量区间的节点组 + var otherGroups []group.NodeGroup + if err := l.svcCtx.DB. + Where("id != ?", currentNodeGroupId). + Where("(min_traffic_gb > 0 OR max_traffic_gb > 0)"). + Find(&otherGroups).Error; err != nil { + logger.Errorf("failed to query other node groups: %v", err) + return err + } + + // 检查是否有重叠 + for _, other := range otherGroups { + otherMin := int64(0) + otherMax := int64(0) + if other.MinTrafficGB != nil { + otherMin = *other.MinTrafficGB + } + if other.MaxTrafficGB != nil { + otherMax = *other.MaxTrafficGB + } + + // 如果对方也没设置区间,跳过 + if otherMin == 0 && otherMax == 0 { + continue + } + + // 检查是否有重叠: 如果两个区间相交,就是重叠 + // 不重叠的条件是: newMax <= otherMin OR newMin >= otherMax + if !(maxVal <= otherMin || minVal >= otherMax) { + return errors.New("traffic range overlaps with another node group") + } + } + + return nil +} diff --git a/internal/logic/admin/server/createNodeLogic.go b/internal/logic/admin/server/createNodeLogic.go index 78ce987..38044de 100644 --- a/internal/logic/admin/server/createNodeLogic.go +++ b/internal/logic/admin/server/createNodeLogic.go @@ -29,13 +29,14 @@ func NewCreateNodeLogic(ctx context.Context, svcCtx *svc.ServiceContext) *Create func (l *CreateNodeLogic) CreateNode(req *types.CreateNodeRequest) error { data := node.Node{ - Name: req.Name, - Tags: tool.StringSliceToString(req.Tags), - Enabled: req.Enabled, - Port: req.Port, - Address: req.Address, - ServerId: req.ServerId, - Protocol: req.Protocol, + Name: req.Name, + Tags: tool.StringSliceToString(req.Tags), + Enabled: req.Enabled, + Port: req.Port, + Address: req.Address, + ServerId: req.ServerId, + Protocol: req.Protocol, + NodeGroupIds: node.JSONInt64Slice(req.NodeGroupIds), } err := l.svcCtx.NodeModel.InsertNode(l.ctx, &data) if err != nil { diff --git a/internal/logic/admin/server/filterNodeListLogic.go b/internal/logic/admin/server/filterNodeListLogic.go index 2e41cec..47f8574 100644 --- a/internal/logic/admin/server/filterNodeListLogic.go +++ b/internal/logic/admin/server/filterNodeListLogic.go @@ -29,10 +29,17 @@ func NewFilterNodeListLogic(ctx context.Context, svcCtx *svc.ServiceContext) *Fi } func (l *FilterNodeListLogic) FilterNodeList(req *types.FilterNodeListRequest) (resp *types.FilterNodeListResponse, err error) { + // Convert NodeGroupId to []int64 for model + var nodeGroupIds []int64 + if req.NodeGroupId != nil { + nodeGroupIds = []int64{*req.NodeGroupId} + } + total, data, err := l.svcCtx.NodeModel.FilterNodeList(l.ctx, &node.FilterNodeParams{ - Page: req.Page, - Size: req.Size, - Search: req.Search, + Page: req.Page, + Size: req.Size, + Search: req.Search, + NodeGroupIds: nodeGroupIds, }) if err != nil { @@ -43,17 +50,18 @@ func (l *FilterNodeListLogic) FilterNodeList(req *types.FilterNodeListRequest) ( list := make([]types.Node, 0) for _, datum := range data { list = append(list, types.Node{ - Id: datum.Id, - Name: datum.Name, - Tags: tool.RemoveDuplicateElements(strings.Split(datum.Tags, ",")...), - Port: datum.Port, - Address: datum.Address, - ServerId: datum.ServerId, - Protocol: datum.Protocol, - Enabled: datum.Enabled, - Sort: datum.Sort, - CreatedAt: datum.CreatedAt.UnixMilli(), - UpdatedAt: datum.UpdatedAt.UnixMilli(), + Id: datum.Id, + Name: datum.Name, + Tags: tool.RemoveDuplicateElements(strings.Split(datum.Tags, ",")...), + Port: datum.Port, + Address: datum.Address, + ServerId: datum.ServerId, + Protocol: datum.Protocol, + Enabled: datum.Enabled, + Sort: datum.Sort, + NodeGroupIds: []int64(datum.NodeGroupIds), + CreatedAt: datum.CreatedAt.UnixMilli(), + UpdatedAt: datum.UpdatedAt.UnixMilli(), }) } diff --git a/internal/logic/admin/server/updateNodeLogic.go b/internal/logic/admin/server/updateNodeLogic.go index 2af8a4d..3b0c291 100644 --- a/internal/logic/admin/server/updateNodeLogic.go +++ b/internal/logic/admin/server/updateNodeLogic.go @@ -40,6 +40,7 @@ func (l *UpdateNodeLogic) UpdateNode(req *types.UpdateNodeRequest) error { data.Address = req.Address data.Protocol = req.Protocol data.Enabled = req.Enabled + data.NodeGroupIds = node.JSONInt64Slice(req.NodeGroupIds) err = l.svcCtx.NodeModel.UpdateNode(l.ctx, data) if err != nil { l.Errorw("[UpdateNode] Update Database Error: ", logger.Field("error", err.Error())) diff --git a/internal/logic/admin/subscribe/createSubscribeLogic.go b/internal/logic/admin/subscribe/createSubscribeLogic.go index 6309e2b..3aeb751 100644 --- a/internal/logic/admin/subscribe/createSubscribeLogic.go +++ b/internal/logic/admin/subscribe/createSubscribeLogic.go @@ -50,6 +50,8 @@ func (l *CreateSubscribeLogic) CreateSubscribe(req *types.CreateSubscribeRequest Quota: req.Quota, Nodes: tool.Int64SliceToString(req.Nodes), NodeTags: tool.StringSliceToString(req.NodeTags), + NodeGroupIds: subscribe.JSONInt64Slice(req.NodeGroupIds), + NodeGroupId: req.NodeGroupId, Show: req.Show, Sell: req.Sell, Sort: 0, diff --git a/internal/logic/admin/subscribe/getSubscribeListLogic.go b/internal/logic/admin/subscribe/getSubscribeListLogic.go index e8c7866..130d682 100644 --- a/internal/logic/admin/subscribe/getSubscribeListLogic.go +++ b/internal/logic/admin/subscribe/getSubscribeListLogic.go @@ -30,12 +30,20 @@ func NewGetSubscribeListLogic(ctx context.Context, svcCtx *svc.ServiceContext) * } func (l *GetSubscribeListLogic) GetSubscribeList(req *types.GetSubscribeListRequest) (resp *types.GetSubscribeListResponse, err error) { - total, list, err := l.svcCtx.SubscribeModel.FilterList(l.ctx, &subscribe.FilterParams{ + // Build filter params + filterParams := &subscribe.FilterParams{ Page: int(req.Page), Size: int(req.Size), Language: req.Language, Search: req.Search, - }) + } + + // Add NodeGroupId filter if provided + if req.NodeGroupId > 0 { + filterParams.NodeGroupId = &req.NodeGroupId + } + + total, list, err := l.svcCtx.SubscribeModel.FilterList(l.ctx, filterParams) if err != nil { l.Logger.Error("[GetSubscribeListLogic] get subscribe list failed: ", logger.Field("error", err.Error())) return nil, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "get subscribe list failed: %v", err.Error()) @@ -56,6 +64,14 @@ func (l *GetSubscribeListLogic) GetSubscribeList(req *types.GetSubscribeListRequ } sub.Nodes = tool.StringToInt64Slice(item.Nodes) sub.NodeTags = strings.Split(item.NodeTags, ",") + // Handle NodeGroupIds - convert from JSONInt64Slice to []int64 + if item.NodeGroupIds != nil { + sub.NodeGroupIds = []int64(item.NodeGroupIds) + } else { + sub.NodeGroupIds = []int64{} + } + // NodeGroupId is already int64, should be copied by DeepCopy + sub.NodeGroupId = item.NodeGroupId resultList = append(resultList, sub) } diff --git a/internal/logic/admin/subscribe/updateSubscribeLogic.go b/internal/logic/admin/subscribe/updateSubscribeLogic.go index b79fdfe..a60a6a0 100644 --- a/internal/logic/admin/subscribe/updateSubscribeLogic.go +++ b/internal/logic/admin/subscribe/updateSubscribeLogic.go @@ -58,6 +58,8 @@ func (l *UpdateSubscribeLogic) UpdateSubscribe(req *types.UpdateSubscribeRequest Quota: req.Quota, Nodes: tool.Int64SliceToString(req.Nodes), NodeTags: tool.StringSliceToString(req.NodeTags), + NodeGroupIds: subscribe.JSONInt64Slice(req.NodeGroupIds), + NodeGroupId: req.NodeGroupId, Show: req.Show, Sell: req.Sell, Sort: req.Sort, diff --git a/internal/logic/admin/user/createUserSubscribeLogic.go b/internal/logic/admin/user/createUserSubscribeLogic.go index 08876f8..e294cb5 100644 --- a/internal/logic/admin/user/createUserSubscribeLogic.go +++ b/internal/logic/admin/user/createUserSubscribeLogic.go @@ -6,6 +6,7 @@ import ( "time" "github.com/google/uuid" + "github.com/perfect-panel/server/internal/logic/admin/group" "github.com/perfect-panel/server/internal/model/user" "github.com/perfect-panel/server/internal/svc" "github.com/perfect-panel/server/internal/types" @@ -64,6 +65,7 @@ func (l *CreateUserSubscribeLogic) CreateUserSubscribe(req *types.CreateUserSubs Upload: 0, Token: uuidx.SubscribeToken(fmt.Sprintf("adminCreate:%d", time.Now().UnixMilli())), UUID: uuid.New().String(), + NodeGroupId: sub.NodeGroupId, Status: 1, } if err = l.svcCtx.UserModel.InsertSubscribe(l.ctx, &userSub); err != nil { @@ -71,6 +73,60 @@ func (l *CreateUserSubscribeLogic) CreateUserSubscribe(req *types.CreateUserSubs return errors.Wrapf(xerr.NewErrCode(xerr.DatabaseInsertError), "InsertSubscribe error: %v", err.Error()) } + // Trigger user group recalculation (runs in background) + go func() { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + // Check if group management is enabled + var groupEnabled string + err := l.svcCtx.DB.Table("system"). + Where("`category` = ? AND `key` = ?", "group", "enabled"). + Select("value"). + Scan(&groupEnabled).Error + if err != nil || groupEnabled != "true" && groupEnabled != "1" { + l.Debugf("Group management not enabled, skipping recalculation") + return + } + + // Get the configured grouping mode + var groupMode string + err = l.svcCtx.DB.Table("system"). + Where("`category` = ? AND `key` = ?", "group", "mode"). + Select("value"). + Scan(&groupMode).Error + if err != nil { + l.Errorw("Failed to get group mode", logger.Field("error", err.Error())) + return + } + + // Validate group mode + if groupMode != "average" && groupMode != "subscribe" && groupMode != "traffic" { + l.Debugf("Invalid group mode (current: %s), skipping", groupMode) + return + } + + // Trigger group recalculation with the configured mode + logic := group.NewRecalculateGroupLogic(ctx, l.svcCtx) + req := &types.RecalculateGroupRequest{ + Mode: groupMode, + } + + if err := logic.RecalculateGroup(req); err != nil { + l.Errorw("Failed to recalculate user group", + logger.Field("user_id", userInfo.Id), + logger.Field("error", err.Error()), + ) + return + } + + l.Infow("Successfully recalculated user group after admin created subscription", + logger.Field("user_id", userInfo.Id), + logger.Field("subscribe_id", userSub.Id), + logger.Field("mode", groupMode), + ) + }() + err = l.svcCtx.UserModel.UpdateUserCache(l.ctx, userInfo) if err != nil { l.Errorw("UpdateUserCache error", logger.Field("error", err.Error())) @@ -81,5 +137,6 @@ func (l *CreateUserSubscribeLogic) CreateUserSubscribe(req *types.CreateUserSubs if err != nil { logger.Errorw("ClearSubscribe error", logger.Field("error", err.Error())) } + return nil } diff --git a/internal/logic/admin/user/updateUserBasicInfoLogic.go b/internal/logic/admin/user/updateUserBasicInfoLogic.go index faa7930..100f919 100644 --- a/internal/logic/admin/user/updateUserBasicInfoLogic.go +++ b/internal/logic/admin/user/updateUserBasicInfoLogic.go @@ -120,7 +120,31 @@ func (l *UpdateUserBasicInfoLogic) UpdateUserBasicInfo(req *types.UpdateUserBasi } userInfo.Commission = req.Commission } - tool.DeepCopy(userInfo, req) + + // 只更新指定的字段,不使用 DeepCopy 避免零值覆盖 + + // 处理头像(只在提供时更新) + if req.Avatar != "" { + userInfo.Avatar = req.Avatar + } + + // 处理推荐码(只在提供时更新) + if req.ReferCode != "" { + userInfo.ReferCode = req.ReferCode + } + + // 处理推荐人ID(只在非零时更新) + if req.RefererId != 0 { + userInfo.RefererId = req.RefererId + } + + // 处理启用状态(始终更新) + userInfo.Enable = &req.Enable + + // 处理管理员状态(始终更新) + userInfo.IsAdmin = &req.IsAdmin + + // 更新其他字段(只有在明确提供时才更新) userInfo.OnlyFirstPurchase = &req.OnlyFirstPurchase userInfo.ReferralPercentage = req.ReferralPercentage diff --git a/internal/logic/admin/user/updateUserSubscribeLogic.go b/internal/logic/admin/user/updateUserSubscribeLogic.go index 23c2d2f..123f459 100644 --- a/internal/logic/admin/user/updateUserSubscribeLogic.go +++ b/internal/logic/admin/user/updateUserSubscribeLogic.go @@ -53,6 +53,7 @@ func (l *UpdateUserSubscribeLogic) UpdateUserSubscribe(req *types.UpdateUserSubs Token: userSub.Token, UUID: userSub.UUID, Status: userSub.Status, + NodeGroupId: userSub.NodeGroupId, }) if err != nil { @@ -74,5 +75,6 @@ func (l *UpdateUserSubscribeLogic) UpdateUserSubscribe(req *types.UpdateUserSubs l.Errorf("ClearServerAllCache error: %v", err.Error()) return errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "failed to clear server cache: %v", err.Error()) } + return nil } diff --git a/internal/logic/auth/registerLimitLogic.go b/internal/logic/auth/registerLimitLogic.go index 048ef75..11ab3ca 100644 --- a/internal/logic/auth/registerLimitLogic.go +++ b/internal/logic/auth/registerLimitLogic.go @@ -16,6 +16,10 @@ func registerIpLimit(svcCtx *svc.ServiceContext, ctx context.Context, registerIp return true } + // Add timeout protection for Redis operations + ctx, cancel := context.WithTimeout(ctx, 2*time.Second) + defer cancel() + // Use a sorted set to track IP registrations with timestamp as score // Key format: register:ip:{ip} key := fmt.Sprintf("%s%s", config.RegisterIpKeyPrefix, registerIp) diff --git a/internal/logic/auth/userRegisterLogic.go b/internal/logic/auth/userRegisterLogic.go index 7172753..287a39e 100644 --- a/internal/logic/auth/userRegisterLogic.go +++ b/internal/logic/auth/userRegisterLogic.go @@ -7,6 +7,7 @@ import ( "time" "github.com/perfect-panel/server/internal/config" + "github.com/perfect-panel/server/internal/logic/admin/group" "github.com/perfect-panel/server/internal/logic/common" "github.com/perfect-panel/server/internal/model/log" "github.com/perfect-panel/server/internal/model/user" @@ -126,22 +127,76 @@ func (l *UserRegisterLogic) UserRegister(req *types.UserRegisterRequest) (resp * return err } - if l.svcCtx.Config.Register.EnableTrial { - // Active trial - var trialErr error - trialSubscribe, trialErr = l.activeTrial(userInfo.Id) - if trialErr != nil { - return trialErr - } - } return nil }) if err != nil { return nil, err } + // Activate trial subscription after transaction success (moved outside transaction to reduce lock time) + if l.svcCtx.Config.Register.EnableTrial { + trialSubscribe, err = l.activeTrial(userInfo.Id) + if err != nil { + l.Errorw("Failed to activate trial subscription", logger.Field("error", err.Error())) + // Don't fail registration if trial activation fails + } + } + // Clear cache after transaction success if l.svcCtx.Config.Register.EnableTrial && trialSubscribe != nil { + // Trigger user group recalculation (runs in background) + go func() { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + // Check if group management is enabled + var groupEnabled string + err := l.svcCtx.DB.Table("system"). + Where("`category` = ? AND `key` = ?", "group", "enabled"). + Select("value"). + Scan(&groupEnabled).Error + if err != nil || groupEnabled != "true" && groupEnabled != "1" { + l.Debugf("Group management not enabled, skipping recalculation") + return + } + + // Get the configured grouping mode + var groupMode string + err = l.svcCtx.DB.Table("system"). + Where("`category` = ? AND `key` = ?", "group", "mode"). + Select("value"). + Scan(&groupMode).Error + if err != nil { + l.Errorw("Failed to get group mode", logger.Field("error", err.Error())) + return + } + + // Validate group mode + if groupMode != "average" && groupMode != "subscribe" && groupMode != "traffic" { + l.Debugf("Invalid group mode (current: %s), skipping", groupMode) + return + } + + // Trigger group recalculation with the configured mode + logic := group.NewRecalculateGroupLogic(ctx, l.svcCtx) + req := &types.RecalculateGroupRequest{ + Mode: groupMode, + } + + if err := logic.RecalculateGroup(req); err != nil { + l.Errorw("Failed to recalculate user group", + logger.Field("user_id", userInfo.Id), + logger.Field("error", err.Error()), + ) + return + } + + l.Infow("Successfully recalculated user group after registration", + logger.Field("user_id", userInfo.Id), + logger.Field("mode", groupMode), + ) + }() + // Clear user subscription cache if err = l.svcCtx.UserModel.ClearSubscribeCache(l.ctx, trialSubscribe); err != nil { l.Errorw("ClearSubscribeCache failed", logger.Field("error", err.Error()), logger.Field("userSubscribeId", trialSubscribe.Id)) diff --git a/internal/logic/public/subscribe/queryUserSubscribeNodeListLogic.go b/internal/logic/public/subscribe/queryUserSubscribeNodeListLogic.go index 7573d89..a619c88 100644 --- a/internal/logic/public/subscribe/queryUserSubscribeNodeListLogic.go +++ b/internal/logic/public/subscribe/queryUserSubscribeNodeListLogic.go @@ -91,25 +91,37 @@ func (l *QueryUserSubscribeNodeListLogic) getServers(userSub *user.Subscribe) (u return l.createExpiredServers(), nil } - subDetails, err := l.svcCtx.SubscribeModel.FindOne(l.ctx, userSub.SubscribeId) + // Check if group management is enabled + var groupEnabled string + err = l.svcCtx.DB.Table("system"). + Where("`category` = ? AND `key` = ?", "group", "enabled"). + Select("value").Scan(&groupEnabled).Error + if err != nil { - l.Errorw("[Generate Subscribe]find subscribe details error: %v", logger.Field("error", err.Error())) - return nil, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "find subscribe details error: %v", err.Error()) + l.Debugw("[GetServers] Failed to check group enabled", logger.Field("error", err.Error())) + // Continue with tag-based filtering } - nodeIds := tool.StringToInt64Slice(subDetails.Nodes) - tags := strings.Split(subDetails.NodeTags, ",") - l.Debugf("[Generate Subscribe]nodes: %v, NodeTags: %v", nodeIds, tags) + isGroupEnabled := (groupEnabled == "true" || groupEnabled == "1") - enable := true - - _, nodes, err := l.svcCtx.NodeModel.FilterNodeList(l.ctx, &node.FilterNodeParams{ - Page: 0, - Size: 1000, - NodeId: nodeIds, - Enabled: &enable, // Only get enabled nodes - }) + var nodes []*node.Node + if isGroupEnabled { + // Group mode: use group_ids to filter nodes + nodes, err = l.getNodesByGroup(userSub) + if err != nil { + l.Errorw("[GetServers] Failed to get nodes by group", logger.Field("error", err.Error())) + return nil, err + } + } else { + // Tag mode: use node_ids and tags to filter nodes + nodes, err = l.getNodesByTag(userSub) + if err != nil { + l.Errorw("[GetServers] Failed to get nodes by tag", logger.Field("error", err.Error())) + return nil, err + } + } + // Process nodes and create response if len(nodes) > 0 { var serverMapIds = make(map[int64]*node.Server) for _, n := range nodes { @@ -157,15 +169,100 @@ func (l *QueryUserSubscribeNodeListLogic) getServers(userSub *user.Subscribe) (u } l.Debugf("[Query Subscribe]found servers: %v", len(nodes)) - - if err != nil { - l.Errorw("[Generate Subscribe]find server details error: %v", logger.Field("error", err.Error())) - return nil, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "find server details error: %v", err.Error()) - } - logger.Debugf("[Generate Subscribe]found servers: %v", len(nodes)) return userSubscribeNodes, nil } +// getNodesByGroup gets nodes based on user subscription node_group_id with priority fallback +func (l *QueryUserSubscribeNodeListLogic) getNodesByGroup(userSub *user.Subscribe) ([]*node.Node, error) { + // 按优先级获取 node_group_id:user_subscribe.node_group_id > subscribe.node_group_id > subscribe.node_group_ids[0] + nodeGroupId := int64(0) + source := "" + + // 优先级1: user_subscribe.node_group_id + if userSub.NodeGroupId != 0 { + nodeGroupId = userSub.NodeGroupId + source = "user_subscribe.node_group_id" + } else { + // 优先级2 & 3: 从 subscribe 表获取 + subDetails, err := l.svcCtx.SubscribeModel.FindOne(l.ctx, userSub.SubscribeId) + if err != nil { + l.Errorw("[GetNodesByGroup] find subscribe details error", logger.Field("error", err.Error())) + return nil, err + } + + // 优先级2: subscribe.node_group_id + if subDetails.NodeGroupId != 0 { + nodeGroupId = subDetails.NodeGroupId + source = "subscribe.node_group_id" + } else if len(subDetails.NodeGroupIds) > 0 { + // 优先级3: subscribe.node_group_ids[0] + nodeGroupId = subDetails.NodeGroupIds[0] + source = "subscribe.node_group_ids[0]" + } + } + + // 如果所有优先级都没有获取到,返回空节点列表 + if nodeGroupId == 0 { + l.Debugw("[GetNodesByGroup] no node_group_id found in any priority, returning no nodes") + return []*node.Node{}, nil + } + + l.Debugf("[GetNodesByGroup] Using %s: %v", source, nodeGroupId) + + // Filter nodes by node_group_id + enable := true + _, nodes, err := l.svcCtx.NodeModel.FilterNodeList(l.ctx, &node.FilterNodeParams{ + Page: 0, + Size: 1000, + NodeGroupIds: []int64{nodeGroupId}, // Filter by node_group_ids + Enabled: &enable, + }) + if err != nil { + l.Errorw("[GetNodesByGroup] FilterNodeList error", logger.Field("error", err.Error())) + return nil, err + } + + l.Debugf("[GetNodesByGroup] Found %d nodes for node_group_id=%d", len(nodes), nodeGroupId) + return nodes, nil +} + +// getNodesByTag gets nodes based on subscribe node_ids and tags +func (l *QueryUserSubscribeNodeListLogic) getNodesByTag(userSub *user.Subscribe) ([]*node.Node, error) { + subDetails, err := l.svcCtx.SubscribeModel.FindOne(l.ctx, userSub.SubscribeId) + if err != nil { + l.Errorw("[Generate Subscribe]find subscribe details error: %v", logger.Field("error", err.Error())) + return nil, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "find subscribe details error: %v", err.Error()) + } + + nodeIds := tool.StringToInt64Slice(subDetails.Nodes) + tags := strings.Split(subDetails.NodeTags, ",") + + l.Debugf("[Generate Subscribe]nodes: %v, NodeTags: %v", nodeIds, tags) + + enable := true + _, nodes, err := l.svcCtx.NodeModel.FilterNodeList(l.ctx, &node.FilterNodeParams{ + Page: 0, + Size: 1000, + NodeId: nodeIds, + Tag: tags, + Enabled: &enable, // Only get enabled nodes + }) + + return nodes, err +} + +// getAllNodes returns all enabled nodes +func (l *QueryUserSubscribeNodeListLogic) getAllNodes() ([]*node.Node, error) { + enable := true + _, nodes, err := l.svcCtx.NodeModel.FilterNodeList(l.ctx, &node.FilterNodeParams{ + Page: 0, + Size: 1000, + Enabled: &enable, + }) + + return nodes, err +} + func (l *QueryUserSubscribeNodeListLogic) isSubscriptionExpired(userSub *user.Subscribe) bool { return userSub.ExpireTime.Unix() < time.Now().Unix() && userSub.ExpireTime.Unix() != 0 } diff --git a/internal/logic/server/getServerUserListLogic.go b/internal/logic/server/getServerUserListLogic.go index 70ea51f..6d51326 100644 --- a/internal/logic/server/getServerUserListLogic.go +++ b/internal/logic/server/getServerUserListLogic.go @@ -55,6 +55,7 @@ func (l *GetServerUserListLogic) GetServerUserList(req *types.GetServerUserListR return nil, err } + // 查询该服务器上该协议的所有节点(包括属于节点组的节点) _, nodes, err := l.svcCtx.NodeModel.FilterNodeList(l.ctx, &node.FilterNodeParams{ Page: 1, Size: 1000, @@ -65,25 +66,74 @@ func (l *GetServerUserListLogic) GetServerUserList(req *types.GetServerUserListR l.Errorw("FilterNodeList error", logger.Field("error", err.Error())) return nil, err } - var nodeTag []string + + if len(nodes) == 0 { + return &types.GetServerUserListResponse{ + Users: []types.ServerUser{ + { + Id: 1, + UUID: uuidx.NewUUID().String(), + }, + }, + }, nil + } + + // 收集所有唯一的节点组 ID + nodeGroupMap := make(map[int64]bool) // nodeGroupId -> true var nodeIds []int64 + var nodeTags []string + for _, n := range nodes { nodeIds = append(nodeIds, n.Id) if n.Tags != "" { - nodeTag = append(nodeTag, strings.Split(n.Tags, ",")...) + nodeTags = append(nodeTags, strings.Split(n.Tags, ",")...) + } + // 收集节点组 ID + if len(n.NodeGroupIds) > 0 { + for _, gid := range n.NodeGroupIds { + if gid > 0 { + nodeGroupMap[gid] = true + } + } } } - _, subs, err := l.svcCtx.SubscribeModel.FilterList(l.ctx, &subscribe.FilterParams{ - Page: 1, - Size: 9999, - Node: nodeIds, - Tags: nodeTag, - }) - if err != nil { - l.Errorw("QuerySubscribeIdsByServerIdAndServerGroupId error", logger.Field("error", err.Error())) - return nil, err + // 获取所有节点组 ID + nodeGroupIds := make([]int64, 0, len(nodeGroupMap)) + for gid := range nodeGroupMap { + nodeGroupIds = append(nodeGroupIds, gid) } + + // 查询订阅: + // 1. 如果有节点组,查询匹配这些节点组的订阅 + // 2. 如果没有节点组,查询使用节点 ID 或 tags 的订阅 + var subs []*subscribe.Subscribe + if len(nodeGroupIds) > 0 { + // 节点组模式:查询 node_group_id 或 node_group_ids 匹配的订阅 + _, subs, err = l.svcCtx.SubscribeModel.FilterListByNodeGroups(l.ctx, &subscribe.FilterByNodeGroupsParams{ + Page: 1, + Size: 9999, + NodeGroupIds: nodeGroupIds, + }) + if err != nil { + l.Errorw("FilterListByNodeGroups error", logger.Field("error", err.Error())) + return nil, err + } + } else { + // 传统模式:查询匹配节点 ID 或 tags 的订阅 + nodeTags = tool.RemoveDuplicateElements(nodeTags...) + _, subs, err = l.svcCtx.SubscribeModel.FilterList(l.ctx, &subscribe.FilterParams{ + Page: 1, + Size: 9999, + Node: nodeIds, + Tags: nodeTags, + }) + if err != nil { + l.Errorw("FilterList error", logger.Field("error", err.Error())) + return nil, err + } + } + if len(subs) == 0 { return &types.GetServerUserListResponse{ Users: []types.ServerUser{ diff --git a/internal/logic/subscribe/subscribeLogic.go b/internal/logic/subscribe/subscribeLogic.go index 28f9ecb..823f377 100644 --- a/internal/logic/subscribe/subscribeLogic.go +++ b/internal/logic/subscribe/subscribeLogic.go @@ -215,14 +215,133 @@ func (l *SubscribeLogic) getServers(userSub *user.Subscribe) ([]*node.Node, erro return nil, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "find subscribe details error: %v", err.Error()) } + // 判断是否使用分组模式 + isGroupMode := l.isGroupEnabled() + + if isGroupMode { + // === 分组模式:使用 node_group_id 获取节点 === + // 按优先级获取 node_group_id:user_subscribe.node_group_id > subscribe.node_group_id > subscribe.node_group_ids[0] + nodeGroupId := int64(0) + source := "" + + // 优先级1: user_subscribe.node_group_id + if userSub.NodeGroupId != 0 { + nodeGroupId = userSub.NodeGroupId + source = "user_subscribe.node_group_id" + } else { + // 优先级2 & 3: 从 subscribe 表获取 + if subDetails.NodeGroupId != 0 { + nodeGroupId = subDetails.NodeGroupId + source = "subscribe.node_group_id" + } else if len(subDetails.NodeGroupIds) > 0 { + // 优先级3: subscribe.node_group_ids[0] + nodeGroupId = subDetails.NodeGroupIds[0] + source = "subscribe.node_group_ids[0]" + } + } + + l.Debugf("[Generate Subscribe]group mode, using %s: %v", source, nodeGroupId) + + // 根据 node_group_id 获取节点 + enable := true + + // 1. 获取分组节点 + var groupNodes []*node.Node + if nodeGroupId > 0 { + params := &node.FilterNodeParams{ + Page: 0, + Size: 1000, + NodeGroupIds: []int64{nodeGroupId}, + Enabled: &enable, + Preload: true, + } + _, groupNodes, err = l.svc.NodeModel.FilterNodeList(l.ctx.Request.Context(), params) + + if err != nil { + l.Errorw("[Generate Subscribe]filter nodes by group error", logger.Field("error", err.Error())) + return nil, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "filter nodes by group error: %v", err.Error()) + } + l.Debugf("[Generate Subscribe]found %d nodes for node_group_id=%d", len(groupNodes), nodeGroupId) + } + + // 2. 获取公共节点(NodeGroupIds 为空的节点) + _, allNodes, err := l.svc.NodeModel.FilterNodeList(l.ctx.Request.Context(), &node.FilterNodeParams{ + Page: 0, + Size: 1000, + Enabled: &enable, + Preload: true, + }) + + if err != nil { + l.Errorw("[Generate Subscribe]filter all nodes error", logger.Field("error", err.Error())) + return nil, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "filter all nodes error: %v", err.Error()) + } + + // 过滤出公共节点 + var publicNodes []*node.Node + for _, n := range allNodes { + if len(n.NodeGroupIds) == 0 { + publicNodes = append(publicNodes, n) + } + } + l.Debugf("[Generate Subscribe]found %d public nodes (node_group_ids is empty)", len(publicNodes)) + + // 3. 合并分组节点和公共节点 + nodesMap := make(map[int64]*node.Node) + for _, n := range groupNodes { + nodesMap[n.Id] = n + } + for _, n := range publicNodes { + if _, exists := nodesMap[n.Id]; !exists { + nodesMap[n.Id] = n + } + } + + // 转换为切片 + var result []*node.Node + for _, n := range nodesMap { + result = append(result, n) + } + + l.Debugf("[Generate Subscribe]total nodes (group + public): %d (group: %d, public: %d)", len(result), len(groupNodes), len(publicNodes)) + + // 查询节点组信息,获取节点组名称(仅当用户有分组时) + if nodeGroupId > 0 { + type NodeGroupInfo struct { + Id int64 + Name string + } + var nodeGroupInfo NodeGroupInfo + err = l.svc.DB.Table("node_group").Select("id, name").Where("id = ?", nodeGroupId).First(&nodeGroupInfo).Error + if err != nil { + l.Infow("[Generate Subscribe]node group not found", logger.Field("nodeGroupId", nodeGroupId), logger.Field("error", err.Error())) + } + + // 如果节点组信息存在,为没有 tag 的分组节点设置节点组名称为 tag + if nodeGroupInfo.Id != 0 && nodeGroupInfo.Name != "" { + for _, n := range result { + // 只为分组节点设置 tag,公共节点不设置 + if n.Tags == "" && len(n.NodeGroupIds) > 0 { + n.Tags = nodeGroupInfo.Name + l.Debugf("[Generate Subscribe]set node_group name as tag for node %d: %s", n.Id, nodeGroupInfo.Name) + } + } + } + } + + return result, nil + } + + // === 标签模式:使用 node_ids 和 tags 获取节点 === nodeIds := tool.StringToInt64Slice(subDetails.Nodes) tags := tool.RemoveStringElement(strings.Split(subDetails.NodeTags, ","), "") - l.Debugf("[Generate Subscribe]nodes: %v, NodeTags: %v", len(nodeIds), len(tags)) + l.Debugf("[Generate Subscribe]tag mode, nodes: %v, NodeTags: %v", len(nodeIds), len(tags)) if len(nodeIds) == 0 && len(tags) == 0 { - logger.Infow("[Generate Subscribe]no subscribe nodes") + logger.Infow("[Generate Subscribe]no subscribe nodes configured") return []*node.Node{}, nil } + enable := true var nodes []*node.Node _, nodes, err = l.svc.NodeModel.FilterNodeList(l.ctx.Request.Context(), &node.FilterNodeParams{ @@ -231,16 +350,15 @@ func (l *SubscribeLogic) getServers(userSub *user.Subscribe) ([]*node.Node, erro NodeId: nodeIds, Tag: tool.RemoveDuplicateElements(tags...), Preload: true, - Enabled: &enable, // Only get enabled nodes + Enabled: &enable, }) - l.Debugf("[Query Subscribe]found servers: %v", len(nodes)) - if err != nil { l.Errorw("[Generate Subscribe]find server details error: %v", logger.Field("error", err.Error())) return nil, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "find server details error: %v", err.Error()) } - logger.Debugf("[Generate Subscribe]found servers: %v", len(nodes)) + + l.Debugf("[Generate Subscribe]found %d nodes in tag mode", len(nodes)) return nodes, nil } @@ -290,3 +408,17 @@ func (l *SubscribeLogic) getFirstHostLine() string { } return host } + +// isGroupEnabled 判断分组功能是否启用 +func (l *SubscribeLogic) isGroupEnabled() bool { + var value string + err := l.svc.DB.Table("system"). + Where("`category` = ? AND `key` = ?", "group", "enabled"). + Select("value"). + Scan(&value).Error + if err != nil { + l.Debugf("[SubscribeLogic]check group enabled failed: %v", err) + return false + } + return value == "true" || value == "1" +} diff --git a/internal/model/group/history.go b/internal/model/group/history.go new file mode 100644 index 0000000..5ab1b0b --- /dev/null +++ b/internal/model/group/history.go @@ -0,0 +1,54 @@ +package group + +import ( + "time" + + "gorm.io/gorm" +) + +// GroupHistory 分组历史记录模型 +type GroupHistory struct { + Id int64 `gorm:"primaryKey"` + GroupMode string `gorm:"type:varchar(50);not null;index:idx_group_mode;comment:Group Mode: average/subscribe/traffic"` + TriggerType string `gorm:"type:varchar(50);not null;index:idx_trigger_type;comment:Trigger Type: manual/auto/schedule"` + State string `gorm:"type:varchar(50);not null;index:idx_state;comment:State: pending/running/completed/failed"` + TotalUsers int `gorm:"default:0;not null;comment:Total Users"` + SuccessCount int `gorm:"default:0;not null;comment:Success Count"` + FailedCount int `gorm:"default:0;not null;comment:Failed Count"` + StartTime *time.Time `gorm:"comment:Start Time"` + EndTime *time.Time `gorm:"comment:End Time"` + Operator string `gorm:"type:varchar(100);comment:Operator"` + ErrorMessage string `gorm:"type:TEXT;comment:Error Message"` + CreatedAt time.Time `gorm:"<-:create;index:idx_created_at;comment:Create Time"` +} + +// TableName 指定表名 +func (*GroupHistory) TableName() string { + return "group_history" +} + +// BeforeCreate GORM hook - 创建前回调 +func (gh *GroupHistory) BeforeCreate(tx *gorm.DB) error { + return nil +} + +// GroupHistoryDetail 分组历史详情模型 +type GroupHistoryDetail struct { + Id int64 `gorm:"primaryKey"` + HistoryId int64 `gorm:"not null;index:idx_history_id;comment:History ID"` + NodeGroupId int64 `gorm:"not null;index:idx_node_group_id;comment:Node Group ID"` + UserCount int `gorm:"default:0;not null;comment:User Count"` + NodeCount int `gorm:"default:0;not null;comment:Node Count"` + UserData string `gorm:"type:text;comment:User data JSON (id and email/phone)"` + CreatedAt time.Time `gorm:"<-:create;comment:Create Time"` +} + +// TableName 指定表名 +func (*GroupHistoryDetail) TableName() string { + return "group_history_detail" +} + +// BeforeCreate GORM hook - 创建前回调 +func (ghd *GroupHistoryDetail) BeforeCreate(tx *gorm.DB) error { + return nil +} diff --git a/internal/model/group/model.go b/internal/model/group/model.go new file mode 100644 index 0000000..77d5e1b --- /dev/null +++ b/internal/model/group/model.go @@ -0,0 +1,14 @@ +package group + +import ( + "gorm.io/gorm" +) + +// AutoMigrate 自动迁移数据库表 +func AutoMigrate(db *gorm.DB) error { + return db.AutoMigrate( + &NodeGroup{}, + &GroupHistory{}, + &GroupHistoryDetail{}, + ) +} diff --git a/internal/model/group/node_group.go b/internal/model/group/node_group.go new file mode 100644 index 0000000..644580a --- /dev/null +++ b/internal/model/group/node_group.go @@ -0,0 +1,30 @@ +package group + +import ( + "time" + + "gorm.io/gorm" +) + +// NodeGroup 节点组模型 +type NodeGroup struct { + Id int64 `gorm:"primaryKey"` + Name string `gorm:"type:varchar(255);not null;comment:Name"` + Description string `gorm:"type:varchar(500);comment:Description"` + Sort int `gorm:"default:0;index:idx_sort;comment:Sort Order"` + ForCalculation *bool `gorm:"default:true;not null;comment:For Calculation: whether this node group participates in grouping calculation"` + MinTrafficGB *int64 `gorm:"default:0;comment:Minimum Traffic (GB) for this node group"` + MaxTrafficGB *int64 `gorm:"default:0;comment:Maximum Traffic (GB) for this node group"` + CreatedAt time.Time `gorm:"<-:create;comment:Create Time"` + UpdatedAt time.Time `gorm:"comment:Update Time"` +} + +// TableName 指定表名 +func (*NodeGroup) TableName() string { + return "node_group" +} + +// BeforeCreate GORM hook - 创建前回调 +func (ng *NodeGroup) BeforeCreate(tx *gorm.DB) error { + return nil +} diff --git a/internal/model/node/model.go b/internal/model/node/model.go index f5d9eb2..0a769b9 100644 --- a/internal/model/node/model.go +++ b/internal/model/node/model.go @@ -33,15 +33,16 @@ type FilterParams struct { } type FilterNodeParams struct { - Page int // Page Number - Size int // Page Size - NodeId []int64 // Node IDs - ServerId []int64 // Server IDs - Tag []string // Tags - Search string // Search Address or Name - Protocol string // Protocol - Preload bool // Preload Server - Enabled *bool // Enabled + Page int // Page Number + Size int // Page Size + NodeId []int64 // Node IDs + ServerId []int64 // Server IDs + Tag []string // Tags + NodeGroupIds []int64 // Node Group IDs + Search string // Search Address or Name + Protocol string // Protocol + Preload bool // Preload Server + Enabled *bool // Enabled } // FilterServerList Filter Server List @@ -96,6 +97,18 @@ func (m *customServerModel) FilterNodeList(ctx context.Context, params *FilterNo if len(params.Tag) > 0 { query = query.Scopes(InSet("tags", params.Tag)) } + if len(params.NodeGroupIds) > 0 { + // Filter by node_group_ids using JSON_CONTAINS for each group ID + // Multiple group IDs: node must belong to at least one of the groups + var conditions []string + for _, gid := range params.NodeGroupIds { + conditions = append(conditions, fmt.Sprintf("JSON_CONTAINS(node_group_ids, %d)", gid)) + } + if len(conditions) > 0 { + query = query.Where("(" + strings.Join(conditions, " OR ") + ")") + } + } + // If no NodeGroupIds specified, return all nodes (including public nodes) if params.Protocol != "" { query = query.Where("protocol = ?", params.Protocol) } diff --git a/internal/model/node/node.go b/internal/model/node/node.go index 89d665d..787ea32 100644 --- a/internal/model/node/node.go +++ b/internal/model/node/node.go @@ -1,25 +1,73 @@ package node import ( + "database/sql/driver" + "encoding/json" "time" "github.com/perfect-panel/server/pkg/logger" "gorm.io/gorm" ) +// JSONInt64Slice is a custom type for handling []int64 as JSON in database +type JSONInt64Slice []int64 + +// Scan implements sql.Scanner interface +func (j *JSONInt64Slice) Scan(value interface{}) error { + if value == nil { + *j = []int64{} + return nil + } + + // Handle []byte + bytes, ok := value.([]byte) + if !ok { + // Try to handle string + str, ok := value.(string) + if !ok { + *j = []int64{} + return nil + } + bytes = []byte(str) + } + + if len(bytes) == 0 { + *j = []int64{} + return nil + } + + // Check if it's a JSON array + if bytes[0] != '[' { + // Not a JSON array, return empty slice + *j = []int64{} + return nil + } + + return json.Unmarshal(bytes, j) +} + +// Value implements driver.Valuer interface +func (j JSONInt64Slice) Value() (driver.Value, error) { + if len(j) == 0 { + return "[]", nil + } + return json.Marshal(j) +} + type Node struct { - Id int64 `gorm:"primary_key"` - Name string `gorm:"type:varchar(100);not null;default:'';comment:Node Name"` - Tags string `gorm:"type:varchar(255);not null;default:'';comment:Tags"` - Port uint16 `gorm:"not null;default:0;comment:Connect Port"` - Address string `gorm:"type:varchar(255);not null;default:'';comment:Connect Address"` - ServerId int64 `gorm:"not null;default:0;comment:Server ID"` - Server *Server `gorm:"foreignKey:ServerId;references:Id"` - Protocol string `gorm:"type:varchar(100);not null;default:'';comment:Protocol"` - Enabled *bool `gorm:"type:boolean;not null;default:true;comment:Enabled"` - Sort int `gorm:"uniqueIndex;not null;default:0;comment:Sort"` - CreatedAt time.Time `gorm:"<-:create;comment:Creation Time"` - UpdatedAt time.Time `gorm:"comment:Update Time"` + Id int64 `gorm:"primary_key"` + Name string `gorm:"type:varchar(100);not null;default:'';comment:Node Name"` + Tags string `gorm:"type:varchar(255);not null;default:'';comment:Tags"` + Port uint16 `gorm:"not null;default:0;comment:Connect Port"` + Address string `gorm:"type:varchar(255);not null;default:'';comment:Connect Address"` + ServerId int64 `gorm:"not null;default:0;comment:Server ID"` + Server *Server `gorm:"foreignKey:ServerId;references:Id"` + Protocol string `gorm:"type:varchar(100);not null;default:'';comment:Protocol"` + Enabled *bool `gorm:"type:boolean;not null;default:true;comment:Enabled"` + Sort int `gorm:"uniqueIndex;not null;default:0;comment:Sort"` + NodeGroupIds JSONInt64Slice `gorm:"type:json;comment:Node Group IDs (JSON array, multiple groups)"` + CreatedAt time.Time `gorm:"<-:create;comment:Creation Time"` + UpdatedAt time.Time `gorm:"comment:Update Time"` } func (n *Node) TableName() string { diff --git a/internal/model/subscribe/model.go b/internal/model/subscribe/model.go index 9942046..06c0c6d 100644 --- a/internal/model/subscribe/model.go +++ b/internal/model/subscribe/model.go @@ -2,6 +2,7 @@ package subscribe import ( "context" + "strings" "github.com/perfect-panel/server/pkg/tool" "github.com/redis/go-redis/v9" @@ -19,6 +20,13 @@ type FilterParams struct { Language string // Language DefaultLanguage bool // Default Subscribe Language Data Search string // Search Keywords + NodeGroupId *int64 // Node Group ID +} + +type FilterByNodeGroupsParams struct { + Page int // Page Number + Size int // Page Size + NodeGroupIds []int64 // Node Group IDs (multiple) } func (p *FilterParams) Normalize() { @@ -32,6 +40,7 @@ func (p *FilterParams) Normalize() { type customSubscribeLogicModel interface { FilterList(ctx context.Context, params *FilterParams) (int64, []*Subscribe, error) + FilterListByNodeGroups(ctx context.Context, params *FilterByNodeGroupsParams) (int64, []*Subscribe, error) ClearCache(ctx context.Context, id ...int64) error QuerySubscribeMinSortByIds(ctx context.Context, ids []int64) (int64, error) } @@ -102,6 +111,10 @@ func (m *customSubscribeModel) FilterList(ctx context.Context, params *FilterPar if len(params.Tags) > 0 { query = query.Scopes(InSet("node_tags", params.Tags)) } + if params.NodeGroupId != nil { + // Filter by node_group_ids using JSON_CONTAINS + query = query.Where("JSON_CONTAINS(node_group_ids, ?)", *params.NodeGroupId) + } if lang != "" { query = query.Where("language = ?", lang) } else if params.DefaultLanguage { @@ -154,3 +167,67 @@ func InSet(field string, values []string) func(db *gorm.DB) *gorm.DB { return query } } + +// FilterListByNodeGroups Filter subscribes by node groups +// Match if subscribe's node_group_id OR node_group_ids contains any of the provided node group IDs +func (m *customSubscribeModel) FilterListByNodeGroups(ctx context.Context, params *FilterByNodeGroupsParams) (int64, []*Subscribe, error) { + if params == nil { + params = &FilterByNodeGroupsParams{ + Page: 1, + Size: 10, + } + } + if params.Page <= 0 { + params.Page = 1 + } + if params.Size <= 0 { + params.Size = 10 + } + + var list []*Subscribe + var total int64 + + err := m.QueryNoCacheCtx(ctx, &list, func(conn *gorm.DB, v interface{}) error { + query := conn.Model(&Subscribe{}) + + // Filter by node groups: match if node_group_id or node_group_ids contains any of the provided IDs + if len(params.NodeGroupIds) > 0 { + var conditions []string + var args []interface{} + + // Condition 1: node_group_id IN (...) + placeholders := make([]string, len(params.NodeGroupIds)) + for i, id := range params.NodeGroupIds { + placeholders[i] = "?" + args = append(args, id) + } + conditions = append(conditions, "node_group_id IN ("+strings.Join(placeholders, ",")+")") + + // Condition 2: JSON_CONTAINS(node_group_ids, id) for each id + for _, id := range params.NodeGroupIds { + conditions = append(conditions, "JSON_CONTAINS(node_group_ids, ?)") + args = append(args, id) + } + + // Combine with OR: (node_group_id IN (...) OR JSON_CONTAINS(node_group_ids, id1) OR ...) + query = query.Where("("+strings.Join(conditions, " OR ")+")", args...) + } + + // Count total + if err := query.Count(&total).Error; err != nil { + return err + } + + // Find with pagination + return query.Order("sort ASC"). + Limit(params.Size). + Offset((params.Page - 1) * params.Size). + Find(v).Error + }) + + if err != nil { + return 0, nil, err + } + + return total, list, nil +} diff --git a/internal/model/subscribe/subscribe.go b/internal/model/subscribe/subscribe.go index c9c1046..af58598 100644 --- a/internal/model/subscribe/subscribe.go +++ b/internal/model/subscribe/subscribe.go @@ -1,37 +1,86 @@ package subscribe import ( + "database/sql/driver" + "encoding/json" "time" "gorm.io/gorm" ) +// JSONInt64Slice is a custom type for handling []int64 as JSON in database +type JSONInt64Slice []int64 + +// Scan implements sql.Scanner interface +func (j *JSONInt64Slice) Scan(value interface{}) error { + if value == nil { + *j = []int64{} + return nil + } + + // Handle []byte + bytes, ok := value.([]byte) + if !ok { + // Try to handle string + str, ok := value.(string) + if !ok { + *j = []int64{} + return nil + } + bytes = []byte(str) + } + + if len(bytes) == 0 { + *j = []int64{} + return nil + } + + // Check if it's a JSON array + if bytes[0] != '[' { + // Not a JSON array, return empty slice + *j = []int64{} + return nil + } + + return json.Unmarshal(bytes, j) +} + +// Value implements driver.Valuer interface +func (j JSONInt64Slice) Value() (driver.Value, error) { + if len(j) == 0 { + return "[]", nil + } + return json.Marshal(j) +} + type Subscribe struct { - Id int64 `gorm:"primaryKey"` - Name string `gorm:"type:varchar(255);not null;default:'';comment:Subscribe Name"` - Language string `gorm:"type:varchar(255);not null;default:'';comment:Language"` - Description string `gorm:"type:text;comment:Subscribe Description"` - UnitPrice int64 `gorm:"type:int;not null;default:0;comment:Unit Price"` - UnitTime string `gorm:"type:varchar(255);not null;default:'';comment:Unit Time"` - Discount string `gorm:"type:text;comment:Discount"` - Replacement int64 `gorm:"type:int;not null;default:0;comment:Replacement"` - Inventory int64 `gorm:"type:int;not null;default:-1;comment:Inventory"` - Traffic int64 `gorm:"type:int;not null;default:0;comment:Traffic"` - SpeedLimit int64 `gorm:"type:int;not null;default:0;comment:Speed Limit"` - DeviceLimit int64 `gorm:"type:int;not null;default:0;comment:Device Limit"` - Quota int64 `gorm:"type:int;not null;default:0;comment:Quota"` - Nodes string `gorm:"type:varchar(255);comment:Node Ids"` - NodeTags string `gorm:"type:varchar(255);comment:Node Tags"` - Show *bool `gorm:"type:tinyint(1);not null;default:0;comment:Show portal page"` - Sell *bool `gorm:"type:tinyint(1);not null;default:0;comment:Sell"` - Sort int64 `gorm:"type:int;not null;default:0;comment:Sort"` - DeductionRatio int64 `gorm:"type:int;default:0;comment:Deduction Ratio"` - AllowDeduction *bool `gorm:"type:tinyint(1);default:1;comment:Allow deduction"` - ResetCycle int64 `gorm:"type:int;default:0;comment:Reset Cycle: 0: No Reset, 1: 1st, 2: Monthly, 3: Yearly"` - RenewalReset *bool `gorm:"type:tinyint(1);default:0;comment:Renew Reset"` - ShowOriginalPrice bool `gorm:"type:tinyint(1);not null;default:1;comment:Show Original Price"` - CreatedAt time.Time `gorm:"<-:create;comment:Create Time"` - UpdatedAt time.Time `gorm:"comment:Update Time"` + Id int64 `gorm:"primaryKey"` + Name string `gorm:"type:varchar(255);not null;default:'';comment:Subscribe Name"` + Language string `gorm:"type:varchar(255);not null;default:'';comment:Language"` + Description string `gorm:"type:text;comment:Subscribe Description"` + UnitPrice int64 `gorm:"type:int;not null;default:0;comment:Unit Price"` + UnitTime string `gorm:"type:varchar(255);not null;default:'';comment:Unit Time"` + Discount string `gorm:"type:text;comment:Discount"` + Replacement int64 `gorm:"type:int;not null;default:0;comment:Replacement"` + Inventory int64 `gorm:"type:int;not null;default:-1;comment:Inventory"` + Traffic int64 `gorm:"type:int;not null;default:0;comment:Traffic"` + SpeedLimit int64 `gorm:"type:int;not null;default:0;comment:Speed Limit"` + DeviceLimit int64 `gorm:"type:int;not null;default:0;comment:Device Limit"` + Quota int64 `gorm:"type:int;not null;default:0;comment:Quota"` + Nodes string `gorm:"type:varchar(255);comment:Node Ids"` + NodeTags string `gorm:"type:varchar(255);comment:Node Tags"` + NodeGroupIds JSONInt64Slice `gorm:"type:json;comment:Node Group IDs (JSON array, multiple groups)"` + NodeGroupId int64 `gorm:"default:0;index:idx_node_group_id;comment:Default Node Group ID (single ID)"` + Show *bool `gorm:"type:tinyint(1);not null;default:0;comment:Show portal page"` + Sell *bool `gorm:"type:tinyint(1);not null;default:0;comment:Sell"` + Sort int64 `gorm:"type:int;not null;default:0;comment:Sort"` + DeductionRatio int64 `gorm:"type:int;default:0;comment:Deduction Ratio"` + AllowDeduction *bool `gorm:"type:tinyint(1);default:1;comment:Allow deduction"` + ResetCycle int64 `gorm:"type:int;default:0;comment:Reset Cycle: 0: No Reset, 1: 1st, 2: Monthly, 3: Yearly"` + RenewalReset *bool `gorm:"type:tinyint(1);default:0;comment:Renew Reset"` + ShowOriginalPrice bool `gorm:"type:tinyint(1);not null;default:1;comment:Show Original Price"` + CreatedAt time.Time `gorm:"<-:create;comment:Create Time"` + UpdatedAt time.Time `gorm:"comment:Update Time"` } func (*Subscribe) TableName() string { diff --git a/internal/model/user/model.go b/internal/model/user/model.go index 944869c..0cf3502 100644 --- a/internal/model/user/model.go +++ b/internal/model/user/model.go @@ -27,6 +27,7 @@ type SubscribeDetails struct { OrderId int64 `gorm:"index:idx_order_id;not null;comment:Order ID"` SubscribeId int64 `gorm:"index:idx_subscribe_id;not null;comment:Subscription ID"` Subscribe *subscribe.Subscribe `gorm:"foreignKey:SubscribeId;references:Id"` + NodeGroupId int64 `gorm:"index:idx_node_group_id;not null;default:0;comment:Node Group ID (single ID)"` StartTime time.Time `gorm:"default:CURRENT_TIMESTAMP(3);not null;comment:Subscription Start Time"` ExpireTime time.Time `gorm:"default:NULL;comment:Subscription Expire Time"` FinishedAt *time.Time `gorm:"default:NULL;comment:Finished Time"` diff --git a/internal/model/user/user.go b/internal/model/user/user.go index 9077c87..3976468 100644 --- a/internal/model/user/user.go +++ b/internal/model/user/user.go @@ -1,11 +1,58 @@ package user import ( + "database/sql/driver" + "encoding/json" "time" "gorm.io/gorm" ) +// JSONInt64Slice is a custom type for handling []int64 as JSON in database +type JSONInt64Slice []int64 + +// Scan implements sql.Scanner interface +func (j *JSONInt64Slice) Scan(value interface{}) error { + if value == nil { + *j = []int64{} + return nil + } + + // Handle []byte + bytes, ok := value.([]byte) + if !ok { + // Try to handle string + str, ok := value.(string) + if !ok { + *j = []int64{} + return nil + } + bytes = []byte(str) + } + + if len(bytes) == 0 { + *j = []int64{} + return nil + } + + // Check if it's a JSON array + if bytes[0] != '[' { + // Not a JSON array, return empty slice + *j = []int64{} + return nil + } + + return json.Unmarshal(bytes, j) +} + +// Value implements driver.Valuer interface +func (j JSONInt64Slice) Value() (driver.Value, error) { + if len(j) == 0 { + return "[]", nil + } + return json.Marshal(j) +} + type User struct { Id int64 `gorm:"primaryKey"` Password string `gorm:"type:varchar(100);not null;comment:User Password"` @@ -43,6 +90,8 @@ type Subscribe struct { User User `gorm:"foreignKey:UserId;references:Id"` OrderId int64 `gorm:"index:idx_order_id;not null;comment:Order ID"` SubscribeId int64 `gorm:"index:idx_subscribe_id;not null;comment:Subscription ID"` + NodeGroupId int64 `gorm:"index:idx_node_group_id;not null;default:0;comment:Node Group ID (single ID)"` + GroupLocked *bool `gorm:"type:tinyint(1);not null;default:0;comment:Group Locked"` StartTime time.Time `gorm:"default:CURRENT_TIMESTAMP(3);not null;comment:Subscription Start Time"` ExpireTime time.Time `gorm:"default:NULL;comment:Subscription Expire Time"` FinishedAt *time.Time `gorm:"default:NULL;comment:Finished Time"` diff --git a/internal/types/types.go b/internal/types/types.go index 091e0b5..219ec50 100644 --- a/internal/types/types.go +++ b/internal/types/types.go @@ -320,14 +320,24 @@ type CreateDocumentRequest struct { Show *bool `json:"show"` } +type CreateNodeGroupRequest struct { + Name string `json:"name" validate:"required"` + Description string `json:"description"` + Sort int `json:"sort"` + ForCalculation *bool `json:"for_calculation"` + MinTrafficGB *int64 `json:"min_traffic_gb,omitempty"` + MaxTrafficGB *int64 `json:"max_traffic_gb,omitempty"` +} + type CreateNodeRequest struct { - Name string `json:"name"` - Tags []string `json:"tags,omitempty"` - Port uint16 `json:"port"` - Address string `json:"address"` - ServerId int64 `json:"server_id"` - Protocol string `json:"protocol"` - Enabled *bool `json:"enabled"` + Name string `json:"name"` + Tags []string `json:"tags,omitempty"` + Port uint16 `json:"port"` + Address string `json:"address"` + ServerId int64 `json:"server_id"` + Protocol string `json:"protocol"` + Enabled *bool `json:"enabled"` + NodeGroupIds []int64 `json:"node_group_ids,omitempty"` } type CreateOrderRequest struct { @@ -420,6 +430,8 @@ type CreateSubscribeRequest struct { Quota int64 `json:"quota"` Nodes []int64 `json:"nodes"` NodeTags []string `json:"node_tags"` + NodeGroupIds []int64 `json:"node_group_ids,omitempty"` + NodeGroupId int64 `json:"node_group_id"` Show *bool `json:"show"` Sell *bool `json:"sell"` DeductionRatio int64 `json:"deduction_ratio"` @@ -427,6 +439,7 @@ type CreateSubscribeRequest struct { ResetCycle int64 `json:"reset_cycle"` RenewalReset *bool `json:"renewal_reset"` ShowOriginalPrice bool `json:"show_original_price"` + AutoCreateGroup bool `json:"auto_create_group"` } type CreateTicketFollowRequest struct { @@ -505,6 +518,10 @@ type DeleteDocumentRequest struct { Id int64 `json:"id" validate:"required"` } +type DeleteNodeGroupRequest struct { + Id int64 `json:"id" validate:"required"` +} + type DeleteNodeRequest struct { Id int64 `json:"id"` } @@ -600,6 +617,10 @@ type EmailAuthticateConfig struct { DomainSuffixList string `json:"domain_suffix_list"` } +type ExportGroupResultRequest struct { + HistoryId *int64 `form:"history_id,omitempty"` +} + type FilterBalanceLogRequest struct { FilterLogParams UserId int64 `form:"user_id,optional"` @@ -658,9 +679,10 @@ type FilterMobileLogResponse struct { } type FilterNodeListRequest struct { - Page int `form:"page"` - Size int `form:"size"` - Search string `form:"search,omitempty"` + Page int `form:"page"` + Size int `form:"size"` + Search string `form:"search,omitempty"` + NodeGroupId *int64 `form:"node_group_id,omitempty"` } type FilterNodeListResponse struct { @@ -884,6 +906,37 @@ type GetGlobalConfigResponse struct { WebAd bool `json:"web_ad"` } +type GetGroupConfigRequest struct { + Keys []string `form:"keys,omitempty"` +} + +type GetGroupConfigResponse struct { + Enabled bool `json:"enabled"` + Mode string `json:"mode"` + Config map[string]interface{} `json:"config"` + State RecalculationState `json:"state"` +} + +type GetGroupHistoryDetailRequest struct { + Id int64 `form:"id" validate:"required"` +} + +type GetGroupHistoryDetailResponse struct { + GroupHistoryDetail +} + +type GetGroupHistoryRequest struct { + Page int `form:"page"` + Size int `form:"size"` + GroupMode string `form:"group_mode,omitempty"` + TriggerType string `form:"trigger_type,omitempty"` +} + +type GetGroupHistoryResponse struct { + Total int64 `json:"total"` + List []GroupHistory `json:"list"` +} + type GetLoginLogRequest struct { Page int `form:"page"` Size int `form:"size"` @@ -906,6 +959,17 @@ type GetMessageLogListResponse struct { List []MessageLog `json:"list"` } +type GetNodeGroupListRequest struct { + Page int `form:"page"` + Size int `form:"size"` + GroupId string `form:"group_id,omitempty"` +} + +type GetNodeGroupListResponse struct { + Total int64 `json:"total"` + List []NodeGroup `json:"list"` +} + type GetNodeMultiplierResponse struct { Periods []TimePeriod `json:"periods"` } @@ -1033,11 +1097,19 @@ type GetSubscribeGroupListResponse struct { Total int64 `json:"total"` } +type GetSubscribeGroupMappingRequest struct { +} + +type GetSubscribeGroupMappingResponse struct { + List []SubscribeGroupMappingItem `json:"list"` +} + type GetSubscribeListRequest struct { - Page int64 `form:"page" validate:"required"` - Size int64 `form:"size" validate:"required"` - Language string `form:"language,omitempty"` - Search string `form:"search,omitempty"` + Page int64 `form:"page" validate:"required"` + Size int64 `form:"size" validate:"required"` + Language string `form:"language,omitempty"` + Search string `form:"search,omitempty"` + NodeGroupId int64 `form:"node_group_id,omitempty"` } type GetSubscribeListResponse struct { @@ -1215,6 +1287,25 @@ type GoogleLoginCallbackRequest struct { State string `form:"state"` } +type GroupHistory struct { + Id int64 `json:"id"` + GroupMode string `json:"group_mode"` + TriggerType string `json:"trigger_type"` + TotalUsers int `json:"total_users"` + SuccessCount int `json:"success_count"` + FailedCount int `json:"failed_count"` + StartTime *int64 `json:"start_time,omitempty"` + EndTime *int64 `json:"end_time,omitempty"` + Operator string `json:"operator,omitempty"` + ErrorLog string `json:"error_log,omitempty"` + CreatedAt int64 `json:"created_at"` +} + +type GroupHistoryDetail struct { + GroupHistory + ConfigSnapshot map[string]interface{} `json:"config_snapshot,omitempty"` +} + type HasMigrateSeverNodeResponse struct { HasMigrate bool `json:"has_migrate"` } @@ -1295,17 +1386,19 @@ type ModuleConfig struct { } type Node struct { - Id int64 `json:"id"` - Name string `json:"name"` - Tags []string `json:"tags"` - Port uint16 `json:"port"` - Address string `json:"address"` - ServerId int64 `json:"server_id"` - Protocol string `json:"protocol"` - Enabled *bool `json:"enabled"` - Sort int `json:"sort,omitempty"` - CreatedAt int64 `json:"created_at"` - UpdatedAt int64 `json:"updated_at"` + Id int64 `json:"id"` + Name string `json:"name"` + Tags []string `json:"tags"` + Port uint16 `json:"port"` + Address string `json:"address"` + ServerId int64 `json:"server_id"` + Protocol string `json:"protocol"` + Enabled *bool `json:"enabled"` + Sort int `json:"sort,omitempty"` + NodeGroupId int64 `json:"node_group_id,omitempty"` + NodeGroupIds []int64 `json:"node_group_ids,omitempty"` + CreatedAt int64 `json:"created_at"` + UpdatedAt int64 `json:"updated_at"` } type NodeConfig struct { @@ -1325,6 +1418,25 @@ type NodeDNS struct { Domains []string `json:"domains"` } +type NodeGroup struct { + Id int64 `json:"id"` + Name string `json:"name"` + Description string `json:"description"` + Sort int `json:"sort"` + ForCalculation bool `json:"for_calculation"` + MinTrafficGB int64 `json:"min_traffic_gb,omitempty"` + MaxTrafficGB int64 `json:"max_traffic_gb,omitempty"` + NodeCount int64 `json:"node_count,omitempty"` + CreatedAt int64 `json:"created_at"` + UpdatedAt int64 `json:"updated_at"` +} + +type NodeGroupItem struct { + Id int64 `json:"id"` + Name string `json:"name"` + Nodes []Node `json:"nodes"` +} + type NodeOutbound struct { Name string `json:"name"` Protocol string `json:"protocol"` @@ -1535,6 +1647,15 @@ type PreviewSubscribeTemplateResponse struct { Template string `json:"template"` // 预览的模板内容 } +type PreviewUserNodesRequest struct { + UserId int64 `form:"user_id" validate:"required"` +} + +type PreviewUserNodesResponse struct { + UserId int64 `json:"user_id"` + NodeGroups []NodeGroupItem `json:"node_groups"` +} + type PrivacyPolicyConfig struct { PrivacyPolicy string `json:"privacy_policy"` } @@ -1813,6 +1934,17 @@ type QuotaTask struct { UpdatedAt int64 `json:"updated_at"` } +type RecalculateGroupRequest struct { + Mode string `json:"mode" validate:"required"` + TriggerType string `json:"trigger_type"` // "manual" or "scheduled" +} + +type RecalculationState struct { + State string `json:"state"` + Progress int `json:"progress"` + Total int `json:"total"` +} + type RechargeOrderRequest struct { Amount int64 `json:"amount" validate:"required,gt=0,lte=2000000000"` Payment int64 `json:"payment"` @@ -1890,6 +2022,10 @@ type ResetAllSubscribeTokenResponse struct { Success bool `json:"success"` } +type ResetGroupsRequest struct { + Confirm bool `json:"confirm" validate:"required"` +} + type ResetPasswordRequest struct { Identifier string `json:"identifier"` Email string `json:"email" validate:"required"` @@ -2153,6 +2289,8 @@ type Subscribe struct { Quota int64 `json:"quota"` Nodes []int64 `json:"nodes"` NodeTags []string `json:"node_tags"` + NodeGroupIds []int64 `json:"node_group_ids,omitempty"` + NodeGroupId int64 `json:"node_group_id"` Show bool `json:"show"` Sell bool `json:"sell"` Sort int64 `json:"sort"` @@ -2212,6 +2350,11 @@ type SubscribeGroup struct { UpdatedAt int64 `json:"updated_at"` } +type SubscribeGroupMappingItem struct { + SubscribeName string `json:"subscribe_name"` + NodeGroupName string `json:"node_group_name"` +} + type SubscribeItem struct { Subscribe Sold int64 `json:"sold"` @@ -2465,15 +2608,32 @@ type UpdateDocumentRequest struct { Show *bool `json:"show"` } +type UpdateGroupConfigRequest struct { + Enabled bool `json:"enabled"` + Mode string `json:"mode"` + Config map[string]interface{} `json:"config"` +} + +type UpdateNodeGroupRequest struct { + Id int64 `json:"id" validate:"required"` + Name string `json:"name"` + Description string `json:"description"` + Sort int `json:"sort"` + ForCalculation *bool `json:"for_calculation"` + MinTrafficGB *int64 `json:"min_traffic_gb,omitempty"` + MaxTrafficGB *int64 `json:"max_traffic_gb,omitempty"` +} + type UpdateNodeRequest struct { - Id int64 `json:"id"` - Name string `json:"name"` - Tags []string `json:"tags,omitempty"` - Port uint16 `json:"port"` - Address string `json:"address"` - ServerId int64 `json:"server_id"` - Protocol string `json:"protocol"` - Enabled *bool `json:"enabled"` + Id int64 `json:"id"` + Name string `json:"name"` + Tags []string `json:"tags,omitempty"` + Port uint16 `json:"port"` + Address string `json:"address"` + ServerId int64 `json:"server_id"` + Protocol string `json:"protocol"` + Enabled *bool `json:"enabled"` + NodeGroupIds []int64 `json:"node_group_ids,omitempty"` } type UpdateOrderStatusRequest struct { @@ -2551,6 +2711,8 @@ type UpdateSubscribeRequest struct { Quota int64 `json:"quota"` Nodes []int64 `json:"nodes"` NodeTags []string `json:"node_tags"` + NodeGroupIds []int64 `json:"node_group_ids,omitempty"` + NodeGroupId int64 `json:"node_group_id"` Show *bool `json:"show"` Sell *bool `json:"sell"` Sort int64 `json:"sort"` diff --git a/pkg/turnstile/service.go b/pkg/turnstile/service.go index e3be9a7..9af71e5 100644 --- a/pkg/turnstile/service.go +++ b/pkg/turnstile/service.go @@ -56,7 +56,9 @@ func (s *service) verify(ctx context.Context, secret string, token string, ip st _ = writer.WriteField("idempotency_key", key) } _ = writer.Close() - client := &http.Client{} + client := &http.Client{ + Timeout: 5 * time.Second, + } req, _ := http.NewRequest("POST", s.url, body) req.Header.Set("Content-Type", writer.FormDataContentType()) firstResult, err := client.Do(req) diff --git a/ppanel.api b/ppanel.api index 3e6f0d9..b142e91 100644 --- a/ppanel.api +++ b/ppanel.api @@ -30,6 +30,7 @@ import ( "apis/admin/ads.api" "apis/admin/marketing.api" "apis/admin/application.api" + "apis/admin/group.api" "apis/public/user.api" "apis/public/subscribe.api" "apis/public/redemption.api" diff --git a/queue/handler/routes.go b/queue/handler/routes.go index 2e96219..c3df06b 100644 --- a/queue/handler/routes.go +++ b/queue/handler/routes.go @@ -3,6 +3,7 @@ package handler import ( "github.com/hibiken/asynq" "github.com/perfect-panel/server/internal/svc" + groupLogic "github.com/perfect-panel/server/queue/logic/group" orderLogic "github.com/perfect-panel/server/queue/logic/order" smslogic "github.com/perfect-panel/server/queue/logic/sms" "github.com/perfect-panel/server/queue/logic/subscription" @@ -43,4 +44,7 @@ func RegisterHandlers(mux *asynq.ServeMux, serverCtx *svc.ServiceContext) { // ForthwithQuotaTask mux.Handle(types.ForthwithQuotaTask, task.NewQuotaTaskLogic(serverCtx)) + + // SchedulerRecalculateGroup + mux.Handle(types.SchedulerRecalculateGroup, groupLogic.NewRecalculateGroupLogic(serverCtx)) } diff --git a/queue/logic/group/recalculateGroupLogic.go b/queue/logic/group/recalculateGroupLogic.go new file mode 100644 index 0000000..91b3685 --- /dev/null +++ b/queue/logic/group/recalculateGroupLogic.go @@ -0,0 +1,87 @@ +package group + +import ( + "context" + "time" + + "github.com/perfect-panel/server/internal/logic/admin/group" + "github.com/perfect-panel/server/internal/svc" + "github.com/perfect-panel/server/internal/types" + "github.com/perfect-panel/server/pkg/logger" + + "github.com/hibiken/asynq" +) + +type RecalculateGroupLogic struct { + svc *svc.ServiceContext +} + +func NewRecalculateGroupLogic(svc *svc.ServiceContext) *RecalculateGroupLogic { + return &RecalculateGroupLogic{ + svc: svc, + } +} + +func (l *RecalculateGroupLogic) ProcessTask(ctx context.Context, t *asynq.Task) error { + logger.Infof("[RecalculateGroup] Starting scheduled group recalculation: %s", time.Now().Format("2006-01-02 15:04:05")) + + // 1. Check if group management is enabled + var enabledConfig struct { + Value string `gorm:"column:value"` + } + err := l.svc.DB.Table("system"). + Where("`category` = ? AND `key` = ?", "group", "enabled"). + Select("value"). + First(&enabledConfig).Error + if err != nil { + logger.Errorw("[RecalculateGroup] Failed to read group enabled config", logger.Field("error", err.Error())) + return err + } + + // If not enabled, skip execution + if enabledConfig.Value != "true" && enabledConfig.Value != "1" { + logger.Debugf("[RecalculateGroup] Group management is not enabled, skipping") + return nil + } + + // 2. Get grouping mode + var modeConfig struct { + Value string `gorm:"column:value"` + } + err = l.svc.DB.Table("system"). + Where("`category` = ? AND `key` = ?", "group", "mode"). + Select("value"). + First(&modeConfig).Error + if err != nil { + logger.Errorw("[RecalculateGroup] Failed to read group mode config", logger.Field("error", err.Error())) + return err + } + + mode := modeConfig.Value + if mode == "" { + mode = "average" // default mode + } + + // 3. Only execute if mode is "traffic" + if mode != "traffic" { + logger.Debugf("[RecalculateGroup] Group mode is not 'traffic' (current: %s), skipping", mode) + return nil + } + + // 4. Execute traffic-based grouping + logger.Infof("[RecalculateGroup] Executing traffic-based grouping") + + logic := group.NewRecalculateGroupLogic(ctx, l.svc) + req := &types.RecalculateGroupRequest{ + Mode: "traffic", + TriggerType: "scheduled", + } + + if err := logic.RecalculateGroup(req); err != nil { + logger.Errorw("[RecalculateGroup] Failed to execute traffic grouping", logger.Field("error", err.Error())) + return err + } + + logger.Infof("[RecalculateGroup] Successfully completed traffic-based grouping: %s", time.Now().Format("2006-01-02 15:04:05")) + return nil +} diff --git a/queue/logic/order/activateOrderLogic.go b/queue/logic/order/activateOrderLogic.go index 32032cf..24a03e2 100644 --- a/queue/logic/order/activateOrderLogic.go +++ b/queue/logic/order/activateOrderLogic.go @@ -9,6 +9,7 @@ import ( "strconv" "time" + "github.com/perfect-panel/server/internal/logic/admin/group" "github.com/perfect-panel/server/internal/model/log" "github.com/perfect-panel/server/pkg/constant" "github.com/perfect-panel/server/pkg/logger" @@ -22,9 +23,10 @@ import ( "github.com/perfect-panel/server/internal/model/subscribe" "github.com/perfect-panel/server/internal/model/user" "github.com/perfect-panel/server/internal/svc" + "github.com/perfect-panel/server/internal/types" "github.com/perfect-panel/server/pkg/tool" "github.com/perfect-panel/server/pkg/uuidx" - "github.com/perfect-panel/server/queue/types" + queueTypes "github.com/perfect-panel/server/queue/types" "gorm.io/gorm" ) @@ -93,8 +95,8 @@ func (l *ActivateOrderLogic) ProcessTask(ctx context.Context, task *asynq.Task) } // parsePayload unMarshals the task payload into a structured format -func (l *ActivateOrderLogic) parsePayload(ctx context.Context, payload []byte) (*types.ForthwithActivateOrderPayload, error) { - var p types.ForthwithActivateOrderPayload +func (l *ActivateOrderLogic) parsePayload(ctx context.Context, payload []byte) (*queueTypes.ForthwithActivateOrderPayload, error) { + var p queueTypes.ForthwithActivateOrderPayload if err := json.Unmarshal(payload, &p); err != nil { logger.WithContext(ctx).Error("[ActivateOrderLogic] Unmarshal payload failed", logger.Field("error", err.Error()), @@ -196,6 +198,9 @@ func (l *ActivateOrderLogic) NewPurchase(ctx context.Context, orderInfo *order.O return err } + // Trigger user group recalculation (runs in background) + l.triggerUserGroupRecalculation(ctx, userInfo.Id) + // Handle commission in separate goroutine to avoid blocking go l.handleCommission(context.Background(), userInfo, orderInfo) @@ -357,6 +362,7 @@ func (l *ActivateOrderLogic) createUserSubscription(ctx context.Context, orderIn Token: uuidx.SubscribeToken(orderInfo.OrderNo), UUID: uuid.New().String(), Status: 1, + NodeGroupId: sub.NodeGroupId, // Inherit node_group_id from subscription plan } // Check quota limit before creating subscription (final safeguard) @@ -505,6 +511,63 @@ func (l *ActivateOrderLogic) clearServerCache(ctx context.Context, sub *subscrib } } +// triggerUserGroupRecalculation triggers user group recalculation after subscription changes +// This runs asynchronously in background to avoid blocking the main order processing flow +func (l *ActivateOrderLogic) triggerUserGroupRecalculation(ctx context.Context, userId int64) { + go func() { + // Use a new context with timeout for group recalculation + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + // Check if group management is enabled + var groupEnabled string + err := l.svc.DB.Table("system"). + Where("`category` = ? AND `key` = ?", "group", "enabled"). + Select("value"). + Scan(&groupEnabled).Error + if err != nil || groupEnabled != "true" && groupEnabled != "1" { + logger.Debugf("[Group Trigger] Group management not enabled, skipping recalculation") + return + } + + // Get the configured grouping mode + var groupMode string + err = l.svc.DB.Table("system"). + Where("`category` = ? AND `key` = ?", "group", "mode"). + Select("value"). + Scan(&groupMode).Error + if err != nil { + logger.Errorw("[Group Trigger] Failed to get group mode", logger.Field("error", err.Error())) + return + } + + // Validate group mode + if groupMode != "average" && groupMode != "subscribe" && groupMode != "traffic" { + logger.Debugf("[Group Trigger] Invalid group mode (current: %s), skipping", groupMode) + return + } + + // Trigger group recalculation with the configured mode + logic := group.NewRecalculateGroupLogic(ctx, l.svc) + req := &types.RecalculateGroupRequest{ + Mode: groupMode, + } + + if err := logic.RecalculateGroup(req); err != nil { + logger.Errorw("[Group Trigger] Failed to recalculate user group", + logger.Field("user_id", userId), + logger.Field("error", err.Error()), + ) + return + } + + logger.Infow("[Group Trigger] Successfully recalculated user group", + logger.Field("user_id", userId), + logger.Field("mode", groupMode), + ) + }() +} + // Renewal handles subscription renewal including subscription extension, // traffic reset (if configured), commission processing, and notifications func (l *ActivateOrderLogic) Renewal(ctx context.Context, orderInfo *order.Order) error { @@ -898,6 +961,7 @@ func (l *ActivateOrderLogic) RedemptionActivate(ctx context.Context, orderInfo * Traffic: us.Traffic, Download: us.Download, Upload: us.Upload, + NodeGroupId: us.NodeGroupId, } break } @@ -984,6 +1048,7 @@ func (l *ActivateOrderLogic) RedemptionActivate(ctx context.Context, orderInfo * Token: uuidx.SubscribeToken(orderInfo.OrderNo), UUID: uuid.New().String(), Status: 1, + NodeGroupId: sub.NodeGroupId, // Inherit node_group_id from subscription plan } err = l.svc.UserModel.InsertSubscribe(ctx, newSubscribe, tx) @@ -1030,6 +1095,9 @@ func (l *ActivateOrderLogic) RedemptionActivate(ctx context.Context, orderInfo * return err } + // Trigger user group recalculation (runs in background) + l.triggerUserGroupRecalculation(ctx, userInfo.Id) + // 7. 清理缓存(关键步骤:让节点获取最新订阅) l.clearServerCache(ctx, sub) diff --git a/queue/logic/subscription/checkSubscriptionLogic.go b/queue/logic/subscription/checkSubscriptionLogic.go index 81b86e7..02fbf01 100644 --- a/queue/logic/subscription/checkSubscriptionLogic.go +++ b/queue/logic/subscription/checkSubscriptionLogic.go @@ -62,7 +62,6 @@ func (l *CheckSubscriptionLogic) ProcessTask(ctx context.Context, _ *asynq.Task) } l.clearServerCache(ctx, list...) logger.Infow("[Check Subscription Traffic] Update subscribe status", logger.Field("user_ids", ids), logger.Field("count", int64(len(ids)))) - } else { logger.Info("[Check Subscription Traffic] No subscribe need to update") } @@ -108,6 +107,7 @@ func (l *CheckSubscriptionLogic) ProcessTask(ctx context.Context, _ *asynq.Task) } else { logger.Info("[Check Subscription Expire] No subscribe need to update") } + return nil }) if err != nil { diff --git a/queue/types/scheduler.go b/queue/types/scheduler.go index 51ef48c..86875ce 100644 --- a/queue/types/scheduler.go +++ b/queue/types/scheduler.go @@ -5,4 +5,5 @@ const ( SchedulerTotalServerData = "scheduler:total:server" SchedulerResetTraffic = "scheduler:reset:traffic" SchedulerTrafficStat = "scheduler:traffic:stat" + SchedulerRecalculateGroup = "scheduler:recalculate:group" ) diff --git a/scheduler/scheduler.go b/scheduler/scheduler.go index bf131e8..02afa60 100644 --- a/scheduler/scheduler.go +++ b/scheduler/scheduler.go @@ -52,6 +52,12 @@ func (m *Service) Start() { logger.Errorf("register update exchange rate task failed: %s", err.Error()) } + // schedule recalculate group task: every hour + recalculateGroupTask := asynq.NewTask(types.SchedulerRecalculateGroup, nil) + if _, err := m.server.Register("@every 6h", recalculateGroupTask, asynq.MaxRetry(2)); err != nil { + logger.Errorf("register recalculate group task failed: %s", err.Error()) + } + if err := m.server.Run(); err != nil { logger.Errorf("run scheduler failed: %s", err.Error()) } From 0dbcff85f1b0fa69a8656d6b8107eac9cc1ac172 Mon Sep 17 00:00:00 2001 From: EUForest Date: Mon, 9 Mar 2026 22:53:13 +0800 Subject: [PATCH 03/18] feat(captcha): add captcha service interface and implementations - Add captcha service interface with Generate and Verify methods - Implement local image captcha using base64Captcha library - Implement Cloudflare Turnstile verification wrapper - Support Redis-based captcha storage with 5-minute expiration - Add factory method for creating captcha service instances --- pkg/captcha/local.go | 98 ++++++++++++++++++++++++++++++++++++++++ pkg/captcha/service.go | 49 ++++++++++++++++++++ pkg/captcha/turnstile.go | 37 +++++++++++++++ 3 files changed, 184 insertions(+) create mode 100644 pkg/captcha/local.go create mode 100644 pkg/captcha/service.go create mode 100644 pkg/captcha/turnstile.go diff --git a/pkg/captcha/local.go b/pkg/captcha/local.go new file mode 100644 index 0000000..ba6b917 --- /dev/null +++ b/pkg/captcha/local.go @@ -0,0 +1,98 @@ +package captcha + +import ( + "context" + "fmt" + "time" + + "github.com/mojocn/base64Captcha" + "github.com/redis/go-redis/v9" +) + +type localService struct { + redis *redis.Client + driver base64Captcha.Driver +} + +func newLocalService(redisClient *redis.Client) Service { + // Configure captcha driver + driver := base64Captcha.NewDriverDigit(80, 240, 5, 0.7, 80) + return &localService{ + redis: redisClient, + driver: driver, + } +} + +func (s *localService) Generate(ctx context.Context) (id string, image string, err error) { + // Generate captcha + captcha := base64Captcha.NewCaptcha(s.driver, &redisStore{ + redis: s.redis, + ctx: ctx, + }) + + id, b64s, answer, err := captcha.Generate() + if err != nil { + return "", "", err + } + + // Store answer in Redis with 5 minute expiration + key := fmt.Sprintf("captcha:%s", id) + err = s.redis.Set(ctx, key, answer, 5*time.Minute).Err() + if err != nil { + return "", "", err + } + + return id, b64s, nil +} + +func (s *localService) Verify(ctx context.Context, id string, code string, ip string) (bool, error) { + if id == "" || code == "" { + return false, nil + } + + key := fmt.Sprintf("captcha:%s", id) + + // Get answer from Redis + answer, err := s.redis.Get(ctx, key).Result() + if err != nil { + return false, err + } + + // Delete captcha after verification (one-time use) + s.redis.Del(ctx, key) + + // Verify code + return answer == code, nil +} + +func (s *localService) GetType() CaptchaType { + return CaptchaTypeLocal +} + +// redisStore implements base64Captcha.Store interface +type redisStore struct { + redis *redis.Client + ctx context.Context +} + +func (r *redisStore) Set(id string, value string) error { + key := fmt.Sprintf("captcha:%s", id) + return r.redis.Set(r.ctx, key, value, 5*time.Minute).Err() +} + +func (r *redisStore) Get(id string, clear bool) string { + key := fmt.Sprintf("captcha:%s", id) + val, err := r.redis.Get(r.ctx, key).Result() + if err != nil { + return "" + } + if clear { + r.redis.Del(r.ctx, key) + } + return val +} + +func (r *redisStore) Verify(id, answer string, clear bool) bool { + v := r.Get(id, clear) + return v == answer +} diff --git a/pkg/captcha/service.go b/pkg/captcha/service.go new file mode 100644 index 0000000..4536227 --- /dev/null +++ b/pkg/captcha/service.go @@ -0,0 +1,49 @@ +package captcha + +import ( + "context" + + "github.com/redis/go-redis/v9" +) + +type CaptchaType string + +const ( + CaptchaTypeLocal CaptchaType = "local" + CaptchaTypeTurnstile CaptchaType = "turnstile" +) + +// Service defines the captcha service interface +type Service interface { + // Generate generates a new captcha + // For local captcha: returns id and base64 image + // For turnstile: returns empty strings + Generate(ctx context.Context) (id string, image string, err error) + + // Verify verifies the captcha + // For local captcha: token is captcha id, code is user input + // For turnstile: token is cf-turnstile-response, code is ignored + Verify(ctx context.Context, token string, code string, ip string) (bool, error) + + // GetType returns the captcha type + GetType() CaptchaType +} + +// Config holds the configuration for captcha service +type Config struct { + Type CaptchaType + RedisClient *redis.Client + TurnstileSecret string +} + +// NewService creates a new captcha service based on the config +func NewService(config Config) Service { + switch config.Type { + case CaptchaTypeTurnstile: + return newTurnstileService(config.TurnstileSecret) + case CaptchaTypeLocal: + fallthrough + default: + return newLocalService(config.RedisClient) + } +} diff --git a/pkg/captcha/turnstile.go b/pkg/captcha/turnstile.go new file mode 100644 index 0000000..52e5bca --- /dev/null +++ b/pkg/captcha/turnstile.go @@ -0,0 +1,37 @@ +package captcha + +import ( + "context" + + "github.com/perfect-panel/server/pkg/turnstile" +) + +type turnstileService struct { + service turnstile.Service +} + +func newTurnstileService(secret string) Service { + return &turnstileService{ + service: turnstile.New(turnstile.Config{ + Secret: secret, + }), + } +} + +func (s *turnstileService) Generate(ctx context.Context) (id string, image string, err error) { + // Turnstile doesn't need server-side generation + return "", "", nil +} + +func (s *turnstileService) Verify(ctx context.Context, token string, code string, ip string) (bool, error) { + if token == "" { + return false, nil + } + + // Verify with Cloudflare Turnstile + return s.service.Verify(ctx, token, ip) +} + +func (s *turnstileService) GetType() CaptchaType { + return CaptchaTypeTurnstile +} From 36119b842c96218e6233de98eaad766dcb816bd0 Mon Sep 17 00:00:00 2001 From: EUForest Date: Mon, 9 Mar 2026 22:53:34 +0800 Subject: [PATCH 04/18] build(deps): add base64Captcha library for local captcha generation - Add github.com/mojocn/base64Captcha v1.3.6 - Add github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 (indirect) - Add golang.org/x/image v0.23.0 (indirect) --- go.mod | 5 ++++- go.sum | 31 +++++++++++++++++++++++++++++++ 2 files changed, 35 insertions(+), 1 deletion(-) diff --git a/go.mod b/go.mod index aa17478..7e34be3 100644 --- a/go.mod +++ b/go.mod @@ -27,7 +27,7 @@ require ( github.com/jinzhu/copier v0.4.0 github.com/klauspost/compress v1.17.7 github.com/nyaruka/phonenumbers v1.5.0 - github.com/pkg/errors v0.9.1 + github.com/pkg/errors v0.9.1 github.com/redis/go-redis/v9 v9.7.2 github.com/smartwalle/alipay/v3 v3.2.23 github.com/spf13/cast v1.7.0 // indirect @@ -94,6 +94,7 @@ require ( github.com/gin-contrib/sse v1.0.0 // indirect github.com/go-logr/logr v1.4.2 // indirect github.com/go-logr/stdr v1.2.2 // indirect + github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 // indirect github.com/golang/glog v1.2.0 // indirect github.com/golang/mock v1.6.0 // indirect github.com/golang/protobuf v1.5.4 // indirect @@ -117,6 +118,7 @@ require ( github.com/mitchellh/reflectwalk v1.0.2 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect + github.com/mojocn/base64Captcha v1.3.8 // indirect github.com/openzipkin/zipkin-go v0.4.2 // indirect github.com/oschwald/maxminddb-golang v1.13.0 // indirect github.com/pelletier/go-toml/v2 v2.2.3 // indirect @@ -139,6 +141,7 @@ require ( go.uber.org/multierr v1.11.0 // indirect golang.org/x/arch v0.13.0 // indirect golang.org/x/exp v0.0.0-20240525044651-4c93da0ed11d // indirect + golang.org/x/image v0.23.0 // indirect golang.org/x/net v0.34.0 // indirect golang.org/x/sys v0.30.0 // indirect golang.org/x/text v0.22.0 // indirect diff --git a/go.sum b/go.sum index aa84f0e..2b62dca 100644 --- a/go.sum +++ b/go.sum @@ -159,6 +159,8 @@ github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeD github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/golang-migrate/migrate/v4 v4.18.2 h1:2VSCMz7x7mjyTXx3m2zPokOY82LTRgxK1yQYKo6wWQ8= github.com/golang-migrate/migrate/v4 v4.18.2/go.mod h1:2CM6tJvn2kqPXwnXO/d3rAQYiyoIm180VsO8PRX6Rpk= +github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 h1:DACJavvAHhabrF08vX0COfcOBJRhZ8lUbR+ZWIs0Y5g= +github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0/go.mod h1:E/TSTwGwJL78qG/PmXZO1EjYhfJinVAhrmmHX6Z8B9k= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/glog v1.2.0 h1:uCdmnmatrKCgMBlM4rMuJZWOkPDqdbZPnrMXDY4gI68= github.com/golang/glog v1.2.0/go.mod h1:6AhwSGph0fcJtXVM/PEHPqZlFeoLxhs7/t5UDAwmO+w= @@ -274,6 +276,8 @@ github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lN github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= +github.com/mojocn/base64Captcha v1.3.8 h1:rrN9BhCwXKS8ht1e21kvR3iTaMgf4qPC9sRoV52bqEg= +github.com/mojocn/base64Captcha v1.3.8/go.mod h1:QFZy927L8HVP3+VV5z2b1EAEiv1KxVJKZbAucVgLUy4= github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A= github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc= github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= @@ -405,12 +409,17 @@ golang.org/x/crypto v0.0.0-20200510223506-06a226fb4e37/go.mod h1:LzIPMQfyMNhhGPh golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20201012173705-84dcc777aaee/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc= golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg= +golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= +golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= golang.org/x/crypto v0.35.0 h1:b15kiHdrGCHrP6LvwaQ3c03kgNhhiMgvlhxHQhmg2Xs= golang.org/x/crypto v0.35.0/go.mod h1:dy7dXNW32cAb/6/PRuTNsix8T+vJAqvuIy5Bli/x0YQ= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20240525044651-4c93da0ed11d h1:N0hmiNbwsSNwHBAvR3QB5w25pUwH4tK0Y/RltD1j1h4= golang.org/x/exp v0.0.0-20240525044651-4c93da0ed11d/go.mod h1:XtvwrStGgqGPLc4cjQfWqZHG1YFdYs6swckp8vpsjnc= +golang.org/x/image v0.23.0 h1:HseQ7c2OpPKTPVzNjG5fwJsOTCiiwS4QdsYi5XU6H68= +golang.org/x/image v0.23.0/go.mod h1:wJJBTdLfCCf3tiHa1fNxpZmUI4mmoZvwMCPP0ddoNKY= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= @@ -419,6 +428,9 @@ golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= +golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= +golang.org/x/mod v0.15.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= +golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -434,7 +446,10 @@ golang.org/x/net v0.0.0-20210520170846-37e1c6afe023/go.mod h1:9nx3DQGgdP8bBQD5qx golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= +golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk= golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY= +golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= +golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= golang.org/x/net v0.34.0 h1:Mb7Mrk043xzHgnRM88suvJFwzVrRfHEHJEl5/71CKw0= golang.org/x/net v0.34.0/go.mod h1:di0qlW3YNM5oh6GqDGQr92MyTozJPmybPK4Ev/Gm31k= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= @@ -448,6 +463,10 @@ golang.org/x/sync v0.0.0-20200317015054-43a5402ce75a/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= +golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -466,14 +485,21 @@ golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc= golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/telemetry v0.0.0-20240228155512-f48c80bd79b2/go.mod h1:TeRTkGYfJXctD9OcfyVLyj2J3IxLnKwHJR8f4D8a3YE= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= +golang.org/x/term v0.12.0/go.mod h1:owVbMEjm3cBLCHdkQu9b1opXd4ETQWc3BhuQGKgXgvU= golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY= +golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk= +golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= @@ -481,7 +507,10 @@ golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= +golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= golang.org/x/text v0.22.0 h1:bofq7m3/HAFvbF51jz3Q9wLg3jkvSPuiZu/pD1XwgtM= golang.org/x/text v0.22.0/go.mod h1:YRoo4H8PVmsu+E3Ou7cqLVH8oXWIHVoX0jqUWALQhfY= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= @@ -499,6 +528,8 @@ golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.5/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= +golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58= +golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= From 2fd22c97e0df244f86712f85e7450892a7cab0e2 Mon Sep 17 00:00:00 2001 From: EUForest Date: Mon, 9 Mar 2026 22:53:45 +0800 Subject: [PATCH 05/18] feat(migration): add captcha configuration migration - Add CaptchaType field for selecting captcha implementation - Add EnableUserLoginCaptcha for user login verification - Add EnableUserRegisterCaptcha for user registration verification - Add EnableAdminLoginCaptcha for admin login verification - Add EnableUserResetPasswordCaptcha for password reset verification - Remove deprecated EnableLoginVerify, EnableRegisterVerify, EnableResetPasswordVerify fields - Support rollback with down migration --- .../02132_update_verify_config.down.sql | 17 +++++++++++++++++ .../database/02132_update_verify_config.up.sql | 17 +++++++++++++++++ 2 files changed, 34 insertions(+) create mode 100644 initialize/migrate/database/02132_update_verify_config.down.sql create mode 100644 initialize/migrate/database/02132_update_verify_config.up.sql diff --git a/initialize/migrate/database/02132_update_verify_config.down.sql b/initialize/migrate/database/02132_update_verify_config.down.sql new file mode 100644 index 0000000..2c66df3 --- /dev/null +++ b/initialize/migrate/database/02132_update_verify_config.down.sql @@ -0,0 +1,17 @@ +-- Rollback: restore old verify configuration fields +INSERT INTO `system` (`category`, `key`, `value`, `type`, `desc`) VALUES + ('verify', 'EnableLoginVerify', 'false', 'bool', 'is enable login verify'), + ('verify', 'EnableRegisterVerify', 'false', 'bool', 'is enable register verify'), + ('verify', 'EnableResetPasswordVerify', 'false', 'bool', 'is enable reset password verify') +ON DUPLICATE KEY UPDATE + `value` = VALUES(`value`), + `desc` = VALUES(`desc`); + +-- Remove new captcha configuration fields +DELETE FROM `system` WHERE `category` = 'verify' AND `key` IN ( + 'CaptchaType', + 'EnableUserLoginCaptcha', + 'EnableUserRegisterCaptcha', + 'EnableAdminLoginCaptcha', + 'EnableUserResetPasswordCaptcha' +); diff --git a/initialize/migrate/database/02132_update_verify_config.up.sql b/initialize/migrate/database/02132_update_verify_config.up.sql new file mode 100644 index 0000000..b6d5137 --- /dev/null +++ b/initialize/migrate/database/02132_update_verify_config.up.sql @@ -0,0 +1,17 @@ +-- Add new captcha configuration fields +INSERT INTO `system` (`category`, `key`, `value`, `type`, `desc`) VALUES + ('verify', 'CaptchaType', 'local', 'string', 'Captcha type: local or turnstile'), + ('verify', 'EnableUserLoginCaptcha', 'false', 'bool', 'Enable captcha for user login'), + ('verify', 'EnableUserRegisterCaptcha', 'false', 'bool', 'Enable captcha for user registration'), + ('verify', 'EnableAdminLoginCaptcha', 'false', 'bool', 'Enable captcha for admin login'), + ('verify', 'EnableUserResetPasswordCaptcha', 'false', 'bool', 'Enable captcha for user reset password') +ON DUPLICATE KEY UPDATE + `value` = VALUES(`value`), + `desc` = VALUES(`desc`); + +-- Remove old verify configuration fields +DELETE FROM `system` WHERE `category` = 'verify' AND `key` IN ( + 'EnableLoginVerify', + 'EnableRegisterVerify', + 'EnableResetPasswordVerify' +); From 0f6fddc36d64810653cbebd22d8edcee52f0b682 Mon Sep 17 00:00:00 2001 From: EUForest Date: Mon, 9 Mar 2026 22:53:59 +0800 Subject: [PATCH 06/18] feat(error): add PermissionDenied error code - Add error code 40008 for permission denied scenarios - Add corresponding error message for admin permission checks --- pkg/xerr/errCode.go | 1 + pkg/xerr/errMsg.go | 1 + 2 files changed, 2 insertions(+) diff --git a/pkg/xerr/errCode.go b/pkg/xerr/errCode.go index c37e9b9..f5fb4b5 100644 --- a/pkg/xerr/errCode.go +++ b/pkg/xerr/errCode.go @@ -50,6 +50,7 @@ const ( InvalidAccess uint32 = 40005 InvalidCiphertext uint32 = 40006 SecretIsEmpty uint32 = 40007 + PermissionDenied uint32 = 40008 ) //coupon error diff --git a/pkg/xerr/errMsg.go b/pkg/xerr/errMsg.go index ed054ae..6987e30 100644 --- a/pkg/xerr/errMsg.go +++ b/pkg/xerr/errMsg.go @@ -17,6 +17,7 @@ func init() { SecretIsEmpty: "Secret is empty", InvalidAccess: "Invalid access", InvalidCiphertext: "Invalid ciphertext", + PermissionDenied: "Permission denied", // Database error DatabaseQueryError: "Database query error", DatabaseUpdateError: "Database update error", From eb327b26b9712e14f27a3e672d7ba9b6c9fac9b6 Mon Sep 17 00:00:00 2001 From: EUForest Date: Mon, 9 Mar 2026 22:54:08 +0800 Subject: [PATCH 07/18] feat(api): add captcha fields and admin authentication endpoints - Add CaptchaId and CaptchaCode fields to login/register/reset requests - Add /v1/auth/captcha/generate endpoint for user captcha generation - Add /v1/auth/admin/login endpoint for admin authentication - Add /v1/auth/admin/reset-password endpoint for admin password reset - Add /v1/auth/admin/captcha/generate endpoint for admin captcha generation - Update GlobalConfigResponse with new verify configuration fields - Add GenerateCaptchaResponse type for captcha generation --- apis/auth/auth.api | 90 +++++++++++++++++++++++++++++++++------------- apis/common.api | 10 +++--- apis/types.api | 12 ++++--- 3 files changed, 78 insertions(+), 34 deletions(-) diff --git a/apis/auth/auth.api b/apis/auth/auth.api index 6e6d03c..c1d3c8b 100644 --- a/apis/auth/auth.api +++ b/apis/auth/auth.api @@ -11,13 +11,15 @@ info ( type ( // User login request UserLoginRequest { - Identifier string `json:"identifier"` - Email string `json:"email" validate:"required"` - Password string `json:"password" validate:"required"` - IP string `header:"X-Original-Forwarded-For"` - UserAgent string `header:"User-Agent"` - LoginType string `header:"Login-Type"` - CfToken string `json:"cf_token,optional"` + Identifier string `json:"identifier"` + Email string `json:"email" validate:"required"` + Password string `json:"password" validate:"required"` + IP string `header:"X-Original-Forwarded-For"` + UserAgent string `header:"User-Agent"` + LoginType string `header:"Login-Type"` + CfToken string `json:"cf_token,optional"` + CaptchaId string `json:"captcha_id,optional"` + CaptchaCode string `json:"captcha_code,optional"` } // Check user is exist request CheckUserRequest { @@ -29,26 +31,30 @@ type ( } // User login response UserRegisterRequest { - Identifier string `json:"identifier"` - Email string `json:"email" validate:"required"` - Password string `json:"password" validate:"required"` - Invite string `json:"invite,optional"` - Code string `json:"code,optional"` - IP string `header:"X-Original-Forwarded-For"` - UserAgent string `header:"User-Agent"` - LoginType string `header:"Login-Type"` - CfToken string `json:"cf_token,optional"` + Identifier string `json:"identifier"` + Email string `json:"email" validate:"required"` + Password string `json:"password" validate:"required"` + Invite string `json:"invite,optional"` + Code string `json:"code,optional"` + IP string `header:"X-Original-Forwarded-For"` + UserAgent string `header:"User-Agent"` + LoginType string `header:"Login-Type"` + CfToken string `json:"cf_token,optional"` + CaptchaId string `json:"captcha_id,optional"` + CaptchaCode string `json:"captcha_code,optional"` } - // User login response + // User reset password request ResetPasswordRequest { - Identifier string `json:"identifier"` - Email string `json:"email" validate:"required"` - Password string `json:"password" validate:"required"` - Code string `json:"code,optional"` - IP string `header:"X-Original-Forwarded-For"` - UserAgent string `header:"User-Agent"` - LoginType string `header:"Login-Type"` - CfToken string `json:"cf_token,optional"` + Identifier string `json:"identifier"` + Email string `json:"email" validate:"required"` + Password string `json:"password" validate:"required"` + Code string `json:"code,optional"` + IP string `header:"X-Original-Forwarded-For"` + UserAgent string `header:"User-Agent"` + LoginType string `header:"Login-Type"` + CfToken string `json:"cf_token,optional"` + CaptchaId string `json:"captcha_id,optional"` + CaptchaCode string `json:"captcha_code,optional"` } LoginResponse { Token string `json:"token"` @@ -75,6 +81,8 @@ type ( UserAgent string `header:"User-Agent"` LoginType string `header:"Login-Type"` CfToken string `json:"cf_token,optional"` + CaptchaId string `json:"captcha_id,optional"` + CaptchaCode string `json:"captcha_code,optional"` } // Check user is exist request TelephoneCheckUserRequest { @@ -97,6 +105,8 @@ type ( UserAgent string `header:"User-Agent"` LoginType string `header:"Login-Type,optional"` CfToken string `json:"cf_token,optional"` + CaptchaId string `json:"captcha_id,optional"` + CaptchaCode string `json:"captcha_code,optional"` } // User login response TelephoneResetPasswordRequest { @@ -109,6 +119,8 @@ type ( UserAgent string `header:"User-Agent"` LoginType string `header:"Login-Type,optional"` CfToken string `json:"cf_token,optional"` + CaptchaId string `json:"captcha_id,optional"` + CaptchaCode string `json:"captcha_code,optional"` } AppleLoginCallbackRequest { Code string `form:"code"` @@ -126,6 +138,11 @@ type ( CfToken string `json:"cf_token,optional"` ShortCode string `json:"short_code,optional"` } + GenerateCaptchaResponse { + Id string `json:"id"` + Image string `json:"image"` + Type string `json:"type"` + } ) @server ( @@ -166,11 +183,34 @@ service ppanel { @handler TelephoneResetPassword post /reset/telephone (TelephoneResetPasswordRequest) returns (LoginResponse) + @doc "Generate captcha" + @handler GenerateCaptcha + post /captcha/generate returns (GenerateCaptchaResponse) + @doc "Device Login" @handler DeviceLogin post /login/device (DeviceLoginRequest) returns (LoginResponse) } +@server ( + prefix: v1/auth/admin + group: auth/admin + middleware: DeviceMiddleware +) +service ppanel { + @doc "Admin login" + @handler AdminLogin + post /login (UserLoginRequest) returns (LoginResponse) + + @doc "Admin reset password" + @handler AdminResetPassword + post /reset (ResetPasswordRequest) returns (LoginResponse) + + @doc "Generate captcha" + @handler AdminGenerateCaptcha + post /captcha/generate returns (GenerateCaptchaResponse) +} + @server ( prefix: v1/auth/oauth group: auth/oauth diff --git a/apis/common.api b/apis/common.api index db935f4..d0ecb42 100644 --- a/apis/common.api +++ b/apis/common.api @@ -12,10 +12,12 @@ import "./types.api" type ( VeifyConfig { - TurnstileSiteKey string `json:"turnstile_site_key"` - EnableLoginVerify bool `json:"enable_login_verify"` - EnableRegisterVerify bool `json:"enable_register_verify"` - EnableResetPasswordVerify bool `json:"enable_reset_password_verify"` + CaptchaType string `json:"captcha_type"` + TurnstileSiteKey string `json:"turnstile_site_key"` + EnableUserLoginCaptcha bool `json:"enable_user_login_captcha"` + EnableUserRegisterCaptcha bool `json:"enable_user_register_captcha"` + EnableAdminLoginCaptcha bool `json:"enable_admin_login_captcha"` + EnableUserResetPasswordCaptcha bool `json:"enable_user_reset_password_captcha"` } GetGlobalConfigResponse { Site SiteConfig `json:"site"` diff --git a/apis/types.api b/apis/types.api index e5ffbd1..4b9c7f6 100644 --- a/apis/types.api +++ b/apis/types.api @@ -154,11 +154,13 @@ type ( DeviceLimit int64 `json:"device_limit"` } VerifyConfig { - TurnstileSiteKey string `json:"turnstile_site_key"` - TurnstileSecret string `json:"turnstile_secret"` - EnableLoginVerify bool `json:"enable_login_verify"` - EnableRegisterVerify bool `json:"enable_register_verify"` - EnableResetPasswordVerify bool `json:"enable_reset_password_verify"` + CaptchaType string `json:"captcha_type"` // local or turnstile + TurnstileSiteKey string `json:"turnstile_site_key"` + TurnstileSecret string `json:"turnstile_secret"` + EnableUserLoginCaptcha bool `json:"enable_user_login_captcha"` // User login captcha + EnableUserRegisterCaptcha bool `json:"enable_user_register_captcha"` // User register captcha + EnableAdminLoginCaptcha bool `json:"enable_admin_login_captcha"` // Admin login captcha + EnableUserResetPasswordCaptcha bool `json:"enable_user_reset_password_captcha"` // User reset password captcha } NodeConfig { NodeSecret string `json:"node_secret"` From f224d09d09abb76b4f3dbd754cc51ac932c0319b Mon Sep 17 00:00:00 2001 From: EUForest Date: Mon, 9 Mar 2026 22:54:21 +0800 Subject: [PATCH 08/18] feat(types): update request types with captcha fields - Add CaptchaId and CaptchaCode to UserLoginRequest - Add CaptchaId and CaptchaCode to UserRegisterRequest - Add CaptchaId and CaptchaCode to ResetPasswordRequest - Add CaptchaId and CaptchaCode to TelephoneLoginRequest - Add CaptchaId and CaptchaCode to TelephoneUserRegisterRequest - Add CaptchaId and CaptchaCode to TelephoneResetPasswordRequest - Add GenerateCaptchaResponse type - Add AdminLoginRequest and AdminResetPasswordRequest types --- internal/types/types.go | 92 ++++++++++++++++++++++++++--------------- 1 file changed, 59 insertions(+), 33 deletions(-) diff --git a/internal/types/types.go b/internal/types/types.go index 219ec50..f1ee1fc 100644 --- a/internal/types/types.go +++ b/internal/types/types.go @@ -774,6 +774,12 @@ type Follow struct { CreatedAt int64 `json:"created_at"` } +type GenerateCaptchaResponse struct { + Id string `json:"id"` + Image string `json:"image"` + Type string `json:"type"` +} + type GetAdsDetailRequest struct { Id int64 `form:"id"` } @@ -2027,14 +2033,16 @@ type ResetGroupsRequest struct { } type ResetPasswordRequest struct { - Identifier string `json:"identifier"` - Email string `json:"email" validate:"required"` - Password string `json:"password" validate:"required"` - Code string `json:"code,optional"` - IP string `header:"X-Original-Forwarded-For"` - UserAgent string `header:"User-Agent"` - LoginType string `header:"Login-Type"` - CfToken string `json:"cf_token,optional"` + Identifier string `json:"identifier"` + Email string `json:"email" validate:"required"` + Password string `json:"password" validate:"required"` + Code string `json:"code,optional"` + IP string `header:"X-Original-Forwarded-For"` + UserAgent string `header:"User-Agent"` + LoginType string `header:"Login-Type"` + CfToken string `json:"cf_token,optional"` + CaptchaId string `json:"captcha_id,optional"` + CaptchaCode string `json:"captcha_code,optional"` } type ResetSortRequest struct { @@ -2403,6 +2411,8 @@ type TelephoneLoginRequest struct { UserAgent string `header:"User-Agent"` LoginType string `header:"Login-Type"` CfToken string `json:"cf_token,optional"` + CaptchaId string `json:"captcha_id,optional"` + CaptchaCode string `json:"captcha_code,optional"` } type TelephoneRegisterRequest struct { @@ -2416,6 +2426,8 @@ type TelephoneRegisterRequest struct { UserAgent string `header:"User-Agent"` LoginType string `header:"Login-Type,optional"` CfToken string `json:"cf_token,optional"` + CaptchaId string `json:"captcha_id,optional"` + CaptchaCode string `json:"captcha_code,optional"` } type TelephoneResetPasswordRequest struct { @@ -2428,6 +2440,8 @@ type TelephoneResetPasswordRequest struct { UserAgent string `header:"User-Agent"` LoginType string `header:"Login-Type,optional"` CfToken string `json:"cf_token,optional"` + CaptchaId string `json:"captcha_id,optional"` + CaptchaCode string `json:"captcha_code,optional"` } type TestEmailSendRequest struct { @@ -2852,25 +2866,29 @@ type UserLoginLog struct { } type UserLoginRequest struct { - Identifier string `json:"identifier"` - Email string `json:"email" validate:"required"` - Password string `json:"password" validate:"required"` - IP string `header:"X-Original-Forwarded-For"` - UserAgent string `header:"User-Agent"` - LoginType string `header:"Login-Type"` - CfToken string `json:"cf_token,optional"` + Identifier string `json:"identifier"` + Email string `json:"email" validate:"required"` + Password string `json:"password" validate:"required"` + IP string `header:"X-Original-Forwarded-For"` + UserAgent string `header:"User-Agent"` + LoginType string `header:"Login-Type"` + CfToken string `json:"cf_token,optional"` + CaptchaId string `json:"captcha_id,optional"` + CaptchaCode string `json:"captcha_code,optional"` } type UserRegisterRequest struct { - Identifier string `json:"identifier"` - Email string `json:"email" validate:"required"` - Password string `json:"password" validate:"required"` - Invite string `json:"invite,optional"` - Code string `json:"code,optional"` - IP string `header:"X-Original-Forwarded-For"` - UserAgent string `header:"User-Agent"` - LoginType string `header:"Login-Type"` - CfToken string `json:"cf_token,optional"` + Identifier string `json:"identifier"` + Email string `json:"email" validate:"required"` + Password string `json:"password" validate:"required"` + Invite string `json:"invite,optional"` + Code string `json:"code,optional"` + IP string `header:"X-Original-Forwarded-For"` + UserAgent string `header:"User-Agent"` + LoginType string `header:"Login-Type"` + CfToken string `json:"cf_token,optional"` + CaptchaId string `json:"captcha_id,optional"` + CaptchaCode string `json:"captcha_code,optional"` } type UserStatistics struct { @@ -2893,6 +2911,8 @@ type UserSubscribe struct { OrderId int64 `json:"order_id"` SubscribeId int64 `json:"subscribe_id"` Subscribe Subscribe `json:"subscribe"` + NodeGroupId int64 `json:"node_group_id"` + GroupLocked bool `json:"group_locked"` StartTime int64 `json:"start_time"` ExpireTime int64 `json:"expire_time"` FinishedAt int64 `json:"finished_at"` @@ -2914,6 +2934,8 @@ type UserSubscribeDetail struct { OrderId int64 `json:"order_id"` SubscribeId int64 `json:"subscribe_id"` Subscribe Subscribe `json:"subscribe"` + NodeGroupId int64 `json:"node_group_id"` + GroupLocked bool `json:"group_locked"` StartTime int64 `json:"start_time"` ExpireTime int64 `json:"expire_time"` ResetTime int64 `json:"reset_time"` @@ -2997,10 +3019,12 @@ type UserTrafficData struct { } type VeifyConfig struct { - TurnstileSiteKey string `json:"turnstile_site_key"` - EnableLoginVerify bool `json:"enable_login_verify"` - EnableRegisterVerify bool `json:"enable_register_verify"` - EnableResetPasswordVerify bool `json:"enable_reset_password_verify"` + CaptchaType string `json:"captcha_type"` + TurnstileSiteKey string `json:"turnstile_site_key"` + EnableUserLoginCaptcha bool `json:"enable_user_login_captcha"` + EnableUserRegisterCaptcha bool `json:"enable_user_register_captcha"` + EnableAdminLoginCaptcha bool `json:"enable_admin_login_captcha"` + EnableUserResetPasswordCaptcha bool `json:"enable_user_reset_password_captcha"` } type VerifyCodeConfig struct { @@ -3010,11 +3034,13 @@ type VerifyCodeConfig struct { } type VerifyConfig struct { - TurnstileSiteKey string `json:"turnstile_site_key"` - TurnstileSecret string `json:"turnstile_secret"` - EnableLoginVerify bool `json:"enable_login_verify"` - EnableRegisterVerify bool `json:"enable_register_verify"` - EnableResetPasswordVerify bool `json:"enable_reset_password_verify"` + CaptchaType string `json:"captcha_type"` // local or turnstile + TurnstileSiteKey string `json:"turnstile_site_key"` + TurnstileSecret string `json:"turnstile_secret"` + EnableUserLoginCaptcha bool `json:"enable_user_login_captcha"` // User login captcha + EnableUserRegisterCaptcha bool `json:"enable_user_register_captcha"` // User register captcha + EnableAdminLoginCaptcha bool `json:"enable_admin_login_captcha"` // Admin login captcha + EnableUserResetPasswordCaptcha bool `json:"enable_user_reset_password_captcha"` // User reset password captcha } type VerifyEmailRequest struct { From 5727708bbd3155b38577e90221dd09728b084ed1 Mon Sep 17 00:00:00 2001 From: EUForest Date: Mon, 9 Mar 2026 22:54:33 +0800 Subject: [PATCH 09/18] feat(config): add captcha configuration to global config response - Add CaptchaType field to verify config - Add EnableUserLoginCaptcha field - Add EnableUserRegisterCaptcha field - Add EnableAdminLoginCaptcha field - Add EnableUserResetPasswordCaptcha field - Expose captcha configuration to frontend --- internal/logic/common/getGlobalConfigLogic.go | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/internal/logic/common/getGlobalConfigLogic.go b/internal/logic/common/getGlobalConfigLogic.go index 393371b..e3c3b90 100644 --- a/internal/logic/common/getGlobalConfigLogic.go +++ b/internal/logic/common/getGlobalConfigLogic.go @@ -41,27 +41,26 @@ func (l *GetGlobalConfigLogic) GetGlobalConfig() (resp *types.GetGlobalConfigRes l.Logger.Error("[GetGlobalConfigLogic] GetVerifyCodeConfig error: ", logger.Field("error", err.Error())) return nil, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "GetVerifyCodeConfig error: %v", err.Error()) } + verifyCfg, err := l.svcCtx.SystemModel.GetVerifyConfig(l.ctx) + if err != nil { + l.Logger.Error("[GetGlobalConfigLogic] GetVerifyConfig error: ", logger.Field("error", err.Error())) + return nil, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "GetVerifyConfig error: %v", err.Error()) + } tool.DeepCopy(&resp.Site, l.svcCtx.Config.Site) tool.DeepCopy(&resp.Subscribe, l.svcCtx.Config.Subscribe) tool.DeepCopy(&resp.Auth.Email, l.svcCtx.Config.Email) tool.DeepCopy(&resp.Auth.Mobile, l.svcCtx.Config.Mobile) tool.DeepCopy(&resp.Auth.Register, l.svcCtx.Config.Register) - tool.DeepCopy(&resp.Verify, l.svcCtx.Config.Verify) tool.DeepCopy(&resp.Invite, l.svcCtx.Config.Invite) tool.SystemConfigSliceReflectToStruct(currencyCfg, &resp.Currency) tool.SystemConfigSliceReflectToStruct(verifyCodeCfg, &resp.VerifyCode) + tool.SystemConfigSliceReflectToStruct(verifyCfg, &resp.Verify) if report.IsGatewayMode() { resp.Subscribe.SubscribePath = "/sub" + l.svcCtx.Config.Subscribe.SubscribePath } - resp.Verify = types.VeifyConfig{ - TurnstileSiteKey: l.svcCtx.Config.Verify.TurnstileSiteKey, - EnableLoginVerify: l.svcCtx.Config.Verify.LoginVerify, - EnableRegisterVerify: l.svcCtx.Config.Verify.RegisterVerify, - EnableResetPasswordVerify: l.svcCtx.Config.Verify.ResetPasswordVerify, - } var methods []string // auth methods From 2afb86f97351826f3d2ce0c843a1a1b6c4cf90bc Mon Sep 17 00:00:00 2001 From: EUForest Date: Mon, 9 Mar 2026 22:54:47 +0800 Subject: [PATCH 10/18] feat(auth): add user captcha generation endpoint - Add handler for /v1/auth/captcha/generate endpoint - Implement captcha generation logic based on configuration - Support local image captcha generation with Redis storage - Return Turnstile site key for Turnstile mode - Check EnableUserLoginCaptcha configuration --- .../handler/auth/generateCaptchaHandler.go | 18 +++++ internal/logic/auth/generateCaptchaLogic.go | 70 +++++++++++++++++++ 2 files changed, 88 insertions(+) create mode 100644 internal/handler/auth/generateCaptchaHandler.go create mode 100644 internal/logic/auth/generateCaptchaLogic.go diff --git a/internal/handler/auth/generateCaptchaHandler.go b/internal/handler/auth/generateCaptchaHandler.go new file mode 100644 index 0000000..d263da7 --- /dev/null +++ b/internal/handler/auth/generateCaptchaHandler.go @@ -0,0 +1,18 @@ +package auth + +import ( + "github.com/gin-gonic/gin" + "github.com/perfect-panel/server/internal/logic/auth" + "github.com/perfect-panel/server/internal/svc" + "github.com/perfect-panel/server/pkg/result" +) + +// Generate captcha +func GenerateCaptchaHandler(svcCtx *svc.ServiceContext) func(c *gin.Context) { + return func(c *gin.Context) { + + l := auth.NewGenerateCaptchaLogic(c.Request.Context(), svcCtx) + resp, err := l.GenerateCaptcha() + result.HttpResult(c, resp, err) + } +} diff --git a/internal/logic/auth/generateCaptchaLogic.go b/internal/logic/auth/generateCaptchaLogic.go new file mode 100644 index 0000000..068eb2b --- /dev/null +++ b/internal/logic/auth/generateCaptchaLogic.go @@ -0,0 +1,70 @@ +package auth + +import ( + "context" + + "github.com/perfect-panel/server/internal/svc" + "github.com/perfect-panel/server/internal/types" + "github.com/perfect-panel/server/pkg/captcha" + "github.com/perfect-panel/server/pkg/logger" + "github.com/perfect-panel/server/pkg/tool" + "github.com/perfect-panel/server/pkg/xerr" + "github.com/pkg/errors" +) + +type GenerateCaptchaLogic struct { + logger.Logger + ctx context.Context + svcCtx *svc.ServiceContext +} + +// Generate captcha +func NewGenerateCaptchaLogic(ctx context.Context, svcCtx *svc.ServiceContext) *GenerateCaptchaLogic { + return &GenerateCaptchaLogic{ + Logger: logger.WithContext(ctx), + ctx: ctx, + svcCtx: svcCtx, + } +} + +func (l *GenerateCaptchaLogic) GenerateCaptcha() (resp *types.GenerateCaptchaResponse, err error) { + resp = &types.GenerateCaptchaResponse{} + + // Get verify config from database + verifyCfg, err := l.svcCtx.SystemModel.GetVerifyConfig(l.ctx) + if err != nil { + l.Logger.Error("[GenerateCaptchaLogic] GetVerifyConfig error: ", logger.Field("error", err.Error())) + return nil, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "GetVerifyConfig error: %v", err.Error()) + } + + var config struct { + CaptchaType string `json:"captcha_type"` + TurnstileSiteKey string `json:"turnstile_site_key"` + TurnstileSecret string `json:"turnstile_secret"` + } + tool.SystemConfigSliceReflectToStruct(verifyCfg, &config) + + resp.Type = config.CaptchaType + + // If captcha type is local, generate captcha image + if config.CaptchaType == "local" { + captchaService := captcha.NewService(captcha.Config{ + Type: captcha.CaptchaTypeLocal, + RedisClient: l.svcCtx.Redis, + }) + + id, image, err := captchaService.Generate(l.ctx) + if err != nil { + l.Logger.Error("[GenerateCaptchaLogic] Generate captcha error: ", logger.Field("error", err.Error())) + return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "Generate captcha error: %v", err.Error()) + } + + resp.Id = id + resp.Image = image + } else if config.CaptchaType == "turnstile" { + // For Turnstile, just return the site key + resp.Id = config.TurnstileSiteKey + } + + return resp, nil +} From 9aaffec61dd064e5f9c35560e8ed494f0b501de7 Mon Sep 17 00:00:00 2001 From: EUForest Date: Mon, 9 Mar 2026 22:54:59 +0800 Subject: [PATCH 11/18] feat(auth): add admin authentication with permission checks - Add admin login handler and logic with IsAdmin verification - Add admin password reset handler and logic - Add admin captcha generation handler and logic - Implement device binding for admin login - Add login logging for admin authentication - Check EnableAdminLoginCaptcha configuration - Separate admin authentication from user authentication - Verify admin permission before allowing access --- .../auth/admin/adminGenerateCaptchaHandler.go | 18 ++ .../handler/auth/admin/adminLoginHandler.go | 30 +++ .../auth/admin/adminResetPasswordHandler.go | 29 +++ .../auth/admin/adminGenerateCaptchaLogic.go | 70 ++++++ internal/logic/auth/admin/adminLoginLogic.go | 201 +++++++++++++++ .../auth/admin/adminResetPasswordLogic.go | 229 ++++++++++++++++++ 6 files changed, 577 insertions(+) create mode 100644 internal/handler/auth/admin/adminGenerateCaptchaHandler.go create mode 100644 internal/handler/auth/admin/adminLoginHandler.go create mode 100644 internal/handler/auth/admin/adminResetPasswordHandler.go create mode 100644 internal/logic/auth/admin/adminGenerateCaptchaLogic.go create mode 100644 internal/logic/auth/admin/adminLoginLogic.go create mode 100644 internal/logic/auth/admin/adminResetPasswordLogic.go diff --git a/internal/handler/auth/admin/adminGenerateCaptchaHandler.go b/internal/handler/auth/admin/adminGenerateCaptchaHandler.go new file mode 100644 index 0000000..caabd45 --- /dev/null +++ b/internal/handler/auth/admin/adminGenerateCaptchaHandler.go @@ -0,0 +1,18 @@ +package admin + +import ( + "github.com/gin-gonic/gin" + "github.com/perfect-panel/server/internal/logic/auth/admin" + "github.com/perfect-panel/server/internal/svc" + "github.com/perfect-panel/server/pkg/result" +) + +// Generate captcha +func AdminGenerateCaptchaHandler(svcCtx *svc.ServiceContext) func(c *gin.Context) { + return func(c *gin.Context) { + + l := admin.NewAdminGenerateCaptchaLogic(c.Request.Context(), svcCtx) + resp, err := l.AdminGenerateCaptcha() + result.HttpResult(c, resp, err) + } +} diff --git a/internal/handler/auth/admin/adminLoginHandler.go b/internal/handler/auth/admin/adminLoginHandler.go new file mode 100644 index 0000000..95239bd --- /dev/null +++ b/internal/handler/auth/admin/adminLoginHandler.go @@ -0,0 +1,30 @@ +package admin + +import ( + "github.com/gin-gonic/gin" + adminLogic "github.com/perfect-panel/server/internal/logic/auth/admin" + "github.com/perfect-panel/server/internal/svc" + "github.com/perfect-panel/server/internal/types" + "github.com/perfect-panel/server/pkg/result" +) + +// Admin login +func AdminLoginHandler(svcCtx *svc.ServiceContext) func(c *gin.Context) { + return func(c *gin.Context) { + var req types.UserLoginRequest + _ = c.ShouldBind(&req) + // get client ip + req.IP = c.ClientIP() + req.UserAgent = c.Request.UserAgent() + + validateErr := svcCtx.Validate(&req) + if validateErr != nil { + result.ParamErrorResult(c, validateErr) + return + } + + l := adminLogic.NewAdminLoginLogic(c.Request.Context(), svcCtx) + resp, err := l.AdminLogin(&req) + result.HttpResult(c, resp, err) + } +} diff --git a/internal/handler/auth/admin/adminResetPasswordHandler.go b/internal/handler/auth/admin/adminResetPasswordHandler.go new file mode 100644 index 0000000..9fb909c --- /dev/null +++ b/internal/handler/auth/admin/adminResetPasswordHandler.go @@ -0,0 +1,29 @@ +package admin + +import ( + "github.com/gin-gonic/gin" + adminLogic "github.com/perfect-panel/server/internal/logic/auth/admin" + "github.com/perfect-panel/server/internal/svc" + "github.com/perfect-panel/server/internal/types" + "github.com/perfect-panel/server/pkg/result" +) + +// Admin reset password +func AdminResetPasswordHandler(svcCtx *svc.ServiceContext) func(c *gin.Context) { + return func(c *gin.Context) { + var req types.ResetPasswordRequest + _ = c.ShouldBind(&req) + validateErr := svcCtx.Validate(&req) + if validateErr != nil { + result.ParamErrorResult(c, validateErr) + return + } + // get client ip + req.IP = c.ClientIP() + req.UserAgent = c.Request.UserAgent() + + l := adminLogic.NewAdminResetPasswordLogic(c.Request.Context(), svcCtx) + resp, err := l.AdminResetPassword(&req) + result.HttpResult(c, resp, err) + } +} diff --git a/internal/logic/auth/admin/adminGenerateCaptchaLogic.go b/internal/logic/auth/admin/adminGenerateCaptchaLogic.go new file mode 100644 index 0000000..4772855 --- /dev/null +++ b/internal/logic/auth/admin/adminGenerateCaptchaLogic.go @@ -0,0 +1,70 @@ +package admin + +import ( + "context" + + "github.com/perfect-panel/server/internal/svc" + "github.com/perfect-panel/server/internal/types" + "github.com/perfect-panel/server/pkg/captcha" + "github.com/perfect-panel/server/pkg/logger" + "github.com/perfect-panel/server/pkg/tool" + "github.com/perfect-panel/server/pkg/xerr" + "github.com/pkg/errors" +) + +type AdminGenerateCaptchaLogic struct { + logger.Logger + ctx context.Context + svcCtx *svc.ServiceContext +} + +// Generate captcha +func NewAdminGenerateCaptchaLogic(ctx context.Context, svcCtx *svc.ServiceContext) *AdminGenerateCaptchaLogic { + return &AdminGenerateCaptchaLogic{ + Logger: logger.WithContext(ctx), + ctx: ctx, + svcCtx: svcCtx, + } +} + +func (l *AdminGenerateCaptchaLogic) AdminGenerateCaptcha() (resp *types.GenerateCaptchaResponse, err error) { + resp = &types.GenerateCaptchaResponse{} + + // Get verify config from database + verifyCfg, err := l.svcCtx.SystemModel.GetVerifyConfig(l.ctx) + if err != nil { + l.Logger.Error("[AdminGenerateCaptchaLogic] GetVerifyConfig error: ", logger.Field("error", err.Error())) + return nil, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "GetVerifyConfig error: %v", err.Error()) + } + + var config struct { + CaptchaType string `json:"captcha_type"` + TurnstileSiteKey string `json:"turnstile_site_key"` + TurnstileSecret string `json:"turnstile_secret"` + } + tool.SystemConfigSliceReflectToStruct(verifyCfg, &config) + + resp.Type = config.CaptchaType + + // If captcha type is local, generate captcha image + if config.CaptchaType == "local" { + captchaService := captcha.NewService(captcha.Config{ + Type: captcha.CaptchaTypeLocal, + RedisClient: l.svcCtx.Redis, + }) + + id, image, err := captchaService.Generate(l.ctx) + if err != nil { + l.Logger.Error("[AdminGenerateCaptchaLogic] Generate captcha error: ", logger.Field("error", err.Error())) + return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "Generate captcha error: %v", err.Error()) + } + + resp.Id = id + resp.Image = image + } else if config.CaptchaType == "turnstile" { + // For Turnstile, just return the site key + resp.Id = config.TurnstileSiteKey + } + + return resp, nil +} diff --git a/internal/logic/auth/admin/adminLoginLogic.go b/internal/logic/auth/admin/adminLoginLogic.go new file mode 100644 index 0000000..2167eaf --- /dev/null +++ b/internal/logic/auth/admin/adminLoginLogic.go @@ -0,0 +1,201 @@ +package admin + +import ( + "context" + "fmt" + "time" + + "github.com/perfect-panel/server/internal/config" + "github.com/perfect-panel/server/internal/logic/auth" + "github.com/perfect-panel/server/internal/model/log" + "github.com/perfect-panel/server/internal/model/user" + "github.com/perfect-panel/server/internal/svc" + "github.com/perfect-panel/server/internal/types" + "github.com/perfect-panel/server/pkg/captcha" + "github.com/perfect-panel/server/pkg/constant" + "github.com/perfect-panel/server/pkg/jwt" + "github.com/perfect-panel/server/pkg/logger" + "github.com/perfect-panel/server/pkg/tool" + "github.com/perfect-panel/server/pkg/uuidx" + "github.com/perfect-panel/server/pkg/xerr" + "github.com/pkg/errors" + "gorm.io/gorm" +) + +type AdminLoginLogic struct { + logger.Logger + ctx context.Context + svcCtx *svc.ServiceContext +} + +// Admin login +func NewAdminLoginLogic(ctx context.Context, svcCtx *svc.ServiceContext) *AdminLoginLogic { + return &AdminLoginLogic{ + Logger: logger.WithContext(ctx), + ctx: ctx, + svcCtx: svcCtx, + } +} + +func (l *AdminLoginLogic) AdminLogin(req *types.UserLoginRequest) (resp *types.LoginResponse, err error) { + loginStatus := false + var userInfo *user.User + // Record login status + defer func(svcCtx *svc.ServiceContext) { + if userInfo != nil && userInfo.Id != 0 { + loginLog := log.Login{ + Method: "email", + LoginIP: req.IP, + UserAgent: req.UserAgent, + Success: loginStatus, + Timestamp: time.Now().UnixMilli(), + } + content, _ := loginLog.Marshal() + if err := l.svcCtx.LogModel.Insert(l.ctx, &log.SystemLog{ + Type: log.TypeLogin.Uint8(), + Date: time.Now().Format("2006-01-02"), + ObjectID: userInfo.Id, + Content: string(content), + }); err != nil { + l.Errorw("failed to insert login log", + logger.Field("user_id", userInfo.Id), + logger.Field("ip", req.IP), + logger.Field("error", err.Error()), + ) + } + } + }(l.svcCtx) + + // Verify captcha + if err := l.verifyCaptcha(req); err != nil { + return nil, err + } + + userInfo, err = l.svcCtx.UserModel.FindOneByEmail(l.ctx, req.Email) + + if userInfo.DeletedAt.Valid { + return nil, errors.Wrapf(xerr.NewErrCode(xerr.UserNotExist), "user email deleted: %v", req.Email) + } + + if err != nil { + if errors.As(err, &gorm.ErrRecordNotFound) { + return nil, errors.Wrapf(xerr.NewErrCode(xerr.UserNotExist), "user email not exist: %v", req.Email) + } + logger.WithContext(l.ctx).Error(err) + return nil, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "query user info failed: %v", err.Error()) + } + + // Check if user is admin + if userInfo.IsAdmin == nil || !*userInfo.IsAdmin { + return nil, errors.Wrapf(xerr.NewErrCode(xerr.PermissionDenied), "user is not admin") + } + + // Verify password + if !tool.MultiPasswordVerify(userInfo.Algo, userInfo.Salt, req.Password, userInfo.Password) { + return nil, errors.Wrapf(xerr.NewErrCode(xerr.UserPasswordError), "user password") + } + + // Bind device to user if identifier is provided + if req.Identifier != "" { + bindLogic := auth.NewBindDeviceLogic(l.ctx, l.svcCtx) + if err := bindLogic.BindDeviceToUser(req.Identifier, req.IP, req.UserAgent, userInfo.Id); err != nil { + l.Errorw("failed to bind device to user", + logger.Field("user_id", userInfo.Id), + logger.Field("identifier", req.Identifier), + logger.Field("error", err.Error()), + ) + // Don't fail login if device binding fails, just log the error + } + } + if l.ctx.Value(constant.CtxLoginType) != nil { + req.LoginType = l.ctx.Value(constant.CtxLoginType).(string) + } + // Generate session id + sessionId := uuidx.NewUUID().String() + // Generate token + token, err := jwt.NewJwtToken( + l.svcCtx.Config.JwtAuth.AccessSecret, + time.Now().Unix(), + l.svcCtx.Config.JwtAuth.AccessExpire, + jwt.WithOption("UserId", userInfo.Id), + jwt.WithOption("SessionId", sessionId), + jwt.WithOption("identifier", req.Identifier), + jwt.WithOption("CtxLoginType", req.LoginType), + ) + if err != nil { + l.Logger.Error("[AdminLogin] token generate error", logger.Field("error", err.Error())) + return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "token generate error: %v", err.Error()) + } + sessionIdCacheKey := fmt.Sprintf("%v:%v", config.SessionIdKey, sessionId) + if err = l.svcCtx.Redis.Set(l.ctx, sessionIdCacheKey, userInfo.Id, time.Duration(l.svcCtx.Config.JwtAuth.AccessExpire)*time.Second).Err(); err != nil { + return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "set session id error: %v", err.Error()) + } + loginStatus = true + return &types.LoginResponse{ + Token: token, + }, nil +} + +func (l *AdminLoginLogic) verifyCaptcha(req *types.UserLoginRequest) error { + // Get verify config from database + verifyCfg, err := l.svcCtx.SystemModel.GetVerifyConfig(l.ctx) + if err != nil { + l.Logger.Error("[AdminLoginLogic] GetVerifyConfig error: ", logger.Field("error", err.Error())) + return errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "GetVerifyConfig error: %v", err.Error()) + } + + var config struct { + CaptchaType string `json:"captcha_type"` + EnableAdminLoginCaptcha bool `json:"enable_admin_login_captcha"` + TurnstileSecret string `json:"turnstile_secret"` + } + tool.SystemConfigSliceReflectToStruct(verifyCfg, &config) + + // Check if captcha is enabled for admin login + if !config.EnableAdminLoginCaptcha { + return nil + } + + // Verify based on captcha type + if config.CaptchaType == "local" { + if req.CaptchaId == "" || req.CaptchaCode == "" { + return errors.Wrapf(xerr.NewErrCode(xerr.VerifyCodeError), "captcha required") + } + + captchaService := captcha.NewService(captcha.Config{ + Type: captcha.CaptchaTypeLocal, + RedisClient: l.svcCtx.Redis, + }) + + valid, err := captchaService.Verify(l.ctx, req.CaptchaId, req.CaptchaCode, req.IP) + if err != nil { + l.Logger.Error("[AdminLoginLogic] Verify captcha error: ", logger.Field("error", err.Error())) + return errors.Wrapf(xerr.NewErrCode(xerr.VerifyCodeError), "verify captcha error") + } + + if !valid { + return errors.Wrapf(xerr.NewErrCode(xerr.VerifyCodeError), "invalid captcha") + } + } else if config.CaptchaType == "turnstile" { + if req.CfToken == "" { + return errors.Wrapf(xerr.NewErrCode(xerr.VerifyCodeError), "captcha required") + } + + captchaService := captcha.NewService(captcha.Config{ + Type: captcha.CaptchaTypeTurnstile, + TurnstileSecret: config.TurnstileSecret, + }) + + valid, err := captchaService.Verify(l.ctx, req.CfToken, "", req.IP) + if err != nil { + l.Logger.Error("[AdminLoginLogic] Verify turnstile error: ", logger.Field("error", err.Error())) + return errors.Wrapf(xerr.NewErrCode(xerr.VerifyCodeError), "verify captcha error") + } + + if !valid { + return errors.Wrapf(xerr.NewErrCode(xerr.VerifyCodeError), "invalid captcha") + } + } + + return nil +} diff --git a/internal/logic/auth/admin/adminResetPasswordLogic.go b/internal/logic/auth/admin/adminResetPasswordLogic.go new file mode 100644 index 0000000..544b56b --- /dev/null +++ b/internal/logic/auth/admin/adminResetPasswordLogic.go @@ -0,0 +1,229 @@ +package admin + +import ( + "context" + "encoding/json" + "fmt" + "time" + + "github.com/perfect-panel/server/internal/config" + "github.com/perfect-panel/server/internal/logic/auth" + "github.com/perfect-panel/server/internal/model/log" + "github.com/perfect-panel/server/internal/model/user" + "github.com/perfect-panel/server/internal/svc" + "github.com/perfect-panel/server/internal/types" + "github.com/perfect-panel/server/pkg/captcha" + "github.com/perfect-panel/server/pkg/constant" + "github.com/perfect-panel/server/pkg/jwt" + "github.com/perfect-panel/server/pkg/logger" + "github.com/perfect-panel/server/pkg/tool" + "github.com/perfect-panel/server/pkg/uuidx" + "github.com/perfect-panel/server/pkg/xerr" + "github.com/pkg/errors" + "gorm.io/gorm" +) + +type AdminResetPasswordLogic struct { + logger.Logger + ctx context.Context + svcCtx *svc.ServiceContext +} + +type CacheKeyPayload struct { + Code string `json:"code"` +} + +// Admin reset password +func NewAdminResetPasswordLogic(ctx context.Context, svcCtx *svc.ServiceContext) *AdminResetPasswordLogic { + return &AdminResetPasswordLogic{ + Logger: logger.WithContext(ctx), + ctx: ctx, + svcCtx: svcCtx, + } +} + +func (l *AdminResetPasswordLogic) AdminResetPassword(req *types.ResetPasswordRequest) (resp *types.LoginResponse, err error) { + var userInfo *user.User + loginStatus := false + + defer func() { + if userInfo != nil && userInfo.Id != 0 && loginStatus { + loginLog := log.Login{ + Method: "email", + LoginIP: req.IP, + UserAgent: req.UserAgent, + Success: loginStatus, + Timestamp: time.Now().UnixMilli(), + } + content, _ := loginLog.Marshal() + if err := l.svcCtx.LogModel.Insert(l.ctx, &log.SystemLog{ + Id: 0, + Type: log.TypeLogin.Uint8(), + Date: time.Now().Format("2006-01-02"), + ObjectID: userInfo.Id, + Content: string(content), + }); err != nil { + l.Errorw("failed to insert login log", + logger.Field("user_id", userInfo.Id), + logger.Field("ip", req.IP), + logger.Field("error", err.Error()), + ) + } + } + }() + + cacheKey := fmt.Sprintf("%s:%s:%s", config.AuthCodeCacheKey, constant.Security, req.Email) + // Check the verification code + if value, err := l.svcCtx.Redis.Get(l.ctx, cacheKey).Result(); err != nil { + l.Errorw("Verification code error", logger.Field("cacheKey", cacheKey), logger.Field("error", err.Error())) + return nil, errors.Wrapf(xerr.NewErrCode(xerr.VerifyCodeError), "Verification code error") + } else { + var payload CacheKeyPayload + if err := json.Unmarshal([]byte(value), &payload); err != nil { + l.Errorw("Unmarshal errors", logger.Field("cacheKey", cacheKey), logger.Field("error", err.Error()), logger.Field("value", value)) + return nil, errors.Wrapf(xerr.NewErrCode(xerr.VerifyCodeError), "Verification code error") + } + if payload.Code != req.Code { + l.Errorw("Verification code error", logger.Field("cacheKey", cacheKey), logger.Field("error", "Verification code error"), logger.Field("reqCode", req.Code), logger.Field("payloadCode", payload.Code)) + return nil, errors.Wrapf(xerr.NewErrCode(xerr.VerifyCodeError), "Verification code error") + } + } + + // Verify captcha + if err := l.verifyCaptcha(req); err != nil { + return nil, err + } + + // Check user + authMethod, err := l.svcCtx.UserModel.FindUserAuthMethodByOpenID(l.ctx, "email", req.Email) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, errors.Wrapf(xerr.NewErrCode(xerr.UserNotExist), "user email not exist: %v", req.Email) + } + return nil, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "find user by email error: %v", err.Error()) + } + + userInfo, err = l.svcCtx.UserModel.FindOne(l.ctx, authMethod.UserId) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, errors.Wrapf(xerr.NewErrCode(xerr.UserNotExist), "user email not exist: %v", req.Email) + } + return nil, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "query user info failed: %v", err.Error()) + } + + // Check if user is admin + if userInfo.IsAdmin == nil || !*userInfo.IsAdmin { + return nil, errors.Wrapf(xerr.NewErrCode(xerr.PermissionDenied), "user is not admin") + } + + // Update password + userInfo.Password = tool.EncodePassWord(req.Password) + userInfo.Algo = "default" + if err = l.svcCtx.UserModel.Update(l.ctx, userInfo); err != nil { + return nil, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseUpdateError), "update user info failed: %v", err.Error()) + } + + // Bind device to user if identifier is provided + if req.Identifier != "" { + bindLogic := auth.NewBindDeviceLogic(l.ctx, l.svcCtx) + if err := bindLogic.BindDeviceToUser(req.Identifier, req.IP, req.UserAgent, userInfo.Id); err != nil { + l.Errorw("failed to bind device to user", + logger.Field("user_id", userInfo.Id), + logger.Field("identifier", req.Identifier), + logger.Field("error", err.Error()), + ) + // Don't fail register if device binding fails, just log the error + } + } + if l.ctx.Value(constant.CtxLoginType) != nil { + req.LoginType = l.ctx.Value(constant.CtxLoginType).(string) + } + // Generate session id + sessionId := uuidx.NewUUID().String() + // Generate token + token, err := jwt.NewJwtToken( + l.svcCtx.Config.JwtAuth.AccessSecret, + time.Now().Unix(), + l.svcCtx.Config.JwtAuth.AccessExpire, + jwt.WithOption("UserId", userInfo.Id), + jwt.WithOption("SessionId", sessionId), + jwt.WithOption("identifier", req.Identifier), + jwt.WithOption("CtxLoginType", req.LoginType), + ) + if err != nil { + l.Logger.Error("[AdminResetPassword] token generate error", logger.Field("error", err.Error())) + return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "token generate error: %v", err.Error()) + } + sessionIdCacheKey := fmt.Sprintf("%v:%v", config.SessionIdKey, sessionId) + if err = l.svcCtx.Redis.Set(l.ctx, sessionIdCacheKey, userInfo.Id, time.Duration(l.svcCtx.Config.JwtAuth.AccessExpire)*time.Second).Err(); err != nil { + return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "set session id error: %v", err.Error()) + } + loginStatus = true + return &types.LoginResponse{ + Token: token, + }, nil +} + +func (l *AdminResetPasswordLogic) verifyCaptcha(req *types.ResetPasswordRequest) error { + // Get verify config from database + verifyCfg, err := l.svcCtx.SystemModel.GetVerifyConfig(l.ctx) + if err != nil { + l.Logger.Error("[AdminResetPasswordLogic] GetVerifyConfig error: ", logger.Field("error", err.Error())) + return errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "GetVerifyConfig error: %v", err.Error()) + } + + var config struct { + CaptchaType string `json:"captcha_type"` + EnableAdminLoginCaptcha bool `json:"enable_admin_login_captcha"` + TurnstileSecret string `json:"turnstile_secret"` + } + tool.SystemConfigSliceReflectToStruct(verifyCfg, &config) + + // Check if admin login captcha is enabled (use admin login captcha for reset password) + if !config.EnableAdminLoginCaptcha { + return nil + } + + // Verify based on captcha type + if config.CaptchaType == "local" { + if req.CaptchaId == "" || req.CaptchaCode == "" { + return errors.Wrapf(xerr.NewErrCode(xerr.VerifyCodeError), "captcha required") + } + + captchaService := captcha.NewService(captcha.Config{ + Type: captcha.CaptchaTypeLocal, + RedisClient: l.svcCtx.Redis, + }) + + valid, err := captchaService.Verify(l.ctx, req.CaptchaId, req.CaptchaCode, req.IP) + if err != nil { + l.Logger.Error("[AdminResetPasswordLogic] Verify captcha error: ", logger.Field("error", err.Error())) + return errors.Wrapf(xerr.NewErrCode(xerr.VerifyCodeError), "verify captcha error") + } + + if !valid { + return errors.Wrapf(xerr.NewErrCode(xerr.VerifyCodeError), "invalid captcha") + } + } else if config.CaptchaType == "turnstile" { + if req.CfToken == "" { + return errors.Wrapf(xerr.NewErrCode(xerr.VerifyCodeError), "captcha required") + } + + captchaService := captcha.NewService(captcha.Config{ + Type: captcha.CaptchaTypeTurnstile, + TurnstileSecret: config.TurnstileSecret, + }) + + valid, err := captchaService.Verify(l.ctx, req.CfToken, "", req.IP) + if err != nil { + l.Logger.Error("[AdminResetPasswordLogic] Verify turnstile error: ", logger.Field("error", err.Error())) + return errors.Wrapf(xerr.NewErrCode(xerr.VerifyCodeError), "verify captcha error") + } + + if !valid { + return errors.Wrapf(xerr.NewErrCode(xerr.VerifyCodeError), "invalid captcha") + } + } + + return nil +} From cea3e31f3af12c611297b6ab7775ffff4f1935c0 Mon Sep 17 00:00:00 2001 From: EUForest Date: Mon, 9 Mar 2026 22:55:08 +0800 Subject: [PATCH 12/18] feat(auth): add captcha verification to user email authentication - Add verifyCaptcha method to user login logic - Add verifyCaptcha method to user registration logic - Add verifyCaptcha method to password reset logic - Support both local and Turnstile captcha verification - Check respective configuration flags before verification - Validate captcha code and ID for local captcha - Validate Turnstile token for Turnstile mode --- internal/logic/auth/resetPasswordLogic.go | 73 ++++++++++++++++++++++- internal/logic/auth/userLoginLogic.go | 72 +++++++++++++++++++++- internal/logic/auth/userRegisterLogic.go | 73 ++++++++++++++++++++++- 3 files changed, 215 insertions(+), 3 deletions(-) diff --git a/internal/logic/auth/resetPasswordLogic.go b/internal/logic/auth/resetPasswordLogic.go index 22db2c9..7ccc6d0 100644 --- a/internal/logic/auth/resetPasswordLogic.go +++ b/internal/logic/auth/resetPasswordLogic.go @@ -8,6 +8,7 @@ import ( "github.com/perfect-panel/server/internal/model/log" "github.com/perfect-panel/server/internal/model/user" + "github.com/perfect-panel/server/pkg/captcha" "github.com/perfect-panel/server/pkg/jwt" "github.com/perfect-panel/server/pkg/uuidx" @@ -43,7 +44,7 @@ func (l *ResetPasswordLogic) ResetPassword(req *types.ResetPasswordRequest) (res loginStatus := false defer func() { - if userInfo.Id != 0 && loginStatus { + if userInfo != nil && userInfo.Id != 0 && loginStatus { loginLog := log.Login{ Method: "email", LoginIP: req.IP, @@ -85,6 +86,11 @@ func (l *ResetPasswordLogic) ResetPassword(req *types.ResetPasswordRequest) (res } } + // Verify captcha + if err := l.verifyCaptcha(req); err != nil { + return nil, err + } + // Check user authMethod, err := l.svcCtx.UserModel.FindUserAuthMethodByOpenID(l.ctx, "email", req.Email) if err != nil { @@ -149,3 +155,68 @@ func (l *ResetPasswordLogic) ResetPassword(req *types.ResetPasswordRequest) (res Token: token, }, nil } + +func (l *ResetPasswordLogic) verifyCaptcha(req *types.ResetPasswordRequest) error { + // Get verify config from database + verifyCfg, err := l.svcCtx.SystemModel.GetVerifyConfig(l.ctx) + if err != nil { + l.Logger.Error("[ResetPasswordLogic] GetVerifyConfig error: ", logger.Field("error", err.Error())) + return errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "GetVerifyConfig error: %v", err.Error()) + } + + var config struct { + CaptchaType string `json:"captcha_type"` + EnableUserResetPasswordCaptcha bool `json:"enable_user_reset_password_captcha"` + TurnstileSecret string `json:"turnstile_secret"` + } + tool.SystemConfigSliceReflectToStruct(verifyCfg, &config) + + // Check if user reset password captcha is enabled + if !config.EnableUserResetPasswordCaptcha { + return nil + } + + // Verify based on captcha type + if config.CaptchaType == "local" { + if req.CaptchaId == "" || req.CaptchaCode == "" { + return errors.Wrapf(xerr.NewErrCode(xerr.VerifyCodeError), "captcha required") + } + + captchaService := captcha.NewService(captcha.Config{ + Type: captcha.CaptchaTypeLocal, + RedisClient: l.svcCtx.Redis, + }) + + valid, err := captchaService.Verify(l.ctx, req.CaptchaId, req.CaptchaCode, req.IP) + if err != nil { + l.Logger.Error("[ResetPasswordLogic] Verify captcha error: ", logger.Field("error", err.Error())) + return errors.Wrapf(xerr.NewErrCode(xerr.VerifyCodeError), "verify captcha error") + } + + if !valid { + return errors.Wrapf(xerr.NewErrCode(xerr.VerifyCodeError), "invalid captcha") + } + } else if config.CaptchaType == "turnstile" { + if req.CfToken == "" { + return errors.Wrapf(xerr.NewErrCode(xerr.VerifyCodeError), "captcha required") + } + + captchaService := captcha.NewService(captcha.Config{ + Type: captcha.CaptchaTypeTurnstile, + TurnstileSecret: config.TurnstileSecret, + }) + + valid, err := captchaService.Verify(l.ctx, req.CfToken, "", req.IP) + if err != nil { + l.Logger.Error("[ResetPasswordLogic] Verify turnstile error: ", logger.Field("error", err.Error())) + return errors.Wrapf(xerr.NewErrCode(xerr.VerifyCodeError), "verify captcha error") + } + + if !valid { + return errors.Wrapf(xerr.NewErrCode(xerr.VerifyCodeError), "invalid captcha") + } + } + + return nil +} + diff --git a/internal/logic/auth/userLoginLogic.go b/internal/logic/auth/userLoginLogic.go index 4e6fac2..9bd5d59 100644 --- a/internal/logic/auth/userLoginLogic.go +++ b/internal/logic/auth/userLoginLogic.go @@ -6,6 +6,7 @@ import ( "time" "github.com/perfect-panel/server/internal/model/log" + "github.com/perfect-panel/server/pkg/captcha" "github.com/perfect-panel/server/pkg/constant" "github.com/perfect-panel/server/pkg/logger" @@ -42,7 +43,7 @@ func (l *UserLoginLogic) UserLogin(req *types.UserLoginRequest) (resp *types.Log var userInfo *user.User // Record login status defer func(svcCtx *svc.ServiceContext) { - if userInfo.Id != 0 { + if userInfo != nil && userInfo.Id != 0 { loginLog := log.Login{ Method: "email", LoginIP: req.IP, @@ -66,6 +67,11 @@ func (l *UserLoginLogic) UserLogin(req *types.UserLoginRequest) (resp *types.Log } }(l.svcCtx) + // Verify captcha + if err := l.verifyCaptcha(req); err != nil { + return nil, err + } + userInfo, err = l.svcCtx.UserModel.FindOneByEmail(l.ctx, req.Email) if userInfo.DeletedAt.Valid { @@ -125,3 +131,67 @@ func (l *UserLoginLogic) UserLogin(req *types.UserLoginRequest) (resp *types.Log Token: token, }, nil } + +func (l *UserLoginLogic) verifyCaptcha(req *types.UserLoginRequest) error { + // Get verify config from database + verifyCfg, err := l.svcCtx.SystemModel.GetVerifyConfig(l.ctx) + if err != nil { + l.Logger.Error("[UserLoginLogic] GetVerifyConfig error: ", logger.Field("error", err.Error())) + return errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "GetVerifyConfig error: %v", err.Error()) + } + + var config struct { + CaptchaType string `json:"captcha_type"` + EnableUserLoginCaptcha bool `json:"enable_user_login_captcha"` + TurnstileSecret string `json:"turnstile_secret"` + } + tool.SystemConfigSliceReflectToStruct(verifyCfg, &config) + + // Check if captcha is enabled for user login + if !config.EnableUserLoginCaptcha { + return nil + } + + // Verify based on captcha type + if config.CaptchaType == "local" { + if req.CaptchaId == "" || req.CaptchaCode == "" { + return errors.Wrapf(xerr.NewErrCode(xerr.VerifyCodeError), "captcha required") + } + + captchaService := captcha.NewService(captcha.Config{ + Type: captcha.CaptchaTypeLocal, + RedisClient: l.svcCtx.Redis, + }) + + valid, err := captchaService.Verify(l.ctx, req.CaptchaId, req.CaptchaCode, req.IP) + if err != nil { + l.Logger.Error("[UserLoginLogic] Verify captcha error: ", logger.Field("error", err.Error())) + return errors.Wrapf(xerr.NewErrCode(xerr.VerifyCodeError), "verify captcha error") + } + + if !valid { + return errors.Wrapf(xerr.NewErrCode(xerr.VerifyCodeError), "invalid captcha") + } + } else if config.CaptchaType == "turnstile" { + if req.CfToken == "" { + return errors.Wrapf(xerr.NewErrCode(xerr.VerifyCodeError), "captcha required") + } + + captchaService := captcha.NewService(captcha.Config{ + Type: captcha.CaptchaTypeTurnstile, + TurnstileSecret: config.TurnstileSecret, + }) + + valid, err := captchaService.Verify(l.ctx, req.CfToken, "", req.IP) + if err != nil { + l.Logger.Error("[UserLoginLogic] Verify turnstile error: ", logger.Field("error", err.Error())) + return errors.Wrapf(xerr.NewErrCode(xerr.VerifyCodeError), "verify captcha error") + } + + if !valid { + return errors.Wrapf(xerr.NewErrCode(xerr.VerifyCodeError), "invalid captcha") + } + } + + return nil +} diff --git a/internal/logic/auth/userRegisterLogic.go b/internal/logic/auth/userRegisterLogic.go index 287a39e..423a19c 100644 --- a/internal/logic/auth/userRegisterLogic.go +++ b/internal/logic/auth/userRegisterLogic.go @@ -13,6 +13,7 @@ import ( "github.com/perfect-panel/server/internal/model/user" "github.com/perfect-panel/server/internal/svc" "github.com/perfect-panel/server/internal/types" + "github.com/perfect-panel/server/pkg/captcha" "github.com/perfect-panel/server/pkg/constant" "github.com/perfect-panel/server/pkg/jwt" "github.com/perfect-panel/server/pkg/logger" @@ -80,6 +81,12 @@ func (l *UserRegisterLogic) UserRegister(req *types.UserRegisterRequest) (resp * return nil, errors.Wrapf(xerr.NewErrCode(xerr.VerifyCodeError), "code error") } } + + // Verify captcha + if err := l.verifyCaptcha(req); err != nil { + return nil, err + } + // Check if the user exists u, err := l.svcCtx.UserModel.FindOneByEmail(l.ctx, req.Email) if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { @@ -250,7 +257,7 @@ func (l *UserRegisterLogic) UserRegister(req *types.UserRegisterRequest) (resp * } loginStatus := true defer func() { - if token != "" && userInfo.Id != 0 { + if token != "" && userInfo != nil && userInfo.Id != 0 { loginLog := log.Login{ Method: "email", LoginIP: req.IP, @@ -323,3 +330,67 @@ func (l *UserRegisterLogic) activeTrial(uid int64) (*user.Subscribe, error) { } return userSub, nil } + +func (l *UserRegisterLogic) verifyCaptcha(req *types.UserRegisterRequest) error { + // Get verify config from database + verifyCfg, err := l.svcCtx.SystemModel.GetVerifyConfig(l.ctx) + if err != nil { + l.Logger.Error("[UserRegisterLogic] GetVerifyConfig error: ", logger.Field("error", err.Error())) + return errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "GetVerifyConfig error: %v", err.Error()) + } + + var config struct { + CaptchaType string `json:"captcha_type"` + EnableUserRegisterCaptcha bool `json:"enable_user_register_captcha"` + TurnstileSecret string `json:"turnstile_secret"` + } + tool.SystemConfigSliceReflectToStruct(verifyCfg, &config) + + // Check if user register captcha is enabled + if !config.EnableUserRegisterCaptcha { + return nil + } + + // Verify based on captcha type + if config.CaptchaType == "local" { + if req.CaptchaId == "" || req.CaptchaCode == "" { + return errors.Wrapf(xerr.NewErrCode(xerr.VerifyCodeError), "captcha required") + } + + captchaService := captcha.NewService(captcha.Config{ + Type: captcha.CaptchaTypeLocal, + RedisClient: l.svcCtx.Redis, + }) + + valid, err := captchaService.Verify(l.ctx, req.CaptchaId, req.CaptchaCode, req.IP) + if err != nil { + l.Logger.Error("[UserRegisterLogic] Verify captcha error: ", logger.Field("error", err.Error())) + return errors.Wrapf(xerr.NewErrCode(xerr.VerifyCodeError), "verify captcha error") + } + + if !valid { + return errors.Wrapf(xerr.NewErrCode(xerr.VerifyCodeError), "invalid captcha") + } + } else if config.CaptchaType == "turnstile" { + if req.CfToken == "" { + return errors.Wrapf(xerr.NewErrCode(xerr.VerifyCodeError), "captcha required") + } + + captchaService := captcha.NewService(captcha.Config{ + Type: captcha.CaptchaTypeTurnstile, + TurnstileSecret: config.TurnstileSecret, + }) + + valid, err := captchaService.Verify(l.ctx, req.CfToken, "", req.IP) + if err != nil { + l.Logger.Error("[UserRegisterLogic] Verify turnstile error: ", logger.Field("error", err.Error())) + return errors.Wrapf(xerr.NewErrCode(xerr.VerifyCodeError), "verify captcha error") + } + + if !valid { + return errors.Wrapf(xerr.NewErrCode(xerr.VerifyCodeError), "invalid captcha") + } + } + + return nil +} From fae77a8954b062012d17a380f3824c256c1c2a95 Mon Sep 17 00:00:00 2001 From: EUForest Date: Mon, 9 Mar 2026 22:55:23 +0800 Subject: [PATCH 13/18] feat(auth): add captcha verification to phone authentication - Add verifyCaptcha method to phone login logic - Add verifyCaptcha method to phone registration logic - Support both local and Turnstile captcha verification - Check EnableUserLoginCaptcha for phone login - Check EnableUserRegisterCaptcha for phone registration - Validate captcha before processing phone authentication --- internal/logic/auth/telephoneLoginLogic.go | 70 ++++++++++++++++++ .../logic/auth/telephoneUserRegisterLogic.go | 71 +++++++++++++++++++ 2 files changed, 141 insertions(+) diff --git a/internal/logic/auth/telephoneLoginLogic.go b/internal/logic/auth/telephoneLoginLogic.go index 3a8655c..772e006 100644 --- a/internal/logic/auth/telephoneLoginLogic.go +++ b/internal/logic/auth/telephoneLoginLogic.go @@ -12,6 +12,7 @@ import ( "github.com/perfect-panel/server/internal/model/log" "github.com/perfect-panel/server/internal/svc" "github.com/perfect-panel/server/internal/types" + "github.com/perfect-panel/server/pkg/captcha" "github.com/perfect-panel/server/pkg/constant" "github.com/perfect-panel/server/pkg/jwt" "github.com/perfect-panel/server/pkg/logger" @@ -94,6 +95,11 @@ func (l *TelephoneLoginLogic) TelephoneLogin(req *types.TelephoneLoginRequest, r return nil, xerr.NewErrCodeMsg(xerr.InvalidParams, "password and telephone code is empty") } + // Verify captcha + if err := l.verifyCaptcha(req); err != nil { + return nil, err + } + if req.TelephoneCode == "" { // Verify password if !tool.MultiPasswordVerify(userInfo.Algo, userInfo.Salt, req.Password, userInfo.Password) { @@ -164,3 +170,67 @@ func (l *TelephoneLoginLogic) TelephoneLogin(req *types.TelephoneLoginRequest, r Token: token, }, nil } + +func (l *TelephoneLoginLogic) verifyCaptcha(req *types.TelephoneLoginRequest) error { + // Get verify config from database + verifyCfg, err := l.svcCtx.SystemModel.GetVerifyConfig(l.ctx) + if err != nil { + l.Logger.Error("[TelephoneLoginLogic] GetVerifyConfig error: ", logger.Field("error", err.Error())) + return errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "GetVerifyConfig error: %v", err.Error()) + } + + var config struct { + CaptchaType string `json:"captcha_type"` + EnableUserLoginCaptcha bool `json:"enable_user_login_captcha"` + TurnstileSecret string `json:"turnstile_secret"` + } + tool.SystemConfigSliceReflectToStruct(verifyCfg, &config) + + // Check if captcha is enabled for user login + if !config.EnableUserLoginCaptcha { + return nil + } + + // Verify based on captcha type + if config.CaptchaType == "local" { + if req.CaptchaId == "" || req.CaptchaCode == "" { + return errors.Wrapf(xerr.NewErrCode(xerr.VerifyCodeError), "captcha required") + } + + captchaService := captcha.NewService(captcha.Config{ + Type: captcha.CaptchaTypeLocal, + RedisClient: l.svcCtx.Redis, + }) + + valid, err := captchaService.Verify(l.ctx, req.CaptchaId, req.CaptchaCode, req.IP) + if err != nil { + l.Logger.Error("[TelephoneLoginLogic] Verify captcha error: ", logger.Field("error", err.Error())) + return errors.Wrapf(xerr.NewErrCode(xerr.VerifyCodeError), "verify captcha error") + } + + if !valid { + return errors.Wrapf(xerr.NewErrCode(xerr.VerifyCodeError), "invalid captcha") + } + } else if config.CaptchaType == "turnstile" { + if req.CfToken == "" { + return errors.Wrapf(xerr.NewErrCode(xerr.VerifyCodeError), "captcha required") + } + + captchaService := captcha.NewService(captcha.Config{ + Type: captcha.CaptchaTypeTurnstile, + TurnstileSecret: config.TurnstileSecret, + }) + + valid, err := captchaService.Verify(l.ctx, req.CfToken, "", req.IP) + if err != nil { + l.Logger.Error("[TelephoneLoginLogic] Verify turnstile error: ", logger.Field("error", err.Error())) + return errors.Wrapf(xerr.NewErrCode(xerr.VerifyCodeError), "verify captcha error") + } + + if !valid { + return errors.Wrapf(xerr.NewErrCode(xerr.VerifyCodeError), "invalid captcha") + } + } + + return nil +} diff --git a/internal/logic/auth/telephoneUserRegisterLogic.go b/internal/logic/auth/telephoneUserRegisterLogic.go index 006e956..1a5df07 100644 --- a/internal/logic/auth/telephoneUserRegisterLogic.go +++ b/internal/logic/auth/telephoneUserRegisterLogic.go @@ -13,6 +13,7 @@ import ( "github.com/perfect-panel/server/internal/model/user" "github.com/perfect-panel/server/internal/svc" "github.com/perfect-panel/server/internal/types" + "github.com/perfect-panel/server/pkg/captcha" "github.com/perfect-panel/server/pkg/jwt" "github.com/perfect-panel/server/pkg/logger" "github.com/perfect-panel/server/pkg/phone" @@ -81,6 +82,12 @@ func (l *TelephoneUserRegisterLogic) TelephoneUserRegister(req *types.TelephoneR return nil, errors.Wrapf(xerr.NewErrCode(xerr.VerifyCodeError), "code error") } l.svcCtx.Redis.Del(l.ctx, cacheKey) + + // Verify captcha + if err := l.verifyCaptcha(req); err != nil { + return nil, err + } + // Check if the user exists _, err = l.svcCtx.UserModel.FindUserAuthMethodByOpenID(l.ctx, "mobile", phoneNumber) if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { @@ -280,3 +287,67 @@ func (l *TelephoneUserRegisterLogic) activeTrial(uid int64) (*user.Subscribe, er return userSub, nil } +func (l *TelephoneUserRegisterLogic) verifyCaptcha(req *types.TelephoneRegisterRequest) error { + // Get verify config from database + verifyCfg, err := l.svcCtx.SystemModel.GetVerifyConfig(l.ctx) + if err != nil { + l.Logger.Error("[TelephoneUserRegisterLogic] GetVerifyConfig error: ", logger.Field("error", err.Error())) + return errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "GetVerifyConfig error: %v", err.Error()) + } + + var config struct { + CaptchaType string `json:"captcha_type"` + EnableUserRegisterCaptcha bool `json:"enable_user_register_captcha"` + TurnstileSecret string `json:"turnstile_secret"` + } + tool.SystemConfigSliceReflectToStruct(verifyCfg, &config) + + // Check if captcha is enabled for user register + if !config.EnableUserRegisterCaptcha { + return nil + } + + // Verify based on captcha type + if config.CaptchaType == "local" { + if req.CaptchaId == "" || req.CaptchaCode == "" { + return errors.Wrapf(xerr.NewErrCode(xerr.VerifyCodeError), "captcha required") + } + + captchaService := captcha.NewService(captcha.Config{ + Type: captcha.CaptchaTypeLocal, + RedisClient: l.svcCtx.Redis, + }) + + valid, err := captchaService.Verify(l.ctx, req.CaptchaId, req.CaptchaCode, req.IP) + if err != nil { + l.Logger.Error("[TelephoneUserRegisterLogic] Verify captcha error: ", logger.Field("error", err.Error())) + return errors.Wrapf(xerr.NewErrCode(xerr.VerifyCodeError), "verify captcha error") + } + + if !valid { + return errors.Wrapf(xerr.NewErrCode(xerr.VerifyCodeError), "invalid captcha") + } + } else if config.CaptchaType == "turnstile" { + if req.CfToken == "" { + return errors.Wrapf(xerr.NewErrCode(xerr.VerifyCodeError), "captcha required") + } + + captchaService := captcha.NewService(captcha.Config{ + Type: captcha.CaptchaTypeTurnstile, + TurnstileSecret: config.TurnstileSecret, + }) + + valid, err := captchaService.Verify(l.ctx, req.CfToken, "", req.IP) + if err != nil { + l.Logger.Error("[TelephoneUserRegisterLogic] Verify turnstile error: ", logger.Field("error", err.Error())) + return errors.Wrapf(xerr.NewErrCode(xerr.VerifyCodeError), "verify captcha error") + } + + if !valid { + return errors.Wrapf(xerr.NewErrCode(xerr.VerifyCodeError), "invalid captcha") + } + } + + return nil +} + From 3ca471f58c53ef76c291304d15cd5a5aed12aad9 Mon Sep 17 00:00:00 2001 From: EUForest Date: Mon, 9 Mar 2026 22:56:07 +0800 Subject: [PATCH 14/18] refactor(auth): move captcha verification from handler to logic layer - Remove duplicate captcha verification from user login handler - Remove duplicate captcha verification from user register handler - Remove duplicate captcha verification from password reset handler - Remove duplicate captcha verification from phone login handler - Remove duplicate captcha verification from phone register handler - Update phone reset password handler structure - Improve separation of concerns between handler and logic layers - Handlers now only handle HTTP request/response, logic handles business rules --- internal/handler/auth/resetPasswordHandler.go | 18 +------- .../handler/auth/telephoneLoginHandler.go | 18 +------- .../auth/telephoneResetPasswordHandler.go | 46 +++++++++++++++---- .../auth/telephoneUserRegisterHandler.go | 17 +------ internal/handler/auth/userLoginHandler.go | 17 +------ internal/handler/auth/userRegisterHandler.go | 16 +------ 6 files changed, 43 insertions(+), 89 deletions(-) diff --git a/internal/handler/auth/resetPasswordHandler.go b/internal/handler/auth/resetPasswordHandler.go index d4edc9b..8de4ca6 100644 --- a/internal/handler/auth/resetPasswordHandler.go +++ b/internal/handler/auth/resetPasswordHandler.go @@ -1,16 +1,11 @@ package auth import ( - "time" - "github.com/gin-gonic/gin" "github.com/perfect-panel/server/internal/logic/auth" "github.com/perfect-panel/server/internal/svc" "github.com/perfect-panel/server/internal/types" "github.com/perfect-panel/server/pkg/result" - "github.com/perfect-panel/server/pkg/turnstile" - "github.com/perfect-panel/server/pkg/xerr" - "github.com/pkg/errors" ) // Reset password @@ -25,17 +20,8 @@ func ResetPasswordHandler(svcCtx *svc.ServiceContext) func(c *gin.Context) { } // get client ip req.IP = c.ClientIP() - if svcCtx.Config.Verify.ResetPasswordVerify { - verifyTurns := turnstile.New(turnstile.Config{ - Secret: svcCtx.Config.Verify.TurnstileSecret, - Timeout: 3 * time.Second, - }) - if verify, err := verifyTurns.Verify(c, req.CfToken, req.IP); err != nil || !verify { - err = errors.Wrapf(xerr.NewErrCode(xerr.TooManyRequests), "error: %v, verify: %v", err, verify) - result.HttpResult(c, nil, err) - return - } - } + req.UserAgent = c.Request.UserAgent() + l := auth.NewResetPasswordLogic(c.Request.Context(), svcCtx) resp, err := l.ResetPassword(&req) result.HttpResult(c, resp, err) diff --git a/internal/handler/auth/telephoneLoginHandler.go b/internal/handler/auth/telephoneLoginHandler.go index 44c1b53..c7e3bab 100644 --- a/internal/handler/auth/telephoneLoginHandler.go +++ b/internal/handler/auth/telephoneLoginHandler.go @@ -1,16 +1,11 @@ package auth import ( - "time" - "github.com/gin-gonic/gin" "github.com/perfect-panel/server/internal/logic/auth" "github.com/perfect-panel/server/internal/svc" "github.com/perfect-panel/server/internal/types" "github.com/perfect-panel/server/pkg/result" - "github.com/perfect-panel/server/pkg/turnstile" - "github.com/perfect-panel/server/pkg/xerr" - "github.com/pkg/errors" ) // User Telephone login @@ -25,17 +20,8 @@ func TelephoneLoginHandler(svcCtx *svc.ServiceContext) func(c *gin.Context) { } // get client ip req.IP = c.ClientIP() - if svcCtx.Config.Verify.LoginVerify { - verifyTurns := turnstile.New(turnstile.Config{ - Secret: svcCtx.Config.Verify.TurnstileSecret, - Timeout: 3 * time.Second, - }) - if verify, err := verifyTurns.Verify(c, req.CfToken, req.IP); err != nil || !verify { - err = errors.Wrapf(xerr.NewErrCode(xerr.TooManyRequests), "error: %v, verify: %v", err, verify) - result.HttpResult(c, nil, err) - return - } - } + req.UserAgent = c.Request.UserAgent() + l := auth.NewTelephoneLoginLogic(c, svcCtx) resp, err := l.TelephoneLogin(&req, c.Request, c.ClientIP()) result.HttpResult(c, resp, err) diff --git a/internal/handler/auth/telephoneResetPasswordHandler.go b/internal/handler/auth/telephoneResetPasswordHandler.go index 16a5105..bc0f8a6 100644 --- a/internal/handler/auth/telephoneResetPasswordHandler.go +++ b/internal/handler/auth/telephoneResetPasswordHandler.go @@ -1,14 +1,13 @@ package auth import ( - "time" - "github.com/gin-gonic/gin" "github.com/perfect-panel/server/internal/logic/auth" "github.com/perfect-panel/server/internal/svc" "github.com/perfect-panel/server/internal/types" + "github.com/perfect-panel/server/pkg/captcha" "github.com/perfect-panel/server/pkg/result" - "github.com/perfect-panel/server/pkg/turnstile" + "github.com/perfect-panel/server/pkg/tool" "github.com/perfect-panel/server/pkg/xerr" "github.com/pkg/errors" ) @@ -25,17 +24,44 @@ func TelephoneResetPasswordHandler(svcCtx *svc.ServiceContext) func(c *gin.Conte } // get client ip req.IP = c.ClientIP() - if svcCtx.Config.Verify.ResetPasswordVerify { - verifyTurns := turnstile.New(turnstile.Config{ - Secret: svcCtx.Config.Verify.TurnstileSecret, - Timeout: 3 * time.Second, + + // Get verify config from database + verifyCfg, err := svcCtx.SystemModel.GetVerifyConfig(c.Request.Context()) + if err != nil { + result.HttpResult(c, nil, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "get verify config failed: %v", err)) + return + } + + var config struct { + CaptchaType string `json:"captcha_type"` + EnableUserResetPasswordCaptcha bool `json:"enable_user_reset_password_captcha"` + TurnstileSecret string `json:"turnstile_secret"` + } + tool.SystemConfigSliceReflectToStruct(verifyCfg, &config) + + // Verify captcha if enabled + if config.EnableUserResetPasswordCaptcha { + captchaService := captcha.NewService(captcha.Config{ + Type: captcha.CaptchaType(config.CaptchaType), + TurnstileSecret: config.TurnstileSecret, + RedisClient: svcCtx.Redis, }) - if verify, err := verifyTurns.Verify(c.Request.Context(), req.CfToken, req.IP); err != nil || !verify { - err = errors.Wrapf(xerr.NewErrCode(xerr.TooManyRequests), "error: %v, verify: %v", err, verify) - result.HttpResult(c, nil, err) + + var token, code string + if config.CaptchaType == "turnstile" { + token = req.CfToken + } else { + token = req.CaptchaId + code = req.CaptchaCode + } + + verified, err := captchaService.Verify(c.Request.Context(), token, code, req.IP) + if err != nil || !verified { + result.HttpResult(c, nil, errors.Wrapf(xerr.NewErrCode(xerr.TooManyRequests), "captcha verification failed: %v", err)) return } } + l := auth.NewTelephoneResetPasswordLogic(c, svcCtx) resp, err := l.TelephoneResetPassword(&req) result.HttpResult(c, resp, err) diff --git a/internal/handler/auth/telephoneUserRegisterHandler.go b/internal/handler/auth/telephoneUserRegisterHandler.go index 45a7ba8..306777d 100644 --- a/internal/handler/auth/telephoneUserRegisterHandler.go +++ b/internal/handler/auth/telephoneUserRegisterHandler.go @@ -1,16 +1,11 @@ package auth import ( - "time" - "github.com/gin-gonic/gin" "github.com/perfect-panel/server/internal/logic/auth" "github.com/perfect-panel/server/internal/svc" "github.com/perfect-panel/server/internal/types" "github.com/perfect-panel/server/pkg/result" - "github.com/perfect-panel/server/pkg/turnstile" - "github.com/perfect-panel/server/pkg/xerr" - "github.com/pkg/errors" ) // User Telephone register @@ -26,17 +21,7 @@ func TelephoneUserRegisterHandler(svcCtx *svc.ServiceContext) func(c *gin.Contex // get client ip req.IP = c.ClientIP() req.UserAgent = c.Request.UserAgent() - if svcCtx.Config.Verify.RegisterVerify { - verifyTurns := turnstile.New(turnstile.Config{ - Secret: svcCtx.Config.Verify.TurnstileSecret, - Timeout: 3 * time.Second, - }) - if verify, err := verifyTurns.Verify(c, req.CfToken, req.IP); err != nil || !verify { - err = errors.Wrapf(xerr.NewErrCode(xerr.TooManyRequests), "error: %v, verify: %v", err, verify) - result.HttpResult(c, nil, err) - return - } - } + l := auth.NewTelephoneUserRegisterLogic(c.Request.Context(), svcCtx) resp, err := l.TelephoneUserRegister(&req) result.HttpResult(c, resp, err) diff --git a/internal/handler/auth/userLoginHandler.go b/internal/handler/auth/userLoginHandler.go index 20eff59..a188876 100644 --- a/internal/handler/auth/userLoginHandler.go +++ b/internal/handler/auth/userLoginHandler.go @@ -1,16 +1,11 @@ package auth import ( - "time" - "github.com/gin-gonic/gin" "github.com/perfect-panel/server/internal/logic/auth" "github.com/perfect-panel/server/internal/svc" "github.com/perfect-panel/server/internal/types" "github.com/perfect-panel/server/pkg/result" - "github.com/perfect-panel/server/pkg/turnstile" - "github.com/perfect-panel/server/pkg/xerr" - "github.com/pkg/errors" ) // User login @@ -21,17 +16,7 @@ func UserLoginHandler(svcCtx *svc.ServiceContext) func(c *gin.Context) { // get client ip req.IP = c.ClientIP() req.UserAgent = c.Request.UserAgent() - if svcCtx.Config.Verify.LoginVerify && !svcCtx.Config.Debug { - verifyTurns := turnstile.New(turnstile.Config{ - Secret: svcCtx.Config.Verify.TurnstileSecret, - Timeout: 3 * time.Second, - }) - if verify, err := verifyTurns.Verify(c, req.CfToken, req.IP); err != nil || !verify { - err = errors.Wrapf(xerr.NewErrCode(xerr.TooManyRequests), "error: %v, verify: %v", err, verify) - result.HttpResult(c, nil, err) - return - } - } + validateErr := svcCtx.Validate(&req) if validateErr != nil { result.ParamErrorResult(c, validateErr) diff --git a/internal/handler/auth/userRegisterHandler.go b/internal/handler/auth/userRegisterHandler.go index ea40223..a11c743 100644 --- a/internal/handler/auth/userRegisterHandler.go +++ b/internal/handler/auth/userRegisterHandler.go @@ -1,16 +1,11 @@ package auth import ( - "time" - "github.com/gin-gonic/gin" "github.com/perfect-panel/server/internal/logic/auth" "github.com/perfect-panel/server/internal/svc" "github.com/perfect-panel/server/internal/types" "github.com/perfect-panel/server/pkg/result" - "github.com/perfect-panel/server/pkg/turnstile" - "github.com/perfect-panel/server/pkg/xerr" - "github.com/pkg/errors" ) // User register @@ -21,16 +16,7 @@ func UserRegisterHandler(svcCtx *svc.ServiceContext) func(c *gin.Context) { // get client ip req.IP = c.ClientIP() req.UserAgent = c.Request.UserAgent() - if svcCtx.Config.Verify.RegisterVerify { - verifyTurns := turnstile.New(turnstile.Config{ - Secret: svcCtx.Config.Verify.TurnstileSecret, - Timeout: 3 * time.Second, - }) - if verify, err := verifyTurns.Verify(c, req.CfToken, req.IP); err != nil || !verify { - result.HttpResult(c, nil, errors.Wrapf(xerr.NewErrCode(xerr.TooManyRequests), "verify error: %v", err.Error())) - return - } - } + validateErr := svcCtx.Validate(&req) if validateErr != nil { result.ParamErrorResult(c, validateErr) From 884310d9510b80ef7e5a933c077dda55cf79bfe3 Mon Sep 17 00:00:00 2001 From: EUForest Date: Mon, 9 Mar 2026 22:56:20 +0800 Subject: [PATCH 15/18] feat(routes): register admin authentication and captcha endpoints - Register /v1/auth/captcha/generate route for user captcha - Register /v1/auth/admin/login route for admin login - Register /v1/auth/admin/reset-password route for admin password reset - Register /v1/auth/admin/captcha/generate route for admin captcha - Add admin authentication route group --- internal/handler/routes.go | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/internal/handler/routes.go b/internal/handler/routes.go index d811058..dbb6c06 100644 --- a/internal/handler/routes.go +++ b/internal/handler/routes.go @@ -25,6 +25,7 @@ import ( adminTool "github.com/perfect-panel/server/internal/handler/admin/tool" adminUser "github.com/perfect-panel/server/internal/handler/admin/user" auth "github.com/perfect-panel/server/internal/handler/auth" + authAdmin "github.com/perfect-panel/server/internal/handler/auth/admin" authOauth "github.com/perfect-panel/server/internal/handler/auth/oauth" common "github.com/perfect-panel/server/internal/handler/common" publicAnnouncement "github.com/perfect-panel/server/internal/handler/public/announcement" @@ -670,6 +671,9 @@ func RegisterHandlers(router *gin.Engine, serverCtx *svc.ServiceContext) { authGroupRouter.Use(middleware.DeviceMiddleware(serverCtx)) { + // Generate captcha + authGroupRouter.POST("/captcha/generate", auth.GenerateCaptchaHandler(serverCtx)) + // Check user is exist authGroupRouter.GET("/check", auth.CheckUserHandler(serverCtx)) @@ -698,6 +702,20 @@ func RegisterHandlers(router *gin.Engine, serverCtx *svc.ServiceContext) { authGroupRouter.POST("/reset/telephone", auth.TelephoneResetPasswordHandler(serverCtx)) } + authAdminGroupRouter := router.Group("/v1/auth/admin") + authAdminGroupRouter.Use(middleware.DeviceMiddleware(serverCtx)) + + { + // Generate captcha + authAdminGroupRouter.POST("/captcha/generate", authAdmin.AdminGenerateCaptchaHandler(serverCtx)) + + // Admin login + authAdminGroupRouter.POST("/login", authAdmin.AdminLoginHandler(serverCtx)) + + // Admin reset password + authAdminGroupRouter.POST("/reset", authAdmin.AdminResetPasswordHandler(serverCtx)) + } + authOauthGroupRouter := router.Group("/v1/auth/oauth") { From 17163486f6294563a959234460ec58e841a4dd0e Mon Sep 17 00:00:00 2001 From: EUForest Date: Tue, 10 Mar 2026 18:29:19 +0800 Subject: [PATCH 16/18] fix(subscribe): fix user subscription node retrieval logic to support directly assigned nodes --- .../admin/group/previewUserNodesLogic.go | 568 +++++++++++------- .../queryUserSubscribeNodeListLogic.go | 89 ++- 2 files changed, 405 insertions(+), 252 deletions(-) diff --git a/internal/logic/admin/group/previewUserNodesLogic.go b/internal/logic/admin/group/previewUserNodesLogic.go index ba91f4e..3b889df 100644 --- a/internal/logic/admin/group/previewUserNodesLogic.go +++ b/internal/logic/admin/group/previewUserNodesLogic.go @@ -70,10 +70,12 @@ func (l *PreviewUserNodesLogic) PreviewUserNodes(req *types.PreviewUserNodesRequ Id int64 NodeGroupId int64 NodeGroupIds string // JSON string + Nodes string // JSON string - 直接分配的节点ID + NodeTags string // 节点标签 } var subscribeInfos []SubscribeInfo err = l.svcCtx.DB.Table("subscribe"). - Select("id, node_group_id, node_group_ids"). + Select("id, node_group_id, node_group_ids, nodes, node_tags"). Where("id IN ?", subscribeIds). Find(&subscribeInfos).Error if err != nil { @@ -124,6 +126,28 @@ func (l *PreviewUserNodesLogic) PreviewUserNodes(req *types.PreviewUserNodesRequ logger.Infof("[PreviewUserNodes] collected node_group_ids with priority: %v", allNodeGroupIds) + // 3. 收集所有订阅中直接分配的节点ID + var allDirectNodeIds []int64 + for _, subInfo := range subscribeInfos { + if subInfo.Nodes != "" && subInfo.Nodes != "null" { + // nodes 是逗号分隔的字符串,如 "1,2,3" + nodeIdStrs := strings.Split(subInfo.Nodes, ",") + for _, idStr := range nodeIdStrs { + idStr = strings.TrimSpace(idStr) + if idStr != "" { + var nodeId int64 + if _, err := fmt.Sscanf(idStr, "%d", &nodeId); err == nil { + allDirectNodeIds = append(allDirectNodeIds, nodeId) + } + } + } + logger.Debugf("[PreviewUserNodes] subscribe_id=%d has direct nodes: %s", subInfo.Id, subInfo.Nodes) + } + } + // 去重 + allDirectNodeIds = removeDuplicateInt64(allDirectNodeIds) + logger.Infof("[PreviewUserNodes] collected direct node_ids: %v", allDirectNodeIds) + // 4. 判断分组功能是否启用 var groupEnabled string l.svcCtx.DB.Table("system"). @@ -141,8 +165,8 @@ func (l *PreviewUserNodesLogic) PreviewUserNodes(req *types.PreviewUserNodesRequ // === 启用分组功能:通过用户订阅的 node_group_id 查询节点 === logger.Infof("[PreviewUserNodes] using group-based node filtering") - if len(allNodeGroupIds) == 0 { - logger.Infof("[PreviewUserNodes] no node groups found in user subscribes") + if len(allNodeGroupIds) == 0 && len(allDirectNodeIds) == 0 { + logger.Infof("[PreviewUserNodes] no node groups and no direct nodes found in user subscribes") resp = &types.PreviewUserNodesResponse{ UserId: req.UserId, NodeGroups: []types.NodeGroupItem{}, @@ -150,67 +174,48 @@ func (l *PreviewUserNodesLogic) PreviewUserNodes(req *types.PreviewUserNodesRequ return resp, nil } - // 5. 查询所有启用的节点 - var dbNodes []node.Node - err = l.svcCtx.DB.Table("nodes"). - Where("enabled = ?", true). - Find(&dbNodes).Error - if err != nil { - logger.Errorf("[PreviewUserNodes] failed to get nodes: %v", err) - return nil, err - } - - // 6. 过滤出包含至少一个匹配节点组的节点 - // node_group_ids 为空 = 公共节点,所有人可见 - // node_group_ids 与订阅的 node_group_id 匹配 = 该节点可见 - for _, n := range dbNodes { - // 公共节点(node_group_ids 为空),所有人可见 - if len(n.NodeGroupIds) == 0 { - filteredNodes = append(filteredNodes, n) - continue + // 5. 查询所有启用的节点(只有当有节点组时才查询) + if len(allNodeGroupIds) > 0 { + var dbNodes []node.Node + err = l.svcCtx.DB.Table("nodes"). + Where("enabled = ?", true). + Find(&dbNodes).Error + if err != nil { + logger.Errorf("[PreviewUserNodes] failed to get nodes: %v", err) + return nil, err } - // 检查节点的 node_group_ids 是否与订阅的 node_group_id 有交集 - for _, nodeGroupId := range n.NodeGroupIds { - if tool.Contains(allNodeGroupIds, nodeGroupId) { + // 6. 过滤出包含至少一个匹配节点组的节点 + // node_group_ids 为空 = 公共节点,所有人可见 + // node_group_ids 与订阅的 node_group_id 匹配 = 该节点可见 + for _, n := range dbNodes { + // 公共节点(node_group_ids 为空),所有人可见 + if len(n.NodeGroupIds) == 0 { filteredNodes = append(filteredNodes, n) - break + continue + } + + // 检查节点的 node_group_ids 是否与订阅的 node_group_id 有交集 + for _, nodeGroupId := range n.NodeGroupIds { + if tool.Contains(allNodeGroupIds, nodeGroupId) { + filteredNodes = append(filteredNodes, n) + break + } } } - } - logger.Infof("[PreviewUserNodes] found %v nodes using group filter", len(filteredNodes)) + logger.Infof("[PreviewUserNodes] found %v nodes using group filter", len(filteredNodes)) + } } else { // === 未启用分组功能:通过订阅的 node_tags 查询节点 === logger.Infof("[PreviewUserNodes] using tag-based node filtering") - // 5. 获取所有订阅的 subscribeId 列表 - subscribeIds := make([]int64, len(userSubscribes)) - for i, us := range userSubscribes { - subscribeIds[i] = us.SubscribeId - } - - // 6. 查询这些订阅的 node_tags - type SubscribeNodeTags struct { - Id int64 - NodeTags string - } - var subscribeNodeTagsList []SubscribeNodeTags - err = l.svcCtx.DB.Table("subscribe"). - Where("id IN ?", subscribeIds). - Select("id, node_tags"). - Find(&subscribeNodeTagsList).Error - if err != nil { - logger.Errorf("[PreviewUserNodes] failed to get subscribe node tags: %v", err) - return nil, err - } - - // 7. 合并所有标签 + // 从已查询的 subscribeInfos 中获取 node_tags var allTags []string - for _, snt := range subscribeNodeTagsList { - if snt.NodeTags != "" { - tags := strings.Split(snt.NodeTags, ",") + for _, subInfo := range subscribeInfos { + if subInfo.NodeTags != "" { + tags := strings.Split(subInfo.NodeTags, ",") allTags = append(allTags, tags...) } } @@ -221,8 +226,8 @@ func (l *PreviewUserNodesLogic) PreviewUserNodes(req *types.PreviewUserNodesRequ logger.Infof("[PreviewUserNodes] merged tags from subscribes: %v", allTags) - if len(allTags) == 0 { - logger.Infof("[PreviewUserNodes] no tags found in subscribes") + if len(allTags) == 0 && len(allDirectNodeIds) == 0 { + logger.Infof("[PreviewUserNodes] no tags and no direct nodes found in subscribes") resp = &types.PreviewUserNodesResponse{ UserId: req.UserId, NodeGroups: []types.NodeGroupItem{}, @@ -230,218 +235,321 @@ func (l *PreviewUserNodesLogic) PreviewUserNodes(req *types.PreviewUserNodesRequ return resp, nil } - // 8. 查询所有启用的节点 - var dbNodes []node.Node - err = l.svcCtx.DB.Table("nodes"). - Where("enabled = ?", true). - Find(&dbNodes).Error - if err != nil { - logger.Errorf("[PreviewUserNodes] failed to get nodes: %v", err) - return nil, err + // 8. 查询所有启用的节点(只有当有 tags 时才查询) + if len(allTags) > 0 { + var dbNodes []node.Node + err = l.svcCtx.DB.Table("nodes"). + Where("enabled = ?", true). + Find(&dbNodes).Error + if err != nil { + logger.Errorf("[PreviewUserNodes] failed to get nodes: %v", err) + return nil, err + } + + // 9. 过滤出包含至少一个匹配标签的节点 + for _, n := range dbNodes { + if n.Tags == "" { + continue + } + nodeTags := strings.Split(n.Tags, ",") + // 检查是否有交集 + for _, tag := range nodeTags { + if tag != "" && tool.Contains(allTags, tag) { + filteredNodes = append(filteredNodes, n) + break + } + } + } + + logger.Infof("[PreviewUserNodes] found %v nodes using tag filter", len(filteredNodes)) + } + } + + // 10. 根据是否启用分组功能,选择不同的分组方式 + nodeGroupItems := make([]types.NodeGroupItem, 0) + + if isGroupEnabled { + // === 启用分组:按节点组分组 === + // 转换为 types.Node 并按节点组分组 + type NodeWithGroup struct { + Node node.Node + NodeGroupIds []int64 } - // 9. 过滤出包含至少一个匹配标签的节点 - for _, n := range dbNodes { - if n.Tags == "" { + nodesWithGroup := make([]NodeWithGroup, 0, len(filteredNodes)) + for _, n := range filteredNodes { + nodesWithGroup = append(nodesWithGroup, NodeWithGroup{ + Node: n, + NodeGroupIds: n.NodeGroupIds, + }) + } + + // 按节点组分组节点 + type NodeGroupMap struct { + Id int64 + Nodes []types.Node + } + + // 创建节点组映射:group_id -> nodes + groupMap := make(map[int64]*NodeGroupMap) + + // 获取所有涉及的节点组ID + allGroupIds := make([]int64, 0) + for _, ng := range nodesWithGroup { + if len(ng.NodeGroupIds) > 0 { + // 如果节点属于节点组,按第一个节点组分组 + firstGroupId := ng.NodeGroupIds[0] + if _, exists := groupMap[firstGroupId]; !exists { + groupMap[firstGroupId] = &NodeGroupMap{ + Id: firstGroupId, + Nodes: []types.Node{}, + } + allGroupIds = append(allGroupIds, firstGroupId) + } + + // 转换节点 + tags := []string{} + if ng.Node.Tags != "" { + tags = strings.Split(ng.Node.Tags, ",") + } + node := types.Node{ + Id: ng.Node.Id, + Name: ng.Node.Name, + Tags: tags, + Port: ng.Node.Port, + Address: ng.Node.Address, + ServerId: ng.Node.ServerId, + Protocol: ng.Node.Protocol, + Enabled: ng.Node.Enabled, + Sort: ng.Node.Sort, + NodeGroupIds: []int64(ng.Node.NodeGroupIds), + CreatedAt: ng.Node.CreatedAt.Unix(), + UpdatedAt: ng.Node.UpdatedAt.Unix(), + } + + groupMap[firstGroupId].Nodes = append(groupMap[firstGroupId].Nodes, node) + } else { + // 没有节点组的节点,使用 group_id = 0 作为"无节点组"分组 + if _, exists := groupMap[0]; !exists { + groupMap[0] = &NodeGroupMap{ + Id: 0, + Nodes: []types.Node{}, + } + } + + tags := []string{} + if ng.Node.Tags != "" { + tags = strings.Split(ng.Node.Tags, ",") + } + node := types.Node{ + Id: ng.Node.Id, + Name: ng.Node.Name, + Tags: tags, + Port: ng.Node.Port, + Address: ng.Node.Address, + ServerId: ng.Node.ServerId, + Protocol: ng.Node.Protocol, + Enabled: ng.Node.Enabled, + Sort: ng.Node.Sort, + NodeGroupIds: []int64(ng.Node.NodeGroupIds), + CreatedAt: ng.Node.CreatedAt.Unix(), + UpdatedAt: ng.Node.UpdatedAt.Unix(), + } + + groupMap[0].Nodes = append(groupMap[0].Nodes, node) + } + } + + // 查询节点组信息并构建响应 + nodeGroupInfoMap := make(map[int64]string) + validGroupIds := make([]int64, 0) + + if len(allGroupIds) > 0 { + type NodeGroupInfo struct { + Id int64 + Name string + } + var nodeGroupInfos []NodeGroupInfo + err = l.svcCtx.DB.Table("node_group"). + Select("id, name"). + Where("id IN ?", allGroupIds). + Find(&nodeGroupInfos).Error + if err != nil { + logger.Errorf("[PreviewUserNodes] failed to get node group infos: %v", err) + return nil, err + } + + logger.Infof("[PreviewUserNodes] found %v node group infos from %v requested", len(nodeGroupInfos), len(allGroupIds)) + + // 创建节点组信息映射和有效节点组ID列表 + for _, ngInfo := range nodeGroupInfos { + nodeGroupInfoMap[ngInfo.Id] = ngInfo.Name + validGroupIds = append(validGroupIds, ngInfo.Id) + logger.Debugf("[PreviewUserNodes] node_group[%d] = %s", ngInfo.Id, ngInfo.Name) + } + + // 记录无效的节点组ID + for _, requestedId := range allGroupIds { + found := false + for _, validId := range validGroupIds { + if requestedId == validId { + found = true + break + } + } + if !found { + logger.Infof("[PreviewUserNodes] node_group_id %d not found in database, treating as public nodes", requestedId) + } + } + } + + // 构建响应:根据有效节点组ID重新分组节点 + publicNodes := make([]types.Node, 0) + + // 遍历所有分组,重新分类节点 + for groupId, gm := range groupMap { + if groupId == 0 { + // 本来就是无节点组的节点 + publicNodes = append(publicNodes, gm.Nodes...) continue } - nodeTags := strings.Split(n.Tags, ",") - // 检查是否有交集 - for _, tag := range nodeTags { - if tag != "" && tool.Contains(allTags, tag) { - filteredNodes = append(filteredNodes, n) + + // 检查这个节点组ID是否有效 + isValid := false + for _, validId := range validGroupIds { + if groupId == validId { + isValid = true break } } + + if isValid { + // 节点组有效,添加到对应的分组 + groupName := nodeGroupInfoMap[groupId] + if groupName == "" { + groupName = fmt.Sprintf("Group %d", groupId) + } + nodeGroupItems = append(nodeGroupItems, types.NodeGroupItem{ + Id: groupId, + Name: groupName, + Nodes: gm.Nodes, + }) + logger.Infof("[PreviewUserNodes] adding node group: id=%d, name=%s, nodes=%d", groupId, groupName, len(gm.Nodes)) + } else { + // 节点组无效,节点归入公共节点组 + logger.Infof("[PreviewUserNodes] node_group_id %d invalid, moving %d nodes to public group", groupId, len(gm.Nodes)) + publicNodes = append(publicNodes, gm.Nodes...) + } } - logger.Infof("[PreviewUserNodes] found %v nodes using tag filter", len(filteredNodes)) - } + // 添加公共节点组(如果有) + if len(publicNodes) > 0 { + nodeGroupItems = append(nodeGroupItems, types.NodeGroupItem{ + Id: 0, + Name: "", + Nodes: publicNodes, + }) + logger.Infof("[PreviewUserNodes] adding public group: nodes=%d", len(publicNodes)) + } - // 10. 转换为 types.Node 并按节点组分组 - type NodeWithGroup struct { - Node node.Node - NodeGroupIds []int64 - } + } else { + // === 未启用分组:按 tag 分组 === + // 按 tag 分组节点 + tagGroupMap := make(map[string][]types.Node) - nodesWithGroup := make([]NodeWithGroup, 0, len(filteredNodes)) - for _, n := range filteredNodes { - nodesWithGroup = append(nodesWithGroup, NodeWithGroup{ - Node: n, - NodeGroupIds: []int64(n.NodeGroupIds), - }) - } - - // 11. 按节点组分组节点 - type NodeGroupMap struct { - Id int64 - Nodes []types.Node - } - - // 创建节点组映射:group_id -> nodes - groupMap := make(map[int64]*NodeGroupMap) - - // 获取所有涉及的节点组ID - allGroupIds := make([]int64, 0) - for _, ng := range nodesWithGroup { - if len(ng.NodeGroupIds) > 0 { - // 如果节点属于节点组,按第一个节点组分组(或者可以按所有节点组) - // 这里使用节点的第一个节点组 - firstGroupId := ng.NodeGroupIds[0] - if _, exists := groupMap[firstGroupId]; !exists { - groupMap[firstGroupId] = &NodeGroupMap{ - Id: firstGroupId, - Nodes: []types.Node{}, - } - allGroupIds = append(allGroupIds, firstGroupId) + for _, n := range filteredNodes { + tags := []string{} + if n.Tags != "" { + tags = strings.Split(n.Tags, ",") } // 转换节点 - tags := []string{} - if ng.Node.Tags != "" { - tags = strings.Split(ng.Node.Tags, ",") - } node := types.Node{ - Id: ng.Node.Id, - Name: ng.Node.Name, + Id: n.Id, + Name: n.Name, Tags: tags, - Port: ng.Node.Port, - Address: ng.Node.Address, - ServerId: ng.Node.ServerId, - Protocol: ng.Node.Protocol, - Enabled: ng.Node.Enabled, - Sort: ng.Node.Sort, - NodeGroupIds: []int64(ng.Node.NodeGroupIds), - CreatedAt: ng.Node.CreatedAt.Unix(), - UpdatedAt: ng.Node.UpdatedAt.Unix(), + Port: n.Port, + Address: n.Address, + ServerId: n.ServerId, + Protocol: n.Protocol, + Enabled: n.Enabled, + Sort: n.Sort, + NodeGroupIds: []int64(n.NodeGroupIds), + CreatedAt: n.CreatedAt.Unix(), + UpdatedAt: n.UpdatedAt.Unix(), } - groupMap[firstGroupId].Nodes = append(groupMap[firstGroupId].Nodes, node) - } else { - // 没有节点组的节点,使用 group_id = 0 作为"无节点组"分组 - if _, exists := groupMap[0]; !exists { - groupMap[0] = &NodeGroupMap{ - Id: 0, - Nodes: []types.Node{}, + // 将节点添加到每个匹配的 tag 分组中 + if len(tags) > 0 { + for _, tag := range tags { + tag = strings.TrimSpace(tag) + if tag != "" { + tagGroupMap[tag] = append(tagGroupMap[tag], node) + } } + } else { + // 没有 tag 的节点放入特殊分组 + tagGroupMap[""] = append(tagGroupMap[""], node) } + } - tags := []string{} - if ng.Node.Tags != "" { - tags = strings.Split(ng.Node.Tags, ",") - } - node := types.Node{ - Id: ng.Node.Id, - Name: ng.Node.Name, - Tags: tags, - Port: ng.Node.Port, - Address: ng.Node.Address, - ServerId: ng.Node.ServerId, - Protocol: ng.Node.Protocol, - Enabled: ng.Node.Enabled, - Sort: ng.Node.Sort, - NodeGroupIds: []int64(ng.Node.NodeGroupIds), - CreatedAt: ng.Node.CreatedAt.Unix(), - UpdatedAt: ng.Node.UpdatedAt.Unix(), - } - - groupMap[0].Nodes = append(groupMap[0].Nodes, node) + // 构建响应:按 tag 分组 + for tag, nodes := range tagGroupMap { + nodeGroupItems = append(nodeGroupItems, types.NodeGroupItem{ + Id: 0, // tag 分组使用 ID 0 + Name: tag, + Nodes: nodes, + }) + logger.Infof("[PreviewUserNodes] adding tag group: tag=%s, nodes=%d", tag, len(nodes)) } } - // 12. 查询节点组信息并构建响应 - nodeGroupInfoMap := make(map[int64]string) - validGroupIds := make([]int64, 0) // 存储在数据库中实际存在的节点组ID - - if len(allGroupIds) > 0 { - type NodeGroupInfo struct { - Id int64 - Name string - } - var nodeGroupInfos []NodeGroupInfo - err = l.svcCtx.DB.Table("node_group"). - Select("id, name"). - Where("id IN ?", allGroupIds). - Find(&nodeGroupInfos).Error + // 添加套餐节点组(直接分配的节点) + if len(allDirectNodeIds) > 0 { + // 查询直接分配的节点详情 + var directNodes []node.Node + err = l.svcCtx.DB.Table("nodes"). + Where("id IN ? AND enabled = ?", allDirectNodeIds, true). + Find(&directNodes).Error if err != nil { - logger.Errorf("[PreviewUserNodes] failed to get node group infos: %v", err) + logger.Errorf("[PreviewUserNodes] failed to get direct nodes: %v", err) return nil, err } - logger.Infof("[PreviewUserNodes] found %v node group infos from %v requested", len(nodeGroupInfos), len(allGroupIds)) - - // 创建节点组信息映射和有效节点组ID列表 - for _, ngInfo := range nodeGroupInfos { - nodeGroupInfoMap[ngInfo.Id] = ngInfo.Name - validGroupIds = append(validGroupIds, ngInfo.Id) - logger.Debugf("[PreviewUserNodes] node_group[%d] = %s", ngInfo.Id, ngInfo.Name) - } - - // 记录无效的节点组ID(节点有这个ID但数据库中不存在) - for _, requestedId := range allGroupIds { - found := false - for _, validId := range validGroupIds { - if requestedId == validId { - found = true - break + if len(directNodes) > 0 { + // 转换为 types.Node + directNodeItems := make([]types.Node, 0, len(directNodes)) + for _, n := range directNodes { + tags := []string{} + if n.Tags != "" { + tags = strings.Split(n.Tags, ",") } + directNodeItems = append(directNodeItems, types.Node{ + Id: n.Id, + Name: n.Name, + Tags: tags, + Port: n.Port, + Address: n.Address, + ServerId: n.ServerId, + Protocol: n.Protocol, + Enabled: n.Enabled, + Sort: n.Sort, + NodeGroupIds: []int64(n.NodeGroupIds), + CreatedAt: n.CreatedAt.Unix(), + UpdatedAt: n.UpdatedAt.Unix(), + }) } - if !found { - logger.Infof("[PreviewUserNodes] node_group_id %d not found in database, treating as public nodes", requestedId) - } - } - } - // 13. 构建响应:根据有效节点组ID重新分组节点 - nodeGroupItems := make([]types.NodeGroupItem, 0) - publicNodes := make([]types.Node, 0) // 公共节点(包括无效节点组和无节点组的节点) - - // 遍历所有分组,重新分类节点 - for groupId, gm := range groupMap { - if groupId == 0 { - // 本来就是无节点组的节点 - publicNodes = append(publicNodes, gm.Nodes...) - continue - } - - // 检查这个节点组ID是否有效(在数据库中存在) - isValid := false - for _, validId := range validGroupIds { - if groupId == validId { - isValid = true - break - } - } - - if isValid { - // 节点组有效,添加到对应的分组 - groupName := nodeGroupInfoMap[groupId] - if groupName == "" { - groupName = fmt.Sprintf("Group %d", groupId) - } + // 添加套餐节点组(使用特殊ID -1,Name 为空字符串,前端根据 ID -1 进行国际化) nodeGroupItems = append(nodeGroupItems, types.NodeGroupItem{ - Id: groupId, - Name: groupName, - Nodes: gm.Nodes, + Id: -1, + Name: "", // 空字符串,前端根据 ID -1 识别并国际化 + Nodes: directNodeItems, }) - logger.Infof("[PreviewUserNodes] adding node group: id=%d, name=%s, nodes=%d", groupId, groupName, len(gm.Nodes)) - } else { - // 节点组无效,节点归入公共节点组 - logger.Infof("[PreviewUserNodes] node_group_id %d invalid, moving %d nodes to public group", groupId, len(gm.Nodes)) - publicNodes = append(publicNodes, gm.Nodes...) + logger.Infof("[PreviewUserNodes] adding subscription nodes group: nodes=%d", len(directNodeItems)) } } - // 最后添加公共节点组(如果有) - if len(publicNodes) > 0 { - nodeGroupItems = append(nodeGroupItems, types.NodeGroupItem{ - Id: 0, - Name: "", - Nodes: publicNodes, - }) - logger.Infof("[PreviewUserNodes] adding public group: nodes=%d", len(publicNodes)) - } - // 14. 返回结果 resp = &types.PreviewUserNodesResponse{ UserId: req.UserId, diff --git a/internal/logic/public/subscribe/queryUserSubscribeNodeListLogic.go b/internal/logic/public/subscribe/queryUserSubscribeNodeListLogic.go index a619c88..b98f90d 100644 --- a/internal/logic/public/subscribe/queryUserSubscribeNodeListLogic.go +++ b/internal/logic/public/subscribe/queryUserSubscribeNodeListLogic.go @@ -177,19 +177,27 @@ func (l *QueryUserSubscribeNodeListLogic) getNodesByGroup(userSub *user.Subscrib // 按优先级获取 node_group_id:user_subscribe.node_group_id > subscribe.node_group_id > subscribe.node_group_ids[0] nodeGroupId := int64(0) source := "" + var directNodeIds []int64 // 优先级1: user_subscribe.node_group_id if userSub.NodeGroupId != 0 { nodeGroupId = userSub.NodeGroupId source = "user_subscribe.node_group_id" - } else { - // 优先级2 & 3: 从 subscribe 表获取 - subDetails, err := l.svcCtx.SubscribeModel.FindOne(l.ctx, userSub.SubscribeId) - if err != nil { - l.Errorw("[GetNodesByGroup] find subscribe details error", logger.Field("error", err.Error())) - return nil, err - } + } + // 获取 subscribe 详情(用于获取 node_group_id 和直接分配的节点) + subDetails, err := l.svcCtx.SubscribeModel.FindOne(l.ctx, userSub.SubscribeId) + if err != nil { + l.Errorw("[GetNodesByGroup] find subscribe details error", logger.Field("error", err.Error())) + return nil, err + } + + // 获取直接分配的节点ID + directNodeIds = tool.StringToInt64Slice(subDetails.Nodes) + l.Debugf("[GetNodesByGroup] direct nodes: %v", directNodeIds) + + // 如果 user_subscribe 没有 node_group_id,从 subscribe 获取 + if nodeGroupId == 0 { // 优先级2: subscribe.node_group_id if subDetails.NodeGroupId != 0 { nodeGroupId = subDetails.NodeGroupId @@ -201,29 +209,60 @@ func (l *QueryUserSubscribeNodeListLogic) getNodesByGroup(userSub *user.Subscrib } } - // 如果所有优先级都没有获取到,返回空节点列表 - if nodeGroupId == 0 { - l.Debugw("[GetNodesByGroup] no node_group_id found in any priority, returning no nodes") - return []*node.Node{}, nil - } - l.Debugf("[GetNodesByGroup] Using %s: %v", source, nodeGroupId) - // Filter nodes by node_group_id + // 查询所有启用的节点 enable := true - _, nodes, err := l.svcCtx.NodeModel.FilterNodeList(l.ctx, &node.FilterNodeParams{ - Page: 0, - Size: 1000, - NodeGroupIds: []int64{nodeGroupId}, // Filter by node_group_ids - Enabled: &enable, + _, allNodes, err := l.svcCtx.NodeModel.FilterNodeList(l.ctx, &node.FilterNodeParams{ + Page: 0, + Size: 10000, + Enabled: &enable, }) if err != nil { l.Errorw("[GetNodesByGroup] FilterNodeList error", logger.Field("error", err.Error())) return nil, err } - l.Debugf("[GetNodesByGroup] Found %d nodes for node_group_id=%d", len(nodes), nodeGroupId) - return nodes, nil + // 过滤节点 + var resultNodes []*node.Node + nodeIdMap := make(map[int64]bool) + + for _, n := range allNodes { + // 1. 公共节点(node_group_ids 为空),所有人可见 + if len(n.NodeGroupIds) == 0 { + if !nodeIdMap[n.Id] { + resultNodes = append(resultNodes, n) + nodeIdMap[n.Id] = true + } + continue + } + + // 2. 如果有节点组,检查节点是否属于该节点组 + if nodeGroupId != 0 { + for _, gid := range n.NodeGroupIds { + if gid == nodeGroupId { + if !nodeIdMap[n.Id] { + resultNodes = append(resultNodes, n) + nodeIdMap[n.Id] = true + } + break + } + } + } + } + + // 3. 添加直接分配的节点 + if len(directNodeIds) > 0 { + for _, n := range allNodes { + if tool.Contains(directNodeIds, n.Id) && !nodeIdMap[n.Id] { + resultNodes = append(resultNodes, n) + nodeIdMap[n.Id] = true + } + } + } + + l.Debugf("[GetNodesByGroup] Found %d nodes (group=%d, direct=%d)", len(resultNodes), nodeGroupId, len(directNodeIds)) + return resultNodes, nil } // getNodesByTag gets nodes based on subscribe node_ids and tags @@ -236,7 +275,13 @@ func (l *QueryUserSubscribeNodeListLogic) getNodesByTag(userSub *user.Subscribe) nodeIds := tool.StringToInt64Slice(subDetails.Nodes) tags := strings.Split(subDetails.NodeTags, ",") - + newTags := make([]string, 0) + for _, t := range tags { + if t != "" { + newTags = append(newTags, t) + } + } + tags = newTags l.Debugf("[Generate Subscribe]nodes: %v, NodeTags: %v", nodeIds, tags) enable := true From 06a2425474a8c600d77b7940c94d71a9da13b6f6 Mon Sep 17 00:00:00 2001 From: EUForest Date: Sat, 14 Mar 2026 12:41:52 +0800 Subject: [PATCH 17/18] feat(subscribe): add traffic limit rules and user traffic stats - Add subscribe traffic_limit schema and migration\n- Support traffic_limit in admin create/update and list/details\n- Apply traffic_limit when building server user list speed limits\n- Add public user traffic stats API --- apis/admin/group.api | 34 +-- apis/admin/subscribe.api | 2 + apis/public/user.api | 20 ++ apis/types.api | 32 ++- .../02133_add_expired_node_group.down.sql | 12 ++ .../02133_add_expired_node_group.up.sql | 14 ++ .../02134_subscribe_traffic_limit.down.sql | 6 + .../02134_subscribe_traffic_limit.up.sql | 22 ++ .../public/user/getUserTrafficStatsHandler.go | 26 +++ internal/handler/routes.go | 3 + .../logic/admin/group/createNodeGroupLogic.go | 51 ++++- .../logic/admin/group/deleteNodeGroupLogic.go | 5 +- .../admin/group/exportGroupResultLogic.go | 3 +- .../admin/group/getNodeGroupListLogic.go | 40 ++-- .../group/getSubscribeGroupMappingLogic.go | 4 +- .../admin/group/previewUserNodesLogic.go | 33 ++- .../admin/group/recalculateGroupLogic.go | 40 ++-- .../logic/admin/group/updateNodeGroupLogic.go | 45 ++++ .../admin/subscribe/createSubscribeLogic.go | 7 + .../subscribe/getSubscribeDetailsLogic.go | 6 + .../admin/subscribe/getSubscribeListLogic.go | 6 + .../admin/subscribe/updateSubscribeLogic.go | 7 + .../queryUserSubscribeNodeListLogic.go | 97 ++++++++- .../public/user/getUserTrafficStatsLogic.go | 138 ++++++++++++ .../public/user/queryUserSubscribeLogic.go | 4 + .../logic/server/getServerUserListLogic.go | 198 +++++++++++++++++- internal/logic/subscribe/subscribeLogic.go | 63 ++++++ internal/model/group/node_group.go | 22 +- internal/model/subscribe/subscribe.go | 1 + internal/model/user/model.go | 23 +- internal/model/user/user.go | 40 ++-- internal/types/types.go | 88 ++++++-- queue/logic/order/activateOrderLogic.go | 31 +-- queue/logic/traffic/trafficStatisticsLogic.go | 4 +- 34 files changed, 974 insertions(+), 153 deletions(-) create mode 100644 initialize/migrate/database/02133_add_expired_node_group.down.sql create mode 100644 initialize/migrate/database/02133_add_expired_node_group.up.sql create mode 100644 initialize/migrate/database/02134_subscribe_traffic_limit.down.sql create mode 100644 initialize/migrate/database/02134_subscribe_traffic_limit.up.sql create mode 100644 internal/handler/public/user/getUserTrafficStatsHandler.go create mode 100644 internal/logic/public/user/getUserTrafficStatsLogic.go diff --git a/apis/admin/group.api b/apis/admin/group.api index e229fc4..de4aad9 100644 --- a/apis/admin/group.api +++ b/apis/admin/group.api @@ -28,22 +28,30 @@ type ( } // CreateNodeGroupRequest CreateNodeGroupRequest { - Name string `json:"name" validate:"required"` - Description string `json:"description"` - Sort int `json:"sort"` - ForCalculation *bool `json:"for_calculation"` - MinTrafficGB *int64 `json:"min_traffic_gb,omitempty"` - MaxTrafficGB *int64 `json:"max_traffic_gb,omitempty"` + Name string `json:"name" validate:"required"` + Description string `json:"description"` + Sort int `json:"sort"` + ForCalculation *bool `json:"for_calculation"` + IsExpiredGroup *bool `json:"is_expired_group"` + ExpiredDaysLimit *int `json:"expired_days_limit"` + MaxTrafficGBExpired *int64 `json:"max_traffic_gb_expired,omitempty"` + SpeedLimit *int `json:"speed_limit"` + MinTrafficGB *int64 `json:"min_traffic_gb,omitempty"` + MaxTrafficGB *int64 `json:"max_traffic_gb,omitempty"` } // UpdateNodeGroupRequest UpdateNodeGroupRequest { - Id int64 `json:"id" validate:"required"` - Name string `json:"name"` - Description string `json:"description"` - Sort int `json:"sort"` - ForCalculation *bool `json:"for_calculation"` - MinTrafficGB *int64 `json:"min_traffic_gb,omitempty"` - MaxTrafficGB *int64 `json:"max_traffic_gb,omitempty"` + Id int64 `json:"id" validate:"required"` + Name string `json:"name"` + Description string `json:"description"` + Sort int `json:"sort"` + ForCalculation *bool `json:"for_calculation"` + IsExpiredGroup *bool `json:"is_expired_group"` + ExpiredDaysLimit *int `json:"expired_days_limit"` + MaxTrafficGBExpired *int64 `json:"max_traffic_gb_expired,omitempty"` + SpeedLimit *int `json:"speed_limit"` + MinTrafficGB *int64 `json:"min_traffic_gb,omitempty"` + MaxTrafficGB *int64 `json:"max_traffic_gb,omitempty"` } // DeleteNodeGroupRequest DeleteNodeGroupRequest { diff --git a/apis/admin/subscribe.api b/apis/admin/subscribe.api index 881f021..8a662b8 100644 --- a/apis/admin/subscribe.api +++ b/apis/admin/subscribe.api @@ -50,6 +50,7 @@ type ( NodeTags []string `json:"node_tags"` NodeGroupIds []int64 `json:"node_group_ids,omitempty"` NodeGroupId int64 `json:"node_group_id"` + TrafficLimit []TrafficLimit `json:"traffic_limit"` Show *bool `json:"show"` Sell *bool `json:"sell"` DeductionRatio int64 `json:"deduction_ratio"` @@ -77,6 +78,7 @@ type ( NodeTags []string `json:"node_tags"` NodeGroupIds []int64 `json:"node_group_ids,omitempty"` NodeGroupId int64 `json:"node_group_id"` + TrafficLimit []TrafficLimit `json:"traffic_limit"` Show *bool `json:"show"` Sell *bool `json:"sell"` Sort int64 `json:"sort"` diff --git a/apis/public/user.api b/apis/public/user.api index a6eb50f..426cf93 100644 --- a/apis/public/user.api +++ b/apis/public/user.api @@ -144,6 +144,22 @@ type ( HistoryContinuousDays int64 `json:"history_continuous_days"` LongestSingleConnection int64 `json:"longest_single_connection"` } + GetUserTrafficStatsRequest { + UserSubscribeId string `form:"user_subscribe_id" validate:"required"` + Days int `form:"days" validate:"required,oneof=7 30"` + } + DailyTrafficStats { + Date string `json:"date"` + Upload int64 `json:"upload"` + Download int64 `json:"download"` + Total int64 `json:"total"` + } + GetUserTrafficStatsResponse { + List []DailyTrafficStats `json:"list"` + TotalUpload int64 `json:"total_upload"` + TotalDownload int64 `json:"total_download"` + TotalTraffic int64 `json:"total_traffic"` + } ) @server ( @@ -271,6 +287,10 @@ service ppanel { @doc "Delete Current User Account" @handler DeleteCurrentUserAccount delete /current_user_account + + @doc "Get User Traffic Statistics" + @handler GetUserTrafficStats + get /traffic_stats (GetUserTrafficStatsRequest) returns (GetUserTrafficStatsResponse) } @server ( diff --git a/apis/types.api b/apis/types.api index 4b9c7f6..827e038 100644 --- a/apis/types.api +++ b/apis/types.api @@ -211,6 +211,12 @@ type ( Quantity int64 `json:"quantity"` Discount float64 `json:"discount"` } + TrafficLimit { + StatType string `json:"stat_type"` + StatValue int64 `json:"stat_value"` + TrafficUsage int64 `json:"traffic_usage"` + SpeedLimit int64 `json:"speed_limit"` + } Subscribe { Id int64 `json:"id"` Name string `json:"name"` @@ -229,6 +235,7 @@ type ( NodeTags []string `json:"node_tags"` NodeGroupIds []int64 `json:"node_group_ids,omitempty"` NodeGroupId int64 `json:"node_group_id"` + TrafficLimit []TrafficLimit `json:"traffic_limit"` Show bool `json:"show"` Sell bool `json:"sell"` Sort int64 `json:"sort"` @@ -486,6 +493,7 @@ type ( } UserSubscribe { Id int64 `json:"id"` + IdStr string `json:"id_str"` UserId int64 `json:"user_id"` OrderId int64 `json:"order_id"` SubscribeId int64 `json:"subscribe_id"` @@ -882,16 +890,20 @@ type ( // ===== 分组功能类型定义 ===== // NodeGroup 节点组 NodeGroup { - Id int64 `json:"id"` - Name string `json:"name"` - Description string `json:"description"` - Sort int `json:"sort"` - ForCalculation bool `json:"for_calculation"` - MinTrafficGB int64 `json:"min_traffic_gb,omitempty"` - MaxTrafficGB int64 `json:"max_traffic_gb,omitempty"` - NodeCount int64 `json:"node_count,omitempty"` - CreatedAt int64 `json:"created_at"` - UpdatedAt int64 `json:"updated_at"` + Id int64 `json:"id"` + Name string `json:"name"` + Description string `json:"description"` + Sort int `json:"sort"` + ForCalculation bool `json:"for_calculation"` + IsExpiredGroup bool `json:"is_expired_group"` + ExpiredDaysLimit int `json:"expired_days_limit"` + MaxTrafficGBExpired int64 `json:"max_traffic_gb_expired,omitempty"` + SpeedLimit int `json:"speed_limit"` + MinTrafficGB int64 `json:"min_traffic_gb,omitempty"` + MaxTrafficGB int64 `json:"max_traffic_gb,omitempty"` + NodeCount int64 `json:"node_count,omitempty"` + CreatedAt int64 `json:"created_at"` + UpdatedAt int64 `json:"updated_at"` } // GroupHistory 分组历史记录 GroupHistory { diff --git a/initialize/migrate/database/02133_add_expired_node_group.down.sql b/initialize/migrate/database/02133_add_expired_node_group.down.sql new file mode 100644 index 0000000..aeacd0e --- /dev/null +++ b/initialize/migrate/database/02133_add_expired_node_group.down.sql @@ -0,0 +1,12 @@ +-- 回滚 user_subscribe 表的过期流量字段 +ALTER TABLE `user_subscribe` +DROP COLUMN `expired_upload`, +DROP COLUMN `expired_download`; + +-- 回滚 node_group 表的过期节点组字段 +ALTER TABLE `node_group` +DROP INDEX `idx_is_expired_group`, +DROP COLUMN `speed_limit`, +DROP COLUMN `max_traffic_gb_expired`, +DROP COLUMN `expired_days_limit`, +DROP COLUMN `is_expired_group`; diff --git a/initialize/migrate/database/02133_add_expired_node_group.up.sql b/initialize/migrate/database/02133_add_expired_node_group.up.sql new file mode 100644 index 0000000..283f9a5 --- /dev/null +++ b/initialize/migrate/database/02133_add_expired_node_group.up.sql @@ -0,0 +1,14 @@ +-- 为 node_group 表添加过期节点组相关字段 +ALTER TABLE `node_group` +ADD COLUMN `is_expired_group` tinyint(1) NOT NULL DEFAULT 0 COMMENT 'Is Expired Group: 0=normal, 1=expired group' AFTER `for_calculation`, +ADD COLUMN `expired_days_limit` int NOT NULL DEFAULT 7 COMMENT 'Expired days limit (days)' AFTER `is_expired_group`, +ADD COLUMN `max_traffic_gb_expired` bigint DEFAULT 0 COMMENT 'Max traffic for expired users (GB)' AFTER `expired_days_limit`, +ADD COLUMN `speed_limit` int NOT NULL DEFAULT 0 COMMENT 'Speed limit (KB/s)' AFTER `max_traffic_gb_expired`; + +-- 添加索引 +ALTER TABLE `node_group` ADD INDEX `idx_is_expired_group` (`is_expired_group`); + +-- 为 user_subscribe 表添加过期流量统计字段 +ALTER TABLE `user_subscribe` +ADD COLUMN `expired_download` bigint NOT NULL DEFAULT 0 COMMENT 'Expired period download traffic (bytes)' AFTER `upload`, +ADD COLUMN `expired_upload` bigint NOT NULL DEFAULT 0 COMMENT 'Expired period upload traffic (bytes)' AFTER `expired_download`; diff --git a/initialize/migrate/database/02134_subscribe_traffic_limit.down.sql b/initialize/migrate/database/02134_subscribe_traffic_limit.down.sql new file mode 100644 index 0000000..4fdf319 --- /dev/null +++ b/initialize/migrate/database/02134_subscribe_traffic_limit.down.sql @@ -0,0 +1,6 @@ +-- Purpose: Rollback traffic_limit rules from subscribe +-- Author: Claude Code +-- Date: 2026-03-12 + +-- ===== Remove traffic_limit column from subscribe table ===== +ALTER TABLE `subscribe` DROP COLUMN IF EXISTS `traffic_limit`; diff --git a/initialize/migrate/database/02134_subscribe_traffic_limit.up.sql b/initialize/migrate/database/02134_subscribe_traffic_limit.up.sql new file mode 100644 index 0000000..18a9f30 --- /dev/null +++ b/initialize/migrate/database/02134_subscribe_traffic_limit.up.sql @@ -0,0 +1,22 @@ +-- Purpose: Add traffic_limit rules to subscribe +-- Author: Claude Code +-- Date: 2026-03-12 + +-- ===== Add traffic_limit column to subscribe table ===== +SET @column_exists = ( + SELECT COUNT(*) + FROM INFORMATION_SCHEMA.COLUMNS + WHERE TABLE_SCHEMA = DATABASE() + AND TABLE_NAME = 'subscribe' + AND COLUMN_NAME = 'traffic_limit' +); + +SET @sql = IF( + @column_exists = 0, + 'ALTER TABLE `subscribe` ADD COLUMN `traffic_limit` TEXT NULL COMMENT ''Traffic Limit Rules (JSON)'' AFTER `node_group_id`', + 'SELECT ''Column traffic_limit already exists in subscribe table''' +); + +PREPARE stmt FROM @sql; +EXECUTE stmt; +DEALLOCATE PREPARE stmt; diff --git a/internal/handler/public/user/getUserTrafficStatsHandler.go b/internal/handler/public/user/getUserTrafficStatsHandler.go new file mode 100644 index 0000000..5c3c7dd --- /dev/null +++ b/internal/handler/public/user/getUserTrafficStatsHandler.go @@ -0,0 +1,26 @@ +package user + +import ( + "github.com/gin-gonic/gin" + "github.com/perfect-panel/server/internal/logic/public/user" + "github.com/perfect-panel/server/internal/svc" + "github.com/perfect-panel/server/internal/types" + "github.com/perfect-panel/server/pkg/result" +) + +// Get User Traffic Statistics +func GetUserTrafficStatsHandler(svcCtx *svc.ServiceContext) func(c *gin.Context) { + return func(c *gin.Context) { + var req types.GetUserTrafficStatsRequest + _ = c.ShouldBind(&req) + validateErr := svcCtx.Validate(&req) + if validateErr != nil { + result.ParamErrorResult(c, validateErr) + return + } + + l := user.NewGetUserTrafficStatsLogic(c.Request.Context(), svcCtx) + resp, err := l.GetUserTrafficStats(&req) + result.HttpResult(c, resp, err) + } +} diff --git a/internal/handler/routes.go b/internal/handler/routes.go index dbb6c06..e3995be 100644 --- a/internal/handler/routes.go +++ b/internal/handler/routes.go @@ -955,6 +955,9 @@ func RegisterHandlers(router *gin.Engine, serverCtx *svc.ServiceContext) { // Reset User Subscribe Token publicUserGroupRouter.PUT("/subscribe_token", publicUser.ResetUserSubscribeTokenHandler(serverCtx)) + // Get User Traffic Statistics + publicUserGroupRouter.GET("/traffic_stats", publicUser.GetUserTrafficStatsHandler(serverCtx)) + // Unbind Device publicUserGroupRouter.PUT("/unbind_device", publicUser.UnbindDeviceHandler(serverCtx)) diff --git a/internal/logic/admin/group/createNodeGroupLogic.go b/internal/logic/admin/group/createNodeGroupLogic.go index 2d361d6..9e68c10 100644 --- a/internal/logic/admin/group/createNodeGroupLogic.go +++ b/internal/logic/admin/group/createNodeGroupLogic.go @@ -2,6 +2,7 @@ package group import ( "context" + "errors" "time" "github.com/perfect-panel/server/internal/model/group" @@ -25,17 +26,51 @@ func NewCreateNodeGroupLogic(ctx context.Context, svcCtx *svc.ServiceContext) *C } func (l *CreateNodeGroupLogic) CreateNodeGroup(req *types.CreateNodeGroupRequest) error { + // 验证:系统中只能有一个过期节点组 + if req.IsExpiredGroup != nil && *req.IsExpiredGroup { + var count int64 + err := l.svcCtx.DB.Model(&group.NodeGroup{}). + Where("is_expired_group = ?", true). + Count(&count).Error + if err != nil { + logger.Errorf("failed to check expired group count: %v", err) + return err + } + if count > 0 { + return errors.New("system already has an expired node group, cannot create multiple") + } + } + // 创建节点组 nodeGroup := &group.NodeGroup{ - Name: req.Name, - Description: req.Description, - Sort: req.Sort, - ForCalculation: req.ForCalculation, - MinTrafficGB: req.MinTrafficGB, - MaxTrafficGB: req.MaxTrafficGB, - CreatedAt: time.Now(), - UpdatedAt: time.Now(), + Name: req.Name, + Description: req.Description, + Sort: req.Sort, + ForCalculation: req.ForCalculation, + IsExpiredGroup: req.IsExpiredGroup, + MaxTrafficGBExpired: req.MaxTrafficGBExpired, + MinTrafficGB: req.MinTrafficGB, + MaxTrafficGB: req.MaxTrafficGB, + CreatedAt: time.Now(), + UpdatedAt: time.Now(), } + + // 设置过期节点组的默认值 + if req.IsExpiredGroup != nil && *req.IsExpiredGroup { + // 过期节点组不参与分组计算 + falseValue := false + nodeGroup.ForCalculation = &falseValue + + if req.ExpiredDaysLimit != nil { + nodeGroup.ExpiredDaysLimit = *req.ExpiredDaysLimit + } else { + nodeGroup.ExpiredDaysLimit = 7 // 默认7天 + } + if req.SpeedLimit != nil { + nodeGroup.SpeedLimit = *req.SpeedLimit + } + } + if err := l.svcCtx.DB.Create(nodeGroup).Error; err != nil { logger.Errorf("failed to create node group: %v", err) return err diff --git a/internal/logic/admin/group/deleteNodeGroupLogic.go b/internal/logic/admin/group/deleteNodeGroupLogic.go index 947dc49..16c89d4 100644 --- a/internal/logic/admin/group/deleteNodeGroupLogic.go +++ b/internal/logic/admin/group/deleteNodeGroupLogic.go @@ -6,6 +6,7 @@ import ( "fmt" "github.com/perfect-panel/server/internal/model/group" + "github.com/perfect-panel/server/internal/model/node" "github.com/perfect-panel/server/internal/svc" "github.com/perfect-panel/server/internal/types" "github.com/perfect-panel/server/pkg/logger" @@ -37,9 +38,9 @@ func (l *DeleteNodeGroupLogic) DeleteNodeGroup(req *types.DeleteNodeGroupRequest return err } - // 检查是否有关联节点 + // 检查是否有关联节点(使用JSON_CONTAINS查询node_group_ids数组) var nodeCount int64 - if err := l.svcCtx.DB.Table("nodes").Where("node_group_id = ?", nodeGroup.Id).Count(&nodeCount).Error; err != nil { + if err := l.svcCtx.DB.Model(&node.Node{}).Where("JSON_CONTAINS(node_group_ids, ?)", fmt.Sprintf("[%d]", nodeGroup.Id)).Count(&nodeCount).Error; err != nil { logger.Errorf("failed to count nodes in group: %v", err) return err } diff --git a/internal/logic/admin/group/exportGroupResultLogic.go b/internal/logic/admin/group/exportGroupResultLogic.go index ef2183f..a84befa 100644 --- a/internal/logic/admin/group/exportGroupResultLogic.go +++ b/internal/logic/admin/group/exportGroupResultLogic.go @@ -7,6 +7,7 @@ import ( "fmt" "github.com/perfect-panel/server/internal/model/group" + "github.com/perfect-panel/server/internal/model/user" "github.com/perfect-panel/server/internal/svc" "github.com/perfect-panel/server/internal/types" "github.com/perfect-panel/server/pkg/logger" @@ -77,7 +78,7 @@ func (l *ExportGroupResultLogic) ExportGroupResult(req *types.ExportGroupResultR NodeGroupId int64 `json:"node_group_id"` } var userSubscribes []UserNodeGroupInfo - if err := l.svcCtx.DB.Table("user_subscribe"). + if err := l.svcCtx.DB.Model(&user.Subscribe{}). Select("DISTINCT user_id as id, node_group_id"). Where("node_group_id > ?", 0). Find(&userSubscribes).Error; err != nil { diff --git a/internal/logic/admin/group/getNodeGroupListLogic.go b/internal/logic/admin/group/getNodeGroupListLogic.go index abd1a6a..9595393 100644 --- a/internal/logic/admin/group/getNodeGroupListLogic.go +++ b/internal/logic/admin/group/getNodeGroupListLogic.go @@ -2,8 +2,10 @@ package group import ( "context" + "fmt" "github.com/perfect-panel/server/internal/model/group" + "github.com/perfect-panel/server/internal/model/node" "github.com/perfect-panel/server/internal/svc" "github.com/perfect-panel/server/internal/types" "github.com/perfect-panel/server/pkg/logger" @@ -46,9 +48,9 @@ func (l *GetNodeGroupListLogic) GetNodeGroupList(req *types.GetNodeGroupListRequ // 转换为响应格式 var list []types.NodeGroup for _, ng := range nodeGroups { - // 统计该组的节点数 + // 统计该组的节点数(JSON数组查询) var nodeCount int64 - l.svcCtx.DB.Table("nodes").Where("node_group_id = ?", ng.Id).Count(&nodeCount) + l.svcCtx.DB.Model(&node.Node{}).Where("JSON_CONTAINS(node_group_ids, ?)", fmt.Sprintf("[%d]", ng.Id)).Count(&nodeCount) // 处理指针类型的字段 var forCalculation bool @@ -58,25 +60,37 @@ func (l *GetNodeGroupListLogic) GetNodeGroupList(req *types.GetNodeGroupListRequ forCalculation = true // 默认值 } - var minTrafficGB, maxTrafficGB int64 + var isExpiredGroup bool + if ng.IsExpiredGroup != nil { + isExpiredGroup = *ng.IsExpiredGroup + } + + var minTrafficGB, maxTrafficGB, maxTrafficGBExpired int64 if ng.MinTrafficGB != nil { minTrafficGB = *ng.MinTrafficGB } if ng.MaxTrafficGB != nil { maxTrafficGB = *ng.MaxTrafficGB } + if ng.MaxTrafficGBExpired != nil { + maxTrafficGBExpired = *ng.MaxTrafficGBExpired + } list = append(list, types.NodeGroup{ - Id: ng.Id, - Name: ng.Name, - Description: ng.Description, - Sort: ng.Sort, - ForCalculation: forCalculation, - MinTrafficGB: minTrafficGB, - MaxTrafficGB: maxTrafficGB, - NodeCount: nodeCount, - CreatedAt: ng.CreatedAt.Unix(), - UpdatedAt: ng.UpdatedAt.Unix(), + Id: ng.Id, + Name: ng.Name, + Description: ng.Description, + Sort: ng.Sort, + ForCalculation: forCalculation, + IsExpiredGroup: isExpiredGroup, + ExpiredDaysLimit: ng.ExpiredDaysLimit, + MaxTrafficGBExpired: maxTrafficGBExpired, + SpeedLimit: ng.SpeedLimit, + MinTrafficGB: minTrafficGB, + MaxTrafficGB: maxTrafficGB, + NodeCount: nodeCount, + CreatedAt: ng.CreatedAt.Unix(), + UpdatedAt: ng.UpdatedAt.Unix(), }) } diff --git a/internal/logic/admin/group/getSubscribeGroupMappingLogic.go b/internal/logic/admin/group/getSubscribeGroupMappingLogic.go index fb3ed90..cd26305 100644 --- a/internal/logic/admin/group/getSubscribeGroupMappingLogic.go +++ b/internal/logic/admin/group/getSubscribeGroupMappingLogic.go @@ -28,14 +28,14 @@ func NewGetSubscribeGroupMappingLogic(ctx context.Context, svcCtx *svc.ServiceCo func (l *GetSubscribeGroupMappingLogic) GetSubscribeGroupMapping(req *types.GetSubscribeGroupMappingRequest) (resp *types.GetSubscribeGroupMappingResponse, err error) { // 1. 查询所有订阅套餐 var subscribes []subscribe.Subscribe - if err := l.svcCtx.DB.Table("subscribe").Find(&subscribes).Error; err != nil { + if err := l.svcCtx.DB.Model(&subscribe.Subscribe{}).Find(&subscribes).Error; err != nil { l.Errorw("[GetSubscribeGroupMapping] failed to query subscribes", logger.Field("error", err.Error())) return nil, err } // 2. 查询所有节点组 var nodeGroups []group.NodeGroup - if err := l.svcCtx.DB.Table("node_group").Find(&nodeGroups).Error; err != nil { + if err := l.svcCtx.DB.Model(&group.NodeGroup{}).Find(&nodeGroups).Error; err != nil { l.Errorw("[GetSubscribeGroupMapping] failed to query node groups", logger.Field("error", err.Error())) return nil, err } diff --git a/internal/logic/admin/group/previewUserNodesLogic.go b/internal/logic/admin/group/previewUserNodesLogic.go index 3b889df..2122da5 100644 --- a/internal/logic/admin/group/previewUserNodesLogic.go +++ b/internal/logic/admin/group/previewUserNodesLogic.go @@ -6,7 +6,10 @@ import ( "fmt" "strings" + "github.com/perfect-panel/server/internal/model/group" "github.com/perfect-panel/server/internal/model/node" + "github.com/perfect-panel/server/internal/model/subscribe" + "github.com/perfect-panel/server/internal/model/user" "github.com/perfect-panel/server/internal/svc" "github.com/perfect-panel/server/internal/types" "github.com/perfect-panel/server/pkg/logger" @@ -38,7 +41,7 @@ func (l *PreviewUserNodesLogic) PreviewUserNodes(req *types.PreviewUserNodesRequ NodeGroupId int64 // 用户订阅的 node_group_id(单个ID) } var userSubscribes []UserSubscribe - err = l.svcCtx.DB.Table("user_subscribe"). + err = l.svcCtx.DB.Model(&user.Subscribe{}). Select("id, user_id, subscribe_id, node_group_id"). Where("user_id = ? AND status IN ?", req.UserId, []int8{0, 1}). Find(&userSubscribes).Error @@ -74,7 +77,7 @@ func (l *PreviewUserNodesLogic) PreviewUserNodes(req *types.PreviewUserNodesRequ NodeTags string // 节点标签 } var subscribeInfos []SubscribeInfo - err = l.svcCtx.DB.Table("subscribe"). + err = l.svcCtx.DB.Model(&subscribe.Subscribe{}). Select("id, node_group_id, node_group_ids, nodes, node_tags"). Where("id IN ?", subscribeIds). Find(&subscribeInfos).Error @@ -149,15 +152,23 @@ func (l *PreviewUserNodesLogic) PreviewUserNodes(req *types.PreviewUserNodesRequ logger.Infof("[PreviewUserNodes] collected direct node_ids: %v", allDirectNodeIds) // 4. 判断分组功能是否启用 - var groupEnabled string - l.svcCtx.DB.Table("system"). + type SystemConfig struct { + Value string + } + var config SystemConfig + l.svcCtx.DB.Model(&struct { + Category string `gorm:"column:category"` + Key string `gorm:"column:key"` + Value string `gorm:"column:value"` + }{}). + Table("system"). Where("`category` = ? AND `key` = ?", "group", "enabled"). Select("value"). - Scan(&groupEnabled) + Scan(&config) - logger.Infof("[PreviewUserNodes] groupEnabled: %v", groupEnabled) + logger.Infof("[PreviewUserNodes] groupEnabled: %v", config.Value) - isGroupEnabled := groupEnabled == "true" || groupEnabled == "1" + isGroupEnabled := config.Value == "true" || config.Value == "1" var filteredNodes []node.Node @@ -177,7 +188,7 @@ func (l *PreviewUserNodesLogic) PreviewUserNodes(req *types.PreviewUserNodesRequ // 5. 查询所有启用的节点(只有当有节点组时才查询) if len(allNodeGroupIds) > 0 { var dbNodes []node.Node - err = l.svcCtx.DB.Table("nodes"). + err = l.svcCtx.DB.Model(&node.Node{}). Where("enabled = ?", true). Find(&dbNodes).Error if err != nil { @@ -238,7 +249,7 @@ func (l *PreviewUserNodesLogic) PreviewUserNodes(req *types.PreviewUserNodesRequ // 8. 查询所有启用的节点(只有当有 tags 时才查询) if len(allTags) > 0 { var dbNodes []node.Node - err = l.svcCtx.DB.Table("nodes"). + err = l.svcCtx.DB.Model(&node.Node{}). Where("enabled = ?", true). Find(&dbNodes).Error if err != nil { @@ -370,7 +381,7 @@ func (l *PreviewUserNodesLogic) PreviewUserNodes(req *types.PreviewUserNodesRequ Name string } var nodeGroupInfos []NodeGroupInfo - err = l.svcCtx.DB.Table("node_group"). + err = l.svcCtx.DB.Model(&group.NodeGroup{}). Select("id, name"). Where("id IN ?", allGroupIds). Find(&nodeGroupInfos).Error @@ -508,7 +519,7 @@ func (l *PreviewUserNodesLogic) PreviewUserNodes(req *types.PreviewUserNodesRequ if len(allDirectNodeIds) > 0 { // 查询直接分配的节点详情 var directNodes []node.Node - err = l.svcCtx.DB.Table("nodes"). + err = l.svcCtx.DB.Model(&node.Node{}). Where("id IN ? AND enabled = ?", allDirectNodeIds, true). Find(&directNodes).Error if err != nil { diff --git a/internal/logic/admin/group/recalculateGroupLogic.go b/internal/logic/admin/group/recalculateGroupLogic.go index d43557c..cb16188 100644 --- a/internal/logic/admin/group/recalculateGroupLogic.go +++ b/internal/logic/admin/group/recalculateGroupLogic.go @@ -3,9 +3,13 @@ package group import ( "context" "encoding/json" + "fmt" "time" "github.com/perfect-panel/server/internal/model/group" + "github.com/perfect-panel/server/internal/model/node" + "github.com/perfect-panel/server/internal/model/subscribe" + "github.com/perfect-panel/server/internal/model/user" "github.com/perfect-panel/server/internal/svc" "github.com/perfect-panel/server/internal/types" "github.com/perfect-panel/server/pkg/logger" @@ -131,7 +135,7 @@ func (l *RecalculateGroupLogic) getUserEmail(tx *gorm.DB, userId int64) string { } var authMethod UserAuthMethod - if err := tx.Table("user_auth_methods"). + if err := tx.Model(&user.AuthMethods{}). Select("auth_identifier"). Where("user_id = ? AND (auth_type = ? OR auth_type = ?)", userId, "email", "6"). First(&authMethod).Error; err != nil { @@ -152,7 +156,7 @@ func (l *RecalculateGroupLogic) executeAverageGrouping(tx *gorm.DB, historyId in } var userSubscribes []UserSubscribeInfo - if err := tx.Table("user_subscribe"). + if err := tx.Model(&user.Subscribe{}). Select("id, user_id, subscribe_id"). Where("group_locked = ? AND status IN (0, 1)", 0). // 只查询未锁定且有效的用户订阅 Scan(&userSubscribes).Error; err != nil { @@ -168,7 +172,7 @@ func (l *RecalculateGroupLogic) executeAverageGrouping(tx *gorm.DB, historyId in // 1.5 查询所有参与计算的节点组ID var calculationNodeGroups []group.NodeGroup - if err := tx.Table("node_group"). + if err := tx.Model(&group.NodeGroup{}). Select("id"). Where("for_calculation = ?", true). Scan(&calculationNodeGroups).Error; err != nil { @@ -195,7 +199,7 @@ func (l *RecalculateGroupLogic) executeAverageGrouping(tx *gorm.DB, historyId in NodeGroupIds string `json:"node_group_ids"` // JSON string } var subscribeInfos []SubscribeInfo - if err := tx.Table("subscribe"). + if err := tx.Model(&subscribe.Subscribe{}). Select("id, node_group_ids"). Where("id IN ?", subscribeIds). Find(&subscribeInfos).Error; err != nil { @@ -261,10 +265,10 @@ func (l *RecalculateGroupLogic) executeAverageGrouping(tx *gorm.DB, historyId in } } - // 如果没有节点组ID,跳过 + // 如果没有节点组ID,跳过 if len(nodeGroupIds) == 0 { l.Debugf("no valid node_group_ids for subscribe_id=%d, setting to 0", subInfo.Id) - if err := tx.Table("user_subscribe"). + if err := tx.Model(&user.Subscribe{}). Where("id = ?", us.Id). Update("node_group_id", 0).Error; err != nil { l.Errorw("failed to update user_subscribe node_group_id", @@ -290,7 +294,7 @@ func (l *RecalculateGroupLogic) executeAverageGrouping(tx *gorm.DB, historyId in } // 更新 user_subscribe 的 node_group_id 字段(单个ID) - if err := tx.Table("user_subscribe"). + if err := tx.Model(&user.Subscribe{}). Where("id = ?", us.Id). Update("node_group_id", selectedNodeGroupId).Error; err != nil { l.Errorw("failed to update user_subscribe node_group_id", @@ -329,8 +333,8 @@ func (l *RecalculateGroupLogic) executeAverageGrouping(tx *gorm.DB, historyId in // 统计该节点组的节点数 var nodeCount int64 = 0 if nodeGroupId > 0 { - if err := tx.Table("nodes"). - Where("JSON_CONTAINS(node_group_ids, ?)", nodeGroupId). + if err := tx.Model(&node.Node{}). + Where("JSON_CONTAINS(node_group_ids, ?)", fmt.Sprintf("[%d]", nodeGroupId)). Count(&nodeCount).Error; err != nil { l.Errorw("failed to count nodes", logger.Field("node_group_id", nodeGroupId), @@ -383,7 +387,7 @@ func (l *RecalculateGroupLogic) executeSubscribeGrouping(tx *gorm.DB, historyId } var userSubscribes []UserSubscribeInfo - if err := tx.Table("user_subscribe"). + if err := tx.Model(&user.Subscribe{}). Select("id, user_id, subscribe_id"). Where("group_locked = ? AND status IN (0, 1)", 0). Scan(&userSubscribes).Error; err != nil { @@ -400,7 +404,7 @@ func (l *RecalculateGroupLogic) executeSubscribeGrouping(tx *gorm.DB, historyId // 1.5 查询所有参与计算的节点组ID var calculationNodeGroups []group.NodeGroup - if err := tx.Table("node_group"). + if err := tx.Model(&group.NodeGroup{}). Select("id"). Where("for_calculation = ?", true). Scan(&calculationNodeGroups).Error; err != nil { @@ -427,7 +431,7 @@ func (l *RecalculateGroupLogic) executeSubscribeGrouping(tx *gorm.DB, historyId NodeGroupIds string `json:"node_group_ids"` // JSON string } var subscribeInfos []SubscribeInfo - if err := tx.Table("subscribe"). + if err := tx.Model(&subscribe.Subscribe{}). Select("id, node_group_ids"). Where("id IN ?", subscribeIds). Find(&subscribeInfos).Error; err != nil { @@ -501,7 +505,7 @@ func (l *RecalculateGroupLogic) executeSubscribeGrouping(tx *gorm.DB, historyId us.Id, us.SubscribeId, selectedNodeGroupId, len(nodeGroupIds)) // 更新 user_subscribe 的 node_group_id 字段 - if err := tx.Table("user_subscribe"). + if err := tx.Model(&user.Subscribe{}). Where("id = ?", us.Id). Update("node_group_id", selectedNodeGroupId).Error; err != nil { l.Errorw("failed to update user_subscribe node_group_id", @@ -548,7 +552,7 @@ func (l *RecalculateGroupLogic) executeSubscribeGrouping(tx *gorm.DB, historyId expiredAffectedCount := 0 for _, eu := range expiredUserSubscribes { // 更新 user_subscribe 表的 node_group_id 字段到 0 - if err := tx.Table("user_subscribe"). + if err := tx.Model(&user.Subscribe{}). Where("id = ?", eu.Id). Update("node_group_id", 0).Error; err != nil { l.Errorw("failed to update expired user subscribe node_group_id", @@ -573,8 +577,8 @@ func (l *RecalculateGroupLogic) executeSubscribeGrouping(tx *gorm.DB, historyId // 统计该节点组的节点数 var nodeCount int64 = 0 if nodeGroupId > 0 { - if err := tx.Table("nodes"). - Where("JSON_CONTAINS(node_group_ids, ?)", nodeGroupId). + if err := tx.Model(&node.Node{}). + Where("JSON_CONTAINS(node_group_ids, ?)", fmt.Sprintf("[%d]", nodeGroupId)). Count(&nodeCount).Error; err != nil { l.Errorw("failed to count nodes", logger.Field("node_group_id", nodeGroupId), @@ -652,7 +656,7 @@ func (l *RecalculateGroupLogic) executeTrafficGrouping(tx *gorm.DB, historyId in } var userSubscribes []UserSubscribeInfo - if err := tx.Table("user_subscribe"). + if err := tx.Model(&user.Subscribe{}). Select("id, user_id, upload, download, (upload + download) as used_traffic"). Where("group_locked = ? AND status IN (0, 1)", 0). // 只查询有效且未锁定的用户订阅 Scan(&userSubscribes).Error; err != nil { @@ -694,7 +698,7 @@ func (l *RecalculateGroupLogic) executeTrafficGrouping(tx *gorm.DB, historyId in // 如果没有匹配到任何范围,targetNodeGroupId 保持为 0(不分配节点组) // 更新 user_subscribe 的 node_group_id 字段 - if err := tx.Table("user_subscribe"). + if err := tx.Model(&user.Subscribe{}). Where("id = ?", us.Id). Update("node_group_id", targetNodeGroupId).Error; err != nil { l.Errorw("failed to update user subscribe node_group_id", diff --git a/internal/logic/admin/group/updateNodeGroupLogic.go b/internal/logic/admin/group/updateNodeGroupLogic.go index eb299d5..b7d6fa4 100644 --- a/internal/logic/admin/group/updateNodeGroupLogic.go +++ b/internal/logic/admin/group/updateNodeGroupLogic.go @@ -6,6 +6,7 @@ import ( "time" "github.com/perfect-panel/server/internal/model/group" + "github.com/perfect-panel/server/internal/model/subscribe" "github.com/perfect-panel/server/internal/svc" "github.com/perfect-panel/server/internal/types" "github.com/perfect-panel/server/pkg/logger" @@ -37,6 +38,34 @@ func (l *UpdateNodeGroupLogic) UpdateNodeGroup(req *types.UpdateNodeGroupRequest return err } + // 验证:系统中只能有一个过期节点组 + if req.IsExpiredGroup != nil && *req.IsExpiredGroup { + var count int64 + err := l.svcCtx.DB.Model(&group.NodeGroup{}). + Where("is_expired_group = ? AND id != ?", true, req.Id). + Count(&count).Error + if err != nil { + logger.Errorf("failed to check expired group count: %v", err) + return err + } + if count > 0 { + return errors.New("system already has an expired node group, cannot create multiple") + } + + // 验证:被订阅商品设置为默认节点组的不能设置为过期节点组 + var subscribeCount int64 + err = l.svcCtx.DB.Model(&subscribe.Subscribe{}). + Where("node_group_id = ?", req.Id). + Count(&subscribeCount).Error + if err != nil { + logger.Errorf("failed to check subscribe usage: %v", err) + return err + } + if subscribeCount > 0 { + return errors.New("this node group is used as default node group in subscription products, cannot set as expired group") + } + } + // 构建更新数据 updates := map[string]interface{}{ "updated_at": time.Now(), @@ -53,6 +82,22 @@ func (l *UpdateNodeGroupLogic) UpdateNodeGroup(req *types.UpdateNodeGroupRequest if req.ForCalculation != nil { updates["for_calculation"] = *req.ForCalculation } + if req.IsExpiredGroup != nil { + updates["is_expired_group"] = *req.IsExpiredGroup + // 过期节点组不参与分组计算 + if *req.IsExpiredGroup { + updates["for_calculation"] = false + } + } + if req.ExpiredDaysLimit != nil { + updates["expired_days_limit"] = *req.ExpiredDaysLimit + } + if req.MaxTrafficGBExpired != nil { + updates["max_traffic_gb_expired"] = *req.MaxTrafficGBExpired + } + if req.SpeedLimit != nil { + updates["speed_limit"] = *req.SpeedLimit + } // 获取新的流量区间值 newMinTraffic := nodeGroup.MinTrafficGB diff --git a/internal/logic/admin/subscribe/createSubscribeLogic.go b/internal/logic/admin/subscribe/createSubscribeLogic.go index 3aeb751..999b3a3 100644 --- a/internal/logic/admin/subscribe/createSubscribeLogic.go +++ b/internal/logic/admin/subscribe/createSubscribeLogic.go @@ -34,6 +34,12 @@ func (l *CreateSubscribeLogic) CreateSubscribe(req *types.CreateSubscribeRequest val, _ := json.Marshal(req.Discount) discount = string(val) } + + trafficLimit := "" + if len(req.TrafficLimit) > 0 { + val, _ := json.Marshal(req.TrafficLimit) + trafficLimit = string(val) + } sub := &subscribe.Subscribe{ Id: 0, Name: req.Name, @@ -52,6 +58,7 @@ func (l *CreateSubscribeLogic) CreateSubscribe(req *types.CreateSubscribeRequest NodeTags: tool.StringSliceToString(req.NodeTags), NodeGroupIds: subscribe.JSONInt64Slice(req.NodeGroupIds), NodeGroupId: req.NodeGroupId, + TrafficLimit: trafficLimit, Show: req.Show, Sell: req.Sell, Sort: 0, diff --git a/internal/logic/admin/subscribe/getSubscribeDetailsLogic.go b/internal/logic/admin/subscribe/getSubscribeDetailsLogic.go index 6defdf1..fc29938 100644 --- a/internal/logic/admin/subscribe/getSubscribeDetailsLogic.go +++ b/internal/logic/admin/subscribe/getSubscribeDetailsLogic.go @@ -42,6 +42,12 @@ func (l *GetSubscribeDetailsLogic) GetSubscribeDetails(req *types.GetSubscribeDe l.Logger.Error("[GetSubscribeDetailsLogic] JSON unmarshal failed: ", logger.Field("error", err.Error()), logger.Field("discount", sub.Discount)) } } + if sub.TrafficLimit != "" { + err = json.Unmarshal([]byte(sub.TrafficLimit), &resp.TrafficLimit) + if err != nil { + l.Logger.Error("[GetSubscribeDetailsLogic] JSON unmarshal failed: ", logger.Field("error", err.Error()), logger.Field("traffic_limit", sub.TrafficLimit)) + } + } resp.Nodes = tool.StringToInt64Slice(sub.Nodes) resp.NodeTags = strings.Split(sub.NodeTags, ",") return resp, nil diff --git a/internal/logic/admin/subscribe/getSubscribeListLogic.go b/internal/logic/admin/subscribe/getSubscribeListLogic.go index 130d682..6cf6ba6 100644 --- a/internal/logic/admin/subscribe/getSubscribeListLogic.go +++ b/internal/logic/admin/subscribe/getSubscribeListLogic.go @@ -62,6 +62,12 @@ func (l *GetSubscribeListLogic) GetSubscribeList(req *types.GetSubscribeListRequ l.Logger.Error("[GetSubscribeListLogic] JSON unmarshal failed: ", logger.Field("error", err.Error()), logger.Field("discount", item.Discount)) } } + if item.TrafficLimit != "" { + err = json.Unmarshal([]byte(item.TrafficLimit), &sub.TrafficLimit) + if err != nil { + l.Logger.Error("[GetSubscribeListLogic] JSON unmarshal failed: ", logger.Field("error", err.Error()), logger.Field("traffic_limit", item.TrafficLimit)) + } + } sub.Nodes = tool.StringToInt64Slice(item.Nodes) sub.NodeTags = strings.Split(item.NodeTags, ",") // Handle NodeGroupIds - convert from JSONInt64Slice to []int64 diff --git a/internal/logic/admin/subscribe/updateSubscribeLogic.go b/internal/logic/admin/subscribe/updateSubscribeLogic.go index a60a6a0..4aef109 100644 --- a/internal/logic/admin/subscribe/updateSubscribeLogic.go +++ b/internal/logic/admin/subscribe/updateSubscribeLogic.go @@ -42,6 +42,12 @@ func (l *UpdateSubscribeLogic) UpdateSubscribe(req *types.UpdateSubscribeRequest val, _ := json.Marshal(req.Discount) discount = string(val) } + + trafficLimit := "" + if len(req.TrafficLimit) > 0 { + val, _ := json.Marshal(req.TrafficLimit) + trafficLimit = string(val) + } sub := &subscribe.Subscribe{ Id: req.Id, Name: req.Name, @@ -60,6 +66,7 @@ func (l *UpdateSubscribeLogic) UpdateSubscribe(req *types.UpdateSubscribeRequest NodeTags: tool.StringSliceToString(req.NodeTags), NodeGroupIds: subscribe.JSONInt64Slice(req.NodeGroupIds), NodeGroupId: req.NodeGroupId, + TrafficLimit: trafficLimit, Show: req.Show, Sell: req.Sell, Sort: req.Sort, diff --git a/internal/logic/public/subscribe/queryUserSubscribeNodeListLogic.go b/internal/logic/public/subscribe/queryUserSubscribeNodeListLogic.go index b98f90d..493c43d 100644 --- a/internal/logic/public/subscribe/queryUserSubscribeNodeListLogic.go +++ b/internal/logic/public/subscribe/queryUserSubscribeNodeListLogic.go @@ -5,6 +5,7 @@ import ( "strings" "time" + "github.com/perfect-panel/server/internal/model/group" "github.com/perfect-panel/server/internal/model/node" "github.com/perfect-panel/server/internal/model/user" "github.com/perfect-panel/server/internal/svc" @@ -88,7 +89,7 @@ func (l *QueryUserSubscribeNodeListLogic) QueryUserSubscribeNodeList() (resp *ty func (l *QueryUserSubscribeNodeListLogic) getServers(userSub *user.Subscribe) (userSubscribeNodes []*types.UserSubscribeNodeInfo, err error) { userSubscribeNodes = make([]*types.UserSubscribeNodeInfo, 0) if l.isSubscriptionExpired(userSub) { - return l.createExpiredServers(), nil + return l.createExpiredServers(userSub), nil } // Check if group management is enabled @@ -312,8 +313,98 @@ func (l *QueryUserSubscribeNodeListLogic) isSubscriptionExpired(userSub *user.Su return userSub.ExpireTime.Unix() < time.Now().Unix() && userSub.ExpireTime.Unix() != 0 } -func (l *QueryUserSubscribeNodeListLogic) createExpiredServers() []*types.UserSubscribeNodeInfo { - return nil +func (l *QueryUserSubscribeNodeListLogic) createExpiredServers(userSub *user.Subscribe) []*types.UserSubscribeNodeInfo { + // 1. 查询过期节点组 + var expiredGroup group.NodeGroup + err := l.svcCtx.DB.Where("is_expired_group = ?", true).First(&expiredGroup).Error + if err != nil { + l.Debugw("no expired node group configured", logger.Field("error", err)) + return nil + } + + // 2. 检查用户是否在过期天数限制内 + expiredDays := int(time.Since(userSub.ExpireTime).Hours() / 24) + if expiredDays > expiredGroup.ExpiredDaysLimit { + l.Debugf("user subscription expired %d days, exceeds limit %d days", expiredDays, expiredGroup.ExpiredDaysLimit) + return nil + } + + // 3. 检查用户已使用流量是否超过限制(仅使用过期期间的流量) + if expiredGroup.MaxTrafficGBExpired != nil && *expiredGroup.MaxTrafficGBExpired > 0 { + usedTrafficGB := (userSub.ExpiredDownload + userSub.ExpiredUpload) / (1024 * 1024 * 1024) + if usedTrafficGB >= *expiredGroup.MaxTrafficGBExpired { + l.Debugf("user expired traffic %d GB, exceeds expired group limit %d GB", usedTrafficGB, *expiredGroup.MaxTrafficGBExpired) + return nil + } + } + + // 4. 查询过期节点组的节点 + enable := true + _, nodes, err := l.svcCtx.NodeModel.FilterNodeList(l.ctx, &node.FilterNodeParams{ + Page: 0, + Size: 1000, + NodeGroupIds: []int64{expiredGroup.Id}, + Enabled: &enable, + }) + if err != nil { + l.Errorw("failed to query expired group nodes", logger.Field("error", err)) + return nil + } + + if len(nodes) == 0 { + l.Debug("no nodes found in expired group") + return nil + } + + // 5. 查询服务器信息 + var serverMapIds = make(map[int64]*node.Server) + for _, n := range nodes { + serverMapIds[n.ServerId] = nil + } + var serverIds []int64 + for k := range serverMapIds { + serverIds = append(serverIds, k) + } + + servers, err := l.svcCtx.NodeModel.QueryServerList(l.ctx, serverIds) + if err != nil { + l.Errorw("failed to query servers", logger.Field("error", err)) + return nil + } + + for _, s := range servers { + serverMapIds[s.Id] = s + } + + // 6. 构建节点列表 + userSubscribeNodes := make([]*types.UserSubscribeNodeInfo, 0, len(nodes)) + for _, n := range nodes { + server := serverMapIds[n.ServerId] + if server == nil { + continue + } + userSubscribeNode := &types.UserSubscribeNodeInfo{ + Id: n.Id, + Name: n.Name, + Uuid: userSub.UUID, + Protocol: n.Protocol, + Protocols: server.Protocols, + Port: n.Port, + Address: n.Address, + Tags: strings.Split(n.Tags, ","), + Country: server.Country, + City: server.City, + Latitude: server.Latitude, + Longitude: server.Longitude, + LongitudeCenter: server.LongitudeCenter, + LatitudeCenter: server.LatitudeCenter, + CreatedAt: n.CreatedAt.Unix(), + } + userSubscribeNodes = append(userSubscribeNodes, userSubscribeNode) + } + + l.Infof("returned %d nodes from expired group for user %d (expired %d days)", len(userSubscribeNodes), userSub.UserId, expiredDays) + return userSubscribeNodes } func (l *QueryUserSubscribeNodeListLogic) getFirstHostLine() string { diff --git a/internal/logic/public/user/getUserTrafficStatsLogic.go b/internal/logic/public/user/getUserTrafficStatsLogic.go new file mode 100644 index 0000000..e46cb4b --- /dev/null +++ b/internal/logic/public/user/getUserTrafficStatsLogic.go @@ -0,0 +1,138 @@ +package user + +import ( + "context" + "strconv" + "time" + + "gorm.io/gorm" + + "github.com/perfect-panel/server/internal/model/user" + "github.com/perfect-panel/server/internal/svc" + "github.com/perfect-panel/server/internal/types" + "github.com/perfect-panel/server/pkg/constant" + "github.com/perfect-panel/server/pkg/logger" + "github.com/perfect-panel/server/pkg/xerr" + "github.com/pkg/errors" +) + +type GetUserTrafficStatsLogic struct { + logger.Logger + ctx context.Context + svcCtx *svc.ServiceContext +} + +// Get User Traffic Statistics +func NewGetUserTrafficStatsLogic(ctx context.Context, svcCtx *svc.ServiceContext) *GetUserTrafficStatsLogic { + return &GetUserTrafficStatsLogic{ + Logger: logger.WithContext(ctx), + ctx: ctx, + svcCtx: svcCtx, + } +} + +func (l *GetUserTrafficStatsLogic) GetUserTrafficStats(req *types.GetUserTrafficStatsRequest) (resp *types.GetUserTrafficStatsResponse, err error) { + // 获取当前用户 + u, ok := l.ctx.Value(constant.CtxKeyUser).(*user.User) + if !ok { + logger.Error("current user is not found in context") + return nil, errors.Wrapf(xerr.NewErrCode(xerr.InvalidAccess), "Invalid Access") + } + + // 将字符串 ID 转换为 int64 + userSubscribeId, err := strconv.ParseInt(req.UserSubscribeId, 10, 64) + if err != nil { + l.Errorw("[GetUserTrafficStats] Invalid User Subscribe ID:", + logger.Field("user_subscribe_id", req.UserSubscribeId), + logger.Field("err", err.Error())) + return nil, errors.Wrapf(xerr.NewErrCode(xerr.InvalidAccess), "Invalid subscription ID") + } + + // 验证订阅归属权 - 直接查询 user_subscribe 表 + var userSubscribe struct { + Id int64 + UserId int64 + } + err = l.svcCtx.DB.WithContext(l.ctx). + Table("user_subscribe"). + Select("id, user_id"). + Where("id = ?", userSubscribeId). + First(&userSubscribe).Error + + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + l.Errorw("[GetUserTrafficStats] User Subscribe Not Found:", + logger.Field("user_subscribe_id", userSubscribeId), + logger.Field("user_id", u.Id)) + return nil, errors.Wrapf(xerr.NewErrCode(xerr.InvalidAccess), "Subscription not found") + } + l.Errorw("[GetUserTrafficStats] Query User Subscribe Error:", logger.Field("err", err.Error())) + return nil, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "Query User Subscribe Error") + } + + if userSubscribe.UserId != u.Id { + l.Errorw("[GetUserTrafficStats] User Subscribe Access Denied:", + logger.Field("user_subscribe_id", userSubscribeId), + logger.Field("subscribe_user_id", userSubscribe.UserId), + logger.Field("current_user_id", u.Id)) + return nil, errors.Wrapf(xerr.NewErrCode(xerr.InvalidAccess), "Invalid Access") + } + + // 计算时间范围 + now := time.Now() + startDate := now.AddDate(0, 0, -req.Days+1) + startDate = time.Date(startDate.Year(), startDate.Month(), startDate.Day(), 0, 0, 0, 0, time.Local) + + // 初始化响应 + resp = &types.GetUserTrafficStatsResponse{ + List: make([]types.DailyTrafficStats, 0, req.Days), + TotalUpload: 0, + TotalDownload: 0, + TotalTraffic: 0, + } + + // 按天查询流量数据 + for i := 0; i < req.Days; i++ { + currentDate := startDate.AddDate(0, 0, i) + dayStart := time.Date(currentDate.Year(), currentDate.Month(), currentDate.Day(), 0, 0, 0, 0, time.Local) + dayEnd := dayStart.Add(24 * time.Hour).Add(-time.Nanosecond) + + // 查询当天流量 + var dailyTraffic struct { + Upload int64 + Download int64 + } + + // 直接使用 model 的查询方法 + err := l.svcCtx.DB.WithContext(l.ctx). + Table("traffic_log"). + Select("COALESCE(SUM(upload), 0) as upload, COALESCE(SUM(download), 0) as download"). + Where("user_id = ? AND subscribe_id = ? AND timestamp BETWEEN ? AND ?", + u.Id, userSubscribeId, dayStart, dayEnd). + Scan(&dailyTraffic).Error + + if err != nil { + l.Errorw("[GetUserTrafficStats] Query Daily Traffic Error:", + logger.Field("date", currentDate.Format("2006-01-02")), + logger.Field("err", err.Error())) + return nil, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "Query Traffic Error") + } + + // 添加到结果列表 + total := dailyTraffic.Upload + dailyTraffic.Download + resp.List = append(resp.List, types.DailyTrafficStats{ + Date: currentDate.Format("2006-01-02"), + Upload: dailyTraffic.Upload, + Download: dailyTraffic.Download, + Total: total, + }) + + // 累加总计 + resp.TotalUpload += dailyTraffic.Upload + resp.TotalDownload += dailyTraffic.Download + } + + resp.TotalTraffic = resp.TotalUpload + resp.TotalDownload + + return resp, nil +} diff --git a/internal/logic/public/user/queryUserSubscribeLogic.go b/internal/logic/public/user/queryUserSubscribeLogic.go index 55e3770..218f851 100644 --- a/internal/logic/public/user/queryUserSubscribeLogic.go +++ b/internal/logic/public/user/queryUserSubscribeLogic.go @@ -3,6 +3,7 @@ package user import ( "context" "encoding/json" + "strconv" "time" "github.com/perfect-panel/server/pkg/constant" @@ -52,6 +53,9 @@ func (l *QueryUserSubscribeLogic) QueryUserSubscribe() (resp *types.QueryUserSub var sub types.UserSubscribe tool.DeepCopy(&sub, item) + // 填充 IdStr 字段,避免前端精度丢失 + sub.IdStr = strconv.FormatInt(item.Id, 10) + // 解析Discount字段 避免在续订时只能续订一个月 if item.Subscribe != nil && item.Subscribe.Discount != "" { var discounts []types.SubscribeDiscount diff --git a/internal/logic/server/getServerUserListLogic.go b/internal/logic/server/getServerUserListLogic.go index 6d51326..8a871d0 100644 --- a/internal/logic/server/getServerUserListLogic.go +++ b/internal/logic/server/getServerUserListLogic.go @@ -4,10 +4,13 @@ import ( "encoding/json" "fmt" "strings" + "time" "github.com/gin-gonic/gin" + "github.com/perfect-panel/server/internal/model/group" "github.com/perfect-panel/server/internal/model/node" "github.com/perfect-panel/server/internal/model/subscribe" + "github.com/perfect-panel/server/internal/model/user" "github.com/perfect-panel/server/internal/svc" "github.com/perfect-panel/server/internal/types" @@ -133,7 +136,7 @@ func (l *GetServerUserListLogic) GetServerUserList(req *types.GetServerUserListR return nil, err } } - + if len(subs) == 0 { return &types.GetServerUserListResponse{ Users: []types.ServerUser{ @@ -151,14 +154,33 @@ func (l *GetServerUserListLogic) GetServerUserList(req *types.GetServerUserListR return nil, err } for _, datum := range data { + if !l.shouldIncludeServerUser(datum, nodeGroupIds) { + continue + } + + // 计算该用户的实际限速值(考虑按量限速规则) + effectiveSpeedLimit := l.calculateEffectiveSpeedLimit(sub, datum) + users = append(users, types.ServerUser{ Id: datum.Id, UUID: datum.UUID, - SpeedLimit: sub.SpeedLimit, + SpeedLimit: effectiveSpeedLimit, DeviceLimit: sub.DeviceLimit, }) } } + + // 处理过期订阅用户:如果当前节点属于过期节点组,添加符合条件的过期用户 + if len(nodeGroupIds) > 0 { + expiredUsers, expiredSpeedLimit := l.getExpiredUsers(nodeGroupIds) + for i := range expiredUsers { + if expiredSpeedLimit > 0 { + expiredUsers[i].SpeedLimit = expiredSpeedLimit + } + } + users = append(users, expiredUsers...) + } + if len(users) == 0 { users = append(users, types.ServerUser{ Id: 1, @@ -181,3 +203,175 @@ func (l *GetServerUserListLogic) GetServerUserList(req *types.GetServerUserListR } return resp, nil } + +func (l *GetServerUserListLogic) shouldIncludeServerUser(userSub *user.Subscribe, serverNodeGroupIds []int64) bool { + if userSub == nil { + return false + } + + if userSub.ExpireTime.Unix() == 0 || userSub.ExpireTime.After(time.Now()) { + return true + } + + return l.canUseExpiredNodeGroup(userSub, serverNodeGroupIds) +} + +func (l *GetServerUserListLogic) getExpiredUsers(serverNodeGroupIds []int64) ([]types.ServerUser, int64) { + var expiredGroup group.NodeGroup + if err := l.svcCtx.DB.Where("is_expired_group = ?", true).First(&expiredGroup).Error; err != nil { + return nil, 0 + } + + if !tool.Contains(serverNodeGroupIds, expiredGroup.Id) { + return nil, 0 + } + + var expiredSubs []*user.Subscribe + if err := l.svcCtx.DB.Where("status = ?", 3).Find(&expiredSubs).Error; err != nil { + l.Errorw("query expired subscriptions failed", logger.Field("error", err.Error())) + return nil, 0 + } + + users := make([]types.ServerUser, 0) + seen := make(map[int64]bool) + for _, userSub := range expiredSubs { + if !l.checkExpiredUserEligibility(userSub, &expiredGroup) { + continue + } + if seen[userSub.Id] { + continue + } + seen[userSub.Id] = true + users = append(users, types.ServerUser{ + Id: userSub.Id, + UUID: userSub.UUID, + }) + } + + return users, int64(expiredGroup.SpeedLimit) +} + +func (l *GetServerUserListLogic) checkExpiredUserEligibility(userSub *user.Subscribe, expiredGroup *group.NodeGroup) bool { + expiredDays := int(time.Since(userSub.ExpireTime).Hours() / 24) + if expiredDays > expiredGroup.ExpiredDaysLimit { + return false + } + + if expiredGroup.MaxTrafficGBExpired != nil && *expiredGroup.MaxTrafficGBExpired > 0 { + usedTrafficGB := (userSub.ExpiredDownload + userSub.ExpiredUpload) / (1024 * 1024 * 1024) + if usedTrafficGB >= *expiredGroup.MaxTrafficGBExpired { + return false + } + } + + return true +} + +func (l *GetServerUserListLogic) canUseExpiredNodeGroup(userSub *user.Subscribe, serverNodeGroupIds []int64) bool { + var expiredGroup group.NodeGroup + if err := l.svcCtx.DB.Where("is_expired_group = ?", true).First(&expiredGroup).Error; err != nil { + return false + } + + if !tool.Contains(serverNodeGroupIds, expiredGroup.Id) { + return false + } + + expiredDays := int(time.Since(userSub.ExpireTime).Hours() / 24) + if expiredDays > expiredGroup.ExpiredDaysLimit { + return false + } + + if expiredGroup.MaxTrafficGBExpired != nil && *expiredGroup.MaxTrafficGBExpired > 0 { + usedTrafficGB := (userSub.ExpiredDownload + userSub.ExpiredUpload) / (1024 * 1024 * 1024) + if usedTrafficGB >= *expiredGroup.MaxTrafficGBExpired { + return false + } + } + + return true +} + +// calculateEffectiveSpeedLimit 计算用户的实际限速值(考虑按量限速规则) +func (l *GetServerUserListLogic) calculateEffectiveSpeedLimit(sub *subscribe.Subscribe, userSub *user.Subscribe) int64 { + baseSpeedLimit := sub.SpeedLimit + + // 解析 traffic_limit 规则 + if sub.TrafficLimit == "" { + return baseSpeedLimit + } + + var trafficLimitRules []types.TrafficLimit + if err := json.Unmarshal([]byte(sub.TrafficLimit), &trafficLimitRules); err != nil { + l.Errorw("[calculateEffectiveSpeedLimit] Failed to unmarshal traffic_limit", + logger.Field("error", err.Error()), + logger.Field("traffic_limit", sub.TrafficLimit)) + return baseSpeedLimit + } + + if len(trafficLimitRules) == 0 { + return baseSpeedLimit + } + + // 查询用户指定时段的流量使用情况 + now := time.Now() + for _, rule := range trafficLimitRules { + var startTime, endTime time.Time + + if rule.StatType == "hour" { + // 按小时统计:根据 StatValue 计算时间范围(往前推 N 小时) + if rule.StatValue <= 0 { + continue + } + // 从当前时间往前推 StatValue 小时 + startTime = now.Add(-time.Duration(rule.StatValue) * time.Hour) + endTime = now + } else if rule.StatType == "day" { + // 按天统计:根据 StatValue 计算时间范围(往前推 N 天) + if rule.StatValue <= 0 { + continue + } + // 从当前时间往前推 StatValue 天 + startTime = now.AddDate(0, 0, -int(rule.StatValue)) + endTime = now + } else { + continue + } + + // 查询该时段的流量使用 + var usedTraffic struct { + Upload int64 + Download int64 + } + err := l.svcCtx.DB.WithContext(l.ctx.Request.Context()). + Table("traffic_log"). + Select("COALESCE(SUM(upload), 0) as upload, COALESCE(SUM(download), 0) as download"). + Where("user_id = ? AND subscribe_id = ? AND timestamp >= ? AND timestamp < ?", + userSub.UserId, userSub.Id, startTime, endTime). + Scan(&usedTraffic).Error + + if err != nil { + l.Errorw("[calculateEffectiveSpeedLimit] Failed to query traffic usage", + logger.Field("error", err.Error()), + logger.Field("user_id", userSub.UserId), + logger.Field("subscribe_id", userSub.Id)) + continue + } + + // 计算已使用流量(GB) + usedGB := float64(usedTraffic.Upload+usedTraffic.Download) / (1024 * 1024 * 1024) + + // 如果已使用流量达到或超过阈值,应用限速 + if usedGB >= float64(rule.TrafficUsage) { + // 如果规则限速大于0,应用该限速 + if rule.SpeedLimit > 0 { + // 如果基础限速为0(无限速)或规则限速更严格,使用规则限速 + if baseSpeedLimit == 0 || rule.SpeedLimit < baseSpeedLimit { + return rule.SpeedLimit + } + } + } + } + + return baseSpeedLimit +} diff --git a/internal/logic/subscribe/subscribeLogic.go b/internal/logic/subscribe/subscribeLogic.go index 823f377..7b88160 100644 --- a/internal/logic/subscribe/subscribeLogic.go +++ b/internal/logic/subscribe/subscribeLogic.go @@ -8,6 +8,7 @@ import ( "github.com/perfect-panel/server/adapter" "github.com/perfect-panel/server/internal/model/client" + "github.com/perfect-panel/server/internal/model/group" "github.com/perfect-panel/server/internal/model/log" "github.com/perfect-panel/server/internal/model/node" "github.com/perfect-panel/server/internal/report" @@ -206,6 +207,19 @@ func (l *SubscribeLogic) logSubscribeActivity(subscribeStatus bool, userSub *use func (l *SubscribeLogic) getServers(userSub *user.Subscribe) ([]*node.Node, error) { if l.isSubscriptionExpired(userSub) { + // 尝试获取过期节点组的节点 + expiredNodes, err := l.getExpiredGroupNodes(userSub) + if err != nil { + l.Errorw("[Generate Subscribe]get expired group nodes error", logger.Field("error", err.Error())) + return l.createExpiredServers(), nil + } + // 如果有符合条件的过期节点组节点,返回它们 + if len(expiredNodes) > 0 { + l.Debugf("[Generate Subscribe]user %d can use expired node group, nodes count: %d", userSub.UserId, len(expiredNodes)) + return expiredNodes, nil + } + // 否则返回假的过期节点 + l.Debugf("[Generate Subscribe]user %d cannot use expired node group, return fake expired nodes", userSub.UserId) return l.createExpiredServers(), nil } @@ -422,3 +436,52 @@ func (l *SubscribeLogic) isGroupEnabled() bool { } return value == "true" || value == "1" } + +// getExpiredGroupNodes 获取过期节点组的节点 +func (l *SubscribeLogic) getExpiredGroupNodes(userSub *user.Subscribe) ([]*node.Node, error) { + // 1. 查询过期节点组 + var expiredGroup group.NodeGroup + err := l.svc.DB.Where("is_expired_group = ?", true).First(&expiredGroup).Error + if err != nil { + l.Debugw("[SubscribeLogic]no expired node group configured", logger.Field("error", err.Error())) + return nil, err + } + + // 2. 检查用户是否在过期天数限制内 + expiredDays := int(time.Since(userSub.ExpireTime).Hours() / 24) + if expiredDays > expiredGroup.ExpiredDaysLimit { + l.Debugf("[SubscribeLogic]user %d subscription expired %d days, exceeds limit %d days", userSub.UserId, expiredDays, expiredGroup.ExpiredDaysLimit) + return nil, nil + } + + // 3. 检查用户已使用流量是否超过限制(仅使用过期期间的流量) + if expiredGroup.MaxTrafficGBExpired != nil && *expiredGroup.MaxTrafficGBExpired > 0 { + usedTrafficGB := (userSub.ExpiredDownload + userSub.ExpiredUpload) / (1024 * 1024 * 1024) + if usedTrafficGB >= *expiredGroup.MaxTrafficGBExpired { + l.Debugf("[SubscribeLogic]user %d expired traffic %d GB, exceeds expired group limit %d GB", userSub.UserId, usedTrafficGB, *expiredGroup.MaxTrafficGBExpired) + return nil, nil + } + } + + // 4. 查询过期节点组的节点 + enable := true + _, nodes, err := l.svc.NodeModel.FilterNodeList(l.ctx.Request.Context(), &node.FilterNodeParams{ + Page: 0, + Size: 1000, + NodeGroupIds: []int64{expiredGroup.Id}, + Enabled: &enable, + Preload: true, + }) + if err != nil { + l.Errorw("[SubscribeLogic]failed to query expired group nodes", logger.Field("error", err.Error())) + return nil, err + } + + if len(nodes) == 0 { + l.Debug("[SubscribeLogic]no nodes found in expired group") + return nil, nil + } + + l.Infof("[SubscribeLogic]returned %d nodes from expired group for user %d (expired %d days)", len(nodes), userSub.UserId, expiredDays) + return nodes, nil +} diff --git a/internal/model/group/node_group.go b/internal/model/group/node_group.go index 644580a..a2fe3ee 100644 --- a/internal/model/group/node_group.go +++ b/internal/model/group/node_group.go @@ -8,15 +8,19 @@ import ( // NodeGroup 节点组模型 type NodeGroup struct { - Id int64 `gorm:"primaryKey"` - Name string `gorm:"type:varchar(255);not null;comment:Name"` - Description string `gorm:"type:varchar(500);comment:Description"` - Sort int `gorm:"default:0;index:idx_sort;comment:Sort Order"` - ForCalculation *bool `gorm:"default:true;not null;comment:For Calculation: whether this node group participates in grouping calculation"` - MinTrafficGB *int64 `gorm:"default:0;comment:Minimum Traffic (GB) for this node group"` - MaxTrafficGB *int64 `gorm:"default:0;comment:Maximum Traffic (GB) for this node group"` - CreatedAt time.Time `gorm:"<-:create;comment:Create Time"` - UpdatedAt time.Time `gorm:"comment:Update Time"` + Id int64 `gorm:"primaryKey"` + Name string `gorm:"type:varchar(255);not null;comment:Name"` + Description string `gorm:"type:varchar(500);comment:Description"` + Sort int `gorm:"default:0;index:idx_sort;comment:Sort Order"` + ForCalculation *bool `gorm:"default:true;not null;comment:For Calculation: whether this node group participates in grouping calculation"` + IsExpiredGroup *bool `gorm:"default:false;not null;index:idx_is_expired_group;comment:Is Expired Group"` + ExpiredDaysLimit int `gorm:"default:7;not null;comment:Expired days limit (days)"` + MaxTrafficGBExpired *int64 `gorm:"default:0;comment:Max traffic for expired users (GB)"` + SpeedLimit int `gorm:"default:0;not null;comment:Speed limit (KB/s)"` + MinTrafficGB *int64 `gorm:"default:0;comment:Minimum Traffic (GB) for this node group"` + MaxTrafficGB *int64 `gorm:"default:0;comment:Maximum Traffic (GB) for this node group"` + CreatedAt time.Time `gorm:"<-:create;comment:Create Time"` + UpdatedAt time.Time `gorm:"comment:Update Time"` } // TableName 指定表名 diff --git a/internal/model/subscribe/subscribe.go b/internal/model/subscribe/subscribe.go index af58598..d689d16 100644 --- a/internal/model/subscribe/subscribe.go +++ b/internal/model/subscribe/subscribe.go @@ -71,6 +71,7 @@ type Subscribe struct { NodeTags string `gorm:"type:varchar(255);comment:Node Tags"` NodeGroupIds JSONInt64Slice `gorm:"type:json;comment:Node Group IDs (JSON array, multiple groups)"` NodeGroupId int64 `gorm:"default:0;index:idx_node_group_id;comment:Default Node Group ID (single ID)"` + TrafficLimit string `gorm:"type:text;comment:Traffic Limit Rules"` Show *bool `gorm:"type:tinyint(1);not null;default:0;comment:Show portal page"` Sell *bool `gorm:"type:tinyint(1);not null;default:0;comment:Sell"` Sort int64 `gorm:"type:int;not null;default:0;comment:Sort"` diff --git a/internal/model/user/model.go b/internal/model/user/model.go index 0cf3502..7de4c0c 100644 --- a/internal/model/user/model.go +++ b/internal/model/user/model.go @@ -82,7 +82,7 @@ type customUserLogicModel interface { FindOneSubscribeDetailsById(ctx context.Context, id int64) (*SubscribeDetails, error) FindOneUserSubscribe(ctx context.Context, id int64) (*SubscribeDetails, error) FindUsersSubscribeBySubscribeId(ctx context.Context, subscribeId int64) ([]*Subscribe, error) - UpdateUserSubscribeWithTraffic(ctx context.Context, id, download, upload int64, tx ...*gorm.DB) error + UpdateUserSubscribeWithTraffic(ctx context.Context, id, download, upload int64, isExpired bool, tx ...*gorm.DB) error QueryResisterUserTotalByDate(ctx context.Context, date time.Time) (int64, error) QueryResisterUserTotalByMonthly(ctx context.Context, date time.Time) (int64, error) QueryResisterUserTotal(ctx context.Context) (int64, error) @@ -181,7 +181,7 @@ func (m *customUserModel) BatchDeleteUser(ctx context.Context, ids []int64, tx . }, m.batchGetCacheKeys(users...)...) } -func (m *customUserModel) UpdateUserSubscribeWithTraffic(ctx context.Context, id, download, upload int64, tx ...*gorm.DB) error { +func (m *customUserModel) UpdateUserSubscribeWithTraffic(ctx context.Context, id, download, upload int64, isExpired bool, tx ...*gorm.DB) error { sub, err := m.FindOneSubscribe(ctx, id) if err != nil { return err @@ -198,10 +198,21 @@ func (m *customUserModel) UpdateUserSubscribeWithTraffic(ctx context.Context, id if len(tx) > 0 { conn = tx[0] } - return conn.Model(&Subscribe{}).Where("id = ?", id).Updates(map[string]interface{}{ - "download": gorm.Expr("download + ?", download), - "upload": gorm.Expr("upload + ?", upload), - }).Error + + // 根据订阅状态更新对应的流量字段 + if isExpired { + // 过期期间,更新过期流量字段 + return conn.Model(&Subscribe{}).Where("id = ?", id).Updates(map[string]interface{}{ + "expired_download": gorm.Expr("expired_download + ?", download), + "expired_upload": gorm.Expr("expired_upload + ?", upload), + }).Error + } else { + // 正常期间,更新正常流量字段 + return conn.Model(&Subscribe{}).Where("id = ?", id).Updates(map[string]interface{}{ + "download": gorm.Expr("download + ?", download), + "upload": gorm.Expr("upload + ?", upload), + }).Error + } }) } diff --git a/internal/model/user/user.go b/internal/model/user/user.go index 3976468..cbc659d 100644 --- a/internal/model/user/user.go +++ b/internal/model/user/user.go @@ -85,25 +85,27 @@ func (*User) TableName() string { } type Subscribe struct { - Id int64 `gorm:"primaryKey"` - UserId int64 `gorm:"index:idx_user_id;not null;comment:User ID"` - User User `gorm:"foreignKey:UserId;references:Id"` - OrderId int64 `gorm:"index:idx_order_id;not null;comment:Order ID"` - SubscribeId int64 `gorm:"index:idx_subscribe_id;not null;comment:Subscription ID"` - NodeGroupId int64 `gorm:"index:idx_node_group_id;not null;default:0;comment:Node Group ID (single ID)"` - GroupLocked *bool `gorm:"type:tinyint(1);not null;default:0;comment:Group Locked"` - StartTime time.Time `gorm:"default:CURRENT_TIMESTAMP(3);not null;comment:Subscription Start Time"` - ExpireTime time.Time `gorm:"default:NULL;comment:Subscription Expire Time"` - FinishedAt *time.Time `gorm:"default:NULL;comment:Finished Time"` - Traffic int64 `gorm:"default:0;comment:Traffic"` - Download int64 `gorm:"default:0;comment:Download Traffic"` - Upload int64 `gorm:"default:0;comment:Upload Traffic"` - Token string `gorm:"index:idx_token;unique;type:varchar(255);default:'';comment:Token"` - UUID string `gorm:"type:varchar(255);unique;index:idx_uuid;default:'';comment:UUID"` - Status uint8 `gorm:"type:tinyint(1);default:0;comment:Subscription Status: 0: Pending 1: Active 2: Finished 3: Expired 4: Deducted 5: stopped"` - Note string `gorm:"type:varchar(500);default:'';comment:User note for subscription"` - CreatedAt time.Time `gorm:"<-:create;comment:Creation Time"` - UpdatedAt time.Time `gorm:"comment:Update Time"` + Id int64 `gorm:"primaryKey"` + UserId int64 `gorm:"index:idx_user_id;not null;comment:User ID"` + User User `gorm:"foreignKey:UserId;references:Id"` + OrderId int64 `gorm:"index:idx_order_id;not null;comment:Order ID"` + SubscribeId int64 `gorm:"index:idx_subscribe_id;not null;comment:Subscription ID"` + NodeGroupId int64 `gorm:"index:idx_node_group_id;not null;default:0;comment:Node Group ID (single ID)"` + GroupLocked *bool `gorm:"type:tinyint(1);not null;default:0;comment:Group Locked"` + StartTime time.Time `gorm:"default:CURRENT_TIMESTAMP(3);not null;comment:Subscription Start Time"` + ExpireTime time.Time `gorm:"default:NULL;comment:Subscription Expire Time"` + FinishedAt *time.Time `gorm:"default:NULL;comment:Finished Time"` + Traffic int64 `gorm:"default:0;comment:Traffic"` + Download int64 `gorm:"default:0;comment:Download Traffic"` + Upload int64 `gorm:"default:0;comment:Upload Traffic"` + ExpiredDownload int64 `gorm:"default:0;comment:Expired period download traffic (bytes)"` + ExpiredUpload int64 `gorm:"default:0;comment:Expired period upload traffic (bytes)"` + Token string `gorm:"index:idx_token;unique;type:varchar(255);default:'';comment:Token"` + UUID string `gorm:"type:varchar(255);unique;index:idx_uuid;default:'';comment:UUID"` + Status uint8 `gorm:"type:tinyint(1);default:0;comment:Subscription Status: 0: Pending 1: Active 2: Finished 3: Expired 4: Deducted 5: stopped"` + Note string `gorm:"type:varchar(500);default:'';comment:User note for subscription"` + CreatedAt time.Time `gorm:"<-:create;comment:Creation Time"` + UpdatedAt time.Time `gorm:"comment:Update Time"` } func (*Subscribe) TableName() string { diff --git a/internal/types/types.go b/internal/types/types.go index f1ee1fc..88fa181 100644 --- a/internal/types/types.go +++ b/internal/types/types.go @@ -321,12 +321,16 @@ type CreateDocumentRequest struct { } type CreateNodeGroupRequest struct { - Name string `json:"name" validate:"required"` - Description string `json:"description"` - Sort int `json:"sort"` - ForCalculation *bool `json:"for_calculation"` - MinTrafficGB *int64 `json:"min_traffic_gb,omitempty"` - MaxTrafficGB *int64 `json:"max_traffic_gb,omitempty"` + Name string `json:"name" validate:"required"` + Description string `json:"description"` + Sort int `json:"sort"` + ForCalculation *bool `json:"for_calculation"` + IsExpiredGroup *bool `json:"is_expired_group"` + ExpiredDaysLimit *int `json:"expired_days_limit"` + MaxTrafficGBExpired *int64 `json:"max_traffic_gb_expired,omitempty"` + SpeedLimit *int `json:"speed_limit"` + MinTrafficGB *int64 `json:"min_traffic_gb,omitempty"` + MaxTrafficGB *int64 `json:"max_traffic_gb,omitempty"` } type CreateNodeRequest struct { @@ -432,6 +436,7 @@ type CreateSubscribeRequest struct { NodeTags []string `json:"node_tags"` NodeGroupIds []int64 `json:"node_group_ids,omitempty"` NodeGroupId int64 `json:"node_group_id"` + TrafficLimit []TrafficLimit `json:"traffic_limit"` Show *bool `json:"show"` Sell *bool `json:"sell"` DeductionRatio int64 `json:"deduction_ratio"` @@ -502,6 +507,13 @@ type CurrencyConfig struct { CurrencySymbol string `json:"currency_symbol"` } +type DailyTrafficStats struct { + Date string `json:"date"` + Upload int64 `json:"upload"` + Download int64 `json:"download"` + Total int64 `json:"total"` +} + type DeleteAdsRequest struct { Id int64 `json:"id"` } @@ -1277,6 +1289,18 @@ type GetUserTicketListResponse struct { List []Ticket `json:"list"` } +type GetUserTrafficStatsRequest struct { + UserSubscribeId string `form:"user_subscribe_id" validate:"required"` + Days int `form:"days" validate:"required,oneof=7 30"` +} + +type GetUserTrafficStatsResponse struct { + List []DailyTrafficStats `json:"list"` + TotalUpload int64 `json:"total_upload"` + TotalDownload int64 `json:"total_download"` + TotalTraffic int64 `json:"total_traffic"` +} + type GiftLog struct { Type uint16 `json:"type"` UserId int64 `json:"user_id"` @@ -1425,16 +1449,20 @@ type NodeDNS struct { } type NodeGroup struct { - Id int64 `json:"id"` - Name string `json:"name"` - Description string `json:"description"` - Sort int `json:"sort"` - ForCalculation bool `json:"for_calculation"` - MinTrafficGB int64 `json:"min_traffic_gb,omitempty"` - MaxTrafficGB int64 `json:"max_traffic_gb,omitempty"` - NodeCount int64 `json:"node_count,omitempty"` - CreatedAt int64 `json:"created_at"` - UpdatedAt int64 `json:"updated_at"` + Id int64 `json:"id"` + Name string `json:"name"` + Description string `json:"description"` + Sort int `json:"sort"` + ForCalculation bool `json:"for_calculation"` + IsExpiredGroup bool `json:"is_expired_group"` + ExpiredDaysLimit int `json:"expired_days_limit"` + MaxTrafficGBExpired int64 `json:"max_traffic_gb_expired,omitempty"` + SpeedLimit int `json:"speed_limit"` + MinTrafficGB int64 `json:"min_traffic_gb,omitempty"` + MaxTrafficGB int64 `json:"max_traffic_gb,omitempty"` + NodeCount int64 `json:"node_count,omitempty"` + CreatedAt int64 `json:"created_at"` + UpdatedAt int64 `json:"updated_at"` } type NodeGroupItem struct { @@ -2299,6 +2327,7 @@ type Subscribe struct { NodeTags []string `json:"node_tags"` NodeGroupIds []int64 `json:"node_group_ids,omitempty"` NodeGroupId int64 `json:"node_group_id"` + TrafficLimit []TrafficLimit `json:"traffic_limit"` Show bool `json:"show"` Sell bool `json:"sell"` Sort int64 `json:"sort"` @@ -2492,6 +2521,13 @@ type TosConfig struct { TosContent string `json:"tos_content"` } +type TrafficLimit struct { + StatType string `json:"stat_type"` + StatValue int64 `json:"stat_value"` + TrafficUsage int64 `json:"traffic_usage"` + SpeedLimit int64 `json:"speed_limit"` +} + type TrafficLog struct { Id int64 `json:"id"` ServerId int64 `json:"server_id"` @@ -2629,13 +2665,17 @@ type UpdateGroupConfigRequest struct { } type UpdateNodeGroupRequest struct { - Id int64 `json:"id" validate:"required"` - Name string `json:"name"` - Description string `json:"description"` - Sort int `json:"sort"` - ForCalculation *bool `json:"for_calculation"` - MinTrafficGB *int64 `json:"min_traffic_gb,omitempty"` - MaxTrafficGB *int64 `json:"max_traffic_gb,omitempty"` + Id int64 `json:"id" validate:"required"` + Name string `json:"name"` + Description string `json:"description"` + Sort int `json:"sort"` + ForCalculation *bool `json:"for_calculation"` + IsExpiredGroup *bool `json:"is_expired_group"` + ExpiredDaysLimit *int `json:"expired_days_limit"` + MaxTrafficGBExpired *int64 `json:"max_traffic_gb_expired,omitempty"` + SpeedLimit *int `json:"speed_limit"` + MinTrafficGB *int64 `json:"min_traffic_gb,omitempty"` + MaxTrafficGB *int64 `json:"max_traffic_gb,omitempty"` } type UpdateNodeRequest struct { @@ -2727,6 +2767,7 @@ type UpdateSubscribeRequest struct { NodeTags []string `json:"node_tags"` NodeGroupIds []int64 `json:"node_group_ids,omitempty"` NodeGroupId int64 `json:"node_group_id"` + TrafficLimit []TrafficLimit `json:"traffic_limit"` Show *bool `json:"show"` Sell *bool `json:"sell"` Sort int64 `json:"sort"` @@ -2907,6 +2948,7 @@ type UserStatisticsResponse struct { type UserSubscribe struct { Id int64 `json:"id"` + IdStr string `json:"id_str"` UserId int64 `json:"user_id"` OrderId int64 `json:"order_id"` SubscribeId int64 `json:"subscribe_id"` diff --git a/queue/logic/order/activateOrderLogic.go b/queue/logic/order/activateOrderLogic.go index 24a03e2..8cae8fa 100644 --- a/queue/logic/order/activateOrderLogic.go +++ b/queue/logic/order/activateOrderLogic.go @@ -351,18 +351,20 @@ func (l *ActivateOrderLogic) getSubscribeInfo(ctx context.Context, subscribeId i func (l *ActivateOrderLogic) createUserSubscription(ctx context.Context, orderInfo *order.Order, sub *subscribe.Subscribe) (*user.Subscribe, error) { now := time.Now() userSub := &user.Subscribe{ - UserId: orderInfo.UserId, - OrderId: orderInfo.Id, - SubscribeId: orderInfo.SubscribeId, - StartTime: now, - ExpireTime: tool.AddTime(sub.UnitTime, orderInfo.Quantity, now), - Traffic: sub.Traffic, - Download: 0, - Upload: 0, - Token: uuidx.SubscribeToken(orderInfo.OrderNo), - UUID: uuid.New().String(), - Status: 1, - NodeGroupId: sub.NodeGroupId, // Inherit node_group_id from subscription plan + UserId: orderInfo.UserId, + OrderId: orderInfo.Id, + SubscribeId: orderInfo.SubscribeId, + StartTime: now, + ExpireTime: tool.AddTime(sub.UnitTime, orderInfo.Quantity, now), + Traffic: sub.Traffic, + Download: 0, + Upload: 0, + ExpiredDownload: 0, + ExpiredUpload: 0, + Token: uuidx.SubscribeToken(orderInfo.OrderNo), + UUID: uuid.New().String(), + Status: 1, + NodeGroupId: sub.NodeGroupId, // Inherit node_group_id from subscription plan } // Check quota limit before creating subscription (final safeguard) @@ -650,6 +652,9 @@ func (l *ActivateOrderLogic) updateSubscriptionForRenewal(ctx context.Context, u userSub.ExpireTime = tool.AddTime(sub.UnitTime, orderInfo.Quantity, userSub.ExpireTime) userSub.Status = 1 + // 续费时重置过期流量字段 + userSub.ExpiredDownload = 0 + userSub.ExpiredUpload = 0 if err := l.svc.UserModel.UpdateSubscribe(ctx, userSub); err != nil { logger.WithContext(ctx).Error("Update user subscribe failed", logger.Field("error", err.Error())) @@ -674,6 +679,8 @@ func (l *ActivateOrderLogic) ResetTraffic(ctx context.Context, orderInfo *order. // Reset traffic userSub.Download = 0 userSub.Upload = 0 + userSub.ExpiredDownload = 0 + userSub.ExpiredUpload = 0 userSub.Status = 1 if err := l.svc.UserModel.UpdateSubscribe(ctx, userSub); err != nil { diff --git a/queue/logic/traffic/trafficStatisticsLogic.go b/queue/logic/traffic/trafficStatisticsLogic.go index 37614cb..a89df57 100644 --- a/queue/logic/traffic/trafficStatisticsLogic.go +++ b/queue/logic/traffic/trafficStatisticsLogic.go @@ -98,11 +98,13 @@ func (l *TrafficStatisticsLogic) ProcessTask(ctx context.Context, task *asynq.Ta // update user subscribe with log d := int64(float32(log.Download) * ratio * realTimeMultiplier) u := int64(float32(log.Upload) * ratio * realTimeMultiplier) - if err := l.svc.UserModel.UpdateUserSubscribeWithTraffic(ctx, sub.Id, d, u); err != nil { + isExpired := now.After(sub.ExpireTime) + if err := l.svc.UserModel.UpdateUserSubscribeWithTraffic(ctx, sub.Id, d, u, isExpired); err != nil { logger.WithContext(ctx).Error("[TrafficStatistics] Update user subscribe with log failed", logger.Field("sid", log.SID), logger.Field("download", float32(log.Download)*ratio), logger.Field("upload", float32(log.Upload)*ratio), + logger.Field("is_expired", isExpired), logger.Field("error", err.Error()), ) continue From bc721b0ba6b765d77f054a24839a93e983c5a2ee Mon Sep 17 00:00:00 2001 From: EUForest Date: Wed, 18 Mar 2026 12:45:09 +0800 Subject: [PATCH 18/18] update: Adding interference to CAPTCHA --- pkg/captcha/local.go | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/pkg/captcha/local.go b/pkg/captcha/local.go index ba6b917..86b9af6 100644 --- a/pkg/captcha/local.go +++ b/pkg/captcha/local.go @@ -3,6 +3,7 @@ package captcha import ( "context" "fmt" + "strings" "time" "github.com/mojocn/base64Captcha" @@ -15,8 +16,18 @@ type localService struct { } func newLocalService(redisClient *redis.Client) Service { - // Configure captcha driver - driver := base64Captcha.NewDriverDigit(80, 240, 5, 0.7, 80) + // Configure captcha driver - alphanumeric with visual effects (letters + numbers) + driver := base64Captcha.NewDriverString( + 80, // height + 240, // width + 20, // noise count (more interference) + base64Captcha.OptionShowSlimeLine|base64Captcha.OptionShowSineLine, // show curved lines + 5, // length (5 characters) + "abcdefghijkmnpqrstuvwxyzABCDEFGHJKLMNPQRSTUVWXYZ23456789", // source (exclude confusing chars) + nil, // bg color (use default) + nil, // fonts (use default) + nil, // fonts storage (use default) + ) return &localService{ redis: redisClient, driver: driver, @@ -61,8 +72,8 @@ func (s *localService) Verify(ctx context.Context, id string, code string, ip st // Delete captcha after verification (one-time use) s.redis.Del(ctx, key) - // Verify code - return answer == code, nil + // Verify code (case-insensitive) + return strings.EqualFold(answer, code), nil } func (s *localService) GetType() CaptchaType { @@ -94,5 +105,5 @@ func (r *redisStore) Get(id string, clear bool) string { func (r *redisStore) Verify(id, answer string, clear bool) bool { v := r.Get(id, clear) - return v == answer + return strings.EqualFold(v, answer) }