Files
sub2api/backend/internal/service/channel.go
erio 1b53ffcac7 feat(gateway): add web search emulation for Anthropic API Key accounts
Inject web search capability for Claude Console (API Key) accounts that
don't natively support Anthropic's web_search tool. When a pure
web_search request is detected, the gateway calls Brave Search or Tavily
API directly and constructs an Anthropic-protocol-compliant SSE/JSON
response without forwarding to upstream.

Backend:
- New `pkg/websearch/` SDK: Brave and Tavily provider implementations
  with io.LimitReader, proxy support, and Redis-based quota tracking
  (Lua atomic INCR + TTL, DECR rollback on failure)
- Global config via `settings.web_search_emulation_config` (JSON) with
  in-process cache + singleflight, input validation, API key merge on
  save, and sanitized API responses
- Channel-level toggle via `channels.features_config` JSONB column
  (DB migration 101)
- Account-level toggle via `accounts.extra.web_search_emulation`
- Request interception in `Forward()` with SSE streaming response
  construction using json.Marshal (no manual string concatenation)
- Manager hot-reload: `RebuildWebSearchManager()` called on config save
  and startup via `SetWebSearchRedisClient()`
- 70 unit tests covering providers, manager, config validation,
  sanitization, tool detection, query extraction, and response building

Frontend:
- Settings → Gateway tab: Web Search Emulation config card with global
  toggle, provider list (add/remove, API key, priority, quota, proxy)
- Channels → Anthropic tab: web search emulation toggle with global
  state linkage (disabled when global off)
- Account Create/Edit modals: web search emulation toggle for API Key
  type with Toggle component
- Full i18n coverage (zh + en)
2026-04-14 09:20:39 +08:00

