From 7e46d4217d45e10df5c07430aafcf11e17bf1221 Mon Sep 17 00:00:00 2001 From: "1808837298@qq.com" <1808837298@qq.com> Date: Thu, 13 Mar 2025 19:32:08 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=88=9D=E6=AD=A5=E5=85=BC=E5=AE=B9?= =?UTF-8?q?=E6=B5=81=E6=A8=A1=E5=BC=8F=E4=B8=8Bopenai=E6=B8=A0=E9=81=93?= =?UTF-8?q?=E7=B1=BB=E5=9E=8B=E8=BD=AC=E4=B8=BAclaude=E6=A0=BC=E5=BC=8F?= =?UTF-8?q?=E8=AE=BF=E9=97=AE=20#862?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- controller/channel-test.go | 2 +- dto/claude.go | 28 +++- relay/channel/adapter.go | 2 +- relay/channel/ali/adaptor.go | 2 +- relay/channel/api_request.go | 4 + relay/channel/aws/adaptor.go | 2 +- relay/channel/baidu/adaptor.go | 2 +- relay/channel/baidu_v2/adaptor.go | 2 +- relay/channel/claude/adaptor.go | 2 +- relay/channel/claude/relay-claude.go | 2 +- relay/channel/cloudflare/adaptor.go | 2 +- relay/channel/cohere/adaptor.go | 2 +- relay/channel/deepseek/adaptor.go | 2 +- relay/channel/dify/adaptor.go | 2 +- relay/channel/gemini/adaptor.go | 2 +- relay/channel/jina/adaptor.go | 2 +- relay/channel/mistral/adaptor.go | 2 +- relay/channel/mokaai/adaptor.go | 2 +- relay/channel/ollama/adaptor.go | 2 +- relay/channel/openai/adaptor.go | 24 +++- relay/channel/openai/helper.go | 188 +++++++++++++++++++++++++++ relay/channel/openai/relay-openai.go | 95 +------------- relay/channel/openrouter/adaptor.go | 2 +- relay/channel/palm/adaptor.go | 2 +- relay/channel/perplexity/adaptor.go | 2 +- relay/channel/siliconflow/adaptor.go | 2 +- relay/channel/tencent/adaptor.go | 2 +- relay/channel/vertex/adaptor.go | 2 +- relay/channel/volcengine/adaptor.go | 2 +- relay/channel/xunfei/adaptor.go | 2 +- relay/channel/zhipu/adaptor.go | 2 +- relay/channel/zhipu_4v/adaptor.go | 2 +- relay/claude_handler.go | 5 +- relay/common/relay_info.go | 17 ++- relay/helper/common.go | 16 +++ relay/relay-text.go | 5 +- service/convert.go | 119 +++++++++++------ 37 files changed, 390 insertions(+), 165 deletions(-) create mode 100644 relay/channel/openai/helper.go diff --git a/controller/channel-test.go b/controller/channel-test.go index 39af95e1..8ecbde3f 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -107,7 +107,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr adaptor.Init(info) - convertedRequest, err := adaptor.ConvertRequest(c, info, request) + convertedRequest, err := adaptor.ConvertOpenAIRequest(c, info, request) if err != nil { return err, nil } diff --git a/dto/claude.go b/dto/claude.go index 60f638f6..f7354230 100644 --- a/dto/claude.go +++ b/dto/claude.go @@ -13,7 +13,7 @@ type ClaudeMediaMessage struct { Source *ClaudeMessageSource `json:"source,omitempty"` Usage *ClaudeUsage `json:"usage,omitempty"` StopReason *string `json:"stop_reason,omitempty"` - PartialJson string `json:"partial_json,omitempty"` + PartialJson *string `json:"partial_json,omitempty"` Role string `json:"role,omitempty"` Thinking string `json:"thinking,omitempty"` Signature string `json:"signature,omitempty"` @@ -37,6 +37,32 @@ func (c *ClaudeMediaMessage) GetText() string { return *c.Text } +func (c *ClaudeMediaMessage) IsStringContent() bool { + var content string + return json.Unmarshal(c.Content, &content) == nil +} + +func (c *ClaudeMediaMessage) GetStringContent() string { + var content string + if err := json.Unmarshal(c.Content, &content); err == nil { + return content + } + return "" +} + +func (c *ClaudeMediaMessage) SetContent(content any) { + jsonContent, _ := json.Marshal(content) + c.Content = jsonContent +} + +func (c *ClaudeMediaMessage) ParseMediaContent() []ClaudeMediaMessage { + var mediaContent []ClaudeMediaMessage + if err := json.Unmarshal(c.Content, &mediaContent); err == nil { + return mediaContent + } + return make([]ClaudeMediaMessage, 0) +} + type ClaudeMessageSource struct { Type string `json:"type"` MediaType string `json:"media_type"` diff --git a/relay/channel/adapter.go b/relay/channel/adapter.go index 9f449b54..e097dbe6 100644 --- a/relay/channel/adapter.go +++ b/relay/channel/adapter.go @@ -13,7 +13,7 @@ type Adaptor interface { Init(info *relaycommon.RelayInfo) GetRequestURL(info *relaycommon.RelayInfo) (string, error) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error - ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) + ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) diff --git a/relay/channel/ali/adaptor.go b/relay/channel/ali/adaptor.go index 9d3ee99f..e28278e1 100644 --- a/relay/channel/ali/adaptor.go +++ b/relay/channel/ali/adaptor.go @@ -50,7 +50,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel return nil } -func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { +func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { if request == nil { return nil, errors.New("request is nil") } diff --git a/relay/channel/api_request.go b/relay/channel/api_request.go index a60bc6f1..8b2ca889 100644 --- a/relay/channel/api_request.go +++ b/relay/channel/api_request.go @@ -7,6 +7,7 @@ import ( "github.com/gorilla/websocket" "io" "net/http" + common2 "one-api/common" "one-api/relay/common" "one-api/relay/constant" "one-api/service" @@ -31,6 +32,9 @@ func DoApiRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody if err != nil { return nil, fmt.Errorf("get request url failed: %w", err) } + if common2.DebugEnabled { + println("fullRequestURL:", fullRequestURL) + } req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) if err != nil { return nil, fmt.Errorf("new request failed: %w", err) diff --git a/relay/channel/aws/adaptor.go b/relay/channel/aws/adaptor.go index e735ee2b..94edda33 100644 --- a/relay/channel/aws/adaptor.go +++ b/relay/channel/aws/adaptor.go @@ -47,7 +47,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel return nil } -func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { +func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { if request == nil { return nil, errors.New("request is nil") } diff --git a/relay/channel/baidu/adaptor.go b/relay/channel/baidu/adaptor.go index 105f2a9b..eecb0bac 100644 --- a/relay/channel/baidu/adaptor.go +++ b/relay/channel/baidu/adaptor.go @@ -110,7 +110,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel return nil } -func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { +func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { if request == nil { return nil, errors.New("request is nil") } diff --git a/relay/channel/baidu_v2/adaptor.go b/relay/channel/baidu_v2/adaptor.go index 855ed717..9645bbf5 100644 --- a/relay/channel/baidu_v2/adaptor.go +++ b/relay/channel/baidu_v2/adaptor.go @@ -44,7 +44,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel return nil } -func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { +func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { if request == nil { return nil, errors.New("request is nil") } diff --git a/relay/channel/claude/adaptor.go b/relay/channel/claude/adaptor.go index a5c475fa..6d65d6d4 100644 --- a/relay/channel/claude/adaptor.go +++ b/relay/channel/claude/adaptor.go @@ -64,7 +64,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel return nil } -func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { +func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { if request == nil { return nil, errors.New("request is nil") } diff --git a/relay/channel/claude/relay-claude.go b/relay/channel/claude/relay-claude.go index 74b73454..8607f77d 100644 --- a/relay/channel/claude/relay-claude.go +++ b/relay/channel/claude/relay-claude.go @@ -335,7 +335,7 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *dto.ClaudeResponse case "input_json_delta": tools = append(tools, dto.ToolCallResponse{ Function: dto.FunctionResponse{ - Arguments: claudeResponse.Delta.PartialJson, + Arguments: *claudeResponse.Delta.PartialJson, }, }) case "signature_delta": diff --git a/relay/channel/cloudflare/adaptor.go b/relay/channel/cloudflare/adaptor.go index b21e25f3..3d5a5a8a 100644 --- a/relay/channel/cloudflare/adaptor.go +++ b/relay/channel/cloudflare/adaptor.go @@ -43,7 +43,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel return nil } -func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { +func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { if request == nil { return nil, errors.New("request is nil") } diff --git a/relay/channel/cohere/adaptor.go b/relay/channel/cohere/adaptor.go index 7675d546..53a357ad 100644 --- a/relay/channel/cohere/adaptor.go +++ b/relay/channel/cohere/adaptor.go @@ -48,7 +48,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel return nil } -func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { +func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { return requestOpenAI2Cohere(*request), nil } diff --git a/relay/channel/deepseek/adaptor.go b/relay/channel/deepseek/adaptor.go index ad01b8f4..64d92a48 100644 --- a/relay/channel/deepseek/adaptor.go +++ b/relay/channel/deepseek/adaptor.go @@ -50,7 +50,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel return nil } -func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { +func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { if request == nil { return nil, errors.New("request is nil") } diff --git a/relay/channel/dify/adaptor.go b/relay/channel/dify/adaptor.go index 96aff447..003b5f83 100644 --- a/relay/channel/dify/adaptor.go +++ b/relay/channel/dify/adaptor.go @@ -70,7 +70,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel return nil } -func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { +func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { if request == nil { return nil, errors.New("request is nil") } diff --git a/relay/channel/gemini/adaptor.go b/relay/channel/gemini/adaptor.go index a629968b..c5a547ba 100644 --- a/relay/channel/gemini/adaptor.go +++ b/relay/channel/gemini/adaptor.go @@ -95,7 +95,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel return nil } -func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { +func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { if request == nil { return nil, errors.New("request is nil") } diff --git a/relay/channel/jina/adaptor.go b/relay/channel/jina/adaptor.go index bcfc8dea..a65e820e 100644 --- a/relay/channel/jina/adaptor.go +++ b/relay/channel/jina/adaptor.go @@ -49,7 +49,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel return nil } -func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { +func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { return request, nil } diff --git a/relay/channel/mistral/adaptor.go b/relay/channel/mistral/adaptor.go index 80547346..4857209f 100644 --- a/relay/channel/mistral/adaptor.go +++ b/relay/channel/mistral/adaptor.go @@ -43,7 +43,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel return nil } -func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { +func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { if request == nil { return nil, errors.New("request is nil") } diff --git a/relay/channel/mokaai/adaptor.go b/relay/channel/mokaai/adaptor.go index 151072cb..304351fd 100644 --- a/relay/channel/mokaai/adaptor.go +++ b/relay/channel/mokaai/adaptor.go @@ -57,7 +57,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel return nil } -func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { +func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { if request == nil { return nil, errors.New("request is nil") } diff --git a/relay/channel/ollama/adaptor.go b/relay/channel/ollama/adaptor.go index 4190dd3f..2101bf70 100644 --- a/relay/channel/ollama/adaptor.go +++ b/relay/channel/ollama/adaptor.go @@ -49,7 +49,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel return nil } -func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { +func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { if request == nil { return nil, errors.New("request is nil") } diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go index 196343e8..d8bc808e 100644 --- a/relay/channel/openai/adaptor.go +++ b/relay/channel/openai/adaptor.go @@ -21,6 +21,7 @@ import ( "one-api/relay/channel/xinference" relaycommon "one-api/relay/common" "one-api/relay/constant" + "one-api/service" "strings" ) @@ -29,10 +30,20 @@ type Adaptor struct { ResponseFormat string } -func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { - //TODO implement me - panic("implement me") - return nil, nil +func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) { + if !strings.HasPrefix(request.Model, "claude") { + return nil, fmt.Errorf("you are using openai channel type with path /v1/messages, only claude model supported convert, but got %s", request.Model) + } + aiRequest, err := service.ClaudeToOpenAIRequest(*request) + if err != nil { + return nil, err + } + if info.SupportStreamOptions { + aiRequest.StreamOptions = &dto.StreamOptions{ + IncludeUsage: true, + } + } + return a.ConvertOpenAIRequest(c, info, aiRequest) } func (a *Adaptor) Init(info *relaycommon.RelayInfo) { @@ -40,6 +51,9 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { + if info.RelayFormat == relaycommon.RelayFormatClaude { + return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil + } if info.RelayMode == constant.RelayModeRealtime { if strings.HasPrefix(info.BaseUrl, "https://") { baseUrl := strings.TrimPrefix(info.BaseUrl, "https://") @@ -115,7 +129,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, header *http.Header, info * return nil } -func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { +func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { if request == nil { return nil, errors.New("request is nil") } diff --git a/relay/channel/openai/helper.go b/relay/channel/openai/helper.go new file mode 100644 index 00000000..a6d0aed8 --- /dev/null +++ b/relay/channel/openai/helper.go @@ -0,0 +1,188 @@ +package openai + +import ( + "encoding/json" + "one-api/common" + "one-api/dto" + relaycommon "one-api/relay/common" + relayconstant "one-api/relay/constant" + "one-api/relay/helper" + "one-api/service" + "strings" + + "github.com/gin-gonic/gin" +) + +// 辅助函数 +func handleStreamFormat(c *gin.Context, info *relaycommon.RelayInfo, data string, forceFormat bool, thinkToContent bool) error { + info.SendResponseCount++ + switch info.RelayFormat { + case relaycommon.RelayFormatOpenAI: + return sendStreamData(c, info, data, forceFormat, thinkToContent) + case relaycommon.RelayFormatClaude: + return handleClaudeFormat(c, data, info) + } + return nil +} + +func handleClaudeFormat(c *gin.Context, data string, info *relaycommon.RelayInfo) error { + var streamResponse dto.ChatCompletionsStreamResponse + if err := json.Unmarshal(common.StringToByteSlice(data), &streamResponse); err != nil { + return err + } + + claudeResponses := service.StreamResponseOpenAI2Claude(&streamResponse, info) + for _, resp := range claudeResponses { + helper.ClaudeData(c, *resp) + } + return nil +} + +func processStreamResponse(item string, responseTextBuilder *strings.Builder, toolCount *int) error { + var streamResponse dto.ChatCompletionsStreamResponse + if err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse); err != nil { + return err + } + + for _, choice := range streamResponse.Choices { + responseTextBuilder.WriteString(choice.Delta.GetContentString()) + responseTextBuilder.WriteString(choice.Delta.GetReasoningContent()) + if choice.Delta.ToolCalls != nil { + if len(choice.Delta.ToolCalls) > *toolCount { + *toolCount = len(choice.Delta.ToolCalls) + } + for _, tool := range choice.Delta.ToolCalls { + responseTextBuilder.WriteString(tool.Function.Name) + responseTextBuilder.WriteString(tool.Function.Arguments) + } + } + } + return nil +} + +func processTokens(relayMode int, streamItems []string, responseTextBuilder *strings.Builder, toolCount *int) error { + streamResp := "[" + strings.Join(streamItems, ",") + "]" + + switch relayMode { + case relayconstant.RelayModeChatCompletions: + return processChatCompletions(streamResp, streamItems, responseTextBuilder, toolCount) + case relayconstant.RelayModeCompletions: + return processCompletions(streamResp, streamItems, responseTextBuilder) + } + return nil +} + +func processChatCompletions(streamResp string, streamItems []string, responseTextBuilder *strings.Builder, toolCount *int) error { + var streamResponses []dto.ChatCompletionsStreamResponse + if err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses); err != nil { + // 一次性解析失败,逐个解析 + common.SysError("error unmarshalling stream response: " + err.Error()) + for _, item := range streamItems { + if err := processStreamResponse(item, responseTextBuilder, toolCount); err != nil { + common.SysError("error processing stream response: " + err.Error()) + } + } + return nil + } + + // 批量处理所有响应 + for _, streamResponse := range streamResponses { + for _, choice := range streamResponse.Choices { + responseTextBuilder.WriteString(choice.Delta.GetContentString()) + responseTextBuilder.WriteString(choice.Delta.GetReasoningContent()) + if choice.Delta.ToolCalls != nil { + if len(choice.Delta.ToolCalls) > *toolCount { + *toolCount = len(choice.Delta.ToolCalls) + } + for _, tool := range choice.Delta.ToolCalls { + responseTextBuilder.WriteString(tool.Function.Name) + responseTextBuilder.WriteString(tool.Function.Arguments) + } + } + } + } + return nil +} + +func processCompletions(streamResp string, streamItems []string, responseTextBuilder *strings.Builder) error { + var streamResponses []dto.CompletionsStreamResponse + if err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses); err != nil { + // 一次性解析失败,逐个解析 + common.SysError("error unmarshalling stream response: " + err.Error()) + for _, item := range streamItems { + var streamResponse dto.CompletionsStreamResponse + if err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse); err != nil { + continue + } + for _, choice := range streamResponse.Choices { + responseTextBuilder.WriteString(choice.Text) + } + } + return nil + } + + // 批量处理所有响应 + for _, streamResponse := range streamResponses { + for _, choice := range streamResponse.Choices { + responseTextBuilder.WriteString(choice.Text) + } + } + return nil +} + +func handleLastResponse(lastStreamData string, responseId *string, createAt *int64, + systemFingerprint *string, model *string, usage **dto.Usage, + containStreamUsage *bool, info *relaycommon.RelayInfo, + shouldSendLastResp *bool) error { + + var lastStreamResponse dto.ChatCompletionsStreamResponse + if err := json.Unmarshal(common.StringToByteSlice(lastStreamData), &lastStreamResponse); err != nil { + return err + } + + *responseId = lastStreamResponse.Id + *createAt = lastStreamResponse.Created + *systemFingerprint = lastStreamResponse.GetSystemFingerprint() + *model = lastStreamResponse.Model + + if service.ValidUsage(lastStreamResponse.Usage) { + *containStreamUsage = true + *usage = lastStreamResponse.Usage + if !info.ShouldIncludeUsage { + *shouldSendLastResp = false + } + } + + return nil +} + +func handleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStreamData string, + responseId string, createAt int64, model string, systemFingerprint string, + usage *dto.Usage, containStreamUsage bool) { + + switch info.RelayFormat { + case relaycommon.RelayFormatOpenAI: + if info.ShouldIncludeUsage && !containStreamUsage { + response := helper.GenerateFinalUsageResponse(responseId, createAt, model, *usage) + response.SetSystemFingerprint(systemFingerprint) + helper.ObjectData(c, response) + } + helper.Done(c) + + case relaycommon.RelayFormatClaude: + var streamResponse dto.ChatCompletionsStreamResponse + if err := json.Unmarshal(common.StringToByteSlice(lastStreamData), &streamResponse); err != nil { + common.SysError("error unmarshalling stream response: " + err.Error()) + return + } + + if !containStreamUsage { + streamResponse.Usage = usage + } + + claudeResponses := service.StreamResponseOpenAI2Claude(&streamResponse, info) + for _, resp := range claudeResponses { + helper.ClaudeData(c, *resp) + } + } +} diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index ffd36d3c..2d1ad53e 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -12,7 +12,6 @@ import ( "one-api/constant" "one-api/dto" relaycommon "one-api/relay/common" - relayconstant "one-api/relay/constant" "one-api/relay/helper" "one-api/service" "os" @@ -137,10 +136,11 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel helper.StreamScannerHandler(c, resp, info, func(data string) bool { if lastStreamData != "" { - err := sendStreamData(c, info, lastStreamData, forceFormat, thinkToContent) + err := handleStreamFormat(c, info, lastStreamData, forceFormat, thinkToContent) if err != nil { - common.LogError(c, "streaming error: "+err.Error()) + common.SysError("error handling stream format: " + err.Error()) } + info.SetFirstResponseTime() } lastStreamData = data streamItems = append(streamItems, data) @@ -172,83 +172,9 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel sendStreamData(c, info, lastStreamData, forceFormat, thinkToContent) } - // 计算token - streamResp := "[" + strings.Join(streamItems, ",") + "]" - switch info.RelayMode { - case relayconstant.RelayModeChatCompletions: - var streamResponses []dto.ChatCompletionsStreamResponse - err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses) - if err != nil { - // 一次性解析失败,逐个解析 - common.SysError("error unmarshalling stream response: " + err.Error()) - for _, item := range streamItems { - var streamResponse dto.ChatCompletionsStreamResponse - err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse) - if err == nil { - //if service.ValidUsage(streamResponse.Usage) { - // usage = streamResponse.Usage - //} - for _, choice := range streamResponse.Choices { - responseTextBuilder.WriteString(choice.Delta.GetContentString()) - - // handle both reasoning_content and reasoning - responseTextBuilder.WriteString(choice.Delta.GetReasoningContent()) - - if choice.Delta.ToolCalls != nil { - if len(choice.Delta.ToolCalls) > toolCount { - toolCount = len(choice.Delta.ToolCalls) - } - for _, tool := range choice.Delta.ToolCalls { - responseTextBuilder.WriteString(tool.Function.Name) - responseTextBuilder.WriteString(tool.Function.Arguments) - } - } - } - } - } - } else { - for _, streamResponse := range streamResponses { - //if service.ValidUsage(streamResponse.Usage) { - // usage = streamResponse.Usage - // containStreamUsage = true - //} - for _, choice := range streamResponse.Choices { - responseTextBuilder.WriteString(choice.Delta.GetContentString()) - responseTextBuilder.WriteString(choice.Delta.GetReasoningContent()) // This will handle both reasoning_content and reasoning - if choice.Delta.ToolCalls != nil { - if len(choice.Delta.ToolCalls) > toolCount { - toolCount = len(choice.Delta.ToolCalls) - } - for _, tool := range choice.Delta.ToolCalls { - responseTextBuilder.WriteString(tool.Function.Name) - responseTextBuilder.WriteString(tool.Function.Arguments) - } - } - } - } - } - case relayconstant.RelayModeCompletions: - var streamResponses []dto.CompletionsStreamResponse - err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses) - if err != nil { - // 一次性解析失败,逐个解析 - common.SysError("error unmarshalling stream response: " + err.Error()) - for _, item := range streamItems { - var streamResponse dto.CompletionsStreamResponse - err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse) - if err == nil { - for _, choice := range streamResponse.Choices { - responseTextBuilder.WriteString(choice.Text) - } - } - } - } else { - for _, streamResponse := range streamResponses { - for _, choice := range streamResponse.Choices { - responseTextBuilder.WriteString(choice.Text) - } - } - } + // 处理token计算 + if err := processTokens(info.RelayMode, streamItems, &responseTextBuilder, &toolCount); err != nil { + common.SysError("error processing tokens: " + err.Error()) } if !containStreamUsage { @@ -262,15 +188,8 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel } } - if info.ShouldIncludeUsage && !containStreamUsage { - response := helper.GenerateFinalUsageResponse(responseId, createAt, model, *usage) - response.SetSystemFingerprint(systemFingerprint) - helper.ObjectData(c, response) - } + handleFinalResponse(c, info, lastStreamData, responseId, createAt, model, systemFingerprint, usage, containStreamUsage) - helper.Done(c) - - //resp.Body.Close() return nil, usage } diff --git a/relay/channel/openrouter/adaptor.go b/relay/channel/openrouter/adaptor.go index aef5afeb..f2909b6b 100644 --- a/relay/channel/openrouter/adaptor.go +++ b/relay/channel/openrouter/adaptor.go @@ -46,7 +46,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel return nil } -func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { +func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { return request, nil } diff --git a/relay/channel/palm/adaptor.go b/relay/channel/palm/adaptor.go index 69ef5001..f0220f4f 100644 --- a/relay/channel/palm/adaptor.go +++ b/relay/channel/palm/adaptor.go @@ -44,7 +44,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel return nil } -func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { +func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { if request == nil { return nil, errors.New("request is nil") } diff --git a/relay/channel/perplexity/adaptor.go b/relay/channel/perplexity/adaptor.go index de84406c..32f00047 100644 --- a/relay/channel/perplexity/adaptor.go +++ b/relay/channel/perplexity/adaptor.go @@ -44,7 +44,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel return nil } -func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { +func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { if request == nil { return nil, errors.New("request is nil") } diff --git a/relay/channel/siliconflow/adaptor.go b/relay/channel/siliconflow/adaptor.go index 754a1f00..1b319e2a 100644 --- a/relay/channel/siliconflow/adaptor.go +++ b/relay/channel/siliconflow/adaptor.go @@ -54,7 +54,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel return nil } -func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { +func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { return request, nil } diff --git a/relay/channel/tencent/adaptor.go b/relay/channel/tencent/adaptor.go index 28a02aae..f2b51ee9 100644 --- a/relay/channel/tencent/adaptor.go +++ b/relay/channel/tencent/adaptor.go @@ -58,7 +58,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel return nil } -func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { +func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { if request == nil { return nil, errors.New("request is nil") } diff --git a/relay/channel/vertex/adaptor.go b/relay/channel/vertex/adaptor.go index 2f348e46..e09845eb 100644 --- a/relay/channel/vertex/adaptor.go +++ b/relay/channel/vertex/adaptor.go @@ -122,7 +122,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel return nil } -func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { +func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { if request == nil { return nil, errors.New("request is nil") } diff --git a/relay/channel/volcengine/adaptor.go b/relay/channel/volcengine/adaptor.go index f423d587..5e5e276b 100644 --- a/relay/channel/volcengine/adaptor.go +++ b/relay/channel/volcengine/adaptor.go @@ -56,7 +56,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel return nil } -func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { +func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { if request == nil { return nil, errors.New("request is nil") } diff --git a/relay/channel/xunfei/adaptor.go b/relay/channel/xunfei/adaptor.go index d66f3732..9521bb47 100644 --- a/relay/channel/xunfei/adaptor.go +++ b/relay/channel/xunfei/adaptor.go @@ -44,7 +44,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel return nil } -func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { +func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { if request == nil { return nil, errors.New("request is nil") } diff --git a/relay/channel/zhipu/adaptor.go b/relay/channel/zhipu/adaptor.go index aa612f0c..04369001 100644 --- a/relay/channel/zhipu/adaptor.go +++ b/relay/channel/zhipu/adaptor.go @@ -48,7 +48,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel return nil } -func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { +func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { if request == nil { return nil, errors.New("request is nil") } diff --git a/relay/channel/zhipu_4v/adaptor.go b/relay/channel/zhipu_4v/adaptor.go index 7a23e212..ba24814c 100644 --- a/relay/channel/zhipu_4v/adaptor.go +++ b/relay/channel/zhipu_4v/adaptor.go @@ -45,7 +45,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel return nil } -func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { +func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { if request == nil { return nil, errors.New("request is nil") } diff --git a/relay/claude_handler.go b/relay/claude_handler.go index 97de772b..fb68a88a 100644 --- a/relay/claude_handler.go +++ b/relay/claude_handler.go @@ -114,13 +114,14 @@ func ClaudeHelper(c *gin.Context) (claudeError *dto.ClaudeErrorWithStatusCode) { return service.ClaudeErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError) } jsonData, err := json.Marshal(convertedRequest) + if common.DebugEnabled { + println("requestBody: ", string(jsonData)) + } if err != nil { return service.ClaudeErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError) } requestBody = bytes.NewBuffer(jsonData) - //log.Printf("requestBody: %s", requestBody) - statusCodeMappingStr := c.GetString("status_code_mapping") var httpResp *http.Response resp, err := adaptor.DoRequest(c, relayInfo, requestBody) diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index 3b5ef795..5075d07d 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -17,6 +17,16 @@ type ThinkingContentInfo struct { SendLastThinkingContent bool } +const ( + LastMessageTypeText = "text" + LastMessageTypeTools = "tools" +) + +type ClaudeConvertInfo struct { + LastMessagesType string + Index int +} + const ( RelayFormatOpenAI = "openai" RelayFormatClaude = "claude" @@ -64,8 +74,9 @@ type RelayInfo struct { UserEmail string UserQuota int RelayFormat string - ResponseTimes int64 + SendResponseCount int ThinkingContentInfo + ClaudeConvertInfo } // 定义支持流式选项的通道类型 @@ -93,6 +104,9 @@ func GenRelayInfoClaude(c *gin.Context) *RelayInfo { info := GenRelayInfo(c) info.RelayFormat = RelayFormatClaude info.ShouldIncludeUsage = false + info.ClaudeConvertInfo = ClaudeConvertInfo{ + LastMessagesType: LastMessageTypeText, + } return info } @@ -172,7 +186,6 @@ func (info *RelayInfo) SetIsStream(isStream bool) { } func (info *RelayInfo) SetFirstResponseTime() { - info.ResponseTimes++ if info.isFirstResponse { info.FirstResponseTime = time.Now() info.isFirstResponse = false diff --git a/relay/helper/common.go b/relay/helper/common.go index 6af55a86..13fc85ab 100644 --- a/relay/helper/common.go +++ b/relay/helper/common.go @@ -19,6 +19,22 @@ func SetEventStreamHeaders(c *gin.Context) { c.Writer.Header().Set("X-Accel-Buffering", "no") } +func ClaudeData(c *gin.Context, resp dto.ClaudeResponse) error { + jsonData, err := json.Marshal(resp) + if err != nil { + common.SysError("error marshalling stream response: " + err.Error()) + } else { + c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("event: %s\n", resp.Type)}) + c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonData)}) + } + if flusher, ok := c.Writer.(http.Flusher); ok { + flusher.Flush() + } else { + return errors.New("streaming error: flusher not found") + } + return nil +} + func ClaudeChunkData(c *gin.Context, resp dto.ClaudeResponse, data string) { c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("event: %s\n", resp.Type)}) c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("data: %s\n", data)}) diff --git a/relay/relay-text.go b/relay/relay-text.go index a0a97617..a61718fc 100644 --- a/relay/relay-text.go +++ b/relay/relay-text.go @@ -160,7 +160,7 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { } requestBody = bytes.NewBuffer(body) } else { - convertedRequest, err := adaptor.ConvertRequest(c, relayInfo, textRequest) + convertedRequest, err := adaptor.ConvertOpenAIRequest(c, relayInfo, textRequest) if err != nil { return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError) } @@ -168,6 +168,9 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { if err != nil { return service.OpenAIErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError) } + if common.DebugEnabled { + println("requestBody: ", string(jsonData)) + } requestBody = bytes.NewBuffer(jsonData) } diff --git a/service/convert.go b/service/convert.go index c4916df2..dbaae654 100644 --- a/service/convert.go +++ b/service/convert.go @@ -44,24 +44,26 @@ func ClaudeToOpenAIRequest(claudeRequest dto.ClaudeRequest) (*dto.GeneralOpenAIR openAIMessages := make([]dto.Message, 0) // Add system message if present - if claudeRequest.IsStringSystem() { - openAIMessage := dto.Message{ - Role: "system", - } - openAIMessage.SetStringContent(claudeRequest.GetStringSystem()) - openAIMessages = append(openAIMessages, openAIMessage) - } else { - systems := claudeRequest.ParseSystem() - if len(systems) > 0 { - systemStr := "" + if claudeRequest.System != nil { + if claudeRequest.IsStringSystem() { openAIMessage := dto.Message{ Role: "system", } - for _, system := range systems { - systemStr += system.Type - } - openAIMessage.SetStringContent(systemStr) + openAIMessage.SetStringContent(claudeRequest.GetStringSystem()) openAIMessages = append(openAIMessages, openAIMessage) + } else { + systems := claudeRequest.ParseSystem() + if len(systems) > 0 { + systemStr := "" + openAIMessage := dto.Message{ + Role: "system", + } + for _, system := range systems { + systemStr += system.Type + } + openAIMessage.SetStringContent(systemStr) + openAIMessages = append(openAIMessages, openAIMessage) + } } } for _, claudeMessage := range claudeRequest.Messages { @@ -100,7 +102,8 @@ func ClaudeToOpenAIRequest(claudeRequest dto.ClaudeRequest) (*dto.GeneralOpenAIR mediaMessages = append(mediaMessages, mediaMessage) case "tool_use": toolCall := dto.ToolCallRequest{ - ID: mediaMsg.Id, + ID: mediaMsg.Id, + Type: "function", Function: dto.FunctionRequest{ Name: mediaMsg.Name, Arguments: toJSONString(mediaMsg.Input), @@ -111,20 +114,33 @@ func ClaudeToOpenAIRequest(claudeRequest dto.ClaudeRequest) (*dto.GeneralOpenAIR // Add tool result as a separate message oaiToolMessage := dto.Message{ Role: "tool", + Name: &mediaMsg.Name, ToolCallId: mediaMsg.ToolUseId, } - oaiToolMessage.Content = mediaMsg.Content + //oaiToolMessage.SetStringContent(*mediaMsg.GetMediaContent().Text) + if mediaMsg.IsStringContent() { + oaiToolMessage.SetStringContent(mediaMsg.GetStringContent()) + } else { + mediaContents := mediaMsg.ParseMediaContent() + if len(mediaContents) > 0 && mediaContents[0].Text != nil { + oaiToolMessage.SetStringContent(*mediaContents[0].Text) + } + } + openAIMessages = append(openAIMessages, oaiToolMessage) } } - openAIMessage.SetMediaContent(mediaMessages) + if len(mediaMessages) > 0 { + openAIMessage.SetMediaContent(mediaMessages) + } if len(toolCalls) > 0 { openAIMessage.SetToolCalls(toolCalls) } } - - openAIMessages = append(openAIMessages, openAIMessage) + if len(openAIMessage.ParseContent()) > 0 { + openAIMessages = append(openAIMessages, openAIMessage) + } } openAIRequest.Messages = openAIMessages @@ -154,22 +170,35 @@ func ClaudeErrorToOpenAIError(claudeError *dto.ClaudeErrorWithStatusCode) *dto.O } } +func generateStopBlock(index int) *dto.ClaudeResponse { + return &dto.ClaudeResponse{ + Type: "content_block_stop", + Index: common.GetPointer[int](index), + } +} + func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamResponse, info *relaycommon.RelayInfo) []*dto.ClaudeResponse { var claudeResponses []*dto.ClaudeResponse - if info.ResponseTimes == 1 { - claudeResponses = append(claudeResponses, &dto.ClaudeResponse{ - Type: "message_start", - Message: &dto.ClaudeMediaMessage{ - Id: openAIResponse.Id, - Model: openAIResponse.Model, - Type: "message", - Role: "assistant", - Usage: &dto.ClaudeUsage{ - InputTokens: info.PromptTokens, - OutputTokens: 0, - }, + if info.SendResponseCount == 1 { + msg := &dto.ClaudeMediaMessage{ + Id: openAIResponse.Id, + Model: openAIResponse.Model, + Type: "message", + Role: "assistant", + Usage: &dto.ClaudeUsage{ + InputTokens: info.PromptTokens, + OutputTokens: 0, }, + } + msg.SetContent(make([]any, 0)) + claudeResponses = append(claudeResponses, &dto.ClaudeResponse{ + Type: "message_start", + Message: msg, }) + claudeResponses = append(claudeResponses) + //claudeResponses = append(claudeResponses, &dto.ClaudeResponse{ + // Type: "ping", + //}) if openAIResponse.IsToolCall() { resp := &dto.ClaudeResponse{ Type: "content_block_start", @@ -192,23 +221,18 @@ func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamRespon resp.SetIndex(0) claudeResponses = append(claudeResponses, resp) } - claudeResponses = append(claudeResponses, &dto.ClaudeResponse{ - Type: "ping", - }) return claudeResponses } if len(openAIResponse.Choices) == 0 { // no choices // TODO: handle this case + return claudeResponses } else { chosenChoice := openAIResponse.Choices[0] if chosenChoice.FinishReason != nil && *chosenChoice.FinishReason != "" { // should be done - claudeResponses = append(claudeResponses, &dto.ClaudeResponse{ - Type: "content_block_stop", - Index: common.GetPointer[int](0), - }) + claudeResponses = append(claudeResponses, generateStopBlock(info.ClaudeConvertInfo.Index)) if openAIResponse.Usage != nil { claudeResponses = append(claudeResponses, &dto.ClaudeResponse{ Type: "message_delta", @@ -229,18 +253,35 @@ func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamRespon claudeResponse.SetIndex(0) claudeResponse.Type = "content_block_delta" if len(chosenChoice.Delta.ToolCalls) > 0 { + if info.ClaudeConvertInfo.LastMessagesType == relaycommon.LastMessageTypeText { + claudeResponses = append(claudeResponses, generateStopBlock(info.ClaudeConvertInfo.Index)) + info.ClaudeConvertInfo.Index++ + claudeResponses = append(claudeResponses, &dto.ClaudeResponse{ + Index: &info.ClaudeConvertInfo.Index, + Type: "content_block_start", + ContentBlock: &dto.ClaudeMediaMessage{ + Id: openAIResponse.GetFirstToolCall().ID, + Type: "tool_use", + Name: openAIResponse.GetFirstToolCall().Function.Name, + Input: map[string]interface{}{}, + }, + }) + } + info.ClaudeConvertInfo.LastMessagesType = relaycommon.LastMessageTypeTools // tools delta claudeResponse.Delta = &dto.ClaudeMediaMessage{ Type: "input_json_delta", - PartialJson: chosenChoice.Delta.ToolCalls[0].Function.Arguments, + PartialJson: &chosenChoice.Delta.ToolCalls[0].Function.Arguments, } } else { + info.ClaudeConvertInfo.LastMessagesType = relaycommon.LastMessageTypeText // text delta claudeResponse.Delta = &dto.ClaudeMediaMessage{ Type: "text_delta", Text: common.GetPointer[string](chosenChoice.Delta.GetContentString()), } } + claudeResponse.Index = &info.ClaudeConvertInfo.Index claudeResponses = append(claudeResponses, &claudeResponse) } }