diff --git a/dto/openai_response.go b/dto/openai_response.go index 52c1fdce..fe2609bf 100644 --- a/dto/openai_response.go +++ b/dto/openai_response.go @@ -26,7 +26,7 @@ type OpenAITextResponse struct { Object string `json:"object"` Created int64 `json:"created"` Choices []OpenAITextResponseChoice `json:"choices"` - Error *OpenAIError `json:"error"` + Error *OpenAIError `json:"error,omitempty"` Usage `json:"usage"` } diff --git a/relay/channel/aws/relay-aws.go b/relay/channel/aws/relay-aws.go index da4bab89..b8e2e624 100644 --- a/relay/channel/aws/relay-aws.go +++ b/relay/channel/aws/relay-aws.go @@ -84,22 +84,16 @@ func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (* return wrapErr(errors.Wrap(err, "InvokeModel")), nil } - claudeResponse := new(dto.ClaudeResponse) - err = json.Unmarshal(awsResp.Body, claudeResponse) - if err != nil { - return wrapErr(errors.Wrap(err, "unmarshal response")), nil + claudeInfo := &claude.ClaudeResponseInfo{ + ResponseId: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), + Created: common.GetTimestamp(), + Model: info.UpstreamModelName, + ResponseText: strings.Builder{}, + Usage: &dto.Usage{}, } - openaiResp := claude.ResponseClaude2OpenAI(requestMode, claudeResponse) - usage := dto.Usage{ - PromptTokens: claudeResponse.Usage.InputTokens, - CompletionTokens: claudeResponse.Usage.OutputTokens, - TotalTokens: claudeResponse.Usage.InputTokens + claudeResponse.Usage.OutputTokens, - } - openaiResp.Usage = usage - - c.JSON(http.StatusOK, openaiResp) - return nil, &usage + claude.HandleClaudeResponseData(c, info, claudeInfo, awsResp.Body, RequestModeMessage) + return nil, claudeInfo.Usage } func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { @@ -150,9 +144,9 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel switch v := event.(type) { case *types.ResponseStreamMemberChunk: info.SetFirstResponseTime() - err = claude.HandleResponseData(c, info, claudeInfo, string(v.Value.Bytes), RequestModeMessage) - if err != nil { - return wrapErr(err), nil + respErr := claude.HandleStreamResponseData(c, info, claudeInfo, string(v.Value.Bytes), RequestModeMessage) + if respErr != nil { + return respErr, nil } case *types.UnknownUnionMember: fmt.Println("unknown tag:", v.Tag) @@ -163,10 +157,6 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel } } - claude.HandleFinalResponse(c, info, claudeInfo, RequestModeMessage) - - if resp != nil { - resp.Body.Close() - } + claude.HandleStreamFinalResponse(c, info, claudeInfo, RequestModeMessage) return nil, claudeInfo.Usage } diff --git a/relay/channel/claude/relay-claude.go b/relay/channel/claude/relay-claude.go index dbb4a4da..4d4da247 100644 --- a/relay/channel/claude/relay-claude.go +++ b/relay/channel/claude/relay-claude.go @@ -478,12 +478,22 @@ func FormatClaudeResponseInfo(requestMode int, claudeResponse *dto.ClaudeRespons return true } -func HandleResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, data string, requestMode int) error { +func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, data string, requestMode int) *dto.OpenAIErrorWithStatusCode { var claudeResponse dto.ClaudeResponse err := common.DecodeJsonStr(data, &claudeResponse) if err != nil { common.SysError("error unmarshalling stream response: " + err.Error()) - return fmt.Errorf("error unmarshalling stream aws response: %w", err) + return service.OpenAIErrorWrapper(err, "stream_response_error", http.StatusInternalServerError) + } + if claudeResponse.Error.Type != "" { + return &dto.OpenAIErrorWithStatusCode{ + Error: dto.OpenAIError{ + Code: "stream_response_error", + Type: claudeResponse.Error.Type, + Message: claudeResponse.Error.Message, + }, + StatusCode: http.StatusInternalServerError, + } } if info.RelayFormat == relaycommon.RelayFormatClaude { if requestMode == RequestModeCompletion { @@ -523,7 +533,7 @@ func HandleResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo return nil } -func HandleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, requestMode int) { +func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, requestMode int) { if info.RelayFormat == relaycommon.RelayFormatClaude { if requestMode == RequestModeCompletion { claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens) @@ -566,81 +576,90 @@ func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon. ResponseText: strings.Builder{}, Usage: &dto.Usage{}, } - var err error + var err *dto.OpenAIErrorWithStatusCode helper.StreamScannerHandler(c, resp, info, func(data string) bool { - err = HandleResponseData(c, info, claudeInfo, data, requestMode) + err = HandleStreamResponseData(c, info, claudeInfo, data, requestMode) if err != nil { return false } return true }) if err != nil { - return service.OpenAIErrorWrapper(err, "stream_response_error", http.StatusInternalServerError), nil + return err, nil } - HandleFinalResponse(c, info, claudeInfo, requestMode) - + HandleStreamFinalResponse(c, info, claudeInfo, requestMode) return nil, claudeInfo.Usage } -func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { - responseBody, err := io.ReadAll(resp.Body) - if err != nil { - return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil - } - err = resp.Body.Close() - if err != nil { - return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil - } - if common.DebugEnabled { - println("responseBody: ", string(responseBody)) - } +func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, data []byte, requestMode int) *dto.OpenAIErrorWithStatusCode { var claudeResponse dto.ClaudeResponse - err = json.Unmarshal(responseBody, &claudeResponse) + err := common.DecodeJson(data, &claudeResponse) if err != nil { - return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + return service.OpenAIErrorWrapper(err, "unmarshal_claude_response_failed", http.StatusInternalServerError) } if claudeResponse.Error.Type != "" { return &dto.OpenAIErrorWithStatusCode{ Error: dto.OpenAIError{ Message: claudeResponse.Error.Message, Type: claudeResponse.Error.Type, - Param: "", Code: claudeResponse.Error.Type, }, - StatusCode: resp.StatusCode, - }, nil + StatusCode: http.StatusInternalServerError, + } } - usage := dto.Usage{} if requestMode == RequestModeCompletion { completionTokens, err := service.CountTextToken(claudeResponse.Completion, info.OriginModelName) if err != nil { - return service.OpenAIErrorWrapper(err, "count_token_text_failed", http.StatusInternalServerError), nil + return service.OpenAIErrorWrapper(err, "count_token_text_failed", http.StatusInternalServerError) } - usage.PromptTokens = info.PromptTokens - usage.CompletionTokens = completionTokens - usage.TotalTokens = info.PromptTokens + completionTokens + claudeInfo.Usage.PromptTokens = info.PromptTokens + claudeInfo.Usage.CompletionTokens = completionTokens + claudeInfo.Usage.TotalTokens = info.PromptTokens + completionTokens } else { - usage.PromptTokens = claudeResponse.Usage.InputTokens - usage.CompletionTokens = claudeResponse.Usage.OutputTokens - usage.TotalTokens = claudeResponse.Usage.InputTokens + claudeResponse.Usage.OutputTokens - usage.PromptTokensDetails.CachedTokens = claudeResponse.Usage.CacheReadInputTokens - usage.PromptTokensDetails.CachedCreationTokens = claudeResponse.Usage.CacheCreationInputTokens + claudeInfo.Usage.PromptTokens = claudeResponse.Usage.InputTokens + claudeInfo.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens + claudeInfo.Usage.TotalTokens = claudeResponse.Usage.InputTokens + claudeResponse.Usage.OutputTokens + claudeInfo.Usage.PromptTokensDetails.CachedTokens = claudeResponse.Usage.CacheReadInputTokens + claudeInfo.Usage.PromptTokensDetails.CachedCreationTokens = claudeResponse.Usage.CacheCreationInputTokens } var responseData []byte switch info.RelayFormat { case relaycommon.RelayFormatOpenAI: openaiResponse := ResponseClaude2OpenAI(requestMode, &claudeResponse) - openaiResponse.Usage = usage + openaiResponse.Usage = *claudeInfo.Usage responseData, err = json.Marshal(openaiResponse) if err != nil { - return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError) } case relaycommon.RelayFormatClaude: - responseData = responseBody + responseData = data } c.Writer.Header().Set("Content-Type", "application/json") - c.Writer.WriteHeader(resp.StatusCode) + c.Writer.WriteHeader(http.StatusOK) _, err = c.Writer.Write(responseData) - return nil, &usage + return nil +} + +func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { + claudeInfo := &ClaudeResponseInfo{ + ResponseId: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), + Created: common.GetTimestamp(), + Model: info.UpstreamModelName, + ResponseText: strings.Builder{}, + Usage: &dto.Usage{}, + } + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + resp.Body.Close() + if common.DebugEnabled { + println("responseBody: ", string(responseBody)) + } + handleErr := HandleClaudeResponseData(c, info, claudeInfo, responseBody, requestMode) + if handleErr != nil { + return handleErr, nil + } + return nil, claudeInfo.Usage } diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index c0080342..b20d66f3 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -240,7 +240,8 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI c.Writer.WriteHeader(resp.StatusCode) _, err = io.Copy(c.Writer, resp.Body) if err != nil { - return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil + //return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil + common.SysError("error copying response body: " + err.Error()) } resp.Body.Close() if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) {