topfans/backend/gateway/middleware/grpc_status_interceptor_test.go
2026-06-15 16:28:35 +08:00

147 lines
4.1 KiB
Go

package middleware
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
)
func init() {
gin.SetMode(gin.TestMode)
}
func TestGRPCStatusInterceptor_DubboPath_OK(t *testing.T) {
// 模拟 Dubbo 透传: gRPC code=7 (PermissionDenied), Dubbo 框架把它转成 HTTP 403
// 拦截器要剥掉这个 HTTP 403,统一改写为 HTTP 200
r := gin.New()
r.Use(GRPCStatusInterceptor())
r.GET("/api/v1/foo", func(c *gin.Context) {
c.JSON(http.StatusForbidden, gin.H{
"code": 7,
"message": "账号已被封禁",
})
})
req := httptest.NewRequest("GET", "/api/v1/foo", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("status = %d, want %d (200)", w.Code, http.StatusOK)
}
// body 应当原样保留
if got := w.Body.String(); got != `{"code":7,"message":"账号已被封禁"}` {
t.Errorf("body = %s, want %s", got, `{"code":7,"message":"账号已被封禁"}`)
}
}
func TestGRPCStatusInterceptor_DubboPath_Unauthenticated(t *testing.T) {
// gRPC code=16 (Unauthenticated), Dubbo 转成 HTTP 401
// 拦截器要改写为 HTTP 200
r := gin.New()
r.Use(GRPCStatusInterceptor())
r.GET("/api/v1/bar", func(c *gin.Context) {
c.JSON(http.StatusUnauthorized, gin.H{
"code": 16,
"message": "token 过期",
})
})
req := httptest.NewRequest("GET", "/api/v1/bar", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("status = %d, want %d (200)", w.Code, http.StatusOK)
}
}
func TestGRPCStatusInterceptor_NonDubboPath_Skip(t *testing.T) {
// 非 /api/* 路径(health/swagger 等)直接跳过
r := gin.New()
r.Use(GRPCStatusInterceptor())
r.GET("/health", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"status": "ok"})
})
req := httptest.NewRequest("GET", "/health", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("status = %d, want %d", w.Code, http.StatusOK)
}
}
func TestGRPCStatusInterceptor_BodyPreserved(t *testing.T) {
// 验证 body 在重写 status 时不被修改
r := gin.New()
r.Use(GRPCStatusInterceptor())
r.POST("/api/v1/test", func(c *gin.Context) {
c.JSON(http.StatusInternalServerError, gin.H{
"code": 13,
"message": "internal error",
"data": gin.H{"foo": "bar"},
})
})
req := httptest.NewRequest("POST", "/api/v1/test", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("status = %d, want %d", w.Code, http.StatusOK)
}
if w.Body.Len() == 0 {
t.Error("body is empty, expected to be preserved")
}
if !contains(w.Body.String(), `"code":13`) {
t.Errorf("body doesn't contain expected code=13: %s", w.Body.String())
}
}
func TestGRPCStatusInterceptor_AuthMiddleware401_Kept(t *testing.T) {
// 模拟 AuthMiddleware 写的 401(transport error),拦截器不应改写
r := gin.New()
r.Use(GRPCStatusInterceptor())
r.GET("/api/v1/protected", func(c *gin.Context) {
// 模拟 AuthMiddleware 写 401
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
"code": 16,
"message": "未携带访问令牌",
})
})
req := httptest.NewRequest("GET", "/api/v1/protected", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
// 关键: HTTP 401 应被保留(transport error 不被业务拦截器改写)
if w.Code != http.StatusUnauthorized {
t.Errorf("status = %d, want %d (AuthMiddleware 401 should be preserved)", w.Code, http.StatusUnauthorized)
}
}
func TestGRPCStatusInterceptor_NotFound_NoInterception(t *testing.T) {
// 路由不存在的情况: Gin 不会进拦截器链(路由未匹配)
// 由 Gin 自己处理默认 404(本测试只验证不会 panic 之类的)
// 真正的 404 transport 行为需在 E2E 测试覆盖
r := gin.New()
r.Use(GRPCStatusInterceptor())
// 故意不注册 /api/v1/notexist
req := httptest.NewRequest("GET", "/api/v1/notexist", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
// 验证响应成功产生(不为空)
_ = w
}
// helpers
func contains(s, sub string) bool {
return len(s) >= len(sub) && (s == sub || (len(s) > 0 && (s[:len(sub)] == sub || contains(s[1:], sub))))
}