diff --git a/controller/option.go b/controller/option.go
index 81ef463c..250f16bb 100644
--- a/controller/option.go
+++ b/controller/option.go
@@ -110,6 +110,15 @@ func UpdateOption(c *gin.Context) {
})
return
}
+ case "ModelRequestRateLimitGroup":
+ err = setting.CheckModelRequestRateLimitGroup(option.Value)
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
}
err = model.UpdateOption(option.Key, option.Value)
diff --git a/middleware/model-rate-limit.go b/middleware/model-rate-limit.go
index f81160fc..34caa59b 100644
--- a/middleware/model-rate-limit.go
+++ b/middleware/model-rate-limit.go
@@ -6,6 +6,7 @@ import (
"net/http"
"one-api/common"
"one-api/common/limiter"
+ "one-api/constant"
"one-api/setting"
"strconv"
"time"
@@ -175,6 +176,19 @@ func ModelRequestRateLimit() func(c *gin.Context) {
totalMaxCount := setting.ModelRequestRateLimitCount
successMaxCount := setting.ModelRequestRateLimitSuccessCount
+ // 获取分组
+ group := c.GetString("token_group")
+ if group == "" {
+ group = c.GetString(constant.ContextKeyUserGroup)
+ }
+
+ //获取分组的限流配置
+ groupTotalCount, groupSuccessCount, found := setting.GetGroupRateLimit(group)
+ if found {
+ totalMaxCount = groupTotalCount
+ successMaxCount = groupSuccessCount
+ }
+
// 根据存储类型选择并执行限流处理器
if common.RedisEnabled {
redisRateLimitHandler(duration, totalMaxCount, successMaxCount)(c)
diff --git a/model/option.go b/model/option.go
index d575742f..d98a9d38 100644
--- a/model/option.go
+++ b/model/option.go
@@ -92,6 +92,7 @@ func InitOptionMap() {
common.OptionMap["ModelRequestRateLimitCount"] = strconv.Itoa(setting.ModelRequestRateLimitCount)
common.OptionMap["ModelRequestRateLimitDurationMinutes"] = strconv.Itoa(setting.ModelRequestRateLimitDurationMinutes)
common.OptionMap["ModelRequestRateLimitSuccessCount"] = strconv.Itoa(setting.ModelRequestRateLimitSuccessCount)
+ common.OptionMap["ModelRequestRateLimitGroup"] = setting.ModelRequestRateLimitGroup2JSONString()
common.OptionMap["ModelRatio"] = operation_setting.ModelRatio2JSONString()
common.OptionMap["ModelPrice"] = operation_setting.ModelPrice2JSONString()
common.OptionMap["CacheRatio"] = operation_setting.CacheRatio2JSONString()
@@ -338,6 +339,8 @@ func updateOptionMap(key string, value string) (err error) {
setting.ModelRequestRateLimitDurationMinutes, _ = strconv.Atoi(value)
case "ModelRequestRateLimitSuccessCount":
setting.ModelRequestRateLimitSuccessCount, _ = strconv.Atoi(value)
+ case "ModelRequestRateLimitGroup":
+ err = setting.UpdateModelRequestRateLimitGroupByJSONString(value)
case "RetryTimes":
common.RetryTimes, _ = strconv.Atoi(value)
case "DataExportInterval":
diff --git a/setting/rate_limit.go b/setting/rate_limit.go
index 4b216948..53b53f88 100644
--- a/setting/rate_limit.go
+++ b/setting/rate_limit.go
@@ -1,6 +1,64 @@
package setting
+import (
+ "encoding/json"
+ "fmt"
+ "one-api/common"
+ "sync"
+)
+
var ModelRequestRateLimitEnabled = false
var ModelRequestRateLimitDurationMinutes = 1
var ModelRequestRateLimitCount = 0
var ModelRequestRateLimitSuccessCount = 1000
+var ModelRequestRateLimitGroup = map[string][2]int{}
+var ModelRequestRateLimitMutex sync.RWMutex
+
+func ModelRequestRateLimitGroup2JSONString() string {
+ ModelRequestRateLimitMutex.RLock()
+ defer ModelRequestRateLimitMutex.RUnlock()
+
+ jsonBytes, err := json.Marshal(ModelRequestRateLimitGroup)
+ if err != nil {
+ common.SysError("error marshalling model ratio: " + err.Error())
+ }
+ return string(jsonBytes)
+}
+
+func UpdateModelRequestRateLimitGroupByJSONString(jsonStr string) error {
+ ModelRequestRateLimitMutex.RLock()
+ defer ModelRequestRateLimitMutex.RUnlock()
+
+ ModelRequestRateLimitGroup = make(map[string][2]int)
+ return json.Unmarshal([]byte(jsonStr), &ModelRequestRateLimitGroup)
+}
+
+func GetGroupRateLimit(group string) (totalCount, successCount int, found bool) {
+ ModelRequestRateLimitMutex.RLock()
+ defer ModelRequestRateLimitMutex.RUnlock()
+
+ if ModelRequestRateLimitGroup == nil {
+ return 0, 0, false
+ }
+
+ limits, found := ModelRequestRateLimitGroup[group]
+ if !found {
+ return 0, 0, false
+ }
+ return limits[0], limits[1], true
+}
+
+func CheckModelRequestRateLimitGroup(jsonStr string) error {
+ checkModelRequestRateLimitGroup := make(map[string][2]int)
+ err := json.Unmarshal([]byte(jsonStr), &checkModelRequestRateLimitGroup)
+ if err != nil {
+ return err
+ }
+ for group, limits := range checkModelRequestRateLimitGroup {
+ if limits[0] < 0 || limits[1] < 1 {
+ return fmt.Errorf("group %s has negative rate limit values: [%d, %d]", group, limits[0], limits[1])
+ }
+ }
+
+ return nil
+}
diff --git a/web/src/components/RateLimitSetting.js b/web/src/components/RateLimitSetting.js
index e06038d6..5f0200e1 100644
--- a/web/src/components/RateLimitSetting.js
+++ b/web/src/components/RateLimitSetting.js
@@ -13,6 +13,7 @@ const RateLimitSetting = () => {
ModelRequestRateLimitCount: 0,
ModelRequestRateLimitSuccessCount: 1000,
ModelRequestRateLimitDurationMinutes: 1,
+ ModelRequestRateLimitGroup: '',
});
let [loading, setLoading] = useState(false);
@@ -23,10 +24,14 @@ const RateLimitSetting = () => {
if (success) {
let newInputs = {};
data.forEach((item) => {
- if (item.key.endsWith('Enabled')) {
- newInputs[item.key] = item.value === 'true' ? true : false;
- } else {
- newInputs[item.key] = item.value;
+ if (item.key === 'ModelRequestRateLimitGroup') {
+ item.value = JSON.stringify(JSON.parse(item.value), null, 2);
+ }
+
+ if (item.key.endsWith('Enabled')) {
+ newInputs[item.key] = item.value === 'true' ? true : false;
+ } else {
+ newInputs[item.key] = item.value;
}
});
diff --git a/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js b/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js
index 800e9636..73626351 100644
--- a/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js
+++ b/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js
@@ -6,6 +6,7 @@ import {
showError,
showSuccess,
showWarning,
+ verifyJSON,
} from '../../../helpers';
import { useTranslation } from 'react-i18next';
@@ -18,6 +19,7 @@ export default function RequestRateLimit(props) {
ModelRequestRateLimitCount: -1,
ModelRequestRateLimitSuccessCount: 1000,
ModelRequestRateLimitDurationMinutes: 1,
+ ModelRequestRateLimitGroup: '',
});
const refForm = useRef();
const [inputsRow, setInputsRow] = useState(inputs);
@@ -46,6 +48,13 @@ export default function RequestRateLimit(props) {
if (res.includes(undefined))
return showError(t('部分保存失败,请重试'));
}
+
+ for (let i = 0; i < res.length; i++) {
+ if (!res[i].data.success) {
+ return showError(res[i].data.message);
+ }
+ }
+
showSuccess(t('保存成功'));
props.refresh();
})
@@ -147,6 +156,41 @@ export default function RequestRateLimit(props) {
/>
+
+
+ verifyJSON(value),
+ message: t('不是合法的 JSON 字符串'),
+ },
+ ]}
+ extraText={
+
+
{t('说明:')}
+
+ - {t('使用 JSON 对象格式,格式为:{"组名": [最多请求次数, 最多请求完成次数]}')}
+ - {t('示例:{"default": [200, 100], "vip": [0, 1000]}。')}
+ - {t('[最多请求次数]必须大于等于0,[最多请求完成次数]必须大于等于1。')}
+ - {t('分组速率配置优先级高于全局速率限制。')}
+ - {t('限制周期统一使用上方配置的“限制周期”值。')}
+
+
+ }
+ onChange={(value) => {
+ setInputs({ ...inputs, ModelRequestRateLimitGroup: value });
+ }}
+ />
+
+