package handler import ( "context" "encoding/json" "errors" "fmt" "net/http" "runtime/debug" "strconv" "strings" "time" "github.com/Wei-Shaw/sub2api/internal/config" pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil" "github.com/Wei-Shaw/sub2api/internal/pkg/ip" "github.com/Wei-Shaw/sub2api/internal/pkg/logger" middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/service" coderws "github.com/coder/websocket" "github.com/gin-gonic/gin" "github.com/tidwall/gjson" "go.uber.org/zap" ) // OpenAIGatewayHandler handles OpenAI API gateway requests type OpenAIGatewayHandler struct { gatewayService *service.OpenAIGatewayService billingCacheService *service.BillingCacheService apiKeyService *service.APIKeyService usageRecordWorkerPool *service.UsageRecordWorkerPool errorPassthroughService *service.ErrorPassthroughService concurrencyHelper *ConcurrencyHelper maxAccountSwitches int } // NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler func NewOpenAIGatewayHandler( gatewayService *service.OpenAIGatewayService, concurrencyService *service.ConcurrencyService, billingCacheService *service.BillingCacheService, apiKeyService *service.APIKeyService, usageRecordWorkerPool *service.UsageRecordWorkerPool, errorPassthroughService *service.ErrorPassthroughService, cfg *config.Config, ) *OpenAIGatewayHandler { pingInterval := time.Duration(0) maxAccountSwitches := 3 if cfg != nil { pingInterval = time.Duration(cfg.Concurrency.PingInterval) * time.Second if cfg.Gateway.MaxAccountSwitches > 0 { maxAccountSwitches = cfg.Gateway.MaxAccountSwitches } } return &OpenAIGatewayHandler{ gatewayService: gatewayService, billingCacheService: billingCacheService, apiKeyService: apiKeyService, usageRecordWorkerPool: usageRecordWorkerPool, errorPassthroughService: errorPassthroughService, concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval), maxAccountSwitches: maxAccountSwitches, } } // Responses handles OpenAI Responses API endpoint // POST /openai/v1/responses func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { // 局部兜底:确保该 handler 内部任何 panic 都不会击穿到进程级。 streamStarted := false defer h.recoverResponsesPanic(c, &streamStarted) setOpenAIClientTransportHTTP(c) requestStart := time.Now() // Get apiKey and user from context (set by ApiKeyAuth middleware) apiKey, ok := middleware2.GetAPIKeyFromContext(c) if !ok { h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key") return } subject, ok := middleware2.GetAuthSubjectFromContext(c) if !ok { h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found") return } reqLog := requestLogger( c, "handler.openai_gateway.responses", zap.Int64("user_id", subject.UserID), zap.Int64("api_key_id", apiKey.ID), zap.Any("group_id", apiKey.GroupID), ) if !h.ensureResponsesDependencies(c, reqLog) { return } // Read request body body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request) if err != nil { if maxErr, ok := extractMaxBytesError(err); ok { h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit)) return } h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body") return } if len(body) == 0 { h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty") return } setOpsRequestContext(c, "", false, body) // 校验请求体 JSON 合法性 if !gjson.ValidBytes(body) { h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body") return } // 使用 gjson 只读提取字段做校验,避免完整 Unmarshal modelResult := gjson.GetBytes(body, "model") if !modelResult.Exists() || modelResult.Type != gjson.String || modelResult.String() == "" { h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required") return } reqModel := modelResult.String() streamResult := gjson.GetBytes(body, "stream") if streamResult.Exists() && streamResult.Type != gjson.True && streamResult.Type != gjson.False { h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "invalid stream field type") return } reqStream := streamResult.Bool() reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream)) previousResponseID := strings.TrimSpace(gjson.GetBytes(body, "previous_response_id").String()) if previousResponseID != "" { previousResponseIDKind := service.ClassifyOpenAIPreviousResponseIDKind(previousResponseID) reqLog = reqLog.With( zap.Bool("has_previous_response_id", true), zap.String("previous_response_id_kind", previousResponseIDKind), zap.Int("previous_response_id_len", len(previousResponseID)), ) if previousResponseIDKind == service.OpenAIPreviousResponseIDKindMessageID { reqLog.Warn("openai.request_validation_failed", zap.String("reason", "previous_response_id_looks_like_message_id"), ) h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "previous_response_id must be a response.id (resp_*), not a message id") return } } setOpsRequestContext(c, reqModel, reqStream, body) // 提前校验 function_call_output 是否具备可关联上下文,避免上游 400。 if !h.validateFunctionCallOutputRequest(c, body, reqLog) { return } // 绑定错误透传服务,允许 service 层在非 failover 错误场景复用规则。 if h.errorPassthroughService != nil { service.BindErrorPassthroughService(c, h.errorPassthroughService) } // Get subscription info (may be nil) subscription, _ := middleware2.GetSubscriptionFromContext(c) service.SetOpsLatencyMs(c, service.OpsAuthLatencyMsKey, time.Since(requestStart).Milliseconds()) routingStart := time.Now() userReleaseFunc, acquired := h.acquireResponsesUserSlot(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted, reqLog) if !acquired { return } // 确保请求取消时也会释放槽位,避免长连接被动中断造成泄漏 if userReleaseFunc != nil { defer userReleaseFunc() } // 2. Re-check billing eligibility after wait if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { reqLog.Info("openai.billing_eligibility_check_failed", zap.Error(err)) status, code, message := billingErrorDetails(err) h.handleStreamingAwareError(c, status, code, message, streamStarted) return } // Generate session hash (header first; fallback to prompt_cache_key) sessionHash := h.gatewayService.GenerateSessionHash(c, body) maxAccountSwitches := h.maxAccountSwitches switchCount := 0 failedAccountIDs := make(map[int64]struct{}) var lastFailoverErr *service.UpstreamFailoverError for { // Select account supporting the requested model reqLog.Debug("openai.account_selecting", zap.Int("excluded_account_count", len(failedAccountIDs))) selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler( c.Request.Context(), apiKey.GroupID, previousResponseID, sessionHash, reqModel, failedAccountIDs, service.OpenAIUpstreamTransportAny, ) if err != nil { reqLog.Warn("openai.account_select_failed", zap.Error(err), zap.Int("excluded_account_count", len(failedAccountIDs)), ) if len(failedAccountIDs) == 0 { h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted) return } if lastFailoverErr != nil { h.handleFailoverExhausted(c, lastFailoverErr, streamStarted) } else { h.handleFailoverExhaustedSimple(c, 502, streamStarted) } return } if selection == nil || selection.Account == nil { h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted) return } if previousResponseID != "" && selection != nil && selection.Account != nil { reqLog.Debug("openai.account_selected_with_previous_response_id", zap.Int64("account_id", selection.Account.ID)) } reqLog.Debug("openai.account_schedule_decision", zap.String("layer", scheduleDecision.Layer), zap.Bool("sticky_previous_hit", scheduleDecision.StickyPreviousHit), zap.Bool("sticky_session_hit", scheduleDecision.StickySessionHit), zap.Int("candidate_count", scheduleDecision.CandidateCount), zap.Int("top_k", scheduleDecision.TopK), zap.Int64("latency_ms", scheduleDecision.LatencyMs), zap.Float64("load_skew", scheduleDecision.LoadSkew), ) account := selection.Account reqLog.Debug("openai.account_selected", zap.Int64("account_id", account.ID), zap.String("account_name", account.Name)) setOpsSelectedAccount(c, account.ID, account.Platform) accountReleaseFunc, acquired := h.acquireResponsesAccountSlot(c, apiKey.GroupID, sessionHash, selection, reqStream, &streamStarted, reqLog) if !acquired { return } // Forward request service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds()) forwardStart := time.Now() result, err := h.gatewayService.Forward(c.Request.Context(), c, account, body) forwardDurationMs := time.Since(forwardStart).Milliseconds() if accountReleaseFunc != nil { accountReleaseFunc() } upstreamLatencyMs, _ := getContextInt64(c, service.OpsUpstreamLatencyMsKey) responseLatencyMs := forwardDurationMs if upstreamLatencyMs > 0 && forwardDurationMs > upstreamLatencyMs { responseLatencyMs = forwardDurationMs - upstreamLatencyMs } service.SetOpsLatencyMs(c, service.OpsResponseLatencyMsKey, responseLatencyMs) if err == nil && result != nil && result.FirstTokenMs != nil { service.SetOpsLatencyMs(c, service.OpsTimeToFirstTokenMsKey, int64(*result.FirstTokenMs)) } if err != nil { var failoverErr *service.UpstreamFailoverError if errors.As(err, &failoverErr) { h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil) h.gatewayService.RecordOpenAIAccountSwitch() failedAccountIDs[account.ID] = struct{}{} lastFailoverErr = failoverErr if switchCount >= maxAccountSwitches { h.handleFailoverExhausted(c, failoverErr, streamStarted) return } switchCount++ reqLog.Warn("openai.upstream_failover_switching", zap.Int64("account_id", account.ID), zap.Int("upstream_status", failoverErr.StatusCode), zap.Int("switch_count", switchCount), zap.Int("max_switches", maxAccountSwitches), ) continue } h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil) wroteFallback := h.ensureForwardErrorResponse(c, streamStarted) fields := []zap.Field{ zap.Int64("account_id", account.ID), zap.Bool("fallback_error_response_written", wroteFallback), zap.Error(err), } if shouldLogOpenAIForwardFailureAsWarn(c, wroteFallback) { reqLog.Warn("openai.forward_failed", fields...) return } reqLog.Error("openai.forward_failed", fields...) return } if result != nil { h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs) } else { h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, nil) } // 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context) userAgent := c.GetHeader("User-Agent") clientIP := ip.GetClientIP(c) // 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。 h.submitUsageRecordTask(func(ctx context.Context) { if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{ Result: result, APIKey: apiKey, User: apiKey.User, Account: account, Subscription: subscription, UserAgent: userAgent, IPAddress: clientIP, APIKeyService: h.apiKeyService, }); err != nil { logger.L().With( zap.String("component", "handler.openai_gateway.responses"), zap.Int64("user_id", subject.UserID), zap.Int64("api_key_id", apiKey.ID), zap.Any("group_id", apiKey.GroupID), zap.String("model", reqModel), zap.Int64("account_id", account.ID), ).Error("openai.record_usage_failed", zap.Error(err)) } }) reqLog.Debug("openai.request_completed", zap.Int64("account_id", account.ID), zap.Int("switch_count", switchCount), ) return } } func (h *OpenAIGatewayHandler) validateFunctionCallOutputRequest(c *gin.Context, body []byte, reqLog *zap.Logger) bool { if !gjson.GetBytes(body, `input.#(type=="function_call_output")`).Exists() { return true } var reqBody map[string]any if err := json.Unmarshal(body, &reqBody); err != nil { // 保持原有容错语义:解析失败时跳过预校验,沿用后续上游校验结果。 return true } c.Set(service.OpenAIParsedRequestBodyKey, reqBody) validation := service.ValidateFunctionCallOutputContext(reqBody) if !validation.HasFunctionCallOutput { return true } previousResponseID, _ := reqBody["previous_response_id"].(string) if strings.TrimSpace(previousResponseID) != "" || validation.HasToolCallContext { return true } if validation.HasFunctionCallOutputMissingCallID { reqLog.Warn("openai.request_validation_failed", zap.String("reason", "function_call_output_missing_call_id"), ) h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires call_id or previous_response_id; if relying on history, ensure store=true and reuse previous_response_id") return false } if validation.HasItemReferenceForAllCallIDs { return true } reqLog.Warn("openai.request_validation_failed", zap.String("reason", "function_call_output_missing_item_reference"), ) h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires item_reference ids matching each call_id, or previous_response_id/tool_call context; if relying on history, ensure store=true and reuse previous_response_id") return false } func (h *OpenAIGatewayHandler) acquireResponsesUserSlot( c *gin.Context, userID int64, userConcurrency int, reqStream bool, streamStarted *bool, reqLog *zap.Logger, ) (func(), bool) { ctx := c.Request.Context() userReleaseFunc, userAcquired, err := h.concurrencyHelper.TryAcquireUserSlot(ctx, userID, userConcurrency) if err != nil { reqLog.Warn("openai.user_slot_acquire_failed", zap.Error(err)) h.handleConcurrencyError(c, err, "user", *streamStarted) return nil, false } if userAcquired { return wrapReleaseOnDone(ctx, userReleaseFunc), true } maxWait := service.CalculateMaxWait(userConcurrency) canWait, waitErr := h.concurrencyHelper.IncrementWaitCount(ctx, userID, maxWait) if waitErr != nil { reqLog.Warn("openai.user_wait_counter_increment_failed", zap.Error(waitErr)) // 按现有降级语义:等待计数异常时放行后续抢槽流程 } else if !canWait { reqLog.Info("openai.user_wait_queue_full", zap.Int("max_wait", maxWait)) h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later") return nil, false } waitCounted := waitErr == nil && canWait defer func() { if waitCounted { h.concurrencyHelper.DecrementWaitCount(ctx, userID) } }() userReleaseFunc, err = h.concurrencyHelper.AcquireUserSlotWithWait(c, userID, userConcurrency, reqStream, streamStarted) if err != nil { reqLog.Warn("openai.user_slot_acquire_failed_after_wait", zap.Error(err)) h.handleConcurrencyError(c, err, "user", *streamStarted) return nil, false } // 槽位获取成功后,立刻退出等待计数。 if waitCounted { h.concurrencyHelper.DecrementWaitCount(ctx, userID) waitCounted = false } return wrapReleaseOnDone(ctx, userReleaseFunc), true } func (h *OpenAIGatewayHandler) acquireResponsesAccountSlot( c *gin.Context, groupID *int64, sessionHash string, selection *service.AccountSelectionResult, reqStream bool, streamStarted *bool, reqLog *zap.Logger, ) (func(), bool) { if selection == nil || selection.Account == nil { h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", *streamStarted) return nil, false } ctx := c.Request.Context() account := selection.Account if selection.Acquired { return wrapReleaseOnDone(ctx, selection.ReleaseFunc), true } if selection.WaitPlan == nil { h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", *streamStarted) return nil, false } fastReleaseFunc, fastAcquired, err := h.concurrencyHelper.TryAcquireAccountSlot( ctx, account.ID, selection.WaitPlan.MaxConcurrency, ) if err != nil { reqLog.Warn("openai.account_slot_quick_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err)) h.handleConcurrencyError(c, err, "account", *streamStarted) return nil, false } if fastAcquired { if err := h.gatewayService.BindStickySession(ctx, groupID, sessionHash, account.ID); err != nil { reqLog.Warn("openai.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err)) } return wrapReleaseOnDone(ctx, fastReleaseFunc), true } canWait, waitErr := h.concurrencyHelper.IncrementAccountWaitCount(ctx, account.ID, selection.WaitPlan.MaxWaiting) if waitErr != nil { reqLog.Warn("openai.account_wait_counter_increment_failed", zap.Int64("account_id", account.ID), zap.Error(waitErr)) } else if !canWait { reqLog.Info("openai.account_wait_queue_full", zap.Int64("account_id", account.ID), zap.Int("max_waiting", selection.WaitPlan.MaxWaiting), ) h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", *streamStarted) return nil, false } accountWaitCounted := waitErr == nil && canWait releaseWait := func() { if accountWaitCounted { h.concurrencyHelper.DecrementAccountWaitCount(ctx, account.ID) accountWaitCounted = false } } defer releaseWait() accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout( c, account.ID, selection.WaitPlan.MaxConcurrency, selection.WaitPlan.Timeout, reqStream, streamStarted, ) if err != nil { reqLog.Warn("openai.account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err)) h.handleConcurrencyError(c, err, "account", *streamStarted) return nil, false } // Slot acquired: no longer waiting in queue. releaseWait() if err := h.gatewayService.BindStickySession(ctx, groupID, sessionHash, account.ID); err != nil { reqLog.Warn("openai.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err)) } return wrapReleaseOnDone(ctx, accountReleaseFunc), true } // ResponsesWebSocket handles OpenAI Responses API WebSocket ingress endpoint // GET /openai/v1/responses (Upgrade: websocket) func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) { if !isOpenAIWSUpgradeRequest(c.Request) { h.errorResponse(c, http.StatusUpgradeRequired, "invalid_request_error", "WebSocket upgrade required (Upgrade: websocket)") return } setOpenAIClientTransportWS(c) apiKey, ok := middleware2.GetAPIKeyFromContext(c) if !ok { h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key") return } subject, ok := middleware2.GetAuthSubjectFromContext(c) if !ok { h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found") return } reqLog := requestLogger( c, "handler.openai_gateway.responses_ws", zap.Int64("user_id", subject.UserID), zap.Int64("api_key_id", apiKey.ID), zap.Any("group_id", apiKey.GroupID), zap.Bool("openai_ws_mode", true), ) if !h.ensureResponsesDependencies(c, reqLog) { return } reqLog.Info("openai.websocket_ingress_started") clientIP := ip.GetClientIP(c) userAgent := strings.TrimSpace(c.GetHeader("User-Agent")) wsConn, err := coderws.Accept(c.Writer, c.Request, &coderws.AcceptOptions{ CompressionMode: coderws.CompressionContextTakeover, }) if err != nil { reqLog.Warn("openai.websocket_accept_failed", zap.Error(err), zap.String("client_ip", clientIP), zap.String("request_user_agent", userAgent), zap.String("upgrade_header", strings.TrimSpace(c.GetHeader("Upgrade"))), zap.String("connection_header", strings.TrimSpace(c.GetHeader("Connection"))), zap.String("sec_websocket_version", strings.TrimSpace(c.GetHeader("Sec-WebSocket-Version"))), zap.Bool("has_sec_websocket_key", strings.TrimSpace(c.GetHeader("Sec-WebSocket-Key")) != ""), ) return } defer func() { _ = wsConn.CloseNow() }() wsConn.SetReadLimit(16 * 1024 * 1024) ctx := c.Request.Context() readCtx, cancel := context.WithTimeout(ctx, 30*time.Second) msgType, firstMessage, err := wsConn.Read(readCtx) cancel() if err != nil { closeStatus, closeReason := summarizeWSCloseErrorForLog(err) reqLog.Warn("openai.websocket_read_first_message_failed", zap.Error(err), zap.String("client_ip", clientIP), zap.String("close_status", closeStatus), zap.String("close_reason", closeReason), zap.Duration("read_timeout", 30*time.Second), ) closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, "missing first response.create message") return } if msgType != coderws.MessageText && msgType != coderws.MessageBinary { closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, "unsupported websocket message type") return } if !gjson.ValidBytes(firstMessage) { closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, "invalid JSON payload") return } reqModel := strings.TrimSpace(gjson.GetBytes(firstMessage, "model").String()) if reqModel == "" { closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, "model is required in first response.create payload") return } previousResponseID := strings.TrimSpace(gjson.GetBytes(firstMessage, "previous_response_id").String()) previousResponseIDKind := service.ClassifyOpenAIPreviousResponseIDKind(previousResponseID) if previousResponseID != "" && previousResponseIDKind == service.OpenAIPreviousResponseIDKindMessageID { closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, "previous_response_id must be a response.id (resp_*), not a message id") return } reqLog = reqLog.With( zap.Bool("ws_ingress", true), zap.String("model", reqModel), zap.Bool("has_previous_response_id", previousResponseID != ""), zap.String("previous_response_id_kind", previousResponseIDKind), ) setOpsRequestContext(c, reqModel, true, firstMessage) var currentUserRelease func() var currentAccountRelease func() releaseTurnSlots := func() { if currentAccountRelease != nil { currentAccountRelease() currentAccountRelease = nil } if currentUserRelease != nil { currentUserRelease() currentUserRelease = nil } } // 必须尽早注册,确保任何 early return 都能释放已获取的并发槽位。 defer releaseTurnSlots() userReleaseFunc, userAcquired, err := h.concurrencyHelper.TryAcquireUserSlot(ctx, subject.UserID, subject.Concurrency) if err != nil { reqLog.Warn("openai.websocket_user_slot_acquire_failed", zap.Error(err)) closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "failed to acquire user concurrency slot") return } if !userAcquired { closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "too many concurrent requests, please retry later") return } currentUserRelease = wrapReleaseOnDone(ctx, userReleaseFunc) subscription, _ := middleware2.GetSubscriptionFromContext(c) if err := h.billingCacheService.CheckBillingEligibility(ctx, apiKey.User, apiKey, apiKey.Group, subscription); err != nil { reqLog.Info("openai.websocket_billing_eligibility_check_failed", zap.Error(err)) closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, "billing check failed") return } sessionHash := h.gatewayService.GenerateSessionHashWithFallback( c, firstMessage, openAIWSIngressFallbackSessionSeed(subject.UserID, apiKey.ID, apiKey.GroupID), ) selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler( ctx, apiKey.GroupID, previousResponseID, sessionHash, reqModel, nil, service.OpenAIUpstreamTransportResponsesWebsocketV2, ) if err != nil { reqLog.Warn("openai.websocket_account_select_failed", zap.Error(err)) closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "no available account") return } if selection == nil || selection.Account == nil { closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "no available account") return } account := selection.Account accountMaxConcurrency := account.Concurrency if selection.WaitPlan != nil && selection.WaitPlan.MaxConcurrency > 0 { accountMaxConcurrency = selection.WaitPlan.MaxConcurrency } accountReleaseFunc := selection.ReleaseFunc if !selection.Acquired { if selection.WaitPlan == nil { closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "account is busy, please retry later") return } fastReleaseFunc, fastAcquired, err := h.concurrencyHelper.TryAcquireAccountSlot( ctx, account.ID, selection.WaitPlan.MaxConcurrency, ) if err != nil { reqLog.Warn("openai.websocket_account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err)) closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "failed to acquire account concurrency slot") return } if !fastAcquired { closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "account is busy, please retry later") return } accountReleaseFunc = fastReleaseFunc } currentAccountRelease = wrapReleaseOnDone(ctx, accountReleaseFunc) if err := h.gatewayService.BindStickySession(ctx, apiKey.GroupID, sessionHash, account.ID); err != nil { reqLog.Warn("openai.websocket_bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err)) } token, _, err := h.gatewayService.GetAccessToken(ctx, account) if err != nil { reqLog.Warn("openai.websocket_get_access_token_failed", zap.Int64("account_id", account.ID), zap.Error(err)) closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "failed to get access token") return } reqLog.Debug("openai.websocket_account_selected", zap.Int64("account_id", account.ID), zap.String("account_name", account.Name), zap.String("schedule_layer", scheduleDecision.Layer), zap.Int("candidate_count", scheduleDecision.CandidateCount), ) hooks := &service.OpenAIWSIngressHooks{ BeforeTurn: func(turn int) error { if turn == 1 { return nil } // 防御式清理:避免异常路径下旧槽位覆盖导致泄漏。 releaseTurnSlots() // 非首轮 turn 需要重新抢占并发槽位,避免长连接空闲占槽。 userReleaseFunc, userAcquired, err := h.concurrencyHelper.TryAcquireUserSlot(ctx, subject.UserID, subject.Concurrency) if err != nil { return service.NewOpenAIWSClientCloseError(coderws.StatusInternalError, "failed to acquire user concurrency slot", err) } if !userAcquired { return service.NewOpenAIWSClientCloseError(coderws.StatusTryAgainLater, "too many concurrent requests, please retry later", nil) } accountReleaseFunc, accountAcquired, err := h.concurrencyHelper.TryAcquireAccountSlot(ctx, account.ID, accountMaxConcurrency) if err != nil { if userReleaseFunc != nil { userReleaseFunc() } return service.NewOpenAIWSClientCloseError(coderws.StatusInternalError, "failed to acquire account concurrency slot", err) } if !accountAcquired { if userReleaseFunc != nil { userReleaseFunc() } return service.NewOpenAIWSClientCloseError(coderws.StatusTryAgainLater, "account is busy, please retry later", nil) } currentUserRelease = wrapReleaseOnDone(ctx, userReleaseFunc) currentAccountRelease = wrapReleaseOnDone(ctx, accountReleaseFunc) return nil }, AfterTurn: func(turn int, result *service.OpenAIForwardResult, turnErr error) { releaseTurnSlots() if turnErr != nil || result == nil { return } h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs) h.submitUsageRecordTask(func(taskCtx context.Context) { if err := h.gatewayService.RecordUsage(taskCtx, &service.OpenAIRecordUsageInput{ Result: result, APIKey: apiKey, User: apiKey.User, Account: account, Subscription: subscription, UserAgent: userAgent, IPAddress: clientIP, APIKeyService: h.apiKeyService, }); err != nil { reqLog.Error("openai.websocket_record_usage_failed", zap.Int64("account_id", account.ID), zap.String("request_id", result.RequestID), zap.Error(err), ) } }) }, } if err := h.gatewayService.ProxyResponsesWebSocketFromClient(ctx, c, wsConn, account, token, firstMessage, hooks); err != nil { h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil) closeStatus, closeReason := summarizeWSCloseErrorForLog(err) reqLog.Warn("openai.websocket_proxy_failed", zap.Int64("account_id", account.ID), zap.Error(err), zap.String("close_status", closeStatus), zap.String("close_reason", closeReason), ) var closeErr *service.OpenAIWSClientCloseError if errors.As(err, &closeErr) { closeOpenAIClientWS(wsConn, closeErr.StatusCode(), closeErr.Reason()) return } closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "upstream websocket proxy failed") return } reqLog.Info("openai.websocket_ingress_closed", zap.Int64("account_id", account.ID)) } func (h *OpenAIGatewayHandler) recoverResponsesPanic(c *gin.Context, streamStarted *bool) { recovered := recover() if recovered == nil { return } started := false if streamStarted != nil { started = *streamStarted } wroteFallback := h.ensureForwardErrorResponse(c, started) requestLogger(c, "handler.openai_gateway.responses").Error( "openai.responses_panic_recovered", zap.Bool("fallback_error_response_written", wroteFallback), zap.Any("panic", recovered), zap.ByteString("stack", debug.Stack()), ) } func (h *OpenAIGatewayHandler) ensureResponsesDependencies(c *gin.Context, reqLog *zap.Logger) bool { missing := h.missingResponsesDependencies() if len(missing) == 0 { return true } if reqLog == nil { reqLog = requestLogger(c, "handler.openai_gateway.responses") } reqLog.Error("openai.handler_dependencies_missing", zap.Strings("missing_dependencies", missing)) if c != nil && c.Writer != nil && !c.Writer.Written() { c.JSON(http.StatusServiceUnavailable, gin.H{ "error": gin.H{ "type": "api_error", "message": "Service temporarily unavailable", }, }) } return false } func (h *OpenAIGatewayHandler) missingResponsesDependencies() []string { missing := make([]string, 0, 5) if h == nil { return append(missing, "handler") } if h.gatewayService == nil { missing = append(missing, "gatewayService") } if h.billingCacheService == nil { missing = append(missing, "billingCacheService") } if h.apiKeyService == nil { missing = append(missing, "apiKeyService") } if h.concurrencyHelper == nil || h.concurrencyHelper.concurrencyService == nil { missing = append(missing, "concurrencyHelper") } return missing } func getContextInt64(c *gin.Context, key string) (int64, bool) { if c == nil || key == "" { return 0, false } v, ok := c.Get(key) if !ok { return 0, false } switch t := v.(type) { case int64: return t, true case int: return int64(t), true case int32: return int64(t), true case float64: return int64(t), true default: return 0, false } } func (h *OpenAIGatewayHandler) submitUsageRecordTask(task service.UsageRecordTask) { if task == nil { return } if h.usageRecordWorkerPool != nil { h.usageRecordWorkerPool.Submit(task) return } // 回退路径:worker 池未注入时同步执行,避免退回到无界 goroutine 模式。 ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() defer func() { if recovered := recover(); recovered != nil { logger.L().With( zap.String("component", "handler.openai_gateway.responses"), zap.Any("panic", recovered), ).Error("openai.usage_record_task_panic_recovered") } }() task(ctx) } // handleConcurrencyError handles concurrency-related errors with proper 429 response func (h *OpenAIGatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotType string, streamStarted bool) { h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted) } func (h *OpenAIGatewayHandler) handleFailoverExhausted(c *gin.Context, failoverErr *service.UpstreamFailoverError, streamStarted bool) { statusCode := failoverErr.StatusCode responseBody := failoverErr.ResponseBody // 先检查透传规则 if h.errorPassthroughService != nil && len(responseBody) > 0 { if rule := h.errorPassthroughService.MatchRule("openai", statusCode, responseBody); rule != nil { // 确定响应状态码 respCode := statusCode if !rule.PassthroughCode && rule.ResponseCode != nil { respCode = *rule.ResponseCode } // 确定响应消息 msg := service.ExtractUpstreamErrorMessage(responseBody) if !rule.PassthroughBody && rule.CustomMessage != nil { msg = *rule.CustomMessage } if rule.SkipMonitoring { c.Set(service.OpsSkipPassthroughKey, true) } h.handleStreamingAwareError(c, respCode, "upstream_error", msg, streamStarted) return } } // 使用默认的错误映射 status, errType, errMsg := h.mapUpstreamError(statusCode) h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted) } // handleFailoverExhaustedSimple 简化版本,用于没有响应体的情况 func (h *OpenAIGatewayHandler) handleFailoverExhaustedSimple(c *gin.Context, statusCode int, streamStarted bool) { status, errType, errMsg := h.mapUpstreamError(statusCode) h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted) } func (h *OpenAIGatewayHandler) mapUpstreamError(statusCode int) (int, string, string) { switch statusCode { case 401: return http.StatusBadGateway, "upstream_error", "Upstream authentication failed, please contact administrator" case 403: return http.StatusBadGateway, "upstream_error", "Upstream access forbidden, please contact administrator" case 429: return http.StatusTooManyRequests, "rate_limit_error", "Upstream rate limit exceeded, please retry later" case 529: return http.StatusServiceUnavailable, "upstream_error", "Upstream service overloaded, please retry later" case 500, 502, 503, 504: return http.StatusBadGateway, "upstream_error", "Upstream service temporarily unavailable" default: return http.StatusBadGateway, "upstream_error", "Upstream request failed" } } // handleStreamingAwareError handles errors that may occur after streaming has started func (h *OpenAIGatewayHandler) handleStreamingAwareError(c *gin.Context, status int, errType, message string, streamStarted bool) { if streamStarted { // Stream already started, send error as SSE event then close flusher, ok := c.Writer.(http.Flusher) if ok { // SSE 错误事件固定 schema,使用 Quote 直拼可避免额外 Marshal 分配。 errorEvent := "event: error\ndata: " + `{"error":{"type":` + strconv.Quote(errType) + `,"message":` + strconv.Quote(message) + `}}` + "\n\n" if _, err := fmt.Fprint(c.Writer, errorEvent); err != nil { _ = c.Error(err) } flusher.Flush() } return } // Normal case: return JSON response with proper status code h.errorResponse(c, status, errType, message) } // ensureForwardErrorResponse 在 Forward 返回错误但尚未写响应时补写统一错误响应。 func (h *OpenAIGatewayHandler) ensureForwardErrorResponse(c *gin.Context, streamStarted bool) bool { if c == nil || c.Writer == nil || c.Writer.Written() { return false } h.handleStreamingAwareError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed", streamStarted) return true } func shouldLogOpenAIForwardFailureAsWarn(c *gin.Context, wroteFallback bool) bool { if wroteFallback { return false } if c == nil || c.Writer == nil { return false } return c.Writer.Written() } // errorResponse returns OpenAI API format error response func (h *OpenAIGatewayHandler) errorResponse(c *gin.Context, status int, errType, message string) { c.JSON(status, gin.H{ "error": gin.H{ "type": errType, "message": message, }, }) } func setOpenAIClientTransportHTTP(c *gin.Context) { service.SetOpenAIClientTransport(c, service.OpenAIClientTransportHTTP) } func setOpenAIClientTransportWS(c *gin.Context) { service.SetOpenAIClientTransport(c, service.OpenAIClientTransportWS) } func openAIWSIngressFallbackSessionSeed(userID, apiKeyID int64, groupID *int64) string { gid := int64(0) if groupID != nil { gid = *groupID } return fmt.Sprintf("openai_ws_ingress:%d:%d:%d", gid, userID, apiKeyID) } func isOpenAIWSUpgradeRequest(r *http.Request) bool { if r == nil { return false } if !strings.EqualFold(strings.TrimSpace(r.Header.Get("Upgrade")), "websocket") { return false } return strings.Contains(strings.ToLower(strings.TrimSpace(r.Header.Get("Connection"))), "upgrade") } func closeOpenAIClientWS(conn *coderws.Conn, status coderws.StatusCode, reason string) { if conn == nil { return } reason = strings.TrimSpace(reason) if len(reason) > 120 { reason = reason[:120] } _ = conn.Close(status, reason) _ = conn.CloseNow() } func summarizeWSCloseErrorForLog(err error) (string, string) { if err == nil { return "-", "-" } statusCode := coderws.CloseStatus(err) if statusCode == -1 { return "-", "-" } closeStatus := fmt.Sprintf("%d(%s)", int(statusCode), statusCode.String()) closeReason := "-" var closeErr coderws.CloseError if errors.As(err, &closeErr) { reason := strings.TrimSpace(closeErr.Reason) if reason != "" { closeReason = reason } } return closeStatus, closeReason }