topfans/backend/services/notificationService/repository/notification_repository_test.go
2026-06-16 21:30:58 +08:00

177 lines
5.5 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package repository_test
import (
"context"
"fmt"
"os"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/topfans/backend/services/notificationService/model"
"github.com/topfans/backend/services/notificationService/repository"
"gorm.io/driver/postgres"
"gorm.io/gorm"
)
// setupTestDB 打开测试 DB若 TEST_DB_DSN 未设置或连不上则 t.Skip。
func setupTestDB(t *testing.T) *gorm.DB {
t.Helper()
dsn := os.Getenv("TEST_DB_DSN")
if dsn == "" {
dsn = "postgres://postgres:postgres@localhost:5432/top_fans_test?sslmode=disable"
}
db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{})
if err != nil {
t.Skipf("skipping: cannot open test DB: %v", err)
}
sqlDB, err := db.DB()
if err != nil {
t.Skipf("skipping: cannot get sql.DB from gorm: %v", err)
}
if err := sqlDB.Ping(); err != nil {
t.Skipf("skipping: cannot ping test DB: %v", err)
}
return db
}
// cleanupTestData 删除指定 user+star 的所有测试数据。
func cleanupTestData(t *testing.T, db *gorm.DB, userID, starID int64) {
t.Helper()
ctx := context.Background()
if err := db.WithContext(ctx).Exec(
`DELETE FROM public.notifications WHERE user_id=$1 AND star_id=$2`, userID, starID).Error; err != nil {
t.Logf("cleanup notifications failed: %v", err)
}
if err := db.WithContext(ctx).Exec(
`DELETE FROM public.notification_stats WHERE user_id=$1 AND star_id=$2`, userID, starID).Error; err != nil {
t.Logf("cleanup notification_stats failed: %v", err)
}
}
// TestNotificationRepository_CreateAndList 验证:单条 like 创建、列表查询、统计 +1。
func TestNotificationRepository_CreateAndList(t *testing.T) {
db := setupTestDB(t)
repo := repository.NewNotificationRepository(db)
statsRepo := repository.NewNotificationStatsRepository(db)
ctx := context.Background()
userID, starID := int64(990001), int64(1)
cleanupTestData(t, db, userID, starID)
defer cleanupTestData(t, db, userID, starID)
now := time.Now().UnixMilli()
var id int64
err := db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
newID, err := repo.Create(ctx, tx, &model.Notification{
UserID: userID, StarID: starID, Type: "like",
Title: "新点赞", Content: "test", Data: `{}`,
CreatedAt: now,
})
if err != nil {
return err
}
id = newID
return statsRepo.IncrementByType(ctx, tx, userID, starID, "like", now)
})
assert.NoError(t, err)
assert.Greater(t, id, int64(0))
items, total, err := repo.ListSystemActivity(ctx, userID, starID, "like", "", 1, 20)
assert.NoError(t, err)
assert.Equal(t, int64(1), total)
if assert.Len(t, items, 1) {
assert.Equal(t, id, items[0].ID)
}
stats, err := statsRepo.Get(ctx, userID, starID)
assert.NoError(t, err)
if assert.NotNil(t, stats) {
assert.Equal(t, 1, stats.LikeUnreadCount)
assert.Equal(t, 1, stats.TotalUnreadCount)
}
}
// TestNotificationRepository_LikeAggregation 验证5 条 like同一 target_id聚合成 1 条。
func TestNotificationRepository_LikeAggregation(t *testing.T) {
db := setupTestDB(t)
repo := repository.NewNotificationRepository(db)
ctx := context.Background()
userID, starID, targetID := int64(990002), int64(1), int64(8888)
cleanupTestData(t, db, userID, starID)
defer cleanupTestData(t, db, userID, starID)
now := time.Now().UnixMilli()
for i := 0; i < 5; i++ {
err := db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
_, err := repo.Create(ctx, tx, &model.Notification{
UserID: userID, StarID: starID, Type: "like",
Title: "新点赞",
Data: fmt.Sprintf(`{"target_id": %d, "actor_id": %d, "actor_name": "测试%d", "actor_avatar": "https://x/a.png"}`,
targetID, 1000+i, i),
CreatedAt: now + int64(i)*1000,
})
return err
})
assert.NoError(t, err)
}
items, total, err := repo.ListLikesAggregated(ctx, userID, starID, "", 1, 20)
assert.NoError(t, err)
assert.Equal(t, int64(1), total)
if assert.Len(t, items, 1) {
assert.Equal(t, int32(5), items[0].TotalCount)
assert.Equal(t, targetID, items[0].TargetID)
assert.Len(t, items[0].Actors, 5)
}
}
// TestNotificationRepository_MarkAsReadByTarget 验证:标 target 已读影响 3 条,统计 -3。
func TestNotificationRepository_MarkAsReadByTarget(t *testing.T) {
db := setupTestDB(t)
repo := repository.NewNotificationRepository(db)
statsRepo := repository.NewNotificationStatsRepository(db)
ctx := context.Background()
userID, starID, targetID := int64(990003), int64(1), int64(7777)
cleanupTestData(t, db, userID, starID)
defer cleanupTestData(t, db, userID, starID)
now := time.Now().UnixMilli()
for i := 0; i < 3; i++ {
err := db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
_, err := repo.Create(ctx, tx, &model.Notification{
UserID: userID, StarID: starID, Type: "like",
Title: "x",
Data: fmt.Sprintf(`{"target_id": %d, "actor_id": %d, "actor_name": "x", "actor_avatar": ""}`, targetID, i+1),
})
if err != nil {
return err
}
return statsRepo.IncrementByType(ctx, tx, userID, starID, "like", now)
})
assert.NoError(t, err)
}
var affected int32
err := db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
a, err := repo.MarkAsReadByTarget(ctx, tx, userID, starID, targetID, now)
if err != nil {
return err
}
affected = a
return statsRepo.DecrementByType(ctx, tx, userID, starID, "like", int(a), now)
})
assert.NoError(t, err)
assert.Equal(t, int32(3), affected)
stats, err := statsRepo.Get(ctx, userID, starID)
assert.NoError(t, err)
if assert.NotNil(t, stats) {
assert.Equal(t, 0, stats.LikeUnreadCount)
assert.Equal(t, 0, stats.TotalUnreadCount)
}
}