diff --git a/middleware/model-rate-limit.go b/middleware/model-rate-limit.go
index 581dc451..d4199ece 100644
--- a/middleware/model-rate-limit.go
+++ b/middleware/model-rate-limit.go
@@ -168,16 +168,39 @@ func ModelRequestRateLimit() func(c *gin.Context) {
return
}
- // 计算限流参数
+ // 计算通用限流参数
duration := int64(setting.ModelRequestRateLimitDurationMinutes * 60)
- totalMaxCount := setting.ModelRequestRateLimitCount
- successMaxCount := setting.ModelRequestRateLimitSuccessCount
- // 根据存储类型选择并执行限流处理器
- if common.RedisEnabled {
- redisRateLimitHandler(duration, totalMaxCount, successMaxCount)(c)
+ // 获取用户组
+ group := c.GetString("token_group")
+ if group == "" {
+ group = c.GetString("group")
+ }
+ if group == "" {
+ group = "default" // 默认组
+ }
+
+ // 尝试获取用户组特定的限制
+ groupTotalCount, groupSuccessCount, found := setting.GetGroupRateLimit(group)
+
+ // 确定最终的限制值
+ finalTotalCount := setting.ModelRequestRateLimitCount // 默认使用全局总次数限制
+ finalSuccessCount := setting.ModelRequestRateLimitSuccessCount // 默认使用全局成功次数限制
+
+ if found {
+ // 如果找到用户组特定限制,则使用它们
+ finalTotalCount = groupTotalCount
+ finalSuccessCount = groupSuccessCount
+ common.LogWarn(c.Request.Context(), fmt.Sprintf("Using rate limit for group '%s': total=%d, success=%d", group, finalTotalCount, finalSuccessCount))
} else {
- memoryRateLimitHandler(duration, totalMaxCount, successMaxCount)(c)
+ common.LogInfo(c.Request.Context(), fmt.Sprintf("No specific rate limit found for group '%s', using global limits: total=%d, success=%d", group, finalTotalCount, finalSuccessCount))
+ }
+
+ // 根据存储类型选择并执行限流处理器,传入最终确定的限制值
+ if common.RedisEnabled {
+ redisRateLimitHandler(duration, finalTotalCount, finalSuccessCount)(c)
+ } else {
+ memoryRateLimitHandler(duration, finalTotalCount, finalSuccessCount)(c)
}
}
}
diff --git a/model/option.go b/model/option.go
index d575742f..1f5fb3aa 100644
--- a/model/option.go
+++ b/model/option.go
@@ -1,6 +1,8 @@
package model
import (
+ "encoding/json"
+ "fmt"
"one-api/common"
"one-api/setting"
"one-api/setting/config"
@@ -96,6 +98,7 @@ func InitOptionMap() {
common.OptionMap["ModelPrice"] = operation_setting.ModelPrice2JSONString()
common.OptionMap["CacheRatio"] = operation_setting.CacheRatio2JSONString()
common.OptionMap["GroupRatio"] = setting.GroupRatio2JSONString()
+ common.OptionMap[setting.ModelRequestRateLimitGroupKey] = "{}" // 添加用户组速率限制默认值
common.OptionMap["UserUsableGroups"] = setting.UserUsableGroups2JSONString()
common.OptionMap["CompletionRatio"] = operation_setting.CompletionRatio2JSONString()
common.OptionMap["TopUpLink"] = common.TopUpLink
@@ -150,7 +153,32 @@ func SyncOptions(frequency int) {
}
func UpdateOption(key string, value string) error {
- // Save to database first
+ originalValue := value // 保存原始值以备后用
+
+ // Validate and format specific keys before saving
+ if key == setting.ModelRequestRateLimitGroupKey {
+ var cfg map[string][2]int
+ // Validate the JSON structure first using the original value
+ err := json.Unmarshal([]byte(originalValue), &cfg)
+ if err != nil {
+ // 提供更具体的错误信息
+ return fmt.Errorf("无效的 JSON 格式 for %s: %w", key, err)
+ }
+ // TODO: 可以添加更细致的结构验证,例如检查数组长度是否为2,值是否为非负数等。
+ // if !isValidModelRequestRateLimitGroupConfig(cfg) {
+ // return fmt.Errorf("无效的配置值 for %s", key)
+ // }
+
+ // If valid, format the JSON before saving
+ formattedValueBytes, marshalErr := json.MarshalIndent(cfg, "", " ")
+ if marshalErr != nil {
+ // This should ideally not happen if validation passed, but handle defensively
+ return fmt.Errorf("failed to marshal validated %s config: %w", key, marshalErr)
+ }
+ value = string(formattedValueBytes) // Use formatted JSON for saving and memory update
+ }
+
+ // Save to database
option := Option{
Key: key,
}
@@ -160,8 +188,12 @@ func UpdateOption(key string, value string) error {
// Save is a combination function.
// If save value does not contain primary key, it will execute Create,
// otherwise it will execute Update (with all fields).
- DB.Save(&option)
- // Update OptionMap
+ if err := DB.Save(&option).Error; err != nil {
+ return fmt.Errorf("保存选项 %s 到数据库失败: %w", key, err) // 添加错误上下文
+ }
+
+ // Update OptionMap in memory using the potentially formatted value
+ // updateOptionMap 会处理内存中 setting.ModelRequestRateLimitGroupConfig 的更新
return updateOptionMap(key, value)
}
@@ -372,6 +404,15 @@ func updateOptionMap(key string, value string) (err error) {
operation_setting.AutomaticDisableKeywordsFromString(value)
case "StreamCacheQueueLength":
setting.StreamCacheQueueLength, _ = strconv.Atoi(value)
+ case setting.ModelRequestRateLimitGroupKey:
+ // Use the (potentially formatted) value passed from UpdateOption
+ // to update the actual configuration in memory.
+ // This is the single point where the memory state for this specific setting is updated.
+ err = setting.UpdateModelRequestRateLimitGroupConfig(value)
+ if err != nil {
+ // 添加错误上下文
+ err = fmt.Errorf("更新内存中的 %s 配置失败: %w", key, err)
+ }
}
return err
}
diff --git a/setting/rate_limit.go b/setting/rate_limit.go
index 4b216948..c83885a6 100644
--- a/setting/rate_limit.go
+++ b/setting/rate_limit.go
@@ -1,6 +1,76 @@
package setting
+import (
+ "encoding/json"
+ "fmt"
+ "one-api/common"
+ "sync"
+)
+
var ModelRequestRateLimitEnabled = false
var ModelRequestRateLimitDurationMinutes = 1
var ModelRequestRateLimitCount = 0
var ModelRequestRateLimitSuccessCount = 1000
+
+// ModelRequestRateLimitGroupKey 定义了模型请求按组速率限制的配置键
+const ModelRequestRateLimitGroupKey = "ModelRequestRateLimitGroup"
+
+// ModelRequestRateLimitGroupConfig 存储按用户组解析后的速率限制配置
+// map[groupName][2]int{totalCount, successCount}
+var ModelRequestRateLimitGroupConfig map[string][2]int
+var ModelRequestRateLimitGroupMutex sync.RWMutex
+
+// UpdateModelRequestRateLimitGroupConfig 解析、校验并更新内存中的用户组速率限制配置
+func UpdateModelRequestRateLimitGroupConfig(jsonStr string) error {
+ ModelRequestRateLimitGroupMutex.Lock()
+ defer ModelRequestRateLimitGroupMutex.Unlock()
+
+ var newConfig map[string][2]int
+ if jsonStr == "" || jsonStr == "{}" {
+ // 如果配置为空或空JSON对象,则清空内存配置
+ ModelRequestRateLimitGroupConfig = make(map[string][2]int)
+ common.SysLog("Model request rate limit group config cleared")
+ return nil
+ }
+
+ err := json.Unmarshal([]byte(jsonStr), &newConfig)
+ if err != nil {
+ return fmt.Errorf("failed to unmarshal ModelRequestRateLimitGroup config: %w", err)
+ }
+
+ // 校验配置值
+ for group, limits := range newConfig {
+ if len(limits) != 2 {
+ return fmt.Errorf("invalid config for group '%s': limits array length must be 2", group)
+ }
+ if limits[1] <= 0 { // successCount must be greater than 0
+ return fmt.Errorf("invalid config for group '%s': successCount (limits[1]) must be greater than 0", group)
+ }
+ if limits[0] < 0 { // totalCount can be 0 (no limit) or positive
+ return fmt.Errorf("invalid config for group '%s': totalCount (limits[0]) cannot be negative", group)
+ }
+ if limits[0] > 0 && limits[0] < limits[1] { // If totalCount is set, it must be >= successCount
+ return fmt.Errorf("invalid config for group '%s': totalCount (limits[0]) must be greater than or equal to successCount (limits[1]) when totalCount > 0", group)
+ }
+ }
+
+ ModelRequestRateLimitGroupConfig = newConfig
+ common.SysLog("Model request rate limit group config updated")
+ return nil
+}
+
+// GetGroupRateLimit 安全地获取指定用户组的速率限制值
+func GetGroupRateLimit(group string) (totalCount, successCount int, found bool) {
+ ModelRequestRateLimitGroupMutex.RLock()
+ defer ModelRequestRateLimitGroupMutex.RUnlock()
+
+ if ModelRequestRateLimitGroupConfig == nil {
+ return 0, 0, false // 配置尚未初始化
+ }
+
+ limits, found := ModelRequestRateLimitGroupConfig[group]
+ if !found {
+ return 0, 0, false
+ }
+ return limits[0], limits[1], true
+}
diff --git a/web/src/components/RateLimitSetting.js b/web/src/components/RateLimitSetting.js
index e06038d6..ad6b53da 100644
--- a/web/src/components/RateLimitSetting.js
+++ b/web/src/components/RateLimitSetting.js
@@ -13,6 +13,7 @@ const RateLimitSetting = () => {
ModelRequestRateLimitCount: 0,
ModelRequestRateLimitSuccessCount: 1000,
ModelRequestRateLimitDurationMinutes: 1,
+ ModelRequestRateLimitGroup: {},
});
let [loading, setLoading] = useState(false);
diff --git a/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js b/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js
index 800e9636..ec1c2158 100644
--- a/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js
+++ b/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js
@@ -18,6 +18,7 @@ export default function RequestRateLimit(props) {
ModelRequestRateLimitCount: -1,
ModelRequestRateLimitSuccessCount: 1000,
ModelRequestRateLimitDurationMinutes: 1,
+ ModelRequestRateLimitGroup: '{}', // 添加新字段并设置默认值
});
const refForm = useRef();
const [inputsRow, setInputsRow] = useState(inputs);
@@ -32,25 +33,49 @@ export default function RequestRateLimit(props) {
} else {
value = inputs[item.key];
}
+ // 校验 ModelRequestRateLimitGroup 是否为有效的 JSON 对象字符串
+ if (item.key === 'ModelRequestRateLimitGroup') {
+ try {
+ JSON.parse(value);
+ } catch (e) {
+ showError(t('用户组速率限制配置不是有效的 JSON 格式!'));
+ // 阻止请求发送
+ return Promise.reject('Invalid JSON format');
+ }
+ }
return API.put('/api/option/', {
key: item.key,
value,
});
});
+
+ // 过滤掉无效的请求(例如,无效的 JSON)
+ const validRequests = requestQueue.filter(req => req !== null && req !== undefined && typeof req.then === 'function');
+
+ if (validRequests.length === 0 && requestQueue.length > 0) {
+ // 如果所有请求都被过滤掉了(因为 JSON 无效),则不继续执行
+ return;
+ }
+
setLoading(true);
- Promise.all(requestQueue)
+ Promise.all(validRequests)
.then((res) => {
- if (requestQueue.length === 1) {
+ if (validRequests.length === 1) {
if (res.includes(undefined)) return;
- } else if (requestQueue.length > 1) {
+ } else if (validRequests.length > 1) {
if (res.includes(undefined))
return showError(t('部分保存失败,请重试'));
}
showSuccess(t('保存成功'));
props.refresh();
+ // 更新 inputsRow 以反映保存后的状态
+ setInputsRow(structuredClone(inputs));
})
- .catch(() => {
- showError(t('保存失败,请重试'));
+ .catch((error) => {
+ // 检查是否是由于无效 JSON 导致的错误
+ if (error !== 'Invalid JSON format') {
+ showError(t('保存失败,请重试'));
+ }
})
.finally(() => {
setLoading(false);
@@ -66,8 +91,11 @@ export default function RequestRateLimit(props) {
}
setInputs(currentInputs);
setInputsRow(structuredClone(currentInputs));
- refForm.current.setValues(currentInputs);
- }, [props.options]);
+ // 检查 refForm.current 是否存在
+ if (refForm.current) {
+ refForm.current.setValues(currentInputs);
+ }
+ }, [props.options]); // 依赖项保持不变,因为 inputs 状态的结构已固定
return (
<>
@@ -147,7 +175,41 @@ export default function RequestRateLimit(props) {
/>
+ {/* 用户组速率限制配置项 */}
+
+
+ {t('说明:')}
+
+ - {t('使用 JSON 对象格式,键为用户组名 (字符串),值为包含两个整数的数组 [总次数限制, 成功次数限制]。')}
+ - {t('总次数限制: 周期内允许的总请求次数 (含失败),0 代表不限制。')}
+ - {t('成功次数限制: 周期内允许的成功请求次数 (HTTP < 400),必须大于 0。')}
+ - {t('此配置将优先于上方的全局限制设置。')}
+ - {t('未在此处配置的用户组将使用全局限制。')}
+ - {t('限制周期统一使用上方配置的“限制周期”值。')}
+ - {t('输入无效的 JSON 将无法保存。')}
+
+
+ }
+ autosize={{ minRows: 5, maxRows: 15 }}
+ style={{ fontFamily: 'monospace' }}
+ onChange={(value) => {
+ setInputs({
+ ...inputs,
+ ModelRequestRateLimitGroup: value, // 直接更新字符串值
+ });
+ }}
+ />
+
+
+