feat: cc/codex support account retry
This commit is contained in:
@@ -3,6 +3,7 @@ package handler
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
@@ -127,66 +128,134 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
platform = apiKey.Group.Platform
|
||||
}
|
||||
|
||||
// 选择支持该模型的账号
|
||||
var account *service.Account
|
||||
if platform == service.PlatformGemini {
|
||||
account, err = h.geminiCompatService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, req.Model)
|
||||
} else {
|
||||
account, err = h.gatewayService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, req.Model)
|
||||
}
|
||||
if err != nil {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
||||
return
|
||||
}
|
||||
|
||||
// 检查预热请求拦截(在账号选择后、转发前检查)
|
||||
if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) {
|
||||
if req.Stream {
|
||||
sendMockWarmupStream(c, req.Model)
|
||||
} else {
|
||||
sendMockWarmupResponse(c, req.Model)
|
||||
account, err := h.geminiCompatService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, req.Model)
|
||||
if err != nil {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 3. 获取账号并发槽位
|
||||
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, req.Stream, &streamStarted)
|
||||
if err != nil {
|
||||
log.Printf("Account concurrency acquire failed: %v", err)
|
||||
h.handleConcurrencyError(c, err, "account", streamStarted)
|
||||
return
|
||||
}
|
||||
if accountReleaseFunc != nil {
|
||||
defer accountReleaseFunc()
|
||||
}
|
||||
|
||||
// 转发请求
|
||||
var result *service.ForwardResult
|
||||
if platform == service.PlatformGemini {
|
||||
result, err = h.geminiCompatService.Forward(c.Request.Context(), c, account, body)
|
||||
} else {
|
||||
result, err = h.gatewayService.Forward(c.Request.Context(), c, account, body)
|
||||
}
|
||||
if err != nil {
|
||||
// 错误响应已在Forward中处理,这里只记录日志
|
||||
log.Printf("Forward request failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 异步记录使用量(subscription已在函数开头获取)
|
||||
go func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||
Result: result,
|
||||
ApiKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
}); err != nil {
|
||||
log.Printf("Record usage failed: %v", err)
|
||||
// 检查预热请求拦截(在账号选择后、转发前检查)
|
||||
if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) {
|
||||
if req.Stream {
|
||||
sendMockWarmupStream(c, req.Model)
|
||||
} else {
|
||||
sendMockWarmupResponse(c, req.Model)
|
||||
}
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
// 3. 获取账号并发槽位
|
||||
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, req.Stream, &streamStarted)
|
||||
if err != nil {
|
||||
log.Printf("Account concurrency acquire failed: %v", err)
|
||||
h.handleConcurrencyError(c, err, "account", streamStarted)
|
||||
return
|
||||
}
|
||||
if accountReleaseFunc != nil {
|
||||
defer accountReleaseFunc()
|
||||
}
|
||||
|
||||
// 转发请求
|
||||
result, err := h.geminiCompatService.Forward(c.Request.Context(), c, account, body)
|
||||
if err != nil {
|
||||
// 错误响应已在Forward中处理,这里只记录日志
|
||||
log.Printf("Forward request failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 异步记录使用量(subscription已在函数开头获取)
|
||||
go func(result *service.ForwardResult, usedAccount *service.Account) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||
Result: result,
|
||||
ApiKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: usedAccount,
|
||||
Subscription: subscription,
|
||||
}); err != nil {
|
||||
log.Printf("Record usage failed: %v", err)
|
||||
}
|
||||
}(result, account)
|
||||
return
|
||||
}
|
||||
|
||||
const maxAccountSwitches = 3
|
||||
switchCount := 0
|
||||
failedAccountIDs := make(map[int64]struct{})
|
||||
lastFailoverStatus := 0
|
||||
|
||||
for {
|
||||
// 选择支持该模型的账号
|
||||
account, err := h.gatewayService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, req.Model, failedAccountIDs)
|
||||
if err != nil {
|
||||
if len(failedAccountIDs) == 0 {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
||||
return
|
||||
}
|
||||
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
|
||||
return
|
||||
}
|
||||
|
||||
// 检查预热请求拦截(在账号选择后、转发前检查)
|
||||
if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) {
|
||||
if req.Stream {
|
||||
sendMockWarmupStream(c, req.Model)
|
||||
} else {
|
||||
sendMockWarmupResponse(c, req.Model)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 3. 获取账号并发槽位
|
||||
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, req.Stream, &streamStarted)
|
||||
if err != nil {
|
||||
log.Printf("Account concurrency acquire failed: %v", err)
|
||||
h.handleConcurrencyError(c, err, "account", streamStarted)
|
||||
return
|
||||
}
|
||||
|
||||
// 转发请求
|
||||
result, err := h.gatewayService.Forward(c.Request.Context(), c, account, body)
|
||||
if accountReleaseFunc != nil {
|
||||
accountReleaseFunc()
|
||||
}
|
||||
if err != nil {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
if switchCount >= maxAccountSwitches {
|
||||
lastFailoverStatus = failoverErr.StatusCode
|
||||
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
|
||||
return
|
||||
}
|
||||
lastFailoverStatus = failoverErr.StatusCode
|
||||
switchCount++
|
||||
log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
|
||||
continue
|
||||
}
|
||||
// 错误响应已在Forward中处理,这里只记录日志
|
||||
log.Printf("Forward request failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 异步记录使用量(subscription已在函数开头获取)
|
||||
go func(result *service.ForwardResult, usedAccount *service.Account) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||
Result: result,
|
||||
ApiKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: usedAccount,
|
||||
Subscription: subscription,
|
||||
}); err != nil {
|
||||
log.Printf("Record usage failed: %v", err)
|
||||
}
|
||||
}(result, account)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Models handles listing available models
|
||||
@@ -314,6 +383,28 @@ func (h *GatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotT
|
||||
fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted)
|
||||
}
|
||||
|
||||
func (h *GatewayHandler) handleFailoverExhausted(c *gin.Context, statusCode int, streamStarted bool) {
|
||||
status, errType, errMsg := h.mapUpstreamError(statusCode)
|
||||
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
|
||||
}
|
||||
|
||||
func (h *GatewayHandler) mapUpstreamError(statusCode int) (int, string, string) {
|
||||
switch statusCode {
|
||||
case 401:
|
||||
return http.StatusBadGateway, "upstream_error", "Upstream authentication failed, please contact administrator"
|
||||
case 403:
|
||||
return http.StatusBadGateway, "upstream_error", "Upstream access forbidden, please contact administrator"
|
||||
case 429:
|
||||
return http.StatusTooManyRequests, "rate_limit_error", "Upstream rate limit exceeded, please retry later"
|
||||
case 529:
|
||||
return http.StatusServiceUnavailable, "overloaded_error", "Upstream service overloaded, please retry later"
|
||||
case 500, 502, 503, 504:
|
||||
return http.StatusBadGateway, "upstream_error", "Upstream service temporarily unavailable"
|
||||
default:
|
||||
return http.StatusBadGateway, "upstream_error", "Upstream request failed"
|
||||
}
|
||||
}
|
||||
|
||||
// handleStreamingAwareError handles errors that may occur after streaming has started
|
||||
func (h *GatewayHandler) handleStreamingAwareError(c *gin.Context, status int, errType, message string, streamStarted bool) {
|
||||
if streamStarted {
|
||||
|
||||
@@ -3,6 +3,7 @@ package handler
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
@@ -127,49 +128,74 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
// Generate session hash (from header for OpenAI)
|
||||
sessionHash := h.gatewayService.GenerateSessionHash(c)
|
||||
|
||||
// Select account supporting the requested model
|
||||
log.Printf("[OpenAI Handler] Selecting account: groupID=%v model=%s", apiKey.GroupID, reqModel)
|
||||
account, err := h.gatewayService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel)
|
||||
if err != nil {
|
||||
log.Printf("[OpenAI Handler] SelectAccount failed: %v", err)
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
||||
return
|
||||
}
|
||||
log.Printf("[OpenAI Handler] Selected account: id=%d name=%s", account.ID, account.Name)
|
||||
const maxAccountSwitches = 3
|
||||
switchCount := 0
|
||||
failedAccountIDs := make(map[int64]struct{})
|
||||
lastFailoverStatus := 0
|
||||
|
||||
// 3. Acquire account concurrency slot
|
||||
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, reqStream, &streamStarted)
|
||||
if err != nil {
|
||||
log.Printf("Account concurrency acquire failed: %v", err)
|
||||
h.handleConcurrencyError(c, err, "account", streamStarted)
|
||||
return
|
||||
}
|
||||
if accountReleaseFunc != nil {
|
||||
defer accountReleaseFunc()
|
||||
}
|
||||
|
||||
// Forward request
|
||||
result, err := h.gatewayService.Forward(c.Request.Context(), c, account, body)
|
||||
if err != nil {
|
||||
// Error response already handled in Forward, just log
|
||||
log.Printf("Forward request failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Async record usage
|
||||
go func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
||||
Result: result,
|
||||
ApiKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
}); err != nil {
|
||||
log.Printf("Record usage failed: %v", err)
|
||||
for {
|
||||
// Select account supporting the requested model
|
||||
log.Printf("[OpenAI Handler] Selecting account: groupID=%v model=%s", apiKey.GroupID, reqModel)
|
||||
account, err := h.gatewayService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs)
|
||||
if err != nil {
|
||||
log.Printf("[OpenAI Handler] SelectAccount failed: %v", err)
|
||||
if len(failedAccountIDs) == 0 {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
||||
return
|
||||
}
|
||||
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
|
||||
return
|
||||
}
|
||||
}()
|
||||
log.Printf("[OpenAI Handler] Selected account: id=%d name=%s", account.ID, account.Name)
|
||||
|
||||
// 3. Acquire account concurrency slot
|
||||
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, reqStream, &streamStarted)
|
||||
if err != nil {
|
||||
log.Printf("Account concurrency acquire failed: %v", err)
|
||||
h.handleConcurrencyError(c, err, "account", streamStarted)
|
||||
return
|
||||
}
|
||||
|
||||
// Forward request
|
||||
result, err := h.gatewayService.Forward(c.Request.Context(), c, account, body)
|
||||
if accountReleaseFunc != nil {
|
||||
accountReleaseFunc()
|
||||
}
|
||||
if err != nil {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
if switchCount >= maxAccountSwitches {
|
||||
lastFailoverStatus = failoverErr.StatusCode
|
||||
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
|
||||
return
|
||||
}
|
||||
lastFailoverStatus = failoverErr.StatusCode
|
||||
switchCount++
|
||||
log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
|
||||
continue
|
||||
}
|
||||
// Error response already handled in Forward, just log
|
||||
log.Printf("Forward request failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Async record usage
|
||||
go func(result *service.OpenAIForwardResult, usedAccount *service.Account) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
||||
Result: result,
|
||||
ApiKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: usedAccount,
|
||||
Subscription: subscription,
|
||||
}); err != nil {
|
||||
log.Printf("Record usage failed: %v", err)
|
||||
}
|
||||
}(result, account)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// handleConcurrencyError handles concurrency-related errors with proper 429 response
|
||||
@@ -178,6 +204,28 @@ func (h *OpenAIGatewayHandler) handleConcurrencyError(c *gin.Context, err error,
|
||||
fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted)
|
||||
}
|
||||
|
||||
func (h *OpenAIGatewayHandler) handleFailoverExhausted(c *gin.Context, statusCode int, streamStarted bool) {
|
||||
status, errType, errMsg := h.mapUpstreamError(statusCode)
|
||||
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
|
||||
}
|
||||
|
||||
func (h *OpenAIGatewayHandler) mapUpstreamError(statusCode int) (int, string, string) {
|
||||
switch statusCode {
|
||||
case 401:
|
||||
return http.StatusBadGateway, "upstream_error", "Upstream authentication failed, please contact administrator"
|
||||
case 403:
|
||||
return http.StatusBadGateway, "upstream_error", "Upstream access forbidden, please contact administrator"
|
||||
case 429:
|
||||
return http.StatusTooManyRequests, "rate_limit_error", "Upstream rate limit exceeded, please retry later"
|
||||
case 529:
|
||||
return http.StatusServiceUnavailable, "upstream_error", "Upstream service overloaded, please retry later"
|
||||
case 500, 502, 503, 504:
|
||||
return http.StatusBadGateway, "upstream_error", "Upstream service temporarily unavailable"
|
||||
default:
|
||||
return http.StatusBadGateway, "upstream_error", "Upstream request failed"
|
||||
}
|
||||
}
|
||||
|
||||
// handleStreamingAwareError handles errors that may occur after streaming has started
|
||||
func (h *OpenAIGatewayHandler) handleStreamingAwareError(c *gin.Context, status int, errType, message string, streamStarted bool) {
|
||||
if streamStarted {
|
||||
|
||||
Reference in New Issue
Block a user