🐛 fix: refactor response body handling in multiple relay handlers to utilize IOCopyBytesGracefully
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
package ollama
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -118,28 +117,7 @@ func ollamaEmbeddingHandler(c *gin.Context, resp *http.Response, promptTokens in
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
resp.Body = io.NopCloser(bytes.NewBuffer(doResponseBody))
|
||||
// We shouldn't set the header before we parse the response body, because the parse part may fail.
|
||||
// And then we will have to send an error response, but in this case, the header has already been set.
|
||||
// So the httpClient will be confused by the response.
|
||||
// For example, Postman will report error, and we cannot check the response at all.
|
||||
// Copy headers
|
||||
for k, v := range resp.Header {
|
||||
// 删除任何现有的相同头部,以防止重复添加头部
|
||||
c.Writer.Header().Del(k)
|
||||
for _, vv := range v {
|
||||
c.Writer.Header().Add(k, vv)
|
||||
}
|
||||
}
|
||||
// reset content length
|
||||
c.Writer.Header().Del("Content-Length")
|
||||
c.Writer.Header().Set("Content-Length", fmt.Sprintf("%d", len(doResponseBody)))
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = io.Copy(c.Writer, resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
common.IOCopyBytesGracefully(c, resp, doResponseBody)
|
||||
return nil, usage
|
||||
}
|
||||
|
||||
|
||||
@@ -181,12 +181,13 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
||||
}
|
||||
|
||||
func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
defer common.CloseResponseBodyGracefully(resp)
|
||||
|
||||
var simpleResponse dto.OpenAITextResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
err = common.DecodeJson(responseBody, &simpleResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
@@ -264,6 +265,8 @@ func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
||||
}
|
||||
|
||||
func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, responseFormat string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
defer common.CloseResponseBodyGracefully(resp)
|
||||
|
||||
// count tokens by audio file duration
|
||||
audioTokens, err := countAudioTokens(c)
|
||||
if err != nil {
|
||||
@@ -273,8 +276,6 @@ func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
|
||||
// 写入新的 response body
|
||||
common.IOCopyBytesGracefully(c, resp, responseBody)
|
||||
|
||||
@@ -553,6 +554,8 @@ func preConsumeUsage(ctx *gin.Context, info *relaycommon.RelayInfo, usage *dto.R
|
||||
}
|
||||
|
||||
func OpenaiHandlerWithUsage(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
defer common.CloseResponseBodyGracefully(resp)
|
||||
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
@@ -564,9 +567,6 @@ func OpenaiHandlerWithUsage(c *gin.Context, resp *http.Response, info *relaycomm
|
||||
return service.OpenAIErrorWrapper(err, "parse_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
// 关闭旧的 response body(已被读取,再次读取会导致错误)
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
|
||||
// 写入新的 response body
|
||||
common.IOCopyBytesGracefully(c, resp, responseBody)
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -16,13 +15,14 @@ import (
|
||||
)
|
||||
|
||||
func OaiResponsesHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
defer common.CloseResponseBodyGracefully(resp)
|
||||
|
||||
// read response body
|
||||
var responsesResponse dto.OpenAIResponsesResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
err = common.DecodeJson(responseBody, &responsesResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
@@ -38,22 +38,9 @@ func OaiResponsesHandler(c *gin.Context, resp *http.Response, info *relaycommon.
|
||||
}, nil
|
||||
}
|
||||
|
||||
// reset response body
|
||||
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
||||
// We shouldn't set the header before we parse the response body, because the parse part may fail.
|
||||
// And then we will have to send an error response, but in this case, the header has already been set.
|
||||
// So the httpClient will be confused by the response.
|
||||
// For example, Postman will report error, and we cannot check the response at all.
|
||||
for k, v := range resp.Header {
|
||||
c.Writer.Header().Set(k, v[0])
|
||||
}
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
// copy response body
|
||||
_, err = io.Copy(c.Writer, resp.Body)
|
||||
if err != nil {
|
||||
common.SysError("error copying response body: " + err.Error())
|
||||
}
|
||||
resp.Body.Close()
|
||||
// 写入新的 response body
|
||||
common.IOCopyBytesGracefully(c, resp, responseBody)
|
||||
|
||||
// compute usage
|
||||
usage := dto.Usage{}
|
||||
usage.PromptTokens = responsesResponse.Usage.InputTokens
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
package xai
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
@@ -13,6 +11,8 @@ import (
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func streamResponseXAI2OpenAI(xAIResp *dto.ChatCompletionsStreamResponse, usage *dto.Usage) *dto.ChatCompletionsStreamResponse {
|
||||
@@ -78,8 +78,10 @@ func xAIStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
||||
}
|
||||
|
||||
func xAIHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
defer common.CloseResponseBodyGracefully(resp)
|
||||
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
var response *dto.TextResponse
|
||||
var response *dto.SimpleResponse
|
||||
err = common.DecodeJson(responseBody, &response)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
@@ -95,18 +97,7 @@ func xAIHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// set new body
|
||||
resp.Body = io.NopCloser(bytes.NewBuffer(encodeJson))
|
||||
|
||||
for k, v := range resp.Header {
|
||||
c.Writer.Header().Set(k, v[0])
|
||||
}
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = io.Copy(c.Writer, resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
common.IOCopyBytesGracefully(c, resp, encodeJson)
|
||||
|
||||
return nil, &response.Usage
|
||||
}
|
||||
|
||||
@@ -279,10 +279,7 @@ func RelayMidjourneyTaskImageSeed(c *gin.Context) *dto.MidjourneyResponse {
|
||||
if err != nil {
|
||||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "unmarshal_response_body_failed")
|
||||
}
|
||||
_, err = io.Copy(c.Writer, bytes.NewBuffer(respBody))
|
||||
if err != nil {
|
||||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "copy_response_body_failed")
|
||||
}
|
||||
common.IOCopyBytesGracefully(c, nil, respBody)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user