fix: claude to openai tools use

This commit is contained in:
1808837298@qq.com
2025-03-12 19:29:15 +08:00
parent 39d95172e8
commit 229738cda9
2 changed files with 83 additions and 73 deletions

View File

@@ -144,11 +144,14 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
defer stream.Close() defer stream.Close()
c.Writer.Header().Set("Content-Type", "text/event-stream") c.Writer.Header().Set("Content-Type", "text/event-stream")
var usage relaymodel.Usage claudeInfo := &claude.ClaudeResponseInfo{
var id string ResponseId: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
var model string Created: common.GetTimestamp(),
Model: info.UpstreamModelName,
ResponseText: strings.Builder{},
Usage: &relaymodel.Usage{},
}
isFirst := true isFirst := true
createdTime := common.GetTimestamp()
c.Stream(func(w io.Writer) bool { c.Stream(func(w io.Writer) bool {
event, ok := <-stream.Events() event, ok := <-stream.Events()
if !ok { if !ok {
@@ -161,33 +164,19 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
isFirst = false isFirst = false
info.FirstResponseTime = time.Now() info.FirstResponseTime = time.Now()
} }
claudeResp := new(claude.ClaudeResponse) claudeResponse := new(claude.ClaudeResponse)
err := json.NewDecoder(bytes.NewReader(v.Value.Bytes)).Decode(claudeResp) err := json.NewDecoder(bytes.NewReader(v.Value.Bytes)).Decode(claudeResponse)
if err != nil { if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error()) common.SysError("error unmarshalling stream response: " + err.Error())
return false return false
} }
response, claudeUsage := claude.StreamResponseClaude2OpenAI(requestMode, claudeResp) response := claude.StreamResponseClaude2OpenAI(requestMode, claudeResponse)
if claudeUsage != nil {
usage.PromptTokens += claudeUsage.InputTokens
usage.CompletionTokens += claudeUsage.OutputTokens
}
if response == nil { if !claude.FormatClaudeResponseInfo(RequestModeMessage, claudeResponse, response, claudeInfo) {
return true return true
} }
if response.Id != "" {
id = response.Id
}
if response.Model != "" {
model = response.Model
}
response.Created = createdTime
response.Id = id
response.Model = model
jsonStr, err := json.Marshal(response) jsonStr, err := json.Marshal(response)
if err != nil { if err != nil {
common.SysError("error marshalling stream response: " + err.Error()) common.SysError("error marshalling stream response: " + err.Error())
@@ -203,8 +192,16 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
return false return false
} }
}) })
if claudeInfo.Usage.PromptTokens == 0 {
//上游出错
}
if claudeInfo.Usage.CompletionTokens == 0 {
claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
}
if info.ShouldIncludeUsage { if info.ShouldIncludeUsage {
response := helper.GenerateFinalUsageResponse(id, createdTime, info.UpstreamModelName, usage) response := helper.GenerateFinalUsageResponse(claudeInfo.ResponseId, claudeInfo.Created, info.UpstreamModelName, *claudeInfo.Usage)
err := helper.ObjectData(c, response) err := helper.ObjectData(c, response)
if err != nil { if err != nil {
common.SysError("send final response failed: " + err.Error()) common.SysError("send final response failed: " + err.Error())
@@ -217,5 +214,5 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil
} }
} }
return nil, &usage return nil, claudeInfo.Usage
} }

View File

