🐛 fix: refactor response body handling in multiple relay handlers to utilize IOCopyBytesGracefully

This commit is contained in:
CaIon
2025-06-27 23:35:56 +08:00
parent 0a04a76c71
commit 1f4cf07b63
6 changed files with 37 additions and 83 deletions

View File

@@ -3,9 +3,10 @@ package common
import ( import (
"bytes" "bytes"
"fmt" "fmt"
"github.com/gin-gonic/gin"
"io" "io"
"net/http" "net/http"
"github.com/gin-gonic/gin"
) )
func CloseResponseBodyGracefully(httpResponse *http.Response) { func CloseResponseBodyGracefully(httpResponse *http.Response) {
@@ -19,37 +20,37 @@ func CloseResponseBodyGracefully(httpResponse *http.Response) {
} }
func IOCopyBytesGracefully(c *gin.Context, src *http.Response, data []byte) { func IOCopyBytesGracefully(c *gin.Context, src *http.Response, data []byte) {
if src == nil || src.Body == nil {
return
}
defer CloseResponseBodyGracefully(src)
if c.Writer == nil { if c.Writer == nil {
return return
} }
src.Body = io.NopCloser(bytes.NewBuffer(data)) body := io.NopCloser(bytes.NewBuffer(data))
// We shouldn't set the header before we parse the response body, because the parse part may fail. // We shouldn't set the header before we parse the response body, because the parse part may fail.
// And then we will have to send an error response, but in this case, the header has already been set. // And then we will have to send an error response, but in this case, the header has already been set.
// So the httpClient will be confused by the response. // So the httpClient will be confused by the response.
// For example, Postman will report error, and we cannot check the response at all. // For example, Postman will report error, and we cannot check the response at all.
for k, v := range src.Header { if src != nil {
// avoid setting Content-Length for k, v := range src.Header {
if k == "Content-Length" { // avoid setting Content-Length
continue if k == "Content-Length" {
continue
}
c.Writer.Header().Set(k, v[0])
} }
c.Writer.Header().Set(k, v[0])
} }
// set Content-Length header manually // set Content-Length header manually BEFORE calling WriteHeader
c.Writer.Header().Set("Content-Length", fmt.Sprintf("%d", len(data))) c.Writer.Header().Set("Content-Length", fmt.Sprintf("%d", len(data)))
c.Writer.WriteHeader(src.StatusCode) // Write header with status code (this sends the headers)
c.Writer.WriteHeaderNow() if src != nil {
c.Writer.WriteHeader(src.StatusCode)
} else {
c.Writer.WriteHeader(http.StatusOK)
}
_, err := io.Copy(c.Writer, src.Body) _, err := io.Copy(c.Writer, body)
if err != nil { if err != nil {
LogError(c, fmt.Sprintf("failed to copy response body: %s", err.Error())) LogError(c, fmt.Sprintf("failed to copy response body: %s", err.Error()))
} }

View File

@@ -1,7 +1,6 @@
package ollama package ollama
import ( import (
"bytes"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@@ -118,28 +117,7 @@ func ollamaEmbeddingHandler(c *gin.Context, resp *http.Response, promptTokens in
if err != nil { if err != nil {
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
} }
resp.Body = io.NopCloser(bytes.NewBuffer(doResponseBody)) common.IOCopyBytesGracefully(c, resp, doResponseBody)
// We shouldn't set the header before we parse the response body, because the parse part may fail.
// And then we will have to send an error response, but in this case, the header has already been set.
// So the httpClient will be confused by the response.
// For example, Postman will report error, and we cannot check the response at all.
// Copy headers
for k, v := range resp.Header {
// 删除任何现有的相同头部,以防止重复添加头部
c.Writer.Header().Del(k)
for _, vv := range v {
c.Writer.Header().Add(k, vv)
}
}
// reset content length
c.Writer.Header().Del("Content-Length")
c.Writer.Header().Set("Content-Length", fmt.Sprintf("%d", len(doResponseBody)))
c.Writer.WriteHeader(resp.StatusCode)
_, err = io.Copy(c.Writer, resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
}
common.CloseResponseBodyGracefully(resp)
return nil, usage return nil, usage
} }

View File

@@ -181,12 +181,13 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
} }
func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
defer common.CloseResponseBodyGracefully(resp)
var simpleResponse dto.OpenAITextResponse var simpleResponse dto.OpenAITextResponse
responseBody, err := io.ReadAll(resp.Body) responseBody, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
} }
common.CloseResponseBodyGracefully(resp)
err = common.DecodeJson(responseBody, &simpleResponse) err = common.DecodeJson(responseBody, &simpleResponse)
if err != nil { if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
@@ -264,6 +265,8 @@ func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
} }
func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, responseFormat string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, responseFormat string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
defer common.CloseResponseBodyGracefully(resp)
// count tokens by audio file duration // count tokens by audio file duration
audioTokens, err := countAudioTokens(c) audioTokens, err := countAudioTokens(c)
if err != nil { if err != nil {
@@ -273,8 +276,6 @@ func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
if err != nil { if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
} }
common.CloseResponseBodyGracefully(resp)
// 写入新的 response body // 写入新的 response body
common.IOCopyBytesGracefully(c, resp, responseBody) common.IOCopyBytesGracefully(c, resp, responseBody)
@@ -553,6 +554,8 @@ func preConsumeUsage(ctx *gin.Context, info *relaycommon.RelayInfo, usage *dto.R
} }
func OpenaiHandlerWithUsage(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { func OpenaiHandlerWithUsage(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
defer common.CloseResponseBodyGracefully(resp)
responseBody, err := io.ReadAll(resp.Body) responseBody, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
@@ -564,9 +567,6 @@ func OpenaiHandlerWithUsage(c *gin.Context, resp *http.Response, info *relaycomm
return service.OpenAIErrorWrapper(err, "parse_response_body_failed", http.StatusInternalServerError), nil return service.OpenAIErrorWrapper(err, "parse_response_body_failed", http.StatusInternalServerError), nil
} }
// 关闭旧的 response body已被读取再次读取会导致错误
common.CloseResponseBodyGracefully(resp)
// 写入新的 response body // 写入新的 response body
common.IOCopyBytesGracefully(c, resp, responseBody) common.IOCopyBytesGracefully(c, resp, responseBody)

View File

@@ -1,7 +1,6 @@
package openai package openai
import ( import (
"bytes"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
@@ -16,13 +15,14 @@ import (
) )
func OaiResponsesHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { func OaiResponsesHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
defer common.CloseResponseBodyGracefully(resp)
// read response body // read response body
var responsesResponse dto.OpenAIResponsesResponse var responsesResponse dto.OpenAIResponsesResponse
responseBody, err := io.ReadAll(resp.Body) responseBody, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
} }
common.CloseResponseBodyGracefully(resp)
err = common.DecodeJson(responseBody, &responsesResponse) err = common.DecodeJson(responseBody, &responsesResponse)
if err != nil { if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
@@ -38,22 +38,9 @@ func OaiResponsesHandler(c *gin.Context, resp *http.Response, info *relaycommon.
}, nil }, nil
} }
// reset response body // 写入新的 response body
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) common.IOCopyBytesGracefully(c, resp, responseBody)
// We shouldn't set the header before we parse the response body, because the parse part may fail.
// And then we will have to send an error response, but in this case, the header has already been set.
// So the httpClient will be confused by the response.
// For example, Postman will report error, and we cannot check the response at all.
for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0])
}
c.Writer.WriteHeader(resp.StatusCode)
// copy response body
_, err = io.Copy(c.Writer, resp.Body)
if err != nil {
common.SysError("error copying response body: " + err.Error())
}
resp.Body.Close()
// compute usage // compute usage
usage := dto.Usage{} usage := dto.Usage{}
usage.PromptTokens = responsesResponse.Usage.InputTokens usage.PromptTokens = responsesResponse.Usage.InputTokens

