feat(billing): 添加 Gemini 200K 长上下文双倍计费功能
- 新增 CalculateCostWithLongContext 方法支持阈值双倍计费 - 新增 RecordUsageWithLongContext 方法专用于 Gemini 计费 - Gemini 超过 200K token 的部分按 2 倍费率计算 - 其他平台(Claude/OpenAI)完全不受影响
This commit is contained in:
@@ -366,18 +366,21 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
clientIP := ip.GetClientIP(c)
|
||||
|
||||
// 6) record usage async
|
||||
// 6) record usage async (Gemini 使用长上下文双倍计费)
|
||||
go func(result *service.ForwardResult, usedAccount *service.Account, ua, ip string) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: usedAccount,
|
||||
Subscription: subscription,
|
||||
UserAgent: ua,
|
||||
IPAddress: ip,
|
||||
|
||||
if err := h.gatewayService.RecordUsageWithLongContext(ctx, &service.RecordUsageLongContextInput{
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: usedAccount,
|
||||
Subscription: subscription,
|
||||
UserAgent: ua,
|
||||
IPAddress: ip,
|
||||
LongContextThreshold: 200000, // Gemini 200K 阈值
|
||||
LongContextMultiplier: 2.0, // 超出部分双倍计费
|
||||
}); err != nil {
|
||||
log.Printf("Record usage failed: %v", err)
|
||||
}
|
||||
|
||||
@@ -241,6 +241,65 @@ func (s *BillingService) CalculateCostWithConfig(model string, tokens UsageToken
|
||||
return s.CalculateCost(model, tokens, multiplier)
|
||||
}
|
||||
|
||||
// CalculateCostWithLongContext 计算费用,支持长上下文双倍计费
|
||||
// threshold: 阈值(如 200000),超过此值的部分按 extraMultiplier 倍计费
|
||||
// extraMultiplier: 超出部分的倍率(如 2.0 表示双倍)
|
||||
func (s *BillingService) CalculateCostWithLongContext(model string, tokens UsageTokens, rateMultiplier float64, threshold int, extraMultiplier float64) (*CostBreakdown, error) {
|
||||
// 1. 先正常计算全部 token 的成本
|
||||
cost, err := s.CalculateCost(model, tokens, rateMultiplier)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 2. 如果未启用长上下文计费或未超过阈值,直接返回
|
||||
if threshold <= 0 || extraMultiplier <= 1 {
|
||||
return cost, nil
|
||||
}
|
||||
|
||||
// 计算总输入 token(缓存读取 + 新输入)
|
||||
total := tokens.CacheReadTokens + tokens.InputTokens
|
||||
if total <= threshold {
|
||||
return cost, nil
|
||||
}
|
||||
|
||||
// 3. 拆分超出部分的 token
|
||||
extra := total - threshold
|
||||
var extraCacheTokens, extraInputTokens int
|
||||
|
||||
if tokens.CacheReadTokens >= threshold {
|
||||
// 缓存已超过阈值:超出的缓存 + 全部输入
|
||||
extraCacheTokens = tokens.CacheReadTokens - threshold
|
||||
extraInputTokens = tokens.InputTokens
|
||||
} else {
|
||||
// 缓存未超过阈值:只有输入超出部分
|
||||
extraCacheTokens = 0
|
||||
extraInputTokens = extra
|
||||
}
|
||||
|
||||
// 4. 计算超出部分的成本(只算输入和缓存读取)
|
||||
extraTokens := UsageTokens{
|
||||
InputTokens: extraInputTokens,
|
||||
CacheReadTokens: extraCacheTokens,
|
||||
}
|
||||
extraCost, err := s.CalculateCost(model, extraTokens, 1.0) // 先按 1 倍算
|
||||
if err != nil {
|
||||
return cost, nil // 出错时返回正常成本
|
||||
}
|
||||
|
||||
// 5. 额外成本 = 超出部分成本 × (倍率 - 1)
|
||||
extraRate := extraMultiplier - 1
|
||||
additionalInputCost := extraCost.InputCost * extraRate
|
||||
additionalCacheCost := extraCost.CacheReadCost * extraRate
|
||||
|
||||
// 6. 累加到总成本
|
||||
cost.InputCost += additionalInputCost
|
||||
cost.CacheReadCost += additionalCacheCost
|
||||
cost.TotalCost += additionalInputCost + additionalCacheCost
|
||||
cost.ActualCost = cost.TotalCost * rateMultiplier
|
||||
|
||||
return cost, nil
|
||||
}
|
||||
|
||||
// ListSupportedModels 列出所有支持的模型(现在总是返回true,因为有模糊匹配)
|
||||
func (s *BillingService) ListSupportedModels() []string {
|
||||
models := make([]string, 0)
|
||||
|
||||
@@ -3606,6 +3606,162 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
return nil
|
||||
}
|
||||
|
||||
// RecordUsageLongContextInput 记录使用量的输入参数(支持长上下文双倍计费)
|
||||
type RecordUsageLongContextInput struct {
|
||||
Result *ForwardResult
|
||||
APIKey *APIKey
|
||||
User *User
|
||||
Account *Account
|
||||
Subscription *UserSubscription // 可选:订阅信息
|
||||
UserAgent string // 请求的 User-Agent
|
||||
IPAddress string // 请求的客户端 IP 地址
|
||||
LongContextThreshold int // 长上下文阈值(如 200000)
|
||||
LongContextMultiplier float64 // 超出阈值部分的倍率(如 2.0)
|
||||
}
|
||||
|
||||
// RecordUsageWithLongContext 记录使用量并扣费,支持长上下文双倍计费(用于 Gemini)
|
||||
func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *RecordUsageLongContextInput) error {
|
||||
result := input.Result
|
||||
apiKey := input.APIKey
|
||||
user := input.User
|
||||
account := input.Account
|
||||
subscription := input.Subscription
|
||||
|
||||
// 获取费率倍数
|
||||
multiplier := s.cfg.Default.RateMultiplier
|
||||
if apiKey.GroupID != nil && apiKey.Group != nil {
|
||||
multiplier = apiKey.Group.RateMultiplier
|
||||
}
|
||||
|
||||
var cost *CostBreakdown
|
||||
|
||||
// 根据请求类型选择计费方式
|
||||
if result.ImageCount > 0 {
|
||||
// 图片生成计费
|
||||
var groupConfig *ImagePriceConfig
|
||||
if apiKey.Group != nil {
|
||||
groupConfig = &ImagePriceConfig{
|
||||
Price1K: apiKey.Group.ImagePrice1K,
|
||||
Price2K: apiKey.Group.ImagePrice2K,
|
||||
Price4K: apiKey.Group.ImagePrice4K,
|
||||
}
|
||||
}
|
||||
cost = s.billingService.CalculateImageCost(result.Model, result.ImageSize, result.ImageCount, groupConfig, multiplier)
|
||||
} else {
|
||||
// Token 计费(使用长上下文计费方法)
|
||||
tokens := UsageTokens{
|
||||
InputTokens: result.Usage.InputTokens,
|
||||
OutputTokens: result.Usage.OutputTokens,
|
||||
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
|
||||
CacheReadTokens: result.Usage.CacheReadInputTokens,
|
||||
}
|
||||
var err error
|
||||
cost, err = s.billingService.CalculateCostWithLongContext(result.Model, tokens, multiplier, input.LongContextThreshold, input.LongContextMultiplier)
|
||||
if err != nil {
|
||||
log.Printf("Calculate cost failed: %v", err)
|
||||
cost = &CostBreakdown{ActualCost: 0}
|
||||
}
|
||||
}
|
||||
|
||||
// 判断计费方式:订阅模式 vs 余额模式
|
||||
isSubscriptionBilling := subscription != nil && apiKey.Group != nil && apiKey.Group.IsSubscriptionType()
|
||||
billingType := BillingTypeBalance
|
||||
if isSubscriptionBilling {
|
||||
billingType = BillingTypeSubscription
|
||||
}
|
||||
|
||||
// 创建使用日志
|
||||
durationMs := int(result.Duration.Milliseconds())
|
||||
var imageSize *string
|
||||
if result.ImageSize != "" {
|
||||
imageSize = &result.ImageSize
|
||||
}
|
||||
accountRateMultiplier := account.BillingRateMultiplier()
|
||||
usageLog := &UsageLog{
|
||||
UserID: user.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
RequestID: result.RequestID,
|
||||
Model: result.Model,
|
||||
InputTokens: result.Usage.InputTokens,
|
||||
OutputTokens: result.Usage.OutputTokens,
|
||||
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
|
||||
CacheReadTokens: result.Usage.CacheReadInputTokens,
|
||||
InputCost: cost.InputCost,
|
||||
OutputCost: cost.OutputCost,
|
||||
CacheCreationCost: cost.CacheCreationCost,
|
||||
CacheReadCost: cost.CacheReadCost,
|
||||
TotalCost: cost.TotalCost,
|
||||
ActualCost: cost.ActualCost,
|
||||
RateMultiplier: multiplier,
|
||||
AccountRateMultiplier: &accountRateMultiplier,
|
||||
BillingType: billingType,
|
||||
Stream: result.Stream,
|
||||
DurationMs: &durationMs,
|
||||
FirstTokenMs: result.FirstTokenMs,
|
||||
ImageCount: result.ImageCount,
|
||||
ImageSize: imageSize,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
// 添加 UserAgent
|
||||
if input.UserAgent != "" {
|
||||
usageLog.UserAgent = &input.UserAgent
|
||||
}
|
||||
|
||||
// 添加 IPAddress
|
||||
if input.IPAddress != "" {
|
||||
usageLog.IPAddress = &input.IPAddress
|
||||
}
|
||||
|
||||
// 添加分组和订阅关联
|
||||
if apiKey.GroupID != nil {
|
||||
usageLog.GroupID = apiKey.GroupID
|
||||
}
|
||||
if subscription != nil {
|
||||
usageLog.SubscriptionID = &subscription.ID
|
||||
}
|
||||
|
||||
inserted, err := s.usageLogRepo.Create(ctx, usageLog)
|
||||
if err != nil {
|
||||
log.Printf("Create usage log failed: %v", err)
|
||||
}
|
||||
|
||||
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
|
||||
log.Printf("[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens())
|
||||
s.deferredService.ScheduleLastUsedUpdate(account.ID)
|
||||
return nil
|
||||
}
|
||||
|
||||
shouldBill := inserted || err != nil
|
||||
|
||||
// 根据计费类型执行扣费
|
||||
if isSubscriptionBilling {
|
||||
// 订阅模式:更新订阅用量(使用 TotalCost 原始费用,不考虑倍率)
|
||||
if shouldBill && cost.TotalCost > 0 {
|
||||
if err := s.userSubRepo.IncrementUsage(ctx, subscription.ID, cost.TotalCost); err != nil {
|
||||
log.Printf("Increment subscription usage failed: %v", err)
|
||||
}
|
||||
// 异步更新订阅缓存
|
||||
s.billingCacheService.QueueUpdateSubscriptionUsage(user.ID, *apiKey.GroupID, cost.TotalCost)
|
||||
}
|
||||
} else {
|
||||
// 余额模式:扣除用户余额(使用 ActualCost 考虑倍率后的费用)
|
||||
if shouldBill && cost.ActualCost > 0 {
|
||||
if err := s.userRepo.DeductBalance(ctx, user.ID, cost.ActualCost); err != nil {
|
||||
log.Printf("Deduct balance failed: %v", err)
|
||||
}
|
||||
// 异步更新余额缓存
|
||||
s.billingCacheService.QueueDeductBalance(user.ID, cost.ActualCost)
|
||||
}
|
||||
}
|
||||
|
||||
// Schedule batch update for account last_used_at
|
||||
s.deferredService.ScheduleLastUsedUpdate(account.ID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ForwardCountTokens 转发 count_tokens 请求到上游 API
|
||||
// 特点:不记录使用量、仅支持非流式响应
|
||||
func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest) error {
|
||||
|
||||
Reference in New Issue
Block a user