feat(gateway): add web search emulation for Anthropic API Key accounts
Inject web search capability for Claude Console (API Key) accounts that don't natively support Anthropic's web_search tool. When a pure web_search request is detected, the gateway calls Brave Search or Tavily API directly and constructs an Anthropic-protocol-compliant SSE/JSON response without forwarding to upstream. Backend: - New `pkg/websearch/` SDK: Brave and Tavily provider implementations with io.LimitReader, proxy support, and Redis-based quota tracking (Lua atomic INCR + TTL, DECR rollback on failure) - Global config via `settings.web_search_emulation_config` (JSON) with in-process cache + singleflight, input validation, API key merge on save, and sanitized API responses - Channel-level toggle via `channels.features_config` JSONB column (DB migration 101) - Account-level toggle via `accounts.extra.web_search_emulation` - Request interception in `Forward()` with SSE streaming response construction using json.Marshal (no manual string concatenation) - Manager hot-reload: `RebuildWebSearchManager()` called on config save and startup via `SetWebSearchRedisClient()` - 70 unit tests covering providers, manager, config validation, sanitization, tool detection, query extraction, and response building Frontend: - Settings → Gateway tab: Web Search Emulation config card with global toggle, provider list (add/remove, API key, priority, quota, proxy) - Channels → Anthropic tab: web search emulation toggle with global state linkage (disabled when global off) - Account Create/Edit modals: web search emulation toggle for API Key type with Toggle component - Full i18n coverage (zh + en)
This commit is contained in:
@@ -41,10 +41,14 @@ func (r *channelRepository) Create(ctx context.Context, channel *service.Channel
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
featuresConfigJSON, err := marshalFeaturesConfig(channel.FeaturesConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = tx.QueryRowContext(ctx,
|
||||
`INSERT INTO channels (name, description, status, model_mapping, billing_model_source, restrict_models, features) VALUES ($1, $2, $3, $4, $5, $6, $7)
|
||||
`INSERT INTO channels (name, description, status, model_mapping, billing_model_source, restrict_models, features, features_config) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
||||
RETURNING id, created_at, updated_at`,
|
||||
channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.BillingModelSource, channel.RestrictModels, channel.Features,
|
||||
channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.BillingModelSource, channel.RestrictModels, channel.Features, featuresConfigJSON,
|
||||
).Scan(&channel.ID, &channel.CreatedAt, &channel.UpdatedAt)
|
||||
if err != nil {
|
||||
if isUniqueViolation(err) {
|
||||
@@ -73,11 +77,11 @@ func (r *channelRepository) Create(ctx context.Context, channel *service.Channel
|
||||
|
||||
func (r *channelRepository) GetByID(ctx context.Context, id int64) (*service.Channel, error) {
|
||||
ch := &service.Channel{}
|
||||
var modelMappingJSON []byte
|
||||
var modelMappingJSON, featuresConfigJSON []byte
|
||||
err := r.db.QueryRowContext(ctx,
|
||||
`SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, features, created_at, updated_at
|
||||
`SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, features, features_config, created_at, updated_at
|
||||
FROM channels WHERE id = $1`, id,
|
||||
).Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &ch.CreatedAt, &ch.UpdatedAt)
|
||||
).Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &featuresConfigJSON, &ch.CreatedAt, &ch.UpdatedAt)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, service.ErrChannelNotFound
|
||||
}
|
||||
@@ -85,6 +89,7 @@ func (r *channelRepository) GetByID(ctx context.Context, id int64) (*service.Cha
|
||||
return nil, fmt.Errorf("get channel: %w", err)
|
||||
}
|
||||
ch.ModelMapping = unmarshalModelMapping(modelMappingJSON)
|
||||
ch.FeaturesConfig = unmarshalFeaturesConfig(featuresConfigJSON)
|
||||
|
||||
groupIDs, err := r.GetGroupIDs(ctx, id)
|
||||
if err != nil {
|
||||
@@ -107,10 +112,14 @@ func (r *channelRepository) Update(ctx context.Context, channel *service.Channel
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
featuresConfigJSON, err := marshalFeaturesConfig(channel.FeaturesConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
result, err := tx.ExecContext(ctx,
|
||||
`UPDATE channels SET name = $1, description = $2, status = $3, model_mapping = $4, billing_model_source = $5, restrict_models = $6, features = $7, updated_at = NOW()
|
||||
WHERE id = $8`,
|
||||
channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.BillingModelSource, channel.RestrictModels, channel.Features, channel.ID,
|
||||
`UPDATE channels SET name = $1, description = $2, status = $3, model_mapping = $4, billing_model_source = $5, restrict_models = $6, features = $7, features_config = $8, updated_at = NOW()
|
||||
WHERE id = $9`,
|
||||
channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.BillingModelSource, channel.RestrictModels, channel.Features, featuresConfigJSON, channel.ID,
|
||||
)
|
||||
if err != nil {
|
||||
if isUniqueViolation(err) {
|
||||
@@ -187,9 +196,9 @@ func (r *channelRepository) List(ctx context.Context, params pagination.Paginati
|
||||
|
||||
// 查询 channel 列表
|
||||
dataQuery := fmt.Sprintf(
|
||||
`SELECT c.id, c.name, c.description, c.status, c.model_mapping, c.billing_model_source, c.restrict_models, c.features, c.created_at, c.updated_at
|
||||
FROM channels c WHERE %s ORDER BY c.id ASC LIMIT $%d OFFSET $%d`,
|
||||
whereClause, argIdx, argIdx+1,
|
||||
`SELECT c.id, c.name, c.description, c.status, c.model_mapping, c.billing_model_source, c.restrict_models, c.features, c.features_config, c.created_at, c.updated_at
|
||||
FROM channels c WHERE %s ORDER BY %s LIMIT $%d OFFSET $%d`,
|
||||
whereClause, channelListOrderBy(params), argIdx, argIdx+1,
|
||||
)
|
||||
args = append(args, pageSize, offset)
|
||||
|
||||
@@ -203,11 +212,12 @@ func (r *channelRepository) List(ctx context.Context, params pagination.Paginati
|
||||
var channelIDs []int64
|
||||
for rows.Next() {
|
||||
var ch service.Channel
|
||||
var modelMappingJSON []byte
|
||||
if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &ch.CreatedAt, &ch.UpdatedAt); err != nil {
|
||||
var modelMappingJSON, featuresConfigJSON []byte
|
||||
if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &featuresConfigJSON, &ch.CreatedAt, &ch.UpdatedAt); err != nil {
|
||||
return nil, nil, fmt.Errorf("scan channel: %w", err)
|
||||
}
|
||||
ch.ModelMapping = unmarshalModelMapping(modelMappingJSON)
|
||||
ch.FeaturesConfig = unmarshalFeaturesConfig(featuresConfigJSON)
|
||||
channels = append(channels, ch)
|
||||
channelIDs = append(channelIDs, ch.ID)
|
||||
}
|
||||
@@ -246,9 +256,34 @@ func (r *channelRepository) List(ctx context.Context, params pagination.Paginati
|
||||
return channels, paginationResult, nil
|
||||
}
|
||||
|
||||
func channelListOrderBy(params pagination.PaginationParams) string {
|
||||
sortBy := strings.ToLower(strings.TrimSpace(params.SortBy))
|
||||
sortOrder := strings.ToUpper(params.NormalizedSortOrder(pagination.SortOrderAsc))
|
||||
|
||||
var column string
|
||||
switch sortBy {
|
||||
case "":
|
||||
column = "c.id"
|
||||
sortOrder = "ASC"
|
||||
case "id":
|
||||
column = "c.id"
|
||||
case "name":
|
||||
column = "c.name"
|
||||
case "status":
|
||||
column = "c.status"
|
||||
case "created_at":
|
||||
column = "c.created_at"
|
||||
default:
|
||||
column = "c.id"
|
||||
sortOrder = "ASC"
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s %s, c.id %s", column, sortOrder, sortOrder)
|
||||
}
|
||||
|
||||
func (r *channelRepository) ListAll(ctx context.Context) ([]service.Channel, error) {
|
||||
rows, err := r.db.QueryContext(ctx,
|
||||
`SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, features, created_at, updated_at FROM channels ORDER BY id`,
|
||||
`SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, features, features_config, created_at, updated_at FROM channels ORDER BY id`,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query all channels: %w", err)
|
||||
@@ -259,11 +294,12 @@ func (r *channelRepository) ListAll(ctx context.Context) ([]service.Channel, err
|
||||
var channelIDs []int64
|
||||
for rows.Next() {
|
||||
var ch service.Channel
|
||||
var modelMappingJSON []byte
|
||||
if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &ch.CreatedAt, &ch.UpdatedAt); err != nil {
|
||||
var modelMappingJSON, featuresConfigJSON []byte
|
||||
if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &featuresConfigJSON, &ch.CreatedAt, &ch.UpdatedAt); err != nil {
|
||||
return nil, fmt.Errorf("scan channel: %w", err)
|
||||
}
|
||||
ch.ModelMapping = unmarshalModelMapping(modelMappingJSON)
|
||||
ch.FeaturesConfig = unmarshalFeaturesConfig(featuresConfigJSON)
|
||||
channels = append(channels, ch)
|
||||
channelIDs = append(channelIDs, ch.ID)
|
||||
}
|
||||
@@ -431,6 +467,28 @@ func unmarshalModelMapping(data []byte) map[string]map[string]string {
|
||||
return m
|
||||
}
|
||||
|
||||
func marshalFeaturesConfig(m map[string]any) ([]byte, error) {
|
||||
if len(m) == 0 {
|
||||
return []byte("{}"), nil
|
||||
}
|
||||
data, err := json.Marshal(m)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal features_config: %w", err)
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func unmarshalFeaturesConfig(data []byte) map[string]any {
|
||||
if len(data) == 0 {
|
||||
return nil
|
||||
}
|
||||
var m map[string]any
|
||||
if err := json.Unmarshal(data, &m); err != nil {
|
||||
return nil
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// GetGroupPlatforms 批量查询分组 ID 对应的平台
|
||||
func (r *channelRepository) GetGroupPlatforms(ctx context.Context, groupIDs []int64) (map[int64]string, error) {
|
||||
if len(groupIDs) == 0 {
|
||||
|
||||
Reference in New Issue
Block a user