fix(rate_limiter): 更新速率限制逻辑,支持返回修复状态
This commit is contained in:
@@ -2,7 +2,10 @@ package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -25,15 +28,34 @@ type RateLimitOptions struct {
|
||||
var rateLimitScript = redis.NewScript(`
|
||||
local current = redis.call('INCR', KEYS[1])
|
||||
local ttl = redis.call('PTTL', KEYS[1])
|
||||
if current == 1 or ttl == -1 then
|
||||
local repaired = 0
|
||||
if current == 1 then
|
||||
redis.call('PEXPIRE', KEYS[1], ARGV[1])
|
||||
elseif ttl == -1 then
|
||||
redis.call('PEXPIRE', KEYS[1], ARGV[1])
|
||||
repaired = 1
|
||||
end
|
||||
return current
|
||||
return {current, repaired}
|
||||
`)
|
||||
|
||||
// rateLimitRun 允许测试覆写脚本执行逻辑
|
||||
var rateLimitRun = func(ctx context.Context, client *redis.Client, key string, windowMillis int64) (int64, error) {
|
||||
return rateLimitScript.Run(ctx, client, []string{key}, windowMillis).Int64()
|
||||
var rateLimitRun = func(ctx context.Context, client *redis.Client, key string, windowMillis int64) (int64, bool, error) {
|
||||
values, err := rateLimitScript.Run(ctx, client, []string{key}, windowMillis).Slice()
|
||||
if err != nil {
|
||||
return 0, false, err
|
||||
}
|
||||
if len(values) < 2 {
|
||||
return 0, false, fmt.Errorf("rate limit script returned %d values", len(values))
|
||||
}
|
||||
count, err := parseInt64(values[0])
|
||||
if err != nil {
|
||||
return 0, false, err
|
||||
}
|
||||
repaired, err := parseInt64(values[1])
|
||||
if err != nil {
|
||||
return 0, false, err
|
||||
}
|
||||
return count, repaired == 1, nil
|
||||
}
|
||||
|
||||
// RateLimiter Redis 速率限制器
|
||||
@@ -74,8 +96,9 @@ func (r *RateLimiter) LimitWithOptions(key string, limit int, window time.Durati
|
||||
windowMillis := windowTTLMillis(window)
|
||||
|
||||
// 使用 Lua 脚本原子操作增加计数并设置过期
|
||||
count, err := rateLimitRun(ctx, r.redis, redisKey, windowMillis)
|
||||
count, repaired, err := rateLimitRun(ctx, r.redis, redisKey, windowMillis)
|
||||
if err != nil {
|
||||
log.Printf("[RateLimit] redis error: key=%s mode=%s err=%v", redisKey, failureModeLabel(failureMode), err)
|
||||
if failureMode == RateLimitFailClose {
|
||||
abortRateLimit(c)
|
||||
return
|
||||
@@ -84,6 +107,9 @@ func (r *RateLimiter) LimitWithOptions(key string, limit int, window time.Durati
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
if repaired {
|
||||
log.Printf("[RateLimit] ttl repaired: key=%s window_ms=%d", redisKey, windowMillis)
|
||||
}
|
||||
|
||||
// 超过限制
|
||||
if count > int64(limit) {
|
||||
@@ -109,3 +135,27 @@ func abortRateLimit(c *gin.Context) {
|
||||
"message": "Too many requests, please try again later",
|
||||
})
|
||||
}
|
||||
|
||||
func failureModeLabel(mode RateLimitFailureMode) string {
|
||||
if mode == RateLimitFailClose {
|
||||
return "fail-close"
|
||||
}
|
||||
return "fail-open"
|
||||
}
|
||||
|
||||
func parseInt64(value any) (int64, error) {
|
||||
switch v := value.(type) {
|
||||
case int64:
|
||||
return v, nil
|
||||
case int:
|
||||
return int64(v), nil
|
||||
case string:
|
||||
parsed, err := strconv.ParseInt(v, 10, 64)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return parsed, nil
|
||||
default:
|
||||
return 0, fmt.Errorf("unexpected value type %T", value)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -66,13 +66,13 @@ func TestRateLimiterSuccessAndLimit(t *testing.T) {
|
||||
originalRun := rateLimitRun
|
||||
counts := []int64{1, 2}
|
||||
callIndex := 0
|
||||
rateLimitRun = func(ctx context.Context, client *redis.Client, key string, windowMillis int64) (int64, error) {
|
||||
rateLimitRun = func(ctx context.Context, client *redis.Client, key string, windowMillis int64) (int64, bool, error) {
|
||||
if callIndex >= len(counts) {
|
||||
return counts[len(counts)-1], nil
|
||||
return counts[len(counts)-1], false, nil
|
||||
}
|
||||
value := counts[callIndex]
|
||||
callIndex++
|
||||
return value, nil
|
||||
return value, false, nil
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
rateLimitRun = originalRun
|
||||
|
||||
Reference in New Issue
Block a user