Merge pull request #531 from touwaeriol/fix/gemini-error-policy-before-retry

fix: Gemini error policy check should precede retry logic
This commit is contained in:
Wesley Liddick
2026-02-09 20:52:32 +08:00
committed by GitHub
4 changed files with 169 additions and 12 deletions

View File

@@ -371,12 +371,12 @@ urlFallbackLoop:
_ = resp.Body.Close() _ = resp.Body.Close()
// ★ 统一入口:自定义错误码 + 临时不可调度 // ★ 统一入口:自定义错误码 + 临时不可调度
if handled, policyErr := s.applyErrorPolicy(p, resp.StatusCode, resp.Header, respBody); handled { if handled, outStatus, policyErr := s.applyErrorPolicy(p, resp.StatusCode, resp.Header, respBody); handled {
if policyErr != nil { if policyErr != nil {
return nil, policyErr return nil, policyErr
} }
resp = &http.Response{ resp = &http.Response{
StatusCode: resp.StatusCode, StatusCode: outStatus,
Header: resp.Header.Clone(), Header: resp.Header.Clone(),
Body: io.NopCloser(bytes.NewReader(respBody)), Body: io.NopCloser(bytes.NewReader(respBody)),
} }
@@ -610,21 +610,22 @@ func (s *AntigravityGatewayService) checkErrorPolicy(ctx context.Context, accoun
return s.rateLimitService.CheckErrorPolicy(ctx, account, statusCode, body) return s.rateLimitService.CheckErrorPolicy(ctx, account, statusCode, body)
} }
// applyErrorPolicy 应用错误策略结果,返回是否应终止当前循环 // applyErrorPolicy 应用错误策略结果,返回是否应终止当前循环及应返回的状态码。
func (s *AntigravityGatewayService) applyErrorPolicy(p antigravityRetryLoopParams, statusCode int, headers http.Header, respBody []byte) (handled bool, retErr error) { // ErrorPolicySkipped 时 outStatus 为 500前端约定未命中的错误返回 500
func (s *AntigravityGatewayService) applyErrorPolicy(p antigravityRetryLoopParams, statusCode int, headers http.Header, respBody []byte) (handled bool, outStatus int, retErr error) {
switch s.checkErrorPolicy(p.ctx, p.account, statusCode, respBody) { switch s.checkErrorPolicy(p.ctx, p.account, statusCode, respBody) {
case ErrorPolicySkipped: case ErrorPolicySkipped:
return true, nil return true, http.StatusInternalServerError, nil
case ErrorPolicyMatched: case ErrorPolicyMatched:
_ = p.handleError(p.ctx, p.prefix, p.account, statusCode, headers, respBody, _ = p.handleError(p.ctx, p.prefix, p.account, statusCode, headers, respBody,
p.requestedModel, p.groupID, p.sessionHash, p.isStickySession) p.requestedModel, p.groupID, p.sessionHash, p.isStickySession)
return true, nil return true, statusCode, nil
case ErrorPolicyTempUnscheduled: case ErrorPolicyTempUnscheduled:
slog.Info("temp_unschedulable_matched", slog.Info("temp_unschedulable_matched",
"prefix", p.prefix, "status_code", statusCode, "account_id", p.account.ID) "prefix", p.prefix, "status_code", statusCode, "account_id", p.account.ID)
return true, &AntigravityAccountSwitchError{OriginalAccountID: p.account.ID, IsStickySession: p.isStickySession} return true, statusCode, &AntigravityAccountSwitchError{OriginalAccountID: p.account.ID, IsStickySession: p.isStickySession}
} }
return false, nil return false, statusCode, nil
} }
// mapAntigravityModel 获取映射后的模型名 // mapAntigravityModel 获取映射后的模型名
@@ -2242,6 +2243,10 @@ func (s *AntigravityGatewayService) handleUpstreamError(
requestedModel string, requestedModel string,
groupID int64, sessionHash string, isStickySession bool, groupID int64, sessionHash string, isStickySession bool,
) *handleModelRateLimitResult { ) *handleModelRateLimitResult {
// 遵守自定义错误码策略:未命中则跳过所有限流处理
if !account.ShouldHandleErrorCode(statusCode) {
return nil
}
// 模型级限流处理(优先) // 模型级限流处理(优先)
result := s.handleModelRateLimit(&handleModelRateLimitParams{ result := s.handleModelRateLimit(&handleModelRateLimitParams{
ctx: ctx, ctx: ctx,

View File

@@ -116,7 +116,7 @@ func TestRetryLoop_ErrorPolicy_CustomErrorCodes(t *testing.T) {
customCodes: []any{float64(500)}, customCodes: []any{float64(500)},
expectHandleError: 0, expectHandleError: 0,
expectUpstream: 1, expectUpstream: 1,
expectStatusCode: 429, expectStatusCode: 500,
}, },
{ {
name: "500_in_custom_codes_matched", name: "500_in_custom_codes_matched",
@@ -364,3 +364,109 @@ func TestRetryLoop_ErrorPolicy_NoPolicy_OriginalBehavior(t *testing.T) {
require.Equal(t, antigravityMaxRetries, upstream.calls, "should exhaust all retries") require.Equal(t, antigravityMaxRetries, upstream.calls, "should exhaust all retries")
require.Equal(t, 1, handleErrorCount, "handleError should be called once after retries exhausted") require.Equal(t, 1, handleErrorCount, "handleError should be called once after retries exhausted")
} }
// ---------------------------------------------------------------------------
// epTrackingRepo — records SetRateLimited / SetError calls for verification.
// ---------------------------------------------------------------------------
type epTrackingRepo struct {
mockAccountRepoForGemini
rateLimitedCalls int
rateLimitedID int64
setErrCalls int
setErrID int64
tempCalls int
}
func (r *epTrackingRepo) SetRateLimited(_ context.Context, id int64, _ time.Time) error {
r.rateLimitedCalls++
r.rateLimitedID = id
return nil
}
func (r *epTrackingRepo) SetError(_ context.Context, id int64, _ string) error {
r.setErrCalls++
r.setErrID = id
return nil
}
func (r *epTrackingRepo) SetTempUnschedulable(_ context.Context, _ int64, _ time.Time, _ string) error {
r.tempCalls++
return nil
}
// ---------------------------------------------------------------------------
// TestCustomErrorCode599_SkippedErrors_Return500_NoRateLimit
//
// 核心场景:自定义错误码设为 [599](一个不会真正出现的错误码),
// 当上游返回 429/500/503/401 时:
// - 返回给客户端的状态码必须是 500而不是透传原始状态码
// - 不调用 SetRateLimited不进入限流状态
// - 不调用 SetError不停止调度
// - 不调用 handleError
// ---------------------------------------------------------------------------
func TestCustomErrorCode599_SkippedErrors_Return500_NoRateLimit(t *testing.T) {
errorCodes := []int{429, 500, 503, 401, 403}
for _, upstreamStatus := range errorCodes {
t.Run(http.StatusText(upstreamStatus), func(t *testing.T) {
saveAndSetBaseURLs(t)
upstream := &epFixedUpstream{
statusCode: upstreamStatus,
body: `{"error":"some upstream error"}`,
}
repo := &epTrackingRepo{}
rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
svc := &AntigravityGatewayService{rateLimitService: rlSvc}
account := &Account{
ID: 500,
Type: AccountTypeAPIKey,
Platform: PlatformAntigravity,
Schedulable: true,
Status: StatusActive,
Concurrency: 1,
Credentials: map[string]any{
"custom_error_codes_enabled": true,
"custom_error_codes": []any{float64(599)},
},
}
var handleErrorCount int
p := newRetryParams(account, upstream, func(_ context.Context, _ string, _ *Account, _ int, _ http.Header, _ []byte, _ string, _ int64, _ string, _ bool) *handleModelRateLimitResult {
handleErrorCount++
return nil
})
result, err := svc.antigravityRetryLoop(p)
// 不应返回 errorSkipped 不触发账号切换)
require.NoError(t, err, "should not return error")
require.NotNil(t, result, "result should not be nil")
require.NotNil(t, result.resp, "response should not be nil")
defer func() { _ = result.resp.Body.Close() }()
// 状态码必须是 500不透传原始状态码
require.Equal(t, http.StatusInternalServerError, result.resp.StatusCode,
"skipped error should return 500, not %d", upstreamStatus)
// 不调用 handleError
require.Equal(t, 0, handleErrorCount,
"handleError should NOT be called for skipped errors")
// 不标记限流
require.Equal(t, 0, repo.rateLimitedCalls,
"SetRateLimited should NOT be called for skipped errors")
// 不停止调度
require.Equal(t, 0, repo.setErrCalls,
"SetError should NOT be called for skipped errors")
// 只调用一次上游(不重试)
require.Equal(t, 1, upstream.calls,
"should call upstream exactly once (no retry)")
})
}
}

View File

@@ -158,6 +158,7 @@ func TestApplyErrorPolicy(t *testing.T) {
statusCode int statusCode int
body []byte body []byte
expectedHandled bool expectedHandled bool
expectedStatus int // expected outStatus
expectedSwitchErr bool // expect *AntigravityAccountSwitchError expectedSwitchErr bool // expect *AntigravityAccountSwitchError
handleErrorCalls int handleErrorCalls int
}{ }{
@@ -171,6 +172,7 @@ func TestApplyErrorPolicy(t *testing.T) {
statusCode: 500, statusCode: 500,
body: []byte(`"error"`), body: []byte(`"error"`),
expectedHandled: false, expectedHandled: false,
expectedStatus: 500, // passthrough
handleErrorCalls: 0, handleErrorCalls: 0,
}, },
{ {
@@ -187,6 +189,7 @@ func TestApplyErrorPolicy(t *testing.T) {
statusCode: 500, // not in custom codes statusCode: 500, // not in custom codes
body: []byte(`"error"`), body: []byte(`"error"`),
expectedHandled: true, expectedHandled: true,
expectedStatus: http.StatusInternalServerError, // skipped → 500
handleErrorCalls: 0, handleErrorCalls: 0,
}, },
{ {
@@ -203,6 +206,7 @@ func TestApplyErrorPolicy(t *testing.T) {
statusCode: 500, statusCode: 500,
body: []byte(`"error"`), body: []byte(`"error"`),
expectedHandled: true, expectedHandled: true,
expectedStatus: 500, // matched → original status
handleErrorCalls: 1, handleErrorCalls: 1,
}, },
{ {
@@ -225,6 +229,7 @@ func TestApplyErrorPolicy(t *testing.T) {
statusCode: 503, statusCode: 503,
body: []byte(`overloaded`), body: []byte(`overloaded`),
expectedHandled: true, expectedHandled: true,
expectedStatus: 503, // temp_unscheduled → original status
expectedSwitchErr: true, expectedSwitchErr: true,
handleErrorCalls: 0, handleErrorCalls: 0,
}, },
@@ -250,9 +255,10 @@ func TestApplyErrorPolicy(t *testing.T) {
isStickySession: true, isStickySession: true,
} }
handled, retErr := svc.applyErrorPolicy(p, tt.statusCode, http.Header{}, tt.body) handled, outStatus, retErr := svc.applyErrorPolicy(p, tt.statusCode, http.Header{}, tt.body)
require.Equal(t, tt.expectedHandled, handled, "handled mismatch") require.Equal(t, tt.expectedHandled, handled, "handled mismatch")
require.Equal(t, tt.expectedStatus, outStatus, "outStatus mismatch")
require.Equal(t, tt.handleErrorCalls, handleErrorCount, "handleError call count mismatch") require.Equal(t, tt.handleErrorCalls, handleErrorCount, "handleError call count mismatch")
if tt.expectedSwitchErr { if tt.expectedSwitchErr {

View File

@@ -770,6 +770,14 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
break break
} }
// 错误策略优先:匹配则跳过重试直接处理。
if matched, rebuilt := s.checkErrorPolicyInLoop(ctx, account, resp); matched {
resp = rebuilt
break
} else {
resp = rebuilt
}
if resp.StatusCode >= 400 && s.shouldRetryGeminiUpstreamError(account, resp.StatusCode) { if resp.StatusCode >= 400 && s.shouldRetryGeminiUpstreamError(account, resp.StatusCode) {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close() _ = resp.Body.Close()
@@ -839,7 +847,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
if upstreamReqID == "" { if upstreamReqID == "" {
upstreamReqID = resp.Header.Get("x-goog-request-id") upstreamReqID = resp.Header.Get("x-goog-request-id")
} }
return nil, s.writeGeminiMappedError(c, account, resp.StatusCode, upstreamReqID, respBody) return nil, s.writeGeminiMappedError(c, account, http.StatusInternalServerError, upstreamReqID, respBody)
case ErrorPolicyMatched, ErrorPolicyTempUnscheduled: case ErrorPolicyMatched, ErrorPolicyTempUnscheduled:
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
upstreamReqID := resp.Header.Get(requestIDHeader) upstreamReqID := resp.Header.Get(requestIDHeader)
@@ -1176,6 +1184,14 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream request failed after retries: "+safeErr) return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream request failed after retries: "+safeErr)
} }
// 错误策略优先:匹配则跳过重试直接处理。
if matched, rebuilt := s.checkErrorPolicyInLoop(ctx, account, resp); matched {
resp = rebuilt
break
} else {
resp = rebuilt
}
if resp.StatusCode >= 400 && s.shouldRetryGeminiUpstreamError(account, resp.StatusCode) { if resp.StatusCode >= 400 && s.shouldRetryGeminiUpstreamError(account, resp.StatusCode) {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close() _ = resp.Body.Close()
@@ -1283,7 +1299,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
if contentType == "" { if contentType == "" {
contentType = "application/json" contentType = "application/json"
} }
c.Data(resp.StatusCode, contentType, respBody) c.Data(http.StatusInternalServerError, contentType, respBody)
return nil, fmt.Errorf("gemini upstream error: %d (skipped by error policy)", resp.StatusCode) return nil, fmt.Errorf("gemini upstream error: %d (skipped by error policy)", resp.StatusCode)
case ErrorPolicyMatched, ErrorPolicyTempUnscheduled: case ErrorPolicyMatched, ErrorPolicyTempUnscheduled:
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
@@ -1425,6 +1441,26 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
}, nil }, nil
} }
// checkErrorPolicyInLoop 在重试循环内预检查错误策略。
// 返回 true 表示策略已匹配(调用者应 breakresp 已重建可直接使用。
// 返回 false 表示 ErrorPolicyNoneresp 已重建,调用者继续走重试逻辑。
func (s *GeminiMessagesCompatService) checkErrorPolicyInLoop(
ctx context.Context, account *Account, resp *http.Response,
) (matched bool, rebuilt *http.Response) {
if resp.StatusCode < 400 || s.rateLimitService == nil {
return false, resp
}
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close()
rebuilt = &http.Response{
StatusCode: resp.StatusCode,
Header: resp.Header.Clone(),
Body: io.NopCloser(bytes.NewReader(body)),
}
policy := s.rateLimitService.CheckErrorPolicy(ctx, account, resp.StatusCode, body)
return policy != ErrorPolicyNone, rebuilt
}
func (s *GeminiMessagesCompatService) shouldRetryGeminiUpstreamError(account *Account, statusCode int) bool { func (s *GeminiMessagesCompatService) shouldRetryGeminiUpstreamError(account *Account, statusCode int) bool {
switch statusCode { switch statusCode {
case 429, 500, 502, 503, 504, 529: case 429, 500, 502, 503, 504, 529:
@@ -2597,6 +2633,10 @@ func asInt(v any) (int, bool) {
} }
func (s *GeminiMessagesCompatService) handleGeminiUpstreamError(ctx context.Context, account *Account, statusCode int, headers http.Header, body []byte) { func (s *GeminiMessagesCompatService) handleGeminiUpstreamError(ctx context.Context, account *Account, statusCode int, headers http.Header, body []byte) {
// 遵守自定义错误码策略:未命中则跳过所有限流处理
if !account.ShouldHandleErrorCode(statusCode) {
return
}
if s.rateLimitService != nil && (statusCode == 401 || statusCode == 403 || statusCode == 529) { if s.rateLimitService != nil && (statusCode == 401 || statusCode == 403 || statusCode == 529) {
s.rateLimitService.HandleUpstreamError(ctx, account, statusCode, headers, body) s.rateLimitService.HandleUpstreamError(ctx, account, statusCode, headers, body)
return return