From 1b53ffcac71242840a54ef102977e94e204aefc1 Mon Sep 17 00:00:00 2001
From: erio
Date: Sun, 12 Apr 2026 00:02:26 +0800
Subject: [PATCH] feat(gateway): add web search emulation for Anthropic API Key
accounts
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Inject web search capability for Claude Console (API Key) accounts that
don't natively support Anthropic's web_search tool. When a pure
web_search request is detected, the gateway calls Brave Search or Tavily
API directly and constructs an Anthropic-protocol-compliant SSE/JSON
response without forwarding to upstream.
Backend:
- New `pkg/websearch/` SDK: Brave and Tavily provider implementations
with io.LimitReader, proxy support, and Redis-based quota tracking
(Lua atomic INCR + TTL, DECR rollback on failure)
- Global config via `settings.web_search_emulation_config` (JSON) with
in-process cache + singleflight, input validation, API key merge on
save, and sanitized API responses
- Channel-level toggle via `channels.features_config` JSONB column
(DB migration 101)
- Account-level toggle via `accounts.extra.web_search_emulation`
- Request interception in `Forward()` with SSE streaming response
construction using json.Marshal (no manual string concatenation)
- Manager hot-reload: `RebuildWebSearchManager()` called on config save
and startup via `SetWebSearchRedisClient()`
- 70 unit tests covering providers, manager, config validation,
sanitization, tool detection, query extraction, and response building
Frontend:
- Settings → Gateway tab: Web Search Emulation config card with global
toggle, provider list (add/remove, API key, priority, quota, proxy)
- Channels → Anthropic tab: web search emulation toggle with global
state linkage (disabled when global off)
- Account Create/Edit modals: web search emulation toggle for API Key
type with Toggle component
- Full i18n coverage (zh + en)
---
.../internal/handler/admin/channel_handler.go | 6 +
.../internal/handler/admin/setting_handler.go | 35 +
backend/internal/handler/dto/settings.go | 3 +
backend/internal/pkg/websearch/brave.go | 106 ++
backend/internal/pkg/websearch/brave_test.go | 119 +++
backend/internal/pkg/websearch/helpers.go | 14 +
.../internal/pkg/websearch/helpers_test.go | 25 +
backend/internal/pkg/websearch/manager.go | 273 ++++++
.../internal/pkg/websearch/manager_test.go | 149 +++
backend/internal/pkg/websearch/provider.go | 11 +
backend/internal/pkg/websearch/tavily.go | 107 ++
backend/internal/pkg/websearch/tavily_test.go | 63 ++
backend/internal/pkg/websearch/types.go | 30 +
backend/internal/repository/channel_repo.go | 90 +-
backend/internal/server/routes/admin.go | 3 +
backend/internal/service/account.go | 19 +-
.../service/account_websearch_test.go | 71 ++
backend/internal/service/channel.go | 15 +
backend/internal/service/channel_service.go | 389 +++++---
.../service/channel_websearch_test.go | 62 ++
backend/internal/service/domain_constants.go | 4 +
backend/internal/service/gateway_service.go | 5 +
.../service/gateway_websearch_emulation.go | 358 +++++++
.../gateway_websearch_emulation_test.go | 142 +++
backend/internal/service/setting_service.go | 10 +
backend/internal/service/settings_view.go | 3 +
backend/internal/service/websearch_config.go | 253 +++++
.../internal/service/websearch_config_test.go | 148 +++
.../101_add_channel_features_config.sql | 2 +
frontend/src/api/admin/channels.ts | 3 +
frontend/src/api/admin/settings.ts | 40 +-
.../components/account/CreateAccountModal.vue | 26 +
.../components/account/EditAccountModal.vue | 90 +-
frontend/src/i18n/locales/en.ts | 33 +-
frontend/src/i18n/locales/zh.ts | 33 +-
frontend/src/views/admin/ChannelsView.vue | 94 +-
frontend/src/views/admin/SettingsView.vue | 911 +++++++++++++++++-
37 files changed, 3507 insertions(+), 238 deletions(-)
create mode 100644 backend/internal/pkg/websearch/brave.go
create mode 100644 backend/internal/pkg/websearch/brave_test.go
create mode 100644 backend/internal/pkg/websearch/helpers.go
create mode 100644 backend/internal/pkg/websearch/helpers_test.go
create mode 100644 backend/internal/pkg/websearch/manager.go
create mode 100644 backend/internal/pkg/websearch/manager_test.go
create mode 100644 backend/internal/pkg/websearch/provider.go
create mode 100644 backend/internal/pkg/websearch/tavily.go
create mode 100644 backend/internal/pkg/websearch/tavily_test.go
create mode 100644 backend/internal/pkg/websearch/types.go
create mode 100644 backend/internal/service/account_websearch_test.go
create mode 100644 backend/internal/service/channel_websearch_test.go
create mode 100644 backend/internal/service/gateway_websearch_emulation.go
create mode 100644 backend/internal/service/gateway_websearch_emulation_test.go
create mode 100644 backend/internal/service/websearch_config.go
create mode 100644 backend/internal/service/websearch_config_test.go
create mode 100644 backend/migrations/101_add_channel_features_config.sql
diff --git a/backend/internal/handler/admin/channel_handler.go b/backend/internal/handler/admin/channel_handler.go
index d6022283..9cefc792 100644
--- a/backend/internal/handler/admin/channel_handler.go
+++ b/backend/internal/handler/admin/channel_handler.go
@@ -34,6 +34,7 @@ type createChannelRequest struct {
BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream channel_mapped"`
RestrictModels bool `json:"restrict_models"`
Features string `json:"features"`
+ FeaturesConfig map[string]any `json:"features_config"`
}
type updateChannelRequest struct {
@@ -46,6 +47,7 @@ type updateChannelRequest struct {
BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream channel_mapped"`
RestrictModels *bool `json:"restrict_models"`
Features *string `json:"features"`
+ FeaturesConfig map[string]any `json:"features_config"`
}
type channelModelPricingRequest struct {
@@ -81,6 +83,7 @@ type channelResponse struct {
BillingModelSource string `json:"billing_model_source"`
RestrictModels bool `json:"restrict_models"`
Features string `json:"features"`
+ FeaturesConfig map[string]any `json:"features_config"`
GroupIDs []int64 `json:"group_ids"`
ModelPricing []channelModelPricingResponse `json:"model_pricing"`
ModelMapping map[string]map[string]string `json:"model_mapping"`
@@ -126,6 +129,7 @@ func channelToResponse(ch *service.Channel) *channelResponse {
Status: ch.Status,
RestrictModels: ch.RestrictModels,
Features: ch.Features,
+ FeaturesConfig: ch.FeaturesConfig,
GroupIDs: ch.GroupIDs,
ModelMapping: ch.ModelMapping,
CreatedAt: ch.CreatedAt.Format("2006-01-02T15:04:05Z"),
@@ -305,6 +309,7 @@ func (h *ChannelHandler) Create(c *gin.Context) {
BillingModelSource: req.BillingModelSource,
RestrictModels: req.RestrictModels,
Features: req.Features,
+ FeaturesConfig: req.FeaturesConfig,
})
if err != nil {
response.ErrorFrom(c, err)
@@ -338,6 +343,7 @@ func (h *ChannelHandler) Update(c *gin.Context) {
BillingModelSource: req.BillingModelSource,
RestrictModels: req.RestrictModels,
Features: req.Features,
+ FeaturesConfig: req.FeaturesConfig,
}
if req.ModelPricing != nil {
pricing := pricingRequestToService(*req.ModelPricing)
diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go
index ba751131..031b819a 100644
--- a/backend/internal/handler/admin/setting_handler.go
+++ b/backend/internal/handler/admin/setting_handler.go
@@ -175,6 +175,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
EnableFingerprintUnification: settings.EnableFingerprintUnification,
EnableMetadataPassthrough: settings.EnableMetadataPassthrough,
EnableCCHSigning: settings.EnableCCHSigning,
+ WebSearchEmulationEnabled: settings.WebSearchEmulationEnabled,
PaymentEnabled: paymentCfg.Enabled,
PaymentMinAmount: paymentCfg.MinAmount,
PaymentMaxAmount: paymentCfg.MaxAmount,
@@ -1847,3 +1848,37 @@ func (h *SettingHandler) UpdateStreamTimeoutSettings(c *gin.Context) {
ThresholdWindowMinutes: updatedSettings.ThresholdWindowMinutes,
})
}
+
+// GetWebSearchEmulationConfig 获取 Web Search 模拟配置
+// GET /api/v1/admin/settings/web-search-emulation
+func (h *SettingHandler) GetWebSearchEmulationConfig(c *gin.Context) {
+ cfg, err := h.settingService.GetWebSearchEmulationConfig(c.Request.Context())
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Success(c, service.SanitizeWebSearchConfig(cfg))
+}
+
+// UpdateWebSearchEmulationConfig 更新 Web Search 模拟配置
+// PUT /api/v1/admin/settings/web-search-emulation
+func (h *SettingHandler) UpdateWebSearchEmulationConfig(c *gin.Context) {
+ var cfg service.WebSearchEmulationConfig
+ if err := c.ShouldBindJSON(&cfg); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+
+ if err := h.settingService.SaveWebSearchEmulationConfig(c.Request.Context(), &cfg); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ // Re-read (with sanitized api keys) to return current state
+ updated, err := h.settingService.GetWebSearchEmulationConfig(c.Request.Context())
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Success(c, service.SanitizeWebSearchConfig(updated))
+}
diff --git a/backend/internal/handler/dto/settings.go b/backend/internal/handler/dto/settings.go
index cbbe9216..0433d692 100644
--- a/backend/internal/handler/dto/settings.go
+++ b/backend/internal/handler/dto/settings.go
@@ -124,6 +124,9 @@ type SystemSettings struct {
EnableMetadataPassthrough bool `json:"enable_metadata_passthrough"`
EnableCCHSigning bool `json:"enable_cch_signing"`
+ // Web Search Emulation
+ WebSearchEmulationEnabled bool `json:"web_search_emulation_enabled"`
+
// Payment configuration
PaymentEnabled bool `json:"payment_enabled"`
PaymentMinAmount float64 `json:"payment_min_amount"`
diff --git a/backend/internal/pkg/websearch/brave.go b/backend/internal/pkg/websearch/brave.go
new file mode 100644
index 00000000..5620ca8d
--- /dev/null
+++ b/backend/internal/pkg/websearch/brave.go
@@ -0,0 +1,106 @@
+package websearch
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+ "net/url"
+ "strconv"
+)
+
+const (
+ braveSearchEndpoint = "https://api.search.brave.com/res/v1/web/search"
+ braveMaxCount = 20
+ braveProviderName = "brave"
+)
+
+// braveSearchURL is pre-parsed at init time; url.Parse cannot fail on a constant literal.
+var braveSearchURL, _ = url.Parse(braveSearchEndpoint) //nolint:errcheck
+
+// BraveProvider implements web search via the Brave Search API.
+type BraveProvider struct {
+ apiKey string
+ httpClient *http.Client
+}
+
+// NewBraveProvider creates a Brave Search provider.
+// The caller is responsible for configuring the http.Client with proxy/timeouts.
+func NewBraveProvider(apiKey string, httpClient *http.Client) *BraveProvider {
+ if httpClient == nil {
+ httpClient = http.DefaultClient
+ }
+ return &BraveProvider{apiKey: apiKey, httpClient: httpClient}
+}
+
+func (b *BraveProvider) Name() string { return braveProviderName }
+
+func (b *BraveProvider) Search(ctx context.Context, req SearchRequest) (*SearchResponse, error) {
+ count := req.MaxResults
+ if count <= 0 {
+ count = defaultMaxResults
+ }
+ if count > braveMaxCount {
+ count = braveMaxCount
+ }
+
+ u := *braveSearchURL // copy the pre-parsed URL
+ q := u.Query()
+ q.Set("q", req.Query)
+ q.Set("count", strconv.Itoa(count))
+ u.RawQuery = q.Encode()
+
+ httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil)
+ if err != nil {
+ return nil, fmt.Errorf("brave: build request: %w", err)
+ }
+ httpReq.Header.Set("X-Subscription-Token", b.apiKey)
+ httpReq.Header.Set("Accept", "application/json")
+
+ resp, err := b.httpClient.Do(httpReq)
+ if err != nil {
+ return nil, fmt.Errorf("brave: request failed: %w", err)
+ }
+ defer resp.Body.Close()
+
+ body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseSize))
+ if err != nil {
+ return nil, fmt.Errorf("brave: read body: %w", err)
+ }
+
+ if resp.StatusCode != http.StatusOK {
+ return nil, fmt.Errorf("brave: status %d: %s", resp.StatusCode, truncateBody(body))
+ }
+
+ var raw braveResponse
+ if err := json.Unmarshal(body, &raw); err != nil {
+ return nil, fmt.Errorf("brave: decode response: %w", err)
+ }
+
+ results := make([]SearchResult, 0, len(raw.Web.Results))
+ for _, r := range raw.Web.Results {
+ results = append(results, SearchResult{
+ URL: r.URL,
+ Title: r.Title,
+ Snippet: r.Description,
+ PageAge: r.Age,
+ })
+ }
+
+ return &SearchResponse{Results: results, Query: req.Query}, nil
+}
+
+// braveResponse is the minimal structure of the Brave Search API response.
+type braveResponse struct {
+ Web struct {
+ Results []braveResult `json:"results"`
+ } `json:"web"`
+}
+
+type braveResult struct {
+ URL string `json:"url"`
+ Title string `json:"title"`
+ Description string `json:"description"`
+ Age string `json:"age"`
+}
diff --git a/backend/internal/pkg/websearch/brave_test.go b/backend/internal/pkg/websearch/brave_test.go
new file mode 100644
index 00000000..3fe35020
--- /dev/null
+++ b/backend/internal/pkg/websearch/brave_test.go
@@ -0,0 +1,119 @@
+package websearch
+
+import (
+ "context"
+ "encoding/json"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestBraveProvider_Name(t *testing.T) {
+ p := NewBraveProvider("key", nil)
+ require.Equal(t, "brave", p.Name())
+}
+
+func TestBraveProvider_Search_Success(t *testing.T) {
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ require.Equal(t, "test-key", r.Header.Get("X-Subscription-Token"))
+ require.Equal(t, "application/json", r.Header.Get("Accept"))
+ require.Equal(t, "golang", r.URL.Query().Get("q"))
+ require.Equal(t, "3", r.URL.Query().Get("count"))
+
+ resp := braveResponse{}
+ resp.Web.Results = []braveResult{
+ {URL: "https://go.dev", Title: "Go", Description: "Go lang", Age: "1 day"},
+ {URL: "https://pkg.go.dev", Title: "Pkg", Description: "Packages"},
+ {URL: "https://tour.go.dev", Title: "Tour", Description: "A Tour of Go", Age: "3 days"},
+ }
+ w.Header().Set("Content-Type", "application/json")
+ json.NewEncoder(w).Encode(resp)
+ }))
+ defer srv.Close()
+
+ p := NewBraveProvider("test-key", srv.Client())
+ // Override the endpoint for testing
+ origURL := *braveSearchURL
+ u, _ := http.NewRequest("GET", srv.URL, nil)
+ *braveSearchURL = *u.URL
+ defer func() { *braveSearchURL = origURL }()
+
+ resp, err := p.Search(context.Background(), SearchRequest{Query: "golang", MaxResults: 3})
+ require.NoError(t, err)
+ require.Len(t, resp.Results, 3)
+ require.Equal(t, "https://go.dev", resp.Results[0].URL)
+ require.Equal(t, "Go lang", resp.Results[0].Snippet)
+ require.Equal(t, "1 day", resp.Results[0].PageAge)
+}
+
+func TestBraveProvider_Search_DefaultMaxResults(t *testing.T) {
+ var receivedCount string
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ receivedCount = r.URL.Query().Get("count")
+ resp := braveResponse{}
+ json.NewEncoder(w).Encode(resp)
+ }))
+ defer srv.Close()
+
+ p := NewBraveProvider("key", srv.Client())
+ origURL := *braveSearchURL
+ u, _ := http.NewRequest("GET", srv.URL, nil)
+ *braveSearchURL = *u.URL
+ defer func() { *braveSearchURL = origURL }()
+
+ _, _ = p.Search(context.Background(), SearchRequest{Query: "test", MaxResults: 0})
+ require.Equal(t, "5", receivedCount)
+}
+
+func TestBraveProvider_Search_HTTPError(t *testing.T) {
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
+ w.WriteHeader(429)
+ w.Write([]byte("rate limited"))
+ }))
+ defer srv.Close()
+
+ p := NewBraveProvider("key", srv.Client())
+ origURL := *braveSearchURL
+ u, _ := http.NewRequest("GET", srv.URL, nil)
+ *braveSearchURL = *u.URL
+ defer func() { *braveSearchURL = origURL }()
+
+ _, err := p.Search(context.Background(), SearchRequest{Query: "test"})
+ require.ErrorContains(t, err, "brave: status 429")
+}
+
+func TestBraveProvider_Search_InvalidJSON(t *testing.T) {
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
+ w.Write([]byte("not json"))
+ }))
+ defer srv.Close()
+
+ p := NewBraveProvider("key", srv.Client())
+ origURL := *braveSearchURL
+ u, _ := http.NewRequest("GET", srv.URL, nil)
+ *braveSearchURL = *u.URL
+ defer func() { *braveSearchURL = origURL }()
+
+ _, err := p.Search(context.Background(), SearchRequest{Query: "test"})
+ require.ErrorContains(t, err, "brave: decode response")
+}
+
+func TestBraveProvider_Search_EmptyResults(t *testing.T) {
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
+ resp := braveResponse{}
+ json.NewEncoder(w).Encode(resp)
+ }))
+ defer srv.Close()
+
+ p := NewBraveProvider("key", srv.Client())
+ origURL := *braveSearchURL
+ u, _ := http.NewRequest("GET", srv.URL, nil)
+ *braveSearchURL = *u.URL
+ defer func() { *braveSearchURL = origURL }()
+
+ resp, err := p.Search(context.Background(), SearchRequest{Query: "test"})
+ require.NoError(t, err)
+ require.Empty(t, resp.Results)
+}
diff --git a/backend/internal/pkg/websearch/helpers.go b/backend/internal/pkg/websearch/helpers.go
new file mode 100644
index 00000000..0d08b749
--- /dev/null
+++ b/backend/internal/pkg/websearch/helpers.go
@@ -0,0 +1,14 @@
+package websearch
+
+const (
+ maxResponseSize = 1 << 20 // 1 MB
+ errorBodyTruncLen = 200
+)
+
+// truncateBody returns a truncated string of body for error messages.
+func truncateBody(body []byte) string {
+ if len(body) <= errorBodyTruncLen {
+ return string(body)
+ }
+ return string(body[:errorBodyTruncLen]) + "...(truncated)"
+}
diff --git a/backend/internal/pkg/websearch/helpers_test.go b/backend/internal/pkg/websearch/helpers_test.go
new file mode 100644
index 00000000..e3164329
--- /dev/null
+++ b/backend/internal/pkg/websearch/helpers_test.go
@@ -0,0 +1,25 @@
+package websearch
+
+import (
+ "strings"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestTruncateBody_Short(t *testing.T) {
+ body := []byte("short body")
+ require.Equal(t, "short body", truncateBody(body))
+}
+
+func TestTruncateBody_Long(t *testing.T) {
+ body := []byte(strings.Repeat("x", 500))
+ result := truncateBody(body)
+ require.Len(t, result, errorBodyTruncLen+len("...(truncated)"))
+ require.True(t, strings.HasSuffix(result, "...(truncated)"))
+}
+
+func TestTruncateBody_ExactBoundary(t *testing.T) {
+ body := []byte(strings.Repeat("x", errorBodyTruncLen))
+ require.Equal(t, string(body), truncateBody(body))
+}
diff --git a/backend/internal/pkg/websearch/manager.go b/backend/internal/pkg/websearch/manager.go
new file mode 100644
index 00000000..95da70e4
--- /dev/null
+++ b/backend/internal/pkg/websearch/manager.go
@@ -0,0 +1,273 @@
+package websearch
+
+import (
+ "context"
+ "crypto/tls"
+ "fmt"
+ "log/slog"
+ "net/http"
+ "net/url"
+ "sort"
+ "strings"
+ "sync"
+ "time"
+
+ "github.com/redis/go-redis/v9"
+)
+
+// Quota refresh interval constants.
+const (
+ QuotaRefreshDaily = "daily"
+ QuotaRefreshWeekly = "weekly"
+ QuotaRefreshMonthly = "monthly"
+)
+
+// ProviderConfig holds the configuration for a single search provider.
+type ProviderConfig struct {
+ Type string `json:"type"` // ProviderTypeBrave | ProviderTypeTavily
+ APIKey string `json:"api_key"` // secret
+ Priority int `json:"priority"` // lower = higher priority
+ QuotaLimit int64 `json:"quota_limit"` // 0 = unlimited
+ QuotaRefreshInterval string `json:"quota_refresh_interval"` // QuotaRefreshDaily / Weekly / Monthly
+ ProxyURL string `json:"-"` // resolved proxy URL (not persisted)
+ ExpiresAt *int64 `json:"expires_at,omitempty"` // optional expiration (unix seconds)
+}
+
+// Manager selects providers by priority and tracks quota via Redis.
+type Manager struct {
+ configs []ProviderConfig
+ redis *redis.Client
+
+ clientMu sync.Mutex
+ clientCache map[string]*http.Client
+}
+
+const (
+ quotaKeyPrefix = "websearch:quota:"
+ searchRequestTimeout = 30 * time.Second
+ quotaTTLBuffer = 24 * time.Hour
+ maxCachedClients = 100
+)
+
+// quotaIncrScript atomically increments the counter and sets TTL on first creation.
+// KEYS[1] = quota key, ARGV[1] = TTL in seconds.
+// Returns the new counter value.
+var quotaIncrScript = redis.NewScript(`
+local val = redis.call('INCR', KEYS[1])
+if val == 1 then
+ redis.call('EXPIRE', KEYS[1], ARGV[1])
+else
+ -- Defensive: ensure TTL exists even if a prior EXPIRE failed
+ local ttl = redis.call('TTL', KEYS[1])
+ if ttl == -1 then
+ redis.call('EXPIRE', KEYS[1], ARGV[1])
+ end
+end
+return val
+`)
+
+// NewManager creates a Manager with the given provider configs and Redis client.
+func NewManager(configs []ProviderConfig, redisClient *redis.Client) *Manager {
+ sorted := make([]ProviderConfig, len(configs))
+ copy(sorted, configs)
+ sort.Slice(sorted, func(i, j int) bool {
+ return sorted[i].Priority < sorted[j].Priority
+ })
+ return &Manager{
+ configs: sorted,
+ redis: redisClient,
+ clientCache: make(map[string]*http.Client),
+ }
+}
+
+// SearchWithBestProvider selects the highest-priority available provider,
+// reserves quota, executes the search, and rolls back quota on failure.
+func (m *Manager) SearchWithBestProvider(ctx context.Context, req SearchRequest) (*SearchResponse, string, error) {
+ if strings.TrimSpace(req.Query) == "" {
+ return nil, "", fmt.Errorf("websearch: empty search query")
+ }
+ for _, cfg := range m.configs {
+ if !m.isProviderAvailable(cfg) {
+ continue
+ }
+ allowed, incremented := m.tryReserveQuota(ctx, cfg)
+ if !allowed {
+ continue
+ }
+ resp, err := m.executeSearch(ctx, cfg, req)
+ if err != nil {
+ if incremented {
+ m.rollbackQuota(ctx, cfg)
+ }
+ slog.Warn("websearch: provider search failed",
+ "provider", cfg.Type, "error", err)
+ continue
+ }
+ return resp, cfg.Type, nil
+ }
+ return nil, "", fmt.Errorf("websearch: no available provider (all exhausted or failed)")
+}
+
+func (m *Manager) isProviderAvailable(cfg ProviderConfig) bool {
+ if cfg.APIKey == "" {
+ return false
+ }
+ if cfg.ExpiresAt != nil && time.Now().Unix() > *cfg.ExpiresAt {
+ slog.Info("websearch: provider expired, skipping",
+ "provider", cfg.Type, "expires_at", *cfg.ExpiresAt)
+ return false
+ }
+ return true
+}
+
+// tryReserveQuota atomically increments the counter via Lua script and checks limit.
+// Returns (allowed, incremented): allowed=true means the request may proceed;
+// incremented=true means the Redis counter was actually incremented (so rollback is needed on failure).
+func (m *Manager) tryReserveQuota(ctx context.Context, cfg ProviderConfig) (bool, bool) {
+ if cfg.QuotaLimit <= 0 {
+ return true, false // unlimited, no INCR
+ }
+ if m.redis == nil {
+ slog.Warn("websearch: Redis unavailable, quota check skipped",
+ "provider", cfg.Type)
+ return true, false // allowed but not incremented
+ }
+ key := quotaRedisKey(cfg.Type, cfg.QuotaRefreshInterval)
+ ttlSec := int(quotaTTL(cfg.QuotaRefreshInterval).Seconds())
+
+ newVal, err := quotaIncrScript.Run(ctx, m.redis, []string{key}, ttlSec).Int64()
+ if err != nil {
+ slog.Warn("websearch: quota Lua INCR failed, allowing request",
+ "provider", cfg.Type, "error", err)
+ return true, false // allowed but not incremented
+ }
+ if newVal > cfg.QuotaLimit {
+ if decrErr := m.redis.Decr(ctx, key).Err(); decrErr != nil {
+ slog.Warn("websearch: quota over-limit DECR failed",
+ "provider", cfg.Type, "error", decrErr)
+ }
+ slog.Info("websearch: provider quota exhausted",
+ "provider", cfg.Type, "used", newVal, "limit", cfg.QuotaLimit)
+ return false, false // rejected, already rolled back
+ }
+ return true, true // allowed and incremented
+}
+
+// rollbackQuota decrements the counter after a search failure.
+func (m *Manager) rollbackQuota(ctx context.Context, cfg ProviderConfig) {
+ if cfg.QuotaLimit <= 0 || m.redis == nil {
+ return
+ }
+ key := quotaRedisKey(cfg.Type, cfg.QuotaRefreshInterval)
+ if err := m.redis.Decr(ctx, key).Err(); err != nil {
+ slog.Warn("websearch: quota rollback DECR failed",
+ "provider", cfg.Type, "error", err)
+ }
+}
+
+func (m *Manager) executeSearch(ctx context.Context, cfg ProviderConfig, req SearchRequest) (*SearchResponse, error) {
+ proxyURL := cfg.ProxyURL
+ if req.ProxyURL != "" {
+ proxyURL = req.ProxyURL
+ }
+ client := m.getOrCreateHTTPClient(proxyURL)
+ provider := m.buildProvider(cfg, client)
+ return provider.Search(ctx, req)
+}
+
+// GetUsage returns the current usage count for the given provider.
+func (m *Manager) GetUsage(ctx context.Context, providerType, refreshInterval string) (int64, error) {
+ if m.redis == nil {
+ return 0, nil
+ }
+ key := quotaRedisKey(providerType, refreshInterval)
+ val, err := m.redis.Get(ctx, key).Int64()
+ if err == redis.Nil {
+ return 0, nil
+ }
+ return val, err
+}
+
+// GetAllUsage returns usage for every configured provider.
+func (m *Manager) GetAllUsage(ctx context.Context) map[string]int64 {
+ result := make(map[string]int64, len(m.configs))
+ for _, cfg := range m.configs {
+ used, _ := m.GetUsage(ctx, cfg.Type, cfg.QuotaRefreshInterval)
+ result[cfg.Type] = used
+ }
+ return result
+}
+
+// --- HTTP client cache (bounded) ---
+
+func (m *Manager) getOrCreateHTTPClient(proxyURL string) *http.Client {
+ m.clientMu.Lock()
+ defer m.clientMu.Unlock()
+
+ if c, ok := m.clientCache[proxyURL]; ok {
+ return c
+ }
+ if len(m.clientCache) >= maxCachedClients {
+ m.clientCache = make(map[string]*http.Client) // evict all
+ }
+ c := newHTTPClient(proxyURL)
+ m.clientCache[proxyURL] = c
+ return c
+}
+
+func newHTTPClient(proxyURL string) *http.Client {
+ transport := &http.Transport{
+ TLSClientConfig: &tls.Config{MinVersion: tls.VersionTLS12},
+ }
+ if proxyURL != "" {
+ if u, err := url.Parse(proxyURL); err == nil {
+ transport.Proxy = http.ProxyURL(u)
+ }
+ }
+ return &http.Client{Transport: transport, Timeout: searchRequestTimeout}
+}
+
+// --- Provider factory ---
+
+func (m *Manager) buildProvider(cfg ProviderConfig, client *http.Client) Provider {
+ switch cfg.Type {
+ case braveProviderName:
+ return NewBraveProvider(cfg.APIKey, client)
+ case tavilyProviderName:
+ return NewTavilyProvider(cfg.APIKey, client)
+ default:
+ slog.Warn("websearch: unknown provider type, falling back to brave",
+ "type", cfg.Type)
+ return NewBraveProvider(cfg.APIKey, client)
+ }
+}
+
+// --- Redis key helpers ---
+
+func quotaRedisKey(providerType, refreshInterval string) string {
+ return quotaKeyPrefix + providerType + ":" + periodKey(refreshInterval)
+}
+
+func periodKey(refreshInterval string) string {
+ now := time.Now().UTC()
+ switch refreshInterval {
+ case QuotaRefreshDaily:
+ return now.Format("2006-01-02")
+ case QuotaRefreshWeekly:
+ year, week := now.ISOWeek()
+ return fmt.Sprintf("%d-W%02d", year, week)
+ default: // QuotaRefreshMonthly
+ return now.Format("2006-01")
+ }
+}
+
+func quotaTTL(refreshInterval string) time.Duration {
+ switch refreshInterval {
+ case QuotaRefreshDaily:
+ return 24*time.Hour + quotaTTLBuffer
+ case QuotaRefreshWeekly:
+ return 7*24*time.Hour + quotaTTLBuffer
+ default:
+ return 31*24*time.Hour + quotaTTLBuffer
+ }
+}
diff --git a/backend/internal/pkg/websearch/manager_test.go b/backend/internal/pkg/websearch/manager_test.go
new file mode 100644
index 00000000..4387a2ee
--- /dev/null
+++ b/backend/internal/pkg/websearch/manager_test.go
@@ -0,0 +1,149 @@
+package websearch
+
+import (
+ "context"
+ "encoding/json"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestNewManager_SortsByPriority(t *testing.T) {
+ configs := []ProviderConfig{
+ {Type: "brave", APIKey: "k3", Priority: 30},
+ {Type: "tavily", APIKey: "k1", Priority: 10},
+ }
+ m := NewManager(configs, nil)
+ require.Equal(t, 10, m.configs[0].Priority)
+ require.Equal(t, 30, m.configs[1].Priority)
+}
+
+func TestManager_SearchWithBestProvider_EmptyQuery(t *testing.T) {
+ m := NewManager([]ProviderConfig{{Type: "brave", APIKey: "k"}}, nil)
+ _, _, err := m.SearchWithBestProvider(context.Background(), SearchRequest{Query: ""})
+ require.ErrorContains(t, err, "empty search query")
+
+ _, _, err = m.SearchWithBestProvider(context.Background(), SearchRequest{Query: " "})
+ require.ErrorContains(t, err, "empty search query")
+}
+
+func TestManager_SearchWithBestProvider_SkipEmptyAPIKey(t *testing.T) {
+ m := NewManager([]ProviderConfig{{Type: "brave", APIKey: ""}}, nil)
+ _, _, err := m.SearchWithBestProvider(context.Background(), SearchRequest{Query: "test"})
+ require.ErrorContains(t, err, "no available provider")
+}
+
+func TestManager_SearchWithBestProvider_SkipExpired(t *testing.T) {
+ past := time.Now().Add(-1 * time.Hour).Unix()
+ m := NewManager([]ProviderConfig{
+ {Type: "brave", APIKey: "k", ExpiresAt: &past},
+ }, nil)
+ _, _, err := m.SearchWithBestProvider(context.Background(), SearchRequest{Query: "test"})
+ require.ErrorContains(t, err, "no available provider")
+}
+
+func TestManager_SearchWithBestProvider_PriorityOrder(t *testing.T) {
+ // Create two mock servers that return different results
+ srvBrave := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
+ resp := braveResponse{}
+ resp.Web.Results = []braveResult{{URL: "https://brave.com", Title: "Brave", Description: "from brave"}}
+ json.NewEncoder(w).Encode(resp)
+ }))
+ defer srvBrave.Close()
+
+ // Override brave endpoint for test
+ origURL := *braveSearchURL
+ u, _ := http.NewRequest("GET", srvBrave.URL, nil)
+ *braveSearchURL = *u.URL
+ defer func() { *braveSearchURL = origURL }()
+
+ m := NewManager([]ProviderConfig{
+ {Type: "brave", APIKey: "k1", Priority: 1},
+ {Type: "tavily", APIKey: "k2", Priority: 2},
+ }, nil)
+ // Inject the test server's client
+ m.clientCache[srvBrave.URL] = srvBrave.Client()
+ m.clientCache[""] = srvBrave.Client()
+
+ resp, providerName, err := m.SearchWithBestProvider(context.Background(), SearchRequest{Query: "test"})
+ require.NoError(t, err)
+ require.Equal(t, "brave", providerName)
+ require.Len(t, resp.Results, 1)
+ require.Equal(t, "from brave", resp.Results[0].Snippet)
+}
+
+func TestManager_SearchWithBestProvider_NilRedis(t *testing.T) {
+ // With nil Redis, quota check is skipped (always allowed)
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
+ resp := braveResponse{}
+ resp.Web.Results = []braveResult{{URL: "https://test.com", Title: "Test", Description: "result"}}
+ json.NewEncoder(w).Encode(resp)
+ }))
+ defer srv.Close()
+
+ origURL := *braveSearchURL
+ u, _ := http.NewRequest("GET", srv.URL, nil)
+ *braveSearchURL = *u.URL
+ defer func() { *braveSearchURL = origURL }()
+
+ m := NewManager([]ProviderConfig{
+ {Type: "brave", APIKey: "k", Priority: 1, QuotaLimit: 100},
+ }, nil) // nil Redis
+ m.clientCache[""] = srv.Client()
+
+ resp, _, err := m.SearchWithBestProvider(context.Background(), SearchRequest{Query: "test"})
+ require.NoError(t, err)
+ require.Len(t, resp.Results, 1)
+}
+
+func TestManager_GetUsage_NilRedis(t *testing.T) {
+ m := NewManager(nil, nil)
+ used, err := m.GetUsage(context.Background(), "brave", "monthly")
+ require.NoError(t, err)
+ require.Equal(t, int64(0), used)
+}
+
+func TestManager_GetAllUsage_NilRedis(t *testing.T) {
+ m := NewManager([]ProviderConfig{
+ {Type: "brave", QuotaRefreshInterval: "monthly"},
+ }, nil)
+ usage := m.GetAllUsage(context.Background())
+ require.Equal(t, int64(0), usage["brave"])
+}
+
+// --- Key/TTL helpers ---
+
+func TestQuotaTTL_Daily(t *testing.T) {
+ require.Equal(t, 24*time.Hour+quotaTTLBuffer, quotaTTL(QuotaRefreshDaily))
+}
+
+func TestQuotaTTL_Weekly(t *testing.T) {
+ require.Equal(t, 7*24*time.Hour+quotaTTLBuffer, quotaTTL(QuotaRefreshWeekly))
+}
+
+func TestQuotaTTL_Monthly(t *testing.T) {
+ require.Equal(t, 31*24*time.Hour+quotaTTLBuffer, quotaTTL(QuotaRefreshMonthly))
+}
+
+func TestPeriodKey_Daily(t *testing.T) {
+ key := periodKey(QuotaRefreshDaily)
+ require.Regexp(t, `^\d{4}-\d{2}-\d{2}$`, key)
+}
+
+func TestPeriodKey_Weekly(t *testing.T) {
+ key := periodKey(QuotaRefreshWeekly)
+ require.Regexp(t, `^\d{4}-W\d{2}$`, key)
+}
+
+func TestPeriodKey_Monthly(t *testing.T) {
+ key := periodKey(QuotaRefreshMonthly)
+ require.Regexp(t, `^\d{4}-\d{2}$`, key)
+}
+
+func TestQuotaRedisKey_Format(t *testing.T) {
+ key := quotaRedisKey("brave", QuotaRefreshDaily)
+ require.Contains(t, key, "websearch:quota:brave:")
+}
diff --git a/backend/internal/pkg/websearch/provider.go b/backend/internal/pkg/websearch/provider.go
new file mode 100644
index 00000000..3424c056
--- /dev/null
+++ b/backend/internal/pkg/websearch/provider.go
@@ -0,0 +1,11 @@
+package websearch
+
+import "context"
+
+// Provider is the interface every search backend must implement.
+type Provider interface {
+ // Name returns the provider identifier ("brave" or "tavily").
+ Name() string
+ // Search executes a web search and returns results.
+ Search(ctx context.Context, req SearchRequest) (*SearchResponse, error)
+}
diff --git a/backend/internal/pkg/websearch/tavily.go b/backend/internal/pkg/websearch/tavily.go
new file mode 100644
index 00000000..6ac09edf
--- /dev/null
+++ b/backend/internal/pkg/websearch/tavily.go
@@ -0,0 +1,107 @@
+package websearch
+
+import (
+ "bytes"
+ "context"
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+)
+
+const (
+ tavilySearchEndpoint = "https://api.tavily.com/search"
+ tavilyProviderName = "tavily"
+ tavilySearchDepthBasic = "basic"
+)
+
+// TavilyProvider implements web search via the Tavily Search API.
+type TavilyProvider struct {
+ apiKey string
+ httpClient *http.Client
+}
+
+// NewTavilyProvider creates a Tavily Search provider.
+// The caller is responsible for configuring the http.Client with proxy/timeouts.
+func NewTavilyProvider(apiKey string, httpClient *http.Client) *TavilyProvider {
+ if httpClient == nil {
+ httpClient = http.DefaultClient
+ }
+ return &TavilyProvider{apiKey: apiKey, httpClient: httpClient}
+}
+
+func (t *TavilyProvider) Name() string { return tavilyProviderName }
+
+func (t *TavilyProvider) Search(ctx context.Context, req SearchRequest) (*SearchResponse, error) {
+ maxResults := req.MaxResults
+ if maxResults <= 0 {
+ maxResults = defaultMaxResults
+ }
+
+ payload := tavilyRequest{
+ APIKey: t.apiKey,
+ Query: req.Query,
+ MaxResults: maxResults,
+ SearchDepth: tavilySearchDepthBasic,
+ }
+
+ bodyBytes, err := json.Marshal(payload)
+ if err != nil {
+ return nil, fmt.Errorf("tavily: encode request: %w", err)
+ }
+
+ httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, tavilySearchEndpoint, bytes.NewReader(bodyBytes))
+ if err != nil {
+ return nil, fmt.Errorf("tavily: build request: %w", err)
+ }
+ httpReq.Header.Set("Content-Type", "application/json")
+
+ resp, err := t.httpClient.Do(httpReq)
+ if err != nil {
+ return nil, fmt.Errorf("tavily: request failed: %w", err)
+ }
+ defer resp.Body.Close()
+
+ body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseSize))
+ if err != nil {
+ return nil, fmt.Errorf("tavily: read body: %w", err)
+ }
+
+ if resp.StatusCode != http.StatusOK {
+ return nil, fmt.Errorf("tavily: status %d: %s", resp.StatusCode, truncateBody(body))
+ }
+
+ var raw tavilyResponse
+ if err := json.Unmarshal(body, &raw); err != nil {
+ return nil, fmt.Errorf("tavily: decode response: %w", err)
+ }
+
+ results := make([]SearchResult, 0, len(raw.Results))
+ for _, r := range raw.Results {
+ results = append(results, SearchResult{
+ URL: r.URL,
+ Title: r.Title,
+ Snippet: r.Content,
+ })
+ }
+
+ return &SearchResponse{Results: results, Query: req.Query}, nil
+}
+
+type tavilyRequest struct {
+ APIKey string `json:"api_key"`
+ Query string `json:"query"`
+ MaxResults int `json:"max_results"`
+ SearchDepth string `json:"search_depth"`
+}
+
+type tavilyResponse struct {
+ Results []tavilyResult `json:"results"`
+}
+
+type tavilyResult struct {
+ URL string `json:"url"`
+ Title string `json:"title"`
+ Content string `json:"content"`
+ Score float64 `json:"score"`
+}
diff --git a/backend/internal/pkg/websearch/tavily_test.go b/backend/internal/pkg/websearch/tavily_test.go
new file mode 100644
index 00000000..e1b6819a
--- /dev/null
+++ b/backend/internal/pkg/websearch/tavily_test.go
@@ -0,0 +1,63 @@
+package websearch
+
+import (
+ "encoding/json"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestTavilyProvider_Name(t *testing.T) {
+ p := NewTavilyProvider("key", nil)
+ require.Equal(t, "tavily", p.Name())
+}
+
+func TestTavilyProvider_Search_RequestConstruction(t *testing.T) {
+ // Verify tavilyRequest struct fields map correctly
+ req := tavilyRequest{
+ APIKey: "test-key",
+ Query: "golang",
+ MaxResults: 3,
+ SearchDepth: tavilySearchDepthBasic,
+ }
+ data, err := json.Marshal(req)
+ require.NoError(t, err)
+
+ var parsed map[string]any
+ require.NoError(t, json.Unmarshal(data, &parsed))
+ require.Equal(t, "test-key", parsed["api_key"])
+ require.Equal(t, "golang", parsed["query"])
+ require.Equal(t, float64(3), parsed["max_results"])
+ require.Equal(t, "basic", parsed["search_depth"])
+}
+
+func TestTavilyProvider_Search_ResponseParsing(t *testing.T) {
+ rawResp := `{"results":[{"url":"https://go.dev","title":"Go","content":"Go programming language","score":0.95}]}`
+ var resp tavilyResponse
+ require.NoError(t, json.Unmarshal([]byte(rawResp), &resp))
+ require.Len(t, resp.Results, 1)
+ require.Equal(t, "https://go.dev", resp.Results[0].URL)
+ require.Equal(t, "Go programming language", resp.Results[0].Content)
+ require.InDelta(t, 0.95, resp.Results[0].Score, 0.001)
+
+ // Verify mapping to SearchResult
+ results := make([]SearchResult, 0, len(resp.Results))
+ for _, r := range resp.Results {
+ results = append(results, SearchResult{
+ URL: r.URL, Title: r.Title, Snippet: r.Content,
+ })
+ }
+ require.Equal(t, "Go programming language", results[0].Snippet)
+ require.Equal(t, "", results[0].PageAge)
+}
+
+func TestTavilyProvider_Search_EmptyResults(t *testing.T) {
+ var resp tavilyResponse
+ require.NoError(t, json.Unmarshal([]byte(`{"results":[]}`), &resp))
+ require.Empty(t, resp.Results)
+}
+
+func TestTavilyProvider_Search_InvalidJSON(t *testing.T) {
+ var resp tavilyResponse
+ require.Error(t, json.Unmarshal([]byte("not json"), &resp))
+}
diff --git a/backend/internal/pkg/websearch/types.go b/backend/internal/pkg/websearch/types.go
new file mode 100644
index 00000000..bb489690
--- /dev/null
+++ b/backend/internal/pkg/websearch/types.go
@@ -0,0 +1,30 @@
+package websearch
+
+// SearchResult represents a single web search result.
+type SearchResult struct {
+ URL string `json:"url"`
+ Title string `json:"title"`
+ Snippet string `json:"snippet"`
+ PageAge string `json:"page_age,omitempty"`
+}
+
+// SearchRequest describes a web search to perform.
+type SearchRequest struct {
+ Query string
+ MaxResults int // defaults to defaultMaxResults if <= 0
+ ProxyURL string // optional HTTP proxy URL
+}
+
+// SearchResponse holds the results of a web search.
+type SearchResponse struct {
+ Results []SearchResult
+ Query string // the query that was actually executed
+}
+
+const defaultMaxResults = 5
+
+// Provider type identifiers.
+const (
+ ProviderTypeBrave = "brave"
+ ProviderTypeTavily = "tavily"
+)
diff --git a/backend/internal/repository/channel_repo.go b/backend/internal/repository/channel_repo.go
index baad31f7..56b5cc71 100644
--- a/backend/internal/repository/channel_repo.go
+++ b/backend/internal/repository/channel_repo.go
@@ -41,10 +41,14 @@ func (r *channelRepository) Create(ctx context.Context, channel *service.Channel
if err != nil {
return err
}
+ featuresConfigJSON, err := marshalFeaturesConfig(channel.FeaturesConfig)
+ if err != nil {
+ return err
+ }
err = tx.QueryRowContext(ctx,
- `INSERT INTO channels (name, description, status, model_mapping, billing_model_source, restrict_models, features) VALUES ($1, $2, $3, $4, $5, $6, $7)
+ `INSERT INTO channels (name, description, status, model_mapping, billing_model_source, restrict_models, features, features_config) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
RETURNING id, created_at, updated_at`,
- channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.BillingModelSource, channel.RestrictModels, channel.Features,
+ channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.BillingModelSource, channel.RestrictModels, channel.Features, featuresConfigJSON,
).Scan(&channel.ID, &channel.CreatedAt, &channel.UpdatedAt)
if err != nil {
if isUniqueViolation(err) {
@@ -73,11 +77,11 @@ func (r *channelRepository) Create(ctx context.Context, channel *service.Channel
func (r *channelRepository) GetByID(ctx context.Context, id int64) (*service.Channel, error) {
ch := &service.Channel{}
- var modelMappingJSON []byte
+ var modelMappingJSON, featuresConfigJSON []byte
err := r.db.QueryRowContext(ctx,
- `SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, features, created_at, updated_at
+ `SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, features, features_config, created_at, updated_at
FROM channels WHERE id = $1`, id,
- ).Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &ch.CreatedAt, &ch.UpdatedAt)
+ ).Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &featuresConfigJSON, &ch.CreatedAt, &ch.UpdatedAt)
if err == sql.ErrNoRows {
return nil, service.ErrChannelNotFound
}
@@ -85,6 +89,7 @@ func (r *channelRepository) GetByID(ctx context.Context, id int64) (*service.Cha
return nil, fmt.Errorf("get channel: %w", err)
}
ch.ModelMapping = unmarshalModelMapping(modelMappingJSON)
+ ch.FeaturesConfig = unmarshalFeaturesConfig(featuresConfigJSON)
groupIDs, err := r.GetGroupIDs(ctx, id)
if err != nil {
@@ -107,10 +112,14 @@ func (r *channelRepository) Update(ctx context.Context, channel *service.Channel
if err != nil {
return err
}
+ featuresConfigJSON, err := marshalFeaturesConfig(channel.FeaturesConfig)
+ if err != nil {
+ return err
+ }
result, err := tx.ExecContext(ctx,
- `UPDATE channels SET name = $1, description = $2, status = $3, model_mapping = $4, billing_model_source = $5, restrict_models = $6, features = $7, updated_at = NOW()
- WHERE id = $8`,
- channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.BillingModelSource, channel.RestrictModels, channel.Features, channel.ID,
+ `UPDATE channels SET name = $1, description = $2, status = $3, model_mapping = $4, billing_model_source = $5, restrict_models = $6, features = $7, features_config = $8, updated_at = NOW()
+ WHERE id = $9`,
+ channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.BillingModelSource, channel.RestrictModels, channel.Features, featuresConfigJSON, channel.ID,
)
if err != nil {
if isUniqueViolation(err) {
@@ -187,9 +196,9 @@ func (r *channelRepository) List(ctx context.Context, params pagination.Paginati
// 查询 channel 列表
dataQuery := fmt.Sprintf(
- `SELECT c.id, c.name, c.description, c.status, c.model_mapping, c.billing_model_source, c.restrict_models, c.features, c.created_at, c.updated_at
- FROM channels c WHERE %s ORDER BY c.id ASC LIMIT $%d OFFSET $%d`,
- whereClause, argIdx, argIdx+1,
+ `SELECT c.id, c.name, c.description, c.status, c.model_mapping, c.billing_model_source, c.restrict_models, c.features, c.features_config, c.created_at, c.updated_at
+ FROM channels c WHERE %s ORDER BY %s LIMIT $%d OFFSET $%d`,
+ whereClause, channelListOrderBy(params), argIdx, argIdx+1,
)
args = append(args, pageSize, offset)
@@ -203,11 +212,12 @@ func (r *channelRepository) List(ctx context.Context, params pagination.Paginati
var channelIDs []int64
for rows.Next() {
var ch service.Channel
- var modelMappingJSON []byte
- if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &ch.CreatedAt, &ch.UpdatedAt); err != nil {
+ var modelMappingJSON, featuresConfigJSON []byte
+ if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &featuresConfigJSON, &ch.CreatedAt, &ch.UpdatedAt); err != nil {
return nil, nil, fmt.Errorf("scan channel: %w", err)
}
ch.ModelMapping = unmarshalModelMapping(modelMappingJSON)
+ ch.FeaturesConfig = unmarshalFeaturesConfig(featuresConfigJSON)
channels = append(channels, ch)
channelIDs = append(channelIDs, ch.ID)
}
@@ -246,9 +256,34 @@ func (r *channelRepository) List(ctx context.Context, params pagination.Paginati
return channels, paginationResult, nil
}
+func channelListOrderBy(params pagination.PaginationParams) string {
+ sortBy := strings.ToLower(strings.TrimSpace(params.SortBy))
+ sortOrder := strings.ToUpper(params.NormalizedSortOrder(pagination.SortOrderAsc))
+
+ var column string
+ switch sortBy {
+ case "":
+ column = "c.id"
+ sortOrder = "ASC"
+ case "id":
+ column = "c.id"
+ case "name":
+ column = "c.name"
+ case "status":
+ column = "c.status"
+ case "created_at":
+ column = "c.created_at"
+ default:
+ column = "c.id"
+ sortOrder = "ASC"
+ }
+
+ return fmt.Sprintf("%s %s, c.id %s", column, sortOrder, sortOrder)
+}
+
func (r *channelRepository) ListAll(ctx context.Context) ([]service.Channel, error) {
rows, err := r.db.QueryContext(ctx,
- `SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, features, created_at, updated_at FROM channels ORDER BY id`,
+ `SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, features, features_config, created_at, updated_at FROM channels ORDER BY id`,
)
if err != nil {
return nil, fmt.Errorf("query all channels: %w", err)
@@ -259,11 +294,12 @@ func (r *channelRepository) ListAll(ctx context.Context) ([]service.Channel, err
var channelIDs []int64
for rows.Next() {
var ch service.Channel
- var modelMappingJSON []byte
- if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &ch.CreatedAt, &ch.UpdatedAt); err != nil {
+ var modelMappingJSON, featuresConfigJSON []byte
+ if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &featuresConfigJSON, &ch.CreatedAt, &ch.UpdatedAt); err != nil {
return nil, fmt.Errorf("scan channel: %w", err)
}
ch.ModelMapping = unmarshalModelMapping(modelMappingJSON)
+ ch.FeaturesConfig = unmarshalFeaturesConfig(featuresConfigJSON)
channels = append(channels, ch)
channelIDs = append(channelIDs, ch.ID)
}
@@ -431,6 +467,28 @@ func unmarshalModelMapping(data []byte) map[string]map[string]string {
return m
}
+func marshalFeaturesConfig(m map[string]any) ([]byte, error) {
+ if len(m) == 0 {
+ return []byte("{}"), nil
+ }
+ data, err := json.Marshal(m)
+ if err != nil {
+ return nil, fmt.Errorf("marshal features_config: %w", err)
+ }
+ return data, nil
+}
+
+func unmarshalFeaturesConfig(data []byte) map[string]any {
+ if len(data) == 0 {
+ return nil
+ }
+ var m map[string]any
+ if err := json.Unmarshal(data, &m); err != nil {
+ return nil
+ }
+ return m
+}
+
// GetGroupPlatforms 批量查询分组 ID 对应的平台
func (r *channelRepository) GetGroupPlatforms(ctx context.Context, groupIDs []int64) (map[int64]string, error) {
if len(groupIDs) == 0 {
diff --git a/backend/internal/server/routes/admin.go b/backend/internal/server/routes/admin.go
index b921da95..7c4e6cb7 100644
--- a/backend/internal/server/routes/admin.go
+++ b/backend/internal/server/routes/admin.go
@@ -407,6 +407,9 @@ func registerSettingsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
// Beta 策略配置
adminSettings.GET("/beta-policy", h.Admin.Setting.GetBetaPolicySettings)
adminSettings.PUT("/beta-policy", h.Admin.Setting.UpdateBetaPolicySettings)
+ // Web Search 模拟配置
+ adminSettings.GET("/web-search-emulation", h.Admin.Setting.GetWebSearchEmulationConfig)
+ adminSettings.PUT("/web-search-emulation", h.Admin.Setting.UpdateWebSearchEmulationConfig)
}
}
diff --git a/backend/internal/service/account.go b/backend/internal/service/account.go
index 512195e3..582b136c 100644
--- a/backend/internal/service/account.go
+++ b/backend/internal/service/account.go
@@ -969,7 +969,7 @@ func (a *Account) IsOveragesEnabled() bool {
return false
}
-// IsOpenAIPassthroughEnabled 返回 OpenAI 账号是否启用“自动透传(仅替换认证)”。
+// IsOpenAIPassthroughEnabled 返回 OpenAI 账号是否启用"自动透传(仅替换认证)"。
//
// 新字段:accounts.extra.openai_passthrough。
// 兼容字段:accounts.extra.openai_oauth_passthrough(历史 OAuth 开关)。
@@ -1133,7 +1133,7 @@ func (a *Account) ResolveOpenAIResponsesWebSocketV2Mode(defaultMode string) stri
return resolvedDefault
}
-// IsOpenAIWSForceHTTPEnabled 返回账号级“强制 HTTP”开关。
+// IsOpenAIWSForceHTTPEnabled 返回账号级"强制 HTTP"开关。
// 字段:accounts.extra.openai_ws_force_http。
func (a *Account) IsOpenAIWSForceHTTPEnabled() bool {
if a == nil || !a.IsOpenAI() || a.Extra == nil {
@@ -1158,7 +1158,7 @@ func (a *Account) IsOpenAIOAuthPassthroughEnabled() bool {
return a != nil && a.IsOpenAIOAuth() && a.IsOpenAIPassthroughEnabled()
}
-// IsAnthropicAPIKeyPassthroughEnabled 返回 Anthropic API Key 账号是否启用“自动透传(仅替换认证)”。
+// IsAnthropicAPIKeyPassthroughEnabled 返回 Anthropic API Key 账号是否启用"自动透传(仅替换认证)"。
// 字段:accounts.extra.anthropic_passthrough。
// 字段缺失或类型不正确时,按 false(关闭)处理。
func (a *Account) IsAnthropicAPIKeyPassthroughEnabled() bool {
@@ -1169,7 +1169,18 @@ func (a *Account) IsAnthropicAPIKeyPassthroughEnabled() bool {
return ok && enabled
}
-// IsCodexCLIOnlyEnabled 返回 OpenAI OAuth 账号是否启用“仅允许 Codex 官方客户端”。
+// IsWebSearchEmulationEnabled 返回 Anthropic API Key 账号是否启用 web search 模拟。
+// 字段:accounts.extra.web_search_emulation。
+// 字段缺失或类型不正确时,按 false(关闭)处理。
+func (a *Account) IsWebSearchEmulationEnabled() bool {
+ if a == nil || a.Platform != PlatformAnthropic || a.Type != AccountTypeAPIKey || a.Extra == nil {
+ return false
+ }
+ enabled, ok := a.Extra[featureKeyWebSearchEmulation].(bool)
+ return ok && enabled
+}
+
+// IsCodexCLIOnlyEnabled 返回 OpenAI OAuth 账号是否启用"仅允许 Codex 官方客户端"。
// 字段:accounts.extra.codex_cli_only。
// 字段缺失或类型不正确时,按 false(关闭)处理。
func (a *Account) IsCodexCLIOnlyEnabled() bool {
diff --git a/backend/internal/service/account_websearch_test.go b/backend/internal/service/account_websearch_test.go
new file mode 100644
index 00000000..fe742ebf
--- /dev/null
+++ b/backend/internal/service/account_websearch_test.go
@@ -0,0 +1,71 @@
+package service
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestAccount_IsWebSearchEmulationEnabled_Enabled(t *testing.T) {
+ a := &Account{
+ Platform: PlatformAnthropic,
+ Type: AccountTypeAPIKey,
+ Extra: map[string]any{featureKeyWebSearchEmulation: true},
+ }
+ require.True(t, a.IsWebSearchEmulationEnabled())
+}
+
+func TestAccount_IsWebSearchEmulationEnabled_Disabled(t *testing.T) {
+ a := &Account{
+ Platform: PlatformAnthropic,
+ Type: AccountTypeAPIKey,
+ Extra: map[string]any{featureKeyWebSearchEmulation: false},
+ }
+ require.False(t, a.IsWebSearchEmulationEnabled())
+}
+
+func TestAccount_IsWebSearchEmulationEnabled_MissingField(t *testing.T) {
+ a := &Account{
+ Platform: PlatformAnthropic,
+ Type: AccountTypeAPIKey,
+ Extra: map[string]any{},
+ }
+ require.False(t, a.IsWebSearchEmulationEnabled())
+}
+
+func TestAccount_IsWebSearchEmulationEnabled_WrongType(t *testing.T) {
+ a := &Account{
+ Platform: PlatformAnthropic,
+ Type: AccountTypeAPIKey,
+ Extra: map[string]any{featureKeyWebSearchEmulation: "true"},
+ }
+ require.False(t, a.IsWebSearchEmulationEnabled())
+}
+
+func TestAccount_IsWebSearchEmulationEnabled_NilExtra(t *testing.T) {
+ a := &Account{Platform: PlatformAnthropic, Type: AccountTypeAPIKey, Extra: nil}
+ require.False(t, a.IsWebSearchEmulationEnabled())
+}
+
+func TestAccount_IsWebSearchEmulationEnabled_NilAccount(t *testing.T) {
+ var a *Account
+ require.False(t, a.IsWebSearchEmulationEnabled())
+}
+
+func TestAccount_IsWebSearchEmulationEnabled_NonAnthropicPlatform(t *testing.T) {
+ a := &Account{
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Extra: map[string]any{featureKeyWebSearchEmulation: true},
+ }
+ require.False(t, a.IsWebSearchEmulationEnabled())
+}
+
+func TestAccount_IsWebSearchEmulationEnabled_NonAPIKeyType(t *testing.T) {
+ a := &Account{
+ Platform: PlatformAnthropic,
+ Type: AccountTypeOAuth,
+ Extra: map[string]any{featureKeyWebSearchEmulation: true},
+ }
+ require.False(t, a.IsWebSearchEmulationEnabled())
+}
diff --git a/backend/internal/service/channel.go b/backend/internal/service/channel.go
index eac81444..baf5c839 100644
--- a/backend/internal/service/channel.go
+++ b/backend/internal/service/channel.go
@@ -49,6 +49,21 @@ type Channel struct {
ModelPricing []ChannelModelPricing
// 渠道级模型映射(按平台分组:platform → {src→dst})
ModelMapping map[string]map[string]string
+ // 渠道特性配置(如 {"web_search_emulation": {"anthropic": true}})
+ FeaturesConfig map[string]any
+}
+
+// IsWebSearchEmulationEnabled 返回该渠道是否为指定平台启用了 web search 模拟。
+func (c *Channel) IsWebSearchEmulationEnabled(platform string) bool {
+ if c == nil || c.FeaturesConfig == nil {
+ return false
+ }
+ wse, ok := c.FeaturesConfig[featureKeyWebSearchEmulation].(map[string]any)
+ if !ok {
+ return false
+ }
+ enabled, ok := wse[platform].(bool)
+ return ok && enabled
}
// ChannelModelPricing 渠道模型定价条目
diff --git a/backend/internal/service/channel_service.go b/backend/internal/service/channel_service.go
index cdf94a4c..7b28662b 100644
--- a/backend/internal/service/channel_service.go
+++ b/backend/internal/service/channel_service.go
@@ -197,10 +197,8 @@ func newEmptyChannelCache() *channelCache {
}
// expandPricingToCache 将渠道的模型定价展开到缓存(按分组+平台维度)。
-// antigravity 平台同时服务 Claude 和 Gemini 模型,需匹配 anthropic/gemini 的定价条目。
-// 缓存 key 使用定价条目的原始平台(pricing.Platform),而非分组平台,
-// 避免跨平台同名模型(如 anthropic 和 gemini 都有 "model-x")互相覆盖。
-// 查找时通过 lookupPricingAcrossPlatforms() 依次尝试所有匹配平台。
+// 各平台严格独立:antigravity 分组只匹配 antigravity 定价,不会匹配 anthropic/gemini 的定价。
+// 查找时通过 lookupPricingAcrossPlatforms() 在本平台内查找。
func expandPricingToCache(cache *channelCache, ch *Channel, gid int64, platform string) {
for j := range ch.ModelPricing {
pricing := &ch.ModelPricing[j]
@@ -226,8 +224,7 @@ func expandPricingToCache(cache *channelCache, ch *Channel, gid int64, platform
}
// expandMappingToCache 将渠道的模型映射展开到缓存(按分组+平台维度)。
-// antigravity 平台同时服务 Claude 和 Gemini 模型。
-// 缓存 key 使用映射条目的原始平台(mappingPlatform),避免跨平台同名映射覆盖。
+// 各平台严格独立:antigravity 分组只匹配 antigravity 映射。
func expandMappingToCache(cache *channelCache, ch *Channel, gid int64, platform string) {
for _, mappingPlatform := range matchingPlatforms(platform) {
platformMapping, ok := ch.ModelMapping[mappingPlatform]
@@ -251,40 +248,58 @@ func expandMappingToCache(cache *channelCache, ch *Channel, gid int64, platform
}
}
+// storeErrorCache 存入短 TTL 空缓存,防止 DB 错误后紧密重试。
+// 通过回退 loadedAt 使剩余 TTL = channelErrorTTL。
+func (s *ChannelService) storeErrorCache() {
+ errorCache := newEmptyChannelCache()
+ errorCache.loadedAt = time.Now().Add(-(channelCacheTTL - channelErrorTTL))
+ s.cache.Store(errorCache)
+}
+
// buildCache 从数据库构建渠道缓存。
// 使用独立 context 避免请求取消导致空值被长期缓存。
func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error) {
- // 断开请求取消链,避免客户端断连导致空值被长期缓存
dbCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), channelCacheDBTimeout)
defer cancel()
- channels, err := s.repo.ListAll(dbCtx)
+ channels, groupPlatforms, err := s.fetchChannelData(dbCtx)
if err != nil {
- // error-TTL:失败时存入短 TTL 空缓存,防止紧密重试
- slog.Warn("failed to build channel cache", "error", err)
- errorCache := newEmptyChannelCache()
- errorCache.loadedAt = time.Now().Add(-(channelCacheTTL - channelErrorTTL)) // 使剩余 TTL = errorTTL
- s.cache.Store(errorCache)
- return nil, fmt.Errorf("list all channels: %w", err)
+ return nil, err
+ }
+
+ cache := populateChannelCache(channels, groupPlatforms)
+ s.cache.Store(cache)
+ return cache, nil
+}
+
+// fetchChannelData 从数据库加载渠道列表和分组平台映射。
+func (s *ChannelService) fetchChannelData(ctx context.Context) ([]Channel, map[int64]string, error) {
+ channels, err := s.repo.ListAll(ctx)
+ if err != nil {
+ slog.Warn("failed to build channel cache", "error", err)
+ s.storeErrorCache()
+ return nil, nil, fmt.Errorf("list all channels: %w", err)
}
- // 收集所有 groupID,批量查询 platform
var allGroupIDs []int64
for i := range channels {
allGroupIDs = append(allGroupIDs, channels[i].GroupIDs...)
}
+
groupPlatforms := make(map[int64]string)
if len(allGroupIDs) > 0 {
- groupPlatforms, err = s.repo.GetGroupPlatforms(dbCtx, allGroupIDs)
+ groupPlatforms, err = s.repo.GetGroupPlatforms(ctx, allGroupIDs)
if err != nil {
slog.Warn("failed to load group platforms for channel cache", "error", err)
- errorCache := newEmptyChannelCache()
- errorCache.loadedAt = time.Now().Add(-(channelCacheTTL - channelErrorTTL))
- s.cache.Store(errorCache)
- return nil, fmt.Errorf("get group platforms: %w", err)
+ s.storeErrorCache()
+ return nil, nil, fmt.Errorf("get group platforms: %w", err)
}
}
+ return channels, groupPlatforms, nil
+}
+// populateChannelCache 将渠道列表和分组平台映射填充到缓存快照中。
+func populateChannelCache(channels []Channel, groupPlatforms map[int64]string) *channelCache {
cache := newEmptyChannelCache()
cache.groupPlatform = groupPlatforms
cache.byID = make(map[int64]*Channel, len(channels))
@@ -293,7 +308,6 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error)
for i := range channels {
ch := &channels[i]
cache.byID[ch.ID] = ch
-
for _, gid := range ch.GroupIDs {
cache.channelByGroupID[gid] = ch
platform := groupPlatforms[gid]
@@ -302,32 +316,20 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error)
}
}
- // 通配符条目保持配置顺序(最先匹配到优先)
-
- s.cache.Store(cache)
- return cache, nil
+ return cache
}
// invalidateCache 使缓存失效,让下次读取时自然重建
// isPlatformPricingMatch 判断定价条目的平台是否匹配分组平台。
-// antigravity 平台同时服务 Claude(anthropic)和 Gemini(gemini)模型,
-// 因此 antigravity 分组应匹配 anthropic 和 gemini 的定价条目。
+// 各平台(antigravity / anthropic / gemini / openai)严格独立,不跨平台匹配。
func isPlatformPricingMatch(groupPlatform, pricingPlatform string) bool {
- if groupPlatform == pricingPlatform {
- return true
- }
- if groupPlatform == PlatformAntigravity {
- return pricingPlatform == PlatformAnthropic || pricingPlatform == PlatformGemini
- }
- return false
+ return groupPlatform == pricingPlatform
}
-// matchingPlatforms 返回分组平台对应的所有可匹配平台列表。
+// matchingPlatforms 返回分组平台对应的可匹配平台列表。
+// 各平台严格独立,只返回自身。
func matchingPlatforms(groupPlatform string) []string {
- if groupPlatform == PlatformAntigravity {
- return []string{PlatformAntigravity, PlatformAnthropic, PlatformGemini}
- }
return []string{groupPlatform}
}
func (s *ChannelService) invalidateCache() {
@@ -364,10 +366,8 @@ func (c *channelCache) matchWildcardMapping(groupID int64, platform, modelLower
return ""
}
-// lookupPricingAcrossPlatforms 在所有匹配平台中查找模型定价。
-// antigravity 分组的缓存 key 使用定价条目的原始平台,因此查找时需依次尝试
-// matchingPlatforms() 返回的所有平台(antigravity → anthropic → gemini),
-// 返回第一个命中的结果。非 antigravity 平台只尝试自身。
+// lookupPricingAcrossPlatforms 在分组平台内查找模型定价。
+// 各平台严格独立,只在本平台内查找(先精确匹配,再通配符)。
func lookupPricingAcrossPlatforms(cache *channelCache, groupID int64, groupPlatform, modelLower string) *ChannelModelPricing {
for _, p := range matchingPlatforms(groupPlatform) {
key := channelModelKey{groupID: groupID, platform: p, model: modelLower}
@@ -384,7 +384,7 @@ func lookupPricingAcrossPlatforms(cache *channelCache, groupID int64, groupPlatf
return nil
}
-// lookupMappingAcrossPlatforms 在所有匹配平台中查找模型映射。
+// lookupMappingAcrossPlatforms 在分组平台内查找模型映射。
// 逻辑与 lookupPricingAcrossPlatforms 相同:先精确查找,再通配符。
func lookupMappingAcrossPlatforms(cache *channelCache, groupID int64, groupPlatform, modelLower string) string {
for _, p := range matchingPlatforms(groupPlatform) {
@@ -442,8 +442,7 @@ func (s *ChannelService) lookupGroupChannel(ctx context.Context, groupID int64)
}
// GetChannelModelPricing 获取指定分组+模型的渠道定价(热路径 O(1))。
-// antigravity 分组依次尝试所有匹配平台(antigravity → anthropic → gemini),
-// 确保跨平台同名模型各自独立匹配。
+// 各平台严格独立,只在本平台内查找定价。
func (s *ChannelService) GetChannelModelPricing(ctx context.Context, groupID int64, model string) *ChannelModelPricing {
lk, err := s.lookupGroupChannel(ctx, groupID)
if err != nil {
@@ -481,7 +480,10 @@ func (s *ChannelService) ResolveChannelMapping(ctx context.Context, groupID int6
// 返回 true 表示模型被限制(不在允许列表中)。
// 如果渠道未启用模型限制或分组无渠道关联,返回 false。
func (s *ChannelService) IsModelRestricted(ctx context.Context, groupID int64, model string) bool {
- lk, _ := s.lookupGroupChannel(ctx, groupID)
+ lk, err := s.lookupGroupChannel(ctx, groupID)
+ if err != nil {
+ slog.Warn("failed to load channel cache for model restriction check", "group_id", groupID, "error", err)
+ }
if lk == nil {
return false
}
@@ -524,7 +526,7 @@ func resolveMapping(lk *channelLookup, groupID int64, model string) ChannelMappi
}
// checkRestricted 基于已查找的渠道信息检查模型是否被限制。
-// antigravity 分组依次尝试所有匹配平台的定价列表。
+// 只在本平台的定价列表中查找。
func checkRestricted(lk *channelLookup, groupID int64, model string) bool {
if !lk.channel.RestrictModels {
return false
@@ -552,6 +554,91 @@ func ReplaceModelInBody(body []byte, newModel string) []byte {
return newBody
}
+// validateChannelConfig 校验渠道的定价和映射配置(冲突检测 + 区间校验 + 计费模式校验)。
+// Create 和 Update 共用此函数,避免重复。
+func validateChannelConfig(pricing []ChannelModelPricing, mapping map[string]map[string]string) error {
+ if err := validateNoConflictingModels(pricing); err != nil {
+ return err
+ }
+ if err := validatePricingIntervals(pricing); err != nil {
+ return err
+ }
+ if err := validateNoConflictingMappings(mapping); err != nil {
+ return err
+ }
+ return validatePricingBillingMode(pricing)
+}
+
+// validatePricingBillingMode 校验计费模式配置:按次/图片模式必须配价格或区间,所有价格字段不能为负,区间至少有一个价格字段。
+func validatePricingBillingMode(pricing []ChannelModelPricing) error {
+ for _, p := range pricing {
+ if err := checkBillingModeRequirements(p); err != nil {
+ return err
+ }
+ if err := checkPricesNotNegative(p); err != nil {
+ return err
+ }
+ if err := checkIntervalsHavePrices(p); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+func checkBillingModeRequirements(p ChannelModelPricing) error {
+ if p.BillingMode == BillingModePerRequest || p.BillingMode == BillingModeImage {
+ if p.PerRequestPrice == nil && len(p.Intervals) == 0 {
+ return infraerrors.BadRequest(
+ "BILLING_MODE_MISSING_PRICE",
+ "per-request price or intervals required for per_request/image billing mode",
+ )
+ }
+ }
+ return nil
+}
+
+func checkPricesNotNegative(p ChannelModelPricing) error {
+ checks := []struct {
+ field string
+ val *float64
+ }{
+ {"input_price", p.InputPrice},
+ {"output_price", p.OutputPrice},
+ {"cache_write_price", p.CacheWritePrice},
+ {"cache_read_price", p.CacheReadPrice},
+ {"image_output_price", p.ImageOutputPrice},
+ {"per_request_price", p.PerRequestPrice},
+ }
+ for _, c := range checks {
+ if c.val != nil && *c.val < 0 {
+ return infraerrors.BadRequest("NEGATIVE_PRICE", fmt.Sprintf("%s must be >= 0", c.field))
+ }
+ }
+ return nil
+}
+
+func checkIntervalsHavePrices(p ChannelModelPricing) error {
+ for _, iv := range p.Intervals {
+ if iv.InputPrice == nil && iv.OutputPrice == nil &&
+ iv.CacheWritePrice == nil && iv.CacheReadPrice == nil &&
+ iv.PerRequestPrice == nil {
+ return infraerrors.BadRequest(
+ "INTERVAL_MISSING_PRICE",
+ fmt.Sprintf("interval [%d, %s] has no price fields set for model %v",
+ iv.MinTokens, formatMaxTokens(iv.MaxTokens), p.Models),
+ )
+ }
+ }
+ return nil
+}
+
+func formatMaxTokens(max *int) string {
+ if max == nil {
+ return "∞"
+ }
+ return fmt.Sprintf("%d", *max)
+}
+
// --- CRUD ---
// Create 创建渠道
@@ -564,15 +651,8 @@ func (s *ChannelService) Create(ctx context.Context, input *CreateChannelInput)
return nil, ErrChannelExists
}
- // 检查分组冲突
- if len(input.GroupIDs) > 0 {
- conflicting, err := s.repo.GetGroupsInOtherChannels(ctx, 0, input.GroupIDs)
- if err != nil {
- return nil, fmt.Errorf("check group conflicts: %w", err)
- }
- if len(conflicting) > 0 {
- return nil, ErrGroupAlreadyInChannel
- }
+ if err := s.checkGroupConflicts(ctx, 0, input.GroupIDs); err != nil {
+ return nil, err
}
channel := &Channel{
@@ -585,18 +665,13 @@ func (s *ChannelService) Create(ctx context.Context, input *CreateChannelInput)
ModelPricing: input.ModelPricing,
ModelMapping: input.ModelMapping,
Features: input.Features,
+ FeaturesConfig: input.FeaturesConfig,
}
if channel.BillingModelSource == "" {
channel.BillingModelSource = BillingModelSourceChannelMapped
}
- if err := validateNoConflictingModels(channel.ModelPricing); err != nil {
- return nil, err
- }
- if err := validatePricingIntervals(channel.ModelPricing); err != nil {
- return nil, err
- }
- if err := validateNoConflictingMappings(channel.ModelMapping); err != nil {
+ if err := validateChannelConfig(channel.ModelPricing, channel.ModelMapping); err != nil {
return nil, err
}
@@ -620,105 +695,118 @@ func (s *ChannelService) Update(ctx context.Context, id int64, input *UpdateChan
return nil, fmt.Errorf("get channel: %w", err)
}
- if input.Name != "" && input.Name != channel.Name {
- exists, err := s.repo.ExistsByNameExcluding(ctx, input.Name, id)
- if err != nil {
- return nil, fmt.Errorf("check channel exists: %w", err)
- }
- if exists {
- return nil, ErrChannelExists
- }
- channel.Name = input.Name
- }
-
- if input.Description != nil {
- channel.Description = *input.Description
- }
-
- if input.Status != "" {
- channel.Status = input.Status
- }
-
- if input.RestrictModels != nil {
- channel.RestrictModels = *input.RestrictModels
- }
- if input.Features != nil {
- channel.Features = *input.Features
- }
-
- // 检查分组冲突
- if input.GroupIDs != nil {
- conflicting, err := s.repo.GetGroupsInOtherChannels(ctx, id, *input.GroupIDs)
- if err != nil {
- return nil, fmt.Errorf("check group conflicts: %w", err)
- }
- if len(conflicting) > 0 {
- return nil, ErrGroupAlreadyInChannel
- }
- channel.GroupIDs = *input.GroupIDs
- }
-
- if input.ModelPricing != nil {
- channel.ModelPricing = *input.ModelPricing
- }
-
- if input.ModelMapping != nil {
- channel.ModelMapping = input.ModelMapping
- }
-
- if input.BillingModelSource != "" {
- channel.BillingModelSource = input.BillingModelSource
- }
-
- if err := validateNoConflictingModels(channel.ModelPricing); err != nil {
- return nil, err
- }
- if err := validatePricingIntervals(channel.ModelPricing); err != nil {
- return nil, err
- }
- if err := validateNoConflictingMappings(channel.ModelMapping); err != nil {
+ if err := s.applyUpdateInput(ctx, channel, input); err != nil {
return nil, err
}
- // 先获取旧分组,Update 后旧分组关联已删除,无法再查到
- var oldGroupIDs []int64
- if s.authCacheInvalidator != nil {
- var err2 error
- oldGroupIDs, err2 = s.repo.GetGroupIDs(ctx, id)
- if err2 != nil {
- slog.Warn("failed to get old group IDs for cache invalidation", "channel_id", id, "error", err2)
- }
+ if err := validateChannelConfig(channel.ModelPricing, channel.ModelMapping); err != nil {
+ return nil, err
}
+ oldGroupIDs := s.getOldGroupIDs(ctx, id)
+
if err := s.repo.Update(ctx, channel); err != nil {
return nil, fmt.Errorf("update channel: %w", err)
}
s.invalidateCache()
-
- // 失效新旧分组的 auth 缓存
- if s.authCacheInvalidator != nil {
- seen := make(map[int64]struct{}, len(oldGroupIDs)+len(channel.GroupIDs))
- for _, gid := range oldGroupIDs {
- if _, ok := seen[gid]; !ok {
- seen[gid] = struct{}{}
- s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, gid)
- }
- }
- for _, gid := range channel.GroupIDs {
- if _, ok := seen[gid]; !ok {
- seen[gid] = struct{}{}
- s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, gid)
- }
- }
- }
+ s.invalidateAuthCacheForGroups(ctx, oldGroupIDs, channel.GroupIDs)
return s.repo.GetByID(ctx, id)
}
+// applyUpdateInput 将更新请求的字段应用到渠道实体上。
+func (s *ChannelService) applyUpdateInput(ctx context.Context, channel *Channel, input *UpdateChannelInput) error {
+ if input.Name != "" && input.Name != channel.Name {
+ exists, err := s.repo.ExistsByNameExcluding(ctx, input.Name, channel.ID)
+ if err != nil {
+ return fmt.Errorf("check channel exists: %w", err)
+ }
+ if exists {
+ return ErrChannelExists
+ }
+ channel.Name = input.Name
+ }
+ if input.Description != nil {
+ channel.Description = *input.Description
+ }
+ if input.Status != "" {
+ channel.Status = input.Status
+ }
+ if input.RestrictModels != nil {
+ channel.RestrictModels = *input.RestrictModels
+ }
+ if input.Features != nil {
+ channel.Features = *input.Features
+ }
+ if input.GroupIDs != nil {
+ if err := s.checkGroupConflicts(ctx, channel.ID, *input.GroupIDs); err != nil {
+ return err
+ }
+ channel.GroupIDs = *input.GroupIDs
+ }
+ if input.ModelPricing != nil {
+ channel.ModelPricing = *input.ModelPricing
+ }
+ if input.ModelMapping != nil {
+ channel.ModelMapping = input.ModelMapping
+ }
+ if input.BillingModelSource != "" {
+ channel.BillingModelSource = input.BillingModelSource
+ }
+ if input.FeaturesConfig != nil {
+ channel.FeaturesConfig = input.FeaturesConfig
+ }
+ return nil
+}
+
+// checkGroupConflicts 检查待关联的分组是否已属于其他渠道。
+// channelID 为当前渠道 ID(Create 时传 0)。
+func (s *ChannelService) checkGroupConflicts(ctx context.Context, channelID int64, groupIDs []int64) error {
+ if len(groupIDs) == 0 {
+ return nil
+ }
+ conflicting, err := s.repo.GetGroupsInOtherChannels(ctx, channelID, groupIDs)
+ if err != nil {
+ return fmt.Errorf("check group conflicts: %w", err)
+ }
+ if len(conflicting) > 0 {
+ return ErrGroupAlreadyInChannel
+ }
+ return nil
+}
+
+// getOldGroupIDs 获取渠道更新前的关联分组 ID(用于失效 auth 缓存)。
+func (s *ChannelService) getOldGroupIDs(ctx context.Context, channelID int64) []int64 {
+ if s.authCacheInvalidator == nil {
+ return nil
+ }
+ oldGroupIDs, err := s.repo.GetGroupIDs(ctx, channelID)
+ if err != nil {
+ slog.Warn("failed to get old group IDs for cache invalidation", "channel_id", channelID, "error", err)
+ }
+ return oldGroupIDs
+}
+
+// invalidateAuthCacheForGroups 对新旧分组去重后逐个失效 auth 缓存。
+func (s *ChannelService) invalidateAuthCacheForGroups(ctx context.Context, groupIDSets ...[]int64) {
+ if s.authCacheInvalidator == nil {
+ return
+ }
+ seen := make(map[int64]struct{})
+ for _, ids := range groupIDSets {
+ for _, gid := range ids {
+ if _, ok := seen[gid]; ok {
+ continue
+ }
+ seen[gid] = struct{}{}
+ s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, gid)
+ }
+ }
+}
+
// Delete 删除渠道
func (s *ChannelService) Delete(ctx context.Context, id int64) error {
- // 先获取关联分组用于失效缓存
groupIDs, err := s.repo.GetGroupIDs(ctx, id)
if err != nil {
slog.Warn("failed to get group IDs before delete", "channel_id", id, "error", err)
@@ -729,12 +817,7 @@ func (s *ChannelService) Delete(ctx context.Context, id int64) error {
}
s.invalidateCache()
-
- if s.authCacheInvalidator != nil {
- for _, gid := range groupIDs {
- s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, gid)
- }
- }
+ s.invalidateAuthCacheForGroups(ctx, groupIDs)
return nil
}
@@ -847,6 +930,7 @@ type CreateChannelInput struct {
BillingModelSource string
RestrictModels bool
Features string
+ FeaturesConfig map[string]any
}
// UpdateChannelInput 更新渠道输入
@@ -860,4 +944,5 @@ type UpdateChannelInput struct {
BillingModelSource string
RestrictModels *bool
Features *string
+ FeaturesConfig map[string]any
}
diff --git a/backend/internal/service/channel_websearch_test.go b/backend/internal/service/channel_websearch_test.go
new file mode 100644
index 00000000..d3dbe45d
--- /dev/null
+++ b/backend/internal/service/channel_websearch_test.go
@@ -0,0 +1,62 @@
+package service
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestChannel_IsWebSearchEmulationEnabled_Enabled(t *testing.T) {
+ c := &Channel{
+ FeaturesConfig: map[string]any{
+ featureKeyWebSearchEmulation: map[string]any{"anthropic": true},
+ },
+ }
+ require.True(t, c.IsWebSearchEmulationEnabled("anthropic"))
+}
+
+func TestChannel_IsWebSearchEmulationEnabled_DifferentPlatform(t *testing.T) {
+ c := &Channel{
+ FeaturesConfig: map[string]any{
+ featureKeyWebSearchEmulation: map[string]any{"anthropic": true},
+ },
+ }
+ require.False(t, c.IsWebSearchEmulationEnabled("openai"))
+}
+
+func TestChannel_IsWebSearchEmulationEnabled_Disabled(t *testing.T) {
+ c := &Channel{
+ FeaturesConfig: map[string]any{
+ featureKeyWebSearchEmulation: map[string]any{"anthropic": false},
+ },
+ }
+ require.False(t, c.IsWebSearchEmulationEnabled("anthropic"))
+}
+
+func TestChannel_IsWebSearchEmulationEnabled_NilFeaturesConfig(t *testing.T) {
+ c := &Channel{FeaturesConfig: nil}
+ require.False(t, c.IsWebSearchEmulationEnabled("anthropic"))
+}
+
+func TestChannel_IsWebSearchEmulationEnabled_NilChannel(t *testing.T) {
+ var c *Channel
+ require.False(t, c.IsWebSearchEmulationEnabled("anthropic"))
+}
+
+func TestChannel_IsWebSearchEmulationEnabled_WrongStructure(t *testing.T) {
+ c := &Channel{
+ FeaturesConfig: map[string]any{
+ featureKeyWebSearchEmulation: true, // not a map
+ },
+ }
+ require.False(t, c.IsWebSearchEmulationEnabled("anthropic"))
+}
+
+func TestChannel_IsWebSearchEmulationEnabled_PlatformValueNotBool(t *testing.T) {
+ c := &Channel{
+ FeaturesConfig: map[string]any{
+ featureKeyWebSearchEmulation: map[string]any{"anthropic": "yes"},
+ },
+ }
+ require.False(t, c.IsWebSearchEmulationEnabled("anthropic"))
+}
diff --git a/backend/internal/service/domain_constants.go b/backend/internal/service/domain_constants.go
index 68d7da3b..f43d388b 100644
--- a/backend/internal/service/domain_constants.go
+++ b/backend/internal/service/domain_constants.go
@@ -249,6 +249,10 @@ const (
SettingKeyEnableMetadataPassthrough = "enable_metadata_passthrough"
// SettingKeyEnableCCHSigning 是否对 billing header 中的 cch 进行 xxHash64 签名(默认 false)
SettingKeyEnableCCHSigning = "enable_cch_signing"
+
+ // Web Search Emulation
+ // SettingKeyWebSearchEmulationConfig 全局 web search 模拟配置(JSON)
+ SettingKeyWebSearchEmulationConfig = "web_search_emulation_config"
)
// AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys).
diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go
index 5d285fb6..77e9b8c8 100644
--- a/backend/internal/service/gateway_service.go
+++ b/backend/internal/service/gateway_service.go
@@ -3785,6 +3785,11 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
return nil, fmt.Errorf("parse request: empty request")
}
+ // Web Search 模拟:纯 web_search 请求时,直接调用搜索 API 构造响应
+ if account != nil && s.shouldEmulateWebSearch(ctx, account, parsed.Body) {
+ return s.handleWebSearchEmulation(ctx, c, account, parsed)
+ }
+
if account != nil && account.IsAnthropicAPIKeyPassthroughEnabled() {
passthroughBody := parsed.Body
passthroughModel := parsed.Model
diff --git a/backend/internal/service/gateway_websearch_emulation.go b/backend/internal/service/gateway_websearch_emulation.go
new file mode 100644
index 00000000..fbea96c0
--- /dev/null
+++ b/backend/internal/service/gateway_websearch_emulation.go
@@ -0,0 +1,358 @@
+package service
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "log/slog"
+ "net/http"
+ "strings"
+ "sync/atomic"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/pkg/websearch"
+ "github.com/gin-gonic/gin"
+ "github.com/google/uuid"
+ "github.com/tidwall/gjson"
+)
+
+// Web search emulation constants
+const (
+ toolTypeWebSearchPrefix = "web_search"
+ toolTypeGoogleSearch = "google_search"
+ toolNameWebSearch = "web_search"
+ toolNameGoogleSearch = "google_search"
+ toolNameWebSearch2025 = "web_search_20250305"
+
+ webSearchDefaultMaxResults = 5
+ defaultWebSearchModel = "claude-sonnet-4-6"
+ webSearchMsgIDPrefix = "msg_ws_"
+ webSearchToolUseIDPrefix = "srvtoolu_ws_"
+ tokenEstimateDivisor = 4
+
+ // featureKeyWebSearchEmulation is the key used in Account.Extra and Channel.FeaturesConfig.
+ featureKeyWebSearchEmulation = "web_search_emulation"
+)
+
+// webSearchManagerPtr stores *websearch.Manager atomically for concurrent safety.
+var webSearchManagerPtr atomic.Pointer[websearch.Manager]
+
+// SetWebSearchManager wires the websearch.Manager into the gateway (goroutine-safe).
+func SetWebSearchManager(m *websearch.Manager) {
+ webSearchManagerPtr.Store(m)
+}
+
+func getWebSearchManager() *websearch.Manager {
+ return webSearchManagerPtr.Load()
+}
+
+// shouldEmulateWebSearch checks whether a request should be intercepted.
+//
+// Judgment chain: manager exists → only web_search tool → global enabled → account enabled.
+// Note: channel-level control is enforced via the account's extra field; the channel toggle
+// in the admin UI sets the account's flag for all accounts in that channel's groups.
+func (s *GatewayService) shouldEmulateWebSearch(ctx context.Context, account *Account, body []byte) bool {
+ if getWebSearchManager() == nil {
+ return false
+ }
+ if !isOnlyWebSearchToolInBody(body) {
+ return false
+ }
+ if !s.settingService.IsWebSearchEmulationEnabled(ctx) {
+ return false
+ }
+ if !account.IsWebSearchEmulationEnabled() {
+ return false
+ }
+ return true
+}
+
+// isOnlyWebSearchToolInBody checks if the body contains exactly one web_search tool.
+func isOnlyWebSearchToolInBody(body []byte) bool {
+ tools := gjson.GetBytes(body, "tools")
+ if !tools.IsArray() {
+ return false
+ }
+ arr := tools.Array()
+ if len(arr) != 1 {
+ return false
+ }
+ return isWebSearchToolJSON(arr[0])
+}
+
+func isWebSearchToolJSON(tool gjson.Result) bool {
+ toolType := tool.Get("type").String()
+ if strings.HasPrefix(toolType, toolTypeWebSearchPrefix) || toolType == toolTypeGoogleSearch {
+ return true
+ }
+ switch tool.Get("name").String() {
+ case toolNameWebSearch, toolNameGoogleSearch, toolNameWebSearch2025:
+ return true
+ }
+ return false
+}
+
+// extractSearchQueryFromBody extracts the last user message text as the search query.
+func extractSearchQueryFromBody(body []byte) string {
+ messages := gjson.GetBytes(body, "messages")
+ if !messages.IsArray() {
+ return ""
+ }
+ arr := messages.Array()
+ if len(arr) == 0 {
+ return ""
+ }
+ lastMsg := arr[len(arr)-1]
+ if lastMsg.Get("role").String() != "user" {
+ return ""
+ }
+ return extractWebSearchTextFromContent(lastMsg.Get("content"))
+}
+
+func extractWebSearchTextFromContent(content gjson.Result) string {
+ if content.Type == gjson.String {
+ return content.String()
+ }
+ if content.IsArray() {
+ for _, block := range content.Array() {
+ if block.Get("type").String() == "text" {
+ if text := block.Get("text").String(); text != "" {
+ return text
+ }
+ }
+ }
+ }
+ return ""
+}
+
+// handleWebSearchEmulation intercepts a web-search-only request,
+// calls a third-party search API, and constructs an Anthropic-format response.
+func (s *GatewayService) handleWebSearchEmulation(
+ ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest,
+) (*ForwardResult, error) {
+ startTime := time.Now()
+
+ // Release the serial queue lock immediately — we don't need upstream.
+ if parsed.OnUpstreamAccepted != nil {
+ parsed.OnUpstreamAccepted()
+ }
+
+ query := extractSearchQueryFromBody(parsed.Body)
+ if query == "" {
+ return nil, fmt.Errorf("web search emulation: no query found in messages")
+ }
+
+ slog.Info("web search emulation: executing search",
+ "account_id", account.ID, "account_name", account.Name, "query", query)
+
+ resp, providerName, err := doWebSearch(ctx, account, query)
+ if err != nil {
+ return nil, err
+ }
+
+ slog.Info("web search emulation: search completed",
+ "provider", providerName, "results_count", len(resp.Results))
+
+ model := parsed.Model
+ if model == "" {
+ model = defaultWebSearchModel
+ }
+
+ if parsed.Stream {
+ return writeWebSearchStreamResponse(c, query, resp, model, startTime)
+ }
+ return writeWebSearchNonStreamResponse(c, query, resp, model, startTime)
+}
+
+func doWebSearch(ctx context.Context, account *Account, query string) (*websearch.SearchResponse, string, error) {
+ proxyURL := resolveAccountProxyURL(account)
+ mgr := getWebSearchManager()
+ if mgr == nil {
+ return nil, "", fmt.Errorf("web search emulation: manager not initialized")
+ }
+ resp, providerName, err := mgr.SearchWithBestProvider(ctx, websearch.SearchRequest{
+ Query: query, MaxResults: webSearchDefaultMaxResults, ProxyURL: proxyURL,
+ })
+ if err != nil {
+ slog.Error("web search emulation: search failed", "error", err)
+ return nil, "", fmt.Errorf("web search emulation: %w", err)
+ }
+ return resp, providerName, nil
+}
+
+func resolveAccountProxyURL(account *Account) string {
+ if account.ProxyID != nil && account.Proxy != nil {
+ return account.Proxy.URL()
+ }
+ return ""
+}
+
+// --- SSE streaming response ---
+
+func writeWebSearchStreamResponse(
+ c *gin.Context, query string, resp *websearch.SearchResponse, model string, startTime time.Time,
+) (*ForwardResult, error) {
+ msgID := webSearchMsgIDPrefix + uuid.New().String()
+ toolUseID := webSearchToolUseIDPrefix + uuid.New().String()[:16]
+
+ setSSEHeaders(c)
+ if err := writeSSEMessageStart(c.Writer, msgID, model); err != nil {
+ return nil, fmt.Errorf("web search emulation: SSE write: %w", err)
+ }
+ writeSSEServerToolUse(c.Writer, toolUseID, query, 0)
+ writeSSEToolResult(c.Writer, toolUseID, resp.Results, 1)
+ textSummary := buildTextSummary(query, resp.Results)
+ writeSSETextBlock(c.Writer, textSummary, 2)
+ writeSSEMessageEnd(c.Writer, len(textSummary)/tokenEstimateDivisor)
+ c.Writer.Flush()
+
+ return &ForwardResult{Model: model, Duration: time.Since(startTime), Usage: ClaudeUsage{}}, nil
+}
+
+func setSSEHeaders(c *gin.Context) {
+ c.Writer.Header().Set("Content-Type", "text/event-stream")
+ c.Writer.Header().Set("Cache-Control", "no-cache")
+ c.Writer.Header().Set("Connection", "keep-alive")
+ c.Writer.Header().Set("X-Accel-Buffering", "no")
+ c.Writer.WriteHeader(http.StatusOK)
+}
+
+func writeSSEMessageStart(w http.ResponseWriter, msgID, model string) error {
+ evt := map[string]any{
+ "type": "message_start",
+ "message": map[string]any{
+ "id": msgID, "type": "message", "role": "assistant", "model": model,
+ "content": []any{}, "stop_reason": nil, "stop_sequence": nil,
+ "usage": map[string]int{"input_tokens": 0, "output_tokens": 0},
+ },
+ }
+ return flushSSEJSON(w, "message_start", evt)
+}
+
+func writeSSEServerToolUse(w http.ResponseWriter, toolUseID, query string, index int) {
+ start := map[string]any{
+ "type": "content_block_start", "index": index,
+ "content_block": map[string]any{
+ "type": "server_tool_use", "id": toolUseID,
+ "name": toolNameWebSearch, "input": map[string]string{"query": query},
+ },
+ }
+ _ = flushSSEJSON(w, "content_block_start", start)
+ _ = flushSSEJSON(w, "content_block_stop", map[string]any{"type": "content_block_stop", "index": index})
+}
+
+func writeSSEToolResult(w http.ResponseWriter, toolUseID string, results []websearch.SearchResult, index int) {
+ start := map[string]any{
+ "type": "content_block_start", "index": index,
+ "content_block": map[string]any{
+ "type": "web_search_tool_result", "tool_use_id": toolUseID,
+ "content": buildSearchResultBlocks(results),
+ },
+ }
+ _ = flushSSEJSON(w, "content_block_start", start)
+ _ = flushSSEJSON(w, "content_block_stop", map[string]any{"type": "content_block_stop", "index": index})
+}
+
+func writeSSETextBlock(w http.ResponseWriter, text string, index int) {
+ _ = flushSSEJSON(w, "content_block_start", map[string]any{
+ "type": "content_block_start", "index": index,
+ "content_block": map[string]any{"type": "text", "text": ""},
+ })
+ _ = flushSSEJSON(w, "content_block_delta", map[string]any{
+ "type": "content_block_delta", "index": index,
+ "delta": map[string]string{"type": "text_delta", "text": text},
+ })
+ _ = flushSSEJSON(w, "content_block_stop", map[string]any{"type": "content_block_stop", "index": index})
+}
+
+func writeSSEMessageEnd(w http.ResponseWriter, outputTokens int) {
+ _ = flushSSEJSON(w, "message_delta", map[string]any{
+ "type": "message_delta",
+ "delta": map[string]any{"stop_reason": "end_turn", "stop_sequence": nil},
+ "usage": map[string]int{"output_tokens": outputTokens},
+ })
+ _ = flushSSEJSON(w, "message_stop", map[string]string{"type": "message_stop"})
+}
+
+// flushSSEJSON marshals data to JSON and writes an SSE event. Returns error on marshal failure.
+func flushSSEJSON(w http.ResponseWriter, event string, data any) error {
+ b, err := json.Marshal(data)
+ if err != nil {
+ slog.Error("web search emulation: failed to marshal SSE event",
+ "event", event, "error", err)
+ return err
+ }
+ fmt.Fprintf(w, "event: %s\ndata: %s\n\n", event, b)
+ if f, ok := w.(http.Flusher); ok {
+ f.Flush()
+ }
+ return nil
+}
+
+// --- Non-streaming JSON response ---
+
+func writeWebSearchNonStreamResponse(
+ c *gin.Context, query string, resp *websearch.SearchResponse, model string, startTime time.Time,
+) (*ForwardResult, error) {
+ msgID := webSearchMsgIDPrefix + uuid.New().String()
+ toolUseID := webSearchToolUseIDPrefix + uuid.New().String()[:16]
+ textSummary := buildTextSummary(query, resp.Results)
+
+ msg := map[string]any{
+ "id": msgID, "type": "message", "role": "assistant", "model": model,
+ "content": []any{
+ map[string]any{
+ "type": "server_tool_use", "id": toolUseID,
+ "name": toolNameWebSearch, "input": map[string]string{"query": query},
+ },
+ map[string]any{
+ "type": "web_search_tool_result", "tool_use_id": toolUseID,
+ "content": buildSearchResultBlocks(resp.Results),
+ },
+ map[string]any{"type": "text", "text": textSummary},
+ },
+ "stop_reason": "end_turn", "stop_sequence": nil,
+ "usage": map[string]int{"input_tokens": 0, "output_tokens": len(textSummary) / tokenEstimateDivisor},
+ }
+
+ body, err := json.Marshal(msg)
+ if err != nil {
+ return nil, fmt.Errorf("web search emulation: marshal response: %w", err)
+ }
+ c.Data(http.StatusOK, "application/json", body)
+
+ return &ForwardResult{Model: model, Duration: time.Since(startTime), Usage: ClaudeUsage{}}, nil
+}
+
+// --- Helpers ---
+
+func buildSearchResultBlocks(results []websearch.SearchResult) []map[string]string {
+ blocks := make([]map[string]string, 0, len(results))
+ for _, r := range results {
+ block := map[string]string{
+ "type": "web_search_result",
+ "url": r.URL,
+ "title": r.Title,
+ }
+ if r.Snippet != "" {
+ block["page_content"] = r.Snippet
+ }
+ if r.PageAge != "" {
+ block["page_age"] = r.PageAge
+ }
+ blocks = append(blocks, block)
+ }
+ return blocks
+}
+
+func buildTextSummary(query string, results []websearch.SearchResult) string {
+ if len(results) == 0 {
+ return "No search results found for: " + query
+ }
+ var sb strings.Builder
+ fmt.Fprintf(&sb, "Here are the search results for \"%s\":\n\n", query)
+ for i, r := range results {
+ fmt.Fprintf(&sb, "%d. **%s**\n %s\n %s\n\n", i+1, r.Title, r.URL, r.Snippet)
+ }
+ return sb.String()
+}
diff --git a/backend/internal/service/gateway_websearch_emulation_test.go b/backend/internal/service/gateway_websearch_emulation_test.go
new file mode 100644
index 00000000..b606c748
--- /dev/null
+++ b/backend/internal/service/gateway_websearch_emulation_test.go
@@ -0,0 +1,142 @@
+package service
+
+import (
+ "testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/pkg/websearch"
+ "github.com/stretchr/testify/require"
+)
+
+// --- isOnlyWebSearchToolInBody ---
+
+func TestIsOnlyWebSearchToolInBody_WebSearchType(t *testing.T) {
+ require.True(t, isOnlyWebSearchToolInBody([]byte(`{"tools":[{"type":"web_search"}]}`)))
+}
+
+func TestIsOnlyWebSearchToolInBody_WebSearch2025Type(t *testing.T) {
+ require.True(t, isOnlyWebSearchToolInBody([]byte(`{"tools":[{"type":"web_search_20250305"}]}`)))
+}
+
+func TestIsOnlyWebSearchToolInBody_GoogleSearchType(t *testing.T) {
+ require.True(t, isOnlyWebSearchToolInBody([]byte(`{"tools":[{"type":"google_search"}]}`)))
+}
+
+func TestIsOnlyWebSearchToolInBody_NameWebSearch(t *testing.T) {
+ require.True(t, isOnlyWebSearchToolInBody([]byte(`{"tools":[{"name":"web_search"}]}`)))
+}
+
+func TestIsOnlyWebSearchToolInBody_NameWebSearch2025(t *testing.T) {
+ require.True(t, isOnlyWebSearchToolInBody([]byte(`{"tools":[{"name":"web_search_20250305"}]}`)))
+}
+
+func TestIsOnlyWebSearchToolInBody_NameGoogleSearch(t *testing.T) {
+ require.True(t, isOnlyWebSearchToolInBody([]byte(`{"tools":[{"name":"google_search"}]}`)))
+}
+
+func TestIsOnlyWebSearchToolInBody_MultipleTools(t *testing.T) {
+ require.False(t, isOnlyWebSearchToolInBody(
+ []byte(`{"tools":[{"type":"web_search"},{"type":"text_editor"}]}`)))
+}
+
+func TestIsOnlyWebSearchToolInBody_NoTools(t *testing.T) {
+ require.False(t, isOnlyWebSearchToolInBody([]byte(`{"model":"claude-3"}`)))
+}
+
+func TestIsOnlyWebSearchToolInBody_EmptyToolsArray(t *testing.T) {
+ require.False(t, isOnlyWebSearchToolInBody([]byte(`{"tools":[]}`)))
+}
+
+func TestIsOnlyWebSearchToolInBody_NonWebSearchTool(t *testing.T) {
+ require.False(t, isOnlyWebSearchToolInBody([]byte(`{"tools":[{"type":"text_editor"}]}`)))
+}
+
+func TestIsOnlyWebSearchToolInBody_ToolsNotArray(t *testing.T) {
+ require.False(t, isOnlyWebSearchToolInBody([]byte(`{"tools":"web_search"}`)))
+}
+
+// --- extractSearchQueryFromBody ---
+
+func TestExtractSearchQueryFromBody_StringContent(t *testing.T) {
+ body := `{"messages":[{"role":"user","content":"what is golang"}]}`
+ require.Equal(t, "what is golang", extractSearchQueryFromBody([]byte(body)))
+}
+
+func TestExtractSearchQueryFromBody_ArrayContent(t *testing.T) {
+ body := `{"messages":[{"role":"user","content":[{"type":"text","text":"search this"}]}]}`
+ require.Equal(t, "search this", extractSearchQueryFromBody([]byte(body)))
+}
+
+func TestExtractSearchQueryFromBody_MultipleMessages(t *testing.T) {
+ body := `{"messages":[{"role":"user","content":"first"},{"role":"assistant","content":"ok"},{"role":"user","content":"second"}]}`
+ require.Equal(t, "second", extractSearchQueryFromBody([]byte(body)))
+}
+
+func TestExtractSearchQueryFromBody_LastMessageNotUser(t *testing.T) {
+ body := `{"messages":[{"role":"user","content":"q"},{"role":"assistant","content":"a"}]}`
+ require.Equal(t, "", extractSearchQueryFromBody([]byte(body)))
+}
+
+func TestExtractSearchQueryFromBody_EmptyMessages(t *testing.T) {
+ require.Equal(t, "", extractSearchQueryFromBody([]byte(`{"messages":[]}`)))
+}
+
+func TestExtractSearchQueryFromBody_NoMessages(t *testing.T) {
+ require.Equal(t, "", extractSearchQueryFromBody([]byte(`{"model":"claude-3"}`)))
+}
+
+func TestExtractSearchQueryFromBody_ArrayContentSkipsEmptyText(t *testing.T) {
+ body := `{"messages":[{"role":"user","content":[{"type":"image"},{"type":"text","text":""},{"type":"text","text":"real query"}]}]}`
+ require.Equal(t, "real query", extractSearchQueryFromBody([]byte(body)))
+}
+
+func TestExtractSearchQueryFromBody_ArrayContentNoTextBlock(t *testing.T) {
+ body := `{"messages":[{"role":"user","content":[{"type":"image","source":{}}]}]}`
+ require.Equal(t, "", extractSearchQueryFromBody([]byte(body)))
+}
+
+// --- buildSearchResultBlocks ---
+
+func TestBuildSearchResultBlocks_WithResults(t *testing.T) {
+ results := []websearch.SearchResult{
+ {URL: "https://a.com", Title: "A", Snippet: "snippet a", PageAge: "2 days"},
+ {URL: "https://b.com", Title: "B", Snippet: "snippet b"},
+ }
+ blocks := buildSearchResultBlocks(results)
+ require.Len(t, blocks, 2)
+ require.Equal(t, "web_search_result", blocks[0]["type"])
+ require.Equal(t, "https://a.com", blocks[0]["url"])
+ require.Equal(t, "snippet a", blocks[0]["page_content"])
+ require.Equal(t, "2 days", blocks[0]["page_age"])
+ // Second result has no PageAge
+ require.Equal(t, "https://b.com", blocks[1]["url"])
+ _, hasPageAge := blocks[1]["page_age"]
+ require.False(t, hasPageAge)
+}
+
+func TestBuildSearchResultBlocks_Empty(t *testing.T) {
+ blocks := buildSearchResultBlocks(nil)
+ require.Empty(t, blocks)
+}
+
+func TestBuildSearchResultBlocks_SnippetEmpty(t *testing.T) {
+ blocks := buildSearchResultBlocks([]websearch.SearchResult{{URL: "https://x.com", Title: "X", Snippet: ""}})
+ _, hasContent := blocks[0]["page_content"]
+ require.False(t, hasContent)
+}
+
+// --- buildTextSummary ---
+
+func TestBuildTextSummary_WithResults(t *testing.T) {
+ results := []websearch.SearchResult{
+ {URL: "https://a.com", Title: "A", Snippet: "desc a"},
+ }
+ summary := buildTextSummary("test query", results)
+ require.Contains(t, summary, "test query")
+ require.Contains(t, summary, "1. **A**")
+ require.Contains(t, summary, "https://a.com")
+}
+
+func TestBuildTextSummary_NoResults(t *testing.T) {
+ summary := buildTextSummary("test", nil)
+ require.Contains(t, summary, "No search results found for: test")
+}
diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go
index 48f25da0..3cfe5e56 100644
--- a/backend/internal/service/setting_service.go
+++ b/backend/internal/service/setting_service.go
@@ -18,6 +18,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/imroc/req/v3"
+ "github.com/redis/go-redis/v9"
"golang.org/x/sync/singleflight"
)
@@ -106,6 +107,7 @@ type SettingService struct {
cfg *config.Config
onUpdate func() // Callback when settings are updated (for cache invalidation)
version string // Application version
+ webSearchRedis *redis.Client // optional: Redis client for web search quota tracking
}
// NewSettingService 创建系统设置服务实例
@@ -1217,6 +1219,14 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
result.EnableMetadataPassthrough = settings[SettingKeyEnableMetadataPassthrough] == "true"
result.EnableCCHSigning = settings[SettingKeyEnableCCHSigning] == "true"
+ // Web search emulation: quick enabled check from the JSON config
+ if raw := settings[SettingKeyWebSearchEmulationConfig]; raw != "" {
+ var wsCfg WebSearchEmulationConfig
+ if err := json.Unmarshal([]byte(raw), &wsCfg); err == nil {
+ result.WebSearchEmulationEnabled = wsCfg.Enabled && len(wsCfg.Providers) > 0
+ }
+ }
+
return result
}
diff --git a/backend/internal/service/settings_view.go b/backend/internal/service/settings_view.go
index de92b796..f5535bca 100644
--- a/backend/internal/service/settings_view.go
+++ b/backend/internal/service/settings_view.go
@@ -106,6 +106,9 @@ type SystemSettings struct {
EnableFingerprintUnification bool // 是否统一 OAuth 账号的指纹头(默认 true)
EnableMetadataPassthrough bool // 是否透传客户端原始 metadata(默认 false)
EnableCCHSigning bool // 是否对 billing header cch 进行签名(默认 false)
+
+ // Web Search Emulation (read-only quick check; full config via dedicated API)
+ WebSearchEmulationEnabled bool
}
type DefaultSubscriptionSetting struct {
diff --git a/backend/internal/service/websearch_config.go b/backend/internal/service/websearch_config.go
new file mode 100644
index 00000000..15ec1f9d
--- /dev/null
+++ b/backend/internal/service/websearch_config.go
@@ -0,0 +1,253 @@
+package service
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "log/slog"
+ "sync/atomic"
+ "time"
+
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/websearch"
+ "github.com/redis/go-redis/v9"
+ "golang.org/x/sync/singleflight"
+)
+
+// WebSearchEmulationConfig holds the global web search emulation configuration.
+type WebSearchEmulationConfig struct {
+ Enabled bool `json:"enabled"`
+ Providers []WebSearchProviderConfig `json:"providers"`
+}
+
+// WebSearchProviderConfig describes a single search provider (Brave or Tavily).
+type WebSearchProviderConfig struct {
+ Type string `json:"type"` // websearch.ProviderTypeBrave | Tavily
+ APIKey string `json:"api_key,omitempty"` // secret — omitted in API responses
+ APIKeyConfigured bool `json:"api_key_configured"` // read-only mask
+ Priority int `json:"priority"` // lower = higher priority
+ QuotaLimit int64 `json:"quota_limit"` // 0 = unlimited
+ QuotaRefreshInterval string `json:"quota_refresh_interval"` // websearch.QuotaRefresh*
+ QuotaUsed int64 `json:"quota_used,omitempty"` // read-only: current period usage
+ ProxyID *int64 `json:"proxy_id"` // optional proxy association
+ ExpiresAt *int64 `json:"expires_at,omitempty"` // optional expiration timestamp
+}
+
+// --- Validation ---
+
+const maxWebSearchProviders = 10
+
+var validProviderTypes = map[string]bool{
+ websearch.ProviderTypeBrave: true,
+ websearch.ProviderTypeTavily: true,
+}
+
+var validQuotaIntervals = map[string]bool{
+ websearch.QuotaRefreshDaily: true,
+ websearch.QuotaRefreshWeekly: true,
+ websearch.QuotaRefreshMonthly: true,
+ "": true, // defaults to monthly
+}
+
+func validateWebSearchConfig(cfg *WebSearchEmulationConfig) error {
+ if cfg == nil {
+ return nil
+ }
+ if len(cfg.Providers) > maxWebSearchProviders {
+ return fmt.Errorf("too many providers (max %d)", maxWebSearchProviders)
+ }
+ seen := make(map[string]bool, len(cfg.Providers))
+ for i, p := range cfg.Providers {
+ if !validProviderTypes[p.Type] {
+ return fmt.Errorf("provider[%d]: invalid type %q", i, p.Type)
+ }
+ if !validQuotaIntervals[p.QuotaRefreshInterval] {
+ return fmt.Errorf("provider[%d]: invalid quota_refresh_interval %q", i, p.QuotaRefreshInterval)
+ }
+ if p.QuotaLimit < 0 {
+ return fmt.Errorf("provider[%d]: quota_limit must be >= 0", i)
+ }
+ if seen[p.Type] {
+ return fmt.Errorf("provider[%d]: duplicate type %q", i, p.Type)
+ }
+ seen[p.Type] = true
+ }
+ return nil
+}
+
+// --- In-process cache (same pattern as gateway forwarding settings) ---
+
+const sfKeyWebSearchConfig = "web_search_emulation_config"
+
+type cachedWebSearchEmulationConfig struct {
+ config *WebSearchEmulationConfig
+ expiresAt int64 // unix nano
+}
+
+var webSearchEmulationCache atomic.Value // *cachedWebSearchEmulationConfig
+var webSearchEmulationSF singleflight.Group
+
+const (
+ webSearchEmulationCacheTTL = 60 * time.Second
+ webSearchEmulationErrorTTL = 5 * time.Second
+ webSearchEmulationDBTimeout = 5 * time.Second
+)
+
+// GetWebSearchEmulationConfig returns the configuration with in-process cache + singleflight.
+func (s *SettingService) GetWebSearchEmulationConfig(ctx context.Context) (*WebSearchEmulationConfig, error) {
+ if cached := webSearchEmulationCache.Load(); cached != nil {
+ c := cached.(*cachedWebSearchEmulationConfig)
+ if time.Now().UnixNano() < c.expiresAt {
+ return c.config, nil
+ }
+ }
+ result, err, _ := webSearchEmulationSF.Do(sfKeyWebSearchConfig, func() (any, error) {
+ return s.loadWebSearchConfigFromDB()
+ })
+ if err != nil {
+ return &WebSearchEmulationConfig{}, err
+ }
+ return result.(*WebSearchEmulationConfig), nil
+}
+
+func (s *SettingService) loadWebSearchConfigFromDB() (*WebSearchEmulationConfig, error) {
+ dbCtx, cancel := context.WithTimeout(context.Background(), webSearchEmulationDBTimeout)
+ defer cancel()
+
+ raw, err := s.settingRepo.GetValue(dbCtx, SettingKeyWebSearchEmulationConfig)
+ if err != nil {
+ webSearchEmulationCache.Store(&cachedWebSearchEmulationConfig{
+ config: &WebSearchEmulationConfig{},
+ expiresAt: time.Now().Add(webSearchEmulationErrorTTL).UnixNano(),
+ })
+ return &WebSearchEmulationConfig{}, err
+ }
+ cfg := parseWebSearchConfigJSON(raw)
+ webSearchEmulationCache.Store(&cachedWebSearchEmulationConfig{
+ config: cfg,
+ expiresAt: time.Now().Add(webSearchEmulationCacheTTL).UnixNano(),
+ })
+ return cfg, nil
+}
+
+func parseWebSearchConfigJSON(raw string) *WebSearchEmulationConfig {
+ cfg := &WebSearchEmulationConfig{}
+ if raw == "" {
+ return cfg
+ }
+ if err := json.Unmarshal([]byte(raw), cfg); err != nil {
+ slog.Warn("websearch: failed to parse config JSON", "error", err)
+ return &WebSearchEmulationConfig{}
+ }
+ return cfg
+}
+
+// SaveWebSearchEmulationConfig validates and persists the configuration.
+// Empty API keys in the input are preserved from the existing config.
+func (s *SettingService) SaveWebSearchEmulationConfig(ctx context.Context, cfg *WebSearchEmulationConfig) error {
+ if err := validateWebSearchConfig(cfg); err != nil {
+ return infraerrors.BadRequest("INVALID_WEB_SEARCH_CONFIG", err.Error())
+ }
+ s.mergeExistingAPIKeys(ctx, cfg)
+
+ data, err := json.Marshal(cfg)
+ if err != nil {
+ return fmt.Errorf("websearch: marshal config: %w", err)
+ }
+ if err := s.settingRepo.Set(ctx, SettingKeyWebSearchEmulationConfig, string(data)); err != nil {
+ return fmt.Errorf("websearch: save config: %w", err)
+ }
+ // Invalidate: forget singleflight first, then store new value
+ webSearchEmulationSF.Forget(sfKeyWebSearchConfig)
+ webSearchEmulationCache.Store(&cachedWebSearchEmulationConfig{
+ config: cfg,
+ expiresAt: time.Now().Add(webSearchEmulationCacheTTL).UnixNano(),
+ })
+
+ // Hot-reload: rebuild the global Manager with new config
+ s.RebuildWebSearchManager(ctx)
+ return nil
+}
+
+// mergeExistingAPIKeys preserves API keys from the current config when incoming value is empty.
+func (s *SettingService) mergeExistingAPIKeys(ctx context.Context, cfg *WebSearchEmulationConfig) {
+ existing, _ := s.getWebSearchEmulationConfigRaw(ctx)
+ if existing == nil || cfg == nil {
+ return
+ }
+ existingByType := make(map[string]string, len(existing.Providers))
+ for _, p := range existing.Providers {
+ if p.APIKey != "" {
+ existingByType[p.Type] = p.APIKey
+ }
+ }
+ for i := range cfg.Providers {
+ if cfg.Providers[i].APIKey == "" {
+ if key, ok := existingByType[cfg.Providers[i].Type]; ok {
+ cfg.Providers[i].APIKey = key
+ }
+ }
+ }
+}
+
+func (s *SettingService) getWebSearchEmulationConfigRaw(ctx context.Context) (*WebSearchEmulationConfig, error) {
+ raw, err := s.settingRepo.GetValue(ctx, SettingKeyWebSearchEmulationConfig)
+ if err != nil {
+ return nil, err
+ }
+ return parseWebSearchConfigJSON(raw), nil
+}
+
+// IsWebSearchEmulationEnabled is a quick check for whether the global switch is on.
+func (s *SettingService) IsWebSearchEmulationEnabled(ctx context.Context) bool {
+ cfg, err := s.GetWebSearchEmulationConfig(ctx)
+ if err != nil {
+ return false
+ }
+ return cfg.Enabled && len(cfg.Providers) > 0
+}
+
+// SetWebSearchRedisClient injects the Redis client used for quota tracking.
+// Call after construction, before first use. Triggers initial Manager build.
+func (s *SettingService) SetWebSearchRedisClient(ctx context.Context, redisClient *redis.Client) {
+ s.webSearchRedis = redisClient
+ s.RebuildWebSearchManager(ctx)
+}
+
+// RebuildWebSearchManager reads the current config and (re)creates the global websearch.Manager.
+// Called on startup and after SaveWebSearchEmulationConfig.
+func (s *SettingService) RebuildWebSearchManager(ctx context.Context) {
+ cfg, err := s.GetWebSearchEmulationConfig(ctx)
+ if err != nil || !cfg.Enabled || len(cfg.Providers) == 0 {
+ SetWebSearchManager(nil)
+ return
+ }
+ providerConfigs := make([]websearch.ProviderConfig, 0, len(cfg.Providers))
+ for _, p := range cfg.Providers {
+ providerConfigs = append(providerConfigs, websearch.ProviderConfig{
+ Type: p.Type,
+ APIKey: p.APIKey,
+ Priority: p.Priority,
+ QuotaLimit: p.QuotaLimit,
+ QuotaRefreshInterval: p.QuotaRefreshInterval,
+ ExpiresAt: p.ExpiresAt,
+ })
+ }
+ SetWebSearchManager(websearch.NewManager(providerConfigs, s.webSearchRedis))
+ slog.Info("websearch: manager rebuilt", "provider_count", len(providerConfigs))
+}
+
+// SanitizeWebSearchConfig returns a copy with api_key fields masked for API responses.
+func SanitizeWebSearchConfig(cfg *WebSearchEmulationConfig) *WebSearchEmulationConfig {
+ if cfg == nil {
+ return nil
+ }
+ out := *cfg
+ out.Providers = make([]WebSearchProviderConfig, len(cfg.Providers))
+ for i, p := range cfg.Providers {
+ out.Providers[i] = p
+ out.Providers[i].APIKeyConfigured = p.APIKey != ""
+ out.Providers[i].APIKey = "" // never return the secret
+ }
+ return &out
+}
diff --git a/backend/internal/service/websearch_config_test.go b/backend/internal/service/websearch_config_test.go
new file mode 100644
index 00000000..1a19dd9d
--- /dev/null
+++ b/backend/internal/service/websearch_config_test.go
@@ -0,0 +1,148 @@
+package service
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+// --- validateWebSearchConfig ---
+
+func TestValidateWebSearchConfig_Nil(t *testing.T) {
+ require.NoError(t, validateWebSearchConfig(nil))
+}
+
+func TestValidateWebSearchConfig_Valid(t *testing.T) {
+ cfg := &WebSearchEmulationConfig{
+ Enabled: true,
+ Providers: []WebSearchProviderConfig{
+ {Type: "brave", Priority: 1, QuotaLimit: 1000, QuotaRefreshInterval: "monthly"},
+ {Type: "tavily", Priority: 2, QuotaLimit: 500, QuotaRefreshInterval: "daily"},
+ },
+ }
+ require.NoError(t, validateWebSearchConfig(cfg))
+}
+
+func TestValidateWebSearchConfig_TooManyProviders(t *testing.T) {
+ cfg := &WebSearchEmulationConfig{Providers: make([]WebSearchProviderConfig, 11)}
+ for i := range cfg.Providers {
+ cfg.Providers[i] = WebSearchProviderConfig{Type: "brave"}
+ }
+ err := validateWebSearchConfig(cfg)
+ require.ErrorContains(t, err, "too many providers")
+}
+
+func TestValidateWebSearchConfig_InvalidType(t *testing.T) {
+ cfg := &WebSearchEmulationConfig{
+ Providers: []WebSearchProviderConfig{{Type: "bing"}},
+ }
+ require.ErrorContains(t, validateWebSearchConfig(cfg), "invalid type")
+}
+
+func TestValidateWebSearchConfig_InvalidQuotaInterval(t *testing.T) {
+ cfg := &WebSearchEmulationConfig{
+ Providers: []WebSearchProviderConfig{{Type: "brave", QuotaRefreshInterval: "hourly"}},
+ }
+ require.ErrorContains(t, validateWebSearchConfig(cfg), "invalid quota_refresh_interval")
+}
+
+func TestValidateWebSearchConfig_NegativeQuotaLimit(t *testing.T) {
+ cfg := &WebSearchEmulationConfig{
+ Providers: []WebSearchProviderConfig{{Type: "brave", QuotaLimit: -1}},
+ }
+ require.ErrorContains(t, validateWebSearchConfig(cfg), "quota_limit must be >= 0")
+}
+
+func TestValidateWebSearchConfig_DuplicateType(t *testing.T) {
+ cfg := &WebSearchEmulationConfig{
+ Providers: []WebSearchProviderConfig{
+ {Type: "brave", Priority: 1},
+ {Type: "brave", Priority: 2},
+ },
+ }
+ require.ErrorContains(t, validateWebSearchConfig(cfg), "duplicate type")
+}
+
+func TestValidateWebSearchConfig_EmptyQuotaInterval(t *testing.T) {
+ cfg := &WebSearchEmulationConfig{
+ Providers: []WebSearchProviderConfig{{Type: "brave", QuotaRefreshInterval: ""}},
+ }
+ require.NoError(t, validateWebSearchConfig(cfg))
+}
+
+func TestValidateWebSearchConfig_ZeroQuotaLimit(t *testing.T) {
+ cfg := &WebSearchEmulationConfig{
+ Providers: []WebSearchProviderConfig{{Type: "brave", QuotaLimit: 0}},
+ }
+ require.NoError(t, validateWebSearchConfig(cfg))
+}
+
+// --- parseWebSearchConfigJSON ---
+
+func TestParseWebSearchConfigJSON_ValidJSON(t *testing.T) {
+ raw := `{"enabled":true,"providers":[{"type":"brave","api_key":"sk-xxx"}]}`
+ cfg := parseWebSearchConfigJSON(raw)
+ require.True(t, cfg.Enabled)
+ require.Len(t, cfg.Providers, 1)
+ require.Equal(t, "brave", cfg.Providers[0].Type)
+}
+
+func TestParseWebSearchConfigJSON_EmptyString(t *testing.T) {
+ cfg := parseWebSearchConfigJSON("")
+ require.False(t, cfg.Enabled)
+ require.Empty(t, cfg.Providers)
+}
+
+func TestParseWebSearchConfigJSON_InvalidJSON(t *testing.T) {
+ cfg := parseWebSearchConfigJSON("not{json")
+ require.False(t, cfg.Enabled)
+ require.Empty(t, cfg.Providers)
+}
+
+// --- SanitizeWebSearchConfig ---
+
+func TestSanitizeWebSearchConfig_MaskAPIKey(t *testing.T) {
+ cfg := &WebSearchEmulationConfig{
+ Enabled: true,
+ Providers: []WebSearchProviderConfig{
+ {Type: "brave", APIKey: "sk-secret-xxx"},
+ },
+ }
+ out := SanitizeWebSearchConfig(cfg)
+ require.Equal(t, "", out.Providers[0].APIKey)
+ require.True(t, out.Providers[0].APIKeyConfigured)
+}
+
+func TestSanitizeWebSearchConfig_NoAPIKey(t *testing.T) {
+ cfg := &WebSearchEmulationConfig{
+ Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: ""}},
+ }
+ out := SanitizeWebSearchConfig(cfg)
+ require.Equal(t, "", out.Providers[0].APIKey)
+ require.False(t, out.Providers[0].APIKeyConfigured)
+}
+
+func TestSanitizeWebSearchConfig_Nil(t *testing.T) {
+ require.Nil(t, SanitizeWebSearchConfig(nil))
+}
+
+func TestSanitizeWebSearchConfig_PreservesOtherFields(t *testing.T) {
+ cfg := &WebSearchEmulationConfig{
+ Enabled: true,
+ Providers: []WebSearchProviderConfig{
+ {Type: "brave", APIKey: "secret", Priority: 10, QuotaLimit: 1000},
+ },
+ }
+ out := SanitizeWebSearchConfig(cfg)
+ require.True(t, out.Enabled)
+ require.Equal(t, 10, out.Providers[0].Priority)
+ require.Equal(t, int64(1000), out.Providers[0].QuotaLimit)
+}
+
+func TestSanitizeWebSearchConfig_DoesNotMutateOriginal(t *testing.T) {
+ cfg := &WebSearchEmulationConfig{
+ Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: "secret"}},
+ }
+ _ = SanitizeWebSearchConfig(cfg)
+ require.Equal(t, "secret", cfg.Providers[0].APIKey)
+}
diff --git a/backend/migrations/101_add_channel_features_config.sql b/backend/migrations/101_add_channel_features_config.sql
new file mode 100644
index 00000000..b054b085
--- /dev/null
+++ b/backend/migrations/101_add_channel_features_config.sql
@@ -0,0 +1,2 @@
+ALTER TABLE channels ADD COLUMN IF NOT EXISTS features_config JSONB NOT NULL DEFAULT '{}';
+COMMENT ON COLUMN channels.features_config IS '渠道特性配置(如 web_search_emulation),JSON 对象格式';
diff --git a/frontend/src/api/admin/channels.ts b/frontend/src/api/admin/channels.ts
index b3455022..d49982aa 100644
--- a/frontend/src/api/admin/channels.ts
+++ b/frontend/src/api/admin/channels.ts
@@ -41,6 +41,7 @@ export interface Channel {
status: string
billing_model_source: string // "requested" | "upstream"
restrict_models: boolean
+ features_config?: Record
group_ids: number[]
model_pricing: ChannelModelPricing[]
model_mapping: Record> // platform → {src→dst}
@@ -56,6 +57,7 @@ export interface CreateChannelRequest {
model_mapping?: Record>
billing_model_source?: string
restrict_models?: boolean
+ features_config?: Record
}
export interface UpdateChannelRequest {
@@ -67,6 +69,7 @@ export interface UpdateChannelRequest {
model_mapping?: Record>
billing_model_source?: string
restrict_models?: boolean
+ features_config?: Record
}
interface PaginatedResponse {
diff --git a/frontend/src/api/admin/settings.ts b/frontend/src/api/admin/settings.ts
index 504abe9c..7fc6c852 100644
--- a/frontend/src/api/admin/settings.ts
+++ b/frontend/src/api/admin/settings.ts
@@ -482,6 +482,42 @@ export async function updateBetaPolicySettings(
return data
}
+// --- Web Search Emulation Config ---
+
+export interface WebSearchProviderConfig {
+ type: 'brave' | 'tavily'
+ api_key: string
+ api_key_configured: boolean
+ priority: number
+ quota_limit: number
+ quota_refresh_interval: 'daily' | 'weekly' | 'monthly'
+ quota_used?: number
+ proxy_id: number | null
+ expires_at: number | null
+}
+
+export interface WebSearchEmulationConfig {
+ enabled: boolean
+ providers: WebSearchProviderConfig[]
+}
+
+export async function getWebSearchEmulationConfig(): Promise {
+ const { data } = await apiClient.get(
+ '/admin/settings/web-search-emulation'
+ )
+ return data
+}
+
+export async function updateWebSearchEmulationConfig(
+ config: WebSearchEmulationConfig
+): Promise {
+ const { data } = await apiClient.put(
+ '/admin/settings/web-search-emulation',
+ config
+ )
+ return data
+}
+
export const settingsAPI = {
getSettings,
updateSettings,
@@ -497,7 +533,9 @@ export const settingsAPI = {
getRectifierSettings,
updateRectifierSettings,
getBetaPolicySettings,
- updateBetaPolicySettings
+ updateBetaPolicySettings,
+ getWebSearchEmulationConfig,
+ updateWebSearchEmulationConfig
}
export default settingsAPI
diff --git a/frontend/src/components/account/CreateAccountModal.vue b/frontend/src/components/account/CreateAccountModal.vue
index 380201c4..d5d8ff89 100644
--- a/frontend/src/components/account/CreateAccountModal.vue
+++ b/frontend/src/components/account/CreateAccountModal.vue
@@ -2325,6 +2325,22 @@
+
+
+
+
+
+
+ {{ t('admin.accounts.anthropic.webSearchEmulationDesc') }}
+
+
+
+
+
+
(OPENAI_WS_MODE_OFF
const openaiAPIKeyResponsesWebSocketV2Mode = ref(OPENAI_WS_MODE_OFF)
const codexCLIOnlyEnabled = ref(false)
const anthropicPassthroughEnabled = ref(false)
+const webSearchEmulationEnabled = ref(false)
const mixedScheduling = ref(false) // For antigravity accounts: enable mixed scheduling
const allowOverages = ref(false) // For antigravity accounts: enable AI Credits overages
const antigravityAccountType = ref<'oauth' | 'upstream'>('oauth') // For antigravity: oauth or upstream
@@ -3307,6 +3325,7 @@ watch(
}
if (newPlatform !== 'anthropic') {
anthropicPassthroughEnabled.value = false
+ webSearchEmulationEnabled.value = false
}
// Reset OAuth states
oauth.resetState()
@@ -3326,6 +3345,7 @@ watch(
}
if (platform !== 'anthropic' || category !== 'apikey') {
anthropicPassthroughEnabled.value = false
+ webSearchEmulationEnabled.value = false
}
}
)
@@ -3690,6 +3710,7 @@ const resetForm = () => {
openaiAPIKeyResponsesWebSocketV2Mode.value = OPENAI_WS_MODE_OFF
codexCLIOnlyEnabled.value = false
anthropicPassthroughEnabled.value = false
+ webSearchEmulationEnabled.value = false
// Reset quota control state
windowCostEnabled.value = false
windowCostLimit.value = null
@@ -3777,6 +3798,11 @@ const buildAnthropicExtra = (base?: Record): Record 0 ? extra : undefined
}
diff --git a/frontend/src/components/account/EditAccountModal.vue b/frontend/src/components/account/EditAccountModal.vue
index 8b72d6d1..a67366fc 100644
--- a/frontend/src/components/account/EditAccountModal.vue
+++ b/frontend/src/components/account/EditAccountModal.vue
@@ -1149,10 +1149,61 @@
-
-
+
+
+
+
+
+
+ {{ t('admin.accounts.anthropic.webSearchEmulationDesc') }}
+
+
+
+
+
+
+
+
-
{{ t('admin.accounts.quotaLimit') }}
+
{{ t('admin.accounts.quotaControl.title') }}
+
+ {{ t('admin.accounts.quotaControl.hint') }}
+
+
+
+
+
+
+
+
{{ t('admin.accounts.quotaControl.title') }}
{{ t('admin.accounts.quotaLimitHint') }}
@@ -1237,7 +1288,7 @@
-
+
(OPENAI_WS_MODE_OFF
const openaiAPIKeyResponsesWebSocketV2Mode = ref(OPENAI_WS_MODE_OFF)
const codexCLIOnlyEnabled = ref(false)
const anthropicPassthroughEnabled = ref(false)
+const webSearchEmulationEnabled = ref(false)
const editQuotaLimit = ref(null)
const editQuotaDailyLimit = ref(null)
const editQuotaWeeklyLimit = ref(null)
@@ -2067,6 +2120,7 @@ const syncFormFromAccount = (newAccount: Account | null) => {
openaiAPIKeyResponsesWebSocketV2Mode.value = OPENAI_WS_MODE_OFF
codexCLIOnlyEnabled.value = false
anthropicPassthroughEnabled.value = false
+ webSearchEmulationEnabled.value = false
if (newAccount.platform === 'openai' && (newAccount.type === 'oauth' || newAccount.type === 'apikey')) {
openaiPassthroughEnabled.value = extra?.openai_passthrough === true || extra?.openai_oauth_passthrough === true
openaiOAuthResponsesWebSocketV2Mode.value = resolveOpenAIWSModeFromExtra(extra, {
@@ -2087,6 +2141,7 @@ const syncFormFromAccount = (newAccount: Account | null) => {
}
if (newAccount.platform === 'anthropic' && newAccount.type === 'apikey') {
anthropicPassthroughEnabled.value = extra?.anthropic_passthrough === true
+ webSearchEmulationEnabled.value = extra?.web_search_emulation === true
}
// Load quota limit for apikey/bedrock accounts (bedrock quota is also loaded in its own branch above)
@@ -2522,8 +2577,13 @@ function loadQuotaControlSettings(account: Account) {
customBaseUrlEnabled.value = false
customBaseUrl.value = ''
- // Only applies to Anthropic OAuth/SetupToken accounts
- if (account.platform !== 'anthropic' || (account.type !== 'oauth' && account.type !== 'setup-token')) {
+ // Remaining quota control settings only apply to Anthropic accounts
+ if (account.platform !== 'anthropic') {
+ return
+ }
+
+ // Window cost / session limit only apply to Anthropic OAuth/SetupToken accounts
+ if (account.type !== 'oauth' && account.type !== 'setup-token') {
return
}
@@ -2949,7 +3009,7 @@ const handleSubmit = async () => {
// For Anthropic OAuth/SetupToken accounts, handle quota control settings in extra
if (props.account.platform === 'anthropic' && (props.account.type === 'oauth' || props.account.type === 'setup-token')) {
- const currentExtra = (props.account.extra as Record) || {}
+ const currentExtra = (updatePayload.extra as Record) || (props.account.extra as Record) || {}
const newExtra: Record = { ...currentExtra }
// Window cost limit settings
@@ -3037,15 +3097,20 @@ const handleSubmit = async () => {
updatePayload.extra = newExtra
}
- // For Anthropic API Key accounts, handle passthrough mode in extra
+ // For Anthropic API Key accounts, handle passthrough mode + web search emulation in extra
if (props.account.platform === 'anthropic' && props.account.type === 'apikey') {
- const currentExtra = (props.account.extra as Record) || {}
+ const currentExtra = (updatePayload.extra as Record) || (props.account.extra as Record) || {}
const newExtra: Record = { ...currentExtra }
if (anthropicPassthroughEnabled.value) {
newExtra.anthropic_passthrough = true
} else {
delete newExtra.anthropic_passthrough
}
+ if (webSearchEmulationEnabled.value) {
+ newExtra.web_search_emulation = true
+ } else {
+ delete newExtra.web_search_emulation
+ }
updatePayload.extra = newExtra
}
@@ -3089,20 +3154,27 @@ const handleSubmit = async () => {
const currentExtra = (updatePayload.extra as Record) ||
(props.account.extra as Record) || {}
const newExtra: Record = { ...currentExtra }
+ // Total quota
if (editQuotaLimit.value != null && editQuotaLimit.value > 0) {
newExtra.quota_limit = editQuotaLimit.value
} else {
delete newExtra.quota_limit
}
+ // Daily quota
if (editQuotaDailyLimit.value != null && editQuotaDailyLimit.value > 0) {
newExtra.quota_daily_limit = editQuotaDailyLimit.value
} else {
delete newExtra.quota_daily_limit
+ delete newExtra.quota_daily_used
+ delete newExtra.quota_daily_start
}
+ // Weekly quota
if (editQuotaWeeklyLimit.value != null && editQuotaWeeklyLimit.value > 0) {
newExtra.quota_weekly_limit = editQuotaWeeklyLimit.value
} else {
delete newExtra.quota_weekly_limit
+ delete newExtra.quota_weekly_used
+ delete newExtra.quota_weekly_start
}
// Quota reset mode config
if (editDailyResetMode.value === 'fixed') {
diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts
index d6c87d52..99f8d535 100644
--- a/frontend/src/i18n/locales/en.ts
+++ b/frontend/src/i18n/locales/en.ts
@@ -1836,6 +1836,9 @@ export default {
defaultPerRequestPrice: 'Default per-request price (fallback when no tier matches)',
defaultImagePrice: 'Default image price (fallback when no tier matches)',
platformConfig: 'Platform Configuration',
+ webSearchEmulation: 'Web Search Emulation',
+ webSearchEmulationHint: '⚠️ When enabled, all accounts in this channel\'s Anthropic groups will intercept web_search requests. Use with caution.',
+ webSearchEmulationGlobalDisabled: 'Please enable the global switch first in Settings → Gateway → Web Search Emulation',
basicSettings: 'Basic Settings',
addPlatform: 'Add Platform',
noPlatforms: 'Click "Add Platform" to start configuring the channel',
@@ -2325,7 +2328,10 @@ export default {
anthropic: {
apiKeyPassthrough: 'Auto passthrough (auth only)',
apiKeyPassthroughDesc:
- 'Only applies to Anthropic API Key accounts. When enabled, messages/count_tokens are forwarded in passthrough mode with auth replacement only, while billing/concurrency/audit and safety filtering are preserved. Disable to roll back immediately.'
+ 'Only applies to Anthropic API Key accounts. When enabled, messages/count_tokens are forwarded in passthrough mode with auth replacement only, while billing/concurrency/audit and safety filtering are preserved. Disable to roll back immediately.',
+ webSearchEmulation: 'Web Search Emulation',
+ webSearchEmulationDesc:
+ 'Enable web search emulation for this API Key account. When a pure web_search request is detected, the gateway calls a third-party search API and constructs the response locally.',
},
modelRestriction: 'Model Restriction (Optional)',
modelWhitelist: 'Model Whitelist',
@@ -4358,6 +4364,31 @@ export default {
cchSigning: 'CCH Signing',
cchSigningHint: 'Sign the billing header in forwarded requests with CCH hash. When disabled, the placeholder is preserved.',
},
+ webSearchEmulation: {
+ title: 'Web Search Emulation',
+ description: 'Inject web search capability for Anthropic API Key accounts that don\'t natively support it',
+ enabled: 'Enable Web Search Emulation',
+ enabledHint: 'Global switch. When disabled, web search emulation is inactive for all channels and accounts.',
+ providers: 'Search Providers',
+ addProvider: 'Add Provider',
+ providerType: 'Provider Type',
+ apiKey: 'API Key',
+ apiKeyPlaceholder: 'Enter API Key',
+ apiKeyConfigured: 'Configured',
+ priority: 'Priority',
+ priorityHint: 'Lower number = higher priority',
+ quotaLimit: 'Quota Limit',
+ quotaLimitHint: '0 = unlimited',
+ quotaRefreshInterval: 'Refresh Interval',
+ quotaUsed: 'Used',
+ proxy: 'Proxy',
+ expiresAt: 'Expires At',
+ removeProvider: 'Remove',
+ daily: 'Daily',
+ weekly: 'Weekly',
+ monthly: 'Monthly',
+ noProviders: 'No search providers configured',
+ },
site: {
title: 'Site Settings',
description: 'Customize site branding',
diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts
index 2038970a..7ef7ead0 100644
--- a/frontend/src/i18n/locales/zh.ts
+++ b/frontend/src/i18n/locales/zh.ts
@@ -1915,6 +1915,9 @@ export default {
defaultPerRequestPrice: '默认单次价格(未命中层级时使用)',
defaultImagePrice: '默认图片价格(未命中层级时使用)',
platformConfig: '平台配置',
+ webSearchEmulation: 'Web Search 模拟',
+ webSearchEmulationHint: '⚠️ 开启后该渠道下所有 Anthropic 分组的账号将自动拦截 web_search 请求,请谨慎操作',
+ webSearchEmulationGlobalDisabled: '请先在系统设置 → 网关 → Web Search 模拟中启用全局开关',
basicSettings: '基础设置',
addPlatform: '添加平台',
noPlatforms: '点击"添加平台"开始配置渠道',
@@ -2472,7 +2475,10 @@ export default {
anthropic: {
apiKeyPassthrough: '自动透传(仅替换认证)',
apiKeyPassthroughDesc:
- '仅对 Anthropic API Key 生效。开启后,messages/count_tokens 请求将透传上游并仅替换认证,保留计费/并发/审计及必要安全过滤;关闭即可回滚到现有兼容链路。'
+ '仅对 Anthropic API Key 生效。开启后,messages/count_tokens 请求将透传上游并仅替换认证,保留计费/并发/审计及必要安全过滤;关闭即可回滚到现有兼容链路。',
+ webSearchEmulation: 'Web Search 模拟',
+ webSearchEmulationDesc:
+ '为该 API Key 账号启用 web search 模拟。客户端发送纯 web_search 请求时,由网关调用第三方搜索 API 并构造响应返回。',
},
modelRestriction: '模型限制(可选)',
modelWhitelist: '模型白名单',
@@ -4520,6 +4526,31 @@ export default {
cchSigning: 'CCH 签名',
cchSigningHint: '对转发请求的 billing header 进行 CCH 哈希签名。关闭时保留原始占位符。',
},
+ webSearchEmulation: {
+ title: 'Web Search 模拟',
+ description: '为不原生支持搜索的 Anthropic API Key 账号注入 web search 能力',
+ enabled: '启用 Web Search 模拟',
+ enabledHint: '全局开关。关闭后所有渠道和账号的 web search 模拟均不生效。',
+ providers: '搜索服务商',
+ addProvider: '添加服务商',
+ providerType: '服务商类型',
+ apiKey: 'API Key',
+ apiKeyPlaceholder: '输入 API Key',
+ apiKeyConfigured: '已配置',
+ priority: '优先级',
+ priorityHint: '数值越小优先级越高',
+ quotaLimit: '配额上限',
+ quotaLimitHint: '0 表示无限制',
+ quotaRefreshInterval: '刷新周期',
+ quotaUsed: '已使用',
+ proxy: '代理',
+ expiresAt: '过期时间',
+ removeProvider: '删除',
+ daily: '每日',
+ weekly: '每周',
+ monthly: '每月',
+ noProviders: '未配置搜索服务商',
+ },
site: {
title: '站点设置',
description: '自定义站点品牌',
diff --git a/frontend/src/views/admin/ChannelsView.vue b/frontend/src/views/admin/ChannelsView.vue
index 5c2f153b..ce8d3c9c 100644
--- a/frontend/src/views/admin/ChannelsView.vue
+++ b/frontend/src/views/admin/ChannelsView.vue
@@ -306,6 +306,24 @@
+
+
+
+
+
+
+ {{ t('admin.channels.form.webSearchEmulationHint') }}
+
+
+ {{ t('admin.channels.form.webSearchEmulationGlobalDisabled') }}
+
+
+
+
+
+
@@ -423,6 +441,7 @@
import { ref, reactive, computed, onMounted, onUnmounted } from 'vue'
import { useI18n } from 'vue-i18n'
import { useAppStore } from '@/stores/app'
+import { extractApiErrorMessage } from '@/utils/apiError'
import { adminAPI } from '@/api/admin'
import type { Channel, ChannelModelPricing, CreateChannelRequest, UpdateChannelRequest } from '@/api/admin/channels'
import type { PricingFormEntry } from '@/components/admin/channel/types'
@@ -446,6 +465,18 @@ import { getPersistedPageSize } from '@/composables/usePersistedPageSize'
const { t } = useI18n()
const appStore = useAppStore()
+// Web Search global enabled state (loaded once on mount)
+const webSearchGlobalEnabled = ref(false)
+async function loadWebSearchGlobalState() {
+ try {
+ const cfg = await adminAPI.settings.getWebSearchEmulationConfig()
+ webSearchGlobalEnabled.value = cfg?.enabled === true && (cfg?.providers?.length ?? 0) > 0
+ } catch (err: unknown) {
+ console.warn('Failed to load web search global state:', err)
+ webSearchGlobalEnabled.value = false
+ }
+}
+
// ── Platform Section type ──
interface PlatformSection {
platform: GroupPlatform
@@ -454,6 +485,7 @@ interface PlatformSection {
group_ids: number[]
model_mapping: Record
model_pricing: PricingFormEntry[]
+ web_search_emulation: boolean
}
// ── Table columns ──
@@ -565,7 +597,8 @@ function addPlatformSection(platform: GroupPlatform) {
collapsed: false,
group_ids: [],
model_mapping: {},
- model_pricing: []
+ model_pricing: [],
+ web_search_emulation: false,
})
}
@@ -679,10 +712,14 @@ function renameMappingKey(sectionIdx: number, oldKey: string, newKey: string) {
}
// ── Form ↔ API conversion ──
-function formToAPI(): { group_ids: number[], model_pricing: ChannelModelPricing[], model_mapping: Record> } {
+function formToAPI(): { group_ids: number[], model_pricing: ChannelModelPricing[], model_mapping: Record>, features_config: Record } {
const group_ids: number[] = []
const model_pricing: ChannelModelPricing[] = []
const model_mapping: Record> = {}
+ // Preserve existing features_config fields not managed by the form
+ const featuresConfig: Record = editingChannel.value?.features_config
+ ? { ...editingChannel.value.features_config }
+ : {}
for (const section of form.platforms) {
if (!section.enabled) continue
@@ -711,7 +748,19 @@ function formToAPI(): { group_ids: number[], model_pricing: ChannelModelPricing[
}
}
- return { group_ids, model_pricing, model_mapping }
+ // Collect web_search_emulation (only anthropic platform supports it)
+ const wsEmulation: Record = {}
+ for (const section of form.platforms) {
+ if (!section.enabled) continue
+ if (section.web_search_emulation && section.platform === 'anthropic') {
+ wsEmulation[section.platform] = true
+ }
+ }
+ if (Object.keys(wsEmulation).length > 0) {
+ featuresConfig.web_search_emulation = wsEmulation
+ }
+
+ return { group_ids, model_pricing, model_mapping, features_config: featuresConfig }
}
function apiToForm(channel: Channel): PlatformSection[] {
@@ -755,13 +804,19 @@ function apiToForm(channel: Channel): PlatformSection[] {
intervals: apiIntervalsToForm(p.intervals || [])
} as PricingFormEntry))
+ // Read web_search_emulation from features_config
+ const fc = channel.features_config
+ const wsEmulation = fc?.web_search_emulation as Record | undefined
+ const webSearchEnabled = wsEmulation?.[platform] === true
+
sections.push({
platform,
enabled: true,
collapsed: false,
group_ids: groupIds,
model_mapping: { ...mapping },
- model_pricing: pricing
+ model_pricing: pricing,
+ web_search_emulation: webSearchEnabled,
})
}
@@ -786,10 +841,10 @@ async function loadChannels() {
if (ctrl.signal.aborted || abortController !== ctrl) return
channels.value = response.items || []
pagination.total = response.total
- } catch (error: any) {
- if (error?.name === 'AbortError' || error?.code === 'ERR_CANCELED') return
- appStore.showError(t('admin.channels.loadError', 'Failed to load channels'))
- console.error('Error loading channels:', error)
+ } catch (error: unknown) {
+ const e = error as { name?: string; code?: string }
+ if (e?.name === 'AbortError' || e?.code === 'ERR_CANCELED') return
+ appStore.showError(extractApiErrorMessage(error, t('admin.channels.loadError', 'Failed to load channels')))
} finally {
if (abortController === ctrl) {
loading.value = false
@@ -969,8 +1024,7 @@ async function handleSubmit() {
}
}
- const { group_ids, model_pricing, model_mapping } = formToAPI()
- console.log('[handleSubmit] model_pricing to send:', JSON.stringify(model_pricing))
+ const { group_ids, model_pricing, model_mapping, features_config } = formToAPI()
submitting.value = true
try {
@@ -983,7 +1037,8 @@ async function handleSubmit() {
model_pricing,
model_mapping: Object.keys(model_mapping).length > 0 ? model_mapping : {},
billing_model_source: form.billing_model_source,
- restrict_models: form.restrict_models
+ restrict_models: form.restrict_models,
+ features_config,
}
await adminAPI.channels.update(editingChannel.value.id, req)
appStore.showSuccess(t('admin.channels.updateSuccess', 'Channel updated'))
@@ -995,19 +1050,18 @@ async function handleSubmit() {
model_pricing,
model_mapping: Object.keys(model_mapping).length > 0 ? model_mapping : {},
billing_model_source: form.billing_model_source,
- restrict_models: form.restrict_models
+ restrict_models: form.restrict_models,
+ features_config,
}
await adminAPI.channels.create(req)
appStore.showSuccess(t('admin.channels.createSuccess', 'Channel created'))
}
closeDialog()
loadChannels()
- } catch (error: any) {
- const msg = error.response?.data?.detail || (editingChannel.value
+ } catch (error: unknown) {
+ appStore.showError(extractApiErrorMessage(error, editingChannel.value
? t('admin.channels.updateError', 'Failed to update channel')
- : t('admin.channels.createError', 'Failed to create channel'))
- appStore.showError(msg)
- console.error('Error saving channel:', error)
+ : t('admin.channels.createError', 'Failed to create channel')))
} finally {
submitting.value = false
}
@@ -1045,9 +1099,8 @@ async function confirmDelete() {
showDeleteDialog.value = false
deletingChannel.value = null
loadChannels()
- } catch (error: any) {
- appStore.showError(error.response?.data?.detail || t('admin.channels.deleteError', 'Failed to delete channel'))
- console.error('Error deleting channel:', error)
+ } catch (error: unknown) {
+ appStore.showError(extractApiErrorMessage(error, t('admin.channels.deleteError', 'Failed to delete channel')))
}
}
@@ -1055,6 +1108,7 @@ async function confirmDelete() {
onMounted(() => {
loadChannels()
loadGroups()
+ loadWebSearchGlobalState()
})
onUnmounted(() => {
diff --git a/frontend/src/views/admin/SettingsView.vue b/frontend/src/views/admin/SettingsView.vue
index 20f9318c..6abc725a 100644
--- a/frontend/src/views/admin/SettingsView.vue
+++ b/frontend/src/views/admin/SettingsView.vue
@@ -630,6 +630,108 @@
{{ t('admin.settings.betaPolicy.errorMessageHint') }}
+
+
+
+
+
+
+
+
+
+
+
+
+
+ {{ t('admin.settings.betaPolicy.modelWhitelistHint') }}
+
+
+
+
+
+
+
+ {{ t('admin.settings.betaPolicy.commonPatterns') }}:
+
+
+
+
+
+
+
+
+
+ {{ t('admin.settings.betaPolicy.fallbackActionHint') }}
+
+
+
+
+
+ {{ t('admin.settings.betaPolicy.errorMessageHint') }}
+
+
+
@@ -1022,7 +1124,327 @@
-
+
+
+
+
+
+ {{ t('admin.settings.oidc.title') }}
+
+
+ {{ t('admin.settings.oidc.description') }}
+
+
+
+
+
+
+
+ {{ t('admin.settings.oidc.enableHint') }}
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ {{
+ form.oidc_connect_client_secret_configured
+ ? t('admin.settings.oidc.clientSecretConfiguredHint')
+ : t('admin.settings.oidc.clientSecretHint')
+ }}
+
+
+
+
+
+
+
+
+
+
+
+ {{ t('admin.settings.oidc.scopesHint') }}
+
+
+
+
+
+
+
+
+
+ {{ oidcRedirectUrlSuggestion }}
+
+
+
+ {{ t('admin.settings.oidc.redirectUrlHint') }}
+
+
+
+
+
+
+
+ {{ t('admin.settings.oidc.frontendRedirectUrlHint') }}
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
@@ -1274,8 +1696,122 @@
+
+
+
+
+
+
+ {{ t('admin.settings.gatewayForwarding.cchSigningHint') }}
+
+
+
+
+
+
+
+
+ {{ t('admin.settings.webSearchEmulation.title') }}
+
+
+ {{ t('admin.settings.webSearchEmulation.description') }}
+
+
+
+
+
+
+
+
+ {{ t('admin.settings.webSearchEmulation.enabledHint') }}
+
+
+
+
+
+
+
+
+
+
+
+
+
+ {{ t('admin.settings.webSearchEmulation.noProviders') }}
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
{{ t('admin.settings.webSearchEmulation.priorityHint') }}
+
+
+
+
+
{{ t('admin.settings.webSearchEmulation.quotaLimitHint') }}
+
+ {{ t('admin.settings.webSearchEmulation.quotaUsed') }}: {{ provider.quota_used }} / {{ provider.quota_limit || '∞' }}
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
@@ -1353,6 +1889,48 @@
+
+
+
+ {{ t('admin.settings.site.tablePreferencesTitle') }}
+
+
+ {{ t('admin.settings.site.tablePreferencesDescription') }}
+
+
+
+
+
+
+ {{ t('admin.settings.site.tableDefaultPageSizeHint') }}
+
+
+
+
+
+
+ {{ t('admin.settings.site.tablePageSizeOptionsHint') }}
+
+
+
+
+