From 83a37e4653d30776b707ea604eb8972688c7ffeb Mon Sep 17 00:00:00 2001 From: "1808837298@qq.com" <1808837298@qq.com> Date: Mon, 24 Feb 2025 16:20:55 +0800 Subject: [PATCH] feat: Add model request rate limiting functionality --- middleware/model-rate-limit.go | 172 ++++++++++++++++++ model/option.go | 13 ++ router/relay-router.go | 1 + setting/rate_limit.go | 6 + web/src/components/RateLimitSetting.js | 80 ++++++++ .../RateLimit/SettingsRequestRateLimit.js | 159 ++++++++++++++++ web/src/pages/Setting/index.js | 6 + 7 files changed, 437 insertions(+) create mode 100644 middleware/model-rate-limit.go create mode 100644 setting/rate_limit.go create mode 100644 web/src/components/RateLimitSetting.js create mode 100644 web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js diff --git a/middleware/model-rate-limit.go b/middleware/model-rate-limit.go new file mode 100644 index 00000000..135e0005 --- /dev/null +++ b/middleware/model-rate-limit.go @@ -0,0 +1,172 @@ +package middleware + +import ( + "context" + "fmt" + "net/http" + "one-api/common" + "one-api/setting" + "strconv" + "time" + + "github.com/gin-gonic/gin" + "github.com/go-redis/redis/v8" +) + +const ( + ModelRequestRateLimitCountMark = "MRRL" + ModelRequestRateLimitSuccessCountMark = "MRRLS" +) + +// 检查Redis中的请求限制 +func checkRedisRateLimit(ctx context.Context, rdb *redis.Client, key string, maxCount int, duration int64) (bool, error) { + // 如果maxCount为0,表示不限制 + if maxCount == 0 { + return true, nil + } + + // 获取当前计数 + length, err := rdb.LLen(ctx, key).Result() + if err != nil { + return false, err + } + + // 如果未达到限制,允许请求 + if length < int64(maxCount) { + return true, nil + } + + // 检查时间窗口 + oldTimeStr, _ := rdb.LIndex(ctx, key, -1).Result() + oldTime, err := time.Parse(timeFormat, oldTimeStr) + if err != nil { + return false, err + } + + nowTimeStr := time.Now().Format(timeFormat) + nowTime, err := time.Parse(timeFormat, nowTimeStr) + if err != nil { + return false, err + } + // 如果在时间窗口内已达到限制,拒绝请求 + subTime := nowTime.Sub(oldTime).Seconds() + if int64(subTime) < duration { + rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration) + return false, nil + } + + return true, nil +} + +// 记录Redis请求 +func recordRedisRequest(ctx context.Context, rdb *redis.Client, key string, maxCount int) { + // 如果maxCount为0,不记录请求 + if maxCount == 0 { + return + } + + now := time.Now().Format(timeFormat) + rdb.LPush(ctx, key, now) + rdb.LTrim(ctx, key, 0, int64(maxCount-1)) + rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration) +} + +// Redis限流处理器 +func redisRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) gin.HandlerFunc { + return func(c *gin.Context) { + userId := strconv.Itoa(c.GetInt("id")) + ctx := context.Background() + rdb := common.RDB + + // 1. 检查总请求数限制(当totalMaxCount为0时会自动跳过) + totalKey := fmt.Sprintf("rateLimit:%s:%s", ModelRequestRateLimitCountMark, userId) + allowed, err := checkRedisRateLimit(ctx, rdb, totalKey, totalMaxCount, duration) + if err != nil { + fmt.Println("检查总请求数限制失败:", err.Error()) + abortWithOpenAiMessage(c, http.StatusInternalServerError, "rate_limit_check_failed") + return + } + if !allowed { + abortWithOpenAiMessage(c, http.StatusTooManyRequests, fmt.Sprintf("您已达到总请求数限制:%d分钟内最多请求%d次,包括失败次数,请检查您的请求是否正确", setting.ModelRequestRateLimitDurationMinutes, totalMaxCount)) + } + + // 2. 检查成功请求数限制 + successKey := fmt.Sprintf("rateLimit:%s:%s", ModelRequestRateLimitSuccessCountMark, userId) + allowed, err = checkRedisRateLimit(ctx, rdb, successKey, successMaxCount, duration) + if err != nil { + fmt.Println("检查成功请求数限制失败:", err.Error()) + abortWithOpenAiMessage(c, http.StatusInternalServerError, "rate_limit_check_failed") + return + } + if !allowed { + abortWithOpenAiMessage(c, http.StatusTooManyRequests, fmt.Sprintf("您已达到请求数限制:%d分钟内最多请求%d次", setting.ModelRequestRateLimitDurationMinutes, successMaxCount)) + return + } + + // 3. 记录总请求(当totalMaxCount为0时会自动跳过) + recordRedisRequest(ctx, rdb, totalKey, totalMaxCount) + + // 4. 处理请求 + c.Next() + + // 5. 如果请求成功,记录成功请求 + if c.Writer.Status() < 400 { + recordRedisRequest(ctx, rdb, successKey, successMaxCount) + } + } +} + +// 内存限流处理器 +func memoryRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) gin.HandlerFunc { + inMemoryRateLimiter.Init(common.RateLimitKeyExpirationDuration) + + return func(c *gin.Context) { + userId := strconv.Itoa(c.GetInt("id")) + totalKey := ModelRequestRateLimitCountMark + userId + successKey := ModelRequestRateLimitSuccessCountMark + userId + + // 1. 检查总请求数限制(当totalMaxCount为0时跳过) + if totalMaxCount > 0 && !inMemoryRateLimiter.Request(totalKey, totalMaxCount, duration) { + c.Status(http.StatusTooManyRequests) + c.Abort() + return + } + + // 2. 检查成功请求数限制 + // 使用一个临时key来检查限制,这样可以避免实际记录 + checkKey := successKey + "_check" + if !inMemoryRateLimiter.Request(checkKey, successMaxCount, duration) { + c.Status(http.StatusTooManyRequests) + c.Abort() + return + } + + // 3. 处理请求 + c.Next() + + // 4. 如果请求成功,记录到实际的成功请求计数中 + if c.Writer.Status() < 400 { + inMemoryRateLimiter.Request(successKey, successMaxCount, duration) + } + } +} + +// ModelRequestRateLimit 模型请求限流中间件 +func ModelRequestRateLimit() func(c *gin.Context) { + // 如果未启用限流,直接放行 + if !setting.ModelRequestRateLimitEnabled { + return defNext + } + + // 计算限流参数 + duration := int64(setting.ModelRequestRateLimitDurationMinutes * 60) + totalMaxCount := setting.ModelRequestRateLimitCount + successMaxCount := setting.ModelRequestRateLimitSuccessCount + + // 根据存储类型选择限流处理器 + if common.RedisEnabled { + return redisRateLimitHandler(duration, totalMaxCount, successMaxCount) + } else { + return memoryRateLimitHandler(duration, totalMaxCount, successMaxCount) + } +} diff --git a/model/option.go b/model/option.go index 24935c69..3e9e9541 100644 --- a/model/option.go +++ b/model/option.go @@ -85,6 +85,9 @@ func InitOptionMap() { common.OptionMap["QuotaForInvitee"] = strconv.Itoa(common.QuotaForInvitee) common.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(common.QuotaRemindThreshold) common.OptionMap["ShouldPreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota) + common.OptionMap["ModelRequestRateLimitCount"] = strconv.Itoa(setting.ModelRequestRateLimitCount) + common.OptionMap["ModelRequestRateLimitDurationMinutes"] = strconv.Itoa(setting.ModelRequestRateLimitDurationMinutes) + common.OptionMap["ModelRequestRateLimitSuccessCount"] = strconv.Itoa(setting.ModelRequestRateLimitSuccessCount) common.OptionMap["ModelRatio"] = common.ModelRatio2JSONString() common.OptionMap["ModelPrice"] = common.ModelPrice2JSONString() common.OptionMap["GroupRatio"] = setting.GroupRatio2JSONString() @@ -105,6 +108,7 @@ func InitOptionMap() { common.OptionMap["MjActionCheckSuccessEnabled"] = strconv.FormatBool(setting.MjActionCheckSuccessEnabled) common.OptionMap["CheckSensitiveEnabled"] = strconv.FormatBool(setting.CheckSensitiveEnabled) common.OptionMap["DemoSiteEnabled"] = strconv.FormatBool(setting.DemoSiteEnabled) + common.OptionMap["ModelRequestRateLimitEnabled"] = strconv.FormatBool(setting.ModelRequestRateLimitEnabled) common.OptionMap["CheckSensitiveOnPromptEnabled"] = strconv.FormatBool(setting.CheckSensitiveOnPromptEnabled) //common.OptionMap["CheckSensitiveOnCompletionEnabled"] = strconv.FormatBool(constant.CheckSensitiveOnCompletionEnabled) common.OptionMap["StopOnSensitiveEnabled"] = strconv.FormatBool(setting.StopOnSensitiveEnabled) @@ -226,6 +230,9 @@ func updateOptionMap(key string, value string) (err error) { setting.DemoSiteEnabled = boolValue case "CheckSensitiveOnPromptEnabled": setting.CheckSensitiveOnPromptEnabled = boolValue + case "ModelRequestRateLimitEnabled": + setting.ModelRequestRateLimitEnabled = boolValue + //case "CheckSensitiveOnCompletionEnabled": // constant.CheckSensitiveOnCompletionEnabled = boolValue case "StopOnSensitiveEnabled": @@ -308,6 +315,12 @@ func updateOptionMap(key string, value string) (err error) { common.QuotaRemindThreshold, _ = strconv.Atoi(value) case "ShouldPreConsumedQuota": common.PreConsumedQuota, _ = strconv.Atoi(value) + case "ModelRequestRateLimitCount": + setting.ModelRequestRateLimitCount, _ = strconv.Atoi(value) + case "ModelRequestRateLimitDurationMinutes": + setting.ModelRequestRateLimitDurationMinutes, _ = strconv.Atoi(value) + case "ModelRequestRateLimitSuccessCount": + setting.ModelRequestRateLimitSuccessCount, _ = strconv.Atoi(value) case "RetryTimes": common.RetryTimes, _ = strconv.Atoi(value) case "DataExportInterval": diff --git a/router/relay-router.go b/router/relay-router.go index 63f5c36d..32e0c682 100644 --- a/router/relay-router.go +++ b/router/relay-router.go @@ -24,6 +24,7 @@ func SetRelayRouter(router *gin.Engine) { } relayV1Router := router.Group("/v1") relayV1Router.Use(middleware.TokenAuth()) + relayV1Router.Use(middleware.ModelRequestRateLimit()) { // WebSocket 路由 wsRouter := relayV1Router.Group("") diff --git a/setting/rate_limit.go b/setting/rate_limit.go new file mode 100644 index 00000000..4b216948 --- /dev/null +++ b/setting/rate_limit.go @@ -0,0 +1,6 @@ +package setting + +var ModelRequestRateLimitEnabled = false +var ModelRequestRateLimitDurationMinutes = 1 +var ModelRequestRateLimitCount = 0 +var ModelRequestRateLimitSuccessCount = 1000 diff --git a/web/src/components/RateLimitSetting.js b/web/src/components/RateLimitSetting.js new file mode 100644 index 00000000..b6c92917 --- /dev/null +++ b/web/src/components/RateLimitSetting.js @@ -0,0 +1,80 @@ +import React, { useEffect, useState } from 'react'; +import { Card, Spin, Tabs } from '@douyinfe/semi-ui'; +import SettingsGeneral from '../pages/Setting/Operation/SettingsGeneral.js'; +import SettingsDrawing from '../pages/Setting/Operation/SettingsDrawing.js'; +import SettingsSensitiveWords from '../pages/Setting/Operation/SettingsSensitiveWords.js'; +import SettingsLog from '../pages/Setting/Operation/SettingsLog.js'; +import SettingsDataDashboard from '../pages/Setting/Operation/SettingsDataDashboard.js'; +import SettingsMonitoring from '../pages/Setting/Operation/SettingsMonitoring.js'; +import SettingsCreditLimit from '../pages/Setting/Operation/SettingsCreditLimit.js'; +import SettingsMagnification from '../pages/Setting/Operation/SettingsMagnification.js'; +import ModelSettingsVisualEditor from '../pages/Setting/Operation/ModelSettingsVisualEditor.js'; +import GroupRatioSettings from '../pages/Setting/Operation/GroupRatioSettings.js'; +import ModelRatioSettings from '../pages/Setting/Operation/ModelRatioSettings.js'; + + +import { API, showError, showSuccess } from '../helpers'; +import SettingsChats from '../pages/Setting/Operation/SettingsChats.js'; +import { useTranslation } from 'react-i18next'; +import RequestRateLimit from '../pages/Setting/RateLimit/SettingsRequestRateLimit.js'; + +const RateLimitSetting = () => { + const { t } = useTranslation(); + let [inputs, setInputs] = useState({ + ModelRequestRateLimitEnabled: false, + ModelRequestRateLimitCount: 0, + ModelRequestRateLimitSuccessCount: 1000, + ModelRequestRateLimitDurationMinutes: 1, + }); + + let [loading, setLoading] = useState(false); + + const getOptions = async () => { + const res = await API.get('/api/option/'); + const { success, message, data } = res.data; + if (success) { + let newInputs = {}; + data.forEach((item) => { + if ( + item.key.endsWith('Enabled') + ) { + newInputs[item.key] = item.value === 'true' ? true : false; + } else { + newInputs[item.key] = item.value; + } + }); + + setInputs(newInputs); + } else { + showError(message); + } + }; + async function onRefresh() { + try { + setLoading(true); + await getOptions(); + // showSuccess('刷新成功'); + } catch (error) { + showError('刷新失败'); + } finally { + setLoading(false); + } + } + + useEffect(() => { + onRefresh(); + }, []); + + return ( + <> + + {/* AI请求速率限制 */} + + + + + + ); +}; + +export default RateLimitSetting; diff --git a/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js b/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js new file mode 100644 index 00000000..6f4a5571 --- /dev/null +++ b/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js @@ -0,0 +1,159 @@ +import React, { useEffect, useState, useRef } from 'react'; +import { Button, Col, Form, Row, Spin } from '@douyinfe/semi-ui'; +import { + compareObjects, + API, + showError, + showSuccess, + showWarning, +} from '../../../helpers'; +import { useTranslation } from 'react-i18next'; + +export default function RequestRateLimit(props) { + const { t } = useTranslation(); + + const [loading, setLoading] = useState(false); + const [inputs, setInputs] = useState({ + ModelRequestRateLimitEnabled: false, + ModelRequestRateLimitCount: -1, + ModelRequestRateLimitSuccessCount: 1000, + ModelRequestRateLimitDurationMinutes: 1 + }); + const refForm = useRef(); + const [inputsRow, setInputsRow] = useState(inputs); + + function onSubmit() { + const updateArray = compareObjects(inputs, inputsRow); + if (!updateArray.length) return showWarning(t('你似乎并没有修改什么')); + const requestQueue = updateArray.map((item) => { + let value = ''; + if (typeof inputs[item.key] === 'boolean') { + value = String(inputs[item.key]); + } else { + value = inputs[item.key]; + } + return API.put('/api/option/', { + key: item.key, + value, + }); + }); + setLoading(true); + Promise.all(requestQueue) + .then((res) => { + if (requestQueue.length === 1) { + if (res.includes(undefined)) return; + } else if (requestQueue.length > 1) { + if (res.includes(undefined)) return showError(t('部分保存失败,请重试')); + } + showSuccess(t('保存成功')); + props.refresh(); + }) + .catch(() => { + showError(t('保存失败,请重试')); + }) + .finally(() => { + setLoading(false); + }); + } + + useEffect(() => { + const currentInputs = {}; + for (let key in props.options) { + if (Object.keys(inputs).includes(key)) { + currentInputs[key] = props.options[key]; + } + } + setInputs(currentInputs); + setInputsRow(structuredClone(currentInputs)); + refForm.current.setValues(currentInputs); + }, [props.options]); + + return ( + <> + +
(refForm.current = formAPI)} + style={{ marginBottom: 15 }} + > + + + + { + setInputs({ + ...inputs, + ModelRequestRateLimitEnabled: value, + }); + }} + /> + + + + + + setInputs({ + ...inputs, + ModelRequestRateLimitDurationMinutes: String(value), + }) + } + /> + + + + + + setInputs({ + ...inputs, + ModelRequestRateLimitCount: String(value), + }) + } + /> + + + + setInputs({ + ...inputs, + ModelRequestRateLimitSuccessCount: String(value), + }) + } + /> + + + + + + +
+
+ + ); +} diff --git a/web/src/pages/Setting/index.js b/web/src/pages/Setting/index.js index 385fbfeb..b5c5e268 100644 --- a/web/src/pages/Setting/index.js +++ b/web/src/pages/Setting/index.js @@ -8,6 +8,7 @@ import { isRoot } from '../../helpers'; import OtherSetting from '../../components/OtherSetting'; import PersonalSetting from '../../components/PersonalSetting'; import OperationSetting from '../../components/OperationSetting'; +import RateLimitSetting from '../../components/RateLimitSetting.js'; const Setting = () => { const { t } = useTranslation(); @@ -28,6 +29,11 @@ const Setting = () => { content: , itemKey: 'operation', }); + panes.push({ + tab: t('速率限制设置'), + content: , + itemKey: 'ratelimit', + }); panes.push({ tab: t('系统设置'), content: ,