From 6c3fb7777ec3fe4874b249251120e68b5e22642f Mon Sep 17 00:00:00 2001 From: tbphp Date: Mon, 5 May 2025 07:31:54 +0800 Subject: [PATCH 01/16] =?UTF-8?q?feat:=20=E5=A2=9E=E5=8A=A0=E5=88=86?= =?UTF-8?q?=E7=BB=84=E9=80=9F=E7=8E=87=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- middleware/model-rate-limit.go | 37 +++++++-- model/option.go | 47 +++++++++++- setting/rate_limit.go | 70 +++++++++++++++++ web/src/components/RateLimitSetting.js | 1 + .../RateLimit/SettingsRequestRateLimit.js | 76 +++++++++++++++++-- 5 files changed, 214 insertions(+), 17 deletions(-) diff --git a/middleware/model-rate-limit.go b/middleware/model-rate-limit.go index 581dc451..d4199ece 100644 --- a/middleware/model-rate-limit.go +++ b/middleware/model-rate-limit.go @@ -168,16 +168,39 @@ func ModelRequestRateLimit() func(c *gin.Context) { return } - // 计算限流参数 + // 计算通用限流参数 duration := int64(setting.ModelRequestRateLimitDurationMinutes * 60) - totalMaxCount := setting.ModelRequestRateLimitCount - successMaxCount := setting.ModelRequestRateLimitSuccessCount - // 根据存储类型选择并执行限流处理器 - if common.RedisEnabled { - redisRateLimitHandler(duration, totalMaxCount, successMaxCount)(c) + // 获取用户组 + group := c.GetString("token_group") + if group == "" { + group = c.GetString("group") + } + if group == "" { + group = "default" // 默认组 + } + + // 尝试获取用户组特定的限制 + groupTotalCount, groupSuccessCount, found := setting.GetGroupRateLimit(group) + + // 确定最终的限制值 + finalTotalCount := setting.ModelRequestRateLimitCount // 默认使用全局总次数限制 + finalSuccessCount := setting.ModelRequestRateLimitSuccessCount // 默认使用全局成功次数限制 + + if found { + // 如果找到用户组特定限制,则使用它们 + finalTotalCount = groupTotalCount + finalSuccessCount = groupSuccessCount + common.LogWarn(c.Request.Context(), fmt.Sprintf("Using rate limit for group '%s': total=%d, success=%d", group, finalTotalCount, finalSuccessCount)) } else { - memoryRateLimitHandler(duration, totalMaxCount, successMaxCount)(c) + common.LogInfo(c.Request.Context(), fmt.Sprintf("No specific rate limit found for group '%s', using global limits: total=%d, success=%d", group, finalTotalCount, finalSuccessCount)) + } + + // 根据存储类型选择并执行限流处理器,传入最终确定的限制值 + if common.RedisEnabled { + redisRateLimitHandler(duration, finalTotalCount, finalSuccessCount)(c) + } else { + memoryRateLimitHandler(duration, finalTotalCount, finalSuccessCount)(c) } } } diff --git a/model/option.go b/model/option.go index d575742f..1f5fb3aa 100644 --- a/model/option.go +++ b/model/option.go @@ -1,6 +1,8 @@ package model import ( + "encoding/json" + "fmt" "one-api/common" "one-api/setting" "one-api/setting/config" @@ -96,6 +98,7 @@ func InitOptionMap() { common.OptionMap["ModelPrice"] = operation_setting.ModelPrice2JSONString() common.OptionMap["CacheRatio"] = operation_setting.CacheRatio2JSONString() common.OptionMap["GroupRatio"] = setting.GroupRatio2JSONString() + common.OptionMap[setting.ModelRequestRateLimitGroupKey] = "{}" // 添加用户组速率限制默认值 common.OptionMap["UserUsableGroups"] = setting.UserUsableGroups2JSONString() common.OptionMap["CompletionRatio"] = operation_setting.CompletionRatio2JSONString() common.OptionMap["TopUpLink"] = common.TopUpLink @@ -150,7 +153,32 @@ func SyncOptions(frequency int) { } func UpdateOption(key string, value string) error { - // Save to database first + originalValue := value // 保存原始值以备后用 + + // Validate and format specific keys before saving + if key == setting.ModelRequestRateLimitGroupKey { + var cfg map[string][2]int + // Validate the JSON structure first using the original value + err := json.Unmarshal([]byte(originalValue), &cfg) + if err != nil { + // 提供更具体的错误信息 + return fmt.Errorf("无效的 JSON 格式 for %s: %w", key, err) + } + // TODO: 可以添加更细致的结构验证,例如检查数组长度是否为2,值是否为非负数等。 + // if !isValidModelRequestRateLimitGroupConfig(cfg) { + // return fmt.Errorf("无效的配置值 for %s", key) + // } + + // If valid, format the JSON before saving + formattedValueBytes, marshalErr := json.MarshalIndent(cfg, "", " ") + if marshalErr != nil { + // This should ideally not happen if validation passed, but handle defensively + return fmt.Errorf("failed to marshal validated %s config: %w", key, marshalErr) + } + value = string(formattedValueBytes) // Use formatted JSON for saving and memory update + } + + // Save to database option := Option{ Key: key, } @@ -160,8 +188,12 @@ func UpdateOption(key string, value string) error { // Save is a combination function. // If save value does not contain primary key, it will execute Create, // otherwise it will execute Update (with all fields). - DB.Save(&option) - // Update OptionMap + if err := DB.Save(&option).Error; err != nil { + return fmt.Errorf("保存选项 %s 到数据库失败: %w", key, err) // 添加错误上下文 + } + + // Update OptionMap in memory using the potentially formatted value + // updateOptionMap 会处理内存中 setting.ModelRequestRateLimitGroupConfig 的更新 return updateOptionMap(key, value) } @@ -372,6 +404,15 @@ func updateOptionMap(key string, value string) (err error) { operation_setting.AutomaticDisableKeywordsFromString(value) case "StreamCacheQueueLength": setting.StreamCacheQueueLength, _ = strconv.Atoi(value) + case setting.ModelRequestRateLimitGroupKey: + // Use the (potentially formatted) value passed from UpdateOption + // to update the actual configuration in memory. + // This is the single point where the memory state for this specific setting is updated. + err = setting.UpdateModelRequestRateLimitGroupConfig(value) + if err != nil { + // 添加错误上下文 + err = fmt.Errorf("更新内存中的 %s 配置失败: %w", key, err) + } } return err } diff --git a/setting/rate_limit.go b/setting/rate_limit.go index 4b216948..c83885a6 100644 --- a/setting/rate_limit.go +++ b/setting/rate_limit.go @@ -1,6 +1,76 @@ package setting +import ( + "encoding/json" + "fmt" + "one-api/common" + "sync" +) + var ModelRequestRateLimitEnabled = false var ModelRequestRateLimitDurationMinutes = 1 var ModelRequestRateLimitCount = 0 var ModelRequestRateLimitSuccessCount = 1000 + +// ModelRequestRateLimitGroupKey 定义了模型请求按组速率限制的配置键 +const ModelRequestRateLimitGroupKey = "ModelRequestRateLimitGroup" + +// ModelRequestRateLimitGroupConfig 存储按用户组解析后的速率限制配置 +// map[groupName][2]int{totalCount, successCount} +var ModelRequestRateLimitGroupConfig map[string][2]int +var ModelRequestRateLimitGroupMutex sync.RWMutex + +// UpdateModelRequestRateLimitGroupConfig 解析、校验并更新内存中的用户组速率限制配置 +func UpdateModelRequestRateLimitGroupConfig(jsonStr string) error { + ModelRequestRateLimitGroupMutex.Lock() + defer ModelRequestRateLimitGroupMutex.Unlock() + + var newConfig map[string][2]int + if jsonStr == "" || jsonStr == "{}" { + // 如果配置为空或空JSON对象,则清空内存配置 + ModelRequestRateLimitGroupConfig = make(map[string][2]int) + common.SysLog("Model request rate limit group config cleared") + return nil + } + + err := json.Unmarshal([]byte(jsonStr), &newConfig) + if err != nil { + return fmt.Errorf("failed to unmarshal ModelRequestRateLimitGroup config: %w", err) + } + + // 校验配置值 + for group, limits := range newConfig { + if len(limits) != 2 { + return fmt.Errorf("invalid config for group '%s': limits array length must be 2", group) + } + if limits[1] <= 0 { // successCount must be greater than 0 + return fmt.Errorf("invalid config for group '%s': successCount (limits[1]) must be greater than 0", group) + } + if limits[0] < 0 { // totalCount can be 0 (no limit) or positive + return fmt.Errorf("invalid config for group '%s': totalCount (limits[0]) cannot be negative", group) + } + if limits[0] > 0 && limits[0] < limits[1] { // If totalCount is set, it must be >= successCount + return fmt.Errorf("invalid config for group '%s': totalCount (limits[0]) must be greater than or equal to successCount (limits[1]) when totalCount > 0", group) + } + } + + ModelRequestRateLimitGroupConfig = newConfig + common.SysLog("Model request rate limit group config updated") + return nil +} + +// GetGroupRateLimit 安全地获取指定用户组的速率限制值 +func GetGroupRateLimit(group string) (totalCount, successCount int, found bool) { + ModelRequestRateLimitGroupMutex.RLock() + defer ModelRequestRateLimitGroupMutex.RUnlock() + + if ModelRequestRateLimitGroupConfig == nil { + return 0, 0, false // 配置尚未初始化 + } + + limits, found := ModelRequestRateLimitGroupConfig[group] + if !found { + return 0, 0, false + } + return limits[0], limits[1], true +} diff --git a/web/src/components/RateLimitSetting.js b/web/src/components/RateLimitSetting.js index e06038d6..ad6b53da 100644 --- a/web/src/components/RateLimitSetting.js +++ b/web/src/components/RateLimitSetting.js @@ -13,6 +13,7 @@ const RateLimitSetting = () => { ModelRequestRateLimitCount: 0, ModelRequestRateLimitSuccessCount: 1000, ModelRequestRateLimitDurationMinutes: 1, + ModelRequestRateLimitGroup: {}, }); let [loading, setLoading] = useState(false); diff --git a/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js b/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js index 800e9636..ec1c2158 100644 --- a/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js +++ b/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js @@ -18,6 +18,7 @@ export default function RequestRateLimit(props) { ModelRequestRateLimitCount: -1, ModelRequestRateLimitSuccessCount: 1000, ModelRequestRateLimitDurationMinutes: 1, + ModelRequestRateLimitGroup: '{}', // 添加新字段并设置默认值 }); const refForm = useRef(); const [inputsRow, setInputsRow] = useState(inputs); @@ -32,25 +33,49 @@ export default function RequestRateLimit(props) { } else { value = inputs[item.key]; } + // 校验 ModelRequestRateLimitGroup 是否为有效的 JSON 对象字符串 + if (item.key === 'ModelRequestRateLimitGroup') { + try { + JSON.parse(value); + } catch (e) { + showError(t('用户组速率限制配置不是有效的 JSON 格式!')); + // 阻止请求发送 + return Promise.reject('Invalid JSON format'); + } + } return API.put('/api/option/', { key: item.key, value, }); }); + + // 过滤掉无效的请求(例如,无效的 JSON) + const validRequests = requestQueue.filter(req => req !== null && req !== undefined && typeof req.then === 'function'); + + if (validRequests.length === 0 && requestQueue.length > 0) { + // 如果所有请求都被过滤掉了(因为 JSON 无效),则不继续执行 + return; + } + setLoading(true); - Promise.all(requestQueue) + Promise.all(validRequests) .then((res) => { - if (requestQueue.length === 1) { + if (validRequests.length === 1) { if (res.includes(undefined)) return; - } else if (requestQueue.length > 1) { + } else if (validRequests.length > 1) { if (res.includes(undefined)) return showError(t('部分保存失败,请重试')); } showSuccess(t('保存成功')); props.refresh(); + // 更新 inputsRow 以反映保存后的状态 + setInputsRow(structuredClone(inputs)); }) - .catch(() => { - showError(t('保存失败,请重试')); + .catch((error) => { + // 检查是否是由于无效 JSON 导致的错误 + if (error !== 'Invalid JSON format') { + showError(t('保存失败,请重试')); + } }) .finally(() => { setLoading(false); @@ -66,8 +91,11 @@ export default function RequestRateLimit(props) { } setInputs(currentInputs); setInputsRow(structuredClone(currentInputs)); - refForm.current.setValues(currentInputs); - }, [props.options]); + // 检查 refForm.current 是否存在 + if (refForm.current) { + refForm.current.setValues(currentInputs); + } + }, [props.options]); // 依赖项保持不变,因为 inputs 状态的结构已固定 return ( <> @@ -147,7 +175,41 @@ export default function RequestRateLimit(props) { /> + {/* 用户组速率限制配置项 */} + + +

{t('说明:')}

+
    +
  • {t('使用 JSON 对象格式,键为用户组名 (字符串),值为包含两个整数的数组 [总次数限制, 成功次数限制]。')}
  • +
  • {t('总次数限制: 周期内允许的总请求次数 (含失败),0 代表不限制。')}
  • +
  • {t('成功次数限制: 周期内允许的成功请求次数 (HTTP < 400),必须大于 0。')}
  • +
  • {t('此配置将优先于上方的全局限制设置。')}
  • +
  • {t('未在此处配置的用户组将使用全局限制。')}
  • +
  • {t('限制周期统一使用上方配置的“限制周期”值。')}
  • +
  • {t('输入无效的 JSON 将无法保存。')}
  • +
