From 0c326556aa2ee4e6f64242de133c612eab3b2fc1 Mon Sep 17 00:00:00 2001 From: CalciumIon <1808837298@qq.com> Date: Sun, 22 Dec 2024 16:20:30 +0800 Subject: [PATCH] refactor: Update OpenAI request and message handling - Changed the type of ToolCalls in the Message struct from `any` to `json.RawMessage` for better type safety and clarity. - Introduced ParseToolCalls and SetToolCalls methods to handle ToolCalls more effectively, improving code readability and maintainability. - Updated the ParseContent method to work with the new MediaContent type instead of MediaMessage, enhancing the structure of content parsing. - Refactored Gemini relay functions to utilize the new ToolCalls handling methods, streamlining the integration with OpenAI and Gemini systems. --- dto/openai_request.go | 36 ++++++--- relay/channel/claude/relay-claude.go | 11 +-- relay/channel/gemini/relay-gemini.go | 105 +++++++++++++-------------- 3 files changed, 78 insertions(+), 74 deletions(-) diff --git a/dto/openai_request.go b/dto/openai_request.go index b1eebe17..752311c7 100644 --- a/dto/openai_request.go +++ b/dto/openai_request.go @@ -22,7 +22,7 @@ type GeneralOpenAIRequest struct { StreamOptions *StreamOptions `json:"stream_options,omitempty"` MaxTokens uint `json:"max_tokens,omitempty"` MaxCompletionTokens uint `json:"max_completion_tokens,omitempty"` - ReasoningEffort string `json:"reasoning_effort,omitempty"` + ReasoningEffort string `json:"reasoning_effort,omitempty"` Temperature float64 `json:"temperature,omitempty"` TopP float64 `json:"top_p,omitempty"` TopK int `json:"top_k,omitempty"` @@ -89,11 +89,27 @@ type Message struct { Role string `json:"role"` Content json.RawMessage `json:"content"` Name *string `json:"name,omitempty"` - ToolCalls any `json:"tool_calls,omitempty"` + ToolCalls json.RawMessage `json:"tool_calls,omitempty"` ToolCallId string `json:"tool_call_id,omitempty"` } -type MediaMessage struct { +func (m Message) ParseToolCalls() []ToolCall { + if m.ToolCalls == nil { + return nil + } + var toolCalls []ToolCall + if err := json.Unmarshal(m.ToolCalls, &toolCalls); err == nil { + return toolCalls + } + return toolCalls +} + +func (m Message) SetToolCalls(toolCalls any) { + toolCallsJson, _ := json.Marshal(toolCalls) + m.ToolCalls = toolCallsJson +} + +type MediaContent struct { Type string `json:"type"` Text string `json:"text"` ImageUrl any `json:"image_url,omitempty"` @@ -137,11 +153,11 @@ func (m Message) IsStringContent() bool { return false } -func (m Message) ParseContent() []MediaMessage { - var contentList []MediaMessage +func (m Message) ParseContent() []MediaContent { + var contentList []MediaContent var stringContent string if err := json.Unmarshal(m.Content, &stringContent); err == nil { - contentList = append(contentList, MediaMessage{ + contentList = append(contentList, MediaContent{ Type: ContentTypeText, Text: stringContent, }) @@ -157,7 +173,7 @@ func (m Message) ParseContent() []MediaMessage { switch contentMap["type"] { case ContentTypeText: if subStr, ok := contentMap["text"].(string); ok { - contentList = append(contentList, MediaMessage{ + contentList = append(contentList, MediaContent{ Type: ContentTypeText, Text: subStr, }) @@ -170,7 +186,7 @@ func (m Message) ParseContent() []MediaMessage { } else { subObj["detail"] = "high" } - contentList = append(contentList, MediaMessage{ + contentList = append(contentList, MediaContent{ Type: ContentTypeImageURL, ImageUrl: MessageImageUrl{ Url: subObj["url"].(string), @@ -178,7 +194,7 @@ func (m Message) ParseContent() []MediaMessage { }, }) } else if url, ok := contentMap["image_url"].(string); ok { - contentList = append(contentList, MediaMessage{ + contentList = append(contentList, MediaContent{ Type: ContentTypeImageURL, ImageUrl: MessageImageUrl{ Url: url, @@ -188,7 +204,7 @@ func (m Message) ParseContent() []MediaMessage { } case ContentTypeInputAudio: if subObj, ok := contentMap["input_audio"].(map[string]any); ok { - contentList = append(contentList, MediaMessage{ + contentList = append(contentList, MediaContent{ Type: ContentTypeInputAudio, InputAudio: MessageInputAudio{ Data: subObj["data"].(string), diff --git a/relay/channel/claude/relay-claude.go b/relay/channel/claude/relay-claude.go index 4c7f188b..0cddf8a6 100644 --- a/relay/channel/claude/relay-claude.go +++ b/relay/channel/claude/relay-claude.go @@ -240,14 +240,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR claudeMediaMessages = append(claudeMediaMessages, claudeMediaMessage) } if message.ToolCalls != nil { - for _, tc := range message.ToolCalls.([]interface{}) { - toolCallJSON, _ := json.Marshal(tc) - var toolCall dto.ToolCall - err := json.Unmarshal(toolCallJSON, &toolCall) - if err != nil { - common.SysError("tool call is not a dto.ToolCall: " + fmt.Sprintf("%v", tc)) - continue - } + for _, toolCall := range message.ParseToolCalls() { inputObj := make(map[string]any) if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &inputObj); err != nil { common.SysError("tool call function arguments is not a map[string]any: " + fmt.Sprintf("%v", toolCall.Function.Arguments)) @@ -393,7 +386,7 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.Ope } choice.SetStringContent(responseText) if len(tools) > 0 { - choice.Message.ToolCalls = tools + choice.Message.SetToolCalls(tools) } fullTextResponse.Model = claudeResponse.Model choices = append(choices, choice) diff --git a/relay/channel/gemini/relay-gemini.go b/relay/channel/gemini/relay-gemini.go index 2041d8e7..f2aa24c9 100644 --- a/relay/channel/gemini/relay-gemini.go +++ b/relay/channel/gemini/relay-gemini.go @@ -108,50 +108,63 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque }, } continue + } else if message.Role == "tool" { + message.Role = "model" } + + var parts []GeminiPart content := GeminiChatContent{ Role: message.Role, - //Parts: []GeminiPart{ - // { - // Text: message.StringContent(), - // }, - //}, } - openaiContent := message.ParseContent() - var parts []GeminiPart - imageNum := 0 - for _, part := range openaiContent { - if part.Type == dto.ContentTypeText { - parts = append(parts, GeminiPart{ - Text: part.Text, - }) - } else if part.Type == dto.ContentTypeImageURL { - imageNum += 1 - - if constant.GeminiVisionMaxImageNum != -1 && imageNum > constant.GeminiVisionMaxImageNum { - return nil, fmt.Errorf("too many images in the message, max allowed is %d", constant.GeminiVisionMaxImageNum) + isToolCall := false + if message.ToolCalls != nil { + isToolCall = true + for _, call := range message.ParseToolCalls() { + toolCall := GeminiPart{ + FunctionCall: &FunctionCall{ + FunctionName: call.Function.Name, + Arguments: call.Function.Parameters, + }, } - // 判断是否是url - if strings.HasPrefix(part.ImageUrl.(dto.MessageImageUrl).Url, "http") { - // 是url,获取图片的类型和base64编码的数据 - mimeType, data, _ := service.GetImageFromUrl(part.ImageUrl.(dto.MessageImageUrl).Url) + parts = append(parts, toolCall) + } + } + if !isToolCall { + openaiContent := message.ParseContent() + imageNum := 0 + for _, part := range openaiContent { + if part.Type == dto.ContentTypeText { parts = append(parts, GeminiPart{ - InlineData: &GeminiInlineData{ - MimeType: mimeType, - Data: data, - }, + Text: part.Text, }) - } else { - _, format, base64String, err := service.DecodeBase64ImageData(part.ImageUrl.(dto.MessageImageUrl).Url) - if err != nil { - return nil, fmt.Errorf("decode base64 image data failed: %s", err.Error()) + } else if part.Type == dto.ContentTypeImageURL { + imageNum += 1 + + if constant.GeminiVisionMaxImageNum != -1 && imageNum > constant.GeminiVisionMaxImageNum { + return nil, fmt.Errorf("too many images in the message, max allowed is %d", constant.GeminiVisionMaxImageNum) + } + // 判断是否是url + if strings.HasPrefix(part.ImageUrl.(dto.MessageImageUrl).Url, "http") { + // 是url,获取图片的类型和base64编码的数据 + mimeType, data, _ := service.GetImageFromUrl(part.ImageUrl.(dto.MessageImageUrl).Url) + parts = append(parts, GeminiPart{ + InlineData: &GeminiInlineData{ + MimeType: mimeType, + Data: data, + }, + }) + } else { + _, format, base64String, err := service.DecodeBase64ImageData(part.ImageUrl.(dto.MessageImageUrl).Url) + if err != nil { + return nil, fmt.Errorf("decode base64 image data failed: %s", err.Error()) + } + parts = append(parts, GeminiPart{ + InlineData: &GeminiInlineData{ + MimeType: "image/" + format, + Data: base64String, + }, + }) } - parts = append(parts, GeminiPart{ - InlineData: &GeminiInlineData{ - MimeType: "image/" + format, - Data: base64String, - }, - }) } } } @@ -161,25 +174,7 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque if content.Role == "assistant" { content.Role = "model" } - // Converting system prompt to prompt from user for the same reason - //if content.Role == "system" { - // content.Role = "user" - // shouldAddDummyModelMessage = true - //} geminiRequest.Contents = append(geminiRequest.Contents, content) - // - //// If a system message is the last message, we need to add a dummy model message to make gemini happy - //if shouldAddDummyModelMessage { - // geminiRequest.Contents = append(geminiRequest.Contents, GeminiChatContent{ - // Role: "model", - // Parts: []GeminiPart{ - // { - // Text: "Okay", - // }, - // }, - // }) - // shouldAddDummyModelMessage = false - //} } return &geminiRequest, nil } @@ -278,7 +273,7 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp if len(candidate.Content.Parts) > 0 { if candidate.Content.Parts[0].FunctionCall != nil { choice.FinishReason = constant.FinishReasonToolCalls - choice.Message.ToolCalls = getToolCalls(&candidate) + choice.Message.SetToolCalls(getToolCalls(&candidate)) } else { var texts []string for _, part := range candidate.Content.Parts {