refactor: unify post-usage billing logic and fix account quota calculation

- Extract postUsageBilling() to consolidate billing logic across
  GatewayService.RecordUsage, RecordUsageWithLongContext, and
  OpenAIGatewayService.RecordUsage, eliminating ~120 lines of
  duplicated code
- Fix account quota to use TotalCost × accountRateMultiplier
  (was using raw TotalCost, inconsistent with account cost stats)
- Fix RecordUsageWithLongContext API Key quota only updating in
  balance mode (now updates regardless of billing type)
- Fix WebSocket client disconnect detection on Windows by adding
  "an established connection was aborted" to known disconnect errors
This commit is contained in:
erio
2026-03-06 00:37:37 +08:00
parent c26f93c4a0
commit 02dea7b09b
3 changed files with 131 additions and 121 deletions

View File

@@ -6410,6 +6410,89 @@ type APIKeyQuotaUpdater interface {
UpdateRateLimitUsage(ctx context.Context, apiKeyID int64, cost float64) error UpdateRateLimitUsage(ctx context.Context, apiKeyID int64, cost float64) error
} }
// postUsageBillingParams 统一扣费所需的参数
type postUsageBillingParams struct {
Cost *CostBreakdown
User *User
APIKey *APIKey
Account *Account
Subscription *UserSubscription
IsSubscriptionBill bool
AccountRateMultiplier float64
APIKeyService APIKeyQuotaUpdater
}
// postUsageBilling 统一处理使用量记录后的扣费逻辑:
// - 订阅/余额扣费
// - API Key 配额更新
// - API Key 限速用量更新
// - 账号配额用量更新账号口径TotalCost × 账号计费倍率)
func postUsageBilling(ctx context.Context, p *postUsageBillingParams, deps *billingDeps) {
cost := p.Cost
// 1. 订阅 / 余额扣费
if p.IsSubscriptionBill {
if cost.TotalCost > 0 {
if err := deps.userSubRepo.IncrementUsage(ctx, p.Subscription.ID, cost.TotalCost); err != nil {
slog.Error("increment subscription usage failed", "subscription_id", p.Subscription.ID, "error", err)
}
deps.billingCacheService.QueueUpdateSubscriptionUsage(p.User.ID, *p.APIKey.GroupID, cost.TotalCost)
}
} else {
if cost.ActualCost > 0 {
if err := deps.userRepo.DeductBalance(ctx, p.User.ID, cost.ActualCost); err != nil {
slog.Error("deduct balance failed", "user_id", p.User.ID, "error", err)
}
deps.billingCacheService.QueueDeductBalance(p.User.ID, cost.ActualCost)
}
}
// 2. API Key 配额
if cost.ActualCost > 0 && p.APIKey.Quota > 0 && p.APIKeyService != nil {
if err := p.APIKeyService.UpdateQuotaUsed(ctx, p.APIKey.ID, cost.ActualCost); err != nil {
slog.Error("update api key quota failed", "api_key_id", p.APIKey.ID, "error", err)
}
}
// 3. API Key 限速用量
if cost.ActualCost > 0 && p.APIKey.HasRateLimits() && p.APIKeyService != nil {
if err := p.APIKeyService.UpdateRateLimitUsage(ctx, p.APIKey.ID, cost.ActualCost); err != nil {
slog.Error("update api key rate limit usage failed", "api_key_id", p.APIKey.ID, "error", err)
}
deps.billingCacheService.QueueUpdateAPIKeyRateLimitUsage(p.APIKey.ID, cost.ActualCost)
}
// 4. 账号配额用量账号口径TotalCost × 账号计费倍率)
if cost.TotalCost > 0 && p.Account.Type == AccountTypeAPIKey && p.Account.GetQuotaLimit() > 0 {
accountCost := cost.TotalCost * p.AccountRateMultiplier
if err := deps.accountRepo.IncrementQuotaUsed(ctx, p.Account.ID, accountCost); err != nil {
slog.Error("increment account quota used failed", "account_id", p.Account.ID, "cost", accountCost, "error", err)
}
}
// 5. 更新账号最近使用时间
deps.deferredService.ScheduleLastUsedUpdate(p.Account.ID)
}
// billingDeps 扣费逻辑依赖的服务(由各 gateway service 提供)
type billingDeps struct {
accountRepo AccountRepository
userRepo UserRepository
userSubRepo UserSubscriptionRepository
billingCacheService *BillingCacheService
deferredService *DeferredService
}
func (s *GatewayService) billingDeps() *billingDeps {
return &billingDeps{
accountRepo: s.accountRepo,
userRepo: s.userRepo,
userSubRepo: s.userSubRepo,
billingCacheService: s.billingCacheService,
deferredService: s.deferredService,
}
}
// RecordUsage 记录使用量并扣费(或更新订阅用量) // RecordUsage 记录使用量并扣费(或更新订阅用量)
func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInput) error { func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInput) error {
result := input.Result result := input.Result
@@ -6573,52 +6656,21 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
shouldBill := inserted || err != nil shouldBill := inserted || err != nil
// 根据计费类型执行扣费 if shouldBill {
if isSubscriptionBilling { postUsageBilling(ctx, &postUsageBillingParams{
// 订阅模式:更新订阅用量(使用 TotalCost 原始费用,不考虑倍率) Cost: cost,
if shouldBill && cost.TotalCost > 0 { User: user,
if err := s.userSubRepo.IncrementUsage(ctx, subscription.ID, cost.TotalCost); err != nil { APIKey: apiKey,
logger.LegacyPrintf("service.gateway", "Increment subscription usage failed: %v", err) Account: account,
} Subscription: subscription,
// 异步更新订阅缓存 IsSubscriptionBill: isSubscriptionBilling,
s.billingCacheService.QueueUpdateSubscriptionUsage(user.ID, *apiKey.GroupID, cost.TotalCost) AccountRateMultiplier: accountRateMultiplier,
} APIKeyService: input.APIKeyService,
}, s.billingDeps())
} else { } else {
// 余额模式:扣除用户余额(使用 ActualCost 考虑倍率后的费用) s.deferredService.ScheduleLastUsedUpdate(account.ID)
if shouldBill && cost.ActualCost > 0 {
if err := s.userRepo.DeductBalance(ctx, user.ID, cost.ActualCost); err != nil {
logger.LegacyPrintf("service.gateway", "Deduct balance failed: %v", err)
}
// 异步更新余额缓存
s.billingCacheService.QueueDeductBalance(user.ID, cost.ActualCost)
}
} }
// 更新 API Key 配额(如果设置了配额限制)
if shouldBill && cost.ActualCost > 0 && apiKey.Quota > 0 && input.APIKeyService != nil {
if err := input.APIKeyService.UpdateQuotaUsed(ctx, apiKey.ID, cost.ActualCost); err != nil {
logger.LegacyPrintf("service.gateway", "Update API key quota failed: %v", err)
}
}
// Update API Key rate limit usage
if shouldBill && cost.ActualCost > 0 && apiKey.HasRateLimits() && input.APIKeyService != nil {
if err := input.APIKeyService.UpdateRateLimitUsage(ctx, apiKey.ID, cost.ActualCost); err != nil {
logger.LegacyPrintf("service.gateway", "Update API key rate limit usage failed: %v", err)
}
s.billingCacheService.QueueUpdateAPIKeyRateLimitUsage(apiKey.ID, cost.ActualCost)
}
// 更新 API Key 账号配额用量
if shouldBill && cost.TotalCost > 0 && account.Type == AccountTypeAPIKey && account.GetQuotaLimit() > 0 {
if err := s.accountRepo.IncrementQuotaUsed(ctx, account.ID, cost.TotalCost); err != nil {
slog.Error("increment account quota used failed", "account_id", account.ID, "cost", cost.TotalCost, "error", err)
}
}
// Schedule batch update for account last_used_at
s.deferredService.ScheduleLastUsedUpdate(account.ID)
return nil return nil
} }
@@ -6778,51 +6830,21 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
shouldBill := inserted || err != nil shouldBill := inserted || err != nil
// 根据计费类型执行扣费 if shouldBill {
if isSubscriptionBilling { postUsageBilling(ctx, &postUsageBillingParams{
// 订阅模式:更新订阅用量(使用 TotalCost 原始费用,不考虑倍率) Cost: cost,
if shouldBill && cost.TotalCost > 0 { User: user,
if err := s.userSubRepo.IncrementUsage(ctx, subscription.ID, cost.TotalCost); err != nil { APIKey: apiKey,
logger.LegacyPrintf("service.gateway", "Increment subscription usage failed: %v", err) Account: account,
} Subscription: subscription,
// 异步更新订阅缓存 IsSubscriptionBill: isSubscriptionBilling,
s.billingCacheService.QueueUpdateSubscriptionUsage(user.ID, *apiKey.GroupID, cost.TotalCost) AccountRateMultiplier: accountRateMultiplier,
} APIKeyService: input.APIKeyService,
}, s.billingDeps())
} else { } else {
// 余额模式:扣除用户余额(使用 ActualCost 考虑倍率后的费用) s.deferredService.ScheduleLastUsedUpdate(account.ID)
if shouldBill && cost.ActualCost > 0 {
if err := s.userRepo.DeductBalance(ctx, user.ID, cost.ActualCost); err != nil {
logger.LegacyPrintf("service.gateway", "Deduct balance failed: %v", err)
}
// 异步更新余额缓存
s.billingCacheService.QueueDeductBalance(user.ID, cost.ActualCost)
// API Key 独立配额扣费
if input.APIKeyService != nil && apiKey.Quota > 0 {
if err := input.APIKeyService.UpdateQuotaUsed(ctx, apiKey.ID, cost.ActualCost); err != nil {
logger.LegacyPrintf("service.gateway", "Add API key quota used failed: %v", err)
}
}
}
} }
// Update API Key rate limit usage
if shouldBill && cost.ActualCost > 0 && apiKey.HasRateLimits() && input.APIKeyService != nil {
if err := input.APIKeyService.UpdateRateLimitUsage(ctx, apiKey.ID, cost.ActualCost); err != nil {
logger.LegacyPrintf("service.gateway", "Update API key rate limit usage failed: %v", err)
}
s.billingCacheService.QueueUpdateAPIKeyRateLimitUsage(apiKey.ID, cost.ActualCost)
}
// 更新 API Key 账号配额用量
if shouldBill && cost.TotalCost > 0 && account.Type == AccountTypeAPIKey && account.GetQuotaLimit() > 0 {
if err := s.accountRepo.IncrementQuotaUsed(ctx, account.ID, cost.TotalCost); err != nil {
slog.Error("increment account quota used failed", "account_id", account.ID, "cost", cost.TotalCost, "error", err)
}
}
// Schedule batch update for account last_used_at
s.deferredService.ScheduleLastUsedUpdate(account.ID)
return nil return nil
} }

