diff --git a/common/gin.go b/common/gin.go index 4a909dfc..0614f735 100644 --- a/common/gin.go +++ b/common/gin.go @@ -2,7 +2,6 @@ package common import ( "bytes" - "encoding/json" "github.com/gin-gonic/gin" "io" "strings" @@ -31,7 +30,7 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error { } contentType := c.Request.Header.Get("Content-Type") if strings.HasPrefix(contentType, "application/json") { - err = json.Unmarshal(requestBody, &v) + err = UnmarshalJson(requestBody, &v) } else { // skip for now // TODO: someday non json request have variant model, we will need to implementation this diff --git a/common/json.go b/common/json.go index cec8f16b..512ad0c3 100644 --- a/common/json.go +++ b/common/json.go @@ -5,12 +5,16 @@ import ( "encoding/json" ) -func DecodeJson(data []byte, v any) error { - return json.NewDecoder(bytes.NewReader(data)).Decode(v) +func UnmarshalJson(data []byte, v any) error { + return json.Unmarshal(data, v) } -func DecodeJsonStr(data string, v any) error { - return DecodeJson(StringToByteSlice(data), v) +func UnmarshalJsonStr(data string, v any) error { + return json.Unmarshal(StringToByteSlice(data), v) +} + +func DecodeJson(reader *bytes.Reader, v any) error { + return json.NewDecoder(reader).Decode(v) } func EncodeJson(v any) ([]byte, error) { diff --git a/dto/openai_request.go b/dto/openai_request.go index 0104f347..a6567542 100644 --- a/dto/openai_request.go +++ b/dto/openai_request.go @@ -66,7 +66,7 @@ type GeneralOpenAIRequest struct { func (r *GeneralOpenAIRequest) ToMap() map[string]any { result := make(map[string]any) data, _ := common.EncodeJson(r) - _ = common.DecodeJson(data, &result) + _ = common.UnmarshalJson(data, &result) return result } diff --git a/relay/channel/claude/relay-claude.go b/relay/channel/claude/relay-claude.go index f164fd4d..a8607d86 100644 --- a/relay/channel/claude/relay-claude.go +++ b/relay/channel/claude/relay-claude.go @@ -125,7 +125,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.Cla if textRequest.Reasoning != nil { var reasoning openrouter.RequestReasoning - if err := common.DecodeJson(textRequest.Reasoning, &reasoning); err != nil { + if err := common.UnmarshalJson(textRequest.Reasoning, &reasoning); err != nil { return nil, err } @@ -519,7 +519,7 @@ func FormatClaudeResponseInfo(requestMode int, claudeResponse *dto.ClaudeRespons func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, data string, requestMode int) *dto.OpenAIErrorWithStatusCode { var claudeResponse dto.ClaudeResponse - err := common.DecodeJsonStr(data, &claudeResponse) + err := common.UnmarshalJsonStr(data, &claudeResponse) if err != nil { common.SysError("error unmarshalling stream response: " + err.Error()) return service.OpenAIErrorWrapper(err, "stream_response_error", http.StatusInternalServerError) @@ -619,7 +619,7 @@ func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon. func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, data []byte, requestMode int) *dto.OpenAIErrorWithStatusCode { var claudeResponse dto.ClaudeResponse - err := common.DecodeJson(data, &claudeResponse) + err := common.UnmarshalJson(data, &claudeResponse) if err != nil { return service.OpenAIErrorWrapper(err, "unmarshal_claude_response_failed", http.StatusInternalServerError) } @@ -657,13 +657,14 @@ func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud case relaycommon.RelayFormatClaude: responseData = data } - c.Writer.Header().Set("Content-Type", "application/json") - c.Writer.WriteHeader(http.StatusOK) - _, err = c.Writer.Write(responseData) + + common.IOCopyBytesGracefully(c, nil, responseData) return nil } func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { + defer common.CloseResponseBodyGracefully(resp) + claudeInfo := &ClaudeResponseInfo{ ResponseId: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), Created: common.GetTimestamp(), @@ -675,7 +676,6 @@ func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *r if err != nil { return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil } - resp.Body.Close() if common.DebugEnabled { println("responseBody: ", string(responseBody)) } diff --git a/relay/channel/gemini/relay-gemini-native.go b/relay/channel/gemini/relay-gemini-native.go index 822d3097..52846c66 100644 --- a/relay/channel/gemini/relay-gemini-native.go +++ b/relay/channel/gemini/relay-gemini-native.go @@ -1,7 +1,6 @@ package gemini import ( - "encoding/json" "io" "net/http" "one-api/common" @@ -15,12 +14,13 @@ import ( ) func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *dto.OpenAIErrorWithStatusCode) { + defer common.CloseResponseBodyGracefully(resp) + // 读取响应体 responseBody, err := io.ReadAll(resp.Body) if err != nil { return nil, service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) } - common.CloseResponseBodyGracefully(resp) if common.DebugEnabled { println(string(responseBody)) @@ -28,7 +28,7 @@ func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *rela // 解析为 Gemini 原生响应格式 var geminiResponse GeminiChatResponse - err = common.DecodeJson(responseBody, &geminiResponse) + err = common.UnmarshalJson(responseBody, &geminiResponse) if err != nil { return nil, service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) } @@ -51,18 +51,12 @@ func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *rela } // 直接返回 Gemini 原生格式的 JSON 响应 - jsonResponse, err := json.Marshal(geminiResponse) + jsonResponse, err := common.EncodeJson(geminiResponse) if err != nil { return nil, service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError) } - // 设置响应头并写入响应 - c.Writer.Header().Set("Content-Type", "application/json") - c.Writer.WriteHeader(resp.StatusCode) - _, err = c.Writer.Write(jsonResponse) - if err != nil { - return nil, service.OpenAIErrorWrapper(err, "write_response_failed", http.StatusInternalServerError) - } + common.IOCopyBytesGracefully(c, resp, jsonResponse) return &usage, nil } @@ -77,7 +71,7 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, resp *http.Response, info helper.StreamScannerHandler(c, resp, info, func(data string) bool { var geminiResponse GeminiChatResponse - err := common.DecodeJsonStr(data, &geminiResponse) + err := common.UnmarshalJsonStr(data, &geminiResponse) if err != nil { common.LogError(c, "error unmarshalling stream response: "+err.Error()) return false diff --git a/relay/channel/gemini/relay-gemini.go b/relay/channel/gemini/relay-gemini.go index b01d46e4..1544e8cf 100644 --- a/relay/channel/gemini/relay-gemini.go +++ b/relay/channel/gemini/relay-gemini.go @@ -801,7 +801,7 @@ func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom helper.StreamScannerHandler(c, resp, info, func(data string) bool { var geminiResponse GeminiChatResponse - err := common.DecodeJsonStr(data, &geminiResponse) + err := common.UnmarshalJsonStr(data, &geminiResponse) if err != nil { common.LogError(c, "error unmarshalling stream response: "+err.Error()) return false @@ -871,7 +871,7 @@ func GeminiChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re println(string(responseBody)) } var geminiResponse GeminiChatResponse - err = common.DecodeJson(responseBody, &geminiResponse) + err = common.UnmarshalJson(responseBody, &geminiResponse) if err != nil { return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil } @@ -917,11 +917,12 @@ func GeminiChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re } func GeminiEmbeddingHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { + defer common.CloseResponseBodyGracefully(resp) + responseBody, readErr := io.ReadAll(resp.Body) if readErr != nil { return nil, service.OpenAIErrorWrapper(readErr, "read_response_body_failed", http.StatusInternalServerError) } - _ = resp.Body.Close() var geminiResponse GeminiEmbeddingResponse if jsonErr := json.Unmarshal(responseBody, &geminiResponse); jsonErr != nil { @@ -953,14 +954,11 @@ func GeminiEmbeddingHandler(c *gin.Context, resp *http.Response, info *relaycomm } openAIResponse.Usage = *usage.(*dto.Usage) - jsonResponse, jsonErr := json.Marshal(openAIResponse) + jsonResponse, jsonErr := common.EncodeJson(openAIResponse) if jsonErr != nil { return nil, service.OpenAIErrorWrapper(jsonErr, "marshal_response_failed", http.StatusInternalServerError) } - c.Writer.Header().Set("Content-Type", "application/json") - c.Writer.WriteHeader(resp.StatusCode) - _, _ = c.Writer.Write(jsonResponse) - + common.IOCopyBytesGracefully(c, resp, jsonResponse) return usage, nil } diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index cb8467e2..8a7d55d5 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -33,7 +33,7 @@ func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, fo } var lastStreamResponse dto.ChatCompletionsStreamResponse - if err := common.DecodeJsonStr(data, &lastStreamResponse); err != nil { + if err := common.UnmarshalJsonStr(data, &lastStreamResponse); err != nil { return err } @@ -188,7 +188,7 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI if err != nil { return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil } - err = common.DecodeJson(responseBody, &simpleResponse) + err = common.UnmarshalJson(responseBody, &simpleResponse) if err != nil { return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil } @@ -368,7 +368,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op } realtimeEvent := &dto.RealtimeEvent{} - err = common.DecodeJson(message, realtimeEvent) + err = common.UnmarshalJson(message, realtimeEvent) if err != nil { errChan <- fmt.Errorf("error unmarshalling message: %v", err) return @@ -428,7 +428,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op } info.SetFirstResponseTime() realtimeEvent := &dto.RealtimeEvent{} - err = common.DecodeJson(message, realtimeEvent) + err = common.UnmarshalJson(message, realtimeEvent) if err != nil { errChan <- fmt.Errorf("error unmarshalling message: %v", err) return @@ -562,7 +562,7 @@ func OpenaiHandlerWithUsage(c *gin.Context, resp *http.Response, info *relaycomm } var usageResp dto.SimpleResponse - err = common.DecodeJson(responseBody, &usageResp) + err = common.UnmarshalJson(responseBody, &usageResp) if err != nil { return service.OpenAIErrorWrapper(err, "parse_response_body_failed", http.StatusInternalServerError), nil } diff --git a/relay/channel/openai/relay_responses.go b/relay/channel/openai/relay_responses.go index 257cf68b..7f426c33 100644 --- a/relay/channel/openai/relay_responses.go +++ b/relay/channel/openai/relay_responses.go @@ -23,7 +23,7 @@ func OaiResponsesHandler(c *gin.Context, resp *http.Response, info *relaycommon. if err != nil { return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil } - err = common.DecodeJson(responseBody, &responsesResponse) + err = common.UnmarshalJson(responseBody, &responsesResponse) if err != nil { return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil } @@ -66,7 +66,7 @@ func OaiResponsesStreamHandler(c *gin.Context, resp *http.Response, info *relayc // 检查当前数据是否包含 completed 状态和 usage 信息 var streamResponse dto.ResponsesStreamResponse - if err := common.DecodeJsonStr(data, &streamResponse); err == nil { + if err := common.UnmarshalJsonStr(data, &streamResponse); err == nil { sendResponsesStreamData(c, streamResponse, data) switch streamResponse.Type { case "response.completed": diff --git a/relay/channel/xai/text.go b/relay/channel/xai/text.go index 4825ca69..4a030e48 100644 --- a/relay/channel/xai/text.go +++ b/relay/channel/xai/text.go @@ -82,7 +82,7 @@ func xAIHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo responseBody, err := io.ReadAll(resp.Body) var response *dto.SimpleResponse - err = common.DecodeJson(responseBody, &response) + err = common.UnmarshalJson(responseBody, &response) if err != nil { common.SysError("error unmarshalling stream response: " + err.Error()) return nil, nil diff --git a/relay/common_handler/rerank.go b/relay/common_handler/rerank.go index 63ab4769..d7033846 100644 --- a/relay/common_handler/rerank.go +++ b/relay/common_handler/rerank.go @@ -23,7 +23,7 @@ func RerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo var jinaResp dto.RerankResponse if info.ChannelType == common.ChannelTypeXinference { var xinRerankResponse xinference.XinRerankResponse - err = common.DecodeJson(responseBody, &xinRerankResponse) + err = common.UnmarshalJson(responseBody, &xinRerankResponse) if err != nil { return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil } @@ -58,7 +58,7 @@ func RerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo }, } } else { - err = common.DecodeJson(responseBody, &jinaResp) + err = common.UnmarshalJson(responseBody, &jinaResp) if err != nil { return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil }