diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index de3cbad9..7d1eab28 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -88,6 +88,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) { return } + setOpsRequestContext(c, "", false, body) + parsedReq, err := service.ParseGatewayRequest(body) if err != nil { h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body") @@ -96,6 +98,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) { reqModel := parsedReq.Model reqStream := parsedReq.Stream + setOpsRequestContext(c, reqModel, reqStream, body) + // 验证 model 必填 if reqModel == "" { h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required") @@ -111,6 +115,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { // 0. 检查wait队列是否已满 maxWait := service.CalculateMaxWait(subject.Concurrency) canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait) + waitCounted := false if err != nil { log.Printf("Increment wait count failed: %v", err) // On error, allow request to proceed @@ -118,8 +123,15 @@ func (h *GatewayHandler) Messages(c *gin.Context) { h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later") return } - // 确保在函数退出时减少wait计数 - defer h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID) + if err == nil && canWait { + waitCounted = true + } + // Ensure we decrement if we exit before acquiring the user slot. + defer func() { + if waitCounted { + h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID) + } + }() // 1. 首先获取用户并发槽位 userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted) @@ -128,6 +140,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) { h.handleConcurrencyError(c, err, "user", streamStarted) return } + // User slot acquired: no longer waiting in the queue. + if waitCounted { + h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID) + waitCounted = false + } // 在请求结束或 Context 取消时确保释放槽位,避免客户端断开造成泄漏 userReleaseFunc = wrapReleaseOnDone(c.Request.Context(), userReleaseFunc) if userReleaseFunc != nil { @@ -174,6 +191,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { return } account := selection.Account + setOpsSelectedAccount(c, account.ID) // 检查预热请求拦截(在账号选择后、转发前检查) if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) { @@ -190,12 +208,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) { // 3. 获取账号并发槽位 accountReleaseFunc := selection.ReleaseFunc - var accountWaitRelease func() if !selection.Acquired { if selection.WaitPlan == nil { h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted) return } + accountWaitCounted := false canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting) if err != nil { log.Printf("Increment account wait count failed: %v", err) @@ -203,12 +221,16 @@ func (h *GatewayHandler) Messages(c *gin.Context) { log.Printf("Account wait queue full: account=%d", account.ID) h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted) return - } else { - // Only set release function if increment succeeded - accountWaitRelease = func() { + } + if err == nil && canWait { + accountWaitCounted = true + } + // Ensure the wait counter is decremented if we exit before acquiring the slot. + defer func() { + if accountWaitCounted { h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) } - } + }() accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout( c, @@ -219,20 +241,21 @@ func (h *GatewayHandler) Messages(c *gin.Context) { &streamStarted, ) if err != nil { - if accountWaitRelease != nil { - accountWaitRelease() - } log.Printf("Account concurrency acquire failed: %v", err) h.handleConcurrencyError(c, err, "account", streamStarted) return } + // Slot acquired: no longer waiting in queue. + if accountWaitCounted { + h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) + accountWaitCounted = false + } if err := h.gatewayService.BindStickySession(c.Request.Context(), sessionKey, account.ID); err != nil { log.Printf("Bind sticky session failed: %v", err) } } // 账号槽位/等待计数需要在超时或断开时安全回收 accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc) - accountWaitRelease = wrapReleaseOnDone(c.Request.Context(), accountWaitRelease) // 转发请求 - 根据账号平台分流 var result *service.ForwardResult @@ -244,9 +267,6 @@ func (h *GatewayHandler) Messages(c *gin.Context) { if accountReleaseFunc != nil { accountReleaseFunc() } - if accountWaitRelease != nil { - accountWaitRelease() - } if err != nil { var failoverErr *service.UpstreamFailoverError if errors.As(err, &failoverErr) { @@ -301,6 +321,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { return } account := selection.Account + setOpsSelectedAccount(c, account.ID) // 检查预热请求拦截(在账号选择后、转发前检查) if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) { @@ -317,12 +338,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) { // 3. 获取账号并发槽位 accountReleaseFunc := selection.ReleaseFunc - var accountWaitRelease func() if !selection.Acquired { if selection.WaitPlan == nil { h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted) return } + accountWaitCounted := false canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting) if err != nil { log.Printf("Increment account wait count failed: %v", err) @@ -330,12 +351,15 @@ func (h *GatewayHandler) Messages(c *gin.Context) { log.Printf("Account wait queue full: account=%d", account.ID) h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted) return - } else { - // Only set release function if increment succeeded - accountWaitRelease = func() { + } + if err == nil && canWait { + accountWaitCounted = true + } + defer func() { + if accountWaitCounted { h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) } - } + }() accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout( c, @@ -346,20 +370,20 @@ func (h *GatewayHandler) Messages(c *gin.Context) { &streamStarted, ) if err != nil { - if accountWaitRelease != nil { - accountWaitRelease() - } log.Printf("Account concurrency acquire failed: %v", err) h.handleConcurrencyError(c, err, "account", streamStarted) return } + if accountWaitCounted { + h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) + accountWaitCounted = false + } if err := h.gatewayService.BindStickySession(c.Request.Context(), sessionKey, account.ID); err != nil { log.Printf("Bind sticky session failed: %v", err) } } // 账号槽位/等待计数需要在超时或断开时安全回收 accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc) - accountWaitRelease = wrapReleaseOnDone(c.Request.Context(), accountWaitRelease) // 转发请求 - 根据账号平台分流 var result *service.ForwardResult @@ -371,9 +395,6 @@ func (h *GatewayHandler) Messages(c *gin.Context) { if accountReleaseFunc != nil { accountReleaseFunc() } - if accountWaitRelease != nil { - accountWaitRelease() - } if err != nil { var failoverErr *service.UpstreamFailoverError if errors.As(err, &failoverErr) { @@ -672,6 +693,8 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) { return } + setOpsRequestContext(c, "", false, body) + parsedReq, err := service.ParseGatewayRequest(body) if err != nil { h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body") @@ -684,6 +707,8 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) { return } + setOpsRequestContext(c, parsedReq.Model, parsedReq.Stream, body) + // 获取订阅信息(可能为nil) subscription, _ := middleware2.GetSubscriptionFromContext(c) @@ -704,6 +729,7 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) { h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error()) return } + setOpsSelectedAccount(c, account.ID) // 转发请求(不记录使用量) if err := h.gatewayService.ForwardCountTokens(c.Request.Context(), c, account, parsedReq); err != nil { diff --git a/backend/internal/handler/gemini_v1beta_handler.go b/backend/internal/handler/gemini_v1beta_handler.go index aaf651e9..73550575 100644 --- a/backend/internal/handler/gemini_v1beta_handler.go +++ b/backend/internal/handler/gemini_v1beta_handler.go @@ -161,6 +161,8 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { return } + setOpsRequestContext(c, modelName, stream, body) + // Get subscription (may be nil) subscription, _ := middleware.GetSubscriptionFromContext(c) @@ -170,13 +172,21 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { // 0) wait queue check maxWait := service.CalculateMaxWait(authSubject.Concurrency) canWait, err := geminiConcurrency.IncrementWaitCount(c.Request.Context(), authSubject.UserID, maxWait) + waitCounted := false if err != nil { log.Printf("Increment wait count failed: %v", err) } else if !canWait { googleError(c, http.StatusTooManyRequests, "Too many pending requests, please retry later") return } - defer geminiConcurrency.DecrementWaitCount(c.Request.Context(), authSubject.UserID) + if err == nil && canWait { + waitCounted = true + } + defer func() { + if waitCounted { + geminiConcurrency.DecrementWaitCount(c.Request.Context(), authSubject.UserID) + } + }() // 1) user concurrency slot streamStarted := false @@ -185,6 +195,10 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { googleError(c, http.StatusTooManyRequests, err.Error()) return } + if waitCounted { + geminiConcurrency.DecrementWaitCount(c.Request.Context(), authSubject.UserID) + waitCounted = false + } // 确保请求取消时也会释放槽位,避免长连接被动中断造成泄漏 userReleaseFunc = wrapReleaseOnDone(c.Request.Context(), userReleaseFunc) if userReleaseFunc != nil { @@ -221,15 +235,16 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { return } account := selection.Account + setOpsSelectedAccount(c, account.ID) // 4) account concurrency slot accountReleaseFunc := selection.ReleaseFunc - var accountWaitRelease func() if !selection.Acquired { if selection.WaitPlan == nil { googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts") return } + accountWaitCounted := false canWait, err := geminiConcurrency.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting) if err != nil { log.Printf("Increment account wait count failed: %v", err) @@ -237,12 +252,15 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { log.Printf("Account wait queue full: account=%d", account.ID) googleError(c, http.StatusTooManyRequests, "Too many pending requests, please retry later") return - } else { - // Only set release function if increment succeeded - accountWaitRelease = func() { + } + if err == nil && canWait { + accountWaitCounted = true + } + defer func() { + if accountWaitCounted { geminiConcurrency.DecrementAccountWaitCount(c.Request.Context(), account.ID) } - } + }() accountReleaseFunc, err = geminiConcurrency.AcquireAccountSlotWithWaitTimeout( c, @@ -253,19 +271,19 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { &streamStarted, ) if err != nil { - if accountWaitRelease != nil { - accountWaitRelease() - } googleError(c, http.StatusTooManyRequests, err.Error()) return } + if accountWaitCounted { + geminiConcurrency.DecrementAccountWaitCount(c.Request.Context(), account.ID) + accountWaitCounted = false + } if err := h.gatewayService.BindStickySession(c.Request.Context(), sessionKey, account.ID); err != nil { log.Printf("Bind sticky session failed: %v", err) } } // 账号槽位/等待计数需要在超时或断开时安全回收 accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc) - accountWaitRelease = wrapReleaseOnDone(c.Request.Context(), accountWaitRelease) // 5) forward (根据平台分流) var result *service.ForwardResult @@ -277,9 +295,6 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { if accountReleaseFunc != nil { accountReleaseFunc() } - if accountWaitRelease != nil { - accountWaitRelease() - } if err != nil { var failoverErr *service.UpstreamFailoverError if errors.As(err, &failoverErr) { diff --git a/backend/internal/handler/handler.go b/backend/internal/handler/handler.go index 817b71d3..030ebd68 100644 --- a/backend/internal/handler/handler.go +++ b/backend/internal/handler/handler.go @@ -17,6 +17,7 @@ type AdminHandlers struct { Proxy *admin.ProxyHandler Redeem *admin.RedeemHandler Setting *admin.SettingHandler + Ops *admin.OpsHandler System *admin.SystemHandler Subscription *admin.SubscriptionHandler Usage *admin.UsageHandler diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go index 04d268a5..2ddf77ed 100644 --- a/backend/internal/handler/openai_gateway_handler.go +++ b/backend/internal/handler/openai_gateway_handler.go @@ -75,6 +75,8 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { return } + setOpsRequestContext(c, "", false, body) + // Parse request body to map for potential modification var reqBody map[string]any if err := json.Unmarshal(body, &reqBody); err != nil { @@ -104,6 +106,8 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { } } + setOpsRequestContext(c, reqModel, reqStream, body) + // Track if we've started streaming (for error handling) streamStarted := false @@ -113,6 +117,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { // 0. Check if wait queue is full maxWait := service.CalculateMaxWait(subject.Concurrency) canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait) + waitCounted := false if err != nil { log.Printf("Increment wait count failed: %v", err) // On error, allow request to proceed @@ -120,8 +125,14 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later") return } - // Ensure wait count is decremented when function exits - defer h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID) + if err == nil && canWait { + waitCounted = true + } + defer func() { + if waitCounted { + h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID) + } + }() // 1. First acquire user concurrency slot userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted) @@ -130,6 +141,11 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { h.handleConcurrencyError(c, err, "user", streamStarted) return } + // User slot acquired: no longer waiting. + if waitCounted { + h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID) + waitCounted = false + } // 确保请求取消时也会释放槽位,避免长连接被动中断造成泄漏 userReleaseFunc = wrapReleaseOnDone(c.Request.Context(), userReleaseFunc) if userReleaseFunc != nil { @@ -167,15 +183,16 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { } account := selection.Account log.Printf("[OpenAI Handler] Selected account: id=%d name=%s", account.ID, account.Name) + setOpsSelectedAccount(c, account.ID) // 3. Acquire account concurrency slot accountReleaseFunc := selection.ReleaseFunc - var accountWaitRelease func() if !selection.Acquired { if selection.WaitPlan == nil { h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted) return } + accountWaitCounted := false canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting) if err != nil { log.Printf("Increment account wait count failed: %v", err) @@ -183,12 +200,15 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { log.Printf("Account wait queue full: account=%d", account.ID) h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted) return - } else { - // Only set release function if increment succeeded - accountWaitRelease = func() { + } + if err == nil && canWait { + accountWaitCounted = true + } + defer func() { + if accountWaitCounted { h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) } - } + }() accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout( c, @@ -199,29 +219,26 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { &streamStarted, ) if err != nil { - if accountWaitRelease != nil { - accountWaitRelease() - } log.Printf("Account concurrency acquire failed: %v", err) h.handleConcurrencyError(c, err, "account", streamStarted) return } + if accountWaitCounted { + h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) + accountWaitCounted = false + } if err := h.gatewayService.BindStickySession(c.Request.Context(), sessionHash, account.ID); err != nil { log.Printf("Bind sticky session failed: %v", err) } } // 账号槽位/等待计数需要在超时或断开时安全回收 accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc) - accountWaitRelease = wrapReleaseOnDone(c.Request.Context(), accountWaitRelease) // Forward request result, err := h.gatewayService.Forward(c.Request.Context(), c, account, body) if accountReleaseFunc != nil { accountReleaseFunc() } - if accountWaitRelease != nil { - accountWaitRelease() - } if err != nil { var failoverErr *service.UpstreamFailoverError if errors.As(err, &failoverErr) {