refactor: Update OpenAI request and message handling

- Changed the type of ToolCalls in the Message struct from `any` to `json.RawMessage` for better type safety and clarity.
- Introduced ParseToolCalls and SetToolCalls methods to handle ToolCalls more effectively, improving code readability and maintainability.
- Updated the ParseContent method to work with the new MediaContent type instead of MediaMessage, enhancing the structure of content parsing.
- Refactored Gemini relay functions to utilize the new ToolCalls handling methods, streamlining the integration with OpenAI and Gemini systems.
This commit is contained in:
CalciumIon
2024-12-22 16:20:30 +08:00
parent 794f6a6e34
commit 0c326556aa
3 changed files with 78 additions and 74 deletions

View File

@@ -89,11 +89,27 @@ type Message struct {
Role string `json:"role"` Role string `json:"role"`
Content json.RawMessage `json:"content"` Content json.RawMessage `json:"content"`
Name *string `json:"name,omitempty"` Name *string `json:"name,omitempty"`
ToolCalls any `json:"tool_calls,omitempty"` ToolCalls json.RawMessage `json:"tool_calls,omitempty"`
ToolCallId string `json:"tool_call_id,omitempty"` ToolCallId string `json:"tool_call_id,omitempty"`
} }
type MediaMessage struct { func (m Message) ParseToolCalls() []ToolCall {
if m.ToolCalls == nil {
return nil
}
var toolCalls []ToolCall
if err := json.Unmarshal(m.ToolCalls, &toolCalls); err == nil {
return toolCalls
}
return toolCalls
}
func (m Message) SetToolCalls(toolCalls any) {
toolCallsJson, _ := json.Marshal(toolCalls)
m.ToolCalls = toolCallsJson
}
type MediaContent struct {
Type string `json:"type"` Type string `json:"type"`
Text string `json:"text"` Text string `json:"text"`
ImageUrl any `json:"image_url,omitempty"` ImageUrl any `json:"image_url,omitempty"`
@@ -137,11 +153,11 @@ func (m Message) IsStringContent() bool {
return false return false
} }
func (m Message) ParseContent() []MediaMessage { func (m Message) ParseContent() []MediaContent {
var contentList []MediaMessage var contentList []MediaContent
var stringContent string var stringContent string
if err := json.Unmarshal(m.Content, &stringContent); err == nil { if err := json.Unmarshal(m.Content, &stringContent); err == nil {
contentList = append(contentList, MediaMessage{ contentList = append(contentList, MediaContent{
Type: ContentTypeText, Type: ContentTypeText,
Text: stringContent, Text: stringContent,
}) })
@@ -157,7 +173,7 @@ func (m Message) ParseContent() []MediaMessage {
switch contentMap["type"] { switch contentMap["type"] {
case ContentTypeText: case ContentTypeText:
if subStr, ok := contentMap["text"].(string); ok { if subStr, ok := contentMap["text"].(string); ok {
contentList = append(contentList, MediaMessage{ contentList = append(contentList, MediaContent{
Type: ContentTypeText, Type: ContentTypeText,
Text: subStr, Text: subStr,
}) })
@@ -170,7 +186,7 @@ func (m Message) ParseContent() []MediaMessage {
} else { } else {
subObj["detail"] = "high" subObj["detail"] = "high"
} }
contentList = append(contentList, MediaMessage{ contentList = append(contentList, MediaContent{
Type: ContentTypeImageURL, Type: ContentTypeImageURL,
ImageUrl: MessageImageUrl{ ImageUrl: MessageImageUrl{
Url: subObj["url"].(string), Url: subObj["url"].(string),
@@ -178,7 +194,7 @@ func (m Message) ParseContent() []MediaMessage {
}, },
}) })
} else if url, ok := contentMap["image_url"].(string); ok { } else if url, ok := contentMap["image_url"].(string); ok {
contentList = append(contentList, MediaMessage{ contentList = append(contentList, MediaContent{
Type: ContentTypeImageURL, Type: ContentTypeImageURL,
ImageUrl: MessageImageUrl{ ImageUrl: MessageImageUrl{
Url: url, Url: url,
@@ -188,7 +204,7 @@ func (m Message) ParseContent() []MediaMessage {
} }
case ContentTypeInputAudio: case ContentTypeInputAudio:
if subObj, ok := contentMap["input_audio"].(map[string]any); ok { if subObj, ok := contentMap["input_audio"].(map[string]any); ok {
contentList = append(contentList, MediaMessage{ contentList = append(contentList, MediaContent{
Type: ContentTypeInputAudio, Type: ContentTypeInputAudio,
InputAudio: MessageInputAudio{ InputAudio: MessageInputAudio{
Data: subObj["data"].(string), Data: subObj["data"].(string),

View File

@@ -240,14 +240,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
claudeMediaMessages = append(claudeMediaMessages, claudeMediaMessage) claudeMediaMessages = append(claudeMediaMessages, claudeMediaMessage)
} }
if message.ToolCalls != nil { if message.ToolCalls != nil {
for _, tc := range message.ToolCalls.([]interface{}) { for _, toolCall := range message.ParseToolCalls() {
toolCallJSON, _ := json.Marshal(tc)
var toolCall dto.ToolCall
err := json.Unmarshal(toolCallJSON, &toolCall)
if err != nil {
common.SysError("tool call is not a dto.ToolCall: " + fmt.Sprintf("%v", tc))
continue
}
inputObj := make(map[string]any) inputObj := make(map[string]any)
if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &inputObj); err != nil { if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &inputObj); err != nil {
common.SysError("tool call function arguments is not a map[string]any: " + fmt.Sprintf("%v", toolCall.Function.Arguments)) common.SysError("tool call function arguments is not a map[string]any: " + fmt.Sprintf("%v", toolCall.Function.Arguments))
@@ -393,7 +386,7 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.Ope
} }
choice.SetStringContent(responseText) choice.SetStringContent(responseText)
if len(tools) > 0 { if len(tools) > 0 {
choice.Message.ToolCalls = tools choice.Message.SetToolCalls(tools)
} }
fullTextResponse.Model = claudeResponse.Model fullTextResponse.Model = claudeResponse.Model
choices = append(choices, choice) choices = append(choices, choice)

View File

@@ -108,17 +108,29 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
}, },
} }
continue continue
} else if message.Role == "tool" {
message.Role = "model"
} }
var parts []GeminiPart
content := GeminiChatContent{ content := GeminiChatContent{
Role: message.Role, Role: message.Role,
//Parts: []GeminiPart{
// {
// Text: message.StringContent(),
// },
//},
} }
isToolCall := false
if message.ToolCalls != nil {
isToolCall = true
for _, call := range message.ParseToolCalls() {
toolCall := GeminiPart{
FunctionCall: &FunctionCall{
FunctionName: call.Function.Name,
Arguments: call.Function.Parameters,
},
}
parts = append(parts, toolCall)
}
}
if !isToolCall {
openaiContent := message.ParseContent() openaiContent := message.ParseContent()
var parts []GeminiPart
imageNum := 0 imageNum := 0
for _, part := range openaiContent { for _, part := range openaiContent {
if part.Type == dto.ContentTypeText { if part.Type == dto.ContentTypeText {
@@ -155,31 +167,14 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
} }
} }
} }
}
content.Parts = parts content.Parts = parts
// there's no assistant role in gemini and API shall vomit if Role is not user or model // there's no assistant role in gemini and API shall vomit if Role is not user or model
if content.Role == "assistant" { if content.Role == "assistant" {
content.Role = "model" content.Role = "model"
} }
// Converting system prompt to prompt from user for the same reason
//if content.Role == "system" {
// content.Role = "user"
// shouldAddDummyModelMessage = true
//}
geminiRequest.Contents = append(geminiRequest.Contents, content) geminiRequest.Contents = append(geminiRequest.Contents, content)
//
//// If a system message is the last message, we need to add a dummy model message to make gemini happy
//if shouldAddDummyModelMessage {
// geminiRequest.Contents = append(geminiRequest.Contents, GeminiChatContent{
// Role: "model",
// Parts: []GeminiPart{
// {
// Text: "Okay",
// },
// },
// })
// shouldAddDummyModelMessage = false
//}
} }
return &geminiRequest, nil return &geminiRequest, nil
} }
@@ -278,7 +273,7 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp
if len(candidate.Content.Parts) > 0 { if len(candidate.Content.Parts) > 0 {
if candidate.Content.Parts[0].FunctionCall != nil { if candidate.Content.Parts[0].FunctionCall != nil {
choice.FinishReason = constant.FinishReasonToolCalls choice.FinishReason = constant.FinishReasonToolCalls
choice.Message.ToolCalls = getToolCalls(&candidate) choice.Message.SetToolCalls(getToolCalls(&candidate))
} else { } else {
var texts []string var texts []string
for _, part := range candidate.Content.Parts { for _, part := range candidate.Content.Parts {