From 4573868c08e7c0b82fda86596ea762d987a324e1 Mon Sep 17 00:00:00 2001 From: erio Date: Tue, 24 Feb 2026 18:08:19 +0800 Subject: [PATCH] fix(antigravity): bill with mapped model and use final model key for rate limiting - Use mapped model (billingModel) instead of original request model for billing - Use resolveFinalAntigravityModelKey for 429 rate limit model key, ensuring rate limit records match the actual upstream model - Add regression tests for both fixes --- .../service/antigravity_gateway_service.go | 23 ++- .../antigravity_gateway_service_test.go | 142 +++++++++++++++++- .../service/antigravity_rate_limit_test.go | 16 ++ 3 files changed, 172 insertions(+), 9 deletions(-) diff --git a/backend/internal/service/antigravity_gateway_service.go b/backend/internal/service/antigravity_gateway_service.go index 26b14e68..108ff9ab 100644 --- a/backend/internal/service/antigravity_gateway_service.go +++ b/backend/internal/service/antigravity_gateway_service.go @@ -87,7 +87,6 @@ var ( ) const ( - antigravityBillingModelEnv = "GATEWAY_ANTIGRAVITY_BILL_WITH_MAPPED_MODEL" antigravityForwardBaseURLEnv = "GATEWAY_ANTIGRAVITY_FORWARD_BASE_URL" antigravityFallbackSecondsEnv = "GATEWAY_ANTIGRAVITY_FALLBACK_COOLDOWN_SECONDS" ) @@ -1309,6 +1308,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, // 应用 thinking 模式自动后缀:如果 thinking 开启且目标是 claude-sonnet-4-5,自动改为 thinking 版本 thinkingEnabled := claudeReq.Thinking != nil && (claudeReq.Thinking.Type == "enabled" || claudeReq.Thinking.Type == "adaptive") mappedModel = applyThinkingModelSuffix(mappedModel, thinkingEnabled) + billingModel := mappedModel // 获取 access_token if s.tokenProvider == nil { @@ -1622,7 +1622,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, return &ForwardResult{ RequestID: requestID, Usage: *usage, - Model: originalModel, // 使用原始模型用于计费和日志 + Model: billingModel, // 使用映射模型用于计费和日志 Stream: claudeReq.Stream, Duration: time.Since(startTime), FirstTokenMs: firstTokenMs, @@ -1976,6 +1976,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co if mappedModel == "" { return nil, s.writeGoogleError(c, http.StatusForbidden, fmt.Sprintf("model %s not in whitelist", originalModel)) } + billingModel := mappedModel // 获取 access_token if s.tokenProvider == nil { @@ -2205,7 +2206,7 @@ handleSuccess: return &ForwardResult{ RequestID: requestID, Usage: *usage, - Model: originalModel, + Model: billingModel, Stream: stream, Duration: time.Since(startTime), FirstTokenMs: firstTokenMs, @@ -2650,7 +2651,16 @@ func (s *AntigravityGatewayService) handleUpstreamError( defaultDur := s.getDefaultRateLimitDuration() // 尝试解析模型 key 并设置模型级限流 - modelKey := resolveAntigravityModelKey(requestedModel) + // + // 注意:requestedModel 可能是"映射前"的请求模型名(例如 claude-opus-4-6), + // 调度与限流判定使用的是 Antigravity 最终模型名(包含映射与 thinking 后缀)。 + // 因此这里必须写入最终模型 key,确保后续调度能正确避开已限流模型。 + modelKey := resolveFinalAntigravityModelKey(ctx, account, requestedModel) + if strings.TrimSpace(modelKey) == "" { + // 极少数情况下无法映射(理论上不应发生:能转发成功说明映射已通过), + // 保持旧行为作为兜底,避免完全丢失模型级限流记录。 + modelKey = resolveAntigravityModelKey(requestedModel) + } if modelKey != "" { ra := s.resolveResetTime(resetAt, defaultDur) if err := s.accountRepo.SetModelRateLimit(ctx, account.ID, modelKey, ra); err != nil { @@ -3889,7 +3899,6 @@ func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin. return nil, fmt.Errorf("missing model") } originalModel := claudeReq.Model - billingModel := originalModel // 构建上游请求 URL upstreamURL := baseURL + "/v1/messages" @@ -3942,7 +3951,7 @@ func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin. _, _ = c.Writer.Write(respBody) return &ForwardResult{ - Model: billingModel, + Model: originalModel, }, nil } @@ -3983,7 +3992,7 @@ func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin. logger.LegacyPrintf("service.antigravity_gateway", "%s status=success duration_ms=%d", prefix, duration.Milliseconds()) return &ForwardResult{ - Model: billingModel, + Model: originalModel, Stream: claudeReq.Stream, Duration: duration, FirstTokenMs: firstTokenMs, diff --git a/backend/internal/service/antigravity_gateway_service_test.go b/backend/internal/service/antigravity_gateway_service_test.go index abe7b75d..84b65adc 100644 --- a/backend/internal/service/antigravity_gateway_service_test.go +++ b/backend/internal/service/antigravity_gateway_service_test.go @@ -134,6 +134,36 @@ func (s *httpUpstreamStub) DoWithTLS(_ *http.Request, _ string, _ int64, _ int, return s.resp, s.err } +type antigravitySettingRepoStub struct{} + +func (s *antigravitySettingRepoStub) Get(ctx context.Context, key string) (*Setting, error) { + panic("unexpected Get call") +} + +func (s *antigravitySettingRepoStub) GetValue(ctx context.Context, key string) (string, error) { + return "", ErrSettingNotFound +} + +func (s *antigravitySettingRepoStub) Set(ctx context.Context, key, value string) error { + panic("unexpected Set call") +} + +func (s *antigravitySettingRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) { + panic("unexpected GetMultiple call") +} + +func (s *antigravitySettingRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error { + panic("unexpected SetMultiple call") +} + +func (s *antigravitySettingRepoStub) GetAll(ctx context.Context) (map[string]string, error) { + panic("unexpected GetAll call") +} + +func (s *antigravitySettingRepoStub) Delete(ctx context.Context, key string) error { + panic("unexpected Delete call") +} + func TestAntigravityGatewayService_Forward_PromptTooLong(t *testing.T) { gin.SetMode(gin.TestMode) writer := httptest.NewRecorder() @@ -160,8 +190,9 @@ func TestAntigravityGatewayService_Forward_PromptTooLong(t *testing.T) { } svc := &AntigravityGatewayService{ - tokenProvider: &AntigravityTokenProvider{}, - httpUpstream: &httpUpstreamStub{resp: resp}, + settingService: NewSettingService(&antigravitySettingRepoStub{}, &config.Config{Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}}), + tokenProvider: &AntigravityTokenProvider{}, + httpUpstream: &httpUpstreamStub{resp: resp}, } account := &Account{ @@ -418,6 +449,113 @@ func TestAntigravityGatewayService_ForwardGemini_StickySessionForceCacheBilling( require.True(t, failoverErr.ForceCacheBilling, "ForceCacheBilling should be true for sticky session switch") } +// TestAntigravityGatewayService_Forward_BillsWithMappedModel +// 验证:Antigravity Claude 转发返回的计费模型使用映射后的模型 +func TestAntigravityGatewayService_Forward_BillsWithMappedModel(t *testing.T) { + gin.SetMode(gin.TestMode) + writer := httptest.NewRecorder() + c, _ := gin.CreateTestContext(writer) + + body, err := json.Marshal(map[string]any{ + "model": "claude-sonnet-4-5", + "messages": []map[string]any{ + {"role": "user", "content": "hello"}, + }, + "max_tokens": 16, + "stream": true, + }) + require.NoError(t, err) + + req := httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)) + c.Request = req + + upstreamBody := []byte("data: {\"response\":{\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"ok\"}]},\"finishReason\":\"STOP\"}],\"usageMetadata\":{\"promptTokenCount\":8,\"candidatesTokenCount\":3}}}\n\n") + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"X-Request-Id": []string{"req-bill-1"}}, + Body: io.NopCloser(bytes.NewReader(upstreamBody)), + } + + svc := &AntigravityGatewayService{ + settingService: NewSettingService(&antigravitySettingRepoStub{}, &config.Config{Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}}), + tokenProvider: &AntigravityTokenProvider{}, + httpUpstream: &httpUpstreamStub{resp: resp}, + } + + const mappedModel = "gemini-3-pro-high" + account := &Account{ + ID: 5, + Name: "acc-forward-billing", + Platform: PlatformAntigravity, + Type: AccountTypeOAuth, + Status: StatusActive, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "token", + "model_mapping": map[string]any{ + "claude-sonnet-4-5": mappedModel, + }, + }, + } + + result, err := svc.Forward(context.Background(), c, account, body, false) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, mappedModel, result.Model) +} + +// TestAntigravityGatewayService_ForwardGemini_BillsWithMappedModel +// 验证:Antigravity Gemini 转发返回的计费模型使用映射后的模型 +func TestAntigravityGatewayService_ForwardGemini_BillsWithMappedModel(t *testing.T) { + gin.SetMode(gin.TestMode) + writer := httptest.NewRecorder() + c, _ := gin.CreateTestContext(writer) + + body, err := json.Marshal(map[string]any{ + "contents": []map[string]any{ + {"role": "user", "parts": []map[string]any{{"text": "hello"}}}, + }, + }) + require.NoError(t, err) + + req := httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-flash:generateContent", bytes.NewReader(body)) + c.Request = req + + upstreamBody := []byte("data: {\"response\":{\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"ok\"}]},\"finishReason\":\"STOP\"}],\"usageMetadata\":{\"promptTokenCount\":8,\"candidatesTokenCount\":3}}}\n\n") + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"X-Request-Id": []string{"req-bill-2"}}, + Body: io.NopCloser(bytes.NewReader(upstreamBody)), + } + + svc := &AntigravityGatewayService{ + settingService: NewSettingService(&antigravitySettingRepoStub{}, &config.Config{Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}}), + tokenProvider: &AntigravityTokenProvider{}, + httpUpstream: &httpUpstreamStub{resp: resp}, + } + + const mappedModel = "gemini-3-pro-high" + account := &Account{ + ID: 6, + Name: "acc-gemini-billing", + Platform: PlatformAntigravity, + Type: AccountTypeOAuth, + Status: StatusActive, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "token", + "model_mapping": map[string]any{ + "gemini-2.5-flash": mappedModel, + }, + }, + } + + result, err := svc.ForwardGemini(context.Background(), c, account, "gemini-2.5-flash", "generateContent", true, body, false) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, mappedModel, result.Model) +} + // TestStreamUpstreamResponse_UsageAndFirstToken // 验证:usage 字段可被累积/覆盖更新,并且能记录首 token 时间 func TestStreamUpstreamResponse_UsageAndFirstToken(t *testing.T) { diff --git a/backend/internal/service/antigravity_rate_limit_test.go b/backend/internal/service/antigravity_rate_limit_test.go index 6a486ebc..dd8dd83f 100644 --- a/backend/internal/service/antigravity_rate_limit_test.go +++ b/backend/internal/service/antigravity_rate_limit_test.go @@ -197,6 +197,22 @@ func TestHandleUpstreamError_429_NonModelRateLimit(t *testing.T) { require.Equal(t, "claude-sonnet-4-5", repo.modelRateLimitCalls[0].modelKey) } +// TestHandleUpstreamError_429_NonModelRateLimit_UsesMappedModelKey 测试 429 非模型限流场景 +// 验证:requestedModel 会被映射到 Antigravity 最终模型(例如 claude-opus-4-6 -> claude-opus-4-6-thinking) +func TestHandleUpstreamError_429_NonModelRateLimit_UsesMappedModelKey(t *testing.T) { + repo := &stubAntigravityAccountRepo{} + svc := &AntigravityGatewayService{accountRepo: repo} + account := &Account{ID: 20, Name: "acc-20", Platform: PlatformAntigravity} + + body := buildGeminiRateLimitBody("5s") + + result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusTooManyRequests, http.Header{}, body, "claude-opus-4-6", 0, "", false) + + require.Nil(t, result) + require.Len(t, repo.modelRateLimitCalls, 1) + require.Equal(t, "claude-opus-4-6-thinking", repo.modelRateLimitCalls[0].modelKey) +} + // TestHandleUpstreamError_503_ModelCapacityExhausted 测试 503 模型容量不足场景 // MODEL_CAPACITY_EXHAUSTED 时应等待重试,不切换账号 func TestHandleUpstreamError_503_ModelCapacityExhausted(t *testing.T) {