feat(gateway): add web search emulation for Anthropic API Key accounts

Inject web search capability for Claude Console (API Key) accounts that
don't natively support Anthropic's web_search tool. When a pure
web_search request is detected, the gateway calls Brave Search or Tavily
API directly and constructs an Anthropic-protocol-compliant SSE/JSON
response without forwarding to upstream.

Backend:
- New `pkg/websearch/` SDK: Brave and Tavily provider implementations
  with io.LimitReader, proxy support, and Redis-based quota tracking
  (Lua atomic INCR + TTL, DECR rollback on failure)
- Global config via `settings.web_search_emulation_config` (JSON) with
  in-process cache + singleflight, input validation, API key merge on
  save, and sanitized API responses
- Channel-level toggle via `channels.features_config` JSONB column
  (DB migration 101)
- Account-level toggle via `accounts.extra.web_search_emulation`
- Request interception in `Forward()` with SSE streaming response
  construction using json.Marshal (no manual string concatenation)
- Manager hot-reload: `RebuildWebSearchManager()` called on config save
  and startup via `SetWebSearchRedisClient()`
- 70 unit tests covering providers, manager, config validation,
  sanitization, tool detection, query extraction, and response building

Frontend:
- Settings → Gateway tab: Web Search Emulation config card with global
  toggle, provider list (add/remove, API key, priority, quota, proxy)
- Channels → Anthropic tab: web search emulation toggle with global
  state linkage (disabled when global off)
- Account Create/Edit modals: web search emulation toggle for API Key
  type with Toggle component
- Full i18n coverage (zh + en)
This commit is contained in:
erio
2026-04-12 00:02:26 +08:00
parent c738cfec93
commit 1b53ffcac7
37 changed files with 3507 additions and 238 deletions

View File

@@ -969,7 +969,7 @@ func (a *Account) IsOveragesEnabled() bool {
return false
}
// IsOpenAIPassthroughEnabled 返回 OpenAI 账号是否启用自动透传(仅替换认证)
// IsOpenAIPassthroughEnabled 返回 OpenAI 账号是否启用"自动透传(仅替换认证)"
//
// 新字段accounts.extra.openai_passthrough。
// 兼容字段accounts.extra.openai_oauth_passthrough历史 OAuth 开关)。
@@ -1133,7 +1133,7 @@ func (a *Account) ResolveOpenAIResponsesWebSocketV2Mode(defaultMode string) stri
return resolvedDefault
}
// IsOpenAIWSForceHTTPEnabled 返回账号级强制 HTTP开关。
// IsOpenAIWSForceHTTPEnabled 返回账号级"强制 HTTP"开关。
// 字段accounts.extra.openai_ws_force_http。
func (a *Account) IsOpenAIWSForceHTTPEnabled() bool {
if a == nil || !a.IsOpenAI() || a.Extra == nil {
@@ -1158,7 +1158,7 @@ func (a *Account) IsOpenAIOAuthPassthroughEnabled() bool {
return a != nil && a.IsOpenAIOAuth() && a.IsOpenAIPassthroughEnabled()
}
// IsAnthropicAPIKeyPassthroughEnabled 返回 Anthropic API Key 账号是否启用自动透传(仅替换认证)
// IsAnthropicAPIKeyPassthroughEnabled 返回 Anthropic API Key 账号是否启用"自动透传(仅替换认证)"
// 字段accounts.extra.anthropic_passthrough。
// 字段缺失或类型不正确时,按 false关闭处理。
func (a *Account) IsAnthropicAPIKeyPassthroughEnabled() bool {
@@ -1169,7 +1169,18 @@ func (a *Account) IsAnthropicAPIKeyPassthroughEnabled() bool {
return ok && enabled
}
// IsCodexCLIOnlyEnabled 返回 OpenAI OAuth 账号是否启用“仅允许 Codex 官方客户端”
// IsWebSearchEmulationEnabled 返回 Anthropic API Key 账号是否启用 web search 模拟
// 字段accounts.extra.web_search_emulation。
// 字段缺失或类型不正确时,按 false关闭处理。
func (a *Account) IsWebSearchEmulationEnabled() bool {
if a == nil || a.Platform != PlatformAnthropic || a.Type != AccountTypeAPIKey || a.Extra == nil {
return false
}
enabled, ok := a.Extra[featureKeyWebSearchEmulation].(bool)
return ok && enabled
}
// IsCodexCLIOnlyEnabled 返回 OpenAI OAuth 账号是否启用"仅允许 Codex 官方客户端"。
// 字段accounts.extra.codex_cli_only。
// 字段缺失或类型不正确时,按 false关闭处理。
func (a *Account) IsCodexCLIOnlyEnabled() bool {

View File

@@ -0,0 +1,71 @@
package service
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestAccount_IsWebSearchEmulationEnabled_Enabled(t *testing.T) {
a := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeAPIKey,
Extra: map[string]any{featureKeyWebSearchEmulation: true},
}
require.True(t, a.IsWebSearchEmulationEnabled())
}
func TestAccount_IsWebSearchEmulationEnabled_Disabled(t *testing.T) {
a := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeAPIKey,
Extra: map[string]any{featureKeyWebSearchEmulation: false},
}
require.False(t, a.IsWebSearchEmulationEnabled())
}
func TestAccount_IsWebSearchEmulationEnabled_MissingField(t *testing.T) {
a := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeAPIKey,
Extra: map[string]any{},
}
require.False(t, a.IsWebSearchEmulationEnabled())
}
func TestAccount_IsWebSearchEmulationEnabled_WrongType(t *testing.T) {
a := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeAPIKey,
Extra: map[string]any{featureKeyWebSearchEmulation: "true"},
}
require.False(t, a.IsWebSearchEmulationEnabled())
}
func TestAccount_IsWebSearchEmulationEnabled_NilExtra(t *testing.T) {
a := &Account{Platform: PlatformAnthropic, Type: AccountTypeAPIKey, Extra: nil}
require.False(t, a.IsWebSearchEmulationEnabled())
}
func TestAccount_IsWebSearchEmulationEnabled_NilAccount(t *testing.T) {
var a *Account
require.False(t, a.IsWebSearchEmulationEnabled())
}
func TestAccount_IsWebSearchEmulationEnabled_NonAnthropicPlatform(t *testing.T) {
a := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Extra: map[string]any{featureKeyWebSearchEmulation: true},
}
require.False(t, a.IsWebSearchEmulationEnabled())
}
func TestAccount_IsWebSearchEmulationEnabled_NonAPIKeyType(t *testing.T) {
a := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Extra: map[string]any{featureKeyWebSearchEmulation: true},
}
require.False(t, a.IsWebSearchEmulationEnabled())
}

View File

@@ -49,6 +49,21 @@ type Channel struct {
ModelPricing []ChannelModelPricing
// 渠道级模型映射按平台分组platform → {src→dst}
ModelMapping map[string]map[string]string
// 渠道特性配置(如 {"web_search_emulation": {"anthropic": true}}
FeaturesConfig map[string]any
}
// IsWebSearchEmulationEnabled 返回该渠道是否为指定平台启用了 web search 模拟。
func (c *Channel) IsWebSearchEmulationEnabled(platform string) bool {
if c == nil || c.FeaturesConfig == nil {
return false
}
wse, ok := c.FeaturesConfig[featureKeyWebSearchEmulation].(map[string]any)
if !ok {
return false
}
enabled, ok := wse[platform].(bool)
return ok && enabled
}
// ChannelModelPricing 渠道模型定价条目

View File

