feat: cc/codex/gemini 增加账号重试

This commit is contained in:
daodao97
2025-12-27 12:27:47 +08:00
parent 95583fce83
commit f0f920e49f
3 changed files with 183 additions and 86 deletions

View File

@@ -129,56 +129,80 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
} }
if platform == service.PlatformGemini { if platform == service.PlatformGemini {
account, err := h.geminiCompatService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, req.Model) const maxAccountSwitches = 3
if err != nil { switchCount := 0
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) failedAccountIDs := make(map[int64]struct{})
return lastFailoverStatus := 0
}
// 检查预热请求拦截(在账号选择后、转发前检查) for {
if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) { account, err := h.geminiCompatService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, req.Model, failedAccountIDs)
if req.Stream { if err != nil {
sendMockWarmupStream(c, req.Model) if len(failedAccountIDs) == 0 {
} else { h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
sendMockWarmupResponse(c, req.Model) return
}
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
return
} }
return
}
// 3. 获取账号并发槽位 // 检查预热请求拦截(在账号选择后、转发前检查)
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, req.Stream, &streamStarted) if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) {
if err != nil { if req.Stream {
log.Printf("Account concurrency acquire failed: %v", err) sendMockWarmupStream(c, req.Model)
h.handleConcurrencyError(c, err, "account", streamStarted) } else {
return sendMockWarmupResponse(c, req.Model)
} }
if accountReleaseFunc != nil { return
defer accountReleaseFunc()
}
// 转发请求
result, err := h.geminiCompatService.Forward(c.Request.Context(), c, account, body)
if err != nil {
// 错误响应已在Forward中处理这里只记录日志
log.Printf("Forward request failed: %v", err)
return
}
// 异步记录使用量subscription已在函数开头获取
go func(result *service.ForwardResult, usedAccount *service.Account) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result,
ApiKey: apiKey,
User: apiKey.User,
Account: usedAccount,
Subscription: subscription,
}); err != nil {
log.Printf("Record usage failed: %v", err)
} }
}(result, account)
return // 3. 获取账号并发槽位
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, req.Stream, &streamStarted)
if err != nil {
log.Printf("Account concurrency acquire failed: %v", err)
h.handleConcurrencyError(c, err, "account", streamStarted)
return
}
// 转发请求
result, err := h.geminiCompatService.Forward(c.Request.Context(), c, account, body)
if accountReleaseFunc != nil {
accountReleaseFunc()
}
if err != nil {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
failedAccountIDs[account.ID] = struct{}{}
if switchCount >= maxAccountSwitches {
lastFailoverStatus = failoverErr.StatusCode
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
return
}
lastFailoverStatus = failoverErr.StatusCode
switchCount++
log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
continue
}
// 错误响应已在Forward中处理这里只记录日志
log.Printf("Forward request failed: %v", err)
return
}
// 异步记录使用量subscription已在函数开头获取
go func(result *service.ForwardResult, usedAccount *service.Account) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result,
ApiKey: apiKey,
User: apiKey.User,
Account: usedAccount,
Subscription: subscription,
}); err != nil {
log.Printf("Record usage failed: %v", err)
}
}(result, account)
return
}
} }
const maxAccountSwitches = 3 const maxAccountSwitches = 3

View File