294 lines
9.6 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package service
import (
"fmt"
"sort"
"strings"
"time"
)
// BillingMode 计费模式
type BillingMode string
const (
BillingModeToken BillingMode = "token" // 按 token 区间计费
BillingModePerRequest BillingMode = "per_request" // 按次计费(支持上下文窗口分层)
BillingModeImage BillingMode = "image" // 图片计费(当前按次,预留 token 计费)
)
// IsValid 检查 BillingMode 是否为合法值
func (m BillingMode) IsValid() bool {
switch m {
case BillingModeToken, BillingModePerRequest, BillingModeImage, "":
return true
}
return false
}
const (
BillingModelSourceRequested = "requested"
BillingModelSourceUpstream = "upstream"
BillingModelSourceChannelMapped = "channel_mapped"
)
// Channel 渠道实体
type Channel struct {
ID int64
Name string
Description string
Status string
BillingModelSource string // "requested", "upstream", or "channel_mapped"
RestrictModels bool // 是否限制模型(仅允许定价列表中的模型)
Features string // 渠道特性描述JSON 数组),用于支付页面展示
CreatedAt time.Time
UpdatedAt time.Time
// 关联的分组 ID 列表
GroupIDs []int64
// 模型定价列表(每条含 Platform 字段)
ModelPricing []ChannelModelPricing
// 渠道级模型映射按平台分组platform → {src→dst}
ModelMapping map[string]map[string]string
// 渠道特性配置(如 {"web_search_emulation": {"anthropic": true}}
FeaturesConfig map[string]any
}
// IsWebSearchEmulationEnabled 返回该渠道是否为指定平台启用了 web search 模拟。
func (c *Channel) IsWebSearchEmulationEnabled(platform string) bool {
if c == nil || c.FeaturesConfig == nil {
return false
}
wse, ok := c.FeaturesConfig[featureKeyWebSearchEmulation].(map[string]any)
if !ok {
return false
}
enabled, ok := wse[platform].(bool)
return ok && enabled
}
// ChannelModelPricing 渠道模型定价条目
type ChannelModelPricing struct {
ID int64
ChannelID int64
Platform string // 所属平台anthropic/openai/gemini/...
Models []string // 绑定的模型列表
BillingMode BillingMode // 计费模式
InputPrice *float64 // 每 token 输入价格USD— 向后兼容 flat 定价
OutputPrice *float64 // 每 token 输出价格USD
CacheWritePrice *float64 // 缓存写入价格
CacheReadPrice *float64 // 缓存读取价格
ImageOutputPrice *float64 // 图片输出价格(向后兼容)
PerRequestPrice *float64 // 默认按次计费价格USD
Intervals []PricingInterval // 区间定价列表
CreatedAt time.Time
UpdatedAt time.Time
}
// PricingInterval 定价区间token 区间 / 按次分层 / 图片分辨率分层)
type PricingInterval struct {
ID int64
PricingID int64
MinTokens int // 区间下界(含)
MaxTokens *int // 区间上界不含nil = 无上限
TierLabel string // 层级标签(按次/图片模式1K, 2K, 4K, HD 等)
InputPrice *float64 // token 模式:每 token 输入价
OutputPrice *float64 // token 模式:每 token 输出价
CacheWritePrice *float64 // token 模式:缓存写入价
CacheReadPrice *float64 // token 模式:缓存读取价
PerRequestPrice *float64 // 按次/图片模式:每次请求价格
SortOrder int
CreatedAt time.Time
UpdatedAt time.Time
}
// IsActive 判断渠道是否启用
func (c *Channel) IsActive() bool {
return c.Status == StatusActive
}
// GetModelPricing 根据模型名查找渠道定价,未找到返回 nil。
// 精确匹配,大小写不敏感。返回值拷贝,不污染缓存。
func (c *Channel) GetModelPricing(model string) *ChannelModelPricing {
modelLower := strings.ToLower(model)
for i := range c.ModelPricing {
for _, m := range c.ModelPricing[i].Models {
if strings.ToLower(m) == modelLower {
cp := c.ModelPricing[i].Clone()
return &cp
}
}
}
return nil
}
// FindMatchingInterval 在区间列表中查找匹配 totalTokens 的区间。
// 区间为左开右闭 (min, max]min 不含max 包含。
// 第一个区间 min=0 时0 token 不匹配任何区间(回退到默认价格)。
func FindMatchingInterval(intervals []PricingInterval, totalTokens int) *PricingInterval {
for i := range intervals {
iv := &intervals[i]
if totalTokens > iv.MinTokens && (iv.MaxTokens == nil || totalTokens <= *iv.MaxTokens) {
return iv
}
}
return nil
}
// GetIntervalForContext 根据总 context token 数查找匹配的区间。
func (p *ChannelModelPricing) GetIntervalForContext(totalTokens int) *PricingInterval {
return FindMatchingInterval(p.Intervals, totalTokens)
}
// GetTierByLabel 根据标签查找层级(用于 per_request / image 模式)
func (p *ChannelModelPricing) GetTierByLabel(label string) *PricingInterval {
labelLower := strings.ToLower(label)
for i := range p.Intervals {
if strings.ToLower(p.Intervals[i].TierLabel) == labelLower {
return &p.Intervals[i]
}
}
return nil
}
// Clone 返回 ChannelModelPricing 的拷贝(切片独立,指针字段共享,调用方只读安全)
func (p ChannelModelPricing) Clone() ChannelModelPricing {
cp := p
if p.Models != nil {
cp.Models = make([]string, len(p.Models))
copy(cp.Models, p.Models)
}
if p.Intervals != nil {
cp.Intervals = make([]PricingInterval, len(p.Intervals))
copy(cp.Intervals, p.Intervals)
}
return cp
}
// Clone 返回 Channel 的深拷贝
func (c *Channel) Clone() *Channel {
if c == nil {
return nil
}
cp := *c
if c.GroupIDs != nil {
cp.GroupIDs = make([]int64, len(c.GroupIDs))
copy(cp.GroupIDs, c.GroupIDs)
}
if c.ModelPricing != nil {
cp.ModelPricing = make([]ChannelModelPricing, len(c.ModelPricing))
for i := range c.ModelPricing {
cp.ModelPricing[i] = c.ModelPricing[i].Clone()
}
}
if c.ModelMapping != nil {
cp.ModelMapping = make(map[string]map[string]string, len(c.ModelMapping))
for platform, mapping := range c.ModelMapping {
inner := make(map[string]string, len(mapping))
for k, v := range mapping {
inner[k] = v
}
cp.ModelMapping[platform] = inner
}
}
return &cp
}
// ValidateIntervals 校验区间列表的合法性。
// 规则MinTokens >= 0MaxTokens 若非 nil 则 > 0 且 > MinTokens
// 所有价格字段 >= 0区间按 MinTokens 排序后无重叠((min, max] 语义);
// 无界区间MaxTokens=nil必须是最后一个。间隙允许回退默认价格
func ValidateIntervals(intervals []PricingInterval) error {
if len(intervals) == 0 {
return nil
}
sorted := make([]PricingInterval, len(intervals))
copy(sorted, intervals)
sort.Slice(sorted, func(i, j int) bool {
return sorted[i].MinTokens < sorted[j].MinTokens
})
for i := range sorted {
if err := validateSingleInterval(&sorted[i], i); err != nil {
return err
}
}
return validateIntervalOverlap(sorted)
}
// validateSingleInterval 校验单个区间的字段合法性
func validateSingleInterval(iv *PricingInterval, idx int) error {
if iv.MinTokens < 0 {
return fmt.Errorf("interval #%d: min_tokens (%d) must be >= 0", idx+1, iv.MinTokens)
}
if iv.MaxTokens != nil {
if *iv.MaxTokens <= 0 {
return fmt.Errorf("interval #%d: max_tokens (%d) must be > 0", idx+1, *iv.MaxTokens)
}
if *iv.MaxTokens <= iv.MinTokens {
return fmt.Errorf("interval #%d: max_tokens (%d) must be > min_tokens (%d)",
idx+1, *iv.MaxTokens, iv.MinTokens)
}
}
return validateIntervalPrices(iv, idx)
}
// validateIntervalPrices 校验区间内所有价格字段 >= 0
func validateIntervalPrices(iv *PricingInterval, idx int) error {
prices := []struct {
name string
val *float64
}{
{"input_price", iv.InputPrice},
{"output_price", iv.OutputPrice},
{"cache_write_price", iv.CacheWritePrice},
{"cache_read_price", iv.CacheReadPrice},
{"per_request_price", iv.PerRequestPrice},
}
for _, p := range prices {
if p.val != nil && *p.val < 0 {
return fmt.Errorf("interval #%d: %s must be >= 0", idx+1, p.name)
}
}
return nil
}
// validateIntervalOverlap 校验排序后的区间列表无重叠,且无界区间在最后
func validateIntervalOverlap(sorted []PricingInterval) error {
for i, iv := range sorted {
// 无界区间必须是最后一个
if iv.MaxTokens == nil && i < len(sorted)-1 {
return fmt.Errorf("interval #%d: unbounded interval (max_tokens=null) must be the last one",
i+1)
}
if i == 0 {
continue
}
prev := sorted[i-1]
// 检查重叠:前一个区间的上界 > 当前区间的下界则重叠
// (min, max] 语义prev 覆盖 (prev.Min, prev.Max]cur 覆盖 (cur.Min, cur.Max]
if prev.MaxTokens == nil || *prev.MaxTokens > iv.MinTokens {
return fmt.Errorf("interval #%d and #%d overlap: prev max=%s > cur min=%d",
i, i+1, formatMaxTokensLabel(prev.MaxTokens), iv.MinTokens)
}
}
return nil
}
func formatMaxTokensLabel(max *int) string {
if max == nil {
return "∞"
}
return fmt.Sprintf("%d", *max)
}
// ChannelUsageFields 渠道相关的使用记录字段(嵌入到各平台的 RecordUsageInput 中)
type ChannelUsageFields struct {
ChannelID int64 // 渠道 ID0 = 无渠道)
OriginalModel string // 用户原始请求模型(渠道映射前)
ChannelMappedModel string // 渠道映射后的模型名(无映射时等于 OriginalModel
BillingModelSource string // 计费模型来源:"requested" / "upstream" / "channel_mapped"
ModelMappingChain string // 映射链描述,如 "a→b→c"
}