+ + } + autosize={{ minRows: 5, maxRows: 15 }} + style={{ fontFamily: 'monospace' }} + onChange={(value) => { + setInputs({ + ...inputs, + ModelRequestRateLimitGroup: value, // 直接更新字符串值 + }); + }} + /> + +
+ From 7e7d6112ca460be5c30a6c89fb4165346a6d5651 Mon Sep 17 00:00:00 2001 From: tbphp Date: Mon, 5 May 2025 11:34:57 +0800 Subject: [PATCH 02/16] =?UTF-8?q?feat:=20=E4=BC=98=E5=8C=96=E4=BB=A3?= =?UTF-8?q?=E7=A0=81=EF=BC=8C=E5=8E=BB=E9=99=A4=E5=A4=9A=E4=BD=99=E6=B3=A8?= =?UTF-8?q?=E9=87=8A=E5=92=8C=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .old/option.go | 402 ++++++++++++++++++ middleware/model-rate-limit.go | 43 +- model/option.go | 40 +- setting/rate_limit.go | 39 +- web/src/components/RateLimitSetting.js | 92 ++-- .../RateLimit/SettingsRequestRateLimit.js | 388 +++++++++-------- 6 files changed, 663 insertions(+), 341 deletions(-) create mode 100644 .old/option.go diff --git a/.old/option.go b/.old/option.go new file mode 100644 index 00000000..f80f5cb3 --- /dev/null +++ b/.old/option.go @@ -0,0 +1,402 @@ +package model + +import ( + "one-api/common" + "one-api/setting" + "one-api/setting/config" + "one-api/setting/operation_setting" + "strconv" + "strings" + "time" +) + +type Option struct { + Key string `json:"key" gorm:"primaryKey"` + Value string `json:"value"` +} + +func AllOption() ([]*Option, error) { + var options []*Option + var err error + err = DB.Find(&options).Error + return options, err +} + +func InitOptionMap() { + common.OptionMapRWMutex.Lock() + common.OptionMap = make(map[string]string) + + // 添加原有的系统配置 + common.OptionMap["FileUploadPermission"] = strconv.Itoa(common.FileUploadPermission) + common.OptionMap["FileDownloadPermission"] = strconv.Itoa(common.FileDownloadPermission) + common.OptionMap["ImageUploadPermission"] = strconv.Itoa(common.ImageUploadPermission) + common.OptionMap["ImageDownloadPermission"] = strconv.Itoa(common.ImageDownloadPermission) + common.OptionMap["PasswordLoginEnabled"] = strconv.FormatBool(common.PasswordLoginEnabled) + common.OptionMap["PasswordRegisterEnabled"] = strconv.FormatBool(common.PasswordRegisterEnabled) + common.OptionMap["EmailVerificationEnabled"] = strconv.FormatBool(common.EmailVerificationEnabled) + common.OptionMap["GitHubOAuthEnabled"] = strconv.FormatBool(common.GitHubOAuthEnabled) + common.OptionMap["LinuxDOOAuthEnabled"] = strconv.FormatBool(common.LinuxDOOAuthEnabled) + common.OptionMap["TelegramOAuthEnabled"] = strconv.FormatBool(common.TelegramOAuthEnabled) + common.OptionMap["WeChatAuthEnabled"] = strconv.FormatBool(common.WeChatAuthEnabled) + common.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(common.TurnstileCheckEnabled) + common.OptionMap["RegisterEnabled"] = strconv.FormatBool(common.RegisterEnabled) + common.OptionMap["AutomaticDisableChannelEnabled"] = strconv.FormatBool(common.AutomaticDisableChannelEnabled) + common.OptionMap["AutomaticEnableChannelEnabled"] = strconv.FormatBool(common.AutomaticEnableChannelEnabled) + common.OptionMap["LogConsumeEnabled"] = strconv.FormatBool(common.LogConsumeEnabled) + common.OptionMap["DisplayInCurrencyEnabled"] = strconv.FormatBool(common.DisplayInCurrencyEnabled) + common.OptionMap["DisplayTokenStatEnabled"] = strconv.FormatBool(common.DisplayTokenStatEnabled) + common.OptionMap["DrawingEnabled"] = strconv.FormatBool(common.DrawingEnabled) + common.OptionMap["TaskEnabled"] = strconv.FormatBool(common.TaskEnabled) + common.OptionMap["DataExportEnabled"] = strconv.FormatBool(common.DataExportEnabled) + common.OptionMap["ChannelDisableThreshold"] = strconv.FormatFloat(common.ChannelDisableThreshold, 'f', -1, 64) + common.OptionMap["EmailDomainRestrictionEnabled"] = strconv.FormatBool(common.EmailDomainRestrictionEnabled) + common.OptionMap["EmailAliasRestrictionEnabled"] = strconv.FormatBool(common.EmailAliasRestrictionEnabled) + common.OptionMap["EmailDomainWhitelist"] = strings.Join(common.EmailDomainWhitelist, ",") + common.OptionMap["SMTPServer"] = "" + common.OptionMap["SMTPFrom"] = "" + common.OptionMap["SMTPPort"] = strconv.Itoa(common.SMTPPort) + common.OptionMap["SMTPAccount"] = "" + common.OptionMap["SMTPToken"] = "" + common.OptionMap["SMTPSSLEnabled"] = strconv.FormatBool(common.SMTPSSLEnabled) + common.OptionMap["Notice"] = "" + common.OptionMap["About"] = "" + common.OptionMap["HomePageContent"] = "" + common.OptionMap["Footer"] = common.Footer + common.OptionMap["SystemName"] = common.SystemName + common.OptionMap["Logo"] = common.Logo + common.OptionMap["ServerAddress"] = "" + common.OptionMap["WorkerUrl"] = setting.WorkerUrl + common.OptionMap["WorkerValidKey"] = setting.WorkerValidKey + common.OptionMap["PayAddress"] = "" + common.OptionMap["CustomCallbackAddress"] = "" + common.OptionMap["EpayId"] = "" + common.OptionMap["EpayKey"] = "" + common.OptionMap["Price"] = strconv.FormatFloat(setting.Price, 'f', -1, 64) + common.OptionMap["MinTopUp"] = strconv.Itoa(setting.MinTopUp) + common.OptionMap["TopupGroupRatio"] = common.TopupGroupRatio2JSONString() + common.OptionMap["Chats"] = setting.Chats2JsonString() + common.OptionMap["GitHubClientId"] = "" + common.OptionMap["GitHubClientSecret"] = "" + common.OptionMap["TelegramBotToken"] = "" + common.OptionMap["TelegramBotName"] = "" + common.OptionMap["WeChatServerAddress"] = "" + common.OptionMap["WeChatServerToken"] = "" + common.OptionMap["WeChatAccountQRCodeImageURL"] = "" + common.OptionMap["TurnstileSiteKey"] = "" + common.OptionMap["TurnstileSecretKey"] = "" + common.OptionMap["QuotaForNewUser"] = strconv.Itoa(common.QuotaForNewUser) + common.OptionMap["QuotaForInviter"] = strconv.Itoa(common.QuotaForInviter) + common.OptionMap["QuotaForInvitee"] = strconv.Itoa(common.QuotaForInvitee) + common.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(common.QuotaRemindThreshold) + common.OptionMap["PreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota) + common.OptionMap["ModelRequestRateLimitCount"] = strconv.Itoa(setting.ModelRequestRateLimitCount) + common.OptionMap["ModelRequestRateLimitDurationMinutes"] = strconv.Itoa(setting.ModelRequestRateLimitDurationMinutes) + common.OptionMap["ModelRequestRateLimitSuccessCount"] = strconv.Itoa(setting.ModelRequestRateLimitSuccessCount) + common.OptionMap["ModelRatio"] = operation_setting.ModelRatio2JSONString() + common.OptionMap["ModelPrice"] = operation_setting.ModelPrice2JSONString() + common.OptionMap["CacheRatio"] = operation_setting.CacheRatio2JSONString() + common.OptionMap["GroupRatio"] = setting.GroupRatio2JSONString() + common.OptionMap["UserUsableGroups"] = setting.UserUsableGroups2JSONString() + common.OptionMap["CompletionRatio"] = operation_setting.CompletionRatio2JSONString() + common.OptionMap["TopUpLink"] = common.TopUpLink + //common.OptionMap["ChatLink"] = common.ChatLink + //common.OptionMap["ChatLink2"] = common.ChatLink2 + common.OptionMap["QuotaPerUnit"] = strconv.FormatFloat(common.QuotaPerUnit, 'f', -1, 64) + common.OptionMap["RetryTimes"] = strconv.Itoa(common.RetryTimes) + common.OptionMap["DataExportInterval"] = strconv.Itoa(common.DataExportInterval) + common.OptionMap["DataExportDefaultTime"] = common.DataExportDefaultTime + common.OptionMap["DefaultCollapseSidebar"] = strconv.FormatBool(common.DefaultCollapseSidebar) + common.OptionMap["MjNotifyEnabled"] = strconv.FormatBool(setting.MjNotifyEnabled) + common.OptionMap["MjAccountFilterEnabled"] = strconv.FormatBool(setting.MjAccountFilterEnabled) + common.OptionMap["MjModeClearEnabled"] = strconv.FormatBool(setting.MjModeClearEnabled) + common.OptionMap["MjForwardUrlEnabled"] = strconv.FormatBool(setting.MjForwardUrlEnabled) + common.OptionMap["MjActionCheckSuccessEnabled"] = strconv.FormatBool(setting.MjActionCheckSuccessEnabled) + common.OptionMap["CheckSensitiveEnabled"] = strconv.FormatBool(setting.CheckSensitiveEnabled) + common.OptionMap["DemoSiteEnabled"] = strconv.FormatBool(operation_setting.DemoSiteEnabled) + common.OptionMap["SelfUseModeEnabled"] = strconv.FormatBool(operation_setting.SelfUseModeEnabled) + common.OptionMap["ModelRequestRateLimitEnabled"] = strconv.FormatBool(setting.ModelRequestRateLimitEnabled) + common.OptionMap["CheckSensitiveOnPromptEnabled"] = strconv.FormatBool(setting.CheckSensitiveOnPromptEnabled) + common.OptionMap["StopOnSensitiveEnabled"] = strconv.FormatBool(setting.StopOnSensitiveEnabled) + common.OptionMap["SensitiveWords"] = setting.SensitiveWordsToString() + common.OptionMap["StreamCacheQueueLength"] = strconv.Itoa(setting.StreamCacheQueueLength) + common.OptionMap["AutomaticDisableKeywords"] = operation_setting.AutomaticDisableKeywordsToString() + + // 自动添加所有注册的模型配置 + modelConfigs := config.GlobalConfig.ExportAllConfigs() + for k, v := range modelConfigs { + common.OptionMap[k] = v + } + + common.OptionMapRWMutex.Unlock() + loadOptionsFromDatabase() +} + +func loadOptionsFromDatabase() { + options, _ := AllOption() + for _, option := range options { + err := updateOptionMap(option.Key, option.Value) + if err != nil { + common.SysError("failed to update option map: " + err.Error()) + } + } +} + +func SyncOptions(frequency int) { + for { + time.Sleep(time.Duration(frequency) * time.Second) + common.SysLog("syncing options from database") + loadOptionsFromDatabase() + } +} + +func UpdateOption(key string, value string) error { + // Save to database first + option := Option{ + Key: key, + } + // https://gorm.io/docs/update.html#Save-All-Fields + DB.FirstOrCreate(&option, Option{Key: key}) + option.Value = value + // Save is a combination function. + // If save value does not contain primary key, it will execute Create, + // otherwise it will execute Update (with all fields). + DB.Save(&option) + // Update OptionMap + return updateOptionMap(key, value) +} + +func updateOptionMap(key string, value string) (err error) { + common.OptionMapRWMutex.Lock() + defer common.OptionMapRWMutex.Unlock() + common.OptionMap[key] = value + + // 检查是否是模型配置 - 使用更规范的方式处理 + if handleConfigUpdate(key, value) { + return nil // 已由配置系统处理 + } + + // 处理传统配置项... + if strings.HasSuffix(key, "Permission") { + intValue, _ := strconv.Atoi(value) + switch key { + case "FileUploadPermission": + common.FileUploadPermission = intValue + case "FileDownloadPermission": + common.FileDownloadPermission = intValue + case "ImageUploadPermission": + common.ImageUploadPermission = intValue + case "ImageDownloadPermission": + common.ImageDownloadPermission = intValue + } + } + if strings.HasSuffix(key, "Enabled") || key == "DefaultCollapseSidebar" { + boolValue := value == "true" + switch key { + case "PasswordRegisterEnabled": + common.PasswordRegisterEnabled = boolValue + case "PasswordLoginEnabled": + common.PasswordLoginEnabled = boolValue + case "EmailVerificationEnabled": + common.EmailVerificationEnabled = boolValue + case "GitHubOAuthEnabled": + common.GitHubOAuthEnabled = boolValue + case "LinuxDOOAuthEnabled": + common.LinuxDOOAuthEnabled = boolValue + case "WeChatAuthEnabled": + common.WeChatAuthEnabled = boolValue + case "TelegramOAuthEnabled": + common.TelegramOAuthEnabled = boolValue + case "TurnstileCheckEnabled": + common.TurnstileCheckEnabled = boolValue + case "RegisterEnabled": + common.RegisterEnabled = boolValue + case "EmailDomainRestrictionEnabled": + common.EmailDomainRestrictionEnabled = boolValue + case "EmailAliasRestrictionEnabled": + common.EmailAliasRestrictionEnabled = boolValue + case "AutomaticDisableChannelEnabled": + common.AutomaticDisableChannelEnabled = boolValue + case "AutomaticEnableChannelEnabled": + common.AutomaticEnableChannelEnabled = boolValue + case "LogConsumeEnabled": + common.LogConsumeEnabled = boolValue + case "DisplayInCurrencyEnabled": + common.DisplayInCurrencyEnabled = boolValue + case "DisplayTokenStatEnabled": + common.DisplayTokenStatEnabled = boolValue + case "DrawingEnabled": + common.DrawingEnabled = boolValue + case "TaskEnabled": + common.TaskEnabled = boolValue + case "DataExportEnabled": + common.DataExportEnabled = boolValue + case "DefaultCollapseSidebar": + common.DefaultCollapseSidebar = boolValue + case "MjNotifyEnabled": + setting.MjNotifyEnabled = boolValue + case "MjAccountFilterEnabled": + setting.MjAccountFilterEnabled = boolValue + case "MjModeClearEnabled": + setting.MjModeClearEnabled = boolValue + case "MjForwardUrlEnabled": + setting.MjForwardUrlEnabled = boolValue + case "MjActionCheckSuccessEnabled": + setting.MjActionCheckSuccessEnabled = boolValue + case "CheckSensitiveEnabled": + setting.CheckSensitiveEnabled = boolValue + case "DemoSiteEnabled": + operation_setting.DemoSiteEnabled = boolValue + case "SelfUseModeEnabled": + operation_setting.SelfUseModeEnabled = boolValue + case "CheckSensitiveOnPromptEnabled": + setting.CheckSensitiveOnPromptEnabled = boolValue + case "ModelRequestRateLimitEnabled": + setting.ModelRequestRateLimitEnabled = boolValue + case "StopOnSensitiveEnabled": + setting.StopOnSensitiveEnabled = boolValue + case "SMTPSSLEnabled": + common.SMTPSSLEnabled = boolValue + } + } + switch key { + case "EmailDomainWhitelist": + common.EmailDomainWhitelist = strings.Split(value, ",") + case "SMTPServer": + common.SMTPServer = value + case "SMTPPort": + intValue, _ := strconv.Atoi(value) + common.SMTPPort = intValue + case "SMTPAccount": + common.SMTPAccount = value + case "SMTPFrom": + common.SMTPFrom = value + case "SMTPToken": + common.SMTPToken = value + case "ServerAddress": + setting.ServerAddress = value + case "WorkerUrl": + setting.WorkerUrl = value + case "WorkerValidKey": + setting.WorkerValidKey = value + case "PayAddress": + setting.PayAddress = value + case "Chats": + err = setting.UpdateChatsByJsonString(value) + case "CustomCallbackAddress": + setting.CustomCallbackAddress = value + case "EpayId": + setting.EpayId = value + case "EpayKey": + setting.EpayKey = value + case "Price": + setting.Price, _ = strconv.ParseFloat(value, 64) + case "MinTopUp": + setting.MinTopUp, _ = strconv.Atoi(value) + case "TopupGroupRatio": + err = common.UpdateTopupGroupRatioByJSONString(value) + case "GitHubClientId": + common.GitHubClientId = value + case "GitHubClientSecret": + common.GitHubClientSecret = value + case "LinuxDOClientId": + common.LinuxDOClientId = value + case "LinuxDOClientSecret": + common.LinuxDOClientSecret = value + case "Footer": + common.Footer = value + case "SystemName": + common.SystemName = value + case "Logo": + common.Logo = value + case "WeChatServerAddress": + common.WeChatServerAddress = value + case "WeChatServerToken": + common.WeChatServerToken = value + case "WeChatAccountQRCodeImageURL": + common.WeChatAccountQRCodeImageURL = value + case "TelegramBotToken": + common.TelegramBotToken = value + case "TelegramBotName": + common.TelegramBotName = value + case "TurnstileSiteKey": + common.TurnstileSiteKey = value + case "TurnstileSecretKey": + common.TurnstileSecretKey = value + case "QuotaForNewUser": + common.QuotaForNewUser, _ = strconv.Atoi(value) + case "QuotaForInviter": + common.QuotaForInviter, _ = strconv.Atoi(value) + case "QuotaForInvitee": + common.QuotaForInvitee, _ = strconv.Atoi(value) + case "QuotaRemindThreshold": + common.QuotaRemindThreshold, _ = strconv.Atoi(value) + case "PreConsumedQuota": + common.PreConsumedQuota, _ = strconv.Atoi(value) + case "ModelRequestRateLimitCount": + setting.ModelRequestRateLimitCount, _ = strconv.Atoi(value) + case "ModelRequestRateLimitDurationMinutes": + setting.ModelRequestRateLimitDurationMinutes, _ = strconv.Atoi(value) + case "ModelRequestRateLimitSuccessCount": + setting.ModelRequestRateLimitSuccessCount, _ = strconv.Atoi(value) + case "RetryTimes": + common.RetryTimes, _ = strconv.Atoi(value) + case "DataExportInterval": + common.DataExportInterval, _ = strconv.Atoi(value) + case "DataExportDefaultTime": + common.DataExportDefaultTime = value + case "ModelRatio": + err = operation_setting.UpdateModelRatioByJSONString(value) + case "GroupRatio": + err = setting.UpdateGroupRatioByJSONString(value) + case "UserUsableGroups": + err = setting.UpdateUserUsableGroupsByJSONString(value) + case "CompletionRatio": + err = operation_setting.UpdateCompletionRatioByJSONString(value) + case "ModelPrice": + err = operation_setting.UpdateModelPriceByJSONString(value) + case "CacheRatio": + err = operation_setting.UpdateCacheRatioByJSONString(value) + case "TopUpLink": + common.TopUpLink = value + //case "ChatLink": + // common.ChatLink = value + //case "ChatLink2": + // common.ChatLink2 = value + case "ChannelDisableThreshold": + common.ChannelDisableThreshold, _ = strconv.ParseFloat(value, 64) + case "QuotaPerUnit": + common.QuotaPerUnit, _ = strconv.ParseFloat(value, 64) + case "SensitiveWords": + setting.SensitiveWordsFromString(value) + case "AutomaticDisableKeywords": + operation_setting.AutomaticDisableKeywordsFromString(value) + case "StreamCacheQueueLength": + setting.StreamCacheQueueLength, _ = strconv.Atoi(value) + } + return err +} + +// handleConfigUpdate 处理分层配置更新,返回是否已处理 +func handleConfigUpdate(key, value string) bool { + parts := strings.SplitN(key, ".", 2) + if len(parts) != 2 { + return false // 不是分层配置 + } + + configName := parts[0] + configKey := parts[1] + + // 获取配置对象 + cfg := config.GlobalConfig.Get(configName) + if cfg == nil { + return false // 未注册的配置 + } + + // 更新配置 + configMap := map[string]string{ + configKey: value, + } + config.UpdateConfigFromMap(cfg, configMap) + + return true // 已处理 +} \ No newline at end of file diff --git a/middleware/model-rate-limit.go b/middleware/model-rate-limit.go index d4199ece..b0047b70 100644 --- a/middleware/model-rate-limit.go +++ b/middleware/model-rate-limit.go @@ -19,25 +19,20 @@ const ( ModelRequestRateLimitSuccessCountMark = "MRRLS" ) -// 检查Redis中的请求限制 func checkRedisRateLimit(ctx context.Context, rdb *redis.Client, key string, maxCount int, duration int64) (bool, error) { - // 如果maxCount为0,表示不限制 if maxCount == 0 { return true, nil } - // 获取当前计数 length, err := rdb.LLen(ctx, key).Result() if err != nil { return false, err } - // 如果未达到限制,允许请求 if length < int64(maxCount) { return true, nil } - // 检查时间窗口 oldTimeStr, _ := rdb.LIndex(ctx, key, -1).Result() oldTime, err := time.Parse(timeFormat, oldTimeStr) if err != nil { @@ -49,7 +44,6 @@ func checkRedisRateLimit(ctx context.Context, rdb *redis.Client, key string, max if err != nil { return false, err } - // 如果在时间窗口内已达到限制,拒绝请求 subTime := nowTime.Sub(oldTime).Seconds() if int64(subTime) < duration { rdb.Expire(ctx, key, time.Duration(setting.ModelRequestRateLimitDurationMinutes)*time.Minute) @@ -59,9 +53,7 @@ func checkRedisRateLimit(ctx context.Context, rdb *redis.Client, key string, max return true, nil } -// 记录Redis请求 func recordRedisRequest(ctx context.Context, rdb *redis.Client, key string, maxCount int) { - // 如果maxCount为0,不记录请求 if maxCount == 0 { return } @@ -72,14 +64,12 @@ func recordRedisRequest(ctx context.Context, rdb *redis.Client, key string, maxC rdb.Expire(ctx, key, time.Duration(setting.ModelRequestRateLimitDurationMinutes)*time.Minute) } -// Redis限流处理器 func redisRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) gin.HandlerFunc { return func(c *gin.Context) { userId := strconv.Itoa(c.GetInt("id")) ctx := context.Background() rdb := common.RDB - // 1. 检查成功请求数限制 successKey := fmt.Sprintf("rateLimit:%s:%s", ModelRequestRateLimitSuccessCountMark, userId) allowed, err := checkRedisRateLimit(ctx, rdb, successKey, successMaxCount, duration) if err != nil { @@ -92,9 +82,7 @@ func redisRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) g return } - //2.检查总请求数限制并记录总请求(当totalMaxCount为0时会自动跳过,使用令牌桶限流器 totalKey := fmt.Sprintf("rateLimit:%s", userId) - // 初始化 tb := limiter.New(ctx, rdb) allowed, err = tb.Allow( ctx, @@ -114,17 +102,14 @@ func redisRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) g abortWithOpenAiMessage(c, http.StatusTooManyRequests, fmt.Sprintf("您已达到总请求数限制:%d分钟内最多请求%d次,包括失败次数,请检查您的请求是否正确", setting.ModelRequestRateLimitDurationMinutes, totalMaxCount)) } - // 4. 处理请求 c.Next() - // 5. 如果请求成功,记录成功请求 if c.Writer.Status() < 400 { recordRedisRequest(ctx, rdb, successKey, successMaxCount) } } } -// 内存限流处理器 func memoryRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) gin.HandlerFunc { inMemoryRateLimiter.Init(time.Duration(setting.ModelRequestRateLimitDurationMinutes) * time.Minute) @@ -133,15 +118,12 @@ func memoryRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) totalKey := ModelRequestRateLimitCountMark + userId successKey := ModelRequestRateLimitSuccessCountMark + userId - // 1. 检查总请求数限制(当totalMaxCount为0时跳过) if totalMaxCount > 0 && !inMemoryRateLimiter.Request(totalKey, totalMaxCount, duration) { c.Status(http.StatusTooManyRequests) c.Abort() return } - // 2. 检查成功请求数限制 - // 使用一个临时key来检查限制,这样可以避免实际记录 checkKey := successKey + "_check" if !inMemoryRateLimiter.Request(checkKey, successMaxCount, duration) { c.Status(http.StatusTooManyRequests) @@ -149,54 +131,47 @@ func memoryRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) return } - // 3. 处理请求 c.Next() - // 4. 如果请求成功,记录到实际的成功请求计数中 if c.Writer.Status() < 400 { inMemoryRateLimiter.Request(successKey, successMaxCount, duration) } } } -// ModelRequestRateLimit 模型请求限流中间件 func ModelRequestRateLimit() func(c *gin.Context) { return func(c *gin.Context) { - // 在每个请求时检查是否启用限流 if !setting.ModelRequestRateLimitEnabled { c.Next() return } - // 计算通用限流参数 duration := int64(setting.ModelRequestRateLimitDurationMinutes * 60) - // 获取用户组 group := c.GetString("token_group") if group == "" { group = c.GetString("group") } if group == "" { - group = "default" // 默认组 + group = "default" } - // 尝试获取用户组特定的限制 + finalTotalCount := setting.ModelRequestRateLimitCount + finalSuccessCount := setting.ModelRequestRateLimitSuccessCount + foundGroupLimit := false + groupTotalCount, groupSuccessCount, found := setting.GetGroupRateLimit(group) - - // 确定最终的限制值 - finalTotalCount := setting.ModelRequestRateLimitCount // 默认使用全局总次数限制 - finalSuccessCount := setting.ModelRequestRateLimitSuccessCount // 默认使用全局成功次数限制 - if found { - // 如果找到用户组特定限制,则使用它们 finalTotalCount = groupTotalCount finalSuccessCount = groupSuccessCount + foundGroupLimit = true common.LogWarn(c.Request.Context(), fmt.Sprintf("Using rate limit for group '%s': total=%d, success=%d", group, finalTotalCount, finalSuccessCount)) - } else { + } + + if !foundGroupLimit { common.LogInfo(c.Request.Context(), fmt.Sprintf("No specific rate limit found for group '%s', using global limits: total=%d, success=%d", group, finalTotalCount, finalSuccessCount)) } - // 根据存储类型选择并执行限流处理器,传入最终确定的限制值 if common.RedisEnabled { redisRateLimitHandler(duration, finalTotalCount, finalSuccessCount)(c) } else { diff --git a/model/option.go b/model/option.go index 1f5fb3aa..79556737 100644 --- a/model/option.go +++ b/model/option.go @@ -94,11 +94,12 @@ func InitOptionMap() { common.OptionMap["ModelRequestRateLimitCount"] = strconv.Itoa(setting.ModelRequestRateLimitCount) common.OptionMap["ModelRequestRateLimitDurationMinutes"] = strconv.Itoa(setting.ModelRequestRateLimitDurationMinutes) common.OptionMap["ModelRequestRateLimitSuccessCount"] = strconv.Itoa(setting.ModelRequestRateLimitSuccessCount) + jsonBytes, _ := json.Marshal(map[string][2]int{}) + common.OptionMap["ModelRequestRateLimitGroup"] = string(jsonBytes) common.OptionMap["ModelRatio"] = operation_setting.ModelRatio2JSONString() common.OptionMap["ModelPrice"] = operation_setting.ModelPrice2JSONString() common.OptionMap["CacheRatio"] = operation_setting.CacheRatio2JSONString() common.OptionMap["GroupRatio"] = setting.GroupRatio2JSONString() - common.OptionMap[setting.ModelRequestRateLimitGroupKey] = "{}" // 添加用户组速率限制默认值 common.OptionMap["UserUsableGroups"] = setting.UserUsableGroups2JSONString() common.OptionMap["CompletionRatio"] = operation_setting.CompletionRatio2JSONString() common.OptionMap["TopUpLink"] = common.TopUpLink @@ -153,47 +154,31 @@ func SyncOptions(frequency int) { } func UpdateOption(key string, value string) error { - originalValue := value // 保存原始值以备后用 + originalValue := value - // Validate and format specific keys before saving - if key == setting.ModelRequestRateLimitGroupKey { + if key == "ModelRequestRateLimitGroup" { var cfg map[string][2]int - // Validate the JSON structure first using the original value err := json.Unmarshal([]byte(originalValue), &cfg) if err != nil { - // 提供更具体的错误信息 return fmt.Errorf("无效的 JSON 格式 for %s: %w", key, err) } - // TODO: 可以添加更细致的结构验证,例如检查数组长度是否为2,值是否为非负数等。 - // if !isValidModelRequestRateLimitGroupConfig(cfg) { - // return fmt.Errorf("无效的配置值 for %s", key) - // } - // If valid, format the JSON before saving formattedValueBytes, marshalErr := json.MarshalIndent(cfg, "", " ") if marshalErr != nil { - // This should ideally not happen if validation passed, but handle defensively return fmt.Errorf("failed to marshal validated %s config: %w", key, marshalErr) } - value = string(formattedValueBytes) // Use formatted JSON for saving and memory update + value = string(formattedValueBytes) } - // Save to database option := Option{ Key: key, } - // https://gorm.io/docs/update.html#Save-All-Fields DB.FirstOrCreate(&option, Option{Key: key}) option.Value = value - // Save is a combination function. - // If save value does not contain primary key, it will execute Create, - // otherwise it will execute Update (with all fields). if err := DB.Save(&option).Error; err != nil { - return fmt.Errorf("保存选项 %s 到数据库失败: %w", key, err) // 添加错误上下文 + return fmt.Errorf("保存选项 %s 到数据库失败: %w", key, err) } - // Update OptionMap in memory using the potentially formatted value - // updateOptionMap 会处理内存中 setting.ModelRequestRateLimitGroupConfig 的更新 return updateOptionMap(key, value) } @@ -370,6 +355,8 @@ func updateOptionMap(key string, value string) (err error) { setting.ModelRequestRateLimitDurationMinutes, _ = strconv.Atoi(value) case "ModelRequestRateLimitSuccessCount": setting.ModelRequestRateLimitSuccessCount, _ = strconv.Atoi(value) + case "ModelRequestRateLimitGroup": + err = setting.UpdateModelRequestRateLimitGroup(value) case "RetryTimes": common.RetryTimes, _ = strconv.Atoi(value) case "DataExportInterval": @@ -404,15 +391,6 @@ func updateOptionMap(key string, value string) (err error) { operation_setting.AutomaticDisableKeywordsFromString(value) case "StreamCacheQueueLength": setting.StreamCacheQueueLength, _ = strconv.Atoi(value) - case setting.ModelRequestRateLimitGroupKey: - // Use the (potentially formatted) value passed from UpdateOption - // to update the actual configuration in memory. - // This is the single point where the memory state for this specific setting is updated. - err = setting.UpdateModelRequestRateLimitGroupConfig(value) - if err != nil { - // 添加错误上下文 - err = fmt.Errorf("更新内存中的 %s 配置失败: %w", key, err) - } } return err } @@ -440,4 +418,4 @@ func handleConfigUpdate(key, value string) bool { config.UpdateConfigFromMap(cfg, configMap) return true // 已处理 -} +} \ No newline at end of file diff --git a/setting/rate_limit.go b/setting/rate_limit.go index c83885a6..5be75cc1 100644 --- a/setting/rate_limit.go +++ b/setting/rate_limit.go @@ -11,24 +11,17 @@ var ModelRequestRateLimitEnabled = false var ModelRequestRateLimitDurationMinutes = 1 var ModelRequestRateLimitCount = 0 var ModelRequestRateLimitSuccessCount = 1000 +var ModelRequestRateLimitGroup map[string][2]int -// ModelRequestRateLimitGroupKey 定义了模型请求按组速率限制的配置键 -const ModelRequestRateLimitGroupKey = "ModelRequestRateLimitGroup" - -// ModelRequestRateLimitGroupConfig 存储按用户组解析后的速率限制配置 -// map[groupName][2]int{totalCount, successCount} -var ModelRequestRateLimitGroupConfig map[string][2]int var ModelRequestRateLimitGroupMutex sync.RWMutex -// UpdateModelRequestRateLimitGroupConfig 解析、校验并更新内存中的用户组速率限制配置 -func UpdateModelRequestRateLimitGroupConfig(jsonStr string) error { +func UpdateModelRequestRateLimitGroup(jsonStr string) error { ModelRequestRateLimitGroupMutex.Lock() defer ModelRequestRateLimitGroupMutex.Unlock() var newConfig map[string][2]int if jsonStr == "" || jsonStr == "{}" { - // 如果配置为空或空JSON对象,则清空内存配置 - ModelRequestRateLimitGroupConfig = make(map[string][2]int) + ModelRequestRateLimitGroup = make(map[string][2]int) common.SysLog("Model request rate limit group config cleared") return nil } @@ -38,37 +31,19 @@ func UpdateModelRequestRateLimitGroupConfig(jsonStr string) error { return fmt.Errorf("failed to unmarshal ModelRequestRateLimitGroup config: %w", err) } - // 校验配置值 - for group, limits := range newConfig { - if len(limits) != 2 { - return fmt.Errorf("invalid config for group '%s': limits array length must be 2", group) - } - if limits[1] <= 0 { // successCount must be greater than 0 - return fmt.Errorf("invalid config for group '%s': successCount (limits[1]) must be greater than 0", group) - } - if limits[0] < 0 { // totalCount can be 0 (no limit) or positive - return fmt.Errorf("invalid config for group '%s': totalCount (limits[0]) cannot be negative", group) - } - if limits[0] > 0 && limits[0] < limits[1] { // If totalCount is set, it must be >= successCount - return fmt.Errorf("invalid config for group '%s': totalCount (limits[0]) must be greater than or equal to successCount (limits[1]) when totalCount > 0", group) - } - } - - ModelRequestRateLimitGroupConfig = newConfig - common.SysLog("Model request rate limit group config updated") + ModelRequestRateLimitGroup = newConfig return nil } -// GetGroupRateLimit 安全地获取指定用户组的速率限制值 func GetGroupRateLimit(group string) (totalCount, successCount int, found bool) { ModelRequestRateLimitGroupMutex.RLock() defer ModelRequestRateLimitGroupMutex.RUnlock() - if ModelRequestRateLimitGroupConfig == nil { - return 0, 0, false // 配置尚未初始化 + if ModelRequestRateLimitGroup == nil { + return 0, 0, false } - limits, found := ModelRequestRateLimitGroupConfig[group] + limits, found := ModelRequestRateLimitGroup[group] if !found { return 0, 0, false } diff --git a/web/src/components/RateLimitSetting.js b/web/src/components/RateLimitSetting.js index ad6b53da..7e206672 100644 --- a/web/src/components/RateLimitSetting.js +++ b/web/src/components/RateLimitSetting.js @@ -9,59 +9,59 @@ import RequestRateLimit from '../pages/Setting/RateLimit/SettingsRequestRateLimi const RateLimitSetting = () => { const { t } = useTranslation(); let [inputs, setInputs] = useState({ - ModelRequestRateLimitEnabled: false, - ModelRequestRateLimitCount: 0, - ModelRequestRateLimitSuccessCount: 1000, - ModelRequestRateLimitDurationMinutes: 1, - ModelRequestRateLimitGroup: {}, + ModelRequestRateLimitEnabled: false, + ModelRequestRateLimitCount: 0, + ModelRequestRateLimitSuccessCount: 1000, + ModelRequestRateLimitDurationMinutes: 1, + ModelRequestRateLimitGroup: '{}', }); - + let [loading, setLoading] = useState(false); - + const getOptions = async () => { - const res = await API.get('/api/option/'); - const { success, message, data } = res.data; - if (success) { - let newInputs = {}; - data.forEach((item) => { - if (item.key.endsWith('Enabled')) { - newInputs[item.key] = item.value === 'true' ? true : false; - } else { - newInputs[item.key] = item.value; - } - }); - - setInputs(newInputs); - } else { - showError(message); - } + const res = await API.get('/api/option/'); + const { success, message, data } = res.data; + if (success) { + let newInputs = {}; + data.forEach((item) => { + // 检查 key 是否在初始 inputs 中定义 + if (Object.prototype.hasOwnProperty.call(inputs, item.key)) { + if (item.key.endsWith('Enabled')) { + newInputs[item.key] = item.value === 'true'; + } else { + newInputs[item.key] = item.value; + } + } + }); + setInputs(newInputs); + } else { + showError(message); + } }; async function onRefresh() { - try { - setLoading(true); - await getOptions(); - // showSuccess('刷新成功'); - } catch (error) { - showError('刷新失败'); - } finally { - setLoading(false); - } + try { + setLoading(true); + await getOptions(); + } catch (error) { + showError('刷新失败'); + } finally { + setLoading(false); + } } - + useEffect(() => { - onRefresh(); + onRefresh(); }, []); - + return ( - <> - - {/* AI请求速率限制 */} - - - - - + <> + + + + + + ); -}; - -export default RateLimitSetting; + }; + + export default RateLimitSetting; diff --git a/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js b/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js index ec1c2158..2434020e 100644 --- a/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js +++ b/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js @@ -14,209 +14,201 @@ export default function RequestRateLimit(props) { const [loading, setLoading] = useState(false); const [inputs, setInputs] = useState({ - ModelRequestRateLimitEnabled: false, - ModelRequestRateLimitCount: -1, - ModelRequestRateLimitSuccessCount: 1000, - ModelRequestRateLimitDurationMinutes: 1, - ModelRequestRateLimitGroup: '{}', // 添加新字段并设置默认值 + ModelRequestRateLimitEnabled: false, + ModelRequestRateLimitCount: -1, + ModelRequestRateLimitSuccessCount: 1000, + ModelRequestRateLimitDurationMinutes: 1, + ModelRequestRateLimitGroup: '{}', }); const refForm = useRef(); const [inputsRow, setInputsRow] = useState(inputs); - + function onSubmit() { - const updateArray = compareObjects(inputs, inputsRow); - if (!updateArray.length) return showWarning(t('你似乎并没有修改什么')); - const requestQueue = updateArray.map((item) => { - let value = ''; - if (typeof inputs[item.key] === 'boolean') { - value = String(inputs[item.key]); - } else { - value = inputs[item.key]; - } - // 校验 ModelRequestRateLimitGroup 是否为有效的 JSON 对象字符串 - if (item.key === 'ModelRequestRateLimitGroup') { - try { - JSON.parse(value); - } catch (e) { - showError(t('用户组速率限制配置不是有效的 JSON 格式!')); - // 阻止请求发送 - return Promise.reject('Invalid JSON format'); - } - } - return API.put('/api/option/', { - key: item.key, - value, - }); - }); - - // 过滤掉无效的请求(例如,无效的 JSON) - const validRequests = requestQueue.filter(req => req !== null && req !== undefined && typeof req.then === 'function'); - - if (validRequests.length === 0 && requestQueue.length > 0) { - // 如果所有请求都被过滤掉了(因为 JSON 无效),则不继续执行 - return; - } - - setLoading(true); - Promise.all(validRequests) - .then((res) => { - if (validRequests.length === 1) { - if (res.includes(undefined)) return; - } else if (validRequests.length > 1) { - if (res.includes(undefined)) - return showError(t('部分保存失败,请重试')); - } - showSuccess(t('保存成功')); - props.refresh(); - // 更新 inputsRow 以反映保存后的状态 - setInputsRow(structuredClone(inputs)); - }) - .catch((error) => { - // 检查是否是由于无效 JSON 导致的错误 - if (error !== 'Invalid JSON format') { - showError(t('保存失败,请重试')); - } - }) - .finally(() => { - setLoading(false); - }); + const updateArray = compareObjects(inputs, inputsRow); + if (!updateArray.length) return showWarning(t('你似乎并没有修改什么')); + const requestQueue = updateArray.map((item) => { + let value = ''; + if (typeof inputs[item.key] === 'boolean') { + value = String(inputs[item.key]); + } else { + value = inputs[item.key]; + } + if (item.key === 'ModelRequestRateLimitGroup') { + try { + JSON.parse(value); + } catch (e) { + showError(t('用户组速率限制配置不是有效的 JSON 格式!')); + return Promise.reject('Invalid JSON format'); + } + } + return API.put('/api/option/', { + key: item.key, + value, + }); + }); + + const validRequests = requestQueue.filter(req => req !== null && req !== undefined && typeof req.then === 'function'); + + if (validRequests.length === 0 && requestQueue.length > 0) { + return; + } + + setLoading(true); + Promise.all(validRequests) + .then((res) => { + if (validRequests.length === 1) { + if (res.includes(undefined)) return; + } else if (validRequests.length > 1) { + if (res.includes(undefined)) + return showError(t('部分保存失败,请重试')); + } + showSuccess(t('保存成功')); + props.refresh(); + setInputsRow(structuredClone(inputs)); + }) + .catch((error) => { + if (error !== 'Invalid JSON format') { + showError(t('保存失败,请重试')); + } + }) + .finally(() => { + setLoading(false); + }); } - + useEffect(() => { - const currentInputs = {}; - for (let key in props.options) { - if (Object.keys(inputs).includes(key)) { - currentInputs[key] = props.options[key]; - } - } - setInputs(currentInputs); - setInputsRow(structuredClone(currentInputs)); - // 检查 refForm.current 是否存在 - if (refForm.current) { - refForm.current.setValues(currentInputs); - } - }, [props.options]); // 依赖项保持不变,因为 inputs 状态的结构已固定 - + const currentInputs = {}; + for (let key in props.options) { + if (Object.prototype.hasOwnProperty.call(inputs, key)) { // 使用 hasOwnProperty 检查 + currentInputs[key] = props.options[key]; + } + } + setInputs(currentInputs); + setInputsRow(structuredClone(currentInputs)); + if (refForm.current) { + refForm.current.setValues(currentInputs); + } + }, [props.options]); + return ( - <> - -
(refForm.current = formAPI)} - style={{ marginBottom: 15 }} - > - - - - { - setInputs({ - ...inputs, - ModelRequestRateLimitEnabled: value, - }); - }} - /> - - - - - - setInputs({ - ...inputs, - ModelRequestRateLimitDurationMinutes: String(value), - }) - } - /> - - - - - - setInputs({ - ...inputs, - ModelRequestRateLimitCount: String(value), - }) - } - /> - - - - setInputs({ - ...inputs, - ModelRequestRateLimitSuccessCount: String(value), - }) - } - /> - - - {/* 用户组速率限制配置项 */} - - - -

