feat: 增加分组速率功能
This commit is contained in:
@@ -168,16 +168,39 @@ func ModelRequestRateLimit() func(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 计算限流参数
|
// 计算通用限流参数
|
||||||
duration := int64(setting.ModelRequestRateLimitDurationMinutes * 60)
|
duration := int64(setting.ModelRequestRateLimitDurationMinutes * 60)
|
||||||
totalMaxCount := setting.ModelRequestRateLimitCount
|
|
||||||
successMaxCount := setting.ModelRequestRateLimitSuccessCount
|
|
||||||
|
|
||||||
// 根据存储类型选择并执行限流处理器
|
// 获取用户组
|
||||||
if common.RedisEnabled {
|
group := c.GetString("token_group")
|
||||||
redisRateLimitHandler(duration, totalMaxCount, successMaxCount)(c)
|
if group == "" {
|
||||||
|
group = c.GetString("group")
|
||||||
|
}
|
||||||
|
if group == "" {
|
||||||
|
group = "default" // 默认组
|
||||||
|
}
|
||||||
|
|
||||||
|
// 尝试获取用户组特定的限制
|
||||||
|
groupTotalCount, groupSuccessCount, found := setting.GetGroupRateLimit(group)
|
||||||
|
|
||||||
|
// 确定最终的限制值
|
||||||
|
finalTotalCount := setting.ModelRequestRateLimitCount // 默认使用全局总次数限制
|
||||||
|
finalSuccessCount := setting.ModelRequestRateLimitSuccessCount // 默认使用全局成功次数限制
|
||||||
|
|
||||||
|
if found {
|
||||||
|
// 如果找到用户组特定限制,则使用它们
|
||||||
|
finalTotalCount = groupTotalCount
|
||||||
|
finalSuccessCount = groupSuccessCount
|
||||||
|
common.LogWarn(c.Request.Context(), fmt.Sprintf("Using rate limit for group '%s': total=%d, success=%d", group, finalTotalCount, finalSuccessCount))
|
||||||
} else {
|
} else {
|
||||||
memoryRateLimitHandler(duration, totalMaxCount, successMaxCount)(c)
|
common.LogInfo(c.Request.Context(), fmt.Sprintf("No specific rate limit found for group '%s', using global limits: total=%d, success=%d", group, finalTotalCount, finalSuccessCount))
|
||||||
|
}
|
||||||
|
|
||||||
|
// 根据存储类型选择并执行限流处理器,传入最终确定的限制值
|
||||||
|
if common.RedisEnabled {
|
||||||
|
redisRateLimitHandler(duration, finalTotalCount, finalSuccessCount)(c)
|
||||||
|
} else {
|
||||||
|
memoryRateLimitHandler(duration, finalTotalCount, finalSuccessCount)(c)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
package model
|
package model
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/setting"
|
"one-api/setting"
|
||||||
"one-api/setting/config"
|
"one-api/setting/config"
|
||||||
@@ -96,6 +98,7 @@ func InitOptionMap() {
|
|||||||
common.OptionMap["ModelPrice"] = operation_setting.ModelPrice2JSONString()
|
common.OptionMap["ModelPrice"] = operation_setting.ModelPrice2JSONString()
|
||||||
common.OptionMap["CacheRatio"] = operation_setting.CacheRatio2JSONString()
|
common.OptionMap["CacheRatio"] = operation_setting.CacheRatio2JSONString()
|
||||||
common.OptionMap["GroupRatio"] = setting.GroupRatio2JSONString()
|
common.OptionMap["GroupRatio"] = setting.GroupRatio2JSONString()
|
||||||
|
common.OptionMap[setting.ModelRequestRateLimitGroupKey] = "{}" // 添加用户组速率限制默认值
|
||||||
common.OptionMap["UserUsableGroups"] = setting.UserUsableGroups2JSONString()
|
common.OptionMap["UserUsableGroups"] = setting.UserUsableGroups2JSONString()
|
||||||
common.OptionMap["CompletionRatio"] = operation_setting.CompletionRatio2JSONString()
|
common.OptionMap["CompletionRatio"] = operation_setting.CompletionRatio2JSONString()
|
||||||
common.OptionMap["TopUpLink"] = common.TopUpLink
|
common.OptionMap["TopUpLink"] = common.TopUpLink
|
||||||
@@ -150,7 +153,32 @@ func SyncOptions(frequency int) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func UpdateOption(key string, value string) error {
|
func UpdateOption(key string, value string) error {
|
||||||
// Save to database first
|
originalValue := value // 保存原始值以备后用
|
||||||
|
|
||||||
|
// Validate and format specific keys before saving
|
||||||
|
if key == setting.ModelRequestRateLimitGroupKey {
|
||||||
|
var cfg map[string][2]int
|
||||||
|
// Validate the JSON structure first using the original value
|
||||||
|
err := json.Unmarshal([]byte(originalValue), &cfg)
|
||||||
|
if err != nil {
|
||||||
|
// 提供更具体的错误信息
|
||||||
|
return fmt.Errorf("无效的 JSON 格式 for %s: %w", key, err)
|
||||||
|
}
|
||||||
|
// TODO: 可以添加更细致的结构验证,例如检查数组长度是否为2,值是否为非负数等。
|
||||||
|
// if !isValidModelRequestRateLimitGroupConfig(cfg) {
|
||||||
|
// return fmt.Errorf("无效的配置值 for %s", key)
|
||||||
|
// }
|
||||||
|
|
||||||
|
// If valid, format the JSON before saving
|
||||||
|
formattedValueBytes, marshalErr := json.MarshalIndent(cfg, "", " ")
|
||||||
|
if marshalErr != nil {
|
||||||
|
// This should ideally not happen if validation passed, but handle defensively
|
||||||
|
return fmt.Errorf("failed to marshal validated %s config: %w", key, marshalErr)
|
||||||
|
}
|
||||||
|
value = string(formattedValueBytes) // Use formatted JSON for saving and memory update
|
||||||
|
}
|
||||||
|
|
||||||
|
// Save to database
|
||||||
option := Option{
|
option := Option{
|
||||||
Key: key,
|
Key: key,
|
||||||
}
|
}
|
||||||
@@ -160,8 +188,12 @@ func UpdateOption(key string, value string) error {
|
|||||||
// Save is a combination function.
|
// Save is a combination function.
|
||||||
// If save value does not contain primary key, it will execute Create,
|
// If save value does not contain primary key, it will execute Create,
|
||||||
// otherwise it will execute Update (with all fields).
|
// otherwise it will execute Update (with all fields).
|
||||||
DB.Save(&option)
|
if err := DB.Save(&option).Error; err != nil {
|
||||||
// Update OptionMap
|
return fmt.Errorf("保存选项 %s 到数据库失败: %w", key, err) // 添加错误上下文
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update OptionMap in memory using the potentially formatted value
|
||||||
|
// updateOptionMap 会处理内存中 setting.ModelRequestRateLimitGroupConfig 的更新
|
||||||
return updateOptionMap(key, value)
|
return updateOptionMap(key, value)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -372,6 +404,15 @@ func updateOptionMap(key string, value string) (err error) {
|
|||||||
operation_setting.AutomaticDisableKeywordsFromString(value)
|
operation_setting.AutomaticDisableKeywordsFromString(value)
|
||||||
case "StreamCacheQueueLength":
|
case "StreamCacheQueueLength":
|
||||||
setting.StreamCacheQueueLength, _ = strconv.Atoi(value)
|
setting.StreamCacheQueueLength, _ = strconv.Atoi(value)
|
||||||
|
case setting.ModelRequestRateLimitGroupKey:
|
||||||
|
// Use the (potentially formatted) value passed from UpdateOption
|
||||||
|
// to update the actual configuration in memory.
|
||||||
|
// This is the single point where the memory state for this specific setting is updated.
|
||||||
|
err = setting.UpdateModelRequestRateLimitGroupConfig(value)
|
||||||
|
if err != nil {
|
||||||
|
// 添加错误上下文
|
||||||
|
err = fmt.Errorf("更新内存中的 %s 配置失败: %w", key, err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,76 @@
|
|||||||
package setting
|
package setting
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"one-api/common"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
var ModelRequestRateLimitEnabled = false
|
var ModelRequestRateLimitEnabled = false
|
||||||
var ModelRequestRateLimitDurationMinutes = 1
|
var ModelRequestRateLimitDurationMinutes = 1
|
||||||
var ModelRequestRateLimitCount = 0
|
var ModelRequestRateLimitCount = 0
|
||||||
var ModelRequestRateLimitSuccessCount = 1000
|
var ModelRequestRateLimitSuccessCount = 1000
|
||||||
|
|
||||||
|
// ModelRequestRateLimitGroupKey 定义了模型请求按组速率限制的配置键
|
||||||
|
const ModelRequestRateLimitGroupKey = "ModelRequestRateLimitGroup"
|
||||||
|
|
||||||
|
// ModelRequestRateLimitGroupConfig 存储按用户组解析后的速率限制配置
|
||||||
|
// map[groupName][2]int{totalCount, successCount}
|
||||||
|
var ModelRequestRateLimitGroupConfig map[string][2]int
|
||||||
|
var ModelRequestRateLimitGroupMutex sync.RWMutex
|
||||||
|
|
||||||
|
// UpdateModelRequestRateLimitGroupConfig 解析、校验并更新内存中的用户组速率限制配置
|
||||||
|
func UpdateModelRequestRateLimitGroupConfig(jsonStr string) error {
|
||||||
|
ModelRequestRateLimitGroupMutex.Lock()
|
||||||
|
defer ModelRequestRateLimitGroupMutex.Unlock()
|
||||||
|
|
||||||
|
var newConfig map[string][2]int
|
||||||
|
if jsonStr == "" || jsonStr == "{}" {
|
||||||
|
// 如果配置为空或空JSON对象,则清空内存配置
|
||||||
|
ModelRequestRateLimitGroupConfig = make(map[string][2]int)
|
||||||
|
common.SysLog("Model request rate limit group config cleared")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
err := json.Unmarshal([]byte(jsonStr), &newConfig)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to unmarshal ModelRequestRateLimitGroup config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 校验配置值
|
||||||
|
for group, limits := range newConfig {
|
||||||
|
if len(limits) != 2 {
|
||||||
|
return fmt.Errorf("invalid config for group '%s': limits array length must be 2", group)
|
||||||
|
}
|
||||||
|
if limits[1] <= 0 { // successCount must be greater than 0
|
||||||
|
return fmt.Errorf("invalid config for group '%s': successCount (limits[1]) must be greater than 0", group)
|
||||||
|
}
|
||||||
|
if limits[0] < 0 { // totalCount can be 0 (no limit) or positive
|
||||||
|
return fmt.Errorf("invalid config for group '%s': totalCount (limits[0]) cannot be negative", group)
|
||||||
|
}
|
||||||
|
if limits[0] > 0 && limits[0] < limits[1] { // If totalCount is set, it must be >= successCount
|
||||||
|
return fmt.Errorf("invalid config for group '%s': totalCount (limits[0]) must be greater than or equal to successCount (limits[1]) when totalCount > 0", group)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ModelRequestRateLimitGroupConfig = newConfig
|
||||||
|
common.SysLog("Model request rate limit group config updated")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetGroupRateLimit 安全地获取指定用户组的速率限制值
|
||||||
|
func GetGroupRateLimit(group string) (totalCount, successCount int, found bool) {
|
||||||
|
ModelRequestRateLimitGroupMutex.RLock()
|
||||||
|
defer ModelRequestRateLimitGroupMutex.RUnlock()
|
||||||
|
|
||||||
|
if ModelRequestRateLimitGroupConfig == nil {
|
||||||
|
return 0, 0, false // 配置尚未初始化
|
||||||
|
}
|
||||||
|
|
||||||
|
limits, found := ModelRequestRateLimitGroupConfig[group]
|
||||||
|
if !found {
|
||||||
|
return 0, 0, false
|
||||||
|
}
|
||||||
|
return limits[0], limits[1], true
|
||||||
|
}
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ const RateLimitSetting = () => {
|
|||||||
ModelRequestRateLimitCount: 0,
|
ModelRequestRateLimitCount: 0,
|
||||||
ModelRequestRateLimitSuccessCount: 1000,
|
ModelRequestRateLimitSuccessCount: 1000,
|
||||||
ModelRequestRateLimitDurationMinutes: 1,
|
ModelRequestRateLimitDurationMinutes: 1,
|
||||||
|
ModelRequestRateLimitGroup: {},
|
||||||
});
|
});
|
||||||
|
|
||||||
let [loading, setLoading] = useState(false);
|
let [loading, setLoading] = useState(false);
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ export default function RequestRateLimit(props) {
|
|||||||
ModelRequestRateLimitCount: -1,
|
ModelRequestRateLimitCount: -1,
|
||||||
ModelRequestRateLimitSuccessCount: 1000,
|
ModelRequestRateLimitSuccessCount: 1000,
|
||||||
ModelRequestRateLimitDurationMinutes: 1,
|
ModelRequestRateLimitDurationMinutes: 1,
|
||||||
|
ModelRequestRateLimitGroup: '{}', // 添加新字段并设置默认值
|
||||||
});
|
});
|
||||||
const refForm = useRef();
|
const refForm = useRef();
|
||||||
const [inputsRow, setInputsRow] = useState(inputs);
|
const [inputsRow, setInputsRow] = useState(inputs);
|
||||||
@@ -32,25 +33,49 @@ export default function RequestRateLimit(props) {
|
|||||||
} else {
|
} else {
|
||||||
value = inputs[item.key];
|
value = inputs[item.key];
|
||||||
}
|
}
|
||||||
|
// 校验 ModelRequestRateLimitGroup 是否为有效的 JSON 对象字符串
|
||||||
|
if (item.key === 'ModelRequestRateLimitGroup') {
|
||||||
|
try {
|
||||||
|
JSON.parse(value);
|
||||||
|
} catch (e) {
|
||||||
|
showError(t('用户组速率限制配置不是有效的 JSON 格式!'));
|
||||||
|
// 阻止请求发送
|
||||||
|
return Promise.reject('Invalid JSON format');
|
||||||
|
}
|
||||||
|
}
|
||||||
return API.put('/api/option/', {
|
return API.put('/api/option/', {
|
||||||
key: item.key,
|
key: item.key,
|
||||||
value,
|
value,
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// 过滤掉无效的请求(例如,无效的 JSON)
|
||||||
|
const validRequests = requestQueue.filter(req => req !== null && req !== undefined && typeof req.then === 'function');
|
||||||
|
|
||||||
|
if (validRequests.length === 0 && requestQueue.length > 0) {
|
||||||
|
// 如果所有请求都被过滤掉了(因为 JSON 无效),则不继续执行
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
setLoading(true);
|
setLoading(true);
|
||||||
Promise.all(requestQueue)
|
Promise.all(validRequests)
|
||||||
.then((res) => {
|
.then((res) => {
|
||||||
if (requestQueue.length === 1) {
|
if (validRequests.length === 1) {
|
||||||
if (res.includes(undefined)) return;
|
if (res.includes(undefined)) return;
|
||||||
} else if (requestQueue.length > 1) {
|
} else if (validRequests.length > 1) {
|
||||||
if (res.includes(undefined))
|
if (res.includes(undefined))
|
||||||
return showError(t('部分保存失败,请重试'));
|
return showError(t('部分保存失败,请重试'));
|
||||||
}
|
}
|
||||||
showSuccess(t('保存成功'));
|
showSuccess(t('保存成功'));
|
||||||
props.refresh();
|
props.refresh();
|
||||||
|
// 更新 inputsRow 以反映保存后的状态
|
||||||
|
setInputsRow(structuredClone(inputs));
|
||||||
})
|
})
|
||||||
.catch(() => {
|
.catch((error) => {
|
||||||
showError(t('保存失败,请重试'));
|
// 检查是否是由于无效 JSON 导致的错误
|
||||||
|
if (error !== 'Invalid JSON format') {
|
||||||
|
showError(t('保存失败,请重试'));
|
||||||
|
}
|
||||||
})
|
})
|
||||||
.finally(() => {
|
.finally(() => {
|
||||||
setLoading(false);
|
setLoading(false);
|
||||||
@@ -66,8 +91,11 @@ export default function RequestRateLimit(props) {
|
|||||||
}
|
}
|
||||||
setInputs(currentInputs);
|
setInputs(currentInputs);
|
||||||
setInputsRow(structuredClone(currentInputs));
|
setInputsRow(structuredClone(currentInputs));
|
||||||
refForm.current.setValues(currentInputs);
|
// 检查 refForm.current 是否存在
|
||||||
}, [props.options]);
|
if (refForm.current) {
|
||||||
|
refForm.current.setValues(currentInputs);
|
||||||
|
}
|
||||||
|
}, [props.options]); // 依赖项保持不变,因为 inputs 状态的结构已固定
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<>
|
<>
|
||||||
@@ -147,7 +175,41 @@ export default function RequestRateLimit(props) {
|
|||||||
/>
|
/>
|
||||||
</Col>
|
</Col>
|
||||||
</Row>
|
</Row>
|
||||||
|
{/* 用户组速率限制配置项 */}
|
||||||
<Row>
|
<Row>
|
||||||
|
<Col span={24}>
|
||||||
|
<Form.TextArea
|
||||||
|
label={t('用户组速率限制 (JSON)')}
|
||||||
|
field={'ModelRequestRateLimitGroup'}
|
||||||
|
placeholder={t( // 更新 placeholder
|
||||||
|
'请输入 JSON 格式的用户组限制,例如:\n{\n "default": [200, 100],\n "vip": [1000, 500]\n}',
|
||||||
|
)}
|
||||||
|
extraText={ // 更新 extraText
|
||||||
|
<div>
|
||||||
|
<p>{t('说明:')}</p>
|
||||||
|
<ul>
|
||||||
|
<li>{t('使用 JSON 对象格式,键为用户组名 (字符串),值为包含两个整数的数组 [总次数限制, 成功次数限制]。')}</li>
|
||||||
|
<li>{t('总次数限制: 周期内允许的总请求次数 (含失败),0 代表不限制。')}</li>
|
||||||
|
<li>{t('成功次数限制: 周期内允许的成功请求次数 (HTTP < 400),必须大于 0。')}</li>
|
||||||
|
<li>{t('此配置将优先于上方的全局限制设置。')}</li>
|
||||||
|
<li>{t('未在此处配置的用户组将使用全局限制。')}</li>
|
||||||
|
<li>{t('限制周期统一使用上方配置的“限制周期”值。')}</li>
|
||||||
|
<li>{t('输入无效的 JSON 将无法保存。')}</li>
|
||||||
|
</ul>
|
||||||
|
</div>
|
||||||
|
}
|
||||||
|
autosize={{ minRows: 5, maxRows: 15 }}
|
||||||
|
style={{ fontFamily: 'monospace' }}
|
||||||
|
onChange={(value) => {
|
||||||
|
setInputs({
|
||||||
|
...inputs,
|
||||||
|
ModelRequestRateLimitGroup: value, // 直接更新字符串值
|
||||||
|
});
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
</Col>
|
||||||
|
</Row>
|
||||||
|
<Row style={{ marginTop: 15 }}>
|
||||||
<Button size='default' onClick={onSubmit}>
|
<Button size='default' onClick={onSubmit}>
|
||||||
{t('保存模型速率限制')}
|
{t('保存模型速率限制')}
|
||||||
</Button>
|
</Button>
|
||||||
|
|||||||
Reference in New Issue
Block a user