package repository import ( "errors" "github.com/topfans/backend/pkg/database" "github.com/topfans/backend/pkg/models" "gorm.io/gorm" ) // ActivityRepository Activity仓库接口 type ActivityRepository interface { // CreateActivity 创建活动 CreateActivity(activity *models.Activity) error // GetActivityByID 根据ID获取活动 GetActivityByID(id int64) (*models.Activity, error) // GetActivitiesByStar 根据star_id获取活动列表 GetActivitiesByStar(starID int64, status string, page, pageSize int) ([]*models.Activity, int64, error) // UpdateActivityProgress 更新活动进度 UpdateActivityProgress(id int64, progress int64) error // GetActivityItems 获取活动道具列表 GetActivityItems(activityID int64) ([]*models.ActivityItem, error) // GetActivityItemByType 根据类型获取道具 GetActivityItemByType(activityID int64, itemType string) (*models.ActivityItem, error) // CreateContribution 创建贡献记录 CreateContribution(contribution *models.ActivityContribution) error // GetUserStats 获取用户活动统计 GetUserStats(activityID, userID, starID int64) (*models.ActivityUserStats, error) // UpdateUserStats 更新用户活动统计 UpdateUserStats(stats *models.ActivityUserStats) error // GetRanking 获取排行榜 GetRanking(activityID, starID int64, page, pageSize int) ([]*models.ActivityUserStats, int64, error) // GetUserRank 获取用户排名 GetUserRank(userID, activityID, starID int64) (int, error) } // activityRepository Activity仓库实现 type activityRepository struct { db *gorm.DB } // NewActivityRepository 创建Activity仓库实例 func NewActivityRepository() ActivityRepository { return &activityRepository{ db: database.GetDB(), } } // CreateActivity 创建活动 func (r *activityRepository) CreateActivity(activity *models.Activity) error { if activity == nil { return errors.New("activity cannot be nil") } return r.db.Create(activity).Error } // GetActivityByID 根据ID获取活动 func (r *activityRepository) GetActivityByID(id int64) (*models.Activity, error) { if id <= 0 { return nil, errors.New("activity id must be greater than 0") } var activity models.Activity if err := r.db.Preload("Items", "is_active = ?", true).First(&activity, id).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, nil } return nil, err } return &activity, nil } // GetActivitiesByStar 根据star_id获取活动列表 func (r *activityRepository) GetActivitiesByStar(starID int64, status string, page, pageSize int) ([]*models.Activity, int64, error) { if starID <= 0 { return nil, 0, errors.New("star_id must be greater than 0") } if page <= 0 { page = 1 } if pageSize <= 0 { pageSize = 10 } query := r.db.Model(&models.Activity{}).Where("star_id = ?", starID) if status != "" { query = query.Where("status = ?", status) } var total int64 if err := query.Count(&total).Error; err != nil { return nil, 0, err } offset := (page - 1) * pageSize var activities []*models.Activity if err := query.Preload("Items", "is_active = ?", true). Order("start_time DESC"). Offset(offset). Limit(pageSize). Find(&activities).Error; err != nil { return nil, 0, err } return activities, total, nil } // UpdateActivityProgress 更新活动进度 func (r *activityRepository) UpdateActivityProgress(id int64, progress int64) error { if id <= 0 { return errors.New("activity id must be greater than 0") } return r.db.Model(&models.Activity{}). Where("id = ?", id). Update("current_progress", progress).Error } // GetActivityItems 获取活动道具列表 func (r *activityRepository) GetActivityItems(activityID int64) ([]*models.ActivityItem, error) { if activityID <= 0 { return nil, errors.New("activity_id must be greater than 0") } var items []*models.ActivityItem if err := r.db.Where("activity_id = ? AND is_active = ?", activityID, true). Order("sort_order ASC"). Find(&items).Error; err != nil { return nil, err } return items, nil } // GetActivityItemByType 根据类型获取道具 func (r *activityRepository) GetActivityItemByType(activityID int64, itemType string) (*models.ActivityItem, error) { if activityID <= 0 { return nil, errors.New("activity_id must be greater than 0") } if itemType == "" { return nil, errors.New("item_type cannot be empty") } var item models.ActivityItem if err := r.db.Where("activity_id = ? AND item_type = ? AND is_active = ?", activityID, itemType, true). First(&item).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, nil } return nil, err } return &item, nil } // CreateContribution 创建贡献记录 func (r *activityRepository) CreateContribution(contribution *models.ActivityContribution) error { if contribution == nil { return errors.New("contribution cannot be nil") } return r.db.Create(contribution).Error } // GetUserStats 获取用户活动统计 func (r *activityRepository) GetUserStats(activityID, userID, starID int64) (*models.ActivityUserStats, error) { if activityID <= 0 || userID <= 0 || starID <= 0 { return nil, errors.New("activity_id, user_id, star_id must be greater than 0") } var stats models.ActivityUserStats if err := r.db.Where("activity_id = ? AND user_id = ? AND star_id = ?", activityID, userID, starID). First(&stats).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, nil } return nil, err } return &stats, nil } // UpdateUserStats 更新用户活动统计 func (r *activityRepository) UpdateUserStats(stats *models.ActivityUserStats) error { if stats == nil { return errors.New("stats cannot be nil") } // 使用 upsert 逻辑 var existing models.ActivityUserStats err := r.db.Where("activity_id = ? AND user_id = ? AND star_id = ?", stats.ActivityID, stats.UserID, stats.StarID).First(&existing).Error if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { // 不存在,创建新记录 return r.db.Create(stats).Error } return err } // 存在,更新记录 return r.db.Model(&existing). Updates(map[string]interface{}{ "total_contribution": stats.TotalContribution, "total_crystal_spent": stats.TotalCrystalSpent, "total_items": stats.TotalItems, "last_contribute_at": stats.LastContributeAt, "updated_at": stats.UpdatedAt, }).Error } // GetRanking 获取排行榜 func (r *activityRepository) GetRanking(activityID, starID int64, page, pageSize int) ([]*models.ActivityUserStats, int64, error) { if activityID <= 0 { return nil, 0, errors.New("activity_id must be greater than 0") } if page <= 0 { page = 1 } if pageSize <= 0 { pageSize = 10 } query := r.db.Model(&models.ActivityUserStats{}).Where("activity_id = ?", activityID) // 添加 star_id 过滤 if starID > 0 { query = query.Where("star_id = ?", starID) } var total int64 if err := query.Count(&total).Error; err != nil { return nil, 0, err } offset := (page - 1) * pageSize var stats []*models.ActivityUserStats if err := query.Order("total_contribution DESC"). Offset(offset). Limit(pageSize). Find(&stats).Error; err != nil { return nil, 0, err } return stats, total, nil } // GetUserRank 获取用户排名 func (r *activityRepository) GetUserRank(userID, activityID, starID int64) (int, error) { if userID <= 0 || activityID <= 0 || starID <= 0 { return 0, errors.New("user_id, activity_id, star_id must be greater than 0") } // 获取用户的贡献值 var userStats models.ActivityUserStats if err := r.db.Where("activity_id = ? AND user_id = ? AND star_id = ?", activityID, userID, starID). First(&userStats).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return 0, nil // 用户没有贡献记录 } return 0, err } // 计算排名:统计贡献值大于当前用户的数量 var count int64 if err := r.db.Model(&models.ActivityUserStats{}). Where("activity_id = ? AND star_id = ? AND total_contribution > ?", activityID, starID, userStats.TotalContribution). Count(&count).Error; err != nil { return 0, err } return int(count) + 1, nil }