@@ -2,6 +2,7 @@ package handler
import ( import (
"context" "context"
"errors"
"io" "io"
"log" "log"
"net/http" "net/http"
@@ -158,44 +159,69 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
// 3) select account (sticky session based on request body) // 3) select account (sticky session based on request body)
sessionHash := h.gatewayService.GenerateSessionHash(body) sessionHash := h.gatewayService.GenerateSessionHash(body)
account, err := h.geminiCompatService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, modelName) const maxAccountSwitches = 3
if err != nil { switchCount := 0
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error()) failedAccountIDs := make(map[int64]struct{})
return lastFailoverStatus := 0
}
// 4) account concurrency slot for {
accountReleaseFunc, err := geminiConcurrency.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, stream, &streamStarted) account, err := h.geminiCompatService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, modelName, failedAccountIDs)
if err != nil { if err != nil {
googleError(c, http.StatusTooManyRequests, err.Error()) if len(failedAccountIDs) == 0 {
return googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
} return
if accountReleaseFunc != nil { }
defer accountReleaseFunc() handleGeminiFailoverExhausted(c, lastFailoverStatus)
} return
// 5) forward (writes response to client)
result, err := h.geminiCompatService.ForwardNative(c.Request.Context(), c, account, modelName, action, stream, body)
if err != nil {
// ForwardNative already wrote the response
log.Printf("Gemini native forward failed: %v", err)
return
}
// 6) record usage async
go func() {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result,
ApiKey: apiKey,
User: apiKey.User,
Account: account,
Subscription: subscription,
}); err != nil {
log.Printf("Record usage failed: %v", err)
} }
}()
// 4) account concurrency slot
accountReleaseFunc, err := geminiConcurrency.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, stream, &streamStarted)
if err != nil {
googleError(c, http.StatusTooManyRequests, err.Error())
return
}
// 5) forward (writes response to client)
result, err := h.geminiCompatService.ForwardNative(c.Request.Context(), c, account, modelName, action, stream, body)
if accountReleaseFunc != nil {
accountReleaseFunc()
}
if err != nil {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
failedAccountIDs[account.ID] = struct{}{}
if switchCount >= maxAccountSwitches {
lastFailoverStatus = failoverErr.StatusCode
handleGeminiFailoverExhausted(c, lastFailoverStatus)
return
}
lastFailoverStatus = failoverErr.StatusCode
switchCount++
log.Printf("Gemini account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
continue
}
// ForwardNative already wrote the response
log.Printf("Gemini native forward failed: %v", err)
return
}
// 6) record usage async
go func(result *service.ForwardResult, usedAccount *service.Account) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result,
ApiKey: apiKey,
User: apiKey.User,
Account: usedAccount,
Subscription: subscription,
}); err != nil {
log.Printf("Record usage failed: %v", err)
}
}(result, account)
return
}
} }
func parseGeminiModelAction(rest string) (model string, action string, err error) { func parseGeminiModelAction(rest string) (model string, action string, err error) {
@@ -217,6 +243,28 @@ func parseGeminiModelAction(rest string) (model string, action string, err error
return "", "", &pathParseError{"invalid model action path"} return "", "", &pathParseError{"invalid model action path"}
} }
func handleGeminiFailoverExhausted(c *gin.Context, statusCode int) {
status, message := mapGeminiUpstreamError(statusCode)
googleError(c, status, message)
}
func mapGeminiUpstreamError(statusCode int) (int, string) {
switch statusCode {
case 401:
return http.StatusBadGateway, "Upstream authentication failed, please contact administrator"
case 403:
return http.StatusBadGateway, "Upstream access forbidden, please contact administrator"
case 429:
return http.StatusTooManyRequests, "Upstream rate limit exceeded, please retry later"
case 529:
return http.StatusServiceUnavailable, "Upstream service overloaded, please retry later"
case 500, 502, 503, 504:
return http.StatusBadGateway, "Upstream service temporarily unavailable"
default:
return http.StatusBadGateway, "Upstream request failed"
}
}
type pathParseError struct{ msg string } type pathParseError struct{ msg string }
func (e *pathParseError) Error() string { return e.msg } func (e *pathParseError) Error() string { return e.msg }

View File

@@ -62,14 +62,20 @@ func (s *GeminiMessagesCompatService) GetTokenProvider() *GeminiTokenProvider {
} }
func (s *GeminiMessagesCompatService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*Account, error) { func (s *GeminiMessagesCompatService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*Account, error) {
return s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, nil)
}
func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) {
cacheKey := "gemini:" + sessionHash cacheKey := "gemini:" + sessionHash
if sessionHash != "" { if sessionHash != "" {
accountID, err := s.cache.GetSessionAccountID(ctx, cacheKey) accountID, err := s.cache.GetSessionAccountID(ctx, cacheKey)
if err == nil && accountID > 0 { if err == nil && accountID > 0 {
account, err := s.accountRepo.GetByID(ctx, accountID) if _, excluded := excludedIDs[accountID]; !excluded {
if err == nil && account.IsSchedulable() && account.Platform == PlatformGemini && (requestedModel == "" || account.IsModelSupported(requestedModel)) { account, err := s.accountRepo.GetByID(ctx, accountID)
_ = s.cache.RefreshSessionTTL(ctx, cacheKey, geminiStickySessionTTL) if err == nil && account.IsSchedulable() && account.Platform == PlatformGemini && (requestedModel == "" || account.IsModelSupported(requestedModel)) {
return account, nil _ = s.cache.RefreshSessionTTL(ctx, cacheKey, geminiStickySessionTTL)
return account, nil
}
} }
} }
} }
@@ -88,6 +94,9 @@ func (s *GeminiMessagesCompatService) SelectAccountForModel(ctx context.Context,
var selected *Account var selected *Account
for i := range accounts { for i := range accounts {
acc := &accounts[i] acc := &accounts[i]
if _, excluded := excludedIDs[acc.ID]; excluded {
continue
}
if requestedModel != "" && !acc.IsModelSupported(requestedModel) { if requestedModel != "" && !acc.IsModelSupported(requestedModel) {
continue continue
} }
@@ -425,6 +434,9 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
if resp.StatusCode >= 400 { if resp.StatusCode >= 400 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) {
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
}
return nil, s.writeGeminiMappedError(c, resp.StatusCode, respBody) return nil, s.writeGeminiMappedError(c, resp.StatusCode, respBody)
} }
@@ -724,6 +736,10 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
}, nil }, nil
} }
if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) {
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
}
respBody = unwrapIfNeeded(isOAuth, respBody) respBody = unwrapIfNeeded(isOAuth, respBody)
contentType := resp.Header.Get("Content-Type") contentType := resp.Header.Get("Content-Type")
if contentType == "" { if contentType == "" {
@@ -795,6 +811,15 @@ func (s *GeminiMessagesCompatService) shouldRetryGeminiUpstreamError(account *Ac
} }
} }
func (s *GeminiMessagesCompatService) shouldFailoverGeminiUpstreamError(statusCode int) bool {
switch statusCode {
case 401, 403, 429, 529:
return true
default:
return statusCode >= 500
}
}
func sleepGeminiBackoff(attempt int) { func sleepGeminiBackoff(attempt int) {
delay := geminiRetryBaseDelay * time.Duration(1<<uint(attempt-1)) delay := geminiRetryBaseDelay * time.Duration(1<<uint(attempt-1))
if delay > geminiRetryMaxDelay { if delay > geminiRetryMaxDelay {