topfans/backend/services/socialService/repository/social_repository.go

529 lines
15 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package repository
import (
"errors"
"fmt"
"math/rand"
"time"
"gorm.io/gorm"
"github.com/topfans/backend/pkg/database"
"github.com/topfans/backend/pkg/models"
)
// SocialRepository 社交数据访问接口
// 目前包含好友功能,后续可扩展关注、点赞等功能
type SocialRepository interface {
// ========== 好友请求相关 ==========
// CreateRequest 创建好友请求
CreateRequest(request *models.FriendRequest) error
// GetRequestByID 根据ID获取好友请求
GetRequestByID(requestID int64) (*models.FriendRequest, error)
// GetRequestsByUser 获取用户的好友请求列表
// userID: 用户ID
// starID: 明星ID
// requestType: 请求类型received=收到的, sent=发出的)
// status: 状态筛选(可选,空字符串表示所有状态)
// page: 页码从1开始
// pageSize: 每页数量
GetRequestsByUser(userID, starID int64, requestType string, status string, page, pageSize int) ([]*models.FriendRequest, int64, error)
// GetLatestRequest 获取两个用户之间最近的请求记录(用于防骚扰机制)
// fromUserID: 发送者ID
// toUserID: 接收者ID
// starID: 明星ID
GetLatestRequest(fromUserID, toUserID, starID int64) (*models.FriendRequest, error)
// UpdateRequestStatus 更新好友请求状态
// requestID: 请求ID
// status: 新状态
// processedAt: 处理时间可选传nil表示不更新
UpdateRequestStatus(requestID int64, status string, processedAt *int64) error
// ========== 好友关系相关 ==========
// CreateFriendshipPair 创建双向好友关系A→B 和 B→A
// 内部使用事务保证原子性
CreateFriendshipPair(userID, friendID, starID int64) error
// CheckFriendship 检查是否为好友关系
CheckFriendship(userID, friendID, starID int64) (bool, error)
// GetFriendsByUser 获取用户的好友列表(支持分页和关键词搜索)
// userID: 用户ID
// starID: 明星ID
// keyword: 搜索关键词(搜索备注名或昵称,可选)
// page: 页码从1开始
// pageSize: 每页数量
GetFriendsByUser(userID, starID int64, keyword string, page, pageSize int) ([]*models.Friendship, int64, error)
// DeleteFriendshipPair 删除双向好友关系A→B 和 B→A
// 内部使用事务保证原子性
DeleteFriendshipPair(userID, friendID, starID int64) error
// UpdateRemark 更新好友备注
UpdateRemark(userID, friendID, starID int64, remark string) error
// CountFriends 统计用户的好友数量
CountFriends(userID, starID int64) (int64, error)
// ========== 随机用户相关 ==========
// GetRandomUsersByStar 获取同一明星下的随机用户(基于偏移量算法)
// starID: 明星ID
// count: 返回数量默认1最大100
// 返回: 随机用户列表user_id, nickname
GetRandomUsersByStar(starID int64, count int) ([]*RandomUserInfo, error)
// GetUsersByStarPaged 分页获取同一明星下的用户列表
// starID: 明星ID
// excludeUserID: 要排除的用户ID通常是当前用户传0则不排除
// page: 页码从1开始
// pageSize: 每页数量
// 返回: 用户列表user_id, nickname, level, slot_limit和总数
GetUsersByStarPaged(starID int64, excludeUserID int64, page, pageSize int) ([]*PagedUserInfo, int64, error)
// GetSlotCountsByUsers 批量获取用户的展位占用数量
// userIDs: 用户ID列表
// starID: 明星ID
// 返回: map[userID]occupiedSlots
GetSlotCountsByUsers(userIDs []int64, starID int64) (map[int64]int64, error)
// GetFanProfileByUserIDAndStarID 根据userID和starID获取粉丝档案
GetFanProfileByUserIDAndStarID(userID, starID int64) (*models.FanProfile, error)
// GetFanProfilesByNickname 根据昵称模糊搜索粉丝档案(同一明星下)
// starID: 明星ID
// nickname: 昵称关键词
// limit: 返回数量限制
GetFanProfilesByNickname(starID int64, nickname string, limit int) ([]*models.FanProfile, error)
}
// RandomUserInfo 随机用户信息
type RandomUserInfo struct {
UserID int64
Nickname string
}
// PagedUserInfo 分页用户信息
type PagedUserInfo struct {
UserID int64
Nickname string
Level int32
SlotLimit int32 // 槽位上限
}
// socialRepositoryImpl SocialRepository 的实现
type socialRepositoryImpl struct {
db *gorm.DB
}
// NewSocialRepository 创建 SocialRepository 实例
func NewSocialRepository() SocialRepository {
return &socialRepositoryImpl{
db: database.GetDB(),
}
}
// ========== 好友请求相关实现 ==========
func (r *socialRepositoryImpl) CreateRequest(request *models.FriendRequest) error {
if request == nil {
return errors.New("request cannot be nil")
}
return r.db.Create(request).Error
}
func (r *socialRepositoryImpl) GetRequestByID(requestID int64) (*models.FriendRequest, error) {
var request models.FriendRequest
err := r.db.Where("id = ?", requestID).First(&request).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil // 返回 nil 而不是错误,便于调用方判断
}
return nil, err
}
return &request, nil
}
func (r *socialRepositoryImpl) GetRequestsByUser(userID, starID int64, requestType string, status string, page, pageSize int) ([]*models.FriendRequest, int64, error) {
var requests []*models.FriendRequest
var total int64
query := r.db.Model(&models.FriendRequest{}).Where("star_id = ?", starID)
// 根据请求类型筛选
if requestType == models.FriendRequestTypeReceived {
query = query.Where("to_user_id = ?", userID)
} else if requestType == models.FriendRequestTypeSent {
query = query.Where("from_user_id = ?", userID)
} else {
return nil, 0, fmt.Errorf("invalid request type: %s", requestType)
}
// 根据状态筛选
if status != "" && status != "all" {
query = query.Where("status = ?", status)
}
// 统计总数
if err := query.Count(&total).Error; err != nil {
return nil, 0, err
}
// 分页查询
offset := (page - 1) * pageSize
err := query.Order("created_at DESC").
Limit(pageSize).
Offset(offset).
Find(&requests).Error
if err != nil {
return nil, 0, err
}
return requests, total, nil
}
func (r *socialRepositoryImpl) GetLatestRequest(fromUserID, toUserID, starID int64) (*models.FriendRequest, error) {
var request models.FriendRequest
err := r.db.Where("from_user_id = ? AND to_user_id = ? AND star_id = ?", fromUserID, toUserID, starID).
Order("created_at DESC").
First(&request).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil // 返回 nil 而不是错误,表示没有找到记录
}
return nil, err
}
return &request, nil
}
func (r *socialRepositoryImpl) UpdateRequestStatus(requestID int64, status string, processedAt *int64) error {
updates := map[string]interface{}{
"status": status,
}
if processedAt != nil {
updates["processed_at"] = *processedAt
}
result := r.db.Model(&models.FriendRequest{}).
Where("id = ?", requestID).
Updates(updates)
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return errors.New("friend request not found")
}
return nil
}
// ========== 好友关系相关实现 ==========
func (r *socialRepositoryImpl) CreateFriendshipPair(userID, friendID, starID int64) error {
// 使用事务创建双向好友关系
return r.db.Transaction(func(tx *gorm.DB) error {
// 创建 A→B
friendship1 := &models.Friendship{
UserID: userID,
FriendID: friendID,
StarID: starID,
Status: models.FriendshipStatusAccepted,
}
if err := tx.Create(friendship1).Error; err != nil {
return fmt.Errorf("failed to create friendship A→B: %w", err)
}
// 创建 B→A
friendship2 := &models.Friendship{
UserID: friendID,
FriendID: userID,
StarID: starID,
Status: models.FriendshipStatusAccepted,
}
if err := tx.Create(friendship2).Error; err != nil {
return fmt.Errorf("failed to create friendship B→A: %w", err)
}
return nil
})
}
func (r *socialRepositoryImpl) CheckFriendship(userID, friendID, starID int64) (bool, error) {
var count int64
err := r.db.Model(&models.Friendship{}).
Where("user_id = ? AND friend_id = ? AND star_id = ? AND status = ?",
userID, friendID, starID, models.FriendshipStatusAccepted).
Count(&count).Error
if err != nil {
return false, err
}
return count > 0, nil
}
func (r *socialRepositoryImpl) GetFriendsByUser(userID, starID int64, keyword string, page, pageSize int) ([]*models.Friendship, int64, error) {
var friendships []*models.Friendship
var total int64
query := r.db.Model(&models.Friendship{}).
Where("user_id = ? AND star_id = ? AND status = ?", userID, starID, models.FriendshipStatusAccepted)
// 如果有关键词,需要关联查询 fan_profiles 表
if keyword != "" {
keyword = "%" + keyword + "%"
// 关联查询昵称或备注
query = query.Where(
r.db.Where("remark LIKE ?", keyword).
Or("friend_id IN (?)",
r.db.Model(&models.FanProfile{}).
Select("user_id").
Where("star_id = ? AND nickname LIKE ?", starID, keyword),
),
)
}
// 统计总数
if err := query.Count(&total).Error; err != nil {
return nil, 0, err
}
// 分页查询,按创建时间倒序
offset := (page - 1) * pageSize
err := query.Order("created_at DESC").
Limit(pageSize).
Offset(offset).
Find(&friendships).Error
if err != nil {
return nil, 0, err
}
return friendships, total, nil
}
func (r *socialRepositoryImpl) DeleteFriendshipPair(userID, friendID, starID int64) error {
// 使用事务删除双向好友关系
return r.db.Transaction(func(tx *gorm.DB) error {
// 删除 A→B
result1 := tx.Where("user_id = ? AND friend_id = ? AND star_id = ?", userID, friendID, starID).
Delete(&models.Friendship{})
if result1.Error != nil {
return fmt.Errorf("failed to delete friendship A→B: %w", result1.Error)
}
// 删除 B→A
result2 := tx.Where("user_id = ? AND friend_id = ? AND star_id = ?", friendID, userID, starID).
Delete(&models.Friendship{})
if result2.Error != nil {
return fmt.Errorf("failed to delete friendship B→A: %w", result2.Error)
}
// 检查是否真的删除了记录
if result1.RowsAffected == 0 && result2.RowsAffected == 0 {
return errors.New("friendship not found")
}
return nil
})
}
func (r *socialRepositoryImpl) UpdateRemark(userID, friendID, starID int64, remark string) error {
result := r.db.Model(&models.Friendship{}).
Where("user_id = ? AND friend_id = ? AND star_id = ?", userID, friendID, starID).
Update("remark", remark)
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return errors.New("friendship not found")
}
return nil
}
func (r *socialRepositoryImpl) CountFriends(userID, starID int64) (int64, error) {
var count int64
err := r.db.Model(&models.Friendship{}).
Where("user_id = ? AND star_id = ? AND status = ?", userID, starID, models.FriendshipStatusAccepted).
Count(&count).Error
return count, err
}
// ========== 随机用户相关实现 ==========
// GetRandomUsersByStar 获取同一明星下的随机用户(基于偏移量算法)
func (r *socialRepositoryImpl) GetRandomUsersByStar(starID int64, count int) ([]*RandomUserInfo, error) {
// 1. 参数验证
if starID <= 0 {
return nil, errors.New("star_id must be greater than 0")
}
if count <= 0 {
count = 1 // 默认返回1个
}
if count > 100 {
count = 100 // 最大限制100个
}
// 2. 查询总数
var total int64
err := r.db.Model(&models.FanProfile{}).
Where("star_id = ? AND is_active = ?", starID, true).
Count(&total).Error
if err != nil {
return nil, fmt.Errorf("failed to count fan profiles: %w", err)
}
if total == 0 {
return []*RandomUserInfo{}, nil // 没有数据,返回空列表
}
// 3. 生成随机偏移量并查询
// 使用当前时间作为随机种子,确保每次调用都有不同的随机性
rand.Seed(time.Now().UnixNano())
randomOffset := rand.Int63n(total) // [0, total-1]
var profiles []models.FanProfile
err = r.db.Model(&models.FanProfile{}).
Select("user_id", "nickname").
Where("star_id = ? AND is_active = ?", starID, true).
Order("id ASC").
Limit(count).
Offset(int(randomOffset)).
Find(&profiles).Error
if err != nil {
return nil, fmt.Errorf("failed to get random users: %w", err)
}
// 4. 转换为结果
result := make([]*RandomUserInfo, 0, len(profiles))
for _, profile := range profiles {
result = append(result, &RandomUserInfo{
UserID: profile.UserID,
Nickname: profile.Nickname,
})
}
return result, nil
}
// GetUsersByStarPaged 分页获取同一明星下的用户列表
func (r *socialRepositoryImpl) GetUsersByStarPaged(starID int64, excludeUserID int64, page, pageSize int) ([]*PagedUserInfo, int64, error) {
offset := (page - 1) * pageSize
query := r.db.Model(&models.FanProfile{}).
Where("star_id = ?", starID)
if excludeUserID > 0 {
query = query.Where("user_id != ?", excludeUserID)
}
var total int64
if err := query.Count(&total).Error; err != nil {
return nil, 0, fmt.Errorf("failed to count users: %w", err)
}
var profiles []models.FanProfile
err := query.
Select("user_id", "nickname", "level", "slot_limit").
Order("id DESC").
Limit(pageSize).
Offset(offset).
Find(&profiles).Error
if err != nil {
return nil, 0, fmt.Errorf("failed to get paged users: %w", err)
}
result := make([]*PagedUserInfo, 0, len(profiles))
for _, profile := range profiles {
result = append(result, &PagedUserInfo{
UserID: profile.UserID,
Nickname: profile.Nickname,
Level: profile.Level,
SlotLimit: profile.SlotLimit,
})
}
return result, total, nil
}
// GetSlotCountsByUsers 批量获取用户的共享展位已占用数量
// 共享展位 = public visibility 的 BoothSlot
// 已占用 = BoothSlot 上有 Exhibition 记录
func (r *socialRepositoryImpl) GetSlotCountsByUsers(userIDs []int64, starID int64) (map[int64]int64, error) {
if len(userIDs) == 0 {
return make(map[int64]int64), nil
}
// 统计每个用户已占用的共享展位数量(通过 Exhibition 表 JOIN BoothSlot
var results []struct {
UserID int64
Count int64
}
err := r.db.Model(&models.Exhibition{}).
Select("bs.user_id, COUNT(*) as count").
Joins("JOIN booth_slots bs ON bs.slot_id = exhibitions.slot_id").
Where("bs.user_id IN ? AND bs.star_id = ? AND bs.visibility = ? AND bs.is_enabled = ?", userIDs, starID, "public", true).
Group("bs.user_id").
Find(&results).Error
if err != nil {
return nil, err
}
countMap := make(map[int64]int64)
for _, res := range results {
countMap[res.UserID] = res.Count
}
// 未找到记录的用户,默认为 0不需要在 map 中明确设置)
return countMap, nil
}
// GetFanProfileByUserIDAndStarID 根据userID和starID获取粉丝档案
func (r *socialRepositoryImpl) GetFanProfileByUserIDAndStarID(userID, starID int64) (*models.FanProfile, error) {
var profile models.FanProfile
err := r.db.Where("user_id = ? AND star_id = ?", userID, starID).First(&profile).Error
if err != nil {
return nil, err
}
return &profile, nil
}
// GetFanProfilesByNickname 根据昵称模糊搜索粉丝档案(同一明星下)
func (r *socialRepositoryImpl) GetFanProfilesByNickname(starID int64, nickname string, limit int) ([]*models.FanProfile, error) {
if limit <= 0 {
limit = 10 // 默认返回10条
}
if limit > 50 {
limit = 50 // 最大限制50条
}
var profiles []*models.FanProfile
err := r.db.Where("star_id = ? AND is_active = ? AND nickname LIKE ?", starID, true, "%"+nickname+"%").
Order("level DESC, updated_at DESC").
Limit(limit).
Find(&profiles).Error
if err != nil {
return nil, err
}
return profiles, nil
}