fix: mutil func call in gemini
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -8,3 +8,4 @@ build
|
|||||||
logs
|
logs
|
||||||
web/dist
|
web/dist
|
||||||
.env
|
.env
|
||||||
|
one-api
|
||||||
@@ -35,7 +35,9 @@ func StrToMap(str string) map[string]interface{} {
|
|||||||
m := make(map[string]interface{})
|
m := make(map[string]interface{})
|
||||||
err := json.Unmarshal([]byte(str), &m)
|
err := json.Unmarshal([]byte(str), &m)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil
|
return map[string]interface{}{
|
||||||
|
"result": str,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return m
|
return m
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -95,7 +95,7 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
|
|||||||
geminiRequest.GenerationConfig.ResponseSchema = cleanedSchema
|
geminiRequest.GenerationConfig.ResponseSchema = cleanedSchema
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
tool_call_ids := make(map[string]string)
|
||||||
//shouldAddDummyModelMessage := false
|
//shouldAddDummyModelMessage := false
|
||||||
for _, message := range textRequest.Messages {
|
for _, message := range textRequest.Messages {
|
||||||
|
|
||||||
@@ -108,6 +108,27 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
|
} else if message.Role == "tool" {
|
||||||
|
if len(geminiRequest.Contents) == 0 || geminiRequest.Contents[len(geminiRequest.Contents)-1].Role != "user" {
|
||||||
|
geminiRequest.Contents = append(geminiRequest.Contents, GeminiChatContent{
|
||||||
|
Role: "user",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
var parts = &geminiRequest.Contents[len(geminiRequest.Contents)-1].Parts
|
||||||
|
name := ""
|
||||||
|
if message.Name != nil {
|
||||||
|
name = *message.Name
|
||||||
|
} else if val, exists := tool_call_ids[message.ToolCallId]; exists {
|
||||||
|
name = val
|
||||||
|
}
|
||||||
|
functionResp := &FunctionResponse{
|
||||||
|
Name: name,
|
||||||
|
Response: common.StrToMap(message.StringContent()),
|
||||||
|
}
|
||||||
|
*parts = append(*parts, GeminiPart{
|
||||||
|
FunctionResponse: functionResp,
|
||||||
|
})
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
var parts []GeminiPart
|
var parts []GeminiPart
|
||||||
content := GeminiChatContent{
|
content := GeminiChatContent{
|
||||||
@@ -125,62 +146,49 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
parts = append(parts, toolCall)
|
parts = append(parts, toolCall)
|
||||||
|
tool_call_ids[call.ID] = call.Function.Name
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if !isToolCall {
|
if !isToolCall {
|
||||||
if message.Role == "tool" {
|
openaiContent := message.ParseContent()
|
||||||
content.Role = "user"
|
imageNum := 0
|
||||||
name := ""
|
for _, part := range openaiContent {
|
||||||
if message.Name != nil {
|
if part.Type == dto.ContentTypeText {
|
||||||
name = *message.Name
|
parts = append(parts, GeminiPart{
|
||||||
}
|
Text: part.Text,
|
||||||
functionResp := &FunctionResponse{
|
})
|
||||||
Name: name,
|
} else if part.Type == dto.ContentTypeImageURL {
|
||||||
Response: common.StrToMap(message.StringContent()),
|
imageNum += 1
|
||||||
}
|
|
||||||
parts = append(parts, GeminiPart{
|
|
||||||
FunctionResponse: functionResp,
|
|
||||||
})
|
|
||||||
} else {
|
|
||||||
openaiContent := message.ParseContent()
|
|
||||||
imageNum := 0
|
|
||||||
for _, part := range openaiContent {
|
|
||||||
if part.Type == dto.ContentTypeText {
|
|
||||||
parts = append(parts, GeminiPart{
|
|
||||||
Text: part.Text,
|
|
||||||
})
|
|
||||||
} else if part.Type == dto.ContentTypeImageURL {
|
|
||||||
imageNum += 1
|
|
||||||
|
|
||||||
if constant.GeminiVisionMaxImageNum != -1 && imageNum > constant.GeminiVisionMaxImageNum {
|
if constant.GeminiVisionMaxImageNum != -1 && imageNum > constant.GeminiVisionMaxImageNum {
|
||||||
return nil, fmt.Errorf("too many images in the message, max allowed is %d", constant.GeminiVisionMaxImageNum)
|
return nil, fmt.Errorf("too many images in the message, max allowed is %d", constant.GeminiVisionMaxImageNum)
|
||||||
}
|
}
|
||||||
// 判断是否是url
|
// 判断是否是url
|
||||||
if strings.HasPrefix(part.ImageUrl.(dto.MessageImageUrl).Url, "http") {
|
if strings.HasPrefix(part.ImageUrl.(dto.MessageImageUrl).Url, "http") {
|
||||||
// 是url,获取图片的类型和base64编码的数据
|
// 是url,获取图片的类型和base64编码的数据
|
||||||
mimeType, data, _ := service.GetImageFromUrl(part.ImageUrl.(dto.MessageImageUrl).Url)
|
mimeType, data, _ := service.GetImageFromUrl(part.ImageUrl.(dto.MessageImageUrl).Url)
|
||||||
parts = append(parts, GeminiPart{
|
parts = append(parts, GeminiPart{
|
||||||
InlineData: &GeminiInlineData{
|
InlineData: &GeminiInlineData{
|
||||||
MimeType: mimeType,
|
MimeType: mimeType,
|
||||||
Data: data,
|
Data: data,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
} else {
|
} else {
|
||||||
_, format, base64String, err := service.DecodeBase64ImageData(part.ImageUrl.(dto.MessageImageUrl).Url)
|
_, format, base64String, err := service.DecodeBase64ImageData(part.ImageUrl.(dto.MessageImageUrl).Url)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("decode base64 image data failed: %s", err.Error())
|
return nil, fmt.Errorf("decode base64 image data failed: %s", err.Error())
|
||||||
}
|
|
||||||
parts = append(parts, GeminiPart{
|
|
||||||
InlineData: &GeminiInlineData{
|
|
||||||
MimeType: "image/" + format,
|
|
||||||
Data: base64String,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
parts = append(parts, GeminiPart{
|
||||||
|
InlineData: &GeminiInlineData{
|
||||||
|
MimeType: "image/" + format,
|
||||||
|
Data: base64String,
|
||||||
|
},
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
content.Parts = parts
|
content.Parts = parts
|
||||||
|
|
||||||
// there's no assistant role in gemini and API shall vomit if Role is not user or model
|
// there's no assistant role in gemini and API shall vomit if Role is not user or model
|
||||||
@@ -242,19 +250,13 @@ func (g *GeminiChatResponse) GetResponseText() string {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
func getToolCalls(candidate *GeminiChatCandidate) []dto.ToolCall {
|
func getToolCall(item *GeminiPart) *dto.ToolCall {
|
||||||
var toolCalls []dto.ToolCall
|
|
||||||
|
|
||||||
item := candidate.Content.Parts[0]
|
|
||||||
if item.FunctionCall == nil {
|
|
||||||
return toolCalls
|
|
||||||
}
|
|
||||||
argsBytes, err := json.Marshal(item.FunctionCall.Arguments)
|
argsBytes, err := json.Marshal(item.FunctionCall.Arguments)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
//common.SysError("getToolCalls failed: " + err.Error())
|
//common.SysError("getToolCall failed: " + err.Error())
|
||||||
return toolCalls
|
return nil
|
||||||
}
|
}
|
||||||
toolCall := dto.ToolCall{
|
return &dto.ToolCall{
|
||||||
ID: fmt.Sprintf("call_%s", common.GetUUID()),
|
ID: fmt.Sprintf("call_%s", common.GetUUID()),
|
||||||
Type: "function",
|
Type: "function",
|
||||||
Function: dto.FunctionCall{
|
Function: dto.FunctionCall{
|
||||||
@@ -262,10 +264,32 @@ func getToolCalls(candidate *GeminiChatCandidate) []dto.ToolCall {
|
|||||||
Name: item.FunctionCall.FunctionName,
|
Name: item.FunctionCall.FunctionName,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
toolCalls = append(toolCalls, toolCall)
|
|
||||||
return toolCalls
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// func getToolCalls(candidate *GeminiChatCandidate, index int) []dto.ToolCall {
|
||||||
|
// var toolCalls []dto.ToolCall
|
||||||
|
|
||||||
|
// item := candidate.Content.Parts[index]
|
||||||
|
// if item.FunctionCall == nil {
|
||||||
|
// return toolCalls
|
||||||
|
// }
|
||||||
|
// argsBytes, err := json.Marshal(item.FunctionCall.Arguments)
|
||||||
|
// if err != nil {
|
||||||
|
// //common.SysError("getToolCalls failed: " + err.Error())
|
||||||
|
// return toolCalls
|
||||||
|
// }
|
||||||
|
// toolCall := dto.ToolCall{
|
||||||
|
// ID: fmt.Sprintf("call_%s", common.GetUUID()),
|
||||||
|
// Type: "function",
|
||||||
|
// Function: dto.FunctionCall{
|
||||||
|
// Arguments: string(argsBytes),
|
||||||
|
// Name: item.FunctionCall.FunctionName,
|
||||||
|
// },
|
||||||
|
// }
|
||||||
|
// toolCalls = append(toolCalls, toolCall)
|
||||||
|
// return toolCalls
|
||||||
|
// }
|
||||||
|
|
||||||
func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResponse {
|
func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResponse {
|
||||||
fullTextResponse := dto.OpenAITextResponse{
|
fullTextResponse := dto.OpenAITextResponse{
|
||||||
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
||||||
@@ -275,6 +299,8 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp
|
|||||||
}
|
}
|
||||||
content, _ := json.Marshal("")
|
content, _ := json.Marshal("")
|
||||||
for i, candidate := range response.Candidates {
|
for i, candidate := range response.Candidates {
|
||||||
|
// jsonData, _ := json.MarshalIndent(candidate, "", " ")
|
||||||
|
// common.SysLog(fmt.Sprintf("candidate: %v", string(jsonData)))
|
||||||
choice := dto.OpenAITextResponseChoice{
|
choice := dto.OpenAITextResponseChoice{
|
||||||
Index: i,
|
Index: i,
|
||||||
Message: dto.Message{
|
Message: dto.Message{
|
||||||
@@ -284,16 +310,20 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp
|
|||||||
FinishReason: constant.FinishReasonStop,
|
FinishReason: constant.FinishReasonStop,
|
||||||
}
|
}
|
||||||
if len(candidate.Content.Parts) > 0 {
|
if len(candidate.Content.Parts) > 0 {
|
||||||
if candidate.Content.Parts[0].FunctionCall != nil {
|
var texts []string
|
||||||
choice.FinishReason = constant.FinishReasonToolCalls
|
var tool_calls []dto.ToolCall
|
||||||
choice.Message.SetToolCalls(getToolCalls(&candidate))
|
for _, part := range candidate.Content.Parts {
|
||||||
} else {
|
if part.FunctionCall != nil {
|
||||||
var texts []string
|
choice.FinishReason = constant.FinishReasonToolCalls
|
||||||
for _, part := range candidate.Content.Parts {
|
if call := getToolCall(&part); call != nil {
|
||||||
|
tool_calls = append(tool_calls, *call)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
texts = append(texts, part.Text)
|
texts = append(texts, part.Text)
|
||||||
}
|
}
|
||||||
choice.Message.SetStringContent(strings.Join(texts, "\n"))
|
|
||||||
}
|
}
|
||||||
|
choice.Message.SetStringContent(strings.Join(texts, "\n"))
|
||||||
|
choice.Message.SetToolCalls(tool_calls)
|
||||||
}
|
}
|
||||||
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
|
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
|
||||||
}
|
}
|
||||||
@@ -304,18 +334,23 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *dto.Ch
|
|||||||
var choice dto.ChatCompletionsStreamResponseChoice
|
var choice dto.ChatCompletionsStreamResponseChoice
|
||||||
//choice.Delta.SetContentString(geminiResponse.GetResponseText())
|
//choice.Delta.SetContentString(geminiResponse.GetResponseText())
|
||||||
if len(geminiResponse.Candidates) > 0 && len(geminiResponse.Candidates[0].Content.Parts) > 0 {
|
if len(geminiResponse.Candidates) > 0 && len(geminiResponse.Candidates[0].Content.Parts) > 0 {
|
||||||
respFirstParts := geminiResponse.Candidates[0].Content.Parts
|
var texts []string
|
||||||
if respFirstParts[0].FunctionCall != nil {
|
var tool_calls []dto.ToolCall
|
||||||
// function response
|
for _, part := range geminiResponse.Candidates[0].Content.Parts {
|
||||||
choice.Delta.ToolCalls = getToolCalls(&geminiResponse.Candidates[0])
|
if part.FunctionCall != nil {
|
||||||
} else {
|
if call := getToolCall(&part); call != nil {
|
||||||
// text response
|
tool_calls = append(tool_calls, *call)
|
||||||
var texts []string
|
}
|
||||||
for _, part := range respFirstParts {
|
} else {
|
||||||
texts = append(texts, part.Text)
|
texts = append(texts, part.Text)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
if len(texts) > 0 {
|
||||||
choice.Delta.SetContentString(strings.Join(texts, "\n"))
|
choice.Delta.SetContentString(strings.Join(texts, "\n"))
|
||||||
}
|
}
|
||||||
|
if len(tool_calls) > 0 {
|
||||||
|
choice.Delta.ToolCalls = tool_calls
|
||||||
|
}
|
||||||
}
|
}
|
||||||
var response dto.ChatCompletionsStreamResponse
|
var response dto.ChatCompletionsStreamResponse
|
||||||
response.Object = "chat.completion.chunk"
|
response.Object = "chat.completion.chunk"
|
||||||
|
|||||||
Reference in New Issue
Block a user