diff --git a/backend/internal/handler/gemini_v1beta_handler.go b/backend/internal/handler/gemini_v1beta_handler.go index 32f83013..d1b19ede 100644 --- a/backend/internal/handler/gemini_v1beta_handler.go +++ b/backend/internal/handler/gemini_v1beta_handler.go @@ -366,18 +366,21 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { userAgent := c.GetHeader("User-Agent") clientIP := ip.GetClientIP(c) - // 6) record usage async + // 6) record usage async (Gemini 使用长上下文双倍计费) go func(result *service.ForwardResult, usedAccount *service.Account, ua, ip string) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ - Result: result, - APIKey: apiKey, - User: apiKey.User, - Account: usedAccount, - Subscription: subscription, - UserAgent: ua, - IPAddress: ip, + + if err := h.gatewayService.RecordUsageWithLongContext(ctx, &service.RecordUsageLongContextInput{ + Result: result, + APIKey: apiKey, + User: apiKey.User, + Account: usedAccount, + Subscription: subscription, + UserAgent: ua, + IPAddress: ip, + LongContextThreshold: 200000, // Gemini 200K 阈值 + LongContextMultiplier: 2.0, // 超出部分双倍计费 }); err != nil { log.Printf("Record usage failed: %v", err) } diff --git a/backend/internal/service/billing_service.go b/backend/internal/service/billing_service.go index f2afc343..95e16c4e 100644 --- a/backend/internal/service/billing_service.go +++ b/backend/internal/service/billing_service.go @@ -241,6 +241,65 @@ func (s *BillingService) CalculateCostWithConfig(model string, tokens UsageToken return s.CalculateCost(model, tokens, multiplier) } +// CalculateCostWithLongContext 计算费用,支持长上下文双倍计费 +// threshold: 阈值(如 200000),超过此值的部分按 extraMultiplier 倍计费 +// extraMultiplier: 超出部分的倍率(如 2.0 表示双倍) +func (s *BillingService) CalculateCostWithLongContext(model string, tokens UsageTokens, rateMultiplier float64, threshold int, extraMultiplier float64) (*CostBreakdown, error) { + // 1. 先正常计算全部 token 的成本 + cost, err := s.CalculateCost(model, tokens, rateMultiplier) + if err != nil { + return nil, err + } + + // 2. 如果未启用长上下文计费或未超过阈值,直接返回 + if threshold <= 0 || extraMultiplier <= 1 { + return cost, nil + } + + // 计算总输入 token(缓存读取 + 新输入) + total := tokens.CacheReadTokens + tokens.InputTokens + if total <= threshold { + return cost, nil + } + + // 3. 拆分超出部分的 token + extra := total - threshold + var extraCacheTokens, extraInputTokens int + + if tokens.CacheReadTokens >= threshold { + // 缓存已超过阈值:超出的缓存 + 全部输入 + extraCacheTokens = tokens.CacheReadTokens - threshold + extraInputTokens = tokens.InputTokens + } else { + // 缓存未超过阈值:只有输入超出部分 + extraCacheTokens = 0 + extraInputTokens = extra + } + + // 4. 计算超出部分的成本(只算输入和缓存读取) + extraTokens := UsageTokens{ + InputTokens: extraInputTokens, + CacheReadTokens: extraCacheTokens, + } + extraCost, err := s.CalculateCost(model, extraTokens, 1.0) // 先按 1 倍算 + if err != nil { + return cost, nil // 出错时返回正常成本 + } + + // 5. 额外成本 = 超出部分成本 × (倍率 - 1) + extraRate := extraMultiplier - 1 + additionalInputCost := extraCost.InputCost * extraRate + additionalCacheCost := extraCost.CacheReadCost * extraRate + + // 6. 累加到总成本 + cost.InputCost += additionalInputCost + cost.CacheReadCost += additionalCacheCost + cost.TotalCost += additionalInputCost + additionalCacheCost + cost.ActualCost = cost.TotalCost * rateMultiplier + + return cost, nil +} + // ListSupportedModels 列出所有支持的模型(现在总是返回true,因为有模糊匹配) func (s *BillingService) ListSupportedModels() []string { models := make([]string, 0) diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 7a901907..9125163a 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -3606,6 +3606,162 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu return nil } +// RecordUsageLongContextInput 记录使用量的输入参数(支持长上下文双倍计费) +type RecordUsageLongContextInput struct { + Result *ForwardResult + APIKey *APIKey + User *User + Account *Account + Subscription *UserSubscription // 可选:订阅信息 + UserAgent string // 请求的 User-Agent + IPAddress string // 请求的客户端 IP 地址 + LongContextThreshold int // 长上下文阈值(如 200000) + LongContextMultiplier float64 // 超出阈值部分的倍率(如 2.0) +} + +// RecordUsageWithLongContext 记录使用量并扣费,支持长上下文双倍计费(用于 Gemini) +func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *RecordUsageLongContextInput) error { + result := input.Result + apiKey := input.APIKey + user := input.User + account := input.Account + subscription := input.Subscription + + // 获取费率倍数 + multiplier := s.cfg.Default.RateMultiplier + if apiKey.GroupID != nil && apiKey.Group != nil { + multiplier = apiKey.Group.RateMultiplier + } + + var cost *CostBreakdown + + // 根据请求类型选择计费方式 + if result.ImageCount > 0 { + // 图片生成计费 + var groupConfig *ImagePriceConfig + if apiKey.Group != nil { + groupConfig = &ImagePriceConfig{ + Price1K: apiKey.Group.ImagePrice1K, + Price2K: apiKey.Group.ImagePrice2K, + Price4K: apiKey.Group.ImagePrice4K, + } + } + cost = s.billingService.CalculateImageCost(result.Model, result.ImageSize, result.ImageCount, groupConfig, multiplier) + } else { + // Token 计费(使用长上下文计费方法) + tokens := UsageTokens{ + InputTokens: result.Usage.InputTokens, + OutputTokens: result.Usage.OutputTokens, + CacheCreationTokens: result.Usage.CacheCreationInputTokens, + CacheReadTokens: result.Usage.CacheReadInputTokens, + } + var err error + cost, err = s.billingService.CalculateCostWithLongContext(result.Model, tokens, multiplier, input.LongContextThreshold, input.LongContextMultiplier) + if err != nil { + log.Printf("Calculate cost failed: %v", err) + cost = &CostBreakdown{ActualCost: 0} + } + } + + // 判断计费方式:订阅模式 vs 余额模式 + isSubscriptionBilling := subscription != nil && apiKey.Group != nil && apiKey.Group.IsSubscriptionType() + billingType := BillingTypeBalance + if isSubscriptionBilling { + billingType = BillingTypeSubscription + } + + // 创建使用日志 + durationMs := int(result.Duration.Milliseconds()) + var imageSize *string + if result.ImageSize != "" { + imageSize = &result.ImageSize + } + accountRateMultiplier := account.BillingRateMultiplier() + usageLog := &UsageLog{ + UserID: user.ID, + APIKeyID: apiKey.ID, + AccountID: account.ID, + RequestID: result.RequestID, + Model: result.Model, + InputTokens: result.Usage.InputTokens, + OutputTokens: result.Usage.OutputTokens, + CacheCreationTokens: result.Usage.CacheCreationInputTokens, + CacheReadTokens: result.Usage.CacheReadInputTokens, + InputCost: cost.InputCost, + OutputCost: cost.OutputCost, + CacheCreationCost: cost.CacheCreationCost, + CacheReadCost: cost.CacheReadCost, + TotalCost: cost.TotalCost, + ActualCost: cost.ActualCost, + RateMultiplier: multiplier, + AccountRateMultiplier: &accountRateMultiplier, + BillingType: billingType, + Stream: result.Stream, + DurationMs: &durationMs, + FirstTokenMs: result.FirstTokenMs, + ImageCount: result.ImageCount, + ImageSize: imageSize, + CreatedAt: time.Now(), + } + + // 添加 UserAgent + if input.UserAgent != "" { + usageLog.UserAgent = &input.UserAgent + } + + // 添加 IPAddress + if input.IPAddress != "" { + usageLog.IPAddress = &input.IPAddress + } + + // 添加分组和订阅关联 + if apiKey.GroupID != nil { + usageLog.GroupID = apiKey.GroupID + } + if subscription != nil { + usageLog.SubscriptionID = &subscription.ID + } + + inserted, err := s.usageLogRepo.Create(ctx, usageLog) + if err != nil { + log.Printf("Create usage log failed: %v", err) + } + + if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { + log.Printf("[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens()) + s.deferredService.ScheduleLastUsedUpdate(account.ID) + return nil + } + + shouldBill := inserted || err != nil + + // 根据计费类型执行扣费 + if isSubscriptionBilling { + // 订阅模式:更新订阅用量(使用 TotalCost 原始费用,不考虑倍率) + if shouldBill && cost.TotalCost > 0 { + if err := s.userSubRepo.IncrementUsage(ctx, subscription.ID, cost.TotalCost); err != nil { + log.Printf("Increment subscription usage failed: %v", err) + } + // 异步更新订阅缓存 + s.billingCacheService.QueueUpdateSubscriptionUsage(user.ID, *apiKey.GroupID, cost.TotalCost) + } + } else { + // 余额模式:扣除用户余额(使用 ActualCost 考虑倍率后的费用) + if shouldBill && cost.ActualCost > 0 { + if err := s.userRepo.DeductBalance(ctx, user.ID, cost.ActualCost); err != nil { + log.Printf("Deduct balance failed: %v", err) + } + // 异步更新余额缓存 + s.billingCacheService.QueueDeductBalance(user.ID, cost.ActualCost) + } + } + + // Schedule batch update for account last_used_at + s.deferredService.ScheduleLastUsedUpdate(account.ID) + + return nil +} + // ForwardCountTokens 转发 count_tokens 请求到上游 API // 特点:不记录使用量、仅支持非流式响应 func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest) error {