When custom error codes are enabled and the upstream error code is NOT in the configured list, return HTTP 500 to the client instead of transparently forwarding the original status code. Also adds integration test TestCustomErrorCode599 verifying that 429, 500, 503, 401, 403 all return 500 without triggering SetRateLimited or SetError.
296 lines
8.3 KiB
Go
296 lines
8.3 KiB
Go
//go:build unit
|
|
|
|
package service
|
|
|
|
import (
|
|
"context"
|
|
"net/http"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// TestCheckErrorPolicy — 6 table-driven cases for the pure logic function
|
|
// ---------------------------------------------------------------------------
|
|
|
|
func TestCheckErrorPolicy(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
account *Account
|
|
statusCode int
|
|
body []byte
|
|
expected ErrorPolicyResult
|
|
}{
|
|
{
|
|
name: "no_policy_oauth_returns_none",
|
|
account: &Account{
|
|
ID: 1,
|
|
Type: AccountTypeOAuth,
|
|
Platform: PlatformAntigravity,
|
|
// no custom error codes, no temp rules
|
|
},
|
|
statusCode: 500,
|
|
body: []byte(`"error"`),
|
|
expected: ErrorPolicyNone,
|
|
},
|
|
{
|
|
name: "custom_error_codes_hit_returns_matched",
|
|
account: &Account{
|
|
ID: 2,
|
|
Type: AccountTypeAPIKey,
|
|
Platform: PlatformAntigravity,
|
|
Credentials: map[string]any{
|
|
"custom_error_codes_enabled": true,
|
|
"custom_error_codes": []any{float64(429), float64(500)},
|
|
},
|
|
},
|
|
statusCode: 500,
|
|
body: []byte(`"error"`),
|
|
expected: ErrorPolicyMatched,
|
|
},
|
|
{
|
|
name: "custom_error_codes_miss_returns_skipped",
|
|
account: &Account{
|
|
ID: 3,
|
|
Type: AccountTypeAPIKey,
|
|
Platform: PlatformAntigravity,
|
|
Credentials: map[string]any{
|
|
"custom_error_codes_enabled": true,
|
|
"custom_error_codes": []any{float64(429), float64(500)},
|
|
},
|
|
},
|
|
statusCode: 503,
|
|
body: []byte(`"error"`),
|
|
expected: ErrorPolicySkipped,
|
|
},
|
|
{
|
|
name: "temp_unschedulable_hit_returns_temp_unscheduled",
|
|
account: &Account{
|
|
ID: 4,
|
|
Type: AccountTypeOAuth,
|
|
Platform: PlatformAntigravity,
|
|
Credentials: map[string]any{
|
|
"temp_unschedulable_enabled": true,
|
|
"temp_unschedulable_rules": []any{
|
|
map[string]any{
|
|
"error_code": float64(503),
|
|
"keywords": []any{"overloaded"},
|
|
"duration_minutes": float64(10),
|
|
"description": "overloaded rule",
|
|
},
|
|
},
|
|
},
|
|
},
|
|
statusCode: 503,
|
|
body: []byte(`overloaded service`),
|
|
expected: ErrorPolicyTempUnscheduled,
|
|
},
|
|
{
|
|
name: "temp_unschedulable_body_miss_returns_none",
|
|
account: &Account{
|
|
ID: 5,
|
|
Type: AccountTypeOAuth,
|
|
Platform: PlatformAntigravity,
|
|
Credentials: map[string]any{
|
|
"temp_unschedulable_enabled": true,
|
|
"temp_unschedulable_rules": []any{
|
|
map[string]any{
|
|
"error_code": float64(503),
|
|
"keywords": []any{"overloaded"},
|
|
"duration_minutes": float64(10),
|
|
"description": "overloaded rule",
|
|
},
|
|
},
|
|
},
|
|
},
|
|
statusCode: 503,
|
|
body: []byte(`random msg`),
|
|
expected: ErrorPolicyNone,
|
|
},
|
|
{
|
|
name: "custom_error_codes_override_temp_unschedulable",
|
|
account: &Account{
|
|
ID: 6,
|
|
Type: AccountTypeAPIKey,
|
|
Platform: PlatformAntigravity,
|
|
Credentials: map[string]any{
|
|
"custom_error_codes_enabled": true,
|
|
"custom_error_codes": []any{float64(503)},
|
|
"temp_unschedulable_enabled": true,
|
|
"temp_unschedulable_rules": []any{
|
|
map[string]any{
|
|
"error_code": float64(503),
|
|
"keywords": []any{"overloaded"},
|
|
"duration_minutes": float64(10),
|
|
"description": "overloaded rule",
|
|
},
|
|
},
|
|
},
|
|
},
|
|
statusCode: 503,
|
|
body: []byte(`overloaded`),
|
|
expected: ErrorPolicyMatched, // custom codes take precedence
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
repo := &errorPolicyRepoStub{}
|
|
svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
|
|
|
result := svc.CheckErrorPolicy(context.Background(), tt.account, tt.statusCode, tt.body)
|
|
require.Equal(t, tt.expected, result, "unexpected ErrorPolicyResult")
|
|
})
|
|
}
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// TestApplyErrorPolicy — 4 table-driven cases for the wrapper method
|
|
// ---------------------------------------------------------------------------
|
|
|
|
func TestApplyErrorPolicy(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
account *Account
|
|
statusCode int
|
|
body []byte
|
|
expectedHandled bool
|
|
expectedStatus int // expected outStatus
|
|
expectedSwitchErr bool // expect *AntigravityAccountSwitchError
|
|
handleErrorCalls int
|
|
}{
|
|
{
|
|
name: "none_not_handled",
|
|
account: &Account{
|
|
ID: 10,
|
|
Type: AccountTypeOAuth,
|
|
Platform: PlatformAntigravity,
|
|
},
|
|
statusCode: 500,
|
|
body: []byte(`"error"`),
|
|
expectedHandled: false,
|
|
expectedStatus: 500, // passthrough
|
|
handleErrorCalls: 0,
|
|
},
|
|
{
|
|
name: "skipped_handled_no_handleError",
|
|
account: &Account{
|
|
ID: 11,
|
|
Type: AccountTypeAPIKey,
|
|
Platform: PlatformAntigravity,
|
|
Credentials: map[string]any{
|
|
"custom_error_codes_enabled": true,
|
|
"custom_error_codes": []any{float64(429)},
|
|
},
|
|
},
|
|
statusCode: 500, // not in custom codes
|
|
body: []byte(`"error"`),
|
|
expectedHandled: true,
|
|
expectedStatus: http.StatusInternalServerError, // skipped → 500
|
|
handleErrorCalls: 0,
|
|
},
|
|
{
|
|
name: "matched_handled_calls_handleError",
|
|
account: &Account{
|
|
ID: 12,
|
|
Type: AccountTypeAPIKey,
|
|
Platform: PlatformAntigravity,
|
|
Credentials: map[string]any{
|
|
"custom_error_codes_enabled": true,
|
|
"custom_error_codes": []any{float64(500)},
|
|
},
|
|
},
|
|
statusCode: 500,
|
|
body: []byte(`"error"`),
|
|
expectedHandled: true,
|
|
expectedStatus: 500, // matched → original status
|
|
handleErrorCalls: 1,
|
|
},
|
|
{
|
|
name: "temp_unscheduled_returns_switch_error",
|
|
account: &Account{
|
|
ID: 13,
|
|
Type: AccountTypeOAuth,
|
|
Platform: PlatformAntigravity,
|
|
Credentials: map[string]any{
|
|
"temp_unschedulable_enabled": true,
|
|
"temp_unschedulable_rules": []any{
|
|
map[string]any{
|
|
"error_code": float64(503),
|
|
"keywords": []any{"overloaded"},
|
|
"duration_minutes": float64(10),
|
|
},
|
|
},
|
|
},
|
|
},
|
|
statusCode: 503,
|
|
body: []byte(`overloaded`),
|
|
expectedHandled: true,
|
|
expectedStatus: 503, // temp_unscheduled → original status
|
|
expectedSwitchErr: true,
|
|
handleErrorCalls: 0,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
repo := &errorPolicyRepoStub{}
|
|
rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
|
svc := &AntigravityGatewayService{
|
|
rateLimitService: rlSvc,
|
|
}
|
|
|
|
var handleErrorCount int
|
|
p := antigravityRetryLoopParams{
|
|
ctx: context.Background(),
|
|
prefix: "[test]",
|
|
account: tt.account,
|
|
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
|
handleErrorCount++
|
|
return nil
|
|
},
|
|
isStickySession: true,
|
|
}
|
|
|
|
handled, outStatus, retErr := svc.applyErrorPolicy(p, tt.statusCode, http.Header{}, tt.body)
|
|
|
|
require.Equal(t, tt.expectedHandled, handled, "handled mismatch")
|
|
require.Equal(t, tt.expectedStatus, outStatus, "outStatus mismatch")
|
|
require.Equal(t, tt.handleErrorCalls, handleErrorCount, "handleError call count mismatch")
|
|
|
|
if tt.expectedSwitchErr {
|
|
var switchErr *AntigravityAccountSwitchError
|
|
require.ErrorAs(t, retErr, &switchErr)
|
|
require.Equal(t, tt.account.ID, switchErr.OriginalAccountID)
|
|
} else {
|
|
require.NoError(t, retErr)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// errorPolicyRepoStub — minimal AccountRepository stub for error policy tests
|
|
// ---------------------------------------------------------------------------
|
|
|
|
type errorPolicyRepoStub struct {
|
|
mockAccountRepoForGemini
|
|
tempCalls int
|
|
setErrCalls int
|
|
lastErrorMsg string
|
|
}
|
|
|
|
func (r *errorPolicyRepoStub) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
|
|
r.tempCalls++
|
|
return nil
|
|
}
|
|
|
|
func (r *errorPolicyRepoStub) SetError(ctx context.Context, id int64, errorMsg string) error {
|
|
r.setErrCalls++
|
|
r.lastErrorMsg = errorMsg
|
|
return nil
|
|
}
|