diff --git a/apis/public/subscribe.api b/apis/public/subscribe.api index 4c0d2aa..6c51122 100644 --- a/apis/public/subscribe.api +++ b/apis/public/subscribe.api @@ -15,47 +15,50 @@ type ( Language string `form:"language"` } - QueryUserSubscribeNodeListResponse { - List []UserSubscribeInfo `json:"list"` - } + QueryUserSubscribeNodeListResponse { + List []UserSubscribeInfo `json:"list"` + } - UserSubscribeInfo { - Id int64 `json:"id"` - UserId int64 `json:"user_id"` - OrderId int64 `json:"order_id"` - SubscribeId int64 `json:"subscribe_id"` - StartTime int64 `json:"start_time"` - ExpireTime int64 `json:"expire_time"` - FinishedAt int64 `json:"finished_at"` - ResetTime int64 `json:"reset_time"` - Traffic int64 `json:"traffic"` - Download int64 `json:"download"` - Upload int64 `json:"upload"` - Token string `json:"token"` - Status uint8 `json:"status"` - CreatedAt int64 `json:"created_at"` - UpdatedAt int64 `json:"updated_at"` - IsTryOut bool `json:"is_try_out"` - Nodes []*UserSubscribeNodeInfo `json:"nodes"` - } + UserSubscribeInfo { + Id int64 `json:"id"` + UserId int64 `json:"user_id"` + OrderId int64 `json:"order_id"` + SubscribeId int64 `json:"subscribe_id"` + StartTime int64 `json:"start_time"` + ExpireTime int64 `json:"expire_time"` + FinishedAt int64 `json:"finished_at"` + ResetTime int64 `json:"reset_time"` + Traffic int64 `json:"traffic"` + Download int64 `json:"download"` + Upload int64 `json:"upload"` + Token string `json:"token"` + Status uint8 `json:"status"` + EntitlementSource string `json:"entitlement_source"` + EntitlementOwnerUserId int64 `json:"entitlement_owner_user_id"` + ReadOnly bool `json:"read_only"` + CreatedAt int64 `json:"created_at"` + UpdatedAt int64 `json:"updated_at"` + IsTryOut bool `json:"is_try_out"` + Nodes []*UserSubscribeNodeInfo `json:"nodes"` + } - UserSubscribeNodeInfo{ - Id int64 `json:"id"` - Name string `json:"name"` - Uuid string `json:"uuid"` - Protocol string `json:"protocol"` - Protocols string `json:"protocols"` - Port uint16 `json:"port"` - Address string `json:"address"` - Tags []string `json:"tags"` - Country string `json:"country"` - City string `json:"city"` - Longitude string `json:"longitude"` - Latitude string `json:"latitude"` - LatitudeCenter string `json:"latitude_center"` - LongitudeCenter string `json:"longitude_center"` - CreatedAt int64 `json:"created_at"` - } + UserSubscribeNodeInfo { + Id int64 `json:"id"` + Name string `json:"name"` + Uuid string `json:"uuid"` + Protocol string `json:"protocol"` + Protocols string `json:"protocols"` + Port uint16 `json:"port"` + Address string `json:"address"` + Tags []string `json:"tags"` + Country string `json:"country"` + City string `json:"city"` + Longitude string `json:"longitude"` + Latitude string `json:"latitude"` + LatitudeCenter string `json:"latitude_center"` + LongitudeCenter string `json:"longitude_center"` + CreatedAt int64 `json:"created_at"` + } ) @server ( @@ -68,8 +71,7 @@ service ppanel { @handler QuerySubscribeList get /list (QuerySubscribeListRequest) returns (QuerySubscribeListResponse) - @doc "Get user subscribe node info" - @handler QueryUserSubscribeNodeList - get /node/list returns (QueryUserSubscribeNodeListResponse) + @doc "Get user subscribe node info" + @handler QueryUserSubscribeNodeList + get /node/list returns (QueryUserSubscribeNodeListResponse) } - diff --git a/apis/types.api b/apis/types.api index 0a3552d..8f60427 100644 --- a/apis/types.api +++ b/apis/types.api @@ -507,6 +507,9 @@ type ( Upload int64 `json:"upload"` Token string `json:"token"` Status uint8 `json:"status"` + EntitlementSource string `json:"entitlement_source"` + EntitlementOwnerUserId int64 `json:"entitlement_owner_user_id"` + ReadOnly bool `json:"read_only"` Short string `json:"short"` CreatedAt int64 `json:"created_at"` UpdatedAt int64 `json:"updated_at"` diff --git a/internal/logic/admin/user/getUserDetailLogic.go b/internal/logic/admin/user/getUserDetailLogic.go index c95e6af..5373ae5 100644 --- a/internal/logic/admin/user/getUserDetailLogic.go +++ b/internal/logic/admin/user/getUserDetailLogic.go @@ -2,7 +2,9 @@ package user import ( "context" + "strings" + logicCommon "github.com/perfect-panel/server/internal/logic/common" "github.com/perfect-panel/server/internal/model/user" "github.com/perfect-panel/server/internal/svc" "github.com/perfect-panel/server/internal/types" @@ -33,6 +35,9 @@ func (l *GetUserDetailLogic) GetUserDetail(req *types.GetDetailRequest) (*types. return nil, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "get user detail error: %v", err.Error()) } tool.DeepCopy(&resp, userInfo) + if referCode := strings.TrimSpace(resp.ReferCode); referCode != "" { + resp.ShareLink = logicCommon.NewInviteLinkResolver(l.ctx, l.svcCtx).ResolveInviteLink(referCode) + } type familyRelation struct { FamilyId int64 diff --git a/internal/logic/admin/user/getUserListLogic.go b/internal/logic/admin/user/getUserListLogic.go index ee6cf68..f7ac017 100644 --- a/internal/logic/admin/user/getUserListLogic.go +++ b/internal/logic/admin/user/getUserListLogic.go @@ -3,7 +3,10 @@ package user import ( "context" "fmt" + "strings" + "time" + logicCommon "github.com/perfect-panel/server/internal/logic/common" "github.com/perfect-panel/server/internal/model/user" "github.com/perfect-panel/server/internal/svc" "github.com/perfect-panel/server/internal/types" @@ -45,6 +48,16 @@ func (l *GetUserListLogic) GetUserList(req *types.GetUserListRequest) (*types.Ge return nil, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "GetUserListLogic failed: %v", err.Error()) } + referCodes := make([]string, 0, len(list)) + for _, item := range list { + referCode := strings.TrimSpace(item.ReferCode) + if referCode == "" { + continue + } + referCodes = append(referCodes, referCode) + } + inviteLinkMap := logicCommon.NewInviteLinkResolver(l.ctx, l.svcCtx).ResolveInviteLinksBatch(referCodes, 8, 3, 1500*time.Millisecond) + // Batch fetch active subscriptions userIds := make([]int64, 0, len(list)) for _, u := range list { @@ -171,6 +184,9 @@ func (l *GetUserListLogic) GetUserList(req *types.GetUserListRequest) (*types.Ge u.FamilyRoleName = fmt.Sprintf("role_%d", relation.Role) } } + if referCode := strings.TrimSpace(item.ReferCode); referCode != "" { + u.ShareLink = inviteLinkMap[referCode] + } userRespList = append(userRespList, u) } diff --git a/internal/logic/common/familyEntitlement.go b/internal/logic/common/familyEntitlement.go new file mode 100644 index 0000000..f6f15bc --- /dev/null +++ b/internal/logic/common/familyEntitlement.go @@ -0,0 +1,88 @@ +package common + +import ( + "context" + + modelUser "github.com/perfect-panel/server/internal/model/user" + "github.com/perfect-panel/server/pkg/xerr" + "github.com/pkg/errors" + "gorm.io/gorm" +) + +const ( + EntitlementSourceSelf = "self" + EntitlementSourceFamilyOwner = "family_owner" +) + +type EntitlementContext struct { + EffectiveUserID int64 + Source string + OwnerUserID int64 + ReadOnly bool +} + +type familyEntitlementRelation struct { + Role uint8 `gorm:"column:role"` + FamilyStatus uint8 `gorm:"column:family_status"` + OwnerUserID int64 `gorm:"column:owner_user_id"` +} + +func ResolveEntitlementUser(ctx context.Context, db *gorm.DB, currentUserID int64) (*EntitlementContext, error) { + entitlement := buildEntitlementContext(currentUserID, nil) + if currentUserID <= 0 { + return entitlement, nil + } + + var relation familyEntitlementRelation + err := db.WithContext(ctx). + Table("user_family_member"). + Select("user_family_member.role, user_family.status AS family_status, user_family.owner_user_id"). + Joins("JOIN user_family ON user_family.id = user_family_member.family_id AND user_family.deleted_at IS NULL"). + Where("user_family_member.user_id = ? AND user_family_member.deleted_at IS NULL AND user_family_member.status = ?", currentUserID, modelUser.FamilyMemberActive). + First(&relation).Error + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return entitlement, nil + } + return nil, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "query family entitlement relation failed") + } + + return buildEntitlementContext(currentUserID, &relation), nil +} + +func DenyIfFamilyMemberReadonly(ctx context.Context, db *gorm.DB, currentUserID int64) error { + entitlement, err := ResolveEntitlementUser(ctx, db, currentUserID) + if err != nil { + return err + } + return denyReadonlyEntitlement(entitlement) +} + +func buildEntitlementContext(currentUserID int64, relation *familyEntitlementRelation) *EntitlementContext { + entitlement := &EntitlementContext{ + EffectiveUserID: currentUserID, + Source: EntitlementSourceSelf, + } + if relation == nil { + return entitlement + } + if relation.Role == modelUser.FamilyRoleMember && + relation.FamilyStatus == modelUser.FamilyStatusActive && + relation.OwnerUserID > 0 && + relation.OwnerUserID != currentUserID { + return &EntitlementContext{ + EffectiveUserID: relation.OwnerUserID, + Source: EntitlementSourceFamilyOwner, + OwnerUserID: relation.OwnerUserID, + ReadOnly: true, + } + } + return entitlement +} + +func denyReadonlyEntitlement(entitlement *EntitlementContext) error { + if entitlement != nil && entitlement.ReadOnly { + return errors.Wrapf(xerr.NewErrCode(xerr.FamilyOwnerOperationForbidden), "family member operation is forbidden") + } + return nil +} diff --git a/internal/logic/common/familyEntitlement_test.go b/internal/logic/common/familyEntitlement_test.go new file mode 100644 index 0000000..9863d63 --- /dev/null +++ b/internal/logic/common/familyEntitlement_test.go @@ -0,0 +1,78 @@ +package common + +import ( + stderrors "errors" + "testing" + + modelUser "github.com/perfect-panel/server/internal/model/user" + "github.com/perfect-panel/server/pkg/xerr" + pkgerrors "github.com/pkg/errors" + "github.com/stretchr/testify/require" +) + +func extractFamilyEntitlementCode(err error) uint32 { + if err == nil { + return 0 + } + + var codeErr *xerr.CodeError + if stderrors.As(pkgerrors.Cause(err), &codeErr) { + return codeErr.GetErrCode() + } + return 0 +} + +func TestBuildEntitlementContext(t *testing.T) { + t.Run("default self entitlement", func(t *testing.T) { + entitlement := buildEntitlementContext(1001, nil) + require.Equal(t, int64(1001), entitlement.EffectiveUserID) + require.Equal(t, EntitlementSourceSelf, entitlement.Source) + require.Equal(t, int64(0), entitlement.OwnerUserID) + require.False(t, entitlement.ReadOnly) + }) + + t.Run("active family member uses owner entitlement", func(t *testing.T) { + entitlement := buildEntitlementContext(1001, &familyEntitlementRelation{ + Role: modelUser.FamilyRoleMember, + FamilyStatus: modelUser.FamilyStatusActive, + OwnerUserID: 2001, + }) + require.Equal(t, int64(2001), entitlement.EffectiveUserID) + require.Equal(t, EntitlementSourceFamilyOwner, entitlement.Source) + require.Equal(t, int64(2001), entitlement.OwnerUserID) + require.True(t, entitlement.ReadOnly) + }) + + t.Run("owner relation keeps self entitlement", func(t *testing.T) { + entitlement := buildEntitlementContext(2001, &familyEntitlementRelation{ + Role: modelUser.FamilyRoleOwner, + FamilyStatus: modelUser.FamilyStatusActive, + OwnerUserID: 2001, + }) + require.Equal(t, int64(2001), entitlement.EffectiveUserID) + require.Equal(t, EntitlementSourceSelf, entitlement.Source) + require.False(t, entitlement.ReadOnly) + }) + + t.Run("disabled family keeps self entitlement", func(t *testing.T) { + entitlement := buildEntitlementContext(1001, &familyEntitlementRelation{ + Role: modelUser.FamilyRoleMember, + FamilyStatus: 0, + OwnerUserID: 2001, + }) + require.Equal(t, int64(1001), entitlement.EffectiveUserID) + require.Equal(t, EntitlementSourceSelf, entitlement.Source) + require.False(t, entitlement.ReadOnly) + }) +} + +func TestDenyReadonlyEntitlement(t *testing.T) { + require.NoError(t, denyReadonlyEntitlement(&EntitlementContext{ReadOnly: false})) + + err := denyReadonlyEntitlement(&EntitlementContext{ + Source: EntitlementSourceFamilyOwner, + ReadOnly: true, + }) + require.Error(t, err) + require.Equal(t, xerr.FamilyOwnerOperationForbidden, extractFamilyEntitlementCode(err)) +} diff --git a/internal/logic/common/inviteLinkResolver.go b/internal/logic/common/inviteLinkResolver.go new file mode 100644 index 0000000..23e9329 --- /dev/null +++ b/internal/logic/common/inviteLinkResolver.go @@ -0,0 +1,289 @@ +package common + +import ( + "context" + "encoding/json" + "fmt" + "net/url" + "strings" + "sync" + "time" + + "github.com/perfect-panel/server/internal/svc" + "github.com/perfect-panel/server/pkg/kutt" +) + +const inviteShortLinkCachePrefix = "cache:invite:short_link:" + +type inviteLinkCustomData struct { + ShareURL string `json:"shareUrl"` + Domain string `json:"domain"` +} + +type InviteLinkResolver struct { + ctx context.Context + svcCtx *svc.ServiceContext + createShortLink func(ctx context.Context, targetURL, domain string) (string, error) +} + +func NewInviteLinkResolver(ctx context.Context, svcCtx *svc.ServiceContext) *InviteLinkResolver { + resolver := &InviteLinkResolver{ + ctx: ctx, + svcCtx: svcCtx, + } + resolver.createShortLink = func(ctx context.Context, targetURL, domain string) (string, error) { + client := kutt.NewClient(svcCtx.Config.Kutt.ApiURL, svcCtx.Config.Kutt.ApiKey) + link, err := client.CreateShortLink(ctx, &kutt.CreateLinkRequest{ + Target: targetURL, + Reuse: true, + Domain: domain, + }) + if err != nil { + return "", err + } + + shortLink := strings.TrimSpace(link.Link) + if strings.HasPrefix(shortLink, "http://") { + shortLink = strings.Replace(shortLink, "http://", "https://", 1) + } + return shortLink, nil + } + return resolver +} + +func (r *InviteLinkResolver) ResolveInviteLink(referCode string) string { + normalizedCode := strings.TrimSpace(referCode) + if normalizedCode == "" { + return "" + } + + longLink := r.buildLongInviteLink(normalizedCode) + if !r.canUseKutt() || longLink == "" { + return longLink + } + + if cached := r.getCachedShortLink(normalizedCode); cached != "" { + return cached + } + + shortLink, err := r.generateShortLinkWithTimeout(normalizedCode, 1500*time.Millisecond) + if err != nil || strings.TrimSpace(shortLink) == "" { + return longLink + } + + r.cacheShortLink(normalizedCode, shortLink) + return shortLink +} + +func (r *InviteLinkResolver) ResolveInviteLinksBatch(referCodes []string, maxGenerate, maxConcurrency int, timeout time.Duration) map[string]string { + result := make(map[string]string) + uniqueCodes := uniqueReferCodes(referCodes) + if len(uniqueCodes) == 0 { + return result + } + + for _, referCode := range uniqueCodes { + result[referCode] = r.buildLongInviteLink(referCode) + } + + if !r.canUseKutt() { + return result + } + + toGenerate := make([]string, 0, len(uniqueCodes)) + for _, referCode := range uniqueCodes { + if cached := r.getCachedShortLink(referCode); cached != "" { + result[referCode] = cached + continue + } + toGenerate = append(toGenerate, referCode) + } + + if maxGenerate > 0 && len(toGenerate) > maxGenerate { + toGenerate = toGenerate[:maxGenerate] + } + + if len(toGenerate) == 0 { + return result + } + + if maxConcurrency <= 0 { + maxConcurrency = 1 + } + if timeout <= 0 { + timeout = 1500 * time.Millisecond + } + + limiter := make(chan struct{}, maxConcurrency) + var waitGroup sync.WaitGroup + var mutex sync.Mutex + + for _, referCode := range toGenerate { + waitGroup.Add(1) + currentCode := referCode + + go func() { + defer waitGroup.Done() + limiter <- struct{}{} + defer func() { <-limiter }() + + shortLink, err := r.generateShortLinkWithTimeout(currentCode, timeout) + if err != nil || strings.TrimSpace(shortLink) == "" { + return + } + + mutex.Lock() + result[currentCode] = shortLink + mutex.Unlock() + + r.cacheShortLink(currentCode, shortLink) + }() + } + + waitGroup.Wait() + return result +} + +func (r *InviteLinkResolver) canUseKutt() bool { + if r == nil || r.svcCtx == nil { + return false + } + if !r.svcCtx.Config.Kutt.Enable { + return false + } + if strings.TrimSpace(r.svcCtx.Config.Kutt.ApiURL) == "" || strings.TrimSpace(r.svcCtx.Config.Kutt.ApiKey) == "" { + return false + } + return r.createShortLink != nil +} + +func (r *InviteLinkResolver) resolveShareURLAndDomain() (string, string) { + if r == nil || r.svcCtx == nil { + return "", "" + } + + shareURL := strings.TrimSpace(r.svcCtx.Config.Kutt.TargetURL) + domain := strings.TrimSpace(r.svcCtx.Config.Kutt.Domain) + + customData := strings.TrimSpace(r.svcCtx.Config.Site.CustomData) + if customData == "" { + return shareURL, domain + } + + var parsedData inviteLinkCustomData + if err := json.Unmarshal([]byte(customData), &parsedData); err != nil { + return shareURL, domain + } + + if strings.TrimSpace(parsedData.ShareURL) != "" { + shareURL = strings.TrimSpace(parsedData.ShareURL) + } + if strings.TrimSpace(parsedData.Domain) != "" { + domain = strings.TrimSpace(parsedData.Domain) + } + + return shareURL, domain +} + +func (r *InviteLinkResolver) buildLongInviteLink(referCode string) string { + normalizedCode := strings.TrimSpace(referCode) + if normalizedCode == "" { + return "" + } + + shareURL, _ := r.resolveShareURLAndDomain() + if shareURL == "" { + return "" + } + + parsedURL, err := url.Parse(shareURL) + if err != nil { + return fallbackLongInviteLink(shareURL, normalizedCode) + } + + queryValues := parsedURL.Query() + queryValues.Set("ic", normalizedCode) + parsedURL.RawQuery = queryValues.Encode() + + return parsedURL.String() +} + +func (r *InviteLinkResolver) generateShortLinkWithTimeout(referCode string, timeout time.Duration) (string, error) { + longLink := r.buildLongInviteLink(referCode) + if longLink == "" { + return "", nil + } + _, domain := r.resolveShareURLAndDomain() + + requestCtx := r.ctx + var cancel context.CancelFunc + if timeout > 0 { + requestCtx, cancel = context.WithTimeout(r.ctx, timeout) + defer cancel() + } + + shortLink, err := r.createShortLink(requestCtx, longLink, domain) + if err != nil { + return "", err + } + return strings.TrimSpace(shortLink), nil +} + +func (r *InviteLinkResolver) getCachedShortLink(referCode string) string { + if r == nil || r.svcCtx == nil || r.svcCtx.Redis == nil { + return "" + } + + cacheKey := inviteShortLinkCachePrefix + referCode + shortLink, err := r.svcCtx.Redis.Get(r.ctx, cacheKey).Result() + if err != nil { + return "" + } + return strings.TrimSpace(shortLink) +} + +func (r *InviteLinkResolver) cacheShortLink(referCode, shortLink string) { + if r == nil || r.svcCtx == nil || r.svcCtx.Redis == nil { + return + } + if strings.TrimSpace(referCode) == "" || strings.TrimSpace(shortLink) == "" { + return + } + + cacheKey := inviteShortLinkCachePrefix + referCode + _ = r.svcCtx.Redis.Set(r.ctx, cacheKey, shortLink, 0).Err() +} + +func uniqueReferCodes(referCodes []string) []string { + uniqueCodes := make([]string, 0, len(referCodes)) + seen := make(map[string]struct{}, len(referCodes)) + + for _, referCode := range referCodes { + normalizedCode := strings.TrimSpace(referCode) + if normalizedCode == "" { + continue + } + if _, exists := seen[normalizedCode]; exists { + continue + } + seen[normalizedCode] = struct{}{} + uniqueCodes = append(uniqueCodes, normalizedCode) + } + + return uniqueCodes +} + +func fallbackLongInviteLink(baseURL, referCode string) string { + normalizedBase := strings.TrimSpace(baseURL) + normalizedCode := strings.TrimSpace(referCode) + if normalizedBase == "" || normalizedCode == "" { + return "" + } + + separator := "?" + if strings.Contains(normalizedBase, "?") { + separator = "&" + } + trimmedBase := strings.TrimRight(normalizedBase, "?&") + return fmt.Sprintf("%s%sic=%s", trimmedBase, separator, url.QueryEscape(normalizedCode)) +} diff --git a/internal/logic/common/inviteLinkResolver_test.go b/internal/logic/common/inviteLinkResolver_test.go new file mode 100644 index 0000000..36f5fe9 --- /dev/null +++ b/internal/logic/common/inviteLinkResolver_test.go @@ -0,0 +1,145 @@ +package common + +import ( + "context" + "errors" + "net/url" + "testing" + + "github.com/alicebob/miniredis/v2" + "github.com/perfect-panel/server/internal/config" + "github.com/perfect-panel/server/internal/svc" + "github.com/redis/go-redis/v9" + "github.com/stretchr/testify/require" +) + +func buildInviteResolverForTest(t *testing.T, cfg config.Config) (*InviteLinkResolver, *miniredis.Miniredis) { + t.Helper() + + redisServer, err := miniredis.Run() + require.NoError(t, err) + t.Cleanup(func() { + redisServer.Close() + }) + + redisClient := redis.NewClient(&redis.Options{ + Addr: redisServer.Addr(), + DB: 0, + }) + t.Cleanup(func() { + _ = redisClient.Close() + }) + + serviceCtx := &svc.ServiceContext{ + Config: cfg, + Redis: redisClient, + } + + resolver := NewInviteLinkResolver(context.Background(), serviceCtx) + return resolver, redisServer +} + +func TestInviteLinkResolverResolveInviteLink(t *testing.T) { + t.Run("kutt disabled returns long link", func(t *testing.T) { + cfg := config.Config{} + cfg.Kutt.TargetURL = "https://example.com/register" + + resolver, _ := buildInviteResolverForTest(t, cfg) + link := resolver.ResolveInviteLink("abc123") + require.Equal(t, "https://example.com/register?ic=abc123", link) + }) + + t.Run("cache hit returns cached short link", func(t *testing.T) { + cfg := config.Config{} + cfg.Kutt.Enable = true + cfg.Kutt.ApiURL = "https://kutt.local/api/v2" + cfg.Kutt.ApiKey = "token" + cfg.Kutt.TargetURL = "https://example.com/register" + + resolver, redisServer := buildInviteResolverForTest(t, cfg) + redisServer.Set(inviteShortLinkCachePrefix+"abc123", "https://sho.rt/cached") + + called := 0 + resolver.createShortLink = func(ctx context.Context, targetURL, domain string) (string, error) { + called++ + return "", errors.New("should not call createShortLink on cache hit") + } + + link := resolver.ResolveInviteLink("abc123") + require.Equal(t, "https://sho.rt/cached", link) + require.Equal(t, 0, called) + }) + + t.Run("cache miss kutt success returns short link and writes cache", func(t *testing.T) { + cfg := config.Config{} + cfg.Kutt.Enable = true + cfg.Kutt.ApiURL = "https://kutt.local/api/v2" + cfg.Kutt.ApiKey = "token" + cfg.Kutt.TargetURL = "https://example.com/register" + + resolver, _ := buildInviteResolverForTest(t, cfg) + resolver.createShortLink = func(ctx context.Context, targetURL, domain string) (string, error) { + return "https://sho.rt/new", nil + } + + link := resolver.ResolveInviteLink("abc123") + require.Equal(t, "https://sho.rt/new", link) + + cached := resolver.getCachedShortLink("abc123") + require.Equal(t, "https://sho.rt/new", cached) + }) + + t.Run("kutt failure falls back to long link", func(t *testing.T) { + cfg := config.Config{} + cfg.Kutt.Enable = true + cfg.Kutt.ApiURL = "https://kutt.local/api/v2" + cfg.Kutt.ApiKey = "token" + cfg.Kutt.TargetURL = "https://example.com/register" + + resolver, _ := buildInviteResolverForTest(t, cfg) + resolver.createShortLink = func(ctx context.Context, targetURL, domain string) (string, error) { + return "", errors.New("kutt request failed") + } + + link := resolver.ResolveInviteLink("abc123") + require.Equal(t, "https://example.com/register?ic=abc123", link) + }) + + t.Run("long link preserves existing query string", func(t *testing.T) { + cfg := config.Config{} + cfg.Kutt.TargetURL = "https://example.com/register?channel=ios" + + resolver, _ := buildInviteResolverForTest(t, cfg) + link := resolver.ResolveInviteLink("abc123") + parsed, err := url.Parse(link) + require.NoError(t, err) + require.Equal(t, "https", parsed.Scheme) + require.Equal(t, "example.com", parsed.Host) + require.Equal(t, "/register", parsed.Path) + require.Equal(t, "ios", parsed.Query().Get("channel")) + require.Equal(t, "abc123", parsed.Query().Get("ic")) + }) + + t.Run("kutt target preserves existing query string", func(t *testing.T) { + cfg := config.Config{} + cfg.Kutt.Enable = true + cfg.Kutt.ApiURL = "https://kutt.local/api/v2" + cfg.Kutt.ApiKey = "token" + cfg.Kutt.TargetURL = "https://example.com/register?channel=ios" + + resolver, _ := buildInviteResolverForTest(t, cfg) + capturedTargetURL := "" + resolver.createShortLink = func(ctx context.Context, targetURL, domain string) (string, error) { + capturedTargetURL = targetURL + return "https://sho.rt/query", nil + } + + link := resolver.ResolveInviteLink("abc123") + require.Equal(t, "https://sho.rt/query", link) + + parsed, err := url.Parse(capturedTargetURL) + require.NoError(t, err) + require.Equal(t, "ios", parsed.Query().Get("channel")) + require.Equal(t, "abc123", parsed.Query().Get("ic")) + }) +} diff --git a/internal/logic/public/iap/apple/attachTransactionLogic.go b/internal/logic/public/iap/apple/attachTransactionLogic.go index bceb8b6..8e8d926 100644 --- a/internal/logic/public/iap/apple/attachTransactionLogic.go +++ b/internal/logic/public/iap/apple/attachTransactionLogic.go @@ -12,6 +12,7 @@ import ( "github.com/google/uuid" "github.com/hibiken/asynq" + commonLogic "github.com/perfect-panel/server/internal/logic/common" iapmodel "github.com/perfect-panel/server/internal/model/iap/apple" "github.com/perfect-panel/server/internal/model/subscribe" "github.com/perfect-panel/server/internal/model/user" @@ -47,6 +48,9 @@ func (l *AttachTransactionLogic) Attach(req *types.AttachAppleTransactionRequest l.Errorw("无效访问,用户信息缺失") return nil, errors.Wrapf(xerr.NewErrCode(xerr.InvalidAccess), "invalid access") } + if err := commonLogic.DenyIfFamilyMemberReadonly(l.ctx, l.svcCtx.DB, u.Id); err != nil { + return nil, err + } txPayload, err := iapapple.VerifyTransactionJWS(req.SignedTransactionJWS) if err != nil { l.Errorw("JWS 验签失败", logger.Field("error", err.Error())) diff --git a/internal/logic/public/iap/apple/restoreLogic.go b/internal/logic/public/iap/apple/restoreLogic.go index 9364667..08c4175 100644 --- a/internal/logic/public/iap/apple/restoreLogic.go +++ b/internal/logic/public/iap/apple/restoreLogic.go @@ -7,6 +7,7 @@ import ( "time" "github.com/google/uuid" + commonLogic "github.com/perfect-panel/server/internal/logic/common" iapmodel "github.com/perfect-panel/server/internal/model/iap/apple" "github.com/perfect-panel/server/internal/model/payment" "github.com/perfect-panel/server/internal/model/user" @@ -39,6 +40,9 @@ func (l *RestoreLogic) Restore(req *types.RestoreAppleTransactionsRequest) error if !ok || u == nil { return errors.Wrapf(xerr.NewErrCode(xerr.InvalidAccess), "invalid access") } + if err := commonLogic.DenyIfFamilyMemberReadonly(l.ctx, l.svcCtx.DB, u.Id); err != nil { + return err + } pm, _ := iapapple.ParseProductMap(l.svcCtx.Config.Site.CustomData) // Try to load payment config to get API credentials var apiCfg iapapple.ServerAPIConfig diff --git a/internal/logic/public/order/preCreateOrderLogic.go b/internal/logic/public/order/preCreateOrderLogic.go index 7b06174..8b8f324 100644 --- a/internal/logic/public/order/preCreateOrderLogic.go +++ b/internal/logic/public/order/preCreateOrderLogic.go @@ -45,6 +45,9 @@ func (l *PreCreateOrderLogic) PreCreateOrder(req *types.PurchaseOrderRequest) (r logger.Error("current user is not found in context") return nil, errors.Wrapf(xerr.NewErrCode(xerr.InvalidAccess), "Invalid Access") } + if err = commonLogic.DenyIfFamilyMemberReadonly(l.ctx, l.svcCtx.DB, u.Id); err != nil { + return nil, err + } if req.Quantity <= 0 { l.Debugf("[PreCreateOrder] Quantity is less than or equal to 0, setting to 1") diff --git a/internal/logic/public/order/purchaseLogic.go b/internal/logic/public/order/purchaseLogic.go index a98f18f..f3226e3 100644 --- a/internal/logic/public/order/purchaseLogic.go +++ b/internal/logic/public/order/purchaseLogic.go @@ -54,6 +54,9 @@ func (l *PurchaseLogic) Purchase(req *types.PurchaseOrderRequest) (resp *types.P logger.Error("current user is not found in context") return nil, errors.Wrapf(xerr.NewErrCode(xerr.InvalidAccess), "Invalid Access") } + if err = commonLogic.DenyIfFamilyMemberReadonly(l.ctx, l.svcCtx.DB, u.Id); err != nil { + return nil, err + } if req.Quantity <= 0 { l.Debugf("[Purchase] Quantity is less than or equal to 0, setting to 1") diff --git a/internal/logic/public/order/renewalLogic.go b/internal/logic/public/order/renewalLogic.go index 1898652..3643b7f 100644 --- a/internal/logic/public/order/renewalLogic.go +++ b/internal/logic/public/order/renewalLogic.go @@ -6,6 +6,7 @@ import ( "math" "time" + commonLogic "github.com/perfect-panel/server/internal/logic/common" "github.com/perfect-panel/server/internal/model/log" "github.com/perfect-panel/server/pkg/constant" @@ -46,6 +47,9 @@ func (l *RenewalLogic) Renewal(req *types.RenewalOrderRequest) (resp *types.Rene logger.Error("current user is not found in context") return nil, errors.Wrapf(xerr.NewErrCode(xerr.InvalidAccess), "Invalid Access") } + if err = commonLogic.DenyIfFamilyMemberReadonly(l.ctx, l.svcCtx.DB, u.Id); err != nil { + return nil, err + } if req.Quantity <= 0 { l.Debugf("[Renewal] Quantity is less than or equal to 0, setting to 1") req.Quantity = 1 diff --git a/internal/logic/public/order/resetTrafficLogic.go b/internal/logic/public/order/resetTrafficLogic.go index a3a0669..4bde541 100644 --- a/internal/logic/public/order/resetTrafficLogic.go +++ b/internal/logic/public/order/resetTrafficLogic.go @@ -5,6 +5,7 @@ import ( "encoding/json" "time" + commonLogic "github.com/perfect-panel/server/internal/logic/common" "github.com/perfect-panel/server/internal/model/log" "github.com/perfect-panel/server/pkg/constant" "github.com/perfect-panel/server/pkg/xerr" @@ -43,6 +44,9 @@ func (l *ResetTrafficLogic) ResetTraffic(req *types.ResetTrafficOrderRequest) (r logger.Error("current user is not found in context") return nil, errors.Wrapf(xerr.NewErrCode(xerr.InvalidAccess), "Invalid Access") } + if err = commonLogic.DenyIfFamilyMemberReadonly(l.ctx, l.svcCtx.DB, u.Id); err != nil { + return nil, err + } // find user subscription userSubscribe, err := l.svcCtx.UserModel.FindOneUserSubscribe(l.ctx, req.UserSubscribeID) if err != nil { diff --git a/internal/logic/public/redemption/redeemCodeLogic.go b/internal/logic/public/redemption/redeemCodeLogic.go index da64d0e..fb3cc2f 100644 --- a/internal/logic/public/redemption/redeemCodeLogic.go +++ b/internal/logic/public/redemption/redeemCodeLogic.go @@ -7,6 +7,7 @@ import ( "time" "github.com/hibiken/asynq" + commonLogic "github.com/perfect-panel/server/internal/logic/common" "github.com/perfect-panel/server/internal/model/order" "github.com/perfect-panel/server/internal/model/user" "github.com/perfect-panel/server/pkg/constant" @@ -43,6 +44,9 @@ func (l *RedeemCodeLogic) RedeemCode(req *types.RedeemCodeRequest) (resp *types. logger.Error("current user is not found in context") return nil, errors.Wrapf(xerr.NewErrCode(xerr.InvalidAccess), "Invalid Access") } + if err = commonLogic.DenyIfFamilyMemberReadonly(l.ctx, l.svcCtx.DB, u.Id); err != nil { + return nil, err + } // 使用Redis分布式锁防止并发重复兑换 lockKey := fmt.Sprintf("redemption_lock:%d:%s", u.Id, req.Code) @@ -221,4 +225,4 @@ func (l *RedeemCodeLogic) RedeemCode(req *types.RedeemCodeRequest) (resp *types. return &types.RedeemCodeResponse{ Message: "Redemption successful, processing...", }, nil -} \ No newline at end of file +} diff --git a/internal/logic/public/subscribe/queryUserSubscribeNodeListLogic.go b/internal/logic/public/subscribe/queryUserSubscribeNodeListLogic.go index 7573d89..af88aa0 100644 --- a/internal/logic/public/subscribe/queryUserSubscribeNodeListLogic.go +++ b/internal/logic/public/subscribe/queryUserSubscribeNodeListLogic.go @@ -5,6 +5,7 @@ import ( "strings" "time" + commonLogic "github.com/perfect-panel/server/internal/logic/common" "github.com/perfect-panel/server/internal/model/node" "github.com/perfect-panel/server/internal/model/user" "github.com/perfect-panel/server/internal/svc" @@ -38,7 +39,12 @@ func (l *QueryUserSubscribeNodeListLogic) QueryUserSubscribeNodeList() (resp *ty return nil, errors.Wrapf(xerr.NewErrCode(xerr.InvalidAccess), "Invalid Access") } - userSubscribes, err := l.svcCtx.UserModel.QueryUserSubscribe(l.ctx, u.Id, 0, 1, 2, 3) + entitlement, err := commonLogic.ResolveEntitlementUser(l.ctx, l.svcCtx.DB, u.Id) + if err != nil { + return nil, err + } + + userSubscribes, err := l.svcCtx.UserModel.QueryUserSubscribe(l.ctx, entitlement.EffectiveUserID, 0, 1, 2, 3) if err != nil { logger.Errorw("failed to query user subscribe", logger.Field("error", err.Error()), logger.Field("user_id", u.Id)) return nil, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "DB_ERROR") @@ -79,12 +85,22 @@ func (l *QueryUserSubscribeNodeListLogic) QueryUserSubscribeNodeList() (resp *ty if l.svcCtx.Config.Register.EnableTrial && l.svcCtx.Config.Register.TrialSubscribe == userSubscribe.SubscribeId { userSubscribeInfo.IsTryOut = true } + fillUserSubscribeInfoEntitlementFields(&userSubscribeInfo, entitlement) resp.List = append(resp.List, userSubscribeInfo) } return } +func fillUserSubscribeInfoEntitlementFields(sub *types.UserSubscribeInfo, entitlement *commonLogic.EntitlementContext) { + if sub == nil || entitlement == nil { + return + } + sub.EntitlementSource = entitlement.Source + sub.EntitlementOwnerUserId = entitlement.OwnerUserID + sub.ReadOnly = entitlement.ReadOnly +} + func (l *QueryUserSubscribeNodeListLogic) getServers(userSub *user.Subscribe) (userSubscribeNodes []*types.UserSubscribeNodeInfo, err error) { userSubscribeNodes = make([]*types.UserSubscribeNodeInfo, 0) if l.isSubscriptionExpired(userSub) { diff --git a/internal/logic/public/subscribe/queryUserSubscribeNodeListLogic_test.go b/internal/logic/public/subscribe/queryUserSubscribeNodeListLogic_test.go new file mode 100644 index 0000000..b2cf915 --- /dev/null +++ b/internal/logic/public/subscribe/queryUserSubscribeNodeListLogic_test.go @@ -0,0 +1,25 @@ +package subscribe + +import ( + "testing" + + commonLogic "github.com/perfect-panel/server/internal/logic/common" + "github.com/perfect-panel/server/internal/types" + "github.com/stretchr/testify/require" +) + +func TestFillUserSubscribeInfoEntitlementFields(t *testing.T) { + sub := &types.UserSubscribeInfo{} + entitlement := &commonLogic.EntitlementContext{ + EffectiveUserID: 3001, + Source: commonLogic.EntitlementSourceFamilyOwner, + OwnerUserID: 3001, + ReadOnly: true, + } + + fillUserSubscribeInfoEntitlementFields(sub, entitlement) + + require.Equal(t, commonLogic.EntitlementSourceFamilyOwner, sub.EntitlementSource) + require.Equal(t, int64(3001), sub.EntitlementOwnerUserId) + require.True(t, sub.ReadOnly) +} diff --git a/internal/logic/public/user/familyBindingHelper.go b/internal/logic/public/user/familyBindingHelper.go index 0674d39..c8a8ed5 100644 --- a/internal/logic/public/user/familyBindingHelper.go +++ b/internal/logic/public/user/familyBindingHelper.go @@ -118,6 +118,7 @@ func (h *familyBindingHelper) joinFamily(ownerUserId, memberUserId int64, source result := &familyJoinResult{ OwnerUserId: ownerUserId, } + removedSubscribes := make([]user.Subscribe, 0) err := h.svcCtx.DB.WithContext(h.ctx).Transaction(func(tx *gorm.DB) error { ownerFamily, err := h.getOrCreateOwnerFamily(tx, ownerUserId) @@ -166,20 +167,24 @@ func (h *familyBindingHelper) joinFamily(ownerUserId, memberUserId int64, source if err = tx.Create(&memberRecord).Error; err != nil { return errors.Wrapf(xerr.NewErrCode(xerr.DatabaseInsertError), "create family member failed") } - return nil + } else { + if memberRecord.FamilyId != ownerFamily.Id { + memberRecord.FamilyId = ownerFamily.Id + } + memberRecord.Status = user.FamilyMemberActive + memberRecord.Role = user.FamilyRoleMember + memberRecord.JoinSource = source + memberRecord.JoinedAt = now + memberRecord.LeftAt = nil + memberRecord.DeletedAt = gorm.DeletedAt{} + if err = tx.Unscoped().Save(&memberRecord).Error; err != nil { + return errors.Wrapf(xerr.NewErrCode(xerr.DatabaseUpdateError), "update family member failed") + } } - if memberRecord.FamilyId != ownerFamily.Id { - memberRecord.FamilyId = ownerFamily.Id - } - memberRecord.Status = user.FamilyMemberActive - memberRecord.Role = user.FamilyRoleMember - memberRecord.JoinSource = source - memberRecord.JoinedAt = now - memberRecord.LeftAt = nil - memberRecord.DeletedAt = gorm.DeletedAt{} - if err = tx.Unscoped().Save(&memberRecord).Error; err != nil { - return errors.Wrapf(xerr.NewErrCode(xerr.DatabaseUpdateError), "update family member failed") + removedSubscribes, err = clearMemberSubscribes(tx, memberUserId) + if err != nil { + return err } return nil }) @@ -187,9 +192,63 @@ func (h *familyBindingHelper) joinFamily(ownerUserId, memberUserId int64, source if err != nil { return nil, err } + + if err = h.clearRemovedMemberSubscribeCache(removedSubscribes); err != nil { + return nil, err + } return result, nil } +func clearMemberSubscribes(tx *gorm.DB, memberUserId int64) ([]user.Subscribe, error) { + var subscribes []user.Subscribe + if err := tx.Model(&user.Subscribe{}). + Where("user_id = ?", memberUserId). + Find(&subscribes).Error; err != nil { + return nil, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "query member subscribe list failed") + } + if len(subscribes) == 0 { + return nil, nil + } + if err := tx.Where("user_id = ?", memberUserId).Delete(&user.Subscribe{}).Error; err != nil { + return nil, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseUpdateError), "delete member subscribe list failed") + } + return subscribes, nil +} + +func (h *familyBindingHelper) clearRemovedMemberSubscribeCache(removedSubscribes []user.Subscribe) error { + if len(removedSubscribes) == 0 { + return nil + } + + subscribeModels, subscribeIDSet := buildRemovedSubscribeCacheMeta(removedSubscribes) + + if err := h.svcCtx.UserModel.ClearSubscribeCache(h.ctx, subscribeModels...); err != nil { + return errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "clear member subscribe cache failed") + } + for subscribeID := range subscribeIDSet { + if err := h.svcCtx.SubscribeModel.ClearCache(h.ctx, subscribeID); err != nil { + return errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "clear subscribe cache failed") + } + } + if err := h.svcCtx.NodeModel.ClearServerAllCache(h.ctx); err != nil { + return errors.Wrapf(xerr.NewErrCode(xerr.ERROR), "clear node cache failed") + } + + return nil +} + +func buildRemovedSubscribeCacheMeta(removedSubscribes []user.Subscribe) ([]*user.Subscribe, map[int64]struct{}) { + subscribeModels := make([]*user.Subscribe, 0, len(removedSubscribes)) + subscribeIDSet := make(map[int64]struct{}, len(removedSubscribes)) + for i := range removedSubscribes { + subscribeModels = append(subscribeModels, &removedSubscribes[i]) + if removedSubscribes[i].SubscribeId > 0 { + subscribeIDSet[removedSubscribes[i].SubscribeId] = struct{}{} + } + } + return subscribeModels, subscribeIDSet +} + func (h *familyBindingHelper) getOrCreateOwnerFamily(tx *gorm.DB, ownerUserId int64) (*user.UserFamily, error) { var ownerFamily user.UserFamily err := tx.Unscoped().Clauses(clause.Locking{Strength: "UPDATE"}). diff --git a/internal/logic/public/user/familyBindingHelper_test.go b/internal/logic/public/user/familyBindingHelper_test.go index c6ef4b3..40db128 100644 --- a/internal/logic/public/user/familyBindingHelper_test.go +++ b/internal/logic/public/user/familyBindingHelper_test.go @@ -105,3 +105,24 @@ func TestValidateMemberJoinConflict(t *testing.T) { }) } } + +func TestBuildRemovedSubscribeCacheMeta(t *testing.T) { + removed := []modelUser.Subscribe{ + {Id: 1, SubscribeId: 10, Token: "member-token-1"}, + {Id: 2, SubscribeId: 11, Token: "member-token-2"}, + {Id: 3, SubscribeId: 0, Token: "member-token-3"}, + } + + models, subscribeIDSet := buildRemovedSubscribeCacheMeta(removed) + + require.Len(t, models, 3) + require.Equal(t, int64(1), models[0].Id) + require.Equal(t, "member-token-2", models[1].Token) + require.Len(t, subscribeIDSet, 2) + _, has10 := subscribeIDSet[10] + _, has11 := subscribeIDSet[11] + _, has0 := subscribeIDSet[0] + require.True(t, has10) + require.True(t, has11) + require.False(t, has0) +} diff --git a/internal/logic/public/user/getSubscribeStatusLogic.go b/internal/logic/public/user/getSubscribeStatusLogic.go index 16f582a..5d08488 100644 --- a/internal/logic/public/user/getSubscribeStatusLogic.go +++ b/internal/logic/public/user/getSubscribeStatusLogic.go @@ -3,6 +3,7 @@ package user import ( "context" + commonLogic "github.com/perfect-panel/server/internal/logic/common" "github.com/perfect-panel/server/internal/model/user" "github.com/perfect-panel/server/internal/svc" "github.com/perfect-panel/server/internal/types" @@ -37,7 +38,8 @@ func (l *GetSubscribeStatusLogic) GetSubscribeStatus(req *types.GetSubscribeStat deviceStatus := false if len(u.UserDevices) > 0 { if dev, err := l.svcCtx.UserModel.FindOneDeviceByIdentifier(l.ctx, u.UserDevices[0].Identifier); err == nil && dev.Id > 0 { - subscribes, err := l.svcCtx.UserModel.QueryUserSubscribe(l.ctx, dev.UserId) + effectiveUserID := l.resolveEntitlementUserID(dev.UserId) + subscribes, err := l.svcCtx.UserModel.QueryUserSubscribe(l.ctx, effectiveUserID) if err == nil { deviceStatus = len(subscribes) > 0 } @@ -48,7 +50,8 @@ func (l *GetSubscribeStatusLogic) GetSubscribeStatus(req *types.GetSubscribeStat emailStatus := false if req.Email != "" { if auth, err := l.svcCtx.UserModel.FindUserAuthMethodByOpenID(l.ctx, "email", req.Email); err == nil && auth.Id > 0 { - subscribes, err := l.svcCtx.UserModel.QueryUserSubscribe(l.ctx, auth.UserId) + effectiveUserID := l.resolveEntitlementUserID(auth.UserId) + subscribes, err := l.svcCtx.UserModel.QueryUserSubscribe(l.ctx, effectiveUserID) if err == nil { emailStatus = len(subscribes) > 0 } @@ -61,3 +64,12 @@ func (l *GetSubscribeStatusLogic) GetSubscribeStatus(req *types.GetSubscribeStat EmailStatus: emailStatus, }, nil } + +func (l *GetSubscribeStatusLogic) resolveEntitlementUserID(userID int64) int64 { + entitlement, err := commonLogic.ResolveEntitlementUser(l.ctx, l.svcCtx.DB, userID) + if err != nil { + l.Errorw("resolve family entitlement failed", logger.Field("user_id", userID), logger.Field("error", err.Error())) + return userID + } + return entitlement.EffectiveUserID +} diff --git a/internal/logic/public/user/preUnsubscribeLogic.go b/internal/logic/public/user/preUnsubscribeLogic.go index 729dcbc..602bdbb 100644 --- a/internal/logic/public/user/preUnsubscribeLogic.go +++ b/internal/logic/public/user/preUnsubscribeLogic.go @@ -3,9 +3,14 @@ package user import ( "context" + commonLogic "github.com/perfect-panel/server/internal/logic/common" + "github.com/perfect-panel/server/internal/model/user" "github.com/perfect-panel/server/internal/svc" "github.com/perfect-panel/server/internal/types" + "github.com/perfect-panel/server/pkg/constant" "github.com/perfect-panel/server/pkg/logger" + "github.com/perfect-panel/server/pkg/xerr" + "github.com/pkg/errors" ) type PreUnsubscribeLogic struct { @@ -24,6 +29,15 @@ func NewPreUnsubscribeLogic(ctx context.Context, svcCtx *svc.ServiceContext) *Pr } func (l *PreUnsubscribeLogic) PreUnsubscribe(req *types.PreUnsubscribeRequest) (resp *types.PreUnsubscribeResponse, err error) { + u, ok := l.ctx.Value(constant.CtxKeyUser).(*user.User) + if !ok { + logger.Error("current user is not found in context") + return nil, errors.Wrapf(xerr.NewErrCode(xerr.InvalidAccess), "Invalid Access") + } + if err = commonLogic.DenyIfFamilyMemberReadonly(l.ctx, l.svcCtx.DB, u.Id); err != nil { + return nil, err + } + remainingAmount, err := CalculateRemainingAmount(l.ctx, l.svcCtx, req.Id) if err != nil { l.Errorw("[PreUnsubscribeLogic] Calculate Remaining Amount Error:", logger.Field("err", err.Error())) diff --git a/internal/logic/public/user/queryUserSubscribeLogic.go b/internal/logic/public/user/queryUserSubscribeLogic.go index 7a1461f..fbbcaeb 100644 --- a/internal/logic/public/user/queryUserSubscribeLogic.go +++ b/internal/logic/public/user/queryUserSubscribeLogic.go @@ -5,6 +5,7 @@ import ( "encoding/json" "time" + commonLogic "github.com/perfect-panel/server/internal/logic/common" "github.com/perfect-panel/server/pkg/constant" "github.com/perfect-panel/server/internal/model/user" @@ -37,7 +38,12 @@ func (l *QueryUserSubscribeLogic) QueryUserSubscribe() (resp *types.QueryUserSub logger.Error("current user is not found in context") return nil, errors.Wrapf(xerr.NewErrCode(xerr.InvalidAccess), "Invalid Access") } - data, err := l.svcCtx.UserModel.QueryUserSubscribe(l.ctx, u.Id, 0, 1, 2, 3) + entitlement, err := commonLogic.ResolveEntitlementUser(l.ctx, l.svcCtx.DB, u.Id) + if err != nil { + return nil, err + } + + data, err := l.svcCtx.UserModel.QueryUserSubscribe(l.ctx, entitlement.EffectiveUserID, 0, 1, 2, 3) if err != nil { l.Errorw("[QueryUserSubscribeLogic] Query User Subscribe Error:", logger.Field("err", err.Error())) return nil, errors.Wrapf(xerr.NewErrCode(xerr.DatabaseQueryError), "Query User Subscribe Error") @@ -71,12 +77,22 @@ func (l *QueryUserSubscribeLogic) QueryUserSubscribe() (resp *types.QueryUserSub } } + fillUserSubscribeEntitlementFields(&sub, entitlement) sub.ResetTime = calculateNextResetTime(&sub) resp.List = append(resp.List, sub) } return } +func fillUserSubscribeEntitlementFields(sub *types.UserSubscribe, entitlement *commonLogic.EntitlementContext) { + if sub == nil || entitlement == nil { + return + } + sub.EntitlementSource = entitlement.Source + sub.EntitlementOwnerUserId = entitlement.OwnerUserID + sub.ReadOnly = entitlement.ReadOnly +} + // 计算下次重置时间 func calculateNextResetTime(sub *types.UserSubscribe) int64 { resetTime := time.UnixMilli(sub.ExpireTime) diff --git a/internal/logic/public/user/queryUserSubscribeLogic_test.go b/internal/logic/public/user/queryUserSubscribeLogic_test.go new file mode 100644 index 0000000..249f0af --- /dev/null +++ b/internal/logic/public/user/queryUserSubscribeLogic_test.go @@ -0,0 +1,25 @@ +package user + +import ( + "testing" + + commonLogic "github.com/perfect-panel/server/internal/logic/common" + "github.com/perfect-panel/server/internal/types" + "github.com/stretchr/testify/require" +) + +func TestFillUserSubscribeEntitlementFields(t *testing.T) { + sub := &types.UserSubscribe{} + entitlement := &commonLogic.EntitlementContext{ + EffectiveUserID: 2001, + Source: commonLogic.EntitlementSourceFamilyOwner, + OwnerUserID: 2001, + ReadOnly: true, + } + + fillUserSubscribeEntitlementFields(sub, entitlement) + + require.Equal(t, commonLogic.EntitlementSourceFamilyOwner, sub.EntitlementSource) + require.Equal(t, int64(2001), sub.EntitlementOwnerUserId) + require.True(t, sub.ReadOnly) +} diff --git a/internal/logic/public/user/resetUserSubscribeTokenLogic.go b/internal/logic/public/user/resetUserSubscribeTokenLogic.go index 56919f9..673d805 100644 --- a/internal/logic/public/user/resetUserSubscribeTokenLogic.go +++ b/internal/logic/public/user/resetUserSubscribeTokenLogic.go @@ -4,6 +4,7 @@ import ( "context" "time" + commonLogic "github.com/perfect-panel/server/internal/logic/common" "github.com/perfect-panel/server/internal/model/order" "github.com/perfect-panel/server/pkg/constant" @@ -40,6 +41,9 @@ func (l *ResetUserSubscribeTokenLogic) ResetUserSubscribeToken(req *types.ResetU logger.Error("current user is not found in context") return errors.Wrapf(xerr.NewErrCode(xerr.InvalidAccess), "Invalid Access") } + if err := commonLogic.DenyIfFamilyMemberReadonly(l.ctx, l.svcCtx.DB, u.Id); err != nil { + return err + } userSub, err := l.svcCtx.UserModel.FindOneUserSubscribe(l.ctx, req.UserSubscribeId) if err != nil { l.Errorw("FindOneUserSubscribe failed:", logger.Field("error", err.Error())) diff --git a/internal/logic/public/user/unsubscribeLogic.go b/internal/logic/public/user/unsubscribeLogic.go index d3390fe..c87ca84 100644 --- a/internal/logic/public/user/unsubscribeLogic.go +++ b/internal/logic/public/user/unsubscribeLogic.go @@ -4,6 +4,7 @@ import ( "context" "time" + commonLogic "github.com/perfect-panel/server/internal/logic/common" "github.com/perfect-panel/server/internal/model/log" "github.com/perfect-panel/server/pkg/constant" "github.com/perfect-panel/server/pkg/tool" @@ -41,6 +42,9 @@ func (l *UnsubscribeLogic) Unsubscribe(req *types.UnsubscribeRequest) error { logger.Error("current user is not found in context") return errors.Wrapf(xerr.NewErrCode(xerr.InvalidAccess), "Invalid Access") } + if err := commonLogic.DenyIfFamilyMemberReadonly(l.ctx, l.svcCtx.DB, u.Id); err != nil { + return err + } // find user subscription by ID userSub, err := l.svcCtx.UserModel.FindOneSubscribe(l.ctx, req.Id) diff --git a/internal/logic/public/user/updateUserSubscribeNoteLogic.go b/internal/logic/public/user/updateUserSubscribeNoteLogic.go index 3c43a8d..e4685b5 100644 --- a/internal/logic/public/user/updateUserSubscribeNoteLogic.go +++ b/internal/logic/public/user/updateUserSubscribeNoteLogic.go @@ -3,6 +3,7 @@ package user import ( "context" + commonLogic "github.com/perfect-panel/server/internal/logic/common" "github.com/perfect-panel/server/pkg/constant" "github.com/perfect-panel/server/internal/model/user" @@ -35,6 +36,9 @@ func (l *UpdateUserSubscribeNoteLogic) UpdateUserSubscribeNote(req *types.Update logger.Error("current user is not found in context") return errors.Wrapf(xerr.NewErrCode(xerr.InvalidAccess), "Invalid Access") } + if err := commonLogic.DenyIfFamilyMemberReadonly(l.ctx, l.svcCtx.DB, u.Id); err != nil { + return err + } userSub, err := l.svcCtx.UserModel.FindOneUserSubscribe(l.ctx, req.UserSubscribeId) if err != nil { diff --git a/internal/logic/public/user/ws/deviceWsConnectLogic.go b/internal/logic/public/user/ws/deviceWsConnectLogic.go index 5024a7f..9e56e12 100644 --- a/internal/logic/public/user/ws/deviceWsConnectLogic.go +++ b/internal/logic/public/user/ws/deviceWsConnectLogic.go @@ -6,6 +6,7 @@ import ( "time" "github.com/gin-gonic/gin" + commonLogic "github.com/perfect-panel/server/internal/logic/common" "github.com/perfect-panel/server/internal/model/user" "github.com/perfect-panel/server/pkg/constant" "github.com/perfect-panel/server/pkg/xerr" @@ -71,7 +72,12 @@ func (l *DeviceWsConnectLogic) DeviceWsConnect(c *gin.Context) error { } //默认在线设备1 maxDevice := 3 - subscribe, err := l.svcCtx.UserModel.QueryUserSubscribe(l.ctx, userInfo.Id, 1, 2) + entitlement, err := commonLogic.ResolveEntitlementUser(l.ctx, l.svcCtx.DB, userInfo.Id) + if err != nil { + return err + } + + subscribe, err := l.svcCtx.UserModel.QueryUserSubscribe(l.ctx, entitlement.EffectiveUserID, 1, 2) if err == nil { for _, sub := range subscribe { if time.Now().Before(sub.ExpireTime) { diff --git a/internal/types/types.go b/internal/types/types.go index e550b9f..7eeca2c 100644 --- a/internal/types/types.go +++ b/internal/types/types.go @@ -2826,24 +2826,27 @@ type UserStatisticsResponse struct { } type UserSubscribe struct { - Id int64 `json:"id"` - UserId int64 `json:"user_id"` - OrderId int64 `json:"order_id"` - SubscribeId int64 `json:"subscribe_id"` - Subscribe Subscribe `json:"subscribe"` - StartTime int64 `json:"start_time"` - ExpireTime int64 `json:"expire_time"` - FinishedAt int64 `json:"finished_at"` - ResetTime int64 `json:"reset_time"` - Traffic int64 `json:"traffic"` - Download int64 `json:"download"` - Upload int64 `json:"upload"` - Token string `json:"token"` - Status uint8 `json:"status"` - Short string `json:"short"` - IsGift bool `json:"is_gift"` - CreatedAt int64 `json:"created_at"` - UpdatedAt int64 `json:"updated_at"` + Id int64 `json:"id"` + UserId int64 `json:"user_id"` + OrderId int64 `json:"order_id"` + SubscribeId int64 `json:"subscribe_id"` + Subscribe Subscribe `json:"subscribe"` + StartTime int64 `json:"start_time"` + ExpireTime int64 `json:"expire_time"` + FinishedAt int64 `json:"finished_at"` + ResetTime int64 `json:"reset_time"` + Traffic int64 `json:"traffic"` + Download int64 `json:"download"` + Upload int64 `json:"upload"` + Token string `json:"token"` + Status uint8 `json:"status"` + EntitlementSource string `json:"entitlement_source"` + EntitlementOwnerUserId int64 `json:"entitlement_owner_user_id"` + ReadOnly bool `json:"read_only"` + Short string `json:"short"` + IsGift bool `json:"is_gift"` + CreatedAt int64 `json:"created_at"` + UpdatedAt int64 `json:"updated_at"` } type UserSubscribeDetail struct { @@ -2866,23 +2869,26 @@ type UserSubscribeDetail struct { } type UserSubscribeInfo struct { - Id int64 `json:"id"` - UserId int64 `json:"user_id"` - OrderId int64 `json:"order_id"` - SubscribeId int64 `json:"subscribe_id"` - StartTime int64 `json:"start_time"` - ExpireTime int64 `json:"expire_time"` - FinishedAt int64 `json:"finished_at"` - ResetTime int64 `json:"reset_time"` - Traffic int64 `json:"traffic"` - Download int64 `json:"download"` - Upload int64 `json:"upload"` - Token string `json:"token"` - Status uint8 `json:"status"` - CreatedAt int64 `json:"created_at"` - UpdatedAt int64 `json:"updated_at"` - IsTryOut bool `json:"is_try_out"` - Nodes []*UserSubscribeNodeInfo `json:"nodes"` + Id int64 `json:"id"` + UserId int64 `json:"user_id"` + OrderId int64 `json:"order_id"` + SubscribeId int64 `json:"subscribe_id"` + StartTime int64 `json:"start_time"` + ExpireTime int64 `json:"expire_time"` + FinishedAt int64 `json:"finished_at"` + ResetTime int64 `json:"reset_time"` + Traffic int64 `json:"traffic"` + Download int64 `json:"download"` + Upload int64 `json:"upload"` + Token string `json:"token"` + Status uint8 `json:"status"` + EntitlementSource string `json:"entitlement_source"` + EntitlementOwnerUserId int64 `json:"entitlement_owner_user_id"` + ReadOnly bool `json:"read_only"` + CreatedAt int64 `json:"created_at"` + UpdatedAt int64 `json:"updated_at"` + IsTryOut bool `json:"is_try_out"` + Nodes []*UserSubscribeNodeInfo `json:"nodes"` } type UserSubscribeLog struct {