refactor: extract failover error handling into FailoverState
- Extract duplicated failover logic from gateway_handler.go (3 places) and gemini_v1beta_handler.go into shared failover_loop.go - Introduce FailoverState with HandleFailoverError and HandleSelectionExhausted - Move helper functions (needForceCacheBilling, sleepWithContext) into failover_loop.go - Add comprehensive unit tests (32+ test cases) - Delete redundant gateway_handler_single_account_retry_test.go
This commit is contained in:
@@ -7,7 +7,6 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -257,12 +256,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
hasBoundSession := sessionKey != "" && sessionBoundAccountID > 0
|
||||
|
||||
if platform == service.PlatformGemini {
|
||||
maxAccountSwitches := h.maxAccountSwitchesGemini
|
||||
switchCount := 0
|
||||
failedAccountIDs := make(map[int64]struct{})
|
||||
sameAccountRetryCount := make(map[int64]int) // 同账号重试计数
|
||||
var lastFailoverErr *service.UpstreamFailoverError
|
||||
var forceCacheBilling bool // 粘性会话切换时的缓存计费标记
|
||||
fs := NewFailoverState(h.maxAccountSwitchesGemini, hasBoundSession)
|
||||
|
||||
// 单账号分组提前设置 SingleAccountRetry 标记,让 Service 层首次 503 就不设模型限流标记。
|
||||
// 避免单账号分组收到 503 (MODEL_CAPACITY_EXHAUSTED) 时设 29s 限流,导致后续请求连续快速失败。
|
||||
@@ -272,35 +266,28 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
}
|
||||
|
||||
for {
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs, "") // Gemini 不使用会话限制
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, fs.FailedAccountIDs, "") // Gemini 不使用会话限制
|
||||
if err != nil {
|
||||
if len(failedAccountIDs) == 0 {
|
||||
reqLog.Warn("gateway.account_select_failed", zap.Error(err), zap.Int("excluded_account_count", len(failedAccountIDs)))
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted)
|
||||
if len(fs.FailedAccountIDs) == 0 {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
||||
return
|
||||
}
|
||||
// Antigravity 单账号退避重试:分组内没有其他可用账号时,
|
||||
// 对 503 错误不直接返回,而是清除排除列表、等待退避后重试同一个账号。
|
||||
// 谷歌上游 503 (MODEL_CAPACITY_EXHAUSTED) 通常是暂时性的,等几秒就能恢复。
|
||||
if lastFailoverErr != nil && lastFailoverErr.StatusCode == http.StatusServiceUnavailable && switchCount <= maxAccountSwitches {
|
||||
if sleepAntigravitySingleAccountBackoff(c.Request.Context(), switchCount) {
|
||||
reqLog.Warn("gateway.single_account_retrying",
|
||||
zap.Int("retry_count", switchCount),
|
||||
zap.Int("max_retries", maxAccountSwitches),
|
||||
)
|
||||
failedAccountIDs = make(map[int64]struct{})
|
||||
// 设置 context 标记,让 Service 层预检查等待限流过期而非直接切换
|
||||
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
continue
|
||||
action := fs.HandleSelectionExhausted(c.Request.Context())
|
||||
switch action {
|
||||
case FailoverContinue:
|
||||
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
continue
|
||||
case FailoverCanceled:
|
||||
return
|
||||
default: // FailoverExhausted
|
||||
if fs.LastFailoverErr != nil {
|
||||
h.handleFailoverExhausted(c, fs.LastFailoverErr, service.PlatformGemini, streamStarted)
|
||||
} else {
|
||||
h.handleFailoverExhaustedSimple(c, 502, streamStarted)
|
||||
}
|
||||
return
|
||||
}
|
||||
if lastFailoverErr != nil {
|
||||
h.handleFailoverExhausted(c, lastFailoverErr, service.PlatformGemini, streamStarted)
|
||||
} else {
|
||||
h.handleFailoverExhaustedSimple(c, 502, streamStarted)
|
||||
}
|
||||
return
|
||||
}
|
||||
account := selection.Account
|
||||
setOpsSelectedAccount(c, account.ID, account.Platform)
|
||||
@@ -376,8 +363,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
// 转发请求 - 根据账号平台分流
|
||||
var result *service.ForwardResult
|
||||
requestCtx := c.Request.Context()
|
||||
if switchCount > 0 {
|
||||
requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, switchCount)
|
||||
if fs.SwitchCount > 0 {
|
||||
requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, fs.SwitchCount)
|
||||
}
|
||||
if account.Platform == service.PlatformAntigravity {
|
||||
result, err = h.antigravityGatewayService.ForwardGemini(requestCtx, c, account, reqModel, "generateContent", reqStream, body, hasBoundSession)
|
||||
@@ -390,45 +377,16 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
if err != nil {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
lastFailoverErr = failoverErr
|
||||
if needForceCacheBilling(hasBoundSession, failoverErr) {
|
||||
forceCacheBilling = true
|
||||
}
|
||||
|
||||
// 同账号重试:对 RetryableOnSameAccount 的临时性错误,先在同一账号上重试
|
||||
if failoverErr.RetryableOnSameAccount && sameAccountRetryCount[account.ID] < maxSameAccountRetries {
|
||||
sameAccountRetryCount[account.ID]++
|
||||
log.Printf("Account %d: retryable error %d, same-account retry %d/%d",
|
||||
account.ID, failoverErr.StatusCode, sameAccountRetryCount[account.ID], maxSameAccountRetries)
|
||||
if !sleepSameAccountRetryDelay(c.Request.Context()) {
|
||||
return
|
||||
}
|
||||
action := fs.HandleFailoverError(c.Request.Context(), h.gatewayService, account.ID, account.Platform, failoverErr)
|
||||
switch action {
|
||||
case FailoverContinue:
|
||||
continue
|
||||
}
|
||||
|
||||
// 同账号重试用尽,执行临时封禁并切换账号
|
||||
if failoverErr.RetryableOnSameAccount {
|
||||
h.gatewayService.TempUnscheduleRetryableError(c.Request.Context(), account.ID, failoverErr)
|
||||
}
|
||||
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
if switchCount >= maxAccountSwitches {
|
||||
h.handleFailoverExhausted(c, failoverErr, service.PlatformGemini, streamStarted)
|
||||
case FailoverExhausted:
|
||||
h.handleFailoverExhausted(c, fs.LastFailoverErr, service.PlatformGemini, streamStarted)
|
||||
return
|
||||
case FailoverCanceled:
|
||||
return
|
||||
}
|
||||
switchCount++
|
||||
reqLog.Warn("gateway.upstream_failover_switching",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("switch_count", switchCount),
|
||||
zap.Int("max_switches", maxAccountSwitches),
|
||||
)
|
||||
if account.Platform == service.PlatformAntigravity {
|
||||
if !sleepFailoverDelay(c.Request.Context(), switchCount) {
|
||||
return
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
|
||||
reqLog.Error("gateway.forward_failed",
|
||||
@@ -453,7 +411,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
Subscription: subscription,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
ForceCacheBilling: forceCacheBilling,
|
||||
ForceCacheBilling: fs.ForceCacheBilling,
|
||||
APIKeyService: h.apiKeyService,
|
||||
}); err != nil {
|
||||
logger.L().With(
|
||||
@@ -486,45 +444,33 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
}
|
||||
|
||||
for {
|
||||
maxAccountSwitches := h.maxAccountSwitches
|
||||
switchCount := 0
|
||||
failedAccountIDs := make(map[int64]struct{})
|
||||
sameAccountRetryCount := make(map[int64]int) // 同账号重试计数
|
||||
var lastFailoverErr *service.UpstreamFailoverError
|
||||
fs := NewFailoverState(h.maxAccountSwitches, hasBoundSession)
|
||||
retryWithFallback := false
|
||||
var forceCacheBilling bool // 粘性会话切换时的缓存计费标记
|
||||
|
||||
for {
|
||||
// 选择支持该模型的账号
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), currentAPIKey.GroupID, sessionKey, reqModel, failedAccountIDs, parsedReq.MetadataUserID)
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), currentAPIKey.GroupID, sessionKey, reqModel, fs.FailedAccountIDs, parsedReq.MetadataUserID)
|
||||
if err != nil {
|
||||
if len(failedAccountIDs) == 0 {
|
||||
reqLog.Warn("gateway.account_select_failed", zap.Error(err), zap.Int("excluded_account_count", len(failedAccountIDs)))
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted)
|
||||
if len(fs.FailedAccountIDs) == 0 {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
||||
return
|
||||
}
|
||||
// Antigravity 单账号退避重试:分组内没有其他可用账号时,
|
||||
// 对 503 错误不直接返回,而是清除排除列表、等待退避后重试同一个账号。
|
||||
// 谷歌上游 503 (MODEL_CAPACITY_EXHAUSTED) 通常是暂时性的,等几秒就能恢复。
|
||||
if lastFailoverErr != nil && lastFailoverErr.StatusCode == http.StatusServiceUnavailable && switchCount <= maxAccountSwitches {
|
||||
if sleepAntigravitySingleAccountBackoff(c.Request.Context(), switchCount) {
|
||||
reqLog.Warn("gateway.single_account_retrying",
|
||||
zap.Int("retry_count", switchCount),
|
||||
zap.Int("max_retries", maxAccountSwitches),
|
||||
)
|
||||
failedAccountIDs = make(map[int64]struct{})
|
||||
// 设置 context 标记,让 Service 层预检查等待限流过期而非直接切换
|
||||
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
continue
|
||||
action := fs.HandleSelectionExhausted(c.Request.Context())
|
||||
switch action {
|
||||
case FailoverContinue:
|
||||
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
continue
|
||||
case FailoverCanceled:
|
||||
return
|
||||
default: // FailoverExhausted
|
||||
if fs.LastFailoverErr != nil {
|
||||
h.handleFailoverExhausted(c, fs.LastFailoverErr, platform, streamStarted)
|
||||
} else {
|
||||
h.handleFailoverExhaustedSimple(c, 502, streamStarted)
|
||||
}
|
||||
return
|
||||
}
|
||||
if lastFailoverErr != nil {
|
||||
h.handleFailoverExhausted(c, lastFailoverErr, platform, streamStarted)
|
||||
} else {
|
||||
h.handleFailoverExhaustedSimple(c, 502, streamStarted)
|
||||
}
|
||||
return
|
||||
}
|
||||
account := selection.Account
|
||||
setOpsSelectedAccount(c, account.ID, account.Platform)
|
||||
@@ -600,8 +546,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
// 转发请求 - 根据账号平台分流
|
||||
var result *service.ForwardResult
|
||||
requestCtx := c.Request.Context()
|
||||
if switchCount > 0 {
|
||||
requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, switchCount)
|
||||
if fs.SwitchCount > 0 {
|
||||
requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, fs.SwitchCount)
|
||||
}
|
||||
if account.Platform == service.PlatformAntigravity && account.Type != service.AccountTypeAPIKey {
|
||||
result, err = h.antigravityGatewayService.Forward(requestCtx, c, account, body, hasBoundSession)
|
||||
@@ -657,45 +603,16 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
}
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
lastFailoverErr = failoverErr
|
||||
if needForceCacheBilling(hasBoundSession, failoverErr) {
|
||||
forceCacheBilling = true
|
||||
}
|
||||
|
||||
// 同账号重试:对 RetryableOnSameAccount 的临时性错误,先在同一账号上重试
|
||||
if failoverErr.RetryableOnSameAccount && sameAccountRetryCount[account.ID] < maxSameAccountRetries {
|
||||
sameAccountRetryCount[account.ID]++
|
||||
log.Printf("Account %d: retryable error %d, same-account retry %d/%d",
|
||||
account.ID, failoverErr.StatusCode, sameAccountRetryCount[account.ID], maxSameAccountRetries)
|
||||
if !sleepSameAccountRetryDelay(c.Request.Context()) {
|
||||
return
|
||||
}
|
||||
action := fs.HandleFailoverError(c.Request.Context(), h.gatewayService, account.ID, account.Platform, failoverErr)
|
||||
switch action {
|
||||
case FailoverContinue:
|
||||
continue
|
||||
}
|
||||
|
||||
// 同账号重试用尽,执行临时封禁并切换账号
|
||||
if failoverErr.RetryableOnSameAccount {
|
||||
h.gatewayService.TempUnscheduleRetryableError(c.Request.Context(), account.ID, failoverErr)
|
||||
}
|
||||
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
if switchCount >= maxAccountSwitches {
|
||||
h.handleFailoverExhausted(c, failoverErr, account.Platform, streamStarted)
|
||||
case FailoverExhausted:
|
||||
h.handleFailoverExhausted(c, fs.LastFailoverErr, account.Platform, streamStarted)
|
||||
return
|
||||
case FailoverCanceled:
|
||||
return
|
||||
}
|
||||
switchCount++
|
||||
reqLog.Warn("gateway.upstream_failover_switching",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("switch_count", switchCount),
|
||||
zap.Int("max_switches", maxAccountSwitches),
|
||||
)
|
||||
if account.Platform == service.PlatformAntigravity {
|
||||
if !sleepFailoverDelay(c.Request.Context(), switchCount) {
|
||||
return
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
|
||||
reqLog.Error("gateway.forward_failed",
|
||||
@@ -720,7 +637,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
Subscription: currentSubscription,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
ForceCacheBilling: forceCacheBilling,
|
||||
ForceCacheBilling: fs.ForceCacheBilling,
|
||||
APIKeyService: h.apiKeyService,
|
||||
}); err != nil {
|
||||
logger.L().With(
|
||||
@@ -735,7 +652,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
})
|
||||
reqLog.Debug("gateway.request_completed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("switch_count", switchCount),
|
||||
zap.Int("switch_count", fs.SwitchCount),
|
||||
zap.Bool("fallback_used", fallbackUsed),
|
||||
)
|
||||
return
|
||||
@@ -982,69 +899,6 @@ func (h *GatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotT
|
||||
fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted)
|
||||
}
|
||||
|
||||
// needForceCacheBilling 判断 failover 时是否需要强制缓存计费
|
||||
// 粘性会话切换账号、或上游明确标记时,将 input_tokens 转为 cache_read 计费
|
||||
func needForceCacheBilling(hasBoundSession bool, failoverErr *service.UpstreamFailoverError) bool {
|
||||
return hasBoundSession || (failoverErr != nil && failoverErr.ForceCacheBilling)
|
||||
}
|
||||
|
||||
const (
|
||||
// maxSameAccountRetries 同账号重试次数上限(针对 RetryableOnSameAccount 错误)
|
||||
maxSameAccountRetries = 2
|
||||
// sameAccountRetryDelay 同账号重试间隔
|
||||
sameAccountRetryDelay = 500 * time.Millisecond
|
||||
)
|
||||
|
||||
// sleepSameAccountRetryDelay 同账号重试固定延时,返回 false 表示 context 已取消。
|
||||
func sleepSameAccountRetryDelay(ctx context.Context) bool {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return false
|
||||
case <-time.After(sameAccountRetryDelay):
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// sleepFailoverDelay 账号切换线性递增延时:第1次0s、第2次1s、第3次2s…
|
||||
// 返回 false 表示 context 已取消。
|
||||
func sleepFailoverDelay(ctx context.Context, switchCount int) bool {
|
||||
delay := time.Duration(switchCount-1) * time.Second
|
||||
if delay <= 0 {
|
||||
return true
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return false
|
||||
case <-time.After(delay):
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// sleepAntigravitySingleAccountBackoff Antigravity 平台单账号分组的 503 退避重试延时。
|
||||
// 当分组内只有一个可用账号且上游返回 503(MODEL_CAPACITY_EXHAUSTED)时使用,
|
||||
// 采用短固定延时策略。Service 层在 SingleAccountRetry 模式下已经做了充分的原地重试
|
||||
// (最多 3 次、总等待 30s),所以 Handler 层的退避只需短暂等待即可。
|
||||
// 返回 false 表示 context 已取消。
|
||||
func sleepAntigravitySingleAccountBackoff(ctx context.Context, retryCount int) bool {
|
||||
// 固定短延时:2s
|
||||
// Service 层已经在原地等待了足够长的时间(retryDelay × 重试次数),
|
||||
// Handler 层只需短暂间隔后重新进入 Service 层即可。
|
||||
const delay = 2 * time.Second
|
||||
|
||||
logger.L().With(
|
||||
zap.String("component", "handler.gateway.failover"),
|
||||
zap.Duration("delay", delay),
|
||||
zap.Int("retry_count", retryCount),
|
||||
).Info("gateway.single_account_backoff_waiting")
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return false
|
||||
case <-time.After(delay):
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
func (h *GatewayHandler) handleFailoverExhausted(c *gin.Context, failoverErr *service.UpstreamFailoverError, platform string, streamStarted bool) {
|
||||
statusCode := failoverErr.StatusCode
|
||||
responseBody := failoverErr.ResponseBody
|
||||
|
||||
Reference in New Issue
Block a user