All checks were successful
Build docker and publish / build (20.15.1) (push) Successful in 7m34s
337 lines
9.9 KiB
Go
337 lines
9.9 KiB
Go
package user
|
||
|
||
import (
|
||
"context"
|
||
"encoding/json"
|
||
"fmt"
|
||
"sort"
|
||
"strings"
|
||
|
||
"github.com/perfect-panel/server/pkg/constant"
|
||
"github.com/perfect-panel/server/pkg/kutt"
|
||
"github.com/perfect-panel/server/pkg/uuidx"
|
||
"github.com/perfect-panel/server/pkg/xerr"
|
||
"github.com/pkg/errors"
|
||
|
||
"github.com/perfect-panel/server/internal/model/user"
|
||
"github.com/perfect-panel/server/internal/svc"
|
||
"github.com/perfect-panel/server/internal/types"
|
||
"github.com/perfect-panel/server/pkg/logger"
|
||
"github.com/perfect-panel/server/pkg/phone"
|
||
"github.com/perfect-panel/server/pkg/tool"
|
||
"gorm.io/gorm"
|
||
)
|
||
|
||
type QueryUserInfoLogic struct {
|
||
logger.Logger
|
||
ctx context.Context
|
||
svcCtx *svc.ServiceContext
|
||
}
|
||
|
||
// Query User Info
|
||
func NewQueryUserInfoLogic(ctx context.Context, svcCtx *svc.ServiceContext) *QueryUserInfoLogic {
|
||
return &QueryUserInfoLogic{
|
||
Logger: logger.WithContext(ctx),
|
||
ctx: ctx,
|
||
svcCtx: svcCtx,
|
||
}
|
||
}
|
||
|
||
func (l *QueryUserInfoLogic) QueryUserInfo() (resp *types.User, err error) {
|
||
resp = &types.User{}
|
||
u, ok := l.ctx.Value(constant.CtxKeyUser).(*user.User)
|
||
if !ok {
|
||
logger.Error("current user is not found in context")
|
||
return nil, errors.Wrapf(xerr.NewErrCode(xerr.InvalidAccess), "Invalid Access")
|
||
}
|
||
tool.DeepCopy(resp, u)
|
||
for i, d := range u.UserDevices {
|
||
if i < len(resp.UserDevices) {
|
||
resp.UserDevices[i].DeviceID = tool.DeviceIdToHash(d.Id)
|
||
}
|
||
}
|
||
// refer_code 为空时自动生成
|
||
if resp.ReferCode == "" {
|
||
resp.ReferCode = uuidx.UserInviteCode(u.Id)
|
||
if err := l.svcCtx.DB.Model(&user.User{}).Where("id = ?", u.Id).Update("refer_code", resp.ReferCode).Error; err != nil {
|
||
l.Errorw("auto generate refer_code failed", logger.Field("user_id", u.Id), logger.Field("error", err.Error()))
|
||
} else {
|
||
_ = l.svcCtx.UserModel.ClearUserCache(l.ctx, u)
|
||
}
|
||
}
|
||
|
||
ownerEmailMethod := l.fillFamilyContext(resp, u.Id)
|
||
|
||
var userMethods []types.UserAuthMethod
|
||
for _, method := range resp.AuthMethods {
|
||
var item types.UserAuthMethod
|
||
tool.DeepCopy(&item, method)
|
||
|
||
switch method.AuthType {
|
||
case "mobile":
|
||
item.AuthIdentifier = phone.MaskPhoneNumber(method.AuthIdentifier)
|
||
case "email":
|
||
default:
|
||
item.AuthIdentifier = maskOpenID(method.AuthIdentifier)
|
||
}
|
||
userMethods = append(userMethods, item)
|
||
}
|
||
userMethods = appendFamilyOwnerEmailIfNeeded(userMethods, resp.FamilyJoined, ownerEmailMethod)
|
||
|
||
sortUserAuthMethodsByPriority(userMethods)
|
||
|
||
resp.AuthMethods = userMethods
|
||
|
||
// 生成邀请短链接
|
||
if l.svcCtx.Config.Kutt.Enable && resp.ReferCode != "" {
|
||
shortLink := l.generateInviteShortLink(resp.ReferCode)
|
||
if shortLink != "" {
|
||
resp.ShareLink = shortLink
|
||
}
|
||
}
|
||
|
||
return resp, nil
|
||
}
|
||
|
||
func (l *QueryUserInfoLogic) fillFamilyContext(resp *types.User, userId int64) *user.AuthMethods {
|
||
type familyRelation struct {
|
||
FamilyId int64
|
||
Role uint8
|
||
FamilyStatus uint8
|
||
OwnerUserId int64
|
||
MaxMembers int64
|
||
}
|
||
|
||
var relation familyRelation
|
||
relationErr := l.svcCtx.DB.WithContext(l.ctx).
|
||
Table("user_family_member").
|
||
Select("user_family_member.family_id, user_family_member.role, user_family.status as family_status, user_family.owner_user_id, user_family.max_members").
|
||
Joins("JOIN user_family ON user_family.id = user_family_member.family_id AND user_family.deleted_at IS NULL").
|
||
Where("user_family_member.user_id = ? AND user_family_member.deleted_at IS NULL AND user_family_member.status = ?", userId, user.FamilyMemberActive).
|
||
First(&relation).Error
|
||
if relationErr != nil {
|
||
if !errors.Is(relationErr, gorm.ErrRecordNotFound) {
|
||
l.Errorw("query family relation failed", logger.Field("user_id", userId), logger.Field("error", relationErr.Error()))
|
||
}
|
||
return nil
|
||
}
|
||
|
||
resp.FamilyJoined = true
|
||
resp.FamilyId = relation.FamilyId
|
||
resp.FamilyRole = relation.Role
|
||
resp.FamilyRoleName = getFamilyRoleName(relation.Role)
|
||
resp.FamilyOwnerUserId = relation.OwnerUserId
|
||
resp.FamilyStatus = getFamilyStatusName(relation.FamilyStatus)
|
||
resp.FamilyMaxMembers = relation.MaxMembers
|
||
|
||
var activeMemberCount int64
|
||
countErr := l.svcCtx.DB.WithContext(l.ctx).
|
||
Table("user_family_member").
|
||
Where("family_id = ? AND status = ? AND deleted_at IS NULL", relation.FamilyId, user.FamilyMemberActive).
|
||
Count(&activeMemberCount).Error
|
||
if countErr != nil {
|
||
l.Errorw("count family members failed", logger.Field("family_id", relation.FamilyId), logger.Field("error", countErr.Error()))
|
||
} else {
|
||
resp.FamilyMemberCount = activeMemberCount
|
||
}
|
||
|
||
ownerEmailMethod, ownerEmailErr := l.svcCtx.UserModel.FindUserAuthMethodByUserId(l.ctx, "email", relation.OwnerUserId)
|
||
if ownerEmailErr != nil {
|
||
if !errors.Is(ownerEmailErr, gorm.ErrRecordNotFound) {
|
||
l.Errorw("query family owner email failed", logger.Field("owner_user_id", relation.OwnerUserId), logger.Field("error", ownerEmailErr.Error()))
|
||
}
|
||
return nil
|
||
}
|
||
return ownerEmailMethod
|
||
}
|
||
|
||
func appendFamilyOwnerEmailIfNeeded(methods []types.UserAuthMethod, familyJoined bool, ownerEmailMethod *user.AuthMethods) []types.UserAuthMethod {
|
||
if !familyJoined || ownerEmailMethod == nil {
|
||
return methods
|
||
}
|
||
ownerEmail := strings.TrimSpace(ownerEmailMethod.AuthIdentifier)
|
||
if ownerEmail == "" {
|
||
return methods
|
||
}
|
||
if hasEmailAuthMethod(methods) {
|
||
return methods
|
||
}
|
||
return append(methods, types.UserAuthMethod{
|
||
AuthType: "email",
|
||
AuthIdentifier: ownerEmail,
|
||
Verified: ownerEmailMethod.Verified,
|
||
})
|
||
}
|
||
|
||
func hasEmailAuthMethod(methods []types.UserAuthMethod) bool {
|
||
for _, method := range methods {
|
||
if strings.EqualFold(strings.TrimSpace(method.AuthType), "email") && strings.TrimSpace(method.AuthIdentifier) != "" {
|
||
return true
|
||
}
|
||
}
|
||
return false
|
||
}
|
||
|
||
func sortUserAuthMethodsByPriority(methods []types.UserAuthMethod) {
|
||
sort.SliceStable(methods, func(i, j int) bool {
|
||
return getAuthTypePriority(methods[i].AuthType) < getAuthTypePriority(methods[j].AuthType)
|
||
})
|
||
}
|
||
|
||
func getFamilyRoleName(role uint8) string {
|
||
switch role {
|
||
case user.FamilyRoleOwner:
|
||
return "owner"
|
||
case user.FamilyRoleMember:
|
||
return "member"
|
||
default:
|
||
return fmt.Sprintf("role_%d", role)
|
||
}
|
||
}
|
||
|
||
func getFamilyStatusName(status uint8) string {
|
||
if status == user.FamilyStatusActive {
|
||
return "active"
|
||
}
|
||
return "disabled"
|
||
}
|
||
|
||
// customData 用于解析 SiteConfig.CustomData JSON 字段
|
||
// 包含从自定义数据中提取所需的配置项
|
||
type customData struct {
|
||
ShareUrl string `json:"shareUrl"` // 分享链接前缀 URL(目标落地页)
|
||
Domain string `json:"domain"` // 短链接域名
|
||
}
|
||
|
||
// getShareUrl 从 SiteConfig.CustomData 中获取 shareUrl
|
||
//
|
||
// 返回:
|
||
// - string: 分享链接前缀 URL,如果获取失败则返回 Kutt.TargetURL 作为 fallback
|
||
func (l *QueryUserInfoLogic) getShareUrl() string {
|
||
siteConfig := l.svcCtx.Config.Site
|
||
if siteConfig.CustomData != "" {
|
||
var data customData
|
||
if err := json.Unmarshal([]byte(siteConfig.CustomData), &data); err == nil {
|
||
if data.ShareUrl != "" {
|
||
return data.ShareUrl
|
||
}
|
||
}
|
||
}
|
||
// fallback 到 Kutt.TargetURL
|
||
return l.svcCtx.Config.Kutt.TargetURL
|
||
}
|
||
|
||
// getDomain 从 SiteConfig.CustomData 中获取短链接域名
|
||
//
|
||
// 返回:
|
||
// - string: 短链接域名,如果获取失败则返回 Kutt.Domain 作为 fallback
|
||
func (l *QueryUserInfoLogic) getDomain() string {
|
||
siteConfig := l.svcCtx.Config.Site
|
||
if siteConfig.CustomData != "" {
|
||
var data customData
|
||
if err := json.Unmarshal([]byte(siteConfig.CustomData), &data); err == nil {
|
||
if data.Domain != "" {
|
||
return data.Domain
|
||
}
|
||
}
|
||
}
|
||
// fallback 到 Kutt.Domain
|
||
return l.svcCtx.Config.Kutt.Domain
|
||
}
|
||
|
||
// generateInviteShortLink 生成邀请短链接(带 Redis 缓存)
|
||
//
|
||
// 参数:
|
||
// - inviteCode: 邀请码
|
||
//
|
||
// 返回:
|
||
// - string: 短链接 URL,失败时返回空字符串
|
||
func (l *QueryUserInfoLogic) generateInviteShortLink(inviteCode string) string {
|
||
cfg := l.svcCtx.Config.Kutt
|
||
shareUrl := l.getShareUrl()
|
||
domain := l.getDomain()
|
||
|
||
// 检查必要配置
|
||
if cfg.ApiURL == "" || cfg.ApiKey == "" {
|
||
l.Sloww("Kutt config incomplete",
|
||
logger.Field("api_url", cfg.ApiURL != ""),
|
||
logger.Field("api_key", cfg.ApiKey != ""))
|
||
return ""
|
||
}
|
||
if shareUrl == "" {
|
||
l.Sloww("ShareUrl not configured in CustomData or Kutt.TargetURL")
|
||
return ""
|
||
}
|
||
|
||
// Redis 缓存 key
|
||
cacheKey := "cache:invite:short_link:" + inviteCode
|
||
|
||
// 1. 尝试从 Redis 缓存读取
|
||
cachedLink, err := l.svcCtx.Redis.Get(l.ctx, cacheKey).Result()
|
||
if err == nil && cachedLink != "" {
|
||
l.Debugw("Hit cache for invite short link",
|
||
logger.Field("invite_code", inviteCode),
|
||
logger.Field("short_link", cachedLink))
|
||
return cachedLink
|
||
}
|
||
|
||
// 2. 缓存未命中,调用 Kutt API 创建短链接
|
||
client := kutt.NewClient(cfg.ApiURL, cfg.ApiKey)
|
||
shortLink, err := client.CreateInviteShortLink(l.ctx, shareUrl, inviteCode, domain)
|
||
if err != nil {
|
||
l.Errorw("Failed to create short link",
|
||
logger.Field("error", err.Error()),
|
||
logger.Field("invite_code", inviteCode),
|
||
logger.Field("share_url", shareUrl))
|
||
return ""
|
||
}
|
||
|
||
// 3. 写入 Redis 缓存(永不过期,因为邀请码不变短链接也不会变)
|
||
if err := l.svcCtx.Redis.Set(l.ctx, cacheKey, shortLink, 0).Err(); err != nil {
|
||
l.Errorw("Failed to cache short link",
|
||
logger.Field("error", err.Error()),
|
||
logger.Field("invite_code", inviteCode))
|
||
// 缓存失败不影响返回
|
||
}
|
||
|
||
l.Infow("Created and cached invite short link",
|
||
logger.Field("invite_code", inviteCode),
|
||
logger.Field("short_link", shortLink),
|
||
logger.Field("share_url", shareUrl))
|
||
|
||
return shortLink
|
||
}
|
||
|
||
// getAuthTypePriority 获取认证类型的排序优先级
|
||
// email: 1 (第一位)
|
||
// mobile: 2 (第二位)
|
||
// 其他类型: 100+ (后续位置)
|
||
func getAuthTypePriority(authType string) int {
|
||
switch authType {
|
||
case "email":
|
||
return 1
|
||
case "mobile":
|
||
return 2
|
||
default:
|
||
return 100
|
||
}
|
||
}
|
||
|
||
// maskOpenID 脱敏 OpenID,只保留前 3 和后 3 位
|
||
func maskOpenID(openID string) string {
|
||
length := len(openID)
|
||
if length <= 6 {
|
||
return "***" // 如果 ID 太短,直接返回 "***"
|
||
}
|
||
|
||
// 计算中间需要被替换的 `*` 数量
|
||
maskLength := length - 6
|
||
mask := make([]byte, maskLength)
|
||
for i := range mask {
|
||
mask[i] = '*'
|
||
}
|
||
|
||
// 组合脱敏后的 OpenID
|
||
return openID[:3] + string(mask) + openID[length-3:]
|
||
}
|