feat(billing): 网关计费迁移到 CalculateCostUnified + 模型限制错误统一

- GatewayService/OpenAIGatewayService 注入 ModelPricingResolver
- RecordUsage 从旧路径迁移到 CalculateCostUnified(支持 per_request/image 模式)
- 无渠道时自动回退旧路径,保持原有行为
- 长上下文双倍计费仅在无渠道定价时生效
- CostBreakdown 新增 BillingMode 字段,使用日志记录实际计费模式
- 模型限制错误改为与"无可用账号"相同的 503 响应
This commit is contained in:
erio
2026-03-30 22:58:28 +08:00
parent a51e0047b7
commit 632035aabd
11 changed files with 96 additions and 30 deletions

View File

@@ -164,14 +164,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
channelMapping = h.gatewayService.ResolveChannelMapping(c.Request.Context(), *apiKey.GroupID, reqModel)
}
// 渠道模型限制检查
// 渠道模型限制检查:使用原始请求模型名,因为定价列表中注册的是用户请求的模型名
if apiKey.GroupID != nil {
checkModel := reqModel
if channelMapping.Mapped {
checkModel = channelMapping.MappedModel
}
if h.gatewayService.IsModelRestricted(c.Request.Context(), *apiKey.GroupID, checkModel) {
h.errorResponse(c, http.StatusForbidden, "invalid_request_error", "Model not available in current channel: "+reqModel)
if h.gatewayService.IsModelRestricted(c.Request.Context(), *apiKey.GroupID, reqModel) {
h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts")
return
}
}

View File

@@ -162,6 +162,7 @@ func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*servi
nil, // settingService
nil, // tlsFPProfileService
nil, // channelService
nil, // resolver
)
// RunModeSimple跳过计费检查避免引入 repo/cache 依赖。

View File

@@ -2224,7 +2224,8 @@ func (s *stubSoraClientForHandler) GetVideoTask(_ context.Context, _ *service.Ac
func newMinimalGatewayService(accountRepo service.AccountRepository) *service.GatewayService {
return service.NewGatewayService(
accountRepo, nil, nil, nil, nil, nil, nil, nil, nil,
nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil,
nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil,
nil, nil,
)
}

View File

@@ -466,6 +466,7 @@ func TestSoraGatewayHandler_ChatCompletions(t *testing.T) {
nil, // settingService
nil, // tlsFPProfileService
nil, // channelService
nil, // resolver
)
soraClient := &stubSoraClient{imageURLs: []string{"https://example.com/a.png"}}

View File

@@ -104,6 +104,7 @@ type CostBreakdown struct {
CacheReadCost float64
TotalCost float64
ActualCost float64 // 应用倍率后的实际费用
BillingMode string // 计费模式("token"/"per_request"/"image"),由 CalculateCostUnified 填充
}
// BillingService 计费服务
@@ -439,12 +440,21 @@ func (s *BillingService) CalculateCostUnified(input CostInput) (*CostBreakdown,
input.RateMultiplier = 1.0
}
var breakdown *CostBreakdown
var err error
switch resolved.Mode {
case BillingModePerRequest, BillingModeImage:
return s.calculatePerRequestCost(resolved, input)
breakdown, err = s.calculatePerRequestCost(resolved, input)
default: // BillingModeToken
return s.calculateTokenCost(resolved, input)
breakdown, err = s.calculateTokenCost(resolved, input)
}
if err == nil && breakdown != nil {
breakdown.BillingMode = string(resolved.Mode)
if breakdown.BillingMode == "" {
breakdown.BillingMode = string(BillingModeToken)
}
}
return breakdown, err
}
// calculateTokenCost 按 token 区间计费

View File

@@ -42,6 +42,7 @@ func newGatewayRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo
nil,
nil,
nil,
nil,
)
}

View File

