Merge pull request #531 from touwaeriol/fix/gemini-error-policy-before-retry
fix: Gemini error policy check should precede retry logic
This commit is contained in:
@@ -371,12 +371,12 @@ urlFallbackLoop:
|
||||
_ = resp.Body.Close()
|
||||
|
||||
// ★ 统一入口:自定义错误码 + 临时不可调度
|
||||
if handled, policyErr := s.applyErrorPolicy(p, resp.StatusCode, resp.Header, respBody); handled {
|
||||
if handled, outStatus, policyErr := s.applyErrorPolicy(p, resp.StatusCode, resp.Header, respBody); handled {
|
||||
if policyErr != nil {
|
||||
return nil, policyErr
|
||||
}
|
||||
resp = &http.Response{
|
||||
StatusCode: resp.StatusCode,
|
||||
StatusCode: outStatus,
|
||||
Header: resp.Header.Clone(),
|
||||
Body: io.NopCloser(bytes.NewReader(respBody)),
|
||||
}
|
||||
@@ -610,21 +610,22 @@ func (s *AntigravityGatewayService) checkErrorPolicy(ctx context.Context, accoun
|
||||
return s.rateLimitService.CheckErrorPolicy(ctx, account, statusCode, body)
|
||||
}
|
||||
|
||||
// applyErrorPolicy 应用错误策略结果,返回是否应终止当前循环
|
||||
func (s *AntigravityGatewayService) applyErrorPolicy(p antigravityRetryLoopParams, statusCode int, headers http.Header, respBody []byte) (handled bool, retErr error) {
|
||||
// applyErrorPolicy 应用错误策略结果,返回是否应终止当前循环及应返回的状态码。
|
||||
// ErrorPolicySkipped 时 outStatus 为 500(前端约定:未命中的错误返回 500)。
|
||||
func (s *AntigravityGatewayService) applyErrorPolicy(p antigravityRetryLoopParams, statusCode int, headers http.Header, respBody []byte) (handled bool, outStatus int, retErr error) {
|
||||
switch s.checkErrorPolicy(p.ctx, p.account, statusCode, respBody) {
|
||||
case ErrorPolicySkipped:
|
||||
return true, nil
|
||||
return true, http.StatusInternalServerError, nil
|
||||
case ErrorPolicyMatched:
|
||||
_ = p.handleError(p.ctx, p.prefix, p.account, statusCode, headers, respBody,
|
||||
p.requestedModel, p.groupID, p.sessionHash, p.isStickySession)
|
||||
return true, nil
|
||||
return true, statusCode, nil
|
||||
case ErrorPolicyTempUnscheduled:
|
||||
slog.Info("temp_unschedulable_matched",
|
||||
"prefix", p.prefix, "status_code", statusCode, "account_id", p.account.ID)
|
||||
return true, &AntigravityAccountSwitchError{OriginalAccountID: p.account.ID, IsStickySession: p.isStickySession}
|
||||
return true, statusCode, &AntigravityAccountSwitchError{OriginalAccountID: p.account.ID, IsStickySession: p.isStickySession}
|
||||
}
|
||||
return false, nil
|
||||
return false, statusCode, nil
|
||||
}
|
||||
|
||||
// mapAntigravityModel 获取映射后的模型名
|
||||
@@ -2242,6 +2243,10 @@ func (s *AntigravityGatewayService) handleUpstreamError(
|
||||
requestedModel string,
|
||||
groupID int64, sessionHash string, isStickySession bool,
|
||||
) *handleModelRateLimitResult {
|
||||
// 遵守自定义错误码策略:未命中则跳过所有限流处理
|
||||
if !account.ShouldHandleErrorCode(statusCode) {
|
||||
return nil
|
||||
}
|
||||
// 模型级限流处理(优先)
|
||||
result := s.handleModelRateLimit(&handleModelRateLimitParams{
|
||||
ctx: ctx,
|
||||
|
||||
@@ -116,7 +116,7 @@ func TestRetryLoop_ErrorPolicy_CustomErrorCodes(t *testing.T) {
|
||||
customCodes: []any{float64(500)},
|
||||
expectHandleError: 0,
|
||||
expectUpstream: 1,
|
||||
expectStatusCode: 429,
|
||||
expectStatusCode: 500,
|
||||
},
|
||||
{
|
||||
name: "500_in_custom_codes_matched",
|
||||
@@ -364,3 +364,109 @@ func TestRetryLoop_ErrorPolicy_NoPolicy_OriginalBehavior(t *testing.T) {
|
||||
require.Equal(t, antigravityMaxRetries, upstream.calls, "should exhaust all retries")
|
||||
require.Equal(t, 1, handleErrorCount, "handleError should be called once after retries exhausted")
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// epTrackingRepo — records SetRateLimited / SetError calls for verification.
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
type epTrackingRepo struct {
|
||||
mockAccountRepoForGemini
|
||||
rateLimitedCalls int
|
||||
rateLimitedID int64
|
||||
setErrCalls int
|
||||
setErrID int64
|
||||
tempCalls int
|
||||
}
|
||||
|
||||
func (r *epTrackingRepo) SetRateLimited(_ context.Context, id int64, _ time.Time) error {
|
||||
r.rateLimitedCalls++
|
||||
r.rateLimitedID = id
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *epTrackingRepo) SetError(_ context.Context, id int64, _ string) error {
|
||||
r.setErrCalls++
|
||||
r.setErrID = id
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *epTrackingRepo) SetTempUnschedulable(_ context.Context, _ int64, _ time.Time, _ string) error {
|
||||
r.tempCalls++
|
||||
return nil
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// TestCustomErrorCode599_SkippedErrors_Return500_NoRateLimit
|
||||
//
|
||||
// 核心场景:自定义错误码设为 [599](一个不会真正出现的错误码),
|
||||
// 当上游返回 429/500/503/401 时:
|
||||
// - 返回给客户端的状态码必须是 500(而不是透传原始状态码)
|
||||
// - 不调用 SetRateLimited(不进入限流状态)
|
||||
// - 不调用 SetError(不停止调度)
|
||||
// - 不调用 handleError
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestCustomErrorCode599_SkippedErrors_Return500_NoRateLimit(t *testing.T) {
|
||||
errorCodes := []int{429, 500, 503, 401, 403}
|
||||
|
||||
for _, upstreamStatus := range errorCodes {
|
||||
t.Run(http.StatusText(upstreamStatus), func(t *testing.T) {
|
||||
saveAndSetBaseURLs(t)
|
||||
|
||||
upstream := &epFixedUpstream{
|
||||
statusCode: upstreamStatus,
|
||||
body: `{"error":"some upstream error"}`,
|
||||
}
|
||||
repo := &epTrackingRepo{}
|
||||
rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||
svc := &AntigravityGatewayService{rateLimitService: rlSvc}
|
||||
|
||||
account := &Account{
|
||||
ID: 500,
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformAntigravity,
|
||||
Schedulable: true,
|
||||
Status: StatusActive,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"custom_error_codes_enabled": true,
|
||||
"custom_error_codes": []any{float64(599)},
|
||||
},
|
||||
}
|
||||
|
||||
var handleErrorCount int
|
||||
p := newRetryParams(account, upstream, func(_ context.Context, _ string, _ *Account, _ int, _ http.Header, _ []byte, _ string, _ int64, _ string, _ bool) *handleModelRateLimitResult {
|
||||
handleErrorCount++
|
||||
return nil
|
||||
})
|
||||
|
||||
result, err := svc.antigravityRetryLoop(p)
|
||||
|
||||
// 不应返回 error(Skipped 不触发账号切换)
|
||||
require.NoError(t, err, "should not return error")
|
||||
require.NotNil(t, result, "result should not be nil")
|
||||
require.NotNil(t, result.resp, "response should not be nil")
|
||||
defer func() { _ = result.resp.Body.Close() }()
|
||||
|
||||
// 状态码必须是 500(不透传原始状态码)
|
||||
require.Equal(t, http.StatusInternalServerError, result.resp.StatusCode,
|
||||
"skipped error should return 500, not %d", upstreamStatus)
|
||||
|
||||
// 不调用 handleError
|
||||
require.Equal(t, 0, handleErrorCount,
|
||||
"handleError should NOT be called for skipped errors")
|
||||
|
||||
// 不标记限流
|
||||
require.Equal(t, 0, repo.rateLimitedCalls,
|
||||
"SetRateLimited should NOT be called for skipped errors")
|
||||
|
||||
// 不停止调度
|
||||
require.Equal(t, 0, repo.setErrCalls,
|
||||
"SetError should NOT be called for skipped errors")
|
||||
|
||||
// 只调用一次上游(不重试)
|
||||
require.Equal(t, 1, upstream.calls,
|
||||
"should call upstream exactly once (no retry)")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -158,6 +158,7 @@ func TestApplyErrorPolicy(t *testing.T) {
|
||||
statusCode int
|
||||
body []byte
|
||||
expectedHandled bool
|
||||
expectedStatus int // expected outStatus
|
||||
expectedSwitchErr bool // expect *AntigravityAccountSwitchError
|
||||
handleErrorCalls int
|
||||
}{
|
||||
@@ -171,6 +172,7 @@ func TestApplyErrorPolicy(t *testing.T) {
|
||||
statusCode: 500,
|
||||
body: []byte(`"error"`),
|
||||
expectedHandled: false,
|
||||
expectedStatus: 500, // passthrough
|
||||
handleErrorCalls: 0,
|
||||
},
|
||||
{
|
||||
@@ -187,6 +189,7 @@ func TestApplyErrorPolicy(t *testing.T) {
|
||||
statusCode: 500, // not in custom codes
|
||||
body: []byte(`"error"`),
|
||||
expectedHandled: true,
|
||||
expectedStatus: http.StatusInternalServerError, // skipped → 500
|
||||
handleErrorCalls: 0,
|
||||
},
|
||||
{
|
||||
@@ -203,6 +206,7 @@ func TestApplyErrorPolicy(t *testing.T) {
|
||||
statusCode: 500,
|
||||
body: []byte(`"error"`),
|
||||
expectedHandled: true,
|
||||
expectedStatus: 500, // matched → original status
|
||||
handleErrorCalls: 1,
|
||||
},
|
||||
{
|
||||
@@ -225,6 +229,7 @@ func TestApplyErrorPolicy(t *testing.T) {
|
||||
statusCode: 503,
|
||||
body: []byte(`overloaded`),
|
||||
expectedHandled: true,
|
||||
expectedStatus: 503, // temp_unscheduled → original status
|
||||
expectedSwitchErr: true,
|
||||
handleErrorCalls: 0,
|
||||
},
|
||||
@@ -250,9 +255,10 @@ func TestApplyErrorPolicy(t *testing.T) {
|
||||
isStickySession: true,
|
||||
}
|
||||
|
||||
handled, retErr := svc.applyErrorPolicy(p, tt.statusCode, http.Header{}, tt.body)
|
||||
handled, outStatus, retErr := svc.applyErrorPolicy(p, tt.statusCode, http.Header{}, tt.body)
|
||||
|
||||
require.Equal(t, tt.expectedHandled, handled, "handled mismatch")
|
||||
require.Equal(t, tt.expectedStatus, outStatus, "outStatus mismatch")
|
||||
require.Equal(t, tt.handleErrorCalls, handleErrorCount, "handleError call count mismatch")
|
||||
|
||||
if tt.expectedSwitchErr {
|
||||
|
||||
@@ -770,6 +770,14 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
||||
break
|
||||
}
|
||||
|
||||
// 错误策略优先:匹配则跳过重试直接处理。
|
||||
if matched, rebuilt := s.checkErrorPolicyInLoop(ctx, account, resp); matched {
|
||||
resp = rebuilt
|
||||
break
|
||||
} else {
|
||||
resp = rebuilt
|
||||
}
|
||||
|
||||
if resp.StatusCode >= 400 && s.shouldRetryGeminiUpstreamError(account, resp.StatusCode) {
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
_ = resp.Body.Close()
|
||||
@@ -839,7 +847,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
||||
if upstreamReqID == "" {
|
||||
upstreamReqID = resp.Header.Get("x-goog-request-id")
|
||||
}
|
||||
return nil, s.writeGeminiMappedError(c, account, resp.StatusCode, upstreamReqID, respBody)
|
||||
return nil, s.writeGeminiMappedError(c, account, http.StatusInternalServerError, upstreamReqID, respBody)
|
||||
case ErrorPolicyMatched, ErrorPolicyTempUnscheduled:
|
||||
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||||
upstreamReqID := resp.Header.Get(requestIDHeader)
|
||||
@@ -1176,6 +1184,14 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
||||
return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream request failed after retries: "+safeErr)
|
||||
}
|
||||
|
||||
// 错误策略优先:匹配则跳过重试直接处理。
|
||||
if matched, rebuilt := s.checkErrorPolicyInLoop(ctx, account, resp); matched {
|
||||
resp = rebuilt
|
||||
break
|
||||
} else {
|
||||
resp = rebuilt
|
||||
}
|
||||
|
||||
if resp.StatusCode >= 400 && s.shouldRetryGeminiUpstreamError(account, resp.StatusCode) {
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
_ = resp.Body.Close()
|
||||
@@ -1283,7 +1299,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
||||
if contentType == "" {
|
||||
contentType = "application/json"
|
||||
}
|
||||
c.Data(resp.StatusCode, contentType, respBody)
|
||||
c.Data(http.StatusInternalServerError, contentType, respBody)
|
||||
return nil, fmt.Errorf("gemini upstream error: %d (skipped by error policy)", resp.StatusCode)
|
||||
case ErrorPolicyMatched, ErrorPolicyTempUnscheduled:
|
||||
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||||
@@ -1425,6 +1441,26 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
||||
}, nil
|
||||
}
|
||||
|
||||
// checkErrorPolicyInLoop 在重试循环内预检查错误策略。
|
||||
// 返回 true 表示策略已匹配(调用者应 break),resp 已重建可直接使用。
|
||||
// 返回 false 表示 ErrorPolicyNone,resp 已重建,调用者继续走重试逻辑。
|
||||
func (s *GeminiMessagesCompatService) checkErrorPolicyInLoop(
|
||||
ctx context.Context, account *Account, resp *http.Response,
|
||||
) (matched bool, rebuilt *http.Response) {
|
||||
if resp.StatusCode < 400 || s.rateLimitService == nil {
|
||||
return false, resp
|
||||
}
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
_ = resp.Body.Close()
|
||||
rebuilt = &http.Response{
|
||||
StatusCode: resp.StatusCode,
|
||||
Header: resp.Header.Clone(),
|
||||
Body: io.NopCloser(bytes.NewReader(body)),
|
||||
}
|
||||
policy := s.rateLimitService.CheckErrorPolicy(ctx, account, resp.StatusCode, body)
|
||||
return policy != ErrorPolicyNone, rebuilt
|
||||
}
|
||||
|
||||
func (s *GeminiMessagesCompatService) shouldRetryGeminiUpstreamError(account *Account, statusCode int) bool {
|
||||
switch statusCode {
|
||||
case 429, 500, 502, 503, 504, 529:
|
||||
@@ -2597,6 +2633,10 @@ func asInt(v any) (int, bool) {
|
||||
}
|
||||
|
||||
func (s *GeminiMessagesCompatService) handleGeminiUpstreamError(ctx context.Context, account *Account, statusCode int, headers http.Header, body []byte) {
|
||||
// 遵守自定义错误码策略:未命中则跳过所有限流处理
|
||||
if !account.ShouldHandleErrorCode(statusCode) {
|
||||
return
|
||||
}
|
||||
if s.rateLimitService != nil && (statusCode == 401 || statusCode == 403 || statusCode == 529) {
|
||||
s.rateLimitService.HandleUpstreamError(ctx, account, statusCode, headers, body)
|
||||
return
|
||||
|
||||
Reference in New Issue
Block a user