refactor: extract helpers to reduce duplication and function length in gateway billing
- Extract resolveChannelPricing to DRY the resolver pattern shared by calculateImageCost/calculateTokenCost - Remove unnecessary IIFE wrapper and pass accountRateMultiplier as parameter - Extract resolveBillingMode, resolveMediaType, optionalSubscriptionID to simplify buildRecordUsageLog (104→65 lines) - Extract shouldDeductAPIKeyQuota/shouldUpdateRateLimits/shouldUpdateAccountQuota methods on postUsageBillingParams to unify duplicated billing conditions
This commit is contained in:
@@ -7451,6 +7451,18 @@ type postUsageBillingParams struct {
|
||||
APIKeyService APIKeyQuotaUpdater
|
||||
}
|
||||
|
||||
func (p *postUsageBillingParams) shouldDeductAPIKeyQuota() bool {
|
||||
return p.Cost.ActualCost > 0 && p.APIKey.Quota > 0 && p.APIKeyService != nil
|
||||
}
|
||||
|
||||
func (p *postUsageBillingParams) shouldUpdateRateLimits() bool {
|
||||
return p.Cost.ActualCost > 0 && p.APIKey.HasRateLimits() && p.APIKeyService != nil
|
||||
}
|
||||
|
||||
func (p *postUsageBillingParams) shouldUpdateAccountQuota() bool {
|
||||
return p.Cost.TotalCost > 0 && p.Account.IsAPIKeyOrBedrock() && p.Account.HasAnyQuotaLimit()
|
||||
}
|
||||
|
||||
// postUsageBilling 统一处理使用量记录后的扣费逻辑:
|
||||
// - 订阅/余额扣费
|
||||
// - API Key 配额更新
|
||||
@@ -7480,21 +7492,21 @@ func postUsageBilling(ctx context.Context, p *postUsageBillingParams, deps *bill
|
||||
}
|
||||
|
||||
// 2. API Key 配额
|
||||
if cost.ActualCost > 0 && p.APIKey.Quota > 0 && p.APIKeyService != nil {
|
||||
if p.shouldDeductAPIKeyQuota() {
|
||||
if err := p.APIKeyService.UpdateQuotaUsed(billingCtx, p.APIKey.ID, cost.ActualCost); err != nil {
|
||||
slog.Error("update api key quota failed", "api_key_id", p.APIKey.ID, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 3. API Key 限速用量
|
||||
if cost.ActualCost > 0 && p.APIKey.HasRateLimits() && p.APIKeyService != nil {
|
||||
if p.shouldUpdateRateLimits() {
|
||||
if err := p.APIKeyService.UpdateRateLimitUsage(billingCtx, p.APIKey.ID, cost.ActualCost); err != nil {
|
||||
slog.Error("update api key rate limit usage failed", "api_key_id", p.APIKey.ID, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 4. 账号配额用量(账号口径:TotalCost × 账号计费倍率)
|
||||
if cost.TotalCost > 0 && p.Account.IsAPIKeyOrBedrock() && p.Account.HasAnyQuotaLimit() {
|
||||
if p.shouldUpdateAccountQuota() {
|
||||
accountCost := cost.TotalCost * p.AccountRateMultiplier
|
||||
if err := deps.accountRepo.IncrementQuotaUsed(billingCtx, p.Account.ID, accountCost); err != nil {
|
||||
slog.Error("increment account quota used failed", "account_id", p.Account.ID, "cost", accountCost, "error", err)
|
||||
@@ -7576,13 +7588,13 @@ func buildUsageBillingCommand(requestID string, usageLog *UsageLog, p *postUsage
|
||||
cmd.BalanceCost = p.Cost.ActualCost
|
||||
}
|
||||
|
||||
if p.Cost.ActualCost > 0 && p.APIKey.Quota > 0 && p.APIKeyService != nil {
|
||||
if p.shouldDeductAPIKeyQuota() {
|
||||
cmd.APIKeyQuotaCost = p.Cost.ActualCost
|
||||
}
|
||||
if p.Cost.ActualCost > 0 && p.APIKey.HasRateLimits() && p.APIKeyService != nil {
|
||||
if p.shouldUpdateRateLimits() {
|
||||
cmd.APIKeyRateLimitCost = p.Cost.ActualCost
|
||||
}
|
||||
if p.Cost.TotalCost > 0 && p.Account.IsAPIKeyOrBedrock() && p.Account.HasAnyQuotaLimit() {
|
||||
if p.shouldUpdateAccountQuota() {
|
||||
cmd.AccountQuotaCost = p.Cost.TotalCost * p.AccountRateMultiplier
|
||||
}
|
||||
|
||||
@@ -7879,8 +7891,9 @@ func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsage
|
||||
}
|
||||
|
||||
// 创建使用日志
|
||||
accountRateMultiplier := account.BillingRateMultiplier()
|
||||
usageLog := s.buildRecordUsageLog(ctx, input, result, apiKey, user, account, subscription,
|
||||
requestedModel, multiplier, billingType, cacheTTLOverridden, cost, opts)
|
||||
requestedModel, multiplier, accountRateMultiplier, billingType, cacheTTLOverridden, cost, opts)
|
||||
|
||||
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
|
||||
writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway")
|
||||
@@ -7890,21 +7903,17 @@ func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsage
|
||||
}
|
||||
|
||||
requestID := usageLog.RequestID
|
||||
accountRateMultiplier := account.BillingRateMultiplier()
|
||||
billingErr := func() error {
|
||||
_, err := applyUsageBilling(ctx, requestID, usageLog, &postUsageBillingParams{
|
||||
Cost: cost,
|
||||
User: user,
|
||||
APIKey: apiKey,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
RequestPayloadHash: resolveUsageBillingPayloadFingerprint(ctx, input.RequestPayloadHash),
|
||||
IsSubscriptionBill: isSubscriptionBilling,
|
||||
AccountRateMultiplier: accountRateMultiplier,
|
||||
APIKeyService: input.APIKeyService,
|
||||
}, s.billingDeps(), s.usageBillingRepo)
|
||||
return err
|
||||
}()
|
||||
_, billingErr := applyUsageBilling(ctx, requestID, usageLog, &postUsageBillingParams{
|
||||
Cost: cost,
|
||||
User: user,
|
||||
APIKey: apiKey,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
RequestPayloadHash: resolveUsageBillingPayloadFingerprint(ctx, input.RequestPayloadHash),
|
||||
IsSubscriptionBill: isSubscriptionBilling,
|
||||
AccountRateMultiplier: accountRateMultiplier,
|
||||
APIKeyService: input.APIKeyService,
|
||||
}, s.billingDeps(), s.usageBillingRepo)
|
||||
|
||||
if billingErr != nil {
|
||||
return billingErr
|
||||
@@ -7964,6 +7973,20 @@ func (s *GatewayService) calculateSoraMediaCost(
|
||||
return s.billingService.CalculateSoraVideoCost(billingModel, soraConfig, multiplier)
|
||||
}
|
||||
|
||||
// resolveChannelPricing 检查指定模型是否存在渠道级别定价。
|
||||
// 返回非 nil 的 ResolvedPricing 表示有渠道定价,nil 表示走默认定价路径。
|
||||
func (s *GatewayService) resolveChannelPricing(ctx context.Context, billingModel string, apiKey *APIKey) *ResolvedPricing {
|
||||
if s.resolver == nil || apiKey.Group == nil {
|
||||
return nil
|
||||
}
|
||||
gid := apiKey.Group.ID
|
||||
resolved := s.resolver.Resolve(ctx, PricingInput{Model: billingModel, GroupID: &gid})
|
||||
if resolved.Source == PricingSourceChannel {
|
||||
return resolved
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// calculateImageCost 计算图片生成费用:渠道级别定价优先,否则走按次计费。
|
||||
func (s *GatewayService) calculateImageCost(
|
||||
ctx context.Context,
|
||||
@@ -7972,15 +7995,7 @@ func (s *GatewayService) calculateImageCost(
|
||||
billingModel string,
|
||||
multiplier float64,
|
||||
) *CostBreakdown {
|
||||
hasChannelPricing := false
|
||||
if s.resolver != nil && apiKey.Group != nil {
|
||||
gid := apiKey.Group.ID
|
||||
resolved := s.resolver.Resolve(ctx, PricingInput{Model: billingModel, GroupID: &gid})
|
||||
if resolved.Source == PricingSourceChannel {
|
||||
hasChannelPricing = true
|
||||
}
|
||||
}
|
||||
if hasChannelPricing {
|
||||
if s.resolveChannelPricing(ctx, billingModel, apiKey) != nil {
|
||||
tokens := UsageTokens{
|
||||
InputTokens: result.Usage.InputTokens,
|
||||
OutputTokens: result.Usage.OutputTokens,
|
||||
@@ -8036,34 +8051,26 @@ func (s *GatewayService) calculateTokenCost(
|
||||
var cost *CostBreakdown
|
||||
var err error
|
||||
|
||||
// 优先尝试 Resolver + CalculateCostUnified(仅在有渠道定价时使用)
|
||||
useUnified := false
|
||||
if s.resolver != nil && apiKey.Group != nil {
|
||||
// 优先尝试渠道定价 → CalculateCostUnified
|
||||
if s.resolveChannelPricing(ctx, billingModel, apiKey) != nil {
|
||||
gid := apiKey.Group.ID
|
||||
resolved := s.resolver.Resolve(ctx, PricingInput{Model: billingModel, GroupID: &gid})
|
||||
if resolved.Source == PricingSourceChannel {
|
||||
cost, err = s.billingService.CalculateCostUnified(CostInput{
|
||||
Ctx: ctx,
|
||||
Model: billingModel,
|
||||
GroupID: &gid,
|
||||
Tokens: tokens,
|
||||
RequestCount: 1,
|
||||
RateMultiplier: multiplier,
|
||||
Resolver: s.resolver,
|
||||
})
|
||||
useUnified = true
|
||||
}
|
||||
}
|
||||
if !useUnified {
|
||||
if opts.LongContextThreshold > 0 {
|
||||
// 长上下文双倍计费(如 Gemini 200K 阈值)
|
||||
cost, err = s.billingService.CalculateCostWithLongContext(
|
||||
billingModel, tokens, multiplier,
|
||||
opts.LongContextThreshold, opts.LongContextMultiplier,
|
||||
)
|
||||
} else {
|
||||
cost, err = s.billingService.CalculateCost(billingModel, tokens, multiplier)
|
||||
}
|
||||
cost, err = s.billingService.CalculateCostUnified(CostInput{
|
||||
Ctx: ctx,
|
||||
Model: billingModel,
|
||||
GroupID: &gid,
|
||||
Tokens: tokens,
|
||||
RequestCount: 1,
|
||||
RateMultiplier: multiplier,
|
||||
Resolver: s.resolver,
|
||||
})
|
||||
} else if opts.LongContextThreshold > 0 {
|
||||
// 长上下文双倍计费(如 Gemini 200K 阈值)
|
||||
cost, err = s.billingService.CalculateCostWithLongContext(
|
||||
billingModel, tokens, multiplier,
|
||||
opts.LongContextThreshold, opts.LongContextMultiplier,
|
||||
)
|
||||
} else {
|
||||
cost, err = s.billingService.CalculateCost(billingModel, tokens, multiplier)
|
||||
}
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.gateway", "Calculate cost failed: %v", err)
|
||||
@@ -8083,21 +8090,13 @@ func (s *GatewayService) buildRecordUsageLog(
|
||||
subscription *UserSubscription,
|
||||
requestedModel string,
|
||||
multiplier float64,
|
||||
accountRateMultiplier float64,
|
||||
billingType int8,
|
||||
cacheTTLOverridden bool,
|
||||
cost *CostBreakdown,
|
||||
opts *recordUsageOpts,
|
||||
) *UsageLog {
|
||||
durationMs := int(result.Duration.Milliseconds())
|
||||
var imageSize *string
|
||||
if result.ImageSize != "" {
|
||||
imageSize = &result.ImageSize
|
||||
}
|
||||
var mediaType *string
|
||||
if opts.EnableClaudePath && strings.TrimSpace(result.MediaType) != "" {
|
||||
mediaType = &result.MediaType
|
||||
}
|
||||
accountRateMultiplier := account.BillingRateMultiplier()
|
||||
requestID := resolveUsageBillingRequestID(ctx, result.RequestID)
|
||||
usageLog := &UsageLog{
|
||||
UserID: user.ID,
|
||||
@@ -8120,15 +8119,20 @@ func (s *GatewayService) buildRecordUsageLog(
|
||||
RateMultiplier: multiplier,
|
||||
AccountRateMultiplier: &accountRateMultiplier,
|
||||
BillingType: billingType,
|
||||
BillingMode: resolveBillingMode(opts, result, cost),
|
||||
Stream: result.Stream,
|
||||
DurationMs: &durationMs,
|
||||
FirstTokenMs: result.FirstTokenMs,
|
||||
ImageCount: result.ImageCount,
|
||||
ImageSize: imageSize,
|
||||
MediaType: mediaType,
|
||||
ImageSize: optionalTrimmedStringPtr(result.ImageSize),
|
||||
MediaType: resolveMediaType(opts, result),
|
||||
CacheTTLOverridden: cacheTTLOverridden,
|
||||
ChannelID: optionalInt64Ptr(input.ChannelID),
|
||||
ModelMappingChain: optionalTrimmedStringPtr(input.ModelMappingChain),
|
||||
UserAgent: optionalTrimmedStringPtr(input.UserAgent),
|
||||
IPAddress: optionalTrimmedStringPtr(input.IPAddress),
|
||||
GroupID: apiKey.GroupID,
|
||||
SubscriptionID: optionalSubscriptionID(subscription),
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
if cost != nil {
|
||||
@@ -8141,41 +8145,41 @@ func (s *GatewayService) buildRecordUsageLog(
|
||||
usageLog.ActualCost = cost.ActualCost
|
||||
}
|
||||
|
||||
// 设置计费模式:Sora 媒体类型自身已确定计费模式(由上游处理),跳过
|
||||
return usageLog
|
||||
}
|
||||
|
||||
// resolveBillingMode 根据计费结果和请求类型确定计费模式。
|
||||
// Sora 媒体类型自身已确定计费模式(由上游处理),返回 nil 跳过。
|
||||
func resolveBillingMode(opts *recordUsageOpts, result *ForwardResult, cost *CostBreakdown) *string {
|
||||
isSoraMedia := opts.EnableClaudePath &&
|
||||
(result.MediaType == MediaTypeImage || result.MediaType == MediaTypeVideo || result.MediaType == MediaTypePrompt)
|
||||
if !isSoraMedia {
|
||||
if cost != nil && cost.BillingMode != "" {
|
||||
billingMode := cost.BillingMode
|
||||
usageLog.BillingMode = &billingMode
|
||||
} else if result.ImageCount > 0 {
|
||||
billingMode := string(BillingModeImage)
|
||||
usageLog.BillingMode = &billingMode
|
||||
} else {
|
||||
billingMode := string(BillingModeToken)
|
||||
usageLog.BillingMode = &billingMode
|
||||
}
|
||||
if isSoraMedia {
|
||||
return nil
|
||||
}
|
||||
var mode string
|
||||
switch {
|
||||
case cost != nil && cost.BillingMode != "":
|
||||
mode = cost.BillingMode
|
||||
case result.ImageCount > 0:
|
||||
mode = string(BillingModeImage)
|
||||
default:
|
||||
mode = string(BillingModeToken)
|
||||
}
|
||||
return &mode
|
||||
}
|
||||
|
||||
// 添加 UserAgent
|
||||
if input.UserAgent != "" {
|
||||
usageLog.UserAgent = &input.UserAgent
|
||||
func resolveMediaType(opts *recordUsageOpts, result *ForwardResult) *string {
|
||||
if opts.EnableClaudePath && strings.TrimSpace(result.MediaType) != "" {
|
||||
return &result.MediaType
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// 添加 IPAddress
|
||||
if input.IPAddress != "" {
|
||||
usageLog.IPAddress = &input.IPAddress
|
||||
}
|
||||
|
||||
// 添加分组和订阅关联
|
||||
if apiKey.GroupID != nil {
|
||||
usageLog.GroupID = apiKey.GroupID
|
||||
}
|
||||
func optionalSubscriptionID(subscription *UserSubscription) *int64 {
|
||||
if subscription != nil {
|
||||
usageLog.SubscriptionID = &subscription.ID
|
||||
return &subscription.ID
|
||||
}
|
||||
|
||||
return usageLog
|
||||
return nil
|
||||
}
|
||||
|
||||
// ResolveChannelMapping 委托渠道服务解析模型映射
|
||||
|
||||
Reference in New Issue
Block a user