✨ feat: Refactor Gemini tools handling to support JSON raw message format
This commit is contained in:
@@ -3,16 +3,52 @@ package dto
|
|||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
type GeminiChatRequest struct {
|
type GeminiChatRequest struct {
|
||||||
Contents []GeminiChatContent `json:"contents"`
|
Contents []GeminiChatContent `json:"contents"`
|
||||||
SafetySettings []GeminiChatSafetySettings `json:"safetySettings,omitempty"`
|
SafetySettings []GeminiChatSafetySettings `json:"safetySettings,omitempty"`
|
||||||
GenerationConfig GeminiChatGenerationConfig `json:"generationConfig,omitempty"`
|
GenerationConfig GeminiChatGenerationConfig `json:"generationConfig,omitempty"`
|
||||||
Tools []GeminiChatTool `json:"tools,omitempty"`
|
Tools json.RawMessage `json:"tools,omitempty"`
|
||||||
SystemInstructions *GeminiChatContent `json:"systemInstruction,omitempty"`
|
SystemInstructions *GeminiChatContent `json:"systemInstruction,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *GeminiChatRequest) GetTools() []GeminiChatTool {
|
||||||
|
var tools []GeminiChatTool
|
||||||
|
if strings.HasSuffix(string(r.Tools), "[") {
|
||||||
|
// is array
|
||||||
|
if err := common.Unmarshal(r.Tools, &tools); err != nil {
|
||||||
|
common.LogError(nil, "error_unmarshalling_tools: "+err.Error())
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
} else if strings.HasPrefix(string(r.Tools), "{") {
|
||||||
|
// is object
|
||||||
|
singleTool := GeminiChatTool{}
|
||||||
|
if err := common.Unmarshal(r.Tools, &singleTool); err != nil {
|
||||||
|
common.LogError(nil, "error_unmarshalling_single_tool: "+err.Error())
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
tools = []GeminiChatTool{singleTool}
|
||||||
|
}
|
||||||
|
return tools
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *GeminiChatRequest) SetTools(tools []GeminiChatTool) {
|
||||||
|
if len(tools) == 0 {
|
||||||
|
r.Tools = json.RawMessage("[]")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Marshal the tools to JSON
|
||||||
|
data, err := common.Marshal(tools)
|
||||||
|
if err != nil {
|
||||||
|
common.LogError(nil, "error_marshalling_tools: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
r.Tools = data
|
||||||
|
}
|
||||||
|
|
||||||
type GeminiThinkingConfig struct {
|
type GeminiThinkingConfig struct {
|
||||||
IncludeThoughts bool `json:"includeThoughts,omitempty"`
|
IncludeThoughts bool `json:"includeThoughts,omitempty"`
|
||||||
ThinkingBudget *int `json:"thinkingBudget,omitempty"`
|
ThinkingBudget *int `json:"thinkingBudget,omitempty"`
|
||||||
|
|||||||
@@ -267,24 +267,23 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
|
|||||||
tool.Function.Parameters = cleanedParams
|
tool.Function.Parameters = cleanedParams
|
||||||
functions = append(functions, tool.Function)
|
functions = append(functions, tool.Function)
|
||||||
}
|
}
|
||||||
|
geminiTools := geminiRequest.GetTools()
|
||||||
if codeExecution {
|
if codeExecution {
|
||||||
geminiRequest.Tools = append(geminiRequest.Tools, dto.GeminiChatTool{
|
geminiTools = append(geminiTools, dto.GeminiChatTool{
|
||||||
CodeExecution: make(map[string]string),
|
CodeExecution: make(map[string]string),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
if googleSearch {
|
if googleSearch {
|
||||||
geminiRequest.Tools = append(geminiRequest.Tools, dto.GeminiChatTool{
|
geminiTools = append(geminiTools, dto.GeminiChatTool{
|
||||||
GoogleSearch: make(map[string]string),
|
GoogleSearch: make(map[string]string),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
if len(functions) > 0 {
|
if len(functions) > 0 {
|
||||||
geminiRequest.Tools = append(geminiRequest.Tools, dto.GeminiChatTool{
|
geminiTools = append(geminiTools, dto.GeminiChatTool{
|
||||||
FunctionDeclarations: functions,
|
FunctionDeclarations: functions,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
// common.SysLog("tools: " + fmt.Sprintf("%+v", geminiRequest.Tools))
|
geminiRequest.SetTools(geminiTools)
|
||||||
// json_data, _ := json.Marshal(geminiRequest.Tools)
|
|
||||||
// common.SysLog("tools_json: " + string(json_data))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if textRequest.ResponseFormat != nil && (textRequest.ResponseFormat.Type == "json_schema" || textRequest.ResponseFormat.Type == "json_object") {
|
if textRequest.ResponseFormat != nil && (textRequest.ResponseFormat.Type == "json_schema" || textRequest.ResponseFormat.Type == "json_object") {
|
||||||
|
|||||||
@@ -569,9 +569,9 @@ func GeminiToOpenAIRequest(geminiRequest *dto.GeminiChatRequest, info *relaycomm
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 转换工具调用
|
// 转换工具调用
|
||||||
if len(geminiRequest.Tools) > 0 {
|
if len(geminiRequest.GetTools()) > 0 {
|
||||||
var tools []dto.ToolCallRequest
|
var tools []dto.ToolCallRequest
|
||||||
for _, tool := range geminiRequest.Tools {
|
for _, tool := range geminiRequest.GetTools() {
|
||||||
if tool.FunctionDeclarations != nil {
|
if tool.FunctionDeclarations != nil {
|
||||||
// 将 Gemini 的 FunctionDeclarations 转换为 OpenAI 的 ToolCallRequest
|
// 将 Gemini 的 FunctionDeclarations 转换为 OpenAI 的 ToolCallRequest
|
||||||
functionDeclarations, ok := tool.FunctionDeclarations.([]dto.FunctionRequest)
|
functionDeclarations, ok := tool.FunctionDeclarations.([]dto.FunctionRequest)
|
||||||
|
|||||||
Reference in New Issue
Block a user