159 lines
5.2 KiB
Go
159 lines
5.2 KiB
Go
package repository
|
||
|
||
import (
|
||
"context"
|
||
"errors"
|
||
"fmt"
|
||
|
||
"github.com/topfans/backend/pkg/logger"
|
||
"github.com/topfans/backend/services/notificationService/model"
|
||
"go.uber.org/zap"
|
||
"gorm.io/gorm"
|
||
)
|
||
|
||
// NotificationStatsRepository 通知未读计数仓储(操作 public.notification_stats 表)。
|
||
//
|
||
// 设计:所有写方法要求事务上下文(service 层在事务回调内同时写 notifications 与 stats);
|
||
// Get 是单条读,使用仓储默认 db。
|
||
type NotificationStatsRepository struct {
|
||
db *gorm.DB
|
||
}
|
||
|
||
// NewNotificationStatsRepository 创建未读计数仓储。
|
||
func NewNotificationStatsRepository(db *gorm.DB) *NotificationStatsRepository {
|
||
return &NotificationStatsRepository{db: db}
|
||
}
|
||
|
||
// IncrementByType 在事务内对指定 type 未读数 +1(total +1)。
|
||
func (r *NotificationStatsRepository) IncrementByType(ctx context.Context, tx *gorm.DB, userID, starID int64, ntype string, now int64) error {
|
||
if tx == nil {
|
||
return errors.New("IncrementByType must be called within a transaction")
|
||
}
|
||
col, err := typeToColumn(ntype)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
// 使用 ON CONFLICT 做 upsert:首次写入时插入 1,后续累加。
|
||
// GORM Exec 支持 PostgreSQL 占位符 $1/$2/$3。
|
||
query := fmt.Sprintf(`
|
||
INSERT INTO public.notification_stats (user_id, star_id, %s, total_unread_count, updated_at)
|
||
VALUES ($1, $2, 1, 1, $3)
|
||
ON CONFLICT (user_id, star_id) DO UPDATE SET
|
||
%[1]s = public.notification_stats.%[1]s + 1,
|
||
total_unread_count = public.notification_stats.total_unread_count + 1,
|
||
updated_at = $3
|
||
`, col)
|
||
if err := tx.WithContext(ctx).Exec(query, userID, starID, now).Error; err != nil {
|
||
logger.Logger.Error("failed to increment notification stats",
|
||
zap.Int64("user_id", userID),
|
||
zap.Int64("star_id", starID),
|
||
zap.String("type", ntype),
|
||
zap.Error(err))
|
||
return fmt.Errorf("increment stats: %w", err)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// DecrementByType 在事务内对指定 type 未读数 -N(total -N)。
|
||
//
|
||
// 使用 GREATEST(0, ...) 防止计数被减为负数。
|
||
func (r *NotificationStatsRepository) DecrementByType(ctx context.Context, tx *gorm.DB, userID, starID int64, ntype string, delta int, now int64) error {
|
||
if delta <= 0 {
|
||
return nil
|
||
}
|
||
if tx == nil {
|
||
return errors.New("DecrementByType must be called within a transaction")
|
||
}
|
||
col, err := typeToColumn(ntype)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
query := fmt.Sprintf(`
|
||
UPDATE public.notification_stats
|
||
SET %s = GREATEST(0, %[1]s - $3),
|
||
total_unread_count = GREATEST(0, total_unread_count - $3),
|
||
updated_at = $4
|
||
WHERE user_id = $1 AND star_id = $2
|
||
`, col)
|
||
if err := tx.WithContext(ctx).Exec(query, userID, starID, delta, now).Error; err != nil {
|
||
logger.Logger.Error("failed to decrement notification stats",
|
||
zap.Int64("user_id", userID),
|
||
zap.Int64("star_id", starID),
|
||
zap.String("type", ntype),
|
||
zap.Int("delta", delta),
|
||
zap.Error(err))
|
||
return fmt.Errorf("decrement stats: %w", err)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// ResetByType 把指定 type 未读数置 0(同时从 total 中减去)。
|
||
//
|
||
// ntype 为空字符串表示重置全部(所有 type + total)。
|
||
func (r *NotificationStatsRepository) ResetByType(ctx context.Context, tx *gorm.DB, userID, starID int64, ntype string, now int64) error {
|
||
if tx == nil {
|
||
return errors.New("ResetByType must be called within a transaction")
|
||
}
|
||
gdb := tx.WithContext(ctx)
|
||
if ntype == "" {
|
||
return gdb.Exec(`
|
||
UPDATE public.notification_stats
|
||
SET like_unread_count = 0, system_unread_count = 0,
|
||
activity_unread_count = 0, total_unread_count = 0,
|
||
updated_at = $3
|
||
WHERE user_id = $1 AND star_id = $2
|
||
`, userID, starID, now).Error
|
||
}
|
||
col, err := typeToColumn(ntype)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
// 仅减去本类型的数值,避免误扣其它类型的累计。
|
||
query := fmt.Sprintf(`
|
||
UPDATE public.notification_stats
|
||
SET %s = 0,
|
||
total_unread_count = GREATEST(0, total_unread_count - %[1]s),
|
||
updated_at = $3
|
||
WHERE user_id = $1 AND star_id = $2
|
||
`, col)
|
||
if err := gdb.Exec(query, userID, starID, now).Error; err != nil {
|
||
return fmt.Errorf("reset stats: %w", err)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// Get 拉取 user+star 的统计行;记录不存在时返回零值。
|
||
func (r *NotificationStatsRepository) Get(ctx context.Context, userID, starID int64) (*model.NotificationStats, error) {
|
||
s := &model.NotificationStats{}
|
||
err := r.db.WithContext(ctx).Raw(`
|
||
SELECT user_id, star_id, like_unread_count, system_unread_count,
|
||
activity_unread_count, total_unread_count, updated_at
|
||
FROM public.notification_stats
|
||
WHERE user_id = $1 AND star_id = $2
|
||
`, userID, starID).Row().Scan(
|
||
&s.UserID, &s.StarID, &s.LikeUnreadCount, &s.SystemUnreadCount,
|
||
&s.ActivityUnreadCount, &s.TotalUnreadCount, &s.UpdatedAt,
|
||
)
|
||
if err != nil {
|
||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||
return &model.NotificationStats{UserID: userID, StarID: starID}, nil
|
||
}
|
||
return nil, fmt.Errorf("get stats: %w", err)
|
||
}
|
||
return s, nil
|
||
}
|
||
|
||
// typeToColumn 将通知类型映射到 stats 表的列名。
|
||
func typeToColumn(ntype string) (string, error) {
|
||
switch ntype {
|
||
case "like":
|
||
return "like_unread_count", nil
|
||
case "system":
|
||
return "system_unread_count", nil
|
||
case "activity":
|
||
return "activity_unread_count", nil
|
||
default:
|
||
return "", fmt.Errorf("invalid notification type: %s", ntype)
|
||
}
|
||
}
|