177 lines
5.5 KiB
Go
177 lines
5.5 KiB
Go
package repository_test
|
||
|
||
import (
|
||
"context"
|
||
"fmt"
|
||
"os"
|
||
"testing"
|
||
"time"
|
||
|
||
"github.com/stretchr/testify/assert"
|
||
"github.com/topfans/backend/services/notificationService/model"
|
||
"github.com/topfans/backend/services/notificationService/repository"
|
||
"gorm.io/driver/postgres"
|
||
"gorm.io/gorm"
|
||
)
|
||
|
||
// setupTestDB 打开测试 DB;若 TEST_DB_DSN 未设置或连不上则 t.Skip。
|
||
func setupTestDB(t *testing.T) *gorm.DB {
|
||
t.Helper()
|
||
dsn := os.Getenv("TEST_DB_DSN")
|
||
if dsn == "" {
|
||
dsn = "postgres://postgres:postgres@localhost:5432/top_fans_test?sslmode=disable"
|
||
}
|
||
db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{})
|
||
if err != nil {
|
||
t.Skipf("skipping: cannot open test DB: %v", err)
|
||
}
|
||
sqlDB, err := db.DB()
|
||
if err != nil {
|
||
t.Skipf("skipping: cannot get sql.DB from gorm: %v", err)
|
||
}
|
||
if err := sqlDB.Ping(); err != nil {
|
||
t.Skipf("skipping: cannot ping test DB: %v", err)
|
||
}
|
||
return db
|
||
}
|
||
|
||
// cleanupTestData 删除指定 user+star 的所有测试数据。
|
||
func cleanupTestData(t *testing.T, db *gorm.DB, userID, starID int64) {
|
||
t.Helper()
|
||
ctx := context.Background()
|
||
if err := db.WithContext(ctx).Exec(
|
||
`DELETE FROM public.notifications WHERE user_id=$1 AND star_id=$2`, userID, starID).Error; err != nil {
|
||
t.Logf("cleanup notifications failed: %v", err)
|
||
}
|
||
if err := db.WithContext(ctx).Exec(
|
||
`DELETE FROM public.notification_stats WHERE user_id=$1 AND star_id=$2`, userID, starID).Error; err != nil {
|
||
t.Logf("cleanup notification_stats failed: %v", err)
|
||
}
|
||
}
|
||
|
||
// TestNotificationRepository_CreateAndList 验证:单条 like 创建、列表查询、统计 +1。
|
||
func TestNotificationRepository_CreateAndList(t *testing.T) {
|
||
db := setupTestDB(t)
|
||
repo := repository.NewNotificationRepository(db)
|
||
statsRepo := repository.NewNotificationStatsRepository(db)
|
||
ctx := context.Background()
|
||
|
||
userID, starID := int64(990001), int64(1)
|
||
cleanupTestData(t, db, userID, starID)
|
||
defer cleanupTestData(t, db, userID, starID)
|
||
|
||
now := time.Now().UnixMilli()
|
||
|
||
var id int64
|
||
err := db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||
newID, err := repo.Create(ctx, tx, &model.Notification{
|
||
UserID: userID, StarID: starID, Type: "like",
|
||
Title: "新点赞", Content: "test", Data: `{}`,
|
||
CreatedAt: now,
|
||
})
|
||
if err != nil {
|
||
return err
|
||
}
|
||
id = newID
|
||
return statsRepo.IncrementByType(ctx, tx, userID, starID, "like", now)
|
||
})
|
||
assert.NoError(t, err)
|
||
assert.Greater(t, id, int64(0))
|
||
|
||
items, total, err := repo.ListSystemActivity(ctx, userID, starID, "like", "", 1, 20)
|
||
assert.NoError(t, err)
|
||
assert.Equal(t, int64(1), total)
|
||
if assert.Len(t, items, 1) {
|
||
assert.Equal(t, id, items[0].ID)
|
||
}
|
||
|
||
stats, err := statsRepo.Get(ctx, userID, starID)
|
||
assert.NoError(t, err)
|
||
if assert.NotNil(t, stats) {
|
||
assert.Equal(t, 1, stats.LikeUnreadCount)
|
||
assert.Equal(t, 1, stats.TotalUnreadCount)
|
||
}
|
||
}
|
||
|
||
// TestNotificationRepository_LikeAggregation 验证:5 条 like(同一 target_id)聚合成 1 条。
|
||
func TestNotificationRepository_LikeAggregation(t *testing.T) {
|
||
db := setupTestDB(t)
|
||
repo := repository.NewNotificationRepository(db)
|
||
ctx := context.Background()
|
||
|
||
userID, starID, targetID := int64(990002), int64(1), int64(8888)
|
||
cleanupTestData(t, db, userID, starID)
|
||
defer cleanupTestData(t, db, userID, starID)
|
||
|
||
now := time.Now().UnixMilli()
|
||
for i := 0; i < 5; i++ {
|
||
err := db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||
_, err := repo.Create(ctx, tx, &model.Notification{
|
||
UserID: userID, StarID: starID, Type: "like",
|
||
Title: "新点赞",
|
||
Data: fmt.Sprintf(`{"target_id": %d, "actor_id": %d, "actor_name": "测试%d", "actor_avatar": "https://x/a.png"}`,
|
||
targetID, 1000+i, i),
|
||
CreatedAt: now + int64(i)*1000,
|
||
})
|
||
return err
|
||
})
|
||
assert.NoError(t, err)
|
||
}
|
||
|
||
items, total, err := repo.ListLikesAggregated(ctx, userID, starID, "", 1, 20)
|
||
assert.NoError(t, err)
|
||
assert.Equal(t, int64(1), total)
|
||
if assert.Len(t, items, 1) {
|
||
assert.Equal(t, int32(5), items[0].TotalCount)
|
||
assert.Equal(t, targetID, items[0].TargetID)
|
||
assert.Len(t, items[0].Actors, 5)
|
||
}
|
||
}
|
||
|
||
// TestNotificationRepository_MarkAsReadByTarget 验证:标 target 已读影响 3 条,统计 -3。
|
||
func TestNotificationRepository_MarkAsReadByTarget(t *testing.T) {
|
||
db := setupTestDB(t)
|
||
repo := repository.NewNotificationRepository(db)
|
||
statsRepo := repository.NewNotificationStatsRepository(db)
|
||
ctx := context.Background()
|
||
|
||
userID, starID, targetID := int64(990003), int64(1), int64(7777)
|
||
cleanupTestData(t, db, userID, starID)
|
||
defer cleanupTestData(t, db, userID, starID)
|
||
|
||
now := time.Now().UnixMilli()
|
||
for i := 0; i < 3; i++ {
|
||
err := db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||
_, err := repo.Create(ctx, tx, &model.Notification{
|
||
UserID: userID, StarID: starID, Type: "like",
|
||
Title: "x",
|
||
Data: fmt.Sprintf(`{"target_id": %d, "actor_id": %d, "actor_name": "x", "actor_avatar": ""}`, targetID, i+1),
|
||
})
|
||
if err != nil {
|
||
return err
|
||
}
|
||
return statsRepo.IncrementByType(ctx, tx, userID, starID, "like", now)
|
||
})
|
||
assert.NoError(t, err)
|
||
}
|
||
|
||
var affected int32
|
||
err := db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||
a, err := repo.MarkAsReadByTarget(ctx, tx, userID, starID, targetID, now)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
affected = a
|
||
return statsRepo.DecrementByType(ctx, tx, userID, starID, "like", int(a), now)
|
||
})
|
||
assert.NoError(t, err)
|
||
assert.Equal(t, int32(3), affected)
|
||
|
||
stats, err := statsRepo.Get(ctx, userID, starID)
|
||
assert.NoError(t, err)
|
||
if assert.NotNil(t, stats) {
|
||
assert.Equal(t, 0, stats.LikeUnreadCount)
|
||
assert.Equal(t, 0, stats.TotalUnreadCount)
|
||
}
|
||
}
|