@@ -569,6 +569,7 @@ type GatewayService struct {
debugModelRouting atomic.Bool
debugClaudeMimic atomic.Bool
channelService *ChannelService
resolver *ModelPricingResolver
debugGatewayBodyFile atomic.Pointer[os.File] // non-nil when SUB2API_DEBUG_GATEWAY_BODY is set
tlsFPProfileService *TLSFingerprintProfileService
}
@@ -599,6 +600,7 @@ func NewGatewayService(
settingService *SettingService,
tlsFPProfileService *TLSFingerprintProfileService,
channelService *ChannelService,
resolver *ModelPricingResolver,
) *GatewayService {
userGroupRateTTL := resolveUserGroupRateCacheTTL(cfg)
modelsListTTL := resolveModelsListCacheTTL(cfg)
@@ -632,6 +634,7 @@ func NewGatewayService(
responseHeaderFilter: compileResponseHeaderFilter(cfg),
tlsFPProfileService: tlsFPProfileService,
channelService: channelService,
resolver: resolver,
}
svc.userGroupRateResolver = newUserGroupRateResolver(
userGroupRateRepo,
@@ -7790,13 +7793,21 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
CacheCreation1hTokens: result.Usage.CacheCreation1hTokens,
}
var err error
// 渠道定价覆盖
var chPricing *ChannelModelPricing
if s.channelService != nil && apiKey.Group != nil {
chPricing = s.channelService.GetChannelModelPricing(ctx, apiKey.Group.ID, billingModel)
}
if chPricing != nil {
cost, err = s.billingService.CalculateCostWithChannel(billingModel, tokens, multiplier, chPricing)
if s.resolver != nil && apiKey.Group != nil {
var groupID *int64
if apiKey.Group != nil {
gid := apiKey.Group.ID
groupID = &gid
}
cost, err = s.billingService.CalculateCostUnified(CostInput{
Ctx: ctx,
Model: billingModel,
GroupID: groupID,
Tokens: tokens,
RequestCount: 1,
RateMultiplier: multiplier,
Resolver: s.resolver,
})
} else {
cost, err = s.billingService.CalculateCost(billingModel, tokens, multiplier)
}
@@ -7868,6 +7879,9 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
if result.ImageCount > 0 {
billingMode := "image"
usageLog.BillingMode = &billingMode
} else if cost != nil && cost.BillingMode != "" {
billingMode := cost.BillingMode
usageLog.BillingMode = &billingMode
} else {
billingMode := "token"
usageLog.BillingMode = &billingMode
@@ -8016,14 +8030,30 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
CacheCreation1hTokens: result.Usage.CacheCreation1hTokens,
}
var err error
// 渠道定价覆盖
var chPricing2 *ChannelModelPricing
if s.channelService != nil && apiKey.Group != nil {
chPricing2 = s.channelService.GetChannelModelPricing(ctx, apiKey.Group.ID, billingModel)
// 优先尝试 Resolver + CalculateCostUnified仅在有渠道定价时使用
useUnified := false
if s.resolver != nil && apiKey.Group != nil {
gid := apiKey.Group.ID
resolved := s.resolver.Resolve(ctx, PricingInput{
Model: billingModel,
GroupID: &gid,
})
if resolved.Source == "channel" {
// 有渠道定价,渠道区间已包含上下文分层
cost, err = s.billingService.CalculateCostUnified(CostInput{
Ctx: ctx,
Model: billingModel,
GroupID: &gid,
Tokens: tokens,
RequestCount: 1,
RateMultiplier: multiplier,
Resolver: s.resolver,
})
useUnified = true
}
}
if chPricing2 != nil {
cost, err = s.billingService.CalculateCostWithChannel(billingModel, tokens, multiplier, chPricing2)
} else {
if !useUnified {
// 无渠道定价,保持原有长上下文双倍计费逻辑(如 Gemini 200K 阈值)
cost, err = s.billingService.CalculateCostWithLongContext(billingModel, tokens, multiplier, input.LongContextThreshold, input.LongContextMultiplier)
}
if err != nil {
@@ -8088,6 +8118,9 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
if result.ImageCount > 0 {
billingMode := "image"
usageLog.BillingMode = &billingMode
} else if cost != nil && cost.BillingMode != "" {
billingMode := cost.BillingMode
usageLog.BillingMode = &billingMode
} else {
billingMode := "token"
usageLog.BillingMode = &billingMode

View File

@@ -145,6 +145,7 @@ func newOpenAIRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo U
nil,
&DeferredService{},
nil,
nil,
)
svc.userGroupRateResolver = newUserGroupRateResolver(
rateRepo,

View File

@@ -322,6 +322,7 @@ type OpenAIGatewayService struct {
openAITokenProvider *OpenAITokenProvider
toolCorrector *CodexToolCorrector
openaiWSResolver OpenAIWSProtocolResolver
resolver *ModelPricingResolver
openaiWSPoolOnce sync.Once
openaiWSStateStoreOnce sync.Once
@@ -357,6 +358,7 @@ func NewOpenAIGatewayService(
httpUpstream HTTPUpstream,
deferredService *DeferredService,
openAITokenProvider *OpenAITokenProvider,
resolver *ModelPricingResolver,
) *OpenAIGatewayService {
svc := &OpenAIGatewayService{
accountRepo: accountRepo,
@@ -384,6 +386,7 @@ func NewOpenAIGatewayService(
openAITokenProvider: openAITokenProvider,
toolCorrector: NewCodexToolCorrector(),
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
resolver: resolver,
responseHeaderFilter: compileResponseHeaderFilter(cfg),
codexSnapshotThrottle: newAccountWriteThrottle(openAICodexSnapshotPersistMinInterval),
}
@@ -4152,12 +4155,28 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
multiplier = resolver.Resolve(ctx, user.ID, *apiKey.GroupID, apiKey.Group.RateMultiplier)
}
var cost *CostBreakdown
var err error
billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel)
serviceTier := ""
if result.ServiceTier != nil {
serviceTier = strings.TrimSpace(*result.ServiceTier)
}
cost, err := s.billingService.CalculateCostWithServiceTier(billingModel, tokens, multiplier, serviceTier)
if s.resolver != nil && apiKey.Group != nil {
gid := apiKey.Group.ID
cost, err = s.billingService.CalculateCostUnified(CostInput{
Ctx: ctx,
Model: billingModel,
GroupID: &gid,
Tokens: tokens,
RequestCount: 1,
RateMultiplier: multiplier,
ServiceTier: serviceTier,
Resolver: s.resolver,
})
} else {
cost, err = s.billingService.CalculateCostWithServiceTier(billingModel, tokens, multiplier, serviceTier)
}
if err != nil {
cost = &CostBreakdown{ActualCost: 0}
}
@@ -4204,8 +4223,11 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
FirstTokenMs: result.FirstTokenMs,
CreatedAt: time.Now(),
}
// 设置计费模式OpenAI 网关都是 token 计费)
{
// 设置计费模式
if cost != nil && cost.BillingMode != "" {
billingMode := cost.BillingMode
usageLog.BillingMode = &billingMode
} else {
billingMode := "token"
usageLog.BillingMode = &billingMode
}

View File

@@ -615,6 +615,7 @@ func TestNewOpenAIGatewayService_InitializesOpenAIWSResolver(t *testing.T) {
nil,
nil,
nil,
nil,
)
decision := svc.getOpenAIWSProtocolResolver().Resolve(nil)