diff --git a/backend/internal/service/antigravity_gateway_service.go b/backend/internal/service/antigravity_gateway_service.go index 014b3c86..81a1c149 100644 --- a/backend/internal/service/antigravity_gateway_service.go +++ b/backend/internal/service/antigravity_gateway_service.go @@ -371,12 +371,12 @@ urlFallbackLoop: _ = resp.Body.Close() // ★ 统一入口:自定义错误码 + 临时不可调度 - if handled, policyErr := s.applyErrorPolicy(p, resp.StatusCode, resp.Header, respBody); handled { + if handled, outStatus, policyErr := s.applyErrorPolicy(p, resp.StatusCode, resp.Header, respBody); handled { if policyErr != nil { return nil, policyErr } resp = &http.Response{ - StatusCode: resp.StatusCode, + StatusCode: outStatus, Header: resp.Header.Clone(), Body: io.NopCloser(bytes.NewReader(respBody)), } @@ -610,21 +610,22 @@ func (s *AntigravityGatewayService) checkErrorPolicy(ctx context.Context, accoun return s.rateLimitService.CheckErrorPolicy(ctx, account, statusCode, body) } -// applyErrorPolicy 应用错误策略结果,返回是否应终止当前循环 -func (s *AntigravityGatewayService) applyErrorPolicy(p antigravityRetryLoopParams, statusCode int, headers http.Header, respBody []byte) (handled bool, retErr error) { +// applyErrorPolicy 应用错误策略结果,返回是否应终止当前循环及应返回的状态码。 +// ErrorPolicySkipped 时 outStatus 为 500(前端约定:未命中的错误返回 500)。 +func (s *AntigravityGatewayService) applyErrorPolicy(p antigravityRetryLoopParams, statusCode int, headers http.Header, respBody []byte) (handled bool, outStatus int, retErr error) { switch s.checkErrorPolicy(p.ctx, p.account, statusCode, respBody) { case ErrorPolicySkipped: - return true, nil + return true, http.StatusInternalServerError, nil case ErrorPolicyMatched: _ = p.handleError(p.ctx, p.prefix, p.account, statusCode, headers, respBody, p.requestedModel, p.groupID, p.sessionHash, p.isStickySession) - return true, nil + return true, statusCode, nil case ErrorPolicyTempUnscheduled: slog.Info("temp_unschedulable_matched", "prefix", p.prefix, "status_code", statusCode, "account_id", p.account.ID) - return true, &AntigravityAccountSwitchError{OriginalAccountID: p.account.ID, IsStickySession: p.isStickySession} + return true, statusCode, &AntigravityAccountSwitchError{OriginalAccountID: p.account.ID, IsStickySession: p.isStickySession} } - return false, nil + return false, statusCode, nil } // mapAntigravityModel 获取映射后的模型名 @@ -2242,6 +2243,10 @@ func (s *AntigravityGatewayService) handleUpstreamError( requestedModel string, groupID int64, sessionHash string, isStickySession bool, ) *handleModelRateLimitResult { + // 遵守自定义错误码策略:未命中则跳过所有限流处理 + if !account.ShouldHandleErrorCode(statusCode) { + return nil + } // 模型级限流处理(优先) result := s.handleModelRateLimit(&handleModelRateLimitParams{ ctx: ctx, diff --git a/backend/internal/service/error_policy_integration_test.go b/backend/internal/service/error_policy_integration_test.go index 9f8ad938..a8b42a2c 100644 --- a/backend/internal/service/error_policy_integration_test.go +++ b/backend/internal/service/error_policy_integration_test.go @@ -116,7 +116,7 @@ func TestRetryLoop_ErrorPolicy_CustomErrorCodes(t *testing.T) { customCodes: []any{float64(500)}, expectHandleError: 0, expectUpstream: 1, - expectStatusCode: 429, + expectStatusCode: 500, }, { name: "500_in_custom_codes_matched", @@ -364,3 +364,109 @@ func TestRetryLoop_ErrorPolicy_NoPolicy_OriginalBehavior(t *testing.T) { require.Equal(t, antigravityMaxRetries, upstream.calls, "should exhaust all retries") require.Equal(t, 1, handleErrorCount, "handleError should be called once after retries exhausted") } + +// --------------------------------------------------------------------------- +// epTrackingRepo — records SetRateLimited / SetError calls for verification. +// --------------------------------------------------------------------------- + +type epTrackingRepo struct { + mockAccountRepoForGemini + rateLimitedCalls int + rateLimitedID int64 + setErrCalls int + setErrID int64 + tempCalls int +} + +func (r *epTrackingRepo) SetRateLimited(_ context.Context, id int64, _ time.Time) error { + r.rateLimitedCalls++ + r.rateLimitedID = id + return nil +} + +func (r *epTrackingRepo) SetError(_ context.Context, id int64, _ string) error { + r.setErrCalls++ + r.setErrID = id + return nil +} + +func (r *epTrackingRepo) SetTempUnschedulable(_ context.Context, _ int64, _ time.Time, _ string) error { + r.tempCalls++ + return nil +} + +// --------------------------------------------------------------------------- +// TestCustomErrorCode599_SkippedErrors_Return500_NoRateLimit +// +// 核心场景:自定义错误码设为 [599](一个不会真正出现的错误码), +// 当上游返回 429/500/503/401 时: +// - 返回给客户端的状态码必须是 500(而不是透传原始状态码) +// - 不调用 SetRateLimited(不进入限流状态) +// - 不调用 SetError(不停止调度) +// - 不调用 handleError +// --------------------------------------------------------------------------- + +func TestCustomErrorCode599_SkippedErrors_Return500_NoRateLimit(t *testing.T) { + errorCodes := []int{429, 500, 503, 401, 403} + + for _, upstreamStatus := range errorCodes { + t.Run(http.StatusText(upstreamStatus), func(t *testing.T) { + saveAndSetBaseURLs(t) + + upstream := &epFixedUpstream{ + statusCode: upstreamStatus, + body: `{"error":"some upstream error"}`, + } + repo := &epTrackingRepo{} + rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) + svc := &AntigravityGatewayService{rateLimitService: rlSvc} + + account := &Account{ + ID: 500, + Type: AccountTypeAPIKey, + Platform: PlatformAntigravity, + Schedulable: true, + Status: StatusActive, + Concurrency: 1, + Credentials: map[string]any{ + "custom_error_codes_enabled": true, + "custom_error_codes": []any{float64(599)}, + }, + } + + var handleErrorCount int + p := newRetryParams(account, upstream, func(_ context.Context, _ string, _ *Account, _ int, _ http.Header, _ []byte, _ string, _ int64, _ string, _ bool) *handleModelRateLimitResult { + handleErrorCount++ + return nil + }) + + result, err := svc.antigravityRetryLoop(p) + + // 不应返回 error(Skipped 不触发账号切换) + require.NoError(t, err, "should not return error") + require.NotNil(t, result, "result should not be nil") + require.NotNil(t, result.resp, "response should not be nil") + defer func() { _ = result.resp.Body.Close() }() + + // 状态码必须是 500(不透传原始状态码) + require.Equal(t, http.StatusInternalServerError, result.resp.StatusCode, + "skipped error should return 500, not %d", upstreamStatus) + + // 不调用 handleError + require.Equal(t, 0, handleErrorCount, + "handleError should NOT be called for skipped errors") + + // 不标记限流 + require.Equal(t, 0, repo.rateLimitedCalls, + "SetRateLimited should NOT be called for skipped errors") + + // 不停止调度 + require.Equal(t, 0, repo.setErrCalls, + "SetError should NOT be called for skipped errors") + + // 只调用一次上游(不重试) + require.Equal(t, 1, upstream.calls, + "should call upstream exactly once (no retry)") + }) + } +} diff --git a/backend/internal/service/error_policy_test.go b/backend/internal/service/error_policy_test.go index a8b69c22..9d7d025e 100644 --- a/backend/internal/service/error_policy_test.go +++ b/backend/internal/service/error_policy_test.go @@ -158,6 +158,7 @@ func TestApplyErrorPolicy(t *testing.T) { statusCode int body []byte expectedHandled bool + expectedStatus int // expected outStatus expectedSwitchErr bool // expect *AntigravityAccountSwitchError handleErrorCalls int }{ @@ -171,6 +172,7 @@ func TestApplyErrorPolicy(t *testing.T) { statusCode: 500, body: []byte(`"error"`), expectedHandled: false, + expectedStatus: 500, // passthrough handleErrorCalls: 0, }, { @@ -187,6 +189,7 @@ func TestApplyErrorPolicy(t *testing.T) { statusCode: 500, // not in custom codes body: []byte(`"error"`), expectedHandled: true, + expectedStatus: http.StatusInternalServerError, // skipped → 500 handleErrorCalls: 0, }, { @@ -203,6 +206,7 @@ func TestApplyErrorPolicy(t *testing.T) { statusCode: 500, body: []byte(`"error"`), expectedHandled: true, + expectedStatus: 500, // matched → original status handleErrorCalls: 1, }, { @@ -225,6 +229,7 @@ func TestApplyErrorPolicy(t *testing.T) { statusCode: 503, body: []byte(`overloaded`), expectedHandled: true, + expectedStatus: 503, // temp_unscheduled → original status expectedSwitchErr: true, handleErrorCalls: 0, }, @@ -250,9 +255,10 @@ func TestApplyErrorPolicy(t *testing.T) { isStickySession: true, } - handled, retErr := svc.applyErrorPolicy(p, tt.statusCode, http.Header{}, tt.body) + handled, outStatus, retErr := svc.applyErrorPolicy(p, tt.statusCode, http.Header{}, tt.body) require.Equal(t, tt.expectedHandled, handled, "handled mismatch") + require.Equal(t, tt.expectedStatus, outStatus, "outStatus mismatch") require.Equal(t, tt.handleErrorCalls, handleErrorCount, "handleError call count mismatch") if tt.expectedSwitchErr { diff --git a/backend/internal/service/gemini_messages_compat_service.go b/backend/internal/service/gemini_messages_compat_service.go index d77f6f92..792c8f4b 100644 --- a/backend/internal/service/gemini_messages_compat_service.go +++ b/backend/internal/service/gemini_messages_compat_service.go @@ -770,6 +770,14 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex break } + // 错误策略优先:匹配则跳过重试直接处理。 + if matched, rebuilt := s.checkErrorPolicyInLoop(ctx, account, resp); matched { + resp = rebuilt + break + } else { + resp = rebuilt + } + if resp.StatusCode >= 400 && s.shouldRetryGeminiUpstreamError(account, resp.StatusCode) { respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) _ = resp.Body.Close() @@ -839,7 +847,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex if upstreamReqID == "" { upstreamReqID = resp.Header.Get("x-goog-request-id") } - return nil, s.writeGeminiMappedError(c, account, resp.StatusCode, upstreamReqID, respBody) + return nil, s.writeGeminiMappedError(c, account, http.StatusInternalServerError, upstreamReqID, respBody) case ErrorPolicyMatched, ErrorPolicyTempUnscheduled: s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) upstreamReqID := resp.Header.Get(requestIDHeader) @@ -1176,6 +1184,14 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin. return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream request failed after retries: "+safeErr) } + // 错误策略优先:匹配则跳过重试直接处理。 + if matched, rebuilt := s.checkErrorPolicyInLoop(ctx, account, resp); matched { + resp = rebuilt + break + } else { + resp = rebuilt + } + if resp.StatusCode >= 400 && s.shouldRetryGeminiUpstreamError(account, resp.StatusCode) { respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) _ = resp.Body.Close() @@ -1283,7 +1299,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin. if contentType == "" { contentType = "application/json" } - c.Data(resp.StatusCode, contentType, respBody) + c.Data(http.StatusInternalServerError, contentType, respBody) return nil, fmt.Errorf("gemini upstream error: %d (skipped by error policy)", resp.StatusCode) case ErrorPolicyMatched, ErrorPolicyTempUnscheduled: s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) @@ -1425,6 +1441,26 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin. }, nil } +// checkErrorPolicyInLoop 在重试循环内预检查错误策略。 +// 返回 true 表示策略已匹配(调用者应 break),resp 已重建可直接使用。 +// 返回 false 表示 ErrorPolicyNone,resp 已重建,调用者继续走重试逻辑。 +func (s *GeminiMessagesCompatService) checkErrorPolicyInLoop( + ctx context.Context, account *Account, resp *http.Response, +) (matched bool, rebuilt *http.Response) { + if resp.StatusCode < 400 || s.rateLimitService == nil { + return false, resp + } + body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + rebuilt = &http.Response{ + StatusCode: resp.StatusCode, + Header: resp.Header.Clone(), + Body: io.NopCloser(bytes.NewReader(body)), + } + policy := s.rateLimitService.CheckErrorPolicy(ctx, account, resp.StatusCode, body) + return policy != ErrorPolicyNone, rebuilt +} + func (s *GeminiMessagesCompatService) shouldRetryGeminiUpstreamError(account *Account, statusCode int) bool { switch statusCode { case 429, 500, 502, 503, 504, 529: @@ -2597,6 +2633,10 @@ func asInt(v any) (int, bool) { } func (s *GeminiMessagesCompatService) handleGeminiUpstreamError(ctx context.Context, account *Account, statusCode int, headers http.Header, body []byte) { + // 遵守自定义错误码策略:未命中则跳过所有限流处理 + if !account.ShouldHandleErrorCode(statusCode) { + return + } if s.rateLimitService != nil && (statusCode == 401 || statusCode == 403 || statusCode == 529) { s.rateLimitService.HandleUpstreamError(ctx, account, statusCode, headers, body) return