From 1513ed78477044999e066d5eb3b1fc1762dce531 Mon Sep 17 00:00:00 2001 From: tbphp Date: Mon, 5 May 2025 19:32:22 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E8=B0=83=E6=95=B4=E4=BB=A3?= =?UTF-8?q?=E7=A0=81=EF=BC=8C=E7=AC=A6=E5=90=88=E9=A1=B9=E7=9B=AE=E7=8E=B0?= =?UTF-8?q?=E6=9C=89=E8=A7=84=E8=8C=83?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- middleware/model-rate-limit.go | 54 +++++++++++++++++--------- model/option.go | 34 +++++----------- setting/rate_limit.go | 37 ++++++++---------- web/src/components/RateLimitSetting.js | 6 ++- 4 files changed, 65 insertions(+), 66 deletions(-) diff --git a/middleware/model-rate-limit.go b/middleware/model-rate-limit.go index b0047b70..1ca5ace6 100644 --- a/middleware/model-rate-limit.go +++ b/middleware/model-rate-limit.go @@ -6,6 +6,7 @@ import ( "net/http" "one-api/common" "one-api/common/limiter" + "one-api/constant" "one-api/setting" "strconv" "time" @@ -19,20 +20,25 @@ const ( ModelRequestRateLimitSuccessCountMark = "MRRLS" ) +// 检查Redis中的请求限制 func checkRedisRateLimit(ctx context.Context, rdb *redis.Client, key string, maxCount int, duration int64) (bool, error) { + // 如果maxCount为0,表示不限制 if maxCount == 0 { return true, nil } + // 获取当前计数 length, err := rdb.LLen(ctx, key).Result() if err != nil { return false, err } + // 如果未达到限制,允许请求 if length < int64(maxCount) { return true, nil } + // 检查时间窗口 oldTimeStr, _ := rdb.LIndex(ctx, key, -1).Result() oldTime, err := time.Parse(timeFormat, oldTimeStr) if err != nil { @@ -44,6 +50,7 @@ func checkRedisRateLimit(ctx context.Context, rdb *redis.Client, key string, max if err != nil { return false, err } + // 如果在时间窗口内已达到限制,拒绝请求 subTime := nowTime.Sub(oldTime).Seconds() if int64(subTime) < duration { rdb.Expire(ctx, key, time.Duration(setting.ModelRequestRateLimitDurationMinutes)*time.Minute) @@ -53,7 +60,9 @@ func checkRedisRateLimit(ctx context.Context, rdb *redis.Client, key string, max return true, nil } +// 记录Redis请求 func recordRedisRequest(ctx context.Context, rdb *redis.Client, key string, maxCount int) { + // 如果maxCount为0,不记录请求 if maxCount == 0 { return } @@ -64,12 +73,14 @@ func recordRedisRequest(ctx context.Context, rdb *redis.Client, key string, maxC rdb.Expire(ctx, key, time.Duration(setting.ModelRequestRateLimitDurationMinutes)*time.Minute) } +// Redis限流处理器 func redisRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) gin.HandlerFunc { return func(c *gin.Context) { userId := strconv.Itoa(c.GetInt("id")) ctx := context.Background() rdb := common.RDB + // 1. 检查成功请求数限制 successKey := fmt.Sprintf("rateLimit:%s:%s", ModelRequestRateLimitSuccessCountMark, userId) allowed, err := checkRedisRateLimit(ctx, rdb, successKey, successMaxCount, duration) if err != nil { @@ -82,7 +93,9 @@ func redisRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) g return } + //2.检查总请求数限制并记录总请求(当totalMaxCount为0时会自动跳过,使用令牌桶限流器 totalKey := fmt.Sprintf("rateLimit:%s", userId) + // 初始化 tb := limiter.New(ctx, rdb) allowed, err = tb.Allow( ctx, @@ -102,14 +115,17 @@ func redisRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) g abortWithOpenAiMessage(c, http.StatusTooManyRequests, fmt.Sprintf("您已达到总请求数限制:%d分钟内最多请求%d次,包括失败次数,请检查您的请求是否正确", setting.ModelRequestRateLimitDurationMinutes, totalMaxCount)) } + // 4. 处理请求 c.Next() + // 5. 如果请求成功,记录成功请求 if c.Writer.Status() < 400 { recordRedisRequest(ctx, rdb, successKey, successMaxCount) } } } +// 内存限流处理器 func memoryRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) gin.HandlerFunc { inMemoryRateLimiter.Init(time.Duration(setting.ModelRequestRateLimitDurationMinutes) * time.Minute) @@ -118,12 +134,15 @@ func memoryRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) totalKey := ModelRequestRateLimitCountMark + userId successKey := ModelRequestRateLimitSuccessCountMark + userId + // 1. 检查总请求数限制(当totalMaxCount为0时跳过) if totalMaxCount > 0 && !inMemoryRateLimiter.Request(totalKey, totalMaxCount, duration) { c.Status(http.StatusTooManyRequests) c.Abort() return } + // 2. 检查成功请求数限制 + // 使用一个临时key来检查限制,这样可以避免实际记录 checkKey := successKey + "_check" if !inMemoryRateLimiter.Request(checkKey, successMaxCount, duration) { c.Status(http.StatusTooManyRequests) @@ -131,51 +150,48 @@ func memoryRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) return } + // 3. 处理请求 c.Next() + // 4. 如果请求成功,记录到实际的成功请求计数中 if c.Writer.Status() < 400 { inMemoryRateLimiter.Request(successKey, successMaxCount, duration) } } } +// ModelRequestRateLimit 模型请求限流中间件 func ModelRequestRateLimit() func(c *gin.Context) { return func(c *gin.Context) { + // 在每个请求时检查是否启用限流 if !setting.ModelRequestRateLimitEnabled { c.Next() return } + // 计算限流参数 duration := int64(setting.ModelRequestRateLimitDurationMinutes * 60) + totalMaxCount := setting.ModelRequestRateLimitCount + successMaxCount := setting.ModelRequestRateLimitSuccessCount + // 获取分组 group := c.GetString("token_group") if group == "" { - group = c.GetString("group") - } - if group == "" { - group = "default" + group = c.GetString(constant.ContextKeyUserGroup) } - finalTotalCount := setting.ModelRequestRateLimitCount - finalSuccessCount := setting.ModelRequestRateLimitSuccessCount - foundGroupLimit := false - + //获取分组的限流配置 groupTotalCount, groupSuccessCount, found := setting.GetGroupRateLimit(group) if found { - finalTotalCount = groupTotalCount - finalSuccessCount = groupSuccessCount - foundGroupLimit = true - common.LogWarn(c.Request.Context(), fmt.Sprintf("Using rate limit for group '%s': total=%d, success=%d", group, finalTotalCount, finalSuccessCount)) - } - - if !foundGroupLimit { - common.LogInfo(c.Request.Context(), fmt.Sprintf("No specific rate limit found for group '%s', using global limits: total=%d, success=%d", group, finalTotalCount, finalSuccessCount)) + totalMaxCount = groupTotalCount + successMaxCount = groupSuccessCount } + // 根据存储类型选择并执行限流处理器 if common.RedisEnabled { - redisRateLimitHandler(duration, finalTotalCount, finalSuccessCount)(c) + redisRateLimitHandler(duration, totalMaxCount, successMaxCount)(c) } else { - memoryRateLimitHandler(duration, finalTotalCount, finalSuccessCount)(c) + memoryRateLimitHandler(duration, totalMaxCount, successMaxCount)(c) } } -} +} \ No newline at end of file diff --git a/model/option.go b/model/option.go index 79556737..e9c129e1 100644 --- a/model/option.go +++ b/model/option.go @@ -1,8 +1,6 @@ package model import ( - "encoding/json" - "fmt" "one-api/common" "one-api/setting" "one-api/setting/config" @@ -94,8 +92,7 @@ func InitOptionMap() { common.OptionMap["ModelRequestRateLimitCount"] = strconv.Itoa(setting.ModelRequestRateLimitCount) common.OptionMap["ModelRequestRateLimitDurationMinutes"] = strconv.Itoa(setting.ModelRequestRateLimitDurationMinutes) common.OptionMap["ModelRequestRateLimitSuccessCount"] = strconv.Itoa(setting.ModelRequestRateLimitSuccessCount) - jsonBytes, _ := json.Marshal(map[string][2]int{}) - common.OptionMap["ModelRequestRateLimitGroup"] = string(jsonBytes) + common.OptionMap["ModelRequestRateLimitGroup"] = setting.ModelRequestRateLimitGroup2JSONString() common.OptionMap["ModelRatio"] = operation_setting.ModelRatio2JSONString() common.OptionMap["ModelPrice"] = operation_setting.ModelPrice2JSONString() common.OptionMap["CacheRatio"] = operation_setting.CacheRatio2JSONString() @@ -154,31 +151,18 @@ func SyncOptions(frequency int) { } func UpdateOption(key string, value string) error { - originalValue := value - - if key == "ModelRequestRateLimitGroup" { - var cfg map[string][2]int - err := json.Unmarshal([]byte(originalValue), &cfg) - if err != nil { - return fmt.Errorf("无效的 JSON 格式 for %s: %w", key, err) - } - - formattedValueBytes, marshalErr := json.MarshalIndent(cfg, "", " ") - if marshalErr != nil { - return fmt.Errorf("failed to marshal validated %s config: %w", key, marshalErr) - } - value = string(formattedValueBytes) - } - + // Save to database first option := Option{ Key: key, } + // https://gorm.io/docs/update.html#Save-All-Fields DB.FirstOrCreate(&option, Option{Key: key}) option.Value = value - if err := DB.Save(&option).Error; err != nil { - return fmt.Errorf("保存选项 %s 到数据库失败: %w", key, err) - } - + // Save is a combination function. + // If save value does not contain primary key, it will execute Create, + // otherwise it will execute Update (with all fields). + DB.Save(&option) + // Update OptionMap return updateOptionMap(key, value) } @@ -356,7 +340,7 @@ func updateOptionMap(key string, value string) (err error) { case "ModelRequestRateLimitSuccessCount": setting.ModelRequestRateLimitSuccessCount, _ = strconv.Atoi(value) case "ModelRequestRateLimitGroup": - err = setting.UpdateModelRequestRateLimitGroup(value) + err = setting.UpdateModelRequestRateLimitGroupByJSONString(value) case "RetryTimes": common.RetryTimes, _ = strconv.Atoi(value) case "DataExportInterval": diff --git a/setting/rate_limit.go b/setting/rate_limit.go index 5be75cc1..aab030cd 100644 --- a/setting/rate_limit.go +++ b/setting/rate_limit.go @@ -2,7 +2,6 @@ package setting import ( "encoding/json" - "fmt" "one-api/common" "sync" ) @@ -11,33 +10,31 @@ var ModelRequestRateLimitEnabled = false var ModelRequestRateLimitDurationMinutes = 1 var ModelRequestRateLimitCount = 0 var ModelRequestRateLimitSuccessCount = 1000 -var ModelRequestRateLimitGroup map[string][2]int +var ModelRequestRateLimitGroup = map[string][2]int{} +var ModelRequestRateLimitMutex sync.RWMutex -var ModelRequestRateLimitGroupMutex sync.RWMutex +func ModelRequestRateLimitGroup2JSONString() string { + ModelRequestRateLimitMutex.RLock() + defer ModelRequestRateLimitMutex.RUnlock() -func UpdateModelRequestRateLimitGroup(jsonStr string) error { - ModelRequestRateLimitGroupMutex.Lock() - defer ModelRequestRateLimitGroupMutex.Unlock() - - var newConfig map[string][2]int - if jsonStr == "" || jsonStr == "{}" { - ModelRequestRateLimitGroup = make(map[string][2]int) - common.SysLog("Model request rate limit group config cleared") - return nil - } - - err := json.Unmarshal([]byte(jsonStr), &newConfig) + jsonBytes, err := json.Marshal(ModelRequestRateLimitGroup) if err != nil { - return fmt.Errorf("failed to unmarshal ModelRequestRateLimitGroup config: %w", err) + common.SysError("error marshalling model ratio: " + err.Error()) } + return string(jsonBytes) +} - ModelRequestRateLimitGroup = newConfig - return nil +func UpdateModelRequestRateLimitGroupByJSONString(jsonStr string) error { + ModelRequestRateLimitMutex.RLock() + defer ModelRequestRateLimitMutex.RUnlock() + + ModelRequestRateLimitGroup = make(map[string][2]int) + return json.Unmarshal([]byte(jsonStr), &ModelRequestRateLimitGroup) } func GetGroupRateLimit(group string) (totalCount, successCount int, found bool) { - ModelRequestRateLimitGroupMutex.RLock() - defer ModelRequestRateLimitGroupMutex.RUnlock() + ModelRequestRateLimitMutex.RLock() + defer ModelRequestRateLimitMutex.RUnlock() if ModelRequestRateLimitGroup == nil { return 0, 0, false diff --git a/web/src/components/RateLimitSetting.js b/web/src/components/RateLimitSetting.js index 7e206672..309b94de 100644 --- a/web/src/components/RateLimitSetting.js +++ b/web/src/components/RateLimitSetting.js @@ -24,7 +24,6 @@ const RateLimitSetting = () => { if (success) { let newInputs = {}; data.forEach((item) => { - // 检查 key 是否在初始 inputs 中定义 if (Object.prototype.hasOwnProperty.call(inputs, item.key)) { if (item.key.endsWith('Enabled')) { newInputs[item.key] = item.value === 'true'; @@ -33,6 +32,7 @@ const RateLimitSetting = () => { } } }); + setInputs(newInputs); } else { showError(message); @@ -42,6 +42,7 @@ const RateLimitSetting = () => { try { setLoading(true); await getOptions(); + // showSuccess('刷新成功'); } catch (error) { showError('刷新失败'); } finally { @@ -56,6 +57,7 @@ const RateLimitSetting = () => { return ( <> + {/* AI请求速率限制 */} @@ -64,4 +66,4 @@ const RateLimitSetting = () => { ); }; - export default RateLimitSetting; + export default RateLimitSetting; \ No newline at end of file