feat(gateway): add web search emulation for Anthropic API Key accounts
Inject web search capability for Claude Console (API Key) accounts that don't natively support Anthropic's web_search tool. When a pure web_search request is detected, the gateway calls Brave Search or Tavily API directly and constructs an Anthropic-protocol-compliant SSE/JSON response without forwarding to upstream. Backend: - New `pkg/websearch/` SDK: Brave and Tavily provider implementations with io.LimitReader, proxy support, and Redis-based quota tracking (Lua atomic INCR + TTL, DECR rollback on failure) - Global config via `settings.web_search_emulation_config` (JSON) with in-process cache + singleflight, input validation, API key merge on save, and sanitized API responses - Channel-level toggle via `channels.features_config` JSONB column (DB migration 101) - Account-level toggle via `accounts.extra.web_search_emulation` - Request interception in `Forward()` with SSE streaming response construction using json.Marshal (no manual string concatenation) - Manager hot-reload: `RebuildWebSearchManager()` called on config save and startup via `SetWebSearchRedisClient()` - 70 unit tests covering providers, manager, config validation, sanitization, tool detection, query extraction, and response building Frontend: - Settings → Gateway tab: Web Search Emulation config card with global toggle, provider list (add/remove, API key, priority, quota, proxy) - Channels → Anthropic tab: web search emulation toggle with global state linkage (disabled when global off) - Account Create/Edit modals: web search emulation toggle for API Key type with Toggle component - Full i18n coverage (zh + en)
This commit is contained in:
149
backend/internal/pkg/websearch/manager_test.go
Normal file
149
backend/internal/pkg/websearch/manager_test.go
Normal file
@@ -0,0 +1,149 @@
|
||||
package websearch
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewManager_SortsByPriority(t *testing.T) {
|
||||
configs := []ProviderConfig{
|
||||
{Type: "brave", APIKey: "k3", Priority: 30},
|
||||
{Type: "tavily", APIKey: "k1", Priority: 10},
|
||||
}
|
||||
m := NewManager(configs, nil)
|
||||
require.Equal(t, 10, m.configs[0].Priority)
|
||||
require.Equal(t, 30, m.configs[1].Priority)
|
||||
}
|
||||
|
||||
func TestManager_SearchWithBestProvider_EmptyQuery(t *testing.T) {
|
||||
m := NewManager([]ProviderConfig{{Type: "brave", APIKey: "k"}}, nil)
|
||||
_, _, err := m.SearchWithBestProvider(context.Background(), SearchRequest{Query: ""})
|
||||
require.ErrorContains(t, err, "empty search query")
|
||||
|
||||
_, _, err = m.SearchWithBestProvider(context.Background(), SearchRequest{Query: " "})
|
||||
require.ErrorContains(t, err, "empty search query")
|
||||
}
|
||||
|
||||
func TestManager_SearchWithBestProvider_SkipEmptyAPIKey(t *testing.T) {
|
||||
m := NewManager([]ProviderConfig{{Type: "brave", APIKey: ""}}, nil)
|
||||
_, _, err := m.SearchWithBestProvider(context.Background(), SearchRequest{Query: "test"})
|
||||
require.ErrorContains(t, err, "no available provider")
|
||||
}
|
||||
|
||||
func TestManager_SearchWithBestProvider_SkipExpired(t *testing.T) {
|
||||
past := time.Now().Add(-1 * time.Hour).Unix()
|
||||
m := NewManager([]ProviderConfig{
|
||||
{Type: "brave", APIKey: "k", ExpiresAt: &past},
|
||||
}, nil)
|
||||
_, _, err := m.SearchWithBestProvider(context.Background(), SearchRequest{Query: "test"})
|
||||
require.ErrorContains(t, err, "no available provider")
|
||||
}
|
||||
|
||||
func TestManager_SearchWithBestProvider_PriorityOrder(t *testing.T) {
|
||||
// Create two mock servers that return different results
|
||||
srvBrave := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
resp := braveResponse{}
|
||||
resp.Web.Results = []braveResult{{URL: "https://brave.com", Title: "Brave", Description: "from brave"}}
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer srvBrave.Close()
|
||||
|
||||
// Override brave endpoint for test
|
||||
origURL := *braveSearchURL
|
||||
u, _ := http.NewRequest("GET", srvBrave.URL, nil)
|
||||
*braveSearchURL = *u.URL
|
||||
defer func() { *braveSearchURL = origURL }()
|
||||
|
||||
m := NewManager([]ProviderConfig{
|
||||
{Type: "brave", APIKey: "k1", Priority: 1},
|
||||
{Type: "tavily", APIKey: "k2", Priority: 2},
|
||||
}, nil)
|
||||
// Inject the test server's client
|
||||
m.clientCache[srvBrave.URL] = srvBrave.Client()
|
||||
m.clientCache[""] = srvBrave.Client()
|
||||
|
||||
resp, providerName, err := m.SearchWithBestProvider(context.Background(), SearchRequest{Query: "test"})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "brave", providerName)
|
||||
require.Len(t, resp.Results, 1)
|
||||
require.Equal(t, "from brave", resp.Results[0].Snippet)
|
||||
}
|
||||
|
||||
func TestManager_SearchWithBestProvider_NilRedis(t *testing.T) {
|
||||
// With nil Redis, quota check is skipped (always allowed)
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
resp := braveResponse{}
|
||||
resp.Web.Results = []braveResult{{URL: "https://test.com", Title: "Test", Description: "result"}}
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
origURL := *braveSearchURL
|
||||
u, _ := http.NewRequest("GET", srv.URL, nil)
|
||||
*braveSearchURL = *u.URL
|
||||
defer func() { *braveSearchURL = origURL }()
|
||||
|
||||
m := NewManager([]ProviderConfig{
|
||||
{Type: "brave", APIKey: "k", Priority: 1, QuotaLimit: 100},
|
||||
}, nil) // nil Redis
|
||||
m.clientCache[""] = srv.Client()
|
||||
|
||||
resp, _, err := m.SearchWithBestProvider(context.Background(), SearchRequest{Query: "test"})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, resp.Results, 1)
|
||||
}
|
||||
|
||||
func TestManager_GetUsage_NilRedis(t *testing.T) {
|
||||
m := NewManager(nil, nil)
|
||||
used, err := m.GetUsage(context.Background(), "brave", "monthly")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(0), used)
|
||||
}
|
||||
|
||||
func TestManager_GetAllUsage_NilRedis(t *testing.T) {
|
||||
m := NewManager([]ProviderConfig{
|
||||
{Type: "brave", QuotaRefreshInterval: "monthly"},
|
||||
}, nil)
|
||||
usage := m.GetAllUsage(context.Background())
|
||||
require.Equal(t, int64(0), usage["brave"])
|
||||
}
|
||||
|
||||
// --- Key/TTL helpers ---
|
||||
|
||||
func TestQuotaTTL_Daily(t *testing.T) {
|
||||
require.Equal(t, 24*time.Hour+quotaTTLBuffer, quotaTTL(QuotaRefreshDaily))
|
||||
}
|
||||
|
||||
func TestQuotaTTL_Weekly(t *testing.T) {
|
||||
require.Equal(t, 7*24*time.Hour+quotaTTLBuffer, quotaTTL(QuotaRefreshWeekly))
|
||||
}
|
||||
|
||||
func TestQuotaTTL_Monthly(t *testing.T) {
|
||||
require.Equal(t, 31*24*time.Hour+quotaTTLBuffer, quotaTTL(QuotaRefreshMonthly))
|
||||
}
|
||||
|
||||
func TestPeriodKey_Daily(t *testing.T) {
|
||||
key := periodKey(QuotaRefreshDaily)
|
||||
require.Regexp(t, `^\d{4}-\d{2}-\d{2}$`, key)
|
||||
}
|
||||
|
||||
func TestPeriodKey_Weekly(t *testing.T) {
|
||||
key := periodKey(QuotaRefreshWeekly)
|
||||
require.Regexp(t, `^\d{4}-W\d{2}$`, key)
|
||||
}
|
||||
|
||||
func TestPeriodKey_Monthly(t *testing.T) {
|
||||
key := periodKey(QuotaRefreshMonthly)
|
||||
require.Regexp(t, `^\d{4}-\d{2}$`, key)
|
||||
}
|
||||
|
||||
func TestQuotaRedisKey_Format(t *testing.T) {
|
||||
key := quotaRedisKey("brave", QuotaRefreshDaily)
|
||||
require.Contains(t, key, "websearch:quota:brave:")
|
||||
}
|
||||
Reference in New Issue
Block a user