package handler import ( "context" "errors" "net/http" "strings" "time" pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil" "github.com/Wei-Shaw/sub2api/internal/pkg/ip" "github.com/Wei-Shaw/sub2api/internal/pkg/logger" middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/gin-gonic/gin" "go.uber.org/zap" ) // Images handles OpenAI Images API requests. // POST /v1/images/generations // POST /v1/images/edits func (h *OpenAIGatewayHandler) Images(c *gin.Context) { streamStarted := false defer h.recoverResponsesPanic(c, &streamStarted) requestStart := time.Now() apiKey, ok := middleware2.GetAPIKeyFromContext(c) if !ok { h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key") return } subject, ok := middleware2.GetAuthSubjectFromContext(c) if !ok { h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found") return } reqLog := requestLogger( c, "handler.openai_gateway.images", zap.Int64("user_id", subject.UserID), zap.Int64("api_key_id", apiKey.ID), zap.Any("group_id", apiKey.GroupID), ) if !h.ensureResponsesDependencies(c, reqLog) { return } body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request) if err != nil { if maxErr, ok := extractMaxBytesError(err); ok { h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit)) return } h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body") return } if len(body) == 0 { h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty") return } if isMultipartImagesContentType(c.GetHeader("Content-Type")) { setOpsRequestContext(c, "", false, nil) } else { setOpsRequestContext(c, "", false, body) } parsed, err := h.gatewayService.ParseOpenAIImagesRequest(c, body) if err != nil { h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", err.Error()) return } reqLog = reqLog.With( zap.String("model", parsed.Model), zap.Bool("stream", parsed.Stream), zap.Bool("multipart", parsed.Multipart), zap.String("capability", string(parsed.RequiredCapability)), ) if parsed.Multipart { setOpsRequestContext(c, parsed.Model, parsed.Stream, nil) } else { setOpsRequestContext(c, parsed.Model, parsed.Stream, body) } setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(parsed.Stream, false))) channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, parsed.Model) if h.errorPassthroughService != nil { service.BindErrorPassthroughService(c, h.errorPassthroughService) } subscription, _ := middleware2.GetSubscriptionFromContext(c) service.SetOpsLatencyMs(c, service.OpsAuthLatencyMsKey, time.Since(requestStart).Milliseconds()) routingStart := time.Now() userReleaseFunc, acquired := h.acquireResponsesUserSlot(c, subject.UserID, subject.Concurrency, parsed.Stream, &streamStarted, reqLog) if !acquired { return } if userReleaseFunc != nil { defer userReleaseFunc() } if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { reqLog.Info("openai.images.billing_eligibility_check_failed", zap.Error(err)) status, code, message := billingErrorDetails(err) h.handleStreamingAwareError(c, status, code, message, streamStarted) return } sessionHash := "" if parsed.Multipart { sessionHash = h.gatewayService.GenerateSessionHashWithFallback(c, nil, parsed.StickySessionSeed()) } else { sessionHash = h.gatewayService.GenerateSessionHash(c, body) } maxAccountSwitches := h.maxAccountSwitches switchCount := 0 failedAccountIDs := make(map[int64]struct{}) sameAccountRetryCount := make(map[int64]int) var lastFailoverErr *service.UpstreamFailoverError for { reqLog.Debug("openai.images.account_selecting", zap.Int("excluded_account_count", len(failedAccountIDs))) selection, scheduleDecision, err := h.gatewayService.SelectAccountWithSchedulerForImages( c.Request.Context(), apiKey.GroupID, sessionHash, parsed.Model, failedAccountIDs, parsed.RequiredCapability, ) if err != nil { reqLog.Warn("openai.images.account_select_failed", zap.Error(err), zap.Int("excluded_account_count", len(failedAccountIDs)), ) if len(failedAccountIDs) == 0 { h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available compatible accounts", streamStarted) return } if lastFailoverErr != nil { h.handleFailoverExhausted(c, lastFailoverErr, streamStarted) } else { h.handleFailoverExhaustedSimple(c, 502, streamStarted) } return } if selection == nil || selection.Account == nil { h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available compatible accounts", streamStarted) return } reqLog.Debug("openai.images.account_schedule_decision", zap.String("layer", scheduleDecision.Layer), zap.Bool("sticky_session_hit", scheduleDecision.StickySessionHit), zap.Int("candidate_count", scheduleDecision.CandidateCount), zap.Int("top_k", scheduleDecision.TopK), zap.Int64("latency_ms", scheduleDecision.LatencyMs), zap.Float64("load_skew", scheduleDecision.LoadSkew), ) account := selection.Account sessionHash = ensureOpenAIPoolModeSessionHash(sessionHash, account) reqLog.Debug("openai.images.account_selected", zap.Int64("account_id", account.ID), zap.String("account_name", account.Name)) setOpsSelectedAccount(c, account.ID, account.Platform) accountReleaseFunc, acquired := h.acquireResponsesAccountSlot(c, apiKey.GroupID, sessionHash, selection, parsed.Stream, &streamStarted, reqLog) if !acquired { return } service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds()) forwardStart := time.Now() result, err := h.gatewayService.ForwardImages(c.Request.Context(), c, account, body, parsed, channelMapping.MappedModel) forwardDurationMs := time.Since(forwardStart).Milliseconds() if accountReleaseFunc != nil { accountReleaseFunc() } upstreamLatencyMs, _ := getContextInt64(c, service.OpsUpstreamLatencyMsKey) responseLatencyMs := forwardDurationMs if upstreamLatencyMs > 0 && forwardDurationMs > upstreamLatencyMs { responseLatencyMs = forwardDurationMs - upstreamLatencyMs } service.SetOpsLatencyMs(c, service.OpsResponseLatencyMsKey, responseLatencyMs) if err == nil && result != nil && result.FirstTokenMs != nil { service.SetOpsLatencyMs(c, service.OpsTimeToFirstTokenMsKey, int64(*result.FirstTokenMs)) } if err != nil { var failoverErr *service.UpstreamFailoverError if errors.As(err, &failoverErr) { h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil) if failoverErr.RetryableOnSameAccount { retryLimit := account.GetPoolModeRetryCount() if sameAccountRetryCount[account.ID] < retryLimit { sameAccountRetryCount[account.ID]++ reqLog.Warn("openai.images.pool_mode_same_account_retry", zap.Int64("account_id", account.ID), zap.Int("upstream_status", failoverErr.StatusCode), zap.Int("retry_limit", retryLimit), zap.Int("retry_count", sameAccountRetryCount[account.ID]), ) select { case <-c.Request.Context().Done(): return case <-time.After(sameAccountRetryDelay): } continue } } h.gatewayService.RecordOpenAIAccountSwitch() failedAccountIDs[account.ID] = struct{}{} lastFailoverErr = failoverErr if switchCount >= maxAccountSwitches { h.handleFailoverExhausted(c, failoverErr, streamStarted) return } switchCount++ reqLog.Warn("openai.images.upstream_failover_switching", zap.Int64("account_id", account.ID), zap.Int("upstream_status", failoverErr.StatusCode), zap.Int("switch_count", switchCount), zap.Int("max_switches", maxAccountSwitches), ) continue } h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil) wroteFallback := h.ensureForwardErrorResponse(c, streamStarted) fields := []zap.Field{ zap.Int64("account_id", account.ID), zap.Bool("fallback_error_response_written", wroteFallback), zap.Error(err), } if shouldLogOpenAIForwardFailureAsWarn(c, wroteFallback) { reqLog.Warn("openai.images.forward_failed", fields...) return } reqLog.Error("openai.images.forward_failed", fields...) return } if result != nil { if account.Type == service.AccountTypeOAuth { h.gatewayService.UpdateCodexUsageSnapshotFromHeaders(c.Request.Context(), account.ID, result.ResponseHeaders) } h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs) } else { h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, nil) } userAgent := c.GetHeader("User-Agent") clientIP := ip.GetClientIP(c) requestPayloadHash := service.HashUsageRequestPayload(body) if parsed.Multipart { requestPayloadHash = service.HashUsageRequestPayload([]byte(parsed.StickySessionSeed())) } h.submitUsageRecordTask(func(ctx context.Context) { if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{ Result: result, APIKey: apiKey, User: apiKey.User, Account: account, Subscription: subscription, InboundEndpoint: GetInboundEndpoint(c), UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform), UserAgent: userAgent, IPAddress: clientIP, RequestPayloadHash: requestPayloadHash, APIKeyService: h.apiKeyService, ChannelUsageFields: channelMapping.ToUsageFields(parsed.Model, result.UpstreamModel), }); err != nil { logger.L().With( zap.String("component", "handler.openai_gateway.images"), zap.Int64("user_id", subject.UserID), zap.Int64("api_key_id", apiKey.ID), zap.Any("group_id", apiKey.GroupID), zap.String("model", parsed.Model), zap.Int64("account_id", account.ID), ).Error("openai.images.record_usage_failed", zap.Error(err)) } }) reqLog.Debug("openai.images.request_completed", zap.Int64("account_id", account.ID), zap.Int("switch_count", switchCount), ) return } } func isMultipartImagesContentType(contentType string) bool { return strings.HasPrefix(strings.ToLower(strings.TrimSpace(contentType)), "multipart/form-data") }