package service import ( "context" "fmt" "log" "strings" "github.com/Wei-Shaw/sub2api/internal/config" ) // BillingCache defines cache operations for billing service type BillingCache interface { // Balance operations GetUserBalance(ctx context.Context, userID int64) (float64, error) SetUserBalance(ctx context.Context, userID int64, balance float64) error DeductUserBalance(ctx context.Context, userID int64, amount float64) error InvalidateUserBalance(ctx context.Context, userID int64) error // Subscription operations GetSubscriptionCache(ctx context.Context, userID, groupID int64) (*SubscriptionCacheData, error) SetSubscriptionCache(ctx context.Context, userID, groupID int64, data *SubscriptionCacheData) error UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, cost float64) error InvalidateSubscriptionCache(ctx context.Context, userID, groupID int64) error } // ModelPricing 模型价格配置(per-token价格,与LiteLLM格式一致) type ModelPricing struct { InputPricePerToken float64 // 每token输入价格 (USD) OutputPricePerToken float64 // 每token输出价格 (USD) CacheCreationPricePerToken float64 // 缓存创建每token价格 (USD) CacheReadPricePerToken float64 // 缓存读取每token价格 (USD) CacheCreation5mPrice float64 // 5分钟缓存创建价格(每百万token)- 仅用于硬编码回退 CacheCreation1hPrice float64 // 1小时缓存创建价格(每百万token)- 仅用于硬编码回退 SupportsCacheBreakdown bool // 是否支持详细的缓存分类 } // UsageTokens 使用的token数量 type UsageTokens struct { InputTokens int OutputTokens int CacheCreationTokens int CacheReadTokens int CacheCreation5mTokens int CacheCreation1hTokens int } // CostBreakdown 费用明细 type CostBreakdown struct { InputCost float64 OutputCost float64 CacheCreationCost float64 CacheReadCost float64 TotalCost float64 ActualCost float64 // 应用倍率后的实际费用 } // BillingService 计费服务 type BillingService struct { cfg *config.Config pricingService *PricingService fallbackPrices map[string]*ModelPricing // 硬编码回退价格 } // NewBillingService 创建计费服务实例 func NewBillingService(cfg *config.Config, pricingService *PricingService) *BillingService { s := &BillingService{ cfg: cfg, pricingService: pricingService, fallbackPrices: make(map[string]*ModelPricing), } // 初始化硬编码回退价格(当动态价格不可用时使用) s.initFallbackPricing() return s } // initFallbackPricing 初始化硬编码回退价格(当动态价格不可用时使用) // 价格单位:USD per token(与LiteLLM格式一致) func (s *BillingService) initFallbackPricing() { // Claude 4.5 Opus s.fallbackPrices["claude-opus-4.5"] = &ModelPricing{ InputPricePerToken: 5e-6, // $5 per MTok OutputPricePerToken: 25e-6, // $25 per MTok CacheCreationPricePerToken: 6.25e-6, // $6.25 per MTok CacheReadPricePerToken: 0.5e-6, // $0.50 per MTok SupportsCacheBreakdown: false, } // Claude 4 Sonnet s.fallbackPrices["claude-sonnet-4"] = &ModelPricing{ InputPricePerToken: 3e-6, // $3 per MTok OutputPricePerToken: 15e-6, // $15 per MTok CacheCreationPricePerToken: 3.75e-6, // $3.75 per MTok CacheReadPricePerToken: 0.3e-6, // $0.30 per MTok SupportsCacheBreakdown: false, } // Claude 3.5 Sonnet s.fallbackPrices["claude-3-5-sonnet"] = &ModelPricing{ InputPricePerToken: 3e-6, // $3 per MTok OutputPricePerToken: 15e-6, // $15 per MTok CacheCreationPricePerToken: 3.75e-6, // $3.75 per MTok CacheReadPricePerToken: 0.3e-6, // $0.30 per MTok SupportsCacheBreakdown: false, } // Claude 3.5 Haiku s.fallbackPrices["claude-3-5-haiku"] = &ModelPricing{ InputPricePerToken: 1e-6, // $1 per MTok OutputPricePerToken: 5e-6, // $5 per MTok CacheCreationPricePerToken: 1.25e-6, // $1.25 per MTok CacheReadPricePerToken: 0.1e-6, // $0.10 per MTok SupportsCacheBreakdown: false, } // Claude 3 Opus s.fallbackPrices["claude-3-opus"] = &ModelPricing{ InputPricePerToken: 15e-6, // $15 per MTok OutputPricePerToken: 75e-6, // $75 per MTok CacheCreationPricePerToken: 18.75e-6, // $18.75 per MTok CacheReadPricePerToken: 1.5e-6, // $1.50 per MTok SupportsCacheBreakdown: false, } // Claude 3 Haiku s.fallbackPrices["claude-3-haiku"] = &ModelPricing{ InputPricePerToken: 0.25e-6, // $0.25 per MTok OutputPricePerToken: 1.25e-6, // $1.25 per MTok CacheCreationPricePerToken: 0.3e-6, // $0.30 per MTok CacheReadPricePerToken: 0.03e-6, // $0.03 per MTok SupportsCacheBreakdown: false, } } // getFallbackPricing 根据模型系列获取回退价格 func (s *BillingService) getFallbackPricing(model string) *ModelPricing { modelLower := strings.ToLower(model) // 按模型系列匹配 if strings.Contains(modelLower, "opus") { if strings.Contains(modelLower, "4.5") || strings.Contains(modelLower, "4-5") { return s.fallbackPrices["claude-opus-4.5"] } return s.fallbackPrices["claude-3-opus"] } if strings.Contains(modelLower, "sonnet") { if strings.Contains(modelLower, "4") && !strings.Contains(modelLower, "3") { return s.fallbackPrices["claude-sonnet-4"] } return s.fallbackPrices["claude-3-5-sonnet"] } if strings.Contains(modelLower, "haiku") { if strings.Contains(modelLower, "3-5") || strings.Contains(modelLower, "3.5") { return s.fallbackPrices["claude-3-5-haiku"] } return s.fallbackPrices["claude-3-haiku"] } // 默认使用Sonnet价格 return s.fallbackPrices["claude-sonnet-4"] } // GetModelPricing 获取模型价格配置 func (s *BillingService) GetModelPricing(model string) (*ModelPricing, error) { // 标准化模型名称(转小写) model = strings.ToLower(model) // 1. 优先从动态价格服务获取 if s.pricingService != nil { litellmPricing := s.pricingService.GetModelPricing(model) if litellmPricing != nil { return &ModelPricing{ InputPricePerToken: litellmPricing.InputCostPerToken, OutputPricePerToken: litellmPricing.OutputCostPerToken, CacheCreationPricePerToken: litellmPricing.CacheCreationInputTokenCost, CacheReadPricePerToken: litellmPricing.CacheReadInputTokenCost, SupportsCacheBreakdown: false, }, nil } } // 2. 使用硬编码回退价格 fallback := s.getFallbackPricing(model) if fallback != nil { log.Printf("[Billing] Using fallback pricing for model: %s", model) return fallback, nil } return nil, fmt.Errorf("pricing not found for model: %s", model) } // CalculateCost 计算使用费用 func (s *BillingService) CalculateCost(model string, tokens UsageTokens, rateMultiplier float64) (*CostBreakdown, error) { pricing, err := s.GetModelPricing(model) if err != nil { return nil, err } breakdown := &CostBreakdown{} // 计算输入token费用(使用per-token价格) breakdown.InputCost = float64(tokens.InputTokens) * pricing.InputPricePerToken // 计算输出token费用 breakdown.OutputCost = float64(tokens.OutputTokens) * pricing.OutputPricePerToken // 计算缓存费用 if pricing.SupportsCacheBreakdown && (pricing.CacheCreation5mPrice > 0 || pricing.CacheCreation1hPrice > 0) { // 支持详细缓存分类的模型(5分钟/1小时缓存) breakdown.CacheCreationCost = float64(tokens.CacheCreation5mTokens)/1_000_000*pricing.CacheCreation5mPrice + float64(tokens.CacheCreation1hTokens)/1_000_000*pricing.CacheCreation1hPrice } else { // 标准缓存创建价格(per-token) breakdown.CacheCreationCost = float64(tokens.CacheCreationTokens) * pricing.CacheCreationPricePerToken } breakdown.CacheReadCost = float64(tokens.CacheReadTokens) * pricing.CacheReadPricePerToken // 计算总费用 breakdown.TotalCost = breakdown.InputCost + breakdown.OutputCost + breakdown.CacheCreationCost + breakdown.CacheReadCost // 应用倍率计算实际费用 if rateMultiplier <= 0 { rateMultiplier = 1.0 } breakdown.ActualCost = breakdown.TotalCost * rateMultiplier return breakdown, nil } // CalculateCostWithConfig 使用配置中的默认倍率计算费用 func (s *BillingService) CalculateCostWithConfig(model string, tokens UsageTokens) (*CostBreakdown, error) { multiplier := s.cfg.Default.RateMultiplier if multiplier <= 0 { multiplier = 1.0 } return s.CalculateCost(model, tokens, multiplier) } // CalculateCostWithLongContext 计算费用,支持长上下文双倍计费 // threshold: 阈值(如 200000),超过此值的部分按 extraMultiplier 倍计费 // extraMultiplier: 超出部分的倍率(如 2.0 表示双倍) // // 示例:缓存 210k + 输入 10k = 220k,阈值 200k,倍率 2.0 // 拆分为:范围内 (200k, 0) + 范围外 (10k, 10k) // 范围内正常计费,范围外 × 2 计费 func (s *BillingService) CalculateCostWithLongContext(model string, tokens UsageTokens, rateMultiplier float64, threshold int, extraMultiplier float64) (*CostBreakdown, error) { // 未启用长上下文计费,直接走正常计费 if threshold <= 0 || extraMultiplier <= 1 { return s.CalculateCost(model, tokens, rateMultiplier) } // 计算总输入 token(缓存读取 + 新输入) total := tokens.CacheReadTokens + tokens.InputTokens if total <= threshold { return s.CalculateCost(model, tokens, rateMultiplier) } // 拆分成范围内和范围外 var inRangeCacheTokens, inRangeInputTokens int var outRangeCacheTokens, outRangeInputTokens int if tokens.CacheReadTokens >= threshold { // 缓存已超过阈值:范围内只有缓存,范围外是超出的缓存+全部输入 inRangeCacheTokens = threshold inRangeInputTokens = 0 outRangeCacheTokens = tokens.CacheReadTokens - threshold outRangeInputTokens = tokens.InputTokens } else { // 缓存未超过阈值:范围内是全部缓存+部分输入,范围外是剩余输入 inRangeCacheTokens = tokens.CacheReadTokens inRangeInputTokens = threshold - tokens.CacheReadTokens outRangeCacheTokens = 0 outRangeInputTokens = tokens.InputTokens - inRangeInputTokens } // 范围内部分:正常计费 inRangeTokens := UsageTokens{ InputTokens: inRangeInputTokens, OutputTokens: tokens.OutputTokens, // 输出只算一次 CacheCreationTokens: tokens.CacheCreationTokens, CacheReadTokens: inRangeCacheTokens, } inRangeCost, err := s.CalculateCost(model, inRangeTokens, rateMultiplier) if err != nil { return nil, err } // 范围外部分:× extraMultiplier 计费 outRangeTokens := UsageTokens{ InputTokens: outRangeInputTokens, CacheReadTokens: outRangeCacheTokens, } outRangeCost, err := s.CalculateCost(model, outRangeTokens, rateMultiplier*extraMultiplier) if err != nil { return inRangeCost, nil // 出错时返回范围内成本 } // 合并成本 return &CostBreakdown{ InputCost: inRangeCost.InputCost + outRangeCost.InputCost, OutputCost: inRangeCost.OutputCost, CacheCreationCost: inRangeCost.CacheCreationCost, CacheReadCost: inRangeCost.CacheReadCost + outRangeCost.CacheReadCost, TotalCost: inRangeCost.TotalCost + outRangeCost.TotalCost, ActualCost: inRangeCost.ActualCost + outRangeCost.ActualCost, }, nil } // ListSupportedModels 列出所有支持的模型(现在总是返回true,因为有模糊匹配) func (s *BillingService) ListSupportedModels() []string { models := make([]string, 0) // 返回回退价格支持的模型系列 for model := range s.fallbackPrices { models = append(models, model) } return models } // IsModelSupported 检查模型是否支持(现在总是返回true,因为有模糊匹配回退) func (s *BillingService) IsModelSupported(model string) bool { // 所有Claude模型都有回退价格支持 modelLower := strings.ToLower(model) return strings.Contains(modelLower, "claude") || strings.Contains(modelLower, "opus") || strings.Contains(modelLower, "sonnet") || strings.Contains(modelLower, "haiku") } // GetEstimatedCost 估算费用(用于前端展示) func (s *BillingService) GetEstimatedCost(model string, estimatedInputTokens, estimatedOutputTokens int) (float64, error) { tokens := UsageTokens{ InputTokens: estimatedInputTokens, OutputTokens: estimatedOutputTokens, } breakdown, err := s.CalculateCostWithConfig(model, tokens) if err != nil { return 0, err } return breakdown.ActualCost, nil } // GetPricingServiceStatus 获取价格服务状态 func (s *BillingService) GetPricingServiceStatus() map[string]any { if s.pricingService != nil { return s.pricingService.GetStatus() } return map[string]any{ "model_count": len(s.fallbackPrices), "last_updated": "using fallback", "local_hash": "N/A", } } // ForceUpdatePricing 强制更新价格数据 func (s *BillingService) ForceUpdatePricing() error { if s.pricingService != nil { return s.pricingService.ForceUpdate() } return fmt.Errorf("pricing service not initialized") } // ImagePriceConfig 图片计费配置 type ImagePriceConfig struct { Price1K *float64 // 1K 尺寸价格(nil 表示使用默认值) Price2K *float64 // 2K 尺寸价格(nil 表示使用默认值) Price4K *float64 // 4K 尺寸价格(nil 表示使用默认值) } // CalculateImageCost 计算图片生成费用 // model: 请求的模型名称(用于获取 LiteLLM 默认价格) // imageSize: 图片尺寸 "1K", "2K", "4K" // imageCount: 生成的图片数量 // groupConfig: 分组配置的价格(可能为 nil,表示使用默认值) // rateMultiplier: 费率倍数 func (s *BillingService) CalculateImageCost(model string, imageSize string, imageCount int, groupConfig *ImagePriceConfig, rateMultiplier float64) *CostBreakdown { if imageCount <= 0 { return &CostBreakdown{} } // 获取单价 unitPrice := s.getImageUnitPrice(model, imageSize, groupConfig) // 计算总费用 totalCost := unitPrice * float64(imageCount) // 应用倍率 if rateMultiplier <= 0 { rateMultiplier = 1.0 } actualCost := totalCost * rateMultiplier return &CostBreakdown{ TotalCost: totalCost, ActualCost: actualCost, } } // getImageUnitPrice 获取图片单价 func (s *BillingService) getImageUnitPrice(model string, imageSize string, groupConfig *ImagePriceConfig) float64 { // 优先使用分组配置的价格 if groupConfig != nil { switch imageSize { case "1K": if groupConfig.Price1K != nil { return *groupConfig.Price1K } case "2K": if groupConfig.Price2K != nil { return *groupConfig.Price2K } case "4K": if groupConfig.Price4K != nil { return *groupConfig.Price4K } } } // 回退到 LiteLLM 默认价格 return s.getDefaultImagePrice(model, imageSize) } // getDefaultImagePrice 获取 LiteLLM 默认图片价格 func (s *BillingService) getDefaultImagePrice(model string, imageSize string) float64 { basePrice := 0.0 // 从 PricingService 获取 output_cost_per_image if s.pricingService != nil { pricing := s.pricingService.GetModelPricing(model) if pricing != nil && pricing.OutputCostPerImage > 0 { basePrice = pricing.OutputCostPerImage } } // 如果没有找到价格,使用硬编码默认值($0.134,来自 gemini-3-pro-image-preview) if basePrice <= 0 { basePrice = 0.134 } // 4K 尺寸翻倍 if imageSize == "4K" { return basePrice * 2 } return basePrice }