From 02dea7b09b2126f2a91f14398279e88c7031bcf2 Mon Sep 17 00:00:00 2001 From: erio Date: Fri, 6 Mar 2026 00:37:37 +0800 Subject: [PATCH] refactor: unify post-usage billing logic and fix account quota calculation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Extract postUsageBilling() to consolidate billing logic across GatewayService.RecordUsage, RecordUsageWithLongContext, and OpenAIGatewayService.RecordUsage, eliminating ~120 lines of duplicated code - Fix account quota to use TotalCost × accountRateMultiplier (was using raw TotalCost, inconsistent with account cost stats) - Fix RecordUsageWithLongContext API Key quota only updating in balance mode (now updates regardless of billing type) - Fix WebSocket client disconnect detection on Windows by adding "an established connection was aborted" to known disconnect errors --- backend/internal/service/gateway_service.go | 192 ++++++++++-------- .../service/openai_gateway_service.go | 57 ++---- .../internal/service/openai_ws_forwarder.go | 3 +- 3 files changed, 131 insertions(+), 121 deletions(-) diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 006d4bc3..177c4631 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -6410,6 +6410,89 @@ type APIKeyQuotaUpdater interface { UpdateRateLimitUsage(ctx context.Context, apiKeyID int64, cost float64) error } +// postUsageBillingParams 统一扣费所需的参数 +type postUsageBillingParams struct { + Cost *CostBreakdown + User *User + APIKey *APIKey + Account *Account + Subscription *UserSubscription + IsSubscriptionBill bool + AccountRateMultiplier float64 + APIKeyService APIKeyQuotaUpdater +} + +// postUsageBilling 统一处理使用量记录后的扣费逻辑: +// - 订阅/余额扣费 +// - API Key 配额更新 +// - API Key 限速用量更新 +// - 账号配额用量更新(账号口径:TotalCost × 账号计费倍率) +func postUsageBilling(ctx context.Context, p *postUsageBillingParams, deps *billingDeps) { + cost := p.Cost + + // 1. 订阅 / 余额扣费 + if p.IsSubscriptionBill { + if cost.TotalCost > 0 { + if err := deps.userSubRepo.IncrementUsage(ctx, p.Subscription.ID, cost.TotalCost); err != nil { + slog.Error("increment subscription usage failed", "subscription_id", p.Subscription.ID, "error", err) + } + deps.billingCacheService.QueueUpdateSubscriptionUsage(p.User.ID, *p.APIKey.GroupID, cost.TotalCost) + } + } else { + if cost.ActualCost > 0 { + if err := deps.userRepo.DeductBalance(ctx, p.User.ID, cost.ActualCost); err != nil { + slog.Error("deduct balance failed", "user_id", p.User.ID, "error", err) + } + deps.billingCacheService.QueueDeductBalance(p.User.ID, cost.ActualCost) + } + } + + // 2. API Key 配额 + if cost.ActualCost > 0 && p.APIKey.Quota > 0 && p.APIKeyService != nil { + if err := p.APIKeyService.UpdateQuotaUsed(ctx, p.APIKey.ID, cost.ActualCost); err != nil { + slog.Error("update api key quota failed", "api_key_id", p.APIKey.ID, "error", err) + } + } + + // 3. API Key 限速用量 + if cost.ActualCost > 0 && p.APIKey.HasRateLimits() && p.APIKeyService != nil { + if err := p.APIKeyService.UpdateRateLimitUsage(ctx, p.APIKey.ID, cost.ActualCost); err != nil { + slog.Error("update api key rate limit usage failed", "api_key_id", p.APIKey.ID, "error", err) + } + deps.billingCacheService.QueueUpdateAPIKeyRateLimitUsage(p.APIKey.ID, cost.ActualCost) + } + + // 4. 账号配额用量(账号口径:TotalCost × 账号计费倍率) + if cost.TotalCost > 0 && p.Account.Type == AccountTypeAPIKey && p.Account.GetQuotaLimit() > 0 { + accountCost := cost.TotalCost * p.AccountRateMultiplier + if err := deps.accountRepo.IncrementQuotaUsed(ctx, p.Account.ID, accountCost); err != nil { + slog.Error("increment account quota used failed", "account_id", p.Account.ID, "cost", accountCost, "error", err) + } + } + + // 5. 更新账号最近使用时间 + deps.deferredService.ScheduleLastUsedUpdate(p.Account.ID) +} + +// billingDeps 扣费逻辑依赖的服务(由各 gateway service 提供) +type billingDeps struct { + accountRepo AccountRepository + userRepo UserRepository + userSubRepo UserSubscriptionRepository + billingCacheService *BillingCacheService + deferredService *DeferredService +} + +func (s *GatewayService) billingDeps() *billingDeps { + return &billingDeps{ + accountRepo: s.accountRepo, + userRepo: s.userRepo, + userSubRepo: s.userSubRepo, + billingCacheService: s.billingCacheService, + deferredService: s.deferredService, + } +} + // RecordUsage 记录使用量并扣费(或更新订阅用量) func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInput) error { result := input.Result @@ -6573,52 +6656,21 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu shouldBill := inserted || err != nil - // 根据计费类型执行扣费 - if isSubscriptionBilling { - // 订阅模式:更新订阅用量(使用 TotalCost 原始费用,不考虑倍率) - if shouldBill && cost.TotalCost > 0 { - if err := s.userSubRepo.IncrementUsage(ctx, subscription.ID, cost.TotalCost); err != nil { - logger.LegacyPrintf("service.gateway", "Increment subscription usage failed: %v", err) - } - // 异步更新订阅缓存 - s.billingCacheService.QueueUpdateSubscriptionUsage(user.ID, *apiKey.GroupID, cost.TotalCost) - } + if shouldBill { + postUsageBilling(ctx, &postUsageBillingParams{ + Cost: cost, + User: user, + APIKey: apiKey, + Account: account, + Subscription: subscription, + IsSubscriptionBill: isSubscriptionBilling, + AccountRateMultiplier: accountRateMultiplier, + APIKeyService: input.APIKeyService, + }, s.billingDeps()) } else { - // 余额模式:扣除用户余额(使用 ActualCost 考虑倍率后的费用) - if shouldBill && cost.ActualCost > 0 { - if err := s.userRepo.DeductBalance(ctx, user.ID, cost.ActualCost); err != nil { - logger.LegacyPrintf("service.gateway", "Deduct balance failed: %v", err) - } - // 异步更新余额缓存 - s.billingCacheService.QueueDeductBalance(user.ID, cost.ActualCost) - } + s.deferredService.ScheduleLastUsedUpdate(account.ID) } - // 更新 API Key 配额(如果设置了配额限制) - if shouldBill && cost.ActualCost > 0 && apiKey.Quota > 0 && input.APIKeyService != nil { - if err := input.APIKeyService.UpdateQuotaUsed(ctx, apiKey.ID, cost.ActualCost); err != nil { - logger.LegacyPrintf("service.gateway", "Update API key quota failed: %v", err) - } - } - - // Update API Key rate limit usage - if shouldBill && cost.ActualCost > 0 && apiKey.HasRateLimits() && input.APIKeyService != nil { - if err := input.APIKeyService.UpdateRateLimitUsage(ctx, apiKey.ID, cost.ActualCost); err != nil { - logger.LegacyPrintf("service.gateway", "Update API key rate limit usage failed: %v", err) - } - s.billingCacheService.QueueUpdateAPIKeyRateLimitUsage(apiKey.ID, cost.ActualCost) - } - - // 更新 API Key 账号配额用量 - if shouldBill && cost.TotalCost > 0 && account.Type == AccountTypeAPIKey && account.GetQuotaLimit() > 0 { - if err := s.accountRepo.IncrementQuotaUsed(ctx, account.ID, cost.TotalCost); err != nil { - slog.Error("increment account quota used failed", "account_id", account.ID, "cost", cost.TotalCost, "error", err) - } - } - - // Schedule batch update for account last_used_at - s.deferredService.ScheduleLastUsedUpdate(account.ID) - return nil } @@ -6778,51 +6830,21 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * shouldBill := inserted || err != nil - // 根据计费类型执行扣费 - if isSubscriptionBilling { - // 订阅模式:更新订阅用量(使用 TotalCost 原始费用,不考虑倍率) - if shouldBill && cost.TotalCost > 0 { - if err := s.userSubRepo.IncrementUsage(ctx, subscription.ID, cost.TotalCost); err != nil { - logger.LegacyPrintf("service.gateway", "Increment subscription usage failed: %v", err) - } - // 异步更新订阅缓存 - s.billingCacheService.QueueUpdateSubscriptionUsage(user.ID, *apiKey.GroupID, cost.TotalCost) - } + if shouldBill { + postUsageBilling(ctx, &postUsageBillingParams{ + Cost: cost, + User: user, + APIKey: apiKey, + Account: account, + Subscription: subscription, + IsSubscriptionBill: isSubscriptionBilling, + AccountRateMultiplier: accountRateMultiplier, + APIKeyService: input.APIKeyService, + }, s.billingDeps()) } else { - // 余额模式:扣除用户余额(使用 ActualCost 考虑倍率后的费用) - if shouldBill && cost.ActualCost > 0 { - if err := s.userRepo.DeductBalance(ctx, user.ID, cost.ActualCost); err != nil { - logger.LegacyPrintf("service.gateway", "Deduct balance failed: %v", err) - } - // 异步更新余额缓存 - s.billingCacheService.QueueDeductBalance(user.ID, cost.ActualCost) - // API Key 独立配额扣费 - if input.APIKeyService != nil && apiKey.Quota > 0 { - if err := input.APIKeyService.UpdateQuotaUsed(ctx, apiKey.ID, cost.ActualCost); err != nil { - logger.LegacyPrintf("service.gateway", "Add API key quota used failed: %v", err) - } - } - } + s.deferredService.ScheduleLastUsedUpdate(account.ID) } - // Update API Key rate limit usage - if shouldBill && cost.ActualCost > 0 && apiKey.HasRateLimits() && input.APIKeyService != nil { - if err := input.APIKeyService.UpdateRateLimitUsage(ctx, apiKey.ID, cost.ActualCost); err != nil { - logger.LegacyPrintf("service.gateway", "Update API key rate limit usage failed: %v", err) - } - s.billingCacheService.QueueUpdateAPIKeyRateLimitUsage(apiKey.ID, cost.ActualCost) - } - - // 更新 API Key 账号配额用量 - if shouldBill && cost.TotalCost > 0 && account.Type == AccountTypeAPIKey && account.GetQuotaLimit() > 0 { - if err := s.accountRepo.IncrementQuotaUsed(ctx, account.ID, cost.TotalCost); err != nil { - slog.Error("increment account quota used failed", "account_id", account.ID, "cost", cost.TotalCost, "error", err) - } - } - - // Schedule batch update for account last_used_at - s.deferredService.ScheduleLastUsedUpdate(account.ID) - return nil } diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index 6752d18b..84fe351c 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -319,6 +319,16 @@ func NewOpenAIGatewayService( return svc } +func (s *OpenAIGatewayService) billingDeps() *billingDeps { + return &billingDeps{ + accountRepo: s.accountRepo, + userRepo: s.userRepo, + userSubRepo: s.userSubRepo, + billingCacheService: s.billingCacheService, + deferredService: s.deferredService, + } +} + // CloseOpenAIWSPool 关闭 OpenAI WebSocket 连接池的后台 worker 和空闲连接。 // 应在应用优雅关闭时调用。 func (s *OpenAIGatewayService) CloseOpenAIWSPool() { @@ -3474,44 +3484,21 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec shouldBill := inserted || err != nil - // Deduct based on billing type - if isSubscriptionBilling { - if shouldBill && cost.TotalCost > 0 { - _ = s.userSubRepo.IncrementUsage(ctx, subscription.ID, cost.TotalCost) - s.billingCacheService.QueueUpdateSubscriptionUsage(user.ID, *apiKey.GroupID, cost.TotalCost) - } + if shouldBill { + postUsageBilling(ctx, &postUsageBillingParams{ + Cost: cost, + User: user, + APIKey: apiKey, + Account: account, + Subscription: subscription, + IsSubscriptionBill: isSubscriptionBilling, + AccountRateMultiplier: accountRateMultiplier, + APIKeyService: input.APIKeyService, + }, s.billingDeps()) } else { - if shouldBill && cost.ActualCost > 0 { - _ = s.userRepo.DeductBalance(ctx, user.ID, cost.ActualCost) - s.billingCacheService.QueueDeductBalance(user.ID, cost.ActualCost) - } + s.deferredService.ScheduleLastUsedUpdate(account.ID) } - // Update API key quota if applicable (only for balance mode with quota set) - if shouldBill && cost.ActualCost > 0 && apiKey.Quota > 0 && input.APIKeyService != nil { - if err := input.APIKeyService.UpdateQuotaUsed(ctx, apiKey.ID, cost.ActualCost); err != nil { - logger.LegacyPrintf("service.openai_gateway", "Update API key quota failed: %v", err) - } - } - - // Update API Key rate limit usage - if shouldBill && cost.ActualCost > 0 && apiKey.HasRateLimits() && input.APIKeyService != nil { - if err := input.APIKeyService.UpdateRateLimitUsage(ctx, apiKey.ID, cost.ActualCost); err != nil { - logger.LegacyPrintf("service.openai_gateway", "Update API key rate limit usage failed: %v", err) - } - s.billingCacheService.QueueUpdateAPIKeyRateLimitUsage(apiKey.ID, cost.ActualCost) - } - - // 更新 API Key 账号配额用量 - if shouldBill && cost.TotalCost > 0 && account.Type == AccountTypeAPIKey && account.GetQuotaLimit() > 0 { - if err := s.accountRepo.IncrementQuotaUsed(ctx, account.ID, cost.TotalCost); err != nil { - logger.LegacyPrintf("service.openai_gateway", "increment account quota used failed: account_id=%d cost=%f error=%v", account.ID, cost.TotalCost, err) - } - } - - // Schedule batch update for account last_used_at - s.deferredService.ScheduleLastUsedUpdate(account.ID) - return nil } diff --git a/backend/internal/service/openai_ws_forwarder.go b/backend/internal/service/openai_ws_forwarder.go index a5c2fd7a..7b6591fa 100644 --- a/backend/internal/service/openai_ws_forwarder.go +++ b/backend/internal/service/openai_ws_forwarder.go @@ -864,7 +864,8 @@ func isOpenAIWSClientDisconnectError(err error) bool { strings.Contains(message, "unexpected eof") || strings.Contains(message, "use of closed network connection") || strings.Contains(message, "connection reset by peer") || - strings.Contains(message, "broken pipe") + strings.Contains(message, "broken pipe") || + strings.Contains(message, "an established connection was aborted") } func classifyOpenAIWSReadFallbackReason(err error) string {