feat: Refactor Gemini tools handling to support JSON raw message format

This commit is contained in:
CaIon
2025-08-11 19:48:04 +08:00
parent 03cfc05afd
commit d3170310ff
3 changed files with 44 additions and 9 deletions

View File

@@ -3,16 +3,52 @@ package dto
import ( import (
"encoding/json" "encoding/json"
"one-api/common" "one-api/common"
"strings"
) )
type GeminiChatRequest struct { type GeminiChatRequest struct {
Contents []GeminiChatContent `json:"contents"` Contents []GeminiChatContent `json:"contents"`
SafetySettings []GeminiChatSafetySettings `json:"safetySettings,omitempty"` SafetySettings []GeminiChatSafetySettings `json:"safetySettings,omitempty"`
GenerationConfig GeminiChatGenerationConfig `json:"generationConfig,omitempty"` GenerationConfig GeminiChatGenerationConfig `json:"generationConfig,omitempty"`
Tools []GeminiChatTool `json:"tools,omitempty"` Tools json.RawMessage `json:"tools,omitempty"`
SystemInstructions *GeminiChatContent `json:"systemInstruction,omitempty"` SystemInstructions *GeminiChatContent `json:"systemInstruction,omitempty"`
} }
func (r *GeminiChatRequest) GetTools() []GeminiChatTool {
var tools []GeminiChatTool
if strings.HasSuffix(string(r.Tools), "[") {
// is array
if err := common.Unmarshal(r.Tools, &tools); err != nil {
common.LogError(nil, "error_unmarshalling_tools: "+err.Error())
return nil
}
} else if strings.HasPrefix(string(r.Tools), "{") {
// is object
singleTool := GeminiChatTool{}
if err := common.Unmarshal(r.Tools, &singleTool); err != nil {
common.LogError(nil, "error_unmarshalling_single_tool: "+err.Error())
return nil
}
tools = []GeminiChatTool{singleTool}
}
return tools
}
func (r *GeminiChatRequest) SetTools(tools []GeminiChatTool) {
if len(tools) == 0 {
r.Tools = json.RawMessage("[]")
return
}
// Marshal the tools to JSON
data, err := common.Marshal(tools)
if err != nil {
common.LogError(nil, "error_marshalling_tools: "+err.Error())
return
}
r.Tools = data
}
type GeminiThinkingConfig struct { type GeminiThinkingConfig struct {
IncludeThoughts bool `json:"includeThoughts,omitempty"` IncludeThoughts bool `json:"includeThoughts,omitempty"`
ThinkingBudget *int `json:"thinkingBudget,omitempty"` ThinkingBudget *int `json:"thinkingBudget,omitempty"`

View File

@@ -267,24 +267,23 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
tool.Function.Parameters = cleanedParams tool.Function.Parameters = cleanedParams
functions = append(functions, tool.Function) functions = append(functions, tool.Function)
} }
geminiTools := geminiRequest.GetTools()
if codeExecution { if codeExecution {
geminiRequest.Tools = append(geminiRequest.Tools, dto.GeminiChatTool{ geminiTools = append(geminiTools, dto.GeminiChatTool{
CodeExecution: make(map[string]string), CodeExecution: make(map[string]string),
}) })
} }
if googleSearch { if googleSearch {
geminiRequest.Tools = append(geminiRequest.Tools, dto.GeminiChatTool{ geminiTools = append(geminiTools, dto.GeminiChatTool{
GoogleSearch: make(map[string]string), GoogleSearch: make(map[string]string),
}) })
} }
if len(functions) > 0 { if len(functions) > 0 {
geminiRequest.Tools = append(geminiRequest.Tools, dto.GeminiChatTool{ geminiTools = append(geminiTools, dto.GeminiChatTool{
FunctionDeclarations: functions, FunctionDeclarations: functions,
}) })
} }
// common.SysLog("tools: " + fmt.Sprintf("%+v", geminiRequest.Tools)) geminiRequest.SetTools(geminiTools)
// json_data, _ := json.Marshal(geminiRequest.Tools)
// common.SysLog("tools_json: " + string(json_data))
} }
if textRequest.ResponseFormat != nil && (textRequest.ResponseFormat.Type == "json_schema" || textRequest.ResponseFormat.Type == "json_object") { if textRequest.ResponseFormat != nil && (textRequest.ResponseFormat.Type == "json_schema" || textRequest.ResponseFormat.Type == "json_object") {

View File

@@ -569,9 +569,9 @@ func GeminiToOpenAIRequest(geminiRequest *dto.GeminiChatRequest, info *relaycomm
} }
// 转换工具调用 // 转换工具调用
if len(geminiRequest.Tools) > 0 { if len(geminiRequest.GetTools()) > 0 {
var tools []dto.ToolCallRequest var tools []dto.ToolCallRequest
for _, tool := range geminiRequest.Tools { for _, tool := range geminiRequest.GetTools() {
if tool.FunctionDeclarations != nil { if tool.FunctionDeclarations != nil {
// 将 Gemini 的 FunctionDeclarations 转换为 OpenAI 的 ToolCallRequest // 将 Gemini 的 FunctionDeclarations 转换为 OpenAI 的 ToolCallRequest
functionDeclarations, ok := tool.FunctionDeclarations.([]dto.FunctionRequest) functionDeclarations, ok := tool.FunctionDeclarations.([]dto.FunctionRequest)