diff --git a/backend/internal/service/antigravity_gateway_service.go b/backend/internal/service/antigravity_gateway_service.go index 3caf9a93..cc4d2fb9 100644 --- a/backend/internal/service/antigravity_gateway_service.go +++ b/backend/internal/service/antigravity_gateway_service.go @@ -9,6 +9,7 @@ import ( "fmt" "io" "log" + "log/slog" mathrand "math/rand" "net" "net/http" @@ -353,87 +354,102 @@ urlFallbackLoop: return nil, fmt.Errorf("upstream request failed after retries: %w", err) } - // 429/503 限流处理:区分 URL 级别限流、智能重试和账户配额限流 - if resp.StatusCode == http.StatusTooManyRequests || resp.StatusCode == http.StatusServiceUnavailable { + // 统一处理错误响应 + if resp.StatusCode >= 400 { respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) _ = resp.Body.Close() - // 尝试智能重试处理(OAuth 账号专用) - smartResult := s.handleSmartRetry(p, resp, respBody, baseURL, urlIdx, availableURLs) - switch smartResult.action { - case smartRetryActionContinueURL: - continue urlFallbackLoop - case smartRetryActionBreakWithResp: - if smartResult.err != nil { - return nil, smartResult.err + // ★ 统一入口:自定义错误码 + 临时不可调度 + if handled, policyErr := s.applyErrorPolicy(p, resp.StatusCode, resp.Header, respBody); handled { + if policyErr != nil { + return nil, policyErr } - // 模型限流时返回切换账号信号 - if smartResult.switchError != nil { - return nil, smartResult.switchError + resp = &http.Response{ + StatusCode: resp.StatusCode, + Header: resp.Header.Clone(), + Body: io.NopCloser(bytes.NewReader(respBody)), } - resp = smartResult.resp break urlFallbackLoop } - // smartRetryActionContinue: 继续默认重试逻辑 - // 账户/模型配额限流,重试 3 次(指数退避)- 默认逻辑(非 OAuth 账号或解析失败) - if attempt < antigravityMaxRetries { - upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody)) - upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) - appendOpsUpstreamError(p.c, OpsUpstreamErrorEvent{ - Platform: p.account.Platform, - AccountID: p.account.ID, - AccountName: p.account.Name, - UpstreamStatusCode: resp.StatusCode, - UpstreamRequestID: resp.Header.Get("x-request-id"), - Kind: "retry", - Message: upstreamMsg, - Detail: getUpstreamDetail(respBody), - }) - log.Printf("%s status=%d retry=%d/%d body=%s", p.prefix, resp.StatusCode, attempt, antigravityMaxRetries, truncateForLog(respBody, 200)) - if !sleepAntigravityBackoffWithContext(p.ctx, attempt) { - log.Printf("%s status=context_canceled_during_backoff", p.prefix) - return nil, p.ctx.Err() + // 429/503 限流处理:区分 URL 级别限流、智能重试和账户配额限流 + if resp.StatusCode == http.StatusTooManyRequests || resp.StatusCode == http.StatusServiceUnavailable { + // 尝试智能重试处理(OAuth 账号专用) + smartResult := s.handleSmartRetry(p, resp, respBody, baseURL, urlIdx, availableURLs) + switch smartResult.action { + case smartRetryActionContinueURL: + continue urlFallbackLoop + case smartRetryActionBreakWithResp: + if smartResult.err != nil { + return nil, smartResult.err + } + // 模型限流时返回切换账号信号 + if smartResult.switchError != nil { + return nil, smartResult.switchError + } + resp = smartResult.resp + break urlFallbackLoop } - continue + // smartRetryActionContinue: 继续默认重试逻辑 + + // 账户/模型配额限流,重试 3 次(指数退避)- 默认逻辑(非 OAuth 账号或解析失败) + if attempt < antigravityMaxRetries { + upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody)) + upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + appendOpsUpstreamError(p.c, OpsUpstreamErrorEvent{ + Platform: p.account.Platform, + AccountID: p.account.ID, + AccountName: p.account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Kind: "retry", + Message: upstreamMsg, + Detail: getUpstreamDetail(respBody), + }) + log.Printf("%s status=%d retry=%d/%d body=%s", p.prefix, resp.StatusCode, attempt, antigravityMaxRetries, truncateForLog(respBody, 200)) + if !sleepAntigravityBackoffWithContext(p.ctx, attempt) { + log.Printf("%s status=context_canceled_during_backoff", p.prefix) + return nil, p.ctx.Err() + } + continue + } + + // 重试用尽,标记账户限流 + p.handleError(p.ctx, p.prefix, p.account, resp.StatusCode, resp.Header, respBody, p.quotaScope, p.groupID, p.sessionHash, p.isStickySession) + log.Printf("%s status=%d rate_limited base_url=%s body=%s", p.prefix, resp.StatusCode, baseURL, truncateForLog(respBody, 200)) + resp = &http.Response{ + StatusCode: resp.StatusCode, + Header: resp.Header.Clone(), + Body: io.NopCloser(bytes.NewReader(respBody)), + } + break urlFallbackLoop } - // 重试用尽,标记账户限流 - p.handleError(p.ctx, p.prefix, p.account, resp.StatusCode, resp.Header, respBody, p.quotaScope, p.groupID, p.sessionHash, p.isStickySession) - log.Printf("%s status=%d rate_limited base_url=%s body=%s", p.prefix, resp.StatusCode, baseURL, truncateForLog(respBody, 200)) - resp = &http.Response{ - StatusCode: resp.StatusCode, - Header: resp.Header.Clone(), - Body: io.NopCloser(bytes.NewReader(respBody)), - } - break urlFallbackLoop - } - - // 其他可重试错误(不包括 429 和 503,因为上面已处理) - if resp.StatusCode >= 400 && shouldRetryAntigravityError(resp.StatusCode) { - respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) - _ = resp.Body.Close() - - if attempt < antigravityMaxRetries { - upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody)) - upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) - appendOpsUpstreamError(p.c, OpsUpstreamErrorEvent{ - Platform: p.account.Platform, - AccountID: p.account.ID, - AccountName: p.account.Name, - UpstreamStatusCode: resp.StatusCode, - UpstreamRequestID: resp.Header.Get("x-request-id"), - Kind: "retry", - Message: upstreamMsg, - Detail: getUpstreamDetail(respBody), - }) - log.Printf("%s status=%d retry=%d/%d body=%s", p.prefix, resp.StatusCode, attempt, antigravityMaxRetries, truncateForLog(respBody, 500)) - if !sleepAntigravityBackoffWithContext(p.ctx, attempt) { - log.Printf("%s status=context_canceled_during_backoff", p.prefix) - return nil, p.ctx.Err() - } - continue - } + // 其他可重试错误(500/502/504/529,不包括 429 和 503) + if shouldRetryAntigravityError(resp.StatusCode) { + if attempt < antigravityMaxRetries { + upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody)) + upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + appendOpsUpstreamError(p.c, OpsUpstreamErrorEvent{ + Platform: p.account.Platform, + AccountID: p.account.ID, + AccountName: p.account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Kind: "retry", + Message: upstreamMsg, + Detail: getUpstreamDetail(respBody), + }) + log.Printf("%s status=%d retry=%d/%d body=%s", p.prefix, resp.StatusCode, attempt, antigravityMaxRetries, truncateForLog(respBody, 500)) + if !sleepAntigravityBackoffWithContext(p.ctx, attempt) { + log.Printf("%s status=context_canceled_during_backoff", p.prefix) + return nil, p.ctx.Err() + } + continue + } + } + + // 其他 4xx 错误或重试用尽,直接返回 resp = &http.Response{ StatusCode: resp.StatusCode, Header: resp.Header.Clone(), @@ -442,6 +458,7 @@ urlFallbackLoop: break urlFallbackLoop } + // 成功响应(< 400) break urlFallbackLoop } } @@ -574,6 +591,31 @@ func (s *AntigravityGatewayService) getUpstreamErrorDetail(body []byte) string { return truncateString(string(body), maxBytes) } +// checkErrorPolicy nil 安全的包装 +func (s *AntigravityGatewayService) checkErrorPolicy(ctx context.Context, account *Account, statusCode int, body []byte) ErrorPolicyResult { + if s.rateLimitService == nil { + return ErrorPolicyNone + } + return s.rateLimitService.CheckErrorPolicy(ctx, account, statusCode, body) +} + +// applyErrorPolicy 应用错误策略结果,返回是否应终止当前循环 +func (s *AntigravityGatewayService) applyErrorPolicy(p antigravityRetryLoopParams, statusCode int, headers http.Header, respBody []byte) (handled bool, retErr error) { + switch s.checkErrorPolicy(p.ctx, p.account, statusCode, respBody) { + case ErrorPolicySkipped: + return true, nil + case ErrorPolicyMatched: + _ = p.handleError(p.ctx, p.prefix, p.account, statusCode, headers, respBody, + p.quotaScope, p.groupID, p.sessionHash, p.isStickySession) + return true, nil + case ErrorPolicyTempUnscheduled: + slog.Info("temp_unschedulable_matched", + "prefix", p.prefix, "status_code", statusCode, "account_id", p.account.ID) + return true, &AntigravityAccountSwitchError{OriginalAccountID: p.account.ID, IsStickySession: p.isStickySession} + } + return false, nil +} + // mapAntigravityModel 获取映射后的模型名 // 完全依赖映射配置:账户映射(通配符)→ 默认映射兜底(DefaultAntigravityModelMapping) // 注意:返回空字符串表示模型不被支持,调度时会过滤掉该账号 diff --git a/backend/internal/service/error_policy_integration_test.go b/backend/internal/service/error_policy_integration_test.go new file mode 100644 index 00000000..9f8ad938 --- /dev/null +++ b/backend/internal/service/error_policy_integration_test.go @@ -0,0 +1,366 @@ +//go:build unit + +package service + +import ( + "context" + "io" + "net/http" + "strings" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" + "github.com/stretchr/testify/require" +) + +// --------------------------------------------------------------------------- +// Mocks (scoped to this file by naming convention) +// --------------------------------------------------------------------------- + +// epFixedUpstream returns a fixed response for every request. +type epFixedUpstream struct { + statusCode int + body string + calls int +} + +func (u *epFixedUpstream) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) { + u.calls++ + return &http.Response{ + StatusCode: u.statusCode, + Header: http.Header{}, + Body: io.NopCloser(strings.NewReader(u.body)), + }, nil +} + +func (u *epFixedUpstream) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, enableTLSFingerprint bool) (*http.Response, error) { + return u.Do(req, proxyURL, accountID, accountConcurrency) +} + +// epAccountRepo records SetTempUnschedulable / SetError calls. +type epAccountRepo struct { + mockAccountRepoForGemini + tempCalls int + setErrCalls int +} + +func (r *epAccountRepo) SetTempUnschedulable(_ context.Context, _ int64, _ time.Time, _ string) error { + r.tempCalls++ + return nil +} + +func (r *epAccountRepo) SetError(_ context.Context, _ int64, _ string) error { + r.setErrCalls++ + return nil +} + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +func saveAndSetBaseURLs(t *testing.T) { + t.Helper() + oldBaseURLs := append([]string(nil), antigravity.BaseURLs...) + oldAvail := antigravity.DefaultURLAvailability + antigravity.BaseURLs = []string{"https://ep-test.example"} + antigravity.DefaultURLAvailability = antigravity.NewURLAvailability(time.Minute) + t.Cleanup(func() { + antigravity.BaseURLs = oldBaseURLs + antigravity.DefaultURLAvailability = oldAvail + }) +} + +func newRetryParams(account *Account, upstream HTTPUpstream, handleError func(context.Context, string, *Account, int, http.Header, []byte, string, int64, string, bool) *handleModelRateLimitResult) antigravityRetryLoopParams { + return antigravityRetryLoopParams{ + ctx: context.Background(), + prefix: "[ep-test]", + account: account, + accessToken: "token", + action: "generateContent", + body: []byte(`{"input":"test"}`), + httpUpstream: upstream, + requestedModel: "claude-sonnet-4-5", + handleError: handleError, + } +} + +// --------------------------------------------------------------------------- +// TestRetryLoop_ErrorPolicy_CustomErrorCodes +// --------------------------------------------------------------------------- + +func TestRetryLoop_ErrorPolicy_CustomErrorCodes(t *testing.T) { + tests := []struct { + name string + upstreamStatus int + upstreamBody string + customCodes []any + expectHandleError int + expectUpstream int + expectStatusCode int + }{ + { + name: "429_in_custom_codes_matched", + upstreamStatus: 429, + upstreamBody: `{"error":"rate limited"}`, + customCodes: []any{float64(429)}, + expectHandleError: 1, + expectUpstream: 1, + expectStatusCode: 429, + }, + { + name: "429_not_in_custom_codes_skipped", + upstreamStatus: 429, + upstreamBody: `{"error":"rate limited"}`, + customCodes: []any{float64(500)}, + expectHandleError: 0, + expectUpstream: 1, + expectStatusCode: 429, + }, + { + name: "500_in_custom_codes_matched", + upstreamStatus: 500, + upstreamBody: `{"error":"internal"}`, + customCodes: []any{float64(500)}, + expectHandleError: 1, + expectUpstream: 1, + expectStatusCode: 500, + }, + { + name: "500_not_in_custom_codes_skipped", + upstreamStatus: 500, + upstreamBody: `{"error":"internal"}`, + customCodes: []any{float64(429)}, + expectHandleError: 0, + expectUpstream: 1, + expectStatusCode: 500, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + saveAndSetBaseURLs(t) + + upstream := &epFixedUpstream{statusCode: tt.upstreamStatus, body: tt.upstreamBody} + repo := &epAccountRepo{} + rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) + + account := &Account{ + ID: 100, + Type: AccountTypeAPIKey, + Platform: PlatformAntigravity, + Schedulable: true, + Status: StatusActive, + Concurrency: 1, + Credentials: map[string]any{ + "custom_error_codes_enabled": true, + "custom_error_codes": tt.customCodes, + }, + } + + svc := &AntigravityGatewayService{rateLimitService: rlSvc} + + var handleErrorCount int + p := newRetryParams(account, upstream, func(_ context.Context, _ string, _ *Account, _ int, _ http.Header, _ []byte, _ string, _ int64, _ string, _ bool) *handleModelRateLimitResult { + handleErrorCount++ + return nil + }) + + result, err := svc.antigravityRetryLoop(p) + + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, result.resp) + defer func() { _ = result.resp.Body.Close() }() + + require.Equal(t, tt.expectStatusCode, result.resp.StatusCode) + require.Equal(t, tt.expectHandleError, handleErrorCount, "handleError call count") + require.Equal(t, tt.expectUpstream, upstream.calls, "upstream call count") + }) + } +} + +// --------------------------------------------------------------------------- +// TestRetryLoop_ErrorPolicy_TempUnschedulable +// --------------------------------------------------------------------------- + +func TestRetryLoop_ErrorPolicy_TempUnschedulable(t *testing.T) { + tempRulesAccount := func(rules []any) *Account { + return &Account{ + ID: 200, + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + Schedulable: true, + Status: StatusActive, + Concurrency: 1, + Credentials: map[string]any{ + "temp_unschedulable_enabled": true, + "temp_unschedulable_rules": rules, + }, + } + } + + overloadedRule := map[string]any{ + "error_code": float64(503), + "keywords": []any{"overloaded"}, + "duration_minutes": float64(10), + } + + rateLimitRule := map[string]any{ + "error_code": float64(429), + "keywords": []any{"rate limited keyword"}, + "duration_minutes": float64(5), + } + + t.Run("503_overloaded_matches_rule", func(t *testing.T) { + saveAndSetBaseURLs(t) + + upstream := &epFixedUpstream{statusCode: 503, body: `overloaded`} + repo := &epAccountRepo{} + rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) + svc := &AntigravityGatewayService{rateLimitService: rlSvc} + + account := tempRulesAccount([]any{overloadedRule}) + p := newRetryParams(account, upstream, func(_ context.Context, _ string, _ *Account, _ int, _ http.Header, _ []byte, _ string, _ int64, _ string, _ bool) *handleModelRateLimitResult { + t.Error("handleError should not be called for temp unschedulable") + return nil + }) + + result, err := svc.antigravityRetryLoop(p) + + require.Nil(t, result) + var switchErr *AntigravityAccountSwitchError + require.ErrorAs(t, err, &switchErr) + require.Equal(t, account.ID, switchErr.OriginalAccountID) + require.Equal(t, 1, upstream.calls, "should not retry") + }) + + t.Run("429_rate_limited_keyword_matches_rule", func(t *testing.T) { + saveAndSetBaseURLs(t) + + upstream := &epFixedUpstream{statusCode: 429, body: `rate limited keyword`} + repo := &epAccountRepo{} + rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) + svc := &AntigravityGatewayService{rateLimitService: rlSvc} + + account := tempRulesAccount([]any{rateLimitRule}) + p := newRetryParams(account, upstream, func(_ context.Context, _ string, _ *Account, _ int, _ http.Header, _ []byte, _ string, _ int64, _ string, _ bool) *handleModelRateLimitResult { + t.Error("handleError should not be called for temp unschedulable") + return nil + }) + + result, err := svc.antigravityRetryLoop(p) + + require.Nil(t, result) + var switchErr *AntigravityAccountSwitchError + require.ErrorAs(t, err, &switchErr) + require.Equal(t, account.ID, switchErr.OriginalAccountID) + require.Equal(t, 1, upstream.calls, "should not retry") + }) + + t.Run("503_body_no_match_continues_default_retry", func(t *testing.T) { + saveAndSetBaseURLs(t) + + upstream := &epFixedUpstream{statusCode: 503, body: `random`} + repo := &epAccountRepo{} + rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) + svc := &AntigravityGatewayService{rateLimitService: rlSvc} + + account := tempRulesAccount([]any{overloadedRule}) + + // Use a short-lived context: the backoff sleep (~1s) will be + // interrupted, proving the code entered the default retry path + // instead of breaking early via error policy. + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + p := newRetryParams(account, upstream, func(_ context.Context, _ string, _ *Account, _ int, _ http.Header, _ []byte, _ string, _ int64, _ string, _ bool) *handleModelRateLimitResult { + return nil + }) + p.ctx = ctx + + result, err := svc.antigravityRetryLoop(p) + + // Context cancellation during backoff proves default retry was entered + require.Nil(t, result) + require.ErrorIs(t, err, context.DeadlineExceeded) + require.GreaterOrEqual(t, upstream.calls, 1, "should have called upstream at least once") + }) +} + +// --------------------------------------------------------------------------- +// TestRetryLoop_ErrorPolicy_NilRateLimitService +// --------------------------------------------------------------------------- + +func TestRetryLoop_ErrorPolicy_NilRateLimitService(t *testing.T) { + saveAndSetBaseURLs(t) + + upstream := &epFixedUpstream{statusCode: 429, body: `{"error":"rate limited"}`} + // rateLimitService is nil — must not panic + svc := &AntigravityGatewayService{rateLimitService: nil} + + account := &Account{ + ID: 300, + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + Schedulable: true, + Status: StatusActive, + Concurrency: 1, + } + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + p := newRetryParams(account, upstream, func(_ context.Context, _ string, _ *Account, _ int, _ http.Header, _ []byte, _ string, _ int64, _ string, _ bool) *handleModelRateLimitResult { + return nil + }) + p.ctx = ctx + + // Should not panic; enters the default retry path (eventually times out) + result, err := svc.antigravityRetryLoop(p) + + require.Nil(t, result) + require.ErrorIs(t, err, context.DeadlineExceeded) + require.GreaterOrEqual(t, upstream.calls, 1) +} + +// --------------------------------------------------------------------------- +// TestRetryLoop_ErrorPolicy_NoPolicy_OriginalBehavior +// --------------------------------------------------------------------------- + +func TestRetryLoop_ErrorPolicy_NoPolicy_OriginalBehavior(t *testing.T) { + saveAndSetBaseURLs(t) + + upstream := &epFixedUpstream{statusCode: 429, body: `{"error":"rate limited"}`} + repo := &epAccountRepo{} + rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) + svc := &AntigravityGatewayService{rateLimitService: rlSvc} + + // Plain OAuth account with no error policy configured + account := &Account{ + ID: 400, + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + Schedulable: true, + Status: StatusActive, + Concurrency: 1, + } + + var handleErrorCount int + p := newRetryParams(account, upstream, func(_ context.Context, _ string, _ *Account, _ int, _ http.Header, _ []byte, _ string, _ int64, _ string, _ bool) *handleModelRateLimitResult { + handleErrorCount++ + return nil + }) + + result, err := svc.antigravityRetryLoop(p) + + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, result.resp) + defer func() { _ = result.resp.Body.Close() }() + + require.Equal(t, http.StatusTooManyRequests, result.resp.StatusCode) + require.Equal(t, antigravityMaxRetries, upstream.calls, "should exhaust all retries") + require.Equal(t, 1, handleErrorCount, "handleError should be called once after retries exhausted") +} diff --git a/backend/internal/service/error_policy_test.go b/backend/internal/service/error_policy_test.go new file mode 100644 index 00000000..a8b69c22 --- /dev/null +++ b/backend/internal/service/error_policy_test.go @@ -0,0 +1,289 @@ +//go:build unit + +package service + +import ( + "context" + "net/http" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +// --------------------------------------------------------------------------- +// TestCheckErrorPolicy — 6 table-driven cases for the pure logic function +// --------------------------------------------------------------------------- + +func TestCheckErrorPolicy(t *testing.T) { + tests := []struct { + name string + account *Account + statusCode int + body []byte + expected ErrorPolicyResult + }{ + { + name: "no_policy_oauth_returns_none", + account: &Account{ + ID: 1, + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + // no custom error codes, no temp rules + }, + statusCode: 500, + body: []byte(`"error"`), + expected: ErrorPolicyNone, + }, + { + name: "custom_error_codes_hit_returns_matched", + account: &Account{ + ID: 2, + Type: AccountTypeAPIKey, + Platform: PlatformAntigravity, + Credentials: map[string]any{ + "custom_error_codes_enabled": true, + "custom_error_codes": []any{float64(429), float64(500)}, + }, + }, + statusCode: 500, + body: []byte(`"error"`), + expected: ErrorPolicyMatched, + }, + { + name: "custom_error_codes_miss_returns_skipped", + account: &Account{ + ID: 3, + Type: AccountTypeAPIKey, + Platform: PlatformAntigravity, + Credentials: map[string]any{ + "custom_error_codes_enabled": true, + "custom_error_codes": []any{float64(429), float64(500)}, + }, + }, + statusCode: 503, + body: []byte(`"error"`), + expected: ErrorPolicySkipped, + }, + { + name: "temp_unschedulable_hit_returns_temp_unscheduled", + account: &Account{ + ID: 4, + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + Credentials: map[string]any{ + "temp_unschedulable_enabled": true, + "temp_unschedulable_rules": []any{ + map[string]any{ + "error_code": float64(503), + "keywords": []any{"overloaded"}, + "duration_minutes": float64(10), + "description": "overloaded rule", + }, + }, + }, + }, + statusCode: 503, + body: []byte(`overloaded service`), + expected: ErrorPolicyTempUnscheduled, + }, + { + name: "temp_unschedulable_body_miss_returns_none", + account: &Account{ + ID: 5, + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + Credentials: map[string]any{ + "temp_unschedulable_enabled": true, + "temp_unschedulable_rules": []any{ + map[string]any{ + "error_code": float64(503), + "keywords": []any{"overloaded"}, + "duration_minutes": float64(10), + "description": "overloaded rule", + }, + }, + }, + }, + statusCode: 503, + body: []byte(`random msg`), + expected: ErrorPolicyNone, + }, + { + name: "custom_error_codes_override_temp_unschedulable", + account: &Account{ + ID: 6, + Type: AccountTypeAPIKey, + Platform: PlatformAntigravity, + Credentials: map[string]any{ + "custom_error_codes_enabled": true, + "custom_error_codes": []any{float64(503)}, + "temp_unschedulable_enabled": true, + "temp_unschedulable_rules": []any{ + map[string]any{ + "error_code": float64(503), + "keywords": []any{"overloaded"}, + "duration_minutes": float64(10), + "description": "overloaded rule", + }, + }, + }, + }, + statusCode: 503, + body: []byte(`overloaded`), + expected: ErrorPolicyMatched, // custom codes take precedence + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + repo := &errorPolicyRepoStub{} + svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) + + result := svc.CheckErrorPolicy(context.Background(), tt.account, tt.statusCode, tt.body) + require.Equal(t, tt.expected, result, "unexpected ErrorPolicyResult") + }) + } +} + +// --------------------------------------------------------------------------- +// TestApplyErrorPolicy — 4 table-driven cases for the wrapper method +// --------------------------------------------------------------------------- + +func TestApplyErrorPolicy(t *testing.T) { + tests := []struct { + name string + account *Account + statusCode int + body []byte + expectedHandled bool + expectedSwitchErr bool // expect *AntigravityAccountSwitchError + handleErrorCalls int + }{ + { + name: "none_not_handled", + account: &Account{ + ID: 10, + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + }, + statusCode: 500, + body: []byte(`"error"`), + expectedHandled: false, + handleErrorCalls: 0, + }, + { + name: "skipped_handled_no_handleError", + account: &Account{ + ID: 11, + Type: AccountTypeAPIKey, + Platform: PlatformAntigravity, + Credentials: map[string]any{ + "custom_error_codes_enabled": true, + "custom_error_codes": []any{float64(429)}, + }, + }, + statusCode: 500, // not in custom codes + body: []byte(`"error"`), + expectedHandled: true, + handleErrorCalls: 0, + }, + { + name: "matched_handled_calls_handleError", + account: &Account{ + ID: 12, + Type: AccountTypeAPIKey, + Platform: PlatformAntigravity, + Credentials: map[string]any{ + "custom_error_codes_enabled": true, + "custom_error_codes": []any{float64(500)}, + }, + }, + statusCode: 500, + body: []byte(`"error"`), + expectedHandled: true, + handleErrorCalls: 1, + }, + { + name: "temp_unscheduled_returns_switch_error", + account: &Account{ + ID: 13, + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + Credentials: map[string]any{ + "temp_unschedulable_enabled": true, + "temp_unschedulable_rules": []any{ + map[string]any{ + "error_code": float64(503), + "keywords": []any{"overloaded"}, + "duration_minutes": float64(10), + }, + }, + }, + }, + statusCode: 503, + body: []byte(`overloaded`), + expectedHandled: true, + expectedSwitchErr: true, + handleErrorCalls: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + repo := &errorPolicyRepoStub{} + rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) + svc := &AntigravityGatewayService{ + rateLimitService: rlSvc, + } + + var handleErrorCount int + p := antigravityRetryLoopParams{ + ctx: context.Background(), + prefix: "[test]", + account: tt.account, + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + handleErrorCount++ + return nil + }, + isStickySession: true, + } + + handled, retErr := svc.applyErrorPolicy(p, tt.statusCode, http.Header{}, tt.body) + + require.Equal(t, tt.expectedHandled, handled, "handled mismatch") + require.Equal(t, tt.handleErrorCalls, handleErrorCount, "handleError call count mismatch") + + if tt.expectedSwitchErr { + var switchErr *AntigravityAccountSwitchError + require.ErrorAs(t, retErr, &switchErr) + require.Equal(t, tt.account.ID, switchErr.OriginalAccountID) + } else { + require.NoError(t, retErr) + } + }) + } +} + +// --------------------------------------------------------------------------- +// errorPolicyRepoStub — minimal AccountRepository stub for error policy tests +// --------------------------------------------------------------------------- + +type errorPolicyRepoStub struct { + mockAccountRepoForGemini + tempCalls int + setErrCalls int + lastErrorMsg string +} + +func (r *errorPolicyRepoStub) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error { + r.tempCalls++ + return nil +} + +func (r *errorPolicyRepoStub) SetError(ctx context.Context, id int64, errorMsg string) error { + r.setErrCalls++ + r.lastErrorMsg = errorMsg + return nil +} diff --git a/backend/internal/service/ratelimit_service.go b/backend/internal/service/ratelimit_service.go index 47286deb..63732dee 100644 --- a/backend/internal/service/ratelimit_service.go +++ b/backend/internal/service/ratelimit_service.go @@ -62,6 +62,32 @@ func (s *RateLimitService) SetTokenCacheInvalidator(invalidator TokenCacheInvali s.tokenCacheInvalidator = invalidator } +// ErrorPolicyResult 表示错误策略检查的结果 +type ErrorPolicyResult int + +const ( + ErrorPolicyNone ErrorPolicyResult = iota // 未命中任何策略,继续默认逻辑 + ErrorPolicySkipped // 自定义错误码开启但未命中,跳过处理 + ErrorPolicyMatched // 自定义错误码命中,应停止调度 + ErrorPolicyTempUnscheduled // 临时不可调度规则命中 +) + +// CheckErrorPolicy 检查自定义错误码和临时不可调度规则。 +// 自定义错误码开启时覆盖后续所有逻辑(包括临时不可调度)。 +func (s *RateLimitService) CheckErrorPolicy(ctx context.Context, account *Account, statusCode int, responseBody []byte) ErrorPolicyResult { + if account.IsCustomErrorCodesEnabled() { + if account.ShouldHandleErrorCode(statusCode) { + return ErrorPolicyMatched + } + slog.Info("account_error_code_skipped", "account_id", account.ID, "status_code", statusCode) + return ErrorPolicySkipped + } + if s.tryTempUnschedulable(ctx, account, statusCode, responseBody) { + return ErrorPolicyTempUnscheduled + } + return ErrorPolicyNone +} + // HandleUpstreamError 处理上游错误响应,标记账号状态 // 返回是否应该停止该账号的调度 func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Account, statusCode int, headers http.Header, responseBody []byte) (shouldDisable bool) { diff --git a/frontend/src/components/account/CreateAccountModal.vue b/frontend/src/components/account/CreateAccountModal.vue index 18bac7ff..8b4d4c06 100644 --- a/frontend/src/components/account/CreateAccountModal.vue +++ b/frontend/src/components/account/CreateAccountModal.vue @@ -1038,10 +1038,7 @@ -
+