diff --git a/backend/internal/middleware/rate_limiter.go b/backend/internal/middleware/rate_limiter.go index 9526f071..13b71683 100644 --- a/backend/internal/middleware/rate_limiter.go +++ b/backend/internal/middleware/rate_limiter.go @@ -1,6 +1,7 @@ package middleware import ( + "context" "net/http" "time" @@ -8,6 +9,33 @@ import ( "github.com/redis/go-redis/v9" ) +// RateLimitFailureMode Redis 故障策略 +type RateLimitFailureMode int + +const ( + RateLimitFailOpen RateLimitFailureMode = iota + RateLimitFailClose +) + +// RateLimitOptions 限流可选配置 +type RateLimitOptions struct { + FailureMode RateLimitFailureMode +} + +var rateLimitScript = redis.NewScript(` +local current = redis.call('INCR', KEYS[1]) +local ttl = redis.call('PTTL', KEYS[1]) +if current == 1 or ttl == -1 then + redis.call('PEXPIRE', KEYS[1], ARGV[1]) +end +return current +`) + +// rateLimitRun 允许测试覆写脚本执行逻辑 +var rateLimitRun = func(ctx context.Context, client *redis.Client, key string, windowMillis int64) (int64, error) { + return rateLimitScript.Run(ctx, client, []string{key}, windowMillis).Int64() +} + // RateLimiter Redis 速率限制器 type RateLimiter struct { redis *redis.Client @@ -27,34 +55,57 @@ func NewRateLimiter(redisClient *redis.Client) *RateLimiter { // limit: 时间窗口内最大请求数 // window: 时间窗口 func (r *RateLimiter) Limit(key string, limit int, window time.Duration) gin.HandlerFunc { + return r.LimitWithOptions(key, limit, window, RateLimitOptions{}) +} + +// LimitWithOptions 返回速率限制中间件(带可选配置) +func (r *RateLimiter) LimitWithOptions(key string, limit int, window time.Duration, opts RateLimitOptions) gin.HandlerFunc { + failureMode := opts.FailureMode + if failureMode != RateLimitFailClose { + failureMode = RateLimitFailOpen + } + return func(c *gin.Context) { ip := c.ClientIP() redisKey := r.prefix + key + ":" + ip ctx := c.Request.Context() - // 使用 INCR 原子操作增加计数 - count, err := r.redis.Incr(ctx, redisKey).Result() + windowMillis := windowTTLMillis(window) + + // 使用 Lua 脚本原子操作增加计数并设置过期 + count, err := rateLimitRun(ctx, r.redis, redisKey, windowMillis) if err != nil { + if failureMode == RateLimitFailClose { + abortRateLimit(c) + return + } // Redis 错误时放行,避免影响正常服务 c.Next() return } - // 首次访问时设置过期时间 - if count == 1 { - r.redis.Expire(ctx, redisKey, window) - } - // 超过限制 if count > int64(limit) { - c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{ - "error": "rate limit exceeded", - "message": "Too many requests, please try again later", - }) + abortRateLimit(c) return } c.Next() } } + +func windowTTLMillis(window time.Duration) int64 { + ttl := window.Milliseconds() + if ttl < 1 { + return 1 + } + return ttl +} + +func abortRateLimit(c *gin.Context) { + c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{ + "error": "rate limit exceeded", + "message": "Too many requests, please try again later", + }) +} diff --git a/backend/internal/middleware/rate_limiter_integration_test.go b/backend/internal/middleware/rate_limiter_integration_test.go new file mode 100644 index 00000000..4759a988 --- /dev/null +++ b/backend/internal/middleware/rate_limiter_integration_test.go @@ -0,0 +1,114 @@ +//go:build integration + +package middleware + +import ( + "context" + "fmt" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/gin-gonic/gin" + "github.com/redis/go-redis/v9" + "github.com/stretchr/testify/require" + tcredis "github.com/testcontainers/testcontainers-go/modules/redis" +) + +const redisImageTag = "redis:8.4-alpine" + +func TestRateLimiterSetsTTLAndDoesNotRefresh(t *testing.T) { + gin.SetMode(gin.TestMode) + + ctx := context.Background() + rdb := startRedis(t, ctx) + limiter := NewRateLimiter(rdb) + + router := gin.New() + router.Use(limiter.Limit("ttl-test", 10, 2*time.Second)) + router.GET("/test", func(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{"ok": true}) + }) + + recorder := performRequest(router) + require.Equal(t, http.StatusOK, recorder.Code) + + redisKey := limiter.prefix + "ttl-test:127.0.0.1" + ttlBefore, err := rdb.PTTL(ctx, redisKey).Result() + require.NoError(t, err) + require.Greater(t, ttlBefore, time.Duration(0)) + require.LessOrEqual(t, ttlBefore, 2*time.Second) + + time.Sleep(50 * time.Millisecond) + + recorder = performRequest(router) + require.Equal(t, http.StatusOK, recorder.Code) + + ttlAfter, err := rdb.PTTL(ctx, redisKey).Result() + require.NoError(t, err) + require.Less(t, ttlAfter, ttlBefore) +} + +func TestRateLimiterFixesMissingTTL(t *testing.T) { + gin.SetMode(gin.TestMode) + + ctx := context.Background() + rdb := startRedis(t, ctx) + limiter := NewRateLimiter(rdb) + + router := gin.New() + router.Use(limiter.Limit("ttl-missing", 10, 2*time.Second)) + router.GET("/test", func(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{"ok": true}) + }) + + redisKey := limiter.prefix + "ttl-missing:127.0.0.1" + require.NoError(t, rdb.Set(ctx, redisKey, 5, 0).Err()) + + ttlBefore, err := rdb.PTTL(ctx, redisKey).Result() + require.NoError(t, err) + require.Less(t, ttlBefore, time.Duration(0)) + + recorder := performRequest(router) + require.Equal(t, http.StatusOK, recorder.Code) + + ttlAfter, err := rdb.PTTL(ctx, redisKey).Result() + require.NoError(t, err) + require.Greater(t, ttlAfter, time.Duration(0)) +} + +func performRequest(router *gin.Engine) *httptest.ResponseRecorder { + req := httptest.NewRequest(http.MethodGet, "/test", nil) + req.RemoteAddr = "127.0.0.1:1234" + recorder := httptest.NewRecorder() + router.ServeHTTP(recorder, req) + return recorder +} + +func startRedis(t *testing.T, ctx context.Context) *redis.Client { + t.Helper() + + redisContainer, err := tcredis.Run(ctx, redisImageTag) + require.NoError(t, err) + t.Cleanup(func() { + _ = redisContainer.Terminate(ctx) + }) + + redisHost, err := redisContainer.Host(ctx) + require.NoError(t, err) + redisPort, err := redisContainer.MappedPort(ctx, "6379/tcp") + require.NoError(t, err) + + rdb := redis.NewClient(&redis.Options{ + Addr: fmt.Sprintf("%s:%d", redisHost, redisPort.Int()), + DB: 0, + }) + require.NoError(t, rdb.Ping(ctx).Err()) + + t.Cleanup(func() { + _ = rdb.Close() + }) + + return rdb +} diff --git a/backend/internal/middleware/rate_limiter_test.go b/backend/internal/middleware/rate_limiter_test.go new file mode 100644 index 00000000..7c72e5be --- /dev/null +++ b/backend/internal/middleware/rate_limiter_test.go @@ -0,0 +1,100 @@ +package middleware + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/gin-gonic/gin" + "github.com/redis/go-redis/v9" + "github.com/stretchr/testify/require" +) + +func TestWindowTTLMillis(t *testing.T) { + require.Equal(t, int64(1), windowTTLMillis(500*time.Microsecond)) + require.Equal(t, int64(1), windowTTLMillis(1500*time.Microsecond)) + require.Equal(t, int64(2), windowTTLMillis(2500*time.Microsecond)) +} + +func TestRateLimiterFailureModes(t *testing.T) { + gin.SetMode(gin.TestMode) + + rdb := redis.NewClient(&redis.Options{ + Addr: "127.0.0.1:1", + DialTimeout: 50 * time.Millisecond, + ReadTimeout: 50 * time.Millisecond, + WriteTimeout: 50 * time.Millisecond, + }) + t.Cleanup(func() { + _ = rdb.Close() + }) + + limiter := NewRateLimiter(rdb) + + failOpenRouter := gin.New() + failOpenRouter.Use(limiter.Limit("test", 1, time.Second)) + failOpenRouter.GET("/test", func(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{"ok": true}) + }) + + req := httptest.NewRequest(http.MethodGet, "/test", nil) + req.RemoteAddr = "127.0.0.1:1234" + recorder := httptest.NewRecorder() + failOpenRouter.ServeHTTP(recorder, req) + require.Equal(t, http.StatusOK, recorder.Code) + + failCloseRouter := gin.New() + failCloseRouter.Use(limiter.LimitWithOptions("test", 1, time.Second, RateLimitOptions{ + FailureMode: RateLimitFailClose, + })) + failCloseRouter.GET("/test", func(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{"ok": true}) + }) + + req = httptest.NewRequest(http.MethodGet, "/test", nil) + req.RemoteAddr = "127.0.0.1:1234" + recorder = httptest.NewRecorder() + failCloseRouter.ServeHTTP(recorder, req) + require.Equal(t, http.StatusTooManyRequests, recorder.Code) +} + +func TestRateLimiterSuccessAndLimit(t *testing.T) { + gin.SetMode(gin.TestMode) + + originalRun := rateLimitRun + counts := []int64{1, 2} + callIndex := 0 + rateLimitRun = func(ctx context.Context, client *redis.Client, key string, windowMillis int64) (int64, error) { + if callIndex >= len(counts) { + return counts[len(counts)-1], nil + } + value := counts[callIndex] + callIndex++ + return value, nil + } + t.Cleanup(func() { + rateLimitRun = originalRun + }) + + limiter := NewRateLimiter(redis.NewClient(&redis.Options{Addr: "127.0.0.1:1"})) + + router := gin.New() + router.Use(limiter.Limit("test", 1, time.Second)) + router.GET("/test", func(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{"ok": true}) + }) + + req := httptest.NewRequest(http.MethodGet, "/test", nil) + req.RemoteAddr = "127.0.0.1:1234" + recorder := httptest.NewRecorder() + router.ServeHTTP(recorder, req) + require.Equal(t, http.StatusOK, recorder.Code) + + req = httptest.NewRequest(http.MethodGet, "/test", nil) + req.RemoteAddr = "127.0.0.1:1234" + recorder = httptest.NewRecorder() + router.ServeHTTP(recorder, req) + require.Equal(t, http.StatusTooManyRequests, recorder.Code) +} diff --git a/backend/internal/server/routes/auth.go b/backend/internal/server/routes/auth.go index 7d8a79e9..aa691eba 100644 --- a/backend/internal/server/routes/auth.go +++ b/backend/internal/server/routes/auth.go @@ -27,8 +27,10 @@ func RegisterAuthRoutes( auth.POST("/register", h.Auth.Register) auth.POST("/login", h.Auth.Login) auth.POST("/send-verify-code", h.Auth.SendVerifyCode) - // 优惠码验证接口添加速率限制:每分钟最多 10 次 - auth.POST("/validate-promo-code", rateLimiter.Limit("validate-promo", 10, time.Minute), h.Auth.ValidatePromoCode) + // 优惠码验证接口添加速率限制:每分钟最多 10 次(Redis 故障时 fail-close) + auth.POST("/validate-promo-code", rateLimiter.LimitWithOptions("validate-promo", 10, time.Minute, middleware.RateLimitOptions{ + FailureMode: middleware.RateLimitFailClose, + }), h.Auth.ValidatePromoCode) auth.GET("/oauth/linuxdo/start", h.Auth.LinuxDoOAuthStart) auth.GET("/oauth/linuxdo/callback", h.Auth.LinuxDoOAuthCallback) } diff --git a/openspec/changes/update-rate-limit-ttl-atomic/design.md b/openspec/changes/update-rate-limit-ttl-atomic/design.md new file mode 100644 index 00000000..55be897a --- /dev/null +++ b/openspec/changes/update-rate-limit-ttl-atomic/design.md @@ -0,0 +1,37 @@ +## Context +限流中间件当前采用 `INCR` 后 `EXPIRE` 的两步操作,且未处理 `EXPIRE` 失败,导致计数 key 可能没有过期时间。该情况一旦发生,计数会持续累加,触发长期限流并造成 Redis key 膨胀。 + +## Goals / Non-Goals +- Goals: + - 原子化 Redis 计数与过期设置 + - 修复 TTL 缺失的历史 key + - 支持按接口配置 Redis 故障策略(fail-open/fail-close) + - 为需要强制保护的接口启用 fail-close +- Non-Goals: + - 改变现有固定窗口限流算法 + - 调整限流 key 格式或前缀 + - 引入新的外部依赖 + +## Decisions +- 使用 Lua 脚本在 Redis 内部原子执行 `INCR`、`TTL` 与 `PEXPIRE` +- 过期时间统一采用毫秒精度窗口(`window.Milliseconds()` 向下取整)以保持精度一致 +- 当毫秒窗口小于 1 时,按 1ms 设置过期,避免 0 导致立即过期 +- 当 `count == 1` 或 `TTL == -1` 时设置过期,避免刷新已有 TTL +- 新增 `RateLimitOptions` 并提供 `LimitWithOptions`,由调用方显式配置故障策略 +- `Limit` 默认使用 fail-open 以保持兼容 +- 当 fail-close 生效时,Redis 执行失败直接返回 429 + +## Alternatives considered +- 使用 `MULTI/EXEC` 事务封装 `INCR` + `EXPIRE`:原子性可保证,但无法在同一事务内便捷修复 `TTL == -1`,且仍需额外判断逻辑 +- 使用 `SET` + `EX`/`NX` 组合:无法保留计数累加语义 + +## Risks / Trade-offs +- Lua 脚本会带来轻微 CPU 开销,但可接受 +- TTL 修复会在首次访问时设定过期,可能缩短历史脏 key 的“无限期”状态,这是期望的修复效果 + +## Migration Plan +- 上线后脚本在请求路径上自动修复 TTL 缺失的 key +- 如需回滚,恢复原有两步命令即可 + +## Open Questions +- 无 diff --git a/openspec/changes/update-rate-limit-ttl-atomic/proposal.md b/openspec/changes/update-rate-limit-ttl-atomic/proposal.md new file mode 100644 index 00000000..ba19528d --- /dev/null +++ b/openspec/changes/update-rate-limit-ttl-atomic/proposal.md @@ -0,0 +1,14 @@ +# Change: 原子化 Redis 限流 TTL 设置 + +## Why +当前限流逻辑使用 `INCR` 后再 `EXPIRE`,非原子且未处理 `EXPIRE` 失败,会导致 key 可能永久存在,引发长期限流或 Redis 内存增长。 + +## What Changes +- 使用 Lua 脚本原子化 `INCR` + 过期设置 +- 当检测到 TTL 缺失时补设过期,修复历史脏数据 +- 支持 Redis 故障策略配置(默认放行,特定接口可 fail-close) +- 新增 `limit-requests` capability,用于描述限流行为与故障策略 + +## Impact +- Affected specs: 新增 `specs/limit-requests/spec.md` +- Affected code: `backend/internal/middleware/rate_limiter.go` diff --git a/openspec/changes/update-rate-limit-ttl-atomic/specs/limit-requests/spec.md b/openspec/changes/update-rate-limit-ttl-atomic/specs/limit-requests/spec.md new file mode 100644 index 00000000..e6f9b73c --- /dev/null +++ b/openspec/changes/update-rate-limit-ttl-atomic/specs/limit-requests/spec.md @@ -0,0 +1,41 @@ +## ADDED Requirements +### Requirement: 原子化限流计数与过期 +限流中间件 SHALL 在单个原子操作中完成 Redis 计数增量与过期设置,并且仅在首次创建或 TTL 缺失时设置过期,避免刷新已有 TTL;过期时间以毫秒为单位向下取整,最小为 1ms。 + +#### Scenario: 首次请求创建计数器 +- **WHEN** 第一次请求命中该限流 key +- **THEN** 计数增量为 1 且 key 过期时间设置为窗口值 + +#### Scenario: 窗口小于 1ms +- **WHEN** 限流窗口小于 1ms +- **THEN** 过期时间按 1ms 设置 + +#### Scenario: 窗口包含非整数毫秒 +- **WHEN** 限流窗口包含非整数毫秒 +- **THEN** 过期时间按毫秒向下取整 + +#### Scenario: 已有 TTL 的计数器继续计数 +- **WHEN** 计数器已存在且 TTL 正常 +- **THEN** 计数递增且 TTL 不被刷新 + +#### Scenario: 计数器缺失 TTL +- **WHEN** 计数器存在但 TTL 为 -1 +- **THEN** 系统为该 key 补设窗口过期时间 + +### Requirement: Redis 故障策略可配置 +限流中间件 SHALL 支持为每个限流 key 配置 Redis 故障策略,支持 fail-open 与 fail-close,默认 fail-open,配置由调用方在注册限流时提供。 + +#### Scenario: fail-open 策略 +- **WHEN** 配置为 fail-open 且 Redis 脚本执行返回错误或连接不可用 +- **THEN** 请求继续处理且不执行限流阻断 + +#### Scenario: fail-close 策略 +- **WHEN** 配置为 fail-close 且 Redis 脚本执行返回错误或连接不可用 +- **THEN** 请求被限流阻断并返回 429 + +### Requirement: 优惠码验证接口 fail-close +系统 SHALL 对 `/auth/validate-promo-code` 的限流在 Redis 故障时采用 fail-close。 + +#### Scenario: 验证优惠码时 Redis 不可用 +- **WHEN** 请求 `/auth/validate-promo-code` 且 Redis 不可用 +- **THEN** 请求返回 429 diff --git a/openspec/changes/update-rate-limit-ttl-atomic/tasks.md b/openspec/changes/update-rate-limit-ttl-atomic/tasks.md new file mode 100644 index 00000000..c033ca95 --- /dev/null +++ b/openspec/changes/update-rate-limit-ttl-atomic/tasks.md @@ -0,0 +1,6 @@ +## 1. Implementation +- [x] 1.1 在限流中间件中引入 Lua 脚本原子化计数与过期设置(使用 PEXPIRE 毫秒窗口) +- [x] 1.2 脚本内检测 `TTL == -1` 时补设过期,修复历史脏 key +- [x] 1.3 引入 `RateLimitOptions` 与 `LimitWithOptions`,`Limit` 保持默认 fail-open +- [x] 1.4 为 `/auth/validate-promo-code` 配置 fail-close 策略 +- [x] 1.5 添加测试覆盖首次请求、已有 TTL、TTL 缺失、非整数毫秒窗口与故障策略(使用 Redis 集成测试/testcontainers 方案)