feat: unified error policy for Antigravity + enable custom error codes for Gemini accounts
This commit is contained in:
289
backend/internal/service/error_policy_test.go
Normal file
289
backend/internal/service/error_policy_test.go
Normal file
@@ -0,0 +1,289 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// TestCheckErrorPolicy — 6 table-driven cases for the pure logic function
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestCheckErrorPolicy(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
account *Account
|
||||
statusCode int
|
||||
body []byte
|
||||
expected ErrorPolicyResult
|
||||
}{
|
||||
{
|
||||
name: "no_policy_oauth_returns_none",
|
||||
account: &Account{
|
||||
ID: 1,
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAntigravity,
|
||||
// no custom error codes, no temp rules
|
||||
},
|
||||
statusCode: 500,
|
||||
body: []byte(`"error"`),
|
||||
expected: ErrorPolicyNone,
|
||||
},
|
||||
{
|
||||
name: "custom_error_codes_hit_returns_matched",
|
||||
account: &Account{
|
||||
ID: 2,
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformAntigravity,
|
||||
Credentials: map[string]any{
|
||||
"custom_error_codes_enabled": true,
|
||||
"custom_error_codes": []any{float64(429), float64(500)},
|
||||
},
|
||||
},
|
||||
statusCode: 500,
|
||||
body: []byte(`"error"`),
|
||||
expected: ErrorPolicyMatched,
|
||||
},
|
||||
{
|
||||
name: "custom_error_codes_miss_returns_skipped",
|
||||
account: &Account{
|
||||
ID: 3,
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformAntigravity,
|
||||
Credentials: map[string]any{
|
||||
"custom_error_codes_enabled": true,
|
||||
"custom_error_codes": []any{float64(429), float64(500)},
|
||||
},
|
||||
},
|
||||
statusCode: 503,
|
||||
body: []byte(`"error"`),
|
||||
expected: ErrorPolicySkipped,
|
||||
},
|
||||
{
|
||||
name: "temp_unschedulable_hit_returns_temp_unscheduled",
|
||||
account: &Account{
|
||||
ID: 4,
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAntigravity,
|
||||
Credentials: map[string]any{
|
||||
"temp_unschedulable_enabled": true,
|
||||
"temp_unschedulable_rules": []any{
|
||||
map[string]any{
|
||||
"error_code": float64(503),
|
||||
"keywords": []any{"overloaded"},
|
||||
"duration_minutes": float64(10),
|
||||
"description": "overloaded rule",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
statusCode: 503,
|
||||
body: []byte(`overloaded service`),
|
||||
expected: ErrorPolicyTempUnscheduled,
|
||||
},
|
||||
{
|
||||
name: "temp_unschedulable_body_miss_returns_none",
|
||||
account: &Account{
|
||||
ID: 5,
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAntigravity,
|
||||
Credentials: map[string]any{
|
||||
"temp_unschedulable_enabled": true,
|
||||
"temp_unschedulable_rules": []any{
|
||||
map[string]any{
|
||||
"error_code": float64(503),
|
||||
"keywords": []any{"overloaded"},
|
||||
"duration_minutes": float64(10),
|
||||
"description": "overloaded rule",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
statusCode: 503,
|
||||
body: []byte(`random msg`),
|
||||
expected: ErrorPolicyNone,
|
||||
},
|
||||
{
|
||||
name: "custom_error_codes_override_temp_unschedulable",
|
||||
account: &Account{
|
||||
ID: 6,
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformAntigravity,
|
||||
Credentials: map[string]any{
|
||||
"custom_error_codes_enabled": true,
|
||||
"custom_error_codes": []any{float64(503)},
|
||||
"temp_unschedulable_enabled": true,
|
||||
"temp_unschedulable_rules": []any{
|
||||
map[string]any{
|
||||
"error_code": float64(503),
|
||||
"keywords": []any{"overloaded"},
|
||||
"duration_minutes": float64(10),
|
||||
"description": "overloaded rule",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
statusCode: 503,
|
||||
body: []byte(`overloaded`),
|
||||
expected: ErrorPolicyMatched, // custom codes take precedence
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
repo := &errorPolicyRepoStub{}
|
||||
svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||
|
||||
result := svc.CheckErrorPolicy(context.Background(), tt.account, tt.statusCode, tt.body)
|
||||
require.Equal(t, tt.expected, result, "unexpected ErrorPolicyResult")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// TestApplyErrorPolicy — 4 table-driven cases for the wrapper method
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestApplyErrorPolicy(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
account *Account
|
||||
statusCode int
|
||||
body []byte
|
||||
expectedHandled bool
|
||||
expectedSwitchErr bool // expect *AntigravityAccountSwitchError
|
||||
handleErrorCalls int
|
||||
}{
|
||||
{
|
||||
name: "none_not_handled",
|
||||
account: &Account{
|
||||
ID: 10,
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAntigravity,
|
||||
},
|
||||
statusCode: 500,
|
||||
body: []byte(`"error"`),
|
||||
expectedHandled: false,
|
||||
handleErrorCalls: 0,
|
||||
},
|
||||
{
|
||||
name: "skipped_handled_no_handleError",
|
||||
account: &Account{
|
||||
ID: 11,
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformAntigravity,
|
||||
Credentials: map[string]any{
|
||||
"custom_error_codes_enabled": true,
|
||||
"custom_error_codes": []any{float64(429)},
|
||||
},
|
||||
},
|
||||
statusCode: 500, // not in custom codes
|
||||
body: []byte(`"error"`),
|
||||
expectedHandled: true,
|
||||
handleErrorCalls: 0,
|
||||
},
|
||||
{
|
||||
name: "matched_handled_calls_handleError",
|
||||
account: &Account{
|
||||
ID: 12,
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformAntigravity,
|
||||
Credentials: map[string]any{
|
||||
"custom_error_codes_enabled": true,
|
||||
"custom_error_codes": []any{float64(500)},
|
||||
},
|
||||
},
|
||||
statusCode: 500,
|
||||
body: []byte(`"error"`),
|
||||
expectedHandled: true,
|
||||
handleErrorCalls: 1,
|
||||
},
|
||||
{
|
||||
name: "temp_unscheduled_returns_switch_error",
|
||||
account: &Account{
|
||||
ID: 13,
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAntigravity,
|
||||
Credentials: map[string]any{
|
||||
"temp_unschedulable_enabled": true,
|
||||
"temp_unschedulable_rules": []any{
|
||||
map[string]any{
|
||||
"error_code": float64(503),
|
||||
"keywords": []any{"overloaded"},
|
||||
"duration_minutes": float64(10),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
statusCode: 503,
|
||||
body: []byte(`overloaded`),
|
||||
expectedHandled: true,
|
||||
expectedSwitchErr: true,
|
||||
handleErrorCalls: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
repo := &errorPolicyRepoStub{}
|
||||
rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||
svc := &AntigravityGatewayService{
|
||||
rateLimitService: rlSvc,
|
||||
}
|
||||
|
||||
var handleErrorCount int
|
||||
p := antigravityRetryLoopParams{
|
||||
ctx: context.Background(),
|
||||
prefix: "[test]",
|
||||
account: tt.account,
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
handleErrorCount++
|
||||
return nil
|
||||
},
|
||||
isStickySession: true,
|
||||
}
|
||||
|
||||
handled, retErr := svc.applyErrorPolicy(p, tt.statusCode, http.Header{}, tt.body)
|
||||
|
||||
require.Equal(t, tt.expectedHandled, handled, "handled mismatch")
|
||||
require.Equal(t, tt.handleErrorCalls, handleErrorCount, "handleError call count mismatch")
|
||||
|
||||
if tt.expectedSwitchErr {
|
||||
var switchErr *AntigravityAccountSwitchError
|
||||
require.ErrorAs(t, retErr, &switchErr)
|
||||
require.Equal(t, tt.account.ID, switchErr.OriginalAccountID)
|
||||
} else {
|
||||
require.NoError(t, retErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// errorPolicyRepoStub — minimal AccountRepository stub for error policy tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
type errorPolicyRepoStub struct {
|
||||
mockAccountRepoForGemini
|
||||
tempCalls int
|
||||
setErrCalls int
|
||||
lastErrorMsg string
|
||||
}
|
||||
|
||||
func (r *errorPolicyRepoStub) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
|
||||
r.tempCalls++
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *errorPolicyRepoStub) SetError(ctx context.Context, id int64, errorMsg string) error {
|
||||
r.setErrCalls++
|
||||
r.lastErrorMsg = errorMsg
|
||||
return nil
|
||||
}
|
||||
Reference in New Issue
Block a user