feat: cc/codex/gemini 增加账号重试
This commit is contained in:
@@ -129,56 +129,80 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if platform == service.PlatformGemini {
|
if platform == service.PlatformGemini {
|
||||||
account, err := h.geminiCompatService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, req.Model)
|
const maxAccountSwitches = 3
|
||||||
if err != nil {
|
switchCount := 0
|
||||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
failedAccountIDs := make(map[int64]struct{})
|
||||||
return
|
lastFailoverStatus := 0
|
||||||
}
|
|
||||||
|
|
||||||
// 检查预热请求拦截(在账号选择后、转发前检查)
|
for {
|
||||||
if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) {
|
account, err := h.geminiCompatService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, req.Model, failedAccountIDs)
|
||||||
if req.Stream {
|
if err != nil {
|
||||||
sendMockWarmupStream(c, req.Model)
|
if len(failedAccountIDs) == 0 {
|
||||||
} else {
|
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
||||||
sendMockWarmupResponse(c, req.Model)
|
return
|
||||||
|
}
|
||||||
|
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 3. 获取账号并发槽位
|
// 检查预热请求拦截(在账号选择后、转发前检查)
|
||||||
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, req.Stream, &streamStarted)
|
if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) {
|
||||||
if err != nil {
|
if req.Stream {
|
||||||
log.Printf("Account concurrency acquire failed: %v", err)
|
sendMockWarmupStream(c, req.Model)
|
||||||
h.handleConcurrencyError(c, err, "account", streamStarted)
|
} else {
|
||||||
return
|
sendMockWarmupResponse(c, req.Model)
|
||||||
}
|
}
|
||||||
if accountReleaseFunc != nil {
|
return
|
||||||
defer accountReleaseFunc()
|
|
||||||
}
|
|
||||||
|
|
||||||
// 转发请求
|
|
||||||
result, err := h.geminiCompatService.Forward(c.Request.Context(), c, account, body)
|
|
||||||
if err != nil {
|
|
||||||
// 错误响应已在Forward中处理,这里只记录日志
|
|
||||||
log.Printf("Forward request failed: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 异步记录使用量(subscription已在函数开头获取)
|
|
||||||
go func(result *service.ForwardResult, usedAccount *service.Account) {
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
|
||||||
Result: result,
|
|
||||||
ApiKey: apiKey,
|
|
||||||
User: apiKey.User,
|
|
||||||
Account: usedAccount,
|
|
||||||
Subscription: subscription,
|
|
||||||
}); err != nil {
|
|
||||||
log.Printf("Record usage failed: %v", err)
|
|
||||||
}
|
}
|
||||||
}(result, account)
|
|
||||||
return
|
// 3. 获取账号并发槽位
|
||||||
|
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, req.Stream, &streamStarted)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Account concurrency acquire failed: %v", err)
|
||||||
|
h.handleConcurrencyError(c, err, "account", streamStarted)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 转发请求
|
||||||
|
result, err := h.geminiCompatService.Forward(c.Request.Context(), c, account, body)
|
||||||
|
if accountReleaseFunc != nil {
|
||||||
|
accountReleaseFunc()
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
var failoverErr *service.UpstreamFailoverError
|
||||||
|
if errors.As(err, &failoverErr) {
|
||||||
|
failedAccountIDs[account.ID] = struct{}{}
|
||||||
|
if switchCount >= maxAccountSwitches {
|
||||||
|
lastFailoverStatus = failoverErr.StatusCode
|
||||||
|
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
lastFailoverStatus = failoverErr.StatusCode
|
||||||
|
switchCount++
|
||||||
|
log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// 错误响应已在Forward中处理,这里只记录日志
|
||||||
|
log.Printf("Forward request failed: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 异步记录使用量(subscription已在函数开头获取)
|
||||||
|
go func(result *service.ForwardResult, usedAccount *service.Account) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||||
|
Result: result,
|
||||||
|
ApiKey: apiKey,
|
||||||
|
User: apiKey.User,
|
||||||
|
Account: usedAccount,
|
||||||
|
Subscription: subscription,
|
||||||
|
}); err != nil {
|
||||||
|
log.Printf("Record usage failed: %v", err)
|
||||||
|
}
|
||||||
|
}(result, account)
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const maxAccountSwitches = 3
|
const maxAccountSwitches = 3
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package handler
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -158,44 +159,69 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
|||||||
|
|
||||||
// 3) select account (sticky session based on request body)
|
// 3) select account (sticky session based on request body)
|
||||||
sessionHash := h.gatewayService.GenerateSessionHash(body)
|
sessionHash := h.gatewayService.GenerateSessionHash(body)
|
||||||
account, err := h.geminiCompatService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, modelName)
|
const maxAccountSwitches = 3
|
||||||
if err != nil {
|
switchCount := 0
|
||||||
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
|
failedAccountIDs := make(map[int64]struct{})
|
||||||
return
|
lastFailoverStatus := 0
|
||||||
}
|
|
||||||
|
|
||||||
// 4) account concurrency slot
|
for {
|
||||||
accountReleaseFunc, err := geminiConcurrency.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, stream, &streamStarted)
|
account, err := h.geminiCompatService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, modelName, failedAccountIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
googleError(c, http.StatusTooManyRequests, err.Error())
|
if len(failedAccountIDs) == 0 {
|
||||||
return
|
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
|
||||||
}
|
return
|
||||||
if accountReleaseFunc != nil {
|
}
|
||||||
defer accountReleaseFunc()
|
handleGeminiFailoverExhausted(c, lastFailoverStatus)
|
||||||
}
|
return
|
||||||
|
|
||||||
// 5) forward (writes response to client)
|
|
||||||
result, err := h.geminiCompatService.ForwardNative(c.Request.Context(), c, account, modelName, action, stream, body)
|
|
||||||
if err != nil {
|
|
||||||
// ForwardNative already wrote the response
|
|
||||||
log.Printf("Gemini native forward failed: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 6) record usage async
|
|
||||||
go func() {
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
|
||||||
Result: result,
|
|
||||||
ApiKey: apiKey,
|
|
||||||
User: apiKey.User,
|
|
||||||
Account: account,
|
|
||||||
Subscription: subscription,
|
|
||||||
}); err != nil {
|
|
||||||
log.Printf("Record usage failed: %v", err)
|
|
||||||
}
|
}
|
||||||
}()
|
|
||||||
|
// 4) account concurrency slot
|
||||||
|
accountReleaseFunc, err := geminiConcurrency.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, stream, &streamStarted)
|
||||||
|
if err != nil {
|
||||||
|
googleError(c, http.StatusTooManyRequests, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 5) forward (writes response to client)
|
||||||
|
result, err := h.geminiCompatService.ForwardNative(c.Request.Context(), c, account, modelName, action, stream, body)
|
||||||
|
if accountReleaseFunc != nil {
|
||||||
|
accountReleaseFunc()
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
var failoverErr *service.UpstreamFailoverError
|
||||||
|
if errors.As(err, &failoverErr) {
|
||||||
|
failedAccountIDs[account.ID] = struct{}{}
|
||||||
|
if switchCount >= maxAccountSwitches {
|
||||||
|
lastFailoverStatus = failoverErr.StatusCode
|
||||||
|
handleGeminiFailoverExhausted(c, lastFailoverStatus)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
lastFailoverStatus = failoverErr.StatusCode
|
||||||
|
switchCount++
|
||||||
|
log.Printf("Gemini account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// ForwardNative already wrote the response
|
||||||
|
log.Printf("Gemini native forward failed: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 6) record usage async
|
||||||
|
go func(result *service.ForwardResult, usedAccount *service.Account) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||||
|
Result: result,
|
||||||
|
ApiKey: apiKey,
|
||||||
|
User: apiKey.User,
|
||||||
|
Account: usedAccount,
|
||||||
|
Subscription: subscription,
|
||||||
|
}); err != nil {
|
||||||
|
log.Printf("Record usage failed: %v", err)
|
||||||
|
}
|
||||||
|
}(result, account)
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseGeminiModelAction(rest string) (model string, action string, err error) {
|
func parseGeminiModelAction(rest string) (model string, action string, err error) {
|
||||||
@@ -217,6 +243,28 @@ func parseGeminiModelAction(rest string) (model string, action string, err error
|
|||||||
return "", "", &pathParseError{"invalid model action path"}
|
return "", "", &pathParseError{"invalid model action path"}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func handleGeminiFailoverExhausted(c *gin.Context, statusCode int) {
|
||||||
|
status, message := mapGeminiUpstreamError(statusCode)
|
||||||
|
googleError(c, status, message)
|
||||||
|
}
|
||||||
|
|
||||||
|
func mapGeminiUpstreamError(statusCode int) (int, string) {
|
||||||
|
switch statusCode {
|
||||||
|
case 401:
|
||||||
|
return http.StatusBadGateway, "Upstream authentication failed, please contact administrator"
|
||||||
|
case 403:
|
||||||
|
return http.StatusBadGateway, "Upstream access forbidden, please contact administrator"
|
||||||
|
case 429:
|
||||||
|
return http.StatusTooManyRequests, "Upstream rate limit exceeded, please retry later"
|
||||||
|
case 529:
|
||||||
|
return http.StatusServiceUnavailable, "Upstream service overloaded, please retry later"
|
||||||
|
case 500, 502, 503, 504:
|
||||||
|
return http.StatusBadGateway, "Upstream service temporarily unavailable"
|
||||||
|
default:
|
||||||
|
return http.StatusBadGateway, "Upstream request failed"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
type pathParseError struct{ msg string }
|
type pathParseError struct{ msg string }
|
||||||
|
|
||||||
func (e *pathParseError) Error() string { return e.msg }
|
func (e *pathParseError) Error() string { return e.msg }
|
||||||
|
|||||||
@@ -62,14 +62,20 @@ func (s *GeminiMessagesCompatService) GetTokenProvider() *GeminiTokenProvider {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *GeminiMessagesCompatService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*Account, error) {
|
func (s *GeminiMessagesCompatService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*Account, error) {
|
||||||
|
return s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) {
|
||||||
cacheKey := "gemini:" + sessionHash
|
cacheKey := "gemini:" + sessionHash
|
||||||
if sessionHash != "" {
|
if sessionHash != "" {
|
||||||
accountID, err := s.cache.GetSessionAccountID(ctx, cacheKey)
|
accountID, err := s.cache.GetSessionAccountID(ctx, cacheKey)
|
||||||
if err == nil && accountID > 0 {
|
if err == nil && accountID > 0 {
|
||||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
if _, excluded := excludedIDs[accountID]; !excluded {
|
||||||
if err == nil && account.IsSchedulable() && account.Platform == PlatformGemini && (requestedModel == "" || account.IsModelSupported(requestedModel)) {
|
account, err := s.accountRepo.GetByID(ctx, accountID)
|
||||||
_ = s.cache.RefreshSessionTTL(ctx, cacheKey, geminiStickySessionTTL)
|
if err == nil && account.IsSchedulable() && account.Platform == PlatformGemini && (requestedModel == "" || account.IsModelSupported(requestedModel)) {
|
||||||
return account, nil
|
_ = s.cache.RefreshSessionTTL(ctx, cacheKey, geminiStickySessionTTL)
|
||||||
|
return account, nil
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -88,6 +94,9 @@ func (s *GeminiMessagesCompatService) SelectAccountForModel(ctx context.Context,
|
|||||||
var selected *Account
|
var selected *Account
|
||||||
for i := range accounts {
|
for i := range accounts {
|
||||||
acc := &accounts[i]
|
acc := &accounts[i]
|
||||||
|
if _, excluded := excludedIDs[acc.ID]; excluded {
|
||||||
|
continue
|
||||||
|
}
|
||||||
if requestedModel != "" && !acc.IsModelSupported(requestedModel) {
|
if requestedModel != "" && !acc.IsModelSupported(requestedModel) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -425,6 +434,9 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
|||||||
if resp.StatusCode >= 400 {
|
if resp.StatusCode >= 400 {
|
||||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||||
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||||||
|
if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) {
|
||||||
|
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||||
|
}
|
||||||
return nil, s.writeGeminiMappedError(c, resp.StatusCode, respBody)
|
return nil, s.writeGeminiMappedError(c, resp.StatusCode, respBody)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -724,6 +736,10 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) {
|
||||||
|
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||||
|
}
|
||||||
|
|
||||||
respBody = unwrapIfNeeded(isOAuth, respBody)
|
respBody = unwrapIfNeeded(isOAuth, respBody)
|
||||||
contentType := resp.Header.Get("Content-Type")
|
contentType := resp.Header.Get("Content-Type")
|
||||||
if contentType == "" {
|
if contentType == "" {
|
||||||
@@ -795,6 +811,15 @@ func (s *GeminiMessagesCompatService) shouldRetryGeminiUpstreamError(account *Ac
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *GeminiMessagesCompatService) shouldFailoverGeminiUpstreamError(statusCode int) bool {
|
||||||
|
switch statusCode {
|
||||||
|
case 401, 403, 429, 529:
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return statusCode >= 500
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func sleepGeminiBackoff(attempt int) {
|
func sleepGeminiBackoff(attempt int) {
|
||||||
delay := geminiRetryBaseDelay * time.Duration(1<<uint(attempt-1))
|
delay := geminiRetryBaseDelay * time.Duration(1<<uint(attempt-1))
|
||||||
if delay > geminiRetryMaxDelay {
|
if delay > geminiRetryMaxDelay {
|
||||||
|
|||||||
Reference in New Issue
Block a user