View File

@@ -319,6 +319,16 @@ func NewOpenAIGatewayService(
return svc return svc
} }
func (s *OpenAIGatewayService) billingDeps() *billingDeps {
return &billingDeps{
accountRepo: s.accountRepo,
userRepo: s.userRepo,
userSubRepo: s.userSubRepo,
billingCacheService: s.billingCacheService,
deferredService: s.deferredService,
}
}
// CloseOpenAIWSPool 关闭 OpenAI WebSocket 连接池的后台 worker 和空闲连接。 // CloseOpenAIWSPool 关闭 OpenAI WebSocket 连接池的后台 worker 和空闲连接。
// 应在应用优雅关闭时调用。 // 应在应用优雅关闭时调用。
func (s *OpenAIGatewayService) CloseOpenAIWSPool() { func (s *OpenAIGatewayService) CloseOpenAIWSPool() {
@@ -3474,44 +3484,21 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
shouldBill := inserted || err != nil shouldBill := inserted || err != nil
// Deduct based on billing type if shouldBill {
if isSubscriptionBilling { postUsageBilling(ctx, &postUsageBillingParams{
if shouldBill && cost.TotalCost > 0 { Cost: cost,
_ = s.userSubRepo.IncrementUsage(ctx, subscription.ID, cost.TotalCost) User: user,
s.billingCacheService.QueueUpdateSubscriptionUsage(user.ID, *apiKey.GroupID, cost.TotalCost) APIKey: apiKey,
} Account: account,
Subscription: subscription,
IsSubscriptionBill: isSubscriptionBilling,
AccountRateMultiplier: accountRateMultiplier,
APIKeyService: input.APIKeyService,
}, s.billingDeps())
} else { } else {
if shouldBill && cost.ActualCost > 0 { s.deferredService.ScheduleLastUsedUpdate(account.ID)
_ = s.userRepo.DeductBalance(ctx, user.ID, cost.ActualCost)
s.billingCacheService.QueueDeductBalance(user.ID, cost.ActualCost)
}
} }
// Update API key quota if applicable (only for balance mode with quota set)
if shouldBill && cost.ActualCost > 0 && apiKey.Quota > 0 && input.APIKeyService != nil {
if err := input.APIKeyService.UpdateQuotaUsed(ctx, apiKey.ID, cost.ActualCost); err != nil {
logger.LegacyPrintf("service.openai_gateway", "Update API key quota failed: %v", err)
}
}
// Update API Key rate limit usage
if shouldBill && cost.ActualCost > 0 && apiKey.HasRateLimits() && input.APIKeyService != nil {
if err := input.APIKeyService.UpdateRateLimitUsage(ctx, apiKey.ID, cost.ActualCost); err != nil {
logger.LegacyPrintf("service.openai_gateway", "Update API key rate limit usage failed: %v", err)
}
s.billingCacheService.QueueUpdateAPIKeyRateLimitUsage(apiKey.ID, cost.ActualCost)
}
// 更新 API Key 账号配额用量
if shouldBill && cost.TotalCost > 0 && account.Type == AccountTypeAPIKey && account.GetQuotaLimit() > 0 {
if err := s.accountRepo.IncrementQuotaUsed(ctx, account.ID, cost.TotalCost); err != nil {
logger.LegacyPrintf("service.openai_gateway", "increment account quota used failed: account_id=%d cost=%f error=%v", account.ID, cost.TotalCost, err)
}
}
// Schedule batch update for account last_used_at
s.deferredService.ScheduleLastUsedUpdate(account.ID)
return nil return nil
} }

View File

@@ -864,7 +864,8 @@ func isOpenAIWSClientDisconnectError(err error) bool {
strings.Contains(message, "unexpected eof") || strings.Contains(message, "unexpected eof") ||
strings.Contains(message, "use of closed network connection") || strings.Contains(message, "use of closed network connection") ||
strings.Contains(message, "connection reset by peer") || strings.Contains(message, "connection reset by peer") ||
strings.Contains(message, "broken pipe") strings.Contains(message, "broken pipe") ||
strings.Contains(message, "an established connection was aborted")
} }
func classifyOpenAIWSReadFallbackReason(err error) string { func classifyOpenAIWSReadFallbackReason(err error) string {