feat(gateway): add web search emulation for Anthropic API Key accounts
Inject web search capability for Claude Console (API Key) accounts that don't natively support Anthropic's web_search tool. When a pure web_search request is detected, the gateway calls Brave Search or Tavily API directly and constructs an Anthropic-protocol-compliant SSE/JSON response without forwarding to upstream. Backend: - New `pkg/websearch/` SDK: Brave and Tavily provider implementations with io.LimitReader, proxy support, and Redis-based quota tracking (Lua atomic INCR + TTL, DECR rollback on failure) - Global config via `settings.web_search_emulation_config` (JSON) with in-process cache + singleflight, input validation, API key merge on save, and sanitized API responses - Channel-level toggle via `channels.features_config` JSONB column (DB migration 101) - Account-level toggle via `accounts.extra.web_search_emulation` - Request interception in `Forward()` with SSE streaming response construction using json.Marshal (no manual string concatenation) - Manager hot-reload: `RebuildWebSearchManager()` called on config save and startup via `SetWebSearchRedisClient()` - 70 unit tests covering providers, manager, config validation, sanitization, tool detection, query extraction, and response building Frontend: - Settings → Gateway tab: Web Search Emulation config card with global toggle, provider list (add/remove, API key, priority, quota, proxy) - Channels → Anthropic tab: web search emulation toggle with global state linkage (disabled when global off) - Account Create/Edit modals: web search emulation toggle for API Key type with Toggle component - Full i18n coverage (zh + en)
This commit is contained in:
@@ -969,7 +969,7 @@ func (a *Account) IsOveragesEnabled() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// IsOpenAIPassthroughEnabled 返回 OpenAI 账号是否启用“自动透传(仅替换认证)”。
|
||||
// IsOpenAIPassthroughEnabled 返回 OpenAI 账号是否启用"自动透传(仅替换认证)"。
|
||||
//
|
||||
// 新字段:accounts.extra.openai_passthrough。
|
||||
// 兼容字段:accounts.extra.openai_oauth_passthrough(历史 OAuth 开关)。
|
||||
@@ -1133,7 +1133,7 @@ func (a *Account) ResolveOpenAIResponsesWebSocketV2Mode(defaultMode string) stri
|
||||
return resolvedDefault
|
||||
}
|
||||
|
||||
// IsOpenAIWSForceHTTPEnabled 返回账号级“强制 HTTP”开关。
|
||||
// IsOpenAIWSForceHTTPEnabled 返回账号级"强制 HTTP"开关。
|
||||
// 字段:accounts.extra.openai_ws_force_http。
|
||||
func (a *Account) IsOpenAIWSForceHTTPEnabled() bool {
|
||||
if a == nil || !a.IsOpenAI() || a.Extra == nil {
|
||||
@@ -1158,7 +1158,7 @@ func (a *Account) IsOpenAIOAuthPassthroughEnabled() bool {
|
||||
return a != nil && a.IsOpenAIOAuth() && a.IsOpenAIPassthroughEnabled()
|
||||
}
|
||||
|
||||
// IsAnthropicAPIKeyPassthroughEnabled 返回 Anthropic API Key 账号是否启用“自动透传(仅替换认证)”。
|
||||
// IsAnthropicAPIKeyPassthroughEnabled 返回 Anthropic API Key 账号是否启用"自动透传(仅替换认证)"。
|
||||
// 字段:accounts.extra.anthropic_passthrough。
|
||||
// 字段缺失或类型不正确时,按 false(关闭)处理。
|
||||
func (a *Account) IsAnthropicAPIKeyPassthroughEnabled() bool {
|
||||
@@ -1169,7 +1169,18 @@ func (a *Account) IsAnthropicAPIKeyPassthroughEnabled() bool {
|
||||
return ok && enabled
|
||||
}
|
||||
|
||||
// IsCodexCLIOnlyEnabled 返回 OpenAI OAuth 账号是否启用“仅允许 Codex 官方客户端”。
|
||||
// IsWebSearchEmulationEnabled 返回 Anthropic API Key 账号是否启用 web search 模拟。
|
||||
// 字段:accounts.extra.web_search_emulation。
|
||||
// 字段缺失或类型不正确时,按 false(关闭)处理。
|
||||
func (a *Account) IsWebSearchEmulationEnabled() bool {
|
||||
if a == nil || a.Platform != PlatformAnthropic || a.Type != AccountTypeAPIKey || a.Extra == nil {
|
||||
return false
|
||||
}
|
||||
enabled, ok := a.Extra[featureKeyWebSearchEmulation].(bool)
|
||||
return ok && enabled
|
||||
}
|
||||
|
||||
// IsCodexCLIOnlyEnabled 返回 OpenAI OAuth 账号是否启用"仅允许 Codex 官方客户端"。
|
||||
// 字段:accounts.extra.codex_cli_only。
|
||||
// 字段缺失或类型不正确时,按 false(关闭)处理。
|
||||
func (a *Account) IsCodexCLIOnlyEnabled() bool {
|
||||
|
||||
71
backend/internal/service/account_websearch_test.go
Normal file
71
backend/internal/service/account_websearch_test.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestAccount_IsWebSearchEmulationEnabled_Enabled(t *testing.T) {
|
||||
a := &Account{
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeAPIKey,
|
||||
Extra: map[string]any{featureKeyWebSearchEmulation: true},
|
||||
}
|
||||
require.True(t, a.IsWebSearchEmulationEnabled())
|
||||
}
|
||||
|
||||
func TestAccount_IsWebSearchEmulationEnabled_Disabled(t *testing.T) {
|
||||
a := &Account{
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeAPIKey,
|
||||
Extra: map[string]any{featureKeyWebSearchEmulation: false},
|
||||
}
|
||||
require.False(t, a.IsWebSearchEmulationEnabled())
|
||||
}
|
||||
|
||||
func TestAccount_IsWebSearchEmulationEnabled_MissingField(t *testing.T) {
|
||||
a := &Account{
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeAPIKey,
|
||||
Extra: map[string]any{},
|
||||
}
|
||||
require.False(t, a.IsWebSearchEmulationEnabled())
|
||||
}
|
||||
|
||||
func TestAccount_IsWebSearchEmulationEnabled_WrongType(t *testing.T) {
|
||||
a := &Account{
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeAPIKey,
|
||||
Extra: map[string]any{featureKeyWebSearchEmulation: "true"},
|
||||
}
|
||||
require.False(t, a.IsWebSearchEmulationEnabled())
|
||||
}
|
||||
|
||||
func TestAccount_IsWebSearchEmulationEnabled_NilExtra(t *testing.T) {
|
||||
a := &Account{Platform: PlatformAnthropic, Type: AccountTypeAPIKey, Extra: nil}
|
||||
require.False(t, a.IsWebSearchEmulationEnabled())
|
||||
}
|
||||
|
||||
func TestAccount_IsWebSearchEmulationEnabled_NilAccount(t *testing.T) {
|
||||
var a *Account
|
||||
require.False(t, a.IsWebSearchEmulationEnabled())
|
||||
}
|
||||
|
||||
func TestAccount_IsWebSearchEmulationEnabled_NonAnthropicPlatform(t *testing.T) {
|
||||
a := &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Extra: map[string]any{featureKeyWebSearchEmulation: true},
|
||||
}
|
||||
require.False(t, a.IsWebSearchEmulationEnabled())
|
||||
}
|
||||
|
||||
func TestAccount_IsWebSearchEmulationEnabled_NonAPIKeyType(t *testing.T) {
|
||||
a := &Account{
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Extra: map[string]any{featureKeyWebSearchEmulation: true},
|
||||
}
|
||||
require.False(t, a.IsWebSearchEmulationEnabled())
|
||||
}
|
||||
@@ -49,6 +49,21 @@ type Channel struct {
|
||||
ModelPricing []ChannelModelPricing
|
||||
// 渠道级模型映射(按平台分组:platform → {src→dst})
|
||||
ModelMapping map[string]map[string]string
|
||||
// 渠道特性配置(如 {"web_search_emulation": {"anthropic": true}})
|
||||
FeaturesConfig map[string]any
|
||||
}
|
||||
|
||||
// IsWebSearchEmulationEnabled 返回该渠道是否为指定平台启用了 web search 模拟。
|
||||
func (c *Channel) IsWebSearchEmulationEnabled(platform string) bool {
|
||||
if c == nil || c.FeaturesConfig == nil {
|
||||
return false
|
||||
}
|
||||
wse, ok := c.FeaturesConfig[featureKeyWebSearchEmulation].(map[string]any)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
enabled, ok := wse[platform].(bool)
|
||||
return ok && enabled
|
||||
}
|
||||
|
||||
// ChannelModelPricing 渠道模型定价条目
|
||||
|
||||
@@ -197,10 +197,8 @@ func newEmptyChannelCache() *channelCache {
|
||||
}
|
||||
|
||||
// expandPricingToCache 将渠道的模型定价展开到缓存(按分组+平台维度)。
|
||||
// antigravity 平台同时服务 Claude 和 Gemini 模型,需匹配 anthropic/gemini 的定价条目。
|
||||
// 缓存 key 使用定价条目的原始平台(pricing.Platform),而非分组平台,
|
||||
// 避免跨平台同名模型(如 anthropic 和 gemini 都有 "model-x")互相覆盖。
|
||||
// 查找时通过 lookupPricingAcrossPlatforms() 依次尝试所有匹配平台。
|
||||
// 各平台严格独立:antigravity 分组只匹配 antigravity 定价,不会匹配 anthropic/gemini 的定价。
|
||||
// 查找时通过 lookupPricingAcrossPlatforms() 在本平台内查找。
|
||||
func expandPricingToCache(cache *channelCache, ch *Channel, gid int64, platform string) {
|
||||
for j := range ch.ModelPricing {
|
||||
pricing := &ch.ModelPricing[j]
|
||||
@@ -226,8 +224,7 @@ func expandPricingToCache(cache *channelCache, ch *Channel, gid int64, platform
|
||||
}
|
||||
|
||||
// expandMappingToCache 将渠道的模型映射展开到缓存(按分组+平台维度)。
|
||||
// antigravity 平台同时服务 Claude 和 Gemini 模型。
|
||||
// 缓存 key 使用映射条目的原始平台(mappingPlatform),避免跨平台同名映射覆盖。
|
||||
// 各平台严格独立:antigravity 分组只匹配 antigravity 映射。
|
||||
func expandMappingToCache(cache *channelCache, ch *Channel, gid int64, platform string) {
|
||||
for _, mappingPlatform := range matchingPlatforms(platform) {
|
||||
platformMapping, ok := ch.ModelMapping[mappingPlatform]
|
||||
@@ -251,40 +248,58 @@ func expandMappingToCache(cache *channelCache, ch *Channel, gid int64, platform
|
||||
}
|
||||
}
|
||||
|
||||
// storeErrorCache 存入短 TTL 空缓存,防止 DB 错误后紧密重试。
|
||||
// 通过回退 loadedAt 使剩余 TTL = channelErrorTTL。
|
||||
func (s *ChannelService) storeErrorCache() {
|
||||
errorCache := newEmptyChannelCache()
|
||||
errorCache.loadedAt = time.Now().Add(-(channelCacheTTL - channelErrorTTL))
|
||||
s.cache.Store(errorCache)
|
||||
}
|
||||
|
||||
// buildCache 从数据库构建渠道缓存。
|
||||
// 使用独立 context 避免请求取消导致空值被长期缓存。
|
||||
func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error) {
|
||||
// 断开请求取消链,避免客户端断连导致空值被长期缓存
|
||||
dbCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), channelCacheDBTimeout)
|
||||
defer cancel()
|
||||
|
||||
channels, err := s.repo.ListAll(dbCtx)
|
||||
channels, groupPlatforms, err := s.fetchChannelData(dbCtx)
|
||||
if err != nil {
|
||||
// error-TTL:失败时存入短 TTL 空缓存,防止紧密重试
|
||||
slog.Warn("failed to build channel cache", "error", err)
|
||||
errorCache := newEmptyChannelCache()
|
||||
errorCache.loadedAt = time.Now().Add(-(channelCacheTTL - channelErrorTTL)) // 使剩余 TTL = errorTTL
|
||||
s.cache.Store(errorCache)
|
||||
return nil, fmt.Errorf("list all channels: %w", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cache := populateChannelCache(channels, groupPlatforms)
|
||||
s.cache.Store(cache)
|
||||
return cache, nil
|
||||
}
|
||||
|
||||
// fetchChannelData 从数据库加载渠道列表和分组平台映射。
|
||||
func (s *ChannelService) fetchChannelData(ctx context.Context) ([]Channel, map[int64]string, error) {
|
||||
channels, err := s.repo.ListAll(ctx)
|
||||
if err != nil {
|
||||
slog.Warn("failed to build channel cache", "error", err)
|
||||
s.storeErrorCache()
|
||||
return nil, nil, fmt.Errorf("list all channels: %w", err)
|
||||
}
|
||||
|
||||
// 收集所有 groupID,批量查询 platform
|
||||
var allGroupIDs []int64
|
||||
for i := range channels {
|
||||
allGroupIDs = append(allGroupIDs, channels[i].GroupIDs...)
|
||||
}
|
||||
|
||||
groupPlatforms := make(map[int64]string)
|
||||
if len(allGroupIDs) > 0 {
|
||||
groupPlatforms, err = s.repo.GetGroupPlatforms(dbCtx, allGroupIDs)
|
||||
groupPlatforms, err = s.repo.GetGroupPlatforms(ctx, allGroupIDs)
|
||||
if err != nil {
|
||||
slog.Warn("failed to load group platforms for channel cache", "error", err)
|
||||
errorCache := newEmptyChannelCache()
|
||||
errorCache.loadedAt = time.Now().Add(-(channelCacheTTL - channelErrorTTL))
|
||||
s.cache.Store(errorCache)
|
||||
return nil, fmt.Errorf("get group platforms: %w", err)
|
||||
s.storeErrorCache()
|
||||
return nil, nil, fmt.Errorf("get group platforms: %w", err)
|
||||
}
|
||||
}
|
||||
return channels, groupPlatforms, nil
|
||||
}
|
||||
|
||||
// populateChannelCache 将渠道列表和分组平台映射填充到缓存快照中。
|
||||
func populateChannelCache(channels []Channel, groupPlatforms map[int64]string) *channelCache {
|
||||
cache := newEmptyChannelCache()
|
||||
cache.groupPlatform = groupPlatforms
|
||||
cache.byID = make(map[int64]*Channel, len(channels))
|
||||
@@ -293,7 +308,6 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error)
|
||||
for i := range channels {
|
||||
ch := &channels[i]
|
||||
cache.byID[ch.ID] = ch
|
||||
|
||||
for _, gid := range ch.GroupIDs {
|
||||
cache.channelByGroupID[gid] = ch
|
||||
platform := groupPlatforms[gid]
|
||||
@@ -302,32 +316,20 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error)
|
||||
}
|
||||
}
|
||||
|
||||
// 通配符条目保持配置顺序(最先匹配到优先)
|
||||
|
||||
s.cache.Store(cache)
|
||||
return cache, nil
|
||||
return cache
|
||||
}
|
||||
|
||||
// invalidateCache 使缓存失效,让下次读取时自然重建
|
||||
|
||||
// isPlatformPricingMatch 判断定价条目的平台是否匹配分组平台。
|
||||
// antigravity 平台同时服务 Claude(anthropic)和 Gemini(gemini)模型,
|
||||
// 因此 antigravity 分组应匹配 anthropic 和 gemini 的定价条目。
|
||||
// 各平台(antigravity / anthropic / gemini / openai)严格独立,不跨平台匹配。
|
||||
func isPlatformPricingMatch(groupPlatform, pricingPlatform string) bool {
|
||||
if groupPlatform == pricingPlatform {
|
||||
return true
|
||||
}
|
||||
if groupPlatform == PlatformAntigravity {
|
||||
return pricingPlatform == PlatformAnthropic || pricingPlatform == PlatformGemini
|
||||
}
|
||||
return false
|
||||
return groupPlatform == pricingPlatform
|
||||
}
|
||||
|
||||
// matchingPlatforms 返回分组平台对应的所有可匹配平台列表。
|
||||
// matchingPlatforms 返回分组平台对应的可匹配平台列表。
|
||||
// 各平台严格独立,只返回自身。
|
||||
func matchingPlatforms(groupPlatform string) []string {
|
||||
if groupPlatform == PlatformAntigravity {
|
||||
return []string{PlatformAntigravity, PlatformAnthropic, PlatformGemini}
|
||||
}
|
||||
return []string{groupPlatform}
|
||||
}
|
||||
func (s *ChannelService) invalidateCache() {
|
||||
@@ -364,10 +366,8 @@ func (c *channelCache) matchWildcardMapping(groupID int64, platform, modelLower
|
||||
return ""
|
||||
}
|
||||
|
||||
// lookupPricingAcrossPlatforms 在所有匹配平台中查找模型定价。
|
||||
// antigravity 分组的缓存 key 使用定价条目的原始平台,因此查找时需依次尝试
|
||||
// matchingPlatforms() 返回的所有平台(antigravity → anthropic → gemini),
|
||||
// 返回第一个命中的结果。非 antigravity 平台只尝试自身。
|
||||
// lookupPricingAcrossPlatforms 在分组平台内查找模型定价。
|
||||
// 各平台严格独立,只在本平台内查找(先精确匹配,再通配符)。
|
||||
func lookupPricingAcrossPlatforms(cache *channelCache, groupID int64, groupPlatform, modelLower string) *ChannelModelPricing {
|
||||
for _, p := range matchingPlatforms(groupPlatform) {
|
||||
key := channelModelKey{groupID: groupID, platform: p, model: modelLower}
|
||||
@@ -384,7 +384,7 @@ func lookupPricingAcrossPlatforms(cache *channelCache, groupID int64, groupPlatf
|
||||
return nil
|
||||
}
|
||||
|
||||
// lookupMappingAcrossPlatforms 在所有匹配平台中查找模型映射。
|
||||
// lookupMappingAcrossPlatforms 在分组平台内查找模型映射。
|
||||
// 逻辑与 lookupPricingAcrossPlatforms 相同:先精确查找,再通配符。
|
||||
func lookupMappingAcrossPlatforms(cache *channelCache, groupID int64, groupPlatform, modelLower string) string {
|
||||
for _, p := range matchingPlatforms(groupPlatform) {
|
||||
@@ -442,8 +442,7 @@ func (s *ChannelService) lookupGroupChannel(ctx context.Context, groupID int64)
|
||||
}
|
||||
|
||||
// GetChannelModelPricing 获取指定分组+模型的渠道定价(热路径 O(1))。
|
||||
// antigravity 分组依次尝试所有匹配平台(antigravity → anthropic → gemini),
|
||||
// 确保跨平台同名模型各自独立匹配。
|
||||
// 各平台严格独立,只在本平台内查找定价。
|
||||
func (s *ChannelService) GetChannelModelPricing(ctx context.Context, groupID int64, model string) *ChannelModelPricing {
|
||||
lk, err := s.lookupGroupChannel(ctx, groupID)
|
||||
if err != nil {
|
||||
@@ -481,7 +480,10 @@ func (s *ChannelService) ResolveChannelMapping(ctx context.Context, groupID int6
|
||||
// 返回 true 表示模型被限制(不在允许列表中)。
|
||||
// 如果渠道未启用模型限制或分组无渠道关联,返回 false。
|
||||
func (s *ChannelService) IsModelRestricted(ctx context.Context, groupID int64, model string) bool {
|
||||
lk, _ := s.lookupGroupChannel(ctx, groupID)
|
||||
lk, err := s.lookupGroupChannel(ctx, groupID)
|
||||
if err != nil {
|
||||
slog.Warn("failed to load channel cache for model restriction check", "group_id", groupID, "error", err)
|
||||
}
|
||||
if lk == nil {
|
||||
return false
|
||||
}
|
||||
@@ -524,7 +526,7 @@ func resolveMapping(lk *channelLookup, groupID int64, model string) ChannelMappi
|
||||
}
|
||||
|
||||
// checkRestricted 基于已查找的渠道信息检查模型是否被限制。
|
||||
// antigravity 分组依次尝试所有匹配平台的定价列表。
|
||||
// 只在本平台的定价列表中查找。
|
||||
func checkRestricted(lk *channelLookup, groupID int64, model string) bool {
|
||||
if !lk.channel.RestrictModels {
|
||||
return false
|
||||
@@ -552,6 +554,91 @@ func ReplaceModelInBody(body []byte, newModel string) []byte {
|
||||
return newBody
|
||||
}
|
||||
|
||||
// validateChannelConfig 校验渠道的定价和映射配置(冲突检测 + 区间校验 + 计费模式校验)。
|
||||
// Create 和 Update 共用此函数,避免重复。
|
||||
func validateChannelConfig(pricing []ChannelModelPricing, mapping map[string]map[string]string) error {
|
||||
if err := validateNoConflictingModels(pricing); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := validatePricingIntervals(pricing); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := validateNoConflictingMappings(mapping); err != nil {
|
||||
return err
|
||||
}
|
||||
return validatePricingBillingMode(pricing)
|
||||
}
|
||||
|
||||
// validatePricingBillingMode 校验计费模式配置:按次/图片模式必须配价格或区间,所有价格字段不能为负,区间至少有一个价格字段。
|
||||
func validatePricingBillingMode(pricing []ChannelModelPricing) error {
|
||||
for _, p := range pricing {
|
||||
if err := checkBillingModeRequirements(p); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := checkPricesNotNegative(p); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := checkIntervalsHavePrices(p); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func checkBillingModeRequirements(p ChannelModelPricing) error {
|
||||
if p.BillingMode == BillingModePerRequest || p.BillingMode == BillingModeImage {
|
||||
if p.PerRequestPrice == nil && len(p.Intervals) == 0 {
|
||||
return infraerrors.BadRequest(
|
||||
"BILLING_MODE_MISSING_PRICE",
|
||||
"per-request price or intervals required for per_request/image billing mode",
|
||||
)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func checkPricesNotNegative(p ChannelModelPricing) error {
|
||||
checks := []struct {
|
||||
field string
|
||||
val *float64
|
||||
}{
|
||||
{"input_price", p.InputPrice},
|
||||
{"output_price", p.OutputPrice},
|
||||
{"cache_write_price", p.CacheWritePrice},
|
||||
{"cache_read_price", p.CacheReadPrice},
|
||||
{"image_output_price", p.ImageOutputPrice},
|
||||
{"per_request_price", p.PerRequestPrice},
|
||||
}
|
||||
for _, c := range checks {
|
||||
if c.val != nil && *c.val < 0 {
|
||||
return infraerrors.BadRequest("NEGATIVE_PRICE", fmt.Sprintf("%s must be >= 0", c.field))
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func checkIntervalsHavePrices(p ChannelModelPricing) error {
|
||||
for _, iv := range p.Intervals {
|
||||
if iv.InputPrice == nil && iv.OutputPrice == nil &&
|
||||
iv.CacheWritePrice == nil && iv.CacheReadPrice == nil &&
|
||||
iv.PerRequestPrice == nil {
|
||||
return infraerrors.BadRequest(
|
||||
"INTERVAL_MISSING_PRICE",
|
||||
fmt.Sprintf("interval [%d, %s] has no price fields set for model %v",
|
||||
iv.MinTokens, formatMaxTokens(iv.MaxTokens), p.Models),
|
||||
)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func formatMaxTokens(max *int) string {
|
||||
if max == nil {
|
||||
return "∞"
|
||||
}
|
||||
return fmt.Sprintf("%d", *max)
|
||||
}
|
||||
|
||||
// --- CRUD ---
|
||||
|
||||
// Create 创建渠道
|
||||
@@ -564,15 +651,8 @@ func (s *ChannelService) Create(ctx context.Context, input *CreateChannelInput)
|
||||
return nil, ErrChannelExists
|
||||
}
|
||||
|
||||
// 检查分组冲突
|
||||
if len(input.GroupIDs) > 0 {
|
||||
conflicting, err := s.repo.GetGroupsInOtherChannels(ctx, 0, input.GroupIDs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("check group conflicts: %w", err)
|
||||
}
|
||||
if len(conflicting) > 0 {
|
||||
return nil, ErrGroupAlreadyInChannel
|
||||
}
|
||||
if err := s.checkGroupConflicts(ctx, 0, input.GroupIDs); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
channel := &Channel{
|
||||
@@ -585,18 +665,13 @@ func (s *ChannelService) Create(ctx context.Context, input *CreateChannelInput)
|
||||
ModelPricing: input.ModelPricing,
|
||||
ModelMapping: input.ModelMapping,
|
||||
Features: input.Features,
|
||||
FeaturesConfig: input.FeaturesConfig,
|
||||
}
|
||||
if channel.BillingModelSource == "" {
|
||||
channel.BillingModelSource = BillingModelSourceChannelMapped
|
||||
}
|
||||
|
||||
if err := validateNoConflictingModels(channel.ModelPricing); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := validatePricingIntervals(channel.ModelPricing); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := validateNoConflictingMappings(channel.ModelMapping); err != nil {
|
||||
if err := validateChannelConfig(channel.ModelPricing, channel.ModelMapping); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -620,105 +695,118 @@ func (s *ChannelService) Update(ctx context.Context, id int64, input *UpdateChan
|
||||
return nil, fmt.Errorf("get channel: %w", err)
|
||||
}
|
||||
|
||||
if input.Name != "" && input.Name != channel.Name {
|
||||
exists, err := s.repo.ExistsByNameExcluding(ctx, input.Name, id)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("check channel exists: %w", err)
|
||||
}
|
||||
if exists {
|
||||
return nil, ErrChannelExists
|
||||
}
|
||||
channel.Name = input.Name
|
||||
}
|
||||
|
||||
if input.Description != nil {
|
||||
channel.Description = *input.Description
|
||||
}
|
||||
|
||||
if input.Status != "" {
|
||||
channel.Status = input.Status
|
||||
}
|
||||
|
||||
if input.RestrictModels != nil {
|
||||
channel.RestrictModels = *input.RestrictModels
|
||||
}
|
||||
if input.Features != nil {
|
||||
channel.Features = *input.Features
|
||||
}
|
||||
|
||||
// 检查分组冲突
|
||||
if input.GroupIDs != nil {
|
||||
conflicting, err := s.repo.GetGroupsInOtherChannels(ctx, id, *input.GroupIDs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("check group conflicts: %w", err)
|
||||
}
|
||||
if len(conflicting) > 0 {
|
||||
return nil, ErrGroupAlreadyInChannel
|
||||
}
|
||||
channel.GroupIDs = *input.GroupIDs
|
||||
}
|
||||
|
||||
if input.ModelPricing != nil {
|
||||
channel.ModelPricing = *input.ModelPricing
|
||||
}
|
||||
|
||||
if input.ModelMapping != nil {
|
||||
channel.ModelMapping = input.ModelMapping
|
||||
}
|
||||
|
||||
if input.BillingModelSource != "" {
|
||||
channel.BillingModelSource = input.BillingModelSource
|
||||
}
|
||||
|
||||
if err := validateNoConflictingModels(channel.ModelPricing); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := validatePricingIntervals(channel.ModelPricing); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := validateNoConflictingMappings(channel.ModelMapping); err != nil {
|
||||
if err := s.applyUpdateInput(ctx, channel, input); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 先获取旧分组,Update 后旧分组关联已删除,无法再查到
|
||||
var oldGroupIDs []int64
|
||||
if s.authCacheInvalidator != nil {
|
||||
var err2 error
|
||||
oldGroupIDs, err2 = s.repo.GetGroupIDs(ctx, id)
|
||||
if err2 != nil {
|
||||
slog.Warn("failed to get old group IDs for cache invalidation", "channel_id", id, "error", err2)
|
||||
}
|
||||
if err := validateChannelConfig(channel.ModelPricing, channel.ModelMapping); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
oldGroupIDs := s.getOldGroupIDs(ctx, id)
|
||||
|
||||
if err := s.repo.Update(ctx, channel); err != nil {
|
||||
return nil, fmt.Errorf("update channel: %w", err)
|
||||
}
|
||||
|
||||
s.invalidateCache()
|
||||
|
||||
// 失效新旧分组的 auth 缓存
|
||||
if s.authCacheInvalidator != nil {
|
||||
seen := make(map[int64]struct{}, len(oldGroupIDs)+len(channel.GroupIDs))
|
||||
for _, gid := range oldGroupIDs {
|
||||
if _, ok := seen[gid]; !ok {
|
||||
seen[gid] = struct{}{}
|
||||
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, gid)
|
||||
}
|
||||
}
|
||||
for _, gid := range channel.GroupIDs {
|
||||
if _, ok := seen[gid]; !ok {
|
||||
seen[gid] = struct{}{}
|
||||
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, gid)
|
||||
}
|
||||
}
|
||||
}
|
||||
s.invalidateAuthCacheForGroups(ctx, oldGroupIDs, channel.GroupIDs)
|
||||
|
||||
return s.repo.GetByID(ctx, id)
|
||||
}
|
||||
|
||||
// applyUpdateInput 将更新请求的字段应用到渠道实体上。
|
||||
func (s *ChannelService) applyUpdateInput(ctx context.Context, channel *Channel, input *UpdateChannelInput) error {
|
||||
if input.Name != "" && input.Name != channel.Name {
|
||||
exists, err := s.repo.ExistsByNameExcluding(ctx, input.Name, channel.ID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("check channel exists: %w", err)
|
||||
}
|
||||
if exists {
|
||||
return ErrChannelExists
|
||||
}
|
||||
channel.Name = input.Name
|
||||
}
|
||||
if input.Description != nil {
|
||||
channel.Description = *input.Description
|
||||
}
|
||||
if input.Status != "" {
|
||||
channel.Status = input.Status
|
||||
}
|
||||
if input.RestrictModels != nil {
|
||||
channel.RestrictModels = *input.RestrictModels
|
||||
}
|
||||
if input.Features != nil {
|
||||
channel.Features = *input.Features
|
||||
}
|
||||
if input.GroupIDs != nil {
|
||||
if err := s.checkGroupConflicts(ctx, channel.ID, *input.GroupIDs); err != nil {
|
||||
return err
|
||||
}
|
||||
channel.GroupIDs = *input.GroupIDs
|
||||
}
|
||||
if input.ModelPricing != nil {
|
||||
channel.ModelPricing = *input.ModelPricing
|
||||
}
|
||||
if input.ModelMapping != nil {
|
||||
channel.ModelMapping = input.ModelMapping
|
||||
}
|
||||
if input.BillingModelSource != "" {
|
||||
channel.BillingModelSource = input.BillingModelSource
|
||||
}
|
||||
if input.FeaturesConfig != nil {
|
||||
channel.FeaturesConfig = input.FeaturesConfig
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkGroupConflicts 检查待关联的分组是否已属于其他渠道。
|
||||
// channelID 为当前渠道 ID(Create 时传 0)。
|
||||
func (s *ChannelService) checkGroupConflicts(ctx context.Context, channelID int64, groupIDs []int64) error {
|
||||
if len(groupIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
conflicting, err := s.repo.GetGroupsInOtherChannels(ctx, channelID, groupIDs)
|
||||
if err != nil {
|
||||
return fmt.Errorf("check group conflicts: %w", err)
|
||||
}
|
||||
if len(conflicting) > 0 {
|
||||
return ErrGroupAlreadyInChannel
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// getOldGroupIDs 获取渠道更新前的关联分组 ID(用于失效 auth 缓存)。
|
||||
func (s *ChannelService) getOldGroupIDs(ctx context.Context, channelID int64) []int64 {
|
||||
if s.authCacheInvalidator == nil {
|
||||
return nil
|
||||
}
|
||||
oldGroupIDs, err := s.repo.GetGroupIDs(ctx, channelID)
|
||||
if err != nil {
|
||||
slog.Warn("failed to get old group IDs for cache invalidation", "channel_id", channelID, "error", err)
|
||||
}
|
||||
return oldGroupIDs
|
||||
}
|
||||
|
||||
// invalidateAuthCacheForGroups 对新旧分组去重后逐个失效 auth 缓存。
|
||||
func (s *ChannelService) invalidateAuthCacheForGroups(ctx context.Context, groupIDSets ...[]int64) {
|
||||
if s.authCacheInvalidator == nil {
|
||||
return
|
||||
}
|
||||
seen := make(map[int64]struct{})
|
||||
for _, ids := range groupIDSets {
|
||||
for _, gid := range ids {
|
||||
if _, ok := seen[gid]; ok {
|
||||
continue
|
||||
}
|
||||
seen[gid] = struct{}{}
|
||||
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, gid)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Delete 删除渠道
|
||||
func (s *ChannelService) Delete(ctx context.Context, id int64) error {
|
||||
// 先获取关联分组用于失效缓存
|
||||
groupIDs, err := s.repo.GetGroupIDs(ctx, id)
|
||||
if err != nil {
|
||||
slog.Warn("failed to get group IDs before delete", "channel_id", id, "error", err)
|
||||
@@ -729,12 +817,7 @@ func (s *ChannelService) Delete(ctx context.Context, id int64) error {
|
||||
}
|
||||
|
||||
s.invalidateCache()
|
||||
|
||||
if s.authCacheInvalidator != nil {
|
||||
for _, gid := range groupIDs {
|
||||
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, gid)
|
||||
}
|
||||
}
|
||||
s.invalidateAuthCacheForGroups(ctx, groupIDs)
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -847,6 +930,7 @@ type CreateChannelInput struct {
|
||||
BillingModelSource string
|
||||
RestrictModels bool
|
||||
Features string
|
||||
FeaturesConfig map[string]any
|
||||
}
|
||||
|
||||
// UpdateChannelInput 更新渠道输入
|
||||
@@ -860,4 +944,5 @@ type UpdateChannelInput struct {
|
||||
BillingModelSource string
|
||||
RestrictModels *bool
|
||||
Features *string
|
||||
FeaturesConfig map[string]any
|
||||
}
|
||||
|
||||
62
backend/internal/service/channel_websearch_test.go
Normal file
62
backend/internal/service/channel_websearch_test.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestChannel_IsWebSearchEmulationEnabled_Enabled(t *testing.T) {
|
||||
c := &Channel{
|
||||
FeaturesConfig: map[string]any{
|
||||
featureKeyWebSearchEmulation: map[string]any{"anthropic": true},
|
||||
},
|
||||
}
|
||||
require.True(t, c.IsWebSearchEmulationEnabled("anthropic"))
|
||||
}
|
||||
|
||||
func TestChannel_IsWebSearchEmulationEnabled_DifferentPlatform(t *testing.T) {
|
||||
c := &Channel{
|
||||
FeaturesConfig: map[string]any{
|
||||
featureKeyWebSearchEmulation: map[string]any{"anthropic": true},
|
||||
},
|
||||
}
|
||||
require.False(t, c.IsWebSearchEmulationEnabled("openai"))
|
||||
}
|
||||
|
||||
func TestChannel_IsWebSearchEmulationEnabled_Disabled(t *testing.T) {
|
||||
c := &Channel{
|
||||
FeaturesConfig: map[string]any{
|
||||
featureKeyWebSearchEmulation: map[string]any{"anthropic": false},
|
||||
},
|
||||
}
|
||||
require.False(t, c.IsWebSearchEmulationEnabled("anthropic"))
|
||||
}
|
||||
|
||||
func TestChannel_IsWebSearchEmulationEnabled_NilFeaturesConfig(t *testing.T) {
|
||||
c := &Channel{FeaturesConfig: nil}
|
||||
require.False(t, c.IsWebSearchEmulationEnabled("anthropic"))
|
||||
}
|
||||
|
||||
func TestChannel_IsWebSearchEmulationEnabled_NilChannel(t *testing.T) {
|
||||
var c *Channel
|
||||
require.False(t, c.IsWebSearchEmulationEnabled("anthropic"))
|
||||
}
|
||||
|
||||
func TestChannel_IsWebSearchEmulationEnabled_WrongStructure(t *testing.T) {
|
||||
c := &Channel{
|
||||
FeaturesConfig: map[string]any{
|
||||
featureKeyWebSearchEmulation: true, // not a map
|
||||
},
|
||||
}
|
||||
require.False(t, c.IsWebSearchEmulationEnabled("anthropic"))
|
||||
}
|
||||
|
||||
func TestChannel_IsWebSearchEmulationEnabled_PlatformValueNotBool(t *testing.T) {
|
||||
c := &Channel{
|
||||
FeaturesConfig: map[string]any{
|
||||
featureKeyWebSearchEmulation: map[string]any{"anthropic": "yes"},
|
||||
},
|
||||
}
|
||||
require.False(t, c.IsWebSearchEmulationEnabled("anthropic"))
|
||||
}
|
||||
@@ -249,6 +249,10 @@ const (
|
||||
SettingKeyEnableMetadataPassthrough = "enable_metadata_passthrough"
|
||||
// SettingKeyEnableCCHSigning 是否对 billing header 中的 cch 进行 xxHash64 签名(默认 false)
|
||||
SettingKeyEnableCCHSigning = "enable_cch_signing"
|
||||
|
||||
// Web Search Emulation
|
||||
// SettingKeyWebSearchEmulationConfig 全局 web search 模拟配置(JSON)
|
||||
SettingKeyWebSearchEmulationConfig = "web_search_emulation_config"
|
||||
)
|
||||
|
||||
// AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys).
|
||||
|
||||
@@ -3785,6 +3785,11 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
return nil, fmt.Errorf("parse request: empty request")
|
||||
}
|
||||
|
||||
// Web Search 模拟:纯 web_search 请求时,直接调用搜索 API 构造响应
|
||||
if account != nil && s.shouldEmulateWebSearch(ctx, account, parsed.Body) {
|
||||
return s.handleWebSearchEmulation(ctx, c, account, parsed)
|
||||
}
|
||||
|
||||
if account != nil && account.IsAnthropicAPIKeyPassthroughEnabled() {
|
||||
passthroughBody := parsed.Body
|
||||
passthroughModel := parsed.Model
|
||||
|
||||
358
backend/internal/service/gateway_websearch_emulation.go
Normal file
358
backend/internal/service/gateway_websearch_emulation.go
Normal file
@@ -0,0 +1,358 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/websearch"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
// Web search emulation constants
|
||||
const (
|
||||
toolTypeWebSearchPrefix = "web_search"
|
||||
toolTypeGoogleSearch = "google_search"
|
||||
toolNameWebSearch = "web_search"
|
||||
toolNameGoogleSearch = "google_search"
|
||||
toolNameWebSearch2025 = "web_search_20250305"
|
||||
|
||||
webSearchDefaultMaxResults = 5
|
||||
defaultWebSearchModel = "claude-sonnet-4-6"
|
||||
webSearchMsgIDPrefix = "msg_ws_"
|
||||
webSearchToolUseIDPrefix = "srvtoolu_ws_"
|
||||
tokenEstimateDivisor = 4
|
||||
|
||||
// featureKeyWebSearchEmulation is the key used in Account.Extra and Channel.FeaturesConfig.
|
||||
featureKeyWebSearchEmulation = "web_search_emulation"
|
||||
)
|
||||
|
||||
// webSearchManagerPtr stores *websearch.Manager atomically for concurrent safety.
|
||||
var webSearchManagerPtr atomic.Pointer[websearch.Manager]
|
||||
|
||||
// SetWebSearchManager wires the websearch.Manager into the gateway (goroutine-safe).
|
||||
func SetWebSearchManager(m *websearch.Manager) {
|
||||
webSearchManagerPtr.Store(m)
|
||||
}
|
||||
|
||||
func getWebSearchManager() *websearch.Manager {
|
||||
return webSearchManagerPtr.Load()
|
||||
}
|
||||
|
||||
// shouldEmulateWebSearch checks whether a request should be intercepted.
|
||||
//
|
||||
// Judgment chain: manager exists → only web_search tool → global enabled → account enabled.
|
||||
// Note: channel-level control is enforced via the account's extra field; the channel toggle
|
||||
// in the admin UI sets the account's flag for all accounts in that channel's groups.
|
||||
func (s *GatewayService) shouldEmulateWebSearch(ctx context.Context, account *Account, body []byte) bool {
|
||||
if getWebSearchManager() == nil {
|
||||
return false
|
||||
}
|
||||
if !isOnlyWebSearchToolInBody(body) {
|
||||
return false
|
||||
}
|
||||
if !s.settingService.IsWebSearchEmulationEnabled(ctx) {
|
||||
return false
|
||||
}
|
||||
if !account.IsWebSearchEmulationEnabled() {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// isOnlyWebSearchToolInBody checks if the body contains exactly one web_search tool.
|
||||
func isOnlyWebSearchToolInBody(body []byte) bool {
|
||||
tools := gjson.GetBytes(body, "tools")
|
||||
if !tools.IsArray() {
|
||||
return false
|
||||
}
|
||||
arr := tools.Array()
|
||||
if len(arr) != 1 {
|
||||
return false
|
||||
}
|
||||
return isWebSearchToolJSON(arr[0])
|
||||
}
|
||||
|
||||
func isWebSearchToolJSON(tool gjson.Result) bool {
|
||||
toolType := tool.Get("type").String()
|
||||
if strings.HasPrefix(toolType, toolTypeWebSearchPrefix) || toolType == toolTypeGoogleSearch {
|
||||
return true
|
||||
}
|
||||
switch tool.Get("name").String() {
|
||||
case toolNameWebSearch, toolNameGoogleSearch, toolNameWebSearch2025:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// extractSearchQueryFromBody extracts the last user message text as the search query.
|
||||
func extractSearchQueryFromBody(body []byte) string {
|
||||
messages := gjson.GetBytes(body, "messages")
|
||||
if !messages.IsArray() {
|
||||
return ""
|
||||
}
|
||||
arr := messages.Array()
|
||||
if len(arr) == 0 {
|
||||
return ""
|
||||
}
|
||||
lastMsg := arr[len(arr)-1]
|
||||
if lastMsg.Get("role").String() != "user" {
|
||||
return ""
|
||||
}
|
||||
return extractWebSearchTextFromContent(lastMsg.Get("content"))
|
||||
}
|
||||
|
||||
func extractWebSearchTextFromContent(content gjson.Result) string {
|
||||
if content.Type == gjson.String {
|
||||
return content.String()
|
||||
}
|
||||
if content.IsArray() {
|
||||
for _, block := range content.Array() {
|
||||
if block.Get("type").String() == "text" {
|
||||
if text := block.Get("text").String(); text != "" {
|
||||
return text
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// handleWebSearchEmulation intercepts a web-search-only request,
|
||||
// calls a third-party search API, and constructs an Anthropic-format response.
|
||||
func (s *GatewayService) handleWebSearchEmulation(
|
||||
ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest,
|
||||
) (*ForwardResult, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
// Release the serial queue lock immediately — we don't need upstream.
|
||||
if parsed.OnUpstreamAccepted != nil {
|
||||
parsed.OnUpstreamAccepted()
|
||||
}
|
||||
|
||||
query := extractSearchQueryFromBody(parsed.Body)
|
||||
if query == "" {
|
||||
return nil, fmt.Errorf("web search emulation: no query found in messages")
|
||||
}
|
||||
|
||||
slog.Info("web search emulation: executing search",
|
||||
"account_id", account.ID, "account_name", account.Name, "query", query)
|
||||
|
||||
resp, providerName, err := doWebSearch(ctx, account, query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
slog.Info("web search emulation: search completed",
|
||||
"provider", providerName, "results_count", len(resp.Results))
|
||||
|
||||
model := parsed.Model
|
||||
if model == "" {
|
||||
model = defaultWebSearchModel
|
||||
}
|
||||
|
||||
if parsed.Stream {
|
||||
return writeWebSearchStreamResponse(c, query, resp, model, startTime)
|
||||
}
|
||||
return writeWebSearchNonStreamResponse(c, query, resp, model, startTime)
|
||||
}
|
||||
|
||||
func doWebSearch(ctx context.Context, account *Account, query string) (*websearch.SearchResponse, string, error) {
|
||||
proxyURL := resolveAccountProxyURL(account)
|
||||
mgr := getWebSearchManager()
|
||||
if mgr == nil {
|
||||
return nil, "", fmt.Errorf("web search emulation: manager not initialized")
|
||||
}
|
||||
resp, providerName, err := mgr.SearchWithBestProvider(ctx, websearch.SearchRequest{
|
||||
Query: query, MaxResults: webSearchDefaultMaxResults, ProxyURL: proxyURL,
|
||||
})
|
||||
if err != nil {
|
||||
slog.Error("web search emulation: search failed", "error", err)
|
||||
return nil, "", fmt.Errorf("web search emulation: %w", err)
|
||||
}
|
||||
return resp, providerName, nil
|
||||
}
|
||||
|
||||
func resolveAccountProxyURL(account *Account) string {
|
||||
if account.ProxyID != nil && account.Proxy != nil {
|
||||
return account.Proxy.URL()
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// --- SSE streaming response ---
|
||||
|
||||
func writeWebSearchStreamResponse(
|
||||
c *gin.Context, query string, resp *websearch.SearchResponse, model string, startTime time.Time,
|
||||
) (*ForwardResult, error) {
|
||||
msgID := webSearchMsgIDPrefix + uuid.New().String()
|
||||
toolUseID := webSearchToolUseIDPrefix + uuid.New().String()[:16]
|
||||
|
||||
setSSEHeaders(c)
|
||||
if err := writeSSEMessageStart(c.Writer, msgID, model); err != nil {
|
||||
return nil, fmt.Errorf("web search emulation: SSE write: %w", err)
|
||||
}
|
||||
writeSSEServerToolUse(c.Writer, toolUseID, query, 0)
|
||||
writeSSEToolResult(c.Writer, toolUseID, resp.Results, 1)
|
||||
textSummary := buildTextSummary(query, resp.Results)
|
||||
writeSSETextBlock(c.Writer, textSummary, 2)
|
||||
writeSSEMessageEnd(c.Writer, len(textSummary)/tokenEstimateDivisor)
|
||||
c.Writer.Flush()
|
||||
|
||||
return &ForwardResult{Model: model, Duration: time.Since(startTime), Usage: ClaudeUsage{}}, nil
|
||||
}
|
||||
|
||||
func setSSEHeaders(c *gin.Context) {
|
||||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||
c.Writer.Header().Set("Connection", "keep-alive")
|
||||
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
||||
c.Writer.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
func writeSSEMessageStart(w http.ResponseWriter, msgID, model string) error {
|
||||
evt := map[string]any{
|
||||
"type": "message_start",
|
||||
"message": map[string]any{
|
||||
"id": msgID, "type": "message", "role": "assistant", "model": model,
|
||||
"content": []any{}, "stop_reason": nil, "stop_sequence": nil,
|
||||
"usage": map[string]int{"input_tokens": 0, "output_tokens": 0},
|
||||
},
|
||||
}
|
||||
return flushSSEJSON(w, "message_start", evt)
|
||||
}
|
||||
|
||||
func writeSSEServerToolUse(w http.ResponseWriter, toolUseID, query string, index int) {
|
||||
start := map[string]any{
|
||||
"type": "content_block_start", "index": index,
|
||||
"content_block": map[string]any{
|
||||
"type": "server_tool_use", "id": toolUseID,
|
||||
"name": toolNameWebSearch, "input": map[string]string{"query": query},
|
||||
},
|
||||
}
|
||||
_ = flushSSEJSON(w, "content_block_start", start)
|
||||
_ = flushSSEJSON(w, "content_block_stop", map[string]any{"type": "content_block_stop", "index": index})
|
||||
}
|
||||
|
||||
func writeSSEToolResult(w http.ResponseWriter, toolUseID string, results []websearch.SearchResult, index int) {
|
||||
start := map[string]any{
|
||||
"type": "content_block_start", "index": index,
|
||||
"content_block": map[string]any{
|
||||
"type": "web_search_tool_result", "tool_use_id": toolUseID,
|
||||
"content": buildSearchResultBlocks(results),
|
||||
},
|
||||
}
|
||||
_ = flushSSEJSON(w, "content_block_start", start)
|
||||
_ = flushSSEJSON(w, "content_block_stop", map[string]any{"type": "content_block_stop", "index": index})
|
||||
}
|
||||
|
||||
func writeSSETextBlock(w http.ResponseWriter, text string, index int) {
|
||||
_ = flushSSEJSON(w, "content_block_start", map[string]any{
|
||||
"type": "content_block_start", "index": index,
|
||||
"content_block": map[string]any{"type": "text", "text": ""},
|
||||
})
|
||||
_ = flushSSEJSON(w, "content_block_delta", map[string]any{
|
||||
"type": "content_block_delta", "index": index,
|
||||
"delta": map[string]string{"type": "text_delta", "text": text},
|
||||
})
|
||||
_ = flushSSEJSON(w, "content_block_stop", map[string]any{"type": "content_block_stop", "index": index})
|
||||
}
|
||||
|
||||
func writeSSEMessageEnd(w http.ResponseWriter, outputTokens int) {
|
||||
_ = flushSSEJSON(w, "message_delta", map[string]any{
|
||||
"type": "message_delta",
|
||||
"delta": map[string]any{"stop_reason": "end_turn", "stop_sequence": nil},
|
||||
"usage": map[string]int{"output_tokens": outputTokens},
|
||||
})
|
||||
_ = flushSSEJSON(w, "message_stop", map[string]string{"type": "message_stop"})
|
||||
}
|
||||
|
||||
// flushSSEJSON marshals data to JSON and writes an SSE event. Returns error on marshal failure.
|
||||
func flushSSEJSON(w http.ResponseWriter, event string, data any) error {
|
||||
b, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
slog.Error("web search emulation: failed to marshal SSE event",
|
||||
"event", event, "error", err)
|
||||
return err
|
||||
}
|
||||
fmt.Fprintf(w, "event: %s\ndata: %s\n\n", event, b)
|
||||
if f, ok := w.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// --- Non-streaming JSON response ---
|
||||
|
||||
func writeWebSearchNonStreamResponse(
|
||||
c *gin.Context, query string, resp *websearch.SearchResponse, model string, startTime time.Time,
|
||||
) (*ForwardResult, error) {
|
||||
msgID := webSearchMsgIDPrefix + uuid.New().String()
|
||||
toolUseID := webSearchToolUseIDPrefix + uuid.New().String()[:16]
|
||||
textSummary := buildTextSummary(query, resp.Results)
|
||||
|
||||
msg := map[string]any{
|
||||
"id": msgID, "type": "message", "role": "assistant", "model": model,
|
||||
"content": []any{
|
||||
map[string]any{
|
||||
"type": "server_tool_use", "id": toolUseID,
|
||||
"name": toolNameWebSearch, "input": map[string]string{"query": query},
|
||||
},
|
||||
map[string]any{
|
||||
"type": "web_search_tool_result", "tool_use_id": toolUseID,
|
||||
"content": buildSearchResultBlocks(resp.Results),
|
||||
},
|
||||
map[string]any{"type": "text", "text": textSummary},
|
||||
},
|
||||
"stop_reason": "end_turn", "stop_sequence": nil,
|
||||
"usage": map[string]int{"input_tokens": 0, "output_tokens": len(textSummary) / tokenEstimateDivisor},
|
||||
}
|
||||
|
||||
body, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("web search emulation: marshal response: %w", err)
|
||||
}
|
||||
c.Data(http.StatusOK, "application/json", body)
|
||||
|
||||
return &ForwardResult{Model: model, Duration: time.Since(startTime), Usage: ClaudeUsage{}}, nil
|
||||
}
|
||||
|
||||
// --- Helpers ---
|
||||
|
||||
func buildSearchResultBlocks(results []websearch.SearchResult) []map[string]string {
|
||||
blocks := make([]map[string]string, 0, len(results))
|
||||
for _, r := range results {
|
||||
block := map[string]string{
|
||||
"type": "web_search_result",
|
||||
"url": r.URL,
|
||||
"title": r.Title,
|
||||
}
|
||||
if r.Snippet != "" {
|
||||
block["page_content"] = r.Snippet
|
||||
}
|
||||
if r.PageAge != "" {
|
||||
block["page_age"] = r.PageAge
|
||||
}
|
||||
blocks = append(blocks, block)
|
||||
}
|
||||
return blocks
|
||||
}
|
||||
|
||||
func buildTextSummary(query string, results []websearch.SearchResult) string {
|
||||
if len(results) == 0 {
|
||||
return "No search results found for: " + query
|
||||
}
|
||||
var sb strings.Builder
|
||||
fmt.Fprintf(&sb, "Here are the search results for \"%s\":\n\n", query)
|
||||
for i, r := range results {
|
||||
fmt.Fprintf(&sb, "%d. **%s**\n %s\n %s\n\n", i+1, r.Title, r.URL, r.Snippet)
|
||||
}
|
||||
return sb.String()
|
||||
}
|
||||
142
backend/internal/service/gateway_websearch_emulation_test.go
Normal file
142
backend/internal/service/gateway_websearch_emulation_test.go
Normal file
@@ -0,0 +1,142 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/websearch"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// --- isOnlyWebSearchToolInBody ---
|
||||
|
||||
func TestIsOnlyWebSearchToolInBody_WebSearchType(t *testing.T) {
|
||||
require.True(t, isOnlyWebSearchToolInBody([]byte(`{"tools":[{"type":"web_search"}]}`)))
|
||||
}
|
||||
|
||||
func TestIsOnlyWebSearchToolInBody_WebSearch2025Type(t *testing.T) {
|
||||
require.True(t, isOnlyWebSearchToolInBody([]byte(`{"tools":[{"type":"web_search_20250305"}]}`)))
|
||||
}
|
||||
|
||||
func TestIsOnlyWebSearchToolInBody_GoogleSearchType(t *testing.T) {
|
||||
require.True(t, isOnlyWebSearchToolInBody([]byte(`{"tools":[{"type":"google_search"}]}`)))
|
||||
}
|
||||
|
||||
func TestIsOnlyWebSearchToolInBody_NameWebSearch(t *testing.T) {
|
||||
require.True(t, isOnlyWebSearchToolInBody([]byte(`{"tools":[{"name":"web_search"}]}`)))
|
||||
}
|
||||
|
||||
func TestIsOnlyWebSearchToolInBody_NameWebSearch2025(t *testing.T) {
|
||||
require.True(t, isOnlyWebSearchToolInBody([]byte(`{"tools":[{"name":"web_search_20250305"}]}`)))
|
||||
}
|
||||
|
||||
func TestIsOnlyWebSearchToolInBody_NameGoogleSearch(t *testing.T) {
|
||||
require.True(t, isOnlyWebSearchToolInBody([]byte(`{"tools":[{"name":"google_search"}]}`)))
|
||||
}
|
||||
|
||||
func TestIsOnlyWebSearchToolInBody_MultipleTools(t *testing.T) {
|
||||
require.False(t, isOnlyWebSearchToolInBody(
|
||||
[]byte(`{"tools":[{"type":"web_search"},{"type":"text_editor"}]}`)))
|
||||
}
|
||||
|
||||
func TestIsOnlyWebSearchToolInBody_NoTools(t *testing.T) {
|
||||
require.False(t, isOnlyWebSearchToolInBody([]byte(`{"model":"claude-3"}`)))
|
||||
}
|
||||
|
||||
func TestIsOnlyWebSearchToolInBody_EmptyToolsArray(t *testing.T) {
|
||||
require.False(t, isOnlyWebSearchToolInBody([]byte(`{"tools":[]}`)))
|
||||
}
|
||||
|
||||
func TestIsOnlyWebSearchToolInBody_NonWebSearchTool(t *testing.T) {
|
||||
require.False(t, isOnlyWebSearchToolInBody([]byte(`{"tools":[{"type":"text_editor"}]}`)))
|
||||
}
|
||||
|
||||
func TestIsOnlyWebSearchToolInBody_ToolsNotArray(t *testing.T) {
|
||||
require.False(t, isOnlyWebSearchToolInBody([]byte(`{"tools":"web_search"}`)))
|
||||
}
|
||||
|
||||
// --- extractSearchQueryFromBody ---
|
||||
|
||||
func TestExtractSearchQueryFromBody_StringContent(t *testing.T) {
|
||||
body := `{"messages":[{"role":"user","content":"what is golang"}]}`
|
||||
require.Equal(t, "what is golang", extractSearchQueryFromBody([]byte(body)))
|
||||
}
|
||||
|
||||
func TestExtractSearchQueryFromBody_ArrayContent(t *testing.T) {
|
||||
body := `{"messages":[{"role":"user","content":[{"type":"text","text":"search this"}]}]}`
|
||||
require.Equal(t, "search this", extractSearchQueryFromBody([]byte(body)))
|
||||
}
|
||||
|
||||
func TestExtractSearchQueryFromBody_MultipleMessages(t *testing.T) {
|
||||
body := `{"messages":[{"role":"user","content":"first"},{"role":"assistant","content":"ok"},{"role":"user","content":"second"}]}`
|
||||
require.Equal(t, "second", extractSearchQueryFromBody([]byte(body)))
|
||||
}
|
||||
|
||||
func TestExtractSearchQueryFromBody_LastMessageNotUser(t *testing.T) {
|
||||
body := `{"messages":[{"role":"user","content":"q"},{"role":"assistant","content":"a"}]}`
|
||||
require.Equal(t, "", extractSearchQueryFromBody([]byte(body)))
|
||||
}
|
||||
|
||||
func TestExtractSearchQueryFromBody_EmptyMessages(t *testing.T) {
|
||||
require.Equal(t, "", extractSearchQueryFromBody([]byte(`{"messages":[]}`)))
|
||||
}
|
||||
|
||||
func TestExtractSearchQueryFromBody_NoMessages(t *testing.T) {
|
||||
require.Equal(t, "", extractSearchQueryFromBody([]byte(`{"model":"claude-3"}`)))
|
||||
}
|
||||
|
||||
func TestExtractSearchQueryFromBody_ArrayContentSkipsEmptyText(t *testing.T) {
|
||||
body := `{"messages":[{"role":"user","content":[{"type":"image"},{"type":"text","text":""},{"type":"text","text":"real query"}]}]}`
|
||||
require.Equal(t, "real query", extractSearchQueryFromBody([]byte(body)))
|
||||
}
|
||||
|
||||
func TestExtractSearchQueryFromBody_ArrayContentNoTextBlock(t *testing.T) {
|
||||
body := `{"messages":[{"role":"user","content":[{"type":"image","source":{}}]}]}`
|
||||
require.Equal(t, "", extractSearchQueryFromBody([]byte(body)))
|
||||
}
|
||||
|
||||
// --- buildSearchResultBlocks ---
|
||||
|
||||
func TestBuildSearchResultBlocks_WithResults(t *testing.T) {
|
||||
results := []websearch.SearchResult{
|
||||
{URL: "https://a.com", Title: "A", Snippet: "snippet a", PageAge: "2 days"},
|
||||
{URL: "https://b.com", Title: "B", Snippet: "snippet b"},
|
||||
}
|
||||
blocks := buildSearchResultBlocks(results)
|
||||
require.Len(t, blocks, 2)
|
||||
require.Equal(t, "web_search_result", blocks[0]["type"])
|
||||
require.Equal(t, "https://a.com", blocks[0]["url"])
|
||||
require.Equal(t, "snippet a", blocks[0]["page_content"])
|
||||
require.Equal(t, "2 days", blocks[0]["page_age"])
|
||||
// Second result has no PageAge
|
||||
require.Equal(t, "https://b.com", blocks[1]["url"])
|
||||
_, hasPageAge := blocks[1]["page_age"]
|
||||
require.False(t, hasPageAge)
|
||||
}
|
||||
|
||||
func TestBuildSearchResultBlocks_Empty(t *testing.T) {
|
||||
blocks := buildSearchResultBlocks(nil)
|
||||
require.Empty(t, blocks)
|
||||
}
|
||||
|
||||
func TestBuildSearchResultBlocks_SnippetEmpty(t *testing.T) {
|
||||
blocks := buildSearchResultBlocks([]websearch.SearchResult{{URL: "https://x.com", Title: "X", Snippet: ""}})
|
||||
_, hasContent := blocks[0]["page_content"]
|
||||
require.False(t, hasContent)
|
||||
}
|
||||
|
||||
// --- buildTextSummary ---
|
||||
|
||||
func TestBuildTextSummary_WithResults(t *testing.T) {
|
||||
results := []websearch.SearchResult{
|
||||
{URL: "https://a.com", Title: "A", Snippet: "desc a"},
|
||||
}
|
||||
summary := buildTextSummary("test query", results)
|
||||
require.Contains(t, summary, "test query")
|
||||
require.Contains(t, summary, "1. **A**")
|
||||
require.Contains(t, summary, "https://a.com")
|
||||
}
|
||||
|
||||
func TestBuildTextSummary_NoResults(t *testing.T) {
|
||||
summary := buildTextSummary("test", nil)
|
||||
require.Contains(t, summary, "No search results found for: test")
|
||||
}
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/imroc/req/v3"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"golang.org/x/sync/singleflight"
|
||||
)
|
||||
|
||||
@@ -106,6 +107,7 @@ type SettingService struct {
|
||||
cfg *config.Config
|
||||
onUpdate func() // Callback when settings are updated (for cache invalidation)
|
||||
version string // Application version
|
||||
webSearchRedis *redis.Client // optional: Redis client for web search quota tracking
|
||||
}
|
||||
|
||||
// NewSettingService 创建系统设置服务实例
|
||||
@@ -1217,6 +1219,14 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
|
||||
result.EnableMetadataPassthrough = settings[SettingKeyEnableMetadataPassthrough] == "true"
|
||||
result.EnableCCHSigning = settings[SettingKeyEnableCCHSigning] == "true"
|
||||
|
||||
// Web search emulation: quick enabled check from the JSON config
|
||||
if raw := settings[SettingKeyWebSearchEmulationConfig]; raw != "" {
|
||||
var wsCfg WebSearchEmulationConfig
|
||||
if err := json.Unmarshal([]byte(raw), &wsCfg); err == nil {
|
||||
result.WebSearchEmulationEnabled = wsCfg.Enabled && len(wsCfg.Providers) > 0
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
|
||||
@@ -106,6 +106,9 @@ type SystemSettings struct {
|
||||
EnableFingerprintUnification bool // 是否统一 OAuth 账号的指纹头(默认 true)
|
||||
EnableMetadataPassthrough bool // 是否透传客户端原始 metadata(默认 false)
|
||||
EnableCCHSigning bool // 是否对 billing header cch 进行签名(默认 false)
|
||||
|
||||
// Web Search Emulation (read-only quick check; full config via dedicated API)
|
||||
WebSearchEmulationEnabled bool
|
||||
}
|
||||
|
||||
type DefaultSubscriptionSetting struct {
|
||||
|
||||
253
backend/internal/service/websearch_config.go
Normal file
253
backend/internal/service/websearch_config.go
Normal file
@@ -0,0 +1,253 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/websearch"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"golang.org/x/sync/singleflight"
|
||||
)
|
||||
|
||||
// WebSearchEmulationConfig holds the global web search emulation configuration.
|
||||
type WebSearchEmulationConfig struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Providers []WebSearchProviderConfig `json:"providers"`
|
||||
}
|
||||
|
||||
// WebSearchProviderConfig describes a single search provider (Brave or Tavily).
|
||||
type WebSearchProviderConfig struct {
|
||||
Type string `json:"type"` // websearch.ProviderTypeBrave | Tavily
|
||||
APIKey string `json:"api_key,omitempty"` // secret — omitted in API responses
|
||||
APIKeyConfigured bool `json:"api_key_configured"` // read-only mask
|
||||
Priority int `json:"priority"` // lower = higher priority
|
||||
QuotaLimit int64 `json:"quota_limit"` // 0 = unlimited
|
||||
QuotaRefreshInterval string `json:"quota_refresh_interval"` // websearch.QuotaRefresh*
|
||||
QuotaUsed int64 `json:"quota_used,omitempty"` // read-only: current period usage
|
||||
ProxyID *int64 `json:"proxy_id"` // optional proxy association
|
||||
ExpiresAt *int64 `json:"expires_at,omitempty"` // optional expiration timestamp
|
||||
}
|
||||
|
||||
// --- Validation ---
|
||||
|
||||
const maxWebSearchProviders = 10
|
||||
|
||||
var validProviderTypes = map[string]bool{
|
||||
websearch.ProviderTypeBrave: true,
|
||||
websearch.ProviderTypeTavily: true,
|
||||
}
|
||||
|
||||
var validQuotaIntervals = map[string]bool{
|
||||
websearch.QuotaRefreshDaily: true,
|
||||
websearch.QuotaRefreshWeekly: true,
|
||||
websearch.QuotaRefreshMonthly: true,
|
||||
"": true, // defaults to monthly
|
||||
}
|
||||
|
||||
func validateWebSearchConfig(cfg *WebSearchEmulationConfig) error {
|
||||
if cfg == nil {
|
||||
return nil
|
||||
}
|
||||
if len(cfg.Providers) > maxWebSearchProviders {
|
||||
return fmt.Errorf("too many providers (max %d)", maxWebSearchProviders)
|
||||
}
|
||||
seen := make(map[string]bool, len(cfg.Providers))
|
||||
for i, p := range cfg.Providers {
|
||||
if !validProviderTypes[p.Type] {
|
||||
return fmt.Errorf("provider[%d]: invalid type %q", i, p.Type)
|
||||
}
|
||||
if !validQuotaIntervals[p.QuotaRefreshInterval] {
|
||||
return fmt.Errorf("provider[%d]: invalid quota_refresh_interval %q", i, p.QuotaRefreshInterval)
|
||||
}
|
||||
if p.QuotaLimit < 0 {
|
||||
return fmt.Errorf("provider[%d]: quota_limit must be >= 0", i)
|
||||
}
|
||||
if seen[p.Type] {
|
||||
return fmt.Errorf("provider[%d]: duplicate type %q", i, p.Type)
|
||||
}
|
||||
seen[p.Type] = true
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// --- In-process cache (same pattern as gateway forwarding settings) ---
|
||||
|
||||
const sfKeyWebSearchConfig = "web_search_emulation_config"
|
||||
|
||||
type cachedWebSearchEmulationConfig struct {
|
||||
config *WebSearchEmulationConfig
|
||||
expiresAt int64 // unix nano
|
||||
}
|
||||
|
||||
var webSearchEmulationCache atomic.Value // *cachedWebSearchEmulationConfig
|
||||
var webSearchEmulationSF singleflight.Group
|
||||
|
||||
const (
|
||||
webSearchEmulationCacheTTL = 60 * time.Second
|
||||
webSearchEmulationErrorTTL = 5 * time.Second
|
||||
webSearchEmulationDBTimeout = 5 * time.Second
|
||||
)
|
||||
|
||||
// GetWebSearchEmulationConfig returns the configuration with in-process cache + singleflight.
|
||||
func (s *SettingService) GetWebSearchEmulationConfig(ctx context.Context) (*WebSearchEmulationConfig, error) {
|
||||
if cached := webSearchEmulationCache.Load(); cached != nil {
|
||||
c := cached.(*cachedWebSearchEmulationConfig)
|
||||
if time.Now().UnixNano() < c.expiresAt {
|
||||
return c.config, nil
|
||||
}
|
||||
}
|
||||
result, err, _ := webSearchEmulationSF.Do(sfKeyWebSearchConfig, func() (any, error) {
|
||||
return s.loadWebSearchConfigFromDB()
|
||||
})
|
||||
if err != nil {
|
||||
return &WebSearchEmulationConfig{}, err
|
||||
}
|
||||
return result.(*WebSearchEmulationConfig), nil
|
||||
}
|
||||
|
||||
func (s *SettingService) loadWebSearchConfigFromDB() (*WebSearchEmulationConfig, error) {
|
||||
dbCtx, cancel := context.WithTimeout(context.Background(), webSearchEmulationDBTimeout)
|
||||
defer cancel()
|
||||
|
||||
raw, err := s.settingRepo.GetValue(dbCtx, SettingKeyWebSearchEmulationConfig)
|
||||
if err != nil {
|
||||
webSearchEmulationCache.Store(&cachedWebSearchEmulationConfig{
|
||||
config: &WebSearchEmulationConfig{},
|
||||
expiresAt: time.Now().Add(webSearchEmulationErrorTTL).UnixNano(),
|
||||
})
|
||||
return &WebSearchEmulationConfig{}, err
|
||||
}
|
||||
cfg := parseWebSearchConfigJSON(raw)
|
||||
webSearchEmulationCache.Store(&cachedWebSearchEmulationConfig{
|
||||
config: cfg,
|
||||
expiresAt: time.Now().Add(webSearchEmulationCacheTTL).UnixNano(),
|
||||
})
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func parseWebSearchConfigJSON(raw string) *WebSearchEmulationConfig {
|
||||
cfg := &WebSearchEmulationConfig{}
|
||||
if raw == "" {
|
||||
return cfg
|
||||
}
|
||||
if err := json.Unmarshal([]byte(raw), cfg); err != nil {
|
||||
slog.Warn("websearch: failed to parse config JSON", "error", err)
|
||||
return &WebSearchEmulationConfig{}
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
||||
// SaveWebSearchEmulationConfig validates and persists the configuration.
|
||||
// Empty API keys in the input are preserved from the existing config.
|
||||
func (s *SettingService) SaveWebSearchEmulationConfig(ctx context.Context, cfg *WebSearchEmulationConfig) error {
|
||||
if err := validateWebSearchConfig(cfg); err != nil {
|
||||
return infraerrors.BadRequest("INVALID_WEB_SEARCH_CONFIG", err.Error())
|
||||
}
|
||||
s.mergeExistingAPIKeys(ctx, cfg)
|
||||
|
||||
data, err := json.Marshal(cfg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("websearch: marshal config: %w", err)
|
||||
}
|
||||
if err := s.settingRepo.Set(ctx, SettingKeyWebSearchEmulationConfig, string(data)); err != nil {
|
||||
return fmt.Errorf("websearch: save config: %w", err)
|
||||
}
|
||||
// Invalidate: forget singleflight first, then store new value
|
||||
webSearchEmulationSF.Forget(sfKeyWebSearchConfig)
|
||||
webSearchEmulationCache.Store(&cachedWebSearchEmulationConfig{
|
||||
config: cfg,
|
||||
expiresAt: time.Now().Add(webSearchEmulationCacheTTL).UnixNano(),
|
||||
})
|
||||
|
||||
// Hot-reload: rebuild the global Manager with new config
|
||||
s.RebuildWebSearchManager(ctx)
|
||||
return nil
|
||||
}
|
||||
|
||||
// mergeExistingAPIKeys preserves API keys from the current config when incoming value is empty.
|
||||
func (s *SettingService) mergeExistingAPIKeys(ctx context.Context, cfg *WebSearchEmulationConfig) {
|
||||
existing, _ := s.getWebSearchEmulationConfigRaw(ctx)
|
||||
if existing == nil || cfg == nil {
|
||||
return
|
||||
}
|
||||
existingByType := make(map[string]string, len(existing.Providers))
|
||||
for _, p := range existing.Providers {
|
||||
if p.APIKey != "" {
|
||||
existingByType[p.Type] = p.APIKey
|
||||
}
|
||||
}
|
||||
for i := range cfg.Providers {
|
||||
if cfg.Providers[i].APIKey == "" {
|
||||
if key, ok := existingByType[cfg.Providers[i].Type]; ok {
|
||||
cfg.Providers[i].APIKey = key
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SettingService) getWebSearchEmulationConfigRaw(ctx context.Context) (*WebSearchEmulationConfig, error) {
|
||||
raw, err := s.settingRepo.GetValue(ctx, SettingKeyWebSearchEmulationConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return parseWebSearchConfigJSON(raw), nil
|
||||
}
|
||||
|
||||
// IsWebSearchEmulationEnabled is a quick check for whether the global switch is on.
|
||||
func (s *SettingService) IsWebSearchEmulationEnabled(ctx context.Context) bool {
|
||||
cfg, err := s.GetWebSearchEmulationConfig(ctx)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return cfg.Enabled && len(cfg.Providers) > 0
|
||||
}
|
||||
|
||||
// SetWebSearchRedisClient injects the Redis client used for quota tracking.
|
||||
// Call after construction, before first use. Triggers initial Manager build.
|
||||
func (s *SettingService) SetWebSearchRedisClient(ctx context.Context, redisClient *redis.Client) {
|
||||
s.webSearchRedis = redisClient
|
||||
s.RebuildWebSearchManager(ctx)
|
||||
}
|
||||
|
||||
// RebuildWebSearchManager reads the current config and (re)creates the global websearch.Manager.
|
||||
// Called on startup and after SaveWebSearchEmulationConfig.
|
||||
func (s *SettingService) RebuildWebSearchManager(ctx context.Context) {
|
||||
cfg, err := s.GetWebSearchEmulationConfig(ctx)
|
||||
if err != nil || !cfg.Enabled || len(cfg.Providers) == 0 {
|
||||
SetWebSearchManager(nil)
|
||||
return
|
||||
}
|
||||
providerConfigs := make([]websearch.ProviderConfig, 0, len(cfg.Providers))
|
||||
for _, p := range cfg.Providers {
|
||||
providerConfigs = append(providerConfigs, websearch.ProviderConfig{
|
||||
Type: p.Type,
|
||||
APIKey: p.APIKey,
|
||||
Priority: p.Priority,
|
||||
QuotaLimit: p.QuotaLimit,
|
||||
QuotaRefreshInterval: p.QuotaRefreshInterval,
|
||||
ExpiresAt: p.ExpiresAt,
|
||||
})
|
||||
}
|
||||
SetWebSearchManager(websearch.NewManager(providerConfigs, s.webSearchRedis))
|
||||
slog.Info("websearch: manager rebuilt", "provider_count", len(providerConfigs))
|
||||
}
|
||||
|
||||
// SanitizeWebSearchConfig returns a copy with api_key fields masked for API responses.
|
||||
func SanitizeWebSearchConfig(cfg *WebSearchEmulationConfig) *WebSearchEmulationConfig {
|
||||
if cfg == nil {
|
||||
return nil
|
||||
}
|
||||
out := *cfg
|
||||
out.Providers = make([]WebSearchProviderConfig, len(cfg.Providers))
|
||||
for i, p := range cfg.Providers {
|
||||
out.Providers[i] = p
|
||||
out.Providers[i].APIKeyConfigured = p.APIKey != ""
|
||||
out.Providers[i].APIKey = "" // never return the secret
|
||||
}
|
||||
return &out
|
||||
}
|
||||
148
backend/internal/service/websearch_config_test.go
Normal file
148
backend/internal/service/websearch_config_test.go
Normal file
@@ -0,0 +1,148 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// --- validateWebSearchConfig ---
|
||||
|
||||
func TestValidateWebSearchConfig_Nil(t *testing.T) {
|
||||
require.NoError(t, validateWebSearchConfig(nil))
|
||||
}
|
||||
|
||||
func TestValidateWebSearchConfig_Valid(t *testing.T) {
|
||||
cfg := &WebSearchEmulationConfig{
|
||||
Enabled: true,
|
||||
Providers: []WebSearchProviderConfig{
|
||||
{Type: "brave", Priority: 1, QuotaLimit: 1000, QuotaRefreshInterval: "monthly"},
|
||||
{Type: "tavily", Priority: 2, QuotaLimit: 500, QuotaRefreshInterval: "daily"},
|
||||
},
|
||||
}
|
||||
require.NoError(t, validateWebSearchConfig(cfg))
|
||||
}
|
||||
|
||||
func TestValidateWebSearchConfig_TooManyProviders(t *testing.T) {
|
||||
cfg := &WebSearchEmulationConfig{Providers: make([]WebSearchProviderConfig, 11)}
|
||||
for i := range cfg.Providers {
|
||||
cfg.Providers[i] = WebSearchProviderConfig{Type: "brave"}
|
||||
}
|
||||
err := validateWebSearchConfig(cfg)
|
||||
require.ErrorContains(t, err, "too many providers")
|
||||
}
|
||||
|
||||
func TestValidateWebSearchConfig_InvalidType(t *testing.T) {
|
||||
cfg := &WebSearchEmulationConfig{
|
||||
Providers: []WebSearchProviderConfig{{Type: "bing"}},
|
||||
}
|
||||
require.ErrorContains(t, validateWebSearchConfig(cfg), "invalid type")
|
||||
}
|
||||
|
||||
func TestValidateWebSearchConfig_InvalidQuotaInterval(t *testing.T) {
|
||||
cfg := &WebSearchEmulationConfig{
|
||||
Providers: []WebSearchProviderConfig{{Type: "brave", QuotaRefreshInterval: "hourly"}},
|
||||
}
|
||||
require.ErrorContains(t, validateWebSearchConfig(cfg), "invalid quota_refresh_interval")
|
||||
}
|
||||
|
||||
func TestValidateWebSearchConfig_NegativeQuotaLimit(t *testing.T) {
|
||||
cfg := &WebSearchEmulationConfig{
|
||||
Providers: []WebSearchProviderConfig{{Type: "brave", QuotaLimit: -1}},
|
||||
}
|
||||
require.ErrorContains(t, validateWebSearchConfig(cfg), "quota_limit must be >= 0")
|
||||
}
|
||||
|
||||
func TestValidateWebSearchConfig_DuplicateType(t *testing.T) {
|
||||
cfg := &WebSearchEmulationConfig{
|
||||
Providers: []WebSearchProviderConfig{
|
||||
{Type: "brave", Priority: 1},
|
||||
{Type: "brave", Priority: 2},
|
||||
},
|
||||
}
|
||||
require.ErrorContains(t, validateWebSearchConfig(cfg), "duplicate type")
|
||||
}
|
||||
|
||||
func TestValidateWebSearchConfig_EmptyQuotaInterval(t *testing.T) {
|
||||
cfg := &WebSearchEmulationConfig{
|
||||
Providers: []WebSearchProviderConfig{{Type: "brave", QuotaRefreshInterval: ""}},
|
||||
}
|
||||
require.NoError(t, validateWebSearchConfig(cfg))
|
||||
}
|
||||
|
||||
func TestValidateWebSearchConfig_ZeroQuotaLimit(t *testing.T) {
|
||||
cfg := &WebSearchEmulationConfig{
|
||||
Providers: []WebSearchProviderConfig{{Type: "brave", QuotaLimit: 0}},
|
||||
}
|
||||
require.NoError(t, validateWebSearchConfig(cfg))
|
||||
}
|
||||
|
||||
// --- parseWebSearchConfigJSON ---
|
||||
|
||||
func TestParseWebSearchConfigJSON_ValidJSON(t *testing.T) {
|
||||
raw := `{"enabled":true,"providers":[{"type":"brave","api_key":"sk-xxx"}]}`
|
||||
cfg := parseWebSearchConfigJSON(raw)
|
||||
require.True(t, cfg.Enabled)
|
||||
require.Len(t, cfg.Providers, 1)
|
||||
require.Equal(t, "brave", cfg.Providers[0].Type)
|
||||
}
|
||||
|
||||
func TestParseWebSearchConfigJSON_EmptyString(t *testing.T) {
|
||||
cfg := parseWebSearchConfigJSON("")
|
||||
require.False(t, cfg.Enabled)
|
||||
require.Empty(t, cfg.Providers)
|
||||
}
|
||||
|
||||
func TestParseWebSearchConfigJSON_InvalidJSON(t *testing.T) {
|
||||
cfg := parseWebSearchConfigJSON("not{json")
|
||||
require.False(t, cfg.Enabled)
|
||||
require.Empty(t, cfg.Providers)
|
||||
}
|
||||
|
||||
// --- SanitizeWebSearchConfig ---
|
||||
|
||||
func TestSanitizeWebSearchConfig_MaskAPIKey(t *testing.T) {
|
||||
cfg := &WebSearchEmulationConfig{
|
||||
Enabled: true,
|
||||
Providers: []WebSearchProviderConfig{
|
||||
{Type: "brave", APIKey: "sk-secret-xxx"},
|
||||
},
|
||||
}
|
||||
out := SanitizeWebSearchConfig(cfg)
|
||||
require.Equal(t, "", out.Providers[0].APIKey)
|
||||
require.True(t, out.Providers[0].APIKeyConfigured)
|
||||
}
|
||||
|
||||
func TestSanitizeWebSearchConfig_NoAPIKey(t *testing.T) {
|
||||
cfg := &WebSearchEmulationConfig{
|
||||
Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: ""}},
|
||||
}
|
||||
out := SanitizeWebSearchConfig(cfg)
|
||||
require.Equal(t, "", out.Providers[0].APIKey)
|
||||
require.False(t, out.Providers[0].APIKeyConfigured)
|
||||
}
|
||||
|
||||
func TestSanitizeWebSearchConfig_Nil(t *testing.T) {
|
||||
require.Nil(t, SanitizeWebSearchConfig(nil))
|
||||
}
|
||||
|
||||
func TestSanitizeWebSearchConfig_PreservesOtherFields(t *testing.T) {
|
||||
cfg := &WebSearchEmulationConfig{
|
||||
Enabled: true,
|
||||
Providers: []WebSearchProviderConfig{
|
||||
{Type: "brave", APIKey: "secret", Priority: 10, QuotaLimit: 1000},
|
||||
},
|
||||
}
|
||||
out := SanitizeWebSearchConfig(cfg)
|
||||
require.True(t, out.Enabled)
|
||||
require.Equal(t, 10, out.Providers[0].Priority)
|
||||
require.Equal(t, int64(1000), out.Providers[0].QuotaLimit)
|
||||
}
|
||||
|
||||
func TestSanitizeWebSearchConfig_DoesNotMutateOriginal(t *testing.T) {
|
||||
cfg := &WebSearchEmulationConfig{
|
||||
Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: "secret"}},
|
||||
}
|
||||
_ = SanitizeWebSearchConfig(cfg)
|
||||
require.Equal(t, "secret", cfg.Providers[0].APIKey)
|
||||
}
|
||||
Reference in New Issue
Block a user