package middleware import ( "context" "fmt" "log" "net/http" "strconv" "time" "github.com/gin-gonic/gin" "github.com/redis/go-redis/v9" ) // RateLimitFailureMode Redis 故障策略 type RateLimitFailureMode int const ( RateLimitFailOpen RateLimitFailureMode = iota RateLimitFailClose ) // RateLimitOptions 限流可选配置 type RateLimitOptions struct { FailureMode RateLimitFailureMode } var rateLimitScript = redis.NewScript(` local current = redis.call('INCR', KEYS[1]) local ttl = redis.call('PTTL', KEYS[1]) local repaired = 0 if current == 1 then redis.call('PEXPIRE', KEYS[1], ARGV[1]) elseif ttl == -1 then redis.call('PEXPIRE', KEYS[1], ARGV[1]) repaired = 1 end return {current, repaired} `) // rateLimitRun 允许测试覆写脚本执行逻辑 var rateLimitRun = func(ctx context.Context, client *redis.Client, key string, windowMillis int64) (int64, bool, error) { values, err := rateLimitScript.Run(ctx, client, []string{key}, windowMillis).Slice() if err != nil { return 0, false, err } if len(values) < 2 { return 0, false, fmt.Errorf("rate limit script returned %d values", len(values)) } count, err := parseInt64(values[0]) if err != nil { return 0, false, err } repaired, err := parseInt64(values[1]) if err != nil { return 0, false, err } return count, repaired == 1, nil } // RateLimiter Redis 速率限制器 type RateLimiter struct { redis *redis.Client prefix string } // NewRateLimiter 创建速率限制器实例 func NewRateLimiter(redisClient *redis.Client) *RateLimiter { return &RateLimiter{ redis: redisClient, prefix: "rate_limit:", } } // Limit 返回速率限制中间件 // key: 限制类型标识 // limit: 时间窗口内最大请求数 // window: 时间窗口 func (r *RateLimiter) Limit(key string, limit int, window time.Duration) gin.HandlerFunc { return r.LimitWithOptions(key, limit, window, RateLimitOptions{}) } // LimitWithOptions 返回速率限制中间件(带可选配置) func (r *RateLimiter) LimitWithOptions(key string, limit int, window time.Duration, opts RateLimitOptions) gin.HandlerFunc { failureMode := opts.FailureMode if failureMode != RateLimitFailClose { failureMode = RateLimitFailOpen } return func(c *gin.Context) { ip := c.ClientIP() redisKey := r.prefix + key + ":" + ip ctx := c.Request.Context() windowMillis := windowTTLMillis(window) // 使用 Lua 脚本原子操作增加计数并设置过期 count, repaired, err := rateLimitRun(ctx, r.redis, redisKey, windowMillis) if err != nil { log.Printf("[RateLimit] redis error: key=%s mode=%s err=%v", redisKey, failureModeLabel(failureMode), err) if failureMode == RateLimitFailClose { abortRateLimit(c) return } // Redis 错误时放行,避免影响正常服务 c.Next() return } if repaired { log.Printf("[RateLimit] ttl repaired: key=%s window_ms=%d", redisKey, windowMillis) } // 超过限制 if count > int64(limit) { abortRateLimit(c) return } c.Next() } } func windowTTLMillis(window time.Duration) int64 { ttl := window.Milliseconds() if ttl < 1 { return 1 } return ttl } func abortRateLimit(c *gin.Context) { c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{ "error": "rate limit exceeded", "message": "Too many requests, please try again later", }) } func failureModeLabel(mode RateLimitFailureMode) string { if mode == RateLimitFailClose { return "fail-close" } return "fail-open" } func parseInt64(value any) (int64, error) { switch v := value.(type) { case int64: return v, nil case int: return int64(v), nil case string: parsed, err := strconv.ParseInt(v, 10, 64) if err != nil { return 0, err } return parsed, nil default: return 0, fmt.Errorf("unexpected value type %T", value) } }