feat: 增加分组速率功能
This commit is contained in:
@@ -168,16 +168,39 @@ func ModelRequestRateLimit() func(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 计算限流参数
|
||||
// 计算通用限流参数
|
||||
duration := int64(setting.ModelRequestRateLimitDurationMinutes * 60)
|
||||
totalMaxCount := setting.ModelRequestRateLimitCount
|
||||
successMaxCount := setting.ModelRequestRateLimitSuccessCount
|
||||
|
||||
// 根据存储类型选择并执行限流处理器
|
||||
if common.RedisEnabled {
|
||||
redisRateLimitHandler(duration, totalMaxCount, successMaxCount)(c)
|
||||
// 获取用户组
|
||||
group := c.GetString("token_group")
|
||||
if group == "" {
|
||||
group = c.GetString("group")
|
||||
}
|
||||
if group == "" {
|
||||
group = "default" // 默认组
|
||||
}
|
||||
|
||||
// 尝试获取用户组特定的限制
|
||||
groupTotalCount, groupSuccessCount, found := setting.GetGroupRateLimit(group)
|
||||
|
||||
// 确定最终的限制值
|
||||
finalTotalCount := setting.ModelRequestRateLimitCount // 默认使用全局总次数限制
|
||||
finalSuccessCount := setting.ModelRequestRateLimitSuccessCount // 默认使用全局成功次数限制
|
||||
|
||||
if found {
|
||||
// 如果找到用户组特定限制,则使用它们
|
||||
finalTotalCount = groupTotalCount
|
||||
finalSuccessCount = groupSuccessCount
|
||||
common.LogWarn(c.Request.Context(), fmt.Sprintf("Using rate limit for group '%s': total=%d, success=%d", group, finalTotalCount, finalSuccessCount))
|
||||
} else {
|
||||
memoryRateLimitHandler(duration, totalMaxCount, successMaxCount)(c)
|
||||
common.LogInfo(c.Request.Context(), fmt.Sprintf("No specific rate limit found for group '%s', using global limits: total=%d, success=%d", group, finalTotalCount, finalSuccessCount))
|
||||
}
|
||||
|
||||
// 根据存储类型选择并执行限流处理器,传入最终确定的限制值
|
||||
if common.RedisEnabled {
|
||||
redisRateLimitHandler(duration, finalTotalCount, finalSuccessCount)(c)
|
||||
} else {
|
||||
memoryRateLimitHandler(duration, finalTotalCount, finalSuccessCount)(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"one-api/common"
|
||||
"one-api/setting"
|
||||
"one-api/setting/config"
|
||||
@@ -96,6 +98,7 @@ func InitOptionMap() {
|
||||
common.OptionMap["ModelPrice"] = operation_setting.ModelPrice2JSONString()
|
||||
common.OptionMap["CacheRatio"] = operation_setting.CacheRatio2JSONString()
|
||||
common.OptionMap["GroupRatio"] = setting.GroupRatio2JSONString()
|
||||
common.OptionMap[setting.ModelRequestRateLimitGroupKey] = "{}" // 添加用户组速率限制默认值
|
||||
common.OptionMap["UserUsableGroups"] = setting.UserUsableGroups2JSONString()
|
||||
common.OptionMap["CompletionRatio"] = operation_setting.CompletionRatio2JSONString()
|
||||
common.OptionMap["TopUpLink"] = common.TopUpLink
|
||||
@@ -150,7 +153,32 @@ func SyncOptions(frequency int) {
|
||||
}
|
||||
|
||||
func UpdateOption(key string, value string) error {
|
||||
// Save to database first
|
||||
originalValue := value // 保存原始值以备后用
|
||||
|
||||
// Validate and format specific keys before saving
|
||||
if key == setting.ModelRequestRateLimitGroupKey {
|
||||
var cfg map[string][2]int
|
||||
// Validate the JSON structure first using the original value
|
||||
err := json.Unmarshal([]byte(originalValue), &cfg)
|
||||
if err != nil {
|
||||
// 提供更具体的错误信息
|
||||
return fmt.Errorf("无效的 JSON 格式 for %s: %w", key, err)
|
||||
}
|
||||
// TODO: 可以添加更细致的结构验证,例如检查数组长度是否为2,值是否为非负数等。
|
||||
// if !isValidModelRequestRateLimitGroupConfig(cfg) {
|
||||
// return fmt.Errorf("无效的配置值 for %s", key)
|
||||
// }
|
||||
|
||||
// If valid, format the JSON before saving
|
||||
formattedValueBytes, marshalErr := json.MarshalIndent(cfg, "", " ")
|
||||
if marshalErr != nil {
|
||||
// This should ideally not happen if validation passed, but handle defensively
|
||||
return fmt.Errorf("failed to marshal validated %s config: %w", key, marshalErr)
|
||||
}
|
||||
value = string(formattedValueBytes) // Use formatted JSON for saving and memory update
|
||||
}
|
||||
|
||||
// Save to database
|
||||
option := Option{
|
||||
Key: key,
|
||||
}
|
||||
@@ -160,8 +188,12 @@ func UpdateOption(key string, value string) error {
|
||||
// Save is a combination function.
|
||||
// If save value does not contain primary key, it will execute Create,
|
||||
// otherwise it will execute Update (with all fields).
|
||||
DB.Save(&option)
|
||||
// Update OptionMap
|
||||
if err := DB.Save(&option).Error; err != nil {
|
||||
return fmt.Errorf("保存选项 %s 到数据库失败: %w", key, err) // 添加错误上下文
|
||||
}
|
||||
|
||||
// Update OptionMap in memory using the potentially formatted value
|
||||
// updateOptionMap 会处理内存中 setting.ModelRequestRateLimitGroupConfig 的更新
|
||||
return updateOptionMap(key, value)
|
||||
}
|
||||
|
||||
@@ -372,6 +404,15 @@ func updateOptionMap(key string, value string) (err error) {
|
||||
operation_setting.AutomaticDisableKeywordsFromString(value)
|
||||
case "StreamCacheQueueLength":
|
||||
setting.StreamCacheQueueLength, _ = strconv.Atoi(value)
|
||||
case setting.ModelRequestRateLimitGroupKey:
|
||||
// Use the (potentially formatted) value passed from UpdateOption
|
||||
// to update the actual configuration in memory.
|
||||
// This is the single point where the memory state for this specific setting is updated.
|
||||
err = setting.UpdateModelRequestRateLimitGroupConfig(value)
|
||||
if err != nil {
|
||||
// 添加错误上下文
|
||||
err = fmt.Errorf("更新内存中的 %s 配置失败: %w", key, err)
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -1,6 +1,76 @@
|
||||
package setting
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"one-api/common"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var ModelRequestRateLimitEnabled = false
|
||||
var ModelRequestRateLimitDurationMinutes = 1
|
||||
var ModelRequestRateLimitCount = 0
|
||||
var ModelRequestRateLimitSuccessCount = 1000
|
||||
|
||||
// ModelRequestRateLimitGroupKey 定义了模型请求按组速率限制的配置键
|
||||
const ModelRequestRateLimitGroupKey = "ModelRequestRateLimitGroup"
|
||||
|
||||
// ModelRequestRateLimitGroupConfig 存储按用户组解析后的速率限制配置
|
||||
// map[groupName][2]int{totalCount, successCount}
|
||||
var ModelRequestRateLimitGroupConfig map[string][2]int
|
||||
var ModelRequestRateLimitGroupMutex sync.RWMutex
|
||||
|
||||
// UpdateModelRequestRateLimitGroupConfig 解析、校验并更新内存中的用户组速率限制配置
|
||||
func UpdateModelRequestRateLimitGroupConfig(jsonStr string) error {
|
||||
ModelRequestRateLimitGroupMutex.Lock()
|
||||
defer ModelRequestRateLimitGroupMutex.Unlock()
|
||||
|
||||
var newConfig map[string][2]int
|
||||
if jsonStr == "" || jsonStr == "{}" {
|
||||
// 如果配置为空或空JSON对象,则清空内存配置
|
||||
ModelRequestRateLimitGroupConfig = make(map[string][2]int)
|
||||
common.SysLog("Model request rate limit group config cleared")
|
||||
return nil
|
||||
}
|
||||
|
||||
err := json.Unmarshal([]byte(jsonStr), &newConfig)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to unmarshal ModelRequestRateLimitGroup config: %w", err)
|
||||
}
|
||||
|
||||
// 校验配置值
|
||||
for group, limits := range newConfig {
|
||||
if len(limits) != 2 {
|
||||
return fmt.Errorf("invalid config for group '%s': limits array length must be 2", group)
|
||||
}
|
||||
if limits[1] <= 0 { // successCount must be greater than 0
|
||||
return fmt.Errorf("invalid config for group '%s': successCount (limits[1]) must be greater than 0", group)
|
||||
}
|
||||
if limits[0] < 0 { // totalCount can be 0 (no limit) or positive
|
||||
return fmt.Errorf("invalid config for group '%s': totalCount (limits[0]) cannot be negative", group)
|
||||
}
|
||||
if limits[0] > 0 && limits[0] < limits[1] { // If totalCount is set, it must be >= successCount
|
||||
return fmt.Errorf("invalid config for group '%s': totalCount (limits[0]) must be greater than or equal to successCount (limits[1]) when totalCount > 0", group)
|
||||
}
|
||||
}
|
||||
|
||||
ModelRequestRateLimitGroupConfig = newConfig
|
||||
common.SysLog("Model request rate limit group config updated")
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetGroupRateLimit 安全地获取指定用户组的速率限制值
|
||||
func GetGroupRateLimit(group string) (totalCount, successCount int, found bool) {
|
||||
ModelRequestRateLimitGroupMutex.RLock()
|
||||
defer ModelRequestRateLimitGroupMutex.RUnlock()
|
||||
|
||||
if ModelRequestRateLimitGroupConfig == nil {
|
||||
return 0, 0, false // 配置尚未初始化
|
||||
}
|
||||
|
||||
limits, found := ModelRequestRateLimitGroupConfig[group]
|
||||
if !found {
|
||||
return 0, 0, false
|
||||
}
|
||||
return limits[0], limits[1], true
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@ const RateLimitSetting = () => {
|
||||
ModelRequestRateLimitCount: 0,
|
||||
ModelRequestRateLimitSuccessCount: 1000,
|
||||
ModelRequestRateLimitDurationMinutes: 1,
|
||||
ModelRequestRateLimitGroup: {},
|
||||
});
|
||||
|
||||
let [loading, setLoading] = useState(false);
|
||||
|
||||
@@ -18,6 +18,7 @@ export default function RequestRateLimit(props) {
|
||||
ModelRequestRateLimitCount: -1,
|
||||
ModelRequestRateLimitSuccessCount: 1000,
|
||||
ModelRequestRateLimitDurationMinutes: 1,
|
||||
ModelRequestRateLimitGroup: '{}', // 添加新字段并设置默认值
|
||||
});
|
||||
const refForm = useRef();
|
||||
const [inputsRow, setInputsRow] = useState(inputs);
|
||||
@@ -32,25 +33,49 @@ export default function RequestRateLimit(props) {
|
||||
} else {
|
||||
value = inputs[item.key];
|
||||
}
|
||||
// 校验 ModelRequestRateLimitGroup 是否为有效的 JSON 对象字符串
|
||||
if (item.key === 'ModelRequestRateLimitGroup') {
|
||||
try {
|
||||
JSON.parse(value);
|
||||
} catch (e) {
|
||||
showError(t('用户组速率限制配置不是有效的 JSON 格式!'));
|
||||
// 阻止请求发送
|
||||
return Promise.reject('Invalid JSON format');
|
||||
}
|
||||
}
|
||||
return API.put('/api/option/', {
|
||||
key: item.key,
|
||||
value,
|
||||
});
|
||||
});
|
||||
|
||||
// 过滤掉无效的请求(例如,无效的 JSON)
|
||||
const validRequests = requestQueue.filter(req => req !== null && req !== undefined && typeof req.then === 'function');
|
||||
|
||||
if (validRequests.length === 0 && requestQueue.length > 0) {
|
||||
// 如果所有请求都被过滤掉了(因为 JSON 无效),则不继续执行
|
||||
return;
|
||||
}
|
||||
|
||||
setLoading(true);
|
||||
Promise.all(requestQueue)
|
||||
Promise.all(validRequests)
|
||||
.then((res) => {
|
||||
if (requestQueue.length === 1) {
|
||||
if (validRequests.length === 1) {
|
||||
if (res.includes(undefined)) return;
|
||||
} else if (requestQueue.length > 1) {
|
||||
} else if (validRequests.length > 1) {
|
||||
if (res.includes(undefined))
|
||||
return showError(t('部分保存失败,请重试'));
|
||||
}
|
||||
showSuccess(t('保存成功'));
|
||||
props.refresh();
|
||||
// 更新 inputsRow 以反映保存后的状态
|
||||
setInputsRow(structuredClone(inputs));
|
||||
})
|
||||
.catch(() => {
|
||||
showError(t('保存失败,请重试'));
|
||||
.catch((error) => {
|
||||
// 检查是否是由于无效 JSON 导致的错误
|
||||
if (error !== 'Invalid JSON format') {
|
||||
showError(t('保存失败,请重试'));
|
||||
}
|
||||
})
|
||||
.finally(() => {
|
||||
setLoading(false);
|
||||
@@ -66,8 +91,11 @@ export default function RequestRateLimit(props) {
|
||||
}
|
||||
setInputs(currentInputs);
|
||||
setInputsRow(structuredClone(currentInputs));
|
||||
refForm.current.setValues(currentInputs);
|
||||
}, [props.options]);
|
||||
// 检查 refForm.current 是否存在
|
||||
if (refForm.current) {
|
||||
refForm.current.setValues(currentInputs);
|
||||
}
|
||||
}, [props.options]); // 依赖项保持不变,因为 inputs 状态的结构已固定
|
||||
|
||||
return (
|
||||
<>
|
||||
@@ -147,7 +175,41 @@ export default function RequestRateLimit(props) {
|
||||
/>
|
||||
</Col>
|
||||
</Row>
|
||||
{/* 用户组速率限制配置项 */}
|
||||
<Row>
|
||||
<Col span={24}>
|
||||
<Form.TextArea
|
||||
label={t('用户组速率限制 (JSON)')}
|
||||
field={'ModelRequestRateLimitGroup'}
|
||||
placeholder={t( // 更新 placeholder
|
||||
'请输入 JSON 格式的用户组限制,例如:\n{\n "default": [200, 100],\n "vip": [1000, 500]\n}',
|
||||
)}
|
||||
extraText={ // 更新 extraText
|
||||
<div>
|
||||
<p>{t('说明:')}</p>
|
||||
<ul>
|
||||
<li>{t('使用 JSON 对象格式,键为用户组名 (字符串),值为包含两个整数的数组 [总次数限制, 成功次数限制]。')}</li>
|
||||
<li>{t('总次数限制: 周期内允许的总请求次数 (含失败),0 代表不限制。')}</li>
|
||||
<li>{t('成功次数限制: 周期内允许的成功请求次数 (HTTP < 400),必须大于 0。')}</li>
|
||||
<li>{t('此配置将优先于上方的全局限制设置。')}</li>
|
||||
<li>{t('未在此处配置的用户组将使用全局限制。')}</li>
|
||||
<li>{t('限制周期统一使用上方配置的“限制周期”值。')}</li>
|
||||
<li>{t('输入无效的 JSON 将无法保存。')}</li>
|
||||
</ul>
|
||||
</div>
|
||||
}
|
||||
autosize={{ minRows: 5, maxRows: 15 }}
|
||||
style={{ fontFamily: 'monospace' }}
|
||||
onChange={(value) => {
|
||||
setInputs({
|
||||
...inputs,
|
||||
ModelRequestRateLimitGroup: value, // 直接更新字符串值
|
||||
});
|
||||
}}
|
||||
/>
|
||||
</Col>
|
||||
</Row>
|
||||
<Row style={{ marginTop: 15 }}>
|
||||
<Button size='default' onClick={onSubmit}>
|
||||
{t('保存模型速率限制')}
|
||||
</Button>
|
||||
|
||||
Reference in New Issue
Block a user