From 73f455745c9a4ce7372a501ffa60d00fec0c82dc Mon Sep 17 00:00:00 2001 From: erio Date: Mon, 9 Feb 2026 19:54:54 +0800 Subject: [PATCH] feat: ErrorPolicySkipped returns 500 instead of upstream status code When custom error codes are enabled and the upstream error code is NOT in the configured list, return HTTP 500 to the client instead of transparently forwarding the original status code. Also adds integration test TestCustomErrorCode599 verifying that 429, 500, 503, 401, 403 all return 500 without triggering SetRateLimited or SetError. --- .../service/antigravity_gateway_service.go | 17 +-- .../service/error_policy_integration_test.go | 108 +++++++++++++++++- backend/internal/service/error_policy_test.go | 8 +- .../service/gemini_messages_compat_service.go | 4 +- 4 files changed, 125 insertions(+), 12 deletions(-) diff --git a/backend/internal/service/antigravity_gateway_service.go b/backend/internal/service/antigravity_gateway_service.go index 014b3c86..c295627e 100644 --- a/backend/internal/service/antigravity_gateway_service.go +++ b/backend/internal/service/antigravity_gateway_service.go @@ -371,12 +371,12 @@ urlFallbackLoop: _ = resp.Body.Close() // ★ 统一入口:自定义错误码 + 临时不可调度 - if handled, policyErr := s.applyErrorPolicy(p, resp.StatusCode, resp.Header, respBody); handled { + if handled, outStatus, policyErr := s.applyErrorPolicy(p, resp.StatusCode, resp.Header, respBody); handled { if policyErr != nil { return nil, policyErr } resp = &http.Response{ - StatusCode: resp.StatusCode, + StatusCode: outStatus, Header: resp.Header.Clone(), Body: io.NopCloser(bytes.NewReader(respBody)), } @@ -610,21 +610,22 @@ func (s *AntigravityGatewayService) checkErrorPolicy(ctx context.Context, accoun return s.rateLimitService.CheckErrorPolicy(ctx, account, statusCode, body) } -// applyErrorPolicy 应用错误策略结果,返回是否应终止当前循环 -func (s *AntigravityGatewayService) applyErrorPolicy(p antigravityRetryLoopParams, statusCode int, headers http.Header, respBody []byte) (handled bool, retErr error) { +// applyErrorPolicy 应用错误策略结果,返回是否应终止当前循环及应返回的状态码。 +// ErrorPolicySkipped 时 outStatus 为 500(前端约定:未命中的错误返回 500)。 +func (s *AntigravityGatewayService) applyErrorPolicy(p antigravityRetryLoopParams, statusCode int, headers http.Header, respBody []byte) (handled bool, outStatus int, retErr error) { switch s.checkErrorPolicy(p.ctx, p.account, statusCode, respBody) { case ErrorPolicySkipped: - return true, nil + return true, http.StatusInternalServerError, nil case ErrorPolicyMatched: _ = p.handleError(p.ctx, p.prefix, p.account, statusCode, headers, respBody, p.requestedModel, p.groupID, p.sessionHash, p.isStickySession) - return true, nil + return true, statusCode, nil case ErrorPolicyTempUnscheduled: slog.Info("temp_unschedulable_matched", "prefix", p.prefix, "status_code", statusCode, "account_id", p.account.ID) - return true, &AntigravityAccountSwitchError{OriginalAccountID: p.account.ID, IsStickySession: p.isStickySession} + return true, statusCode, &AntigravityAccountSwitchError{OriginalAccountID: p.account.ID, IsStickySession: p.isStickySession} } - return false, nil + return false, statusCode, nil } // mapAntigravityModel 获取映射后的模型名 diff --git a/backend/internal/service/error_policy_integration_test.go b/backend/internal/service/error_policy_integration_test.go index 9f8ad938..a8b42a2c 100644 --- a/backend/internal/service/error_policy_integration_test.go +++ b/backend/internal/service/error_policy_integration_test.go @@ -116,7 +116,7 @@ func TestRetryLoop_ErrorPolicy_CustomErrorCodes(t *testing.T) { customCodes: []any{float64(500)}, expectHandleError: 0, expectUpstream: 1, - expectStatusCode: 429, + expectStatusCode: 500, }, { name: "500_in_custom_codes_matched", @@ -364,3 +364,109 @@ func TestRetryLoop_ErrorPolicy_NoPolicy_OriginalBehavior(t *testing.T) { require.Equal(t, antigravityMaxRetries, upstream.calls, "should exhaust all retries") require.Equal(t, 1, handleErrorCount, "handleError should be called once after retries exhausted") } + +// --------------------------------------------------------------------------- +// epTrackingRepo — records SetRateLimited / SetError calls for verification. +// --------------------------------------------------------------------------- + +type epTrackingRepo struct { + mockAccountRepoForGemini + rateLimitedCalls int + rateLimitedID int64 + setErrCalls int + setErrID int64 + tempCalls int +} + +func (r *epTrackingRepo) SetRateLimited(_ context.Context, id int64, _ time.Time) error { + r.rateLimitedCalls++ + r.rateLimitedID = id + return nil +} + +func (r *epTrackingRepo) SetError(_ context.Context, id int64, _ string) error { + r.setErrCalls++ + r.setErrID = id + return nil +} + +func (r *epTrackingRepo) SetTempUnschedulable(_ context.Context, _ int64, _ time.Time, _ string) error { + r.tempCalls++ + return nil +} + +// --------------------------------------------------------------------------- +// TestCustomErrorCode599_SkippedErrors_Return500_NoRateLimit +// +// 核心场景:自定义错误码设为 [599](一个不会真正出现的错误码), +// 当上游返回 429/500/503/401 时: +// - 返回给客户端的状态码必须是 500(而不是透传原始状态码) +// - 不调用 SetRateLimited(不进入限流状态) +// - 不调用 SetError(不停止调度) +// - 不调用 handleError +// --------------------------------------------------------------------------- + +func TestCustomErrorCode599_SkippedErrors_Return500_NoRateLimit(t *testing.T) { + errorCodes := []int{429, 500, 503, 401, 403} + + for _, upstreamStatus := range errorCodes { + t.Run(http.StatusText(upstreamStatus), func(t *testing.T) { + saveAndSetBaseURLs(t) + + upstream := &epFixedUpstream{ + statusCode: upstreamStatus, + body: `{"error":"some upstream error"}`, + } + repo := &epTrackingRepo{} + rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) + svc := &AntigravityGatewayService{rateLimitService: rlSvc} + + account := &Account{ + ID: 500, + Type: AccountTypeAPIKey, + Platform: PlatformAntigravity, + Schedulable: true, + Status: StatusActive, + Concurrency: 1, + Credentials: map[string]any{ + "custom_error_codes_enabled": true, + "custom_error_codes": []any{float64(599)}, + }, + } + + var handleErrorCount int + p := newRetryParams(account, upstream, func(_ context.Context, _ string, _ *Account, _ int, _ http.Header, _ []byte, _ string, _ int64, _ string, _ bool) *handleModelRateLimitResult { + handleErrorCount++ + return nil + }) + + result, err := svc.antigravityRetryLoop(p) + + // 不应返回 error(Skipped 不触发账号切换) + require.NoError(t, err, "should not return error") + require.NotNil(t, result, "result should not be nil") + require.NotNil(t, result.resp, "response should not be nil") + defer func() { _ = result.resp.Body.Close() }() + + // 状态码必须是 500(不透传原始状态码) + require.Equal(t, http.StatusInternalServerError, result.resp.StatusCode, + "skipped error should return 500, not %d", upstreamStatus) + + // 不调用 handleError + require.Equal(t, 0, handleErrorCount, + "handleError should NOT be called for skipped errors") + + // 不标记限流 + require.Equal(t, 0, repo.rateLimitedCalls, + "SetRateLimited should NOT be called for skipped errors") + + // 不停止调度 + require.Equal(t, 0, repo.setErrCalls, + "SetError should NOT be called for skipped errors") + + // 只调用一次上游(不重试) + require.Equal(t, 1, upstream.calls, + "should call upstream exactly once (no retry)") + }) + } +} diff --git a/backend/internal/service/error_policy_test.go b/backend/internal/service/error_policy_test.go index a8b69c22..9d7d025e 100644 --- a/backend/internal/service/error_policy_test.go +++ b/backend/internal/service/error_policy_test.go @@ -158,6 +158,7 @@ func TestApplyErrorPolicy(t *testing.T) { statusCode int body []byte expectedHandled bool + expectedStatus int // expected outStatus expectedSwitchErr bool // expect *AntigravityAccountSwitchError handleErrorCalls int }{ @@ -171,6 +172,7 @@ func TestApplyErrorPolicy(t *testing.T) { statusCode: 500, body: []byte(`"error"`), expectedHandled: false, + expectedStatus: 500, // passthrough handleErrorCalls: 0, }, { @@ -187,6 +189,7 @@ func TestApplyErrorPolicy(t *testing.T) { statusCode: 500, // not in custom codes body: []byte(`"error"`), expectedHandled: true, + expectedStatus: http.StatusInternalServerError, // skipped → 500 handleErrorCalls: 0, }, { @@ -203,6 +206,7 @@ func TestApplyErrorPolicy(t *testing.T) { statusCode: 500, body: []byte(`"error"`), expectedHandled: true, + expectedStatus: 500, // matched → original status handleErrorCalls: 1, }, { @@ -225,6 +229,7 @@ func TestApplyErrorPolicy(t *testing.T) { statusCode: 503, body: []byte(`overloaded`), expectedHandled: true, + expectedStatus: 503, // temp_unscheduled → original status expectedSwitchErr: true, handleErrorCalls: 0, }, @@ -250,9 +255,10 @@ func TestApplyErrorPolicy(t *testing.T) { isStickySession: true, } - handled, retErr := svc.applyErrorPolicy(p, tt.statusCode, http.Header{}, tt.body) + handled, outStatus, retErr := svc.applyErrorPolicy(p, tt.statusCode, http.Header{}, tt.body) require.Equal(t, tt.expectedHandled, handled, "handled mismatch") + require.Equal(t, tt.expectedStatus, outStatus, "outStatus mismatch") require.Equal(t, tt.handleErrorCalls, handleErrorCount, "handleError call count mismatch") if tt.expectedSwitchErr { diff --git a/backend/internal/service/gemini_messages_compat_service.go b/backend/internal/service/gemini_messages_compat_service.go index d77f6f92..9197021f 100644 --- a/backend/internal/service/gemini_messages_compat_service.go +++ b/backend/internal/service/gemini_messages_compat_service.go @@ -839,7 +839,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex if upstreamReqID == "" { upstreamReqID = resp.Header.Get("x-goog-request-id") } - return nil, s.writeGeminiMappedError(c, account, resp.StatusCode, upstreamReqID, respBody) + return nil, s.writeGeminiMappedError(c, account, http.StatusInternalServerError, upstreamReqID, respBody) case ErrorPolicyMatched, ErrorPolicyTempUnscheduled: s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) upstreamReqID := resp.Header.Get(requestIDHeader) @@ -1283,7 +1283,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin. if contentType == "" { contentType = "application/json" } - c.Data(resp.StatusCode, contentType, respBody) + c.Data(http.StatusInternalServerError, contentType, respBody) return nil, fmt.Errorf("gemini upstream error: %d (skipped by error policy)", resp.StatusCode) case ErrorPolicyMatched, ErrorPolicyTempUnscheduled: s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)