refactor: extract helpers to reduce duplication and function length in gateway billing
- Extract resolveChannelPricing to DRY the resolver pattern shared by calculateImageCost/calculateTokenCost - Remove unnecessary IIFE wrapper and pass accountRateMultiplier as parameter - Extract resolveBillingMode, resolveMediaType, optionalSubscriptionID to simplify buildRecordUsageLog (104→65 lines) - Extract shouldDeductAPIKeyQuota/shouldUpdateRateLimits/shouldUpdateAccountQuota methods on postUsageBillingParams to unify duplicated billing conditions
This commit is contained in:
@@ -7451,6 +7451,18 @@ type postUsageBillingParams struct {
|
|||||||
APIKeyService APIKeyQuotaUpdater
|
APIKeyService APIKeyQuotaUpdater
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *postUsageBillingParams) shouldDeductAPIKeyQuota() bool {
|
||||||
|
return p.Cost.ActualCost > 0 && p.APIKey.Quota > 0 && p.APIKeyService != nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *postUsageBillingParams) shouldUpdateRateLimits() bool {
|
||||||
|
return p.Cost.ActualCost > 0 && p.APIKey.HasRateLimits() && p.APIKeyService != nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *postUsageBillingParams) shouldUpdateAccountQuota() bool {
|
||||||
|
return p.Cost.TotalCost > 0 && p.Account.IsAPIKeyOrBedrock() && p.Account.HasAnyQuotaLimit()
|
||||||
|
}
|
||||||
|
|
||||||
// postUsageBilling 统一处理使用量记录后的扣费逻辑:
|
// postUsageBilling 统一处理使用量记录后的扣费逻辑:
|
||||||
// - 订阅/余额扣费
|
// - 订阅/余额扣费
|
||||||
// - API Key 配额更新
|
// - API Key 配额更新
|
||||||
@@ -7480,21 +7492,21 @@ func postUsageBilling(ctx context.Context, p *postUsageBillingParams, deps *bill
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 2. API Key 配额
|
// 2. API Key 配额
|
||||||
if cost.ActualCost > 0 && p.APIKey.Quota > 0 && p.APIKeyService != nil {
|
if p.shouldDeductAPIKeyQuota() {
|
||||||
if err := p.APIKeyService.UpdateQuotaUsed(billingCtx, p.APIKey.ID, cost.ActualCost); err != nil {
|
if err := p.APIKeyService.UpdateQuotaUsed(billingCtx, p.APIKey.ID, cost.ActualCost); err != nil {
|
||||||
slog.Error("update api key quota failed", "api_key_id", p.APIKey.ID, "error", err)
|
slog.Error("update api key quota failed", "api_key_id", p.APIKey.ID, "error", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 3. API Key 限速用量
|
// 3. API Key 限速用量
|
||||||
if cost.ActualCost > 0 && p.APIKey.HasRateLimits() && p.APIKeyService != nil {
|
if p.shouldUpdateRateLimits() {
|
||||||
if err := p.APIKeyService.UpdateRateLimitUsage(billingCtx, p.APIKey.ID, cost.ActualCost); err != nil {
|
if err := p.APIKeyService.UpdateRateLimitUsage(billingCtx, p.APIKey.ID, cost.ActualCost); err != nil {
|
||||||
slog.Error("update api key rate limit usage failed", "api_key_id", p.APIKey.ID, "error", err)
|
slog.Error("update api key rate limit usage failed", "api_key_id", p.APIKey.ID, "error", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 4. 账号配额用量(账号口径:TotalCost × 账号计费倍率)
|
// 4. 账号配额用量(账号口径:TotalCost × 账号计费倍率)
|
||||||
if cost.TotalCost > 0 && p.Account.IsAPIKeyOrBedrock() && p.Account.HasAnyQuotaLimit() {
|
if p.shouldUpdateAccountQuota() {
|
||||||
accountCost := cost.TotalCost * p.AccountRateMultiplier
|
accountCost := cost.TotalCost * p.AccountRateMultiplier
|
||||||
if err := deps.accountRepo.IncrementQuotaUsed(billingCtx, p.Account.ID, accountCost); err != nil {
|
if err := deps.accountRepo.IncrementQuotaUsed(billingCtx, p.Account.ID, accountCost); err != nil {
|
||||||
slog.Error("increment account quota used failed", "account_id", p.Account.ID, "cost", accountCost, "error", err)
|
slog.Error("increment account quota used failed", "account_id", p.Account.ID, "cost", accountCost, "error", err)
|
||||||
@@ -7576,13 +7588,13 @@ func buildUsageBillingCommand(requestID string, usageLog *UsageLog, p *postUsage
|
|||||||
cmd.BalanceCost = p.Cost.ActualCost
|
cmd.BalanceCost = p.Cost.ActualCost
|
||||||
}
|
}
|
||||||
|
|
||||||
if p.Cost.ActualCost > 0 && p.APIKey.Quota > 0 && p.APIKeyService != nil {
|
if p.shouldDeductAPIKeyQuota() {
|
||||||
cmd.APIKeyQuotaCost = p.Cost.ActualCost
|
cmd.APIKeyQuotaCost = p.Cost.ActualCost
|
||||||
}
|
}
|
||||||
if p.Cost.ActualCost > 0 && p.APIKey.HasRateLimits() && p.APIKeyService != nil {
|
if p.shouldUpdateRateLimits() {
|
||||||
cmd.APIKeyRateLimitCost = p.Cost.ActualCost
|
cmd.APIKeyRateLimitCost = p.Cost.ActualCost
|
||||||
}
|
}
|
||||||
if p.Cost.TotalCost > 0 && p.Account.IsAPIKeyOrBedrock() && p.Account.HasAnyQuotaLimit() {
|
if p.shouldUpdateAccountQuota() {
|
||||||
cmd.AccountQuotaCost = p.Cost.TotalCost * p.AccountRateMultiplier
|
cmd.AccountQuotaCost = p.Cost.TotalCost * p.AccountRateMultiplier
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -7879,8 +7891,9 @@ func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsage
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 创建使用日志
|
// 创建使用日志
|
||||||
|
accountRateMultiplier := account.BillingRateMultiplier()
|
||||||
usageLog := s.buildRecordUsageLog(ctx, input, result, apiKey, user, account, subscription,
|
usageLog := s.buildRecordUsageLog(ctx, input, result, apiKey, user, account, subscription,
|
||||||
requestedModel, multiplier, billingType, cacheTTLOverridden, cost, opts)
|
requestedModel, multiplier, accountRateMultiplier, billingType, cacheTTLOverridden, cost, opts)
|
||||||
|
|
||||||
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
|
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
|
||||||
writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway")
|
writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway")
|
||||||
@@ -7890,21 +7903,17 @@ func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsage
|
|||||||
}
|
}
|
||||||
|
|
||||||
requestID := usageLog.RequestID
|
requestID := usageLog.RequestID
|
||||||
accountRateMultiplier := account.BillingRateMultiplier()
|
_, billingErr := applyUsageBilling(ctx, requestID, usageLog, &postUsageBillingParams{
|
||||||
billingErr := func() error {
|
Cost: cost,
|
||||||
_, err := applyUsageBilling(ctx, requestID, usageLog, &postUsageBillingParams{
|
User: user,
|
||||||
Cost: cost,
|
APIKey: apiKey,
|
||||||
User: user,
|
Account: account,
|
||||||
APIKey: apiKey,
|
Subscription: subscription,
|
||||||
Account: account,
|
RequestPayloadHash: resolveUsageBillingPayloadFingerprint(ctx, input.RequestPayloadHash),
|
||||||
Subscription: subscription,
|
IsSubscriptionBill: isSubscriptionBilling,
|
||||||
RequestPayloadHash: resolveUsageBillingPayloadFingerprint(ctx, input.RequestPayloadHash),
|
AccountRateMultiplier: accountRateMultiplier,
|
||||||
IsSubscriptionBill: isSubscriptionBilling,
|
APIKeyService: input.APIKeyService,
|
||||||
AccountRateMultiplier: accountRateMultiplier,
|
}, s.billingDeps(), s.usageBillingRepo)
|
||||||
APIKeyService: input.APIKeyService,
|
|
||||||
}, s.billingDeps(), s.usageBillingRepo)
|
|
||||||
return err
|
|
||||||
}()
|
|
||||||
|
|
||||||
if billingErr != nil {
|
if billingErr != nil {
|
||||||
return billingErr
|
return billingErr
|
||||||
@@ -7964,6 +7973,20 @@ func (s *GatewayService) calculateSoraMediaCost(
|
|||||||
return s.billingService.CalculateSoraVideoCost(billingModel, soraConfig, multiplier)
|
return s.billingService.CalculateSoraVideoCost(billingModel, soraConfig, multiplier)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// resolveChannelPricing 检查指定模型是否存在渠道级别定价。
|
||||||
|
// 返回非 nil 的 ResolvedPricing 表示有渠道定价,nil 表示走默认定价路径。
|
||||||
|
func (s *GatewayService) resolveChannelPricing(ctx context.Context, billingModel string, apiKey *APIKey) *ResolvedPricing {
|
||||||
|
if s.resolver == nil || apiKey.Group == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
gid := apiKey.Group.ID
|
||||||
|
resolved := s.resolver.Resolve(ctx, PricingInput{Model: billingModel, GroupID: &gid})
|
||||||
|
if resolved.Source == PricingSourceChannel {
|
||||||
|
return resolved
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// calculateImageCost 计算图片生成费用:渠道级别定价优先,否则走按次计费。
|
// calculateImageCost 计算图片生成费用:渠道级别定价优先,否则走按次计费。
|
||||||
func (s *GatewayService) calculateImageCost(
|
func (s *GatewayService) calculateImageCost(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
@@ -7972,15 +7995,7 @@ func (s *GatewayService) calculateImageCost(
|
|||||||
billingModel string,
|
billingModel string,
|
||||||
multiplier float64,
|
multiplier float64,
|
||||||
) *CostBreakdown {
|
) *CostBreakdown {
|
||||||
hasChannelPricing := false
|
if s.resolveChannelPricing(ctx, billingModel, apiKey) != nil {
|
||||||
if s.resolver != nil && apiKey.Group != nil {
|
|
||||||
gid := apiKey.Group.ID
|
|
||||||
resolved := s.resolver.Resolve(ctx, PricingInput{Model: billingModel, GroupID: &gid})
|
|
||||||
if resolved.Source == PricingSourceChannel {
|
|
||||||
hasChannelPricing = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if hasChannelPricing {
|
|
||||||
tokens := UsageTokens{
|
tokens := UsageTokens{
|
||||||
InputTokens: result.Usage.InputTokens,
|
InputTokens: result.Usage.InputTokens,
|
||||||
OutputTokens: result.Usage.OutputTokens,
|
OutputTokens: result.Usage.OutputTokens,
|
||||||
@@ -8036,34 +8051,26 @@ func (s *GatewayService) calculateTokenCost(
|
|||||||
var cost *CostBreakdown
|
var cost *CostBreakdown
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
// 优先尝试 Resolver + CalculateCostUnified(仅在有渠道定价时使用)
|
// 优先尝试渠道定价 → CalculateCostUnified
|
||||||
useUnified := false
|
if s.resolveChannelPricing(ctx, billingModel, apiKey) != nil {
|
||||||
if s.resolver != nil && apiKey.Group != nil {
|
|
||||||
gid := apiKey.Group.ID
|
gid := apiKey.Group.ID
|
||||||
resolved := s.resolver.Resolve(ctx, PricingInput{Model: billingModel, GroupID: &gid})
|
cost, err = s.billingService.CalculateCostUnified(CostInput{
|
||||||
if resolved.Source == PricingSourceChannel {
|
Ctx: ctx,
|
||||||
cost, err = s.billingService.CalculateCostUnified(CostInput{
|
Model: billingModel,
|
||||||
Ctx: ctx,
|
GroupID: &gid,
|
||||||
Model: billingModel,
|
Tokens: tokens,
|
||||||
GroupID: &gid,
|
RequestCount: 1,
|
||||||
Tokens: tokens,
|
RateMultiplier: multiplier,
|
||||||
RequestCount: 1,
|
Resolver: s.resolver,
|
||||||
RateMultiplier: multiplier,
|
})
|
||||||
Resolver: s.resolver,
|
} else if opts.LongContextThreshold > 0 {
|
||||||
})
|
// 长上下文双倍计费(如 Gemini 200K 阈值)
|
||||||
useUnified = true
|
cost, err = s.billingService.CalculateCostWithLongContext(
|
||||||
}
|
billingModel, tokens, multiplier,
|
||||||
}
|
opts.LongContextThreshold, opts.LongContextMultiplier,
|
||||||
if !useUnified {
|
)
|
||||||
if opts.LongContextThreshold > 0 {
|
} else {
|
||||||
// 长上下文双倍计费(如 Gemini 200K 阈值)
|
cost, err = s.billingService.CalculateCost(billingModel, tokens, multiplier)
|
||||||
cost, err = s.billingService.CalculateCostWithLongContext(
|
|
||||||
billingModel, tokens, multiplier,
|
|
||||||
opts.LongContextThreshold, opts.LongContextMultiplier,
|
|
||||||
)
|
|
||||||
} else {
|
|
||||||
cost, err = s.billingService.CalculateCost(billingModel, tokens, multiplier)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.LegacyPrintf("service.gateway", "Calculate cost failed: %v", err)
|
logger.LegacyPrintf("service.gateway", "Calculate cost failed: %v", err)
|
||||||
@@ -8083,21 +8090,13 @@ func (s *GatewayService) buildRecordUsageLog(
|
|||||||
subscription *UserSubscription,
|
subscription *UserSubscription,
|
||||||
requestedModel string,
|
requestedModel string,
|
||||||
multiplier float64,
|
multiplier float64,
|
||||||
|
accountRateMultiplier float64,
|
||||||
billingType int8,
|
billingType int8,
|
||||||
cacheTTLOverridden bool,
|
cacheTTLOverridden bool,
|
||||||
cost *CostBreakdown,
|
cost *CostBreakdown,
|
||||||
opts *recordUsageOpts,
|
opts *recordUsageOpts,
|
||||||
) *UsageLog {
|
) *UsageLog {
|
||||||
durationMs := int(result.Duration.Milliseconds())
|
durationMs := int(result.Duration.Milliseconds())
|
||||||
var imageSize *string
|
|
||||||
if result.ImageSize != "" {
|
|
||||||
imageSize = &result.ImageSize
|
|
||||||
}
|
|
||||||
var mediaType *string
|
|
||||||
if opts.EnableClaudePath && strings.TrimSpace(result.MediaType) != "" {
|
|
||||||
mediaType = &result.MediaType
|
|
||||||
}
|
|
||||||
accountRateMultiplier := account.BillingRateMultiplier()
|
|
||||||
requestID := resolveUsageBillingRequestID(ctx, result.RequestID)
|
requestID := resolveUsageBillingRequestID(ctx, result.RequestID)
|
||||||
usageLog := &UsageLog{
|
usageLog := &UsageLog{
|
||||||
UserID: user.ID,
|
UserID: user.ID,
|
||||||
@@ -8120,15 +8119,20 @@ func (s *GatewayService) buildRecordUsageLog(
|
|||||||
RateMultiplier: multiplier,
|
RateMultiplier: multiplier,
|
||||||
AccountRateMultiplier: &accountRateMultiplier,
|
AccountRateMultiplier: &accountRateMultiplier,
|
||||||
BillingType: billingType,
|
BillingType: billingType,
|
||||||
|
BillingMode: resolveBillingMode(opts, result, cost),
|
||||||
Stream: result.Stream,
|
Stream: result.Stream,
|
||||||
DurationMs: &durationMs,
|
DurationMs: &durationMs,
|
||||||
FirstTokenMs: result.FirstTokenMs,
|
FirstTokenMs: result.FirstTokenMs,
|
||||||
ImageCount: result.ImageCount,
|
ImageCount: result.ImageCount,
|
||||||
ImageSize: imageSize,
|
ImageSize: optionalTrimmedStringPtr(result.ImageSize),
|
||||||
MediaType: mediaType,
|
MediaType: resolveMediaType(opts, result),
|
||||||
CacheTTLOverridden: cacheTTLOverridden,
|
CacheTTLOverridden: cacheTTLOverridden,
|
||||||
ChannelID: optionalInt64Ptr(input.ChannelID),
|
ChannelID: optionalInt64Ptr(input.ChannelID),
|
||||||
ModelMappingChain: optionalTrimmedStringPtr(input.ModelMappingChain),
|
ModelMappingChain: optionalTrimmedStringPtr(input.ModelMappingChain),
|
||||||
|
UserAgent: optionalTrimmedStringPtr(input.UserAgent),
|
||||||
|
IPAddress: optionalTrimmedStringPtr(input.IPAddress),
|
||||||
|
GroupID: apiKey.GroupID,
|
||||||
|
SubscriptionID: optionalSubscriptionID(subscription),
|
||||||
CreatedAt: time.Now(),
|
CreatedAt: time.Now(),
|
||||||
}
|
}
|
||||||
if cost != nil {
|
if cost != nil {
|
||||||
@@ -8141,41 +8145,41 @@ func (s *GatewayService) buildRecordUsageLog(
|
|||||||
usageLog.ActualCost = cost.ActualCost
|
usageLog.ActualCost = cost.ActualCost
|
||||||
}
|
}
|
||||||
|
|
||||||
// 设置计费模式:Sora 媒体类型自身已确定计费模式(由上游处理),跳过
|
return usageLog
|
||||||
|
}
|
||||||
|
|
||||||
|
// resolveBillingMode 根据计费结果和请求类型确定计费模式。
|
||||||
|
// Sora 媒体类型自身已确定计费模式(由上游处理),返回 nil 跳过。
|
||||||
|
func resolveBillingMode(opts *recordUsageOpts, result *ForwardResult, cost *CostBreakdown) *string {
|
||||||
isSoraMedia := opts.EnableClaudePath &&
|
isSoraMedia := opts.EnableClaudePath &&
|
||||||
(result.MediaType == MediaTypeImage || result.MediaType == MediaTypeVideo || result.MediaType == MediaTypePrompt)
|
(result.MediaType == MediaTypeImage || result.MediaType == MediaTypeVideo || result.MediaType == MediaTypePrompt)
|
||||||
if !isSoraMedia {
|
if isSoraMedia {
|
||||||
if cost != nil && cost.BillingMode != "" {
|
return nil
|
||||||
billingMode := cost.BillingMode
|
|
||||||
usageLog.BillingMode = &billingMode
|
|
||||||
} else if result.ImageCount > 0 {
|
|
||||||
billingMode := string(BillingModeImage)
|
|
||||||
usageLog.BillingMode = &billingMode
|
|
||||||
} else {
|
|
||||||
billingMode := string(BillingModeToken)
|
|
||||||
usageLog.BillingMode = &billingMode
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
var mode string
|
||||||
|
switch {
|
||||||
|
case cost != nil && cost.BillingMode != "":
|
||||||
|
mode = cost.BillingMode
|
||||||
|
case result.ImageCount > 0:
|
||||||
|
mode = string(BillingModeImage)
|
||||||
|
default:
|
||||||
|
mode = string(BillingModeToken)
|
||||||
|
}
|
||||||
|
return &mode
|
||||||
|
}
|
||||||
|
|
||||||
// 添加 UserAgent
|
func resolveMediaType(opts *recordUsageOpts, result *ForwardResult) *string {
|
||||||
if input.UserAgent != "" {
|
if opts.EnableClaudePath && strings.TrimSpace(result.MediaType) != "" {
|
||||||
usageLog.UserAgent = &input.UserAgent
|
return &result.MediaType
|
||||||
}
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// 添加 IPAddress
|
func optionalSubscriptionID(subscription *UserSubscription) *int64 {
|
||||||
if input.IPAddress != "" {
|
|
||||||
usageLog.IPAddress = &input.IPAddress
|
|
||||||
}
|
|
||||||
|
|
||||||
// 添加分组和订阅关联
|
|
||||||
if apiKey.GroupID != nil {
|
|
||||||
usageLog.GroupID = apiKey.GroupID
|
|
||||||
}
|
|
||||||
if subscription != nil {
|
if subscription != nil {
|
||||||
usageLog.SubscriptionID = &subscription.ID
|
return &subscription.ID
|
||||||
}
|
}
|
||||||
|
return nil
|
||||||
return usageLog
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ResolveChannelMapping 委托渠道服务解析模型映射
|
// ResolveChannelMapping 委托渠道服务解析模型映射
|
||||||
|
|||||||
Reference in New Issue
Block a user