@@ -1,6 +1,7 @@
package claude package claude
import ( import (
"bytes"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
@@ -290,9 +291,8 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
return &claudeRequest, nil return &claudeRequest, nil
} }
func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*dto.ChatCompletionsStreamResponse, *ClaudeUsage) { func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.ChatCompletionsStreamResponse {
var response dto.ChatCompletionsStreamResponse var response dto.ChatCompletionsStreamResponse
var claudeUsage *ClaudeUsage
response.Object = "chat.completion.chunk" response.Object = "chat.completion.chunk"
response.Model = claudeResponse.Model response.Model = claudeResponse.Model
response.Choices = make([]dto.ChatCompletionsStreamResponseChoice, 0) response.Choices = make([]dto.ChatCompletionsStreamResponseChoice, 0)
@@ -308,7 +308,7 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*
if claudeResponse.Type == "message_start" { if claudeResponse.Type == "message_start" {
response.Id = claudeResponse.Message.Id response.Id = claudeResponse.Message.Id
response.Model = claudeResponse.Message.Model response.Model = claudeResponse.Message.Model
claudeUsage = &claudeResponse.Message.Usage //claudeUsage = &claudeResponse.Message.Usage
choice.Delta.SetContentString("") choice.Delta.SetContentString("")
choice.Delta.Role = "assistant" choice.Delta.Role = "assistant"
} else if claudeResponse.Type == "content_block_start" { } else if claudeResponse.Type == "content_block_start" {
@@ -325,7 +325,7 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*
}) })
} }
} else { } else {
return nil, nil return nil
} }
} else if claudeResponse.Type == "content_block_delta" { } else if claudeResponse.Type == "content_block_delta" {
if claudeResponse.Delta != nil { if claudeResponse.Delta != nil {
@@ -352,23 +352,20 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*
if finishReason != "null" { if finishReason != "null" {
choice.FinishReason = &finishReason choice.FinishReason = &finishReason
} }
claudeUsage = &claudeResponse.Usage //claudeUsage = &claudeResponse.Usage
} else if claudeResponse.Type == "message_stop" { } else if claudeResponse.Type == "message_stop" {
return nil, nil return nil
} else { } else {
return nil, nil return nil
} }
} }
if claudeUsage == nil {
claudeUsage = &ClaudeUsage{}
}
if len(tools) > 0 { if len(tools) > 0 {
choice.Delta.Content = nil // compatible with other OpenAI derivative applications, like LobeOpenAICompatibleFactory ... choice.Delta.Content = nil // compatible with other OpenAI derivative applications, like LobeOpenAICompatibleFactory ...
choice.Delta.ToolCalls = tools choice.Delta.ToolCalls = tools
} }
response.Choices = append(response.Choices, choice) response.Choices = append(response.Choices, choice)
return &response, claudeUsage return &response
} }
func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.OpenAITextResponse { func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.OpenAITextResponse {
@@ -437,48 +434,65 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.Ope
return &fullTextResponse return &fullTextResponse
} }
type ClaudeResponseInfo struct {
ResponseId string
Created int64
Model string
ResponseText strings.Builder
Usage *dto.Usage
}
func FormatClaudeResponseInfo(requestMode int, claudeResponse *ClaudeResponse, oaiResponse *dto.ChatCompletionsStreamResponse, claudeInfo *ClaudeResponseInfo) bool {
if oaiResponse == nil {
return false
}
if requestMode == RequestModeCompletion {
claudeInfo.ResponseText.WriteString(claudeResponse.Completion)
} else {
if claudeResponse.Type == "message_start" {
// message_start, 获取usage
claudeInfo.ResponseId = claudeResponse.Message.Id
claudeInfo.Model = claudeResponse.Message.Model
claudeInfo.Usage.PromptTokens = claudeResponse.Message.Usage.InputTokens
} else if claudeResponse.Type == "content_block_delta" {
claudeInfo.ResponseText.WriteString(claudeResponse.Delta.Text)
} else if claudeResponse.Type == "message_delta" {
claudeInfo.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens
claudeInfo.Usage.TotalTokens = claudeResponse.Usage.InputTokens + claudeResponse.Usage.OutputTokens
} else if claudeResponse.Type == "content_block_start" {
} else {
return false
}
}
oaiResponse.Id = claudeInfo.ResponseId
oaiResponse.Created = claudeInfo.Created
oaiResponse.Model = claudeInfo.Model
return true
}
func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID()) responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
var usage *dto.Usage claudeInfo := &ClaudeResponseInfo{
usage = &dto.Usage{} ResponseId: responseId,
responseText := "" Created: common.GetTimestamp(),
createdTime := common.GetTimestamp() Model: info.UpstreamModelName,
ResponseText: strings.Builder{},
Usage: &dto.Usage{},
}
helper.StreamScannerHandler(c, resp, info, func(data string) bool { helper.StreamScannerHandler(c, resp, info, func(data string) bool {
var claudeResponse ClaudeResponse var claudeResponse ClaudeResponse
err := json.Unmarshal([]byte(data), &claudeResponse) err := json.NewDecoder(bytes.NewReader(common.StringToByteSlice(data))).Decode(&claudeResponse)
if err != nil { if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error()) common.SysError("error unmarshalling stream response: " + err.Error())
return true return true
} }
response, claudeUsage := StreamResponseClaude2OpenAI(requestMode, &claudeResponse) response := StreamResponseClaude2OpenAI(requestMode, &claudeResponse)
if response == nil {
if !FormatClaudeResponseInfo(requestMode, &claudeResponse, response, claudeInfo) {
return true return true
} }
if requestMode == RequestModeCompletion {
responseText += claudeResponse.Completion
responseId = response.Id
} else {
if claudeResponse.Type == "message_start" {
// message_start, 获取usage
responseId = claudeResponse.Message.Id
info.UpstreamModelName = claudeResponse.Message.Model
usage.PromptTokens = claudeUsage.InputTokens
} else if claudeResponse.Type == "content_block_delta" {
responseText += claudeResponse.Delta.Text
} else if claudeResponse.Type == "message_delta" {
usage.CompletionTokens = claudeUsage.OutputTokens
usage.TotalTokens = claudeUsage.InputTokens + claudeUsage.OutputTokens
} else if claudeResponse.Type == "content_block_start" {
} else {
return true
}
}
//response.Id = responseId
response.Id = responseId
response.Created = createdTime
response.Model = info.UpstreamModelName
err = helper.ObjectData(c, response) err = helper.ObjectData(c, response)
if err != nil { if err != nil {
@@ -488,25 +502,24 @@ func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
}) })
if requestMode == RequestModeCompletion { if requestMode == RequestModeCompletion {
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens)
} else { } else {
if usage.PromptTokens == 0 { if claudeInfo.Usage.PromptTokens == 0 {
usage.PromptTokens = info.PromptTokens //上游出错
} }
if usage.CompletionTokens == 0 { if claudeInfo.Usage.CompletionTokens == 0 {
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, usage.PromptTokens) claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
} }
} }
if info.ShouldIncludeUsage { if info.ShouldIncludeUsage {
response := helper.GenerateFinalUsageResponse(responseId, createdTime, info.UpstreamModelName, *usage) response := helper.GenerateFinalUsageResponse(responseId, claudeInfo.Created, info.UpstreamModelName, *claudeInfo.Usage)
err := helper.ObjectData(c, response) err := helper.ObjectData(c, response)
if err != nil { if err != nil {
common.SysError("send final response failed: " + err.Error()) common.SysError("send final response failed: " + err.Error())
} }
} }
helper.Done(c) helper.Done(c)
//resp.Body.Close() return nil, claudeInfo.Usage
return nil, usage
} }
func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {