fix(限流): 原子化 Redis 限流并支持故障策略

使用 Lua 脚本原子设置计数与过期,修复 TTL 缺失\n支持 fail-open/fail-close 并对优惠码验证启用 fail-close\n新增单元与集成测试覆盖关键分支\n\n测试:go test ./...
This commit is contained in:
yangjianbo
2026-01-11 22:21:05 +08:00
parent 5f80760a8c
commit 18b8bd43ad
8 changed files with 378 additions and 13 deletions

View File

@@ -1,6 +1,7 @@
package middleware
import (
"context"
"net/http"
"time"
@@ -8,6 +9,33 @@ import (
"github.com/redis/go-redis/v9"
)
// RateLimitFailureMode Redis 故障策略
type RateLimitFailureMode int
const (
RateLimitFailOpen RateLimitFailureMode = iota
RateLimitFailClose
)
// RateLimitOptions 限流可选配置
type RateLimitOptions struct {
FailureMode RateLimitFailureMode
}
var rateLimitScript = redis.NewScript(`
local current = redis.call('INCR', KEYS[1])
local ttl = redis.call('PTTL', KEYS[1])
if current == 1 or ttl == -1 then
redis.call('PEXPIRE', KEYS[1], ARGV[1])
end
return current
`)
// rateLimitRun 允许测试覆写脚本执行逻辑
var rateLimitRun = func(ctx context.Context, client *redis.Client, key string, windowMillis int64) (int64, error) {
return rateLimitScript.Run(ctx, client, []string{key}, windowMillis).Int64()
}
// RateLimiter Redis 速率限制器
type RateLimiter struct {
redis *redis.Client
@@ -27,34 +55,57 @@ func NewRateLimiter(redisClient *redis.Client) *RateLimiter {
// limit: 时间窗口内最大请求数
// window: 时间窗口
func (r *RateLimiter) Limit(key string, limit int, window time.Duration) gin.HandlerFunc {
return r.LimitWithOptions(key, limit, window, RateLimitOptions{})
}
// LimitWithOptions 返回速率限制中间件(带可选配置)
func (r *RateLimiter) LimitWithOptions(key string, limit int, window time.Duration, opts RateLimitOptions) gin.HandlerFunc {
failureMode := opts.FailureMode
if failureMode != RateLimitFailClose {
failureMode = RateLimitFailOpen
}
return func(c *gin.Context) {
ip := c.ClientIP()
redisKey := r.prefix + key + ":" + ip
ctx := c.Request.Context()
// 使用 INCR 原子操作增加计数
count, err := r.redis.Incr(ctx, redisKey).Result()
windowMillis := windowTTLMillis(window)
// 使用 Lua 脚本原子操作增加计数并设置过期
count, err := rateLimitRun(ctx, r.redis, redisKey, windowMillis)
if err != nil {
if failureMode == RateLimitFailClose {
abortRateLimit(c)
return
}
// Redis 错误时放行,避免影响正常服务
c.Next()
return
}
// 首次访问时设置过期时间
if count == 1 {
r.redis.Expire(ctx, redisKey, window)
}
// 超过限制
if count > int64(limit) {
c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{
"error": "rate limit exceeded",
"message": "Too many requests, please try again later",
})
abortRateLimit(c)
return
}
c.Next()
}
}
func windowTTLMillis(window time.Duration) int64 {
ttl := window.Milliseconds()
if ttl < 1 {
return 1
}
return ttl
}
func abortRateLimit(c *gin.Context) {
c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{
"error": "rate limit exceeded",
"message": "Too many requests, please try again later",
})
}

View File

@@ -0,0 +1,114 @@
//go:build integration
package middleware
import (
"context"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/require"
tcredis "github.com/testcontainers/testcontainers-go/modules/redis"
)
const redisImageTag = "redis:8.4-alpine"
func TestRateLimiterSetsTTLAndDoesNotRefresh(t *testing.T) {
gin.SetMode(gin.TestMode)
ctx := context.Background()
rdb := startRedis(t, ctx)
limiter := NewRateLimiter(rdb)
router := gin.New()
router.Use(limiter.Limit("ttl-test", 10, 2*time.Second))
router.GET("/test", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"ok": true})
})
recorder := performRequest(router)
require.Equal(t, http.StatusOK, recorder.Code)
redisKey := limiter.prefix + "ttl-test:127.0.0.1"
ttlBefore, err := rdb.PTTL(ctx, redisKey).Result()
require.NoError(t, err)
require.Greater(t, ttlBefore, time.Duration(0))
require.LessOrEqual(t, ttlBefore, 2*time.Second)
time.Sleep(50 * time.Millisecond)
recorder = performRequest(router)
require.Equal(t, http.StatusOK, recorder.Code)
ttlAfter, err := rdb.PTTL(ctx, redisKey).Result()
require.NoError(t, err)
require.Less(t, ttlAfter, ttlBefore)
}
func TestRateLimiterFixesMissingTTL(t *testing.T) {
gin.SetMode(gin.TestMode)
ctx := context.Background()
rdb := startRedis(t, ctx)
limiter := NewRateLimiter(rdb)
router := gin.New()
router.Use(limiter.Limit("ttl-missing", 10, 2*time.Second))
router.GET("/test", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"ok": true})
})
redisKey := limiter.prefix + "ttl-missing:127.0.0.1"
require.NoError(t, rdb.Set(ctx, redisKey, 5, 0).Err())
ttlBefore, err := rdb.PTTL(ctx, redisKey).Result()
require.NoError(t, err)
require.Less(t, ttlBefore, time.Duration(0))
recorder := performRequest(router)
require.Equal(t, http.StatusOK, recorder.Code)
ttlAfter, err := rdb.PTTL(ctx, redisKey).Result()
require.NoError(t, err)
require.Greater(t, ttlAfter, time.Duration(0))
}
func performRequest(router *gin.Engine) *httptest.ResponseRecorder {
req := httptest.NewRequest(http.MethodGet, "/test", nil)
req.RemoteAddr = "127.0.0.1:1234"
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, req)
return recorder
}
func startRedis(t *testing.T, ctx context.Context) *redis.Client {
t.Helper()
redisContainer, err := tcredis.Run(ctx, redisImageTag)
require.NoError(t, err)
t.Cleanup(func() {
_ = redisContainer.Terminate(ctx)
})
redisHost, err := redisContainer.Host(ctx)
require.NoError(t, err)
redisPort, err := redisContainer.MappedPort(ctx, "6379/tcp")
require.NoError(t, err)
rdb := redis.NewClient(&redis.Options{
Addr: fmt.Sprintf("%s:%d", redisHost, redisPort.Int()),
DB: 0,
})
require.NoError(t, rdb.Ping(ctx).Err())
t.Cleanup(func() {
_ = rdb.Close()
})
return rdb
}

