diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index 9d2e2e8a..a0a4f05e 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -129,56 +129,80 @@ func (h *GatewayHandler) Messages(c *gin.Context) { } if platform == service.PlatformGemini { - account, err := h.geminiCompatService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, req.Model) - if err != nil { - h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) - return - } + const maxAccountSwitches = 3 + switchCount := 0 + failedAccountIDs := make(map[int64]struct{}) + lastFailoverStatus := 0 - // 检查预热请求拦截(在账号选择后、转发前检查) - if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) { - if req.Stream { - sendMockWarmupStream(c, req.Model) - } else { - sendMockWarmupResponse(c, req.Model) + for { + account, err := h.geminiCompatService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, req.Model, failedAccountIDs) + if err != nil { + if len(failedAccountIDs) == 0 { + h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) + return + } + h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted) + return } - return - } - // 3. 获取账号并发槽位 - accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, req.Stream, &streamStarted) - if err != nil { - log.Printf("Account concurrency acquire failed: %v", err) - h.handleConcurrencyError(c, err, "account", streamStarted) - return - } - if accountReleaseFunc != nil { - defer accountReleaseFunc() - } - - // 转发请求 - result, err := h.geminiCompatService.Forward(c.Request.Context(), c, account, body) - if err != nil { - // 错误响应已在Forward中处理,这里只记录日志 - log.Printf("Forward request failed: %v", err) - return - } - - // 异步记录使用量(subscription已在函数开头获取) - go func(result *service.ForwardResult, usedAccount *service.Account) { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ - Result: result, - ApiKey: apiKey, - User: apiKey.User, - Account: usedAccount, - Subscription: subscription, - }); err != nil { - log.Printf("Record usage failed: %v", err) + // 检查预热请求拦截(在账号选择后、转发前检查) + if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) { + if req.Stream { + sendMockWarmupStream(c, req.Model) + } else { + sendMockWarmupResponse(c, req.Model) + } + return } - }(result, account) - return + + // 3. 获取账号并发槽位 + accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, req.Stream, &streamStarted) + if err != nil { + log.Printf("Account concurrency acquire failed: %v", err) + h.handleConcurrencyError(c, err, "account", streamStarted) + return + } + + // 转发请求 + result, err := h.geminiCompatService.Forward(c.Request.Context(), c, account, body) + if accountReleaseFunc != nil { + accountReleaseFunc() + } + if err != nil { + var failoverErr *service.UpstreamFailoverError + if errors.As(err, &failoverErr) { + failedAccountIDs[account.ID] = struct{}{} + if switchCount >= maxAccountSwitches { + lastFailoverStatus = failoverErr.StatusCode + h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted) + return + } + lastFailoverStatus = failoverErr.StatusCode + switchCount++ + log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches) + continue + } + // 错误响应已在Forward中处理,这里只记录日志 + log.Printf("Forward request failed: %v", err) + return + } + + // 异步记录使用量(subscription已在函数开头获取) + go func(result *service.ForwardResult, usedAccount *service.Account) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ + Result: result, + ApiKey: apiKey, + User: apiKey.User, + Account: usedAccount, + Subscription: subscription, + }); err != nil { + log.Printf("Record usage failed: %v", err) + } + }(result, account) + return + } } const maxAccountSwitches = 3 diff --git a/backend/internal/handler/gemini_v1beta_handler.go b/backend/internal/handler/gemini_v1beta_handler.go index 6a9e2e15..53625669 100644 --- a/backend/internal/handler/gemini_v1beta_handler.go +++ b/backend/internal/handler/gemini_v1beta_handler.go @@ -2,6 +2,7 @@ package handler import ( "context" + "errors" "io" "log" "net/http" @@ -158,44 +159,69 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { // 3) select account (sticky session based on request body) sessionHash := h.gatewayService.GenerateSessionHash(body) - account, err := h.geminiCompatService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, modelName) - if err != nil { - googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error()) - return - } + const maxAccountSwitches = 3 + switchCount := 0 + failedAccountIDs := make(map[int64]struct{}) + lastFailoverStatus := 0 - // 4) account concurrency slot - accountReleaseFunc, err := geminiConcurrency.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, stream, &streamStarted) - if err != nil { - googleError(c, http.StatusTooManyRequests, err.Error()) - return - } - if accountReleaseFunc != nil { - defer accountReleaseFunc() - } - - // 5) forward (writes response to client) - result, err := h.geminiCompatService.ForwardNative(c.Request.Context(), c, account, modelName, action, stream, body) - if err != nil { - // ForwardNative already wrote the response - log.Printf("Gemini native forward failed: %v", err) - return - } - - // 6) record usage async - go func() { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ - Result: result, - ApiKey: apiKey, - User: apiKey.User, - Account: account, - Subscription: subscription, - }); err != nil { - log.Printf("Record usage failed: %v", err) + for { + account, err := h.geminiCompatService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, modelName, failedAccountIDs) + if err != nil { + if len(failedAccountIDs) == 0 { + googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error()) + return + } + handleGeminiFailoverExhausted(c, lastFailoverStatus) + return } - }() + + // 4) account concurrency slot + accountReleaseFunc, err := geminiConcurrency.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, stream, &streamStarted) + if err != nil { + googleError(c, http.StatusTooManyRequests, err.Error()) + return + } + + // 5) forward (writes response to client) + result, err := h.geminiCompatService.ForwardNative(c.Request.Context(), c, account, modelName, action, stream, body) + if accountReleaseFunc != nil { + accountReleaseFunc() + } + if err != nil { + var failoverErr *service.UpstreamFailoverError + if errors.As(err, &failoverErr) { + failedAccountIDs[account.ID] = struct{}{} + if switchCount >= maxAccountSwitches { + lastFailoverStatus = failoverErr.StatusCode + handleGeminiFailoverExhausted(c, lastFailoverStatus) + return + } + lastFailoverStatus = failoverErr.StatusCode + switchCount++ + log.Printf("Gemini account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches) + continue + } + // ForwardNative already wrote the response + log.Printf("Gemini native forward failed: %v", err) + return + } + + // 6) record usage async + go func(result *service.ForwardResult, usedAccount *service.Account) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ + Result: result, + ApiKey: apiKey, + User: apiKey.User, + Account: usedAccount, + Subscription: subscription, + }); err != nil { + log.Printf("Record usage failed: %v", err) + } + }(result, account) + return + } } func parseGeminiModelAction(rest string) (model string, action string, err error) { @@ -217,6 +243,28 @@ func parseGeminiModelAction(rest string) (model string, action string, err error return "", "", &pathParseError{"invalid model action path"} } +func handleGeminiFailoverExhausted(c *gin.Context, statusCode int) { + status, message := mapGeminiUpstreamError(statusCode) + googleError(c, status, message) +} + +func mapGeminiUpstreamError(statusCode int) (int, string) { + switch statusCode { + case 401: + return http.StatusBadGateway, "Upstream authentication failed, please contact administrator" + case 403: + return http.StatusBadGateway, "Upstream access forbidden, please contact administrator" + case 429: + return http.StatusTooManyRequests, "Upstream rate limit exceeded, please retry later" + case 529: + return http.StatusServiceUnavailable, "Upstream service overloaded, please retry later" + case 500, 502, 503, 504: + return http.StatusBadGateway, "Upstream service temporarily unavailable" + default: + return http.StatusBadGateway, "Upstream request failed" + } +} + type pathParseError struct{ msg string } func (e *pathParseError) Error() string { return e.msg } diff --git a/backend/internal/service/gemini_messages_compat_service.go b/backend/internal/service/gemini_messages_compat_service.go index e2462f3a..c4a474c1 100644 --- a/backend/internal/service/gemini_messages_compat_service.go +++ b/backend/internal/service/gemini_messages_compat_service.go @@ -62,14 +62,20 @@ func (s *GeminiMessagesCompatService) GetTokenProvider() *GeminiTokenProvider { } func (s *GeminiMessagesCompatService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*Account, error) { + return s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, nil) +} + +func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) { cacheKey := "gemini:" + sessionHash if sessionHash != "" { accountID, err := s.cache.GetSessionAccountID(ctx, cacheKey) if err == nil && accountID > 0 { - account, err := s.accountRepo.GetByID(ctx, accountID) - if err == nil && account.IsSchedulable() && account.Platform == PlatformGemini && (requestedModel == "" || account.IsModelSupported(requestedModel)) { - _ = s.cache.RefreshSessionTTL(ctx, cacheKey, geminiStickySessionTTL) - return account, nil + if _, excluded := excludedIDs[accountID]; !excluded { + account, err := s.accountRepo.GetByID(ctx, accountID) + if err == nil && account.IsSchedulable() && account.Platform == PlatformGemini && (requestedModel == "" || account.IsModelSupported(requestedModel)) { + _ = s.cache.RefreshSessionTTL(ctx, cacheKey, geminiStickySessionTTL) + return account, nil + } } } } @@ -88,6 +94,9 @@ func (s *GeminiMessagesCompatService) SelectAccountForModel(ctx context.Context, var selected *Account for i := range accounts { acc := &accounts[i] + if _, excluded := excludedIDs[acc.ID]; excluded { + continue + } if requestedModel != "" && !acc.IsModelSupported(requestedModel) { continue } @@ -425,6 +434,9 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex if resp.StatusCode >= 400 { respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) + if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) { + return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode} + } return nil, s.writeGeminiMappedError(c, resp.StatusCode, respBody) } @@ -724,6 +736,10 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin. }, nil } + if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) { + return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode} + } + respBody = unwrapIfNeeded(isOAuth, respBody) contentType := resp.Header.Get("Content-Type") if contentType == "" { @@ -795,6 +811,15 @@ func (s *GeminiMessagesCompatService) shouldRetryGeminiUpstreamError(account *Ac } } +func (s *GeminiMessagesCompatService) shouldFailoverGeminiUpstreamError(statusCode int) bool { + switch statusCode { + case 401, 403, 429, 529: + return true + default: + return statusCode >= 500 + } +} + func sleepGeminiBackoff(attempt int) { delay := geminiRetryBaseDelay * time.Duration(1< geminiRetryMaxDelay {