fix(gemini): normalize ai studio google search tools
This commit is contained in:
@@ -612,7 +612,8 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
|||||||
fullURL += "?alt=sse"
|
fullURL += "?alt=sse"
|
||||||
}
|
}
|
||||||
|
|
||||||
upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(geminiReq))
|
restGeminiReq := normalizeGeminiRequestForAIStudio(geminiReq)
|
||||||
|
upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(restGeminiReq))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
@@ -685,7 +686,8 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
|||||||
fullURL += "?alt=sse"
|
fullURL += "?alt=sse"
|
||||||
}
|
}
|
||||||
|
|
||||||
upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(geminiReq))
|
restGeminiReq := normalizeGeminiRequestForAIStudio(geminiReq)
|
||||||
|
upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(restGeminiReq))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
@@ -3240,6 +3242,46 @@ func convertClaudeToolsToGeminiTools(tools any) []any {
|
|||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func normalizeGeminiRequestForAIStudio(body []byte) []byte {
|
||||||
|
var payload map[string]any
|
||||||
|
if err := json.Unmarshal(body, &payload); err != nil {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
|
tools, ok := payload["tools"].([]any)
|
||||||
|
if !ok || len(tools) == 0 {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
|
modified := false
|
||||||
|
for _, rawTool := range tools {
|
||||||
|
tool, ok := rawTool.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
googleSearch, ok := tool["googleSearch"]
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, exists := tool["google_search"]; exists {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
tool["google_search"] = googleSearch
|
||||||
|
delete(tool, "googleSearch")
|
||||||
|
modified = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if !modified {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
|
normalized, err := json.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
return normalized
|
||||||
|
}
|
||||||
|
|
||||||
func isClaudeWebSearchToolMap(tool map[string]any) bool {
|
func isClaudeWebSearchToolMap(tool map[string]any) bool {
|
||||||
toolType, _ := tool["type"].(string)
|
toolType, _ := tool["type"].(string)
|
||||||
if strings.HasPrefix(toolType, "web_search") || toolType == "google_search" {
|
if strings.HasPrefix(toolType, "web_search") || toolType == "google_search" {
|
||||||
|
|||||||
@@ -261,6 +261,53 @@ func TestGeminiMessagesCompatServiceForward_PreservesRequestedModelAndMappedUpst
|
|||||||
require.Contains(t, httpStub.lastReq.URL.String(), "/models/claude-sonnet-4-20250514:")
|
require.Contains(t, httpStub.lastReq.URL.String(), "/models/claude-sonnet-4-20250514:")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGeminiMessagesCompatServiceForward_NormalizesWebSearchToolForAIStudio(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||||
|
|
||||||
|
httpStub := &geminiCompatHTTPUpstreamStub{
|
||||||
|
response: &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: http.Header{"x-request-id": []string{"gemini-req-2"}},
|
||||||
|
Body: io.NopCloser(strings.NewReader(`{"candidates":[{"content":{"parts":[{"text":"hello"}]}}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":5}}`)),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := &GeminiMessagesCompatService{httpUpstream: httpStub, cfg: &config.Config{}}
|
||||||
|
account := &Account{
|
||||||
|
ID: 1,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"api_key": "test-key",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
body := []byte(`{"model":"claude-sonnet-4","max_tokens":16,"messages":[{"role":"user","content":"hello"}],"tools":[{"name":"get_weather","description":"Get weather info","input_schema":{"type":"object"}},{"type":"web_search_20250305","name":"web_search"}]}`)
|
||||||
|
|
||||||
|
result, err := svc.Forward(context.Background(), c, account, body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.NotNil(t, httpStub.lastReq)
|
||||||
|
|
||||||
|
postedBody, err := io.ReadAll(httpStub.lastReq.Body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var posted map[string]any
|
||||||
|
require.NoError(t, json.Unmarshal(postedBody, &posted))
|
||||||
|
tools, ok := posted["tools"].([]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Len(t, tools, 2)
|
||||||
|
|
||||||
|
searchTool, ok := tools[1].(map[string]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
_, hasSnake := searchTool["google_search"]
|
||||||
|
_, hasCamel := searchTool["googleSearch"]
|
||||||
|
require.True(t, hasSnake)
|
||||||
|
require.False(t, hasCamel)
|
||||||
|
_, hasFuncDecl := searchTool["functionDeclarations"]
|
||||||
|
require.False(t, hasFuncDecl)
|
||||||
|
}
|
||||||
|
|
||||||
func TestConvertClaudeMessagesToGeminiGenerateContent_AddsThoughtSignatureForToolUse(t *testing.T) {
|
func TestConvertClaudeMessagesToGeminiGenerateContent_AddsThoughtSignatureForToolUse(t *testing.T) {
|
||||||
claudeReq := map[string]any{
|
claudeReq := map[string]any{
|
||||||
"model": "claude-haiku-4-5-20251001",
|
"model": "claude-haiku-4-5-20251001",
|
||||||
|
|||||||
Reference in New Issue
Block a user