Merge branch 'tbphp-tbphp_model_request_rate_limit_for_group'
This commit is contained in:
@@ -110,6 +110,15 @@ func UpdateOption(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
case "ModelRequestRateLimitGroup":
|
||||||
|
err = setting.CheckModelRequestRateLimitGroup(option.Value)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": err.Error(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
err = model.UpdateOption(option.Key, option.Value)
|
err = model.UpdateOption(option.Key, option.Value)
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/common/limiter"
|
"one-api/common/limiter"
|
||||||
|
"one-api/constant"
|
||||||
"one-api/setting"
|
"one-api/setting"
|
||||||
"strconv"
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
@@ -175,6 +176,19 @@ func ModelRequestRateLimit() func(c *gin.Context) {
|
|||||||
totalMaxCount := setting.ModelRequestRateLimitCount
|
totalMaxCount := setting.ModelRequestRateLimitCount
|
||||||
successMaxCount := setting.ModelRequestRateLimitSuccessCount
|
successMaxCount := setting.ModelRequestRateLimitSuccessCount
|
||||||
|
|
||||||
|
// 获取分组
|
||||||
|
group := c.GetString("token_group")
|
||||||
|
if group == "" {
|
||||||
|
group = c.GetString(constant.ContextKeyUserGroup)
|
||||||
|
}
|
||||||
|
|
||||||
|
//获取分组的限流配置
|
||||||
|
groupTotalCount, groupSuccessCount, found := setting.GetGroupRateLimit(group)
|
||||||
|
if found {
|
||||||
|
totalMaxCount = groupTotalCount
|
||||||
|
successMaxCount = groupSuccessCount
|
||||||
|
}
|
||||||
|
|
||||||
// 根据存储类型选择并执行限流处理器
|
// 根据存储类型选择并执行限流处理器
|
||||||
if common.RedisEnabled {
|
if common.RedisEnabled {
|
||||||
redisRateLimitHandler(duration, totalMaxCount, successMaxCount)(c)
|
redisRateLimitHandler(duration, totalMaxCount, successMaxCount)(c)
|
||||||
|
|||||||
@@ -92,6 +92,7 @@ func InitOptionMap() {
|
|||||||
common.OptionMap["ModelRequestRateLimitCount"] = strconv.Itoa(setting.ModelRequestRateLimitCount)
|
common.OptionMap["ModelRequestRateLimitCount"] = strconv.Itoa(setting.ModelRequestRateLimitCount)
|
||||||
common.OptionMap["ModelRequestRateLimitDurationMinutes"] = strconv.Itoa(setting.ModelRequestRateLimitDurationMinutes)
|
common.OptionMap["ModelRequestRateLimitDurationMinutes"] = strconv.Itoa(setting.ModelRequestRateLimitDurationMinutes)
|
||||||
common.OptionMap["ModelRequestRateLimitSuccessCount"] = strconv.Itoa(setting.ModelRequestRateLimitSuccessCount)
|
common.OptionMap["ModelRequestRateLimitSuccessCount"] = strconv.Itoa(setting.ModelRequestRateLimitSuccessCount)
|
||||||
|
common.OptionMap["ModelRequestRateLimitGroup"] = setting.ModelRequestRateLimitGroup2JSONString()
|
||||||
common.OptionMap["ModelRatio"] = operation_setting.ModelRatio2JSONString()
|
common.OptionMap["ModelRatio"] = operation_setting.ModelRatio2JSONString()
|
||||||
common.OptionMap["ModelPrice"] = operation_setting.ModelPrice2JSONString()
|
common.OptionMap["ModelPrice"] = operation_setting.ModelPrice2JSONString()
|
||||||
common.OptionMap["CacheRatio"] = operation_setting.CacheRatio2JSONString()
|
common.OptionMap["CacheRatio"] = operation_setting.CacheRatio2JSONString()
|
||||||
@@ -338,6 +339,8 @@ func updateOptionMap(key string, value string) (err error) {
|
|||||||
setting.ModelRequestRateLimitDurationMinutes, _ = strconv.Atoi(value)
|
setting.ModelRequestRateLimitDurationMinutes, _ = strconv.Atoi(value)
|
||||||
case "ModelRequestRateLimitSuccessCount":
|
case "ModelRequestRateLimitSuccessCount":
|
||||||
setting.ModelRequestRateLimitSuccessCount, _ = strconv.Atoi(value)
|
setting.ModelRequestRateLimitSuccessCount, _ = strconv.Atoi(value)
|
||||||
|
case "ModelRequestRateLimitGroup":
|
||||||
|
err = setting.UpdateModelRequestRateLimitGroupByJSONString(value)
|
||||||
case "RetryTimes":
|
case "RetryTimes":
|
||||||
common.RetryTimes, _ = strconv.Atoi(value)
|
common.RetryTimes, _ = strconv.Atoi(value)
|
||||||
case "DataExportInterval":
|
case "DataExportInterval":
|
||||||
|
|||||||
@@ -1,6 +1,64 @@
|
|||||||
package setting
|
package setting
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"one-api/common"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
var ModelRequestRateLimitEnabled = false
|
var ModelRequestRateLimitEnabled = false
|
||||||
var ModelRequestRateLimitDurationMinutes = 1
|
var ModelRequestRateLimitDurationMinutes = 1
|
||||||
var ModelRequestRateLimitCount = 0
|
var ModelRequestRateLimitCount = 0
|
||||||
var ModelRequestRateLimitSuccessCount = 1000
|
var ModelRequestRateLimitSuccessCount = 1000
|
||||||
|
var ModelRequestRateLimitGroup = map[string][2]int{}
|
||||||
|
var ModelRequestRateLimitMutex sync.RWMutex
|
||||||
|
|
||||||
|
func ModelRequestRateLimitGroup2JSONString() string {
|
||||||
|
ModelRequestRateLimitMutex.RLock()
|
||||||
|
defer ModelRequestRateLimitMutex.RUnlock()
|
||||||
|
|
||||||
|
jsonBytes, err := json.Marshal(ModelRequestRateLimitGroup)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("error marshalling model ratio: " + err.Error())
|
||||||
|
}
|
||||||
|
return string(jsonBytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
func UpdateModelRequestRateLimitGroupByJSONString(jsonStr string) error {
|
||||||
|
ModelRequestRateLimitMutex.RLock()
|
||||||
|
defer ModelRequestRateLimitMutex.RUnlock()
|
||||||
|
|
||||||
|
ModelRequestRateLimitGroup = make(map[string][2]int)
|
||||||
|
return json.Unmarshal([]byte(jsonStr), &ModelRequestRateLimitGroup)
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetGroupRateLimit(group string) (totalCount, successCount int, found bool) {
|
||||||
|
ModelRequestRateLimitMutex.RLock()
|
||||||
|
defer ModelRequestRateLimitMutex.RUnlock()
|
||||||
|
|
||||||
|
if ModelRequestRateLimitGroup == nil {
|
||||||
|
return 0, 0, false
|
||||||
|
}
|
||||||
|
|
||||||
|
limits, found := ModelRequestRateLimitGroup[group]
|
||||||
|
if !found {
|
||||||
|
return 0, 0, false
|
||||||
|
}
|
||||||
|
return limits[0], limits[1], true
|
||||||
|
}
|
||||||
|
|
||||||
|
func CheckModelRequestRateLimitGroup(jsonStr string) error {
|
||||||
|
checkModelRequestRateLimitGroup := make(map[string][2]int)
|
||||||
|
err := json.Unmarshal([]byte(jsonStr), &checkModelRequestRateLimitGroup)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
for group, limits := range checkModelRequestRateLimitGroup {
|
||||||
|
if limits[0] < 0 || limits[1] < 1 {
|
||||||
|
return fmt.Errorf("group %s has negative rate limit values: [%d, %d]", group, limits[0], limits[1])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ const RateLimitSetting = () => {
|
|||||||
ModelRequestRateLimitCount: 0,
|
ModelRequestRateLimitCount: 0,
|
||||||
ModelRequestRateLimitSuccessCount: 1000,
|
ModelRequestRateLimitSuccessCount: 1000,
|
||||||
ModelRequestRateLimitDurationMinutes: 1,
|
ModelRequestRateLimitDurationMinutes: 1,
|
||||||
|
ModelRequestRateLimitGroup: '',
|
||||||
});
|
});
|
||||||
|
|
||||||
let [loading, setLoading] = useState(false);
|
let [loading, setLoading] = useState(false);
|
||||||
@@ -23,10 +24,14 @@ const RateLimitSetting = () => {
|
|||||||
if (success) {
|
if (success) {
|
||||||
let newInputs = {};
|
let newInputs = {};
|
||||||
data.forEach((item) => {
|
data.forEach((item) => {
|
||||||
if (item.key.endsWith('Enabled')) {
|
if (item.key === 'ModelRequestRateLimitGroup') {
|
||||||
newInputs[item.key] = item.value === 'true' ? true : false;
|
item.value = JSON.stringify(JSON.parse(item.value), null, 2);
|
||||||
} else {
|
}
|
||||||
newInputs[item.key] = item.value;
|
|
||||||
|
if (item.key.endsWith('Enabled')) {
|
||||||
|
newInputs[item.key] = item.value === 'true' ? true : false;
|
||||||
|
} else {
|
||||||
|
newInputs[item.key] = item.value;
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import {
|
|||||||
showError,
|
showError,
|
||||||
showSuccess,
|
showSuccess,
|
||||||
showWarning,
|
showWarning,
|
||||||
|
verifyJSON,
|
||||||
} from '../../../helpers';
|
} from '../../../helpers';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
|
|
||||||
@@ -18,6 +19,7 @@ export default function RequestRateLimit(props) {
|
|||||||
ModelRequestRateLimitCount: -1,
|
ModelRequestRateLimitCount: -1,
|
||||||
ModelRequestRateLimitSuccessCount: 1000,
|
ModelRequestRateLimitSuccessCount: 1000,
|
||||||
ModelRequestRateLimitDurationMinutes: 1,
|
ModelRequestRateLimitDurationMinutes: 1,
|
||||||
|
ModelRequestRateLimitGroup: '',
|
||||||
});
|
});
|
||||||
const refForm = useRef();
|
const refForm = useRef();
|
||||||
const [inputsRow, setInputsRow] = useState(inputs);
|
const [inputsRow, setInputsRow] = useState(inputs);
|
||||||
@@ -46,6 +48,13 @@ export default function RequestRateLimit(props) {
|
|||||||
if (res.includes(undefined))
|
if (res.includes(undefined))
|
||||||
return showError(t('部分保存失败,请重试'));
|
return showError(t('部分保存失败,请重试'));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for (let i = 0; i < res.length; i++) {
|
||||||
|
if (!res[i].data.success) {
|
||||||
|
return showError(res[i].data.message);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
showSuccess(t('保存成功'));
|
showSuccess(t('保存成功'));
|
||||||
props.refresh();
|
props.refresh();
|
||||||
})
|
})
|
||||||
@@ -147,6 +156,41 @@ export default function RequestRateLimit(props) {
|
|||||||
/>
|
/>
|
||||||
</Col>
|
</Col>
|
||||||
</Row>
|
</Row>
|
||||||
|
<Row>
|
||||||
|
<Col xs={24} sm={16}>
|
||||||
|
<Form.TextArea
|
||||||
|
label={t('分组速率限制')}
|
||||||
|
placeholder={t(
|
||||||
|
'{\n "default": [200, 100],\n "vip": [0, 1000]\n}',
|
||||||
|
)}
|
||||||
|
field={'ModelRequestRateLimitGroup'}
|
||||||
|
autosize={{ minRows: 5, maxRows: 15 }}
|
||||||
|
trigger='blur'
|
||||||
|
stopValidateWithError
|
||||||
|
rules={[
|
||||||
|
{
|
||||||
|
validator: (rule, value) => verifyJSON(value),
|
||||||
|
message: t('不是合法的 JSON 字符串'),
|
||||||
|
},
|
||||||
|
]}
|
||||||
|
extraText={
|
||||||
|
<div>
|
||||||
|
<p style={{ marginBottom: -15 }}>{t('说明:')}</p>
|
||||||
|
<ul>
|
||||||
|
<li>{t('使用 JSON 对象格式,格式为:{"组名": [最多请求次数, 最多请求完成次数]}')}</li>
|
||||||
|
<li>{t('示例:{"default": [200, 100], "vip": [0, 1000]}。')}</li>
|
||||||
|
<li>{t('[最多请求次数]必须大于等于0,[最多请求完成次数]必须大于等于1。')}</li>
|
||||||
|
<li>{t('分组速率配置优先级高于全局速率限制。')}</li>
|
||||||
|
<li>{t('限制周期统一使用上方配置的“限制周期”值。')}</li>
|
||||||
|
</ul>
|
||||||
|
</div>
|
||||||
|
}
|
||||||
|
onChange={(value) => {
|
||||||
|
setInputs({ ...inputs, ModelRequestRateLimitGroup: value });
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
</Col>
|
||||||
|
</Row>
|
||||||
<Row>
|
<Row>
|
||||||
<Button size='default' onClick={onSubmit}>
|
<Button size='default' onClick={onSubmit}>
|
||||||
{t('保存模型速率限制')}
|
{t('保存模型速率限制')}
|
||||||
|
|||||||
Reference in New Issue
Block a user