{t('说明:')}

-
    -
  • {t('使用 JSON 对象格式,键为用户组名 (字符串),值为包含两个整数的数组 [总次数限制, 成功次数限制]。')}
  • -
  • {t('总次数限制: 周期内允许的总请求次数 (含失败),0 代表不限制。')}
  • -
  • {t('成功次数限制: 周期内允许的成功请求次数 (HTTP < 400),必须大于 0。')}
  • -
  • {t('此配置将优先于上方的全局限制设置。')}
  • -
  • {t('未在此处配置的用户组将使用全局限制。')}
  • -
  • {t('限制周期统一使用上方配置的“限制周期”值。')}
  • -
  • {t('输入无效的 JSON 将无法保存。')}
  • -
- - } - autosize={{ minRows: 5, maxRows: 15 }} - style={{ fontFamily: 'monospace' }} - onChange={(value) => { - setInputs({ - ...inputs, - ModelRequestRateLimitGroup: value, // 直接更新字符串值 - }); - }} - /> - -
- - - -
-
-
- + <> + +
(refForm.current = formAPI)} + style={{ marginBottom: 15 }} + > + + + + { + setInputs({ + ...inputs, + ModelRequestRateLimitEnabled: value, + }); + }} + /> + + + + + + setInputs({ + ...inputs, + ModelRequestRateLimitDurationMinutes: String(value), + }) + } + /> + + + + + + setInputs({ + ...inputs, + ModelRequestRateLimitCount: String(value), + }) + } + /> + + + + setInputs({ + ...inputs, + ModelRequestRateLimitSuccessCount: String(value), + }) + } + /> + + + + + +

