// Package provider 实现 notification.proto 生成的 NotificationServiceHandler 接口。 // // 设计要点: // - 从 gRPC metadata(或 Dubbo attachments,gateway 兼容模式)提取 user_id / star_id。 // - 仅做参数透传和日志记录,业务逻辑全部委托给 service.NotificationService。 // - 任何错误按 service 返回的 err 透传给 gRPC 层;不吞错。 package provider import ( "context" "fmt" "strconv" "time" "dubbo.apache.org/dubbo-go/v3/common/constant" pbCommon "github.com/topfans/backend/pkg/proto/common" notifPb "github.com/topfans/backend/pkg/proto/notification" "github.com/topfans/backend/services/notificationService/service" "google.golang.org/grpc/metadata" ) // NotificationProvider notification 服务的 RPC Provider。 type NotificationProvider struct { notifSvc *service.NotificationService deviceSvc *service.UserDeviceService } // 编译期断言:NotificationProvider 实现了 notifPb.NotificationServiceHandler 接口(triple 生成)。 var _ notifPb.NotificationServiceHandler = (*NotificationProvider)(nil) // NewNotificationProvider 创建 NotificationProvider。 func NewNotificationProvider(notifSvc *service.NotificationService, deviceSvc *service.UserDeviceService) *NotificationProvider { return &NotificationProvider{notifSvc: notifSvc, deviceSvc: deviceSvc} } // ========== 8 个 RPC 方法 ========== // CreateNotification 创建通知(无 user_id/star_id 从 metadata 取,由 service 校验)。 func (p *NotificationProvider) CreateNotification(ctx context.Context, req *notifPb.CreateNotificationRequest) (*notifPb.CreateNotificationResponse, error) { // 注意:CreateNotification 通常由 RPC 内部触发(social service 调用),不强制从 metadata 取 user_id。 // 但仍优先用 metadata(如有)覆盖 req 中的字段,保证网关侧传过来的身份与请求一致。 if uid, sid, err := extractUserInfo(ctx); err == nil && uid > 0 { if req.UserId <= 0 { req.UserId = uid } if req.StarId <= 0 { req.StarId = sid } } return p.notifSvc.CreateNotification(ctx, req) } // GetNotifications 拉取通知列表(type=like 时聚合)。 func (p *NotificationProvider) GetNotifications(ctx context.Context, req *notifPb.GetNotificationsRequest) (*notifPb.GetNotificationsResponse, error) { userID, starID, err := extractUserInfo(ctx) if err != nil { return nil, fmt.Errorf("extract user info: %w", err) } return p.notifSvc.GetNotifications(ctx, userID, starID, req.Type, req.Tab, req.Page, req.PageSize) } // GetUnreadCount 获取未读计数。 func (p *NotificationProvider) GetUnreadCount(ctx context.Context, req *notifPb.GetUnreadCountRequest) (*notifPb.GetUnreadCountResponse, error) { userID, starID, err := extractUserInfo(ctx) if err != nil { return nil, fmt.Errorf("extract user info: %w", err) } return p.notifSvc.GetUnreadCount(ctx, userID, starID) } // MarkAsRead 单条标已读。 func (p *NotificationProvider) MarkAsRead(ctx context.Context, req *notifPb.MarkAsReadRequest) (*notifPb.MarkAsReadResponse, error) { userID, starID, err := extractUserInfo(ctx) if err != nil { return nil, fmt.Errorf("extract user info: %w", err) } now := time.Now().UnixMilli() return p.notifSvc.MarkAsRead(ctx, userID, starID, req.Id, now) } // MarkAsReadByTarget 将某个 target 下所有 like 标已读。 func (p *NotificationProvider) MarkAsReadByTarget(ctx context.Context, req *notifPb.MarkAsReadByTargetRequest) (*notifPb.MarkAsReadByTargetResponse, error) { userID, starID, err := extractUserInfo(ctx) if err != nil { return nil, fmt.Errorf("extract user info: %w", err) } now := time.Now().UnixMilli() return p.notifSvc.MarkAsReadByTarget(ctx, userID, starID, req.TargetId, now) } // MarkAllAsRead 全部已读(按 type 过滤)。 func (p *NotificationProvider) MarkAllAsRead(ctx context.Context, req *notifPb.MarkAllAsReadRequest) (*notifPb.MarkAllAsReadResponse, error) { userID, starID, err := extractUserInfo(ctx) if err != nil { return nil, fmt.Errorf("extract user info: %w", err) } now := time.Now().UnixMilli() return p.notifSvc.MarkAllAsRead(ctx, userID, starID, req.Type, now) } // DeleteNotification 软删单条。 func (p *NotificationProvider) DeleteNotification(ctx context.Context, req *notifPb.DeleteNotificationRequest) (*notifPb.DeleteNotificationResponse, error) { userID, starID, err := extractUserInfo(ctx) if err != nil { return nil, fmt.Errorf("extract user info: %w", err) } now := time.Now().UnixMilli() return p.notifSvc.DeleteNotification(ctx, userID, starID, req.Id, now) } // DeleteByTarget 软删某个 target 下所有 like。 func (p *NotificationProvider) DeleteByTarget(ctx context.Context, req *notifPb.DeleteByTargetRequest) (*notifPb.DeleteByTargetResponse, error) { userID, starID, err := extractUserInfo(ctx) if err != nil { return nil, fmt.Errorf("extract user info: %w", err) } now := time.Now().UnixMilli() return p.notifSvc.DeleteByTarget(ctx, userID, starID, req.TargetId, now) } // ========== 设备注册 RPC ========== // RegisterDevice 注册/更新推送 cid。user_id 从 metadata 取。 func (p *NotificationProvider) RegisterDevice(ctx context.Context, req *notifPb.RegisterDeviceRequest) (*notifPb.RegisterDeviceResponse, error) { userID, _, err := extractUserInfo(ctx) if err != nil { return nil, fmt.Errorf("extract user info: %w", err) } return p.deviceSvc.RegisterDevice(ctx, userID, req) } // UnregisterDevice 注销推送 cid。cid 为空时注销当前用户的所有设备。 func (p *NotificationProvider) UnregisterDevice(ctx context.Context, req *notifPb.UnregisterDeviceRequest) (*notifPb.UnregisterDeviceResponse, error) { userID, _, err := extractUserInfo(ctx) if err != nil { return nil, fmt.Errorf("extract user info: %w", err) } return p.deviceSvc.UnregisterDevice(ctx, userID, req) } // ========== 辅助方法 ========== // extractUserInfo 从 gRPC metadata 提取 user_id 和 star_id,fallback 到 Dubbo attachments。 // // gateway 在调用 notification service 时,会通过 gRPC metadata(HTTP 层是 HTTP header)传递 // x-user-id / x-star-id。Dubbo Triple 协议会把 metadata 放进 metadata.MD; // 同时也兼容 Dubbo attachments(Dubbo 老链路)。 func extractUserInfo(ctx context.Context) (int64, int64, error) { // 优先从 gRPC metadata 取(Tripe 协议会把 HTTP header 转成 metadata.MD) if md, ok := metadata.FromIncomingContext(ctx); ok { if uid, ok := readInt64FromMD(md, "x-user-id"); ok && uid > 0 { sid, _ := readInt64FromMD(md, "x-star-id") return uid, sid, nil } } // fallback:Dubbo attachments(constant.AttachmentKey) if attachments := ctx.Value(constant.AttachmentKey); attachments != nil { if attMap, ok := attachments.(map[string]interface{}); ok { uid := parseIntValue(attMap["user_id"]) sid := parseIntValue(attMap["star_id"]) if uid > 0 && sid > 0 { return uid, sid, nil } } } return 0, 0, fmt.Errorf("user info not found: expected x-user-id and x-star-id in metadata") } // readInt64FromMD 从 metadata.MD 中按 key 读取首个 int64 值(gRPC metadata value 均为字符串切片)。 func readInt64FromMD(md metadata.MD, key string) (int64, bool) { vals := md.Get(key) if len(vals) == 0 { return 0, false } n, err := strconv.ParseInt(vals[0], 10, 64) if err != nil { return 0, false } return n, true } // parseIntValue 把任意类型(int / int64 / float64 / string)转 int64。 func parseIntValue(v interface{}) int64 { switch val := v.(type) { case int64: return val case int: return int64(val) case float64: return int64(val) case string: if i, err := strconv.ParseInt(val, 10, 64); err == nil { return i } case []string: if len(val) > 0 { if i, err := strconv.ParseInt(val[0], 10, 64); err == nil { return i } } } return 0 } // avoid unused import warnings (pbCommon may not be referenced directly here but reserved for future use) var _ = pbCommon.BaseResponse{}