From e385e347eae84e1c20ee1210a4f15ced2618f0cf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9C=8D=E9=9B=A8=E4=BD=B3?= Date: Wed, 16 Apr 2025 16:36:07 +0800 Subject: [PATCH] Refactor: Optimize the token bucket algorithm, specifically the New method in common/imiterlimiter.go. Solution: Remove Redis ping. When printing exceptions, use SysLog to print and add additional logging information. --- common/limiter/limiter.go | 13 ++++--------- middleware/model-rate-limit.go | 4 ++-- 2 files changed, 6 insertions(+), 11 deletions(-) diff --git a/common/limiter/limiter.go b/common/limiter/limiter.go index 564e3fe4..ef5d1935 100644 --- a/common/limiter/limiter.go +++ b/common/limiter/limiter.go @@ -5,6 +5,7 @@ import ( _ "embed" "fmt" "github.com/go-redis/redis/v8" + "one-api/common" "sync" ) @@ -23,19 +24,13 @@ var ( func New(ctx context.Context, r *redis.Client) *RedisLimiter { once.Do(func() { - client := r - _, err := client.Ping(ctx).Result() - if err != nil { - panic(err) // 或者处理连接错误 - } // 预加载脚本 - limitSHA, err := client.ScriptLoad(ctx, rateLimitScript).Result() + limitSHA, err := r.ScriptLoad(ctx, rateLimitScript).Result() if err != nil { - fmt.Println(err) + common.SysLog(fmt.Sprintf("Failed to load rate limit script: %v", err)) } - instance = &RedisLimiter{ - client: client, + client: r, limitScriptSHA: limitSHA, } }) diff --git a/middleware/model-rate-limit.go b/middleware/model-rate-limit.go index 565e4c63..581dc451 100644 --- a/middleware/model-rate-limit.go +++ b/middleware/model-rate-limit.go @@ -91,9 +91,9 @@ func redisRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) g abortWithOpenAiMessage(c, http.StatusTooManyRequests, fmt.Sprintf("您已达到请求数限制:%d分钟内最多请求%d次", setting.ModelRequestRateLimitDurationMinutes, successMaxCount)) return } - //检查总请求数限制并记录总请求(当totalMaxCount为0时会自动跳过,使用令牌桶限流器 + + //2.检查总请求数限制并记录总请求(当totalMaxCount为0时会自动跳过,使用令牌桶限流器 totalKey := fmt.Sprintf("rateLimit:%s", userId) - //allowed, err = checkRedisRateLimit(ctx, rdb, totalKey, totalMaxCount, duration) // 初始化 tb := limiter.New(ctx, rdb) allowed, err = tb.Allow(