diff --git a/backend/internal/service/openai_gateway_record_usage_test.go b/backend/internal/service/openai_gateway_record_usage_test.go index e2b164c0..38b97b11 100644 --- a/backend/internal/service/openai_gateway_record_usage_test.go +++ b/backend/internal/service/openai_gateway_record_usage_test.go @@ -933,6 +933,89 @@ func TestOpenAIGatewayServiceRecordUsage_BillsMappedRequestsUsingRequestedModel( require.Equal(t, expectedCost.ActualCost, userRepo.lastAmount) } +func TestOpenAIGatewayServiceRecordUsage_ChannelMappedDoesNotOverrideBillingModelWhenUnmapped(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{inserted: true} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil) + usage := OpenAIUsage{InputTokens: 20, OutputTokens: 10} + + // When channel did NOT map the model (ChannelMappedModel == OriginalModel), + // billing should use result.BillingModel (the actual model used after group + // DefaultMappedModel resolution), not the unmapped original model. + expectedCost, err := svc.billingService.CalculateCost("gpt-5.1", UsageTokens{ + InputTokens: 20, + OutputTokens: 10, + }, 1.1) + require.NoError(t, err) + + err = svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_channel_unmapped_billing", + Model: "glm", + BillingModel: "gpt-5.1", + UpstreamModel: "gpt-5.1", + Usage: usage, + Duration: time.Second, + }, + APIKey: &APIKey{ID: 10}, + User: &User{ID: 20}, + Account: &Account{ID: 30}, + ChannelUsageFields: ChannelUsageFields{ + ChannelID: 1, + OriginalModel: "glm", + ChannelMappedModel: "glm", // channel did NOT map + BillingModelSource: BillingModelSourceChannelMapped, + }, + }) + + require.NoError(t, err) + require.NotNil(t, usageRepo.lastLog) + require.Equal(t, expectedCost.ActualCost, usageRepo.lastLog.ActualCost) + require.True(t, usageRepo.lastLog.ActualCost > 0, "cost must not be zero") +} + +func TestOpenAIGatewayServiceRecordUsage_ChannelMappedOverridesBillingModelWhenMapped(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{inserted: true} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil) + usage := OpenAIUsage{InputTokens: 20, OutputTokens: 10} + + // When channel DID map the model (ChannelMappedModel != OriginalModel), + // billing should use the channel-mapped model, honoring admin intent. + expectedCost, err := svc.billingService.CalculateCost("gpt-5.1", UsageTokens{ + InputTokens: 20, + OutputTokens: 10, + }, 1.1) + require.NoError(t, err) + + err = svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_channel_mapped_billing", + Model: "glm", + BillingModel: "gpt-5.1-codex", + UpstreamModel: "gpt-5.1-codex", + Usage: usage, + Duration: time.Second, + }, + APIKey: &APIKey{ID: 10}, + User: &User{ID: 20}, + Account: &Account{ID: 30}, + ChannelUsageFields: ChannelUsageFields{ + ChannelID: 1, + OriginalModel: "glm", + ChannelMappedModel: "gpt-5.1", // channel mapped glm → gpt-5.1 + BillingModelSource: BillingModelSourceChannelMapped, + }, + }) + + require.NoError(t, err) + require.NotNil(t, usageRepo.lastLog) + require.Equal(t, expectedCost.ActualCost, usageRepo.lastLog.ActualCost) + require.True(t, usageRepo.lastLog.ActualCost > 0, "cost must not be zero") +} + func TestOpenAIGatewayServiceRecordUsage_SubscriptionBillingSetsSubscriptionFields(t *testing.T) { usageRepo := &openAIRecordUsageLogRepoStub{inserted: true} userRepo := &openAIRecordUsageUserRepoStub{} diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index 28c4b1f4..ef8198a1 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -4277,7 +4277,7 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec if result.BillingModel != "" { billingModel = strings.TrimSpace(result.BillingModel) } - if input.BillingModelSource == BillingModelSourceChannelMapped && input.ChannelMappedModel != "" { + if input.BillingModelSource == BillingModelSourceChannelMapped && input.ChannelMappedModel != "" && input.ChannelMappedModel != input.OriginalModel { billingModel = input.ChannelMappedModel } if input.BillingModelSource == BillingModelSourceRequested && input.OriginalModel != "" {