fix: mutil func call in gemini

This commit is contained in:
Yan
2024-12-23 01:26:14 +08:00
parent eec8f523ce
commit a4795737fe
3 changed files with 114 additions and 76 deletions

1
.gitignore vendored
View File

@@ -8,3 +8,4 @@ build
logs logs
web/dist web/dist
.env .env
one-api

View File

@@ -35,7 +35,9 @@ func StrToMap(str string) map[string]interface{} {
m := make(map[string]interface{}) m := make(map[string]interface{})
err := json.Unmarshal([]byte(str), &m) err := json.Unmarshal([]byte(str), &m)
if err != nil { if err != nil {
return nil return map[string]interface{}{
"result": str,
}
} }
return m return m
} }

View File

@@ -95,7 +95,7 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
geminiRequest.GenerationConfig.ResponseSchema = cleanedSchema geminiRequest.GenerationConfig.ResponseSchema = cleanedSchema
} }
} }
tool_call_ids := make(map[string]string)
//shouldAddDummyModelMessage := false //shouldAddDummyModelMessage := false
for _, message := range textRequest.Messages { for _, message := range textRequest.Messages {
@@ -108,6 +108,27 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
}, },
} }
continue continue
} else if message.Role == "tool" {
if len(geminiRequest.Contents) == 0 || geminiRequest.Contents[len(geminiRequest.Contents)-1].Role != "user" {
geminiRequest.Contents = append(geminiRequest.Contents, GeminiChatContent{
Role: "user",
})
}
var parts = &geminiRequest.Contents[len(geminiRequest.Contents)-1].Parts
name := ""
if message.Name != nil {
name = *message.Name
} else if val, exists := tool_call_ids[message.ToolCallId]; exists {
name = val
}
functionResp := &FunctionResponse{
Name: name,
Response: common.StrToMap(message.StringContent()),
}
*parts = append(*parts, GeminiPart{
FunctionResponse: functionResp,
})
continue
} }
var parts []GeminiPart var parts []GeminiPart
content := GeminiChatContent{ content := GeminiChatContent{
@@ -125,62 +146,49 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
}, },
} }
parts = append(parts, toolCall) parts = append(parts, toolCall)
tool_call_ids[call.ID] = call.Function.Name
} }
} }
if !isToolCall { if !isToolCall {
if message.Role == "tool" { openaiContent := message.ParseContent()
content.Role = "user" imageNum := 0
name := "" for _, part := range openaiContent {
if message.Name != nil { if part.Type == dto.ContentTypeText {
name = *message.Name parts = append(parts, GeminiPart{
} Text: part.Text,
functionResp := &FunctionResponse{ })
Name: name, } else if part.Type == dto.ContentTypeImageURL {
Response: common.StrToMap(message.StringContent()), imageNum += 1
}
parts = append(parts, GeminiPart{
FunctionResponse: functionResp,
})
} else {
openaiContent := message.ParseContent()
imageNum := 0
for _, part := range openaiContent {
if part.Type == dto.ContentTypeText {
parts = append(parts, GeminiPart{
Text: part.Text,
})
} else if part.Type == dto.ContentTypeImageURL {
imageNum += 1
if constant.GeminiVisionMaxImageNum != -1 && imageNum > constant.GeminiVisionMaxImageNum { if constant.GeminiVisionMaxImageNum != -1 && imageNum > constant.GeminiVisionMaxImageNum {
return nil, fmt.Errorf("too many images in the message, max allowed is %d", constant.GeminiVisionMaxImageNum) return nil, fmt.Errorf("too many images in the message, max allowed is %d", constant.GeminiVisionMaxImageNum)
} }
// 判断是否是url // 判断是否是url
if strings.HasPrefix(part.ImageUrl.(dto.MessageImageUrl).Url, "http") { if strings.HasPrefix(part.ImageUrl.(dto.MessageImageUrl).Url, "http") {
// 是url获取图片的类型和base64编码的数据 // 是url获取图片的类型和base64编码的数据
mimeType, data, _ := service.GetImageFromUrl(part.ImageUrl.(dto.MessageImageUrl).Url) mimeType, data, _ := service.GetImageFromUrl(part.ImageUrl.(dto.MessageImageUrl).Url)
parts = append(parts, GeminiPart{ parts = append(parts, GeminiPart{
InlineData: &GeminiInlineData{ InlineData: &GeminiInlineData{
MimeType: mimeType, MimeType: mimeType,
Data: data, Data: data,
}, },
}) })
} else { } else {
_, format, base64String, err := service.DecodeBase64ImageData(part.ImageUrl.(dto.MessageImageUrl).Url) _, format, base64String, err := service.DecodeBase64ImageData(part.ImageUrl.(dto.MessageImageUrl).Url)
if err != nil { if err != nil {
return nil, fmt.Errorf("decode base64 image data failed: %s", err.Error()) return nil, fmt.Errorf("decode base64 image data failed: %s", err.Error())
}
parts = append(parts, GeminiPart{
InlineData: &GeminiInlineData{
MimeType: "image/" + format,
Data: base64String,
},
})
} }
parts = append(parts, GeminiPart{
InlineData: &GeminiInlineData{
MimeType: "image/" + format,
Data: base64String,
},
})
} }
} }
} }
} }
content.Parts = parts content.Parts = parts
// there's no assistant role in gemini and API shall vomit if Role is not user or model // there's no assistant role in gemini and API shall vomit if Role is not user or model
@@ -242,19 +250,13 @@ func (g *GeminiChatResponse) GetResponseText() string {
return "" return ""
} }
func getToolCalls(candidate *GeminiChatCandidate) []dto.ToolCall { func getToolCall(item *GeminiPart) *dto.ToolCall {
var toolCalls []dto.ToolCall
item := candidate.Content.Parts[0]
if item.FunctionCall == nil {
return toolCalls
}
argsBytes, err := json.Marshal(item.FunctionCall.Arguments) argsBytes, err := json.Marshal(item.FunctionCall.Arguments)
if err != nil { if err != nil {
//common.SysError("getToolCalls failed: " + err.Error()) //common.SysError("getToolCall failed: " + err.Error())
return toolCalls return nil
} }
toolCall := dto.ToolCall{ return &dto.ToolCall{
ID: fmt.Sprintf("call_%s", common.GetUUID()), ID: fmt.Sprintf("call_%s", common.GetUUID()),
Type: "function", Type: "function",
Function: dto.FunctionCall{ Function: dto.FunctionCall{
@@ -262,10 +264,32 @@ func getToolCalls(candidate *GeminiChatCandidate) []dto.ToolCall {
Name: item.FunctionCall.FunctionName, Name: item.FunctionCall.FunctionName,
}, },
} }
toolCalls = append(toolCalls, toolCall)
return toolCalls
} }
// func getToolCalls(candidate *GeminiChatCandidate, index int) []dto.ToolCall {
// var toolCalls []dto.ToolCall
// item := candidate.Content.Parts[index]
// if item.FunctionCall == nil {
// return toolCalls
// }
// argsBytes, err := json.Marshal(item.FunctionCall.Arguments)
// if err != nil {
// //common.SysError("getToolCalls failed: " + err.Error())
// return toolCalls
// }
// toolCall := dto.ToolCall{
// ID: fmt.Sprintf("call_%s", common.GetUUID()),
// Type: "function",
// Function: dto.FunctionCall{
// Arguments: string(argsBytes),
// Name: item.FunctionCall.FunctionName,
// },
// }
// toolCalls = append(toolCalls, toolCall)
// return toolCalls
// }
func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResponse { func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResponse {
fullTextResponse := dto.OpenAITextResponse{ fullTextResponse := dto.OpenAITextResponse{
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
@@ -275,6 +299,8 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp
} }
content, _ := json.Marshal("") content, _ := json.Marshal("")
for i, candidate := range response.Candidates { for i, candidate := range response.Candidates {
// jsonData, _ := json.MarshalIndent(candidate, "", " ")
// common.SysLog(fmt.Sprintf("candidate: %v", string(jsonData)))
choice := dto.OpenAITextResponseChoice{ choice := dto.OpenAITextResponseChoice{
Index: i, Index: i,
Message: dto.Message{ Message: dto.Message{
@@ -284,16 +310,20 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp
FinishReason: constant.FinishReasonStop, FinishReason: constant.FinishReasonStop,
} }
if len(candidate.Content.Parts) > 0 { if len(candidate.Content.Parts) > 0 {
if candidate.Content.Parts[0].FunctionCall != nil { var texts []string
choice.FinishReason = constant.FinishReasonToolCalls var tool_calls []dto.ToolCall
choice.Message.SetToolCalls(getToolCalls(&candidate)) for _, part := range candidate.Content.Parts {
} else { if part.FunctionCall != nil {
var texts []string choice.FinishReason = constant.FinishReasonToolCalls
for _, part := range candidate.Content.Parts { if call := getToolCall(&part); call != nil {
tool_calls = append(tool_calls, *call)
}
} else {
texts = append(texts, part.Text) texts = append(texts, part.Text)
} }
choice.Message.SetStringContent(strings.Join(texts, "\n"))
} }
choice.Message.SetStringContent(strings.Join(texts, "\n"))
choice.Message.SetToolCalls(tool_calls)
} }
fullTextResponse.Choices = append(fullTextResponse.Choices, choice) fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
} }
@@ -304,18 +334,23 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *dto.Ch
var choice dto.ChatCompletionsStreamResponseChoice var choice dto.ChatCompletionsStreamResponseChoice
//choice.Delta.SetContentString(geminiResponse.GetResponseText()) //choice.Delta.SetContentString(geminiResponse.GetResponseText())
if len(geminiResponse.Candidates) > 0 && len(geminiResponse.Candidates[0].Content.Parts) > 0 { if len(geminiResponse.Candidates) > 0 && len(geminiResponse.Candidates[0].Content.Parts) > 0 {
respFirstParts := geminiResponse.Candidates[0].Content.Parts var texts []string
if respFirstParts[0].FunctionCall != nil { var tool_calls []dto.ToolCall
// function response for _, part := range geminiResponse.Candidates[0].Content.Parts {
choice.Delta.ToolCalls = getToolCalls(&geminiResponse.Candidates[0]) if part.FunctionCall != nil {
} else { if call := getToolCall(&part); call != nil {
// text response tool_calls = append(tool_calls, *call)
var texts []string }
for _, part := range respFirstParts { } else {
texts = append(texts, part.Text) texts = append(texts, part.Text)
} }
}
if len(texts) > 0 {
choice.Delta.SetContentString(strings.Join(texts, "\n")) choice.Delta.SetContentString(strings.Join(texts, "\n"))
} }
if len(tool_calls) > 0 {
choice.Delta.ToolCalls = tool_calls
}
} }
var response dto.ChatCompletionsStreamResponse var response dto.ChatCompletionsStreamResponse
response.Object = "chat.completion.chunk" response.Object = "chat.completion.chunk"