fix(usage): preserve requested model in gateway billing paths
Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent) Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
This commit is contained in:
@@ -162,6 +162,32 @@ func TestGatewayServiceRecordUsage_BillingFingerprintFallsBackToContextRequestID
|
|||||||
require.Equal(t, "local:req-local-123", billingRepo.lastCmd.RequestPayloadHash)
|
require.Equal(t, "local:req-local-123", billingRepo.lastCmd.RequestPayloadHash)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGatewayServiceRecordUsage_PreservesRequestedAndUpstreamModels(t *testing.T) {
|
||||||
|
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
|
||||||
|
svc := newGatewayRecordUsageServiceForTest(usageRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{})
|
||||||
|
mappedModel := "claude-sonnet-4-20250514"
|
||||||
|
|
||||||
|
err := svc.RecordUsage(context.Background(), &RecordUsageInput{
|
||||||
|
Result: &ForwardResult{
|
||||||
|
RequestID: "gateway_models_split",
|
||||||
|
Usage: ClaudeUsage{InputTokens: 10, OutputTokens: 6},
|
||||||
|
Model: "claude-sonnet-4",
|
||||||
|
UpstreamModel: mappedModel,
|
||||||
|
Duration: time.Second,
|
||||||
|
},
|
||||||
|
APIKey: &APIKey{ID: 501, Quota: 100},
|
||||||
|
User: &User{ID: 601},
|
||||||
|
Account: &Account{ID: 701},
|
||||||
|
})
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, usageRepo.lastLog)
|
||||||
|
require.Equal(t, "claude-sonnet-4", usageRepo.lastLog.Model)
|
||||||
|
require.Equal(t, "claude-sonnet-4", usageRepo.lastLog.RequestedModel)
|
||||||
|
require.NotNil(t, usageRepo.lastLog.UpstreamModel)
|
||||||
|
require.Equal(t, mappedModel, *usageRepo.lastLog.UpstreamModel)
|
||||||
|
}
|
||||||
|
|
||||||
func TestGatewayServiceRecordUsage_UsageLogWriteErrorDoesNotSkipBilling(t *testing.T) {
|
func TestGatewayServiceRecordUsage_UsageLogWriteErrorDoesNotSkipBilling(t *testing.T) {
|
||||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: false, err: MarkUsageLogCreateNotPersisted(context.Canceled)}
|
usageRepo := &openAIRecordUsageLogRepoStub{inserted: false, err: MarkUsageLogCreateNotPersisted(context.Canceled)}
|
||||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||||
|
|||||||
@@ -482,10 +482,12 @@ type ClaudeUsage struct {
|
|||||||
|
|
||||||
// ForwardResult 转发结果
|
// ForwardResult 转发结果
|
||||||
type ForwardResult struct {
|
type ForwardResult struct {
|
||||||
RequestID string
|
RequestID string
|
||||||
Usage ClaudeUsage
|
Usage ClaudeUsage
|
||||||
Model string
|
Model string
|
||||||
UpstreamModel string // Actual upstream model after mapping (empty = no mapping)
|
// UpstreamModel is the actual upstream model after mapping.
|
||||||
|
// Prefer empty when it is identical to Model; persistence normalizes equal values away as no-op mappings.
|
||||||
|
UpstreamModel string
|
||||||
Stream bool
|
Stream bool
|
||||||
Duration time.Duration
|
Duration time.Duration
|
||||||
FirstTokenMs *int // 首字时间(流式请求)
|
FirstTokenMs *int // 首字时间(流式请求)
|
||||||
@@ -7516,6 +7518,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
|||||||
}
|
}
|
||||||
|
|
||||||
var cost *CostBreakdown
|
var cost *CostBreakdown
|
||||||
|
billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel)
|
||||||
|
|
||||||
// 根据请求类型选择计费方式
|
// 根据请求类型选择计费方式
|
||||||
if result.MediaType == "image" || result.MediaType == "video" {
|
if result.MediaType == "image" || result.MediaType == "video" {
|
||||||
@@ -7531,7 +7534,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
|||||||
if result.MediaType == "image" {
|
if result.MediaType == "image" {
|
||||||
cost = s.billingService.CalculateSoraImageCost(result.ImageSize, result.ImageCount, soraConfig, multiplier)
|
cost = s.billingService.CalculateSoraImageCost(result.ImageSize, result.ImageCount, soraConfig, multiplier)
|
||||||
} else {
|
} else {
|
||||||
cost = s.billingService.CalculateSoraVideoCost(result.Model, soraConfig, multiplier)
|
cost = s.billingService.CalculateSoraVideoCost(billingModel, soraConfig, multiplier)
|
||||||
}
|
}
|
||||||
} else if result.MediaType == "prompt" {
|
} else if result.MediaType == "prompt" {
|
||||||
cost = &CostBreakdown{}
|
cost = &CostBreakdown{}
|
||||||
@@ -7545,7 +7548,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
|||||||
Price4K: apiKey.Group.ImagePrice4K,
|
Price4K: apiKey.Group.ImagePrice4K,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
cost = s.billingService.CalculateImageCost(result.Model, result.ImageSize, result.ImageCount, groupConfig, multiplier)
|
cost = s.billingService.CalculateImageCost(billingModel, result.ImageSize, result.ImageCount, groupConfig, multiplier)
|
||||||
} else {
|
} else {
|
||||||
// Token 计费
|
// Token 计费
|
||||||
tokens := UsageTokens{
|
tokens := UsageTokens{
|
||||||
@@ -7557,7 +7560,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
|||||||
CacheCreation1hTokens: result.Usage.CacheCreation1hTokens,
|
CacheCreation1hTokens: result.Usage.CacheCreation1hTokens,
|
||||||
}
|
}
|
||||||
var err error
|
var err error
|
||||||
cost, err = s.billingService.CalculateCost(result.Model, tokens, multiplier)
|
cost, err = s.billingService.CalculateCost(billingModel, tokens, multiplier)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.LegacyPrintf("service.gateway", "Calculate cost failed: %v", err)
|
logger.LegacyPrintf("service.gateway", "Calculate cost failed: %v", err)
|
||||||
cost = &CostBreakdown{ActualCost: 0}
|
cost = &CostBreakdown{ActualCost: 0}
|
||||||
@@ -7589,6 +7592,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
|||||||
AccountID: account.ID,
|
AccountID: account.ID,
|
||||||
RequestID: requestID,
|
RequestID: requestID,
|
||||||
Model: result.Model,
|
Model: result.Model,
|
||||||
|
RequestedModel: result.Model,
|
||||||
UpstreamModel: optionalNonEqualStringPtr(result.UpstreamModel, result.Model),
|
UpstreamModel: optionalNonEqualStringPtr(result.UpstreamModel, result.Model),
|
||||||
ReasoningEffort: result.ReasoningEffort,
|
ReasoningEffort: result.ReasoningEffort,
|
||||||
InboundEndpoint: optionalTrimmedStringPtr(input.InboundEndpoint),
|
InboundEndpoint: optionalTrimmedStringPtr(input.InboundEndpoint),
|
||||||
@@ -7719,6 +7723,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
|
|||||||
}
|
}
|
||||||
|
|
||||||
var cost *CostBreakdown
|
var cost *CostBreakdown
|
||||||
|
billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel)
|
||||||
|
|
||||||
// 根据请求类型选择计费方式
|
// 根据请求类型选择计费方式
|
||||||
if result.ImageCount > 0 {
|
if result.ImageCount > 0 {
|
||||||
@@ -7731,7 +7736,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
|
|||||||
Price4K: apiKey.Group.ImagePrice4K,
|
Price4K: apiKey.Group.ImagePrice4K,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
cost = s.billingService.CalculateImageCost(result.Model, result.ImageSize, result.ImageCount, groupConfig, multiplier)
|
cost = s.billingService.CalculateImageCost(billingModel, result.ImageSize, result.ImageCount, groupConfig, multiplier)
|
||||||
} else {
|
} else {
|
||||||
// Token 计费(使用长上下文计费方法)
|
// Token 计费(使用长上下文计费方法)
|
||||||
tokens := UsageTokens{
|
tokens := UsageTokens{
|
||||||
@@ -7743,7 +7748,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
|
|||||||
CacheCreation1hTokens: result.Usage.CacheCreation1hTokens,
|
CacheCreation1hTokens: result.Usage.CacheCreation1hTokens,
|
||||||
}
|
}
|
||||||
var err error
|
var err error
|
||||||
cost, err = s.billingService.CalculateCostWithLongContext(result.Model, tokens, multiplier, input.LongContextThreshold, input.LongContextMultiplier)
|
cost, err = s.billingService.CalculateCostWithLongContext(billingModel, tokens, multiplier, input.LongContextThreshold, input.LongContextMultiplier)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.LegacyPrintf("service.gateway", "Calculate cost failed: %v", err)
|
logger.LegacyPrintf("service.gateway", "Calculate cost failed: %v", err)
|
||||||
cost = &CostBreakdown{ActualCost: 0}
|
cost = &CostBreakdown{ActualCost: 0}
|
||||||
@@ -7771,6 +7776,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
|
|||||||
AccountID: account.ID,
|
AccountID: account.ID,
|
||||||
RequestID: requestID,
|
RequestID: requestID,
|
||||||
Model: result.Model,
|
Model: result.Model,
|
||||||
|
RequestedModel: result.Model,
|
||||||
UpstreamModel: optionalNonEqualStringPtr(result.UpstreamModel, result.Model),
|
UpstreamModel: optionalNonEqualStringPtr(result.UpstreamModel, result.Model),
|
||||||
ReasoningEffort: result.ReasoningEffort,
|
ReasoningEffort: result.ReasoningEffort,
|
||||||
InboundEndpoint: optionalTrimmedStringPtr(input.InboundEndpoint),
|
InboundEndpoint: optionalTrimmedStringPtr(input.InboundEndpoint),
|
||||||
|
|||||||
@@ -879,6 +879,7 @@ func TestOpenAIGatewayServiceRecordUsage_UsesRequestedModelAndUpstreamModelMetad
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, usageRepo.lastLog)
|
require.NotNil(t, usageRepo.lastLog)
|
||||||
require.Equal(t, "gpt-5.1", usageRepo.lastLog.Model)
|
require.Equal(t, "gpt-5.1", usageRepo.lastLog.Model)
|
||||||
|
require.Equal(t, "gpt-5.1", usageRepo.lastLog.RequestedModel)
|
||||||
require.NotNil(t, usageRepo.lastLog.UpstreamModel)
|
require.NotNil(t, usageRepo.lastLog.UpstreamModel)
|
||||||
require.Equal(t, "gpt-5.1-codex", *usageRepo.lastLog.UpstreamModel)
|
require.Equal(t, "gpt-5.1-codex", *usageRepo.lastLog.UpstreamModel)
|
||||||
require.NotNil(t, usageRepo.lastLog.ServiceTier)
|
require.NotNil(t, usageRepo.lastLog.ServiceTier)
|
||||||
@@ -894,6 +895,40 @@ func TestOpenAIGatewayServiceRecordUsage_UsesRequestedModelAndUpstreamModelMetad
|
|||||||
require.Equal(t, 1, userRepo.deductCalls)
|
require.Equal(t, 1, userRepo.deductCalls)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOpenAIGatewayServiceRecordUsage_BillsMappedRequestsUsingUpstreamModelFallback(t *testing.T) {
|
||||||
|
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
|
||||||
|
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||||
|
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||||
|
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
|
||||||
|
usage := OpenAIUsage{InputTokens: 20, OutputTokens: 10}
|
||||||
|
|
||||||
|
expectedCost, err := svc.billingService.CalculateCost("gpt-5.1-codex", UsageTokens{
|
||||||
|
InputTokens: 20,
|
||||||
|
OutputTokens: 10,
|
||||||
|
}, 1.1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
|
||||||
|
Result: &OpenAIForwardResult{
|
||||||
|
RequestID: "resp_upstream_model_billing_fallback",
|
||||||
|
Model: "gpt-5.1",
|
||||||
|
UpstreamModel: "gpt-5.1-codex",
|
||||||
|
Usage: usage,
|
||||||
|
Duration: time.Second,
|
||||||
|
},
|
||||||
|
APIKey: &APIKey{ID: 10},
|
||||||
|
User: &User{ID: 20},
|
||||||
|
Account: &Account{ID: 30},
|
||||||
|
})
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, usageRepo.lastLog)
|
||||||
|
require.Equal(t, "gpt-5.1", usageRepo.lastLog.Model)
|
||||||
|
require.Equal(t, expectedCost.ActualCost, usageRepo.lastLog.ActualCost)
|
||||||
|
require.Equal(t, expectedCost.TotalCost, usageRepo.lastLog.TotalCost)
|
||||||
|
require.Equal(t, expectedCost.ActualCost, userRepo.lastAmount)
|
||||||
|
}
|
||||||
|
|
||||||
func TestOpenAIGatewayServiceRecordUsage_SubscriptionBillingSetsSubscriptionFields(t *testing.T) {
|
func TestOpenAIGatewayServiceRecordUsage_SubscriptionBillingSetsSubscriptionFields(t *testing.T) {
|
||||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
|
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
|
||||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||||
|
|||||||
@@ -4110,9 +4110,9 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
|||||||
multiplier = resolver.Resolve(ctx, user.ID, *apiKey.GroupID, apiKey.Group.RateMultiplier)
|
multiplier = resolver.Resolve(ctx, user.ID, *apiKey.GroupID, apiKey.Group.RateMultiplier)
|
||||||
}
|
}
|
||||||
|
|
||||||
billingModel := result.Model
|
billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel)
|
||||||
if result.BillingModel != "" {
|
if result.BillingModel != "" {
|
||||||
billingModel = result.BillingModel
|
billingModel = strings.TrimSpace(result.BillingModel)
|
||||||
}
|
}
|
||||||
serviceTier := ""
|
serviceTier := ""
|
||||||
if result.ServiceTier != nil {
|
if result.ServiceTier != nil {
|
||||||
@@ -4140,6 +4140,7 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
|||||||
AccountID: account.ID,
|
AccountID: account.ID,
|
||||||
RequestID: requestID,
|
RequestID: requestID,
|
||||||
Model: result.Model,
|
Model: result.Model,
|
||||||
|
RequestedModel: result.Model,
|
||||||
UpstreamModel: optionalNonEqualStringPtr(result.UpstreamModel, result.Model),
|
UpstreamModel: optionalNonEqualStringPtr(result.UpstreamModel, result.Model),
|
||||||
ServiceTier: result.ServiceTier,
|
ServiceTier: result.ServiceTier,
|
||||||
ReasoningEffort: result.ReasoningEffort,
|
ReasoningEffort: result.ReasoningEffort,
|
||||||
|
|||||||
Reference in New Issue
Block a user