feat: WebSearch tri-state, account stats pricing fix, quota cache fix, usage tooltip
WebSearch tri-state switch: - Account-level web_search_emulation changed from bool to tri-state string: "default" (follow channel) / "enabled" / "disabled" - shouldEmulateWebSearch checks channel config when account is "default" - SQL migration converts old bool values - Frontend select replaces toggle in Edit/CreateAccountModal Account stats pricing: - resolveAccountStatsCost uses upstream model (post-mapping) for matching - Priority: custom rules → model pricing file (when toggle on) → default - Custom rules always configurable, independent of toggle - Account ID field changed to searchable selector filtered by platform - Description updated to reflect new behavior Quota notification cache fix: - CheckAccountQuotaAfterIncrement fetches real-time account from DB - Reconstructs pre-increment usage for accurate threshold crossing detection - New AccountQuotaReader interface (minimal: GetByID only) Usage tooltip: - Per-request/image billing shows per-request price instead of $0 token price - Token billing continues to show input/output price per million tokens
This commit is contained in:
@@ -1169,15 +1169,30 @@ func (a *Account) IsAnthropicAPIKeyPassthroughEnabled() bool {
|
||||
return ok && enabled
|
||||
}
|
||||
|
||||
// IsWebSearchEmulationEnabled 返回 Anthropic API Key 账号是否启用 web search 模拟。
|
||||
// 字段:accounts.extra.web_search_emulation。
|
||||
// 字段缺失或类型不正确时,按 false(关闭)处理。
|
||||
func (a *Account) IsWebSearchEmulationEnabled() bool {
|
||||
// WebSearch 模拟三态常量
|
||||
const (
|
||||
WebSearchModeDefault = "default" // 跟随渠道配置
|
||||
WebSearchModeEnabled = "enabled" // 强制开启
|
||||
WebSearchModeDisabled = "disabled" // 强制关闭
|
||||
)
|
||||
|
||||
// GetWebSearchEmulationMode 返回账号的 WebSearch 模拟模式。
|
||||
// 三态:default(跟随渠道)/ enabled(强制开启)/ disabled(强制关闭)。
|
||||
// 旧 bool 值需通过 SQL 迁移脚本转换,Go 代码不做兼容。
|
||||
func (a *Account) GetWebSearchEmulationMode() string {
|
||||
if a == nil || a.Platform != PlatformAnthropic || a.Type != AccountTypeAPIKey || a.Extra == nil {
|
||||
return false
|
||||
return WebSearchModeDefault
|
||||
}
|
||||
mode, ok := a.Extra[featureKeyWebSearchEmulation].(string)
|
||||
if !ok {
|
||||
return WebSearchModeDefault
|
||||
}
|
||||
switch mode {
|
||||
case WebSearchModeEnabled, WebSearchModeDisabled:
|
||||
return mode
|
||||
default:
|
||||
return WebSearchModeDefault
|
||||
}
|
||||
enabled, ok := a.Extra[featureKeyWebSearchEmulation].(bool)
|
||||
return ok && enabled
|
||||
}
|
||||
|
||||
// IsCodexCLIOnlyEnabled 返回 OpenAI OAuth 账号是否启用"仅允许 Codex 官方客户端"。
|
||||
|
||||
@@ -8,11 +8,17 @@ import (
|
||||
|
||||
// resolveAccountStatsCost 计算账号统计定价费用。
|
||||
// 返回 nil 表示不覆盖,使用默认公式(total_cost × account_rate_multiplier)。
|
||||
// 仅匹配自定义规则(AccountStatsPricingRules),按数组顺序先命中为准。
|
||||
// upstreamModel 是最终发往上游的模型 ID,用于匹配自定义规则中的模型定价。
|
||||
//
|
||||
// 优先级(先命中为准):
|
||||
// 1. 自定义规则(始终尝试,不依赖 ApplyPricingToAccountStats 开关)
|
||||
// 2. ApplyPricingToAccountStats 启用时,用模型定价文件(LiteLLM)中上游模型的标准价格计算
|
||||
// 3. nil → 走默认公式
|
||||
//
|
||||
// upstreamModel 是最终发往上游的模型 ID。
|
||||
func resolveAccountStatsCost(
|
||||
ctx context.Context,
|
||||
channelService *ChannelService,
|
||||
billingService *BillingService,
|
||||
accountID int64,
|
||||
groupID int64,
|
||||
upstreamModel string,
|
||||
@@ -23,12 +29,39 @@ func resolveAccountStatsCost(
|
||||
return nil
|
||||
}
|
||||
channel, err := channelService.GetChannelForGroup(ctx, groupID)
|
||||
if err != nil || channel == nil || !channel.ApplyPricingToAccountStats {
|
||||
if err != nil || channel == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
platform := channelService.GetGroupPlatform(ctx, groupID)
|
||||
return tryCustomRules(channel, accountID, groupID, platform, upstreamModel, tokens, requestCount)
|
||||
|
||||
// 优先级 1:自定义规则(始终尝试)
|
||||
if cost := tryCustomRules(channel, accountID, groupID, platform, upstreamModel, tokens, requestCount); cost != nil {
|
||||
return cost
|
||||
}
|
||||
|
||||
// 优先级 2:模型定价文件(LiteLLM/fallback)中上游模型的标准价格
|
||||
if channel.ApplyPricingToAccountStats && billingService != nil {
|
||||
return tryModelFilePricing(billingService, upstreamModel, tokens)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// tryModelFilePricing 使用模型定价文件(LiteLLM/fallback)中的标准价格计算费用。
|
||||
func tryModelFilePricing(billingService *BillingService, model string, tokens UsageTokens) *float64 {
|
||||
pricing, err := billingService.GetModelPricing(model)
|
||||
if err != nil || pricing == nil {
|
||||
return nil
|
||||
}
|
||||
cost := float64(tokens.InputTokens)*pricing.InputPricePerToken +
|
||||
float64(tokens.OutputTokens)*pricing.OutputPricePerToken +
|
||||
float64(tokens.CacheCreationTokens)*pricing.CacheCreationPricePerToken +
|
||||
float64(tokens.CacheReadTokens)*pricing.CacheReadPricePerToken
|
||||
if cost <= 0 {
|
||||
return nil
|
||||
}
|
||||
return &cost
|
||||
}
|
||||
|
||||
// tryCustomRules 遍历自定义规则,按数组顺序先命中为准。
|
||||
|
||||
@@ -27,17 +27,24 @@ var quotaDimLabels = map[string]string{
|
||||
quotaDimTotal: "总限额 / Total",
|
||||
}
|
||||
|
||||
// AccountQuotaReader provides read access to account quota data.
|
||||
type AccountQuotaReader interface {
|
||||
GetByID(ctx context.Context, id int64) (*Account, error)
|
||||
}
|
||||
|
||||
// BalanceNotifyService handles balance and quota threshold notifications.
|
||||
type BalanceNotifyService struct {
|
||||
emailService *EmailService
|
||||
settingRepo SettingRepository
|
||||
accountRepo AccountQuotaReader
|
||||
}
|
||||
|
||||
// NewBalanceNotifyService creates a new BalanceNotifyService.
|
||||
func NewBalanceNotifyService(emailService *EmailService, settingRepo SettingRepository) *BalanceNotifyService {
|
||||
func NewBalanceNotifyService(emailService *EmailService, settingRepo SettingRepository, accountRepo AccountQuotaReader) *BalanceNotifyService {
|
||||
return &BalanceNotifyService{
|
||||
emailService: emailService,
|
||||
settingRepo: settingRepo,
|
||||
accountRepo: accountRepo,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -110,7 +117,7 @@ func buildQuotaDims(account *Account) []quotaDim {
|
||||
}
|
||||
|
||||
// CheckAccountQuotaAfterIncrement checks if any quota dimension crossed above its notify threshold.
|
||||
// The account's Extra fields contain pre-increment usage values.
|
||||
// It fetches real-time quota usage from DB to avoid stale snapshot values.
|
||||
func (s *BalanceNotifyService) CheckAccountQuotaAfterIncrement(ctx context.Context, account *Account, cost float64) {
|
||||
if account == nil || s.emailService == nil || s.settingRepo == nil || cost <= 0 {
|
||||
return
|
||||
@@ -123,8 +130,29 @@ func (s *BalanceNotifyService) CheckAccountQuotaAfterIncrement(ctx context.Conte
|
||||
return
|
||||
}
|
||||
|
||||
freshAccount := s.fetchFreshAccount(ctx, account)
|
||||
siteName := s.getSiteName(ctx)
|
||||
for _, dim := range buildQuotaDims(account) {
|
||||
s.checkQuotaDimCrossings(freshAccount, cost, adminEmails, siteName)
|
||||
}
|
||||
|
||||
// fetchFreshAccount loads the latest account from DB; falls back to the snapshot on error.
|
||||
func (s *BalanceNotifyService) fetchFreshAccount(ctx context.Context, snapshot *Account) *Account {
|
||||
if s.accountRepo == nil {
|
||||
return snapshot
|
||||
}
|
||||
fresh, err := s.accountRepo.GetByID(ctx, snapshot.ID)
|
||||
if err != nil {
|
||||
slog.Warn("failed to fetch fresh account for quota notify, using snapshot",
|
||||
"account_id", snapshot.ID, "error", err)
|
||||
return snapshot
|
||||
}
|
||||
return fresh
|
||||
}
|
||||
|
||||
// checkQuotaDimCrossings iterates quota dimensions and sends alerts for threshold crossings.
|
||||
// freshAccount has post-increment values; oldUsed is reconstructed as freshUsed - cost.
|
||||
func (s *BalanceNotifyService) checkQuotaDimCrossings(freshAccount *Account, cost float64, adminEmails []string, siteName string) {
|
||||
for _, dim := range buildQuotaDims(freshAccount) {
|
||||
if !dim.enabled || dim.threshold <= 0 {
|
||||
continue
|
||||
}
|
||||
@@ -132,9 +160,12 @@ func (s *BalanceNotifyService) CheckAccountQuotaAfterIncrement(ctx context.Conte
|
||||
if effectiveThreshold <= 0 {
|
||||
continue
|
||||
}
|
||||
newUsed := dim.oldUsed + cost
|
||||
if dim.oldUsed < effectiveThreshold && newUsed >= effectiveThreshold {
|
||||
s.asyncSendQuotaAlert(adminEmails, account.Name, dim, newUsed, effectiveThreshold, siteName)
|
||||
// dim.oldUsed is actually the post-increment value from fresh DB data;
|
||||
// reconstruct pre-increment value to detect threshold crossing.
|
||||
newUsed := dim.oldUsed
|
||||
oldUsed := dim.oldUsed - cost
|
||||
if oldUsed < effectiveThreshold && newUsed >= effectiveThreshold {
|
||||
s.asyncSendQuotaAlert(adminEmails, freshAccount.Name, dim, newUsed, effectiveThreshold, siteName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -75,6 +75,9 @@ type ParsedRequest struct {
|
||||
MaxTokens int // max_tokens 值(用于探测请求拦截)
|
||||
SessionContext *SessionContext // 可选:请求上下文区分因子(nil 时行为不变)
|
||||
|
||||
// GroupID 请求所属分组 ID(来自 API Key)
|
||||
GroupID *int64
|
||||
|
||||
// OnUpstreamAccepted 上游接受请求后立即调用(用于提前释放串行锁)
|
||||
// 流式请求在收到 2xx 响应头后调用,避免持锁等流完成
|
||||
OnUpstreamAccepted func()
|
||||
|
||||
@@ -3789,7 +3789,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
}
|
||||
|
||||
// Web Search 模拟:纯 web_search 请求时,直接调用搜索 API 构造响应
|
||||
if account != nil && s.shouldEmulateWebSearch(ctx, account, parsed.Body) {
|
||||
if account != nil && s.shouldEmulateWebSearch(ctx, account, parsed.GroupID, parsed.Body) {
|
||||
return s.handleWebSearchEmulation(ctx, c, account, parsed)
|
||||
}
|
||||
|
||||
@@ -7588,7 +7588,7 @@ func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsage
|
||||
upstreamModel = result.Model
|
||||
}
|
||||
usageLog.AccountStatsCost = resolveAccountStatsCost(
|
||||
ctx, s.channelService,
|
||||
ctx, s.channelService, s.billingService,
|
||||
account.ID, *apiKey.GroupID, upstreamModel,
|
||||
UsageTokens{
|
||||
InputTokens: result.Usage.InputTokens,
|
||||
|
||||
@@ -49,10 +49,9 @@ func getWebSearchManager() *websearch.Manager {
|
||||
|
||||
// shouldEmulateWebSearch checks whether a request should be intercepted.
|
||||
//
|
||||
// Judgment chain: manager exists → only web_search tool → global enabled → account enabled.
|
||||
// Note: channel-level control is enforced via the account's extra field; the channel toggle
|
||||
// in the admin UI sets the account's flag for all accounts in that channel's groups.
|
||||
func (s *GatewayService) shouldEmulateWebSearch(ctx context.Context, account *Account, body []byte) bool {
|
||||
// Judgment chain: manager exists → only web_search tool → global enabled → account/channel enabled.
|
||||
// Account-level mode: "enabled" (force on), "disabled" (force off), "default" (follow channel).
|
||||
func (s *GatewayService) shouldEmulateWebSearch(ctx context.Context, account *Account, groupID *int64, body []byte) bool {
|
||||
if getWebSearchManager() == nil {
|
||||
return false
|
||||
}
|
||||
@@ -62,10 +61,23 @@ func (s *GatewayService) shouldEmulateWebSearch(ctx context.Context, account *Ac
|
||||
if !s.settingService.IsWebSearchEmulationEnabled(ctx) {
|
||||
return false
|
||||
}
|
||||
if !account.IsWebSearchEmulationEnabled() {
|
||||
|
||||
mode := account.GetWebSearchEmulationMode()
|
||||
switch mode {
|
||||
case WebSearchModeEnabled:
|
||||
return true
|
||||
case WebSearchModeDisabled:
|
||||
return false
|
||||
default: // "default" → follow channel config
|
||||
if groupID == nil || s.channelService == nil {
|
||||
return false
|
||||
}
|
||||
ch, err := s.channelService.GetChannelForGroup(ctx, *groupID)
|
||||
if err != nil || ch == nil {
|
||||
return false
|
||||
}
|
||||
return ch.IsWebSearchEmulationEnabled(account.Platform)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// isOnlyWebSearchToolInBody checks if the body contains exactly one web_search tool.
|
||||
|
||||
@@ -4580,7 +4580,7 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
||||
statsModel = result.Model
|
||||
}
|
||||
usageLog.AccountStatsCost = resolveAccountStatsCost(
|
||||
ctx, s.channelService,
|
||||
ctx, s.channelService, s.billingService,
|
||||
account.ID, *apiKey.GroupID, statsModel,
|
||||
tokens, 1,
|
||||
)
|
||||
|
||||
@@ -476,8 +476,8 @@ func ProvidePaymentConfigService(entClient *dbent.Client, settingRepo SettingRep
|
||||
}
|
||||
|
||||
// ProvideBalanceNotifyService creates BalanceNotifyService
|
||||
func ProvideBalanceNotifyService(emailService *EmailService, settingRepo SettingRepository) *BalanceNotifyService {
|
||||
return NewBalanceNotifyService(emailService, settingRepo)
|
||||
func ProvideBalanceNotifyService(emailService *EmailService, settingRepo SettingRepository, accountRepo AccountRepository) *BalanceNotifyService {
|
||||
return NewBalanceNotifyService(emailService, settingRepo, accountRepo)
|
||||
}
|
||||
|
||||
// ProvidePaymentOrderExpiryService creates and starts PaymentOrderExpiryService.
|
||||
|
||||
Reference in New Issue
Block a user