🐛 fix: refactor response body handling in multiple relay handlers to utilize IOCopyBytesGracefully
This commit is contained in:
@@ -3,9 +3,10 @@ package common
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
func CloseResponseBodyGracefully(httpResponse *http.Response) {
|
func CloseResponseBodyGracefully(httpResponse *http.Response) {
|
||||||
@@ -19,37 +20,37 @@ func CloseResponseBodyGracefully(httpResponse *http.Response) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func IOCopyBytesGracefully(c *gin.Context, src *http.Response, data []byte) {
|
func IOCopyBytesGracefully(c *gin.Context, src *http.Response, data []byte) {
|
||||||
if src == nil || src.Body == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
defer CloseResponseBodyGracefully(src)
|
|
||||||
|
|
||||||
if c.Writer == nil {
|
if c.Writer == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
src.Body = io.NopCloser(bytes.NewBuffer(data))
|
body := io.NopCloser(bytes.NewBuffer(data))
|
||||||
|
|
||||||
// We shouldn't set the header before we parse the response body, because the parse part may fail.
|
// We shouldn't set the header before we parse the response body, because the parse part may fail.
|
||||||
// And then we will have to send an error response, but in this case, the header has already been set.
|
// And then we will have to send an error response, but in this case, the header has already been set.
|
||||||
// So the httpClient will be confused by the response.
|
// So the httpClient will be confused by the response.
|
||||||
// For example, Postman will report error, and we cannot check the response at all.
|
// For example, Postman will report error, and we cannot check the response at all.
|
||||||
for k, v := range src.Header {
|
if src != nil {
|
||||||
// avoid setting Content-Length
|
for k, v := range src.Header {
|
||||||
if k == "Content-Length" {
|
// avoid setting Content-Length
|
||||||
continue
|
if k == "Content-Length" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
c.Writer.Header().Set(k, v[0])
|
||||||
}
|
}
|
||||||
c.Writer.Header().Set(k, v[0])
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// set Content-Length header manually
|
// set Content-Length header manually BEFORE calling WriteHeader
|
||||||
c.Writer.Header().Set("Content-Length", fmt.Sprintf("%d", len(data)))
|
c.Writer.Header().Set("Content-Length", fmt.Sprintf("%d", len(data)))
|
||||||
|
|
||||||
c.Writer.WriteHeader(src.StatusCode)
|
// Write header with status code (this sends the headers)
|
||||||
c.Writer.WriteHeaderNow()
|
if src != nil {
|
||||||
|
c.Writer.WriteHeader(src.StatusCode)
|
||||||
|
} else {
|
||||||
|
c.Writer.WriteHeader(http.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
_, err := io.Copy(c.Writer, src.Body)
|
_, err := io.Copy(c.Writer, body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
LogError(c, fmt.Sprintf("failed to copy response body: %s", err.Error()))
|
LogError(c, fmt.Sprintf("failed to copy response body: %s", err.Error()))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package ollama
|
package ollama
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@@ -118,28 +117,7 @@ func ollamaEmbeddingHandler(c *gin.Context, resp *http.Response, promptTokens in
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
resp.Body = io.NopCloser(bytes.NewBuffer(doResponseBody))
|
common.IOCopyBytesGracefully(c, resp, doResponseBody)
|
||||||
// We shouldn't set the header before we parse the response body, because the parse part may fail.
|
|
||||||
// And then we will have to send an error response, but in this case, the header has already been set.
|
|
||||||
// So the httpClient will be confused by the response.
|
|
||||||
// For example, Postman will report error, and we cannot check the response at all.
|
|
||||||
// Copy headers
|
|
||||||
for k, v := range resp.Header {
|
|
||||||
// 删除任何现有的相同头部,以防止重复添加头部
|
|
||||||
c.Writer.Header().Del(k)
|
|
||||||
for _, vv := range v {
|
|
||||||
c.Writer.Header().Add(k, vv)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// reset content length
|
|
||||||
c.Writer.Header().Del("Content-Length")
|
|
||||||
c.Writer.Header().Set("Content-Length", fmt.Sprintf("%d", len(doResponseBody)))
|
|
||||||
c.Writer.WriteHeader(resp.StatusCode)
|
|
||||||
_, err = io.Copy(c.Writer, resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
common.CloseResponseBodyGracefully(resp)
|
|
||||||
return nil, usage
|
return nil, usage
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -181,12 +181,13 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
|||||||
}
|
}
|
||||||
|
|
||||||
func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||||
|
defer common.CloseResponseBodyGracefully(resp)
|
||||||
|
|
||||||
var simpleResponse dto.OpenAITextResponse
|
var simpleResponse dto.OpenAITextResponse
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
common.CloseResponseBodyGracefully(resp)
|
|
||||||
err = common.DecodeJson(responseBody, &simpleResponse)
|
err = common.DecodeJson(responseBody, &simpleResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
@@ -264,6 +265,8 @@ func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
|||||||
}
|
}
|
||||||
|
|
||||||
func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, responseFormat string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, responseFormat string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||||
|
defer common.CloseResponseBodyGracefully(resp)
|
||||||
|
|
||||||
// count tokens by audio file duration
|
// count tokens by audio file duration
|
||||||
audioTokens, err := countAudioTokens(c)
|
audioTokens, err := countAudioTokens(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -273,8 +276,6 @@ func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
common.CloseResponseBodyGracefully(resp)
|
|
||||||
|
|
||||||
// 写入新的 response body
|
// 写入新的 response body
|
||||||
common.IOCopyBytesGracefully(c, resp, responseBody)
|
common.IOCopyBytesGracefully(c, resp, responseBody)
|
||||||
|
|
||||||
@@ -553,6 +554,8 @@ func preConsumeUsage(ctx *gin.Context, info *relaycommon.RelayInfo, usage *dto.R
|
|||||||
}
|
}
|
||||||
|
|
||||||
func OpenaiHandlerWithUsage(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
func OpenaiHandlerWithUsage(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||||
|
defer common.CloseResponseBodyGracefully(resp)
|
||||||
|
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||||
@@ -564,9 +567,6 @@ func OpenaiHandlerWithUsage(c *gin.Context, resp *http.Response, info *relaycomm
|
|||||||
return service.OpenAIErrorWrapper(err, "parse_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "parse_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// 关闭旧的 response body(已被读取,再次读取会导致错误)
|
|
||||||
common.CloseResponseBodyGracefully(resp)
|
|
||||||
|
|
||||||
// 写入新的 response body
|
// 写入新的 response body
|
||||||
common.IOCopyBytesGracefully(c, resp, responseBody)
|
common.IOCopyBytesGracefully(c, resp, responseBody)
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package openai
|
package openai
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -16,13 +15,14 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func OaiResponsesHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
func OaiResponsesHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||||
|
defer common.CloseResponseBodyGracefully(resp)
|
||||||
|
|
||||||
// read response body
|
// read response body
|
||||||
var responsesResponse dto.OpenAIResponsesResponse
|
var responsesResponse dto.OpenAIResponsesResponse
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
common.CloseResponseBodyGracefully(resp)
|
|
||||||
err = common.DecodeJson(responseBody, &responsesResponse)
|
err = common.DecodeJson(responseBody, &responsesResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
@@ -38,22 +38,9 @@ func OaiResponsesHandler(c *gin.Context, resp *http.Response, info *relaycommon.
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// reset response body
|
// 写入新的 response body
|
||||||
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
common.IOCopyBytesGracefully(c, resp, responseBody)
|
||||||
// We shouldn't set the header before we parse the response body, because the parse part may fail.
|
|
||||||
// And then we will have to send an error response, but in this case, the header has already been set.
|
|
||||||
// So the httpClient will be confused by the response.
|
|
||||||
// For example, Postman will report error, and we cannot check the response at all.
|
|
||||||
for k, v := range resp.Header {
|
|
||||||
c.Writer.Header().Set(k, v[0])
|
|
||||||
}
|
|
||||||
c.Writer.WriteHeader(resp.StatusCode)
|
|
||||||
// copy response body
|
|
||||||
_, err = io.Copy(c.Writer, resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
common.SysError("error copying response body: " + err.Error())
|
|
||||||
}
|
|
||||||
resp.Body.Close()
|
|
||||||
// compute usage
|
// compute usage
|
||||||
usage := dto.Usage{}
|
usage := dto.Usage{}
|
||||||
usage.PromptTokens = responsesResponse.Usage.InputTokens
|
usage.PromptTokens = responsesResponse.Usage.InputTokens
|
||||||
|
|||||||
@@ -1,9 +1,7 @@
|
|||||||
package xai
|
package xai
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
@@ -13,6 +11,8 @@ import (
|
|||||||
"one-api/relay/helper"
|
"one-api/relay/helper"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
func streamResponseXAI2OpenAI(xAIResp *dto.ChatCompletionsStreamResponse, usage *dto.Usage) *dto.ChatCompletionsStreamResponse {
|
func streamResponseXAI2OpenAI(xAIResp *dto.ChatCompletionsStreamResponse, usage *dto.Usage) *dto.ChatCompletionsStreamResponse {
|
||||||
@@ -78,8 +78,10 @@ func xAIStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
|||||||
}
|
}
|
||||||
|
|
||||||
func xAIHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
func xAIHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||||
|
defer common.CloseResponseBodyGracefully(resp)
|
||||||
|
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
var response *dto.TextResponse
|
var response *dto.SimpleResponse
|
||||||
err = common.DecodeJson(responseBody, &response)
|
err = common.DecodeJson(responseBody, &response)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||||
@@ -95,18 +97,7 @@ func xAIHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// set new body
|
common.IOCopyBytesGracefully(c, resp, encodeJson)
|
||||||
resp.Body = io.NopCloser(bytes.NewBuffer(encodeJson))
|
|
||||||
|
|
||||||
for k, v := range resp.Header {
|
|
||||||
c.Writer.Header().Set(k, v[0])
|
|
||||||
}
|
|
||||||
c.Writer.WriteHeader(resp.StatusCode)
|
|
||||||
_, err = io.Copy(c.Writer, resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
common.CloseResponseBodyGracefully(resp)
|
|
||||||
|
|
||||||
return nil, &response.Usage
|
return nil, &response.Usage
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -279,10 +279,7 @@ func RelayMidjourneyTaskImageSeed(c *gin.Context) *dto.MidjourneyResponse {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "unmarshal_response_body_failed")
|
return service.MidjourneyErrorWrapper(constant.MjRequestError, "unmarshal_response_body_failed")
|
||||||
}
|
}
|
||||||
_, err = io.Copy(c.Writer, bytes.NewBuffer(respBody))
|
common.IOCopyBytesGracefully(c, nil, respBody)
|
||||||
if err != nil {
|
|
||||||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "copy_response_body_failed")
|
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user