Files
new-api-oiss/service/quota.go
CaIon 116004fd44 refactor: 抽象统一计费会话 BillingSession
将散落在多个文件中的预扣费/结算/退款逻辑抽象为统一的 BillingSession 生命周期管理:

- 新增 BillingSettler 接口 (relay/common/billing.go) 避免循环引用
- 新增 FundingSource 接口 + WalletFunding / SubscriptionFunding 实现 (service/funding_source.go)
- 新增 BillingSession 封装预扣/结算/退款原子操作 (service/billing_session.go)
- 新增 SettleBilling 统一结算辅助函数,替换各 handler 中的 quotaDelta 模式
- 重写 PreConsumeBilling 为 BillingSession 工厂入口
- controller/relay.go 退款守卫改用 BillingSession.Refund()

修复的 Bug:
- 令牌额度泄漏:PreConsumeTokenQuota 成功但 DecreaseUserQuota 失败时未回滚
- 订阅退款遗漏:FinalPreConsumedQuota=0 但 SubscriptionPreConsumed>0 时跳过退款
- 订阅多扣费:subConsume 强制为 1 但 FinalPreConsumedQuota 不同步
- 退款路径不统一:钱包/订阅退款逻辑现统一由 FundingSource.Refund 分派
2026-02-06 23:14:25 +08:00

