diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index afb1c572..a0a4f05e 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -3,6 +3,7 @@ package handler import ( "context" "encoding/json" + "errors" "fmt" "io" "log" @@ -127,66 +128,158 @@ func (h *GatewayHandler) Messages(c *gin.Context) { platform = apiKey.Group.Platform } - // 选择支持该模型的账号 - var account *service.Account if platform == service.PlatformGemini { - account, err = h.geminiCompatService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, req.Model) - } else { - account, err = h.gatewayService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, req.Model) - } - if err != nil { - h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) - return - } + const maxAccountSwitches = 3 + switchCount := 0 + failedAccountIDs := make(map[int64]struct{}) + lastFailoverStatus := 0 - // 检查预热请求拦截(在账号选择后、转发前检查) - if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) { - if req.Stream { - sendMockWarmupStream(c, req.Model) - } else { - sendMockWarmupResponse(c, req.Model) + for { + account, err := h.geminiCompatService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, req.Model, failedAccountIDs) + if err != nil { + if len(failedAccountIDs) == 0 { + h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) + return + } + h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted) + return + } + + // 检查预热请求拦截(在账号选择后、转发前检查) + if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) { + if req.Stream { + sendMockWarmupStream(c, req.Model) + } else { + sendMockWarmupResponse(c, req.Model) + } + return + } + + // 3. 获取账号并发槽位 + accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, req.Stream, &streamStarted) + if err != nil { + log.Printf("Account concurrency acquire failed: %v", err) + h.handleConcurrencyError(c, err, "account", streamStarted) + return + } + + // 转发请求 + result, err := h.geminiCompatService.Forward(c.Request.Context(), c, account, body) + if accountReleaseFunc != nil { + accountReleaseFunc() + } + if err != nil { + var failoverErr *service.UpstreamFailoverError + if errors.As(err, &failoverErr) { + failedAccountIDs[account.ID] = struct{}{} + if switchCount >= maxAccountSwitches { + lastFailoverStatus = failoverErr.StatusCode + h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted) + return + } + lastFailoverStatus = failoverErr.StatusCode + switchCount++ + log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches) + continue + } + // 错误响应已在Forward中处理,这里只记录日志 + log.Printf("Forward request failed: %v", err) + return + } + + // 异步记录使用量(subscription已在函数开头获取) + go func(result *service.ForwardResult, usedAccount *service.Account) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ + Result: result, + ApiKey: apiKey, + User: apiKey.User, + Account: usedAccount, + Subscription: subscription, + }); err != nil { + log.Printf("Record usage failed: %v", err) + } + }(result, account) + return } - return } - // 3. 获取账号并发槽位 - accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, req.Stream, &streamStarted) - if err != nil { - log.Printf("Account concurrency acquire failed: %v", err) - h.handleConcurrencyError(c, err, "account", streamStarted) - return - } - if accountReleaseFunc != nil { - defer accountReleaseFunc() - } + const maxAccountSwitches = 3 + switchCount := 0 + failedAccountIDs := make(map[int64]struct{}) + lastFailoverStatus := 0 - // 转发请求 - var result *service.ForwardResult - if platform == service.PlatformGemini { - result, err = h.geminiCompatService.Forward(c.Request.Context(), c, account, body) - } else { - result, err = h.gatewayService.Forward(c.Request.Context(), c, account, body) - } - if err != nil { - // 错误响应已在Forward中处理,这里只记录日志 - log.Printf("Forward request failed: %v", err) - return - } - - // 异步记录使用量(subscription已在函数开头获取) - go func() { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ - Result: result, - ApiKey: apiKey, - User: apiKey.User, - Account: account, - Subscription: subscription, - }); err != nil { - log.Printf("Record usage failed: %v", err) + for { + // 选择支持该模型的账号 + account, err := h.gatewayService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, req.Model, failedAccountIDs) + if err != nil { + if len(failedAccountIDs) == 0 { + h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) + return + } + h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted) + return } - }() + + // 检查预热请求拦截(在账号选择后、转发前检查) + if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) { + if req.Stream { + sendMockWarmupStream(c, req.Model) + } else { + sendMockWarmupResponse(c, req.Model) + } + return + } + + // 3. 获取账号并发槽位 + accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, req.Stream, &streamStarted) + if err != nil { + log.Printf("Account concurrency acquire failed: %v", err) + h.handleConcurrencyError(c, err, "account", streamStarted) + return + } + + // 转发请求 + result, err := h.gatewayService.Forward(c.Request.Context(), c, account, body) + if accountReleaseFunc != nil { + accountReleaseFunc() + } + if err != nil { + var failoverErr *service.UpstreamFailoverError + if errors.As(err, &failoverErr) { + failedAccountIDs[account.ID] = struct{}{} + if switchCount >= maxAccountSwitches { + lastFailoverStatus = failoverErr.StatusCode + h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted) + return + } + lastFailoverStatus = failoverErr.StatusCode + switchCount++ + log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches) + continue + } + // 错误响应已在Forward中处理,这里只记录日志 + log.Printf("Forward request failed: %v", err) + return + } + + // 异步记录使用量(subscription已在函数开头获取) + go func(result *service.ForwardResult, usedAccount *service.Account) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ + Result: result, + ApiKey: apiKey, + User: apiKey.User, + Account: usedAccount, + Subscription: subscription, + }); err != nil { + log.Printf("Record usage failed: %v", err) + } + }(result, account) + return + } } // Models handles listing available models @@ -314,6 +407,28 @@ func (h *GatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotT fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted) } +func (h *GatewayHandler) handleFailoverExhausted(c *gin.Context, statusCode int, streamStarted bool) { + status, errType, errMsg := h.mapUpstreamError(statusCode) + h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted) +} + +func (h *GatewayHandler) mapUpstreamError(statusCode int) (int, string, string) { + switch statusCode { + case 401: + return http.StatusBadGateway, "upstream_error", "Upstream authentication failed, please contact administrator" + case 403: + return http.StatusBadGateway, "upstream_error", "Upstream access forbidden, please contact administrator" + case 429: + return http.StatusTooManyRequests, "rate_limit_error", "Upstream rate limit exceeded, please retry later" + case 529: + return http.StatusServiceUnavailable, "overloaded_error", "Upstream service overloaded, please retry later" + case 500, 502, 503, 504: + return http.StatusBadGateway, "upstream_error", "Upstream service temporarily unavailable" + default: + return http.StatusBadGateway, "upstream_error", "Upstream request failed" + } +} + // handleStreamingAwareError handles errors that may occur after streaming has started func (h *GatewayHandler) handleStreamingAwareError(c *gin.Context, status int, errType, message string, streamStarted bool) { if streamStarted { diff --git a/backend/internal/handler/gemini_v1beta_handler.go b/backend/internal/handler/gemini_v1beta_handler.go index 6a9e2e15..53625669 100644 --- a/backend/internal/handler/gemini_v1beta_handler.go +++ b/backend/internal/handler/gemini_v1beta_handler.go @@ -2,6 +2,7 @@ package handler import ( "context" + "errors" "io" "log" "net/http" @@ -158,44 +159,69 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { // 3) select account (sticky session based on request body) sessionHash := h.gatewayService.GenerateSessionHash(body) - account, err := h.geminiCompatService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, modelName) - if err != nil { - googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error()) - return - } + const maxAccountSwitches = 3 + switchCount := 0 + failedAccountIDs := make(map[int64]struct{}) + lastFailoverStatus := 0 - // 4) account concurrency slot - accountReleaseFunc, err := geminiConcurrency.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, stream, &streamStarted) - if err != nil { - googleError(c, http.StatusTooManyRequests, err.Error()) - return - } - if accountReleaseFunc != nil { - defer accountReleaseFunc() - } - - // 5) forward (writes response to client) - result, err := h.geminiCompatService.ForwardNative(c.Request.Context(), c, account, modelName, action, stream, body) - if err != nil { - // ForwardNative already wrote the response - log.Printf("Gemini native forward failed: %v", err) - return - } - - // 6) record usage async - go func() { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ - Result: result, - ApiKey: apiKey, - User: apiKey.User, - Account: account, - Subscription: subscription, - }); err != nil { - log.Printf("Record usage failed: %v", err) + for { + account, err := h.geminiCompatService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, modelName, failedAccountIDs) + if err != nil { + if len(failedAccountIDs) == 0 { + googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error()) + return + } + handleGeminiFailoverExhausted(c, lastFailoverStatus) + return } - }() + + // 4) account concurrency slot + accountReleaseFunc, err := geminiConcurrency.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, stream, &streamStarted) + if err != nil { + googleError(c, http.StatusTooManyRequests, err.Error()) + return + } + + // 5) forward (writes response to client) + result, err := h.geminiCompatService.ForwardNative(c.Request.Context(), c, account, modelName, action, stream, body) + if accountReleaseFunc != nil { + accountReleaseFunc() + } + if err != nil { + var failoverErr *service.UpstreamFailoverError + if errors.As(err, &failoverErr) { + failedAccountIDs[account.ID] = struct{}{} + if switchCount >= maxAccountSwitches { + lastFailoverStatus = failoverErr.StatusCode + handleGeminiFailoverExhausted(c, lastFailoverStatus) + return + } + lastFailoverStatus = failoverErr.StatusCode + switchCount++ + log.Printf("Gemini account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches) + continue + } + // ForwardNative already wrote the response + log.Printf("Gemini native forward failed: %v", err) + return + } + + // 6) record usage async + go func(result *service.ForwardResult, usedAccount *service.Account) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ + Result: result, + ApiKey: apiKey, + User: apiKey.User, + Account: usedAccount, + Subscription: subscription, + }); err != nil { + log.Printf("Record usage failed: %v", err) + } + }(result, account) + return + } } func parseGeminiModelAction(rest string) (model string, action string, err error) { @@ -217,6 +243,28 @@ func parseGeminiModelAction(rest string) (model string, action string, err error return "", "", &pathParseError{"invalid model action path"} } +func handleGeminiFailoverExhausted(c *gin.Context, statusCode int) { + status, message := mapGeminiUpstreamError(statusCode) + googleError(c, status, message) +} + +func mapGeminiUpstreamError(statusCode int) (int, string) { + switch statusCode { + case 401: + return http.StatusBadGateway, "Upstream authentication failed, please contact administrator" + case 403: + return http.StatusBadGateway, "Upstream access forbidden, please contact administrator" + case 429: + return http.StatusTooManyRequests, "Upstream rate limit exceeded, please retry later" + case 529: + return http.StatusServiceUnavailable, "Upstream service overloaded, please retry later" + case 500, 502, 503, 504: + return http.StatusBadGateway, "Upstream service temporarily unavailable" + default: + return http.StatusBadGateway, "Upstream request failed" + } +} + type pathParseError struct{ msg string } func (e *pathParseError) Error() string { return e.msg } diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go index b082d727..2dee9ccd 100644 --- a/backend/internal/handler/openai_gateway_handler.go +++ b/backend/internal/handler/openai_gateway_handler.go @@ -3,6 +3,7 @@ package handler import ( "context" "encoding/json" + "errors" "fmt" "io" "log" @@ -127,49 +128,74 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { // Generate session hash (from header for OpenAI) sessionHash := h.gatewayService.GenerateSessionHash(c) - // Select account supporting the requested model - log.Printf("[OpenAI Handler] Selecting account: groupID=%v model=%s", apiKey.GroupID, reqModel) - account, err := h.gatewayService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel) - if err != nil { - log.Printf("[OpenAI Handler] SelectAccount failed: %v", err) - h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) - return - } - log.Printf("[OpenAI Handler] Selected account: id=%d name=%s", account.ID, account.Name) + const maxAccountSwitches = 3 + switchCount := 0 + failedAccountIDs := make(map[int64]struct{}) + lastFailoverStatus := 0 - // 3. Acquire account concurrency slot - accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, reqStream, &streamStarted) - if err != nil { - log.Printf("Account concurrency acquire failed: %v", err) - h.handleConcurrencyError(c, err, "account", streamStarted) - return - } - if accountReleaseFunc != nil { - defer accountReleaseFunc() - } - - // Forward request - result, err := h.gatewayService.Forward(c.Request.Context(), c, account, body) - if err != nil { - // Error response already handled in Forward, just log - log.Printf("Forward request failed: %v", err) - return - } - - // Async record usage - go func() { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{ - Result: result, - ApiKey: apiKey, - User: apiKey.User, - Account: account, - Subscription: subscription, - }); err != nil { - log.Printf("Record usage failed: %v", err) + for { + // Select account supporting the requested model + log.Printf("[OpenAI Handler] Selecting account: groupID=%v model=%s", apiKey.GroupID, reqModel) + account, err := h.gatewayService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs) + if err != nil { + log.Printf("[OpenAI Handler] SelectAccount failed: %v", err) + if len(failedAccountIDs) == 0 { + h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) + return + } + h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted) + return } - }() + log.Printf("[OpenAI Handler] Selected account: id=%d name=%s", account.ID, account.Name) + + // 3. Acquire account concurrency slot + accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, reqStream, &streamStarted) + if err != nil { + log.Printf("Account concurrency acquire failed: %v", err) + h.handleConcurrencyError(c, err, "account", streamStarted) + return + } + + // Forward request + result, err := h.gatewayService.Forward(c.Request.Context(), c, account, body) + if accountReleaseFunc != nil { + accountReleaseFunc() + } + if err != nil { + var failoverErr *service.UpstreamFailoverError + if errors.As(err, &failoverErr) { + failedAccountIDs[account.ID] = struct{}{} + if switchCount >= maxAccountSwitches { + lastFailoverStatus = failoverErr.StatusCode + h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted) + return + } + lastFailoverStatus = failoverErr.StatusCode + switchCount++ + log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches) + continue + } + // Error response already handled in Forward, just log + log.Printf("Forward request failed: %v", err) + return + } + + // Async record usage + go func(result *service.OpenAIForwardResult, usedAccount *service.Account) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{ + Result: result, + ApiKey: apiKey, + User: apiKey.User, + Account: usedAccount, + Subscription: subscription, + }); err != nil { + log.Printf("Record usage failed: %v", err) + } + }(result, account) + return + } } // handleConcurrencyError handles concurrency-related errors with proper 429 response @@ -178,6 +204,28 @@ func (h *OpenAIGatewayHandler) handleConcurrencyError(c *gin.Context, err error, fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted) } +func (h *OpenAIGatewayHandler) handleFailoverExhausted(c *gin.Context, statusCode int, streamStarted bool) { + status, errType, errMsg := h.mapUpstreamError(statusCode) + h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted) +} + +func (h *OpenAIGatewayHandler) mapUpstreamError(statusCode int) (int, string, string) { + switch statusCode { + case 401: + return http.StatusBadGateway, "upstream_error", "Upstream authentication failed, please contact administrator" + case 403: + return http.StatusBadGateway, "upstream_error", "Upstream access forbidden, please contact administrator" + case 429: + return http.StatusTooManyRequests, "rate_limit_error", "Upstream rate limit exceeded, please retry later" + case 529: + return http.StatusServiceUnavailable, "upstream_error", "Upstream service overloaded, please retry later" + case 500, 502, 503, 504: + return http.StatusBadGateway, "upstream_error", "Upstream service temporarily unavailable" + default: + return http.StatusBadGateway, "upstream_error", "Upstream request failed" + } +} + // handleStreamingAwareError handles errors that may occur after streaming has started func (h *OpenAIGatewayHandler) handleStreamingAwareError(c *gin.Context, status int, errType, message string, streamStarted bool) { if streamStarted { diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index d25bb314..bda31a7d 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -81,6 +81,15 @@ type ForwardResult struct { FirstTokenMs *int // 首字时间(流式请求) } +// UpstreamFailoverError indicates an upstream error that should trigger account failover. +type UpstreamFailoverError struct { + StatusCode int +} + +func (e *UpstreamFailoverError) Error() string { + return fmt.Sprintf("upstream error: %d (failover)", e.StatusCode) +} + // GatewayService handles API gateway operations type GatewayService struct { accountRepo AccountRepository @@ -274,19 +283,26 @@ func (s *GatewayService) SelectAccount(ctx context.Context, groupID *int64, sess // SelectAccountForModel 选择支持指定模型的账号(粘性会话+优先级+模型映射) func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*Account, error) { + return s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, nil) +} + +// SelectAccountForModelWithExclusions selects an account supporting the requested model while excluding specified accounts. +func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) { // 1. 查询粘性会话 if sessionHash != "" { accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash) if err == nil && accountID > 0 { - account, err := s.accountRepo.GetByID(ctx, accountID) - // 使用IsSchedulable代替IsActive,确保限流/过载账号不会被选中 - // 同时检查模型支持 - if err == nil && account.IsSchedulable() && (requestedModel == "" || account.IsModelSupported(requestedModel)) { - // 续期粘性会话 - if err := s.cache.RefreshSessionTTL(ctx, sessionHash, stickySessionTTL); err != nil { - log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err) + if _, excluded := excludedIDs[accountID]; !excluded { + account, err := s.accountRepo.GetByID(ctx, accountID) + // 使用IsSchedulable代替IsActive,确保限流/过载账号不会被选中 + // 同时检查模型支持 + if err == nil && account.IsSchedulable() && (requestedModel == "" || account.IsModelSupported(requestedModel)) { + // 续期粘性会话 + if err := s.cache.RefreshSessionTTL(ctx, sessionHash, stickySessionTTL); err != nil { + log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err) + } + return account, nil } - return account, nil } } } @@ -307,6 +323,9 @@ func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int var selected *Account for i := range accounts { acc := &accounts[i] + if _, excluded := excludedIDs[acc.ID]; excluded { + continue + } // 检查模型支持 if requestedModel != "" && !acc.IsModelSupported(requestedModel) { continue @@ -394,6 +413,16 @@ func (s *GatewayService) shouldRetryUpstreamError(account *Account, statusCode i return !account.ShouldHandleErrorCode(statusCode) } +// shouldFailoverUpstreamError determines whether an upstream error should trigger account failover. +func (s *GatewayService) shouldFailoverUpstreamError(statusCode int) bool { + switch statusCode { + case 401, 403, 429, 529: + return true + default: + return statusCode >= 500 + } +} + // Forward 转发请求到Claude API func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte) (*ForwardResult, error) { startTime := time.Now() @@ -478,9 +507,19 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A // 处理重试耗尽的情况 if resp.StatusCode >= 400 && s.shouldRetryUpstreamError(account, resp.StatusCode) { + if s.shouldFailoverUpstreamError(resp.StatusCode) { + s.handleRetryExhaustedSideEffects(ctx, resp, account) + return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode} + } return s.handleRetryExhaustedError(ctx, resp, c, account) } + // 处理可切换账号的错误 + if resp.StatusCode >= 400 && s.shouldFailoverUpstreamError(resp.StatusCode) { + s.handleFailoverSideEffects(ctx, resp, account) + return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode} + } + // 处理错误响应(不可重试的错误) if resp.StatusCode >= 400 { return s.handleErrorResponse(ctx, resp, c, account) @@ -692,10 +731,7 @@ func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Res return nil, fmt.Errorf("upstream error: %d", resp.StatusCode) } -// handleRetryExhaustedError 处理重试耗尽后的错误 -// OAuth 403:标记账号异常 -// API Key 未配置错误码:仅返回错误,不标记账号 -func (s *GatewayService) handleRetryExhaustedError(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*ForwardResult, error) { +func (s *GatewayService) handleRetryExhaustedSideEffects(ctx context.Context, resp *http.Response, account *Account) { body, _ := io.ReadAll(resp.Body) statusCode := resp.StatusCode @@ -707,6 +743,18 @@ func (s *GatewayService) handleRetryExhaustedError(ctx context.Context, resp *ht // API Key 未配置错误码:不标记账号状态 log.Printf("Account %d: upstream error %d after %d retries (not marking account)", account.ID, statusCode, maxRetries) } +} + +func (s *GatewayService) handleFailoverSideEffects(ctx context.Context, resp *http.Response, account *Account) { + body, _ := io.ReadAll(resp.Body) + s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body) +} + +// handleRetryExhaustedError 处理重试耗尽后的错误 +// OAuth 403:标记账号异常 +// API Key 未配置错误码:仅返回错误,不标记账号 +func (s *GatewayService) handleRetryExhaustedError(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*ForwardResult, error) { + s.handleRetryExhaustedSideEffects(ctx, resp, account) // 返回统一的重试耗尽错误响应 c.JSON(http.StatusBadGateway, gin.H{ @@ -717,7 +765,7 @@ func (s *GatewayService) handleRetryExhaustedError(ctx context.Context, resp *ht }, }) - return nil, fmt.Errorf("upstream error: %d (retries exhausted)", statusCode) + return nil, fmt.Errorf("upstream error: %d (retries exhausted)", resp.StatusCode) } // streamingResult 流式响应结果 diff --git a/backend/internal/service/gemini_messages_compat_service.go b/backend/internal/service/gemini_messages_compat_service.go index e2462f3a..c4a474c1 100644 --- a/backend/internal/service/gemini_messages_compat_service.go +++ b/backend/internal/service/gemini_messages_compat_service.go @@ -62,14 +62,20 @@ func (s *GeminiMessagesCompatService) GetTokenProvider() *GeminiTokenProvider { } func (s *GeminiMessagesCompatService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*Account, error) { + return s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, nil) +} + +func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) { cacheKey := "gemini:" + sessionHash if sessionHash != "" { accountID, err := s.cache.GetSessionAccountID(ctx, cacheKey) if err == nil && accountID > 0 { - account, err := s.accountRepo.GetByID(ctx, accountID) - if err == nil && account.IsSchedulable() && account.Platform == PlatformGemini && (requestedModel == "" || account.IsModelSupported(requestedModel)) { - _ = s.cache.RefreshSessionTTL(ctx, cacheKey, geminiStickySessionTTL) - return account, nil + if _, excluded := excludedIDs[accountID]; !excluded { + account, err := s.accountRepo.GetByID(ctx, accountID) + if err == nil && account.IsSchedulable() && account.Platform == PlatformGemini && (requestedModel == "" || account.IsModelSupported(requestedModel)) { + _ = s.cache.RefreshSessionTTL(ctx, cacheKey, geminiStickySessionTTL) + return account, nil + } } } } @@ -88,6 +94,9 @@ func (s *GeminiMessagesCompatService) SelectAccountForModel(ctx context.Context, var selected *Account for i := range accounts { acc := &accounts[i] + if _, excluded := excludedIDs[acc.ID]; excluded { + continue + } if requestedModel != "" && !acc.IsModelSupported(requestedModel) { continue } @@ -425,6 +434,9 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex if resp.StatusCode >= 400 { respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) + if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) { + return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode} + } return nil, s.writeGeminiMappedError(c, resp.StatusCode, respBody) } @@ -724,6 +736,10 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin. }, nil } + if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) { + return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode} + } + respBody = unwrapIfNeeded(isOAuth, respBody) contentType := resp.Header.Get("Content-Type") if contentType == "" { @@ -795,6 +811,15 @@ func (s *GeminiMessagesCompatService) shouldRetryGeminiUpstreamError(account *Ac } } +func (s *GeminiMessagesCompatService) shouldFailoverGeminiUpstreamError(statusCode int) bool { + switch statusCode { + case 401, 403, 429, 529: + return true + default: + return statusCode >= 500 + } +} + func sleepGeminiBackoff(attempt int) { delay := geminiRetryBaseDelay * time.Duration(1< geminiRetryMaxDelay { diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index f57d361b..7900ff3e 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -129,15 +129,22 @@ func (s *OpenAIGatewayService) SelectAccount(ctx context.Context, groupID *int64 // SelectAccountForModel selects an account supporting the requested model func (s *OpenAIGatewayService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*Account, error) { + return s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, nil) +} + +// SelectAccountForModelWithExclusions selects an account supporting the requested model while excluding specified accounts. +func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) { // 1. Check sticky session if sessionHash != "" { accountID, err := s.cache.GetSessionAccountID(ctx, "openai:"+sessionHash) if err == nil && accountID > 0 { - account, err := s.accountRepo.GetByID(ctx, accountID) - if err == nil && account.IsSchedulable() && account.IsOpenAI() && (requestedModel == "" || account.IsModelSupported(requestedModel)) { - // Refresh sticky session TTL - _ = s.cache.RefreshSessionTTL(ctx, "openai:"+sessionHash, openaiStickySessionTTL) - return account, nil + if _, excluded := excludedIDs[accountID]; !excluded { + account, err := s.accountRepo.GetByID(ctx, accountID) + if err == nil && account.IsSchedulable() && account.IsOpenAI() && (requestedModel == "" || account.IsModelSupported(requestedModel)) { + // Refresh sticky session TTL + _ = s.cache.RefreshSessionTTL(ctx, "openai:"+sessionHash, openaiStickySessionTTL) + return account, nil + } } } } @@ -158,6 +165,9 @@ func (s *OpenAIGatewayService) SelectAccountForModel(ctx context.Context, groupI var selected *Account for i := range accounts { acc := &accounts[i] + if _, excluded := excludedIDs[acc.ID]; excluded { + continue + } // Check model support if requestedModel != "" && !acc.IsModelSupported(requestedModel) { continue @@ -221,6 +231,20 @@ func (s *OpenAIGatewayService) GetAccessToken(ctx context.Context, account *Acco } } +func (s *OpenAIGatewayService) shouldFailoverUpstreamError(statusCode int) bool { + switch statusCode { + case 401, 403, 429, 529: + return true + default: + return statusCode >= 500 + } +} + +func (s *OpenAIGatewayService) handleFailoverSideEffects(ctx context.Context, resp *http.Response, account *Account) { + body, _ := io.ReadAll(resp.Body) + s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body) +} + // Forward forwards request to OpenAI API func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte) (*OpenAIForwardResult, error) { startTime := time.Now() @@ -288,6 +312,10 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco // Handle error response if resp.StatusCode >= 400 { + if s.shouldFailoverUpstreamError(resp.StatusCode) { + s.handleFailoverSideEffects(ctx, resp, account) + return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode} + } return s.handleErrorResponse(ctx, resp, c, account) }