topfans/backend/services/notificationService/service/trigger_push_test.go

240 lines
7.2 KiB
Go

package service
import (
"context"
"net/http"
"net/http/httptest"
"strings"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
notifPb "github.com/topfans/backend/pkg/proto/notification"
"github.com/topfans/backend/pkg/push"
"google.golang.org/protobuf/types/known/structpb"
)
// fakePusher 记录最近一次 Push 调用;用于验证 CreateNotification 是否触发推送。
type fakePusher struct {
mu sync.Mutex
calls []push.Payload
err error
delay time.Duration
}
func (f *fakePusher) Send(_ context.Context, p push.Payload) error {
if f.delay > 0 {
time.Sleep(f.delay)
}
f.mu.Lock()
defer f.mu.Unlock()
f.calls = append(f.calls, p)
return f.err
}
func (f *fakePusher) last() push.Payload {
f.mu.Lock()
defer f.mu.Unlock()
if len(f.calls) == 0 {
return push.Payload{}
}
return f.calls[len(f.calls)-1]
}
func (f *fakePusher) count() int {
f.mu.Lock()
defer f.mu.Unlock()
return len(f.calls)
}
// TestTriggerPush_NoPusherNoDevice 验证:未注入 pusher/device 时不触发推送,
// CreateNotification 仍正常返回(不影响主流程)。
// 注:CreateNotification 需要真 DB,缺 DB 时跳过。
func TestTriggerPush_NoPusherNoDevice(t *testing.T) {
db, ok := setupTestDB(t)
if !ok {
t.Skip("skipping: test DB not available")
}
svc := NewNotificationService(db, nil, nil)
data, _ := structpb.NewStruct(map[string]interface{}{"target_id": int64(1)})
resp, err := svc.CreateNotification(context.Background(), &notifPb.CreateNotificationRequest{
UserId: 880500,
StarId: 1,
Type: "system",
Title: "triggerPush-disabled",
Data: data,
})
require.NoError(t, err)
require.NotNil(t, resp)
assert.Greater(t, resp.Id, int64(0))
// cleanup
t.Cleanup(func() {
ctx := context.Background()
_ = db.WithContext(ctx).Exec(`DELETE FROM public.notifications WHERE user_id=$1`, 880500).Error
_ = db.WithContext(ctx).Exec(`DELETE FROM public.notification_stats WHERE user_id=$1`, 880500).Error
})
}
// TestTriggerPush_HTTPEndToEnd 验证:CreateNotification 后通过 httptest mock 验证 HTTP 出站载荷。
// 1) 用 httptest 模拟 uniCloud 云函数;2) 把 mock URL 注入 UniPushClient;
// 3) 触发 CreateNotification;4) 等 goroutine 完成;5) 检查 mock 收到的请求。
//
// 由于 triggerPush 用 goroutine 异步发,我们用 polling 等一下(最多 2s)。
func TestTriggerPush_HTTPEndToEnd(t *testing.T) {
db, ok := setupTestDB(t)
if !ok {
t.Skip("skipping: test DB not available")
}
var (
mu sync.Mutex
hitCnt int
)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var p push.Payload
_ = decodeJSON(r, &p)
mu.Lock()
hitCnt++
mu.Unlock()
w.WriteHeader(200)
_, _ = w.Write([]byte(`{"errcode":0}`))
}))
defer srv.Close()
// 注入 UniPushClient + 一个不会 panic 的 deviceService(用 nil db 也能跑,因为我们提前塞好 cids 走不通)
// 这里 deviceService 用 nil,所以 triggerPush 会因为 s.device == nil 直接返回 —— 这条路径已被上面测过。
// 真正能验证 HTTP 出站的是把 deviceService.ListActiveCIDs 替换掉。
//
// 替代方案:在测试里直接把 pusher 用 fakePusher,然后断言 payload 字段正确。
fp := &fakePusher{}
notifSvc := NewNotificationService(db, nil, fp)
data, _ := structpb.NewStruct(map[string]interface{}{
"target_id": int64(1234),
"actor_id": int64(99),
})
resp, err := notifSvc.CreateNotification(context.Background(), &notifPb.CreateNotificationRequest{
UserId: 880501,
StarId: 1,
Type: "like",
Title: "张三 赞了你的《藏品A》",
Content: "查看详情",
Data: data,
})
require.NoError(t, err)
require.NotNil(t, resp)
assert.Greater(t, resp.Id, int64(0))
// cleanup
t.Cleanup(func() {
ctx := context.Background()
_ = db.WithContext(ctx).Exec(`DELETE FROM public.notifications WHERE user_id=$1`, 880501).Error
_ = db.WithContext(ctx).Exec(`DELETE FROM public.notification_stats WHERE user_id=$1`, 880501).Error
})
// 由于 deviceSvc=nil,triggerPush 直接 return 不发推送,这里断言不会发生调用
// (我们要测的是 HTTP 出站,见下一个 TestTriggerPush_FullFlow).
assert.Equal(t, 0, fp.count(), "deviceSvc=nil 时 triggerPush 直接跳过")
}
// TestTriggerPush_FullFlow 完整链路验证:
// - 用 httptest 作为云函数 mock
// - 手动构造 UserDeviceService(共享 db)+ UniPushClient(指向 httptest URL)
// - 先 RegisterDevice 注册 cid,再 CreateNotification,断言 mock 收到的 payload 包含 cids/title/data。
func TestTriggerPush_FullFlow(t *testing.T) {
db, ok := setupTestDB(t)
if !ok {
t.Skip("skipping: test DB not available")
}
var (
mu sync.Mutex
got push.Payload
hitCnt int
)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var p push.Payload
_ = decodeJSON(r, &p)
mu.Lock()
got = p
hitCnt++
mu.Unlock()
w.WriteHeader(200)
_, _ = w.Write([]byte(`{"errcode":0}`))
}))
defer srv.Close()
deviceSvc := NewUserDeviceService(db)
pushCli := push.NewUniPushClient(srv.URL, 2*time.Second, nil)
notifSvc := NewNotificationService(db, deviceSvc, pushCli)
userID := int64(880502)
cid := "test-cid-trigger-" + time.Now().Format("150405.000")
// cleanup
t.Cleanup(func() {
ctx := context.Background()
_ = db.WithContext(ctx).Exec(`DELETE FROM public.user_devices WHERE user_id=$1`, userID).Error
_ = db.WithContext(ctx).Exec(`DELETE FROM public.notifications WHERE user_id=$1`, userID).Error
_ = db.WithContext(ctx).Exec(`DELETE FROM public.notification_stats WHERE user_id=$1`, userID).Error
})
// 1) 注册 cid
regResp, err := deviceSvc.RegisterDevice(context.Background(), userID, &notifPb.RegisterDeviceRequest{
Cid: cid, Platform: "ios", AppVersion: "1.0.0", DeviceModel: "iPhone",
})
require.NoError(t, err)
require.NotNil(t, regResp)
assert.Greater(t, regResp.Id, int64(0))
// 2) CreateNotification 触发推送
data, _ := structpb.NewStruct(map[string]interface{}{
"target_id": int64(7777),
"actor_id": int64(99),
})
_, err = notifSvc.CreateNotification(context.Background(), &notifPb.CreateNotificationRequest{
UserId: userID, StarId: 1, Type: "like",
Title: "触发推送测试", Content: "你有一条新消息",
Data: data,
})
require.NoError(t, err)
// 3) 等异步推送(最多 3s)
deadline := time.Now().Add(3 * time.Second)
for time.Now().Before(deadline) {
mu.Lock()
c := hitCnt
mu.Unlock()
if c >= 1 {
break
}
time.Sleep(50 * time.Millisecond)
}
mu.Lock()
finalCnt := hitCnt
finalPayload := got
mu.Unlock()
assert.Equal(t, 1, finalCnt, "应触发 1 次推送")
if finalCnt >= 1 {
assert.Equal(t, "触发推送测试", finalPayload.Title)
assert.Equal(t, "你有一条新消息", finalPayload.Content)
assert.Contains(t, finalPayload.CIDs, cid, "payload.cids 应包含已注册的 cid")
assert.NotEmpty(t, finalPayload.RequestID)
assert.EqualValues(t, 7777, finalPayload.Data["target_id"])
assert.EqualValues(t, int64(99), finalPayload.Data["actor_id"])
}
_ = strings.TrimSpace // 防止 lint 抱怨
}
// decodeJSON 辅助:读取 body 并 unmarshal。
func decodeJSON(r *http.Request, dst interface{}) error {
defer r.Body.Close()
dec := newJSONDecoder(r.Body)
return dec.Decode(dst)
}