From 6c3fb7777ec3fe4874b249251120e68b5e22642f Mon Sep 17 00:00:00 2001 From: tbphp Date: Mon, 5 May 2025 07:31:54 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=A2=9E=E5=8A=A0=E5=88=86=E7=BB=84?= =?UTF-8?q?=E9=80=9F=E7=8E=87=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- middleware/model-rate-limit.go | 37 +++++++-- model/option.go | 47 +++++++++++- setting/rate_limit.go | 70 +++++++++++++++++ web/src/components/RateLimitSetting.js | 1 + .../RateLimit/SettingsRequestRateLimit.js | 76 +++++++++++++++++-- 5 files changed, 214 insertions(+), 17 deletions(-) diff --git a/middleware/model-rate-limit.go b/middleware/model-rate-limit.go index 581dc451..d4199ece 100644 --- a/middleware/model-rate-limit.go +++ b/middleware/model-rate-limit.go @@ -168,16 +168,39 @@ func ModelRequestRateLimit() func(c *gin.Context) { return } - // 计算限流参数 + // 计算通用限流参数 duration := int64(setting.ModelRequestRateLimitDurationMinutes * 60) - totalMaxCount := setting.ModelRequestRateLimitCount - successMaxCount := setting.ModelRequestRateLimitSuccessCount - // 根据存储类型选择并执行限流处理器 - if common.RedisEnabled { - redisRateLimitHandler(duration, totalMaxCount, successMaxCount)(c) + // 获取用户组 + group := c.GetString("token_group") + if group == "" { + group = c.GetString("group") + } + if group == "" { + group = "default" // 默认组 + } + + // 尝试获取用户组特定的限制 + groupTotalCount, groupSuccessCount, found := setting.GetGroupRateLimit(group) + + // 确定最终的限制值 + finalTotalCount := setting.ModelRequestRateLimitCount // 默认使用全局总次数限制 + finalSuccessCount := setting.ModelRequestRateLimitSuccessCount // 默认使用全局成功次数限制 + + if found { + // 如果找到用户组特定限制,则使用它们 + finalTotalCount = groupTotalCount + finalSuccessCount = groupSuccessCount + common.LogWarn(c.Request.Context(), fmt.Sprintf("Using rate limit for group '%s': total=%d, success=%d", group, finalTotalCount, finalSuccessCount)) } else { - memoryRateLimitHandler(duration, totalMaxCount, successMaxCount)(c) + common.LogInfo(c.Request.Context(), fmt.Sprintf("No specific rate limit found for group '%s', using global limits: total=%d, success=%d", group, finalTotalCount, finalSuccessCount)) + } + + // 根据存储类型选择并执行限流处理器,传入最终确定的限制值 + if common.RedisEnabled { + redisRateLimitHandler(duration, finalTotalCount, finalSuccessCount)(c) + } else { + memoryRateLimitHandler(duration, finalTotalCount, finalSuccessCount)(c) } } } diff --git a/model/option.go b/model/option.go index d575742f..1f5fb3aa 100644 --- a/model/option.go +++ b/model/option.go @@ -1,6 +1,8 @@ package model import ( + "encoding/json" + "fmt" "one-api/common" "one-api/setting" "one-api/setting/config" @@ -96,6 +98,7 @@ func InitOptionMap() { common.OptionMap["ModelPrice"] = operation_setting.ModelPrice2JSONString() common.OptionMap["CacheRatio"] = operation_setting.CacheRatio2JSONString() common.OptionMap["GroupRatio"] = setting.GroupRatio2JSONString() + common.OptionMap[setting.ModelRequestRateLimitGroupKey] = "{}" // 添加用户组速率限制默认值 common.OptionMap["UserUsableGroups"] = setting.UserUsableGroups2JSONString() common.OptionMap["CompletionRatio"] = operation_setting.CompletionRatio2JSONString() common.OptionMap["TopUpLink"] = common.TopUpLink @@ -150,7 +153,32 @@ func SyncOptions(frequency int) { } func UpdateOption(key string, value string) error { - // Save to database first + originalValue := value // 保存原始值以备后用 + + // Validate and format specific keys before saving + if key == setting.ModelRequestRateLimitGroupKey { + var cfg map[string][2]int + // Validate the JSON structure first using the original value + err := json.Unmarshal([]byte(originalValue), &cfg) + if err != nil { + // 提供更具体的错误信息 + return fmt.Errorf("无效的 JSON 格式 for %s: %w", key, err) + } + // TODO: 可以添加更细致的结构验证,例如检查数组长度是否为2,值是否为非负数等。 + // if !isValidModelRequestRateLimitGroupConfig(cfg) { + // return fmt.Errorf("无效的配置值 for %s", key) + // } + + // If valid, format the JSON before saving + formattedValueBytes, marshalErr := json.MarshalIndent(cfg, "", " ") + if marshalErr != nil { + // This should ideally not happen if validation passed, but handle defensively + return fmt.Errorf("failed to marshal validated %s config: %w", key, marshalErr) + } + value = string(formattedValueBytes) // Use formatted JSON for saving and memory update + } + + // Save to database option := Option{ Key: key, } @@ -160,8 +188,12 @@ func UpdateOption(key string, value string) error { // Save is a combination function. // If save value does not contain primary key, it will execute Create, // otherwise it will execute Update (with all fields). - DB.Save(&option) - // Update OptionMap + if err := DB.Save(&option).Error; err != nil { + return fmt.Errorf("保存选项 %s 到数据库失败: %w", key, err) // 添加错误上下文 + } + + // Update OptionMap in memory using the potentially formatted value + // updateOptionMap 会处理内存中 setting.ModelRequestRateLimitGroupConfig 的更新 return updateOptionMap(key, value) } @@ -372,6 +404,15 @@ func updateOptionMap(key string, value string) (err error) { operation_setting.AutomaticDisableKeywordsFromString(value) case "StreamCacheQueueLength": setting.StreamCacheQueueLength, _ = strconv.Atoi(value) + case setting.ModelRequestRateLimitGroupKey: + // Use the (potentially formatted) value passed from UpdateOption + // to update the actual configuration in memory. + // This is the single point where the memory state for this specific setting is updated. + err = setting.UpdateModelRequestRateLimitGroupConfig(value) + if err != nil { + // 添加错误上下文 + err = fmt.Errorf("更新内存中的 %s 配置失败: %w", key, err) + } } return err } diff --git a/setting/rate_limit.go b/setting/rate_limit.go index 4b216948..c83885a6 100644 --- a/setting/rate_limit.go +++ b/setting/rate_limit.go @@ -1,6 +1,76 @@ package setting +import ( + "encoding/json" + "fmt" + "one-api/common" + "sync" +) + var ModelRequestRateLimitEnabled = false var ModelRequestRateLimitDurationMinutes = 1 var ModelRequestRateLimitCount = 0 var ModelRequestRateLimitSuccessCount = 1000 + +// ModelRequestRateLimitGroupKey 定义了模型请求按组速率限制的配置键 +const ModelRequestRateLimitGroupKey = "ModelRequestRateLimitGroup" + +// ModelRequestRateLimitGroupConfig 存储按用户组解析后的速率限制配置 +// map[groupName][2]int{totalCount, successCount} +var ModelRequestRateLimitGroupConfig map[string][2]int +var ModelRequestRateLimitGroupMutex sync.RWMutex + +// UpdateModelRequestRateLimitGroupConfig 解析、校验并更新内存中的用户组速率限制配置 +func UpdateModelRequestRateLimitGroupConfig(jsonStr string) error { + ModelRequestRateLimitGroupMutex.Lock() + defer ModelRequestRateLimitGroupMutex.Unlock() + + var newConfig map[string][2]int + if jsonStr == "" || jsonStr == "{}" { + // 如果配置为空或空JSON对象,则清空内存配置 + ModelRequestRateLimitGroupConfig = make(map[string][2]int) + common.SysLog("Model request rate limit group config cleared") + return nil + } + + err := json.Unmarshal([]byte(jsonStr), &newConfig) + if err != nil { + return fmt.Errorf("failed to unmarshal ModelRequestRateLimitGroup config: %w", err) + } + + // 校验配置值 + for group, limits := range newConfig { + if len(limits) != 2 { + return fmt.Errorf("invalid config for group '%s': limits array length must be 2", group) + } + if limits[1] <= 0 { // successCount must be greater than 0 + return fmt.Errorf("invalid config for group '%s': successCount (limits[1]) must be greater than 0", group) + } + if limits[0] < 0 { // totalCount can be 0 (no limit) or positive + return fmt.Errorf("invalid config for group '%s': totalCount (limits[0]) cannot be negative", group) + } + if limits[0] > 0 && limits[0] < limits[1] { // If totalCount is set, it must be >= successCount + return fmt.Errorf("invalid config for group '%s': totalCount (limits[0]) must be greater than or equal to successCount (limits[1]) when totalCount > 0", group) + } + } + + ModelRequestRateLimitGroupConfig = newConfig + common.SysLog("Model request rate limit group config updated") + return nil +} + +// GetGroupRateLimit 安全地获取指定用户组的速率限制值 +func GetGroupRateLimit(group string) (totalCount, successCount int, found bool) { + ModelRequestRateLimitGroupMutex.RLock() + defer ModelRequestRateLimitGroupMutex.RUnlock() + + if ModelRequestRateLimitGroupConfig == nil { + return 0, 0, false // 配置尚未初始化 + } + + limits, found := ModelRequestRateLimitGroupConfig[group] + if !found { + return 0, 0, false + } + return limits[0], limits[1], true +} diff --git a/web/src/components/RateLimitSetting.js b/web/src/components/RateLimitSetting.js index e06038d6..ad6b53da 100644 --- a/web/src/components/RateLimitSetting.js +++ b/web/src/components/RateLimitSetting.js @@ -13,6 +13,7 @@ const RateLimitSetting = () => { ModelRequestRateLimitCount: 0, ModelRequestRateLimitSuccessCount: 1000, ModelRequestRateLimitDurationMinutes: 1, + ModelRequestRateLimitGroup: {}, }); let [loading, setLoading] = useState(false); diff --git a/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js b/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js index 800e9636..ec1c2158 100644 --- a/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js +++ b/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js @@ -18,6 +18,7 @@ export default function RequestRateLimit(props) { ModelRequestRateLimitCount: -1, ModelRequestRateLimitSuccessCount: 1000, ModelRequestRateLimitDurationMinutes: 1, + ModelRequestRateLimitGroup: '{}', // 添加新字段并设置默认值 }); const refForm = useRef(); const [inputsRow, setInputsRow] = useState(inputs); @@ -32,25 +33,49 @@ export default function RequestRateLimit(props) { } else { value = inputs[item.key]; } + // 校验 ModelRequestRateLimitGroup 是否为有效的 JSON 对象字符串 + if (item.key === 'ModelRequestRateLimitGroup') { + try { + JSON.parse(value); + } catch (e) { + showError(t('用户组速率限制配置不是有效的 JSON 格式!')); + // 阻止请求发送 + return Promise.reject('Invalid JSON format'); + } + } return API.put('/api/option/', { key: item.key, value, }); }); + + // 过滤掉无效的请求(例如,无效的 JSON) + const validRequests = requestQueue.filter(req => req !== null && req !== undefined && typeof req.then === 'function'); + + if (validRequests.length === 0 && requestQueue.length > 0) { + // 如果所有请求都被过滤掉了(因为 JSON 无效),则不继续执行 + return; + } + setLoading(true); - Promise.all(requestQueue) + Promise.all(validRequests) .then((res) => { - if (requestQueue.length === 1) { + if (validRequests.length === 1) { if (res.includes(undefined)) return; - } else if (requestQueue.length > 1) { + } else if (validRequests.length > 1) { if (res.includes(undefined)) return showError(t('部分保存失败,请重试')); } showSuccess(t('保存成功')); props.refresh(); + // 更新 inputsRow 以反映保存后的状态 + setInputsRow(structuredClone(inputs)); }) - .catch(() => { - showError(t('保存失败,请重试')); + .catch((error) => { + // 检查是否是由于无效 JSON 导致的错误 + if (error !== 'Invalid JSON format') { + showError(t('保存失败,请重试')); + } }) .finally(() => { setLoading(false); @@ -66,8 +91,11 @@ export default function RequestRateLimit(props) { } setInputs(currentInputs); setInputsRow(structuredClone(currentInputs)); - refForm.current.setValues(currentInputs); - }, [props.options]); + // 检查 refForm.current 是否存在 + if (refForm.current) { + refForm.current.setValues(currentInputs); + } + }, [props.options]); // 依赖项保持不变,因为 inputs 状态的结构已固定 return ( <> @@ -147,7 +175,41 @@ export default function RequestRateLimit(props) { /> + {/* 用户组速率限制配置项 */} + + +

{t('说明:')}

+
    +
  • {t('使用 JSON 对象格式,键为用户组名 (字符串),值为包含两个整数的数组 [总次数限制, 成功次数限制]。')}
  • +
  • {t('总次数限制: 周期内允许的总请求次数 (含失败),0 代表不限制。')}
  • +
  • {t('成功次数限制: 周期内允许的成功请求次数 (HTTP < 400),必须大于 0。')}
  • +
  • {t('此配置将优先于上方的全局限制设置。')}
  • +
  • {t('未在此处配置的用户组将使用全局限制。')}
  • +
  • {t('限制周期统一使用上方配置的“限制周期”值。')}
  • +
  • {t('输入无效的 JSON 将无法保存。')}
  • +
+ + } + autosize={{ minRows: 5, maxRows: 15 }} + style={{ fontFamily: 'monospace' }} + onChange={(value) => { + setInputs({ + ...inputs, + ModelRequestRateLimitGroup: value, // 直接更新字符串值 + }); + }} + /> + +
+