fix(限流): 原子化 Redis 限流并支持故障策略
使用 Lua 脚本原子设置计数与过期,修复 TTL 缺失\n支持 fail-open/fail-close 并对优惠码验证启用 fail-close\n新增单元与集成测试覆盖关键分支\n\n测试:go test ./...
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
@@ -8,6 +9,33 @@ import (
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// RateLimitFailureMode Redis 故障策略
|
||||
type RateLimitFailureMode int
|
||||
|
||||
const (
|
||||
RateLimitFailOpen RateLimitFailureMode = iota
|
||||
RateLimitFailClose
|
||||
)
|
||||
|
||||
// RateLimitOptions 限流可选配置
|
||||
type RateLimitOptions struct {
|
||||
FailureMode RateLimitFailureMode
|
||||
}
|
||||
|
||||
var rateLimitScript = redis.NewScript(`
|
||||
local current = redis.call('INCR', KEYS[1])
|
||||
local ttl = redis.call('PTTL', KEYS[1])
|
||||
if current == 1 or ttl == -1 then
|
||||
redis.call('PEXPIRE', KEYS[1], ARGV[1])
|
||||
end
|
||||
return current
|
||||
`)
|
||||
|
||||
// rateLimitRun 允许测试覆写脚本执行逻辑
|
||||
var rateLimitRun = func(ctx context.Context, client *redis.Client, key string, windowMillis int64) (int64, error) {
|
||||
return rateLimitScript.Run(ctx, client, []string{key}, windowMillis).Int64()
|
||||
}
|
||||
|
||||
// RateLimiter Redis 速率限制器
|
||||
type RateLimiter struct {
|
||||
redis *redis.Client
|
||||
@@ -27,34 +55,57 @@ func NewRateLimiter(redisClient *redis.Client) *RateLimiter {
|
||||
// limit: 时间窗口内最大请求数
|
||||
// window: 时间窗口
|
||||
func (r *RateLimiter) Limit(key string, limit int, window time.Duration) gin.HandlerFunc {
|
||||
return r.LimitWithOptions(key, limit, window, RateLimitOptions{})
|
||||
}
|
||||
|
||||
// LimitWithOptions 返回速率限制中间件(带可选配置)
|
||||
func (r *RateLimiter) LimitWithOptions(key string, limit int, window time.Duration, opts RateLimitOptions) gin.HandlerFunc {
|
||||
failureMode := opts.FailureMode
|
||||
if failureMode != RateLimitFailClose {
|
||||
failureMode = RateLimitFailOpen
|
||||
}
|
||||
|
||||
return func(c *gin.Context) {
|
||||
ip := c.ClientIP()
|
||||
redisKey := r.prefix + key + ":" + ip
|
||||
|
||||
ctx := c.Request.Context()
|
||||
|
||||
// 使用 INCR 原子操作增加计数
|
||||
count, err := r.redis.Incr(ctx, redisKey).Result()
|
||||
windowMillis := windowTTLMillis(window)
|
||||
|
||||
// 使用 Lua 脚本原子操作增加计数并设置过期
|
||||
count, err := rateLimitRun(ctx, r.redis, redisKey, windowMillis)
|
||||
if err != nil {
|
||||
if failureMode == RateLimitFailClose {
|
||||
abortRateLimit(c)
|
||||
return
|
||||
}
|
||||
// Redis 错误时放行,避免影响正常服务
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// 首次访问时设置过期时间
|
||||
if count == 1 {
|
||||
r.redis.Expire(ctx, redisKey, window)
|
||||
}
|
||||
|
||||
// 超过限制
|
||||
if count > int64(limit) {
|
||||
c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{
|
||||
"error": "rate limit exceeded",
|
||||
"message": "Too many requests, please try again later",
|
||||
})
|
||||
abortRateLimit(c)
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func windowTTLMillis(window time.Duration) int64 {
|
||||
ttl := window.Milliseconds()
|
||||
if ttl < 1 {
|
||||
return 1
|
||||
}
|
||||
return ttl
|
||||
}
|
||||
|
||||
func abortRateLimit(c *gin.Context) {
|
||||
c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{
|
||||
"error": "rate limit exceeded",
|
||||
"message": "Too many requests, please try again later",
|
||||
})
|
||||
}
|
||||
|
||||
114
backend/internal/middleware/rate_limiter_integration_test.go
Normal file
114
backend/internal/middleware/rate_limiter_integration_test.go
Normal file
@@ -0,0 +1,114 @@
|
||||
//go:build integration
|
||||
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/stretchr/testify/require"
|
||||
tcredis "github.com/testcontainers/testcontainers-go/modules/redis"
|
||||
)
|
||||
|
||||
const redisImageTag = "redis:8.4-alpine"
|
||||
|
||||
func TestRateLimiterSetsTTLAndDoesNotRefresh(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
ctx := context.Background()
|
||||
rdb := startRedis(t, ctx)
|
||||
limiter := NewRateLimiter(rdb)
|
||||
|
||||
router := gin.New()
|
||||
router.Use(limiter.Limit("ttl-test", 10, 2*time.Second))
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||
})
|
||||
|
||||
recorder := performRequest(router)
|
||||
require.Equal(t, http.StatusOK, recorder.Code)
|
||||
|
||||
redisKey := limiter.prefix + "ttl-test:127.0.0.1"
|
||||
ttlBefore, err := rdb.PTTL(ctx, redisKey).Result()
|
||||
require.NoError(t, err)
|
||||
require.Greater(t, ttlBefore, time.Duration(0))
|
||||
require.LessOrEqual(t, ttlBefore, 2*time.Second)
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
recorder = performRequest(router)
|
||||
require.Equal(t, http.StatusOK, recorder.Code)
|
||||
|
||||
ttlAfter, err := rdb.PTTL(ctx, redisKey).Result()
|
||||
require.NoError(t, err)
|
||||
require.Less(t, ttlAfter, ttlBefore)
|
||||
}
|
||||
|
||||
func TestRateLimiterFixesMissingTTL(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
ctx := context.Background()
|
||||
rdb := startRedis(t, ctx)
|
||||
limiter := NewRateLimiter(rdb)
|
||||
|
||||
router := gin.New()
|
||||
router.Use(limiter.Limit("ttl-missing", 10, 2*time.Second))
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||
})
|
||||
|
||||
redisKey := limiter.prefix + "ttl-missing:127.0.0.1"
|
||||
require.NoError(t, rdb.Set(ctx, redisKey, 5, 0).Err())
|
||||
|
||||
ttlBefore, err := rdb.PTTL(ctx, redisKey).Result()
|
||||
require.NoError(t, err)
|
||||
require.Less(t, ttlBefore, time.Duration(0))
|
||||
|
||||
recorder := performRequest(router)
|
||||
require.Equal(t, http.StatusOK, recorder.Code)
|
||||
|
||||
ttlAfter, err := rdb.PTTL(ctx, redisKey).Result()
|
||||
require.NoError(t, err)
|
||||
require.Greater(t, ttlAfter, time.Duration(0))
|
||||
}
|
||||
|
||||
func performRequest(router *gin.Engine) *httptest.ResponseRecorder {
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
req.RemoteAddr = "127.0.0.1:1234"
|
||||
recorder := httptest.NewRecorder()
|
||||
router.ServeHTTP(recorder, req)
|
||||
return recorder
|
||||
}
|
||||
|
||||
func startRedis(t *testing.T, ctx context.Context) *redis.Client {
|
||||
t.Helper()
|
||||
|
||||
redisContainer, err := tcredis.Run(ctx, redisImageTag)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
_ = redisContainer.Terminate(ctx)
|
||||
})
|
||||
|
||||
redisHost, err := redisContainer.Host(ctx)
|
||||
require.NoError(t, err)
|
||||
redisPort, err := redisContainer.MappedPort(ctx, "6379/tcp")
|
||||
require.NoError(t, err)
|
||||
|
||||
rdb := redis.NewClient(&redis.Options{
|
||||
Addr: fmt.Sprintf("%s:%d", redisHost, redisPort.Int()),
|
||||
DB: 0,
|
||||
})
|
||||
require.NoError(t, rdb.Ping(ctx).Err())
|
||||
|
||||
t.Cleanup(func() {
|
||||
_ = rdb.Close()
|
||||
})
|
||||
|
||||
return rdb
|
||||
}
|
||||
100
backend/internal/middleware/rate_limiter_test.go
Normal file
100
backend/internal/middleware/rate_limiter_test.go
Normal file
@@ -0,0 +1,100 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestWindowTTLMillis(t *testing.T) {
|
||||
require.Equal(t, int64(1), windowTTLMillis(500*time.Microsecond))
|
||||
require.Equal(t, int64(1), windowTTLMillis(1500*time.Microsecond))
|
||||
require.Equal(t, int64(2), windowTTLMillis(2500*time.Microsecond))
|
||||
}
|
||||
|
||||
func TestRateLimiterFailureModes(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
rdb := redis.NewClient(&redis.Options{
|
||||
Addr: "127.0.0.1:1",
|
||||
DialTimeout: 50 * time.Millisecond,
|
||||
ReadTimeout: 50 * time.Millisecond,
|
||||
WriteTimeout: 50 * time.Millisecond,
|
||||
})
|
||||
t.Cleanup(func() {
|
||||
_ = rdb.Close()
|
||||
})
|
||||
|
||||
limiter := NewRateLimiter(rdb)
|
||||
|
||||
failOpenRouter := gin.New()
|
||||
failOpenRouter.Use(limiter.Limit("test", 1, time.Second))
|
||||
failOpenRouter.GET("/test", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
req.RemoteAddr = "127.0.0.1:1234"
|
||||
recorder := httptest.NewRecorder()
|
||||
failOpenRouter.ServeHTTP(recorder, req)
|
||||
require.Equal(t, http.StatusOK, recorder.Code)
|
||||
|
||||
failCloseRouter := gin.New()
|
||||
failCloseRouter.Use(limiter.LimitWithOptions("test", 1, time.Second, RateLimitOptions{
|
||||
FailureMode: RateLimitFailClose,
|
||||
}))
|
||||
failCloseRouter.GET("/test", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||
})
|
||||
|
||||
req = httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
req.RemoteAddr = "127.0.0.1:1234"
|
||||
recorder = httptest.NewRecorder()
|
||||
failCloseRouter.ServeHTTP(recorder, req)
|
||||
require.Equal(t, http.StatusTooManyRequests, recorder.Code)
|
||||
}
|
||||
|
||||
func TestRateLimiterSuccessAndLimit(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
originalRun := rateLimitRun
|
||||
counts := []int64{1, 2}
|
||||
callIndex := 0
|
||||
rateLimitRun = func(ctx context.Context, client *redis.Client, key string, windowMillis int64) (int64, error) {
|
||||
if callIndex >= len(counts) {
|
||||
return counts[len(counts)-1], nil
|
||||
}
|
||||
value := counts[callIndex]
|
||||
callIndex++
|
||||
return value, nil
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
rateLimitRun = originalRun
|
||||
})
|
||||
|
||||
limiter := NewRateLimiter(redis.NewClient(&redis.Options{Addr: "127.0.0.1:1"}))
|
||||
|
||||
router := gin.New()
|
||||
router.Use(limiter.Limit("test", 1, time.Second))
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
req.RemoteAddr = "127.0.0.1:1234"
|
||||
recorder := httptest.NewRecorder()
|
||||
router.ServeHTTP(recorder, req)
|
||||
require.Equal(t, http.StatusOK, recorder.Code)
|
||||
|
||||
req = httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
req.RemoteAddr = "127.0.0.1:1234"
|
||||
recorder = httptest.NewRecorder()
|
||||
router.ServeHTTP(recorder, req)
|
||||
require.Equal(t, http.StatusTooManyRequests, recorder.Code)
|
||||
}
|
||||
@@ -27,8 +27,10 @@ func RegisterAuthRoutes(
|
||||
auth.POST("/register", h.Auth.Register)
|
||||
auth.POST("/login", h.Auth.Login)
|
||||
auth.POST("/send-verify-code", h.Auth.SendVerifyCode)
|
||||
// 优惠码验证接口添加速率限制:每分钟最多 10 次
|
||||
auth.POST("/validate-promo-code", rateLimiter.Limit("validate-promo", 10, time.Minute), h.Auth.ValidatePromoCode)
|
||||
// 优惠码验证接口添加速率限制:每分钟最多 10 次(Redis 故障时 fail-close)
|
||||
auth.POST("/validate-promo-code", rateLimiter.LimitWithOptions("validate-promo", 10, time.Minute, middleware.RateLimitOptions{
|
||||
FailureMode: middleware.RateLimitFailClose,
|
||||
}), h.Auth.ValidatePromoCode)
|
||||
auth.GET("/oauth/linuxdo/start", h.Auth.LinuxDoOAuthStart)
|
||||
auth.GET("/oauth/linuxdo/callback", h.Auth.LinuxDoOAuthCallback)
|
||||
}
|
||||
|
||||
37
openspec/changes/update-rate-limit-ttl-atomic/design.md
Normal file
37
openspec/changes/update-rate-limit-ttl-atomic/design.md
Normal file
@@ -0,0 +1,37 @@
|
||||
## Context
|
||||
限流中间件当前采用 `INCR` 后 `EXPIRE` 的两步操作,且未处理 `EXPIRE` 失败,导致计数 key 可能没有过期时间。该情况一旦发生,计数会持续累加,触发长期限流并造成 Redis key 膨胀。
|
||||
|
||||
## Goals / Non-Goals
|
||||
- Goals:
|
||||
- 原子化 Redis 计数与过期设置
|
||||
- 修复 TTL 缺失的历史 key
|
||||
- 支持按接口配置 Redis 故障策略(fail-open/fail-close)
|
||||
- 为需要强制保护的接口启用 fail-close
|
||||
- Non-Goals:
|
||||
- 改变现有固定窗口限流算法
|
||||
- 调整限流 key 格式或前缀
|
||||
- 引入新的外部依赖
|
||||
|
||||
## Decisions
|
||||
- 使用 Lua 脚本在 Redis 内部原子执行 `INCR`、`TTL` 与 `PEXPIRE`
|
||||
- 过期时间统一采用毫秒精度窗口(`window.Milliseconds()` 向下取整)以保持精度一致
|
||||
- 当毫秒窗口小于 1 时,按 1ms 设置过期,避免 0 导致立即过期
|
||||
- 当 `count == 1` 或 `TTL == -1` 时设置过期,避免刷新已有 TTL
|
||||
- 新增 `RateLimitOptions` 并提供 `LimitWithOptions`,由调用方显式配置故障策略
|
||||
- `Limit` 默认使用 fail-open 以保持兼容
|
||||
- 当 fail-close 生效时,Redis 执行失败直接返回 429
|
||||
|
||||
## Alternatives considered
|
||||
- 使用 `MULTI/EXEC` 事务封装 `INCR` + `EXPIRE`:原子性可保证,但无法在同一事务内便捷修复 `TTL == -1`,且仍需额外判断逻辑
|
||||
- 使用 `SET` + `EX`/`NX` 组合:无法保留计数累加语义
|
||||
|
||||
## Risks / Trade-offs
|
||||
- Lua 脚本会带来轻微 CPU 开销,但可接受
|
||||
- TTL 修复会在首次访问时设定过期,可能缩短历史脏 key 的“无限期”状态,这是期望的修复效果
|
||||
|
||||
## Migration Plan
|
||||
- 上线后脚本在请求路径上自动修复 TTL 缺失的 key
|
||||
- 如需回滚,恢复原有两步命令即可
|
||||
|
||||
## Open Questions
|
||||
- 无
|
||||
14
openspec/changes/update-rate-limit-ttl-atomic/proposal.md
Normal file
14
openspec/changes/update-rate-limit-ttl-atomic/proposal.md
Normal file
@@ -0,0 +1,14 @@
|
||||
# Change: 原子化 Redis 限流 TTL 设置
|
||||
|
||||
## Why
|
||||
当前限流逻辑使用 `INCR` 后再 `EXPIRE`,非原子且未处理 `EXPIRE` 失败,会导致 key 可能永久存在,引发长期限流或 Redis 内存增长。
|
||||
|
||||
## What Changes
|
||||
- 使用 Lua 脚本原子化 `INCR` + 过期设置
|
||||
- 当检测到 TTL 缺失时补设过期,修复历史脏数据
|
||||
- 支持 Redis 故障策略配置(默认放行,特定接口可 fail-close)
|
||||
- 新增 `limit-requests` capability,用于描述限流行为与故障策略
|
||||
|
||||
## Impact
|
||||
- Affected specs: 新增 `specs/limit-requests/spec.md`
|
||||
- Affected code: `backend/internal/middleware/rate_limiter.go`
|
||||
@@ -0,0 +1,41 @@
|
||||
## ADDED Requirements
|
||||
### Requirement: 原子化限流计数与过期
|
||||
限流中间件 SHALL 在单个原子操作中完成 Redis 计数增量与过期设置,并且仅在首次创建或 TTL 缺失时设置过期,避免刷新已有 TTL;过期时间以毫秒为单位向下取整,最小为 1ms。
|
||||
|
||||
#### Scenario: 首次请求创建计数器
|
||||
- **WHEN** 第一次请求命中该限流 key
|
||||
- **THEN** 计数增量为 1 且 key 过期时间设置为窗口值
|
||||
|
||||
#### Scenario: 窗口小于 1ms
|
||||
- **WHEN** 限流窗口小于 1ms
|
||||
- **THEN** 过期时间按 1ms 设置
|
||||
|
||||
#### Scenario: 窗口包含非整数毫秒
|
||||
- **WHEN** 限流窗口包含非整数毫秒
|
||||
- **THEN** 过期时间按毫秒向下取整
|
||||
|
||||
#### Scenario: 已有 TTL 的计数器继续计数
|
||||
- **WHEN** 计数器已存在且 TTL 正常
|
||||
- **THEN** 计数递增且 TTL 不被刷新
|
||||
|
||||
#### Scenario: 计数器缺失 TTL
|
||||
- **WHEN** 计数器存在但 TTL 为 -1
|
||||
- **THEN** 系统为该 key 补设窗口过期时间
|
||||
|
||||
### Requirement: Redis 故障策略可配置
|
||||
限流中间件 SHALL 支持为每个限流 key 配置 Redis 故障策略,支持 fail-open 与 fail-close,默认 fail-open,配置由调用方在注册限流时提供。
|
||||
|
||||
#### Scenario: fail-open 策略
|
||||
- **WHEN** 配置为 fail-open 且 Redis 脚本执行返回错误或连接不可用
|
||||
- **THEN** 请求继续处理且不执行限流阻断
|
||||
|
||||
#### Scenario: fail-close 策略
|
||||
- **WHEN** 配置为 fail-close 且 Redis 脚本执行返回错误或连接不可用
|
||||
- **THEN** 请求被限流阻断并返回 429
|
||||
|
||||
### Requirement: 优惠码验证接口 fail-close
|
||||
系统 SHALL 对 `/auth/validate-promo-code` 的限流在 Redis 故障时采用 fail-close。
|
||||
|
||||
#### Scenario: 验证优惠码时 Redis 不可用
|
||||
- **WHEN** 请求 `/auth/validate-promo-code` 且 Redis 不可用
|
||||
- **THEN** 请求返回 429
|
||||
6
openspec/changes/update-rate-limit-ttl-atomic/tasks.md
Normal file
6
openspec/changes/update-rate-limit-ttl-atomic/tasks.md
Normal file
@@ -0,0 +1,6 @@
|
||||
## 1. Implementation
|
||||
- [x] 1.1 在限流中间件中引入 Lua 脚本原子化计数与过期设置(使用 PEXPIRE 毫秒窗口)
|
||||
- [x] 1.2 脚本内检测 `TTL == -1` 时补设过期,修复历史脏 key
|
||||
- [x] 1.3 引入 `RateLimitOptions` 与 `LimitWithOptions`,`Limit` 保持默认 fail-open
|
||||
- [x] 1.4 为 `/auth/validate-promo-code` 配置 fail-close 策略
|
||||
- [x] 1.5 添加测试覆盖首次请求、已有 TTL、TTL 缺失、非整数毫秒窗口与故障策略(使用 Redis 集成测试/testcontainers 方案)
|
||||
Reference in New Issue
Block a user