// Package provider 实现 notification.proto 生成的 NotificationServiceHandler 接口。 // // 设计要点: // - 从 gRPC metadata(或 Dubbo attachments,gateway 兼容模式)提取 user_id / star_id。 // - 仅做参数透传和日志记录,业务逻辑全部委托给 service.NotificationService。 // - 任何错误按 service 返回的 err 透传给 gRPC 层;不吞错。 package provider import ( "context" "fmt" "strconv" "time" "dubbo.apache.org/dubbo-go/v3/common/constant" pbCommon "github.com/topfans/backend/pkg/proto/common" notifPb "github.com/topfans/backend/pkg/proto/notification" "github.com/topfans/backend/services/notificationService/service" "google.golang.org/grpc/metadata" ) // NotificationProvider notification 服务的 RPC Provider。 type NotificationProvider struct { svc *service.NotificationService } // 编译期断言:NotificationProvider 实现了 notifPb.NotificationServiceHandler 接口(triple 生成)。 var _ notifPb.NotificationServiceHandler = (*NotificationProvider)(nil) // NewNotificationProvider 创建 NotificationProvider。 func NewNotificationProvider(svc *service.NotificationService) *NotificationProvider { return &NotificationProvider{svc: svc} } // ========== 8 个 RPC 方法 ========== // CreateNotification 创建通知(无 user_id/star_id 从 metadata 取,由 service 校验)。 func (p *NotificationProvider) CreateNotification(ctx context.Context, req *notifPb.CreateNotificationRequest) (*notifPb.CreateNotificationResponse, error) { // 注意:CreateNotification 通常由 RPC 内部触发(social service 调用),不强制从 metadata 取 user_id。 // 但仍优先用 metadata(如有)覆盖 req 中的字段,保证网关侧传过来的身份与请求一致。 if uid, sid, err := extractUserInfo(ctx); err == nil && uid > 0 { if req.UserId <= 0 { req.UserId = uid } if req.StarId <= 0 { req.StarId = sid } } return p.svc.CreateNotification(ctx, req) } // GetNotifications 拉取通知列表(type=like 时聚合)。 func (p *NotificationProvider) GetNotifications(ctx context.Context, req *notifPb.GetNotificationsRequest) (*notifPb.GetNotificationsResponse, error) { userID, starID, err := extractUserInfo(ctx) if err != nil { return nil, fmt.Errorf("extract user info: %w", err) } return p.svc.GetNotifications(ctx, userID, starID, req.Type, req.Tab, req.Page, req.PageSize) } // GetUnreadCount 获取未读计数。 func (p *NotificationProvider) GetUnreadCount(ctx context.Context, req *notifPb.GetUnreadCountRequest) (*notifPb.GetUnreadCountResponse, error) { userID, starID, err := extractUserInfo(ctx) if err != nil { return nil, fmt.Errorf("extract user info: %w", err) } return p.svc.GetUnreadCount(ctx, userID, starID) } // MarkAsRead 单条标已读。 func (p *NotificationProvider) MarkAsRead(ctx context.Context, req *notifPb.MarkAsReadRequest) (*notifPb.MarkAsReadResponse, error) { userID, starID, err := extractUserInfo(ctx) if err != nil { return nil, fmt.Errorf("extract user info: %w", err) } now := time.Now().UnixMilli() return p.svc.MarkAsRead(ctx, userID, starID, req.Id, now) } // MarkAsReadByTarget 将某个 target 下所有 like 标已读。 func (p *NotificationProvider) MarkAsReadByTarget(ctx context.Context, req *notifPb.MarkAsReadByTargetRequest) (*notifPb.MarkAsReadByTargetResponse, error) { userID, starID, err := extractUserInfo(ctx) if err != nil { return nil, fmt.Errorf("extract user info: %w", err) } now := time.Now().UnixMilli() return p.svc.MarkAsReadByTarget(ctx, userID, starID, req.TargetId, now) } // MarkAllAsRead 全部已读(按 type 过滤)。 func (p *NotificationProvider) MarkAllAsRead(ctx context.Context, req *notifPb.MarkAllAsReadRequest) (*notifPb.MarkAllAsReadResponse, error) { userID, starID, err := extractUserInfo(ctx) if err != nil { return nil, fmt.Errorf("extract user info: %w", err) } now := time.Now().UnixMilli() return p.svc.MarkAllAsRead(ctx, userID, starID, req.Type, now) } // DeleteNotification 软删单条。 func (p *NotificationProvider) DeleteNotification(ctx context.Context, req *notifPb.DeleteNotificationRequest) (*notifPb.DeleteNotificationResponse, error) { userID, starID, err := extractUserInfo(ctx) if err != nil { return nil, fmt.Errorf("extract user info: %w", err) } now := time.Now().UnixMilli() return p.svc.DeleteNotification(ctx, userID, starID, req.Id, now) } // DeleteByTarget 软删某个 target 下所有 like。 func (p *NotificationProvider) DeleteByTarget(ctx context.Context, req *notifPb.DeleteByTargetRequest) (*notifPb.DeleteByTargetResponse, error) { userID, starID, err := extractUserInfo(ctx) if err != nil { return nil, fmt.Errorf("extract user info: %w", err) } now := time.Now().UnixMilli() return p.svc.DeleteByTarget(ctx, userID, starID, req.TargetId, now) } // ========== 辅助方法 ========== // extractUserInfo 从 gRPC metadata 提取 user_id 和 star_id,fallback 到 Dubbo attachments。 // // gateway 在调用 notification service 时,会通过 gRPC metadata(HTTP 层是 HTTP header)传递 // x-user-id / x-star-id。Dubbo Triple 协议会把 metadata 放进 metadata.MD; // 同时也兼容 Dubbo attachments(Dubbo 老链路)。 func extractUserInfo(ctx context.Context) (int64, int64, error) { // 优先从 gRPC metadata 取(Tripe 协议会把 HTTP header 转成 metadata.MD) if md, ok := metadata.FromIncomingContext(ctx); ok { if uid, ok := readInt64FromMD(md, "x-user-id"); ok && uid > 0 { sid, _ := readInt64FromMD(md, "x-star-id") return uid, sid, nil } } // fallback:Dubbo attachments(constant.AttachmentKey) if attachments := ctx.Value(constant.AttachmentKey); attachments != nil { if attMap, ok := attachments.(map[string]interface{}); ok { uid := parseIntValue(attMap["user_id"]) sid := parseIntValue(attMap["star_id"]) if uid > 0 && sid > 0 { return uid, sid, nil } } } return 0, 0, fmt.Errorf("user info not found: expected x-user-id and x-star-id in metadata") } // readInt64FromMD 从 metadata.MD 中按 key 读取首个 int64 值(gRPC metadata value 均为字符串切片)。 func readInt64FromMD(md metadata.MD, key string) (int64, bool) { vals := md.Get(key) if len(vals) == 0 { return 0, false } n, err := strconv.ParseInt(vals[0], 10, 64) if err != nil { return 0, false } return n, true } // parseIntValue 把任意类型(int / int64 / float64 / string)转 int64。 func parseIntValue(v interface{}) int64 { switch val := v.(type) { case int64: return val case int: return int64(val) case float64: return int64(val) case string: if i, err := strconv.ParseInt(val, 10, 64); err == nil { return i } case []string: if len(val) > 0 { if i, err := strconv.ParseInt(val[0], 10, 64); err == nil { return i } } } return 0 } // avoid unused import warnings (pbCommon may not be referenced directly here but reserved for future use) var _ = pbCommon.BaseResponse{}