package repository import ( "errors" "time" "github.com/topfans/backend/pkg/models" "github.com/topfans/backend/services/galleryService/config" "gorm.io/gorm" "gorm.io/gorm/clause" ) // GalleryRepository 展馆数据访问层接口 type GalleryRepository interface { // 展位相关 GetSlotsByUser(userID, starID int64) ([]*models.BoothSlot, error) GetSlotByID(slotID int64) (*models.BoothSlot, error) GetSlotCount(userID, starID int64) (int64, error) CreateInitialSlots(userID, starID, hostProfileID int64) error CreateSlot(slot *models.BoothSlot) error UnlockSlot(slotID int64) error // 展品相关 GetExhibitionByAsset(assetID int64) (*models.Exhibition, error) GetExhibitionBySlot(slotID int64) (*models.Exhibition, error) GetExhibitionsByUser(userID, starID int64) ([]*models.Exhibition, error) CreateExhibition(exhibition *models.Exhibition) error DeleteExhibition(exhibitionID int64) error DeleteExhibitionByAsset(assetID int64) error GetExpiredExhibitions(beforeTime int64) ([]*models.Exhibition, error) } // galleryRepository Repository实现 type galleryRepository struct { db *gorm.DB } // NewGalleryRepository 创建Repository实例 func NewGalleryRepository(db *gorm.DB) GalleryRepository { return &galleryRepository{db: db} } // ==================== 展位相关 ==================== // GetSlotsByUser 获取用户的所有展位 func (r *galleryRepository) GetSlotsByUser(userID, starID int64) ([]*models.BoothSlot, error) { var slots []*models.BoothSlot err := r.db.Where("user_id = ? AND star_id = ?", userID, starID). Order("slot_index ASC"). Find(&slots).Error return slots, err } // GetSlotByID 根据ID获取展位 func (r *galleryRepository) GetSlotByID(slotID int64) (*models.BoothSlot, error) { var slot models.BoothSlot err := r.db.Where("slot_id = ?", slotID).First(&slot).Error if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, errors.New("展位不存在") } return nil, err } return &slot, nil } // GetSlotCount 获取用户的展位数量 func (r *galleryRepository) GetSlotCount(userID, starID int64) (int64, error) { var count int64 err := r.db.Model(&models.BoothSlot{}). Where("user_id = ? AND star_id = ?", userID, starID). Count(&count).Error return count, err } // CreateInitialSlots 创建初始展位(懒加载,支持并发安全) func (r *galleryRepository) CreateInitialSlots(userID, starID int64, hostProfileID int64) error { // 使用 PostgreSQL 的 ON CONFLICT 保证并发安全性 return r.db.Transaction(func(tx *gorm.DB) error { now := time.Now().UnixMilli() initialSlotCount := config.GalleryRules.InitialSlotCount for i := 1; i <= initialSlotCount; i++ { vis := "public" if i > 3 { vis = "private" } slot := &models.BoothSlot{ HostProfileID: hostProfileID, UserID: userID, StarID: starID, SlotIndex: i, Visibility: vis, IsEnabled: true, UnlockType: "free", UnlockValue: 0, CreatedAt: now, UpdatedAt: now, } // 使用 Clause 处理冲突,确保幂等性 err := tx.Clauses(clause.OnConflict{ Columns: []clause.Column{{Name: "host_profile_id"}, {Name: "slot_index"}}, DoNothing: true, }).Create(slot).Error if err != nil { return err } } return nil }) } // CreateSlot 创建新展位(用于解锁) func (r *galleryRepository) CreateSlot(slot *models.BoothSlot) error { now := time.Now().UnixMilli() slot.CreatedAt = now slot.UpdatedAt = now return r.db.Create(slot).Error } // UnlockSlot 解锁展位 func (r *galleryRepository) UnlockSlot(slotID int64) error { now := time.Now().UnixMilli() return r.db.Model(&models.BoothSlot{}). Where("slot_id = ?", slotID). Updates(map[string]interface{}{ "is_enabled": true, "updated_at": now, }).Error } // ==================== 展品相关 ==================== // GetExhibitionByAsset 根据资产ID获取展品展示记录 func (r *galleryRepository) GetExhibitionByAsset(assetID int64) (*models.Exhibition, error) { var exhibition models.Exhibition err := r.db.Where("asset_id = ?", assetID).First(&exhibition).Error if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, nil // 未找到记录,返回 nil(不是错误) } return nil, err } return &exhibition, nil } // GetExhibitionBySlot 根据展位ID获取展品展示记录 func (r *galleryRepository) GetExhibitionBySlot(slotID int64) (*models.Exhibition, error) { var exhibition models.Exhibition err := r.db.Where("slot_id = ?", slotID).First(&exhibition).Error if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, nil // 未找到记录,返回 nil(不是错误) } return nil, err } return &exhibition, nil } // GetExhibitionsByUser 获取用户的所有展品展示记录 func (r *galleryRepository) GetExhibitionsByUser(userID, starID int64) ([]*models.Exhibition, error) { var exhibitions []*models.Exhibition err := r.db.Where("occupier_uid = ? AND occupier_star_id = ?", userID, starID). Find(&exhibitions).Error return exhibitions, err } // CreateExhibition 创建展品展示记录 func (r *galleryRepository) CreateExhibition(exhibition *models.Exhibition) error { now := time.Now().UnixMilli() exhibition.CreatedAt = now exhibition.UpdatedAt = now return r.db.Create(exhibition).Error } // DeleteExhibition 删除展品展示记录(根据ID) func (r *galleryRepository) DeleteExhibition(exhibitionID int64) error { return r.db.Where("id = ?", exhibitionID).Delete(&models.Exhibition{}).Error } // DeleteExhibitionByAsset 删除展品展示记录(根据资产ID) func (r *galleryRepository) DeleteExhibitionByAsset(assetID int64) error { return r.db.Where("asset_id = ?", assetID).Delete(&models.Exhibition{}).Error } // GetExpiredExhibitions 获取过期的展品展示记录 func (r *galleryRepository) GetExpiredExhibitions(beforeTime int64) ([]*models.Exhibition, error) { var exhibitions []*models.Exhibition err := r.db.Where("expire_at <= ?", beforeTime).Find(&exhibitions).Error return exhibitions, err } // ==================== 辅助函数 ==================== // generateHostProfileID 生成 host_profile_id // 注意:这里使用简单的生成逻辑,实际应该与 fan_profiles 表的逻辑一致 func generateHostProfileID(userID, starID int64) int64 { // 使用简单的组合方式:userID * 1000000 + starID // 实际项目中应该使用与 User Service 一致的逻辑 return userID*1000000 + starID }