From 1f4cf07b63deac8aebfcda39c083e28bf7e1a831 Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Fri, 27 Jun 2025 23:35:56 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20fix:=20refactor=20response=20bod?= =?UTF-8?q?y=20handling=20in=20multiple=20relay=20handlers=20to=20utilize?= =?UTF-8?q?=20IOCopyBytesGracefully?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- common/http.go | 35 +++++++++++++------------ relay/channel/ollama/relay-ollama.go | 24 +---------------- relay/channel/openai/relay-openai.go | 12 ++++----- relay/channel/openai/relay_responses.go | 23 ++++------------ relay/channel/xai/text.go | 21 +++++---------- relay/relay-mj.go | 5 +--- 6 files changed, 37 insertions(+), 83 deletions(-) diff --git a/common/http.go b/common/http.go index 315b19af..d2e824ef 100644 --- a/common/http.go +++ b/common/http.go @@ -3,9 +3,10 @@ package common import ( "bytes" "fmt" - "github.com/gin-gonic/gin" "io" "net/http" + + "github.com/gin-gonic/gin" ) func CloseResponseBodyGracefully(httpResponse *http.Response) { @@ -19,37 +20,37 @@ func CloseResponseBodyGracefully(httpResponse *http.Response) { } func IOCopyBytesGracefully(c *gin.Context, src *http.Response, data []byte) { - if src == nil || src.Body == nil { - return - } - - defer CloseResponseBodyGracefully(src) - if c.Writer == nil { return } - src.Body = io.NopCloser(bytes.NewBuffer(data)) + body := io.NopCloser(bytes.NewBuffer(data)) // We shouldn't set the header before we parse the response body, because the parse part may fail. // And then we will have to send an error response, but in this case, the header has already been set. // So the httpClient will be confused by the response. // For example, Postman will report error, and we cannot check the response at all. - for k, v := range src.Header { - // avoid setting Content-Length - if k == "Content-Length" { - continue + if src != nil { + for k, v := range src.Header { + // avoid setting Content-Length + if k == "Content-Length" { + continue + } + c.Writer.Header().Set(k, v[0]) } - c.Writer.Header().Set(k, v[0]) } - // set Content-Length header manually + // set Content-Length header manually BEFORE calling WriteHeader c.Writer.Header().Set("Content-Length", fmt.Sprintf("%d", len(data))) - c.Writer.WriteHeader(src.StatusCode) - c.Writer.WriteHeaderNow() + // Write header with status code (this sends the headers) + if src != nil { + c.Writer.WriteHeader(src.StatusCode) + } else { + c.Writer.WriteHeader(http.StatusOK) + } - _, err := io.Copy(c.Writer, src.Body) + _, err := io.Copy(c.Writer, body) if err != nil { LogError(c, fmt.Sprintf("failed to copy response body: %s", err.Error())) } diff --git a/relay/channel/ollama/relay-ollama.go b/relay/channel/ollama/relay-ollama.go index aa1ec441..bf7501e5 100644 --- a/relay/channel/ollama/relay-ollama.go +++ b/relay/channel/ollama/relay-ollama.go @@ -1,7 +1,6 @@ package ollama import ( - "bytes" "encoding/json" "fmt" "github.com/gin-gonic/gin" @@ -118,28 +117,7 @@ func ollamaEmbeddingHandler(c *gin.Context, resp *http.Response, promptTokens in if err != nil { return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil } - resp.Body = io.NopCloser(bytes.NewBuffer(doResponseBody)) - // We shouldn't set the header before we parse the response body, because the parse part may fail. - // And then we will have to send an error response, but in this case, the header has already been set. - // So the httpClient will be confused by the response. - // For example, Postman will report error, and we cannot check the response at all. - // Copy headers - for k, v := range resp.Header { - // 删除任何现有的相同头部,以防止重复添加头部 - c.Writer.Header().Del(k) - for _, vv := range v { - c.Writer.Header().Add(k, vv) - } - } - // reset content length - c.Writer.Header().Del("Content-Length") - c.Writer.Header().Set("Content-Length", fmt.Sprintf("%d", len(doResponseBody))) - c.Writer.WriteHeader(resp.StatusCode) - _, err = io.Copy(c.Writer, resp.Body) - if err != nil { - return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil - } - common.CloseResponseBodyGracefully(resp) + common.IOCopyBytesGracefully(c, resp, doResponseBody) return nil, usage } diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index c2def5d9..cb8467e2 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -181,12 +181,13 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel } func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { + defer common.CloseResponseBodyGracefully(resp) + var simpleResponse dto.OpenAITextResponse responseBody, err := io.ReadAll(resp.Body) if err != nil { return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil } - common.CloseResponseBodyGracefully(resp) err = common.DecodeJson(responseBody, &simpleResponse) if err != nil { return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil @@ -264,6 +265,8 @@ func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel } func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, responseFormat string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { + defer common.CloseResponseBodyGracefully(resp) + // count tokens by audio file duration audioTokens, err := countAudioTokens(c) if err != nil { @@ -273,8 +276,6 @@ func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel if err != nil { return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil } - common.CloseResponseBodyGracefully(resp) - // 写入新的 response body common.IOCopyBytesGracefully(c, resp, responseBody) @@ -553,6 +554,8 @@ func preConsumeUsage(ctx *gin.Context, info *relaycommon.RelayInfo, usage *dto.R } func OpenaiHandlerWithUsage(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { + defer common.CloseResponseBodyGracefully(resp) + responseBody, err := io.ReadAll(resp.Body) if err != nil { return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil @@ -564,9 +567,6 @@ func OpenaiHandlerWithUsage(c *gin.Context, resp *http.Response, info *relaycomm return service.OpenAIErrorWrapper(err, "parse_response_body_failed", http.StatusInternalServerError), nil } - // 关闭旧的 response body(已被读取,再次读取会导致错误) - common.CloseResponseBodyGracefully(resp) - // 写入新的 response body common.IOCopyBytesGracefully(c, resp, responseBody) diff --git a/relay/channel/openai/relay_responses.go b/relay/channel/openai/relay_responses.go index f7eae7d3..257cf68b 100644 --- a/relay/channel/openai/relay_responses.go +++ b/relay/channel/openai/relay_responses.go @@ -1,7 +1,6 @@ package openai import ( - "bytes" "fmt" "io" "net/http" @@ -16,13 +15,14 @@ import ( ) func OaiResponsesHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { + defer common.CloseResponseBodyGracefully(resp) + // read response body var responsesResponse dto.OpenAIResponsesResponse responseBody, err := io.ReadAll(resp.Body) if err != nil { return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil } - common.CloseResponseBodyGracefully(resp) err = common.DecodeJson(responseBody, &responsesResponse) if err != nil { return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil @@ -38,22 +38,9 @@ func OaiResponsesHandler(c *gin.Context, resp *http.Response, info *relaycommon. }, nil } - // reset response body - resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) - // We shouldn't set the header before we parse the response body, because the parse part may fail. - // And then we will have to send an error response, but in this case, the header has already been set. - // So the httpClient will be confused by the response. - // For example, Postman will report error, and we cannot check the response at all. - for k, v := range resp.Header { - c.Writer.Header().Set(k, v[0]) - } - c.Writer.WriteHeader(resp.StatusCode) - // copy response body - _, err = io.Copy(c.Writer, resp.Body) - if err != nil { - common.SysError("error copying response body: " + err.Error()) - } - resp.Body.Close() + // 写入新的 response body + common.IOCopyBytesGracefully(c, resp, responseBody) + // compute usage usage := dto.Usage{} usage.PromptTokens = responsesResponse.Usage.InputTokens diff --git a/relay/channel/xai/text.go b/relay/channel/xai/text.go index 9a300356..4825ca69 100644 --- a/relay/channel/xai/text.go +++ b/relay/channel/xai/text.go @@ -1,9 +1,7 @@ package xai import ( - "bytes" "encoding/json" - "github.com/gin-gonic/gin" "io" "net/http" "one-api/common" @@ -13,6 +11,8 @@ import ( "one-api/relay/helper" "one-api/service" "strings" + + "github.com/gin-gonic/gin" ) func streamResponseXAI2OpenAI(xAIResp *dto.ChatCompletionsStreamResponse, usage *dto.Usage) *dto.ChatCompletionsStreamResponse { @@ -78,8 +78,10 @@ func xAIStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel } func xAIHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { + defer common.CloseResponseBodyGracefully(resp) + responseBody, err := io.ReadAll(resp.Body) - var response *dto.TextResponse + var response *dto.SimpleResponse err = common.DecodeJson(responseBody, &response) if err != nil { common.SysError("error unmarshalling stream response: " + err.Error()) @@ -95,18 +97,7 @@ func xAIHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo return nil, nil } - // set new body - resp.Body = io.NopCloser(bytes.NewBuffer(encodeJson)) - - for k, v := range resp.Header { - c.Writer.Header().Set(k, v[0]) - } - c.Writer.WriteHeader(resp.StatusCode) - _, err = io.Copy(c.Writer, resp.Body) - if err != nil { - return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil - } - common.CloseResponseBodyGracefully(resp) + common.IOCopyBytesGracefully(c, resp, encodeJson) return nil, &response.Usage } diff --git a/relay/relay-mj.go b/relay/relay-mj.go index b44890c1..5eb5922c 100644 --- a/relay/relay-mj.go +++ b/relay/relay-mj.go @@ -279,10 +279,7 @@ func RelayMidjourneyTaskImageSeed(c *gin.Context) *dto.MidjourneyResponse { if err != nil { return service.MidjourneyErrorWrapper(constant.MjRequestError, "unmarshal_response_body_failed") } - _, err = io.Copy(c.Writer, bytes.NewBuffer(respBody)) - if err != nil { - return service.MidjourneyErrorWrapper(constant.MjRequestError, "copy_response_body_failed") - } + common.IOCopyBytesGracefully(c, nil, respBody) return nil }