Files
sub2api/backend/internal/middleware/rate_limiter.go
yangjianbo 18b8bd43ad fix(限流): 原子化 Redis 限流并支持故障策略
使用 Lua 脚本原子设置计数与过期,修复 TTL 缺失\n支持 fail-open/fail-close 并对优惠码验证启用 fail-close\n新增单元与集成测试覆盖关键分支\n\n测试:go test ./...
2026-01-11 22:21:05 +08:00

112 lines
2.6 KiB
Go

package middleware
import (
"context"
"net/http"
"time"
"github.com/gin-gonic/gin"
"github.com/redis/go-redis/v9"
)
// RateLimitFailureMode Redis 故障策略
type RateLimitFailureMode int
const (
RateLimitFailOpen RateLimitFailureMode = iota
RateLimitFailClose
)
// RateLimitOptions 限流可选配置
type RateLimitOptions struct {
FailureMode RateLimitFailureMode
}
var rateLimitScript = redis.NewScript(`
local current = redis.call('INCR', KEYS[1])
local ttl = redis.call('PTTL', KEYS[1])
if current == 1 or ttl == -1 then
redis.call('PEXPIRE', KEYS[1], ARGV[1])
end
return current
`)
// rateLimitRun 允许测试覆写脚本执行逻辑
var rateLimitRun = func(ctx context.Context, client *redis.Client, key string, windowMillis int64) (int64, error) {
return rateLimitScript.Run(ctx, client, []string{key}, windowMillis).Int64()
}
// RateLimiter Redis 速率限制器
type RateLimiter struct {
redis *redis.Client
prefix string
}
// NewRateLimiter 创建速率限制器实例
func NewRateLimiter(redisClient *redis.Client) *RateLimiter {
return &RateLimiter{
redis: redisClient,
prefix: "rate_limit:",
}
}
// Limit 返回速率限制中间件
// key: 限制类型标识
// limit: 时间窗口内最大请求数
// window: 时间窗口
func (r *RateLimiter) Limit(key string, limit int, window time.Duration) gin.HandlerFunc {
return r.LimitWithOptions(key, limit, window, RateLimitOptions{})
}
// LimitWithOptions 返回速率限制中间件(带可选配置)
func (r *RateLimiter) LimitWithOptions(key string, limit int, window time.Duration, opts RateLimitOptions) gin.HandlerFunc {
failureMode := opts.FailureMode
if failureMode != RateLimitFailClose {
failureMode = RateLimitFailOpen
}
return func(c *gin.Context) {
ip := c.ClientIP()
redisKey := r.prefix + key + ":" + ip
ctx := c.Request.Context()
windowMillis := windowTTLMillis(window)
// 使用 Lua 脚本原子操作增加计数并设置过期
count, err := rateLimitRun(ctx, r.redis, redisKey, windowMillis)
if err != nil {
if failureMode == RateLimitFailClose {
abortRateLimit(c)
return
}
// Redis 错误时放行,避免影响正常服务
c.Next()
return
}
// 超过限制
if count > int64(limit) {
abortRateLimit(c)
return
}
c.Next()
}
}
func windowTTLMillis(window time.Duration) int64 {
ttl := window.Milliseconds()
if ttl < 1 {
return 1
}
return ttl
}
func abortRateLimit(c *gin.Context) {
c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{
"error": "rate limit exceeded",
"message": "Too many requests, please try again later",
})
}