🐛 fix: refactor JSON unmarshalling across multiple handlers to use UnmarshalJson and UnmarshalJsonStr for consistency

This update replaces instances of DecodeJson and DecodeJsonStr with UnmarshalJson and UnmarshalJsonStr in various relay handlers, enhancing code consistency and clarity in JSON processing. The changes improve maintainability and align with recent refactoring efforts in the codebase.
This commit is contained in:
CaIon
2025-06-28 00:02:07 +08:00
parent 1f4cf07b63
commit 6b9237f868
10 changed files with 39 additions and 44 deletions

View File

@@ -2,7 +2,6 @@ package common
import ( import (
"bytes" "bytes"
"encoding/json"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"io" "io"
"strings" "strings"
@@ -31,7 +30,7 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error {
} }
contentType := c.Request.Header.Get("Content-Type") contentType := c.Request.Header.Get("Content-Type")
if strings.HasPrefix(contentType, "application/json") { if strings.HasPrefix(contentType, "application/json") {
err = json.Unmarshal(requestBody, &v) err = UnmarshalJson(requestBody, &v)
} else { } else {
// skip for now // skip for now
// TODO: someday non json request have variant model, we will need to implementation this // TODO: someday non json request have variant model, we will need to implementation this

View File

@@ -5,12 +5,16 @@ import (
"encoding/json" "encoding/json"
) )
func DecodeJson(data []byte, v any) error { func UnmarshalJson(data []byte, v any) error {
return json.NewDecoder(bytes.NewReader(data)).Decode(v) return json.Unmarshal(data, v)
} }
func DecodeJsonStr(data string, v any) error { func UnmarshalJsonStr(data string, v any) error {
return DecodeJson(StringToByteSlice(data), v) return json.Unmarshal(StringToByteSlice(data), v)
}
func DecodeJson(reader *bytes.Reader, v any) error {
return json.NewDecoder(reader).Decode(v)
} }
func EncodeJson(v any) ([]byte, error) { func EncodeJson(v any) ([]byte, error) {

View File

@@ -66,7 +66,7 @@ type GeneralOpenAIRequest struct {
func (r *GeneralOpenAIRequest) ToMap() map[string]any { func (r *GeneralOpenAIRequest) ToMap() map[string]any {
result := make(map[string]any) result := make(map[string]any)
data, _ := common.EncodeJson(r) data, _ := common.EncodeJson(r)
_ = common.DecodeJson(data, &result) _ = common.UnmarshalJson(data, &result)
return result return result
} }

View File

@@ -125,7 +125,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.Cla
if textRequest.Reasoning != nil { if textRequest.Reasoning != nil {
var reasoning openrouter.RequestReasoning var reasoning openrouter.RequestReasoning
if err := common.DecodeJson(textRequest.Reasoning, &reasoning); err != nil { if err := common.UnmarshalJson(textRequest.Reasoning, &reasoning); err != nil {
return nil, err return nil, err
} }
@@ -519,7 +519,7 @@ func FormatClaudeResponseInfo(requestMode int, claudeResponse *dto.ClaudeRespons
func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, data string, requestMode int) *dto.OpenAIErrorWithStatusCode { func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, data string, requestMode int) *dto.OpenAIErrorWithStatusCode {
var claudeResponse dto.ClaudeResponse var claudeResponse dto.ClaudeResponse
err := common.DecodeJsonStr(data, &claudeResponse) err := common.UnmarshalJsonStr(data, &claudeResponse)
if err != nil { if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error()) common.SysError("error unmarshalling stream response: " + err.Error())
return service.OpenAIErrorWrapper(err, "stream_response_error", http.StatusInternalServerError) return service.OpenAIErrorWrapper(err, "stream_response_error", http.StatusInternalServerError)
@@ -619,7 +619,7 @@ func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, data []byte, requestMode int) *dto.OpenAIErrorWithStatusCode { func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, data []byte, requestMode int) *dto.OpenAIErrorWithStatusCode {
var claudeResponse dto.ClaudeResponse var claudeResponse dto.ClaudeResponse
err := common.DecodeJson(data, &claudeResponse) err := common.UnmarshalJson(data, &claudeResponse)
if err != nil { if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_claude_response_failed", http.StatusInternalServerError) return service.OpenAIErrorWrapper(err, "unmarshal_claude_response_failed", http.StatusInternalServerError)
} }
@@ -657,13 +657,14 @@ func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
case relaycommon.RelayFormatClaude: case relaycommon.RelayFormatClaude:
responseData = data responseData = data
} }
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(http.StatusOK) common.IOCopyBytesGracefully(c, nil, responseData)
_, err = c.Writer.Write(responseData)
return nil return nil
} }
func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
defer common.CloseResponseBodyGracefully(resp)
claudeInfo := &ClaudeResponseInfo{ claudeInfo := &ClaudeResponseInfo{
ResponseId: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), ResponseId: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
Created: common.GetTimestamp(), Created: common.GetTimestamp(),
@@ -675,7 +676,6 @@ func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *r
if err != nil { if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
} }
resp.Body.Close()
if common.DebugEnabled { if common.DebugEnabled {
println("responseBody: ", string(responseBody)) println("responseBody: ", string(responseBody))
} }

View File

@@ -1,7 +1,6 @@
package gemini package gemini
import ( import (
"encoding/json"
"io" "io"
"net/http" "net/http"
"one-api/common" "one-api/common"
@@ -15,12 +14,13 @@ import (
) )
func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *dto.OpenAIErrorWithStatusCode) { func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *dto.OpenAIErrorWithStatusCode) {
defer common.CloseResponseBodyGracefully(resp)
// 读取响应体 // 读取响应体
responseBody, err := io.ReadAll(resp.Body) responseBody, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return nil, service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) return nil, service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
} }
common.CloseResponseBodyGracefully(resp)
if common.DebugEnabled { if common.DebugEnabled {
println(string(responseBody)) println(string(responseBody))
@@ -28,7 +28,7 @@ func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *rela
// 解析为 Gemini 原生响应格式 // 解析为 Gemini 原生响应格式
var geminiResponse GeminiChatResponse var geminiResponse GeminiChatResponse
err = common.DecodeJson(responseBody, &geminiResponse) err = common.UnmarshalJson(responseBody, &geminiResponse)
if err != nil { if err != nil {
return nil, service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) return nil, service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
} }
@@ -51,18 +51,12 @@ func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *rela
} }
// 直接返回 Gemini 原生格式的 JSON 响应 // 直接返回 Gemini 原生格式的 JSON 响应
jsonResponse, err := json.Marshal(geminiResponse) jsonResponse, err := common.EncodeJson(geminiResponse)
if err != nil { if err != nil {
return nil, service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError) return nil, service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError)
} }
// 设置响应头并写入响应 common.IOCopyBytesGracefully(c, resp, jsonResponse)
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
if err != nil {
return nil, service.OpenAIErrorWrapper(err, "write_response_failed", http.StatusInternalServerError)
}
return &usage, nil return &usage, nil
} }
@@ -77,7 +71,7 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, resp *http.Response, info
helper.StreamScannerHandler(c, resp, info, func(data string) bool { helper.StreamScannerHandler(c, resp, info, func(data string) bool {
var geminiResponse GeminiChatResponse var geminiResponse GeminiChatResponse
err := common.DecodeJsonStr(data, &geminiResponse) err := common.UnmarshalJsonStr(data, &geminiResponse)
if err != nil { if err != nil {
common.LogError(c, "error unmarshalling stream response: "+err.Error()) common.LogError(c, "error unmarshalling stream response: "+err.Error())
return false return false

View File

@@ -801,7 +801,7 @@ func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom
helper.StreamScannerHandler(c, resp, info, func(data string) bool { helper.StreamScannerHandler(c, resp, info, func(data string) bool {
var geminiResponse GeminiChatResponse var geminiResponse GeminiChatResponse
err := common.DecodeJsonStr(data, &geminiResponse) err := common.UnmarshalJsonStr(data, &geminiResponse)
if err != nil { if err != nil {
common.LogError(c, "error unmarshalling stream response: "+err.Error()) common.LogError(c, "error unmarshalling stream response: "+err.Error())
return false return false
@@ -871,7 +871,7 @@ func GeminiChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
println(string(responseBody)) println(string(responseBody))
} }
var geminiResponse GeminiChatResponse var geminiResponse GeminiChatResponse
err = common.DecodeJson(responseBody, &geminiResponse) err = common.UnmarshalJson(responseBody, &geminiResponse)
if err != nil { if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
} }
@@ -917,11 +917,12 @@ func GeminiChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
} }
func GeminiEmbeddingHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { func GeminiEmbeddingHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
defer common.CloseResponseBodyGracefully(resp)
responseBody, readErr := io.ReadAll(resp.Body) responseBody, readErr := io.ReadAll(resp.Body)
if readErr != nil { if readErr != nil {
return nil, service.OpenAIErrorWrapper(readErr, "read_response_body_failed", http.StatusInternalServerError) return nil, service.OpenAIErrorWrapper(readErr, "read_response_body_failed", http.StatusInternalServerError)
} }
_ = resp.Body.Close()
var geminiResponse GeminiEmbeddingResponse var geminiResponse GeminiEmbeddingResponse
if jsonErr := json.Unmarshal(responseBody, &geminiResponse); jsonErr != nil { if jsonErr := json.Unmarshal(responseBody, &geminiResponse); jsonErr != nil {
@@ -953,14 +954,11 @@ func GeminiEmbeddingHandler(c *gin.Context, resp *http.Response, info *relaycomm
} }
openAIResponse.Usage = *usage.(*dto.Usage) openAIResponse.Usage = *usage.(*dto.Usage)
jsonResponse, jsonErr := json.Marshal(openAIResponse) jsonResponse, jsonErr := common.EncodeJson(openAIResponse)
if jsonErr != nil { if jsonErr != nil {
return nil, service.OpenAIErrorWrapper(jsonErr, "marshal_response_failed", http.StatusInternalServerError) return nil, service.OpenAIErrorWrapper(jsonErr, "marshal_response_failed", http.StatusInternalServerError)
} }
c.Writer.Header().Set("Content-Type", "application/json") common.IOCopyBytesGracefully(c, resp, jsonResponse)
c.Writer.WriteHeader(resp.StatusCode)
_, _ = c.Writer.Write(jsonResponse)
return usage, nil return usage, nil
} }

View File

@@ -33,7 +33,7 @@ func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, fo
} }
var lastStreamResponse dto.ChatCompletionsStreamResponse var lastStreamResponse dto.ChatCompletionsStreamResponse
if err := common.DecodeJsonStr(data, &lastStreamResponse); err != nil { if err := common.UnmarshalJsonStr(data, &lastStreamResponse); err != nil {
return err return err
} }
@@ -188,7 +188,7 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI
if err != nil { if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
} }
err = common.DecodeJson(responseBody, &simpleResponse) err = common.UnmarshalJson(responseBody, &simpleResponse)
if err != nil { if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
} }
@@ -368,7 +368,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op
} }
realtimeEvent := &dto.RealtimeEvent{} realtimeEvent := &dto.RealtimeEvent{}
err = common.DecodeJson(message, realtimeEvent) err = common.UnmarshalJson(message, realtimeEvent)
if err != nil { if err != nil {
errChan <- fmt.Errorf("error unmarshalling message: %v", err) errChan <- fmt.Errorf("error unmarshalling message: %v", err)
return return
@@ -428,7 +428,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op
} }
info.SetFirstResponseTime() info.SetFirstResponseTime()
realtimeEvent := &dto.RealtimeEvent{} realtimeEvent := &dto.RealtimeEvent{}
err = common.DecodeJson(message, realtimeEvent) err = common.UnmarshalJson(message, realtimeEvent)
if err != nil { if err != nil {
errChan <- fmt.Errorf("error unmarshalling message: %v", err) errChan <- fmt.Errorf("error unmarshalling message: %v", err)
return return
@@ -562,7 +562,7 @@ func OpenaiHandlerWithUsage(c *gin.Context, resp *http.Response, info *relaycomm
} }
var usageResp dto.SimpleResponse var usageResp dto.SimpleResponse
err = common.DecodeJson(responseBody, &usageResp) err = common.UnmarshalJson(responseBody, &usageResp)
if err != nil { if err != nil {
return service.OpenAIErrorWrapper(err, "parse_response_body_failed", http.StatusInternalServerError), nil return service.OpenAIErrorWrapper(err, "parse_response_body_failed", http.StatusInternalServerError), nil
} }

View File

@@ -23,7 +23,7 @@ func OaiResponsesHandler(c *gin.Context, resp *http.Response, info *relaycommon.
if err != nil { if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
} }
err = common.DecodeJson(responseBody, &responsesResponse) err = common.UnmarshalJson(responseBody, &responsesResponse)
if err != nil { if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
} }
@@ -66,7 +66,7 @@ func OaiResponsesStreamHandler(c *gin.Context, resp *http.Response, info *relayc
// 检查当前数据是否包含 completed 状态和 usage 信息 // 检查当前数据是否包含 completed 状态和 usage 信息
var streamResponse dto.ResponsesStreamResponse var streamResponse dto.ResponsesStreamResponse
if err := common.DecodeJsonStr(data, &streamResponse); err == nil { if err := common.UnmarshalJsonStr(data, &streamResponse); err == nil {
sendResponsesStreamData(c, streamResponse, data) sendResponsesStreamData(c, streamResponse, data)
switch streamResponse.Type { switch streamResponse.Type {
case "response.completed": case "response.completed":

View File

@@ -82,7 +82,7 @@ func xAIHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo
responseBody, err := io.ReadAll(resp.Body) responseBody, err := io.ReadAll(resp.Body)
var response *dto.SimpleResponse var response *dto.SimpleResponse
err = common.DecodeJson(responseBody, &response) err = common.UnmarshalJson(responseBody, &response)
if err != nil { if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error()) common.SysError("error unmarshalling stream response: " + err.Error())
return nil, nil return nil, nil

View File

@@ -23,7 +23,7 @@ func RerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
var jinaResp dto.RerankResponse var jinaResp dto.RerankResponse
if info.ChannelType == common.ChannelTypeXinference { if info.ChannelType == common.ChannelTypeXinference {
var xinRerankResponse xinference.XinRerankResponse var xinRerankResponse xinference.XinRerankResponse
err = common.DecodeJson(responseBody, &xinRerankResponse) err = common.UnmarshalJson(responseBody, &xinRerankResponse)
if err != nil { if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
} }
@@ -58,7 +58,7 @@ func RerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
}, },
} }
} else { } else {
err = common.DecodeJson(responseBody, &jinaResp) err = common.UnmarshalJson(responseBody, &jinaResp)
if err != nil { if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
} }