refactor(billing): 简化 CalculateCostWithLongContext 逻辑
将 token 直接拆分为范围内和范围外两部分,分别调用 CalculateCost: - 范围内:正常计费 (rateMultiplier) - 范围外:双倍计费 (rateMultiplier × extraMultiplier) 代码更直观,便于理解和维护
This commit is contained in:
@@ -244,60 +244,71 @@ func (s *BillingService) CalculateCostWithConfig(model string, tokens UsageToken
|
|||||||
// CalculateCostWithLongContext 计算费用,支持长上下文双倍计费
|
// CalculateCostWithLongContext 计算费用,支持长上下文双倍计费
|
||||||
// threshold: 阈值(如 200000),超过此值的部分按 extraMultiplier 倍计费
|
// threshold: 阈值(如 200000),超过此值的部分按 extraMultiplier 倍计费
|
||||||
// extraMultiplier: 超出部分的倍率(如 2.0 表示双倍)
|
// extraMultiplier: 超出部分的倍率(如 2.0 表示双倍)
|
||||||
|
//
|
||||||
|
// 示例:缓存 210k + 输入 10k = 220k,阈值 200k,倍率 2.0
|
||||||
|
// 拆分为:范围内 (200k, 0) + 范围外 (10k, 10k)
|
||||||
|
// 范围内正常计费,范围外 × 2 计费
|
||||||
func (s *BillingService) CalculateCostWithLongContext(model string, tokens UsageTokens, rateMultiplier float64, threshold int, extraMultiplier float64) (*CostBreakdown, error) {
|
func (s *BillingService) CalculateCostWithLongContext(model string, tokens UsageTokens, rateMultiplier float64, threshold int, extraMultiplier float64) (*CostBreakdown, error) {
|
||||||
// 1. 先正常计算全部 token 的成本
|
// 未启用长上下文计费,直接走正常计费
|
||||||
cost, err := s.CalculateCost(model, tokens, rateMultiplier)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// 2. 如果未启用长上下文计费或未超过阈值,直接返回
|
|
||||||
if threshold <= 0 || extraMultiplier <= 1 {
|
if threshold <= 0 || extraMultiplier <= 1 {
|
||||||
return cost, nil
|
return s.CalculateCost(model, tokens, rateMultiplier)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 计算总输入 token(缓存读取 + 新输入)
|
// 计算总输入 token(缓存读取 + 新输入)
|
||||||
total := tokens.CacheReadTokens + tokens.InputTokens
|
total := tokens.CacheReadTokens + tokens.InputTokens
|
||||||
if total <= threshold {
|
if total <= threshold {
|
||||||
return cost, nil
|
return s.CalculateCost(model, tokens, rateMultiplier)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 3. 拆分超出部分的 token
|
// 拆分成范围内和范围外
|
||||||
extra := total - threshold
|
var inRangeCacheTokens, inRangeInputTokens int
|
||||||
var extraCacheTokens, extraInputTokens int
|
var outRangeCacheTokens, outRangeInputTokens int
|
||||||
|
|
||||||
if tokens.CacheReadTokens >= threshold {
|
if tokens.CacheReadTokens >= threshold {
|
||||||
// 缓存已超过阈值:超出的缓存 + 全部输入
|
// 缓存已超过阈值:范围内只有缓存,范围外是超出的缓存+全部输入
|
||||||
extraCacheTokens = tokens.CacheReadTokens - threshold
|
inRangeCacheTokens = threshold
|
||||||
extraInputTokens = tokens.InputTokens
|
inRangeInputTokens = 0
|
||||||
|
outRangeCacheTokens = tokens.CacheReadTokens - threshold
|
||||||
|
outRangeInputTokens = tokens.InputTokens
|
||||||
} else {
|
} else {
|
||||||
// 缓存未超过阈值:只有输入超出部分
|
// 缓存未超过阈值:范围内是全部缓存+部分输入,范围外是剩余输入
|
||||||
extraCacheTokens = 0
|
inRangeCacheTokens = tokens.CacheReadTokens
|
||||||
extraInputTokens = extra
|
inRangeInputTokens = threshold - tokens.CacheReadTokens
|
||||||
|
outRangeCacheTokens = 0
|
||||||
|
outRangeInputTokens = tokens.InputTokens - inRangeInputTokens
|
||||||
}
|
}
|
||||||
|
|
||||||
// 4. 计算超出部分的成本(只算输入和缓存读取)
|
// 范围内部分:正常计费
|
||||||
extraTokens := UsageTokens{
|
inRangeTokens := UsageTokens{
|
||||||
InputTokens: extraInputTokens,
|
InputTokens: inRangeInputTokens,
|
||||||
CacheReadTokens: extraCacheTokens,
|
OutputTokens: tokens.OutputTokens, // 输出只算一次
|
||||||
|
CacheCreationTokens: tokens.CacheCreationTokens,
|
||||||
|
CacheReadTokens: inRangeCacheTokens,
|
||||||
}
|
}
|
||||||
extraCost, err := s.CalculateCost(model, extraTokens, 1.0) // 先按 1 倍算
|
inRangeCost, err := s.CalculateCost(model, inRangeTokens, rateMultiplier)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return cost, nil // 出错时返回正常成本
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// 5. 额外成本 = 超出部分成本 × (倍率 - 1)
|
// 范围外部分:× extraMultiplier 计费
|
||||||
extraRate := extraMultiplier - 1
|
outRangeTokens := UsageTokens{
|
||||||
additionalInputCost := extraCost.InputCost * extraRate
|
InputTokens: outRangeInputTokens,
|
||||||
additionalCacheCost := extraCost.CacheReadCost * extraRate
|
CacheReadTokens: outRangeCacheTokens,
|
||||||
|
}
|
||||||
|
outRangeCost, err := s.CalculateCost(model, outRangeTokens, rateMultiplier*extraMultiplier)
|
||||||
|
if err != nil {
|
||||||
|
return inRangeCost, nil // 出错时返回范围内成本
|
||||||
|
}
|
||||||
|
|
||||||
// 6. 累加到总成本
|
// 合并成本
|
||||||
cost.InputCost += additionalInputCost
|
return &CostBreakdown{
|
||||||
cost.CacheReadCost += additionalCacheCost
|
InputCost: inRangeCost.InputCost + outRangeCost.InputCost,
|
||||||
cost.TotalCost += additionalInputCost + additionalCacheCost
|
OutputCost: inRangeCost.OutputCost,
|
||||||
cost.ActualCost = cost.TotalCost * rateMultiplier
|
CacheCreationCost: inRangeCost.CacheCreationCost,
|
||||||
|
CacheReadCost: inRangeCost.CacheReadCost + outRangeCost.CacheReadCost,
|
||||||
return cost, nil
|
TotalCost: inRangeCost.TotalCost + outRangeCost.TotalCost,
|
||||||
|
ActualCost: inRangeCost.ActualCost + outRangeCost.ActualCost,
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListSupportedModels 列出所有支持的模型(现在总是返回true,因为有模糊匹配)
|
// ListSupportedModels 列出所有支持的模型(现在总是返回true,因为有模糊匹配)
|
||||||
|
|||||||
Reference in New Issue
Block a user