diff --git a/common/json.go b/common/json.go new file mode 100644 index 00000000..5b2b1aac --- /dev/null +++ b/common/json.go @@ -0,0 +1,14 @@ +package common + +import ( + "bytes" + "encoding/json" +) + +func DecodeJson(data []byte, v any) error { + return json.NewDecoder(bytes.NewReader(data)).Decode(v) +} + +func DecodeJsonStr(data string, v any) error { + return DecodeJson(StringToByteSlice(data), v) +} diff --git a/dto/claude.go b/dto/claude.go index f9a6024a..2af43dae 100644 --- a/dto/claude.go +++ b/dto/claude.go @@ -165,8 +165,8 @@ func (c *ClaudeRequest) ParseSystem() []ClaudeMediaMessage { } type ClaudeError struct { - Type string `json:"type"` - Message string `json:"message"` + Type string `json:"type,omitempty"` + Message string `json:"message,omitempty"` } type ClaudeErrorWithStatusCode struct { diff --git a/dto/openai_response.go b/dto/openai_response.go index 53883bb4..52c1fdce 100644 --- a/dto/openai_response.go +++ b/dto/openai_response.go @@ -1,9 +1,8 @@ package dto type SimpleResponse struct { - Usage `json:"usage"` - Error OpenAIError `json:"error"` - Choices []OpenAITextResponseChoice `json:"choices"` + Usage `json:"usage"` + Error *OpenAIError `json:"error"` } type TextResponse struct { @@ -27,6 +26,7 @@ type OpenAITextResponse struct { Object string `json:"object"` Created int64 `json:"created"` Choices []OpenAITextResponseChoice `json:"choices"` + Error *OpenAIError `json:"error"` Usage `json:"usage"` } diff --git a/relay/channel/claude/relay-claude.go b/relay/channel/claude/relay-claude.go index 1214699b..dbb4a4da 100644 --- a/relay/channel/claude/relay-claude.go +++ b/relay/channel/claude/relay-claude.go @@ -1,7 +1,6 @@ package claude import ( - "bytes" "encoding/json" "fmt" "io" @@ -481,7 +480,7 @@ func FormatClaudeResponseInfo(requestMode int, claudeResponse *dto.ClaudeRespons func HandleResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, data string, requestMode int) error { var claudeResponse dto.ClaudeResponse - err := json.NewDecoder(bytes.NewReader(common.StringToByteSlice(data))).Decode(&claudeResponse) + err := common.DecodeJsonStr(data, &claudeResponse) if err != nil { common.SysError("error unmarshalling stream response: " + err.Error()) return fmt.Errorf("error unmarshalling stream aws response: %w", err) diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index 30f927a7..c0080342 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -33,7 +33,7 @@ func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, fo } var lastStreamResponse dto.ChatCompletionsStreamResponse - if err := json.Unmarshal(common.StringToByteSlice(data), &lastStreamResponse); err != nil { + if err := common.DecodeJsonStr(data, &lastStreamResponse); err != nil { return err } @@ -151,7 +151,7 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel shouldSendLastResp := true var lastStreamResponse dto.ChatCompletionsStreamResponse - err := json.Unmarshal(common.StringToByteSlice(lastStreamData), &lastStreamResponse) + err := common.DecodeJsonStr(lastStreamData, &lastStreamResponse) if err == nil { responseId = lastStreamResponse.Id createAt = lastStreamResponse.Created @@ -196,7 +196,7 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel } func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { - var simpleResponse dto.SimpleResponse + var simpleResponse dto.OpenAITextResponse responseBody, err := io.ReadAll(resp.Body) if err != nil { return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil @@ -205,16 +205,29 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI if err != nil { return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil } - err = json.Unmarshal(responseBody, &simpleResponse) + err = common.DecodeJson(responseBody, &simpleResponse) if err != nil { return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil } - if simpleResponse.Error.Type != "" { + if simpleResponse.Error != nil && simpleResponse.Error.Type != "" { return &dto.OpenAIErrorWithStatusCode{ - Error: simpleResponse.Error, + Error: *simpleResponse.Error, StatusCode: resp.StatusCode, }, nil } + + switch info.RelayFormat { + case relaycommon.RelayFormatOpenAI: + break + case relaycommon.RelayFormatClaude: + claudeResp := service.ResponseOpenAI2Claude(&simpleResponse, info) + claudeRespStr, err := json.Marshal(claudeResp) + if err != nil { + return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + responseBody = claudeRespStr + } + // Reset response body resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) // We shouldn't set the header before we parse the response body, because the parse part may fail. diff --git a/service/error.go b/service/error.go index 9824a853..1bf5992b 100644 --- a/service/error.go +++ b/service/error.go @@ -60,7 +60,6 @@ func ClaudeErrorWrapper(err error, code string, statusCode int) *dto.ClaudeError claudeError := dto.ClaudeError{ Message: text, Type: "new_api_error", - //Code: code, } return &dto.ClaudeErrorWithStatusCode{ Error: claudeError,