diff --git a/internal/config/config.go b/internal/config/config.go index 07e1e50..ad7c3ef 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -37,9 +37,18 @@ type Config struct { } type RedisConfig struct { - Host string `yaml:"Host" default:"localhost:6379"` - Pass string `yaml:"Pass" default:""` - DB int `yaml:"DB" default:"0"` + Host string `yaml:"Host" default:"localhost:6379"` + Pass string `yaml:"Pass" default:""` + DB int `yaml:"DB" default:"0"` + PoolSize int `yaml:"PoolSize" default:"100"` // 连接池大小(最大连接数) + MinIdleConns int `yaml:"MinIdleConns" default:"10"` // 最小空闲连接数 + MaxRetries int `yaml:"MaxRetries" default:"3"` // 最大重试次数 + PoolTimeout int `yaml:"PoolTimeout" default:"4"` // 连接池超时时间(秒) + IdleTimeout int `yaml:"IdleTimeout" default:"300"` // 空闲连接超时时间(秒) + MaxConnAge int `yaml:"MaxConnAge" default:"0"` // 连接最大生命周期(秒),0表示不限制 + DialTimeout int `yaml:"DialTimeout" default:"5"` // 连接超时时间(秒) + ReadTimeout int `yaml:"ReadTimeout" default:"3"` // 读超时时间(秒) + WriteTimeout int `yaml:"WriteTimeout" default:"3"` // 写超时时间(秒) } type JwtAuth struct { diff --git a/internal/logic/auth/registerLimitLogic.go b/internal/logic/auth/registerLimitLogic.go index 40ce3ac..048ef75 100644 --- a/internal/logic/auth/registerLimitLogic.go +++ b/internal/logic/auth/registerLimitLogic.go @@ -7,6 +7,7 @@ import ( "github.com/perfect-panel/server/internal/config" "github.com/perfect-panel/server/internal/svc" + "github.com/redis/go-redis/v9" "go.uber.org/zap" ) @@ -14,32 +15,51 @@ func registerIpLimit(svcCtx *svc.ServiceContext, ctx context.Context, registerIp if !svcCtx.Config.Register.EnableIpRegisterLimit { return true } - cacheKey := fmt.Sprintf("%s%s:*", config.RegisterIpKeyPrefix, registerIp) - var cacheKeys []string - var cursor uint64 - for { - keys, newCursor, err := svcCtx.Redis.Scan(ctx, 0, cacheKey, 100).Result() - if err != nil { - zap.S().Errorf("[registerIpLimit] Err: %v", err) - return true - } - if len(keys) > 0 { - cacheKeys = append(cacheKeys, keys...) - } - cursor = newCursor - if cursor == 0 { - break - } - } - defer func() { - key := fmt.Sprintf("%s%s:%s:%s", config.RegisterIpKeyPrefix, registerIp, authType, account) - if err := svcCtx.Redis.Set(ctx, key, account, time.Minute*time.Duration(svcCtx.Config.Register.IpRegisterLimitDuration)).Err(); err != nil { - zap.S().Errorf("[registerIpLimit] Set Err: %v", err) - } - }() - if len(cacheKeys) < int(svcCtx.Config.Register.IpRegisterLimit) { + // Use a sorted set to track IP registrations with timestamp as score + // Key format: register:ip:{ip} + key := fmt.Sprintf("%s%s", config.RegisterIpKeyPrefix, registerIp) + now := time.Now().Unix() + expiration := int64(svcCtx.Config.Register.IpRegisterLimitDuration) * 60 + + // Clean up expired entries first (remove entries older than expiration duration) + expireTimestamp := now - expiration + removed, err := svcCtx.Redis.ZRemRangeByScore(ctx, key, "0", fmt.Sprintf("%d", expireTimestamp)).Result() + if err != nil { + zap.S().Errorf("[registerIpLimit] ZRemRangeByScore Err: %v", err) return true } - return false + if removed > 0 { + zap.S().Debugf("[registerIpLimit] Cleaned %d expired entries for IP: %s", removed, registerIp) + } + + // Get current count of registrations within the time window + count, err := svcCtx.Redis.ZCard(ctx, key).Result() + if err != nil { + zap.S().Errorf("[registerIpLimit] ZCard Err: %v", err) + return true + } + + // Check if limit is reached + if count >= svcCtx.Config.Register.IpRegisterLimit { + zap.S().Warnf("[registerIpLimit] IP %s exceeded limit: %d/%d", registerIp, count, svcCtx.Config.Register.IpRegisterLimit) + return false + } + + // Add new registration entry with current timestamp as score + member := fmt.Sprintf("%s:%s", authType, account) + if err := svcCtx.Redis.ZAdd(ctx, key, redis.Z{ + Score: float64(now), + Member: member, + }).Err(); err != nil { + zap.S().Errorf("[registerIpLimit] ZAdd Err: %v", err) + return true + } + + // Set expiration on the sorted set key + if err := svcCtx.Redis.Expire(ctx, key, time.Minute*time.Duration(svcCtx.Config.Register.IpRegisterLimitDuration)).Err(); err != nil { + zap.S().Errorf("[registerIpLimit] Expire Err: %v", err) + } + + return true } diff --git a/internal/svc/serviceContext.go b/internal/svc/serviceContext.go index 5650b4b..b78bea6 100644 --- a/internal/svc/serviceContext.go +++ b/internal/svc/serviceContext.go @@ -2,6 +2,7 @@ package svc import ( "context" + "time" "github.com/perfect-panel/server/internal/model/client" "github.com/perfect-panel/server/internal/model/node" @@ -84,9 +85,18 @@ func NewServiceContext(c config.Config) *ServiceContext { } rds := redis.NewClient(&redis.Options{ - Addr: c.Redis.Host, - Password: c.Redis.Pass, - DB: c.Redis.DB, + Addr: c.Redis.Host, + Password: c.Redis.Pass, + DB: c.Redis.DB, + PoolSize: c.Redis.PoolSize, // 连接池大小:根据应用并发量调整,建议 100-500 + MinIdleConns: c.Redis.MinIdleConns, // 最小空闲连接:保持一定数量的空闲连接,减少建立连接的开销 + MaxRetries: c.Redis.MaxRetries, // 最大重试次数:网络抖动时自动重试 + PoolTimeout: time.Second * time.Duration(c.Redis.PoolTimeout), // 从连接池获取连接的超时时间 + ConnMaxIdleTime: time.Second * time.Duration(c.Redis.IdleTimeout), // 空闲连接的超时时间,自动回收长时间空闲的连接 + ConnMaxLifetime: time.Second * time.Duration(c.Redis.MaxConnAge), // 连接的最大生命周期,定期重建连接避免长时间使用的问题 + DialTimeout: time.Second * time.Duration(c.Redis.DialTimeout), // 建立新连接的超时时间 + ReadTimeout: time.Second * time.Duration(c.Redis.ReadTimeout), // 读操作超时时间 + WriteTimeout: time.Second * time.Duration(c.Redis.WriteTimeout), // 写操作超时时间 }) err = rds.Ping(context.Background()).Err() if err != nil {