501 lines
14 KiB
Go
501 lines
14 KiB
Go
package repository
|
||
|
||
import (
|
||
"errors"
|
||
"fmt"
|
||
"math/rand"
|
||
"time"
|
||
|
||
"gorm.io/gorm"
|
||
|
||
"github.com/topfans/backend/pkg/database"
|
||
"github.com/topfans/backend/pkg/models"
|
||
)
|
||
|
||
// SocialRepository 社交数据访问接口
|
||
// 目前包含好友功能,后续可扩展关注、点赞等功能
|
||
type SocialRepository interface {
|
||
// ========== 好友请求相关 ==========
|
||
|
||
// CreateRequest 创建好友请求
|
||
CreateRequest(request *models.FriendRequest) error
|
||
|
||
// GetRequestByID 根据ID获取好友请求
|
||
GetRequestByID(requestID int64) (*models.FriendRequest, error)
|
||
|
||
// GetRequestsByUser 获取用户的好友请求列表
|
||
// userID: 用户ID
|
||
// starID: 明星ID
|
||
// requestType: 请求类型(received=收到的, sent=发出的)
|
||
// status: 状态筛选(可选,空字符串表示所有状态)
|
||
// page: 页码(从1开始)
|
||
// pageSize: 每页数量
|
||
GetRequestsByUser(userID, starID int64, requestType string, status string, page, pageSize int) ([]*models.FriendRequest, int64, error)
|
||
|
||
// GetLatestRequest 获取两个用户之间最近的请求记录(用于防骚扰机制)
|
||
// fromUserID: 发送者ID
|
||
// toUserID: 接收者ID
|
||
// starID: 明星ID
|
||
GetLatestRequest(fromUserID, toUserID, starID int64) (*models.FriendRequest, error)
|
||
|
||
// UpdateRequestStatus 更新好友请求状态
|
||
// requestID: 请求ID
|
||
// status: 新状态
|
||
// processedAt: 处理时间(可选,传nil表示不更新)
|
||
UpdateRequestStatus(requestID int64, status string, processedAt *int64) error
|
||
|
||
// ========== 好友关系相关 ==========
|
||
|
||
// CreateFriendshipPair 创建双向好友关系(A→B 和 B→A)
|
||
// 内部使用事务保证原子性
|
||
CreateFriendshipPair(userID, friendID, starID int64) error
|
||
|
||
// CheckFriendship 检查是否为好友关系
|
||
CheckFriendship(userID, friendID, starID int64) (bool, error)
|
||
|
||
// GetFriendsByUser 获取用户的好友列表(支持分页和关键词搜索)
|
||
// userID: 用户ID
|
||
// starID: 明星ID
|
||
// keyword: 搜索关键词(搜索备注名或昵称,可选)
|
||
// page: 页码(从1开始)
|
||
// pageSize: 每页数量
|
||
GetFriendsByUser(userID, starID int64, keyword string, page, pageSize int) ([]*models.Friendship, int64, error)
|
||
|
||
// DeleteFriendshipPair 删除双向好友关系(A→B 和 B→A)
|
||
// 内部使用事务保证原子性
|
||
DeleteFriendshipPair(userID, friendID, starID int64) error
|
||
|
||
// UpdateRemark 更新好友备注
|
||
UpdateRemark(userID, friendID, starID int64, remark string) error
|
||
|
||
// CountFriends 统计用户的好友数量
|
||
CountFriends(userID, starID int64) (int64, error)
|
||
|
||
// ========== 随机用户相关 ==========
|
||
|
||
// GetRandomUsersByStar 获取同一明星下的随机用户(基于偏移量算法)
|
||
// starID: 明星ID
|
||
// count: 返回数量(默认1,最大100)
|
||
// 返回: 随机用户列表(user_id, nickname)
|
||
GetRandomUsersByStar(starID int64, count int) ([]*RandomUserInfo, error)
|
||
|
||
// GetUsersByStarPaged 分页获取同一明星下的用户列表
|
||
// starID: 明星ID
|
||
// excludeUserID: 要排除的用户ID(通常是当前用户,传0则不排除)
|
||
// page: 页码(从1开始)
|
||
// pageSize: 每页数量
|
||
// 返回: 用户列表(user_id, nickname, level, slot_limit)和总数
|
||
GetUsersByStarPaged(starID int64, excludeUserID int64, page, pageSize int) ([]*PagedUserInfo, int64, error)
|
||
|
||
// GetSlotCountsByUsers 批量获取用户的展位占用数量
|
||
// userIDs: 用户ID列表
|
||
// starID: 明星ID
|
||
// 返回: map[userID]occupiedSlots
|
||
GetSlotCountsByUsers(userIDs []int64, starID int64) (map[int64]int64, error)
|
||
|
||
// GetFanProfileByUserIDAndStarID 根据userID和starID获取粉丝档案
|
||
GetFanProfileByUserIDAndStarID(userID, starID int64) (*models.FanProfile, error)
|
||
}
|
||
|
||
// RandomUserInfo 随机用户信息
|
||
type RandomUserInfo struct {
|
||
UserID int64
|
||
Nickname string
|
||
}
|
||
|
||
// PagedUserInfo 分页用户信息
|
||
type PagedUserInfo struct {
|
||
UserID int64
|
||
Nickname string
|
||
Level int32
|
||
SlotLimit int32 // 槽位上限
|
||
}
|
||
|
||
// socialRepositoryImpl SocialRepository 的实现
|
||
type socialRepositoryImpl struct {
|
||
db *gorm.DB
|
||
}
|
||
|
||
// NewSocialRepository 创建 SocialRepository 实例
|
||
func NewSocialRepository() SocialRepository {
|
||
return &socialRepositoryImpl{
|
||
db: database.GetDB(),
|
||
}
|
||
}
|
||
|
||
// ========== 好友请求相关实现 ==========
|
||
|
||
func (r *socialRepositoryImpl) CreateRequest(request *models.FriendRequest) error {
|
||
if request == nil {
|
||
return errors.New("request cannot be nil")
|
||
}
|
||
return r.db.Create(request).Error
|
||
}
|
||
|
||
func (r *socialRepositoryImpl) GetRequestByID(requestID int64) (*models.FriendRequest, error) {
|
||
var request models.FriendRequest
|
||
err := r.db.Where("id = ?", requestID).First(&request).Error
|
||
if err != nil {
|
||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||
return nil, nil // 返回 nil 而不是错误,便于调用方判断
|
||
}
|
||
return nil, err
|
||
}
|
||
return &request, nil
|
||
}
|
||
|
||
func (r *socialRepositoryImpl) GetRequestsByUser(userID, starID int64, requestType string, status string, page, pageSize int) ([]*models.FriendRequest, int64, error) {
|
||
var requests []*models.FriendRequest
|
||
var total int64
|
||
|
||
query := r.db.Model(&models.FriendRequest{}).Where("star_id = ?", starID)
|
||
|
||
// 根据请求类型筛选
|
||
if requestType == models.FriendRequestTypeReceived {
|
||
query = query.Where("to_user_id = ?", userID)
|
||
} else if requestType == models.FriendRequestTypeSent {
|
||
query = query.Where("from_user_id = ?", userID)
|
||
} else {
|
||
return nil, 0, fmt.Errorf("invalid request type: %s", requestType)
|
||
}
|
||
|
||
// 根据状态筛选
|
||
if status != "" && status != "all" {
|
||
query = query.Where("status = ?", status)
|
||
}
|
||
|
||
// 统计总数
|
||
if err := query.Count(&total).Error; err != nil {
|
||
return nil, 0, err
|
||
}
|
||
|
||
// 分页查询
|
||
offset := (page - 1) * pageSize
|
||
err := query.Order("created_at DESC").
|
||
Limit(pageSize).
|
||
Offset(offset).
|
||
Find(&requests).Error
|
||
|
||
if err != nil {
|
||
return nil, 0, err
|
||
}
|
||
|
||
return requests, total, nil
|
||
}
|
||
|
||
func (r *socialRepositoryImpl) GetLatestRequest(fromUserID, toUserID, starID int64) (*models.FriendRequest, error) {
|
||
var request models.FriendRequest
|
||
err := r.db.Where("from_user_id = ? AND to_user_id = ? AND star_id = ?", fromUserID, toUserID, starID).
|
||
Order("created_at DESC").
|
||
First(&request).Error
|
||
|
||
if err != nil {
|
||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||
return nil, nil // 返回 nil 而不是错误,表示没有找到记录
|
||
}
|
||
return nil, err
|
||
}
|
||
|
||
return &request, nil
|
||
}
|
||
|
||
func (r *socialRepositoryImpl) UpdateRequestStatus(requestID int64, status string, processedAt *int64) error {
|
||
updates := map[string]interface{}{
|
||
"status": status,
|
||
}
|
||
|
||
if processedAt != nil {
|
||
updates["processed_at"] = *processedAt
|
||
}
|
||
|
||
result := r.db.Model(&models.FriendRequest{}).
|
||
Where("id = ?", requestID).
|
||
Updates(updates)
|
||
|
||
if result.Error != nil {
|
||
return result.Error
|
||
}
|
||
|
||
if result.RowsAffected == 0 {
|
||
return errors.New("friend request not found")
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// ========== 好友关系相关实现 ==========
|
||
|
||
func (r *socialRepositoryImpl) CreateFriendshipPair(userID, friendID, starID int64) error {
|
||
// 使用事务创建双向好友关系
|
||
return r.db.Transaction(func(tx *gorm.DB) error {
|
||
// 创建 A→B
|
||
friendship1 := &models.Friendship{
|
||
UserID: userID,
|
||
FriendID: friendID,
|
||
StarID: starID,
|
||
Status: models.FriendshipStatusAccepted,
|
||
}
|
||
if err := tx.Create(friendship1).Error; err != nil {
|
||
return fmt.Errorf("failed to create friendship A→B: %w", err)
|
||
}
|
||
|
||
// 创建 B→A
|
||
friendship2 := &models.Friendship{
|
||
UserID: friendID,
|
||
FriendID: userID,
|
||
StarID: starID,
|
||
Status: models.FriendshipStatusAccepted,
|
||
}
|
||
if err := tx.Create(friendship2).Error; err != nil {
|
||
return fmt.Errorf("failed to create friendship B→A: %w", err)
|
||
}
|
||
|
||
return nil
|
||
})
|
||
}
|
||
|
||
func (r *socialRepositoryImpl) CheckFriendship(userID, friendID, starID int64) (bool, error) {
|
||
var count int64
|
||
err := r.db.Model(&models.Friendship{}).
|
||
Where("user_id = ? AND friend_id = ? AND star_id = ? AND status = ?",
|
||
userID, friendID, starID, models.FriendshipStatusAccepted).
|
||
Count(&count).Error
|
||
|
||
if err != nil {
|
||
return false, err
|
||
}
|
||
|
||
return count > 0, nil
|
||
}
|
||
|
||
func (r *socialRepositoryImpl) GetFriendsByUser(userID, starID int64, keyword string, page, pageSize int) ([]*models.Friendship, int64, error) {
|
||
var friendships []*models.Friendship
|
||
var total int64
|
||
|
||
query := r.db.Model(&models.Friendship{}).
|
||
Where("user_id = ? AND star_id = ? AND status = ?", userID, starID, models.FriendshipStatusAccepted)
|
||
|
||
// 如果有关键词,需要关联查询 fan_profiles 表
|
||
if keyword != "" {
|
||
keyword = "%" + keyword + "%"
|
||
// 关联查询昵称或备注
|
||
query = query.Where(
|
||
r.db.Where("remark LIKE ?", keyword).
|
||
Or("friend_id IN (?)",
|
||
r.db.Model(&models.FanProfile{}).
|
||
Select("user_id").
|
||
Where("star_id = ? AND nickname LIKE ?", starID, keyword),
|
||
),
|
||
)
|
||
}
|
||
|
||
// 统计总数
|
||
if err := query.Count(&total).Error; err != nil {
|
||
return nil, 0, err
|
||
}
|
||
|
||
// 分页查询,按创建时间倒序
|
||
offset := (page - 1) * pageSize
|
||
err := query.Order("created_at DESC").
|
||
Limit(pageSize).
|
||
Offset(offset).
|
||
Find(&friendships).Error
|
||
|
||
if err != nil {
|
||
return nil, 0, err
|
||
}
|
||
|
||
return friendships, total, nil
|
||
}
|
||
|
||
func (r *socialRepositoryImpl) DeleteFriendshipPair(userID, friendID, starID int64) error {
|
||
// 使用事务删除双向好友关系
|
||
return r.db.Transaction(func(tx *gorm.DB) error {
|
||
// 删除 A→B
|
||
result1 := tx.Where("user_id = ? AND friend_id = ? AND star_id = ?", userID, friendID, starID).
|
||
Delete(&models.Friendship{})
|
||
if result1.Error != nil {
|
||
return fmt.Errorf("failed to delete friendship A→B: %w", result1.Error)
|
||
}
|
||
|
||
// 删除 B→A
|
||
result2 := tx.Where("user_id = ? AND friend_id = ? AND star_id = ?", friendID, userID, starID).
|
||
Delete(&models.Friendship{})
|
||
if result2.Error != nil {
|
||
return fmt.Errorf("failed to delete friendship B→A: %w", result2.Error)
|
||
}
|
||
|
||
// 检查是否真的删除了记录
|
||
if result1.RowsAffected == 0 && result2.RowsAffected == 0 {
|
||
return errors.New("friendship not found")
|
||
}
|
||
|
||
return nil
|
||
})
|
||
}
|
||
|
||
func (r *socialRepositoryImpl) UpdateRemark(userID, friendID, starID int64, remark string) error {
|
||
result := r.db.Model(&models.Friendship{}).
|
||
Where("user_id = ? AND friend_id = ? AND star_id = ?", userID, friendID, starID).
|
||
Update("remark", remark)
|
||
|
||
if result.Error != nil {
|
||
return result.Error
|
||
}
|
||
|
||
if result.RowsAffected == 0 {
|
||
return errors.New("friendship not found")
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
func (r *socialRepositoryImpl) CountFriends(userID, starID int64) (int64, error) {
|
||
var count int64
|
||
err := r.db.Model(&models.Friendship{}).
|
||
Where("user_id = ? AND star_id = ? AND status = ?", userID, starID, models.FriendshipStatusAccepted).
|
||
Count(&count).Error
|
||
|
||
return count, err
|
||
}
|
||
|
||
// ========== 随机用户相关实现 ==========
|
||
|
||
// GetRandomUsersByStar 获取同一明星下的随机用户(基于偏移量算法)
|
||
func (r *socialRepositoryImpl) GetRandomUsersByStar(starID int64, count int) ([]*RandomUserInfo, error) {
|
||
// 1. 参数验证
|
||
if starID <= 0 {
|
||
return nil, errors.New("star_id must be greater than 0")
|
||
}
|
||
if count <= 0 {
|
||
count = 1 // 默认返回1个
|
||
}
|
||
if count > 100 {
|
||
count = 100 // 最大限制100个
|
||
}
|
||
|
||
// 2. 查询总数
|
||
var total int64
|
||
err := r.db.Model(&models.FanProfile{}).
|
||
Where("star_id = ? AND is_active = ?", starID, true).
|
||
Count(&total).Error
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to count fan profiles: %w", err)
|
||
}
|
||
|
||
if total == 0 {
|
||
return []*RandomUserInfo{}, nil // 没有数据,返回空列表
|
||
}
|
||
|
||
// 3. 生成随机偏移量并查询
|
||
// 使用当前时间作为随机种子,确保每次调用都有不同的随机性
|
||
rand.Seed(time.Now().UnixNano())
|
||
randomOffset := rand.Int63n(total) // [0, total-1]
|
||
|
||
var profiles []models.FanProfile
|
||
err = r.db.Model(&models.FanProfile{}).
|
||
Select("user_id", "nickname").
|
||
Where("star_id = ? AND is_active = ?", starID, true).
|
||
Order("id ASC").
|
||
Limit(count).
|
||
Offset(int(randomOffset)).
|
||
Find(&profiles).Error
|
||
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to get random users: %w", err)
|
||
}
|
||
|
||
// 4. 转换为结果
|
||
result := make([]*RandomUserInfo, 0, len(profiles))
|
||
for _, profile := range profiles {
|
||
result = append(result, &RandomUserInfo{
|
||
UserID: profile.UserID,
|
||
Nickname: profile.Nickname,
|
||
})
|
||
}
|
||
|
||
return result, nil
|
||
}
|
||
|
||
// GetUsersByStarPaged 分页获取同一明星下的用户列表
|
||
func (r *socialRepositoryImpl) GetUsersByStarPaged(starID int64, excludeUserID int64, page, pageSize int) ([]*PagedUserInfo, int64, error) {
|
||
offset := (page - 1) * pageSize
|
||
|
||
query := r.db.Model(&models.FanProfile{}).
|
||
Where("star_id = ?", starID)
|
||
|
||
if excludeUserID > 0 {
|
||
query = query.Where("user_id != ?", excludeUserID)
|
||
}
|
||
|
||
var total int64
|
||
if err := query.Count(&total).Error; err != nil {
|
||
return nil, 0, fmt.Errorf("failed to count users: %w", err)
|
||
}
|
||
|
||
var profiles []models.FanProfile
|
||
err := query.
|
||
Select("user_id", "nickname", "level", "slot_limit").
|
||
Order("id DESC").
|
||
Limit(pageSize).
|
||
Offset(offset).
|
||
Find(&profiles).Error
|
||
|
||
if err != nil {
|
||
return nil, 0, fmt.Errorf("failed to get paged users: %w", err)
|
||
}
|
||
|
||
result := make([]*PagedUserInfo, 0, len(profiles))
|
||
for _, profile := range profiles {
|
||
result = append(result, &PagedUserInfo{
|
||
UserID: profile.UserID,
|
||
Nickname: profile.Nickname,
|
||
Level: profile.Level,
|
||
SlotLimit: profile.SlotLimit,
|
||
})
|
||
}
|
||
|
||
return result, total, nil
|
||
}
|
||
|
||
// GetSlotCountsByUsers 批量获取用户的共享展位已占用数量
|
||
// 共享展位 = public visibility 的 BoothSlot
|
||
// 已占用 = BoothSlot 上有 Exhibition 记录
|
||
func (r *socialRepositoryImpl) GetSlotCountsByUsers(userIDs []int64, starID int64) (map[int64]int64, error) {
|
||
if len(userIDs) == 0 {
|
||
return make(map[int64]int64), nil
|
||
}
|
||
|
||
// 统计每个用户已占用的共享展位数量(通过 Exhibition 表 JOIN BoothSlot)
|
||
var results []struct {
|
||
UserID int64
|
||
Count int64
|
||
}
|
||
err := r.db.Model(&models.Exhibition{}).
|
||
Select("bs.user_id, COUNT(*) as count").
|
||
Joins("JOIN booth_slots bs ON bs.slot_id = exhibitions.slot_id").
|
||
Where("bs.user_id IN ? AND bs.star_id = ? AND bs.visibility = ? AND bs.is_enabled = ?", userIDs, starID, "public", true).
|
||
Group("bs.user_id").
|
||
Find(&results).Error
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
countMap := make(map[int64]int64)
|
||
for _, res := range results {
|
||
countMap[res.UserID] = res.Count
|
||
}
|
||
// 未找到记录的用户,默认为 0(不需要在 map 中明确设置)
|
||
return countMap, nil
|
||
}
|
||
|
||
// GetFanProfileByUserIDAndStarID 根据userID和starID获取粉丝档案
|
||
func (r *socialRepositoryImpl) GetFanProfileByUserIDAndStarID(userID, starID int64) (*models.FanProfile, error) {
|
||
var profile models.FanProfile
|
||
err := r.db.Where("user_id = ? AND star_id = ?", userID, starID).First(&profile).Error
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return &profile, nil
|
||
}
|