topfans/backend/services/userService/middleware/auth_interceptor.go
2026-04-07 22:29:48 +08:00

201 lines
5.6 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package middleware
import (
"context"
"fmt"
"strings"
"dubbo.apache.org/dubbo-go/v3/common/constant"
"github.com/topfans/backend/pkg/jwt"
"github.com/topfans/backend/pkg/logger"
"go.uber.org/zap"
"google.golang.org/grpc/metadata"
)
// Context Key 类型
type contextKey string
const (
UserIDKey contextKey = "user_id"
StarIDKey contextKey = "star_id"
)
// min 返回两个整数中的较小值
func min(a, b int) int {
if a < b {
return a
}
return b
}
// extractTokenFromMetadata 从 gRPC metadata 或 Dubbo attachments 中提取 Token
// 支持两种格式:
// 1. Authorization: Bearer <token>
// 2. authorization: <token>
func extractTokenFromMetadata(ctx context.Context) (string, error) {
// 方案1: 尝试从 Dubbo attachments 中获取
if attachments := ctx.Value(constant.AttachmentKey); attachments != nil {
if attMap, ok := attachments.(map[string]interface{}); ok {
logger.Logger.Debug("Found Dubbo attachments",
zap.Any("attachments", attMap),
)
// 尝试获取 authorization可能是字符串或字符串数组
if authValue, ok := attMap["authorization"]; ok && authValue != nil {
var token string
// 处理字符串类型
if tokenStr, ok := authValue.(string); ok && tokenStr != "" {
token = tokenStr
} else if tokenArr, ok := authValue.([]string); ok && len(tokenArr) > 0 {
// 处理字符串数组类型
token = tokenArr[0]
} else if tokenArrInterface, ok := authValue.([]interface{}); ok && len(tokenArrInterface) > 0 {
// 处理 []interface{} 类型
if tokenStr, ok := tokenArrInterface[0].(string); ok {
token = tokenStr
}
}
if token != "" {
logger.Logger.Debug("Found token in Dubbo attachments (authorization)",
zap.String("token_prefix", token[:min(20, len(token))]+"..."),
)
// 支持 "Bearer <token>" 格式
if strings.HasPrefix(token, "Bearer ") {
return strings.TrimPrefix(token, "Bearer "), nil
}
return token, nil
}
}
// 尝试获取 Authorization大写
if authValue, ok := attMap["Authorization"]; ok && authValue != nil {
var token string
if tokenStr, ok := authValue.(string); ok && tokenStr != "" {
token = tokenStr
} else if tokenArr, ok := authValue.([]string); ok && len(tokenArr) > 0 {
token = tokenArr[0]
} else if tokenArrInterface, ok := authValue.([]interface{}); ok && len(tokenArrInterface) > 0 {
if tokenStr, ok := tokenArrInterface[0].(string); ok {
token = tokenStr
}
}
if token != "" {
logger.Logger.Debug("Found token in Dubbo attachments (Authorization)",
zap.String("token_prefix", token[:min(20, len(token))]+"..."),
)
if strings.HasPrefix(token, "Bearer ") {
return strings.TrimPrefix(token, "Bearer "), nil
}
return token, nil
}
}
}
}
// 方案2: 尝试从 gRPC metadata 中获取
md, ok := metadata.FromIncomingContext(ctx)
if ok {
logger.Logger.Debug("Found gRPC metadata",
zap.Any("keys", md),
)
// 尝试从 "authorization" 字段获取
authHeaders := md.Get("authorization")
if len(authHeaders) == 0 {
// 尝试从 "Authorization" 字段获取(大写)
authHeaders = md.Get("Authorization")
}
if len(authHeaders) > 0 {
authHeader := authHeaders[0]
logger.Logger.Debug("Found token in gRPC metadata",
zap.String("token_prefix", authHeader[:20]+"..."),
)
// 支持 "Bearer <token>" 格式
if strings.HasPrefix(authHeader, "Bearer ") {
return strings.TrimPrefix(authHeader, "Bearer "), nil
}
// 直接返回 token
return authHeader, nil
}
}
logger.Logger.Warn("Token not found in context",
zap.Bool("has_metadata", ok),
zap.Any("attachments", ctx.Value(constant.AttachmentKey)),
)
return "", fmt.Errorf("authorization token not found in context")
}
// ExtractUserIDFromContext 从 Context 中提取 user_id
func ExtractUserIDFromContext(ctx context.Context) (int64, error) {
userID, ok := ctx.Value(UserIDKey).(int64)
if !ok {
return 0, fmt.Errorf("user_id not found in context")
}
return userID, nil
}
// ExtractStarIDFromContext 从 Context 中提取 star_id
func ExtractStarIDFromContext(ctx context.Context) (int64, error) {
starID, ok := ctx.Value(StarIDKey).(int64)
if !ok {
return 0, fmt.Errorf("star_id not found in context")
}
return starID, nil
}
// ExtractUserInfoFromContext 从 Context 中提取 user_id 和 star_id
func ExtractUserInfoFromContext(ctx context.Context) (int64, int64, error) {
userID, err := ExtractUserIDFromContext(ctx)
if err != nil {
return 0, 0, err
}
starID, err := ExtractStarIDFromContext(ctx)
if err != nil {
return 0, 0, err
}
return userID, starID, nil
}
// ValidateTokenAndExtractClaims 验证 Token 并提取用户信息到 Context
// 这是一个辅助函数,用于在 Provider 方法中手动验证 Token
func ValidateTokenAndExtractClaims(ctx context.Context) (context.Context, error) {
// 从 metadata 中提取 Token
token, err := extractTokenFromMetadata(ctx)
if err != nil {
logger.Logger.Warn("Failed to extract token from metadata",
zap.Error(err),
)
return ctx, fmt.Errorf("missing or invalid authorization token")
}
// 验证 Token 并提取用户信息
claims, err := jwt.ParseToken(token)
if err != nil {
logger.Logger.Warn("Failed to parse token",
zap.Error(err),
)
return ctx, fmt.Errorf("invalid token")
}
// 将用户信息注入到 Context
ctx = context.WithValue(ctx, UserIDKey, claims.UserID)
ctx = context.WithValue(ctx, StarIDKey, claims.StarID)
logger.Logger.Debug("Token validated successfully",
zap.Int64("user_id", claims.UserID),
zap.Int64("star_id", claims.StarID),
)
return ctx, nil
}