367 lines
12 KiB
Go
367 lines
12 KiB
Go
//go:build unit
|
|
|
|
package service
|
|
|
|
import (
|
|
"context"
|
|
"io"
|
|
"net/http"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
|
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// Mocks (scoped to this file by naming convention)
|
|
// ---------------------------------------------------------------------------
|
|
|
|
// epFixedUpstream returns a fixed response for every request.
|
|
type epFixedUpstream struct {
|
|
statusCode int
|
|
body string
|
|
calls int
|
|
}
|
|
|
|
func (u *epFixedUpstream) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) {
|
|
u.calls++
|
|
return &http.Response{
|
|
StatusCode: u.statusCode,
|
|
Header: http.Header{},
|
|
Body: io.NopCloser(strings.NewReader(u.body)),
|
|
}, nil
|
|
}
|
|
|
|
func (u *epFixedUpstream) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, enableTLSFingerprint bool) (*http.Response, error) {
|
|
return u.Do(req, proxyURL, accountID, accountConcurrency)
|
|
}
|
|
|
|
// epAccountRepo records SetTempUnschedulable / SetError calls.
|
|
type epAccountRepo struct {
|
|
mockAccountRepoForGemini
|
|
tempCalls int
|
|
setErrCalls int
|
|
}
|
|
|
|
func (r *epAccountRepo) SetTempUnschedulable(_ context.Context, _ int64, _ time.Time, _ string) error {
|
|
r.tempCalls++
|
|
return nil
|
|
}
|
|
|
|
func (r *epAccountRepo) SetError(_ context.Context, _ int64, _ string) error {
|
|
r.setErrCalls++
|
|
return nil
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// Helpers
|
|
// ---------------------------------------------------------------------------
|
|
|
|
func saveAndSetBaseURLs(t *testing.T) {
|
|
t.Helper()
|
|
oldBaseURLs := append([]string(nil), antigravity.BaseURLs...)
|
|
oldAvail := antigravity.DefaultURLAvailability
|
|
antigravity.BaseURLs = []string{"https://ep-test.example"}
|
|
antigravity.DefaultURLAvailability = antigravity.NewURLAvailability(time.Minute)
|
|
t.Cleanup(func() {
|
|
antigravity.BaseURLs = oldBaseURLs
|
|
antigravity.DefaultURLAvailability = oldAvail
|
|
})
|
|
}
|
|
|
|
func newRetryParams(account *Account, upstream HTTPUpstream, handleError func(context.Context, string, *Account, int, http.Header, []byte, string, int64, string, bool) *handleModelRateLimitResult) antigravityRetryLoopParams {
|
|
return antigravityRetryLoopParams{
|
|
ctx: context.Background(),
|
|
prefix: "[ep-test]",
|
|
account: account,
|
|
accessToken: "token",
|
|
action: "generateContent",
|
|
body: []byte(`{"input":"test"}`),
|
|
httpUpstream: upstream,
|
|
requestedModel: "claude-sonnet-4-5",
|
|
handleError: handleError,
|
|
}
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// TestRetryLoop_ErrorPolicy_CustomErrorCodes
|
|
// ---------------------------------------------------------------------------
|
|
|
|
func TestRetryLoop_ErrorPolicy_CustomErrorCodes(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
upstreamStatus int
|
|
upstreamBody string
|
|
customCodes []any
|
|
expectHandleError int
|
|
expectUpstream int
|
|
expectStatusCode int
|
|
}{
|
|
{
|
|
name: "429_in_custom_codes_matched",
|
|
upstreamStatus: 429,
|
|
upstreamBody: `{"error":"rate limited"}`,
|
|
customCodes: []any{float64(429)},
|
|
expectHandleError: 1,
|
|
expectUpstream: 1,
|
|
expectStatusCode: 429,
|
|
},
|
|
{
|
|
name: "429_not_in_custom_codes_skipped",
|
|
upstreamStatus: 429,
|
|
upstreamBody: `{"error":"rate limited"}`,
|
|
customCodes: []any{float64(500)},
|
|
expectHandleError: 0,
|
|
expectUpstream: 1,
|
|
expectStatusCode: 429,
|
|
},
|
|
{
|
|
name: "500_in_custom_codes_matched",
|
|
upstreamStatus: 500,
|
|
upstreamBody: `{"error":"internal"}`,
|
|
customCodes: []any{float64(500)},
|
|
expectHandleError: 1,
|
|
expectUpstream: 1,
|
|
expectStatusCode: 500,
|
|
},
|
|
{
|
|
name: "500_not_in_custom_codes_skipped",
|
|
upstreamStatus: 500,
|
|
upstreamBody: `{"error":"internal"}`,
|
|
customCodes: []any{float64(429)},
|
|
expectHandleError: 0,
|
|
expectUpstream: 1,
|
|
expectStatusCode: 500,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
saveAndSetBaseURLs(t)
|
|
|
|
upstream := &epFixedUpstream{statusCode: tt.upstreamStatus, body: tt.upstreamBody}
|
|
repo := &epAccountRepo{}
|
|
rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
|
|
|
account := &Account{
|
|
ID: 100,
|
|
Type: AccountTypeAPIKey,
|
|
Platform: PlatformAntigravity,
|
|
Schedulable: true,
|
|
Status: StatusActive,
|
|
Concurrency: 1,
|
|
Credentials: map[string]any{
|
|
"custom_error_codes_enabled": true,
|
|
"custom_error_codes": tt.customCodes,
|
|
},
|
|
}
|
|
|
|
svc := &AntigravityGatewayService{rateLimitService: rlSvc}
|
|
|
|
var handleErrorCount int
|
|
p := newRetryParams(account, upstream, func(_ context.Context, _ string, _ *Account, _ int, _ http.Header, _ []byte, _ string, _ int64, _ string, _ bool) *handleModelRateLimitResult {
|
|
handleErrorCount++
|
|
return nil
|
|
})
|
|
|
|
result, err := svc.antigravityRetryLoop(p)
|
|
|
|
require.NoError(t, err)
|
|
require.NotNil(t, result)
|
|
require.NotNil(t, result.resp)
|
|
defer func() { _ = result.resp.Body.Close() }()
|
|
|
|
require.Equal(t, tt.expectStatusCode, result.resp.StatusCode)
|
|
require.Equal(t, tt.expectHandleError, handleErrorCount, "handleError call count")
|
|
require.Equal(t, tt.expectUpstream, upstream.calls, "upstream call count")
|
|
})
|
|
}
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// TestRetryLoop_ErrorPolicy_TempUnschedulable
|
|
// ---------------------------------------------------------------------------
|
|
|
|
func TestRetryLoop_ErrorPolicy_TempUnschedulable(t *testing.T) {
|
|
tempRulesAccount := func(rules []any) *Account {
|
|
return &Account{
|
|
ID: 200,
|
|
Type: AccountTypeOAuth,
|
|
Platform: PlatformAntigravity,
|
|
Schedulable: true,
|
|
Status: StatusActive,
|
|
Concurrency: 1,
|
|
Credentials: map[string]any{
|
|
"temp_unschedulable_enabled": true,
|
|
"temp_unschedulable_rules": rules,
|
|
},
|
|
}
|
|
}
|
|
|
|
overloadedRule := map[string]any{
|
|
"error_code": float64(503),
|
|
"keywords": []any{"overloaded"},
|
|
"duration_minutes": float64(10),
|
|
}
|
|
|
|
rateLimitRule := map[string]any{
|
|
"error_code": float64(429),
|
|
"keywords": []any{"rate limited keyword"},
|
|
"duration_minutes": float64(5),
|
|
}
|
|
|
|
t.Run("503_overloaded_matches_rule", func(t *testing.T) {
|
|
saveAndSetBaseURLs(t)
|
|
|
|
upstream := &epFixedUpstream{statusCode: 503, body: `overloaded`}
|
|
repo := &epAccountRepo{}
|
|
rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
|
svc := &AntigravityGatewayService{rateLimitService: rlSvc}
|
|
|
|
account := tempRulesAccount([]any{overloadedRule})
|
|
p := newRetryParams(account, upstream, func(_ context.Context, _ string, _ *Account, _ int, _ http.Header, _ []byte, _ string, _ int64, _ string, _ bool) *handleModelRateLimitResult {
|
|
t.Error("handleError should not be called for temp unschedulable")
|
|
return nil
|
|
})
|
|
|
|
result, err := svc.antigravityRetryLoop(p)
|
|
|
|
require.Nil(t, result)
|
|
var switchErr *AntigravityAccountSwitchError
|
|
require.ErrorAs(t, err, &switchErr)
|
|
require.Equal(t, account.ID, switchErr.OriginalAccountID)
|
|
require.Equal(t, 1, upstream.calls, "should not retry")
|
|
})
|
|
|
|
t.Run("429_rate_limited_keyword_matches_rule", func(t *testing.T) {
|
|
saveAndSetBaseURLs(t)
|
|
|
|
upstream := &epFixedUpstream{statusCode: 429, body: `rate limited keyword`}
|
|
repo := &epAccountRepo{}
|
|
rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
|
svc := &AntigravityGatewayService{rateLimitService: rlSvc}
|
|
|
|
account := tempRulesAccount([]any{rateLimitRule})
|
|
p := newRetryParams(account, upstream, func(_ context.Context, _ string, _ *Account, _ int, _ http.Header, _ []byte, _ string, _ int64, _ string, _ bool) *handleModelRateLimitResult {
|
|
t.Error("handleError should not be called for temp unschedulable")
|
|
return nil
|
|
})
|
|
|
|
result, err := svc.antigravityRetryLoop(p)
|
|
|
|
require.Nil(t, result)
|
|
var switchErr *AntigravityAccountSwitchError
|
|
require.ErrorAs(t, err, &switchErr)
|
|
require.Equal(t, account.ID, switchErr.OriginalAccountID)
|
|
require.Equal(t, 1, upstream.calls, "should not retry")
|
|
})
|
|
|
|
t.Run("503_body_no_match_continues_default_retry", func(t *testing.T) {
|
|
saveAndSetBaseURLs(t)
|
|
|
|
upstream := &epFixedUpstream{statusCode: 503, body: `random`}
|
|
repo := &epAccountRepo{}
|
|
rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
|
svc := &AntigravityGatewayService{rateLimitService: rlSvc}
|
|
|
|
account := tempRulesAccount([]any{overloadedRule})
|
|
|
|
// Use a short-lived context: the backoff sleep (~1s) will be
|
|
// interrupted, proving the code entered the default retry path
|
|
// instead of breaking early via error policy.
|
|
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
|
defer cancel()
|
|
|
|
p := newRetryParams(account, upstream, func(_ context.Context, _ string, _ *Account, _ int, _ http.Header, _ []byte, _ string, _ int64, _ string, _ bool) *handleModelRateLimitResult {
|
|
return nil
|
|
})
|
|
p.ctx = ctx
|
|
|
|
result, err := svc.antigravityRetryLoop(p)
|
|
|
|
// Context cancellation during backoff proves default retry was entered
|
|
require.Nil(t, result)
|
|
require.ErrorIs(t, err, context.DeadlineExceeded)
|
|
require.GreaterOrEqual(t, upstream.calls, 1, "should have called upstream at least once")
|
|
})
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// TestRetryLoop_ErrorPolicy_NilRateLimitService
|
|
// ---------------------------------------------------------------------------
|
|
|
|
func TestRetryLoop_ErrorPolicy_NilRateLimitService(t *testing.T) {
|
|
saveAndSetBaseURLs(t)
|
|
|
|
upstream := &epFixedUpstream{statusCode: 429, body: `{"error":"rate limited"}`}
|
|
// rateLimitService is nil — must not panic
|
|
svc := &AntigravityGatewayService{rateLimitService: nil}
|
|
|
|
account := &Account{
|
|
ID: 300,
|
|
Type: AccountTypeOAuth,
|
|
Platform: PlatformAntigravity,
|
|
Schedulable: true,
|
|
Status: StatusActive,
|
|
Concurrency: 1,
|
|
}
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
|
defer cancel()
|
|
|
|
p := newRetryParams(account, upstream, func(_ context.Context, _ string, _ *Account, _ int, _ http.Header, _ []byte, _ string, _ int64, _ string, _ bool) *handleModelRateLimitResult {
|
|
return nil
|
|
})
|
|
p.ctx = ctx
|
|
|
|
// Should not panic; enters the default retry path (eventually times out)
|
|
result, err := svc.antigravityRetryLoop(p)
|
|
|
|
require.Nil(t, result)
|
|
require.ErrorIs(t, err, context.DeadlineExceeded)
|
|
require.GreaterOrEqual(t, upstream.calls, 1)
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// TestRetryLoop_ErrorPolicy_NoPolicy_OriginalBehavior
|
|
// ---------------------------------------------------------------------------
|
|
|
|
func TestRetryLoop_ErrorPolicy_NoPolicy_OriginalBehavior(t *testing.T) {
|
|
saveAndSetBaseURLs(t)
|
|
|
|
upstream := &epFixedUpstream{statusCode: 429, body: `{"error":"rate limited"}`}
|
|
repo := &epAccountRepo{}
|
|
rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
|
svc := &AntigravityGatewayService{rateLimitService: rlSvc}
|
|
|
|
// Plain OAuth account with no error policy configured
|
|
account := &Account{
|
|
ID: 400,
|
|
Type: AccountTypeOAuth,
|
|
Platform: PlatformAntigravity,
|
|
Schedulable: true,
|
|
Status: StatusActive,
|
|
Concurrency: 1,
|
|
}
|
|
|
|
var handleErrorCount int
|
|
p := newRetryParams(account, upstream, func(_ context.Context, _ string, _ *Account, _ int, _ http.Header, _ []byte, _ string, _ int64, _ string, _ bool) *handleModelRateLimitResult {
|
|
handleErrorCount++
|
|
return nil
|
|
})
|
|
|
|
result, err := svc.antigravityRetryLoop(p)
|
|
|
|
require.NoError(t, err)
|
|
require.NotNil(t, result)
|
|
require.NotNil(t, result.resp)
|
|
defer func() { _ = result.resp.Body.Close() }()
|
|
|
|
require.Equal(t, http.StatusTooManyRequests, result.resp.StatusCode)
|
|
require.Equal(t, antigravityMaxRetries, upstream.calls, "should exhaust all retries")
|
|
require.Equal(t, 1, handleErrorCount, "handleError should be called once after retries exhausted")
|
|
}
|