refactor: merge RecordUsage and RecordUsageWithLongContext into shared core
- Extract recordUsageCore with recordUsageOpts for parameterized differences - RecordUsage (276 lines) → thin wrapper (~40 lines) - RecordUsageWithLongContext (251 lines) → thin wrapper (~20 lines) - Split billing logic into calculateSoraMediaCost, calculateImageCost, calculateTokenCost sub-functions - Extract buildRecordUsageLog for usage log construction - Net reduction: -79 lines, eliminated ~170 lines of duplication
This commit is contained in:
@@ -7706,8 +7706,109 @@ func writeUsageLogBestEffort(ctx context.Context, repo UsageLogRepository, usage
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// recordUsageOpts 内部选项,参数化 RecordUsage 与 RecordUsageWithLongContext 的差异点。
|
||||||
|
type recordUsageOpts struct {
|
||||||
|
// Claude Max 策略所需的 ParsedRequest(可选,仅 Claude 路径传入)
|
||||||
|
ParsedRequest *ParsedRequest
|
||||||
|
|
||||||
|
// EnableClaudePath 启用 Claude 路径特有逻辑:
|
||||||
|
// - Claude Max 缓存计费策略
|
||||||
|
// - Sora 媒体类型分支(image/video/prompt)
|
||||||
|
// - MediaType 字段写入使用日志
|
||||||
|
EnableClaudePath bool
|
||||||
|
|
||||||
|
// 长上下文计费(仅 Gemini 路径需要)
|
||||||
|
LongContextThreshold int
|
||||||
|
LongContextMultiplier float64
|
||||||
|
}
|
||||||
|
|
||||||
// RecordUsage 记录使用量并扣费(或更新订阅用量)
|
// RecordUsage 记录使用量并扣费(或更新订阅用量)
|
||||||
func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInput) error {
|
func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInput) error {
|
||||||
|
return s.recordUsageCore(ctx, &recordUsageCoreInput{
|
||||||
|
Result: input.Result,
|
||||||
|
APIKey: input.APIKey,
|
||||||
|
User: input.User,
|
||||||
|
Account: input.Account,
|
||||||
|
Subscription: input.Subscription,
|
||||||
|
InboundEndpoint: input.InboundEndpoint,
|
||||||
|
UpstreamEndpoint: input.UpstreamEndpoint,
|
||||||
|
UserAgent: input.UserAgent,
|
||||||
|
IPAddress: input.IPAddress,
|
||||||
|
RequestPayloadHash: input.RequestPayloadHash,
|
||||||
|
ForceCacheBilling: input.ForceCacheBilling,
|
||||||
|
APIKeyService: input.APIKeyService,
|
||||||
|
ChannelUsageFields: input.ChannelUsageFields,
|
||||||
|
}, &recordUsageOpts{
|
||||||
|
ParsedRequest: input.ParsedRequest,
|
||||||
|
EnableClaudePath: true,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// RecordUsageLongContextInput 记录使用量的输入参数(支持长上下文双倍计费)
|
||||||
|
type RecordUsageLongContextInput struct {
|
||||||
|
Result *ForwardResult
|
||||||
|
APIKey *APIKey
|
||||||
|
User *User
|
||||||
|
Account *Account
|
||||||
|
Subscription *UserSubscription // 可选:订阅信息
|
||||||
|
InboundEndpoint string // 入站端点(客户端请求路径)
|
||||||
|
UpstreamEndpoint string // 上游端点(标准化后的上游路径)
|
||||||
|
UserAgent string // 请求的 User-Agent
|
||||||
|
IPAddress string // 请求的客户端 IP 地址
|
||||||
|
RequestPayloadHash string // 请求体语义哈希,用于降低 request_id 误复用时的静默误去重风险
|
||||||
|
LongContextThreshold int // 长上下文阈值(如 200000)
|
||||||
|
LongContextMultiplier float64 // 超出阈值部分的倍率(如 2.0)
|
||||||
|
ForceCacheBilling bool // 强制缓存计费:将 input_tokens 转为 cache_read 计费(用于粘性会话切换)
|
||||||
|
APIKeyService APIKeyQuotaUpdater // API Key 配额服务(可选)
|
||||||
|
|
||||||
|
ChannelUsageFields // 渠道映射信息(由 handler 在 Forward 前解析)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RecordUsageWithLongContext 记录使用量并扣费,支持长上下文双倍计费(用于 Gemini)
|
||||||
|
func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *RecordUsageLongContextInput) error {
|
||||||
|
return s.recordUsageCore(ctx, &recordUsageCoreInput{
|
||||||
|
Result: input.Result,
|
||||||
|
APIKey: input.APIKey,
|
||||||
|
User: input.User,
|
||||||
|
Account: input.Account,
|
||||||
|
Subscription: input.Subscription,
|
||||||
|
InboundEndpoint: input.InboundEndpoint,
|
||||||
|
UpstreamEndpoint: input.UpstreamEndpoint,
|
||||||
|
UserAgent: input.UserAgent,
|
||||||
|
IPAddress: input.IPAddress,
|
||||||
|
RequestPayloadHash: input.RequestPayloadHash,
|
||||||
|
ForceCacheBilling: input.ForceCacheBilling,
|
||||||
|
APIKeyService: input.APIKeyService,
|
||||||
|
ChannelUsageFields: input.ChannelUsageFields,
|
||||||
|
}, &recordUsageOpts{
|
||||||
|
LongContextThreshold: input.LongContextThreshold,
|
||||||
|
LongContextMultiplier: input.LongContextMultiplier,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// recordUsageCoreInput 是 recordUsageCore 的公共输入字段,从两种输入结构体中提取。
|
||||||
|
type recordUsageCoreInput struct {
|
||||||
|
Result *ForwardResult
|
||||||
|
APIKey *APIKey
|
||||||
|
User *User
|
||||||
|
Account *Account
|
||||||
|
Subscription *UserSubscription
|
||||||
|
InboundEndpoint string
|
||||||
|
UpstreamEndpoint string
|
||||||
|
UserAgent string
|
||||||
|
IPAddress string
|
||||||
|
RequestPayloadHash string
|
||||||
|
ForceCacheBilling bool
|
||||||
|
APIKeyService APIKeyQuotaUpdater
|
||||||
|
ChannelUsageFields
|
||||||
|
}
|
||||||
|
|
||||||
|
// recordUsageCore 是 RecordUsage 和 RecordUsageWithLongContext 的统一实现。
|
||||||
|
// opts 中的字段控制两者之间的差异行为:
|
||||||
|
// - ParsedRequest != nil → 启用 Claude Max 缓存计费策略
|
||||||
|
// - EnableSoraMedia → 启用 Sora MediaType 分支(image/video/prompt)
|
||||||
|
// - LongContextThreshold > 0 → Token 计费回退走 CalculateCostWithLongContext
|
||||||
|
func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsageCoreInput, opts *recordUsageOpts) error {
|
||||||
result := input.Result
|
result := input.Result
|
||||||
apiKey := input.APIKey
|
apiKey := input.APIKey
|
||||||
user := input.User
|
user := input.User
|
||||||
@@ -7723,9 +7824,21 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
|||||||
result.Usage.InputTokens = 0
|
result.Usage.InputTokens = 0
|
||||||
}
|
}
|
||||||
|
|
||||||
// Cache TTL Override: 确保计费时 token 分类与账号设置一致
|
// Claude Max cache billing policy(仅 Claude 路径启用)
|
||||||
cacheTTLOverridden := false
|
cacheTTLOverridden := false
|
||||||
if account.IsCacheTTLOverrideEnabled() {
|
simulatedClaudeMax := false
|
||||||
|
if opts.EnableClaudePath {
|
||||||
|
var apiKeyGroup *Group
|
||||||
|
if apiKey != nil {
|
||||||
|
apiKeyGroup = apiKey.Group
|
||||||
|
}
|
||||||
|
claudeMaxOutcome := applyClaudeMaxCacheBillingPolicyToUsage(&result.Usage, opts.ParsedRequest, apiKeyGroup, result.Model, account.ID)
|
||||||
|
simulatedClaudeMax = claudeMaxOutcome.Simulated ||
|
||||||
|
(shouldApplyClaudeMaxBillingRulesForUsage(apiKeyGroup, result.Model, opts.ParsedRequest) && hasCacheCreationTokens(result.Usage))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cache TTL Override: 确保计费时 token 分类与账号设置一致
|
||||||
|
if account.IsCacheTTLOverrideEnabled() && !simulatedClaudeMax {
|
||||||
applyCacheTTLOverride(&result.Usage, account.GetCacheTTLOverrideTarget())
|
applyCacheTTLOverride(&result.Usage, account.GetCacheTTLOverrideTarget())
|
||||||
cacheTTLOverridden = (result.Usage.CacheCreation5mTokens + result.Usage.CacheCreation1hTokens) > 0
|
cacheTTLOverridden = (result.Usage.CacheCreation5mTokens + result.Usage.CacheCreation1hTokens) > 0
|
||||||
}
|
}
|
||||||
@@ -7740,7 +7853,6 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
|||||||
multiplier = s.getUserGroupRateMultiplier(ctx, user.ID, *apiKey.GroupID, groupDefault)
|
multiplier = s.getUserGroupRateMultiplier(ctx, user.ID, *apiKey.GroupID, groupDefault)
|
||||||
}
|
}
|
||||||
|
|
||||||
var cost *CostBreakdown
|
|
||||||
// 确定计费模型
|
// 确定计费模型
|
||||||
billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel)
|
billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel)
|
||||||
if input.BillingModelSource == BillingModelSourceChannelMapped && input.ChannelMappedModel != "" {
|
if input.BillingModelSource == BillingModelSourceChannelMapped && input.ChannelMappedModel != "" {
|
||||||
@@ -7756,8 +7868,87 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
|||||||
requestedModel = input.OriginalModel
|
requestedModel = input.OriginalModel
|
||||||
}
|
}
|
||||||
|
|
||||||
// 根据请求类型选择计费方式
|
// 计算费用
|
||||||
|
cost := s.calculateRecordUsageCost(ctx, result, apiKey, billingModel, multiplier, opts)
|
||||||
|
|
||||||
|
// 判断计费方式:订阅模式 vs 余额模式
|
||||||
|
isSubscriptionBilling := subscription != nil && apiKey.Group != nil && apiKey.Group.IsSubscriptionType()
|
||||||
|
billingType := BillingTypeBalance
|
||||||
|
if isSubscriptionBilling {
|
||||||
|
billingType = BillingTypeSubscription
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建使用日志
|
||||||
|
usageLog := s.buildRecordUsageLog(ctx, input, result, apiKey, user, account, subscription,
|
||||||
|
requestedModel, multiplier, billingType, cacheTTLOverridden, cost, opts)
|
||||||
|
|
||||||
|
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
|
||||||
|
writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway")
|
||||||
|
logger.LegacyPrintf("service.gateway", "[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens())
|
||||||
|
s.deferredService.ScheduleLastUsedUpdate(account.ID)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
requestID := usageLog.RequestID
|
||||||
|
accountRateMultiplier := account.BillingRateMultiplier()
|
||||||
|
billingErr := func() error {
|
||||||
|
_, err := applyUsageBilling(ctx, requestID, usageLog, &postUsageBillingParams{
|
||||||
|
Cost: cost,
|
||||||
|
User: user,
|
||||||
|
APIKey: apiKey,
|
||||||
|
Account: account,
|
||||||
|
Subscription: subscription,
|
||||||
|
RequestPayloadHash: resolveUsageBillingPayloadFingerprint(ctx, input.RequestPayloadHash),
|
||||||
|
IsSubscriptionBill: isSubscriptionBilling,
|
||||||
|
AccountRateMultiplier: accountRateMultiplier,
|
||||||
|
APIKeyService: input.APIKeyService,
|
||||||
|
}, s.billingDeps(), s.usageBillingRepo)
|
||||||
|
return err
|
||||||
|
}()
|
||||||
|
|
||||||
|
if billingErr != nil {
|
||||||
|
return billingErr
|
||||||
|
}
|
||||||
|
writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway")
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// calculateRecordUsageCost 根据请求类型和选项计算费用。
|
||||||
|
func (s *GatewayService) calculateRecordUsageCost(
|
||||||
|
ctx context.Context,
|
||||||
|
result *ForwardResult,
|
||||||
|
apiKey *APIKey,
|
||||||
|
billingModel string,
|
||||||
|
multiplier float64,
|
||||||
|
opts *recordUsageOpts,
|
||||||
|
) *CostBreakdown {
|
||||||
|
// Sora 媒体类型分支(仅 Claude 路径启用)
|
||||||
|
if opts.EnableClaudePath {
|
||||||
if result.MediaType == MediaTypeImage || result.MediaType == MediaTypeVideo {
|
if result.MediaType == MediaTypeImage || result.MediaType == MediaTypeVideo {
|
||||||
|
return s.calculateSoraMediaCost(result, apiKey, billingModel, multiplier)
|
||||||
|
}
|
||||||
|
if result.MediaType == MediaTypePrompt {
|
||||||
|
return &CostBreakdown{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 图片生成计费
|
||||||
|
if result.ImageCount > 0 {
|
||||||
|
return s.calculateImageCost(ctx, result, apiKey, billingModel, multiplier)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Token 计费
|
||||||
|
return s.calculateTokenCost(ctx, result, apiKey, billingModel, multiplier, opts)
|
||||||
|
}
|
||||||
|
|
||||||
|
// calculateSoraMediaCost 计算 Sora 图片/视频的费用。
|
||||||
|
func (s *GatewayService) calculateSoraMediaCost(
|
||||||
|
result *ForwardResult,
|
||||||
|
apiKey *APIKey,
|
||||||
|
billingModel string,
|
||||||
|
multiplier float64,
|
||||||
|
) *CostBreakdown {
|
||||||
var soraConfig *SoraPriceConfig
|
var soraConfig *SoraPriceConfig
|
||||||
if apiKey.Group != nil {
|
if apiKey.Group != nil {
|
||||||
soraConfig = &SoraPriceConfig{
|
soraConfig = &SoraPriceConfig{
|
||||||
@@ -7768,14 +7959,19 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if result.MediaType == MediaTypeImage {
|
if result.MediaType == MediaTypeImage {
|
||||||
cost = s.billingService.CalculateSoraImageCost(result.ImageSize, result.ImageCount, soraConfig, multiplier)
|
return s.billingService.CalculateSoraImageCost(result.ImageSize, result.ImageCount, soraConfig, multiplier)
|
||||||
} else {
|
|
||||||
cost = s.billingService.CalculateSoraVideoCost(billingModel, soraConfig, multiplier)
|
|
||||||
}
|
}
|
||||||
} else if result.MediaType == MediaTypePrompt {
|
return s.billingService.CalculateSoraVideoCost(billingModel, soraConfig, multiplier)
|
||||||
cost = &CostBreakdown{}
|
}
|
||||||
} else if result.ImageCount > 0 {
|
|
||||||
// 图片生成计费:渠道级别定价优先,否则走按次计费(兼容旧版本)
|
// calculateImageCost 计算图片生成费用:渠道级别定价优先,否则走按次计费。
|
||||||
|
func (s *GatewayService) calculateImageCost(
|
||||||
|
ctx context.Context,
|
||||||
|
result *ForwardResult,
|
||||||
|
apiKey *APIKey,
|
||||||
|
billingModel string,
|
||||||
|
multiplier float64,
|
||||||
|
) *CostBreakdown {
|
||||||
hasChannelPricing := false
|
hasChannelPricing := false
|
||||||
if s.resolver != nil && apiKey.Group != nil {
|
if s.resolver != nil && apiKey.Group != nil {
|
||||||
gid := apiKey.Group.ID
|
gid := apiKey.Group.ID
|
||||||
@@ -7785,15 +7981,13 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if hasChannelPricing {
|
if hasChannelPricing {
|
||||||
// 渠道定价优先 → 由 CalculateCostUnified 按 resolved.Mode 分发计费
|
|
||||||
tokens := UsageTokens{
|
tokens := UsageTokens{
|
||||||
InputTokens: result.Usage.InputTokens,
|
InputTokens: result.Usage.InputTokens,
|
||||||
OutputTokens: result.Usage.OutputTokens,
|
OutputTokens: result.Usage.OutputTokens,
|
||||||
ImageOutputTokens: result.Usage.ImageOutputTokens,
|
ImageOutputTokens: result.Usage.ImageOutputTokens,
|
||||||
}
|
}
|
||||||
gid := apiKey.Group.ID
|
gid := apiKey.Group.ID
|
||||||
var err error
|
cost, err := s.billingService.CalculateCostUnified(CostInput{
|
||||||
cost, err = s.billingService.CalculateCostUnified(CostInput{
|
|
||||||
Ctx: ctx,
|
Ctx: ctx,
|
||||||
Model: billingModel,
|
Model: billingModel,
|
||||||
GroupID: &gid,
|
GroupID: &gid,
|
||||||
@@ -7804,10 +7998,11 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
|||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.LegacyPrintf("service.gateway", "Calculate image token cost failed: %v", err)
|
logger.LegacyPrintf("service.gateway", "Calculate image token cost failed: %v", err)
|
||||||
cost = &CostBreakdown{ActualCost: 0}
|
return &CostBreakdown{ActualCost: 0}
|
||||||
}
|
}
|
||||||
} else {
|
return cost
|
||||||
// 无渠道定价 → 走按次计费(默认,兼容旧版本)
|
}
|
||||||
|
|
||||||
var groupConfig *ImagePriceConfig
|
var groupConfig *ImagePriceConfig
|
||||||
if apiKey.Group != nil {
|
if apiKey.Group != nil {
|
||||||
groupConfig = &ImagePriceConfig{
|
groupConfig = &ImagePriceConfig{
|
||||||
@@ -7816,10 +8011,18 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
|||||||
Price4K: apiKey.Group.ImagePrice4K,
|
Price4K: apiKey.Group.ImagePrice4K,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
cost = s.billingService.CalculateImageCost(billingModel, result.ImageSize, result.ImageCount, groupConfig, multiplier)
|
return s.billingService.CalculateImageCost(billingModel, result.ImageSize, result.ImageCount, groupConfig, multiplier)
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
// Token 计费
|
// calculateTokenCost 计算 Token 计费:根据 opts 决定走普通/长上下文/渠道统一计费。
|
||||||
|
func (s *GatewayService) calculateTokenCost(
|
||||||
|
ctx context.Context,
|
||||||
|
result *ForwardResult,
|
||||||
|
apiKey *APIKey,
|
||||||
|
billingModel string,
|
||||||
|
multiplier float64,
|
||||||
|
opts *recordUsageOpts,
|
||||||
|
) *CostBreakdown {
|
||||||
tokens := UsageTokens{
|
tokens := UsageTokens{
|
||||||
InputTokens: result.Usage.InputTokens,
|
InputTokens: result.Usage.InputTokens,
|
||||||
OutputTokens: result.Usage.OutputTokens,
|
OutputTokens: result.Usage.OutputTokens,
|
||||||
@@ -7829,43 +8032,69 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
|||||||
CacheCreation1hTokens: result.Usage.CacheCreation1hTokens,
|
CacheCreation1hTokens: result.Usage.CacheCreation1hTokens,
|
||||||
ImageOutputTokens: result.Usage.ImageOutputTokens,
|
ImageOutputTokens: result.Usage.ImageOutputTokens,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var cost *CostBreakdown
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
|
// 优先尝试 Resolver + CalculateCostUnified(仅在有渠道定价时使用)
|
||||||
|
useUnified := false
|
||||||
if s.resolver != nil && apiKey.Group != nil {
|
if s.resolver != nil && apiKey.Group != nil {
|
||||||
gid := apiKey.Group.ID
|
gid := apiKey.Group.ID
|
||||||
groupID := &gid
|
resolved := s.resolver.Resolve(ctx, PricingInput{Model: billingModel, GroupID: &gid})
|
||||||
|
if resolved.Source == PricingSourceChannel {
|
||||||
cost, err = s.billingService.CalculateCostUnified(CostInput{
|
cost, err = s.billingService.CalculateCostUnified(CostInput{
|
||||||
Ctx: ctx,
|
Ctx: ctx,
|
||||||
Model: billingModel,
|
Model: billingModel,
|
||||||
GroupID: groupID,
|
GroupID: &gid,
|
||||||
Tokens: tokens,
|
Tokens: tokens,
|
||||||
RequestCount: 1,
|
RequestCount: 1,
|
||||||
RateMultiplier: multiplier,
|
RateMultiplier: multiplier,
|
||||||
Resolver: s.resolver,
|
Resolver: s.resolver,
|
||||||
})
|
})
|
||||||
|
useUnified = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !useUnified {
|
||||||
|
if opts.LongContextThreshold > 0 {
|
||||||
|
// 长上下文双倍计费(如 Gemini 200K 阈值)
|
||||||
|
cost, err = s.billingService.CalculateCostWithLongContext(
|
||||||
|
billingModel, tokens, multiplier,
|
||||||
|
opts.LongContextThreshold, opts.LongContextMultiplier,
|
||||||
|
)
|
||||||
} else {
|
} else {
|
||||||
cost, err = s.billingService.CalculateCost(billingModel, tokens, multiplier)
|
cost, err = s.billingService.CalculateCost(billingModel, tokens, multiplier)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.LegacyPrintf("service.gateway", "Calculate cost failed: %v", err)
|
logger.LegacyPrintf("service.gateway", "Calculate cost failed: %v", err)
|
||||||
cost = &CostBreakdown{ActualCost: 0}
|
return &CostBreakdown{ActualCost: 0}
|
||||||
}
|
}
|
||||||
|
return cost
|
||||||
}
|
}
|
||||||
|
|
||||||
// 判断计费方式:订阅模式 vs 余额模式
|
// buildRecordUsageLog 构建使用日志并设置计费模式。
|
||||||
isSubscriptionBilling := subscription != nil && apiKey.Group != nil && apiKey.Group.IsSubscriptionType()
|
func (s *GatewayService) buildRecordUsageLog(
|
||||||
billingType := BillingTypeBalance
|
ctx context.Context,
|
||||||
if isSubscriptionBilling {
|
input *recordUsageCoreInput,
|
||||||
billingType = BillingTypeSubscription
|
result *ForwardResult,
|
||||||
}
|
apiKey *APIKey,
|
||||||
|
user *User,
|
||||||
// 创建使用日志
|
account *Account,
|
||||||
|
subscription *UserSubscription,
|
||||||
|
requestedModel string,
|
||||||
|
multiplier float64,
|
||||||
|
billingType int8,
|
||||||
|
cacheTTLOverridden bool,
|
||||||
|
cost *CostBreakdown,
|
||||||
|
opts *recordUsageOpts,
|
||||||
|
) *UsageLog {
|
||||||
durationMs := int(result.Duration.Milliseconds())
|
durationMs := int(result.Duration.Milliseconds())
|
||||||
var imageSize *string
|
var imageSize *string
|
||||||
if result.ImageSize != "" {
|
if result.ImageSize != "" {
|
||||||
imageSize = &result.ImageSize
|
imageSize = &result.ImageSize
|
||||||
}
|
}
|
||||||
var mediaType *string
|
var mediaType *string
|
||||||
if strings.TrimSpace(result.MediaType) != "" {
|
if opts.EnableClaudePath && strings.TrimSpace(result.MediaType) != "" {
|
||||||
mediaType = &result.MediaType
|
mediaType = &result.MediaType
|
||||||
}
|
}
|
||||||
accountRateMultiplier := account.BillingRateMultiplier()
|
accountRateMultiplier := account.BillingRateMultiplier()
|
||||||
@@ -7912,8 +8141,10 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
|||||||
usageLog.ActualCost = cost.ActualCost
|
usageLog.ActualCost = cost.ActualCost
|
||||||
}
|
}
|
||||||
|
|
||||||
// 设置计费模式
|
// 设置计费模式:Sora 媒体类型自身已确定计费模式(由上游处理),跳过
|
||||||
if result.MediaType != MediaTypeImage && result.MediaType != MediaTypeVideo && result.MediaType != MediaTypePrompt {
|
isSoraMedia := opts.EnableClaudePath &&
|
||||||
|
(result.MediaType == MediaTypeImage || result.MediaType == MediaTypeVideo || result.MediaType == MediaTypePrompt)
|
||||||
|
if !isSoraMedia {
|
||||||
if cost != nil && cost.BillingMode != "" {
|
if cost != nil && cost.BillingMode != "" {
|
||||||
billingMode := cost.BillingMode
|
billingMode := cost.BillingMode
|
||||||
usageLog.BillingMode = &billingMode
|
usageLog.BillingMode = &billingMode
|
||||||
@@ -7944,307 +8175,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
|||||||
usageLog.SubscriptionID = &subscription.ID
|
usageLog.SubscriptionID = &subscription.ID
|
||||||
}
|
}
|
||||||
|
|
||||||
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
|
return usageLog
|
||||||
writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway")
|
|
||||||
logger.LegacyPrintf("service.gateway", "[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens())
|
|
||||||
s.deferredService.ScheduleLastUsedUpdate(account.ID)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
billingErr := func() error {
|
|
||||||
_, err := applyUsageBilling(ctx, requestID, usageLog, &postUsageBillingParams{
|
|
||||||
Cost: cost,
|
|
||||||
User: user,
|
|
||||||
APIKey: apiKey,
|
|
||||||
Account: account,
|
|
||||||
Subscription: subscription,
|
|
||||||
RequestPayloadHash: resolveUsageBillingPayloadFingerprint(ctx, input.RequestPayloadHash),
|
|
||||||
IsSubscriptionBill: isSubscriptionBilling,
|
|
||||||
AccountRateMultiplier: accountRateMultiplier,
|
|
||||||
APIKeyService: input.APIKeyService,
|
|
||||||
}, s.billingDeps(), s.usageBillingRepo)
|
|
||||||
return err
|
|
||||||
}()
|
|
||||||
|
|
||||||
if billingErr != nil {
|
|
||||||
return billingErr
|
|
||||||
}
|
|
||||||
writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway")
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// RecordUsageLongContextInput 记录使用量的输入参数(支持长上下文双倍计费)
|
|
||||||
type RecordUsageLongContextInput struct {
|
|
||||||
Result *ForwardResult
|
|
||||||
APIKey *APIKey
|
|
||||||
User *User
|
|
||||||
Account *Account
|
|
||||||
Subscription *UserSubscription // 可选:订阅信息
|
|
||||||
InboundEndpoint string // 入站端点(客户端请求路径)
|
|
||||||
UpstreamEndpoint string // 上游端点(标准化后的上游路径)
|
|
||||||
UserAgent string // 请求的 User-Agent
|
|
||||||
IPAddress string // 请求的客户端 IP 地址
|
|
||||||
RequestPayloadHash string // 请求体语义哈希,用于降低 request_id 误复用时的静默误去重风险
|
|
||||||
LongContextThreshold int // 长上下文阈值(如 200000)
|
|
||||||
LongContextMultiplier float64 // 超出阈值部分的倍率(如 2.0)
|
|
||||||
ForceCacheBilling bool // 强制缓存计费:将 input_tokens 转为 cache_read 计费(用于粘性会话切换)
|
|
||||||
APIKeyService APIKeyQuotaUpdater // API Key 配额服务(可选)
|
|
||||||
|
|
||||||
ChannelUsageFields // 渠道映射信息(由 handler 在 Forward 前解析)
|
|
||||||
}
|
|
||||||
|
|
||||||
// RecordUsageWithLongContext 记录使用量并扣费,支持长上下文双倍计费(用于 Gemini)
|
|
||||||
func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *RecordUsageLongContextInput) error {
|
|
||||||
result := input.Result
|
|
||||||
apiKey := input.APIKey
|
|
||||||
user := input.User
|
|
||||||
account := input.Account
|
|
||||||
subscription := input.Subscription
|
|
||||||
|
|
||||||
// 强制缓存计费:将 input_tokens 转为 cache_read_input_tokens
|
|
||||||
// 用于粘性会话切换时的特殊计费处理
|
|
||||||
if input.ForceCacheBilling && result.Usage.InputTokens > 0 {
|
|
||||||
logger.LegacyPrintf("service.gateway", "force_cache_billing: %d input_tokens → cache_read_input_tokens (account=%d)",
|
|
||||||
result.Usage.InputTokens, account.ID)
|
|
||||||
result.Usage.CacheReadInputTokens += result.Usage.InputTokens
|
|
||||||
result.Usage.InputTokens = 0
|
|
||||||
}
|
|
||||||
|
|
||||||
// Cache TTL Override: 确保计费时 token 分类与账号设置一致
|
|
||||||
cacheTTLOverridden := false
|
|
||||||
if account.IsCacheTTLOverrideEnabled() {
|
|
||||||
applyCacheTTLOverride(&result.Usage, account.GetCacheTTLOverrideTarget())
|
|
||||||
cacheTTLOverridden = (result.Usage.CacheCreation5mTokens + result.Usage.CacheCreation1hTokens) > 0
|
|
||||||
}
|
|
||||||
|
|
||||||
// 获取费率倍数(优先级:用户专属 > 分组默认 > 系统默认)
|
|
||||||
multiplier := 1.0
|
|
||||||
if s.cfg != nil {
|
|
||||||
multiplier = s.cfg.Default.RateMultiplier
|
|
||||||
}
|
|
||||||
if apiKey.GroupID != nil && apiKey.Group != nil {
|
|
||||||
groupDefault := apiKey.Group.RateMultiplier
|
|
||||||
multiplier = s.getUserGroupRateMultiplier(ctx, user.ID, *apiKey.GroupID, groupDefault)
|
|
||||||
}
|
|
||||||
|
|
||||||
var cost *CostBreakdown
|
|
||||||
// 确定计费模型
|
|
||||||
billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel)
|
|
||||||
if input.BillingModelSource == BillingModelSourceChannelMapped && input.ChannelMappedModel != "" {
|
|
||||||
billingModel = input.ChannelMappedModel
|
|
||||||
}
|
|
||||||
if input.BillingModelSource == BillingModelSourceRequested && input.OriginalModel != "" {
|
|
||||||
billingModel = input.OriginalModel
|
|
||||||
}
|
|
||||||
|
|
||||||
// 确定 RequestedModel(渠道映射前的原始模型)
|
|
||||||
requestedModel := result.Model
|
|
||||||
if input.OriginalModel != "" {
|
|
||||||
requestedModel = input.OriginalModel
|
|
||||||
}
|
|
||||||
|
|
||||||
// 根据请求类型选择计费方式
|
|
||||||
if result.ImageCount > 0 {
|
|
||||||
// 图片生成计费:渠道级别定价优先,否则走按次计费(兼容旧版本)
|
|
||||||
hasChannelPricing := false
|
|
||||||
if s.resolver != nil && apiKey.Group != nil {
|
|
||||||
gid := apiKey.Group.ID
|
|
||||||
resolved := s.resolver.Resolve(ctx, PricingInput{Model: billingModel, GroupID: &gid})
|
|
||||||
if resolved.Source == PricingSourceChannel {
|
|
||||||
hasChannelPricing = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if hasChannelPricing {
|
|
||||||
tokens := UsageTokens{
|
|
||||||
InputTokens: result.Usage.InputTokens,
|
|
||||||
OutputTokens: result.Usage.OutputTokens,
|
|
||||||
ImageOutputTokens: result.Usage.ImageOutputTokens,
|
|
||||||
}
|
|
||||||
gid := apiKey.Group.ID
|
|
||||||
var err error
|
|
||||||
cost, err = s.billingService.CalculateCostUnified(CostInput{
|
|
||||||
Ctx: ctx,
|
|
||||||
Model: billingModel,
|
|
||||||
GroupID: &gid,
|
|
||||||
Tokens: tokens,
|
|
||||||
RequestCount: 1,
|
|
||||||
RateMultiplier: multiplier,
|
|
||||||
Resolver: s.resolver,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
logger.LegacyPrintf("service.gateway", "Calculate image token cost failed: %v", err)
|
|
||||||
cost = &CostBreakdown{ActualCost: 0}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
var groupConfig *ImagePriceConfig
|
|
||||||
if apiKey.Group != nil {
|
|
||||||
groupConfig = &ImagePriceConfig{
|
|
||||||
Price1K: apiKey.Group.ImagePrice1K,
|
|
||||||
Price2K: apiKey.Group.ImagePrice2K,
|
|
||||||
Price4K: apiKey.Group.ImagePrice4K,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
cost = s.billingService.CalculateImageCost(billingModel, result.ImageSize, result.ImageCount, groupConfig, multiplier)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// Token 计费(使用长上下文计费方法)
|
|
||||||
tokens := UsageTokens{
|
|
||||||
InputTokens: result.Usage.InputTokens,
|
|
||||||
OutputTokens: result.Usage.OutputTokens,
|
|
||||||
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
|
|
||||||
CacheReadTokens: result.Usage.CacheReadInputTokens,
|
|
||||||
CacheCreation5mTokens: result.Usage.CacheCreation5mTokens,
|
|
||||||
CacheCreation1hTokens: result.Usage.CacheCreation1hTokens,
|
|
||||||
ImageOutputTokens: result.Usage.ImageOutputTokens,
|
|
||||||
}
|
|
||||||
var err error
|
|
||||||
// 优先尝试 Resolver + CalculateCostUnified(仅在有渠道定价时使用)
|
|
||||||
useUnified := false
|
|
||||||
if s.resolver != nil && apiKey.Group != nil {
|
|
||||||
gid := apiKey.Group.ID
|
|
||||||
resolved := s.resolver.Resolve(ctx, PricingInput{
|
|
||||||
Model: billingModel,
|
|
||||||
GroupID: &gid,
|
|
||||||
})
|
|
||||||
if resolved.Source == PricingSourceChannel {
|
|
||||||
// 有渠道定价,渠道区间已包含上下文分层
|
|
||||||
cost, err = s.billingService.CalculateCostUnified(CostInput{
|
|
||||||
Ctx: ctx,
|
|
||||||
Model: billingModel,
|
|
||||||
GroupID: &gid,
|
|
||||||
Tokens: tokens,
|
|
||||||
RequestCount: 1,
|
|
||||||
RateMultiplier: multiplier,
|
|
||||||
Resolver: s.resolver,
|
|
||||||
})
|
|
||||||
useUnified = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !useUnified {
|
|
||||||
// 无渠道定价,保持原有长上下文双倍计费逻辑(如 Gemini 200K 阈值)
|
|
||||||
cost, err = s.billingService.CalculateCostWithLongContext(billingModel, tokens, multiplier, input.LongContextThreshold, input.LongContextMultiplier)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
logger.LegacyPrintf("service.gateway", "Calculate cost failed: %v", err)
|
|
||||||
cost = &CostBreakdown{ActualCost: 0}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 判断计费方式:订阅模式 vs 余额模式
|
|
||||||
isSubscriptionBilling := subscription != nil && apiKey.Group != nil && apiKey.Group.IsSubscriptionType()
|
|
||||||
billingType := BillingTypeBalance
|
|
||||||
if isSubscriptionBilling {
|
|
||||||
billingType = BillingTypeSubscription
|
|
||||||
}
|
|
||||||
|
|
||||||
// 创建使用日志
|
|
||||||
durationMs := int(result.Duration.Milliseconds())
|
|
||||||
var imageSize *string
|
|
||||||
if result.ImageSize != "" {
|
|
||||||
imageSize = &result.ImageSize
|
|
||||||
}
|
|
||||||
accountRateMultiplier := account.BillingRateMultiplier()
|
|
||||||
requestID := resolveUsageBillingRequestID(ctx, result.RequestID)
|
|
||||||
usageLog := &UsageLog{
|
|
||||||
UserID: user.ID,
|
|
||||||
APIKeyID: apiKey.ID,
|
|
||||||
AccountID: account.ID,
|
|
||||||
RequestID: requestID,
|
|
||||||
Model: result.Model,
|
|
||||||
RequestedModel: requestedModel,
|
|
||||||
UpstreamModel: optionalNonEqualStringPtr(result.UpstreamModel, result.Model),
|
|
||||||
ReasoningEffort: result.ReasoningEffort,
|
|
||||||
InboundEndpoint: optionalTrimmedStringPtr(input.InboundEndpoint),
|
|
||||||
UpstreamEndpoint: optionalTrimmedStringPtr(input.UpstreamEndpoint),
|
|
||||||
InputTokens: result.Usage.InputTokens,
|
|
||||||
OutputTokens: result.Usage.OutputTokens,
|
|
||||||
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
|
|
||||||
CacheReadTokens: result.Usage.CacheReadInputTokens,
|
|
||||||
CacheCreation5mTokens: result.Usage.CacheCreation5mTokens,
|
|
||||||
CacheCreation1hTokens: result.Usage.CacheCreation1hTokens,
|
|
||||||
ImageOutputTokens: result.Usage.ImageOutputTokens,
|
|
||||||
RateMultiplier: multiplier,
|
|
||||||
AccountRateMultiplier: &accountRateMultiplier,
|
|
||||||
BillingType: billingType,
|
|
||||||
Stream: result.Stream,
|
|
||||||
DurationMs: &durationMs,
|
|
||||||
FirstTokenMs: result.FirstTokenMs,
|
|
||||||
ImageCount: result.ImageCount,
|
|
||||||
ImageSize: imageSize,
|
|
||||||
CacheTTLOverridden: cacheTTLOverridden,
|
|
||||||
ChannelID: optionalInt64Ptr(input.ChannelID),
|
|
||||||
ModelMappingChain: optionalTrimmedStringPtr(input.ModelMappingChain),
|
|
||||||
CreatedAt: time.Now(),
|
|
||||||
}
|
|
||||||
if cost != nil {
|
|
||||||
usageLog.InputCost = cost.InputCost
|
|
||||||
usageLog.OutputCost = cost.OutputCost
|
|
||||||
usageLog.ImageOutputCost = cost.ImageOutputCost
|
|
||||||
usageLog.CacheCreationCost = cost.CacheCreationCost
|
|
||||||
usageLog.CacheReadCost = cost.CacheReadCost
|
|
||||||
usageLog.TotalCost = cost.TotalCost
|
|
||||||
usageLog.ActualCost = cost.ActualCost
|
|
||||||
}
|
|
||||||
|
|
||||||
// 设置计费模式
|
|
||||||
if cost != nil && cost.BillingMode != "" {
|
|
||||||
billingMode := cost.BillingMode
|
|
||||||
usageLog.BillingMode = &billingMode
|
|
||||||
} else if result.ImageCount > 0 {
|
|
||||||
billingMode := string(BillingModeImage)
|
|
||||||
usageLog.BillingMode = &billingMode
|
|
||||||
} else {
|
|
||||||
billingMode := string(BillingModeToken)
|
|
||||||
usageLog.BillingMode = &billingMode
|
|
||||||
}
|
|
||||||
|
|
||||||
// 添加 UserAgent
|
|
||||||
if input.UserAgent != "" {
|
|
||||||
usageLog.UserAgent = &input.UserAgent
|
|
||||||
}
|
|
||||||
|
|
||||||
// 添加 IPAddress
|
|
||||||
if input.IPAddress != "" {
|
|
||||||
usageLog.IPAddress = &input.IPAddress
|
|
||||||
}
|
|
||||||
|
|
||||||
// 添加分组和订阅关联
|
|
||||||
if apiKey.GroupID != nil {
|
|
||||||
usageLog.GroupID = apiKey.GroupID
|
|
||||||
}
|
|
||||||
if subscription != nil {
|
|
||||||
usageLog.SubscriptionID = &subscription.ID
|
|
||||||
}
|
|
||||||
|
|
||||||
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
|
|
||||||
writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway")
|
|
||||||
logger.LegacyPrintf("service.gateway", "[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens())
|
|
||||||
s.deferredService.ScheduleLastUsedUpdate(account.ID)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
billingErr := func() error {
|
|
||||||
_, err := applyUsageBilling(ctx, requestID, usageLog, &postUsageBillingParams{
|
|
||||||
Cost: cost,
|
|
||||||
User: user,
|
|
||||||
APIKey: apiKey,
|
|
||||||
Account: account,
|
|
||||||
Subscription: subscription,
|
|
||||||
RequestPayloadHash: resolveUsageBillingPayloadFingerprint(ctx, input.RequestPayloadHash),
|
|
||||||
IsSubscriptionBill: isSubscriptionBilling,
|
|
||||||
AccountRateMultiplier: accountRateMultiplier,
|
|
||||||
APIKeyService: input.APIKeyService,
|
|
||||||
}, s.billingDeps(), s.usageBillingRepo)
|
|
||||||
return err
|
|
||||||
}()
|
|
||||||
|
|
||||||
if billingErr != nil {
|
|
||||||
return billingErr
|
|
||||||
}
|
|
||||||
writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway")
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ResolveChannelMapping 委托渠道服务解析模型映射
|
// ResolveChannelMapping 委托渠道服务解析模型映射
|
||||||
|
|||||||
Reference in New Issue
Block a user