From 1d0ef89ce989107ac774600e83e3b232f1a04d55 Mon Sep 17 00:00:00 2001 From: CalciumIon <1808837298@qq.com> Date: Sun, 22 Dec 2024 23:53:25 +0800 Subject: [PATCH] feat: Add FunctionResponse type and enhance GeminiPart structure - Introduced a new `FunctionResponse` type to encapsulate function call responses, improving the clarity of data handling. - Updated the `GeminiPart` struct to include the new `FunctionResponse` field, allowing for better representation of function call results in Gemini requests. - Modified the `CovertGemini2OpenAI` function to handle tool calls more effectively by setting the message role and appending function responses to the Gemini parts, enhancing the integration with OpenAI and Gemini systems. --- relay/channel/gemini/dto.go | 12 +++- relay/channel/gemini/relay-gemini.go | 85 ++++++++++++++++------------ 2 files changed, 59 insertions(+), 38 deletions(-) diff --git a/relay/channel/gemini/dto.go b/relay/channel/gemini/dto.go index 16027b41..727446c5 100644 --- a/relay/channel/gemini/dto.go +++ b/relay/channel/gemini/dto.go @@ -18,10 +18,16 @@ type FunctionCall struct { Arguments any `json:"args"` } +type FunctionResponse struct { + Name string `json:"name"` + Response any `json:"response"` +} + type GeminiPart struct { - Text string `json:"text,omitempty"` - InlineData *GeminiInlineData `json:"inlineData,omitempty"` - FunctionCall *FunctionCall `json:"functionCall,omitempty"` + Text string `json:"text,omitempty"` + InlineData *GeminiInlineData `json:"inlineData,omitempty"` + FunctionCall *FunctionCall `json:"functionCall,omitempty"` + FunctionResponse *FunctionResponse `json:"functionResponse,omitempty"` } type GeminiChatContent struct { diff --git a/relay/channel/gemini/relay-gemini.go b/relay/channel/gemini/relay-gemini.go index f2aa24c9..7c557550 100644 --- a/relay/channel/gemini/relay-gemini.go +++ b/relay/channel/gemini/relay-gemini.go @@ -2,6 +2,7 @@ package gemini import ( "bufio" + "context" "encoding/json" "fmt" "io" @@ -108,16 +109,14 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque }, } continue - } else if message.Role == "tool" { - message.Role = "model" } - var parts []GeminiPart content := GeminiChatContent{ Role: message.Role, } isToolCall := false if message.ToolCalls != nil { + message.Role = "model" isToolCall = true for _, call := range message.ParseToolCalls() { toolCall := GeminiPart{ @@ -130,40 +129,55 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque } } if !isToolCall { - openaiContent := message.ParseContent() - imageNum := 0 - for _, part := range openaiContent { - if part.Type == dto.ContentTypeText { - parts = append(parts, GeminiPart{ - Text: part.Text, - }) - } else if part.Type == dto.ContentTypeImageURL { - imageNum += 1 + if message.Role == "tool" { + content.Role = "user" + name := "" + if message.Name != nil { + name = *message.Name + } + functionResp := &FunctionResponse{ + Name: name, + Response: common.StrToMap(message.StringContent()), + } + parts = append(parts, GeminiPart{ + FunctionResponse: functionResp, + }) + } else { + openaiContent := message.ParseContent() + imageNum := 0 + for _, part := range openaiContent { + if part.Type == dto.ContentTypeText { + parts = append(parts, GeminiPart{ + Text: part.Text, + }) + } else if part.Type == dto.ContentTypeImageURL { + imageNum += 1 - if constant.GeminiVisionMaxImageNum != -1 && imageNum > constant.GeminiVisionMaxImageNum { - return nil, fmt.Errorf("too many images in the message, max allowed is %d", constant.GeminiVisionMaxImageNum) - } - // 判断是否是url - if strings.HasPrefix(part.ImageUrl.(dto.MessageImageUrl).Url, "http") { - // 是url,获取图片的类型和base64编码的数据 - mimeType, data, _ := service.GetImageFromUrl(part.ImageUrl.(dto.MessageImageUrl).Url) - parts = append(parts, GeminiPart{ - InlineData: &GeminiInlineData{ - MimeType: mimeType, - Data: data, - }, - }) - } else { - _, format, base64String, err := service.DecodeBase64ImageData(part.ImageUrl.(dto.MessageImageUrl).Url) - if err != nil { - return nil, fmt.Errorf("decode base64 image data failed: %s", err.Error()) + if constant.GeminiVisionMaxImageNum != -1 && imageNum > constant.GeminiVisionMaxImageNum { + return nil, fmt.Errorf("too many images in the message, max allowed is %d", constant.GeminiVisionMaxImageNum) + } + // 判断是否是url + if strings.HasPrefix(part.ImageUrl.(dto.MessageImageUrl).Url, "http") { + // 是url,获取图片的类型和base64编码的数据 + mimeType, data, _ := service.GetImageFromUrl(part.ImageUrl.(dto.MessageImageUrl).Url) + parts = append(parts, GeminiPart{ + InlineData: &GeminiInlineData{ + MimeType: mimeType, + Data: data, + }, + }) + } else { + _, format, base64String, err := service.DecodeBase64ImageData(part.ImageUrl.(dto.MessageImageUrl).Url) + if err != nil { + return nil, fmt.Errorf("decode base64 image data failed: %s", err.Error()) + } + parts = append(parts, GeminiPart{ + InlineData: &GeminiInlineData{ + MimeType: "image/" + format, + Data: base64String, + }, + }) } - parts = append(parts, GeminiPart{ - InlineData: &GeminiInlineData{ - MimeType: "image/" + format, - Data: base64String, - }, - }) } } } @@ -176,6 +190,7 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque } geminiRequest.Contents = append(geminiRequest.Contents, content) } + common.LogJson(context.Background(), "gemini_request", geminiRequest) return &geminiRequest, nil }