package handler import ( "context" "encoding/json" "errors" "fmt" "io" "net/http" "strings" "time" "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/pkg/ip" "github.com/Wei-Shaw/sub2api/internal/pkg/logger" middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/gin-gonic/gin" "github.com/tidwall/gjson" "go.uber.org/zap" ) // OpenAIGatewayHandler handles OpenAI API gateway requests type OpenAIGatewayHandler struct { gatewayService *service.OpenAIGatewayService billingCacheService *service.BillingCacheService apiKeyService *service.APIKeyService usageRecordWorkerPool *service.UsageRecordWorkerPool errorPassthroughService *service.ErrorPassthroughService concurrencyHelper *ConcurrencyHelper maxAccountSwitches int } // NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler func NewOpenAIGatewayHandler( gatewayService *service.OpenAIGatewayService, concurrencyService *service.ConcurrencyService, billingCacheService *service.BillingCacheService, apiKeyService *service.APIKeyService, usageRecordWorkerPool *service.UsageRecordWorkerPool, errorPassthroughService *service.ErrorPassthroughService, cfg *config.Config, ) *OpenAIGatewayHandler { pingInterval := time.Duration(0) maxAccountSwitches := 3 if cfg != nil { pingInterval = time.Duration(cfg.Concurrency.PingInterval) * time.Second if cfg.Gateway.MaxAccountSwitches > 0 { maxAccountSwitches = cfg.Gateway.MaxAccountSwitches } } return &OpenAIGatewayHandler{ gatewayService: gatewayService, billingCacheService: billingCacheService, apiKeyService: apiKeyService, usageRecordWorkerPool: usageRecordWorkerPool, errorPassthroughService: errorPassthroughService, concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval), maxAccountSwitches: maxAccountSwitches, } } // Responses handles OpenAI Responses API endpoint // POST /openai/v1/responses func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { requestStart := time.Now() // Get apiKey and user from context (set by ApiKeyAuth middleware) apiKey, ok := middleware2.GetAPIKeyFromContext(c) if !ok { h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key") return } subject, ok := middleware2.GetAuthSubjectFromContext(c) if !ok { h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found") return } reqLog := requestLogger( c, "handler.openai_gateway.responses", zap.Int64("user_id", subject.UserID), zap.Int64("api_key_id", apiKey.ID), zap.Any("group_id", apiKey.GroupID), ) // Read request body body, err := io.ReadAll(c.Request.Body) if err != nil { if maxErr, ok := extractMaxBytesError(err); ok { h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit)) return } h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body") return } if len(body) == 0 { h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty") return } setOpsRequestContext(c, "", false, body) // 校验请求体 JSON 合法性 if !gjson.ValidBytes(body) { h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body") return } // 使用 gjson 只读提取字段做校验,避免完整 Unmarshal modelResult := gjson.GetBytes(body, "model") if !modelResult.Exists() || modelResult.Type != gjson.String || modelResult.String() == "" { h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required") return } reqModel := modelResult.String() streamResult := gjson.GetBytes(body, "stream") if streamResult.Exists() && streamResult.Type != gjson.True && streamResult.Type != gjson.False { h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "invalid stream field type") return } reqStream := streamResult.Bool() reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream)) setOpsRequestContext(c, reqModel, reqStream, body) // 提前校验 function_call_output 是否具备可关联上下文,避免上游 400。 // 要求 previous_response_id,或 input 内存在带 call_id 的 tool_call/function_call, // 或带 id 且与 call_id 匹配的 item_reference。 // 此路径需要遍历 input 数组做 call_id 关联检查,保留 Unmarshal if gjson.GetBytes(body, `input.#(type=="function_call_output")`).Exists() { var reqBody map[string]any if err := json.Unmarshal(body, &reqBody); err == nil { c.Set(service.OpenAIParsedRequestBodyKey, reqBody) if service.HasFunctionCallOutput(reqBody) { previousResponseID, _ := reqBody["previous_response_id"].(string) if strings.TrimSpace(previousResponseID) == "" && !service.HasToolCallContext(reqBody) { if service.HasFunctionCallOutputMissingCallID(reqBody) { reqLog.Warn("openai.request_validation_failed", zap.String("reason", "function_call_output_missing_call_id"), ) h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires call_id or previous_response_id; if relying on history, ensure store=true and reuse previous_response_id") return } callIDs := service.FunctionCallOutputCallIDs(reqBody) if !service.HasItemReferenceForCallIDs(reqBody, callIDs) { reqLog.Warn("openai.request_validation_failed", zap.String("reason", "function_call_output_missing_item_reference"), ) h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires item_reference ids matching each call_id, or previous_response_id/tool_call context; if relying on history, ensure store=true and reuse previous_response_id") return } } } } } // Track if we've started streaming (for error handling) streamStarted := false // 绑定错误透传服务,允许 service 层在非 failover 错误场景复用规则。 if h.errorPassthroughService != nil { service.BindErrorPassthroughService(c, h.errorPassthroughService) } // Get subscription info (may be nil) subscription, _ := middleware2.GetSubscriptionFromContext(c) service.SetOpsLatencyMs(c, service.OpsAuthLatencyMsKey, time.Since(requestStart).Milliseconds()) routingStart := time.Now() // 0. 先尝试直接抢占用户槽位(快速路径) userReleaseFunc, userAcquired, err := h.concurrencyHelper.TryAcquireUserSlot(c.Request.Context(), subject.UserID, subject.Concurrency) if err != nil { reqLog.Warn("openai.user_slot_acquire_failed", zap.Error(err)) h.handleConcurrencyError(c, err, "user", streamStarted) return } waitCounted := false if !userAcquired { // 仅在抢槽失败时才进入等待队列,减少常态请求 Redis 写入。 maxWait := service.CalculateMaxWait(subject.Concurrency) canWait, waitErr := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait) if waitErr != nil { reqLog.Warn("openai.user_wait_counter_increment_failed", zap.Error(waitErr)) // 按现有降级语义:等待计数异常时放行后续抢槽流程 } else if !canWait { reqLog.Info("openai.user_wait_queue_full", zap.Int("max_wait", maxWait)) h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later") return } if waitErr == nil && canWait { waitCounted = true } defer func() { if waitCounted { h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID) } }() userReleaseFunc, err = h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted) if err != nil { reqLog.Warn("openai.user_slot_acquire_failed_after_wait", zap.Error(err)) h.handleConcurrencyError(c, err, "user", streamStarted) return } } // 用户槽位已获取:退出等待队列计数。 if waitCounted { h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID) waitCounted = false } // 确保请求取消时也会释放槽位,避免长连接被动中断造成泄漏 userReleaseFunc = wrapReleaseOnDone(c.Request.Context(), userReleaseFunc) if userReleaseFunc != nil { defer userReleaseFunc() } // 2. Re-check billing eligibility after wait if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { reqLog.Info("openai.billing_eligibility_check_failed", zap.Error(err)) status, code, message := billingErrorDetails(err) h.handleStreamingAwareError(c, status, code, message, streamStarted) return } // Generate session hash (header first; fallback to prompt_cache_key) sessionHash := h.gatewayService.GenerateSessionHash(c, body) maxAccountSwitches := h.maxAccountSwitches switchCount := 0 failedAccountIDs := make(map[int64]struct{}) var lastFailoverErr *service.UpstreamFailoverError for { // Select account supporting the requested model reqLog.Debug("openai.account_selecting", zap.Int("excluded_account_count", len(failedAccountIDs))) selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs) if err != nil { reqLog.Warn("openai.account_select_failed", zap.Error(err), zap.Int("excluded_account_count", len(failedAccountIDs)), ) if len(failedAccountIDs) == 0 { h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted) return } if lastFailoverErr != nil { h.handleFailoverExhausted(c, lastFailoverErr, streamStarted) } else { h.handleFailoverExhaustedSimple(c, 502, streamStarted) } return } account := selection.Account reqLog.Debug("openai.account_selected", zap.Int64("account_id", account.ID), zap.String("account_name", account.Name)) setOpsSelectedAccount(c, account.ID, account.Platform) // 3. Acquire account concurrency slot accountReleaseFunc := selection.ReleaseFunc if !selection.Acquired { if selection.WaitPlan == nil { h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted) return } // 先快速尝试一次账号槽位,命中则跳过等待计数写入。 fastReleaseFunc, fastAcquired, err := h.concurrencyHelper.TryAcquireAccountSlot( c.Request.Context(), account.ID, selection.WaitPlan.MaxConcurrency, ) if err != nil { reqLog.Warn("openai.account_slot_quick_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err)) h.handleConcurrencyError(c, err, "account", streamStarted) return } if fastAcquired { accountReleaseFunc = fastReleaseFunc if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionHash, account.ID); err != nil { reqLog.Warn("openai.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err)) } } else { accountWaitCounted := false canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting) if err != nil { reqLog.Warn("openai.account_wait_counter_increment_failed", zap.Int64("account_id", account.ID), zap.Error(err)) } else if !canWait { reqLog.Info("openai.account_wait_queue_full", zap.Int64("account_id", account.ID), zap.Int("max_waiting", selection.WaitPlan.MaxWaiting), ) h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted) return } if err == nil && canWait { accountWaitCounted = true } releaseWait := func() { if accountWaitCounted { h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) accountWaitCounted = false } } accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout( c, account.ID, selection.WaitPlan.MaxConcurrency, selection.WaitPlan.Timeout, reqStream, &streamStarted, ) if err != nil { reqLog.Warn("openai.account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err)) releaseWait() h.handleConcurrencyError(c, err, "account", streamStarted) return } // Slot acquired: no longer waiting in queue. releaseWait() if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionHash, account.ID); err != nil { reqLog.Warn("openai.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err)) } } } // 账号槽位/等待计数需要在超时或断开时安全回收 accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc) // Forward request service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds()) forwardStart := time.Now() result, err := h.gatewayService.Forward(c.Request.Context(), c, account, body) forwardDurationMs := time.Since(forwardStart).Milliseconds() if accountReleaseFunc != nil { accountReleaseFunc() } upstreamLatencyMs, _ := getContextInt64(c, service.OpsUpstreamLatencyMsKey) responseLatencyMs := forwardDurationMs if upstreamLatencyMs > 0 && forwardDurationMs > upstreamLatencyMs { responseLatencyMs = forwardDurationMs - upstreamLatencyMs } service.SetOpsLatencyMs(c, service.OpsResponseLatencyMsKey, responseLatencyMs) if err == nil && result != nil && result.FirstTokenMs != nil { service.SetOpsLatencyMs(c, service.OpsTimeToFirstTokenMsKey, int64(*result.FirstTokenMs)) } if err != nil { var failoverErr *service.UpstreamFailoverError if errors.As(err, &failoverErr) { failedAccountIDs[account.ID] = struct{}{} lastFailoverErr = failoverErr if switchCount >= maxAccountSwitches { h.handleFailoverExhausted(c, failoverErr, streamStarted) return } switchCount++ reqLog.Warn("openai.upstream_failover_switching", zap.Int64("account_id", account.ID), zap.Int("upstream_status", failoverErr.StatusCode), zap.Int("switch_count", switchCount), zap.Int("max_switches", maxAccountSwitches), ) continue } wroteFallback := h.ensureForwardErrorResponse(c, streamStarted) reqLog.Error("openai.forward_failed", zap.Int64("account_id", account.ID), zap.Bool("fallback_error_response_written", wroteFallback), zap.Error(err), ) return } // 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context) userAgent := c.GetHeader("User-Agent") clientIP := ip.GetClientIP(c) // 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。 h.submitUsageRecordTask(func(ctx context.Context) { if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{ Result: result, APIKey: apiKey, User: apiKey.User, Account: account, Subscription: subscription, UserAgent: userAgent, IPAddress: clientIP, APIKeyService: h.apiKeyService, }); err != nil { logger.L().With( zap.String("component", "handler.openai_gateway.responses"), zap.Int64("user_id", subject.UserID), zap.Int64("api_key_id", apiKey.ID), zap.Any("group_id", apiKey.GroupID), zap.String("model", reqModel), zap.Int64("account_id", account.ID), ).Error("openai.record_usage_failed", zap.Error(err)) } }) reqLog.Debug("openai.request_completed", zap.Int64("account_id", account.ID), zap.Int("switch_count", switchCount), ) return } } func getContextInt64(c *gin.Context, key string) (int64, bool) { if c == nil || key == "" { return 0, false } v, ok := c.Get(key) if !ok { return 0, false } switch t := v.(type) { case int64: return t, true case int: return int64(t), true case int32: return int64(t), true case float64: return int64(t), true default: return 0, false } } func (h *OpenAIGatewayHandler) submitUsageRecordTask(task service.UsageRecordTask) { if task == nil { return } if h.usageRecordWorkerPool != nil { h.usageRecordWorkerPool.Submit(task) return } // 回退路径:worker 池未注入时同步执行,避免退回到无界 goroutine 模式。 ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() task(ctx) } // handleConcurrencyError handles concurrency-related errors with proper 429 response func (h *OpenAIGatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotType string, streamStarted bool) { h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted) } func (h *OpenAIGatewayHandler) handleFailoverExhausted(c *gin.Context, failoverErr *service.UpstreamFailoverError, streamStarted bool) { statusCode := failoverErr.StatusCode responseBody := failoverErr.ResponseBody // 先检查透传规则 if h.errorPassthroughService != nil && len(responseBody) > 0 { if rule := h.errorPassthroughService.MatchRule("openai", statusCode, responseBody); rule != nil { // 确定响应状态码 respCode := statusCode if !rule.PassthroughCode && rule.ResponseCode != nil { respCode = *rule.ResponseCode } // 确定响应消息 msg := service.ExtractUpstreamErrorMessage(responseBody) if !rule.PassthroughBody && rule.CustomMessage != nil { msg = *rule.CustomMessage } if rule.SkipMonitoring { c.Set(service.OpsSkipPassthroughKey, true) } h.handleStreamingAwareError(c, respCode, "upstream_error", msg, streamStarted) return } } // 使用默认的错误映射 status, errType, errMsg := h.mapUpstreamError(statusCode) h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted) } // handleFailoverExhaustedSimple 简化版本,用于没有响应体的情况 func (h *OpenAIGatewayHandler) handleFailoverExhaustedSimple(c *gin.Context, statusCode int, streamStarted bool) { status, errType, errMsg := h.mapUpstreamError(statusCode) h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted) } func (h *OpenAIGatewayHandler) mapUpstreamError(statusCode int) (int, string, string) { switch statusCode { case 401: return http.StatusBadGateway, "upstream_error", "Upstream authentication failed, please contact administrator" case 403: return http.StatusBadGateway, "upstream_error", "Upstream access forbidden, please contact administrator" case 429: return http.StatusTooManyRequests, "rate_limit_error", "Upstream rate limit exceeded, please retry later" case 529: return http.StatusServiceUnavailable, "upstream_error", "Upstream service overloaded, please retry later" case 500, 502, 503, 504: return http.StatusBadGateway, "upstream_error", "Upstream service temporarily unavailable" default: return http.StatusBadGateway, "upstream_error", "Upstream request failed" } } // handleStreamingAwareError handles errors that may occur after streaming has started func (h *OpenAIGatewayHandler) handleStreamingAwareError(c *gin.Context, status int, errType, message string, streamStarted bool) { if streamStarted { // Stream already started, send error as SSE event then close flusher, ok := c.Writer.(http.Flusher) if ok { // Send error event in OpenAI SSE format with proper JSON marshaling errorData := map[string]any{ "error": map[string]string{ "type": errType, "message": message, }, } jsonBytes, err := json.Marshal(errorData) if err != nil { _ = c.Error(err) return } errorEvent := fmt.Sprintf("event: error\ndata: %s\n\n", string(jsonBytes)) if _, err := fmt.Fprint(c.Writer, errorEvent); err != nil { _ = c.Error(err) } flusher.Flush() } return } // Normal case: return JSON response with proper status code h.errorResponse(c, status, errType, message) } // ensureForwardErrorResponse 在 Forward 返回错误但尚未写响应时补写统一错误响应。 func (h *OpenAIGatewayHandler) ensureForwardErrorResponse(c *gin.Context, streamStarted bool) bool { if c == nil || c.Writer == nil || c.Writer.Written() { return false } h.handleStreamingAwareError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed", streamStarted) return true } // errorResponse returns OpenAI API format error response func (h *OpenAIGatewayHandler) errorResponse(c *gin.Context, status int, errType, message string) { c.JSON(status, gin.H{ "error": gin.H{ "type": errType, "message": message, }, }) }