feat: cc/codex support account retry

This commit is contained in:
daodao97
2025-12-27 11:44:00 +08:00
parent 254f12543c
commit 95583fce83
4 changed files with 330 additions and 115 deletions

View File

@@ -3,6 +3,7 @@ package handler
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"io" "io"
"log" "log"
@@ -127,13 +128,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
platform = apiKey.Group.Platform platform = apiKey.Group.Platform
} }
// 选择支持该模型的账号
var account *service.Account
if platform == service.PlatformGemini { if platform == service.PlatformGemini {
account, err = h.geminiCompatService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, req.Model) account, err := h.geminiCompatService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, req.Model)
} else {
account, err = h.gatewayService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, req.Model)
}
if err != nil { if err != nil {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
return return
@@ -161,12 +157,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
} }
// 转发请求 // 转发请求
var result *service.ForwardResult result, err := h.geminiCompatService.Forward(c.Request.Context(), c, account, body)
if platform == service.PlatformGemini {
result, err = h.geminiCompatService.Forward(c.Request.Context(), c, account, body)
} else {
result, err = h.gatewayService.Forward(c.Request.Context(), c, account, body)
}
if err != nil { if err != nil {
// 错误响应已在Forward中处理这里只记录日志 // 错误响应已在Forward中处理这里只记录日志
log.Printf("Forward request failed: %v", err) log.Printf("Forward request failed: %v", err)
@@ -174,19 +165,97 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
} }
// 异步记录使用量subscription已在函数开头获取 // 异步记录使用量subscription已在函数开头获取
go func() { go func(result *service.ForwardResult, usedAccount *service.Account) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel() defer cancel()
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result, Result: result,
ApiKey: apiKey, ApiKey: apiKey,
User: apiKey.User, User: apiKey.User,
Account: account, Account: usedAccount,
Subscription: subscription, Subscription: subscription,
}); err != nil { }); err != nil {
log.Printf("Record usage failed: %v", err) log.Printf("Record usage failed: %v", err)
} }
}() }(result, account)
return
}
const maxAccountSwitches = 3
switchCount := 0
failedAccountIDs := make(map[int64]struct{})
lastFailoverStatus := 0
for {
// 选择支持该模型的账号
account, err := h.gatewayService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, req.Model, failedAccountIDs)
if err != nil {
if len(failedAccountIDs) == 0 {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
return
}
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
return
}
// 检查预热请求拦截(在账号选择后、转发前检查)
if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) {
if req.Stream {
sendMockWarmupStream(c, req.Model)
} else {
sendMockWarmupResponse(c, req.Model)
}
return
}
// 3. 获取账号并发槽位
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, req.Stream, &streamStarted)
if err != nil {
log.Printf("Account concurrency acquire failed: %v", err)
h.handleConcurrencyError(c, err, "account", streamStarted)
return
}
// 转发请求
result, err := h.gatewayService.Forward(c.Request.Context(), c, account, body)
if accountReleaseFunc != nil {
accountReleaseFunc()
}
if err != nil {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
failedAccountIDs[account.ID] = struct{}{}
if switchCount >= maxAccountSwitches {
lastFailoverStatus = failoverErr.StatusCode
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
return
}
lastFailoverStatus = failoverErr.StatusCode
switchCount++
log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
continue
}
// 错误响应已在Forward中处理这里只记录日志
log.Printf("Forward request failed: %v", err)
return
}
// 异步记录使用量subscription已在函数开头获取
go func(result *service.ForwardResult, usedAccount *service.Account) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result,
ApiKey: apiKey,
User: apiKey.User,
Account: usedAccount,
Subscription: subscription,
}); err != nil {
log.Printf("Record usage failed: %v", err)
}
}(result, account)
return
}
} }
// Models handles listing available models // Models handles listing available models
@@ -314,6 +383,28 @@ func (h *GatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotT
fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted) fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted)
} }
func (h *GatewayHandler) handleFailoverExhausted(c *gin.Context, statusCode int, streamStarted bool) {
status, errType, errMsg := h.mapUpstreamError(statusCode)
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
}
func (h *GatewayHandler) mapUpstreamError(statusCode int) (int, string, string) {
switch statusCode {
case 401:
return http.StatusBadGateway, "upstream_error", "Upstream authentication failed, please contact administrator"
case 403:
return http.StatusBadGateway, "upstream_error", "Upstream access forbidden, please contact administrator"
case 429:
return http.StatusTooManyRequests, "rate_limit_error", "Upstream rate limit exceeded, please retry later"
case 529:
return http.StatusServiceUnavailable, "overloaded_error", "Upstream service overloaded, please retry later"
case 500, 502, 503, 504:
return http.StatusBadGateway, "upstream_error", "Upstream service temporarily unavailable"
default:
return http.StatusBadGateway, "upstream_error", "Upstream request failed"
}
}
// handleStreamingAwareError handles errors that may occur after streaming has started // handleStreamingAwareError handles errors that may occur after streaming has started
func (h *GatewayHandler) handleStreamingAwareError(c *gin.Context, status int, errType, message string, streamStarted bool) { func (h *GatewayHandler) handleStreamingAwareError(c *gin.Context, status int, errType, message string, streamStarted bool) {
if streamStarted { if streamStarted {

View File

@@ -3,6 +3,7 @@ package handler
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"io" "io"
"log" "log"
@@ -127,14 +128,24 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
// Generate session hash (from header for OpenAI) // Generate session hash (from header for OpenAI)
sessionHash := h.gatewayService.GenerateSessionHash(c) sessionHash := h.gatewayService.GenerateSessionHash(c)
const maxAccountSwitches = 3
switchCount := 0
failedAccountIDs := make(map[int64]struct{})
lastFailoverStatus := 0
for {
// Select account supporting the requested model // Select account supporting the requested model
log.Printf("[OpenAI Handler] Selecting account: groupID=%v model=%s", apiKey.GroupID, reqModel) log.Printf("[OpenAI Handler] Selecting account: groupID=%v model=%s", apiKey.GroupID, reqModel)
account, err := h.gatewayService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel) account, err := h.gatewayService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs)
if err != nil { if err != nil {
log.Printf("[OpenAI Handler] SelectAccount failed: %v", err) log.Printf("[OpenAI Handler] SelectAccount failed: %v", err)
if len(failedAccountIDs) == 0 {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
return return
} }
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
return
}
log.Printf("[OpenAI Handler] Selected account: id=%d name=%s", account.ID, account.Name) log.Printf("[OpenAI Handler] Selected account: id=%d name=%s", account.ID, account.Name)
// 3. Acquire account concurrency slot // 3. Acquire account concurrency slot
@@ -144,32 +155,47 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
h.handleConcurrencyError(c, err, "account", streamStarted) h.handleConcurrencyError(c, err, "account", streamStarted)
return return
} }
if accountReleaseFunc != nil {
defer accountReleaseFunc()
}
// Forward request // Forward request
result, err := h.gatewayService.Forward(c.Request.Context(), c, account, body) result, err := h.gatewayService.Forward(c.Request.Context(), c, account, body)
if accountReleaseFunc != nil {
accountReleaseFunc()
}
if err != nil { if err != nil {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
failedAccountIDs[account.ID] = struct{}{}
if switchCount >= maxAccountSwitches {
lastFailoverStatus = failoverErr.StatusCode
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
return
}
lastFailoverStatus = failoverErr.StatusCode
switchCount++
log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
continue
}
// Error response already handled in Forward, just log // Error response already handled in Forward, just log
log.Printf("Forward request failed: %v", err) log.Printf("Forward request failed: %v", err)
return return
} }
// Async record usage // Async record usage
go func() { go func(result *service.OpenAIForwardResult, usedAccount *service.Account) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel() defer cancel()
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{ if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
Result: result, Result: result,
ApiKey: apiKey, ApiKey: apiKey,
User: apiKey.User, User: apiKey.User,
Account: account, Account: usedAccount,
Subscription: subscription, Subscription: subscription,
}); err != nil { }); err != nil {
log.Printf("Record usage failed: %v", err) log.Printf("Record usage failed: %v", err)
} }
}() }(result, account)
return
}
} }
// handleConcurrencyError handles concurrency-related errors with proper 429 response // handleConcurrencyError handles concurrency-related errors with proper 429 response
@@ -178,6 +204,28 @@ func (h *OpenAIGatewayHandler) handleConcurrencyError(c *gin.Context, err error,
fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted) fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted)
} }
func (h *OpenAIGatewayHandler) handleFailoverExhausted(c *gin.Context, statusCode int, streamStarted bool) {
status, errType, errMsg := h.mapUpstreamError(statusCode)
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
}
func (h *OpenAIGatewayHandler) mapUpstreamError(statusCode int) (int, string, string) {
switch statusCode {
case 401:
return http.StatusBadGateway, "upstream_error", "Upstream authentication failed, please contact administrator"
case 403:
return http.StatusBadGateway, "upstream_error", "Upstream access forbidden, please contact administrator"
case 429:
return http.StatusTooManyRequests, "rate_limit_error", "Upstream rate limit exceeded, please retry later"
case 529:
return http.StatusServiceUnavailable, "upstream_error", "Upstream service overloaded, please retry later"
case 500, 502, 503, 504:
return http.StatusBadGateway, "upstream_error", "Upstream service temporarily unavailable"
default:
return http.StatusBadGateway, "upstream_error", "Upstream request failed"
}
}
// handleStreamingAwareError handles errors that may occur after streaming has started // handleStreamingAwareError handles errors that may occur after streaming has started
func (h *OpenAIGatewayHandler) handleStreamingAwareError(c *gin.Context, status int, errType, message string, streamStarted bool) { func (h *OpenAIGatewayHandler) handleStreamingAwareError(c *gin.Context, status int, errType, message string, streamStarted bool) {
if streamStarted { if streamStarted {

View File

@@ -81,6 +81,15 @@ type ForwardResult struct {
FirstTokenMs *int // 首字时间(流式请求) FirstTokenMs *int // 首字时间(流式请求)
} }
// UpstreamFailoverError indicates an upstream error that should trigger account failover.
type UpstreamFailoverError struct {
StatusCode int
}
func (e *UpstreamFailoverError) Error() string {
return fmt.Sprintf("upstream error: %d (failover)", e.StatusCode)
}
// GatewayService handles API gateway operations // GatewayService handles API gateway operations
type GatewayService struct { type GatewayService struct {
accountRepo AccountRepository accountRepo AccountRepository
@@ -274,10 +283,16 @@ func (s *GatewayService) SelectAccount(ctx context.Context, groupID *int64, sess
// SelectAccountForModel 选择支持指定模型的账号(粘性会话+优先级+模型映射) // SelectAccountForModel 选择支持指定模型的账号(粘性会话+优先级+模型映射)
func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*Account, error) { func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*Account, error) {
return s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, nil)
}
// SelectAccountForModelWithExclusions selects an account supporting the requested model while excluding specified accounts.
func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) {
// 1. 查询粘性会话 // 1. 查询粘性会话
if sessionHash != "" { if sessionHash != "" {
accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash) accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash)
if err == nil && accountID > 0 { if err == nil && accountID > 0 {
if _, excluded := excludedIDs[accountID]; !excluded {
account, err := s.accountRepo.GetByID(ctx, accountID) account, err := s.accountRepo.GetByID(ctx, accountID)
// 使用IsSchedulable代替IsActive确保限流/过载账号不会被选中 // 使用IsSchedulable代替IsActive确保限流/过载账号不会被选中
// 同时检查模型支持 // 同时检查模型支持
@@ -290,6 +305,7 @@ func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int
} }
} }
} }
}
// 2. 获取可调度账号列表(排除限流和过载的账号,仅限 Anthropic 平台) // 2. 获取可调度账号列表(排除限流和过载的账号,仅限 Anthropic 平台)
var accounts []Account var accounts []Account
@@ -307,6 +323,9 @@ func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int
var selected *Account var selected *Account
for i := range accounts { for i := range accounts {
acc := &accounts[i] acc := &accounts[i]
if _, excluded := excludedIDs[acc.ID]; excluded {
continue
}
// 检查模型支持 // 检查模型支持
if requestedModel != "" && !acc.IsModelSupported(requestedModel) { if requestedModel != "" && !acc.IsModelSupported(requestedModel) {
continue continue
@@ -394,6 +413,16 @@ func (s *GatewayService) shouldRetryUpstreamError(account *Account, statusCode i
return !account.ShouldHandleErrorCode(statusCode) return !account.ShouldHandleErrorCode(statusCode)
} }
// shouldFailoverUpstreamError determines whether an upstream error should trigger account failover.
func (s *GatewayService) shouldFailoverUpstreamError(statusCode int) bool {
switch statusCode {
case 401, 403, 429, 529:
return true
default:
return statusCode >= 500
}
}
// Forward 转发请求到Claude API // Forward 转发请求到Claude API
func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte) (*ForwardResult, error) { func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte) (*ForwardResult, error) {
startTime := time.Now() startTime := time.Now()
@@ -478,9 +507,19 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
// 处理重试耗尽的情况 // 处理重试耗尽的情况
if resp.StatusCode >= 400 && s.shouldRetryUpstreamError(account, resp.StatusCode) { if resp.StatusCode >= 400 && s.shouldRetryUpstreamError(account, resp.StatusCode) {
if s.shouldFailoverUpstreamError(resp.StatusCode) {
s.handleRetryExhaustedSideEffects(ctx, resp, account)
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
}
return s.handleRetryExhaustedError(ctx, resp, c, account) return s.handleRetryExhaustedError(ctx, resp, c, account)
} }
// 处理可切换账号的错误
if resp.StatusCode >= 400 && s.shouldFailoverUpstreamError(resp.StatusCode) {
s.handleFailoverSideEffects(ctx, resp, account)
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
}
// 处理错误响应(不可重试的错误) // 处理错误响应(不可重试的错误)
if resp.StatusCode >= 400 { if resp.StatusCode >= 400 {
return s.handleErrorResponse(ctx, resp, c, account) return s.handleErrorResponse(ctx, resp, c, account)
@@ -692,10 +731,7 @@ func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Res
return nil, fmt.Errorf("upstream error: %d", resp.StatusCode) return nil, fmt.Errorf("upstream error: %d", resp.StatusCode)
} }
// handleRetryExhaustedError 处理重试耗尽后的错误 func (s *GatewayService) handleRetryExhaustedSideEffects(ctx context.Context, resp *http.Response, account *Account) {
// OAuth 403标记账号异常
// API Key 未配置错误码:仅返回错误,不标记账号
func (s *GatewayService) handleRetryExhaustedError(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*ForwardResult, error) {
body, _ := io.ReadAll(resp.Body) body, _ := io.ReadAll(resp.Body)
statusCode := resp.StatusCode statusCode := resp.StatusCode
@@ -707,6 +743,18 @@ func (s *GatewayService) handleRetryExhaustedError(ctx context.Context, resp *ht
// API Key 未配置错误码:不标记账号状态 // API Key 未配置错误码:不标记账号状态
log.Printf("Account %d: upstream error %d after %d retries (not marking account)", account.ID, statusCode, maxRetries) log.Printf("Account %d: upstream error %d after %d retries (not marking account)", account.ID, statusCode, maxRetries)
} }
}
func (s *GatewayService) handleFailoverSideEffects(ctx context.Context, resp *http.Response, account *Account) {
body, _ := io.ReadAll(resp.Body)
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
}
// handleRetryExhaustedError 处理重试耗尽后的错误
// OAuth 403标记账号异常
// API Key 未配置错误码:仅返回错误,不标记账号
func (s *GatewayService) handleRetryExhaustedError(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*ForwardResult, error) {
s.handleRetryExhaustedSideEffects(ctx, resp, account)
// 返回统一的重试耗尽错误响应 // 返回统一的重试耗尽错误响应
c.JSON(http.StatusBadGateway, gin.H{ c.JSON(http.StatusBadGateway, gin.H{
@@ -717,7 +765,7 @@ func (s *GatewayService) handleRetryExhaustedError(ctx context.Context, resp *ht
}, },
}) })
return nil, fmt.Errorf("upstream error: %d (retries exhausted)", statusCode) return nil, fmt.Errorf("upstream error: %d (retries exhausted)", resp.StatusCode)
} }
// streamingResult 流式响应结果 // streamingResult 流式响应结果

