240 lines
7.2 KiB
Go
240 lines
7.2 KiB
Go
package service
|
|
|
|
import (
|
|
"context"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
notifPb "github.com/topfans/backend/pkg/proto/notification"
|
|
"github.com/topfans/backend/pkg/push"
|
|
"google.golang.org/protobuf/types/known/structpb"
|
|
)
|
|
|
|
// fakePusher 记录最近一次 Push 调用;用于验证 CreateNotification 是否触发推送。
|
|
type fakePusher struct {
|
|
mu sync.Mutex
|
|
calls []push.Payload
|
|
err error
|
|
delay time.Duration
|
|
}
|
|
|
|
func (f *fakePusher) Send(_ context.Context, p push.Payload) error {
|
|
if f.delay > 0 {
|
|
time.Sleep(f.delay)
|
|
}
|
|
f.mu.Lock()
|
|
defer f.mu.Unlock()
|
|
f.calls = append(f.calls, p)
|
|
return f.err
|
|
}
|
|
|
|
func (f *fakePusher) last() push.Payload {
|
|
f.mu.Lock()
|
|
defer f.mu.Unlock()
|
|
if len(f.calls) == 0 {
|
|
return push.Payload{}
|
|
}
|
|
return f.calls[len(f.calls)-1]
|
|
}
|
|
|
|
func (f *fakePusher) count() int {
|
|
f.mu.Lock()
|
|
defer f.mu.Unlock()
|
|
return len(f.calls)
|
|
}
|
|
|
|
// TestTriggerPush_NoPusherNoDevice 验证:未注入 pusher/device 时不触发推送,
|
|
// CreateNotification 仍正常返回(不影响主流程)。
|
|
// 注:CreateNotification 需要真 DB,缺 DB 时跳过。
|
|
func TestTriggerPush_NoPusherNoDevice(t *testing.T) {
|
|
db, ok := setupTestDB(t)
|
|
if !ok {
|
|
t.Skip("skipping: test DB not available")
|
|
}
|
|
svc := NewNotificationService(db, nil, nil)
|
|
|
|
data, _ := structpb.NewStruct(map[string]interface{}{"target_id": int64(1)})
|
|
resp, err := svc.CreateNotification(context.Background(), ¬ifPb.CreateNotificationRequest{
|
|
UserId: 880500,
|
|
StarId: 1,
|
|
Type: "system",
|
|
Title: "triggerPush-disabled",
|
|
Data: data,
|
|
})
|
|
require.NoError(t, err)
|
|
require.NotNil(t, resp)
|
|
assert.Greater(t, resp.Id, int64(0))
|
|
|
|
// cleanup
|
|
t.Cleanup(func() {
|
|
ctx := context.Background()
|
|
_ = db.WithContext(ctx).Exec(`DELETE FROM public.notifications WHERE user_id=$1`, 880500).Error
|
|
_ = db.WithContext(ctx).Exec(`DELETE FROM public.notification_stats WHERE user_id=$1`, 880500).Error
|
|
})
|
|
}
|
|
|
|
// TestTriggerPush_HTTPEndToEnd 验证:CreateNotification 后通过 httptest mock 验证 HTTP 出站载荷。
|
|
// 1) 用 httptest 模拟 uniCloud 云函数;2) 把 mock URL 注入 UniPushClient;
|
|
// 3) 触发 CreateNotification;4) 等 goroutine 完成;5) 检查 mock 收到的请求。
|
|
//
|
|
// 由于 triggerPush 用 goroutine 异步发,我们用 polling 等一下(最多 2s)。
|
|
func TestTriggerPush_HTTPEndToEnd(t *testing.T) {
|
|
db, ok := setupTestDB(t)
|
|
if !ok {
|
|
t.Skip("skipping: test DB not available")
|
|
}
|
|
|
|
var (
|
|
mu sync.Mutex
|
|
hitCnt int
|
|
)
|
|
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
var p push.Payload
|
|
_ = decodeJSON(r, &p)
|
|
mu.Lock()
|
|
hitCnt++
|
|
mu.Unlock()
|
|
w.WriteHeader(200)
|
|
_, _ = w.Write([]byte(`{"errcode":0}`))
|
|
}))
|
|
defer srv.Close()
|
|
|
|
// 注入 UniPushClient + 一个不会 panic 的 deviceService(用 nil db 也能跑,因为我们提前塞好 cids 走不通)
|
|
// 这里 deviceService 用 nil,所以 triggerPush 会因为 s.device == nil 直接返回 —— 这条路径已被上面测过。
|
|
// 真正能验证 HTTP 出站的是把 deviceService.ListActiveCIDs 替换掉。
|
|
//
|
|
// 替代方案:在测试里直接把 pusher 用 fakePusher,然后断言 payload 字段正确。
|
|
fp := &fakePusher{}
|
|
notifSvc := NewNotificationService(db, nil, fp)
|
|
|
|
data, _ := structpb.NewStruct(map[string]interface{}{
|
|
"target_id": int64(1234),
|
|
"actor_id": int64(99),
|
|
})
|
|
resp, err := notifSvc.CreateNotification(context.Background(), ¬ifPb.CreateNotificationRequest{
|
|
UserId: 880501,
|
|
StarId: 1,
|
|
Type: "like",
|
|
Title: "张三 赞了你的《藏品A》",
|
|
Content: "查看详情",
|
|
Data: data,
|
|
})
|
|
require.NoError(t, err)
|
|
require.NotNil(t, resp)
|
|
assert.Greater(t, resp.Id, int64(0))
|
|
|
|
// cleanup
|
|
t.Cleanup(func() {
|
|
ctx := context.Background()
|
|
_ = db.WithContext(ctx).Exec(`DELETE FROM public.notifications WHERE user_id=$1`, 880501).Error
|
|
_ = db.WithContext(ctx).Exec(`DELETE FROM public.notification_stats WHERE user_id=$1`, 880501).Error
|
|
})
|
|
|
|
// 由于 deviceSvc=nil,triggerPush 直接 return 不发推送,这里断言不会发生调用
|
|
// (我们要测的是 HTTP 出站,见下一个 TestTriggerPush_FullFlow).
|
|
assert.Equal(t, 0, fp.count(), "deviceSvc=nil 时 triggerPush 直接跳过")
|
|
}
|
|
|
|
// TestTriggerPush_FullFlow 完整链路验证:
|
|
// - 用 httptest 作为云函数 mock
|
|
// - 手动构造 UserDeviceService(共享 db)+ UniPushClient(指向 httptest URL)
|
|
// - 先 RegisterDevice 注册 cid,再 CreateNotification,断言 mock 收到的 payload 包含 cids/title/data。
|
|
func TestTriggerPush_FullFlow(t *testing.T) {
|
|
db, ok := setupTestDB(t)
|
|
if !ok {
|
|
t.Skip("skipping: test DB not available")
|
|
}
|
|
|
|
var (
|
|
mu sync.Mutex
|
|
got push.Payload
|
|
hitCnt int
|
|
)
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
var p push.Payload
|
|
_ = decodeJSON(r, &p)
|
|
mu.Lock()
|
|
got = p
|
|
hitCnt++
|
|
mu.Unlock()
|
|
w.WriteHeader(200)
|
|
_, _ = w.Write([]byte(`{"errcode":0}`))
|
|
}))
|
|
defer srv.Close()
|
|
|
|
deviceSvc := NewUserDeviceService(db)
|
|
pushCli := push.NewUniPushClient(srv.URL, 2*time.Second, nil)
|
|
notifSvc := NewNotificationService(db, deviceSvc, pushCli)
|
|
|
|
userID := int64(880502)
|
|
cid := "test-cid-trigger-" + time.Now().Format("150405.000")
|
|
|
|
// cleanup
|
|
t.Cleanup(func() {
|
|
ctx := context.Background()
|
|
_ = db.WithContext(ctx).Exec(`DELETE FROM public.user_devices WHERE user_id=$1`, userID).Error
|
|
_ = db.WithContext(ctx).Exec(`DELETE FROM public.notifications WHERE user_id=$1`, userID).Error
|
|
_ = db.WithContext(ctx).Exec(`DELETE FROM public.notification_stats WHERE user_id=$1`, userID).Error
|
|
})
|
|
|
|
// 1) 注册 cid
|
|
regResp, err := deviceSvc.RegisterDevice(context.Background(), userID, ¬ifPb.RegisterDeviceRequest{
|
|
Cid: cid, Platform: "ios", AppVersion: "1.0.0", DeviceModel: "iPhone",
|
|
})
|
|
require.NoError(t, err)
|
|
require.NotNil(t, regResp)
|
|
assert.Greater(t, regResp.Id, int64(0))
|
|
|
|
// 2) CreateNotification 触发推送
|
|
data, _ := structpb.NewStruct(map[string]interface{}{
|
|
"target_id": int64(7777),
|
|
"actor_id": int64(99),
|
|
})
|
|
_, err = notifSvc.CreateNotification(context.Background(), ¬ifPb.CreateNotificationRequest{
|
|
UserId: userID, StarId: 1, Type: "like",
|
|
Title: "触发推送测试", Content: "你有一条新消息",
|
|
Data: data,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
// 3) 等异步推送(最多 3s)
|
|
deadline := time.Now().Add(3 * time.Second)
|
|
for time.Now().Before(deadline) {
|
|
mu.Lock()
|
|
c := hitCnt
|
|
mu.Unlock()
|
|
if c >= 1 {
|
|
break
|
|
}
|
|
time.Sleep(50 * time.Millisecond)
|
|
}
|
|
|
|
mu.Lock()
|
|
finalCnt := hitCnt
|
|
finalPayload := got
|
|
mu.Unlock()
|
|
assert.Equal(t, 1, finalCnt, "应触发 1 次推送")
|
|
if finalCnt >= 1 {
|
|
assert.Equal(t, "触发推送测试", finalPayload.Title)
|
|
assert.Equal(t, "你有一条新消息", finalPayload.Content)
|
|
assert.Contains(t, finalPayload.CIDs, cid, "payload.cids 应包含已注册的 cid")
|
|
assert.NotEmpty(t, finalPayload.RequestID)
|
|
assert.EqualValues(t, 7777, finalPayload.Data["target_id"])
|
|
assert.EqualValues(t, int64(99), finalPayload.Data["actor_id"])
|
|
}
|
|
_ = strings.TrimSpace // 防止 lint 抱怨
|
|
}
|
|
|
|
// decodeJSON 辅助:读取 body 并 unmarshal。
|
|
func decodeJSON(r *http.Request, dst interface{}) error {
|
|
defer r.Body.Close()
|
|
dec := newJSONDecoder(r.Body)
|
|
return dec.Decode(dst)
|
|
} |