View File

@@ -0,0 +1,100 @@
package middleware
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/require"
)
func TestWindowTTLMillis(t *testing.T) {
require.Equal(t, int64(1), windowTTLMillis(500*time.Microsecond))
require.Equal(t, int64(1), windowTTLMillis(1500*time.Microsecond))
require.Equal(t, int64(2), windowTTLMillis(2500*time.Microsecond))
}
func TestRateLimiterFailureModes(t *testing.T) {
gin.SetMode(gin.TestMode)
rdb := redis.NewClient(&redis.Options{
Addr: "127.0.0.1:1",
DialTimeout: 50 * time.Millisecond,
ReadTimeout: 50 * time.Millisecond,
WriteTimeout: 50 * time.Millisecond,
})
t.Cleanup(func() {
_ = rdb.Close()
})
limiter := NewRateLimiter(rdb)
failOpenRouter := gin.New()
failOpenRouter.Use(limiter.Limit("test", 1, time.Second))
failOpenRouter.GET("/test", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"ok": true})
})
req := httptest.NewRequest(http.MethodGet, "/test", nil)
req.RemoteAddr = "127.0.0.1:1234"
recorder := httptest.NewRecorder()
failOpenRouter.ServeHTTP(recorder, req)
require.Equal(t, http.StatusOK, recorder.Code)
failCloseRouter := gin.New()
failCloseRouter.Use(limiter.LimitWithOptions("test", 1, time.Second, RateLimitOptions{
FailureMode: RateLimitFailClose,
}))
failCloseRouter.GET("/test", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"ok": true})
})
req = httptest.NewRequest(http.MethodGet, "/test", nil)
req.RemoteAddr = "127.0.0.1:1234"
recorder = httptest.NewRecorder()
failCloseRouter.ServeHTTP(recorder, req)
require.Equal(t, http.StatusTooManyRequests, recorder.Code)
}
func TestRateLimiterSuccessAndLimit(t *testing.T) {
gin.SetMode(gin.TestMode)
originalRun := rateLimitRun
counts := []int64{1, 2}
callIndex := 0
rateLimitRun = func(ctx context.Context, client *redis.Client, key string, windowMillis int64) (int64, error) {
if callIndex >= len(counts) {
return counts[len(counts)-1], nil
}
value := counts[callIndex]
callIndex++
return value, nil
}
t.Cleanup(func() {
rateLimitRun = originalRun
})
limiter := NewRateLimiter(redis.NewClient(&redis.Options{Addr: "127.0.0.1:1"}))
router := gin.New()
router.Use(limiter.Limit("test", 1, time.Second))
router.GET("/test", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"ok": true})
})
req := httptest.NewRequest(http.MethodGet, "/test", nil)
req.RemoteAddr = "127.0.0.1:1234"
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, req)
require.Equal(t, http.StatusOK, recorder.Code)
req = httptest.NewRequest(http.MethodGet, "/test", nil)
req.RemoteAddr = "127.0.0.1:1234"
recorder = httptest.NewRecorder()
router.ServeHTTP(recorder, req)
require.Equal(t, http.StatusTooManyRequests, recorder.Code)
}

