topfans/backend/services/notificationService/repository/notification_repository.go
2026-06-16 21:30:58 +08:00

334 lines
12 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package repository
import (
"context"
"encoding/json"
"errors"
"fmt"
"time"
"github.com/topfans/backend/pkg/logger"
"github.com/topfans/backend/services/notificationService/model"
"go.uber.org/zap"
"gorm.io/gorm"
)
// NotificationRepository 通知仓储层(操作 public.notifications 表)。
//
// 设计约定:
// - 所有需要事务控制的方法接受 *gorm.DB由 service 层在事务回调内传入 tx。
// 传入的既可以是 r.db 本身(非事务场景),也可以是 db.Transaction(...) 内
// 的 tx事务场景。这种做法让仓储方法既能单独调用、又能复用于事务中。
// - 复杂聚合查询走 Raw SQLPostgreSQL JSONB 表达式),其余场景优先 GORM API。
type NotificationRepository struct {
db *gorm.DB
}
// NewNotificationRepository 创建通知仓储。
func NewNotificationRepository(db *gorm.DB) *NotificationRepository {
return &NotificationRepository{db: db}
}
// execDB 返回带 ctx 的执行句柄(优先使用传入的 tx否则使用仓储默认 db
func (r *NotificationRepository) execDB(tx *gorm.DB, ctx context.Context) *gorm.DB {
if tx != nil {
return tx.WithContext(ctx)
}
return r.db.WithContext(ctx)
}
// Create 插入通知。
//
// 必须传入事务 tx外层 service 通过 db.Transaction(...) 包裹写入与统计更新)。
func (r *NotificationRepository) Create(ctx context.Context, tx *gorm.DB, n *model.Notification) (int64, error) {
if n == nil {
return 0, errors.New("notification is nil")
}
if tx == nil {
return 0, errors.New("Create must be called within a transaction")
}
now := time.Now().UnixMilli()
if n.CreatedAt == 0 {
n.CreatedAt = now
}
gdb := tx.WithContext(ctx)
if err := gdb.Exec(`
INSERT INTO public.notifications
(user_id, star_id, type, title, content, data, is_read, is_deleted, created_at, read_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
`, n.UserID, n.StarID, n.Type, n.Title, n.Content, n.Data, n.IsRead, n.IsDeleted, n.CreatedAt, n.ReadAt).Error; err != nil {
logger.Logger.Error("failed to insert notification", zap.Error(err))
return 0, fmt.Errorf("insert notification: %w", err)
}
// 取自增值:通过 currval 拿序列当前值(同事务内安全)。
var id int64
if err := gdb.Raw(`SELECT currval(pg_get_serial_sequence('public.notifications','id'))`).Scan(&id).Error; err != nil {
logger.Logger.Error("failed to fetch inserted notification id", zap.Error(err))
return 0, fmt.Errorf("fetch inserted id: %w", err)
}
return id, nil
}
// ListSystemActivity 列出 system / activity 通知(非聚合)。
func (r *NotificationRepository) ListSystemActivity(ctx context.Context, userID, starID int64, ntype, tab string, page, pageSize int) ([]*model.Notification, int64, error) {
args := []interface{}{userID, starID, ntype}
where := "user_id = $1 AND star_id = $2 AND type = $3 AND is_deleted = FALSE"
if tab == "today" {
where += " AND created_at >= $4"
args = append(args, startOfTodayMs())
} else if tab == "history" {
where += " AND created_at < $4"
args = append(args, startOfTodayMs())
}
gdb := r.db.WithContext(ctx)
var total int64
if err := gdb.Raw("SELECT COUNT(*) FROM public.notifications WHERE "+where, args...).Scan(&total).Error; err != nil {
return nil, 0, fmt.Errorf("count notifications: %w", err)
}
offset := (page - 1) * pageSize
args = append(args, pageSize, offset)
limitIdx := len(args) - 1
offsetIdx := len(args)
query := fmt.Sprintf(`
SELECT id, user_id, star_id, type, title, COALESCE(content,'') AS content, data,
is_read, is_deleted, created_at, COALESCE(read_at, 0) AS read_at
FROM public.notifications
WHERE %s
ORDER BY created_at DESC
LIMIT $%d OFFSET $%d
`, where, limitIdx, offsetIdx)
var items []*model.Notification
if err := gdb.Raw(query, args...).Scan(&items).Error; err != nil {
return nil, 0, fmt.Errorf("list notifications: %w", err)
}
return items, total, nil
}
// ListLikesAggregated 列出 like 通知(按 target_id 聚合)。
func (r *NotificationRepository) ListLikesAggregated(ctx context.Context, userID, starID int64, tab string, page, pageSize int) ([]*model.AggregatedNotification, int64, error) {
args := []interface{}{userID, starID}
gdb := r.db.WithContext(ctx)
var total int64
countQuery := `
SELECT COUNT(*) FROM (
SELECT (data->>'target_id')::bigint AS target_id
FROM public.notifications
WHERE user_id=$1 AND star_id=$2 AND type='like' AND is_deleted=FALSE
GROUP BY (data->>'target_id')
) t
`
if err := gdb.Raw(countQuery, userID, starID).Scan(&total).Error; err != nil {
return nil, 0, fmt.Errorf("count likes aggregated: %w", err)
}
offset := (page - 1) * pageSize
args = append(args, pageSize, offset)
limitIdx, offsetIdx := len(args)-1, len(args)
query := fmt.Sprintf(`
WITH agg AS (
SELECT
(data->>'target_id')::bigint AS target_id,
COUNT(*) AS total_count,
MAX(created_at) AS latest_at,
BOOL_AND(is_read) AS all_read
FROM public.notifications
WHERE user_id=$1 AND star_id=$2 AND type='like' AND is_deleted=FALSE
GROUP BY (data->>'target_id')
),
first_notif AS (
SELECT DISTINCT ON ((data->>'target_id')::bigint)
(data->>'target_id')::bigint AS target_id,
id, title, content, data, read_at
FROM public.notifications
WHERE user_id=$1 AND star_id=$2 AND type='like' AND is_deleted=FALSE
ORDER BY (data->>'target_id')::bigint, created_at DESC
),
actors AS (
SELECT (data->>'target_id')::bigint AS target_id,
json_agg(json_build_object(
'user_id', (data->>'actor_id')::bigint,
'nickname', COALESCE(data->>'actor_name', ''),
'avatar', COALESCE(data->>'actor_avatar', ''),
'liked_at', created_at
) ORDER BY created_at DESC) AS actor_previews
FROM public.notifications
WHERE user_id=$1 AND star_id=$2 AND type='like' AND is_deleted=FALSE
GROUP BY (data->>'target_id')
)
SELECT
a.target_id, a.total_count, a.latest_at, a.all_read,
f.id, f.title, f.content, f.data, f.read_at,
COALESCE(act.actor_previews, '[]'::json) AS actor_previews
FROM agg a
JOIN first_notif f ON f.target_id = a.target_id
LEFT JOIN actors act ON act.target_id = a.target_id
ORDER BY a.latest_at DESC
LIMIT $%d OFFSET $%d
`, limitIdx, offsetIdx)
rows, err := gdb.Raw(query, args...).Rows()
if err != nil {
return nil, 0, fmt.Errorf("list likes aggregated: %w", err)
}
defer rows.Close()
items := make([]*model.AggregatedNotification, 0, pageSize)
for rows.Next() {
var item model.AggregatedNotification
var actorPreviewsJSON []byte
if err := rows.Scan(
&item.TargetID, &item.TotalCount, &item.CreatedAt, &item.IsRead,
&item.ID, &item.Title, &item.Content, &item.Data, &item.ReadAt,
&actorPreviewsJSON,
); err != nil {
return nil, 0, fmt.Errorf("scan aggregated row: %w", err)
}
item.UserID = userID
item.StarID = starID
item.Type = "like"
item.Actors = parseActorLikes(actorPreviewsJSON)
items = append(items, &item)
}
if err := rows.Err(); err != nil {
return nil, 0, err
}
return items, total, nil
}
// MarkAsReadByID 单条标已读。
func (r *NotificationRepository) MarkAsReadByID(ctx context.Context, tx *gorm.DB, userID, starID, id, now int64) (int32, error) {
if tx == nil {
return 0, errors.New("MarkAsReadByID must be called within a transaction")
}
res := tx.WithContext(ctx).Exec(`
UPDATE public.notifications
SET is_read = TRUE, read_at = $4
WHERE id = $1 AND user_id = $2 AND star_id = $3 AND is_read = FALSE AND is_deleted = FALSE
`, id, userID, starID, now)
if res.Error != nil {
return 0, fmt.Errorf("mark as read by id: %w", res.Error)
}
return int32(res.RowsAffected), nil
}
// MarkAsReadByTarget 将指定 target_id 下所有未读 like 标已读。
func (r *NotificationRepository) MarkAsReadByTarget(ctx context.Context, tx *gorm.DB, userID, starID, targetID, now int64) (int32, error) {
if tx == nil {
return 0, errors.New("MarkAsReadByTarget must be called within a transaction")
}
res := tx.WithContext(ctx).Exec(`
UPDATE public.notifications
SET is_read = TRUE, read_at = $5
WHERE user_id=$1 AND star_id=$2 AND type='like'
AND (data->>'target_id')::bigint = $3
AND is_read = FALSE AND is_deleted = FALSE
`, userID, starID, targetID, now)
if res.Error != nil {
return 0, fmt.Errorf("mark as read by target: %w", res.Error)
}
return int32(res.RowsAffected), nil
}
// MarkAllAsRead 将某类型未读通知全部标已读。
//
// ntype: "like" / "system" / "activity"。
func (r *NotificationRepository) MarkAllAsRead(ctx context.Context, tx *gorm.DB, userID, starID int64, ntype string, now int64) (int32, error) {
if tx == nil {
return 0, errors.New("MarkAllAsRead must be called within a transaction")
}
res := tx.WithContext(ctx).Exec(`
UPDATE public.notifications
SET is_read = TRUE, read_at = $4
WHERE user_id=$1 AND star_id=$2 AND type=$3 AND is_read=FALSE AND is_deleted=FALSE
`, userID, starID, ntype, now)
if res.Error != nil {
return 0, fmt.Errorf("mark all as read: %w", res.Error)
}
return int32(res.RowsAffected), nil
}
// SoftDeleteByID 软删单条通知。
func (r *NotificationRepository) SoftDeleteByID(ctx context.Context, tx *gorm.DB, userID, starID, id int64) (int32, error) {
if tx == nil {
return 0, errors.New("SoftDeleteByID must be called within a transaction")
}
res := tx.WithContext(ctx).Exec(`
UPDATE public.notifications
SET is_deleted = TRUE
WHERE id = $1 AND user_id = $2 AND star_id = $3 AND is_deleted = FALSE
`, id, userID, starID)
if res.Error != nil {
return 0, fmt.Errorf("soft delete by id: %w", res.Error)
}
return int32(res.RowsAffected), nil
}
// SoftDeleteByTarget 软删某 target 下所有 like 通知。
func (r *NotificationRepository) SoftDeleteByTarget(ctx context.Context, tx *gorm.DB, userID, starID, targetID int64) (int32, error) {
if tx == nil {
return 0, errors.New("SoftDeleteByTarget must be called within a transaction")
}
res := tx.WithContext(ctx).Exec(`
UPDATE public.notifications
SET is_deleted = TRUE
WHERE user_id=$1 AND star_id=$2 AND type='like'
AND (data->>'target_id')::bigint = $3 AND is_deleted = FALSE
`, userID, starID, targetID)
if res.Error != nil {
return 0, fmt.Errorf("soft delete by target: %w", res.Error)
}
return int32(res.RowsAffected), nil
}
// GetTypeByID 查询通知类型与是否已读(用于 service 层做安全校验)。
func (r *NotificationRepository) GetTypeByID(ctx context.Context, tx *gorm.DB, id, userID, starID int64) (string, bool, error) {
gdb := r.execDB(tx, ctx)
var ntype string
var isRead bool
err := gdb.Raw(`
SELECT type, is_read FROM public.notifications
WHERE id=$1 AND user_id=$2 AND star_id=$3 AND is_deleted=FALSE
`, id, userID, starID).Row().Scan(&ntype, &isRead)
if errors.Is(err, gorm.ErrRecordNotFound) {
return "", false, nil
}
if err != nil {
return "", false, fmt.Errorf("get type by id: %w", err)
}
return ntype, isRead, nil
}
// startOfTodayMs 今日 0 点毫秒时间戳(用于 today/history tab 切分)。
func startOfTodayMs() int64 {
now := time.Now()
return time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location()).UnixMilli()
}
// parseActorLikes 把 actor_previews 的 JSON 反序列化为 model.ActorPreview 列表。
func parseActorLikes(data []byte) []model.ActorPreview {
type rawActor struct {
UserID int64 `json:"user_id"`
Nickname string `json:"nickname"`
Avatar string `json:"avatar"`
LikedAt int64 `json:"liked_at"`
}
var raws []rawActor
if err := json.Unmarshal(data, &raws); err != nil {
return nil
}
out := make([]model.ActorPreview, 0, len(raws))
for _, r := range raws {
out = append(out, model.ActorPreview{
UserID: r.UserID,
Nickname: r.Nickname,
Avatar: r.Avatar,
LikedAt: r.LikedAt,
})
}
return out
}