@@ -197,10 +197,8 @@ func newEmptyChannelCache() *channelCache {
}
// expandPricingToCache 将渠道的模型定价展开到缓存(按分组+平台维度)。
// antigravity 平台同时服务 Claude 和 Gemini 模型,需匹配 anthropic/gemini 的定价条目
// 缓存 key 使用定价条目的原始平台pricing.Platform而非分组平台
// 避免跨平台同名模型(如 anthropic 和 gemini 都有 "model-x")互相覆盖。
// 查找时通过 lookupPricingAcrossPlatforms() 依次尝试所有匹配平台。
// 各平台严格独立:antigravity 分组只匹配 antigravity 定价,不会匹配 anthropic/gemini 的定价。
// 查找时通过 lookupPricingAcrossPlatforms() 在本平台内查找。
func expandPricingToCache(cache *channelCache, ch *Channel, gid int64, platform string) {
for j := range ch.ModelPricing {
pricing := &ch.ModelPricing[j]
@@ -226,8 +224,7 @@ func expandPricingToCache(cache *channelCache, ch *Channel, gid int64, platform
}
// expandMappingToCache 将渠道的模型映射展开到缓存(按分组+平台维度)。
// antigravity 平台同时服务 Claude 和 Gemini 模型
// 缓存 key 使用映射条目的原始平台mappingPlatform避免跨平台同名映射覆盖。
// 各平台严格独立:antigravity 分组只匹配 antigravity 映射
func expandMappingToCache(cache *channelCache, ch *Channel, gid int64, platform string) {
for _, mappingPlatform := range matchingPlatforms(platform) {
platformMapping, ok := ch.ModelMapping[mappingPlatform]
@@ -251,40 +248,58 @@ func expandMappingToCache(cache *channelCache, ch *Channel, gid int64, platform
}
}
// storeErrorCache 存入短 TTL 空缓存,防止 DB 错误后紧密重试。
// 通过回退 loadedAt 使剩余 TTL = channelErrorTTL。
func (s *ChannelService) storeErrorCache() {
errorCache := newEmptyChannelCache()
errorCache.loadedAt = time.Now().Add(-(channelCacheTTL - channelErrorTTL))
s.cache.Store(errorCache)
}
// buildCache 从数据库构建渠道缓存。
// 使用独立 context 避免请求取消导致空值被长期缓存。
func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error) {
// 断开请求取消链,避免客户端断连导致空值被长期缓存
dbCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), channelCacheDBTimeout)
defer cancel()
channels, err := s.repo.ListAll(dbCtx)
channels, groupPlatforms, err := s.fetchChannelData(dbCtx)
if err != nil {
// error-TTL失败时存入短 TTL 空缓存,防止紧密重试
slog.Warn("failed to build channel cache", "error", err)
errorCache := newEmptyChannelCache()
errorCache.loadedAt = time.Now().Add(-(channelCacheTTL - channelErrorTTL)) // 使剩余 TTL = errorTTL
s.cache.Store(errorCache)
return nil, fmt.Errorf("list all channels: %w", err)
return nil, err
}
cache := populateChannelCache(channels, groupPlatforms)
s.cache.Store(cache)
return cache, nil
}
// fetchChannelData 从数据库加载渠道列表和分组平台映射。
func (s *ChannelService) fetchChannelData(ctx context.Context) ([]Channel, map[int64]string, error) {
channels, err := s.repo.ListAll(ctx)
if err != nil {
slog.Warn("failed to build channel cache", "error", err)
s.storeErrorCache()
return nil, nil, fmt.Errorf("list all channels: %w", err)
}
// 收集所有 groupID批量查询 platform
var allGroupIDs []int64
for i := range channels {
allGroupIDs = append(allGroupIDs, channels[i].GroupIDs...)
}
groupPlatforms := make(map[int64]string)
if len(allGroupIDs) > 0 {
groupPlatforms, err = s.repo.GetGroupPlatforms(dbCtx, allGroupIDs)
groupPlatforms, err = s.repo.GetGroupPlatforms(ctx, allGroupIDs)
if err != nil {
slog.Warn("failed to load group platforms for channel cache", "error", err)
errorCache := newEmptyChannelCache()
errorCache.loadedAt = time.Now().Add(-(channelCacheTTL - channelErrorTTL))
s.cache.Store(errorCache)
return nil, fmt.Errorf("get group platforms: %w", err)
s.storeErrorCache()
return nil, nil, fmt.Errorf("get group platforms: %w", err)
}
}
return channels, groupPlatforms, nil
}
// populateChannelCache 将渠道列表和分组平台映射填充到缓存快照中。
func populateChannelCache(channels []Channel, groupPlatforms map[int64]string) *channelCache {
cache := newEmptyChannelCache()
cache.groupPlatform = groupPlatforms
cache.byID = make(map[int64]*Channel, len(channels))
@@ -293,7 +308,6 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error)
for i := range channels {
ch := &channels[i]
cache.byID[ch.ID] = ch
for _, gid := range ch.GroupIDs {
cache.channelByGroupID[gid] = ch
platform := groupPlatforms[gid]
@@ -302,32 +316,20 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error)
}
}
// 通配符条目保持配置顺序(最先匹配到优先)
s.cache.Store(cache)
return cache, nil
return cache
}
// invalidateCache 使缓存失效,让下次读取时自然重建
// isPlatformPricingMatch 判断定价条目的平台是否匹配分组平台。
// antigravity 平台同时服务 Claudeanthropic)和 Geminigemini模型
// 因此 antigravity 分组应匹配 anthropic 和 gemini 的定价条目。
// 各平台(antigravity / anthropic / gemini / openai严格独立不跨平台匹配。
func isPlatformPricingMatch(groupPlatform, pricingPlatform string) bool {
if groupPlatform == pricingPlatform {
return true
}
if groupPlatform == PlatformAntigravity {
return pricingPlatform == PlatformAnthropic || pricingPlatform == PlatformGemini
}
return false
return groupPlatform == pricingPlatform
}
// matchingPlatforms 返回分组平台对应的所有可匹配平台列表。
// matchingPlatforms 返回分组平台对应的可匹配平台列表。
// 各平台严格独立,只返回自身。
func matchingPlatforms(groupPlatform string) []string {
if groupPlatform == PlatformAntigravity {
return []string{PlatformAntigravity, PlatformAnthropic, PlatformGemini}
}
return []string{groupPlatform}
}
func (s *ChannelService) invalidateCache() {
@@ -364,10 +366,8 @@ func (c *channelCache) matchWildcardMapping(groupID int64, platform, modelLower
return ""
}
// lookupPricingAcrossPlatforms 在所有匹配平台查找模型定价。
// antigravity 分组的缓存 key 使用定价条目的原始平台,因此查找时需依次尝试
// matchingPlatforms() 返回的所有平台antigravity → anthropic → gemini
// 返回第一个命中的结果。非 antigravity 平台只尝试自身。
// lookupPricingAcrossPlatforms 在分组平台查找模型定价。
// 各平台严格独立,只在本平台内查找(先精确匹配,再通配符)。
func lookupPricingAcrossPlatforms(cache *channelCache, groupID int64, groupPlatform, modelLower string) *ChannelModelPricing {
for _, p := range matchingPlatforms(groupPlatform) {
key := channelModelKey{groupID: groupID, platform: p, model: modelLower}
@@ -384,7 +384,7 @@ func lookupPricingAcrossPlatforms(cache *channelCache, groupID int64, groupPlatf
return nil
}
// lookupMappingAcrossPlatforms 在所有匹配平台查找模型映射。
// lookupMappingAcrossPlatforms 在分组平台查找模型映射。
// 逻辑与 lookupPricingAcrossPlatforms 相同:先精确查找,再通配符。
func lookupMappingAcrossPlatforms(cache *channelCache, groupID int64, groupPlatform, modelLower string) string {
for _, p := range matchingPlatforms(groupPlatform) {
@@ -442,8 +442,7 @@ func (s *ChannelService) lookupGroupChannel(ctx context.Context, groupID int64)
}
// GetChannelModelPricing 获取指定分组+模型的渠道定价(热路径 O(1))。
// antigravity 分组依次尝试所有匹配平台antigravity → anthropic → gemini
// 确保跨平台同名模型各自独立匹配。
// 各平台严格独立,只在本平台内查找定价。
func (s *ChannelService) GetChannelModelPricing(ctx context.Context, groupID int64, model string) *ChannelModelPricing {
lk, err := s.lookupGroupChannel(ctx, groupID)
if err != nil {
@@ -481,7 +480,10 @@ func (s *ChannelService) ResolveChannelMapping(ctx context.Context, groupID int6
// 返回 true 表示模型被限制(不在允许列表中)。
// 如果渠道未启用模型限制或分组无渠道关联,返回 false。
func (s *ChannelService) IsModelRestricted(ctx context.Context, groupID int64, model string) bool {
lk, _ := s.lookupGroupChannel(ctx, groupID)
lk, err := s.lookupGroupChannel(ctx, groupID)
if err != nil {
slog.Warn("failed to load channel cache for model restriction check", "group_id", groupID, "error", err)
}
if lk == nil {
return false
}
@@ -524,7 +526,7 @@ func resolveMapping(lk *channelLookup, groupID int64, model string) ChannelMappi
}
// checkRestricted 基于已查找的渠道信息检查模型是否被限制。
// antigravity 分组依次尝试所有匹配平台的定价列表。
// 只在本平台的定价列表中查找
func checkRestricted(lk *channelLookup, groupID int64, model string) bool {
if !lk.channel.RestrictModels {
return false
@@ -552,6 +554,91 @@ func ReplaceModelInBody(body []byte, newModel string) []byte {
return newBody
}
// validateChannelConfig 校验渠道的定价和映射配置(冲突检测 + 区间校验 + 计费模式校验)。
// Create 和 Update 共用此函数,避免重复。
func validateChannelConfig(pricing []ChannelModelPricing, mapping map[string]map[string]string) error {
if err := validateNoConflictingModels(pricing); err != nil {
return err
}
if err := validatePricingIntervals(pricing); err != nil {
return err
}
if err := validateNoConflictingMappings(mapping); err != nil {
return err
}
return validatePricingBillingMode(pricing)
}
// validatePricingBillingMode 校验计费模式配置:按次/图片模式必须配价格或区间,所有价格字段不能为负,区间至少有一个价格字段。
func validatePricingBillingMode(pricing []ChannelModelPricing) error {
for _, p := range pricing {
if err := checkBillingModeRequirements(p); err != nil {
return err
}
if err := checkPricesNotNegative(p); err != nil {
return err
}
if err := checkIntervalsHavePrices(p); err != nil {
return err
}
}
return nil
}
func checkBillingModeRequirements(p ChannelModelPricing) error {
if p.BillingMode == BillingModePerRequest || p.BillingMode == BillingModeImage {
if p.PerRequestPrice == nil && len(p.Intervals) == 0 {
return infraerrors.BadRequest(
"BILLING_MODE_MISSING_PRICE",
"per-request price or intervals required for per_request/image billing mode",
)
}
}
return nil
}
func checkPricesNotNegative(p ChannelModelPricing) error {
checks := []struct {
field string
val *float64
}{
{"input_price", p.InputPrice},
{"output_price", p.OutputPrice},
{"cache_write_price", p.CacheWritePrice},
{"cache_read_price", p.CacheReadPrice},
{"image_output_price", p.ImageOutputPrice},
{"per_request_price", p.PerRequestPrice},
}
for _, c := range checks {
if c.val != nil && *c.val < 0 {
return infraerrors.BadRequest("NEGATIVE_PRICE", fmt.Sprintf("%s must be >= 0", c.field))
}
}
return nil
}
func checkIntervalsHavePrices(p ChannelModelPricing) error {
for _, iv := range p.Intervals {
if iv.InputPrice == nil && iv.OutputPrice == nil &&
iv.CacheWritePrice == nil && iv.CacheReadPrice == nil &&
iv.PerRequestPrice == nil {
return infraerrors.BadRequest(
"INTERVAL_MISSING_PRICE",
fmt.Sprintf("interval [%d, %s] has no price fields set for model %v",
iv.MinTokens, formatMaxTokens(iv.MaxTokens), p.Models),
)
}
}
return nil
}
func formatMaxTokens(max *int) string {
if max == nil {
return "∞"
}
return fmt.Sprintf("%d", *max)
}
// --- CRUD ---
// Create 创建渠道
@@ -564,15 +651,8 @@ func (s *ChannelService) Create(ctx context.Context, input *CreateChannelInput)
return nil, ErrChannelExists
}
// 检查分组冲突
if len(input.GroupIDs) > 0 {
conflicting, err := s.repo.GetGroupsInOtherChannels(ctx, 0, input.GroupIDs)
if err != nil {
return nil, fmt.Errorf("check group conflicts: %w", err)
}
if len(conflicting) > 0 {
return nil, ErrGroupAlreadyInChannel
}
if err := s.checkGroupConflicts(ctx, 0, input.GroupIDs); err != nil {
return nil, err
}
channel := &Channel{
@@ -585,18 +665,13 @@ func (s *ChannelService) Create(ctx context.Context, input *CreateChannelInput)
ModelPricing: input.ModelPricing,
ModelMapping: input.ModelMapping,
Features: input.Features,
FeaturesConfig: input.FeaturesConfig,
}
if channel.BillingModelSource == "" {
channel.BillingModelSource = BillingModelSourceChannelMapped
}
if err := validateNoConflictingModels(channel.ModelPricing); err != nil {
return nil, err
}
if err := validatePricingIntervals(channel.ModelPricing); err != nil {
return nil, err
}
if err := validateNoConflictingMappings(channel.ModelMapping); err != nil {
if err := validateChannelConfig(channel.ModelPricing, channel.ModelMapping); err != nil {
return nil, err
}
@@ -620,105 +695,118 @@ func (s *ChannelService) Update(ctx context.Context, id int64, input *UpdateChan
return nil, fmt.Errorf("get channel: %w", err)
}
if input.Name != "" && input.Name != channel.Name {
exists, err := s.repo.ExistsByNameExcluding(ctx, input.Name, id)
if err != nil {
return nil, fmt.Errorf("check channel exists: %w", err)
}
if exists {
return nil, ErrChannelExists
}
channel.Name = input.Name
}
if input.Description != nil {
channel.Description = *input.Description
}
if input.Status != "" {
channel.Status = input.Status
}
if input.RestrictModels != nil {
channel.RestrictModels = *input.RestrictModels
}
if input.Features != nil {
channel.Features = *input.Features
}
// 检查分组冲突
if input.GroupIDs != nil {
conflicting, err := s.repo.GetGroupsInOtherChannels(ctx, id, *input.GroupIDs)
if err != nil {
return nil, fmt.Errorf("check group conflicts: %w", err)
}
if len(conflicting) > 0 {
return nil, ErrGroupAlreadyInChannel
}
channel.GroupIDs = *input.GroupIDs
}
if input.ModelPricing != nil {
channel.ModelPricing = *input.ModelPricing
}
if input.ModelMapping != nil {
channel.ModelMapping = input.ModelMapping
}
if input.BillingModelSource != "" {
channel.BillingModelSource = input.BillingModelSource
}
if err := validateNoConflictingModels(channel.ModelPricing); err != nil {
return nil, err
}
if err := validatePricingIntervals(channel.ModelPricing); err != nil {
return nil, err
}
if err := validateNoConflictingMappings(channel.ModelMapping); err != nil {
if err := s.applyUpdateInput(ctx, channel, input); err != nil {
return nil, err
}
// 先获取旧分组Update 后旧分组关联已删除,无法再查到
var oldGroupIDs []int64
if s.authCacheInvalidator != nil {
var err2 error
oldGroupIDs, err2 = s.repo.GetGroupIDs(ctx, id)
if err2 != nil {
slog.Warn("failed to get old group IDs for cache invalidation", "channel_id", id, "error", err2)
}
if err := validateChannelConfig(channel.ModelPricing, channel.ModelMapping); err != nil {
return nil, err
}
oldGroupIDs := s.getOldGroupIDs(ctx, id)
if err := s.repo.Update(ctx, channel); err != nil {
return nil, fmt.Errorf("update channel: %w", err)
}
s.invalidateCache()
// 失效新旧分组的 auth 缓存
if s.authCacheInvalidator != nil {
seen := make(map[int64]struct{}, len(oldGroupIDs)+len(channel.GroupIDs))
for _, gid := range oldGroupIDs {
if _, ok := seen[gid]; !ok {
seen[gid] = struct{}{}
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, gid)
}
}
for _, gid := range channel.GroupIDs {
if _, ok := seen[gid]; !ok {
seen[gid] = struct{}{}
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, gid)
}
}
}
s.invalidateAuthCacheForGroups(ctx, oldGroupIDs, channel.GroupIDs)
return s.repo.GetByID(ctx, id)
}
// applyUpdateInput 将更新请求的字段应用到渠道实体上。
func (s *ChannelService) applyUpdateInput(ctx context.Context, channel *Channel, input *UpdateChannelInput) error {
if input.Name != "" && input.Name != channel.Name {
exists, err := s.repo.ExistsByNameExcluding(ctx, input.Name, channel.ID)
if err != nil {
return fmt.Errorf("check channel exists: %w", err)
}
if exists {
return ErrChannelExists
}
channel.Name = input.Name
}
if input.Description != nil {
channel.Description = *input.Description
}
if input.Status != "" {
channel.Status = input.Status
}
if input.RestrictModels != nil {
channel.RestrictModels = *input.RestrictModels
}
if input.Features != nil {
channel.Features = *input.Features
}
if input.GroupIDs != nil {
if err := s.checkGroupConflicts(ctx, channel.ID, *input.GroupIDs); err != nil {
return err
}
channel.GroupIDs = *input.GroupIDs
}
if input.ModelPricing != nil {
channel.ModelPricing = *input.ModelPricing
}
if input.ModelMapping != nil {
channel.ModelMapping = input.ModelMapping
}
if input.BillingModelSource != "" {
channel.BillingModelSource = input.BillingModelSource
}
if input.FeaturesConfig != nil {
channel.FeaturesConfig = input.FeaturesConfig
}
return nil
}
// checkGroupConflicts 检查待关联的分组是否已属于其他渠道。
// channelID 为当前渠道 IDCreate 时传 0
func (s *ChannelService) checkGroupConflicts(ctx context.Context, channelID int64, groupIDs []int64) error {
if len(groupIDs) == 0 {
return nil
}
conflicting, err := s.repo.GetGroupsInOtherChannels(ctx, channelID, groupIDs)
if err != nil {
return fmt.Errorf("check group conflicts: %w", err)
}
if len(conflicting) > 0 {
return ErrGroupAlreadyInChannel
}
return nil
}
// getOldGroupIDs 获取渠道更新前的关联分组 ID用于失效 auth 缓存)。
func (s *ChannelService) getOldGroupIDs(ctx context.Context, channelID int64) []int64 {
if s.authCacheInvalidator == nil {
return nil
}
oldGroupIDs, err := s.repo.GetGroupIDs(ctx, channelID)
if err != nil {
slog.Warn("failed to get old group IDs for cache invalidation", "channel_id", channelID, "error", err)
}
return oldGroupIDs
}
// invalidateAuthCacheForGroups 对新旧分组去重后逐个失效 auth 缓存。
func (s *ChannelService) invalidateAuthCacheForGroups(ctx context.Context, groupIDSets ...[]int64) {
if s.authCacheInvalidator == nil {
return
}
seen := make(map[int64]struct{})
for _, ids := range groupIDSets {
for _, gid := range ids {
if _, ok := seen[gid]; ok {
continue
}
seen[gid] = struct{}{}
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, gid)
}
}
}
// Delete 删除渠道
func (s *ChannelService) Delete(ctx context.Context, id int64) error {
// 先获取关联分组用于失效缓存
groupIDs, err := s.repo.GetGroupIDs(ctx, id)
if err != nil {
slog.Warn("failed to get group IDs before delete", "channel_id", id, "error", err)
@@ -729,12 +817,7 @@ func (s *ChannelService) Delete(ctx context.Context, id int64) error {
}
s.invalidateCache()
if s.authCacheInvalidator != nil {
for _, gid := range groupIDs {
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, gid)
}
}
s.invalidateAuthCacheForGroups(ctx, groupIDs)
return nil
}
@@ -847,6 +930,7 @@ type CreateChannelInput struct {
BillingModelSource string
RestrictModels bool
Features string
FeaturesConfig map[string]any
}
// UpdateChannelInput 更新渠道输入
@@ -860,4 +944,5 @@ type UpdateChannelInput struct {
BillingModelSource string
RestrictModels *bool
Features *string
FeaturesConfig map[string]any
}

View File

@@ -0,0 +1,62 @@
package service
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestChannel_IsWebSearchEmulationEnabled_Enabled(t *testing.T) {
c := &Channel{
FeaturesConfig: map[string]any{
featureKeyWebSearchEmulation: map[string]any{"anthropic": true},
},
}
require.True(t, c.IsWebSearchEmulationEnabled("anthropic"))
}
func TestChannel_IsWebSearchEmulationEnabled_DifferentPlatform(t *testing.T) {
c := &Channel{
FeaturesConfig: map[string]any{
featureKeyWebSearchEmulation: map[string]any{"anthropic": true},
},
}
require.False(t, c.IsWebSearchEmulationEnabled("openai"))
}
func TestChannel_IsWebSearchEmulationEnabled_Disabled(t *testing.T) {
c := &Channel{
FeaturesConfig: map[string]any{
featureKeyWebSearchEmulation: map[string]any{"anthropic": false},
},
}
require.False(t, c.IsWebSearchEmulationEnabled("anthropic"))
}
func TestChannel_IsWebSearchEmulationEnabled_NilFeaturesConfig(t *testing.T) {
c := &Channel{FeaturesConfig: nil}
require.False(t, c.IsWebSearchEmulationEnabled("anthropic"))
}
func TestChannel_IsWebSearchEmulationEnabled_NilChannel(t *testing.T) {
var c *Channel
require.False(t, c.IsWebSearchEmulationEnabled("anthropic"))
}
func TestChannel_IsWebSearchEmulationEnabled_WrongStructure(t *testing.T) {
c := &Channel{
FeaturesConfig: map[string]any{
featureKeyWebSearchEmulation: true, // not a map
},
}
require.False(t, c.IsWebSearchEmulationEnabled("anthropic"))
}
func TestChannel_IsWebSearchEmulationEnabled_PlatformValueNotBool(t *testing.T) {
c := &Channel{
FeaturesConfig: map[string]any{
featureKeyWebSearchEmulation: map[string]any{"anthropic": "yes"},
},
}
require.False(t, c.IsWebSearchEmulationEnabled("anthropic"))
}

View File

@@ -249,6 +249,10 @@ const (
SettingKeyEnableMetadataPassthrough = "enable_metadata_passthrough"
// SettingKeyEnableCCHSigning 是否对 billing header 中的 cch 进行 xxHash64 签名(默认 false
SettingKeyEnableCCHSigning = "enable_cch_signing"
// Web Search Emulation
// SettingKeyWebSearchEmulationConfig 全局 web search 模拟配置JSON
SettingKeyWebSearchEmulationConfig = "web_search_emulation_config"
)
// AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys).

View File

@@ -3785,6 +3785,11 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
return nil, fmt.Errorf("parse request: empty request")
}
// Web Search 模拟:纯 web_search 请求时,直接调用搜索 API 构造响应
if account != nil && s.shouldEmulateWebSearch(ctx, account, parsed.Body) {
return s.handleWebSearchEmulation(ctx, c, account, parsed)
}
if account != nil && account.IsAnthropicAPIKeyPassthroughEnabled() {
passthroughBody := parsed.Body
passthroughModel := parsed.Model

View File

@@ -0,0 +1,358 @@
package service
import (
"context"
"encoding/json"
"fmt"
"log/slog"
"net/http"
"strings"
"sync/atomic"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/websearch"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/tidwall/gjson"
)
// Web search emulation constants
const (
toolTypeWebSearchPrefix = "web_search"
toolTypeGoogleSearch = "google_search"
toolNameWebSearch = "web_search"
toolNameGoogleSearch = "google_search"
toolNameWebSearch2025 = "web_search_20250305"
webSearchDefaultMaxResults = 5
defaultWebSearchModel = "claude-sonnet-4-6"
webSearchMsgIDPrefix = "msg_ws_"
webSearchToolUseIDPrefix = "srvtoolu_ws_"
tokenEstimateDivisor = 4
// featureKeyWebSearchEmulation is the key used in Account.Extra and Channel.FeaturesConfig.
featureKeyWebSearchEmulation = "web_search_emulation"
)
// webSearchManagerPtr stores *websearch.Manager atomically for concurrent safety.
var webSearchManagerPtr atomic.Pointer[websearch.Manager]
// SetWebSearchManager wires the websearch.Manager into the gateway (goroutine-safe).
func SetWebSearchManager(m *websearch.Manager) {
webSearchManagerPtr.Store(m)
}
func getWebSearchManager() *websearch.Manager {
return webSearchManagerPtr.Load()
}
// shouldEmulateWebSearch checks whether a request should be intercepted.
//
// Judgment chain: manager exists → only web_search tool → global enabled → account enabled.
// Note: channel-level control is enforced via the account's extra field; the channel toggle
// in the admin UI sets the account's flag for all accounts in that channel's groups.
func (s *GatewayService) shouldEmulateWebSearch(ctx context.Context, account *Account, body []byte) bool {
if getWebSearchManager() == nil {
return false
}
if !isOnlyWebSearchToolInBody(body) {
return false
}
if !s.settingService.IsWebSearchEmulationEnabled(ctx) {
return false
}
if !account.IsWebSearchEmulationEnabled() {
return false
}
return true
}
// isOnlyWebSearchToolInBody checks if the body contains exactly one web_search tool.
func isOnlyWebSearchToolInBody(body []byte) bool {
tools := gjson.GetBytes(body, "tools")
if !tools.IsArray() {
return false
}
arr := tools.Array()
if len(arr) != 1 {
return false
}
return isWebSearchToolJSON(arr[0])
}
func isWebSearchToolJSON(tool gjson.Result) bool {
toolType := tool.Get("type").String()
if strings.HasPrefix(toolType, toolTypeWebSearchPrefix) || toolType == toolTypeGoogleSearch {
return true
}
switch tool.Get("name").String() {
case toolNameWebSearch, toolNameGoogleSearch, toolNameWebSearch2025:
return true
}
return false
}
// extractSearchQueryFromBody extracts the last user message text as the search query.
func extractSearchQueryFromBody(body []byte) string {
messages := gjson.GetBytes(body, "messages")
if !messages.IsArray() {
return ""
}
arr := messages.Array()
if len(arr) == 0 {
return ""
}
lastMsg := arr[len(arr)-1]
if lastMsg.Get("role").String() != "user" {
return ""
}
return extractWebSearchTextFromContent(lastMsg.Get("content"))
}
func extractWebSearchTextFromContent(content gjson.Result) string {
if content.Type == gjson.String {
return content.String()
}
if content.IsArray() {
for _, block := range content.Array() {
if block.Get("type").String() == "text" {
if text := block.Get("text").String(); text != "" {
return text
}
}
}
}
return ""
}
// handleWebSearchEmulation intercepts a web-search-only request,
// calls a third-party search API, and constructs an Anthropic-format response.
func (s *GatewayService) handleWebSearchEmulation(
ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest,
) (*ForwardResult, error) {
startTime := time.Now()
// Release the serial queue lock immediately — we don't need upstream.
if parsed.OnUpstreamAccepted != nil {
parsed.OnUpstreamAccepted()
}
query := extractSearchQueryFromBody(parsed.Body)
if query == "" {
return nil, fmt.Errorf("web search emulation: no query found in messages")
}
slog.Info("web search emulation: executing search",
"account_id", account.ID, "account_name", account.Name, "query", query)
resp, providerName, err := doWebSearch(ctx, account, query)
if err != nil {
return nil, err
}
slog.Info("web search emulation: search completed",
"provider", providerName, "results_count", len(resp.Results))
model := parsed.Model
if model == "" {
model = defaultWebSearchModel
}
if parsed.Stream {
return writeWebSearchStreamResponse(c, query, resp, model, startTime)
}
return writeWebSearchNonStreamResponse(c, query, resp, model, startTime)
}
func doWebSearch(ctx context.Context, account *Account, query string) (*websearch.SearchResponse, string, error) {
proxyURL := resolveAccountProxyURL(account)
mgr := getWebSearchManager()
if mgr == nil {
return nil, "", fmt.Errorf("web search emulation: manager not initialized")
}
resp, providerName, err := mgr.SearchWithBestProvider(ctx, websearch.SearchRequest{
Query: query, MaxResults: webSearchDefaultMaxResults, ProxyURL: proxyURL,
})
if err != nil {
slog.Error("web search emulation: search failed", "error", err)
return nil, "", fmt.Errorf("web search emulation: %w", err)
}
return resp, providerName, nil
}
func resolveAccountProxyURL(account *Account) string {
if account.ProxyID != nil && account.Proxy != nil {
return account.Proxy.URL()
}
return ""
}
// --- SSE streaming response ---
func writeWebSearchStreamResponse(
c *gin.Context, query string, resp *websearch.SearchResponse, model string, startTime time.Time,
) (*ForwardResult, error) {
msgID := webSearchMsgIDPrefix + uuid.New().String()
toolUseID := webSearchToolUseIDPrefix + uuid.New().String()[:16]
setSSEHeaders(c)
if err := writeSSEMessageStart(c.Writer, msgID, model); err != nil {
return nil, fmt.Errorf("web search emulation: SSE write: %w", err)
}
writeSSEServerToolUse(c.Writer, toolUseID, query, 0)
writeSSEToolResult(c.Writer, toolUseID, resp.Results, 1)
textSummary := buildTextSummary(query, resp.Results)
writeSSETextBlock(c.Writer, textSummary, 2)
writeSSEMessageEnd(c.Writer, len(textSummary)/tokenEstimateDivisor)
c.Writer.Flush()
return &ForwardResult{Model: model, Duration: time.Since(startTime), Usage: ClaudeUsage{}}, nil
}
func setSSEHeaders(c *gin.Context) {
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("X-Accel-Buffering", "no")
c.Writer.WriteHeader(http.StatusOK)
}
func writeSSEMessageStart(w http.ResponseWriter, msgID, model string) error {
evt := map[string]any{
"type": "message_start",
"message": map[string]any{
"id": msgID, "type": "message", "role": "assistant", "model": model,
"content": []any{}, "stop_reason": nil, "stop_sequence": nil,
"usage": map[string]int{"input_tokens": 0, "output_tokens": 0},
},
}
return flushSSEJSON(w, "message_start", evt)
}
func writeSSEServerToolUse(w http.ResponseWriter, toolUseID, query string, index int) {
start := map[string]any{
"type": "content_block_start", "index": index,
"content_block": map[string]any{
"type": "server_tool_use", "id": toolUseID,
"name": toolNameWebSearch, "input": map[string]string{"query": query},
},
}
_ = flushSSEJSON(w, "content_block_start", start)
_ = flushSSEJSON(w, "content_block_stop", map[string]any{"type": "content_block_stop", "index": index})
}
func writeSSEToolResult(w http.ResponseWriter, toolUseID string, results []websearch.SearchResult, index int) {
start := map[string]any{
"type": "content_block_start", "index": index,
"content_block": map[string]any{
"type": "web_search_tool_result", "tool_use_id": toolUseID,
"content": buildSearchResultBlocks(results),
},
}
_ = flushSSEJSON(w, "content_block_start", start)
_ = flushSSEJSON(w, "content_block_stop", map[string]any{"type": "content_block_stop", "index": index})
}
func writeSSETextBlock(w http.ResponseWriter, text string, index int) {
_ = flushSSEJSON(w, "content_block_start", map[string]any{
"type": "content_block_start", "index": index,
"content_block": map[string]any{"type": "text", "text": ""},
})
_ = flushSSEJSON(w, "content_block_delta", map[string]any{
"type": "content_block_delta", "index": index,
"delta": map[string]string{"type": "text_delta", "text": text},
})
_ = flushSSEJSON(w, "content_block_stop", map[string]any{"type": "content_block_stop", "index": index})
}
func writeSSEMessageEnd(w http.ResponseWriter, outputTokens int) {
_ = flushSSEJSON(w, "message_delta", map[string]any{
"type": "message_delta",
"delta": map[string]any{"stop_reason": "end_turn", "stop_sequence": nil},
"usage": map[string]int{"output_tokens": outputTokens},
})
_ = flushSSEJSON(w, "message_stop", map[string]string{"type": "message_stop"})
}
// flushSSEJSON marshals data to JSON and writes an SSE event. Returns error on marshal failure.
func flushSSEJSON(w http.ResponseWriter, event string, data any) error {
b, err := json.Marshal(data)
if err != nil {
slog.Error("web search emulation: failed to marshal SSE event",
"event", event, "error", err)
return err
}
fmt.Fprintf(w, "event: %s\ndata: %s\n\n", event, b)
if f, ok := w.(http.Flusher); ok {
f.Flush()
}
return nil
}
// --- Non-streaming JSON response ---
func writeWebSearchNonStreamResponse(
c *gin.Context, query string, resp *websearch.SearchResponse, model string, startTime time.Time,
) (*ForwardResult, error) {
msgID := webSearchMsgIDPrefix + uuid.New().String()
toolUseID := webSearchToolUseIDPrefix + uuid.New().String()[:16]
textSummary := buildTextSummary(query, resp.Results)
msg := map[string]any{
"id": msgID, "type": "message", "role": "assistant", "model": model,
"content": []any{
map[string]any{
"type": "server_tool_use", "id": toolUseID,
"name": toolNameWebSearch, "input": map[string]string{"query": query},
},
map[string]any{
"type": "web_search_tool_result", "tool_use_id": toolUseID,
"content": buildSearchResultBlocks(resp.Results),
},
map[string]any{"type": "text", "text": textSummary},
},
"stop_reason": "end_turn", "stop_sequence": nil,
"usage": map[string]int{"input_tokens": 0, "output_tokens": len(textSummary) / tokenEstimateDivisor},
}
body, err := json.Marshal(msg)
if err != nil {
return nil, fmt.Errorf("web search emulation: marshal response: %w", err)
}
c.Data(http.StatusOK, "application/json", body)
return &ForwardResult{Model: model, Duration: time.Since(startTime), Usage: ClaudeUsage{}}, nil
}
// --- Helpers ---
func buildSearchResultBlocks(results []websearch.SearchResult) []map[string]string {
blocks := make([]map[string]string, 0, len(results))
for _, r := range results {
block := map[string]string{
"type": "web_search_result",
"url": r.URL,
"title": r.Title,
}
if r.Snippet != "" {
block["page_content"] = r.Snippet
}
if r.PageAge != "" {
block["page_age"] = r.PageAge
}
blocks = append(blocks, block)
}
return blocks
}
func buildTextSummary(query string, results []websearch.SearchResult) string {
if len(results) == 0 {
return "No search results found for: " + query
}
var sb strings.Builder
fmt.Fprintf(&sb, "Here are the search results for \"%s\":\n\n", query)
for i, r := range results {
fmt.Fprintf(&sb, "%d. **%s**\n %s\n %s\n\n", i+1, r.Title, r.URL, r.Snippet)
}
return sb.String()
}

View File

@@ -0,0 +1,142 @@
package service
import (
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/websearch"
"github.com/stretchr/testify/require"
)
// --- isOnlyWebSearchToolInBody ---
func TestIsOnlyWebSearchToolInBody_WebSearchType(t *testing.T) {
require.True(t, isOnlyWebSearchToolInBody([]byte(`{"tools":[{"type":"web_search"}]}`)))
}
func TestIsOnlyWebSearchToolInBody_WebSearch2025Type(t *testing.T) {
require.True(t, isOnlyWebSearchToolInBody([]byte(`{"tools":[{"type":"web_search_20250305"}]}`)))
}
func TestIsOnlyWebSearchToolInBody_GoogleSearchType(t *testing.T) {
require.True(t, isOnlyWebSearchToolInBody([]byte(`{"tools":[{"type":"google_search"}]}`)))
}
func TestIsOnlyWebSearchToolInBody_NameWebSearch(t *testing.T) {
require.True(t, isOnlyWebSearchToolInBody([]byte(`{"tools":[{"name":"web_search"}]}`)))
}
func TestIsOnlyWebSearchToolInBody_NameWebSearch2025(t *testing.T) {
require.True(t, isOnlyWebSearchToolInBody([]byte(`{"tools":[{"name":"web_search_20250305"}]}`)))
}
func TestIsOnlyWebSearchToolInBody_NameGoogleSearch(t *testing.T) {
require.True(t, isOnlyWebSearchToolInBody([]byte(`{"tools":[{"name":"google_search"}]}`)))
}
func TestIsOnlyWebSearchToolInBody_MultipleTools(t *testing.T) {
require.False(t, isOnlyWebSearchToolInBody(
[]byte(`{"tools":[{"type":"web_search"},{"type":"text_editor"}]}`)))
}
func TestIsOnlyWebSearchToolInBody_NoTools(t *testing.T) {
require.False(t, isOnlyWebSearchToolInBody([]byte(`{"model":"claude-3"}`)))
}
func TestIsOnlyWebSearchToolInBody_EmptyToolsArray(t *testing.T) {
require.False(t, isOnlyWebSearchToolInBody([]byte(`{"tools":[]}`)))
}
func TestIsOnlyWebSearchToolInBody_NonWebSearchTool(t *testing.T) {
require.False(t, isOnlyWebSearchToolInBody([]byte(`{"tools":[{"type":"text_editor"}]}`)))
}
func TestIsOnlyWebSearchToolInBody_ToolsNotArray(t *testing.T) {
require.False(t, isOnlyWebSearchToolInBody([]byte(`{"tools":"web_search"}`)))
}
// --- extractSearchQueryFromBody ---
func TestExtractSearchQueryFromBody_StringContent(t *testing.T) {
body := `{"messages":[{"role":"user","content":"what is golang"}]}`
require.Equal(t, "what is golang", extractSearchQueryFromBody([]byte(body)))
}
func TestExtractSearchQueryFromBody_ArrayContent(t *testing.T) {
body := `{"messages":[{"role":"user","content":[{"type":"text","text":"search this"}]}]}`
require.Equal(t, "search this", extractSearchQueryFromBody([]byte(body)))
}
func TestExtractSearchQueryFromBody_MultipleMessages(t *testing.T) {
body := `{"messages":[{"role":"user","content":"first"},{"role":"assistant","content":"ok"},{"role":"user","content":"second"}]}`
require.Equal(t, "second", extractSearchQueryFromBody([]byte(body)))
}
func TestExtractSearchQueryFromBody_LastMessageNotUser(t *testing.T) {
body := `{"messages":[{"role":"user","content":"q"},{"role":"assistant","content":"a"}]}`
require.Equal(t, "", extractSearchQueryFromBody([]byte(body)))
}
func TestExtractSearchQueryFromBody_EmptyMessages(t *testing.T) {
require.Equal(t, "", extractSearchQueryFromBody([]byte(`{"messages":[]}`)))
}
func TestExtractSearchQueryFromBody_NoMessages(t *testing.T) {
require.Equal(t, "", extractSearchQueryFromBody([]byte(`{"model":"claude-3"}`)))
}
func TestExtractSearchQueryFromBody_ArrayContentSkipsEmptyText(t *testing.T) {
body := `{"messages":[{"role":"user","content":[{"type":"image"},{"type":"text","text":""},{"type":"text","text":"real query"}]}]}`
require.Equal(t, "real query", extractSearchQueryFromBody([]byte(body)))
}
func TestExtractSearchQueryFromBody_ArrayContentNoTextBlock(t *testing.T) {
body := `{"messages":[{"role":"user","content":[{"type":"image","source":{}}]}]}`
require.Equal(t, "", extractSearchQueryFromBody([]byte(body)))
}
// --- buildSearchResultBlocks ---
func TestBuildSearchResultBlocks_WithResults(t *testing.T) {
results := []websearch.SearchResult{
{URL: "https://a.com", Title: "A", Snippet: "snippet a", PageAge: "2 days"},
{URL: "https://b.com", Title: "B", Snippet: "snippet b"},
}
blocks := buildSearchResultBlocks(results)
require.Len(t, blocks, 2)
require.Equal(t, "web_search_result", blocks[0]["type"])
require.Equal(t, "https://a.com", blocks[0]["url"])
require.Equal(t, "snippet a", blocks[0]["page_content"])
require.Equal(t, "2 days", blocks[0]["page_age"])
// Second result has no PageAge
require.Equal(t, "https://b.com", blocks[1]["url"])
_, hasPageAge := blocks[1]["page_age"]
require.False(t, hasPageAge)
}
func TestBuildSearchResultBlocks_Empty(t *testing.T) {
blocks := buildSearchResultBlocks(nil)
require.Empty(t, blocks)
}
func TestBuildSearchResultBlocks_SnippetEmpty(t *testing.T) {
blocks := buildSearchResultBlocks([]websearch.SearchResult{{URL: "https://x.com", Title: "X", Snippet: ""}})
_, hasContent := blocks[0]["page_content"]
require.False(t, hasContent)
}
// --- buildTextSummary ---
func TestBuildTextSummary_WithResults(t *testing.T) {
results := []websearch.SearchResult{
{URL: "https://a.com", Title: "A", Snippet: "desc a"},
}
summary := buildTextSummary("test query", results)
require.Contains(t, summary, "test query")
require.Contains(t, summary, "1. **A**")
require.Contains(t, summary, "https://a.com")
}
func TestBuildTextSummary_NoResults(t *testing.T) {
summary := buildTextSummary("test", nil)
require.Contains(t, summary, "No search results found for: test")
}

View File

@@ -18,6 +18,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/imroc/req/v3"
"github.com/redis/go-redis/v9"
"golang.org/x/sync/singleflight"
)
@@ -106,6 +107,7 @@ type SettingService struct {
cfg *config.Config
onUpdate func() // Callback when settings are updated (for cache invalidation)
version string // Application version
webSearchRedis *redis.Client // optional: Redis client for web search quota tracking
}
// NewSettingService 创建系统设置服务实例
@@ -1217,6 +1219,14 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
result.EnableMetadataPassthrough = settings[SettingKeyEnableMetadataPassthrough] == "true"
result.EnableCCHSigning = settings[SettingKeyEnableCCHSigning] == "true"
// Web search emulation: quick enabled check from the JSON config
if raw := settings[SettingKeyWebSearchEmulationConfig]; raw != "" {
var wsCfg WebSearchEmulationConfig
if err := json.Unmarshal([]byte(raw), &wsCfg); err == nil {
result.WebSearchEmulationEnabled = wsCfg.Enabled && len(wsCfg.Providers) > 0
}
}
return result
}

View File

@@ -106,6 +106,9 @@ type SystemSettings struct {
EnableFingerprintUnification bool // 是否统一 OAuth 账号的指纹头(默认 true
EnableMetadataPassthrough bool // 是否透传客户端原始 metadata默认 false
EnableCCHSigning bool // 是否对 billing header cch 进行签名(默认 false
// Web Search Emulation (read-only quick check; full config via dedicated API)
WebSearchEmulationEnabled bool
}
type DefaultSubscriptionSetting struct {

View File

@@ -0,0 +1,253 @@
package service
import (
"context"
"encoding/json"
"fmt"
"log/slog"
"sync/atomic"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/websearch"
"github.com/redis/go-redis/v9"
"golang.org/x/sync/singleflight"
)
// WebSearchEmulationConfig holds the global web search emulation configuration.
type WebSearchEmulationConfig struct {
Enabled bool `json:"enabled"`
Providers []WebSearchProviderConfig `json:"providers"`
}
// WebSearchProviderConfig describes a single search provider (Brave or Tavily).
type WebSearchProviderConfig struct {
Type string `json:"type"` // websearch.ProviderTypeBrave | Tavily
APIKey string `json:"api_key,omitempty"` // secret — omitted in API responses
APIKeyConfigured bool `json:"api_key_configured"` // read-only mask
Priority int `json:"priority"` // lower = higher priority
QuotaLimit int64 `json:"quota_limit"` // 0 = unlimited
QuotaRefreshInterval string `json:"quota_refresh_interval"` // websearch.QuotaRefresh*
QuotaUsed int64 `json:"quota_used,omitempty"` // read-only: current period usage
ProxyID *int64 `json:"proxy_id"` // optional proxy association
ExpiresAt *int64 `json:"expires_at,omitempty"` // optional expiration timestamp
}
// --- Validation ---
const maxWebSearchProviders = 10
var validProviderTypes = map[string]bool{
websearch.ProviderTypeBrave: true,
websearch.ProviderTypeTavily: true,
}
var validQuotaIntervals = map[string]bool{
websearch.QuotaRefreshDaily: true,
websearch.QuotaRefreshWeekly: true,
websearch.QuotaRefreshMonthly: true,
"": true, // defaults to monthly
}
func validateWebSearchConfig(cfg *WebSearchEmulationConfig) error {
if cfg == nil {
return nil
}
if len(cfg.Providers) > maxWebSearchProviders {
return fmt.Errorf("too many providers (max %d)", maxWebSearchProviders)
}
seen := make(map[string]bool, len(cfg.Providers))
for i, p := range cfg.Providers {
if !validProviderTypes[p.Type] {
return fmt.Errorf("provider[%d]: invalid type %q", i, p.Type)
}
if !validQuotaIntervals[p.QuotaRefreshInterval] {
return fmt.Errorf("provider[%d]: invalid quota_refresh_interval %q", i, p.QuotaRefreshInterval)
}
if p.QuotaLimit < 0 {
return fmt.Errorf("provider[%d]: quota_limit must be >= 0", i)
}
if seen[p.Type] {
return fmt.Errorf("provider[%d]: duplicate type %q", i, p.Type)
}
seen[p.Type] = true
}
return nil
}
// --- In-process cache (same pattern as gateway forwarding settings) ---
const sfKeyWebSearchConfig = "web_search_emulation_config"
type cachedWebSearchEmulationConfig struct {
config *WebSearchEmulationConfig
expiresAt int64 // unix nano
}
var webSearchEmulationCache atomic.Value // *cachedWebSearchEmulationConfig
var webSearchEmulationSF singleflight.Group
const (
webSearchEmulationCacheTTL = 60 * time.Second
webSearchEmulationErrorTTL = 5 * time.Second
webSearchEmulationDBTimeout = 5 * time.Second
)
// GetWebSearchEmulationConfig returns the configuration with in-process cache + singleflight.
func (s *SettingService) GetWebSearchEmulationConfig(ctx context.Context) (*WebSearchEmulationConfig, error) {
if cached := webSearchEmulationCache.Load(); cached != nil {
c := cached.(*cachedWebSearchEmulationConfig)
if time.Now().UnixNano() < c.expiresAt {
return c.config, nil
}
}
result, err, _ := webSearchEmulationSF.Do(sfKeyWebSearchConfig, func() (any, error) {
return s.loadWebSearchConfigFromDB()
})
if err != nil {
return &WebSearchEmulationConfig{}, err
}
return result.(*WebSearchEmulationConfig), nil
}
func (s *SettingService) loadWebSearchConfigFromDB() (*WebSearchEmulationConfig, error) {
dbCtx, cancel := context.WithTimeout(context.Background(), webSearchEmulationDBTimeout)
defer cancel()
raw, err := s.settingRepo.GetValue(dbCtx, SettingKeyWebSearchEmulationConfig)
if err != nil {
webSearchEmulationCache.Store(&cachedWebSearchEmulationConfig{
config: &WebSearchEmulationConfig{},
expiresAt: time.Now().Add(webSearchEmulationErrorTTL).UnixNano(),
})
return &WebSearchEmulationConfig{}, err
}
cfg := parseWebSearchConfigJSON(raw)
webSearchEmulationCache.Store(&cachedWebSearchEmulationConfig{
config: cfg,
expiresAt: time.Now().Add(webSearchEmulationCacheTTL).UnixNano(),
})
return cfg, nil
}
func parseWebSearchConfigJSON(raw string) *WebSearchEmulationConfig {
cfg := &WebSearchEmulationConfig{}
if raw == "" {
return cfg
}
if err := json.Unmarshal([]byte(raw), cfg); err != nil {
slog.Warn("websearch: failed to parse config JSON", "error", err)
return &WebSearchEmulationConfig{}
}
return cfg
}
// SaveWebSearchEmulationConfig validates and persists the configuration.
// Empty API keys in the input are preserved from the existing config.
func (s *SettingService) SaveWebSearchEmulationConfig(ctx context.Context, cfg *WebSearchEmulationConfig) error {
if err := validateWebSearchConfig(cfg); err != nil {
return infraerrors.BadRequest("INVALID_WEB_SEARCH_CONFIG", err.Error())
}
s.mergeExistingAPIKeys(ctx, cfg)
data, err := json.Marshal(cfg)
if err != nil {
return fmt.Errorf("websearch: marshal config: %w", err)
}
if err := s.settingRepo.Set(ctx, SettingKeyWebSearchEmulationConfig, string(data)); err != nil {
return fmt.Errorf("websearch: save config: %w", err)
}
// Invalidate: forget singleflight first, then store new value
webSearchEmulationSF.Forget(sfKeyWebSearchConfig)
webSearchEmulationCache.Store(&cachedWebSearchEmulationConfig{
config: cfg,
expiresAt: time.Now().Add(webSearchEmulationCacheTTL).UnixNano(),
})
// Hot-reload: rebuild the global Manager with new config
s.RebuildWebSearchManager(ctx)
return nil
}
// mergeExistingAPIKeys preserves API keys from the current config when incoming value is empty.
func (s *SettingService) mergeExistingAPIKeys(ctx context.Context, cfg *WebSearchEmulationConfig) {
existing, _ := s.getWebSearchEmulationConfigRaw(ctx)
if existing == nil || cfg == nil {
return
}
existingByType := make(map[string]string, len(existing.Providers))
for _, p := range existing.Providers {
if p.APIKey != "" {
existingByType[p.Type] = p.APIKey
}
}
for i := range cfg.Providers {
if cfg.Providers[i].APIKey == "" {
if key, ok := existingByType[cfg.Providers[i].Type]; ok {
cfg.Providers[i].APIKey = key
}
}
}
}
func (s *SettingService) getWebSearchEmulationConfigRaw(ctx context.Context) (*WebSearchEmulationConfig, error) {
raw, err := s.settingRepo.GetValue(ctx, SettingKeyWebSearchEmulationConfig)
if err != nil {
return nil, err
}
return parseWebSearchConfigJSON(raw), nil
}
// IsWebSearchEmulationEnabled is a quick check for whether the global switch is on.
func (s *SettingService) IsWebSearchEmulationEnabled(ctx context.Context) bool {
cfg, err := s.GetWebSearchEmulationConfig(ctx)
if err != nil {
return false
}
return cfg.Enabled && len(cfg.Providers) > 0
}
// SetWebSearchRedisClient injects the Redis client used for quota tracking.
// Call after construction, before first use. Triggers initial Manager build.
func (s *SettingService) SetWebSearchRedisClient(ctx context.Context, redisClient *redis.Client) {
s.webSearchRedis = redisClient
s.RebuildWebSearchManager(ctx)
}
// RebuildWebSearchManager reads the current config and (re)creates the global websearch.Manager.
// Called on startup and after SaveWebSearchEmulationConfig.
func (s *SettingService) RebuildWebSearchManager(ctx context.Context) {
cfg, err := s.GetWebSearchEmulationConfig(ctx)
if err != nil || !cfg.Enabled || len(cfg.Providers) == 0 {
SetWebSearchManager(nil)
return
}
providerConfigs := make([]websearch.ProviderConfig, 0, len(cfg.Providers))
for _, p := range cfg.Providers {
providerConfigs = append(providerConfigs, websearch.ProviderConfig{
Type: p.Type,
APIKey: p.APIKey,
Priority: p.Priority,
QuotaLimit: p.QuotaLimit,
QuotaRefreshInterval: p.QuotaRefreshInterval,
ExpiresAt: p.ExpiresAt,
})
}
SetWebSearchManager(websearch.NewManager(providerConfigs, s.webSearchRedis))
slog.Info("websearch: manager rebuilt", "provider_count", len(providerConfigs))
}
// SanitizeWebSearchConfig returns a copy with api_key fields masked for API responses.
func SanitizeWebSearchConfig(cfg *WebSearchEmulationConfig) *WebSearchEmulationConfig {
if cfg == nil {
return nil
}
out := *cfg
out.Providers = make([]WebSearchProviderConfig, len(cfg.Providers))
for i, p := range cfg.Providers {
out.Providers[i] = p
out.Providers[i].APIKeyConfigured = p.APIKey != ""
out.Providers[i].APIKey = "" // never return the secret
}
return &out
}

View File

@@ -0,0 +1,148 @@
package service
import (
"testing"
"github.com/stretchr/testify/require"
)
// --- validateWebSearchConfig ---
func TestValidateWebSearchConfig_Nil(t *testing.T) {
require.NoError(t, validateWebSearchConfig(nil))
}
func TestValidateWebSearchConfig_Valid(t *testing.T) {
cfg := &WebSearchEmulationConfig{
Enabled: true,
Providers: []WebSearchProviderConfig{
{Type: "brave", Priority: 1, QuotaLimit: 1000, QuotaRefreshInterval: "monthly"},
{Type: "tavily", Priority: 2, QuotaLimit: 500, QuotaRefreshInterval: "daily"},
},
}
require.NoError(t, validateWebSearchConfig(cfg))
}
func TestValidateWebSearchConfig_TooManyProviders(t *testing.T) {
cfg := &WebSearchEmulationConfig{Providers: make([]WebSearchProviderConfig, 11)}
for i := range cfg.Providers {
cfg.Providers[i] = WebSearchProviderConfig{Type: "brave"}
}
err := validateWebSearchConfig(cfg)
require.ErrorContains(t, err, "too many providers")
}
func TestValidateWebSearchConfig_InvalidType(t *testing.T) {
cfg := &WebSearchEmulationConfig{
Providers: []WebSearchProviderConfig{{Type: "bing"}},
}
require.ErrorContains(t, validateWebSearchConfig(cfg), "invalid type")
}
func TestValidateWebSearchConfig_InvalidQuotaInterval(t *testing.T) {
cfg := &WebSearchEmulationConfig{
Providers: []WebSearchProviderConfig{{Type: "brave", QuotaRefreshInterval: "hourly"}},
}
require.ErrorContains(t, validateWebSearchConfig(cfg), "invalid quota_refresh_interval")
}
func TestValidateWebSearchConfig_NegativeQuotaLimit(t *testing.T) {
cfg := &WebSearchEmulationConfig{
Providers: []WebSearchProviderConfig{{Type: "brave", QuotaLimit: -1}},
}
require.ErrorContains(t, validateWebSearchConfig(cfg), "quota_limit must be >= 0")
}
func TestValidateWebSearchConfig_DuplicateType(t *testing.T) {
cfg := &WebSearchEmulationConfig{
Providers: []WebSearchProviderConfig{
{Type: "brave", Priority: 1},
{Type: "brave", Priority: 2},
},
}
require.ErrorContains(t, validateWebSearchConfig(cfg), "duplicate type")
}
func TestValidateWebSearchConfig_EmptyQuotaInterval(t *testing.T) {
cfg := &WebSearchEmulationConfig{
Providers: []WebSearchProviderConfig{{Type: "brave", QuotaRefreshInterval: ""}},
}
require.NoError(t, validateWebSearchConfig(cfg))
}
func TestValidateWebSearchConfig_ZeroQuotaLimit(t *testing.T) {
cfg := &WebSearchEmulationConfig{
Providers: []WebSearchProviderConfig{{Type: "brave", QuotaLimit: 0}},
}
require.NoError(t, validateWebSearchConfig(cfg))
}
// --- parseWebSearchConfigJSON ---
func TestParseWebSearchConfigJSON_ValidJSON(t *testing.T) {
raw := `{"enabled":true,"providers":[{"type":"brave","api_key":"sk-xxx"}]}`
cfg := parseWebSearchConfigJSON(raw)
require.True(t, cfg.Enabled)
require.Len(t, cfg.Providers, 1)
require.Equal(t, "brave", cfg.Providers[0].Type)
}
func TestParseWebSearchConfigJSON_EmptyString(t *testing.T) {
cfg := parseWebSearchConfigJSON("")
require.False(t, cfg.Enabled)
require.Empty(t, cfg.Providers)
}
func TestParseWebSearchConfigJSON_InvalidJSON(t *testing.T) {
cfg := parseWebSearchConfigJSON("not{json")
require.False(t, cfg.Enabled)
require.Empty(t, cfg.Providers)
}
// --- SanitizeWebSearchConfig ---
func TestSanitizeWebSearchConfig_MaskAPIKey(t *testing.T) {
cfg := &WebSearchEmulationConfig{
Enabled: true,
Providers: []WebSearchProviderConfig{
{Type: "brave", APIKey: "sk-secret-xxx"},
},
}
out := SanitizeWebSearchConfig(cfg)
require.Equal(t, "", out.Providers[0].APIKey)
require.True(t, out.Providers[0].APIKeyConfigured)
}
func TestSanitizeWebSearchConfig_NoAPIKey(t *testing.T) {
cfg := &WebSearchEmulationConfig{
Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: ""}},
}
out := SanitizeWebSearchConfig(cfg)
require.Equal(t, "", out.Providers[0].APIKey)
require.False(t, out.Providers[0].APIKeyConfigured)
}
func TestSanitizeWebSearchConfig_Nil(t *testing.T) {
require.Nil(t, SanitizeWebSearchConfig(nil))
}
func TestSanitizeWebSearchConfig_PreservesOtherFields(t *testing.T) {
cfg := &WebSearchEmulationConfig{
Enabled: true,
Providers: []WebSearchProviderConfig{
{Type: "brave", APIKey: "secret", Priority: 10, QuotaLimit: 1000},
},
}
out := SanitizeWebSearchConfig(cfg)
require.True(t, out.Enabled)
require.Equal(t, 10, out.Providers[0].Priority)
require.Equal(t, int64(1000), out.Providers[0].QuotaLimit)
}
func TestSanitizeWebSearchConfig_DoesNotMutateOriginal(t *testing.T) {
cfg := &WebSearchEmulationConfig{
Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: "secret"}},
}
_ = SanitizeWebSearchConfig(cfg)
require.Equal(t, "secret", cfg.Providers[0].APIKey)
}