diff --git a/backend/internal/service/billing_service.go b/backend/internal/service/billing_service.go index 95e16c4e..db5a9708 100644 --- a/backend/internal/service/billing_service.go +++ b/backend/internal/service/billing_service.go @@ -244,60 +244,71 @@ func (s *BillingService) CalculateCostWithConfig(model string, tokens UsageToken // CalculateCostWithLongContext 计算费用,支持长上下文双倍计费 // threshold: 阈值(如 200000),超过此值的部分按 extraMultiplier 倍计费 // extraMultiplier: 超出部分的倍率(如 2.0 表示双倍) +// +// 示例:缓存 210k + 输入 10k = 220k,阈值 200k,倍率 2.0 +// 拆分为:范围内 (200k, 0) + 范围外 (10k, 10k) +// 范围内正常计费,范围外 × 2 计费 func (s *BillingService) CalculateCostWithLongContext(model string, tokens UsageTokens, rateMultiplier float64, threshold int, extraMultiplier float64) (*CostBreakdown, error) { - // 1. 先正常计算全部 token 的成本 - cost, err := s.CalculateCost(model, tokens, rateMultiplier) - if err != nil { - return nil, err - } - - // 2. 如果未启用长上下文计费或未超过阈值,直接返回 + // 未启用长上下文计费,直接走正常计费 if threshold <= 0 || extraMultiplier <= 1 { - return cost, nil + return s.CalculateCost(model, tokens, rateMultiplier) } // 计算总输入 token(缓存读取 + 新输入) total := tokens.CacheReadTokens + tokens.InputTokens if total <= threshold { - return cost, nil + return s.CalculateCost(model, tokens, rateMultiplier) } - // 3. 拆分超出部分的 token - extra := total - threshold - var extraCacheTokens, extraInputTokens int + // 拆分成范围内和范围外 + var inRangeCacheTokens, inRangeInputTokens int + var outRangeCacheTokens, outRangeInputTokens int if tokens.CacheReadTokens >= threshold { - // 缓存已超过阈值:超出的缓存 + 全部输入 - extraCacheTokens = tokens.CacheReadTokens - threshold - extraInputTokens = tokens.InputTokens + // 缓存已超过阈值:范围内只有缓存,范围外是超出的缓存+全部输入 + inRangeCacheTokens = threshold + inRangeInputTokens = 0 + outRangeCacheTokens = tokens.CacheReadTokens - threshold + outRangeInputTokens = tokens.InputTokens } else { - // 缓存未超过阈值:只有输入超出部分 - extraCacheTokens = 0 - extraInputTokens = extra + // 缓存未超过阈值:范围内是全部缓存+部分输入,范围外是剩余输入 + inRangeCacheTokens = tokens.CacheReadTokens + inRangeInputTokens = threshold - tokens.CacheReadTokens + outRangeCacheTokens = 0 + outRangeInputTokens = tokens.InputTokens - inRangeInputTokens } - // 4. 计算超出部分的成本(只算输入和缓存读取) - extraTokens := UsageTokens{ - InputTokens: extraInputTokens, - CacheReadTokens: extraCacheTokens, + // 范围内部分:正常计费 + inRangeTokens := UsageTokens{ + InputTokens: inRangeInputTokens, + OutputTokens: tokens.OutputTokens, // 输出只算一次 + CacheCreationTokens: tokens.CacheCreationTokens, + CacheReadTokens: inRangeCacheTokens, } - extraCost, err := s.CalculateCost(model, extraTokens, 1.0) // 先按 1 倍算 + inRangeCost, err := s.CalculateCost(model, inRangeTokens, rateMultiplier) if err != nil { - return cost, nil // 出错时返回正常成本 + return nil, err } - // 5. 额外成本 = 超出部分成本 × (倍率 - 1) - extraRate := extraMultiplier - 1 - additionalInputCost := extraCost.InputCost * extraRate - additionalCacheCost := extraCost.CacheReadCost * extraRate + // 范围外部分:× extraMultiplier 计费 + outRangeTokens := UsageTokens{ + InputTokens: outRangeInputTokens, + CacheReadTokens: outRangeCacheTokens, + } + outRangeCost, err := s.CalculateCost(model, outRangeTokens, rateMultiplier*extraMultiplier) + if err != nil { + return inRangeCost, nil // 出错时返回范围内成本 + } - // 6. 累加到总成本 - cost.InputCost += additionalInputCost - cost.CacheReadCost += additionalCacheCost - cost.TotalCost += additionalInputCost + additionalCacheCost - cost.ActualCost = cost.TotalCost * rateMultiplier - - return cost, nil + // 合并成本 + return &CostBreakdown{ + InputCost: inRangeCost.InputCost + outRangeCost.InputCost, + OutputCost: inRangeCost.OutputCost, + CacheCreationCost: inRangeCost.CacheCreationCost, + CacheReadCost: inRangeCost.CacheReadCost + outRangeCost.CacheReadCost, + TotalCost: inRangeCost.TotalCost + outRangeCost.TotalCost, + ActualCost: inRangeCost.ActualCost + outRangeCost.ActualCost, + }, nil } // ListSupportedModels 列出所有支持的模型(现在总是返回true,因为有模糊匹配)