201 lines
5.6 KiB
Go
201 lines
5.6 KiB
Go
package middleware
|
||
|
||
import (
|
||
"context"
|
||
"fmt"
|
||
"strings"
|
||
|
||
"dubbo.apache.org/dubbo-go/v3/common/constant"
|
||
"github.com/topfans/backend/pkg/jwt"
|
||
"github.com/topfans/backend/pkg/logger"
|
||
"go.uber.org/zap"
|
||
"google.golang.org/grpc/metadata"
|
||
)
|
||
|
||
// Context Key 类型
|
||
type contextKey string
|
||
|
||
const (
|
||
UserIDKey contextKey = "user_id"
|
||
StarIDKey contextKey = "star_id"
|
||
)
|
||
|
||
// min 返回两个整数中的较小值
|
||
func min(a, b int) int {
|
||
if a < b {
|
||
return a
|
||
}
|
||
return b
|
||
}
|
||
|
||
// extractTokenFromMetadata 从 gRPC metadata 或 Dubbo attachments 中提取 Token
|
||
// 支持两种格式:
|
||
// 1. Authorization: Bearer <token>
|
||
// 2. authorization: <token>
|
||
func extractTokenFromMetadata(ctx context.Context) (string, error) {
|
||
// 方案1: 尝试从 Dubbo attachments 中获取
|
||
if attachments := ctx.Value(constant.AttachmentKey); attachments != nil {
|
||
if attMap, ok := attachments.(map[string]interface{}); ok {
|
||
logger.Logger.Debug("Found Dubbo attachments",
|
||
zap.Any("attachments", attMap),
|
||
)
|
||
|
||
// 尝试获取 authorization(可能是字符串或字符串数组)
|
||
if authValue, ok := attMap["authorization"]; ok && authValue != nil {
|
||
var token string
|
||
|
||
// 处理字符串类型
|
||
if tokenStr, ok := authValue.(string); ok && tokenStr != "" {
|
||
token = tokenStr
|
||
} else if tokenArr, ok := authValue.([]string); ok && len(tokenArr) > 0 {
|
||
// 处理字符串数组类型
|
||
token = tokenArr[0]
|
||
} else if tokenArrInterface, ok := authValue.([]interface{}); ok && len(tokenArrInterface) > 0 {
|
||
// 处理 []interface{} 类型
|
||
if tokenStr, ok := tokenArrInterface[0].(string); ok {
|
||
token = tokenStr
|
||
}
|
||
}
|
||
|
||
if token != "" {
|
||
logger.Logger.Debug("Found token in Dubbo attachments (authorization)",
|
||
zap.String("token_prefix", token[:min(20, len(token))]+"..."),
|
||
)
|
||
// 支持 "Bearer <token>" 格式
|
||
if strings.HasPrefix(token, "Bearer ") {
|
||
return strings.TrimPrefix(token, "Bearer "), nil
|
||
}
|
||
return token, nil
|
||
}
|
||
}
|
||
|
||
// 尝试获取 Authorization(大写)
|
||
if authValue, ok := attMap["Authorization"]; ok && authValue != nil {
|
||
var token string
|
||
|
||
if tokenStr, ok := authValue.(string); ok && tokenStr != "" {
|
||
token = tokenStr
|
||
} else if tokenArr, ok := authValue.([]string); ok && len(tokenArr) > 0 {
|
||
token = tokenArr[0]
|
||
} else if tokenArrInterface, ok := authValue.([]interface{}); ok && len(tokenArrInterface) > 0 {
|
||
if tokenStr, ok := tokenArrInterface[0].(string); ok {
|
||
token = tokenStr
|
||
}
|
||
}
|
||
|
||
if token != "" {
|
||
logger.Logger.Debug("Found token in Dubbo attachments (Authorization)",
|
||
zap.String("token_prefix", token[:min(20, len(token))]+"..."),
|
||
)
|
||
if strings.HasPrefix(token, "Bearer ") {
|
||
return strings.TrimPrefix(token, "Bearer "), nil
|
||
}
|
||
return token, nil
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// 方案2: 尝试从 gRPC metadata 中获取
|
||
md, ok := metadata.FromIncomingContext(ctx)
|
||
if ok {
|
||
logger.Logger.Debug("Found gRPC metadata",
|
||
zap.Any("keys", md),
|
||
)
|
||
|
||
// 尝试从 "authorization" 字段获取
|
||
authHeaders := md.Get("authorization")
|
||
if len(authHeaders) == 0 {
|
||
// 尝试从 "Authorization" 字段获取(大写)
|
||
authHeaders = md.Get("Authorization")
|
||
}
|
||
|
||
if len(authHeaders) > 0 {
|
||
authHeader := authHeaders[0]
|
||
logger.Logger.Debug("Found token in gRPC metadata",
|
||
zap.String("token_prefix", authHeader[:20]+"..."),
|
||
)
|
||
|
||
// 支持 "Bearer <token>" 格式
|
||
if strings.HasPrefix(authHeader, "Bearer ") {
|
||
return strings.TrimPrefix(authHeader, "Bearer "), nil
|
||
}
|
||
|
||
// 直接返回 token
|
||
return authHeader, nil
|
||
}
|
||
}
|
||
|
||
logger.Logger.Warn("Token not found in context",
|
||
zap.Bool("has_metadata", ok),
|
||
zap.Any("attachments", ctx.Value(constant.AttachmentKey)),
|
||
)
|
||
return "", fmt.Errorf("authorization token not found in context")
|
||
}
|
||
|
||
// ExtractUserIDFromContext 从 Context 中提取 user_id
|
||
func ExtractUserIDFromContext(ctx context.Context) (int64, error) {
|
||
userID, ok := ctx.Value(UserIDKey).(int64)
|
||
if !ok {
|
||
return 0, fmt.Errorf("user_id not found in context")
|
||
}
|
||
return userID, nil
|
||
}
|
||
|
||
// ExtractStarIDFromContext 从 Context 中提取 star_id
|
||
func ExtractStarIDFromContext(ctx context.Context) (int64, error) {
|
||
starID, ok := ctx.Value(StarIDKey).(int64)
|
||
if !ok {
|
||
return 0, fmt.Errorf("star_id not found in context")
|
||
}
|
||
return starID, nil
|
||
}
|
||
|
||
// ExtractUserInfoFromContext 从 Context 中提取 user_id 和 star_id
|
||
func ExtractUserInfoFromContext(ctx context.Context) (int64, int64, error) {
|
||
userID, err := ExtractUserIDFromContext(ctx)
|
||
if err != nil {
|
||
return 0, 0, err
|
||
}
|
||
|
||
starID, err := ExtractStarIDFromContext(ctx)
|
||
if err != nil {
|
||
return 0, 0, err
|
||
}
|
||
|
||
return userID, starID, nil
|
||
}
|
||
|
||
// ValidateTokenAndExtractClaims 验证 Token 并提取用户信息到 Context
|
||
// 这是一个辅助函数,用于在 Provider 方法中手动验证 Token
|
||
func ValidateTokenAndExtractClaims(ctx context.Context) (context.Context, error) {
|
||
// 从 metadata 中提取 Token
|
||
token, err := extractTokenFromMetadata(ctx)
|
||
if err != nil {
|
||
logger.Logger.Warn("Failed to extract token from metadata",
|
||
zap.Error(err),
|
||
)
|
||
return ctx, fmt.Errorf("missing or invalid authorization token")
|
||
}
|
||
|
||
// 验证 Token 并提取用户信息
|
||
claims, err := jwt.ParseToken(token)
|
||
if err != nil {
|
||
logger.Logger.Warn("Failed to parse token",
|
||
zap.Error(err),
|
||
)
|
||
return ctx, fmt.Errorf("invalid token")
|
||
}
|
||
|
||
// 将用户信息注入到 Context
|
||
ctx = context.WithValue(ctx, UserIDKey, claims.UserID)
|
||
ctx = context.WithValue(ctx, StarIDKey, claims.StarID)
|
||
|
||
logger.Logger.Debug("Token validated successfully",
|
||
zap.Int64("user_id", claims.UserID),
|
||
zap.Int64("star_id", claims.StarID),
|
||
)
|
||
|
||
return ctx, nil
|
||
}
|