diff --git a/backend/internal/handler/endpoint.go b/backend/internal/handler/endpoint.go index a897bc40..db29618a 100644 --- a/backend/internal/handler/endpoint.go +++ b/backend/internal/handler/endpoint.go @@ -15,10 +15,12 @@ import ( // ────────────────────────────────────────────────────────── const ( - EndpointMessages = "/v1/messages" - EndpointChatCompletions = "/v1/chat/completions" - EndpointResponses = "/v1/responses" - EndpointGeminiModels = "/v1beta/models" + EndpointMessages = "/v1/messages" + EndpointChatCompletions = "/v1/chat/completions" + EndpointResponses = "/v1/responses" + EndpointImagesGenerations = "/v1/images/generations" + EndpointImagesEdits = "/v1/images/edits" + EndpointGeminiModels = "/v1beta/models" ) // gin.Context keys used by the middleware and helpers below. @@ -44,6 +46,10 @@ func NormalizeInboundEndpoint(path string) string { return EndpointChatCompletions case strings.Contains(path, EndpointMessages): return EndpointMessages + case strings.Contains(path, EndpointImagesGenerations) || strings.Contains(path, "/images/generations"): + return EndpointImagesGenerations + case strings.Contains(path, EndpointImagesEdits) || strings.Contains(path, "/images/edits"): + return EndpointImagesEdits case strings.Contains(path, EndpointResponses): return EndpointResponses case strings.Contains(path, EndpointGeminiModels): @@ -69,6 +75,9 @@ func DeriveUpstreamEndpoint(inbound, rawRequestPath, platform string) string { switch platform { case service.PlatformOpenAI: + if inbound == EndpointImagesGenerations || inbound == EndpointImagesEdits { + return inbound + } // OpenAI forwards everything to the Responses API. // Preserve subresource suffix (e.g. /v1/responses/compact). if suffix := responsesSubpathSuffix(rawRequestPath); suffix != "" { diff --git a/backend/internal/handler/endpoint_test.go b/backend/internal/handler/endpoint_test.go index 1519bc9e..369c5fa7 100644 --- a/backend/internal/handler/endpoint_test.go +++ b/backend/internal/handler/endpoint_test.go @@ -25,12 +25,16 @@ func TestNormalizeInboundEndpoint(t *testing.T) { {"/v1/messages", EndpointMessages}, {"/v1/chat/completions", EndpointChatCompletions}, {"/v1/responses", EndpointResponses}, + {"/v1/images/generations", EndpointImagesGenerations}, + {"/v1/images/edits", EndpointImagesEdits}, {"/v1beta/models", EndpointGeminiModels}, // Prefixed paths (antigravity, openai). {"/antigravity/v1/messages", EndpointMessages}, {"/openai/v1/responses", EndpointResponses}, {"/openai/v1/responses/compact", EndpointResponses}, + {"/openai/v1/images/generations", EndpointImagesGenerations}, + {"/openai/v1/images/edits", EndpointImagesEdits}, {"/antigravity/v1beta/models/gemini:generateContent", EndpointGeminiModels}, // Gin route patterns with wildcards. @@ -73,6 +77,8 @@ func TestDeriveUpstreamEndpoint(t *testing.T) { {"openai responses nested", EndpointResponses, "/openai/v1/responses/compact/detail", service.PlatformOpenAI, "/v1/responses/compact/detail"}, {"openai from messages", EndpointMessages, "/v1/messages", service.PlatformOpenAI, EndpointResponses}, {"openai from completions", EndpointChatCompletions, "/v1/chat/completions", service.PlatformOpenAI, EndpointResponses}, + {"openai image generations", EndpointImagesGenerations, "/v1/images/generations", service.PlatformOpenAI, EndpointImagesGenerations}, + {"openai image edits", EndpointImagesEdits, "/openai/v1/images/edits", service.PlatformOpenAI, EndpointImagesEdits}, // Antigravity — uses inbound to pick Claude vs Gemini upstream. {"antigravity claude", EndpointMessages, "/antigravity/v1/messages", service.PlatformAntigravity, EndpointMessages}, diff --git a/backend/internal/handler/openai_images.go b/backend/internal/handler/openai_images.go new file mode 100644 index 00000000..8dbf8935 --- /dev/null +++ b/backend/internal/handler/openai_images.go @@ -0,0 +1,300 @@ +package handler + +import ( + "context" + "errors" + "net/http" + "strings" + "time" + + pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil" + "github.com/Wei-Shaw/sub2api/internal/pkg/ip" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "go.uber.org/zap" +) + +// Images handles OpenAI Images API requests. +// POST /v1/images/generations +// POST /v1/images/edits +func (h *OpenAIGatewayHandler) Images(c *gin.Context) { + streamStarted := false + defer h.recoverResponsesPanic(c, &streamStarted) + + requestStart := time.Now() + + apiKey, ok := middleware2.GetAPIKeyFromContext(c) + if !ok { + h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key") + return + } + + subject, ok := middleware2.GetAuthSubjectFromContext(c) + if !ok { + h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found") + return + } + reqLog := requestLogger( + c, + "handler.openai_gateway.images", + zap.Int64("user_id", subject.UserID), + zap.Int64("api_key_id", apiKey.ID), + zap.Any("group_id", apiKey.GroupID), + ) + if !h.ensureResponsesDependencies(c, reqLog) { + return + } + + body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request) + if err != nil { + if maxErr, ok := extractMaxBytesError(err); ok { + h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit)) + return + } + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body") + return + } + if len(body) == 0 { + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty") + return + } + + if isMultipartImagesContentType(c.GetHeader("Content-Type")) { + setOpsRequestContext(c, "", false, nil) + } else { + setOpsRequestContext(c, "", false, body) + } + + parsed, err := h.gatewayService.ParseOpenAIImagesRequest(c, body) + if err != nil { + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", err.Error()) + return + } + + reqLog = reqLog.With( + zap.String("model", parsed.Model), + zap.Bool("stream", parsed.Stream), + zap.Bool("multipart", parsed.Multipart), + zap.String("capability", string(parsed.RequiredCapability)), + ) + + if parsed.Multipart { + setOpsRequestContext(c, parsed.Model, parsed.Stream, nil) + } else { + setOpsRequestContext(c, parsed.Model, parsed.Stream, body) + } + setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(parsed.Stream, false))) + + channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, parsed.Model) + + if h.errorPassthroughService != nil { + service.BindErrorPassthroughService(c, h.errorPassthroughService) + } + + subscription, _ := middleware2.GetSubscriptionFromContext(c) + + service.SetOpsLatencyMs(c, service.OpsAuthLatencyMsKey, time.Since(requestStart).Milliseconds()) + routingStart := time.Now() + + userReleaseFunc, acquired := h.acquireResponsesUserSlot(c, subject.UserID, subject.Concurrency, parsed.Stream, &streamStarted, reqLog) + if !acquired { + return + } + if userReleaseFunc != nil { + defer userReleaseFunc() + } + + if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { + reqLog.Info("openai.images.billing_eligibility_check_failed", zap.Error(err)) + status, code, message := billingErrorDetails(err) + h.handleStreamingAwareError(c, status, code, message, streamStarted) + return + } + + sessionHash := "" + if parsed.Multipart { + sessionHash = h.gatewayService.GenerateSessionHashWithFallback(c, nil, parsed.StickySessionSeed()) + } else { + sessionHash = h.gatewayService.GenerateSessionHash(c, body) + } + + maxAccountSwitches := h.maxAccountSwitches + switchCount := 0 + failedAccountIDs := make(map[int64]struct{}) + sameAccountRetryCount := make(map[int64]int) + var lastFailoverErr *service.UpstreamFailoverError + + for { + reqLog.Debug("openai.images.account_selecting", zap.Int("excluded_account_count", len(failedAccountIDs))) + selection, scheduleDecision, err := h.gatewayService.SelectAccountWithSchedulerForImages( + c.Request.Context(), + apiKey.GroupID, + sessionHash, + parsed.Model, + failedAccountIDs, + parsed.RequiredCapability, + ) + if err != nil { + reqLog.Warn("openai.images.account_select_failed", + zap.Error(err), + zap.Int("excluded_account_count", len(failedAccountIDs)), + ) + if len(failedAccountIDs) == 0 { + h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available compatible accounts", streamStarted) + return + } + if lastFailoverErr != nil { + h.handleFailoverExhausted(c, lastFailoverErr, streamStarted) + } else { + h.handleFailoverExhaustedSimple(c, 502, streamStarted) + } + return + } + if selection == nil || selection.Account == nil { + h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available compatible accounts", streamStarted) + return + } + + reqLog.Debug("openai.images.account_schedule_decision", + zap.String("layer", scheduleDecision.Layer), + zap.Bool("sticky_session_hit", scheduleDecision.StickySessionHit), + zap.Int("candidate_count", scheduleDecision.CandidateCount), + zap.Int("top_k", scheduleDecision.TopK), + zap.Int64("latency_ms", scheduleDecision.LatencyMs), + zap.Float64("load_skew", scheduleDecision.LoadSkew), + ) + + account := selection.Account + sessionHash = ensureOpenAIPoolModeSessionHash(sessionHash, account) + reqLog.Debug("openai.images.account_selected", zap.Int64("account_id", account.ID), zap.String("account_name", account.Name)) + setOpsSelectedAccount(c, account.ID, account.Platform) + + accountReleaseFunc, acquired := h.acquireResponsesAccountSlot(c, apiKey.GroupID, sessionHash, selection, parsed.Stream, &streamStarted, reqLog) + if !acquired { + return + } + + service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds()) + forwardStart := time.Now() + result, err := h.gatewayService.ForwardImages(c.Request.Context(), c, account, body, parsed, channelMapping.MappedModel) + forwardDurationMs := time.Since(forwardStart).Milliseconds() + if accountReleaseFunc != nil { + accountReleaseFunc() + } + upstreamLatencyMs, _ := getContextInt64(c, service.OpsUpstreamLatencyMsKey) + responseLatencyMs := forwardDurationMs + if upstreamLatencyMs > 0 && forwardDurationMs > upstreamLatencyMs { + responseLatencyMs = forwardDurationMs - upstreamLatencyMs + } + service.SetOpsLatencyMs(c, service.OpsResponseLatencyMsKey, responseLatencyMs) + if err == nil && result != nil && result.FirstTokenMs != nil { + service.SetOpsLatencyMs(c, service.OpsTimeToFirstTokenMsKey, int64(*result.FirstTokenMs)) + } + if err != nil { + var failoverErr *service.UpstreamFailoverError + if errors.As(err, &failoverErr) { + h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil) + if failoverErr.RetryableOnSameAccount { + retryLimit := account.GetPoolModeRetryCount() + if sameAccountRetryCount[account.ID] < retryLimit { + sameAccountRetryCount[account.ID]++ + reqLog.Warn("openai.images.pool_mode_same_account_retry", + zap.Int64("account_id", account.ID), + zap.Int("upstream_status", failoverErr.StatusCode), + zap.Int("retry_limit", retryLimit), + zap.Int("retry_count", sameAccountRetryCount[account.ID]), + ) + select { + case <-c.Request.Context().Done(): + return + case <-time.After(sameAccountRetryDelay): + } + continue + } + } + h.gatewayService.RecordOpenAIAccountSwitch() + failedAccountIDs[account.ID] = struct{}{} + lastFailoverErr = failoverErr + if switchCount >= maxAccountSwitches { + h.handleFailoverExhausted(c, failoverErr, streamStarted) + return + } + switchCount++ + reqLog.Warn("openai.images.upstream_failover_switching", + zap.Int64("account_id", account.ID), + zap.Int("upstream_status", failoverErr.StatusCode), + zap.Int("switch_count", switchCount), + zap.Int("max_switches", maxAccountSwitches), + ) + continue + } + h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil) + wroteFallback := h.ensureForwardErrorResponse(c, streamStarted) + fields := []zap.Field{ + zap.Int64("account_id", account.ID), + zap.Bool("fallback_error_response_written", wroteFallback), + zap.Error(err), + } + if shouldLogOpenAIForwardFailureAsWarn(c, wroteFallback) { + reqLog.Warn("openai.images.forward_failed", fields...) + return + } + reqLog.Error("openai.images.forward_failed", fields...) + return + } + + if result != nil { + if account.Type == service.AccountTypeOAuth { + h.gatewayService.UpdateCodexUsageSnapshotFromHeaders(c.Request.Context(), account.ID, result.ResponseHeaders) + } + h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs) + } else { + h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, nil) + } + + userAgent := c.GetHeader("User-Agent") + clientIP := ip.GetClientIP(c) + requestPayloadHash := service.HashUsageRequestPayload(body) + if parsed.Multipart { + requestPayloadHash = service.HashUsageRequestPayload([]byte(parsed.StickySessionSeed())) + } + + h.submitUsageRecordTask(func(ctx context.Context) { + if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{ + Result: result, + APIKey: apiKey, + User: apiKey.User, + Account: account, + Subscription: subscription, + InboundEndpoint: GetInboundEndpoint(c), + UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform), + UserAgent: userAgent, + IPAddress: clientIP, + RequestPayloadHash: requestPayloadHash, + APIKeyService: h.apiKeyService, + ChannelUsageFields: channelMapping.ToUsageFields(parsed.Model, result.UpstreamModel), + }); err != nil { + logger.L().With( + zap.String("component", "handler.openai_gateway.images"), + zap.Int64("user_id", subject.UserID), + zap.Int64("api_key_id", apiKey.ID), + zap.Any("group_id", apiKey.GroupID), + zap.String("model", parsed.Model), + zap.Int64("account_id", account.ID), + ).Error("openai.images.record_usage_failed", zap.Error(err)) + } + }) + + reqLog.Debug("openai.images.request_completed", + zap.Int64("account_id", account.ID), + zap.Int("switch_count", switchCount), + ) + return + } +} + +func isMultipartImagesContentType(contentType string) bool { + return strings.HasPrefix(strings.ToLower(strings.TrimSpace(contentType)), "multipart/form-data") +} diff --git a/backend/internal/handler/ops_error_logger.go b/backend/internal/handler/ops_error_logger.go index 90e90dd0..93554912 100644 --- a/backend/internal/handler/ops_error_logger.go +++ b/backend/internal/handler/ops_error_logger.go @@ -1068,7 +1068,7 @@ func guessPlatformFromPath(path string) string { return service.PlatformAntigravity case strings.HasPrefix(p, "/v1beta/"): return service.PlatformGemini - case strings.Contains(p, "/responses"): + case strings.Contains(p, "/responses"), strings.Contains(p, "/images/"): return service.PlatformOpenAI default: return "" diff --git a/backend/internal/pkg/openai/constants.go b/backend/internal/pkg/openai/constants.go index 980f058d..f023e32b 100644 --- a/backend/internal/pkg/openai/constants.go +++ b/backend/internal/pkg/openai/constants.go @@ -20,6 +20,7 @@ var DefaultModels = []Model{ {ID: "gpt-5.3-codex", Object: "model", Created: 1735689600, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.3 Codex"}, {ID: "gpt-5.3-codex-spark", Object: "model", Created: 1735689600, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.3 Codex Spark"}, {ID: "gpt-5.2", Object: "model", Created: 1733875200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.2"}, + {ID: "gpt-image-2", Object: "model", Created: 1738368000, OwnedBy: "openai", Type: "model", DisplayName: "GPT Image 2"}, } // DefaultModelIDs returns the default model ID list diff --git a/backend/internal/server/middleware/security_headers.go b/backend/internal/server/middleware/security_headers.go index 7021ab2e..398c0351 100644 --- a/backend/internal/server/middleware/security_headers.go +++ b/backend/internal/server/middleware/security_headers.go @@ -96,7 +96,8 @@ func isAPIRoutePath(c *gin.Context) bool { return strings.HasPrefix(path, "/v1/") || strings.HasPrefix(path, "/v1beta/") || strings.HasPrefix(path, "/antigravity/") || - strings.HasPrefix(path, "/responses") + strings.HasPrefix(path, "/responses") || + strings.HasPrefix(path, "/images") } // enhanceCSPPolicy ensures the CSP policy includes nonce support, Cloudflare Insights, diff --git a/backend/internal/server/routes/gateway.go b/backend/internal/server/routes/gateway.go index cbf98293..5982e1cc 100644 --- a/backend/internal/server/routes/gateway.go +++ b/backend/internal/server/routes/gateway.go @@ -88,6 +88,30 @@ func RegisterGatewayRoutes( } h.Gateway.ChatCompletions(c) }) + gateway.POST("/images/generations", func(c *gin.Context) { + if getGroupPlatform(c) != service.PlatformOpenAI { + c.JSON(http.StatusNotFound, gin.H{ + "error": gin.H{ + "type": "not_found_error", + "message": "Images API is not supported for this platform", + }, + }) + return + } + h.OpenAIGateway.Images(c) + }) + gateway.POST("/images/edits", func(c *gin.Context) { + if getGroupPlatform(c) != service.PlatformOpenAI { + c.JSON(http.StatusNotFound, gin.H{ + "error": gin.H{ + "type": "not_found_error", + "message": "Images API is not supported for this platform", + }, + }) + return + } + h.OpenAIGateway.Images(c) + }) } // Gemini 原生 API 兼容层(Gemini SDK/CLI 直连) @@ -124,6 +148,30 @@ func RegisterGatewayRoutes( } h.Gateway.ChatCompletions(c) }) + r.POST("/images/generations", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, func(c *gin.Context) { + if getGroupPlatform(c) != service.PlatformOpenAI { + c.JSON(http.StatusNotFound, gin.H{ + "error": gin.H{ + "type": "not_found_error", + "message": "Images API is not supported for this platform", + }, + }) + return + } + h.OpenAIGateway.Images(c) + }) + r.POST("/images/edits", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, func(c *gin.Context) { + if getGroupPlatform(c) != service.PlatformOpenAI { + c.JSON(http.StatusNotFound, gin.H{ + "error": gin.H{ + "type": "not_found_error", + "message": "Images API is not supported for this platform", + }, + }) + return + } + h.OpenAIGateway.Images(c) + }) // Antigravity 模型列表 r.GET("/antigravity/models", gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.Gateway.AntigravityModels) diff --git a/backend/internal/server/routes/gateway_test.go b/backend/internal/server/routes/gateway_test.go index 4d65a626..87a77cbc 100644 --- a/backend/internal/server/routes/gateway_test.go +++ b/backend/internal/server/routes/gateway_test.go @@ -9,6 +9,7 @@ import ( "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/handler" servermiddleware "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" "github.com/gin-gonic/gin" "github.com/stretchr/testify/require" ) @@ -24,6 +25,11 @@ func newGatewayRoutesTestRouter() *gin.Engine { OpenAIGateway: &handler.OpenAIGatewayHandler{}, }, servermiddleware.APIKeyAuthMiddleware(func(c *gin.Context) { + groupID := int64(1) + c.Set(string(servermiddleware.ContextKeyAPIKey), &service.APIKey{ + GroupID: &groupID, + Group: &service.Group{Platform: service.PlatformOpenAI}, + }) c.Next() }), nil, @@ -48,3 +54,21 @@ func TestGatewayRoutesOpenAIResponsesCompactPathIsRegistered(t *testing.T) { require.NotEqual(t, http.StatusNotFound, w.Code, "path=%s should hit OpenAI responses handler", path) } } + +func TestGatewayRoutesOpenAIImagesPathsAreRegistered(t *testing.T) { + router := newGatewayRoutesTestRouter() + + for _, path := range []string{ + "/v1/images/generations", + "/v1/images/edits", + "/images/generations", + "/images/edits", + } { + req := httptest.NewRequest(http.MethodPost, path, strings.NewReader(`{"model":"gpt-image-2","prompt":"draw a cat"}`)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + router.ServeHTTP(w, req) + require.NotEqual(t, http.StatusNotFound, w.Code, "path=%s should hit OpenAI images handler", path) + } +} diff --git a/backend/internal/service/account.go b/backend/internal/service/account.go index af686ae7..801eac1b 100644 --- a/backend/internal/service/account.go +++ b/backend/internal/service/account.go @@ -911,6 +911,34 @@ func (a *Account) GetChatGPTAccountID() string { return a.GetCredential("chatgpt_account_id") } +func (a *Account) GetOpenAIDeviceID() string { + if !a.IsOpenAIOAuth() { + return "" + } + return strings.TrimSpace(a.GetExtraString("openai_device_id")) +} + +func (a *Account) GetOpenAISessionID() string { + if !a.IsOpenAIOAuth() { + return "" + } + return strings.TrimSpace(a.GetExtraString("openai_session_id")) +} + +func (a *Account) SupportsOpenAIImageCapability(capability OpenAIImagesCapability) bool { + if !a.IsOpenAI() { + return false + } + switch capability { + case OpenAIImagesCapabilityBasic: + return a.Type == AccountTypeOAuth || a.Type == AccountTypeAPIKey + case OpenAIImagesCapabilityNative: + return a.Type == AccountTypeAPIKey + default: + return true + } +} + func (a *Account) GetChatGPTUserID() string { if !a.IsOpenAIOAuth() { return "" diff --git a/backend/internal/service/model_pricing_resolver.go b/backend/internal/service/model_pricing_resolver.go index b7ca4cb7..58089776 100644 --- a/backend/internal/service/model_pricing_resolver.go +++ b/backend/internal/service/model_pricing_resolver.go @@ -61,6 +61,25 @@ type PricingInput struct { // 1. 获取基础定价(LiteLLM → Fallback) // 2. 如果指定了 GroupID,查找渠道定价并覆盖 func (r *ModelPricingResolver) Resolve(ctx context.Context, input PricingInput) *ResolvedPricing { + var chPricing *ChannelModelPricing + if input.GroupID != nil && r.channelService != nil { + chPricing = r.channelService.GetChannelModelPricing(ctx, *input.GroupID, input.Model) + if chPricing != nil { + mode := chPricing.BillingMode + if mode == "" { + mode = BillingModeToken + } + if mode == BillingModePerRequest || mode == BillingModeImage { + resolved := &ResolvedPricing{ + Mode: mode, + Source: PricingSourceChannel, + } + r.applyRequestTierOverrides(chPricing, resolved) + return resolved + } + } + } + // 1. 获取基础定价 basePricing, source := r.resolveBasePricing(input.Model) @@ -72,7 +91,10 @@ func (r *ModelPricingResolver) Resolve(ctx context.Context, input PricingInput) } // 2. 如果有 GroupID,尝试渠道覆盖 - if input.GroupID != nil { + if chPricing != nil { + resolved.Source = PricingSourceChannel + r.applyTokenOverrides(chPricing, resolved) + } else if input.GroupID != nil { r.applyChannelOverrides(ctx, *input.GroupID, input.Model, resolved) } diff --git a/backend/internal/service/openai_account_scheduler.go b/backend/internal/service/openai_account_scheduler.go index 5fda3abd..f3533ec4 100644 --- a/backend/internal/service/openai_account_scheduler.go +++ b/backend/internal/service/openai_account_scheduler.go @@ -38,13 +38,14 @@ var openAIAdvancedSchedulerSettingCache atomic.Value // *cachedOpenAIAdvancedSch var openAIAdvancedSchedulerSettingSF singleflight.Group type OpenAIAccountScheduleRequest struct { - GroupID *int64 - SessionHash string - StickyAccountID int64 - PreviousResponseID string - RequestedModel string - RequiredTransport OpenAIUpstreamTransport - ExcludedIDs map[int64]struct{} + GroupID *int64 + SessionHash string + StickyAccountID int64 + PreviousResponseID string + RequestedModel string + RequiredTransport OpenAIUpstreamTransport + RequiredImageCapability OpenAIImagesCapability + ExcludedIDs map[int64]struct{} } type OpenAIAccountScheduleDecision struct { @@ -340,7 +341,7 @@ func (s *defaultOpenAIAccountScheduler) selectBySessionHash( _ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash) return nil, nil } - if req.RequestedModel != "" && !account.IsModelSupported(req.RequestedModel) { + if !s.isAccountRequestCompatible(account, req) { return nil, nil } if !s.isAccountTransportCompatible(account, req.RequiredTransport) { @@ -616,7 +617,7 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance( fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name)) continue } - if req.RequestedModel != "" && !account.IsModelSupported(req.RequestedModel) { + if !s.isAccountRequestCompatible(account, req) { continue } if !s.isAccountTransportCompatible(account, req.RequiredTransport) { @@ -722,11 +723,11 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance( for i := 0; i < len(selectionOrder); i++ { candidate := selectionOrder[i] fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, candidate.account, req.RequestedModel) - if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) { + if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(fresh, req) { continue } fresh = s.service.recheckSelectedOpenAIAccountFromDB(ctx, fresh, req.RequestedModel) - if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) { + if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(fresh, req) { continue } result, acquireErr := s.service.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency) @@ -749,7 +750,7 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance( // WaitPlan.MaxConcurrency 使用 Concurrency(非 EffectiveLoadFactor),因为 WaitPlan 控制的是 Redis 实际并发槽位等待。 for _, candidate := range selectionOrder { fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, candidate.account, req.RequestedModel) - if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) { + if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(fresh, req) { continue } return &AccountSelectionResult{ @@ -776,6 +777,16 @@ func (s *defaultOpenAIAccountScheduler) isAccountTransportCompatible(account *Ac return s.service.isOpenAIAccountTransportCompatible(account, requiredTransport) } +func (s *defaultOpenAIAccountScheduler) isAccountRequestCompatible(account *Account, req OpenAIAccountScheduleRequest) bool { + if account == nil { + return false + } + if req.RequestedModel != "" && !account.IsModelSupported(req.RequestedModel) { + return false + } + return account.SupportsOpenAIImageCapability(req.RequiredImageCapability) +} + func (s *defaultOpenAIAccountScheduler) ReportResult(accountID int64, success bool, firstTokenMs *int) { if s == nil || s.stats == nil { return @@ -894,14 +905,59 @@ func (s *OpenAIGatewayService) SelectAccountWithScheduler( requestedModel string, excludedIDs map[int64]struct{}, requiredTransport OpenAIUpstreamTransport, +) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error) { + return s.selectAccountWithScheduler(ctx, groupID, previousResponseID, sessionHash, requestedModel, excludedIDs, requiredTransport, "") +} + +func (s *OpenAIGatewayService) SelectAccountWithSchedulerForImages( + ctx context.Context, + groupID *int64, + sessionHash string, + requestedModel string, + excludedIDs map[int64]struct{}, + requiredCapability OpenAIImagesCapability, +) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error) { + return s.selectAccountWithScheduler(ctx, groupID, "", sessionHash, requestedModel, excludedIDs, OpenAIUpstreamTransportHTTPSSE, requiredCapability) +} + +func (s *OpenAIGatewayService) selectAccountWithScheduler( + ctx context.Context, + groupID *int64, + previousResponseID string, + sessionHash string, + requestedModel string, + excludedIDs map[int64]struct{}, + requiredTransport OpenAIUpstreamTransport, + requiredImageCapability OpenAIImagesCapability, ) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error) { decision := OpenAIAccountScheduleDecision{} scheduler := s.getOpenAIAccountScheduler(ctx) if scheduler == nil { decision.Layer = openAIAccountScheduleLayerLoadBalance if requiredTransport == OpenAIUpstreamTransportAny || requiredTransport == OpenAIUpstreamTransportHTTPSSE { - selection, err := s.SelectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, excludedIDs) - return selection, decision, err + effectiveExcludedIDs := cloneExcludedAccountIDs(excludedIDs) + for { + selection, err := s.SelectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, effectiveExcludedIDs) + if err != nil { + return nil, decision, err + } + if selection == nil || selection.Account == nil { + return selection, decision, nil + } + if selection.Account.SupportsOpenAIImageCapability(requiredImageCapability) { + return selection, decision, nil + } + if selection.ReleaseFunc != nil { + selection.ReleaseFunc() + } + if effectiveExcludedIDs == nil { + effectiveExcludedIDs = make(map[int64]struct{}) + } + if _, exists := effectiveExcludedIDs[selection.Account.ID]; exists { + return nil, decision, ErrNoAvailableAccounts + } + effectiveExcludedIDs[selection.Account.ID] = struct{}{} + } } effectiveExcludedIDs := cloneExcludedAccountIDs(excludedIDs) @@ -937,13 +993,14 @@ func (s *OpenAIGatewayService) SelectAccountWithScheduler( } return scheduler.Select(ctx, OpenAIAccountScheduleRequest{ - GroupID: groupID, - SessionHash: sessionHash, - StickyAccountID: stickyAccountID, - PreviousResponseID: previousResponseID, - RequestedModel: requestedModel, - RequiredTransport: requiredTransport, - ExcludedIDs: excludedIDs, + GroupID: groupID, + SessionHash: sessionHash, + StickyAccountID: stickyAccountID, + PreviousResponseID: previousResponseID, + RequestedModel: requestedModel, + RequiredTransport: requiredTransport, + RequiredImageCapability: requiredImageCapability, + ExcludedIDs: excludedIDs, }) } diff --git a/backend/internal/service/openai_gateway_record_usage_test.go b/backend/internal/service/openai_gateway_record_usage_test.go index 6fa8a5bd..95e1bffa 100644 --- a/backend/internal/service/openai_gateway_record_usage_test.go +++ b/backend/internal/service/openai_gateway_record_usage_test.go @@ -1070,3 +1070,31 @@ func TestOpenAIGatewayServiceRecordUsage_SimpleModeSkipsBillingAfterPersist(t *t require.Equal(t, 0, userRepo.deductCalls) require.Equal(t, 0, subRepo.incrementCalls) } + +func TestOpenAIGatewayServiceRecordUsage_ImageOnlyUsageStillPersists(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{inserted: true} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil) + + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_image_only_usage", + Model: "gpt-image-2", + ImageCount: 2, + ImageSize: "1K", + Duration: time.Second, + }, + APIKey: &APIKey{ID: 1007}, + User: &User{ID: 2007}, + Account: &Account{ID: 3007}, + }) + + require.NoError(t, err) + require.NotNil(t, usageRepo.lastLog) + require.Equal(t, 2, usageRepo.lastLog.ImageCount) + require.NotNil(t, usageRepo.lastLog.ImageSize) + require.Equal(t, "1K", *usageRepo.lastLog.ImageSize) + require.NotNil(t, usageRepo.lastLog.BillingMode) + require.Equal(t, string(BillingModeImage), *usageRepo.lastLog.BillingMode) +} diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index 064191bd..a4a7ff1b 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -233,6 +233,8 @@ type OpenAIForwardResult struct { ResponseHeaders http.Header Duration time.Duration FirstTokenMs *int + ImageCount int + ImageSize string } type OpenAIWSRetryMetricsSnapshot struct { @@ -3889,6 +3891,7 @@ func (s *OpenAIGatewayService) parseSSEUsageBytes(data []byte, usage *OpenAIUsag usage.InputTokens = int(gjson.GetBytes(data, "response.usage.input_tokens").Int()) usage.OutputTokens = int(gjson.GetBytes(data, "response.usage.output_tokens").Int()) usage.CacheReadInputTokens = int(gjson.GetBytes(data, "response.usage.input_tokens_details.cached_tokens").Int()) + usage.ImageOutputTokens = int(gjson.GetBytes(data, "response.usage.output_tokens_details.image_tokens").Int()) } func extractOpenAIUsageFromJSONBytes(body []byte) (OpenAIUsage, bool) { @@ -3900,11 +3903,13 @@ func extractOpenAIUsageFromJSONBytes(body []byte) (OpenAIUsage, bool) { "usage.input_tokens", "usage.output_tokens", "usage.input_tokens_details.cached_tokens", + "usage.output_tokens_details.image_tokens", ) return OpenAIUsage{ InputTokens: int(values[0].Int()), OutputTokens: int(values[1].Int()), CacheReadInputTokens: int(values[2].Int()), + ImageOutputTokens: int(values[3].Int()), }, true } @@ -4397,7 +4402,8 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec // 跳过所有 token 均为零的用量记录——上游未返回 usage 时不应写入数据库 if result.Usage.InputTokens == 0 && result.Usage.OutputTokens == 0 && - result.Usage.CacheCreationInputTokens == 0 && result.Usage.CacheReadInputTokens == 0 { + result.Usage.CacheCreationInputTokens == 0 && result.Usage.CacheReadInputTokens == 0 && + result.Usage.ImageOutputTokens == 0 && result.ImageCount == 0 { return nil } @@ -4451,21 +4457,7 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec if result.ServiceTier != nil { serviceTier = strings.TrimSpace(*result.ServiceTier) } - if s.resolver != nil && apiKey.Group != nil { - gid := apiKey.Group.ID - cost, err = s.billingService.CalculateCostUnified(CostInput{ - Ctx: ctx, - Model: billingModel, - GroupID: &gid, - Tokens: tokens, - RequestCount: 1, - RateMultiplier: multiplier, - ServiceTier: serviceTier, - Resolver: s.resolver, - }) - } else { - cost, err = s.billingService.CalculateCostWithServiceTier(billingModel, tokens, multiplier, serviceTier) - } + cost, err = s.calculateOpenAIRecordUsageCost(ctx, result, apiKey, billingModel, multiplier, tokens, serviceTier) if err != nil { cost = &CostBreakdown{ActualCost: 0} } @@ -4505,6 +4497,8 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec CacheCreationTokens: result.Usage.CacheCreationInputTokens, CacheReadTokens: result.Usage.CacheReadInputTokens, ImageOutputTokens: result.Usage.ImageOutputTokens, + ImageCount: result.ImageCount, + ImageSize: optionalTrimmedStringPtr(result.ImageSize), } if cost != nil { usageLog.InputCost = cost.InputCost @@ -4530,6 +4524,9 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec if cost != nil && cost.BillingMode != "" { billingMode := cost.BillingMode usageLog.BillingMode = &billingMode + } else if result.ImageCount > 0 { + billingMode := string(BillingModeImage) + usageLog.BillingMode = &billingMode } else { billingMode := string(BillingModeToken) usageLog.BillingMode = &billingMode @@ -4589,6 +4586,125 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec return nil } +func (s *OpenAIGatewayService) calculateOpenAIRecordUsageCost( + ctx context.Context, + result *OpenAIForwardResult, + apiKey *APIKey, + billingModel string, + multiplier float64, + tokens UsageTokens, + serviceTier string, +) (*CostBreakdown, error) { + if result != nil && result.ImageCount > 0 { + if hasOpenAIImageUsageTokens(result) { + cost, err := s.calculateOpenAIImageTokenCost(ctx, apiKey, billingModel, multiplier, tokens, serviceTier, result.ImageSize) + if err == nil { + return cost, nil + } + } + return s.calculateOpenAIImageCost(ctx, billingModel, apiKey, result, multiplier), nil + } + if s.resolver != nil && apiKey.Group != nil { + gid := apiKey.Group.ID + return s.billingService.CalculateCostUnified(CostInput{ + Ctx: ctx, + Model: billingModel, + GroupID: &gid, + Tokens: tokens, + RequestCount: 1, + RateMultiplier: multiplier, + ServiceTier: serviceTier, + Resolver: s.resolver, + }) + } + return s.billingService.CalculateCostWithServiceTier(billingModel, tokens, multiplier, serviceTier) +} + +func (s *OpenAIGatewayService) calculateOpenAIImageTokenCost( + ctx context.Context, + apiKey *APIKey, + billingModel string, + multiplier float64, + tokens UsageTokens, + serviceTier string, + sizeTier string, +) (*CostBreakdown, error) { + if s.resolver != nil && apiKey.Group != nil { + gid := apiKey.Group.ID + return s.billingService.CalculateCostUnified(CostInput{ + Ctx: ctx, + Model: billingModel, + GroupID: &gid, + Tokens: tokens, + RequestCount: 1, + SizeTier: sizeTier, + RateMultiplier: multiplier, + ServiceTier: serviceTier, + Resolver: s.resolver, + }) + } + return s.billingService.CalculateCostWithServiceTier(billingModel, tokens, multiplier, serviceTier) +} + +func (s *OpenAIGatewayService) calculateOpenAIImageCost( + ctx context.Context, + billingModel string, + apiKey *APIKey, + result *OpenAIForwardResult, + multiplier float64, +) *CostBreakdown { + if resolved := s.resolveOpenAIChannelPricing(ctx, billingModel, apiKey); resolved != nil { + gid := apiKey.Group.ID + cost, err := s.billingService.CalculateCostUnified(CostInput{ + Ctx: ctx, + Model: billingModel, + GroupID: &gid, + RequestCount: 1, + SizeTier: result.ImageSize, + RateMultiplier: multiplier, + Resolver: s.resolver, + Resolved: resolved, + }) + if err == nil { + return cost + } + logger.LegacyPrintf("service.openai_gateway", "Calculate image channel cost failed: %v", err) + } + + var groupConfig *ImagePriceConfig + if apiKey != nil && apiKey.Group != nil { + groupConfig = &ImagePriceConfig{ + Price1K: apiKey.Group.ImagePrice1K, + Price2K: apiKey.Group.ImagePrice2K, + Price4K: apiKey.Group.ImagePrice4K, + } + } + return s.billingService.CalculateImageCost(billingModel, result.ImageSize, result.ImageCount, groupConfig, multiplier) +} + +func (s *OpenAIGatewayService) resolveOpenAIChannelPricing(ctx context.Context, billingModel string, apiKey *APIKey) *ResolvedPricing { + if s.resolver == nil || apiKey == nil || apiKey.Group == nil { + return nil + } + gid := apiKey.Group.ID + resolved := s.resolver.Resolve(ctx, PricingInput{Model: billingModel, GroupID: &gid}) + if resolved.Source == PricingSourceChannel { + return resolved + } + return nil +} + +func hasOpenAIImageUsageTokens(result *OpenAIForwardResult) bool { + if result == nil { + return false + } + return result.Usage.InputTokens > 0 || + result.Usage.OutputTokens > 0 || + result.Usage.CacheCreationInputTokens > 0 || + result.Usage.CacheReadInputTokens > 0 || + result.Usage.ImageOutputTokens > 0 +} + // ParseCodexRateLimitHeaders extracts Codex usage limits from response headers. // Exported for use in ratelimit_service when handling OpenAI 429 responses. func ParseCodexRateLimitHeaders(headers http.Header) *OpenAICodexUsageSnapshot { diff --git a/backend/internal/service/openai_images.go b/backend/internal/service/openai_images.go new file mode 100644 index 00000000..396c0381 --- /dev/null +++ b/backend/internal/service/openai_images.go @@ -0,0 +1,2013 @@ +package service + +import ( + "bufio" + "bytes" + "context" + "crypto/sha256" + "crypto/sha3" + "encoding/base64" + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "io" + "mime" + "mime/multipart" + "net/http" + "net/textproto" + "sort" + "strconv" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/Wei-Shaw/sub2api/internal/pkg/proxyurl" + "github.com/Wei-Shaw/sub2api/internal/util/responseheaders" + "github.com/gin-gonic/gin" + "github.com/google/uuid" + "github.com/imroc/req/v3" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" +) + +const ( + openAIImagesGenerationsEndpoint = "/v1/images/generations" + openAIImagesEditsEndpoint = "/v1/images/edits" + + openAIImagesGenerationsURL = "https://api.openai.com/v1/images/generations" + openAIImagesEditsURL = "https://api.openai.com/v1/images/edits" + + openAIChatGPTStartURL = "https://chatgpt.com/" + openAIChatGPTFilesURL = "https://chatgpt.com/backend-api/files" + openAIChatGPTConversationInitURL = "https://chatgpt.com/backend-api/conversation/init" + openAIChatGPTConversationURL = "https://chatgpt.com/backend-api/f/conversation" + openAIChatGPTConversationPrepareURL = "https://chatgpt.com/backend-api/f/conversation/prepare" + openAIChatGPTChatRequirementsURL = "https://chatgpt.com/backend-api/sentinel/chat-requirements" + + openAIImageBackendUserAgent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36" + openAIImageRequirementsDiff = "0fffff" +) + +type OpenAIImagesCapability string + +const ( + OpenAIImagesCapabilityBasic OpenAIImagesCapability = "images-basic" + OpenAIImagesCapabilityNative OpenAIImagesCapability = "images-native" +) + +type OpenAIImagesUpload struct { + FieldName string + FileName string + ContentType string + Data []byte + Width int + Height int +} + +type OpenAIImagesRequest struct { + Endpoint string + ContentType string + Multipart bool + Model string + ExplicitModel bool + Prompt string + Stream bool + N int + Size string + ExplicitSize bool + SizeTier string + ResponseFormat string + HasMask bool + HasNativeOptions bool + RequiredCapability OpenAIImagesCapability + Uploads []OpenAIImagesUpload + Body []byte + bodyHash string +} + +func (r *OpenAIImagesRequest) IsEdits() bool { + return r != nil && r.Endpoint == openAIImagesEditsEndpoint +} + +func (r *OpenAIImagesRequest) StickySessionSeed() string { + if r == nil { + return "" + } + parts := []string{ + "openai-images", + strings.TrimSpace(r.Endpoint), + strings.TrimSpace(r.Model), + strings.TrimSpace(r.Size), + strings.TrimSpace(r.Prompt), + } + seed := strings.Join(parts, "|") + if strings.TrimSpace(r.Prompt) == "" && r.bodyHash != "" { + seed += "|body=" + r.bodyHash + } + return seed +} + +func (s *OpenAIGatewayService) ParseOpenAIImagesRequest(c *gin.Context, body []byte) (*OpenAIImagesRequest, error) { + if c == nil || c.Request == nil { + return nil, fmt.Errorf("missing request context") + } + endpoint := normalizeOpenAIImagesEndpointPath(c.Request.URL.Path) + if endpoint == "" { + return nil, fmt.Errorf("unsupported images endpoint") + } + + contentType := strings.TrimSpace(c.GetHeader("Content-Type")) + req := &OpenAIImagesRequest{ + Endpoint: endpoint, + ContentType: contentType, + N: 1, + Body: body, + } + if len(body) > 0 { + sum := sha256.Sum256(body) + req.bodyHash = hex.EncodeToString(sum[:8]) + } + + mediaType, _, err := mime.ParseMediaType(contentType) + if err == nil && strings.EqualFold(mediaType, "multipart/form-data") { + req.Multipart = true + if parseErr := parseOpenAIImagesMultipartRequest(body, contentType, req); parseErr != nil { + return nil, parseErr + } + } else { + if len(body) == 0 { + return nil, fmt.Errorf("request body is empty") + } + if !gjson.ValidBytes(body) { + return nil, fmt.Errorf("failed to parse request body") + } + if parseErr := parseOpenAIImagesJSONRequest(body, req); parseErr != nil { + return nil, parseErr + } + } + + applyOpenAIImagesDefaults(req) + req.SizeTier = normalizeOpenAIImageSizeTier(req.Size) + req.RequiredCapability = classifyOpenAIImagesCapability(req) + return req, nil +} + +func parseOpenAIImagesJSONRequest(body []byte, req *OpenAIImagesRequest) error { + if modelResult := gjson.GetBytes(body, "model"); modelResult.Exists() { + req.Model = strings.TrimSpace(modelResult.String()) + req.ExplicitModel = req.Model != "" + } + req.Prompt = strings.TrimSpace(gjson.GetBytes(body, "prompt").String()) + + if streamResult := gjson.GetBytes(body, "stream"); streamResult.Exists() { + if streamResult.Type != gjson.True && streamResult.Type != gjson.False { + return fmt.Errorf("invalid stream field type") + } + req.Stream = streamResult.Bool() + } + + if nResult := gjson.GetBytes(body, "n"); nResult.Exists() { + if nResult.Type != gjson.Number { + return fmt.Errorf("invalid n field type") + } + req.N = int(nResult.Int()) + if req.N <= 0 { + return fmt.Errorf("n must be greater than 0") + } + } + + if sizeResult := gjson.GetBytes(body, "size"); sizeResult.Exists() { + req.Size = strings.TrimSpace(sizeResult.String()) + req.ExplicitSize = req.Size != "" + } + req.ResponseFormat = strings.ToLower(strings.TrimSpace(gjson.GetBytes(body, "response_format").String())) + req.HasMask = gjson.GetBytes(body, "mask").Exists() + req.HasNativeOptions = hasOpenAINativeImageOptions(func(path string) bool { + return gjson.GetBytes(body, path).Exists() + }) + return nil +} + +func parseOpenAIImagesMultipartRequest(body []byte, contentType string, req *OpenAIImagesRequest) error { + _, params, err := mime.ParseMediaType(contentType) + if err != nil { + return fmt.Errorf("invalid multipart content-type: %w", err) + } + boundary := strings.TrimSpace(params["boundary"]) + if boundary == "" { + return fmt.Errorf("multipart boundary is required") + } + + reader := multipart.NewReader(bytes.NewReader(body), boundary) + for { + part, err := reader.NextPart() + if err == io.EOF { + break + } + if err != nil { + return fmt.Errorf("read multipart body: %w", err) + } + name := strings.TrimSpace(part.FormName()) + if name == "" { + _ = part.Close() + continue + } + + data, err := io.ReadAll(part) + _ = part.Close() + if err != nil { + return fmt.Errorf("read multipart field %s: %w", name, err) + } + + fileName := strings.TrimSpace(part.FileName()) + if fileName != "" { + partContentType := strings.TrimSpace(part.Header.Get("Content-Type")) + if name == "mask" && len(data) > 0 { + req.HasMask = true + } + if name == "image" || strings.HasPrefix(name, "image[") { + width, height := parseOpenAIImageDimensions(part.Header) + req.Uploads = append(req.Uploads, OpenAIImagesUpload{ + FieldName: name, + FileName: fileName, + ContentType: partContentType, + Data: data, + Width: width, + Height: height, + }) + } + continue + } + + value := strings.TrimSpace(string(data)) + switch name { + case "model": + req.Model = value + req.ExplicitModel = value != "" + case "prompt": + req.Prompt = value + case "size": + req.Size = value + req.ExplicitSize = value != "" + case "response_format": + req.ResponseFormat = strings.ToLower(value) + case "stream": + parsed, err := strconv.ParseBool(value) + if err != nil { + return fmt.Errorf("invalid stream field value") + } + req.Stream = parsed + case "n": + n, err := strconv.Atoi(value) + if err != nil || n <= 0 { + return fmt.Errorf("n must be a positive integer") + } + req.N = n + default: + if isOpenAINativeImageOption(name) && value != "" { + req.HasNativeOptions = true + } + } + } + + if len(req.Uploads) == 0 && req.IsEdits() { + return fmt.Errorf("image file is required") + } + return nil +} + +func parseOpenAIImageDimensions(_ textproto.MIMEHeader) (int, int) { + return 0, 0 +} + +func applyOpenAIImagesDefaults(req *OpenAIImagesRequest) { + if req == nil { + return + } + if req.N <= 0 { + req.N = 1 + } + if strings.TrimSpace(req.Model) != "" { + req.Model = strings.TrimSpace(req.Model) + return + } + req.Model = "gpt-image-2" +} + +func normalizeOpenAIImagesEndpointPath(path string) string { + trimmed := strings.TrimSpace(path) + switch { + case strings.Contains(trimmed, "/images/generations"): + return openAIImagesGenerationsEndpoint + case strings.Contains(trimmed, "/images/edits"): + return openAIImagesEditsEndpoint + default: + return "" + } +} + +func classifyOpenAIImagesCapability(req *OpenAIImagesRequest) OpenAIImagesCapability { + if req == nil { + return OpenAIImagesCapabilityNative + } + if req.ExplicitModel || req.ExplicitSize { + return OpenAIImagesCapabilityNative + } + model := strings.ToLower(strings.TrimSpace(req.Model)) + if !strings.HasPrefix(model, "gpt-image-") { + return OpenAIImagesCapabilityNative + } + if req.Stream || req.N != 1 || req.HasMask || req.HasNativeOptions { + return OpenAIImagesCapabilityNative + } + if req.IsEdits() && !req.Multipart { + return OpenAIImagesCapabilityNative + } + if req.ResponseFormat != "" && req.ResponseFormat != "b64_json" { + return OpenAIImagesCapabilityNative + } + return OpenAIImagesCapabilityBasic +} + +func hasOpenAINativeImageOptions(exists func(path string) bool) bool { + for _, path := range []string{ + "background", + "quality", + "style", + "output_format", + "output_compression", + "moderation", + } { + if exists(path) { + return true + } + } + return false +} + +func isOpenAINativeImageOption(name string) bool { + switch strings.TrimSpace(strings.ToLower(name)) { + case "background", "quality", "style", "output_format", "output_compression", "moderation": + return true + default: + return false + } +} + +func normalizeOpenAIImageSizeTier(size string) string { + switch strings.ToLower(strings.TrimSpace(size)) { + case "1024x1024": + return "1K" + case "1536x1024", "1024x1536", "1792x1024", "1024x1792", "", "auto": + return "2K" + default: + return "2K" + } +} + +func (s *OpenAIGatewayService) ForwardImages( + ctx context.Context, + c *gin.Context, + account *Account, + body []byte, + parsed *OpenAIImagesRequest, + channelMappedModel string, +) (*OpenAIForwardResult, error) { + if parsed == nil { + return nil, fmt.Errorf("parsed images request is required") + } + switch account.Type { + case AccountTypeAPIKey: + return s.forwardOpenAIImagesAPIKey(ctx, c, account, body, parsed, channelMappedModel) + case AccountTypeOAuth: + return s.forwardOpenAIImagesOAuth(ctx, c, account, parsed, channelMappedModel) + default: + return nil, fmt.Errorf("unsupported account type: %s", account.Type) + } +} + +func (s *OpenAIGatewayService) forwardOpenAIImagesAPIKey( + ctx context.Context, + c *gin.Context, + account *Account, + body []byte, + parsed *OpenAIImagesRequest, + channelMappedModel string, +) (*OpenAIForwardResult, error) { + startTime := time.Now() + requestModel := strings.TrimSpace(parsed.Model) + if mapped := strings.TrimSpace(channelMappedModel); mapped != "" { + requestModel = mapped + } + upstreamModel := account.GetMappedModel(requestModel) + forwardBody, forwardContentType, err := rewriteOpenAIImagesModel(body, parsed.ContentType, upstreamModel) + if err != nil { + return nil, err + } + if !parsed.Multipart { + setOpsUpstreamRequestBody(c, forwardBody) + } + + token, _, err := s.GetAccessToken(ctx, account) + if err != nil { + return nil, err + } + upstreamReq, err := s.buildOpenAIImagesRequest(ctx, c, account, forwardBody, forwardContentType, token, parsed.Endpoint) + if err != nil { + return nil, err + } + + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + upstreamStart := time.Now() + resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) + SetOpsLatencyMs(c, OpsUpstreamLatencyMsKey, time.Since(upstreamStart).Milliseconds()) + if err != nil { + safeErr := sanitizeUpstreamErrorMessage(err.Error()) + setOpsUpstreamError(c, 0, safeErr, "") + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: 0, + UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()), + Kind: "request_error", + Message: safeErr, + }) + return nil, fmt.Errorf("upstream request failed: %s", safeErr) + } + if resp.StatusCode >= 400 { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + resp.Body = io.NopCloser(bytes.NewReader(respBody)) + upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody)) + upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + if s.shouldFailoverOpenAIUpstreamResponse(resp.StatusCode, upstreamMsg, respBody) { + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()), + Kind: "failover", + Message: upstreamMsg, + }) + s.handleFailoverSideEffects(ctx, resp, account) + return nil, &UpstreamFailoverError{ + StatusCode: resp.StatusCode, + ResponseBody: respBody, + RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode), + } + } + return s.handleErrorResponse(ctx, resp, c, account, forwardBody) + } + defer func() { _ = resp.Body.Close() }() + + var usage OpenAIUsage + imageCount := parsed.N + var firstTokenMs *int + if parsed.Stream { + streamUsage, streamCount, ttft, err := s.handleOpenAIImagesStreamingResponse(resp, c, startTime) + if err != nil { + return nil, err + } + usage = streamUsage + imageCount = streamCount + firstTokenMs = ttft + } else { + nonStreamUsage, nonStreamCount, err := s.handleOpenAIImagesNonStreamingResponse(resp, c) + if err != nil { + return nil, err + } + usage = nonStreamUsage + if nonStreamCount > 0 { + imageCount = nonStreamCount + } + } + return &OpenAIForwardResult{ + RequestID: resp.Header.Get("x-request-id"), + Usage: usage, + Model: requestModel, + UpstreamModel: upstreamModel, + Stream: parsed.Stream, + ResponseHeaders: resp.Header.Clone(), + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + ImageCount: imageCount, + ImageSize: parsed.SizeTier, + }, nil +} + +func (s *OpenAIGatewayService) buildOpenAIImagesRequest( + ctx context.Context, + c *gin.Context, + account *Account, + body []byte, + contentType string, + token string, + endpoint string, +) (*http.Request, error) { + targetURL := openAIImagesGenerationsURL + if endpoint == openAIImagesEditsEndpoint { + targetURL = openAIImagesEditsURL + } + baseURL := account.GetOpenAIBaseURL() + if baseURL != "" { + validatedURL, err := s.validateUpstreamBaseURL(baseURL) + if err != nil { + return nil, err + } + targetURL = buildOpenAIImagesURL(validatedURL, endpoint) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, bytes.NewReader(body)) + if err != nil { + return nil, err + } + req.Header.Set("Authorization", "Bearer "+token) + for key, values := range c.Request.Header { + if !openaiPassthroughAllowedHeaders[strings.ToLower(key)] { + continue + } + for _, value := range values { + req.Header.Add(key, value) + } + } + customUA := account.GetOpenAIUserAgent() + if customUA != "" { + req.Header.Set("User-Agent", customUA) + } + if strings.TrimSpace(contentType) != "" { + req.Header.Set("Content-Type", contentType) + } + return req, nil +} + +func buildOpenAIImagesURL(base string, endpoint string) string { + normalized := strings.TrimRight(strings.TrimSpace(base), "/") + relative := strings.TrimPrefix(strings.TrimSpace(endpoint), "/v1") + if strings.HasSuffix(normalized, endpoint) || strings.HasSuffix(normalized, relative) { + return normalized + } + if strings.HasSuffix(normalized, "/v1") { + return normalized + relative + } + return normalized + endpoint +} + +func rewriteOpenAIImagesModel(body []byte, contentType string, model string) ([]byte, string, error) { + model = strings.TrimSpace(model) + if model == "" { + return body, contentType, nil + } + mediaType, _, err := mime.ParseMediaType(contentType) + if err == nil && strings.EqualFold(mediaType, "multipart/form-data") { + rewrittenBody, rewrittenType, rewriteErr := rewriteOpenAIImagesMultipartModel(body, contentType, model) + return rewrittenBody, rewrittenType, rewriteErr + } + rewritten, err := sjson.SetBytes(body, "model", model) + if err != nil { + return nil, "", fmt.Errorf("rewrite image request model: %w", err) + } + return rewritten, contentType, nil +} + +func rewriteOpenAIImagesMultipartModel(body []byte, contentType string, model string) ([]byte, string, error) { + _, params, err := mime.ParseMediaType(contentType) + if err != nil { + return nil, "", fmt.Errorf("parse multipart content-type: %w", err) + } + boundary := strings.TrimSpace(params["boundary"]) + if boundary == "" { + return nil, "", fmt.Errorf("multipart boundary is required") + } + + reader := multipart.NewReader(bytes.NewReader(body), boundary) + var buffer bytes.Buffer + writer := multipart.NewWriter(&buffer) + modelWritten := false + + for { + part, err := reader.NextPart() + if err == io.EOF { + break + } + if err != nil { + return nil, "", fmt.Errorf("read multipart body: %w", err) + } + + formName := strings.TrimSpace(part.FormName()) + partHeader := cloneMultipartHeader(part.Header) + target, err := writer.CreatePart(partHeader) + if err != nil { + _ = part.Close() + return nil, "", fmt.Errorf("create multipart part: %w", err) + } + + if formName == "model" && part.FileName() == "" { + if _, err := target.Write([]byte(model)); err != nil { + _ = part.Close() + return nil, "", fmt.Errorf("rewrite multipart model: %w", err) + } + modelWritten = true + _ = part.Close() + continue + } + if _, err := io.Copy(target, part); err != nil { + _ = part.Close() + return nil, "", fmt.Errorf("copy multipart part: %w", err) + } + _ = part.Close() + } + + if !modelWritten { + if err := writer.WriteField("model", model); err != nil { + return nil, "", fmt.Errorf("append multipart model field: %w", err) + } + } + if err := writer.Close(); err != nil { + return nil, "", fmt.Errorf("finalize multipart body: %w", err) + } + return buffer.Bytes(), writer.FormDataContentType(), nil +} + +func cloneMultipartHeader(src textproto.MIMEHeader) textproto.MIMEHeader { + dst := make(textproto.MIMEHeader, len(src)) + for key, values := range src { + copied := make([]string, len(values)) + copy(copied, values) + dst[key] = copied + } + return dst +} + +func (s *OpenAIGatewayService) handleOpenAIImagesNonStreamingResponse(resp *http.Response, c *gin.Context) (OpenAIUsage, int, error) { + body, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, openAITooLargeError) + if err != nil { + return OpenAIUsage{}, 0, err + } + responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) + contentType := "application/json" + if s.cfg != nil && !s.cfg.Security.ResponseHeaders.Enabled { + if upstreamType := resp.Header.Get("Content-Type"); upstreamType != "" { + contentType = upstreamType + } + } + c.Data(resp.StatusCode, contentType, body) + + usage, _ := extractOpenAIUsageFromJSONBytes(body) + return usage, extractOpenAIImageCountFromJSONBytes(body), nil +} + +func (s *OpenAIGatewayService) handleOpenAIImagesStreamingResponse( + resp *http.Response, + c *gin.Context, + startTime time.Time, +) (OpenAIUsage, int, *int, error) { + responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) + contentType := strings.TrimSpace(resp.Header.Get("Content-Type")) + if contentType == "" { + contentType = "text/event-stream" + } + c.Status(resp.StatusCode) + c.Header("Content-Type", contentType) + + flusher, ok := c.Writer.(http.Flusher) + if !ok { + return OpenAIUsage{}, 0, nil, fmt.Errorf("streaming is not supported by response writer") + } + + reader := bufio.NewReader(resp.Body) + usage := OpenAIUsage{} + imageCount := 0 + var firstTokenMs *int + + for { + line, err := reader.ReadBytes('\n') + if len(line) > 0 { + if firstTokenMs == nil { + ms := int(time.Since(startTime).Milliseconds()) + firstTokenMs = &ms + } + if _, writeErr := c.Writer.Write(line); writeErr != nil { + return OpenAIUsage{}, 0, firstTokenMs, writeErr + } + flusher.Flush() + + if data, ok := extractOpenAISSEDataLine(strings.TrimRight(string(line), "\r\n")); ok && data != "" && data != "[DONE]" { + dataBytes := []byte(data) + mergeOpenAIUsage(&usage, dataBytes) + if count := extractOpenAIImageCountFromJSONBytes(dataBytes); count > imageCount { + imageCount = count + } + } + } + if err == io.EOF { + break + } + if err != nil { + return OpenAIUsage{}, 0, firstTokenMs, err + } + } + return usage, imageCount, firstTokenMs, nil +} + +func mergeOpenAIUsage(dst *OpenAIUsage, body []byte) { + if dst == nil { + return + } + if parsed, ok := extractOpenAIUsageFromJSONBytes(body); ok { + if parsed.InputTokens > 0 { + dst.InputTokens = parsed.InputTokens + } + if parsed.OutputTokens > 0 { + dst.OutputTokens = parsed.OutputTokens + } + if parsed.CacheReadInputTokens > 0 { + dst.CacheReadInputTokens = parsed.CacheReadInputTokens + } + if parsed.ImageOutputTokens > 0 { + dst.ImageOutputTokens = parsed.ImageOutputTokens + } + } +} + +func extractOpenAIImageCountFromJSONBytes(body []byte) int { + if len(body) == 0 || !gjson.ValidBytes(body) { + return 0 + } + data := gjson.GetBytes(body, "data") + if data.Exists() && data.IsArray() { + return len(data.Array()) + } + return 0 +} + +func (s *OpenAIGatewayService) forwardOpenAIImagesOAuth( + ctx context.Context, + c *gin.Context, + account *Account, + parsed *OpenAIImagesRequest, + channelMappedModel string, +) (*OpenAIForwardResult, error) { + startTime := time.Now() + requestModel := strings.TrimSpace(parsed.Model) + if mapped := strings.TrimSpace(channelMappedModel); mapped != "" { + requestModel = mapped + } + + token, _, err := s.GetAccessToken(ctx, account) + if err != nil { + return nil, err + } + client, err := newOpenAIBackendAPIClient(resolveOpenAIProxyURL(account)) + if err != nil { + return nil, err + } + headers, err := s.buildOpenAIBackendAPIHeaders(account, token) + if err != nil { + return nil, err + } + if bootstrapErr := bootstrapOpenAIBackendAPI(ctx, client, headers); bootstrapErr != nil { + logger.LegacyPrintf("service.openai_gateway", "OpenAI image bootstrap failed: %v", bootstrapErr) + } + + chatReqs, err := fetchOpenAIChatRequirements(ctx, client, headers) + if err != nil { + return nil, s.wrapOpenAIImageBackendError(ctx, c, account, err) + } + if chatReqs.Arkose.Required { + return nil, s.wrapOpenAIImageBackendError( + ctx, + c, + account, + newOpenAIImageSyntheticStatusError( + http.StatusForbidden, + "chat-requirements requires unsupported challenge (arkose)", + openAIChatGPTChatRequirementsURL, + ), + ) + } + + parentMessageID := uuid.NewString() + proofToken := generateOpenAIProofToken(chatReqs.ProofOfWork.Required, chatReqs.ProofOfWork.Seed, chatReqs.ProofOfWork.Difficulty, headers.Get("User-Agent")) + _ = initializeOpenAIImageConversation(ctx, client, headers) + conduitToken, err := prepareOpenAIImageConversation(ctx, client, headers, parsed.Prompt, parentMessageID, chatReqs.Token, proofToken) + if err != nil { + return nil, s.wrapOpenAIImageBackendError(ctx, c, account, err) + } + + uploads, err := uploadOpenAIImageFiles(ctx, client, headers, parsed.Uploads) + if err != nil { + return nil, s.wrapOpenAIImageBackendError(ctx, c, account, err) + } + + convReq := buildOpenAIImageConversationRequest(parsed, parentMessageID, uploads) + if parsedContent, err := json.Marshal(convReq); err == nil { + setOpsUpstreamRequestBody(c, parsedContent) + } + convHeaders := cloneHTTPHeader(headers) + convHeaders.Set("Accept", "text/event-stream") + convHeaders.Set("Content-Type", "application/json") + convHeaders.Set("openai-sentinel-chat-requirements-token", chatReqs.Token) + if conduitToken != "" { + convHeaders.Set("x-conduit-token", conduitToken) + } + if proofToken != "" { + convHeaders.Set("openai-sentinel-proof-token", proofToken) + } + + resp, err := client.R(). + SetContext(ctx). + DisableAutoReadResponse(). + SetHeaders(headerToMap(convHeaders)). + SetBodyJsonMarshal(convReq). + Post(openAIChatGPTConversationURL) + if err != nil { + return nil, fmt.Errorf("openai image conversation request failed: %w", err) + } + defer func() { + if resp != nil && resp.Body != nil { + _ = resp.Body.Close() + } + }() + if resp.StatusCode >= 400 { + return nil, s.wrapOpenAIImageBackendError(ctx, c, account, handleOpenAIImageBackendError(resp)) + } + + conversationID, pointerInfos, usage, firstTokenMs, err := readOpenAIImageConversationStream(resp, startTime) + if err != nil { + return nil, err + } + pointerInfos = mergeOpenAIImagePointerInfos(pointerInfos, nil) + if conversationID != "" && !hasOpenAIFileServicePointerInfos(pointerInfos) { + polledPointers, pollErr := pollOpenAIImageConversation(ctx, client, headers, conversationID) + if pollErr != nil { + return nil, s.wrapOpenAIImageBackendError(ctx, c, account, pollErr) + } + pointerInfos = mergeOpenAIImagePointerInfos(pointerInfos, polledPointers) + } + pointerInfos = preferOpenAIFileServicePointerInfos(pointerInfos) + if len(pointerInfos) == 0 { + return nil, fmt.Errorf("openai image conversation returned no downloadable images") + } + + responseBody, imageCount, err := buildOpenAIImageResponse(ctx, client, headers, conversationID, pointerInfos) + if err != nil { + return nil, s.wrapOpenAIImageBackendError(ctx, c, account, err) + } + + c.Data(http.StatusOK, "application/json; charset=utf-8", responseBody) + return &OpenAIForwardResult{ + RequestID: resp.Header.Get("x-request-id"), + Usage: usage, + Model: requestModel, + UpstreamModel: requestModel, + Stream: false, + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + ImageCount: imageCount, + ImageSize: parsed.SizeTier, + }, nil +} + +func resolveOpenAIProxyURL(account *Account) string { + if account != nil && account.ProxyID != nil && account.Proxy != nil { + return account.Proxy.URL() + } + return "" +} + +func newOpenAIBackendAPIClient(proxyURL string) (*req.Client, error) { + client := req.C(). + SetTimeout(180 * time.Second). + ImpersonateChrome() + trimmed, _, err := proxyurl.Parse(proxyURL) + if err != nil { + return nil, err + } + if trimmed != "" { + client.SetProxyURL(trimmed) + } + return client, nil +} + +func (s *OpenAIGatewayService) buildOpenAIBackendAPIHeaders(account *Account, token string) (http.Header, error) { + deviceID, sessionID := s.ensureOpenAIImageSessionCredentials(context.Background(), account) + headers := make(http.Header) + headers.Set("Authorization", "Bearer "+token) + headers.Set("Accept", "application/json") + headers.Set("Origin", "https://chatgpt.com") + headers.Set("Referer", "https://chatgpt.com/") + headers.Set("Sec-Fetch-Dest", "empty") + headers.Set("Sec-Fetch-Mode", "cors") + headers.Set("Sec-Fetch-Site", "same-origin") + headers.Set("User-Agent", openAIImageBackendUserAgent) + if customUA := strings.TrimSpace(account.GetOpenAIUserAgent()); customUA != "" { + headers.Set("User-Agent", customUA) + } + if chatgptAccountID := strings.TrimSpace(account.GetChatGPTAccountID()); chatgptAccountID != "" { + headers.Set("chatgpt-account-id", chatgptAccountID) + } + if deviceID != "" { + headers.Set("oai-device-id", deviceID) + headers.Set("Cookie", "oai-did="+deviceID) + } + if sessionID != "" { + headers.Set("oai-session-id", sessionID) + } + return headers, nil +} + +func (s *OpenAIGatewayService) ensureOpenAIImageSessionCredentials(ctx context.Context, account *Account) (string, string) { + if account == nil { + return "", "" + } + deviceID := account.GetOpenAIDeviceID() + sessionID := account.GetOpenAISessionID() + if deviceID != "" && sessionID != "" { + return deviceID, sessionID + } + + updates := map[string]any{} + if deviceID == "" { + deviceID = uuid.NewString() + updates["openai_device_id"] = deviceID + } + if sessionID == "" { + sessionID = uuid.NewString() + updates["openai_session_id"] = sessionID + } + if account.Extra == nil { + account.Extra = map[string]any{} + } + for key, value := range updates { + account.Extra[key] = value + } + if len(updates) == 0 || s == nil || s.accountRepo == nil { + return deviceID, sessionID + } + + updateCtx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + if err := s.accountRepo.UpdateExtra(updateCtx, account.ID, updates); err != nil { + logger.LegacyPrintf("service.openai_gateway", "persist openai image session creds failed: account=%d err=%v", account.ID, err) + } + return deviceID, sessionID +} + +func bootstrapOpenAIBackendAPI(ctx context.Context, client *req.Client, headers http.Header) error { + resp, err := client.R(). + SetContext(ctx). + DisableAutoReadResponse(). + SetHeaders(headerToMap(headers)). + Get(openAIChatGPTStartURL) + if err != nil { + return err + } + if resp != nil && resp.Body != nil { + _, _ = io.Copy(io.Discard, resp.Body) + _ = resp.Body.Close() + } + return nil +} + +func initializeOpenAIImageConversation(ctx context.Context, client *req.Client, headers http.Header) error { + payload := map[string]any{ + "gizmo_id": nil, + "requested_default_model": nil, + "conversation_id": nil, + "timezone_offset_min": openAITimezoneOffsetMinutes(), + "system_hints": []string{"picture_v2"}, + } + resp, err := client.R(). + SetContext(ctx). + SetHeaders(headerToMap(headers)). + SetBodyJsonMarshal(payload). + Post(openAIChatGPTConversationInitURL) + if err != nil { + return err + } + if !resp.IsSuccessState() { + return newOpenAIImageStatusError(resp, "conversation init failed") + } + return nil +} + +type openAIChatRequirements struct { + Token string `json:"token"` + Turnstile struct { + Required bool `json:"required"` + } `json:"turnstile"` + Arkose struct { + Required bool `json:"required"` + } `json:"arkose"` + ProofOfWork struct { + Required bool `json:"required"` + Seed string `json:"seed"` + Difficulty string `json:"difficulty"` + } `json:"proofofwork"` +} + +func fetchOpenAIChatRequirements(ctx context.Context, client *req.Client, headers http.Header) (*openAIChatRequirements, error) { + var lastErr error + for _, payload := range []map[string]any{ + {"p": nil}, + {"p": generateOpenAIRequirementsToken(headers.Get("User-Agent"))}, + } { + var result openAIChatRequirements + resp, err := client.R(). + SetContext(ctx). + SetHeaders(headerToMap(headers)). + SetBodyJsonMarshal(payload). + SetSuccessResult(&result). + Post(openAIChatGPTChatRequirementsURL) + if err != nil { + lastErr = err + continue + } + if resp.IsSuccessState() && strings.TrimSpace(result.Token) != "" { + return &result, nil + } + lastErr = newOpenAIImageStatusError(resp, "chat-requirements failed") + } + if lastErr == nil { + lastErr = fmt.Errorf("chat-requirements failed") + } + return nil, lastErr +} + +func prepareOpenAIImageConversation( + ctx context.Context, + client *req.Client, + headers http.Header, + prompt string, + parentMessageID string, + chatToken string, + proofToken string, +) (string, error) { + messageID := uuid.NewString() + payload := map[string]any{ + "action": "next", + "client_prepare_state": "success", + "fork_from_shared_post": false, + "parent_message_id": parentMessageID, + "model": "auto", + "timezone_offset_min": openAITimezoneOffsetMinutes(), + "timezone": openAITimezoneName(), + "conversation_mode": map[string]any{"kind": "primary_assistant"}, + "system_hints": []string{"picture_v2"}, + "supports_buffering": true, + "supported_encodings": []string{"v1"}, + "partial_query": map[string]any{ + "id": messageID, + "author": map[string]any{"role": "user"}, + "content": map[string]any{ + "content_type": "text", + "parts": []string{coalesceOpenAIFileName(prompt, "Generate an image.")}, + }, + }, + "client_contextual_info": map[string]any{ + "app_name": "chatgpt.com", + }, + } + prepareHeaders := cloneHTTPHeader(headers) + prepareHeaders.Set("Accept", "*/*") + prepareHeaders.Set("Content-Type", "application/json") + if strings.TrimSpace(chatToken) != "" { + prepareHeaders.Set("openai-sentinel-chat-requirements-token", strings.TrimSpace(chatToken)) + } + if strings.TrimSpace(proofToken) != "" { + prepareHeaders.Set("openai-sentinel-proof-token", strings.TrimSpace(proofToken)) + } + var result struct { + ConduitToken string `json:"conduit_token"` + } + resp, err := client.R(). + SetContext(ctx). + SetHeaders(headerToMap(prepareHeaders)). + SetBodyJsonMarshal(payload). + SetSuccessResult(&result). + Post(openAIChatGPTConversationPrepareURL) + if err != nil { + return "", err + } + if !resp.IsSuccessState() { + return "", newOpenAIImageStatusError(resp, "conversation prepare failed") + } + return strings.TrimSpace(result.ConduitToken), nil +} + +type openAIUploadedImage struct { + FileID string + FileName string + FileSize int + MimeType string + Width int + Height int +} + +func uploadOpenAIImageFiles(ctx context.Context, client *req.Client, headers http.Header, uploads []OpenAIImagesUpload) ([]openAIUploadedImage, error) { + if len(uploads) == 0 { + return nil, nil + } + results := make([]openAIUploadedImage, 0, len(uploads)) + for i := range uploads { + item := uploads[i] + fileName := coalesceOpenAIFileName(item.FileName, "image.png") + payload := map[string]any{ + "file_name": fileName, + "file_size": len(item.Data), + "use_case": "multimodal", + } + var created struct { + FileID string `json:"file_id"` + UploadURL string `json:"upload_url"` + } + resp, err := client.R(). + SetContext(ctx). + SetHeaders(headerToMap(headers)). + SetBodyJsonMarshal(payload). + SetSuccessResult(&created). + Post(openAIChatGPTFilesURL) + if err != nil { + return nil, err + } + if !resp.IsSuccessState() || strings.TrimSpace(created.FileID) == "" || strings.TrimSpace(created.UploadURL) == "" { + return nil, newOpenAIImageStatusError(resp, "create upload slot failed") + } + + uploadHeaders := map[string]string{ + "Content-Type": coalesceOpenAIFileName(item.ContentType, "application/octet-stream"), + "Origin": "https://chatgpt.com", + "x-ms-blob-type": "BlockBlob", + "x-ms-version": "2020-04-08", + "User-Agent": headers.Get("User-Agent"), + } + putResp, err := client.R(). + SetContext(ctx). + SetHeaders(uploadHeaders). + SetBody(item.Data). + DisableAutoReadResponse(). + Put(created.UploadURL) + if err != nil { + return nil, err + } + if putResp.Response != nil && putResp.Response.Body != nil { + _, _ = io.Copy(io.Discard, putResp.Response.Body) + _ = putResp.Response.Body.Close() + } + if putResp.StatusCode < 200 || putResp.StatusCode >= 300 { + return nil, newOpenAIImageStatusError(putResp, "upload image bytes failed") + } + + uploadedResp, err := client.R(). + SetContext(ctx). + SetHeaders(headerToMap(headers)). + SetBodyJsonMarshal(map[string]any{}). + Post(fmt.Sprintf("%s/%s/uploaded", openAIChatGPTFilesURL, created.FileID)) + if err != nil { + return nil, err + } + if !uploadedResp.IsSuccessState() { + return nil, newOpenAIImageStatusError(uploadedResp, "mark upload complete failed") + } + + results = append(results, openAIUploadedImage{ + FileID: created.FileID, + FileName: fileName, + FileSize: len(item.Data), + MimeType: coalesceOpenAIFileName(item.ContentType, "application/octet-stream"), + Width: item.Width, + Height: item.Height, + }) + } + return results, nil +} + +func coalesceOpenAIFileName(value string, fallback string) string { + value = strings.TrimSpace(value) + if value == "" { + return fallback + } + return value +} + +func buildOpenAIImageConversationRequest(parsed *OpenAIImagesRequest, parentMessageID string, uploads []openAIUploadedImage) map[string]any { + parts := []any{coalesceOpenAIFileName(parsed.Prompt, "Generate an image.")} + attachments := make([]map[string]any, 0, len(uploads)) + if len(uploads) > 0 { + parts = make([]any, 0, len(uploads)+1) + for _, upload := range uploads { + parts = append(parts, map[string]any{ + "content_type": "image_asset_pointer", + "asset_pointer": "file-service://" + upload.FileID, + "size_bytes": upload.FileSize, + "width": upload.Width, + "height": upload.Height, + }) + attachment := map[string]any{ + "id": upload.FileID, + "mimeType": upload.MimeType, + "name": upload.FileName, + "size": upload.FileSize, + } + if upload.Width > 0 { + attachment["width"] = upload.Width + } + if upload.Height > 0 { + attachment["height"] = upload.Height + } + attachments = append(attachments, attachment) + } + parts = append(parts, coalesceOpenAIFileName(parsed.Prompt, "Edit this image.")) + } + + contentType := "text" + if len(uploads) > 0 { + contentType = "multimodal_text" + } + metadata := map[string]any{ + "developer_mode_connector_ids": []any{}, + "selected_github_repos": []any{}, + "selected_all_github_repos": false, + "system_hints": []string{"picture_v2"}, + "serialization_metadata": map[string]any{ + "custom_symbol_offsets": []any{}, + }, + } + message := map[string]any{ + "id": uuid.NewString(), + "author": map[string]any{"role": "user"}, + "content": map[string]any{ + "content_type": contentType, + "parts": parts, + }, + "metadata": metadata, + "create_time": float64(time.Now().UnixMilli()) / 1000, + } + if len(attachments) > 0 { + metadata["attachments"] = attachments + } + + return map[string]any{ + "action": "next", + "client_prepare_state": "sent", + "parent_message_id": parentMessageID, + "model": "auto", + "timezone_offset_min": openAITimezoneOffsetMinutes(), + "timezone": openAITimezoneName(), + "conversation_mode": map[string]any{"kind": "primary_assistant"}, + "enable_message_followups": true, + "system_hints": []string{"picture_v2"}, + "supports_buffering": true, + "supported_encodings": []string{"v1"}, + "paragen_cot_summary_display_override": "allow", + "force_parallel_switch": "auto", + "client_contextual_info": map[string]any{ + "is_dark_mode": false, + "time_since_loaded": 200, + "page_height": 900, + "page_width": 1440, + "pixel_ratio": 1, + "screen_height": 1080, + "screen_width": 1920, + "app_name": "chatgpt.com", + }, + "messages": []any{message}, + } +} + +type openAIImagePointerInfo struct { + Pointer string + Prompt string +} + +type openAIImageToolMessage struct { + MessageID string + CreateTime float64 + PointerInfos []openAIImagePointerInfo +} + +func readOpenAIImageConversationStream(resp *req.Response, startTime time.Time) (string, []openAIImagePointerInfo, OpenAIUsage, *int, error) { + if resp == nil || resp.Response == nil || resp.Response.Body == nil { + return "", nil, OpenAIUsage{}, nil, fmt.Errorf("empty conversation response") + } + reader := bufio.NewReader(resp.Response.Body) + var ( + conversationID string + firstTokenMs *int + usage OpenAIUsage + pointers []openAIImagePointerInfo + ) + + for { + line, err := reader.ReadString('\n') + if strings.TrimSpace(line) != "" && firstTokenMs == nil { + ms := int(time.Since(startTime).Milliseconds()) + firstTokenMs = &ms + } + if data, ok := extractOpenAISSEDataLine(strings.TrimRight(line, "\r\n")); ok && data != "" && data != "[DONE]" { + dataBytes := []byte(data) + if conversationID == "" { + conversationID = strings.TrimSpace(gjson.GetBytes(dataBytes, "v.conversation_id").String()) + if conversationID == "" { + conversationID = strings.TrimSpace(gjson.GetBytes(dataBytes, "conversation_id").String()) + } + } + mergeOpenAIUsage(&usage, dataBytes) + pointers = mergeOpenAIImagePointerInfos(pointers, collectOpenAIImagePointers(dataBytes)) + } + if err == io.EOF { + break + } + if err != nil { + return "", nil, OpenAIUsage{}, firstTokenMs, err + } + } + return conversationID, pointers, usage, firstTokenMs, nil +} + +func collectOpenAIImagePointers(body []byte) []openAIImagePointerInfo { + if len(body) == 0 { + return nil + } + matches := openAIImagePointerMatches(body) + if len(matches) == 0 { + return nil + } + prompt := "" + for _, path := range []string{ + "message.metadata.dalle.prompt", + "metadata.dalle.prompt", + "revised_prompt", + } { + if value := strings.TrimSpace(gjson.GetBytes(body, path).String()); value != "" { + prompt = value + break + } + } + out := make([]openAIImagePointerInfo, 0, len(matches)) + for _, pointer := range matches { + out = append(out, openAIImagePointerInfo{Pointer: pointer, Prompt: prompt}) + } + return out +} + +func openAIImagePointerMatches(body []byte) []string { + raw := string(body) + matches := make([]string, 0, 4) + for _, prefix := range []string{"file-service://", "sediment://"} { + start := 0 + for { + idx := strings.Index(raw[start:], prefix) + if idx < 0 { + break + } + idx += start + end := idx + len(prefix) + for end < len(raw) { + ch := raw[end] + if ch != '-' && ch != '_' && + (ch < '0' || ch > '9') && + (ch < 'a' || ch > 'z') && + (ch < 'A' || ch > 'Z') { + break + } + end++ + } + matches = append(matches, raw[idx:end]) + start = end + } + } + return dedupeStrings(matches) +} + +func mergeOpenAIImagePointerInfos(existing []openAIImagePointerInfo, next []openAIImagePointerInfo) []openAIImagePointerInfo { + if len(next) == 0 { + return existing + } + seen := make(map[string]openAIImagePointerInfo, len(existing)+len(next)) + out := make([]openAIImagePointerInfo, 0, len(existing)+len(next)) + for _, item := range existing { + seen[item.Pointer] = item + out = append(out, item) + } + for _, item := range next { + if existingItem, ok := seen[item.Pointer]; ok { + if existingItem.Prompt == "" && item.Prompt != "" { + for i := range out { + if out[i].Pointer == item.Pointer { + out[i].Prompt = item.Prompt + break + } + } + } + continue + } + seen[item.Pointer] = item + out = append(out, item) + } + return out +} + +func hasOpenAIFileServicePointerInfos(items []openAIImagePointerInfo) bool { + for _, item := range items { + if strings.HasPrefix(item.Pointer, "file-service://") { + return true + } + } + return false +} + +func preferOpenAIFileServicePointerInfos(items []openAIImagePointerInfo) []openAIImagePointerInfo { + if !hasOpenAIFileServicePointerInfos(items) { + return items + } + out := make([]openAIImagePointerInfo, 0, len(items)) + for _, item := range items { + if strings.HasPrefix(item.Pointer, "file-service://") { + out = append(out, item) + } + } + return out +} + +func extractOpenAIImageToolMessages(mapping map[string]any) []openAIImageToolMessage { + if len(mapping) == 0 { + return nil + } + out := make([]openAIImageToolMessage, 0, 4) + for messageID, raw := range mapping { + node, _ := raw.(map[string]any) + if node == nil { + continue + } + message, _ := node["message"].(map[string]any) + if message == nil { + continue + } + author, _ := message["author"].(map[string]any) + metadata, _ := message["metadata"].(map[string]any) + content, _ := message["content"].(map[string]any) + if author == nil || metadata == nil || content == nil { + continue + } + if role, _ := author["role"].(string); role != "tool" { + continue + } + if asyncTaskType, _ := metadata["async_task_type"].(string); asyncTaskType != "image_gen" { + continue + } + if contentType, _ := content["content_type"].(string); contentType != "multimodal_text" { + continue + } + prompt := "" + if title, _ := metadata["image_gen_title"].(string); strings.TrimSpace(title) != "" { + prompt = strings.TrimSpace(title) + } + item := openAIImageToolMessage{MessageID: messageID} + if createTime, ok := message["create_time"].(float64); ok { + item.CreateTime = createTime + } + parts, _ := content["parts"].([]any) + for _, part := range parts { + switch value := part.(type) { + case map[string]any: + if assetPointer, _ := value["asset_pointer"].(string); strings.TrimSpace(assetPointer) != "" { + for _, pointer := range openAIImagePointerMatches([]byte(assetPointer)) { + item.PointerInfos = append(item.PointerInfos, openAIImagePointerInfo{ + Pointer: pointer, + Prompt: prompt, + }) + } + } + case string: + for _, pointer := range openAIImagePointerMatches([]byte(value)) { + item.PointerInfos = append(item.PointerInfos, openAIImagePointerInfo{ + Pointer: pointer, + Prompt: prompt, + }) + } + } + } + if len(item.PointerInfos) == 0 { + continue + } + item.PointerInfos = mergeOpenAIImagePointerInfos(nil, item.PointerInfos) + out = append(out, item) + } + sort.Slice(out, func(i, j int) bool { + return out[i].CreateTime < out[j].CreateTime + }) + return out +} + +func pollOpenAIImageConversation(ctx context.Context, client *req.Client, headers http.Header, conversationID string) ([]openAIImagePointerInfo, error) { + conversationID = strings.TrimSpace(conversationID) + if conversationID == "" { + return nil, nil + } + deadline := time.Now().Add(90 * time.Second) + interval := 3 * time.Second + previewWait := 15 * time.Second + var ( + lastErr error + firstToolAt time.Time + ) + for time.Now().Before(deadline) { + resp, err := client.R(). + SetContext(ctx). + SetHeaders(headerToMap(headers)). + DisableAutoReadResponse(). + Get(fmt.Sprintf("https://chatgpt.com/backend-api/conversation/%s", conversationID)) + if err != nil { + lastErr = err + } else { + if resp.StatusCode >= 200 && resp.StatusCode < 300 { + body, readErr := io.ReadAll(resp.Response.Body) + _ = resp.Response.Body.Close() + if readErr != nil { + lastErr = readErr + goto waitNextPoll + } + pointers := mergeOpenAIImagePointerInfos(nil, collectOpenAIImagePointers(body)) + var decoded map[string]any + if err := json.Unmarshal(body, &decoded); err == nil { + if mapping, _ := decoded["mapping"].(map[string]any); len(mapping) > 0 { + toolMessages := extractOpenAIImageToolMessages(mapping) + if len(toolMessages) > 0 && firstToolAt.IsZero() { + firstToolAt = time.Now() + } + for _, msg := range toolMessages { + pointers = mergeOpenAIImagePointerInfos(pointers, msg.PointerInfos) + } + } + } + if hasOpenAIFileServicePointerInfos(pointers) { + return preferOpenAIFileServicePointerInfos(pointers), nil + } + if len(pointers) > 0 && !firstToolAt.IsZero() && time.Since(firstToolAt) >= previewWait { + return pointers, nil + } + } else { + statusErr := newOpenAIImageStatusError(resp, "conversation poll failed") + if isOpenAIImageTransientConversationNotFoundError(statusErr) { + lastErr = statusErr + goto waitNextPoll + } + return nil, statusErr + } + } + + waitNextPoll: + timer := time.NewTimer(interval) + select { + case <-ctx.Done(): + if !timer.Stop() { + <-timer.C + } + return nil, ctx.Err() + case <-timer.C: + } + } + return nil, lastErr +} + +func buildOpenAIImageResponse( + ctx context.Context, + client *req.Client, + headers http.Header, + conversationID string, + pointers []openAIImagePointerInfo, +) ([]byte, int, error) { + type responseItem struct { + B64JSON string `json:"b64_json"` + RevisedPrompt string `json:"revised_prompt,omitempty"` + } + items := make([]responseItem, 0, len(pointers)) + for _, pointer := range pointers { + downloadURL, err := fetchOpenAIImageDownloadURL(ctx, client, headers, conversationID, pointer.Pointer) + if err != nil { + return nil, 0, err + } + data, err := downloadOpenAIImageBytes(ctx, client, headers, downloadURL) + if err != nil { + return nil, 0, err + } + items = append(items, responseItem{ + B64JSON: base64.StdEncoding.EncodeToString(data), + RevisedPrompt: pointer.Prompt, + }) + } + payload := map[string]any{ + "created": time.Now().Unix(), + "data": items, + } + body, err := json.Marshal(payload) + if err != nil { + return nil, 0, err + } + return body, len(items), nil +} + +func fetchOpenAIImageDownloadURL( + ctx context.Context, + client *req.Client, + headers http.Header, + conversationID string, + pointer string, +) (string, error) { + url := "" + allowConversationRetry := false + switch { + case strings.HasPrefix(pointer, "file-service://"): + fileID := strings.TrimPrefix(pointer, "file-service://") + url = fmt.Sprintf("%s/%s/download", openAIChatGPTFilesURL, fileID) + case strings.HasPrefix(pointer, "sediment://"): + attachmentID := strings.TrimPrefix(pointer, "sediment://") + url = fmt.Sprintf("https://chatgpt.com/backend-api/conversation/%s/attachment/%s/download", conversationID, attachmentID) + allowConversationRetry = true + default: + return "", fmt.Errorf("unsupported image pointer: %s", pointer) + } + + var lastErr error + for attempt := 0; attempt < 8; attempt++ { + var result struct { + DownloadURL string `json:"download_url"` + } + resp, err := client.R(). + SetContext(ctx). + SetHeaders(headerToMap(headers)). + SetSuccessResult(&result). + Get(url) + if err != nil { + lastErr = err + } else if resp.IsSuccessState() && strings.TrimSpace(result.DownloadURL) != "" { + return strings.TrimSpace(result.DownloadURL), nil + } else { + statusErr := newOpenAIImageStatusError(resp, "fetch image download url failed") + if !allowConversationRetry || !isOpenAIImageTransientConversationNotFoundError(statusErr) { + return "", statusErr + } + lastErr = statusErr + } + if attempt == 7 { + break + } + timer := time.NewTimer(750 * time.Millisecond) + select { + case <-ctx.Done(): + if !timer.Stop() { + <-timer.C + } + return "", ctx.Err() + case <-timer.C: + } + } + if lastErr == nil { + lastErr = fmt.Errorf("fetch image download url failed") + } + return "", lastErr +} + +func downloadOpenAIImageBytes(ctx context.Context, client *req.Client, headers http.Header, downloadURL string) ([]byte, error) { + request := client.R(). + SetContext(ctx). + DisableAutoReadResponse() + + if strings.HasPrefix(downloadURL, openAIChatGPTStartURL) { + downloadHeaders := cloneHTTPHeader(headers) + downloadHeaders.Set("Accept", "image/*,*/*;q=0.8") + downloadHeaders.Del("Content-Type") + request.SetHeaders(headerToMap(downloadHeaders)) + } else { + userAgent := strings.TrimSpace(headers.Get("User-Agent")) + if userAgent == "" { + userAgent = openAIImageBackendUserAgent + } + request.SetHeader("User-Agent", userAgent) + } + + resp, err := request.Get(downloadURL) + if err != nil { + return nil, err + } + defer func() { + if resp != nil && resp.Body != nil { + _ = resp.Body.Close() + } + }() + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return nil, newOpenAIImageStatusError(resp, "download image bytes failed") + } + return io.ReadAll(resp.Body) +} + +func handleOpenAIImageBackendError(resp *req.Response) error { + return newOpenAIImageStatusError(resp, "backend-api request failed") +} + +type openAIImageStatusError struct { + StatusCode int + Message string + ResponseBody []byte + ResponseHeaders http.Header + RequestID string + URL string +} + +func (e *openAIImageStatusError) Error() string { + if e == nil { + return "openai image backend request failed" + } + if e.Message != "" { + return e.Message + } + if e.StatusCode > 0 { + return fmt.Sprintf("openai image backend request failed: status %d", e.StatusCode) + } + return "openai image backend request failed" +} + +func newOpenAIImageStatusError(resp *req.Response, fallback string) error { + if resp == nil { + if strings.TrimSpace(fallback) == "" { + fallback = "openai image backend request failed" + } + return fmt.Errorf("%s", fallback) + } + + statusCode := resp.StatusCode + headers := http.Header(nil) + requestID := "" + requestURL := "" + body := []byte(nil) + + if resp.Response != nil { + headers = resp.Response.Header.Clone() + requestID = strings.TrimSpace(resp.Response.Header.Get("x-request-id")) + if resp.Response.Request != nil && resp.Response.Request.URL != nil { + requestURL = resp.Response.Request.URL.String() + } + if resp.Response.Body != nil { + body, _ = io.ReadAll(io.LimitReader(resp.Response.Body, 2<<20)) + _ = resp.Response.Body.Close() + } + } + + message := sanitizeUpstreamErrorMessage(extractUpstreamErrorMessage(body)) + if message == "" { + prefix := strings.TrimSpace(fallback) + if prefix == "" { + prefix = "openai image backend request failed" + } + message = fmt.Sprintf("%s: status %d", prefix, statusCode) + } + + return &openAIImageStatusError{ + StatusCode: statusCode, + Message: message, + ResponseBody: body, + ResponseHeaders: headers, + RequestID: requestID, + URL: requestURL, + } +} + +func newOpenAIImageSyntheticStatusError(statusCode int, message string, requestURL string) *openAIImageStatusError { + message = sanitizeUpstreamErrorMessage(strings.TrimSpace(message)) + if message == "" { + message = "openai image backend request failed" + } + var body []byte + if payload, err := json.Marshal(map[string]string{"detail": message}); err == nil { + body = payload + } + return &openAIImageStatusError{ + StatusCode: statusCode, + Message: message, + ResponseBody: body, + URL: strings.TrimSpace(requestURL), + } +} + +func isOpenAIImageTransientConversationNotFoundError(err error) bool { + statusErr, ok := err.(*openAIImageStatusError) + if !ok || statusErr == nil || statusErr.StatusCode != http.StatusNotFound { + return false + } + msg := strings.ToLower(strings.TrimSpace(statusErr.Message)) + if strings.Contains(msg, "conversation_not_found") { + return true + } + if strings.Contains(msg, "conversation") && strings.Contains(msg, "not found") { + return true + } + bodyMsg := strings.ToLower(strings.TrimSpace(extractUpstreamErrorMessage(statusErr.ResponseBody))) + if strings.Contains(bodyMsg, "conversation_not_found") { + return true + } + return strings.Contains(bodyMsg, "conversation") && strings.Contains(bodyMsg, "not found") +} + +func (s *OpenAIGatewayService) wrapOpenAIImageBackendError( + ctx context.Context, + c *gin.Context, + account *Account, + err error, +) error { + var statusErr *openAIImageStatusError + if !errors.As(err, &statusErr) || statusErr == nil { + return err + } + + upstreamMsg := sanitizeUpstreamErrorMessage(statusErr.Message) + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: statusErr.StatusCode, + UpstreamRequestID: statusErr.RequestID, + UpstreamURL: safeUpstreamURL(statusErr.URL), + Kind: "request_error", + Message: upstreamMsg, + }) + setOpsUpstreamError(c, statusErr.StatusCode, upstreamMsg, "") + + if s.shouldFailoverOpenAIUpstreamResponse(statusErr.StatusCode, upstreamMsg, statusErr.ResponseBody) { + if s.rateLimitService != nil { + s.rateLimitService.HandleUpstreamError(ctx, account, statusErr.StatusCode, statusErr.ResponseHeaders, statusErr.ResponseBody) + } + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: statusErr.StatusCode, + UpstreamRequestID: statusErr.RequestID, + UpstreamURL: safeUpstreamURL(statusErr.URL), + Kind: "failover", + Message: upstreamMsg, + }) + retryableOnSameAccount := account.IsPoolMode() && isPoolModeRetryableStatus(statusErr.StatusCode) + if strings.Contains(strings.ToLower(statusErr.Message), "unsupported challenge") { + retryableOnSameAccount = false + } + return &UpstreamFailoverError{ + StatusCode: statusErr.StatusCode, + ResponseBody: statusErr.ResponseBody, + RetryableOnSameAccount: retryableOnSameAccount, + } + } + + return statusErr +} + +func cloneHTTPHeader(src http.Header) http.Header { + dst := make(http.Header, len(src)) + for key, values := range src { + copied := make([]string, len(values)) + copy(copied, values) + dst[key] = copied + } + return dst +} + +func headerToMap(header http.Header) map[string]string { + if len(header) == 0 { + return nil + } + result := make(map[string]string, len(header)) + for key, values := range header { + if len(values) == 0 { + continue + } + result[key] = values[0] + } + return result +} + +func openAITimezoneOffsetMinutes() int { + _, offset := time.Now().Zone() + return offset / 60 +} + +func openAITimezoneName() string { + return time.Now().Location().String() +} + +func generateOpenAIRequirementsToken(userAgent string) string { + config := []any{ + "core" + strconv.Itoa(3008), + time.Now().UTC().Format(time.RFC1123), + nil, + 0.123456, + coalesceOpenAIFileName(strings.TrimSpace(userAgent), openAIImageBackendUserAgent), + nil, + "prod-openai-images", + "en-US", + "en-US,en", + 0, + "navigator.webdriver", + "location", + "document.body", + float64(time.Now().UnixMilli()) / 1000, + uuid.NewString(), + "", + 8, + time.Now().Unix(), + } + answer, solved := generateOpenAIChallengeAnswer(strconv.FormatInt(time.Now().UnixNano(), 10), openAIImageRequirementsDiff, config) + if solved { + return "gAAAAAC" + answer + } + return "" +} + +func generateOpenAIChallengeAnswer(seed string, difficulty string, config []any) (string, bool) { + diffBytes, err := hex.DecodeString(difficulty) + if err != nil { + return "", false + } + p1 := []byte(jsonCompactSlice(config[:3], true)) + p2 := []byte(jsonCompactSlice(config[4:9], false)) + p3 := []byte(jsonCompactSlice(config[10:], false)) + seedBytes := []byte(seed) + + for i := 0; i < 100000; i++ { + payload := fmt.Sprintf("%s%d,%s,%d,%s", p1, i, p2, i>>1, p3) + encoded := base64.StdEncoding.EncodeToString([]byte(payload)) + sum := sha3.Sum512(append(seedBytes, []byte(encoded)...)) + if bytes.Compare(sum[:len(diffBytes)], diffBytes) <= 0 { + return encoded, true + } + } + return "", false +} + +func jsonCompactSlice(values []any, trimSuffixComma bool) string { + raw, _ := json.Marshal(values) + text := string(raw) + if trimSuffixComma { + return strings.TrimSuffix(text, "]") + } + return strings.TrimPrefix(text, "[") +} + +func generateOpenAIProofToken(required bool, seed string, difficulty string, userAgent string) string { + if !required || strings.TrimSpace(seed) == "" || strings.TrimSpace(difficulty) == "" { + return "" + } + screen := 3008 + if len(seed)%2 == 0 { + screen = 4010 + } + proofToken := []any{ + screen, + time.Now().UTC().Format(time.RFC1123), + nil, + 0, + coalesceOpenAIFileName(strings.TrimSpace(userAgent), openAIImageBackendUserAgent), + "https://chatgpt.com/", + "dpl=openai-images", + "en", + "en-US", + nil, + "plugins[object PluginArray]", + "_reactListening", + "alert", + } + diffLen := len(difficulty) + for i := 0; i < 100000; i++ { + proofToken[3] = i + raw, _ := json.Marshal(proofToken) + encoded := base64.StdEncoding.EncodeToString(raw) + sum := sha3.Sum512([]byte(seed + encoded)) + if strings.Compare(hex.EncodeToString(sum[:])[:diffLen], difficulty) <= 0 { + return "gAAAAAB" + encoded + } + } + fallbackBase := base64.StdEncoding.EncodeToString([]byte(fmt.Sprintf("%q", seed))) + return "gAAAAABwQ8Lk5FbGpA2NcR9dShT6gYjU7VxZ4D" + fallbackBase +} + +func dedupeStrings(values []string) []string { + if len(values) == 0 { + return nil + } + seen := make(map[string]struct{}, len(values)) + out := make([]string, 0, len(values)) + for _, value := range values { + if _, ok := seen[value]; ok { + continue + } + seen[value] = struct{}{} + out = append(out, value) + } + return out +} diff --git a/backend/internal/service/openai_images_test.go b/backend/internal/service/openai_images_test.go new file mode 100644 index 00000000..173d69ba --- /dev/null +++ b/backend/internal/service/openai_images_test.go @@ -0,0 +1,105 @@ +package service + +import ( + "bytes" + "mime/multipart" + "net/http" + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func TestOpenAIGatewayServiceParseOpenAIImagesRequest_JSON(t *testing.T) { + gin.SetMode(gin.TestMode) + body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","size":"1024x1024","quality":"high","stream":true}`) + + req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = req + + svc := &OpenAIGatewayService{} + parsed, err := svc.ParseOpenAIImagesRequest(c, body) + require.NoError(t, err) + require.NotNil(t, parsed) + require.Equal(t, "/v1/images/generations", parsed.Endpoint) + require.Equal(t, "gpt-image-2", parsed.Model) + require.Equal(t, "draw a cat", parsed.Prompt) + require.True(t, parsed.Stream) + require.Equal(t, "1024x1024", parsed.Size) + require.Equal(t, "1K", parsed.SizeTier) + require.Equal(t, OpenAIImagesCapabilityNative, parsed.RequiredCapability) + require.False(t, parsed.Multipart) +} + +func TestOpenAIGatewayServiceParseOpenAIImagesRequest_MultipartEdit(t *testing.T) { + gin.SetMode(gin.TestMode) + + var body bytes.Buffer + writer := multipart.NewWriter(&body) + require.NoError(t, writer.WriteField("model", "gpt-image-2")) + require.NoError(t, writer.WriteField("prompt", "replace background")) + require.NoError(t, writer.WriteField("size", "1536x1024")) + part, err := writer.CreateFormFile("image", "source.png") + require.NoError(t, err) + _, err = part.Write([]byte("fake-image-bytes")) + require.NoError(t, err) + require.NoError(t, writer.Close()) + + req := httptest.NewRequest(http.MethodPost, "/v1/images/edits", bytes.NewReader(body.Bytes())) + req.Header.Set("Content-Type", writer.FormDataContentType()) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = req + + svc := &OpenAIGatewayService{} + parsed, err := svc.ParseOpenAIImagesRequest(c, body.Bytes()) + require.NoError(t, err) + require.NotNil(t, parsed) + require.Equal(t, "/v1/images/edits", parsed.Endpoint) + require.True(t, parsed.Multipart) + require.Equal(t, "gpt-image-2", parsed.Model) + require.Equal(t, "replace background", parsed.Prompt) + require.Equal(t, "1536x1024", parsed.Size) + require.Equal(t, "2K", parsed.SizeTier) + require.Len(t, parsed.Uploads, 1) + require.Equal(t, OpenAIImagesCapabilityNative, parsed.RequiredCapability) +} + +func TestOpenAIGatewayServiceParseOpenAIImagesRequest_PromptOnlyDefaultsRemainBasic(t *testing.T) { + gin.SetMode(gin.TestMode) + body := []byte(`{"prompt":"draw a cat"}`) + + req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = req + + svc := &OpenAIGatewayService{} + parsed, err := svc.ParseOpenAIImagesRequest(c, body) + require.NoError(t, err) + require.NotNil(t, parsed) + require.Equal(t, "gpt-image-2", parsed.Model) + require.Equal(t, OpenAIImagesCapabilityBasic, parsed.RequiredCapability) +} + +func TestOpenAIGatewayServiceParseOpenAIImagesRequest_ExplicitSizeRequiresNativeCapability(t *testing.T) { + gin.SetMode(gin.TestMode) + body := []byte(`{"prompt":"draw a cat","size":"1024x1024"}`) + + req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = req + + svc := &OpenAIGatewayService{} + parsed, err := svc.ParseOpenAIImagesRequest(c, body) + require.NoError(t, err) + require.NotNil(t, parsed) + require.Equal(t, OpenAIImagesCapabilityNative, parsed.RequiredCapability) +} diff --git a/backend/internal/service/ops_retry.go b/backend/internal/service/ops_retry.go index c0e814ab..bd40d389 100644 --- a/backend/internal/service/ops_retry.go +++ b/backend/internal/service/ops_retry.go @@ -388,7 +388,7 @@ func (s *OpsService) executeRetry(ctx context.Context, errorLog *OpsErrorLogDeta func detectOpsRetryType(path string) opsRetryRequestType { p := strings.ToLower(strings.TrimSpace(path)) switch { - case strings.Contains(p, "/responses"): + case strings.Contains(p, "/responses"), strings.Contains(p, "/images/"): return opsRetryTypeOpenAI case strings.Contains(p, "/v1beta/"): return opsRetryTypeGeminiV1B diff --git a/backend/internal/web/embed_on.go b/backend/internal/web/embed_on.go index 89d09eef..5f3719be 100644 --- a/backend/internal/web/embed_on.go +++ b/backend/internal/web/embed_on.go @@ -305,7 +305,8 @@ func shouldBypassEmbeddedFrontend(path string) bool { strings.HasPrefix(trimmed, "/setup/") || trimmed == "/health" || trimmed == "/responses" || - strings.HasPrefix(trimmed, "/responses/") + strings.HasPrefix(trimmed, "/responses/") || + strings.HasPrefix(trimmed, "/images/") } func serveIndexHTML(c *gin.Context, fsys fs.FS) {