290 lines
8.0 KiB
Go
290 lines
8.0 KiB
Go
//go:build unit
|
|
|
|
package service
|
|
|
|
import (
|
|
"context"
|
|
"net/http"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// TestCheckErrorPolicy — 6 table-driven cases for the pure logic function
|
|
// ---------------------------------------------------------------------------
|
|
|
|
func TestCheckErrorPolicy(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
account *Account
|
|
statusCode int
|
|
body []byte
|
|
expected ErrorPolicyResult
|
|
}{
|
|
{
|
|
name: "no_policy_oauth_returns_none",
|
|
account: &Account{
|
|
ID: 1,
|
|
Type: AccountTypeOAuth,
|
|
Platform: PlatformAntigravity,
|
|
// no custom error codes, no temp rules
|
|
},
|
|
statusCode: 500,
|
|
body: []byte(`"error"`),
|
|
expected: ErrorPolicyNone,
|
|
},
|
|
{
|
|
name: "custom_error_codes_hit_returns_matched",
|
|
account: &Account{
|
|
ID: 2,
|
|
Type: AccountTypeAPIKey,
|
|
Platform: PlatformAntigravity,
|
|
Credentials: map[string]any{
|
|
"custom_error_codes_enabled": true,
|
|
"custom_error_codes": []any{float64(429), float64(500)},
|
|
},
|
|
},
|
|
statusCode: 500,
|
|
body: []byte(`"error"`),
|
|
expected: ErrorPolicyMatched,
|
|
},
|
|
{
|
|
name: "custom_error_codes_miss_returns_skipped",
|
|
account: &Account{
|
|
ID: 3,
|
|
Type: AccountTypeAPIKey,
|
|
Platform: PlatformAntigravity,
|
|
Credentials: map[string]any{
|
|
"custom_error_codes_enabled": true,
|
|
"custom_error_codes": []any{float64(429), float64(500)},
|
|
},
|
|
},
|
|
statusCode: 503,
|
|
body: []byte(`"error"`),
|
|
expected: ErrorPolicySkipped,
|
|
},
|
|
{
|
|
name: "temp_unschedulable_hit_returns_temp_unscheduled",
|
|
account: &Account{
|
|
ID: 4,
|
|
Type: AccountTypeOAuth,
|
|
Platform: PlatformAntigravity,
|
|
Credentials: map[string]any{
|
|
"temp_unschedulable_enabled": true,
|
|
"temp_unschedulable_rules": []any{
|
|
map[string]any{
|
|
"error_code": float64(503),
|
|
"keywords": []any{"overloaded"},
|
|
"duration_minutes": float64(10),
|
|
"description": "overloaded rule",
|
|
},
|
|
},
|
|
},
|
|
},
|
|
statusCode: 503,
|
|
body: []byte(`overloaded service`),
|
|
expected: ErrorPolicyTempUnscheduled,
|
|
},
|
|
{
|
|
name: "temp_unschedulable_body_miss_returns_none",
|
|
account: &Account{
|
|
ID: 5,
|
|
Type: AccountTypeOAuth,
|
|
Platform: PlatformAntigravity,
|
|
Credentials: map[string]any{
|
|
"temp_unschedulable_enabled": true,
|
|
"temp_unschedulable_rules": []any{
|
|
map[string]any{
|
|
"error_code": float64(503),
|
|
"keywords": []any{"overloaded"},
|
|
"duration_minutes": float64(10),
|
|
"description": "overloaded rule",
|
|
},
|
|
},
|
|
},
|
|
},
|
|
statusCode: 503,
|
|
body: []byte(`random msg`),
|
|
expected: ErrorPolicyNone,
|
|
},
|
|
{
|
|
name: "custom_error_codes_override_temp_unschedulable",
|
|
account: &Account{
|
|
ID: 6,
|
|
Type: AccountTypeAPIKey,
|
|
Platform: PlatformAntigravity,
|
|
Credentials: map[string]any{
|
|
"custom_error_codes_enabled": true,
|
|
"custom_error_codes": []any{float64(503)},
|
|
"temp_unschedulable_enabled": true,
|
|
"temp_unschedulable_rules": []any{
|
|
map[string]any{
|
|
"error_code": float64(503),
|
|
"keywords": []any{"overloaded"},
|
|
"duration_minutes": float64(10),
|
|
"description": "overloaded rule",
|
|
},
|
|
},
|
|
},
|
|
},
|
|
statusCode: 503,
|
|
body: []byte(`overloaded`),
|
|
expected: ErrorPolicyMatched, // custom codes take precedence
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
repo := &errorPolicyRepoStub{}
|
|
svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
|
|
|
result := svc.CheckErrorPolicy(context.Background(), tt.account, tt.statusCode, tt.body)
|
|
require.Equal(t, tt.expected, result, "unexpected ErrorPolicyResult")
|
|
})
|
|
}
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// TestApplyErrorPolicy — 4 table-driven cases for the wrapper method
|
|
// ---------------------------------------------------------------------------
|
|
|
|
func TestApplyErrorPolicy(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
account *Account
|
|
statusCode int
|
|
body []byte
|
|
expectedHandled bool
|
|
expectedSwitchErr bool // expect *AntigravityAccountSwitchError
|
|
handleErrorCalls int
|
|
}{
|
|
{
|
|
name: "none_not_handled",
|
|
account: &Account{
|
|
ID: 10,
|
|
Type: AccountTypeOAuth,
|
|
Platform: PlatformAntigravity,
|
|
},
|
|
statusCode: 500,
|
|
body: []byte(`"error"`),
|
|
expectedHandled: false,
|
|
handleErrorCalls: 0,
|
|
},
|
|
{
|
|
name: "skipped_handled_no_handleError",
|
|
account: &Account{
|
|
ID: 11,
|
|
Type: AccountTypeAPIKey,
|
|
Platform: PlatformAntigravity,
|
|
Credentials: map[string]any{
|
|
"custom_error_codes_enabled": true,
|
|
"custom_error_codes": []any{float64(429)},
|
|
},
|
|
},
|
|
statusCode: 500, // not in custom codes
|
|
body: []byte(`"error"`),
|
|
expectedHandled: true,
|
|
handleErrorCalls: 0,
|
|
},
|
|
{
|
|
name: "matched_handled_calls_handleError",
|
|
account: &Account{
|
|
ID: 12,
|
|
Type: AccountTypeAPIKey,
|
|
Platform: PlatformAntigravity,
|
|
Credentials: map[string]any{
|
|
"custom_error_codes_enabled": true,
|
|
"custom_error_codes": []any{float64(500)},
|
|
},
|
|
},
|
|
statusCode: 500,
|
|
body: []byte(`"error"`),
|
|
expectedHandled: true,
|
|
handleErrorCalls: 1,
|
|
},
|
|
{
|
|
name: "temp_unscheduled_returns_switch_error",
|
|
account: &Account{
|
|
ID: 13,
|
|
Type: AccountTypeOAuth,
|
|
Platform: PlatformAntigravity,
|
|
Credentials: map[string]any{
|
|
"temp_unschedulable_enabled": true,
|
|
"temp_unschedulable_rules": []any{
|
|
map[string]any{
|
|
"error_code": float64(503),
|
|
"keywords": []any{"overloaded"},
|
|
"duration_minutes": float64(10),
|
|
},
|
|
},
|
|
},
|
|
},
|
|
statusCode: 503,
|
|
body: []byte(`overloaded`),
|
|
expectedHandled: true,
|
|
expectedSwitchErr: true,
|
|
handleErrorCalls: 0,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
repo := &errorPolicyRepoStub{}
|
|
rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
|
svc := &AntigravityGatewayService{
|
|
rateLimitService: rlSvc,
|
|
}
|
|
|
|
var handleErrorCount int
|
|
p := antigravityRetryLoopParams{
|
|
ctx: context.Background(),
|
|
prefix: "[test]",
|
|
account: tt.account,
|
|
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
|
handleErrorCount++
|
|
return nil
|
|
},
|
|
isStickySession: true,
|
|
}
|
|
|
|
handled, retErr := svc.applyErrorPolicy(p, tt.statusCode, http.Header{}, tt.body)
|
|
|
|
require.Equal(t, tt.expectedHandled, handled, "handled mismatch")
|
|
require.Equal(t, tt.handleErrorCalls, handleErrorCount, "handleError call count mismatch")
|
|
|
|
if tt.expectedSwitchErr {
|
|
var switchErr *AntigravityAccountSwitchError
|
|
require.ErrorAs(t, retErr, &switchErr)
|
|
require.Equal(t, tt.account.ID, switchErr.OriginalAccountID)
|
|
} else {
|
|
require.NoError(t, retErr)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// errorPolicyRepoStub — minimal AccountRepository stub for error policy tests
|
|
// ---------------------------------------------------------------------------
|
|
|
|
type errorPolicyRepoStub struct {
|
|
mockAccountRepoForGemini
|
|
tempCalls int
|
|
setErrCalls int
|
|
lastErrorMsg string
|
|
}
|
|
|
|
func (r *errorPolicyRepoStub) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
|
|
r.tempCalls++
|
|
return nil
|
|
}
|
|
|
|
func (r *errorPolicyRepoStub) SetError(ctx context.Context, id int64, errorMsg string) error {
|
|
r.setErrCalls++
|
|
r.lastErrorMsg = errorMsg
|
|
return nil
|
|
}
|