package middleware import ( "context" "fmt" "strings" "dubbo.apache.org/dubbo-go/v3/common/constant" "github.com/topfans/backend/pkg/jwt" "github.com/topfans/backend/pkg/logger" "go.uber.org/zap" "google.golang.org/grpc/metadata" ) // Context Key 类型 type contextKey string const ( UserIDKey contextKey = "user_id" StarIDKey contextKey = "star_id" ) // min 返回两个整数中的较小值 func min(a, b int) int { if a < b { return a } return b } // extractTokenFromMetadata 从 gRPC metadata 或 Dubbo attachments 中提取 Token // 支持两种格式: // 1. Authorization: Bearer // 2. authorization: func extractTokenFromMetadata(ctx context.Context) (string, error) { // 方案1: 尝试从 Dubbo attachments 中获取 if attachments := ctx.Value(constant.AttachmentKey); attachments != nil { if attMap, ok := attachments.(map[string]interface{}); ok { logger.Logger.Debug("Found Dubbo attachments", zap.Any("attachments", attMap), ) // 尝试获取 authorization(可能是字符串或字符串数组) if authValue, ok := attMap["authorization"]; ok && authValue != nil { var token string // 处理字符串类型 if tokenStr, ok := authValue.(string); ok && tokenStr != "" { token = tokenStr } else if tokenArr, ok := authValue.([]string); ok && len(tokenArr) > 0 { // 处理字符串数组类型 token = tokenArr[0] } else if tokenArrInterface, ok := authValue.([]interface{}); ok && len(tokenArrInterface) > 0 { // 处理 []interface{} 类型 if tokenStr, ok := tokenArrInterface[0].(string); ok { token = tokenStr } } if token != "" { logger.Logger.Debug("Found token in Dubbo attachments (authorization)", zap.String("token_prefix", token[:min(20, len(token))]+"..."), ) // 支持 "Bearer " 格式 if strings.HasPrefix(token, "Bearer ") { return strings.TrimPrefix(token, "Bearer "), nil } return token, nil } } // 尝试获取 Authorization(大写) if authValue, ok := attMap["Authorization"]; ok && authValue != nil { var token string if tokenStr, ok := authValue.(string); ok && tokenStr != "" { token = tokenStr } else if tokenArr, ok := authValue.([]string); ok && len(tokenArr) > 0 { token = tokenArr[0] } else if tokenArrInterface, ok := authValue.([]interface{}); ok && len(tokenArrInterface) > 0 { if tokenStr, ok := tokenArrInterface[0].(string); ok { token = tokenStr } } if token != "" { logger.Logger.Debug("Found token in Dubbo attachments (Authorization)", zap.String("token_prefix", token[:min(20, len(token))]+"..."), ) if strings.HasPrefix(token, "Bearer ") { return strings.TrimPrefix(token, "Bearer "), nil } return token, nil } } } } // 方案2: 尝试从 gRPC metadata 中获取 md, ok := metadata.FromIncomingContext(ctx) if ok { logger.Logger.Debug("Found gRPC metadata", zap.Any("keys", md), ) // 尝试从 "authorization" 字段获取 authHeaders := md.Get("authorization") if len(authHeaders) == 0 { // 尝试从 "Authorization" 字段获取(大写) authHeaders = md.Get("Authorization") } if len(authHeaders) > 0 { authHeader := authHeaders[0] logger.Logger.Debug("Found token in gRPC metadata", zap.String("token_prefix", authHeader[:20]+"..."), ) // 支持 "Bearer " 格式 if strings.HasPrefix(authHeader, "Bearer ") { return strings.TrimPrefix(authHeader, "Bearer "), nil } // 直接返回 token return authHeader, nil } } logger.Logger.Warn("Token not found in context", zap.Bool("has_metadata", ok), zap.Any("attachments", ctx.Value(constant.AttachmentKey)), ) return "", fmt.Errorf("authorization token not found in context") } // ExtractUserIDFromContext 从 Context 中提取 user_id func ExtractUserIDFromContext(ctx context.Context) (int64, error) { userID, ok := ctx.Value(UserIDKey).(int64) if !ok { return 0, fmt.Errorf("user_id not found in context") } return userID, nil } // ExtractStarIDFromContext 从 Context 中提取 star_id func ExtractStarIDFromContext(ctx context.Context) (int64, error) { starID, ok := ctx.Value(StarIDKey).(int64) if !ok { return 0, fmt.Errorf("star_id not found in context") } return starID, nil } // ExtractUserInfoFromContext 从 Context 中提取 user_id 和 star_id func ExtractUserInfoFromContext(ctx context.Context) (int64, int64, error) { userID, err := ExtractUserIDFromContext(ctx) if err != nil { return 0, 0, err } starID, err := ExtractStarIDFromContext(ctx) if err != nil { return 0, 0, err } return userID, starID, nil } // ValidateTokenAndExtractClaims 验证 Token 并提取用户信息到 Context // 这是一个辅助函数,用于在 Provider 方法中手动验证 Token func ValidateTokenAndExtractClaims(ctx context.Context) (context.Context, error) { // 从 metadata 中提取 Token token, err := extractTokenFromMetadata(ctx) if err != nil { logger.Logger.Warn("Failed to extract token from metadata", zap.Error(err), ) return ctx, fmt.Errorf("missing or invalid authorization token") } // 验证 Token 并提取用户信息 claims, err := jwt.ParseToken(token) if err != nil { logger.Logger.Warn("Failed to parse token", zap.Error(err), ) return ctx, fmt.Errorf("invalid token") } // 将用户信息注入到 Context ctx = context.WithValue(ctx, UserIDKey, claims.UserID) ctx = context.WithValue(ctx, StarIDKey, claims.StarID) logger.Logger.Debug("Token validated successfully", zap.Int64("user_id", claims.UserID), zap.Int64("star_id", claims.StarID), ) return ctx, nil }