feat: unified error policy for Antigravity + enable custom error codes for Gemini accounts
This commit is contained in:
366
backend/internal/service/error_policy_integration_test.go
Normal file
366
backend/internal/service/error_policy_integration_test.go
Normal file
@@ -0,0 +1,366 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Mocks (scoped to this file by naming convention)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// epFixedUpstream returns a fixed response for every request.
|
||||
type epFixedUpstream struct {
|
||||
statusCode int
|
||||
body string
|
||||
calls int
|
||||
}
|
||||
|
||||
func (u *epFixedUpstream) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) {
|
||||
u.calls++
|
||||
return &http.Response{
|
||||
StatusCode: u.statusCode,
|
||||
Header: http.Header{},
|
||||
Body: io.NopCloser(strings.NewReader(u.body)),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (u *epFixedUpstream) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, enableTLSFingerprint bool) (*http.Response, error) {
|
||||
return u.Do(req, proxyURL, accountID, accountConcurrency)
|
||||
}
|
||||
|
||||
// epAccountRepo records SetTempUnschedulable / SetError calls.
|
||||
type epAccountRepo struct {
|
||||
mockAccountRepoForGemini
|
||||
tempCalls int
|
||||
setErrCalls int
|
||||
}
|
||||
|
||||
func (r *epAccountRepo) SetTempUnschedulable(_ context.Context, _ int64, _ time.Time, _ string) error {
|
||||
r.tempCalls++
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *epAccountRepo) SetError(_ context.Context, _ int64, _ string) error {
|
||||
r.setErrCalls++
|
||||
return nil
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func saveAndSetBaseURLs(t *testing.T) {
|
||||
t.Helper()
|
||||
oldBaseURLs := append([]string(nil), antigravity.BaseURLs...)
|
||||
oldAvail := antigravity.DefaultURLAvailability
|
||||
antigravity.BaseURLs = []string{"https://ep-test.example"}
|
||||
antigravity.DefaultURLAvailability = antigravity.NewURLAvailability(time.Minute)
|
||||
t.Cleanup(func() {
|
||||
antigravity.BaseURLs = oldBaseURLs
|
||||
antigravity.DefaultURLAvailability = oldAvail
|
||||
})
|
||||
}
|
||||
|
||||
func newRetryParams(account *Account, upstream HTTPUpstream, handleError func(context.Context, string, *Account, int, http.Header, []byte, string, int64, string, bool) *handleModelRateLimitResult) antigravityRetryLoopParams {
|
||||
return antigravityRetryLoopParams{
|
||||
ctx: context.Background(),
|
||||
prefix: "[ep-test]",
|
||||
account: account,
|
||||
accessToken: "token",
|
||||
action: "generateContent",
|
||||
body: []byte(`{"input":"test"}`),
|
||||
httpUpstream: upstream,
|
||||
requestedModel: "claude-sonnet-4-5",
|
||||
handleError: handleError,
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// TestRetryLoop_ErrorPolicy_CustomErrorCodes
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestRetryLoop_ErrorPolicy_CustomErrorCodes(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
upstreamStatus int
|
||||
upstreamBody string
|
||||
customCodes []any
|
||||
expectHandleError int
|
||||
expectUpstream int
|
||||
expectStatusCode int
|
||||
}{
|
||||
{
|
||||
name: "429_in_custom_codes_matched",
|
||||
upstreamStatus: 429,
|
||||
upstreamBody: `{"error":"rate limited"}`,
|
||||
customCodes: []any{float64(429)},
|
||||
expectHandleError: 1,
|
||||
expectUpstream: 1,
|
||||
expectStatusCode: 429,
|
||||
},
|
||||
{
|
||||
name: "429_not_in_custom_codes_skipped",
|
||||
upstreamStatus: 429,
|
||||
upstreamBody: `{"error":"rate limited"}`,
|
||||
customCodes: []any{float64(500)},
|
||||
expectHandleError: 0,
|
||||
expectUpstream: 1,
|
||||
expectStatusCode: 429,
|
||||
},
|
||||
{
|
||||
name: "500_in_custom_codes_matched",
|
||||
upstreamStatus: 500,
|
||||
upstreamBody: `{"error":"internal"}`,
|
||||
customCodes: []any{float64(500)},
|
||||
expectHandleError: 1,
|
||||
expectUpstream: 1,
|
||||
expectStatusCode: 500,
|
||||
},
|
||||
{
|
||||
name: "500_not_in_custom_codes_skipped",
|
||||
upstreamStatus: 500,
|
||||
upstreamBody: `{"error":"internal"}`,
|
||||
customCodes: []any{float64(429)},
|
||||
expectHandleError: 0,
|
||||
expectUpstream: 1,
|
||||
expectStatusCode: 500,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
saveAndSetBaseURLs(t)
|
||||
|
||||
upstream := &epFixedUpstream{statusCode: tt.upstreamStatus, body: tt.upstreamBody}
|
||||
repo := &epAccountRepo{}
|
||||
rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||
|
||||
account := &Account{
|
||||
ID: 100,
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformAntigravity,
|
||||
Schedulable: true,
|
||||
Status: StatusActive,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"custom_error_codes_enabled": true,
|
||||
"custom_error_codes": tt.customCodes,
|
||||
},
|
||||
}
|
||||
|
||||
svc := &AntigravityGatewayService{rateLimitService: rlSvc}
|
||||
|
||||
var handleErrorCount int
|
||||
p := newRetryParams(account, upstream, func(_ context.Context, _ string, _ *Account, _ int, _ http.Header, _ []byte, _ string, _ int64, _ string, _ bool) *handleModelRateLimitResult {
|
||||
handleErrorCount++
|
||||
return nil
|
||||
})
|
||||
|
||||
result, err := svc.antigravityRetryLoop(p)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, result.resp)
|
||||
defer func() { _ = result.resp.Body.Close() }()
|
||||
|
||||
require.Equal(t, tt.expectStatusCode, result.resp.StatusCode)
|
||||
require.Equal(t, tt.expectHandleError, handleErrorCount, "handleError call count")
|
||||
require.Equal(t, tt.expectUpstream, upstream.calls, "upstream call count")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// TestRetryLoop_ErrorPolicy_TempUnschedulable
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestRetryLoop_ErrorPolicy_TempUnschedulable(t *testing.T) {
|
||||
tempRulesAccount := func(rules []any) *Account {
|
||||
return &Account{
|
||||
ID: 200,
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAntigravity,
|
||||
Schedulable: true,
|
||||
Status: StatusActive,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"temp_unschedulable_enabled": true,
|
||||
"temp_unschedulable_rules": rules,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
overloadedRule := map[string]any{
|
||||
"error_code": float64(503),
|
||||
"keywords": []any{"overloaded"},
|
||||
"duration_minutes": float64(10),
|
||||
}
|
||||
|
||||
rateLimitRule := map[string]any{
|
||||
"error_code": float64(429),
|
||||
"keywords": []any{"rate limited keyword"},
|
||||
"duration_minutes": float64(5),
|
||||
}
|
||||
|
||||
t.Run("503_overloaded_matches_rule", func(t *testing.T) {
|
||||
saveAndSetBaseURLs(t)
|
||||
|
||||
upstream := &epFixedUpstream{statusCode: 503, body: `overloaded`}
|
||||
repo := &epAccountRepo{}
|
||||
rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||
svc := &AntigravityGatewayService{rateLimitService: rlSvc}
|
||||
|
||||
account := tempRulesAccount([]any{overloadedRule})
|
||||
p := newRetryParams(account, upstream, func(_ context.Context, _ string, _ *Account, _ int, _ http.Header, _ []byte, _ string, _ int64, _ string, _ bool) *handleModelRateLimitResult {
|
||||
t.Error("handleError should not be called for temp unschedulable")
|
||||
return nil
|
||||
})
|
||||
|
||||
result, err := svc.antigravityRetryLoop(p)
|
||||
|
||||
require.Nil(t, result)
|
||||
var switchErr *AntigravityAccountSwitchError
|
||||
require.ErrorAs(t, err, &switchErr)
|
||||
require.Equal(t, account.ID, switchErr.OriginalAccountID)
|
||||
require.Equal(t, 1, upstream.calls, "should not retry")
|
||||
})
|
||||
|
||||
t.Run("429_rate_limited_keyword_matches_rule", func(t *testing.T) {
|
||||
saveAndSetBaseURLs(t)
|
||||
|
||||
upstream := &epFixedUpstream{statusCode: 429, body: `rate limited keyword`}
|
||||
repo := &epAccountRepo{}
|
||||
rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||
svc := &AntigravityGatewayService{rateLimitService: rlSvc}
|
||||
|
||||
account := tempRulesAccount([]any{rateLimitRule})
|
||||
p := newRetryParams(account, upstream, func(_ context.Context, _ string, _ *Account, _ int, _ http.Header, _ []byte, _ string, _ int64, _ string, _ bool) *handleModelRateLimitResult {
|
||||
t.Error("handleError should not be called for temp unschedulable")
|
||||
return nil
|
||||
})
|
||||
|
||||
result, err := svc.antigravityRetryLoop(p)
|
||||
|
||||
require.Nil(t, result)
|
||||
var switchErr *AntigravityAccountSwitchError
|
||||
require.ErrorAs(t, err, &switchErr)
|
||||
require.Equal(t, account.ID, switchErr.OriginalAccountID)
|
||||
require.Equal(t, 1, upstream.calls, "should not retry")
|
||||
})
|
||||
|
||||
t.Run("503_body_no_match_continues_default_retry", func(t *testing.T) {
|
||||
saveAndSetBaseURLs(t)
|
||||
|
||||
upstream := &epFixedUpstream{statusCode: 503, body: `random`}
|
||||
repo := &epAccountRepo{}
|
||||
rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||
svc := &AntigravityGatewayService{rateLimitService: rlSvc}
|
||||
|
||||
account := tempRulesAccount([]any{overloadedRule})
|
||||
|
||||
// Use a short-lived context: the backoff sleep (~1s) will be
|
||||
// interrupted, proving the code entered the default retry path
|
||||
// instead of breaking early via error policy.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
p := newRetryParams(account, upstream, func(_ context.Context, _ string, _ *Account, _ int, _ http.Header, _ []byte, _ string, _ int64, _ string, _ bool) *handleModelRateLimitResult {
|
||||
return nil
|
||||
})
|
||||
p.ctx = ctx
|
||||
|
||||
result, err := svc.antigravityRetryLoop(p)
|
||||
|
||||
// Context cancellation during backoff proves default retry was entered
|
||||
require.Nil(t, result)
|
||||
require.ErrorIs(t, err, context.DeadlineExceeded)
|
||||
require.GreaterOrEqual(t, upstream.calls, 1, "should have called upstream at least once")
|
||||
})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// TestRetryLoop_ErrorPolicy_NilRateLimitService
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestRetryLoop_ErrorPolicy_NilRateLimitService(t *testing.T) {
|
||||
saveAndSetBaseURLs(t)
|
||||
|
||||
upstream := &epFixedUpstream{statusCode: 429, body: `{"error":"rate limited"}`}
|
||||
// rateLimitService is nil — must not panic
|
||||
svc := &AntigravityGatewayService{rateLimitService: nil}
|
||||
|
||||
account := &Account{
|
||||
ID: 300,
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAntigravity,
|
||||
Schedulable: true,
|
||||
Status: StatusActive,
|
||||
Concurrency: 1,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
p := newRetryParams(account, upstream, func(_ context.Context, _ string, _ *Account, _ int, _ http.Header, _ []byte, _ string, _ int64, _ string, _ bool) *handleModelRateLimitResult {
|
||||
return nil
|
||||
})
|
||||
p.ctx = ctx
|
||||
|
||||
// Should not panic; enters the default retry path (eventually times out)
|
||||
result, err := svc.antigravityRetryLoop(p)
|
||||
|
||||
require.Nil(t, result)
|
||||
require.ErrorIs(t, err, context.DeadlineExceeded)
|
||||
require.GreaterOrEqual(t, upstream.calls, 1)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// TestRetryLoop_ErrorPolicy_NoPolicy_OriginalBehavior
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestRetryLoop_ErrorPolicy_NoPolicy_OriginalBehavior(t *testing.T) {
|
||||
saveAndSetBaseURLs(t)
|
||||
|
||||
upstream := &epFixedUpstream{statusCode: 429, body: `{"error":"rate limited"}`}
|
||||
repo := &epAccountRepo{}
|
||||
rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||
svc := &AntigravityGatewayService{rateLimitService: rlSvc}
|
||||
|
||||
// Plain OAuth account with no error policy configured
|
||||
account := &Account{
|
||||
ID: 400,
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAntigravity,
|
||||
Schedulable: true,
|
||||
Status: StatusActive,
|
||||
Concurrency: 1,
|
||||
}
|
||||
|
||||
var handleErrorCount int
|
||||
p := newRetryParams(account, upstream, func(_ context.Context, _ string, _ *Account, _ int, _ http.Header, _ []byte, _ string, _ int64, _ string, _ bool) *handleModelRateLimitResult {
|
||||
handleErrorCount++
|
||||
return nil
|
||||
})
|
||||
|
||||
result, err := svc.antigravityRetryLoop(p)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, result.resp)
|
||||
defer func() { _ = result.resp.Body.Close() }()
|
||||
|
||||
require.Equal(t, http.StatusTooManyRequests, result.resp.StatusCode)
|
||||
require.Equal(t, antigravityMaxRetries, upstream.calls, "should exhaust all retries")
|
||||
require.Equal(t, 1, handleErrorCount, "handleError should be called once after retries exhausted")
|
||||
}
|
||||
Reference in New Issue
Block a user