feat(billing): 网关计费迁移到 CalculateCostUnified + 模型限制错误统一

- GatewayService/OpenAIGatewayService 注入 ModelPricingResolver
- RecordUsage 从旧路径迁移到 CalculateCostUnified(支持 per_request/image 模式)
- 无渠道时自动回退旧路径,保持原有行为
- 长上下文双倍计费仅在无渠道定价时生效
- CostBreakdown 新增 BillingMode 字段,使用日志记录实际计费模式
- 模型限制错误改为与"无可用账号"相同的 503 响应
This commit is contained in:
erio
2026-03-30 22:58:28 +08:00
parent a51e0047b7
commit 632035aabd
11 changed files with 96 additions and 30 deletions

View File

@@ -104,6 +104,7 @@ type CostBreakdown struct {
CacheReadCost float64
TotalCost float64
ActualCost float64 // 应用倍率后的实际费用
BillingMode string // 计费模式("token"/"per_request"/"image"),由 CalculateCostUnified 填充
}
// BillingService 计费服务
@@ -439,12 +440,21 @@ func (s *BillingService) CalculateCostUnified(input CostInput) (*CostBreakdown,
input.RateMultiplier = 1.0
}
var breakdown *CostBreakdown
var err error
switch resolved.Mode {
case BillingModePerRequest, BillingModeImage:
return s.calculatePerRequestCost(resolved, input)
breakdown, err = s.calculatePerRequestCost(resolved, input)
default: // BillingModeToken
return s.calculateTokenCost(resolved, input)
breakdown, err = s.calculateTokenCost(resolved, input)
}
if err == nil && breakdown != nil {
breakdown.BillingMode = string(resolved.Mode)
if breakdown.BillingMode == "" {
breakdown.BillingMode = string(BillingModeToken)
}
}
return breakdown, err
}
// calculateTokenCost 按 token 区间计费

View File

@@ -42,6 +42,7 @@ func newGatewayRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo
nil,
nil,
nil,
nil,
)
}

View File

