feat: cc/codex/gemini 增加账号重试
This commit is contained in:
@@ -129,56 +129,80 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
}
|
||||
|
||||
if platform == service.PlatformGemini {
|
||||
account, err := h.geminiCompatService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, req.Model)
|
||||
if err != nil {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
||||
return
|
||||
}
|
||||
const maxAccountSwitches = 3
|
||||
switchCount := 0
|
||||
failedAccountIDs := make(map[int64]struct{})
|
||||
lastFailoverStatus := 0
|
||||
|
||||
// 检查预热请求拦截(在账号选择后、转发前检查)
|
||||
if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) {
|
||||
if req.Stream {
|
||||
sendMockWarmupStream(c, req.Model)
|
||||
} else {
|
||||
sendMockWarmupResponse(c, req.Model)
|
||||
for {
|
||||
account, err := h.geminiCompatService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, req.Model, failedAccountIDs)
|
||||
if err != nil {
|
||||
if len(failedAccountIDs) == 0 {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
||||
return
|
||||
}
|
||||
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 3. 获取账号并发槽位
|
||||
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, req.Stream, &streamStarted)
|
||||
if err != nil {
|
||||
log.Printf("Account concurrency acquire failed: %v", err)
|
||||
h.handleConcurrencyError(c, err, "account", streamStarted)
|
||||
return
|
||||
}
|
||||
if accountReleaseFunc != nil {
|
||||
defer accountReleaseFunc()
|
||||
}
|
||||
|
||||
// 转发请求
|
||||
result, err := h.geminiCompatService.Forward(c.Request.Context(), c, account, body)
|
||||
if err != nil {
|
||||
// 错误响应已在Forward中处理,这里只记录日志
|
||||
log.Printf("Forward request failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 异步记录使用量(subscription已在函数开头获取)
|
||||
go func(result *service.ForwardResult, usedAccount *service.Account) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||
Result: result,
|
||||
ApiKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: usedAccount,
|
||||
Subscription: subscription,
|
||||
}); err != nil {
|
||||
log.Printf("Record usage failed: %v", err)
|
||||
// 检查预热请求拦截(在账号选择后、转发前检查)
|
||||
if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) {
|
||||
if req.Stream {
|
||||
sendMockWarmupStream(c, req.Model)
|
||||
} else {
|
||||
sendMockWarmupResponse(c, req.Model)
|
||||
}
|
||||
return
|
||||
}
|
||||
}(result, account)
|
||||
return
|
||||
|
||||
// 3. 获取账号并发槽位
|
||||
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, req.Stream, &streamStarted)
|
||||
if err != nil {
|
||||
log.Printf("Account concurrency acquire failed: %v", err)
|
||||
h.handleConcurrencyError(c, err, "account", streamStarted)
|
||||
return
|
||||
}
|
||||
|
||||
// 转发请求
|
||||
result, err := h.geminiCompatService.Forward(c.Request.Context(), c, account, body)
|
||||
if accountReleaseFunc != nil {
|
||||
accountReleaseFunc()
|
||||
}
|
||||
if err != nil {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
if switchCount >= maxAccountSwitches {
|
||||
lastFailoverStatus = failoverErr.StatusCode
|
||||
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
|
||||
return
|
||||
}
|
||||
lastFailoverStatus = failoverErr.StatusCode
|
||||
switchCount++
|
||||
log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
|
||||
continue
|
||||
}
|
||||
// 错误响应已在Forward中处理,这里只记录日志
|
||||
log.Printf("Forward request failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 异步记录使用量(subscription已在函数开头获取)
|
||||
go func(result *service.ForwardResult, usedAccount *service.Account) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||
Result: result,
|
||||
ApiKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: usedAccount,
|
||||
Subscription: subscription,
|
||||
}); err != nil {
|
||||
log.Printf("Record usage failed: %v", err)
|
||||
}
|
||||
}(result, account)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
const maxAccountSwitches = 3
|
||||
|
||||
@@ -2,6 +2,7 @@ package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
@@ -158,44 +159,69 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
|
||||
// 3) select account (sticky session based on request body)
|
||||
sessionHash := h.gatewayService.GenerateSessionHash(body)
|
||||
account, err := h.geminiCompatService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, modelName)
|
||||
if err != nil {
|
||||
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
|
||||
return
|
||||
}
|
||||
const maxAccountSwitches = 3
|
||||
switchCount := 0
|
||||
failedAccountIDs := make(map[int64]struct{})
|
||||
lastFailoverStatus := 0
|
||||
|
||||
// 4) account concurrency slot
|
||||
accountReleaseFunc, err := geminiConcurrency.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, stream, &streamStarted)
|
||||
if err != nil {
|
||||
googleError(c, http.StatusTooManyRequests, err.Error())
|
||||
return
|
||||
}
|
||||
if accountReleaseFunc != nil {
|
||||
defer accountReleaseFunc()
|
||||
}
|
||||
|
||||
// 5) forward (writes response to client)
|
||||
result, err := h.geminiCompatService.ForwardNative(c.Request.Context(), c, account, modelName, action, stream, body)
|
||||
if err != nil {
|
||||
// ForwardNative already wrote the response
|
||||
log.Printf("Gemini native forward failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 6) record usage async
|
||||
go func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||
Result: result,
|
||||
ApiKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
}); err != nil {
|
||||
log.Printf("Record usage failed: %v", err)
|
||||
for {
|
||||
account, err := h.geminiCompatService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, modelName, failedAccountIDs)
|
||||
if err != nil {
|
||||
if len(failedAccountIDs) == 0 {
|
||||
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
|
||||
return
|
||||
}
|
||||
handleGeminiFailoverExhausted(c, lastFailoverStatus)
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
// 4) account concurrency slot
|
||||
accountReleaseFunc, err := geminiConcurrency.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, stream, &streamStarted)
|
||||
if err != nil {
|
||||
googleError(c, http.StatusTooManyRequests, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 5) forward (writes response to client)
|
||||
result, err := h.geminiCompatService.ForwardNative(c.Request.Context(), c, account, modelName, action, stream, body)
|
||||
if accountReleaseFunc != nil {
|
||||
accountReleaseFunc()
|
||||
}
|
||||
if err != nil {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
if switchCount >= maxAccountSwitches {
|
||||
lastFailoverStatus = failoverErr.StatusCode
|
||||
handleGeminiFailoverExhausted(c, lastFailoverStatus)
|
||||
return
|
||||
}
|
||||
lastFailoverStatus = failoverErr.StatusCode
|
||||
switchCount++
|
||||
log.Printf("Gemini account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
|
||||
continue
|
||||
}
|
||||
// ForwardNative already wrote the response
|
||||
log.Printf("Gemini native forward failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 6) record usage async
|
||||
go func(result *service.ForwardResult, usedAccount *service.Account) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||
Result: result,
|
||||
ApiKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: usedAccount,
|
||||
Subscription: subscription,
|
||||
}); err != nil {
|
||||
log.Printf("Record usage failed: %v", err)
|
||||
}
|
||||
}(result, account)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func parseGeminiModelAction(rest string) (model string, action string, err error) {
|
||||
@@ -217,6 +243,28 @@ func parseGeminiModelAction(rest string) (model string, action string, err error
|
||||
return "", "", &pathParseError{"invalid model action path"}
|
||||
}
|
||||
|
||||
func handleGeminiFailoverExhausted(c *gin.Context, statusCode int) {
|
||||
status, message := mapGeminiUpstreamError(statusCode)
|
||||
googleError(c, status, message)
|
||||
}
|
||||
|
||||
func mapGeminiUpstreamError(statusCode int) (int, string) {
|
||||
switch statusCode {
|
||||
case 401:
|
||||
return http.StatusBadGateway, "Upstream authentication failed, please contact administrator"
|
||||
case 403:
|
||||
return http.StatusBadGateway, "Upstream access forbidden, please contact administrator"
|
||||
case 429:
|
||||
return http.StatusTooManyRequests, "Upstream rate limit exceeded, please retry later"
|
||||
case 529:
|
||||
return http.StatusServiceUnavailable, "Upstream service overloaded, please retry later"
|
||||
case 500, 502, 503, 504:
|
||||
return http.StatusBadGateway, "Upstream service temporarily unavailable"
|
||||
default:
|
||||
return http.StatusBadGateway, "Upstream request failed"
|
||||
}
|
||||
}
|
||||
|
||||
type pathParseError struct{ msg string }
|
||||
|
||||
func (e *pathParseError) Error() string { return e.msg }
|
||||
|
||||
@@ -62,14 +62,20 @@ func (s *GeminiMessagesCompatService) GetTokenProvider() *GeminiTokenProvider {
|
||||
}
|
||||
|
||||
func (s *GeminiMessagesCompatService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*Account, error) {
|
||||
return s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, nil)
|
||||
}
|
||||
|
||||
func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) {
|
||||
cacheKey := "gemini:" + sessionHash
|
||||
if sessionHash != "" {
|
||||
accountID, err := s.cache.GetSessionAccountID(ctx, cacheKey)
|
||||
if err == nil && accountID > 0 {
|
||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
||||
if err == nil && account.IsSchedulable() && account.Platform == PlatformGemini && (requestedModel == "" || account.IsModelSupported(requestedModel)) {
|
||||
_ = s.cache.RefreshSessionTTL(ctx, cacheKey, geminiStickySessionTTL)
|
||||
return account, nil
|
||||
if _, excluded := excludedIDs[accountID]; !excluded {
|
||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
||||
if err == nil && account.IsSchedulable() && account.Platform == PlatformGemini && (requestedModel == "" || account.IsModelSupported(requestedModel)) {
|
||||
_ = s.cache.RefreshSessionTTL(ctx, cacheKey, geminiStickySessionTTL)
|
||||
return account, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -88,6 +94,9 @@ func (s *GeminiMessagesCompatService) SelectAccountForModel(ctx context.Context,
|
||||
var selected *Account
|
||||
for i := range accounts {
|
||||
acc := &accounts[i]
|
||||
if _, excluded := excludedIDs[acc.ID]; excluded {
|
||||
continue
|
||||
}
|
||||
if requestedModel != "" && !acc.IsModelSupported(requestedModel) {
|
||||
continue
|
||||
}
|
||||
@@ -425,6 +434,9 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
||||
if resp.StatusCode >= 400 {
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||||
if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) {
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||
}
|
||||
return nil, s.writeGeminiMappedError(c, resp.StatusCode, respBody)
|
||||
}
|
||||
|
||||
@@ -724,6 +736,10 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
||||
}, nil
|
||||
}
|
||||
|
||||
if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) {
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||
}
|
||||
|
||||
respBody = unwrapIfNeeded(isOAuth, respBody)
|
||||
contentType := resp.Header.Get("Content-Type")
|
||||
if contentType == "" {
|
||||
@@ -795,6 +811,15 @@ func (s *GeminiMessagesCompatService) shouldRetryGeminiUpstreamError(account *Ac
|
||||
}
|
||||
}
|
||||
|
||||
func (s *GeminiMessagesCompatService) shouldFailoverGeminiUpstreamError(statusCode int) bool {
|
||||
switch statusCode {
|
||||
case 401, 403, 429, 529:
|
||||
return true
|
||||
default:
|
||||
return statusCode >= 500
|
||||
}
|
||||
}
|
||||
|
||||
func sleepGeminiBackoff(attempt int) {
|
||||
delay := geminiRetryBaseDelay * time.Duration(1<<uint(attempt-1))
|
||||
if delay > geminiRetryMaxDelay {
|
||||
|
||||
Reference in New Issue
Block a user