From b4a42a640d27ba8e55d62cbfc976013f5f3436cc Mon Sep 17 00:00:00 2001 From: erio Date: Thu, 2 Apr 2026 03:28:52 +0800 Subject: [PATCH] refactor: extract helpers to reduce duplication and function length in gateway billing MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Extract resolveChannelPricing to DRY the resolver pattern shared by calculateImageCost/calculateTokenCost - Remove unnecessary IIFE wrapper and pass accountRateMultiplier as parameter - Extract resolveBillingMode, resolveMediaType, optionalSubscriptionID to simplify buildRecordUsageLog (104→65 lines) - Extract shouldDeductAPIKeyQuota/shouldUpdateRateLimits/shouldUpdateAccountQuota methods on postUsageBillingParams to unify duplicated billing conditions --- backend/internal/service/gateway_service.go | 196 ++++++++++---------- 1 file changed, 100 insertions(+), 96 deletions(-) diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 49e2b412..e42f3702 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -7451,6 +7451,18 @@ type postUsageBillingParams struct { APIKeyService APIKeyQuotaUpdater } +func (p *postUsageBillingParams) shouldDeductAPIKeyQuota() bool { + return p.Cost.ActualCost > 0 && p.APIKey.Quota > 0 && p.APIKeyService != nil +} + +func (p *postUsageBillingParams) shouldUpdateRateLimits() bool { + return p.Cost.ActualCost > 0 && p.APIKey.HasRateLimits() && p.APIKeyService != nil +} + +func (p *postUsageBillingParams) shouldUpdateAccountQuota() bool { + return p.Cost.TotalCost > 0 && p.Account.IsAPIKeyOrBedrock() && p.Account.HasAnyQuotaLimit() +} + // postUsageBilling 统一处理使用量记录后的扣费逻辑: // - 订阅/余额扣费 // - API Key 配额更新 @@ -7480,21 +7492,21 @@ func postUsageBilling(ctx context.Context, p *postUsageBillingParams, deps *bill } // 2. API Key 配额 - if cost.ActualCost > 0 && p.APIKey.Quota > 0 && p.APIKeyService != nil { + if p.shouldDeductAPIKeyQuota() { if err := p.APIKeyService.UpdateQuotaUsed(billingCtx, p.APIKey.ID, cost.ActualCost); err != nil { slog.Error("update api key quota failed", "api_key_id", p.APIKey.ID, "error", err) } } // 3. API Key 限速用量 - if cost.ActualCost > 0 && p.APIKey.HasRateLimits() && p.APIKeyService != nil { + if p.shouldUpdateRateLimits() { if err := p.APIKeyService.UpdateRateLimitUsage(billingCtx, p.APIKey.ID, cost.ActualCost); err != nil { slog.Error("update api key rate limit usage failed", "api_key_id", p.APIKey.ID, "error", err) } } // 4. 账号配额用量(账号口径:TotalCost × 账号计费倍率) - if cost.TotalCost > 0 && p.Account.IsAPIKeyOrBedrock() && p.Account.HasAnyQuotaLimit() { + if p.shouldUpdateAccountQuota() { accountCost := cost.TotalCost * p.AccountRateMultiplier if err := deps.accountRepo.IncrementQuotaUsed(billingCtx, p.Account.ID, accountCost); err != nil { slog.Error("increment account quota used failed", "account_id", p.Account.ID, "cost", accountCost, "error", err) @@ -7576,13 +7588,13 @@ func buildUsageBillingCommand(requestID string, usageLog *UsageLog, p *postUsage cmd.BalanceCost = p.Cost.ActualCost } - if p.Cost.ActualCost > 0 && p.APIKey.Quota > 0 && p.APIKeyService != nil { + if p.shouldDeductAPIKeyQuota() { cmd.APIKeyQuotaCost = p.Cost.ActualCost } - if p.Cost.ActualCost > 0 && p.APIKey.HasRateLimits() && p.APIKeyService != nil { + if p.shouldUpdateRateLimits() { cmd.APIKeyRateLimitCost = p.Cost.ActualCost } - if p.Cost.TotalCost > 0 && p.Account.IsAPIKeyOrBedrock() && p.Account.HasAnyQuotaLimit() { + if p.shouldUpdateAccountQuota() { cmd.AccountQuotaCost = p.Cost.TotalCost * p.AccountRateMultiplier } @@ -7879,8 +7891,9 @@ func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsage } // 创建使用日志 + accountRateMultiplier := account.BillingRateMultiplier() usageLog := s.buildRecordUsageLog(ctx, input, result, apiKey, user, account, subscription, - requestedModel, multiplier, billingType, cacheTTLOverridden, cost, opts) + requestedModel, multiplier, accountRateMultiplier, billingType, cacheTTLOverridden, cost, opts) if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway") @@ -7890,21 +7903,17 @@ func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsage } requestID := usageLog.RequestID - accountRateMultiplier := account.BillingRateMultiplier() - billingErr := func() error { - _, err := applyUsageBilling(ctx, requestID, usageLog, &postUsageBillingParams{ - Cost: cost, - User: user, - APIKey: apiKey, - Account: account, - Subscription: subscription, - RequestPayloadHash: resolveUsageBillingPayloadFingerprint(ctx, input.RequestPayloadHash), - IsSubscriptionBill: isSubscriptionBilling, - AccountRateMultiplier: accountRateMultiplier, - APIKeyService: input.APIKeyService, - }, s.billingDeps(), s.usageBillingRepo) - return err - }() + _, billingErr := applyUsageBilling(ctx, requestID, usageLog, &postUsageBillingParams{ + Cost: cost, + User: user, + APIKey: apiKey, + Account: account, + Subscription: subscription, + RequestPayloadHash: resolveUsageBillingPayloadFingerprint(ctx, input.RequestPayloadHash), + IsSubscriptionBill: isSubscriptionBilling, + AccountRateMultiplier: accountRateMultiplier, + APIKeyService: input.APIKeyService, + }, s.billingDeps(), s.usageBillingRepo) if billingErr != nil { return billingErr @@ -7964,6 +7973,20 @@ func (s *GatewayService) calculateSoraMediaCost( return s.billingService.CalculateSoraVideoCost(billingModel, soraConfig, multiplier) } +// resolveChannelPricing 检查指定模型是否存在渠道级别定价。 +// 返回非 nil 的 ResolvedPricing 表示有渠道定价,nil 表示走默认定价路径。 +func (s *GatewayService) resolveChannelPricing(ctx context.Context, billingModel string, apiKey *APIKey) *ResolvedPricing { + if s.resolver == nil || apiKey.Group == nil { + return nil + } + gid := apiKey.Group.ID + resolved := s.resolver.Resolve(ctx, PricingInput{Model: billingModel, GroupID: &gid}) + if resolved.Source == PricingSourceChannel { + return resolved + } + return nil +} + // calculateImageCost 计算图片生成费用:渠道级别定价优先,否则走按次计费。 func (s *GatewayService) calculateImageCost( ctx context.Context, @@ -7972,15 +7995,7 @@ func (s *GatewayService) calculateImageCost( billingModel string, multiplier float64, ) *CostBreakdown { - hasChannelPricing := false - if s.resolver != nil && apiKey.Group != nil { - gid := apiKey.Group.ID - resolved := s.resolver.Resolve(ctx, PricingInput{Model: billingModel, GroupID: &gid}) - if resolved.Source == PricingSourceChannel { - hasChannelPricing = true - } - } - if hasChannelPricing { + if s.resolveChannelPricing(ctx, billingModel, apiKey) != nil { tokens := UsageTokens{ InputTokens: result.Usage.InputTokens, OutputTokens: result.Usage.OutputTokens, @@ -8036,34 +8051,26 @@ func (s *GatewayService) calculateTokenCost( var cost *CostBreakdown var err error - // 优先尝试 Resolver + CalculateCostUnified(仅在有渠道定价时使用) - useUnified := false - if s.resolver != nil && apiKey.Group != nil { + // 优先尝试渠道定价 → CalculateCostUnified + if s.resolveChannelPricing(ctx, billingModel, apiKey) != nil { gid := apiKey.Group.ID - resolved := s.resolver.Resolve(ctx, PricingInput{Model: billingModel, GroupID: &gid}) - if resolved.Source == PricingSourceChannel { - cost, err = s.billingService.CalculateCostUnified(CostInput{ - Ctx: ctx, - Model: billingModel, - GroupID: &gid, - Tokens: tokens, - RequestCount: 1, - RateMultiplier: multiplier, - Resolver: s.resolver, - }) - useUnified = true - } - } - if !useUnified { - if opts.LongContextThreshold > 0 { - // 长上下文双倍计费(如 Gemini 200K 阈值) - cost, err = s.billingService.CalculateCostWithLongContext( - billingModel, tokens, multiplier, - opts.LongContextThreshold, opts.LongContextMultiplier, - ) - } else { - cost, err = s.billingService.CalculateCost(billingModel, tokens, multiplier) - } + cost, err = s.billingService.CalculateCostUnified(CostInput{ + Ctx: ctx, + Model: billingModel, + GroupID: &gid, + Tokens: tokens, + RequestCount: 1, + RateMultiplier: multiplier, + Resolver: s.resolver, + }) + } else if opts.LongContextThreshold > 0 { + // 长上下文双倍计费(如 Gemini 200K 阈值) + cost, err = s.billingService.CalculateCostWithLongContext( + billingModel, tokens, multiplier, + opts.LongContextThreshold, opts.LongContextMultiplier, + ) + } else { + cost, err = s.billingService.CalculateCost(billingModel, tokens, multiplier) } if err != nil { logger.LegacyPrintf("service.gateway", "Calculate cost failed: %v", err) @@ -8083,21 +8090,13 @@ func (s *GatewayService) buildRecordUsageLog( subscription *UserSubscription, requestedModel string, multiplier float64, + accountRateMultiplier float64, billingType int8, cacheTTLOverridden bool, cost *CostBreakdown, opts *recordUsageOpts, ) *UsageLog { durationMs := int(result.Duration.Milliseconds()) - var imageSize *string - if result.ImageSize != "" { - imageSize = &result.ImageSize - } - var mediaType *string - if opts.EnableClaudePath && strings.TrimSpace(result.MediaType) != "" { - mediaType = &result.MediaType - } - accountRateMultiplier := account.BillingRateMultiplier() requestID := resolveUsageBillingRequestID(ctx, result.RequestID) usageLog := &UsageLog{ UserID: user.ID, @@ -8120,15 +8119,20 @@ func (s *GatewayService) buildRecordUsageLog( RateMultiplier: multiplier, AccountRateMultiplier: &accountRateMultiplier, BillingType: billingType, + BillingMode: resolveBillingMode(opts, result, cost), Stream: result.Stream, DurationMs: &durationMs, FirstTokenMs: result.FirstTokenMs, ImageCount: result.ImageCount, - ImageSize: imageSize, - MediaType: mediaType, + ImageSize: optionalTrimmedStringPtr(result.ImageSize), + MediaType: resolveMediaType(opts, result), CacheTTLOverridden: cacheTTLOverridden, ChannelID: optionalInt64Ptr(input.ChannelID), ModelMappingChain: optionalTrimmedStringPtr(input.ModelMappingChain), + UserAgent: optionalTrimmedStringPtr(input.UserAgent), + IPAddress: optionalTrimmedStringPtr(input.IPAddress), + GroupID: apiKey.GroupID, + SubscriptionID: optionalSubscriptionID(subscription), CreatedAt: time.Now(), } if cost != nil { @@ -8141,41 +8145,41 @@ func (s *GatewayService) buildRecordUsageLog( usageLog.ActualCost = cost.ActualCost } - // 设置计费模式:Sora 媒体类型自身已确定计费模式(由上游处理),跳过 + return usageLog +} + +// resolveBillingMode 根据计费结果和请求类型确定计费模式。 +// Sora 媒体类型自身已确定计费模式(由上游处理),返回 nil 跳过。 +func resolveBillingMode(opts *recordUsageOpts, result *ForwardResult, cost *CostBreakdown) *string { isSoraMedia := opts.EnableClaudePath && (result.MediaType == MediaTypeImage || result.MediaType == MediaTypeVideo || result.MediaType == MediaTypePrompt) - if !isSoraMedia { - if cost != nil && cost.BillingMode != "" { - billingMode := cost.BillingMode - usageLog.BillingMode = &billingMode - } else if result.ImageCount > 0 { - billingMode := string(BillingModeImage) - usageLog.BillingMode = &billingMode - } else { - billingMode := string(BillingModeToken) - usageLog.BillingMode = &billingMode - } + if isSoraMedia { + return nil } + var mode string + switch { + case cost != nil && cost.BillingMode != "": + mode = cost.BillingMode + case result.ImageCount > 0: + mode = string(BillingModeImage) + default: + mode = string(BillingModeToken) + } + return &mode +} - // 添加 UserAgent - if input.UserAgent != "" { - usageLog.UserAgent = &input.UserAgent +func resolveMediaType(opts *recordUsageOpts, result *ForwardResult) *string { + if opts.EnableClaudePath && strings.TrimSpace(result.MediaType) != "" { + return &result.MediaType } + return nil +} - // 添加 IPAddress - if input.IPAddress != "" { - usageLog.IPAddress = &input.IPAddress - } - - // 添加分组和订阅关联 - if apiKey.GroupID != nil { - usageLog.GroupID = apiKey.GroupID - } +func optionalSubscriptionID(subscription *UserSubscription) *int64 { if subscription != nil { - usageLog.SubscriptionID = &subscription.ID + return &subscription.ID } - - return usageLog + return nil } // ResolveChannelMapping 委托渠道服务解析模型映射