View File

@@ -129,10 +129,16 @@ func (s *OpenAIGatewayService) SelectAccount(ctx context.Context, groupID *int64
// SelectAccountForModel selects an account supporting the requested model // SelectAccountForModel selects an account supporting the requested model
func (s *OpenAIGatewayService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*Account, error) { func (s *OpenAIGatewayService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*Account, error) {
return s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, nil)
}
// SelectAccountForModelWithExclusions selects an account supporting the requested model while excluding specified accounts.
func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) {
// 1. Check sticky session // 1. Check sticky session
if sessionHash != "" { if sessionHash != "" {
accountID, err := s.cache.GetSessionAccountID(ctx, "openai:"+sessionHash) accountID, err := s.cache.GetSessionAccountID(ctx, "openai:"+sessionHash)
if err == nil && accountID > 0 { if err == nil && accountID > 0 {
if _, excluded := excludedIDs[accountID]; !excluded {
account, err := s.accountRepo.GetByID(ctx, accountID) account, err := s.accountRepo.GetByID(ctx, accountID)
if err == nil && account.IsSchedulable() && account.IsOpenAI() && (requestedModel == "" || account.IsModelSupported(requestedModel)) { if err == nil && account.IsSchedulable() && account.IsOpenAI() && (requestedModel == "" || account.IsModelSupported(requestedModel)) {
// Refresh sticky session TTL // Refresh sticky session TTL
@@ -141,6 +147,7 @@ func (s *OpenAIGatewayService) SelectAccountForModel(ctx context.Context, groupI
} }
} }
} }
}
// 2. Get schedulable OpenAI accounts // 2. Get schedulable OpenAI accounts
var accounts []Account var accounts []Account
@@ -158,6 +165,9 @@ func (s *OpenAIGatewayService) SelectAccountForModel(ctx context.Context, groupI
var selected *Account var selected *Account
for i := range accounts { for i := range accounts {
acc := &accounts[i] acc := &accounts[i]
if _, excluded := excludedIDs[acc.ID]; excluded {
continue
}
// Check model support // Check model support
if requestedModel != "" && !acc.IsModelSupported(requestedModel) { if requestedModel != "" && !acc.IsModelSupported(requestedModel) {
continue continue
@@ -221,6 +231,20 @@ func (s *OpenAIGatewayService) GetAccessToken(ctx context.Context, account *Acco
} }
} }
func (s *OpenAIGatewayService) shouldFailoverUpstreamError(statusCode int) bool {
switch statusCode {
case 401, 403, 429, 529:
return true
default:
return statusCode >= 500
}
}
func (s *OpenAIGatewayService) handleFailoverSideEffects(ctx context.Context, resp *http.Response, account *Account) {
body, _ := io.ReadAll(resp.Body)
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
}
// Forward forwards request to OpenAI API // Forward forwards request to OpenAI API
func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte) (*OpenAIForwardResult, error) { func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte) (*OpenAIForwardResult, error) {
startTime := time.Now() startTime := time.Now()
@@ -288,6 +312,10 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
// Handle error response // Handle error response
if resp.StatusCode >= 400 { if resp.StatusCode >= 400 {
if s.shouldFailoverUpstreamError(resp.StatusCode) {
s.handleFailoverSideEffects(ctx, resp, account)
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
}
return s.handleErrorResponse(ctx, resp, c, account) return s.handleErrorResponse(ctx, resp, c, account)
} }