package repository import ( "context" "errors" "fmt" "github.com/topfans/backend/pkg/logger" "github.com/topfans/backend/services/notificationService/model" "go.uber.org/zap" "gorm.io/gorm" ) // NotificationStatsRepository 通知未读计数仓储(操作 public.notification_stats 表)。 // // 设计:所有写方法要求事务上下文(service 层在事务回调内同时写 notifications 与 stats); // Get 是单条读,使用仓储默认 db。 type NotificationStatsRepository struct { db *gorm.DB } // NewNotificationStatsRepository 创建未读计数仓储。 func NewNotificationStatsRepository(db *gorm.DB) *NotificationStatsRepository { return &NotificationStatsRepository{db: db} } // IncrementByType 在事务内对指定 type 未读数 +1(total +1)。 func (r *NotificationStatsRepository) IncrementByType(ctx context.Context, tx *gorm.DB, userID, starID int64, ntype string, now int64) error { if tx == nil { return errors.New("IncrementByType must be called within a transaction") } col, err := typeToColumn(ntype) if err != nil { return err } // 使用 ON CONFLICT 做 upsert:首次写入时插入 1,后续累加。 // GORM Exec 支持 PostgreSQL 占位符 $1/$2/$3。 query := fmt.Sprintf(` INSERT INTO public.notification_stats (user_id, star_id, %s, total_unread_count, updated_at) VALUES ($1, $2, 1, 1, $3) ON CONFLICT (user_id, star_id) DO UPDATE SET %[1]s = public.notification_stats.%[1]s + 1, total_unread_count = public.notification_stats.total_unread_count + 1, updated_at = $3 `, col) if err := tx.WithContext(ctx).Exec(query, userID, starID, now).Error; err != nil { logger.Logger.Error("failed to increment notification stats", zap.Int64("user_id", userID), zap.Int64("star_id", starID), zap.String("type", ntype), zap.Error(err)) return fmt.Errorf("increment stats: %w", err) } return nil } // DecrementByType 在事务内对指定 type 未读数 -N(total -N)。 // // 使用 GREATEST(0, ...) 防止计数被减为负数。 func (r *NotificationStatsRepository) DecrementByType(ctx context.Context, tx *gorm.DB, userID, starID int64, ntype string, delta int, now int64) error { if delta <= 0 { return nil } if tx == nil { return errors.New("DecrementByType must be called within a transaction") } col, err := typeToColumn(ntype) if err != nil { return err } query := fmt.Sprintf(` UPDATE public.notification_stats SET %s = GREATEST(0, %[1]s - $3), total_unread_count = GREATEST(0, total_unread_count - $3), updated_at = $4 WHERE user_id = $1 AND star_id = $2 `, col) if err := tx.WithContext(ctx).Exec(query, userID, starID, delta, now).Error; err != nil { logger.Logger.Error("failed to decrement notification stats", zap.Int64("user_id", userID), zap.Int64("star_id", starID), zap.String("type", ntype), zap.Int("delta", delta), zap.Error(err)) return fmt.Errorf("decrement stats: %w", err) } return nil } // ResetByType 把指定 type 未读数置 0(同时从 total 中减去)。 // // ntype 为空字符串表示重置全部(所有 type + total)。 func (r *NotificationStatsRepository) ResetByType(ctx context.Context, tx *gorm.DB, userID, starID int64, ntype string, now int64) error { if tx == nil { return errors.New("ResetByType must be called within a transaction") } gdb := tx.WithContext(ctx) if ntype == "" { return gdb.Exec(` UPDATE public.notification_stats SET like_unread_count = 0, system_unread_count = 0, activity_unread_count = 0, total_unread_count = 0, updated_at = $3 WHERE user_id = $1 AND star_id = $2 `, userID, starID, now).Error } col, err := typeToColumn(ntype) if err != nil { return err } // 仅减去本类型的数值,避免误扣其它类型的累计。 query := fmt.Sprintf(` UPDATE public.notification_stats SET %s = 0, total_unread_count = GREATEST(0, total_unread_count - %[1]s), updated_at = $3 WHERE user_id = $1 AND star_id = $2 `, col) if err := gdb.Exec(query, userID, starID, now).Error; err != nil { return fmt.Errorf("reset stats: %w", err) } return nil } // Get 拉取 user+star 的统计行;记录不存在时返回零值。 func (r *NotificationStatsRepository) Get(ctx context.Context, userID, starID int64) (*model.NotificationStats, error) { s := &model.NotificationStats{} err := r.db.WithContext(ctx).Raw(` SELECT user_id, star_id, like_unread_count, system_unread_count, activity_unread_count, total_unread_count, updated_at FROM public.notification_stats WHERE user_id = $1 AND star_id = $2 `, userID, starID).Row().Scan( &s.UserID, &s.StarID, &s.LikeUnreadCount, &s.SystemUnreadCount, &s.ActivityUnreadCount, &s.TotalUnreadCount, &s.UpdatedAt, ) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return &model.NotificationStats{UserID: userID, StarID: starID}, nil } return nil, fmt.Errorf("get stats: %w", err) } return s, nil } // typeToColumn 将通知类型映射到 stats 表的列名。 func typeToColumn(ntype string) (string, error) { switch ntype { case "like": return "like_unread_count", nil case "system": return "system_unread_count", nil case "activity": return "activity_unread_count", nil default: return "", fmt.Errorf("invalid notification type: %s", ntype) } }