refactor: 抽象统一计费会话 BillingSession

将散落在多个文件中的预扣费/结算/退款逻辑抽象为统一的 BillingSession 生命周期管理:

- 新增 BillingSettler 接口 (relay/common/billing.go) 避免循环引用
- 新增 FundingSource 接口 + WalletFunding / SubscriptionFunding 实现 (service/funding_source.go)
- 新增 BillingSession 封装预扣/结算/退款原子操作 (service/billing_session.go)
- 新增 SettleBilling 统一结算辅助函数,替换各 handler 中的 quotaDelta 模式
- 重写 PreConsumeBilling 为 BillingSession 工厂入口
- controller/relay.go 退款守卫改用 BillingSession.Refund()

修复的 Bug:
- 令牌额度泄漏:PreConsumeTokenQuota 成功但 DecreaseUserQuota 失败时未回滚
- 订阅退款遗漏:FinalPreConsumedQuota=0 但 SubscriptionPreConsumed>0 时跳过退款
- 订阅多扣费:subConsume 强制为 1 但 FinalPreConsumedQuota 不同步
- 退款路径不统一:钱包/订阅退款逻辑现统一由 FundingSource.Refund 分派
This commit is contained in:
CaIon
2026-02-06 23:14:25 +08:00
parent d814d62e2f
commit 0c0ccf510b
9 changed files with 545 additions and 278 deletions

321
service/billing_session.go Normal file
View File

