This reverts commit 8d252303fc.
This commit is contained in:
@@ -141,10 +141,6 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
} else if apiKey.Group != nil {
|
||||
platform = apiKey.Group.Platform
|
||||
}
|
||||
sessionKey := sessionHash
|
||||
if platform == service.PlatformGemini && sessionHash != "" {
|
||||
sessionKey = "gemini:" + sessionHash
|
||||
}
|
||||
|
||||
if platform == service.PlatformGemini {
|
||||
const maxAccountSwitches = 3
|
||||
@@ -153,7 +149,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
lastFailoverStatus := 0
|
||||
|
||||
for {
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs)
|
||||
account, err := h.geminiCompatService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs)
|
||||
if err != nil {
|
||||
if len(failedAccountIDs) == 0 {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
||||
@@ -162,13 +158,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
|
||||
return
|
||||
}
|
||||
account := selection.Account
|
||||
|
||||
// 检查预热请求拦截(在账号选择后、转发前检查)
|
||||
if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) {
|
||||
if selection.Acquired && selection.ReleaseFunc != nil {
|
||||
selection.ReleaseFunc()
|
||||
}
|
||||
if reqStream {
|
||||
sendMockWarmupStream(c, reqModel)
|
||||
} else {
|
||||
@@ -178,46 +170,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 3. 获取账号并发槽位
|
||||
accountReleaseFunc := selection.ReleaseFunc
|
||||
var accountWaitRelease func()
|
||||
if !selection.Acquired {
|
||||
if selection.WaitPlan == nil {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
|
||||
return
|
||||
}
|
||||
canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
|
||||
if err != nil {
|
||||
log.Printf("Increment account wait count failed: %v", err)
|
||||
} else if !canWait {
|
||||
log.Printf("Account wait queue full: account=%d", account.ID)
|
||||
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted)
|
||||
return
|
||||
} else {
|
||||
// Only set release function if increment succeeded
|
||||
accountWaitRelease = func() {
|
||||
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
|
||||
}
|
||||
}
|
||||
|
||||
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
|
||||
c,
|
||||
account.ID,
|
||||
selection.WaitPlan.MaxConcurrency,
|
||||
selection.WaitPlan.Timeout,
|
||||
reqStream,
|
||||
&streamStarted,
|
||||
)
|
||||
if err != nil {
|
||||
if accountWaitRelease != nil {
|
||||
accountWaitRelease()
|
||||
}
|
||||
log.Printf("Account concurrency acquire failed: %v", err)
|
||||
h.handleConcurrencyError(c, err, "account", streamStarted)
|
||||
return
|
||||
}
|
||||
if err := h.gatewayService.BindStickySession(c.Request.Context(), sessionKey, account.ID); err != nil {
|
||||
log.Printf("Bind sticky session failed: %v", err)
|
||||
}
|
||||
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, reqStream, &streamStarted)
|
||||
if err != nil {
|
||||
log.Printf("Account concurrency acquire failed: %v", err)
|
||||
h.handleConcurrencyError(c, err, "account", streamStarted)
|
||||
return
|
||||
}
|
||||
|
||||
// 转发请求 - 根据账号平台分流
|
||||
@@ -230,9 +187,6 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
if accountReleaseFunc != nil {
|
||||
accountReleaseFunc()
|
||||
}
|
||||
if accountWaitRelease != nil {
|
||||
accountWaitRelease()
|
||||
}
|
||||
if err != nil {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
@@ -277,7 +231,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
|
||||
for {
|
||||
// 选择支持该模型的账号
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs)
|
||||
account, err := h.gatewayService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs)
|
||||
if err != nil {
|
||||
if len(failedAccountIDs) == 0 {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
||||
@@ -286,13 +240,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
|
||||
return
|
||||
}
|
||||
account := selection.Account
|
||||
|
||||
// 检查预热请求拦截(在账号选择后、转发前检查)
|
||||
if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) {
|
||||
if selection.Acquired && selection.ReleaseFunc != nil {
|
||||
selection.ReleaseFunc()
|
||||
}
|
||||
if reqStream {
|
||||
sendMockWarmupStream(c, reqModel)
|
||||
} else {
|
||||
@@ -302,46 +252,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 3. 获取账号并发槽位
|
||||
accountReleaseFunc := selection.ReleaseFunc
|
||||
var accountWaitRelease func()
|
||||
if !selection.Acquired {
|
||||
if selection.WaitPlan == nil {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
|
||||
return
|
||||
}
|
||||
canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
|
||||
if err != nil {
|
||||
log.Printf("Increment account wait count failed: %v", err)
|
||||
} else if !canWait {
|
||||
log.Printf("Account wait queue full: account=%d", account.ID)
|
||||
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted)
|
||||
return
|
||||
} else {
|
||||
// Only set release function if increment succeeded
|
||||
accountWaitRelease = func() {
|
||||
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
|
||||
}
|
||||
}
|
||||
|
||||
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
|
||||
c,
|
||||
account.ID,
|
||||
selection.WaitPlan.MaxConcurrency,
|
||||
selection.WaitPlan.Timeout,
|
||||
reqStream,
|
||||
&streamStarted,
|
||||
)
|
||||
if err != nil {
|
||||
if accountWaitRelease != nil {
|
||||
accountWaitRelease()
|
||||
}
|
||||
log.Printf("Account concurrency acquire failed: %v", err)
|
||||
h.handleConcurrencyError(c, err, "account", streamStarted)
|
||||
return
|
||||
}
|
||||
if err := h.gatewayService.BindStickySession(c.Request.Context(), sessionKey, account.ID); err != nil {
|
||||
log.Printf("Bind sticky session failed: %v", err)
|
||||
}
|
||||
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, reqStream, &streamStarted)
|
||||
if err != nil {
|
||||
log.Printf("Account concurrency acquire failed: %v", err)
|
||||
h.handleConcurrencyError(c, err, "account", streamStarted)
|
||||
return
|
||||
}
|
||||
|
||||
// 转发请求 - 根据账号平台分流
|
||||
@@ -354,9 +269,6 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
if accountReleaseFunc != nil {
|
||||
accountReleaseFunc()
|
||||
}
|
||||
if accountWaitRelease != nil {
|
||||
accountWaitRelease()
|
||||
}
|
||||
if err != nil {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
|
||||
@@ -83,16 +83,6 @@ func (h *ConcurrencyHelper) DecrementWaitCount(ctx context.Context, userID int64
|
||||
h.concurrencyService.DecrementWaitCount(ctx, userID)
|
||||
}
|
||||
|
||||
// IncrementAccountWaitCount increments the wait count for an account
|
||||
func (h *ConcurrencyHelper) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
|
||||
return h.concurrencyService.IncrementAccountWaitCount(ctx, accountID, maxWait)
|
||||
}
|
||||
|
||||
// DecrementAccountWaitCount decrements the wait count for an account
|
||||
func (h *ConcurrencyHelper) DecrementAccountWaitCount(ctx context.Context, accountID int64) {
|
||||
h.concurrencyService.DecrementAccountWaitCount(ctx, accountID)
|
||||
}
|
||||
|
||||
// AcquireUserSlotWithWait acquires a user concurrency slot, waiting if necessary.
|
||||
// For streaming requests, sends ping events during the wait.
|
||||
// streamStarted is updated if streaming response has begun.
|
||||
@@ -136,12 +126,7 @@ func (h *ConcurrencyHelper) AcquireAccountSlotWithWait(c *gin.Context, accountID
|
||||
// waitForSlotWithPing waits for a concurrency slot, sending ping events for streaming requests.
|
||||
// streamStarted pointer is updated when streaming begins (for proper error handling by caller).
|
||||
func (h *ConcurrencyHelper) waitForSlotWithPing(c *gin.Context, slotType string, id int64, maxConcurrency int, isStream bool, streamStarted *bool) (func(), error) {
|
||||
return h.waitForSlotWithPingTimeout(c, slotType, id, maxConcurrency, maxConcurrencyWait, isStream, streamStarted)
|
||||
}
|
||||
|
||||
// waitForSlotWithPingTimeout waits for a concurrency slot with a custom timeout.
|
||||
func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType string, id int64, maxConcurrency int, timeout time.Duration, isStream bool, streamStarted *bool) (func(), error) {
|
||||
ctx, cancel := context.WithTimeout(c.Request.Context(), timeout)
|
||||
ctx, cancel := context.WithTimeout(c.Request.Context(), maxConcurrencyWait)
|
||||
defer cancel()
|
||||
|
||||
// Determine if ping is needed (streaming + ping format defined)
|
||||
@@ -215,11 +200,6 @@ func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType
|
||||
}
|
||||
}
|
||||
|
||||
// AcquireAccountSlotWithWaitTimeout acquires an account slot with a custom timeout (keeps SSE ping).
|
||||
func (h *ConcurrencyHelper) AcquireAccountSlotWithWaitTimeout(c *gin.Context, accountID int64, maxConcurrency int, timeout time.Duration, isStream bool, streamStarted *bool) (func(), error) {
|
||||
return h.waitForSlotWithPingTimeout(c, "account", accountID, maxConcurrency, timeout, isStream, streamStarted)
|
||||
}
|
||||
|
||||
// nextBackoff 计算下一次退避时间
|
||||
// 性能优化:使用指数退避 + 随机抖动,避免惊群效应
|
||||
// current: 当前退避时间
|
||||
|
||||
@@ -197,17 +197,13 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
// 3) select account (sticky session based on request body)
|
||||
parsedReq, _ := service.ParseGatewayRequest(body)
|
||||
sessionHash := h.gatewayService.GenerateSessionHash(parsedReq)
|
||||
sessionKey := sessionHash
|
||||
if sessionHash != "" {
|
||||
sessionKey = "gemini:" + sessionHash
|
||||
}
|
||||
const maxAccountSwitches = 3
|
||||
switchCount := 0
|
||||
failedAccountIDs := make(map[int64]struct{})
|
||||
lastFailoverStatus := 0
|
||||
|
||||
for {
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, modelName, failedAccountIDs)
|
||||
account, err := h.geminiCompatService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, modelName, failedAccountIDs)
|
||||
if err != nil {
|
||||
if len(failedAccountIDs) == 0 {
|
||||
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
|
||||
@@ -216,48 +212,12 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
handleGeminiFailoverExhausted(c, lastFailoverStatus)
|
||||
return
|
||||
}
|
||||
account := selection.Account
|
||||
|
||||
// 4) account concurrency slot
|
||||
accountReleaseFunc := selection.ReleaseFunc
|
||||
var accountWaitRelease func()
|
||||
if !selection.Acquired {
|
||||
if selection.WaitPlan == nil {
|
||||
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts")
|
||||
return
|
||||
}
|
||||
canWait, err := geminiConcurrency.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
|
||||
if err != nil {
|
||||
log.Printf("Increment account wait count failed: %v", err)
|
||||
} else if !canWait {
|
||||
log.Printf("Account wait queue full: account=%d", account.ID)
|
||||
googleError(c, http.StatusTooManyRequests, "Too many pending requests, please retry later")
|
||||
return
|
||||
} else {
|
||||
// Only set release function if increment succeeded
|
||||
accountWaitRelease = func() {
|
||||
geminiConcurrency.DecrementAccountWaitCount(c.Request.Context(), account.ID)
|
||||
}
|
||||
}
|
||||
|
||||
accountReleaseFunc, err = geminiConcurrency.AcquireAccountSlotWithWaitTimeout(
|
||||
c,
|
||||
account.ID,
|
||||
selection.WaitPlan.MaxConcurrency,
|
||||
selection.WaitPlan.Timeout,
|
||||
stream,
|
||||
&streamStarted,
|
||||
)
|
||||
if err != nil {
|
||||
if accountWaitRelease != nil {
|
||||
accountWaitRelease()
|
||||
}
|
||||
googleError(c, http.StatusTooManyRequests, err.Error())
|
||||
return
|
||||
}
|
||||
if err := h.gatewayService.BindStickySession(c.Request.Context(), sessionKey, account.ID); err != nil {
|
||||
log.Printf("Bind sticky session failed: %v", err)
|
||||
}
|
||||
accountReleaseFunc, err := geminiConcurrency.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, stream, &streamStarted)
|
||||
if err != nil {
|
||||
googleError(c, http.StatusTooManyRequests, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 5) forward (根据平台分流)
|
||||
@@ -270,9 +230,6 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
if accountReleaseFunc != nil {
|
||||
accountReleaseFunc()
|
||||
}
|
||||
if accountWaitRelease != nil {
|
||||
accountWaitRelease()
|
||||
}
|
||||
if err != nil {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
|
||||
@@ -146,7 +146,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
for {
|
||||
// Select account supporting the requested model
|
||||
log.Printf("[OpenAI Handler] Selecting account: groupID=%v model=%s", apiKey.GroupID, reqModel)
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs)
|
||||
account, err := h.gatewayService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs)
|
||||
if err != nil {
|
||||
log.Printf("[OpenAI Handler] SelectAccount failed: %v", err)
|
||||
if len(failedAccountIDs) == 0 {
|
||||
@@ -156,50 +156,14 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
|
||||
return
|
||||
}
|
||||
account := selection.Account
|
||||
log.Printf("[OpenAI Handler] Selected account: id=%d name=%s", account.ID, account.Name)
|
||||
|
||||
// 3. Acquire account concurrency slot
|
||||
accountReleaseFunc := selection.ReleaseFunc
|
||||
var accountWaitRelease func()
|
||||
if !selection.Acquired {
|
||||
if selection.WaitPlan == nil {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
|
||||
return
|
||||
}
|
||||
canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
|
||||
if err != nil {
|
||||
log.Printf("Increment account wait count failed: %v", err)
|
||||
} else if !canWait {
|
||||
log.Printf("Account wait queue full: account=%d", account.ID)
|
||||
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted)
|
||||
return
|
||||
} else {
|
||||
// Only set release function if increment succeeded
|
||||
accountWaitRelease = func() {
|
||||
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
|
||||
}
|
||||
}
|
||||
|
||||
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
|
||||
c,
|
||||
account.ID,
|
||||
selection.WaitPlan.MaxConcurrency,
|
||||
selection.WaitPlan.Timeout,
|
||||
reqStream,
|
||||
&streamStarted,
|
||||
)
|
||||
if err != nil {
|
||||
if accountWaitRelease != nil {
|
||||
accountWaitRelease()
|
||||
}
|
||||
log.Printf("Account concurrency acquire failed: %v", err)
|
||||
h.handleConcurrencyError(c, err, "account", streamStarted)
|
||||
return
|
||||
}
|
||||
if err := h.gatewayService.BindStickySession(c.Request.Context(), sessionHash, account.ID); err != nil {
|
||||
log.Printf("Bind sticky session failed: %v", err)
|
||||
}
|
||||
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, reqStream, &streamStarted)
|
||||
if err != nil {
|
||||
log.Printf("Account concurrency acquire failed: %v", err)
|
||||
h.handleConcurrencyError(c, err, "account", streamStarted)
|
||||
return
|
||||
}
|
||||
|
||||
// Forward request
|
||||
@@ -207,9 +171,6 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
if accountReleaseFunc != nil {
|
||||
accountReleaseFunc()
|
||||
}
|
||||
if accountWaitRelease != nil {
|
||||
accountWaitRelease()
|
||||
}
|
||||
if err != nil {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
|
||||
Reference in New Issue
Block a user