feat(channel): 渠道管理系统 — 多模式定价 + 统一计费解析
Cherry-picked from release/custom-0.1.106: a9117600
This commit is contained in:
@@ -371,13 +371,193 @@ func (s *BillingService) GetModelPricing(model string) (*ModelPricing, error) {
|
||||
return nil, fmt.Errorf("pricing not found for model: %s", model)
|
||||
}
|
||||
|
||||
// GetModelPricingWithChannel 获取模型定价,渠道配置的价格覆盖默认值
|
||||
// 仅覆盖渠道中非 nil 的价格字段,nil 字段使用默认定价
|
||||
func (s *BillingService) GetModelPricingWithChannel(model string, channelPricing *ChannelModelPricing) (*ModelPricing, error) {
|
||||
pricing, err := s.GetModelPricing(model)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if channelPricing == nil {
|
||||
return pricing, nil
|
||||
}
|
||||
if channelPricing.InputPrice != nil {
|
||||
pricing.InputPricePerToken = *channelPricing.InputPrice
|
||||
pricing.InputPricePerTokenPriority = *channelPricing.InputPrice
|
||||
}
|
||||
if channelPricing.OutputPrice != nil {
|
||||
pricing.OutputPricePerToken = *channelPricing.OutputPrice
|
||||
pricing.OutputPricePerTokenPriority = *channelPricing.OutputPrice
|
||||
}
|
||||
if channelPricing.CacheWritePrice != nil {
|
||||
pricing.CacheCreationPricePerToken = *channelPricing.CacheWritePrice
|
||||
pricing.CacheCreation5mPrice = *channelPricing.CacheWritePrice
|
||||
pricing.CacheCreation1hPrice = *channelPricing.CacheWritePrice
|
||||
}
|
||||
if channelPricing.CacheReadPrice != nil {
|
||||
pricing.CacheReadPricePerToken = *channelPricing.CacheReadPrice
|
||||
pricing.CacheReadPricePerTokenPriority = *channelPricing.CacheReadPrice
|
||||
}
|
||||
return pricing, nil
|
||||
}
|
||||
|
||||
// CalculateCostWithChannel 使用渠道定价计算费用
|
||||
// Deprecated: 使用 CalculateCostUnified 代替
|
||||
func (s *BillingService) CalculateCostWithChannel(model string, tokens UsageTokens, rateMultiplier float64, channelPricing *ChannelModelPricing) (*CostBreakdown, error) {
|
||||
return s.calculateCostInternal(model, tokens, rateMultiplier, "", channelPricing)
|
||||
}
|
||||
|
||||
// --- 统一计费入口 ---
|
||||
|
||||
// CostInput 统一计费输入
|
||||
type CostInput struct {
|
||||
Ctx context.Context
|
||||
Model string
|
||||
GroupID *int64 // 用于渠道定价查找
|
||||
Tokens UsageTokens
|
||||
RequestCount int // 按次计费时使用
|
||||
SizeTier string // 按次/图片模式的层级标签("1K","2K","4K","HD" 等)
|
||||
RateMultiplier float64
|
||||
ServiceTier string // "priority","flex","" 等
|
||||
Resolver *ModelPricingResolver // 定价解析器
|
||||
}
|
||||
|
||||
// CalculateCostUnified 统一计费入口,支持三种计费模式。
|
||||
// 使用 ModelPricingResolver 解析定价,然后根据 BillingMode 分发计算。
|
||||
func (s *BillingService) CalculateCostUnified(input CostInput) (*CostBreakdown, error) {
|
||||
if input.Resolver == nil {
|
||||
// 无 Resolver,回退到旧路径
|
||||
return s.calculateCostInternal(input.Model, input.Tokens, input.RateMultiplier, input.ServiceTier, nil)
|
||||
}
|
||||
|
||||
resolved := input.Resolver.Resolve(input.Ctx, PricingInput{
|
||||
Model: input.Model,
|
||||
GroupID: input.GroupID,
|
||||
})
|
||||
|
||||
if input.RateMultiplier <= 0 {
|
||||
input.RateMultiplier = 1.0
|
||||
}
|
||||
|
||||
switch resolved.Mode {
|
||||
case BillingModePerRequest, BillingModeImage:
|
||||
return s.calculatePerRequestCost(resolved, input)
|
||||
default: // BillingModeToken
|
||||
return s.calculateTokenCost(resolved, input)
|
||||
}
|
||||
}
|
||||
|
||||
// calculateTokenCost 按 token 区间计费
|
||||
func (s *BillingService) calculateTokenCost(resolved *ResolvedPricing, input CostInput) (*CostBreakdown, error) {
|
||||
totalContext := input.Tokens.InputTokens + input.Tokens.CacheReadTokens
|
||||
|
||||
pricing := input.Resolver.GetIntervalPricing(resolved, totalContext)
|
||||
if pricing == nil {
|
||||
return nil, fmt.Errorf("no pricing available for model: %s", input.Model)
|
||||
}
|
||||
|
||||
pricing = s.applyModelSpecificPricingPolicy(input.Model, pricing)
|
||||
|
||||
breakdown := &CostBreakdown{}
|
||||
inputPricePerToken := pricing.InputPricePerToken
|
||||
outputPricePerToken := pricing.OutputPricePerToken
|
||||
cacheReadPricePerToken := pricing.CacheReadPricePerToken
|
||||
tierMultiplier := 1.0
|
||||
|
||||
if usePriorityServiceTierPricing(input.ServiceTier, pricing) {
|
||||
if pricing.InputPricePerTokenPriority > 0 {
|
||||
inputPricePerToken = pricing.InputPricePerTokenPriority
|
||||
}
|
||||
if pricing.OutputPricePerTokenPriority > 0 {
|
||||
outputPricePerToken = pricing.OutputPricePerTokenPriority
|
||||
}
|
||||
if pricing.CacheReadPricePerTokenPriority > 0 {
|
||||
cacheReadPricePerToken = pricing.CacheReadPricePerTokenPriority
|
||||
}
|
||||
} else {
|
||||
tierMultiplier = serviceTierCostMultiplier(input.ServiceTier)
|
||||
}
|
||||
|
||||
// 长上下文定价(仅在无区间定价时应用,区间定价已包含上下文分层)
|
||||
if len(resolved.Intervals) == 0 && s.shouldApplySessionLongContextPricing(input.Tokens, pricing) {
|
||||
inputPricePerToken *= pricing.LongContextInputMultiplier
|
||||
outputPricePerToken *= pricing.LongContextOutputMultiplier
|
||||
}
|
||||
|
||||
breakdown.InputCost = float64(input.Tokens.InputTokens) * inputPricePerToken
|
||||
breakdown.OutputCost = float64(input.Tokens.OutputTokens) * outputPricePerToken
|
||||
|
||||
if pricing.SupportsCacheBreakdown && (pricing.CacheCreation5mPrice > 0 || pricing.CacheCreation1hPrice > 0) {
|
||||
if input.Tokens.CacheCreation5mTokens == 0 && input.Tokens.CacheCreation1hTokens == 0 && input.Tokens.CacheCreationTokens > 0 {
|
||||
breakdown.CacheCreationCost = float64(input.Tokens.CacheCreationTokens) * pricing.CacheCreation5mPrice
|
||||
} else {
|
||||
breakdown.CacheCreationCost = float64(input.Tokens.CacheCreation5mTokens)*pricing.CacheCreation5mPrice +
|
||||
float64(input.Tokens.CacheCreation1hTokens)*pricing.CacheCreation1hPrice
|
||||
}
|
||||
} else {
|
||||
breakdown.CacheCreationCost = float64(input.Tokens.CacheCreationTokens) * pricing.CacheCreationPricePerToken
|
||||
}
|
||||
|
||||
breakdown.CacheReadCost = float64(input.Tokens.CacheReadTokens) * cacheReadPricePerToken
|
||||
|
||||
if tierMultiplier != 1.0 {
|
||||
breakdown.InputCost *= tierMultiplier
|
||||
breakdown.OutputCost *= tierMultiplier
|
||||
breakdown.CacheCreationCost *= tierMultiplier
|
||||
breakdown.CacheReadCost *= tierMultiplier
|
||||
}
|
||||
|
||||
breakdown.TotalCost = breakdown.InputCost + breakdown.OutputCost +
|
||||
breakdown.CacheCreationCost + breakdown.CacheReadCost
|
||||
breakdown.ActualCost = breakdown.TotalCost * input.RateMultiplier
|
||||
|
||||
return breakdown, nil
|
||||
}
|
||||
|
||||
// calculatePerRequestCost 按次/图片计费
|
||||
func (s *BillingService) calculatePerRequestCost(resolved *ResolvedPricing, input CostInput) (*CostBreakdown, error) {
|
||||
count := input.RequestCount
|
||||
if count <= 0 {
|
||||
count = 1
|
||||
}
|
||||
|
||||
var unitPrice float64
|
||||
|
||||
if input.SizeTier != "" {
|
||||
unitPrice = input.Resolver.GetRequestTierPrice(resolved, input.SizeTier)
|
||||
}
|
||||
|
||||
if unitPrice == 0 {
|
||||
totalContext := input.Tokens.InputTokens + input.Tokens.CacheReadTokens
|
||||
unitPrice = input.Resolver.GetRequestTierPriceByContext(resolved, totalContext)
|
||||
}
|
||||
|
||||
totalCost := unitPrice * float64(count)
|
||||
actualCost := totalCost * input.RateMultiplier
|
||||
|
||||
return &CostBreakdown{
|
||||
TotalCost: totalCost,
|
||||
ActualCost: actualCost,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// CalculateCost 计算使用费用
|
||||
func (s *BillingService) CalculateCost(model string, tokens UsageTokens, rateMultiplier float64) (*CostBreakdown, error) {
|
||||
return s.CalculateCostWithServiceTier(model, tokens, rateMultiplier, "")
|
||||
return s.calculateCostInternal(model, tokens, rateMultiplier, "", nil)
|
||||
}
|
||||
|
||||
func (s *BillingService) CalculateCostWithServiceTier(model string, tokens UsageTokens, rateMultiplier float64, serviceTier string) (*CostBreakdown, error) {
|
||||
pricing, err := s.GetModelPricing(model)
|
||||
return s.calculateCostInternal(model, tokens, rateMultiplier, serviceTier, nil)
|
||||
}
|
||||
|
||||
func (s *BillingService) calculateCostInternal(model string, tokens UsageTokens, rateMultiplier float64, serviceTier string, channelPricing *ChannelModelPricing) (*CostBreakdown, error) {
|
||||
var pricing *ModelPricing
|
||||
var err error
|
||||
if channelPricing != nil {
|
||||
pricing, err = s.GetModelPricingWithChannel(model, channelPricing)
|
||||
} else {
|
||||
pricing, err = s.GetModelPricing(model)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user