diff --git a/middleware/model-rate-limit.go b/middleware/model-rate-limit.go
new file mode 100644
index 00000000..135e0005
--- /dev/null
+++ b/middleware/model-rate-limit.go
@@ -0,0 +1,172 @@
+package middleware
+
+import (
+ "context"
+ "fmt"
+ "net/http"
+ "one-api/common"
+ "one-api/setting"
+ "strconv"
+ "time"
+
+ "github.com/gin-gonic/gin"
+ "github.com/go-redis/redis/v8"
+)
+
+const (
+ ModelRequestRateLimitCountMark = "MRRL"
+ ModelRequestRateLimitSuccessCountMark = "MRRLS"
+)
+
+// 检查Redis中的请求限制
+func checkRedisRateLimit(ctx context.Context, rdb *redis.Client, key string, maxCount int, duration int64) (bool, error) {
+ // 如果maxCount为0,表示不限制
+ if maxCount == 0 {
+ return true, nil
+ }
+
+ // 获取当前计数
+ length, err := rdb.LLen(ctx, key).Result()
+ if err != nil {
+ return false, err
+ }
+
+ // 如果未达到限制,允许请求
+ if length < int64(maxCount) {
+ return true, nil
+ }
+
+ // 检查时间窗口
+ oldTimeStr, _ := rdb.LIndex(ctx, key, -1).Result()
+ oldTime, err := time.Parse(timeFormat, oldTimeStr)
+ if err != nil {
+ return false, err
+ }
+
+ nowTimeStr := time.Now().Format(timeFormat)
+ nowTime, err := time.Parse(timeFormat, nowTimeStr)
+ if err != nil {
+ return false, err
+ }
+ // 如果在时间窗口内已达到限制,拒绝请求
+ subTime := nowTime.Sub(oldTime).Seconds()
+ if int64(subTime) < duration {
+ rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration)
+ return false, nil
+ }
+
+ return true, nil
+}
+
+// 记录Redis请求
+func recordRedisRequest(ctx context.Context, rdb *redis.Client, key string, maxCount int) {
+ // 如果maxCount为0,不记录请求
+ if maxCount == 0 {
+ return
+ }
+
+ now := time.Now().Format(timeFormat)
+ rdb.LPush(ctx, key, now)
+ rdb.LTrim(ctx, key, 0, int64(maxCount-1))
+ rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration)
+}
+
+// Redis限流处理器
+func redisRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) gin.HandlerFunc {
+ return func(c *gin.Context) {
+ userId := strconv.Itoa(c.GetInt("id"))
+ ctx := context.Background()
+ rdb := common.RDB
+
+ // 1. 检查总请求数限制(当totalMaxCount为0时会自动跳过)
+ totalKey := fmt.Sprintf("rateLimit:%s:%s", ModelRequestRateLimitCountMark, userId)
+ allowed, err := checkRedisRateLimit(ctx, rdb, totalKey, totalMaxCount, duration)
+ if err != nil {
+ fmt.Println("检查总请求数限制失败:", err.Error())
+ abortWithOpenAiMessage(c, http.StatusInternalServerError, "rate_limit_check_failed")
+ return
+ }
+ if !allowed {
+ abortWithOpenAiMessage(c, http.StatusTooManyRequests, fmt.Sprintf("您已达到总请求数限制:%d分钟内最多请求%d次,包括失败次数,请检查您的请求是否正确", setting.ModelRequestRateLimitDurationMinutes, totalMaxCount))
+ }
+
+ // 2. 检查成功请求数限制
+ successKey := fmt.Sprintf("rateLimit:%s:%s", ModelRequestRateLimitSuccessCountMark, userId)
+ allowed, err = checkRedisRateLimit(ctx, rdb, successKey, successMaxCount, duration)
+ if err != nil {
+ fmt.Println("检查成功请求数限制失败:", err.Error())
+ abortWithOpenAiMessage(c, http.StatusInternalServerError, "rate_limit_check_failed")
+ return
+ }
+ if !allowed {
+ abortWithOpenAiMessage(c, http.StatusTooManyRequests, fmt.Sprintf("您已达到请求数限制:%d分钟内最多请求%d次", setting.ModelRequestRateLimitDurationMinutes, successMaxCount))
+ return
+ }
+
+ // 3. 记录总请求(当totalMaxCount为0时会自动跳过)
+ recordRedisRequest(ctx, rdb, totalKey, totalMaxCount)
+
+ // 4. 处理请求
+ c.Next()
+
+ // 5. 如果请求成功,记录成功请求
+ if c.Writer.Status() < 400 {
+ recordRedisRequest(ctx, rdb, successKey, successMaxCount)
+ }
+ }
+}
+
+// 内存限流处理器
+func memoryRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) gin.HandlerFunc {
+ inMemoryRateLimiter.Init(common.RateLimitKeyExpirationDuration)
+
+ return func(c *gin.Context) {
+ userId := strconv.Itoa(c.GetInt("id"))
+ totalKey := ModelRequestRateLimitCountMark + userId
+ successKey := ModelRequestRateLimitSuccessCountMark + userId
+
+ // 1. 检查总请求数限制(当totalMaxCount为0时跳过)
+ if totalMaxCount > 0 && !inMemoryRateLimiter.Request(totalKey, totalMaxCount, duration) {
+ c.Status(http.StatusTooManyRequests)
+ c.Abort()
+ return
+ }
+
+ // 2. 检查成功请求数限制
+ // 使用一个临时key来检查限制,这样可以避免实际记录
+ checkKey := successKey + "_check"
+ if !inMemoryRateLimiter.Request(checkKey, successMaxCount, duration) {
+ c.Status(http.StatusTooManyRequests)
+ c.Abort()
+ return
+ }
+
+ // 3. 处理请求
+ c.Next()
+
+ // 4. 如果请求成功,记录到实际的成功请求计数中
+ if c.Writer.Status() < 400 {
+ inMemoryRateLimiter.Request(successKey, successMaxCount, duration)
+ }
+ }
+}
+
+// ModelRequestRateLimit 模型请求限流中间件
+func ModelRequestRateLimit() func(c *gin.Context) {
+ // 如果未启用限流,直接放行
+ if !setting.ModelRequestRateLimitEnabled {
+ return defNext
+ }
+
+ // 计算限流参数
+ duration := int64(setting.ModelRequestRateLimitDurationMinutes * 60)
+ totalMaxCount := setting.ModelRequestRateLimitCount
+ successMaxCount := setting.ModelRequestRateLimitSuccessCount
+
+ // 根据存储类型选择限流处理器
+ if common.RedisEnabled {
+ return redisRateLimitHandler(duration, totalMaxCount, successMaxCount)
+ } else {
+ return memoryRateLimitHandler(duration, totalMaxCount, successMaxCount)
+ }
+}
diff --git a/model/option.go b/model/option.go
index 24935c69..3e9e9541 100644
--- a/model/option.go
+++ b/model/option.go
@@ -85,6 +85,9 @@ func InitOptionMap() {
common.OptionMap["QuotaForInvitee"] = strconv.Itoa(common.QuotaForInvitee)
common.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(common.QuotaRemindThreshold)
common.OptionMap["ShouldPreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota)
+ common.OptionMap["ModelRequestRateLimitCount"] = strconv.Itoa(setting.ModelRequestRateLimitCount)
+ common.OptionMap["ModelRequestRateLimitDurationMinutes"] = strconv.Itoa(setting.ModelRequestRateLimitDurationMinutes)
+ common.OptionMap["ModelRequestRateLimitSuccessCount"] = strconv.Itoa(setting.ModelRequestRateLimitSuccessCount)
common.OptionMap["ModelRatio"] = common.ModelRatio2JSONString()
common.OptionMap["ModelPrice"] = common.ModelPrice2JSONString()
common.OptionMap["GroupRatio"] = setting.GroupRatio2JSONString()
@@ -105,6 +108,7 @@ func InitOptionMap() {
common.OptionMap["MjActionCheckSuccessEnabled"] = strconv.FormatBool(setting.MjActionCheckSuccessEnabled)
common.OptionMap["CheckSensitiveEnabled"] = strconv.FormatBool(setting.CheckSensitiveEnabled)
common.OptionMap["DemoSiteEnabled"] = strconv.FormatBool(setting.DemoSiteEnabled)
+ common.OptionMap["ModelRequestRateLimitEnabled"] = strconv.FormatBool(setting.ModelRequestRateLimitEnabled)
common.OptionMap["CheckSensitiveOnPromptEnabled"] = strconv.FormatBool(setting.CheckSensitiveOnPromptEnabled)
//common.OptionMap["CheckSensitiveOnCompletionEnabled"] = strconv.FormatBool(constant.CheckSensitiveOnCompletionEnabled)
common.OptionMap["StopOnSensitiveEnabled"] = strconv.FormatBool(setting.StopOnSensitiveEnabled)
@@ -226,6 +230,9 @@ func updateOptionMap(key string, value string) (err error) {
setting.DemoSiteEnabled = boolValue
case "CheckSensitiveOnPromptEnabled":
setting.CheckSensitiveOnPromptEnabled = boolValue
+ case "ModelRequestRateLimitEnabled":
+ setting.ModelRequestRateLimitEnabled = boolValue
+
//case "CheckSensitiveOnCompletionEnabled":
// constant.CheckSensitiveOnCompletionEnabled = boolValue
case "StopOnSensitiveEnabled":
@@ -308,6 +315,12 @@ func updateOptionMap(key string, value string) (err error) {
common.QuotaRemindThreshold, _ = strconv.Atoi(value)
case "ShouldPreConsumedQuota":
common.PreConsumedQuota, _ = strconv.Atoi(value)
+ case "ModelRequestRateLimitCount":
+ setting.ModelRequestRateLimitCount, _ = strconv.Atoi(value)
+ case "ModelRequestRateLimitDurationMinutes":
+ setting.ModelRequestRateLimitDurationMinutes, _ = strconv.Atoi(value)
+ case "ModelRequestRateLimitSuccessCount":
+ setting.ModelRequestRateLimitSuccessCount, _ = strconv.Atoi(value)
case "RetryTimes":
common.RetryTimes, _ = strconv.Atoi(value)
case "DataExportInterval":
diff --git a/router/relay-router.go b/router/relay-router.go
index 63f5c36d..32e0c682 100644
--- a/router/relay-router.go
+++ b/router/relay-router.go
@@ -24,6 +24,7 @@ func SetRelayRouter(router *gin.Engine) {
}
relayV1Router := router.Group("/v1")
relayV1Router.Use(middleware.TokenAuth())
+ relayV1Router.Use(middleware.ModelRequestRateLimit())
{
// WebSocket 路由
wsRouter := relayV1Router.Group("")
diff --git a/setting/rate_limit.go b/setting/rate_limit.go
new file mode 100644
index 00000000..4b216948
--- /dev/null
+++ b/setting/rate_limit.go
@@ -0,0 +1,6 @@
+package setting
+
+var ModelRequestRateLimitEnabled = false
+var ModelRequestRateLimitDurationMinutes = 1
+var ModelRequestRateLimitCount = 0
+var ModelRequestRateLimitSuccessCount = 1000
diff --git a/web/src/components/RateLimitSetting.js b/web/src/components/RateLimitSetting.js
new file mode 100644
index 00000000..b6c92917
--- /dev/null
+++ b/web/src/components/RateLimitSetting.js
@@ -0,0 +1,80 @@
+import React, { useEffect, useState } from 'react';
+import { Card, Spin, Tabs } from '@douyinfe/semi-ui';
+import SettingsGeneral from '../pages/Setting/Operation/SettingsGeneral.js';
+import SettingsDrawing from '../pages/Setting/Operation/SettingsDrawing.js';
+import SettingsSensitiveWords from '../pages/Setting/Operation/SettingsSensitiveWords.js';
+import SettingsLog from '../pages/Setting/Operation/SettingsLog.js';
+import SettingsDataDashboard from '../pages/Setting/Operation/SettingsDataDashboard.js';
+import SettingsMonitoring from '../pages/Setting/Operation/SettingsMonitoring.js';
+import SettingsCreditLimit from '../pages/Setting/Operation/SettingsCreditLimit.js';
+import SettingsMagnification from '../pages/Setting/Operation/SettingsMagnification.js';
+import ModelSettingsVisualEditor from '../pages/Setting/Operation/ModelSettingsVisualEditor.js';
+import GroupRatioSettings from '../pages/Setting/Operation/GroupRatioSettings.js';
+import ModelRatioSettings from '../pages/Setting/Operation/ModelRatioSettings.js';
+
+
+import { API, showError, showSuccess } from '../helpers';
+import SettingsChats from '../pages/Setting/Operation/SettingsChats.js';
+import { useTranslation } from 'react-i18next';
+import RequestRateLimit from '../pages/Setting/RateLimit/SettingsRequestRateLimit.js';
+
+const RateLimitSetting = () => {
+ const { t } = useTranslation();
+ let [inputs, setInputs] = useState({
+ ModelRequestRateLimitEnabled: false,
+ ModelRequestRateLimitCount: 0,
+ ModelRequestRateLimitSuccessCount: 1000,
+ ModelRequestRateLimitDurationMinutes: 1,
+ });
+
+ let [loading, setLoading] = useState(false);
+
+ const getOptions = async () => {
+ const res = await API.get('/api/option/');
+ const { success, message, data } = res.data;
+ if (success) {
+ let newInputs = {};
+ data.forEach((item) => {
+ if (
+ item.key.endsWith('Enabled')
+ ) {
+ newInputs[item.key] = item.value === 'true' ? true : false;
+ } else {
+ newInputs[item.key] = item.value;
+ }
+ });
+
+ setInputs(newInputs);
+ } else {
+ showError(message);
+ }
+ };
+ async function onRefresh() {
+ try {
+ setLoading(true);
+ await getOptions();
+ // showSuccess('刷新成功');
+ } catch (error) {
+ showError('刷新失败');
+ } finally {
+ setLoading(false);
+ }
+ }
+
+ useEffect(() => {
+ onRefresh();
+ }, []);
+
+ return (
+ <>
+
+ {/* AI请求速率限制 */}
+
+
+
+
+ >
+ );
+};
+
+export default RateLimitSetting;
diff --git a/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js b/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js
new file mode 100644
index 00000000..6f4a5571
--- /dev/null
+++ b/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js
@@ -0,0 +1,159 @@
+import React, { useEffect, useState, useRef } from 'react';
+import { Button, Col, Form, Row, Spin } from '@douyinfe/semi-ui';
+import {
+ compareObjects,
+ API,
+ showError,
+ showSuccess,
+ showWarning,
+} from '../../../helpers';
+import { useTranslation } from 'react-i18next';
+
+export default function RequestRateLimit(props) {
+ const { t } = useTranslation();
+
+ const [loading, setLoading] = useState(false);
+ const [inputs, setInputs] = useState({
+ ModelRequestRateLimitEnabled: false,
+ ModelRequestRateLimitCount: -1,
+ ModelRequestRateLimitSuccessCount: 1000,
+ ModelRequestRateLimitDurationMinutes: 1
+ });
+ const refForm = useRef();
+ const [inputsRow, setInputsRow] = useState(inputs);
+
+ function onSubmit() {
+ const updateArray = compareObjects(inputs, inputsRow);
+ if (!updateArray.length) return showWarning(t('你似乎并没有修改什么'));
+ const requestQueue = updateArray.map((item) => {
+ let value = '';
+ if (typeof inputs[item.key] === 'boolean') {
+ value = String(inputs[item.key]);
+ } else {
+ value = inputs[item.key];
+ }
+ return API.put('/api/option/', {
+ key: item.key,
+ value,
+ });
+ });
+ setLoading(true);
+ Promise.all(requestQueue)
+ .then((res) => {
+ if (requestQueue.length === 1) {
+ if (res.includes(undefined)) return;
+ } else if (requestQueue.length > 1) {
+ if (res.includes(undefined)) return showError(t('部分保存失败,请重试'));
+ }
+ showSuccess(t('保存成功'));
+ props.refresh();
+ })
+ .catch(() => {
+ showError(t('保存失败,请重试'));
+ })
+ .finally(() => {
+ setLoading(false);
+ });
+ }
+
+ useEffect(() => {
+ const currentInputs = {};
+ for (let key in props.options) {
+ if (Object.keys(inputs).includes(key)) {
+ currentInputs[key] = props.options[key];
+ }
+ }
+ setInputs(currentInputs);
+ setInputsRow(structuredClone(currentInputs));
+ refForm.current.setValues(currentInputs);
+ }, [props.options]);
+
+ return (
+ <>
+
+
+
+
+ {
+ setInputs({
+ ...inputs,
+ ModelRequestRateLimitEnabled: value,
+ });
+ }}
+ />
+
+
+
+
+
+ setInputs({
+ ...inputs,
+ ModelRequestRateLimitDurationMinutes: String(value),
+ })
+ }
+ />
+
+
+
+
+
+ setInputs({
+ ...inputs,
+ ModelRequestRateLimitCount: String(value),
+ })
+ }
+ />
+
+
+
+ setInputs({
+ ...inputs,
+ ModelRequestRateLimitSuccessCount: String(value),
+ })
+ }
+ />
+
+
+
+
+
+
+
+
+ >
+ );
+}
diff --git a/web/src/pages/Setting/index.js b/web/src/pages/Setting/index.js
index 385fbfeb..b5c5e268 100644
--- a/web/src/pages/Setting/index.js
+++ b/web/src/pages/Setting/index.js
@@ -8,6 +8,7 @@ import { isRoot } from '../../helpers';
import OtherSetting from '../../components/OtherSetting';
import PersonalSetting from '../../components/PersonalSetting';
import OperationSetting from '../../components/OperationSetting';
+import RateLimitSetting from '../../components/RateLimitSetting.js';
const Setting = () => {
const { t } = useTranslation();
@@ -28,6 +29,11 @@ const Setting = () => {
content: ,
itemKey: 'operation',
});
+ panes.push({
+ tab: t('速率限制设置'),
+ content: ,
+ itemKey: 'ratelimit',
+ });
panes.push({
tab: t('系统设置'),
content: ,