From 6b9237f868d8fbd33dcd9080b0956195cd3edba6 Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Sat, 28 Jun 2025 00:02:07 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20fix:=20refactor=20JSON=20unmarsh?= =?UTF-8?q?alling=20across=20multiple=20handlers=20to=20use=20UnmarshalJso?= =?UTF-8?q?n=20and=20UnmarshalJsonStr=20for=20consistency?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This update replaces instances of DecodeJson and DecodeJsonStr with UnmarshalJson and UnmarshalJsonStr in various relay handlers, enhancing code consistency and clarity in JSON processing. The changes improve maintainability and align with recent refactoring efforts in the codebase. --- common/gin.go | 3 +-- common/json.go | 12 ++++++++---- dto/openai_request.go | 2 +- relay/channel/claude/relay-claude.go | 14 +++++++------- relay/channel/gemini/relay-gemini-native.go | 18 ++++++------------ relay/channel/gemini/relay-gemini.go | 14 ++++++-------- relay/channel/openai/relay-openai.go | 10 +++++----- relay/channel/openai/relay_responses.go | 4 ++-- relay/channel/xai/text.go | 2 +- relay/common_handler/rerank.go | 4 ++-- 10 files changed, 39 insertions(+), 44 deletions(-) diff --git a/common/gin.go b/common/gin.go index 4a909dfc..0614f735 100644 --- a/common/gin.go +++ b/common/gin.go @@ -2,7 +2,6 @@ package common import ( "bytes" - "encoding/json" "github.com/gin-gonic/gin" "io" "strings" @@ -31,7 +30,7 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error { } contentType := c.Request.Header.Get("Content-Type") if strings.HasPrefix(contentType, "application/json") { - err = json.Unmarshal(requestBody, &v) + err = UnmarshalJson(requestBody, &v) } else { // skip for now // TODO: someday non json request have variant model, we will need to implementation this diff --git a/common/json.go b/common/json.go index cec8f16b..512ad0c3 100644 --- a/common/json.go +++ b/common/json.go @@ -5,12 +5,16 @@ import ( "encoding/json" ) -func DecodeJson(data []byte, v any) error { - return json.NewDecoder(bytes.NewReader(data)).Decode(v) +func UnmarshalJson(data []byte, v any) error { + return json.Unmarshal(data, v) } -func DecodeJsonStr(data string, v any) error { - return DecodeJson(StringToByteSlice(data), v) +func UnmarshalJsonStr(data string, v any) error { + return json.Unmarshal(StringToByteSlice(data), v) +} + +func DecodeJson(reader *bytes.Reader, v any) error { + return json.NewDecoder(reader).Decode(v) } func EncodeJson(v any) ([]byte, error) { diff --git a/dto/openai_request.go b/dto/openai_request.go index 0104f347..a6567542 100644 --- a/dto/openai_request.go +++ b/dto/openai_request.go @@ -66,7 +66,7 @@ type GeneralOpenAIRequest struct { func (r *GeneralOpenAIRequest) ToMap() map[string]any { result := make(map[string]any) data, _ := common.EncodeJson(r) - _ = common.DecodeJson(data, &result) + _ = common.UnmarshalJson(data, &result) return result } diff --git a/relay/channel/claude/relay-claude.go b/relay/channel/claude/relay-claude.go index f164fd4d..a8607d86 100644 --- a/relay/channel/claude/relay-claude.go +++ b/relay/channel/claude/relay-claude.go @@ -125,7 +125,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.Cla if textRequest.Reasoning != nil { var reasoning openrouter.RequestReasoning - if err := common.DecodeJson(textRequest.Reasoning, &reasoning); err != nil { + if err := common.UnmarshalJson(textRequest.Reasoning, &reasoning); err != nil { return nil, err } @@ -519,7 +519,7 @@ func FormatClaudeResponseInfo(requestMode int, claudeResponse *dto.ClaudeRespons func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, data string, requestMode int) *dto.OpenAIErrorWithStatusCode { var claudeResponse dto.ClaudeResponse - err := common.DecodeJsonStr(data, &claudeResponse) + err := common.UnmarshalJsonStr(data, &claudeResponse) if err != nil { common.SysError("error unmarshalling stream response: " + err.Error()) return service.OpenAIErrorWrapper(err, "stream_response_error", http.StatusInternalServerError) @@ -619,7 +619,7 @@ func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon. func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, data []byte, requestMode int) *dto.OpenAIErrorWithStatusCode { var claudeResponse dto.ClaudeResponse - err := common.DecodeJson(data, &claudeResponse) + err := common.UnmarshalJson(data, &claudeResponse) if err != nil { return service.OpenAIErrorWrapper(err, "unmarshal_claude_response_failed", http.StatusInternalServerError) } @@ -657,13 +657,14 @@ func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud case relaycommon.RelayFormatClaude: responseData = data } - c.Writer.Header().Set("Content-Type", "application/json") - c.Writer.WriteHeader(http.StatusOK) - _, err = c.Writer.Write(responseData) + + common.IOCopyBytesGracefully(c, nil, responseData) return nil } func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { + defer common.CloseResponseBodyGracefully(resp) + claudeInfo := &ClaudeResponseInfo{ ResponseId: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), Created: common.GetTimestamp(), @@ -675,7 +676,6 @@ func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *r if err != nil { return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil } - resp.Body.Close() if common.DebugEnabled { println("responseBody: ", string(responseBody)) } diff --git a/relay/channel/gemini/relay-gemini-native.go b/relay/channel/gemini/relay-gemini-native.go index 822d3097..52846c66 100644 --- a/relay/channel/gemini/relay-gemini-native.go +++ b/relay/channel/gemini/relay-gemini-native.go @@ -1,7 +1,6 @@ package gemini import ( - "encoding/json" "io" "net/http" "one-api/common" @@ -15,12 +14,13 @@ import ( ) func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *dto.OpenAIErrorWithStatusCode) { + defer common.CloseResponseBodyGracefully(resp) + // 读取响应体 responseBody, err := io.ReadAll(resp.Body) if err != nil { return nil, service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) } - common.CloseResponseBodyGracefully(resp) if common.DebugEnabled { println(string(responseBody)) @@ -28,7 +28,7 @@ func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *rela // 解析为 Gemini 原生响应格式 var geminiResponse GeminiChatResponse - err = common.DecodeJson(responseBody, &geminiResponse) + err = common.UnmarshalJson(responseBody, &geminiResponse) if err != nil { return nil, service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) } @@ -51,18 +51,12 @@ func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *rela } // 直接返回 Gemini 原生格式的 JSON 响应 - jsonResponse, err := json.Marshal(geminiResponse) + jsonResponse, err := common.EncodeJson(geminiResponse) if err != nil { return nil, service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError) } - // 设置响应头并写入响应 - c.Writer.Header().Set("Content-Type", "application/json") - c.Writer.WriteHeader(resp.StatusCode) - _, err = c.Writer.Write(jsonResponse) - if err != nil { - return nil, service.OpenAIErrorWrapper(err, "write_response_failed", http.StatusInternalServerError) - } + common.IOCopyBytesGracefully(c, resp, jsonResponse) return &usage, nil } @@ -77,7 +71,7 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, resp *http.Response, info helper.StreamScannerHandler(c, resp, info, func(data string) bool { var geminiResponse GeminiChatResponse - err := common.DecodeJsonStr(data, &geminiResponse) + err := common.UnmarshalJsonStr(data, &geminiResponse) if err != nil { common.LogError(c, "error unmarshalling stream response: "+err.Error()) return false diff --git a/relay/channel/gemini/relay-gemini.go b/relay/channel/gemini/relay-gemini.go index b01d46e4..1544e8cf 100644 --- a/relay/channel/gemini/relay-gemini.go +++ b/relay/channel/gemini/relay-gemini.go @@ -801,7 +801,7 @@ func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom helper.StreamScannerHandler(c, resp, info, func(data string) bool { var geminiResponse GeminiChatResponse - err := common.DecodeJsonStr(data, &geminiResponse) + err := common.UnmarshalJsonStr(data, &geminiResponse) if err != nil { common.LogError(c, "error unmarshalling stream response: "+err.Error()) return false @@ -871,7 +871,7 @@ func GeminiChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re println(string(responseBody)) } var geminiResponse GeminiChatResponse - err = common.DecodeJson(responseBody, &geminiResponse) + err = common.UnmarshalJson(responseBody, &geminiResponse) if err != nil { return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil } @@ -917,11 +917,12 @@ func GeminiChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re } func GeminiEmbeddingHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { + defer common.CloseResponseBodyGracefully(resp) + responseBody, readErr := io.ReadAll(resp.Body) if readErr != nil { return nil, service.OpenAIErrorWrapper(readErr, "read_response_body_failed", http.StatusInternalServerError) } - _ = resp.Body.Close() var geminiResponse GeminiEmbeddingResponse if jsonErr := json.Unmarshal(responseBody, &geminiResponse); jsonErr != nil { @@ -953,14 +954,11 @@ func GeminiEmbeddingHandler(c *gin.Context, resp *http.Response, info *relaycomm } openAIResponse.Usage = *usage.(*dto.Usage) - jsonResponse, jsonErr := json.Marshal(openAIResponse) + jsonResponse, jsonErr := common.EncodeJson(openAIResponse) if jsonErr != nil { return nil, service.OpenAIErrorWrapper(jsonErr, "marshal_response_failed", http.StatusInternalServerError) } - c.Writer.Header().Set("Content-Type", "application/json") - c.Writer.WriteHeader(resp.StatusCode) - _, _ = c.Writer.Write(jsonResponse) - + common.IOCopyBytesGracefully(c, resp, jsonResponse) return usage, nil } diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index cb8467e2..8a7d55d5 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -33,7 +33,7 @@ func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, fo } var lastStreamResponse dto.ChatCompletionsStreamResponse - if err := common.DecodeJsonStr(data, &lastStreamResponse); err != nil { + if err := common.UnmarshalJsonStr(data, &lastStreamResponse); err != nil { return err } @@ -188,7 +188,7 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI if err != nil { return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil } - err = common.DecodeJson(responseBody, &simpleResponse) + err = common.UnmarshalJson(responseBody, &simpleResponse) if err != nil { return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil } @@ -368,7 +368,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op } realtimeEvent := &dto.RealtimeEvent{} - err = common.DecodeJson(message, realtimeEvent) + err = common.UnmarshalJson(message, realtimeEvent) if err != nil { errChan <- fmt.Errorf("error unmarshalling message: %v", err) return @@ -428,7 +428,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op } info.SetFirstResponseTime() realtimeEvent := &dto.RealtimeEvent{} - err = common.DecodeJson(message, realtimeEvent) + err = common.UnmarshalJson(message, realtimeEvent) if err != nil { errChan <- fmt.Errorf("error unmarshalling message: %v", err) return @@ -562,7 +562,7 @@ func OpenaiHandlerWithUsage(c *gin.Context, resp *http.Response, info *relaycomm } var usageResp dto.SimpleResponse - err = common.DecodeJson(responseBody, &usageResp) + err = common.UnmarshalJson(responseBody, &usageResp) if err != nil { return service.OpenAIErrorWrapper(err, "parse_response_body_failed", http.StatusInternalServerError), nil } diff --git a/relay/channel/openai/relay_responses.go b/relay/channel/openai/relay_responses.go index 257cf68b..7f426c33 100644 --- a/relay/channel/openai/relay_responses.go +++ b/relay/channel/openai/relay_responses.go @@ -23,7 +23,7 @@ func OaiResponsesHandler(c *gin.Context, resp *http.Response, info *relaycommon. if err != nil { return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil } - err = common.DecodeJson(responseBody, &responsesResponse) + err = common.UnmarshalJson(responseBody, &responsesResponse) if err != nil { return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil } @@ -66,7 +66,7 @@ func OaiResponsesStreamHandler(c *gin.Context, resp *http.Response, info *relayc // 检查当前数据是否包含 completed 状态和 usage 信息 var streamResponse dto.ResponsesStreamResponse - if err := common.DecodeJsonStr(data, &streamResponse); err == nil { + if err := common.UnmarshalJsonStr(data, &streamResponse); err == nil { sendResponsesStreamData(c, streamResponse, data) switch streamResponse.Type { case "response.completed": diff --git a/relay/channel/xai/text.go b/relay/channel/xai/text.go index 4825ca69..4a030e48 100644 --- a/relay/channel/xai/text.go +++ b/relay/channel/xai/text.go @@ -82,7 +82,7 @@ func xAIHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo responseBody, err := io.ReadAll(resp.Body) var response *dto.SimpleResponse - err = common.DecodeJson(responseBody, &response) + err = common.UnmarshalJson(responseBody, &response) if err != nil { common.SysError("error unmarshalling stream response: " + err.Error()) return nil, nil diff --git a/relay/common_handler/rerank.go b/relay/common_handler/rerank.go index 63ab4769..d7033846 100644 --- a/relay/common_handler/rerank.go +++ b/relay/common_handler/rerank.go @@ -23,7 +23,7 @@ func RerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo var jinaResp dto.RerankResponse if info.ChannelType == common.ChannelTypeXinference { var xinRerankResponse xinference.XinRerankResponse - err = common.DecodeJson(responseBody, &xinRerankResponse) + err = common.UnmarshalJson(responseBody, &xinRerankResponse) if err != nil { return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil } @@ -58,7 +58,7 @@ func RerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo }, } } else { - err = common.DecodeJson(responseBody, &jinaResp) + err = common.UnmarshalJson(responseBody, &jinaResp) if err != nil { return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil }