feat: cc/codex support account retry
This commit is contained in:
@@ -3,6 +3,7 @@ package handler
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
@@ -127,66 +128,134 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
platform = apiKey.Group.Platform
|
platform = apiKey.Group.Platform
|
||||||
}
|
}
|
||||||
|
|
||||||
// 选择支持该模型的账号
|
|
||||||
var account *service.Account
|
|
||||||
if platform == service.PlatformGemini {
|
if platform == service.PlatformGemini {
|
||||||
account, err = h.geminiCompatService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, req.Model)
|
account, err := h.geminiCompatService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, req.Model)
|
||||||
} else {
|
if err != nil {
|
||||||
account, err = h.gatewayService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, req.Model)
|
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
||||||
}
|
return
|
||||||
if err != nil {
|
|
||||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 检查预热请求拦截(在账号选择后、转发前检查)
|
|
||||||
if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) {
|
|
||||||
if req.Stream {
|
|
||||||
sendMockWarmupStream(c, req.Model)
|
|
||||||
} else {
|
|
||||||
sendMockWarmupResponse(c, req.Model)
|
|
||||||
}
|
}
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 3. 获取账号并发槽位
|
// 检查预热请求拦截(在账号选择后、转发前检查)
|
||||||
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, req.Stream, &streamStarted)
|
if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) {
|
||||||
if err != nil {
|
if req.Stream {
|
||||||
log.Printf("Account concurrency acquire failed: %v", err)
|
sendMockWarmupStream(c, req.Model)
|
||||||
h.handleConcurrencyError(c, err, "account", streamStarted)
|
} else {
|
||||||
return
|
sendMockWarmupResponse(c, req.Model)
|
||||||
}
|
}
|
||||||
if accountReleaseFunc != nil {
|
return
|
||||||
defer accountReleaseFunc()
|
|
||||||
}
|
|
||||||
|
|
||||||
// 转发请求
|
|
||||||
var result *service.ForwardResult
|
|
||||||
if platform == service.PlatformGemini {
|
|
||||||
result, err = h.geminiCompatService.Forward(c.Request.Context(), c, account, body)
|
|
||||||
} else {
|
|
||||||
result, err = h.gatewayService.Forward(c.Request.Context(), c, account, body)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
// 错误响应已在Forward中处理,这里只记录日志
|
|
||||||
log.Printf("Forward request failed: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 异步记录使用量(subscription已在函数开头获取)
|
|
||||||
go func() {
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
|
||||||
Result: result,
|
|
||||||
ApiKey: apiKey,
|
|
||||||
User: apiKey.User,
|
|
||||||
Account: account,
|
|
||||||
Subscription: subscription,
|
|
||||||
}); err != nil {
|
|
||||||
log.Printf("Record usage failed: %v", err)
|
|
||||||
}
|
}
|
||||||
}()
|
|
||||||
|
// 3. 获取账号并发槽位
|
||||||
|
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, req.Stream, &streamStarted)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Account concurrency acquire failed: %v", err)
|
||||||
|
h.handleConcurrencyError(c, err, "account", streamStarted)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if accountReleaseFunc != nil {
|
||||||
|
defer accountReleaseFunc()
|
||||||
|
}
|
||||||
|
|
||||||
|
// 转发请求
|
||||||
|
result, err := h.geminiCompatService.Forward(c.Request.Context(), c, account, body)
|
||||||
|
if err != nil {
|
||||||
|
// 错误响应已在Forward中处理,这里只记录日志
|
||||||
|
log.Printf("Forward request failed: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 异步记录使用量(subscription已在函数开头获取)
|
||||||
|
go func(result *service.ForwardResult, usedAccount *service.Account) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||||
|
Result: result,
|
||||||
|
ApiKey: apiKey,
|
||||||
|
User: apiKey.User,
|
||||||
|
Account: usedAccount,
|
||||||
|
Subscription: subscription,
|
||||||
|
}); err != nil {
|
||||||
|
log.Printf("Record usage failed: %v", err)
|
||||||
|
}
|
||||||
|
}(result, account)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
const maxAccountSwitches = 3
|
||||||
|
switchCount := 0
|
||||||
|
failedAccountIDs := make(map[int64]struct{})
|
||||||
|
lastFailoverStatus := 0
|
||||||
|
|
||||||
|
for {
|
||||||
|
// 选择支持该模型的账号
|
||||||
|
account, err := h.gatewayService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, req.Model, failedAccountIDs)
|
||||||
|
if err != nil {
|
||||||
|
if len(failedAccountIDs) == 0 {
|
||||||
|
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查预热请求拦截(在账号选择后、转发前检查)
|
||||||
|
if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) {
|
||||||
|
if req.Stream {
|
||||||
|
sendMockWarmupStream(c, req.Model)
|
||||||
|
} else {
|
||||||
|
sendMockWarmupResponse(c, req.Model)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. 获取账号并发槽位
|
||||||
|
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, req.Stream, &streamStarted)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Account concurrency acquire failed: %v", err)
|
||||||
|
h.handleConcurrencyError(c, err, "account", streamStarted)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 转发请求
|
||||||
|
result, err := h.gatewayService.Forward(c.Request.Context(), c, account, body)
|
||||||
|
if accountReleaseFunc != nil {
|
||||||
|
accountReleaseFunc()
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
var failoverErr *service.UpstreamFailoverError
|
||||||
|
if errors.As(err, &failoverErr) {
|
||||||
|
failedAccountIDs[account.ID] = struct{}{}
|
||||||
|
if switchCount >= maxAccountSwitches {
|
||||||
|
lastFailoverStatus = failoverErr.StatusCode
|
||||||
|
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
lastFailoverStatus = failoverErr.StatusCode
|
||||||
|
switchCount++
|
||||||
|
log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// 错误响应已在Forward中处理,这里只记录日志
|
||||||
|
log.Printf("Forward request failed: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 异步记录使用量(subscription已在函数开头获取)
|
||||||
|
go func(result *service.ForwardResult, usedAccount *service.Account) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||||
|
Result: result,
|
||||||
|
ApiKey: apiKey,
|
||||||
|
User: apiKey.User,
|
||||||
|
Account: usedAccount,
|
||||||
|
Subscription: subscription,
|
||||||
|
}); err != nil {
|
||||||
|
log.Printf("Record usage failed: %v", err)
|
||||||
|
}
|
||||||
|
}(result, account)
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Models handles listing available models
|
// Models handles listing available models
|
||||||
@@ -314,6 +383,28 @@ func (h *GatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotT
|
|||||||
fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted)
|
fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *GatewayHandler) handleFailoverExhausted(c *gin.Context, statusCode int, streamStarted bool) {
|
||||||
|
status, errType, errMsg := h.mapUpstreamError(statusCode)
|
||||||
|
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *GatewayHandler) mapUpstreamError(statusCode int) (int, string, string) {
|
||||||
|
switch statusCode {
|
||||||
|
case 401:
|
||||||
|
return http.StatusBadGateway, "upstream_error", "Upstream authentication failed, please contact administrator"
|
||||||
|
case 403:
|
||||||
|
return http.StatusBadGateway, "upstream_error", "Upstream access forbidden, please contact administrator"
|
||||||
|
case 429:
|
||||||
|
return http.StatusTooManyRequests, "rate_limit_error", "Upstream rate limit exceeded, please retry later"
|
||||||
|
case 529:
|
||||||
|
return http.StatusServiceUnavailable, "overloaded_error", "Upstream service overloaded, please retry later"
|
||||||
|
case 500, 502, 503, 504:
|
||||||
|
return http.StatusBadGateway, "upstream_error", "Upstream service temporarily unavailable"
|
||||||
|
default:
|
||||||
|
return http.StatusBadGateway, "upstream_error", "Upstream request failed"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// handleStreamingAwareError handles errors that may occur after streaming has started
|
// handleStreamingAwareError handles errors that may occur after streaming has started
|
||||||
func (h *GatewayHandler) handleStreamingAwareError(c *gin.Context, status int, errType, message string, streamStarted bool) {
|
func (h *GatewayHandler) handleStreamingAwareError(c *gin.Context, status int, errType, message string, streamStarted bool) {
|
||||||
if streamStarted {
|
if streamStarted {
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package handler
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
@@ -127,49 +128,74 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
|||||||
// Generate session hash (from header for OpenAI)
|
// Generate session hash (from header for OpenAI)
|
||||||
sessionHash := h.gatewayService.GenerateSessionHash(c)
|
sessionHash := h.gatewayService.GenerateSessionHash(c)
|
||||||
|
|
||||||
// Select account supporting the requested model
|
const maxAccountSwitches = 3
|
||||||
log.Printf("[OpenAI Handler] Selecting account: groupID=%v model=%s", apiKey.GroupID, reqModel)
|
switchCount := 0
|
||||||
account, err := h.gatewayService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel)
|
failedAccountIDs := make(map[int64]struct{})
|
||||||
if err != nil {
|
lastFailoverStatus := 0
|
||||||
log.Printf("[OpenAI Handler] SelectAccount failed: %v", err)
|
|
||||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
log.Printf("[OpenAI Handler] Selected account: id=%d name=%s", account.ID, account.Name)
|
|
||||||
|
|
||||||
// 3. Acquire account concurrency slot
|
for {
|
||||||
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, reqStream, &streamStarted)
|
// Select account supporting the requested model
|
||||||
if err != nil {
|
log.Printf("[OpenAI Handler] Selecting account: groupID=%v model=%s", apiKey.GroupID, reqModel)
|
||||||
log.Printf("Account concurrency acquire failed: %v", err)
|
account, err := h.gatewayService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs)
|
||||||
h.handleConcurrencyError(c, err, "account", streamStarted)
|
if err != nil {
|
||||||
return
|
log.Printf("[OpenAI Handler] SelectAccount failed: %v", err)
|
||||||
}
|
if len(failedAccountIDs) == 0 {
|
||||||
if accountReleaseFunc != nil {
|
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
||||||
defer accountReleaseFunc()
|
return
|
||||||
}
|
}
|
||||||
|
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
|
||||||
// Forward request
|
return
|
||||||
result, err := h.gatewayService.Forward(c.Request.Context(), c, account, body)
|
|
||||||
if err != nil {
|
|
||||||
// Error response already handled in Forward, just log
|
|
||||||
log.Printf("Forward request failed: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Async record usage
|
|
||||||
go func() {
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
|
||||||
Result: result,
|
|
||||||
ApiKey: apiKey,
|
|
||||||
User: apiKey.User,
|
|
||||||
Account: account,
|
|
||||||
Subscription: subscription,
|
|
||||||
}); err != nil {
|
|
||||||
log.Printf("Record usage failed: %v", err)
|
|
||||||
}
|
}
|
||||||
}()
|
log.Printf("[OpenAI Handler] Selected account: id=%d name=%s", account.ID, account.Name)
|
||||||
|
|
||||||
|
// 3. Acquire account concurrency slot
|
||||||
|
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, reqStream, &streamStarted)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Account concurrency acquire failed: %v", err)
|
||||||
|
h.handleConcurrencyError(c, err, "account", streamStarted)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Forward request
|
||||||
|
result, err := h.gatewayService.Forward(c.Request.Context(), c, account, body)
|
||||||
|
if accountReleaseFunc != nil {
|
||||||
|
accountReleaseFunc()
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
var failoverErr *service.UpstreamFailoverError
|
||||||
|
if errors.As(err, &failoverErr) {
|
||||||
|
failedAccountIDs[account.ID] = struct{}{}
|
||||||
|
if switchCount >= maxAccountSwitches {
|
||||||
|
lastFailoverStatus = failoverErr.StatusCode
|
||||||
|
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
lastFailoverStatus = failoverErr.StatusCode
|
||||||
|
switchCount++
|
||||||
|
log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// Error response already handled in Forward, just log
|
||||||
|
log.Printf("Forward request failed: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Async record usage
|
||||||
|
go func(result *service.OpenAIForwardResult, usedAccount *service.Account) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
||||||
|
Result: result,
|
||||||
|
ApiKey: apiKey,
|
||||||
|
User: apiKey.User,
|
||||||
|
Account: usedAccount,
|
||||||
|
Subscription: subscription,
|
||||||
|
}); err != nil {
|
||||||
|
log.Printf("Record usage failed: %v", err)
|
||||||
|
}
|
||||||
|
}(result, account)
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleConcurrencyError handles concurrency-related errors with proper 429 response
|
// handleConcurrencyError handles concurrency-related errors with proper 429 response
|
||||||
@@ -178,6 +204,28 @@ func (h *OpenAIGatewayHandler) handleConcurrencyError(c *gin.Context, err error,
|
|||||||
fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted)
|
fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *OpenAIGatewayHandler) handleFailoverExhausted(c *gin.Context, statusCode int, streamStarted bool) {
|
||||||
|
status, errType, errMsg := h.mapUpstreamError(statusCode)
|
||||||
|
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *OpenAIGatewayHandler) mapUpstreamError(statusCode int) (int, string, string) {
|
||||||
|
switch statusCode {
|
||||||
|
case 401:
|
||||||
|
return http.StatusBadGateway, "upstream_error", "Upstream authentication failed, please contact administrator"
|
||||||
|
case 403:
|
||||||
|
return http.StatusBadGateway, "upstream_error", "Upstream access forbidden, please contact administrator"
|
||||||
|
case 429:
|
||||||
|
return http.StatusTooManyRequests, "rate_limit_error", "Upstream rate limit exceeded, please retry later"
|
||||||
|
case 529:
|
||||||
|
return http.StatusServiceUnavailable, "upstream_error", "Upstream service overloaded, please retry later"
|
||||||
|
case 500, 502, 503, 504:
|
||||||
|
return http.StatusBadGateway, "upstream_error", "Upstream service temporarily unavailable"
|
||||||
|
default:
|
||||||
|
return http.StatusBadGateway, "upstream_error", "Upstream request failed"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// handleStreamingAwareError handles errors that may occur after streaming has started
|
// handleStreamingAwareError handles errors that may occur after streaming has started
|
||||||
func (h *OpenAIGatewayHandler) handleStreamingAwareError(c *gin.Context, status int, errType, message string, streamStarted bool) {
|
func (h *OpenAIGatewayHandler) handleStreamingAwareError(c *gin.Context, status int, errType, message string, streamStarted bool) {
|
||||||
if streamStarted {
|
if streamStarted {
|
||||||
|
|||||||
@@ -81,6 +81,15 @@ type ForwardResult struct {
|
|||||||
FirstTokenMs *int // 首字时间(流式请求)
|
FirstTokenMs *int // 首字时间(流式请求)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UpstreamFailoverError indicates an upstream error that should trigger account failover.
|
||||||
|
type UpstreamFailoverError struct {
|
||||||
|
StatusCode int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *UpstreamFailoverError) Error() string {
|
||||||
|
return fmt.Sprintf("upstream error: %d (failover)", e.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
// GatewayService handles API gateway operations
|
// GatewayService handles API gateway operations
|
||||||
type GatewayService struct {
|
type GatewayService struct {
|
||||||
accountRepo AccountRepository
|
accountRepo AccountRepository
|
||||||
@@ -274,19 +283,26 @@ func (s *GatewayService) SelectAccount(ctx context.Context, groupID *int64, sess
|
|||||||
|
|
||||||
// SelectAccountForModel 选择支持指定模型的账号(粘性会话+优先级+模型映射)
|
// SelectAccountForModel 选择支持指定模型的账号(粘性会话+优先级+模型映射)
|
||||||
func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*Account, error) {
|
func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*Account, error) {
|
||||||
|
return s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SelectAccountForModelWithExclusions selects an account supporting the requested model while excluding specified accounts.
|
||||||
|
func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) {
|
||||||
// 1. 查询粘性会话
|
// 1. 查询粘性会话
|
||||||
if sessionHash != "" {
|
if sessionHash != "" {
|
||||||
accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash)
|
accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash)
|
||||||
if err == nil && accountID > 0 {
|
if err == nil && accountID > 0 {
|
||||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
if _, excluded := excludedIDs[accountID]; !excluded {
|
||||||
// 使用IsSchedulable代替IsActive,确保限流/过载账号不会被选中
|
account, err := s.accountRepo.GetByID(ctx, accountID)
|
||||||
// 同时检查模型支持
|
// 使用IsSchedulable代替IsActive,确保限流/过载账号不会被选中
|
||||||
if err == nil && account.IsSchedulable() && (requestedModel == "" || account.IsModelSupported(requestedModel)) {
|
// 同时检查模型支持
|
||||||
// 续期粘性会话
|
if err == nil && account.IsSchedulable() && (requestedModel == "" || account.IsModelSupported(requestedModel)) {
|
||||||
if err := s.cache.RefreshSessionTTL(ctx, sessionHash, stickySessionTTL); err != nil {
|
// 续期粘性会话
|
||||||
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
|
if err := s.cache.RefreshSessionTTL(ctx, sessionHash, stickySessionTTL); err != nil {
|
||||||
|
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
|
||||||
|
}
|
||||||
|
return account, nil
|
||||||
}
|
}
|
||||||
return account, nil
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -307,6 +323,9 @@ func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int
|
|||||||
var selected *Account
|
var selected *Account
|
||||||
for i := range accounts {
|
for i := range accounts {
|
||||||
acc := &accounts[i]
|
acc := &accounts[i]
|
||||||
|
if _, excluded := excludedIDs[acc.ID]; excluded {
|
||||||
|
continue
|
||||||
|
}
|
||||||
// 检查模型支持
|
// 检查模型支持
|
||||||
if requestedModel != "" && !acc.IsModelSupported(requestedModel) {
|
if requestedModel != "" && !acc.IsModelSupported(requestedModel) {
|
||||||
continue
|
continue
|
||||||
@@ -394,6 +413,16 @@ func (s *GatewayService) shouldRetryUpstreamError(account *Account, statusCode i
|
|||||||
return !account.ShouldHandleErrorCode(statusCode)
|
return !account.ShouldHandleErrorCode(statusCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// shouldFailoverUpstreamError determines whether an upstream error should trigger account failover.
|
||||||
|
func (s *GatewayService) shouldFailoverUpstreamError(statusCode int) bool {
|
||||||
|
switch statusCode {
|
||||||
|
case 401, 403, 429, 529:
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return statusCode >= 500
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Forward 转发请求到Claude API
|
// Forward 转发请求到Claude API
|
||||||
func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte) (*ForwardResult, error) {
|
func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte) (*ForwardResult, error) {
|
||||||
startTime := time.Now()
|
startTime := time.Now()
|
||||||
@@ -478,9 +507,19 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
|||||||
|
|
||||||
// 处理重试耗尽的情况
|
// 处理重试耗尽的情况
|
||||||
if resp.StatusCode >= 400 && s.shouldRetryUpstreamError(account, resp.StatusCode) {
|
if resp.StatusCode >= 400 && s.shouldRetryUpstreamError(account, resp.StatusCode) {
|
||||||
|
if s.shouldFailoverUpstreamError(resp.StatusCode) {
|
||||||
|
s.handleRetryExhaustedSideEffects(ctx, resp, account)
|
||||||
|
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||||
|
}
|
||||||
return s.handleRetryExhaustedError(ctx, resp, c, account)
|
return s.handleRetryExhaustedError(ctx, resp, c, account)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 处理可切换账号的错误
|
||||||
|
if resp.StatusCode >= 400 && s.shouldFailoverUpstreamError(resp.StatusCode) {
|
||||||
|
s.handleFailoverSideEffects(ctx, resp, account)
|
||||||
|
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||||
|
}
|
||||||
|
|
||||||
// 处理错误响应(不可重试的错误)
|
// 处理错误响应(不可重试的错误)
|
||||||
if resp.StatusCode >= 400 {
|
if resp.StatusCode >= 400 {
|
||||||
return s.handleErrorResponse(ctx, resp, c, account)
|
return s.handleErrorResponse(ctx, resp, c, account)
|
||||||
@@ -692,10 +731,7 @@ func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Res
|
|||||||
return nil, fmt.Errorf("upstream error: %d", resp.StatusCode)
|
return nil, fmt.Errorf("upstream error: %d", resp.StatusCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleRetryExhaustedError 处理重试耗尽后的错误
|
func (s *GatewayService) handleRetryExhaustedSideEffects(ctx context.Context, resp *http.Response, account *Account) {
|
||||||
// OAuth 403:标记账号异常
|
|
||||||
// API Key 未配置错误码:仅返回错误,不标记账号
|
|
||||||
func (s *GatewayService) handleRetryExhaustedError(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*ForwardResult, error) {
|
|
||||||
body, _ := io.ReadAll(resp.Body)
|
body, _ := io.ReadAll(resp.Body)
|
||||||
statusCode := resp.StatusCode
|
statusCode := resp.StatusCode
|
||||||
|
|
||||||
@@ -707,6 +743,18 @@ func (s *GatewayService) handleRetryExhaustedError(ctx context.Context, resp *ht
|
|||||||
// API Key 未配置错误码:不标记账号状态
|
// API Key 未配置错误码:不标记账号状态
|
||||||
log.Printf("Account %d: upstream error %d after %d retries (not marking account)", account.ID, statusCode, maxRetries)
|
log.Printf("Account %d: upstream error %d after %d retries (not marking account)", account.ID, statusCode, maxRetries)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *GatewayService) handleFailoverSideEffects(ctx context.Context, resp *http.Response, account *Account) {
|
||||||
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleRetryExhaustedError 处理重试耗尽后的错误
|
||||||
|
// OAuth 403:标记账号异常
|
||||||
|
// API Key 未配置错误码:仅返回错误,不标记账号
|
||||||
|
func (s *GatewayService) handleRetryExhaustedError(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*ForwardResult, error) {
|
||||||
|
s.handleRetryExhaustedSideEffects(ctx, resp, account)
|
||||||
|
|
||||||
// 返回统一的重试耗尽错误响应
|
// 返回统一的重试耗尽错误响应
|
||||||
c.JSON(http.StatusBadGateway, gin.H{
|
c.JSON(http.StatusBadGateway, gin.H{
|
||||||
@@ -717,7 +765,7 @@ func (s *GatewayService) handleRetryExhaustedError(ctx context.Context, resp *ht
|
|||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
return nil, fmt.Errorf("upstream error: %d (retries exhausted)", statusCode)
|
return nil, fmt.Errorf("upstream error: %d (retries exhausted)", resp.StatusCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
// streamingResult 流式响应结果
|
// streamingResult 流式响应结果
|
||||||
|
|||||||
@@ -129,15 +129,22 @@ func (s *OpenAIGatewayService) SelectAccount(ctx context.Context, groupID *int64
|
|||||||
|
|
||||||
// SelectAccountForModel selects an account supporting the requested model
|
// SelectAccountForModel selects an account supporting the requested model
|
||||||
func (s *OpenAIGatewayService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*Account, error) {
|
func (s *OpenAIGatewayService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*Account, error) {
|
||||||
|
return s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SelectAccountForModelWithExclusions selects an account supporting the requested model while excluding specified accounts.
|
||||||
|
func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) {
|
||||||
// 1. Check sticky session
|
// 1. Check sticky session
|
||||||
if sessionHash != "" {
|
if sessionHash != "" {
|
||||||
accountID, err := s.cache.GetSessionAccountID(ctx, "openai:"+sessionHash)
|
accountID, err := s.cache.GetSessionAccountID(ctx, "openai:"+sessionHash)
|
||||||
if err == nil && accountID > 0 {
|
if err == nil && accountID > 0 {
|
||||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
if _, excluded := excludedIDs[accountID]; !excluded {
|
||||||
if err == nil && account.IsSchedulable() && account.IsOpenAI() && (requestedModel == "" || account.IsModelSupported(requestedModel)) {
|
account, err := s.accountRepo.GetByID(ctx, accountID)
|
||||||
// Refresh sticky session TTL
|
if err == nil && account.IsSchedulable() && account.IsOpenAI() && (requestedModel == "" || account.IsModelSupported(requestedModel)) {
|
||||||
_ = s.cache.RefreshSessionTTL(ctx, "openai:"+sessionHash, openaiStickySessionTTL)
|
// Refresh sticky session TTL
|
||||||
return account, nil
|
_ = s.cache.RefreshSessionTTL(ctx, "openai:"+sessionHash, openaiStickySessionTTL)
|
||||||
|
return account, nil
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -158,6 +165,9 @@ func (s *OpenAIGatewayService) SelectAccountForModel(ctx context.Context, groupI
|
|||||||
var selected *Account
|
var selected *Account
|
||||||
for i := range accounts {
|
for i := range accounts {
|
||||||
acc := &accounts[i]
|
acc := &accounts[i]
|
||||||
|
if _, excluded := excludedIDs[acc.ID]; excluded {
|
||||||
|
continue
|
||||||
|
}
|
||||||
// Check model support
|
// Check model support
|
||||||
if requestedModel != "" && !acc.IsModelSupported(requestedModel) {
|
if requestedModel != "" && !acc.IsModelSupported(requestedModel) {
|
||||||
continue
|
continue
|
||||||
@@ -221,6 +231,20 @@ func (s *OpenAIGatewayService) GetAccessToken(ctx context.Context, account *Acco
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *OpenAIGatewayService) shouldFailoverUpstreamError(statusCode int) bool {
|
||||||
|
switch statusCode {
|
||||||
|
case 401, 403, 429, 529:
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return statusCode >= 500
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OpenAIGatewayService) handleFailoverSideEffects(ctx context.Context, resp *http.Response, account *Account) {
|
||||||
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
|
||||||
|
}
|
||||||
|
|
||||||
// Forward forwards request to OpenAI API
|
// Forward forwards request to OpenAI API
|
||||||
func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte) (*OpenAIForwardResult, error) {
|
func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte) (*OpenAIForwardResult, error) {
|
||||||
startTime := time.Now()
|
startTime := time.Now()
|
||||||
@@ -288,6 +312,10 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
|||||||
|
|
||||||
// Handle error response
|
// Handle error response
|
||||||
if resp.StatusCode >= 400 {
|
if resp.StatusCode >= 400 {
|
||||||
|
if s.shouldFailoverUpstreamError(resp.StatusCode) {
|
||||||
|
s.handleFailoverSideEffects(ctx, resp, account)
|
||||||
|
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||||
|
}
|
||||||
return s.handleErrorResponse(ctx, resp, c, account)
|
return s.handleErrorResponse(ctx, resp, c, account)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user