Merge pull request #1418 from YanzheL/fix/1161-gemini-google-search-grounding
fix(gemini): preserve google search grounding tools
This commit is contained in:
@@ -730,13 +730,14 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
|
||||
})
|
||||
}
|
||||
|
||||
if len(funcDecls) == 0 {
|
||||
if !hasWebSearch {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Web Search 工具映射
|
||||
return []GeminiToolDeclaration{{
|
||||
var declarations []GeminiToolDeclaration
|
||||
if len(funcDecls) > 0 {
|
||||
declarations = append(declarations, GeminiToolDeclaration{
|
||||
FunctionDeclarations: funcDecls,
|
||||
})
|
||||
}
|
||||
if hasWebSearch {
|
||||
declarations = append(declarations, GeminiToolDeclaration{
|
||||
GoogleSearch: &GeminiGoogleSearch{
|
||||
EnhancedContent: &GeminiEnhancedContent{
|
||||
ImageSearch: &GeminiImageSearch{
|
||||
@@ -744,10 +745,11 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
|
||||
},
|
||||
},
|
||||
},
|
||||
}}
|
||||
})
|
||||
}
|
||||
if len(declarations) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return []GeminiToolDeclaration{{
|
||||
FunctionDeclarations: funcDecls,
|
||||
}}
|
||||
return declarations
|
||||
}
|
||||
|
||||
@@ -263,6 +263,29 @@ func TestBuildTools_CustomTypeTools(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildTools_PreservesWebSearchAlongsideFunctions(t *testing.T) {
|
||||
tools := []ClaudeTool{
|
||||
{
|
||||
Name: "get_weather",
|
||||
Description: "Get weather information",
|
||||
InputSchema: map[string]any{"type": "object"},
|
||||
},
|
||||
{
|
||||
Type: "web_search_20250305",
|
||||
Name: "web_search",
|
||||
},
|
||||
}
|
||||
|
||||
result := buildTools(tools)
|
||||
require.Len(t, result, 2)
|
||||
require.Len(t, result[0].FunctionDeclarations, 1)
|
||||
require.Equal(t, "get_weather", result[0].FunctionDeclarations[0].Name)
|
||||
require.NotNil(t, result[1].GoogleSearch)
|
||||
require.NotNil(t, result[1].GoogleSearch.EnhancedContent)
|
||||
require.NotNil(t, result[1].GoogleSearch.EnhancedContent.ImageSearch)
|
||||
require.Equal(t, 5, result[1].GoogleSearch.EnhancedContent.ImageSearch.MaxResultCount)
|
||||
}
|
||||
|
||||
func TestBuildGenerationConfig_ThinkingDynamicBudget(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -400,3 +423,36 @@ func TestTransformClaudeToGeminiWithOptions_PreservesBillingHeaderSystemBlock(t
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTransformClaudeToGeminiWithOptions_PreservesWebSearchAlongsideFunctions(t *testing.T) {
|
||||
claudeReq := &ClaudeRequest{
|
||||
Model: "claude-3-5-sonnet-latest",
|
||||
Messages: []ClaudeMessage{
|
||||
{
|
||||
Role: "user",
|
||||
Content: json.RawMessage(`[{"type":"text","text":"hello"}]`),
|
||||
},
|
||||
},
|
||||
Tools: []ClaudeTool{
|
||||
{
|
||||
Name: "get_weather",
|
||||
Description: "Get weather information",
|
||||
InputSchema: map[string]any{"type": "object"},
|
||||
},
|
||||
{
|
||||
Type: "web_search_20250305",
|
||||
Name: "web_search",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
body, err := TransformClaudeToGeminiWithOptions(claudeReq, "project-1", "gemini-2.5-flash", DefaultTransformOptions())
|
||||
require.NoError(t, err)
|
||||
|
||||
var req V1InternalRequest
|
||||
require.NoError(t, json.Unmarshal(body, &req))
|
||||
require.Len(t, req.Request.Tools, 2)
|
||||
require.Len(t, req.Request.Tools[0].FunctionDeclarations, 1)
|
||||
require.Equal(t, "get_weather", req.Request.Tools[0].FunctionDeclarations[0].Name)
|
||||
require.NotNil(t, req.Request.Tools[1].GoogleSearch)
|
||||
}
|
||||
|
||||
@@ -612,7 +612,8 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
||||
fullURL += "?alt=sse"
|
||||
}
|
||||
|
||||
upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(geminiReq))
|
||||
restGeminiReq := normalizeGeminiRequestForAIStudio(geminiReq)
|
||||
upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(restGeminiReq))
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
@@ -685,7 +686,8 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
||||
fullURL += "?alt=sse"
|
||||
}
|
||||
|
||||
upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(geminiReq))
|
||||
restGeminiReq := normalizeGeminiRequestForAIStudio(geminiReq)
|
||||
upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(restGeminiReq))
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
@@ -3184,12 +3186,17 @@ func convertClaudeToolsToGeminiTools(tools any) []any {
|
||||
return nil
|
||||
}
|
||||
|
||||
hasWebSearch := false
|
||||
funcDecls := make([]any, 0, len(arr))
|
||||
for _, t := range arr {
|
||||
tm, ok := t.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if isClaudeWebSearchToolMap(tm) {
|
||||
hasWebSearch = true
|
||||
continue
|
||||
}
|
||||
|
||||
var name, desc string
|
||||
var params any
|
||||
@@ -3233,13 +3240,75 @@ func convertClaudeToolsToGeminiTools(tools any) []any {
|
||||
})
|
||||
}
|
||||
|
||||
if len(funcDecls) == 0 {
|
||||
out := make([]any, 0, 2)
|
||||
if len(funcDecls) > 0 {
|
||||
out = append(out, map[string]any{
|
||||
"functionDeclarations": funcDecls,
|
||||
})
|
||||
}
|
||||
if hasWebSearch {
|
||||
out = append(out, map[string]any{
|
||||
"googleSearch": map[string]any{},
|
||||
})
|
||||
}
|
||||
if len(out) == 0 {
|
||||
return nil
|
||||
}
|
||||
return []any{
|
||||
map[string]any{
|
||||
"functionDeclarations": funcDecls,
|
||||
},
|
||||
return out
|
||||
}
|
||||
|
||||
func normalizeGeminiRequestForAIStudio(body []byte) []byte {
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal(body, &payload); err != nil {
|
||||
return body
|
||||
}
|
||||
|
||||
tools, ok := payload["tools"].([]any)
|
||||
if !ok || len(tools) == 0 {
|
||||
return body
|
||||
}
|
||||
|
||||
modified := false
|
||||
for _, rawTool := range tools {
|
||||
tool, ok := rawTool.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
googleSearch, ok := tool["googleSearch"]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if _, exists := tool["google_search"]; exists {
|
||||
continue
|
||||
}
|
||||
tool["google_search"] = googleSearch
|
||||
delete(tool, "googleSearch")
|
||||
modified = true
|
||||
}
|
||||
|
||||
if !modified {
|
||||
return body
|
||||
}
|
||||
|
||||
normalized, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return body
|
||||
}
|
||||
return normalized
|
||||
}
|
||||
|
||||
func isClaudeWebSearchToolMap(tool map[string]any) bool {
|
||||
toolType, _ := tool["type"].(string)
|
||||
if strings.HasPrefix(toolType, "web_search") || toolType == "google_search" {
|
||||
return true
|
||||
}
|
||||
|
||||
name, _ := tool["name"].(string)
|
||||
switch strings.TrimSpace(name) {
|
||||
case "web_search", "google_search", "web_search_20250305":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -164,6 +164,35 @@ func TestConvertClaudeToolsToGeminiTools_CustomType(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeToolsToGeminiTools_PreservesWebSearchAlongsideFunctions(t *testing.T) {
|
||||
tools := []any{
|
||||
map[string]any{
|
||||
"name": "get_weather",
|
||||
"description": "Get weather info",
|
||||
"input_schema": map[string]any{"type": "object"},
|
||||
},
|
||||
map[string]any{
|
||||
"type": "web_search_20250305",
|
||||
"name": "web_search",
|
||||
},
|
||||
}
|
||||
|
||||
result := convertClaudeToolsToGeminiTools(tools)
|
||||
require.Len(t, result, 2)
|
||||
|
||||
functionDecl, ok := result[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
funcDecls, ok := functionDecl["functionDeclarations"].([]any)
|
||||
require.True(t, ok)
|
||||
require.Len(t, funcDecls, 1)
|
||||
|
||||
searchDecl, ok := result[1].(map[string]any)
|
||||
require.True(t, ok)
|
||||
googleSearch, ok := searchDecl["googleSearch"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Empty(t, googleSearch)
|
||||
}
|
||||
|
||||
func TestGeminiHandleNativeNonStreamingResponse_DebugDisabledDoesNotEmitHeaderLogs(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
logSink, restore := captureStructuredLog(t)
|
||||
@@ -232,6 +261,53 @@ func TestGeminiMessagesCompatServiceForward_PreservesRequestedModelAndMappedUpst
|
||||
require.Contains(t, httpStub.lastReq.URL.String(), "/models/claude-sonnet-4-20250514:")
|
||||
}
|
||||
|
||||
func TestGeminiMessagesCompatServiceForward_NormalizesWebSearchToolForAIStudio(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
|
||||
httpStub := &geminiCompatHTTPUpstreamStub{
|
||||
response: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"x-request-id": []string{"gemini-req-2"}},
|
||||
Body: io.NopCloser(strings.NewReader(`{"candidates":[{"content":{"parts":[{"text":"hello"}]}}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":5}}`)),
|
||||
},
|
||||
}
|
||||
svc := &GeminiMessagesCompatService{httpUpstream: httpStub, cfg: &config.Config{}}
|
||||
account := &Account{
|
||||
ID: 1,
|
||||
Type: AccountTypeAPIKey,
|
||||
Credentials: map[string]any{
|
||||
"api_key": "test-key",
|
||||
},
|
||||
}
|
||||
body := []byte(`{"model":"claude-sonnet-4","max_tokens":16,"messages":[{"role":"user","content":"hello"}],"tools":[{"name":"get_weather","description":"Get weather info","input_schema":{"type":"object"}},{"type":"web_search_20250305","name":"web_search"}]}`)
|
||||
|
||||
result, err := svc.Forward(context.Background(), c, account, body)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, httpStub.lastReq)
|
||||
|
||||
postedBody, err := io.ReadAll(httpStub.lastReq.Body)
|
||||
require.NoError(t, err)
|
||||
|
||||
var posted map[string]any
|
||||
require.NoError(t, json.Unmarshal(postedBody, &posted))
|
||||
tools, ok := posted["tools"].([]any)
|
||||
require.True(t, ok)
|
||||
require.Len(t, tools, 2)
|
||||
|
||||
searchTool, ok := tools[1].(map[string]any)
|
||||
require.True(t, ok)
|
||||
_, hasSnake := searchTool["google_search"]
|
||||
_, hasCamel := searchTool["googleSearch"]
|
||||
require.True(t, hasSnake)
|
||||
require.False(t, hasCamel)
|
||||
_, hasFuncDecl := searchTool["functionDeclarations"]
|
||||
require.False(t, hasFuncDecl)
|
||||
}
|
||||
|
||||
func TestConvertClaudeMessagesToGeminiGenerateContent_AddsThoughtSignatureForToolUse(t *testing.T) {
|
||||
claudeReq := map[string]any{
|
||||
"model": "claude-haiku-4-5-20251001",
|
||||
|
||||
Reference in New Issue
Block a user