feat: ErrorPolicySkipped returns 500 instead of upstream status code
When custom error codes are enabled and the upstream error code is NOT in the configured list, return HTTP 500 to the client instead of transparently forwarding the original status code. Also adds integration test TestCustomErrorCode599 verifying that 429, 500, 503, 401, 403 all return 500 without triggering SetRateLimited or SetError.
This commit is contained in:
@@ -116,7 +116,7 @@ func TestRetryLoop_ErrorPolicy_CustomErrorCodes(t *testing.T) {
|
||||
customCodes: []any{float64(500)},
|
||||
expectHandleError: 0,
|
||||
expectUpstream: 1,
|
||||
expectStatusCode: 429,
|
||||
expectStatusCode: 500,
|
||||
},
|
||||
{
|
||||
name: "500_in_custom_codes_matched",
|
||||
@@ -364,3 +364,109 @@ func TestRetryLoop_ErrorPolicy_NoPolicy_OriginalBehavior(t *testing.T) {
|
||||
require.Equal(t, antigravityMaxRetries, upstream.calls, "should exhaust all retries")
|
||||
require.Equal(t, 1, handleErrorCount, "handleError should be called once after retries exhausted")
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// epTrackingRepo — records SetRateLimited / SetError calls for verification.
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
type epTrackingRepo struct {
|
||||
mockAccountRepoForGemini
|
||||
rateLimitedCalls int
|
||||
rateLimitedID int64
|
||||
setErrCalls int
|
||||
setErrID int64
|
||||
tempCalls int
|
||||
}
|
||||
|
||||
func (r *epTrackingRepo) SetRateLimited(_ context.Context, id int64, _ time.Time) error {
|
||||
r.rateLimitedCalls++
|
||||
r.rateLimitedID = id
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *epTrackingRepo) SetError(_ context.Context, id int64, _ string) error {
|
||||
r.setErrCalls++
|
||||
r.setErrID = id
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *epTrackingRepo) SetTempUnschedulable(_ context.Context, _ int64, _ time.Time, _ string) error {
|
||||
r.tempCalls++
|
||||
return nil
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// TestCustomErrorCode599_SkippedErrors_Return500_NoRateLimit
|
||||
//
|
||||
// 核心场景:自定义错误码设为 [599](一个不会真正出现的错误码),
|
||||
// 当上游返回 429/500/503/401 时:
|
||||
// - 返回给客户端的状态码必须是 500(而不是透传原始状态码)
|
||||
// - 不调用 SetRateLimited(不进入限流状态)
|
||||
// - 不调用 SetError(不停止调度)
|
||||
// - 不调用 handleError
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestCustomErrorCode599_SkippedErrors_Return500_NoRateLimit(t *testing.T) {
|
||||
errorCodes := []int{429, 500, 503, 401, 403}
|
||||
|
||||
for _, upstreamStatus := range errorCodes {
|
||||
t.Run(http.StatusText(upstreamStatus), func(t *testing.T) {
|
||||
saveAndSetBaseURLs(t)
|
||||
|
||||
upstream := &epFixedUpstream{
|
||||
statusCode: upstreamStatus,
|
||||
body: `{"error":"some upstream error"}`,
|
||||
}
|
||||
repo := &epTrackingRepo{}
|
||||
rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||
svc := &AntigravityGatewayService{rateLimitService: rlSvc}
|
||||
|
||||
account := &Account{
|
||||
ID: 500,
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformAntigravity,
|
||||
Schedulable: true,
|
||||
Status: StatusActive,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"custom_error_codes_enabled": true,
|
||||
"custom_error_codes": []any{float64(599)},
|
||||
},
|
||||
}
|
||||
|
||||
var handleErrorCount int
|
||||
p := newRetryParams(account, upstream, func(_ context.Context, _ string, _ *Account, _ int, _ http.Header, _ []byte, _ string, _ int64, _ string, _ bool) *handleModelRateLimitResult {
|
||||
handleErrorCount++
|
||||
return nil
|
||||
})
|
||||
|
||||
result, err := svc.antigravityRetryLoop(p)
|
||||
|
||||
// 不应返回 error(Skipped 不触发账号切换)
|
||||
require.NoError(t, err, "should not return error")
|
||||
require.NotNil(t, result, "result should not be nil")
|
||||
require.NotNil(t, result.resp, "response should not be nil")
|
||||
defer func() { _ = result.resp.Body.Close() }()
|
||||
|
||||
// 状态码必须是 500(不透传原始状态码)
|
||||
require.Equal(t, http.StatusInternalServerError, result.resp.StatusCode,
|
||||
"skipped error should return 500, not %d", upstreamStatus)
|
||||
|
||||
// 不调用 handleError
|
||||
require.Equal(t, 0, handleErrorCount,
|
||||
"handleError should NOT be called for skipped errors")
|
||||
|
||||
// 不标记限流
|
||||
require.Equal(t, 0, repo.rateLimitedCalls,
|
||||
"SetRateLimited should NOT be called for skipped errors")
|
||||
|
||||
// 不停止调度
|
||||
require.Equal(t, 0, repo.setErrCalls,
|
||||
"SetError should NOT be called for skipped errors")
|
||||
|
||||
// 只调用一次上游(不重试)
|
||||
require.Equal(t, 1, upstream.calls,
|
||||
"should call upstream exactly once (no retry)")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user