208 lines
4.7 KiB
Go
208 lines
4.7 KiB
Go
package middleware
|
||
|
||
import (
|
||
"fmt"
|
||
"net/http"
|
||
"strconv"
|
||
"time"
|
||
|
||
"yinli-api/pkg/cache"
|
||
"yinli-api/pkg/config"
|
||
|
||
"github.com/gin-gonic/gin"
|
||
"golang.org/x/time/rate"
|
||
)
|
||
|
||
// RateLimitMiddleware Redis频率限制中间件
|
||
func RateLimitMiddleware() gin.HandlerFunc {
|
||
return func(c *gin.Context) {
|
||
cfg := config.AppConfig
|
||
if cfg == nil || !cfg.RateLimit.Enabled {
|
||
c.Next()
|
||
return
|
||
}
|
||
|
||
// 获取客户端IP
|
||
clientIP := c.ClientIP()
|
||
key := fmt.Sprintf("rate_limit:%s", clientIP)
|
||
|
||
// 获取当前请求数
|
||
currentRequests, err := cache.GetString(key)
|
||
if err != nil {
|
||
// 键不存在,设置为1
|
||
err = cache.SetString(key, "1", time.Duration(cfg.RateLimit.Window)*time.Second)
|
||
if err != nil {
|
||
c.JSON(http.StatusInternalServerError, gin.H{
|
||
"code": 500,
|
||
"message": "内部服务器错误",
|
||
})
|
||
c.Abort()
|
||
return
|
||
}
|
||
c.Next()
|
||
return
|
||
}
|
||
|
||
// 转换为整数
|
||
requests, err := strconv.Atoi(currentRequests)
|
||
if err != nil {
|
||
c.JSON(http.StatusInternalServerError, gin.H{
|
||
"code": 500,
|
||
"message": "内部服务器错误",
|
||
})
|
||
c.Abort()
|
||
return
|
||
}
|
||
|
||
// 检查是否超过限制
|
||
if requests >= cfg.RateLimit.Requests {
|
||
// 获取剩余时间
|
||
ttl, _ := cache.TTL(key)
|
||
c.Header("X-RateLimit-Limit", strconv.Itoa(cfg.RateLimit.Requests))
|
||
c.Header("X-RateLimit-Remaining", "0")
|
||
c.Header("X-RateLimit-Reset", strconv.FormatInt(time.Now().Add(ttl).Unix(), 10))
|
||
|
||
c.JSON(http.StatusTooManyRequests, gin.H{
|
||
"code": 429,
|
||
"message": "请求过于频繁,请稍后再试",
|
||
"retry_after": int(ttl.Seconds()),
|
||
})
|
||
c.Abort()
|
||
return
|
||
}
|
||
|
||
// 增加请求计数
|
||
newCount, err := cache.IncrBy(key, 1)
|
||
if err != nil {
|
||
c.JSON(http.StatusInternalServerError, gin.H{
|
||
"code": 500,
|
||
"message": "内部服务器错误",
|
||
})
|
||
c.Abort()
|
||
return
|
||
}
|
||
|
||
// 设置响应头
|
||
remaining := cfg.RateLimit.Requests - int(newCount)
|
||
if remaining < 0 {
|
||
remaining = 0
|
||
}
|
||
|
||
c.Header("X-RateLimit-Limit", strconv.Itoa(cfg.RateLimit.Requests))
|
||
c.Header("X-RateLimit-Remaining", strconv.Itoa(remaining))
|
||
|
||
ttl, _ := cache.TTL(key)
|
||
c.Header("X-RateLimit-Reset", strconv.FormatInt(time.Now().Add(ttl).Unix(), 10))
|
||
|
||
c.Next()
|
||
}
|
||
}
|
||
|
||
// InMemoryRateLimitMiddleware 内存频率限制中间件(备用方案)
|
||
func InMemoryRateLimitMiddleware() gin.HandlerFunc {
|
||
limiters := make(map[string]*rate.Limiter)
|
||
|
||
return func(c *gin.Context) {
|
||
cfg := config.AppConfig
|
||
if cfg == nil || !cfg.RateLimit.Enabled {
|
||
c.Next()
|
||
return
|
||
}
|
||
|
||
clientIP := c.ClientIP()
|
||
|
||
limiter, exists := limiters[clientIP]
|
||
if !exists {
|
||
// 创建新的限制器
|
||
limiter = rate.NewLimiter(rate.Every(time.Duration(cfg.RateLimit.Window)*time.Second/time.Duration(cfg.RateLimit.Requests)), cfg.RateLimit.Requests)
|
||
limiters[clientIP] = limiter
|
||
}
|
||
|
||
if !limiter.Allow() {
|
||
c.JSON(http.StatusTooManyRequests, gin.H{
|
||
"code": 429,
|
||
"message": "请求过于频繁,请稍后再试",
|
||
})
|
||
c.Abort()
|
||
return
|
||
}
|
||
|
||
c.Next()
|
||
}
|
||
}
|
||
|
||
// UserRateLimitMiddleware 基于用户的频率限制中间件
|
||
func UserRateLimitMiddleware() gin.HandlerFunc {
|
||
return func(c *gin.Context) {
|
||
cfg := config.AppConfig
|
||
if cfg == nil || !cfg.RateLimit.Enabled {
|
||
c.Next()
|
||
return
|
||
}
|
||
|
||
// 获取用户ID
|
||
userID, exists := GetUserID(c)
|
||
if !exists {
|
||
// 如果没有用户ID,使用IP限制
|
||
RateLimitMiddleware()(c)
|
||
return
|
||
}
|
||
|
||
key := fmt.Sprintf("user_rate_limit:%d", userID)
|
||
|
||
// 获取当前请求数
|
||
currentRequests, err := cache.GetString(key)
|
||
if err != nil {
|
||
// 键不存在,设置为1
|
||
err = cache.SetString(key, "1", time.Duration(cfg.RateLimit.Window)*time.Second)
|
||
if err != nil {
|
||
c.JSON(http.StatusInternalServerError, gin.H{
|
||
"code": 500,
|
||
"message": "内部服务器错误",
|
||
})
|
||
c.Abort()
|
||
return
|
||
}
|
||
c.Next()
|
||
return
|
||
}
|
||
|
||
// 转换为整数
|
||
requests, err := strconv.Atoi(currentRequests)
|
||
if err != nil {
|
||
c.JSON(http.StatusInternalServerError, gin.H{
|
||
"code": 500,
|
||
"message": "内部服务器错误",
|
||
})
|
||
c.Abort()
|
||
return
|
||
}
|
||
|
||
// 检查是否超过限制(认证用户可以有更高的限制)
|
||
userLimit := cfg.RateLimit.Requests * 2 // 认证用户限制翻倍
|
||
if requests >= userLimit {
|
||
ttl, _ := cache.TTL(key)
|
||
c.JSON(http.StatusTooManyRequests, gin.H{
|
||
"code": 429,
|
||
"message": "请求过于频繁,请稍后再试",
|
||
"retry_after": int(ttl.Seconds()),
|
||
})
|
||
c.Abort()
|
||
return
|
||
}
|
||
|
||
// 增加请求计数
|
||
_, err = cache.IncrBy(key, 1)
|
||
if err != nil {
|
||
c.JSON(http.StatusInternalServerError, gin.H{
|
||
"code": 500,
|
||
"message": "内部服务器错误",
|
||
})
|
||
c.Abort()
|
||
return
|
||
}
|
||
|
||
c.Next()
|
||
}
|
||
}
|