Merge pull request #1418 from YanzheL/fix/1161-gemini-google-search-grounding

fix(gemini): preserve google search grounding tools
This commit is contained in:
Wesley Liddick
2026-04-08 14:19:57 +08:00
committed by GitHub
4 changed files with 221 additions and 18 deletions

View File

@@ -612,7 +612,8 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
fullURL += "?alt=sse"
}
upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(geminiReq))
restGeminiReq := normalizeGeminiRequestForAIStudio(geminiReq)
upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(restGeminiReq))
if err != nil {
return nil, "", err
}
@@ -685,7 +686,8 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
fullURL += "?alt=sse"
}
upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(geminiReq))
restGeminiReq := normalizeGeminiRequestForAIStudio(geminiReq)
upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(restGeminiReq))
if err != nil {
return nil, "", err
}
@@ -3184,12 +3186,17 @@ func convertClaudeToolsToGeminiTools(tools any) []any {
return nil
}
hasWebSearch := false
funcDecls := make([]any, 0, len(arr))
for _, t := range arr {
tm, ok := t.(map[string]any)
if !ok {
continue
}
if isClaudeWebSearchToolMap(tm) {
hasWebSearch = true
continue
}
var name, desc string
var params any
@@ -3233,13 +3240,75 @@ func convertClaudeToolsToGeminiTools(tools any) []any {
})
}
if len(funcDecls) == 0 {
out := make([]any, 0, 2)
if len(funcDecls) > 0 {
out = append(out, map[string]any{
"functionDeclarations": funcDecls,
})
}
if hasWebSearch {
out = append(out, map[string]any{
"googleSearch": map[string]any{},
})
}
if len(out) == 0 {
return nil
}
return []any{
map[string]any{
"functionDeclarations": funcDecls,
},
return out
}
func normalizeGeminiRequestForAIStudio(body []byte) []byte {
var payload map[string]any
if err := json.Unmarshal(body, &payload); err != nil {
return body
}
tools, ok := payload["tools"].([]any)
if !ok || len(tools) == 0 {
return body
}
modified := false
for _, rawTool := range tools {
tool, ok := rawTool.(map[string]any)
if !ok {
continue
}
googleSearch, ok := tool["googleSearch"]
if !ok {
continue
}
if _, exists := tool["google_search"]; exists {
continue
}
tool["google_search"] = googleSearch
delete(tool, "googleSearch")
modified = true
}
if !modified {
return body
}
normalized, err := json.Marshal(payload)
if err != nil {
return body
}
return normalized
}
func isClaudeWebSearchToolMap(tool map[string]any) bool {
toolType, _ := tool["type"].(string)
if strings.HasPrefix(toolType, "web_search") || toolType == "google_search" {
return true
}
name, _ := tool["name"].(string)
switch strings.TrimSpace(name) {
case "web_search", "google_search", "web_search_20250305":
return true
default:
return false
}
}

View File

@@ -164,6 +164,35 @@ func TestConvertClaudeToolsToGeminiTools_CustomType(t *testing.T) {
}
}
func TestConvertClaudeToolsToGeminiTools_PreservesWebSearchAlongsideFunctions(t *testing.T) {
tools := []any{
map[string]any{
"name": "get_weather",
"description": "Get weather info",
"input_schema": map[string]any{"type": "object"},
},
map[string]any{
"type": "web_search_20250305",
"name": "web_search",
},
}
result := convertClaudeToolsToGeminiTools(tools)
require.Len(t, result, 2)
functionDecl, ok := result[0].(map[string]any)
require.True(t, ok)
funcDecls, ok := functionDecl["functionDeclarations"].([]any)
require.True(t, ok)
require.Len(t, funcDecls, 1)
searchDecl, ok := result[1].(map[string]any)
require.True(t, ok)
googleSearch, ok := searchDecl["googleSearch"].(map[string]any)
require.True(t, ok)
require.Empty(t, googleSearch)
}
func TestGeminiHandleNativeNonStreamingResponse_DebugDisabledDoesNotEmitHeaderLogs(t *testing.T) {
gin.SetMode(gin.TestMode)
logSink, restore := captureStructuredLog(t)
@@ -232,6 +261,53 @@ func TestGeminiMessagesCompatServiceForward_PreservesRequestedModelAndMappedUpst
require.Contains(t, httpStub.lastReq.URL.String(), "/models/claude-sonnet-4-20250514:")
}
func TestGeminiMessagesCompatServiceForward_NormalizesWebSearchToolForAIStudio(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
httpStub := &geminiCompatHTTPUpstreamStub{
response: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"x-request-id": []string{"gemini-req-2"}},
Body: io.NopCloser(strings.NewReader(`{"candidates":[{"content":{"parts":[{"text":"hello"}]}}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":5}}`)),
},
}
svc := &GeminiMessagesCompatService{httpUpstream: httpStub, cfg: &config.Config{}}
account := &Account{
ID: 1,
Type: AccountTypeAPIKey,
Credentials: map[string]any{
"api_key": "test-key",
},
}
body := []byte(`{"model":"claude-sonnet-4","max_tokens":16,"messages":[{"role":"user","content":"hello"}],"tools":[{"name":"get_weather","description":"Get weather info","input_schema":{"type":"object"}},{"type":"web_search_20250305","name":"web_search"}]}`)
result, err := svc.Forward(context.Background(), c, account, body)
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, httpStub.lastReq)
postedBody, err := io.ReadAll(httpStub.lastReq.Body)
require.NoError(t, err)
var posted map[string]any
require.NoError(t, json.Unmarshal(postedBody, &posted))
tools, ok := posted["tools"].([]any)
require.True(t, ok)
require.Len(t, tools, 2)
searchTool, ok := tools[1].(map[string]any)
require.True(t, ok)
_, hasSnake := searchTool["google_search"]
_, hasCamel := searchTool["googleSearch"]
require.True(t, hasSnake)
require.False(t, hasCamel)
_, hasFuncDecl := searchTool["functionDeclarations"]
require.False(t, hasFuncDecl)
}
func TestConvertClaudeMessagesToGeminiGenerateContent_AddsThoughtSignatureForToolUse(t *testing.T) {
claudeReq := map[string]any{
"model": "claude-haiku-4-5-20251001",