feat: ErrorPolicySkipped returns 500 instead of upstream status code
When custom error codes are enabled and the upstream error code is NOT in the configured list, return HTTP 500 to the client instead of transparently forwarding the original status code. Also adds integration test TestCustomErrorCode599 verifying that 429, 500, 503, 401, 403 all return 500 without triggering SetRateLimited or SetError.
This commit is contained in:
@@ -371,12 +371,12 @@ urlFallbackLoop:
|
|||||||
_ = resp.Body.Close()
|
_ = resp.Body.Close()
|
||||||
|
|
||||||
// ★ 统一入口:自定义错误码 + 临时不可调度
|
// ★ 统一入口:自定义错误码 + 临时不可调度
|
||||||
if handled, policyErr := s.applyErrorPolicy(p, resp.StatusCode, resp.Header, respBody); handled {
|
if handled, outStatus, policyErr := s.applyErrorPolicy(p, resp.StatusCode, resp.Header, respBody); handled {
|
||||||
if policyErr != nil {
|
if policyErr != nil {
|
||||||
return nil, policyErr
|
return nil, policyErr
|
||||||
}
|
}
|
||||||
resp = &http.Response{
|
resp = &http.Response{
|
||||||
StatusCode: resp.StatusCode,
|
StatusCode: outStatus,
|
||||||
Header: resp.Header.Clone(),
|
Header: resp.Header.Clone(),
|
||||||
Body: io.NopCloser(bytes.NewReader(respBody)),
|
Body: io.NopCloser(bytes.NewReader(respBody)),
|
||||||
}
|
}
|
||||||
@@ -610,21 +610,22 @@ func (s *AntigravityGatewayService) checkErrorPolicy(ctx context.Context, accoun
|
|||||||
return s.rateLimitService.CheckErrorPolicy(ctx, account, statusCode, body)
|
return s.rateLimitService.CheckErrorPolicy(ctx, account, statusCode, body)
|
||||||
}
|
}
|
||||||
|
|
||||||
// applyErrorPolicy 应用错误策略结果,返回是否应终止当前循环
|
// applyErrorPolicy 应用错误策略结果,返回是否应终止当前循环及应返回的状态码。
|
||||||
func (s *AntigravityGatewayService) applyErrorPolicy(p antigravityRetryLoopParams, statusCode int, headers http.Header, respBody []byte) (handled bool, retErr error) {
|
// ErrorPolicySkipped 时 outStatus 为 500(前端约定:未命中的错误返回 500)。
|
||||||
|
func (s *AntigravityGatewayService) applyErrorPolicy(p antigravityRetryLoopParams, statusCode int, headers http.Header, respBody []byte) (handled bool, outStatus int, retErr error) {
|
||||||
switch s.checkErrorPolicy(p.ctx, p.account, statusCode, respBody) {
|
switch s.checkErrorPolicy(p.ctx, p.account, statusCode, respBody) {
|
||||||
case ErrorPolicySkipped:
|
case ErrorPolicySkipped:
|
||||||
return true, nil
|
return true, http.StatusInternalServerError, nil
|
||||||
case ErrorPolicyMatched:
|
case ErrorPolicyMatched:
|
||||||
_ = p.handleError(p.ctx, p.prefix, p.account, statusCode, headers, respBody,
|
_ = p.handleError(p.ctx, p.prefix, p.account, statusCode, headers, respBody,
|
||||||
p.requestedModel, p.groupID, p.sessionHash, p.isStickySession)
|
p.requestedModel, p.groupID, p.sessionHash, p.isStickySession)
|
||||||
return true, nil
|
return true, statusCode, nil
|
||||||
case ErrorPolicyTempUnscheduled:
|
case ErrorPolicyTempUnscheduled:
|
||||||
slog.Info("temp_unschedulable_matched",
|
slog.Info("temp_unschedulable_matched",
|
||||||
"prefix", p.prefix, "status_code", statusCode, "account_id", p.account.ID)
|
"prefix", p.prefix, "status_code", statusCode, "account_id", p.account.ID)
|
||||||
return true, &AntigravityAccountSwitchError{OriginalAccountID: p.account.ID, IsStickySession: p.isStickySession}
|
return true, statusCode, &AntigravityAccountSwitchError{OriginalAccountID: p.account.ID, IsStickySession: p.isStickySession}
|
||||||
}
|
}
|
||||||
return false, nil
|
return false, statusCode, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// mapAntigravityModel 获取映射后的模型名
|
// mapAntigravityModel 获取映射后的模型名
|
||||||
|
|||||||
@@ -116,7 +116,7 @@ func TestRetryLoop_ErrorPolicy_CustomErrorCodes(t *testing.T) {
|
|||||||
customCodes: []any{float64(500)},
|
customCodes: []any{float64(500)},
|
||||||
expectHandleError: 0,
|
expectHandleError: 0,
|
||||||
expectUpstream: 1,
|
expectUpstream: 1,
|
||||||
expectStatusCode: 429,
|
expectStatusCode: 500,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "500_in_custom_codes_matched",
|
name: "500_in_custom_codes_matched",
|
||||||
@@ -364,3 +364,109 @@ func TestRetryLoop_ErrorPolicy_NoPolicy_OriginalBehavior(t *testing.T) {
|
|||||||
require.Equal(t, antigravityMaxRetries, upstream.calls, "should exhaust all retries")
|
require.Equal(t, antigravityMaxRetries, upstream.calls, "should exhaust all retries")
|
||||||
require.Equal(t, 1, handleErrorCount, "handleError should be called once after retries exhausted")
|
require.Equal(t, 1, handleErrorCount, "handleError should be called once after retries exhausted")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// epTrackingRepo — records SetRateLimited / SetError calls for verification.
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
type epTrackingRepo struct {
|
||||||
|
mockAccountRepoForGemini
|
||||||
|
rateLimitedCalls int
|
||||||
|
rateLimitedID int64
|
||||||
|
setErrCalls int
|
||||||
|
setErrID int64
|
||||||
|
tempCalls int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *epTrackingRepo) SetRateLimited(_ context.Context, id int64, _ time.Time) error {
|
||||||
|
r.rateLimitedCalls++
|
||||||
|
r.rateLimitedID = id
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *epTrackingRepo) SetError(_ context.Context, id int64, _ string) error {
|
||||||
|
r.setErrCalls++
|
||||||
|
r.setErrID = id
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *epTrackingRepo) SetTempUnschedulable(_ context.Context, _ int64, _ time.Time, _ string) error {
|
||||||
|
r.tempCalls++
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// TestCustomErrorCode599_SkippedErrors_Return500_NoRateLimit
|
||||||
|
//
|
||||||
|
// 核心场景:自定义错误码设为 [599](一个不会真正出现的错误码),
|
||||||
|
// 当上游返回 429/500/503/401 时:
|
||||||
|
// - 返回给客户端的状态码必须是 500(而不是透传原始状态码)
|
||||||
|
// - 不调用 SetRateLimited(不进入限流状态)
|
||||||
|
// - 不调用 SetError(不停止调度)
|
||||||
|
// - 不调用 handleError
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestCustomErrorCode599_SkippedErrors_Return500_NoRateLimit(t *testing.T) {
|
||||||
|
errorCodes := []int{429, 500, 503, 401, 403}
|
||||||
|
|
||||||
|
for _, upstreamStatus := range errorCodes {
|
||||||
|
t.Run(http.StatusText(upstreamStatus), func(t *testing.T) {
|
||||||
|
saveAndSetBaseURLs(t)
|
||||||
|
|
||||||
|
upstream := &epFixedUpstream{
|
||||||
|
statusCode: upstreamStatus,
|
||||||
|
body: `{"error":"some upstream error"}`,
|
||||||
|
}
|
||||||
|
repo := &epTrackingRepo{}
|
||||||
|
rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||||
|
svc := &AntigravityGatewayService{rateLimitService: rlSvc}
|
||||||
|
|
||||||
|
account := &Account{
|
||||||
|
ID: 500,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Platform: PlatformAntigravity,
|
||||||
|
Schedulable: true,
|
||||||
|
Status: StatusActive,
|
||||||
|
Concurrency: 1,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"custom_error_codes_enabled": true,
|
||||||
|
"custom_error_codes": []any{float64(599)},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var handleErrorCount int
|
||||||
|
p := newRetryParams(account, upstream, func(_ context.Context, _ string, _ *Account, _ int, _ http.Header, _ []byte, _ string, _ int64, _ string, _ bool) *handleModelRateLimitResult {
|
||||||
|
handleErrorCount++
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
result, err := svc.antigravityRetryLoop(p)
|
||||||
|
|
||||||
|
// 不应返回 error(Skipped 不触发账号切换)
|
||||||
|
require.NoError(t, err, "should not return error")
|
||||||
|
require.NotNil(t, result, "result should not be nil")
|
||||||
|
require.NotNil(t, result.resp, "response should not be nil")
|
||||||
|
defer func() { _ = result.resp.Body.Close() }()
|
||||||
|
|
||||||
|
// 状态码必须是 500(不透传原始状态码)
|
||||||
|
require.Equal(t, http.StatusInternalServerError, result.resp.StatusCode,
|
||||||
|
"skipped error should return 500, not %d", upstreamStatus)
|
||||||
|
|
||||||
|
// 不调用 handleError
|
||||||
|
require.Equal(t, 0, handleErrorCount,
|
||||||
|
"handleError should NOT be called for skipped errors")
|
||||||
|
|
||||||
|
// 不标记限流
|
||||||
|
require.Equal(t, 0, repo.rateLimitedCalls,
|
||||||
|
"SetRateLimited should NOT be called for skipped errors")
|
||||||
|
|
||||||
|
// 不停止调度
|
||||||
|
require.Equal(t, 0, repo.setErrCalls,
|
||||||
|
"SetError should NOT be called for skipped errors")
|
||||||
|
|
||||||
|
// 只调用一次上游(不重试)
|
||||||
|
require.Equal(t, 1, upstream.calls,
|
||||||
|
"should call upstream exactly once (no retry)")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -158,6 +158,7 @@ func TestApplyErrorPolicy(t *testing.T) {
|
|||||||
statusCode int
|
statusCode int
|
||||||
body []byte
|
body []byte
|
||||||
expectedHandled bool
|
expectedHandled bool
|
||||||
|
expectedStatus int // expected outStatus
|
||||||
expectedSwitchErr bool // expect *AntigravityAccountSwitchError
|
expectedSwitchErr bool // expect *AntigravityAccountSwitchError
|
||||||
handleErrorCalls int
|
handleErrorCalls int
|
||||||
}{
|
}{
|
||||||
@@ -171,6 +172,7 @@ func TestApplyErrorPolicy(t *testing.T) {
|
|||||||
statusCode: 500,
|
statusCode: 500,
|
||||||
body: []byte(`"error"`),
|
body: []byte(`"error"`),
|
||||||
expectedHandled: false,
|
expectedHandled: false,
|
||||||
|
expectedStatus: 500, // passthrough
|
||||||
handleErrorCalls: 0,
|
handleErrorCalls: 0,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -187,6 +189,7 @@ func TestApplyErrorPolicy(t *testing.T) {
|
|||||||
statusCode: 500, // not in custom codes
|
statusCode: 500, // not in custom codes
|
||||||
body: []byte(`"error"`),
|
body: []byte(`"error"`),
|
||||||
expectedHandled: true,
|
expectedHandled: true,
|
||||||
|
expectedStatus: http.StatusInternalServerError, // skipped → 500
|
||||||
handleErrorCalls: 0,
|
handleErrorCalls: 0,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -203,6 +206,7 @@ func TestApplyErrorPolicy(t *testing.T) {
|
|||||||
statusCode: 500,
|
statusCode: 500,
|
||||||
body: []byte(`"error"`),
|
body: []byte(`"error"`),
|
||||||
expectedHandled: true,
|
expectedHandled: true,
|
||||||
|
expectedStatus: 500, // matched → original status
|
||||||
handleErrorCalls: 1,
|
handleErrorCalls: 1,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -225,6 +229,7 @@ func TestApplyErrorPolicy(t *testing.T) {
|
|||||||
statusCode: 503,
|
statusCode: 503,
|
||||||
body: []byte(`overloaded`),
|
body: []byte(`overloaded`),
|
||||||
expectedHandled: true,
|
expectedHandled: true,
|
||||||
|
expectedStatus: 503, // temp_unscheduled → original status
|
||||||
expectedSwitchErr: true,
|
expectedSwitchErr: true,
|
||||||
handleErrorCalls: 0,
|
handleErrorCalls: 0,
|
||||||
},
|
},
|
||||||
@@ -250,9 +255,10 @@ func TestApplyErrorPolicy(t *testing.T) {
|
|||||||
isStickySession: true,
|
isStickySession: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
handled, retErr := svc.applyErrorPolicy(p, tt.statusCode, http.Header{}, tt.body)
|
handled, outStatus, retErr := svc.applyErrorPolicy(p, tt.statusCode, http.Header{}, tt.body)
|
||||||
|
|
||||||
require.Equal(t, tt.expectedHandled, handled, "handled mismatch")
|
require.Equal(t, tt.expectedHandled, handled, "handled mismatch")
|
||||||
|
require.Equal(t, tt.expectedStatus, outStatus, "outStatus mismatch")
|
||||||
require.Equal(t, tt.handleErrorCalls, handleErrorCount, "handleError call count mismatch")
|
require.Equal(t, tt.handleErrorCalls, handleErrorCount, "handleError call count mismatch")
|
||||||
|
|
||||||
if tt.expectedSwitchErr {
|
if tt.expectedSwitchErr {
|
||||||
|
|||||||
@@ -839,7 +839,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
|||||||
if upstreamReqID == "" {
|
if upstreamReqID == "" {
|
||||||
upstreamReqID = resp.Header.Get("x-goog-request-id")
|
upstreamReqID = resp.Header.Get("x-goog-request-id")
|
||||||
}
|
}
|
||||||
return nil, s.writeGeminiMappedError(c, account, resp.StatusCode, upstreamReqID, respBody)
|
return nil, s.writeGeminiMappedError(c, account, http.StatusInternalServerError, upstreamReqID, respBody)
|
||||||
case ErrorPolicyMatched, ErrorPolicyTempUnscheduled:
|
case ErrorPolicyMatched, ErrorPolicyTempUnscheduled:
|
||||||
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||||||
upstreamReqID := resp.Header.Get(requestIDHeader)
|
upstreamReqID := resp.Header.Get(requestIDHeader)
|
||||||
@@ -1283,7 +1283,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
|||||||
if contentType == "" {
|
if contentType == "" {
|
||||||
contentType = "application/json"
|
contentType = "application/json"
|
||||||
}
|
}
|
||||||
c.Data(resp.StatusCode, contentType, respBody)
|
c.Data(http.StatusInternalServerError, contentType, respBody)
|
||||||
return nil, fmt.Errorf("gemini upstream error: %d (skipped by error policy)", resp.StatusCode)
|
return nil, fmt.Errorf("gemini upstream error: %d (skipped by error policy)", resp.StatusCode)
|
||||||
case ErrorPolicyMatched, ErrorPolicyTempUnscheduled:
|
case ErrorPolicyMatched, ErrorPolicyTempUnscheduled:
|
||||||
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||||||
|
|||||||
Reference in New Issue
Block a user