From eb75ff232ff9a3222f497bd3d6ae12789d3a4d43 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9C=8D=E9=9B=A8=E4=BD=B3?= Date: Wed, 16 Apr 2025 10:33:43 +0800 Subject: [PATCH] Refactor: Optimize the request rate limiting for ModelRequestRateLimitCount. Reason: The original steps 1 and 3 in the redisRateLimitHandler method were not atomic, leading to poor precision under high concurrent requests. For example, with a rate limit set to 60, sending 200 concurrent requests would result in none being blocked, whereas theoretically around 140 should be intercepted. Solution: I chose not to merge steps 1 and 3 into a single Lua script because a single atomic operation involving read, write, and delete operations could suffer from performance issues under high concurrency. Instead, I implemented a token bucket algorithm to optimize this, reducing the atomic operation to just read and write steps while significantly decreasing the memory footprint. --- common/limiter/limiter.go | 94 +++++++++++++++++++++++++++++++ common/limiter/lua/rate_limit.lua | 44 +++++++++++++++ middleware/model-rate-limit.go | 40 +++++++------ 3 files changed, 162 insertions(+), 16 deletions(-) create mode 100644 common/limiter/limiter.go create mode 100644 common/limiter/lua/rate_limit.lua diff --git a/common/limiter/limiter.go b/common/limiter/limiter.go new file mode 100644 index 00000000..564e3fe4 --- /dev/null +++ b/common/limiter/limiter.go @@ -0,0 +1,94 @@ +package limiter + +import ( + "context" + _ "embed" + "fmt" + "github.com/go-redis/redis/v8" + "sync" +) + +//go:embed lua/rate_limit.lua +var rateLimitScript string + +type RedisLimiter struct { + client *redis.Client + limitScriptSHA string +} + +var ( + instance *RedisLimiter + once sync.Once +) + +func New(ctx context.Context, r *redis.Client) *RedisLimiter { + once.Do(func() { + client := r + _, err := client.Ping(ctx).Result() + if err != nil { + panic(err) // 或者处理连接错误 + } + // 预加载脚本 + limitSHA, err := client.ScriptLoad(ctx, rateLimitScript).Result() + if err != nil { + fmt.Println(err) + } + + instance = &RedisLimiter{ + client: client, + limitScriptSHA: limitSHA, + } + }) + + return instance +} + +func (rl *RedisLimiter) Allow(ctx context.Context, key string, opts ...Option) (bool, error) { + // 默认配置 + config := &Config{ + Capacity: 10, + Rate: 1, + Requested: 1, + } + + // 应用选项模式 + for _, opt := range opts { + opt(config) + } + + // 执行限流 + result, err := rl.client.EvalSha( + ctx, + rl.limitScriptSHA, + []string{key}, + config.Requested, + config.Rate, + config.Capacity, + ).Int() + + if err != nil { + return false, fmt.Errorf("rate limit failed: %w", err) + } + return result == 1, nil +} + +// Config 配置选项模式 +type Config struct { + Capacity int64 + Rate int64 + Requested int64 +} + +type Option func(*Config) + +func WithCapacity(c int64) Option { + return func(cfg *Config) { cfg.Capacity = c } +} + +func WithRate(r int64) Option { + return func(cfg *Config) { cfg.Rate = r } +} + +func WithRequested(n int64) Option { + return func(cfg *Config) { cfg.Requested = n } +} diff --git a/common/limiter/lua/rate_limit.lua b/common/limiter/lua/rate_limit.lua new file mode 100644 index 00000000..c07fd3a8 --- /dev/null +++ b/common/limiter/lua/rate_limit.lua @@ -0,0 +1,44 @@ +-- 令牌桶限流器 +-- KEYS[1]: 限流器唯一标识 +-- ARGV[1]: 请求令牌数 (通常为1) +-- ARGV[2]: 令牌生成速率 (每秒) +-- ARGV[3]: 桶容量 + +local key = KEYS[1] +local requested = tonumber(ARGV[1]) +local rate = tonumber(ARGV[2]) +local capacity = tonumber(ARGV[3]) + +-- 获取当前时间(Redis服务器时间) +local now = redis.call('TIME') +local nowInSeconds = tonumber(now[1]) + +-- 获取桶状态 +local bucket = redis.call('HMGET', key, 'tokens', 'last_time') +local tokens = tonumber(bucket[1]) +local last_time = tonumber(bucket[2]) + +-- 初始化桶(首次请求或过期) +if not tokens or not last_time then + tokens = capacity + last_time = nowInSeconds +else + -- 计算新增令牌 + local elapsed = nowInSeconds - last_time + local add_tokens = elapsed * rate + tokens = math.min(capacity, tokens + add_tokens) + last_time = nowInSeconds +end + +-- 判断是否允许请求 +local allowed = false +if tokens >= requested then + tokens = tokens - requested + allowed = true +end + +---- 更新桶状态并设置过期时间 +redis.call('HMSET', key, 'tokens', tokens, 'last_time', last_time) +--redis.call('EXPIRE', key, math.ceil(capacity / rate) + 60) -- 适当延长过期时间 + +return allowed and 1 or 0 \ No newline at end of file diff --git a/middleware/model-rate-limit.go b/middleware/model-rate-limit.go index bd5f9d25..565e4c63 100644 --- a/middleware/model-rate-limit.go +++ b/middleware/model-rate-limit.go @@ -5,6 +5,7 @@ import ( "fmt" "net/http" "one-api/common" + "one-api/common/limiter" "one-api/setting" "strconv" "time" @@ -78,21 +79,9 @@ func redisRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) g ctx := context.Background() rdb := common.RDB - // 1. 检查总请求数限制(当totalMaxCount为0时会自动跳过) - totalKey := fmt.Sprintf("rateLimit:%s:%s", ModelRequestRateLimitCountMark, userId) - allowed, err := checkRedisRateLimit(ctx, rdb, totalKey, totalMaxCount, duration) - if err != nil { - fmt.Println("检查总请求数限制失败:", err.Error()) - abortWithOpenAiMessage(c, http.StatusInternalServerError, "rate_limit_check_failed") - return - } - if !allowed { - abortWithOpenAiMessage(c, http.StatusTooManyRequests, fmt.Sprintf("您已达到总请求数限制:%d分钟内最多请求%d次,包括失败次数,请检查您的请求是否正确", setting.ModelRequestRateLimitDurationMinutes, totalMaxCount)) - } - - // 2. 检查成功请求数限制 + // 1. 检查成功请求数限制 successKey := fmt.Sprintf("rateLimit:%s:%s", ModelRequestRateLimitSuccessCountMark, userId) - allowed, err = checkRedisRateLimit(ctx, rdb, successKey, successMaxCount, duration) + allowed, err := checkRedisRateLimit(ctx, rdb, successKey, successMaxCount, duration) if err != nil { fmt.Println("检查成功请求数限制失败:", err.Error()) abortWithOpenAiMessage(c, http.StatusInternalServerError, "rate_limit_check_failed") @@ -102,9 +91,28 @@ func redisRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) g abortWithOpenAiMessage(c, http.StatusTooManyRequests, fmt.Sprintf("您已达到请求数限制:%d分钟内最多请求%d次", setting.ModelRequestRateLimitDurationMinutes, successMaxCount)) return } + //检查总请求数限制并记录总请求(当totalMaxCount为0时会自动跳过,使用令牌桶限流器 + totalKey := fmt.Sprintf("rateLimit:%s", userId) + //allowed, err = checkRedisRateLimit(ctx, rdb, totalKey, totalMaxCount, duration) + // 初始化 + tb := limiter.New(ctx, rdb) + allowed, err = tb.Allow( + ctx, + totalKey, + limiter.WithCapacity(int64(totalMaxCount)*duration), + limiter.WithRate(int64(totalMaxCount)), + limiter.WithRequested(duration), + ) - // 3. 记录总请求(当totalMaxCount为0时会自动跳过) - recordRedisRequest(ctx, rdb, totalKey, totalMaxCount) + if err != nil { + fmt.Println("检查总请求数限制失败:", err.Error()) + abortWithOpenAiMessage(c, http.StatusInternalServerError, "rate_limit_check_failed") + return + } + + if !allowed { + abortWithOpenAiMessage(c, http.StatusTooManyRequests, fmt.Sprintf("您已达到总请求数限制:%d分钟内最多请求%d次,包括失败次数,请检查您的请求是否正确", setting.ModelRequestRateLimitDurationMinutes, totalMaxCount)) + } // 4. 处理请求 c.Next()