feat(gateway): add web search emulation for Anthropic API Key accounts
Inject web search capability for Claude Console (API Key) accounts that don't natively support Anthropic's web_search tool. When a pure web_search request is detected, the gateway calls Brave Search or Tavily API directly and constructs an Anthropic-protocol-compliant SSE/JSON response without forwarding to upstream. Backend: - New `pkg/websearch/` SDK: Brave and Tavily provider implementations with io.LimitReader, proxy support, and Redis-based quota tracking (Lua atomic INCR + TTL, DECR rollback on failure) - Global config via `settings.web_search_emulation_config` (JSON) with in-process cache + singleflight, input validation, API key merge on save, and sanitized API responses - Channel-level toggle via `channels.features_config` JSONB column (DB migration 101) - Account-level toggle via `accounts.extra.web_search_emulation` - Request interception in `Forward()` with SSE streaming response construction using json.Marshal (no manual string concatenation) - Manager hot-reload: `RebuildWebSearchManager()` called on config save and startup via `SetWebSearchRedisClient()` - 70 unit tests covering providers, manager, config validation, sanitization, tool detection, query extraction, and response building Frontend: - Settings → Gateway tab: Web Search Emulation config card with global toggle, provider list (add/remove, API key, priority, quota, proxy) - Channels → Anthropic tab: web search emulation toggle with global state linkage (disabled when global off) - Account Create/Edit modals: web search emulation toggle for API Key type with Toggle component - Full i18n coverage (zh + en)
This commit is contained in:
@@ -34,6 +34,7 @@ type createChannelRequest struct {
|
||||
BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream channel_mapped"`
|
||||
RestrictModels bool `json:"restrict_models"`
|
||||
Features string `json:"features"`
|
||||
FeaturesConfig map[string]any `json:"features_config"`
|
||||
}
|
||||
|
||||
type updateChannelRequest struct {
|
||||
@@ -46,6 +47,7 @@ type updateChannelRequest struct {
|
||||
BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream channel_mapped"`
|
||||
RestrictModels *bool `json:"restrict_models"`
|
||||
Features *string `json:"features"`
|
||||
FeaturesConfig map[string]any `json:"features_config"`
|
||||
}
|
||||
|
||||
type channelModelPricingRequest struct {
|
||||
@@ -81,6 +83,7 @@ type channelResponse struct {
|
||||
BillingModelSource string `json:"billing_model_source"`
|
||||
RestrictModels bool `json:"restrict_models"`
|
||||
Features string `json:"features"`
|
||||
FeaturesConfig map[string]any `json:"features_config"`
|
||||
GroupIDs []int64 `json:"group_ids"`
|
||||
ModelPricing []channelModelPricingResponse `json:"model_pricing"`
|
||||
ModelMapping map[string]map[string]string `json:"model_mapping"`
|
||||
@@ -126,6 +129,7 @@ func channelToResponse(ch *service.Channel) *channelResponse {
|
||||
Status: ch.Status,
|
||||
RestrictModels: ch.RestrictModels,
|
||||
Features: ch.Features,
|
||||
FeaturesConfig: ch.FeaturesConfig,
|
||||
GroupIDs: ch.GroupIDs,
|
||||
ModelMapping: ch.ModelMapping,
|
||||
CreatedAt: ch.CreatedAt.Format("2006-01-02T15:04:05Z"),
|
||||
@@ -305,6 +309,7 @@ func (h *ChannelHandler) Create(c *gin.Context) {
|
||||
BillingModelSource: req.BillingModelSource,
|
||||
RestrictModels: req.RestrictModels,
|
||||
Features: req.Features,
|
||||
FeaturesConfig: req.FeaturesConfig,
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
@@ -338,6 +343,7 @@ func (h *ChannelHandler) Update(c *gin.Context) {
|
||||
BillingModelSource: req.BillingModelSource,
|
||||
RestrictModels: req.RestrictModels,
|
||||
Features: req.Features,
|
||||
FeaturesConfig: req.FeaturesConfig,
|
||||
}
|
||||
if req.ModelPricing != nil {
|
||||
pricing := pricingRequestToService(*req.ModelPricing)
|
||||
|
||||
@@ -175,6 +175,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
||||
EnableFingerprintUnification: settings.EnableFingerprintUnification,
|
||||
EnableMetadataPassthrough: settings.EnableMetadataPassthrough,
|
||||
EnableCCHSigning: settings.EnableCCHSigning,
|
||||
WebSearchEmulationEnabled: settings.WebSearchEmulationEnabled,
|
||||
PaymentEnabled: paymentCfg.Enabled,
|
||||
PaymentMinAmount: paymentCfg.MinAmount,
|
||||
PaymentMaxAmount: paymentCfg.MaxAmount,
|
||||
@@ -1847,3 +1848,37 @@ func (h *SettingHandler) UpdateStreamTimeoutSettings(c *gin.Context) {
|
||||
ThresholdWindowMinutes: updatedSettings.ThresholdWindowMinutes,
|
||||
})
|
||||
}
|
||||
|
||||
// GetWebSearchEmulationConfig 获取 Web Search 模拟配置
|
||||
// GET /api/v1/admin/settings/web-search-emulation
|
||||
func (h *SettingHandler) GetWebSearchEmulationConfig(c *gin.Context) {
|
||||
cfg, err := h.settingService.GetWebSearchEmulationConfig(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, service.SanitizeWebSearchConfig(cfg))
|
||||
}
|
||||
|
||||
// UpdateWebSearchEmulationConfig 更新 Web Search 模拟配置
|
||||
// PUT /api/v1/admin/settings/web-search-emulation
|
||||
func (h *SettingHandler) UpdateWebSearchEmulationConfig(c *gin.Context) {
|
||||
var cfg service.WebSearchEmulationConfig
|
||||
if err := c.ShouldBindJSON(&cfg); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.settingService.SaveWebSearchEmulationConfig(c.Request.Context(), &cfg); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Re-read (with sanitized api keys) to return current state
|
||||
updated, err := h.settingService.GetWebSearchEmulationConfig(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, service.SanitizeWebSearchConfig(updated))
|
||||
}
|
||||
|
||||
@@ -124,6 +124,9 @@ type SystemSettings struct {
|
||||
EnableMetadataPassthrough bool `json:"enable_metadata_passthrough"`
|
||||
EnableCCHSigning bool `json:"enable_cch_signing"`
|
||||
|
||||
// Web Search Emulation
|
||||
WebSearchEmulationEnabled bool `json:"web_search_emulation_enabled"`
|
||||
|
||||
// Payment configuration
|
||||
PaymentEnabled bool `json:"payment_enabled"`
|
||||
PaymentMinAmount float64 `json:"payment_min_amount"`
|
||||
|
||||
106
backend/internal/pkg/websearch/brave.go
Normal file
106
backend/internal/pkg/websearch/brave.go
Normal file
@@ -0,0 +1,106 @@
|
||||
package websearch
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
const (
|
||||
braveSearchEndpoint = "https://api.search.brave.com/res/v1/web/search"
|
||||
braveMaxCount = 20
|
||||
braveProviderName = "brave"
|
||||
)
|
||||
|
||||
// braveSearchURL is pre-parsed at init time; url.Parse cannot fail on a constant literal.
|
||||
var braveSearchURL, _ = url.Parse(braveSearchEndpoint) //nolint:errcheck
|
||||
|
||||
// BraveProvider implements web search via the Brave Search API.
|
||||
type BraveProvider struct {
|
||||
apiKey string
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// NewBraveProvider creates a Brave Search provider.
|
||||
// The caller is responsible for configuring the http.Client with proxy/timeouts.
|
||||
func NewBraveProvider(apiKey string, httpClient *http.Client) *BraveProvider {
|
||||
if httpClient == nil {
|
||||
httpClient = http.DefaultClient
|
||||
}
|
||||
return &BraveProvider{apiKey: apiKey, httpClient: httpClient}
|
||||
}
|
||||
|
||||
func (b *BraveProvider) Name() string { return braveProviderName }
|
||||
|
||||
func (b *BraveProvider) Search(ctx context.Context, req SearchRequest) (*SearchResponse, error) {
|
||||
count := req.MaxResults
|
||||
if count <= 0 {
|
||||
count = defaultMaxResults
|
||||
}
|
||||
if count > braveMaxCount {
|
||||
count = braveMaxCount
|
||||
}
|
||||
|
||||
u := *braveSearchURL // copy the pre-parsed URL
|
||||
q := u.Query()
|
||||
q.Set("q", req.Query)
|
||||
q.Set("count", strconv.Itoa(count))
|
||||
u.RawQuery = q.Encode()
|
||||
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("brave: build request: %w", err)
|
||||
}
|
||||
httpReq.Header.Set("X-Subscription-Token", b.apiKey)
|
||||
httpReq.Header.Set("Accept", "application/json")
|
||||
|
||||
resp, err := b.httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("brave: request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseSize))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("brave: read body: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("brave: status %d: %s", resp.StatusCode, truncateBody(body))
|
||||
}
|
||||
|
||||
var raw braveResponse
|
||||
if err := json.Unmarshal(body, &raw); err != nil {
|
||||
return nil, fmt.Errorf("brave: decode response: %w", err)
|
||||
}
|
||||
|
||||
results := make([]SearchResult, 0, len(raw.Web.Results))
|
||||
for _, r := range raw.Web.Results {
|
||||
results = append(results, SearchResult{
|
||||
URL: r.URL,
|
||||
Title: r.Title,
|
||||
Snippet: r.Description,
|
||||
PageAge: r.Age,
|
||||
})
|
||||
}
|
||||
|
||||
return &SearchResponse{Results: results, Query: req.Query}, nil
|
||||
}
|
||||
|
||||
// braveResponse is the minimal structure of the Brave Search API response.
|
||||
type braveResponse struct {
|
||||
Web struct {
|
||||
Results []braveResult `json:"results"`
|
||||
} `json:"web"`
|
||||
}
|
||||
|
||||
type braveResult struct {
|
||||
URL string `json:"url"`
|
||||
Title string `json:"title"`
|
||||
Description string `json:"description"`
|
||||
Age string `json:"age"`
|
||||
}
|
||||
119
backend/internal/pkg/websearch/brave_test.go
Normal file
119
backend/internal/pkg/websearch/brave_test.go
Normal file
@@ -0,0 +1,119 @@
|
||||
package websearch
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestBraveProvider_Name(t *testing.T) {
|
||||
p := NewBraveProvider("key", nil)
|
||||
require.Equal(t, "brave", p.Name())
|
||||
}
|
||||
|
||||
func TestBraveProvider_Search_Success(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, "test-key", r.Header.Get("X-Subscription-Token"))
|
||||
require.Equal(t, "application/json", r.Header.Get("Accept"))
|
||||
require.Equal(t, "golang", r.URL.Query().Get("q"))
|
||||
require.Equal(t, "3", r.URL.Query().Get("count"))
|
||||
|
||||
resp := braveResponse{}
|
||||
resp.Web.Results = []braveResult{
|
||||
{URL: "https://go.dev", Title: "Go", Description: "Go lang", Age: "1 day"},
|
||||
{URL: "https://pkg.go.dev", Title: "Pkg", Description: "Packages"},
|
||||
{URL: "https://tour.go.dev", Title: "Tour", Description: "A Tour of Go", Age: "3 days"},
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
p := NewBraveProvider("test-key", srv.Client())
|
||||
// Override the endpoint for testing
|
||||
origURL := *braveSearchURL
|
||||
u, _ := http.NewRequest("GET", srv.URL, nil)
|
||||
*braveSearchURL = *u.URL
|
||||
defer func() { *braveSearchURL = origURL }()
|
||||
|
||||
resp, err := p.Search(context.Background(), SearchRequest{Query: "golang", MaxResults: 3})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, resp.Results, 3)
|
||||
require.Equal(t, "https://go.dev", resp.Results[0].URL)
|
||||
require.Equal(t, "Go lang", resp.Results[0].Snippet)
|
||||
require.Equal(t, "1 day", resp.Results[0].PageAge)
|
||||
}
|
||||
|
||||
func TestBraveProvider_Search_DefaultMaxResults(t *testing.T) {
|
||||
var receivedCount string
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
receivedCount = r.URL.Query().Get("count")
|
||||
resp := braveResponse{}
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
p := NewBraveProvider("key", srv.Client())
|
||||
origURL := *braveSearchURL
|
||||
u, _ := http.NewRequest("GET", srv.URL, nil)
|
||||
*braveSearchURL = *u.URL
|
||||
defer func() { *braveSearchURL = origURL }()
|
||||
|
||||
_, _ = p.Search(context.Background(), SearchRequest{Query: "test", MaxResults: 0})
|
||||
require.Equal(t, "5", receivedCount)
|
||||
}
|
||||
|
||||
func TestBraveProvider_Search_HTTPError(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(429)
|
||||
w.Write([]byte("rate limited"))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
p := NewBraveProvider("key", srv.Client())
|
||||
origURL := *braveSearchURL
|
||||
u, _ := http.NewRequest("GET", srv.URL, nil)
|
||||
*braveSearchURL = *u.URL
|
||||
defer func() { *braveSearchURL = origURL }()
|
||||
|
||||
_, err := p.Search(context.Background(), SearchRequest{Query: "test"})
|
||||
require.ErrorContains(t, err, "brave: status 429")
|
||||
}
|
||||
|
||||
func TestBraveProvider_Search_InvalidJSON(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.Write([]byte("not json"))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
p := NewBraveProvider("key", srv.Client())
|
||||
origURL := *braveSearchURL
|
||||
u, _ := http.NewRequest("GET", srv.URL, nil)
|
||||
*braveSearchURL = *u.URL
|
||||
defer func() { *braveSearchURL = origURL }()
|
||||
|
||||
_, err := p.Search(context.Background(), SearchRequest{Query: "test"})
|
||||
require.ErrorContains(t, err, "brave: decode response")
|
||||
}
|
||||
|
||||
func TestBraveProvider_Search_EmptyResults(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
resp := braveResponse{}
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
p := NewBraveProvider("key", srv.Client())
|
||||
origURL := *braveSearchURL
|
||||
u, _ := http.NewRequest("GET", srv.URL, nil)
|
||||
*braveSearchURL = *u.URL
|
||||
defer func() { *braveSearchURL = origURL }()
|
||||
|
||||
resp, err := p.Search(context.Background(), SearchRequest{Query: "test"})
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, resp.Results)
|
||||
}
|
||||
14
backend/internal/pkg/websearch/helpers.go
Normal file
14
backend/internal/pkg/websearch/helpers.go
Normal file
@@ -0,0 +1,14 @@
|
||||
package websearch
|
||||
|
||||
const (
|
||||
maxResponseSize = 1 << 20 // 1 MB
|
||||
errorBodyTruncLen = 200
|
||||
)
|
||||
|
||||
// truncateBody returns a truncated string of body for error messages.
|
||||
func truncateBody(body []byte) string {
|
||||
if len(body) <= errorBodyTruncLen {
|
||||
return string(body)
|
||||
}
|
||||
return string(body[:errorBodyTruncLen]) + "...(truncated)"
|
||||
}
|
||||
25
backend/internal/pkg/websearch/helpers_test.go
Normal file
25
backend/internal/pkg/websearch/helpers_test.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package websearch
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestTruncateBody_Short(t *testing.T) {
|
||||
body := []byte("short body")
|
||||
require.Equal(t, "short body", truncateBody(body))
|
||||
}
|
||||
|
||||
func TestTruncateBody_Long(t *testing.T) {
|
||||
body := []byte(strings.Repeat("x", 500))
|
||||
result := truncateBody(body)
|
||||
require.Len(t, result, errorBodyTruncLen+len("...(truncated)"))
|
||||
require.True(t, strings.HasSuffix(result, "...(truncated)"))
|
||||
}
|
||||
|
||||
func TestTruncateBody_ExactBoundary(t *testing.T) {
|
||||
body := []byte(strings.Repeat("x", errorBodyTruncLen))
|
||||
require.Equal(t, string(body), truncateBody(body))
|
||||
}
|
||||
273
backend/internal/pkg/websearch/manager.go
Normal file
273
backend/internal/pkg/websearch/manager.go
Normal file
@@ -0,0 +1,273 @@
|
||||
package websearch
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// Quota refresh interval constants.
|
||||
const (
|
||||
QuotaRefreshDaily = "daily"
|
||||
QuotaRefreshWeekly = "weekly"
|
||||
QuotaRefreshMonthly = "monthly"
|
||||
)
|
||||
|
||||
// ProviderConfig holds the configuration for a single search provider.
|
||||
type ProviderConfig struct {
|
||||
Type string `json:"type"` // ProviderTypeBrave | ProviderTypeTavily
|
||||
APIKey string `json:"api_key"` // secret
|
||||
Priority int `json:"priority"` // lower = higher priority
|
||||
QuotaLimit int64 `json:"quota_limit"` // 0 = unlimited
|
||||
QuotaRefreshInterval string `json:"quota_refresh_interval"` // QuotaRefreshDaily / Weekly / Monthly
|
||||
ProxyURL string `json:"-"` // resolved proxy URL (not persisted)
|
||||
ExpiresAt *int64 `json:"expires_at,omitempty"` // optional expiration (unix seconds)
|
||||
}
|
||||
|
||||
// Manager selects providers by priority and tracks quota via Redis.
|
||||
type Manager struct {
|
||||
configs []ProviderConfig
|
||||
redis *redis.Client
|
||||
|
||||
clientMu sync.Mutex
|
||||
clientCache map[string]*http.Client
|
||||
}
|
||||
|
||||
const (
|
||||
quotaKeyPrefix = "websearch:quota:"
|
||||
searchRequestTimeout = 30 * time.Second
|
||||
quotaTTLBuffer = 24 * time.Hour
|
||||
maxCachedClients = 100
|
||||
)
|
||||
|
||||
// quotaIncrScript atomically increments the counter and sets TTL on first creation.
|
||||
// KEYS[1] = quota key, ARGV[1] = TTL in seconds.
|
||||
// Returns the new counter value.
|
||||
var quotaIncrScript = redis.NewScript(`
|
||||
local val = redis.call('INCR', KEYS[1])
|
||||
if val == 1 then
|
||||
redis.call('EXPIRE', KEYS[1], ARGV[1])
|
||||
else
|
||||
-- Defensive: ensure TTL exists even if a prior EXPIRE failed
|
||||
local ttl = redis.call('TTL', KEYS[1])
|
||||
if ttl == -1 then
|
||||
redis.call('EXPIRE', KEYS[1], ARGV[1])
|
||||
end
|
||||
end
|
||||
return val
|
||||
`)
|
||||
|
||||
// NewManager creates a Manager with the given provider configs and Redis client.
|
||||
func NewManager(configs []ProviderConfig, redisClient *redis.Client) *Manager {
|
||||
sorted := make([]ProviderConfig, len(configs))
|
||||
copy(sorted, configs)
|
||||
sort.Slice(sorted, func(i, j int) bool {
|
||||
return sorted[i].Priority < sorted[j].Priority
|
||||
})
|
||||
return &Manager{
|
||||
configs: sorted,
|
||||
redis: redisClient,
|
||||
clientCache: make(map[string]*http.Client),
|
||||
}
|
||||
}
|
||||
|
||||
// SearchWithBestProvider selects the highest-priority available provider,
|
||||
// reserves quota, executes the search, and rolls back quota on failure.
|
||||
func (m *Manager) SearchWithBestProvider(ctx context.Context, req SearchRequest) (*SearchResponse, string, error) {
|
||||
if strings.TrimSpace(req.Query) == "" {
|
||||
return nil, "", fmt.Errorf("websearch: empty search query")
|
||||
}
|
||||
for _, cfg := range m.configs {
|
||||
if !m.isProviderAvailable(cfg) {
|
||||
continue
|
||||
}
|
||||
allowed, incremented := m.tryReserveQuota(ctx, cfg)
|
||||
if !allowed {
|
||||
continue
|
||||
}
|
||||
resp, err := m.executeSearch(ctx, cfg, req)
|
||||
if err != nil {
|
||||
if incremented {
|
||||
m.rollbackQuota(ctx, cfg)
|
||||
}
|
||||
slog.Warn("websearch: provider search failed",
|
||||
"provider", cfg.Type, "error", err)
|
||||
continue
|
||||
}
|
||||
return resp, cfg.Type, nil
|
||||
}
|
||||
return nil, "", fmt.Errorf("websearch: no available provider (all exhausted or failed)")
|
||||
}
|
||||
|
||||
func (m *Manager) isProviderAvailable(cfg ProviderConfig) bool {
|
||||
if cfg.APIKey == "" {
|
||||
return false
|
||||
}
|
||||
if cfg.ExpiresAt != nil && time.Now().Unix() > *cfg.ExpiresAt {
|
||||
slog.Info("websearch: provider expired, skipping",
|
||||
"provider", cfg.Type, "expires_at", *cfg.ExpiresAt)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// tryReserveQuota atomically increments the counter via Lua script and checks limit.
|
||||
// Returns (allowed, incremented): allowed=true means the request may proceed;
|
||||
// incremented=true means the Redis counter was actually incremented (so rollback is needed on failure).
|
||||
func (m *Manager) tryReserveQuota(ctx context.Context, cfg ProviderConfig) (bool, bool) {
|
||||
if cfg.QuotaLimit <= 0 {
|
||||
return true, false // unlimited, no INCR
|
||||
}
|
||||
if m.redis == nil {
|
||||
slog.Warn("websearch: Redis unavailable, quota check skipped",
|
||||
"provider", cfg.Type)
|
||||
return true, false // allowed but not incremented
|
||||
}
|
||||
key := quotaRedisKey(cfg.Type, cfg.QuotaRefreshInterval)
|
||||
ttlSec := int(quotaTTL(cfg.QuotaRefreshInterval).Seconds())
|
||||
|
||||
newVal, err := quotaIncrScript.Run(ctx, m.redis, []string{key}, ttlSec).Int64()
|
||||
if err != nil {
|
||||
slog.Warn("websearch: quota Lua INCR failed, allowing request",
|
||||
"provider", cfg.Type, "error", err)
|
||||
return true, false // allowed but not incremented
|
||||
}
|
||||
if newVal > cfg.QuotaLimit {
|
||||
if decrErr := m.redis.Decr(ctx, key).Err(); decrErr != nil {
|
||||
slog.Warn("websearch: quota over-limit DECR failed",
|
||||
"provider", cfg.Type, "error", decrErr)
|
||||
}
|
||||
slog.Info("websearch: provider quota exhausted",
|
||||
"provider", cfg.Type, "used", newVal, "limit", cfg.QuotaLimit)
|
||||
return false, false // rejected, already rolled back
|
||||
}
|
||||
return true, true // allowed and incremented
|
||||
}
|
||||
|
||||
// rollbackQuota decrements the counter after a search failure.
|
||||
func (m *Manager) rollbackQuota(ctx context.Context, cfg ProviderConfig) {
|
||||
if cfg.QuotaLimit <= 0 || m.redis == nil {
|
||||
return
|
||||
}
|
||||
key := quotaRedisKey(cfg.Type, cfg.QuotaRefreshInterval)
|
||||
if err := m.redis.Decr(ctx, key).Err(); err != nil {
|
||||
slog.Warn("websearch: quota rollback DECR failed",
|
||||
"provider", cfg.Type, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) executeSearch(ctx context.Context, cfg ProviderConfig, req SearchRequest) (*SearchResponse, error) {
|
||||
proxyURL := cfg.ProxyURL
|
||||
if req.ProxyURL != "" {
|
||||
proxyURL = req.ProxyURL
|
||||
}
|
||||
client := m.getOrCreateHTTPClient(proxyURL)
|
||||
provider := m.buildProvider(cfg, client)
|
||||
return provider.Search(ctx, req)
|
||||
}
|
||||
|
||||
// GetUsage returns the current usage count for the given provider.
|
||||
func (m *Manager) GetUsage(ctx context.Context, providerType, refreshInterval string) (int64, error) {
|
||||
if m.redis == nil {
|
||||
return 0, nil
|
||||
}
|
||||
key := quotaRedisKey(providerType, refreshInterval)
|
||||
val, err := m.redis.Get(ctx, key).Int64()
|
||||
if err == redis.Nil {
|
||||
return 0, nil
|
||||
}
|
||||
return val, err
|
||||
}
|
||||
|
||||
// GetAllUsage returns usage for every configured provider.
|
||||
func (m *Manager) GetAllUsage(ctx context.Context) map[string]int64 {
|
||||
result := make(map[string]int64, len(m.configs))
|
||||
for _, cfg := range m.configs {
|
||||
used, _ := m.GetUsage(ctx, cfg.Type, cfg.QuotaRefreshInterval)
|
||||
result[cfg.Type] = used
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// --- HTTP client cache (bounded) ---
|
||||
|
||||
func (m *Manager) getOrCreateHTTPClient(proxyURL string) *http.Client {
|
||||
m.clientMu.Lock()
|
||||
defer m.clientMu.Unlock()
|
||||
|
||||
if c, ok := m.clientCache[proxyURL]; ok {
|
||||
return c
|
||||
}
|
||||
if len(m.clientCache) >= maxCachedClients {
|
||||
m.clientCache = make(map[string]*http.Client) // evict all
|
||||
}
|
||||
c := newHTTPClient(proxyURL)
|
||||
m.clientCache[proxyURL] = c
|
||||
return c
|
||||
}
|
||||
|
||||
func newHTTPClient(proxyURL string) *http.Client {
|
||||
transport := &http.Transport{
|
||||
TLSClientConfig: &tls.Config{MinVersion: tls.VersionTLS12},
|
||||
}
|
||||
if proxyURL != "" {
|
||||
if u, err := url.Parse(proxyURL); err == nil {
|
||||
transport.Proxy = http.ProxyURL(u)
|
||||
}
|
||||
}
|
||||
return &http.Client{Transport: transport, Timeout: searchRequestTimeout}
|
||||
}
|
||||
|
||||
// --- Provider factory ---
|
||||
|
||||
func (m *Manager) buildProvider(cfg ProviderConfig, client *http.Client) Provider {
|
||||
switch cfg.Type {
|
||||
case braveProviderName:
|
||||
return NewBraveProvider(cfg.APIKey, client)
|
||||
case tavilyProviderName:
|
||||
return NewTavilyProvider(cfg.APIKey, client)
|
||||
default:
|
||||
slog.Warn("websearch: unknown provider type, falling back to brave",
|
||||
"type", cfg.Type)
|
||||
return NewBraveProvider(cfg.APIKey, client)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Redis key helpers ---
|
||||
|
||||
func quotaRedisKey(providerType, refreshInterval string) string {
|
||||
return quotaKeyPrefix + providerType + ":" + periodKey(refreshInterval)
|
||||
}
|
||||
|
||||
func periodKey(refreshInterval string) string {
|
||||
now := time.Now().UTC()
|
||||
switch refreshInterval {
|
||||
case QuotaRefreshDaily:
|
||||
return now.Format("2006-01-02")
|
||||
case QuotaRefreshWeekly:
|
||||
year, week := now.ISOWeek()
|
||||
return fmt.Sprintf("%d-W%02d", year, week)
|
||||
default: // QuotaRefreshMonthly
|
||||
return now.Format("2006-01")
|
||||
}
|
||||
}
|
||||
|
||||
func quotaTTL(refreshInterval string) time.Duration {
|
||||
switch refreshInterval {
|
||||
case QuotaRefreshDaily:
|
||||
return 24*time.Hour + quotaTTLBuffer
|
||||
case QuotaRefreshWeekly:
|
||||
return 7*24*time.Hour + quotaTTLBuffer
|
||||
default:
|
||||
return 31*24*time.Hour + quotaTTLBuffer
|
||||
}
|
||||
}
|
||||
149
backend/internal/pkg/websearch/manager_test.go
Normal file
149
backend/internal/pkg/websearch/manager_test.go
Normal file
@@ -0,0 +1,149 @@
|
||||
package websearch
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewManager_SortsByPriority(t *testing.T) {
|
||||
configs := []ProviderConfig{
|
||||
{Type: "brave", APIKey: "k3", Priority: 30},
|
||||
{Type: "tavily", APIKey: "k1", Priority: 10},
|
||||
}
|
||||
m := NewManager(configs, nil)
|
||||
require.Equal(t, 10, m.configs[0].Priority)
|
||||
require.Equal(t, 30, m.configs[1].Priority)
|
||||
}
|
||||
|
||||
func TestManager_SearchWithBestProvider_EmptyQuery(t *testing.T) {
|
||||
m := NewManager([]ProviderConfig{{Type: "brave", APIKey: "k"}}, nil)
|
||||
_, _, err := m.SearchWithBestProvider(context.Background(), SearchRequest{Query: ""})
|
||||
require.ErrorContains(t, err, "empty search query")
|
||||
|
||||
_, _, err = m.SearchWithBestProvider(context.Background(), SearchRequest{Query: " "})
|
||||
require.ErrorContains(t, err, "empty search query")
|
||||
}
|
||||
|
||||
func TestManager_SearchWithBestProvider_SkipEmptyAPIKey(t *testing.T) {
|
||||
m := NewManager([]ProviderConfig{{Type: "brave", APIKey: ""}}, nil)
|
||||
_, _, err := m.SearchWithBestProvider(context.Background(), SearchRequest{Query: "test"})
|
||||
require.ErrorContains(t, err, "no available provider")
|
||||
}
|
||||
|
||||
func TestManager_SearchWithBestProvider_SkipExpired(t *testing.T) {
|
||||
past := time.Now().Add(-1 * time.Hour).Unix()
|
||||
m := NewManager([]ProviderConfig{
|
||||
{Type: "brave", APIKey: "k", ExpiresAt: &past},
|
||||
}, nil)
|
||||
_, _, err := m.SearchWithBestProvider(context.Background(), SearchRequest{Query: "test"})
|
||||
require.ErrorContains(t, err, "no available provider")
|
||||
}
|
||||
|
||||
func TestManager_SearchWithBestProvider_PriorityOrder(t *testing.T) {
|
||||
// Create two mock servers that return different results
|
||||
srvBrave := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
resp := braveResponse{}
|
||||
resp.Web.Results = []braveResult{{URL: "https://brave.com", Title: "Brave", Description: "from brave"}}
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer srvBrave.Close()
|
||||
|
||||
// Override brave endpoint for test
|
||||
origURL := *braveSearchURL
|
||||
u, _ := http.NewRequest("GET", srvBrave.URL, nil)
|
||||
*braveSearchURL = *u.URL
|
||||
defer func() { *braveSearchURL = origURL }()
|
||||
|
||||
m := NewManager([]ProviderConfig{
|
||||
{Type: "brave", APIKey: "k1", Priority: 1},
|
||||
{Type: "tavily", APIKey: "k2", Priority: 2},
|
||||
}, nil)
|
||||
// Inject the test server's client
|
||||
m.clientCache[srvBrave.URL] = srvBrave.Client()
|
||||
m.clientCache[""] = srvBrave.Client()
|
||||
|
||||
resp, providerName, err := m.SearchWithBestProvider(context.Background(), SearchRequest{Query: "test"})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "brave", providerName)
|
||||
require.Len(t, resp.Results, 1)
|
||||
require.Equal(t, "from brave", resp.Results[0].Snippet)
|
||||
}
|
||||
|
||||
func TestManager_SearchWithBestProvider_NilRedis(t *testing.T) {
|
||||
// With nil Redis, quota check is skipped (always allowed)
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
resp := braveResponse{}
|
||||
resp.Web.Results = []braveResult{{URL: "https://test.com", Title: "Test", Description: "result"}}
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
origURL := *braveSearchURL
|
||||
u, _ := http.NewRequest("GET", srv.URL, nil)
|
||||
*braveSearchURL = *u.URL
|
||||
defer func() { *braveSearchURL = origURL }()
|
||||
|
||||
m := NewManager([]ProviderConfig{
|
||||
{Type: "brave", APIKey: "k", Priority: 1, QuotaLimit: 100},
|
||||
}, nil) // nil Redis
|
||||
m.clientCache[""] = srv.Client()
|
||||
|
||||
resp, _, err := m.SearchWithBestProvider(context.Background(), SearchRequest{Query: "test"})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, resp.Results, 1)
|
||||
}
|
||||
|
||||
func TestManager_GetUsage_NilRedis(t *testing.T) {
|
||||
m := NewManager(nil, nil)
|
||||
used, err := m.GetUsage(context.Background(), "brave", "monthly")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(0), used)
|
||||
}
|
||||
|
||||
func TestManager_GetAllUsage_NilRedis(t *testing.T) {
|
||||
m := NewManager([]ProviderConfig{
|
||||
{Type: "brave", QuotaRefreshInterval: "monthly"},
|
||||
}, nil)
|
||||
usage := m.GetAllUsage(context.Background())
|
||||
require.Equal(t, int64(0), usage["brave"])
|
||||
}
|
||||
|
||||
// --- Key/TTL helpers ---
|
||||
|
||||
func TestQuotaTTL_Daily(t *testing.T) {
|
||||
require.Equal(t, 24*time.Hour+quotaTTLBuffer, quotaTTL(QuotaRefreshDaily))
|
||||
}
|
||||
|
||||
func TestQuotaTTL_Weekly(t *testing.T) {
|
||||
require.Equal(t, 7*24*time.Hour+quotaTTLBuffer, quotaTTL(QuotaRefreshWeekly))
|
||||
}
|
||||
|
||||
func TestQuotaTTL_Monthly(t *testing.T) {
|
||||
require.Equal(t, 31*24*time.Hour+quotaTTLBuffer, quotaTTL(QuotaRefreshMonthly))
|
||||
}
|
||||
|
||||
func TestPeriodKey_Daily(t *testing.T) {
|
||||
key := periodKey(QuotaRefreshDaily)
|
||||
require.Regexp(t, `^\d{4}-\d{2}-\d{2}$`, key)
|
||||
}
|
||||
|
||||
func TestPeriodKey_Weekly(t *testing.T) {
|
||||
key := periodKey(QuotaRefreshWeekly)
|
||||
require.Regexp(t, `^\d{4}-W\d{2}$`, key)
|
||||
}
|
||||
|
||||
func TestPeriodKey_Monthly(t *testing.T) {
|
||||
key := periodKey(QuotaRefreshMonthly)
|
||||
require.Regexp(t, `^\d{4}-\d{2}$`, key)
|
||||
}
|
||||
|
||||
func TestQuotaRedisKey_Format(t *testing.T) {
|
||||
key := quotaRedisKey("brave", QuotaRefreshDaily)
|
||||
require.Contains(t, key, "websearch:quota:brave:")
|
||||
}
|
||||
11
backend/internal/pkg/websearch/provider.go
Normal file
11
backend/internal/pkg/websearch/provider.go
Normal file
@@ -0,0 +1,11 @@
|
||||
package websearch
|
||||
|
||||
import "context"
|
||||
|
||||
// Provider is the interface every search backend must implement.
|
||||
type Provider interface {
|
||||
// Name returns the provider identifier ("brave" or "tavily").
|
||||
Name() string
|
||||
// Search executes a web search and returns results.
|
||||
Search(ctx context.Context, req SearchRequest) (*SearchResponse, error)
|
||||
}
|
||||
107
backend/internal/pkg/websearch/tavily.go
Normal file
107
backend/internal/pkg/websearch/tavily.go
Normal file
@@ -0,0 +1,107 @@
|
||||
package websearch
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
const (
|
||||
tavilySearchEndpoint = "https://api.tavily.com/search"
|
||||
tavilyProviderName = "tavily"
|
||||
tavilySearchDepthBasic = "basic"
|
||||
)
|
||||
|
||||
// TavilyProvider implements web search via the Tavily Search API.
|
||||
type TavilyProvider struct {
|
||||
apiKey string
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// NewTavilyProvider creates a Tavily Search provider.
|
||||
// The caller is responsible for configuring the http.Client with proxy/timeouts.
|
||||
func NewTavilyProvider(apiKey string, httpClient *http.Client) *TavilyProvider {
|
||||
if httpClient == nil {
|
||||
httpClient = http.DefaultClient
|
||||
}
|
||||
return &TavilyProvider{apiKey: apiKey, httpClient: httpClient}
|
||||
}
|
||||
|
||||
func (t *TavilyProvider) Name() string { return tavilyProviderName }
|
||||
|
||||
func (t *TavilyProvider) Search(ctx context.Context, req SearchRequest) (*SearchResponse, error) {
|
||||
maxResults := req.MaxResults
|
||||
if maxResults <= 0 {
|
||||
maxResults = defaultMaxResults
|
||||
}
|
||||
|
||||
payload := tavilyRequest{
|
||||
APIKey: t.apiKey,
|
||||
Query: req.Query,
|
||||
MaxResults: maxResults,
|
||||
SearchDepth: tavilySearchDepthBasic,
|
||||
}
|
||||
|
||||
bodyBytes, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("tavily: encode request: %w", err)
|
||||
}
|
||||
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, tavilySearchEndpoint, bytes.NewReader(bodyBytes))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("tavily: build request: %w", err)
|
||||
}
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := t.httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("tavily: request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseSize))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("tavily: read body: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("tavily: status %d: %s", resp.StatusCode, truncateBody(body))
|
||||
}
|
||||
|
||||
var raw tavilyResponse
|
||||
if err := json.Unmarshal(body, &raw); err != nil {
|
||||
return nil, fmt.Errorf("tavily: decode response: %w", err)
|
||||
}
|
||||
|
||||
results := make([]SearchResult, 0, len(raw.Results))
|
||||
for _, r := range raw.Results {
|
||||
results = append(results, SearchResult{
|
||||
URL: r.URL,
|
||||
Title: r.Title,
|
||||
Snippet: r.Content,
|
||||
})
|
||||
}
|
||||
|
||||
return &SearchResponse{Results: results, Query: req.Query}, nil
|
||||
}
|
||||
|
||||
type tavilyRequest struct {
|
||||
APIKey string `json:"api_key"`
|
||||
Query string `json:"query"`
|
||||
MaxResults int `json:"max_results"`
|
||||
SearchDepth string `json:"search_depth"`
|
||||
}
|
||||
|
||||
type tavilyResponse struct {
|
||||
Results []tavilyResult `json:"results"`
|
||||
}
|
||||
|
||||
type tavilyResult struct {
|
||||
URL string `json:"url"`
|
||||
Title string `json:"title"`
|
||||
Content string `json:"content"`
|
||||
Score float64 `json:"score"`
|
||||
}
|
||||
63
backend/internal/pkg/websearch/tavily_test.go
Normal file
63
backend/internal/pkg/websearch/tavily_test.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package websearch
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestTavilyProvider_Name(t *testing.T) {
|
||||
p := NewTavilyProvider("key", nil)
|
||||
require.Equal(t, "tavily", p.Name())
|
||||
}
|
||||
|
||||
func TestTavilyProvider_Search_RequestConstruction(t *testing.T) {
|
||||
// Verify tavilyRequest struct fields map correctly
|
||||
req := tavilyRequest{
|
||||
APIKey: "test-key",
|
||||
Query: "golang",
|
||||
MaxResults: 3,
|
||||
SearchDepth: tavilySearchDepthBasic,
|
||||
}
|
||||
data, err := json.Marshal(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
var parsed map[string]any
|
||||
require.NoError(t, json.Unmarshal(data, &parsed))
|
||||
require.Equal(t, "test-key", parsed["api_key"])
|
||||
require.Equal(t, "golang", parsed["query"])
|
||||
require.Equal(t, float64(3), parsed["max_results"])
|
||||
require.Equal(t, "basic", parsed["search_depth"])
|
||||
}
|
||||
|
||||
func TestTavilyProvider_Search_ResponseParsing(t *testing.T) {
|
||||
rawResp := `{"results":[{"url":"https://go.dev","title":"Go","content":"Go programming language","score":0.95}]}`
|
||||
var resp tavilyResponse
|
||||
require.NoError(t, json.Unmarshal([]byte(rawResp), &resp))
|
||||
require.Len(t, resp.Results, 1)
|
||||
require.Equal(t, "https://go.dev", resp.Results[0].URL)
|
||||
require.Equal(t, "Go programming language", resp.Results[0].Content)
|
||||
require.InDelta(t, 0.95, resp.Results[0].Score, 0.001)
|
||||
|
||||
// Verify mapping to SearchResult
|
||||
results := make([]SearchResult, 0, len(resp.Results))
|
||||
for _, r := range resp.Results {
|
||||
results = append(results, SearchResult{
|
||||
URL: r.URL, Title: r.Title, Snippet: r.Content,
|
||||
})
|
||||
}
|
||||
require.Equal(t, "Go programming language", results[0].Snippet)
|
||||
require.Equal(t, "", results[0].PageAge)
|
||||
}
|
||||
|
||||
func TestTavilyProvider_Search_EmptyResults(t *testing.T) {
|
||||
var resp tavilyResponse
|
||||
require.NoError(t, json.Unmarshal([]byte(`{"results":[]}`), &resp))
|
||||
require.Empty(t, resp.Results)
|
||||
}
|
||||
|
||||
func TestTavilyProvider_Search_InvalidJSON(t *testing.T) {
|
||||
var resp tavilyResponse
|
||||
require.Error(t, json.Unmarshal([]byte("not json"), &resp))
|
||||
}
|
||||
30
backend/internal/pkg/websearch/types.go
Normal file
30
backend/internal/pkg/websearch/types.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package websearch
|
||||
|
||||
// SearchResult represents a single web search result.
|
||||
type SearchResult struct {
|
||||
URL string `json:"url"`
|
||||
Title string `json:"title"`
|
||||
Snippet string `json:"snippet"`
|
||||
PageAge string `json:"page_age,omitempty"`
|
||||
}
|
||||
|
||||
// SearchRequest describes a web search to perform.
|
||||
type SearchRequest struct {
|
||||
Query string
|
||||
MaxResults int // defaults to defaultMaxResults if <= 0
|
||||
ProxyURL string // optional HTTP proxy URL
|
||||
}
|
||||
|
||||
// SearchResponse holds the results of a web search.
|
||||
type SearchResponse struct {
|
||||
Results []SearchResult
|
||||
Query string // the query that was actually executed
|
||||
}
|
||||
|
||||
const defaultMaxResults = 5
|
||||
|
||||
// Provider type identifiers.
|
||||
const (
|
||||
ProviderTypeBrave = "brave"
|
||||
ProviderTypeTavily = "tavily"
|
||||
)
|
||||
@@ -41,10 +41,14 @@ func (r *channelRepository) Create(ctx context.Context, channel *service.Channel
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
featuresConfigJSON, err := marshalFeaturesConfig(channel.FeaturesConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = tx.QueryRowContext(ctx,
|
||||
`INSERT INTO channels (name, description, status, model_mapping, billing_model_source, restrict_models, features) VALUES ($1, $2, $3, $4, $5, $6, $7)
|
||||
`INSERT INTO channels (name, description, status, model_mapping, billing_model_source, restrict_models, features, features_config) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
||||
RETURNING id, created_at, updated_at`,
|
||||
channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.BillingModelSource, channel.RestrictModels, channel.Features,
|
||||
channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.BillingModelSource, channel.RestrictModels, channel.Features, featuresConfigJSON,
|
||||
).Scan(&channel.ID, &channel.CreatedAt, &channel.UpdatedAt)
|
||||
if err != nil {
|
||||
if isUniqueViolation(err) {
|
||||
@@ -73,11 +77,11 @@ func (r *channelRepository) Create(ctx context.Context, channel *service.Channel
|
||||
|
||||
func (r *channelRepository) GetByID(ctx context.Context, id int64) (*service.Channel, error) {
|
||||
ch := &service.Channel{}
|
||||
var modelMappingJSON []byte
|
||||
var modelMappingJSON, featuresConfigJSON []byte
|
||||
err := r.db.QueryRowContext(ctx,
|
||||
`SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, features, created_at, updated_at
|
||||
`SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, features, features_config, created_at, updated_at
|
||||
FROM channels WHERE id = $1`, id,
|
||||
).Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &ch.CreatedAt, &ch.UpdatedAt)
|
||||
).Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &featuresConfigJSON, &ch.CreatedAt, &ch.UpdatedAt)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, service.ErrChannelNotFound
|
||||
}
|
||||
@@ -85,6 +89,7 @@ func (r *channelRepository) GetByID(ctx context.Context, id int64) (*service.Cha
|
||||
return nil, fmt.Errorf("get channel: %w", err)
|
||||
}
|
||||
ch.ModelMapping = unmarshalModelMapping(modelMappingJSON)
|
||||
ch.FeaturesConfig = unmarshalFeaturesConfig(featuresConfigJSON)
|
||||
|
||||
groupIDs, err := r.GetGroupIDs(ctx, id)
|
||||
if err != nil {
|
||||
@@ -107,10 +112,14 @@ func (r *channelRepository) Update(ctx context.Context, channel *service.Channel
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
featuresConfigJSON, err := marshalFeaturesConfig(channel.FeaturesConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
result, err := tx.ExecContext(ctx,
|
||||
`UPDATE channels SET name = $1, description = $2, status = $3, model_mapping = $4, billing_model_source = $5, restrict_models = $6, features = $7, updated_at = NOW()
|
||||
WHERE id = $8`,
|
||||
channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.BillingModelSource, channel.RestrictModels, channel.Features, channel.ID,
|
||||
`UPDATE channels SET name = $1, description = $2, status = $3, model_mapping = $4, billing_model_source = $5, restrict_models = $6, features = $7, features_config = $8, updated_at = NOW()
|
||||
WHERE id = $9`,
|
||||
channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.BillingModelSource, channel.RestrictModels, channel.Features, featuresConfigJSON, channel.ID,
|
||||
)
|
||||
if err != nil {
|
||||
if isUniqueViolation(err) {
|
||||
@@ -187,9 +196,9 @@ func (r *channelRepository) List(ctx context.Context, params pagination.Paginati
|
||||
|
||||
// 查询 channel 列表
|
||||
dataQuery := fmt.Sprintf(
|
||||
`SELECT c.id, c.name, c.description, c.status, c.model_mapping, c.billing_model_source, c.restrict_models, c.features, c.created_at, c.updated_at
|
||||
FROM channels c WHERE %s ORDER BY c.id ASC LIMIT $%d OFFSET $%d`,
|
||||
whereClause, argIdx, argIdx+1,
|
||||
`SELECT c.id, c.name, c.description, c.status, c.model_mapping, c.billing_model_source, c.restrict_models, c.features, c.features_config, c.created_at, c.updated_at
|
||||
FROM channels c WHERE %s ORDER BY %s LIMIT $%d OFFSET $%d`,
|
||||
whereClause, channelListOrderBy(params), argIdx, argIdx+1,
|
||||
)
|
||||
args = append(args, pageSize, offset)
|
||||
|
||||
@@ -203,11 +212,12 @@ func (r *channelRepository) List(ctx context.Context, params pagination.Paginati
|
||||
var channelIDs []int64
|
||||
for rows.Next() {
|
||||
var ch service.Channel
|
||||
var modelMappingJSON []byte
|
||||
if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &ch.CreatedAt, &ch.UpdatedAt); err != nil {
|
||||
var modelMappingJSON, featuresConfigJSON []byte
|
||||
if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &featuresConfigJSON, &ch.CreatedAt, &ch.UpdatedAt); err != nil {
|
||||
return nil, nil, fmt.Errorf("scan channel: %w", err)
|
||||
}
|
||||
ch.ModelMapping = unmarshalModelMapping(modelMappingJSON)
|
||||
ch.FeaturesConfig = unmarshalFeaturesConfig(featuresConfigJSON)
|
||||
channels = append(channels, ch)
|
||||
channelIDs = append(channelIDs, ch.ID)
|
||||
}
|
||||
@@ -246,9 +256,34 @@ func (r *channelRepository) List(ctx context.Context, params pagination.Paginati
|
||||
return channels, paginationResult, nil
|
||||
}
|
||||
|
||||
func channelListOrderBy(params pagination.PaginationParams) string {
|
||||
sortBy := strings.ToLower(strings.TrimSpace(params.SortBy))
|
||||
sortOrder := strings.ToUpper(params.NormalizedSortOrder(pagination.SortOrderAsc))
|
||||
|
||||
var column string
|
||||
switch sortBy {
|
||||
case "":
|
||||
column = "c.id"
|
||||
sortOrder = "ASC"
|
||||
case "id":
|
||||
column = "c.id"
|
||||
case "name":
|
||||
column = "c.name"
|
||||
case "status":
|
||||
column = "c.status"
|
||||
case "created_at":
|
||||
column = "c.created_at"
|
||||
default:
|
||||
column = "c.id"
|
||||
sortOrder = "ASC"
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s %s, c.id %s", column, sortOrder, sortOrder)
|
||||
}
|
||||
|
||||
func (r *channelRepository) ListAll(ctx context.Context) ([]service.Channel, error) {
|
||||
rows, err := r.db.QueryContext(ctx,
|
||||
`SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, features, created_at, updated_at FROM channels ORDER BY id`,
|
||||
`SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, features, features_config, created_at, updated_at FROM channels ORDER BY id`,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query all channels: %w", err)
|
||||
@@ -259,11 +294,12 @@ func (r *channelRepository) ListAll(ctx context.Context) ([]service.Channel, err
|
||||
var channelIDs []int64
|
||||
for rows.Next() {
|
||||
var ch service.Channel
|
||||
var modelMappingJSON []byte
|
||||
if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &ch.CreatedAt, &ch.UpdatedAt); err != nil {
|
||||
var modelMappingJSON, featuresConfigJSON []byte
|
||||
if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &featuresConfigJSON, &ch.CreatedAt, &ch.UpdatedAt); err != nil {
|
||||
return nil, fmt.Errorf("scan channel: %w", err)
|
||||
}
|
||||
ch.ModelMapping = unmarshalModelMapping(modelMappingJSON)
|
||||
ch.FeaturesConfig = unmarshalFeaturesConfig(featuresConfigJSON)
|
||||
channels = append(channels, ch)
|
||||
channelIDs = append(channelIDs, ch.ID)
|
||||
}
|
||||
@@ -431,6 +467,28 @@ func unmarshalModelMapping(data []byte) map[string]map[string]string {
|
||||
return m
|
||||
}
|
||||
|
||||
func marshalFeaturesConfig(m map[string]any) ([]byte, error) {
|
||||
if len(m) == 0 {
|
||||
return []byte("{}"), nil
|
||||
}
|
||||
data, err := json.Marshal(m)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal features_config: %w", err)
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func unmarshalFeaturesConfig(data []byte) map[string]any {
|
||||
if len(data) == 0 {
|
||||
return nil
|
||||
}
|
||||
var m map[string]any
|
||||
if err := json.Unmarshal(data, &m); err != nil {
|
||||
return nil
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// GetGroupPlatforms 批量查询分组 ID 对应的平台
|
||||
func (r *channelRepository) GetGroupPlatforms(ctx context.Context, groupIDs []int64) (map[int64]string, error) {
|
||||
if len(groupIDs) == 0 {
|
||||
|
||||
@@ -407,6 +407,9 @@ func registerSettingsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
// Beta 策略配置
|
||||
adminSettings.GET("/beta-policy", h.Admin.Setting.GetBetaPolicySettings)
|
||||
adminSettings.PUT("/beta-policy", h.Admin.Setting.UpdateBetaPolicySettings)
|
||||
// Web Search 模拟配置
|
||||
adminSettings.GET("/web-search-emulation", h.Admin.Setting.GetWebSearchEmulationConfig)
|
||||
adminSettings.PUT("/web-search-emulation", h.Admin.Setting.UpdateWebSearchEmulationConfig)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -969,7 +969,7 @@ func (a *Account) IsOveragesEnabled() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// IsOpenAIPassthroughEnabled 返回 OpenAI 账号是否启用“自动透传(仅替换认证)”。
|
||||
// IsOpenAIPassthroughEnabled 返回 OpenAI 账号是否启用"自动透传(仅替换认证)"。
|
||||
//
|
||||
// 新字段:accounts.extra.openai_passthrough。
|
||||
// 兼容字段:accounts.extra.openai_oauth_passthrough(历史 OAuth 开关)。
|
||||
@@ -1133,7 +1133,7 @@ func (a *Account) ResolveOpenAIResponsesWebSocketV2Mode(defaultMode string) stri
|
||||
return resolvedDefault
|
||||
}
|
||||
|
||||
// IsOpenAIWSForceHTTPEnabled 返回账号级“强制 HTTP”开关。
|
||||
// IsOpenAIWSForceHTTPEnabled 返回账号级"强制 HTTP"开关。
|
||||
// 字段:accounts.extra.openai_ws_force_http。
|
||||
func (a *Account) IsOpenAIWSForceHTTPEnabled() bool {
|
||||
if a == nil || !a.IsOpenAI() || a.Extra == nil {
|
||||
@@ -1158,7 +1158,7 @@ func (a *Account) IsOpenAIOAuthPassthroughEnabled() bool {
|
||||
return a != nil && a.IsOpenAIOAuth() && a.IsOpenAIPassthroughEnabled()
|
||||
}
|
||||
|
||||
// IsAnthropicAPIKeyPassthroughEnabled 返回 Anthropic API Key 账号是否启用“自动透传(仅替换认证)”。
|
||||
// IsAnthropicAPIKeyPassthroughEnabled 返回 Anthropic API Key 账号是否启用"自动透传(仅替换认证)"。
|
||||
// 字段:accounts.extra.anthropic_passthrough。
|
||||
// 字段缺失或类型不正确时,按 false(关闭)处理。
|
||||
func (a *Account) IsAnthropicAPIKeyPassthroughEnabled() bool {
|
||||
@@ -1169,7 +1169,18 @@ func (a *Account) IsAnthropicAPIKeyPassthroughEnabled() bool {
|
||||
return ok && enabled
|
||||
}
|
||||
|
||||
// IsCodexCLIOnlyEnabled 返回 OpenAI OAuth 账号是否启用“仅允许 Codex 官方客户端”。
|
||||
// IsWebSearchEmulationEnabled 返回 Anthropic API Key 账号是否启用 web search 模拟。
|
||||
// 字段:accounts.extra.web_search_emulation。
|
||||
// 字段缺失或类型不正确时,按 false(关闭)处理。
|
||||
func (a *Account) IsWebSearchEmulationEnabled() bool {
|
||||
if a == nil || a.Platform != PlatformAnthropic || a.Type != AccountTypeAPIKey || a.Extra == nil {
|
||||
return false
|
||||
}
|
||||
enabled, ok := a.Extra[featureKeyWebSearchEmulation].(bool)
|
||||
return ok && enabled
|
||||
}
|
||||
|
||||
// IsCodexCLIOnlyEnabled 返回 OpenAI OAuth 账号是否启用"仅允许 Codex 官方客户端"。
|
||||
// 字段:accounts.extra.codex_cli_only。
|
||||
// 字段缺失或类型不正确时,按 false(关闭)处理。
|
||||
func (a *Account) IsCodexCLIOnlyEnabled() bool {
|
||||
|
||||
71
backend/internal/service/account_websearch_test.go
Normal file
71
backend/internal/service/account_websearch_test.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestAccount_IsWebSearchEmulationEnabled_Enabled(t *testing.T) {
|
||||
a := &Account{
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeAPIKey,
|
||||
Extra: map[string]any{featureKeyWebSearchEmulation: true},
|
||||
}
|
||||
require.True(t, a.IsWebSearchEmulationEnabled())
|
||||
}
|
||||
|
||||
func TestAccount_IsWebSearchEmulationEnabled_Disabled(t *testing.T) {
|
||||
a := &Account{
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeAPIKey,
|
||||
Extra: map[string]any{featureKeyWebSearchEmulation: false},
|
||||
}
|
||||
require.False(t, a.IsWebSearchEmulationEnabled())
|
||||
}
|
||||
|
||||
func TestAccount_IsWebSearchEmulationEnabled_MissingField(t *testing.T) {
|
||||
a := &Account{
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeAPIKey,
|
||||
Extra: map[string]any{},
|
||||
}
|
||||
require.False(t, a.IsWebSearchEmulationEnabled())
|
||||
}
|
||||
|
||||
func TestAccount_IsWebSearchEmulationEnabled_WrongType(t *testing.T) {
|
||||
a := &Account{
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeAPIKey,
|
||||
Extra: map[string]any{featureKeyWebSearchEmulation: "true"},
|
||||
}
|
||||
require.False(t, a.IsWebSearchEmulationEnabled())
|
||||
}
|
||||
|
||||
func TestAccount_IsWebSearchEmulationEnabled_NilExtra(t *testing.T) {
|
||||
a := &Account{Platform: PlatformAnthropic, Type: AccountTypeAPIKey, Extra: nil}
|
||||
require.False(t, a.IsWebSearchEmulationEnabled())
|
||||
}
|
||||
|
||||
func TestAccount_IsWebSearchEmulationEnabled_NilAccount(t *testing.T) {
|
||||
var a *Account
|
||||
require.False(t, a.IsWebSearchEmulationEnabled())
|
||||
}
|
||||
|
||||
func TestAccount_IsWebSearchEmulationEnabled_NonAnthropicPlatform(t *testing.T) {
|
||||
a := &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Extra: map[string]any{featureKeyWebSearchEmulation: true},
|
||||
}
|
||||
require.False(t, a.IsWebSearchEmulationEnabled())
|
||||
}
|
||||
|
||||
func TestAccount_IsWebSearchEmulationEnabled_NonAPIKeyType(t *testing.T) {
|
||||
a := &Account{
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Extra: map[string]any{featureKeyWebSearchEmulation: true},
|
||||
}
|
||||
require.False(t, a.IsWebSearchEmulationEnabled())
|
||||
}
|
||||
@@ -49,6 +49,21 @@ type Channel struct {
|
||||
ModelPricing []ChannelModelPricing
|
||||
// 渠道级模型映射(按平台分组:platform → {src→dst})
|
||||
ModelMapping map[string]map[string]string
|
||||
// 渠道特性配置(如 {"web_search_emulation": {"anthropic": true}})
|
||||
FeaturesConfig map[string]any
|
||||
}
|
||||
|
||||
// IsWebSearchEmulationEnabled 返回该渠道是否为指定平台启用了 web search 模拟。
|
||||
func (c *Channel) IsWebSearchEmulationEnabled(platform string) bool {
|
||||
if c == nil || c.FeaturesConfig == nil {
|
||||
return false
|
||||
}
|
||||
wse, ok := c.FeaturesConfig[featureKeyWebSearchEmulation].(map[string]any)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
enabled, ok := wse[platform].(bool)
|
||||
return ok && enabled
|
||||
}
|
||||
|
||||
// ChannelModelPricing 渠道模型定价条目
|
||||
|
||||
@@ -197,10 +197,8 @@ func newEmptyChannelCache() *channelCache {
|
||||
}
|
||||
|
||||
// expandPricingToCache 将渠道的模型定价展开到缓存(按分组+平台维度)。
|
||||
// antigravity 平台同时服务 Claude 和 Gemini 模型,需匹配 anthropic/gemini 的定价条目。
|
||||
// 缓存 key 使用定价条目的原始平台(pricing.Platform),而非分组平台,
|
||||
// 避免跨平台同名模型(如 anthropic 和 gemini 都有 "model-x")互相覆盖。
|
||||
// 查找时通过 lookupPricingAcrossPlatforms() 依次尝试所有匹配平台。
|
||||
// 各平台严格独立:antigravity 分组只匹配 antigravity 定价,不会匹配 anthropic/gemini 的定价。
|
||||
// 查找时通过 lookupPricingAcrossPlatforms() 在本平台内查找。
|
||||
func expandPricingToCache(cache *channelCache, ch *Channel, gid int64, platform string) {
|
||||
for j := range ch.ModelPricing {
|
||||
pricing := &ch.ModelPricing[j]
|
||||
@@ -226,8 +224,7 @@ func expandPricingToCache(cache *channelCache, ch *Channel, gid int64, platform
|
||||
}
|
||||
|
||||
// expandMappingToCache 将渠道的模型映射展开到缓存(按分组+平台维度)。
|
||||
// antigravity 平台同时服务 Claude 和 Gemini 模型。
|
||||
// 缓存 key 使用映射条目的原始平台(mappingPlatform),避免跨平台同名映射覆盖。
|
||||
// 各平台严格独立:antigravity 分组只匹配 antigravity 映射。
|
||||
func expandMappingToCache(cache *channelCache, ch *Channel, gid int64, platform string) {
|
||||
for _, mappingPlatform := range matchingPlatforms(platform) {
|
||||
platformMapping, ok := ch.ModelMapping[mappingPlatform]
|
||||
@@ -251,40 +248,58 @@ func expandMappingToCache(cache *channelCache, ch *Channel, gid int64, platform
|
||||
}
|
||||
}
|
||||
|
||||
// storeErrorCache 存入短 TTL 空缓存,防止 DB 错误后紧密重试。
|
||||
// 通过回退 loadedAt 使剩余 TTL = channelErrorTTL。
|
||||
func (s *ChannelService) storeErrorCache() {
|
||||
errorCache := newEmptyChannelCache()
|
||||
errorCache.loadedAt = time.Now().Add(-(channelCacheTTL - channelErrorTTL))
|
||||
s.cache.Store(errorCache)
|
||||
}
|
||||
|
||||
// buildCache 从数据库构建渠道缓存。
|
||||
// 使用独立 context 避免请求取消导致空值被长期缓存。
|
||||
func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error) {
|
||||
// 断开请求取消链,避免客户端断连导致空值被长期缓存
|
||||
dbCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), channelCacheDBTimeout)
|
||||
defer cancel()
|
||||
|
||||
channels, err := s.repo.ListAll(dbCtx)
|
||||
channels, groupPlatforms, err := s.fetchChannelData(dbCtx)
|
||||
if err != nil {
|
||||
// error-TTL:失败时存入短 TTL 空缓存,防止紧密重试
|
||||
slog.Warn("failed to build channel cache", "error", err)
|
||||
errorCache := newEmptyChannelCache()
|
||||
errorCache.loadedAt = time.Now().Add(-(channelCacheTTL - channelErrorTTL)) // 使剩余 TTL = errorTTL
|
||||
s.cache.Store(errorCache)
|
||||
return nil, fmt.Errorf("list all channels: %w", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cache := populateChannelCache(channels, groupPlatforms)
|
||||
s.cache.Store(cache)
|
||||
return cache, nil
|
||||
}
|
||||
|
||||
// fetchChannelData 从数据库加载渠道列表和分组平台映射。
|
||||
func (s *ChannelService) fetchChannelData(ctx context.Context) ([]Channel, map[int64]string, error) {
|
||||
channels, err := s.repo.ListAll(ctx)
|
||||
if err != nil {
|
||||
slog.Warn("failed to build channel cache", "error", err)
|
||||
s.storeErrorCache()
|
||||
return nil, nil, fmt.Errorf("list all channels: %w", err)
|
||||
}
|
||||
|
||||
// 收集所有 groupID,批量查询 platform
|
||||
var allGroupIDs []int64
|
||||
for i := range channels {
|
||||
allGroupIDs = append(allGroupIDs, channels[i].GroupIDs...)
|
||||
}
|
||||
|
||||
groupPlatforms := make(map[int64]string)
|
||||
if len(allGroupIDs) > 0 {
|
||||
groupPlatforms, err = s.repo.GetGroupPlatforms(dbCtx, allGroupIDs)
|
||||
groupPlatforms, err = s.repo.GetGroupPlatforms(ctx, allGroupIDs)
|
||||
if err != nil {
|
||||
slog.Warn("failed to load group platforms for channel cache", "error", err)
|
||||
errorCache := newEmptyChannelCache()
|
||||
errorCache.loadedAt = time.Now().Add(-(channelCacheTTL - channelErrorTTL))
|
||||
s.cache.Store(errorCache)
|
||||
return nil, fmt.Errorf("get group platforms: %w", err)
|
||||
s.storeErrorCache()
|
||||
return nil, nil, fmt.Errorf("get group platforms: %w", err)
|
||||
}
|
||||
}
|
||||
return channels, groupPlatforms, nil
|
||||
}
|
||||
|
||||
// populateChannelCache 将渠道列表和分组平台映射填充到缓存快照中。
|
||||
func populateChannelCache(channels []Channel, groupPlatforms map[int64]string) *channelCache {
|
||||
cache := newEmptyChannelCache()
|
||||
cache.groupPlatform = groupPlatforms
|
||||
cache.byID = make(map[int64]*Channel, len(channels))
|
||||
@@ -293,7 +308,6 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error)
|
||||
for i := range channels {
|
||||
ch := &channels[i]
|
||||
cache.byID[ch.ID] = ch
|
||||
|
||||
for _, gid := range ch.GroupIDs {
|
||||
cache.channelByGroupID[gid] = ch
|
||||
platform := groupPlatforms[gid]
|
||||
@@ -302,32 +316,20 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error)
|
||||
}
|
||||
}
|
||||
|
||||
// 通配符条目保持配置顺序(最先匹配到优先)
|
||||
|
||||
s.cache.Store(cache)
|
||||
return cache, nil
|
||||
return cache
|
||||
}
|
||||
|
||||
// invalidateCache 使缓存失效,让下次读取时自然重建
|
||||
|
||||
// isPlatformPricingMatch 判断定价条目的平台是否匹配分组平台。
|
||||
// antigravity 平台同时服务 Claude(anthropic)和 Gemini(gemini)模型,
|
||||
// 因此 antigravity 分组应匹配 anthropic 和 gemini 的定价条目。
|
||||
// 各平台(antigravity / anthropic / gemini / openai)严格独立,不跨平台匹配。
|
||||
func isPlatformPricingMatch(groupPlatform, pricingPlatform string) bool {
|
||||
if groupPlatform == pricingPlatform {
|
||||
return true
|
||||
}
|
||||
if groupPlatform == PlatformAntigravity {
|
||||
return pricingPlatform == PlatformAnthropic || pricingPlatform == PlatformGemini
|
||||
}
|
||||
return false
|
||||
return groupPlatform == pricingPlatform
|
||||
}
|
||||
|
||||
// matchingPlatforms 返回分组平台对应的所有可匹配平台列表。
|
||||
// matchingPlatforms 返回分组平台对应的可匹配平台列表。
|
||||
// 各平台严格独立,只返回自身。
|
||||
func matchingPlatforms(groupPlatform string) []string {
|
||||
if groupPlatform == PlatformAntigravity {
|
||||
return []string{PlatformAntigravity, PlatformAnthropic, PlatformGemini}
|
||||
}
|
||||
return []string{groupPlatform}
|
||||
}
|
||||
func (s *ChannelService) invalidateCache() {
|
||||
@@ -364,10 +366,8 @@ func (c *channelCache) matchWildcardMapping(groupID int64, platform, modelLower
|
||||
return ""
|
||||
}
|
||||
|
||||
// lookupPricingAcrossPlatforms 在所有匹配平台中查找模型定价。
|
||||
// antigravity 分组的缓存 key 使用定价条目的原始平台,因此查找时需依次尝试
|
||||
// matchingPlatforms() 返回的所有平台(antigravity → anthropic → gemini),
|
||||
// 返回第一个命中的结果。非 antigravity 平台只尝试自身。
|
||||
// lookupPricingAcrossPlatforms 在分组平台内查找模型定价。
|
||||
// 各平台严格独立,只在本平台内查找(先精确匹配,再通配符)。
|
||||
func lookupPricingAcrossPlatforms(cache *channelCache, groupID int64, groupPlatform, modelLower string) *ChannelModelPricing {
|
||||
for _, p := range matchingPlatforms(groupPlatform) {
|
||||
key := channelModelKey{groupID: groupID, platform: p, model: modelLower}
|
||||
@@ -384,7 +384,7 @@ func lookupPricingAcrossPlatforms(cache *channelCache, groupID int64, groupPlatf
|
||||
return nil
|
||||
}
|
||||
|
||||
// lookupMappingAcrossPlatforms 在所有匹配平台中查找模型映射。
|
||||
// lookupMappingAcrossPlatforms 在分组平台内查找模型映射。
|
||||
// 逻辑与 lookupPricingAcrossPlatforms 相同:先精确查找,再通配符。
|
||||
func lookupMappingAcrossPlatforms(cache *channelCache, groupID int64, groupPlatform, modelLower string) string {
|
||||
for _, p := range matchingPlatforms(groupPlatform) {
|
||||
@@ -442,8 +442,7 @@ func (s *ChannelService) lookupGroupChannel(ctx context.Context, groupID int64)
|
||||
}
|
||||
|
||||
// GetChannelModelPricing 获取指定分组+模型的渠道定价(热路径 O(1))。
|
||||
// antigravity 分组依次尝试所有匹配平台(antigravity → anthropic → gemini),
|
||||
// 确保跨平台同名模型各自独立匹配。
|
||||
// 各平台严格独立,只在本平台内查找定价。
|
||||
func (s *ChannelService) GetChannelModelPricing(ctx context.Context, groupID int64, model string) *ChannelModelPricing {
|
||||
lk, err := s.lookupGroupChannel(ctx, groupID)
|
||||
if err != nil {
|
||||
@@ -481,7 +480,10 @@ func (s *ChannelService) ResolveChannelMapping(ctx context.Context, groupID int6
|
||||
// 返回 true 表示模型被限制(不在允许列表中)。
|
||||
// 如果渠道未启用模型限制或分组无渠道关联,返回 false。
|
||||
func (s *ChannelService) IsModelRestricted(ctx context.Context, groupID int64, model string) bool {
|
||||
lk, _ := s.lookupGroupChannel(ctx, groupID)
|
||||
lk, err := s.lookupGroupChannel(ctx, groupID)
|
||||
if err != nil {
|
||||
slog.Warn("failed to load channel cache for model restriction check", "group_id", groupID, "error", err)
|
||||
}
|
||||
if lk == nil {
|
||||
return false
|
||||
}
|
||||
@@ -524,7 +526,7 @@ func resolveMapping(lk *channelLookup, groupID int64, model string) ChannelMappi
|
||||
}
|
||||
|
||||
// checkRestricted 基于已查找的渠道信息检查模型是否被限制。
|
||||
// antigravity 分组依次尝试所有匹配平台的定价列表。
|
||||
// 只在本平台的定价列表中查找。
|
||||
func checkRestricted(lk *channelLookup, groupID int64, model string) bool {
|
||||
if !lk.channel.RestrictModels {
|
||||
return false
|
||||
@@ -552,6 +554,91 @@ func ReplaceModelInBody(body []byte, newModel string) []byte {
|
||||
return newBody
|
||||
}
|
||||
|
||||
// validateChannelConfig 校验渠道的定价和映射配置(冲突检测 + 区间校验 + 计费模式校验)。
|
||||
// Create 和 Update 共用此函数,避免重复。
|
||||
func validateChannelConfig(pricing []ChannelModelPricing, mapping map[string]map[string]string) error {
|
||||
if err := validateNoConflictingModels(pricing); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := validatePricingIntervals(pricing); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := validateNoConflictingMappings(mapping); err != nil {
|
||||
return err
|
||||
}
|
||||
return validatePricingBillingMode(pricing)
|
||||
}
|
||||
|
||||
// validatePricingBillingMode 校验计费模式配置:按次/图片模式必须配价格或区间,所有价格字段不能为负,区间至少有一个价格字段。
|
||||
func validatePricingBillingMode(pricing []ChannelModelPricing) error {
|
||||
for _, p := range pricing {
|
||||
if err := checkBillingModeRequirements(p); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := checkPricesNotNegative(p); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := checkIntervalsHavePrices(p); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func checkBillingModeRequirements(p ChannelModelPricing) error {
|
||||
if p.BillingMode == BillingModePerRequest || p.BillingMode == BillingModeImage {
|
||||
if p.PerRequestPrice == nil && len(p.Intervals) == 0 {
|
||||
return infraerrors.BadRequest(
|
||||
"BILLING_MODE_MISSING_PRICE",
|
||||
"per-request price or intervals required for per_request/image billing mode",
|
||||
)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func checkPricesNotNegative(p ChannelModelPricing) error {
|
||||
checks := []struct {
|
||||
field string
|
||||
val *float64
|
||||
}{
|
||||
{"input_price", p.InputPrice},
|
||||
{"output_price", p.OutputPrice},
|
||||
{"cache_write_price", p.CacheWritePrice},
|
||||
{"cache_read_price", p.CacheReadPrice},
|
||||
{"image_output_price", p.ImageOutputPrice},
|
||||
{"per_request_price", p.PerRequestPrice},
|
||||
}
|
||||
for _, c := range checks {
|
||||
if c.val != nil && *c.val < 0 {
|
||||
return infraerrors.BadRequest("NEGATIVE_PRICE", fmt.Sprintf("%s must be >= 0", c.field))
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func checkIntervalsHavePrices(p ChannelModelPricing) error {
|
||||
for _, iv := range p.Intervals {
|
||||
if iv.InputPrice == nil && iv.OutputPrice == nil &&
|
||||
iv.CacheWritePrice == nil && iv.CacheReadPrice == nil &&
|
||||
iv.PerRequestPrice == nil {
|
||||
return infraerrors.BadRequest(
|
||||
"INTERVAL_MISSING_PRICE",
|
||||
fmt.Sprintf("interval [%d, %s] has no price fields set for model %v",
|
||||
iv.MinTokens, formatMaxTokens(iv.MaxTokens), p.Models),
|
||||
)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func formatMaxTokens(max *int) string {
|
||||
if max == nil {
|
||||
return "∞"
|
||||
}
|
||||
return fmt.Sprintf("%d", *max)
|
||||
}
|
||||
|
||||
// --- CRUD ---
|
||||
|
||||
// Create 创建渠道
|
||||
@@ -564,15 +651,8 @@ func (s *ChannelService) Create(ctx context.Context, input *CreateChannelInput)
|
||||
return nil, ErrChannelExists
|
||||
}
|
||||
|
||||
// 检查分组冲突
|
||||
if len(input.GroupIDs) > 0 {
|
||||
conflicting, err := s.repo.GetGroupsInOtherChannels(ctx, 0, input.GroupIDs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("check group conflicts: %w", err)
|
||||
}
|
||||
if len(conflicting) > 0 {
|
||||
return nil, ErrGroupAlreadyInChannel
|
||||
}
|
||||
if err := s.checkGroupConflicts(ctx, 0, input.GroupIDs); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
channel := &Channel{
|
||||
@@ -585,18 +665,13 @@ func (s *ChannelService) Create(ctx context.Context, input *CreateChannelInput)
|
||||
ModelPricing: input.ModelPricing,
|
||||
ModelMapping: input.ModelMapping,
|
||||
Features: input.Features,
|
||||
FeaturesConfig: input.FeaturesConfig,
|
||||
}
|
||||
if channel.BillingModelSource == "" {
|
||||
channel.BillingModelSource = BillingModelSourceChannelMapped
|
||||
}
|
||||
|
||||
if err := validateNoConflictingModels(channel.ModelPricing); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := validatePricingIntervals(channel.ModelPricing); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := validateNoConflictingMappings(channel.ModelMapping); err != nil {
|
||||
if err := validateChannelConfig(channel.ModelPricing, channel.ModelMapping); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -620,105 +695,118 @@ func (s *ChannelService) Update(ctx context.Context, id int64, input *UpdateChan
|
||||
return nil, fmt.Errorf("get channel: %w", err)
|
||||
}
|
||||
|
||||
if input.Name != "" && input.Name != channel.Name {
|
||||
exists, err := s.repo.ExistsByNameExcluding(ctx, input.Name, id)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("check channel exists: %w", err)
|
||||
}
|
||||
if exists {
|
||||
return nil, ErrChannelExists
|
||||
}
|
||||
channel.Name = input.Name
|
||||
}
|
||||
|
||||
if input.Description != nil {
|
||||
channel.Description = *input.Description
|
||||
}
|
||||
|
||||
if input.Status != "" {
|
||||
channel.Status = input.Status
|
||||
}
|
||||
|
||||
if input.RestrictModels != nil {
|
||||
channel.RestrictModels = *input.RestrictModels
|
||||
}
|
||||
if input.Features != nil {
|
||||
channel.Features = *input.Features
|
||||
}
|
||||
|
||||
// 检查分组冲突
|
||||
if input.GroupIDs != nil {
|
||||
conflicting, err := s.repo.GetGroupsInOtherChannels(ctx, id, *input.GroupIDs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("check group conflicts: %w", err)
|
||||
}
|
||||
if len(conflicting) > 0 {
|
||||
return nil, ErrGroupAlreadyInChannel
|
||||
}
|
||||
channel.GroupIDs = *input.GroupIDs
|
||||
}
|
||||
|
||||
if input.ModelPricing != nil {
|
||||
channel.ModelPricing = *input.ModelPricing
|
||||
}
|
||||
|
||||
if input.ModelMapping != nil {
|
||||
channel.ModelMapping = input.ModelMapping
|
||||
}
|
||||
|
||||
if input.BillingModelSource != "" {
|
||||
channel.BillingModelSource = input.BillingModelSource
|
||||
}
|
||||
|
||||
if err := validateNoConflictingModels(channel.ModelPricing); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := validatePricingIntervals(channel.ModelPricing); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := validateNoConflictingMappings(channel.ModelMapping); err != nil {
|
||||
if err := s.applyUpdateInput(ctx, channel, input); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 先获取旧分组,Update 后旧分组关联已删除,无法再查到
|
||||
var oldGroupIDs []int64
|
||||
if s.authCacheInvalidator != nil {
|
||||
var err2 error
|
||||
oldGroupIDs, err2 = s.repo.GetGroupIDs(ctx, id)
|
||||
if err2 != nil {
|
||||
slog.Warn("failed to get old group IDs for cache invalidation", "channel_id", id, "error", err2)
|
||||
}
|
||||
if err := validateChannelConfig(channel.ModelPricing, channel.ModelMapping); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
oldGroupIDs := s.getOldGroupIDs(ctx, id)
|
||||
|
||||
if err := s.repo.Update(ctx, channel); err != nil {
|
||||
return nil, fmt.Errorf("update channel: %w", err)
|
||||
}
|
||||
|
||||
s.invalidateCache()
|
||||
|
||||
// 失效新旧分组的 auth 缓存
|
||||
if s.authCacheInvalidator != nil {
|
||||
seen := make(map[int64]struct{}, len(oldGroupIDs)+len(channel.GroupIDs))
|
||||
for _, gid := range oldGroupIDs {
|
||||
if _, ok := seen[gid]; !ok {
|
||||
seen[gid] = struct{}{}
|
||||
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, gid)
|
||||
}
|
||||
}
|
||||
for _, gid := range channel.GroupIDs {
|
||||
if _, ok := seen[gid]; !ok {
|
||||
seen[gid] = struct{}{}
|
||||
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, gid)
|
||||
}
|
||||
}
|
||||
}
|
||||
s.invalidateAuthCacheForGroups(ctx, oldGroupIDs, channel.GroupIDs)
|
||||
|
||||
return s.repo.GetByID(ctx, id)
|
||||
}
|
||||
|
||||
// applyUpdateInput 将更新请求的字段应用到渠道实体上。
|
||||
func (s *ChannelService) applyUpdateInput(ctx context.Context, channel *Channel, input *UpdateChannelInput) error {
|
||||
if input.Name != "" && input.Name != channel.Name {
|
||||
exists, err := s.repo.ExistsByNameExcluding(ctx, input.Name, channel.ID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("check channel exists: %w", err)
|
||||
}
|
||||
if exists {
|
||||
return ErrChannelExists
|
||||
}
|
||||
channel.Name = input.Name
|
||||
}
|
||||
if input.Description != nil {
|
||||
channel.Description = *input.Description
|
||||
}
|
||||
if input.Status != "" {
|
||||
channel.Status = input.Status
|
||||
}
|
||||
if input.RestrictModels != nil {
|
||||
channel.RestrictModels = *input.RestrictModels
|
||||
}
|
||||
if input.Features != nil {
|
||||
channel.Features = *input.Features
|
||||
}
|
||||
if input.GroupIDs != nil {
|
||||
if err := s.checkGroupConflicts(ctx, channel.ID, *input.GroupIDs); err != nil {
|
||||
return err
|
||||
}
|
||||
channel.GroupIDs = *input.GroupIDs
|
||||
}
|
||||
if input.ModelPricing != nil {
|
||||
channel.ModelPricing = *input.ModelPricing
|
||||
}
|
||||
if input.ModelMapping != nil {
|
||||
channel.ModelMapping = input.ModelMapping
|
||||
}
|
||||
if input.BillingModelSource != "" {
|
||||
channel.BillingModelSource = input.BillingModelSource
|
||||
}
|
||||
if input.FeaturesConfig != nil {
|
||||
channel.FeaturesConfig = input.FeaturesConfig
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkGroupConflicts 检查待关联的分组是否已属于其他渠道。
|
||||
// channelID 为当前渠道 ID(Create 时传 0)。
|
||||
func (s *ChannelService) checkGroupConflicts(ctx context.Context, channelID int64, groupIDs []int64) error {
|
||||
if len(groupIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
conflicting, err := s.repo.GetGroupsInOtherChannels(ctx, channelID, groupIDs)
|
||||
if err != nil {
|
||||
return fmt.Errorf("check group conflicts: %w", err)
|
||||
}
|
||||
if len(conflicting) > 0 {
|
||||
return ErrGroupAlreadyInChannel
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// getOldGroupIDs 获取渠道更新前的关联分组 ID(用于失效 auth 缓存)。
|
||||
func (s *ChannelService) getOldGroupIDs(ctx context.Context, channelID int64) []int64 {
|
||||
if s.authCacheInvalidator == nil {
|
||||
return nil
|
||||
}
|
||||
oldGroupIDs, err := s.repo.GetGroupIDs(ctx, channelID)
|
||||
if err != nil {
|
||||
slog.Warn("failed to get old group IDs for cache invalidation", "channel_id", channelID, "error", err)
|
||||
}
|
||||
return oldGroupIDs
|
||||
}
|
||||
|
||||
// invalidateAuthCacheForGroups 对新旧分组去重后逐个失效 auth 缓存。
|
||||
func (s *ChannelService) invalidateAuthCacheForGroups(ctx context.Context, groupIDSets ...[]int64) {
|
||||
if s.authCacheInvalidator == nil {
|
||||
return
|
||||
}
|
||||
seen := make(map[int64]struct{})
|
||||
for _, ids := range groupIDSets {
|
||||
for _, gid := range ids {
|
||||
if _, ok := seen[gid]; ok {
|
||||
continue
|
||||
}
|
||||
seen[gid] = struct{}{}
|
||||
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, gid)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Delete 删除渠道
|
||||
func (s *ChannelService) Delete(ctx context.Context, id int64) error {
|
||||
// 先获取关联分组用于失效缓存
|
||||
groupIDs, err := s.repo.GetGroupIDs(ctx, id)
|
||||
if err != nil {
|
||||
slog.Warn("failed to get group IDs before delete", "channel_id", id, "error", err)
|
||||
@@ -729,12 +817,7 @@ func (s *ChannelService) Delete(ctx context.Context, id int64) error {
|
||||
}
|
||||
|
||||
s.invalidateCache()
|
||||
|
||||
if s.authCacheInvalidator != nil {
|
||||
for _, gid := range groupIDs {
|
||||
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, gid)
|
||||
}
|
||||
}
|
||||
s.invalidateAuthCacheForGroups(ctx, groupIDs)
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -847,6 +930,7 @@ type CreateChannelInput struct {
|
||||
BillingModelSource string
|
||||
RestrictModels bool
|
||||
Features string
|
||||
FeaturesConfig map[string]any
|
||||
}
|
||||
|
||||
// UpdateChannelInput 更新渠道输入
|
||||
@@ -860,4 +944,5 @@ type UpdateChannelInput struct {
|
||||
BillingModelSource string
|
||||
RestrictModels *bool
|
||||
Features *string
|
||||
FeaturesConfig map[string]any
|
||||
}
|
||||
|
||||
62
backend/internal/service/channel_websearch_test.go
Normal file
62
backend/internal/service/channel_websearch_test.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestChannel_IsWebSearchEmulationEnabled_Enabled(t *testing.T) {
|
||||
c := &Channel{
|
||||
FeaturesConfig: map[string]any{
|
||||
featureKeyWebSearchEmulation: map[string]any{"anthropic": true},
|
||||
},
|
||||
}
|
||||
require.True(t, c.IsWebSearchEmulationEnabled("anthropic"))
|
||||
}
|
||||
|
||||
func TestChannel_IsWebSearchEmulationEnabled_DifferentPlatform(t *testing.T) {
|
||||
c := &Channel{
|
||||
FeaturesConfig: map[string]any{
|
||||
featureKeyWebSearchEmulation: map[string]any{"anthropic": true},
|
||||
},
|
||||
}
|
||||
require.False(t, c.IsWebSearchEmulationEnabled("openai"))
|
||||
}
|
||||
|
||||
func TestChannel_IsWebSearchEmulationEnabled_Disabled(t *testing.T) {
|
||||
c := &Channel{
|
||||
FeaturesConfig: map[string]any{
|
||||
featureKeyWebSearchEmulation: map[string]any{"anthropic": false},
|
||||
},
|
||||
}
|
||||
require.False(t, c.IsWebSearchEmulationEnabled("anthropic"))
|
||||
}
|
||||
|
||||
func TestChannel_IsWebSearchEmulationEnabled_NilFeaturesConfig(t *testing.T) {
|
||||
c := &Channel{FeaturesConfig: nil}
|
||||
require.False(t, c.IsWebSearchEmulationEnabled("anthropic"))
|
||||
}
|
||||
|
||||
func TestChannel_IsWebSearchEmulationEnabled_NilChannel(t *testing.T) {
|
||||
var c *Channel
|
||||
require.False(t, c.IsWebSearchEmulationEnabled("anthropic"))
|
||||
}
|
||||
|
||||
func TestChannel_IsWebSearchEmulationEnabled_WrongStructure(t *testing.T) {
|
||||
c := &Channel{
|
||||
FeaturesConfig: map[string]any{
|
||||
featureKeyWebSearchEmulation: true, // not a map
|
||||
},
|
||||
}
|
||||
require.False(t, c.IsWebSearchEmulationEnabled("anthropic"))
|
||||
}
|
||||
|
||||
func TestChannel_IsWebSearchEmulationEnabled_PlatformValueNotBool(t *testing.T) {
|
||||
c := &Channel{
|
||||
FeaturesConfig: map[string]any{
|
||||
featureKeyWebSearchEmulation: map[string]any{"anthropic": "yes"},
|
||||
},
|
||||
}
|
||||
require.False(t, c.IsWebSearchEmulationEnabled("anthropic"))
|
||||
}
|
||||
@@ -249,6 +249,10 @@ const (
|
||||
SettingKeyEnableMetadataPassthrough = "enable_metadata_passthrough"
|
||||
// SettingKeyEnableCCHSigning 是否对 billing header 中的 cch 进行 xxHash64 签名(默认 false)
|
||||
SettingKeyEnableCCHSigning = "enable_cch_signing"
|
||||
|
||||
// Web Search Emulation
|
||||
// SettingKeyWebSearchEmulationConfig 全局 web search 模拟配置(JSON)
|
||||
SettingKeyWebSearchEmulationConfig = "web_search_emulation_config"
|
||||
)
|
||||
|
||||
// AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys).
|
||||
|
||||
@@ -3785,6 +3785,11 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
return nil, fmt.Errorf("parse request: empty request")
|
||||
}
|
||||
|
||||
// Web Search 模拟:纯 web_search 请求时,直接调用搜索 API 构造响应
|
||||
if account != nil && s.shouldEmulateWebSearch(ctx, account, parsed.Body) {
|
||||
return s.handleWebSearchEmulation(ctx, c, account, parsed)
|
||||
}
|
||||
|
||||
if account != nil && account.IsAnthropicAPIKeyPassthroughEnabled() {
|
||||
passthroughBody := parsed.Body
|
||||
passthroughModel := parsed.Model
|
||||
|
||||
358
backend/internal/service/gateway_websearch_emulation.go
Normal file
358
backend/internal/service/gateway_websearch_emulation.go
Normal file
@@ -0,0 +1,358 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/websearch"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
// Web search emulation constants
|
||||
const (
|
||||
toolTypeWebSearchPrefix = "web_search"
|
||||
toolTypeGoogleSearch = "google_search"
|
||||
toolNameWebSearch = "web_search"
|
||||
toolNameGoogleSearch = "google_search"
|
||||
toolNameWebSearch2025 = "web_search_20250305"
|
||||
|
||||
webSearchDefaultMaxResults = 5
|
||||
defaultWebSearchModel = "claude-sonnet-4-6"
|
||||
webSearchMsgIDPrefix = "msg_ws_"
|
||||
webSearchToolUseIDPrefix = "srvtoolu_ws_"
|
||||
tokenEstimateDivisor = 4
|
||||
|
||||
// featureKeyWebSearchEmulation is the key used in Account.Extra and Channel.FeaturesConfig.
|
||||
featureKeyWebSearchEmulation = "web_search_emulation"
|
||||
)
|
||||
|
||||
// webSearchManagerPtr stores *websearch.Manager atomically for concurrent safety.
|
||||
var webSearchManagerPtr atomic.Pointer[websearch.Manager]
|
||||
|
||||
// SetWebSearchManager wires the websearch.Manager into the gateway (goroutine-safe).
|
||||
func SetWebSearchManager(m *websearch.Manager) {
|
||||
webSearchManagerPtr.Store(m)
|
||||
}
|
||||
|
||||
func getWebSearchManager() *websearch.Manager {
|
||||
return webSearchManagerPtr.Load()
|
||||
}
|
||||
|
||||
// shouldEmulateWebSearch checks whether a request should be intercepted.
|
||||
//
|
||||
// Judgment chain: manager exists → only web_search tool → global enabled → account enabled.
|
||||
// Note: channel-level control is enforced via the account's extra field; the channel toggle
|
||||
// in the admin UI sets the account's flag for all accounts in that channel's groups.
|
||||
func (s *GatewayService) shouldEmulateWebSearch(ctx context.Context, account *Account, body []byte) bool {
|
||||
if getWebSearchManager() == nil {
|
||||
return false
|
||||
}
|
||||
if !isOnlyWebSearchToolInBody(body) {
|
||||
return false
|
||||
}
|
||||
if !s.settingService.IsWebSearchEmulationEnabled(ctx) {
|
||||
return false
|
||||
}
|
||||
if !account.IsWebSearchEmulationEnabled() {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// isOnlyWebSearchToolInBody checks if the body contains exactly one web_search tool.
|
||||
func isOnlyWebSearchToolInBody(body []byte) bool {
|
||||
tools := gjson.GetBytes(body, "tools")
|
||||
if !tools.IsArray() {
|
||||
return false
|
||||
}
|
||||
arr := tools.Array()
|
||||
if len(arr) != 1 {
|
||||
return false
|
||||
}
|
||||
return isWebSearchToolJSON(arr[0])
|
||||
}
|
||||
|
||||
func isWebSearchToolJSON(tool gjson.Result) bool {
|
||||
toolType := tool.Get("type").String()
|
||||
if strings.HasPrefix(toolType, toolTypeWebSearchPrefix) || toolType == toolTypeGoogleSearch {
|
||||
return true
|
||||
}
|
||||
switch tool.Get("name").String() {
|
||||
case toolNameWebSearch, toolNameGoogleSearch, toolNameWebSearch2025:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// extractSearchQueryFromBody extracts the last user message text as the search query.
|
||||
func extractSearchQueryFromBody(body []byte) string {
|
||||
messages := gjson.GetBytes(body, "messages")
|
||||
if !messages.IsArray() {
|
||||
return ""
|
||||
}
|
||||
arr := messages.Array()
|
||||
if len(arr) == 0 {
|
||||
return ""
|
||||
}
|
||||
lastMsg := arr[len(arr)-1]
|
||||
if lastMsg.Get("role").String() != "user" {
|
||||
return ""
|
||||
}
|
||||
return extractWebSearchTextFromContent(lastMsg.Get("content"))
|
||||
}
|
||||
|
||||
func extractWebSearchTextFromContent(content gjson.Result) string {
|
||||
if content.Type == gjson.String {
|
||||
return content.String()
|
||||
}
|
||||
if content.IsArray() {
|
||||
for _, block := range content.Array() {
|
||||
if block.Get("type").String() == "text" {
|
||||
if text := block.Get("text").String(); text != "" {
|
||||
return text
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// handleWebSearchEmulation intercepts a web-search-only request,
|
||||
// calls a third-party search API, and constructs an Anthropic-format response.
|
||||
func (s *GatewayService) handleWebSearchEmulation(
|
||||
ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest,
|
||||
) (*ForwardResult, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
// Release the serial queue lock immediately — we don't need upstream.
|
||||
if parsed.OnUpstreamAccepted != nil {
|
||||
parsed.OnUpstreamAccepted()
|
||||
}
|
||||
|
||||
query := extractSearchQueryFromBody(parsed.Body)
|
||||
if query == "" {
|
||||
return nil, fmt.Errorf("web search emulation: no query found in messages")
|
||||
}
|
||||
|
||||
slog.Info("web search emulation: executing search",
|
||||
"account_id", account.ID, "account_name", account.Name, "query", query)
|
||||
|
||||
resp, providerName, err := doWebSearch(ctx, account, query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
slog.Info("web search emulation: search completed",
|
||||
"provider", providerName, "results_count", len(resp.Results))
|
||||
|
||||
model := parsed.Model
|
||||
if model == "" {
|
||||
model = defaultWebSearchModel
|
||||
}
|
||||
|
||||
if parsed.Stream {
|
||||
return writeWebSearchStreamResponse(c, query, resp, model, startTime)
|
||||
}
|
||||
return writeWebSearchNonStreamResponse(c, query, resp, model, startTime)
|
||||
}
|
||||
|
||||
func doWebSearch(ctx context.Context, account *Account, query string) (*websearch.SearchResponse, string, error) {
|
||||
proxyURL := resolveAccountProxyURL(account)
|
||||
mgr := getWebSearchManager()
|
||||
if mgr == nil {
|
||||
return nil, "", fmt.Errorf("web search emulation: manager not initialized")
|
||||
}
|
||||
resp, providerName, err := mgr.SearchWithBestProvider(ctx, websearch.SearchRequest{
|
||||
Query: query, MaxResults: webSearchDefaultMaxResults, ProxyURL: proxyURL,
|
||||
})
|
||||
if err != nil {
|
||||
slog.Error("web search emulation: search failed", "error", err)
|
||||
return nil, "", fmt.Errorf("web search emulation: %w", err)
|
||||
}
|
||||
return resp, providerName, nil
|
||||
}
|
||||
|
||||
func resolveAccountProxyURL(account *Account) string {
|
||||
if account.ProxyID != nil && account.Proxy != nil {
|
||||
return account.Proxy.URL()
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// --- SSE streaming response ---
|
||||
|
||||
func writeWebSearchStreamResponse(
|
||||
c *gin.Context, query string, resp *websearch.SearchResponse, model string, startTime time.Time,
|
||||
) (*ForwardResult, error) {
|
||||
msgID := webSearchMsgIDPrefix + uuid.New().String()
|
||||
toolUseID := webSearchToolUseIDPrefix + uuid.New().String()[:16]
|
||||
|
||||
setSSEHeaders(c)
|
||||
if err := writeSSEMessageStart(c.Writer, msgID, model); err != nil {
|
||||
return nil, fmt.Errorf("web search emulation: SSE write: %w", err)
|
||||
}
|
||||
writeSSEServerToolUse(c.Writer, toolUseID, query, 0)
|
||||
writeSSEToolResult(c.Writer, toolUseID, resp.Results, 1)
|
||||
textSummary := buildTextSummary(query, resp.Results)
|
||||
writeSSETextBlock(c.Writer, textSummary, 2)
|
||||
writeSSEMessageEnd(c.Writer, len(textSummary)/tokenEstimateDivisor)
|
||||
c.Writer.Flush()
|
||||
|
||||
return &ForwardResult{Model: model, Duration: time.Since(startTime), Usage: ClaudeUsage{}}, nil
|
||||
}
|
||||
|
||||
func setSSEHeaders(c *gin.Context) {
|
||||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||
c.Writer.Header().Set("Connection", "keep-alive")
|
||||
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
||||
c.Writer.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
func writeSSEMessageStart(w http.ResponseWriter, msgID, model string) error {
|
||||
evt := map[string]any{
|
||||
"type": "message_start",
|
||||
"message": map[string]any{
|
||||
"id": msgID, "type": "message", "role": "assistant", "model": model,
|
||||
"content": []any{}, "stop_reason": nil, "stop_sequence": nil,
|
||||
"usage": map[string]int{"input_tokens": 0, "output_tokens": 0},
|
||||
},
|
||||
}
|
||||
return flushSSEJSON(w, "message_start", evt)
|
||||
}
|
||||
|
||||
func writeSSEServerToolUse(w http.ResponseWriter, toolUseID, query string, index int) {
|
||||
start := map[string]any{
|
||||
"type": "content_block_start", "index": index,
|
||||
"content_block": map[string]any{
|
||||
"type": "server_tool_use", "id": toolUseID,
|
||||
"name": toolNameWebSearch, "input": map[string]string{"query": query},
|
||||
},
|
||||
}
|
||||
_ = flushSSEJSON(w, "content_block_start", start)
|
||||
_ = flushSSEJSON(w, "content_block_stop", map[string]any{"type": "content_block_stop", "index": index})
|
||||
}
|
||||
|
||||
func writeSSEToolResult(w http.ResponseWriter, toolUseID string, results []websearch.SearchResult, index int) {
|
||||
start := map[string]any{
|
||||
"type": "content_block_start", "index": index,
|
||||
"content_block": map[string]any{
|
||||
"type": "web_search_tool_result", "tool_use_id": toolUseID,
|
||||
"content": buildSearchResultBlocks(results),
|
||||
},
|
||||
}
|
||||
_ = flushSSEJSON(w, "content_block_start", start)
|
||||
_ = flushSSEJSON(w, "content_block_stop", map[string]any{"type": "content_block_stop", "index": index})
|
||||
}
|
||||
|
||||
func writeSSETextBlock(w http.ResponseWriter, text string, index int) {
|
||||
_ = flushSSEJSON(w, "content_block_start", map[string]any{
|
||||
"type": "content_block_start", "index": index,
|
||||
"content_block": map[string]any{"type": "text", "text": ""},
|
||||
})
|
||||
_ = flushSSEJSON(w, "content_block_delta", map[string]any{
|
||||
"type": "content_block_delta", "index": index,
|
||||
"delta": map[string]string{"type": "text_delta", "text": text},
|
||||
})
|
||||
_ = flushSSEJSON(w, "content_block_stop", map[string]any{"type": "content_block_stop", "index": index})
|
||||
}
|
||||
|
||||
func writeSSEMessageEnd(w http.ResponseWriter, outputTokens int) {
|
||||
_ = flushSSEJSON(w, "message_delta", map[string]any{
|
||||
"type": "message_delta",
|
||||
"delta": map[string]any{"stop_reason": "end_turn", "stop_sequence": nil},
|
||||
"usage": map[string]int{"output_tokens": outputTokens},
|
||||
})
|
||||
_ = flushSSEJSON(w, "message_stop", map[string]string{"type": "message_stop"})
|
||||
}
|
||||
|
||||
// flushSSEJSON marshals data to JSON and writes an SSE event. Returns error on marshal failure.
|
||||
func flushSSEJSON(w http.ResponseWriter, event string, data any) error {
|
||||
b, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
slog.Error("web search emulation: failed to marshal SSE event",
|
||||
"event", event, "error", err)
|
||||
return err
|
||||
}
|
||||
fmt.Fprintf(w, "event: %s\ndata: %s\n\n", event, b)
|
||||
if f, ok := w.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// --- Non-streaming JSON response ---
|
||||
|
||||
func writeWebSearchNonStreamResponse(
|
||||
c *gin.Context, query string, resp *websearch.SearchResponse, model string, startTime time.Time,
|
||||
) (*ForwardResult, error) {
|
||||
msgID := webSearchMsgIDPrefix + uuid.New().String()
|
||||
toolUseID := webSearchToolUseIDPrefix + uuid.New().String()[:16]
|
||||
textSummary := buildTextSummary(query, resp.Results)
|
||||
|
||||
msg := map[string]any{
|
||||
"id": msgID, "type": "message", "role": "assistant", "model": model,
|
||||
"content": []any{
|
||||
map[string]any{
|
||||
"type": "server_tool_use", "id": toolUseID,
|
||||
"name": toolNameWebSearch, "input": map[string]string{"query": query},
|
||||
},
|
||||
map[string]any{
|
||||
"type": "web_search_tool_result", "tool_use_id": toolUseID,
|
||||
"content": buildSearchResultBlocks(resp.Results),
|
||||
},
|
||||
map[string]any{"type": "text", "text": textSummary},
|
||||
},
|
||||
"stop_reason": "end_turn", "stop_sequence": nil,
|
||||
"usage": map[string]int{"input_tokens": 0, "output_tokens": len(textSummary) / tokenEstimateDivisor},
|
||||
}
|
||||
|
||||
body, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("web search emulation: marshal response: %w", err)
|
||||
}
|
||||
c.Data(http.StatusOK, "application/json", body)
|
||||
|
||||
return &ForwardResult{Model: model, Duration: time.Since(startTime), Usage: ClaudeUsage{}}, nil
|
||||
}
|
||||
|
||||
// --- Helpers ---
|
||||
|
||||
func buildSearchResultBlocks(results []websearch.SearchResult) []map[string]string {
|
||||
blocks := make([]map[string]string, 0, len(results))
|
||||
for _, r := range results {
|
||||
block := map[string]string{
|
||||
"type": "web_search_result",
|
||||
"url": r.URL,
|
||||
"title": r.Title,
|
||||
}
|
||||
if r.Snippet != "" {
|
||||
block["page_content"] = r.Snippet
|
||||
}
|
||||
if r.PageAge != "" {
|
||||
block["page_age"] = r.PageAge
|
||||
}
|
||||
blocks = append(blocks, block)
|
||||
}
|
||||
return blocks
|
||||
}
|
||||
|
||||
func buildTextSummary(query string, results []websearch.SearchResult) string {
|
||||
if len(results) == 0 {
|
||||
return "No search results found for: " + query
|
||||
}
|
||||
var sb strings.Builder
|
||||
fmt.Fprintf(&sb, "Here are the search results for \"%s\":\n\n", query)
|
||||
for i, r := range results {
|
||||
fmt.Fprintf(&sb, "%d. **%s**\n %s\n %s\n\n", i+1, r.Title, r.URL, r.Snippet)
|
||||
}
|
||||
return sb.String()
|
||||
}
|
||||
142
backend/internal/service/gateway_websearch_emulation_test.go
Normal file
142
backend/internal/service/gateway_websearch_emulation_test.go
Normal file
@@ -0,0 +1,142 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/websearch"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// --- isOnlyWebSearchToolInBody ---
|
||||
|
||||
func TestIsOnlyWebSearchToolInBody_WebSearchType(t *testing.T) {
|
||||
require.True(t, isOnlyWebSearchToolInBody([]byte(`{"tools":[{"type":"web_search"}]}`)))
|
||||
}
|
||||
|
||||
func TestIsOnlyWebSearchToolInBody_WebSearch2025Type(t *testing.T) {
|
||||
require.True(t, isOnlyWebSearchToolInBody([]byte(`{"tools":[{"type":"web_search_20250305"}]}`)))
|
||||
}
|
||||
|
||||
func TestIsOnlyWebSearchToolInBody_GoogleSearchType(t *testing.T) {
|
||||
require.True(t, isOnlyWebSearchToolInBody([]byte(`{"tools":[{"type":"google_search"}]}`)))
|
||||
}
|
||||
|
||||
func TestIsOnlyWebSearchToolInBody_NameWebSearch(t *testing.T) {
|
||||
require.True(t, isOnlyWebSearchToolInBody([]byte(`{"tools":[{"name":"web_search"}]}`)))
|
||||
}
|
||||
|
||||
func TestIsOnlyWebSearchToolInBody_NameWebSearch2025(t *testing.T) {
|
||||
require.True(t, isOnlyWebSearchToolInBody([]byte(`{"tools":[{"name":"web_search_20250305"}]}`)))
|
||||
}
|
||||
|
||||
func TestIsOnlyWebSearchToolInBody_NameGoogleSearch(t *testing.T) {
|
||||
require.True(t, isOnlyWebSearchToolInBody([]byte(`{"tools":[{"name":"google_search"}]}`)))
|
||||
}
|
||||
|
||||
func TestIsOnlyWebSearchToolInBody_MultipleTools(t *testing.T) {
|
||||
require.False(t, isOnlyWebSearchToolInBody(
|
||||
[]byte(`{"tools":[{"type":"web_search"},{"type":"text_editor"}]}`)))
|
||||
}
|
||||
|
||||
func TestIsOnlyWebSearchToolInBody_NoTools(t *testing.T) {
|
||||
require.False(t, isOnlyWebSearchToolInBody([]byte(`{"model":"claude-3"}`)))
|
||||
}
|
||||
|
||||
func TestIsOnlyWebSearchToolInBody_EmptyToolsArray(t *testing.T) {
|
||||
require.False(t, isOnlyWebSearchToolInBody([]byte(`{"tools":[]}`)))
|
||||
}
|
||||
|
||||
func TestIsOnlyWebSearchToolInBody_NonWebSearchTool(t *testing.T) {
|
||||
require.False(t, isOnlyWebSearchToolInBody([]byte(`{"tools":[{"type":"text_editor"}]}`)))
|
||||
}
|
||||
|
||||
func TestIsOnlyWebSearchToolInBody_ToolsNotArray(t *testing.T) {
|
||||
require.False(t, isOnlyWebSearchToolInBody([]byte(`{"tools":"web_search"}`)))
|
||||
}
|
||||
|
||||
// --- extractSearchQueryFromBody ---
|
||||
|
||||
func TestExtractSearchQueryFromBody_StringContent(t *testing.T) {
|
||||
body := `{"messages":[{"role":"user","content":"what is golang"}]}`
|
||||
require.Equal(t, "what is golang", extractSearchQueryFromBody([]byte(body)))
|
||||
}
|
||||
|
||||
func TestExtractSearchQueryFromBody_ArrayContent(t *testing.T) {
|
||||
body := `{"messages":[{"role":"user","content":[{"type":"text","text":"search this"}]}]}`
|
||||
require.Equal(t, "search this", extractSearchQueryFromBody([]byte(body)))
|
||||
}
|
||||
|
||||
func TestExtractSearchQueryFromBody_MultipleMessages(t *testing.T) {
|
||||
body := `{"messages":[{"role":"user","content":"first"},{"role":"assistant","content":"ok"},{"role":"user","content":"second"}]}`
|
||||
require.Equal(t, "second", extractSearchQueryFromBody([]byte(body)))
|
||||
}
|
||||
|
||||
func TestExtractSearchQueryFromBody_LastMessageNotUser(t *testing.T) {
|
||||
body := `{"messages":[{"role":"user","content":"q"},{"role":"assistant","content":"a"}]}`
|
||||
require.Equal(t, "", extractSearchQueryFromBody([]byte(body)))
|
||||
}
|
||||
|
||||
func TestExtractSearchQueryFromBody_EmptyMessages(t *testing.T) {
|
||||
require.Equal(t, "", extractSearchQueryFromBody([]byte(`{"messages":[]}`)))
|
||||
}
|
||||
|
||||
func TestExtractSearchQueryFromBody_NoMessages(t *testing.T) {
|
||||
require.Equal(t, "", extractSearchQueryFromBody([]byte(`{"model":"claude-3"}`)))
|
||||
}
|
||||
|
||||
func TestExtractSearchQueryFromBody_ArrayContentSkipsEmptyText(t *testing.T) {
|
||||
body := `{"messages":[{"role":"user","content":[{"type":"image"},{"type":"text","text":""},{"type":"text","text":"real query"}]}]}`
|
||||
require.Equal(t, "real query", extractSearchQueryFromBody([]byte(body)))
|
||||
}
|
||||
|
||||
func TestExtractSearchQueryFromBody_ArrayContentNoTextBlock(t *testing.T) {
|
||||
body := `{"messages":[{"role":"user","content":[{"type":"image","source":{}}]}]}`
|
||||
require.Equal(t, "", extractSearchQueryFromBody([]byte(body)))
|
||||
}
|
||||
|
||||
// --- buildSearchResultBlocks ---
|
||||
|
||||
func TestBuildSearchResultBlocks_WithResults(t *testing.T) {
|
||||
results := []websearch.SearchResult{
|
||||
{URL: "https://a.com", Title: "A", Snippet: "snippet a", PageAge: "2 days"},
|
||||
{URL: "https://b.com", Title: "B", Snippet: "snippet b"},
|
||||
}
|
||||
blocks := buildSearchResultBlocks(results)
|
||||
require.Len(t, blocks, 2)
|
||||
require.Equal(t, "web_search_result", blocks[0]["type"])
|
||||
require.Equal(t, "https://a.com", blocks[0]["url"])
|
||||
require.Equal(t, "snippet a", blocks[0]["page_content"])
|
||||
require.Equal(t, "2 days", blocks[0]["page_age"])
|
||||
// Second result has no PageAge
|
||||
require.Equal(t, "https://b.com", blocks[1]["url"])
|
||||
_, hasPageAge := blocks[1]["page_age"]
|
||||
require.False(t, hasPageAge)
|
||||
}
|
||||
|
||||
func TestBuildSearchResultBlocks_Empty(t *testing.T) {
|
||||
blocks := buildSearchResultBlocks(nil)
|
||||
require.Empty(t, blocks)
|
||||
}
|
||||
|
||||
func TestBuildSearchResultBlocks_SnippetEmpty(t *testing.T) {
|
||||
blocks := buildSearchResultBlocks([]websearch.SearchResult{{URL: "https://x.com", Title: "X", Snippet: ""}})
|
||||
_, hasContent := blocks[0]["page_content"]
|
||||
require.False(t, hasContent)
|
||||
}
|
||||
|
||||
// --- buildTextSummary ---
|
||||
|
||||
func TestBuildTextSummary_WithResults(t *testing.T) {
|
||||
results := []websearch.SearchResult{
|
||||
{URL: "https://a.com", Title: "A", Snippet: "desc a"},
|
||||
}
|
||||
summary := buildTextSummary("test query", results)
|
||||
require.Contains(t, summary, "test query")
|
||||
require.Contains(t, summary, "1. **A**")
|
||||
require.Contains(t, summary, "https://a.com")
|
||||
}
|
||||
|
||||
func TestBuildTextSummary_NoResults(t *testing.T) {
|
||||
summary := buildTextSummary("test", nil)
|
||||
require.Contains(t, summary, "No search results found for: test")
|
||||
}
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/imroc/req/v3"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"golang.org/x/sync/singleflight"
|
||||
)
|
||||
|
||||
@@ -106,6 +107,7 @@ type SettingService struct {
|
||||
cfg *config.Config
|
||||
onUpdate func() // Callback when settings are updated (for cache invalidation)
|
||||
version string // Application version
|
||||
webSearchRedis *redis.Client // optional: Redis client for web search quota tracking
|
||||
}
|
||||
|
||||
// NewSettingService 创建系统设置服务实例
|
||||
@@ -1217,6 +1219,14 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
|
||||
result.EnableMetadataPassthrough = settings[SettingKeyEnableMetadataPassthrough] == "true"
|
||||
result.EnableCCHSigning = settings[SettingKeyEnableCCHSigning] == "true"
|
||||
|
||||
// Web search emulation: quick enabled check from the JSON config
|
||||
if raw := settings[SettingKeyWebSearchEmulationConfig]; raw != "" {
|
||||
var wsCfg WebSearchEmulationConfig
|
||||
if err := json.Unmarshal([]byte(raw), &wsCfg); err == nil {
|
||||
result.WebSearchEmulationEnabled = wsCfg.Enabled && len(wsCfg.Providers) > 0
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
|
||||
@@ -106,6 +106,9 @@ type SystemSettings struct {
|
||||
EnableFingerprintUnification bool // 是否统一 OAuth 账号的指纹头(默认 true)
|
||||
EnableMetadataPassthrough bool // 是否透传客户端原始 metadata(默认 false)
|
||||
EnableCCHSigning bool // 是否对 billing header cch 进行签名(默认 false)
|
||||
|
||||
// Web Search Emulation (read-only quick check; full config via dedicated API)
|
||||
WebSearchEmulationEnabled bool
|
||||
}
|
||||
|
||||
type DefaultSubscriptionSetting struct {
|
||||
|
||||
253
backend/internal/service/websearch_config.go
Normal file
253
backend/internal/service/websearch_config.go
Normal file
@@ -0,0 +1,253 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/websearch"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"golang.org/x/sync/singleflight"
|
||||
)
|
||||
|
||||
// WebSearchEmulationConfig holds the global web search emulation configuration.
|
||||
type WebSearchEmulationConfig struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Providers []WebSearchProviderConfig `json:"providers"`
|
||||
}
|
||||
|
||||
// WebSearchProviderConfig describes a single search provider (Brave or Tavily).
|
||||
type WebSearchProviderConfig struct {
|
||||
Type string `json:"type"` // websearch.ProviderTypeBrave | Tavily
|
||||
APIKey string `json:"api_key,omitempty"` // secret — omitted in API responses
|
||||
APIKeyConfigured bool `json:"api_key_configured"` // read-only mask
|
||||
Priority int `json:"priority"` // lower = higher priority
|
||||
QuotaLimit int64 `json:"quota_limit"` // 0 = unlimited
|
||||
QuotaRefreshInterval string `json:"quota_refresh_interval"` // websearch.QuotaRefresh*
|
||||
QuotaUsed int64 `json:"quota_used,omitempty"` // read-only: current period usage
|
||||
ProxyID *int64 `json:"proxy_id"` // optional proxy association
|
||||
ExpiresAt *int64 `json:"expires_at,omitempty"` // optional expiration timestamp
|
||||
}
|
||||
|
||||
// --- Validation ---
|
||||
|
||||
const maxWebSearchProviders = 10
|
||||
|
||||
var validProviderTypes = map[string]bool{
|
||||
websearch.ProviderTypeBrave: true,
|
||||
websearch.ProviderTypeTavily: true,
|
||||
}
|
||||
|
||||
var validQuotaIntervals = map[string]bool{
|
||||
websearch.QuotaRefreshDaily: true,
|
||||
websearch.QuotaRefreshWeekly: true,
|
||||
websearch.QuotaRefreshMonthly: true,
|
||||
"": true, // defaults to monthly
|
||||
}
|
||||
|
||||
func validateWebSearchConfig(cfg *WebSearchEmulationConfig) error {
|
||||
if cfg == nil {
|
||||
return nil
|
||||
}
|
||||
if len(cfg.Providers) > maxWebSearchProviders {
|
||||
return fmt.Errorf("too many providers (max %d)", maxWebSearchProviders)
|
||||
}
|
||||
seen := make(map[string]bool, len(cfg.Providers))
|
||||
for i, p := range cfg.Providers {
|
||||
if !validProviderTypes[p.Type] {
|
||||
return fmt.Errorf("provider[%d]: invalid type %q", i, p.Type)
|
||||
}
|
||||
if !validQuotaIntervals[p.QuotaRefreshInterval] {
|
||||
return fmt.Errorf("provider[%d]: invalid quota_refresh_interval %q", i, p.QuotaRefreshInterval)
|
||||
}
|
||||
if p.QuotaLimit < 0 {
|
||||
return fmt.Errorf("provider[%d]: quota_limit must be >= 0", i)
|
||||
}
|
||||
if seen[p.Type] {
|
||||
return fmt.Errorf("provider[%d]: duplicate type %q", i, p.Type)
|
||||
}
|
||||
seen[p.Type] = true
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// --- In-process cache (same pattern as gateway forwarding settings) ---
|
||||
|
||||
const sfKeyWebSearchConfig = "web_search_emulation_config"
|
||||
|
||||
type cachedWebSearchEmulationConfig struct {
|
||||
config *WebSearchEmulationConfig
|
||||
expiresAt int64 // unix nano
|
||||
}
|
||||
|
||||
var webSearchEmulationCache atomic.Value // *cachedWebSearchEmulationConfig
|
||||
var webSearchEmulationSF singleflight.Group
|
||||
|
||||
const (
|
||||
webSearchEmulationCacheTTL = 60 * time.Second
|
||||
webSearchEmulationErrorTTL = 5 * time.Second
|
||||
webSearchEmulationDBTimeout = 5 * time.Second
|
||||
)
|
||||
|
||||
// GetWebSearchEmulationConfig returns the configuration with in-process cache + singleflight.
|
||||
func (s *SettingService) GetWebSearchEmulationConfig(ctx context.Context) (*WebSearchEmulationConfig, error) {
|
||||
if cached := webSearchEmulationCache.Load(); cached != nil {
|
||||
c := cached.(*cachedWebSearchEmulationConfig)
|
||||
if time.Now().UnixNano() < c.expiresAt {
|
||||
return c.config, nil
|
||||
}
|
||||
}
|
||||
result, err, _ := webSearchEmulationSF.Do(sfKeyWebSearchConfig, func() (any, error) {
|
||||
return s.loadWebSearchConfigFromDB()
|
||||
})
|
||||
if err != nil {
|
||||
return &WebSearchEmulationConfig{}, err
|
||||
}
|
||||
return result.(*WebSearchEmulationConfig), nil
|
||||
}
|
||||
|
||||
func (s *SettingService) loadWebSearchConfigFromDB() (*WebSearchEmulationConfig, error) {
|
||||
dbCtx, cancel := context.WithTimeout(context.Background(), webSearchEmulationDBTimeout)
|
||||
defer cancel()
|
||||
|
||||
raw, err := s.settingRepo.GetValue(dbCtx, SettingKeyWebSearchEmulationConfig)
|
||||
if err != nil {
|
||||
webSearchEmulationCache.Store(&cachedWebSearchEmulationConfig{
|
||||
config: &WebSearchEmulationConfig{},
|
||||
expiresAt: time.Now().Add(webSearchEmulationErrorTTL).UnixNano(),
|
||||
})
|
||||
return &WebSearchEmulationConfig{}, err
|
||||
}
|
||||
cfg := parseWebSearchConfigJSON(raw)
|
||||
webSearchEmulationCache.Store(&cachedWebSearchEmulationConfig{
|
||||
config: cfg,
|
||||
expiresAt: time.Now().Add(webSearchEmulationCacheTTL).UnixNano(),
|
||||
})
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func parseWebSearchConfigJSON(raw string) *WebSearchEmulationConfig {
|
||||
cfg := &WebSearchEmulationConfig{}
|
||||
if raw == "" {
|
||||
return cfg
|
||||
}
|
||||
if err := json.Unmarshal([]byte(raw), cfg); err != nil {
|
||||
slog.Warn("websearch: failed to parse config JSON", "error", err)
|
||||
return &WebSearchEmulationConfig{}
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
||||
// SaveWebSearchEmulationConfig validates and persists the configuration.
|
||||
// Empty API keys in the input are preserved from the existing config.
|
||||
func (s *SettingService) SaveWebSearchEmulationConfig(ctx context.Context, cfg *WebSearchEmulationConfig) error {
|
||||
if err := validateWebSearchConfig(cfg); err != nil {
|
||||
return infraerrors.BadRequest("INVALID_WEB_SEARCH_CONFIG", err.Error())
|
||||
}
|
||||
s.mergeExistingAPIKeys(ctx, cfg)
|
||||
|
||||
data, err := json.Marshal(cfg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("websearch: marshal config: %w", err)
|
||||
}
|
||||
if err := s.settingRepo.Set(ctx, SettingKeyWebSearchEmulationConfig, string(data)); err != nil {
|
||||
return fmt.Errorf("websearch: save config: %w", err)
|
||||
}
|
||||
// Invalidate: forget singleflight first, then store new value
|
||||
webSearchEmulationSF.Forget(sfKeyWebSearchConfig)
|
||||
webSearchEmulationCache.Store(&cachedWebSearchEmulationConfig{
|
||||
config: cfg,
|
||||
expiresAt: time.Now().Add(webSearchEmulationCacheTTL).UnixNano(),
|
||||
})
|
||||
|
||||
// Hot-reload: rebuild the global Manager with new config
|
||||
s.RebuildWebSearchManager(ctx)
|
||||
return nil
|
||||
}
|
||||
|
||||
// mergeExistingAPIKeys preserves API keys from the current config when incoming value is empty.
|
||||
func (s *SettingService) mergeExistingAPIKeys(ctx context.Context, cfg *WebSearchEmulationConfig) {
|
||||
existing, _ := s.getWebSearchEmulationConfigRaw(ctx)
|
||||
if existing == nil || cfg == nil {
|
||||
return
|
||||
}
|
||||
existingByType := make(map[string]string, len(existing.Providers))
|
||||
for _, p := range existing.Providers {
|
||||
if p.APIKey != "" {
|
||||
existingByType[p.Type] = p.APIKey
|
||||
}
|
||||
}
|
||||
for i := range cfg.Providers {
|
||||
if cfg.Providers[i].APIKey == "" {
|
||||
if key, ok := existingByType[cfg.Providers[i].Type]; ok {
|
||||
cfg.Providers[i].APIKey = key
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SettingService) getWebSearchEmulationConfigRaw(ctx context.Context) (*WebSearchEmulationConfig, error) {
|
||||
raw, err := s.settingRepo.GetValue(ctx, SettingKeyWebSearchEmulationConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return parseWebSearchConfigJSON(raw), nil
|
||||
}
|
||||
|
||||
// IsWebSearchEmulationEnabled is a quick check for whether the global switch is on.
|
||||
func (s *SettingService) IsWebSearchEmulationEnabled(ctx context.Context) bool {
|
||||
cfg, err := s.GetWebSearchEmulationConfig(ctx)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return cfg.Enabled && len(cfg.Providers) > 0
|
||||
}
|
||||
|
||||
// SetWebSearchRedisClient injects the Redis client used for quota tracking.
|
||||
// Call after construction, before first use. Triggers initial Manager build.
|
||||
func (s *SettingService) SetWebSearchRedisClient(ctx context.Context, redisClient *redis.Client) {
|
||||
s.webSearchRedis = redisClient
|
||||
s.RebuildWebSearchManager(ctx)
|
||||
}
|
||||
|
||||
// RebuildWebSearchManager reads the current config and (re)creates the global websearch.Manager.
|
||||
// Called on startup and after SaveWebSearchEmulationConfig.
|
||||
func (s *SettingService) RebuildWebSearchManager(ctx context.Context) {
|
||||
cfg, err := s.GetWebSearchEmulationConfig(ctx)
|
||||
if err != nil || !cfg.Enabled || len(cfg.Providers) == 0 {
|
||||
SetWebSearchManager(nil)
|
||||
return
|
||||
}
|
||||
providerConfigs := make([]websearch.ProviderConfig, 0, len(cfg.Providers))
|
||||
for _, p := range cfg.Providers {
|
||||
providerConfigs = append(providerConfigs, websearch.ProviderConfig{
|
||||
Type: p.Type,
|
||||
APIKey: p.APIKey,
|
||||
Priority: p.Priority,
|
||||
QuotaLimit: p.QuotaLimit,
|
||||
QuotaRefreshInterval: p.QuotaRefreshInterval,
|
||||
ExpiresAt: p.ExpiresAt,
|
||||
})
|
||||
}
|
||||
SetWebSearchManager(websearch.NewManager(providerConfigs, s.webSearchRedis))
|
||||
slog.Info("websearch: manager rebuilt", "provider_count", len(providerConfigs))
|
||||
}
|
||||
|
||||
// SanitizeWebSearchConfig returns a copy with api_key fields masked for API responses.
|
||||
func SanitizeWebSearchConfig(cfg *WebSearchEmulationConfig) *WebSearchEmulationConfig {
|
||||
if cfg == nil {
|
||||
return nil
|
||||
}
|
||||
out := *cfg
|
||||
out.Providers = make([]WebSearchProviderConfig, len(cfg.Providers))
|
||||
for i, p := range cfg.Providers {
|
||||
out.Providers[i] = p
|
||||
out.Providers[i].APIKeyConfigured = p.APIKey != ""
|
||||
out.Providers[i].APIKey = "" // never return the secret
|
||||
}
|
||||
return &out
|
||||
}
|
||||
148
backend/internal/service/websearch_config_test.go
Normal file
148
backend/internal/service/websearch_config_test.go
Normal file
@@ -0,0 +1,148 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// --- validateWebSearchConfig ---
|
||||
|
||||
func TestValidateWebSearchConfig_Nil(t *testing.T) {
|
||||
require.NoError(t, validateWebSearchConfig(nil))
|
||||
}
|
||||
|
||||
func TestValidateWebSearchConfig_Valid(t *testing.T) {
|
||||
cfg := &WebSearchEmulationConfig{
|
||||
Enabled: true,
|
||||
Providers: []WebSearchProviderConfig{
|
||||
{Type: "brave", Priority: 1, QuotaLimit: 1000, QuotaRefreshInterval: "monthly"},
|
||||
{Type: "tavily", Priority: 2, QuotaLimit: 500, QuotaRefreshInterval: "daily"},
|
||||
},
|
||||
}
|
||||
require.NoError(t, validateWebSearchConfig(cfg))
|
||||
}
|
||||
|
||||
func TestValidateWebSearchConfig_TooManyProviders(t *testing.T) {
|
||||
cfg := &WebSearchEmulationConfig{Providers: make([]WebSearchProviderConfig, 11)}
|
||||
for i := range cfg.Providers {
|
||||
cfg.Providers[i] = WebSearchProviderConfig{Type: "brave"}
|
||||
}
|
||||
err := validateWebSearchConfig(cfg)
|
||||
require.ErrorContains(t, err, "too many providers")
|
||||
}
|
||||
|
||||
func TestValidateWebSearchConfig_InvalidType(t *testing.T) {
|
||||
cfg := &WebSearchEmulationConfig{
|
||||
Providers: []WebSearchProviderConfig{{Type: "bing"}},
|
||||
}
|
||||
require.ErrorContains(t, validateWebSearchConfig(cfg), "invalid type")
|
||||
}
|
||||
|
||||
func TestValidateWebSearchConfig_InvalidQuotaInterval(t *testing.T) {
|
||||
cfg := &WebSearchEmulationConfig{
|
||||
Providers: []WebSearchProviderConfig{{Type: "brave", QuotaRefreshInterval: "hourly"}},
|
||||
}
|
||||
require.ErrorContains(t, validateWebSearchConfig(cfg), "invalid quota_refresh_interval")
|
||||
}
|
||||
|
||||
func TestValidateWebSearchConfig_NegativeQuotaLimit(t *testing.T) {
|
||||
cfg := &WebSearchEmulationConfig{
|
||||
Providers: []WebSearchProviderConfig{{Type: "brave", QuotaLimit: -1}},
|
||||
}
|
||||
require.ErrorContains(t, validateWebSearchConfig(cfg), "quota_limit must be >= 0")
|
||||
}
|
||||
|
||||
func TestValidateWebSearchConfig_DuplicateType(t *testing.T) {
|
||||
cfg := &WebSearchEmulationConfig{
|
||||
Providers: []WebSearchProviderConfig{
|
||||
{Type: "brave", Priority: 1},
|
||||
{Type: "brave", Priority: 2},
|
||||
},
|
||||
}
|
||||
require.ErrorContains(t, validateWebSearchConfig(cfg), "duplicate type")
|
||||
}
|
||||
|
||||
func TestValidateWebSearchConfig_EmptyQuotaInterval(t *testing.T) {
|
||||
cfg := &WebSearchEmulationConfig{
|
||||
Providers: []WebSearchProviderConfig{{Type: "brave", QuotaRefreshInterval: ""}},
|
||||
}
|
||||
require.NoError(t, validateWebSearchConfig(cfg))
|
||||
}
|
||||
|
||||
func TestValidateWebSearchConfig_ZeroQuotaLimit(t *testing.T) {
|
||||
cfg := &WebSearchEmulationConfig{
|
||||
Providers: []WebSearchProviderConfig{{Type: "brave", QuotaLimit: 0}},
|
||||
}
|
||||
require.NoError(t, validateWebSearchConfig(cfg))
|
||||
}
|
||||
|
||||
// --- parseWebSearchConfigJSON ---
|
||||
|
||||
func TestParseWebSearchConfigJSON_ValidJSON(t *testing.T) {
|
||||
raw := `{"enabled":true,"providers":[{"type":"brave","api_key":"sk-xxx"}]}`
|
||||
cfg := parseWebSearchConfigJSON(raw)
|
||||
require.True(t, cfg.Enabled)
|
||||
require.Len(t, cfg.Providers, 1)
|
||||
require.Equal(t, "brave", cfg.Providers[0].Type)
|
||||
}
|
||||
|
||||
func TestParseWebSearchConfigJSON_EmptyString(t *testing.T) {
|
||||
cfg := parseWebSearchConfigJSON("")
|
||||
require.False(t, cfg.Enabled)
|
||||
require.Empty(t, cfg.Providers)
|
||||
}
|
||||
|
||||
func TestParseWebSearchConfigJSON_InvalidJSON(t *testing.T) {
|
||||
cfg := parseWebSearchConfigJSON("not{json")
|
||||
require.False(t, cfg.Enabled)
|
||||
require.Empty(t, cfg.Providers)
|
||||
}
|
||||
|
||||
// --- SanitizeWebSearchConfig ---
|
||||
|
||||
func TestSanitizeWebSearchConfig_MaskAPIKey(t *testing.T) {
|
||||
cfg := &WebSearchEmulationConfig{
|
||||
Enabled: true,
|
||||
Providers: []WebSearchProviderConfig{
|
||||
{Type: "brave", APIKey: "sk-secret-xxx"},
|
||||
},
|
||||
}
|
||||
out := SanitizeWebSearchConfig(cfg)
|
||||
require.Equal(t, "", out.Providers[0].APIKey)
|
||||
require.True(t, out.Providers[0].APIKeyConfigured)
|
||||
}
|
||||
|
||||
func TestSanitizeWebSearchConfig_NoAPIKey(t *testing.T) {
|
||||
cfg := &WebSearchEmulationConfig{
|
||||
Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: ""}},
|
||||
}
|
||||
out := SanitizeWebSearchConfig(cfg)
|
||||
require.Equal(t, "", out.Providers[0].APIKey)
|
||||
require.False(t, out.Providers[0].APIKeyConfigured)
|
||||
}
|
||||
|
||||
func TestSanitizeWebSearchConfig_Nil(t *testing.T) {
|
||||
require.Nil(t, SanitizeWebSearchConfig(nil))
|
||||
}
|
||||
|
||||
func TestSanitizeWebSearchConfig_PreservesOtherFields(t *testing.T) {
|
||||
cfg := &WebSearchEmulationConfig{
|
||||
Enabled: true,
|
||||
Providers: []WebSearchProviderConfig{
|
||||
{Type: "brave", APIKey: "secret", Priority: 10, QuotaLimit: 1000},
|
||||
},
|
||||
}
|
||||
out := SanitizeWebSearchConfig(cfg)
|
||||
require.True(t, out.Enabled)
|
||||
require.Equal(t, 10, out.Providers[0].Priority)
|
||||
require.Equal(t, int64(1000), out.Providers[0].QuotaLimit)
|
||||
}
|
||||
|
||||
func TestSanitizeWebSearchConfig_DoesNotMutateOriginal(t *testing.T) {
|
||||
cfg := &WebSearchEmulationConfig{
|
||||
Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: "secret"}},
|
||||
}
|
||||
_ = SanitizeWebSearchConfig(cfg)
|
||||
require.Equal(t, "secret", cfg.Providers[0].APIKey)
|
||||
}
|
||||
Reference in New Issue
Block a user