219 lines
6.2 KiB
Go
219 lines
6.2 KiB
Go
package service
|
||
|
||
import (
|
||
"context"
|
||
"os"
|
||
"strings"
|
||
"testing"
|
||
|
||
"github.com/stretchr/testify/assert"
|
||
"github.com/topfans/backend/pkg/proto/notification"
|
||
"github.com/topfans/backend/services/notificationService/model"
|
||
"google.golang.org/protobuf/types/known/structpb"
|
||
"gorm.io/driver/postgres"
|
||
"gorm.io/gorm"
|
||
)
|
||
|
||
// setupTestDB 打开测试 DB;若 TEST_DB_DSN 未设置或连不上则返回 nil + false,由调用方 skip。
|
||
func setupTestDB(t *testing.T) (*gorm.DB, bool) {
|
||
t.Helper()
|
||
dsn := os.Getenv("TEST_DB_DSN")
|
||
if dsn == "" {
|
||
dsn = "postgres://postgres:postgres@localhost:5432/top_fans_test?sslmode=disable"
|
||
}
|
||
db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{})
|
||
if err != nil {
|
||
return nil, false
|
||
}
|
||
sqlDB, err := db.DB()
|
||
if err != nil {
|
||
return nil, false
|
||
}
|
||
if err := sqlDB.Ping(); err != nil {
|
||
return nil, false
|
||
}
|
||
return db, true
|
||
}
|
||
|
||
// TestCreateNotification_Validation 覆盖 CreateNotification 的参数校验分支:
|
||
// 不需要 DB,所有失败路径应该在 service 层就拦截。
|
||
func TestCreateNotification_Validation(t *testing.T) {
|
||
svc := NewNotificationService(nil, nil, nil) // nil db:参数校验失败不进入 DB;device/pusher nil:跳过推送
|
||
|
||
tests := []struct {
|
||
name string
|
||
req *notification.CreateNotificationRequest
|
||
wantErr bool
|
||
errMsg string
|
||
}{
|
||
{
|
||
name: "missing user_id",
|
||
req: ¬ification.CreateNotificationRequest{StarId: 1, Type: "system", Title: "hi"},
|
||
wantErr: true,
|
||
errMsg: "user_id",
|
||
},
|
||
{
|
||
name: "missing star_id",
|
||
req: ¬ification.CreateNotificationRequest{UserId: 1, Type: "system", Title: "hi"},
|
||
wantErr: true,
|
||
errMsg: "star_id",
|
||
},
|
||
{
|
||
name: "missing type",
|
||
req: ¬ification.CreateNotificationRequest{UserId: 1, StarId: 1, Title: "hi"},
|
||
wantErr: true,
|
||
errMsg: "type",
|
||
},
|
||
{
|
||
name: "invalid type",
|
||
req: ¬ification.CreateNotificationRequest{UserId: 1, StarId: 1, Type: "garbage", Title: "hi"},
|
||
wantErr: true,
|
||
errMsg: "type",
|
||
},
|
||
{
|
||
name: "empty title",
|
||
req: ¬ification.CreateNotificationRequest{UserId: 1, StarId: 1, Type: "system", Title: " "},
|
||
wantErr: true,
|
||
errMsg: "title",
|
||
},
|
||
{
|
||
name: "title too long",
|
||
req: ¬ification.CreateNotificationRequest{UserId: 1, StarId: 1, Type: "system", Title: strings.Repeat("a", 201)},
|
||
wantErr: true,
|
||
errMsg: "title",
|
||
},
|
||
{
|
||
name: "content too long",
|
||
req: ¬ification.CreateNotificationRequest{UserId: 1, StarId: 1, Type: "system", Title: "ok", Content: strings.Repeat("a", 501)},
|
||
wantErr: true,
|
||
errMsg: "content",
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
resp, err := svc.CreateNotification(context.Background(), tt.req)
|
||
assert.Error(t, err, "expected validation error for case %s", tt.name)
|
||
assert.Nil(t, resp, "response should be nil on validation error")
|
||
if tt.errMsg != "" && err != nil {
|
||
assert.Contains(t, err.Error(), tt.errMsg,
|
||
"error message should mention %s, got: %v", tt.errMsg, err)
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
// TestCreateNotification_TransactionRollback 需要真实 DB;缺 DB 时 skip。
|
||
// 验证:参数全部合法时,事务被打开、能完成 insert。
|
||
func TestCreateNotification_TransactionRollback(t *testing.T) {
|
||
db, ok := setupTestDB(t)
|
||
if !ok {
|
||
t.Skip("skipping: test DB not available")
|
||
}
|
||
svc := NewNotificationService(db, nil, nil)
|
||
ctx := context.Background()
|
||
|
||
data, _ := structpb.NewStruct(map[string]interface{}{"foo": "bar"})
|
||
req := ¬ification.CreateNotificationRequest{
|
||
UserId: 880001,
|
||
StarId: 1,
|
||
Type: "system",
|
||
Title: "rollback-test",
|
||
Data: data,
|
||
}
|
||
|
||
resp, err := svc.CreateNotification(ctx, req)
|
||
if err != nil {
|
||
t.Fatalf("CreateNotification failed: %v", err)
|
||
}
|
||
assert.NotNil(t, resp)
|
||
assert.NotZero(t, resp.Id)
|
||
|
||
// cleanup
|
||
t.Cleanup(func() {
|
||
_ = db.WithContext(ctx).Exec(
|
||
`DELETE FROM public.notifications WHERE user_id=$1`, req.UserId).Error
|
||
_ = db.WithContext(ctx).Exec(
|
||
`DELETE FROM public.notification_stats WHERE user_id=$1`, req.UserId).Error
|
||
})
|
||
}
|
||
|
||
// TestBuildAggregatedLikeTitle 是纯函数测试,不依赖 DB。
|
||
// 直接覆盖 buildAggregatedLikeTitle 的 6 个分支(package service 才能访问私有函数)。
|
||
func TestBuildAggregatedLikeTitle(t *testing.T) {
|
||
cases := []struct {
|
||
name string
|
||
actors []model.ActorPreview
|
||
total int32
|
||
assetTitle string
|
||
want string
|
||
}{
|
||
{
|
||
name: "0 actors with total",
|
||
actors: nil,
|
||
total: 3,
|
||
assetTitle: "藏品A",
|
||
want: "有 3 人赞了你的《藏品A》",
|
||
},
|
||
{
|
||
name: "1 actor",
|
||
actors: []model.ActorPreview{{UserID: 1, Nickname: "张三"}},
|
||
total: 1,
|
||
assetTitle: "藏品A",
|
||
want: "张三 赞了你的《藏品A》",
|
||
},
|
||
{
|
||
name: "2 actors",
|
||
actors: []model.ActorPreview{
|
||
{UserID: 1, Nickname: "张三"},
|
||
{UserID: 2, Nickname: "李四"},
|
||
},
|
||
total: 2,
|
||
assetTitle: "藏品A",
|
||
want: "张三、李四 赞了你的《藏品A》",
|
||
},
|
||
{
|
||
name: "3 actors with total=10",
|
||
actors: []model.ActorPreview{
|
||
{UserID: 1, Nickname: "张三"},
|
||
{UserID: 2, Nickname: "李四"},
|
||
{UserID: 3, Nickname: "王五"},
|
||
},
|
||
total: 10,
|
||
assetTitle: "藏品A",
|
||
want: "张三、李四 等 10 人赞了你的《藏品A》",
|
||
},
|
||
{
|
||
name: "actor nickname empty fallback to 用户{id}",
|
||
actors: []model.ActorPreview{{UserID: 42, Nickname: ""}},
|
||
total: 1,
|
||
assetTitle: "藏品A",
|
||
want: "用户42 赞了你的《藏品A》",
|
||
},
|
||
{
|
||
name: "asset title empty fallback to 你的藏品",
|
||
actors: []model.ActorPreview{{UserID: 1, Nickname: "张三"}},
|
||
total: 1,
|
||
assetTitle: "",
|
||
want: "张三 赞了你的《你的藏品》",
|
||
},
|
||
{
|
||
name: "actor with only whitespace nickname",
|
||
actors: []model.ActorPreview{
|
||
{UserID: 99, Nickname: " "},
|
||
{UserID: 100, Nickname: "小李"},
|
||
},
|
||
total: 2,
|
||
assetTitle: "藏品B",
|
||
want: "用户99、小李 赞了你的《藏品B》",
|
||
},
|
||
}
|
||
|
||
for _, tt := range cases {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
got := buildAggregatedLikeTitle(tt.actors, tt.total, tt.assetTitle)
|
||
assert.Equal(t, tt.want, got)
|
||
})
|
||
}
|
||
}
|