feat: WebSearch tri-state, account stats pricing fix, quota cache fix, usage tooltip

WebSearch tri-state switch:
- Account-level web_search_emulation changed from bool to tri-state
  string: "default" (follow channel) / "enabled" / "disabled"
- shouldEmulateWebSearch checks channel config when account is "default"
- SQL migration converts old bool values
- Frontend select replaces toggle in Edit/CreateAccountModal

Account stats pricing:
- resolveAccountStatsCost uses upstream model (post-mapping) for matching
- Priority: custom rules → model pricing file (when toggle on) → default
- Custom rules always configurable, independent of toggle
- Account ID field changed to searchable selector filtered by platform
- Description updated to reflect new behavior

Quota notification cache fix:
- CheckAccountQuotaAfterIncrement fetches real-time account from DB
- Reconstructs pre-increment usage for accurate threshold crossing detection
- New AccountQuotaReader interface (minimal: GetByID only)

Usage tooltip:
- Per-request/image billing shows per-request price instead of $0 token price
- Token billing continues to show input/output price per million tokens
This commit is contained in:
erio
2026-04-13 11:37:08 +08:00
parent 11c4606874
commit 1262654d97
18 changed files with 346 additions and 79 deletions

View File

@@ -1169,15 +1169,30 @@ func (a *Account) IsAnthropicAPIKeyPassthroughEnabled() bool {
return ok && enabled
}
// IsWebSearchEmulationEnabled 返回 Anthropic API Key 账号是否启用 web search 模拟。
// 字段accounts.extra.web_search_emulation。
// 字段缺失或类型不正确时,按 false关闭处理。
func (a *Account) IsWebSearchEmulationEnabled() bool {
// WebSearch 模拟三态常量
const (
WebSearchModeDefault = "default" // 跟随渠道配置
WebSearchModeEnabled = "enabled" // 强制开启
WebSearchModeDisabled = "disabled" // 强制关闭
)
// GetWebSearchEmulationMode 返回账号的 WebSearch 模拟模式。
// 三态default跟随渠道/ enabled强制开启/ disabled强制关闭
// 旧 bool 值需通过 SQL 迁移脚本转换Go 代码不做兼容。
func (a *Account) GetWebSearchEmulationMode() string {
if a == nil || a.Platform != PlatformAnthropic || a.Type != AccountTypeAPIKey || a.Extra == nil {
return false
return WebSearchModeDefault
}
mode, ok := a.Extra[featureKeyWebSearchEmulation].(string)
if !ok {
return WebSearchModeDefault
}
switch mode {
case WebSearchModeEnabled, WebSearchModeDisabled:
return mode
default:
return WebSearchModeDefault
}
enabled, ok := a.Extra[featureKeyWebSearchEmulation].(bool)
return ok && enabled
}
// IsCodexCLIOnlyEnabled 返回 OpenAI OAuth 账号是否启用"仅允许 Codex 官方客户端"。

View File

@@ -8,11 +8,17 @@ import (
// resolveAccountStatsCost 计算账号统计定价费用。
// 返回 nil 表示不覆盖使用默认公式total_cost × account_rate_multiplier
// 仅匹配自定义规则AccountStatsPricingRules按数组顺序先命中为准。
// upstreamModel 是最终发往上游的模型 ID用于匹配自定义规则中的模型定价。
//
// 优先级(先命中为准):
// 1. 自定义规则(始终尝试,不依赖 ApplyPricingToAccountStats 开关)
// 2. ApplyPricingToAccountStats 启用时用模型定价文件LiteLLM中上游模型的标准价格计算
// 3. nil → 走默认公式
//
// upstreamModel 是最终发往上游的模型 ID。
func resolveAccountStatsCost(
ctx context.Context,
channelService *ChannelService,
billingService *BillingService,
accountID int64,
groupID int64,
upstreamModel string,
@@ -23,12 +29,39 @@ func resolveAccountStatsCost(
return nil
}
channel, err := channelService.GetChannelForGroup(ctx, groupID)
if err != nil || channel == nil || !channel.ApplyPricingToAccountStats {
if err != nil || channel == nil {
return nil
}
platform := channelService.GetGroupPlatform(ctx, groupID)
return tryCustomRules(channel, accountID, groupID, platform, upstreamModel, tokens, requestCount)
// 优先级 1自定义规则始终尝试
if cost := tryCustomRules(channel, accountID, groupID, platform, upstreamModel, tokens, requestCount); cost != nil {
return cost
}
// 优先级 2模型定价文件LiteLLM/fallback中上游模型的标准价格
if channel.ApplyPricingToAccountStats && billingService != nil {
return tryModelFilePricing(billingService, upstreamModel, tokens)
}
return nil
}
// tryModelFilePricing 使用模型定价文件LiteLLM/fallback中的标准价格计算费用。
func tryModelFilePricing(billingService *BillingService, model string, tokens UsageTokens) *float64 {
pricing, err := billingService.GetModelPricing(model)
if err != nil || pricing == nil {
return nil
}
cost := float64(tokens.InputTokens)*pricing.InputPricePerToken +
float64(tokens.OutputTokens)*pricing.OutputPricePerToken +
float64(tokens.CacheCreationTokens)*pricing.CacheCreationPricePerToken +
float64(tokens.CacheReadTokens)*pricing.CacheReadPricePerToken
if cost <= 0 {
return nil
}
return &cost
}
// tryCustomRules 遍历自定义规则,按数组顺序先命中为准。

View File

@@ -27,17 +27,24 @@ var quotaDimLabels = map[string]string{
quotaDimTotal: "总限额 / Total",
}
// AccountQuotaReader provides read access to account quota data.
type AccountQuotaReader interface {
GetByID(ctx context.Context, id int64) (*Account, error)
}
// BalanceNotifyService handles balance and quota threshold notifications.
type BalanceNotifyService struct {
emailService *EmailService
settingRepo SettingRepository
accountRepo AccountQuotaReader
}
// NewBalanceNotifyService creates a new BalanceNotifyService.
func NewBalanceNotifyService(emailService *EmailService, settingRepo SettingRepository) *BalanceNotifyService {
func NewBalanceNotifyService(emailService *EmailService, settingRepo SettingRepository, accountRepo AccountQuotaReader) *BalanceNotifyService {
return &BalanceNotifyService{
emailService: emailService,
settingRepo: settingRepo,
accountRepo: accountRepo,
}
}
@@ -110,7 +117,7 @@ func buildQuotaDims(account *Account) []quotaDim {
}
// CheckAccountQuotaAfterIncrement checks if any quota dimension crossed above its notify threshold.
// The account's Extra fields contain pre-increment usage values.
// It fetches real-time quota usage from DB to avoid stale snapshot values.
func (s *BalanceNotifyService) CheckAccountQuotaAfterIncrement(ctx context.Context, account *Account, cost float64) {
if account == nil || s.emailService == nil || s.settingRepo == nil || cost <= 0 {
return
@@ -123,8 +130,29 @@ func (s *BalanceNotifyService) CheckAccountQuotaAfterIncrement(ctx context.Conte
return
}
freshAccount := s.fetchFreshAccount(ctx, account)
siteName := s.getSiteName(ctx)
for _, dim := range buildQuotaDims(account) {
s.checkQuotaDimCrossings(freshAccount, cost, adminEmails, siteName)
}
// fetchFreshAccount loads the latest account from DB; falls back to the snapshot on error.
func (s *BalanceNotifyService) fetchFreshAccount(ctx context.Context, snapshot *Account) *Account {
if s.accountRepo == nil {
return snapshot
}
fresh, err := s.accountRepo.GetByID(ctx, snapshot.ID)
if err != nil {
slog.Warn("failed to fetch fresh account for quota notify, using snapshot",
"account_id", snapshot.ID, "error", err)
return snapshot
}
return fresh
}
// checkQuotaDimCrossings iterates quota dimensions and sends alerts for threshold crossings.
// freshAccount has post-increment values; oldUsed is reconstructed as freshUsed - cost.
func (s *BalanceNotifyService) checkQuotaDimCrossings(freshAccount *Account, cost float64, adminEmails []string, siteName string) {
for _, dim := range buildQuotaDims(freshAccount) {
if !dim.enabled || dim.threshold <= 0 {
continue
}
@@ -132,9 +160,12 @@ func (s *BalanceNotifyService) CheckAccountQuotaAfterIncrement(ctx context.Conte
if effectiveThreshold <= 0 {
continue
}
newUsed := dim.oldUsed + cost
if dim.oldUsed < effectiveThreshold && newUsed >= effectiveThreshold {
s.asyncSendQuotaAlert(adminEmails, account.Name, dim, newUsed, effectiveThreshold, siteName)
// dim.oldUsed is actually the post-increment value from fresh DB data;
// reconstruct pre-increment value to detect threshold crossing.
newUsed := dim.oldUsed
oldUsed := dim.oldUsed - cost
if oldUsed < effectiveThreshold && newUsed >= effectiveThreshold {
s.asyncSendQuotaAlert(adminEmails, freshAccount.Name, dim, newUsed, effectiveThreshold, siteName)
}
}
}

View File

@@ -75,6 +75,9 @@ type ParsedRequest struct {
MaxTokens int // max_tokens 值(用于探测请求拦截)
SessionContext *SessionContext // 可选请求上下文区分因子nil 时行为不变)
// GroupID 请求所属分组 ID来自 API Key
GroupID *int64
// OnUpstreamAccepted 上游接受请求后立即调用(用于提前释放串行锁)
// 流式请求在收到 2xx 响应头后调用,避免持锁等流完成
OnUpstreamAccepted func()

View File

@@ -3789,7 +3789,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
}
// Web Search 模拟:纯 web_search 请求时,直接调用搜索 API 构造响应
if account != nil && s.shouldEmulateWebSearch(ctx, account, parsed.Body) {
if account != nil && s.shouldEmulateWebSearch(ctx, account, parsed.GroupID, parsed.Body) {
return s.handleWebSearchEmulation(ctx, c, account, parsed)
}
@@ -7588,7 +7588,7 @@ func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsage
upstreamModel = result.Model
}
usageLog.AccountStatsCost = resolveAccountStatsCost(
ctx, s.channelService,
ctx, s.channelService, s.billingService,
account.ID, *apiKey.GroupID, upstreamModel,
UsageTokens{
InputTokens: result.Usage.InputTokens,

View File

@@ -49,10 +49,9 @@ func getWebSearchManager() *websearch.Manager {
// shouldEmulateWebSearch checks whether a request should be intercepted.
//
// Judgment chain: manager exists → only web_search tool → global enabled → account enabled.
// Note: channel-level control is enforced via the account's extra field; the channel toggle
// in the admin UI sets the account's flag for all accounts in that channel's groups.
func (s *GatewayService) shouldEmulateWebSearch(ctx context.Context, account *Account, body []byte) bool {
// Judgment chain: manager exists → only web_search tool → global enabled → account/channel enabled.
// Account-level mode: "enabled" (force on), "disabled" (force off), "default" (follow channel).
func (s *GatewayService) shouldEmulateWebSearch(ctx context.Context, account *Account, groupID *int64, body []byte) bool {
if getWebSearchManager() == nil {
return false
}
@@ -62,10 +61,23 @@ func (s *GatewayService) shouldEmulateWebSearch(ctx context.Context, account *Ac
if !s.settingService.IsWebSearchEmulationEnabled(ctx) {
return false
}
if !account.IsWebSearchEmulationEnabled() {
mode := account.GetWebSearchEmulationMode()
switch mode {
case WebSearchModeEnabled:
return true
case WebSearchModeDisabled:
return false
default: // "default" → follow channel config
if groupID == nil || s.channelService == nil {
return false
}
ch, err := s.channelService.GetChannelForGroup(ctx, *groupID)
if err != nil || ch == nil {
return false
}
return ch.IsWebSearchEmulationEnabled(account.Platform)
}
return true
}
// isOnlyWebSearchToolInBody checks if the body contains exactly one web_search tool.

View File

@@ -4580,7 +4580,7 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
statsModel = result.Model
}
usageLog.AccountStatsCost = resolveAccountStatsCost(
ctx, s.channelService,
ctx, s.channelService, s.billingService,
account.ID, *apiKey.GroupID, statsModel,
tokens, 1,
)

View File

@@ -476,8 +476,8 @@ func ProvidePaymentConfigService(entClient *dbent.Client, settingRepo SettingRep
}
// ProvideBalanceNotifyService creates BalanceNotifyService
func ProvideBalanceNotifyService(emailService *EmailService, settingRepo SettingRepository) *BalanceNotifyService {
return NewBalanceNotifyService(emailService, settingRepo)
func ProvideBalanceNotifyService(emailService *EmailService, settingRepo SettingRepository, accountRepo AccountRepository) *BalanceNotifyService {
return NewBalanceNotifyService(emailService, settingRepo, accountRepo)
}
// ProvidePaymentOrderExpiryService creates and starts PaymentOrderExpiryService.