Refactor: Optimize the request rate limiting for ModelRequestRateLimitCount.

Reason: The original steps 1 and 3 in the redisRateLimitHandler method were not atomic, leading to poor precision under high concurrent requests. For example, with a rate limit set to 60, sending 200 concurrent requests would result in none being blocked, whereas theoretically around 140 should be intercepted.
Solution: I chose not to merge steps 1 and 3 into a single Lua script because a single atomic operation involving read, write, and delete operations could suffer from performance issues under high concurrency. Instead, I implemented a token bucket algorithm to optimize this, reducing the atomic operation to just read and write steps while significantly decreasing the memory footprint.
This commit is contained in:
霍雨佳
2025-04-16 10:33:43 +08:00
parent 214ca4db56
commit eb75ff232f
3 changed files with 162 additions and 16 deletions

View File

@@ -5,6 +5,7 @@ import (
"fmt"
"net/http"
"one-api/common"
"one-api/common/limiter"
"one-api/setting"
"strconv"
"time"
@@ -78,21 +79,9 @@ func redisRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) g
ctx := context.Background()
rdb := common.RDB
// 1. 检查请求数限制当totalMaxCount为0时会自动跳过
totalKey := fmt.Sprintf("rateLimit:%s:%s", ModelRequestRateLimitCountMark, userId)
allowed, err := checkRedisRateLimit(ctx, rdb, totalKey, totalMaxCount, duration)
if err != nil {
fmt.Println("检查总请求数限制失败:", err.Error())
abortWithOpenAiMessage(c, http.StatusInternalServerError, "rate_limit_check_failed")
return
}
if !allowed {
abortWithOpenAiMessage(c, http.StatusTooManyRequests, fmt.Sprintf("您已达到总请求数限制:%d分钟内最多请求%d次包括失败次数请检查您的请求是否正确", setting.ModelRequestRateLimitDurationMinutes, totalMaxCount))
}
// 2. 检查成功请求数限制
// 1. 检查成功请求数限制
successKey := fmt.Sprintf("rateLimit:%s:%s", ModelRequestRateLimitSuccessCountMark, userId)
allowed, err = checkRedisRateLimit(ctx, rdb, successKey, successMaxCount, duration)
allowed, err := checkRedisRateLimit(ctx, rdb, successKey, successMaxCount, duration)
if err != nil {
fmt.Println("检查成功请求数限制失败:", err.Error())
abortWithOpenAiMessage(c, http.StatusInternalServerError, "rate_limit_check_failed")
@@ -102,9 +91,28 @@ func redisRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) g
abortWithOpenAiMessage(c, http.StatusTooManyRequests, fmt.Sprintf("您已达到请求数限制:%d分钟内最多请求%d次", setting.ModelRequestRateLimitDurationMinutes, successMaxCount))
return
}
//检查总请求数限制并记录总请求当totalMaxCount为0时会自动跳过使用令牌桶限流器
totalKey := fmt.Sprintf("rateLimit:%s", userId)
//allowed, err = checkRedisRateLimit(ctx, rdb, totalKey, totalMaxCount, duration)
// 初始化
tb := limiter.New(ctx, rdb)
allowed, err = tb.Allow(
ctx,
totalKey,
limiter.WithCapacity(int64(totalMaxCount)*duration),
limiter.WithRate(int64(totalMaxCount)),
limiter.WithRequested(duration),
)
// 3. 记录总请求当totalMaxCount为0时会自动跳过
recordRedisRequest(ctx, rdb, totalKey, totalMaxCount)
if err != nil {
fmt.Println("检查总请求数限制失败:", err.Error())
abortWithOpenAiMessage(c, http.StatusInternalServerError, "rate_limit_check_failed")
return
}
if !allowed {
abortWithOpenAiMessage(c, http.StatusTooManyRequests, fmt.Sprintf("您已达到总请求数限制:%d分钟内最多请求%d次包括失败次数请检查您的请求是否正确", setting.ModelRequestRateLimitDurationMinutes, totalMaxCount))
}
// 4. 处理请求
c.Next()