@@ -0,0 +1,321 @@
package service
import (
"fmt"
"net/http"
"strings"
"sync"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/logger"
"github.com/QuantumNous/new-api/model"
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/QuantumNous/new-api/types"
"github.com/bytedance/gopkg/util/gopool"
"github.com/gin-gonic/gin"
)
// ---------------------------------------------------------------------------
// BillingSession — 统一计费会话
// ---------------------------------------------------------------------------
// BillingSession 封装单次请求的预扣费/结算/退款生命周期。
// 实现 relaycommon.BillingSettler 接口。
type BillingSession struct {
relayInfo *relaycommon.RelayInfo
funding FundingSource
preConsumedQuota int // 实际预扣额度(信任用户可能为 0
tokenConsumed int // 令牌额度实际扣减量
settled bool // Settle 已调用
refunded bool // Refund 已调用
mu sync.Mutex
}
// Settle 根据实际消耗额度进行结算。
func (s *BillingSession) Settle(actualQuota int) error {
s.mu.Lock()
defer s.mu.Unlock()
if s.settled {
return nil
}
delta := actualQuota - s.preConsumedQuota
if delta == 0 {
s.settled = true
return nil
}
// 1) 调整资金来源
if err := s.funding.Settle(delta); err != nil {
return err
}
// 2) 调整令牌额度
if !s.relayInfo.IsPlayground {
if delta > 0 {
if err := model.DecreaseTokenQuota(s.relayInfo.TokenId, s.relayInfo.TokenKey, delta); err != nil {
return err
}
} else {
if err := model.IncreaseTokenQuota(s.relayInfo.TokenId, s.relayInfo.TokenKey, -delta); err != nil {
return err
}
}
}
// 3) 更新 relayInfo 上的订阅 PostDelta用于日志
if s.funding.Source() == BillingSourceSubscription {
s.relayInfo.SubscriptionPostDelta += int64(delta)
}
s.settled = true
return nil
}
// Refund 退还所有预扣费,幂等安全,异步执行。
func (s *BillingSession) Refund(c *gin.Context) {
s.mu.Lock()
if s.settled || s.refunded || !s.needsRefundLocked() {
s.mu.Unlock()
return
}
s.refunded = true
s.mu.Unlock()
logger.LogInfo(c, fmt.Sprintf("用户 %d 请求失败, 返还预扣费token_quota=%s, funding=%s",
s.relayInfo.UserId,
logger.FormatQuota(s.tokenConsumed),
s.funding.Source(),
))
// 复制需要的值到闭包中
tokenId := s.relayInfo.TokenId
tokenKey := s.relayInfo.TokenKey
isPlayground := s.relayInfo.IsPlayground
tokenConsumed := s.tokenConsumed
funding := s.funding
gopool.Go(func() {
// 1) 退还资金来源
if err := funding.Refund(); err != nil {
common.SysLog("error refunding billing source: " + err.Error())
}
// 2) 退还令牌额度
if tokenConsumed > 0 && !isPlayground {
if err := model.IncreaseTokenQuota(tokenId, tokenKey, tokenConsumed); err != nil {
common.SysLog("error refunding token quota: " + err.Error())
}
}
})
}
// NeedsRefund 返回是否存在需要退还的预扣状态。
func (s *BillingSession) NeedsRefund() bool {
s.mu.Lock()
defer s.mu.Unlock()
return s.needsRefundLocked()
}
func (s *BillingSession) needsRefundLocked() bool {
if s.settled || s.refunded {
return false
}
if s.tokenConsumed > 0 {
return true
}
// 订阅可能在 tokenConsumed=0 时仍预扣了额度
if sub, ok := s.funding.(*SubscriptionFunding); ok && sub.preConsumed > 0 {
return true
}
return false
}
// GetPreConsumedQuota 返回实际预扣的额度。
func (s *BillingSession) GetPreConsumedQuota() int {
return s.preConsumedQuota
}
// ---------------------------------------------------------------------------
// PreConsume — 统一预扣费入口(含信任额度旁路)
// ---------------------------------------------------------------------------
// preConsume 执行预扣费:信任检查 -> 令牌预扣 -> 资金来源预扣。
// 任一步骤失败时原子回滚已完成的步骤。
func (s *BillingSession) preConsume(c *gin.Context, quota int) *types.NewAPIError {
effectiveQuota := quota
// ---- 信任额度旁路 ----
if s.shouldTrust(c) {
effectiveQuota = 0
logger.LogInfo(c, fmt.Sprintf("用户 %d 额度充足, 信任且不需要预扣费 (funding=%s)", s.relayInfo.UserId, s.funding.Source()))
} else if effectiveQuota > 0 {
logger.LogInfo(c, fmt.Sprintf("用户 %d 需要预扣费 %s (funding=%s)", s.relayInfo.UserId, logger.FormatQuota(effectiveQuota), s.funding.Source()))
}
// ---- 1) 预扣令牌额度 ----
if effectiveQuota > 0 {
if err := PreConsumeTokenQuota(s.relayInfo, effectiveQuota); err != nil {
return types.NewErrorWithStatusCode(err, types.ErrorCodePreConsumeTokenQuotaFailed, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
}
s.tokenConsumed = effectiveQuota
}
// ---- 2) 预扣资金来源 ----
if err := s.funding.PreConsume(effectiveQuota); err != nil {
// 回滚令牌额度
if s.tokenConsumed > 0 && !s.relayInfo.IsPlayground {
_ = model.IncreaseTokenQuota(s.relayInfo.TokenId, s.relayInfo.TokenKey, s.tokenConsumed)
s.tokenConsumed = 0
}
errMsg := err.Error()
if strings.Contains(errMsg, "no active subscription") || strings.Contains(errMsg, "subscription quota insufficient") {
return types.NewErrorWithStatusCode(fmt.Errorf("订阅额度不足或未配置订阅: %s", errMsg), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
}
if strings.Contains(errMsg, "用户额度不足") || strings.Contains(errMsg, "预扣费额度失败") {
return types.NewErrorWithStatusCode(err, types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
}
return types.NewError(err, types.ErrorCodeUpdateDataError, types.ErrOptionWithSkipRetry())
}
s.preConsumedQuota = effectiveQuota
// ---- 同步 RelayInfo 兼容字段 ----
s.syncRelayInfo()
return nil
}
// shouldTrust 统一信任额度检查,适用于钱包和订阅。
func (s *BillingSession) shouldTrust(c *gin.Context) bool {
trustQuota := common.GetTrustQuota()
if trustQuota <= 0 {
return false
}
// 检查令牌是否充足
tokenTrusted := s.relayInfo.TokenUnlimited
if !tokenTrusted {
tokenQuota := c.GetInt("token_quota")
tokenTrusted = tokenQuota > trustQuota
}
if !tokenTrusted {
return false
}
switch s.funding.Source() {
case BillingSourceWallet:
return s.relayInfo.UserQuota > trustQuota
case BillingSourceSubscription:
// 订阅暂不支持信任旁路(订阅剩余额度需要额外查询,且预扣开销小)
// 后续可以在此处添加订阅信任逻辑
return false
default:
return false
}
}
// syncRelayInfo 将 BillingSession 的状态同步到 RelayInfo 的兼容字段上。
func (s *BillingSession) syncRelayInfo() {
info := s.relayInfo
info.FinalPreConsumedQuota = s.preConsumedQuota
info.BillingSource = s.funding.Source()
if sub, ok := s.funding.(*SubscriptionFunding); ok {
info.SubscriptionId = sub.subscriptionId
info.SubscriptionPreConsumed = sub.preConsumed
info.SubscriptionPostDelta = 0
info.SubscriptionAmountTotal = sub.AmountTotal
info.SubscriptionAmountUsedAfterPreConsume = sub.AmountUsedAfter
info.SubscriptionPlanId = sub.PlanId
info.SubscriptionPlanTitle = sub.PlanTitle
} else {
info.SubscriptionId = 0
info.SubscriptionPreConsumed = 0
}
}
// ---------------------------------------------------------------------------
// NewBillingSession 工厂 — 根据计费偏好创建会话并处理回退
// ---------------------------------------------------------------------------
// NewBillingSession 根据用户计费偏好创建 BillingSession处理 subscription_first / wallet_first 的回退。
func NewBillingSession(c *gin.Context, relayInfo *relaycommon.RelayInfo, preConsumedQuota int) (*BillingSession, *types.NewAPIError) {
if relayInfo == nil {
return nil, types.NewError(fmt.Errorf("relayInfo is nil"), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
}
pref := common.NormalizeBillingPreference(relayInfo.UserSetting.BillingPreference)
// 钱包路径需要先检查用户额度
tryWallet := func() (*BillingSession, *types.NewAPIError) {
userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeQueryDataError, types.ErrOptionWithSkipRetry())
}
if userQuota <= 0 {
return nil, types.NewErrorWithStatusCode(
fmt.Errorf("用户额度不足, 剩余额度: %s", logger.FormatQuota(userQuota)),
types.ErrorCodeInsufficientUserQuota, http.StatusForbidden,
types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
}
if userQuota-preConsumedQuota < 0 {
return nil, types.NewErrorWithStatusCode(
fmt.Errorf("预扣费额度失败, 用户剩余额度: %s, 需要预扣费额度: %s", logger.FormatQuota(userQuota), logger.FormatQuota(preConsumedQuota)),
types.ErrorCodeInsufficientUserQuota, http.StatusForbidden,
types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
}
relayInfo.UserQuota = userQuota
session := &BillingSession{
relayInfo: relayInfo,
funding: &WalletFunding{userId: relayInfo.UserId},
}
if apiErr := session.preConsume(c, preConsumedQuota); apiErr != nil {
return nil, apiErr
}
return session, nil
}
trySubscription := func() (*BillingSession, *types.NewAPIError) {
subConsume := int64(preConsumedQuota)
if subConsume <= 0 {
subConsume = 1
}
session := &BillingSession{
relayInfo: relayInfo,
funding: &SubscriptionFunding{
requestId: relayInfo.RequestId,
userId: relayInfo.UserId,
modelName: relayInfo.OriginModelName,
amount: subConsume,
},
}
if apiErr := session.preConsume(c, preConsumedQuota); apiErr != nil {
return nil, apiErr
}
return session, nil
}
switch pref {
case "subscription_only":
return trySubscription()
case "wallet_only":
return tryWallet()
case "wallet_first":
session, err := tryWallet()
if err != nil {
if err.GetErrorCode() == types.ErrorCodeInsufficientUserQuota {
return trySubscription()
}
return nil, err
}
return session, nil
case "subscription_first":
fallthrough
default:
session, err := trySubscription()
if err != nil {
if err.GetErrorCode() == types.ErrorCodeInsufficientUserQuota {
return tryWallet()
}
return nil, err
}
return session, nil
}
}