172 lines
3.9 KiB
Go
172 lines
3.9 KiB
Go
package test
|
|
|
|
import (
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
|
|
"yinli-api/src/middleware"
|
|
"yinli-api/src/pkg/auth"
|
|
"yinli-api/src/pkg/config"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/stretchr/testify/assert"
|
|
)
|
|
|
|
// setupMiddlewareTestRouter 设置中间件测试路由
|
|
func setupMiddlewareTestRouter() *gin.Engine {
|
|
gin.SetMode(gin.TestMode)
|
|
config.LoadConfig("dev")
|
|
|
|
r := gin.New()
|
|
return r
|
|
}
|
|
|
|
// TestAuthMiddleware 测试JWT认证中间件
|
|
func TestAuthMiddleware(t *testing.T) {
|
|
router := setupMiddlewareTestRouter()
|
|
|
|
// 生成测试token
|
|
token, _ := auth.GenerateToken(1, "testuser")
|
|
|
|
router.GET("/protected", middleware.AuthMiddleware(), func(c *gin.Context) {
|
|
userID, _ := middleware.GetUserID(c)
|
|
username, _ := middleware.GetUsername(c)
|
|
c.JSON(200, gin.H{
|
|
"user_id": userID,
|
|
"username": username,
|
|
})
|
|
})
|
|
|
|
tests := []struct {
|
|
name string
|
|
authHeader string
|
|
expectedStatus int
|
|
}{
|
|
{
|
|
name: "有效token",
|
|
authHeader: "Bearer " + token,
|
|
expectedStatus: http.StatusOK,
|
|
},
|
|
{
|
|
name: "无授权头",
|
|
authHeader: "",
|
|
expectedStatus: http.StatusUnauthorized,
|
|
},
|
|
{
|
|
name: "无效token格式",
|
|
authHeader: "InvalidFormat",
|
|
expectedStatus: http.StatusUnauthorized,
|
|
},
|
|
{
|
|
name: "无效token",
|
|
authHeader: "Bearer invalid-token",
|
|
expectedStatus: http.StatusUnauthorized,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
req, _ := http.NewRequest("GET", "/protected", nil)
|
|
if tt.authHeader != "" {
|
|
req.Header.Set("Authorization", tt.authHeader)
|
|
}
|
|
|
|
w := httptest.NewRecorder()
|
|
router.ServeHTTP(w, req)
|
|
|
|
assert.Equal(t, tt.expectedStatus, w.Code)
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestCORSMiddleware 测试CORS中间件
|
|
func TestCORSMiddleware(t *testing.T) {
|
|
router := setupMiddlewareTestRouter()
|
|
router.Use(middleware.CORSMiddleware())
|
|
|
|
router.GET("/test", func(c *gin.Context) {
|
|
c.JSON(200, gin.H{"message": "ok"})
|
|
})
|
|
|
|
req, _ := http.NewRequest("GET", "/test", nil)
|
|
req.Header.Set("Origin", "http://localhost:3000")
|
|
|
|
w := httptest.NewRecorder()
|
|
router.ServeHTTP(w, req)
|
|
|
|
assert.Equal(t, http.StatusOK, w.Code)
|
|
assert.NotEmpty(t, w.Header().Get("Access-Control-Allow-Origin"))
|
|
}
|
|
|
|
// TestRequestIDMiddleware 测试请求ID中间件
|
|
func TestRequestIDMiddleware(t *testing.T) {
|
|
router := setupMiddlewareTestRouter()
|
|
router.Use(middleware.RequestIDMiddleware())
|
|
|
|
router.GET("/test", func(c *gin.Context) {
|
|
requestID := middleware.GetRequestID(c)
|
|
c.JSON(200, gin.H{"request_id": requestID})
|
|
})
|
|
|
|
req, _ := http.NewRequest("GET", "/test", nil)
|
|
w := httptest.NewRecorder()
|
|
router.ServeHTTP(w, req)
|
|
|
|
assert.Equal(t, http.StatusOK, w.Code)
|
|
assert.NotEmpty(t, w.Header().Get("X-Request-ID"))
|
|
}
|
|
|
|
// TestOptionalAuthMiddleware 测试可选认证中间件
|
|
func TestOptionalAuthMiddleware(t *testing.T) {
|
|
router := setupMiddlewareTestRouter()
|
|
|
|
token, _ := auth.GenerateToken(1, "testuser")
|
|
|
|
router.GET("/optional", middleware.OptionalAuthMiddleware(), func(c *gin.Context) {
|
|
userID, exists := middleware.GetUserID(c)
|
|
if exists {
|
|
c.JSON(200, gin.H{"authenticated": true, "user_id": userID})
|
|
} else {
|
|
c.JSON(200, gin.H{"authenticated": false})
|
|
}
|
|
})
|
|
|
|
tests := []struct {
|
|
name string
|
|
authHeader string
|
|
hasUserID bool
|
|
}{
|
|
{
|
|
name: "有效token",
|
|
authHeader: "Bearer " + token,
|
|
hasUserID: true,
|
|
},
|
|
{
|
|
name: "无token",
|
|
authHeader: "",
|
|
hasUserID: false,
|
|
},
|
|
{
|
|
name: "无效token",
|
|
authHeader: "Bearer invalid-token",
|
|
hasUserID: false,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
req, _ := http.NewRequest("GET", "/optional", nil)
|
|
if tt.authHeader != "" {
|
|
req.Header.Set("Authorization", tt.authHeader)
|
|
}
|
|
|
|
w := httptest.NewRecorder()
|
|
router.ServeHTTP(w, req)
|
|
|
|
assert.Equal(t, http.StatusOK, w.Code)
|
|
// 这里可以进一步验证响应内容
|
|
})
|
|
}
|
|
}
|