//go:build unit package service import ( "context" "net/http" "testing" "time" "github.com/Wei-Shaw/sub2api/internal/config" "github.com/stretchr/testify/require" ) // --------------------------------------------------------------------------- // TestCheckErrorPolicy — 6 table-driven cases for the pure logic function // --------------------------------------------------------------------------- func TestCheckErrorPolicy(t *testing.T) { tests := []struct { name string account *Account statusCode int body []byte expected ErrorPolicyResult }{ { name: "no_policy_oauth_returns_none", account: &Account{ ID: 1, Type: AccountTypeOAuth, Platform: PlatformAntigravity, // no custom error codes, no temp rules }, statusCode: 500, body: []byte(`"error"`), expected: ErrorPolicyNone, }, { name: "custom_error_codes_hit_returns_matched", account: &Account{ ID: 2, Type: AccountTypeAPIKey, Platform: PlatformAntigravity, Credentials: map[string]any{ "custom_error_codes_enabled": true, "custom_error_codes": []any{float64(429), float64(500)}, }, }, statusCode: 500, body: []byte(`"error"`), expected: ErrorPolicyMatched, }, { name: "custom_error_codes_miss_returns_skipped", account: &Account{ ID: 3, Type: AccountTypeAPIKey, Platform: PlatformAntigravity, Credentials: map[string]any{ "custom_error_codes_enabled": true, "custom_error_codes": []any{float64(429), float64(500)}, }, }, statusCode: 503, body: []byte(`"error"`), expected: ErrorPolicySkipped, }, { name: "temp_unschedulable_hit_returns_temp_unscheduled", account: &Account{ ID: 4, Type: AccountTypeOAuth, Platform: PlatformAntigravity, Credentials: map[string]any{ "temp_unschedulable_enabled": true, "temp_unschedulable_rules": []any{ map[string]any{ "error_code": float64(503), "keywords": []any{"overloaded"}, "duration_minutes": float64(10), "description": "overloaded rule", }, }, }, }, statusCode: 503, body: []byte(`overloaded service`), expected: ErrorPolicyTempUnscheduled, }, { name: "temp_unschedulable_body_miss_returns_none", account: &Account{ ID: 5, Type: AccountTypeOAuth, Platform: PlatformAntigravity, Credentials: map[string]any{ "temp_unschedulable_enabled": true, "temp_unschedulable_rules": []any{ map[string]any{ "error_code": float64(503), "keywords": []any{"overloaded"}, "duration_minutes": float64(10), "description": "overloaded rule", }, }, }, }, statusCode: 503, body: []byte(`random msg`), expected: ErrorPolicyNone, }, { name: "custom_error_codes_override_temp_unschedulable", account: &Account{ ID: 6, Type: AccountTypeAPIKey, Platform: PlatformAntigravity, Credentials: map[string]any{ "custom_error_codes_enabled": true, "custom_error_codes": []any{float64(503)}, "temp_unschedulable_enabled": true, "temp_unschedulable_rules": []any{ map[string]any{ "error_code": float64(503), "keywords": []any{"overloaded"}, "duration_minutes": float64(10), "description": "overloaded rule", }, }, }, }, statusCode: 503, body: []byte(`overloaded`), expected: ErrorPolicyMatched, // custom codes take precedence }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { repo := &errorPolicyRepoStub{} svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) result := svc.CheckErrorPolicy(context.Background(), tt.account, tt.statusCode, tt.body) require.Equal(t, tt.expected, result, "unexpected ErrorPolicyResult") }) } } // --------------------------------------------------------------------------- // TestApplyErrorPolicy — 4 table-driven cases for the wrapper method // --------------------------------------------------------------------------- func TestApplyErrorPolicy(t *testing.T) { tests := []struct { name string account *Account statusCode int body []byte expectedHandled bool expectedSwitchErr bool // expect *AntigravityAccountSwitchError handleErrorCalls int }{ { name: "none_not_handled", account: &Account{ ID: 10, Type: AccountTypeOAuth, Platform: PlatformAntigravity, }, statusCode: 500, body: []byte(`"error"`), expectedHandled: false, handleErrorCalls: 0, }, { name: "skipped_handled_no_handleError", account: &Account{ ID: 11, Type: AccountTypeAPIKey, Platform: PlatformAntigravity, Credentials: map[string]any{ "custom_error_codes_enabled": true, "custom_error_codes": []any{float64(429)}, }, }, statusCode: 500, // not in custom codes body: []byte(`"error"`), expectedHandled: true, handleErrorCalls: 0, }, { name: "matched_handled_calls_handleError", account: &Account{ ID: 12, Type: AccountTypeAPIKey, Platform: PlatformAntigravity, Credentials: map[string]any{ "custom_error_codes_enabled": true, "custom_error_codes": []any{float64(500)}, }, }, statusCode: 500, body: []byte(`"error"`), expectedHandled: true, handleErrorCalls: 1, }, { name: "temp_unscheduled_returns_switch_error", account: &Account{ ID: 13, Type: AccountTypeOAuth, Platform: PlatformAntigravity, Credentials: map[string]any{ "temp_unschedulable_enabled": true, "temp_unschedulable_rules": []any{ map[string]any{ "error_code": float64(503), "keywords": []any{"overloaded"}, "duration_minutes": float64(10), }, }, }, }, statusCode: 503, body: []byte(`overloaded`), expectedHandled: true, expectedSwitchErr: true, handleErrorCalls: 0, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { repo := &errorPolicyRepoStub{} rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) svc := &AntigravityGatewayService{ rateLimitService: rlSvc, } var handleErrorCount int p := antigravityRetryLoopParams{ ctx: context.Background(), prefix: "[test]", account: tt.account, handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { handleErrorCount++ return nil }, isStickySession: true, } handled, retErr := svc.applyErrorPolicy(p, tt.statusCode, http.Header{}, tt.body) require.Equal(t, tt.expectedHandled, handled, "handled mismatch") require.Equal(t, tt.handleErrorCalls, handleErrorCount, "handleError call count mismatch") if tt.expectedSwitchErr { var switchErr *AntigravityAccountSwitchError require.ErrorAs(t, retErr, &switchErr) require.Equal(t, tt.account.ID, switchErr.OriginalAccountID) } else { require.NoError(t, retErr) } }) } } // --------------------------------------------------------------------------- // errorPolicyRepoStub — minimal AccountRepository stub for error policy tests // --------------------------------------------------------------------------- type errorPolicyRepoStub struct { mockAccountRepoForGemini tempCalls int setErrCalls int lastErrorMsg string } func (r *errorPolicyRepoStub) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error { r.tempCalls++ return nil } func (r *errorPolicyRepoStub) SetError(ctx context.Context, id int64, errorMsg string) error { r.setErrCalls++ r.lastErrorMsg = errorMsg return nil }