feat: 增加分组速率功能

This commit is contained in:
tbphp
2025-05-05 07:31:54 +08:00
parent bae57c05c1
commit 6c3fb7777e
5 changed files with 214 additions and 17 deletions

View File

@@ -168,16 +168,39 @@ func ModelRequestRateLimit() func(c *gin.Context) {
return
}
// 计算限流参数
// 计算通用限流参数
duration := int64(setting.ModelRequestRateLimitDurationMinutes * 60)
totalMaxCount := setting.ModelRequestRateLimitCount
successMaxCount := setting.ModelRequestRateLimitSuccessCount
// 根据存储类型选择并执行限流处理器
if common.RedisEnabled {
redisRateLimitHandler(duration, totalMaxCount, successMaxCount)(c)
// 获取用户组
group := c.GetString("token_group")
if group == "" {
group = c.GetString("group")
}
if group == "" {
group = "default" // 默认组
}
// 尝试获取用户组特定的限制
groupTotalCount, groupSuccessCount, found := setting.GetGroupRateLimit(group)
// 确定最终的限制值
finalTotalCount := setting.ModelRequestRateLimitCount // 默认使用全局总次数限制
finalSuccessCount := setting.ModelRequestRateLimitSuccessCount // 默认使用全局成功次数限制
if found {
// 如果找到用户组特定限制,则使用它们
finalTotalCount = groupTotalCount
finalSuccessCount = groupSuccessCount
common.LogWarn(c.Request.Context(), fmt.Sprintf("Using rate limit for group '%s': total=%d, success=%d", group, finalTotalCount, finalSuccessCount))
} else {
memoryRateLimitHandler(duration, totalMaxCount, successMaxCount)(c)
common.LogInfo(c.Request.Context(), fmt.Sprintf("No specific rate limit found for group '%s', using global limits: total=%d, success=%d", group, finalTotalCount, finalSuccessCount))
}
// 根据存储类型选择并执行限流处理器,传入最终确定的限制值
if common.RedisEnabled {
redisRateLimitHandler(duration, finalTotalCount, finalSuccessCount)(c)
} else {
memoryRateLimitHandler(duration, finalTotalCount, finalSuccessCount)(c)
}
}
}

View File

