✨ feat: Refactor Gemini tools handling to support JSON raw message format
This commit is contained in:
@@ -3,16 +3,52 @@ package dto
|
||||
import (
|
||||
"encoding/json"
|
||||
"one-api/common"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type GeminiChatRequest struct {
|
||||
Contents []GeminiChatContent `json:"contents"`
|
||||
SafetySettings []GeminiChatSafetySettings `json:"safetySettings,omitempty"`
|
||||
GenerationConfig GeminiChatGenerationConfig `json:"generationConfig,omitempty"`
|
||||
Tools []GeminiChatTool `json:"tools,omitempty"`
|
||||
Tools json.RawMessage `json:"tools,omitempty"`
|
||||
SystemInstructions *GeminiChatContent `json:"systemInstruction,omitempty"`
|
||||
}
|
||||
|
||||
func (r *GeminiChatRequest) GetTools() []GeminiChatTool {
|
||||
var tools []GeminiChatTool
|
||||
if strings.HasSuffix(string(r.Tools), "[") {
|
||||
// is array
|
||||
if err := common.Unmarshal(r.Tools, &tools); err != nil {
|
||||
common.LogError(nil, "error_unmarshalling_tools: "+err.Error())
|
||||
return nil
|
||||
}
|
||||
} else if strings.HasPrefix(string(r.Tools), "{") {
|
||||
// is object
|
||||
singleTool := GeminiChatTool{}
|
||||
if err := common.Unmarshal(r.Tools, &singleTool); err != nil {
|
||||
common.LogError(nil, "error_unmarshalling_single_tool: "+err.Error())
|
||||
return nil
|
||||
}
|
||||
tools = []GeminiChatTool{singleTool}
|
||||
}
|
||||
return tools
|
||||
}
|
||||
|
||||
func (r *GeminiChatRequest) SetTools(tools []GeminiChatTool) {
|
||||
if len(tools) == 0 {
|
||||
r.Tools = json.RawMessage("[]")
|
||||
return
|
||||
}
|
||||
|
||||
// Marshal the tools to JSON
|
||||
data, err := common.Marshal(tools)
|
||||
if err != nil {
|
||||
common.LogError(nil, "error_marshalling_tools: "+err.Error())
|
||||
return
|
||||
}
|
||||
r.Tools = data
|
||||
}
|
||||
|
||||
type GeminiThinkingConfig struct {
|
||||
IncludeThoughts bool `json:"includeThoughts,omitempty"`
|
||||
ThinkingBudget *int `json:"thinkingBudget,omitempty"`
|
||||
|
||||
@@ -267,24 +267,23 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
|
||||
tool.Function.Parameters = cleanedParams
|
||||
functions = append(functions, tool.Function)
|
||||
}
|
||||
geminiTools := geminiRequest.GetTools()
|
||||
if codeExecution {
|
||||
geminiRequest.Tools = append(geminiRequest.Tools, dto.GeminiChatTool{
|
||||
geminiTools = append(geminiTools, dto.GeminiChatTool{
|
||||
CodeExecution: make(map[string]string),
|
||||
})
|
||||
}
|
||||
if googleSearch {
|
||||
geminiRequest.Tools = append(geminiRequest.Tools, dto.GeminiChatTool{
|
||||
geminiTools = append(geminiTools, dto.GeminiChatTool{
|
||||
GoogleSearch: make(map[string]string),
|
||||
})
|
||||
}
|
||||
if len(functions) > 0 {
|
||||
geminiRequest.Tools = append(geminiRequest.Tools, dto.GeminiChatTool{
|
||||
geminiTools = append(geminiTools, dto.GeminiChatTool{
|
||||
FunctionDeclarations: functions,
|
||||
})
|
||||
}
|
||||
// common.SysLog("tools: " + fmt.Sprintf("%+v", geminiRequest.Tools))
|
||||
// json_data, _ := json.Marshal(geminiRequest.Tools)
|
||||
// common.SysLog("tools_json: " + string(json_data))
|
||||
geminiRequest.SetTools(geminiTools)
|
||||
}
|
||||
|
||||
if textRequest.ResponseFormat != nil && (textRequest.ResponseFormat.Type == "json_schema" || textRequest.ResponseFormat.Type == "json_object") {
|
||||
|
||||
@@ -569,9 +569,9 @@ func GeminiToOpenAIRequest(geminiRequest *dto.GeminiChatRequest, info *relaycomm
|
||||
}
|
||||
|
||||
// 转换工具调用
|
||||
if len(geminiRequest.Tools) > 0 {
|
||||
if len(geminiRequest.GetTools()) > 0 {
|
||||
var tools []dto.ToolCallRequest
|
||||
for _, tool := range geminiRequest.Tools {
|
||||
for _, tool := range geminiRequest.GetTools() {
|
||||
if tool.FunctionDeclarations != nil {
|
||||
// 将 Gemini 的 FunctionDeclarations 转换为 OpenAI 的 ToolCallRequest
|
||||
functionDeclarations, ok := tool.FunctionDeclarations.([]dto.FunctionRequest)
|
||||
|
||||
Reference in New Issue
Block a user