refactor: Enhance Claude response handling

This commit is contained in:
1808837298@qq.com
2025-03-16 19:11:58 +08:00
parent b3b1c803fc
commit 53b3599827
4 changed files with 74 additions and 64 deletions

View File

@@ -26,7 +26,7 @@ type OpenAITextResponse struct {
Object string `json:"object"`
Created int64 `json:"created"`
Choices []OpenAITextResponseChoice `json:"choices"`
Error *OpenAIError `json:"error"`
Error *OpenAIError `json:"error,omitempty"`
Usage `json:"usage"`
}

View File

@@ -84,22 +84,16 @@ func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*
return wrapErr(errors.Wrap(err, "InvokeModel")), nil
}
claudeResponse := new(dto.ClaudeResponse)
err = json.Unmarshal(awsResp.Body, claudeResponse)
if err != nil {
return wrapErr(errors.Wrap(err, "unmarshal response")), nil
claudeInfo := &claude.ClaudeResponseInfo{
ResponseId: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
Created: common.GetTimestamp(),
Model: info.UpstreamModelName,
ResponseText: strings.Builder{},
Usage: &dto.Usage{},
}
openaiResp := claude.ResponseClaude2OpenAI(requestMode, claudeResponse)
usage := dto.Usage{
PromptTokens: claudeResponse.Usage.InputTokens,
CompletionTokens: claudeResponse.Usage.OutputTokens,
TotalTokens: claudeResponse.Usage.InputTokens + claudeResponse.Usage.OutputTokens,
}
openaiResp.Usage = usage
c.JSON(http.StatusOK, openaiResp)
return nil, &usage
claude.HandleClaudeResponseData(c, info, claudeInfo, awsResp.Body, RequestModeMessage)
return nil, claudeInfo.Usage
}
func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
@@ -150,9 +144,9 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
switch v := event.(type) {
case *types.ResponseStreamMemberChunk:
info.SetFirstResponseTime()
err = claude.HandleResponseData(c, info, claudeInfo, string(v.Value.Bytes), RequestModeMessage)
if err != nil {
return wrapErr(err), nil
respErr := claude.HandleStreamResponseData(c, info, claudeInfo, string(v.Value.Bytes), RequestModeMessage)
if respErr != nil {
return respErr, nil
}
case *types.UnknownUnionMember:
fmt.Println("unknown tag:", v.Tag)
@@ -163,10 +157,6 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
}
}
claude.HandleFinalResponse(c, info, claudeInfo, RequestModeMessage)
if resp != nil {
resp.Body.Close()
}
claude.HandleStreamFinalResponse(c, info, claudeInfo, RequestModeMessage)
return nil, claudeInfo.Usage
}

View File

