From 562175565582900a9fc03ddfc0502a885f3cf05f Mon Sep 17 00:00:00 2001 From: Glaxy <90437693+QingyeSC@users.noreply.github.com> Date: Wed, 16 Jul 2025 18:00:33 +0800 Subject: [PATCH 01/24] Update relay-claude.go MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 优化claude messages接口启用思考时的参数设置 --- relay/channel/claude/relay-claude.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/relay/channel/claude/relay-claude.go b/relay/channel/claude/relay-claude.go index a8607d86..9638853f 100644 --- a/relay/channel/claude/relay-claude.go +++ b/relay/channel/claude/relay-claude.go @@ -118,7 +118,10 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.Cla } // TODO: 临时处理 // https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#important-considerations-when-using-extended-thinking - claudeRequest.TopP = 0 + // Anthropic 要求去掉 top_k + claudeRequest.TopK = nil + //top_p值可以在0.95-1之间 + claudeRequest.TopP = 0.95 claudeRequest.Temperature = common.GetPointer[float64](1.0) claudeRequest.Model = strings.TrimSuffix(textRequest.Model, "-thinking") } From 3ac54b2178f7dbda26fe3cef487091ba7068ef5d Mon Sep 17 00:00:00 2001 From: RedwindA Date: Fri, 18 Jul 2025 23:38:35 +0800 Subject: [PATCH 02/24] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=20DisablePing=20?= =?UTF-8?q?=E5=AD=97=E6=AE=B5=E4=BB=A5=E6=8E=A7=E5=88=B6=E6=98=AF=E5=90=A6?= =?UTF-8?q?=E5=8F=91=E9=80=81=E8=87=AA=E5=AE=9A=E4=B9=89=20Ping?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- relay/channel/api_request.go | 2 +- relay/common/relay_info.go | 1 + relay/helper/stream_scanner.go | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/relay/channel/api_request.go b/relay/channel/api_request.go index ff7c63fa..3ccd2d78 100644 --- a/relay/channel/api_request.go +++ b/relay/channel/api_request.go @@ -223,7 +223,7 @@ func doRequest(c *gin.Context, req *http.Request, info *common.RelayInfo) (*http helper.SetEventStreamHeaders(c) // 处理流式请求的 ping 保活 generalSettings := operation_setting.GetGeneralSetting() - if generalSettings.PingIntervalEnabled { + if generalSettings.PingIntervalEnabled && !info.DisablePing { pingInterval := time.Duration(generalSettings.PingIntervalSeconds) * time.Second stopPinger = startPingKeepAlive(c, pingInterval) // 使用defer确保在任何情况下都能停止ping goroutine diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index 5b7dee80..26f668ab 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -88,6 +88,7 @@ type RelayInfo struct { BaseUrl string SupportStreamOptions bool ShouldIncludeUsage bool + DisablePing bool // 是否禁止向下游发送自定义 Ping IsModelMapped bool ClientWs *websocket.Conn TargetWs *websocket.Conn diff --git a/relay/helper/stream_scanner.go b/relay/helper/stream_scanner.go index b526b1c0..64919020 100644 --- a/relay/helper/stream_scanner.go +++ b/relay/helper/stream_scanner.go @@ -54,7 +54,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon ) generalSettings := operation_setting.GetGeneralSetting() - pingEnabled := generalSettings.PingIntervalEnabled + pingEnabled := generalSettings.PingIntervalEnabled && !info.DisablePing pingInterval := time.Duration(generalSettings.PingIntervalSeconds) * time.Second if pingInterval <= 0 { pingInterval = DefaultPingInterval From 7af3fb5ae47cd0f03a3ff255f864f854227f8127 Mon Sep 17 00:00:00 2001 From: RedwindA Date: Fri, 18 Jul 2025 23:39:01 +0800 Subject: [PATCH 03/24] =?UTF-8?q?=E7=A6=81=E7=94=A8=E5=8E=9F=E7=94=9FGemin?= =?UTF-8?q?i=E6=A8=A1=E5=BC=8F=E4=B8=AD=E7=9A=84ping=E4=BF=9D=E6=B4=BB?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- relay/channel/gemini/adaptor.go | 1 + 1 file changed, 1 insertion(+) diff --git a/relay/channel/gemini/adaptor.go b/relay/channel/gemini/adaptor.go index 71eb9ba4..a97e9b76 100644 --- a/relay/channel/gemini/adaptor.go +++ b/relay/channel/gemini/adaptor.go @@ -171,6 +171,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) { if info.RelayMode == constant.RelayModeGemini { if info.IsStream { + info.DisablePing = true return GeminiTextGenerationStreamHandler(c, info, resp) } else { return GeminiTextGenerationHandler(c, info, resp) From 6103888610c5d53e62e171c37ae820b9b1b53255 Mon Sep 17 00:00:00 2001 From: RedwindA Date: Sun, 20 Jul 2025 17:35:34 +0800 Subject: [PATCH 04/24] =?UTF-8?q?fix:=20=E4=BF=AE=E6=AD=A3nothinking?= =?UTF-8?q?=E5=88=A4=E6=96=AD=E9=80=BB=E8=BE=91=EF=BC=8C=E7=A1=AE=E4=BF=9D?= =?UTF-8?q?=E4=BB=85=E5=BD=93=E9=A2=84=E7=AE=97=E4=B8=BA=E9=9B=B6=E6=97=B6?= =?UTF-8?q?=E8=BF=94=E5=9B=9Etrue?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- relay/gemini_handler.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/relay/gemini_handler.go b/relay/gemini_handler.go index e448b491..730983e0 100644 --- a/relay/gemini_handler.go +++ b/relay/gemini_handler.go @@ -80,7 +80,7 @@ func getGeminiInputTokens(req *gemini.GeminiChatRequest, info *relaycommon.Relay func isNoThinkingRequest(req *gemini.GeminiChatRequest) bool { if req.GenerationConfig.ThinkingConfig != nil && req.GenerationConfig.ThinkingConfig.ThinkingBudget != nil { - return *req.GenerationConfig.ThinkingConfig.ThinkingBudget <= 0 + return *req.GenerationConfig.ThinkingConfig.ThinkingBudget == 0 } return false } From 19df2ac234d69d00ecdefa3e11c0a85d64361b4b Mon Sep 17 00:00:00 2001 From: Raymond <240029725@qq.com> Date: Sat, 26 Jul 2025 17:09:38 +0800 Subject: [PATCH 05/24] =?UTF-8?q?=E6=A8=A1=E5=9E=8B=E8=AF=B7=E6=B1=82?= =?UTF-8?q?=E9=80=9F=E7=8E=87=E9=99=90=E5=88=B6=EF=BC=8C=E5=A2=9E=E5=8A=A0?= =?UTF-8?q?=E5=AF=B9=E8=AF=B7=E6=B1=82=E6=AC=A1=E6=95=B0=E6=9C=80=E5=A4=A7?= =?UTF-8?q?=E5=80=BC=E7=9A=84=E9=99=90=E5=88=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- setting/rate_limit.go | 3 +++ web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js | 3 +++ 2 files changed, 6 insertions(+) diff --git a/setting/rate_limit.go b/setting/rate_limit.go index 53b53f88..535fd831 100644 --- a/setting/rate_limit.go +++ b/setting/rate_limit.go @@ -58,6 +58,9 @@ func CheckModelRequestRateLimitGroup(jsonStr string) error { if limits[0] < 0 || limits[1] < 1 { return fmt.Errorf("group %s has negative rate limit values: [%d, %d]", group, limits[0], limits[1]) } + if limits[0] > 2147483647 || limits[1] > 2147483647 { + return fmt.Errorf("group %s [%d, %d] has max rate limits value 2147483647", group, limits[0], limits[1]) + } } return nil diff --git a/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js b/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js index 85473ec9..d5b1c637 100644 --- a/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js +++ b/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js @@ -128,6 +128,7 @@ export default function RequestRateLimit(props) { label={t('用户每周期最多请求次数')} step={1} min={0} + max={2147483647} suffix={t('次')} extraText={t('包括失败请求的次数,0代表不限制')} field={'ModelRequestRateLimitCount'} @@ -144,6 +145,7 @@ export default function RequestRateLimit(props) { label={t('用户每周期最多请求完成次数')} step={1} min={1} + max={2147483647} suffix={t('次')} extraText={t('只包括请求成功的次数')} field={'ModelRequestRateLimitSuccessCount'} @@ -180,6 +182,7 @@ export default function RequestRateLimit(props) {
  • {t('使用 JSON 对象格式,格式为:{"组名": [最多请求次数, 最多请求完成次数]}')}
  • {t('示例:{"default": [200, 100], "vip": [0, 1000]}。')}
  • {t('[最多请求次数]必须大于等于0,[最多请求完成次数]必须大于等于1。')}
  • +
  • {t('[最多请求次数]和[最多请求完成次数]的最大值为2147483647。')}
  • {t('分组速率配置优先级高于全局速率限制。')}
  • {t('限制周期统一使用上方配置的“限制周期”值。')}
  • From a8a42cbfa8eb51ef9ecf9b02b53369e65291c8dd Mon Sep 17 00:00:00 2001 From: t0ng7u Date: Sat, 26 Jul 2025 17:18:47 +0800 Subject: [PATCH 06/24] =?UTF-8?q?=F0=9F=92=84=20style(ui):=20show=20"Force?= =?UTF-8?q?=20Format"=20toggle=20only=20for=20OpenAI=20channels?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Previously, the "Force Format" switch was displayed for every channel type although it only applies to OpenAI (type === 1). This change wraps the switch in a conditional so it renders exclusively when the selected channel type is OpenAI. Why: - Prevents user confusion when configuring non-OpenAI channels - Keeps the UI clean and context-relevant Scope: - web/src/components/table/channels/modals/EditChannelModal.jsx No backend logic affected. --- .../table/channels/modals/EditChannelModal.jsx | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/web/src/components/table/channels/modals/EditChannelModal.jsx b/web/src/components/table/channels/modals/EditChannelModal.jsx index a4c8ea76..248307c4 100644 --- a/web/src/components/table/channels/modals/EditChannelModal.jsx +++ b/web/src/components/table/channels/modals/EditChannelModal.jsx @@ -1592,14 +1592,16 @@ const EditChannelModal = (props) => { - handleChannelSettingsChange('force_format', value)} - extraText={t('强制将响应格式化为 OpenAI 标准格式(只适用于OpenAI渠道类型)')} - /> + {inputs.type === 1 && ( + handleChannelSettingsChange('force_format', value)} + extraText={t('强制将响应格式化为 OpenAI 标准格式(只适用于OpenAI渠道类型)')} + /> + )} Date: Sun, 27 Jul 2025 00:01:12 +0800 Subject: [PATCH 07/24] =?UTF-8?q?=F0=9F=94=8D=20fix:=20select=20search=20f?= =?UTF-8?q?ilter?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary • Introduced a unified `selectFilter` helper that matches both `option.value` and `option.label`, ensuring all ` -export const modelSelectFilter = (input, option) => { +// 使用方式: } + placeholder={t('搜索模型')} + value={keyword} + onChange={(v) => setKeyword(v)} + showClear + /> + + +
    + {filteredModels.length === 0 ? ( + } + darkModeImage={} + description={t('暂无匹配模型')} + style={{ padding: 30 }} + /> + ) : ( + setCheckedList(vals)}> + {activeTab === 'new' && newModels.length > 0 && ( +
    + {renderModelsByCategory(newModelsByCategory, 'new')} +
    + )} + {activeTab === 'existing' && existingModels.length > 0 && ( +
    + {renderModelsByCategory(existingModelsByCategory, 'existing')} +
    + )} +
    + )} +
    +
    + + +
    + {(() => { + const currentModels = activeTab === 'new' ? newModels : existingModels; + const currentSelected = currentModels.filter(model => checkedList.includes(model)).length; + const isAllSelected = currentModels.length > 0 && currentSelected === currentModels.length; + const isIndeterminate = currentSelected > 0 && currentSelected < currentModels.length; + + return ( + <> + + {t('已选择 {{selected}} / {{total}}', { + selected: currentSelected, + total: currentModels.length + })} + + { + handleCategorySelectAll(currentModels, e.target.checked); + }} + /> + + ); + })()} +
    +
    + + ); +}; + +export default ModelSelectModal; \ No newline at end of file diff --git a/web/src/i18n/locales/en.json b/web/src/i18n/locales/en.json index a1bf619d..29190b13 100644 --- a/web/src/i18n/locales/en.json +++ b/web/src/i18n/locales/en.json @@ -1799,5 +1799,10 @@ "显示第": "Showing", "条 - 第": "to", "条,共": "of", - "条": "items" + "条": "items", + "选择模型": "Select model", + "已选择 {{selected}} / {{total}}": "Selected {{selected}} / {{total}}", + "新获取的模型": "New models", + "已有的模型": "Existing models", + "搜索模型": "Search models" } \ No newline at end of file From 010f27678d6409d3922c18728c1d2d4ca916de06 Mon Sep 17 00:00:00 2001 From: CaIon Date: Tue, 29 Jul 2025 15:20:08 +0800 Subject: [PATCH 13/24] fix: auto ban --- relay/channel/gemini/relay-gemini-native.go | 6 +++--- relay/gemini_handler.go | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/relay/channel/gemini/relay-gemini-native.go b/relay/channel/gemini/relay-gemini-native.go index 0870e3fa..7d459cc2 100644 --- a/relay/channel/gemini/relay-gemini-native.go +++ b/relay/channel/gemini/relay-gemini-native.go @@ -20,7 +20,7 @@ func GeminiTextGenerationHandler(c *gin.Context, info *relaycommon.RelayInfo, re // 读取响应体 responseBody, err := io.ReadAll(resp.Body) if err != nil { - return nil, types.NewError(err, types.ErrorCodeBadResponseBody) + return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } if common.DebugEnabled { @@ -31,7 +31,7 @@ func GeminiTextGenerationHandler(c *gin.Context, info *relaycommon.RelayInfo, re var geminiResponse GeminiChatResponse err = common.Unmarshal(responseBody, &geminiResponse) if err != nil { - return nil, types.NewError(err, types.ErrorCodeBadResponseBody) + return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } // 计算使用量(基于 UsageMetadata) @@ -54,7 +54,7 @@ func GeminiTextGenerationHandler(c *gin.Context, info *relaycommon.RelayInfo, re // 直接返回 Gemini 原生格式的 JSON 响应 jsonResponse, err := common.Marshal(geminiResponse) if err != nil { - return nil, types.NewError(err, types.ErrorCodeBadResponseBody) + return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } common.IOCopyBytesGracefully(c, resp, jsonResponse) diff --git a/relay/gemini_handler.go b/relay/gemini_handler.go index 0f1aa5bf..3b27bfe2 100644 --- a/relay/gemini_handler.go +++ b/relay/gemini_handler.go @@ -230,7 +230,7 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) { resp, err := adaptor.DoRequest(c, relayInfo, requestBody) if err != nil { common.LogError(c, "Do gemini request failed: "+err.Error()) - return types.NewError(err, types.ErrorCodeDoRequestFailed) + return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError) } statusCodeMappingStr := c.GetString("status_code_mapping") From 95d46d1dfc4c21b203eaa14a0c6efb0d5c1d6154 Mon Sep 17 00:00:00 2001 From: CaIon Date: Tue, 29 Jul 2025 23:08:16 +0800 Subject: [PATCH 14/24] fix: auto ban --- controller/channel-test.go | 10 +++++----- service/error.go | 3 +-- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/controller/channel-test.go b/controller/channel-test.go index c1c3c21d..75fec463 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -209,7 +209,7 @@ func testChannel(channel *model.Channel, testModel string) testResult { return testResult{ context: c, localErr: err, - newAPIError: types.NewError(err, types.ErrorCodeDoRequestFailed), + newAPIError: types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError), } } var httpResp *http.Response @@ -220,7 +220,7 @@ func testChannel(channel *model.Channel, testModel string) testResult { return testResult{ context: c, localErr: err, - newAPIError: types.NewError(err, types.ErrorCodeBadResponse), + newAPIError: types.NewOpenAIError(err, types.ErrorCodeBadResponse, http.StatusInternalServerError), } } } @@ -236,7 +236,7 @@ func testChannel(channel *model.Channel, testModel string) testResult { return testResult{ context: c, localErr: errors.New("usage is nil"), - newAPIError: types.NewError(errors.New("usage is nil"), types.ErrorCodeBadResponseBody), + newAPIError: types.NewOpenAIError(errors.New("usage is nil"), types.ErrorCodeBadResponseBody, http.StatusInternalServerError), } } usage := usageA.(*dto.Usage) @@ -246,7 +246,7 @@ func testChannel(channel *model.Channel, testModel string) testResult { return testResult{ context: c, localErr: err, - newAPIError: types.NewError(err, types.ErrorCodeReadResponseBodyFailed), + newAPIError: types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), } } info.PromptTokens = usage.PromptTokens @@ -417,7 +417,7 @@ func testAllChannels(notify bool) error { if common.AutomaticDisableChannelEnabled && !shouldBanChannel { if milliseconds > disableThreshold { err := errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)) - newAPIError = types.NewError(err, types.ErrorCodeChannelResponseTimeExceeded) + newAPIError = types.NewOpenAIError(err, types.ErrorCodeChannelResponseTimeExceeded, http.StatusRequestTimeout) shouldBanChannel = true } } diff --git a/service/error.go b/service/error.go index 83979add..94d9c250 100644 --- a/service/error.go +++ b/service/error.go @@ -1,7 +1,6 @@ package service import ( - "encoding/json" "errors" "fmt" "io" @@ -112,7 +111,7 @@ func ResetStatusCode(newApiErr *types.NewAPIError, statusCodeMappingStr string) return } statusCodeMapping := make(map[string]string) - err := json.Unmarshal([]byte(statusCodeMappingStr), &statusCodeMapping) + err := common.Unmarshal([]byte(statusCodeMappingStr), &statusCodeMapping) if err != nil { return } From a8462c1b70c73769ee452989487693039f1da3a8 Mon Sep 17 00:00:00 2001 From: IcedTangerine Date: Wed, 30 Jul 2025 12:17:56 +0800 Subject: [PATCH 15/24] Revert "Update relay-claude.go" --- relay/channel/claude/relay-claude.go | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/relay/channel/claude/relay-claude.go b/relay/channel/claude/relay-claude.go index 7d6c973f..f20b573d 100644 --- a/relay/channel/claude/relay-claude.go +++ b/relay/channel/claude/relay-claude.go @@ -185,10 +185,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.Cla } // TODO: 临时处理 // https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#important-considerations-when-using-extended-thinking - // Anthropic 要求去掉 top_k - claudeRequest.TopK = nil - //top_p值可以在0.95-1之间 - claudeRequest.TopP = 0.95 + claudeRequest.TopP = 0 claudeRequest.Temperature = common.GetPointer[float64](1.0) claudeRequest.Model = strings.TrimSuffix(textRequest.Model, "-thinking") } From 0cd93d67ff9e398c3f44f251b2d701bef5199cc3 Mon Sep 17 00:00:00 2001 From: CaIon Date: Wed, 30 Jul 2025 18:39:19 +0800 Subject: [PATCH 16/24] fix: auto ban --- relay/channel/ali/image.go | 4 +- relay/channel/ali/rerank.go | 4 +- relay/channel/ali/text.go | 6 +- relay/channel/gemini/adaptor.go | 56 ---------------- relay/channel/gemini/relay-gemini.go | 66 +++++++++++++++++-- relay/channel/jimeng/image.go | 4 +- relay/channel/openai/relay-openai.go | 12 ++-- relay/channel/openai/relay_responses.go | 4 +- relay/channel/palm/relay-palm.go | 4 +- .../channel/siliconflow/relay-siliconflow.go | 4 +- relay/channel/tencent/relay-tencent.go | 4 +- relay/channel/vertex/adaptor.go | 12 +++- relay/channel/zhipu/relay-zhipu.go | 4 +- relay/common_handler/rerank.go | 6 +- relay/gemini_handler.go | 3 +- 15 files changed, 99 insertions(+), 94 deletions(-) diff --git a/relay/channel/ali/image.go b/relay/channel/ali/image.go index 0d430c62..754f29c8 100644 --- a/relay/channel/ali/image.go +++ b/relay/channel/ali/image.go @@ -132,12 +132,12 @@ func aliImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela var aliTaskResponse AliResponse responseBody, err := io.ReadAll(resp.Body) if err != nil { - return types.NewError(err, types.ErrorCodeReadResponseBodyFailed), nil + return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil } common.CloseResponseBodyGracefully(resp) err = json.Unmarshal(responseBody, &aliTaskResponse) if err != nil { - return types.NewError(err, types.ErrorCodeBadResponseBody), nil + return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil } if aliTaskResponse.Message != "" { diff --git a/relay/channel/ali/rerank.go b/relay/channel/ali/rerank.go index 59cb0a11..4f448e01 100644 --- a/relay/channel/ali/rerank.go +++ b/relay/channel/ali/rerank.go @@ -34,14 +34,14 @@ func ConvertRerankRequest(request dto.RerankRequest) *AliRerankRequest { func RerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.Usage) { responseBody, err := io.ReadAll(resp.Body) if err != nil { - return types.NewError(err, types.ErrorCodeReadResponseBodyFailed), nil + return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil } common.CloseResponseBodyGracefully(resp) var aliResponse AliRerankResponse err = json.Unmarshal(responseBody, &aliResponse) if err != nil { - return types.NewError(err, types.ErrorCodeBadResponseBody), nil + return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil } if aliResponse.Code != "" { diff --git a/relay/channel/ali/text.go b/relay/channel/ali/text.go index 6d90fa71..fcf63854 100644 --- a/relay/channel/ali/text.go +++ b/relay/channel/ali/text.go @@ -43,7 +43,7 @@ func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*types.NewAPIErro var fullTextResponse dto.FlexibleEmbeddingResponse err := json.NewDecoder(resp.Body).Decode(&fullTextResponse) if err != nil { - return types.NewError(err, types.ErrorCodeBadResponseBody), nil + return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil } common.CloseResponseBodyGracefully(resp) @@ -179,12 +179,12 @@ func aliHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, *dto.U var aliResponse AliResponse responseBody, err := io.ReadAll(resp.Body) if err != nil { - return types.NewError(err, types.ErrorCodeReadResponseBodyFailed), nil + return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil } common.CloseResponseBodyGracefully(resp) err = json.Unmarshal(responseBody, &aliResponse) if err != nil { - return types.NewError(err, types.ErrorCodeBadResponseBody), nil + return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil } if aliResponse.Code != "" { return types.WithOpenAIError(types.OpenAIError{ diff --git a/relay/channel/gemini/adaptor.go b/relay/channel/gemini/adaptor.go index 2e31ec55..da291aa9 100644 --- a/relay/channel/gemini/adaptor.go +++ b/relay/channel/gemini/adaptor.go @@ -1,12 +1,10 @@ package gemini import ( - "encoding/json" "errors" "fmt" "io" "net/http" - "one-api/common" "one-api/dto" "one-api/relay/channel" "one-api/relay/channel/openai" @@ -212,60 +210,6 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom return nil, types.NewError(errors.New("not implemented"), types.ErrorCodeBadResponseBody) } -func GeminiImageHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { - responseBody, readErr := io.ReadAll(resp.Body) - if readErr != nil { - return nil, types.NewError(readErr, types.ErrorCodeBadResponseBody) - } - _ = resp.Body.Close() - - var geminiResponse GeminiImageResponse - if jsonErr := json.Unmarshal(responseBody, &geminiResponse); jsonErr != nil { - return nil, types.NewError(jsonErr, types.ErrorCodeBadResponseBody) - } - - if len(geminiResponse.Predictions) == 0 { - return nil, types.NewError(errors.New("no images generated"), types.ErrorCodeBadResponseBody) - } - - // convert to openai format response - openAIResponse := dto.ImageResponse{ - Created: common.GetTimestamp(), - Data: make([]dto.ImageData, 0, len(geminiResponse.Predictions)), - } - - for _, prediction := range geminiResponse.Predictions { - if prediction.RaiFilteredReason != "" { - continue // skip filtered image - } - openAIResponse.Data = append(openAIResponse.Data, dto.ImageData{ - B64Json: prediction.BytesBase64Encoded, - }) - } - - jsonResponse, jsonErr := json.Marshal(openAIResponse) - if jsonErr != nil { - return nil, types.NewError(jsonErr, types.ErrorCodeBadResponseBody) - } - - c.Writer.Header().Set("Content-Type", "application/json") - c.Writer.WriteHeader(resp.StatusCode) - _, _ = c.Writer.Write(jsonResponse) - - // https://github.com/google-gemini/cookbook/blob/719a27d752aac33f39de18a8d3cb42a70874917e/quickstarts/Counting_Tokens.ipynb - // each image has fixed 258 tokens - const imageTokens = 258 - generatedImages := len(openAIResponse.Data) - - usage := &dto.Usage{ - PromptTokens: imageTokens * generatedImages, // each generated image has fixed 258 tokens - CompletionTokens: 0, // image generation does not calculate completion tokens - TotalTokens: imageTokens * generatedImages, - } - - return usage, nil -} - func (a *Adaptor) GetModelList() []string { return ModelList } diff --git a/relay/channel/gemini/relay-gemini.go b/relay/channel/gemini/relay-gemini.go index 7e57bdac..5dac0ce5 100644 --- a/relay/channel/gemini/relay-gemini.go +++ b/relay/channel/gemini/relay-gemini.go @@ -907,7 +907,7 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp * func GeminiChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { responseBody, err := io.ReadAll(resp.Body) if err != nil { - return nil, types.NewError(err, types.ErrorCodeBadResponseBody) + return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } common.CloseResponseBodyGracefully(resp) if common.DebugEnabled { @@ -916,10 +916,10 @@ func GeminiChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R var geminiResponse GeminiChatResponse err = common.Unmarshal(responseBody, &geminiResponse) if err != nil { - return nil, types.NewError(err, types.ErrorCodeBadResponseBody) + return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } if len(geminiResponse.Candidates) == 0 { - return nil, types.NewError(errors.New("no candidates returned"), types.ErrorCodeBadResponseBody) + return nil, types.NewOpenAIError(errors.New("no candidates returned"), types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } fullTextResponse := responseGeminiChat2OpenAI(c, &geminiResponse) fullTextResponse.Model = info.UpstreamModelName @@ -956,12 +956,12 @@ func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *h responseBody, readErr := io.ReadAll(resp.Body) if readErr != nil { - return nil, types.NewError(readErr, types.ErrorCodeBadResponseBody) + return nil, types.NewOpenAIError(readErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } var geminiResponse GeminiEmbeddingResponse if jsonErr := common.Unmarshal(responseBody, &geminiResponse); jsonErr != nil { - return nil, types.NewError(jsonErr, types.ErrorCodeBadResponseBody) + return nil, types.NewOpenAIError(jsonErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } // convert to openai format response @@ -991,9 +991,63 @@ func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *h jsonResponse, jsonErr := common.Marshal(openAIResponse) if jsonErr != nil { - return nil, types.NewError(jsonErr, types.ErrorCodeBadResponseBody) + return nil, types.NewOpenAIError(jsonErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } common.IOCopyBytesGracefully(c, resp, jsonResponse) return usage, nil } + +func GeminiImageHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { + responseBody, readErr := io.ReadAll(resp.Body) + if readErr != nil { + return nil, types.NewOpenAIError(readErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) + } + _ = resp.Body.Close() + + var geminiResponse GeminiImageResponse + if jsonErr := common.Unmarshal(responseBody, &geminiResponse); jsonErr != nil { + return nil, types.NewOpenAIError(jsonErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) + } + + if len(geminiResponse.Predictions) == 0 { + return nil, types.NewOpenAIError(errors.New("no images generated"), types.ErrorCodeBadResponseBody, http.StatusInternalServerError) + } + + // convert to openai format response + openAIResponse := dto.ImageResponse{ + Created: common.GetTimestamp(), + Data: make([]dto.ImageData, 0, len(geminiResponse.Predictions)), + } + + for _, prediction := range geminiResponse.Predictions { + if prediction.RaiFilteredReason != "" { + continue // skip filtered image + } + openAIResponse.Data = append(openAIResponse.Data, dto.ImageData{ + B64Json: prediction.BytesBase64Encoded, + }) + } + + jsonResponse, jsonErr := json.Marshal(openAIResponse) + if jsonErr != nil { + return nil, types.NewError(jsonErr, types.ErrorCodeBadResponseBody) + } + + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, _ = c.Writer.Write(jsonResponse) + + // https://github.com/google-gemini/cookbook/blob/719a27d752aac33f39de18a8d3cb42a70874917e/quickstarts/Counting_Tokens.ipynb + // each image has fixed 258 tokens + const imageTokens = 258 + generatedImages := len(openAIResponse.Data) + + usage := &dto.Usage{ + PromptTokens: imageTokens * generatedImages, // each generated image has fixed 258 tokens + CompletionTokens: 0, // image generation does not calculate completion tokens + TotalTokens: imageTokens * generatedImages, + } + + return usage, nil +} diff --git a/relay/channel/jimeng/image.go b/relay/channel/jimeng/image.go index 3c6a1d99..28af1866 100644 --- a/relay/channel/jimeng/image.go +++ b/relay/channel/jimeng/image.go @@ -52,13 +52,13 @@ func jimengImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.R var jimengResponse ImageResponse responseBody, err := io.ReadAll(resp.Body) if err != nil { - return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed) + return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError) } common.CloseResponseBodyGracefully(resp) err = json.Unmarshal(responseBody, &jimengResponse) if err != nil { - return nil, types.NewError(err, types.ErrorCodeBadResponseBody) + return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } // Check if the response indicates an error diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index 82bd2d26..2252b407 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -109,7 +109,7 @@ func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, fo func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { if resp == nil || resp.Body == nil { common.LogError(c, "invalid response or response body") - return nil, types.NewError(fmt.Errorf("invalid response"), types.ErrorCodeBadResponse) + return nil, types.NewOpenAIError(fmt.Errorf("invalid response"), types.ErrorCodeBadResponse, http.StatusInternalServerError) } defer common.CloseResponseBodyGracefully(resp) @@ -178,11 +178,11 @@ func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo var simpleResponse dto.OpenAITextResponse responseBody, err := io.ReadAll(resp.Body) if err != nil { - return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed) + return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError) } err = common.Unmarshal(responseBody, &simpleResponse) if err != nil { - return nil, types.NewError(err, types.ErrorCodeBadResponseBody) + return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } if simpleResponse.Error != nil && simpleResponse.Error.Type != "" { return nil, types.WithOpenAIError(*simpleResponse.Error, resp.StatusCode) @@ -263,7 +263,7 @@ func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel } responseBody, err := io.ReadAll(resp.Body) if err != nil { - return types.NewError(err, types.ErrorCodeReadResponseBodyFailed), nil + return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil } // 写入新的 response body common.IOCopyBytesGracefully(c, resp, responseBody) @@ -547,13 +547,13 @@ func OpenaiHandlerWithUsage(c *gin.Context, info *relaycommon.RelayInfo, resp *h responseBody, err := io.ReadAll(resp.Body) if err != nil { - return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed) + return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError) } var usageResp dto.SimpleResponse err = common.Unmarshal(responseBody, &usageResp) if err != nil { - return nil, types.NewError(err, types.ErrorCodeBadResponseBody) + return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } // 写入新的 response body diff --git a/relay/channel/openai/relay_responses.go b/relay/channel/openai/relay_responses.go index d9dd96b9..fd57924b 100644 --- a/relay/channel/openai/relay_responses.go +++ b/relay/channel/openai/relay_responses.go @@ -22,11 +22,11 @@ func OaiResponsesHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http var responsesResponse dto.OpenAIResponsesResponse responseBody, err := io.ReadAll(resp.Body) if err != nil { - return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed) + return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError) } err = common.Unmarshal(responseBody, &responsesResponse) if err != nil { - return nil, types.NewError(err, types.ErrorCodeBadResponseBody) + return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } if responsesResponse.Error != nil { return nil, types.WithOpenAIError(*responsesResponse.Error, resp.StatusCode) diff --git a/relay/channel/palm/relay-palm.go b/relay/channel/palm/relay-palm.go index 4db31573..cbd60f5e 100644 --- a/relay/channel/palm/relay-palm.go +++ b/relay/channel/palm/relay-palm.go @@ -127,13 +127,13 @@ func palmStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, func palmHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { responseBody, err := io.ReadAll(resp.Body) if err != nil { - return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed) + return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError) } common.CloseResponseBodyGracefully(resp) var palmResponse PaLMChatResponse err = json.Unmarshal(responseBody, &palmResponse) if err != nil { - return nil, types.NewError(err, types.ErrorCodeBadResponseBody) + return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } if palmResponse.Error.Code != 0 || len(palmResponse.Candidates) == 0 { return nil, types.WithOpenAIError(types.OpenAIError{ diff --git a/relay/channel/siliconflow/relay-siliconflow.go b/relay/channel/siliconflow/relay-siliconflow.go index fabaf9c6..2e37ad15 100644 --- a/relay/channel/siliconflow/relay-siliconflow.go +++ b/relay/channel/siliconflow/relay-siliconflow.go @@ -15,13 +15,13 @@ import ( func siliconflowRerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { responseBody, err := io.ReadAll(resp.Body) if err != nil { - return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed) + return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError) } common.CloseResponseBodyGracefully(resp) var siliconflowResp SFRerankResponse err = json.Unmarshal(responseBody, &siliconflowResp) if err != nil { - return nil, types.NewError(err, types.ErrorCodeBadResponseBody) + return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } usage := &dto.Usage{ PromptTokens: siliconflowResp.Meta.Tokens.InputTokens, diff --git a/relay/channel/tencent/relay-tencent.go b/relay/channel/tencent/relay-tencent.go index c3d96c49..78ce6238 100644 --- a/relay/channel/tencent/relay-tencent.go +++ b/relay/channel/tencent/relay-tencent.go @@ -136,12 +136,12 @@ func tencentHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Resp var tencentSb TencentChatResponseSB responseBody, err := io.ReadAll(resp.Body) if err != nil { - return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed) + return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError) } common.CloseResponseBodyGracefully(resp) err = json.Unmarshal(responseBody, &tencentSb) if err != nil { - return nil, types.NewError(err, types.ErrorCodeBadResponseBody) + return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } if tencentSb.Response.Error.Code != 0 { return nil, types.WithOpenAIError(types.OpenAIError{ diff --git a/relay/channel/vertex/adaptor.go b/relay/channel/vertex/adaptor.go index fa895de0..40c9ca89 100644 --- a/relay/channel/vertex/adaptor.go +++ b/relay/channel/vertex/adaptor.go @@ -67,11 +67,10 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf func (a *Adaptor) Init(info *relaycommon.RelayInfo) { if strings.HasPrefix(info.UpstreamModelName, "claude") { a.RequestMode = RequestModeClaude - } else if strings.HasPrefix(info.UpstreamModelName, "gemini") { - a.RequestMode = RequestModeGemini } else if strings.Contains(info.UpstreamModelName, "llama") { a.RequestMode = RequestModeLlama } + a.RequestMode = RequestModeGemini } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { @@ -83,6 +82,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { a.AccountCredentials = *adc suffix := "" if a.RequestMode == RequestModeGemini { + if model_setting.GetGeminiSettings().ThinkingAdapterEnabled { // 新增逻辑:处理 -thinking- 格式 if strings.Contains(info.UpstreamModelName, "-thinking-") { @@ -100,6 +100,11 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { } else { suffix = "generateContent" } + + if strings.HasPrefix(info.UpstreamModelName, "imagen") { + suffix = "predict" + } + if region == "global" { return fmt.Sprintf( "https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:%s", @@ -231,6 +236,9 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom if info.RelayMode == constant.RelayModeGemini { usage, err = gemini.GeminiTextGenerationHandler(c, info, resp) } else { + if strings.HasPrefix(info.UpstreamModelName, "imagen") { + return gemini.GeminiImageHandler(c, info, resp) + } usage, err = gemini.GeminiChatHandler(c, info, resp) } case RequestModeLlama: diff --git a/relay/channel/zhipu/relay-zhipu.go b/relay/channel/zhipu/relay-zhipu.go index 916a200d..35882ed5 100644 --- a/relay/channel/zhipu/relay-zhipu.go +++ b/relay/channel/zhipu/relay-zhipu.go @@ -220,12 +220,12 @@ func zhipuHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respon var zhipuResponse ZhipuResponse responseBody, err := io.ReadAll(resp.Body) if err != nil { - return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed) + return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError) } common.CloseResponseBodyGracefully(resp) err = json.Unmarshal(responseBody, &zhipuResponse) if err != nil { - return nil, types.NewError(err, types.ErrorCodeBadResponseBody) + return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } if !zhipuResponse.Success { return nil, types.WithOpenAIError(types.OpenAIError{ diff --git a/relay/common_handler/rerank.go b/relay/common_handler/rerank.go index ce823b3a..57df5fe3 100644 --- a/relay/common_handler/rerank.go +++ b/relay/common_handler/rerank.go @@ -16,7 +16,7 @@ import ( func RerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { responseBody, err := io.ReadAll(resp.Body) if err != nil { - return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed) + return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError) } common.CloseResponseBodyGracefully(resp) if common.DebugEnabled { @@ -27,7 +27,7 @@ func RerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo var xinRerankResponse xinference.XinRerankResponse err = common.Unmarshal(responseBody, &xinRerankResponse) if err != nil { - return nil, types.NewError(err, types.ErrorCodeBadResponseBody) + return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } jinaRespResults := make([]dto.RerankResponseResult, len(xinRerankResponse.Results)) for i, result := range xinRerankResponse.Results { @@ -62,7 +62,7 @@ func RerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo } else { err = common.Unmarshal(responseBody, &jinaResp) if err != nil { - return nil, types.NewError(err, types.ErrorCodeBadResponseBody) + return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } jinaResp.Usage.PromptTokens = jinaResp.Usage.TotalTokens } diff --git a/relay/gemini_handler.go b/relay/gemini_handler.go index 3b27bfe2..6da8b131 100644 --- a/relay/gemini_handler.go +++ b/relay/gemini_handler.go @@ -2,7 +2,6 @@ package relay import ( "bytes" - "encoding/json" "errors" "fmt" "io" @@ -203,7 +202,7 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) { } requestBody = bytes.NewReader(body) } else { - jsonData, err := json.Marshal(req) + jsonData, err := common.Marshal(req) if err != nil { return types.NewError(err, types.ErrorCodeConvertRequestFailed) } From f7b284ad73dc69265ef642652a48446836999d31 Mon Sep 17 00:00:00 2001 From: CaIon Date: Wed, 30 Jul 2025 19:08:35 +0800 Subject: [PATCH 17/24] =?UTF-8?q?feat:=20=E9=94=99=E8=AF=AF=E5=86=85?= =?UTF-8?q?=E5=AE=B9=E8=84=B1=E6=95=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- common/str.go | 95 +++++++++++++++++++++++++++++++++++++++++++++ controller/relay.go | 3 +- types/error.go | 29 +++++++++++--- 3 files changed, 119 insertions(+), 8 deletions(-) diff --git a/common/str.go b/common/str.go index 88b58c72..f5399eab 100644 --- a/common/str.go +++ b/common/str.go @@ -4,7 +4,10 @@ import ( "encoding/base64" "encoding/json" "math/rand" + "net/url" + "regexp" "strconv" + "strings" "unsafe" ) @@ -95,3 +98,95 @@ func GetJsonString(data any) string { b, _ := json.Marshal(data) return string(b) } + +// MaskSensitiveInfo masks sensitive information like URLs, IPs in a string +// Example: +// http://example.com -> http://***.com +// https://api.test.org/v1/users/123?key=secret -> https://***.org/***/***/?key=*** +// https://sub.domain.co.uk/path/to/resource -> https://***.co.uk/***/*** +// 192.168.1.1 -> ***.***.***.*** +func MaskSensitiveInfo(str string) string { + // Mask URLs + urlPattern := regexp.MustCompile(`(http|https)://[^\s/$.?#].[^\s]*`) + str = urlPattern.ReplaceAllStringFunc(str, func(urlStr string) string { + u, err := url.Parse(urlStr) + if err != nil { + return urlStr + } + + host := u.Host + if host == "" { + return urlStr + } + + // Split host by dots + parts := strings.Split(host, ".") + if len(parts) < 2 { + // If less than 2 parts, just mask the whole host + return u.Scheme + "://***" + u.Path + } + + // Keep the TLD (Top Level Domain) and mask the rest + var maskedHost string + if len(parts) == 2 { + // example.com -> ***.com + maskedHost = "***." + parts[len(parts)-1] + } else { + // Handle cases like sub.domain.co.uk or api.example.com + // Keep last 2 parts if they look like country code TLD (co.uk, com.cn, etc.) + lastPart := parts[len(parts)-1] + secondLastPart := parts[len(parts)-2] + + if len(lastPart) == 2 && len(secondLastPart) <= 3 { + // Likely country code TLD like co.uk, com.cn + maskedHost = "***." + secondLastPart + "." + lastPart + } else { + // Regular TLD like .com, .org + maskedHost = "***." + lastPart + } + } + + result := u.Scheme + "://" + maskedHost + + // Mask path + if u.Path != "" && u.Path != "/" { + pathParts := strings.Split(strings.Trim(u.Path, "/"), "/") + maskedPathParts := make([]string, len(pathParts)) + for i := range pathParts { + if pathParts[i] != "" { + maskedPathParts[i] = "***" + } + } + if len(maskedPathParts) > 0 { + result += "/" + strings.Join(maskedPathParts, "/") + } + } else if u.Path == "/" { + result += "/" + } + + // Mask query parameters + if u.RawQuery != "" { + values, err := url.ParseQuery(u.RawQuery) + if err != nil { + // If can't parse query, just mask the whole query string + result += "?***" + } else { + maskedParams := make([]string, 0, len(values)) + for key := range values { + maskedParams = append(maskedParams, key+"=***") + } + if len(maskedParams) > 0 { + result += "?" + strings.Join(maskedParams, "&") + } + } + } + + return result + }) + + // Mask IP addresses + ipPattern := regexp.MustCompile(`\b(?:\d{1,3}\.){3}\d{1,3}\b`) + str = ipPattern.ReplaceAllString(str, "***.***.***.***") + + return str +} diff --git a/controller/relay.go b/controller/relay.go index d4b5fd18..01081d3d 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -62,8 +62,7 @@ func relayHandler(c *gin.Context, relayMode int) *types.NewAPIError { other["channel_id"] = channelId other["channel_name"] = c.GetString("channel_name") other["channel_type"] = c.GetInt("channel_type") - - model.RecordErrorLog(c, userId, channelId, modelName, tokenName, err.Error(), tokenId, 0, false, userGroup, other) + model.RecordErrorLog(c, userId, channelId, modelName, tokenName, err.MaskSensitiveError(), tokenId, 0, false, userGroup, other) } return err diff --git a/types/error.go b/types/error.go index c94bd001..2a8105c7 100644 --- a/types/error.go +++ b/types/error.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "net/http" + "one-api/common" "strings" ) @@ -107,19 +108,30 @@ func (e *NewAPIError) Error() string { return e.Err.Error() } +func (e *NewAPIError) MaskSensitiveError() string { + if e == nil { + return "" + } + if e.Err == nil { + return string(e.errorCode) + } + return common.MaskSensitiveInfo(e.Err.Error()) +} + func (e *NewAPIError) SetMessage(message string) { e.Err = errors.New(message) } func (e *NewAPIError) ToOpenAIError() OpenAIError { + var result OpenAIError switch e.errorType { case ErrorTypeOpenAIError: if openAIError, ok := e.RelayError.(OpenAIError); ok { - return openAIError + result = openAIError } case ErrorTypeClaudeError: if claudeError, ok := e.RelayError.(ClaudeError); ok { - return OpenAIError{ + result = OpenAIError{ Message: e.Error(), Type: claudeError.Type, Param: "", @@ -127,30 +139,35 @@ func (e *NewAPIError) ToOpenAIError() OpenAIError { } } } - return OpenAIError{ + result = OpenAIError{ Message: e.Error(), Type: string(e.errorType), Param: "", Code: e.errorCode, } + result.Message = common.MaskSensitiveInfo(result.Message) + return result } func (e *NewAPIError) ToClaudeError() ClaudeError { + var result ClaudeError switch e.errorType { case ErrorTypeOpenAIError: openAIError := e.RelayError.(OpenAIError) - return ClaudeError{ + result = ClaudeError{ Message: e.Error(), Type: fmt.Sprintf("%v", openAIError.Code), } case ErrorTypeClaudeError: - return e.RelayError.(ClaudeError) + result = e.RelayError.(ClaudeError) default: - return ClaudeError{ + result = ClaudeError{ Message: e.Error(), Type: string(e.errorType), } } + result.Message = common.MaskSensitiveInfo(result.Message) + return result } func NewError(err error, errorCode ErrorCode) *NewAPIError { From e3d3e697d324bd55760acd424a425468be705d08 Mon Sep 17 00:00:00 2001 From: CaIon Date: Wed, 30 Jul 2025 20:31:51 +0800 Subject: [PATCH 18/24] fix: WriteContentType panic --- common/custom-event.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/common/custom-event.go b/common/custom-event.go index d8f9ec9f..256db546 100644 --- a/common/custom-event.go +++ b/common/custom-event.go @@ -9,6 +9,7 @@ import ( "io" "net/http" "strings" + "sync" ) type stringWriter interface { @@ -52,6 +53,8 @@ type CustomEvent struct { Id string Retry uint Data interface{} + + Mutex sync.Mutex } func encode(writer io.Writer, event CustomEvent) error { @@ -73,6 +76,8 @@ func (r CustomEvent) Render(w http.ResponseWriter) error { } func (r CustomEvent) WriteContentType(w http.ResponseWriter) { + r.Mutex.Lock() + defer r.Mutex.Unlock() header := w.Header() header["Content-Type"] = contentType From 1f5ef24ecdaf1cb4bd437bab628a67e04322d66e Mon Sep 17 00:00:00 2001 From: Xyfacai Date: Wed, 30 Jul 2025 22:35:31 +0800 Subject: [PATCH 19/24] =?UTF-8?q?feat:=20=E6=98=BE=E5=BC=8F=E6=8C=87?= =?UTF-8?q?=E5=AE=9A=20error=20=E8=B7=B3=E8=BF=87=E9=87=8D=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- controller/playground.go | 10 +++---- controller/relay.go | 8 +++--- middleware/distributor.go | 2 +- model/channel.go | 2 +- relay/audio_handler.go | 10 +++---- relay/claude_handler.go | 18 ++++++------- relay/embedding_handler.go | 15 +++++------ relay/gemini_handler.go | 16 +++++------ relay/image_handler.go | 20 +++++++------- relay/relay-text.go | 32 +++++++++++----------- relay/rerank_handler.go | 20 +++++++------- relay/responses_handler.go | 20 +++++++------- relay/websocket.go | 6 ++--- service/channel.go | 2 +- types/error.go | 54 ++++++++++++++++++++++++++++---------- 15 files changed, 129 insertions(+), 106 deletions(-) diff --git a/controller/playground.go b/controller/playground.go index 0073cf06..64c0e1ce 100644 --- a/controller/playground.go +++ b/controller/playground.go @@ -28,19 +28,19 @@ func Playground(c *gin.Context) { useAccessToken := c.GetBool("use_access_token") if useAccessToken { - newAPIError = types.NewError(errors.New("暂不支持使用 access token"), types.ErrorCodeAccessDenied) + newAPIError = types.NewError(errors.New("暂不支持使用 access token"), types.ErrorCodeAccessDenied, types.ErrOptionWithSkipRetry()) return } playgroundRequest := &dto.PlayGroundRequest{} err := common.UnmarshalBodyReusable(c, playgroundRequest) if err != nil { - newAPIError = types.NewError(err, types.ErrorCodeInvalidRequest) + newAPIError = types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) return } if playgroundRequest.Model == "" { - newAPIError = types.NewError(errors.New("请选择模型"), types.ErrorCodeInvalidRequest) + newAPIError = types.NewError(errors.New("请选择模型"), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) return } c.Set("original_model", playgroundRequest.Model) @@ -51,7 +51,7 @@ func Playground(c *gin.Context) { group = userGroup } else { if !setting.GroupInUserUsableGroups(group) && group != userGroup { - newAPIError = types.NewError(errors.New("无权访问该分组"), types.ErrorCodeAccessDenied) + newAPIError = types.NewError(errors.New("无权访问该分组"), types.ErrorCodeAccessDenied, types.ErrOptionWithSkipRetry()) return } c.Set("group", group) @@ -62,7 +62,7 @@ func Playground(c *gin.Context) { // Write user context to ensure acceptUnsetRatio is available userCache, err := model.GetUserCache(userId) if err != nil { - newAPIError = types.NewError(err, types.ErrorCodeQueryDataError) + newAPIError = types.NewError(err, types.ErrorCodeQueryDataError, types.ErrOptionWithSkipRetry()) return } userCache.WriteContext(c) diff --git a/controller/relay.go b/controller/relay.go index 01081d3d..e7318e9b 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -127,7 +127,7 @@ func WssRelay(c *gin.Context) { defer ws.Close() if err != nil { - helper.WssError(c, ws, types.NewError(err, types.ErrorCodeGetChannelFailed).ToOpenAIError()) + helper.WssError(c, ws, types.NewError(err, types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry()).ToOpenAIError()) return } @@ -258,10 +258,10 @@ func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*m } channel, selectGroup, err := model.CacheGetRandomSatisfiedChannel(c, group, originalModel, retryCount) if err != nil { - return nil, types.NewError(errors.New(fmt.Sprintf("获取分组 %s 下模型 %s 的可用渠道失败(retry): %s", selectGroup, originalModel, err.Error())), types.ErrorCodeGetChannelFailed) + return nil, types.NewError(errors.New(fmt.Sprintf("获取分组 %s 下模型 %s 的可用渠道失败(retry): %s", selectGroup, originalModel, err.Error())), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry()) } if channel == nil { - return nil, types.NewError(errors.New(fmt.Sprintf("分组 %s 下模型 %s 的可用渠道不存在(数据库一致性已被破坏,retry)", selectGroup, originalModel)), types.ErrorCodeGetChannelFailed) + return nil, types.NewError(errors.New(fmt.Sprintf("分组 %s 下模型 %s 的可用渠道不存在(数据库一致性已被破坏,retry)", selectGroup, originalModel)), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry()) } newAPIError := middleware.SetupContextForSelectedChannel(c, channel, originalModel) if newAPIError != nil { @@ -277,7 +277,7 @@ func shouldRetry(c *gin.Context, openaiErr *types.NewAPIError, retryTimes int) b if types.IsChannelError(openaiErr) { return true } - if types.IsLocalError(openaiErr) { + if types.IsSkipRetryError(openaiErr) { return false } if retryTimes <= 0 { diff --git a/middleware/distributor.go b/middleware/distributor.go index cba9b521..fb4a6645 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -247,7 +247,7 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) { func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, modelName string) *types.NewAPIError { c.Set("original_model", modelName) // for retry if channel == nil { - return types.NewError(errors.New("channel is nil"), types.ErrorCodeGetChannelFailed) + return types.NewError(errors.New("channel is nil"), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry()) } common.SetContextKey(c, constant.ContextKeyChannelId, channel.Id) common.SetContextKey(c, constant.ContextKeyChannelName, channel.Name) diff --git a/model/channel.go b/model/channel.go index 6277fcda..ea263c84 100644 --- a/model/channel.go +++ b/model/channel.go @@ -138,7 +138,7 @@ func (channel *Channel) GetNextEnabledKey() (string, int, *types.NewAPIError) { channelInfo, err := CacheGetChannelInfo(channel.Id) if err != nil { - return "", 0, types.NewError(err, types.ErrorCodeGetChannelFailed) + return "", 0, types.NewError(err, types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry()) } //println("before polling index:", channel.ChannelInfo.MultiKeyPollingIndex) defer func() { diff --git a/relay/audio_handler.go b/relay/audio_handler.go index f39dbd82..88777838 100644 --- a/relay/audio_handler.go +++ b/relay/audio_handler.go @@ -62,7 +62,7 @@ func AudioHelper(c *gin.Context) (newAPIError *types.NewAPIError) { if err != nil { common.LogError(c, fmt.Sprintf("getAndValidAudioRequest failed: %s", err.Error())) - return types.NewError(err, types.ErrorCodeInvalidRequest) + return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) } promptTokens := 0 @@ -75,7 +75,7 @@ func AudioHelper(c *gin.Context) (newAPIError *types.NewAPIError) { priceData, err := helper.ModelPriceHelper(c, relayInfo, preConsumedTokens, 0) if err != nil { - return types.NewError(err, types.ErrorCodeModelPriceError) + return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry()) } preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) @@ -90,18 +90,18 @@ func AudioHelper(c *gin.Context) (newAPIError *types.NewAPIError) { err = helper.ModelMappedHelper(c, relayInfo, audioRequest) if err != nil { - return types.NewError(err, types.ErrorCodeChannelModelMappedError) + return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry()) } adaptor := GetAdaptor(relayInfo.ApiType) if adaptor == nil { - return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType) + return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry()) } adaptor.Init(relayInfo) ioReader, err := adaptor.ConvertAudioRequest(c, relayInfo, *audioRequest) if err != nil { - return types.NewError(err, types.ErrorCodeConvertRequestFailed) + return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } resp, err := adaptor.DoRequest(c, relayInfo, ioReader) diff --git a/relay/claude_handler.go b/relay/claude_handler.go index 2c60a91e..b4bf78ff 100644 --- a/relay/claude_handler.go +++ b/relay/claude_handler.go @@ -40,7 +40,7 @@ func ClaudeHelper(c *gin.Context) (newAPIError *types.NewAPIError) { // get & validate textRequest 获取并验证文本请求 textRequest, err := getAndValidateClaudeRequest(c) if err != nil { - return types.NewError(err, types.ErrorCodeInvalidRequest) + return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) } if textRequest.Stream { @@ -49,18 +49,18 @@ func ClaudeHelper(c *gin.Context) (newAPIError *types.NewAPIError) { err = helper.ModelMappedHelper(c, relayInfo, textRequest) if err != nil { - return types.NewError(err, types.ErrorCodeChannelModelMappedError) + return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry()) } promptTokens, err := getClaudePromptTokens(textRequest, relayInfo) // count messages token error 计算promptTokens错误 if err != nil { - return types.NewError(err, types.ErrorCodeCountTokenFailed) + return types.NewError(err, types.ErrorCodeCountTokenFailed, types.ErrOptionWithSkipRetry()) } priceData, err := helper.ModelPriceHelper(c, relayInfo, promptTokens, int(textRequest.MaxTokens)) if err != nil { - return types.NewError(err, types.ErrorCodeModelPriceError) + return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry()) } // pre-consume quota 预消耗配额 @@ -77,7 +77,7 @@ func ClaudeHelper(c *gin.Context) (newAPIError *types.NewAPIError) { adaptor := GetAdaptor(relayInfo.ApiType) if adaptor == nil { - return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType) + return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry()) } adaptor.Init(relayInfo) @@ -111,17 +111,17 @@ func ClaudeHelper(c *gin.Context) (newAPIError *types.NewAPIError) { if model_setting.GetGlobalSettings().PassThroughRequestEnabled || relayInfo.ChannelSetting.PassThroughBodyEnabled { body, err := common.GetRequestBody(c) if err != nil { - return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest) + return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry()) } requestBody = bytes.NewBuffer(body) } else { convertedRequest, err := adaptor.ConvertClaudeRequest(c, relayInfo, textRequest) if err != nil { - return types.NewError(err, types.ErrorCodeConvertRequestFailed) + return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } jsonData, err := common.Marshal(convertedRequest) if err != nil { - return types.NewError(err, types.ErrorCodeConvertRequestFailed) + return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } // apply param override @@ -133,7 +133,7 @@ func ClaudeHelper(c *gin.Context) (newAPIError *types.NewAPIError) { } jsonData, err = common.Marshal(reqMap) if err != nil { - return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid) + return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry()) } } diff --git a/relay/embedding_handler.go b/relay/embedding_handler.go index be11bb2b..fef8d2c9 100644 --- a/relay/embedding_handler.go +++ b/relay/embedding_handler.go @@ -41,17 +41,17 @@ func EmbeddingHelper(c *gin.Context) (newAPIError *types.NewAPIError) { err := common.UnmarshalBodyReusable(c, &embeddingRequest) if err != nil { common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error())) - return types.NewError(err, types.ErrorCodeInvalidRequest) + return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) } err = validateEmbeddingRequest(c, relayInfo, *embeddingRequest) if err != nil { - return types.NewError(err, types.ErrorCodeInvalidRequest) + return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) } err = helper.ModelMappedHelper(c, relayInfo, embeddingRequest) if err != nil { - return types.NewError(err, types.ErrorCodeChannelModelMappedError) + return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry()) } promptToken := getEmbeddingPromptToken(*embeddingRequest) @@ -59,7 +59,7 @@ func EmbeddingHelper(c *gin.Context) (newAPIError *types.NewAPIError) { priceData, err := helper.ModelPriceHelper(c, relayInfo, promptToken, 0) if err != nil { - return types.NewError(err, types.ErrorCodeModelPriceError) + return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry()) } // pre-consume quota 预消耗配额 preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) @@ -74,18 +74,17 @@ func EmbeddingHelper(c *gin.Context) (newAPIError *types.NewAPIError) { adaptor := GetAdaptor(relayInfo.ApiType) if adaptor == nil { - return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType) + return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry()) } adaptor.Init(relayInfo) convertedRequest, err := adaptor.ConvertEmbeddingRequest(c, relayInfo, *embeddingRequest) - if err != nil { - return types.NewError(err, types.ErrorCodeConvertRequestFailed) + return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } jsonData, err := json.Marshal(convertedRequest) if err != nil { - return types.NewError(err, types.ErrorCodeConvertRequestFailed) + return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } requestBody := bytes.NewBuffer(jsonData) statusCodeMappingStr := c.GetString("status_code_mapping") diff --git a/relay/gemini_handler.go b/relay/gemini_handler.go index 2e2c1480..43c7ca58 100644 --- a/relay/gemini_handler.go +++ b/relay/gemini_handler.go @@ -109,7 +109,7 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) { req, err := getAndValidateGeminiRequest(c) if err != nil { common.LogError(c, fmt.Sprintf("getAndValidateGeminiRequest error: %s", err.Error())) - return types.NewError(err, types.ErrorCodeInvalidRequest) + return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) } relayInfo := relaycommon.GenRelayInfoGemini(c) @@ -121,14 +121,14 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) { sensitiveWords, err := checkGeminiInputSensitive(req) if err != nil { common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(sensitiveWords, ", "))) - return types.NewError(err, types.ErrorCodeSensitiveWordsDetected) + return types.NewError(err, types.ErrorCodeSensitiveWordsDetected, types.ErrOptionWithSkipRetry()) } } // model mapped 模型映射 err = helper.ModelMappedHelper(c, relayInfo, req) if err != nil { - return types.NewError(err, types.ErrorCodeChannelModelMappedError) + return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry()) } if value, exists := c.Get("prompt_tokens"); exists { @@ -159,7 +159,7 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) { priceData, err := helper.ModelPriceHelper(c, relayInfo, relayInfo.PromptTokens, int(req.GenerationConfig.MaxOutputTokens)) if err != nil { - return types.NewError(err, types.ErrorCodeModelPriceError) + return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry()) } // pre consume quota @@ -175,7 +175,7 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) { adaptor := GetAdaptor(relayInfo.ApiType) if adaptor == nil { - return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType) + return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry()) } adaptor.Init(relayInfo) @@ -198,13 +198,13 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) { if model_setting.GetGlobalSettings().PassThroughRequestEnabled || relayInfo.ChannelSetting.PassThroughBodyEnabled { body, err := common.GetRequestBody(c) if err != nil { - return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest) + return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry()) } requestBody = bytes.NewReader(body) } else { jsonData, err := common.Marshal(req) if err != nil { - return types.NewError(err, types.ErrorCodeConvertRequestFailed) + return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } // apply param override @@ -216,7 +216,7 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) { } jsonData, err = common.Marshal(reqMap) if err != nil { - return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid) + return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry()) } } diff --git a/relay/image_handler.go b/relay/image_handler.go index c97eb48e..f0b69699 100644 --- a/relay/image_handler.go +++ b/relay/image_handler.go @@ -115,17 +115,17 @@ func ImageHelper(c *gin.Context) (newAPIError *types.NewAPIError) { imageRequest, err := getAndValidImageRequest(c, relayInfo) if err != nil { common.LogError(c, fmt.Sprintf("getAndValidImageRequest failed: %s", err.Error())) - return types.NewError(err, types.ErrorCodeInvalidRequest) + return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) } err = helper.ModelMappedHelper(c, relayInfo, imageRequest) if err != nil { - return types.NewError(err, types.ErrorCodeChannelModelMappedError) + return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry()) } priceData, err := helper.ModelPriceHelper(c, relayInfo, len(imageRequest.Prompt), 0) if err != nil { - return types.NewError(err, types.ErrorCodeModelPriceError) + return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry()) } var preConsumedQuota int var quota int @@ -173,16 +173,16 @@ func ImageHelper(c *gin.Context) (newAPIError *types.NewAPIError) { quota = int(priceData.ModelPrice * priceData.GroupRatioInfo.GroupRatio * common.QuotaPerUnit) userQuota, err = model.GetUserQuota(relayInfo.UserId, false) if err != nil { - return types.NewError(err, types.ErrorCodeQueryDataError) + return types.NewError(err, types.ErrorCodeQueryDataError, types.ErrOptionWithSkipRetry()) } if userQuota-quota < 0 { - return types.NewError(fmt.Errorf("image pre-consumed quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(quota)), types.ErrorCodeInsufficientUserQuota) + return types.NewError(fmt.Errorf("image pre-consumed quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(quota)), types.ErrorCodeInsufficientUserQuota, types.ErrOptionWithSkipRetry()) } } adaptor := GetAdaptor(relayInfo.ApiType) if adaptor == nil { - return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType) + return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry()) } adaptor.Init(relayInfo) @@ -191,20 +191,20 @@ func ImageHelper(c *gin.Context) (newAPIError *types.NewAPIError) { if model_setting.GetGlobalSettings().PassThroughRequestEnabled || relayInfo.ChannelSetting.PassThroughBodyEnabled { body, err := common.GetRequestBody(c) if err != nil { - return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest) + return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry()) } requestBody = bytes.NewBuffer(body) } else { convertedRequest, err := adaptor.ConvertImageRequest(c, relayInfo, *imageRequest) if err != nil { - return types.NewError(err, types.ErrorCodeConvertRequestFailed) + return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } if relayInfo.RelayMode == relayconstant.RelayModeImagesEdits { requestBody = convertedRequest.(io.Reader) } else { jsonData, err := json.Marshal(convertedRequest) if err != nil { - return types.NewError(err, types.ErrorCodeConvertRequestFailed) + return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } // apply param override @@ -216,7 +216,7 @@ func ImageHelper(c *gin.Context) (newAPIError *types.NewAPIError) { } jsonData, err = common.Marshal(reqMap) if err != nil { - return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid) + return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry()) } } diff --git a/relay/relay-text.go b/relay/relay-text.go index 1856a2a1..97313be6 100644 --- a/relay/relay-text.go +++ b/relay/relay-text.go @@ -90,9 +90,8 @@ func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) { // get & validate textRequest 获取并验证文本请求 textRequest, err := getAndValidateTextRequest(c, relayInfo) - if err != nil { - return types.NewError(err, types.ErrorCodeInvalidRequest) + return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) } if textRequest.WebSearchOptions != nil { @@ -103,13 +102,13 @@ func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) { words, err := checkRequestSensitive(textRequest, relayInfo) if err != nil { common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(words, ", "))) - return types.NewError(err, types.ErrorCodeSensitiveWordsDetected) + return types.NewError(err, types.ErrorCodeSensitiveWordsDetected, types.ErrOptionWithSkipRetry()) } } err = helper.ModelMappedHelper(c, relayInfo, textRequest) if err != nil { - return types.NewError(err, types.ErrorCodeChannelModelMappedError) + return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry()) } // 获取 promptTokens,如果上下文中已经存在,则直接使用 @@ -121,14 +120,14 @@ func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) { promptTokens, err = getPromptTokens(textRequest, relayInfo) // count messages token error 计算promptTokens错误 if err != nil { - return types.NewError(err, types.ErrorCodeCountTokenFailed) + return types.NewError(err, types.ErrorCodeCountTokenFailed, types.ErrOptionWithSkipRetry()) } c.Set("prompt_tokens", promptTokens) } priceData, err := helper.ModelPriceHelper(c, relayInfo, promptTokens, int(math.Max(float64(textRequest.MaxTokens), float64(textRequest.MaxCompletionTokens)))) if err != nil { - return types.NewError(err, types.ErrorCodeModelPriceError) + return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry()) } // pre-consume quota 预消耗配额 @@ -165,7 +164,7 @@ func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) { adaptor := GetAdaptor(relayInfo.ApiType) if adaptor == nil { - return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType) + return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry()) } adaptor.Init(relayInfo) var requestBody io.Reader @@ -173,7 +172,7 @@ func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) { if model_setting.GetGlobalSettings().PassThroughRequestEnabled || relayInfo.ChannelSetting.PassThroughBodyEnabled { body, err := common.GetRequestBody(c) if err != nil { - return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest) + return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry()) } if common.DebugEnabled { println("requestBody: ", string(body)) @@ -182,7 +181,7 @@ func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) { } else { convertedRequest, err := adaptor.ConvertOpenAIRequest(c, relayInfo, textRequest) if err != nil { - return types.NewError(err, types.ErrorCodeConvertRequestFailed) + return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } if relayInfo.ChannelSetting.SystemPrompt != "" { @@ -207,7 +206,7 @@ func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) { jsonData, err := common.Marshal(convertedRequest) if err != nil { - return types.NewError(err, types.ErrorCodeConvertRequestFailed) + return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } // apply param override @@ -219,7 +218,7 @@ func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) { } jsonData, err = common.Marshal(reqMap) if err != nil { - return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid) + return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry()) } } @@ -231,7 +230,6 @@ func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) { var httpResp *http.Response resp, err := adaptor.DoRequest(c, relayInfo, requestBody) - if err != nil { return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError) } @@ -304,13 +302,13 @@ func checkRequestSensitive(textRequest *dto.GeneralOpenAIRequest, info *relaycom func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommon.RelayInfo) (int, int, *types.NewAPIError) { userQuota, err := model.GetUserQuota(relayInfo.UserId, false) if err != nil { - return 0, 0, types.NewError(err, types.ErrorCodeQueryDataError) + return 0, 0, types.NewError(err, types.ErrorCodeQueryDataError, types.ErrOptionWithSkipRetry()) } if userQuota <= 0 { - return 0, 0, types.NewErrorWithStatusCode(errors.New("user quota is not enough"), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden) + return 0, 0, types.NewErrorWithStatusCode(errors.New("user quota is not enough"), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry()) } if userQuota-preConsumedQuota < 0 { - return 0, 0, types.NewErrorWithStatusCode(fmt.Errorf("pre-consume quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(preConsumedQuota)), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden) + return 0, 0, types.NewErrorWithStatusCode(fmt.Errorf("pre-consume quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(preConsumedQuota)), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry()) } relayInfo.UserQuota = userQuota if userQuota > 100*preConsumedQuota { @@ -334,11 +332,11 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo if preConsumedQuota > 0 { err := service.PreConsumeTokenQuota(relayInfo, preConsumedQuota) if err != nil { - return 0, 0, types.NewErrorWithStatusCode(err, types.ErrorCodePreConsumeTokenQuotaFailed, http.StatusForbidden) + return 0, 0, types.NewErrorWithStatusCode(err, types.ErrorCodePreConsumeTokenQuotaFailed, http.StatusForbidden, types.ErrOptionWithSkipRetry()) } err = model.DecreaseUserQuota(relayInfo.UserId, preConsumedQuota) if err != nil { - return 0, 0, types.NewError(err, types.ErrorCodeUpdateDataError) + return 0, 0, types.NewError(err, types.ErrorCodeUpdateDataError, types.ErrOptionWithSkipRetry()) } } return preConsumedQuota, userQuota, nil diff --git a/relay/rerank_handler.go b/relay/rerank_handler.go index 0190cf08..1e547e2a 100644 --- a/relay/rerank_handler.go +++ b/relay/rerank_handler.go @@ -31,21 +31,21 @@ func RerankHelper(c *gin.Context, relayMode int) (newAPIError *types.NewAPIError err := common.UnmarshalBodyReusable(c, &rerankRequest) if err != nil { common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error())) - return types.NewError(err, types.ErrorCodeInvalidRequest) + return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) } relayInfo := relaycommon.GenRelayInfoRerank(c, rerankRequest) if rerankRequest.Query == "" { - return types.NewError(fmt.Errorf("query is empty"), types.ErrorCodeInvalidRequest) + return types.NewError(fmt.Errorf("query is empty"), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) } if len(rerankRequest.Documents) == 0 { - return types.NewError(fmt.Errorf("documents is empty"), types.ErrorCodeInvalidRequest) + return types.NewError(fmt.Errorf("documents is empty"), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) } err = helper.ModelMappedHelper(c, relayInfo, rerankRequest) if err != nil { - return types.NewError(err, types.ErrorCodeChannelModelMappedError) + return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry()) } promptToken := getRerankPromptToken(*rerankRequest) @@ -53,7 +53,7 @@ func RerankHelper(c *gin.Context, relayMode int) (newAPIError *types.NewAPIError priceData, err := helper.ModelPriceHelper(c, relayInfo, promptToken, 0) if err != nil { - return types.NewError(err, types.ErrorCodeModelPriceError) + return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry()) } // pre-consume quota 预消耗配额 preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) @@ -68,7 +68,7 @@ func RerankHelper(c *gin.Context, relayMode int) (newAPIError *types.NewAPIError adaptor := GetAdaptor(relayInfo.ApiType) if adaptor == nil { - return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType) + return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry()) } adaptor.Init(relayInfo) @@ -76,17 +76,17 @@ func RerankHelper(c *gin.Context, relayMode int) (newAPIError *types.NewAPIError if model_setting.GetGlobalSettings().PassThroughRequestEnabled || relayInfo.ChannelSetting.PassThroughBodyEnabled { body, err := common.GetRequestBody(c) if err != nil { - return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest) + return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry()) } requestBody = bytes.NewBuffer(body) } else { convertedRequest, err := adaptor.ConvertRerankRequest(c, relayInfo.RelayMode, *rerankRequest) if err != nil { - return types.NewError(err, types.ErrorCodeConvertRequestFailed) + return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } jsonData, err := common.Marshal(convertedRequest) if err != nil { - return types.NewError(err, types.ErrorCodeConvertRequestFailed) + return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } // apply param override @@ -98,7 +98,7 @@ func RerankHelper(c *gin.Context, relayMode int) (newAPIError *types.NewAPIError } jsonData, err = common.Marshal(reqMap) if err != nil { - return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid) + return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry()) } } diff --git a/relay/responses_handler.go b/relay/responses_handler.go index 52d1db6e..65c240b2 100644 --- a/relay/responses_handler.go +++ b/relay/responses_handler.go @@ -51,7 +51,7 @@ func ResponsesHelper(c *gin.Context) (newAPIError *types.NewAPIError) { req, err := getAndValidateResponsesRequest(c) if err != nil { common.LogError(c, fmt.Sprintf("getAndValidateResponsesRequest error: %s", err.Error())) - return types.NewError(err, types.ErrorCodeInvalidRequest) + return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) } relayInfo := relaycommon.GenRelayInfoResponses(c, req) @@ -60,13 +60,13 @@ func ResponsesHelper(c *gin.Context) (newAPIError *types.NewAPIError) { sensitiveWords, err := checkInputSensitive(req, relayInfo) if err != nil { common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(sensitiveWords, ", "))) - return types.NewError(err, types.ErrorCodeSensitiveWordsDetected) + return types.NewError(err, types.ErrorCodeSensitiveWordsDetected, types.ErrOptionWithSkipRetry()) } } err = helper.ModelMappedHelper(c, relayInfo, req) if err != nil { - return types.NewError(err, types.ErrorCodeChannelModelMappedError) + return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry()) } if value, exists := c.Get("prompt_tokens"); exists { @@ -79,7 +79,7 @@ func ResponsesHelper(c *gin.Context) (newAPIError *types.NewAPIError) { priceData, err := helper.ModelPriceHelper(c, relayInfo, relayInfo.PromptTokens, int(req.MaxOutputTokens)) if err != nil { - return types.NewError(err, types.ErrorCodeModelPriceError) + return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry()) } // pre consume quota preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) @@ -93,38 +93,38 @@ func ResponsesHelper(c *gin.Context) (newAPIError *types.NewAPIError) { }() adaptor := GetAdaptor(relayInfo.ApiType) if adaptor == nil { - return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType) + return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry()) } adaptor.Init(relayInfo) var requestBody io.Reader if model_setting.GetGlobalSettings().PassThroughRequestEnabled { body, err := common.GetRequestBody(c) if err != nil { - return types.NewError(err, types.ErrorCodeReadRequestBodyFailed) + return types.NewError(err, types.ErrorCodeReadRequestBodyFailed, types.ErrOptionWithSkipRetry()) } requestBody = bytes.NewBuffer(body) } else { convertedRequest, err := adaptor.ConvertOpenAIResponsesRequest(c, relayInfo, *req) if err != nil { - return types.NewError(err, types.ErrorCodeConvertRequestFailed) + return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } jsonData, err := json.Marshal(convertedRequest) if err != nil { - return types.NewError(err, types.ErrorCodeConvertRequestFailed) + return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } // apply param override if len(relayInfo.ParamOverride) > 0 { reqMap := make(map[string]interface{}) err = json.Unmarshal(jsonData, &reqMap) if err != nil { - return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid) + return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry()) } for key, value := range relayInfo.ParamOverride { reqMap[key] = value } jsonData, err = json.Marshal(reqMap) if err != nil { - return types.NewError(err, types.ErrorCodeConvertRequestFailed) + return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } } diff --git a/relay/websocket.go b/relay/websocket.go index 659e27d5..3715b237 100644 --- a/relay/websocket.go +++ b/relay/websocket.go @@ -24,12 +24,12 @@ func WssHelper(c *gin.Context, ws *websocket.Conn) (newAPIError *types.NewAPIErr err := helper.ModelMappedHelper(c, relayInfo, nil) if err != nil { - return types.NewError(err, types.ErrorCodeChannelModelMappedError) + return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry()) } priceData, err := helper.ModelPriceHelper(c, relayInfo, 0, 0) if err != nil { - return types.NewError(err, types.ErrorCodeModelPriceError) + return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry()) } // pre-consume quota 预消耗配额 @@ -46,7 +46,7 @@ func WssHelper(c *gin.Context, ws *websocket.Conn) (newAPIError *types.NewAPIErr adaptor := GetAdaptor(relayInfo.ApiType) if adaptor == nil { - return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType) + return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry()) } adaptor.Init(relayInfo) //var requestBody io.Reader diff --git a/service/channel.go b/service/channel.go index 4d38e6ed..faac6d10 100644 --- a/service/channel.go +++ b/service/channel.go @@ -45,7 +45,7 @@ func ShouldDisableChannel(channelType int, err *types.NewAPIError) bool { if types.IsChannelError(err) { return true } - if types.IsLocalError(err) { + if types.IsSkipRetryError(err) { return false } if err.StatusCode == http.StatusUnauthorized { diff --git a/types/error.go b/types/error.go index 2a8105c7..74c3bae5 100644 --- a/types/error.go +++ b/types/error.go @@ -78,6 +78,7 @@ const ( type NewAPIError struct { Err error RelayError any + skipRetry bool errorType ErrorType errorCode ErrorCode StatusCode int @@ -170,33 +171,39 @@ func (e *NewAPIError) ToClaudeError() ClaudeError { return result } -func NewError(err error, errorCode ErrorCode) *NewAPIError { - return &NewAPIError{ +type NewAPIErrorOptions func(*NewAPIError) + +func NewError(err error, errorCode ErrorCode, ops ...NewAPIErrorOptions) *NewAPIError { + e := &NewAPIError{ Err: err, RelayError: nil, errorType: ErrorTypeNewAPIError, StatusCode: http.StatusInternalServerError, errorCode: errorCode, } + for _, op := range ops { + op(e) + } + return e } -func NewOpenAIError(err error, errorCode ErrorCode, statusCode int) *NewAPIError { +func NewOpenAIError(err error, errorCode ErrorCode, statusCode int, ops ...NewAPIErrorOptions) *NewAPIError { openaiError := OpenAIError{ Message: err.Error(), Type: string(errorCode), } - return WithOpenAIError(openaiError, statusCode) + return WithOpenAIError(openaiError, statusCode, ops...) } -func InitOpenAIError(errorCode ErrorCode, statusCode int) *NewAPIError { +func InitOpenAIError(errorCode ErrorCode, statusCode int, ops ...NewAPIErrorOptions) *NewAPIError { openaiError := OpenAIError{ Type: string(errorCode), } - return WithOpenAIError(openaiError, statusCode) + return WithOpenAIError(openaiError, statusCode, ops...) } -func NewErrorWithStatusCode(err error, errorCode ErrorCode, statusCode int) *NewAPIError { - return &NewAPIError{ +func NewErrorWithStatusCode(err error, errorCode ErrorCode, statusCode int, ops ...NewAPIErrorOptions) *NewAPIError { + e := &NewAPIError{ Err: err, RelayError: OpenAIError{ Message: err.Error(), @@ -206,9 +213,14 @@ func NewErrorWithStatusCode(err error, errorCode ErrorCode, statusCode int) *New StatusCode: statusCode, errorCode: errorCode, } + for _, op := range ops { + op(e) + } + + return e } -func WithOpenAIError(openAIError OpenAIError, statusCode int) *NewAPIError { +func WithOpenAIError(openAIError OpenAIError, statusCode int, ops ...NewAPIErrorOptions) *NewAPIError { code, ok := openAIError.Code.(string) if !ok { code = fmt.Sprintf("%v", openAIError.Code) @@ -216,26 +228,34 @@ func WithOpenAIError(openAIError OpenAIError, statusCode int) *NewAPIError { if openAIError.Type == "" { openAIError.Type = "upstream_error" } - return &NewAPIError{ + e := &NewAPIError{ RelayError: openAIError, errorType: ErrorTypeOpenAIError, StatusCode: statusCode, Err: errors.New(openAIError.Message), errorCode: ErrorCode(code), } + for _, op := range ops { + op(e) + } + return e } -func WithClaudeError(claudeError ClaudeError, statusCode int) *NewAPIError { +func WithClaudeError(claudeError ClaudeError, statusCode int, ops ...NewAPIErrorOptions) *NewAPIError { if claudeError.Type == "" { claudeError.Type = "upstream_error" } - return &NewAPIError{ + e := &NewAPIError{ RelayError: claudeError, errorType: ErrorTypeClaudeError, StatusCode: statusCode, Err: errors.New(claudeError.Message), errorCode: ErrorCode(claudeError.Type), } + for _, op := range ops { + op(e) + } + return e } func IsChannelError(err *NewAPIError) bool { @@ -245,10 +265,16 @@ func IsChannelError(err *NewAPIError) bool { return strings.HasPrefix(string(err.errorCode), "channel:") } -func IsLocalError(err *NewAPIError) bool { +func IsSkipRetryError(err *NewAPIError) bool { if err == nil { return false } - return err.errorType == ErrorTypeNewAPIError + return err.skipRetry +} + +func ErrOptionWithSkipRetry() NewAPIErrorOptions { + return func(e *NewAPIError) { + e.skipRetry = true + } } From fc09051d8b9fbf55abc6997f4ed278a1ca19a8e0 Mon Sep 17 00:00:00 2001 From: CaIon Date: Wed, 30 Jul 2025 23:26:09 +0800 Subject: [PATCH 20/24] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=E7=BC=93?= =?UTF-8?q?=E5=AD=98=E5=BC=80=E5=90=AF=E4=B8=8B=E8=87=AA=E5=8A=A8=E7=A6=81?= =?UTF-8?q?=E7=94=A8=E5=A4=B1=E6=95=88?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- model/channel.go | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/model/channel.go b/model/channel.go index ea263c84..e3535d64 100644 --- a/model/channel.go +++ b/model/channel.go @@ -75,7 +75,7 @@ func (channel *Channel) getKeys() []string { // If the key starts with '[', try to parse it as a JSON array (e.g., for Vertex AI scenarios) if strings.HasPrefix(trimmed, "[") { var arr []json.RawMessage - if err := json.Unmarshal([]byte(trimmed), &arr); err == nil { + if err := common.Unmarshal([]byte(trimmed), &arr); err == nil { res := make([]string, len(arr)) for i, v := range arr { res[i] = string(v) @@ -197,7 +197,7 @@ func (channel *Channel) GetGroups() []string { func (channel *Channel) GetOtherInfo() map[string]interface{} { otherInfo := make(map[string]interface{}) if channel.OtherInfo != "" { - err := json.Unmarshal([]byte(channel.OtherInfo), &otherInfo) + err := common.Unmarshal([]byte(channel.OtherInfo), &otherInfo) if err != nil { common.SysError("failed to unmarshal other info: " + err.Error()) } @@ -425,7 +425,7 @@ func (channel *Channel) Update() error { trimmed := strings.TrimSpace(keyStr) if strings.HasPrefix(trimmed, "[") { var arr []json.RawMessage - if err := json.Unmarshal([]byte(trimmed), &arr); err == nil { + if err := common.Unmarshal([]byte(trimmed), &arr); err == nil { keys = make([]string, len(arr)) for i, v := range arr { keys[i] = string(v) @@ -553,6 +553,7 @@ func handlerMultiKeyUpdate(channel *Channel, usingKey string, status int) { } func UpdateChannelStatus(channelId int, usingKey string, status int, reason string) bool { + println("UpdateChannelStatus called with channelId:", channelId, "usingKey:", usingKey, "status:", status, "reason:", reason) if common.MemoryCacheEnabled { channelStatusLock.Lock() defer channelStatusLock.Unlock() @@ -571,10 +572,6 @@ func UpdateChannelStatus(channelId int, usingKey string, status int, reason stri if channelCache.Status == status { return false } - // 如果缓存渠道不存在(说明已经被禁用),且要设置的状态不为启用,直接返回 - if status != common.ChannelStatusEnabled { - return false - } CacheUpdateChannelStatus(channelId, status) } } @@ -778,7 +775,7 @@ func SearchTags(keyword string, group string, model string, idSort bool) ([]*str func (channel *Channel) ValidateSettings() error { channelParams := &dto.ChannelSettings{} if channel.Setting != nil && *channel.Setting != "" { - err := json.Unmarshal([]byte(*channel.Setting), channelParams) + err := common.Unmarshal([]byte(*channel.Setting), channelParams) if err != nil { return err } @@ -789,7 +786,7 @@ func (channel *Channel) ValidateSettings() error { func (channel *Channel) GetSetting() dto.ChannelSettings { setting := dto.ChannelSettings{} if channel.Setting != nil && *channel.Setting != "" { - err := json.Unmarshal([]byte(*channel.Setting), &setting) + err := common.Unmarshal([]byte(*channel.Setting), &setting) if err != nil { common.SysError("failed to unmarshal setting: " + err.Error()) channel.Setting = nil // 清空设置以避免后续错误 @@ -800,7 +797,7 @@ func (channel *Channel) GetSetting() dto.ChannelSettings { } func (channel *Channel) SetSetting(setting dto.ChannelSettings) { - settingBytes, err := json.Marshal(setting) + settingBytes, err := common.Marshal(setting) if err != nil { common.SysError("failed to marshal setting: " + err.Error()) return @@ -811,7 +808,7 @@ func (channel *Channel) SetSetting(setting dto.ChannelSettings) { func (channel *Channel) GetParamOverride() map[string]interface{} { paramOverride := make(map[string]interface{}) if channel.ParamOverride != nil && *channel.ParamOverride != "" { - err := json.Unmarshal([]byte(*channel.ParamOverride), ¶mOverride) + err := common.Unmarshal([]byte(*channel.ParamOverride), ¶mOverride) if err != nil { common.SysError("failed to unmarshal param override: " + err.Error()) } From 54447bf227237ecbe9a2f8a5cfa7b95ede49cb45 Mon Sep 17 00:00:00 2001 From: CaIon Date: Wed, 30 Jul 2025 23:29:45 +0800 Subject: [PATCH 21/24] fix: remove debug print statement --- model/channel.go | 1 - 1 file changed, 1 deletion(-) diff --git a/model/channel.go b/model/channel.go index e3535d64..58f0a064 100644 --- a/model/channel.go +++ b/model/channel.go @@ -553,7 +553,6 @@ func handlerMultiKeyUpdate(channel *Channel, usingKey string, status int) { } func UpdateChannelStatus(channelId int, usingKey string, status int, reason string) bool { - println("UpdateChannelStatus called with channelId:", channelId, "usingKey:", usingKey, "status:", status, "reason:", reason) if common.MemoryCacheEnabled { channelStatusLock.Lock() defer channelStatusLock.Unlock() From f20b558e22b8a72b390319e590278d3e7419fb63 Mon Sep 17 00:00:00 2001 From: CaIon Date: Wed, 30 Jul 2025 23:32:20 +0800 Subject: [PATCH 22/24] fix: correct request mode assignment logic in adaptor --- relay/channel/vertex/adaptor.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/relay/channel/vertex/adaptor.go b/relay/channel/vertex/adaptor.go index 40c9ca89..c88b4359 100644 --- a/relay/channel/vertex/adaptor.go +++ b/relay/channel/vertex/adaptor.go @@ -69,8 +69,9 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { a.RequestMode = RequestModeClaude } else if strings.Contains(info.UpstreamModelName, "llama") { a.RequestMode = RequestModeLlama + } else { + a.RequestMode = RequestModeGemini } - a.RequestMode = RequestModeGemini } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { From 196bafff0387e9c7adffb4964e9718152182bd02 Mon Sep 17 00:00:00 2001 From: CaIon Date: Thu, 31 Jul 2025 10:56:51 +0800 Subject: [PATCH 23/24] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=E8=A2=AB?= =?UTF-8?q?=E7=A6=81=E7=94=A8=E7=9A=84=E6=B8=A0=E9=81=93=E6=97=A0=E6=B3=95?= =?UTF-8?q?=E6=B5=8B=E8=AF=95=E7=9A=84=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- controller/channel-test.go | 7 +++++-- model/channel_cache.go | 14 ++++++++++++++ 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/controller/channel-test.go b/controller/channel-test.go index 75fec463..3a7c582b 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -332,8 +332,11 @@ func TestChannel(c *gin.Context) { } channel, err := model.CacheGetChannel(channelId) if err != nil { - common.ApiError(c, err) - return + channel, err = model.GetChannelById(channelId, true) + if err != nil { + common.ApiError(c, err) + return + } } //defer func() { // if channel.ChannelInfo.IsMultiKey { diff --git a/model/channel_cache.go b/model/channel_cache.go index 45069ba0..1abc8b85 100644 --- a/model/channel_cache.go +++ b/model/channel_cache.go @@ -239,6 +239,20 @@ func CacheUpdateChannelStatus(id int, status int) { if channel, ok := channelsIDM[id]; ok { channel.Status = status } + if status != common.ChannelStatusEnabled { + // delete the channel from group2model2channels + for group, model2channels := range group2model2channels { + for model, channels := range model2channels { + for i, channelId := range channels { + if channelId == id { + // remove the channel from the slice + group2model2channels[group][model] = append(channels[:i], channels[i+1:]...) + break + } + } + } + } + } } func CacheUpdateChannel(channel *Channel) { From bd6b811183129709969227f1d4f10121f0df92a0 Mon Sep 17 00:00:00 2001 From: CaIon Date: Thu, 31 Jul 2025 12:54:07 +0800 Subject: [PATCH 24/24] feat: add JSONEditor component for enhanced JSON input handling --- web/src/components/common/JSONEditor.js | 609 ++++++++++++++++++ .../channels/modals/EditChannelModal.jsx | 61 +- 2 files changed, 637 insertions(+), 33 deletions(-) create mode 100644 web/src/components/common/JSONEditor.js diff --git a/web/src/components/common/JSONEditor.js b/web/src/components/common/JSONEditor.js new file mode 100644 index 00000000..d0c159b2 --- /dev/null +++ b/web/src/components/common/JSONEditor.js @@ -0,0 +1,609 @@ +import React, { useState, useEffect, useCallback } from 'react'; +import { useTranslation } from 'react-i18next'; +import { + Space, + Button, + Form, + Card, + Typography, + Banner, + Row, + Col, + InputNumber, + Switch, + Select, + Input, +} from '@douyinfe/semi-ui'; +import { + IconCode, + IconEdit, + IconPlus, + IconDelete, + IconSetting, +} from '@douyinfe/semi-icons'; + +const { Text } = Typography; + +const JSONEditor = ({ + value = '', + onChange, + field, + label, + placeholder, + extraText, + showClear = true, + template, + templateLabel, + editorType = 'keyValue', // keyValue, object, region + autosize = true, + rules = [], + formApi = null, + ...props +}) => { + const { t } = useTranslation(); + + // 初始化JSON数据 + const [jsonData, setJsonData] = useState(() => { + // 初始化时解析JSON数据 + if (value && value.trim()) { + try { + const parsed = JSON.parse(value); + return parsed; + } catch (error) { + return {}; + } + } + return {}; + }); + + // 根据键数量决定默认编辑模式 + const [editMode, setEditMode] = useState(() => { + // 如果初始JSON数据的键数量大于10个,则默认使用手动模式 + if (value && value.trim()) { + try { + const parsed = JSON.parse(value); + const keyCount = Object.keys(parsed).length; + return keyCount > 10 ? 'manual' : 'visual'; + } catch (error) { + return 'visual'; + } + } + return 'visual'; + }); + const [jsonError, setJsonError] = useState(''); + + // 数据同步 - 当value变化时总是更新jsonData(如果JSON有效) + useEffect(() => { + try { + const parsed = value && value.trim() ? JSON.parse(value) : {}; + setJsonData(parsed); + setJsonError(''); + } catch (error) { + console.log('JSON解析失败:', error.message); + setJsonError(error.message); + // JSON格式错误时不更新jsonData + } + }, [value]); + + + // 处理可视化编辑的数据变化 + const handleVisualChange = useCallback((newData) => { + setJsonData(newData); + setJsonError(''); + const jsonString = Object.keys(newData).length === 0 ? '' : JSON.stringify(newData, null, 2); + + // 通过formApi设置值(如果提供的话) + if (formApi && field) { + formApi.setValue(field, jsonString); + } + + onChange?.(jsonString); + }, [onChange, formApi, field]); + + // 处理手动编辑的数据变化 + const handleManualChange = useCallback((newValue) => { + onChange?.(newValue); + // 验证JSON格式 + if (newValue && newValue.trim()) { + try { + const parsed = JSON.parse(newValue); + setJsonError(''); + // 预先准备可视化数据,但不立即应用 + // 这样切换到可视化模式时数据已经准备好了 + } catch (error) { + setJsonError(error.message); + } + } else { + setJsonError(''); + } + }, [onChange]); + + // 切换编辑模式 + const toggleEditMode = useCallback(() => { + if (editMode === 'visual') { + // 从可视化模式切换到手动模式 + setEditMode('manual'); + } else { + // 从手动模式切换到可视化模式,需要验证JSON + try { + const parsed = value && value.trim() ? JSON.parse(value) : {}; + setJsonData(parsed); + setJsonError(''); + setEditMode('visual'); + } catch (error) { + setJsonError(error.message); + // JSON格式错误时不切换模式 + return; + } + } + }, [editMode, value]); + + // 添加键值对 + const addKeyValue = useCallback(() => { + const newData = { ...jsonData }; + const keys = Object.keys(newData); + let newKey = 'key'; + let counter = 1; + while (newData.hasOwnProperty(newKey)) { + newKey = `key${counter}`; + counter++; + } + newData[newKey] = ''; + handleVisualChange(newData); + }, [jsonData, handleVisualChange]); + + // 删除键值对 + const removeKeyValue = useCallback((keyToRemove) => { + const newData = { ...jsonData }; + delete newData[keyToRemove]; + handleVisualChange(newData); + }, [jsonData, handleVisualChange]); + + // 更新键名 + const updateKey = useCallback((oldKey, newKey) => { + if (oldKey === newKey) return; + const newData = { ...jsonData }; + const value = newData[oldKey]; + delete newData[oldKey]; + newData[newKey] = value; + handleVisualChange(newData); + }, [jsonData, handleVisualChange]); + + // 更新值 + const updateValue = useCallback((key, newValue) => { + const newData = { ...jsonData }; + newData[key] = newValue; + handleVisualChange(newData); + }, [jsonData, handleVisualChange]); + + // 填入模板 + const fillTemplate = useCallback(() => { + if (template) { + const templateString = JSON.stringify(template, null, 2); + + // 通过formApi设置值(如果提供的话) + if (formApi && field) { + formApi.setValue(field, templateString); + } + + // 无论哪种模式都要更新值 + onChange?.(templateString); + + // 如果是可视化模式,同时更新jsonData + if (editMode === 'visual') { + setJsonData(template); + } + + // 清除错误状态 + setJsonError(''); + } + }, [template, onChange, editMode, formApi, field]); + + // 渲染键值对编辑器 + const renderKeyValueEditor = () => { + const entries = Object.entries(jsonData); + + return ( +
    + {entries.length === 0 && ( +
    +
    + +
    + + {t('暂无数据,点击下方按钮添加键值对')} + +
    + )} + + {entries.map(([key, value], index) => ( + + + +
    + {t('键名')} + updateKey(key, newKey)} + size="small" + /> +
    + + +
    + {t('值')} + updateValue(key, newValue)} + size="small" + /> +
    + + +
    +
    + +
    +
    + ))} + +
    + +
    +
    + ); + }; + + // 渲染对象编辑器(用于复杂JSON) + const renderObjectEditor = () => { + const entries = Object.entries(jsonData); + + return ( +
    + {entries.length === 0 && ( +
    +
    + +
    + + {t('暂无参数,点击下方按钮添加请求参数')} + +
    + )} + + {entries.map(([key, value], index) => ( + + + +
    + {t('参数名')} + updateKey(key, newKey)} + size="small" + /> +
    + + +
    + {t('参数值')} ({typeof value}) + {renderValueInput(key, value)} +
    + + +
    +
    + +
    +
    + ))} + +
    + +
    +
    + ); + }; + + // 渲染参数值输入控件 + const renderValueInput = (key, value) => { + const valueType = typeof value; + + if (valueType === 'boolean') { + return ( +
    + updateValue(key, newValue)} + size="small" + /> + + {value ? t('true') : t('false')} + +
    + ); + } + + if (valueType === 'number') { + return ( + updateValue(key, newValue)} + size="small" + style={{ width: '100%' }} + step={key === 'temperature' ? 0.1 : 1} + precision={key === 'temperature' ? 2 : 0} + placeholder={t('输入数字')} + /> + ); + } + + // 字符串类型或其他类型 + return ( + { + // 尝试转换为适当的类型 + let convertedValue = newValue; + if (newValue === 'true') convertedValue = true; + else if (newValue === 'false') convertedValue = false; + else if (!isNaN(newValue) && newValue !== '' && newValue !== '0') { + convertedValue = Number(newValue); + } + + updateValue(key, convertedValue); + }} + size="small" + /> + ); + }; + + // 渲染区域编辑器(特殊格式) + const renderRegionEditor = () => { + const entries = Object.entries(jsonData); + const defaultEntry = entries.find(([key]) => key === 'default'); + const modelEntries = entries.filter(([key]) => key !== 'default'); + + return ( +
    + {/* 默认区域 */} + +
    + {t('默认区域')} +
    + updateValue('default', value)} + size="small" + /> +
    + + {/* 模型专用区域 */} +
    + {t('模型专用区域')} + {modelEntries.map(([modelName, region], index) => ( + + + +
    + {t('模型名称')} + updateKey(modelName, newKey)} + size="small" + /> +
    + + +
    + {t('区域')} + updateValue(modelName, newValue)} + size="small" + /> +
    + + +
    +
    + +
    +
    + ))} + +
    + +
    +
    +
    + ); + }; + + // 渲染可视化编辑器 + const renderVisualEditor = () => { + switch (editorType) { + case 'region': + return renderRegionEditor(); + case 'object': + return renderObjectEditor(); + case 'keyValue': + default: + return renderKeyValueEditor(); + } + }; + + const hasJsonError = jsonError && jsonError.trim() !== ''; + + return ( +
    + {/* Label统一显示在上方 */} + {label && ( +
    + {label} +
    + )} + + {/* 编辑模式切换 */} +
    +
    + {editMode === 'visual' && ( + + {t('可视化模式')} + + )} + {editMode === 'manual' && ( + + {t('手动编辑模式')} + + )} +
    +
    + {template && templateLabel && ( + + )} + + + + +
    +
    + + {/* JSON错误提示 */} + {hasJsonError && ( + + )} + + {/* 编辑器内容 */} + {editMode === 'visual' ? ( +
    + + {renderVisualEditor()} + + {/* 可视化模式下的额外文本显示在下方 */} + {extraText && ( +
    + {extraText} +
    + )} + {/* 隐藏的Form字段用于验证和数据绑定 */} + +
    + ) : ( + + )} + + {/* 额外文本在手动编辑模式下显示 */} + {extraText && editMode === 'manual' && ( +
    + {extraText} +
    + )} +
    + ); +}; + +export default JSONEditor; \ No newline at end of file diff --git a/web/src/components/table/channels/modals/EditChannelModal.jsx b/web/src/components/table/channels/modals/EditChannelModal.jsx index a3f09166..37e9af75 100644 --- a/web/src/components/table/channels/modals/EditChannelModal.jsx +++ b/web/src/components/table/channels/modals/EditChannelModal.jsx @@ -48,6 +48,7 @@ import { } from '@douyinfe/semi-ui'; import { getChannelModels, copy, getChannelIcon, getModelCategories, selectFilter } from '../../../../helpers'; import ModelSelectModal from './ModelSelectModal'; +import JSONEditor from '../../../common/JSONEditor'; import { IconSave, IconClose, @@ -69,7 +70,9 @@ const STATUS_CODE_MAPPING_EXAMPLE = { }; const REGION_EXAMPLE = { - default: 'us-central1', + "default": 'global', + "gemini-1.5-pro-002": "europe-west2", + "gemini-1.5-flash-002": "europe-west2", 'claude-3-5-sonnet-20240620': 'europe-west1', }; @@ -1174,24 +1177,24 @@ const EditChannelModal = (props) => { )} {inputs.type === 41 && ( - handleInputChange('other', value)} rules={[{ required: true, message: t('请填写部署地区') }]} + template={REGION_EXAMPLE} + templateLabel={t('填入模板')} + editorType="region" + formApi={formApiRef.current} extraText={ - handleInputChange('other', JSON.stringify(REGION_EXAMPLE, null, 2))} - > - {t('填入模板')} + + {t('设置默认地区和特定模型的专用地区')} } - showClear /> )} @@ -1447,24 +1450,24 @@ const EditChannelModal = (props) => { showClear /> - handleInputChange('model_mapping', value)} + template={MODEL_MAPPING_EXAMPLE} + templateLabel={t('填入模板')} + editorType="keyValue" + formApi={formApiRef.current} extraText={ - handleInputChange('model_mapping', JSON.stringify(MODEL_MAPPING_EXAMPLE, null, 2))} - > - {t('填入模板')} + + {t('键为请求中的模型名称,值为要替换的模型名称')} } - showClear /> @@ -1554,7 +1557,7 @@ const EditChannelModal = (props) => { showClear /> - { '\n' + JSON.stringify(STATUS_CODE_MAPPING_EXAMPLE, null, 2) } - autosize + value={inputs.status_code_mapping || ''} onChange={(value) => handleInputChange('status_code_mapping', value)} + template={STATUS_CODE_MAPPING_EXAMPLE} + templateLabel={t('填入模板')} + editorType="keyValue" + formApi={formApiRef.current} extraText={ - handleInputChange('status_code_mapping', JSON.stringify(STATUS_CODE_MAPPING_EXAMPLE, null, 2))} - > - {t('填入模板')} + + {t('键为原状态码,值为要复写的状态码,仅影响本地判断')} } - showClear /> @@ -1585,14 +1588,6 @@ const EditChannelModal = (props) => {
    {t('渠道额外设置')} -
    - window.open('https://github.com/QuantumNous/new-api/blob/main/docs/channel/other_setting.md')} - > - {t('设置说明')} - -