refactor: unify post-usage billing logic and fix account quota calculation
- Extract postUsageBilling() to consolidate billing logic across GatewayService.RecordUsage, RecordUsageWithLongContext, and OpenAIGatewayService.RecordUsage, eliminating ~120 lines of duplicated code - Fix account quota to use TotalCost × accountRateMultiplier (was using raw TotalCost, inconsistent with account cost stats) - Fix RecordUsageWithLongContext API Key quota only updating in balance mode (now updates regardless of billing type) - Fix WebSocket client disconnect detection on Windows by adding "an established connection was aborted" to known disconnect errors
This commit is contained in:
@@ -6410,6 +6410,89 @@ type APIKeyQuotaUpdater interface {
|
|||||||
UpdateRateLimitUsage(ctx context.Context, apiKeyID int64, cost float64) error
|
UpdateRateLimitUsage(ctx context.Context, apiKeyID int64, cost float64) error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// postUsageBillingParams 统一扣费所需的参数
|
||||||
|
type postUsageBillingParams struct {
|
||||||
|
Cost *CostBreakdown
|
||||||
|
User *User
|
||||||
|
APIKey *APIKey
|
||||||
|
Account *Account
|
||||||
|
Subscription *UserSubscription
|
||||||
|
IsSubscriptionBill bool
|
||||||
|
AccountRateMultiplier float64
|
||||||
|
APIKeyService APIKeyQuotaUpdater
|
||||||
|
}
|
||||||
|
|
||||||
|
// postUsageBilling 统一处理使用量记录后的扣费逻辑:
|
||||||
|
// - 订阅/余额扣费
|
||||||
|
// - API Key 配额更新
|
||||||
|
// - API Key 限速用量更新
|
||||||
|
// - 账号配额用量更新(账号口径:TotalCost × 账号计费倍率)
|
||||||
|
func postUsageBilling(ctx context.Context, p *postUsageBillingParams, deps *billingDeps) {
|
||||||
|
cost := p.Cost
|
||||||
|
|
||||||
|
// 1. 订阅 / 余额扣费
|
||||||
|
if p.IsSubscriptionBill {
|
||||||
|
if cost.TotalCost > 0 {
|
||||||
|
if err := deps.userSubRepo.IncrementUsage(ctx, p.Subscription.ID, cost.TotalCost); err != nil {
|
||||||
|
slog.Error("increment subscription usage failed", "subscription_id", p.Subscription.ID, "error", err)
|
||||||
|
}
|
||||||
|
deps.billingCacheService.QueueUpdateSubscriptionUsage(p.User.ID, *p.APIKey.GroupID, cost.TotalCost)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if cost.ActualCost > 0 {
|
||||||
|
if err := deps.userRepo.DeductBalance(ctx, p.User.ID, cost.ActualCost); err != nil {
|
||||||
|
slog.Error("deduct balance failed", "user_id", p.User.ID, "error", err)
|
||||||
|
}
|
||||||
|
deps.billingCacheService.QueueDeductBalance(p.User.ID, cost.ActualCost)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. API Key 配额
|
||||||
|
if cost.ActualCost > 0 && p.APIKey.Quota > 0 && p.APIKeyService != nil {
|
||||||
|
if err := p.APIKeyService.UpdateQuotaUsed(ctx, p.APIKey.ID, cost.ActualCost); err != nil {
|
||||||
|
slog.Error("update api key quota failed", "api_key_id", p.APIKey.ID, "error", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. API Key 限速用量
|
||||||
|
if cost.ActualCost > 0 && p.APIKey.HasRateLimits() && p.APIKeyService != nil {
|
||||||
|
if err := p.APIKeyService.UpdateRateLimitUsage(ctx, p.APIKey.ID, cost.ActualCost); err != nil {
|
||||||
|
slog.Error("update api key rate limit usage failed", "api_key_id", p.APIKey.ID, "error", err)
|
||||||
|
}
|
||||||
|
deps.billingCacheService.QueueUpdateAPIKeyRateLimitUsage(p.APIKey.ID, cost.ActualCost)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4. 账号配额用量(账号口径:TotalCost × 账号计费倍率)
|
||||||
|
if cost.TotalCost > 0 && p.Account.Type == AccountTypeAPIKey && p.Account.GetQuotaLimit() > 0 {
|
||||||
|
accountCost := cost.TotalCost * p.AccountRateMultiplier
|
||||||
|
if err := deps.accountRepo.IncrementQuotaUsed(ctx, p.Account.ID, accountCost); err != nil {
|
||||||
|
slog.Error("increment account quota used failed", "account_id", p.Account.ID, "cost", accountCost, "error", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 5. 更新账号最近使用时间
|
||||||
|
deps.deferredService.ScheduleLastUsedUpdate(p.Account.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// billingDeps 扣费逻辑依赖的服务(由各 gateway service 提供)
|
||||||
|
type billingDeps struct {
|
||||||
|
accountRepo AccountRepository
|
||||||
|
userRepo UserRepository
|
||||||
|
userSubRepo UserSubscriptionRepository
|
||||||
|
billingCacheService *BillingCacheService
|
||||||
|
deferredService *DeferredService
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *GatewayService) billingDeps() *billingDeps {
|
||||||
|
return &billingDeps{
|
||||||
|
accountRepo: s.accountRepo,
|
||||||
|
userRepo: s.userRepo,
|
||||||
|
userSubRepo: s.userSubRepo,
|
||||||
|
billingCacheService: s.billingCacheService,
|
||||||
|
deferredService: s.deferredService,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// RecordUsage 记录使用量并扣费(或更新订阅用量)
|
// RecordUsage 记录使用量并扣费(或更新订阅用量)
|
||||||
func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInput) error {
|
func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInput) error {
|
||||||
result := input.Result
|
result := input.Result
|
||||||
@@ -6573,52 +6656,21 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
|||||||
|
|
||||||
shouldBill := inserted || err != nil
|
shouldBill := inserted || err != nil
|
||||||
|
|
||||||
// 根据计费类型执行扣费
|
if shouldBill {
|
||||||
if isSubscriptionBilling {
|
postUsageBilling(ctx, &postUsageBillingParams{
|
||||||
// 订阅模式:更新订阅用量(使用 TotalCost 原始费用,不考虑倍率)
|
Cost: cost,
|
||||||
if shouldBill && cost.TotalCost > 0 {
|
User: user,
|
||||||
if err := s.userSubRepo.IncrementUsage(ctx, subscription.ID, cost.TotalCost); err != nil {
|
APIKey: apiKey,
|
||||||
logger.LegacyPrintf("service.gateway", "Increment subscription usage failed: %v", err)
|
Account: account,
|
||||||
}
|
Subscription: subscription,
|
||||||
// 异步更新订阅缓存
|
IsSubscriptionBill: isSubscriptionBilling,
|
||||||
s.billingCacheService.QueueUpdateSubscriptionUsage(user.ID, *apiKey.GroupID, cost.TotalCost)
|
AccountRateMultiplier: accountRateMultiplier,
|
||||||
}
|
APIKeyService: input.APIKeyService,
|
||||||
|
}, s.billingDeps())
|
||||||
} else {
|
} else {
|
||||||
// 余额模式:扣除用户余额(使用 ActualCost 考虑倍率后的费用)
|
s.deferredService.ScheduleLastUsedUpdate(account.ID)
|
||||||
if shouldBill && cost.ActualCost > 0 {
|
|
||||||
if err := s.userRepo.DeductBalance(ctx, user.ID, cost.ActualCost); err != nil {
|
|
||||||
logger.LegacyPrintf("service.gateway", "Deduct balance failed: %v", err)
|
|
||||||
}
|
|
||||||
// 异步更新余额缓存
|
|
||||||
s.billingCacheService.QueueDeductBalance(user.ID, cost.ActualCost)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 更新 API Key 配额(如果设置了配额限制)
|
|
||||||
if shouldBill && cost.ActualCost > 0 && apiKey.Quota > 0 && input.APIKeyService != nil {
|
|
||||||
if err := input.APIKeyService.UpdateQuotaUsed(ctx, apiKey.ID, cost.ActualCost); err != nil {
|
|
||||||
logger.LegacyPrintf("service.gateway", "Update API key quota failed: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update API Key rate limit usage
|
|
||||||
if shouldBill && cost.ActualCost > 0 && apiKey.HasRateLimits() && input.APIKeyService != nil {
|
|
||||||
if err := input.APIKeyService.UpdateRateLimitUsage(ctx, apiKey.ID, cost.ActualCost); err != nil {
|
|
||||||
logger.LegacyPrintf("service.gateway", "Update API key rate limit usage failed: %v", err)
|
|
||||||
}
|
|
||||||
s.billingCacheService.QueueUpdateAPIKeyRateLimitUsage(apiKey.ID, cost.ActualCost)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 更新 API Key 账号配额用量
|
|
||||||
if shouldBill && cost.TotalCost > 0 && account.Type == AccountTypeAPIKey && account.GetQuotaLimit() > 0 {
|
|
||||||
if err := s.accountRepo.IncrementQuotaUsed(ctx, account.ID, cost.TotalCost); err != nil {
|
|
||||||
slog.Error("increment account quota used failed", "account_id", account.ID, "cost", cost.TotalCost, "error", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Schedule batch update for account last_used_at
|
|
||||||
s.deferredService.ScheduleLastUsedUpdate(account.ID)
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -6778,51 +6830,21 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
|
|||||||
|
|
||||||
shouldBill := inserted || err != nil
|
shouldBill := inserted || err != nil
|
||||||
|
|
||||||
// 根据计费类型执行扣费
|
if shouldBill {
|
||||||
if isSubscriptionBilling {
|
postUsageBilling(ctx, &postUsageBillingParams{
|
||||||
// 订阅模式:更新订阅用量(使用 TotalCost 原始费用,不考虑倍率)
|
Cost: cost,
|
||||||
if shouldBill && cost.TotalCost > 0 {
|
User: user,
|
||||||
if err := s.userSubRepo.IncrementUsage(ctx, subscription.ID, cost.TotalCost); err != nil {
|
APIKey: apiKey,
|
||||||
logger.LegacyPrintf("service.gateway", "Increment subscription usage failed: %v", err)
|
Account: account,
|
||||||
}
|
Subscription: subscription,
|
||||||
// 异步更新订阅缓存
|
IsSubscriptionBill: isSubscriptionBilling,
|
||||||
s.billingCacheService.QueueUpdateSubscriptionUsage(user.ID, *apiKey.GroupID, cost.TotalCost)
|
AccountRateMultiplier: accountRateMultiplier,
|
||||||
}
|
APIKeyService: input.APIKeyService,
|
||||||
|
}, s.billingDeps())
|
||||||
} else {
|
} else {
|
||||||
// 余额模式:扣除用户余额(使用 ActualCost 考虑倍率后的费用)
|
s.deferredService.ScheduleLastUsedUpdate(account.ID)
|
||||||
if shouldBill && cost.ActualCost > 0 {
|
|
||||||
if err := s.userRepo.DeductBalance(ctx, user.ID, cost.ActualCost); err != nil {
|
|
||||||
logger.LegacyPrintf("service.gateway", "Deduct balance failed: %v", err)
|
|
||||||
}
|
|
||||||
// 异步更新余额缓存
|
|
||||||
s.billingCacheService.QueueDeductBalance(user.ID, cost.ActualCost)
|
|
||||||
// API Key 独立配额扣费
|
|
||||||
if input.APIKeyService != nil && apiKey.Quota > 0 {
|
|
||||||
if err := input.APIKeyService.UpdateQuotaUsed(ctx, apiKey.ID, cost.ActualCost); err != nil {
|
|
||||||
logger.LegacyPrintf("service.gateway", "Add API key quota used failed: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update API Key rate limit usage
|
|
||||||
if shouldBill && cost.ActualCost > 0 && apiKey.HasRateLimits() && input.APIKeyService != nil {
|
|
||||||
if err := input.APIKeyService.UpdateRateLimitUsage(ctx, apiKey.ID, cost.ActualCost); err != nil {
|
|
||||||
logger.LegacyPrintf("service.gateway", "Update API key rate limit usage failed: %v", err)
|
|
||||||
}
|
|
||||||
s.billingCacheService.QueueUpdateAPIKeyRateLimitUsage(apiKey.ID, cost.ActualCost)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 更新 API Key 账号配额用量
|
|
||||||
if shouldBill && cost.TotalCost > 0 && account.Type == AccountTypeAPIKey && account.GetQuotaLimit() > 0 {
|
|
||||||
if err := s.accountRepo.IncrementQuotaUsed(ctx, account.ID, cost.TotalCost); err != nil {
|
|
||||||
slog.Error("increment account quota used failed", "account_id", account.ID, "cost", cost.TotalCost, "error", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Schedule batch update for account last_used_at
|
|
||||||
s.deferredService.ScheduleLastUsedUpdate(account.ID)
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -319,6 +319,16 @@ func NewOpenAIGatewayService(
|
|||||||
return svc
|
return svc
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *OpenAIGatewayService) billingDeps() *billingDeps {
|
||||||
|
return &billingDeps{
|
||||||
|
accountRepo: s.accountRepo,
|
||||||
|
userRepo: s.userRepo,
|
||||||
|
userSubRepo: s.userSubRepo,
|
||||||
|
billingCacheService: s.billingCacheService,
|
||||||
|
deferredService: s.deferredService,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// CloseOpenAIWSPool 关闭 OpenAI WebSocket 连接池的后台 worker 和空闲连接。
|
// CloseOpenAIWSPool 关闭 OpenAI WebSocket 连接池的后台 worker 和空闲连接。
|
||||||
// 应在应用优雅关闭时调用。
|
// 应在应用优雅关闭时调用。
|
||||||
func (s *OpenAIGatewayService) CloseOpenAIWSPool() {
|
func (s *OpenAIGatewayService) CloseOpenAIWSPool() {
|
||||||
@@ -3474,44 +3484,21 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
|||||||
|
|
||||||
shouldBill := inserted || err != nil
|
shouldBill := inserted || err != nil
|
||||||
|
|
||||||
// Deduct based on billing type
|
if shouldBill {
|
||||||
if isSubscriptionBilling {
|
postUsageBilling(ctx, &postUsageBillingParams{
|
||||||
if shouldBill && cost.TotalCost > 0 {
|
Cost: cost,
|
||||||
_ = s.userSubRepo.IncrementUsage(ctx, subscription.ID, cost.TotalCost)
|
User: user,
|
||||||
s.billingCacheService.QueueUpdateSubscriptionUsage(user.ID, *apiKey.GroupID, cost.TotalCost)
|
APIKey: apiKey,
|
||||||
}
|
Account: account,
|
||||||
|
Subscription: subscription,
|
||||||
|
IsSubscriptionBill: isSubscriptionBilling,
|
||||||
|
AccountRateMultiplier: accountRateMultiplier,
|
||||||
|
APIKeyService: input.APIKeyService,
|
||||||
|
}, s.billingDeps())
|
||||||
} else {
|
} else {
|
||||||
if shouldBill && cost.ActualCost > 0 {
|
s.deferredService.ScheduleLastUsedUpdate(account.ID)
|
||||||
_ = s.userRepo.DeductBalance(ctx, user.ID, cost.ActualCost)
|
|
||||||
s.billingCacheService.QueueDeductBalance(user.ID, cost.ActualCost)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update API key quota if applicable (only for balance mode with quota set)
|
|
||||||
if shouldBill && cost.ActualCost > 0 && apiKey.Quota > 0 && input.APIKeyService != nil {
|
|
||||||
if err := input.APIKeyService.UpdateQuotaUsed(ctx, apiKey.ID, cost.ActualCost); err != nil {
|
|
||||||
logger.LegacyPrintf("service.openai_gateway", "Update API key quota failed: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update API Key rate limit usage
|
|
||||||
if shouldBill && cost.ActualCost > 0 && apiKey.HasRateLimits() && input.APIKeyService != nil {
|
|
||||||
if err := input.APIKeyService.UpdateRateLimitUsage(ctx, apiKey.ID, cost.ActualCost); err != nil {
|
|
||||||
logger.LegacyPrintf("service.openai_gateway", "Update API key rate limit usage failed: %v", err)
|
|
||||||
}
|
|
||||||
s.billingCacheService.QueueUpdateAPIKeyRateLimitUsage(apiKey.ID, cost.ActualCost)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 更新 API Key 账号配额用量
|
|
||||||
if shouldBill && cost.TotalCost > 0 && account.Type == AccountTypeAPIKey && account.GetQuotaLimit() > 0 {
|
|
||||||
if err := s.accountRepo.IncrementQuotaUsed(ctx, account.ID, cost.TotalCost); err != nil {
|
|
||||||
logger.LegacyPrintf("service.openai_gateway", "increment account quota used failed: account_id=%d cost=%f error=%v", account.ID, cost.TotalCost, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Schedule batch update for account last_used_at
|
|
||||||
s.deferredService.ScheduleLastUsedUpdate(account.ID)
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -864,7 +864,8 @@ func isOpenAIWSClientDisconnectError(err error) bool {
|
|||||||
strings.Contains(message, "unexpected eof") ||
|
strings.Contains(message, "unexpected eof") ||
|
||||||
strings.Contains(message, "use of closed network connection") ||
|
strings.Contains(message, "use of closed network connection") ||
|
||||||
strings.Contains(message, "connection reset by peer") ||
|
strings.Contains(message, "connection reset by peer") ||
|
||||||
strings.Contains(message, "broken pipe")
|
strings.Contains(message, "broken pipe") ||
|
||||||
|
strings.Contains(message, "an established connection was aborted")
|
||||||
}
|
}
|
||||||
|
|
||||||
func classifyOpenAIWSReadFallbackReason(err error) string {
|
func classifyOpenAIWSReadFallbackReason(err error) string {
|
||||||
|
|||||||
Reference in New Issue
Block a user