package repository import ( "errors" "fmt" "math/rand" "time" "gorm.io/gorm" "github.com/topfans/backend/pkg/database" "github.com/topfans/backend/pkg/models" ) // SocialRepository 社交数据访问接口 // 目前包含好友功能,后续可扩展关注、点赞等功能 type SocialRepository interface { // ========== 好友请求相关 ========== // CreateRequest 创建好友请求 CreateRequest(request *models.FriendRequest) error // GetRequestByID 根据ID获取好友请求 GetRequestByID(requestID int64) (*models.FriendRequest, error) // GetRequestsByUser 获取用户的好友请求列表 // userID: 用户ID // starID: 明星ID // requestType: 请求类型(received=收到的, sent=发出的) // status: 状态筛选(可选,空字符串表示所有状态) // page: 页码(从1开始) // pageSize: 每页数量 GetRequestsByUser(userID, starID int64, requestType string, status string, page, pageSize int) ([]*models.FriendRequest, int64, error) // GetLatestRequest 获取两个用户之间最近的请求记录(用于防骚扰机制) // fromUserID: 发送者ID // toUserID: 接收者ID // starID: 明星ID GetLatestRequest(fromUserID, toUserID, starID int64) (*models.FriendRequest, error) // UpdateRequestStatus 更新好友请求状态 // requestID: 请求ID // status: 新状态 // processedAt: 处理时间(可选,传nil表示不更新) UpdateRequestStatus(requestID int64, status string, processedAt *int64) error // ========== 好友关系相关 ========== // CreateFriendshipPair 创建双向好友关系(A→B 和 B→A) // 内部使用事务保证原子性 CreateFriendshipPair(userID, friendID, starID int64) error // CheckFriendship 检查是否为好友关系 CheckFriendship(userID, friendID, starID int64) (bool, error) // GetFriendsByUser 获取用户的好友列表(支持分页和关键词搜索) // userID: 用户ID // starID: 明星ID // keyword: 搜索关键词(搜索备注名或昵称,可选) // page: 页码(从1开始) // pageSize: 每页数量 GetFriendsByUser(userID, starID int64, keyword string, page, pageSize int) ([]*models.Friendship, int64, error) // DeleteFriendshipPair 删除双向好友关系(A→B 和 B→A) // 内部使用事务保证原子性 DeleteFriendshipPair(userID, friendID, starID int64) error // UpdateRemark 更新好友备注 UpdateRemark(userID, friendID, starID int64, remark string) error // CountFriends 统计用户的好友数量 CountFriends(userID, starID int64) (int64, error) // ========== 随机用户相关 ========== // GetRandomUsersByStar 获取同一明星下的随机用户(基于偏移量算法) // starID: 明星ID // count: 返回数量(默认1,最大100) // 返回: 随机用户列表(user_id, nickname) GetRandomUsersByStar(starID int64, count int) ([]*RandomUserInfo, error) // GetUsersByStarPaged 分页获取同一明星下的用户列表 // starID: 明星ID // excludeUserID: 要排除的用户ID(通常是当前用户,传0则不排除) // page: 页码(从1开始) // pageSize: 每页数量 // 返回: 用户列表(user_id, nickname, level, slot_limit)和总数 GetUsersByStarPaged(starID int64, excludeUserID int64, page, pageSize int) ([]*PagedUserInfo, int64, error) // GetSlotCountsByUsers 批量获取用户的展位占用数量 // userIDs: 用户ID列表 // starID: 明星ID // 返回: map[userID]occupiedSlots GetSlotCountsByUsers(userIDs []int64, starID int64) (map[int64]int64, error) // GetFanProfileByUserIDAndStarID 根据userID和starID获取粉丝档案 GetFanProfileByUserIDAndStarID(userID, starID int64) (*models.FanProfile, error) } // RandomUserInfo 随机用户信息 type RandomUserInfo struct { UserID int64 Nickname string } // PagedUserInfo 分页用户信息 type PagedUserInfo struct { UserID int64 Nickname string Level int32 SlotLimit int32 // 槽位上限 } // socialRepositoryImpl SocialRepository 的实现 type socialRepositoryImpl struct { db *gorm.DB } // NewSocialRepository 创建 SocialRepository 实例 func NewSocialRepository() SocialRepository { return &socialRepositoryImpl{ db: database.GetDB(), } } // ========== 好友请求相关实现 ========== func (r *socialRepositoryImpl) CreateRequest(request *models.FriendRequest) error { if request == nil { return errors.New("request cannot be nil") } return r.db.Create(request).Error } func (r *socialRepositoryImpl) GetRequestByID(requestID int64) (*models.FriendRequest, error) { var request models.FriendRequest err := r.db.Where("id = ?", requestID).First(&request).Error if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, nil // 返回 nil 而不是错误,便于调用方判断 } return nil, err } return &request, nil } func (r *socialRepositoryImpl) GetRequestsByUser(userID, starID int64, requestType string, status string, page, pageSize int) ([]*models.FriendRequest, int64, error) { var requests []*models.FriendRequest var total int64 query := r.db.Model(&models.FriendRequest{}).Where("star_id = ?", starID) // 根据请求类型筛选 if requestType == models.FriendRequestTypeReceived { query = query.Where("to_user_id = ?", userID) } else if requestType == models.FriendRequestTypeSent { query = query.Where("from_user_id = ?", userID) } else { return nil, 0, fmt.Errorf("invalid request type: %s", requestType) } // 根据状态筛选 if status != "" && status != "all" { query = query.Where("status = ?", status) } // 统计总数 if err := query.Count(&total).Error; err != nil { return nil, 0, err } // 分页查询 offset := (page - 1) * pageSize err := query.Order("created_at DESC"). Limit(pageSize). Offset(offset). Find(&requests).Error if err != nil { return nil, 0, err } return requests, total, nil } func (r *socialRepositoryImpl) GetLatestRequest(fromUserID, toUserID, starID int64) (*models.FriendRequest, error) { var request models.FriendRequest err := r.db.Where("from_user_id = ? AND to_user_id = ? AND star_id = ?", fromUserID, toUserID, starID). Order("created_at DESC"). First(&request).Error if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, nil // 返回 nil 而不是错误,表示没有找到记录 } return nil, err } return &request, nil } func (r *socialRepositoryImpl) UpdateRequestStatus(requestID int64, status string, processedAt *int64) error { updates := map[string]interface{}{ "status": status, } if processedAt != nil { updates["processed_at"] = *processedAt } result := r.db.Model(&models.FriendRequest{}). Where("id = ?", requestID). Updates(updates) if result.Error != nil { return result.Error } if result.RowsAffected == 0 { return errors.New("friend request not found") } return nil } // ========== 好友关系相关实现 ========== func (r *socialRepositoryImpl) CreateFriendshipPair(userID, friendID, starID int64) error { // 使用事务创建双向好友关系 return r.db.Transaction(func(tx *gorm.DB) error { // 创建 A→B friendship1 := &models.Friendship{ UserID: userID, FriendID: friendID, StarID: starID, Status: models.FriendshipStatusAccepted, } if err := tx.Create(friendship1).Error; err != nil { return fmt.Errorf("failed to create friendship A→B: %w", err) } // 创建 B→A friendship2 := &models.Friendship{ UserID: friendID, FriendID: userID, StarID: starID, Status: models.FriendshipStatusAccepted, } if err := tx.Create(friendship2).Error; err != nil { return fmt.Errorf("failed to create friendship B→A: %w", err) } return nil }) } func (r *socialRepositoryImpl) CheckFriendship(userID, friendID, starID int64) (bool, error) { var count int64 err := r.db.Model(&models.Friendship{}). Where("user_id = ? AND friend_id = ? AND star_id = ? AND status = ?", userID, friendID, starID, models.FriendshipStatusAccepted). Count(&count).Error if err != nil { return false, err } return count > 0, nil } func (r *socialRepositoryImpl) GetFriendsByUser(userID, starID int64, keyword string, page, pageSize int) ([]*models.Friendship, int64, error) { var friendships []*models.Friendship var total int64 query := r.db.Model(&models.Friendship{}). Where("user_id = ? AND star_id = ? AND status = ?", userID, starID, models.FriendshipStatusAccepted) // 如果有关键词,需要关联查询 fan_profiles 表 if keyword != "" { keyword = "%" + keyword + "%" // 关联查询昵称或备注 query = query.Where( r.db.Where("remark LIKE ?", keyword). Or("friend_id IN (?)", r.db.Model(&models.FanProfile{}). Select("user_id"). Where("star_id = ? AND nickname LIKE ?", starID, keyword), ), ) } // 统计总数 if err := query.Count(&total).Error; err != nil { return nil, 0, err } // 分页查询,按创建时间倒序 offset := (page - 1) * pageSize err := query.Order("created_at DESC"). Limit(pageSize). Offset(offset). Find(&friendships).Error if err != nil { return nil, 0, err } return friendships, total, nil } func (r *socialRepositoryImpl) DeleteFriendshipPair(userID, friendID, starID int64) error { // 使用事务删除双向好友关系 return r.db.Transaction(func(tx *gorm.DB) error { // 删除 A→B result1 := tx.Where("user_id = ? AND friend_id = ? AND star_id = ?", userID, friendID, starID). Delete(&models.Friendship{}) if result1.Error != nil { return fmt.Errorf("failed to delete friendship A→B: %w", result1.Error) } // 删除 B→A result2 := tx.Where("user_id = ? AND friend_id = ? AND star_id = ?", friendID, userID, starID). Delete(&models.Friendship{}) if result2.Error != nil { return fmt.Errorf("failed to delete friendship B→A: %w", result2.Error) } // 检查是否真的删除了记录 if result1.RowsAffected == 0 && result2.RowsAffected == 0 { return errors.New("friendship not found") } return nil }) } func (r *socialRepositoryImpl) UpdateRemark(userID, friendID, starID int64, remark string) error { result := r.db.Model(&models.Friendship{}). Where("user_id = ? AND friend_id = ? AND star_id = ?", userID, friendID, starID). Update("remark", remark) if result.Error != nil { return result.Error } if result.RowsAffected == 0 { return errors.New("friendship not found") } return nil } func (r *socialRepositoryImpl) CountFriends(userID, starID int64) (int64, error) { var count int64 err := r.db.Model(&models.Friendship{}). Where("user_id = ? AND star_id = ? AND status = ?", userID, starID, models.FriendshipStatusAccepted). Count(&count).Error return count, err } // ========== 随机用户相关实现 ========== // GetRandomUsersByStar 获取同一明星下的随机用户(基于偏移量算法) func (r *socialRepositoryImpl) GetRandomUsersByStar(starID int64, count int) ([]*RandomUserInfo, error) { // 1. 参数验证 if starID <= 0 { return nil, errors.New("star_id must be greater than 0") } if count <= 0 { count = 1 // 默认返回1个 } if count > 100 { count = 100 // 最大限制100个 } // 2. 查询总数 var total int64 err := r.db.Model(&models.FanProfile{}). Where("star_id = ? AND is_active = ?", starID, true). Count(&total).Error if err != nil { return nil, fmt.Errorf("failed to count fan profiles: %w", err) } if total == 0 { return []*RandomUserInfo{}, nil // 没有数据,返回空列表 } // 3. 生成随机偏移量并查询 // 使用当前时间作为随机种子,确保每次调用都有不同的随机性 rand.Seed(time.Now().UnixNano()) randomOffset := rand.Int63n(total) // [0, total-1] var profiles []models.FanProfile err = r.db.Model(&models.FanProfile{}). Select("user_id", "nickname"). Where("star_id = ? AND is_active = ?", starID, true). Order("id ASC"). Limit(count). Offset(int(randomOffset)). Find(&profiles).Error if err != nil { return nil, fmt.Errorf("failed to get random users: %w", err) } // 4. 转换为结果 result := make([]*RandomUserInfo, 0, len(profiles)) for _, profile := range profiles { result = append(result, &RandomUserInfo{ UserID: profile.UserID, Nickname: profile.Nickname, }) } return result, nil } // GetUsersByStarPaged 分页获取同一明星下的用户列表 func (r *socialRepositoryImpl) GetUsersByStarPaged(starID int64, excludeUserID int64, page, pageSize int) ([]*PagedUserInfo, int64, error) { offset := (page - 1) * pageSize query := r.db.Model(&models.FanProfile{}). Where("star_id = ?", starID) if excludeUserID > 0 { query = query.Where("user_id != ?", excludeUserID) } var total int64 if err := query.Count(&total).Error; err != nil { return nil, 0, fmt.Errorf("failed to count users: %w", err) } var profiles []models.FanProfile err := query. Select("user_id", "nickname", "level", "slot_limit"). Order("id DESC"). Limit(pageSize). Offset(offset). Find(&profiles).Error if err != nil { return nil, 0, fmt.Errorf("failed to get paged users: %w", err) } result := make([]*PagedUserInfo, 0, len(profiles)) for _, profile := range profiles { result = append(result, &PagedUserInfo{ UserID: profile.UserID, Nickname: profile.Nickname, Level: profile.Level, SlotLimit: profile.SlotLimit, }) } return result, total, nil } // GetSlotCountsByUsers 批量获取用户的共享展位已占用数量 // 共享展位 = public visibility 的 BoothSlot // 已占用 = BoothSlot 上有 Exhibition 记录 func (r *socialRepositoryImpl) GetSlotCountsByUsers(userIDs []int64, starID int64) (map[int64]int64, error) { if len(userIDs) == 0 { return make(map[int64]int64), nil } // 统计每个用户已占用的共享展位数量(通过 Exhibition 表 JOIN BoothSlot) var results []struct { UserID int64 Count int64 } err := r.db.Model(&models.Exhibition{}). Select("bs.user_id, COUNT(*) as count"). Joins("JOIN booth_slots bs ON bs.slot_id = exhibitions.slot_id"). Where("bs.user_id IN ? AND bs.star_id = ? AND bs.visibility = ? AND bs.is_enabled = ?", userIDs, starID, "public", true). Group("bs.user_id"). Find(&results).Error if err != nil { return nil, err } countMap := make(map[int64]int64) for _, res := range results { countMap[res.UserID] = res.Count } // 未找到记录的用户,默认为 0(不需要在 map 中明确设置) return countMap, nil } // GetFanProfileByUserIDAndStarID 根据userID和starID获取粉丝档案 func (r *socialRepositoryImpl) GetFanProfileByUserIDAndStarID(userID, starID int64) (*models.FanProfile, error) { var profile models.FanProfile err := r.db.Where("user_id = ? AND star_id = ?", userID, starID).First(&profile).Error if err != nil { return nil, err } return &profile, nil }