package middleware import ( "fmt" "net/http" "strconv" "time" "yinli-api/pkg/cache" "yinli-api/pkg/config" "github.com/gin-gonic/gin" "golang.org/x/time/rate" ) // RateLimitMiddleware Redis频率限制中间件 func RateLimitMiddleware() gin.HandlerFunc { return func(c *gin.Context) { cfg := config.AppConfig if cfg == nil || !cfg.RateLimit.Enabled { c.Next() return } // 获取客户端IP clientIP := c.ClientIP() key := fmt.Sprintf("rate_limit:%s", clientIP) // 获取当前请求数 currentRequests, err := cache.GetString(key) if err != nil { // 键不存在,设置为1 err = cache.SetString(key, "1", time.Duration(cfg.RateLimit.Window)*time.Second) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{ "code": 500, "message": "内部服务器错误", }) c.Abort() return } c.Next() return } // 转换为整数 requests, err := strconv.Atoi(currentRequests) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{ "code": 500, "message": "内部服务器错误", }) c.Abort() return } // 检查是否超过限制 if requests >= cfg.RateLimit.Requests { // 获取剩余时间 ttl, _ := cache.TTL(key) c.Header("X-RateLimit-Limit", strconv.Itoa(cfg.RateLimit.Requests)) c.Header("X-RateLimit-Remaining", "0") c.Header("X-RateLimit-Reset", strconv.FormatInt(time.Now().Add(ttl).Unix(), 10)) c.JSON(http.StatusTooManyRequests, gin.H{ "code": 429, "message": "请求过于频繁,请稍后再试", "retry_after": int(ttl.Seconds()), }) c.Abort() return } // 增加请求计数 newCount, err := cache.IncrBy(key, 1) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{ "code": 500, "message": "内部服务器错误", }) c.Abort() return } // 设置响应头 remaining := cfg.RateLimit.Requests - int(newCount) if remaining < 0 { remaining = 0 } c.Header("X-RateLimit-Limit", strconv.Itoa(cfg.RateLimit.Requests)) c.Header("X-RateLimit-Remaining", strconv.Itoa(remaining)) ttl, _ := cache.TTL(key) c.Header("X-RateLimit-Reset", strconv.FormatInt(time.Now().Add(ttl).Unix(), 10)) c.Next() } } // InMemoryRateLimitMiddleware 内存频率限制中间件(备用方案) func InMemoryRateLimitMiddleware() gin.HandlerFunc { limiters := make(map[string]*rate.Limiter) return func(c *gin.Context) { cfg := config.AppConfig if cfg == nil || !cfg.RateLimit.Enabled { c.Next() return } clientIP := c.ClientIP() limiter, exists := limiters[clientIP] if !exists { // 创建新的限制器 limiter = rate.NewLimiter(rate.Every(time.Duration(cfg.RateLimit.Window)*time.Second/time.Duration(cfg.RateLimit.Requests)), cfg.RateLimit.Requests) limiters[clientIP] = limiter } if !limiter.Allow() { c.JSON(http.StatusTooManyRequests, gin.H{ "code": 429, "message": "请求过于频繁,请稍后再试", }) c.Abort() return } c.Next() } } // UserRateLimitMiddleware 基于用户的频率限制中间件 func UserRateLimitMiddleware() gin.HandlerFunc { return func(c *gin.Context) { cfg := config.AppConfig if cfg == nil || !cfg.RateLimit.Enabled { c.Next() return } // 获取用户ID userID, exists := GetUserID(c) if !exists { // 如果没有用户ID,使用IP限制 RateLimitMiddleware()(c) return } key := fmt.Sprintf("user_rate_limit:%d", userID) // 获取当前请求数 currentRequests, err := cache.GetString(key) if err != nil { // 键不存在,设置为1 err = cache.SetString(key, "1", time.Duration(cfg.RateLimit.Window)*time.Second) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{ "code": 500, "message": "内部服务器错误", }) c.Abort() return } c.Next() return } // 转换为整数 requests, err := strconv.Atoi(currentRequests) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{ "code": 500, "message": "内部服务器错误", }) c.Abort() return } // 检查是否超过限制(认证用户可以有更高的限制) userLimit := cfg.RateLimit.Requests * 2 // 认证用户限制翻倍 if requests >= userLimit { ttl, _ := cache.TTL(key) c.JSON(http.StatusTooManyRequests, gin.H{ "code": 429, "message": "请求过于频繁,请稍后再试", "retry_after": int(ttl.Seconds()), }) c.Abort() return } // 增加请求计数 _, err = cache.IncrBy(key, 1) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{ "code": 500, "message": "内部服务器错误", }) c.Abort() return } c.Next() } }