Compare commits
2 Commits
79427c9f4c
...
6b64e8c461
| Author | SHA1 | Date | |
|---|---|---|---|
| 6b64e8c461 | |||
| 47696b9e68 |
@ -12,6 +12,7 @@ import (
|
|||||||
"github.com/perfect-panel/server/internal/types"
|
"github.com/perfect-panel/server/internal/types"
|
||||||
"github.com/perfect-panel/server/pkg/jwt"
|
"github.com/perfect-panel/server/pkg/jwt"
|
||||||
"github.com/perfect-panel/server/pkg/logger"
|
"github.com/perfect-panel/server/pkg/logger"
|
||||||
|
"github.com/perfect-panel/server/pkg/tool"
|
||||||
"github.com/perfect-panel/server/pkg/uuidx"
|
"github.com/perfect-panel/server/pkg/uuidx"
|
||||||
"github.com/perfect-panel/server/pkg/xerr"
|
"github.com/perfect-panel/server/pkg/xerr"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
@ -135,6 +136,8 @@ func (l *DeviceLoginLogic) DeviceLogin(req *types.DeviceLoginRequest) (resp *typ
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
l.tryGrantTrialForDeviceLogin(userInfo, req.Identifier)
|
||||||
|
|
||||||
// Generate session id
|
// Generate session id
|
||||||
sessionId := uuidx.NewUUID().String()
|
sessionId := uuidx.NewUUID().String()
|
||||||
|
|
||||||
@ -291,3 +294,108 @@ func (l *DeviceLoginLogic) registerUserAndDevice(req *types.DeviceLoginRequest)
|
|||||||
|
|
||||||
return userInfo, nil
|
return userInfo, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (l *DeviceLoginLogic) tryGrantTrialForDeviceLogin(userInfo *user.User, identifier string) {
|
||||||
|
if userInfo == nil || userInfo.Id == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !IsTrialConfigReady(l.svcCtx.Config.Register) {
|
||||||
|
l.Debugw("skip device trial grant because trial config is not ready",
|
||||||
|
logger.Field("user_id", userInfo.Id),
|
||||||
|
logger.Field("identifier", identifier),
|
||||||
|
logger.Field("enable_trial", l.svcCtx.Config.Register.EnableTrial),
|
||||||
|
logger.Field("trial_subscribe_id", l.svcCtx.Config.Register.TrialSubscribe),
|
||||||
|
logger.Field("trial_time", l.svcCtx.Config.Register.TrialTime),
|
||||||
|
logger.Field("trial_time_unit", l.svcCtx.Config.Register.TrialTimeUnit),
|
||||||
|
)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if userInfo.CreatedAt.IsZero() || time.Since(userInfo.CreatedAt) > 24*time.Hour {
|
||||||
|
l.Debugw("skip device trial grant because user is outside trial backfill window",
|
||||||
|
logger.Field("user_id", userInfo.Id),
|
||||||
|
logger.Field("identifier", identifier),
|
||||||
|
logger.Field("user_created_at", userInfo.CreatedAt),
|
||||||
|
)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var count int64
|
||||||
|
if err := l.svcCtx.DB.WithContext(l.ctx).
|
||||||
|
Model(&user.Subscribe{}).
|
||||||
|
Where("user_id = ?", userInfo.Id).
|
||||||
|
Count(&count).Error; err != nil {
|
||||||
|
l.Errorw("failed to query existing subscriptions before device trial grant",
|
||||||
|
logger.Field("user_id", userInfo.Id),
|
||||||
|
logger.Field("identifier", identifier),
|
||||||
|
logger.Field("error", err.Error()),
|
||||||
|
)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if count > 0 {
|
||||||
|
l.Debugw("skip device trial grant because user already has subscriptions",
|
||||||
|
logger.Field("user_id", userInfo.Id),
|
||||||
|
logger.Field("identifier", identifier),
|
||||||
|
logger.Field("subscription_count", count),
|
||||||
|
)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
trialSubscribe, err := l.activeTrial(userInfo.Id)
|
||||||
|
if err != nil {
|
||||||
|
l.Errorw("failed to activate trial subscription for device login",
|
||||||
|
logger.Field("user_id", userInfo.Id),
|
||||||
|
logger.Field("identifier", identifier),
|
||||||
|
logger.Field("trial_subscribe_id", l.svcCtx.Config.Register.TrialSubscribe),
|
||||||
|
logger.Field("error", err.Error()),
|
||||||
|
)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if clearErr := l.svcCtx.UserModel.ClearSubscribeCache(l.ctx, trialSubscribe); clearErr != nil {
|
||||||
|
l.Errorw("ClearSubscribeCache failed",
|
||||||
|
logger.Field("error", clearErr.Error()),
|
||||||
|
logger.Field("userSubscribeId", trialSubscribe.Id),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
if clearErr := l.svcCtx.SubscribeModel.ClearCache(l.ctx, trialSubscribe.SubscribeId); clearErr != nil {
|
||||||
|
l.Errorw("ClearSubscribeCache failed",
|
||||||
|
logger.Field("error", clearErr.Error()),
|
||||||
|
logger.Field("subscribeId", trialSubscribe.SubscribeId),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
if clearErr := l.svcCtx.NodeModel.ClearServerAllCache(l.ctx); clearErr != nil {
|
||||||
|
l.Errorf("ClearServerAllCache error: %v", clearErr.Error())
|
||||||
|
}
|
||||||
|
l.Infow("device trial subscription granted",
|
||||||
|
logger.Field("user_id", userInfo.Id),
|
||||||
|
logger.Field("identifier", identifier),
|
||||||
|
logger.Field("user_subscribe_id", trialSubscribe.Id),
|
||||||
|
logger.Field("trial_subscribe_id", trialSubscribe.SubscribeId),
|
||||||
|
logger.Field("expire_time", trialSubscribe.ExpireTime),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *DeviceLoginLogic) activeTrial(uid int64) (*user.Subscribe, error) {
|
||||||
|
sub, err := l.svcCtx.SubscribeModel.FindOne(l.ctx, l.svcCtx.Config.Register.TrialSubscribe)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
startTime := time.Now()
|
||||||
|
userSub := &user.Subscribe{
|
||||||
|
UserId: uid,
|
||||||
|
OrderId: 0,
|
||||||
|
SubscribeId: sub.Id,
|
||||||
|
StartTime: startTime,
|
||||||
|
ExpireTime: tool.AddTime(l.svcCtx.Config.Register.TrialTimeUnit, l.svcCtx.Config.Register.TrialTime, startTime),
|
||||||
|
Traffic: sub.Traffic,
|
||||||
|
Download: 0,
|
||||||
|
Upload: 0,
|
||||||
|
Token: uuidx.NewUUID().String(),
|
||||||
|
UUID: uuidx.NewUUID().String(),
|
||||||
|
Status: 1,
|
||||||
|
}
|
||||||
|
if err = l.svcCtx.UserModel.InsertSubscribe(l.ctx, userSub); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return userSub, nil
|
||||||
|
}
|
||||||
|
|||||||
@ -54,12 +54,19 @@ func ShouldGrantTrialForEmail(register config.RegisterConfig, email string) bool
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsTrialConfigReady verifies that trial auto-grant has all required config.
|
||||||
|
func IsTrialConfigReady(register config.RegisterConfig) bool {
|
||||||
|
return register.EnableTrial &&
|
||||||
|
register.TrialSubscribe > 0 &&
|
||||||
|
register.TrialTime > 0 &&
|
||||||
|
strings.TrimSpace(register.TrialTimeUnit) != ""
|
||||||
|
}
|
||||||
|
|
||||||
// ShouldAutoGrantTrialOnPublicEmailFlows defines whether browser/email-originated
|
// ShouldAutoGrantTrialOnPublicEmailFlows defines whether browser/email-originated
|
||||||
// flows may auto-create a trial subscription. The current policy disables trial
|
// flows may auto-create a trial subscription. Email-specific abuse protection
|
||||||
// creation for email registration, email login auto-register, OAuth-with-email,
|
// is still handled by ShouldGrantTrialForEmail and NormalizedEmailHasTrial.
|
||||||
// and email binding/verification to avoid abuse through public email channels.
|
|
||||||
func ShouldAutoGrantTrialOnPublicEmailFlows(register config.RegisterConfig) bool {
|
func ShouldAutoGrantTrialOnPublicEmailFlows(register config.RegisterConfig) bool {
|
||||||
return false
|
return IsTrialConfigReady(register)
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsDisposableAlias detects Gmail dot trick and + alias abuse.
|
// IsDisposableAlias detects Gmail dot trick and + alias abuse.
|
||||||
|
|||||||
@ -304,20 +304,17 @@ func (l *ActivateOrderLogic) reconcilePostOrderSubscriptions(ctx context.Context
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
maxExpire := survivor.ExpireTime
|
now := time.Now()
|
||||||
|
accumulatedExpire := now
|
||||||
for i := range ownerSubs {
|
for i := range ownerSubs {
|
||||||
item := ownerSubs[i]
|
item := ownerSubs[i]
|
||||||
if item.Id == survivor.Id {
|
if (item.Id == survivor.Id || orderMergeRemainingTimeStatus(item.Status)) && item.ExpireTime.After(now) {
|
||||||
if item.ExpireTime.After(maxExpire) {
|
accumulatedExpire = accumulatedExpire.Add(item.ExpireTime.Sub(now))
|
||||||
maxExpire = item.ExpireTime
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
|
|
||||||
losers = append(losers, item)
|
if item.Id != survivor.Id {
|
||||||
mergedIDs = append(mergedIDs, item.Id)
|
losers = append(losers, item)
|
||||||
if item.ExpireTime.After(maxExpire) {
|
mergedIDs = append(mergedIDs, item.Id)
|
||||||
maxExpire = item.ExpireTime
|
|
||||||
}
|
}
|
||||||
if item.SubscribeId > 0 {
|
if item.SubscribeId > 0 {
|
||||||
subscribeIDsToClear[item.SubscribeId] = struct{}{}
|
subscribeIDsToClear[item.SubscribeId] = struct{}{}
|
||||||
@ -341,9 +338,9 @@ func (l *ActivateOrderLogic) reconcilePostOrderSubscriptions(ctx context.Context
|
|||||||
"status": 1,
|
"status": 1,
|
||||||
"finished_at": nil,
|
"finished_at": nil,
|
||||||
}
|
}
|
||||||
if maxExpire.After(survivor.ExpireTime) {
|
if accumulatedExpire.After(survivor.ExpireTime) {
|
||||||
survivor.ExpireTime = maxExpire
|
survivor.ExpireTime = accumulatedExpire
|
||||||
updateFields["expire_time"] = maxExpire
|
updateFields["expire_time"] = accumulatedExpire
|
||||||
}
|
}
|
||||||
if identitySource != nil {
|
if identitySource != nil {
|
||||||
if identitySource.Token != "" {
|
if identitySource.Token != "" {
|
||||||
@ -441,6 +438,15 @@ func shouldReconcilePostOrderSubscriptions(orderInfo *order.Order) bool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func orderMergeRemainingTimeStatus(status uint8) bool {
|
||||||
|
switch status {
|
||||||
|
case 0, 1, 2:
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func pickSubscriptionIdentitySource(candidates []user.Subscribe) *user.Subscribe {
|
func pickSubscriptionIdentitySource(candidates []user.Subscribe) *user.Subscribe {
|
||||||
if len(candidates) == 0 {
|
if len(candidates) == 0 {
|
||||||
return nil
|
return nil
|
||||||
@ -1434,6 +1440,7 @@ func (l *ActivateOrderLogic) updateSubscriptionWithIAPExpire(ctx context.Context
|
|||||||
userSub.FinishedAt = nil
|
userSub.FinishedAt = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
userSub.OrderId = orderInfo.Id
|
||||||
userSub.ExpireTime = newExpire
|
userSub.ExpireTime = newExpire
|
||||||
userSub.Status = 1
|
userSub.Status = 1
|
||||||
|
|
||||||
|
|||||||
249
scripts/test_device_trial_registration.go
Normal file
249
scripts/test_device_trial_registration.go
Normal file
@ -0,0 +1,249 @@
|
|||||||
|
//go:build ignore
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"flag"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/perfect-panel/server/initialize"
|
||||||
|
"github.com/perfect-panel/server/internal/config"
|
||||||
|
authlogic "github.com/perfect-panel/server/internal/logic/auth"
|
||||||
|
modelAuth "github.com/perfect-panel/server/internal/model/auth"
|
||||||
|
modelLog "github.com/perfect-panel/server/internal/model/log"
|
||||||
|
modelNode "github.com/perfect-panel/server/internal/model/node"
|
||||||
|
modelSubscribe "github.com/perfect-panel/server/internal/model/subscribe"
|
||||||
|
modelSystem "github.com/perfect-panel/server/internal/model/system"
|
||||||
|
modelUser "github.com/perfect-panel/server/internal/model/user"
|
||||||
|
"github.com/perfect-panel/server/internal/svc"
|
||||||
|
"github.com/perfect-panel/server/internal/types"
|
||||||
|
"github.com/perfect-panel/server/pkg/conf"
|
||||||
|
"github.com/perfect-panel/server/pkg/orm"
|
||||||
|
"github.com/perfect-panel/server/pkg/tool"
|
||||||
|
"github.com/redis/go-redis/v9"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
var (
|
||||||
|
configPath = flag.String("config", "etc/ppanel.yaml", "config file path on the test server")
|
||||||
|
dsn = flag.String("dsn", "", "optional MySQL DSN override")
|
||||||
|
identifier = flag.String("identifier", "", "optional device identifier; defaults to a unique test identifier")
|
||||||
|
ip = flag.String("ip", "", "optional request IP; defaults to a reserved test IP")
|
||||||
|
userAgent = flag.String("user-agent", "CodexDeviceTrialTest/1.0", "device user agent")
|
||||||
|
write = flag.Bool("write", false, "actually create a test device user by running DeviceLogin")
|
||||||
|
cleanup = flag.Bool("cleanup", false, "delete the test user/device/subscription/log rows after verification")
|
||||||
|
)
|
||||||
|
flag.Parse()
|
||||||
|
|
||||||
|
if !*write {
|
||||||
|
fmt.Println("Refusing to write DB without -write.")
|
||||||
|
fmt.Println("Example:")
|
||||||
|
fmt.Printf(" go run scripts/test_device_trial_registration.go -config %s -write\n", *configPath)
|
||||||
|
os.Exit(2)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
cfg := loadConfig(*configPath, *dsn)
|
||||||
|
env := mustNewDeviceTrialEnv(ctx, cfg)
|
||||||
|
defer env.close()
|
||||||
|
|
||||||
|
initialize.Device(env.svcCtx)
|
||||||
|
initialize.Register(env.svcCtx)
|
||||||
|
|
||||||
|
if *identifier == "" {
|
||||||
|
*identifier = fmt.Sprintf("codex-device-trial-%d", time.Now().UnixNano())
|
||||||
|
}
|
||||||
|
if *ip == "" {
|
||||||
|
now := time.Now().UnixNano()
|
||||||
|
*ip = fmt.Sprintf("198.18.%d.%d", now%200+1, now/200%200+1)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println("== device trial registration test ==")
|
||||||
|
fmt.Printf("mysql: %s/%s\n", env.cfg.MySQL.Addr, env.cfg.MySQL.Dbname)
|
||||||
|
fmt.Printf("redis: %s db=%d\n", env.cfg.Redis.Host, env.cfg.Redis.DB)
|
||||||
|
fmt.Printf("device.enable=%v\n", env.svcCtx.Config.Device.Enable)
|
||||||
|
fmt.Printf("register.enable_trial=%v trial_subscribe=%d trial_time=%d trial_time_unit=%s\n",
|
||||||
|
env.svcCtx.Config.Register.EnableTrial,
|
||||||
|
env.svcCtx.Config.Register.TrialSubscribe,
|
||||||
|
env.svcCtx.Config.Register.TrialTime,
|
||||||
|
env.svcCtx.Config.Register.TrialTimeUnit,
|
||||||
|
)
|
||||||
|
fmt.Printf("identifier=%s ip=%s user_agent=%s\n", *identifier, *ip, *userAgent)
|
||||||
|
|
||||||
|
if err := ensureIdentifierUnused(ctx, env.db, *identifier); err != nil {
|
||||||
|
fail(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
logic := authlogic.NewDeviceLoginLogic(ctx, env.svcCtx)
|
||||||
|
resp, err := logic.DeviceLogin(&types.DeviceLoginRequest{
|
||||||
|
Identifier: *identifier,
|
||||||
|
IP: *ip,
|
||||||
|
UserAgent: *userAgent,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
fail(fmt.Errorf("DeviceLogin failed: %w", err))
|
||||||
|
}
|
||||||
|
if resp == nil || strings.TrimSpace(resp.Token) == "" {
|
||||||
|
fail(fmt.Errorf("DeviceLogin returned empty token"))
|
||||||
|
}
|
||||||
|
fmt.Printf("login token: ok len=%d\n", len(resp.Token))
|
||||||
|
|
||||||
|
device, err := env.svcCtx.UserModel.FindOneDeviceByIdentifier(ctx, *identifier)
|
||||||
|
if err != nil {
|
||||||
|
fail(fmt.Errorf("query created device failed: %w", err))
|
||||||
|
}
|
||||||
|
fmt.Printf("device: id=%d sn=%s user_id=%d created_at=%s\n",
|
||||||
|
device.Id,
|
||||||
|
tool.DeviceIdToHash(device.Id),
|
||||||
|
device.UserId,
|
||||||
|
device.CreatedAt.Format(time.RFC3339),
|
||||||
|
)
|
||||||
|
|
||||||
|
var subs []modelUser.Subscribe
|
||||||
|
if err = env.db.WithContext(ctx).
|
||||||
|
Where("user_id = ?", device.UserId).
|
||||||
|
Order("id ASC").
|
||||||
|
Find(&subs).Error; err != nil {
|
||||||
|
fail(fmt.Errorf("query user_subscribe failed: %w", err))
|
||||||
|
}
|
||||||
|
if len(subs) == 0 {
|
||||||
|
fail(fmt.Errorf("FAIL: no user_subscribe rows created for user_id=%d", device.UserId))
|
||||||
|
}
|
||||||
|
|
||||||
|
var trial *modelUser.Subscribe
|
||||||
|
for i := range subs {
|
||||||
|
sub := &subs[i]
|
||||||
|
fmt.Printf("subscribe: id=%d order_id=%d subscribe_id=%d status=%d start=%s expire=%s token_empty=%v\n",
|
||||||
|
sub.Id,
|
||||||
|
sub.OrderId,
|
||||||
|
sub.SubscribeId,
|
||||||
|
sub.Status,
|
||||||
|
sub.StartTime.Format(time.RFC3339),
|
||||||
|
sub.ExpireTime.Format(time.RFC3339),
|
||||||
|
sub.Token == "",
|
||||||
|
)
|
||||||
|
if sub.OrderId == 0 &&
|
||||||
|
sub.SubscribeId == env.svcCtx.Config.Register.TrialSubscribe &&
|
||||||
|
(sub.Status == 0 || sub.Status == 1) &&
|
||||||
|
sub.ExpireTime.After(time.Now()) {
|
||||||
|
trial = sub
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if trial == nil {
|
||||||
|
fail(fmt.Errorf("FAIL: trial subscription was not granted for user_id=%d", device.UserId))
|
||||||
|
}
|
||||||
|
fmt.Printf("PASS: trial granted user_subscribe_id=%d expire_time=%s\n",
|
||||||
|
trial.Id,
|
||||||
|
trial.ExpireTime.Format(time.RFC3339),
|
||||||
|
)
|
||||||
|
|
||||||
|
if *cleanup {
|
||||||
|
if err = cleanupTestRows(ctx, env.db, device.UserId); err != nil {
|
||||||
|
fail(fmt.Errorf("cleanup failed: %w", err))
|
||||||
|
}
|
||||||
|
fmt.Printf("cleanup: deleted test rows for user_id=%d\n", device.UserId)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type deviceTrialEnv struct {
|
||||||
|
db *gorm.DB
|
||||||
|
rds *redis.Client
|
||||||
|
cfg config.Config
|
||||||
|
svcCtx *svc.ServiceContext
|
||||||
|
}
|
||||||
|
|
||||||
|
func mustNewDeviceTrialEnv(ctx context.Context, cfg config.Config) *deviceTrialEnv {
|
||||||
|
db, err := orm.ConnectMysql(orm.Mysql{Config: cfg.MySQL})
|
||||||
|
must(err)
|
||||||
|
|
||||||
|
rds := redis.NewClient(&redis.Options{
|
||||||
|
Addr: cfg.Redis.Host,
|
||||||
|
Password: cfg.Redis.Pass,
|
||||||
|
DB: cfg.Redis.DB,
|
||||||
|
PoolSize: cfg.Redis.PoolSize,
|
||||||
|
MinIdleConns: cfg.Redis.MinIdleConns,
|
||||||
|
})
|
||||||
|
must(rds.Ping(ctx).Err())
|
||||||
|
|
||||||
|
svcCtx := &svc.ServiceContext{
|
||||||
|
DB: db,
|
||||||
|
Redis: rds,
|
||||||
|
Config: cfg,
|
||||||
|
AuthModel: modelAuth.NewModel(db, rds),
|
||||||
|
LogModel: modelLog.NewModel(db),
|
||||||
|
NodeModel: modelNode.NewModel(db, rds),
|
||||||
|
SystemModel: modelSystem.NewModel(db, rds),
|
||||||
|
UserModel: modelUser.NewModel(db, rds),
|
||||||
|
SubscribeModel: modelSubscribe.NewModel(db, rds),
|
||||||
|
}
|
||||||
|
return &deviceTrialEnv{db: db, rds: rds, cfg: cfg, svcCtx: svcCtx}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *deviceTrialEnv) close() {
|
||||||
|
if e == nil || e.rds == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_ = e.rds.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func loadConfig(path, dsn string) config.Config {
|
||||||
|
var cfg config.Config
|
||||||
|
conf.MustLoad(path, &cfg)
|
||||||
|
if dsn != "" {
|
||||||
|
parsed := orm.ParseDSN(dsn)
|
||||||
|
if parsed == nil {
|
||||||
|
fail(fmt.Errorf("invalid dsn"))
|
||||||
|
}
|
||||||
|
cfg.MySQL = *parsed
|
||||||
|
}
|
||||||
|
return cfg
|
||||||
|
}
|
||||||
|
|
||||||
|
func ensureIdentifierUnused(ctx context.Context, db *gorm.DB, identifier string) error {
|
||||||
|
var count int64
|
||||||
|
if err := db.WithContext(ctx).
|
||||||
|
Model(&modelUser.Device{}).
|
||||||
|
Where("identifier = ?", identifier).
|
||||||
|
Count(&count).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if count > 0 {
|
||||||
|
return fmt.Errorf("identifier already exists: %s", identifier)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func cleanupTestRows(ctx context.Context, db *gorm.DB, userID int64) error {
|
||||||
|
return db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||||
|
if err := tx.Where("object_id = ?", userID).Delete(&modelLog.SystemLog{}).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := tx.Where("user_id = ?", userID).Delete(&modelUser.Subscribe{}).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := tx.Where("user_id = ?", userID).Delete(&modelUser.AuthMethods{}).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := tx.Where("user_id = ?", userID).Delete(&modelUser.Device{}).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return tx.Where("id = ?", userID).Delete(&modelUser.User{}).Error
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func must(err error) {
|
||||||
|
if err != nil {
|
||||||
|
fail(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func fail(err error) {
|
||||||
|
fmt.Fprintln(os.Stderr, err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
Loading…
x
Reference in New Issue
Block a user