Refactor: Optimize the request rate limiting for ModelRequestRateLimitCount.
Reason: The original steps 1 and 3 in the redisRateLimitHandler method were not atomic, leading to poor precision under high concurrent requests. For example, with a rate limit set to 60, sending 200 concurrent requests would result in none being blocked, whereas theoretically around 140 should be intercepted. Solution: I chose not to merge steps 1 and 3 into a single Lua script because a single atomic operation involving read, write, and delete operations could suffer from performance issues under high concurrency. Instead, I implemented a token bucket algorithm to optimize this, reducing the atomic operation to just read and write steps while significantly decreasing the memory footprint.
This commit is contained in:
94
common/limiter/limiter.go
Normal file
94
common/limiter/limiter.go
Normal file
@@ -0,0 +1,94 @@
|
|||||||
|
package limiter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
_ "embed"
|
||||||
|
"fmt"
|
||||||
|
"github.com/go-redis/redis/v8"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
//go:embed lua/rate_limit.lua
|
||||||
|
var rateLimitScript string
|
||||||
|
|
||||||
|
type RedisLimiter struct {
|
||||||
|
client *redis.Client
|
||||||
|
limitScriptSHA string
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
instance *RedisLimiter
|
||||||
|
once sync.Once
|
||||||
|
)
|
||||||
|
|
||||||
|
func New(ctx context.Context, r *redis.Client) *RedisLimiter {
|
||||||
|
once.Do(func() {
|
||||||
|
client := r
|
||||||
|
_, err := client.Ping(ctx).Result()
|
||||||
|
if err != nil {
|
||||||
|
panic(err) // 或者处理连接错误
|
||||||
|
}
|
||||||
|
// 预加载脚本
|
||||||
|
limitSHA, err := client.ScriptLoad(ctx, rateLimitScript).Result()
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
instance = &RedisLimiter{
|
||||||
|
client: client,
|
||||||
|
limitScriptSHA: limitSHA,
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
return instance
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rl *RedisLimiter) Allow(ctx context.Context, key string, opts ...Option) (bool, error) {
|
||||||
|
// 默认配置
|
||||||
|
config := &Config{
|
||||||
|
Capacity: 10,
|
||||||
|
Rate: 1,
|
||||||
|
Requested: 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
// 应用选项模式
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(config)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 执行限流
|
||||||
|
result, err := rl.client.EvalSha(
|
||||||
|
ctx,
|
||||||
|
rl.limitScriptSHA,
|
||||||
|
[]string{key},
|
||||||
|
config.Requested,
|
||||||
|
config.Rate,
|
||||||
|
config.Capacity,
|
||||||
|
).Int()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("rate limit failed: %w", err)
|
||||||
|
}
|
||||||
|
return result == 1, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Config 配置选项模式
|
||||||
|
type Config struct {
|
||||||
|
Capacity int64
|
||||||
|
Rate int64
|
||||||
|
Requested int64
|
||||||
|
}
|
||||||
|
|
||||||
|
type Option func(*Config)
|
||||||
|
|
||||||
|
func WithCapacity(c int64) Option {
|
||||||
|
return func(cfg *Config) { cfg.Capacity = c }
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithRate(r int64) Option {
|
||||||
|
return func(cfg *Config) { cfg.Rate = r }
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithRequested(n int64) Option {
|
||||||
|
return func(cfg *Config) { cfg.Requested = n }
|
||||||
|
}
|
||||||
44
common/limiter/lua/rate_limit.lua
Normal file
44
common/limiter/lua/rate_limit.lua
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
-- 令牌桶限流器
|
||||||
|
-- KEYS[1]: 限流器唯一标识
|
||||||
|
-- ARGV[1]: 请求令牌数 (通常为1)
|
||||||
|
-- ARGV[2]: 令牌生成速率 (每秒)
|
||||||
|
-- ARGV[3]: 桶容量
|
||||||
|
|
||||||
|
local key = KEYS[1]
|
||||||
|
local requested = tonumber(ARGV[1])
|
||||||
|
local rate = tonumber(ARGV[2])
|
||||||
|
local capacity = tonumber(ARGV[3])
|
||||||
|
|
||||||
|
-- 获取当前时间(Redis服务器时间)
|
||||||
|
local now = redis.call('TIME')
|
||||||
|
local nowInSeconds = tonumber(now[1])
|
||||||
|
|
||||||
|
-- 获取桶状态
|
||||||
|
local bucket = redis.call('HMGET', key, 'tokens', 'last_time')
|
||||||
|
local tokens = tonumber(bucket[1])
|
||||||
|
local last_time = tonumber(bucket[2])
|
||||||
|
|
||||||
|
-- 初始化桶(首次请求或过期)
|
||||||
|
if not tokens or not last_time then
|
||||||
|
tokens = capacity
|
||||||
|
last_time = nowInSeconds
|
||||||
|
else
|
||||||
|
-- 计算新增令牌
|
||||||
|
local elapsed = nowInSeconds - last_time
|
||||||
|
local add_tokens = elapsed * rate
|
||||||
|
tokens = math.min(capacity, tokens + add_tokens)
|
||||||
|
last_time = nowInSeconds
|
||||||
|
end
|
||||||
|
|
||||||
|
-- 判断是否允许请求
|
||||||
|
local allowed = false
|
||||||
|
if tokens >= requested then
|
||||||
|
tokens = tokens - requested
|
||||||
|
allowed = true
|
||||||
|
end
|
||||||
|
|
||||||
|
---- 更新桶状态并设置过期时间
|
||||||
|
redis.call('HMSET', key, 'tokens', tokens, 'last_time', last_time)
|
||||||
|
--redis.call('EXPIRE', key, math.ceil(capacity / rate) + 60) -- 适当延长过期时间
|
||||||
|
|
||||||
|
return allowed and 1 or 0
|
||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/common/limiter"
|
||||||
"one-api/setting"
|
"one-api/setting"
|
||||||
"strconv"
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
@@ -78,21 +79,9 @@ func redisRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) g
|
|||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
rdb := common.RDB
|
rdb := common.RDB
|
||||||
|
|
||||||
// 1. 检查总请求数限制(当totalMaxCount为0时会自动跳过)
|
// 1. 检查成功请求数限制
|
||||||
totalKey := fmt.Sprintf("rateLimit:%s:%s", ModelRequestRateLimitCountMark, userId)
|
|
||||||
allowed, err := checkRedisRateLimit(ctx, rdb, totalKey, totalMaxCount, duration)
|
|
||||||
if err != nil {
|
|
||||||
fmt.Println("检查总请求数限制失败:", err.Error())
|
|
||||||
abortWithOpenAiMessage(c, http.StatusInternalServerError, "rate_limit_check_failed")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if !allowed {
|
|
||||||
abortWithOpenAiMessage(c, http.StatusTooManyRequests, fmt.Sprintf("您已达到总请求数限制:%d分钟内最多请求%d次,包括失败次数,请检查您的请求是否正确", setting.ModelRequestRateLimitDurationMinutes, totalMaxCount))
|
|
||||||
}
|
|
||||||
|
|
||||||
// 2. 检查成功请求数限制
|
|
||||||
successKey := fmt.Sprintf("rateLimit:%s:%s", ModelRequestRateLimitSuccessCountMark, userId)
|
successKey := fmt.Sprintf("rateLimit:%s:%s", ModelRequestRateLimitSuccessCountMark, userId)
|
||||||
allowed, err = checkRedisRateLimit(ctx, rdb, successKey, successMaxCount, duration)
|
allowed, err := checkRedisRateLimit(ctx, rdb, successKey, successMaxCount, duration)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Println("检查成功请求数限制失败:", err.Error())
|
fmt.Println("检查成功请求数限制失败:", err.Error())
|
||||||
abortWithOpenAiMessage(c, http.StatusInternalServerError, "rate_limit_check_failed")
|
abortWithOpenAiMessage(c, http.StatusInternalServerError, "rate_limit_check_failed")
|
||||||
@@ -102,9 +91,28 @@ func redisRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) g
|
|||||||
abortWithOpenAiMessage(c, http.StatusTooManyRequests, fmt.Sprintf("您已达到请求数限制:%d分钟内最多请求%d次", setting.ModelRequestRateLimitDurationMinutes, successMaxCount))
|
abortWithOpenAiMessage(c, http.StatusTooManyRequests, fmt.Sprintf("您已达到请求数限制:%d分钟内最多请求%d次", setting.ModelRequestRateLimitDurationMinutes, successMaxCount))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
//检查总请求数限制并记录总请求(当totalMaxCount为0时会自动跳过,使用令牌桶限流器
|
||||||
|
totalKey := fmt.Sprintf("rateLimit:%s", userId)
|
||||||
|
//allowed, err = checkRedisRateLimit(ctx, rdb, totalKey, totalMaxCount, duration)
|
||||||
|
// 初始化
|
||||||
|
tb := limiter.New(ctx, rdb)
|
||||||
|
allowed, err = tb.Allow(
|
||||||
|
ctx,
|
||||||
|
totalKey,
|
||||||
|
limiter.WithCapacity(int64(totalMaxCount)*duration),
|
||||||
|
limiter.WithRate(int64(totalMaxCount)),
|
||||||
|
limiter.WithRequested(duration),
|
||||||
|
)
|
||||||
|
|
||||||
// 3. 记录总请求(当totalMaxCount为0时会自动跳过)
|
if err != nil {
|
||||||
recordRedisRequest(ctx, rdb, totalKey, totalMaxCount)
|
fmt.Println("检查总请求数限制失败:", err.Error())
|
||||||
|
abortWithOpenAiMessage(c, http.StatusInternalServerError, "rate_limit_check_failed")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !allowed {
|
||||||
|
abortWithOpenAiMessage(c, http.StatusTooManyRequests, fmt.Sprintf("您已达到总请求数限制:%d分钟内最多请求%d次,包括失败次数,请检查您的请求是否正确", setting.ModelRequestRateLimitDurationMinutes, totalMaxCount))
|
||||||
|
}
|
||||||
|
|
||||||
// 4. 处理请求
|
// 4. 处理请求
|
||||||
c.Next()
|
c.Next()
|
||||||
|
|||||||
Reference in New Issue
Block a user