{t('说明:')}

+
    +
  • {t('使用 JSON 对象格式,键为用户组名 (字符串),值为包含两个整数的数组 [总次数限制, 成功次数限制]。')}
  • +
  • {t('总次数限制: 周期内允许的总请求次数 (含失败),0 代表不限制。')}
  • +
  • {t('成功次数限制: 周期内允许的成功请求次数 (HTTP < 400),必须大于 0。')}
  • +
  • {t('此配置将优先于上方的全局限制设置。')}
  • +
  • {t('未在此处配置的用户组将使用全局限制。')}
  • +
  • {t('限制周期统一使用上方配置的“限制周期”值。')}
  • +
  • {t('输入无效的 JSON 将无法保存。')}
  • +
+ + } + autosize={{ minRows: 5, maxRows: 15 }} + style={{ fontFamily: 'monospace' }} + onChange={(value) => { + setInputs({ + ...inputs, + ModelRequestRateLimitGroup: value, + }); + }} + /> + +
+ + + +
+
+
+ ); -} + } From 1e1d24d1b075042473902991cbc3610f6c8bfff8 Mon Sep 17 00:00:00 2001 From: tbphp Date: Mon, 5 May 2025 17:57:02 +0800 Subject: [PATCH 03/16] fix: rm debug file --- .old/option.go | 402 ------------------------------------------------- 1 file changed, 402 deletions(-) delete mode 100644 .old/option.go diff --git a/.old/option.go b/.old/option.go deleted file mode 100644 index f80f5cb3..00000000 --- a/.old/option.go +++ /dev/null @@ -1,402 +0,0 @@ -package model - -import ( - "one-api/common" - "one-api/setting" - "one-api/setting/config" - "one-api/setting/operation_setting" - "strconv" - "strings" - "time" -) - -type Option struct { - Key string `json:"key" gorm:"primaryKey"` - Value string `json:"value"` -} - -func AllOption() ([]*Option, error) { - var options []*Option - var err error - err = DB.Find(&options).Error - return options, err -} - -func InitOptionMap() { - common.OptionMapRWMutex.Lock() - common.OptionMap = make(map[string]string) - - // 添加原有的系统配置 - common.OptionMap["FileUploadPermission"] = strconv.Itoa(common.FileUploadPermission) - common.OptionMap["FileDownloadPermission"] = strconv.Itoa(common.FileDownloadPermission) - common.OptionMap["ImageUploadPermission"] = strconv.Itoa(common.ImageUploadPermission) - common.OptionMap["ImageDownloadPermission"] = strconv.Itoa(common.ImageDownloadPermission) - common.OptionMap["PasswordLoginEnabled"] = strconv.FormatBool(common.PasswordLoginEnabled) - common.OptionMap["PasswordRegisterEnabled"] = strconv.FormatBool(common.PasswordRegisterEnabled) - common.OptionMap["EmailVerificationEnabled"] = strconv.FormatBool(common.EmailVerificationEnabled) - common.OptionMap["GitHubOAuthEnabled"] = strconv.FormatBool(common.GitHubOAuthEnabled) - common.OptionMap["LinuxDOOAuthEnabled"] = strconv.FormatBool(common.LinuxDOOAuthEnabled) - common.OptionMap["TelegramOAuthEnabled"] = strconv.FormatBool(common.TelegramOAuthEnabled) - common.OptionMap["WeChatAuthEnabled"] = strconv.FormatBool(common.WeChatAuthEnabled) - common.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(common.TurnstileCheckEnabled) - common.OptionMap["RegisterEnabled"] = strconv.FormatBool(common.RegisterEnabled) - common.OptionMap["AutomaticDisableChannelEnabled"] = strconv.FormatBool(common.AutomaticDisableChannelEnabled) - common.OptionMap["AutomaticEnableChannelEnabled"] = strconv.FormatBool(common.AutomaticEnableChannelEnabled) - common.OptionMap["LogConsumeEnabled"] = strconv.FormatBool(common.LogConsumeEnabled) - common.OptionMap["DisplayInCurrencyEnabled"] = strconv.FormatBool(common.DisplayInCurrencyEnabled) - common.OptionMap["DisplayTokenStatEnabled"] = strconv.FormatBool(common.DisplayTokenStatEnabled) - common.OptionMap["DrawingEnabled"] = strconv.FormatBool(common.DrawingEnabled) - common.OptionMap["TaskEnabled"] = strconv.FormatBool(common.TaskEnabled) - common.OptionMap["DataExportEnabled"] = strconv.FormatBool(common.DataExportEnabled) - common.OptionMap["ChannelDisableThreshold"] = strconv.FormatFloat(common.ChannelDisableThreshold, 'f', -1, 64) - common.OptionMap["EmailDomainRestrictionEnabled"] = strconv.FormatBool(common.EmailDomainRestrictionEnabled) - common.OptionMap["EmailAliasRestrictionEnabled"] = strconv.FormatBool(common.EmailAliasRestrictionEnabled) - common.OptionMap["EmailDomainWhitelist"] = strings.Join(common.EmailDomainWhitelist, ",") - common.OptionMap["SMTPServer"] = "" - common.OptionMap["SMTPFrom"] = "" - common.OptionMap["SMTPPort"] = strconv.Itoa(common.SMTPPort) - common.OptionMap["SMTPAccount"] = "" - common.OptionMap["SMTPToken"] = "" - common.OptionMap["SMTPSSLEnabled"] = strconv.FormatBool(common.SMTPSSLEnabled) - common.OptionMap["Notice"] = "" - common.OptionMap["About"] = "" - common.OptionMap["HomePageContent"] = "" - common.OptionMap["Footer"] = common.Footer - common.OptionMap["SystemName"] = common.SystemName - common.OptionMap["Logo"] = common.Logo - common.OptionMap["ServerAddress"] = "" - common.OptionMap["WorkerUrl"] = setting.WorkerUrl - common.OptionMap["WorkerValidKey"] = setting.WorkerValidKey - common.OptionMap["PayAddress"] = "" - common.OptionMap["CustomCallbackAddress"] = "" - common.OptionMap["EpayId"] = "" - common.OptionMap["EpayKey"] = "" - common.OptionMap["Price"] = strconv.FormatFloat(setting.Price, 'f', -1, 64) - common.OptionMap["MinTopUp"] = strconv.Itoa(setting.MinTopUp) - common.OptionMap["TopupGroupRatio"] = common.TopupGroupRatio2JSONString() - common.OptionMap["Chats"] = setting.Chats2JsonString() - common.OptionMap["GitHubClientId"] = "" - common.OptionMap["GitHubClientSecret"] = "" - common.OptionMap["TelegramBotToken"] = "" - common.OptionMap["TelegramBotName"] = "" - common.OptionMap["WeChatServerAddress"] = "" - common.OptionMap["WeChatServerToken"] = "" - common.OptionMap["WeChatAccountQRCodeImageURL"] = "" - common.OptionMap["TurnstileSiteKey"] = "" - common.OptionMap["TurnstileSecretKey"] = "" - common.OptionMap["QuotaForNewUser"] = strconv.Itoa(common.QuotaForNewUser) - common.OptionMap["QuotaForInviter"] = strconv.Itoa(common.QuotaForInviter) - common.OptionMap["QuotaForInvitee"] = strconv.Itoa(common.QuotaForInvitee) - common.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(common.QuotaRemindThreshold) - common.OptionMap["PreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota) - common.OptionMap["ModelRequestRateLimitCount"] = strconv.Itoa(setting.ModelRequestRateLimitCount) - common.OptionMap["ModelRequestRateLimitDurationMinutes"] = strconv.Itoa(setting.ModelRequestRateLimitDurationMinutes) - common.OptionMap["ModelRequestRateLimitSuccessCount"] = strconv.Itoa(setting.ModelRequestRateLimitSuccessCount) - common.OptionMap["ModelRatio"] = operation_setting.ModelRatio2JSONString() - common.OptionMap["ModelPrice"] = operation_setting.ModelPrice2JSONString() - common.OptionMap["CacheRatio"] = operation_setting.CacheRatio2JSONString() - common.OptionMap["GroupRatio"] = setting.GroupRatio2JSONString() - common.OptionMap["UserUsableGroups"] = setting.UserUsableGroups2JSONString() - common.OptionMap["CompletionRatio"] = operation_setting.CompletionRatio2JSONString() - common.OptionMap["TopUpLink"] = common.TopUpLink - //common.OptionMap["ChatLink"] = common.ChatLink - //common.OptionMap["ChatLink2"] = common.ChatLink2 - common.OptionMap["QuotaPerUnit"] = strconv.FormatFloat(common.QuotaPerUnit, 'f', -1, 64) - common.OptionMap["RetryTimes"] = strconv.Itoa(common.RetryTimes) - common.OptionMap["DataExportInterval"] = strconv.Itoa(common.DataExportInterval) - common.OptionMap["DataExportDefaultTime"] = common.DataExportDefaultTime - common.OptionMap["DefaultCollapseSidebar"] = strconv.FormatBool(common.DefaultCollapseSidebar) - common.OptionMap["MjNotifyEnabled"] = strconv.FormatBool(setting.MjNotifyEnabled) - common.OptionMap["MjAccountFilterEnabled"] = strconv.FormatBool(setting.MjAccountFilterEnabled) - common.OptionMap["MjModeClearEnabled"] = strconv.FormatBool(setting.MjModeClearEnabled) - common.OptionMap["MjForwardUrlEnabled"] = strconv.FormatBool(setting.MjForwardUrlEnabled) - common.OptionMap["MjActionCheckSuccessEnabled"] = strconv.FormatBool(setting.MjActionCheckSuccessEnabled) - common.OptionMap["CheckSensitiveEnabled"] = strconv.FormatBool(setting.CheckSensitiveEnabled) - common.OptionMap["DemoSiteEnabled"] = strconv.FormatBool(operation_setting.DemoSiteEnabled) - common.OptionMap["SelfUseModeEnabled"] = strconv.FormatBool(operation_setting.SelfUseModeEnabled) - common.OptionMap["ModelRequestRateLimitEnabled"] = strconv.FormatBool(setting.ModelRequestRateLimitEnabled) - common.OptionMap["CheckSensitiveOnPromptEnabled"] = strconv.FormatBool(setting.CheckSensitiveOnPromptEnabled) - common.OptionMap["StopOnSensitiveEnabled"] = strconv.FormatBool(setting.StopOnSensitiveEnabled) - common.OptionMap["SensitiveWords"] = setting.SensitiveWordsToString() - common.OptionMap["StreamCacheQueueLength"] = strconv.Itoa(setting.StreamCacheQueueLength) - common.OptionMap["AutomaticDisableKeywords"] = operation_setting.AutomaticDisableKeywordsToString() - - // 自动添加所有注册的模型配置 - modelConfigs := config.GlobalConfig.ExportAllConfigs() - for k, v := range modelConfigs { - common.OptionMap[k] = v - } - - common.OptionMapRWMutex.Unlock() - loadOptionsFromDatabase() -} - -func loadOptionsFromDatabase() { - options, _ := AllOption() - for _, option := range options { - err := updateOptionMap(option.Key, option.Value) - if err != nil { - common.SysError("failed to update option map: " + err.Error()) - } - } -} - -func SyncOptions(frequency int) { - for { - time.Sleep(time.Duration(frequency) * time.Second) - common.SysLog("syncing options from database") - loadOptionsFromDatabase() - } -} - -func UpdateOption(key string, value string) error { - // Save to database first - option := Option{ - Key: key, - } - // https://gorm.io/docs/update.html#Save-All-Fields - DB.FirstOrCreate(&option, Option{Key: key}) - option.Value = value - // Save is a combination function. - // If save value does not contain primary key, it will execute Create, - // otherwise it will execute Update (with all fields). - DB.Save(&option) - // Update OptionMap - return updateOptionMap(key, value) -} - -func updateOptionMap(key string, value string) (err error) { - common.OptionMapRWMutex.Lock() - defer common.OptionMapRWMutex.Unlock() - common.OptionMap[key] = value - - // 检查是否是模型配置 - 使用更规范的方式处理 - if handleConfigUpdate(key, value) { - return nil // 已由配置系统处理 - } - - // 处理传统配置项... - if strings.HasSuffix(key, "Permission") { - intValue, _ := strconv.Atoi(value) - switch key { - case "FileUploadPermission": - common.FileUploadPermission = intValue - case "FileDownloadPermission": - common.FileDownloadPermission = intValue - case "ImageUploadPermission": - common.ImageUploadPermission = intValue - case "ImageDownloadPermission": - common.ImageDownloadPermission = intValue - } - } - if strings.HasSuffix(key, "Enabled") || key == "DefaultCollapseSidebar" { - boolValue := value == "true" - switch key { - case "PasswordRegisterEnabled": - common.PasswordRegisterEnabled = boolValue - case "PasswordLoginEnabled": - common.PasswordLoginEnabled = boolValue - case "EmailVerificationEnabled": - common.EmailVerificationEnabled = boolValue - case "GitHubOAuthEnabled": - common.GitHubOAuthEnabled = boolValue - case "LinuxDOOAuthEnabled": - common.LinuxDOOAuthEnabled = boolValue - case "WeChatAuthEnabled": - common.WeChatAuthEnabled = boolValue - case "TelegramOAuthEnabled": - common.TelegramOAuthEnabled = boolValue - case "TurnstileCheckEnabled": - common.TurnstileCheckEnabled = boolValue - case "RegisterEnabled": - common.RegisterEnabled = boolValue - case "EmailDomainRestrictionEnabled": - common.EmailDomainRestrictionEnabled = boolValue - case "EmailAliasRestrictionEnabled": - common.EmailAliasRestrictionEnabled = boolValue - case "AutomaticDisableChannelEnabled": - common.AutomaticDisableChannelEnabled = boolValue - case "AutomaticEnableChannelEnabled": - common.AutomaticEnableChannelEnabled = boolValue - case "LogConsumeEnabled": - common.LogConsumeEnabled = boolValue - case "DisplayInCurrencyEnabled": - common.DisplayInCurrencyEnabled = boolValue - case "DisplayTokenStatEnabled": - common.DisplayTokenStatEnabled = boolValue - case "DrawingEnabled": - common.DrawingEnabled = boolValue - case "TaskEnabled": - common.TaskEnabled = boolValue - case "DataExportEnabled": - common.DataExportEnabled = boolValue - case "DefaultCollapseSidebar": - common.DefaultCollapseSidebar = boolValue - case "MjNotifyEnabled": - setting.MjNotifyEnabled = boolValue - case "MjAccountFilterEnabled": - setting.MjAccountFilterEnabled = boolValue - case "MjModeClearEnabled": - setting.MjModeClearEnabled = boolValue - case "MjForwardUrlEnabled": - setting.MjForwardUrlEnabled = boolValue - case "MjActionCheckSuccessEnabled": - setting.MjActionCheckSuccessEnabled = boolValue - case "CheckSensitiveEnabled": - setting.CheckSensitiveEnabled = boolValue - case "DemoSiteEnabled": - operation_setting.DemoSiteEnabled = boolValue - case "SelfUseModeEnabled": - operation_setting.SelfUseModeEnabled = boolValue - case "CheckSensitiveOnPromptEnabled": - setting.CheckSensitiveOnPromptEnabled = boolValue - case "ModelRequestRateLimitEnabled": - setting.ModelRequestRateLimitEnabled = boolValue - case "StopOnSensitiveEnabled": - setting.StopOnSensitiveEnabled = boolValue - case "SMTPSSLEnabled": - common.SMTPSSLEnabled = boolValue - } - } - switch key { - case "EmailDomainWhitelist": - common.EmailDomainWhitelist = strings.Split(value, ",") - case "SMTPServer": - common.SMTPServer = value - case "SMTPPort": - intValue, _ := strconv.Atoi(value) - common.SMTPPort = intValue - case "SMTPAccount": - common.SMTPAccount = value - case "SMTPFrom": - common.SMTPFrom = value - case "SMTPToken": - common.SMTPToken = value - case "ServerAddress": - setting.ServerAddress = value - case "WorkerUrl": - setting.WorkerUrl = value - case "WorkerValidKey": - setting.WorkerValidKey = value - case "PayAddress": - setting.PayAddress = value - case "Chats": - err = setting.UpdateChatsByJsonString(value) - case "CustomCallbackAddress": - setting.CustomCallbackAddress = value - case "EpayId": - setting.EpayId = value - case "EpayKey": - setting.EpayKey = value - case "Price": - setting.Price, _ = strconv.ParseFloat(value, 64) - case "MinTopUp": - setting.MinTopUp, _ = strconv.Atoi(value) - case "TopupGroupRatio": - err = common.UpdateTopupGroupRatioByJSONString(value) - case "GitHubClientId": - common.GitHubClientId = value - case "GitHubClientSecret": - common.GitHubClientSecret = value - case "LinuxDOClientId": - common.LinuxDOClientId = value - case "LinuxDOClientSecret": - common.LinuxDOClientSecret = value - case "Footer": - common.Footer = value - case "SystemName": - common.SystemName = value - case "Logo": - common.Logo = value - case "WeChatServerAddress": - common.WeChatServerAddress = value - case "WeChatServerToken": - common.WeChatServerToken = value - case "WeChatAccountQRCodeImageURL": - common.WeChatAccountQRCodeImageURL = value - case "TelegramBotToken": - common.TelegramBotToken = value - case "TelegramBotName": - common.TelegramBotName = value - case "TurnstileSiteKey": - common.TurnstileSiteKey = value - case "TurnstileSecretKey": - common.TurnstileSecretKey = value - case "QuotaForNewUser": - common.QuotaForNewUser, _ = strconv.Atoi(value) - case "QuotaForInviter": - common.QuotaForInviter, _ = strconv.Atoi(value) - case "QuotaForInvitee": - common.QuotaForInvitee, _ = strconv.Atoi(value) - case "QuotaRemindThreshold": - common.QuotaRemindThreshold, _ = strconv.Atoi(value) - case "PreConsumedQuota": - common.PreConsumedQuota, _ = strconv.Atoi(value) - case "ModelRequestRateLimitCount": - setting.ModelRequestRateLimitCount, _ = strconv.Atoi(value) - case "ModelRequestRateLimitDurationMinutes": - setting.ModelRequestRateLimitDurationMinutes, _ = strconv.Atoi(value) - case "ModelRequestRateLimitSuccessCount": - setting.ModelRequestRateLimitSuccessCount, _ = strconv.Atoi(value) - case "RetryTimes": - common.RetryTimes, _ = strconv.Atoi(value) - case "DataExportInterval": - common.DataExportInterval, _ = strconv.Atoi(value) - case "DataExportDefaultTime": - common.DataExportDefaultTime = value - case "ModelRatio": - err = operation_setting.UpdateModelRatioByJSONString(value) - case "GroupRatio": - err = setting.UpdateGroupRatioByJSONString(value) - case "UserUsableGroups": - err = setting.UpdateUserUsableGroupsByJSONString(value) - case "CompletionRatio": - err = operation_setting.UpdateCompletionRatioByJSONString(value) - case "ModelPrice": - err = operation_setting.UpdateModelPriceByJSONString(value) - case "CacheRatio": - err = operation_setting.UpdateCacheRatioByJSONString(value) - case "TopUpLink": - common.TopUpLink = value - //case "ChatLink": - // common.ChatLink = value - //case "ChatLink2": - // common.ChatLink2 = value - case "ChannelDisableThreshold": - common.ChannelDisableThreshold, _ = strconv.ParseFloat(value, 64) - case "QuotaPerUnit": - common.QuotaPerUnit, _ = strconv.ParseFloat(value, 64) - case "SensitiveWords": - setting.SensitiveWordsFromString(value) - case "AutomaticDisableKeywords": - operation_setting.AutomaticDisableKeywordsFromString(value) - case "StreamCacheQueueLength": - setting.StreamCacheQueueLength, _ = strconv.Atoi(value) - } - return err -} - -// handleConfigUpdate 处理分层配置更新,返回是否已处理 -func handleConfigUpdate(key, value string) bool { - parts := strings.SplitN(key, ".", 2) - if len(parts) != 2 { - return false // 不是分层配置 - } - - configName := parts[0] - configKey := parts[1] - - // 获取配置对象 - cfg := config.GlobalConfig.Get(configName) - if cfg == nil { - return false // 未注册的配置 - } - - // 更新配置 - configMap := map[string]string{ - configKey: value, - } - config.UpdateConfigFromMap(cfg, configMap) - - return true // 已处理 -} \ No newline at end of file From 1513ed78477044999e066d5eb3b1fc1762dce531 Mon Sep 17 00:00:00 2001 From: tbphp Date: Mon, 5 May 2025 19:32:22 +0800 Subject: [PATCH 04/16] =?UTF-8?q?refactor:=20=E8=B0=83=E6=95=B4=E4=BB=A3?= =?UTF-8?q?=E7=A0=81=EF=BC=8C=E7=AC=A6=E5=90=88=E9=A1=B9=E7=9B=AE=E7=8E=B0?= =?UTF-8?q?=E6=9C=89=E8=A7=84=E8=8C=83?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- middleware/model-rate-limit.go | 54 +++++++++++++++++--------- model/option.go | 34 +++++----------- setting/rate_limit.go | 37 ++++++++---------- web/src/components/RateLimitSetting.js | 6 ++- 4 files changed, 65 insertions(+), 66 deletions(-) diff --git a/middleware/model-rate-limit.go b/middleware/model-rate-limit.go index b0047b70..1ca5ace6 100644 --- a/middleware/model-rate-limit.go +++ b/middleware/model-rate-limit.go @@ -6,6 +6,7 @@ import ( "net/http" "one-api/common" "one-api/common/limiter" + "one-api/constant" "one-api/setting" "strconv" "time" @@ -19,20 +20,25 @@ const ( ModelRequestRateLimitSuccessCountMark = "MRRLS" ) +// 检查Redis中的请求限制 func checkRedisRateLimit(ctx context.Context, rdb *redis.Client, key string, maxCount int, duration int64) (bool, error) { + // 如果maxCount为0,表示不限制 if maxCount == 0 { return true, nil } + // 获取当前计数 length, err := rdb.LLen(ctx, key).Result() if err != nil { return false, err } + // 如果未达到限制,允许请求 if length < int64(maxCount) { return true, nil } + // 检查时间窗口 oldTimeStr, _ := rdb.LIndex(ctx, key, -1).Result() oldTime, err := time.Parse(timeFormat, oldTimeStr) if err != nil { @@ -44,6 +50,7 @@ func checkRedisRateLimit(ctx context.Context, rdb *redis.Client, key string, max if err != nil { return false, err } + // 如果在时间窗口内已达到限制,拒绝请求 subTime := nowTime.Sub(oldTime).Seconds() if int64(subTime) < duration { rdb.Expire(ctx, key, time.Duration(setting.ModelRequestRateLimitDurationMinutes)*time.Minute) @@ -53,7 +60,9 @@ func checkRedisRateLimit(ctx context.Context, rdb *redis.Client, key string, max return true, nil } +// 记录Redis请求 func recordRedisRequest(ctx context.Context, rdb *redis.Client, key string, maxCount int) { + // 如果maxCount为0,不记录请求 if maxCount == 0 { return } @@ -64,12 +73,14 @@ func recordRedisRequest(ctx context.Context, rdb *redis.Client, key string, maxC rdb.Expire(ctx, key, time.Duration(setting.ModelRequestRateLimitDurationMinutes)*time.Minute) } +// Redis限流处理器 func redisRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) gin.HandlerFunc { return func(c *gin.Context) { userId := strconv.Itoa(c.GetInt("id")) ctx := context.Background() rdb := common.RDB + // 1. 检查成功请求数限制 successKey := fmt.Sprintf("rateLimit:%s:%s", ModelRequestRateLimitSuccessCountMark, userId) allowed, err := checkRedisRateLimit(ctx, rdb, successKey, successMaxCount, duration) if err != nil { @@ -82,7 +93,9 @@ func redisRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) g return } + //2.检查总请求数限制并记录总请求(当totalMaxCount为0时会自动跳过,使用令牌桶限流器 totalKey := fmt.Sprintf("rateLimit:%s", userId) + // 初始化 tb := limiter.New(ctx, rdb) allowed, err = tb.Allow( ctx, @@ -102,14 +115,17 @@ func redisRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) g abortWithOpenAiMessage(c, http.StatusTooManyRequests, fmt.Sprintf("您已达到总请求数限制:%d分钟内最多请求%d次,包括失败次数,请检查您的请求是否正确", setting.ModelRequestRateLimitDurationMinutes, totalMaxCount)) } + // 4. 处理请求 c.Next() + // 5. 如果请求成功,记录成功请求 if c.Writer.Status() < 400 { recordRedisRequest(ctx, rdb, successKey, successMaxCount) } } } +// 内存限流处理器 func memoryRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) gin.HandlerFunc { inMemoryRateLimiter.Init(time.Duration(setting.ModelRequestRateLimitDurationMinutes) * time.Minute) @@ -118,12 +134,15 @@ func memoryRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) totalKey := ModelRequestRateLimitCountMark + userId successKey := ModelRequestRateLimitSuccessCountMark + userId + // 1. 检查总请求数限制(当totalMaxCount为0时跳过) if totalMaxCount > 0 && !inMemoryRateLimiter.Request(totalKey, totalMaxCount, duration) { c.Status(http.StatusTooManyRequests) c.Abort() return } + // 2. 检查成功请求数限制 + // 使用一个临时key来检查限制,这样可以避免实际记录 checkKey := successKey + "_check" if !inMemoryRateLimiter.Request(checkKey, successMaxCount, duration) { c.Status(http.StatusTooManyRequests) @@ -131,51 +150,48 @@ func memoryRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) return } + // 3. 处理请求 c.Next() + // 4. 如果请求成功,记录到实际的成功请求计数中 if c.Writer.Status() < 400 { inMemoryRateLimiter.Request(successKey, successMaxCount, duration) } } } +// ModelRequestRateLimit 模型请求限流中间件 func ModelRequestRateLimit() func(c *gin.Context) { return func(c *gin.Context) { + // 在每个请求时检查是否启用限流 if !setting.ModelRequestRateLimitEnabled { c.Next() return } + // 计算限流参数 duration := int64(setting.ModelRequestRateLimitDurationMinutes * 60) + totalMaxCount := setting.ModelRequestRateLimitCount + successMaxCount := setting.ModelRequestRateLimitSuccessCount + // 获取分组 group := c.GetString("token_group") if group == "" { - group = c.GetString("group") - } - if group == "" { - group = "default" + group = c.GetString(constant.ContextKeyUserGroup) } - finalTotalCount := setting.ModelRequestRateLimitCount - finalSuccessCount := setting.ModelRequestRateLimitSuccessCount - foundGroupLimit := false - + //获取分组的限流配置 groupTotalCount, groupSuccessCount, found := setting.GetGroupRateLimit(group) if found { - finalTotalCount = groupTotalCount - finalSuccessCount = groupSuccessCount - foundGroupLimit = true - common.LogWarn(c.Request.Context(), fmt.Sprintf("Using rate limit for group '%s': total=%d, success=%d", group, finalTotalCount, finalSuccessCount)) - } - - if !foundGroupLimit { - common.LogInfo(c.Request.Context(), fmt.Sprintf("No specific rate limit found for group '%s', using global limits: total=%d, success=%d", group, finalTotalCount, finalSuccessCount)) + totalMaxCount = groupTotalCount + successMaxCount = groupSuccessCount } + // 根据存储类型选择并执行限流处理器 if common.RedisEnabled { - redisRateLimitHandler(duration, finalTotalCount, finalSuccessCount)(c) + redisRateLimitHandler(duration, totalMaxCount, successMaxCount)(c) } else { - memoryRateLimitHandler(duration, finalTotalCount, finalSuccessCount)(c) + memoryRateLimitHandler(duration, totalMaxCount, successMaxCount)(c) } } -} +} \ No newline at end of file diff --git a/model/option.go b/model/option.go index 79556737..e9c129e1 100644 --- a/model/option.go +++ b/model/option.go @@ -1,8 +1,6 @@ package model import ( - "encoding/json" - "fmt" "one-api/common" "one-api/setting" "one-api/setting/config" @@ -94,8 +92,7 @@ func InitOptionMap() { common.OptionMap["ModelRequestRateLimitCount"] = strconv.Itoa(setting.ModelRequestRateLimitCount) common.OptionMap["ModelRequestRateLimitDurationMinutes"] = strconv.Itoa(setting.ModelRequestRateLimitDurationMinutes) common.OptionMap["ModelRequestRateLimitSuccessCount"] = strconv.Itoa(setting.ModelRequestRateLimitSuccessCount) - jsonBytes, _ := json.Marshal(map[string][2]int{}) - common.OptionMap["ModelRequestRateLimitGroup"] = string(jsonBytes) + common.OptionMap["ModelRequestRateLimitGroup"] = setting.ModelRequestRateLimitGroup2JSONString() common.OptionMap["ModelRatio"] = operation_setting.ModelRatio2JSONString() common.OptionMap["ModelPrice"] = operation_setting.ModelPrice2JSONString() common.OptionMap["CacheRatio"] = operation_setting.CacheRatio2JSONString() @@ -154,31 +151,18 @@ func SyncOptions(frequency int) { } func UpdateOption(key string, value string) error { - originalValue := value - - if key == "ModelRequestRateLimitGroup" { - var cfg map[string][2]int - err := json.Unmarshal([]byte(originalValue), &cfg) - if err != nil { - return fmt.Errorf("无效的 JSON 格式 for %s: %w", key, err) - } - - formattedValueBytes, marshalErr := json.MarshalIndent(cfg, "", " ") - if marshalErr != nil { - return fmt.Errorf("failed to marshal validated %s config: %w", key, marshalErr) - } - value = string(formattedValueBytes) - } - + // Save to database first option := Option{ Key: key, } + // https://gorm.io/docs/update.html#Save-All-Fields DB.FirstOrCreate(&option, Option{Key: key}) option.Value = value - if err := DB.Save(&option).Error; err != nil { - return fmt.Errorf("保存选项 %s 到数据库失败: %w", key, err) - } - + // Save is a combination function. + // If save value does not contain primary key, it will execute Create, + // otherwise it will execute Update (with all fields). + DB.Save(&option) + // Update OptionMap return updateOptionMap(key, value) } @@ -356,7 +340,7 @@ func updateOptionMap(key string, value string) (err error) { case "ModelRequestRateLimitSuccessCount": setting.ModelRequestRateLimitSuccessCount, _ = strconv.Atoi(value) case "ModelRequestRateLimitGroup": - err = setting.UpdateModelRequestRateLimitGroup(value) + err = setting.UpdateModelRequestRateLimitGroupByJSONString(value) case "RetryTimes": common.RetryTimes, _ = strconv.Atoi(value) case "DataExportInterval": diff --git a/setting/rate_limit.go b/setting/rate_limit.go index 5be75cc1..aab030cd 100644 --- a/setting/rate_limit.go +++ b/setting/rate_limit.go @@ -2,7 +2,6 @@ package setting import ( "encoding/json" - "fmt" "one-api/common" "sync" ) @@ -11,33 +10,31 @@ var ModelRequestRateLimitEnabled = false var ModelRequestRateLimitDurationMinutes = 1 var ModelRequestRateLimitCount = 0 var ModelRequestRateLimitSuccessCount = 1000 -var ModelRequestRateLimitGroup map[string][2]int +var ModelRequestRateLimitGroup = map[string][2]int{} +var ModelRequestRateLimitMutex sync.RWMutex -var ModelRequestRateLimitGroupMutex sync.RWMutex +func ModelRequestRateLimitGroup2JSONString() string { + ModelRequestRateLimitMutex.RLock() + defer ModelRequestRateLimitMutex.RUnlock() -func UpdateModelRequestRateLimitGroup(jsonStr string) error { - ModelRequestRateLimitGroupMutex.Lock() - defer ModelRequestRateLimitGroupMutex.Unlock() - - var newConfig map[string][2]int - if jsonStr == "" || jsonStr == "{}" { - ModelRequestRateLimitGroup = make(map[string][2]int) - common.SysLog("Model request rate limit group config cleared") - return nil - } - - err := json.Unmarshal([]byte(jsonStr), &newConfig) + jsonBytes, err := json.Marshal(ModelRequestRateLimitGroup) if err != nil { - return fmt.Errorf("failed to unmarshal ModelRequestRateLimitGroup config: %w", err) + common.SysError("error marshalling model ratio: " + err.Error()) } + return string(jsonBytes) +} - ModelRequestRateLimitGroup = newConfig - return nil +func UpdateModelRequestRateLimitGroupByJSONString(jsonStr string) error { + ModelRequestRateLimitMutex.RLock() + defer ModelRequestRateLimitMutex.RUnlock() + + ModelRequestRateLimitGroup = make(map[string][2]int) + return json.Unmarshal([]byte(jsonStr), &ModelRequestRateLimitGroup) } func GetGroupRateLimit(group string) (totalCount, successCount int, found bool) { - ModelRequestRateLimitGroupMutex.RLock() - defer ModelRequestRateLimitGroupMutex.RUnlock() + ModelRequestRateLimitMutex.RLock() + defer ModelRequestRateLimitMutex.RUnlock() if ModelRequestRateLimitGroup == nil { return 0, 0, false diff --git a/web/src/components/RateLimitSetting.js b/web/src/components/RateLimitSetting.js index 7e206672..309b94de 100644 --- a/web/src/components/RateLimitSetting.js +++ b/web/src/components/RateLimitSetting.js @@ -24,7 +24,6 @@ const RateLimitSetting = () => { if (success) { let newInputs = {}; data.forEach((item) => { - // 检查 key 是否在初始 inputs 中定义 if (Object.prototype.hasOwnProperty.call(inputs, item.key)) { if (item.key.endsWith('Enabled')) { newInputs[item.key] = item.value === 'true'; @@ -33,6 +32,7 @@ const RateLimitSetting = () => { } } }); + setInputs(newInputs); } else { showError(message); @@ -42,6 +42,7 @@ const RateLimitSetting = () => { try { setLoading(true); await getOptions(); + // showSuccess('刷新成功'); } catch (error) { showError('刷新失败'); } finally { @@ -56,6 +57,7 @@ const RateLimitSetting = () => { return ( <> + {/* AI请求速率限制 */} @@ -64,4 +66,4 @@ const RateLimitSetting = () => { ); }; - export default RateLimitSetting; + export default RateLimitSetting; \ No newline at end of file From 88ed83f41927eacc43526b5739592016d2ae4c10 Mon Sep 17 00:00:00 2001 From: tbphp Date: Mon, 5 May 2025 20:00:06 +0800 Subject: [PATCH 05/16] feat: Modellimitgroup check --- controller/option.go | 9 +++++++++ setting/rate_limit.go | 16 ++++++++++++++++ 2 files changed, 25 insertions(+) diff --git a/controller/option.go b/controller/option.go index 81ef463c..250f16bb 100644 --- a/controller/option.go +++ b/controller/option.go @@ -110,6 +110,15 @@ func UpdateOption(c *gin.Context) { }) return } + case "ModelRequestRateLimitGroup": + err = setting.CheckModelRequestRateLimitGroup(option.Value) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } } err = model.UpdateOption(option.Key, option.Value) diff --git a/setting/rate_limit.go b/setting/rate_limit.go index aab030cd..14680791 100644 --- a/setting/rate_limit.go +++ b/setting/rate_limit.go @@ -2,6 +2,7 @@ package setting import ( "encoding/json" + "fmt" "one-api/common" "sync" ) @@ -46,3 +47,18 @@ func GetGroupRateLimit(group string) (totalCount, successCount int, found bool) } return limits[0], limits[1], true } + +func CheckModelRequestRateLimitGroup(jsonStr string) error { + checkModelRequestRateLimitGroup := make(map[string][2]int) + err := json.Unmarshal([]byte(jsonStr), &checkModelRequestRateLimitGroup) + if err != nil { + return err + } + for group, limits := range checkModelRequestRateLimitGroup { + if limits[0] < 0 || limits[1] < 0 { + return fmt.Errorf("group %s has negative rate limit values: [%d, %d]", group, limits[0], limits[1]) + } + } + + return nil +} From 1cb4d750e471649da8fa5824942c43bffdc4705e Mon Sep 17 00:00:00 2001 From: tbphp Date: Mon, 5 May 2025 22:06:16 +0800 Subject: [PATCH 06/16] =?UTF-8?q?feat:=20=E5=88=86=E7=BB=84=E9=80=9F?= =?UTF-8?q?=E7=8E=87=E5=89=8D=E7=AB=AF=E4=BC=98=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- web/src/components/RateLimitSetting.js | 16 ++-- .../RateLimit/SettingsRequestRateLimit.js | 83 ++++++++----------- 2 files changed, 45 insertions(+), 54 deletions(-) diff --git a/web/src/components/RateLimitSetting.js b/web/src/components/RateLimitSetting.js index 309b94de..4671317f 100644 --- a/web/src/components/RateLimitSetting.js +++ b/web/src/components/RateLimitSetting.js @@ -13,7 +13,7 @@ const RateLimitSetting = () => { ModelRequestRateLimitCount: 0, ModelRequestRateLimitSuccessCount: 1000, ModelRequestRateLimitDurationMinutes: 1, - ModelRequestRateLimitGroup: '{}', + ModelRequestRateLimitGroup: '', }); let [loading, setLoading] = useState(false); @@ -24,12 +24,14 @@ const RateLimitSetting = () => { if (success) { let newInputs = {}; data.forEach((item) => { - if (Object.prototype.hasOwnProperty.call(inputs, item.key)) { - if (item.key.endsWith('Enabled')) { - newInputs[item.key] = item.value === 'true'; - } else { - newInputs[item.key] = item.value; - } + if (item.key === 'ModelRequestRateLimitGroup') { + item.value = JSON.stringify(JSON.parse(item.value), null, 2); + } + + if (item.key.endsWith('Enabled')) { + newInputs[item.key] = item.value === 'true' ? true : false; + } else { + newInputs[item.key] = item.value; } }); diff --git a/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js b/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js index 2434020e..b77c1e6a 100644 --- a/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js +++ b/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js @@ -6,6 +6,7 @@ import { showError, showSuccess, showWarning, + verifyJSON, } from '../../../helpers'; import { useTranslation } from 'react-i18next'; @@ -18,7 +19,7 @@ export default function RequestRateLimit(props) { ModelRequestRateLimitCount: -1, ModelRequestRateLimitSuccessCount: 1000, ModelRequestRateLimitDurationMinutes: 1, - ModelRequestRateLimitGroup: '{}', + ModelRequestRateLimitGroup: '', }); const refForm = useRef(); const [inputsRow, setInputsRow] = useState(inputs); @@ -33,43 +34,32 @@ export default function RequestRateLimit(props) { } else { value = inputs[item.key]; } - if (item.key === 'ModelRequestRateLimitGroup') { - try { - JSON.parse(value); - } catch (e) { - showError(t('用户组速率限制配置不是有效的 JSON 格式!')); - return Promise.reject('Invalid JSON format'); - } - } return API.put('/api/option/', { key: item.key, value, }); }); - - const validRequests = requestQueue.filter(req => req !== null && req !== undefined && typeof req.then === 'function'); - - if (validRequests.length === 0 && requestQueue.length > 0) { - return; - } - setLoading(true); - Promise.all(validRequests) + Promise.all(requestQueue) .then((res) => { - if (validRequests.length === 1) { + if (requestQueue.length === 1) { if (res.includes(undefined)) return; - } else if (validRequests.length > 1) { + } else if (requestQueue.length > 1) { if (res.includes(undefined)) return showError(t('部分保存失败,请重试')); } + + for (let i = 0; i < res.length; i++) { + if (!res[i].data.success) { + return showError(res[i].data.message); + } + } + showSuccess(t('保存成功')); props.refresh(); - setInputsRow(structuredClone(inputs)); }) - .catch((error) => { - if (error !== 'Invalid JSON format') { - showError(t('保存失败,请重试')); - } + .catch(() => { + showError(t('保存失败,请重试')); }) .finally(() => { setLoading(false); @@ -79,15 +69,13 @@ export default function RequestRateLimit(props) { useEffect(() => { const currentInputs = {}; for (let key in props.options) { - if (Object.prototype.hasOwnProperty.call(inputs, key)) { // 使用 hasOwnProperty 检查 + if (Object.keys(inputs).includes(key)) { currentInputs[key] = props.options[key]; } } setInputs(currentInputs); setInputsRow(structuredClone(currentInputs)); - if (refForm.current) { refForm.current.setValues(currentInputs); - } }, [props.options]); return ( @@ -168,40 +156,41 @@ export default function RequestRateLimit(props) { />
- - + + verifyJSON(value), + message: t('不是合法的 JSON 字符串'), + }, + ]} extraText={
-

{t('说明:')}

+

{t('说明:')}

    -
  • {t('使用 JSON 对象格式,键为用户组名 (字符串),值为包含两个整数的数组 [总次数限制, 成功次数限制]。')}
  • -
  • {t('总次数限制: 周期内允许的总请求次数 (含失败),0 代表不限制。')}
  • -
  • {t('成功次数限制: 周期内允许的成功请求次数 (HTTP < 400),必须大于 0。')}
  • -
  • {t('此配置将优先于上方的全局限制设置。')}
  • -
  • {t('未在此处配置的用户组将使用全局限制。')}
  • +
  • {t('使用 JSON 对象格式,格式为:{"组名": [最多请求次数, 最多请求完成次数]}')}
  • +
  • {t('示例:{"default": [200, 100], "vip": [0, 1000]}。')}
  • +
  • {t('分组速率配置优先级高于全局速率限制。')}
  • {t('限制周期统一使用上方配置的“限制周期”值。')}
  • -
  • {t('输入无效的 JSON 将无法保存。')}
} - autosize={{ minRows: 5, maxRows: 15 }} - style={{ fontFamily: 'monospace' }} onChange={(value) => { - setInputs({ - ...inputs, - ModelRequestRateLimitGroup: value, - }); + setInputs({ ...inputs, ModelRequestRateLimitGroup: value }); }} />
- + @@ -211,4 +200,4 @@ export default function RequestRateLimit(props) { ); - } + } \ No newline at end of file From 0be3678c9ca8d687920ba52ff7d17d65afba23ca Mon Sep 17 00:00:00 2001 From: tbphp Date: Mon, 5 May 2025 23:41:43 +0800 Subject: [PATCH 07/16] =?UTF-8?q?fix:=20=E8=AF=B7=E6=B1=82=E5=AE=8C?= =?UTF-8?q?=E6=88=90=E6=95=B0=E5=BF=85=E9=A1=BB=E5=A4=A7=E4=BA=8E=E7=AD=89?= =?UTF-8?q?=E4=BA=8E1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- setting/rate_limit.go | 2 +- web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/setting/rate_limit.go b/setting/rate_limit.go index 14680791..53b53f88 100644 --- a/setting/rate_limit.go +++ b/setting/rate_limit.go @@ -55,7 +55,7 @@ func CheckModelRequestRateLimitGroup(jsonStr string) error { return err } for group, limits := range checkModelRequestRateLimitGroup { - if limits[0] < 0 || limits[1] < 0 { + if limits[0] < 0 || limits[1] < 1 { return fmt.Errorf("group %s has negative rate limit values: [%d, %d]", group, limits[0], limits[1]) } } diff --git a/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js b/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js index b77c1e6a..ae54b1ef 100644 --- a/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js +++ b/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js @@ -179,6 +179,7 @@ export default function RequestRateLimit(props) {
  • {t('使用 JSON 对象格式,格式为:{"组名": [最多请求次数, 最多请求完成次数]}')}
  • {t('示例:{"default": [200, 100], "vip": [0, 1000]}。')}
  • +
  • {t('[最多请求次数]必须大于等于0,[最多请求完成次数]必须大于等于1')}
  • {t('分组速率配置优先级高于全局速率限制。')}
  • {t('限制周期统一使用上方配置的“限制周期”值。')}
From bbab729619820b49706af49a48596e8cab105bde Mon Sep 17 00:00:00 2001 From: tbphp Date: Mon, 5 May 2025 23:48:15 +0800 Subject: [PATCH 08/16] fix: text --- web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js b/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js index ae54b1ef..7003c279 100644 --- a/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js +++ b/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js @@ -179,7 +179,7 @@ export default function RequestRateLimit(props) {
  • {t('使用 JSON 对象格式,格式为:{"组名": [最多请求次数, 最多请求完成次数]}')}
  • {t('示例:{"default": [200, 100], "vip": [0, 1000]}。')}
  • -
  • {t('[最多请求次数]必须大于等于0,[最多请求完成次数]必须大于等于1')}
  • +
  • {t('[最多请求次数]必须大于等于0,[最多请求完成次数]必须大于等于1。')}
  • {t('分组速率配置优先级高于全局速率限制。')}
  • {t('限制周期统一使用上方配置的“限制周期”值。')}
From 87188cd7d458464c7e83e3502eb0a11126e6f94e Mon Sep 17 00:00:00 2001 From: tbphp Date: Mon, 5 May 2025 23:53:05 +0800 Subject: [PATCH 09/16] =?UTF-8?q?fix:=20=E7=BC=A9=E8=BF=9B=E4=BF=AE?= =?UTF-8?q?=E5=A4=8D=E8=BF=98=E5=8E=9F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- middleware/model-rate-limit.go | 2 +- model/option.go | 2 +- web/src/components/RateLimitSetting.js | 92 ++--- .../RateLimit/SettingsRequestRateLimit.js | 344 +++++++++--------- 4 files changed, 220 insertions(+), 220 deletions(-) diff --git a/middleware/model-rate-limit.go b/middleware/model-rate-limit.go index 1ca5ace6..03ef0ff3 100644 --- a/middleware/model-rate-limit.go +++ b/middleware/model-rate-limit.go @@ -194,4 +194,4 @@ func ModelRequestRateLimit() func(c *gin.Context) { memoryRateLimitHandler(duration, totalMaxCount, successMaxCount)(c) } } -} \ No newline at end of file +} diff --git a/model/option.go b/model/option.go index e9c129e1..d98a9d38 100644 --- a/model/option.go +++ b/model/option.go @@ -402,4 +402,4 @@ func handleConfigUpdate(key, value string) bool { config.UpdateConfigFromMap(cfg, configMap) return true // 已处理 -} \ No newline at end of file +} diff --git a/web/src/components/RateLimitSetting.js b/web/src/components/RateLimitSetting.js index 4671317f..a0953db7 100644 --- a/web/src/components/RateLimitSetting.js +++ b/web/src/components/RateLimitSetting.js @@ -9,62 +9,62 @@ import RequestRateLimit from '../pages/Setting/RateLimit/SettingsRequestRateLimi const RateLimitSetting = () => { const { t } = useTranslation(); let [inputs, setInputs] = useState({ - ModelRequestRateLimitEnabled: false, - ModelRequestRateLimitCount: 0, - ModelRequestRateLimitSuccessCount: 1000, - ModelRequestRateLimitDurationMinutes: 1, - ModelRequestRateLimitGroup: '', + ModelRequestRateLimitEnabled: false, + ModelRequestRateLimitCount: 0, + ModelRequestRateLimitSuccessCount: 1000, + ModelRequestRateLimitDurationMinutes: 1, + ModelRequestRateLimitGroup: '', }); - - let [loading, setLoading] = useState(false); - - const getOptions = async () => { - const res = await API.get('/api/option/'); - const { success, message, data } = res.data; - if (success) { - let newInputs = {}; - data.forEach((item) => { - if (item.key === 'ModelRequestRateLimitGroup') { - item.value = JSON.stringify(JSON.parse(item.value), null, 2); - } - if (item.key.endsWith('Enabled')) { - newInputs[item.key] = item.value === 'true' ? true : false; - } else { - newInputs[item.key] = item.value; - } - }); - - setInputs(newInputs); - } else { - showError(message); - } + let [loading, setLoading] = useState(false); + + const getOptions = async () => { + const res = await API.get('/api/option/'); + const { success, message, data } = res.data; + if (success) { + let newInputs = {}; + data.forEach((item) => { + if (item.key === 'ModelRequestRateLimitGroup') { + item.value = JSON.stringify(JSON.parse(item.value), null, 2); + } + + if (item.key.endsWith('Enabled')) { + newInputs[item.key] = item.value === 'true' ? true : false; + } else { + newInputs[item.key] = item.value; + } + }); + + setInputs(newInputs); + } else { + showError(message); + } }; async function onRefresh() { - try { - setLoading(true); - await getOptions(); - // showSuccess('刷新成功'); - } catch (error) { - showError('刷新失败'); - } finally { - setLoading(false); - } + try { + setLoading(true); + await getOptions(); + // showSuccess('刷新成功'); + } catch (error) { + showError('刷新失败'); + } finally { + setLoading(false); + } } useEffect(() => { - onRefresh(); + onRefresh(); }, []); return ( - <> - - {/* AI请求速率限制 */} - - - - - + <> + + {/* AI请求速率限制 */} + + + + + ); }; diff --git a/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js b/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js index 7003c279..7c60bc47 100644 --- a/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js +++ b/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js @@ -15,190 +15,190 @@ export default function RequestRateLimit(props) { const [loading, setLoading] = useState(false); const [inputs, setInputs] = useState({ - ModelRequestRateLimitEnabled: false, - ModelRequestRateLimitCount: -1, - ModelRequestRateLimitSuccessCount: 1000, - ModelRequestRateLimitDurationMinutes: 1, - ModelRequestRateLimitGroup: '', + ModelRequestRateLimitEnabled: false, + ModelRequestRateLimitCount: -1, + ModelRequestRateLimitSuccessCount: 1000, + ModelRequestRateLimitDurationMinutes: 1, + ModelRequestRateLimitGroup: '', }); const refForm = useRef(); const [inputsRow, setInputsRow] = useState(inputs); function onSubmit() { - const updateArray = compareObjects(inputs, inputsRow); - if (!updateArray.length) return showWarning(t('你似乎并没有修改什么')); - const requestQueue = updateArray.map((item) => { - let value = ''; - if (typeof inputs[item.key] === 'boolean') { - value = String(inputs[item.key]); - } else { - value = inputs[item.key]; - } - return API.put('/api/option/', { - key: item.key, - value, - }); - }); - setLoading(true); - Promise.all(requestQueue) - .then((res) => { - if (requestQueue.length === 1) { - if (res.includes(undefined)) return; - } else if (requestQueue.length > 1) { - if (res.includes(undefined)) - return showError(t('部分保存失败,请重试')); - } + const updateArray = compareObjects(inputs, inputsRow); + if (!updateArray.length) return showWarning(t('你似乎并没有修改什么')); + const requestQueue = updateArray.map((item) => { + let value = ''; + if (typeof inputs[item.key] === 'boolean') { + value = String(inputs[item.key]); + } else { + value = inputs[item.key]; + } + return API.put('/api/option/', { + key: item.key, + value, + }); + }); + setLoading(true); + Promise.all(requestQueue) + .then((res) => { + if (requestQueue.length === 1) { + if (res.includes(undefined)) return; + } else if (requestQueue.length > 1) { + if (res.includes(undefined)) + return showError(t('部分保存失败,请重试')); + } - for (let i = 0; i < res.length; i++) { - if (!res[i].data.success) { - return showError(res[i].data.message); - } - } + for (let i = 0; i < res.length; i++) { + if (!res[i].data.success) { + return showError(res[i].data.message); + } + } - showSuccess(t('保存成功')); - props.refresh(); - }) - .catch(() => { - showError(t('保存失败,请重试')); - }) - .finally(() => { - setLoading(false); - }); + showSuccess(t('保存成功')); + props.refresh(); + }) + .catch(() => { + showError(t('保存失败,请重试')); + }) + .finally(() => { + setLoading(false); + }); } useEffect(() => { - const currentInputs = {}; - for (let key in props.options) { - if (Object.keys(inputs).includes(key)) { - currentInputs[key] = props.options[key]; - } - } - setInputs(currentInputs); - setInputsRow(structuredClone(currentInputs)); - refForm.current.setValues(currentInputs); + const currentInputs = {}; + for (let key in props.options) { + if (Object.keys(inputs).includes(key)) { + currentInputs[key] = props.options[key]; + } + } + setInputs(currentInputs); + setInputsRow(structuredClone(currentInputs)); + refForm.current.setValues(currentInputs); }, [props.options]); return ( - <> - -
(refForm.current = formAPI)} - style={{ marginBottom: 15 }} - > - - - - { - setInputs({ - ...inputs, - ModelRequestRateLimitEnabled: value, - }); - }} - /> - - - - - - setInputs({ - ...inputs, - ModelRequestRateLimitDurationMinutes: String(value), - }) - } - /> - - - - - - setInputs({ - ...inputs, - ModelRequestRateLimitCount: String(value), - }) - } - /> - - - - setInputs({ - ...inputs, - ModelRequestRateLimitSuccessCount: String(value), - }) - } - /> - - - - - verifyJSON(value), - message: t('不是合法的 JSON 字符串'), - }, - ]} - extraText={ -
-

{t('说明:')}

-
    -
  • {t('使用 JSON 对象格式,格式为:{"组名": [最多请求次数, 最多请求完成次数]}')}
  • -
  • {t('示例:{"default": [200, 100], "vip": [0, 1000]}。')}
  • -
  • {t('[最多请求次数]必须大于等于0,[最多请求完成次数]必须大于等于1。')}
  • -
  • {t('分组速率配置优先级高于全局速率限制。')}
  • -
  • {t('限制周期统一使用上方配置的“限制周期”值。')}
  • -
-
- } - onChange={(value) => { - setInputs({ ...inputs, ModelRequestRateLimitGroup: value }); - }} - /> - -
- - - -
-
-
- + <> + +
(refForm.current = formAPI)} + style={{ marginBottom: 15 }} + > + + + + { + setInputs({ + ...inputs, + ModelRequestRateLimitEnabled: value, + }); + }} + /> + + + + + + setInputs({ + ...inputs, + ModelRequestRateLimitDurationMinutes: String(value), + }) + } + /> + + + + + + setInputs({ + ...inputs, + ModelRequestRateLimitCount: String(value), + }) + } + /> + + + + setInputs({ + ...inputs, + ModelRequestRateLimitSuccessCount: String(value), + }) + } + /> + + + + + verifyJSON(value), + message: t('不是合法的 JSON 字符串'), + }, + ]} + extraText={ +
+

{t('说明:')}

+
    +
  • {t('使用 JSON 对象格式,格式为:{"组名": [最多请求次数, 最多请求完成次数]}')}
  • +
  • {t('示例:{"default": [200, 100], "vip": [0, 1000]}。')}
  • +
  • {t('[最多请求次数]必须大于等于0,[最多请求完成次数]必须大于等于1。')}
  • +
  • {t('分组速率配置优先级高于全局速率限制。')}
  • +
  • {t('限制周期统一使用上方配置的“限制周期”值。')}
  • +
+
+ } + onChange={(value) => { + setInputs({ ...inputs, ModelRequestRateLimitGroup: value }); + }} + /> + +
+ + + +
+
+
+ ); } \ No newline at end of file From 3d243c3ee2bc2a92d21d31f0155378ac5c188c39 Mon Sep 17 00:00:00 2001 From: tbphp Date: Mon, 5 May 2025 23:56:15 +0800 Subject: [PATCH 10/16] =?UTF-8?q?fix:=20=E6=A0=B7=E5=BC=8F=E4=BF=AE?= =?UTF-8?q?=E5=A4=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- web/src/components/RateLimitSetting.js | 16 ++++++++-------- .../RateLimit/SettingsRequestRateLimit.js | 10 +++++----- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/web/src/components/RateLimitSetting.js b/web/src/components/RateLimitSetting.js index a0953db7..5f0200e1 100644 --- a/web/src/components/RateLimitSetting.js +++ b/web/src/components/RateLimitSetting.js @@ -34,7 +34,7 @@ const RateLimitSetting = () => { newInputs[item.key] = item.value; } }); - + setInputs(newInputs); } else { showError(message); @@ -44,28 +44,28 @@ const RateLimitSetting = () => { try { setLoading(true); await getOptions(); - // showSuccess('刷新成功'); + // showSuccess('刷新成功'); } catch (error) { showError('刷新失败'); } finally { setLoading(false); } } - + useEffect(() => { onRefresh(); }, []); - + return ( <> - {/* AI请求速率限制 */} + {/* AI请求速率限制 */} ); - }; - - export default RateLimitSetting; \ No newline at end of file +}; + +export default RateLimitSetting; diff --git a/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js b/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js index 7c60bc47..73626351 100644 --- a/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js +++ b/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js @@ -23,7 +23,7 @@ export default function RequestRateLimit(props) { }); const refForm = useRef(); const [inputsRow, setInputsRow] = useState(inputs); - + function onSubmit() { const updateArray = compareObjects(inputs, inputsRow); if (!updateArray.length) return showWarning(t('你似乎并没有修改什么')); @@ -65,7 +65,7 @@ export default function RequestRateLimit(props) { setLoading(false); }); } - + useEffect(() => { const currentInputs = {}; for (let key in props.options) { @@ -75,9 +75,9 @@ export default function RequestRateLimit(props) { } setInputs(currentInputs); setInputsRow(structuredClone(currentInputs)); - refForm.current.setValues(currentInputs); + refForm.current.setValues(currentInputs); }, [props.options]); - + return ( <> @@ -201,4 +201,4 @@ export default function RequestRateLimit(props) { ); - } \ No newline at end of file +} From 97b5ca809982751d9bb4c1f1dc3a5a45fafa21bd Mon Sep 17 00:00:00 2001 From: liusanp Date: Wed, 7 May 2025 16:17:22 +0800 Subject: [PATCH 11/16] fix: quality, size or style are not supported by xAI API --- relay/channel/xai/adaptor.go | 9 +++++++-- relay/channel/xai/dto.go | 13 +++++++++++++ 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/relay/channel/xai/adaptor.go b/relay/channel/xai/adaptor.go index 12634c84..18c734ee 100644 --- a/relay/channel/xai/adaptor.go +++ b/relay/channel/xai/adaptor.go @@ -28,8 +28,13 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf } func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { - request.Size = "" - return request, nil + xaiRequest := ImageRequest{ + Model: request.Model, + Prompt: request.Prompt, + N: request.N, + ResponseFormat: request.ResponseFormat, + } + return xaiRequest, nil } func (a *Adaptor) Init(info *relaycommon.RelayInfo) { diff --git a/relay/channel/xai/dto.go b/relay/channel/xai/dto.go index 7036d5f1..b8098475 100644 --- a/relay/channel/xai/dto.go +++ b/relay/channel/xai/dto.go @@ -12,3 +12,16 @@ type ChatCompletionResponse struct { Usage *dto.Usage `json:"usage"` SystemFingerprint string `json:"system_fingerprint"` } + +// quality, size or style are not supported by xAI API at the moment. +type ImageRequest struct { + Model string `json:"model"` + Prompt string `json:"prompt" binding:"required"` + N int `json:"n,omitempty"` + // Size string `json:"size,omitempty"` + // Quality string `json:"quality,omitempty"` + ResponseFormat string `json:"response_format,omitempty"` + // Style string `json:"style,omitempty"` + // User string `json:"user,omitempty"` + // ExtraFields json.RawMessage `json:"extra_fields,omitempty"` +} \ No newline at end of file From 04f7d89399c71e731788ccdee18e7b39c4f3b0a4 Mon Sep 17 00:00:00 2001 From: liusanp Date: Wed, 7 May 2025 18:32:59 +0800 Subject: [PATCH 12/16] fix: xAi requestUrl --- relay/channel/xai/adaptor.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/relay/channel/xai/adaptor.go b/relay/channel/xai/adaptor.go index 18c734ee..21eed20c 100644 --- a/relay/channel/xai/adaptor.go +++ b/relay/channel/xai/adaptor.go @@ -41,7 +41,7 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { - return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil + return relaycommon.GetFullRequestURL(info.BaseUrl, info.RequestURLPath, info.ChannelType), nil } func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { From 562448b44105e4e85d5d043ccaccddea2084aa41 Mon Sep 17 00:00:00 2001 From: liusanp Date: Wed, 7 May 2025 18:59:27 +0800 Subject: [PATCH 13/16] fix: xAi response --- relay/channel/xai/adaptor.go | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/relay/channel/xai/adaptor.go b/relay/channel/xai/adaptor.go index 21eed20c..b5896415 100644 --- a/relay/channel/xai/adaptor.go +++ b/relay/channel/xai/adaptor.go @@ -2,14 +2,16 @@ package xai import ( "errors" - "fmt" "io" "net/http" "one-api/dto" "one-api/relay/channel" + "one-api/relay/channel/openai" relaycommon "one-api/relay/common" "strings" + "one-api/relay/constant" + "github.com/gin-gonic/gin" ) @@ -94,15 +96,16 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request } func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { - if info.IsStream { - err, usage = xAIStreamHandler(c, resp, info) - } else { - err, usage = xAIHandler(c, resp, info) + switch info.RelayMode { + case constant.RelayModeImagesGenerations, constant.RelayModeImagesEdits: + err, usage = openai.OpenaiHandlerWithUsage(c, resp, info) + default: + if info.IsStream { + err, usage = xAIStreamHandler(c, resp, info) + } else { + err, usage = xAIHandler(c, resp, info) + } } - //if _, ok := usage.(*dto.Usage); ok && usage != nil { - // usage.(*dto.Usage).CompletionTokens = usage.(*dto.Usage).TotalTokens - usage.(*dto.Usage).PromptTokens - //} - return } From ae254f5368ff3764ab0d23b4f6dd2c54411a359c Mon Sep 17 00:00:00 2001 From: creamlike1024 Date: Wed, 7 May 2025 19:33:32 +0800 Subject: [PATCH 14/16] fix: tool quota calculate --- relay/relay-text.go | 6 +++--- web/src/helpers/render.js | 21 ++++++++++++++------- 2 files changed, 17 insertions(+), 10 deletions(-) diff --git a/relay/relay-text.go b/relay/relay-text.go index 89a6a973..e0b6ad0e 100644 --- a/relay/relay-text.go +++ b/relay/relay-text.go @@ -364,11 +364,11 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, var webSearchPrice float64 if relayInfo.ResponsesUsageInfo != nil { if webSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolWebSearchPreview]; exists && webSearchTool.CallCount > 0 { - // 计算 web search 调用的配额 (配额 = 价格 * 调用次数 / 1000) + // 计算 web search 调用的配额 (配额 = 价格 * 调用次数 / 1000 * 分组倍率) webSearchPrice = operation_setting.GetWebSearchPricePerThousand(modelName, webSearchTool.SearchContextSize) dWebSearchQuota = decimal.NewFromFloat(webSearchPrice). Mul(decimal.NewFromInt(int64(webSearchTool.CallCount))). - Div(decimal.NewFromInt(1000)) + Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit) extraContent += fmt.Sprintf("Web Search 调用 %d 次,上下文大小 %s,调用花费 $%s", webSearchTool.CallCount, webSearchTool.SearchContextSize, dWebSearchQuota.String()) } @@ -381,7 +381,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, fileSearchPrice = operation_setting.GetFileSearchPricePerThousand() dFileSearchQuota = decimal.NewFromFloat(fileSearchPrice). Mul(decimal.NewFromInt(int64(fileSearchTool.CallCount))). - Div(decimal.NewFromInt(1000)) + Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit) extraContent += fmt.Sprintf("File Search 调用 %d 次,调用花费 $%s", fileSearchTool.CallCount, dFileSearchQuota.String()) } diff --git a/web/src/helpers/render.js b/web/src/helpers/render.js index fb4c3dbd..5a59356b 100644 --- a/web/src/helpers/render.js +++ b/web/src/helpers/render.js @@ -354,8 +354,8 @@ export function renderModelPrice( let price = (effectiveInputTokens / 1000000) * inputRatioPrice * groupRatio + (completionTokens / 1000000) * completionRatioPrice * groupRatio + - (webSearchCallCount / 1000) * webSearchPrice + - (fileSearchCallCount / 1000) * fileSearchPrice; + (webSearchCallCount / 1000) * webSearchPrice * groupRatio + + (fileSearchCallCount / 1000) * fileSearchPrice * groupRatio; return ( <> @@ -446,7 +446,7 @@ export function renderModelPrice( ) : webSearch && webSearchCallCount > 0 && !image && !fileSearch ? i18next.t( - '输入 {{input}} tokens / 1M tokens * ${{price}} + 输出 {{completion}} tokens / 1M tokens * ${{compPrice}} * 分组 {{ratio}} + Web搜索 {{webSearchCallCount}}次 / 1K 次 * ${{webSearchPrice}} = ${{total}}', + '输入 {{input}} tokens / 1M tokens * ${{price}} + 输出 {{completion}} tokens / 1M tokens * ${{compPrice}} * 分组 {{ratio}} + Web搜索 {{webSearchCallCount}}次 / 1K 次 * ${{webSearchPrice}} * {{ratio}} = ${{total}}', { input: inputTokens, price: inputRatioPrice, @@ -458,9 +458,12 @@ export function renderModelPrice( total: price.toFixed(6), }, ) - : fileSearch && fileSearchCallCount > 0 && !image && !webSearch + : fileSearch && + fileSearchCallCount > 0 && + !image && + !webSearch ? i18next.t( - '输入 {{input}} tokens / 1M tokens * ${{price}} + 输出 {{completion}} tokens / 1M tokens * ${{compPrice}} * 分组 {{ratio}} + 文件搜索 {{fileSearchCallCount}}次 / 1K 次 * ${{fileSearchPrice}} = ${{total}}', + '输入 {{input}} tokens / 1M tokens * ${{price}} + 输出 {{completion}} tokens / 1M tokens * ${{compPrice}} * 分组 {{ratio}} + 文件搜索 {{fileSearchCallCount}}次 / 1K 次 * ${{fileSearchPrice}} * {{ratio}}= ${{total}}', { input: inputTokens, price: inputRatioPrice, @@ -472,9 +475,13 @@ export function renderModelPrice( total: price.toFixed(6), }, ) - : webSearch && webSearchCallCount > 0 && fileSearch && fileSearchCallCount > 0 && !image + : webSearch && + webSearchCallCount > 0 && + fileSearch && + fileSearchCallCount > 0 && + !image ? i18next.t( - '输入 {{input}} tokens / 1M tokens * ${{price}} + 输出 {{completion}} tokens / 1M tokens * ${{compPrice}} * 分组 {{ratio}} + Web搜索 {{webSearchCallCount}}次 / 1K 次 * ${{webSearchPrice}} + 文件搜索 {{fileSearchCallCount}}次 / 1K 次 * ${{fileSearchPrice}} = ${{total}}', + '输入 {{input}} tokens / 1M tokens * ${{price}} + 输出 {{completion}} tokens / 1M tokens * ${{compPrice}} * 分组 {{ratio}} + Web搜索 {{webSearchCallCount}}次 / 1K 次 * ${{webSearchPrice}} * {{ratio}}+ 文件搜索 {{fileSearchCallCount}}次 / 1K 次 * ${{fileSearchPrice}} * {{ratio}}= ${{total}}', { input: inputTokens, price: inputRatioPrice, From d40429ad935ec63dc34e585475d3b2a96f75ad2c Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Thu, 8 May 2025 21:34:31 +0800 Subject: [PATCH 15/16] fix: update OpenAI request handling to include 'o1-preview' model support #1029 --- relay/channel/openai/adaptor.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go index eb12a22a..c81b9366 100644 --- a/relay/channel/openai/adaptor.go +++ b/relay/channel/openai/adaptor.go @@ -173,7 +173,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn info.UpstreamModelName = request.Model // o系列模型developer适配(o1-mini除外) - if !strings.HasPrefix(request.Model, "o1-mini") { + if !strings.HasPrefix(request.Model, "o1-mini") && !strings.HasPrefix(request.Model, "o1-preview") { //修改第一个Message的内容,将system改为developer if len(request.Messages) > 0 && request.Messages[0].Role == "system" { request.Messages[0].Role = "developer" From 90d85a6f0a04e1bd645f61ec06ec8af894282e3c Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Thu, 8 May 2025 22:39:55 +0800 Subject: [PATCH 16/16] feat: add AzureNoRemoveDotTime constant and update channel handling #1044 - Introduced a new constant `AzureNoRemoveDotTime` in `constant/azure.go` to manage model name formatting for channels created after May 10, 2025. - Updated `distributor.go` to set `channel_create_time` in the context. - Modified `adaptor.go` to conditionally remove dots from model names based on the channel creation time. - Enhanced `relay_info.go` to include `ChannelCreateTime` in the `RelayInfo` struct. - Updated English localization files to reflect changes in model name handling for new channels. --- constant/azure.go | 5 ++++ middleware/distributor.go | 1 + relay/channel/openai/adaptor.go | 8 ++++--- relay/common/relay_info.go | 18 ++++++++------- web/src/i18n/locales/en.json | 1 + web/src/pages/Channel/EditChannel.js | 34 ++++++++++++++-------------- 6 files changed, 39 insertions(+), 28 deletions(-) create mode 100644 constant/azure.go diff --git a/constant/azure.go b/constant/azure.go new file mode 100644 index 00000000..d84040ce --- /dev/null +++ b/constant/azure.go @@ -0,0 +1,5 @@ +package constant + +import "time" + +var AzureNoRemoveDotTime = time.Date(2025, time.May, 10, 0, 0, 0, 0, time.UTC).Unix() diff --git a/middleware/distributor.go b/middleware/distributor.go index 51fd8fd1..34882381 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -213,6 +213,7 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode c.Set("channel_id", channel.Id) c.Set("channel_name", channel.Name) c.Set("channel_type", channel.Type) + c.Set("channel_create_time", channel.CreatedTime) c.Set("channel_setting", channel.GetSetting()) c.Set("param_override", channel.GetParamOverride()) if nil != channel.OpenAIOrganization && "" != *channel.OpenAIOrganization { diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go index c81b9366..da92692b 100644 --- a/relay/channel/openai/adaptor.go +++ b/relay/channel/openai/adaptor.go @@ -8,6 +8,7 @@ import ( "io" "mime/multipart" "net/http" + "net/textproto" "one-api/common" constant2 "one-api/constant" "one-api/dto" @@ -25,8 +26,6 @@ import ( "path/filepath" "strings" - "net/textproto" - "github.com/gin-gonic/gin" ) @@ -93,7 +92,10 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, apiVersion) task := strings.TrimPrefix(requestURL, "/v1/") model_ := info.UpstreamModelName - model_ = strings.Replace(model_, ".", "", -1) + // 2025年5月10日后创建的渠道不移除. + if info.ChannelCreateTime < constant2.AzureNoRemoveDotTime { + model_ = strings.Replace(model_, ".", "", -1) + } // https://github.com/songquanpeng/one-api/issues/67 requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task) if info.RelayMode == constant.RelayModeRealtime { diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index 04e28980..f4fc3c1e 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -98,6 +98,7 @@ type RelayInfo struct { UserQuota int RelayFormat string SendResponseCount int + ChannelCreateTime int64 ThinkingContentInfo *ClaudeConvertInfo *RerankerInfo @@ -209,14 +210,15 @@ func GenRelayInfo(c *gin.Context) *RelayInfo { OriginModelName: c.GetString("original_model"), UpstreamModelName: c.GetString("original_model"), //RecodeModelName: c.GetString("original_model"), - IsModelMapped: false, - ApiType: apiType, - ApiVersion: c.GetString("api_version"), - ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "), - Organization: c.GetString("channel_organization"), - ChannelSetting: channelSetting, - ParamOverride: paramOverride, - RelayFormat: RelayFormatOpenAI, + IsModelMapped: false, + ApiType: apiType, + ApiVersion: c.GetString("api_version"), + ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "), + Organization: c.GetString("channel_organization"), + ChannelSetting: channelSetting, + ChannelCreateTime: c.GetInt64("channel_create_time"), + ParamOverride: paramOverride, + RelayFormat: RelayFormatOpenAI, ThinkingContentInfo: ThinkingContentInfo{ IsFirstThinkingContent: true, SendLastThinkingContent: false, diff --git a/web/src/i18n/locales/en.json b/web/src/i18n/locales/en.json index e9975f61..eedf1196 100644 --- a/web/src/i18n/locales/en.json +++ b/web/src/i18n/locales/en.json @@ -493,6 +493,7 @@ "默认": "default", "图片演示": "Image demo", "注意,系统请求的时模型名称中的点会被剔除,例如:gpt-4.1会请求为gpt-41,所以在Azure部署的时候,部署模型名称需要手动改为gpt-41": "Note that the dot in the model name requested by the system will be removed, for example: gpt-4.1 will be requested as gpt-41, so when deploying on Azure, the deployment model name needs to be manually changed to gpt-41", + "2025年5月10日后添加的渠道,不需要再在部署的时候移除模型名称中的\".\"": "After May 10, 2025, channels added do not need to remove the dot in the model name during deployment", "模型映射必须是合法的 JSON 格式!": "Model mapping must be in valid JSON format!", "取消无限额度": "Cancel unlimited quota", "取消": "Cancel", diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js index a793e149..cba787fc 100644 --- a/web/src/pages/Channel/EditChannel.js +++ b/web/src/pages/Channel/EditChannel.js @@ -477,24 +477,24 @@ const EditChannel = (props) => { type={'warning'} description={ <> - {t('注意,系统请求的时模型名称中的点会被剔除,例如:gpt-4.1会请求为gpt-41,所以在Azure部署的时候,部署模型名称需要手动改为gpt-41')} -
- { - setModalImageUrl( - '/azure_model_name.png', - ); - setIsModalOpenurl(true) + {t('2025年5月10日后添加的渠道,不需要再在部署的时候移除模型名称中的"."')} + {/*
*/} + {/* {*/} + {/* setModalImageUrl(*/} + {/* '/azure_model_name.png',*/} + {/* );*/} + {/* setIsModalOpenurl(true)*/} - }} - > - {t('查看示例')} -
+ {/* }}*/} + {/*>*/} + {/* {t('查看示例')}*/} + {/**/} } >