refactor: Update OpenAI request and message handling
- Changed the type of ToolCalls in the Message struct from `any` to `json.RawMessage` for better type safety and clarity. - Introduced ParseToolCalls and SetToolCalls methods to handle ToolCalls more effectively, improving code readability and maintainability. - Updated the ParseContent method to work with the new MediaContent type instead of MediaMessage, enhancing the structure of content parsing. - Refactored Gemini relay functions to utilize the new ToolCalls handling methods, streamlining the integration with OpenAI and Gemini systems.
This commit is contained in:
@@ -22,7 +22,7 @@ type GeneralOpenAIRequest struct {
|
|||||||
StreamOptions *StreamOptions `json:"stream_options,omitempty"`
|
StreamOptions *StreamOptions `json:"stream_options,omitempty"`
|
||||||
MaxTokens uint `json:"max_tokens,omitempty"`
|
MaxTokens uint `json:"max_tokens,omitempty"`
|
||||||
MaxCompletionTokens uint `json:"max_completion_tokens,omitempty"`
|
MaxCompletionTokens uint `json:"max_completion_tokens,omitempty"`
|
||||||
ReasoningEffort string `json:"reasoning_effort,omitempty"`
|
ReasoningEffort string `json:"reasoning_effort,omitempty"`
|
||||||
Temperature float64 `json:"temperature,omitempty"`
|
Temperature float64 `json:"temperature,omitempty"`
|
||||||
TopP float64 `json:"top_p,omitempty"`
|
TopP float64 `json:"top_p,omitempty"`
|
||||||
TopK int `json:"top_k,omitempty"`
|
TopK int `json:"top_k,omitempty"`
|
||||||
@@ -89,11 +89,27 @@ type Message struct {
|
|||||||
Role string `json:"role"`
|
Role string `json:"role"`
|
||||||
Content json.RawMessage `json:"content"`
|
Content json.RawMessage `json:"content"`
|
||||||
Name *string `json:"name,omitempty"`
|
Name *string `json:"name,omitempty"`
|
||||||
ToolCalls any `json:"tool_calls,omitempty"`
|
ToolCalls json.RawMessage `json:"tool_calls,omitempty"`
|
||||||
ToolCallId string `json:"tool_call_id,omitempty"`
|
ToolCallId string `json:"tool_call_id,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type MediaMessage struct {
|
func (m Message) ParseToolCalls() []ToolCall {
|
||||||
|
if m.ToolCalls == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
var toolCalls []ToolCall
|
||||||
|
if err := json.Unmarshal(m.ToolCalls, &toolCalls); err == nil {
|
||||||
|
return toolCalls
|
||||||
|
}
|
||||||
|
return toolCalls
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m Message) SetToolCalls(toolCalls any) {
|
||||||
|
toolCallsJson, _ := json.Marshal(toolCalls)
|
||||||
|
m.ToolCalls = toolCallsJson
|
||||||
|
}
|
||||||
|
|
||||||
|
type MediaContent struct {
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
Text string `json:"text"`
|
Text string `json:"text"`
|
||||||
ImageUrl any `json:"image_url,omitempty"`
|
ImageUrl any `json:"image_url,omitempty"`
|
||||||
@@ -137,11 +153,11 @@ func (m Message) IsStringContent() bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m Message) ParseContent() []MediaMessage {
|
func (m Message) ParseContent() []MediaContent {
|
||||||
var contentList []MediaMessage
|
var contentList []MediaContent
|
||||||
var stringContent string
|
var stringContent string
|
||||||
if err := json.Unmarshal(m.Content, &stringContent); err == nil {
|
if err := json.Unmarshal(m.Content, &stringContent); err == nil {
|
||||||
contentList = append(contentList, MediaMessage{
|
contentList = append(contentList, MediaContent{
|
||||||
Type: ContentTypeText,
|
Type: ContentTypeText,
|
||||||
Text: stringContent,
|
Text: stringContent,
|
||||||
})
|
})
|
||||||
@@ -157,7 +173,7 @@ func (m Message) ParseContent() []MediaMessage {
|
|||||||
switch contentMap["type"] {
|
switch contentMap["type"] {
|
||||||
case ContentTypeText:
|
case ContentTypeText:
|
||||||
if subStr, ok := contentMap["text"].(string); ok {
|
if subStr, ok := contentMap["text"].(string); ok {
|
||||||
contentList = append(contentList, MediaMessage{
|
contentList = append(contentList, MediaContent{
|
||||||
Type: ContentTypeText,
|
Type: ContentTypeText,
|
||||||
Text: subStr,
|
Text: subStr,
|
||||||
})
|
})
|
||||||
@@ -170,7 +186,7 @@ func (m Message) ParseContent() []MediaMessage {
|
|||||||
} else {
|
} else {
|
||||||
subObj["detail"] = "high"
|
subObj["detail"] = "high"
|
||||||
}
|
}
|
||||||
contentList = append(contentList, MediaMessage{
|
contentList = append(contentList, MediaContent{
|
||||||
Type: ContentTypeImageURL,
|
Type: ContentTypeImageURL,
|
||||||
ImageUrl: MessageImageUrl{
|
ImageUrl: MessageImageUrl{
|
||||||
Url: subObj["url"].(string),
|
Url: subObj["url"].(string),
|
||||||
@@ -178,7 +194,7 @@ func (m Message) ParseContent() []MediaMessage {
|
|||||||
},
|
},
|
||||||
})
|
})
|
||||||
} else if url, ok := contentMap["image_url"].(string); ok {
|
} else if url, ok := contentMap["image_url"].(string); ok {
|
||||||
contentList = append(contentList, MediaMessage{
|
contentList = append(contentList, MediaContent{
|
||||||
Type: ContentTypeImageURL,
|
Type: ContentTypeImageURL,
|
||||||
ImageUrl: MessageImageUrl{
|
ImageUrl: MessageImageUrl{
|
||||||
Url: url,
|
Url: url,
|
||||||
@@ -188,7 +204,7 @@ func (m Message) ParseContent() []MediaMessage {
|
|||||||
}
|
}
|
||||||
case ContentTypeInputAudio:
|
case ContentTypeInputAudio:
|
||||||
if subObj, ok := contentMap["input_audio"].(map[string]any); ok {
|
if subObj, ok := contentMap["input_audio"].(map[string]any); ok {
|
||||||
contentList = append(contentList, MediaMessage{
|
contentList = append(contentList, MediaContent{
|
||||||
Type: ContentTypeInputAudio,
|
Type: ContentTypeInputAudio,
|
||||||
InputAudio: MessageInputAudio{
|
InputAudio: MessageInputAudio{
|
||||||
Data: subObj["data"].(string),
|
Data: subObj["data"].(string),
|
||||||
|
|||||||
@@ -240,14 +240,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
|
|||||||
claudeMediaMessages = append(claudeMediaMessages, claudeMediaMessage)
|
claudeMediaMessages = append(claudeMediaMessages, claudeMediaMessage)
|
||||||
}
|
}
|
||||||
if message.ToolCalls != nil {
|
if message.ToolCalls != nil {
|
||||||
for _, tc := range message.ToolCalls.([]interface{}) {
|
for _, toolCall := range message.ParseToolCalls() {
|
||||||
toolCallJSON, _ := json.Marshal(tc)
|
|
||||||
var toolCall dto.ToolCall
|
|
||||||
err := json.Unmarshal(toolCallJSON, &toolCall)
|
|
||||||
if err != nil {
|
|
||||||
common.SysError("tool call is not a dto.ToolCall: " + fmt.Sprintf("%v", tc))
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
inputObj := make(map[string]any)
|
inputObj := make(map[string]any)
|
||||||
if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &inputObj); err != nil {
|
if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &inputObj); err != nil {
|
||||||
common.SysError("tool call function arguments is not a map[string]any: " + fmt.Sprintf("%v", toolCall.Function.Arguments))
|
common.SysError("tool call function arguments is not a map[string]any: " + fmt.Sprintf("%v", toolCall.Function.Arguments))
|
||||||
@@ -393,7 +386,7 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.Ope
|
|||||||
}
|
}
|
||||||
choice.SetStringContent(responseText)
|
choice.SetStringContent(responseText)
|
||||||
if len(tools) > 0 {
|
if len(tools) > 0 {
|
||||||
choice.Message.ToolCalls = tools
|
choice.Message.SetToolCalls(tools)
|
||||||
}
|
}
|
||||||
fullTextResponse.Model = claudeResponse.Model
|
fullTextResponse.Model = claudeResponse.Model
|
||||||
choices = append(choices, choice)
|
choices = append(choices, choice)
|
||||||
|
|||||||
@@ -108,50 +108,63 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
|
} else if message.Role == "tool" {
|
||||||
|
message.Role = "model"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var parts []GeminiPart
|
||||||
content := GeminiChatContent{
|
content := GeminiChatContent{
|
||||||
Role: message.Role,
|
Role: message.Role,
|
||||||
//Parts: []GeminiPart{
|
|
||||||
// {
|
|
||||||
// Text: message.StringContent(),
|
|
||||||
// },
|
|
||||||
//},
|
|
||||||
}
|
}
|
||||||
openaiContent := message.ParseContent()
|
isToolCall := false
|
||||||
var parts []GeminiPart
|
if message.ToolCalls != nil {
|
||||||
imageNum := 0
|
isToolCall = true
|
||||||
for _, part := range openaiContent {
|
for _, call := range message.ParseToolCalls() {
|
||||||
if part.Type == dto.ContentTypeText {
|
toolCall := GeminiPart{
|
||||||
parts = append(parts, GeminiPart{
|
FunctionCall: &FunctionCall{
|
||||||
Text: part.Text,
|
FunctionName: call.Function.Name,
|
||||||
})
|
Arguments: call.Function.Parameters,
|
||||||
} else if part.Type == dto.ContentTypeImageURL {
|
},
|
||||||
imageNum += 1
|
|
||||||
|
|
||||||
if constant.GeminiVisionMaxImageNum != -1 && imageNum > constant.GeminiVisionMaxImageNum {
|
|
||||||
return nil, fmt.Errorf("too many images in the message, max allowed is %d", constant.GeminiVisionMaxImageNum)
|
|
||||||
}
|
}
|
||||||
// 判断是否是url
|
parts = append(parts, toolCall)
|
||||||
if strings.HasPrefix(part.ImageUrl.(dto.MessageImageUrl).Url, "http") {
|
}
|
||||||
// 是url,获取图片的类型和base64编码的数据
|
}
|
||||||
mimeType, data, _ := service.GetImageFromUrl(part.ImageUrl.(dto.MessageImageUrl).Url)
|
if !isToolCall {
|
||||||
|
openaiContent := message.ParseContent()
|
||||||
|
imageNum := 0
|
||||||
|
for _, part := range openaiContent {
|
||||||
|
if part.Type == dto.ContentTypeText {
|
||||||
parts = append(parts, GeminiPart{
|
parts = append(parts, GeminiPart{
|
||||||
InlineData: &GeminiInlineData{
|
Text: part.Text,
|
||||||
MimeType: mimeType,
|
|
||||||
Data: data,
|
|
||||||
},
|
|
||||||
})
|
})
|
||||||
} else {
|
} else if part.Type == dto.ContentTypeImageURL {
|
||||||
_, format, base64String, err := service.DecodeBase64ImageData(part.ImageUrl.(dto.MessageImageUrl).Url)
|
imageNum += 1
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("decode base64 image data failed: %s", err.Error())
|
if constant.GeminiVisionMaxImageNum != -1 && imageNum > constant.GeminiVisionMaxImageNum {
|
||||||
|
return nil, fmt.Errorf("too many images in the message, max allowed is %d", constant.GeminiVisionMaxImageNum)
|
||||||
|
}
|
||||||
|
// 判断是否是url
|
||||||
|
if strings.HasPrefix(part.ImageUrl.(dto.MessageImageUrl).Url, "http") {
|
||||||
|
// 是url,获取图片的类型和base64编码的数据
|
||||||
|
mimeType, data, _ := service.GetImageFromUrl(part.ImageUrl.(dto.MessageImageUrl).Url)
|
||||||
|
parts = append(parts, GeminiPart{
|
||||||
|
InlineData: &GeminiInlineData{
|
||||||
|
MimeType: mimeType,
|
||||||
|
Data: data,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
_, format, base64String, err := service.DecodeBase64ImageData(part.ImageUrl.(dto.MessageImageUrl).Url)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("decode base64 image data failed: %s", err.Error())
|
||||||
|
}
|
||||||
|
parts = append(parts, GeminiPart{
|
||||||
|
InlineData: &GeminiInlineData{
|
||||||
|
MimeType: "image/" + format,
|
||||||
|
Data: base64String,
|
||||||
|
},
|
||||||
|
})
|
||||||
}
|
}
|
||||||
parts = append(parts, GeminiPart{
|
|
||||||
InlineData: &GeminiInlineData{
|
|
||||||
MimeType: "image/" + format,
|
|
||||||
Data: base64String,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -161,25 +174,7 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
|
|||||||
if content.Role == "assistant" {
|
if content.Role == "assistant" {
|
||||||
content.Role = "model"
|
content.Role = "model"
|
||||||
}
|
}
|
||||||
// Converting system prompt to prompt from user for the same reason
|
|
||||||
//if content.Role == "system" {
|
|
||||||
// content.Role = "user"
|
|
||||||
// shouldAddDummyModelMessage = true
|
|
||||||
//}
|
|
||||||
geminiRequest.Contents = append(geminiRequest.Contents, content)
|
geminiRequest.Contents = append(geminiRequest.Contents, content)
|
||||||
//
|
|
||||||
//// If a system message is the last message, we need to add a dummy model message to make gemini happy
|
|
||||||
//if shouldAddDummyModelMessage {
|
|
||||||
// geminiRequest.Contents = append(geminiRequest.Contents, GeminiChatContent{
|
|
||||||
// Role: "model",
|
|
||||||
// Parts: []GeminiPart{
|
|
||||||
// {
|
|
||||||
// Text: "Okay",
|
|
||||||
// },
|
|
||||||
// },
|
|
||||||
// })
|
|
||||||
// shouldAddDummyModelMessage = false
|
|
||||||
//}
|
|
||||||
}
|
}
|
||||||
return &geminiRequest, nil
|
return &geminiRequest, nil
|
||||||
}
|
}
|
||||||
@@ -278,7 +273,7 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp
|
|||||||
if len(candidate.Content.Parts) > 0 {
|
if len(candidate.Content.Parts) > 0 {
|
||||||
if candidate.Content.Parts[0].FunctionCall != nil {
|
if candidate.Content.Parts[0].FunctionCall != nil {
|
||||||
choice.FinishReason = constant.FinishReasonToolCalls
|
choice.FinishReason = constant.FinishReasonToolCalls
|
||||||
choice.Message.ToolCalls = getToolCalls(&candidate)
|
choice.Message.SetToolCalls(getToolCalls(&candidate))
|
||||||
} else {
|
} else {
|
||||||
var texts []string
|
var texts []string
|
||||||
for _, part := range candidate.Content.Parts {
|
for _, part := range candidate.Content.Parts {
|
||||||
|
|||||||
Reference in New Issue
Block a user