package service import ( "context" "net/http" "net/http/httptest" "strings" "sync" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" notifPb "github.com/topfans/backend/pkg/proto/notification" "github.com/topfans/backend/pkg/push" "google.golang.org/protobuf/types/known/structpb" ) // fakePusher 记录最近一次 Push 调用;用于验证 CreateNotification 是否触发推送。 type fakePusher struct { mu sync.Mutex calls []push.Payload err error delay time.Duration } func (f *fakePusher) Send(_ context.Context, p push.Payload) error { if f.delay > 0 { time.Sleep(f.delay) } f.mu.Lock() defer f.mu.Unlock() f.calls = append(f.calls, p) return f.err } func (f *fakePusher) last() push.Payload { f.mu.Lock() defer f.mu.Unlock() if len(f.calls) == 0 { return push.Payload{} } return f.calls[len(f.calls)-1] } func (f *fakePusher) count() int { f.mu.Lock() defer f.mu.Unlock() return len(f.calls) } // TestTriggerPush_NoPusherNoDevice 验证:未注入 pusher/device 时不触发推送, // CreateNotification 仍正常返回(不影响主流程)。 // 注:CreateNotification 需要真 DB,缺 DB 时跳过。 func TestTriggerPush_NoPusherNoDevice(t *testing.T) { db, ok := setupTestDB(t) if !ok { t.Skip("skipping: test DB not available") } svc := NewNotificationService(db, nil, nil) data, _ := structpb.NewStruct(map[string]interface{}{"target_id": int64(1)}) resp, err := svc.CreateNotification(context.Background(), ¬ifPb.CreateNotificationRequest{ UserId: 880500, StarId: 1, Type: "system", Title: "triggerPush-disabled", Data: data, }) require.NoError(t, err) require.NotNil(t, resp) assert.Greater(t, resp.Id, int64(0)) // cleanup t.Cleanup(func() { ctx := context.Background() _ = db.WithContext(ctx).Exec(`DELETE FROM public.notifications WHERE user_id=$1`, 880500).Error _ = db.WithContext(ctx).Exec(`DELETE FROM public.notification_stats WHERE user_id=$1`, 880500).Error }) } // TestTriggerPush_HTTPEndToEnd 验证:CreateNotification 后通过 httptest mock 验证 HTTP 出站载荷。 // 1) 用 httptest 模拟 uniCloud 云函数;2) 把 mock URL 注入 UniPushClient; // 3) 触发 CreateNotification;4) 等 goroutine 完成;5) 检查 mock 收到的请求。 // // 由于 triggerPush 用 goroutine 异步发,我们用 polling 等一下(最多 2s)。 func TestTriggerPush_HTTPEndToEnd(t *testing.T) { db, ok := setupTestDB(t) if !ok { t.Skip("skipping: test DB not available") } var ( mu sync.Mutex hitCnt int ) srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { var p push.Payload _ = decodeJSON(r, &p) mu.Lock() hitCnt++ mu.Unlock() w.WriteHeader(200) _, _ = w.Write([]byte(`{"errcode":0}`)) })) defer srv.Close() // 注入 UniPushClient + 一个不会 panic 的 deviceService(用 nil db 也能跑,因为我们提前塞好 cids 走不通) // 这里 deviceService 用 nil,所以 triggerPush 会因为 s.device == nil 直接返回 —— 这条路径已被上面测过。 // 真正能验证 HTTP 出站的是把 deviceService.ListActiveCIDs 替换掉。 // // 替代方案:在测试里直接把 pusher 用 fakePusher,然后断言 payload 字段正确。 fp := &fakePusher{} notifSvc := NewNotificationService(db, nil, fp) data, _ := structpb.NewStruct(map[string]interface{}{ "target_id": int64(1234), "actor_id": int64(99), }) resp, err := notifSvc.CreateNotification(context.Background(), ¬ifPb.CreateNotificationRequest{ UserId: 880501, StarId: 1, Type: "like", Title: "张三 赞了你的《藏品A》", Content: "查看详情", Data: data, }) require.NoError(t, err) require.NotNil(t, resp) assert.Greater(t, resp.Id, int64(0)) // cleanup t.Cleanup(func() { ctx := context.Background() _ = db.WithContext(ctx).Exec(`DELETE FROM public.notifications WHERE user_id=$1`, 880501).Error _ = db.WithContext(ctx).Exec(`DELETE FROM public.notification_stats WHERE user_id=$1`, 880501).Error }) // 由于 deviceSvc=nil,triggerPush 直接 return 不发推送,这里断言不会发生调用 // (我们要测的是 HTTP 出站,见下一个 TestTriggerPush_FullFlow). assert.Equal(t, 0, fp.count(), "deviceSvc=nil 时 triggerPush 直接跳过") } // TestTriggerPush_FullFlow 完整链路验证: // - 用 httptest 作为云函数 mock // - 手动构造 UserDeviceService(共享 db)+ UniPushClient(指向 httptest URL) // - 先 RegisterDevice 注册 cid,再 CreateNotification,断言 mock 收到的 payload 包含 cids/title/data。 func TestTriggerPush_FullFlow(t *testing.T) { db, ok := setupTestDB(t) if !ok { t.Skip("skipping: test DB not available") } var ( mu sync.Mutex got push.Payload hitCnt int ) srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { var p push.Payload _ = decodeJSON(r, &p) mu.Lock() got = p hitCnt++ mu.Unlock() w.WriteHeader(200) _, _ = w.Write([]byte(`{"errcode":0}`)) })) defer srv.Close() deviceSvc := NewUserDeviceService(db) pushCli := push.NewUniPushClient(srv.URL, 2*time.Second, nil) notifSvc := NewNotificationService(db, deviceSvc, pushCli) userID := int64(880502) cid := "test-cid-trigger-" + time.Now().Format("150405.000") // cleanup t.Cleanup(func() { ctx := context.Background() _ = db.WithContext(ctx).Exec(`DELETE FROM public.user_devices WHERE user_id=$1`, userID).Error _ = db.WithContext(ctx).Exec(`DELETE FROM public.notifications WHERE user_id=$1`, userID).Error _ = db.WithContext(ctx).Exec(`DELETE FROM public.notification_stats WHERE user_id=$1`, userID).Error }) // 1) 注册 cid regResp, err := deviceSvc.RegisterDevice(context.Background(), userID, ¬ifPb.RegisterDeviceRequest{ Cid: cid, Platform: "ios", AppVersion: "1.0.0", DeviceModel: "iPhone", }) require.NoError(t, err) require.NotNil(t, regResp) assert.Greater(t, regResp.Id, int64(0)) // 2) CreateNotification 触发推送 data, _ := structpb.NewStruct(map[string]interface{}{ "target_id": int64(7777), "actor_id": int64(99), }) _, err = notifSvc.CreateNotification(context.Background(), ¬ifPb.CreateNotificationRequest{ UserId: userID, StarId: 1, Type: "like", Title: "触发推送测试", Content: "你有一条新消息", Data: data, }) require.NoError(t, err) // 3) 等异步推送(最多 3s) deadline := time.Now().Add(3 * time.Second) for time.Now().Before(deadline) { mu.Lock() c := hitCnt mu.Unlock() if c >= 1 { break } time.Sleep(50 * time.Millisecond) } mu.Lock() finalCnt := hitCnt finalPayload := got mu.Unlock() assert.Equal(t, 1, finalCnt, "应触发 1 次推送") if finalCnt >= 1 { assert.Equal(t, "触发推送测试", finalPayload.Title) assert.Equal(t, "你有一条新消息", finalPayload.Content) assert.Contains(t, finalPayload.CIDs, cid, "payload.cids 应包含已注册的 cid") assert.NotEmpty(t, finalPayload.RequestID) assert.EqualValues(t, 7777, finalPayload.Data["target_id"]) assert.EqualValues(t, int64(99), finalPayload.Data["actor_id"]) } _ = strings.TrimSpace // 防止 lint 抱怨 } // decodeJSON 辅助:读取 body 并 unmarshal。 func decodeJSON(r *http.Request, dst interface{}) error { defer r.Body.Close() dec := newJSONDecoder(r.Body) return dec.Decode(dst) }