559 lines
20 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package service
import (
"errors"
"fmt"
"log"
"math"
"strings"
"time"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/constant"
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/logger"
"github.com/QuantumNous/new-api/model"
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/QuantumNous/new-api/setting/ratio_setting"
"github.com/QuantumNous/new-api/setting/system_setting"
"github.com/QuantumNous/new-api/types"
"github.com/bytedance/gopkg/util/gopool"
"github.com/gin-gonic/gin"
"github.com/shopspring/decimal"
)
type TokenDetails struct {
TextTokens int
AudioTokens int
}
type QuotaInfo struct {
InputDetails TokenDetails
OutputDetails TokenDetails
ModelName string
UsePrice bool
ModelPrice float64
ModelRatio float64
GroupRatio float64
}
func hasCustomModelRatio(modelName string, currentRatio float64) bool {
defaultRatio, exists := ratio_setting.GetDefaultModelRatioMap()[modelName]
if !exists {
return true
}
return currentRatio != defaultRatio
}
func calculateAudioQuota(info QuotaInfo) int {
if info.UsePrice {
modelPrice := decimal.NewFromFloat(info.ModelPrice)
quotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit)
groupRatio := decimal.NewFromFloat(info.GroupRatio)
quota := modelPrice.Mul(quotaPerUnit).Mul(groupRatio)
return int(quota.IntPart())
}
completionRatio := decimal.NewFromFloat(ratio_setting.GetCompletionRatio(info.ModelName))
audioRatio := decimal.NewFromFloat(ratio_setting.GetAudioRatio(info.ModelName))
audioCompletionRatio := decimal.NewFromFloat(ratio_setting.GetAudioCompletionRatio(info.ModelName))
groupRatio := decimal.NewFromFloat(info.GroupRatio)
modelRatio := decimal.NewFromFloat(info.ModelRatio)
ratio := groupRatio.Mul(modelRatio)
inputTextTokens := decimal.NewFromInt(int64(info.InputDetails.TextTokens))
outputTextTokens := decimal.NewFromInt(int64(info.OutputDetails.TextTokens))
inputAudioTokens := decimal.NewFromInt(int64(info.InputDetails.AudioTokens))
outputAudioTokens := decimal.NewFromInt(int64(info.OutputDetails.AudioTokens))
quota := decimal.Zero
quota = quota.Add(inputTextTokens)
quota = quota.Add(outputTextTokens.Mul(completionRatio))
quota = quota.Add(inputAudioTokens.Mul(audioRatio))
quota = quota.Add(outputAudioTokens.Mul(audioRatio).Mul(audioCompletionRatio))
quota = quota.Mul(ratio)
// If ratio is not zero and quota is less than or equal to zero, set quota to 1
if !ratio.IsZero() && quota.LessThanOrEqual(decimal.Zero) {
quota = decimal.NewFromInt(1)
}
return int(quota.Round(0).IntPart())
}
func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.RealtimeUsage) error {
if relayInfo.UsePrice {
return nil
}
userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
if err != nil {
return err
}
token, err := model.GetTokenByKey(strings.TrimPrefix(relayInfo.TokenKey, "sk-"), false)
if err != nil {
return err
}
modelName := relayInfo.OriginModelName
textInputTokens := usage.InputTokenDetails.TextTokens
textOutTokens := usage.OutputTokenDetails.TextTokens
audioInputTokens := usage.InputTokenDetails.AudioTokens
audioOutTokens := usage.OutputTokenDetails.AudioTokens
groupRatio := ratio_setting.GetGroupRatio(relayInfo.UsingGroup)
modelRatio, _, _ := ratio_setting.GetModelRatio(modelName)
autoGroup, exists := common.GetContextKey(ctx, constant.ContextKeyAutoGroup)
if exists {
groupRatio = ratio_setting.GetGroupRatio(autoGroup.(string))
log.Printf("final group ratio: %f", groupRatio)
relayInfo.UsingGroup = autoGroup.(string)
}
actualGroupRatio := groupRatio
userGroupRatio, ok := ratio_setting.GetGroupGroupRatio(relayInfo.UserGroup, relayInfo.UsingGroup)
if ok {
actualGroupRatio = userGroupRatio
}
quotaInfo := QuotaInfo{
InputDetails: TokenDetails{
TextTokens: textInputTokens,
AudioTokens: audioInputTokens,
},
OutputDetails: TokenDetails{
TextTokens: textOutTokens,
AudioTokens: audioOutTokens,
},
ModelName: modelName,
UsePrice: relayInfo.UsePrice,
ModelRatio: modelRatio,
GroupRatio: actualGroupRatio,
}
quota := calculateAudioQuota(quotaInfo)
if userQuota < quota {
return fmt.Errorf("user quota is not enough, user quota: %s, need quota: %s", logger.FormatQuota(userQuota), logger.FormatQuota(quota))
}
if !token.UnlimitedQuota && token.RemainQuota < quota {
return fmt.Errorf("token quota is not enough, token remain quota: %s, need quota: %s", logger.FormatQuota(token.RemainQuota), logger.FormatQuota(quota))
}
err = PostConsumeQuota(relayInfo, quota, 0, false)
if err != nil {
return err
}
logger.LogInfo(ctx, "realtime streaming consume quota success, quota: "+fmt.Sprintf("%d", quota))
return nil
}
func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelName string,
usage *dto.RealtimeUsage, extraContent string) {
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
textInputTokens := usage.InputTokenDetails.TextTokens
textOutTokens := usage.OutputTokenDetails.TextTokens
audioInputTokens := usage.InputTokenDetails.AudioTokens
audioOutTokens := usage.OutputTokenDetails.AudioTokens
tokenName := ctx.GetString("token_name")
completionRatio := decimal.NewFromFloat(ratio_setting.GetCompletionRatio(modelName))
audioRatio := decimal.NewFromFloat(ratio_setting.GetAudioRatio(relayInfo.OriginModelName))
audioCompletionRatio := decimal.NewFromFloat(ratio_setting.GetAudioCompletionRatio(modelName))
modelRatio := relayInfo.PriceData.ModelRatio
groupRatio := relayInfo.PriceData.GroupRatioInfo.GroupRatio
modelPrice := relayInfo.PriceData.ModelPrice
usePrice := relayInfo.PriceData.UsePrice
quotaInfo := QuotaInfo{
InputDetails: TokenDetails{
TextTokens: textInputTokens,
AudioTokens: audioInputTokens,
},
OutputDetails: TokenDetails{
TextTokens: textOutTokens,
AudioTokens: audioOutTokens,
},
ModelName: modelName,
UsePrice: usePrice,
ModelRatio: modelRatio,
GroupRatio: groupRatio,
}
quota := calculateAudioQuota(quotaInfo)
totalTokens := usage.TotalTokens
var logContent string
if !usePrice {
logContent = fmt.Sprintf("模型倍率 %.2f,补全倍率 %.2f,音频倍率 %.2f,音频补全倍率 %.2f,分组倍率 %.2f",
modelRatio, completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), groupRatio)
} else {
logContent = fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f", modelPrice, groupRatio)
}
// record all the consume log even if quota is 0
if totalTokens == 0 {
// in this case, must be some error happened
// we cannot just return, because we may have to return the pre-consumed quota
quota = 0
logContent += fmt.Sprintf("(可能是上游超时)")
logger.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
"tokenId %d, model %s pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, relayInfo.FinalPreConsumedQuota))
} else {
model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
}
logModel := modelName
if extraContent != "" {
logContent += ", " + extraContent
}
other := GenerateWssOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio,
completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice, relayInfo.PriceData.GroupRatioInfo.GroupSpecialRatio)
model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{
ChannelId: relayInfo.ChannelId,
PromptTokens: usage.InputTokens,
CompletionTokens: usage.OutputTokens,
ModelName: logModel,
TokenName: tokenName,
Quota: quota,
Content: logContent,
TokenId: relayInfo.TokenId,
UseTimeSeconds: int(useTimeSeconds),
IsStream: relayInfo.IsStream,
Group: relayInfo.UsingGroup,
Other: other,
})
}
func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.Usage) {
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
promptTokens := usage.PromptTokens
completionTokens := usage.CompletionTokens
modelName := relayInfo.OriginModelName
tokenName := ctx.GetString("token_name")
completionRatio := relayInfo.PriceData.CompletionRatio
modelRatio := relayInfo.PriceData.ModelRatio
groupRatio := relayInfo.PriceData.GroupRatioInfo.GroupRatio
modelPrice := relayInfo.PriceData.ModelPrice
cacheRatio := relayInfo.PriceData.CacheRatio
cacheTokens := usage.PromptTokensDetails.CachedTokens
cacheCreationRatio := relayInfo.PriceData.CacheCreationRatio
cacheCreationRatio5m := relayInfo.PriceData.CacheCreation5mRatio
cacheCreationRatio1h := relayInfo.PriceData.CacheCreation1hRatio
cacheCreationTokens := usage.PromptTokensDetails.CachedCreationTokens
cacheCreationTokens5m := usage.ClaudeCacheCreation5mTokens
cacheCreationTokens1h := usage.ClaudeCacheCreation1hTokens
if relayInfo.ChannelType == constant.ChannelTypeOpenRouter {
promptTokens -= cacheTokens
isUsingCustomSettings := relayInfo.PriceData.UsePrice || hasCustomModelRatio(modelName, relayInfo.PriceData.ModelRatio)
if cacheCreationTokens == 0 && relayInfo.PriceData.CacheCreationRatio != 1 && usage.Cost != 0 && !isUsingCustomSettings {
maybeCacheCreationTokens := CalcOpenRouterCacheCreateTokens(*usage, relayInfo.PriceData)
if maybeCacheCreationTokens >= 0 && promptTokens >= maybeCacheCreationTokens {
cacheCreationTokens = maybeCacheCreationTokens
}
}
promptTokens -= cacheCreationTokens
}
calculateQuota := 0.0
if !relayInfo.PriceData.UsePrice {
calculateQuota = float64(promptTokens)
calculateQuota += float64(cacheTokens) * cacheRatio
calculateQuota += float64(cacheCreationTokens5m) * cacheCreationRatio5m
calculateQuota += float64(cacheCreationTokens1h) * cacheCreationRatio1h
remainingCacheCreationTokens := cacheCreationTokens - cacheCreationTokens5m - cacheCreationTokens1h
if remainingCacheCreationTokens > 0 {
calculateQuota += float64(remainingCacheCreationTokens) * cacheCreationRatio
}
calculateQuota += float64(completionTokens) * completionRatio
calculateQuota = calculateQuota * groupRatio * modelRatio
} else {
calculateQuota = modelPrice * common.QuotaPerUnit * groupRatio
}
if modelRatio != 0 && calculateQuota <= 0 {
calculateQuota = 1
}
quota := int(calculateQuota)
totalTokens := promptTokens + completionTokens
var logContent string
// record all the consume log even if quota is 0
if totalTokens == 0 {
// in this case, must be some error happened
// we cannot just return, because we may have to return the pre-consumed quota
quota = 0
logContent += fmt.Sprintf("(可能是上游出错)")
logger.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
"tokenId %d, model %s pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, relayInfo.FinalPreConsumedQuota))
} else {
model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
}
if err := SettleBilling(ctx, relayInfo, quota); err != nil {
logger.LogError(ctx, "error settling billing: "+err.Error())
}
other := GenerateClaudeOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio,
cacheTokens, cacheRatio,
cacheCreationTokens, cacheCreationRatio,
cacheCreationTokens5m, cacheCreationRatio5m,
cacheCreationTokens1h, cacheCreationRatio1h,
modelPrice, relayInfo.PriceData.GroupRatioInfo.GroupSpecialRatio)
model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{
ChannelId: relayInfo.ChannelId,
PromptTokens: promptTokens,
CompletionTokens: completionTokens,
ModelName: modelName,
TokenName: tokenName,
Quota: quota,
Content: logContent,
TokenId: relayInfo.TokenId,
UseTimeSeconds: int(useTimeSeconds),
IsStream: relayInfo.IsStream,
Group: relayInfo.UsingGroup,
Other: other,
})
}
func CalcOpenRouterCacheCreateTokens(usage dto.Usage, priceData types.PriceData) int {
if priceData.CacheCreationRatio == 1 {
return 0
}
quotaPrice := priceData.ModelRatio / common.QuotaPerUnit
promptCacheCreatePrice := quotaPrice * priceData.CacheCreationRatio
promptCacheReadPrice := quotaPrice * priceData.CacheRatio
completionPrice := quotaPrice * priceData.CompletionRatio
cost, _ := usage.Cost.(float64)
totalPromptTokens := float64(usage.PromptTokens)
completionTokens := float64(usage.CompletionTokens)
promptCacheReadTokens := float64(usage.PromptTokensDetails.CachedTokens)
return int(math.Round((cost -
totalPromptTokens*quotaPrice +
promptCacheReadTokens*(quotaPrice-promptCacheReadPrice) -
completionTokens*completionPrice) /
(promptCacheCreatePrice - quotaPrice)))
}
func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.Usage, extraContent string) {
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
textInputTokens := usage.PromptTokensDetails.TextTokens
textOutTokens := usage.CompletionTokenDetails.TextTokens
audioInputTokens := usage.PromptTokensDetails.AudioTokens
audioOutTokens := usage.CompletionTokenDetails.AudioTokens
tokenName := ctx.GetString("token_name")
completionRatio := decimal.NewFromFloat(ratio_setting.GetCompletionRatio(relayInfo.OriginModelName))
audioRatio := decimal.NewFromFloat(ratio_setting.GetAudioRatio(relayInfo.OriginModelName))
audioCompletionRatio := decimal.NewFromFloat(ratio_setting.GetAudioCompletionRatio(relayInfo.OriginModelName))
modelRatio := relayInfo.PriceData.ModelRatio
groupRatio := relayInfo.PriceData.GroupRatioInfo.GroupRatio
modelPrice := relayInfo.PriceData.ModelPrice
usePrice := relayInfo.PriceData.UsePrice
quotaInfo := QuotaInfo{
InputDetails: TokenDetails{
TextTokens: textInputTokens,
AudioTokens: audioInputTokens,
},
OutputDetails: TokenDetails{
TextTokens: textOutTokens,
AudioTokens: audioOutTokens,
},
ModelName: relayInfo.OriginModelName,
UsePrice: usePrice,
ModelRatio: modelRatio,
GroupRatio: groupRatio,
}
quota := calculateAudioQuota(quotaInfo)
totalTokens := usage.TotalTokens
var logContent string
if !usePrice {
logContent = fmt.Sprintf("模型倍率 %.2f,补全倍率 %.2f,音频倍率 %.2f,音频补全倍率 %.2f,分组倍率 %.2f",
modelRatio, completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), groupRatio)
} else {
logContent = fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f", modelPrice, groupRatio)
}
// record all the consume log even if quota is 0
if totalTokens == 0 {
// in this case, must be some error happened
// we cannot just return, because we may have to return the pre-consumed quota
quota = 0
logContent += fmt.Sprintf("(可能是上游超时)")
logger.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
"tokenId %d, model %s pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, relayInfo.OriginModelName, relayInfo.FinalPreConsumedQuota))
} else {
model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
}
if err := SettleBilling(ctx, relayInfo, quota); err != nil {
logger.LogError(ctx, "error settling billing: "+err.Error())
}
logModel := relayInfo.OriginModelName
if extraContent != "" {
logContent += ", " + extraContent
}
other := GenerateAudioOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio,
completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice, relayInfo.PriceData.GroupRatioInfo.GroupSpecialRatio)
model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{
ChannelId: relayInfo.ChannelId,
PromptTokens: usage.PromptTokens,
CompletionTokens: usage.CompletionTokens,
ModelName: logModel,
TokenName: tokenName,
Quota: quota,
Content: logContent,
TokenId: relayInfo.TokenId,
UseTimeSeconds: int(useTimeSeconds),
IsStream: relayInfo.IsStream,
Group: relayInfo.UsingGroup,
Other: other,
})
}
func PreConsumeTokenQuota(relayInfo *relaycommon.RelayInfo, quota int) error {
if quota < 0 {
return errors.New("quota 不能为负数!")
}
if relayInfo.IsPlayground {
return nil
}
//if relayInfo.TokenUnlimited {
// return nil
//}
token, err := model.GetTokenByKey(relayInfo.TokenKey, false)
if err != nil {
return err
}
if !relayInfo.TokenUnlimited && token.RemainQuota < quota {
return fmt.Errorf("token quota is not enough, token remain quota: %s, need quota: %s", logger.FormatQuota(token.RemainQuota), logger.FormatQuota(quota))
}
err = model.DecreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, quota)
if err != nil {
return err
}
return nil
}
func PostConsumeQuota(relayInfo *relaycommon.RelayInfo, quota int, preConsumedQuota int, sendEmail bool) (err error) {
// 1) Consume from wallet quota OR subscription item
if relayInfo != nil && relayInfo.BillingSource == BillingSourceSubscription {
if relayInfo.SubscriptionId == 0 {
return errors.New("subscription id is missing")
}
delta := int64(quota)
if delta != 0 {
if err := model.PostConsumeUserSubscriptionDelta(relayInfo.SubscriptionId, delta); err != nil {
return err
}
relayInfo.SubscriptionPostDelta += delta
}
} else {
// Wallet
if quota > 0 {
err = model.DecreaseUserQuota(relayInfo.UserId, quota)
} else {
err = model.IncreaseUserQuota(relayInfo.UserId, -quota, false)
}
if err != nil {
return err
}
}
if !relayInfo.IsPlayground {
if quota > 0 {
err = model.DecreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, quota)
} else {
err = model.IncreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, -quota)
}
if err != nil {
return err
}
}
if sendEmail {
if (quota + preConsumedQuota) != 0 {
checkAndSendQuotaNotify(relayInfo, quota, preConsumedQuota)
}
}
return nil
}
func checkAndSendQuotaNotify(relayInfo *relaycommon.RelayInfo, quota int, preConsumedQuota int) {
gopool.Go(func() {
userSetting := relayInfo.UserSetting
threshold := common.QuotaRemindThreshold
if userSetting.QuotaWarningThreshold != 0 {
threshold = int(userSetting.QuotaWarningThreshold)
}
//noMoreQuota := userCache.Quota-(quota+preConsumedQuota) <= 0
quotaTooLow := false
consumeQuota := quota + preConsumedQuota
if relayInfo.UserQuota-consumeQuota < threshold {
quotaTooLow = true
}
if quotaTooLow {
prompt := "您的额度即将用尽"
topUpLink := fmt.Sprintf("%s/console/topup", system_setting.ServerAddress)
// 根据通知方式生成不同的内容格式
var content string
var values []interface{}
notifyType := userSetting.NotifyType
if notifyType == "" {
notifyType = dto.NotifyTypeEmail
}
if notifyType == dto.NotifyTypeBark {
// Bark推送使用简短文本不支持HTML
content = "{{value}},剩余额度:{{value}},请及时充值"
values = []interface{}{prompt, logger.FormatQuota(relayInfo.UserQuota)}
} else if notifyType == dto.NotifyTypeGotify {
content = "{{value}},当前剩余额度为 {{value}},请及时充值。"
values = []interface{}{prompt, logger.FormatQuota(relayInfo.UserQuota)}
} else {
// 默认内容格式适用于Email和Webhook支持HTML
content = "{{value}},当前剩余额度为 {{value}},为了不影响您的使用,请及时充值。<br/>充值链接:<a href='{{value}}'>{{value}}</a>"
values = []interface{}{prompt, logger.FormatQuota(relayInfo.UserQuota), topUpLink, topUpLink}
}
err := NotifyUser(relayInfo.UserId, relayInfo.UserEmail, relayInfo.UserSetting, dto.NewNotify(dto.NotifyTypeQuotaExceed, prompt, content, values))
if err != nil {
common.SysError(fmt.Sprintf("failed to send quota notify to user %d: %s", relayInfo.UserId, err.Error()))
}
}
})
}