Priority was wrong: - Before: custom rules → LiteLLM (when ApplyPricingToAccountStats) → nil - After: custom rules → totalCost (when ApplyPricingToAccountStats) → LiteLLM → nil When ApplyPricingToAccountStats is enabled, use the request's actual client billing cost (before multiplier) as account_stats_cost, instead of recalculating from LiteLLM per-token prices which produced incorrect values for per-request billing mode. LiteLLM model pricing is now the final fallback (priority 3), used only when neither custom rules nor ApplyPricingToAccountStats apply.
215 lines
6.5 KiB
Go
215 lines
6.5 KiB
Go
package service
|
||
|
||
import (
|
||
"context"
|
||
"sort"
|
||
"strings"
|
||
)
|
||
|
||
// resolveAccountStatsCost 计算账号统计定价费用。
|
||
// 返回 nil 表示不覆盖,使用默认公式(total_cost × account_rate_multiplier)。
|
||
//
|
||
// 优先级(先命中为准):
|
||
// 1. 自定义规则(始终尝试,不依赖 ApplyPricingToAccountStats 开关)
|
||
// 2. ApplyPricingToAccountStats 启用时,直接使用本次请求的客户计费(倍率前的 totalCost)
|
||
// 3. 模型定价文件(LiteLLM)中上游模型的默认价格
|
||
// 4. nil → 走默认公式(total_cost × account_rate_multiplier)
|
||
//
|
||
// upstreamModel 是最终发往上游的模型 ID。
|
||
// totalCost 是本次请求的客户计费(倍率前),用于优先级 2。
|
||
func resolveAccountStatsCost(
|
||
ctx context.Context,
|
||
channelService *ChannelService,
|
||
billingService *BillingService,
|
||
accountID int64,
|
||
groupID int64,
|
||
upstreamModel string,
|
||
tokens UsageTokens,
|
||
requestCount int,
|
||
totalCost float64,
|
||
) *float64 {
|
||
if channelService == nil || upstreamModel == "" {
|
||
return nil
|
||
}
|
||
channel, err := channelService.GetChannelForGroup(ctx, groupID)
|
||
if err != nil || channel == nil {
|
||
return nil
|
||
}
|
||
|
||
platform := channelService.GetGroupPlatform(ctx, groupID)
|
||
|
||
// 优先级 1:自定义规则(始终尝试)
|
||
if cost := tryCustomRules(channel, accountID, groupID, platform, upstreamModel, tokens, requestCount); cost != nil {
|
||
return cost
|
||
}
|
||
|
||
// 优先级 2:渠道开启"应用模型定价到账号统计"时,直接使用客户计费(倍率前)
|
||
if channel.ApplyPricingToAccountStats {
|
||
cost := totalCost
|
||
if cost <= 0 {
|
||
return nil
|
||
}
|
||
return &cost
|
||
}
|
||
|
||
// 优先级 3:模型定价文件(LiteLLM)默认价格
|
||
if billingService != nil {
|
||
return tryModelFilePricing(billingService, upstreamModel, tokens)
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// tryModelFilePricing 使用模型定价文件(LiteLLM/fallback)中的标准价格计算费用。
|
||
func tryModelFilePricing(billingService *BillingService, model string, tokens UsageTokens) *float64 {
|
||
pricing, err := billingService.GetModelPricing(model)
|
||
if err != nil || pricing == nil {
|
||
return nil
|
||
}
|
||
cost := float64(tokens.InputTokens)*pricing.InputPricePerToken +
|
||
float64(tokens.OutputTokens)*pricing.OutputPricePerToken +
|
||
float64(tokens.CacheCreationTokens)*pricing.CacheCreationPricePerToken +
|
||
float64(tokens.CacheReadTokens)*pricing.CacheReadPricePerToken +
|
||
float64(tokens.ImageOutputTokens)*pricing.ImageOutputPricePerToken
|
||
if cost <= 0 {
|
||
return nil
|
||
}
|
||
return &cost
|
||
}
|
||
|
||
// tryCustomRules 遍历自定义规则,按数组顺序先命中为准。
|
||
func tryCustomRules(
|
||
channel *Channel, accountID, groupID int64,
|
||
platform, model string, tokens UsageTokens, requestCount int,
|
||
) *float64 {
|
||
modelLower := strings.ToLower(model)
|
||
for _, rule := range channel.AccountStatsPricingRules {
|
||
if !matchAccountStatsRule(&rule, accountID, groupID) {
|
||
continue
|
||
}
|
||
pricing := findPricingForModel(rule.Pricing, platform, modelLower)
|
||
if pricing == nil {
|
||
continue // 规则匹配但模型不在规则定价中,继续下一条
|
||
}
|
||
return calculateStatsCost(pricing, tokens, requestCount)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// matchAccountStatsRule 检查规则是否匹配指定的 accountID 和 groupID。
|
||
// 匹配条件:accountID ∈ rule.AccountIDs 或 groupID ∈ rule.GroupIDs。
|
||
// 如果规则的 AccountIDs 和 GroupIDs 都为空,视为不匹配。
|
||
func matchAccountStatsRule(rule *AccountStatsPricingRule, accountID, groupID int64) bool {
|
||
if len(rule.AccountIDs) == 0 && len(rule.GroupIDs) == 0 {
|
||
return false
|
||
}
|
||
for _, id := range rule.AccountIDs {
|
||
if id == accountID {
|
||
return true
|
||
}
|
||
}
|
||
for _, id := range rule.GroupIDs {
|
||
if id == groupID {
|
||
return true
|
||
}
|
||
}
|
||
return false
|
||
}
|
||
|
||
// wildcardMatch 通配符匹配候选项(用于排序)
|
||
type wildcardMatch struct {
|
||
prefixLen int
|
||
pricing *ChannelModelPricing
|
||
}
|
||
|
||
// findPricingForModel 在定价列表中查找匹配的模型定价。
|
||
// 先精确匹配,再通配符匹配(前缀越长优先级越高)。
|
||
func findPricingForModel(pricingList []ChannelModelPricing, platform, modelLower string) *ChannelModelPricing {
|
||
// 精确匹配优先
|
||
for i := range pricingList {
|
||
p := &pricingList[i]
|
||
if !isPlatformMatch(platform, p.Platform) {
|
||
continue
|
||
}
|
||
for _, m := range p.Models {
|
||
if strings.ToLower(m) == modelLower {
|
||
return p
|
||
}
|
||
}
|
||
}
|
||
// 通配符匹配:收集所有匹配项,按前缀长度降序取最长
|
||
var matches []wildcardMatch
|
||
for i := range pricingList {
|
||
p := &pricingList[i]
|
||
if !isPlatformMatch(platform, p.Platform) {
|
||
continue
|
||
}
|
||
for _, m := range p.Models {
|
||
ml := strings.ToLower(m)
|
||
if !strings.HasSuffix(ml, "*") {
|
||
continue
|
||
}
|
||
prefix := strings.TrimSuffix(ml, "*")
|
||
if strings.HasPrefix(modelLower, prefix) {
|
||
matches = append(matches, wildcardMatch{prefixLen: len(prefix), pricing: p})
|
||
}
|
||
}
|
||
}
|
||
if len(matches) == 0 {
|
||
return nil
|
||
}
|
||
sort.Slice(matches, func(i, j int) bool {
|
||
return matches[i].prefixLen > matches[j].prefixLen
|
||
})
|
||
return matches[0].pricing
|
||
}
|
||
|
||
// isPlatformMatch 判断平台是否匹配(空平台视为不限平台)。
|
||
func isPlatformMatch(queryPlatform, pricingPlatform string) bool {
|
||
if queryPlatform == "" || pricingPlatform == "" {
|
||
return true
|
||
}
|
||
return queryPlatform == pricingPlatform
|
||
}
|
||
|
||
// calculateStatsCost 使用给定的定价计算费用(不含任何倍率,原始费用)。
|
||
func calculateStatsCost(pricing *ChannelModelPricing, tokens UsageTokens, requestCount int) *float64 {
|
||
if pricing == nil {
|
||
return nil
|
||
}
|
||
switch pricing.BillingMode {
|
||
case BillingModePerRequest, BillingModeImage:
|
||
return calculatePerRequestStatsCost(pricing, requestCount)
|
||
default:
|
||
return calculateTokenStatsCost(pricing, tokens)
|
||
}
|
||
}
|
||
|
||
// calculatePerRequestStatsCost 按次/图片计费。
|
||
func calculatePerRequestStatsCost(pricing *ChannelModelPricing, requestCount int) *float64 {
|
||
if pricing.PerRequestPrice == nil || *pricing.PerRequestPrice <= 0 {
|
||
return nil
|
||
}
|
||
cost := *pricing.PerRequestPrice * float64(requestCount)
|
||
return &cost
|
||
}
|
||
|
||
// calculateTokenStatsCost Token 计费。
|
||
func calculateTokenStatsCost(pricing *ChannelModelPricing, tokens UsageTokens) *float64 {
|
||
deref := func(p *float64) float64 {
|
||
if p == nil {
|
||
return 0
|
||
}
|
||
return *p
|
||
}
|
||
cost := float64(tokens.InputTokens)*deref(pricing.InputPrice) +
|
||
float64(tokens.OutputTokens)*deref(pricing.OutputPrice) +
|
||
float64(tokens.CacheCreationTokens)*deref(pricing.CacheWritePrice) +
|
||
float64(tokens.CacheReadTokens)*deref(pricing.CacheReadPrice) +
|
||
float64(tokens.ImageOutputTokens)*deref(pricing.ImageOutputPrice)
|
||
if cost <= 0 {
|
||
return nil
|
||
}
|
||
return &cost
|
||
}
|