125 lines
3.4 KiB
Go
125 lines
3.4 KiB
Go
package middleware
|
||
|
||
import (
|
||
"net/http"
|
||
"strings"
|
||
|
||
"github.com/gin-gonic/gin"
|
||
"github.com/topfans/backend/gateway/pkg/response"
|
||
"github.com/topfans/backend/pkg/database"
|
||
"github.com/topfans/backend/pkg/jwt"
|
||
"github.com/topfans/backend/pkg/logger"
|
||
"go.uber.org/zap"
|
||
)
|
||
|
||
// AuthMiddleware JWT 认证中间件
|
||
func AuthMiddleware() gin.HandlerFunc {
|
||
return func(c *gin.Context) {
|
||
// 1. 从 HTTP Header 提取 Token
|
||
authHeader := c.GetHeader("Authorization")
|
||
if authHeader == "" {
|
||
logger.Logger.Warn("Missing authorization token",
|
||
zap.String("path", c.Request.URL.Path),
|
||
zap.String("method", c.Request.Method),
|
||
)
|
||
response.Unauthorized(c, "未携带访问令牌")
|
||
c.Abort()
|
||
return
|
||
}
|
||
|
||
// 2. 去除 "Bearer " 前缀
|
||
token := strings.TrimPrefix(authHeader, "Bearer ")
|
||
if token == authHeader {
|
||
// 没有 Bearer 前缀,也尝试解析(兼容性)
|
||
token = authHeader
|
||
}
|
||
|
||
// 3. 验证 Token
|
||
claims, err := jwt.ParseToken(token)
|
||
if err != nil {
|
||
logger.Logger.Warn("Invalid token",
|
||
zap.String("path", c.Request.URL.Path),
|
||
zap.String("method", c.Request.Method),
|
||
zap.Error(err),
|
||
)
|
||
response.Unauthorized(c, "访问令牌无效或已过期")
|
||
c.Abort()
|
||
return
|
||
}
|
||
|
||
// 4. 检查 Token 是否在黑名单
|
||
isBlacklisted, bannedUserID, banReason, err := database.IsBlacklisted(c.Request.Context(), token)
|
||
if err != nil {
|
||
// Redis 错误时 fail-closed(安全策略),拒绝请求
|
||
logger.Logger.Error("Failed to check blacklist, rejecting request for security",
|
||
zap.String("path", c.Request.URL.Path),
|
||
zap.Error(err),
|
||
)
|
||
response.Unauthorized(c, "认证服务异常,请稍后重试")
|
||
c.Abort()
|
||
return
|
||
}
|
||
if isBlacklisted {
|
||
logger.Logger.Warn("Token is blacklisted",
|
||
zap.Int64("banned_user_id", bannedUserID),
|
||
zap.String("ban_reason", banReason),
|
||
zap.String("path", c.Request.URL.Path),
|
||
)
|
||
response.Unauthorized(c, "账号已被封禁")
|
||
c.Abort()
|
||
return
|
||
}
|
||
|
||
// 5. 将用户信息存入 gin.Context
|
||
c.Set("user_id", claims.UserID)
|
||
c.Set("star_id", claims.StarID)
|
||
c.Set("token_updated_at", claims.UpdatedAt)
|
||
|
||
logger.Logger.Debug("User authenticated",
|
||
zap.Int64("user_id", claims.UserID),
|
||
zap.Int64("star_id", claims.StarID),
|
||
zap.String("path", c.Request.URL.Path),
|
||
)
|
||
|
||
c.Next()
|
||
}
|
||
}
|
||
|
||
// CORSMiddleware CORS 中间件
|
||
func CORSMiddleware() gin.HandlerFunc {
|
||
return func(c *gin.Context) {
|
||
c.Writer.Header().Set("Access-Control-Allow-Origin", "*")
|
||
c.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
|
||
c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With")
|
||
c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT, DELETE")
|
||
|
||
if c.Request.Method == "OPTIONS" {
|
||
c.AbortWithStatus(http.StatusNoContent)
|
||
return
|
||
}
|
||
|
||
c.Next()
|
||
}
|
||
}
|
||
|
||
// LoggerMiddleware 日志中间件
|
||
func LoggerMiddleware() gin.HandlerFunc {
|
||
return func(c *gin.Context) {
|
||
// 记录请求开始
|
||
logger.Logger.Info("Request received",
|
||
zap.String("method", c.Request.Method),
|
||
zap.String("path", c.Request.URL.Path),
|
||
zap.String("ip", c.ClientIP()),
|
||
)
|
||
|
||
c.Next()
|
||
|
||
// 记录请求结束
|
||
logger.Logger.Info("Request completed",
|
||
zap.String("method", c.Request.Method),
|
||
zap.String("path", c.Request.URL.Path),
|
||
zap.Int("status", c.Writer.Status()),
|
||
)
|
||
}
|
||
}
|