View File

@@ -27,8 +27,10 @@ func RegisterAuthRoutes(
auth.POST("/register", h.Auth.Register)
auth.POST("/login", h.Auth.Login)
auth.POST("/send-verify-code", h.Auth.SendVerifyCode)
// 优惠码验证接口添加速率限制:每分钟最多 10 次
auth.POST("/validate-promo-code", rateLimiter.Limit("validate-promo", 10, time.Minute), h.Auth.ValidatePromoCode)
// 优惠码验证接口添加速率限制:每分钟最多 10 次Redis 故障时 fail-close
auth.POST("/validate-promo-code", rateLimiter.LimitWithOptions("validate-promo", 10, time.Minute, middleware.RateLimitOptions{
FailureMode: middleware.RateLimitFailClose,
}), h.Auth.ValidatePromoCode)
auth.GET("/oauth/linuxdo/start", h.Auth.LinuxDoOAuthStart)
auth.GET("/oauth/linuxdo/callback", h.Auth.LinuxDoOAuthCallback)
}

View File

@@ -0,0 +1,37 @@
## Context
限流中间件当前采用 `INCR``EXPIRE` 的两步操作,且未处理 `EXPIRE` 失败,导致计数 key 可能没有过期时间。该情况一旦发生,计数会持续累加,触发长期限流并造成 Redis key 膨胀。
## Goals / Non-Goals
- Goals:
- 原子化 Redis 计数与过期设置
- 修复 TTL 缺失的历史 key
- 支持按接口配置 Redis 故障策略fail-open/fail-close
- 为需要强制保护的接口启用 fail-close
- Non-Goals:
- 改变现有固定窗口限流算法
- 调整限流 key 格式或前缀
- 引入新的外部依赖
## Decisions
- 使用 Lua 脚本在 Redis 内部原子执行 `INCR``TTL``PEXPIRE`
- 过期时间统一采用毫秒精度窗口(`window.Milliseconds()` 向下取整)以保持精度一致
- 当毫秒窗口小于 1 时,按 1ms 设置过期,避免 0 导致立即过期
-`count == 1``TTL == -1` 时设置过期,避免刷新已有 TTL
- 新增 `RateLimitOptions` 并提供 `LimitWithOptions`,由调用方显式配置故障策略
- `Limit` 默认使用 fail-open 以保持兼容
- 当 fail-close 生效时Redis 执行失败直接返回 429
## Alternatives considered
- 使用 `MULTI/EXEC` 事务封装 `INCR` + `EXPIRE`:原子性可保证,但无法在同一事务内便捷修复 `TTL == -1`,且仍需额外判断逻辑
- 使用 `SET` + `EX`/`NX` 组合:无法保留计数累加语义
## Risks / Trade-offs
- Lua 脚本会带来轻微 CPU 开销,但可接受
- TTL 修复会在首次访问时设定过期,可能缩短历史脏 key 的“无限期”状态,这是期望的修复效果
## Migration Plan
- 上线后脚本在请求路径上自动修复 TTL 缺失的 key
- 如需回滚,恢复原有两步命令即可
## Open Questions
-

