fix(usage): preserve requested model in gateway billing paths
Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent) Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
This commit is contained in:
@@ -162,6 +162,32 @@ func TestGatewayServiceRecordUsage_BillingFingerprintFallsBackToContextRequestID
|
||||
require.Equal(t, "local:req-local-123", billingRepo.lastCmd.RequestPayloadHash)
|
||||
}
|
||||
|
||||
func TestGatewayServiceRecordUsage_PreservesRequestedAndUpstreamModels(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
|
||||
svc := newGatewayRecordUsageServiceForTest(usageRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{})
|
||||
mappedModel := "claude-sonnet-4-20250514"
|
||||
|
||||
err := svc.RecordUsage(context.Background(), &RecordUsageInput{
|
||||
Result: &ForwardResult{
|
||||
RequestID: "gateway_models_split",
|
||||
Usage: ClaudeUsage{InputTokens: 10, OutputTokens: 6},
|
||||
Model: "claude-sonnet-4",
|
||||
UpstreamModel: mappedModel,
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{ID: 501, Quota: 100},
|
||||
User: &User{ID: 601},
|
||||
Account: &Account{ID: 701},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, usageRepo.lastLog)
|
||||
require.Equal(t, "claude-sonnet-4", usageRepo.lastLog.Model)
|
||||
require.Equal(t, "claude-sonnet-4", usageRepo.lastLog.RequestedModel)
|
||||
require.NotNil(t, usageRepo.lastLog.UpstreamModel)
|
||||
require.Equal(t, mappedModel, *usageRepo.lastLog.UpstreamModel)
|
||||
}
|
||||
|
||||
func TestGatewayServiceRecordUsage_UsageLogWriteErrorDoesNotSkipBilling(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: false, err: MarkUsageLogCreateNotPersisted(context.Canceled)}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
|
||||
@@ -482,10 +482,12 @@ type ClaudeUsage struct {
|
||||
|
||||
// ForwardResult 转发结果
|
||||
type ForwardResult struct {
|
||||
RequestID string
|
||||
Usage ClaudeUsage
|
||||
Model string
|
||||
UpstreamModel string // Actual upstream model after mapping (empty = no mapping)
|
||||
RequestID string
|
||||
Usage ClaudeUsage
|
||||
Model string
|
||||
// UpstreamModel is the actual upstream model after mapping.
|
||||
// Prefer empty when it is identical to Model; persistence normalizes equal values away as no-op mappings.
|
||||
UpstreamModel string
|
||||
Stream bool
|
||||
Duration time.Duration
|
||||
FirstTokenMs *int // 首字时间(流式请求)
|
||||
@@ -7516,6 +7518,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
}
|
||||
|
||||
var cost *CostBreakdown
|
||||
billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel)
|
||||
|
||||
// 根据请求类型选择计费方式
|
||||
if result.MediaType == "image" || result.MediaType == "video" {
|
||||
@@ -7531,7 +7534,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
if result.MediaType == "image" {
|
||||
cost = s.billingService.CalculateSoraImageCost(result.ImageSize, result.ImageCount, soraConfig, multiplier)
|
||||
} else {
|
||||
cost = s.billingService.CalculateSoraVideoCost(result.Model, soraConfig, multiplier)
|
||||
cost = s.billingService.CalculateSoraVideoCost(billingModel, soraConfig, multiplier)
|
||||
}
|
||||
} else if result.MediaType == "prompt" {
|
||||
cost = &CostBreakdown{}
|
||||
@@ -7545,7 +7548,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
Price4K: apiKey.Group.ImagePrice4K,
|
||||
}
|
||||
}
|
||||
cost = s.billingService.CalculateImageCost(result.Model, result.ImageSize, result.ImageCount, groupConfig, multiplier)
|
||||
cost = s.billingService.CalculateImageCost(billingModel, result.ImageSize, result.ImageCount, groupConfig, multiplier)
|
||||
} else {
|
||||
// Token 计费
|
||||
tokens := UsageTokens{
|
||||
@@ -7557,7 +7560,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
CacheCreation1hTokens: result.Usage.CacheCreation1hTokens,
|
||||
}
|
||||
var err error
|
||||
cost, err = s.billingService.CalculateCost(result.Model, tokens, multiplier)
|
||||
cost, err = s.billingService.CalculateCost(billingModel, tokens, multiplier)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.gateway", "Calculate cost failed: %v", err)
|
||||
cost = &CostBreakdown{ActualCost: 0}
|
||||
@@ -7589,6 +7592,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
AccountID: account.ID,
|
||||
RequestID: requestID,
|
||||
Model: result.Model,
|
||||
RequestedModel: result.Model,
|
||||
UpstreamModel: optionalNonEqualStringPtr(result.UpstreamModel, result.Model),
|
||||
ReasoningEffort: result.ReasoningEffort,
|
||||
InboundEndpoint: optionalTrimmedStringPtr(input.InboundEndpoint),
|
||||
@@ -7719,6 +7723,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
|
||||
}
|
||||
|
||||
var cost *CostBreakdown
|
||||
billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel)
|
||||
|
||||
// 根据请求类型选择计费方式
|
||||
if result.ImageCount > 0 {
|
||||
@@ -7731,7 +7736,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
|
||||
Price4K: apiKey.Group.ImagePrice4K,
|
||||
}
|
||||
}
|
||||
cost = s.billingService.CalculateImageCost(result.Model, result.ImageSize, result.ImageCount, groupConfig, multiplier)
|
||||
cost = s.billingService.CalculateImageCost(billingModel, result.ImageSize, result.ImageCount, groupConfig, multiplier)
|
||||
} else {
|
||||
// Token 计费(使用长上下文计费方法)
|
||||
tokens := UsageTokens{
|
||||
@@ -7743,7 +7748,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
|
||||
CacheCreation1hTokens: result.Usage.CacheCreation1hTokens,
|
||||
}
|
||||
var err error
|
||||
cost, err = s.billingService.CalculateCostWithLongContext(result.Model, tokens, multiplier, input.LongContextThreshold, input.LongContextMultiplier)
|
||||
cost, err = s.billingService.CalculateCostWithLongContext(billingModel, tokens, multiplier, input.LongContextThreshold, input.LongContextMultiplier)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.gateway", "Calculate cost failed: %v", err)
|
||||
cost = &CostBreakdown{ActualCost: 0}
|
||||
@@ -7771,6 +7776,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
|
||||
AccountID: account.ID,
|
||||
RequestID: requestID,
|
||||
Model: result.Model,
|
||||
RequestedModel: result.Model,
|
||||
UpstreamModel: optionalNonEqualStringPtr(result.UpstreamModel, result.Model),
|
||||
ReasoningEffort: result.ReasoningEffort,
|
||||
InboundEndpoint: optionalTrimmedStringPtr(input.InboundEndpoint),
|
||||
|
||||
@@ -879,6 +879,7 @@ func TestOpenAIGatewayServiceRecordUsage_UsesRequestedModelAndUpstreamModelMetad
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, usageRepo.lastLog)
|
||||
require.Equal(t, "gpt-5.1", usageRepo.lastLog.Model)
|
||||
require.Equal(t, "gpt-5.1", usageRepo.lastLog.RequestedModel)
|
||||
require.NotNil(t, usageRepo.lastLog.UpstreamModel)
|
||||
require.Equal(t, "gpt-5.1-codex", *usageRepo.lastLog.UpstreamModel)
|
||||
require.NotNil(t, usageRepo.lastLog.ServiceTier)
|
||||
@@ -894,6 +895,40 @@ func TestOpenAIGatewayServiceRecordUsage_UsesRequestedModelAndUpstreamModelMetad
|
||||
require.Equal(t, 1, userRepo.deductCalls)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_BillsMappedRequestsUsingUpstreamModelFallback(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
|
||||
usage := OpenAIUsage{InputTokens: 20, OutputTokens: 10}
|
||||
|
||||
expectedCost, err := svc.billingService.CalculateCost("gpt-5.1-codex", UsageTokens{
|
||||
InputTokens: 20,
|
||||
OutputTokens: 10,
|
||||
}, 1.1)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
|
||||
Result: &OpenAIForwardResult{
|
||||
RequestID: "resp_upstream_model_billing_fallback",
|
||||
Model: "gpt-5.1",
|
||||
UpstreamModel: "gpt-5.1-codex",
|
||||
Usage: usage,
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{ID: 10},
|
||||
User: &User{ID: 20},
|
||||
Account: &Account{ID: 30},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, usageRepo.lastLog)
|
||||
require.Equal(t, "gpt-5.1", usageRepo.lastLog.Model)
|
||||
require.Equal(t, expectedCost.ActualCost, usageRepo.lastLog.ActualCost)
|
||||
require.Equal(t, expectedCost.TotalCost, usageRepo.lastLog.TotalCost)
|
||||
require.Equal(t, expectedCost.ActualCost, userRepo.lastAmount)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_SubscriptionBillingSetsSubscriptionFields(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
|
||||
@@ -4110,9 +4110,9 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
||||
multiplier = resolver.Resolve(ctx, user.ID, *apiKey.GroupID, apiKey.Group.RateMultiplier)
|
||||
}
|
||||
|
||||
billingModel := result.Model
|
||||
billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel)
|
||||
if result.BillingModel != "" {
|
||||
billingModel = result.BillingModel
|
||||
billingModel = strings.TrimSpace(result.BillingModel)
|
||||
}
|
||||
serviceTier := ""
|
||||
if result.ServiceTier != nil {
|
||||
@@ -4140,6 +4140,7 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
||||
AccountID: account.ID,
|
||||
RequestID: requestID,
|
||||
Model: result.Model,
|
||||
RequestedModel: result.Model,
|
||||
UpstreamModel: optionalNonEqualStringPtr(result.UpstreamModel, result.Model),
|
||||
ServiceTier: result.ServiceTier,
|
||||
ReasoningEffort: result.ReasoningEffort,
|
||||
|
||||
Reference in New Issue
Block a user