Merge pull request #531 from touwaeriol/fix/gemini-error-policy-before-retry
fix: Gemini error policy check should precede retry logic
This commit is contained in:
@@ -371,12 +371,12 @@ urlFallbackLoop:
|
|||||||
_ = resp.Body.Close()
|
_ = resp.Body.Close()
|
||||||
|
|
||||||
// ★ 统一入口:自定义错误码 + 临时不可调度
|
// ★ 统一入口:自定义错误码 + 临时不可调度
|
||||||
if handled, policyErr := s.applyErrorPolicy(p, resp.StatusCode, resp.Header, respBody); handled {
|
if handled, outStatus, policyErr := s.applyErrorPolicy(p, resp.StatusCode, resp.Header, respBody); handled {
|
||||||
if policyErr != nil {
|
if policyErr != nil {
|
||||||
return nil, policyErr
|
return nil, policyErr
|
||||||
}
|
}
|
||||||
resp = &http.Response{
|
resp = &http.Response{
|
||||||
StatusCode: resp.StatusCode,
|
StatusCode: outStatus,
|
||||||
Header: resp.Header.Clone(),
|
Header: resp.Header.Clone(),
|
||||||
Body: io.NopCloser(bytes.NewReader(respBody)),
|
Body: io.NopCloser(bytes.NewReader(respBody)),
|
||||||
}
|
}
|
||||||
@@ -610,21 +610,22 @@ func (s *AntigravityGatewayService) checkErrorPolicy(ctx context.Context, accoun
|
|||||||
return s.rateLimitService.CheckErrorPolicy(ctx, account, statusCode, body)
|
return s.rateLimitService.CheckErrorPolicy(ctx, account, statusCode, body)
|
||||||
}
|
}
|
||||||
|
|
||||||
// applyErrorPolicy 应用错误策略结果,返回是否应终止当前循环
|
// applyErrorPolicy 应用错误策略结果,返回是否应终止当前循环及应返回的状态码。
|
||||||
func (s *AntigravityGatewayService) applyErrorPolicy(p antigravityRetryLoopParams, statusCode int, headers http.Header, respBody []byte) (handled bool, retErr error) {
|
// ErrorPolicySkipped 时 outStatus 为 500(前端约定:未命中的错误返回 500)。
|
||||||
|
func (s *AntigravityGatewayService) applyErrorPolicy(p antigravityRetryLoopParams, statusCode int, headers http.Header, respBody []byte) (handled bool, outStatus int, retErr error) {
|
||||||
switch s.checkErrorPolicy(p.ctx, p.account, statusCode, respBody) {
|
switch s.checkErrorPolicy(p.ctx, p.account, statusCode, respBody) {
|
||||||
case ErrorPolicySkipped:
|
case ErrorPolicySkipped:
|
||||||
return true, nil
|
return true, http.StatusInternalServerError, nil
|
||||||
case ErrorPolicyMatched:
|
case ErrorPolicyMatched:
|
||||||
_ = p.handleError(p.ctx, p.prefix, p.account, statusCode, headers, respBody,
|
_ = p.handleError(p.ctx, p.prefix, p.account, statusCode, headers, respBody,
|
||||||
p.requestedModel, p.groupID, p.sessionHash, p.isStickySession)
|
p.requestedModel, p.groupID, p.sessionHash, p.isStickySession)
|
||||||
return true, nil
|
return true, statusCode, nil
|
||||||
case ErrorPolicyTempUnscheduled:
|
case ErrorPolicyTempUnscheduled:
|
||||||
slog.Info("temp_unschedulable_matched",
|
slog.Info("temp_unschedulable_matched",
|
||||||
"prefix", p.prefix, "status_code", statusCode, "account_id", p.account.ID)
|
"prefix", p.prefix, "status_code", statusCode, "account_id", p.account.ID)
|
||||||
return true, &AntigravityAccountSwitchError{OriginalAccountID: p.account.ID, IsStickySession: p.isStickySession}
|
return true, statusCode, &AntigravityAccountSwitchError{OriginalAccountID: p.account.ID, IsStickySession: p.isStickySession}
|
||||||
}
|
}
|
||||||
return false, nil
|
return false, statusCode, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// mapAntigravityModel 获取映射后的模型名
|
// mapAntigravityModel 获取映射后的模型名
|
||||||
@@ -2242,6 +2243,10 @@ func (s *AntigravityGatewayService) handleUpstreamError(
|
|||||||
requestedModel string,
|
requestedModel string,
|
||||||
groupID int64, sessionHash string, isStickySession bool,
|
groupID int64, sessionHash string, isStickySession bool,
|
||||||
) *handleModelRateLimitResult {
|
) *handleModelRateLimitResult {
|
||||||
|
// 遵守自定义错误码策略:未命中则跳过所有限流处理
|
||||||
|
if !account.ShouldHandleErrorCode(statusCode) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
// 模型级限流处理(优先)
|
// 模型级限流处理(优先)
|
||||||
result := s.handleModelRateLimit(&handleModelRateLimitParams{
|
result := s.handleModelRateLimit(&handleModelRateLimitParams{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
|
|||||||
@@ -116,7 +116,7 @@ func TestRetryLoop_ErrorPolicy_CustomErrorCodes(t *testing.T) {
|
|||||||
customCodes: []any{float64(500)},
|
customCodes: []any{float64(500)},
|
||||||
expectHandleError: 0,
|
expectHandleError: 0,
|
||||||
expectUpstream: 1,
|
expectUpstream: 1,
|
||||||
expectStatusCode: 429,
|
expectStatusCode: 500,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "500_in_custom_codes_matched",
|
name: "500_in_custom_codes_matched",
|
||||||
@@ -364,3 +364,109 @@ func TestRetryLoop_ErrorPolicy_NoPolicy_OriginalBehavior(t *testing.T) {
|
|||||||
require.Equal(t, antigravityMaxRetries, upstream.calls, "should exhaust all retries")
|
require.Equal(t, antigravityMaxRetries, upstream.calls, "should exhaust all retries")
|
||||||
require.Equal(t, 1, handleErrorCount, "handleError should be called once after retries exhausted")
|
require.Equal(t, 1, handleErrorCount, "handleError should be called once after retries exhausted")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// epTrackingRepo — records SetRateLimited / SetError calls for verification.
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
type epTrackingRepo struct {
|
||||||
|
mockAccountRepoForGemini
|
||||||
|
rateLimitedCalls int
|
||||||
|
rateLimitedID int64
|
||||||
|
setErrCalls int
|
||||||
|
setErrID int64
|
||||||
|
tempCalls int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *epTrackingRepo) SetRateLimited(_ context.Context, id int64, _ time.Time) error {
|
||||||
|
r.rateLimitedCalls++
|
||||||
|
r.rateLimitedID = id
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *epTrackingRepo) SetError(_ context.Context, id int64, _ string) error {
|
||||||
|
r.setErrCalls++
|
||||||
|
r.setErrID = id
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *epTrackingRepo) SetTempUnschedulable(_ context.Context, _ int64, _ time.Time, _ string) error {
|
||||||
|
r.tempCalls++
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// TestCustomErrorCode599_SkippedErrors_Return500_NoRateLimit
|
||||||
|
//
|
||||||
|
// 核心场景:自定义错误码设为 [599](一个不会真正出现的错误码),
|
||||||
|
// 当上游返回 429/500/503/401 时:
|
||||||
|
// - 返回给客户端的状态码必须是 500(而不是透传原始状态码)
|
||||||
|
// - 不调用 SetRateLimited(不进入限流状态)
|
||||||
|
// - 不调用 SetError(不停止调度)
|
||||||
|
// - 不调用 handleError
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestCustomErrorCode599_SkippedErrors_Return500_NoRateLimit(t *testing.T) {
|
||||||
|
errorCodes := []int{429, 500, 503, 401, 403}
|
||||||
|
|
||||||
|
for _, upstreamStatus := range errorCodes {
|
||||||
|
t.Run(http.StatusText(upstreamStatus), func(t *testing.T) {
|
||||||
|
saveAndSetBaseURLs(t)
|
||||||
|
|
||||||
|
upstream := &epFixedUpstream{
|
||||||
|
statusCode: upstreamStatus,
|
||||||
|
body: `{"error":"some upstream error"}`,
|
||||||
|
}
|
||||||
|
repo := &epTrackingRepo{}
|
||||||
|
rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||||
|
svc := &AntigravityGatewayService{rateLimitService: rlSvc}
|
||||||
|
|
||||||
|
account := &Account{
|
||||||
|
ID: 500,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Platform: PlatformAntigravity,
|
||||||
|
Schedulable: true,
|
||||||
|
Status: StatusActive,
|
||||||
|
Concurrency: 1,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"custom_error_codes_enabled": true,
|
||||||
|
"custom_error_codes": []any{float64(599)},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var handleErrorCount int
|
||||||
|
p := newRetryParams(account, upstream, func(_ context.Context, _ string, _ *Account, _ int, _ http.Header, _ []byte, _ string, _ int64, _ string, _ bool) *handleModelRateLimitResult {
|
||||||
|
handleErrorCount++
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
result, err := svc.antigravityRetryLoop(p)
|
||||||
|
|
||||||
|
// 不应返回 error(Skipped 不触发账号切换)
|
||||||
|
require.NoError(t, err, "should not return error")
|
||||||
|
require.NotNil(t, result, "result should not be nil")
|
||||||
|
require.NotNil(t, result.resp, "response should not be nil")
|
||||||
|
defer func() { _ = result.resp.Body.Close() }()
|
||||||
|
|
||||||
|
// 状态码必须是 500(不透传原始状态码)
|
||||||
|
require.Equal(t, http.StatusInternalServerError, result.resp.StatusCode,
|
||||||
|
"skipped error should return 500, not %d", upstreamStatus)
|
||||||
|
|
||||||
|
// 不调用 handleError
|
||||||
|
require.Equal(t, 0, handleErrorCount,
|
||||||
|
"handleError should NOT be called for skipped errors")
|
||||||
|
|
||||||
|
// 不标记限流
|
||||||
|
require.Equal(t, 0, repo.rateLimitedCalls,
|
||||||
|
"SetRateLimited should NOT be called for skipped errors")
|
||||||
|
|
||||||
|
// 不停止调度
|
||||||
|
require.Equal(t, 0, repo.setErrCalls,
|
||||||
|
"SetError should NOT be called for skipped errors")
|
||||||
|
|
||||||
|
// 只调用一次上游(不重试)
|
||||||
|
require.Equal(t, 1, upstream.calls,
|
||||||
|
"should call upstream exactly once (no retry)")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -158,6 +158,7 @@ func TestApplyErrorPolicy(t *testing.T) {
|
|||||||
statusCode int
|
statusCode int
|
||||||
body []byte
|
body []byte
|
||||||
expectedHandled bool
|
expectedHandled bool
|
||||||
|
expectedStatus int // expected outStatus
|
||||||
expectedSwitchErr bool // expect *AntigravityAccountSwitchError
|
expectedSwitchErr bool // expect *AntigravityAccountSwitchError
|
||||||
handleErrorCalls int
|
handleErrorCalls int
|
||||||
}{
|
}{
|
||||||
@@ -171,6 +172,7 @@ func TestApplyErrorPolicy(t *testing.T) {
|
|||||||
statusCode: 500,
|
statusCode: 500,
|
||||||
body: []byte(`"error"`),
|
body: []byte(`"error"`),
|
||||||
expectedHandled: false,
|
expectedHandled: false,
|
||||||
|
expectedStatus: 500, // passthrough
|
||||||
handleErrorCalls: 0,
|
handleErrorCalls: 0,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -187,6 +189,7 @@ func TestApplyErrorPolicy(t *testing.T) {
|
|||||||
statusCode: 500, // not in custom codes
|
statusCode: 500, // not in custom codes
|
||||||
body: []byte(`"error"`),
|
body: []byte(`"error"`),
|
||||||
expectedHandled: true,
|
expectedHandled: true,
|
||||||
|
expectedStatus: http.StatusInternalServerError, // skipped → 500
|
||||||
handleErrorCalls: 0,
|
handleErrorCalls: 0,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -203,6 +206,7 @@ func TestApplyErrorPolicy(t *testing.T) {
|
|||||||
statusCode: 500,
|
statusCode: 500,
|
||||||
body: []byte(`"error"`),
|
body: []byte(`"error"`),
|
||||||
expectedHandled: true,
|
expectedHandled: true,
|
||||||
|
expectedStatus: 500, // matched → original status
|
||||||
handleErrorCalls: 1,
|
handleErrorCalls: 1,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -225,6 +229,7 @@ func TestApplyErrorPolicy(t *testing.T) {
|
|||||||
statusCode: 503,
|
statusCode: 503,
|
||||||
body: []byte(`overloaded`),
|
body: []byte(`overloaded`),
|
||||||
expectedHandled: true,
|
expectedHandled: true,
|
||||||
|
expectedStatus: 503, // temp_unscheduled → original status
|
||||||
expectedSwitchErr: true,
|
expectedSwitchErr: true,
|
||||||
handleErrorCalls: 0,
|
handleErrorCalls: 0,
|
||||||
},
|
},
|
||||||
@@ -250,9 +255,10 @@ func TestApplyErrorPolicy(t *testing.T) {
|
|||||||
isStickySession: true,
|
isStickySession: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
handled, retErr := svc.applyErrorPolicy(p, tt.statusCode, http.Header{}, tt.body)
|
handled, outStatus, retErr := svc.applyErrorPolicy(p, tt.statusCode, http.Header{}, tt.body)
|
||||||
|
|
||||||
require.Equal(t, tt.expectedHandled, handled, "handled mismatch")
|
require.Equal(t, tt.expectedHandled, handled, "handled mismatch")
|
||||||
|
require.Equal(t, tt.expectedStatus, outStatus, "outStatus mismatch")
|
||||||
require.Equal(t, tt.handleErrorCalls, handleErrorCount, "handleError call count mismatch")
|
require.Equal(t, tt.handleErrorCalls, handleErrorCount, "handleError call count mismatch")
|
||||||
|
|
||||||
if tt.expectedSwitchErr {
|
if tt.expectedSwitchErr {
|
||||||
|
|||||||
@@ -770,6 +770,14 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
|||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 错误策略优先:匹配则跳过重试直接处理。
|
||||||
|
if matched, rebuilt := s.checkErrorPolicyInLoop(ctx, account, resp); matched {
|
||||||
|
resp = rebuilt
|
||||||
|
break
|
||||||
|
} else {
|
||||||
|
resp = rebuilt
|
||||||
|
}
|
||||||
|
|
||||||
if resp.StatusCode >= 400 && s.shouldRetryGeminiUpstreamError(account, resp.StatusCode) {
|
if resp.StatusCode >= 400 && s.shouldRetryGeminiUpstreamError(account, resp.StatusCode) {
|
||||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||||
_ = resp.Body.Close()
|
_ = resp.Body.Close()
|
||||||
@@ -839,7 +847,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
|||||||
if upstreamReqID == "" {
|
if upstreamReqID == "" {
|
||||||
upstreamReqID = resp.Header.Get("x-goog-request-id")
|
upstreamReqID = resp.Header.Get("x-goog-request-id")
|
||||||
}
|
}
|
||||||
return nil, s.writeGeminiMappedError(c, account, resp.StatusCode, upstreamReqID, respBody)
|
return nil, s.writeGeminiMappedError(c, account, http.StatusInternalServerError, upstreamReqID, respBody)
|
||||||
case ErrorPolicyMatched, ErrorPolicyTempUnscheduled:
|
case ErrorPolicyMatched, ErrorPolicyTempUnscheduled:
|
||||||
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||||||
upstreamReqID := resp.Header.Get(requestIDHeader)
|
upstreamReqID := resp.Header.Get(requestIDHeader)
|
||||||
@@ -1176,6 +1184,14 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
|||||||
return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream request failed after retries: "+safeErr)
|
return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream request failed after retries: "+safeErr)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 错误策略优先:匹配则跳过重试直接处理。
|
||||||
|
if matched, rebuilt := s.checkErrorPolicyInLoop(ctx, account, resp); matched {
|
||||||
|
resp = rebuilt
|
||||||
|
break
|
||||||
|
} else {
|
||||||
|
resp = rebuilt
|
||||||
|
}
|
||||||
|
|
||||||
if resp.StatusCode >= 400 && s.shouldRetryGeminiUpstreamError(account, resp.StatusCode) {
|
if resp.StatusCode >= 400 && s.shouldRetryGeminiUpstreamError(account, resp.StatusCode) {
|
||||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||||
_ = resp.Body.Close()
|
_ = resp.Body.Close()
|
||||||
@@ -1283,7 +1299,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
|||||||
if contentType == "" {
|
if contentType == "" {
|
||||||
contentType = "application/json"
|
contentType = "application/json"
|
||||||
}
|
}
|
||||||
c.Data(resp.StatusCode, contentType, respBody)
|
c.Data(http.StatusInternalServerError, contentType, respBody)
|
||||||
return nil, fmt.Errorf("gemini upstream error: %d (skipped by error policy)", resp.StatusCode)
|
return nil, fmt.Errorf("gemini upstream error: %d (skipped by error policy)", resp.StatusCode)
|
||||||
case ErrorPolicyMatched, ErrorPolicyTempUnscheduled:
|
case ErrorPolicyMatched, ErrorPolicyTempUnscheduled:
|
||||||
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||||||
@@ -1425,6 +1441,26 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// checkErrorPolicyInLoop 在重试循环内预检查错误策略。
|
||||||
|
// 返回 true 表示策略已匹配(调用者应 break),resp 已重建可直接使用。
|
||||||
|
// 返回 false 表示 ErrorPolicyNone,resp 已重建,调用者继续走重试逻辑。
|
||||||
|
func (s *GeminiMessagesCompatService) checkErrorPolicyInLoop(
|
||||||
|
ctx context.Context, account *Account, resp *http.Response,
|
||||||
|
) (matched bool, rebuilt *http.Response) {
|
||||||
|
if resp.StatusCode < 400 || s.rateLimitService == nil {
|
||||||
|
return false, resp
|
||||||
|
}
|
||||||
|
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
rebuilt = &http.Response{
|
||||||
|
StatusCode: resp.StatusCode,
|
||||||
|
Header: resp.Header.Clone(),
|
||||||
|
Body: io.NopCloser(bytes.NewReader(body)),
|
||||||
|
}
|
||||||
|
policy := s.rateLimitService.CheckErrorPolicy(ctx, account, resp.StatusCode, body)
|
||||||
|
return policy != ErrorPolicyNone, rebuilt
|
||||||
|
}
|
||||||
|
|
||||||
func (s *GeminiMessagesCompatService) shouldRetryGeminiUpstreamError(account *Account, statusCode int) bool {
|
func (s *GeminiMessagesCompatService) shouldRetryGeminiUpstreamError(account *Account, statusCode int) bool {
|
||||||
switch statusCode {
|
switch statusCode {
|
||||||
case 429, 500, 502, 503, 504, 529:
|
case 429, 500, 502, 503, 504, 529:
|
||||||
@@ -2597,6 +2633,10 @@ func asInt(v any) (int, bool) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *GeminiMessagesCompatService) handleGeminiUpstreamError(ctx context.Context, account *Account, statusCode int, headers http.Header, body []byte) {
|
func (s *GeminiMessagesCompatService) handleGeminiUpstreamError(ctx context.Context, account *Account, statusCode int, headers http.Header, body []byte) {
|
||||||
|
// 遵守自定义错误码策略:未命中则跳过所有限流处理
|
||||||
|
if !account.ShouldHandleErrorCode(statusCode) {
|
||||||
|
return
|
||||||
|
}
|
||||||
if s.rateLimitService != nil && (statusCode == 401 || statusCode == 403 || statusCode == 529) {
|
if s.rateLimitService != nil && (statusCode == 401 || statusCode == 403 || statusCode == 529) {
|
||||||
s.rateLimitService.HandleUpstreamError(ctx, account, statusCode, headers, body)
|
s.rateLimitService.HandleUpstreamError(ctx, account, statusCode, headers, body)
|
||||||
return
|
return
|
||||||
|
|||||||
Reference in New Issue
Block a user