@@ -478,12 +478,22 @@ func FormatClaudeResponseInfo(requestMode int, claudeResponse *dto.ClaudeRespons
return true
}
func HandleResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, data string, requestMode int) error {
func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, data string, requestMode int) *dto.OpenAIErrorWithStatusCode {
var claudeResponse dto.ClaudeResponse
err := common.DecodeJsonStr(data, &claudeResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
return fmt.Errorf("error unmarshalling stream aws response: %w", err)
return service.OpenAIErrorWrapper(err, "stream_response_error", http.StatusInternalServerError)
}
if claudeResponse.Error.Type != "" {
return &dto.OpenAIErrorWithStatusCode{
Error: dto.OpenAIError{
Code: "stream_response_error",
Type: claudeResponse.Error.Type,
Message: claudeResponse.Error.Message,
},
StatusCode: http.StatusInternalServerError,
}
}
if info.RelayFormat == relaycommon.RelayFormatClaude {
if requestMode == RequestModeCompletion {
@@ -523,7 +533,7 @@ func HandleResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo
return nil
}
func HandleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, requestMode int) {
func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, requestMode int) {
if info.RelayFormat == relaycommon.RelayFormatClaude {
if requestMode == RequestModeCompletion {
claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens)
@@ -566,81 +576,90 @@ func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
ResponseText: strings.Builder{},
Usage: &dto.Usage{},
}
var err error
var err *dto.OpenAIErrorWithStatusCode
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
err = HandleResponseData(c, info, claudeInfo, data, requestMode)
err = HandleStreamResponseData(c, info, claudeInfo, data, requestMode)
if err != nil {
return false
}
return true
})
if err != nil {
return service.OpenAIErrorWrapper(err, "stream_response_error", http.StatusInternalServerError), nil
return err, nil
}
HandleFinalResponse(c, info, claudeInfo, requestMode)
HandleStreamFinalResponse(c, info, claudeInfo, requestMode)
return nil, claudeInfo.Usage
}
func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
if common.DebugEnabled {
println("responseBody: ", string(responseBody))
}
func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, data []byte, requestMode int) *dto.OpenAIErrorWithStatusCode {
var claudeResponse dto.ClaudeResponse
err = json.Unmarshal(responseBody, &claudeResponse)
err := common.DecodeJson(data, &claudeResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
return service.OpenAIErrorWrapper(err, "unmarshal_claude_response_failed", http.StatusInternalServerError)
}
if claudeResponse.Error.Type != "" {
return &dto.OpenAIErrorWithStatusCode{
Error: dto.OpenAIError{
Message: claudeResponse.Error.Message,
Type: claudeResponse.Error.Type,
Param: "",
Code: claudeResponse.Error.Type,
},
StatusCode: resp.StatusCode,
}, nil
StatusCode: http.StatusInternalServerError,
}
}
usage := dto.Usage{}
if requestMode == RequestModeCompletion {
completionTokens, err := service.CountTextToken(claudeResponse.Completion, info.OriginModelName)
if err != nil {
return service.OpenAIErrorWrapper(err, "count_token_text_failed", http.StatusInternalServerError), nil
return service.OpenAIErrorWrapper(err, "count_token_text_failed", http.StatusInternalServerError)
}
usage.PromptTokens = info.PromptTokens
usage.CompletionTokens = completionTokens
usage.TotalTokens = info.PromptTokens + completionTokens
claudeInfo.Usage.PromptTokens = info.PromptTokens
claudeInfo.Usage.CompletionTokens = completionTokens
claudeInfo.Usage.TotalTokens = info.PromptTokens + completionTokens
} else {
usage.PromptTokens = claudeResponse.Usage.InputTokens
usage.CompletionTokens = claudeResponse.Usage.OutputTokens
usage.TotalTokens = claudeResponse.Usage.InputTokens + claudeResponse.Usage.OutputTokens
usage.PromptTokensDetails.CachedTokens = claudeResponse.Usage.CacheReadInputTokens
usage.PromptTokensDetails.CachedCreationTokens = claudeResponse.Usage.CacheCreationInputTokens
claudeInfo.Usage.PromptTokens = claudeResponse.Usage.InputTokens
claudeInfo.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens
claudeInfo.Usage.TotalTokens = claudeResponse.Usage.InputTokens + claudeResponse.Usage.OutputTokens
claudeInfo.Usage.PromptTokensDetails.CachedTokens = claudeResponse.Usage.CacheReadInputTokens
claudeInfo.Usage.PromptTokensDetails.CachedCreationTokens = claudeResponse.Usage.CacheCreationInputTokens
}
var responseData []byte
switch info.RelayFormat {
case relaycommon.RelayFormatOpenAI:
openaiResponse := ResponseClaude2OpenAI(requestMode, &claudeResponse)
openaiResponse.Usage = usage
openaiResponse.Usage = *claudeInfo.Usage
responseData, err = json.Marshal(openaiResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError)
}
case relaycommon.RelayFormatClaude:
responseData = responseBody
responseData = data
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
c.Writer.WriteHeader(http.StatusOK)
_, err = c.Writer.Write(responseData)
return nil, &usage
return nil
}
func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
claudeInfo := &ClaudeResponseInfo{
ResponseId: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
Created: common.GetTimestamp(),
Model: info.UpstreamModelName,
ResponseText: strings.Builder{},
Usage: &dto.Usage{},
}
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
resp.Body.Close()
if common.DebugEnabled {
println("responseBody: ", string(responseBody))
}
handleErr := HandleClaudeResponseData(c, info, claudeInfo, responseBody, requestMode)
if handleErr != nil {
return handleErr, nil
}
return nil, claudeInfo.Usage
}

View File

@@ -240,7 +240,8 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI
c.Writer.WriteHeader(resp.StatusCode)
_, err = io.Copy(c.Writer, resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
//return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
common.SysError("error copying response body: " + err.Error())
}
resp.Body.Close()
if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) {