topfans/backend/services/notificationService/repository/notification_stats_repository.go
2026-06-16 21:30:58 +08:00

159 lines
5.2 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package repository
import (
"context"
"errors"
"fmt"
"github.com/topfans/backend/pkg/logger"
"github.com/topfans/backend/services/notificationService/model"
"go.uber.org/zap"
"gorm.io/gorm"
)
// NotificationStatsRepository 通知未读计数仓储(操作 public.notification_stats 表)。
//
// 设计所有写方法要求事务上下文service 层在事务回调内同时写 notifications 与 stats
// Get 是单条读,使用仓储默认 db。
type NotificationStatsRepository struct {
db *gorm.DB
}
// NewNotificationStatsRepository 创建未读计数仓储。
func NewNotificationStatsRepository(db *gorm.DB) *NotificationStatsRepository {
return &NotificationStatsRepository{db: db}
}
// IncrementByType 在事务内对指定 type 未读数 +1total +1
func (r *NotificationStatsRepository) IncrementByType(ctx context.Context, tx *gorm.DB, userID, starID int64, ntype string, now int64) error {
if tx == nil {
return errors.New("IncrementByType must be called within a transaction")
}
col, err := typeToColumn(ntype)
if err != nil {
return err
}
// 使用 ON CONFLICT 做 upsert首次写入时插入 1后续累加。
// GORM Exec 支持 PostgreSQL 占位符 $1/$2/$3。
query := fmt.Sprintf(`
INSERT INTO public.notification_stats (user_id, star_id, %s, total_unread_count, updated_at)
VALUES ($1, $2, 1, 1, $3)
ON CONFLICT (user_id, star_id) DO UPDATE SET
%[1]s = public.notification_stats.%[1]s + 1,
total_unread_count = public.notification_stats.total_unread_count + 1,
updated_at = $3
`, col)
if err := tx.WithContext(ctx).Exec(query, userID, starID, now).Error; err != nil {
logger.Logger.Error("failed to increment notification stats",
zap.Int64("user_id", userID),
zap.Int64("star_id", starID),
zap.String("type", ntype),
zap.Error(err))
return fmt.Errorf("increment stats: %w", err)
}
return nil
}
// DecrementByType 在事务内对指定 type 未读数 -Ntotal -N
//
// 使用 GREATEST(0, ...) 防止计数被减为负数。
func (r *NotificationStatsRepository) DecrementByType(ctx context.Context, tx *gorm.DB, userID, starID int64, ntype string, delta int, now int64) error {
if delta <= 0 {
return nil
}
if tx == nil {
return errors.New("DecrementByType must be called within a transaction")
}
col, err := typeToColumn(ntype)
if err != nil {
return err
}
query := fmt.Sprintf(`
UPDATE public.notification_stats
SET %s = GREATEST(0, %[1]s - $3),
total_unread_count = GREATEST(0, total_unread_count - $3),
updated_at = $4
WHERE user_id = $1 AND star_id = $2
`, col)
if err := tx.WithContext(ctx).Exec(query, userID, starID, delta, now).Error; err != nil {
logger.Logger.Error("failed to decrement notification stats",
zap.Int64("user_id", userID),
zap.Int64("star_id", starID),
zap.String("type", ntype),
zap.Int("delta", delta),
zap.Error(err))
return fmt.Errorf("decrement stats: %w", err)
}
return nil
}
// ResetByType 把指定 type 未读数置 0同时从 total 中减去)。
//
// ntype 为空字符串表示重置全部(所有 type + total
func (r *NotificationStatsRepository) ResetByType(ctx context.Context, tx *gorm.DB, userID, starID int64, ntype string, now int64) error {
if tx == nil {
return errors.New("ResetByType must be called within a transaction")
}
gdb := tx.WithContext(ctx)
if ntype == "" {
return gdb.Exec(`
UPDATE public.notification_stats
SET like_unread_count = 0, system_unread_count = 0,
activity_unread_count = 0, total_unread_count = 0,
updated_at = $3
WHERE user_id = $1 AND star_id = $2
`, userID, starID, now).Error
}
col, err := typeToColumn(ntype)
if err != nil {
return err
}
// 仅减去本类型的数值,避免误扣其它类型的累计。
query := fmt.Sprintf(`
UPDATE public.notification_stats
SET %s = 0,
total_unread_count = GREATEST(0, total_unread_count - %[1]s),
updated_at = $3
WHERE user_id = $1 AND star_id = $2
`, col)
if err := gdb.Exec(query, userID, starID, now).Error; err != nil {
return fmt.Errorf("reset stats: %w", err)
}
return nil
}
// Get 拉取 user+star 的统计行;记录不存在时返回零值。
func (r *NotificationStatsRepository) Get(ctx context.Context, userID, starID int64) (*model.NotificationStats, error) {
s := &model.NotificationStats{}
err := r.db.WithContext(ctx).Raw(`
SELECT user_id, star_id, like_unread_count, system_unread_count,
activity_unread_count, total_unread_count, updated_at
FROM public.notification_stats
WHERE user_id = $1 AND star_id = $2
`, userID, starID).Row().Scan(
&s.UserID, &s.StarID, &s.LikeUnreadCount, &s.SystemUnreadCount,
&s.ActivityUnreadCount, &s.TotalUnreadCount, &s.UpdatedAt,
)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return &model.NotificationStats{UserID: userID, StarID: starID}, nil
}
return nil, fmt.Errorf("get stats: %w", err)
}
return s, nil
}
// typeToColumn 将通知类型映射到 stats 表的列名。
func typeToColumn(ntype string) (string, error) {
switch ntype {
case "like":
return "like_unread_count", nil
case "system":
return "system_unread_count", nil
case "activity":
return "activity_unread_count", nil
default:
return "", fmt.Errorf("invalid notification type: %s", ntype)
}
}