From a4795737fe3ecb412f5c8b01e8a8eb86370cbc33 Mon Sep 17 00:00:00 2001 From: Yan <1964649083@qq.com> Date: Mon, 23 Dec 2024 01:26:14 +0800 Subject: [PATCH] fix: mutil func call in gemini --- .gitignore | 1 + common/str.go | 4 +- relay/channel/gemini/relay-gemini.go | 185 ++++++++++++++++----------- 3 files changed, 114 insertions(+), 76 deletions(-) diff --git a/.gitignore b/.gitignore index ac995ac7..f99bae43 100644 --- a/.gitignore +++ b/.gitignore @@ -8,3 +8,4 @@ build logs web/dist .env +one-api \ No newline at end of file diff --git a/common/str.go b/common/str.go index d42fd837..e3834694 100644 --- a/common/str.go +++ b/common/str.go @@ -35,7 +35,9 @@ func StrToMap(str string) map[string]interface{} { m := make(map[string]interface{}) err := json.Unmarshal([]byte(str), &m) if err != nil { - return nil + return map[string]interface{}{ + "result": str, + } } return m } diff --git a/relay/channel/gemini/relay-gemini.go b/relay/channel/gemini/relay-gemini.go index c6e91f2b..4c12812a 100644 --- a/relay/channel/gemini/relay-gemini.go +++ b/relay/channel/gemini/relay-gemini.go @@ -95,7 +95,7 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque geminiRequest.GenerationConfig.ResponseSchema = cleanedSchema } } - + tool_call_ids := make(map[string]string) //shouldAddDummyModelMessage := false for _, message := range textRequest.Messages { @@ -108,6 +108,27 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque }, } continue + } else if message.Role == "tool" { + if len(geminiRequest.Contents) == 0 || geminiRequest.Contents[len(geminiRequest.Contents)-1].Role != "user" { + geminiRequest.Contents = append(geminiRequest.Contents, GeminiChatContent{ + Role: "user", + }) + } + var parts = &geminiRequest.Contents[len(geminiRequest.Contents)-1].Parts + name := "" + if message.Name != nil { + name = *message.Name + } else if val, exists := tool_call_ids[message.ToolCallId]; exists { + name = val + } + functionResp := &FunctionResponse{ + Name: name, + Response: common.StrToMap(message.StringContent()), + } + *parts = append(*parts, GeminiPart{ + FunctionResponse: functionResp, + }) + continue } var parts []GeminiPart content := GeminiChatContent{ @@ -125,62 +146,49 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque }, } parts = append(parts, toolCall) + tool_call_ids[call.ID] = call.Function.Name } } if !isToolCall { - if message.Role == "tool" { - content.Role = "user" - name := "" - if message.Name != nil { - name = *message.Name - } - functionResp := &FunctionResponse{ - Name: name, - Response: common.StrToMap(message.StringContent()), - } - parts = append(parts, GeminiPart{ - FunctionResponse: functionResp, - }) - } else { - openaiContent := message.ParseContent() - imageNum := 0 - for _, part := range openaiContent { - if part.Type == dto.ContentTypeText { - parts = append(parts, GeminiPart{ - Text: part.Text, - }) - } else if part.Type == dto.ContentTypeImageURL { - imageNum += 1 + openaiContent := message.ParseContent() + imageNum := 0 + for _, part := range openaiContent { + if part.Type == dto.ContentTypeText { + parts = append(parts, GeminiPart{ + Text: part.Text, + }) + } else if part.Type == dto.ContentTypeImageURL { + imageNum += 1 - if constant.GeminiVisionMaxImageNum != -1 && imageNum > constant.GeminiVisionMaxImageNum { - return nil, fmt.Errorf("too many images in the message, max allowed is %d", constant.GeminiVisionMaxImageNum) - } - // 判断是否是url - if strings.HasPrefix(part.ImageUrl.(dto.MessageImageUrl).Url, "http") { - // 是url,获取图片的类型和base64编码的数据 - mimeType, data, _ := service.GetImageFromUrl(part.ImageUrl.(dto.MessageImageUrl).Url) - parts = append(parts, GeminiPart{ - InlineData: &GeminiInlineData{ - MimeType: mimeType, - Data: data, - }, - }) - } else { - _, format, base64String, err := service.DecodeBase64ImageData(part.ImageUrl.(dto.MessageImageUrl).Url) - if err != nil { - return nil, fmt.Errorf("decode base64 image data failed: %s", err.Error()) - } - parts = append(parts, GeminiPart{ - InlineData: &GeminiInlineData{ - MimeType: "image/" + format, - Data: base64String, - }, - }) + if constant.GeminiVisionMaxImageNum != -1 && imageNum > constant.GeminiVisionMaxImageNum { + return nil, fmt.Errorf("too many images in the message, max allowed is %d", constant.GeminiVisionMaxImageNum) + } + // 判断是否是url + if strings.HasPrefix(part.ImageUrl.(dto.MessageImageUrl).Url, "http") { + // 是url,获取图片的类型和base64编码的数据 + mimeType, data, _ := service.GetImageFromUrl(part.ImageUrl.(dto.MessageImageUrl).Url) + parts = append(parts, GeminiPart{ + InlineData: &GeminiInlineData{ + MimeType: mimeType, + Data: data, + }, + }) + } else { + _, format, base64String, err := service.DecodeBase64ImageData(part.ImageUrl.(dto.MessageImageUrl).Url) + if err != nil { + return nil, fmt.Errorf("decode base64 image data failed: %s", err.Error()) } + parts = append(parts, GeminiPart{ + InlineData: &GeminiInlineData{ + MimeType: "image/" + format, + Data: base64String, + }, + }) } } } } + content.Parts = parts // there's no assistant role in gemini and API shall vomit if Role is not user or model @@ -242,19 +250,13 @@ func (g *GeminiChatResponse) GetResponseText() string { return "" } -func getToolCalls(candidate *GeminiChatCandidate) []dto.ToolCall { - var toolCalls []dto.ToolCall - - item := candidate.Content.Parts[0] - if item.FunctionCall == nil { - return toolCalls - } +func getToolCall(item *GeminiPart) *dto.ToolCall { argsBytes, err := json.Marshal(item.FunctionCall.Arguments) if err != nil { - //common.SysError("getToolCalls failed: " + err.Error()) - return toolCalls + //common.SysError("getToolCall failed: " + err.Error()) + return nil } - toolCall := dto.ToolCall{ + return &dto.ToolCall{ ID: fmt.Sprintf("call_%s", common.GetUUID()), Type: "function", Function: dto.FunctionCall{ @@ -262,10 +264,32 @@ func getToolCalls(candidate *GeminiChatCandidate) []dto.ToolCall { Name: item.FunctionCall.FunctionName, }, } - toolCalls = append(toolCalls, toolCall) - return toolCalls } +// func getToolCalls(candidate *GeminiChatCandidate, index int) []dto.ToolCall { +// var toolCalls []dto.ToolCall + +// item := candidate.Content.Parts[index] +// if item.FunctionCall == nil { +// return toolCalls +// } +// argsBytes, err := json.Marshal(item.FunctionCall.Arguments) +// if err != nil { +// //common.SysError("getToolCalls failed: " + err.Error()) +// return toolCalls +// } +// toolCall := dto.ToolCall{ +// ID: fmt.Sprintf("call_%s", common.GetUUID()), +// Type: "function", +// Function: dto.FunctionCall{ +// Arguments: string(argsBytes), +// Name: item.FunctionCall.FunctionName, +// }, +// } +// toolCalls = append(toolCalls, toolCall) +// return toolCalls +// } + func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResponse { fullTextResponse := dto.OpenAITextResponse{ Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), @@ -275,6 +299,8 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp } content, _ := json.Marshal("") for i, candidate := range response.Candidates { + // jsonData, _ := json.MarshalIndent(candidate, "", " ") + // common.SysLog(fmt.Sprintf("candidate: %v", string(jsonData))) choice := dto.OpenAITextResponseChoice{ Index: i, Message: dto.Message{ @@ -284,16 +310,20 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp FinishReason: constant.FinishReasonStop, } if len(candidate.Content.Parts) > 0 { - if candidate.Content.Parts[0].FunctionCall != nil { - choice.FinishReason = constant.FinishReasonToolCalls - choice.Message.SetToolCalls(getToolCalls(&candidate)) - } else { - var texts []string - for _, part := range candidate.Content.Parts { + var texts []string + var tool_calls []dto.ToolCall + for _, part := range candidate.Content.Parts { + if part.FunctionCall != nil { + choice.FinishReason = constant.FinishReasonToolCalls + if call := getToolCall(&part); call != nil { + tool_calls = append(tool_calls, *call) + } + } else { texts = append(texts, part.Text) } - choice.Message.SetStringContent(strings.Join(texts, "\n")) } + choice.Message.SetStringContent(strings.Join(texts, "\n")) + choice.Message.SetToolCalls(tool_calls) } fullTextResponse.Choices = append(fullTextResponse.Choices, choice) } @@ -304,18 +334,23 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *dto.Ch var choice dto.ChatCompletionsStreamResponseChoice //choice.Delta.SetContentString(geminiResponse.GetResponseText()) if len(geminiResponse.Candidates) > 0 && len(geminiResponse.Candidates[0].Content.Parts) > 0 { - respFirstParts := geminiResponse.Candidates[0].Content.Parts - if respFirstParts[0].FunctionCall != nil { - // function response - choice.Delta.ToolCalls = getToolCalls(&geminiResponse.Candidates[0]) - } else { - // text response - var texts []string - for _, part := range respFirstParts { + var texts []string + var tool_calls []dto.ToolCall + for _, part := range geminiResponse.Candidates[0].Content.Parts { + if part.FunctionCall != nil { + if call := getToolCall(&part); call != nil { + tool_calls = append(tool_calls, *call) + } + } else { texts = append(texts, part.Text) } + } + if len(texts) > 0 { choice.Delta.SetContentString(strings.Join(texts, "\n")) } + if len(tool_calls) > 0 { + choice.Delta.ToolCalls = tool_calls + } } var response dto.ChatCompletionsStreamResponse response.Object = "chat.completion.chunk"