@@ -1,6 +1,8 @@
package model
import (
"encoding/json"
"fmt"
"one-api/common"
"one-api/setting"
"one-api/setting/config"
@@ -96,6 +98,7 @@ func InitOptionMap() {
common.OptionMap["ModelPrice"] = operation_setting.ModelPrice2JSONString()
common.OptionMap["CacheRatio"] = operation_setting.CacheRatio2JSONString()
common.OptionMap["GroupRatio"] = setting.GroupRatio2JSONString()
common.OptionMap[setting.ModelRequestRateLimitGroupKey] = "{}" // 添加用户组速率限制默认值
common.OptionMap["UserUsableGroups"] = setting.UserUsableGroups2JSONString()
common.OptionMap["CompletionRatio"] = operation_setting.CompletionRatio2JSONString()
common.OptionMap["TopUpLink"] = common.TopUpLink
@@ -150,7 +153,32 @@ func SyncOptions(frequency int) {
}
func UpdateOption(key string, value string) error {
// Save to database first
originalValue := value // 保存原始值以备后用
// Validate and format specific keys before saving
if key == setting.ModelRequestRateLimitGroupKey {
var cfg map[string][2]int
// Validate the JSON structure first using the original value
err := json.Unmarshal([]byte(originalValue), &cfg)
if err != nil {
// 提供更具体的错误信息
return fmt.Errorf("无效的 JSON 格式 for %s: %w", key, err)
}
// TODO: 可以添加更细致的结构验证例如检查数组长度是否为2值是否为非负数等。
// if !isValidModelRequestRateLimitGroupConfig(cfg) {
// return fmt.Errorf("无效的配置值 for %s", key)
// }
// If valid, format the JSON before saving
formattedValueBytes, marshalErr := json.MarshalIndent(cfg, "", " ")
if marshalErr != nil {
// This should ideally not happen if validation passed, but handle defensively
return fmt.Errorf("failed to marshal validated %s config: %w", key, marshalErr)
}
value = string(formattedValueBytes) // Use formatted JSON for saving and memory update
}
// Save to database
option := Option{
Key: key,
}
@@ -160,8 +188,12 @@ func UpdateOption(key string, value string) error {
// Save is a combination function.
// If save value does not contain primary key, it will execute Create,
// otherwise it will execute Update (with all fields).
DB.Save(&option)
// Update OptionMap
if err := DB.Save(&option).Error; err != nil {
return fmt.Errorf("保存选项 %s 到数据库失败: %w", key, err) // 添加错误上下文
}
// Update OptionMap in memory using the potentially formatted value
// updateOptionMap 会处理内存中 setting.ModelRequestRateLimitGroupConfig 的更新
return updateOptionMap(key, value)
}
@@ -372,6 +404,15 @@ func updateOptionMap(key string, value string) (err error) {
operation_setting.AutomaticDisableKeywordsFromString(value)
case "StreamCacheQueueLength":
setting.StreamCacheQueueLength, _ = strconv.Atoi(value)
case setting.ModelRequestRateLimitGroupKey:
// Use the (potentially formatted) value passed from UpdateOption
// to update the actual configuration in memory.
// This is the single point where the memory state for this specific setting is updated.
err = setting.UpdateModelRequestRateLimitGroupConfig(value)
if err != nil {
// 添加错误上下文
err = fmt.Errorf("更新内存中的 %s 配置失败: %w", key, err)
}
}
return err
}

View File

@@ -1,6 +1,76 @@
package setting
import (
"encoding/json"
"fmt"
"one-api/common"
"sync"
)
var ModelRequestRateLimitEnabled = false
var ModelRequestRateLimitDurationMinutes = 1
var ModelRequestRateLimitCount = 0
var ModelRequestRateLimitSuccessCount = 1000
// ModelRequestRateLimitGroupKey 定义了模型请求按组速率限制的配置键
const ModelRequestRateLimitGroupKey = "ModelRequestRateLimitGroup"
// ModelRequestRateLimitGroupConfig 存储按用户组解析后的速率限制配置
// map[groupName][2]int{totalCount, successCount}
var ModelRequestRateLimitGroupConfig map[string][2]int
var ModelRequestRateLimitGroupMutex sync.RWMutex
// UpdateModelRequestRateLimitGroupConfig 解析、校验并更新内存中的用户组速率限制配置
func UpdateModelRequestRateLimitGroupConfig(jsonStr string) error {
ModelRequestRateLimitGroupMutex.Lock()
defer ModelRequestRateLimitGroupMutex.Unlock()
var newConfig map[string][2]int
if jsonStr == "" || jsonStr == "{}" {
// 如果配置为空或空JSON对象则清空内存配置
ModelRequestRateLimitGroupConfig = make(map[string][2]int)
common.SysLog("Model request rate limit group config cleared")
return nil
}
err := json.Unmarshal([]byte(jsonStr), &newConfig)
if err != nil {
return fmt.Errorf("failed to unmarshal ModelRequestRateLimitGroup config: %w", err)
}
// 校验配置值
for group, limits := range newConfig {
if len(limits) != 2 {
return fmt.Errorf("invalid config for group '%s': limits array length must be 2", group)
}
if limits[1] <= 0 { // successCount must be greater than 0
return fmt.Errorf("invalid config for group '%s': successCount (limits[1]) must be greater than 0", group)
}
if limits[0] < 0 { // totalCount can be 0 (no limit) or positive
return fmt.Errorf("invalid config for group '%s': totalCount (limits[0]) cannot be negative", group)
}
if limits[0] > 0 && limits[0] < limits[1] { // If totalCount is set, it must be >= successCount
return fmt.Errorf("invalid config for group '%s': totalCount (limits[0]) must be greater than or equal to successCount (limits[1]) when totalCount > 0", group)
}
}
ModelRequestRateLimitGroupConfig = newConfig
common.SysLog("Model request rate limit group config updated")
return nil
}
// GetGroupRateLimit 安全地获取指定用户组的速率限制值
func GetGroupRateLimit(group string) (totalCount, successCount int, found bool) {
ModelRequestRateLimitGroupMutex.RLock()
defer ModelRequestRateLimitGroupMutex.RUnlock()
if ModelRequestRateLimitGroupConfig == nil {
return 0, 0, false // 配置尚未初始化
}
limits, found := ModelRequestRateLimitGroupConfig[group]
if !found {
return 0, 0, false
}
return limits[0], limits[1], true
}

View File

@@ -13,6 +13,7 @@ const RateLimitSetting = () => {
ModelRequestRateLimitCount: 0,
ModelRequestRateLimitSuccessCount: 1000,
ModelRequestRateLimitDurationMinutes: 1,
ModelRequestRateLimitGroup: {},
});
let [loading, setLoading] = useState(false);

View File

@@ -18,6 +18,7 @@ export default function RequestRateLimit(props) {
ModelRequestRateLimitCount: -1,
ModelRequestRateLimitSuccessCount: 1000,
ModelRequestRateLimitDurationMinutes: 1,
ModelRequestRateLimitGroup: '{}', // 添加新字段并设置默认值
});
const refForm = useRef();
const [inputsRow, setInputsRow] = useState(inputs);
@@ -32,25 +33,49 @@ export default function RequestRateLimit(props) {
} else {
value = inputs[item.key];
}
// 校验 ModelRequestRateLimitGroup 是否为有效的 JSON 对象字符串
if (item.key === 'ModelRequestRateLimitGroup') {
try {
JSON.parse(value);
} catch (e) {
showError(t('用户组速率限制配置不是有效的 JSON 格式!'));
// 阻止请求发送
return Promise.reject('Invalid JSON format');
}
}
return API.put('/api/option/', {
key: item.key,
value,
});
});
// 过滤掉无效的请求(例如,无效的 JSON
const validRequests = requestQueue.filter(req => req !== null && req !== undefined && typeof req.then === 'function');
if (validRequests.length === 0 && requestQueue.length > 0) {
// 如果所有请求都被过滤掉了(因为 JSON 无效),则不继续执行
return;
}
setLoading(true);
Promise.all(requestQueue)
Promise.all(validRequests)
.then((res) => {
if (requestQueue.length === 1) {
if (validRequests.length === 1) {
if (res.includes(undefined)) return;
} else if (requestQueue.length > 1) {
} else if (validRequests.length > 1) {
if (res.includes(undefined))
return showError(t('部分保存失败,请重试'));
}
showSuccess(t('保存成功'));
props.refresh();
// 更新 inputsRow 以反映保存后的状态
setInputsRow(structuredClone(inputs));
})
.catch(() => {
showError(t('保存失败,请重试'));
.catch((error) => {
// 检查是否是由于无效 JSON 导致的错误
if (error !== 'Invalid JSON format') {
showError(t('保存失败,请重试'));
}
})
.finally(() => {
setLoading(false);
@@ -66,8 +91,11 @@ export default function RequestRateLimit(props) {
}
setInputs(currentInputs);
setInputsRow(structuredClone(currentInputs));
refForm.current.setValues(currentInputs);
}, [props.options]);
// 检查 refForm.current 是否存在
if (refForm.current) {
refForm.current.setValues(currentInputs);
}
}, [props.options]); // 依赖项保持不变,因为 inputs 状态的结构已固定
return (
<>
@@ -147,7 +175,41 @@ export default function RequestRateLimit(props) {
/>
</Col>
</Row>
{/* 用户组速率限制配置项 */}
<Row>
<Col span={24}>
<Form.TextArea
label={t('用户组速率限制 (JSON)')}
field={'ModelRequestRateLimitGroup'}
placeholder={t( // 更新 placeholder
'请输入 JSON 格式的用户组限制,例如:\n{\n "default": [200, 100],\n "vip": [1000, 500]\n}',
)}
extraText={ // 更新 extraText
<div>
<p>{t('说明:')}</p>
<ul>
<li>{t('使用 JSON 对象格式,键为用户组名 (字符串),值为包含两个整数的数组 [总次数限制, 成功次数限制]。')}</li>
<li>{t('总次数限制: 周期内允许的总请求次数 (含失败)0 代表不限制。')}</li>
<li>{t('成功次数限制: 周期内允许的成功请求次数 (HTTP < 400),必须大于 0。')}</li>
<li>{t('此配置将优先于上方的全局限制设置。')}</li>
<li>{t('未在此处配置的用户组将使用全局限制。')}</li>
<li>{t('限制周期统一使用上方配置的“限制周期”值。')}</li>
<li>{t('输入无效的 JSON 将无法保存。')}</li>
</ul>
</div>
}
autosize={{ minRows: 5, maxRows: 15 }}
style={{ fontFamily: 'monospace' }}
onChange={(value) => {
setInputs({
...inputs,
ModelRequestRateLimitGroup: value, // 直接更新字符串值
});
}}
/>
</Col>
</Row>
<Row style={{ marginTop: 15 }}>
<Button size='default' onClick={onSubmit}>
{t('保存模型速率限制')}
</Button>