diff --git a/internal/logic/public/user/bindEmailWithVerificationLogic.go b/internal/logic/public/user/bindEmailWithVerificationLogic.go index 3abce2a..0702631 100644 --- a/internal/logic/public/user/bindEmailWithVerificationLogic.go +++ b/internal/logic/public/user/bindEmailWithVerificationLogic.go @@ -135,16 +135,18 @@ func (l *BindEmailWithVerificationLogic) BindEmailWithVerification(req *types.Bi logger.Field("email", req.Email), ) - // Join family: email user as owner, device user as member + // Join family: email user as owner, device user as member. + // For a newly-created email owner, preserve the device user's paid entitlement by moving + // its subscriptions to the email owner and issuing owner-side subscription tokens. if err = familyHelper.validateJoinFamily(emailUser.Id, u.Id); err != nil { return nil, err } - joinResult, err := familyHelper.joinFamily(emailUser.Id, u.Id, "bind_email_with_verification") + joinResult, err := familyHelper.joinFamilyAndMoveSubscribesToOwner(emailUser.Id, u.Id, "bind_email_with_verification") if err != nil { return nil, err } commonLogic.SubscriptionTraceInfo(l.Logger, commonLogic.SubscriptionTraceFlowEmailBind, "family_joined", - "[SubscriptionFlow] device user joined email owner family", + "[SubscriptionFlow] device user joined email owner family and subscriptions moved to owner with reset tokens", logger.Field("device_user_id", u.Id), logger.Field("owner_user_id", emailUser.Id), logger.Field("family_id", joinResult.FamilyId), diff --git a/internal/logic/public/user/familyBindingHelper.go b/internal/logic/public/user/familyBindingHelper.go index d4962e2..b768750 100644 --- a/internal/logic/public/user/familyBindingHelper.go +++ b/internal/logic/public/user/familyBindingHelper.go @@ -2,10 +2,13 @@ package user import ( "context" + "fmt" "time" + modelOrder "github.com/perfect-panel/server/internal/model/order" "github.com/perfect-panel/server/internal/model/user" "github.com/perfect-panel/server/internal/svc" + "github.com/perfect-panel/server/pkg/uuidx" "github.com/perfect-panel/server/pkg/xerr" "github.com/pkg/errors" "gorm.io/gorm" @@ -22,6 +25,13 @@ type familyBindingHelper struct { svcCtx *svc.ServiceContext } +type familySubscribeTransferMode uint8 + +const ( + familySubscribeDiscard familySubscribeTransferMode = iota + familySubscribeMoveToOwner +) + func newFamilyBindingHelper(ctx context.Context, svcCtx *svc.ServiceContext) *familyBindingHelper { return &familyBindingHelper{ ctx: ctx, @@ -111,6 +121,14 @@ func (h *familyBindingHelper) validateJoinFamily(ownerUserId, memberUserId int64 } func (h *familyBindingHelper) joinFamily(ownerUserId, memberUserId int64, source string) (*familyJoinResult, error) { + return h.joinFamilyWithSubscribeMode(ownerUserId, memberUserId, source, familySubscribeDiscard) +} + +func (h *familyBindingHelper) joinFamilyAndMoveSubscribesToOwner(ownerUserId, memberUserId int64, source string) (*familyJoinResult, error) { + return h.joinFamilyWithSubscribeMode(ownerUserId, memberUserId, source, familySubscribeMoveToOwner) +} + +func (h *familyBindingHelper) joinFamilyWithSubscribeMode(ownerUserId, memberUserId int64, source string, subscribeMode familySubscribeTransferMode) (*familyJoinResult, error) { if ownerUserId == memberUserId { return nil, errors.Wrapf(xerr.NewErrCode(xerr.FamilyAlreadyBound), "user already bound to this family") } @@ -118,7 +136,7 @@ func (h *familyBindingHelper) joinFamily(ownerUserId, memberUserId int64, source result := &familyJoinResult{ OwnerUserId: ownerUserId, } - removedSubscribes := make([]user.Subscribe, 0) + affectedSubscribes := make([]user.Subscribe, 0) err := h.svcCtx.DB.WithContext(h.ctx).Transaction(func(tx *gorm.DB) error { ownerFamily, err := h.getOrCreateOwnerFamily(tx, ownerUserId) @@ -182,9 +200,17 @@ func (h *familyBindingHelper) joinFamily(ownerUserId, memberUserId int64, source } } - removedSubscribes, err = transferMemberSubscribesToOwner(tx, memberUserId, ownerUserId) - if err != nil { - return err + switch subscribeMode { + case familySubscribeMoveToOwner: + affectedSubscribes, err = moveMemberSubscribesToOwner(tx, memberUserId, ownerUserId) + if err != nil { + return err + } + default: + affectedSubscribes, err = transferMemberSubscribesToOwner(tx, memberUserId, ownerUserId) + if err != nil { + return err + } } return nil }) @@ -193,12 +219,56 @@ func (h *familyBindingHelper) joinFamily(ownerUserId, memberUserId int64, source return nil, err } - if err = h.clearRemovedMemberSubscribeCache(removedSubscribes); err != nil { - return nil, err + if subscribeMode == familySubscribeMoveToOwner { + if err = h.clearMovedMemberSubscribeCache(affectedSubscribes, ownerUserId); err != nil { + return nil, err + } + } else { + if err = h.clearRemovedMemberSubscribeCache(affectedSubscribes); err != nil { + return nil, err + } } return result, nil } +func moveMemberSubscribesToOwner(tx *gorm.DB, memberUserId, ownerUserId int64) ([]user.Subscribe, error) { + var subscribes []user.Subscribe + if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}). + Model(&user.Subscribe{}). + Where("user_id = ?", memberUserId). + Find(&subscribes).Error; err != nil { + return nil, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "query member subscribe list failed") + } + if len(subscribes) == 0 { + return nil, nil + } + + for _, sub := range subscribes { + newToken := uuidx.SubscribeToken(fmt.Sprintf("familyMove:%d:%d:%s", ownerUserId, sub.Id, uuidx.NewUUID().String())) + if err := tx.Model(&user.Subscribe{}). + Where("id = ? AND user_id = ?", sub.Id, memberUserId). + Updates(map[string]interface{}{ + "user_id": ownerUserId, + "token": newToken, + "uuid": uuidx.NewUUID().String(), + "updated_at": time.Now(), + }).Error; err != nil { + return nil, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseUpdateError), "move member subscribe to owner failed") + } + if sub.OrderId > 0 { + if err := tx.Model(&modelOrder.Order{}). + Where("id = ? AND subscription_user_id = ?", sub.OrderId, memberUserId). + Updates(map[string]interface{}{ + "subscription_user_id": ownerUserId, + "subscribe_token": newToken, + }).Error; err != nil { + return nil, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseUpdateError), "move member subscribe order to owner failed") + } + } + } + return subscribes, nil +} + func transferMemberSubscribesToOwner(tx *gorm.DB, memberUserId, ownerUserId int64) ([]user.Subscribe, error) { var subscribes []user.Subscribe if err := tx.Model(&user.Subscribe{}). @@ -238,6 +308,37 @@ func (h *familyBindingHelper) clearRemovedMemberSubscribeCache(removedSubscribes return nil } +func (h *familyBindingHelper) clearMovedMemberSubscribeCache(movedSubscribes []user.Subscribe, ownerUserId int64) error { + if len(movedSubscribes) == 0 { + return nil + } + + cacheModels := make([]*user.Subscribe, 0, len(movedSubscribes)*2) + ownerCopies := make([]user.Subscribe, len(movedSubscribes)) + for i := range movedSubscribes { + cacheModels = append(cacheModels, &movedSubscribes[i]) + ownerCopies[i] = movedSubscribes[i] + ownerCopies[i].UserId = ownerUserId + cacheModels = append(cacheModels, &ownerCopies[i]) + } + + if err := h.svcCtx.UserModel.ClearSubscribeCache(h.ctx, cacheModels...); err != nil { + return errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "clear moved subscribe cache failed") + } + + _, subscribeIDSet := buildRemovedSubscribeCacheMeta(movedSubscribes) + for subscribeID := range subscribeIDSet { + if err := h.svcCtx.SubscribeModel.ClearCache(h.ctx, subscribeID); err != nil { + return errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "clear subscribe cache failed") + } + } + if err := h.svcCtx.NodeModel.ClearServerAllCache(h.ctx); err != nil { + return errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "clear node cache failed") + } + + return nil +} + func buildRemovedSubscribeCacheMeta(removedSubscribes []user.Subscribe) ([]*user.Subscribe, map[int64]struct{}) { subscribeModels := make([]*user.Subscribe, 0, len(removedSubscribes)) subscribeIDSet := make(map[int64]struct{}, len(removedSubscribes)) diff --git a/internal/logic/public/user/familyBindingHelper_test.go b/internal/logic/public/user/familyBindingHelper_test.go new file mode 100644 index 0000000..76a6df8 --- /dev/null +++ b/internal/logic/public/user/familyBindingHelper_test.go @@ -0,0 +1,288 @@ +package user + +import ( + "testing" + "time" + + modelOrder "github.com/perfect-panel/server/internal/model/order" + modelUser "github.com/perfect-panel/server/internal/model/user" + "gorm.io/driver/sqlite" + "gorm.io/gorm" + "gorm.io/gorm/logger" +) + +func newFamilyBindingTestDB(t *testing.T) *gorm.DB { + t.Helper() + + db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{ + DisableForeignKeyConstraintWhenMigrating: true, + Logger: logger.Default.LogMode(logger.Silent), + }) + if err != nil { + t.Fatalf("open sqlite db: %v", err) + } + if err = db.Exec(` +CREATE TABLE user_subscribe ( + id integer primary key, + user_id integer not null, + order_id integer not null, + subscribe_id integer not null, + start_time datetime, + expire_time datetime, + token text, + uuid text, + status integer, + created_at datetime, + updated_at datetime +)`).Error; err != nil { + t.Fatalf("create user_subscribe schema: %v", err) + } + if err = db.Exec(` +CREATE TABLE "order" ( + id integer primary key, + user_id integer not null, + subscription_user_id integer not null, + order_no text, + subscribe_token text, + status integer, + subscribe_id integer, + created_at datetime, + updated_at datetime +)`).Error; err != nil { + t.Fatalf("create order schema: %v", err) + } + return db +} + +func insertFamilyBindingTestOrder(t *testing.T, db *gorm.DB, order modelOrder.Order) { + t.Helper() + + if err := db.Exec(` +INSERT INTO "order" ( + id, user_id, subscription_user_id, order_no, subscribe_token, status, subscribe_id, created_at, updated_at +) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`, + order.Id, + order.UserId, + order.SubscriptionUserId, + order.OrderNo, + order.SubscribeToken, + order.Status, + order.SubscribeId, + order.CreatedAt, + order.UpdatedAt, + ).Error; err != nil { + t.Fatalf("insert order %d: %v", order.Id, err) + } +} + +func insertFamilyBindingTestSubscribe(t *testing.T, db *gorm.DB, sub modelUser.Subscribe) { + t.Helper() + + if err := db.Exec(` +INSERT INTO user_subscribe ( + id, user_id, order_id, subscribe_id, start_time, expire_time, token, uuid, status, created_at, updated_at +) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, + sub.Id, + sub.UserId, + sub.OrderId, + sub.SubscribeId, + sub.StartTime, + sub.ExpireTime, + sub.Token, + sub.UUID, + sub.Status, + sub.CreatedAt, + sub.UpdatedAt, + ).Error; err != nil { + t.Fatalf("insert subscribe %d: %v", sub.Id, err) + } +} + +func TestMoveMemberSubscribesToOwnerMovesSubscribeAndCurrentOrder(t *testing.T) { + db := newFamilyBindingTestDB(t) + memberUserID := int64(1001) + ownerUserID := int64(2001) + now := time.Now().Add(-time.Hour) + expireAt := now.Add(30 * 24 * time.Hour) + + currentOrder := modelOrder.Order{ + Id: 9001, + UserId: memberUserID, + SubscriptionUserId: memberUserID, + OrderNo: "order-current", + SubscribeToken: "token-current", + Status: 5, + SubscribeId: 3001, + CreatedAt: now, + UpdatedAt: now, + } + unlinkedOrder := modelOrder.Order{ + Id: 9002, + UserId: memberUserID, + SubscriptionUserId: memberUserID, + OrderNo: "order-unlinked", + SubscribeToken: "token-unlinked", + Status: 5, + SubscribeId: 3001, + CreatedAt: now, + UpdatedAt: now, + } + sub := modelUser.Subscribe{ + Id: 7001, + UserId: memberUserID, + OrderId: currentOrder.Id, + SubscribeId: currentOrder.SubscribeId, + StartTime: now, + ExpireTime: expireAt, + Token: "token-current", + UUID: "uuid-current", + Status: 1, + CreatedAt: now, + UpdatedAt: now, + } + insertFamilyBindingTestOrder(t, db, currentOrder) + insertFamilyBindingTestOrder(t, db, unlinkedOrder) + insertFamilyBindingTestSubscribe(t, db, sub) + + var moved []modelUser.Subscribe + if err := db.Transaction(func(tx *gorm.DB) error { + var err error + moved, err = moveMemberSubscribesToOwner(tx, memberUserID, ownerUserID) + return err + }); err != nil { + t.Fatalf("move member subscribes: %v", err) + } + + if len(moved) != 1 { + t.Fatalf("moved subscribes length = %d, want 1", len(moved)) + } + if moved[0].UserId != memberUserID { + t.Fatalf("moved cache copy user_id = %d, want original member %d", moved[0].UserId, memberUserID) + } + + var gotSub modelUser.Subscribe + if err := db.First(&gotSub, "id = ?", sub.Id).Error; err != nil { + t.Fatalf("query moved subscribe: %v", err) + } + if gotSub.UserId != ownerUserID { + t.Fatalf("subscribe user_id = %d, want owner %d", gotSub.UserId, ownerUserID) + } + if gotSub.Token == "" || gotSub.Token == sub.Token { + t.Fatalf("subscribe token = %q, want regenerated from old token %q", gotSub.Token, sub.Token) + } + if gotSub.UUID == "" || gotSub.UUID == sub.UUID { + t.Fatalf("subscribe uuid = %q, want regenerated from old uuid %q", gotSub.UUID, sub.UUID) + } + + var gotOrder modelOrder.Order + if err := db.First(&gotOrder, "id = ?", currentOrder.Id).Error; err != nil { + t.Fatalf("query updated order: %v", err) + } + if gotOrder.SubscriptionUserId != ownerUserID { + t.Fatalf("current order subscription_user_id = %d, want owner %d", gotOrder.SubscriptionUserId, ownerUserID) + } + if gotOrder.SubscribeToken != gotSub.Token { + t.Fatalf("current order subscribe_token = %q, want regenerated subscribe token %q", gotOrder.SubscribeToken, gotSub.Token) + } + + var gotUnlinkedOrder modelOrder.Order + if err := db.First(&gotUnlinkedOrder, "id = ?", unlinkedOrder.Id).Error; err != nil { + t.Fatalf("query unlinked order: %v", err) + } + if gotUnlinkedOrder.SubscriptionUserId != memberUserID { + t.Fatalf("unlinked order subscription_user_id = %d, want unchanged member %d", gotUnlinkedOrder.SubscriptionUserId, memberUserID) + } + if gotUnlinkedOrder.SubscribeToken != unlinkedOrder.SubscribeToken { + t.Fatalf("unlinked order subscribe_token = %q, want unchanged token %q", gotUnlinkedOrder.SubscribeToken, unlinkedOrder.SubscribeToken) + } +} + +func TestMoveMemberSubscribesToOwnerNoSubscribesIsNoop(t *testing.T) { + db := newFamilyBindingTestDB(t) + + var moved []modelUser.Subscribe + if err := db.Transaction(func(tx *gorm.DB) error { + var err error + moved, err = moveMemberSubscribesToOwner(tx, 1001, 2001) + return err + }); err != nil { + t.Fatalf("move empty member subscribes: %v", err) + } + if len(moved) != 0 { + t.Fatalf("moved subscribes length = %d, want 0", len(moved)) + } +} + +func TestTransferMemberSubscribesToOwnerStillDiscardsMemberSubscribes(t *testing.T) { + db := newFamilyBindingTestDB(t) + memberUserID := int64(1001) + ownerUserID := int64(2001) + now := time.Now().Add(-time.Hour) + + order := modelOrder.Order{ + Id: 9001, + UserId: memberUserID, + SubscriptionUserId: memberUserID, + OrderNo: "order-discard", + SubscribeToken: "token-discard", + Status: 5, + SubscribeId: 3001, + CreatedAt: now, + UpdatedAt: now, + } + sub := modelUser.Subscribe{ + Id: 7001, + UserId: memberUserID, + OrderId: order.Id, + SubscribeId: order.SubscribeId, + StartTime: now, + ExpireTime: now.Add(30 * 24 * time.Hour), + Token: "token-discard", + UUID: "uuid-discard", + Status: 1, + CreatedAt: now, + UpdatedAt: now, + } + insertFamilyBindingTestOrder(t, db, order) + insertFamilyBindingTestSubscribe(t, db, sub) + + var removed []modelUser.Subscribe + if err := db.Transaction(func(tx *gorm.DB) error { + var err error + removed, err = transferMemberSubscribesToOwner(tx, memberUserID, ownerUserID) + return err + }); err != nil { + t.Fatalf("discard member subscribes: %v", err) + } + + if len(removed) != 1 { + t.Fatalf("removed subscribes length = %d, want 1", len(removed)) + } + if removed[0].UserId != memberUserID { + t.Fatalf("removed cache copy user_id = %d, want original member %d", removed[0].UserId, memberUserID) + } + + var memberCount int64 + if err := db.Model(&modelUser.Subscribe{}).Where("user_id = ?", memberUserID).Count(&memberCount).Error; err != nil { + t.Fatalf("count member subscribes: %v", err) + } + if memberCount != 0 { + t.Fatalf("member subscribe count = %d, want 0", memberCount) + } + + var ownerCount int64 + if err := db.Model(&modelUser.Subscribe{}).Where("user_id = ?", ownerUserID).Count(&ownerCount).Error; err != nil { + t.Fatalf("count owner subscribes: %v", err) + } + if ownerCount != 0 { + t.Fatalf("owner subscribe count = %d, want 0", ownerCount) + } + + var gotOrder modelOrder.Order + if err := db.First(&gotOrder, "id = ?", order.Id).Error; err != nil { + t.Fatalf("query order: %v", err) + } + if gotOrder.SubscriptionUserId != memberUserID { + t.Fatalf("discard path order subscription_user_id = %d, want unchanged member %d", gotOrder.SubscriptionUserId, memberUserID) + } +}