diff --git a/backend/internal/service/antigravity_gateway_service.go b/backend/internal/service/antigravity_gateway_service.go index 50fa78f2..6ee8280c 100644 --- a/backend/internal/service/antigravity_gateway_service.go +++ b/backend/internal/service/antigravity_gateway_service.go @@ -1742,7 +1742,8 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, return &ForwardResult{ RequestID: requestID, Usage: *usage, - Model: billingModel, // 使用映射模型用于计费和日志 + Model: originalModel, + UpstreamModel: billingModel, Stream: claudeReq.Stream, Duration: time.Since(startTime), FirstTokenMs: firstTokenMs, @@ -2435,7 +2436,8 @@ handleSuccess: return &ForwardResult{ RequestID: requestID, Usage: *usage, - Model: billingModel, + Model: originalModel, + UpstreamModel: billingModel, Stream: stream, Duration: time.Since(startTime), FirstTokenMs: firstTokenMs, diff --git a/backend/internal/service/antigravity_gateway_service_test.go b/backend/internal/service/antigravity_gateway_service_test.go index 6e0a7305..f5f9434c 100644 --- a/backend/internal/service/antigravity_gateway_service_test.go +++ b/backend/internal/service/antigravity_gateway_service_test.go @@ -542,7 +542,8 @@ func TestAntigravityGatewayService_Forward_BillsWithMappedModel(t *testing.T) { result, err := svc.Forward(context.Background(), c, account, body, false) require.NoError(t, err) require.NotNil(t, result) - require.Equal(t, mappedModel, result.Model) + require.Equal(t, "claude-sonnet-4-5", result.Model) + require.Equal(t, mappedModel, result.UpstreamModel) } // TestAntigravityGatewayService_ForwardGemini_BillsWithMappedModel @@ -594,7 +595,8 @@ func TestAntigravityGatewayService_ForwardGemini_BillsWithMappedModel(t *testing result, err := svc.ForwardGemini(context.Background(), c, account, "gemini-2.5-flash", "generateContent", true, body, false) require.NoError(t, err) require.NotNil(t, result) - require.Equal(t, mappedModel, result.Model) + require.Equal(t, "gemini-2.5-flash", result.Model) + require.Equal(t, mappedModel, result.UpstreamModel) } func TestAntigravityGatewayService_ForwardGemini_RetriesCorruptedThoughtSignature(t *testing.T) { @@ -664,7 +666,8 @@ func TestAntigravityGatewayService_ForwardGemini_RetriesCorruptedThoughtSignatur result, err := svc.ForwardGemini(context.Background(), c, account, originalModel, "streamGenerateContent", true, body, false) require.NoError(t, err) require.NotNil(t, result) - require.Equal(t, mappedModel, result.Model) + require.Equal(t, originalModel, result.Model) + require.Equal(t, mappedModel, result.UpstreamModel) require.Len(t, upstream.requestBodies, 2, "signature error should trigger exactly one retry") firstReq := string(upstream.requestBodies[0]) diff --git a/backend/internal/service/sora_gateway_service.go b/backend/internal/service/sora_gateway_service.go index ab6871bb..e9d325f4 100644 --- a/backend/internal/service/sora_gateway_service.go +++ b/backend/internal/service/sora_gateway_service.go @@ -148,10 +148,13 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "model is required", clientStream) return nil, errors.New("model is required") } + originalModel := reqModel mappedModel := account.GetMappedModel(reqModel) + var upstreamModel string if mappedModel != "" && mappedModel != reqModel { reqModel = mappedModel + upstreamModel = mappedModel } modelCfg, ok := GetSoraModelConfig(reqModel) @@ -213,13 +216,14 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun c.JSON(http.StatusOK, buildSoraNonStreamResponse(content, reqModel)) } return &ForwardResult{ - RequestID: "", - Model: reqModel, - Stream: clientStream, - Duration: time.Since(startTime), - FirstTokenMs: firstTokenMs, - Usage: ClaudeUsage{}, - MediaType: "prompt", + RequestID: "", + Model: originalModel, + UpstreamModel: upstreamModel, + Stream: clientStream, + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + Usage: ClaudeUsage{}, + MediaType: "prompt", }, nil } @@ -269,13 +273,14 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun c.JSON(http.StatusOK, resp) } return &ForwardResult{ - RequestID: "", - Model: reqModel, - Stream: clientStream, - Duration: time.Since(startTime), - FirstTokenMs: firstTokenMs, - Usage: ClaudeUsage{}, - MediaType: "prompt", + RequestID: "", + Model: originalModel, + UpstreamModel: upstreamModel, + Stream: clientStream, + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + Usage: ClaudeUsage{}, + MediaType: "prompt", }, nil } if characterResult != nil && strings.TrimSpace(characterResult.Username) != "" { @@ -419,16 +424,17 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun } return &ForwardResult{ - RequestID: taskID, - Model: reqModel, - Stream: clientStream, - Duration: time.Since(startTime), - FirstTokenMs: firstTokenMs, - Usage: ClaudeUsage{}, - MediaType: mediaType, - MediaURL: firstMediaURL(finalURLs), - ImageCount: imageCount, - ImageSize: imageSize, + RequestID: taskID, + Model: originalModel, + UpstreamModel: upstreamModel, + Stream: clientStream, + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + Usage: ClaudeUsage{}, + MediaType: mediaType, + MediaURL: firstMediaURL(finalURLs), + ImageCount: imageCount, + ImageSize: imageSize, }, nil } diff --git a/backend/internal/service/sora_gateway_service_test.go b/backend/internal/service/sora_gateway_service_test.go index 206636ff..2fef600c 100644 --- a/backend/internal/service/sora_gateway_service_test.go +++ b/backend/internal/service/sora_gateway_service_test.go @@ -144,6 +144,11 @@ func TestSoraGatewayService_ForwardPromptEnhance(t *testing.T) { ID: 1, Platform: PlatformSora, Status: StatusActive, + Credentials: map[string]any{ + "model_mapping": map[string]any{ + "prompt-enhance-short-10s": "prompt-enhance-short-15s", + }, + }, } body := []byte(`{"model":"prompt-enhance-short-10s","messages":[{"role":"user","content":"cat running"}],"stream":false}`) @@ -152,6 +157,7 @@ func TestSoraGatewayService_ForwardPromptEnhance(t *testing.T) { require.NotNil(t, result) require.Equal(t, "prompt", result.MediaType) require.Equal(t, "prompt-enhance-short-10s", result.Model) + require.Equal(t, "prompt-enhance-short-15s", result.UpstreamModel) } func TestSoraGatewayService_ForwardStoryboardPrompt(t *testing.T) {