diff --git a/backend/internal/service/gateway_record_usage_test.go b/backend/internal/service/gateway_record_usage_test.go index 4c1f0317..1b2f5f51 100644 --- a/backend/internal/service/gateway_record_usage_test.go +++ b/backend/internal/service/gateway_record_usage_test.go @@ -162,6 +162,32 @@ func TestGatewayServiceRecordUsage_BillingFingerprintFallsBackToContextRequestID require.Equal(t, "local:req-local-123", billingRepo.lastCmd.RequestPayloadHash) } +func TestGatewayServiceRecordUsage_PreservesRequestedAndUpstreamModels(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{inserted: true} + svc := newGatewayRecordUsageServiceForTest(usageRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{}) + mappedModel := "claude-sonnet-4-20250514" + + err := svc.RecordUsage(context.Background(), &RecordUsageInput{ + Result: &ForwardResult{ + RequestID: "gateway_models_split", + Usage: ClaudeUsage{InputTokens: 10, OutputTokens: 6}, + Model: "claude-sonnet-4", + UpstreamModel: mappedModel, + Duration: time.Second, + }, + APIKey: &APIKey{ID: 501, Quota: 100}, + User: &User{ID: 601}, + Account: &Account{ID: 701}, + }) + + require.NoError(t, err) + require.NotNil(t, usageRepo.lastLog) + require.Equal(t, "claude-sonnet-4", usageRepo.lastLog.Model) + require.Equal(t, "claude-sonnet-4", usageRepo.lastLog.RequestedModel) + require.NotNil(t, usageRepo.lastLog.UpstreamModel) + require.Equal(t, mappedModel, *usageRepo.lastLog.UpstreamModel) +} + func TestGatewayServiceRecordUsage_UsageLogWriteErrorDoesNotSkipBilling(t *testing.T) { usageRepo := &openAIRecordUsageLogRepoStub{inserted: false, err: MarkUsageLogCreateNotPersisted(context.Canceled)} userRepo := &openAIRecordUsageUserRepoStub{} diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index e23d24de..5b194d57 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -482,10 +482,12 @@ type ClaudeUsage struct { // ForwardResult 转发结果 type ForwardResult struct { - RequestID string - Usage ClaudeUsage - Model string - UpstreamModel string // Actual upstream model after mapping (empty = no mapping) + RequestID string + Usage ClaudeUsage + Model string + // UpstreamModel is the actual upstream model after mapping. + // Prefer empty when it is identical to Model; persistence normalizes equal values away as no-op mappings. + UpstreamModel string Stream bool Duration time.Duration FirstTokenMs *int // 首字时间(流式请求) @@ -7516,6 +7518,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu } var cost *CostBreakdown + billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel) // 根据请求类型选择计费方式 if result.MediaType == "image" || result.MediaType == "video" { @@ -7531,7 +7534,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu if result.MediaType == "image" { cost = s.billingService.CalculateSoraImageCost(result.ImageSize, result.ImageCount, soraConfig, multiplier) } else { - cost = s.billingService.CalculateSoraVideoCost(result.Model, soraConfig, multiplier) + cost = s.billingService.CalculateSoraVideoCost(billingModel, soraConfig, multiplier) } } else if result.MediaType == "prompt" { cost = &CostBreakdown{} @@ -7545,7 +7548,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu Price4K: apiKey.Group.ImagePrice4K, } } - cost = s.billingService.CalculateImageCost(result.Model, result.ImageSize, result.ImageCount, groupConfig, multiplier) + cost = s.billingService.CalculateImageCost(billingModel, result.ImageSize, result.ImageCount, groupConfig, multiplier) } else { // Token 计费 tokens := UsageTokens{ @@ -7557,7 +7560,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu CacheCreation1hTokens: result.Usage.CacheCreation1hTokens, } var err error - cost, err = s.billingService.CalculateCost(result.Model, tokens, multiplier) + cost, err = s.billingService.CalculateCost(billingModel, tokens, multiplier) if err != nil { logger.LegacyPrintf("service.gateway", "Calculate cost failed: %v", err) cost = &CostBreakdown{ActualCost: 0} @@ -7589,6 +7592,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu AccountID: account.ID, RequestID: requestID, Model: result.Model, + RequestedModel: result.Model, UpstreamModel: optionalNonEqualStringPtr(result.UpstreamModel, result.Model), ReasoningEffort: result.ReasoningEffort, InboundEndpoint: optionalTrimmedStringPtr(input.InboundEndpoint), @@ -7719,6 +7723,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * } var cost *CostBreakdown + billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel) // 根据请求类型选择计费方式 if result.ImageCount > 0 { @@ -7731,7 +7736,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * Price4K: apiKey.Group.ImagePrice4K, } } - cost = s.billingService.CalculateImageCost(result.Model, result.ImageSize, result.ImageCount, groupConfig, multiplier) + cost = s.billingService.CalculateImageCost(billingModel, result.ImageSize, result.ImageCount, groupConfig, multiplier) } else { // Token 计费(使用长上下文计费方法) tokens := UsageTokens{ @@ -7743,7 +7748,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * CacheCreation1hTokens: result.Usage.CacheCreation1hTokens, } var err error - cost, err = s.billingService.CalculateCostWithLongContext(result.Model, tokens, multiplier, input.LongContextThreshold, input.LongContextMultiplier) + cost, err = s.billingService.CalculateCostWithLongContext(billingModel, tokens, multiplier, input.LongContextThreshold, input.LongContextMultiplier) if err != nil { logger.LegacyPrintf("service.gateway", "Calculate cost failed: %v", err) cost = &CostBreakdown{ActualCost: 0} @@ -7771,6 +7776,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * AccountID: account.ID, RequestID: requestID, Model: result.Model, + RequestedModel: result.Model, UpstreamModel: optionalNonEqualStringPtr(result.UpstreamModel, result.Model), ReasoningEffort: result.ReasoningEffort, InboundEndpoint: optionalTrimmedStringPtr(input.InboundEndpoint), diff --git a/backend/internal/service/openai_gateway_record_usage_test.go b/backend/internal/service/openai_gateway_record_usage_test.go index a35f9127..5aa4db8a 100644 --- a/backend/internal/service/openai_gateway_record_usage_test.go +++ b/backend/internal/service/openai_gateway_record_usage_test.go @@ -879,6 +879,7 @@ func TestOpenAIGatewayServiceRecordUsage_UsesRequestedModelAndUpstreamModelMetad require.NoError(t, err) require.NotNil(t, usageRepo.lastLog) require.Equal(t, "gpt-5.1", usageRepo.lastLog.Model) + require.Equal(t, "gpt-5.1", usageRepo.lastLog.RequestedModel) require.NotNil(t, usageRepo.lastLog.UpstreamModel) require.Equal(t, "gpt-5.1-codex", *usageRepo.lastLog.UpstreamModel) require.NotNil(t, usageRepo.lastLog.ServiceTier) @@ -894,6 +895,40 @@ func TestOpenAIGatewayServiceRecordUsage_UsesRequestedModelAndUpstreamModelMetad require.Equal(t, 1, userRepo.deductCalls) } +func TestOpenAIGatewayServiceRecordUsage_BillsMappedRequestsUsingUpstreamModelFallback(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{inserted: true} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil) + usage := OpenAIUsage{InputTokens: 20, OutputTokens: 10} + + expectedCost, err := svc.billingService.CalculateCost("gpt-5.1-codex", UsageTokens{ + InputTokens: 20, + OutputTokens: 10, + }, 1.1) + require.NoError(t, err) + + err = svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_upstream_model_billing_fallback", + Model: "gpt-5.1", + UpstreamModel: "gpt-5.1-codex", + Usage: usage, + Duration: time.Second, + }, + APIKey: &APIKey{ID: 10}, + User: &User{ID: 20}, + Account: &Account{ID: 30}, + }) + + require.NoError(t, err) + require.NotNil(t, usageRepo.lastLog) + require.Equal(t, "gpt-5.1", usageRepo.lastLog.Model) + require.Equal(t, expectedCost.ActualCost, usageRepo.lastLog.ActualCost) + require.Equal(t, expectedCost.TotalCost, usageRepo.lastLog.TotalCost) + require.Equal(t, expectedCost.ActualCost, userRepo.lastAmount) +} + func TestOpenAIGatewayServiceRecordUsage_SubscriptionBillingSetsSubscriptionFields(t *testing.T) { usageRepo := &openAIRecordUsageLogRepoStub{inserted: true} userRepo := &openAIRecordUsageUserRepoStub{} diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index cf902c20..4e96cf05 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -4110,9 +4110,9 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec multiplier = resolver.Resolve(ctx, user.ID, *apiKey.GroupID, apiKey.Group.RateMultiplier) } - billingModel := result.Model + billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel) if result.BillingModel != "" { - billingModel = result.BillingModel + billingModel = strings.TrimSpace(result.BillingModel) } serviceTier := "" if result.ServiceTier != nil { @@ -4140,6 +4140,7 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec AccountID: account.ID, RequestID: requestID, Model: result.Model, + RequestedModel: result.Model, UpstreamModel: optionalNonEqualStringPtr(result.UpstreamModel, result.Model), ServiceTier: result.ServiceTier, ReasoningEffort: result.ReasoningEffort,