From 0cd93d67ff9e398c3f44f251b2d701bef5199cc3 Mon Sep 17 00:00:00 2001 From: CaIon Date: Wed, 30 Jul 2025 18:39:19 +0800 Subject: [PATCH] fix: auto ban --- relay/channel/ali/image.go | 4 +- relay/channel/ali/rerank.go | 4 +- relay/channel/ali/text.go | 6 +- relay/channel/gemini/adaptor.go | 56 ---------------- relay/channel/gemini/relay-gemini.go | 66 +++++++++++++++++-- relay/channel/jimeng/image.go | 4 +- relay/channel/openai/relay-openai.go | 12 ++-- relay/channel/openai/relay_responses.go | 4 +- relay/channel/palm/relay-palm.go | 4 +- .../channel/siliconflow/relay-siliconflow.go | 4 +- relay/channel/tencent/relay-tencent.go | 4 +- relay/channel/vertex/adaptor.go | 12 +++- relay/channel/zhipu/relay-zhipu.go | 4 +- relay/common_handler/rerank.go | 6 +- relay/gemini_handler.go | 3 +- 15 files changed, 99 insertions(+), 94 deletions(-) diff --git a/relay/channel/ali/image.go b/relay/channel/ali/image.go index 0d430c62..754f29c8 100644 --- a/relay/channel/ali/image.go +++ b/relay/channel/ali/image.go @@ -132,12 +132,12 @@ func aliImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela var aliTaskResponse AliResponse responseBody, err := io.ReadAll(resp.Body) if err != nil { - return types.NewError(err, types.ErrorCodeReadResponseBodyFailed), nil + return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil } common.CloseResponseBodyGracefully(resp) err = json.Unmarshal(responseBody, &aliTaskResponse) if err != nil { - return types.NewError(err, types.ErrorCodeBadResponseBody), nil + return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil } if aliTaskResponse.Message != "" { diff --git a/relay/channel/ali/rerank.go b/relay/channel/ali/rerank.go index 59cb0a11..4f448e01 100644 --- a/relay/channel/ali/rerank.go +++ b/relay/channel/ali/rerank.go @@ -34,14 +34,14 @@ func ConvertRerankRequest(request dto.RerankRequest) *AliRerankRequest { func RerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.Usage) { responseBody, err := io.ReadAll(resp.Body) if err != nil { - return types.NewError(err, types.ErrorCodeReadResponseBodyFailed), nil + return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil } common.CloseResponseBodyGracefully(resp) var aliResponse AliRerankResponse err = json.Unmarshal(responseBody, &aliResponse) if err != nil { - return types.NewError(err, types.ErrorCodeBadResponseBody), nil + return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil } if aliResponse.Code != "" { diff --git a/relay/channel/ali/text.go b/relay/channel/ali/text.go index 6d90fa71..fcf63854 100644 --- a/relay/channel/ali/text.go +++ b/relay/channel/ali/text.go @@ -43,7 +43,7 @@ func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*types.NewAPIErro var fullTextResponse dto.FlexibleEmbeddingResponse err := json.NewDecoder(resp.Body).Decode(&fullTextResponse) if err != nil { - return types.NewError(err, types.ErrorCodeBadResponseBody), nil + return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil } common.CloseResponseBodyGracefully(resp) @@ -179,12 +179,12 @@ func aliHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, *dto.U var aliResponse AliResponse responseBody, err := io.ReadAll(resp.Body) if err != nil { - return types.NewError(err, types.ErrorCodeReadResponseBodyFailed), nil + return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil } common.CloseResponseBodyGracefully(resp) err = json.Unmarshal(responseBody, &aliResponse) if err != nil { - return types.NewError(err, types.ErrorCodeBadResponseBody), nil + return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil } if aliResponse.Code != "" { return types.WithOpenAIError(types.OpenAIError{ diff --git a/relay/channel/gemini/adaptor.go b/relay/channel/gemini/adaptor.go index 2e31ec55..da291aa9 100644 --- a/relay/channel/gemini/adaptor.go +++ b/relay/channel/gemini/adaptor.go @@ -1,12 +1,10 @@ package gemini import ( - "encoding/json" "errors" "fmt" "io" "net/http" - "one-api/common" "one-api/dto" "one-api/relay/channel" "one-api/relay/channel/openai" @@ -212,60 +210,6 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom return nil, types.NewError(errors.New("not implemented"), types.ErrorCodeBadResponseBody) } -func GeminiImageHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { - responseBody, readErr := io.ReadAll(resp.Body) - if readErr != nil { - return nil, types.NewError(readErr, types.ErrorCodeBadResponseBody) - } - _ = resp.Body.Close() - - var geminiResponse GeminiImageResponse - if jsonErr := json.Unmarshal(responseBody, &geminiResponse); jsonErr != nil { - return nil, types.NewError(jsonErr, types.ErrorCodeBadResponseBody) - } - - if len(geminiResponse.Predictions) == 0 { - return nil, types.NewError(errors.New("no images generated"), types.ErrorCodeBadResponseBody) - } - - // convert to openai format response - openAIResponse := dto.ImageResponse{ - Created: common.GetTimestamp(), - Data: make([]dto.ImageData, 0, len(geminiResponse.Predictions)), - } - - for _, prediction := range geminiResponse.Predictions { - if prediction.RaiFilteredReason != "" { - continue // skip filtered image - } - openAIResponse.Data = append(openAIResponse.Data, dto.ImageData{ - B64Json: prediction.BytesBase64Encoded, - }) - } - - jsonResponse, jsonErr := json.Marshal(openAIResponse) - if jsonErr != nil { - return nil, types.NewError(jsonErr, types.ErrorCodeBadResponseBody) - } - - c.Writer.Header().Set("Content-Type", "application/json") - c.Writer.WriteHeader(resp.StatusCode) - _, _ = c.Writer.Write(jsonResponse) - - // https://github.com/google-gemini/cookbook/blob/719a27d752aac33f39de18a8d3cb42a70874917e/quickstarts/Counting_Tokens.ipynb - // each image has fixed 258 tokens - const imageTokens = 258 - generatedImages := len(openAIResponse.Data) - - usage := &dto.Usage{ - PromptTokens: imageTokens * generatedImages, // each generated image has fixed 258 tokens - CompletionTokens: 0, // image generation does not calculate completion tokens - TotalTokens: imageTokens * generatedImages, - } - - return usage, nil -} - func (a *Adaptor) GetModelList() []string { return ModelList } diff --git a/relay/channel/gemini/relay-gemini.go b/relay/channel/gemini/relay-gemini.go index 7e57bdac..5dac0ce5 100644 --- a/relay/channel/gemini/relay-gemini.go +++ b/relay/channel/gemini/relay-gemini.go @@ -907,7 +907,7 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp * func GeminiChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { responseBody, err := io.ReadAll(resp.Body) if err != nil { - return nil, types.NewError(err, types.ErrorCodeBadResponseBody) + return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } common.CloseResponseBodyGracefully(resp) if common.DebugEnabled { @@ -916,10 +916,10 @@ func GeminiChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R var geminiResponse GeminiChatResponse err = common.Unmarshal(responseBody, &geminiResponse) if err != nil { - return nil, types.NewError(err, types.ErrorCodeBadResponseBody) + return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } if len(geminiResponse.Candidates) == 0 { - return nil, types.NewError(errors.New("no candidates returned"), types.ErrorCodeBadResponseBody) + return nil, types.NewOpenAIError(errors.New("no candidates returned"), types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } fullTextResponse := responseGeminiChat2OpenAI(c, &geminiResponse) fullTextResponse.Model = info.UpstreamModelName @@ -956,12 +956,12 @@ func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *h responseBody, readErr := io.ReadAll(resp.Body) if readErr != nil { - return nil, types.NewError(readErr, types.ErrorCodeBadResponseBody) + return nil, types.NewOpenAIError(readErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } var geminiResponse GeminiEmbeddingResponse if jsonErr := common.Unmarshal(responseBody, &geminiResponse); jsonErr != nil { - return nil, types.NewError(jsonErr, types.ErrorCodeBadResponseBody) + return nil, types.NewOpenAIError(jsonErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } // convert to openai format response @@ -991,9 +991,63 @@ func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *h jsonResponse, jsonErr := common.Marshal(openAIResponse) if jsonErr != nil { - return nil, types.NewError(jsonErr, types.ErrorCodeBadResponseBody) + return nil, types.NewOpenAIError(jsonErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } common.IOCopyBytesGracefully(c, resp, jsonResponse) return usage, nil } + +func GeminiImageHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { + responseBody, readErr := io.ReadAll(resp.Body) + if readErr != nil { + return nil, types.NewOpenAIError(readErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) + } + _ = resp.Body.Close() + + var geminiResponse GeminiImageResponse + if jsonErr := common.Unmarshal(responseBody, &geminiResponse); jsonErr != nil { + return nil, types.NewOpenAIError(jsonErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) + } + + if len(geminiResponse.Predictions) == 0 { + return nil, types.NewOpenAIError(errors.New("no images generated"), types.ErrorCodeBadResponseBody, http.StatusInternalServerError) + } + + // convert to openai format response + openAIResponse := dto.ImageResponse{ + Created: common.GetTimestamp(), + Data: make([]dto.ImageData, 0, len(geminiResponse.Predictions)), + } + + for _, prediction := range geminiResponse.Predictions { + if prediction.RaiFilteredReason != "" { + continue // skip filtered image + } + openAIResponse.Data = append(openAIResponse.Data, dto.ImageData{ + B64Json: prediction.BytesBase64Encoded, + }) + } + + jsonResponse, jsonErr := json.Marshal(openAIResponse) + if jsonErr != nil { + return nil, types.NewError(jsonErr, types.ErrorCodeBadResponseBody) + } + + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, _ = c.Writer.Write(jsonResponse) + + // https://github.com/google-gemini/cookbook/blob/719a27d752aac33f39de18a8d3cb42a70874917e/quickstarts/Counting_Tokens.ipynb + // each image has fixed 258 tokens + const imageTokens = 258 + generatedImages := len(openAIResponse.Data) + + usage := &dto.Usage{ + PromptTokens: imageTokens * generatedImages, // each generated image has fixed 258 tokens + CompletionTokens: 0, // image generation does not calculate completion tokens + TotalTokens: imageTokens * generatedImages, + } + + return usage, nil +} diff --git a/relay/channel/jimeng/image.go b/relay/channel/jimeng/image.go index 3c6a1d99..28af1866 100644 --- a/relay/channel/jimeng/image.go +++ b/relay/channel/jimeng/image.go @@ -52,13 +52,13 @@ func jimengImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.R var jimengResponse ImageResponse responseBody, err := io.ReadAll(resp.Body) if err != nil { - return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed) + return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError) } common.CloseResponseBodyGracefully(resp) err = json.Unmarshal(responseBody, &jimengResponse) if err != nil { - return nil, types.NewError(err, types.ErrorCodeBadResponseBody) + return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } // Check if the response indicates an error diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index 82bd2d26..2252b407 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -109,7 +109,7 @@ func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, fo func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { if resp == nil || resp.Body == nil { common.LogError(c, "invalid response or response body") - return nil, types.NewError(fmt.Errorf("invalid response"), types.ErrorCodeBadResponse) + return nil, types.NewOpenAIError(fmt.Errorf("invalid response"), types.ErrorCodeBadResponse, http.StatusInternalServerError) } defer common.CloseResponseBodyGracefully(resp) @@ -178,11 +178,11 @@ func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo var simpleResponse dto.OpenAITextResponse responseBody, err := io.ReadAll(resp.Body) if err != nil { - return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed) + return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError) } err = common.Unmarshal(responseBody, &simpleResponse) if err != nil { - return nil, types.NewError(err, types.ErrorCodeBadResponseBody) + return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } if simpleResponse.Error != nil && simpleResponse.Error.Type != "" { return nil, types.WithOpenAIError(*simpleResponse.Error, resp.StatusCode) @@ -263,7 +263,7 @@ func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel } responseBody, err := io.ReadAll(resp.Body) if err != nil { - return types.NewError(err, types.ErrorCodeReadResponseBodyFailed), nil + return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil } // 写入新的 response body common.IOCopyBytesGracefully(c, resp, responseBody) @@ -547,13 +547,13 @@ func OpenaiHandlerWithUsage(c *gin.Context, info *relaycommon.RelayInfo, resp *h responseBody, err := io.ReadAll(resp.Body) if err != nil { - return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed) + return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError) } var usageResp dto.SimpleResponse err = common.Unmarshal(responseBody, &usageResp) if err != nil { - return nil, types.NewError(err, types.ErrorCodeBadResponseBody) + return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } // 写入新的 response body diff --git a/relay/channel/openai/relay_responses.go b/relay/channel/openai/relay_responses.go index d9dd96b9..fd57924b 100644 --- a/relay/channel/openai/relay_responses.go +++ b/relay/channel/openai/relay_responses.go @@ -22,11 +22,11 @@ func OaiResponsesHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http var responsesResponse dto.OpenAIResponsesResponse responseBody, err := io.ReadAll(resp.Body) if err != nil { - return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed) + return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError) } err = common.Unmarshal(responseBody, &responsesResponse) if err != nil { - return nil, types.NewError(err, types.ErrorCodeBadResponseBody) + return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } if responsesResponse.Error != nil { return nil, types.WithOpenAIError(*responsesResponse.Error, resp.StatusCode) diff --git a/relay/channel/palm/relay-palm.go b/relay/channel/palm/relay-palm.go index 4db31573..cbd60f5e 100644 --- a/relay/channel/palm/relay-palm.go +++ b/relay/channel/palm/relay-palm.go @@ -127,13 +127,13 @@ func palmStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, func palmHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { responseBody, err := io.ReadAll(resp.Body) if err != nil { - return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed) + return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError) } common.CloseResponseBodyGracefully(resp) var palmResponse PaLMChatResponse err = json.Unmarshal(responseBody, &palmResponse) if err != nil { - return nil, types.NewError(err, types.ErrorCodeBadResponseBody) + return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } if palmResponse.Error.Code != 0 || len(palmResponse.Candidates) == 0 { return nil, types.WithOpenAIError(types.OpenAIError{ diff --git a/relay/channel/siliconflow/relay-siliconflow.go b/relay/channel/siliconflow/relay-siliconflow.go index fabaf9c6..2e37ad15 100644 --- a/relay/channel/siliconflow/relay-siliconflow.go +++ b/relay/channel/siliconflow/relay-siliconflow.go @@ -15,13 +15,13 @@ import ( func siliconflowRerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { responseBody, err := io.ReadAll(resp.Body) if err != nil { - return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed) + return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError) } common.CloseResponseBodyGracefully(resp) var siliconflowResp SFRerankResponse err = json.Unmarshal(responseBody, &siliconflowResp) if err != nil { - return nil, types.NewError(err, types.ErrorCodeBadResponseBody) + return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } usage := &dto.Usage{ PromptTokens: siliconflowResp.Meta.Tokens.InputTokens, diff --git a/relay/channel/tencent/relay-tencent.go b/relay/channel/tencent/relay-tencent.go index c3d96c49..78ce6238 100644 --- a/relay/channel/tencent/relay-tencent.go +++ b/relay/channel/tencent/relay-tencent.go @@ -136,12 +136,12 @@ func tencentHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Resp var tencentSb TencentChatResponseSB responseBody, err := io.ReadAll(resp.Body) if err != nil { - return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed) + return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError) } common.CloseResponseBodyGracefully(resp) err = json.Unmarshal(responseBody, &tencentSb) if err != nil { - return nil, types.NewError(err, types.ErrorCodeBadResponseBody) + return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } if tencentSb.Response.Error.Code != 0 { return nil, types.WithOpenAIError(types.OpenAIError{ diff --git a/relay/channel/vertex/adaptor.go b/relay/channel/vertex/adaptor.go index fa895de0..40c9ca89 100644 --- a/relay/channel/vertex/adaptor.go +++ b/relay/channel/vertex/adaptor.go @@ -67,11 +67,10 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf func (a *Adaptor) Init(info *relaycommon.RelayInfo) { if strings.HasPrefix(info.UpstreamModelName, "claude") { a.RequestMode = RequestModeClaude - } else if strings.HasPrefix(info.UpstreamModelName, "gemini") { - a.RequestMode = RequestModeGemini } else if strings.Contains(info.UpstreamModelName, "llama") { a.RequestMode = RequestModeLlama } + a.RequestMode = RequestModeGemini } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { @@ -83,6 +82,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { a.AccountCredentials = *adc suffix := "" if a.RequestMode == RequestModeGemini { + if model_setting.GetGeminiSettings().ThinkingAdapterEnabled { // 新增逻辑:处理 -thinking- 格式 if strings.Contains(info.UpstreamModelName, "-thinking-") { @@ -100,6 +100,11 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { } else { suffix = "generateContent" } + + if strings.HasPrefix(info.UpstreamModelName, "imagen") { + suffix = "predict" + } + if region == "global" { return fmt.Sprintf( "https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:%s", @@ -231,6 +236,9 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom if info.RelayMode == constant.RelayModeGemini { usage, err = gemini.GeminiTextGenerationHandler(c, info, resp) } else { + if strings.HasPrefix(info.UpstreamModelName, "imagen") { + return gemini.GeminiImageHandler(c, info, resp) + } usage, err = gemini.GeminiChatHandler(c, info, resp) } case RequestModeLlama: diff --git a/relay/channel/zhipu/relay-zhipu.go b/relay/channel/zhipu/relay-zhipu.go index 916a200d..35882ed5 100644 --- a/relay/channel/zhipu/relay-zhipu.go +++ b/relay/channel/zhipu/relay-zhipu.go @@ -220,12 +220,12 @@ func zhipuHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respon var zhipuResponse ZhipuResponse responseBody, err := io.ReadAll(resp.Body) if err != nil { - return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed) + return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError) } common.CloseResponseBodyGracefully(resp) err = json.Unmarshal(responseBody, &zhipuResponse) if err != nil { - return nil, types.NewError(err, types.ErrorCodeBadResponseBody) + return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } if !zhipuResponse.Success { return nil, types.WithOpenAIError(types.OpenAIError{ diff --git a/relay/common_handler/rerank.go b/relay/common_handler/rerank.go index ce823b3a..57df5fe3 100644 --- a/relay/common_handler/rerank.go +++ b/relay/common_handler/rerank.go @@ -16,7 +16,7 @@ import ( func RerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { responseBody, err := io.ReadAll(resp.Body) if err != nil { - return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed) + return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError) } common.CloseResponseBodyGracefully(resp) if common.DebugEnabled { @@ -27,7 +27,7 @@ func RerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo var xinRerankResponse xinference.XinRerankResponse err = common.Unmarshal(responseBody, &xinRerankResponse) if err != nil { - return nil, types.NewError(err, types.ErrorCodeBadResponseBody) + return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } jinaRespResults := make([]dto.RerankResponseResult, len(xinRerankResponse.Results)) for i, result := range xinRerankResponse.Results { @@ -62,7 +62,7 @@ func RerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo } else { err = common.Unmarshal(responseBody, &jinaResp) if err != nil { - return nil, types.NewError(err, types.ErrorCodeBadResponseBody) + return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } jinaResp.Usage.PromptTokens = jinaResp.Usage.TotalTokens } diff --git a/relay/gemini_handler.go b/relay/gemini_handler.go index 3b27bfe2..6da8b131 100644 --- a/relay/gemini_handler.go +++ b/relay/gemini_handler.go @@ -2,7 +2,6 @@ package relay import ( "bytes" - "encoding/json" "errors" "fmt" "io" @@ -203,7 +202,7 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) { } requestBody = bytes.NewReader(body) } else { - jsonData, err := json.Marshal(req) + jsonData, err := common.Marshal(req) if err != nil { return types.NewError(err, types.ErrorCodeConvertRequestFailed) }