- 新增 CalculateCostWithLongContext 方法支持阈值双倍计费 - 新增 RecordUsageWithLongContext 方法专用于 Gemini 计费 - Gemini 超过 200K token 的部分按 2 倍费率计算 - 其他平台(Claude/OpenAI)完全不受影响
442 lines
14 KiB
Go
442 lines
14 KiB
Go
package service
|
||
|
||
import (
|
||
"context"
|
||
"fmt"
|
||
|
||
"log"
|
||
"strings"
|
||
|
||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||
)
|
||
|
||
// BillingCache defines cache operations for billing service
|
||
type BillingCache interface {
|
||
// Balance operations
|
||
GetUserBalance(ctx context.Context, userID int64) (float64, error)
|
||
SetUserBalance(ctx context.Context, userID int64, balance float64) error
|
||
DeductUserBalance(ctx context.Context, userID int64, amount float64) error
|
||
InvalidateUserBalance(ctx context.Context, userID int64) error
|
||
|
||
// Subscription operations
|
||
GetSubscriptionCache(ctx context.Context, userID, groupID int64) (*SubscriptionCacheData, error)
|
||
SetSubscriptionCache(ctx context.Context, userID, groupID int64, data *SubscriptionCacheData) error
|
||
UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, cost float64) error
|
||
InvalidateSubscriptionCache(ctx context.Context, userID, groupID int64) error
|
||
}
|
||
|
||
// ModelPricing 模型价格配置(per-token价格,与LiteLLM格式一致)
|
||
type ModelPricing struct {
|
||
InputPricePerToken float64 // 每token输入价格 (USD)
|
||
OutputPricePerToken float64 // 每token输出价格 (USD)
|
||
CacheCreationPricePerToken float64 // 缓存创建每token价格 (USD)
|
||
CacheReadPricePerToken float64 // 缓存读取每token价格 (USD)
|
||
CacheCreation5mPrice float64 // 5分钟缓存创建价格(每百万token)- 仅用于硬编码回退
|
||
CacheCreation1hPrice float64 // 1小时缓存创建价格(每百万token)- 仅用于硬编码回退
|
||
SupportsCacheBreakdown bool // 是否支持详细的缓存分类
|
||
}
|
||
|
||
// UsageTokens 使用的token数量
|
||
type UsageTokens struct {
|
||
InputTokens int
|
||
OutputTokens int
|
||
CacheCreationTokens int
|
||
CacheReadTokens int
|
||
CacheCreation5mTokens int
|
||
CacheCreation1hTokens int
|
||
}
|
||
|
||
// CostBreakdown 费用明细
|
||
type CostBreakdown struct {
|
||
InputCost float64
|
||
OutputCost float64
|
||
CacheCreationCost float64
|
||
CacheReadCost float64
|
||
TotalCost float64
|
||
ActualCost float64 // 应用倍率后的实际费用
|
||
}
|
||
|
||
// BillingService 计费服务
|
||
type BillingService struct {
|
||
cfg *config.Config
|
||
pricingService *PricingService
|
||
fallbackPrices map[string]*ModelPricing // 硬编码回退价格
|
||
}
|
||
|
||
// NewBillingService 创建计费服务实例
|
||
func NewBillingService(cfg *config.Config, pricingService *PricingService) *BillingService {
|
||
s := &BillingService{
|
||
cfg: cfg,
|
||
pricingService: pricingService,
|
||
fallbackPrices: make(map[string]*ModelPricing),
|
||
}
|
||
|
||
// 初始化硬编码回退价格(当动态价格不可用时使用)
|
||
s.initFallbackPricing()
|
||
|
||
return s
|
||
}
|
||
|
||
// initFallbackPricing 初始化硬编码回退价格(当动态价格不可用时使用)
|
||
// 价格单位:USD per token(与LiteLLM格式一致)
|
||
func (s *BillingService) initFallbackPricing() {
|
||
// Claude 4.5 Opus
|
||
s.fallbackPrices["claude-opus-4.5"] = &ModelPricing{
|
||
InputPricePerToken: 5e-6, // $5 per MTok
|
||
OutputPricePerToken: 25e-6, // $25 per MTok
|
||
CacheCreationPricePerToken: 6.25e-6, // $6.25 per MTok
|
||
CacheReadPricePerToken: 0.5e-6, // $0.50 per MTok
|
||
SupportsCacheBreakdown: false,
|
||
}
|
||
|
||
// Claude 4 Sonnet
|
||
s.fallbackPrices["claude-sonnet-4"] = &ModelPricing{
|
||
InputPricePerToken: 3e-6, // $3 per MTok
|
||
OutputPricePerToken: 15e-6, // $15 per MTok
|
||
CacheCreationPricePerToken: 3.75e-6, // $3.75 per MTok
|
||
CacheReadPricePerToken: 0.3e-6, // $0.30 per MTok
|
||
SupportsCacheBreakdown: false,
|
||
}
|
||
|
||
// Claude 3.5 Sonnet
|
||
s.fallbackPrices["claude-3-5-sonnet"] = &ModelPricing{
|
||
InputPricePerToken: 3e-6, // $3 per MTok
|
||
OutputPricePerToken: 15e-6, // $15 per MTok
|
||
CacheCreationPricePerToken: 3.75e-6, // $3.75 per MTok
|
||
CacheReadPricePerToken: 0.3e-6, // $0.30 per MTok
|
||
SupportsCacheBreakdown: false,
|
||
}
|
||
|
||
// Claude 3.5 Haiku
|
||
s.fallbackPrices["claude-3-5-haiku"] = &ModelPricing{
|
||
InputPricePerToken: 1e-6, // $1 per MTok
|
||
OutputPricePerToken: 5e-6, // $5 per MTok
|
||
CacheCreationPricePerToken: 1.25e-6, // $1.25 per MTok
|
||
CacheReadPricePerToken: 0.1e-6, // $0.10 per MTok
|
||
SupportsCacheBreakdown: false,
|
||
}
|
||
|
||
// Claude 3 Opus
|
||
s.fallbackPrices["claude-3-opus"] = &ModelPricing{
|
||
InputPricePerToken: 15e-6, // $15 per MTok
|
||
OutputPricePerToken: 75e-6, // $75 per MTok
|
||
CacheCreationPricePerToken: 18.75e-6, // $18.75 per MTok
|
||
CacheReadPricePerToken: 1.5e-6, // $1.50 per MTok
|
||
SupportsCacheBreakdown: false,
|
||
}
|
||
|
||
// Claude 3 Haiku
|
||
s.fallbackPrices["claude-3-haiku"] = &ModelPricing{
|
||
InputPricePerToken: 0.25e-6, // $0.25 per MTok
|
||
OutputPricePerToken: 1.25e-6, // $1.25 per MTok
|
||
CacheCreationPricePerToken: 0.3e-6, // $0.30 per MTok
|
||
CacheReadPricePerToken: 0.03e-6, // $0.03 per MTok
|
||
SupportsCacheBreakdown: false,
|
||
}
|
||
}
|
||
|
||
// getFallbackPricing 根据模型系列获取回退价格
|
||
func (s *BillingService) getFallbackPricing(model string) *ModelPricing {
|
||
modelLower := strings.ToLower(model)
|
||
|
||
// 按模型系列匹配
|
||
if strings.Contains(modelLower, "opus") {
|
||
if strings.Contains(modelLower, "4.5") || strings.Contains(modelLower, "4-5") {
|
||
return s.fallbackPrices["claude-opus-4.5"]
|
||
}
|
||
return s.fallbackPrices["claude-3-opus"]
|
||
}
|
||
if strings.Contains(modelLower, "sonnet") {
|
||
if strings.Contains(modelLower, "4") && !strings.Contains(modelLower, "3") {
|
||
return s.fallbackPrices["claude-sonnet-4"]
|
||
}
|
||
return s.fallbackPrices["claude-3-5-sonnet"]
|
||
}
|
||
if strings.Contains(modelLower, "haiku") {
|
||
if strings.Contains(modelLower, "3-5") || strings.Contains(modelLower, "3.5") {
|
||
return s.fallbackPrices["claude-3-5-haiku"]
|
||
}
|
||
return s.fallbackPrices["claude-3-haiku"]
|
||
}
|
||
|
||
// 默认使用Sonnet价格
|
||
return s.fallbackPrices["claude-sonnet-4"]
|
||
}
|
||
|
||
// GetModelPricing 获取模型价格配置
|
||
func (s *BillingService) GetModelPricing(model string) (*ModelPricing, error) {
|
||
// 标准化模型名称(转小写)
|
||
model = strings.ToLower(model)
|
||
|
||
// 1. 优先从动态价格服务获取
|
||
if s.pricingService != nil {
|
||
litellmPricing := s.pricingService.GetModelPricing(model)
|
||
if litellmPricing != nil {
|
||
return &ModelPricing{
|
||
InputPricePerToken: litellmPricing.InputCostPerToken,
|
||
OutputPricePerToken: litellmPricing.OutputCostPerToken,
|
||
CacheCreationPricePerToken: litellmPricing.CacheCreationInputTokenCost,
|
||
CacheReadPricePerToken: litellmPricing.CacheReadInputTokenCost,
|
||
SupportsCacheBreakdown: false,
|
||
}, nil
|
||
}
|
||
}
|
||
|
||
// 2. 使用硬编码回退价格
|
||
fallback := s.getFallbackPricing(model)
|
||
if fallback != nil {
|
||
log.Printf("[Billing] Using fallback pricing for model: %s", model)
|
||
return fallback, nil
|
||
}
|
||
|
||
return nil, fmt.Errorf("pricing not found for model: %s", model)
|
||
}
|
||
|
||
// CalculateCost 计算使用费用
|
||
func (s *BillingService) CalculateCost(model string, tokens UsageTokens, rateMultiplier float64) (*CostBreakdown, error) {
|
||
pricing, err := s.GetModelPricing(model)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
breakdown := &CostBreakdown{}
|
||
|
||
// 计算输入token费用(使用per-token价格)
|
||
breakdown.InputCost = float64(tokens.InputTokens) * pricing.InputPricePerToken
|
||
|
||
// 计算输出token费用
|
||
breakdown.OutputCost = float64(tokens.OutputTokens) * pricing.OutputPricePerToken
|
||
|
||
// 计算缓存费用
|
||
if pricing.SupportsCacheBreakdown && (pricing.CacheCreation5mPrice > 0 || pricing.CacheCreation1hPrice > 0) {
|
||
// 支持详细缓存分类的模型(5分钟/1小时缓存)
|
||
breakdown.CacheCreationCost = float64(tokens.CacheCreation5mTokens)/1_000_000*pricing.CacheCreation5mPrice +
|
||
float64(tokens.CacheCreation1hTokens)/1_000_000*pricing.CacheCreation1hPrice
|
||
} else {
|
||
// 标准缓存创建价格(per-token)
|
||
breakdown.CacheCreationCost = float64(tokens.CacheCreationTokens) * pricing.CacheCreationPricePerToken
|
||
}
|
||
|
||
breakdown.CacheReadCost = float64(tokens.CacheReadTokens) * pricing.CacheReadPricePerToken
|
||
|
||
// 计算总费用
|
||
breakdown.TotalCost = breakdown.InputCost + breakdown.OutputCost +
|
||
breakdown.CacheCreationCost + breakdown.CacheReadCost
|
||
|
||
// 应用倍率计算实际费用
|
||
if rateMultiplier <= 0 {
|
||
rateMultiplier = 1.0
|
||
}
|
||
breakdown.ActualCost = breakdown.TotalCost * rateMultiplier
|
||
|
||
return breakdown, nil
|
||
}
|
||
|
||
// CalculateCostWithConfig 使用配置中的默认倍率计算费用
|
||
func (s *BillingService) CalculateCostWithConfig(model string, tokens UsageTokens) (*CostBreakdown, error) {
|
||
multiplier := s.cfg.Default.RateMultiplier
|
||
if multiplier <= 0 {
|
||
multiplier = 1.0
|
||
}
|
||
return s.CalculateCost(model, tokens, multiplier)
|
||
}
|
||
|
||
// CalculateCostWithLongContext 计算费用,支持长上下文双倍计费
|
||
// threshold: 阈值(如 200000),超过此值的部分按 extraMultiplier 倍计费
|
||
// extraMultiplier: 超出部分的倍率(如 2.0 表示双倍)
|
||
func (s *BillingService) CalculateCostWithLongContext(model string, tokens UsageTokens, rateMultiplier float64, threshold int, extraMultiplier float64) (*CostBreakdown, error) {
|
||
// 1. 先正常计算全部 token 的成本
|
||
cost, err := s.CalculateCost(model, tokens, rateMultiplier)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// 2. 如果未启用长上下文计费或未超过阈值,直接返回
|
||
if threshold <= 0 || extraMultiplier <= 1 {
|
||
return cost, nil
|
||
}
|
||
|
||
// 计算总输入 token(缓存读取 + 新输入)
|
||
total := tokens.CacheReadTokens + tokens.InputTokens
|
||
if total <= threshold {
|
||
return cost, nil
|
||
}
|
||
|
||
// 3. 拆分超出部分的 token
|
||
extra := total - threshold
|
||
var extraCacheTokens, extraInputTokens int
|
||
|
||
if tokens.CacheReadTokens >= threshold {
|
||
// 缓存已超过阈值:超出的缓存 + 全部输入
|
||
extraCacheTokens = tokens.CacheReadTokens - threshold
|
||
extraInputTokens = tokens.InputTokens
|
||
} else {
|
||
// 缓存未超过阈值:只有输入超出部分
|
||
extraCacheTokens = 0
|
||
extraInputTokens = extra
|
||
}
|
||
|
||
// 4. 计算超出部分的成本(只算输入和缓存读取)
|
||
extraTokens := UsageTokens{
|
||
InputTokens: extraInputTokens,
|
||
CacheReadTokens: extraCacheTokens,
|
||
}
|
||
extraCost, err := s.CalculateCost(model, extraTokens, 1.0) // 先按 1 倍算
|
||
if err != nil {
|
||
return cost, nil // 出错时返回正常成本
|
||
}
|
||
|
||
// 5. 额外成本 = 超出部分成本 × (倍率 - 1)
|
||
extraRate := extraMultiplier - 1
|
||
additionalInputCost := extraCost.InputCost * extraRate
|
||
additionalCacheCost := extraCost.CacheReadCost * extraRate
|
||
|
||
// 6. 累加到总成本
|
||
cost.InputCost += additionalInputCost
|
||
cost.CacheReadCost += additionalCacheCost
|
||
cost.TotalCost += additionalInputCost + additionalCacheCost
|
||
cost.ActualCost = cost.TotalCost * rateMultiplier
|
||
|
||
return cost, nil
|
||
}
|
||
|
||
// ListSupportedModels 列出所有支持的模型(现在总是返回true,因为有模糊匹配)
|
||
func (s *BillingService) ListSupportedModels() []string {
|
||
models := make([]string, 0)
|
||
// 返回回退价格支持的模型系列
|
||
for model := range s.fallbackPrices {
|
||
models = append(models, model)
|
||
}
|
||
return models
|
||
}
|
||
|
||
// IsModelSupported 检查模型是否支持(现在总是返回true,因为有模糊匹配回退)
|
||
func (s *BillingService) IsModelSupported(model string) bool {
|
||
// 所有Claude模型都有回退价格支持
|
||
modelLower := strings.ToLower(model)
|
||
return strings.Contains(modelLower, "claude") ||
|
||
strings.Contains(modelLower, "opus") ||
|
||
strings.Contains(modelLower, "sonnet") ||
|
||
strings.Contains(modelLower, "haiku")
|
||
}
|
||
|
||
// GetEstimatedCost 估算费用(用于前端展示)
|
||
func (s *BillingService) GetEstimatedCost(model string, estimatedInputTokens, estimatedOutputTokens int) (float64, error) {
|
||
tokens := UsageTokens{
|
||
InputTokens: estimatedInputTokens,
|
||
OutputTokens: estimatedOutputTokens,
|
||
}
|
||
|
||
breakdown, err := s.CalculateCostWithConfig(model, tokens)
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
|
||
return breakdown.ActualCost, nil
|
||
}
|
||
|
||
// GetPricingServiceStatus 获取价格服务状态
|
||
func (s *BillingService) GetPricingServiceStatus() map[string]any {
|
||
if s.pricingService != nil {
|
||
return s.pricingService.GetStatus()
|
||
}
|
||
return map[string]any{
|
||
"model_count": len(s.fallbackPrices),
|
||
"last_updated": "using fallback",
|
||
"local_hash": "N/A",
|
||
}
|
||
}
|
||
|
||
// ForceUpdatePricing 强制更新价格数据
|
||
func (s *BillingService) ForceUpdatePricing() error {
|
||
if s.pricingService != nil {
|
||
return s.pricingService.ForceUpdate()
|
||
}
|
||
return fmt.Errorf("pricing service not initialized")
|
||
}
|
||
|
||
// ImagePriceConfig 图片计费配置
|
||
type ImagePriceConfig struct {
|
||
Price1K *float64 // 1K 尺寸价格(nil 表示使用默认值)
|
||
Price2K *float64 // 2K 尺寸价格(nil 表示使用默认值)
|
||
Price4K *float64 // 4K 尺寸价格(nil 表示使用默认值)
|
||
}
|
||
|
||
// CalculateImageCost 计算图片生成费用
|
||
// model: 请求的模型名称(用于获取 LiteLLM 默认价格)
|
||
// imageSize: 图片尺寸 "1K", "2K", "4K"
|
||
// imageCount: 生成的图片数量
|
||
// groupConfig: 分组配置的价格(可能为 nil,表示使用默认值)
|
||
// rateMultiplier: 费率倍数
|
||
func (s *BillingService) CalculateImageCost(model string, imageSize string, imageCount int, groupConfig *ImagePriceConfig, rateMultiplier float64) *CostBreakdown {
|
||
if imageCount <= 0 {
|
||
return &CostBreakdown{}
|
||
}
|
||
|
||
// 获取单价
|
||
unitPrice := s.getImageUnitPrice(model, imageSize, groupConfig)
|
||
|
||
// 计算总费用
|
||
totalCost := unitPrice * float64(imageCount)
|
||
|
||
// 应用倍率
|
||
if rateMultiplier <= 0 {
|
||
rateMultiplier = 1.0
|
||
}
|
||
actualCost := totalCost * rateMultiplier
|
||
|
||
return &CostBreakdown{
|
||
TotalCost: totalCost,
|
||
ActualCost: actualCost,
|
||
}
|
||
}
|
||
|
||
// getImageUnitPrice 获取图片单价
|
||
func (s *BillingService) getImageUnitPrice(model string, imageSize string, groupConfig *ImagePriceConfig) float64 {
|
||
// 优先使用分组配置的价格
|
||
if groupConfig != nil {
|
||
switch imageSize {
|
||
case "1K":
|
||
if groupConfig.Price1K != nil {
|
||
return *groupConfig.Price1K
|
||
}
|
||
case "2K":
|
||
if groupConfig.Price2K != nil {
|
||
return *groupConfig.Price2K
|
||
}
|
||
case "4K":
|
||
if groupConfig.Price4K != nil {
|
||
return *groupConfig.Price4K
|
||
}
|
||
}
|
||
}
|
||
|
||
// 回退到 LiteLLM 默认价格
|
||
return s.getDefaultImagePrice(model, imageSize)
|
||
}
|
||
|
||
// getDefaultImagePrice 获取 LiteLLM 默认图片价格
|
||
func (s *BillingService) getDefaultImagePrice(model string, imageSize string) float64 {
|
||
basePrice := 0.0
|
||
|
||
// 从 PricingService 获取 output_cost_per_image
|
||
if s.pricingService != nil {
|
||
pricing := s.pricingService.GetModelPricing(model)
|
||
if pricing != nil && pricing.OutputCostPerImage > 0 {
|
||
basePrice = pricing.OutputCostPerImage
|
||
}
|
||
}
|
||
|
||
// 如果没有找到价格,使用硬编码默认值($0.134,来自 gemini-3-pro-image-preview)
|
||
if basePrice <= 0 {
|
||
basePrice = 0.134
|
||
}
|
||
|
||
// 4K 尺寸翻倍
|
||
if imageSize == "4K" {
|
||
return basePrice * 2
|
||
}
|
||
|
||
return basePrice
|
||
}
|