GoTest/internal/middleware/rate_limit.go
2025-11-29 03:27:19 +08:00

208 lines
4.7 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package middleware
import (
"fmt"
"net/http"
"strconv"
"time"
"yinli-api/pkg/cache"
"yinli-api/pkg/config"
"github.com/gin-gonic/gin"
"golang.org/x/time/rate"
)
// RateLimitMiddleware Redis频率限制中间件
func RateLimitMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
cfg := config.AppConfig
if cfg == nil || !cfg.RateLimit.Enabled {
c.Next()
return
}
// 获取客户端IP
clientIP := c.ClientIP()
key := fmt.Sprintf("rate_limit:%s", clientIP)
// 获取当前请求数
currentRequests, err := cache.GetString(key)
if err != nil {
// 键不存在设置为1
err = cache.SetString(key, "1", time.Duration(cfg.RateLimit.Window)*time.Second)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"code": 500,
"message": "内部服务器错误",
})
c.Abort()
return
}
c.Next()
return
}
// 转换为整数
requests, err := strconv.Atoi(currentRequests)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"code": 500,
"message": "内部服务器错误",
})
c.Abort()
return
}
// 检查是否超过限制
if requests >= cfg.RateLimit.Requests {
// 获取剩余时间
ttl, _ := cache.TTL(key)
c.Header("X-RateLimit-Limit", strconv.Itoa(cfg.RateLimit.Requests))
c.Header("X-RateLimit-Remaining", "0")
c.Header("X-RateLimit-Reset", strconv.FormatInt(time.Now().Add(ttl).Unix(), 10))
c.JSON(http.StatusTooManyRequests, gin.H{
"code": 429,
"message": "请求过于频繁,请稍后再试",
"retry_after": int(ttl.Seconds()),
})
c.Abort()
return
}
// 增加请求计数
newCount, err := cache.IncrBy(key, 1)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"code": 500,
"message": "内部服务器错误",
})
c.Abort()
return
}
// 设置响应头
remaining := cfg.RateLimit.Requests - int(newCount)
if remaining < 0 {
remaining = 0
}
c.Header("X-RateLimit-Limit", strconv.Itoa(cfg.RateLimit.Requests))
c.Header("X-RateLimit-Remaining", strconv.Itoa(remaining))
ttl, _ := cache.TTL(key)
c.Header("X-RateLimit-Reset", strconv.FormatInt(time.Now().Add(ttl).Unix(), 10))
c.Next()
}
}
// InMemoryRateLimitMiddleware 内存频率限制中间件(备用方案)
func InMemoryRateLimitMiddleware() gin.HandlerFunc {
limiters := make(map[string]*rate.Limiter)
return func(c *gin.Context) {
cfg := config.AppConfig
if cfg == nil || !cfg.RateLimit.Enabled {
c.Next()
return
}
clientIP := c.ClientIP()
limiter, exists := limiters[clientIP]
if !exists {
// 创建新的限制器
limiter = rate.NewLimiter(rate.Every(time.Duration(cfg.RateLimit.Window)*time.Second/time.Duration(cfg.RateLimit.Requests)), cfg.RateLimit.Requests)
limiters[clientIP] = limiter
}
if !limiter.Allow() {
c.JSON(http.StatusTooManyRequests, gin.H{
"code": 429,
"message": "请求过于频繁,请稍后再试",
})
c.Abort()
return
}
c.Next()
}
}
// UserRateLimitMiddleware 基于用户的频率限制中间件
func UserRateLimitMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
cfg := config.AppConfig
if cfg == nil || !cfg.RateLimit.Enabled {
c.Next()
return
}
// 获取用户ID
userID, exists := GetUserID(c)
if !exists {
// 如果没有用户ID使用IP限制
RateLimitMiddleware()(c)
return
}
key := fmt.Sprintf("user_rate_limit:%d", userID)
// 获取当前请求数
currentRequests, err := cache.GetString(key)
if err != nil {
// 键不存在设置为1
err = cache.SetString(key, "1", time.Duration(cfg.RateLimit.Window)*time.Second)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"code": 500,
"message": "内部服务器错误",
})
c.Abort()
return
}
c.Next()
return
}
// 转换为整数
requests, err := strconv.Atoi(currentRequests)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"code": 500,
"message": "内部服务器错误",
})
c.Abort()
return
}
// 检查是否超过限制(认证用户可以有更高的限制)
userLimit := cfg.RateLimit.Requests * 2 // 认证用户限制翻倍
if requests >= userLimit {
ttl, _ := cache.TTL(key)
c.JSON(http.StatusTooManyRequests, gin.H{
"code": 429,
"message": "请求过于频繁,请稍后再试",
"retry_after": int(ttl.Seconds()),
})
c.Abort()
return
}
// 增加请求计数
_, err = cache.IncrBy(key, 1)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"code": 500,
"message": "内部服务器错误",
})
c.Abort()
return
}
c.Next()
}
}