package repository_test import ( "context" "fmt" "os" "testing" "time" "github.com/stretchr/testify/assert" "github.com/topfans/backend/services/notificationService/model" "github.com/topfans/backend/services/notificationService/repository" "gorm.io/driver/postgres" "gorm.io/gorm" ) // setupTestDB 打开测试 DB;若 TEST_DB_DSN 未设置或连不上则 t.Skip。 func setupTestDB(t *testing.T) *gorm.DB { t.Helper() dsn := os.Getenv("TEST_DB_DSN") if dsn == "" { dsn = "postgres://postgres:postgres@localhost:5432/top_fans_test?sslmode=disable" } db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{}) if err != nil { t.Skipf("skipping: cannot open test DB: %v", err) } sqlDB, err := db.DB() if err != nil { t.Skipf("skipping: cannot get sql.DB from gorm: %v", err) } if err := sqlDB.Ping(); err != nil { t.Skipf("skipping: cannot ping test DB: %v", err) } return db } // cleanupTestData 删除指定 user+star 的所有测试数据。 func cleanupTestData(t *testing.T, db *gorm.DB, userID, starID int64) { t.Helper() ctx := context.Background() if err := db.WithContext(ctx).Exec( `DELETE FROM public.notifications WHERE user_id=$1 AND star_id=$2`, userID, starID).Error; err != nil { t.Logf("cleanup notifications failed: %v", err) } if err := db.WithContext(ctx).Exec( `DELETE FROM public.notification_stats WHERE user_id=$1 AND star_id=$2`, userID, starID).Error; err != nil { t.Logf("cleanup notification_stats failed: %v", err) } } // TestNotificationRepository_CreateAndList 验证:单条 like 创建、列表查询、统计 +1。 func TestNotificationRepository_CreateAndList(t *testing.T) { db := setupTestDB(t) repo := repository.NewNotificationRepository(db) statsRepo := repository.NewNotificationStatsRepository(db) ctx := context.Background() userID, starID := int64(990001), int64(1) cleanupTestData(t, db, userID, starID) defer cleanupTestData(t, db, userID, starID) now := time.Now().UnixMilli() var id int64 err := db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { newID, err := repo.Create(ctx, tx, &model.Notification{ UserID: userID, StarID: starID, Type: "like", Title: "新点赞", Content: "test", Data: `{}`, CreatedAt: now, }) if err != nil { return err } id = newID return statsRepo.IncrementByType(ctx, tx, userID, starID, "like", now) }) assert.NoError(t, err) assert.Greater(t, id, int64(0)) items, total, err := repo.ListSystemActivity(ctx, userID, starID, "like", "", 1, 20) assert.NoError(t, err) assert.Equal(t, int64(1), total) if assert.Len(t, items, 1) { assert.Equal(t, id, items[0].ID) } stats, err := statsRepo.Get(ctx, userID, starID) assert.NoError(t, err) if assert.NotNil(t, stats) { assert.Equal(t, 1, stats.LikeUnreadCount) assert.Equal(t, 1, stats.TotalUnreadCount) } } // TestNotificationRepository_LikeAggregation 验证:5 条 like(同一 target_id)聚合成 1 条。 func TestNotificationRepository_LikeAggregation(t *testing.T) { db := setupTestDB(t) repo := repository.NewNotificationRepository(db) ctx := context.Background() userID, starID, targetID := int64(990002), int64(1), int64(8888) cleanupTestData(t, db, userID, starID) defer cleanupTestData(t, db, userID, starID) now := time.Now().UnixMilli() for i := 0; i < 5; i++ { err := db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { _, err := repo.Create(ctx, tx, &model.Notification{ UserID: userID, StarID: starID, Type: "like", Title: "新点赞", Data: fmt.Sprintf(`{"target_id": %d, "actor_id": %d, "actor_name": "测试%d", "actor_avatar": "https://x/a.png"}`, targetID, 1000+i, i), CreatedAt: now + int64(i)*1000, }) return err }) assert.NoError(t, err) } items, total, err := repo.ListLikesAggregated(ctx, userID, starID, "", 1, 20) assert.NoError(t, err) assert.Equal(t, int64(1), total) if assert.Len(t, items, 1) { assert.Equal(t, int32(5), items[0].TotalCount) assert.Equal(t, targetID, items[0].TargetID) assert.Len(t, items[0].Actors, 5) } } // TestNotificationRepository_MarkAsReadByTarget 验证:标 target 已读影响 3 条,统计 -3。 func TestNotificationRepository_MarkAsReadByTarget(t *testing.T) { db := setupTestDB(t) repo := repository.NewNotificationRepository(db) statsRepo := repository.NewNotificationStatsRepository(db) ctx := context.Background() userID, starID, targetID := int64(990003), int64(1), int64(7777) cleanupTestData(t, db, userID, starID) defer cleanupTestData(t, db, userID, starID) now := time.Now().UnixMilli() for i := 0; i < 3; i++ { err := db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { _, err := repo.Create(ctx, tx, &model.Notification{ UserID: userID, StarID: starID, Type: "like", Title: "x", Data: fmt.Sprintf(`{"target_id": %d, "actor_id": %d, "actor_name": "x", "actor_avatar": ""}`, targetID, i+1), }) if err != nil { return err } return statsRepo.IncrementByType(ctx, tx, userID, starID, "like", now) }) assert.NoError(t, err) } var affected int32 err := db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { a, err := repo.MarkAsReadByTarget(ctx, tx, userID, starID, targetID, now) if err != nil { return err } affected = a return statsRepo.DecrementByType(ctx, tx, userID, starID, "like", int(a), now) }) assert.NoError(t, err) assert.Equal(t, int32(3), affected) stats, err := statsRepo.Get(ctx, userID, starID) assert.NoError(t, err) if assert.NotNil(t, stats) { assert.Equal(t, 0, stats.LikeUnreadCount) assert.Equal(t, 0, stats.TotalUnreadCount) } }