@@ -569,6 +569,7 @@ type GatewayService struct {
debugModelRouting atomic.Bool
debugClaudeMimic atomic.Bool
channelService *ChannelService
resolver *ModelPricingResolver
debugGatewayBodyFile atomic.Pointer[os.File] // non-nil when SUB2API_DEBUG_GATEWAY_BODY is set
tlsFPProfileService *TLSFingerprintProfileService
}
@@ -599,6 +600,7 @@ func NewGatewayService(
settingService *SettingService,
tlsFPProfileService *TLSFingerprintProfileService,
channelService *ChannelService,
resolver *ModelPricingResolver,
) *GatewayService {
userGroupRateTTL := resolveUserGroupRateCacheTTL(cfg)
modelsListTTL := resolveModelsListCacheTTL(cfg)
@@ -632,6 +634,7 @@ func NewGatewayService(
responseHeaderFilter: compileResponseHeaderFilter(cfg),
tlsFPProfileService: tlsFPProfileService,
channelService: channelService,
resolver: resolver,
}
svc.userGroupRateResolver = newUserGroupRateResolver(
userGroupRateRepo,
@@ -7790,13 +7793,21 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
CacheCreation1hTokens: result.Usage.CacheCreation1hTokens,
}
var err error
// 渠道定价覆盖
var chPricing *ChannelModelPricing
if s.channelService != nil && apiKey.Group != nil {
chPricing = s.channelService.GetChannelModelPricing(ctx, apiKey.Group.ID, billingModel)
}
if chPricing != nil {
cost, err = s.billingService.CalculateCostWithChannel(billingModel, tokens, multiplier, chPricing)
if s.resolver != nil && apiKey.Group != nil {
var groupID *int64
if apiKey.Group != nil {
gid := apiKey.Group.ID
groupID = &gid
}
cost, err = s.billingService.CalculateCostUnified(CostInput{
Ctx: ctx,
Model: billingModel,
GroupID: groupID,
Tokens: tokens,
RequestCount: 1,
RateMultiplier: multiplier,
Resolver: s.resolver,
})
} else {
cost, err = s.billingService.CalculateCost(billingModel, tokens, multiplier)
}
@@ -7868,6 +7879,9 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
if result.ImageCount > 0 {
billingMode := "image"
usageLog.BillingMode = &billingMode
} else if cost != nil && cost.BillingMode != "" {
billingMode := cost.BillingMode
usageLog.BillingMode = &billingMode
} else {
billingMode := "token"
usageLog.BillingMode = &billingMode
@@ -8016,14 +8030,30 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
CacheCreation1hTokens: result.Usage.CacheCreation1hTokens,
}
var err error
// 渠道定价覆盖
var chPricing2 *ChannelModelPricing
if s.channelService != nil && apiKey.Group != nil {
chPricing2 = s.channelService.GetChannelModelPricing(ctx, apiKey.Group.ID, billingModel)
// 优先尝试 Resolver + CalculateCostUnified仅在有渠道定价时使用
useUnified := false
if s.resolver != nil && apiKey.Group != nil {
gid := apiKey.Group.ID
resolved := s.resolver.Resolve(ctx, PricingInput{
Model: billingModel,
GroupID: &gid,
})
if resolved.Source == "channel" {
// 有渠道定价,渠道区间已包含上下文分层
cost, err = s.billingService.CalculateCostUnified(CostInput{
Ctx: ctx,
Model: billingModel,
GroupID: &gid,
Tokens: tokens,
RequestCount: 1,
RateMultiplier: multiplier,
Resolver: s.resolver,
})
useUnified = true
}
}
if chPricing2 != nil {
cost, err = s.billingService.CalculateCostWithChannel(billingModel, tokens, multiplier, chPricing2)
} else {
if !useUnified {
// 无渠道定价,保持原有长上下文双倍计费逻辑(如 Gemini 200K 阈值)
cost, err = s.billingService.CalculateCostWithLongContext(billingModel, tokens, multiplier, input.LongContextThreshold, input.LongContextMultiplier)
}
if err != nil {
@@ -8088,6 +8118,9 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
if result.ImageCount > 0 {
billingMode := "image"
usageLog.BillingMode = &billingMode
} else if cost != nil && cost.BillingMode != "" {
billingMode := cost.BillingMode
usageLog.BillingMode = &billingMode
} else {
billingMode := "token"
usageLog.BillingMode = &billingMode

View File

@@ -145,6 +145,7 @@ func newOpenAIRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo U
nil,
&DeferredService{},
nil,
nil,
)
svc.userGroupRateResolver = newUserGroupRateResolver(
rateRepo,

View File

@@ -322,6 +322,7 @@ type OpenAIGatewayService struct {
openAITokenProvider *OpenAITokenProvider
toolCorrector *CodexToolCorrector
openaiWSResolver OpenAIWSProtocolResolver
resolver *ModelPricingResolver
openaiWSPoolOnce sync.Once
openaiWSStateStoreOnce sync.Once
@@ -357,6 +358,7 @@ func NewOpenAIGatewayService(
httpUpstream HTTPUpstream,
deferredService *DeferredService,
openAITokenProvider *OpenAITokenProvider,
resolver *ModelPricingResolver,
) *OpenAIGatewayService {
svc := &OpenAIGatewayService{
accountRepo: accountRepo,
@@ -384,6 +386,7 @@ func NewOpenAIGatewayService(
openAITokenProvider: openAITokenProvider,
toolCorrector: NewCodexToolCorrector(),
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
resolver: resolver,
responseHeaderFilter: compileResponseHeaderFilter(cfg),
codexSnapshotThrottle: newAccountWriteThrottle(openAICodexSnapshotPersistMinInterval),
}
@@ -4152,12 +4155,28 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
multiplier = resolver.Resolve(ctx, user.ID, *apiKey.GroupID, apiKey.Group.RateMultiplier)
}
var cost *CostBreakdown
var err error
billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel)
serviceTier := ""
if result.ServiceTier != nil {
serviceTier = strings.TrimSpace(*result.ServiceTier)
}
cost, err := s.billingService.CalculateCostWithServiceTier(billingModel, tokens, multiplier, serviceTier)
if s.resolver != nil && apiKey.Group != nil {
gid := apiKey.Group.ID
cost, err = s.billingService.CalculateCostUnified(CostInput{
Ctx: ctx,
Model: billingModel,
GroupID: &gid,
Tokens: tokens,
RequestCount: 1,
RateMultiplier: multiplier,
ServiceTier: serviceTier,
Resolver: s.resolver,
})
} else {
cost, err = s.billingService.CalculateCostWithServiceTier(billingModel, tokens, multiplier, serviceTier)
}
if err != nil {
cost = &CostBreakdown{ActualCost: 0}
}
@@ -4204,8 +4223,11 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
FirstTokenMs: result.FirstTokenMs,
CreatedAt: time.Now(),
}
// 设置计费模式OpenAI 网关都是 token 计费)
{
// 设置计费模式
if cost != nil && cost.BillingMode != "" {
billingMode := cost.BillingMode
usageLog.BillingMode = &billingMode
} else {
billingMode := "token"
usageLog.BillingMode = &billingMode
}

View File

@@ -615,6 +615,7 @@ func TestNewOpenAIGatewayService_InitializesOpenAIWSResolver(t *testing.T) {
nil,
nil,
nil,
nil,
)
decision := svc.getOpenAIWSProtocolResolver().Resolve(nil)