package test import ( "net/http" "net/http/httptest" "testing" "yinli-api/internal/middleware" "yinli-api/pkg/auth" "yinli-api/pkg/config" "github.com/gin-gonic/gin" "github.com/stretchr/testify/assert" ) // setupMiddlewareTestRouter 设置中间件测试路由 func setupMiddlewareTestRouter() *gin.Engine { gin.SetMode(gin.TestMode) config.LoadConfig("dev") r := gin.New() return r } // TestAuthMiddleware 测试JWT认证中间件 func TestAuthMiddleware(t *testing.T) { router := setupMiddlewareTestRouter() // 生成测试token token, _ := auth.GenerateToken(1, "testuser") router.GET("/protected", middleware.AuthMiddleware(), func(c *gin.Context) { userID, _ := middleware.GetUserID(c) username, _ := middleware.GetUsername(c) c.JSON(200, gin.H{ "user_id": userID, "username": username, }) }) tests := []struct { name string authHeader string expectedStatus int }{ { name: "有效token", authHeader: "Bearer " + token, expectedStatus: http.StatusOK, }, { name: "无授权头", authHeader: "", expectedStatus: http.StatusUnauthorized, }, { name: "无效token格式", authHeader: "InvalidFormat", expectedStatus: http.StatusUnauthorized, }, { name: "无效token", authHeader: "Bearer invalid-token", expectedStatus: http.StatusUnauthorized, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { req, _ := http.NewRequest("GET", "/protected", nil) if tt.authHeader != "" { req.Header.Set("Authorization", tt.authHeader) } w := httptest.NewRecorder() router.ServeHTTP(w, req) assert.Equal(t, tt.expectedStatus, w.Code) }) } } // TestCORSMiddleware 测试CORS中间件 func TestCORSMiddleware(t *testing.T) { router := setupMiddlewareTestRouter() router.Use(middleware.CORSMiddleware()) router.GET("/test", func(c *gin.Context) { c.JSON(200, gin.H{"message": "ok"}) }) req, _ := http.NewRequest("GET", "/test", nil) req.Header.Set("Origin", "http://localhost:3000") w := httptest.NewRecorder() router.ServeHTTP(w, req) assert.Equal(t, http.StatusOK, w.Code) assert.NotEmpty(t, w.Header().Get("Access-Control-Allow-Origin")) } // TestRequestIDMiddleware 测试请求ID中间件 func TestRequestIDMiddleware(t *testing.T) { router := setupMiddlewareTestRouter() router.Use(middleware.RequestIDMiddleware()) router.GET("/test", func(c *gin.Context) { requestID := middleware.GetRequestID(c) c.JSON(200, gin.H{"request_id": requestID}) }) req, _ := http.NewRequest("GET", "/test", nil) w := httptest.NewRecorder() router.ServeHTTP(w, req) assert.Equal(t, http.StatusOK, w.Code) assert.NotEmpty(t, w.Header().Get("X-Request-ID")) } // TestOptionalAuthMiddleware 测试可选认证中间件 func TestOptionalAuthMiddleware(t *testing.T) { router := setupMiddlewareTestRouter() token, _ := auth.GenerateToken(1, "testuser") router.GET("/optional", middleware.OptionalAuthMiddleware(), func(c *gin.Context) { userID, exists := middleware.GetUserID(c) if exists { c.JSON(200, gin.H{"authenticated": true, "user_id": userID}) } else { c.JSON(200, gin.H{"authenticated": false}) } }) tests := []struct { name string authHeader string hasUserID bool }{ { name: "有效token", authHeader: "Bearer " + token, hasUserID: true, }, { name: "无token", authHeader: "", hasUserID: false, }, { name: "无效token", authHeader: "Bearer invalid-token", hasUserID: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { req, _ := http.NewRequest("GET", "/optional", nil) if tt.authHeader != "" { req.Header.Set("Authorization", tt.authHeader) } w := httptest.NewRecorder() router.ServeHTTP(w, req) assert.Equal(t, http.StatusOK, w.Code) // 这里可以进一步验证响应内容 }) } }