diff --git a/apis/admin/group.api b/apis/admin/group.api new file mode 100644 index 0000000..de4aad9 --- /dev/null +++ b/apis/admin/group.api @@ -0,0 +1,215 @@ +syntax = "v1" + +info ( + title: "Group API" + desc: "API for user group and node group management" + author: "Tension" + email: "tension@ppanel.com" + version: "0.0.1" +) + +import ( + "../types.api" + "./server.api" +) + +type ( + // ===== 节点组管理 ===== + // GetNodeGroupListRequest + GetNodeGroupListRequest { + Page int `form:"page"` + Size int `form:"size"` + GroupId string `form:"group_id,omitempty"` + } + // GetNodeGroupListResponse + GetNodeGroupListResponse { + Total int64 `json:"total"` + List []NodeGroup `json:"list"` + } + // CreateNodeGroupRequest + CreateNodeGroupRequest { + Name string `json:"name" validate:"required"` + Description string `json:"description"` + Sort int `json:"sort"` + ForCalculation *bool `json:"for_calculation"` + IsExpiredGroup *bool `json:"is_expired_group"` + ExpiredDaysLimit *int `json:"expired_days_limit"` + MaxTrafficGBExpired *int64 `json:"max_traffic_gb_expired,omitempty"` + SpeedLimit *int `json:"speed_limit"` + MinTrafficGB *int64 `json:"min_traffic_gb,omitempty"` + MaxTrafficGB *int64 `json:"max_traffic_gb,omitempty"` + } + // UpdateNodeGroupRequest + UpdateNodeGroupRequest { + Id int64 `json:"id" validate:"required"` + Name string `json:"name"` + Description string `json:"description"` + Sort int `json:"sort"` + ForCalculation *bool `json:"for_calculation"` + IsExpiredGroup *bool `json:"is_expired_group"` + ExpiredDaysLimit *int `json:"expired_days_limit"` + MaxTrafficGBExpired *int64 `json:"max_traffic_gb_expired,omitempty"` + SpeedLimit *int `json:"speed_limit"` + MinTrafficGB *int64 `json:"min_traffic_gb,omitempty"` + MaxTrafficGB *int64 `json:"max_traffic_gb,omitempty"` + } + // DeleteNodeGroupRequest + DeleteNodeGroupRequest { + Id int64 `json:"id" validate:"required"` + } + // ===== 分组配置管理 ===== + // GetGroupConfigRequest + GetGroupConfigRequest { + Keys []string `form:"keys,omitempty"` + } + // GetGroupConfigResponse + GetGroupConfigResponse { + Enabled bool `json:"enabled"` + Mode string `json:"mode"` + Config map[string]interface{} `json:"config"` + State RecalculationState `json:"state"` + } + // UpdateGroupConfigRequest + UpdateGroupConfigRequest { + Enabled bool `json:"enabled"` + Mode string `json:"mode"` + Config map[string]interface{} `json:"config"` + } + // RecalculationState + RecalculationState { + State string `json:"state"` + Progress int `json:"progress"` + Total int `json:"total"` + } + // ===== 分组操作 ===== + // RecalculateGroupRequest + RecalculateGroupRequest { + Mode string `json:"mode" validate:"required"` + TriggerType string `json:"trigger_type"` // "manual" or "scheduled" + } + // GetGroupHistoryRequest + GetGroupHistoryRequest { + Page int `form:"page"` + Size int `form:"size"` + GroupMode string `form:"group_mode,omitempty"` + TriggerType string `form:"trigger_type,omitempty"` + } + // GetGroupHistoryResponse + GetGroupHistoryResponse { + Total int64 `json:"total"` + List []GroupHistory `json:"list"` + } + // GetGroupHistoryDetailRequest + GetGroupHistoryDetailRequest { + Id int64 `form:"id" validate:"required"` + } + // GetGroupHistoryDetailResponse + GetGroupHistoryDetailResponse { + GroupHistoryDetail + } + // PreviewUserNodesRequest + PreviewUserNodesRequest { + UserId int64 `form:"user_id" validate:"required"` + } + // PreviewUserNodesResponse + PreviewUserNodesResponse { + UserId int64 `json:"user_id"` + NodeGroups []NodeGroupItem `json:"node_groups"` + } + // NodeGroupItem + NodeGroupItem { + Id int64 `json:"id"` + Name string `json:"name"` + Nodes []Node `json:"nodes"` + } + // ExportGroupResultRequest + ExportGroupResultRequest { + HistoryId *int64 `form:"history_id,omitempty"` + } + // ===== Reset Groups ===== + // ResetGroupsRequest + ResetGroupsRequest { + Confirm bool `json:"confirm" validate:"required"` + } + // ===== 套餐分组映射 ===== + // SubscribeGroupMappingItem + SubscribeGroupMappingItem { + SubscribeName string `json:"subscribe_name"` + NodeGroupName string `json:"node_group_name"` + } + // GetSubscribeGroupMappingRequest + GetSubscribeGroupMappingRequest {} + // GetSubscribeGroupMappingResponse + GetSubscribeGroupMappingResponse { + List []SubscribeGroupMappingItem `json:"list"` + } +) + +@server ( + prefix: v1/admin/group + group: admin/group + jwt: JwtAuth + middleware: AuthMiddleware +) +service ppanel { + // ===== 节点组管理 ===== + @doc "Get node group list" + @handler GetNodeGroupList + get /node/list (GetNodeGroupListRequest) returns (GetNodeGroupListResponse) + + @doc "Create node group" + @handler CreateNodeGroup + post /node (CreateNodeGroupRequest) + + @doc "Update node group" + @handler UpdateNodeGroup + put /node (UpdateNodeGroupRequest) + + @doc "Delete node group" + @handler DeleteNodeGroup + delete /node (DeleteNodeGroupRequest) + + // ===== 分组配置管理 ===== + @doc "Get group config" + @handler GetGroupConfig + get /config (GetGroupConfigRequest) returns (GetGroupConfigResponse) + + @doc "Update group config" + @handler UpdateGroupConfig + put /config (UpdateGroupConfigRequest) + + // ===== 分组操作 ===== + @doc "Recalculate group" + @handler RecalculateGroup + post /recalculate (RecalculateGroupRequest) + + @doc "Get recalculation status" + @handler GetRecalculationStatus + get /recalculation/status returns (RecalculationState) + + @doc "Get group history" + @handler GetGroupHistory + get /history (GetGroupHistoryRequest) returns (GetGroupHistoryResponse) + + @doc "Export group result" + @handler ExportGroupResult + get /export (ExportGroupResultRequest) + + // Routes with query parameters + @doc "Get group history detail" + @handler GetGroupHistoryDetail + get /history/detail (GetGroupHistoryDetailRequest) returns (GetGroupHistoryDetailResponse) + + @doc "Preview user nodes" + @handler PreviewUserNodes + get /preview (PreviewUserNodesRequest) returns (PreviewUserNodesResponse) + + @doc "Reset all groups" + @handler ResetGroups + post /reset (ResetGroupsRequest) + + @doc "Get subscribe group mapping" + @handler GetSubscribeGroupMapping + get /subscribe/mapping (GetSubscribeGroupMappingRequest) returns (GetSubscribeGroupMappingResponse) +} + diff --git a/apis/admin/server.api b/apis/admin/server.api index 87d1f7e..5f5479d 100644 --- a/apis/admin/server.api +++ b/apis/admin/server.api @@ -80,36 +80,40 @@ type ( Protocols []Protocol `json:"protocols"` } Node { - Id int64 `json:"id"` - Name string `json:"name"` - Tags []string `json:"tags"` - Port uint16 `json:"port"` - Address string `json:"address"` - ServerId int64 `json:"server_id"` - Protocol string `json:"protocol"` - Enabled *bool `json:"enabled"` - Sort int `json:"sort,omitempty"` - CreatedAt int64 `json:"created_at"` - UpdatedAt int64 `json:"updated_at"` + Id int64 `json:"id"` + Name string `json:"name"` + Tags []string `json:"tags"` + Port uint16 `json:"port"` + Address string `json:"address"` + ServerId int64 `json:"server_id"` + Protocol string `json:"protocol"` + Enabled *bool `json:"enabled"` + Sort int `json:"sort,omitempty"` + NodeGroupId int64 `json:"node_group_id,omitempty"` + NodeGroupIds []int64 `json:"node_group_ids,omitempty"` + CreatedAt int64 `json:"created_at"` + UpdatedAt int64 `json:"updated_at"` } CreateNodeRequest { - Name string `json:"name"` - Tags []string `json:"tags,omitempty"` - Port uint16 `json:"port"` - Address string `json:"address"` - ServerId int64 `json:"server_id"` - Protocol string `json:"protocol"` - Enabled *bool `json:"enabled"` + Name string `json:"name"` + Tags []string `json:"tags,omitempty"` + Port uint16 `json:"port"` + Address string `json:"address"` + ServerId int64 `json:"server_id"` + Protocol string `json:"protocol"` + Enabled *bool `json:"enabled"` + NodeGroupIds []int64 `json:"node_group_ids,omitempty"` } UpdateNodeRequest { - Id int64 `json:"id"` - Name string `json:"name"` - Tags []string `json:"tags,omitempty"` - Port uint16 `json:"port"` - Address string `json:"address"` - ServerId int64 `json:"server_id"` - Protocol string `json:"protocol"` - Enabled *bool `json:"enabled"` + Id int64 `json:"id"` + Name string `json:"name"` + Tags []string `json:"tags,omitempty"` + Port uint16 `json:"port"` + Address string `json:"address"` + ServerId int64 `json:"server_id"` + Protocol string `json:"protocol"` + Enabled *bool `json:"enabled"` + NodeGroupIds []int64 `json:"node_group_ids,omitempty"` } ToggleNodeStatusRequest { Id int64 `json:"id"` @@ -119,9 +123,10 @@ type ( Id int64 `json:"id"` } FilterNodeListRequest { - Page int `form:"page"` - Size int `form:"size"` - Search string `form:"search,omitempty"` + Page int `form:"page"` + Size int `form:"size"` + Search string `form:"search,omitempty"` + NodeGroupId *int64 `form:"node_group_id,omitempty"` } FilterNodeListResponse { Total int64 `json:"total"` diff --git a/apis/admin/subscribe.api b/apis/admin/subscribe.api index a832b3a..8a662b8 100644 --- a/apis/admin/subscribe.api +++ b/apis/admin/subscribe.api @@ -48,6 +48,9 @@ type ( Quota int64 `json:"quota"` Nodes []int64 `json:"nodes"` NodeTags []string `json:"node_tags"` + NodeGroupIds []int64 `json:"node_group_ids,omitempty"` + NodeGroupId int64 `json:"node_group_id"` + TrafficLimit []TrafficLimit `json:"traffic_limit"` Show *bool `json:"show"` Sell *bool `json:"sell"` DeductionRatio int64 `json:"deduction_ratio"` @@ -55,6 +58,7 @@ type ( ResetCycle int64 `json:"reset_cycle"` RenewalReset *bool `json:"renewal_reset"` ShowOriginalPrice bool `json:"show_original_price"` + AutoCreateGroup bool `json:"auto_create_group"` } UpdateSubscribeRequest { Id int64 `json:"id" validate:"required"` @@ -72,6 +76,9 @@ type ( Quota int64 `json:"quota"` Nodes []int64 `json:"nodes"` NodeTags []string `json:"node_tags"` + NodeGroupIds []int64 `json:"node_group_ids,omitempty"` + NodeGroupId int64 `json:"node_group_id"` + TrafficLimit []TrafficLimit `json:"traffic_limit"` Show *bool `json:"show"` Sell *bool `json:"sell"` Sort int64 `json:"sort"` @@ -85,10 +92,11 @@ type ( Sort []SortItem `json:"sort"` } GetSubscribeListRequest { - Page int64 `form:"page" validate:"required"` - Size int64 `form:"size" validate:"required"` - Language string `form:"language,omitempty"` - Search string `form:"search,omitempty"` + Page int64 `form:"page" validate:"required"` + Size int64 `form:"size" validate:"required"` + Language string `form:"language,omitempty"` + Search string `form:"search,omitempty"` + NodeGroupId int64 `form:"node_group_id,omitempty"` } SubscribeItem { Subscribe diff --git a/apis/admin/user.api b/apis/admin/user.api index b2822c6..e75149f 100644 --- a/apis/admin/user.api +++ b/apis/admin/user.api @@ -83,6 +83,8 @@ type ( OrderId int64 `json:"order_id"` SubscribeId int64 `json:"subscribe_id"` Subscribe Subscribe `json:"subscribe"` + NodeGroupId int64 `json:"node_group_id"` + GroupLocked bool `json:"group_locked"` StartTime int64 `json:"start_time"` ExpireTime int64 `json:"expire_time"` ResetTime int64 `json:"reset_time"` diff --git a/apis/auth/auth.api b/apis/auth/auth.api index 84f0f02..fa5d8dd 100644 --- a/apis/auth/auth.api +++ b/apis/auth/auth.api @@ -11,13 +11,15 @@ info ( type ( // User login request UserLoginRequest { - Identifier string `json:"identifier"` - Email string `json:"email" validate:"required"` - Password string `json:"password" validate:"required"` - IP string `header:"X-Original-Forwarded-For"` - UserAgent string `header:"User-Agent"` - LoginType string `header:"Login-Type"` - CfToken string `json:"cf_token,optional"` + Identifier string `json:"identifier"` + Email string `json:"email" validate:"required"` + Password string `json:"password" validate:"required"` + IP string `header:"X-Original-Forwarded-For"` + UserAgent string `header:"User-Agent"` + LoginType string `header:"Login-Type"` + CfToken string `json:"cf_token,optional"` + CaptchaId string `json:"captcha_id,optional"` + CaptchaCode string `json:"captcha_code,optional"` } // Check user is exist request CheckUserRequest { @@ -29,26 +31,30 @@ type ( } // User login response UserRegisterRequest { - Identifier string `json:"identifier"` - Email string `json:"email" validate:"required"` - Password string `json:"password" validate:"required"` - Invite string `json:"invite,optional"` - Code string `json:"code,optional"` - IP string `header:"X-Original-Forwarded-For"` - UserAgent string `header:"User-Agent"` - LoginType string `header:"Login-Type"` - CfToken string `json:"cf_token,optional"` + Identifier string `json:"identifier"` + Email string `json:"email" validate:"required"` + Password string `json:"password" validate:"required"` + Invite string `json:"invite,optional"` + Code string `json:"code,optional"` + IP string `header:"X-Original-Forwarded-For"` + UserAgent string `header:"User-Agent"` + LoginType string `header:"Login-Type"` + CfToken string `json:"cf_token,optional"` + CaptchaId string `json:"captcha_id,optional"` + CaptchaCode string `json:"captcha_code,optional"` } - // User login response + // User reset password request ResetPasswordRequest { - Identifier string `json:"identifier"` - Email string `json:"email" validate:"required"` - Password string `json:"password" validate:"required"` - Code string `json:"code,optional"` - IP string `header:"X-Original-Forwarded-For"` - UserAgent string `header:"User-Agent"` - LoginType string `header:"Login-Type"` - CfToken string `json:"cf_token,optional"` + Identifier string `json:"identifier"` + Email string `json:"email" validate:"required"` + Password string `json:"password" validate:"required"` + Code string `json:"code,optional"` + IP string `header:"X-Original-Forwarded-For"` + UserAgent string `header:"User-Agent"` + LoginType string `header:"Login-Type"` + CfToken string `json:"cf_token,optional"` + CaptchaId string `json:"captcha_id,optional"` + CaptchaCode string `json:"captcha_code,optional"` } // Email login request EmailLoginRequest { @@ -86,6 +92,8 @@ type ( UserAgent string `header:"User-Agent"` LoginType string `header:"Login-Type"` CfToken string `json:"cf_token,optional"` + CaptchaId string `json:"captcha_id,optional"` + CaptchaCode string `json:"captcha_code,optional"` } // Check user is exist request TelephoneCheckUserRequest { @@ -108,6 +116,8 @@ type ( UserAgent string `header:"User-Agent"` LoginType string `header:"Login-Type,optional"` CfToken string `json:"cf_token,optional"` + CaptchaId string `json:"captcha_id,optional"` + CaptchaCode string `json:"captcha_code,optional"` } // User login response TelephoneResetPasswordRequest { @@ -120,6 +130,8 @@ type ( UserAgent string `header:"User-Agent"` LoginType string `header:"Login-Type,optional"` CfToken string `json:"cf_token,optional"` + CaptchaId string `json:"captcha_id,optional"` + CaptchaCode string `json:"captcha_code,optional"` } AppleLoginCallbackRequest { Code string `form:"code"` @@ -137,6 +149,11 @@ type ( CfToken string `json:"cf_token,optional"` ShortCode string `json:"short_code,optional"` } + GenerateCaptchaResponse { + Id string `json:"id"` + Image string `json:"image"` + Type string `json:"type"` + } ) @server ( @@ -181,11 +198,34 @@ service ppanel { @handler TelephoneResetPassword post /reset/telephone (TelephoneResetPasswordRequest) returns (LoginResponse) + @doc "Generate captcha" + @handler GenerateCaptcha + post /captcha/generate returns (GenerateCaptchaResponse) + @doc "Device Login" @handler DeviceLogin post /login/device (DeviceLoginRequest) returns (LoginResponse) } +@server ( + prefix: v1/auth/admin + group: auth/admin + middleware: DeviceMiddleware +) +service ppanel { + @doc "Admin login" + @handler AdminLogin + post /login (UserLoginRequest) returns (LoginResponse) + + @doc "Admin reset password" + @handler AdminResetPassword + post /reset (ResetPasswordRequest) returns (LoginResponse) + + @doc "Generate captcha" + @handler AdminGenerateCaptcha + post /captcha/generate returns (GenerateCaptchaResponse) +} + @server ( prefix: v1/auth/oauth group: auth/oauth diff --git a/apis/common.api b/apis/common.api index d264b43..3ece8ac 100644 --- a/apis/common.api +++ b/apis/common.api @@ -12,10 +12,12 @@ import "./types.api" type ( VeifyConfig { - TurnstileSiteKey string `json:"turnstile_site_key"` - EnableLoginVerify bool `json:"enable_login_verify"` - EnableRegisterVerify bool `json:"enable_register_verify"` - EnableResetPasswordVerify bool `json:"enable_reset_password_verify"` + CaptchaType string `json:"captcha_type"` + TurnstileSiteKey string `json:"turnstile_site_key"` + EnableUserLoginCaptcha bool `json:"enable_user_login_captcha"` + EnableUserRegisterCaptcha bool `json:"enable_user_register_captcha"` + EnableAdminLoginCaptcha bool `json:"enable_admin_login_captcha"` + EnableUserResetPasswordCaptcha bool `json:"enable_user_reset_password_captcha"` } GetGlobalConfigResponse { Site SiteConfig `json:"site"` diff --git a/apis/types.api b/apis/types.api index 53e4b3f..8249711 100644 --- a/apis/types.api +++ b/apis/types.api @@ -170,11 +170,13 @@ type ( DeviceLimit int64 `json:"device_limit"` } VerifyConfig { - TurnstileSiteKey string `json:"turnstile_site_key"` - TurnstileSecret string `json:"turnstile_secret"` - EnableLoginVerify bool `json:"enable_login_verify"` - EnableRegisterVerify bool `json:"enable_register_verify"` - EnableResetPasswordVerify bool `json:"enable_reset_password_verify"` + CaptchaType string `json:"captcha_type"` // local or turnstile + TurnstileSiteKey string `json:"turnstile_site_key"` + TurnstileSecret string `json:"turnstile_secret"` + EnableUserLoginCaptcha bool `json:"enable_user_login_captcha"` // User login captcha + EnableUserRegisterCaptcha bool `json:"enable_user_register_captcha"` // User register captcha + EnableAdminLoginCaptcha bool `json:"enable_admin_login_captcha"` // Admin login captcha + EnableUserResetPasswordCaptcha bool `json:"enable_user_reset_password_captcha"` // User reset password captcha } NodeConfig { NodeSecret string `json:"node_secret"` @@ -226,6 +228,12 @@ type ( Quantity int64 `json:"quantity"` Discount float64 `json:"discount"` } + TrafficLimit { + StatType string `json:"stat_type"` + StatValue int64 `json:"stat_value"` + TrafficUsage int64 `json:"traffic_usage"` + SpeedLimit int64 `json:"speed_limit"` + } Subscribe { Id int64 `json:"id"` Name string `json:"name"` @@ -243,6 +251,9 @@ type ( Quota int64 `json:"quota"` Nodes []int64 `json:"nodes"` NodeTags []string `json:"node_tags"` + NodeGroupIds []int64 `json:"node_group_ids,omitempty"` + NodeGroupId int64 `json:"node_group_id"` + TrafficLimit []TrafficLimit `json:"traffic_limit"` Show bool `json:"show"` Sell bool `json:"sell"` Sort int64 `json:"sort"` @@ -951,5 +962,42 @@ type ( ResetUserSubscribeTokenRequest { UserSubscribeId int64 `json:"user_subscribe_id"` } + // ===== 分组功能类型定义 ===== + // NodeGroup 节点组 + NodeGroup { + Id int64 `json:"id"` + Name string `json:"name"` + Description string `json:"description"` + Sort int `json:"sort"` + ForCalculation bool `json:"for_calculation"` + IsExpiredGroup bool `json:"is_expired_group"` + ExpiredDaysLimit int `json:"expired_days_limit"` + MaxTrafficGBExpired int64 `json:"max_traffic_gb_expired,omitempty"` + SpeedLimit int `json:"speed_limit"` + MinTrafficGB int64 `json:"min_traffic_gb,omitempty"` + MaxTrafficGB int64 `json:"max_traffic_gb,omitempty"` + NodeCount int64 `json:"node_count,omitempty"` + CreatedAt int64 `json:"created_at"` + UpdatedAt int64 `json:"updated_at"` + } + // GroupHistory 分组历史记录 + GroupHistory { + Id int64 `json:"id"` + GroupMode string `json:"group_mode"` + TriggerType string `json:"trigger_type"` + TotalUsers int `json:"total_users"` + SuccessCount int `json:"success_count"` + FailedCount int `json:"failed_count"` + StartTime *int64 `json:"start_time,omitempty"` + EndTime *int64 `json:"end_time,omitempty"` + Operator string `json:"operator,omitempty"` + ErrorLog string `json:"error_log,omitempty"` + CreatedAt int64 `json:"created_at"` + } + // GroupHistoryDetail 分组历史详情 + GroupHistoryDetail { + GroupHistory + ConfigSnapshot map[string]interface{} `json:"config_snapshot,omitempty"` + } ) diff --git a/go.mod b/go.mod index 88d8ca4..222d463 100644 --- a/go.mod +++ b/go.mod @@ -94,6 +94,7 @@ require ( github.com/gin-contrib/sse v1.0.0 // indirect github.com/go-logr/logr v1.4.2 // indirect github.com/go-logr/stdr v1.2.2 // indirect + github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 // indirect github.com/golang/glog v1.2.0 // indirect github.com/golang/mock v1.6.0 // indirect github.com/golang/protobuf v1.5.4 // indirect @@ -118,6 +119,7 @@ require ( github.com/mitchellh/reflectwalk v1.0.2 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect + github.com/mojocn/base64Captcha v1.3.8 // indirect github.com/openzipkin/zipkin-go v0.4.2 // indirect github.com/oschwald/maxminddb-golang v1.13.0 // indirect github.com/pelletier/go-toml/v2 v2.2.3 // indirect @@ -140,6 +142,7 @@ require ( go.uber.org/multierr v1.11.0 // indirect golang.org/x/arch v0.13.0 // indirect golang.org/x/exp v0.0.0-20240525044651-4c93da0ed11d // indirect + golang.org/x/image v0.23.0 // indirect golang.org/x/net v0.34.0 // indirect golang.org/x/sys v0.30.0 // indirect golang.org/x/text v0.22.0 // indirect diff --git a/go.sum b/go.sum index bec641c..ad1c08e 100644 --- a/go.sum +++ b/go.sum @@ -159,6 +159,8 @@ github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeD github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/golang-migrate/migrate/v4 v4.18.2 h1:2VSCMz7x7mjyTXx3m2zPokOY82LTRgxK1yQYKo6wWQ8= github.com/golang-migrate/migrate/v4 v4.18.2/go.mod h1:2CM6tJvn2kqPXwnXO/d3rAQYiyoIm180VsO8PRX6Rpk= +github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 h1:DACJavvAHhabrF08vX0COfcOBJRhZ8lUbR+ZWIs0Y5g= +github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0/go.mod h1:E/TSTwGwJL78qG/PmXZO1EjYhfJinVAhrmmHX6Z8B9k= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/glog v1.2.0 h1:uCdmnmatrKCgMBlM4rMuJZWOkPDqdbZPnrMXDY4gI68= github.com/golang/glog v1.2.0/go.mod h1:6AhwSGph0fcJtXVM/PEHPqZlFeoLxhs7/t5UDAwmO+w= @@ -274,6 +276,8 @@ github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lN github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= +github.com/mojocn/base64Captcha v1.3.8 h1:rrN9BhCwXKS8ht1e21kvR3iTaMgf4qPC9sRoV52bqEg= +github.com/mojocn/base64Captcha v1.3.8/go.mod h1:QFZy927L8HVP3+VV5z2b1EAEiv1KxVJKZbAucVgLUy4= github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A= github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc= github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= @@ -405,12 +409,17 @@ golang.org/x/crypto v0.0.0-20200510223506-06a226fb4e37/go.mod h1:LzIPMQfyMNhhGPh golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20201012173705-84dcc777aaee/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc= golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg= +golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= +golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= golang.org/x/crypto v0.35.0 h1:b15kiHdrGCHrP6LvwaQ3c03kgNhhiMgvlhxHQhmg2Xs= golang.org/x/crypto v0.35.0/go.mod h1:dy7dXNW32cAb/6/PRuTNsix8T+vJAqvuIy5Bli/x0YQ= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20240525044651-4c93da0ed11d h1:N0hmiNbwsSNwHBAvR3QB5w25pUwH4tK0Y/RltD1j1h4= golang.org/x/exp v0.0.0-20240525044651-4c93da0ed11d/go.mod h1:XtvwrStGgqGPLc4cjQfWqZHG1YFdYs6swckp8vpsjnc= +golang.org/x/image v0.23.0 h1:HseQ7c2OpPKTPVzNjG5fwJsOTCiiwS4QdsYi5XU6H68= +golang.org/x/image v0.23.0/go.mod h1:wJJBTdLfCCf3tiHa1fNxpZmUI4mmoZvwMCPP0ddoNKY= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= @@ -419,6 +428,9 @@ golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= +golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= +golang.org/x/mod v0.15.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= +golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -434,7 +446,10 @@ golang.org/x/net v0.0.0-20210520170846-37e1c6afe023/go.mod h1:9nx3DQGgdP8bBQD5qx golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= +golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk= golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY= +golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= +golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= golang.org/x/net v0.34.0 h1:Mb7Mrk043xzHgnRM88suvJFwzVrRfHEHJEl5/71CKw0= golang.org/x/net v0.34.0/go.mod h1:di0qlW3YNM5oh6GqDGQr92MyTozJPmybPK4Ev/Gm31k= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= @@ -448,6 +463,10 @@ golang.org/x/sync v0.0.0-20200317015054-43a5402ce75a/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= +golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -466,14 +485,21 @@ golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc= golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/telemetry v0.0.0-20240228155512-f48c80bd79b2/go.mod h1:TeRTkGYfJXctD9OcfyVLyj2J3IxLnKwHJR8f4D8a3YE= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= +golang.org/x/term v0.12.0/go.mod h1:owVbMEjm3cBLCHdkQu9b1opXd4ETQWc3BhuQGKgXgvU= golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY= +golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk= +golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= @@ -481,7 +507,10 @@ golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= +golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= golang.org/x/text v0.22.0 h1:bofq7m3/HAFvbF51jz3Q9wLg3jkvSPuiZu/pD1XwgtM= golang.org/x/text v0.22.0/go.mod h1:YRoo4H8PVmsu+E3Ou7cqLVH8oXWIHVoX0jqUWALQhfY= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= @@ -499,6 +528,8 @@ golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.5/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= +golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58= +golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/initialize/migrate/database/02131_add_groups.down.sql b/initialize/migrate/database/02131_add_groups.down.sql new file mode 100644 index 0000000..6765acc --- /dev/null +++ b/initialize/migrate/database/02131_add_groups.down.sql @@ -0,0 +1,28 @@ +-- Purpose: Rollback node group management tables +-- Author: Tension +-- Date: 2025-02-23 +-- Updated: 2025-03-06 + +-- ===== Remove system configuration entries ===== +DELETE FROM `system` WHERE `category` = 'group' AND `key` IN ('enabled', 'mode', 'auto_create_group'); + +-- ===== Remove columns and indexes from subscribe table ===== +ALTER TABLE `subscribe` DROP INDEX IF EXISTS `idx_node_group_id`; +ALTER TABLE `subscribe` DROP COLUMN IF EXISTS `node_group_id`; +ALTER TABLE `subscribe` DROP COLUMN IF EXISTS `node_group_ids`; + +-- ===== Remove columns and indexes from user_subscribe table ===== +ALTER TABLE `user_subscribe` DROP INDEX IF EXISTS `idx_node_group_id`; +ALTER TABLE `user_subscribe` DROP COLUMN IF EXISTS `node_group_id`; + +-- ===== Remove columns and indexes from nodes table ===== +ALTER TABLE `nodes` DROP COLUMN IF EXISTS `node_group_ids`; + +-- ===== Drop group_history_detail table ===== +DROP TABLE IF EXISTS `group_history_detail`; + +-- ===== Drop group_history table ===== +DROP TABLE IF EXISTS `group_history`; + +-- ===== Drop node_group table ===== +DROP TABLE IF EXISTS `node_group`; diff --git a/initialize/migrate/database/02131_add_groups.up.sql b/initialize/migrate/database/02131_add_groups.up.sql new file mode 100644 index 0000000..a4150ce --- /dev/null +++ b/initialize/migrate/database/02131_add_groups.up.sql @@ -0,0 +1,130 @@ +-- Purpose: Add node group management tables with multi-group support +-- Author: Tension +-- Date: 2025-02-23 +-- Updated: 2025-03-06 + +-- ===== Create node_group table ===== +DROP TABLE IF EXISTS `node_group`; +CREATE TABLE IF NOT EXISTS `node_group` ( + `id` bigint NOT NULL AUTO_INCREMENT, + `name` varchar(255) CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci NOT NULL DEFAULT '' COMMENT 'Name', + `description` varchar(500) CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci DEFAULT NULL COMMENT 'Group Description', + `sort` int NOT NULL DEFAULT '0' COMMENT 'Sort Order', + `for_calculation` tinyint(1) NOT NULL DEFAULT 1 COMMENT 'For Grouping Calculation: 0=false, 1=true', + `min_traffic_gb` bigint DEFAULT 0 COMMENT 'Minimum Traffic (GB) for this node group', + `max_traffic_gb` bigint DEFAULT 0 COMMENT 'Maximum Traffic (GB) for this node group', + `created_at` datetime(3) DEFAULT NULL COMMENT 'Create Time', + `updated_at` datetime(3) DEFAULT NULL COMMENT 'Update Time', + PRIMARY KEY (`id`), + KEY `idx_sort` (`sort`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_general_ci COMMENT='Node Groups'; + +-- ===== Create group_history table ===== +DROP TABLE IF EXISTS `group_history`; +CREATE TABLE IF NOT EXISTS `group_history` ( + `id` bigint NOT NULL AUTO_INCREMENT, + `group_mode` varchar(50) CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci NOT NULL DEFAULT '' COMMENT 'Group Mode: average/subscribe/traffic', + `trigger_type` varchar(50) CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci NOT NULL DEFAULT '' COMMENT 'Trigger Type: manual/auto/schedule', + `state` varchar(50) CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci NOT NULL DEFAULT '' COMMENT 'State: pending/running/completed/failed', + `total_users` int NOT NULL DEFAULT '0' COMMENT 'Total Users', + `success_count` int NOT NULL DEFAULT '0' COMMENT 'Success Count', + `failed_count` int NOT NULL DEFAULT '0' COMMENT 'Failed Count', + `start_time` datetime(3) DEFAULT NULL COMMENT 'Start Time', + `end_time` datetime(3) DEFAULT NULL COMMENT 'End Time', + `operator` varchar(100) CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci DEFAULT NULL COMMENT 'Operator', + `error_message` text CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci COMMENT 'Error Message', + `created_at` datetime(3) DEFAULT NULL COMMENT 'Create Time', + PRIMARY KEY (`id`), + KEY `idx_group_mode` (`group_mode`), + KEY `idx_trigger_type` (`trigger_type`), + KEY `idx_state` (`state`), + KEY `idx_created_at` (`created_at`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_general_ci COMMENT='Group Calculation History'; + +-- ===== Create group_history_detail table ===== +-- Note: user_group_id column removed, using user_data JSON field instead +DROP TABLE IF EXISTS `group_history_detail`; +CREATE TABLE IF NOT EXISTS `group_history_detail` ( + `id` bigint NOT NULL AUTO_INCREMENT, + `history_id` bigint NOT NULL COMMENT 'History ID', + `node_group_id` bigint NOT NULL COMMENT 'Node Group ID', + `user_count` int NOT NULL DEFAULT '0' COMMENT 'User Count', + `node_count` int NOT NULL DEFAULT '0' COMMENT 'Node Count', + `user_data` TEXT COMMENT 'User data JSON (id and email/phone)', + `created_at` datetime(3) DEFAULT NULL COMMENT 'Create Time', + PRIMARY KEY (`id`), + KEY `idx_history_id` (`history_id`), + KEY `idx_node_group_id` (`node_group_id`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_general_ci COMMENT='Group History Details'; + +-- ===== Add columns to nodes table ===== +SET @column_exists = (SELECT COUNT(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = 'nodes' AND COLUMN_NAME = 'node_group_ids'); +SET @sql = IF(@column_exists = 0, + 'ALTER TABLE `nodes` ADD COLUMN `node_group_ids` JSON COMMENT ''Node Group IDs (JSON array, multiple groups)''', + 'SELECT ''Column node_group_ids already exists'''); +PREPARE stmt FROM @sql; +EXECUTE stmt; +DEALLOCATE PREPARE stmt; + +-- ===== Add node_group_id column to user_subscribe table ===== +SET @column_exists = (SELECT COUNT(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = 'user_subscribe' AND COLUMN_NAME = 'node_group_id'); +SET @sql = IF(@column_exists = 0, + 'ALTER TABLE `user_subscribe` ADD COLUMN `node_group_id` bigint NOT NULL DEFAULT 0 COMMENT ''Node Group ID (single ID)''', + 'SELECT ''Column node_group_id already exists'''); +PREPARE stmt FROM @sql; +EXECUTE stmt; +DEALLOCATE PREPARE stmt; + +-- ===== Add index for user_subscribe.node_group_id ===== +SET @index_exists = (SELECT COUNT(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = 'user_subscribe' AND INDEX_NAME = 'idx_node_group_id'); +SET @sql = IF(@index_exists = 0, + 'ALTER TABLE `user_subscribe` ADD INDEX `idx_node_group_id` (`node_group_id`)', + 'SELECT ''Index idx_node_group_id already exists'''); +PREPARE stmt FROM @sql; +EXECUTE stmt; +DEALLOCATE PREPARE stmt; + +-- ===== Add group_locked column to user_subscribe table ===== +SET @column_exists = (SELECT COUNT(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = 'user_subscribe' AND COLUMN_NAME = 'group_locked'); +SET @sql = IF(@column_exists = 0, + 'ALTER TABLE `user_subscribe` ADD COLUMN `group_locked` tinyint(1) NOT NULL DEFAULT 0 COMMENT ''Group Locked''', + 'SELECT ''Column group_locked already exists in user_subscribe table'''); +PREPARE stmt FROM @sql; +EXECUTE stmt; +DEALLOCATE PREPARE stmt; + +-- ===== Add columns to subscribe table ===== +SET @column_exists = (SELECT COUNT(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = 'subscribe' AND COLUMN_NAME = 'node_group_ids'); +SET @sql = IF(@column_exists = 0, + 'ALTER TABLE `subscribe` ADD COLUMN `node_group_ids` JSON COMMENT ''Node Group IDs (JSON array, multiple groups)''', + 'SELECT ''Column node_group_ids already exists'''); +PREPARE stmt FROM @sql; +EXECUTE stmt; +DEALLOCATE PREPARE stmt; + +-- ===== Add default node_group_id column to subscribe table ===== +SET @column_exists = (SELECT COUNT(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = 'subscribe' AND COLUMN_NAME = 'node_group_id'); +SET @sql = IF(@column_exists = 0, + 'ALTER TABLE `subscribe` ADD COLUMN `node_group_id` bigint NOT NULL DEFAULT 0 COMMENT ''Default Node Group ID (single ID)''', + 'SELECT ''Column node_group_id already exists in subscribe table'''); +PREPARE stmt FROM @sql; +EXECUTE stmt; +DEALLOCATE PREPARE stmt; + +-- ===== Add index for subscribe.node_group_id ===== +SET @index_exists = (SELECT COUNT(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = 'subscribe' AND INDEX_NAME = 'idx_node_group_id'); +SET @sql = IF(@index_exists = 0, + 'ALTER TABLE `subscribe` ADD INDEX `idx_node_group_id` (`node_group_id`)', + 'SELECT ''Index idx_node_group_id already exists in subscribe table'''); +PREPARE stmt FROM @sql; +EXECUTE stmt; +DEALLOCATE PREPARE stmt; + +-- ===== Insert system configuration entries ===== +INSERT INTO `system` (`category`, `key`, `value`, `desc`) VALUES + ('group', 'enabled', 'false', 'Group Management Enabled'), + ('group', 'mode', 'average', 'Group Mode: average/subscribe/traffic'), + ('group', 'auto_create_group', 'false', 'Auto-create user group when creating subscribe product') +ON DUPLICATE KEY UPDATE + `value` = VALUES(`value`), + `desc` = VALUES(`desc`); diff --git a/initialize/migrate/database/02132_update_verify_config.down.sql b/initialize/migrate/database/02132_update_verify_config.down.sql new file mode 100644 index 0000000..2c66df3 --- /dev/null +++ b/initialize/migrate/database/02132_update_verify_config.down.sql @@ -0,0 +1,17 @@ +-- Rollback: restore old verify configuration fields +INSERT INTO `system` (`category`, `key`, `value`, `type`, `desc`) VALUES + ('verify', 'EnableLoginVerify', 'false', 'bool', 'is enable login verify'), + ('verify', 'EnableRegisterVerify', 'false', 'bool', 'is enable register verify'), + ('verify', 'EnableResetPasswordVerify', 'false', 'bool', 'is enable reset password verify') +ON DUPLICATE KEY UPDATE + `value` = VALUES(`value`), + `desc` = VALUES(`desc`); + +-- Remove new captcha configuration fields +DELETE FROM `system` WHERE `category` = 'verify' AND `key` IN ( + 'CaptchaType', + 'EnableUserLoginCaptcha', + 'EnableUserRegisterCaptcha', + 'EnableAdminLoginCaptcha', + 'EnableUserResetPasswordCaptcha' +); diff --git a/initialize/migrate/database/02132_update_verify_config.up.sql b/initialize/migrate/database/02132_update_verify_config.up.sql new file mode 100644 index 0000000..b6d5137 --- /dev/null +++ b/initialize/migrate/database/02132_update_verify_config.up.sql @@ -0,0 +1,17 @@ +-- Add new captcha configuration fields +INSERT INTO `system` (`category`, `key`, `value`, `type`, `desc`) VALUES + ('verify', 'CaptchaType', 'local', 'string', 'Captcha type: local or turnstile'), + ('verify', 'EnableUserLoginCaptcha', 'false', 'bool', 'Enable captcha for user login'), + ('verify', 'EnableUserRegisterCaptcha', 'false', 'bool', 'Enable captcha for user registration'), + ('verify', 'EnableAdminLoginCaptcha', 'false', 'bool', 'Enable captcha for admin login'), + ('verify', 'EnableUserResetPasswordCaptcha', 'false', 'bool', 'Enable captcha for user reset password') +ON DUPLICATE KEY UPDATE + `value` = VALUES(`value`), + `desc` = VALUES(`desc`); + +-- Remove old verify configuration fields +DELETE FROM `system` WHERE `category` = 'verify' AND `key` IN ( + 'EnableLoginVerify', + 'EnableRegisterVerify', + 'EnableResetPasswordVerify' +); diff --git a/initialize/migrate/database/02133_add_expired_node_group.down.sql b/initialize/migrate/database/02133_add_expired_node_group.down.sql new file mode 100644 index 0000000..aeacd0e --- /dev/null +++ b/initialize/migrate/database/02133_add_expired_node_group.down.sql @@ -0,0 +1,12 @@ +-- 回滚 user_subscribe 表的过期流量字段 +ALTER TABLE `user_subscribe` +DROP COLUMN `expired_upload`, +DROP COLUMN `expired_download`; + +-- 回滚 node_group 表的过期节点组字段 +ALTER TABLE `node_group` +DROP INDEX `idx_is_expired_group`, +DROP COLUMN `speed_limit`, +DROP COLUMN `max_traffic_gb_expired`, +DROP COLUMN `expired_days_limit`, +DROP COLUMN `is_expired_group`; diff --git a/initialize/migrate/database/02133_add_expired_node_group.up.sql b/initialize/migrate/database/02133_add_expired_node_group.up.sql new file mode 100644 index 0000000..283f9a5 --- /dev/null +++ b/initialize/migrate/database/02133_add_expired_node_group.up.sql @@ -0,0 +1,14 @@ +-- 为 node_group 表添加过期节点组相关字段 +ALTER TABLE `node_group` +ADD COLUMN `is_expired_group` tinyint(1) NOT NULL DEFAULT 0 COMMENT 'Is Expired Group: 0=normal, 1=expired group' AFTER `for_calculation`, +ADD COLUMN `expired_days_limit` int NOT NULL DEFAULT 7 COMMENT 'Expired days limit (days)' AFTER `is_expired_group`, +ADD COLUMN `max_traffic_gb_expired` bigint DEFAULT 0 COMMENT 'Max traffic for expired users (GB)' AFTER `expired_days_limit`, +ADD COLUMN `speed_limit` int NOT NULL DEFAULT 0 COMMENT 'Speed limit (KB/s)' AFTER `max_traffic_gb_expired`; + +-- 添加索引 +ALTER TABLE `node_group` ADD INDEX `idx_is_expired_group` (`is_expired_group`); + +-- 为 user_subscribe 表添加过期流量统计字段 +ALTER TABLE `user_subscribe` +ADD COLUMN `expired_download` bigint NOT NULL DEFAULT 0 COMMENT 'Expired period download traffic (bytes)' AFTER `upload`, +ADD COLUMN `expired_upload` bigint NOT NULL DEFAULT 0 COMMENT 'Expired period upload traffic (bytes)' AFTER `expired_download`; diff --git a/initialize/migrate/database/02134_subscribe_traffic_limit.down.sql b/initialize/migrate/database/02134_subscribe_traffic_limit.down.sql new file mode 100644 index 0000000..4fdf319 --- /dev/null +++ b/initialize/migrate/database/02134_subscribe_traffic_limit.down.sql @@ -0,0 +1,6 @@ +-- Purpose: Rollback traffic_limit rules from subscribe +-- Author: Claude Code +-- Date: 2026-03-12 + +-- ===== Remove traffic_limit column from subscribe table ===== +ALTER TABLE `subscribe` DROP COLUMN IF EXISTS `traffic_limit`; diff --git a/initialize/migrate/database/02134_subscribe_traffic_limit.up.sql b/initialize/migrate/database/02134_subscribe_traffic_limit.up.sql new file mode 100644 index 0000000..18a9f30 --- /dev/null +++ b/initialize/migrate/database/02134_subscribe_traffic_limit.up.sql @@ -0,0 +1,22 @@ +-- Purpose: Add traffic_limit rules to subscribe +-- Author: Claude Code +-- Date: 2026-03-12 + +-- ===== Add traffic_limit column to subscribe table ===== +SET @column_exists = ( + SELECT COUNT(*) + FROM INFORMATION_SCHEMA.COLUMNS + WHERE TABLE_SCHEMA = DATABASE() + AND TABLE_NAME = 'subscribe' + AND COLUMN_NAME = 'traffic_limit' +); + +SET @sql = IF( + @column_exists = 0, + 'ALTER TABLE `subscribe` ADD COLUMN `traffic_limit` TEXT NULL COMMENT ''Traffic Limit Rules (JSON)'' AFTER `node_group_id`', + 'SELECT ''Column traffic_limit already exists in subscribe table''' +); + +PREPARE stmt FROM @sql; +EXECUTE stmt; +DEALLOCATE PREPARE stmt; diff --git a/internal/handler/admin/group/createNodeGroupHandler.go b/internal/handler/admin/group/createNodeGroupHandler.go new file mode 100644 index 0000000..eaba8cf --- /dev/null +++ b/internal/handler/admin/group/createNodeGroupHandler.go @@ -0,0 +1,26 @@ +package group + +import ( + "github.com/gin-gonic/gin" + "github.com/perfect-panel/server/internal/logic/admin/group" + "github.com/perfect-panel/server/internal/svc" + "github.com/perfect-panel/server/internal/types" + "github.com/perfect-panel/server/pkg/result" +) + +// Create node group +func CreateNodeGroupHandler(svcCtx *svc.ServiceContext) func(c *gin.Context) { + return func(c *gin.Context) { + var req types.CreateNodeGroupRequest + _ = c.ShouldBind(&req) + validateErr := svcCtx.Validate(&req) + if validateErr != nil { + result.ParamErrorResult(c, validateErr) + return + } + + l := group.NewCreateNodeGroupLogic(c.Request.Context(), svcCtx) + err := l.CreateNodeGroup(&req) + result.HttpResult(c, nil, err) + } +} diff --git a/internal/handler/admin/group/deleteNodeGroupHandler.go b/internal/handler/admin/group/deleteNodeGroupHandler.go new file mode 100644 index 0000000..b93120d --- /dev/null +++ b/internal/handler/admin/group/deleteNodeGroupHandler.go @@ -0,0 +1,29 @@ +package group + +import ( + "github.com/gin-gonic/gin" + "github.com/perfect-panel/server/internal/logic/admin/group" + "github.com/perfect-panel/server/internal/svc" + "github.com/perfect-panel/server/internal/types" + "github.com/perfect-panel/server/pkg/result" +) + +// Delete node group +func DeleteNodeGroupHandler(svcCtx *svc.ServiceContext) func(c *gin.Context) { + return func(c *gin.Context) { + var req types.DeleteNodeGroupRequest + if err := c.ShouldBind(&req); err != nil { + result.ParamErrorResult(c, err) + return + } + validateErr := svcCtx.Validate(&req) + if validateErr != nil { + result.ParamErrorResult(c, validateErr) + return + } + + l := group.NewDeleteNodeGroupLogic(c.Request.Context(), svcCtx) + err := l.DeleteNodeGroup(&req) + result.HttpResult(c, nil, err) + } +} diff --git a/internal/handler/admin/group/exportGroupResultHandler.go b/internal/handler/admin/group/exportGroupResultHandler.go new file mode 100644 index 0000000..69b065f --- /dev/null +++ b/internal/handler/admin/group/exportGroupResultHandler.go @@ -0,0 +1,36 @@ +package group + +import ( + "net/http" + + "github.com/gin-gonic/gin" + "github.com/perfect-panel/server/internal/logic/admin/group" + "github.com/perfect-panel/server/internal/svc" + "github.com/perfect-panel/server/internal/types" + "github.com/perfect-panel/server/pkg/result" +) + +// Export group result +func ExportGroupResultHandler(svcCtx *svc.ServiceContext) func(c *gin.Context) { + return func(c *gin.Context) { + var req types.ExportGroupResultRequest + _ = c.ShouldBind(&req) + validateErr := svcCtx.Validate(&req) + if validateErr != nil { + result.ParamErrorResult(c, validateErr) + return + } + + l := group.NewExportGroupResultLogic(c.Request.Context(), svcCtx) + data, filename, err := l.ExportGroupResult(&req) + if err != nil { + result.HttpResult(c, nil, err) + return + } + + // 设置响应头 + c.Header("Content-Type", "text/csv") + c.Header("Content-Disposition", "attachment; filename="+filename) + c.Data(http.StatusOK, "text/csv", data) + } +} diff --git a/internal/handler/admin/group/getGroupConfigHandler.go b/internal/handler/admin/group/getGroupConfigHandler.go new file mode 100644 index 0000000..ef24311 --- /dev/null +++ b/internal/handler/admin/group/getGroupConfigHandler.go @@ -0,0 +1,26 @@ +package group + +import ( + "github.com/gin-gonic/gin" + "github.com/perfect-panel/server/internal/logic/admin/group" + "github.com/perfect-panel/server/internal/svc" + "github.com/perfect-panel/server/internal/types" + "github.com/perfect-panel/server/pkg/result" +) + +// Get group config +func GetGroupConfigHandler(svcCtx *svc.ServiceContext) func(c *gin.Context) { + return func(c *gin.Context) { + var req types.GetGroupConfigRequest + _ = c.ShouldBind(&req) + validateErr := svcCtx.Validate(&req) + if validateErr != nil { + result.ParamErrorResult(c, validateErr) + return + } + + l := group.NewGetGroupConfigLogic(c.Request.Context(), svcCtx) + resp, err := l.GetGroupConfig(&req) + result.HttpResult(c, resp, err) + } +} diff --git a/internal/handler/admin/group/getGroupHistoryDetailHandler.go b/internal/handler/admin/group/getGroupHistoryDetailHandler.go new file mode 100644 index 0000000..fa58f3c --- /dev/null +++ b/internal/handler/admin/group/getGroupHistoryDetailHandler.go @@ -0,0 +1,26 @@ +package group + +import ( + "github.com/gin-gonic/gin" + "github.com/perfect-panel/server/internal/logic/admin/group" + "github.com/perfect-panel/server/internal/svc" + "github.com/perfect-panel/server/internal/types" + "github.com/perfect-panel/server/pkg/result" +) + +// Get group history detail +func GetGroupHistoryDetailHandler(svcCtx *svc.ServiceContext) func(c *gin.Context) { + return func(c *gin.Context) { + var req types.GetGroupHistoryDetailRequest + _ = c.ShouldBind(&req) + validateErr := svcCtx.Validate(&req) + if validateErr != nil { + result.ParamErrorResult(c, validateErr) + return + } + + l := group.NewGetGroupHistoryDetailLogic(c.Request.Context(), svcCtx) + resp, err := l.GetGroupHistoryDetail(&req) + result.HttpResult(c, resp, err) + } +} diff --git a/internal/handler/admin/group/getGroupHistoryHandler.go b/internal/handler/admin/group/getGroupHistoryHandler.go new file mode 100644 index 0000000..b6b5490 --- /dev/null +++ b/internal/handler/admin/group/getGroupHistoryHandler.go @@ -0,0 +1,26 @@ +package group + +import ( + "github.com/gin-gonic/gin" + "github.com/perfect-panel/server/internal/logic/admin/group" + "github.com/perfect-panel/server/internal/svc" + "github.com/perfect-panel/server/internal/types" + "github.com/perfect-panel/server/pkg/result" +) + +// Get group history +func GetGroupHistoryHandler(svcCtx *svc.ServiceContext) func(c *gin.Context) { + return func(c *gin.Context) { + var req types.GetGroupHistoryRequest + _ = c.ShouldBind(&req) + validateErr := svcCtx.Validate(&req) + if validateErr != nil { + result.ParamErrorResult(c, validateErr) + return + } + + l := group.NewGetGroupHistoryLogic(c.Request.Context(), svcCtx) + resp, err := l.GetGroupHistory(&req) + result.HttpResult(c, resp, err) + } +} diff --git a/internal/handler/admin/group/getNodeGroupListHandler.go b/internal/handler/admin/group/getNodeGroupListHandler.go new file mode 100644 index 0000000..501138f --- /dev/null +++ b/internal/handler/admin/group/getNodeGroupListHandler.go @@ -0,0 +1,26 @@ +package group + +import ( + "github.com/gin-gonic/gin" + "github.com/perfect-panel/server/internal/logic/admin/group" + "github.com/perfect-panel/server/internal/svc" + "github.com/perfect-panel/server/internal/types" + "github.com/perfect-panel/server/pkg/result" +) + +// Get node group list +func GetNodeGroupListHandler(svcCtx *svc.ServiceContext) func(c *gin.Context) { + return func(c *gin.Context) { + var req types.GetNodeGroupListRequest + _ = c.ShouldBind(&req) + validateErr := svcCtx.Validate(&req) + if validateErr != nil { + result.ParamErrorResult(c, validateErr) + return + } + + l := group.NewGetNodeGroupListLogic(c.Request.Context(), svcCtx) + resp, err := l.GetNodeGroupList(&req) + result.HttpResult(c, resp, err) + } +} diff --git a/internal/handler/admin/group/getRecalculationStatusHandler.go b/internal/handler/admin/group/getRecalculationStatusHandler.go new file mode 100644 index 0000000..e9b76b8 --- /dev/null +++ b/internal/handler/admin/group/getRecalculationStatusHandler.go @@ -0,0 +1,18 @@ +package group + +import ( + "github.com/gin-gonic/gin" + "github.com/perfect-panel/server/internal/logic/admin/group" + "github.com/perfect-panel/server/internal/svc" + "github.com/perfect-panel/server/pkg/result" +) + +// Get recalculation status +func GetRecalculationStatusHandler(svcCtx *svc.ServiceContext) func(c *gin.Context) { + return func(c *gin.Context) { + + l := group.NewGetRecalculationStatusLogic(c.Request.Context(), svcCtx) + resp, err := l.GetRecalculationStatus() + result.HttpResult(c, resp, err) + } +} diff --git a/internal/handler/admin/group/getSubscribeGroupMappingHandler.go b/internal/handler/admin/group/getSubscribeGroupMappingHandler.go new file mode 100644 index 0000000..4da798c --- /dev/null +++ b/internal/handler/admin/group/getSubscribeGroupMappingHandler.go @@ -0,0 +1,26 @@ +package group + +import ( + "github.com/gin-gonic/gin" + "github.com/perfect-panel/server/internal/logic/admin/group" + "github.com/perfect-panel/server/internal/svc" + "github.com/perfect-panel/server/internal/types" + "github.com/perfect-panel/server/pkg/result" +) + +// Get subscribe group mapping +func GetSubscribeGroupMappingHandler(svcCtx *svc.ServiceContext) func(c *gin.Context) { + return func(c *gin.Context) { + var req types.GetSubscribeGroupMappingRequest + _ = c.ShouldBind(&req) + validateErr := svcCtx.Validate(&req) + if validateErr != nil { + result.ParamErrorResult(c, validateErr) + return + } + + l := group.NewGetSubscribeGroupMappingLogic(c.Request.Context(), svcCtx) + resp, err := l.GetSubscribeGroupMapping(&req) + result.HttpResult(c, resp, err) + } +} diff --git a/internal/handler/admin/group/previewUserNodesHandler.go b/internal/handler/admin/group/previewUserNodesHandler.go new file mode 100644 index 0000000..da3560d --- /dev/null +++ b/internal/handler/admin/group/previewUserNodesHandler.go @@ -0,0 +1,26 @@ +package group + +import ( + "github.com/gin-gonic/gin" + "github.com/perfect-panel/server/internal/logic/admin/group" + "github.com/perfect-panel/server/internal/svc" + "github.com/perfect-panel/server/internal/types" + "github.com/perfect-panel/server/pkg/result" +) + +// Preview user nodes +func PreviewUserNodesHandler(svcCtx *svc.ServiceContext) func(c *gin.Context) { + return func(c *gin.Context) { + var req types.PreviewUserNodesRequest + _ = c.ShouldBind(&req) + validateErr := svcCtx.Validate(&req) + if validateErr != nil { + result.ParamErrorResult(c, validateErr) + return + } + + l := group.NewPreviewUserNodesLogic(c.Request.Context(), svcCtx) + resp, err := l.PreviewUserNodes(&req) + result.HttpResult(c, resp, err) + } +} diff --git a/internal/handler/admin/group/recalculateGroupHandler.go b/internal/handler/admin/group/recalculateGroupHandler.go new file mode 100644 index 0000000..848363d --- /dev/null +++ b/internal/handler/admin/group/recalculateGroupHandler.go @@ -0,0 +1,26 @@ +package group + +import ( + "github.com/gin-gonic/gin" + "github.com/perfect-panel/server/internal/logic/admin/group" + "github.com/perfect-panel/server/internal/svc" + "github.com/perfect-panel/server/internal/types" + "github.com/perfect-panel/server/pkg/result" +) + +// Recalculate group +func RecalculateGroupHandler(svcCtx *svc.ServiceContext) func(c *gin.Context) { + return func(c *gin.Context) { + var req types.RecalculateGroupRequest + _ = c.ShouldBind(&req) + validateErr := svcCtx.Validate(&req) + if validateErr != nil { + result.ParamErrorResult(c, validateErr) + return + } + + l := group.NewRecalculateGroupLogic(c.Request.Context(), svcCtx) + err := l.RecalculateGroup(&req) + result.HttpResult(c, nil, err) + } +} diff --git a/internal/handler/admin/group/resetGroupsHandler.go b/internal/handler/admin/group/resetGroupsHandler.go new file mode 100644 index 0000000..e0af912 --- /dev/null +++ b/internal/handler/admin/group/resetGroupsHandler.go @@ -0,0 +1,17 @@ +package group + +import ( + "github.com/gin-gonic/gin" + "github.com/perfect-panel/server/internal/logic/admin/group" + "github.com/perfect-panel/server/internal/svc" + "github.com/perfect-panel/server/pkg/result" +) + +// Reset all groups +func ResetGroupsHandler(svcCtx *svc.ServiceContext) func(c *gin.Context) { + return func(c *gin.Context) { + l := group.NewResetGroupsLogic(c.Request.Context(), svcCtx) + err := l.ResetGroups() + result.HttpResult(c, nil, err) + } +} diff --git a/internal/handler/admin/group/updateGroupConfigHandler.go b/internal/handler/admin/group/updateGroupConfigHandler.go new file mode 100644 index 0000000..6f2ea1c --- /dev/null +++ b/internal/handler/admin/group/updateGroupConfigHandler.go @@ -0,0 +1,26 @@ +package group + +import ( + "github.com/gin-gonic/gin" + "github.com/perfect-panel/server/internal/logic/admin/group" + "github.com/perfect-panel/server/internal/svc" + "github.com/perfect-panel/server/internal/types" + "github.com/perfect-panel/server/pkg/result" +) + +// Update group config +func UpdateGroupConfigHandler(svcCtx *svc.ServiceContext) func(c *gin.Context) { + return func(c *gin.Context) { + var req types.UpdateGroupConfigRequest + _ = c.ShouldBind(&req) + validateErr := svcCtx.Validate(&req) + if validateErr != nil { + result.ParamErrorResult(c, validateErr) + return + } + + l := group.NewUpdateGroupConfigLogic(c.Request.Context(), svcCtx) + err := l.UpdateGroupConfig(&req) + result.HttpResult(c, nil, err) + } +} diff --git a/internal/handler/admin/group/updateNodeGroupHandler.go b/internal/handler/admin/group/updateNodeGroupHandler.go new file mode 100644 index 0000000..e9f9058 --- /dev/null +++ b/internal/handler/admin/group/updateNodeGroupHandler.go @@ -0,0 +1,33 @@ +package group + +import ( + "github.com/gin-gonic/gin" + "github.com/perfect-panel/server/internal/logic/admin/group" + "github.com/perfect-panel/server/internal/svc" + "github.com/perfect-panel/server/internal/types" + "github.com/perfect-panel/server/pkg/result" +) + +// Update node group +func UpdateNodeGroupHandler(svcCtx *svc.ServiceContext) func(c *gin.Context) { + return func(c *gin.Context) { + var req types.UpdateNodeGroupRequest + if err := c.ShouldBindUri(&req); err != nil { + result.ParamErrorResult(c, err) + return + } + if err := c.ShouldBind(&req); err != nil { + result.ParamErrorResult(c, err) + return + } + validateErr := svcCtx.Validate(&req) + if validateErr != nil { + result.ParamErrorResult(c, validateErr) + return + } + + l := group.NewUpdateNodeGroupLogic(c.Request.Context(), svcCtx) + err := l.UpdateNodeGroup(&req) + result.HttpResult(c, nil, err) + } +} diff --git a/internal/handler/auth/admin/adminGenerateCaptchaHandler.go b/internal/handler/auth/admin/adminGenerateCaptchaHandler.go new file mode 100644 index 0000000..caabd45 --- /dev/null +++ b/internal/handler/auth/admin/adminGenerateCaptchaHandler.go @@ -0,0 +1,18 @@ +package admin + +import ( + "github.com/gin-gonic/gin" + "github.com/perfect-panel/server/internal/logic/auth/admin" + "github.com/perfect-panel/server/internal/svc" + "github.com/perfect-panel/server/pkg/result" +) + +// Generate captcha +func AdminGenerateCaptchaHandler(svcCtx *svc.ServiceContext) func(c *gin.Context) { + return func(c *gin.Context) { + + l := admin.NewAdminGenerateCaptchaLogic(c.Request.Context(), svcCtx) + resp, err := l.AdminGenerateCaptcha() + result.HttpResult(c, resp, err) + } +} diff --git a/internal/handler/auth/admin/adminLoginHandler.go b/internal/handler/auth/admin/adminLoginHandler.go new file mode 100644 index 0000000..95239bd --- /dev/null +++ b/internal/handler/auth/admin/adminLoginHandler.go @@ -0,0 +1,30 @@ +package admin + +import ( + "github.com/gin-gonic/gin" + adminLogic "github.com/perfect-panel/server/internal/logic/auth/admin" + "github.com/perfect-panel/server/internal/svc" + "github.com/perfect-panel/server/internal/types" + "github.com/perfect-panel/server/pkg/result" +) + +// Admin login +func AdminLoginHandler(svcCtx *svc.ServiceContext) func(c *gin.Context) { + return func(c *gin.Context) { + var req types.UserLoginRequest + _ = c.ShouldBind(&req) + // get client ip + req.IP = c.ClientIP() + req.UserAgent = c.Request.UserAgent() + + validateErr := svcCtx.Validate(&req) + if validateErr != nil { + result.ParamErrorResult(c, validateErr) + return + } + + l := adminLogic.NewAdminLoginLogic(c.Request.Context(), svcCtx) + resp, err := l.AdminLogin(&req) + result.HttpResult(c, resp, err) + } +} diff --git a/internal/handler/auth/admin/adminResetPasswordHandler.go b/internal/handler/auth/admin/adminResetPasswordHandler.go new file mode 100644 index 0000000..9fb909c --- /dev/null +++ b/internal/handler/auth/admin/adminResetPasswordHandler.go @@ -0,0 +1,29 @@ +package admin + +import ( + "github.com/gin-gonic/gin" + adminLogic "github.com/perfect-panel/server/internal/logic/auth/admin" + "github.com/perfect-panel/server/internal/svc" + "github.com/perfect-panel/server/internal/types" + "github.com/perfect-panel/server/pkg/result" +) + +// Admin reset password +func AdminResetPasswordHandler(svcCtx *svc.ServiceContext) func(c *gin.Context) { + return func(c *gin.Context) { + var req types.ResetPasswordRequest + _ = c.ShouldBind(&req) + validateErr := svcCtx.Validate(&req) + if validateErr != nil { + result.ParamErrorResult(c, validateErr) + return + } + // get client ip + req.IP = c.ClientIP() + req.UserAgent = c.Request.UserAgent() + + l := adminLogic.NewAdminResetPasswordLogic(c.Request.Context(), svcCtx) + resp, err := l.AdminResetPassword(&req) + result.HttpResult(c, resp, err) + } +} diff --git a/internal/handler/auth/generateCaptchaHandler.go b/internal/handler/auth/generateCaptchaHandler.go new file mode 100644 index 0000000..d263da7 --- /dev/null +++ b/internal/handler/auth/generateCaptchaHandler.go @@ -0,0 +1,18 @@ +package auth + +import ( + "github.com/gin-gonic/gin" + "github.com/perfect-panel/server/internal/logic/auth" + "github.com/perfect-panel/server/internal/svc" + "github.com/perfect-panel/server/pkg/result" +) + +// Generate captcha +func GenerateCaptchaHandler(svcCtx *svc.ServiceContext) func(c *gin.Context) { + return func(c *gin.Context) { + + l := auth.NewGenerateCaptchaLogic(c.Request.Context(), svcCtx) + resp, err := l.GenerateCaptcha() + result.HttpResult(c, resp, err) + } +} diff --git a/internal/handler/auth/resetPasswordHandler.go b/internal/handler/auth/resetPasswordHandler.go index d4edc9b..8de4ca6 100644 --- a/internal/handler/auth/resetPasswordHandler.go +++ b/internal/handler/auth/resetPasswordHandler.go @@ -1,16 +1,11 @@ package auth import ( - "time" - "github.com/gin-gonic/gin" "github.com/perfect-panel/server/internal/logic/auth" "github.com/perfect-panel/server/internal/svc" "github.com/perfect-panel/server/internal/types" "github.com/perfect-panel/server/pkg/result" - "github.com/perfect-panel/server/pkg/turnstile" - "github.com/perfect-panel/server/pkg/xerr" - "github.com/pkg/errors" ) // Reset password @@ -25,17 +20,8 @@ func ResetPasswordHandler(svcCtx *svc.ServiceContext) func(c *gin.Context) { } // get client ip req.IP = c.ClientIP() - if svcCtx.Config.Verify.ResetPasswordVerify { - verifyTurns := turnstile.New(turnstile.Config{ - Secret: svcCtx.Config.Verify.TurnstileSecret, - Timeout: 3 * time.Second, - }) - if verify, err := verifyTurns.Verify(c, req.CfToken, req.IP); err != nil || !verify { - err = errors.Wrapf(xerr.NewErrCode(xerr.TooManyRequests), "error: %v, verify: %v", err, verify) - result.HttpResult(c, nil, err) - return - } - } + req.UserAgent = c.Request.UserAgent() + l := auth.NewResetPasswordLogic(c.Request.Context(), svcCtx) resp, err := l.ResetPassword(&req) result.HttpResult(c, resp, err) diff --git a/internal/handler/auth/telephoneLoginHandler.go b/internal/handler/auth/telephoneLoginHandler.go index 13352c7..8319352 100644 --- a/internal/handler/auth/telephoneLoginHandler.go +++ b/internal/handler/auth/telephoneLoginHandler.go @@ -1,16 +1,11 @@ package auth import ( - "time" - "github.com/gin-gonic/gin" "github.com/perfect-panel/server/internal/logic/auth" "github.com/perfect-panel/server/internal/svc" "github.com/perfect-panel/server/internal/types" "github.com/perfect-panel/server/pkg/result" - "github.com/perfect-panel/server/pkg/turnstile" - "github.com/perfect-panel/server/pkg/xerr" - "github.com/pkg/errors" ) // User Telephone login diff --git a/internal/handler/auth/telephoneResetPasswordHandler.go b/internal/handler/auth/telephoneResetPasswordHandler.go index 16a5105..bc0f8a6 100644 --- a/internal/handler/auth/telephoneResetPasswordHandler.go +++ b/internal/handler/auth/telephoneResetPasswordHandler.go @@ -1,14 +1,13 @@ package auth import ( - "time" - "github.com/gin-gonic/gin" "github.com/perfect-panel/server/internal/logic/auth" "github.com/perfect-panel/server/internal/svc" "github.com/perfect-panel/server/internal/types" + "github.com/perfect-panel/server/pkg/captcha" "github.com/perfect-panel/server/pkg/result" - "github.com/perfect-panel/server/pkg/turnstile" + "github.com/perfect-panel/server/pkg/tool" "github.com/perfect-panel/server/pkg/xerr" "github.com/pkg/errors" ) @@ -25,17 +24,44 @@ func TelephoneResetPasswordHandler(svcCtx *svc.ServiceContext) func(c *gin.Conte } // get client ip req.IP = c.ClientIP() - if svcCtx.Config.Verify.ResetPasswordVerify { - verifyTurns := turnstile.New(turnstile.Config{ - Secret: svcCtx.Config.Verify.TurnstileSecret, - Timeout: 3 * time.Second, + + // Get verify config from database + verifyCfg, err := svcCtx.SystemModel.GetVerifyConfig(c.Request.Context()) + if err != nil { + result.HttpResult(c, nil, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "get verify config failed: %v", err)) + return + } + + var config struct { + CaptchaType string `json:"captcha_type"` + EnableUserResetPasswordCaptcha bool `json:"enable_user_reset_password_captcha"` + TurnstileSecret string `json:"turnstile_secret"` + } + tool.SystemConfigSliceReflectToStruct(verifyCfg, &config) + + // Verify captcha if enabled + if config.EnableUserResetPasswordCaptcha { + captchaService := captcha.NewService(captcha.Config{ + Type: captcha.CaptchaType(config.CaptchaType), + TurnstileSecret: config.TurnstileSecret, + RedisClient: svcCtx.Redis, }) - if verify, err := verifyTurns.Verify(c.Request.Context(), req.CfToken, req.IP); err != nil || !verify { - err = errors.Wrapf(xerr.NewErrCode(xerr.TooManyRequests), "error: %v, verify: %v", err, verify) - result.HttpResult(c, nil, err) + + var token, code string + if config.CaptchaType == "turnstile" { + token = req.CfToken + } else { + token = req.CaptchaId + code = req.CaptchaCode + } + + verified, err := captchaService.Verify(c.Request.Context(), token, code, req.IP) + if err != nil || !verified { + result.HttpResult(c, nil, errors.Wrapf(xerr.NewErrCode(xerr.TooManyRequests), "captcha verification failed: %v", err)) return } } + l := auth.NewTelephoneResetPasswordLogic(c, svcCtx) resp, err := l.TelephoneResetPassword(&req) result.HttpResult(c, resp, err) diff --git a/internal/handler/auth/telephoneUserRegisterHandler.go b/internal/handler/auth/telephoneUserRegisterHandler.go index 45a7ba8..306777d 100644 --- a/internal/handler/auth/telephoneUserRegisterHandler.go +++ b/internal/handler/auth/telephoneUserRegisterHandler.go @@ -1,16 +1,11 @@ package auth import ( - "time" - "github.com/gin-gonic/gin" "github.com/perfect-panel/server/internal/logic/auth" "github.com/perfect-panel/server/internal/svc" "github.com/perfect-panel/server/internal/types" "github.com/perfect-panel/server/pkg/result" - "github.com/perfect-panel/server/pkg/turnstile" - "github.com/perfect-panel/server/pkg/xerr" - "github.com/pkg/errors" ) // User Telephone register @@ -26,17 +21,7 @@ func TelephoneUserRegisterHandler(svcCtx *svc.ServiceContext) func(c *gin.Contex // get client ip req.IP = c.ClientIP() req.UserAgent = c.Request.UserAgent() - if svcCtx.Config.Verify.RegisterVerify { - verifyTurns := turnstile.New(turnstile.Config{ - Secret: svcCtx.Config.Verify.TurnstileSecret, - Timeout: 3 * time.Second, - }) - if verify, err := verifyTurns.Verify(c, req.CfToken, req.IP); err != nil || !verify { - err = errors.Wrapf(xerr.NewErrCode(xerr.TooManyRequests), "error: %v, verify: %v", err, verify) - result.HttpResult(c, nil, err) - return - } - } + l := auth.NewTelephoneUserRegisterLogic(c.Request.Context(), svcCtx) resp, err := l.TelephoneUserRegister(&req) result.HttpResult(c, resp, err) diff --git a/internal/handler/auth/userLoginHandler.go b/internal/handler/auth/userLoginHandler.go index 20eff59..a188876 100644 --- a/internal/handler/auth/userLoginHandler.go +++ b/internal/handler/auth/userLoginHandler.go @@ -1,16 +1,11 @@ package auth import ( - "time" - "github.com/gin-gonic/gin" "github.com/perfect-panel/server/internal/logic/auth" "github.com/perfect-panel/server/internal/svc" "github.com/perfect-panel/server/internal/types" "github.com/perfect-panel/server/pkg/result" - "github.com/perfect-panel/server/pkg/turnstile" - "github.com/perfect-panel/server/pkg/xerr" - "github.com/pkg/errors" ) // User login @@ -21,17 +16,7 @@ func UserLoginHandler(svcCtx *svc.ServiceContext) func(c *gin.Context) { // get client ip req.IP = c.ClientIP() req.UserAgent = c.Request.UserAgent() - if svcCtx.Config.Verify.LoginVerify && !svcCtx.Config.Debug { - verifyTurns := turnstile.New(turnstile.Config{ - Secret: svcCtx.Config.Verify.TurnstileSecret, - Timeout: 3 * time.Second, - }) - if verify, err := verifyTurns.Verify(c, req.CfToken, req.IP); err != nil || !verify { - err = errors.Wrapf(xerr.NewErrCode(xerr.TooManyRequests), "error: %v, verify: %v", err, verify) - result.HttpResult(c, nil, err) - return - } - } + validateErr := svcCtx.Validate(&req) if validateErr != nil { result.ParamErrorResult(c, validateErr) diff --git a/internal/handler/auth/userRegisterHandler.go b/internal/handler/auth/userRegisterHandler.go index ea40223..a11c743 100644 --- a/internal/handler/auth/userRegisterHandler.go +++ b/internal/handler/auth/userRegisterHandler.go @@ -1,16 +1,11 @@ package auth import ( - "time" - "github.com/gin-gonic/gin" "github.com/perfect-panel/server/internal/logic/auth" "github.com/perfect-panel/server/internal/svc" "github.com/perfect-panel/server/internal/types" "github.com/perfect-panel/server/pkg/result" - "github.com/perfect-panel/server/pkg/turnstile" - "github.com/perfect-panel/server/pkg/xerr" - "github.com/pkg/errors" ) // User register @@ -21,16 +16,7 @@ func UserRegisterHandler(svcCtx *svc.ServiceContext) func(c *gin.Context) { // get client ip req.IP = c.ClientIP() req.UserAgent = c.Request.UserAgent() - if svcCtx.Config.Verify.RegisterVerify { - verifyTurns := turnstile.New(turnstile.Config{ - Secret: svcCtx.Config.Verify.TurnstileSecret, - Timeout: 3 * time.Second, - }) - if verify, err := verifyTurns.Verify(c, req.CfToken, req.IP); err != nil || !verify { - result.HttpResult(c, nil, errors.Wrapf(xerr.NewErrCode(xerr.TooManyRequests), "verify error: %v", err.Error())) - return - } - } + validateErr := svcCtx.Validate(&req) if validateErr != nil { result.ParamErrorResult(c, validateErr) diff --git a/internal/handler/public/user/getUserTrafficStatsHandler.go b/internal/handler/public/user/getUserTrafficStatsHandler.go new file mode 100644 index 0000000..5c3c7dd --- /dev/null +++ b/internal/handler/public/user/getUserTrafficStatsHandler.go @@ -0,0 +1,26 @@ +package user + +import ( + "github.com/gin-gonic/gin" + "github.com/perfect-panel/server/internal/logic/public/user" + "github.com/perfect-panel/server/internal/svc" + "github.com/perfect-panel/server/internal/types" + "github.com/perfect-panel/server/pkg/result" +) + +// Get User Traffic Statistics +func GetUserTrafficStatsHandler(svcCtx *svc.ServiceContext) func(c *gin.Context) { + return func(c *gin.Context) { + var req types.GetUserTrafficStatsRequest + _ = c.ShouldBind(&req) + validateErr := svcCtx.Validate(&req) + if validateErr != nil { + result.ParamErrorResult(c, validateErr) + return + } + + l := user.NewGetUserTrafficStatsLogic(c.Request.Context(), svcCtx) + resp, err := l.GetUserTrafficStats(&req) + result.HttpResult(c, resp, err) + } +} diff --git a/internal/handler/routes.go b/internal/handler/routes.go index 23bed10..9168e0f 100644 --- a/internal/handler/routes.go +++ b/internal/handler/routes.go @@ -12,6 +12,7 @@ import ( adminConsole "github.com/perfect-panel/server/internal/handler/admin/console" adminCoupon "github.com/perfect-panel/server/internal/handler/admin/coupon" adminDocument "github.com/perfect-panel/server/internal/handler/admin/document" + adminGroup "github.com/perfect-panel/server/internal/handler/admin/group" adminLog "github.com/perfect-panel/server/internal/handler/admin/log" adminMarketing "github.com/perfect-panel/server/internal/handler/admin/marketing" adminOrder "github.com/perfect-panel/server/internal/handler/admin/order" @@ -24,6 +25,7 @@ import ( adminTool "github.com/perfect-panel/server/internal/handler/admin/tool" adminUser "github.com/perfect-panel/server/internal/handler/admin/user" auth "github.com/perfect-panel/server/internal/handler/auth" + authAdmin "github.com/perfect-panel/server/internal/handler/auth/admin" authOauth "github.com/perfect-panel/server/internal/handler/auth/oauth" common "github.com/perfect-panel/server/internal/handler/common" publicAnnouncement "github.com/perfect-panel/server/internal/handler/public/announcement" @@ -189,6 +191,53 @@ func RegisterHandlers(router *gin.Engine, serverCtx *svc.ServiceContext) { adminDocumentGroupRouter.GET("/list", adminDocument.GetDocumentListHandler(serverCtx)) } + adminGroupGroupRouter := router.Group("/v1/admin/group") + adminGroupGroupRouter.Use(middleware.AuthMiddleware(serverCtx)) + + { + // Get group config + adminGroupGroupRouter.GET("/config", adminGroup.GetGroupConfigHandler(serverCtx)) + + // Update group config + adminGroupGroupRouter.PUT("/config", adminGroup.UpdateGroupConfigHandler(serverCtx)) + + // Export group result + adminGroupGroupRouter.GET("/export", adminGroup.ExportGroupResultHandler(serverCtx)) + + // Get group history + adminGroupGroupRouter.GET("/history", adminGroup.GetGroupHistoryHandler(serverCtx)) + + // Get group history detail + adminGroupGroupRouter.GET("/history/detail", adminGroup.GetGroupHistoryDetailHandler(serverCtx)) + + // Create node group + adminGroupGroupRouter.POST("/node", adminGroup.CreateNodeGroupHandler(serverCtx)) + + // Update node group + adminGroupGroupRouter.PUT("/node", adminGroup.UpdateNodeGroupHandler(serverCtx)) + + // Delete node group + adminGroupGroupRouter.DELETE("/node", adminGroup.DeleteNodeGroupHandler(serverCtx)) + + // Get node group list + adminGroupGroupRouter.GET("/node/list", adminGroup.GetNodeGroupListHandler(serverCtx)) + + // Preview user nodes + adminGroupGroupRouter.GET("/preview", adminGroup.PreviewUserNodesHandler(serverCtx)) + + // Recalculate group + adminGroupGroupRouter.POST("/recalculate", adminGroup.RecalculateGroupHandler(serverCtx)) + + // Get recalculation status + adminGroupGroupRouter.GET("/recalculation/status", adminGroup.GetRecalculationStatusHandler(serverCtx)) + + // Reset all groups + adminGroupGroupRouter.POST("/reset", adminGroup.ResetGroupsHandler(serverCtx)) + + // Get subscribe group mapping + adminGroupGroupRouter.GET("/subscribe/mapping", adminGroup.GetSubscribeGroupMappingHandler(serverCtx)) + } + adminLogGroupRouter := router.Group("/v1/admin/log") adminLogGroupRouter.Use(middleware.AuthMiddleware(serverCtx)) @@ -659,6 +708,9 @@ func RegisterHandlers(router *gin.Engine, serverCtx *svc.ServiceContext) { authGroupRouter.Use(middleware.DeviceMiddleware(serverCtx)) { + // Generate captcha + authGroupRouter.POST("/captcha/generate", auth.GenerateCaptchaHandler(serverCtx)) + // Check user is exist authGroupRouter.GET("/check", auth.CheckUserHandler(serverCtx)) @@ -690,6 +742,20 @@ func RegisterHandlers(router *gin.Engine, serverCtx *svc.ServiceContext) { authGroupRouter.POST("/reset/telephone", auth.TelephoneResetPasswordHandler(serverCtx)) } + authAdminGroupRouter := router.Group("/v1/auth/admin") + authAdminGroupRouter.Use(middleware.DeviceMiddleware(serverCtx)) + + { + // Generate captcha + authAdminGroupRouter.POST("/captcha/generate", authAdmin.AdminGenerateCaptchaHandler(serverCtx)) + + // Admin login + authAdminGroupRouter.POST("/login", authAdmin.AdminLoginHandler(serverCtx)) + + // Admin reset password + authAdminGroupRouter.POST("/reset", authAdmin.AdminResetPasswordHandler(serverCtx)) + } + authOauthGroupRouter := router.Group("/v1/auth/oauth") { @@ -980,6 +1046,9 @@ func RegisterHandlers(router *gin.Engine, serverCtx *svc.ServiceContext) { // Reset User Subscribe Token publicUserGroupRouter.PUT("/subscribe_token", publicUser.ResetUserSubscribeTokenHandler(serverCtx)) + // Get User Traffic Statistics + publicUserGroupRouter.GET("/traffic_stats", publicUser.GetUserTrafficStatsHandler(serverCtx)) + // Unbind Device publicUserGroupRouter.PUT("/unbind_device", publicUser.UnbindDeviceHandler(serverCtx)) diff --git a/internal/logic/admin/group/createNodeGroupLogic.go b/internal/logic/admin/group/createNodeGroupLogic.go new file mode 100644 index 0000000..9e68c10 --- /dev/null +++ b/internal/logic/admin/group/createNodeGroupLogic.go @@ -0,0 +1,81 @@ +package group + +import ( + "context" + "errors" + "time" + + "github.com/perfect-panel/server/internal/model/group" + "github.com/perfect-panel/server/internal/svc" + "github.com/perfect-panel/server/internal/types" + "github.com/perfect-panel/server/pkg/logger" +) + +type CreateNodeGroupLogic struct { + logger.Logger + ctx context.Context + svcCtx *svc.ServiceContext +} + +func NewCreateNodeGroupLogic(ctx context.Context, svcCtx *svc.ServiceContext) *CreateNodeGroupLogic { + return &CreateNodeGroupLogic{ + Logger: logger.WithContext(ctx), + ctx: ctx, + svcCtx: svcCtx, + } +} + +func (l *CreateNodeGroupLogic) CreateNodeGroup(req *types.CreateNodeGroupRequest) error { + // 验证:系统中只能有一个过期节点组 + if req.IsExpiredGroup != nil && *req.IsExpiredGroup { + var count int64 + err := l.svcCtx.DB.Model(&group.NodeGroup{}). + Where("is_expired_group = ?", true). + Count(&count).Error + if err != nil { + logger.Errorf("failed to check expired group count: %v", err) + return err + } + if count > 0 { + return errors.New("system already has an expired node group, cannot create multiple") + } + } + + // 创建节点组 + nodeGroup := &group.NodeGroup{ + Name: req.Name, + Description: req.Description, + Sort: req.Sort, + ForCalculation: req.ForCalculation, + IsExpiredGroup: req.IsExpiredGroup, + MaxTrafficGBExpired: req.MaxTrafficGBExpired, + MinTrafficGB: req.MinTrafficGB, + MaxTrafficGB: req.MaxTrafficGB, + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } + + // 设置过期节点组的默认值 + if req.IsExpiredGroup != nil && *req.IsExpiredGroup { + // 过期节点组不参与分组计算 + falseValue := false + nodeGroup.ForCalculation = &falseValue + + if req.ExpiredDaysLimit != nil { + nodeGroup.ExpiredDaysLimit = *req.ExpiredDaysLimit + } else { + nodeGroup.ExpiredDaysLimit = 7 // 默认7天 + } + if req.SpeedLimit != nil { + nodeGroup.SpeedLimit = *req.SpeedLimit + } + } + + if err := l.svcCtx.DB.Create(nodeGroup).Error; err != nil { + logger.Errorf("failed to create node group: %v", err) + return err + } + + logger.Infof("created node group: node_group_id=%d", nodeGroup.Id) + return nil +} diff --git a/internal/logic/admin/group/deleteNodeGroupLogic.go b/internal/logic/admin/group/deleteNodeGroupLogic.go new file mode 100644 index 0000000..16c89d4 --- /dev/null +++ b/internal/logic/admin/group/deleteNodeGroupLogic.go @@ -0,0 +1,62 @@ +package group + +import ( + "context" + "errors" + "fmt" + + "github.com/perfect-panel/server/internal/model/group" + "github.com/perfect-panel/server/internal/model/node" + "github.com/perfect-panel/server/internal/svc" + "github.com/perfect-panel/server/internal/types" + "github.com/perfect-panel/server/pkg/logger" + "gorm.io/gorm" +) + +type DeleteNodeGroupLogic struct { + logger.Logger + ctx context.Context + svcCtx *svc.ServiceContext +} + +func NewDeleteNodeGroupLogic(ctx context.Context, svcCtx *svc.ServiceContext) *DeleteNodeGroupLogic { + return &DeleteNodeGroupLogic{ + Logger: logger.WithContext(ctx), + ctx: ctx, + svcCtx: svcCtx, + } +} + +func (l *DeleteNodeGroupLogic) DeleteNodeGroup(req *types.DeleteNodeGroupRequest) error { + // 查询节点组信息 + var nodeGroup group.NodeGroup + if err := l.svcCtx.DB.Where("id = ?", req.Id).First(&nodeGroup).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return errors.New("node group not found") + } + logger.Errorf("failed to find node group: %v", err) + return err + } + + // 检查是否有关联节点(使用JSON_CONTAINS查询node_group_ids数组) + var nodeCount int64 + if err := l.svcCtx.DB.Model(&node.Node{}).Where("JSON_CONTAINS(node_group_ids, ?)", fmt.Sprintf("[%d]", nodeGroup.Id)).Count(&nodeCount).Error; err != nil { + logger.Errorf("failed to count nodes in group: %v", err) + return err + } + if nodeCount > 0 { + return fmt.Errorf("cannot delete group with %d associated nodes, please migrate nodes first", nodeCount) + } + + // 使用 GORM Transaction 删除节点组 + return l.svcCtx.DB.Transaction(func(tx *gorm.DB) error { + // 删除节点组 + if err := tx.Where("id = ?", req.Id).Delete(&group.NodeGroup{}).Error; err != nil { + logger.Errorf("failed to delete node group: %v", err) + return err // 自动回滚 + } + + logger.Infof("deleted node group: id=%d", nodeGroup.Id) + return nil // 自动提交 + }) +} diff --git a/internal/logic/admin/group/exportGroupResultLogic.go b/internal/logic/admin/group/exportGroupResultLogic.go new file mode 100644 index 0000000..a84befa --- /dev/null +++ b/internal/logic/admin/group/exportGroupResultLogic.go @@ -0,0 +1,129 @@ +package group + +import ( + "bytes" + "context" + "encoding/csv" + "fmt" + + "github.com/perfect-panel/server/internal/model/group" + "github.com/perfect-panel/server/internal/model/user" + "github.com/perfect-panel/server/internal/svc" + "github.com/perfect-panel/server/internal/types" + "github.com/perfect-panel/server/pkg/logger" +) + +type ExportGroupResultLogic struct { + logger.Logger + ctx context.Context + svcCtx *svc.ServiceContext +} + +func NewExportGroupResultLogic(ctx context.Context, svcCtx *svc.ServiceContext) *ExportGroupResultLogic { + return &ExportGroupResultLogic{ + Logger: logger.WithContext(ctx), + ctx: ctx, + svcCtx: svcCtx, + } +} + +// ExportGroupResult 导出分组结果为 CSV +// 返回:CSV 数据(字节切片)、文件名、错误 +func (l *ExportGroupResultLogic) ExportGroupResult(req *types.ExportGroupResultRequest) ([]byte, string, error) { + var records [][]string + + // CSV 表头 + records = append(records, []string{"用户ID", "节点组ID", "节点组名称"}) + + if req.HistoryId != nil { + // 导出指定历史的详细结果 + // 1. 查询分组历史详情 + var details []group.GroupHistoryDetail + if err := l.svcCtx.DB.Where("history_id = ?", *req.HistoryId).Find(&details).Error; err != nil { + logger.Errorf("failed to get group history details: %v", err) + return nil, "", err + } + + // 2. 为每个组生成记录 + for _, detail := range details { + // 从 UserData JSON 解析用户信息 + type UserInfo struct { + Id int64 `json:"id"` + Email string `json:"email"` + } + var users []UserInfo + if err := l.svcCtx.DB.Raw("SELECT * FROM JSON_ARRAY(?)", detail.UserData).Scan(&users).Error; err != nil { + // 如果解析失败,尝试用标准 JSON 解析 + logger.Errorf("failed to parse user data: %v", err) + continue + } + + // 查询节点组名称 + var nodeGroup group.NodeGroup + l.svcCtx.DB.Where("id = ?", detail.NodeGroupId).First(&nodeGroup) + + // 为每个用户生成记录 + for _, user := range users { + records = append(records, []string{ + fmt.Sprintf("%d", user.Id), + fmt.Sprintf("%d", nodeGroup.Id), + nodeGroup.Name, + }) + } + } + } else { + // 导出当前所有用户的分组情况 + type UserNodeGroupInfo struct { + Id int64 `json:"id"` + NodeGroupId int64 `json:"node_group_id"` + } + var userSubscribes []UserNodeGroupInfo + if err := l.svcCtx.DB.Model(&user.Subscribe{}). + Select("DISTINCT user_id as id, node_group_id"). + Where("node_group_id > ?", 0). + Find(&userSubscribes).Error; err != nil { + logger.Errorf("failed to get users: %v", err) + return nil, "", err + } + + // 为每个用户生成记录 + for _, us := range userSubscribes { + // 查询节点组信息 + var nodeGroup group.NodeGroup + if err := l.svcCtx.DB.Where("id = ?", us.NodeGroupId).First(&nodeGroup).Error; err != nil { + logger.Errorf("failed to find node group: %v", err) + // 跳过该用户 + continue + } + + records = append(records, []string{ + fmt.Sprintf("%d", us.Id), + fmt.Sprintf("%d", nodeGroup.Id), + nodeGroup.Name, + }) + } + } + + // 生成 CSV 数据 + var buf bytes.Buffer + writer := csv.NewWriter(&buf) + writer.WriteAll(records) + writer.Flush() + + if err := writer.Error(); err != nil { + logger.Errorf("failed to write csv: %v", err) + return nil, "", err + } + + // 添加 UTF-8 BOM + bom := []byte{0xEF, 0xBB, 0xBF} + csvData := buf.Bytes() + result := make([]byte, 0, len(bom)+len(csvData)) + result = append(result, bom...) + result = append(result, csvData...) + + // 生成文件名 + filename := fmt.Sprintf("group_result_%d.csv", req.HistoryId) + + return result, filename, nil +} diff --git a/internal/logic/admin/group/getGroupConfigLogic.go b/internal/logic/admin/group/getGroupConfigLogic.go new file mode 100644 index 0000000..2aedb10 --- /dev/null +++ b/internal/logic/admin/group/getGroupConfigLogic.go @@ -0,0 +1,125 @@ +package group + +import ( + "context" + "encoding/json" + + "github.com/perfect-panel/server/internal/model/group" + "github.com/perfect-panel/server/internal/model/system" + "github.com/perfect-panel/server/internal/svc" + "github.com/perfect-panel/server/internal/types" + "github.com/perfect-panel/server/pkg/logger" + "github.com/pkg/errors" + "gorm.io/gorm" +) + +type GetGroupConfigLogic struct { + logger.Logger + ctx context.Context + svcCtx *svc.ServiceContext +} + +// Get group config +func NewGetGroupConfigLogic(ctx context.Context, svcCtx *svc.ServiceContext) *GetGroupConfigLogic { + return &GetGroupConfigLogic{ + Logger: logger.WithContext(ctx), + ctx: ctx, + svcCtx: svcCtx, + } +} + +func (l *GetGroupConfigLogic) GetGroupConfig(req *types.GetGroupConfigRequest) (resp *types.GetGroupConfigResponse, err error) { + // 读取基础配置 + var enabledConfig system.System + var modeConfig system.System + var averageConfig system.System + var subscribeConfig system.System + var trafficConfig system.System + + // 从 system_config 表读取配置 + if err := l.svcCtx.DB.Where("`category` = 'group' and `key` = ?", "enabled").First(&enabledConfig).Error; err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { + l.Errorw("failed to get group enabled config", logger.Field("error", err.Error())) + return nil, err + } + + if err := l.svcCtx.DB.Where("`category` = 'group' and `key` = ?", "mode").First(&modeConfig).Error; err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { + l.Errorw("failed to get group mode config", logger.Field("error", err.Error())) + return nil, err + } + + // 读取 JSON 配置 + config := make(map[string]interface{}) + + if err := l.svcCtx.DB.Where("`category` = 'group' and `key` = ?", "average_config").First(&averageConfig).Error; err == nil { + var averageCfg map[string]interface{} + if err := json.Unmarshal([]byte(averageConfig.Value), &averageCfg); err == nil { + config["average_config"] = averageCfg + } + } + + if err := l.svcCtx.DB.Where("`category` = 'group' and `key` = ?", "subscribe_config").First(&subscribeConfig).Error; err == nil { + var subscribeCfg map[string]interface{} + if err := json.Unmarshal([]byte(subscribeConfig.Value), &subscribeCfg); err == nil { + config["subscribe_config"] = subscribeCfg + } + } + + if err := l.svcCtx.DB.Where("`category` = 'group' and `key` = ?", "traffic_config").First(&trafficConfig).Error; err == nil { + var trafficCfg map[string]interface{} + if err := json.Unmarshal([]byte(trafficConfig.Value), &trafficCfg); err == nil { + config["traffic_config"] = trafficCfg + } + } + + // 解析基础配置 + enabled := enabledConfig.Value == "true" + mode := modeConfig.Value + if mode == "" { + mode = "average" // 默认模式 + } + + // 获取重算状态 + state, err := l.getRecalculationState() + if err != nil { + l.Errorw("failed to get recalculation state", logger.Field("error", err.Error())) + // 继续执行,不影响配置获取 + state = &types.RecalculationState{ + State: "idle", + Progress: 0, + Total: 0, + } + } + + resp = &types.GetGroupConfigResponse{ + Enabled: enabled, + Mode: mode, + Config: config, + State: *state, + } + + return resp, nil +} + +// getRecalculationState 获取重算状态 +func (l *GetGroupConfigLogic) getRecalculationState() (*types.RecalculationState, error) { + var history group.GroupHistory + err := l.svcCtx.DB.Order("id desc").First(&history).Error + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return &types.RecalculationState{ + State: "idle", + Progress: 0, + Total: 0, + }, nil + } + return nil, err + } + + state := &types.RecalculationState{ + State: history.State, + Progress: history.TotalUsers, + Total: history.TotalUsers, + } + + return state, nil +} diff --git a/internal/logic/admin/group/getGroupHistoryDetailLogic.go b/internal/logic/admin/group/getGroupHistoryDetailLogic.go new file mode 100644 index 0000000..d868d55 --- /dev/null +++ b/internal/logic/admin/group/getGroupHistoryDetailLogic.go @@ -0,0 +1,109 @@ +package group + +import ( + "context" + "encoding/json" + "errors" + + "github.com/perfect-panel/server/internal/model/group" + "github.com/perfect-panel/server/internal/svc" + "github.com/perfect-panel/server/internal/types" + "github.com/perfect-panel/server/pkg/logger" + "gorm.io/gorm" +) + +type GetGroupHistoryDetailLogic struct { + logger.Logger + ctx context.Context + svcCtx *svc.ServiceContext +} + +func NewGetGroupHistoryDetailLogic(ctx context.Context, svcCtx *svc.ServiceContext) *GetGroupHistoryDetailLogic { + return &GetGroupHistoryDetailLogic{ + Logger: logger.WithContext(ctx), + ctx: ctx, + svcCtx: svcCtx, + } +} + +func (l *GetGroupHistoryDetailLogic) GetGroupHistoryDetail(req *types.GetGroupHistoryDetailRequest) (resp *types.GetGroupHistoryDetailResponse, err error) { + // 查询分组历史记录 + var history group.GroupHistory + if err := l.svcCtx.DB.Where("id = ?", req.Id).First(&history).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, errors.New("group history not found") + } + logger.Errorf("failed to find group history: %v", err) + return nil, err + } + + // 查询分组历史详情 + var details []group.GroupHistoryDetail + if err := l.svcCtx.DB.Where("history_id = ?", req.Id).Find(&details).Error; err != nil { + logger.Errorf("failed to find group history details: %v", err) + return nil, err + } + + // 转换时间格式 + var startTime, endTime *int64 + if history.StartTime != nil { + t := history.StartTime.Unix() + startTime = &t + } + if history.EndTime != nil { + t := history.EndTime.Unix() + endTime = &t + } + + // 构建 GroupHistoryDetail + historyDetail := types.GroupHistoryDetail{ + GroupHistory: types.GroupHistory{ + Id: history.Id, + GroupMode: history.GroupMode, + TriggerType: history.TriggerType, + TotalUsers: history.TotalUsers, + SuccessCount: history.SuccessCount, + FailedCount: history.FailedCount, + StartTime: startTime, + EndTime: endTime, + ErrorLog: history.ErrorMessage, + CreatedAt: history.CreatedAt.Unix(), + }, + } + + // 如果有详情记录,构建 ConfigSnapshot + if len(details) > 0 { + configSnapshot := make(map[string]interface{}) + configSnapshot["group_details"] = details + + // 获取配置快照(从 system_config 读取) + var configValue string + if history.GroupMode == "average" { + l.svcCtx.DB.Table("system_config"). + Where("`key` = ?", "group.average_config"). + Select("value"). + Scan(&configValue) + } else if history.GroupMode == "traffic" { + l.svcCtx.DB.Table("system_config"). + Where("`key` = ?", "group.traffic_config"). + Select("value"). + Scan(&configValue) + } + + // 解析 JSON 配置 + if configValue != "" { + var config map[string]interface{} + if err := json.Unmarshal([]byte(configValue), &config); err == nil { + configSnapshot["config"] = config + } + } + + historyDetail.ConfigSnapshot = configSnapshot + } + + resp = &types.GetGroupHistoryDetailResponse{ + GroupHistoryDetail: historyDetail, + } + + return resp, nil +} diff --git a/internal/logic/admin/group/getGroupHistoryLogic.go b/internal/logic/admin/group/getGroupHistoryLogic.go new file mode 100644 index 0000000..6eee9c3 --- /dev/null +++ b/internal/logic/admin/group/getGroupHistoryLogic.go @@ -0,0 +1,87 @@ +package group + +import ( + "context" + + "github.com/perfect-panel/server/internal/model/group" + "github.com/perfect-panel/server/internal/svc" + "github.com/perfect-panel/server/internal/types" + "github.com/perfect-panel/server/pkg/logger" +) + +type GetGroupHistoryLogic struct { + logger.Logger + ctx context.Context + svcCtx *svc.ServiceContext +} + +func NewGetGroupHistoryLogic(ctx context.Context, svcCtx *svc.ServiceContext) *GetGroupHistoryLogic { + return &GetGroupHistoryLogic{ + Logger: logger.WithContext(ctx), + ctx: ctx, + svcCtx: svcCtx, + } +} + +func (l *GetGroupHistoryLogic) GetGroupHistory(req *types.GetGroupHistoryRequest) (resp *types.GetGroupHistoryResponse, err error) { + var histories []group.GroupHistory + var total int64 + + // 构建查询 + query := l.svcCtx.DB.Model(&group.GroupHistory{}) + + // 添加过滤条件 + if req.GroupMode != "" { + query = query.Where("group_mode = ?", req.GroupMode) + } + if req.TriggerType != "" { + query = query.Where("trigger_type = ?", req.TriggerType) + } + + // 获取总数 + if err := query.Count(&total).Error; err != nil { + logger.Errorf("failed to count group histories: %v", err) + return nil, err + } + + // 分页查询 + offset := (req.Page - 1) * req.Size + if err := query.Order("id DESC").Offset(offset).Limit(req.Size).Find(&histories).Error; err != nil { + logger.Errorf("failed to find group histories: %v", err) + return nil, err + } + + // 转换为响应格式 + var list []types.GroupHistory + for _, h := range histories { + var startTime, endTime *int64 + if h.StartTime != nil { + t := h.StartTime.Unix() + startTime = &t + } + if h.EndTime != nil { + t := h.EndTime.Unix() + endTime = &t + } + + list = append(list, types.GroupHistory{ + Id: h.Id, + GroupMode: h.GroupMode, + TriggerType: h.TriggerType, + TotalUsers: h.TotalUsers, + SuccessCount: h.SuccessCount, + FailedCount: h.FailedCount, + StartTime: startTime, + EndTime: endTime, + ErrorLog: h.ErrorMessage, + CreatedAt: h.CreatedAt.Unix(), + }) + } + + resp = &types.GetGroupHistoryResponse{ + Total: total, + List: list, + } + + return resp, nil +} diff --git a/internal/logic/admin/group/getNodeGroupListLogic.go b/internal/logic/admin/group/getNodeGroupListLogic.go new file mode 100644 index 0000000..9595393 --- /dev/null +++ b/internal/logic/admin/group/getNodeGroupListLogic.go @@ -0,0 +1,103 @@ +package group + +import ( + "context" + "fmt" + + "github.com/perfect-panel/server/internal/model/group" + "github.com/perfect-panel/server/internal/model/node" + "github.com/perfect-panel/server/internal/svc" + "github.com/perfect-panel/server/internal/types" + "github.com/perfect-panel/server/pkg/logger" +) + +type GetNodeGroupListLogic struct { + logger.Logger + ctx context.Context + svcCtx *svc.ServiceContext +} + +func NewGetNodeGroupListLogic(ctx context.Context, svcCtx *svc.ServiceContext) *GetNodeGroupListLogic { + return &GetNodeGroupListLogic{ + Logger: logger.WithContext(ctx), + ctx: ctx, + svcCtx: svcCtx, + } +} + +func (l *GetNodeGroupListLogic) GetNodeGroupList(req *types.GetNodeGroupListRequest) (resp *types.GetNodeGroupListResponse, err error) { + var nodeGroups []group.NodeGroup + var total int64 + + // 构建查询 + query := l.svcCtx.DB.Model(&group.NodeGroup{}) + + // 获取总数 + if err := query.Count(&total).Error; err != nil { + logger.Errorf("failed to count node groups: %v", err) + return nil, err + } + + // 分页查询 + offset := (req.Page - 1) * req.Size + if err := query.Order("sort ASC").Offset(offset).Limit(req.Size).Find(&nodeGroups).Error; err != nil { + logger.Errorf("failed to find node groups: %v", err) + return nil, err + } + + // 转换为响应格式 + var list []types.NodeGroup + for _, ng := range nodeGroups { + // 统计该组的节点数(JSON数组查询) + var nodeCount int64 + l.svcCtx.DB.Model(&node.Node{}).Where("JSON_CONTAINS(node_group_ids, ?)", fmt.Sprintf("[%d]", ng.Id)).Count(&nodeCount) + + // 处理指针类型的字段 + var forCalculation bool + if ng.ForCalculation != nil { + forCalculation = *ng.ForCalculation + } else { + forCalculation = true // 默认值 + } + + var isExpiredGroup bool + if ng.IsExpiredGroup != nil { + isExpiredGroup = *ng.IsExpiredGroup + } + + var minTrafficGB, maxTrafficGB, maxTrafficGBExpired int64 + if ng.MinTrafficGB != nil { + minTrafficGB = *ng.MinTrafficGB + } + if ng.MaxTrafficGB != nil { + maxTrafficGB = *ng.MaxTrafficGB + } + if ng.MaxTrafficGBExpired != nil { + maxTrafficGBExpired = *ng.MaxTrafficGBExpired + } + + list = append(list, types.NodeGroup{ + Id: ng.Id, + Name: ng.Name, + Description: ng.Description, + Sort: ng.Sort, + ForCalculation: forCalculation, + IsExpiredGroup: isExpiredGroup, + ExpiredDaysLimit: ng.ExpiredDaysLimit, + MaxTrafficGBExpired: maxTrafficGBExpired, + SpeedLimit: ng.SpeedLimit, + MinTrafficGB: minTrafficGB, + MaxTrafficGB: maxTrafficGB, + NodeCount: nodeCount, + CreatedAt: ng.CreatedAt.Unix(), + UpdatedAt: ng.UpdatedAt.Unix(), + }) + } + + resp = &types.GetNodeGroupListResponse{ + Total: total, + List: list, + } + + return resp, nil +} diff --git a/internal/logic/admin/group/getRecalculationStatusLogic.go b/internal/logic/admin/group/getRecalculationStatusLogic.go new file mode 100644 index 0000000..9a04f80 --- /dev/null +++ b/internal/logic/admin/group/getRecalculationStatusLogic.go @@ -0,0 +1,57 @@ +package group + +import ( + "context" + + "github.com/perfect-panel/server/internal/model/group" + "github.com/perfect-panel/server/internal/svc" + "github.com/perfect-panel/server/internal/types" + "github.com/perfect-panel/server/pkg/logger" + "github.com/pkg/errors" + "gorm.io/gorm" +) + +type GetRecalculationStatusLogic struct { + logger.Logger + ctx context.Context + svcCtx *svc.ServiceContext +} + +// Get recalculation status +func NewGetRecalculationStatusLogic(ctx context.Context, svcCtx *svc.ServiceContext) *GetRecalculationStatusLogic { + return &GetRecalculationStatusLogic{ + Logger: logger.WithContext(ctx), + ctx: ctx, + svcCtx: svcCtx, + } +} + +func (l *GetRecalculationStatusLogic) GetRecalculationStatus() (resp *types.RecalculationState, err error) { + // 返回最近的一条 GroupHistory 记录 + var history group.GroupHistory + err = l.svcCtx.DB.Order("id desc").First(&history).Error + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + // 如果没有历史记录,返回空闲状态 + resp = &types.RecalculationState{ + State: "idle", + Progress: 0, + Total: 0, + } + return resp, nil + } + l.Errorw("failed to get group history", logger.Field("error", err.Error())) + return nil, err + } + + // 转换为 RecalculationState 格式 + // Progress = 已处理的用户数(成功+失败),Total = 总用户数 + processedUsers := history.SuccessCount + history.FailedCount + resp = &types.RecalculationState{ + State: history.State, + Progress: processedUsers, + Total: history.TotalUsers, + } + + return resp, nil +} diff --git a/internal/logic/admin/group/getSubscribeGroupMappingLogic.go b/internal/logic/admin/group/getSubscribeGroupMappingLogic.go new file mode 100644 index 0000000..cd26305 --- /dev/null +++ b/internal/logic/admin/group/getSubscribeGroupMappingLogic.go @@ -0,0 +1,71 @@ +package group + +import ( + "context" + + "github.com/perfect-panel/server/internal/model/group" + "github.com/perfect-panel/server/internal/model/subscribe" + "github.com/perfect-panel/server/internal/svc" + "github.com/perfect-panel/server/internal/types" + "github.com/perfect-panel/server/pkg/logger" +) + +type GetSubscribeGroupMappingLogic struct { + logger.Logger + ctx context.Context + svcCtx *svc.ServiceContext +} + +// Get subscribe group mapping +func NewGetSubscribeGroupMappingLogic(ctx context.Context, svcCtx *svc.ServiceContext) *GetSubscribeGroupMappingLogic { + return &GetSubscribeGroupMappingLogic{ + Logger: logger.WithContext(ctx), + ctx: ctx, + svcCtx: svcCtx, + } +} + +func (l *GetSubscribeGroupMappingLogic) GetSubscribeGroupMapping(req *types.GetSubscribeGroupMappingRequest) (resp *types.GetSubscribeGroupMappingResponse, err error) { + // 1. 查询所有订阅套餐 + var subscribes []subscribe.Subscribe + if err := l.svcCtx.DB.Model(&subscribe.Subscribe{}).Find(&subscribes).Error; err != nil { + l.Errorw("[GetSubscribeGroupMapping] failed to query subscribes", logger.Field("error", err.Error())) + return nil, err + } + + // 2. 查询所有节点组 + var nodeGroups []group.NodeGroup + if err := l.svcCtx.DB.Model(&group.NodeGroup{}).Find(&nodeGroups).Error; err != nil { + l.Errorw("[GetSubscribeGroupMapping] failed to query node groups", logger.Field("error", err.Error())) + return nil, err + } + + // 创建 node_group_id -> node_group_name 的映射 + nodeGroupMap := make(map[int64]string) + for _, ng := range nodeGroups { + nodeGroupMap[ng.Id] = ng.Name + } + + // 3. 构建映射结果:套餐 -> 默认节点组(一对一) + var mappingList []types.SubscribeGroupMappingItem + + for _, sub := range subscribes { + // 获取套餐的默认节点组(node_group_ids 数组的第一个) + nodeGroupName := "" + if len(sub.NodeGroupIds) > 0 { + defaultNodeGroupId := sub.NodeGroupIds[0] + nodeGroupName = nodeGroupMap[defaultNodeGroupId] + } + + mappingList = append(mappingList, types.SubscribeGroupMappingItem{ + SubscribeName: sub.Name, + NodeGroupName: nodeGroupName, + }) + } + + resp = &types.GetSubscribeGroupMappingResponse{ + List: mappingList, + } + + return resp, nil +} diff --git a/internal/logic/admin/group/previewUserNodesLogic.go b/internal/logic/admin/group/previewUserNodesLogic.go new file mode 100644 index 0000000..2122da5 --- /dev/null +++ b/internal/logic/admin/group/previewUserNodesLogic.go @@ -0,0 +1,585 @@ +package group + +import ( + "context" + "encoding/json" + "fmt" + "strings" + + "github.com/perfect-panel/server/internal/model/group" + "github.com/perfect-panel/server/internal/model/node" + "github.com/perfect-panel/server/internal/model/subscribe" + "github.com/perfect-panel/server/internal/model/user" + "github.com/perfect-panel/server/internal/svc" + "github.com/perfect-panel/server/internal/types" + "github.com/perfect-panel/server/pkg/logger" + "github.com/perfect-panel/server/pkg/tool" +) + +type PreviewUserNodesLogic struct { + logger.Logger + ctx context.Context + svcCtx *svc.ServiceContext +} + +func NewPreviewUserNodesLogic(ctx context.Context, svcCtx *svc.ServiceContext) *PreviewUserNodesLogic { + return &PreviewUserNodesLogic{ + Logger: logger.WithContext(ctx), + ctx: ctx, + svcCtx: svcCtx, + } +} + +func (l *PreviewUserNodesLogic) PreviewUserNodes(req *types.PreviewUserNodesRequest) (resp *types.PreviewUserNodesResponse, err error) { + logger.Infof("[PreviewUserNodes] userId: %v", req.UserId) + + // 1. 查询用户的所有有效订阅(只查询可用状态:0-Pending, 1-Active) + type UserSubscribe struct { + Id int64 + UserId int64 + SubscribeId int64 + NodeGroupId int64 // 用户订阅的 node_group_id(单个ID) + } + var userSubscribes []UserSubscribe + err = l.svcCtx.DB.Model(&user.Subscribe{}). + Select("id, user_id, subscribe_id, node_group_id"). + Where("user_id = ? AND status IN ?", req.UserId, []int8{0, 1}). + Find(&userSubscribes).Error + if err != nil { + logger.Errorf("[PreviewUserNodes] failed to get user subscribes: %v", err) + return nil, err + } + + if len(userSubscribes) == 0 { + logger.Infof("[PreviewUserNodes] no user subscribes found") + resp = &types.PreviewUserNodesResponse{ + UserId: req.UserId, + NodeGroups: []types.NodeGroupItem{}, + } + return resp, nil + } + + logger.Infof("[PreviewUserNodes] found %v user subscribes", len(userSubscribes)) + + // 2. 按优先级获取 node_group_id:user_subscribe.node_group_id > subscribe.node_group_id > subscribe.node_group_ids[0] + // 收集所有订阅ID以便批量查询 + subscribeIds := make([]int64, len(userSubscribes)) + for i, us := range userSubscribes { + subscribeIds[i] = us.SubscribeId + } + + // 批量查询订阅信息 + type SubscribeInfo struct { + Id int64 + NodeGroupId int64 + NodeGroupIds string // JSON string + Nodes string // JSON string - 直接分配的节点ID + NodeTags string // 节点标签 + } + var subscribeInfos []SubscribeInfo + err = l.svcCtx.DB.Model(&subscribe.Subscribe{}). + Select("id, node_group_id, node_group_ids, nodes, node_tags"). + Where("id IN ?", subscribeIds). + Find(&subscribeInfos).Error + if err != nil { + logger.Errorf("[PreviewUserNodes] failed to get subscribe infos: %v", err) + return nil, err + } + + // 创建 subscribe_id -> SubscribeInfo 的映射 + subInfoMap := make(map[int64]SubscribeInfo) + for _, si := range subscribeInfos { + subInfoMap[si.Id] = si + } + + // 按优先级获取每个用户订阅的 node_group_id + var allNodeGroupIds []int64 + for _, us := range userSubscribes { + nodeGroupId := int64(0) + + // 优先级1: user_subscribe.node_group_id + if us.NodeGroupId != 0 { + nodeGroupId = us.NodeGroupId + logger.Debugf("[PreviewUserNodes] user_subscribe_id=%d using node_group_id=%d", us.Id, nodeGroupId) + } else { + // 优先级2: subscribe.node_group_id + subInfo, ok := subInfoMap[us.SubscribeId] + if ok { + if subInfo.NodeGroupId != 0 { + nodeGroupId = subInfo.NodeGroupId + logger.Debugf("[PreviewUserNodes] user_subscribe_id=%d using subscribe.node_group_id=%d", us.Id, nodeGroupId) + } else if subInfo.NodeGroupIds != "" && subInfo.NodeGroupIds != "null" && subInfo.NodeGroupIds != "[]" { + // 优先级3: subscribe.node_group_ids[0] + var nodeGroupIds []int64 + if err := json.Unmarshal([]byte(subInfo.NodeGroupIds), &nodeGroupIds); err == nil && len(nodeGroupIds) > 0 { + nodeGroupId = nodeGroupIds[0] + logger.Debugf("[PreviewUserNodes] user_subscribe_id=%d using subscribe.node_group_ids[0]=%d", us.Id, nodeGroupId) + } + } + } + } + + if nodeGroupId != 0 { + allNodeGroupIds = append(allNodeGroupIds, nodeGroupId) + } + } + + // 去重 + allNodeGroupIds = removeDuplicateInt64(allNodeGroupIds) + + logger.Infof("[PreviewUserNodes] collected node_group_ids with priority: %v", allNodeGroupIds) + + // 3. 收集所有订阅中直接分配的节点ID + var allDirectNodeIds []int64 + for _, subInfo := range subscribeInfos { + if subInfo.Nodes != "" && subInfo.Nodes != "null" { + // nodes 是逗号分隔的字符串,如 "1,2,3" + nodeIdStrs := strings.Split(subInfo.Nodes, ",") + for _, idStr := range nodeIdStrs { + idStr = strings.TrimSpace(idStr) + if idStr != "" { + var nodeId int64 + if _, err := fmt.Sscanf(idStr, "%d", &nodeId); err == nil { + allDirectNodeIds = append(allDirectNodeIds, nodeId) + } + } + } + logger.Debugf("[PreviewUserNodes] subscribe_id=%d has direct nodes: %s", subInfo.Id, subInfo.Nodes) + } + } + // 去重 + allDirectNodeIds = removeDuplicateInt64(allDirectNodeIds) + logger.Infof("[PreviewUserNodes] collected direct node_ids: %v", allDirectNodeIds) + + // 4. 判断分组功能是否启用 + type SystemConfig struct { + Value string + } + var config SystemConfig + l.svcCtx.DB.Model(&struct { + Category string `gorm:"column:category"` + Key string `gorm:"column:key"` + Value string `gorm:"column:value"` + }{}). + Table("system"). + Where("`category` = ? AND `key` = ?", "group", "enabled"). + Select("value"). + Scan(&config) + + logger.Infof("[PreviewUserNodes] groupEnabled: %v", config.Value) + + isGroupEnabled := config.Value == "true" || config.Value == "1" + + var filteredNodes []node.Node + + if isGroupEnabled { + // === 启用分组功能:通过用户订阅的 node_group_id 查询节点 === + logger.Infof("[PreviewUserNodes] using group-based node filtering") + + if len(allNodeGroupIds) == 0 && len(allDirectNodeIds) == 0 { + logger.Infof("[PreviewUserNodes] no node groups and no direct nodes found in user subscribes") + resp = &types.PreviewUserNodesResponse{ + UserId: req.UserId, + NodeGroups: []types.NodeGroupItem{}, + } + return resp, nil + } + + // 5. 查询所有启用的节点(只有当有节点组时才查询) + if len(allNodeGroupIds) > 0 { + var dbNodes []node.Node + err = l.svcCtx.DB.Model(&node.Node{}). + Where("enabled = ?", true). + Find(&dbNodes).Error + if err != nil { + logger.Errorf("[PreviewUserNodes] failed to get nodes: %v", err) + return nil, err + } + + // 6. 过滤出包含至少一个匹配节点组的节点 + // node_group_ids 为空 = 公共节点,所有人可见 + // node_group_ids 与订阅的 node_group_id 匹配 = 该节点可见 + for _, n := range dbNodes { + // 公共节点(node_group_ids 为空),所有人可见 + if len(n.NodeGroupIds) == 0 { + filteredNodes = append(filteredNodes, n) + continue + } + + // 检查节点的 node_group_ids 是否与订阅的 node_group_id 有交集 + for _, nodeGroupId := range n.NodeGroupIds { + if tool.Contains(allNodeGroupIds, nodeGroupId) { + filteredNodes = append(filteredNodes, n) + break + } + } + } + + logger.Infof("[PreviewUserNodes] found %v nodes using group filter", len(filteredNodes)) + } + + } else { + // === 未启用分组功能:通过订阅的 node_tags 查询节点 === + logger.Infof("[PreviewUserNodes] using tag-based node filtering") + + // 从已查询的 subscribeInfos 中获取 node_tags + var allTags []string + for _, subInfo := range subscribeInfos { + if subInfo.NodeTags != "" { + tags := strings.Split(subInfo.NodeTags, ",") + allTags = append(allTags, tags...) + } + } + // 去重 + allTags = tool.RemoveDuplicateElements(allTags...) + // 去除空字符串 + allTags = tool.RemoveStringElement(allTags, "") + + logger.Infof("[PreviewUserNodes] merged tags from subscribes: %v", allTags) + + if len(allTags) == 0 && len(allDirectNodeIds) == 0 { + logger.Infof("[PreviewUserNodes] no tags and no direct nodes found in subscribes") + resp = &types.PreviewUserNodesResponse{ + UserId: req.UserId, + NodeGroups: []types.NodeGroupItem{}, + } + return resp, nil + } + + // 8. 查询所有启用的节点(只有当有 tags 时才查询) + if len(allTags) > 0 { + var dbNodes []node.Node + err = l.svcCtx.DB.Model(&node.Node{}). + Where("enabled = ?", true). + Find(&dbNodes).Error + if err != nil { + logger.Errorf("[PreviewUserNodes] failed to get nodes: %v", err) + return nil, err + } + + // 9. 过滤出包含至少一个匹配标签的节点 + for _, n := range dbNodes { + if n.Tags == "" { + continue + } + nodeTags := strings.Split(n.Tags, ",") + // 检查是否有交集 + for _, tag := range nodeTags { + if tag != "" && tool.Contains(allTags, tag) { + filteredNodes = append(filteredNodes, n) + break + } + } + } + + logger.Infof("[PreviewUserNodes] found %v nodes using tag filter", len(filteredNodes)) + } + } + + // 10. 根据是否启用分组功能,选择不同的分组方式 + nodeGroupItems := make([]types.NodeGroupItem, 0) + + if isGroupEnabled { + // === 启用分组:按节点组分组 === + // 转换为 types.Node 并按节点组分组 + type NodeWithGroup struct { + Node node.Node + NodeGroupIds []int64 + } + + nodesWithGroup := make([]NodeWithGroup, 0, len(filteredNodes)) + for _, n := range filteredNodes { + nodesWithGroup = append(nodesWithGroup, NodeWithGroup{ + Node: n, + NodeGroupIds: n.NodeGroupIds, + }) + } + + // 按节点组分组节点 + type NodeGroupMap struct { + Id int64 + Nodes []types.Node + } + + // 创建节点组映射:group_id -> nodes + groupMap := make(map[int64]*NodeGroupMap) + + // 获取所有涉及的节点组ID + allGroupIds := make([]int64, 0) + for _, ng := range nodesWithGroup { + if len(ng.NodeGroupIds) > 0 { + // 如果节点属于节点组,按第一个节点组分组 + firstGroupId := ng.NodeGroupIds[0] + if _, exists := groupMap[firstGroupId]; !exists { + groupMap[firstGroupId] = &NodeGroupMap{ + Id: firstGroupId, + Nodes: []types.Node{}, + } + allGroupIds = append(allGroupIds, firstGroupId) + } + + // 转换节点 + tags := []string{} + if ng.Node.Tags != "" { + tags = strings.Split(ng.Node.Tags, ",") + } + node := types.Node{ + Id: ng.Node.Id, + Name: ng.Node.Name, + Tags: tags, + Port: ng.Node.Port, + Address: ng.Node.Address, + ServerId: ng.Node.ServerId, + Protocol: ng.Node.Protocol, + Enabled: ng.Node.Enabled, + Sort: ng.Node.Sort, + NodeGroupIds: []int64(ng.Node.NodeGroupIds), + CreatedAt: ng.Node.CreatedAt.Unix(), + UpdatedAt: ng.Node.UpdatedAt.Unix(), + } + + groupMap[firstGroupId].Nodes = append(groupMap[firstGroupId].Nodes, node) + } else { + // 没有节点组的节点,使用 group_id = 0 作为"无节点组"分组 + if _, exists := groupMap[0]; !exists { + groupMap[0] = &NodeGroupMap{ + Id: 0, + Nodes: []types.Node{}, + } + } + + tags := []string{} + if ng.Node.Tags != "" { + tags = strings.Split(ng.Node.Tags, ",") + } + node := types.Node{ + Id: ng.Node.Id, + Name: ng.Node.Name, + Tags: tags, + Port: ng.Node.Port, + Address: ng.Node.Address, + ServerId: ng.Node.ServerId, + Protocol: ng.Node.Protocol, + Enabled: ng.Node.Enabled, + Sort: ng.Node.Sort, + NodeGroupIds: []int64(ng.Node.NodeGroupIds), + CreatedAt: ng.Node.CreatedAt.Unix(), + UpdatedAt: ng.Node.UpdatedAt.Unix(), + } + + groupMap[0].Nodes = append(groupMap[0].Nodes, node) + } + } + + // 查询节点组信息并构建响应 + nodeGroupInfoMap := make(map[int64]string) + validGroupIds := make([]int64, 0) + + if len(allGroupIds) > 0 { + type NodeGroupInfo struct { + Id int64 + Name string + } + var nodeGroupInfos []NodeGroupInfo + err = l.svcCtx.DB.Model(&group.NodeGroup{}). + Select("id, name"). + Where("id IN ?", allGroupIds). + Find(&nodeGroupInfos).Error + if err != nil { + logger.Errorf("[PreviewUserNodes] failed to get node group infos: %v", err) + return nil, err + } + + logger.Infof("[PreviewUserNodes] found %v node group infos from %v requested", len(nodeGroupInfos), len(allGroupIds)) + + // 创建节点组信息映射和有效节点组ID列表 + for _, ngInfo := range nodeGroupInfos { + nodeGroupInfoMap[ngInfo.Id] = ngInfo.Name + validGroupIds = append(validGroupIds, ngInfo.Id) + logger.Debugf("[PreviewUserNodes] node_group[%d] = %s", ngInfo.Id, ngInfo.Name) + } + + // 记录无效的节点组ID + for _, requestedId := range allGroupIds { + found := false + for _, validId := range validGroupIds { + if requestedId == validId { + found = true + break + } + } + if !found { + logger.Infof("[PreviewUserNodes] node_group_id %d not found in database, treating as public nodes", requestedId) + } + } + } + + // 构建响应:根据有效节点组ID重新分组节点 + publicNodes := make([]types.Node, 0) + + // 遍历所有分组,重新分类节点 + for groupId, gm := range groupMap { + if groupId == 0 { + // 本来就是无节点组的节点 + publicNodes = append(publicNodes, gm.Nodes...) + continue + } + + // 检查这个节点组ID是否有效 + isValid := false + for _, validId := range validGroupIds { + if groupId == validId { + isValid = true + break + } + } + + if isValid { + // 节点组有效,添加到对应的分组 + groupName := nodeGroupInfoMap[groupId] + if groupName == "" { + groupName = fmt.Sprintf("Group %d", groupId) + } + nodeGroupItems = append(nodeGroupItems, types.NodeGroupItem{ + Id: groupId, + Name: groupName, + Nodes: gm.Nodes, + }) + logger.Infof("[PreviewUserNodes] adding node group: id=%d, name=%s, nodes=%d", groupId, groupName, len(gm.Nodes)) + } else { + // 节点组无效,节点归入公共节点组 + logger.Infof("[PreviewUserNodes] node_group_id %d invalid, moving %d nodes to public group", groupId, len(gm.Nodes)) + publicNodes = append(publicNodes, gm.Nodes...) + } + } + + // 添加公共节点组(如果有) + if len(publicNodes) > 0 { + nodeGroupItems = append(nodeGroupItems, types.NodeGroupItem{ + Id: 0, + Name: "", + Nodes: publicNodes, + }) + logger.Infof("[PreviewUserNodes] adding public group: nodes=%d", len(publicNodes)) + } + + } else { + // === 未启用分组:按 tag 分组 === + // 按 tag 分组节点 + tagGroupMap := make(map[string][]types.Node) + + for _, n := range filteredNodes { + tags := []string{} + if n.Tags != "" { + tags = strings.Split(n.Tags, ",") + } + + // 转换节点 + node := types.Node{ + Id: n.Id, + Name: n.Name, + Tags: tags, + Port: n.Port, + Address: n.Address, + ServerId: n.ServerId, + Protocol: n.Protocol, + Enabled: n.Enabled, + Sort: n.Sort, + NodeGroupIds: []int64(n.NodeGroupIds), + CreatedAt: n.CreatedAt.Unix(), + UpdatedAt: n.UpdatedAt.Unix(), + } + + // 将节点添加到每个匹配的 tag 分组中 + if len(tags) > 0 { + for _, tag := range tags { + tag = strings.TrimSpace(tag) + if tag != "" { + tagGroupMap[tag] = append(tagGroupMap[tag], node) + } + } + } else { + // 没有 tag 的节点放入特殊分组 + tagGroupMap[""] = append(tagGroupMap[""], node) + } + } + + // 构建响应:按 tag 分组 + for tag, nodes := range tagGroupMap { + nodeGroupItems = append(nodeGroupItems, types.NodeGroupItem{ + Id: 0, // tag 分组使用 ID 0 + Name: tag, + Nodes: nodes, + }) + logger.Infof("[PreviewUserNodes] adding tag group: tag=%s, nodes=%d", tag, len(nodes)) + } + } + + // 添加套餐节点组(直接分配的节点) + if len(allDirectNodeIds) > 0 { + // 查询直接分配的节点详情 + var directNodes []node.Node + err = l.svcCtx.DB.Model(&node.Node{}). + Where("id IN ? AND enabled = ?", allDirectNodeIds, true). + Find(&directNodes).Error + if err != nil { + logger.Errorf("[PreviewUserNodes] failed to get direct nodes: %v", err) + return nil, err + } + + if len(directNodes) > 0 { + // 转换为 types.Node + directNodeItems := make([]types.Node, 0, len(directNodes)) + for _, n := range directNodes { + tags := []string{} + if n.Tags != "" { + tags = strings.Split(n.Tags, ",") + } + directNodeItems = append(directNodeItems, types.Node{ + Id: n.Id, + Name: n.Name, + Tags: tags, + Port: n.Port, + Address: n.Address, + ServerId: n.ServerId, + Protocol: n.Protocol, + Enabled: n.Enabled, + Sort: n.Sort, + NodeGroupIds: []int64(n.NodeGroupIds), + CreatedAt: n.CreatedAt.Unix(), + UpdatedAt: n.UpdatedAt.Unix(), + }) + } + + // 添加套餐节点组(使用特殊ID -1,Name 为空字符串,前端根据 ID -1 进行国际化) + nodeGroupItems = append(nodeGroupItems, types.NodeGroupItem{ + Id: -1, + Name: "", // 空字符串,前端根据 ID -1 识别并国际化 + Nodes: directNodeItems, + }) + logger.Infof("[PreviewUserNodes] adding subscription nodes group: nodes=%d", len(directNodeItems)) + } + } + + // 14. 返回结果 + resp = &types.PreviewUserNodesResponse{ + UserId: req.UserId, + NodeGroups: nodeGroupItems, + } + + logger.Infof("[PreviewUserNodes] returning %v node groups for user %v", len(resp.NodeGroups), req.UserId) + return resp, nil +} + +// removeDuplicateInt64 去重 []int64 +func removeDuplicateInt64(slice []int64) []int64 { + keys := make(map[int64]bool) + var list []int64 + for _, entry := range slice { + if !keys[entry] { + keys[entry] = true + list = append(list, entry) + } + } + return list +} diff --git a/internal/logic/admin/group/recalculateGroupLogic.go b/internal/logic/admin/group/recalculateGroupLogic.go new file mode 100644 index 0000000..cb16188 --- /dev/null +++ b/internal/logic/admin/group/recalculateGroupLogic.go @@ -0,0 +1,818 @@ +package group + +import ( + "context" + "encoding/json" + "fmt" + "time" + + "github.com/perfect-panel/server/internal/model/group" + "github.com/perfect-panel/server/internal/model/node" + "github.com/perfect-panel/server/internal/model/subscribe" + "github.com/perfect-panel/server/internal/model/user" + "github.com/perfect-panel/server/internal/svc" + "github.com/perfect-panel/server/internal/types" + "github.com/perfect-panel/server/pkg/logger" + "github.com/pkg/errors" + "gorm.io/gorm" +) + +type RecalculateGroupLogic struct { + logger.Logger + ctx context.Context + svcCtx *svc.ServiceContext +} + +// Recalculate group +func NewRecalculateGroupLogic(ctx context.Context, svcCtx *svc.ServiceContext) *RecalculateGroupLogic { + return &RecalculateGroupLogic{ + Logger: logger.WithContext(ctx), + ctx: ctx, + svcCtx: svcCtx, + } +} + +func (l *RecalculateGroupLogic) RecalculateGroup(req *types.RecalculateGroupRequest) error { + // 验证 mode 参数 + if req.Mode != "average" && req.Mode != "subscribe" && req.Mode != "traffic" { + return errors.New("invalid mode, must be one of: average, subscribe, traffic") + } + + // 创建 GroupHistory 记录(state=pending) + triggerType := req.TriggerType + if triggerType == "" { + triggerType = "manual" // 默认为手动触发 + } + + history := &group.GroupHistory{ + GroupMode: req.Mode, + TriggerType: triggerType, + TotalUsers: 0, + SuccessCount: 0, + FailedCount: 0, + } + now := time.Now() + history.StartTime = &now + + // 使用 GORM Transaction 执行分组重算 + err := l.svcCtx.DB.Transaction(func(tx *gorm.DB) error { + // 创建历史记录 + if err := tx.Create(history).Error; err != nil { + l.Errorw("failed to create group history", logger.Field("error", err.Error())) + return err + } + + // 更新状态为 running + if err := tx.Model(history).Update("state", "running").Error; err != nil { + l.Errorw("failed to update history state to running", logger.Field("error", err.Error())) + return err + } + + // 根据 mode 执行不同的分组算法 + var affectedCount int + var err error + + switch req.Mode { + case "average": + affectedCount, err = l.executeAverageGrouping(tx, history.Id) + if err != nil { + l.Errorw("failed to execute average grouping", logger.Field("error", err.Error())) + return err + } + case "subscribe": + affectedCount, err = l.executeSubscribeGrouping(tx, history.Id) + if err != nil { + l.Errorw("failed to execute subscribe grouping", logger.Field("error", err.Error())) + return err + } + case "traffic": + affectedCount, err = l.executeTrafficGrouping(tx, history.Id) + if err != nil { + l.Errorw("failed to execute traffic grouping", logger.Field("error", err.Error())) + return err + } + } + + // 更新 GroupHistory 记录(state=completed, 统计成功/失败数) + endTime := time.Now() + updates := map[string]interface{}{ + "state": "completed", + "total_users": affectedCount, + "success_count": affectedCount, // 暂时假设所有都成功 + "failed_count": 0, + "end_time": endTime, + } + + if err := tx.Model(history).Updates(updates).Error; err != nil { + l.Errorw("failed to update history state to completed", logger.Field("error", err.Error())) + return err + } + + l.Infof("group recalculation completed: mode=%s, affected_users=%d", req.Mode, affectedCount) + return nil + }) + + if err != nil { + // 如果失败,更新历史记录状态为 failed + updateErr := l.svcCtx.DB.Model(history).Updates(map[string]interface{}{ + "state": "failed", + "error_message": err.Error(), + "end_time": time.Now(), + }).Error + if updateErr != nil { + l.Errorw("failed to update history state to failed", logger.Field("error", updateErr.Error())) + } + return err + } + + return nil +} + +// getUserEmail 查询用户的邮箱 +func (l *RecalculateGroupLogic) getUserEmail(tx *gorm.DB, userId int64) string { + type UserAuthMethod struct { + AuthIdentifier string `json:"auth_identifier"` + } + + var authMethod UserAuthMethod + if err := tx.Model(&user.AuthMethods{}). + Select("auth_identifier"). + Where("user_id = ? AND (auth_type = ? OR auth_type = ?)", userId, "email", "6"). + First(&authMethod).Error; err != nil { + return "" + } + + return authMethod.AuthIdentifier +} + +// executeAverageGrouping 实现平均分组算法(随机分配节点组到用户订阅) +// 新逻辑:获取所有有效用户订阅,从订阅的节点组ID中随机选择一个,设置到用户订阅的 node_group_id 字段 +func (l *RecalculateGroupLogic) executeAverageGrouping(tx *gorm.DB, historyId int64) (int, error) { + // 1. 查询所有有效且未锁定的用户订阅(status IN (0, 1)) + type UserSubscribeInfo struct { + Id int64 `json:"id"` + UserId int64 `json:"user_id"` + SubscribeId int64 `json:"subscribe_id"` + } + + var userSubscribes []UserSubscribeInfo + if err := tx.Model(&user.Subscribe{}). + Select("id, user_id, subscribe_id"). + Where("group_locked = ? AND status IN (0, 1)", 0). // 只查询未锁定且有效的用户订阅 + Scan(&userSubscribes).Error; err != nil { + return 0, err + } + + if len(userSubscribes) == 0 { + l.Infof("average grouping: no valid and unlocked user subscribes found") + return 0, nil + } + + l.Infof("average grouping: found %d valid and unlocked user subscribes", len(userSubscribes)) + + // 1.5 查询所有参与计算的节点组ID + var calculationNodeGroups []group.NodeGroup + if err := tx.Model(&group.NodeGroup{}). + Select("id"). + Where("for_calculation = ?", true). + Scan(&calculationNodeGroups).Error; err != nil { + l.Errorw("failed to query calculation node groups", logger.Field("error", err.Error())) + return 0, err + } + + // 创建参与计算的节点组ID集合(用于快速查找) + calculationNodeGroupIds := make(map[int64]bool) + for _, ng := range calculationNodeGroups { + calculationNodeGroupIds[ng.Id] = true + } + + l.Infof("average grouping: found %d node groups with for_calculation=true", len(calculationNodeGroupIds)) + + // 2. 批量查询订阅的节点组ID信息 + subscribeIds := make([]int64, len(userSubscribes)) + for i, us := range userSubscribes { + subscribeIds[i] = us.SubscribeId + } + + type SubscribeInfo struct { + Id int64 `json:"id"` + NodeGroupIds string `json:"node_group_ids"` // JSON string + } + var subscribeInfos []SubscribeInfo + if err := tx.Model(&subscribe.Subscribe{}). + Select("id, node_group_ids"). + Where("id IN ?", subscribeIds). + Find(&subscribeInfos).Error; err != nil { + l.Errorw("failed to query subscribe infos", logger.Field("error", err.Error())) + return 0, err + } + + // 创建 subscribe_id -> SubscribeInfo 的映射 + subInfoMap := make(map[int64]SubscribeInfo) + for _, si := range subscribeInfos { + subInfoMap[si.Id] = si + } + + // 用于存储统计信息(按节点组ID统计用户数) + groupUsersMap := make(map[int64][]struct { + Id int64 `json:"id"` + Email string `json:"email"` + }) + nodeGroupUserCount := make(map[int64]int) // node_group_id -> user_count + nodeGroupNodeCount := make(map[int64]int) // node_group_id -> node_count + + // 3. 遍历所有用户订阅,按序平均分配节点组 + affectedCount := 0 + failedCount := 0 + + // 为每个订阅维护一个分配索引,用于按序循环分配 + subscribeAllocationIndex := make(map[int64]int) // subscribe_id -> current_index + + for _, us := range userSubscribes { + subInfo, ok := subInfoMap[us.SubscribeId] + if !ok { + l.Infow("subscribe not found", + logger.Field("user_subscribe_id", us.Id), + logger.Field("subscribe_id", us.SubscribeId)) + failedCount++ + continue + } + + // 解析订阅的节点组ID列表,并过滤出参与计算的节点组 + var nodeGroupIds []int64 + if subInfo.NodeGroupIds != "" && subInfo.NodeGroupIds != "[]" { + var allNodeGroupIds []int64 + if err := json.Unmarshal([]byte(subInfo.NodeGroupIds), &allNodeGroupIds); err != nil { + l.Errorw("failed to parse node_group_ids", + logger.Field("subscribe_id", subInfo.Id), + logger.Field("node_group_ids", subInfo.NodeGroupIds), + logger.Field("error", err.Error())) + failedCount++ + continue + } + + // 只保留参与计算的节点组 + for _, ngId := range allNodeGroupIds { + if calculationNodeGroupIds[ngId] { + nodeGroupIds = append(nodeGroupIds, ngId) + } + } + + if len(nodeGroupIds) == 0 && len(allNodeGroupIds) > 0 { + l.Debugw("all node_group_ids are not for calculation, setting to 0", + logger.Field("subscribe_id", subInfo.Id), + logger.Field("total_node_groups", len(allNodeGroupIds))) + } + } + + // 如果没有节点组ID,跳过 + if len(nodeGroupIds) == 0 { + l.Debugf("no valid node_group_ids for subscribe_id=%d, setting to 0", subInfo.Id) + if err := tx.Model(&user.Subscribe{}). + Where("id = ?", us.Id). + Update("node_group_id", 0).Error; err != nil { + l.Errorw("failed to update user_subscribe node_group_id", + logger.Field("user_subscribe_id", us.Id), + logger.Field("error", err.Error())) + failedCount++ + continue + } + } + + // 按序选择节点组ID(循环轮询分配) + selectedNodeGroupId := int64(0) + if len(nodeGroupIds) > 0 { + // 获取当前订阅的分配索引 + currentIndex := subscribeAllocationIndex[us.SubscribeId] + // 选择当前索引对应的节点组 + selectedNodeGroupId = nodeGroupIds[currentIndex] + // 更新索引,循环使用(轮询) + subscribeAllocationIndex[us.SubscribeId] = (currentIndex + 1) % len(nodeGroupIds) + + l.Debugf("assigning user_subscribe_id=%d (subscribe_id=%d) to node_group_id=%d (index=%d, total_options=%d, mode=sequential)", + us.Id, us.SubscribeId, selectedNodeGroupId, currentIndex, len(nodeGroupIds)) + } + + // 更新 user_subscribe 的 node_group_id 字段(单个ID) + if err := tx.Model(&user.Subscribe{}). + Where("id = ?", us.Id). + Update("node_group_id", selectedNodeGroupId).Error; err != nil { + l.Errorw("failed to update user_subscribe node_group_id", + logger.Field("user_subscribe_id", us.Id), + logger.Field("error", err.Error())) + failedCount++ + continue + } + + // 只统计有节点组的用户 + if selectedNodeGroupId > 0 { + // 查询用户邮箱,用于保存到历史记录 + email := l.getUserEmail(tx, us.UserId) + groupUsersMap[selectedNodeGroupId] = append(groupUsersMap[selectedNodeGroupId], struct { + Id int64 `json:"id"` + Email string `json:"email"` + }{ + Id: us.UserId, + Email: email, + }) + nodeGroupUserCount[selectedNodeGroupId]++ + } + + affectedCount++ + } + + l.Infof("average grouping completed: affected=%d, failed=%d", affectedCount, failedCount) + + // 4. 创建分组历史详情记录(按节点组ID统计) + for nodeGroupId, users := range groupUsersMap { + userCount := len(users) + if userCount == 0 { + continue + } + + // 统计该节点组的节点数 + var nodeCount int64 = 0 + if nodeGroupId > 0 { + if err := tx.Model(&node.Node{}). + Where("JSON_CONTAINS(node_group_ids, ?)", fmt.Sprintf("[%d]", nodeGroupId)). + Count(&nodeCount).Error; err != nil { + l.Errorw("failed to count nodes", + logger.Field("node_group_id", nodeGroupId), + logger.Field("error", err.Error())) + } + } + nodeGroupNodeCount[nodeGroupId] = int(nodeCount) + + // 序列化用户信息为 JSON + userDataJSON := "[]" + if jsonData, err := json.Marshal(users); err == nil { + userDataJSON = string(jsonData) + } else { + l.Errorw("failed to marshal user data", + logger.Field("node_group_id", nodeGroupId), + logger.Field("error", err.Error())) + } + + // 创建历史详情(使用 node_group_id 作为分组标识) + detail := &group.GroupHistoryDetail{ + HistoryId: historyId, + NodeGroupId: nodeGroupId, + UserCount: userCount, + NodeCount: int(nodeCount), + UserData: userDataJSON, + } + + if err := tx.Create(detail).Error; err != nil { + l.Errorw("failed to create group history detail", + logger.Field("node_group_id", nodeGroupId), + logger.Field("error", err.Error())) + } + + l.Infof("Average Group (node_group_id=%d): users=%d, nodes=%d", + nodeGroupId, userCount, nodeCount) + } + + return affectedCount, nil +} + +// executeSubscribeGrouping 实现基于订阅套餐的分组算法 +// 逻辑:查询有效订阅 → 获取订阅的 node_group_ids → 取第一个 node_group_id(如果有) → 更新 user_subscribe.node_group_id +// 订阅过期的用户 → 设置 node_group_id 为 0 +func (l *RecalculateGroupLogic) executeSubscribeGrouping(tx *gorm.DB, historyId int64) (int, error) { + // 1. 查询所有有效且未锁定的用户订阅(status IN (0, 1), group_locked = 0) + type UserSubscribeInfo struct { + Id int64 `json:"id"` + UserId int64 `json:"user_id"` + SubscribeId int64 `json:"subscribe_id"` + } + + var userSubscribes []UserSubscribeInfo + if err := tx.Model(&user.Subscribe{}). + Select("id, user_id, subscribe_id"). + Where("group_locked = ? AND status IN (0, 1)", 0). + Scan(&userSubscribes).Error; err != nil { + l.Errorw("failed to query user subscribes", logger.Field("error", err.Error())) + return 0, err + } + + if len(userSubscribes) == 0 { + l.Infof("subscribe grouping: no valid and unlocked user subscribes found") + return 0, nil + } + + l.Infof("subscribe grouping: found %d valid and unlocked user subscribes", len(userSubscribes)) + + // 1.5 查询所有参与计算的节点组ID + var calculationNodeGroups []group.NodeGroup + if err := tx.Model(&group.NodeGroup{}). + Select("id"). + Where("for_calculation = ?", true). + Scan(&calculationNodeGroups).Error; err != nil { + l.Errorw("failed to query calculation node groups", logger.Field("error", err.Error())) + return 0, err + } + + // 创建参与计算的节点组ID集合(用于快速查找) + calculationNodeGroupIds := make(map[int64]bool) + for _, ng := range calculationNodeGroups { + calculationNodeGroupIds[ng.Id] = true + } + + l.Infof("subscribe grouping: found %d node groups with for_calculation=true", len(calculationNodeGroupIds)) + + // 2. 批量查询订阅的节点组ID信息 + subscribeIds := make([]int64, len(userSubscribes)) + for i, us := range userSubscribes { + subscribeIds[i] = us.SubscribeId + } + + type SubscribeInfo struct { + Id int64 `json:"id"` + NodeGroupIds string `json:"node_group_ids"` // JSON string + } + var subscribeInfos []SubscribeInfo + if err := tx.Model(&subscribe.Subscribe{}). + Select("id, node_group_ids"). + Where("id IN ?", subscribeIds). + Find(&subscribeInfos).Error; err != nil { + l.Errorw("failed to query subscribe infos", logger.Field("error", err.Error())) + return 0, err + } + + // 创建 subscribe_id -> SubscribeInfo 的映射 + subInfoMap := make(map[int64]SubscribeInfo) + for _, si := range subscribeInfos { + subInfoMap[si.Id] = si + } + + // 用于存储统计信息(按节点组ID统计用户数) + type UserInfo struct { + Id int64 `json:"id"` + Email string `json:"email"` + } + groupUsersMap := make(map[int64][]UserInfo) + nodeGroupUserCount := make(map[int64]int) // node_group_id -> user_count + nodeGroupNodeCount := make(map[int64]int) // node_group_id -> node_count + + // 3. 遍历所有用户订阅,取第一个节点组ID + affectedCount := 0 + failedCount := 0 + + for _, us := range userSubscribes { + subInfo, ok := subInfoMap[us.SubscribeId] + if !ok { + l.Infow("subscribe not found", + logger.Field("user_subscribe_id", us.Id), + logger.Field("subscribe_id", us.SubscribeId)) + failedCount++ + continue + } + + // 解析订阅的节点组ID列表,并过滤出参与计算的节点组 + var nodeGroupIds []int64 + if subInfo.NodeGroupIds != "" && subInfo.NodeGroupIds != "[]" { + var allNodeGroupIds []int64 + if err := json.Unmarshal([]byte(subInfo.NodeGroupIds), &allNodeGroupIds); err != nil { + l.Errorw("failed to parse node_group_ids", + logger.Field("subscribe_id", subInfo.Id), + logger.Field("node_group_ids", subInfo.NodeGroupIds), + logger.Field("error", err.Error())) + failedCount++ + continue + } + + // 只保留参与计算的节点组 + for _, ngId := range allNodeGroupIds { + if calculationNodeGroupIds[ngId] { + nodeGroupIds = append(nodeGroupIds, ngId) + } + } + + if len(nodeGroupIds) == 0 && len(allNodeGroupIds) > 0 { + l.Debugw("all node_group_ids are not for calculation, setting to 0", + logger.Field("subscribe_id", subInfo.Id), + logger.Field("total_node_groups", len(allNodeGroupIds))) + } + } + + // 取第一个参与计算的节点组ID(如果有),否则设置为 0 + selectedNodeGroupId := int64(0) + if len(nodeGroupIds) > 0 { + selectedNodeGroupId = nodeGroupIds[0] + } + + l.Debugf("assigning user_subscribe_id=%d (subscribe_id=%d) to node_group_id=%d (total_options=%d, selected_first)", + us.Id, us.SubscribeId, selectedNodeGroupId, len(nodeGroupIds)) + + // 更新 user_subscribe 的 node_group_id 字段 + if err := tx.Model(&user.Subscribe{}). + Where("id = ?", us.Id). + Update("node_group_id", selectedNodeGroupId).Error; err != nil { + l.Errorw("failed to update user_subscribe node_group_id", + logger.Field("user_subscribe_id", us.Id), + logger.Field("error", err.Error())) + failedCount++ + continue + } + + // 只统计有节点组的用户 + if selectedNodeGroupId > 0 { + // 查询用户邮箱,用于保存到历史记录 + email := l.getUserEmail(tx, us.UserId) + groupUsersMap[selectedNodeGroupId] = append(groupUsersMap[selectedNodeGroupId], UserInfo{ + Id: us.UserId, + Email: email, + }) + nodeGroupUserCount[selectedNodeGroupId]++ + } + + affectedCount++ + } + + l.Infof("subscribe grouping completed: affected=%d, failed=%d", affectedCount, failedCount) + + // 4. 处理订阅过期/失效的用户,设置 node_group_id 为 0 + // 查询所有没有有效订阅且未锁定的用户订阅记录 + var expiredUserSubscribes []struct { + Id int64 `json:"id"` + UserId int64 `json:"user_id"` + } + + if err := tx.Raw(` + SELECT us.id, us.user_id + FROM user_subscribe as us + WHERE us.group_locked = 0 + AND us.status NOT IN (0, 1) + `).Scan(&expiredUserSubscribes).Error; err != nil { + l.Errorw("failed to query expired user subscribes", logger.Field("error", err.Error())) + // 继续处理,不因为过期用户查询失败而影响 + } else { + l.Infof("found %d expired user subscribes for subscribe-based grouping, will set node_group_id to 0", len(expiredUserSubscribes)) + + expiredAffectedCount := 0 + for _, eu := range expiredUserSubscribes { + // 更新 user_subscribe 表的 node_group_id 字段到 0 + if err := tx.Model(&user.Subscribe{}). + Where("id = ?", eu.Id). + Update("node_group_id", 0).Error; err != nil { + l.Errorw("failed to update expired user subscribe node_group_id", + logger.Field("user_subscribe_id", eu.Id), + logger.Field("error", err.Error())) + continue + } + + expiredAffectedCount++ + } + + l.Infof("expired user subscribes grouping completed: affected=%d", expiredAffectedCount) + } + + // 5. 创建分组历史详情记录(按节点组ID统计) + for nodeGroupId, users := range groupUsersMap { + userCount := len(users) + if userCount == 0 { + continue + } + + // 统计该节点组的节点数 + var nodeCount int64 = 0 + if nodeGroupId > 0 { + if err := tx.Model(&node.Node{}). + Where("JSON_CONTAINS(node_group_ids, ?)", fmt.Sprintf("[%d]", nodeGroupId)). + Count(&nodeCount).Error; err != nil { + l.Errorw("failed to count nodes", + logger.Field("node_group_id", nodeGroupId), + logger.Field("error", err.Error())) + } + } + nodeGroupNodeCount[nodeGroupId] = int(nodeCount) + + // 序列化用户信息为 JSON + userDataJSON := "[]" + if jsonData, err := json.Marshal(users); err == nil { + userDataJSON = string(jsonData) + } else { + l.Errorw("failed to marshal user data", + logger.Field("node_group_id", nodeGroupId), + logger.Field("error", err.Error())) + } + + // 创建历史详情 + detail := &group.GroupHistoryDetail{ + HistoryId: historyId, + NodeGroupId: nodeGroupId, + UserCount: userCount, + NodeCount: int(nodeCount), + UserData: userDataJSON, + } + + if err := tx.Create(detail).Error; err != nil { + l.Errorw("failed to create group history detail", + logger.Field("node_group_id", nodeGroupId), + logger.Field("error", err.Error())) + } + + l.Infof("Subscribe Group (node_group_id=%d): users=%d, nodes=%d", + nodeGroupId, userCount, nodeCount) + } + + return affectedCount, nil +} + +// executeTrafficGrouping 实现基于流量的分组算法 +// 逻辑:根据配置的流量范围,将用户分配到对应的用户组 +func (l *RecalculateGroupLogic) executeTrafficGrouping(tx *gorm.DB, historyId int64) (int, error) { + // 用于存储每个节点组的用户信息(id 和 email) + type UserInfo struct { + Id int64 `json:"id"` + Email string `json:"email"` + } + groupUsersMap := make(map[int64][]UserInfo) // node_group_id -> []UserInfo + + // 1. 获取所有设置了流量区间的节点组 + var nodeGroups []group.NodeGroup + if err := tx.Where("for_calculation = ?", true). + Where("(min_traffic_gb > 0 OR max_traffic_gb > 0)"). + Find(&nodeGroups).Error; err != nil { + l.Errorw("failed to query node groups", logger.Field("error", err.Error())) + return 0, err + } + + if len(nodeGroups) == 0 { + l.Infow("no node groups with traffic ranges configured") + return 0, nil + } + + l.Infow("executeTrafficGrouping loaded node groups", + logger.Field("node_groups_count", len(nodeGroups))) + + // 2. 查询所有有效且未锁定的用户订阅及其已用流量 + type UserSubscribeInfo struct { + Id int64 + UserId int64 + Upload int64 + Download int64 + UsedTraffic int64 // 已用流量 = upload + download (bytes) + } + + var userSubscribes []UserSubscribeInfo + if err := tx.Model(&user.Subscribe{}). + Select("id, user_id, upload, download, (upload + download) as used_traffic"). + Where("group_locked = ? AND status IN (0, 1)", 0). // 只查询有效且未锁定的用户订阅 + Scan(&userSubscribes).Error; err != nil { + l.Errorw("failed to query user subscribes", logger.Field("error", err.Error())) + return 0, err + } + + if len(userSubscribes) == 0 { + l.Infow("no valid and unlocked user subscribes found") + return 0, nil + } + + l.Infow("found user subscribes for traffic-based grouping", logger.Field("count", len(userSubscribes))) + + // 3. 根据流量范围分配节点组ID到用户订阅 + affectedCount := 0 + groupUserCount := make(map[int64]int) // node_group_id -> user_count + + for _, us := range userSubscribes { + // 将字节转换为 GB + usedTrafficGB := float64(us.UsedTraffic) / (1024 * 1024 * 1024) + + // 查找匹配的流量范围(使用左闭右开区间 [Min, Max)) + var targetNodeGroupId int64 = 0 + for _, ng := range nodeGroups { + if ng.MinTrafficGB == nil || ng.MaxTrafficGB == nil { + continue + } + minTraffic := float64(*ng.MinTrafficGB) + maxTraffic := float64(*ng.MaxTrafficGB) + + // 检查是否在区间内 [min, max) + if usedTrafficGB >= minTraffic && usedTrafficGB < maxTraffic { + targetNodeGroupId = ng.Id + break + } + } + + // 如果没有匹配到任何范围,targetNodeGroupId 保持为 0(不分配节点组) + + // 更新 user_subscribe 的 node_group_id 字段 + if err := tx.Model(&user.Subscribe{}). + Where("id = ?", us.Id). + Update("node_group_id", targetNodeGroupId).Error; err != nil { + l.Errorw("failed to update user subscribe node_group_id", + logger.Field("user_subscribe_id", us.Id), + logger.Field("target_node_group_id", targetNodeGroupId), + logger.Field("error", err.Error())) + continue + } + + // 只有分配了节点组的用户才记录到历史 + if targetNodeGroupId > 0 { + // 查询用户邮箱,用于保存到历史记录 + email := l.getUserEmail(tx, us.UserId) + userInfo := UserInfo{ + Id: us.UserId, + Email: email, + } + groupUsersMap[targetNodeGroupId] = append(groupUsersMap[targetNodeGroupId], userInfo) + groupUserCount[targetNodeGroupId]++ + + l.Debugf("assigned user subscribe %d (traffic: %.2fGB) to node group %d", + us.Id, usedTrafficGB, targetNodeGroupId) + } else { + l.Debugf("user subscribe %d (traffic: %.2fGB) not assigned to any node group", + us.Id, usedTrafficGB) + } + + affectedCount++ + } + + l.Infof("traffic-based grouping completed: affected_subscribes=%d", affectedCount) + + // 4. 创建分组历史详情记录(只统计有用户的节点组) + nodeGroupCount := make(map[int64]int) // node_group_id -> node_count + for _, ng := range nodeGroups { + nodeGroupCount[ng.Id] = 1 // 每个节点组计为1 + } + + for nodeGroupId, userCount := range groupUserCount { + userDataJSON, err := json.Marshal(groupUsersMap[nodeGroupId]) + if err != nil { + l.Errorw("failed to marshal user data", + logger.Field("node_group_id", nodeGroupId), + logger.Field("error", err.Error())) + continue + } + + detail := group.GroupHistoryDetail{ + HistoryId: historyId, + NodeGroupId: nodeGroupId, + UserCount: userCount, + NodeCount: nodeGroupCount[nodeGroupId], + UserData: string(userDataJSON), + } + if err := tx.Create(&detail).Error; err != nil { + l.Errorw("failed to create group history detail", + logger.Field("history_id", historyId), + logger.Field("node_group_id", nodeGroupId), + logger.Field("error", err.Error())) + } + } + + return affectedCount, nil +} + +// containsIgnoreCase checks if a string contains another substring (case-insensitive) +func containsIgnoreCase(s, substr string) bool { + if len(substr) == 0 { + return true + } + if len(s) < len(substr) { + return false + } + + // Simple case-insensitive contains check + sLower := toLower(s) + substrLower := toLower(substr) + + return contains(sLower, substrLower) +} + +// toLower converts a string to lowercase +func toLower(s string) string { + result := make([]rune, len(s)) + for i, r := range s { + if r >= 'A' && r <= 'Z' { + result[i] = r + ('a' - 'A') + } else { + result[i] = r + } + } + return string(result) +} + +// contains checks if a string contains another substring (case-sensitive) +func contains(s, substr string) bool { + return len(s) >= len(substr) && indexOf(s, substr) >= 0 +} + +// indexOf returns the index of the first occurrence of substr in s, or -1 if not found +func indexOf(s, substr string) int { + n := len(substr) + if n == 0 { + return 0 + } + if n > len(s) { + return -1 + } + + // Simple string search + for i := 0; i <= len(s)-n; i++ { + if s[i:i+n] == substr { + return i + } + } + return -1 +} diff --git a/internal/logic/admin/group/resetGroupsLogic.go b/internal/logic/admin/group/resetGroupsLogic.go new file mode 100644 index 0000000..eaaa098 --- /dev/null +++ b/internal/logic/admin/group/resetGroupsLogic.go @@ -0,0 +1,82 @@ +package group + +import ( + "context" + + "github.com/perfect-panel/server/internal/model/group" + "github.com/perfect-panel/server/internal/model/node" + "github.com/perfect-panel/server/internal/model/subscribe" + "github.com/perfect-panel/server/internal/model/system" + "github.com/perfect-panel/server/internal/svc" + "github.com/perfect-panel/server/pkg/logger" +) + +type ResetGroupsLogic struct { + logger.Logger + ctx context.Context + svcCtx *svc.ServiceContext +} + +// NewResetGroupsLogic Reset all groups (delete all node groups and reset related data) +func NewResetGroupsLogic(ctx context.Context, svcCtx *svc.ServiceContext) *ResetGroupsLogic { + return &ResetGroupsLogic{ + Logger: logger.WithContext(ctx), + ctx: ctx, + svcCtx: svcCtx, + } +} + +func (l *ResetGroupsLogic) ResetGroups() error { + // 1. Delete all node groups + err := l.svcCtx.DB.Where("1 = 1").Delete(&group.NodeGroup{}).Error + if err != nil { + l.Errorw("Failed to delete all node groups", logger.Field("error", err.Error())) + return err + } + l.Infow("Successfully deleted all node groups") + + // 2. Clear node_group_ids for all subscribes (products) + err = l.svcCtx.DB.Model(&subscribe.Subscribe{}).Where("1 = 1").Update("node_group_ids", "[]").Error + if err != nil { + l.Errorw("Failed to clear subscribes' node_group_ids", logger.Field("error", err.Error())) + return err + } + l.Infow("Successfully cleared all subscribes' node_group_ids") + + // 3. Clear node_group_ids for all nodes + err = l.svcCtx.DB.Model(&node.Node{}).Where("1 = 1").Update("node_group_ids", "[]").Error + if err != nil { + l.Errorw("Failed to clear nodes' node_group_ids", logger.Field("error", err.Error())) + return err + } + l.Infow("Successfully cleared all nodes' node_group_ids") + + // 4. Clear group history + err = l.svcCtx.DB.Where("1 = 1").Delete(&group.GroupHistory{}).Error + if err != nil { + l.Errorw("Failed to clear group history", logger.Field("error", err.Error())) + // Non-critical error, continue anyway + } else { + l.Infow("Successfully cleared group history") + } + + // 7. Clear group history details + err = l.svcCtx.DB.Where("1 = 1").Delete(&group.GroupHistoryDetail{}).Error + if err != nil { + l.Errorw("Failed to clear group history details", logger.Field("error", err.Error())) + // Non-critical error, continue anyway + } else { + l.Infow("Successfully cleared group history details") + } + + // 5. Delete all group config settings + err = l.svcCtx.DB.Where("`category` = ?", "group").Delete(&system.System{}).Error + if err != nil { + l.Errorw("Failed to delete group config", logger.Field("error", err.Error())) + return err + } + l.Infow("Successfully deleted all group config settings") + + l.Infow("Group reset completed successfully") + return nil +} diff --git a/internal/logic/admin/group/updateGroupConfigLogic.go b/internal/logic/admin/group/updateGroupConfigLogic.go new file mode 100644 index 0000000..0980373 --- /dev/null +++ b/internal/logic/admin/group/updateGroupConfigLogic.go @@ -0,0 +1,188 @@ +package group + +import ( + "context" + "encoding/json" + + "github.com/perfect-panel/server/internal/model/system" + "github.com/perfect-panel/server/internal/svc" + "github.com/perfect-panel/server/internal/types" + "github.com/perfect-panel/server/pkg/logger" + "github.com/pkg/errors" + "gorm.io/gorm" +) + +type UpdateGroupConfigLogic struct { + logger.Logger + ctx context.Context + svcCtx *svc.ServiceContext +} + +// Update group config +func NewUpdateGroupConfigLogic(ctx context.Context, svcCtx *svc.ServiceContext) *UpdateGroupConfigLogic { + return &UpdateGroupConfigLogic{ + Logger: logger.WithContext(ctx), + ctx: ctx, + svcCtx: svcCtx, + } +} + +func (l *UpdateGroupConfigLogic) UpdateGroupConfig(req *types.UpdateGroupConfigRequest) error { + // 验证 mode 是否为合法值 + if req.Mode != "" { + if req.Mode != "average" && req.Mode != "subscribe" && req.Mode != "traffic" { + return errors.New("invalid mode, must be one of: average, subscribe, traffic") + } + } + + // 使用 GORM Transaction 更新配置 + err := l.svcCtx.DB.Transaction(func(tx *gorm.DB) error { + // 更新 enabled 配置(使用 Upsert 逻辑) + enabledValue := "false" + if req.Enabled { + enabledValue = "true" + } + result := tx.Model(&system.System{}). + Where("`category` = 'group' and `key` = ?", "enabled"). + Update("value", enabledValue) + if result.Error != nil { + l.Errorw("failed to update group enabled config", logger.Field("error", result.Error.Error())) + return result.Error + } + // 如果没有更新任何行,说明记录不存在,需要插入 + if result.RowsAffected == 0 { + if err := tx.Create(&system.System{ + Category: "group", + Key: "enabled", + Value: enabledValue, + Desc: "Group Feature Enabled", + }).Error; err != nil { + l.Errorw("failed to create group enabled config", logger.Field("error", err.Error())) + return err + } + } + + // 更新 mode 配置(使用 Upsert 逻辑) + if req.Mode != "" { + result := tx.Model(&system.System{}). + Where("`category` = 'group' and `key` = ?", "mode"). + Update("value", req.Mode) + if result.Error != nil { + l.Errorw("failed to update group mode config", logger.Field("error", result.Error.Error())) + return result.Error + } + // 如果没有更新任何行,说明记录不存在,需要插入 + if result.RowsAffected == 0 { + if err := tx.Create(&system.System{ + Category: "group", + Key: "mode", + Value: req.Mode, + Desc: "Group Mode", + }).Error; err != nil { + l.Errorw("failed to create group mode config", logger.Field("error", err.Error())) + return err + } + } + } + + // 更新 JSON 配置 + if req.Config != nil { + // 更新 average_config + if averageConfig, ok := req.Config["average_config"]; ok { + jsonBytes, err := json.Marshal(averageConfig) + if err != nil { + l.Errorw("failed to marshal average_config", logger.Field("error", err.Error())) + return errors.Wrap(err, "failed to marshal average_config") + } + // 使用 Upsert 逻辑:先尝试 UPDATE,如果不存在则 INSERT + result := tx.Model(&system.System{}). + Where("`category` = 'group' and `key` = ?", "average_config"). + Update("value", string(jsonBytes)) + if result.Error != nil { + l.Errorw("failed to update group average_config", logger.Field("error", result.Error.Error())) + return result.Error + } + // 如果没有更新任何行,说明记录不存在,需要插入 + if result.RowsAffected == 0 { + if err := tx.Create(&system.System{ + Category: "group", + Key: "average_config", + Value: string(jsonBytes), + Desc: "Average Group Config", + }).Error; err != nil { + l.Errorw("failed to create group average_config", logger.Field("error", err.Error())) + return err + } + } + } + + // 更新 subscribe_config + if subscribeConfig, ok := req.Config["subscribe_config"]; ok { + jsonBytes, err := json.Marshal(subscribeConfig) + if err != nil { + l.Errorw("failed to marshal subscribe_config", logger.Field("error", err.Error())) + return errors.Wrap(err, "failed to marshal subscribe_config") + } + // 使用 Upsert 逻辑:先尝试 UPDATE,如果不存在则 INSERT + result := tx.Model(&system.System{}). + Where("`category` = 'group' and `key` = ?", "subscribe_config"). + Update("value", string(jsonBytes)) + if result.Error != nil { + l.Errorw("failed to update group subscribe_config", logger.Field("error", result.Error.Error())) + return result.Error + } + // 如果没有更新任何行,说明记录不存在,需要插入 + if result.RowsAffected == 0 { + if err := tx.Create(&system.System{ + Category: "group", + Key: "subscribe_config", + Value: string(jsonBytes), + Desc: "Subscribe Group Config", + }).Error; err != nil { + l.Errorw("failed to create group subscribe_config", logger.Field("error", err.Error())) + return err + } + } + } + + // 更新 traffic_config + if trafficConfig, ok := req.Config["traffic_config"]; ok { + jsonBytes, err := json.Marshal(trafficConfig) + if err != nil { + l.Errorw("failed to marshal traffic_config", logger.Field("error", err.Error())) + return errors.Wrap(err, "failed to marshal traffic_config") + } + // 使用 Upsert 逻辑:先尝试 UPDATE,如果不存在则 INSERT + result := tx.Model(&system.System{}). + Where("`category` = 'group' and `key` = ?", "traffic_config"). + Update("value", string(jsonBytes)) + if result.Error != nil { + l.Errorw("failed to update group traffic_config", logger.Field("error", result.Error.Error())) + return result.Error + } + // 如果没有更新任何行,说明记录不存在,需要插入 + if result.RowsAffected == 0 { + if err := tx.Create(&system.System{ + Category: "group", + Key: "traffic_config", + Value: string(jsonBytes), + Desc: "Traffic Group Config", + }).Error; err != nil { + l.Errorw("failed to create group traffic_config", logger.Field("error", err.Error())) + return err + } + } + } + } + + return nil + }) + + if err != nil { + l.Errorw("failed to update group config", logger.Field("error", err.Error())) + return err + } + + l.Infof("group config updated successfully: enabled=%v, mode=%s", req.Enabled, req.Mode) + return nil +} diff --git a/internal/logic/admin/group/updateNodeGroupLogic.go b/internal/logic/admin/group/updateNodeGroupLogic.go new file mode 100644 index 0000000..b7d6fa4 --- /dev/null +++ b/internal/logic/admin/group/updateNodeGroupLogic.go @@ -0,0 +1,185 @@ +package group + +import ( + "context" + "errors" + "time" + + "github.com/perfect-panel/server/internal/model/group" + "github.com/perfect-panel/server/internal/model/subscribe" + "github.com/perfect-panel/server/internal/svc" + "github.com/perfect-panel/server/internal/types" + "github.com/perfect-panel/server/pkg/logger" + "gorm.io/gorm" +) + +type UpdateNodeGroupLogic struct { + logger.Logger + ctx context.Context + svcCtx *svc.ServiceContext +} + +func NewUpdateNodeGroupLogic(ctx context.Context, svcCtx *svc.ServiceContext) *UpdateNodeGroupLogic { + return &UpdateNodeGroupLogic{ + Logger: logger.WithContext(ctx), + ctx: ctx, + svcCtx: svcCtx, + } +} + +func (l *UpdateNodeGroupLogic) UpdateNodeGroup(req *types.UpdateNodeGroupRequest) error { + // 检查节点组是否存在 + var nodeGroup group.NodeGroup + if err := l.svcCtx.DB.Where("id = ?", req.Id).First(&nodeGroup).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return errors.New("node group not found") + } + logger.Errorf("failed to find node group: %v", err) + return err + } + + // 验证:系统中只能有一个过期节点组 + if req.IsExpiredGroup != nil && *req.IsExpiredGroup { + var count int64 + err := l.svcCtx.DB.Model(&group.NodeGroup{}). + Where("is_expired_group = ? AND id != ?", true, req.Id). + Count(&count).Error + if err != nil { + logger.Errorf("failed to check expired group count: %v", err) + return err + } + if count > 0 { + return errors.New("system already has an expired node group, cannot create multiple") + } + + // 验证:被订阅商品设置为默认节点组的不能设置为过期节点组 + var subscribeCount int64 + err = l.svcCtx.DB.Model(&subscribe.Subscribe{}). + Where("node_group_id = ?", req.Id). + Count(&subscribeCount).Error + if err != nil { + logger.Errorf("failed to check subscribe usage: %v", err) + return err + } + if subscribeCount > 0 { + return errors.New("this node group is used as default node group in subscription products, cannot set as expired group") + } + } + + // 构建更新数据 + updates := map[string]interface{}{ + "updated_at": time.Now(), + } + if req.Name != "" { + updates["name"] = req.Name + } + if req.Description != "" { + updates["description"] = req.Description + } + if req.Sort != 0 { + updates["sort"] = req.Sort + } + if req.ForCalculation != nil { + updates["for_calculation"] = *req.ForCalculation + } + if req.IsExpiredGroup != nil { + updates["is_expired_group"] = *req.IsExpiredGroup + // 过期节点组不参与分组计算 + if *req.IsExpiredGroup { + updates["for_calculation"] = false + } + } + if req.ExpiredDaysLimit != nil { + updates["expired_days_limit"] = *req.ExpiredDaysLimit + } + if req.MaxTrafficGBExpired != nil { + updates["max_traffic_gb_expired"] = *req.MaxTrafficGBExpired + } + if req.SpeedLimit != nil { + updates["speed_limit"] = *req.SpeedLimit + } + + // 获取新的流量区间值 + newMinTraffic := nodeGroup.MinTrafficGB + newMaxTraffic := nodeGroup.MaxTrafficGB + if req.MinTrafficGB != nil { + newMinTraffic = req.MinTrafficGB + updates["min_traffic_gb"] = *req.MinTrafficGB + } + if req.MaxTrafficGB != nil { + newMaxTraffic = req.MaxTrafficGB + updates["max_traffic_gb"] = *req.MaxTrafficGB + } + + // 校验流量区间 + if err := l.validateTrafficRange(int(req.Id), newMinTraffic, newMaxTraffic); err != nil { + return err + } + + // 执行更新 + if err := l.svcCtx.DB.Model(&nodeGroup).Updates(updates).Error; err != nil { + logger.Errorf("failed to update node group: %v", err) + return err + } + + logger.Infof("updated node group: id=%d", req.Id) + return nil +} + +// validateTrafficRange 校验流量区间:不能重叠、不能留空档、最小值不能大于最大值 +func (l *UpdateNodeGroupLogic) validateTrafficRange(currentNodeGroupId int, newMin, newMax *int64) error { + // 处理指针值 + minVal := int64(0) + maxVal := int64(0) + if newMin != nil { + minVal = *newMin + } + if newMax != nil { + maxVal = *newMax + } + + // 检查最小值是否大于最大值 + if minVal > maxVal { + return errors.New("minimum traffic cannot exceed maximum traffic") + } + + // 如果两个值都为0,表示不参与流量分组,不需要校验 + if minVal == 0 && maxVal == 0 { + return nil + } + + // 查询所有其他设置了流量区间的节点组 + var otherGroups []group.NodeGroup + if err := l.svcCtx.DB. + Where("id != ?", currentNodeGroupId). + Where("(min_traffic_gb > 0 OR max_traffic_gb > 0)"). + Find(&otherGroups).Error; err != nil { + logger.Errorf("failed to query other node groups: %v", err) + return err + } + + // 检查是否有重叠 + for _, other := range otherGroups { + otherMin := int64(0) + otherMax := int64(0) + if other.MinTrafficGB != nil { + otherMin = *other.MinTrafficGB + } + if other.MaxTrafficGB != nil { + otherMax = *other.MaxTrafficGB + } + + // 如果对方也没设置区间,跳过 + if otherMin == 0 && otherMax == 0 { + continue + } + + // 检查是否有重叠: 如果两个区间相交,就是重叠 + // 不重叠的条件是: newMax <= otherMin OR newMin >= otherMax + if !(maxVal <= otherMin || minVal >= otherMax) { + return errors.New("traffic range overlaps with another node group") + } + } + + return nil +} diff --git a/internal/logic/admin/server/createNodeLogic.go b/internal/logic/admin/server/createNodeLogic.go index 78ce987..38044de 100644 --- a/internal/logic/admin/server/createNodeLogic.go +++ b/internal/logic/admin/server/createNodeLogic.go @@ -29,13 +29,14 @@ func NewCreateNodeLogic(ctx context.Context, svcCtx *svc.ServiceContext) *Create func (l *CreateNodeLogic) CreateNode(req *types.CreateNodeRequest) error { data := node.Node{ - Name: req.Name, - Tags: tool.StringSliceToString(req.Tags), - Enabled: req.Enabled, - Port: req.Port, - Address: req.Address, - ServerId: req.ServerId, - Protocol: req.Protocol, + Name: req.Name, + Tags: tool.StringSliceToString(req.Tags), + Enabled: req.Enabled, + Port: req.Port, + Address: req.Address, + ServerId: req.ServerId, + Protocol: req.Protocol, + NodeGroupIds: node.JSONInt64Slice(req.NodeGroupIds), } err := l.svcCtx.NodeModel.InsertNode(l.ctx, &data) if err != nil { diff --git a/internal/logic/admin/server/filterNodeListLogic.go b/internal/logic/admin/server/filterNodeListLogic.go index 2e41cec..47f8574 100644 --- a/internal/logic/admin/server/filterNodeListLogic.go +++ b/internal/logic/admin/server/filterNodeListLogic.go @@ -29,10 +29,17 @@ func NewFilterNodeListLogic(ctx context.Context, svcCtx *svc.ServiceContext) *Fi } func (l *FilterNodeListLogic) FilterNodeList(req *types.FilterNodeListRequest) (resp *types.FilterNodeListResponse, err error) { + // Convert NodeGroupId to []int64 for model + var nodeGroupIds []int64 + if req.NodeGroupId != nil { + nodeGroupIds = []int64{*req.NodeGroupId} + } + total, data, err := l.svcCtx.NodeModel.FilterNodeList(l.ctx, &node.FilterNodeParams{ - Page: req.Page, - Size: req.Size, - Search: req.Search, + Page: req.Page, + Size: req.Size, + Search: req.Search, + NodeGroupIds: nodeGroupIds, }) if err != nil { @@ -43,17 +50,18 @@ func (l *FilterNodeListLogic) FilterNodeList(req *types.FilterNodeListRequest) ( list := make([]types.Node, 0) for _, datum := range data { list = append(list, types.Node{ - Id: datum.Id, - Name: datum.Name, - Tags: tool.RemoveDuplicateElements(strings.Split(datum.Tags, ",")...), - Port: datum.Port, - Address: datum.Address, - ServerId: datum.ServerId, - Protocol: datum.Protocol, - Enabled: datum.Enabled, - Sort: datum.Sort, - CreatedAt: datum.CreatedAt.UnixMilli(), - UpdatedAt: datum.UpdatedAt.UnixMilli(), + Id: datum.Id, + Name: datum.Name, + Tags: tool.RemoveDuplicateElements(strings.Split(datum.Tags, ",")...), + Port: datum.Port, + Address: datum.Address, + ServerId: datum.ServerId, + Protocol: datum.Protocol, + Enabled: datum.Enabled, + Sort: datum.Sort, + NodeGroupIds: []int64(datum.NodeGroupIds), + CreatedAt: datum.CreatedAt.UnixMilli(), + UpdatedAt: datum.UpdatedAt.UnixMilli(), }) } diff --git a/internal/logic/admin/server/updateNodeLogic.go b/internal/logic/admin/server/updateNodeLogic.go index 2af8a4d..3b0c291 100644 --- a/internal/logic/admin/server/updateNodeLogic.go +++ b/internal/logic/admin/server/updateNodeLogic.go @@ -40,6 +40,7 @@ func (l *UpdateNodeLogic) UpdateNode(req *types.UpdateNodeRequest) error { data.Address = req.Address data.Protocol = req.Protocol data.Enabled = req.Enabled + data.NodeGroupIds = node.JSONInt64Slice(req.NodeGroupIds) err = l.svcCtx.NodeModel.UpdateNode(l.ctx, data) if err != nil { l.Errorw("[UpdateNode] Update Database Error: ", logger.Field("error", err.Error())) diff --git a/internal/logic/admin/subscribe/createSubscribeLogic.go b/internal/logic/admin/subscribe/createSubscribeLogic.go index 411b026..290da21 100644 --- a/internal/logic/admin/subscribe/createSubscribeLogic.go +++ b/internal/logic/admin/subscribe/createSubscribeLogic.go @@ -34,6 +34,12 @@ func (l *CreateSubscribeLogic) CreateSubscribe(req *types.CreateSubscribeRequest val, _ := json.Marshal(req.Discount) discount = string(val) } + + trafficLimit := "" + if len(req.TrafficLimit) > 0 { + val, _ := json.Marshal(req.TrafficLimit) + trafficLimit = string(val) + } sub := &subscribe.Subscribe{ Id: 0, Name: req.Name, @@ -51,6 +57,9 @@ func (l *CreateSubscribeLogic) CreateSubscribe(req *types.CreateSubscribeRequest NewUserOnly: req.NewUserOnly, Nodes: tool.Int64SliceToString(req.Nodes), NodeTags: tool.StringSliceToString(req.NodeTags), + NodeGroupIds: subscribe.JSONInt64Slice(req.NodeGroupIds), + NodeGroupId: req.NodeGroupId, + TrafficLimit: trafficLimit, Show: req.Show, Sell: req.Sell, Sort: 0, diff --git a/internal/logic/admin/subscribe/getSubscribeDetailsLogic.go b/internal/logic/admin/subscribe/getSubscribeDetailsLogic.go index 6defdf1..fc29938 100644 --- a/internal/logic/admin/subscribe/getSubscribeDetailsLogic.go +++ b/internal/logic/admin/subscribe/getSubscribeDetailsLogic.go @@ -42,6 +42,12 @@ func (l *GetSubscribeDetailsLogic) GetSubscribeDetails(req *types.GetSubscribeDe l.Logger.Error("[GetSubscribeDetailsLogic] JSON unmarshal failed: ", logger.Field("error", err.Error()), logger.Field("discount", sub.Discount)) } } + if sub.TrafficLimit != "" { + err = json.Unmarshal([]byte(sub.TrafficLimit), &resp.TrafficLimit) + if err != nil { + l.Logger.Error("[GetSubscribeDetailsLogic] JSON unmarshal failed: ", logger.Field("error", err.Error()), logger.Field("traffic_limit", sub.TrafficLimit)) + } + } resp.Nodes = tool.StringToInt64Slice(sub.Nodes) resp.NodeTags = strings.Split(sub.NodeTags, ",") return resp, nil diff --git a/internal/logic/admin/subscribe/getSubscribeListLogic.go b/internal/logic/admin/subscribe/getSubscribeListLogic.go index e8c7866..6cf6ba6 100644 --- a/internal/logic/admin/subscribe/getSubscribeListLogic.go +++ b/internal/logic/admin/subscribe/getSubscribeListLogic.go @@ -30,12 +30,20 @@ func NewGetSubscribeListLogic(ctx context.Context, svcCtx *svc.ServiceContext) * } func (l *GetSubscribeListLogic) GetSubscribeList(req *types.GetSubscribeListRequest) (resp *types.GetSubscribeListResponse, err error) { - total, list, err := l.svcCtx.SubscribeModel.FilterList(l.ctx, &subscribe.FilterParams{ + // Build filter params + filterParams := &subscribe.FilterParams{ Page: int(req.Page), Size: int(req.Size), Language: req.Language, Search: req.Search, - }) + } + + // Add NodeGroupId filter if provided + if req.NodeGroupId > 0 { + filterParams.NodeGroupId = &req.NodeGroupId + } + + total, list, err := l.svcCtx.SubscribeModel.FilterList(l.ctx, filterParams) if err != nil { l.Logger.Error("[GetSubscribeListLogic] get subscribe list failed: ", logger.Field("error", err.Error())) return nil, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "get subscribe list failed: %v", err.Error()) @@ -54,8 +62,22 @@ func (l *GetSubscribeListLogic) GetSubscribeList(req *types.GetSubscribeListRequ l.Logger.Error("[GetSubscribeListLogic] JSON unmarshal failed: ", logger.Field("error", err.Error()), logger.Field("discount", item.Discount)) } } + if item.TrafficLimit != "" { + err = json.Unmarshal([]byte(item.TrafficLimit), &sub.TrafficLimit) + if err != nil { + l.Logger.Error("[GetSubscribeListLogic] JSON unmarshal failed: ", logger.Field("error", err.Error()), logger.Field("traffic_limit", item.TrafficLimit)) + } + } sub.Nodes = tool.StringToInt64Slice(item.Nodes) sub.NodeTags = strings.Split(item.NodeTags, ",") + // Handle NodeGroupIds - convert from JSONInt64Slice to []int64 + if item.NodeGroupIds != nil { + sub.NodeGroupIds = []int64(item.NodeGroupIds) + } else { + sub.NodeGroupIds = []int64{} + } + // NodeGroupId is already int64, should be copied by DeepCopy + sub.NodeGroupId = item.NodeGroupId resultList = append(resultList, sub) } diff --git a/internal/logic/admin/subscribe/updateSubscribeLogic.go b/internal/logic/admin/subscribe/updateSubscribeLogic.go index 123d5e0..8e4af98 100644 --- a/internal/logic/admin/subscribe/updateSubscribeLogic.go +++ b/internal/logic/admin/subscribe/updateSubscribeLogic.go @@ -42,6 +42,12 @@ func (l *UpdateSubscribeLogic) UpdateSubscribe(req *types.UpdateSubscribeRequest val, _ := json.Marshal(req.Discount) discount = string(val) } + + trafficLimit := "" + if len(req.TrafficLimit) > 0 { + val, _ := json.Marshal(req.TrafficLimit) + trafficLimit = string(val) + } sub := &subscribe.Subscribe{ Id: req.Id, Name: req.Name, @@ -59,6 +65,9 @@ func (l *UpdateSubscribeLogic) UpdateSubscribe(req *types.UpdateSubscribeRequest NewUserOnly: req.NewUserOnly, Nodes: tool.Int64SliceToString(req.Nodes), NodeTags: tool.StringSliceToString(req.NodeTags), + NodeGroupIds: subscribe.JSONInt64Slice(req.NodeGroupIds), + NodeGroupId: req.NodeGroupId, + TrafficLimit: trafficLimit, Show: req.Show, Sell: req.Sell, Sort: req.Sort, diff --git a/internal/logic/admin/user/createUserSubscribeLogic.go b/internal/logic/admin/user/createUserSubscribeLogic.go index 08876f8..e294cb5 100644 --- a/internal/logic/admin/user/createUserSubscribeLogic.go +++ b/internal/logic/admin/user/createUserSubscribeLogic.go @@ -6,6 +6,7 @@ import ( "time" "github.com/google/uuid" + "github.com/perfect-panel/server/internal/logic/admin/group" "github.com/perfect-panel/server/internal/model/user" "github.com/perfect-panel/server/internal/svc" "github.com/perfect-panel/server/internal/types" @@ -64,6 +65,7 @@ func (l *CreateUserSubscribeLogic) CreateUserSubscribe(req *types.CreateUserSubs Upload: 0, Token: uuidx.SubscribeToken(fmt.Sprintf("adminCreate:%d", time.Now().UnixMilli())), UUID: uuid.New().String(), + NodeGroupId: sub.NodeGroupId, Status: 1, } if err = l.svcCtx.UserModel.InsertSubscribe(l.ctx, &userSub); err != nil { @@ -71,6 +73,60 @@ func (l *CreateUserSubscribeLogic) CreateUserSubscribe(req *types.CreateUserSubs return errors.Wrapf(xerr.NewErrCode(xerr.DatabaseInsertError), "InsertSubscribe error: %v", err.Error()) } + // Trigger user group recalculation (runs in background) + go func() { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + // Check if group management is enabled + var groupEnabled string + err := l.svcCtx.DB.Table("system"). + Where("`category` = ? AND `key` = ?", "group", "enabled"). + Select("value"). + Scan(&groupEnabled).Error + if err != nil || groupEnabled != "true" && groupEnabled != "1" { + l.Debugf("Group management not enabled, skipping recalculation") + return + } + + // Get the configured grouping mode + var groupMode string + err = l.svcCtx.DB.Table("system"). + Where("`category` = ? AND `key` = ?", "group", "mode"). + Select("value"). + Scan(&groupMode).Error + if err != nil { + l.Errorw("Failed to get group mode", logger.Field("error", err.Error())) + return + } + + // Validate group mode + if groupMode != "average" && groupMode != "subscribe" && groupMode != "traffic" { + l.Debugf("Invalid group mode (current: %s), skipping", groupMode) + return + } + + // Trigger group recalculation with the configured mode + logic := group.NewRecalculateGroupLogic(ctx, l.svcCtx) + req := &types.RecalculateGroupRequest{ + Mode: groupMode, + } + + if err := logic.RecalculateGroup(req); err != nil { + l.Errorw("Failed to recalculate user group", + logger.Field("user_id", userInfo.Id), + logger.Field("error", err.Error()), + ) + return + } + + l.Infow("Successfully recalculated user group after admin created subscription", + logger.Field("user_id", userInfo.Id), + logger.Field("subscribe_id", userSub.Id), + logger.Field("mode", groupMode), + ) + }() + err = l.svcCtx.UserModel.UpdateUserCache(l.ctx, userInfo) if err != nil { l.Errorw("UpdateUserCache error", logger.Field("error", err.Error())) @@ -81,5 +137,6 @@ func (l *CreateUserSubscribeLogic) CreateUserSubscribe(req *types.CreateUserSubs if err != nil { logger.Errorw("ClearSubscribe error", logger.Field("error", err.Error())) } + return nil } diff --git a/internal/logic/admin/user/updateUserSubscribeLogic.go b/internal/logic/admin/user/updateUserSubscribeLogic.go index d86bac3..b521346 100644 --- a/internal/logic/admin/user/updateUserSubscribeLogic.go +++ b/internal/logic/admin/user/updateUserSubscribeLogic.go @@ -53,6 +53,7 @@ func (l *UpdateUserSubscribeLogic) UpdateUserSubscribe(req *types.UpdateUserSubs Token: userSub.Token, UUID: userSub.UUID, Status: userSub.Status, + NodeGroupId: userSub.NodeGroupId, }) if err != nil { @@ -81,5 +82,6 @@ func (l *UpdateUserSubscribeLogic) UpdateUserSubscribe(req *types.UpdateUserSubs l.Errorf("ClearServerAllCache error: %v", err.Error()) return errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "failed to clear server cache: %v", err.Error()) } + return nil } diff --git a/internal/logic/auth/admin/adminGenerateCaptchaLogic.go b/internal/logic/auth/admin/adminGenerateCaptchaLogic.go new file mode 100644 index 0000000..4772855 --- /dev/null +++ b/internal/logic/auth/admin/adminGenerateCaptchaLogic.go @@ -0,0 +1,70 @@ +package admin + +import ( + "context" + + "github.com/perfect-panel/server/internal/svc" + "github.com/perfect-panel/server/internal/types" + "github.com/perfect-panel/server/pkg/captcha" + "github.com/perfect-panel/server/pkg/logger" + "github.com/perfect-panel/server/pkg/tool" + "github.com/perfect-panel/server/pkg/xerr" + "github.com/pkg/errors" +) + +type AdminGenerateCaptchaLogic struct { + logger.Logger + ctx context.Context + svcCtx *svc.ServiceContext +} + +// Generate captcha +func NewAdminGenerateCaptchaLogic(ctx context.Context, svcCtx *svc.ServiceContext) *AdminGenerateCaptchaLogic { + return &AdminGenerateCaptchaLogic{ + Logger: logger.WithContext(ctx), + ctx: ctx, + svcCtx: svcCtx, + } +} + +func (l *AdminGenerateCaptchaLogic) AdminGenerateCaptcha() (resp *types.GenerateCaptchaResponse, err error) { + resp = &types.GenerateCaptchaResponse{} + + // Get verify config from database + verifyCfg, err := l.svcCtx.SystemModel.GetVerifyConfig(l.ctx) + if err != nil { + l.Logger.Error("[AdminGenerateCaptchaLogic] GetVerifyConfig error: ", logger.Field("error", err.Error())) + return nil, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "GetVerifyConfig error: %v", err.Error()) + } + + var config struct { + CaptchaType string `json:"captcha_type"` + TurnstileSiteKey string `json:"turnstile_site_key"` + TurnstileSecret string `json:"turnstile_secret"` + } + tool.SystemConfigSliceReflectToStruct(verifyCfg, &config) + + resp.Type = config.CaptchaType + + // If captcha type is local, generate captcha image + if config.CaptchaType == "local" { + captchaService := captcha.NewService(captcha.Config{ + Type: captcha.CaptchaTypeLocal, + RedisClient: l.svcCtx.Redis, + }) + + id, image, err := captchaService.Generate(l.ctx) + if err != nil { + l.Logger.Error("[AdminGenerateCaptchaLogic] Generate captcha error: ", logger.Field("error", err.Error())) + return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "Generate captcha error: %v", err.Error()) + } + + resp.Id = id + resp.Image = image + } else if config.CaptchaType == "turnstile" { + // For Turnstile, just return the site key + resp.Id = config.TurnstileSiteKey + } + + return resp, nil +} diff --git a/internal/logic/auth/admin/adminLoginLogic.go b/internal/logic/auth/admin/adminLoginLogic.go new file mode 100644 index 0000000..2167eaf --- /dev/null +++ b/internal/logic/auth/admin/adminLoginLogic.go @@ -0,0 +1,201 @@ +package admin + +import ( + "context" + "fmt" + "time" + + "github.com/perfect-panel/server/internal/config" + "github.com/perfect-panel/server/internal/logic/auth" + "github.com/perfect-panel/server/internal/model/log" + "github.com/perfect-panel/server/internal/model/user" + "github.com/perfect-panel/server/internal/svc" + "github.com/perfect-panel/server/internal/types" + "github.com/perfect-panel/server/pkg/captcha" + "github.com/perfect-panel/server/pkg/constant" + "github.com/perfect-panel/server/pkg/jwt" + "github.com/perfect-panel/server/pkg/logger" + "github.com/perfect-panel/server/pkg/tool" + "github.com/perfect-panel/server/pkg/uuidx" + "github.com/perfect-panel/server/pkg/xerr" + "github.com/pkg/errors" + "gorm.io/gorm" +) + +type AdminLoginLogic struct { + logger.Logger + ctx context.Context + svcCtx *svc.ServiceContext +} + +// Admin login +func NewAdminLoginLogic(ctx context.Context, svcCtx *svc.ServiceContext) *AdminLoginLogic { + return &AdminLoginLogic{ + Logger: logger.WithContext(ctx), + ctx: ctx, + svcCtx: svcCtx, + } +} + +func (l *AdminLoginLogic) AdminLogin(req *types.UserLoginRequest) (resp *types.LoginResponse, err error) { + loginStatus := false + var userInfo *user.User + // Record login status + defer func(svcCtx *svc.ServiceContext) { + if userInfo != nil && userInfo.Id != 0 { + loginLog := log.Login{ + Method: "email", + LoginIP: req.IP, + UserAgent: req.UserAgent, + Success: loginStatus, + Timestamp: time.Now().UnixMilli(), + } + content, _ := loginLog.Marshal() + if err := l.svcCtx.LogModel.Insert(l.ctx, &log.SystemLog{ + Type: log.TypeLogin.Uint8(), + Date: time.Now().Format("2006-01-02"), + ObjectID: userInfo.Id, + Content: string(content), + }); err != nil { + l.Errorw("failed to insert login log", + logger.Field("user_id", userInfo.Id), + logger.Field("ip", req.IP), + logger.Field("error", err.Error()), + ) + } + } + }(l.svcCtx) + + // Verify captcha + if err := l.verifyCaptcha(req); err != nil { + return nil, err + } + + userInfo, err = l.svcCtx.UserModel.FindOneByEmail(l.ctx, req.Email) + + if userInfo.DeletedAt.Valid { + return nil, errors.Wrapf(xerr.NewErrCode(xerr.UserNotExist), "user email deleted: %v", req.Email) + } + + if err != nil { + if errors.As(err, &gorm.ErrRecordNotFound) { + return nil, errors.Wrapf(xerr.NewErrCode(xerr.UserNotExist), "user email not exist: %v", req.Email) + } + logger.WithContext(l.ctx).Error(err) + return nil, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "query user info failed: %v", err.Error()) + } + + // Check if user is admin + if userInfo.IsAdmin == nil || !*userInfo.IsAdmin { + return nil, errors.Wrapf(xerr.NewErrCode(xerr.PermissionDenied), "user is not admin") + } + + // Verify password + if !tool.MultiPasswordVerify(userInfo.Algo, userInfo.Salt, req.Password, userInfo.Password) { + return nil, errors.Wrapf(xerr.NewErrCode(xerr.UserPasswordError), "user password") + } + + // Bind device to user if identifier is provided + if req.Identifier != "" { + bindLogic := auth.NewBindDeviceLogic(l.ctx, l.svcCtx) + if err := bindLogic.BindDeviceToUser(req.Identifier, req.IP, req.UserAgent, userInfo.Id); err != nil { + l.Errorw("failed to bind device to user", + logger.Field("user_id", userInfo.Id), + logger.Field("identifier", req.Identifier), + logger.Field("error", err.Error()), + ) + // Don't fail login if device binding fails, just log the error + } + } + if l.ctx.Value(constant.CtxLoginType) != nil { + req.LoginType = l.ctx.Value(constant.CtxLoginType).(string) + } + // Generate session id + sessionId := uuidx.NewUUID().String() + // Generate token + token, err := jwt.NewJwtToken( + l.svcCtx.Config.JwtAuth.AccessSecret, + time.Now().Unix(), + l.svcCtx.Config.JwtAuth.AccessExpire, + jwt.WithOption("UserId", userInfo.Id), + jwt.WithOption("SessionId", sessionId), + jwt.WithOption("identifier", req.Identifier), + jwt.WithOption("CtxLoginType", req.LoginType), + ) + if err != nil { + l.Logger.Error("[AdminLogin] token generate error", logger.Field("error", err.Error())) + return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "token generate error: %v", err.Error()) + } + sessionIdCacheKey := fmt.Sprintf("%v:%v", config.SessionIdKey, sessionId) + if err = l.svcCtx.Redis.Set(l.ctx, sessionIdCacheKey, userInfo.Id, time.Duration(l.svcCtx.Config.JwtAuth.AccessExpire)*time.Second).Err(); err != nil { + return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "set session id error: %v", err.Error()) + } + loginStatus = true + return &types.LoginResponse{ + Token: token, + }, nil +} + +func (l *AdminLoginLogic) verifyCaptcha(req *types.UserLoginRequest) error { + // Get verify config from database + verifyCfg, err := l.svcCtx.SystemModel.GetVerifyConfig(l.ctx) + if err != nil { + l.Logger.Error("[AdminLoginLogic] GetVerifyConfig error: ", logger.Field("error", err.Error())) + return errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "GetVerifyConfig error: %v", err.Error()) + } + + var config struct { + CaptchaType string `json:"captcha_type"` + EnableAdminLoginCaptcha bool `json:"enable_admin_login_captcha"` + TurnstileSecret string `json:"turnstile_secret"` + } + tool.SystemConfigSliceReflectToStruct(verifyCfg, &config) + + // Check if captcha is enabled for admin login + if !config.EnableAdminLoginCaptcha { + return nil + } + + // Verify based on captcha type + if config.CaptchaType == "local" { + if req.CaptchaId == "" || req.CaptchaCode == "" { + return errors.Wrapf(xerr.NewErrCode(xerr.VerifyCodeError), "captcha required") + } + + captchaService := captcha.NewService(captcha.Config{ + Type: captcha.CaptchaTypeLocal, + RedisClient: l.svcCtx.Redis, + }) + + valid, err := captchaService.Verify(l.ctx, req.CaptchaId, req.CaptchaCode, req.IP) + if err != nil { + l.Logger.Error("[AdminLoginLogic] Verify captcha error: ", logger.Field("error", err.Error())) + return errors.Wrapf(xerr.NewErrCode(xerr.VerifyCodeError), "verify captcha error") + } + + if !valid { + return errors.Wrapf(xerr.NewErrCode(xerr.VerifyCodeError), "invalid captcha") + } + } else if config.CaptchaType == "turnstile" { + if req.CfToken == "" { + return errors.Wrapf(xerr.NewErrCode(xerr.VerifyCodeError), "captcha required") + } + + captchaService := captcha.NewService(captcha.Config{ + Type: captcha.CaptchaTypeTurnstile, + TurnstileSecret: config.TurnstileSecret, + }) + + valid, err := captchaService.Verify(l.ctx, req.CfToken, "", req.IP) + if err != nil { + l.Logger.Error("[AdminLoginLogic] Verify turnstile error: ", logger.Field("error", err.Error())) + return errors.Wrapf(xerr.NewErrCode(xerr.VerifyCodeError), "verify captcha error") + } + + if !valid { + return errors.Wrapf(xerr.NewErrCode(xerr.VerifyCodeError), "invalid captcha") + } + } + + return nil +} diff --git a/internal/logic/auth/admin/adminResetPasswordLogic.go b/internal/logic/auth/admin/adminResetPasswordLogic.go new file mode 100644 index 0000000..544b56b --- /dev/null +++ b/internal/logic/auth/admin/adminResetPasswordLogic.go @@ -0,0 +1,229 @@ +package admin + +import ( + "context" + "encoding/json" + "fmt" + "time" + + "github.com/perfect-panel/server/internal/config" + "github.com/perfect-panel/server/internal/logic/auth" + "github.com/perfect-panel/server/internal/model/log" + "github.com/perfect-panel/server/internal/model/user" + "github.com/perfect-panel/server/internal/svc" + "github.com/perfect-panel/server/internal/types" + "github.com/perfect-panel/server/pkg/captcha" + "github.com/perfect-panel/server/pkg/constant" + "github.com/perfect-panel/server/pkg/jwt" + "github.com/perfect-panel/server/pkg/logger" + "github.com/perfect-panel/server/pkg/tool" + "github.com/perfect-panel/server/pkg/uuidx" + "github.com/perfect-panel/server/pkg/xerr" + "github.com/pkg/errors" + "gorm.io/gorm" +) + +type AdminResetPasswordLogic struct { + logger.Logger + ctx context.Context + svcCtx *svc.ServiceContext +} + +type CacheKeyPayload struct { + Code string `json:"code"` +} + +// Admin reset password +func NewAdminResetPasswordLogic(ctx context.Context, svcCtx *svc.ServiceContext) *AdminResetPasswordLogic { + return &AdminResetPasswordLogic{ + Logger: logger.WithContext(ctx), + ctx: ctx, + svcCtx: svcCtx, + } +} + +func (l *AdminResetPasswordLogic) AdminResetPassword(req *types.ResetPasswordRequest) (resp *types.LoginResponse, err error) { + var userInfo *user.User + loginStatus := false + + defer func() { + if userInfo != nil && userInfo.Id != 0 && loginStatus { + loginLog := log.Login{ + Method: "email", + LoginIP: req.IP, + UserAgent: req.UserAgent, + Success: loginStatus, + Timestamp: time.Now().UnixMilli(), + } + content, _ := loginLog.Marshal() + if err := l.svcCtx.LogModel.Insert(l.ctx, &log.SystemLog{ + Id: 0, + Type: log.TypeLogin.Uint8(), + Date: time.Now().Format("2006-01-02"), + ObjectID: userInfo.Id, + Content: string(content), + }); err != nil { + l.Errorw("failed to insert login log", + logger.Field("user_id", userInfo.Id), + logger.Field("ip", req.IP), + logger.Field("error", err.Error()), + ) + } + } + }() + + cacheKey := fmt.Sprintf("%s:%s:%s", config.AuthCodeCacheKey, constant.Security, req.Email) + // Check the verification code + if value, err := l.svcCtx.Redis.Get(l.ctx, cacheKey).Result(); err != nil { + l.Errorw("Verification code error", logger.Field("cacheKey", cacheKey), logger.Field("error", err.Error())) + return nil, errors.Wrapf(xerr.NewErrCode(xerr.VerifyCodeError), "Verification code error") + } else { + var payload CacheKeyPayload + if err := json.Unmarshal([]byte(value), &payload); err != nil { + l.Errorw("Unmarshal errors", logger.Field("cacheKey", cacheKey), logger.Field("error", err.Error()), logger.Field("value", value)) + return nil, errors.Wrapf(xerr.NewErrCode(xerr.VerifyCodeError), "Verification code error") + } + if payload.Code != req.Code { + l.Errorw("Verification code error", logger.Field("cacheKey", cacheKey), logger.Field("error", "Verification code error"), logger.Field("reqCode", req.Code), logger.Field("payloadCode", payload.Code)) + return nil, errors.Wrapf(xerr.NewErrCode(xerr.VerifyCodeError), "Verification code error") + } + } + + // Verify captcha + if err := l.verifyCaptcha(req); err != nil { + return nil, err + } + + // Check user + authMethod, err := l.svcCtx.UserModel.FindUserAuthMethodByOpenID(l.ctx, "email", req.Email) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, errors.Wrapf(xerr.NewErrCode(xerr.UserNotExist), "user email not exist: %v", req.Email) + } + return nil, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "find user by email error: %v", err.Error()) + } + + userInfo, err = l.svcCtx.UserModel.FindOne(l.ctx, authMethod.UserId) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, errors.Wrapf(xerr.NewErrCode(xerr.UserNotExist), "user email not exist: %v", req.Email) + } + return nil, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "query user info failed: %v", err.Error()) + } + + // Check if user is admin + if userInfo.IsAdmin == nil || !*userInfo.IsAdmin { + return nil, errors.Wrapf(xerr.NewErrCode(xerr.PermissionDenied), "user is not admin") + } + + // Update password + userInfo.Password = tool.EncodePassWord(req.Password) + userInfo.Algo = "default" + if err = l.svcCtx.UserModel.Update(l.ctx, userInfo); err != nil { + return nil, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseUpdateError), "update user info failed: %v", err.Error()) + } + + // Bind device to user if identifier is provided + if req.Identifier != "" { + bindLogic := auth.NewBindDeviceLogic(l.ctx, l.svcCtx) + if err := bindLogic.BindDeviceToUser(req.Identifier, req.IP, req.UserAgent, userInfo.Id); err != nil { + l.Errorw("failed to bind device to user", + logger.Field("user_id", userInfo.Id), + logger.Field("identifier", req.Identifier), + logger.Field("error", err.Error()), + ) + // Don't fail register if device binding fails, just log the error + } + } + if l.ctx.Value(constant.CtxLoginType) != nil { + req.LoginType = l.ctx.Value(constant.CtxLoginType).(string) + } + // Generate session id + sessionId := uuidx.NewUUID().String() + // Generate token + token, err := jwt.NewJwtToken( + l.svcCtx.Config.JwtAuth.AccessSecret, + time.Now().Unix(), + l.svcCtx.Config.JwtAuth.AccessExpire, + jwt.WithOption("UserId", userInfo.Id), + jwt.WithOption("SessionId", sessionId), + jwt.WithOption("identifier", req.Identifier), + jwt.WithOption("CtxLoginType", req.LoginType), + ) + if err != nil { + l.Logger.Error("[AdminResetPassword] token generate error", logger.Field("error", err.Error())) + return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "token generate error: %v", err.Error()) + } + sessionIdCacheKey := fmt.Sprintf("%v:%v", config.SessionIdKey, sessionId) + if err = l.svcCtx.Redis.Set(l.ctx, sessionIdCacheKey, userInfo.Id, time.Duration(l.svcCtx.Config.JwtAuth.AccessExpire)*time.Second).Err(); err != nil { + return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "set session id error: %v", err.Error()) + } + loginStatus = true + return &types.LoginResponse{ + Token: token, + }, nil +} + +func (l *AdminResetPasswordLogic) verifyCaptcha(req *types.ResetPasswordRequest) error { + // Get verify config from database + verifyCfg, err := l.svcCtx.SystemModel.GetVerifyConfig(l.ctx) + if err != nil { + l.Logger.Error("[AdminResetPasswordLogic] GetVerifyConfig error: ", logger.Field("error", err.Error())) + return errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "GetVerifyConfig error: %v", err.Error()) + } + + var config struct { + CaptchaType string `json:"captcha_type"` + EnableAdminLoginCaptcha bool `json:"enable_admin_login_captcha"` + TurnstileSecret string `json:"turnstile_secret"` + } + tool.SystemConfigSliceReflectToStruct(verifyCfg, &config) + + // Check if admin login captcha is enabled (use admin login captcha for reset password) + if !config.EnableAdminLoginCaptcha { + return nil + } + + // Verify based on captcha type + if config.CaptchaType == "local" { + if req.CaptchaId == "" || req.CaptchaCode == "" { + return errors.Wrapf(xerr.NewErrCode(xerr.VerifyCodeError), "captcha required") + } + + captchaService := captcha.NewService(captcha.Config{ + Type: captcha.CaptchaTypeLocal, + RedisClient: l.svcCtx.Redis, + }) + + valid, err := captchaService.Verify(l.ctx, req.CaptchaId, req.CaptchaCode, req.IP) + if err != nil { + l.Logger.Error("[AdminResetPasswordLogic] Verify captcha error: ", logger.Field("error", err.Error())) + return errors.Wrapf(xerr.NewErrCode(xerr.VerifyCodeError), "verify captcha error") + } + + if !valid { + return errors.Wrapf(xerr.NewErrCode(xerr.VerifyCodeError), "invalid captcha") + } + } else if config.CaptchaType == "turnstile" { + if req.CfToken == "" { + return errors.Wrapf(xerr.NewErrCode(xerr.VerifyCodeError), "captcha required") + } + + captchaService := captcha.NewService(captcha.Config{ + Type: captcha.CaptchaTypeTurnstile, + TurnstileSecret: config.TurnstileSecret, + }) + + valid, err := captchaService.Verify(l.ctx, req.CfToken, "", req.IP) + if err != nil { + l.Logger.Error("[AdminResetPasswordLogic] Verify turnstile error: ", logger.Field("error", err.Error())) + return errors.Wrapf(xerr.NewErrCode(xerr.VerifyCodeError), "verify captcha error") + } + + if !valid { + return errors.Wrapf(xerr.NewErrCode(xerr.VerifyCodeError), "invalid captcha") + } + } + + return nil +} diff --git a/internal/logic/auth/generateCaptchaLogic.go b/internal/logic/auth/generateCaptchaLogic.go new file mode 100644 index 0000000..068eb2b --- /dev/null +++ b/internal/logic/auth/generateCaptchaLogic.go @@ -0,0 +1,70 @@ +package auth + +import ( + "context" + + "github.com/perfect-panel/server/internal/svc" + "github.com/perfect-panel/server/internal/types" + "github.com/perfect-panel/server/pkg/captcha" + "github.com/perfect-panel/server/pkg/logger" + "github.com/perfect-panel/server/pkg/tool" + "github.com/perfect-panel/server/pkg/xerr" + "github.com/pkg/errors" +) + +type GenerateCaptchaLogic struct { + logger.Logger + ctx context.Context + svcCtx *svc.ServiceContext +} + +// Generate captcha +func NewGenerateCaptchaLogic(ctx context.Context, svcCtx *svc.ServiceContext) *GenerateCaptchaLogic { + return &GenerateCaptchaLogic{ + Logger: logger.WithContext(ctx), + ctx: ctx, + svcCtx: svcCtx, + } +} + +func (l *GenerateCaptchaLogic) GenerateCaptcha() (resp *types.GenerateCaptchaResponse, err error) { + resp = &types.GenerateCaptchaResponse{} + + // Get verify config from database + verifyCfg, err := l.svcCtx.SystemModel.GetVerifyConfig(l.ctx) + if err != nil { + l.Logger.Error("[GenerateCaptchaLogic] GetVerifyConfig error: ", logger.Field("error", err.Error())) + return nil, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "GetVerifyConfig error: %v", err.Error()) + } + + var config struct { + CaptchaType string `json:"captcha_type"` + TurnstileSiteKey string `json:"turnstile_site_key"` + TurnstileSecret string `json:"turnstile_secret"` + } + tool.SystemConfigSliceReflectToStruct(verifyCfg, &config) + + resp.Type = config.CaptchaType + + // If captcha type is local, generate captcha image + if config.CaptchaType == "local" { + captchaService := captcha.NewService(captcha.Config{ + Type: captcha.CaptchaTypeLocal, + RedisClient: l.svcCtx.Redis, + }) + + id, image, err := captchaService.Generate(l.ctx) + if err != nil { + l.Logger.Error("[GenerateCaptchaLogic] Generate captcha error: ", logger.Field("error", err.Error())) + return nil, errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "Generate captcha error: %v", err.Error()) + } + + resp.Id = id + resp.Image = image + } else if config.CaptchaType == "turnstile" { + // For Turnstile, just return the site key + resp.Id = config.TurnstileSiteKey + } + + return resp, nil +} diff --git a/internal/logic/auth/registerLimitLogic.go b/internal/logic/auth/registerLimitLogic.go index 048ef75..11ab3ca 100644 --- a/internal/logic/auth/registerLimitLogic.go +++ b/internal/logic/auth/registerLimitLogic.go @@ -16,6 +16,10 @@ func registerIpLimit(svcCtx *svc.ServiceContext, ctx context.Context, registerIp return true } + // Add timeout protection for Redis operations + ctx, cancel := context.WithTimeout(ctx, 2*time.Second) + defer cancel() + // Use a sorted set to track IP registrations with timestamp as score // Key format: register:ip:{ip} key := fmt.Sprintf("%s%s", config.RegisterIpKeyPrefix, registerIp) diff --git a/internal/logic/auth/resetPasswordLogic.go b/internal/logic/auth/resetPasswordLogic.go index f504437..b975055 100644 --- a/internal/logic/auth/resetPasswordLogic.go +++ b/internal/logic/auth/resetPasswordLogic.go @@ -9,6 +9,7 @@ import ( "github.com/perfect-panel/server/internal/model/log" "github.com/perfect-panel/server/internal/model/user" + "github.com/perfect-panel/server/pkg/captcha" "github.com/perfect-panel/server/pkg/jwt" "github.com/perfect-panel/server/pkg/uuidx" @@ -91,6 +92,11 @@ func (l *ResetPasswordLogic) ResetPassword(req *types.ResetPasswordRequest) (res l.svcCtx.Redis.Del(l.ctx, cacheKey) } + // Verify captcha + if err := l.verifyCaptcha(req); err != nil { + return nil, err + } + // Check user authMethod, err := l.svcCtx.UserModel.FindUserAuthMethodByOpenID(l.ctx, "email", req.Email) if err != nil { @@ -155,3 +161,68 @@ func (l *ResetPasswordLogic) ResetPassword(req *types.ResetPasswordRequest) (res Token: token, }, nil } + +func (l *ResetPasswordLogic) verifyCaptcha(req *types.ResetPasswordRequest) error { + // Get verify config from database + verifyCfg, err := l.svcCtx.SystemModel.GetVerifyConfig(l.ctx) + if err != nil { + l.Logger.Error("[ResetPasswordLogic] GetVerifyConfig error: ", logger.Field("error", err.Error())) + return errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "GetVerifyConfig error: %v", err.Error()) + } + + var config struct { + CaptchaType string `json:"captcha_type"` + EnableUserResetPasswordCaptcha bool `json:"enable_user_reset_password_captcha"` + TurnstileSecret string `json:"turnstile_secret"` + } + tool.SystemConfigSliceReflectToStruct(verifyCfg, &config) + + // Check if user reset password captcha is enabled + if !config.EnableUserResetPasswordCaptcha { + return nil + } + + // Verify based on captcha type + if config.CaptchaType == "local" { + if req.CaptchaId == "" || req.CaptchaCode == "" { + return errors.Wrapf(xerr.NewErrCode(xerr.VerifyCodeError), "captcha required") + } + + captchaService := captcha.NewService(captcha.Config{ + Type: captcha.CaptchaTypeLocal, + RedisClient: l.svcCtx.Redis, + }) + + valid, err := captchaService.Verify(l.ctx, req.CaptchaId, req.CaptchaCode, req.IP) + if err != nil { + l.Logger.Error("[ResetPasswordLogic] Verify captcha error: ", logger.Field("error", err.Error())) + return errors.Wrapf(xerr.NewErrCode(xerr.VerifyCodeError), "verify captcha error") + } + + if !valid { + return errors.Wrapf(xerr.NewErrCode(xerr.VerifyCodeError), "invalid captcha") + } + } else if config.CaptchaType == "turnstile" { + if req.CfToken == "" { + return errors.Wrapf(xerr.NewErrCode(xerr.VerifyCodeError), "captcha required") + } + + captchaService := captcha.NewService(captcha.Config{ + Type: captcha.CaptchaTypeTurnstile, + TurnstileSecret: config.TurnstileSecret, + }) + + valid, err := captchaService.Verify(l.ctx, req.CfToken, "", req.IP) + if err != nil { + l.Logger.Error("[ResetPasswordLogic] Verify turnstile error: ", logger.Field("error", err.Error())) + return errors.Wrapf(xerr.NewErrCode(xerr.VerifyCodeError), "verify captcha error") + } + + if !valid { + return errors.Wrapf(xerr.NewErrCode(xerr.VerifyCodeError), "invalid captcha") + } + } + + return nil +} + diff --git a/internal/logic/auth/telephoneLoginLogic.go b/internal/logic/auth/telephoneLoginLogic.go index 3a8655c..772e006 100644 --- a/internal/logic/auth/telephoneLoginLogic.go +++ b/internal/logic/auth/telephoneLoginLogic.go @@ -12,6 +12,7 @@ import ( "github.com/perfect-panel/server/internal/model/log" "github.com/perfect-panel/server/internal/svc" "github.com/perfect-panel/server/internal/types" + "github.com/perfect-panel/server/pkg/captcha" "github.com/perfect-panel/server/pkg/constant" "github.com/perfect-panel/server/pkg/jwt" "github.com/perfect-panel/server/pkg/logger" @@ -94,6 +95,11 @@ func (l *TelephoneLoginLogic) TelephoneLogin(req *types.TelephoneLoginRequest, r return nil, xerr.NewErrCodeMsg(xerr.InvalidParams, "password and telephone code is empty") } + // Verify captcha + if err := l.verifyCaptcha(req); err != nil { + return nil, err + } + if req.TelephoneCode == "" { // Verify password if !tool.MultiPasswordVerify(userInfo.Algo, userInfo.Salt, req.Password, userInfo.Password) { @@ -164,3 +170,67 @@ func (l *TelephoneLoginLogic) TelephoneLogin(req *types.TelephoneLoginRequest, r Token: token, }, nil } + +func (l *TelephoneLoginLogic) verifyCaptcha(req *types.TelephoneLoginRequest) error { + // Get verify config from database + verifyCfg, err := l.svcCtx.SystemModel.GetVerifyConfig(l.ctx) + if err != nil { + l.Logger.Error("[TelephoneLoginLogic] GetVerifyConfig error: ", logger.Field("error", err.Error())) + return errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "GetVerifyConfig error: %v", err.Error()) + } + + var config struct { + CaptchaType string `json:"captcha_type"` + EnableUserLoginCaptcha bool `json:"enable_user_login_captcha"` + TurnstileSecret string `json:"turnstile_secret"` + } + tool.SystemConfigSliceReflectToStruct(verifyCfg, &config) + + // Check if captcha is enabled for user login + if !config.EnableUserLoginCaptcha { + return nil + } + + // Verify based on captcha type + if config.CaptchaType == "local" { + if req.CaptchaId == "" || req.CaptchaCode == "" { + return errors.Wrapf(xerr.NewErrCode(xerr.VerifyCodeError), "captcha required") + } + + captchaService := captcha.NewService(captcha.Config{ + Type: captcha.CaptchaTypeLocal, + RedisClient: l.svcCtx.Redis, + }) + + valid, err := captchaService.Verify(l.ctx, req.CaptchaId, req.CaptchaCode, req.IP) + if err != nil { + l.Logger.Error("[TelephoneLoginLogic] Verify captcha error: ", logger.Field("error", err.Error())) + return errors.Wrapf(xerr.NewErrCode(xerr.VerifyCodeError), "verify captcha error") + } + + if !valid { + return errors.Wrapf(xerr.NewErrCode(xerr.VerifyCodeError), "invalid captcha") + } + } else if config.CaptchaType == "turnstile" { + if req.CfToken == "" { + return errors.Wrapf(xerr.NewErrCode(xerr.VerifyCodeError), "captcha required") + } + + captchaService := captcha.NewService(captcha.Config{ + Type: captcha.CaptchaTypeTurnstile, + TurnstileSecret: config.TurnstileSecret, + }) + + valid, err := captchaService.Verify(l.ctx, req.CfToken, "", req.IP) + if err != nil { + l.Logger.Error("[TelephoneLoginLogic] Verify turnstile error: ", logger.Field("error", err.Error())) + return errors.Wrapf(xerr.NewErrCode(xerr.VerifyCodeError), "verify captcha error") + } + + if !valid { + return errors.Wrapf(xerr.NewErrCode(xerr.VerifyCodeError), "invalid captcha") + } + } + + return nil +} diff --git a/internal/logic/auth/telephoneUserRegisterLogic.go b/internal/logic/auth/telephoneUserRegisterLogic.go index e503190..53298c3 100644 --- a/internal/logic/auth/telephoneUserRegisterLogic.go +++ b/internal/logic/auth/telephoneUserRegisterLogic.go @@ -13,6 +13,7 @@ import ( "github.com/perfect-panel/server/internal/model/user" "github.com/perfect-panel/server/internal/svc" "github.com/perfect-panel/server/internal/types" + "github.com/perfect-panel/server/pkg/captcha" "github.com/perfect-panel/server/pkg/jwt" "github.com/perfect-panel/server/pkg/logger" "github.com/perfect-panel/server/pkg/phone" @@ -81,6 +82,12 @@ func (l *TelephoneUserRegisterLogic) TelephoneUserRegister(req *types.TelephoneR return nil, errors.Wrapf(xerr.NewErrCode(xerr.VerifyCodeError), "code error") } l.svcCtx.Redis.Del(l.ctx, cacheKey) + + // Verify captcha + if err := l.verifyCaptcha(req); err != nil { + return nil, err + } + // Check if the user exists _, err = l.svcCtx.UserModel.FindUserAuthMethodByOpenID(l.ctx, "mobile", phoneNumber) if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { @@ -280,3 +287,67 @@ func (l *TelephoneUserRegisterLogic) activeTrial(uid int64) (*user.Subscribe, er return userSub, nil } +func (l *TelephoneUserRegisterLogic) verifyCaptcha(req *types.TelephoneRegisterRequest) error { + // Get verify config from database + verifyCfg, err := l.svcCtx.SystemModel.GetVerifyConfig(l.ctx) + if err != nil { + l.Logger.Error("[TelephoneUserRegisterLogic] GetVerifyConfig error: ", logger.Field("error", err.Error())) + return errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "GetVerifyConfig error: %v", err.Error()) + } + + var config struct { + CaptchaType string `json:"captcha_type"` + EnableUserRegisterCaptcha bool `json:"enable_user_register_captcha"` + TurnstileSecret string `json:"turnstile_secret"` + } + tool.SystemConfigSliceReflectToStruct(verifyCfg, &config) + + // Check if captcha is enabled for user register + if !config.EnableUserRegisterCaptcha { + return nil + } + + // Verify based on captcha type + if config.CaptchaType == "local" { + if req.CaptchaId == "" || req.CaptchaCode == "" { + return errors.Wrapf(xerr.NewErrCode(xerr.VerifyCodeError), "captcha required") + } + + captchaService := captcha.NewService(captcha.Config{ + Type: captcha.CaptchaTypeLocal, + RedisClient: l.svcCtx.Redis, + }) + + valid, err := captchaService.Verify(l.ctx, req.CaptchaId, req.CaptchaCode, req.IP) + if err != nil { + l.Logger.Error("[TelephoneUserRegisterLogic] Verify captcha error: ", logger.Field("error", err.Error())) + return errors.Wrapf(xerr.NewErrCode(xerr.VerifyCodeError), "verify captcha error") + } + + if !valid { + return errors.Wrapf(xerr.NewErrCode(xerr.VerifyCodeError), "invalid captcha") + } + } else if config.CaptchaType == "turnstile" { + if req.CfToken == "" { + return errors.Wrapf(xerr.NewErrCode(xerr.VerifyCodeError), "captcha required") + } + + captchaService := captcha.NewService(captcha.Config{ + Type: captcha.CaptchaTypeTurnstile, + TurnstileSecret: config.TurnstileSecret, + }) + + valid, err := captchaService.Verify(l.ctx, req.CfToken, "", req.IP) + if err != nil { + l.Logger.Error("[TelephoneUserRegisterLogic] Verify turnstile error: ", logger.Field("error", err.Error())) + return errors.Wrapf(xerr.NewErrCode(xerr.VerifyCodeError), "verify captcha error") + } + + if !valid { + return errors.Wrapf(xerr.NewErrCode(xerr.VerifyCodeError), "invalid captcha") + } + } + + return nil +} + diff --git a/internal/logic/auth/userLoginLogic.go b/internal/logic/auth/userLoginLogic.go index 4204c53..f08f371 100644 --- a/internal/logic/auth/userLoginLogic.go +++ b/internal/logic/auth/userLoginLogic.go @@ -6,6 +6,7 @@ import ( "time" "github.com/perfect-panel/server/internal/model/log" + "github.com/perfect-panel/server/pkg/captcha" "github.com/perfect-panel/server/pkg/constant" "github.com/perfect-panel/server/pkg/logger" @@ -66,6 +67,11 @@ func (l *UserLoginLogic) UserLogin(req *types.UserLoginRequest) (resp *types.Log } }(l.svcCtx) + // Verify captcha + if err := l.verifyCaptcha(req); err != nil { + return nil, err + } + userInfo, err = l.svcCtx.UserModel.FindOneByEmail(l.ctx, req.Email) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { @@ -134,3 +140,67 @@ func (l *UserLoginLogic) UserLogin(req *types.UserLoginRequest) (resp *types.Log Token: token, }, nil } + +func (l *UserLoginLogic) verifyCaptcha(req *types.UserLoginRequest) error { + // Get verify config from database + verifyCfg, err := l.svcCtx.SystemModel.GetVerifyConfig(l.ctx) + if err != nil { + l.Logger.Error("[UserLoginLogic] GetVerifyConfig error: ", logger.Field("error", err.Error())) + return errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "GetVerifyConfig error: %v", err.Error()) + } + + var config struct { + CaptchaType string `json:"captcha_type"` + EnableUserLoginCaptcha bool `json:"enable_user_login_captcha"` + TurnstileSecret string `json:"turnstile_secret"` + } + tool.SystemConfigSliceReflectToStruct(verifyCfg, &config) + + // Check if captcha is enabled for user login + if !config.EnableUserLoginCaptcha { + return nil + } + + // Verify based on captcha type + if config.CaptchaType == "local" { + if req.CaptchaId == "" || req.CaptchaCode == "" { + return errors.Wrapf(xerr.NewErrCode(xerr.VerifyCodeError), "captcha required") + } + + captchaService := captcha.NewService(captcha.Config{ + Type: captcha.CaptchaTypeLocal, + RedisClient: l.svcCtx.Redis, + }) + + valid, err := captchaService.Verify(l.ctx, req.CaptchaId, req.CaptchaCode, req.IP) + if err != nil { + l.Logger.Error("[UserLoginLogic] Verify captcha error: ", logger.Field("error", err.Error())) + return errors.Wrapf(xerr.NewErrCode(xerr.VerifyCodeError), "verify captcha error") + } + + if !valid { + return errors.Wrapf(xerr.NewErrCode(xerr.VerifyCodeError), "invalid captcha") + } + } else if config.CaptchaType == "turnstile" { + if req.CfToken == "" { + return errors.Wrapf(xerr.NewErrCode(xerr.VerifyCodeError), "captcha required") + } + + captchaService := captcha.NewService(captcha.Config{ + Type: captcha.CaptchaTypeTurnstile, + TurnstileSecret: config.TurnstileSecret, + }) + + valid, err := captchaService.Verify(l.ctx, req.CfToken, "", req.IP) + if err != nil { + l.Logger.Error("[UserLoginLogic] Verify turnstile error: ", logger.Field("error", err.Error())) + return errors.Wrapf(xerr.NewErrCode(xerr.VerifyCodeError), "verify captcha error") + } + + if !valid { + return errors.Wrapf(xerr.NewErrCode(xerr.VerifyCodeError), "invalid captcha") + } + } + + return nil +} diff --git a/internal/logic/auth/userRegisterLogic.go b/internal/logic/auth/userRegisterLogic.go index 70def38..cbe7789 100644 --- a/internal/logic/auth/userRegisterLogic.go +++ b/internal/logic/auth/userRegisterLogic.go @@ -8,11 +8,13 @@ import ( "time" "github.com/perfect-panel/server/internal/config" + "github.com/perfect-panel/server/internal/logic/admin/group" "github.com/perfect-panel/server/internal/logic/common" "github.com/perfect-panel/server/internal/model/log" "github.com/perfect-panel/server/internal/model/user" "github.com/perfect-panel/server/internal/svc" "github.com/perfect-panel/server/internal/types" + "github.com/perfect-panel/server/pkg/captcha" "github.com/perfect-panel/server/pkg/constant" "github.com/perfect-panel/server/pkg/jwt" "github.com/perfect-panel/server/pkg/logger" @@ -85,6 +87,12 @@ func (l *UserRegisterLogic) UserRegister(req *types.UserRegisterRequest) (resp * } l.svcCtx.Redis.Del(l.ctx, cacheKey) } + + // Verify captcha + if err := l.verifyCaptcha(req); err != nil { + return nil, err + } + // Check if the user exists u, err := l.svcCtx.UserModel.FindOneByEmail(l.ctx, req.Email) if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { @@ -132,22 +140,76 @@ func (l *UserRegisterLogic) UserRegister(req *types.UserRegisterRequest) (resp * return err } - if l.svcCtx.Config.Register.EnableTrial { - // Active trial - var trialErr error - trialSubscribe, trialErr = l.activeTrial(userInfo.Id) - if trialErr != nil { - return trialErr - } - } return nil }) if err != nil { return nil, err } + // Activate trial subscription after transaction success (moved outside transaction to reduce lock time) + if l.svcCtx.Config.Register.EnableTrial { + trialSubscribe, err = l.activeTrial(userInfo.Id) + if err != nil { + l.Errorw("Failed to activate trial subscription", logger.Field("error", err.Error())) + // Don't fail registration if trial activation fails + } + } + // Clear cache after transaction success if l.svcCtx.Config.Register.EnableTrial && trialSubscribe != nil { + // Trigger user group recalculation (runs in background) + go func() { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + // Check if group management is enabled + var groupEnabled string + err := l.svcCtx.DB.Table("system"). + Where("`category` = ? AND `key` = ?", "group", "enabled"). + Select("value"). + Scan(&groupEnabled).Error + if err != nil || groupEnabled != "true" && groupEnabled != "1" { + l.Debugf("Group management not enabled, skipping recalculation") + return + } + + // Get the configured grouping mode + var groupMode string + err = l.svcCtx.DB.Table("system"). + Where("`category` = ? AND `key` = ?", "group", "mode"). + Select("value"). + Scan(&groupMode).Error + if err != nil { + l.Errorw("Failed to get group mode", logger.Field("error", err.Error())) + return + } + + // Validate group mode + if groupMode != "average" && groupMode != "subscribe" && groupMode != "traffic" { + l.Debugf("Invalid group mode (current: %s), skipping", groupMode) + return + } + + // Trigger group recalculation with the configured mode + logic := group.NewRecalculateGroupLogic(ctx, l.svcCtx) + req := &types.RecalculateGroupRequest{ + Mode: groupMode, + } + + if err := logic.RecalculateGroup(req); err != nil { + l.Errorw("Failed to recalculate user group", + logger.Field("user_id", userInfo.Id), + logger.Field("error", err.Error()), + ) + return + } + + l.Infow("Successfully recalculated user group after registration", + logger.Field("user_id", userInfo.Id), + logger.Field("mode", groupMode), + ) + }() + // Clear user subscription cache if err = l.svcCtx.UserModel.ClearSubscribeCache(l.ctx, trialSubscribe); err != nil { l.Errorw("ClearSubscribeCache failed", logger.Field("error", err.Error()), logger.Field("userSubscribeId", trialSubscribe.Id)) @@ -202,7 +264,7 @@ func (l *UserRegisterLogic) UserRegister(req *types.UserRegisterRequest) (resp * } loginStatus := true defer func() { - if token != "" && userInfo.Id != 0 { + if token != "" && userInfo != nil && userInfo.Id != 0 { loginLog := log.Login{ Method: "email", LoginIP: req.IP, @@ -275,3 +337,67 @@ func (l *UserRegisterLogic) activeTrial(uid int64) (*user.Subscribe, error) { } return userSub, nil } + +func (l *UserRegisterLogic) verifyCaptcha(req *types.UserRegisterRequest) error { + // Get verify config from database + verifyCfg, err := l.svcCtx.SystemModel.GetVerifyConfig(l.ctx) + if err != nil { + l.Logger.Error("[UserRegisterLogic] GetVerifyConfig error: ", logger.Field("error", err.Error())) + return errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "GetVerifyConfig error: %v", err.Error()) + } + + var config struct { + CaptchaType string `json:"captcha_type"` + EnableUserRegisterCaptcha bool `json:"enable_user_register_captcha"` + TurnstileSecret string `json:"turnstile_secret"` + } + tool.SystemConfigSliceReflectToStruct(verifyCfg, &config) + + // Check if user register captcha is enabled + if !config.EnableUserRegisterCaptcha { + return nil + } + + // Verify based on captcha type + if config.CaptchaType == "local" { + if req.CaptchaId == "" || req.CaptchaCode == "" { + return errors.Wrapf(xerr.NewErrCode(xerr.VerifyCodeError), "captcha required") + } + + captchaService := captcha.NewService(captcha.Config{ + Type: captcha.CaptchaTypeLocal, + RedisClient: l.svcCtx.Redis, + }) + + valid, err := captchaService.Verify(l.ctx, req.CaptchaId, req.CaptchaCode, req.IP) + if err != nil { + l.Logger.Error("[UserRegisterLogic] Verify captcha error: ", logger.Field("error", err.Error())) + return errors.Wrapf(xerr.NewErrCode(xerr.VerifyCodeError), "verify captcha error") + } + + if !valid { + return errors.Wrapf(xerr.NewErrCode(xerr.VerifyCodeError), "invalid captcha") + } + } else if config.CaptchaType == "turnstile" { + if req.CfToken == "" { + return errors.Wrapf(xerr.NewErrCode(xerr.VerifyCodeError), "captcha required") + } + + captchaService := captcha.NewService(captcha.Config{ + Type: captcha.CaptchaTypeTurnstile, + TurnstileSecret: config.TurnstileSecret, + }) + + valid, err := captchaService.Verify(l.ctx, req.CfToken, "", req.IP) + if err != nil { + l.Logger.Error("[UserRegisterLogic] Verify turnstile error: ", logger.Field("error", err.Error())) + return errors.Wrapf(xerr.NewErrCode(xerr.VerifyCodeError), "verify captcha error") + } + + if !valid { + return errors.Wrapf(xerr.NewErrCode(xerr.VerifyCodeError), "invalid captcha") + } + } + + return nil +} diff --git a/internal/logic/common/getGlobalConfigLogic.go b/internal/logic/common/getGlobalConfigLogic.go index 502e098..f470b46 100644 --- a/internal/logic/common/getGlobalConfigLogic.go +++ b/internal/logic/common/getGlobalConfigLogic.go @@ -41,6 +41,11 @@ func (l *GetGlobalConfigLogic) GetGlobalConfig() (resp *types.GetGlobalConfigRes l.Logger.Error("[GetGlobalConfigLogic] GetVerifyCodeConfig error: ", logger.Field("error", err.Error())) return nil, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "GetVerifyCodeConfig error: %v", err.Error()) } + verifyCfg, err := l.svcCtx.SystemModel.GetVerifyConfig(l.ctx) + if err != nil { + l.Logger.Error("[GetGlobalConfigLogic] GetVerifyConfig error: ", logger.Field("error", err.Error())) + return nil, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "GetVerifyConfig error: %v", err.Error()) + } tool.DeepCopy(&resp.Site, l.svcCtx.Config.Site) tool.DeepCopy(&resp.Subscribe, l.svcCtx.Config.Subscribe) @@ -52,17 +57,12 @@ func (l *GetGlobalConfigLogic) GetGlobalConfig() (resp *types.GetGlobalConfigRes tool.DeepCopy(&resp.Invite, l.svcCtx.Config.Invite) tool.SystemConfigSliceReflectToStruct(currencyCfg, &resp.Currency) tool.SystemConfigSliceReflectToStruct(verifyCodeCfg, &resp.VerifyCode) + tool.SystemConfigSliceReflectToStruct(verifyCfg, &resp.Verify) if report.IsGatewayMode() { resp.Subscribe.SubscribePath = "/sub" + l.svcCtx.Config.Subscribe.SubscribePath } - resp.Verify = types.VeifyConfig{ - TurnstileSiteKey: l.svcCtx.Config.Verify.TurnstileSiteKey, - EnableLoginVerify: l.svcCtx.Config.Verify.LoginVerify, - EnableRegisterVerify: l.svcCtx.Config.Verify.RegisterVerify, - EnableResetPasswordVerify: l.svcCtx.Config.Verify.ResetPasswordVerify, - } var methods []string // auth methods diff --git a/internal/logic/public/subscribe/queryUserSubscribeNodeListLogic.go b/internal/logic/public/subscribe/queryUserSubscribeNodeListLogic.go index 2530ae7..0e25486 100644 --- a/internal/logic/public/subscribe/queryUserSubscribeNodeListLogic.go +++ b/internal/logic/public/subscribe/queryUserSubscribeNodeListLogic.go @@ -104,18 +104,23 @@ func fillUserSubscribeInfoEntitlementFields(sub *types.UserSubscribeInfo, entitl func (l *QueryUserSubscribeNodeListLogic) getServers(userSub *user.Subscribe) (userSubscribeNodes []*types.UserSubscribeNodeInfo, err error) { userSubscribeNodes = make([]*types.UserSubscribeNodeInfo, 0) if l.isSubscriptionExpired(userSub) { - return l.createExpiredServers(), nil + return l.createExpiredServers(userSub), nil } - subDetails, err := l.svcCtx.SubscribeModel.FindOne(l.ctx, userSub.SubscribeId) + // Check if group management is enabled + var groupEnabled string + err = l.svcCtx.DB.Table("system"). + Where("`category` = ? AND `key` = ?", "group", "enabled"). + Select("value").Scan(&groupEnabled).Error + if err != nil { - l.Errorw("[Generate Subscribe]find subscribe details error: %v", logger.Field("error", err.Error())) - return nil, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "find subscribe details error: %v", err.Error()) + l.Debugw("[GetServers] Failed to check group enabled", logger.Field("error", err.Error())) + // Continue with tag-based filtering } nodeIds := tool.StringToInt64Slice(subDetails.Nodes) tags := normalizeSubscribeNodeTags(subDetails.NodeTags) - l.Debugf("[Generate Subscribe]nodes: %v, NodeTags: %v", nodeIds, tags) + isGroupEnabled := (groupEnabled == "true" || groupEnabled == "1") enable := true @@ -127,6 +132,7 @@ func (l *QueryUserSubscribeNodeListLogic) getServers(userSub *user.Subscribe) (u Enabled: &enable, // Only get enabled nodes }) + // Process nodes and create response if len(nodes) > 0 { var serverMapIds = make(map[int64]*node.Server) for _, n := range nodes { @@ -174,21 +180,241 @@ func (l *QueryUserSubscribeNodeListLogic) getServers(userSub *user.Subscribe) (u } l.Debugf("[Query Subscribe]found servers: %v", len(nodes)) - - if err != nil { - l.Errorw("[Generate Subscribe]find server details error: %v", logger.Field("error", err.Error())) - return nil, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "find server details error: %v", err.Error()) - } - logger.Debugf("[Generate Subscribe]found servers: %v", len(nodes)) return userSubscribeNodes, nil } +// getNodesByGroup gets nodes based on user subscription node_group_id with priority fallback +func (l *QueryUserSubscribeNodeListLogic) getNodesByGroup(userSub *user.Subscribe) ([]*node.Node, error) { + // 按优先级获取 node_group_id:user_subscribe.node_group_id > subscribe.node_group_id > subscribe.node_group_ids[0] + nodeGroupId := int64(0) + source := "" + var directNodeIds []int64 + + // 优先级1: user_subscribe.node_group_id + if userSub.NodeGroupId != 0 { + nodeGroupId = userSub.NodeGroupId + source = "user_subscribe.node_group_id" + } + + // 获取 subscribe 详情(用于获取 node_group_id 和直接分配的节点) + subDetails, err := l.svcCtx.SubscribeModel.FindOne(l.ctx, userSub.SubscribeId) + if err != nil { + l.Errorw("[GetNodesByGroup] find subscribe details error", logger.Field("error", err.Error())) + return nil, err + } + + // 获取直接分配的节点ID + directNodeIds = tool.StringToInt64Slice(subDetails.Nodes) + l.Debugf("[GetNodesByGroup] direct nodes: %v", directNodeIds) + + // 如果 user_subscribe 没有 node_group_id,从 subscribe 获取 + if nodeGroupId == 0 { + // 优先级2: subscribe.node_group_id + if subDetails.NodeGroupId != 0 { + nodeGroupId = subDetails.NodeGroupId + source = "subscribe.node_group_id" + } else if len(subDetails.NodeGroupIds) > 0 { + // 优先级3: subscribe.node_group_ids[0] + nodeGroupId = subDetails.NodeGroupIds[0] + source = "subscribe.node_group_ids[0]" + } + } + + l.Debugf("[GetNodesByGroup] Using %s: %v", source, nodeGroupId) + + // 查询所有启用的节点 + enable := true + _, allNodes, err := l.svcCtx.NodeModel.FilterNodeList(l.ctx, &node.FilterNodeParams{ + Page: 0, + Size: 10000, + Enabled: &enable, + }) + if err != nil { + l.Errorw("[GetNodesByGroup] FilterNodeList error", logger.Field("error", err.Error())) + return nil, err + } + + // 过滤节点 + var resultNodes []*node.Node + nodeIdMap := make(map[int64]bool) + + for _, n := range allNodes { + // 1. 公共节点(node_group_ids 为空),所有人可见 + if len(n.NodeGroupIds) == 0 { + if !nodeIdMap[n.Id] { + resultNodes = append(resultNodes, n) + nodeIdMap[n.Id] = true + } + continue + } + + // 2. 如果有节点组,检查节点是否属于该节点组 + if nodeGroupId != 0 { + for _, gid := range n.NodeGroupIds { + if gid == nodeGroupId { + if !nodeIdMap[n.Id] { + resultNodes = append(resultNodes, n) + nodeIdMap[n.Id] = true + } + break + } + } + } + } + + // 3. 添加直接分配的节点 + if len(directNodeIds) > 0 { + for _, n := range allNodes { + if tool.Contains(directNodeIds, n.Id) && !nodeIdMap[n.Id] { + resultNodes = append(resultNodes, n) + nodeIdMap[n.Id] = true + } + } + } + + l.Debugf("[GetNodesByGroup] Found %d nodes (group=%d, direct=%d)", len(resultNodes), nodeGroupId, len(directNodeIds)) + return resultNodes, nil +} + +// getNodesByTag gets nodes based on subscribe node_ids and tags +func (l *QueryUserSubscribeNodeListLogic) getNodesByTag(userSub *user.Subscribe) ([]*node.Node, error) { + subDetails, err := l.svcCtx.SubscribeModel.FindOne(l.ctx, userSub.SubscribeId) + if err != nil { + l.Errorw("[Generate Subscribe]find subscribe details error: %v", logger.Field("error", err.Error())) + return nil, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "find subscribe details error: %v", err.Error()) + } + + nodeIds := tool.StringToInt64Slice(subDetails.Nodes) + tags := strings.Split(subDetails.NodeTags, ",") + newTags := make([]string, 0) + for _, t := range tags { + if t != "" { + newTags = append(newTags, t) + } + } + tags = newTags + l.Debugf("[Generate Subscribe]nodes: %v, NodeTags: %v", nodeIds, tags) + + enable := true + _, nodes, err := l.svcCtx.NodeModel.FilterNodeList(l.ctx, &node.FilterNodeParams{ + Page: 0, + Size: 1000, + NodeId: nodeIds, + Tag: tags, + Enabled: &enable, // Only get enabled nodes + }) + + return nodes, err +} + +// getAllNodes returns all enabled nodes +func (l *QueryUserSubscribeNodeListLogic) getAllNodes() ([]*node.Node, error) { + enable := true + _, nodes, err := l.svcCtx.NodeModel.FilterNodeList(l.ctx, &node.FilterNodeParams{ + Page: 0, + Size: 1000, + Enabled: &enable, + }) + + return nodes, err +} + func (l *QueryUserSubscribeNodeListLogic) isSubscriptionExpired(userSub *user.Subscribe) bool { return userSub.ExpireTime.Unix() < time.Now().Unix() && userSub.ExpireTime.Unix() != 0 } -func (l *QueryUserSubscribeNodeListLogic) createExpiredServers() []*types.UserSubscribeNodeInfo { - return nil +func (l *QueryUserSubscribeNodeListLogic) createExpiredServers(userSub *user.Subscribe) []*types.UserSubscribeNodeInfo { + // 1. 查询过期节点组 + var expiredGroup group.NodeGroup + err := l.svcCtx.DB.Where("is_expired_group = ?", true).First(&expiredGroup).Error + if err != nil { + l.Debugw("no expired node group configured", logger.Field("error", err)) + return nil + } + + // 2. 检查用户是否在过期天数限制内 + expiredDays := int(time.Since(userSub.ExpireTime).Hours() / 24) + if expiredDays > expiredGroup.ExpiredDaysLimit { + l.Debugf("user subscription expired %d days, exceeds limit %d days", expiredDays, expiredGroup.ExpiredDaysLimit) + return nil + } + + // 3. 检查用户已使用流量是否超过限制(仅使用过期期间的流量) + if expiredGroup.MaxTrafficGBExpired != nil && *expiredGroup.MaxTrafficGBExpired > 0 { + usedTrafficGB := (userSub.ExpiredDownload + userSub.ExpiredUpload) / (1024 * 1024 * 1024) + if usedTrafficGB >= *expiredGroup.MaxTrafficGBExpired { + l.Debugf("user expired traffic %d GB, exceeds expired group limit %d GB", usedTrafficGB, *expiredGroup.MaxTrafficGBExpired) + return nil + } + } + + // 4. 查询过期节点组的节点 + enable := true + _, nodes, err := l.svcCtx.NodeModel.FilterNodeList(l.ctx, &node.FilterNodeParams{ + Page: 0, + Size: 1000, + NodeGroupIds: []int64{expiredGroup.Id}, + Enabled: &enable, + }) + if err != nil { + l.Errorw("failed to query expired group nodes", logger.Field("error", err)) + return nil + } + + if len(nodes) == 0 { + l.Debug("no nodes found in expired group") + return nil + } + + // 5. 查询服务器信息 + var serverMapIds = make(map[int64]*node.Server) + for _, n := range nodes { + serverMapIds[n.ServerId] = nil + } + var serverIds []int64 + for k := range serverMapIds { + serverIds = append(serverIds, k) + } + + servers, err := l.svcCtx.NodeModel.QueryServerList(l.ctx, serverIds) + if err != nil { + l.Errorw("failed to query servers", logger.Field("error", err)) + return nil + } + + for _, s := range servers { + serverMapIds[s.Id] = s + } + + // 6. 构建节点列表 + userSubscribeNodes := make([]*types.UserSubscribeNodeInfo, 0, len(nodes)) + for _, n := range nodes { + server := serverMapIds[n.ServerId] + if server == nil { + continue + } + userSubscribeNode := &types.UserSubscribeNodeInfo{ + Id: n.Id, + Name: n.Name, + Uuid: userSub.UUID, + Protocol: n.Protocol, + Protocols: server.Protocols, + Port: n.Port, + Address: n.Address, + Tags: strings.Split(n.Tags, ","), + Country: server.Country, + City: server.City, + Latitude: server.Latitude, + Longitude: server.Longitude, + LongitudeCenter: server.LongitudeCenter, + LatitudeCenter: server.LatitudeCenter, + CreatedAt: n.CreatedAt.Unix(), + } + userSubscribeNodes = append(userSubscribeNodes, userSubscribeNode) + } + + l.Infof("returned %d nodes from expired group for user %d (expired %d days)", len(userSubscribeNodes), userSub.UserId, expiredDays) + return userSubscribeNodes } func (l *QueryUserSubscribeNodeListLogic) getFirstHostLine() string { diff --git a/internal/logic/public/user/getUserTrafficStatsLogic.go b/internal/logic/public/user/getUserTrafficStatsLogic.go new file mode 100644 index 0000000..e46cb4b --- /dev/null +++ b/internal/logic/public/user/getUserTrafficStatsLogic.go @@ -0,0 +1,138 @@ +package user + +import ( + "context" + "strconv" + "time" + + "gorm.io/gorm" + + "github.com/perfect-panel/server/internal/model/user" + "github.com/perfect-panel/server/internal/svc" + "github.com/perfect-panel/server/internal/types" + "github.com/perfect-panel/server/pkg/constant" + "github.com/perfect-panel/server/pkg/logger" + "github.com/perfect-panel/server/pkg/xerr" + "github.com/pkg/errors" +) + +type GetUserTrafficStatsLogic struct { + logger.Logger + ctx context.Context + svcCtx *svc.ServiceContext +} + +// Get User Traffic Statistics +func NewGetUserTrafficStatsLogic(ctx context.Context, svcCtx *svc.ServiceContext) *GetUserTrafficStatsLogic { + return &GetUserTrafficStatsLogic{ + Logger: logger.WithContext(ctx), + ctx: ctx, + svcCtx: svcCtx, + } +} + +func (l *GetUserTrafficStatsLogic) GetUserTrafficStats(req *types.GetUserTrafficStatsRequest) (resp *types.GetUserTrafficStatsResponse, err error) { + // 获取当前用户 + u, ok := l.ctx.Value(constant.CtxKeyUser).(*user.User) + if !ok { + logger.Error("current user is not found in context") + return nil, errors.Wrapf(xerr.NewErrCode(xerr.InvalidAccess), "Invalid Access") + } + + // 将字符串 ID 转换为 int64 + userSubscribeId, err := strconv.ParseInt(req.UserSubscribeId, 10, 64) + if err != nil { + l.Errorw("[GetUserTrafficStats] Invalid User Subscribe ID:", + logger.Field("user_subscribe_id", req.UserSubscribeId), + logger.Field("err", err.Error())) + return nil, errors.Wrapf(xerr.NewErrCode(xerr.InvalidAccess), "Invalid subscription ID") + } + + // 验证订阅归属权 - 直接查询 user_subscribe 表 + var userSubscribe struct { + Id int64 + UserId int64 + } + err = l.svcCtx.DB.WithContext(l.ctx). + Table("user_subscribe"). + Select("id, user_id"). + Where("id = ?", userSubscribeId). + First(&userSubscribe).Error + + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + l.Errorw("[GetUserTrafficStats] User Subscribe Not Found:", + logger.Field("user_subscribe_id", userSubscribeId), + logger.Field("user_id", u.Id)) + return nil, errors.Wrapf(xerr.NewErrCode(xerr.InvalidAccess), "Subscription not found") + } + l.Errorw("[GetUserTrafficStats] Query User Subscribe Error:", logger.Field("err", err.Error())) + return nil, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "Query User Subscribe Error") + } + + if userSubscribe.UserId != u.Id { + l.Errorw("[GetUserTrafficStats] User Subscribe Access Denied:", + logger.Field("user_subscribe_id", userSubscribeId), + logger.Field("subscribe_user_id", userSubscribe.UserId), + logger.Field("current_user_id", u.Id)) + return nil, errors.Wrapf(xerr.NewErrCode(xerr.InvalidAccess), "Invalid Access") + } + + // 计算时间范围 + now := time.Now() + startDate := now.AddDate(0, 0, -req.Days+1) + startDate = time.Date(startDate.Year(), startDate.Month(), startDate.Day(), 0, 0, 0, 0, time.Local) + + // 初始化响应 + resp = &types.GetUserTrafficStatsResponse{ + List: make([]types.DailyTrafficStats, 0, req.Days), + TotalUpload: 0, + TotalDownload: 0, + TotalTraffic: 0, + } + + // 按天查询流量数据 + for i := 0; i < req.Days; i++ { + currentDate := startDate.AddDate(0, 0, i) + dayStart := time.Date(currentDate.Year(), currentDate.Month(), currentDate.Day(), 0, 0, 0, 0, time.Local) + dayEnd := dayStart.Add(24 * time.Hour).Add(-time.Nanosecond) + + // 查询当天流量 + var dailyTraffic struct { + Upload int64 + Download int64 + } + + // 直接使用 model 的查询方法 + err := l.svcCtx.DB.WithContext(l.ctx). + Table("traffic_log"). + Select("COALESCE(SUM(upload), 0) as upload, COALESCE(SUM(download), 0) as download"). + Where("user_id = ? AND subscribe_id = ? AND timestamp BETWEEN ? AND ?", + u.Id, userSubscribeId, dayStart, dayEnd). + Scan(&dailyTraffic).Error + + if err != nil { + l.Errorw("[GetUserTrafficStats] Query Daily Traffic Error:", + logger.Field("date", currentDate.Format("2006-01-02")), + logger.Field("err", err.Error())) + return nil, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "Query Traffic Error") + } + + // 添加到结果列表 + total := dailyTraffic.Upload + dailyTraffic.Download + resp.List = append(resp.List, types.DailyTrafficStats{ + Date: currentDate.Format("2006-01-02"), + Upload: dailyTraffic.Upload, + Download: dailyTraffic.Download, + Total: total, + }) + + // 累加总计 + resp.TotalUpload += dailyTraffic.Upload + resp.TotalDownload += dailyTraffic.Download + } + + resp.TotalTraffic = resp.TotalUpload + resp.TotalDownload + + return resp, nil +} diff --git a/internal/logic/public/user/queryUserSubscribeLogic.go b/internal/logic/public/user/queryUserSubscribeLogic.go index fbbcaeb..fff0edb 100644 --- a/internal/logic/public/user/queryUserSubscribeLogic.go +++ b/internal/logic/public/user/queryUserSubscribeLogic.go @@ -3,6 +3,7 @@ package user import ( "context" "encoding/json" + "strconv" "time" commonLogic "github.com/perfect-panel/server/internal/logic/common" @@ -58,6 +59,9 @@ func (l *QueryUserSubscribeLogic) QueryUserSubscribe() (resp *types.QueryUserSub var sub types.UserSubscribe tool.DeepCopy(&sub, item) + // 填充 IdStr 字段,避免前端精度丢失 + sub.IdStr = strconv.FormatInt(item.Id, 10) + // 解析Discount字段 避免在续订时只能续订一个月 if item.Subscribe != nil && item.Subscribe.Discount != "" { var discounts []types.SubscribeDiscount diff --git a/internal/logic/server/getServerUserListLogic.go b/internal/logic/server/getServerUserListLogic.go index 70ea51f..8a871d0 100644 --- a/internal/logic/server/getServerUserListLogic.go +++ b/internal/logic/server/getServerUserListLogic.go @@ -4,10 +4,13 @@ import ( "encoding/json" "fmt" "strings" + "time" "github.com/gin-gonic/gin" + "github.com/perfect-panel/server/internal/model/group" "github.com/perfect-panel/server/internal/model/node" "github.com/perfect-panel/server/internal/model/subscribe" + "github.com/perfect-panel/server/internal/model/user" "github.com/perfect-panel/server/internal/svc" "github.com/perfect-panel/server/internal/types" @@ -55,6 +58,7 @@ func (l *GetServerUserListLogic) GetServerUserList(req *types.GetServerUserListR return nil, err } + // 查询该服务器上该协议的所有节点(包括属于节点组的节点) _, nodes, err := l.svcCtx.NodeModel.FilterNodeList(l.ctx, &node.FilterNodeParams{ Page: 1, Size: 1000, @@ -65,25 +69,74 @@ func (l *GetServerUserListLogic) GetServerUserList(req *types.GetServerUserListR l.Errorw("FilterNodeList error", logger.Field("error", err.Error())) return nil, err } - var nodeTag []string + + if len(nodes) == 0 { + return &types.GetServerUserListResponse{ + Users: []types.ServerUser{ + { + Id: 1, + UUID: uuidx.NewUUID().String(), + }, + }, + }, nil + } + + // 收集所有唯一的节点组 ID + nodeGroupMap := make(map[int64]bool) // nodeGroupId -> true var nodeIds []int64 + var nodeTags []string + for _, n := range nodes { nodeIds = append(nodeIds, n.Id) if n.Tags != "" { - nodeTag = append(nodeTag, strings.Split(n.Tags, ",")...) + nodeTags = append(nodeTags, strings.Split(n.Tags, ",")...) + } + // 收集节点组 ID + if len(n.NodeGroupIds) > 0 { + for _, gid := range n.NodeGroupIds { + if gid > 0 { + nodeGroupMap[gid] = true + } + } } } - _, subs, err := l.svcCtx.SubscribeModel.FilterList(l.ctx, &subscribe.FilterParams{ - Page: 1, - Size: 9999, - Node: nodeIds, - Tags: nodeTag, - }) - if err != nil { - l.Errorw("QuerySubscribeIdsByServerIdAndServerGroupId error", logger.Field("error", err.Error())) - return nil, err + // 获取所有节点组 ID + nodeGroupIds := make([]int64, 0, len(nodeGroupMap)) + for gid := range nodeGroupMap { + nodeGroupIds = append(nodeGroupIds, gid) } + + // 查询订阅: + // 1. 如果有节点组,查询匹配这些节点组的订阅 + // 2. 如果没有节点组,查询使用节点 ID 或 tags 的订阅 + var subs []*subscribe.Subscribe + if len(nodeGroupIds) > 0 { + // 节点组模式:查询 node_group_id 或 node_group_ids 匹配的订阅 + _, subs, err = l.svcCtx.SubscribeModel.FilterListByNodeGroups(l.ctx, &subscribe.FilterByNodeGroupsParams{ + Page: 1, + Size: 9999, + NodeGroupIds: nodeGroupIds, + }) + if err != nil { + l.Errorw("FilterListByNodeGroups error", logger.Field("error", err.Error())) + return nil, err + } + } else { + // 传统模式:查询匹配节点 ID 或 tags 的订阅 + nodeTags = tool.RemoveDuplicateElements(nodeTags...) + _, subs, err = l.svcCtx.SubscribeModel.FilterList(l.ctx, &subscribe.FilterParams{ + Page: 1, + Size: 9999, + Node: nodeIds, + Tags: nodeTags, + }) + if err != nil { + l.Errorw("FilterList error", logger.Field("error", err.Error())) + return nil, err + } + } + if len(subs) == 0 { return &types.GetServerUserListResponse{ Users: []types.ServerUser{ @@ -101,14 +154,33 @@ func (l *GetServerUserListLogic) GetServerUserList(req *types.GetServerUserListR return nil, err } for _, datum := range data { + if !l.shouldIncludeServerUser(datum, nodeGroupIds) { + continue + } + + // 计算该用户的实际限速值(考虑按量限速规则) + effectiveSpeedLimit := l.calculateEffectiveSpeedLimit(sub, datum) + users = append(users, types.ServerUser{ Id: datum.Id, UUID: datum.UUID, - SpeedLimit: sub.SpeedLimit, + SpeedLimit: effectiveSpeedLimit, DeviceLimit: sub.DeviceLimit, }) } } + + // 处理过期订阅用户:如果当前节点属于过期节点组,添加符合条件的过期用户 + if len(nodeGroupIds) > 0 { + expiredUsers, expiredSpeedLimit := l.getExpiredUsers(nodeGroupIds) + for i := range expiredUsers { + if expiredSpeedLimit > 0 { + expiredUsers[i].SpeedLimit = expiredSpeedLimit + } + } + users = append(users, expiredUsers...) + } + if len(users) == 0 { users = append(users, types.ServerUser{ Id: 1, @@ -131,3 +203,175 @@ func (l *GetServerUserListLogic) GetServerUserList(req *types.GetServerUserListR } return resp, nil } + +func (l *GetServerUserListLogic) shouldIncludeServerUser(userSub *user.Subscribe, serverNodeGroupIds []int64) bool { + if userSub == nil { + return false + } + + if userSub.ExpireTime.Unix() == 0 || userSub.ExpireTime.After(time.Now()) { + return true + } + + return l.canUseExpiredNodeGroup(userSub, serverNodeGroupIds) +} + +func (l *GetServerUserListLogic) getExpiredUsers(serverNodeGroupIds []int64) ([]types.ServerUser, int64) { + var expiredGroup group.NodeGroup + if err := l.svcCtx.DB.Where("is_expired_group = ?", true).First(&expiredGroup).Error; err != nil { + return nil, 0 + } + + if !tool.Contains(serverNodeGroupIds, expiredGroup.Id) { + return nil, 0 + } + + var expiredSubs []*user.Subscribe + if err := l.svcCtx.DB.Where("status = ?", 3).Find(&expiredSubs).Error; err != nil { + l.Errorw("query expired subscriptions failed", logger.Field("error", err.Error())) + return nil, 0 + } + + users := make([]types.ServerUser, 0) + seen := make(map[int64]bool) + for _, userSub := range expiredSubs { + if !l.checkExpiredUserEligibility(userSub, &expiredGroup) { + continue + } + if seen[userSub.Id] { + continue + } + seen[userSub.Id] = true + users = append(users, types.ServerUser{ + Id: userSub.Id, + UUID: userSub.UUID, + }) + } + + return users, int64(expiredGroup.SpeedLimit) +} + +func (l *GetServerUserListLogic) checkExpiredUserEligibility(userSub *user.Subscribe, expiredGroup *group.NodeGroup) bool { + expiredDays := int(time.Since(userSub.ExpireTime).Hours() / 24) + if expiredDays > expiredGroup.ExpiredDaysLimit { + return false + } + + if expiredGroup.MaxTrafficGBExpired != nil && *expiredGroup.MaxTrafficGBExpired > 0 { + usedTrafficGB := (userSub.ExpiredDownload + userSub.ExpiredUpload) / (1024 * 1024 * 1024) + if usedTrafficGB >= *expiredGroup.MaxTrafficGBExpired { + return false + } + } + + return true +} + +func (l *GetServerUserListLogic) canUseExpiredNodeGroup(userSub *user.Subscribe, serverNodeGroupIds []int64) bool { + var expiredGroup group.NodeGroup + if err := l.svcCtx.DB.Where("is_expired_group = ?", true).First(&expiredGroup).Error; err != nil { + return false + } + + if !tool.Contains(serverNodeGroupIds, expiredGroup.Id) { + return false + } + + expiredDays := int(time.Since(userSub.ExpireTime).Hours() / 24) + if expiredDays > expiredGroup.ExpiredDaysLimit { + return false + } + + if expiredGroup.MaxTrafficGBExpired != nil && *expiredGroup.MaxTrafficGBExpired > 0 { + usedTrafficGB := (userSub.ExpiredDownload + userSub.ExpiredUpload) / (1024 * 1024 * 1024) + if usedTrafficGB >= *expiredGroup.MaxTrafficGBExpired { + return false + } + } + + return true +} + +// calculateEffectiveSpeedLimit 计算用户的实际限速值(考虑按量限速规则) +func (l *GetServerUserListLogic) calculateEffectiveSpeedLimit(sub *subscribe.Subscribe, userSub *user.Subscribe) int64 { + baseSpeedLimit := sub.SpeedLimit + + // 解析 traffic_limit 规则 + if sub.TrafficLimit == "" { + return baseSpeedLimit + } + + var trafficLimitRules []types.TrafficLimit + if err := json.Unmarshal([]byte(sub.TrafficLimit), &trafficLimitRules); err != nil { + l.Errorw("[calculateEffectiveSpeedLimit] Failed to unmarshal traffic_limit", + logger.Field("error", err.Error()), + logger.Field("traffic_limit", sub.TrafficLimit)) + return baseSpeedLimit + } + + if len(trafficLimitRules) == 0 { + return baseSpeedLimit + } + + // 查询用户指定时段的流量使用情况 + now := time.Now() + for _, rule := range trafficLimitRules { + var startTime, endTime time.Time + + if rule.StatType == "hour" { + // 按小时统计:根据 StatValue 计算时间范围(往前推 N 小时) + if rule.StatValue <= 0 { + continue + } + // 从当前时间往前推 StatValue 小时 + startTime = now.Add(-time.Duration(rule.StatValue) * time.Hour) + endTime = now + } else if rule.StatType == "day" { + // 按天统计:根据 StatValue 计算时间范围(往前推 N 天) + if rule.StatValue <= 0 { + continue + } + // 从当前时间往前推 StatValue 天 + startTime = now.AddDate(0, 0, -int(rule.StatValue)) + endTime = now + } else { + continue + } + + // 查询该时段的流量使用 + var usedTraffic struct { + Upload int64 + Download int64 + } + err := l.svcCtx.DB.WithContext(l.ctx.Request.Context()). + Table("traffic_log"). + Select("COALESCE(SUM(upload), 0) as upload, COALESCE(SUM(download), 0) as download"). + Where("user_id = ? AND subscribe_id = ? AND timestamp >= ? AND timestamp < ?", + userSub.UserId, userSub.Id, startTime, endTime). + Scan(&usedTraffic).Error + + if err != nil { + l.Errorw("[calculateEffectiveSpeedLimit] Failed to query traffic usage", + logger.Field("error", err.Error()), + logger.Field("user_id", userSub.UserId), + logger.Field("subscribe_id", userSub.Id)) + continue + } + + // 计算已使用流量(GB) + usedGB := float64(usedTraffic.Upload+usedTraffic.Download) / (1024 * 1024 * 1024) + + // 如果已使用流量达到或超过阈值,应用限速 + if usedGB >= float64(rule.TrafficUsage) { + // 如果规则限速大于0,应用该限速 + if rule.SpeedLimit > 0 { + // 如果基础限速为0(无限速)或规则限速更严格,使用规则限速 + if baseSpeedLimit == 0 || rule.SpeedLimit < baseSpeedLimit { + return rule.SpeedLimit + } + } + } + } + + return baseSpeedLimit +} diff --git a/internal/logic/subscribe/subscribeLogic.go b/internal/logic/subscribe/subscribeLogic.go index 28f9ecb..7b88160 100644 --- a/internal/logic/subscribe/subscribeLogic.go +++ b/internal/logic/subscribe/subscribeLogic.go @@ -8,6 +8,7 @@ import ( "github.com/perfect-panel/server/adapter" "github.com/perfect-panel/server/internal/model/client" + "github.com/perfect-panel/server/internal/model/group" "github.com/perfect-panel/server/internal/model/log" "github.com/perfect-panel/server/internal/model/node" "github.com/perfect-panel/server/internal/report" @@ -206,6 +207,19 @@ func (l *SubscribeLogic) logSubscribeActivity(subscribeStatus bool, userSub *use func (l *SubscribeLogic) getServers(userSub *user.Subscribe) ([]*node.Node, error) { if l.isSubscriptionExpired(userSub) { + // 尝试获取过期节点组的节点 + expiredNodes, err := l.getExpiredGroupNodes(userSub) + if err != nil { + l.Errorw("[Generate Subscribe]get expired group nodes error", logger.Field("error", err.Error())) + return l.createExpiredServers(), nil + } + // 如果有符合条件的过期节点组节点,返回它们 + if len(expiredNodes) > 0 { + l.Debugf("[Generate Subscribe]user %d can use expired node group, nodes count: %d", userSub.UserId, len(expiredNodes)) + return expiredNodes, nil + } + // 否则返回假的过期节点 + l.Debugf("[Generate Subscribe]user %d cannot use expired node group, return fake expired nodes", userSub.UserId) return l.createExpiredServers(), nil } @@ -215,14 +229,133 @@ func (l *SubscribeLogic) getServers(userSub *user.Subscribe) ([]*node.Node, erro return nil, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "find subscribe details error: %v", err.Error()) } + // 判断是否使用分组模式 + isGroupMode := l.isGroupEnabled() + + if isGroupMode { + // === 分组模式:使用 node_group_id 获取节点 === + // 按优先级获取 node_group_id:user_subscribe.node_group_id > subscribe.node_group_id > subscribe.node_group_ids[0] + nodeGroupId := int64(0) + source := "" + + // 优先级1: user_subscribe.node_group_id + if userSub.NodeGroupId != 0 { + nodeGroupId = userSub.NodeGroupId + source = "user_subscribe.node_group_id" + } else { + // 优先级2 & 3: 从 subscribe 表获取 + if subDetails.NodeGroupId != 0 { + nodeGroupId = subDetails.NodeGroupId + source = "subscribe.node_group_id" + } else if len(subDetails.NodeGroupIds) > 0 { + // 优先级3: subscribe.node_group_ids[0] + nodeGroupId = subDetails.NodeGroupIds[0] + source = "subscribe.node_group_ids[0]" + } + } + + l.Debugf("[Generate Subscribe]group mode, using %s: %v", source, nodeGroupId) + + // 根据 node_group_id 获取节点 + enable := true + + // 1. 获取分组节点 + var groupNodes []*node.Node + if nodeGroupId > 0 { + params := &node.FilterNodeParams{ + Page: 0, + Size: 1000, + NodeGroupIds: []int64{nodeGroupId}, + Enabled: &enable, + Preload: true, + } + _, groupNodes, err = l.svc.NodeModel.FilterNodeList(l.ctx.Request.Context(), params) + + if err != nil { + l.Errorw("[Generate Subscribe]filter nodes by group error", logger.Field("error", err.Error())) + return nil, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "filter nodes by group error: %v", err.Error()) + } + l.Debugf("[Generate Subscribe]found %d nodes for node_group_id=%d", len(groupNodes), nodeGroupId) + } + + // 2. 获取公共节点(NodeGroupIds 为空的节点) + _, allNodes, err := l.svc.NodeModel.FilterNodeList(l.ctx.Request.Context(), &node.FilterNodeParams{ + Page: 0, + Size: 1000, + Enabled: &enable, + Preload: true, + }) + + if err != nil { + l.Errorw("[Generate Subscribe]filter all nodes error", logger.Field("error", err.Error())) + return nil, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "filter all nodes error: %v", err.Error()) + } + + // 过滤出公共节点 + var publicNodes []*node.Node + for _, n := range allNodes { + if len(n.NodeGroupIds) == 0 { + publicNodes = append(publicNodes, n) + } + } + l.Debugf("[Generate Subscribe]found %d public nodes (node_group_ids is empty)", len(publicNodes)) + + // 3. 合并分组节点和公共节点 + nodesMap := make(map[int64]*node.Node) + for _, n := range groupNodes { + nodesMap[n.Id] = n + } + for _, n := range publicNodes { + if _, exists := nodesMap[n.Id]; !exists { + nodesMap[n.Id] = n + } + } + + // 转换为切片 + var result []*node.Node + for _, n := range nodesMap { + result = append(result, n) + } + + l.Debugf("[Generate Subscribe]total nodes (group + public): %d (group: %d, public: %d)", len(result), len(groupNodes), len(publicNodes)) + + // 查询节点组信息,获取节点组名称(仅当用户有分组时) + if nodeGroupId > 0 { + type NodeGroupInfo struct { + Id int64 + Name string + } + var nodeGroupInfo NodeGroupInfo + err = l.svc.DB.Table("node_group").Select("id, name").Where("id = ?", nodeGroupId).First(&nodeGroupInfo).Error + if err != nil { + l.Infow("[Generate Subscribe]node group not found", logger.Field("nodeGroupId", nodeGroupId), logger.Field("error", err.Error())) + } + + // 如果节点组信息存在,为没有 tag 的分组节点设置节点组名称为 tag + if nodeGroupInfo.Id != 0 && nodeGroupInfo.Name != "" { + for _, n := range result { + // 只为分组节点设置 tag,公共节点不设置 + if n.Tags == "" && len(n.NodeGroupIds) > 0 { + n.Tags = nodeGroupInfo.Name + l.Debugf("[Generate Subscribe]set node_group name as tag for node %d: %s", n.Id, nodeGroupInfo.Name) + } + } + } + } + + return result, nil + } + + // === 标签模式:使用 node_ids 和 tags 获取节点 === nodeIds := tool.StringToInt64Slice(subDetails.Nodes) tags := tool.RemoveStringElement(strings.Split(subDetails.NodeTags, ","), "") - l.Debugf("[Generate Subscribe]nodes: %v, NodeTags: %v", len(nodeIds), len(tags)) + l.Debugf("[Generate Subscribe]tag mode, nodes: %v, NodeTags: %v", len(nodeIds), len(tags)) if len(nodeIds) == 0 && len(tags) == 0 { - logger.Infow("[Generate Subscribe]no subscribe nodes") + logger.Infow("[Generate Subscribe]no subscribe nodes configured") return []*node.Node{}, nil } + enable := true var nodes []*node.Node _, nodes, err = l.svc.NodeModel.FilterNodeList(l.ctx.Request.Context(), &node.FilterNodeParams{ @@ -231,16 +364,15 @@ func (l *SubscribeLogic) getServers(userSub *user.Subscribe) ([]*node.Node, erro NodeId: nodeIds, Tag: tool.RemoveDuplicateElements(tags...), Preload: true, - Enabled: &enable, // Only get enabled nodes + Enabled: &enable, }) - l.Debugf("[Query Subscribe]found servers: %v", len(nodes)) - if err != nil { l.Errorw("[Generate Subscribe]find server details error: %v", logger.Field("error", err.Error())) return nil, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "find server details error: %v", err.Error()) } - logger.Debugf("[Generate Subscribe]found servers: %v", len(nodes)) + + l.Debugf("[Generate Subscribe]found %d nodes in tag mode", len(nodes)) return nodes, nil } @@ -290,3 +422,66 @@ func (l *SubscribeLogic) getFirstHostLine() string { } return host } + +// isGroupEnabled 判断分组功能是否启用 +func (l *SubscribeLogic) isGroupEnabled() bool { + var value string + err := l.svc.DB.Table("system"). + Where("`category` = ? AND `key` = ?", "group", "enabled"). + Select("value"). + Scan(&value).Error + if err != nil { + l.Debugf("[SubscribeLogic]check group enabled failed: %v", err) + return false + } + return value == "true" || value == "1" +} + +// getExpiredGroupNodes 获取过期节点组的节点 +func (l *SubscribeLogic) getExpiredGroupNodes(userSub *user.Subscribe) ([]*node.Node, error) { + // 1. 查询过期节点组 + var expiredGroup group.NodeGroup + err := l.svc.DB.Where("is_expired_group = ?", true).First(&expiredGroup).Error + if err != nil { + l.Debugw("[SubscribeLogic]no expired node group configured", logger.Field("error", err.Error())) + return nil, err + } + + // 2. 检查用户是否在过期天数限制内 + expiredDays := int(time.Since(userSub.ExpireTime).Hours() / 24) + if expiredDays > expiredGroup.ExpiredDaysLimit { + l.Debugf("[SubscribeLogic]user %d subscription expired %d days, exceeds limit %d days", userSub.UserId, expiredDays, expiredGroup.ExpiredDaysLimit) + return nil, nil + } + + // 3. 检查用户已使用流量是否超过限制(仅使用过期期间的流量) + if expiredGroup.MaxTrafficGBExpired != nil && *expiredGroup.MaxTrafficGBExpired > 0 { + usedTrafficGB := (userSub.ExpiredDownload + userSub.ExpiredUpload) / (1024 * 1024 * 1024) + if usedTrafficGB >= *expiredGroup.MaxTrafficGBExpired { + l.Debugf("[SubscribeLogic]user %d expired traffic %d GB, exceeds expired group limit %d GB", userSub.UserId, usedTrafficGB, *expiredGroup.MaxTrafficGBExpired) + return nil, nil + } + } + + // 4. 查询过期节点组的节点 + enable := true + _, nodes, err := l.svc.NodeModel.FilterNodeList(l.ctx.Request.Context(), &node.FilterNodeParams{ + Page: 0, + Size: 1000, + NodeGroupIds: []int64{expiredGroup.Id}, + Enabled: &enable, + Preload: true, + }) + if err != nil { + l.Errorw("[SubscribeLogic]failed to query expired group nodes", logger.Field("error", err.Error())) + return nil, err + } + + if len(nodes) == 0 { + l.Debug("[SubscribeLogic]no nodes found in expired group") + return nil, nil + } + + l.Infof("[SubscribeLogic]returned %d nodes from expired group for user %d (expired %d days)", len(nodes), userSub.UserId, expiredDays) + return nodes, nil +} diff --git a/internal/model/group/history.go b/internal/model/group/history.go new file mode 100644 index 0000000..5ab1b0b --- /dev/null +++ b/internal/model/group/history.go @@ -0,0 +1,54 @@ +package group + +import ( + "time" + + "gorm.io/gorm" +) + +// GroupHistory 分组历史记录模型 +type GroupHistory struct { + Id int64 `gorm:"primaryKey"` + GroupMode string `gorm:"type:varchar(50);not null;index:idx_group_mode;comment:Group Mode: average/subscribe/traffic"` + TriggerType string `gorm:"type:varchar(50);not null;index:idx_trigger_type;comment:Trigger Type: manual/auto/schedule"` + State string `gorm:"type:varchar(50);not null;index:idx_state;comment:State: pending/running/completed/failed"` + TotalUsers int `gorm:"default:0;not null;comment:Total Users"` + SuccessCount int `gorm:"default:0;not null;comment:Success Count"` + FailedCount int `gorm:"default:0;not null;comment:Failed Count"` + StartTime *time.Time `gorm:"comment:Start Time"` + EndTime *time.Time `gorm:"comment:End Time"` + Operator string `gorm:"type:varchar(100);comment:Operator"` + ErrorMessage string `gorm:"type:TEXT;comment:Error Message"` + CreatedAt time.Time `gorm:"<-:create;index:idx_created_at;comment:Create Time"` +} + +// TableName 指定表名 +func (*GroupHistory) TableName() string { + return "group_history" +} + +// BeforeCreate GORM hook - 创建前回调 +func (gh *GroupHistory) BeforeCreate(tx *gorm.DB) error { + return nil +} + +// GroupHistoryDetail 分组历史详情模型 +type GroupHistoryDetail struct { + Id int64 `gorm:"primaryKey"` + HistoryId int64 `gorm:"not null;index:idx_history_id;comment:History ID"` + NodeGroupId int64 `gorm:"not null;index:idx_node_group_id;comment:Node Group ID"` + UserCount int `gorm:"default:0;not null;comment:User Count"` + NodeCount int `gorm:"default:0;not null;comment:Node Count"` + UserData string `gorm:"type:text;comment:User data JSON (id and email/phone)"` + CreatedAt time.Time `gorm:"<-:create;comment:Create Time"` +} + +// TableName 指定表名 +func (*GroupHistoryDetail) TableName() string { + return "group_history_detail" +} + +// BeforeCreate GORM hook - 创建前回调 +func (ghd *GroupHistoryDetail) BeforeCreate(tx *gorm.DB) error { + return nil +} diff --git a/internal/model/group/model.go b/internal/model/group/model.go new file mode 100644 index 0000000..77d5e1b --- /dev/null +++ b/internal/model/group/model.go @@ -0,0 +1,14 @@ +package group + +import ( + "gorm.io/gorm" +) + +// AutoMigrate 自动迁移数据库表 +func AutoMigrate(db *gorm.DB) error { + return db.AutoMigrate( + &NodeGroup{}, + &GroupHistory{}, + &GroupHistoryDetail{}, + ) +} diff --git a/internal/model/group/node_group.go b/internal/model/group/node_group.go new file mode 100644 index 0000000..a2fe3ee --- /dev/null +++ b/internal/model/group/node_group.go @@ -0,0 +1,34 @@ +package group + +import ( + "time" + + "gorm.io/gorm" +) + +// NodeGroup 节点组模型 +type NodeGroup struct { + Id int64 `gorm:"primaryKey"` + Name string `gorm:"type:varchar(255);not null;comment:Name"` + Description string `gorm:"type:varchar(500);comment:Description"` + Sort int `gorm:"default:0;index:idx_sort;comment:Sort Order"` + ForCalculation *bool `gorm:"default:true;not null;comment:For Calculation: whether this node group participates in grouping calculation"` + IsExpiredGroup *bool `gorm:"default:false;not null;index:idx_is_expired_group;comment:Is Expired Group"` + ExpiredDaysLimit int `gorm:"default:7;not null;comment:Expired days limit (days)"` + MaxTrafficGBExpired *int64 `gorm:"default:0;comment:Max traffic for expired users (GB)"` + SpeedLimit int `gorm:"default:0;not null;comment:Speed limit (KB/s)"` + MinTrafficGB *int64 `gorm:"default:0;comment:Minimum Traffic (GB) for this node group"` + MaxTrafficGB *int64 `gorm:"default:0;comment:Maximum Traffic (GB) for this node group"` + CreatedAt time.Time `gorm:"<-:create;comment:Create Time"` + UpdatedAt time.Time `gorm:"comment:Update Time"` +} + +// TableName 指定表名 +func (*NodeGroup) TableName() string { + return "node_group" +} + +// BeforeCreate GORM hook - 创建前回调 +func (ng *NodeGroup) BeforeCreate(tx *gorm.DB) error { + return nil +} diff --git a/internal/model/node/model.go b/internal/model/node/model.go index bede293..eb250bb 100644 --- a/internal/model/node/model.go +++ b/internal/model/node/model.go @@ -34,15 +34,16 @@ type FilterParams struct { } type FilterNodeParams struct { - Page int // Page Number - Size int // Page Size - NodeId []int64 // Node IDs - ServerId []int64 // Server IDs - Tag []string // Tags - Search string // Search Address or Name - Protocol string // Protocol - Preload bool // Preload Server - Enabled *bool // Enabled + Page int // Page Number + Size int // Page Size + NodeId []int64 // Node IDs + ServerId []int64 // Server IDs + Tag []string // Tags + NodeGroupIds []int64 // Node Group IDs + Search string // Search Address or Name + Protocol string // Protocol + Preload bool // Preload Server + Enabled *bool // Enabled } // FilterServerList Filter Server List @@ -97,6 +98,18 @@ func (m *customServerModel) FilterNodeList(ctx context.Context, params *FilterNo if len(params.Tag) > 0 { query = query.Scopes(InSet("tags", params.Tag)) } + if len(params.NodeGroupIds) > 0 { + // Filter by node_group_ids using JSON_CONTAINS for each group ID + // Multiple group IDs: node must belong to at least one of the groups + var conditions []string + for _, gid := range params.NodeGroupIds { + conditions = append(conditions, fmt.Sprintf("JSON_CONTAINS(node_group_ids, %d)", gid)) + } + if len(conditions) > 0 { + query = query.Where("(" + strings.Join(conditions, " OR ") + ")") + } + } + // If no NodeGroupIds specified, return all nodes (including public nodes) if params.Protocol != "" { query = query.Where("protocol = ?", params.Protocol) } diff --git a/internal/model/node/node.go b/internal/model/node/node.go index 89d665d..787ea32 100644 --- a/internal/model/node/node.go +++ b/internal/model/node/node.go @@ -1,25 +1,73 @@ package node import ( + "database/sql/driver" + "encoding/json" "time" "github.com/perfect-panel/server/pkg/logger" "gorm.io/gorm" ) +// JSONInt64Slice is a custom type for handling []int64 as JSON in database +type JSONInt64Slice []int64 + +// Scan implements sql.Scanner interface +func (j *JSONInt64Slice) Scan(value interface{}) error { + if value == nil { + *j = []int64{} + return nil + } + + // Handle []byte + bytes, ok := value.([]byte) + if !ok { + // Try to handle string + str, ok := value.(string) + if !ok { + *j = []int64{} + return nil + } + bytes = []byte(str) + } + + if len(bytes) == 0 { + *j = []int64{} + return nil + } + + // Check if it's a JSON array + if bytes[0] != '[' { + // Not a JSON array, return empty slice + *j = []int64{} + return nil + } + + return json.Unmarshal(bytes, j) +} + +// Value implements driver.Valuer interface +func (j JSONInt64Slice) Value() (driver.Value, error) { + if len(j) == 0 { + return "[]", nil + } + return json.Marshal(j) +} + type Node struct { - Id int64 `gorm:"primary_key"` - Name string `gorm:"type:varchar(100);not null;default:'';comment:Node Name"` - Tags string `gorm:"type:varchar(255);not null;default:'';comment:Tags"` - Port uint16 `gorm:"not null;default:0;comment:Connect Port"` - Address string `gorm:"type:varchar(255);not null;default:'';comment:Connect Address"` - ServerId int64 `gorm:"not null;default:0;comment:Server ID"` - Server *Server `gorm:"foreignKey:ServerId;references:Id"` - Protocol string `gorm:"type:varchar(100);not null;default:'';comment:Protocol"` - Enabled *bool `gorm:"type:boolean;not null;default:true;comment:Enabled"` - Sort int `gorm:"uniqueIndex;not null;default:0;comment:Sort"` - CreatedAt time.Time `gorm:"<-:create;comment:Creation Time"` - UpdatedAt time.Time `gorm:"comment:Update Time"` + Id int64 `gorm:"primary_key"` + Name string `gorm:"type:varchar(100);not null;default:'';comment:Node Name"` + Tags string `gorm:"type:varchar(255);not null;default:'';comment:Tags"` + Port uint16 `gorm:"not null;default:0;comment:Connect Port"` + Address string `gorm:"type:varchar(255);not null;default:'';comment:Connect Address"` + ServerId int64 `gorm:"not null;default:0;comment:Server ID"` + Server *Server `gorm:"foreignKey:ServerId;references:Id"` + Protocol string `gorm:"type:varchar(100);not null;default:'';comment:Protocol"` + Enabled *bool `gorm:"type:boolean;not null;default:true;comment:Enabled"` + Sort int `gorm:"uniqueIndex;not null;default:0;comment:Sort"` + NodeGroupIds JSONInt64Slice `gorm:"type:json;comment:Node Group IDs (JSON array, multiple groups)"` + CreatedAt time.Time `gorm:"<-:create;comment:Creation Time"` + UpdatedAt time.Time `gorm:"comment:Update Time"` } func (n *Node) TableName() string { diff --git a/internal/model/subscribe/model.go b/internal/model/subscribe/model.go index 9942046..06c0c6d 100644 --- a/internal/model/subscribe/model.go +++ b/internal/model/subscribe/model.go @@ -2,6 +2,7 @@ package subscribe import ( "context" + "strings" "github.com/perfect-panel/server/pkg/tool" "github.com/redis/go-redis/v9" @@ -19,6 +20,13 @@ type FilterParams struct { Language string // Language DefaultLanguage bool // Default Subscribe Language Data Search string // Search Keywords + NodeGroupId *int64 // Node Group ID +} + +type FilterByNodeGroupsParams struct { + Page int // Page Number + Size int // Page Size + NodeGroupIds []int64 // Node Group IDs (multiple) } func (p *FilterParams) Normalize() { @@ -32,6 +40,7 @@ func (p *FilterParams) Normalize() { type customSubscribeLogicModel interface { FilterList(ctx context.Context, params *FilterParams) (int64, []*Subscribe, error) + FilterListByNodeGroups(ctx context.Context, params *FilterByNodeGroupsParams) (int64, []*Subscribe, error) ClearCache(ctx context.Context, id ...int64) error QuerySubscribeMinSortByIds(ctx context.Context, ids []int64) (int64, error) } @@ -102,6 +111,10 @@ func (m *customSubscribeModel) FilterList(ctx context.Context, params *FilterPar if len(params.Tags) > 0 { query = query.Scopes(InSet("node_tags", params.Tags)) } + if params.NodeGroupId != nil { + // Filter by node_group_ids using JSON_CONTAINS + query = query.Where("JSON_CONTAINS(node_group_ids, ?)", *params.NodeGroupId) + } if lang != "" { query = query.Where("language = ?", lang) } else if params.DefaultLanguage { @@ -154,3 +167,67 @@ func InSet(field string, values []string) func(db *gorm.DB) *gorm.DB { return query } } + +// FilterListByNodeGroups Filter subscribes by node groups +// Match if subscribe's node_group_id OR node_group_ids contains any of the provided node group IDs +func (m *customSubscribeModel) FilterListByNodeGroups(ctx context.Context, params *FilterByNodeGroupsParams) (int64, []*Subscribe, error) { + if params == nil { + params = &FilterByNodeGroupsParams{ + Page: 1, + Size: 10, + } + } + if params.Page <= 0 { + params.Page = 1 + } + if params.Size <= 0 { + params.Size = 10 + } + + var list []*Subscribe + var total int64 + + err := m.QueryNoCacheCtx(ctx, &list, func(conn *gorm.DB, v interface{}) error { + query := conn.Model(&Subscribe{}) + + // Filter by node groups: match if node_group_id or node_group_ids contains any of the provided IDs + if len(params.NodeGroupIds) > 0 { + var conditions []string + var args []interface{} + + // Condition 1: node_group_id IN (...) + placeholders := make([]string, len(params.NodeGroupIds)) + for i, id := range params.NodeGroupIds { + placeholders[i] = "?" + args = append(args, id) + } + conditions = append(conditions, "node_group_id IN ("+strings.Join(placeholders, ",")+")") + + // Condition 2: JSON_CONTAINS(node_group_ids, id) for each id + for _, id := range params.NodeGroupIds { + conditions = append(conditions, "JSON_CONTAINS(node_group_ids, ?)") + args = append(args, id) + } + + // Combine with OR: (node_group_id IN (...) OR JSON_CONTAINS(node_group_ids, id1) OR ...) + query = query.Where("("+strings.Join(conditions, " OR ")+")", args...) + } + + // Count total + if err := query.Count(&total).Error; err != nil { + return err + } + + // Find with pagination + return query.Order("sort ASC"). + Limit(params.Size). + Offset((params.Page - 1) * params.Size). + Find(v).Error + }) + + if err != nil { + return 0, nil, err + } + + return total, list, nil +} diff --git a/internal/model/subscribe/subscribe.go b/internal/model/subscribe/subscribe.go index cf363af..49cbcc6 100644 --- a/internal/model/subscribe/subscribe.go +++ b/internal/model/subscribe/subscribe.go @@ -1,11 +1,58 @@ package subscribe import ( + "database/sql/driver" + "encoding/json" "time" "gorm.io/gorm" ) +// JSONInt64Slice is a custom type for handling []int64 as JSON in database +type JSONInt64Slice []int64 + +// Scan implements sql.Scanner interface +func (j *JSONInt64Slice) Scan(value interface{}) error { + if value == nil { + *j = []int64{} + return nil + } + + // Handle []byte + bytes, ok := value.([]byte) + if !ok { + // Try to handle string + str, ok := value.(string) + if !ok { + *j = []int64{} + return nil + } + bytes = []byte(str) + } + + if len(bytes) == 0 { + *j = []int64{} + return nil + } + + // Check if it's a JSON array + if bytes[0] != '[' { + // Not a JSON array, return empty slice + *j = []int64{} + return nil + } + + return json.Unmarshal(bytes, j) +} + +// Value implements driver.Valuer interface +func (j JSONInt64Slice) Value() (driver.Value, error) { + if len(j) == 0 { + return "[]", nil + } + return json.Marshal(j) +} + type Subscribe struct { Id int64 `gorm:"primaryKey"` Name string `gorm:"type:varchar(255);not null;default:'';comment:Subscribe Name"` diff --git a/internal/model/user/model.go b/internal/model/user/model.go index d457520..ecaac6e 100644 --- a/internal/model/user/model.go +++ b/internal/model/user/model.go @@ -29,6 +29,7 @@ type SubscribeDetails struct { OrderId int64 `gorm:"index:idx_order_id;not null;comment:Order ID"` SubscribeId int64 `gorm:"index:idx_subscribe_id;not null;comment:Subscription ID"` Subscribe *subscribe.Subscribe `gorm:"foreignKey:SubscribeId;references:Id"` + NodeGroupId int64 `gorm:"index:idx_node_group_id;not null;default:0;comment:Node Group ID (single ID)"` StartTime time.Time `gorm:"default:CURRENT_TIMESTAMP(3);not null;comment:Subscription Start Time"` ExpireTime time.Time `gorm:"default:NULL;comment:Subscription Expire Time"` FinishedAt *time.Time `gorm:"default:NULL;comment:Finished Time"` @@ -89,7 +90,7 @@ type customUserLogicModel interface { FindOneSubscribeDetailsById(ctx context.Context, id int64) (*SubscribeDetails, error) FindOneUserSubscribe(ctx context.Context, id int64) (*SubscribeDetails, error) FindUsersSubscribeBySubscribeId(ctx context.Context, subscribeId int64) ([]*Subscribe, error) - UpdateUserSubscribeWithTraffic(ctx context.Context, id, download, upload int64, tx ...*gorm.DB) error + UpdateUserSubscribeWithTraffic(ctx context.Context, id, download, upload int64, isExpired bool, tx ...*gorm.DB) error QueryResisterUserTotalByDate(ctx context.Context, date time.Time) (int64, error) QueryResisterUserTotalByMonthly(ctx context.Context, date time.Time) (int64, error) QueryResisterUserTotal(ctx context.Context) (int64, error) @@ -276,7 +277,7 @@ func (m *customUserModel) BatchDeleteUser(ctx context.Context, ids []int64, tx . }, m.batchGetCacheKeys(users...)...) } -func (m *customUserModel) UpdateUserSubscribeWithTraffic(ctx context.Context, id, download, upload int64, tx ...*gorm.DB) error { +func (m *customUserModel) UpdateUserSubscribeWithTraffic(ctx context.Context, id, download, upload int64, isExpired bool, tx ...*gorm.DB) error { sub, err := m.FindOneSubscribe(ctx, id) if err != nil { return err @@ -293,10 +294,21 @@ func (m *customUserModel) UpdateUserSubscribeWithTraffic(ctx context.Context, id if len(tx) > 0 { conn = tx[0] } - return conn.Model(&Subscribe{}).Where("id = ?", id).Updates(map[string]interface{}{ - "download": gorm.Expr("download + ?", download), - "upload": gorm.Expr("upload + ?", upload), - }).Error + + // 根据订阅状态更新对应的流量字段 + if isExpired { + // 过期期间,更新过期流量字段 + return conn.Model(&Subscribe{}).Where("id = ?", id).Updates(map[string]interface{}{ + "expired_download": gorm.Expr("expired_download + ?", download), + "expired_upload": gorm.Expr("expired_upload + ?", upload), + }).Error + } else { + // 正常期间,更新正常流量字段 + return conn.Model(&Subscribe{}).Where("id = ?", id).Updates(map[string]interface{}{ + "download": gorm.Expr("download + ?", download), + "upload": gorm.Expr("upload + ?", upload), + }).Error + } }) } diff --git a/internal/model/user/user.go b/internal/model/user/user.go index 5ac50d9..425af75 100644 --- a/internal/model/user/user.go +++ b/internal/model/user/user.go @@ -1,11 +1,58 @@ package user import ( + "database/sql/driver" + "encoding/json" "time" "gorm.io/gorm" ) +// JSONInt64Slice is a custom type for handling []int64 as JSON in database +type JSONInt64Slice []int64 + +// Scan implements sql.Scanner interface +func (j *JSONInt64Slice) Scan(value interface{}) error { + if value == nil { + *j = []int64{} + return nil + } + + // Handle []byte + bytes, ok := value.([]byte) + if !ok { + // Try to handle string + str, ok := value.(string) + if !ok { + *j = []int64{} + return nil + } + bytes = []byte(str) + } + + if len(bytes) == 0 { + *j = []int64{} + return nil + } + + // Check if it's a JSON array + if bytes[0] != '[' { + // Not a JSON array, return empty slice + *j = []int64{} + return nil + } + + return json.Unmarshal(bytes, j) +} + +// Value implements driver.Valuer interface +func (j JSONInt64Slice) Value() (driver.Value, error) { + if len(j) == 0 { + return "[]", nil + } + return json.Marshal(j) +} + type User struct { Id int64 `gorm:"primaryKey"` Password string `gorm:"type:varchar(100);not null;comment:User Password"` @@ -41,23 +88,27 @@ func (*User) TableName() string { } type Subscribe struct { - Id int64 `gorm:"primaryKey"` - UserId int64 `gorm:"index:idx_user_id;not null;comment:User ID"` - User User `gorm:"foreignKey:UserId;references:Id"` - OrderId int64 `gorm:"index:idx_order_id;not null;comment:Order ID"` - SubscribeId int64 `gorm:"index:idx_subscribe_id;not null;comment:Subscription ID"` - StartTime time.Time `gorm:"default:CURRENT_TIMESTAMP(3);not null;comment:Subscription Start Time"` - ExpireTime time.Time `gorm:"default:NULL;comment:Subscription Expire Time"` - FinishedAt *time.Time `gorm:"default:NULL;comment:Finished Time"` - Traffic int64 `gorm:"default:0;comment:Traffic"` - Download int64 `gorm:"default:0;comment:Download Traffic"` - Upload int64 `gorm:"default:0;comment:Upload Traffic"` - Token string `gorm:"index:idx_token;unique;type:varchar(255);default:'';comment:Token"` - UUID string `gorm:"type:varchar(255);unique;index:idx_uuid;default:'';comment:UUID"` - Status uint8 `gorm:"type:tinyint(1);default:0;comment:Subscription Status: 0: Pending 1: Active 2: Finished 3: Expired 4: Deducted 5: stopped"` - Note string `gorm:"type:varchar(500);default:'';comment:User note for subscription"` - CreatedAt time.Time `gorm:"<-:create;comment:Creation Time"` - UpdatedAt time.Time `gorm:"comment:Update Time"` + Id int64 `gorm:"primaryKey"` + UserId int64 `gorm:"index:idx_user_id;not null;comment:User ID"` + User User `gorm:"foreignKey:UserId;references:Id"` + OrderId int64 `gorm:"index:idx_order_id;not null;comment:Order ID"` + SubscribeId int64 `gorm:"index:idx_subscribe_id;not null;comment:Subscription ID"` + NodeGroupId int64 `gorm:"index:idx_node_group_id;not null;default:0;comment:Node Group ID (single ID)"` + GroupLocked *bool `gorm:"type:tinyint(1);not null;default:0;comment:Group Locked"` + StartTime time.Time `gorm:"default:CURRENT_TIMESTAMP(3);not null;comment:Subscription Start Time"` + ExpireTime time.Time `gorm:"default:NULL;comment:Subscription Expire Time"` + FinishedAt *time.Time `gorm:"default:NULL;comment:Finished Time"` + Traffic int64 `gorm:"default:0;comment:Traffic"` + Download int64 `gorm:"default:0;comment:Download Traffic"` + Upload int64 `gorm:"default:0;comment:Upload Traffic"` + ExpiredDownload int64 `gorm:"default:0;comment:Expired period download traffic (bytes)"` + ExpiredUpload int64 `gorm:"default:0;comment:Expired period upload traffic (bytes)"` + Token string `gorm:"index:idx_token;unique;type:varchar(255);default:'';comment:Token"` + UUID string `gorm:"type:varchar(255);unique;index:idx_uuid;default:'';comment:UUID"` + Status uint8 `gorm:"type:tinyint(1);default:0;comment:Subscription Status: 0: Pending 1: Active 2: Finished 3: Expired 4: Deducted 5: stopped"` + Note string `gorm:"type:varchar(500);default:'';comment:User note for subscription"` + CreatedAt time.Time `gorm:"<-:create;comment:Creation Time"` + UpdatedAt time.Time `gorm:"comment:Update Time"` } func (*Subscribe) TableName() string { diff --git a/internal/types/types.go b/internal/types/types.go index aa88d6e..919e483 100644 --- a/internal/types/types.go +++ b/internal/types/types.go @@ -361,14 +361,28 @@ type CreateDocumentRequest struct { Show *bool `json:"show"` } +type CreateNodeGroupRequest struct { + Name string `json:"name" validate:"required"` + Description string `json:"description"` + Sort int `json:"sort"` + ForCalculation *bool `json:"for_calculation"` + IsExpiredGroup *bool `json:"is_expired_group"` + ExpiredDaysLimit *int `json:"expired_days_limit"` + MaxTrafficGBExpired *int64 `json:"max_traffic_gb_expired,omitempty"` + SpeedLimit *int `json:"speed_limit"` + MinTrafficGB *int64 `json:"min_traffic_gb,omitempty"` + MaxTrafficGB *int64 `json:"max_traffic_gb,omitempty"` +} + type CreateNodeRequest struct { - Name string `json:"name"` - Tags []string `json:"tags,omitempty"` - Port uint16 `json:"port"` - Address string `json:"address"` - ServerId int64 `json:"server_id"` - Protocol string `json:"protocol"` - Enabled *bool `json:"enabled"` + Name string `json:"name"` + Tags []string `json:"tags,omitempty"` + Port uint16 `json:"port"` + Address string `json:"address"` + ServerId int64 `json:"server_id"` + Protocol string `json:"protocol"` + Enabled *bool `json:"enabled"` + NodeGroupIds []int64 `json:"node_group_ids,omitempty"` } type CreateOrderRequest struct { @@ -462,6 +476,9 @@ type CreateSubscribeRequest struct { NewUserOnly *bool `json:"new_user_only"` Nodes []int64 `json:"nodes"` NodeTags []string `json:"node_tags"` + NodeGroupIds []int64 `json:"node_group_ids,omitempty"` + NodeGroupId int64 `json:"node_group_id"` + TrafficLimit []TrafficLimit `json:"traffic_limit"` Show *bool `json:"show"` Sell *bool `json:"sell"` DeductionRatio int64 `json:"deduction_ratio"` @@ -469,6 +486,7 @@ type CreateSubscribeRequest struct { ResetCycle int64 `json:"reset_cycle"` RenewalReset *bool `json:"renewal_reset"` ShowOriginalPrice bool `json:"show_original_price"` + AutoCreateGroup bool `json:"auto_create_group"` } type CreateTicketFollowRequest struct { @@ -559,6 +577,10 @@ type DeleteDocumentRequest struct { Id int64 `json:"id" validate:"required"` } +type DeleteNodeGroupRequest struct { + Id int64 `json:"id" validate:"required"` +} + type DeleteNodeRequest struct { Id int64 `json:"id"` } @@ -761,9 +783,10 @@ type FilterMobileLogResponse struct { } type FilterNodeListRequest struct { - Page int `form:"page"` - Size int `form:"size"` - Search string `form:"search,omitempty"` + Page int `form:"page"` + Size int `form:"size"` + Search string `form:"search,omitempty"` + NodeGroupId *int64 `form:"node_group_id,omitempty"` } type FilterNodeListResponse struct { @@ -855,6 +878,12 @@ type Follow struct { CreatedAt int64 `json:"created_at"` } +type GenerateCaptchaResponse struct { + Id string `json:"id"` + Image string `json:"image"` + Type string `json:"type"` +} + type GetAdsDetailRequest struct { Id int64 `form:"id"` } @@ -1107,6 +1136,17 @@ type GetMessageLogListResponse struct { List []MessageLog `json:"list"` } +type GetNodeGroupListRequest struct { + Page int `form:"page"` + Size int `form:"size"` + GroupId string `form:"group_id,omitempty"` +} + +type GetNodeGroupListResponse struct { + Total int64 `json:"total"` + List []NodeGroup `json:"list"` +} + type GetNodeMultiplierResponse struct { Periods []TimePeriod `json:"periods"` } @@ -1234,11 +1274,19 @@ type GetSubscribeGroupListResponse struct { Total int64 `json:"total"` } +type GetSubscribeGroupMappingRequest struct { +} + +type GetSubscribeGroupMappingResponse struct { + List []SubscribeGroupMappingItem `json:"list"` +} + type GetSubscribeListRequest struct { - Page int64 `form:"page" validate:"required"` - Size int64 `form:"size" validate:"required"` - Language string `form:"language,omitempty"` - Search string `form:"search,omitempty"` + Page int64 `form:"page" validate:"required"` + Size int64 `form:"size" validate:"required"` + Language string `form:"language,omitempty"` + Search string `form:"search,omitempty"` + NodeGroupId int64 `form:"node_group_id,omitempty"` } type GetSubscribeListResponse struct { @@ -1423,6 +1471,18 @@ type GetUserTicketListResponse struct { List []Ticket `json:"list"` } +type GetUserTrafficStatsRequest struct { + UserSubscribeId string `form:"user_subscribe_id" validate:"required"` + Days int `form:"days" validate:"required,oneof=7 30"` +} + +type GetUserTrafficStatsResponse struct { + List []DailyTrafficStats `json:"list"` + TotalUpload int64 `json:"total_upload"` + TotalDownload int64 `json:"total_download"` + TotalTraffic int64 `json:"total_traffic"` +} + type GiftLog struct { Type uint16 `json:"type"` UserId int64 `json:"user_id"` @@ -1439,6 +1499,25 @@ type GoogleLoginCallbackRequest struct { State string `form:"state"` } +type GroupHistory struct { + Id int64 `json:"id"` + GroupMode string `json:"group_mode"` + TriggerType string `json:"trigger_type"` + TotalUsers int `json:"total_users"` + SuccessCount int `json:"success_count"` + FailedCount int `json:"failed_count"` + StartTime *int64 `json:"start_time,omitempty"` + EndTime *int64 `json:"end_time,omitempty"` + Operator string `json:"operator,omitempty"` + ErrorLog string `json:"error_log,omitempty"` + CreatedAt int64 `json:"created_at"` +} + +type GroupHistoryDetail struct { + GroupHistory + ConfigSnapshot map[string]interface{} `json:"config_snapshot,omitempty"` +} + type HasMigrateSeverNodeResponse struct { HasMigrate bool `json:"has_migrate"` } @@ -1527,17 +1606,19 @@ type ModuleConfig struct { } type Node struct { - Id int64 `json:"id"` - Name string `json:"name"` - Tags []string `json:"tags"` - Port uint16 `json:"port"` - Address string `json:"address"` - ServerId int64 `json:"server_id"` - Protocol string `json:"protocol"` - Enabled *bool `json:"enabled"` - Sort int `json:"sort,omitempty"` - CreatedAt int64 `json:"created_at"` - UpdatedAt int64 `json:"updated_at"` + Id int64 `json:"id"` + Name string `json:"name"` + Tags []string `json:"tags"` + Port uint16 `json:"port"` + Address string `json:"address"` + ServerId int64 `json:"server_id"` + Protocol string `json:"protocol"` + Enabled *bool `json:"enabled"` + Sort int `json:"sort,omitempty"` + NodeGroupId int64 `json:"node_group_id,omitempty"` + NodeGroupIds []int64 `json:"node_group_ids,omitempty"` + CreatedAt int64 `json:"created_at"` + UpdatedAt int64 `json:"updated_at"` } type NodeConfig struct { @@ -1557,6 +1638,29 @@ type NodeDNS struct { Domains []string `json:"domains"` } +type NodeGroup struct { + Id int64 `json:"id"` + Name string `json:"name"` + Description string `json:"description"` + Sort int `json:"sort"` + ForCalculation bool `json:"for_calculation"` + IsExpiredGroup bool `json:"is_expired_group"` + ExpiredDaysLimit int `json:"expired_days_limit"` + MaxTrafficGBExpired int64 `json:"max_traffic_gb_expired,omitempty"` + SpeedLimit int `json:"speed_limit"` + MinTrafficGB int64 `json:"min_traffic_gb,omitempty"` + MaxTrafficGB int64 `json:"max_traffic_gb,omitempty"` + NodeCount int64 `json:"node_count,omitempty"` + CreatedAt int64 `json:"created_at"` + UpdatedAt int64 `json:"updated_at"` +} + +type NodeGroupItem struct { + Id int64 `json:"id"` + Name string `json:"name"` + Nodes []Node `json:"nodes"` +} + type NodeOutbound struct { Name string `json:"name"` Protocol string `json:"protocol"` @@ -1774,6 +1878,15 @@ type PreviewSubscribeTemplateResponse struct { Template string `json:"template"` // 预览的模板内容 } +type PreviewUserNodesRequest struct { + UserId int64 `form:"user_id" validate:"required"` +} + +type PreviewUserNodesResponse struct { + UserId int64 `json:"user_id"` + NodeGroups []NodeGroupItem `json:"node_groups"` +} + type PrivacyPolicyConfig struct { PrivacyPolicy string `json:"privacy_policy"` } @@ -2055,6 +2168,17 @@ type QuotaTask struct { UpdatedAt int64 `json:"updated_at"` } +type RecalculateGroupRequest struct { + Mode string `json:"mode" validate:"required"` + TriggerType string `json:"trigger_type"` // "manual" or "scheduled" +} + +type RecalculationState struct { + State string `json:"state"` + Progress int `json:"progress"` + Total int `json:"total"` +} + type RechargeOrderRequest struct { Amount int64 `json:"amount" validate:"required,gt=0,lte=2000000000"` Payment int64 `json:"payment"` @@ -2139,15 +2263,21 @@ type ResetAllSubscribeTokenResponse struct { Success bool `json:"success"` } +type ResetGroupsRequest struct { + Confirm bool `json:"confirm" validate:"required"` +} + type ResetPasswordRequest struct { - Identifier string `json:"identifier"` - Email string `json:"email" validate:"required"` - Password string `json:"password" validate:"required"` - Code string `json:"code,optional"` - IP string `header:"X-Original-Forwarded-For"` - UserAgent string `header:"User-Agent"` - LoginType string `header:"Login-Type"` - CfToken string `json:"cf_token,optional"` + Identifier string `json:"identifier"` + Email string `json:"email" validate:"required"` + Password string `json:"password" validate:"required"` + Code string `json:"code,optional"` + IP string `header:"X-Original-Forwarded-For"` + UserAgent string `header:"User-Agent"` + LoginType string `header:"Login-Type"` + CfToken string `json:"cf_token,optional"` + CaptchaId string `json:"captcha_id,optional"` + CaptchaCode string `json:"captcha_code,optional"` } type ResetSortRequest struct { @@ -2412,6 +2542,9 @@ type Subscribe struct { NewUserOnly bool `json:"new_user_only"` Nodes []int64 `json:"nodes"` NodeTags []string `json:"node_tags"` + NodeGroupIds []int64 `json:"node_group_ids,omitempty"` + NodeGroupId int64 `json:"node_group_id"` + TrafficLimit []TrafficLimit `json:"traffic_limit"` Show bool `json:"show"` Sell bool `json:"sell"` Sort int64 `json:"sort"` @@ -2472,6 +2605,11 @@ type SubscribeGroup struct { UpdatedAt int64 `json:"updated_at"` } +type SubscribeGroupMappingItem struct { + SubscribeName string `json:"subscribe_name"` + NodeGroupName string `json:"node_group_name"` +} + type SubscribeItem struct { Subscribe Sold int64 `json:"sold"` @@ -2520,6 +2658,8 @@ type TelephoneLoginRequest struct { UserAgent string `header:"User-Agent"` LoginType string `header:"Login-Type"` CfToken string `json:"cf_token,optional"` + CaptchaId string `json:"captcha_id,optional"` + CaptchaCode string `json:"captcha_code,optional"` } type TelephoneRegisterRequest struct { @@ -2533,6 +2673,8 @@ type TelephoneRegisterRequest struct { UserAgent string `header:"User-Agent"` LoginType string `header:"Login-Type,optional"` CfToken string `json:"cf_token,optional"` + CaptchaId string `json:"captcha_id,optional"` + CaptchaCode string `json:"captcha_code,optional"` } type TelephoneResetPasswordRequest struct { @@ -2545,6 +2687,8 @@ type TelephoneResetPasswordRequest struct { UserAgent string `header:"User-Agent"` LoginType string `header:"Login-Type,optional"` CfToken string `json:"cf_token,optional"` + CaptchaId string `json:"captcha_id,optional"` + CaptchaCode string `json:"captcha_code,optional"` } type TestEmailSendRequest struct { @@ -2595,6 +2739,13 @@ type TosConfig struct { TosContent string `json:"tos_content"` } +type TrafficLimit struct { + StatType string `json:"stat_type"` + StatValue int64 `json:"stat_value"` + TrafficUsage int64 `json:"traffic_usage"` + SpeedLimit int64 `json:"speed_limit"` +} + type TrafficLog struct { Id int64 `json:"id"` ServerId int64 `json:"server_id"` @@ -2731,14 +2882,15 @@ type UpdateFamilyMaxMembersRequest struct { } type UpdateNodeRequest struct { - Id int64 `json:"id"` - Name string `json:"name"` - Tags []string `json:"tags,omitempty"` - Port uint16 `json:"port"` - Address string `json:"address"` - ServerId int64 `json:"server_id"` - Protocol string `json:"protocol"` - Enabled *bool `json:"enabled"` + Id int64 `json:"id"` + Name string `json:"name"` + Tags []string `json:"tags,omitempty"` + Port uint16 `json:"port"` + Address string `json:"address"` + ServerId int64 `json:"server_id"` + Protocol string `json:"protocol"` + Enabled *bool `json:"enabled"` + NodeGroupIds []int64 `json:"node_group_ids,omitempty"` } type UpdateOrderStatusRequest struct { @@ -2821,6 +2973,9 @@ type UpdateSubscribeRequest struct { NewUserOnly *bool `json:"new_user_only"` Nodes []int64 `json:"nodes"` NodeTags []string `json:"node_tags"` + NodeGroupIds []int64 `json:"node_group_ids,omitempty"` + NodeGroupId int64 `json:"node_group_id"` + TrafficLimit []TrafficLimit `json:"traffic_limit"` Show *bool `json:"show"` Sell *bool `json:"sell"` Sort int64 `json:"sort"` @@ -2975,25 +3130,29 @@ type UserLoginLog struct { } type UserLoginRequest struct { - Identifier string `json:"identifier"` - Email string `json:"email" validate:"required"` - Password string `json:"password" validate:"required"` - IP string `header:"X-Original-Forwarded-For"` - UserAgent string `header:"User-Agent"` - LoginType string `header:"Login-Type"` - CfToken string `json:"cf_token,optional"` + Identifier string `json:"identifier"` + Email string `json:"email" validate:"required"` + Password string `json:"password" validate:"required"` + IP string `header:"X-Original-Forwarded-For"` + UserAgent string `header:"User-Agent"` + LoginType string `header:"Login-Type"` + CfToken string `json:"cf_token,optional"` + CaptchaId string `json:"captcha_id,optional"` + CaptchaCode string `json:"captcha_code,optional"` } type UserRegisterRequest struct { - Identifier string `json:"identifier"` - Email string `json:"email" validate:"required"` - Password string `json:"password" validate:"required"` - Invite string `json:"invite,optional"` - Code string `json:"code,optional"` - IP string `header:"X-Original-Forwarded-For"` - UserAgent string `header:"User-Agent"` - LoginType string `header:"Login-Type"` - CfToken string `json:"cf_token,optional"` + Identifier string `json:"identifier"` + Email string `json:"email" validate:"required"` + Password string `json:"password" validate:"required"` + Invite string `json:"invite,optional"` + Code string `json:"code,optional"` + IP string `header:"X-Original-Forwarded-For"` + UserAgent string `header:"User-Agent"` + LoginType string `header:"Login-Type"` + CfToken string `json:"cf_token,optional"` + CaptchaId string `json:"captcha_id,optional"` + CaptchaCode string `json:"captcha_code,optional"` } type UserStatistics struct { @@ -3041,6 +3200,8 @@ type UserSubscribeDetail struct { OrderId int64 `json:"order_id"` SubscribeId int64 `json:"subscribe_id"` Subscribe Subscribe `json:"subscribe"` + NodeGroupId int64 `json:"node_group_id"` + GroupLocked bool `json:"group_locked"` StartTime int64 `json:"start_time"` ExpireTime int64 `json:"expire_time"` ResetTime int64 `json:"reset_time"` @@ -3127,10 +3288,12 @@ type UserTrafficData struct { } type VeifyConfig struct { - TurnstileSiteKey string `json:"turnstile_site_key"` - EnableLoginVerify bool `json:"enable_login_verify"` - EnableRegisterVerify bool `json:"enable_register_verify"` - EnableResetPasswordVerify bool `json:"enable_reset_password_verify"` + CaptchaType string `json:"captcha_type"` + TurnstileSiteKey string `json:"turnstile_site_key"` + EnableUserLoginCaptcha bool `json:"enable_user_login_captcha"` + EnableUserRegisterCaptcha bool `json:"enable_user_register_captcha"` + EnableAdminLoginCaptcha bool `json:"enable_admin_login_captcha"` + EnableUserResetPasswordCaptcha bool `json:"enable_user_reset_password_captcha"` } type VerifyCodeConfig struct { @@ -3140,11 +3303,13 @@ type VerifyCodeConfig struct { } type VerifyConfig struct { - TurnstileSiteKey string `json:"turnstile_site_key"` - TurnstileSecret string `json:"turnstile_secret"` - EnableLoginVerify bool `json:"enable_login_verify"` - EnableRegisterVerify bool `json:"enable_register_verify"` - EnableResetPasswordVerify bool `json:"enable_reset_password_verify"` + CaptchaType string `json:"captcha_type"` // local or turnstile + TurnstileSiteKey string `json:"turnstile_site_key"` + TurnstileSecret string `json:"turnstile_secret"` + EnableUserLoginCaptcha bool `json:"enable_user_login_captcha"` // User login captcha + EnableUserRegisterCaptcha bool `json:"enable_user_register_captcha"` // User register captcha + EnableAdminLoginCaptcha bool `json:"enable_admin_login_captcha"` // Admin login captcha + EnableUserResetPasswordCaptcha bool `json:"enable_user_reset_password_captcha"` // User reset password captcha } type VerifyEmailRequest struct { diff --git a/pkg/captcha/local.go b/pkg/captcha/local.go new file mode 100644 index 0000000..86b9af6 --- /dev/null +++ b/pkg/captcha/local.go @@ -0,0 +1,109 @@ +package captcha + +import ( + "context" + "fmt" + "strings" + "time" + + "github.com/mojocn/base64Captcha" + "github.com/redis/go-redis/v9" +) + +type localService struct { + redis *redis.Client + driver base64Captcha.Driver +} + +func newLocalService(redisClient *redis.Client) Service { + // Configure captcha driver - alphanumeric with visual effects (letters + numbers) + driver := base64Captcha.NewDriverString( + 80, // height + 240, // width + 20, // noise count (more interference) + base64Captcha.OptionShowSlimeLine|base64Captcha.OptionShowSineLine, // show curved lines + 5, // length (5 characters) + "abcdefghijkmnpqrstuvwxyzABCDEFGHJKLMNPQRSTUVWXYZ23456789", // source (exclude confusing chars) + nil, // bg color (use default) + nil, // fonts (use default) + nil, // fonts storage (use default) + ) + return &localService{ + redis: redisClient, + driver: driver, + } +} + +func (s *localService) Generate(ctx context.Context) (id string, image string, err error) { + // Generate captcha + captcha := base64Captcha.NewCaptcha(s.driver, &redisStore{ + redis: s.redis, + ctx: ctx, + }) + + id, b64s, answer, err := captcha.Generate() + if err != nil { + return "", "", err + } + + // Store answer in Redis with 5 minute expiration + key := fmt.Sprintf("captcha:%s", id) + err = s.redis.Set(ctx, key, answer, 5*time.Minute).Err() + if err != nil { + return "", "", err + } + + return id, b64s, nil +} + +func (s *localService) Verify(ctx context.Context, id string, code string, ip string) (bool, error) { + if id == "" || code == "" { + return false, nil + } + + key := fmt.Sprintf("captcha:%s", id) + + // Get answer from Redis + answer, err := s.redis.Get(ctx, key).Result() + if err != nil { + return false, err + } + + // Delete captcha after verification (one-time use) + s.redis.Del(ctx, key) + + // Verify code (case-insensitive) + return strings.EqualFold(answer, code), nil +} + +func (s *localService) GetType() CaptchaType { + return CaptchaTypeLocal +} + +// redisStore implements base64Captcha.Store interface +type redisStore struct { + redis *redis.Client + ctx context.Context +} + +func (r *redisStore) Set(id string, value string) error { + key := fmt.Sprintf("captcha:%s", id) + return r.redis.Set(r.ctx, key, value, 5*time.Minute).Err() +} + +func (r *redisStore) Get(id string, clear bool) string { + key := fmt.Sprintf("captcha:%s", id) + val, err := r.redis.Get(r.ctx, key).Result() + if err != nil { + return "" + } + if clear { + r.redis.Del(r.ctx, key) + } + return val +} + +func (r *redisStore) Verify(id, answer string, clear bool) bool { + v := r.Get(id, clear) + return strings.EqualFold(v, answer) +} diff --git a/pkg/captcha/service.go b/pkg/captcha/service.go new file mode 100644 index 0000000..4536227 --- /dev/null +++ b/pkg/captcha/service.go @@ -0,0 +1,49 @@ +package captcha + +import ( + "context" + + "github.com/redis/go-redis/v9" +) + +type CaptchaType string + +const ( + CaptchaTypeLocal CaptchaType = "local" + CaptchaTypeTurnstile CaptchaType = "turnstile" +) + +// Service defines the captcha service interface +type Service interface { + // Generate generates a new captcha + // For local captcha: returns id and base64 image + // For turnstile: returns empty strings + Generate(ctx context.Context) (id string, image string, err error) + + // Verify verifies the captcha + // For local captcha: token is captcha id, code is user input + // For turnstile: token is cf-turnstile-response, code is ignored + Verify(ctx context.Context, token string, code string, ip string) (bool, error) + + // GetType returns the captcha type + GetType() CaptchaType +} + +// Config holds the configuration for captcha service +type Config struct { + Type CaptchaType + RedisClient *redis.Client + TurnstileSecret string +} + +// NewService creates a new captcha service based on the config +func NewService(config Config) Service { + switch config.Type { + case CaptchaTypeTurnstile: + return newTurnstileService(config.TurnstileSecret) + case CaptchaTypeLocal: + fallthrough + default: + return newLocalService(config.RedisClient) + } +} diff --git a/pkg/captcha/turnstile.go b/pkg/captcha/turnstile.go new file mode 100644 index 0000000..52e5bca --- /dev/null +++ b/pkg/captcha/turnstile.go @@ -0,0 +1,37 @@ +package captcha + +import ( + "context" + + "github.com/perfect-panel/server/pkg/turnstile" +) + +type turnstileService struct { + service turnstile.Service +} + +func newTurnstileService(secret string) Service { + return &turnstileService{ + service: turnstile.New(turnstile.Config{ + Secret: secret, + }), + } +} + +func (s *turnstileService) Generate(ctx context.Context) (id string, image string, err error) { + // Turnstile doesn't need server-side generation + return "", "", nil +} + +func (s *turnstileService) Verify(ctx context.Context, token string, code string, ip string) (bool, error) { + if token == "" { + return false, nil + } + + // Verify with Cloudflare Turnstile + return s.service.Verify(ctx, token, ip) +} + +func (s *turnstileService) GetType() CaptchaType { + return CaptchaTypeTurnstile +} diff --git a/pkg/turnstile/service.go b/pkg/turnstile/service.go index e3be9a7..9af71e5 100644 --- a/pkg/turnstile/service.go +++ b/pkg/turnstile/service.go @@ -56,7 +56,9 @@ func (s *service) verify(ctx context.Context, secret string, token string, ip st _ = writer.WriteField("idempotency_key", key) } _ = writer.Close() - client := &http.Client{} + client := &http.Client{ + Timeout: 5 * time.Second, + } req, _ := http.NewRequest("POST", s.url, body) req.Header.Set("Content-Type", writer.FormDataContentType()) firstResult, err := client.Do(req) diff --git a/ppanel.api b/ppanel.api index 324f565..0786080 100644 --- a/ppanel.api +++ b/ppanel.api @@ -30,6 +30,7 @@ import ( "apis/admin/ads.api" "apis/admin/marketing.api" "apis/admin/application.api" + "apis/admin/group.api" "apis/public/user.api" "apis/public/subscribe.api" "apis/public/redemption.api" diff --git a/queue/logic/group/recalculateGroupLogic.go b/queue/logic/group/recalculateGroupLogic.go new file mode 100644 index 0000000..91b3685 --- /dev/null +++ b/queue/logic/group/recalculateGroupLogic.go @@ -0,0 +1,87 @@ +package group + +import ( + "context" + "time" + + "github.com/perfect-panel/server/internal/logic/admin/group" + "github.com/perfect-panel/server/internal/svc" + "github.com/perfect-panel/server/internal/types" + "github.com/perfect-panel/server/pkg/logger" + + "github.com/hibiken/asynq" +) + +type RecalculateGroupLogic struct { + svc *svc.ServiceContext +} + +func NewRecalculateGroupLogic(svc *svc.ServiceContext) *RecalculateGroupLogic { + return &RecalculateGroupLogic{ + svc: svc, + } +} + +func (l *RecalculateGroupLogic) ProcessTask(ctx context.Context, t *asynq.Task) error { + logger.Infof("[RecalculateGroup] Starting scheduled group recalculation: %s", time.Now().Format("2006-01-02 15:04:05")) + + // 1. Check if group management is enabled + var enabledConfig struct { + Value string `gorm:"column:value"` + } + err := l.svc.DB.Table("system"). + Where("`category` = ? AND `key` = ?", "group", "enabled"). + Select("value"). + First(&enabledConfig).Error + if err != nil { + logger.Errorw("[RecalculateGroup] Failed to read group enabled config", logger.Field("error", err.Error())) + return err + } + + // If not enabled, skip execution + if enabledConfig.Value != "true" && enabledConfig.Value != "1" { + logger.Debugf("[RecalculateGroup] Group management is not enabled, skipping") + return nil + } + + // 2. Get grouping mode + var modeConfig struct { + Value string `gorm:"column:value"` + } + err = l.svc.DB.Table("system"). + Where("`category` = ? AND `key` = ?", "group", "mode"). + Select("value"). + First(&modeConfig).Error + if err != nil { + logger.Errorw("[RecalculateGroup] Failed to read group mode config", logger.Field("error", err.Error())) + return err + } + + mode := modeConfig.Value + if mode == "" { + mode = "average" // default mode + } + + // 3. Only execute if mode is "traffic" + if mode != "traffic" { + logger.Debugf("[RecalculateGroup] Group mode is not 'traffic' (current: %s), skipping", mode) + return nil + } + + // 4. Execute traffic-based grouping + logger.Infof("[RecalculateGroup] Executing traffic-based grouping") + + logic := group.NewRecalculateGroupLogic(ctx, l.svc) + req := &types.RecalculateGroupRequest{ + Mode: "traffic", + TriggerType: "scheduled", + } + + if err := logic.RecalculateGroup(req); err != nil { + logger.Errorw("[RecalculateGroup] Failed to execute traffic grouping", logger.Field("error", err.Error())) + return err + } + + logger.Infof("[RecalculateGroup] Successfully completed traffic-based grouping: %s", time.Now().Format("2006-01-02 15:04:05")) + return nil +} diff --git a/queue/logic/order/activateOrderLogic.go b/queue/logic/order/activateOrderLogic.go index 2ba039b..f2fdad6 100644 --- a/queue/logic/order/activateOrderLogic.go +++ b/queue/logic/order/activateOrderLogic.go @@ -10,6 +10,7 @@ import ( "strconv" "time" + "github.com/perfect-panel/server/internal/logic/admin/group" "github.com/perfect-panel/server/internal/model/log" "github.com/perfect-panel/server/pkg/constant" "github.com/perfect-panel/server/pkg/logger" @@ -24,9 +25,10 @@ import ( "github.com/perfect-panel/server/internal/model/subscribe" "github.com/perfect-panel/server/internal/model/user" "github.com/perfect-panel/server/internal/svc" + "github.com/perfect-panel/server/internal/types" "github.com/perfect-panel/server/pkg/tool" "github.com/perfect-panel/server/pkg/uuidx" - "github.com/perfect-panel/server/queue/types" + queueTypes "github.com/perfect-panel/server/queue/types" "gorm.io/gorm" ) @@ -126,8 +128,8 @@ func (l *ActivateOrderLogic) ProcessTask(ctx context.Context, task *asynq.Task) } // parsePayload unMarshals the task payload into a structured format -func (l *ActivateOrderLogic) parsePayload(ctx context.Context, payload []byte) (*types.ForthwithActivateOrderPayload, error) { - var p types.ForthwithActivateOrderPayload +func (l *ActivateOrderLogic) parsePayload(ctx context.Context, payload []byte) (*queueTypes.ForthwithActivateOrderPayload, error) { + var p queueTypes.ForthwithActivateOrderPayload if err := json.Unmarshal(payload, &p); err != nil { logger.WithContext(ctx).Error("[ActivateOrderLogic] Unmarshal payload failed", logger.Field("error", err.Error()), @@ -322,6 +324,9 @@ func (l *ActivateOrderLogic) NewPurchase(ctx context.Context, orderInfo *order.O } } + // Trigger user group recalculation (runs in background) + l.triggerUserGroupRecalculation(ctx, userInfo.Id) + // Handle commission in separate goroutine to avoid blocking go l.handleCommission(context.Background(), userInfo, orderInfo) @@ -782,6 +787,63 @@ func (l *ActivateOrderLogic) clearServerCache(ctx context.Context, sub *subscrib } } +// triggerUserGroupRecalculation triggers user group recalculation after subscription changes +// This runs asynchronously in background to avoid blocking the main order processing flow +func (l *ActivateOrderLogic) triggerUserGroupRecalculation(ctx context.Context, userId int64) { + go func() { + // Use a new context with timeout for group recalculation + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + // Check if group management is enabled + var groupEnabled string + err := l.svc.DB.Table("system"). + Where("`category` = ? AND `key` = ?", "group", "enabled"). + Select("value"). + Scan(&groupEnabled).Error + if err != nil || groupEnabled != "true" && groupEnabled != "1" { + logger.Debugf("[Group Trigger] Group management not enabled, skipping recalculation") + return + } + + // Get the configured grouping mode + var groupMode string + err = l.svc.DB.Table("system"). + Where("`category` = ? AND `key` = ?", "group", "mode"). + Select("value"). + Scan(&groupMode).Error + if err != nil { + logger.Errorw("[Group Trigger] Failed to get group mode", logger.Field("error", err.Error())) + return + } + + // Validate group mode + if groupMode != "average" && groupMode != "subscribe" && groupMode != "traffic" { + logger.Debugf("[Group Trigger] Invalid group mode (current: %s), skipping", groupMode) + return + } + + // Trigger group recalculation with the configured mode + logic := group.NewRecalculateGroupLogic(ctx, l.svc) + req := &types.RecalculateGroupRequest{ + Mode: groupMode, + } + + if err := logic.RecalculateGroup(req); err != nil { + logger.Errorw("[Group Trigger] Failed to recalculate user group", + logger.Field("user_id", userId), + logger.Field("error", err.Error()), + ) + return + } + + logger.Infow("[Group Trigger] Successfully recalculated user group", + logger.Field("user_id", userId), + logger.Field("mode", groupMode), + ) + }() +} + // Renewal handles subscription renewal including subscription extension, // traffic reset (if configured), commission processing, and notifications func (l *ActivateOrderLogic) Renewal(ctx context.Context, orderInfo *order.Order, iapExpireAt int64) error { @@ -907,6 +969,9 @@ func (l *ActivateOrderLogic) updateSubscriptionForRenewal(ctx context.Context, u userSub.ExpireTime = tool.AddTime(sub.UnitTime, orderInfo.Quantity, userSub.ExpireTime) userSub.Status = 1 + // 续费时重置过期流量字段 + userSub.ExpiredDownload = 0 + userSub.ExpiredUpload = 0 if err := l.svc.UserModel.UpdateSubscribe(ctx, userSub); err != nil { logger.WithContext(ctx).Error("Update user subscribe failed", logger.Field("error", err.Error())) @@ -931,6 +996,8 @@ func (l *ActivateOrderLogic) ResetTraffic(ctx context.Context, orderInfo *order. // Reset traffic userSub.Download = 0 userSub.Upload = 0 + userSub.ExpiredDownload = 0 + userSub.ExpiredUpload = 0 userSub.Status = 1 if err := l.svc.UserModel.UpdateSubscribe(ctx, userSub); err != nil { @@ -1242,6 +1309,7 @@ func (l *ActivateOrderLogic) RedemptionActivate(ctx context.Context, orderInfo * Traffic: us.Traffic, Download: us.Download, Upload: us.Upload, + NodeGroupId: us.NodeGroupId, } break } @@ -1328,6 +1396,7 @@ func (l *ActivateOrderLogic) RedemptionActivate(ctx context.Context, orderInfo * Token: uuidx.SubscribeToken(orderInfo.OrderNo), UUID: uuid.New().String(), Status: 1, + NodeGroupId: sub.NodeGroupId, // Inherit node_group_id from subscription plan } err = l.svc.UserModel.InsertSubscribe(ctx, newSubscribe, tx) @@ -1374,6 +1443,9 @@ func (l *ActivateOrderLogic) RedemptionActivate(ctx context.Context, orderInfo * return err } + // Trigger user group recalculation (runs in background) + l.triggerUserGroupRecalculation(ctx, userInfo.Id) + // 7. 清理缓存(关键步骤:让节点获取最新订阅) l.clearServerCache(ctx, sub) diff --git a/queue/logic/subscription/checkSubscriptionLogic.go b/queue/logic/subscription/checkSubscriptionLogic.go index 81b86e7..02fbf01 100644 --- a/queue/logic/subscription/checkSubscriptionLogic.go +++ b/queue/logic/subscription/checkSubscriptionLogic.go @@ -62,7 +62,6 @@ func (l *CheckSubscriptionLogic) ProcessTask(ctx context.Context, _ *asynq.Task) } l.clearServerCache(ctx, list...) logger.Infow("[Check Subscription Traffic] Update subscribe status", logger.Field("user_ids", ids), logger.Field("count", int64(len(ids)))) - } else { logger.Info("[Check Subscription Traffic] No subscribe need to update") } @@ -108,6 +107,7 @@ func (l *CheckSubscriptionLogic) ProcessTask(ctx context.Context, _ *asynq.Task) } else { logger.Info("[Check Subscription Expire] No subscribe need to update") } + return nil }) if err != nil { diff --git a/queue/logic/traffic/trafficStatisticsLogic.go b/queue/logic/traffic/trafficStatisticsLogic.go index 37614cb..a89df57 100644 --- a/queue/logic/traffic/trafficStatisticsLogic.go +++ b/queue/logic/traffic/trafficStatisticsLogic.go @@ -98,11 +98,13 @@ func (l *TrafficStatisticsLogic) ProcessTask(ctx context.Context, task *asynq.Ta // update user subscribe with log d := int64(float32(log.Download) * ratio * realTimeMultiplier) u := int64(float32(log.Upload) * ratio * realTimeMultiplier) - if err := l.svc.UserModel.UpdateUserSubscribeWithTraffic(ctx, sub.Id, d, u); err != nil { + isExpired := now.After(sub.ExpireTime) + if err := l.svc.UserModel.UpdateUserSubscribeWithTraffic(ctx, sub.Id, d, u, isExpired); err != nil { logger.WithContext(ctx).Error("[TrafficStatistics] Update user subscribe with log failed", logger.Field("sid", log.SID), logger.Field("download", float32(log.Download)*ratio), logger.Field("upload", float32(log.Upload)*ratio), + logger.Field("is_expired", isExpired), logger.Field("error", err.Error()), ) continue