feat(gateway): add web search emulation for Anthropic API Key accounts
Inject web search capability for Claude Console (API Key) accounts that don't natively support Anthropic's web_search tool. When a pure web_search request is detected, the gateway calls Brave Search or Tavily API directly and constructs an Anthropic-protocol-compliant SSE/JSON response without forwarding to upstream. Backend: - New `pkg/websearch/` SDK: Brave and Tavily provider implementations with io.LimitReader, proxy support, and Redis-based quota tracking (Lua atomic INCR + TTL, DECR rollback on failure) - Global config via `settings.web_search_emulation_config` (JSON) with in-process cache + singleflight, input validation, API key merge on save, and sanitized API responses - Channel-level toggle via `channels.features_config` JSONB column (DB migration 101) - Account-level toggle via `accounts.extra.web_search_emulation` - Request interception in `Forward()` with SSE streaming response construction using json.Marshal (no manual string concatenation) - Manager hot-reload: `RebuildWebSearchManager()` called on config save and startup via `SetWebSearchRedisClient()` - 70 unit tests covering providers, manager, config validation, sanitization, tool detection, query extraction, and response building Frontend: - Settings → Gateway tab: Web Search Emulation config card with global toggle, provider list (add/remove, API key, priority, quota, proxy) - Channels → Anthropic tab: web search emulation toggle with global state linkage (disabled when global off) - Account Create/Edit modals: web search emulation toggle for API Key type with Toggle component - Full i18n coverage (zh + en)
This commit is contained in:
106
backend/internal/pkg/websearch/brave.go
Normal file
106
backend/internal/pkg/websearch/brave.go
Normal file
@@ -0,0 +1,106 @@
|
||||
package websearch
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
const (
|
||||
braveSearchEndpoint = "https://api.search.brave.com/res/v1/web/search"
|
||||
braveMaxCount = 20
|
||||
braveProviderName = "brave"
|
||||
)
|
||||
|
||||
// braveSearchURL is pre-parsed at init time; url.Parse cannot fail on a constant literal.
|
||||
var braveSearchURL, _ = url.Parse(braveSearchEndpoint) //nolint:errcheck
|
||||
|
||||
// BraveProvider implements web search via the Brave Search API.
|
||||
type BraveProvider struct {
|
||||
apiKey string
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// NewBraveProvider creates a Brave Search provider.
|
||||
// The caller is responsible for configuring the http.Client with proxy/timeouts.
|
||||
func NewBraveProvider(apiKey string, httpClient *http.Client) *BraveProvider {
|
||||
if httpClient == nil {
|
||||
httpClient = http.DefaultClient
|
||||
}
|
||||
return &BraveProvider{apiKey: apiKey, httpClient: httpClient}
|
||||
}
|
||||
|
||||
func (b *BraveProvider) Name() string { return braveProviderName }
|
||||
|
||||
func (b *BraveProvider) Search(ctx context.Context, req SearchRequest) (*SearchResponse, error) {
|
||||
count := req.MaxResults
|
||||
if count <= 0 {
|
||||
count = defaultMaxResults
|
||||
}
|
||||
if count > braveMaxCount {
|
||||
count = braveMaxCount
|
||||
}
|
||||
|
||||
u := *braveSearchURL // copy the pre-parsed URL
|
||||
q := u.Query()
|
||||
q.Set("q", req.Query)
|
||||
q.Set("count", strconv.Itoa(count))
|
||||
u.RawQuery = q.Encode()
|
||||
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("brave: build request: %w", err)
|
||||
}
|
||||
httpReq.Header.Set("X-Subscription-Token", b.apiKey)
|
||||
httpReq.Header.Set("Accept", "application/json")
|
||||
|
||||
resp, err := b.httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("brave: request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseSize))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("brave: read body: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("brave: status %d: %s", resp.StatusCode, truncateBody(body))
|
||||
}
|
||||
|
||||
var raw braveResponse
|
||||
if err := json.Unmarshal(body, &raw); err != nil {
|
||||
return nil, fmt.Errorf("brave: decode response: %w", err)
|
||||
}
|
||||
|
||||
results := make([]SearchResult, 0, len(raw.Web.Results))
|
||||
for _, r := range raw.Web.Results {
|
||||
results = append(results, SearchResult{
|
||||
URL: r.URL,
|
||||
Title: r.Title,
|
||||
Snippet: r.Description,
|
||||
PageAge: r.Age,
|
||||
})
|
||||
}
|
||||
|
||||
return &SearchResponse{Results: results, Query: req.Query}, nil
|
||||
}
|
||||
|
||||
// braveResponse is the minimal structure of the Brave Search API response.
|
||||
type braveResponse struct {
|
||||
Web struct {
|
||||
Results []braveResult `json:"results"`
|
||||
} `json:"web"`
|
||||
}
|
||||
|
||||
type braveResult struct {
|
||||
URL string `json:"url"`
|
||||
Title string `json:"title"`
|
||||
Description string `json:"description"`
|
||||
Age string `json:"age"`
|
||||
}
|
||||
119
backend/internal/pkg/websearch/brave_test.go
Normal file
119
backend/internal/pkg/websearch/brave_test.go
Normal file
@@ -0,0 +1,119 @@
|
||||
package websearch
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestBraveProvider_Name(t *testing.T) {
|
||||
p := NewBraveProvider("key", nil)
|
||||
require.Equal(t, "brave", p.Name())
|
||||
}
|
||||
|
||||
func TestBraveProvider_Search_Success(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, "test-key", r.Header.Get("X-Subscription-Token"))
|
||||
require.Equal(t, "application/json", r.Header.Get("Accept"))
|
||||
require.Equal(t, "golang", r.URL.Query().Get("q"))
|
||||
require.Equal(t, "3", r.URL.Query().Get("count"))
|
||||
|
||||
resp := braveResponse{}
|
||||
resp.Web.Results = []braveResult{
|
||||
{URL: "https://go.dev", Title: "Go", Description: "Go lang", Age: "1 day"},
|
||||
{URL: "https://pkg.go.dev", Title: "Pkg", Description: "Packages"},
|
||||
{URL: "https://tour.go.dev", Title: "Tour", Description: "A Tour of Go", Age: "3 days"},
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
p := NewBraveProvider("test-key", srv.Client())
|
||||
// Override the endpoint for testing
|
||||
origURL := *braveSearchURL
|
||||
u, _ := http.NewRequest("GET", srv.URL, nil)
|
||||
*braveSearchURL = *u.URL
|
||||
defer func() { *braveSearchURL = origURL }()
|
||||
|
||||
resp, err := p.Search(context.Background(), SearchRequest{Query: "golang", MaxResults: 3})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, resp.Results, 3)
|
||||
require.Equal(t, "https://go.dev", resp.Results[0].URL)
|
||||
require.Equal(t, "Go lang", resp.Results[0].Snippet)
|
||||
require.Equal(t, "1 day", resp.Results[0].PageAge)
|
||||
}
|
||||
|
||||
func TestBraveProvider_Search_DefaultMaxResults(t *testing.T) {
|
||||
var receivedCount string
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
receivedCount = r.URL.Query().Get("count")
|
||||
resp := braveResponse{}
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
p := NewBraveProvider("key", srv.Client())
|
||||
origURL := *braveSearchURL
|
||||
u, _ := http.NewRequest("GET", srv.URL, nil)
|
||||
*braveSearchURL = *u.URL
|
||||
defer func() { *braveSearchURL = origURL }()
|
||||
|
||||
_, _ = p.Search(context.Background(), SearchRequest{Query: "test", MaxResults: 0})
|
||||
require.Equal(t, "5", receivedCount)
|
||||
}
|
||||
|
||||
func TestBraveProvider_Search_HTTPError(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(429)
|
||||
w.Write([]byte("rate limited"))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
p := NewBraveProvider("key", srv.Client())
|
||||
origURL := *braveSearchURL
|
||||
u, _ := http.NewRequest("GET", srv.URL, nil)
|
||||
*braveSearchURL = *u.URL
|
||||
defer func() { *braveSearchURL = origURL }()
|
||||
|
||||
_, err := p.Search(context.Background(), SearchRequest{Query: "test"})
|
||||
require.ErrorContains(t, err, "brave: status 429")
|
||||
}
|
||||
|
||||
func TestBraveProvider_Search_InvalidJSON(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.Write([]byte("not json"))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
p := NewBraveProvider("key", srv.Client())
|
||||
origURL := *braveSearchURL
|
||||
u, _ := http.NewRequest("GET", srv.URL, nil)
|
||||
*braveSearchURL = *u.URL
|
||||
defer func() { *braveSearchURL = origURL }()
|
||||
|
||||
_, err := p.Search(context.Background(), SearchRequest{Query: "test"})
|
||||
require.ErrorContains(t, err, "brave: decode response")
|
||||
}
|
||||
|
||||
func TestBraveProvider_Search_EmptyResults(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
resp := braveResponse{}
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
p := NewBraveProvider("key", srv.Client())
|
||||
origURL := *braveSearchURL
|
||||
u, _ := http.NewRequest("GET", srv.URL, nil)
|
||||
*braveSearchURL = *u.URL
|
||||
defer func() { *braveSearchURL = origURL }()
|
||||
|
||||
resp, err := p.Search(context.Background(), SearchRequest{Query: "test"})
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, resp.Results)
|
||||
}
|
||||
14
backend/internal/pkg/websearch/helpers.go
Normal file
14
backend/internal/pkg/websearch/helpers.go
Normal file
@@ -0,0 +1,14 @@
|
||||
package websearch
|
||||
|
||||
const (
|
||||
maxResponseSize = 1 << 20 // 1 MB
|
||||
errorBodyTruncLen = 200
|
||||
)
|
||||
|
||||
// truncateBody returns a truncated string of body for error messages.
|
||||
func truncateBody(body []byte) string {
|
||||
if len(body) <= errorBodyTruncLen {
|
||||
return string(body)
|
||||
}
|
||||
return string(body[:errorBodyTruncLen]) + "...(truncated)"
|
||||
}
|
||||
25
backend/internal/pkg/websearch/helpers_test.go
Normal file
25
backend/internal/pkg/websearch/helpers_test.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package websearch
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestTruncateBody_Short(t *testing.T) {
|
||||
body := []byte("short body")
|
||||
require.Equal(t, "short body", truncateBody(body))
|
||||
}
|
||||
|
||||
func TestTruncateBody_Long(t *testing.T) {
|
||||
body := []byte(strings.Repeat("x", 500))
|
||||
result := truncateBody(body)
|
||||
require.Len(t, result, errorBodyTruncLen+len("...(truncated)"))
|
||||
require.True(t, strings.HasSuffix(result, "...(truncated)"))
|
||||
}
|
||||
|
||||
func TestTruncateBody_ExactBoundary(t *testing.T) {
|
||||
body := []byte(strings.Repeat("x", errorBodyTruncLen))
|
||||
require.Equal(t, string(body), truncateBody(body))
|
||||
}
|
||||
273
backend/internal/pkg/websearch/manager.go
Normal file
273
backend/internal/pkg/websearch/manager.go
Normal file
@@ -0,0 +1,273 @@
|
||||
package websearch
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// Quota refresh interval constants.
|
||||
const (
|
||||
QuotaRefreshDaily = "daily"
|
||||
QuotaRefreshWeekly = "weekly"
|
||||
QuotaRefreshMonthly = "monthly"
|
||||
)
|
||||
|
||||
// ProviderConfig holds the configuration for a single search provider.
|
||||
type ProviderConfig struct {
|
||||
Type string `json:"type"` // ProviderTypeBrave | ProviderTypeTavily
|
||||
APIKey string `json:"api_key"` // secret
|
||||
Priority int `json:"priority"` // lower = higher priority
|
||||
QuotaLimit int64 `json:"quota_limit"` // 0 = unlimited
|
||||
QuotaRefreshInterval string `json:"quota_refresh_interval"` // QuotaRefreshDaily / Weekly / Monthly
|
||||
ProxyURL string `json:"-"` // resolved proxy URL (not persisted)
|
||||
ExpiresAt *int64 `json:"expires_at,omitempty"` // optional expiration (unix seconds)
|
||||
}
|
||||
|
||||
// Manager selects providers by priority and tracks quota via Redis.
|
||||
type Manager struct {
|
||||
configs []ProviderConfig
|
||||
redis *redis.Client
|
||||
|
||||
clientMu sync.Mutex
|
||||
clientCache map[string]*http.Client
|
||||
}
|
||||
|
||||
const (
|
||||
quotaKeyPrefix = "websearch:quota:"
|
||||
searchRequestTimeout = 30 * time.Second
|
||||
quotaTTLBuffer = 24 * time.Hour
|
||||
maxCachedClients = 100
|
||||
)
|
||||
|
||||
// quotaIncrScript atomically increments the counter and sets TTL on first creation.
|
||||
// KEYS[1] = quota key, ARGV[1] = TTL in seconds.
|
||||
// Returns the new counter value.
|
||||
var quotaIncrScript = redis.NewScript(`
|
||||
local val = redis.call('INCR', KEYS[1])
|
||||
if val == 1 then
|
||||
redis.call('EXPIRE', KEYS[1], ARGV[1])
|
||||
else
|
||||
-- Defensive: ensure TTL exists even if a prior EXPIRE failed
|
||||
local ttl = redis.call('TTL', KEYS[1])
|
||||
if ttl == -1 then
|
||||
redis.call('EXPIRE', KEYS[1], ARGV[1])
|
||||
end
|
||||
end
|
||||
return val
|
||||
`)
|
||||
|
||||
// NewManager creates a Manager with the given provider configs and Redis client.
|
||||
func NewManager(configs []ProviderConfig, redisClient *redis.Client) *Manager {
|
||||
sorted := make([]ProviderConfig, len(configs))
|
||||
copy(sorted, configs)
|
||||
sort.Slice(sorted, func(i, j int) bool {
|
||||
return sorted[i].Priority < sorted[j].Priority
|
||||
})
|
||||
return &Manager{
|
||||
configs: sorted,
|
||||
redis: redisClient,
|
||||
clientCache: make(map[string]*http.Client),
|
||||
}
|
||||
}
|
||||
|
||||
// SearchWithBestProvider selects the highest-priority available provider,
|
||||
// reserves quota, executes the search, and rolls back quota on failure.
|
||||
func (m *Manager) SearchWithBestProvider(ctx context.Context, req SearchRequest) (*SearchResponse, string, error) {
|
||||
if strings.TrimSpace(req.Query) == "" {
|
||||
return nil, "", fmt.Errorf("websearch: empty search query")
|
||||
}
|
||||
for _, cfg := range m.configs {
|
||||
if !m.isProviderAvailable(cfg) {
|
||||
continue
|
||||
}
|
||||
allowed, incremented := m.tryReserveQuota(ctx, cfg)
|
||||
if !allowed {
|
||||
continue
|
||||
}
|
||||
resp, err := m.executeSearch(ctx, cfg, req)
|
||||
if err != nil {
|
||||
if incremented {
|
||||
m.rollbackQuota(ctx, cfg)
|
||||
}
|
||||
slog.Warn("websearch: provider search failed",
|
||||
"provider", cfg.Type, "error", err)
|
||||
continue
|
||||
}
|
||||
return resp, cfg.Type, nil
|
||||
}
|
||||
return nil, "", fmt.Errorf("websearch: no available provider (all exhausted or failed)")
|
||||
}
|
||||
|
||||
func (m *Manager) isProviderAvailable(cfg ProviderConfig) bool {
|
||||
if cfg.APIKey == "" {
|
||||
return false
|
||||
}
|
||||
if cfg.ExpiresAt != nil && time.Now().Unix() > *cfg.ExpiresAt {
|
||||
slog.Info("websearch: provider expired, skipping",
|
||||
"provider", cfg.Type, "expires_at", *cfg.ExpiresAt)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// tryReserveQuota atomically increments the counter via Lua script and checks limit.
|
||||
// Returns (allowed, incremented): allowed=true means the request may proceed;
|
||||
// incremented=true means the Redis counter was actually incremented (so rollback is needed on failure).
|
||||
func (m *Manager) tryReserveQuota(ctx context.Context, cfg ProviderConfig) (bool, bool) {
|
||||
if cfg.QuotaLimit <= 0 {
|
||||
return true, false // unlimited, no INCR
|
||||
}
|
||||
if m.redis == nil {
|
||||
slog.Warn("websearch: Redis unavailable, quota check skipped",
|
||||
"provider", cfg.Type)
|
||||
return true, false // allowed but not incremented
|
||||
}
|
||||
key := quotaRedisKey(cfg.Type, cfg.QuotaRefreshInterval)
|
||||
ttlSec := int(quotaTTL(cfg.QuotaRefreshInterval).Seconds())
|
||||
|
||||
newVal, err := quotaIncrScript.Run(ctx, m.redis, []string{key}, ttlSec).Int64()
|
||||
if err != nil {
|
||||
slog.Warn("websearch: quota Lua INCR failed, allowing request",
|
||||
"provider", cfg.Type, "error", err)
|
||||
return true, false // allowed but not incremented
|
||||
}
|
||||
if newVal > cfg.QuotaLimit {
|
||||
if decrErr := m.redis.Decr(ctx, key).Err(); decrErr != nil {
|
||||
slog.Warn("websearch: quota over-limit DECR failed",
|
||||
"provider", cfg.Type, "error", decrErr)
|
||||
}
|
||||
slog.Info("websearch: provider quota exhausted",
|
||||
"provider", cfg.Type, "used", newVal, "limit", cfg.QuotaLimit)
|
||||
return false, false // rejected, already rolled back
|
||||
}
|
||||
return true, true // allowed and incremented
|
||||
}
|
||||
|
||||
// rollbackQuota decrements the counter after a search failure.
|
||||
func (m *Manager) rollbackQuota(ctx context.Context, cfg ProviderConfig) {
|
||||
if cfg.QuotaLimit <= 0 || m.redis == nil {
|
||||
return
|
||||
}
|
||||
key := quotaRedisKey(cfg.Type, cfg.QuotaRefreshInterval)
|
||||
if err := m.redis.Decr(ctx, key).Err(); err != nil {
|
||||
slog.Warn("websearch: quota rollback DECR failed",
|
||||
"provider", cfg.Type, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) executeSearch(ctx context.Context, cfg ProviderConfig, req SearchRequest) (*SearchResponse, error) {
|
||||
proxyURL := cfg.ProxyURL
|
||||
if req.ProxyURL != "" {
|
||||
proxyURL = req.ProxyURL
|
||||
}
|
||||
client := m.getOrCreateHTTPClient(proxyURL)
|
||||
provider := m.buildProvider(cfg, client)
|
||||
return provider.Search(ctx, req)
|
||||
}
|
||||
|
||||
// GetUsage returns the current usage count for the given provider.
|
||||
func (m *Manager) GetUsage(ctx context.Context, providerType, refreshInterval string) (int64, error) {
|
||||
if m.redis == nil {
|
||||
return 0, nil
|
||||
}
|
||||
key := quotaRedisKey(providerType, refreshInterval)
|
||||
val, err := m.redis.Get(ctx, key).Int64()
|
||||
if err == redis.Nil {
|
||||
return 0, nil
|
||||
}
|
||||
return val, err
|
||||
}
|
||||
|
||||
// GetAllUsage returns usage for every configured provider.
|
||||
func (m *Manager) GetAllUsage(ctx context.Context) map[string]int64 {
|
||||
result := make(map[string]int64, len(m.configs))
|
||||
for _, cfg := range m.configs {
|
||||
used, _ := m.GetUsage(ctx, cfg.Type, cfg.QuotaRefreshInterval)
|
||||
result[cfg.Type] = used
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// --- HTTP client cache (bounded) ---
|
||||
|
||||
func (m *Manager) getOrCreateHTTPClient(proxyURL string) *http.Client {
|
||||
m.clientMu.Lock()
|
||||
defer m.clientMu.Unlock()
|
||||
|
||||
if c, ok := m.clientCache[proxyURL]; ok {
|
||||
return c
|
||||
}
|
||||
if len(m.clientCache) >= maxCachedClients {
|
||||
m.clientCache = make(map[string]*http.Client) // evict all
|
||||
}
|
||||
c := newHTTPClient(proxyURL)
|
||||
m.clientCache[proxyURL] = c
|
||||
return c
|
||||
}
|
||||
|
||||
func newHTTPClient(proxyURL string) *http.Client {
|
||||
transport := &http.Transport{
|
||||
TLSClientConfig: &tls.Config{MinVersion: tls.VersionTLS12},
|
||||
}
|
||||
if proxyURL != "" {
|
||||
if u, err := url.Parse(proxyURL); err == nil {
|
||||
transport.Proxy = http.ProxyURL(u)
|
||||
}
|
||||
}
|
||||
return &http.Client{Transport: transport, Timeout: searchRequestTimeout}
|
||||
}
|
||||
|
||||
// --- Provider factory ---
|
||||
|
||||
func (m *Manager) buildProvider(cfg ProviderConfig, client *http.Client) Provider {
|
||||
switch cfg.Type {
|
||||
case braveProviderName:
|
||||
return NewBraveProvider(cfg.APIKey, client)
|
||||
case tavilyProviderName:
|
||||
return NewTavilyProvider(cfg.APIKey, client)
|
||||
default:
|
||||
slog.Warn("websearch: unknown provider type, falling back to brave",
|
||||
"type", cfg.Type)
|
||||
return NewBraveProvider(cfg.APIKey, client)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Redis key helpers ---
|
||||
|
||||
func quotaRedisKey(providerType, refreshInterval string) string {
|
||||
return quotaKeyPrefix + providerType + ":" + periodKey(refreshInterval)
|
||||
}
|
||||
|
||||
func periodKey(refreshInterval string) string {
|
||||
now := time.Now().UTC()
|
||||
switch refreshInterval {
|
||||
case QuotaRefreshDaily:
|
||||
return now.Format("2006-01-02")
|
||||
case QuotaRefreshWeekly:
|
||||
year, week := now.ISOWeek()
|
||||
return fmt.Sprintf("%d-W%02d", year, week)
|
||||
default: // QuotaRefreshMonthly
|
||||
return now.Format("2006-01")
|
||||
}
|
||||
}
|
||||
|
||||
func quotaTTL(refreshInterval string) time.Duration {
|
||||
switch refreshInterval {
|
||||
case QuotaRefreshDaily:
|
||||
return 24*time.Hour + quotaTTLBuffer
|
||||
case QuotaRefreshWeekly:
|
||||
return 7*24*time.Hour + quotaTTLBuffer
|
||||
default:
|
||||
return 31*24*time.Hour + quotaTTLBuffer
|
||||
}
|
||||
}
|
||||
149
backend/internal/pkg/websearch/manager_test.go
Normal file
149
backend/internal/pkg/websearch/manager_test.go
Normal file
@@ -0,0 +1,149 @@
|
||||
package websearch
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewManager_SortsByPriority(t *testing.T) {
|
||||
configs := []ProviderConfig{
|
||||
{Type: "brave", APIKey: "k3", Priority: 30},
|
||||
{Type: "tavily", APIKey: "k1", Priority: 10},
|
||||
}
|
||||
m := NewManager(configs, nil)
|
||||
require.Equal(t, 10, m.configs[0].Priority)
|
||||
require.Equal(t, 30, m.configs[1].Priority)
|
||||
}
|
||||
|
||||
func TestManager_SearchWithBestProvider_EmptyQuery(t *testing.T) {
|
||||
m := NewManager([]ProviderConfig{{Type: "brave", APIKey: "k"}}, nil)
|
||||
_, _, err := m.SearchWithBestProvider(context.Background(), SearchRequest{Query: ""})
|
||||
require.ErrorContains(t, err, "empty search query")
|
||||
|
||||
_, _, err = m.SearchWithBestProvider(context.Background(), SearchRequest{Query: " "})
|
||||
require.ErrorContains(t, err, "empty search query")
|
||||
}
|
||||
|
||||
func TestManager_SearchWithBestProvider_SkipEmptyAPIKey(t *testing.T) {
|
||||
m := NewManager([]ProviderConfig{{Type: "brave", APIKey: ""}}, nil)
|
||||
_, _, err := m.SearchWithBestProvider(context.Background(), SearchRequest{Query: "test"})
|
||||
require.ErrorContains(t, err, "no available provider")
|
||||
}
|
||||
|
||||
func TestManager_SearchWithBestProvider_SkipExpired(t *testing.T) {
|
||||
past := time.Now().Add(-1 * time.Hour).Unix()
|
||||
m := NewManager([]ProviderConfig{
|
||||
{Type: "brave", APIKey: "k", ExpiresAt: &past},
|
||||
}, nil)
|
||||
_, _, err := m.SearchWithBestProvider(context.Background(), SearchRequest{Query: "test"})
|
||||
require.ErrorContains(t, err, "no available provider")
|
||||
}
|
||||
|
||||
func TestManager_SearchWithBestProvider_PriorityOrder(t *testing.T) {
|
||||
// Create two mock servers that return different results
|
||||
srvBrave := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
resp := braveResponse{}
|
||||
resp.Web.Results = []braveResult{{URL: "https://brave.com", Title: "Brave", Description: "from brave"}}
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer srvBrave.Close()
|
||||
|
||||
// Override brave endpoint for test
|
||||
origURL := *braveSearchURL
|
||||
u, _ := http.NewRequest("GET", srvBrave.URL, nil)
|
||||
*braveSearchURL = *u.URL
|
||||
defer func() { *braveSearchURL = origURL }()
|
||||
|
||||
m := NewManager([]ProviderConfig{
|
||||
{Type: "brave", APIKey: "k1", Priority: 1},
|
||||
{Type: "tavily", APIKey: "k2", Priority: 2},
|
||||
}, nil)
|
||||
// Inject the test server's client
|
||||
m.clientCache[srvBrave.URL] = srvBrave.Client()
|
||||
m.clientCache[""] = srvBrave.Client()
|
||||
|
||||
resp, providerName, err := m.SearchWithBestProvider(context.Background(), SearchRequest{Query: "test"})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "brave", providerName)
|
||||
require.Len(t, resp.Results, 1)
|
||||
require.Equal(t, "from brave", resp.Results[0].Snippet)
|
||||
}
|
||||
|
||||
func TestManager_SearchWithBestProvider_NilRedis(t *testing.T) {
|
||||
// With nil Redis, quota check is skipped (always allowed)
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
resp := braveResponse{}
|
||||
resp.Web.Results = []braveResult{{URL: "https://test.com", Title: "Test", Description: "result"}}
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
origURL := *braveSearchURL
|
||||
u, _ := http.NewRequest("GET", srv.URL, nil)
|
||||
*braveSearchURL = *u.URL
|
||||
defer func() { *braveSearchURL = origURL }()
|
||||
|
||||
m := NewManager([]ProviderConfig{
|
||||
{Type: "brave", APIKey: "k", Priority: 1, QuotaLimit: 100},
|
||||
}, nil) // nil Redis
|
||||
m.clientCache[""] = srv.Client()
|
||||
|
||||
resp, _, err := m.SearchWithBestProvider(context.Background(), SearchRequest{Query: "test"})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, resp.Results, 1)
|
||||
}
|
||||
|
||||
func TestManager_GetUsage_NilRedis(t *testing.T) {
|
||||
m := NewManager(nil, nil)
|
||||
used, err := m.GetUsage(context.Background(), "brave", "monthly")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(0), used)
|
||||
}
|
||||
|
||||
func TestManager_GetAllUsage_NilRedis(t *testing.T) {
|
||||
m := NewManager([]ProviderConfig{
|
||||
{Type: "brave", QuotaRefreshInterval: "monthly"},
|
||||
}, nil)
|
||||
usage := m.GetAllUsage(context.Background())
|
||||
require.Equal(t, int64(0), usage["brave"])
|
||||
}
|
||||
|
||||
// --- Key/TTL helpers ---
|
||||
|
||||
func TestQuotaTTL_Daily(t *testing.T) {
|
||||
require.Equal(t, 24*time.Hour+quotaTTLBuffer, quotaTTL(QuotaRefreshDaily))
|
||||
}
|
||||
|
||||
func TestQuotaTTL_Weekly(t *testing.T) {
|
||||
require.Equal(t, 7*24*time.Hour+quotaTTLBuffer, quotaTTL(QuotaRefreshWeekly))
|
||||
}
|
||||
|
||||
func TestQuotaTTL_Monthly(t *testing.T) {
|
||||
require.Equal(t, 31*24*time.Hour+quotaTTLBuffer, quotaTTL(QuotaRefreshMonthly))
|
||||
}
|
||||
|
||||
func TestPeriodKey_Daily(t *testing.T) {
|
||||
key := periodKey(QuotaRefreshDaily)
|
||||
require.Regexp(t, `^\d{4}-\d{2}-\d{2}$`, key)
|
||||
}
|
||||
|
||||
func TestPeriodKey_Weekly(t *testing.T) {
|
||||
key := periodKey(QuotaRefreshWeekly)
|
||||
require.Regexp(t, `^\d{4}-W\d{2}$`, key)
|
||||
}
|
||||
|
||||
func TestPeriodKey_Monthly(t *testing.T) {
|
||||
key := periodKey(QuotaRefreshMonthly)
|
||||
require.Regexp(t, `^\d{4}-\d{2}$`, key)
|
||||
}
|
||||
|
||||
func TestQuotaRedisKey_Format(t *testing.T) {
|
||||
key := quotaRedisKey("brave", QuotaRefreshDaily)
|
||||
require.Contains(t, key, "websearch:quota:brave:")
|
||||
}
|
||||
11
backend/internal/pkg/websearch/provider.go
Normal file
11
backend/internal/pkg/websearch/provider.go
Normal file
@@ -0,0 +1,11 @@
|
||||
package websearch
|
||||
|
||||
import "context"
|
||||
|
||||
// Provider is the interface every search backend must implement.
|
||||
type Provider interface {
|
||||
// Name returns the provider identifier ("brave" or "tavily").
|
||||
Name() string
|
||||
// Search executes a web search and returns results.
|
||||
Search(ctx context.Context, req SearchRequest) (*SearchResponse, error)
|
||||
}
|
||||
107
backend/internal/pkg/websearch/tavily.go
Normal file
107
backend/internal/pkg/websearch/tavily.go
Normal file
@@ -0,0 +1,107 @@
|
||||
package websearch
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
const (
|
||||
tavilySearchEndpoint = "https://api.tavily.com/search"
|
||||
tavilyProviderName = "tavily"
|
||||
tavilySearchDepthBasic = "basic"
|
||||
)
|
||||
|
||||
// TavilyProvider implements web search via the Tavily Search API.
|
||||
type TavilyProvider struct {
|
||||
apiKey string
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// NewTavilyProvider creates a Tavily Search provider.
|
||||
// The caller is responsible for configuring the http.Client with proxy/timeouts.
|
||||
func NewTavilyProvider(apiKey string, httpClient *http.Client) *TavilyProvider {
|
||||
if httpClient == nil {
|
||||
httpClient = http.DefaultClient
|
||||
}
|
||||
return &TavilyProvider{apiKey: apiKey, httpClient: httpClient}
|
||||
}
|
||||
|
||||
func (t *TavilyProvider) Name() string { return tavilyProviderName }
|
||||
|
||||
func (t *TavilyProvider) Search(ctx context.Context, req SearchRequest) (*SearchResponse, error) {
|
||||
maxResults := req.MaxResults
|
||||
if maxResults <= 0 {
|
||||
maxResults = defaultMaxResults
|
||||
}
|
||||
|
||||
payload := tavilyRequest{
|
||||
APIKey: t.apiKey,
|
||||
Query: req.Query,
|
||||
MaxResults: maxResults,
|
||||
SearchDepth: tavilySearchDepthBasic,
|
||||
}
|
||||
|
||||
bodyBytes, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("tavily: encode request: %w", err)
|
||||
}
|
||||
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, tavilySearchEndpoint, bytes.NewReader(bodyBytes))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("tavily: build request: %w", err)
|
||||
}
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := t.httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("tavily: request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseSize))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("tavily: read body: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("tavily: status %d: %s", resp.StatusCode, truncateBody(body))
|
||||
}
|
||||
|
||||
var raw tavilyResponse
|
||||
if err := json.Unmarshal(body, &raw); err != nil {
|
||||
return nil, fmt.Errorf("tavily: decode response: %w", err)
|
||||
}
|
||||
|
||||
results := make([]SearchResult, 0, len(raw.Results))
|
||||
for _, r := range raw.Results {
|
||||
results = append(results, SearchResult{
|
||||
URL: r.URL,
|
||||
Title: r.Title,
|
||||
Snippet: r.Content,
|
||||
})
|
||||
}
|
||||
|
||||
return &SearchResponse{Results: results, Query: req.Query}, nil
|
||||
}
|
||||
|
||||
type tavilyRequest struct {
|
||||
APIKey string `json:"api_key"`
|
||||
Query string `json:"query"`
|
||||
MaxResults int `json:"max_results"`
|
||||
SearchDepth string `json:"search_depth"`
|
||||
}
|
||||
|
||||
type tavilyResponse struct {
|
||||
Results []tavilyResult `json:"results"`
|
||||
}
|
||||
|
||||
type tavilyResult struct {
|
||||
URL string `json:"url"`
|
||||
Title string `json:"title"`
|
||||
Content string `json:"content"`
|
||||
Score float64 `json:"score"`
|
||||
}
|
||||
63
backend/internal/pkg/websearch/tavily_test.go
Normal file
63
backend/internal/pkg/websearch/tavily_test.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package websearch
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestTavilyProvider_Name(t *testing.T) {
|
||||
p := NewTavilyProvider("key", nil)
|
||||
require.Equal(t, "tavily", p.Name())
|
||||
}
|
||||
|
||||
func TestTavilyProvider_Search_RequestConstruction(t *testing.T) {
|
||||
// Verify tavilyRequest struct fields map correctly
|
||||
req := tavilyRequest{
|
||||
APIKey: "test-key",
|
||||
Query: "golang",
|
||||
MaxResults: 3,
|
||||
SearchDepth: tavilySearchDepthBasic,
|
||||
}
|
||||
data, err := json.Marshal(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
var parsed map[string]any
|
||||
require.NoError(t, json.Unmarshal(data, &parsed))
|
||||
require.Equal(t, "test-key", parsed["api_key"])
|
||||
require.Equal(t, "golang", parsed["query"])
|
||||
require.Equal(t, float64(3), parsed["max_results"])
|
||||
require.Equal(t, "basic", parsed["search_depth"])
|
||||
}
|
||||
|
||||
func TestTavilyProvider_Search_ResponseParsing(t *testing.T) {
|
||||
rawResp := `{"results":[{"url":"https://go.dev","title":"Go","content":"Go programming language","score":0.95}]}`
|
||||
var resp tavilyResponse
|
||||
require.NoError(t, json.Unmarshal([]byte(rawResp), &resp))
|
||||
require.Len(t, resp.Results, 1)
|
||||
require.Equal(t, "https://go.dev", resp.Results[0].URL)
|
||||
require.Equal(t, "Go programming language", resp.Results[0].Content)
|
||||
require.InDelta(t, 0.95, resp.Results[0].Score, 0.001)
|
||||
|
||||
// Verify mapping to SearchResult
|
||||
results := make([]SearchResult, 0, len(resp.Results))
|
||||
for _, r := range resp.Results {
|
||||
results = append(results, SearchResult{
|
||||
URL: r.URL, Title: r.Title, Snippet: r.Content,
|
||||
})
|
||||
}
|
||||
require.Equal(t, "Go programming language", results[0].Snippet)
|
||||
require.Equal(t, "", results[0].PageAge)
|
||||
}
|
||||
|
||||
func TestTavilyProvider_Search_EmptyResults(t *testing.T) {
|
||||
var resp tavilyResponse
|
||||
require.NoError(t, json.Unmarshal([]byte(`{"results":[]}`), &resp))
|
||||
require.Empty(t, resp.Results)
|
||||
}
|
||||
|
||||
func TestTavilyProvider_Search_InvalidJSON(t *testing.T) {
|
||||
var resp tavilyResponse
|
||||
require.Error(t, json.Unmarshal([]byte("not json"), &resp))
|
||||
}
|
||||
30
backend/internal/pkg/websearch/types.go
Normal file
30
backend/internal/pkg/websearch/types.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package websearch
|
||||
|
||||
// SearchResult represents a single web search result.
|
||||
type SearchResult struct {
|
||||
URL string `json:"url"`
|
||||
Title string `json:"title"`
|
||||
Snippet string `json:"snippet"`
|
||||
PageAge string `json:"page_age,omitempty"`
|
||||
}
|
||||
|
||||
// SearchRequest describes a web search to perform.
|
||||
type SearchRequest struct {
|
||||
Query string
|
||||
MaxResults int // defaults to defaultMaxResults if <= 0
|
||||
ProxyURL string // optional HTTP proxy URL
|
||||
}
|
||||
|
||||
// SearchResponse holds the results of a web search.
|
||||
type SearchResponse struct {
|
||||
Results []SearchResult
|
||||
Query string // the query that was actually executed
|
||||
}
|
||||
|
||||
const defaultMaxResults = 5
|
||||
|
||||
// Provider type identifiers.
|
||||
const (
|
||||
ProviderTypeBrave = "brave"
|
||||
ProviderTypeTavily = "tavily"
|
||||
)
|
||||
Reference in New Issue
Block a user