feat(billing): 网关计费迁移到 CalculateCostUnified + 模型限制错误统一
- GatewayService/OpenAIGatewayService 注入 ModelPricingResolver - RecordUsage 从旧路径迁移到 CalculateCostUnified(支持 per_request/image 模式) - 无渠道时自动回退旧路径,保持原有行为 - 长上下文双倍计费仅在无渠道定价时生效 - CostBreakdown 新增 BillingMode 字段,使用日志记录实际计费模式 - 模型限制错误改为与"无可用账号"相同的 503 响应
This commit is contained in:
@@ -104,6 +104,7 @@ type CostBreakdown struct {
|
||||
CacheReadCost float64
|
||||
TotalCost float64
|
||||
ActualCost float64 // 应用倍率后的实际费用
|
||||
BillingMode string // 计费模式("token"/"per_request"/"image"),由 CalculateCostUnified 填充
|
||||
}
|
||||
|
||||
// BillingService 计费服务
|
||||
@@ -439,12 +440,21 @@ func (s *BillingService) CalculateCostUnified(input CostInput) (*CostBreakdown,
|
||||
input.RateMultiplier = 1.0
|
||||
}
|
||||
|
||||
var breakdown *CostBreakdown
|
||||
var err error
|
||||
switch resolved.Mode {
|
||||
case BillingModePerRequest, BillingModeImage:
|
||||
return s.calculatePerRequestCost(resolved, input)
|
||||
breakdown, err = s.calculatePerRequestCost(resolved, input)
|
||||
default: // BillingModeToken
|
||||
return s.calculateTokenCost(resolved, input)
|
||||
breakdown, err = s.calculateTokenCost(resolved, input)
|
||||
}
|
||||
if err == nil && breakdown != nil {
|
||||
breakdown.BillingMode = string(resolved.Mode)
|
||||
if breakdown.BillingMode == "" {
|
||||
breakdown.BillingMode = string(BillingModeToken)
|
||||
}
|
||||
}
|
||||
return breakdown, err
|
||||
}
|
||||
|
||||
// calculateTokenCost 按 token 区间计费
|
||||
|
||||
@@ -42,6 +42,7 @@ func newGatewayRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@@ -569,6 +569,7 @@ type GatewayService struct {
|
||||
debugModelRouting atomic.Bool
|
||||
debugClaudeMimic atomic.Bool
|
||||
channelService *ChannelService
|
||||
resolver *ModelPricingResolver
|
||||
debugGatewayBodyFile atomic.Pointer[os.File] // non-nil when SUB2API_DEBUG_GATEWAY_BODY is set
|
||||
tlsFPProfileService *TLSFingerprintProfileService
|
||||
}
|
||||
@@ -599,6 +600,7 @@ func NewGatewayService(
|
||||
settingService *SettingService,
|
||||
tlsFPProfileService *TLSFingerprintProfileService,
|
||||
channelService *ChannelService,
|
||||
resolver *ModelPricingResolver,
|
||||
) *GatewayService {
|
||||
userGroupRateTTL := resolveUserGroupRateCacheTTL(cfg)
|
||||
modelsListTTL := resolveModelsListCacheTTL(cfg)
|
||||
@@ -632,6 +634,7 @@ func NewGatewayService(
|
||||
responseHeaderFilter: compileResponseHeaderFilter(cfg),
|
||||
tlsFPProfileService: tlsFPProfileService,
|
||||
channelService: channelService,
|
||||
resolver: resolver,
|
||||
}
|
||||
svc.userGroupRateResolver = newUserGroupRateResolver(
|
||||
userGroupRateRepo,
|
||||
@@ -7790,13 +7793,21 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
CacheCreation1hTokens: result.Usage.CacheCreation1hTokens,
|
||||
}
|
||||
var err error
|
||||
// 渠道定价覆盖
|
||||
var chPricing *ChannelModelPricing
|
||||
if s.channelService != nil && apiKey.Group != nil {
|
||||
chPricing = s.channelService.GetChannelModelPricing(ctx, apiKey.Group.ID, billingModel)
|
||||
}
|
||||
if chPricing != nil {
|
||||
cost, err = s.billingService.CalculateCostWithChannel(billingModel, tokens, multiplier, chPricing)
|
||||
if s.resolver != nil && apiKey.Group != nil {
|
||||
var groupID *int64
|
||||
if apiKey.Group != nil {
|
||||
gid := apiKey.Group.ID
|
||||
groupID = &gid
|
||||
}
|
||||
cost, err = s.billingService.CalculateCostUnified(CostInput{
|
||||
Ctx: ctx,
|
||||
Model: billingModel,
|
||||
GroupID: groupID,
|
||||
Tokens: tokens,
|
||||
RequestCount: 1,
|
||||
RateMultiplier: multiplier,
|
||||
Resolver: s.resolver,
|
||||
})
|
||||
} else {
|
||||
cost, err = s.billingService.CalculateCost(billingModel, tokens, multiplier)
|
||||
}
|
||||
@@ -7868,6 +7879,9 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
if result.ImageCount > 0 {
|
||||
billingMode := "image"
|
||||
usageLog.BillingMode = &billingMode
|
||||
} else if cost != nil && cost.BillingMode != "" {
|
||||
billingMode := cost.BillingMode
|
||||
usageLog.BillingMode = &billingMode
|
||||
} else {
|
||||
billingMode := "token"
|
||||
usageLog.BillingMode = &billingMode
|
||||
@@ -8016,14 +8030,30 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
|
||||
CacheCreation1hTokens: result.Usage.CacheCreation1hTokens,
|
||||
}
|
||||
var err error
|
||||
// 渠道定价覆盖
|
||||
var chPricing2 *ChannelModelPricing
|
||||
if s.channelService != nil && apiKey.Group != nil {
|
||||
chPricing2 = s.channelService.GetChannelModelPricing(ctx, apiKey.Group.ID, billingModel)
|
||||
// 优先尝试 Resolver + CalculateCostUnified(仅在有渠道定价时使用)
|
||||
useUnified := false
|
||||
if s.resolver != nil && apiKey.Group != nil {
|
||||
gid := apiKey.Group.ID
|
||||
resolved := s.resolver.Resolve(ctx, PricingInput{
|
||||
Model: billingModel,
|
||||
GroupID: &gid,
|
||||
})
|
||||
if resolved.Source == "channel" {
|
||||
// 有渠道定价,渠道区间已包含上下文分层
|
||||
cost, err = s.billingService.CalculateCostUnified(CostInput{
|
||||
Ctx: ctx,
|
||||
Model: billingModel,
|
||||
GroupID: &gid,
|
||||
Tokens: tokens,
|
||||
RequestCount: 1,
|
||||
RateMultiplier: multiplier,
|
||||
Resolver: s.resolver,
|
||||
})
|
||||
useUnified = true
|
||||
}
|
||||
}
|
||||
if chPricing2 != nil {
|
||||
cost, err = s.billingService.CalculateCostWithChannel(billingModel, tokens, multiplier, chPricing2)
|
||||
} else {
|
||||
if !useUnified {
|
||||
// 无渠道定价,保持原有长上下文双倍计费逻辑(如 Gemini 200K 阈值)
|
||||
cost, err = s.billingService.CalculateCostWithLongContext(billingModel, tokens, multiplier, input.LongContextThreshold, input.LongContextMultiplier)
|
||||
}
|
||||
if err != nil {
|
||||
@@ -8088,6 +8118,9 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
|
||||
if result.ImageCount > 0 {
|
||||
billingMode := "image"
|
||||
usageLog.BillingMode = &billingMode
|
||||
} else if cost != nil && cost.BillingMode != "" {
|
||||
billingMode := cost.BillingMode
|
||||
usageLog.BillingMode = &billingMode
|
||||
} else {
|
||||
billingMode := "token"
|
||||
usageLog.BillingMode = &billingMode
|
||||
|
||||
@@ -145,6 +145,7 @@ func newOpenAIRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo U
|
||||
nil,
|
||||
&DeferredService{},
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
svc.userGroupRateResolver = newUserGroupRateResolver(
|
||||
rateRepo,
|
||||
|
||||
@@ -322,6 +322,7 @@ type OpenAIGatewayService struct {
|
||||
openAITokenProvider *OpenAITokenProvider
|
||||
toolCorrector *CodexToolCorrector
|
||||
openaiWSResolver OpenAIWSProtocolResolver
|
||||
resolver *ModelPricingResolver
|
||||
|
||||
openaiWSPoolOnce sync.Once
|
||||
openaiWSStateStoreOnce sync.Once
|
||||
@@ -357,6 +358,7 @@ func NewOpenAIGatewayService(
|
||||
httpUpstream HTTPUpstream,
|
||||
deferredService *DeferredService,
|
||||
openAITokenProvider *OpenAITokenProvider,
|
||||
resolver *ModelPricingResolver,
|
||||
) *OpenAIGatewayService {
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: accountRepo,
|
||||
@@ -384,6 +386,7 @@ func NewOpenAIGatewayService(
|
||||
openAITokenProvider: openAITokenProvider,
|
||||
toolCorrector: NewCodexToolCorrector(),
|
||||
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
|
||||
resolver: resolver,
|
||||
responseHeaderFilter: compileResponseHeaderFilter(cfg),
|
||||
codexSnapshotThrottle: newAccountWriteThrottle(openAICodexSnapshotPersistMinInterval),
|
||||
}
|
||||
@@ -4152,12 +4155,28 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
||||
multiplier = resolver.Resolve(ctx, user.ID, *apiKey.GroupID, apiKey.Group.RateMultiplier)
|
||||
}
|
||||
|
||||
var cost *CostBreakdown
|
||||
var err error
|
||||
billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel)
|
||||
serviceTier := ""
|
||||
if result.ServiceTier != nil {
|
||||
serviceTier = strings.TrimSpace(*result.ServiceTier)
|
||||
}
|
||||
cost, err := s.billingService.CalculateCostWithServiceTier(billingModel, tokens, multiplier, serviceTier)
|
||||
if s.resolver != nil && apiKey.Group != nil {
|
||||
gid := apiKey.Group.ID
|
||||
cost, err = s.billingService.CalculateCostUnified(CostInput{
|
||||
Ctx: ctx,
|
||||
Model: billingModel,
|
||||
GroupID: &gid,
|
||||
Tokens: tokens,
|
||||
RequestCount: 1,
|
||||
RateMultiplier: multiplier,
|
||||
ServiceTier: serviceTier,
|
||||
Resolver: s.resolver,
|
||||
})
|
||||
} else {
|
||||
cost, err = s.billingService.CalculateCostWithServiceTier(billingModel, tokens, multiplier, serviceTier)
|
||||
}
|
||||
if err != nil {
|
||||
cost = &CostBreakdown{ActualCost: 0}
|
||||
}
|
||||
@@ -4204,8 +4223,11 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
||||
FirstTokenMs: result.FirstTokenMs,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
// 设置计费模式(OpenAI 网关都是 token 计费)
|
||||
{
|
||||
// 设置计费模式
|
||||
if cost != nil && cost.BillingMode != "" {
|
||||
billingMode := cost.BillingMode
|
||||
usageLog.BillingMode = &billingMode
|
||||
} else {
|
||||
billingMode := "token"
|
||||
usageLog.BillingMode = &billingMode
|
||||
}
|
||||
|
||||
@@ -615,6 +615,7 @@ func TestNewOpenAIGatewayService_InitializesOpenAIWSResolver(t *testing.T) {
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
|
||||
decision := svc.getOpenAIWSProtocolResolver().Resolve(nil)
|
||||
|
||||
Reference in New Issue
Block a user