From 09166a52f89cb873b75a9061285b1864d57ee015 Mon Sep 17 00:00:00 2001 From: erio Date: Tue, 24 Feb 2026 18:08:04 +0800 Subject: [PATCH] refactor: extract failover error handling into FailoverState - Extract duplicated failover logic from gateway_handler.go (3 places) and gemini_v1beta_handler.go into shared failover_loop.go - Introduce FailoverState with HandleFailoverError and HandleSelectionExhausted - Move helper functions (needForceCacheBilling, sleepWithContext) into failover_loop.go - Add comprehensive unit tests (32+ test cases) - Delete redundant gateway_handler_single_account_retry_test.go --- backend/internal/handler/failover_loop.go | 160 ++++ .../internal/handler/failover_loop_test.go | 732 ++++++++++++++++++ backend/internal/handler/gateway_handler.go | 260 ++----- ...teway_handler_single_account_retry_test.go | 51 -- .../internal/handler/gemini_v1beta_handler.go | 75 +- 5 files changed, 975 insertions(+), 303 deletions(-) create mode 100644 backend/internal/handler/failover_loop.go create mode 100644 backend/internal/handler/failover_loop_test.go delete mode 100644 backend/internal/handler/gateway_handler_single_account_retry_test.go diff --git a/backend/internal/handler/failover_loop.go b/backend/internal/handler/failover_loop.go new file mode 100644 index 00000000..1f8a7e9a --- /dev/null +++ b/backend/internal/handler/failover_loop.go @@ -0,0 +1,160 @@ +package handler + +import ( + "context" + "log" + "net/http" + "time" + + "github.com/Wei-Shaw/sub2api/internal/service" +) + +// TempUnscheduler 用于 HandleFailoverError 中同账号重试耗尽后的临时封禁。 +// GatewayService 隐式实现此接口。 +type TempUnscheduler interface { + TempUnscheduleRetryableError(ctx context.Context, accountID int64, failoverErr *service.UpstreamFailoverError) +} + +// FailoverAction 表示 failover 错误处理后的下一步动作 +type FailoverAction int + +const ( + // FailoverContinue 继续循环(同账号重试或切换账号,调用方统一 continue) + FailoverContinue FailoverAction = iota + // FailoverExhausted 切换次数耗尽(调用方应返回错误响应) + FailoverExhausted + // FailoverCanceled context 已取消(调用方应直接 return) + FailoverCanceled +) + +const ( + // maxSameAccountRetries 同账号重试次数上限(针对 RetryableOnSameAccount 错误) + maxSameAccountRetries = 2 + // sameAccountRetryDelay 同账号重试间隔 + sameAccountRetryDelay = 500 * time.Millisecond + // singleAccountBackoffDelay 单账号分组 503 退避重试固定延时。 + // Service 层在 SingleAccountRetry 模式下已做充分原地重试(最多 3 次、总等待 30s), + // Handler 层只需短暂间隔后重新进入 Service 层即可。 + singleAccountBackoffDelay = 2 * time.Second +) + +// FailoverState 跨循环迭代共享的 failover 状态 +type FailoverState struct { + SwitchCount int + MaxSwitches int + FailedAccountIDs map[int64]struct{} + SameAccountRetryCount map[int64]int + LastFailoverErr *service.UpstreamFailoverError + ForceCacheBilling bool + hasBoundSession bool +} + +// NewFailoverState 创建 failover 状态 +func NewFailoverState(maxSwitches int, hasBoundSession bool) *FailoverState { + return &FailoverState{ + MaxSwitches: maxSwitches, + FailedAccountIDs: make(map[int64]struct{}), + SameAccountRetryCount: make(map[int64]int), + hasBoundSession: hasBoundSession, + } +} + +// HandleFailoverError 处理 UpstreamFailoverError,返回下一步动作。 +// 包含:缓存计费判断、同账号重试、临时封禁、切换计数、Antigravity 延时。 +func (s *FailoverState) HandleFailoverError( + ctx context.Context, + gatewayService TempUnscheduler, + accountID int64, + platform string, + failoverErr *service.UpstreamFailoverError, +) FailoverAction { + s.LastFailoverErr = failoverErr + + // 缓存计费判断 + if needForceCacheBilling(s.hasBoundSession, failoverErr) { + s.ForceCacheBilling = true + } + + // 同账号重试:对 RetryableOnSameAccount 的临时性错误,先在同一账号上重试 + if failoverErr.RetryableOnSameAccount && s.SameAccountRetryCount[accountID] < maxSameAccountRetries { + s.SameAccountRetryCount[accountID]++ + log.Printf("Account %d: retryable error %d, same-account retry %d/%d", + accountID, failoverErr.StatusCode, s.SameAccountRetryCount[accountID], maxSameAccountRetries) + if !sleepWithContext(ctx, sameAccountRetryDelay) { + return FailoverCanceled + } + return FailoverContinue + } + + // 同账号重试用尽,执行临时封禁 + if failoverErr.RetryableOnSameAccount { + gatewayService.TempUnscheduleRetryableError(ctx, accountID, failoverErr) + } + + // 加入失败列表 + s.FailedAccountIDs[accountID] = struct{}{} + + // 检查是否耗尽 + if s.SwitchCount >= s.MaxSwitches { + return FailoverExhausted + } + + // 递增切换计数 + s.SwitchCount++ + log.Printf("Account %d: upstream error %d, switching account %d/%d", + accountID, failoverErr.StatusCode, s.SwitchCount, s.MaxSwitches) + + // Antigravity 平台换号线性递增延时 + if platform == service.PlatformAntigravity { + delay := time.Duration(s.SwitchCount-1) * time.Second + if !sleepWithContext(ctx, delay) { + return FailoverCanceled + } + } + + return FailoverContinue +} + +// HandleSelectionExhausted 处理选号失败(所有候选账号都在排除列表中)时的退避重试决策。 +// 针对 Antigravity 单账号分组的 503 (MODEL_CAPACITY_EXHAUSTED) 场景: +// 清除排除列表、等待退避后重新选号。 +// +// 返回 FailoverContinue 时,调用方应设置 SingleAccountRetry context 并 continue。 +// 返回 FailoverExhausted 时,调用方应返回错误响应。 +// 返回 FailoverCanceled 时,调用方应直接 return。 +func (s *FailoverState) HandleSelectionExhausted(ctx context.Context) FailoverAction { + if s.LastFailoverErr != nil && + s.LastFailoverErr.StatusCode == http.StatusServiceUnavailable && + s.SwitchCount <= s.MaxSwitches { + + log.Printf("Antigravity single-account 503 backoff: waiting %v before retry (attempt %d)", + singleAccountBackoffDelay, s.SwitchCount) + if !sleepWithContext(ctx, singleAccountBackoffDelay) { + return FailoverCanceled + } + log.Printf("Antigravity single-account 503 retry: clearing failed accounts, retry %d/%d", + s.SwitchCount, s.MaxSwitches) + s.FailedAccountIDs = make(map[int64]struct{}) + return FailoverContinue + } + return FailoverExhausted +} + +// needForceCacheBilling 判断 failover 时是否需要强制缓存计费。 +// 粘性会话切换账号、或上游明确标记时,将 input_tokens 转为 cache_read 计费。 +func needForceCacheBilling(hasBoundSession bool, failoverErr *service.UpstreamFailoverError) bool { + return hasBoundSession || (failoverErr != nil && failoverErr.ForceCacheBilling) +} + +// sleepWithContext 等待指定时长,返回 false 表示 context 已取消。 +func sleepWithContext(ctx context.Context, d time.Duration) bool { + if d <= 0 { + return true + } + select { + case <-ctx.Done(): + return false + case <-time.After(d): + return true + } +} diff --git a/backend/internal/handler/failover_loop_test.go b/backend/internal/handler/failover_loop_test.go new file mode 100644 index 00000000..5a41b2dd --- /dev/null +++ b/backend/internal/handler/failover_loop_test.go @@ -0,0 +1,732 @@ +package handler + +import ( + "context" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/service" + + "github.com/stretchr/testify/require" +) + +// --------------------------------------------------------------------------- +// Mock +// --------------------------------------------------------------------------- + +// mockTempUnscheduler 记录 TempUnscheduleRetryableError 的调用信息。 +type mockTempUnscheduler struct { + calls []tempUnscheduleCall +} + +type tempUnscheduleCall struct { + accountID int64 + failoverErr *service.UpstreamFailoverError +} + +func (m *mockTempUnscheduler) TempUnscheduleRetryableError(_ context.Context, accountID int64, failoverErr *service.UpstreamFailoverError) { + m.calls = append(m.calls, tempUnscheduleCall{accountID: accountID, failoverErr: failoverErr}) +} + +// --------------------------------------------------------------------------- +// Helper +// --------------------------------------------------------------------------- + +func newTestFailoverErr(statusCode int, retryable, forceBilling bool) *service.UpstreamFailoverError { + return &service.UpstreamFailoverError{ + StatusCode: statusCode, + RetryableOnSameAccount: retryable, + ForceCacheBilling: forceBilling, + } +} + +// --------------------------------------------------------------------------- +// NewFailoverState 测试 +// --------------------------------------------------------------------------- + +func TestNewFailoverState(t *testing.T) { + t.Run("初始化字段正确", func(t *testing.T) { + fs := NewFailoverState(5, true) + require.Equal(t, 5, fs.MaxSwitches) + require.Equal(t, 0, fs.SwitchCount) + require.NotNil(t, fs.FailedAccountIDs) + require.Empty(t, fs.FailedAccountIDs) + require.NotNil(t, fs.SameAccountRetryCount) + require.Empty(t, fs.SameAccountRetryCount) + require.Nil(t, fs.LastFailoverErr) + require.False(t, fs.ForceCacheBilling) + require.True(t, fs.hasBoundSession) + }) + + t.Run("无绑定会话", func(t *testing.T) { + fs := NewFailoverState(3, false) + require.Equal(t, 3, fs.MaxSwitches) + require.False(t, fs.hasBoundSession) + }) + + t.Run("零最大切换次数", func(t *testing.T) { + fs := NewFailoverState(0, false) + require.Equal(t, 0, fs.MaxSwitches) + }) +} + +// --------------------------------------------------------------------------- +// sleepWithContext 测试 +// --------------------------------------------------------------------------- + +func TestSleepWithContext(t *testing.T) { + t.Run("零时长立即返回true", func(t *testing.T) { + start := time.Now() + ok := sleepWithContext(context.Background(), 0) + require.True(t, ok) + require.Less(t, time.Since(start), 50*time.Millisecond) + }) + + t.Run("负时长立即返回true", func(t *testing.T) { + start := time.Now() + ok := sleepWithContext(context.Background(), -1*time.Second) + require.True(t, ok) + require.Less(t, time.Since(start), 50*time.Millisecond) + }) + + t.Run("正常等待后返回true", func(t *testing.T) { + start := time.Now() + ok := sleepWithContext(context.Background(), 50*time.Millisecond) + elapsed := time.Since(start) + require.True(t, ok) + require.GreaterOrEqual(t, elapsed, 40*time.Millisecond) + require.Less(t, elapsed, 500*time.Millisecond) + }) + + t.Run("已取消context立即返回false", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + start := time.Now() + ok := sleepWithContext(ctx, 5*time.Second) + require.False(t, ok) + require.Less(t, time.Since(start), 50*time.Millisecond) + }) + + t.Run("等待期间context取消返回false", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + go func() { + time.Sleep(30 * time.Millisecond) + cancel() + }() + + start := time.Now() + ok := sleepWithContext(ctx, 5*time.Second) + elapsed := time.Since(start) + require.False(t, ok) + require.Less(t, elapsed, 500*time.Millisecond) + }) +} + +// --------------------------------------------------------------------------- +// HandleFailoverError — 基本切换流程 +// --------------------------------------------------------------------------- + +func TestHandleFailoverError_BasicSwitch(t *testing.T) { + t.Run("非重试错误_非Antigravity_直接切换", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(3, false) + err := newTestFailoverErr(500, false, false) + + action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err) + + require.Equal(t, FailoverContinue, action) + require.Equal(t, 1, fs.SwitchCount) + require.Contains(t, fs.FailedAccountIDs, int64(100)) + require.Equal(t, err, fs.LastFailoverErr) + require.False(t, fs.ForceCacheBilling) + require.Empty(t, mock.calls, "不应调用 TempUnschedule") + }) + + t.Run("非重试错误_Antigravity_第一次切换无延迟", func(t *testing.T) { + // switchCount 从 0→1 时,sleepFailoverDelay(ctx, 1) 的延时 = (1-1)*1s = 0 + mock := &mockTempUnscheduler{} + fs := NewFailoverState(3, false) + err := newTestFailoverErr(500, false, false) + + start := time.Now() + action := fs.HandleFailoverError(context.Background(), mock, 100, service.PlatformAntigravity, err) + elapsed := time.Since(start) + + require.Equal(t, FailoverContinue, action) + require.Equal(t, 1, fs.SwitchCount) + require.Less(t, elapsed, 200*time.Millisecond, "第一次切换延迟应为 0") + }) + + t.Run("非重试错误_Antigravity_第二次切换有1秒延迟", func(t *testing.T) { + // switchCount 从 1→2 时,sleepFailoverDelay(ctx, 2) 的延时 = (2-1)*1s = 1s + mock := &mockTempUnscheduler{} + fs := NewFailoverState(3, false) + fs.SwitchCount = 1 // 模拟已切换一次 + + err := newTestFailoverErr(500, false, false) + start := time.Now() + action := fs.HandleFailoverError(context.Background(), mock, 200, service.PlatformAntigravity, err) + elapsed := time.Since(start) + + require.Equal(t, FailoverContinue, action) + require.Equal(t, 2, fs.SwitchCount) + require.GreaterOrEqual(t, elapsed, 800*time.Millisecond, "第二次切换延迟应约 1s") + require.Less(t, elapsed, 3*time.Second) + }) + + t.Run("连续切换直到耗尽", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(2, false) + + // 第一次切换:0→1 + err1 := newTestFailoverErr(500, false, false) + action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err1) + require.Equal(t, FailoverContinue, action) + require.Equal(t, 1, fs.SwitchCount) + + // 第二次切换:1→2 + err2 := newTestFailoverErr(502, false, false) + action = fs.HandleFailoverError(context.Background(), mock, 200, "openai", err2) + require.Equal(t, FailoverContinue, action) + require.Equal(t, 2, fs.SwitchCount) + + // 第三次已耗尽:SwitchCount(2) >= MaxSwitches(2) + err3 := newTestFailoverErr(503, false, false) + action = fs.HandleFailoverError(context.Background(), mock, 300, "openai", err3) + require.Equal(t, FailoverExhausted, action) + require.Equal(t, 2, fs.SwitchCount, "耗尽时不应继续递增") + + // 验证失败账号列表 + require.Len(t, fs.FailedAccountIDs, 3) + require.Contains(t, fs.FailedAccountIDs, int64(100)) + require.Contains(t, fs.FailedAccountIDs, int64(200)) + require.Contains(t, fs.FailedAccountIDs, int64(300)) + + // LastFailoverErr 应为最后一次的错误 + require.Equal(t, err3, fs.LastFailoverErr) + }) + + t.Run("MaxSwitches为0时首次即耗尽", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(0, false) + err := newTestFailoverErr(500, false, false) + + action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err) + require.Equal(t, FailoverExhausted, action) + require.Equal(t, 0, fs.SwitchCount) + require.Contains(t, fs.FailedAccountIDs, int64(100)) + }) +} + +// --------------------------------------------------------------------------- +// HandleFailoverError — 缓存计费 (ForceCacheBilling) +// --------------------------------------------------------------------------- + +func TestHandleFailoverError_CacheBilling(t *testing.T) { + t.Run("hasBoundSession为true时设置ForceCacheBilling", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(3, true) // hasBoundSession=true + err := newTestFailoverErr(500, false, false) + + fs.HandleFailoverError(context.Background(), mock, 100, "openai", err) + require.True(t, fs.ForceCacheBilling) + }) + + t.Run("failoverErr.ForceCacheBilling为true时设置", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(3, false) + err := newTestFailoverErr(500, false, true) // ForceCacheBilling=true + + fs.HandleFailoverError(context.Background(), mock, 100, "openai", err) + require.True(t, fs.ForceCacheBilling) + }) + + t.Run("两者均为false时不设置", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(3, false) + err := newTestFailoverErr(500, false, false) + + fs.HandleFailoverError(context.Background(), mock, 100, "openai", err) + require.False(t, fs.ForceCacheBilling) + }) + + t.Run("一旦设置不会被后续错误重置", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(3, false) + + // 第一次:ForceCacheBilling=true → 设置 + err1 := newTestFailoverErr(500, false, true) + fs.HandleFailoverError(context.Background(), mock, 100, "openai", err1) + require.True(t, fs.ForceCacheBilling) + + // 第二次:ForceCacheBilling=false → 仍然保持 true + err2 := newTestFailoverErr(502, false, false) + fs.HandleFailoverError(context.Background(), mock, 200, "openai", err2) + require.True(t, fs.ForceCacheBilling, "ForceCacheBilling 一旦设置不应被重置") + }) +} + +// --------------------------------------------------------------------------- +// HandleFailoverError — 同账号重试 (RetryableOnSameAccount) +// --------------------------------------------------------------------------- + +func TestHandleFailoverError_SameAccountRetry(t *testing.T) { + t.Run("第一次重试返回FailoverContinue", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(3, false) + err := newTestFailoverErr(400, true, false) + + start := time.Now() + action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err) + elapsed := time.Since(start) + + require.Equal(t, FailoverContinue, action) + require.Equal(t, 1, fs.SameAccountRetryCount[100]) + require.Equal(t, 0, fs.SwitchCount, "同账号重试不应增加切换计数") + require.NotContains(t, fs.FailedAccountIDs, int64(100), "同账号重试不应加入失败列表") + require.Empty(t, mock.calls, "同账号重试期间不应调用 TempUnschedule") + // 验证等待了 sameAccountRetryDelay (500ms) + require.GreaterOrEqual(t, elapsed, 400*time.Millisecond) + require.Less(t, elapsed, 2*time.Second) + }) + + t.Run("第二次重试仍返回FailoverContinue", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(3, false) + err := newTestFailoverErr(400, true, false) + + // 第一次 + action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err) + require.Equal(t, FailoverContinue, action) + require.Equal(t, 1, fs.SameAccountRetryCount[100]) + + // 第二次 + action = fs.HandleFailoverError(context.Background(), mock, 100, "openai", err) + require.Equal(t, FailoverContinue, action) + require.Equal(t, 2, fs.SameAccountRetryCount[100]) + + require.Empty(t, mock.calls, "两次重试期间均不应调用 TempUnschedule") + }) + + t.Run("第三次重试耗尽_触发TempUnschedule并切换", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(3, false) + err := newTestFailoverErr(400, true, false) + + // 第一次、第二次重试 + fs.HandleFailoverError(context.Background(), mock, 100, "openai", err) + fs.HandleFailoverError(context.Background(), mock, 100, "openai", err) + require.Equal(t, 2, fs.SameAccountRetryCount[100]) + + // 第三次:重试已达到 maxSameAccountRetries(2),应切换账号 + action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err) + require.Equal(t, FailoverContinue, action) + require.Equal(t, 1, fs.SwitchCount) + require.Contains(t, fs.FailedAccountIDs, int64(100)) + + // 验证 TempUnschedule 被调用 + require.Len(t, mock.calls, 1) + require.Equal(t, int64(100), mock.calls[0].accountID) + require.Equal(t, err, mock.calls[0].failoverErr) + }) + + t.Run("不同账号独立跟踪重试次数", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(5, false) + err := newTestFailoverErr(400, true, false) + + // 账号 100 第一次重试 + action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err) + require.Equal(t, FailoverContinue, action) + require.Equal(t, 1, fs.SameAccountRetryCount[100]) + + // 账号 200 第一次重试(独立计数) + action = fs.HandleFailoverError(context.Background(), mock, 200, "openai", err) + require.Equal(t, FailoverContinue, action) + require.Equal(t, 1, fs.SameAccountRetryCount[200]) + require.Equal(t, 1, fs.SameAccountRetryCount[100], "账号 100 的计数不应受影响") + }) + + t.Run("重试耗尽后再次遇到同账号_直接切换", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(5, false) + err := newTestFailoverErr(400, true, false) + + // 耗尽账号 100 的重试 + fs.HandleFailoverError(context.Background(), mock, 100, "openai", err) + fs.HandleFailoverError(context.Background(), mock, 100, "openai", err) + // 第三次: 重试耗尽 → 切换 + action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err) + require.Equal(t, FailoverContinue, action) + + // 再次遇到账号 100,计数仍为 2,条件不满足 → 直接切换 + action = fs.HandleFailoverError(context.Background(), mock, 100, "openai", err) + require.Equal(t, FailoverContinue, action) + require.Len(t, mock.calls, 2, "第二次耗尽也应调用 TempUnschedule") + }) +} + +// --------------------------------------------------------------------------- +// HandleFailoverError — TempUnschedule 调用验证 +// --------------------------------------------------------------------------- + +func TestHandleFailoverError_TempUnschedule(t *testing.T) { + t.Run("非重试错误不调用TempUnschedule", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(3, false) + err := newTestFailoverErr(500, false, false) // RetryableOnSameAccount=false + + fs.HandleFailoverError(context.Background(), mock, 100, "openai", err) + require.Empty(t, mock.calls) + }) + + t.Run("重试错误耗尽后调用TempUnschedule_传入正确参数", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(3, false) + err := newTestFailoverErr(502, true, false) + + // 耗尽重试 + fs.HandleFailoverError(context.Background(), mock, 42, "openai", err) + fs.HandleFailoverError(context.Background(), mock, 42, "openai", err) + fs.HandleFailoverError(context.Background(), mock, 42, "openai", err) + + require.Len(t, mock.calls, 1) + require.Equal(t, int64(42), mock.calls[0].accountID) + require.Equal(t, 502, mock.calls[0].failoverErr.StatusCode) + require.True(t, mock.calls[0].failoverErr.RetryableOnSameAccount) + }) +} + +// --------------------------------------------------------------------------- +// HandleFailoverError — Context 取消 +// --------------------------------------------------------------------------- + +func TestHandleFailoverError_ContextCanceled(t *testing.T) { + t.Run("同账号重试sleep期间context取消", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(3, false) + err := newTestFailoverErr(400, true, false) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() // 立即取消 + + start := time.Now() + action := fs.HandleFailoverError(ctx, mock, 100, "openai", err) + elapsed := time.Since(start) + + require.Equal(t, FailoverCanceled, action) + require.Less(t, elapsed, 100*time.Millisecond, "应立即返回") + // 重试计数仍应递增 + require.Equal(t, 1, fs.SameAccountRetryCount[100]) + }) + + t.Run("Antigravity延迟期间context取消", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(3, false) + fs.SwitchCount = 1 // 下一次 switchCount=2 → delay = 1s + err := newTestFailoverErr(500, false, false) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() // 立即取消 + + start := time.Now() + action := fs.HandleFailoverError(ctx, mock, 100, service.PlatformAntigravity, err) + elapsed := time.Since(start) + + require.Equal(t, FailoverCanceled, action) + require.Less(t, elapsed, 100*time.Millisecond, "应立即返回而非等待 1s") + }) +} + +// --------------------------------------------------------------------------- +// HandleFailoverError — FailedAccountIDs 跟踪 +// --------------------------------------------------------------------------- + +func TestHandleFailoverError_FailedAccountIDs(t *testing.T) { + t.Run("切换时添加到失败列表", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(3, false) + + fs.HandleFailoverError(context.Background(), mock, 100, "openai", newTestFailoverErr(500, false, false)) + require.Contains(t, fs.FailedAccountIDs, int64(100)) + + fs.HandleFailoverError(context.Background(), mock, 200, "openai", newTestFailoverErr(502, false, false)) + require.Contains(t, fs.FailedAccountIDs, int64(200)) + require.Len(t, fs.FailedAccountIDs, 2) + }) + + t.Run("耗尽时也添加到失败列表", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(0, false) + + action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", newTestFailoverErr(500, false, false)) + require.Equal(t, FailoverExhausted, action) + require.Contains(t, fs.FailedAccountIDs, int64(100)) + }) + + t.Run("同账号重试期间不添加到失败列表", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(3, false) + + action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", newTestFailoverErr(400, true, false)) + require.Equal(t, FailoverContinue, action) + require.NotContains(t, fs.FailedAccountIDs, int64(100)) + }) + + t.Run("同一账号多次切换不重复添加", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(5, false) + + fs.HandleFailoverError(context.Background(), mock, 100, "openai", newTestFailoverErr(500, false, false)) + fs.HandleFailoverError(context.Background(), mock, 100, "openai", newTestFailoverErr(500, false, false)) + require.Len(t, fs.FailedAccountIDs, 1, "map 天然去重") + }) +} + +// --------------------------------------------------------------------------- +// HandleFailoverError — LastFailoverErr 更新 +// --------------------------------------------------------------------------- + +func TestHandleFailoverError_LastFailoverErr(t *testing.T) { + t.Run("每次调用都更新LastFailoverErr", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(3, false) + + err1 := newTestFailoverErr(500, false, false) + fs.HandleFailoverError(context.Background(), mock, 100, "openai", err1) + require.Equal(t, err1, fs.LastFailoverErr) + + err2 := newTestFailoverErr(502, false, false) + fs.HandleFailoverError(context.Background(), mock, 200, "openai", err2) + require.Equal(t, err2, fs.LastFailoverErr) + }) + + t.Run("同账号重试时也更新LastFailoverErr", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(3, false) + + err := newTestFailoverErr(400, true, false) + fs.HandleFailoverError(context.Background(), mock, 100, "openai", err) + require.Equal(t, err, fs.LastFailoverErr) + }) +} + +// --------------------------------------------------------------------------- +// HandleFailoverError — 综合集成场景 +// --------------------------------------------------------------------------- + +func TestHandleFailoverError_IntegrationScenario(t *testing.T) { + t.Run("模拟完整failover流程_多账号混合重试与切换", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(3, true) // hasBoundSession=true + + // 1. 账号 100 遇到可重试错误,同账号重试 2 次 + retryErr := newTestFailoverErr(400, true, false) + action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", retryErr) + require.Equal(t, FailoverContinue, action) + require.True(t, fs.ForceCacheBilling, "hasBoundSession=true 应设置 ForceCacheBilling") + + action = fs.HandleFailoverError(context.Background(), mock, 100, "openai", retryErr) + require.Equal(t, FailoverContinue, action) + + // 2. 账号 100 重试耗尽 → TempUnschedule + 切换 + action = fs.HandleFailoverError(context.Background(), mock, 100, "openai", retryErr) + require.Equal(t, FailoverContinue, action) + require.Equal(t, 1, fs.SwitchCount) + require.Len(t, mock.calls, 1) + + // 3. 账号 200 遇到不可重试错误 → 直接切换 + switchErr := newTestFailoverErr(500, false, false) + action = fs.HandleFailoverError(context.Background(), mock, 200, "openai", switchErr) + require.Equal(t, FailoverContinue, action) + require.Equal(t, 2, fs.SwitchCount) + + // 4. 账号 300 遇到不可重试错误 → 再切换 + action = fs.HandleFailoverError(context.Background(), mock, 300, "openai", switchErr) + require.Equal(t, FailoverContinue, action) + require.Equal(t, 3, fs.SwitchCount) + + // 5. 账号 400 → 已耗尽 (SwitchCount=3 >= MaxSwitches=3) + action = fs.HandleFailoverError(context.Background(), mock, 400, "openai", switchErr) + require.Equal(t, FailoverExhausted, action) + + // 最终状态验证 + require.Equal(t, 3, fs.SwitchCount, "耗尽时不再递增") + require.Len(t, fs.FailedAccountIDs, 4, "4个不同账号都在失败列表中") + require.True(t, fs.ForceCacheBilling) + require.Len(t, mock.calls, 1, "只有账号 100 触发了 TempUnschedule") + }) + + t.Run("模拟Antigravity平台完整流程", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(2, false) + + err := newTestFailoverErr(500, false, false) + + // 第一次切换:delay = 0s + start := time.Now() + action := fs.HandleFailoverError(context.Background(), mock, 100, service.PlatformAntigravity, err) + elapsed := time.Since(start) + require.Equal(t, FailoverContinue, action) + require.Less(t, elapsed, 200*time.Millisecond, "第一次切换延迟为 0") + + // 第二次切换:delay = 1s + start = time.Now() + action = fs.HandleFailoverError(context.Background(), mock, 200, service.PlatformAntigravity, err) + elapsed = time.Since(start) + require.Equal(t, FailoverContinue, action) + require.GreaterOrEqual(t, elapsed, 800*time.Millisecond, "第二次切换延迟约 1s") + + // 第三次:耗尽(无延迟,因为在检查延迟之前就返回了) + start = time.Now() + action = fs.HandleFailoverError(context.Background(), mock, 300, service.PlatformAntigravity, err) + elapsed = time.Since(start) + require.Equal(t, FailoverExhausted, action) + require.Less(t, elapsed, 200*time.Millisecond, "耗尽时不应有延迟") + }) + + t.Run("ForceCacheBilling通过错误标志设置", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(3, false) // hasBoundSession=false + + // 第一次:ForceCacheBilling=false + err1 := newTestFailoverErr(500, false, false) + fs.HandleFailoverError(context.Background(), mock, 100, "openai", err1) + require.False(t, fs.ForceCacheBilling) + + // 第二次:ForceCacheBilling=true(Antigravity 粘性会话切换) + err2 := newTestFailoverErr(500, false, true) + fs.HandleFailoverError(context.Background(), mock, 200, "openai", err2) + require.True(t, fs.ForceCacheBilling, "错误标志应触发 ForceCacheBilling") + + // 第三次:ForceCacheBilling=false,但状态仍保持 true + err3 := newTestFailoverErr(500, false, false) + fs.HandleFailoverError(context.Background(), mock, 300, "openai", err3) + require.True(t, fs.ForceCacheBilling, "不应重置") + }) +} + +// --------------------------------------------------------------------------- +// HandleFailoverError — 边界条件 +// --------------------------------------------------------------------------- + +func TestHandleFailoverError_EdgeCases(t *testing.T) { + t.Run("StatusCode为0的错误也能正常处理", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(3, false) + err := newTestFailoverErr(0, false, false) + + action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err) + require.Equal(t, FailoverContinue, action) + }) + + t.Run("AccountID为0也能正常跟踪", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(3, false) + err := newTestFailoverErr(500, true, false) + + action := fs.HandleFailoverError(context.Background(), mock, 0, "openai", err) + require.Equal(t, FailoverContinue, action) + require.Equal(t, 1, fs.SameAccountRetryCount[0]) + }) + + t.Run("负AccountID也能正常跟踪", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(3, false) + err := newTestFailoverErr(500, true, false) + + action := fs.HandleFailoverError(context.Background(), mock, -1, "openai", err) + require.Equal(t, FailoverContinue, action) + require.Equal(t, 1, fs.SameAccountRetryCount[-1]) + }) + + t.Run("空平台名称不触发Antigravity延迟", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(3, false) + fs.SwitchCount = 1 + err := newTestFailoverErr(500, false, false) + + start := time.Now() + action := fs.HandleFailoverError(context.Background(), mock, 100, "", err) + elapsed := time.Since(start) + + require.Equal(t, FailoverContinue, action) + require.Less(t, elapsed, 200*time.Millisecond, "空平台不应触发 Antigravity 延迟") + }) +} + +// --------------------------------------------------------------------------- +// HandleSelectionExhausted 测试 +// --------------------------------------------------------------------------- + +func TestHandleSelectionExhausted(t *testing.T) { + t.Run("无LastFailoverErr时返回Exhausted", func(t *testing.T) { + fs := NewFailoverState(3, false) + // LastFailoverErr 为 nil + + action := fs.HandleSelectionExhausted(context.Background()) + require.Equal(t, FailoverExhausted, action) + }) + + t.Run("非503错误返回Exhausted", func(t *testing.T) { + fs := NewFailoverState(3, false) + fs.LastFailoverErr = newTestFailoverErr(500, false, false) + + action := fs.HandleSelectionExhausted(context.Background()) + require.Equal(t, FailoverExhausted, action) + }) + + t.Run("503且未耗尽_等待后返回Continue并清除失败列表", func(t *testing.T) { + fs := NewFailoverState(3, false) + fs.LastFailoverErr = newTestFailoverErr(503, false, false) + fs.FailedAccountIDs[100] = struct{}{} + fs.SwitchCount = 1 + + start := time.Now() + action := fs.HandleSelectionExhausted(context.Background()) + elapsed := time.Since(start) + + require.Equal(t, FailoverContinue, action) + require.Empty(t, fs.FailedAccountIDs, "应清除失败账号列表") + require.GreaterOrEqual(t, elapsed, 1500*time.Millisecond, "应等待约 2s") + require.Less(t, elapsed, 5*time.Second) + }) + + t.Run("503但SwitchCount已超过MaxSwitches_返回Exhausted", func(t *testing.T) { + fs := NewFailoverState(2, false) + fs.LastFailoverErr = newTestFailoverErr(503, false, false) + fs.SwitchCount = 3 // > MaxSwitches(2) + + start := time.Now() + action := fs.HandleSelectionExhausted(context.Background()) + elapsed := time.Since(start) + + require.Equal(t, FailoverExhausted, action) + require.Less(t, elapsed, 100*time.Millisecond, "不应等待") + }) + + t.Run("503但context已取消_返回Canceled", func(t *testing.T) { + fs := NewFailoverState(3, false) + fs.LastFailoverErr = newTestFailoverErr(503, false, false) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + start := time.Now() + action := fs.HandleSelectionExhausted(ctx) + elapsed := time.Since(start) + + require.Equal(t, FailoverCanceled, action) + require.Less(t, elapsed, 100*time.Millisecond, "应立即返回") + }) + + t.Run("503且SwitchCount等于MaxSwitches_仍可重试", func(t *testing.T) { + fs := NewFailoverState(2, false) + fs.LastFailoverErr = newTestFailoverErr(503, false, false) + fs.SwitchCount = 2 // == MaxSwitches,条件是 <=,仍可重试 + + action := fs.HandleSelectionExhausted(context.Background()) + require.Equal(t, FailoverContinue, action) + }) +} diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index 4b32969f..d78f83a6 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -7,7 +7,6 @@ import ( "errors" "fmt" "io" - "log" "net/http" "strings" "time" @@ -257,12 +256,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { hasBoundSession := sessionKey != "" && sessionBoundAccountID > 0 if platform == service.PlatformGemini { - maxAccountSwitches := h.maxAccountSwitchesGemini - switchCount := 0 - failedAccountIDs := make(map[int64]struct{}) - sameAccountRetryCount := make(map[int64]int) // 同账号重试计数 - var lastFailoverErr *service.UpstreamFailoverError - var forceCacheBilling bool // 粘性会话切换时的缓存计费标记 + fs := NewFailoverState(h.maxAccountSwitchesGemini, hasBoundSession) // 单账号分组提前设置 SingleAccountRetry 标记,让 Service 层首次 503 就不设模型限流标记。 // 避免单账号分组收到 503 (MODEL_CAPACITY_EXHAUSTED) 时设 29s 限流,导致后续请求连续快速失败。 @@ -272,35 +266,28 @@ func (h *GatewayHandler) Messages(c *gin.Context) { } for { - selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs, "") // Gemini 不使用会话限制 + selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, fs.FailedAccountIDs, "") // Gemini 不使用会话限制 if err != nil { - if len(failedAccountIDs) == 0 { - reqLog.Warn("gateway.account_select_failed", zap.Error(err), zap.Int("excluded_account_count", len(failedAccountIDs))) - h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted) + if len(fs.FailedAccountIDs) == 0 { + h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) return } - // Antigravity 单账号退避重试:分组内没有其他可用账号时, - // 对 503 错误不直接返回,而是清除排除列表、等待退避后重试同一个账号。 - // 谷歌上游 503 (MODEL_CAPACITY_EXHAUSTED) 通常是暂时性的,等几秒就能恢复。 - if lastFailoverErr != nil && lastFailoverErr.StatusCode == http.StatusServiceUnavailable && switchCount <= maxAccountSwitches { - if sleepAntigravitySingleAccountBackoff(c.Request.Context(), switchCount) { - reqLog.Warn("gateway.single_account_retrying", - zap.Int("retry_count", switchCount), - zap.Int("max_retries", maxAccountSwitches), - ) - failedAccountIDs = make(map[int64]struct{}) - // 设置 context 标记,让 Service 层预检查等待限流过期而非直接切换 - ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true) - c.Request = c.Request.WithContext(ctx) - continue + action := fs.HandleSelectionExhausted(c.Request.Context()) + switch action { + case FailoverContinue: + ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true) + c.Request = c.Request.WithContext(ctx) + continue + case FailoverCanceled: + return + default: // FailoverExhausted + if fs.LastFailoverErr != nil { + h.handleFailoverExhausted(c, fs.LastFailoverErr, service.PlatformGemini, streamStarted) + } else { + h.handleFailoverExhaustedSimple(c, 502, streamStarted) } + return } - if lastFailoverErr != nil { - h.handleFailoverExhausted(c, lastFailoverErr, service.PlatformGemini, streamStarted) - } else { - h.handleFailoverExhaustedSimple(c, 502, streamStarted) - } - return } account := selection.Account setOpsSelectedAccount(c, account.ID, account.Platform) @@ -376,8 +363,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) { // 转发请求 - 根据账号平台分流 var result *service.ForwardResult requestCtx := c.Request.Context() - if switchCount > 0 { - requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, switchCount) + if fs.SwitchCount > 0 { + requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, fs.SwitchCount) } if account.Platform == service.PlatformAntigravity { result, err = h.antigravityGatewayService.ForwardGemini(requestCtx, c, account, reqModel, "generateContent", reqStream, body, hasBoundSession) @@ -390,45 +377,16 @@ func (h *GatewayHandler) Messages(c *gin.Context) { if err != nil { var failoverErr *service.UpstreamFailoverError if errors.As(err, &failoverErr) { - lastFailoverErr = failoverErr - if needForceCacheBilling(hasBoundSession, failoverErr) { - forceCacheBilling = true - } - - // 同账号重试:对 RetryableOnSameAccount 的临时性错误,先在同一账号上重试 - if failoverErr.RetryableOnSameAccount && sameAccountRetryCount[account.ID] < maxSameAccountRetries { - sameAccountRetryCount[account.ID]++ - log.Printf("Account %d: retryable error %d, same-account retry %d/%d", - account.ID, failoverErr.StatusCode, sameAccountRetryCount[account.ID], maxSameAccountRetries) - if !sleepSameAccountRetryDelay(c.Request.Context()) { - return - } + action := fs.HandleFailoverError(c.Request.Context(), h.gatewayService, account.ID, account.Platform, failoverErr) + switch action { + case FailoverContinue: continue - } - - // 同账号重试用尽,执行临时封禁并切换账号 - if failoverErr.RetryableOnSameAccount { - h.gatewayService.TempUnscheduleRetryableError(c.Request.Context(), account.ID, failoverErr) - } - - failedAccountIDs[account.ID] = struct{}{} - if switchCount >= maxAccountSwitches { - h.handleFailoverExhausted(c, failoverErr, service.PlatformGemini, streamStarted) + case FailoverExhausted: + h.handleFailoverExhausted(c, fs.LastFailoverErr, service.PlatformGemini, streamStarted) + return + case FailoverCanceled: return } - switchCount++ - reqLog.Warn("gateway.upstream_failover_switching", - zap.Int64("account_id", account.ID), - zap.Int("upstream_status", failoverErr.StatusCode), - zap.Int("switch_count", switchCount), - zap.Int("max_switches", maxAccountSwitches), - ) - if account.Platform == service.PlatformAntigravity { - if !sleepFailoverDelay(c.Request.Context(), switchCount) { - return - } - } - continue } wroteFallback := h.ensureForwardErrorResponse(c, streamStarted) reqLog.Error("gateway.forward_failed", @@ -453,7 +411,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { Subscription: subscription, UserAgent: userAgent, IPAddress: clientIP, - ForceCacheBilling: forceCacheBilling, + ForceCacheBilling: fs.ForceCacheBilling, APIKeyService: h.apiKeyService, }); err != nil { logger.L().With( @@ -486,45 +444,33 @@ func (h *GatewayHandler) Messages(c *gin.Context) { } for { - maxAccountSwitches := h.maxAccountSwitches - switchCount := 0 - failedAccountIDs := make(map[int64]struct{}) - sameAccountRetryCount := make(map[int64]int) // 同账号重试计数 - var lastFailoverErr *service.UpstreamFailoverError + fs := NewFailoverState(h.maxAccountSwitches, hasBoundSession) retryWithFallback := false - var forceCacheBilling bool // 粘性会话切换时的缓存计费标记 for { // 选择支持该模型的账号 - selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), currentAPIKey.GroupID, sessionKey, reqModel, failedAccountIDs, parsedReq.MetadataUserID) + selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), currentAPIKey.GroupID, sessionKey, reqModel, fs.FailedAccountIDs, parsedReq.MetadataUserID) if err != nil { - if len(failedAccountIDs) == 0 { - reqLog.Warn("gateway.account_select_failed", zap.Error(err), zap.Int("excluded_account_count", len(failedAccountIDs))) - h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted) + if len(fs.FailedAccountIDs) == 0 { + h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) return } - // Antigravity 单账号退避重试:分组内没有其他可用账号时, - // 对 503 错误不直接返回,而是清除排除列表、等待退避后重试同一个账号。 - // 谷歌上游 503 (MODEL_CAPACITY_EXHAUSTED) 通常是暂时性的,等几秒就能恢复。 - if lastFailoverErr != nil && lastFailoverErr.StatusCode == http.StatusServiceUnavailable && switchCount <= maxAccountSwitches { - if sleepAntigravitySingleAccountBackoff(c.Request.Context(), switchCount) { - reqLog.Warn("gateway.single_account_retrying", - zap.Int("retry_count", switchCount), - zap.Int("max_retries", maxAccountSwitches), - ) - failedAccountIDs = make(map[int64]struct{}) - // 设置 context 标记,让 Service 层预检查等待限流过期而非直接切换 - ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true) - c.Request = c.Request.WithContext(ctx) - continue + action := fs.HandleSelectionExhausted(c.Request.Context()) + switch action { + case FailoverContinue: + ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true) + c.Request = c.Request.WithContext(ctx) + continue + case FailoverCanceled: + return + default: // FailoverExhausted + if fs.LastFailoverErr != nil { + h.handleFailoverExhausted(c, fs.LastFailoverErr, platform, streamStarted) + } else { + h.handleFailoverExhaustedSimple(c, 502, streamStarted) } + return } - if lastFailoverErr != nil { - h.handleFailoverExhausted(c, lastFailoverErr, platform, streamStarted) - } else { - h.handleFailoverExhaustedSimple(c, 502, streamStarted) - } - return } account := selection.Account setOpsSelectedAccount(c, account.ID, account.Platform) @@ -600,8 +546,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) { // 转发请求 - 根据账号平台分流 var result *service.ForwardResult requestCtx := c.Request.Context() - if switchCount > 0 { - requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, switchCount) + if fs.SwitchCount > 0 { + requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, fs.SwitchCount) } if account.Platform == service.PlatformAntigravity && account.Type != service.AccountTypeAPIKey { result, err = h.antigravityGatewayService.Forward(requestCtx, c, account, body, hasBoundSession) @@ -657,45 +603,16 @@ func (h *GatewayHandler) Messages(c *gin.Context) { } var failoverErr *service.UpstreamFailoverError if errors.As(err, &failoverErr) { - lastFailoverErr = failoverErr - if needForceCacheBilling(hasBoundSession, failoverErr) { - forceCacheBilling = true - } - - // 同账号重试:对 RetryableOnSameAccount 的临时性错误,先在同一账号上重试 - if failoverErr.RetryableOnSameAccount && sameAccountRetryCount[account.ID] < maxSameAccountRetries { - sameAccountRetryCount[account.ID]++ - log.Printf("Account %d: retryable error %d, same-account retry %d/%d", - account.ID, failoverErr.StatusCode, sameAccountRetryCount[account.ID], maxSameAccountRetries) - if !sleepSameAccountRetryDelay(c.Request.Context()) { - return - } + action := fs.HandleFailoverError(c.Request.Context(), h.gatewayService, account.ID, account.Platform, failoverErr) + switch action { + case FailoverContinue: continue - } - - // 同账号重试用尽,执行临时封禁并切换账号 - if failoverErr.RetryableOnSameAccount { - h.gatewayService.TempUnscheduleRetryableError(c.Request.Context(), account.ID, failoverErr) - } - - failedAccountIDs[account.ID] = struct{}{} - if switchCount >= maxAccountSwitches { - h.handleFailoverExhausted(c, failoverErr, account.Platform, streamStarted) + case FailoverExhausted: + h.handleFailoverExhausted(c, fs.LastFailoverErr, account.Platform, streamStarted) + return + case FailoverCanceled: return } - switchCount++ - reqLog.Warn("gateway.upstream_failover_switching", - zap.Int64("account_id", account.ID), - zap.Int("upstream_status", failoverErr.StatusCode), - zap.Int("switch_count", switchCount), - zap.Int("max_switches", maxAccountSwitches), - ) - if account.Platform == service.PlatformAntigravity { - if !sleepFailoverDelay(c.Request.Context(), switchCount) { - return - } - } - continue } wroteFallback := h.ensureForwardErrorResponse(c, streamStarted) reqLog.Error("gateway.forward_failed", @@ -720,7 +637,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { Subscription: currentSubscription, UserAgent: userAgent, IPAddress: clientIP, - ForceCacheBilling: forceCacheBilling, + ForceCacheBilling: fs.ForceCacheBilling, APIKeyService: h.apiKeyService, }); err != nil { logger.L().With( @@ -735,7 +652,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { }) reqLog.Debug("gateway.request_completed", zap.Int64("account_id", account.ID), - zap.Int("switch_count", switchCount), + zap.Int("switch_count", fs.SwitchCount), zap.Bool("fallback_used", fallbackUsed), ) return @@ -982,69 +899,6 @@ func (h *GatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotT fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted) } -// needForceCacheBilling 判断 failover 时是否需要强制缓存计费 -// 粘性会话切换账号、或上游明确标记时,将 input_tokens 转为 cache_read 计费 -func needForceCacheBilling(hasBoundSession bool, failoverErr *service.UpstreamFailoverError) bool { - return hasBoundSession || (failoverErr != nil && failoverErr.ForceCacheBilling) -} - -const ( - // maxSameAccountRetries 同账号重试次数上限(针对 RetryableOnSameAccount 错误) - maxSameAccountRetries = 2 - // sameAccountRetryDelay 同账号重试间隔 - sameAccountRetryDelay = 500 * time.Millisecond -) - -// sleepSameAccountRetryDelay 同账号重试固定延时,返回 false 表示 context 已取消。 -func sleepSameAccountRetryDelay(ctx context.Context) bool { - select { - case <-ctx.Done(): - return false - case <-time.After(sameAccountRetryDelay): - return true - } -} - -// sleepFailoverDelay 账号切换线性递增延时:第1次0s、第2次1s、第3次2s… -// 返回 false 表示 context 已取消。 -func sleepFailoverDelay(ctx context.Context, switchCount int) bool { - delay := time.Duration(switchCount-1) * time.Second - if delay <= 0 { - return true - } - select { - case <-ctx.Done(): - return false - case <-time.After(delay): - return true - } -} - -// sleepAntigravitySingleAccountBackoff Antigravity 平台单账号分组的 503 退避重试延时。 -// 当分组内只有一个可用账号且上游返回 503(MODEL_CAPACITY_EXHAUSTED)时使用, -// 采用短固定延时策略。Service 层在 SingleAccountRetry 模式下已经做了充分的原地重试 -// (最多 3 次、总等待 30s),所以 Handler 层的退避只需短暂等待即可。 -// 返回 false 表示 context 已取消。 -func sleepAntigravitySingleAccountBackoff(ctx context.Context, retryCount int) bool { - // 固定短延时:2s - // Service 层已经在原地等待了足够长的时间(retryDelay × 重试次数), - // Handler 层只需短暂间隔后重新进入 Service 层即可。 - const delay = 2 * time.Second - - logger.L().With( - zap.String("component", "handler.gateway.failover"), - zap.Duration("delay", delay), - zap.Int("retry_count", retryCount), - ).Info("gateway.single_account_backoff_waiting") - - select { - case <-ctx.Done(): - return false - case <-time.After(delay): - return true - } -} - func (h *GatewayHandler) handleFailoverExhausted(c *gin.Context, failoverErr *service.UpstreamFailoverError, platform string, streamStarted bool) { statusCode := failoverErr.StatusCode responseBody := failoverErr.ResponseBody diff --git a/backend/internal/handler/gateway_handler_single_account_retry_test.go b/backend/internal/handler/gateway_handler_single_account_retry_test.go deleted file mode 100644 index 96aa14c6..00000000 --- a/backend/internal/handler/gateway_handler_single_account_retry_test.go +++ /dev/null @@ -1,51 +0,0 @@ -package handler - -import ( - "context" - "testing" - "time" - - "github.com/stretchr/testify/require" -) - -// --------------------------------------------------------------------------- -// sleepAntigravitySingleAccountBackoff 测试 -// --------------------------------------------------------------------------- - -func TestSleepAntigravitySingleAccountBackoff_ReturnsTrue(t *testing.T) { - ctx := context.Background() - start := time.Now() - ok := sleepAntigravitySingleAccountBackoff(ctx, 1) - elapsed := time.Since(start) - - require.True(t, ok, "should return true when context is not canceled") - // 固定延迟 2s - require.GreaterOrEqual(t, elapsed, 1500*time.Millisecond, "should wait approximately 2s") - require.Less(t, elapsed, 5*time.Second, "should not wait too long") -} - -func TestSleepAntigravitySingleAccountBackoff_ContextCanceled(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - cancel() // 立即取消 - - start := time.Now() - ok := sleepAntigravitySingleAccountBackoff(ctx, 1) - elapsed := time.Since(start) - - require.False(t, ok, "should return false when context is canceled") - require.Less(t, elapsed, 500*time.Millisecond, "should return immediately on cancel") -} - -func TestSleepAntigravitySingleAccountBackoff_FixedDelay(t *testing.T) { - // 验证不同 retryCount 都使用固定 2s 延迟 - ctx := context.Background() - - start := time.Now() - ok := sleepAntigravitySingleAccountBackoff(ctx, 5) - elapsed := time.Since(start) - - require.True(t, ok) - // 即使 retryCount=5,延迟仍然是固定的 2s - require.GreaterOrEqual(t, elapsed, 1500*time.Millisecond) - require.Less(t, elapsed, 5*time.Second) -} diff --git a/backend/internal/handler/gemini_v1beta_handler.go b/backend/internal/handler/gemini_v1beta_handler.go index ea212088..2da0570b 100644 --- a/backend/internal/handler/gemini_v1beta_handler.go +++ b/backend/internal/handler/gemini_v1beta_handler.go @@ -344,11 +344,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { hasBoundSession := sessionKey != "" && sessionBoundAccountID > 0 cleanedForUnknownBinding := false - maxAccountSwitches := h.maxAccountSwitchesGemini - switchCount := 0 - failedAccountIDs := make(map[int64]struct{}) - var lastFailoverErr *service.UpstreamFailoverError - var forceCacheBilling bool // 粘性会话切换时的缓存计费标记 + fs := NewFailoverState(h.maxAccountSwitchesGemini, hasBoundSession) // 单账号分组提前设置 SingleAccountRetry 标记,让 Service 层首次 503 就不设模型限流标记。 // 避免单账号分组收到 503 (MODEL_CAPACITY_EXHAUSTED) 时设 29s 限流,导致后续请求连续快速失败。 @@ -358,30 +354,24 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { } for { - selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, modelName, failedAccountIDs, "") // Gemini 不使用会话限制 + selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, modelName, fs.FailedAccountIDs, "") // Gemini 不使用会话限制 if err != nil { - if len(failedAccountIDs) == 0 { + if len(fs.FailedAccountIDs) == 0 { googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error()) return } - // Antigravity 单账号退避重试:分组内没有其他可用账号时, - // 对 503 错误不直接返回,而是清除排除列表、等待退避后重试同一个账号。 - // 谷歌上游 503 (MODEL_CAPACITY_EXHAUSTED) 通常是暂时性的,等几秒就能恢复。 - if lastFailoverErr != nil && lastFailoverErr.StatusCode == http.StatusServiceUnavailable && switchCount <= maxAccountSwitches { - if sleepAntigravitySingleAccountBackoff(c.Request.Context(), switchCount) { - reqLog.Warn("gemini.single_account_retrying", - zap.Int("retry_count", switchCount), - zap.Int("max_retries", maxAccountSwitches), - ) - failedAccountIDs = make(map[int64]struct{}) - // 设置 context 标记,让 Service 层预检查等待限流过期而非直接切换 - ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true) - c.Request = c.Request.WithContext(ctx) - continue - } + action := fs.HandleSelectionExhausted(c.Request.Context()) + switch action { + case FailoverContinue: + ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true) + c.Request = c.Request.WithContext(ctx) + continue + case FailoverCanceled: + return + default: // FailoverExhausted + h.handleGeminiFailoverExhausted(c, fs.LastFailoverErr) + return } - h.handleGeminiFailoverExhausted(c, lastFailoverErr) - return } account := selection.Account setOpsSelectedAccount(c, account.ID, account.Platform) @@ -465,8 +455,8 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { // 5) forward (根据平台分流) var result *service.ForwardResult requestCtx := c.Request.Context() - if switchCount > 0 { - requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, switchCount) + if fs.SwitchCount > 0 { + requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, fs.SwitchCount) } if account.Platform == service.PlatformAntigravity && account.Type != service.AccountTypeAPIKey { result, err = h.antigravityGatewayService.ForwardGemini(requestCtx, c, account, modelName, action, stream, body, hasBoundSession) @@ -479,29 +469,16 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { if err != nil { var failoverErr *service.UpstreamFailoverError if errors.As(err, &failoverErr) { - failedAccountIDs[account.ID] = struct{}{} - if needForceCacheBilling(hasBoundSession, failoverErr) { - forceCacheBilling = true - } - if switchCount >= maxAccountSwitches { - lastFailoverErr = failoverErr - h.handleGeminiFailoverExhausted(c, lastFailoverErr) + failoverAction := fs.HandleFailoverError(c.Request.Context(), h.gatewayService, account.ID, account.Platform, failoverErr) + switch failoverAction { + case FailoverContinue: + continue + case FailoverExhausted: + h.handleGeminiFailoverExhausted(c, fs.LastFailoverErr) + return + case FailoverCanceled: return } - lastFailoverErr = failoverErr - switchCount++ - reqLog.Warn("gemini.upstream_failover_switching", - zap.Int64("account_id", account.ID), - zap.Int("upstream_status", failoverErr.StatusCode), - zap.Int("switch_count", switchCount), - zap.Int("max_switches", maxAccountSwitches), - ) - if account.Platform == service.PlatformAntigravity { - if !sleepFailoverDelay(c.Request.Context(), switchCount) { - return - } - } - continue } // ForwardNative already wrote the response reqLog.Error("gemini.forward_failed", zap.Int64("account_id", account.ID), zap.Error(err)) @@ -539,7 +516,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { IPAddress: clientIP, LongContextThreshold: 200000, // Gemini 200K 阈值 LongContextMultiplier: 2.0, // 超出部分双倍计费 - ForceCacheBilling: forceCacheBilling, + ForceCacheBilling: fs.ForceCacheBilling, APIKeyService: h.apiKeyService, }); err != nil { logger.L().With( @@ -554,7 +531,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { }) reqLog.Debug("gemini.request_completed", zap.Int64("account_id", account.ID), - zap.Int("switch_count", switchCount), + zap.Int("switch_count", fs.SwitchCount), ) return }