diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 85590c12..3beaef33 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -178,10 +178,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { digestSessionStore := service.NewDigestSessionStore() channelService := service.NewChannelService(channelRepository, apiKeyAuthCacheInvalidator) modelPricingResolver := service.NewModelPricingResolver(channelService, billingService) - _ = modelPricingResolver // Phase 4: 已注册,后续 Gateway 迁移时使用 - gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, channelService) + gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, channelService, modelPricingResolver) openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oauthRefreshAPI) - openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider) + openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider, modelPricingResolver) geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig) opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository) opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink) diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index 2ad3bb76..b46c86a3 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -164,14 +164,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) { channelMapping = h.gatewayService.ResolveChannelMapping(c.Request.Context(), *apiKey.GroupID, reqModel) } - // 渠道模型限制检查 + // 渠道模型限制检查:使用原始请求模型名,因为定价列表中注册的是用户请求的模型名 if apiKey.GroupID != nil { - checkModel := reqModel - if channelMapping.Mapped { - checkModel = channelMapping.MappedModel - } - if h.gatewayService.IsModelRestricted(c.Request.Context(), *apiKey.GroupID, checkModel) { - h.errorResponse(c, http.StatusForbidden, "invalid_request_error", "Model not available in current channel: "+reqModel) + if h.gatewayService.IsModelRestricted(c.Request.Context(), *apiKey.GroupID, reqModel) { + h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts") return } } diff --git a/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go b/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go index 7dc062df..4caef955 100644 --- a/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go +++ b/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go @@ -162,6 +162,7 @@ func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*servi nil, // settingService nil, // tlsFPProfileService nil, // channelService + nil, // resolver ) // RunModeSimple:跳过计费检查,避免引入 repo/cache 依赖。 diff --git a/backend/internal/handler/sora_client_handler_test.go b/backend/internal/handler/sora_client_handler_test.go index 78e2d24b..57055786 100644 --- a/backend/internal/handler/sora_client_handler_test.go +++ b/backend/internal/handler/sora_client_handler_test.go @@ -2224,7 +2224,8 @@ func (s *stubSoraClientForHandler) GetVideoTask(_ context.Context, _ *service.Ac func newMinimalGatewayService(accountRepo service.AccountRepository) *service.GatewayService { return service.NewGatewayService( accountRepo, nil, nil, nil, nil, nil, nil, nil, nil, - nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, + nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, + nil, nil, ) } diff --git a/backend/internal/handler/sora_gateway_handler_test.go b/backend/internal/handler/sora_gateway_handler_test.go index 18e6e929..e053b668 100644 --- a/backend/internal/handler/sora_gateway_handler_test.go +++ b/backend/internal/handler/sora_gateway_handler_test.go @@ -466,6 +466,7 @@ func TestSoraGatewayHandler_ChatCompletions(t *testing.T) { nil, // settingService nil, // tlsFPProfileService nil, // channelService + nil, // resolver ) soraClient := &stubSoraClient{imageURLs: []string{"https://example.com/a.png"}} diff --git a/backend/internal/service/billing_service.go b/backend/internal/service/billing_service.go index 7deb1cf9..d256102c 100644 --- a/backend/internal/service/billing_service.go +++ b/backend/internal/service/billing_service.go @@ -104,6 +104,7 @@ type CostBreakdown struct { CacheReadCost float64 TotalCost float64 ActualCost float64 // 应用倍率后的实际费用 + BillingMode string // 计费模式("token"/"per_request"/"image"),由 CalculateCostUnified 填充 } // BillingService 计费服务 @@ -439,12 +440,21 @@ func (s *BillingService) CalculateCostUnified(input CostInput) (*CostBreakdown, input.RateMultiplier = 1.0 } + var breakdown *CostBreakdown + var err error switch resolved.Mode { case BillingModePerRequest, BillingModeImage: - return s.calculatePerRequestCost(resolved, input) + breakdown, err = s.calculatePerRequestCost(resolved, input) default: // BillingModeToken - return s.calculateTokenCost(resolved, input) + breakdown, err = s.calculateTokenCost(resolved, input) } + if err == nil && breakdown != nil { + breakdown.BillingMode = string(resolved.Mode) + if breakdown.BillingMode == "" { + breakdown.BillingMode = string(BillingModeToken) + } + } + return breakdown, err } // calculateTokenCost 按 token 区间计费 diff --git a/backend/internal/service/gateway_record_usage_test.go b/backend/internal/service/gateway_record_usage_test.go index 5df0b58c..97703a9d 100644 --- a/backend/internal/service/gateway_record_usage_test.go +++ b/backend/internal/service/gateway_record_usage_test.go @@ -42,6 +42,7 @@ func newGatewayRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo nil, nil, nil, + nil, ) } diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 7aeffa16..c7d8403e 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -569,6 +569,7 @@ type GatewayService struct { debugModelRouting atomic.Bool debugClaudeMimic atomic.Bool channelService *ChannelService + resolver *ModelPricingResolver debugGatewayBodyFile atomic.Pointer[os.File] // non-nil when SUB2API_DEBUG_GATEWAY_BODY is set tlsFPProfileService *TLSFingerprintProfileService } @@ -599,6 +600,7 @@ func NewGatewayService( settingService *SettingService, tlsFPProfileService *TLSFingerprintProfileService, channelService *ChannelService, + resolver *ModelPricingResolver, ) *GatewayService { userGroupRateTTL := resolveUserGroupRateCacheTTL(cfg) modelsListTTL := resolveModelsListCacheTTL(cfg) @@ -632,6 +634,7 @@ func NewGatewayService( responseHeaderFilter: compileResponseHeaderFilter(cfg), tlsFPProfileService: tlsFPProfileService, channelService: channelService, + resolver: resolver, } svc.userGroupRateResolver = newUserGroupRateResolver( userGroupRateRepo, @@ -7790,13 +7793,21 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu CacheCreation1hTokens: result.Usage.CacheCreation1hTokens, } var err error - // 渠道定价覆盖 - var chPricing *ChannelModelPricing - if s.channelService != nil && apiKey.Group != nil { - chPricing = s.channelService.GetChannelModelPricing(ctx, apiKey.Group.ID, billingModel) - } - if chPricing != nil { - cost, err = s.billingService.CalculateCostWithChannel(billingModel, tokens, multiplier, chPricing) + if s.resolver != nil && apiKey.Group != nil { + var groupID *int64 + if apiKey.Group != nil { + gid := apiKey.Group.ID + groupID = &gid + } + cost, err = s.billingService.CalculateCostUnified(CostInput{ + Ctx: ctx, + Model: billingModel, + GroupID: groupID, + Tokens: tokens, + RequestCount: 1, + RateMultiplier: multiplier, + Resolver: s.resolver, + }) } else { cost, err = s.billingService.CalculateCost(billingModel, tokens, multiplier) } @@ -7868,6 +7879,9 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu if result.ImageCount > 0 { billingMode := "image" usageLog.BillingMode = &billingMode + } else if cost != nil && cost.BillingMode != "" { + billingMode := cost.BillingMode + usageLog.BillingMode = &billingMode } else { billingMode := "token" usageLog.BillingMode = &billingMode @@ -8016,14 +8030,30 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * CacheCreation1hTokens: result.Usage.CacheCreation1hTokens, } var err error - // 渠道定价覆盖 - var chPricing2 *ChannelModelPricing - if s.channelService != nil && apiKey.Group != nil { - chPricing2 = s.channelService.GetChannelModelPricing(ctx, apiKey.Group.ID, billingModel) + // 优先尝试 Resolver + CalculateCostUnified(仅在有渠道定价时使用) + useUnified := false + if s.resolver != nil && apiKey.Group != nil { + gid := apiKey.Group.ID + resolved := s.resolver.Resolve(ctx, PricingInput{ + Model: billingModel, + GroupID: &gid, + }) + if resolved.Source == "channel" { + // 有渠道定价,渠道区间已包含上下文分层 + cost, err = s.billingService.CalculateCostUnified(CostInput{ + Ctx: ctx, + Model: billingModel, + GroupID: &gid, + Tokens: tokens, + RequestCount: 1, + RateMultiplier: multiplier, + Resolver: s.resolver, + }) + useUnified = true + } } - if chPricing2 != nil { - cost, err = s.billingService.CalculateCostWithChannel(billingModel, tokens, multiplier, chPricing2) - } else { + if !useUnified { + // 无渠道定价,保持原有长上下文双倍计费逻辑(如 Gemini 200K 阈值) cost, err = s.billingService.CalculateCostWithLongContext(billingModel, tokens, multiplier, input.LongContextThreshold, input.LongContextMultiplier) } if err != nil { @@ -8088,6 +8118,9 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * if result.ImageCount > 0 { billingMode := "image" usageLog.BillingMode = &billingMode + } else if cost != nil && cost.BillingMode != "" { + billingMode := cost.BillingMode + usageLog.BillingMode = &billingMode } else { billingMode := "token" usageLog.BillingMode = &billingMode diff --git a/backend/internal/service/openai_gateway_record_usage_test.go b/backend/internal/service/openai_gateway_record_usage_test.go index 7a636afa..bfcff444 100644 --- a/backend/internal/service/openai_gateway_record_usage_test.go +++ b/backend/internal/service/openai_gateway_record_usage_test.go @@ -145,6 +145,7 @@ func newOpenAIRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo U nil, &DeferredService{}, nil, + nil, ) svc.userGroupRateResolver = newUserGroupRateResolver( rateRepo, diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index ddef3d31..06c36b0f 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -322,6 +322,7 @@ type OpenAIGatewayService struct { openAITokenProvider *OpenAITokenProvider toolCorrector *CodexToolCorrector openaiWSResolver OpenAIWSProtocolResolver + resolver *ModelPricingResolver openaiWSPoolOnce sync.Once openaiWSStateStoreOnce sync.Once @@ -357,6 +358,7 @@ func NewOpenAIGatewayService( httpUpstream HTTPUpstream, deferredService *DeferredService, openAITokenProvider *OpenAITokenProvider, + resolver *ModelPricingResolver, ) *OpenAIGatewayService { svc := &OpenAIGatewayService{ accountRepo: accountRepo, @@ -384,6 +386,7 @@ func NewOpenAIGatewayService( openAITokenProvider: openAITokenProvider, toolCorrector: NewCodexToolCorrector(), openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + resolver: resolver, responseHeaderFilter: compileResponseHeaderFilter(cfg), codexSnapshotThrottle: newAccountWriteThrottle(openAICodexSnapshotPersistMinInterval), } @@ -4152,12 +4155,28 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec multiplier = resolver.Resolve(ctx, user.ID, *apiKey.GroupID, apiKey.Group.RateMultiplier) } + var cost *CostBreakdown + var err error billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel) serviceTier := "" if result.ServiceTier != nil { serviceTier = strings.TrimSpace(*result.ServiceTier) } - cost, err := s.billingService.CalculateCostWithServiceTier(billingModel, tokens, multiplier, serviceTier) + if s.resolver != nil && apiKey.Group != nil { + gid := apiKey.Group.ID + cost, err = s.billingService.CalculateCostUnified(CostInput{ + Ctx: ctx, + Model: billingModel, + GroupID: &gid, + Tokens: tokens, + RequestCount: 1, + RateMultiplier: multiplier, + ServiceTier: serviceTier, + Resolver: s.resolver, + }) + } else { + cost, err = s.billingService.CalculateCostWithServiceTier(billingModel, tokens, multiplier, serviceTier) + } if err != nil { cost = &CostBreakdown{ActualCost: 0} } @@ -4204,8 +4223,11 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec FirstTokenMs: result.FirstTokenMs, CreatedAt: time.Now(), } - // 设置计费模式(OpenAI 网关都是 token 计费) - { + // 设置计费模式 + if cost != nil && cost.BillingMode != "" { + billingMode := cost.BillingMode + usageLog.BillingMode = &billingMode + } else { billingMode := "token" usageLog.BillingMode = &billingMode } diff --git a/backend/internal/service/openai_ws_protocol_forward_test.go b/backend/internal/service/openai_ws_protocol_forward_test.go index 8c5c9368..e8d9f8f7 100644 --- a/backend/internal/service/openai_ws_protocol_forward_test.go +++ b/backend/internal/service/openai_ws_protocol_forward_test.go @@ -615,6 +615,7 @@ func TestNewOpenAIGatewayService_InitializesOpenAIWSResolver(t *testing.T) { nil, nil, nil, + nil, ) decision := svc.getOpenAIWSProtocolResolver().Resolve(nil)