topfans/backend/services/notificationService/provider/notification_provider.go
2026-06-16 21:30:58 +08:00

188 lines
6.9 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// Package provider 实现 notification.proto 生成的 NotificationServiceHandler 接口。
//
// 设计要点:
// - 从 gRPC metadata或 Dubbo attachmentsgateway 兼容模式)提取 user_id / star_id。
// - 仅做参数透传和日志记录,业务逻辑全部委托给 service.NotificationService。
// - 任何错误按 service 返回的 err 透传给 gRPC 层;不吞错。
package provider
import (
"context"
"fmt"
"strconv"
"time"
"dubbo.apache.org/dubbo-go/v3/common/constant"
pbCommon "github.com/topfans/backend/pkg/proto/common"
notifPb "github.com/topfans/backend/pkg/proto/notification"
"github.com/topfans/backend/services/notificationService/service"
"google.golang.org/grpc/metadata"
)
// NotificationProvider notification 服务的 RPC Provider。
type NotificationProvider struct {
svc *service.NotificationService
}
// 编译期断言NotificationProvider 实现了 notifPb.NotificationServiceHandler 接口triple 生成)。
var _ notifPb.NotificationServiceHandler = (*NotificationProvider)(nil)
// NewNotificationProvider 创建 NotificationProvider。
func NewNotificationProvider(svc *service.NotificationService) *NotificationProvider {
return &NotificationProvider{svc: svc}
}
// ========== 8 个 RPC 方法 ==========
// CreateNotification 创建通知(无 user_id/star_id 从 metadata 取,由 service 校验)。
func (p *NotificationProvider) CreateNotification(ctx context.Context, req *notifPb.CreateNotificationRequest) (*notifPb.CreateNotificationResponse, error) {
// 注意CreateNotification 通常由 RPC 内部触发social service 调用),不强制从 metadata 取 user_id。
// 但仍优先用 metadata如有覆盖 req 中的字段,保证网关侧传过来的身份与请求一致。
if uid, sid, err := extractUserInfo(ctx); err == nil && uid > 0 {
if req.UserId <= 0 {
req.UserId = uid
}
if req.StarId <= 0 {
req.StarId = sid
}
}
return p.svc.CreateNotification(ctx, req)
}
// GetNotifications 拉取通知列表type=like 时聚合)。
func (p *NotificationProvider) GetNotifications(ctx context.Context, req *notifPb.GetNotificationsRequest) (*notifPb.GetNotificationsResponse, error) {
userID, starID, err := extractUserInfo(ctx)
if err != nil {
return nil, fmt.Errorf("extract user info: %w", err)
}
return p.svc.GetNotifications(ctx, userID, starID, req.Type, req.Tab, req.Page, req.PageSize)
}
// GetUnreadCount 获取未读计数。
func (p *NotificationProvider) GetUnreadCount(ctx context.Context, req *notifPb.GetUnreadCountRequest) (*notifPb.GetUnreadCountResponse, error) {
userID, starID, err := extractUserInfo(ctx)
if err != nil {
return nil, fmt.Errorf("extract user info: %w", err)
}
return p.svc.GetUnreadCount(ctx, userID, starID)
}
// MarkAsRead 单条标已读。
func (p *NotificationProvider) MarkAsRead(ctx context.Context, req *notifPb.MarkAsReadRequest) (*notifPb.MarkAsReadResponse, error) {
userID, starID, err := extractUserInfo(ctx)
if err != nil {
return nil, fmt.Errorf("extract user info: %w", err)
}
now := time.Now().UnixMilli()
return p.svc.MarkAsRead(ctx, userID, starID, req.Id, now)
}
// MarkAsReadByTarget 将某个 target 下所有 like 标已读。
func (p *NotificationProvider) MarkAsReadByTarget(ctx context.Context, req *notifPb.MarkAsReadByTargetRequest) (*notifPb.MarkAsReadByTargetResponse, error) {
userID, starID, err := extractUserInfo(ctx)
if err != nil {
return nil, fmt.Errorf("extract user info: %w", err)
}
now := time.Now().UnixMilli()
return p.svc.MarkAsReadByTarget(ctx, userID, starID, req.TargetId, now)
}
// MarkAllAsRead 全部已读(按 type 过滤)。
func (p *NotificationProvider) MarkAllAsRead(ctx context.Context, req *notifPb.MarkAllAsReadRequest) (*notifPb.MarkAllAsReadResponse, error) {
userID, starID, err := extractUserInfo(ctx)
if err != nil {
return nil, fmt.Errorf("extract user info: %w", err)
}
now := time.Now().UnixMilli()
return p.svc.MarkAllAsRead(ctx, userID, starID, req.Type, now)
}
// DeleteNotification 软删单条。
func (p *NotificationProvider) DeleteNotification(ctx context.Context, req *notifPb.DeleteNotificationRequest) (*notifPb.DeleteNotificationResponse, error) {
userID, starID, err := extractUserInfo(ctx)
if err != nil {
return nil, fmt.Errorf("extract user info: %w", err)
}
now := time.Now().UnixMilli()
return p.svc.DeleteNotification(ctx, userID, starID, req.Id, now)
}
// DeleteByTarget 软删某个 target 下所有 like。
func (p *NotificationProvider) DeleteByTarget(ctx context.Context, req *notifPb.DeleteByTargetRequest) (*notifPb.DeleteByTargetResponse, error) {
userID, starID, err := extractUserInfo(ctx)
if err != nil {
return nil, fmt.Errorf("extract user info: %w", err)
}
now := time.Now().UnixMilli()
return p.svc.DeleteByTarget(ctx, userID, starID, req.TargetId, now)
}
// ========== 辅助方法 ==========
// extractUserInfo 从 gRPC metadata 提取 user_id 和 star_idfallback 到 Dubbo attachments。
//
// gateway 在调用 notification service 时,会通过 gRPC metadataHTTP 层是 HTTP header传递
// x-user-id / x-star-id。Dubbo Triple 协议会把 metadata 放进 metadata.MD
// 同时也兼容 Dubbo attachmentsDubbo 老链路)。
func extractUserInfo(ctx context.Context) (int64, int64, error) {
// 优先从 gRPC metadata 取Tripe 协议会把 HTTP header 转成 metadata.MD
if md, ok := metadata.FromIncomingContext(ctx); ok {
if uid, ok := readInt64FromMD(md, "x-user-id"); ok && uid > 0 {
sid, _ := readInt64FromMD(md, "x-star-id")
return uid, sid, nil
}
}
// fallbackDubbo attachmentsconstant.AttachmentKey
if attachments := ctx.Value(constant.AttachmentKey); attachments != nil {
if attMap, ok := attachments.(map[string]interface{}); ok {
uid := parseIntValue(attMap["user_id"])
sid := parseIntValue(attMap["star_id"])
if uid > 0 && sid > 0 {
return uid, sid, nil
}
}
}
return 0, 0, fmt.Errorf("user info not found: expected x-user-id and x-star-id in metadata")
}
// readInt64FromMD 从 metadata.MD 中按 key 读取首个 int64 值gRPC metadata value 均为字符串切片)。
func readInt64FromMD(md metadata.MD, key string) (int64, bool) {
vals := md.Get(key)
if len(vals) == 0 {
return 0, false
}
n, err := strconv.ParseInt(vals[0], 10, 64)
if err != nil {
return 0, false
}
return n, true
}
// parseIntValue 把任意类型int / int64 / float64 / string转 int64。
func parseIntValue(v interface{}) int64 {
switch val := v.(type) {
case int64:
return val
case int:
return int64(val)
case float64:
return int64(val)
case string:
if i, err := strconv.ParseInt(val, 10, 64); err == nil {
return i
}
case []string:
if len(val) > 0 {
if i, err := strconv.ParseInt(val[0], 10, 64); err == nil {
return i
}
}
}
return 0
}
// avoid unused import warnings (pbCommon may not be referenced directly here but reserved for future use)
var _ = pbCommon.BaseResponse{}