View File

@@ -1,9 +1,7 @@
package xai package xai
import ( import (
"bytes"
"encoding/json" "encoding/json"
"github.com/gin-gonic/gin"
"io" "io"
"net/http" "net/http"
"one-api/common" "one-api/common"
@@ -13,6 +11,8 @@ import (
"one-api/relay/helper" "one-api/relay/helper"
"one-api/service" "one-api/service"
"strings" "strings"
"github.com/gin-gonic/gin"
) )
func streamResponseXAI2OpenAI(xAIResp *dto.ChatCompletionsStreamResponse, usage *dto.Usage) *dto.ChatCompletionsStreamResponse { func streamResponseXAI2OpenAI(xAIResp *dto.ChatCompletionsStreamResponse, usage *dto.Usage) *dto.ChatCompletionsStreamResponse {
@@ -78,8 +78,10 @@ func xAIStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
} }
func xAIHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { func xAIHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
defer common.CloseResponseBodyGracefully(resp)
responseBody, err := io.ReadAll(resp.Body) responseBody, err := io.ReadAll(resp.Body)
var response *dto.TextResponse var response *dto.SimpleResponse
err = common.DecodeJson(responseBody, &response) err = common.DecodeJson(responseBody, &response)
if err != nil { if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error()) common.SysError("error unmarshalling stream response: " + err.Error())
@@ -95,18 +97,7 @@ func xAIHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo
return nil, nil return nil, nil
} }
// set new body common.IOCopyBytesGracefully(c, resp, encodeJson)
resp.Body = io.NopCloser(bytes.NewBuffer(encodeJson))
for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0])
}
c.Writer.WriteHeader(resp.StatusCode)
_, err = io.Copy(c.Writer, resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
}
common.CloseResponseBodyGracefully(resp)
return nil, &response.Usage return nil, &response.Usage
} }

View File

@@ -279,10 +279,7 @@ func RelayMidjourneyTaskImageSeed(c *gin.Context) *dto.MidjourneyResponse {
if err != nil { if err != nil {
return service.MidjourneyErrorWrapper(constant.MjRequestError, "unmarshal_response_body_failed") return service.MidjourneyErrorWrapper(constant.MjRequestError, "unmarshal_response_body_failed")
} }
_, err = io.Copy(c.Writer, bytes.NewBuffer(respBody)) common.IOCopyBytesGracefully(c, nil, respBody)
if err != nil {
return service.MidjourneyErrorWrapper(constant.MjRequestError, "copy_response_body_failed")
}
return nil return nil
} }