188 lines
6.9 KiB
Go
188 lines
6.9 KiB
Go
// Package provider 实现 notification.proto 生成的 NotificationServiceHandler 接口。
|
||
//
|
||
// 设计要点:
|
||
// - 从 gRPC metadata(或 Dubbo attachments,gateway 兼容模式)提取 user_id / star_id。
|
||
// - 仅做参数透传和日志记录,业务逻辑全部委托给 service.NotificationService。
|
||
// - 任何错误按 service 返回的 err 透传给 gRPC 层;不吞错。
|
||
package provider
|
||
|
||
import (
|
||
"context"
|
||
"fmt"
|
||
"strconv"
|
||
"time"
|
||
|
||
"dubbo.apache.org/dubbo-go/v3/common/constant"
|
||
pbCommon "github.com/topfans/backend/pkg/proto/common"
|
||
notifPb "github.com/topfans/backend/pkg/proto/notification"
|
||
"github.com/topfans/backend/services/notificationService/service"
|
||
"google.golang.org/grpc/metadata"
|
||
)
|
||
|
||
// NotificationProvider notification 服务的 RPC Provider。
|
||
type NotificationProvider struct {
|
||
svc *service.NotificationService
|
||
}
|
||
|
||
// 编译期断言:NotificationProvider 实现了 notifPb.NotificationServiceHandler 接口(triple 生成)。
|
||
var _ notifPb.NotificationServiceHandler = (*NotificationProvider)(nil)
|
||
|
||
// NewNotificationProvider 创建 NotificationProvider。
|
||
func NewNotificationProvider(svc *service.NotificationService) *NotificationProvider {
|
||
return &NotificationProvider{svc: svc}
|
||
}
|
||
|
||
// ========== 8 个 RPC 方法 ==========
|
||
|
||
// CreateNotification 创建通知(无 user_id/star_id 从 metadata 取,由 service 校验)。
|
||
func (p *NotificationProvider) CreateNotification(ctx context.Context, req *notifPb.CreateNotificationRequest) (*notifPb.CreateNotificationResponse, error) {
|
||
// 注意:CreateNotification 通常由 RPC 内部触发(social service 调用),不强制从 metadata 取 user_id。
|
||
// 但仍优先用 metadata(如有)覆盖 req 中的字段,保证网关侧传过来的身份与请求一致。
|
||
if uid, sid, err := extractUserInfo(ctx); err == nil && uid > 0 {
|
||
if req.UserId <= 0 {
|
||
req.UserId = uid
|
||
}
|
||
if req.StarId <= 0 {
|
||
req.StarId = sid
|
||
}
|
||
}
|
||
return p.svc.CreateNotification(ctx, req)
|
||
}
|
||
|
||
// GetNotifications 拉取通知列表(type=like 时聚合)。
|
||
func (p *NotificationProvider) GetNotifications(ctx context.Context, req *notifPb.GetNotificationsRequest) (*notifPb.GetNotificationsResponse, error) {
|
||
userID, starID, err := extractUserInfo(ctx)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("extract user info: %w", err)
|
||
}
|
||
return p.svc.GetNotifications(ctx, userID, starID, req.Type, req.Tab, req.Page, req.PageSize)
|
||
}
|
||
|
||
// GetUnreadCount 获取未读计数。
|
||
func (p *NotificationProvider) GetUnreadCount(ctx context.Context, req *notifPb.GetUnreadCountRequest) (*notifPb.GetUnreadCountResponse, error) {
|
||
userID, starID, err := extractUserInfo(ctx)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("extract user info: %w", err)
|
||
}
|
||
return p.svc.GetUnreadCount(ctx, userID, starID)
|
||
}
|
||
|
||
// MarkAsRead 单条标已读。
|
||
func (p *NotificationProvider) MarkAsRead(ctx context.Context, req *notifPb.MarkAsReadRequest) (*notifPb.MarkAsReadResponse, error) {
|
||
userID, starID, err := extractUserInfo(ctx)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("extract user info: %w", err)
|
||
}
|
||
now := time.Now().UnixMilli()
|
||
return p.svc.MarkAsRead(ctx, userID, starID, req.Id, now)
|
||
}
|
||
|
||
// MarkAsReadByTarget 将某个 target 下所有 like 标已读。
|
||
func (p *NotificationProvider) MarkAsReadByTarget(ctx context.Context, req *notifPb.MarkAsReadByTargetRequest) (*notifPb.MarkAsReadByTargetResponse, error) {
|
||
userID, starID, err := extractUserInfo(ctx)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("extract user info: %w", err)
|
||
}
|
||
now := time.Now().UnixMilli()
|
||
return p.svc.MarkAsReadByTarget(ctx, userID, starID, req.TargetId, now)
|
||
}
|
||
|
||
// MarkAllAsRead 全部已读(按 type 过滤)。
|
||
func (p *NotificationProvider) MarkAllAsRead(ctx context.Context, req *notifPb.MarkAllAsReadRequest) (*notifPb.MarkAllAsReadResponse, error) {
|
||
userID, starID, err := extractUserInfo(ctx)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("extract user info: %w", err)
|
||
}
|
||
now := time.Now().UnixMilli()
|
||
return p.svc.MarkAllAsRead(ctx, userID, starID, req.Type, now)
|
||
}
|
||
|
||
// DeleteNotification 软删单条。
|
||
func (p *NotificationProvider) DeleteNotification(ctx context.Context, req *notifPb.DeleteNotificationRequest) (*notifPb.DeleteNotificationResponse, error) {
|
||
userID, starID, err := extractUserInfo(ctx)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("extract user info: %w", err)
|
||
}
|
||
now := time.Now().UnixMilli()
|
||
return p.svc.DeleteNotification(ctx, userID, starID, req.Id, now)
|
||
}
|
||
|
||
// DeleteByTarget 软删某个 target 下所有 like。
|
||
func (p *NotificationProvider) DeleteByTarget(ctx context.Context, req *notifPb.DeleteByTargetRequest) (*notifPb.DeleteByTargetResponse, error) {
|
||
userID, starID, err := extractUserInfo(ctx)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("extract user info: %w", err)
|
||
}
|
||
now := time.Now().UnixMilli()
|
||
return p.svc.DeleteByTarget(ctx, userID, starID, req.TargetId, now)
|
||
}
|
||
|
||
// ========== 辅助方法 ==========
|
||
|
||
// extractUserInfo 从 gRPC metadata 提取 user_id 和 star_id,fallback 到 Dubbo attachments。
|
||
//
|
||
// gateway 在调用 notification service 时,会通过 gRPC metadata(HTTP 层是 HTTP header)传递
|
||
// x-user-id / x-star-id。Dubbo Triple 协议会把 metadata 放进 metadata.MD;
|
||
// 同时也兼容 Dubbo attachments(Dubbo 老链路)。
|
||
func extractUserInfo(ctx context.Context) (int64, int64, error) {
|
||
// 优先从 gRPC metadata 取(Tripe 协议会把 HTTP header 转成 metadata.MD)
|
||
if md, ok := metadata.FromIncomingContext(ctx); ok {
|
||
if uid, ok := readInt64FromMD(md, "x-user-id"); ok && uid > 0 {
|
||
sid, _ := readInt64FromMD(md, "x-star-id")
|
||
return uid, sid, nil
|
||
}
|
||
}
|
||
|
||
// fallback:Dubbo attachments(constant.AttachmentKey)
|
||
if attachments := ctx.Value(constant.AttachmentKey); attachments != nil {
|
||
if attMap, ok := attachments.(map[string]interface{}); ok {
|
||
uid := parseIntValue(attMap["user_id"])
|
||
sid := parseIntValue(attMap["star_id"])
|
||
if uid > 0 && sid > 0 {
|
||
return uid, sid, nil
|
||
}
|
||
}
|
||
}
|
||
|
||
return 0, 0, fmt.Errorf("user info not found: expected x-user-id and x-star-id in metadata")
|
||
}
|
||
|
||
// readInt64FromMD 从 metadata.MD 中按 key 读取首个 int64 值(gRPC metadata value 均为字符串切片)。
|
||
func readInt64FromMD(md metadata.MD, key string) (int64, bool) {
|
||
vals := md.Get(key)
|
||
if len(vals) == 0 {
|
||
return 0, false
|
||
}
|
||
n, err := strconv.ParseInt(vals[0], 10, 64)
|
||
if err != nil {
|
||
return 0, false
|
||
}
|
||
return n, true
|
||
}
|
||
|
||
// parseIntValue 把任意类型(int / int64 / float64 / string)转 int64。
|
||
func parseIntValue(v interface{}) int64 {
|
||
switch val := v.(type) {
|
||
case int64:
|
||
return val
|
||
case int:
|
||
return int64(val)
|
||
case float64:
|
||
return int64(val)
|
||
case string:
|
||
if i, err := strconv.ParseInt(val, 10, 64); err == nil {
|
||
return i
|
||
}
|
||
case []string:
|
||
if len(val) > 0 {
|
||
if i, err := strconv.ParseInt(val[0], 10, 64); err == nil {
|
||
return i
|
||
}
|
||
}
|
||
}
|
||
return 0
|
||
}
|
||
|
||
// avoid unused import warnings (pbCommon may not be referenced directly here but reserved for future use)
|
||
var _ = pbCommon.BaseResponse{}
|