diff --git a/internal/logic/auth/oauth/oAuthLoginGetTokenLogic.go b/internal/logic/auth/oauth/oAuthLoginGetTokenLogic.go index a1dc2a4..58b75f4 100644 --- a/internal/logic/auth/oauth/oAuthLoginGetTokenLogic.go +++ b/internal/logic/auth/oauth/oAuthLoginGetTokenLogic.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "fmt" + "strings" "time" "github.com/perfect-panel/server/internal/config" @@ -393,10 +394,15 @@ func (l *OAuthLoginGetTokenLogic) register(email, avatar, method, openid, reques } } - if l.svcCtx.Config.Register.EnableTrial { + rc := l.svcCtx.Config.Register + // Only activate trial if whitelist is not enabled, or email domain matches whitelist + shouldActivateTrial := rc.EnableTrial && (!rc.EnableTrialEmailWhitelist || (email != "" && l.isEmailDomainWhitelisted(email, rc.TrialEmailDomainWhitelist))) + + if shouldActivateTrial { l.Debugw("activating trial subscription", logger.Field("request_id", requestID), logger.Field("user_id", userInfo.Id), + logger.Field("email", email), ) var trialErr error trialSubscribe, trialErr = l.activeTrial(userInfo.Id, requestID) @@ -882,3 +888,22 @@ func (l *OAuthLoginGetTokenLogic) activeTrial(uid int64, requestID string) (*use ) return userSub, nil } + +// isEmailDomainWhitelisted checks if the email's domain is in the comma-separated whitelist. +// Returns false if the email format is invalid. +func (l *OAuthLoginGetTokenLogic) isEmailDomainWhitelisted(email, whitelistCSV string) bool { + if whitelistCSV == "" { + return false + } + parts := strings.SplitN(email, "@", 2) + if len(parts) != 2 { + return false + } + domain := strings.ToLower(strings.TrimSpace(parts[1])) + for _, d := range strings.Split(whitelistCSV, ",") { + if strings.ToLower(strings.TrimSpace(d)) == domain { + return true + } + } + return false +}