diff --git a/common/str.go b/common/str.go index e3834694..d42fd837 100644 --- a/common/str.go +++ b/common/str.go @@ -35,9 +35,7 @@ func StrToMap(str string) map[string]interface{} { m := make(map[string]interface{}) err := json.Unmarshal([]byte(str), &m) if err != nil { - return map[string]interface{}{ - "result": str, - } + return nil } return m } diff --git a/constant/finish_reason.go b/constant/finish_reason.go index 8d6289a6..5a752a5f 100644 --- a/constant/finish_reason.go +++ b/constant/finish_reason.go @@ -1,6 +1,9 @@ package constant var ( - FinishReasonStop = "stop" - FinishReasonToolCalls = "tool_calls" + FinishReasonStop = "stop" + FinishReasonToolCalls = "tool_calls" + FinishReasonLength = "length" + FinishReasonFunctionCall = "function_call" + FinishReasonContentFilter = "content_filter" ) diff --git a/relay/channel/gemini/dto.go b/relay/channel/gemini/dto.go index 727446c5..6ab002ef 100644 --- a/relay/channel/gemini/dto.go +++ b/relay/channel/gemini/dto.go @@ -4,7 +4,7 @@ type GeminiChatRequest struct { Contents []GeminiChatContent `json:"contents"` SafetySettings []GeminiChatSafetySettings `json:"safety_settings,omitempty"` GenerationConfig GeminiChatGenerationConfig `json:"generation_config,omitempty"` - Tools []GeminiChatTools `json:"tools,omitempty"` + Tools []GeminiChatTool `json:"tools,omitempty"` SystemInstructions *GeminiChatContent `json:"system_instruction,omitempty"` } @@ -18,16 +18,39 @@ type FunctionCall struct { Arguments any `json:"args"` } +type GeminiFunctionResponseContent struct { + Name string `json:"name"` + Content any `json:"content"` +} + type FunctionResponse struct { - Name string `json:"name"` - Response any `json:"response"` + Name string `json:"name"` + Response GeminiFunctionResponseContent `json:"response"` +} + +type GeminiPartExecutableCode struct { + Language string `json:"language,omitempty"` + Code string `json:"code,omitempty"` +} + +type GeminiPartCodeExecutionResult struct { + Outcome string `json:"outcome,omitempty"` + Output string `json:"output,omitempty"` +} + +type GeminiFileData struct { + MimeType string `json:"mimeType,omitempty"` + FileUri string `json:"fileUri,omitempty"` } type GeminiPart struct { - Text string `json:"text,omitempty"` - InlineData *GeminiInlineData `json:"inlineData,omitempty"` - FunctionCall *FunctionCall `json:"functionCall,omitempty"` - FunctionResponse *FunctionResponse `json:"functionResponse,omitempty"` + Text string `json:"text,omitempty"` + InlineData *GeminiInlineData `json:"inlineData,omitempty"` + FunctionCall *FunctionCall `json:"functionCall,omitempty"` + FunctionResponse *FunctionResponse `json:"functionResponse,omitempty"` + FileData *GeminiFileData `json:"fileData,omitempty"` + ExecutableCode *GeminiPartExecutableCode `json:"executableCode,omitempty"` + CodeExecutionResult *GeminiPartCodeExecutionResult `json:"codeExecutionResult,omitempty"` } type GeminiChatContent struct { @@ -40,9 +63,11 @@ type GeminiChatSafetySettings struct { Threshold string `json:"threshold"` } -type GeminiChatTools struct { - GoogleSearch any `json:"googleSearch,omitempty"` - FunctionDeclarations any `json:"functionDeclarations,omitempty"` +type GeminiChatTool struct { + GoogleSearch any `json:"googleSearch,omitempty"` + GoogleSearchRetrieval any `json:"googleSearchRetrieval,omitempty"` + CodeExecution any `json:"codeExecution,omitempty"` + FunctionDeclarations any `json:"functionDeclarations,omitempty"` } type GeminiChatGenerationConfig struct { @@ -54,11 +79,12 @@ type GeminiChatGenerationConfig struct { StopSequences []string `json:"stopSequences,omitempty"` ResponseMimeType string `json:"responseMimeType,omitempty"` ResponseSchema any `json:"responseSchema,omitempty"` + Seed int64 `json:"seed,omitempty"` } type GeminiChatCandidate struct { Content GeminiChatContent `json:"content"` - FinishReason string `json:"finishReason"` + FinishReason *string `json:"finishReason"` Index int64 `json:"index"` SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"` } diff --git a/relay/channel/gemini/relay-gemini.go b/relay/channel/gemini/relay-gemini.go index 4c12812a..eeab7927 100644 --- a/relay/channel/gemini/relay-gemini.go +++ b/relay/channel/gemini/relay-gemini.go @@ -18,6 +18,7 @@ import ( // Setting safety to the lowest possible values since Gemini is already powerless enough func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatRequest, error) { + geminiRequest := GeminiChatRequest{ Contents: make([]GeminiChatContent, 0, len(textRequest.Messages)), SafetySettings: []GeminiChatSafetySettings{ @@ -46,16 +47,24 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque Temperature: textRequest.Temperature, TopP: textRequest.TopP, MaxOutputTokens: textRequest.MaxTokens, + Seed: int64(textRequest.Seed), }, } + + // openaiContent.FuncToToolCalls() if textRequest.Tools != nil { functions := make([]dto.FunctionCall, 0, len(textRequest.Tools)) googleSearch := false + codeExecution := false for _, tool := range textRequest.Tools { if tool.Function.Name == "googleSearch" { googleSearch = true continue } + if tool.Function.Name == "codeExecution" { + codeExecution = true + continue + } if tool.Function.Parameters != nil { params, ok := tool.Function.Parameters.(map[string]interface{}) if ok { @@ -68,25 +77,32 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque } functions = append(functions, tool.Function) } - if len(functions) > 0 { - geminiRequest.Tools = []GeminiChatTools{ - { - FunctionDeclarations: functions, - }, - } + if codeExecution { + geminiRequest.Tools = append(geminiRequest.Tools, GeminiChatTool{ + CodeExecution: make(map[string]string), + }) } if googleSearch { - geminiRequest.Tools = append(geminiRequest.Tools, GeminiChatTools{ + geminiRequest.Tools = append(geminiRequest.Tools, GeminiChatTool{ GoogleSearch: make(map[string]string), }) } + if len(functions) > 0 { + geminiRequest.Tools = append(geminiRequest.Tools, GeminiChatTool{ + FunctionDeclarations: functions, + }) + } + // common.SysLog("tools: " + fmt.Sprintf("%+v", geminiRequest.Tools)) + // json_data, _ := json.Marshal(geminiRequest.Tools) + // common.SysLog("tools_json: " + string(json_data)) } else if textRequest.Functions != nil { - geminiRequest.Tools = []GeminiChatTools{ + geminiRequest.Tools = []GeminiChatTool{ { FunctionDeclarations: textRequest.Functions, }, } } + if textRequest.ResponseFormat != nil && (textRequest.ResponseFormat.Type == "json_schema" || textRequest.ResponseFormat.Type == "json_object") { geminiRequest.GenerationConfig.ResponseMimeType = "application/json" @@ -96,20 +112,14 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque } } tool_call_ids := make(map[string]string) + var system_content []string //shouldAddDummyModelMessage := false for _, message := range textRequest.Messages { - if message.Role == "system" { - geminiRequest.SystemInstructions = &GeminiChatContent{ - Parts: []GeminiPart{ - { - Text: message.StringContent(), - }, - }, - } + system_content = append(system_content, message.StringContent()) continue - } else if message.Role == "tool" { - if len(geminiRequest.Contents) == 0 || geminiRequest.Contents[len(geminiRequest.Contents)-1].Role != "user" { + } else if message.Role == "tool" || message.Role == "function" { + if len(geminiRequest.Contents) == 0 || geminiRequest.Contents[len(geminiRequest.Contents)-1].Role == "model" { geminiRequest.Contents = append(geminiRequest.Contents, GeminiChatContent{ Role: "user", }) @@ -121,9 +131,16 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque } else if val, exists := tool_call_ids[message.ToolCallId]; exists { name = val } + content := common.StrToMap(message.StringContent()) functionResp := &FunctionResponse{ - Name: name, - Response: common.StrToMap(message.StringContent()), + Name: name, + Response: GeminiFunctionResponseContent{ + Name: name, + Content: content, + }, + } + if content == nil { + functionResp.Response.Content = message.StringContent() } *parts = append(*parts, GeminiPart{ FunctionResponse: functionResp, @@ -134,57 +151,65 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque content := GeminiChatContent{ Role: message.Role, } - isToolCall := false + // isToolCall := false if message.ToolCalls != nil { - message.Role = "model" - isToolCall = true + // message.Role = "model" + // isToolCall = true for _, call := range message.ParseToolCalls() { + args := map[string]interface{}{} + if call.Function.Arguments != "" { + if json.Unmarshal([]byte(call.Function.Arguments), &args) != nil { + return nil, fmt.Errorf("invalid arguments for function %s, args: %s", call.Function.Name, call.Function.Arguments) + } + } toolCall := GeminiPart{ FunctionCall: &FunctionCall{ FunctionName: call.Function.Name, - Arguments: call.Function.Parameters, + Arguments: args, }, } parts = append(parts, toolCall) tool_call_ids[call.ID] = call.Function.Name } } - if !isToolCall { - openaiContent := message.ParseContent() - imageNum := 0 - for _, part := range openaiContent { - if part.Type == dto.ContentTypeText { - parts = append(parts, GeminiPart{ - Text: part.Text, - }) - } else if part.Type == dto.ContentTypeImageURL { - imageNum += 1 - if constant.GeminiVisionMaxImageNum != -1 && imageNum > constant.GeminiVisionMaxImageNum { - return nil, fmt.Errorf("too many images in the message, max allowed is %d", constant.GeminiVisionMaxImageNum) - } - // 判断是否是url - if strings.HasPrefix(part.ImageUrl.(dto.MessageImageUrl).Url, "http") { - // 是url,获取图片的类型和base64编码的数据 - mimeType, data, _ := service.GetImageFromUrl(part.ImageUrl.(dto.MessageImageUrl).Url) - parts = append(parts, GeminiPart{ - InlineData: &GeminiInlineData{ - MimeType: mimeType, - Data: data, - }, - }) - } else { - _, format, base64String, err := service.DecodeBase64ImageData(part.ImageUrl.(dto.MessageImageUrl).Url) - if err != nil { - return nil, fmt.Errorf("decode base64 image data failed: %s", err.Error()) - } - parts = append(parts, GeminiPart{ - InlineData: &GeminiInlineData{ - MimeType: "image/" + format, - Data: base64String, - }, - }) + openaiContent := message.ParseContent() + imageNum := 0 + for _, part := range openaiContent { + if part.Type == dto.ContentTypeText { + if part.Text == "" { + continue + } + parts = append(parts, GeminiPart{ + Text: part.Text, + }) + } else if part.Type == dto.ContentTypeImageURL { + imageNum += 1 + + if constant.GeminiVisionMaxImageNum != -1 && imageNum > constant.GeminiVisionMaxImageNum { + return nil, fmt.Errorf("too many images in the message, max allowed is %d", constant.GeminiVisionMaxImageNum) + } + // 判断是否是url + if strings.HasPrefix(part.ImageUrl.(dto.MessageImageUrl).Url, "http") { + // 是url,获取图片的类型和base64编码的数据 + mimeType, data, _ := service.GetImageFromUrl(part.ImageUrl.(dto.MessageImageUrl).Url) + parts = append(parts, GeminiPart{ + InlineData: &GeminiInlineData{ + MimeType: mimeType, + Data: data, + }, + }) + } else { + _, format, base64String, err := service.DecodeBase64ImageData(part.ImageUrl.(dto.MessageImageUrl).Url) + if err != nil { + return nil, fmt.Errorf("decode base64 image data failed: %s", err.Error()) } + parts = append(parts, GeminiPart{ + InlineData: &GeminiInlineData{ + MimeType: "image/" + format, + Data: base64String, + }, + }) } } } @@ -197,6 +222,17 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque } geminiRequest.Contents = append(geminiRequest.Contents, content) } + + if len(system_content) > 0 { + geminiRequest.SystemInstructions = &GeminiChatContent{ + Parts: []GeminiPart{ + { + Text: strings.Join(system_content, "\n"), + }, + }, + } + } + return &geminiRequest, nil } @@ -240,15 +276,15 @@ func removeAdditionalPropertiesWithDepth(schema interface{}, depth int) interfac return v } -func (g *GeminiChatResponse) GetResponseText() string { - if g == nil { - return "" - } - if len(g.Candidates) > 0 && len(g.Candidates[0].Content.Parts) > 0 { - return g.Candidates[0].Content.Parts[0].Text - } - return "" -} +// func (g *GeminiChatResponse) GetResponseText() string { +// if g == nil { +// return "" +// } +// if len(g.Candidates) > 0 && len(g.Candidates[0].Content.Parts) > 0 { +// return g.Candidates[0].Content.Parts[0].Text +// } +// return "" +// } func getToolCall(item *GeminiPart) *dto.ToolCall { argsBytes, err := json.Marshal(item.FunctionCall.Arguments) @@ -298,11 +334,10 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp Choices: make([]dto.OpenAITextResponseChoice, 0, len(response.Candidates)), } content, _ := json.Marshal("") - for i, candidate := range response.Candidates { - // jsonData, _ := json.MarshalIndent(candidate, "", " ") - // common.SysLog(fmt.Sprintf("candidate: %v", string(jsonData))) + is_tool_call := false + for _, candidate := range response.Candidates { choice := dto.OpenAITextResponseChoice{ - Index: i, + Index: int(candidate.Index), Message: dto.Message{ Role: "assistant", Content: content, @@ -319,48 +354,107 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp tool_calls = append(tool_calls, *call) } } else { - texts = append(texts, part.Text) + if part.ExecutableCode != nil { + texts = append(texts, "```"+part.ExecutableCode.Language+"\n"+part.ExecutableCode.Code+"\n```") + } else if part.CodeExecutionResult != nil { + texts = append(texts, "```output\n"+part.CodeExecutionResult.Output+"\n```") + } else { + // 过滤掉空行 + if part.Text != "\n" { + texts = append(texts, part.Text) + } + } } } + if len(tool_calls) > 0 { + choice.Message.SetToolCalls(tool_calls) + is_tool_call = true + } + // 过滤掉空行 + choice.Message.SetStringContent(strings.Join(texts, "\n")) - choice.Message.SetToolCalls(tool_calls) + } + if candidate.FinishReason != nil { + switch *candidate.FinishReason { + case "STOP": + choice.FinishReason = constant.FinishReasonStop + case "MAX_TOKENS": + choice.FinishReason = constant.FinishReasonLength + default: + choice.FinishReason = constant.FinishReasonContentFilter + } + } + if is_tool_call { + choice.FinishReason = constant.FinishReasonToolCalls + } + fullTextResponse.Choices = append(fullTextResponse.Choices, choice) } return &fullTextResponse } -func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *dto.ChatCompletionsStreamResponse { - var choice dto.ChatCompletionsStreamResponseChoice - //choice.Delta.SetContentString(geminiResponse.GetResponseText()) - if len(geminiResponse.Candidates) > 0 && len(geminiResponse.Candidates[0].Content.Parts) > 0 { +func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.ChatCompletionsStreamResponse, bool) { + choices := make([]dto.ChatCompletionsStreamResponseChoice, 0, len(geminiResponse.Candidates)) + is_stop := false + for _, candidate := range geminiResponse.Candidates { + if candidate.FinishReason != nil && *candidate.FinishReason == "STOP" { + is_stop = true + candidate.FinishReason = nil + } + choice := dto.ChatCompletionsStreamResponseChoice{ + Index: int(candidate.Index), + Delta: dto.ChatCompletionsStreamResponseChoiceDelta{ + Role: "assistant", + }, + } var texts []string - var tool_calls []dto.ToolCall - for _, part := range geminiResponse.Candidates[0].Content.Parts { - if part.FunctionCall != nil { - if call := getToolCall(&part); call != nil { - tool_calls = append(tool_calls, *call) - } - } else { - texts = append(texts, part.Text) + isTools := false + if candidate.FinishReason != nil { + // p := GeminiConvertFinishReason(*candidate.FinishReason) + switch *candidate.FinishReason { + case "STOP": + choice.FinishReason = &constant.FinishReasonStop + case "MAX_TOKENS": + choice.FinishReason = &constant.FinishReasonLength + default: + choice.FinishReason = &constant.FinishReasonContentFilter } } - if len(texts) > 0 { - choice.Delta.SetContentString(strings.Join(texts, "\n")) + for _, part := range candidate.Content.Parts { + if part.FunctionCall != nil { + isTools = true + if call := getToolCall(&part); call != nil { + choice.Delta.ToolCalls = append(choice.Delta.ToolCalls, *call) + } + } else { + if part.ExecutableCode != nil { + texts = append(texts, "```"+part.ExecutableCode.Language+"\n"+part.ExecutableCode.Code+"\n```\n") + } else if part.CodeExecutionResult != nil { + texts = append(texts, "```output\n"+part.CodeExecutionResult.Output+"\n```\n") + } else { + if part.Text != "\n" { + texts = append(texts, part.Text) + } + } + } } - if len(tool_calls) > 0 { - choice.Delta.ToolCalls = tool_calls + choice.Delta.SetContentString(strings.Join(texts, "\n")) + if isTools { + choice.FinishReason = &constant.FinishReasonToolCalls } + choices = append(choices, choice) } + var response dto.ChatCompletionsStreamResponse response.Object = "chat.completion.chunk" response.Model = "gemini" - response.Choices = []dto.ChatCompletionsStreamResponseChoice{choice} - return &response + response.Choices = choices + return &response, is_stop } func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { - responseText := "" + // responseText := "" id := fmt.Sprintf("chatcmpl-%s", common.GetUUID()) createAt := common.GetTimestamp() var usage = &dto.Usage{} @@ -384,14 +478,11 @@ func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom continue } - response := streamResponseGeminiChat2OpenAI(&geminiResponse) - if response == nil { - continue - } + response, is_stop := streamResponseGeminiChat2OpenAI(&geminiResponse) response.Id = id response.Created = createAt response.Model = info.UpstreamModelName - responseText += response.Choices[0].Delta.GetContentString() + // responseText += response.Choices[0].Delta.GetContentString() if geminiResponse.UsageMetadata.TotalTokenCount != 0 { usage.PromptTokens = geminiResponse.UsageMetadata.PromptTokenCount usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount @@ -400,12 +491,17 @@ func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom if err != nil { common.LogError(c, err.Error()) } + if is_stop { + response := service.GenerateStopResponse(id, createAt, info.UpstreamModelName, constant.FinishReasonStop) + service.ObjectData(c, response) + } } - response := service.GenerateStopResponse(id, createAt, info.UpstreamModelName, constant.FinishReasonStop) - service.ObjectData(c, response) + var response *dto.ChatCompletionsStreamResponse usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens + usage.PromptTokensDetails.TextTokens = usage.PromptTokens + usage.CompletionTokenDetails.TextTokens = usage.CompletionTokens if info.ShouldIncludeUsage { response = service.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage)