From daa7a135053cfdb28f0783714958bcd2770fe5ee Mon Sep 17 00:00:00 2001 From: CaIon Date: Fri, 8 Aug 2025 16:45:37 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=AE=8C=E5=96=84=E6=A0=BC=E5=BC=8F?= =?UTF-8?q?=E6=8A=93=E6=8D=A2=EF=BC=8C=E4=BF=AE=E5=A4=8Dgemini=E6=B8=A0?= =?UTF-8?q?=E9=81=93=E5=92=8Copenai=E6=B8=A0=E9=81=93=E5=9C=A8claude=20cod?= =?UTF-8?q?e=E4=B8=AD=E4=BD=BF=E7=94=A8=E7=9A=84=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- common/gin.go | 3 ++ dto/claude.go | 12 +++++++ dto/openai_response.go | 20 ++++++++++++ relay/channel/gemini/relay-gemini.go | 47 ++++++++++++++++++++++------ relay/channel/openai/adaptor.go | 16 ++++++++++ service/convert.go | 14 ++++++--- 6 files changed, 98 insertions(+), 14 deletions(-) diff --git a/common/gin.go b/common/gin.go index 8c67bb4d..15765970 100644 --- a/common/gin.go +++ b/common/gin.go @@ -31,6 +31,9 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error { if err != nil { return err } + //if DebugEnabled { + // println("UnmarshalBodyReusable request body:", string(requestBody)) + //} contentType := c.Request.Header.Get("Content-Type") if strings.HasPrefix(contentType, "application/json") { err = Unmarshal(requestBody, &v) diff --git a/dto/claude.go b/dto/claude.go index 7b5f348e..58a09217 100644 --- a/dto/claude.go +++ b/dto/claude.go @@ -199,6 +199,18 @@ type ClaudeRequest struct { Thinking *Thinking `json:"thinking,omitempty"` } +func (c *ClaudeRequest) SearchToolNameByToolCallId(toolCallId string) string { + for _, message := range c.Messages { + content, _ := message.ParseContent() + for _, mediaMessage := range content { + if mediaMessage.Id == toolCallId { + return mediaMessage.Name + } + } + } + return "" +} + // AddTool 添加工具到请求中 func (c *ClaudeRequest) AddTool(tool any) { if c.Tools == nil { diff --git a/dto/openai_response.go b/dto/openai_response.go index b050cd03..c2669fd4 100644 --- a/dto/openai_response.go +++ b/dto/openai_response.go @@ -143,6 +143,13 @@ type ChatCompletionsStreamResponse struct { Usage *Usage `json:"usage"` } +func (c *ChatCompletionsStreamResponse) IsFinished() bool { + if len(c.Choices) == 0 { + return false + } + return c.Choices[0].FinishReason != nil && *c.Choices[0].FinishReason != "" +} + func (c *ChatCompletionsStreamResponse) IsToolCall() bool { if len(c.Choices) == 0 { return false @@ -157,6 +164,19 @@ func (c *ChatCompletionsStreamResponse) GetFirstToolCall() *ToolCallResponse { return nil } +func (c *ChatCompletionsStreamResponse) ClearToolCalls() { + if !c.IsToolCall() { + return + } + for choiceIdx := range c.Choices { + for callIdx := range c.Choices[choiceIdx].Delta.ToolCalls { + c.Choices[choiceIdx].Delta.ToolCalls[callIdx].ID = "" + c.Choices[choiceIdx].Delta.ToolCalls[callIdx].Type = nil + c.Choices[choiceIdx].Delta.ToolCalls[callIdx].Function.Name = "" + } + } +} + func (c *ChatCompletionsStreamResponse) Copy() *ChatCompletionsStreamResponse { choices := make([]ChatCompletionsStreamResponseChoice, len(c.Choices)) copy(choices, c.Choices) diff --git a/relay/channel/gemini/relay-gemini.go b/relay/channel/gemini/relay-gemini.go index 698a972c..25a2c412 100644 --- a/relay/channel/gemini/relay-gemini.go +++ b/relay/channel/gemini/relay-gemini.go @@ -835,6 +835,7 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *dto.GeminiChatResponse) (*d call.SetIndex(len(choice.Delta.ToolCalls)) choice.Delta.ToolCalls = append(choice.Delta.ToolCalls, *call) } + } else if part.Thought { isThought = true texts = append(texts, part.Text) @@ -895,6 +896,7 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp * responseText := strings.Builder{} var usage = &dto.Usage{} var imageCount int + finishReason := constant.FinishReasonStop helper.StreamScannerHandler(c, resp, info, func(data string) bool { var geminiResponse dto.GeminiChatResponse @@ -936,9 +938,21 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp * if info.SendResponseCount == 0 { // send first response - err = handleStream(c, info, helper.GenerateStartEmptyResponse(id, createAt, info.UpstreamModelName, nil)) - if err != nil { - common.LogError(c, err.Error()) + emptyResponse := helper.GenerateStartEmptyResponse(id, createAt, info.UpstreamModelName, nil) + if response.IsToolCall() { + emptyResponse.Choices[0].Delta.ToolCalls = make([]dto.ToolCallResponse, 1) + emptyResponse.Choices[0].Delta.ToolCalls[0] = *response.GetFirstToolCall() + emptyResponse.Choices[0].Delta.ToolCalls[0].Function.Arguments = "" + finishReason = constant.FinishReasonToolCalls + err = handleStream(c, info, emptyResponse) + if err != nil { + common.LogError(c, err.Error()) + } + + response.ClearToolCalls() + if response.IsFinished() { + response.Choices[0].FinishReason = nil + } } } @@ -947,7 +961,7 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp * common.LogError(c, err.Error()) } if isStop { - _ = handleStream(c, info, helper.GenerateStopResponse(id, createAt, info.UpstreamModelName, constant.FinishReasonStop)) + _ = handleStream(c, info, helper.GenerateStopResponse(id, createAt, info.UpstreamModelName, finishReason)) } return true }) @@ -1026,13 +1040,26 @@ func GeminiChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R } fullTextResponse.Usage = usage - jsonResponse, err := json.Marshal(fullTextResponse) - if err != nil { - return nil, types.NewError(err, types.ErrorCodeBadResponseBody) + + switch info.RelayFormat { + case relaycommon.RelayFormatOpenAI: + responseBody, err = common.Marshal(fullTextResponse) + if err != nil { + return nil, types.NewError(err, types.ErrorCodeBadResponseBody) + } + case relaycommon.RelayFormatClaude: + claudeResp := service.ResponseOpenAI2Claude(fullTextResponse, info) + claudeRespStr, err := common.Marshal(claudeResp) + if err != nil { + return nil, types.NewError(err, types.ErrorCodeBadResponseBody) + } + responseBody = claudeRespStr + case relaycommon.RelayFormatGemini: + break } - c.Writer.Header().Set("Content-Type", "application/json") - c.Writer.WriteHeader(resp.StatusCode) - c.Writer.Write(jsonResponse) + + common.IOCopyBytesGracefully(c, resp, responseBody) + return &usage, nil } diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go index 8f81ab8c..ed2bea57 100644 --- a/relay/channel/openai/adaptor.go +++ b/relay/channel/openai/adaptor.go @@ -63,10 +63,26 @@ func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayIn //if !strings.Contains(request.Model, "claude") { // return nil, fmt.Errorf("you are using openai channel type with path /v1/messages, only claude model supported convert, but got %s", request.Model) //} + //if common.DebugEnabled { + // bodyBytes := []byte(common.GetJsonString(request)) + // err := os.WriteFile(fmt.Sprintf("claude_request_%s.txt", c.GetString(common.RequestIdKey)), bodyBytes, 0644) + // if err != nil { + // println(fmt.Sprintf("failed to save request body to file: %v", err)) + // } + //} aiRequest, err := service.ClaudeToOpenAIRequest(*request, info) if err != nil { return nil, err } + //if common.DebugEnabled { + // println(fmt.Sprintf("convert claude to openai request result: %s", common.GetJsonString(aiRequest))) + // // Save request body to file for debugging + // bodyBytes := []byte(common.GetJsonString(aiRequest)) + // err = os.WriteFile(fmt.Sprintf("claude_to_openai_request_%s.txt", c.GetString(common.RequestIdKey)), bodyBytes, 0644) + // if err != nil { + // println(fmt.Sprintf("failed to save request body to file: %v", err)) + // } + //} if info.SupportStreamOptions && info.IsStream { aiRequest.StreamOptions = &dto.StreamOptions{ IncludeUsage: true, diff --git a/service/convert.go b/service/convert.go index 925feae9..a6cfce6c 100644 --- a/service/convert.go +++ b/service/convert.go @@ -153,9 +153,13 @@ func ClaudeToOpenAIRequest(claudeRequest dto.ClaudeRequest, info *relaycommon.Re toolCalls = append(toolCalls, toolCall) case "tool_result": // Add tool result as a separate message + toolName := mediaMsg.Name + if toolName == "" { + toolName = claudeRequest.SearchToolNameByToolCallId(mediaMsg.ToolUseId) + } oaiToolMessage := dto.Message{ Role: "tool", - Name: &mediaMsg.Name, + Name: &toolName, ToolCallId: mediaMsg.ToolUseId, } //oaiToolMessage.SetStringContent(*mediaMsg.GetMediaContent().Text) @@ -218,12 +222,14 @@ func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamRespon // Type: "ping", //}) if openAIResponse.IsToolCall() { + info.ClaudeConvertInfo.LastMessagesType = relaycommon.LastMessageTypeTools resp := &dto.ClaudeResponse{ Type: "content_block_start", ContentBlock: &dto.ClaudeMediaMessage{ - Id: openAIResponse.GetFirstToolCall().ID, - Type: "tool_use", - Name: openAIResponse.GetFirstToolCall().Function.Name, + Id: openAIResponse.GetFirstToolCall().ID, + Type: "tool_use", + Name: openAIResponse.GetFirstToolCall().Function.Name, + Input: map[string]interface{}{}, }, } resp.SetIndex(0)