feat: cc/codex support account retry

This commit is contained in:
daodao97
2025-12-27 11:44:00 +08:00
parent 254f12543c
commit 95583fce83
4 changed files with 330 additions and 115 deletions

View File

@@ -81,6 +81,15 @@ type ForwardResult struct {
FirstTokenMs *int // 首字时间(流式请求)
}
// UpstreamFailoverError indicates an upstream error that should trigger account failover.
type UpstreamFailoverError struct {
StatusCode int
}
func (e *UpstreamFailoverError) Error() string {
return fmt.Sprintf("upstream error: %d (failover)", e.StatusCode)
}
// GatewayService handles API gateway operations
type GatewayService struct {
accountRepo AccountRepository
@@ -274,19 +283,26 @@ func (s *GatewayService) SelectAccount(ctx context.Context, groupID *int64, sess
// SelectAccountForModel 选择支持指定模型的账号(粘性会话+优先级+模型映射)
func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*Account, error) {
return s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, nil)
}
// SelectAccountForModelWithExclusions selects an account supporting the requested model while excluding specified accounts.
func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) {
// 1. 查询粘性会话
if sessionHash != "" {
accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash)
if err == nil && accountID > 0 {
account, err := s.accountRepo.GetByID(ctx, accountID)
// 使用IsSchedulable代替IsActive确保限流/过载账号不会被选中
// 同时检查模型支持
if err == nil && account.IsSchedulable() && (requestedModel == "" || account.IsModelSupported(requestedModel)) {
// 续期粘性会话
if err := s.cache.RefreshSessionTTL(ctx, sessionHash, stickySessionTTL); err != nil {
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
if _, excluded := excludedIDs[accountID]; !excluded {
account, err := s.accountRepo.GetByID(ctx, accountID)
// 使用IsSchedulable代替IsActive确保限流/过载账号不会被选中
// 同时检查模型支持
if err == nil && account.IsSchedulable() && (requestedModel == "" || account.IsModelSupported(requestedModel)) {
// 续期粘性会话
if err := s.cache.RefreshSessionTTL(ctx, sessionHash, stickySessionTTL); err != nil {
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
}
return account, nil
}
return account, nil
}
}
}
@@ -307,6 +323,9 @@ func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int
var selected *Account
for i := range accounts {
acc := &accounts[i]
if _, excluded := excludedIDs[acc.ID]; excluded {
continue
}
// 检查模型支持
if requestedModel != "" && !acc.IsModelSupported(requestedModel) {
continue
@@ -394,6 +413,16 @@ func (s *GatewayService) shouldRetryUpstreamError(account *Account, statusCode i
return !account.ShouldHandleErrorCode(statusCode)
}
// shouldFailoverUpstreamError determines whether an upstream error should trigger account failover.
func (s *GatewayService) shouldFailoverUpstreamError(statusCode int) bool {
switch statusCode {
case 401, 403, 429, 529:
return true
default:
return statusCode >= 500
}
}
// Forward 转发请求到Claude API
func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte) (*ForwardResult, error) {
startTime := time.Now()
@@ -478,9 +507,19 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
// 处理重试耗尽的情况
if resp.StatusCode >= 400 && s.shouldRetryUpstreamError(account, resp.StatusCode) {
if s.shouldFailoverUpstreamError(resp.StatusCode) {
s.handleRetryExhaustedSideEffects(ctx, resp, account)
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
}
return s.handleRetryExhaustedError(ctx, resp, c, account)
}
// 处理可切换账号的错误
if resp.StatusCode >= 400 && s.shouldFailoverUpstreamError(resp.StatusCode) {
s.handleFailoverSideEffects(ctx, resp, account)
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
}
// 处理错误响应(不可重试的错误)
if resp.StatusCode >= 400 {
return s.handleErrorResponse(ctx, resp, c, account)
@@ -692,10 +731,7 @@ func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Res
return nil, fmt.Errorf("upstream error: %d", resp.StatusCode)
}
// handleRetryExhaustedError 处理重试耗尽后的错误
// OAuth 403标记账号异常
// API Key 未配置错误码:仅返回错误,不标记账号
func (s *GatewayService) handleRetryExhaustedError(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*ForwardResult, error) {
func (s *GatewayService) handleRetryExhaustedSideEffects(ctx context.Context, resp *http.Response, account *Account) {
body, _ := io.ReadAll(resp.Body)
statusCode := resp.StatusCode
@@ -707,6 +743,18 @@ func (s *GatewayService) handleRetryExhaustedError(ctx context.Context, resp *ht
// API Key 未配置错误码:不标记账号状态
log.Printf("Account %d: upstream error %d after %d retries (not marking account)", account.ID, statusCode, maxRetries)
}
}
func (s *GatewayService) handleFailoverSideEffects(ctx context.Context, resp *http.Response, account *Account) {
body, _ := io.ReadAll(resp.Body)
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
}
// handleRetryExhaustedError 处理重试耗尽后的错误
// OAuth 403标记账号异常
// API Key 未配置错误码:仅返回错误,不标记账号
func (s *GatewayService) handleRetryExhaustedError(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*ForwardResult, error) {
s.handleRetryExhaustedSideEffects(ctx, resp, account)
// 返回统一的重试耗尽错误响应
c.JSON(http.StatusBadGateway, gin.H{
@@ -717,7 +765,7 @@ func (s *GatewayService) handleRetryExhaustedError(ctx context.Context, resp *ht
},
})
return nil, fmt.Errorf("upstream error: %d (retries exhausted)", statusCode)
return nil, fmt.Errorf("upstream error: %d (retries exhausted)", resp.StatusCode)
}
// streamingResult 流式响应结果

View File

@@ -129,15 +129,22 @@ func (s *OpenAIGatewayService) SelectAccount(ctx context.Context, groupID *int64
// SelectAccountForModel selects an account supporting the requested model
func (s *OpenAIGatewayService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*Account, error) {
return s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, nil)
}
// SelectAccountForModelWithExclusions selects an account supporting the requested model while excluding specified accounts.
func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) {
// 1. Check sticky session
if sessionHash != "" {
accountID, err := s.cache.GetSessionAccountID(ctx, "openai:"+sessionHash)
if err == nil && accountID > 0 {
account, err := s.accountRepo.GetByID(ctx, accountID)
if err == nil && account.IsSchedulable() && account.IsOpenAI() && (requestedModel == "" || account.IsModelSupported(requestedModel)) {
// Refresh sticky session TTL
_ = s.cache.RefreshSessionTTL(ctx, "openai:"+sessionHash, openaiStickySessionTTL)
return account, nil
if _, excluded := excludedIDs[accountID]; !excluded {
account, err := s.accountRepo.GetByID(ctx, accountID)
if err == nil && account.IsSchedulable() && account.IsOpenAI() && (requestedModel == "" || account.IsModelSupported(requestedModel)) {
// Refresh sticky session TTL
_ = s.cache.RefreshSessionTTL(ctx, "openai:"+sessionHash, openaiStickySessionTTL)
return account, nil
}
}
}
}
@@ -158,6 +165,9 @@ func (s *OpenAIGatewayService) SelectAccountForModel(ctx context.Context, groupI
var selected *Account
for i := range accounts {
acc := &accounts[i]
if _, excluded := excludedIDs[acc.ID]; excluded {
continue
}
// Check model support
if requestedModel != "" && !acc.IsModelSupported(requestedModel) {
continue
@@ -221,6 +231,20 @@ func (s *OpenAIGatewayService) GetAccessToken(ctx context.Context, account *Acco
}
}
func (s *OpenAIGatewayService) shouldFailoverUpstreamError(statusCode int) bool {
switch statusCode {
case 401, 403, 429, 529:
return true
default:
return statusCode >= 500
}
}
func (s *OpenAIGatewayService) handleFailoverSideEffects(ctx context.Context, resp *http.Response, account *Account) {
body, _ := io.ReadAll(resp.Body)
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
}
// Forward forwards request to OpenAI API
func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte) (*OpenAIForwardResult, error) {
startTime := time.Now()
@@ -288,6 +312,10 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
// Handle error response
if resp.StatusCode >= 400 {
if s.shouldFailoverUpstreamError(resp.StatusCode) {
s.handleFailoverSideEffects(ctx, resp, account)
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
}
return s.handleErrorResponse(ctx, resp, c, account)
}