diff --git a/relay/channel/aws/relay-aws.go b/relay/channel/aws/relay-aws.go index 976f97ce..e1270606 100644 --- a/relay/channel/aws/relay-aws.go +++ b/relay/channel/aws/relay-aws.go @@ -144,11 +144,14 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel defer stream.Close() c.Writer.Header().Set("Content-Type", "text/event-stream") - var usage relaymodel.Usage - var id string - var model string + claudeInfo := &claude.ClaudeResponseInfo{ + ResponseId: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), + Created: common.GetTimestamp(), + Model: info.UpstreamModelName, + ResponseText: strings.Builder{}, + Usage: &relaymodel.Usage{}, + } isFirst := true - createdTime := common.GetTimestamp() c.Stream(func(w io.Writer) bool { event, ok := <-stream.Events() if !ok { @@ -161,33 +164,19 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel isFirst = false info.FirstResponseTime = time.Now() } - claudeResp := new(claude.ClaudeResponse) - err := json.NewDecoder(bytes.NewReader(v.Value.Bytes)).Decode(claudeResp) + claudeResponse := new(claude.ClaudeResponse) + err := json.NewDecoder(bytes.NewReader(v.Value.Bytes)).Decode(claudeResponse) if err != nil { common.SysError("error unmarshalling stream response: " + err.Error()) return false } - response, claudeUsage := claude.StreamResponseClaude2OpenAI(requestMode, claudeResp) - if claudeUsage != nil { - usage.PromptTokens += claudeUsage.InputTokens - usage.CompletionTokens += claudeUsage.OutputTokens - } + response := claude.StreamResponseClaude2OpenAI(requestMode, claudeResponse) - if response == nil { + if !claude.FormatClaudeResponseInfo(RequestModeMessage, claudeResponse, response, claudeInfo) { return true } - if response.Id != "" { - id = response.Id - } - if response.Model != "" { - model = response.Model - } - response.Created = createdTime - response.Id = id - response.Model = model - jsonStr, err := json.Marshal(response) if err != nil { common.SysError("error marshalling stream response: " + err.Error()) @@ -203,8 +192,16 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel return false } }) + + if claudeInfo.Usage.PromptTokens == 0 { + //上游出错 + } + if claudeInfo.Usage.CompletionTokens == 0 { + claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens) + } + if info.ShouldIncludeUsage { - response := helper.GenerateFinalUsageResponse(id, createdTime, info.UpstreamModelName, usage) + response := helper.GenerateFinalUsageResponse(claudeInfo.ResponseId, claudeInfo.Created, info.UpstreamModelName, *claudeInfo.Usage) err := helper.ObjectData(c, response) if err != nil { common.SysError("send final response failed: " + err.Error()) @@ -217,5 +214,5 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil } } - return nil, &usage + return nil, claudeInfo.Usage } diff --git a/relay/channel/claude/relay-claude.go b/relay/channel/claude/relay-claude.go index 40659020..fb4f5b7e 100644 --- a/relay/channel/claude/relay-claude.go +++ b/relay/channel/claude/relay-claude.go @@ -1,6 +1,7 @@ package claude import ( + "bytes" "encoding/json" "fmt" "io" @@ -290,9 +291,8 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR return &claudeRequest, nil } -func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*dto.ChatCompletionsStreamResponse, *ClaudeUsage) { +func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.ChatCompletionsStreamResponse { var response dto.ChatCompletionsStreamResponse - var claudeUsage *ClaudeUsage response.Object = "chat.completion.chunk" response.Model = claudeResponse.Model response.Choices = make([]dto.ChatCompletionsStreamResponseChoice, 0) @@ -308,7 +308,7 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (* if claudeResponse.Type == "message_start" { response.Id = claudeResponse.Message.Id response.Model = claudeResponse.Message.Model - claudeUsage = &claudeResponse.Message.Usage + //claudeUsage = &claudeResponse.Message.Usage choice.Delta.SetContentString("") choice.Delta.Role = "assistant" } else if claudeResponse.Type == "content_block_start" { @@ -325,7 +325,7 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (* }) } } else { - return nil, nil + return nil } } else if claudeResponse.Type == "content_block_delta" { if claudeResponse.Delta != nil { @@ -352,23 +352,20 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (* if finishReason != "null" { choice.FinishReason = &finishReason } - claudeUsage = &claudeResponse.Usage + //claudeUsage = &claudeResponse.Usage } else if claudeResponse.Type == "message_stop" { - return nil, nil + return nil } else { - return nil, nil + return nil } } - if claudeUsage == nil { - claudeUsage = &ClaudeUsage{} - } if len(tools) > 0 { choice.Delta.Content = nil // compatible with other OpenAI derivative applications, like LobeOpenAICompatibleFactory ... choice.Delta.ToolCalls = tools } response.Choices = append(response.Choices, choice) - return &response, claudeUsage + return &response } func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.OpenAITextResponse { @@ -437,48 +434,65 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.Ope return &fullTextResponse } +type ClaudeResponseInfo struct { + ResponseId string + Created int64 + Model string + ResponseText strings.Builder + Usage *dto.Usage +} + +func FormatClaudeResponseInfo(requestMode int, claudeResponse *ClaudeResponse, oaiResponse *dto.ChatCompletionsStreamResponse, claudeInfo *ClaudeResponseInfo) bool { + if oaiResponse == nil { + return false + } + if requestMode == RequestModeCompletion { + claudeInfo.ResponseText.WriteString(claudeResponse.Completion) + } else { + if claudeResponse.Type == "message_start" { + // message_start, 获取usage + claudeInfo.ResponseId = claudeResponse.Message.Id + claudeInfo.Model = claudeResponse.Message.Model + claudeInfo.Usage.PromptTokens = claudeResponse.Message.Usage.InputTokens + } else if claudeResponse.Type == "content_block_delta" { + claudeInfo.ResponseText.WriteString(claudeResponse.Delta.Text) + } else if claudeResponse.Type == "message_delta" { + claudeInfo.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens + claudeInfo.Usage.TotalTokens = claudeResponse.Usage.InputTokens + claudeResponse.Usage.OutputTokens + } else if claudeResponse.Type == "content_block_start" { + } else { + return false + } + } + oaiResponse.Id = claudeInfo.ResponseId + oaiResponse.Created = claudeInfo.Created + oaiResponse.Model = claudeInfo.Model + return true +} + func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID()) - var usage *dto.Usage - usage = &dto.Usage{} - responseText := "" - createdTime := common.GetTimestamp() + claudeInfo := &ClaudeResponseInfo{ + ResponseId: responseId, + Created: common.GetTimestamp(), + Model: info.UpstreamModelName, + ResponseText: strings.Builder{}, + Usage: &dto.Usage{}, + } helper.StreamScannerHandler(c, resp, info, func(data string) bool { var claudeResponse ClaudeResponse - err := json.Unmarshal([]byte(data), &claudeResponse) + err := json.NewDecoder(bytes.NewReader(common.StringToByteSlice(data))).Decode(&claudeResponse) if err != nil { common.SysError("error unmarshalling stream response: " + err.Error()) return true } - response, claudeUsage := StreamResponseClaude2OpenAI(requestMode, &claudeResponse) - if response == nil { + response := StreamResponseClaude2OpenAI(requestMode, &claudeResponse) + + if !FormatClaudeResponseInfo(requestMode, &claudeResponse, response, claudeInfo) { return true } - if requestMode == RequestModeCompletion { - responseText += claudeResponse.Completion - responseId = response.Id - } else { - if claudeResponse.Type == "message_start" { - // message_start, 获取usage - responseId = claudeResponse.Message.Id - info.UpstreamModelName = claudeResponse.Message.Model - usage.PromptTokens = claudeUsage.InputTokens - } else if claudeResponse.Type == "content_block_delta" { - responseText += claudeResponse.Delta.Text - } else if claudeResponse.Type == "message_delta" { - usage.CompletionTokens = claudeUsage.OutputTokens - usage.TotalTokens = claudeUsage.InputTokens + claudeUsage.OutputTokens - } else if claudeResponse.Type == "content_block_start" { - } else { - return true - } - } - //response.Id = responseId - response.Id = responseId - response.Created = createdTime - response.Model = info.UpstreamModelName err = helper.ObjectData(c, response) if err != nil { @@ -488,25 +502,24 @@ func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon. }) if requestMode == RequestModeCompletion { - usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) + claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens) } else { - if usage.PromptTokens == 0 { - usage.PromptTokens = info.PromptTokens + if claudeInfo.Usage.PromptTokens == 0 { + //上游出错 } - if usage.CompletionTokens == 0 { - usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, usage.PromptTokens) + if claudeInfo.Usage.CompletionTokens == 0 { + claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens) } } if info.ShouldIncludeUsage { - response := helper.GenerateFinalUsageResponse(responseId, createdTime, info.UpstreamModelName, *usage) + response := helper.GenerateFinalUsageResponse(responseId, claudeInfo.Created, info.UpstreamModelName, *claudeInfo.Usage) err := helper.ObjectData(c, response) if err != nil { common.SysError("send final response failed: " + err.Error()) } } helper.Done(c) - //resp.Body.Close() - return nil, usage + return nil, claudeInfo.Usage } func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {