refactor: extract failover error handling into FailoverState
- Extract duplicated failover logic from gateway_handler.go (3 places) and gemini_v1beta_handler.go into shared failover_loop.go - Introduce FailoverState with HandleFailoverError and HandleSelectionExhausted - Move helper functions (needForceCacheBilling, sleepWithContext) into failover_loop.go - Add comprehensive unit tests (32+ test cases) - Delete redundant gateway_handler_single_account_retry_test.go
This commit is contained in:
160
backend/internal/handler/failover_loop.go
Normal file
160
backend/internal/handler/failover_loop.go
Normal file
@@ -0,0 +1,160 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TempUnscheduler 用于 HandleFailoverError 中同账号重试耗尽后的临时封禁。
|
||||||
|
// GatewayService 隐式实现此接口。
|
||||||
|
type TempUnscheduler interface {
|
||||||
|
TempUnscheduleRetryableError(ctx context.Context, accountID int64, failoverErr *service.UpstreamFailoverError)
|
||||||
|
}
|
||||||
|
|
||||||
|
// FailoverAction 表示 failover 错误处理后的下一步动作
|
||||||
|
type FailoverAction int
|
||||||
|
|
||||||
|
const (
|
||||||
|
// FailoverContinue 继续循环(同账号重试或切换账号,调用方统一 continue)
|
||||||
|
FailoverContinue FailoverAction = iota
|
||||||
|
// FailoverExhausted 切换次数耗尽(调用方应返回错误响应)
|
||||||
|
FailoverExhausted
|
||||||
|
// FailoverCanceled context 已取消(调用方应直接 return)
|
||||||
|
FailoverCanceled
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// maxSameAccountRetries 同账号重试次数上限(针对 RetryableOnSameAccount 错误)
|
||||||
|
maxSameAccountRetries = 2
|
||||||
|
// sameAccountRetryDelay 同账号重试间隔
|
||||||
|
sameAccountRetryDelay = 500 * time.Millisecond
|
||||||
|
// singleAccountBackoffDelay 单账号分组 503 退避重试固定延时。
|
||||||
|
// Service 层在 SingleAccountRetry 模式下已做充分原地重试(最多 3 次、总等待 30s),
|
||||||
|
// Handler 层只需短暂间隔后重新进入 Service 层即可。
|
||||||
|
singleAccountBackoffDelay = 2 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
// FailoverState 跨循环迭代共享的 failover 状态
|
||||||
|
type FailoverState struct {
|
||||||
|
SwitchCount int
|
||||||
|
MaxSwitches int
|
||||||
|
FailedAccountIDs map[int64]struct{}
|
||||||
|
SameAccountRetryCount map[int64]int
|
||||||
|
LastFailoverErr *service.UpstreamFailoverError
|
||||||
|
ForceCacheBilling bool
|
||||||
|
hasBoundSession bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewFailoverState 创建 failover 状态
|
||||||
|
func NewFailoverState(maxSwitches int, hasBoundSession bool) *FailoverState {
|
||||||
|
return &FailoverState{
|
||||||
|
MaxSwitches: maxSwitches,
|
||||||
|
FailedAccountIDs: make(map[int64]struct{}),
|
||||||
|
SameAccountRetryCount: make(map[int64]int),
|
||||||
|
hasBoundSession: hasBoundSession,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandleFailoverError 处理 UpstreamFailoverError,返回下一步动作。
|
||||||
|
// 包含:缓存计费判断、同账号重试、临时封禁、切换计数、Antigravity 延时。
|
||||||
|
func (s *FailoverState) HandleFailoverError(
|
||||||
|
ctx context.Context,
|
||||||
|
gatewayService TempUnscheduler,
|
||||||
|
accountID int64,
|
||||||
|
platform string,
|
||||||
|
failoverErr *service.UpstreamFailoverError,
|
||||||
|
) FailoverAction {
|
||||||
|
s.LastFailoverErr = failoverErr
|
||||||
|
|
||||||
|
// 缓存计费判断
|
||||||
|
if needForceCacheBilling(s.hasBoundSession, failoverErr) {
|
||||||
|
s.ForceCacheBilling = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// 同账号重试:对 RetryableOnSameAccount 的临时性错误,先在同一账号上重试
|
||||||
|
if failoverErr.RetryableOnSameAccount && s.SameAccountRetryCount[accountID] < maxSameAccountRetries {
|
||||||
|
s.SameAccountRetryCount[accountID]++
|
||||||
|
log.Printf("Account %d: retryable error %d, same-account retry %d/%d",
|
||||||
|
accountID, failoverErr.StatusCode, s.SameAccountRetryCount[accountID], maxSameAccountRetries)
|
||||||
|
if !sleepWithContext(ctx, sameAccountRetryDelay) {
|
||||||
|
return FailoverCanceled
|
||||||
|
}
|
||||||
|
return FailoverContinue
|
||||||
|
}
|
||||||
|
|
||||||
|
// 同账号重试用尽,执行临时封禁
|
||||||
|
if failoverErr.RetryableOnSameAccount {
|
||||||
|
gatewayService.TempUnscheduleRetryableError(ctx, accountID, failoverErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 加入失败列表
|
||||||
|
s.FailedAccountIDs[accountID] = struct{}{}
|
||||||
|
|
||||||
|
// 检查是否耗尽
|
||||||
|
if s.SwitchCount >= s.MaxSwitches {
|
||||||
|
return FailoverExhausted
|
||||||
|
}
|
||||||
|
|
||||||
|
// 递增切换计数
|
||||||
|
s.SwitchCount++
|
||||||
|
log.Printf("Account %d: upstream error %d, switching account %d/%d",
|
||||||
|
accountID, failoverErr.StatusCode, s.SwitchCount, s.MaxSwitches)
|
||||||
|
|
||||||
|
// Antigravity 平台换号线性递增延时
|
||||||
|
if platform == service.PlatformAntigravity {
|
||||||
|
delay := time.Duration(s.SwitchCount-1) * time.Second
|
||||||
|
if !sleepWithContext(ctx, delay) {
|
||||||
|
return FailoverCanceled
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return FailoverContinue
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandleSelectionExhausted 处理选号失败(所有候选账号都在排除列表中)时的退避重试决策。
|
||||||
|
// 针对 Antigravity 单账号分组的 503 (MODEL_CAPACITY_EXHAUSTED) 场景:
|
||||||
|
// 清除排除列表、等待退避后重新选号。
|
||||||
|
//
|
||||||
|
// 返回 FailoverContinue 时,调用方应设置 SingleAccountRetry context 并 continue。
|
||||||
|
// 返回 FailoverExhausted 时,调用方应返回错误响应。
|
||||||
|
// 返回 FailoverCanceled 时,调用方应直接 return。
|
||||||
|
func (s *FailoverState) HandleSelectionExhausted(ctx context.Context) FailoverAction {
|
||||||
|
if s.LastFailoverErr != nil &&
|
||||||
|
s.LastFailoverErr.StatusCode == http.StatusServiceUnavailable &&
|
||||||
|
s.SwitchCount <= s.MaxSwitches {
|
||||||
|
|
||||||
|
log.Printf("Antigravity single-account 503 backoff: waiting %v before retry (attempt %d)",
|
||||||
|
singleAccountBackoffDelay, s.SwitchCount)
|
||||||
|
if !sleepWithContext(ctx, singleAccountBackoffDelay) {
|
||||||
|
return FailoverCanceled
|
||||||
|
}
|
||||||
|
log.Printf("Antigravity single-account 503 retry: clearing failed accounts, retry %d/%d",
|
||||||
|
s.SwitchCount, s.MaxSwitches)
|
||||||
|
s.FailedAccountIDs = make(map[int64]struct{})
|
||||||
|
return FailoverContinue
|
||||||
|
}
|
||||||
|
return FailoverExhausted
|
||||||
|
}
|
||||||
|
|
||||||
|
// needForceCacheBilling 判断 failover 时是否需要强制缓存计费。
|
||||||
|
// 粘性会话切换账号、或上游明确标记时,将 input_tokens 转为 cache_read 计费。
|
||||||
|
func needForceCacheBilling(hasBoundSession bool, failoverErr *service.UpstreamFailoverError) bool {
|
||||||
|
return hasBoundSession || (failoverErr != nil && failoverErr.ForceCacheBilling)
|
||||||
|
}
|
||||||
|
|
||||||
|
// sleepWithContext 等待指定时长,返回 false 表示 context 已取消。
|
||||||
|
func sleepWithContext(ctx context.Context, d time.Duration) bool {
|
||||||
|
if d <= 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return false
|
||||||
|
case <-time.After(d):
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
732
backend/internal/handler/failover_loop_test.go
Normal file
732
backend/internal/handler/failover_loop_test.go
Normal file
@@ -0,0 +1,732 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Mock
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
// mockTempUnscheduler 记录 TempUnscheduleRetryableError 的调用信息。
|
||||||
|
type mockTempUnscheduler struct {
|
||||||
|
calls []tempUnscheduleCall
|
||||||
|
}
|
||||||
|
|
||||||
|
type tempUnscheduleCall struct {
|
||||||
|
accountID int64
|
||||||
|
failoverErr *service.UpstreamFailoverError
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockTempUnscheduler) TempUnscheduleRetryableError(_ context.Context, accountID int64, failoverErr *service.UpstreamFailoverError) {
|
||||||
|
m.calls = append(m.calls, tempUnscheduleCall{accountID: accountID, failoverErr: failoverErr})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Helper
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func newTestFailoverErr(statusCode int, retryable, forceBilling bool) *service.UpstreamFailoverError {
|
||||||
|
return &service.UpstreamFailoverError{
|
||||||
|
StatusCode: statusCode,
|
||||||
|
RetryableOnSameAccount: retryable,
|
||||||
|
ForceCacheBilling: forceBilling,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// NewFailoverState 测试
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestNewFailoverState(t *testing.T) {
|
||||||
|
t.Run("初始化字段正确", func(t *testing.T) {
|
||||||
|
fs := NewFailoverState(5, true)
|
||||||
|
require.Equal(t, 5, fs.MaxSwitches)
|
||||||
|
require.Equal(t, 0, fs.SwitchCount)
|
||||||
|
require.NotNil(t, fs.FailedAccountIDs)
|
||||||
|
require.Empty(t, fs.FailedAccountIDs)
|
||||||
|
require.NotNil(t, fs.SameAccountRetryCount)
|
||||||
|
require.Empty(t, fs.SameAccountRetryCount)
|
||||||
|
require.Nil(t, fs.LastFailoverErr)
|
||||||
|
require.False(t, fs.ForceCacheBilling)
|
||||||
|
require.True(t, fs.hasBoundSession)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("无绑定会话", func(t *testing.T) {
|
||||||
|
fs := NewFailoverState(3, false)
|
||||||
|
require.Equal(t, 3, fs.MaxSwitches)
|
||||||
|
require.False(t, fs.hasBoundSession)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("零最大切换次数", func(t *testing.T) {
|
||||||
|
fs := NewFailoverState(0, false)
|
||||||
|
require.Equal(t, 0, fs.MaxSwitches)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// sleepWithContext 测试
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestSleepWithContext(t *testing.T) {
|
||||||
|
t.Run("零时长立即返回true", func(t *testing.T) {
|
||||||
|
start := time.Now()
|
||||||
|
ok := sleepWithContext(context.Background(), 0)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Less(t, time.Since(start), 50*time.Millisecond)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("负时长立即返回true", func(t *testing.T) {
|
||||||
|
start := time.Now()
|
||||||
|
ok := sleepWithContext(context.Background(), -1*time.Second)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Less(t, time.Since(start), 50*time.Millisecond)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("正常等待后返回true", func(t *testing.T) {
|
||||||
|
start := time.Now()
|
||||||
|
ok := sleepWithContext(context.Background(), 50*time.Millisecond)
|
||||||
|
elapsed := time.Since(start)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.GreaterOrEqual(t, elapsed, 40*time.Millisecond)
|
||||||
|
require.Less(t, elapsed, 500*time.Millisecond)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("已取消context立即返回false", func(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
start := time.Now()
|
||||||
|
ok := sleepWithContext(ctx, 5*time.Second)
|
||||||
|
require.False(t, ok)
|
||||||
|
require.Less(t, time.Since(start), 50*time.Millisecond)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("等待期间context取消返回false", func(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
go func() {
|
||||||
|
time.Sleep(30 * time.Millisecond)
|
||||||
|
cancel()
|
||||||
|
}()
|
||||||
|
|
||||||
|
start := time.Now()
|
||||||
|
ok := sleepWithContext(ctx, 5*time.Second)
|
||||||
|
elapsed := time.Since(start)
|
||||||
|
require.False(t, ok)
|
||||||
|
require.Less(t, elapsed, 500*time.Millisecond)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// HandleFailoverError — 基本切换流程
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestHandleFailoverError_BasicSwitch(t *testing.T) {
|
||||||
|
t.Run("非重试错误_非Antigravity_直接切换", func(t *testing.T) {
|
||||||
|
mock := &mockTempUnscheduler{}
|
||||||
|
fs := NewFailoverState(3, false)
|
||||||
|
err := newTestFailoverErr(500, false, false)
|
||||||
|
|
||||||
|
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||||
|
|
||||||
|
require.Equal(t, FailoverContinue, action)
|
||||||
|
require.Equal(t, 1, fs.SwitchCount)
|
||||||
|
require.Contains(t, fs.FailedAccountIDs, int64(100))
|
||||||
|
require.Equal(t, err, fs.LastFailoverErr)
|
||||||
|
require.False(t, fs.ForceCacheBilling)
|
||||||
|
require.Empty(t, mock.calls, "不应调用 TempUnschedule")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("非重试错误_Antigravity_第一次切换无延迟", func(t *testing.T) {
|
||||||
|
// switchCount 从 0→1 时,sleepFailoverDelay(ctx, 1) 的延时 = (1-1)*1s = 0
|
||||||
|
mock := &mockTempUnscheduler{}
|
||||||
|
fs := NewFailoverState(3, false)
|
||||||
|
err := newTestFailoverErr(500, false, false)
|
||||||
|
|
||||||
|
start := time.Now()
|
||||||
|
action := fs.HandleFailoverError(context.Background(), mock, 100, service.PlatformAntigravity, err)
|
||||||
|
elapsed := time.Since(start)
|
||||||
|
|
||||||
|
require.Equal(t, FailoverContinue, action)
|
||||||
|
require.Equal(t, 1, fs.SwitchCount)
|
||||||
|
require.Less(t, elapsed, 200*time.Millisecond, "第一次切换延迟应为 0")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("非重试错误_Antigravity_第二次切换有1秒延迟", func(t *testing.T) {
|
||||||
|
// switchCount 从 1→2 时,sleepFailoverDelay(ctx, 2) 的延时 = (2-1)*1s = 1s
|
||||||
|
mock := &mockTempUnscheduler{}
|
||||||
|
fs := NewFailoverState(3, false)
|
||||||
|
fs.SwitchCount = 1 // 模拟已切换一次
|
||||||
|
|
||||||
|
err := newTestFailoverErr(500, false, false)
|
||||||
|
start := time.Now()
|
||||||
|
action := fs.HandleFailoverError(context.Background(), mock, 200, service.PlatformAntigravity, err)
|
||||||
|
elapsed := time.Since(start)
|
||||||
|
|
||||||
|
require.Equal(t, FailoverContinue, action)
|
||||||
|
require.Equal(t, 2, fs.SwitchCount)
|
||||||
|
require.GreaterOrEqual(t, elapsed, 800*time.Millisecond, "第二次切换延迟应约 1s")
|
||||||
|
require.Less(t, elapsed, 3*time.Second)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("连续切换直到耗尽", func(t *testing.T) {
|
||||||
|
mock := &mockTempUnscheduler{}
|
||||||
|
fs := NewFailoverState(2, false)
|
||||||
|
|
||||||
|
// 第一次切换:0→1
|
||||||
|
err1 := newTestFailoverErr(500, false, false)
|
||||||
|
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err1)
|
||||||
|
require.Equal(t, FailoverContinue, action)
|
||||||
|
require.Equal(t, 1, fs.SwitchCount)
|
||||||
|
|
||||||
|
// 第二次切换:1→2
|
||||||
|
err2 := newTestFailoverErr(502, false, false)
|
||||||
|
action = fs.HandleFailoverError(context.Background(), mock, 200, "openai", err2)
|
||||||
|
require.Equal(t, FailoverContinue, action)
|
||||||
|
require.Equal(t, 2, fs.SwitchCount)
|
||||||
|
|
||||||
|
// 第三次已耗尽:SwitchCount(2) >= MaxSwitches(2)
|
||||||
|
err3 := newTestFailoverErr(503, false, false)
|
||||||
|
action = fs.HandleFailoverError(context.Background(), mock, 300, "openai", err3)
|
||||||
|
require.Equal(t, FailoverExhausted, action)
|
||||||
|
require.Equal(t, 2, fs.SwitchCount, "耗尽时不应继续递增")
|
||||||
|
|
||||||
|
// 验证失败账号列表
|
||||||
|
require.Len(t, fs.FailedAccountIDs, 3)
|
||||||
|
require.Contains(t, fs.FailedAccountIDs, int64(100))
|
||||||
|
require.Contains(t, fs.FailedAccountIDs, int64(200))
|
||||||
|
require.Contains(t, fs.FailedAccountIDs, int64(300))
|
||||||
|
|
||||||
|
// LastFailoverErr 应为最后一次的错误
|
||||||
|
require.Equal(t, err3, fs.LastFailoverErr)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("MaxSwitches为0时首次即耗尽", func(t *testing.T) {
|
||||||
|
mock := &mockTempUnscheduler{}
|
||||||
|
fs := NewFailoverState(0, false)
|
||||||
|
err := newTestFailoverErr(500, false, false)
|
||||||
|
|
||||||
|
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||||
|
require.Equal(t, FailoverExhausted, action)
|
||||||
|
require.Equal(t, 0, fs.SwitchCount)
|
||||||
|
require.Contains(t, fs.FailedAccountIDs, int64(100))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// HandleFailoverError — 缓存计费 (ForceCacheBilling)
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestHandleFailoverError_CacheBilling(t *testing.T) {
|
||||||
|
t.Run("hasBoundSession为true时设置ForceCacheBilling", func(t *testing.T) {
|
||||||
|
mock := &mockTempUnscheduler{}
|
||||||
|
fs := NewFailoverState(3, true) // hasBoundSession=true
|
||||||
|
err := newTestFailoverErr(500, false, false)
|
||||||
|
|
||||||
|
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||||
|
require.True(t, fs.ForceCacheBilling)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("failoverErr.ForceCacheBilling为true时设置", func(t *testing.T) {
|
||||||
|
mock := &mockTempUnscheduler{}
|
||||||
|
fs := NewFailoverState(3, false)
|
||||||
|
err := newTestFailoverErr(500, false, true) // ForceCacheBilling=true
|
||||||
|
|
||||||
|
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||||
|
require.True(t, fs.ForceCacheBilling)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("两者均为false时不设置", func(t *testing.T) {
|
||||||
|
mock := &mockTempUnscheduler{}
|
||||||
|
fs := NewFailoverState(3, false)
|
||||||
|
err := newTestFailoverErr(500, false, false)
|
||||||
|
|
||||||
|
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||||
|
require.False(t, fs.ForceCacheBilling)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("一旦设置不会被后续错误重置", func(t *testing.T) {
|
||||||
|
mock := &mockTempUnscheduler{}
|
||||||
|
fs := NewFailoverState(3, false)
|
||||||
|
|
||||||
|
// 第一次:ForceCacheBilling=true → 设置
|
||||||
|
err1 := newTestFailoverErr(500, false, true)
|
||||||
|
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err1)
|
||||||
|
require.True(t, fs.ForceCacheBilling)
|
||||||
|
|
||||||
|
// 第二次:ForceCacheBilling=false → 仍然保持 true
|
||||||
|
err2 := newTestFailoverErr(502, false, false)
|
||||||
|
fs.HandleFailoverError(context.Background(), mock, 200, "openai", err2)
|
||||||
|
require.True(t, fs.ForceCacheBilling, "ForceCacheBilling 一旦设置不应被重置")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// HandleFailoverError — 同账号重试 (RetryableOnSameAccount)
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestHandleFailoverError_SameAccountRetry(t *testing.T) {
|
||||||
|
t.Run("第一次重试返回FailoverContinue", func(t *testing.T) {
|
||||||
|
mock := &mockTempUnscheduler{}
|
||||||
|
fs := NewFailoverState(3, false)
|
||||||
|
err := newTestFailoverErr(400, true, false)
|
||||||
|
|
||||||
|
start := time.Now()
|
||||||
|
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||||
|
elapsed := time.Since(start)
|
||||||
|
|
||||||
|
require.Equal(t, FailoverContinue, action)
|
||||||
|
require.Equal(t, 1, fs.SameAccountRetryCount[100])
|
||||||
|
require.Equal(t, 0, fs.SwitchCount, "同账号重试不应增加切换计数")
|
||||||
|
require.NotContains(t, fs.FailedAccountIDs, int64(100), "同账号重试不应加入失败列表")
|
||||||
|
require.Empty(t, mock.calls, "同账号重试期间不应调用 TempUnschedule")
|
||||||
|
// 验证等待了 sameAccountRetryDelay (500ms)
|
||||||
|
require.GreaterOrEqual(t, elapsed, 400*time.Millisecond)
|
||||||
|
require.Less(t, elapsed, 2*time.Second)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("第二次重试仍返回FailoverContinue", func(t *testing.T) {
|
||||||
|
mock := &mockTempUnscheduler{}
|
||||||
|
fs := NewFailoverState(3, false)
|
||||||
|
err := newTestFailoverErr(400, true, false)
|
||||||
|
|
||||||
|
// 第一次
|
||||||
|
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||||
|
require.Equal(t, FailoverContinue, action)
|
||||||
|
require.Equal(t, 1, fs.SameAccountRetryCount[100])
|
||||||
|
|
||||||
|
// 第二次
|
||||||
|
action = fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||||
|
require.Equal(t, FailoverContinue, action)
|
||||||
|
require.Equal(t, 2, fs.SameAccountRetryCount[100])
|
||||||
|
|
||||||
|
require.Empty(t, mock.calls, "两次重试期间均不应调用 TempUnschedule")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("第三次重试耗尽_触发TempUnschedule并切换", func(t *testing.T) {
|
||||||
|
mock := &mockTempUnscheduler{}
|
||||||
|
fs := NewFailoverState(3, false)
|
||||||
|
err := newTestFailoverErr(400, true, false)
|
||||||
|
|
||||||
|
// 第一次、第二次重试
|
||||||
|
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||||
|
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||||
|
require.Equal(t, 2, fs.SameAccountRetryCount[100])
|
||||||
|
|
||||||
|
// 第三次:重试已达到 maxSameAccountRetries(2),应切换账号
|
||||||
|
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||||
|
require.Equal(t, FailoverContinue, action)
|
||||||
|
require.Equal(t, 1, fs.SwitchCount)
|
||||||
|
require.Contains(t, fs.FailedAccountIDs, int64(100))
|
||||||
|
|
||||||
|
// 验证 TempUnschedule 被调用
|
||||||
|
require.Len(t, mock.calls, 1)
|
||||||
|
require.Equal(t, int64(100), mock.calls[0].accountID)
|
||||||
|
require.Equal(t, err, mock.calls[0].failoverErr)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("不同账号独立跟踪重试次数", func(t *testing.T) {
|
||||||
|
mock := &mockTempUnscheduler{}
|
||||||
|
fs := NewFailoverState(5, false)
|
||||||
|
err := newTestFailoverErr(400, true, false)
|
||||||
|
|
||||||
|
// 账号 100 第一次重试
|
||||||
|
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||||
|
require.Equal(t, FailoverContinue, action)
|
||||||
|
require.Equal(t, 1, fs.SameAccountRetryCount[100])
|
||||||
|
|
||||||
|
// 账号 200 第一次重试(独立计数)
|
||||||
|
action = fs.HandleFailoverError(context.Background(), mock, 200, "openai", err)
|
||||||
|
require.Equal(t, FailoverContinue, action)
|
||||||
|
require.Equal(t, 1, fs.SameAccountRetryCount[200])
|
||||||
|
require.Equal(t, 1, fs.SameAccountRetryCount[100], "账号 100 的计数不应受影响")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("重试耗尽后再次遇到同账号_直接切换", func(t *testing.T) {
|
||||||
|
mock := &mockTempUnscheduler{}
|
||||||
|
fs := NewFailoverState(5, false)
|
||||||
|
err := newTestFailoverErr(400, true, false)
|
||||||
|
|
||||||
|
// 耗尽账号 100 的重试
|
||||||
|
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||||
|
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||||
|
// 第三次: 重试耗尽 → 切换
|
||||||
|
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||||
|
require.Equal(t, FailoverContinue, action)
|
||||||
|
|
||||||
|
// 再次遇到账号 100,计数仍为 2,条件不满足 → 直接切换
|
||||||
|
action = fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||||
|
require.Equal(t, FailoverContinue, action)
|
||||||
|
require.Len(t, mock.calls, 2, "第二次耗尽也应调用 TempUnschedule")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// HandleFailoverError — TempUnschedule 调用验证
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestHandleFailoverError_TempUnschedule(t *testing.T) {
|
||||||
|
t.Run("非重试错误不调用TempUnschedule", func(t *testing.T) {
|
||||||
|
mock := &mockTempUnscheduler{}
|
||||||
|
fs := NewFailoverState(3, false)
|
||||||
|
err := newTestFailoverErr(500, false, false) // RetryableOnSameAccount=false
|
||||||
|
|
||||||
|
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||||
|
require.Empty(t, mock.calls)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("重试错误耗尽后调用TempUnschedule_传入正确参数", func(t *testing.T) {
|
||||||
|
mock := &mockTempUnscheduler{}
|
||||||
|
fs := NewFailoverState(3, false)
|
||||||
|
err := newTestFailoverErr(502, true, false)
|
||||||
|
|
||||||
|
// 耗尽重试
|
||||||
|
fs.HandleFailoverError(context.Background(), mock, 42, "openai", err)
|
||||||
|
fs.HandleFailoverError(context.Background(), mock, 42, "openai", err)
|
||||||
|
fs.HandleFailoverError(context.Background(), mock, 42, "openai", err)
|
||||||
|
|
||||||
|
require.Len(t, mock.calls, 1)
|
||||||
|
require.Equal(t, int64(42), mock.calls[0].accountID)
|
||||||
|
require.Equal(t, 502, mock.calls[0].failoverErr.StatusCode)
|
||||||
|
require.True(t, mock.calls[0].failoverErr.RetryableOnSameAccount)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// HandleFailoverError — Context 取消
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestHandleFailoverError_ContextCanceled(t *testing.T) {
|
||||||
|
t.Run("同账号重试sleep期间context取消", func(t *testing.T) {
|
||||||
|
mock := &mockTempUnscheduler{}
|
||||||
|
fs := NewFailoverState(3, false)
|
||||||
|
err := newTestFailoverErr(400, true, false)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancel() // 立即取消
|
||||||
|
|
||||||
|
start := time.Now()
|
||||||
|
action := fs.HandleFailoverError(ctx, mock, 100, "openai", err)
|
||||||
|
elapsed := time.Since(start)
|
||||||
|
|
||||||
|
require.Equal(t, FailoverCanceled, action)
|
||||||
|
require.Less(t, elapsed, 100*time.Millisecond, "应立即返回")
|
||||||
|
// 重试计数仍应递增
|
||||||
|
require.Equal(t, 1, fs.SameAccountRetryCount[100])
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Antigravity延迟期间context取消", func(t *testing.T) {
|
||||||
|
mock := &mockTempUnscheduler{}
|
||||||
|
fs := NewFailoverState(3, false)
|
||||||
|
fs.SwitchCount = 1 // 下一次 switchCount=2 → delay = 1s
|
||||||
|
err := newTestFailoverErr(500, false, false)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancel() // 立即取消
|
||||||
|
|
||||||
|
start := time.Now()
|
||||||
|
action := fs.HandleFailoverError(ctx, mock, 100, service.PlatformAntigravity, err)
|
||||||
|
elapsed := time.Since(start)
|
||||||
|
|
||||||
|
require.Equal(t, FailoverCanceled, action)
|
||||||
|
require.Less(t, elapsed, 100*time.Millisecond, "应立即返回而非等待 1s")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// HandleFailoverError — FailedAccountIDs 跟踪
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestHandleFailoverError_FailedAccountIDs(t *testing.T) {
|
||||||
|
t.Run("切换时添加到失败列表", func(t *testing.T) {
|
||||||
|
mock := &mockTempUnscheduler{}
|
||||||
|
fs := NewFailoverState(3, false)
|
||||||
|
|
||||||
|
fs.HandleFailoverError(context.Background(), mock, 100, "openai", newTestFailoverErr(500, false, false))
|
||||||
|
require.Contains(t, fs.FailedAccountIDs, int64(100))
|
||||||
|
|
||||||
|
fs.HandleFailoverError(context.Background(), mock, 200, "openai", newTestFailoverErr(502, false, false))
|
||||||
|
require.Contains(t, fs.FailedAccountIDs, int64(200))
|
||||||
|
require.Len(t, fs.FailedAccountIDs, 2)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("耗尽时也添加到失败列表", func(t *testing.T) {
|
||||||
|
mock := &mockTempUnscheduler{}
|
||||||
|
fs := NewFailoverState(0, false)
|
||||||
|
|
||||||
|
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", newTestFailoverErr(500, false, false))
|
||||||
|
require.Equal(t, FailoverExhausted, action)
|
||||||
|
require.Contains(t, fs.FailedAccountIDs, int64(100))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("同账号重试期间不添加到失败列表", func(t *testing.T) {
|
||||||
|
mock := &mockTempUnscheduler{}
|
||||||
|
fs := NewFailoverState(3, false)
|
||||||
|
|
||||||
|
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", newTestFailoverErr(400, true, false))
|
||||||
|
require.Equal(t, FailoverContinue, action)
|
||||||
|
require.NotContains(t, fs.FailedAccountIDs, int64(100))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("同一账号多次切换不重复添加", func(t *testing.T) {
|
||||||
|
mock := &mockTempUnscheduler{}
|
||||||
|
fs := NewFailoverState(5, false)
|
||||||
|
|
||||||
|
fs.HandleFailoverError(context.Background(), mock, 100, "openai", newTestFailoverErr(500, false, false))
|
||||||
|
fs.HandleFailoverError(context.Background(), mock, 100, "openai", newTestFailoverErr(500, false, false))
|
||||||
|
require.Len(t, fs.FailedAccountIDs, 1, "map 天然去重")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// HandleFailoverError — LastFailoverErr 更新
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestHandleFailoverError_LastFailoverErr(t *testing.T) {
|
||||||
|
t.Run("每次调用都更新LastFailoverErr", func(t *testing.T) {
|
||||||
|
mock := &mockTempUnscheduler{}
|
||||||
|
fs := NewFailoverState(3, false)
|
||||||
|
|
||||||
|
err1 := newTestFailoverErr(500, false, false)
|
||||||
|
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err1)
|
||||||
|
require.Equal(t, err1, fs.LastFailoverErr)
|
||||||
|
|
||||||
|
err2 := newTestFailoverErr(502, false, false)
|
||||||
|
fs.HandleFailoverError(context.Background(), mock, 200, "openai", err2)
|
||||||
|
require.Equal(t, err2, fs.LastFailoverErr)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("同账号重试时也更新LastFailoverErr", func(t *testing.T) {
|
||||||
|
mock := &mockTempUnscheduler{}
|
||||||
|
fs := NewFailoverState(3, false)
|
||||||
|
|
||||||
|
err := newTestFailoverErr(400, true, false)
|
||||||
|
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||||
|
require.Equal(t, err, fs.LastFailoverErr)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// HandleFailoverError — 综合集成场景
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestHandleFailoverError_IntegrationScenario(t *testing.T) {
|
||||||
|
t.Run("模拟完整failover流程_多账号混合重试与切换", func(t *testing.T) {
|
||||||
|
mock := &mockTempUnscheduler{}
|
||||||
|
fs := NewFailoverState(3, true) // hasBoundSession=true
|
||||||
|
|
||||||
|
// 1. 账号 100 遇到可重试错误,同账号重试 2 次
|
||||||
|
retryErr := newTestFailoverErr(400, true, false)
|
||||||
|
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", retryErr)
|
||||||
|
require.Equal(t, FailoverContinue, action)
|
||||||
|
require.True(t, fs.ForceCacheBilling, "hasBoundSession=true 应设置 ForceCacheBilling")
|
||||||
|
|
||||||
|
action = fs.HandleFailoverError(context.Background(), mock, 100, "openai", retryErr)
|
||||||
|
require.Equal(t, FailoverContinue, action)
|
||||||
|
|
||||||
|
// 2. 账号 100 重试耗尽 → TempUnschedule + 切换
|
||||||
|
action = fs.HandleFailoverError(context.Background(), mock, 100, "openai", retryErr)
|
||||||
|
require.Equal(t, FailoverContinue, action)
|
||||||
|
require.Equal(t, 1, fs.SwitchCount)
|
||||||
|
require.Len(t, mock.calls, 1)
|
||||||
|
|
||||||
|
// 3. 账号 200 遇到不可重试错误 → 直接切换
|
||||||
|
switchErr := newTestFailoverErr(500, false, false)
|
||||||
|
action = fs.HandleFailoverError(context.Background(), mock, 200, "openai", switchErr)
|
||||||
|
require.Equal(t, FailoverContinue, action)
|
||||||
|
require.Equal(t, 2, fs.SwitchCount)
|
||||||
|
|
||||||
|
// 4. 账号 300 遇到不可重试错误 → 再切换
|
||||||
|
action = fs.HandleFailoverError(context.Background(), mock, 300, "openai", switchErr)
|
||||||
|
require.Equal(t, FailoverContinue, action)
|
||||||
|
require.Equal(t, 3, fs.SwitchCount)
|
||||||
|
|
||||||
|
// 5. 账号 400 → 已耗尽 (SwitchCount=3 >= MaxSwitches=3)
|
||||||
|
action = fs.HandleFailoverError(context.Background(), mock, 400, "openai", switchErr)
|
||||||
|
require.Equal(t, FailoverExhausted, action)
|
||||||
|
|
||||||
|
// 最终状态验证
|
||||||
|
require.Equal(t, 3, fs.SwitchCount, "耗尽时不再递增")
|
||||||
|
require.Len(t, fs.FailedAccountIDs, 4, "4个不同账号都在失败列表中")
|
||||||
|
require.True(t, fs.ForceCacheBilling)
|
||||||
|
require.Len(t, mock.calls, 1, "只有账号 100 触发了 TempUnschedule")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("模拟Antigravity平台完整流程", func(t *testing.T) {
|
||||||
|
mock := &mockTempUnscheduler{}
|
||||||
|
fs := NewFailoverState(2, false)
|
||||||
|
|
||||||
|
err := newTestFailoverErr(500, false, false)
|
||||||
|
|
||||||
|
// 第一次切换:delay = 0s
|
||||||
|
start := time.Now()
|
||||||
|
action := fs.HandleFailoverError(context.Background(), mock, 100, service.PlatformAntigravity, err)
|
||||||
|
elapsed := time.Since(start)
|
||||||
|
require.Equal(t, FailoverContinue, action)
|
||||||
|
require.Less(t, elapsed, 200*time.Millisecond, "第一次切换延迟为 0")
|
||||||
|
|
||||||
|
// 第二次切换:delay = 1s
|
||||||
|
start = time.Now()
|
||||||
|
action = fs.HandleFailoverError(context.Background(), mock, 200, service.PlatformAntigravity, err)
|
||||||
|
elapsed = time.Since(start)
|
||||||
|
require.Equal(t, FailoverContinue, action)
|
||||||
|
require.GreaterOrEqual(t, elapsed, 800*time.Millisecond, "第二次切换延迟约 1s")
|
||||||
|
|
||||||
|
// 第三次:耗尽(无延迟,因为在检查延迟之前就返回了)
|
||||||
|
start = time.Now()
|
||||||
|
action = fs.HandleFailoverError(context.Background(), mock, 300, service.PlatformAntigravity, err)
|
||||||
|
elapsed = time.Since(start)
|
||||||
|
require.Equal(t, FailoverExhausted, action)
|
||||||
|
require.Less(t, elapsed, 200*time.Millisecond, "耗尽时不应有延迟")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("ForceCacheBilling通过错误标志设置", func(t *testing.T) {
|
||||||
|
mock := &mockTempUnscheduler{}
|
||||||
|
fs := NewFailoverState(3, false) // hasBoundSession=false
|
||||||
|
|
||||||
|
// 第一次:ForceCacheBilling=false
|
||||||
|
err1 := newTestFailoverErr(500, false, false)
|
||||||
|
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err1)
|
||||||
|
require.False(t, fs.ForceCacheBilling)
|
||||||
|
|
||||||
|
// 第二次:ForceCacheBilling=true(Antigravity 粘性会话切换)
|
||||||
|
err2 := newTestFailoverErr(500, false, true)
|
||||||
|
fs.HandleFailoverError(context.Background(), mock, 200, "openai", err2)
|
||||||
|
require.True(t, fs.ForceCacheBilling, "错误标志应触发 ForceCacheBilling")
|
||||||
|
|
||||||
|
// 第三次:ForceCacheBilling=false,但状态仍保持 true
|
||||||
|
err3 := newTestFailoverErr(500, false, false)
|
||||||
|
fs.HandleFailoverError(context.Background(), mock, 300, "openai", err3)
|
||||||
|
require.True(t, fs.ForceCacheBilling, "不应重置")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// HandleFailoverError — 边界条件
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestHandleFailoverError_EdgeCases(t *testing.T) {
|
||||||
|
t.Run("StatusCode为0的错误也能正常处理", func(t *testing.T) {
|
||||||
|
mock := &mockTempUnscheduler{}
|
||||||
|
fs := NewFailoverState(3, false)
|
||||||
|
err := newTestFailoverErr(0, false, false)
|
||||||
|
|
||||||
|
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||||
|
require.Equal(t, FailoverContinue, action)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("AccountID为0也能正常跟踪", func(t *testing.T) {
|
||||||
|
mock := &mockTempUnscheduler{}
|
||||||
|
fs := NewFailoverState(3, false)
|
||||||
|
err := newTestFailoverErr(500, true, false)
|
||||||
|
|
||||||
|
action := fs.HandleFailoverError(context.Background(), mock, 0, "openai", err)
|
||||||
|
require.Equal(t, FailoverContinue, action)
|
||||||
|
require.Equal(t, 1, fs.SameAccountRetryCount[0])
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("负AccountID也能正常跟踪", func(t *testing.T) {
|
||||||
|
mock := &mockTempUnscheduler{}
|
||||||
|
fs := NewFailoverState(3, false)
|
||||||
|
err := newTestFailoverErr(500, true, false)
|
||||||
|
|
||||||
|
action := fs.HandleFailoverError(context.Background(), mock, -1, "openai", err)
|
||||||
|
require.Equal(t, FailoverContinue, action)
|
||||||
|
require.Equal(t, 1, fs.SameAccountRetryCount[-1])
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("空平台名称不触发Antigravity延迟", func(t *testing.T) {
|
||||||
|
mock := &mockTempUnscheduler{}
|
||||||
|
fs := NewFailoverState(3, false)
|
||||||
|
fs.SwitchCount = 1
|
||||||
|
err := newTestFailoverErr(500, false, false)
|
||||||
|
|
||||||
|
start := time.Now()
|
||||||
|
action := fs.HandleFailoverError(context.Background(), mock, 100, "", err)
|
||||||
|
elapsed := time.Since(start)
|
||||||
|
|
||||||
|
require.Equal(t, FailoverContinue, action)
|
||||||
|
require.Less(t, elapsed, 200*time.Millisecond, "空平台不应触发 Antigravity 延迟")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// HandleSelectionExhausted 测试
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestHandleSelectionExhausted(t *testing.T) {
|
||||||
|
t.Run("无LastFailoverErr时返回Exhausted", func(t *testing.T) {
|
||||||
|
fs := NewFailoverState(3, false)
|
||||||
|
// LastFailoverErr 为 nil
|
||||||
|
|
||||||
|
action := fs.HandleSelectionExhausted(context.Background())
|
||||||
|
require.Equal(t, FailoverExhausted, action)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("非503错误返回Exhausted", func(t *testing.T) {
|
||||||
|
fs := NewFailoverState(3, false)
|
||||||
|
fs.LastFailoverErr = newTestFailoverErr(500, false, false)
|
||||||
|
|
||||||
|
action := fs.HandleSelectionExhausted(context.Background())
|
||||||
|
require.Equal(t, FailoverExhausted, action)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("503且未耗尽_等待后返回Continue并清除失败列表", func(t *testing.T) {
|
||||||
|
fs := NewFailoverState(3, false)
|
||||||
|
fs.LastFailoverErr = newTestFailoverErr(503, false, false)
|
||||||
|
fs.FailedAccountIDs[100] = struct{}{}
|
||||||
|
fs.SwitchCount = 1
|
||||||
|
|
||||||
|
start := time.Now()
|
||||||
|
action := fs.HandleSelectionExhausted(context.Background())
|
||||||
|
elapsed := time.Since(start)
|
||||||
|
|
||||||
|
require.Equal(t, FailoverContinue, action)
|
||||||
|
require.Empty(t, fs.FailedAccountIDs, "应清除失败账号列表")
|
||||||
|
require.GreaterOrEqual(t, elapsed, 1500*time.Millisecond, "应等待约 2s")
|
||||||
|
require.Less(t, elapsed, 5*time.Second)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("503但SwitchCount已超过MaxSwitches_返回Exhausted", func(t *testing.T) {
|
||||||
|
fs := NewFailoverState(2, false)
|
||||||
|
fs.LastFailoverErr = newTestFailoverErr(503, false, false)
|
||||||
|
fs.SwitchCount = 3 // > MaxSwitches(2)
|
||||||
|
|
||||||
|
start := time.Now()
|
||||||
|
action := fs.HandleSelectionExhausted(context.Background())
|
||||||
|
elapsed := time.Since(start)
|
||||||
|
|
||||||
|
require.Equal(t, FailoverExhausted, action)
|
||||||
|
require.Less(t, elapsed, 100*time.Millisecond, "不应等待")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("503但context已取消_返回Canceled", func(t *testing.T) {
|
||||||
|
fs := NewFailoverState(3, false)
|
||||||
|
fs.LastFailoverErr = newTestFailoverErr(503, false, false)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
start := time.Now()
|
||||||
|
action := fs.HandleSelectionExhausted(ctx)
|
||||||
|
elapsed := time.Since(start)
|
||||||
|
|
||||||
|
require.Equal(t, FailoverCanceled, action)
|
||||||
|
require.Less(t, elapsed, 100*time.Millisecond, "应立即返回")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("503且SwitchCount等于MaxSwitches_仍可重试", func(t *testing.T) {
|
||||||
|
fs := NewFailoverState(2, false)
|
||||||
|
fs.LastFailoverErr = newTestFailoverErr(503, false, false)
|
||||||
|
fs.SwitchCount = 2 // == MaxSwitches,条件是 <=,仍可重试
|
||||||
|
|
||||||
|
action := fs.HandleSelectionExhausted(context.Background())
|
||||||
|
require.Equal(t, FailoverContinue, action)
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -7,7 +7,6 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
@@ -257,12 +256,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
hasBoundSession := sessionKey != "" && sessionBoundAccountID > 0
|
hasBoundSession := sessionKey != "" && sessionBoundAccountID > 0
|
||||||
|
|
||||||
if platform == service.PlatformGemini {
|
if platform == service.PlatformGemini {
|
||||||
maxAccountSwitches := h.maxAccountSwitchesGemini
|
fs := NewFailoverState(h.maxAccountSwitchesGemini, hasBoundSession)
|
||||||
switchCount := 0
|
|
||||||
failedAccountIDs := make(map[int64]struct{})
|
|
||||||
sameAccountRetryCount := make(map[int64]int) // 同账号重试计数
|
|
||||||
var lastFailoverErr *service.UpstreamFailoverError
|
|
||||||
var forceCacheBilling bool // 粘性会话切换时的缓存计费标记
|
|
||||||
|
|
||||||
// 单账号分组提前设置 SingleAccountRetry 标记,让 Service 层首次 503 就不设模型限流标记。
|
// 单账号分组提前设置 SingleAccountRetry 标记,让 Service 层首次 503 就不设模型限流标记。
|
||||||
// 避免单账号分组收到 503 (MODEL_CAPACITY_EXHAUSTED) 时设 29s 限流,导致后续请求连续快速失败。
|
// 避免单账号分组收到 503 (MODEL_CAPACITY_EXHAUSTED) 时设 29s 限流,导致后续请求连续快速失败。
|
||||||
@@ -272,35 +266,28 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for {
|
for {
|
||||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs, "") // Gemini 不使用会话限制
|
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, fs.FailedAccountIDs, "") // Gemini 不使用会话限制
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if len(failedAccountIDs) == 0 {
|
if len(fs.FailedAccountIDs) == 0 {
|
||||||
reqLog.Warn("gateway.account_select_failed", zap.Error(err), zap.Int("excluded_account_count", len(failedAccountIDs)))
|
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
||||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// Antigravity 单账号退避重试:分组内没有其他可用账号时,
|
action := fs.HandleSelectionExhausted(c.Request.Context())
|
||||||
// 对 503 错误不直接返回,而是清除排除列表、等待退避后重试同一个账号。
|
switch action {
|
||||||
// 谷歌上游 503 (MODEL_CAPACITY_EXHAUSTED) 通常是暂时性的,等几秒就能恢复。
|
case FailoverContinue:
|
||||||
if lastFailoverErr != nil && lastFailoverErr.StatusCode == http.StatusServiceUnavailable && switchCount <= maxAccountSwitches {
|
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
|
||||||
if sleepAntigravitySingleAccountBackoff(c.Request.Context(), switchCount) {
|
c.Request = c.Request.WithContext(ctx)
|
||||||
reqLog.Warn("gateway.single_account_retrying",
|
continue
|
||||||
zap.Int("retry_count", switchCount),
|
case FailoverCanceled:
|
||||||
zap.Int("max_retries", maxAccountSwitches),
|
return
|
||||||
)
|
default: // FailoverExhausted
|
||||||
failedAccountIDs = make(map[int64]struct{})
|
if fs.LastFailoverErr != nil {
|
||||||
// 设置 context 标记,让 Service 层预检查等待限流过期而非直接切换
|
h.handleFailoverExhausted(c, fs.LastFailoverErr, service.PlatformGemini, streamStarted)
|
||||||
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
|
} else {
|
||||||
c.Request = c.Request.WithContext(ctx)
|
h.handleFailoverExhaustedSimple(c, 502, streamStarted)
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
|
return
|
||||||
}
|
}
|
||||||
if lastFailoverErr != nil {
|
|
||||||
h.handleFailoverExhausted(c, lastFailoverErr, service.PlatformGemini, streamStarted)
|
|
||||||
} else {
|
|
||||||
h.handleFailoverExhaustedSimple(c, 502, streamStarted)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
account := selection.Account
|
account := selection.Account
|
||||||
setOpsSelectedAccount(c, account.ID, account.Platform)
|
setOpsSelectedAccount(c, account.ID, account.Platform)
|
||||||
@@ -376,8 +363,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
// 转发请求 - 根据账号平台分流
|
// 转发请求 - 根据账号平台分流
|
||||||
var result *service.ForwardResult
|
var result *service.ForwardResult
|
||||||
requestCtx := c.Request.Context()
|
requestCtx := c.Request.Context()
|
||||||
if switchCount > 0 {
|
if fs.SwitchCount > 0 {
|
||||||
requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, switchCount)
|
requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, fs.SwitchCount)
|
||||||
}
|
}
|
||||||
if account.Platform == service.PlatformAntigravity {
|
if account.Platform == service.PlatformAntigravity {
|
||||||
result, err = h.antigravityGatewayService.ForwardGemini(requestCtx, c, account, reqModel, "generateContent", reqStream, body, hasBoundSession)
|
result, err = h.antigravityGatewayService.ForwardGemini(requestCtx, c, account, reqModel, "generateContent", reqStream, body, hasBoundSession)
|
||||||
@@ -390,45 +377,16 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
var failoverErr *service.UpstreamFailoverError
|
var failoverErr *service.UpstreamFailoverError
|
||||||
if errors.As(err, &failoverErr) {
|
if errors.As(err, &failoverErr) {
|
||||||
lastFailoverErr = failoverErr
|
action := fs.HandleFailoverError(c.Request.Context(), h.gatewayService, account.ID, account.Platform, failoverErr)
|
||||||
if needForceCacheBilling(hasBoundSession, failoverErr) {
|
switch action {
|
||||||
forceCacheBilling = true
|
case FailoverContinue:
|
||||||
}
|
|
||||||
|
|
||||||
// 同账号重试:对 RetryableOnSameAccount 的临时性错误,先在同一账号上重试
|
|
||||||
if failoverErr.RetryableOnSameAccount && sameAccountRetryCount[account.ID] < maxSameAccountRetries {
|
|
||||||
sameAccountRetryCount[account.ID]++
|
|
||||||
log.Printf("Account %d: retryable error %d, same-account retry %d/%d",
|
|
||||||
account.ID, failoverErr.StatusCode, sameAccountRetryCount[account.ID], maxSameAccountRetries)
|
|
||||||
if !sleepSameAccountRetryDelay(c.Request.Context()) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
continue
|
continue
|
||||||
}
|
case FailoverExhausted:
|
||||||
|
h.handleFailoverExhausted(c, fs.LastFailoverErr, service.PlatformGemini, streamStarted)
|
||||||
// 同账号重试用尽,执行临时封禁并切换账号
|
return
|
||||||
if failoverErr.RetryableOnSameAccount {
|
case FailoverCanceled:
|
||||||
h.gatewayService.TempUnscheduleRetryableError(c.Request.Context(), account.ID, failoverErr)
|
|
||||||
}
|
|
||||||
|
|
||||||
failedAccountIDs[account.ID] = struct{}{}
|
|
||||||
if switchCount >= maxAccountSwitches {
|
|
||||||
h.handleFailoverExhausted(c, failoverErr, service.PlatformGemini, streamStarted)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
switchCount++
|
|
||||||
reqLog.Warn("gateway.upstream_failover_switching",
|
|
||||||
zap.Int64("account_id", account.ID),
|
|
||||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
|
||||||
zap.Int("switch_count", switchCount),
|
|
||||||
zap.Int("max_switches", maxAccountSwitches),
|
|
||||||
)
|
|
||||||
if account.Platform == service.PlatformAntigravity {
|
|
||||||
if !sleepFailoverDelay(c.Request.Context(), switchCount) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
|
wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
|
||||||
reqLog.Error("gateway.forward_failed",
|
reqLog.Error("gateway.forward_failed",
|
||||||
@@ -453,7 +411,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
Subscription: subscription,
|
Subscription: subscription,
|
||||||
UserAgent: userAgent,
|
UserAgent: userAgent,
|
||||||
IPAddress: clientIP,
|
IPAddress: clientIP,
|
||||||
ForceCacheBilling: forceCacheBilling,
|
ForceCacheBilling: fs.ForceCacheBilling,
|
||||||
APIKeyService: h.apiKeyService,
|
APIKeyService: h.apiKeyService,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
logger.L().With(
|
logger.L().With(
|
||||||
@@ -486,45 +444,33 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for {
|
for {
|
||||||
maxAccountSwitches := h.maxAccountSwitches
|
fs := NewFailoverState(h.maxAccountSwitches, hasBoundSession)
|
||||||
switchCount := 0
|
|
||||||
failedAccountIDs := make(map[int64]struct{})
|
|
||||||
sameAccountRetryCount := make(map[int64]int) // 同账号重试计数
|
|
||||||
var lastFailoverErr *service.UpstreamFailoverError
|
|
||||||
retryWithFallback := false
|
retryWithFallback := false
|
||||||
var forceCacheBilling bool // 粘性会话切换时的缓存计费标记
|
|
||||||
|
|
||||||
for {
|
for {
|
||||||
// 选择支持该模型的账号
|
// 选择支持该模型的账号
|
||||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), currentAPIKey.GroupID, sessionKey, reqModel, failedAccountIDs, parsedReq.MetadataUserID)
|
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), currentAPIKey.GroupID, sessionKey, reqModel, fs.FailedAccountIDs, parsedReq.MetadataUserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if len(failedAccountIDs) == 0 {
|
if len(fs.FailedAccountIDs) == 0 {
|
||||||
reqLog.Warn("gateway.account_select_failed", zap.Error(err), zap.Int("excluded_account_count", len(failedAccountIDs)))
|
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
||||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// Antigravity 单账号退避重试:分组内没有其他可用账号时,
|
action := fs.HandleSelectionExhausted(c.Request.Context())
|
||||||
// 对 503 错误不直接返回,而是清除排除列表、等待退避后重试同一个账号。
|
switch action {
|
||||||
// 谷歌上游 503 (MODEL_CAPACITY_EXHAUSTED) 通常是暂时性的,等几秒就能恢复。
|
case FailoverContinue:
|
||||||
if lastFailoverErr != nil && lastFailoverErr.StatusCode == http.StatusServiceUnavailable && switchCount <= maxAccountSwitches {
|
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
|
||||||
if sleepAntigravitySingleAccountBackoff(c.Request.Context(), switchCount) {
|
c.Request = c.Request.WithContext(ctx)
|
||||||
reqLog.Warn("gateway.single_account_retrying",
|
continue
|
||||||
zap.Int("retry_count", switchCount),
|
case FailoverCanceled:
|
||||||
zap.Int("max_retries", maxAccountSwitches),
|
return
|
||||||
)
|
default: // FailoverExhausted
|
||||||
failedAccountIDs = make(map[int64]struct{})
|
if fs.LastFailoverErr != nil {
|
||||||
// 设置 context 标记,让 Service 层预检查等待限流过期而非直接切换
|
h.handleFailoverExhausted(c, fs.LastFailoverErr, platform, streamStarted)
|
||||||
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
|
} else {
|
||||||
c.Request = c.Request.WithContext(ctx)
|
h.handleFailoverExhaustedSimple(c, 502, streamStarted)
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
|
return
|
||||||
}
|
}
|
||||||
if lastFailoverErr != nil {
|
|
||||||
h.handleFailoverExhausted(c, lastFailoverErr, platform, streamStarted)
|
|
||||||
} else {
|
|
||||||
h.handleFailoverExhaustedSimple(c, 502, streamStarted)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
account := selection.Account
|
account := selection.Account
|
||||||
setOpsSelectedAccount(c, account.ID, account.Platform)
|
setOpsSelectedAccount(c, account.ID, account.Platform)
|
||||||
@@ -600,8 +546,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
// 转发请求 - 根据账号平台分流
|
// 转发请求 - 根据账号平台分流
|
||||||
var result *service.ForwardResult
|
var result *service.ForwardResult
|
||||||
requestCtx := c.Request.Context()
|
requestCtx := c.Request.Context()
|
||||||
if switchCount > 0 {
|
if fs.SwitchCount > 0 {
|
||||||
requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, switchCount)
|
requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, fs.SwitchCount)
|
||||||
}
|
}
|
||||||
if account.Platform == service.PlatformAntigravity && account.Type != service.AccountTypeAPIKey {
|
if account.Platform == service.PlatformAntigravity && account.Type != service.AccountTypeAPIKey {
|
||||||
result, err = h.antigravityGatewayService.Forward(requestCtx, c, account, body, hasBoundSession)
|
result, err = h.antigravityGatewayService.Forward(requestCtx, c, account, body, hasBoundSession)
|
||||||
@@ -657,45 +603,16 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
var failoverErr *service.UpstreamFailoverError
|
var failoverErr *service.UpstreamFailoverError
|
||||||
if errors.As(err, &failoverErr) {
|
if errors.As(err, &failoverErr) {
|
||||||
lastFailoverErr = failoverErr
|
action := fs.HandleFailoverError(c.Request.Context(), h.gatewayService, account.ID, account.Platform, failoverErr)
|
||||||
if needForceCacheBilling(hasBoundSession, failoverErr) {
|
switch action {
|
||||||
forceCacheBilling = true
|
case FailoverContinue:
|
||||||
}
|
|
||||||
|
|
||||||
// 同账号重试:对 RetryableOnSameAccount 的临时性错误,先在同一账号上重试
|
|
||||||
if failoverErr.RetryableOnSameAccount && sameAccountRetryCount[account.ID] < maxSameAccountRetries {
|
|
||||||
sameAccountRetryCount[account.ID]++
|
|
||||||
log.Printf("Account %d: retryable error %d, same-account retry %d/%d",
|
|
||||||
account.ID, failoverErr.StatusCode, sameAccountRetryCount[account.ID], maxSameAccountRetries)
|
|
||||||
if !sleepSameAccountRetryDelay(c.Request.Context()) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
continue
|
continue
|
||||||
}
|
case FailoverExhausted:
|
||||||
|
h.handleFailoverExhausted(c, fs.LastFailoverErr, account.Platform, streamStarted)
|
||||||
// 同账号重试用尽,执行临时封禁并切换账号
|
return
|
||||||
if failoverErr.RetryableOnSameAccount {
|
case FailoverCanceled:
|
||||||
h.gatewayService.TempUnscheduleRetryableError(c.Request.Context(), account.ID, failoverErr)
|
|
||||||
}
|
|
||||||
|
|
||||||
failedAccountIDs[account.ID] = struct{}{}
|
|
||||||
if switchCount >= maxAccountSwitches {
|
|
||||||
h.handleFailoverExhausted(c, failoverErr, account.Platform, streamStarted)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
switchCount++
|
|
||||||
reqLog.Warn("gateway.upstream_failover_switching",
|
|
||||||
zap.Int64("account_id", account.ID),
|
|
||||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
|
||||||
zap.Int("switch_count", switchCount),
|
|
||||||
zap.Int("max_switches", maxAccountSwitches),
|
|
||||||
)
|
|
||||||
if account.Platform == service.PlatformAntigravity {
|
|
||||||
if !sleepFailoverDelay(c.Request.Context(), switchCount) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
|
wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
|
||||||
reqLog.Error("gateway.forward_failed",
|
reqLog.Error("gateway.forward_failed",
|
||||||
@@ -720,7 +637,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
Subscription: currentSubscription,
|
Subscription: currentSubscription,
|
||||||
UserAgent: userAgent,
|
UserAgent: userAgent,
|
||||||
IPAddress: clientIP,
|
IPAddress: clientIP,
|
||||||
ForceCacheBilling: forceCacheBilling,
|
ForceCacheBilling: fs.ForceCacheBilling,
|
||||||
APIKeyService: h.apiKeyService,
|
APIKeyService: h.apiKeyService,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
logger.L().With(
|
logger.L().With(
|
||||||
@@ -735,7 +652,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
reqLog.Debug("gateway.request_completed",
|
reqLog.Debug("gateway.request_completed",
|
||||||
zap.Int64("account_id", account.ID),
|
zap.Int64("account_id", account.ID),
|
||||||
zap.Int("switch_count", switchCount),
|
zap.Int("switch_count", fs.SwitchCount),
|
||||||
zap.Bool("fallback_used", fallbackUsed),
|
zap.Bool("fallback_used", fallbackUsed),
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
@@ -982,69 +899,6 @@ func (h *GatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotT
|
|||||||
fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted)
|
fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted)
|
||||||
}
|
}
|
||||||
|
|
||||||
// needForceCacheBilling 判断 failover 时是否需要强制缓存计费
|
|
||||||
// 粘性会话切换账号、或上游明确标记时,将 input_tokens 转为 cache_read 计费
|
|
||||||
func needForceCacheBilling(hasBoundSession bool, failoverErr *service.UpstreamFailoverError) bool {
|
|
||||||
return hasBoundSession || (failoverErr != nil && failoverErr.ForceCacheBilling)
|
|
||||||
}
|
|
||||||
|
|
||||||
const (
|
|
||||||
// maxSameAccountRetries 同账号重试次数上限(针对 RetryableOnSameAccount 错误)
|
|
||||||
maxSameAccountRetries = 2
|
|
||||||
// sameAccountRetryDelay 同账号重试间隔
|
|
||||||
sameAccountRetryDelay = 500 * time.Millisecond
|
|
||||||
)
|
|
||||||
|
|
||||||
// sleepSameAccountRetryDelay 同账号重试固定延时,返回 false 表示 context 已取消。
|
|
||||||
func sleepSameAccountRetryDelay(ctx context.Context) bool {
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return false
|
|
||||||
case <-time.After(sameAccountRetryDelay):
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// sleepFailoverDelay 账号切换线性递增延时:第1次0s、第2次1s、第3次2s…
|
|
||||||
// 返回 false 表示 context 已取消。
|
|
||||||
func sleepFailoverDelay(ctx context.Context, switchCount int) bool {
|
|
||||||
delay := time.Duration(switchCount-1) * time.Second
|
|
||||||
if delay <= 0 {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return false
|
|
||||||
case <-time.After(delay):
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// sleepAntigravitySingleAccountBackoff Antigravity 平台单账号分组的 503 退避重试延时。
|
|
||||||
// 当分组内只有一个可用账号且上游返回 503(MODEL_CAPACITY_EXHAUSTED)时使用,
|
|
||||||
// 采用短固定延时策略。Service 层在 SingleAccountRetry 模式下已经做了充分的原地重试
|
|
||||||
// (最多 3 次、总等待 30s),所以 Handler 层的退避只需短暂等待即可。
|
|
||||||
// 返回 false 表示 context 已取消。
|
|
||||||
func sleepAntigravitySingleAccountBackoff(ctx context.Context, retryCount int) bool {
|
|
||||||
// 固定短延时:2s
|
|
||||||
// Service 层已经在原地等待了足够长的时间(retryDelay × 重试次数),
|
|
||||||
// Handler 层只需短暂间隔后重新进入 Service 层即可。
|
|
||||||
const delay = 2 * time.Second
|
|
||||||
|
|
||||||
logger.L().With(
|
|
||||||
zap.String("component", "handler.gateway.failover"),
|
|
||||||
zap.Duration("delay", delay),
|
|
||||||
zap.Int("retry_count", retryCount),
|
|
||||||
).Info("gateway.single_account_backoff_waiting")
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return false
|
|
||||||
case <-time.After(delay):
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *GatewayHandler) handleFailoverExhausted(c *gin.Context, failoverErr *service.UpstreamFailoverError, platform string, streamStarted bool) {
|
func (h *GatewayHandler) handleFailoverExhausted(c *gin.Context, failoverErr *service.UpstreamFailoverError, platform string, streamStarted bool) {
|
||||||
statusCode := failoverErr.StatusCode
|
statusCode := failoverErr.StatusCode
|
||||||
responseBody := failoverErr.ResponseBody
|
responseBody := failoverErr.ResponseBody
|
||||||
|
|||||||
@@ -1,51 +0,0 @@
|
|||||||
package handler
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
// ---------------------------------------------------------------------------
|
|
||||||
// sleepAntigravitySingleAccountBackoff 测试
|
|
||||||
// ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
func TestSleepAntigravitySingleAccountBackoff_ReturnsTrue(t *testing.T) {
|
|
||||||
ctx := context.Background()
|
|
||||||
start := time.Now()
|
|
||||||
ok := sleepAntigravitySingleAccountBackoff(ctx, 1)
|
|
||||||
elapsed := time.Since(start)
|
|
||||||
|
|
||||||
require.True(t, ok, "should return true when context is not canceled")
|
|
||||||
// 固定延迟 2s
|
|
||||||
require.GreaterOrEqual(t, elapsed, 1500*time.Millisecond, "should wait approximately 2s")
|
|
||||||
require.Less(t, elapsed, 5*time.Second, "should not wait too long")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSleepAntigravitySingleAccountBackoff_ContextCanceled(t *testing.T) {
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
|
||||||
cancel() // 立即取消
|
|
||||||
|
|
||||||
start := time.Now()
|
|
||||||
ok := sleepAntigravitySingleAccountBackoff(ctx, 1)
|
|
||||||
elapsed := time.Since(start)
|
|
||||||
|
|
||||||
require.False(t, ok, "should return false when context is canceled")
|
|
||||||
require.Less(t, elapsed, 500*time.Millisecond, "should return immediately on cancel")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSleepAntigravitySingleAccountBackoff_FixedDelay(t *testing.T) {
|
|
||||||
// 验证不同 retryCount 都使用固定 2s 延迟
|
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
start := time.Now()
|
|
||||||
ok := sleepAntigravitySingleAccountBackoff(ctx, 5)
|
|
||||||
elapsed := time.Since(start)
|
|
||||||
|
|
||||||
require.True(t, ok)
|
|
||||||
// 即使 retryCount=5,延迟仍然是固定的 2s
|
|
||||||
require.GreaterOrEqual(t, elapsed, 1500*time.Millisecond)
|
|
||||||
require.Less(t, elapsed, 5*time.Second)
|
|
||||||
}
|
|
||||||
@@ -344,11 +344,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
|||||||
hasBoundSession := sessionKey != "" && sessionBoundAccountID > 0
|
hasBoundSession := sessionKey != "" && sessionBoundAccountID > 0
|
||||||
cleanedForUnknownBinding := false
|
cleanedForUnknownBinding := false
|
||||||
|
|
||||||
maxAccountSwitches := h.maxAccountSwitchesGemini
|
fs := NewFailoverState(h.maxAccountSwitchesGemini, hasBoundSession)
|
||||||
switchCount := 0
|
|
||||||
failedAccountIDs := make(map[int64]struct{})
|
|
||||||
var lastFailoverErr *service.UpstreamFailoverError
|
|
||||||
var forceCacheBilling bool // 粘性会话切换时的缓存计费标记
|
|
||||||
|
|
||||||
// 单账号分组提前设置 SingleAccountRetry 标记,让 Service 层首次 503 就不设模型限流标记。
|
// 单账号分组提前设置 SingleAccountRetry 标记,让 Service 层首次 503 就不设模型限流标记。
|
||||||
// 避免单账号分组收到 503 (MODEL_CAPACITY_EXHAUSTED) 时设 29s 限流,导致后续请求连续快速失败。
|
// 避免单账号分组收到 503 (MODEL_CAPACITY_EXHAUSTED) 时设 29s 限流,导致后续请求连续快速失败。
|
||||||
@@ -358,30 +354,24 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for {
|
for {
|
||||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, modelName, failedAccountIDs, "") // Gemini 不使用会话限制
|
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, modelName, fs.FailedAccountIDs, "") // Gemini 不使用会话限制
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if len(failedAccountIDs) == 0 {
|
if len(fs.FailedAccountIDs) == 0 {
|
||||||
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
|
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// Antigravity 单账号退避重试:分组内没有其他可用账号时,
|
action := fs.HandleSelectionExhausted(c.Request.Context())
|
||||||
// 对 503 错误不直接返回,而是清除排除列表、等待退避后重试同一个账号。
|
switch action {
|
||||||
// 谷歌上游 503 (MODEL_CAPACITY_EXHAUSTED) 通常是暂时性的,等几秒就能恢复。
|
case FailoverContinue:
|
||||||
if lastFailoverErr != nil && lastFailoverErr.StatusCode == http.StatusServiceUnavailable && switchCount <= maxAccountSwitches {
|
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
|
||||||
if sleepAntigravitySingleAccountBackoff(c.Request.Context(), switchCount) {
|
c.Request = c.Request.WithContext(ctx)
|
||||||
reqLog.Warn("gemini.single_account_retrying",
|
continue
|
||||||
zap.Int("retry_count", switchCount),
|
case FailoverCanceled:
|
||||||
zap.Int("max_retries", maxAccountSwitches),
|
return
|
||||||
)
|
default: // FailoverExhausted
|
||||||
failedAccountIDs = make(map[int64]struct{})
|
h.handleGeminiFailoverExhausted(c, fs.LastFailoverErr)
|
||||||
// 设置 context 标记,让 Service 层预检查等待限流过期而非直接切换
|
return
|
||||||
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
|
|
||||||
c.Request = c.Request.WithContext(ctx)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
h.handleGeminiFailoverExhausted(c, lastFailoverErr)
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
account := selection.Account
|
account := selection.Account
|
||||||
setOpsSelectedAccount(c, account.ID, account.Platform)
|
setOpsSelectedAccount(c, account.ID, account.Platform)
|
||||||
@@ -465,8 +455,8 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
|||||||
// 5) forward (根据平台分流)
|
// 5) forward (根据平台分流)
|
||||||
var result *service.ForwardResult
|
var result *service.ForwardResult
|
||||||
requestCtx := c.Request.Context()
|
requestCtx := c.Request.Context()
|
||||||
if switchCount > 0 {
|
if fs.SwitchCount > 0 {
|
||||||
requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, switchCount)
|
requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, fs.SwitchCount)
|
||||||
}
|
}
|
||||||
if account.Platform == service.PlatformAntigravity && account.Type != service.AccountTypeAPIKey {
|
if account.Platform == service.PlatformAntigravity && account.Type != service.AccountTypeAPIKey {
|
||||||
result, err = h.antigravityGatewayService.ForwardGemini(requestCtx, c, account, modelName, action, stream, body, hasBoundSession)
|
result, err = h.antigravityGatewayService.ForwardGemini(requestCtx, c, account, modelName, action, stream, body, hasBoundSession)
|
||||||
@@ -479,29 +469,16 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
var failoverErr *service.UpstreamFailoverError
|
var failoverErr *service.UpstreamFailoverError
|
||||||
if errors.As(err, &failoverErr) {
|
if errors.As(err, &failoverErr) {
|
||||||
failedAccountIDs[account.ID] = struct{}{}
|
failoverAction := fs.HandleFailoverError(c.Request.Context(), h.gatewayService, account.ID, account.Platform, failoverErr)
|
||||||
if needForceCacheBilling(hasBoundSession, failoverErr) {
|
switch failoverAction {
|
||||||
forceCacheBilling = true
|
case FailoverContinue:
|
||||||
}
|
continue
|
||||||
if switchCount >= maxAccountSwitches {
|
case FailoverExhausted:
|
||||||
lastFailoverErr = failoverErr
|
h.handleGeminiFailoverExhausted(c, fs.LastFailoverErr)
|
||||||
h.handleGeminiFailoverExhausted(c, lastFailoverErr)
|
return
|
||||||
|
case FailoverCanceled:
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
lastFailoverErr = failoverErr
|
|
||||||
switchCount++
|
|
||||||
reqLog.Warn("gemini.upstream_failover_switching",
|
|
||||||
zap.Int64("account_id", account.ID),
|
|
||||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
|
||||||
zap.Int("switch_count", switchCount),
|
|
||||||
zap.Int("max_switches", maxAccountSwitches),
|
|
||||||
)
|
|
||||||
if account.Platform == service.PlatformAntigravity {
|
|
||||||
if !sleepFailoverDelay(c.Request.Context(), switchCount) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
// ForwardNative already wrote the response
|
// ForwardNative already wrote the response
|
||||||
reqLog.Error("gemini.forward_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
reqLog.Error("gemini.forward_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||||
@@ -539,7 +516,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
|||||||
IPAddress: clientIP,
|
IPAddress: clientIP,
|
||||||
LongContextThreshold: 200000, // Gemini 200K 阈值
|
LongContextThreshold: 200000, // Gemini 200K 阈值
|
||||||
LongContextMultiplier: 2.0, // 超出部分双倍计费
|
LongContextMultiplier: 2.0, // 超出部分双倍计费
|
||||||
ForceCacheBilling: forceCacheBilling,
|
ForceCacheBilling: fs.ForceCacheBilling,
|
||||||
APIKeyService: h.apiKeyService,
|
APIKeyService: h.apiKeyService,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
logger.L().With(
|
logger.L().With(
|
||||||
@@ -554,7 +531,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
reqLog.Debug("gemini.request_completed",
|
reqLog.Debug("gemini.request_completed",
|
||||||
zap.Int64("account_id", account.ID),
|
zap.Int64("account_id", account.ID),
|
||||||
zap.Int("switch_count", switchCount),
|
zap.Int("switch_count", fs.SwitchCount),
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user