Merge pull request #531 from touwaeriol/fix/gemini-error-policy-before-retry

fix: Gemini error policy check should precede retry logic
This commit is contained in:
Wesley Liddick
2026-02-09 20:52:32 +08:00
committed by GitHub
4 changed files with 169 additions and 12 deletions

View File

@@ -371,12 +371,12 @@ urlFallbackLoop:
_ = resp.Body.Close()
// ★ 统一入口:自定义错误码 + 临时不可调度
if handled, policyErr := s.applyErrorPolicy(p, resp.StatusCode, resp.Header, respBody); handled {
if handled, outStatus, policyErr := s.applyErrorPolicy(p, resp.StatusCode, resp.Header, respBody); handled {
if policyErr != nil {
return nil, policyErr
}
resp = &http.Response{
StatusCode: resp.StatusCode,
StatusCode: outStatus,
Header: resp.Header.Clone(),
Body: io.NopCloser(bytes.NewReader(respBody)),
}
@@ -610,21 +610,22 @@ func (s *AntigravityGatewayService) checkErrorPolicy(ctx context.Context, accoun
return s.rateLimitService.CheckErrorPolicy(ctx, account, statusCode, body)
}
// applyErrorPolicy 应用错误策略结果,返回是否应终止当前循环
func (s *AntigravityGatewayService) applyErrorPolicy(p antigravityRetryLoopParams, statusCode int, headers http.Header, respBody []byte) (handled bool, retErr error) {
// applyErrorPolicy 应用错误策略结果,返回是否应终止当前循环及应返回的状态码。
// ErrorPolicySkipped 时 outStatus 为 500前端约定未命中的错误返回 500
func (s *AntigravityGatewayService) applyErrorPolicy(p antigravityRetryLoopParams, statusCode int, headers http.Header, respBody []byte) (handled bool, outStatus int, retErr error) {
switch s.checkErrorPolicy(p.ctx, p.account, statusCode, respBody) {
case ErrorPolicySkipped:
return true, nil
return true, http.StatusInternalServerError, nil
case ErrorPolicyMatched:
_ = p.handleError(p.ctx, p.prefix, p.account, statusCode, headers, respBody,
p.requestedModel, p.groupID, p.sessionHash, p.isStickySession)
return true, nil
return true, statusCode, nil
case ErrorPolicyTempUnscheduled:
slog.Info("temp_unschedulable_matched",
"prefix", p.prefix, "status_code", statusCode, "account_id", p.account.ID)
return true, &AntigravityAccountSwitchError{OriginalAccountID: p.account.ID, IsStickySession: p.isStickySession}
return true, statusCode, &AntigravityAccountSwitchError{OriginalAccountID: p.account.ID, IsStickySession: p.isStickySession}
}
return false, nil
return false, statusCode, nil
}
// mapAntigravityModel 获取映射后的模型名
@@ -2242,6 +2243,10 @@ func (s *AntigravityGatewayService) handleUpstreamError(
requestedModel string,
groupID int64, sessionHash string, isStickySession bool,
) *handleModelRateLimitResult {
// 遵守自定义错误码策略:未命中则跳过所有限流处理
if !account.ShouldHandleErrorCode(statusCode) {
return nil
}
// 模型级限流处理(优先)
result := s.handleModelRateLimit(&handleModelRateLimitParams{
ctx: ctx,

View File

@@ -116,7 +116,7 @@ func TestRetryLoop_ErrorPolicy_CustomErrorCodes(t *testing.T) {
customCodes: []any{float64(500)},
expectHandleError: 0,
expectUpstream: 1,
expectStatusCode: 429,
expectStatusCode: 500,
},
{
name: "500_in_custom_codes_matched",
@@ -364,3 +364,109 @@ func TestRetryLoop_ErrorPolicy_NoPolicy_OriginalBehavior(t *testing.T) {
require.Equal(t, antigravityMaxRetries, upstream.calls, "should exhaust all retries")
require.Equal(t, 1, handleErrorCount, "handleError should be called once after retries exhausted")
}
// ---------------------------------------------------------------------------
// epTrackingRepo — records SetRateLimited / SetError calls for verification.
// ---------------------------------------------------------------------------
type epTrackingRepo struct {
mockAccountRepoForGemini
rateLimitedCalls int
rateLimitedID int64
setErrCalls int
setErrID int64
tempCalls int
}
func (r *epTrackingRepo) SetRateLimited(_ context.Context, id int64, _ time.Time) error {
r.rateLimitedCalls++
r.rateLimitedID = id
return nil
}
func (r *epTrackingRepo) SetError(_ context.Context, id int64, _ string) error {
r.setErrCalls++
r.setErrID = id
return nil
}
func (r *epTrackingRepo) SetTempUnschedulable(_ context.Context, _ int64, _ time.Time, _ string) error {
r.tempCalls++
return nil
}
// ---------------------------------------------------------------------------
// TestCustomErrorCode599_SkippedErrors_Return500_NoRateLimit
//
// 核心场景:自定义错误码设为 [599](一个不会真正出现的错误码),
// 当上游返回 429/500/503/401 时:
// - 返回给客户端的状态码必须是 500而不是透传原始状态码
// - 不调用 SetRateLimited不进入限流状态
// - 不调用 SetError不停止调度
// - 不调用 handleError
// ---------------------------------------------------------------------------
func TestCustomErrorCode599_SkippedErrors_Return500_NoRateLimit(t *testing.T) {
errorCodes := []int{429, 500, 503, 401, 403}
for _, upstreamStatus := range errorCodes {
t.Run(http.StatusText(upstreamStatus), func(t *testing.T) {
saveAndSetBaseURLs(t)
upstream := &epFixedUpstream{
statusCode: upstreamStatus,
body: `{"error":"some upstream error"}`,
}
repo := &epTrackingRepo{}
rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
svc := &AntigravityGatewayService{rateLimitService: rlSvc}
account := &Account{
ID: 500,
Type: AccountTypeAPIKey,
Platform: PlatformAntigravity,
Schedulable: true,
Status: StatusActive,
Concurrency: 1,
Credentials: map[string]any{
"custom_error_codes_enabled": true,
"custom_error_codes": []any{float64(599)},
},
}
var handleErrorCount int
p := newRetryParams(account, upstream, func(_ context.Context, _ string, _ *Account, _ int, _ http.Header, _ []byte, _ string, _ int64, _ string, _ bool) *handleModelRateLimitResult {
handleErrorCount++
return nil
})
result, err := svc.antigravityRetryLoop(p)
// 不应返回 errorSkipped 不触发账号切换)
require.NoError(t, err, "should not return error")
require.NotNil(t, result, "result should not be nil")
require.NotNil(t, result.resp, "response should not be nil")
defer func() { _ = result.resp.Body.Close() }()
// 状态码必须是 500不透传原始状态码
require.Equal(t, http.StatusInternalServerError, result.resp.StatusCode,
"skipped error should return 500, not %d", upstreamStatus)
// 不调用 handleError
require.Equal(t, 0, handleErrorCount,
"handleError should NOT be called for skipped errors")
// 不标记限流
require.Equal(t, 0, repo.rateLimitedCalls,
"SetRateLimited should NOT be called for skipped errors")
// 不停止调度
require.Equal(t, 0, repo.setErrCalls,
"SetError should NOT be called for skipped errors")
// 只调用一次上游(不重试)
require.Equal(t, 1, upstream.calls,
"should call upstream exactly once (no retry)")
})
}
}

View File

@@ -158,6 +158,7 @@ func TestApplyErrorPolicy(t *testing.T) {
statusCode int
body []byte
expectedHandled bool
expectedStatus int // expected outStatus
expectedSwitchErr bool // expect *AntigravityAccountSwitchError
handleErrorCalls int
}{
@@ -171,6 +172,7 @@ func TestApplyErrorPolicy(t *testing.T) {
statusCode: 500,
body: []byte(`"error"`),
expectedHandled: false,
expectedStatus: 500, // passthrough
handleErrorCalls: 0,
},
{
@@ -187,6 +189,7 @@ func TestApplyErrorPolicy(t *testing.T) {
statusCode: 500, // not in custom codes
body: []byte(`"error"`),
expectedHandled: true,
expectedStatus: http.StatusInternalServerError, // skipped → 500
handleErrorCalls: 0,
},
{
@@ -203,6 +206,7 @@ func TestApplyErrorPolicy(t *testing.T) {
statusCode: 500,
body: []byte(`"error"`),
expectedHandled: true,
expectedStatus: 500, // matched → original status
handleErrorCalls: 1,
},
{
@@ -225,6 +229,7 @@ func TestApplyErrorPolicy(t *testing.T) {
statusCode: 503,
body: []byte(`overloaded`),
expectedHandled: true,
expectedStatus: 503, // temp_unscheduled → original status
expectedSwitchErr: true,
handleErrorCalls: 0,
},
@@ -250,9 +255,10 @@ func TestApplyErrorPolicy(t *testing.T) {
isStickySession: true,
}
handled, retErr := svc.applyErrorPolicy(p, tt.statusCode, http.Header{}, tt.body)
handled, outStatus, retErr := svc.applyErrorPolicy(p, tt.statusCode, http.Header{}, tt.body)
require.Equal(t, tt.expectedHandled, handled, "handled mismatch")
require.Equal(t, tt.expectedStatus, outStatus, "outStatus mismatch")
require.Equal(t, tt.handleErrorCalls, handleErrorCount, "handleError call count mismatch")
if tt.expectedSwitchErr {

View File

@@ -770,6 +770,14 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
break
}
// 错误策略优先:匹配则跳过重试直接处理。
if matched, rebuilt := s.checkErrorPolicyInLoop(ctx, account, resp); matched {
resp = rebuilt
break
} else {
resp = rebuilt
}
if resp.StatusCode >= 400 && s.shouldRetryGeminiUpstreamError(account, resp.StatusCode) {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close()
@@ -839,7 +847,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
if upstreamReqID == "" {
upstreamReqID = resp.Header.Get("x-goog-request-id")
}
return nil, s.writeGeminiMappedError(c, account, resp.StatusCode, upstreamReqID, respBody)
return nil, s.writeGeminiMappedError(c, account, http.StatusInternalServerError, upstreamReqID, respBody)
case ErrorPolicyMatched, ErrorPolicyTempUnscheduled:
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
upstreamReqID := resp.Header.Get(requestIDHeader)
@@ -1176,6 +1184,14 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream request failed after retries: "+safeErr)
}
// 错误策略优先:匹配则跳过重试直接处理。
if matched, rebuilt := s.checkErrorPolicyInLoop(ctx, account, resp); matched {
resp = rebuilt
break
} else {
resp = rebuilt
}
if resp.StatusCode >= 400 && s.shouldRetryGeminiUpstreamError(account, resp.StatusCode) {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close()
@@ -1283,7 +1299,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
if contentType == "" {
contentType = "application/json"
}
c.Data(resp.StatusCode, contentType, respBody)
c.Data(http.StatusInternalServerError, contentType, respBody)
return nil, fmt.Errorf("gemini upstream error: %d (skipped by error policy)", resp.StatusCode)
case ErrorPolicyMatched, ErrorPolicyTempUnscheduled:
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
@@ -1425,6 +1441,26 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
}, nil
}
// checkErrorPolicyInLoop 在重试循环内预检查错误策略。
// 返回 true 表示策略已匹配(调用者应 breakresp 已重建可直接使用。
// 返回 false 表示 ErrorPolicyNoneresp 已重建,调用者继续走重试逻辑。
func (s *GeminiMessagesCompatService) checkErrorPolicyInLoop(
ctx context.Context, account *Account, resp *http.Response,
) (matched bool, rebuilt *http.Response) {
if resp.StatusCode < 400 || s.rateLimitService == nil {
return false, resp
}
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close()
rebuilt = &http.Response{
StatusCode: resp.StatusCode,
Header: resp.Header.Clone(),
Body: io.NopCloser(bytes.NewReader(body)),
}
policy := s.rateLimitService.CheckErrorPolicy(ctx, account, resp.StatusCode, body)
return policy != ErrorPolicyNone, rebuilt
}
func (s *GeminiMessagesCompatService) shouldRetryGeminiUpstreamError(account *Account, statusCode int) bool {
switch statusCode {
case 429, 500, 502, 503, 504, 529:
@@ -2597,6 +2633,10 @@ func asInt(v any) (int, bool) {
}
func (s *GeminiMessagesCompatService) handleGeminiUpstreamError(ctx context.Context, account *Account, statusCode int, headers http.Header, body []byte) {
// 遵守自定义错误码策略:未命中则跳过所有限流处理
if !account.ShouldHandleErrorCode(statusCode) {
return
}
if s.rateLimitService != nil && (statusCode == 401 || statusCode == 403 || statusCode == 529) {
s.rateLimitService.HandleUpstreamError(ctx, account, statusCode, headers, body)
return