View File

@@ -0,0 +1,14 @@
# Change: 原子化 Redis 限流 TTL 设置
## Why
当前限流逻辑使用 `INCR` 后再 `EXPIRE`,非原子且未处理 `EXPIRE` 失败,会导致 key 可能永久存在,引发长期限流或 Redis 内存增长。
## What Changes
- 使用 Lua 脚本原子化 `INCR` + 过期设置
- 当检测到 TTL 缺失时补设过期,修复历史脏数据
- 支持 Redis 故障策略配置(默认放行,特定接口可 fail-close
- 新增 `limit-requests` capability用于描述限流行为与故障策略
## Impact
- Affected specs: 新增 `specs/limit-requests/spec.md`
- Affected code: `backend/internal/middleware/rate_limiter.go`

View File

@@ -0,0 +1,41 @@
## ADDED Requirements
### Requirement: 原子化限流计数与过期
限流中间件 SHALL 在单个原子操作中完成 Redis 计数增量与过期设置,并且仅在首次创建或 TTL 缺失时设置过期,避免刷新已有 TTL过期时间以毫秒为单位向下取整最小为 1ms。
#### Scenario: 首次请求创建计数器
- **WHEN** 第一次请求命中该限流 key
- **THEN** 计数增量为 1 且 key 过期时间设置为窗口值
#### Scenario: 窗口小于 1ms
- **WHEN** 限流窗口小于 1ms
- **THEN** 过期时间按 1ms 设置
#### Scenario: 窗口包含非整数毫秒
- **WHEN** 限流窗口包含非整数毫秒
- **THEN** 过期时间按毫秒向下取整
#### Scenario: 已有 TTL 的计数器继续计数
- **WHEN** 计数器已存在且 TTL 正常
- **THEN** 计数递增且 TTL 不被刷新
#### Scenario: 计数器缺失 TTL
- **WHEN** 计数器存在但 TTL 为 -1
- **THEN** 系统为该 key 补设窗口过期时间
### Requirement: Redis 故障策略可配置
限流中间件 SHALL 支持为每个限流 key 配置 Redis 故障策略,支持 fail-open 与 fail-close默认 fail-open配置由调用方在注册限流时提供。
#### Scenario: fail-open 策略
- **WHEN** 配置为 fail-open 且 Redis 脚本执行返回错误或连接不可用
- **THEN** 请求继续处理且不执行限流阻断
#### Scenario: fail-close 策略
- **WHEN** 配置为 fail-close 且 Redis 脚本执行返回错误或连接不可用
- **THEN** 请求被限流阻断并返回 429
### Requirement: 优惠码验证接口 fail-close
系统 SHALL 对 `/auth/validate-promo-code` 的限流在 Redis 故障时采用 fail-close。
#### Scenario: 验证优惠码时 Redis 不可用
- **WHEN** 请求 `/auth/validate-promo-code` 且 Redis 不可用
- **THEN** 请求返回 429

View File

@@ -0,0 +1,6 @@
## 1. Implementation
- [x] 1.1 在限流中间件中引入 Lua 脚本原子化计数与过期设置(使用 PEXPIRE 毫秒窗口)
- [x] 1.2 脚本内检测 `TTL == -1` 时补设过期,修复历史脏 key
- [x] 1.3 引入 `RateLimitOptions``LimitWithOptions``Limit` 保持默认 fail-open
- [x] 1.4 为 `/auth/validate-promo-code` 配置 fail-close 策略
- [x] 1.5 添加测试覆盖首次请求、已有 TTL、TTL 缺失、非整数毫秒窗口与故障策略(使用 Redis 集成测试/testcontainers 方案)