diff --git a/internal/config/config.go b/internal/config/config.go index d742c3f..476cdda 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -90,7 +90,9 @@ type RegisterConfig struct { IpRegisterLimit int64 `yaml:"IpRegisterLimit" default:"0"` IpRegisterLimitDuration int64 `yaml:"IpRegisterLimitDuration" default:"0"` EnableIpRegisterLimit bool `yaml:"EnableIpRegisterLimit" default:"false"` - DeviceLimit int64 `yaml:"DeviceLimit" default:"2"` + DeviceLimit int64 `yaml:"DeviceLimit" default:"2"` + EnableTrialEmailWhitelist bool `yaml:"EnableTrialEmailWhitelist" default:"false"` + TrialEmailDomainWhitelist string `yaml:"TrialEmailDomainWhitelist" default:""` } type EmailConfig struct { diff --git a/internal/logic/auth/deviceLoginLogic.go b/internal/logic/auth/deviceLoginLogic.go index 5ccce8a..e0a0d0e 100644 --- a/internal/logic/auth/deviceLoginLogic.go +++ b/internal/logic/auth/deviceLoginLogic.go @@ -12,7 +12,6 @@ import ( "github.com/perfect-panel/server/internal/types" "github.com/perfect-panel/server/pkg/jwt" "github.com/perfect-panel/server/pkg/logger" - "github.com/perfect-panel/server/pkg/tool" "github.com/perfect-panel/server/pkg/uuidx" "github.com/perfect-panel/server/pkg/xerr" "github.com/pkg/errors" @@ -180,7 +179,6 @@ func (l *DeviceLoginLogic) registerUserAndDevice(req *types.DeviceLoginRequest) ) var userInfo *user.User - var trialSubscribe *user.Subscribe err := l.svcCtx.UserModel.Transaction(l.ctx, func(db *gorm.DB) error { // Create new user userInfo = &user.User{ @@ -239,15 +237,6 @@ func (l *DeviceLoginLogic) registerUserAndDevice(req *types.DeviceLoginRequest) return errors.Wrapf(xerr.NewErrCode(xerr.DatabaseInsertError), "insert device failed: %v", err) } - // Activate trial if enabled - if l.svcCtx.Config.Register.EnableTrial { - var trialErr error - trialSubscribe, trialErr = l.activeTrial(userInfo.Id, db) - if trialErr != nil { - return trialErr - } - } - return nil }) @@ -259,25 +248,6 @@ func (l *DeviceLoginLogic) registerUserAndDevice(req *types.DeviceLoginRequest) return nil, err } - // Clear cache after transaction success - if l.svcCtx.Config.Register.EnableTrial && trialSubscribe != nil { - // Clear user subscription cache - if err = l.svcCtx.UserModel.ClearSubscribeCache(l.ctx, trialSubscribe); err != nil { - l.Errorw("ClearSubscribeCache failed", logger.Field("error", err.Error()), logger.Field("userSubscribeId", trialSubscribe.Id)) - // Don't return error, just log it - } - // Clear subscription cache - if err = l.svcCtx.SubscribeModel.ClearCache(l.ctx, trialSubscribe.SubscribeId); err != nil { - l.Errorw("ClearSubscribeCache failed", logger.Field("error", err.Error()), logger.Field("subscribeId", trialSubscribe.SubscribeId)) - // Don't return error, just log it - } - // Clear all server cache - if err = l.svcCtx.NodeModel.ClearServerAllCache(l.ctx); err != nil { - l.Errorf("ClearServerAllCache error: %v", err.Error()) - // Don't return error, just log it - } - } - l.Infow("device registration completed successfully", logger.Field("user_id", userInfo.Id), logger.Field("identifier", req.Identifier), @@ -309,51 +279,3 @@ func (l *DeviceLoginLogic) registerUserAndDevice(req *types.DeviceLoginRequest) return userInfo, nil } - -func (l *DeviceLoginLogic) activeTrial(userId int64, db *gorm.DB) (*user.Subscribe, error) { - sub, err := l.svcCtx.SubscribeModel.FindOne(l.ctx, l.svcCtx.Config.Register.TrialSubscribe) - if err != nil { - l.Errorw("failed to find trial subscription template", - logger.Field("user_id", userId), - logger.Field("trial_subscribe_id", l.svcCtx.Config.Register.TrialSubscribe), - logger.Field("error", err.Error()), - ) - return nil, err - } - - startTime := time.Now() - expireTime := tool.AddTime(l.svcCtx.Config.Register.TrialTimeUnit, l.svcCtx.Config.Register.TrialTime, startTime) - subscribeToken := uuidx.NewUUID().String() - subscribeUUID := uuidx.NewUUID().String() - - userSub := &user.Subscribe{ - UserId: userId, - OrderId: 0, - SubscribeId: sub.Id, - StartTime: startTime, - ExpireTime: expireTime, - Traffic: sub.Traffic, - Download: 0, - Upload: 0, - Token: subscribeToken, - UUID: subscribeUUID, - Status: 1, - } - - if err := db.Create(userSub).Error; err != nil { - l.Errorw("failed to insert trial subscription", - logger.Field("user_id", userId), - logger.Field("error", err.Error()), - ) - return nil, err - } - - l.Infow("trial subscription activated successfully", - logger.Field("user_id", userId), - logger.Field("subscribe_id", sub.Id), - logger.Field("expire_time", expireTime), - logger.Field("traffic", sub.Traffic), - ) - - return userSub, nil -} diff --git a/internal/logic/auth/emailLoginLogic.go b/internal/logic/auth/emailLoginLogic.go index 9701a79..e4fb016 100644 --- a/internal/logic/auth/emailLoginLogic.go +++ b/internal/logic/auth/emailLoginLogic.go @@ -125,7 +125,8 @@ func (l *EmailLoginLogic) EmailLogin(req *types.EmailLoginRequest) (resp *types. if err = db.Create(authInfo).Error; err != nil { return err } - if l.svcCtx.Config.Register.EnableTrial { + rc := l.svcCtx.Config.Register + if rc.EnableTrial && (!rc.EnableTrialEmailWhitelist || IsEmailDomainWhitelisted(req.Email, rc.TrialEmailDomainWhitelist)) { if err = l.activeTrial(userInfo.Id); err != nil { return err } diff --git a/internal/logic/auth/telephoneUserRegisterLogic.go b/internal/logic/auth/telephoneUserRegisterLogic.go index b40aca6..10b3033 100644 --- a/internal/logic/auth/telephoneUserRegisterLogic.go +++ b/internal/logic/auth/telephoneUserRegisterLogic.go @@ -46,7 +46,6 @@ func NewTelephoneUserRegisterLogic(ctx context.Context, svcCtx *svc.ServiceConte func (l *TelephoneUserRegisterLogic) TelephoneUserRegister(req *types.TelephoneRegisterRequest) (resp *types.LoginResponse, err error) { c := l.svcCtx.Config.Register - var trialSubscribe *user.Subscribe // Check if the registration is stopped if c.StopRegister { return nil, errors.Wrapf(xerr.NewErrCode(xerr.StopRegister), "stop register") @@ -141,39 +140,12 @@ func (l *TelephoneUserRegisterLogic) TelephoneUserRegister(req *types.TelephoneR if err := db.Model(&user.User{}).Where("id = ?", userInfo.Id).Update("refer_code", userInfo.ReferCode).Error; err != nil { return err } - if l.svcCtx.Config.Register.EnableTrial { - // Active trial - var trialErr error - trialSubscribe, trialErr = l.activeTrial(userInfo.Id) - if trialErr != nil { - return trialErr - } - } return nil }) if err != nil { return nil, err } - // Clear cache after transaction success - if l.svcCtx.Config.Register.EnableTrial && trialSubscribe != nil { - // Clear user subscription cache - if err = l.svcCtx.UserModel.ClearSubscribeCache(l.ctx, trialSubscribe); err != nil { - l.Errorw("ClearSubscribeCache failed", logger.Field("error", err.Error()), logger.Field("userSubscribeId", trialSubscribe.Id)) - // Don't return error, just log it - } - // Clear subscription cache - if err = l.svcCtx.SubscribeModel.ClearCache(l.ctx, trialSubscribe.SubscribeId); err != nil { - l.Errorw("ClearSubscribeCache failed", logger.Field("error", err.Error()), logger.Field("subscribeId", trialSubscribe.SubscribeId)) - // Don't return error, just log it - } - // Clear all server cache - if err = l.svcCtx.NodeModel.ClearServerAllCache(l.ctx); err != nil { - l.Errorf("ClearServerAllCache error: %v", err.Error()) - // Don't return error, just log it - } - } - // Bind device to user if identifier is provided if req.Identifier != "" { bindLogic := NewBindDeviceLogic(l.ctx, l.svcCtx) @@ -261,32 +233,6 @@ func (l *TelephoneUserRegisterLogic) TelephoneUserRegister(req *types.TelephoneR }, nil } -func (l *TelephoneUserRegisterLogic) activeTrial(uid int64) (*user.Subscribe, error) { - sub, err := l.svcCtx.SubscribeModel.FindOne(l.ctx, l.svcCtx.Config.Register.TrialSubscribe) - if err != nil { - return nil, err - } - userSub := &user.Subscribe{ - Id: 0, - UserId: uid, - OrderId: 0, - SubscribeId: sub.Id, - StartTime: time.Now(), - ExpireTime: tool.AddTime(l.svcCtx.Config.Register.TrialTimeUnit, l.svcCtx.Config.Register.TrialTime, time.Now()), - Traffic: sub.Traffic, - Download: 0, - Upload: 0, - Token: uuidx.NewUUID().String(), - UUID: uuidx.NewUUID().String(), - Status: 1, - } - err = l.svcCtx.UserModel.InsertSubscribe(l.ctx, userSub) - if err != nil { - return nil, err - } - return userSub, nil -} - func (l *TelephoneUserRegisterLogic) verifyCaptcha(req *types.TelephoneRegisterRequest) error { verifyCfg, err := l.svcCtx.SystemModel.GetVerifyConfig(l.ctx) if err != nil { diff --git a/internal/logic/auth/trialEmailWhitelist.go b/internal/logic/auth/trialEmailWhitelist.go new file mode 100644 index 0000000..4decd98 --- /dev/null +++ b/internal/logic/auth/trialEmailWhitelist.go @@ -0,0 +1,22 @@ +package auth + +import "strings" + +// IsEmailDomainWhitelisted checks if the email's domain is in the comma-separated whitelist. +// Returns false if the email format is invalid. +func IsEmailDomainWhitelisted(email, whitelistCSV string) bool { + if whitelistCSV == "" { + return false + } + parts := strings.SplitN(email, "@", 2) + if len(parts) != 2 { + return false + } + domain := strings.ToLower(strings.TrimSpace(parts[1])) + for _, d := range strings.Split(whitelistCSV, ",") { + if strings.ToLower(strings.TrimSpace(d)) == domain { + return true + } + } + return false +} diff --git a/internal/logic/auth/userRegisterLogic.go b/internal/logic/auth/userRegisterLogic.go index 5aa7ce7..710dff1 100644 --- a/internal/logic/auth/userRegisterLogic.go +++ b/internal/logic/auth/userRegisterLogic.go @@ -147,7 +147,8 @@ func (l *UserRegisterLogic) UserRegister(req *types.UserRegisterRequest) (resp * } // Activate trial subscription after transaction success (moved outside transaction to reduce lock time) - if l.svcCtx.Config.Register.EnableTrial { + rc := l.svcCtx.Config.Register + if rc.EnableTrial && (!rc.EnableTrialEmailWhitelist || IsEmailDomainWhitelisted(req.Email, rc.TrialEmailDomainWhitelist)) { trialSubscribe, err = l.activeTrial(userInfo.Id) if err != nil { l.Errorw("Failed to activate trial subscription", logger.Field("error", err.Error())) @@ -156,7 +157,7 @@ func (l *UserRegisterLogic) UserRegister(req *types.UserRegisterRequest) (resp * } // Clear cache after transaction success - if l.svcCtx.Config.Register.EnableTrial && trialSubscribe != nil { + if trialSubscribe != nil { // Trigger user group recalculation (runs in background) go func() { ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) diff --git a/internal/logic/public/user/bindEmailWithVerificationLogic.go b/internal/logic/public/user/bindEmailWithVerificationLogic.go index e4ef91b..6578124 100644 --- a/internal/logic/public/user/bindEmailWithVerificationLogic.go +++ b/internal/logic/public/user/bindEmailWithVerificationLogic.go @@ -8,12 +8,14 @@ import ( "time" "github.com/perfect-panel/server/internal/config" + "github.com/perfect-panel/server/internal/logic/auth" "github.com/perfect-panel/server/internal/model/user" "github.com/perfect-panel/server/internal/svc" "github.com/perfect-panel/server/internal/types" "github.com/perfect-panel/server/pkg/constant" "github.com/perfect-panel/server/pkg/jwt" "github.com/perfect-panel/server/pkg/logger" + "github.com/perfect-panel/server/pkg/tool" "github.com/perfect-panel/server/pkg/uuidx" "github.com/perfect-panel/server/pkg/xerr" "github.com/pkg/errors" @@ -127,6 +129,8 @@ func (l *BindEmailWithVerificationLogic) BindEmailWithVerification(req *types.Bi if err != nil { return nil, err } + // Grant trial subscription if email domain is whitelisted + l.tryGrantTrialOnEmailBind(emailUser.Id, req.Email) return &types.BindEmailWithVerificationResponse{ Success: true, Message: "email user created and joined family", @@ -154,6 +158,9 @@ func (l *BindEmailWithVerificationLogic) BindEmailWithVerification(req *types.Bi return nil, err } + // Grant trial subscription if email domain is whitelisted + l.tryGrantTrialOnEmailBind(existingMethod.UserId, req.Email) + return &types.BindEmailWithVerificationResponse{ Success: true, Message: "joined family successfully", @@ -200,3 +207,74 @@ func (l *BindEmailWithVerificationLogic) refreshBindSessionToken(userId int64) ( return token, nil } + +// tryGrantTrialOnEmailBind grants trial subscription to the email user (family owner) +// if EnableTrialEmailWhitelist is on and the email domain matches. +func (l *BindEmailWithVerificationLogic) tryGrantTrialOnEmailBind(ownerUserId int64, email string) { + rc := l.svcCtx.Config.Register + if !rc.EnableTrial || !rc.EnableTrialEmailWhitelist { + return + } + if !auth.IsEmailDomainWhitelisted(email, rc.TrialEmailDomainWhitelist) { + l.Infow("email domain not in trial whitelist, skip", + logger.Field("email", email), + logger.Field("owner_user_id", ownerUserId), + ) + return + } + + // Anti-duplicate: check if owner already has trial subscription + var count int64 + if err := l.svcCtx.DB.WithContext(l.ctx). + Model(&user.Subscribe{}). + Where("user_id = ? AND subscribe_id = ?", ownerUserId, rc.TrialSubscribe). + Count(&count).Error; err != nil { + l.Errorw("failed to check existing trial", logger.Field("error", err.Error())) + return + } + if count > 0 { + l.Infow("trial already granted, skip", + logger.Field("owner_user_id", ownerUserId), + ) + return + } + + sub, err := l.svcCtx.SubscribeModel.FindOne(l.ctx, rc.TrialSubscribe) + if err != nil { + l.Errorw("failed to find trial subscribe template", logger.Field("error", err.Error())) + return + } + + userSub := &user.Subscribe{ + UserId: ownerUserId, + OrderId: 0, + SubscribeId: sub.Id, + StartTime: time.Now(), + ExpireTime: tool.AddTime(rc.TrialTimeUnit, rc.TrialTime, time.Now()), + Traffic: sub.Traffic, + Download: 0, + Upload: 0, + Token: uuidx.NewUUID().String(), + UUID: uuidx.NewUUID().String(), + Status: 1, + } + if err = l.svcCtx.UserModel.InsertSubscribe(l.ctx, userSub); err != nil { + l.Errorw("failed to insert trial subscribe", + logger.Field("error", err.Error()), + logger.Field("owner_user_id", ownerUserId), + ) + return + } + + // InsertSubscribe auto-clears user subscribe cache via execSubscribeMutation. + // Clear server cache so nodes pick up the new subscription. + if err = l.svcCtx.NodeModel.ClearServerAllCache(l.ctx); err != nil { + l.Errorw("ClearServerAllCache error", logger.Field("error", err.Error())) + } + + l.Infow("trial granted on email bind", + logger.Field("owner_user_id", ownerUserId), + logger.Field("email", email), + logger.Field("subscribe_id", sub.Id), + ) +}