- Extract PublicSettingsInjectionPayload named struct with drift test - Add channel_monitor_default_interval_seconds to SSR injection - Add image_output_price to SupportedModelChip - Simplify AppSidebar buildSelfNavItems (admins see available channels) - Add gateway WARN logs for 503 no-available-accounts branches - Wire ChannelMonitorRunner into provideCleanup for graceful shutdown - Add migrations 130/131 (CC template userid fix + mimicry field cleanup) - Clean up fork-only features (sora, claude max simulation, client affinity) - Remove ~320 obsolete i18n keys - Add codexUsage utility, WechatServiceButton, BulkEditAccountModal - Tidy go.sum
434 lines
13 KiB
Go
434 lines
13 KiB
Go
//go:build unit
|
||
|
||
package service
|
||
|
||
import (
|
||
"context"
|
||
"net/http"
|
||
"testing"
|
||
"time"
|
||
|
||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||
"github.com/stretchr/testify/require"
|
||
)
|
||
|
||
// ---------------------------------------------------------------------------
|
||
// TestCheckErrorPolicy — 6 table-driven cases for the pure logic function
|
||
// ---------------------------------------------------------------------------
|
||
|
||
func TestCheckErrorPolicy(t *testing.T) {
|
||
tests := []struct {
|
||
name string
|
||
account *Account
|
||
statusCode int
|
||
body []byte
|
||
expected ErrorPolicyResult
|
||
}{
|
||
{
|
||
name: "no_policy_oauth_returns_none",
|
||
account: &Account{
|
||
ID: 1,
|
||
Type: AccountTypeOAuth,
|
||
Platform: PlatformAntigravity,
|
||
// no custom error codes, no temp rules
|
||
},
|
||
statusCode: 500,
|
||
body: []byte(`"error"`),
|
||
expected: ErrorPolicyNone,
|
||
},
|
||
{
|
||
name: "custom_error_codes_hit_returns_matched",
|
||
account: &Account{
|
||
ID: 2,
|
||
Type: AccountTypeAPIKey,
|
||
Platform: PlatformAntigravity,
|
||
Credentials: map[string]any{
|
||
"custom_error_codes_enabled": true,
|
||
"custom_error_codes": []any{float64(429), float64(500)},
|
||
},
|
||
},
|
||
statusCode: 500,
|
||
body: []byte(`"error"`),
|
||
expected: ErrorPolicyMatched,
|
||
},
|
||
{
|
||
name: "custom_error_codes_miss_returns_skipped",
|
||
account: &Account{
|
||
ID: 3,
|
||
Type: AccountTypeAPIKey,
|
||
Platform: PlatformAntigravity,
|
||
Credentials: map[string]any{
|
||
"custom_error_codes_enabled": true,
|
||
"custom_error_codes": []any{float64(429), float64(500)},
|
||
},
|
||
},
|
||
statusCode: 503,
|
||
body: []byte(`"error"`),
|
||
expected: ErrorPolicySkipped,
|
||
},
|
||
{
|
||
name: "temp_unschedulable_hit_returns_temp_unscheduled",
|
||
account: &Account{
|
||
ID: 4,
|
||
Type: AccountTypeOAuth,
|
||
Platform: PlatformAntigravity,
|
||
Credentials: map[string]any{
|
||
"temp_unschedulable_enabled": true,
|
||
"temp_unschedulable_rules": []any{
|
||
map[string]any{
|
||
"error_code": float64(503),
|
||
"keywords": []any{"overloaded"},
|
||
"duration_minutes": float64(10),
|
||
"description": "overloaded rule",
|
||
},
|
||
},
|
||
},
|
||
},
|
||
statusCode: 503,
|
||
body: []byte(`overloaded service`),
|
||
expected: ErrorPolicyTempUnscheduled,
|
||
},
|
||
{
|
||
name: "temp_unschedulable_401_first_hit_returns_temp_unscheduled",
|
||
account: &Account{
|
||
ID: 14,
|
||
Type: AccountTypeOAuth,
|
||
Platform: PlatformAntigravity,
|
||
Credentials: map[string]any{
|
||
"temp_unschedulable_enabled": true,
|
||
"temp_unschedulable_rules": []any{
|
||
map[string]any{
|
||
"error_code": float64(401),
|
||
"keywords": []any{"unauthorized"},
|
||
"duration_minutes": float64(10),
|
||
},
|
||
},
|
||
},
|
||
},
|
||
statusCode: 401,
|
||
body: []byte(`unauthorized`),
|
||
expected: ErrorPolicyTempUnscheduled,
|
||
},
|
||
{
|
||
// Gemini OAuth 401 second hit 会升级为 error(返回 None,交由默认错误逻辑处理)。
|
||
name: "temp_unschedulable_401_second_hit_gemini_escalates",
|
||
account: &Account{
|
||
ID: 15,
|
||
Type: AccountTypeOAuth,
|
||
Platform: PlatformGemini, // 非 Antigravity 平台 401 second hit 升级
|
||
TempUnschedulableReason: `{"status_code":401,"until_unix":1735689600}`,
|
||
Credentials: map[string]any{
|
||
"temp_unschedulable_enabled": true,
|
||
"temp_unschedulable_rules": []any{
|
||
map[string]any{
|
||
"error_code": float64(401),
|
||
"keywords": []any{"unauthorized"},
|
||
"duration_minutes": float64(10),
|
||
},
|
||
},
|
||
},
|
||
},
|
||
statusCode: 401,
|
||
body: []byte(`unauthorized`),
|
||
expected: ErrorPolicyNone, // Gemini 401 second hit 升级为 error
|
||
},
|
||
{
|
||
name: "temp_unschedulable_401_antigravity_no_escalation",
|
||
account: &Account{
|
||
ID: 16,
|
||
Type: AccountTypeOAuth,
|
||
Platform: PlatformAntigravity, // Antigravity 跳过 401 升级,由 rules 正常处理
|
||
TempUnschedulableReason: `{"status_code":401,"until_unix":1735689600}`,
|
||
Credentials: map[string]any{
|
||
"temp_unschedulable_enabled": true,
|
||
"temp_unschedulable_rules": []any{
|
||
map[string]any{
|
||
"error_code": float64(401),
|
||
"keywords": []any{"unauthorized"},
|
||
"duration_minutes": float64(10),
|
||
},
|
||
},
|
||
},
|
||
},
|
||
statusCode: 401,
|
||
body: []byte(`unauthorized`),
|
||
expected: ErrorPolicyTempUnscheduled, // Antigravity 不升级,继续走规则匹配
|
||
},
|
||
{
|
||
name: "temp_unschedulable_body_miss_returns_none",
|
||
account: &Account{
|
||
ID: 5,
|
||
Type: AccountTypeOAuth,
|
||
Platform: PlatformAntigravity,
|
||
Credentials: map[string]any{
|
||
"temp_unschedulable_enabled": true,
|
||
"temp_unschedulable_rules": []any{
|
||
map[string]any{
|
||
"error_code": float64(503),
|
||
"keywords": []any{"overloaded"},
|
||
"duration_minutes": float64(10),
|
||
"description": "overloaded rule",
|
||
},
|
||
},
|
||
},
|
||
},
|
||
statusCode: 503,
|
||
body: []byte(`random msg`),
|
||
expected: ErrorPolicyNone,
|
||
},
|
||
{
|
||
name: "custom_error_codes_override_temp_unschedulable",
|
||
account: &Account{
|
||
ID: 6,
|
||
Type: AccountTypeAPIKey,
|
||
Platform: PlatformAntigravity,
|
||
Credentials: map[string]any{
|
||
"custom_error_codes_enabled": true,
|
||
"custom_error_codes": []any{float64(503)},
|
||
"temp_unschedulable_enabled": true,
|
||
"temp_unschedulable_rules": []any{
|
||
map[string]any{
|
||
"error_code": float64(503),
|
||
"keywords": []any{"overloaded"},
|
||
"duration_minutes": float64(10),
|
||
"description": "overloaded rule",
|
||
},
|
||
},
|
||
},
|
||
},
|
||
statusCode: 503,
|
||
body: []byte(`overloaded`),
|
||
expected: ErrorPolicyMatched, // custom codes take precedence
|
||
},
|
||
{
|
||
name: "pool_mode_custom_error_codes_hit_returns_matched",
|
||
account: &Account{
|
||
ID: 7,
|
||
Type: AccountTypeAPIKey,
|
||
Platform: PlatformOpenAI,
|
||
Credentials: map[string]any{
|
||
"pool_mode": true,
|
||
"custom_error_codes_enabled": true,
|
||
"custom_error_codes": []any{float64(401), float64(403)},
|
||
},
|
||
},
|
||
statusCode: 401,
|
||
body: []byte(`unauthorized`),
|
||
expected: ErrorPolicyMatched,
|
||
},
|
||
{
|
||
name: "pool_mode_without_custom_error_codes_returns_skipped",
|
||
account: &Account{
|
||
ID: 8,
|
||
Type: AccountTypeAPIKey,
|
||
Platform: PlatformOpenAI,
|
||
Credentials: map[string]any{
|
||
"pool_mode": true,
|
||
},
|
||
},
|
||
statusCode: 401,
|
||
body: []byte(`unauthorized`),
|
||
expected: ErrorPolicySkipped,
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
repo := &errorPolicyRepoStub{}
|
||
svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||
|
||
result := svc.CheckErrorPolicy(context.Background(), tt.account, tt.statusCode, tt.body)
|
||
require.Equal(t, tt.expected, result, "unexpected ErrorPolicyResult")
|
||
})
|
||
}
|
||
}
|
||
|
||
func TestHandleUpstreamError_PoolModeCustomErrorCodesOverride(t *testing.T) {
|
||
t.Run("pool_mode_without_custom_error_codes_still_skips", func(t *testing.T) {
|
||
repo := &errorPolicyRepoStub{}
|
||
svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||
account := &Account{
|
||
ID: 30,
|
||
Type: AccountTypeAPIKey,
|
||
Platform: PlatformOpenAI,
|
||
Credentials: map[string]any{
|
||
"pool_mode": true,
|
||
},
|
||
}
|
||
|
||
shouldDisable := svc.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
|
||
|
||
require.False(t, shouldDisable)
|
||
require.Equal(t, 0, repo.setErrCalls)
|
||
require.Equal(t, 0, repo.tempCalls)
|
||
})
|
||
|
||
t.Run("pool_mode_with_custom_error_codes_uses_local_error_policy", func(t *testing.T) {
|
||
repo := &errorPolicyRepoStub{}
|
||
svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||
account := &Account{
|
||
ID: 31,
|
||
Type: AccountTypeAPIKey,
|
||
Platform: PlatformOpenAI,
|
||
Credentials: map[string]any{
|
||
"pool_mode": true,
|
||
"custom_error_codes_enabled": true,
|
||
"custom_error_codes": []any{float64(401)},
|
||
},
|
||
}
|
||
|
||
shouldDisable := svc.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
|
||
|
||
require.True(t, shouldDisable)
|
||
require.Equal(t, 1, repo.setErrCalls)
|
||
require.Equal(t, 0, repo.tempCalls)
|
||
})
|
||
}
|
||
|
||
// ---------------------------------------------------------------------------
|
||
// TestApplyErrorPolicy — 4 table-driven cases for the wrapper method
|
||
// ---------------------------------------------------------------------------
|
||
|
||
func TestApplyErrorPolicy(t *testing.T) {
|
||
tests := []struct {
|
||
name string
|
||
account *Account
|
||
statusCode int
|
||
body []byte
|
||
expectedHandled bool
|
||
expectedStatus int // expected outStatus
|
||
expectedSwitchErr bool // expect *AntigravityAccountSwitchError
|
||
handleErrorCalls int
|
||
}{
|
||
{
|
||
name: "none_not_handled",
|
||
account: &Account{
|
||
ID: 10,
|
||
Type: AccountTypeOAuth,
|
||
Platform: PlatformAntigravity,
|
||
},
|
||
statusCode: 500,
|
||
body: []byte(`"error"`),
|
||
expectedHandled: false,
|
||
expectedStatus: 500, // passthrough
|
||
handleErrorCalls: 0,
|
||
},
|
||
{
|
||
name: "skipped_handled_no_handleError",
|
||
account: &Account{
|
||
ID: 11,
|
||
Type: AccountTypeAPIKey,
|
||
Platform: PlatformAntigravity,
|
||
Credentials: map[string]any{
|
||
"custom_error_codes_enabled": true,
|
||
"custom_error_codes": []any{float64(429)},
|
||
},
|
||
},
|
||
statusCode: 500, // not in custom codes
|
||
body: []byte(`"error"`),
|
||
expectedHandled: true,
|
||
expectedStatus: http.StatusInternalServerError, // skipped → 500
|
||
handleErrorCalls: 0,
|
||
},
|
||
{
|
||
name: "matched_handled_calls_handleError",
|
||
account: &Account{
|
||
ID: 12,
|
||
Type: AccountTypeAPIKey,
|
||
Platform: PlatformAntigravity,
|
||
Credentials: map[string]any{
|
||
"custom_error_codes_enabled": true,
|
||
"custom_error_codes": []any{float64(500)},
|
||
},
|
||
},
|
||
statusCode: 500,
|
||
body: []byte(`"error"`),
|
||
expectedHandled: true,
|
||
expectedStatus: 500, // matched → original status
|
||
handleErrorCalls: 1,
|
||
},
|
||
{
|
||
name: "temp_unscheduled_returns_switch_error",
|
||
account: &Account{
|
||
ID: 13,
|
||
Type: AccountTypeOAuth,
|
||
Platform: PlatformAntigravity,
|
||
Credentials: map[string]any{
|
||
"temp_unschedulable_enabled": true,
|
||
"temp_unschedulable_rules": []any{
|
||
map[string]any{
|
||
"error_code": float64(503),
|
||
"keywords": []any{"overloaded"},
|
||
"duration_minutes": float64(10),
|
||
},
|
||
},
|
||
},
|
||
},
|
||
statusCode: 503,
|
||
body: []byte(`overloaded`),
|
||
expectedHandled: true,
|
||
expectedStatus: 503, // temp_unscheduled → original status
|
||
expectedSwitchErr: true,
|
||
handleErrorCalls: 0,
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
repo := &errorPolicyRepoStub{}
|
||
rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||
svc := &AntigravityGatewayService{
|
||
rateLimitService: rlSvc,
|
||
}
|
||
|
||
var handleErrorCount int
|
||
p := antigravityRetryLoopParams{
|
||
ctx: context.Background(),
|
||
prefix: "[test]",
|
||
account: tt.account,
|
||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||
handleErrorCount++
|
||
return nil
|
||
},
|
||
isStickySession: true,
|
||
}
|
||
|
||
handled, outStatus, retErr := svc.applyErrorPolicy(p, tt.statusCode, http.Header{}, tt.body)
|
||
|
||
require.Equal(t, tt.expectedHandled, handled, "handled mismatch")
|
||
require.Equal(t, tt.expectedStatus, outStatus, "outStatus mismatch")
|
||
require.Equal(t, tt.handleErrorCalls, handleErrorCount, "handleError call count mismatch")
|
||
|
||
if tt.expectedSwitchErr {
|
||
var switchErr *AntigravityAccountSwitchError
|
||
require.ErrorAs(t, retErr, &switchErr)
|
||
require.Equal(t, tt.account.ID, switchErr.OriginalAccountID)
|
||
} else {
|
||
require.NoError(t, retErr)
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
// ---------------------------------------------------------------------------
|
||
// errorPolicyRepoStub — minimal AccountRepository stub for error policy tests
|
||
// ---------------------------------------------------------------------------
|
||
|
||
type errorPolicyRepoStub struct {
|
||
mockAccountRepoForGemini
|
||
tempCalls int
|
||
setErrCalls int
|
||
lastErrorMsg string
|
||
}
|
||
|
||
func (r *errorPolicyRepoStub) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
|
||
r.tempCalls++
|
||
return nil
|
||
}
|
||
|
||
func (r *errorPolicyRepoStub) SetError(ctx context.Context, id int64, errorMsg string) error {
|
||
r.setErrCalls++
|
||
r.lastErrorMsg = errorMsg
|
||
return nil
|
||
}
|