diff --git a/backend/internal/handler/gemini_v1beta_handler.go b/backend/internal/handler/gemini_v1beta_handler.go index 32f83013..d1b19ede 100644 --- a/backend/internal/handler/gemini_v1beta_handler.go +++ b/backend/internal/handler/gemini_v1beta_handler.go @@ -366,18 +366,21 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { userAgent := c.GetHeader("User-Agent") clientIP := ip.GetClientIP(c) - // 6) record usage async + // 6) record usage async (Gemini 使用长上下文双倍计费) go func(result *service.ForwardResult, usedAccount *service.Account, ua, ip string) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ - Result: result, - APIKey: apiKey, - User: apiKey.User, - Account: usedAccount, - Subscription: subscription, - UserAgent: ua, - IPAddress: ip, + + if err := h.gatewayService.RecordUsageWithLongContext(ctx, &service.RecordUsageLongContextInput{ + Result: result, + APIKey: apiKey, + User: apiKey.User, + Account: usedAccount, + Subscription: subscription, + UserAgent: ua, + IPAddress: ip, + LongContextThreshold: 200000, // Gemini 200K 阈值 + LongContextMultiplier: 2.0, // 超出部分双倍计费 }); err != nil { log.Printf("Record usage failed: %v", err) } diff --git a/backend/internal/service/billing_service.go b/backend/internal/service/billing_service.go index f2afc343..db5a9708 100644 --- a/backend/internal/service/billing_service.go +++ b/backend/internal/service/billing_service.go @@ -241,6 +241,76 @@ func (s *BillingService) CalculateCostWithConfig(model string, tokens UsageToken return s.CalculateCost(model, tokens, multiplier) } +// CalculateCostWithLongContext 计算费用,支持长上下文双倍计费 +// threshold: 阈值(如 200000),超过此值的部分按 extraMultiplier 倍计费 +// extraMultiplier: 超出部分的倍率(如 2.0 表示双倍) +// +// 示例:缓存 210k + 输入 10k = 220k,阈值 200k,倍率 2.0 +// 拆分为:范围内 (200k, 0) + 范围外 (10k, 10k) +// 范围内正常计费,范围外 × 2 计费 +func (s *BillingService) CalculateCostWithLongContext(model string, tokens UsageTokens, rateMultiplier float64, threshold int, extraMultiplier float64) (*CostBreakdown, error) { + // 未启用长上下文计费,直接走正常计费 + if threshold <= 0 || extraMultiplier <= 1 { + return s.CalculateCost(model, tokens, rateMultiplier) + } + + // 计算总输入 token(缓存读取 + 新输入) + total := tokens.CacheReadTokens + tokens.InputTokens + if total <= threshold { + return s.CalculateCost(model, tokens, rateMultiplier) + } + + // 拆分成范围内和范围外 + var inRangeCacheTokens, inRangeInputTokens int + var outRangeCacheTokens, outRangeInputTokens int + + if tokens.CacheReadTokens >= threshold { + // 缓存已超过阈值:范围内只有缓存,范围外是超出的缓存+全部输入 + inRangeCacheTokens = threshold + inRangeInputTokens = 0 + outRangeCacheTokens = tokens.CacheReadTokens - threshold + outRangeInputTokens = tokens.InputTokens + } else { + // 缓存未超过阈值:范围内是全部缓存+部分输入,范围外是剩余输入 + inRangeCacheTokens = tokens.CacheReadTokens + inRangeInputTokens = threshold - tokens.CacheReadTokens + outRangeCacheTokens = 0 + outRangeInputTokens = tokens.InputTokens - inRangeInputTokens + } + + // 范围内部分:正常计费 + inRangeTokens := UsageTokens{ + InputTokens: inRangeInputTokens, + OutputTokens: tokens.OutputTokens, // 输出只算一次 + CacheCreationTokens: tokens.CacheCreationTokens, + CacheReadTokens: inRangeCacheTokens, + } + inRangeCost, err := s.CalculateCost(model, inRangeTokens, rateMultiplier) + if err != nil { + return nil, err + } + + // 范围外部分:× extraMultiplier 计费 + outRangeTokens := UsageTokens{ + InputTokens: outRangeInputTokens, + CacheReadTokens: outRangeCacheTokens, + } + outRangeCost, err := s.CalculateCost(model, outRangeTokens, rateMultiplier*extraMultiplier) + if err != nil { + return inRangeCost, nil // 出错时返回范围内成本 + } + + // 合并成本 + return &CostBreakdown{ + InputCost: inRangeCost.InputCost + outRangeCost.InputCost, + OutputCost: inRangeCost.OutputCost, + CacheCreationCost: inRangeCost.CacheCreationCost, + CacheReadCost: inRangeCost.CacheReadCost + outRangeCost.CacheReadCost, + TotalCost: inRangeCost.TotalCost + outRangeCost.TotalCost, + ActualCost: inRangeCost.ActualCost + outRangeCost.ActualCost, + }, nil +} + // ListSupportedModels 列出所有支持的模型(现在总是返回true,因为有模糊匹配) func (s *BillingService) ListSupportedModels() []string { models := make([]string, 0) diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 7a901907..9125163a 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -3606,6 +3606,162 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu return nil } +// RecordUsageLongContextInput 记录使用量的输入参数(支持长上下文双倍计费) +type RecordUsageLongContextInput struct { + Result *ForwardResult + APIKey *APIKey + User *User + Account *Account + Subscription *UserSubscription // 可选:订阅信息 + UserAgent string // 请求的 User-Agent + IPAddress string // 请求的客户端 IP 地址 + LongContextThreshold int // 长上下文阈值(如 200000) + LongContextMultiplier float64 // 超出阈值部分的倍率(如 2.0) +} + +// RecordUsageWithLongContext 记录使用量并扣费,支持长上下文双倍计费(用于 Gemini) +func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *RecordUsageLongContextInput) error { + result := input.Result + apiKey := input.APIKey + user := input.User + account := input.Account + subscription := input.Subscription + + // 获取费率倍数 + multiplier := s.cfg.Default.RateMultiplier + if apiKey.GroupID != nil && apiKey.Group != nil { + multiplier = apiKey.Group.RateMultiplier + } + + var cost *CostBreakdown + + // 根据请求类型选择计费方式 + if result.ImageCount > 0 { + // 图片生成计费 + var groupConfig *ImagePriceConfig + if apiKey.Group != nil { + groupConfig = &ImagePriceConfig{ + Price1K: apiKey.Group.ImagePrice1K, + Price2K: apiKey.Group.ImagePrice2K, + Price4K: apiKey.Group.ImagePrice4K, + } + } + cost = s.billingService.CalculateImageCost(result.Model, result.ImageSize, result.ImageCount, groupConfig, multiplier) + } else { + // Token 计费(使用长上下文计费方法) + tokens := UsageTokens{ + InputTokens: result.Usage.InputTokens, + OutputTokens: result.Usage.OutputTokens, + CacheCreationTokens: result.Usage.CacheCreationInputTokens, + CacheReadTokens: result.Usage.CacheReadInputTokens, + } + var err error + cost, err = s.billingService.CalculateCostWithLongContext(result.Model, tokens, multiplier, input.LongContextThreshold, input.LongContextMultiplier) + if err != nil { + log.Printf("Calculate cost failed: %v", err) + cost = &CostBreakdown{ActualCost: 0} + } + } + + // 判断计费方式:订阅模式 vs 余额模式 + isSubscriptionBilling := subscription != nil && apiKey.Group != nil && apiKey.Group.IsSubscriptionType() + billingType := BillingTypeBalance + if isSubscriptionBilling { + billingType = BillingTypeSubscription + } + + // 创建使用日志 + durationMs := int(result.Duration.Milliseconds()) + var imageSize *string + if result.ImageSize != "" { + imageSize = &result.ImageSize + } + accountRateMultiplier := account.BillingRateMultiplier() + usageLog := &UsageLog{ + UserID: user.ID, + APIKeyID: apiKey.ID, + AccountID: account.ID, + RequestID: result.RequestID, + Model: result.Model, + InputTokens: result.Usage.InputTokens, + OutputTokens: result.Usage.OutputTokens, + CacheCreationTokens: result.Usage.CacheCreationInputTokens, + CacheReadTokens: result.Usage.CacheReadInputTokens, + InputCost: cost.InputCost, + OutputCost: cost.OutputCost, + CacheCreationCost: cost.CacheCreationCost, + CacheReadCost: cost.CacheReadCost, + TotalCost: cost.TotalCost, + ActualCost: cost.ActualCost, + RateMultiplier: multiplier, + AccountRateMultiplier: &accountRateMultiplier, + BillingType: billingType, + Stream: result.Stream, + DurationMs: &durationMs, + FirstTokenMs: result.FirstTokenMs, + ImageCount: result.ImageCount, + ImageSize: imageSize, + CreatedAt: time.Now(), + } + + // 添加 UserAgent + if input.UserAgent != "" { + usageLog.UserAgent = &input.UserAgent + } + + // 添加 IPAddress + if input.IPAddress != "" { + usageLog.IPAddress = &input.IPAddress + } + + // 添加分组和订阅关联 + if apiKey.GroupID != nil { + usageLog.GroupID = apiKey.GroupID + } + if subscription != nil { + usageLog.SubscriptionID = &subscription.ID + } + + inserted, err := s.usageLogRepo.Create(ctx, usageLog) + if err != nil { + log.Printf("Create usage log failed: %v", err) + } + + if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { + log.Printf("[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens()) + s.deferredService.ScheduleLastUsedUpdate(account.ID) + return nil + } + + shouldBill := inserted || err != nil + + // 根据计费类型执行扣费 + if isSubscriptionBilling { + // 订阅模式:更新订阅用量(使用 TotalCost 原始费用,不考虑倍率) + if shouldBill && cost.TotalCost > 0 { + if err := s.userSubRepo.IncrementUsage(ctx, subscription.ID, cost.TotalCost); err != nil { + log.Printf("Increment subscription usage failed: %v", err) + } + // 异步更新订阅缓存 + s.billingCacheService.QueueUpdateSubscriptionUsage(user.ID, *apiKey.GroupID, cost.TotalCost) + } + } else { + // 余额模式:扣除用户余额(使用 ActualCost 考虑倍率后的费用) + if shouldBill && cost.ActualCost > 0 { + if err := s.userRepo.DeductBalance(ctx, user.ID, cost.ActualCost); err != nil { + log.Printf("Deduct balance failed: %v", err) + } + // 异步更新余额缓存 + s.billingCacheService.QueueDeductBalance(user.ID, cost.ActualCost) + } + } + + // Schedule batch update for account last_used_at + s.deferredService.ScheduleLastUsedUpdate(account.ID) + + return nil +} + // ForwardCountTokens 转发 count_tokens 请求到上游 API // 特点:不记录使用量、仅支持非流式响应 func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest) error {