refactor: Update OpenAI request and message handling

- Changed the type of ToolCalls in the Message struct from `any` to `json.RawMessage` for better type safety and clarity.
- Introduced ParseToolCalls and SetToolCalls methods to handle ToolCalls more effectively, improving code readability and maintainability.
- Updated the ParseContent method to work with the new MediaContent type instead of MediaMessage, enhancing the structure of content parsing.
- Refactored Gemini relay functions to utilize the new ToolCalls handling methods, streamlining the integration with OpenAI and Gemini systems.
This commit is contained in:
CalciumIon
2024-12-22 16:20:30 +08:00
parent 794f6a6e34
commit 0c326556aa
3 changed files with 78 additions and 74 deletions

View File

@@ -240,14 +240,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
claudeMediaMessages = append(claudeMediaMessages, claudeMediaMessage)
}
if message.ToolCalls != nil {
for _, tc := range message.ToolCalls.([]interface{}) {
toolCallJSON, _ := json.Marshal(tc)
var toolCall dto.ToolCall
err := json.Unmarshal(toolCallJSON, &toolCall)
if err != nil {
common.SysError("tool call is not a dto.ToolCall: " + fmt.Sprintf("%v", tc))
continue
}
for _, toolCall := range message.ParseToolCalls() {
inputObj := make(map[string]any)
if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &inputObj); err != nil {
common.SysError("tool call function arguments is not a map[string]any: " + fmt.Sprintf("%v", toolCall.Function.Arguments))
@@ -393,7 +386,7 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.Ope
}
choice.SetStringContent(responseText)
if len(tools) > 0 {
choice.Message.ToolCalls = tools
choice.Message.SetToolCalls(tools)
}
fullTextResponse.Model = claudeResponse.Model
choices = append(choices, choice)

View File

@@ -108,50 +108,63 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
},
}
continue
} else if message.Role == "tool" {
message.Role = "model"
}
var parts []GeminiPart
content := GeminiChatContent{
Role: message.Role,
//Parts: []GeminiPart{
// {
// Text: message.StringContent(),
// },
//},
}
openaiContent := message.ParseContent()
var parts []GeminiPart
imageNum := 0
for _, part := range openaiContent {
if part.Type == dto.ContentTypeText {
parts = append(parts, GeminiPart{
Text: part.Text,
})
} else if part.Type == dto.ContentTypeImageURL {
imageNum += 1
if constant.GeminiVisionMaxImageNum != -1 && imageNum > constant.GeminiVisionMaxImageNum {
return nil, fmt.Errorf("too many images in the message, max allowed is %d", constant.GeminiVisionMaxImageNum)
isToolCall := false
if message.ToolCalls != nil {
isToolCall = true
for _, call := range message.ParseToolCalls() {
toolCall := GeminiPart{
FunctionCall: &FunctionCall{
FunctionName: call.Function.Name,
Arguments: call.Function.Parameters,
},
}
// 判断是否是url
if strings.HasPrefix(part.ImageUrl.(dto.MessageImageUrl).Url, "http") {
// 是url获取图片的类型和base64编码的数据
mimeType, data, _ := service.GetImageFromUrl(part.ImageUrl.(dto.MessageImageUrl).Url)
parts = append(parts, toolCall)
}
}
if !isToolCall {
openaiContent := message.ParseContent()
imageNum := 0
for _, part := range openaiContent {
if part.Type == dto.ContentTypeText {
parts = append(parts, GeminiPart{
InlineData: &GeminiInlineData{
MimeType: mimeType,
Data: data,
},
Text: part.Text,
})
} else {
_, format, base64String, err := service.DecodeBase64ImageData(part.ImageUrl.(dto.MessageImageUrl).Url)
if err != nil {
return nil, fmt.Errorf("decode base64 image data failed: %s", err.Error())
} else if part.Type == dto.ContentTypeImageURL {
imageNum += 1
if constant.GeminiVisionMaxImageNum != -1 && imageNum > constant.GeminiVisionMaxImageNum {
return nil, fmt.Errorf("too many images in the message, max allowed is %d", constant.GeminiVisionMaxImageNum)
}
// 判断是否是url
if strings.HasPrefix(part.ImageUrl.(dto.MessageImageUrl).Url, "http") {
// 是url获取图片的类型和base64编码的数据
mimeType, data, _ := service.GetImageFromUrl(part.ImageUrl.(dto.MessageImageUrl).Url)
parts = append(parts, GeminiPart{
InlineData: &GeminiInlineData{
MimeType: mimeType,
Data: data,
},
})
} else {
_, format, base64String, err := service.DecodeBase64ImageData(part.ImageUrl.(dto.MessageImageUrl).Url)
if err != nil {
return nil, fmt.Errorf("decode base64 image data failed: %s", err.Error())
}
parts = append(parts, GeminiPart{
InlineData: &GeminiInlineData{
MimeType: "image/" + format,
Data: base64String,
},
})
}
parts = append(parts, GeminiPart{
InlineData: &GeminiInlineData{
MimeType: "image/" + format,
Data: base64String,
},
})
}
}
}
@@ -161,25 +174,7 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
if content.Role == "assistant" {
content.Role = "model"
}
// Converting system prompt to prompt from user for the same reason
//if content.Role == "system" {
// content.Role = "user"
// shouldAddDummyModelMessage = true
//}
geminiRequest.Contents = append(geminiRequest.Contents, content)
//
//// If a system message is the last message, we need to add a dummy model message to make gemini happy
//if shouldAddDummyModelMessage {
// geminiRequest.Contents = append(geminiRequest.Contents, GeminiChatContent{
// Role: "model",
// Parts: []GeminiPart{
// {
// Text: "Okay",
// },
// },
// })
// shouldAddDummyModelMessage = false
//}
}
return &geminiRequest, nil
}
@@ -278,7 +273,7 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp
if len(candidate.Content.Parts) > 0 {
if candidate.Content.Parts[0].FunctionCall != nil {
choice.FinishReason = constant.FinishReasonToolCalls
choice.Message.ToolCalls = getToolCalls(&candidate)
choice.Message.SetToolCalls(getToolCalls(&candidate))
} else {
var texts []string
for _, part := range candidate.Content.Parts {