diff --git a/backend/internal/middleware/rate_limiter.go b/backend/internal/middleware/rate_limiter.go index 13b71683..819d74c2 100644 --- a/backend/internal/middleware/rate_limiter.go +++ b/backend/internal/middleware/rate_limiter.go @@ -2,7 +2,10 @@ package middleware import ( "context" + "fmt" + "log" "net/http" + "strconv" "time" "github.com/gin-gonic/gin" @@ -25,15 +28,34 @@ type RateLimitOptions struct { var rateLimitScript = redis.NewScript(` local current = redis.call('INCR', KEYS[1]) local ttl = redis.call('PTTL', KEYS[1]) -if current == 1 or ttl == -1 then +local repaired = 0 +if current == 1 then redis.call('PEXPIRE', KEYS[1], ARGV[1]) +elseif ttl == -1 then + redis.call('PEXPIRE', KEYS[1], ARGV[1]) + repaired = 1 end -return current +return {current, repaired} `) // rateLimitRun 允许测试覆写脚本执行逻辑 -var rateLimitRun = func(ctx context.Context, client *redis.Client, key string, windowMillis int64) (int64, error) { - return rateLimitScript.Run(ctx, client, []string{key}, windowMillis).Int64() +var rateLimitRun = func(ctx context.Context, client *redis.Client, key string, windowMillis int64) (int64, bool, error) { + values, err := rateLimitScript.Run(ctx, client, []string{key}, windowMillis).Slice() + if err != nil { + return 0, false, err + } + if len(values) < 2 { + return 0, false, fmt.Errorf("rate limit script returned %d values", len(values)) + } + count, err := parseInt64(values[0]) + if err != nil { + return 0, false, err + } + repaired, err := parseInt64(values[1]) + if err != nil { + return 0, false, err + } + return count, repaired == 1, nil } // RateLimiter Redis 速率限制器 @@ -74,8 +96,9 @@ func (r *RateLimiter) LimitWithOptions(key string, limit int, window time.Durati windowMillis := windowTTLMillis(window) // 使用 Lua 脚本原子操作增加计数并设置过期 - count, err := rateLimitRun(ctx, r.redis, redisKey, windowMillis) + count, repaired, err := rateLimitRun(ctx, r.redis, redisKey, windowMillis) if err != nil { + log.Printf("[RateLimit] redis error: key=%s mode=%s err=%v", redisKey, failureModeLabel(failureMode), err) if failureMode == RateLimitFailClose { abortRateLimit(c) return @@ -84,6 +107,9 @@ func (r *RateLimiter) LimitWithOptions(key string, limit int, window time.Durati c.Next() return } + if repaired { + log.Printf("[RateLimit] ttl repaired: key=%s window_ms=%d", redisKey, windowMillis) + } // 超过限制 if count > int64(limit) { @@ -109,3 +135,27 @@ func abortRateLimit(c *gin.Context) { "message": "Too many requests, please try again later", }) } + +func failureModeLabel(mode RateLimitFailureMode) string { + if mode == RateLimitFailClose { + return "fail-close" + } + return "fail-open" +} + +func parseInt64(value any) (int64, error) { + switch v := value.(type) { + case int64: + return v, nil + case int: + return int64(v), nil + case string: + parsed, err := strconv.ParseInt(v, 10, 64) + if err != nil { + return 0, err + } + return parsed, nil + default: + return 0, fmt.Errorf("unexpected value type %T", value) + } +} diff --git a/backend/internal/middleware/rate_limiter_test.go b/backend/internal/middleware/rate_limiter_test.go index 7c72e5be..0c379c0f 100644 --- a/backend/internal/middleware/rate_limiter_test.go +++ b/backend/internal/middleware/rate_limiter_test.go @@ -66,13 +66,13 @@ func TestRateLimiterSuccessAndLimit(t *testing.T) { originalRun := rateLimitRun counts := []int64{1, 2} callIndex := 0 - rateLimitRun = func(ctx context.Context, client *redis.Client, key string, windowMillis int64) (int64, error) { + rateLimitRun = func(ctx context.Context, client *redis.Client, key string, windowMillis int64) (int64, bool, error) { if callIndex >= len(counts) { - return counts[len(counts)-1], nil + return counts[len(counts)-1], false, nil } value := counts[callIndex] callIndex++ - return value, nil + return value, false, nil } t.Cleanup(func() { rateLimitRun = originalRun