package service import ( "bufio" "bytes" "context" "crypto/sha256" "encoding/base64" "encoding/hex" "encoding/json" "fmt" "io" "mime" "mime/multipart" "net/http" "net/textproto" "strconv" "strings" "time" "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/util/responseheaders" "github.com/gin-gonic/gin" "github.com/imroc/req/v3" "github.com/tidwall/gjson" "github.com/tidwall/sjson" ) const ( openAIImagesGenerationsEndpoint = "/v1/images/generations" openAIImagesEditsEndpoint = "/v1/images/edits" openAIImagesGenerationsURL = "https://api.openai.com/v1/images/generations" openAIImagesEditsURL = "https://api.openai.com/v1/images/edits" openAIChatGPTStartURL = "https://chatgpt.com/" openAIChatGPTFilesURL = "https://chatgpt.com/backend-api/files" openAIImageBackendUserAgent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36" openAIImageMaxDownloadBytes = 20 << 20 // 20MB per image download openAIImageMaxUploadPartSize = 20 << 20 // 20MB per multipart upload part openAIImagesResponsesMainModel = "gpt-5.4-mini" ) type OpenAIImagesCapability string const ( OpenAIImagesCapabilityBasic OpenAIImagesCapability = "images-basic" OpenAIImagesCapabilityNative OpenAIImagesCapability = "images-native" ) type OpenAIImagesUpload struct { FieldName string FileName string ContentType string Data []byte Width int Height int } type OpenAIImagesRequest struct { Endpoint string ContentType string Multipart bool Model string ExplicitModel bool Prompt string Stream bool N int Size string ExplicitSize bool SizeTier string ResponseFormat string Quality string Background string OutputFormat string Moderation string InputFidelity string Style string OutputCompression *int PartialImages *int HasMask bool HasNativeOptions bool RequiredCapability OpenAIImagesCapability InputImageURLs []string MaskImageURL string Uploads []OpenAIImagesUpload MaskUpload *OpenAIImagesUpload Body []byte bodyHash string } func (r *OpenAIImagesRequest) IsEdits() bool { return r != nil && r.Endpoint == openAIImagesEditsEndpoint } func (r *OpenAIImagesRequest) StickySessionSeed() string { if r == nil { return "" } parts := []string{ "openai-images", strings.TrimSpace(r.Endpoint), strings.TrimSpace(r.Model), strings.TrimSpace(r.Size), strings.TrimSpace(r.Prompt), } seed := strings.Join(parts, "|") if strings.TrimSpace(r.Prompt) == "" && r.bodyHash != "" { seed += "|body=" + r.bodyHash } return seed } func (s *OpenAIGatewayService) ParseOpenAIImagesRequest(c *gin.Context, body []byte) (*OpenAIImagesRequest, error) { if c == nil || c.Request == nil { return nil, fmt.Errorf("missing request context") } endpoint := normalizeOpenAIImagesEndpointPath(c.Request.URL.Path) if endpoint == "" { return nil, fmt.Errorf("unsupported images endpoint") } contentType := strings.TrimSpace(c.GetHeader("Content-Type")) req := &OpenAIImagesRequest{ Endpoint: endpoint, ContentType: contentType, N: 1, Body: body, } if len(body) > 0 { sum := sha256.Sum256(body) req.bodyHash = hex.EncodeToString(sum[:8]) } mediaType, _, err := mime.ParseMediaType(contentType) if err == nil && strings.EqualFold(mediaType, "multipart/form-data") { req.Multipart = true if parseErr := parseOpenAIImagesMultipartRequest(body, contentType, req); parseErr != nil { return nil, parseErr } } else { if len(body) == 0 { return nil, fmt.Errorf("request body is empty") } if !gjson.ValidBytes(body) { return nil, fmt.Errorf("failed to parse request body") } if parseErr := parseOpenAIImagesJSONRequest(body, req); parseErr != nil { return nil, parseErr } } applyOpenAIImagesDefaults(req) if err := validateOpenAIImagesModel(req.Model); err != nil { return nil, err } req.SizeTier = normalizeOpenAIImageSizeTier(req.Size) req.RequiredCapability = classifyOpenAIImagesCapability(req) return req, nil } func parseOpenAIImagesJSONRequest(body []byte, req *OpenAIImagesRequest) error { if modelResult := gjson.GetBytes(body, "model"); modelResult.Exists() { req.Model = strings.TrimSpace(modelResult.String()) req.ExplicitModel = req.Model != "" } req.Prompt = strings.TrimSpace(gjson.GetBytes(body, "prompt").String()) if streamResult := gjson.GetBytes(body, "stream"); streamResult.Exists() { if streamResult.Type != gjson.True && streamResult.Type != gjson.False { return fmt.Errorf("invalid stream field type") } req.Stream = streamResult.Bool() } if nResult := gjson.GetBytes(body, "n"); nResult.Exists() { if nResult.Type != gjson.Number { return fmt.Errorf("invalid n field type") } req.N = int(nResult.Int()) if req.N <= 0 { return fmt.Errorf("n must be greater than 0") } } if sizeResult := gjson.GetBytes(body, "size"); sizeResult.Exists() { req.Size = strings.TrimSpace(sizeResult.String()) req.ExplicitSize = req.Size != "" } req.ResponseFormat = strings.ToLower(strings.TrimSpace(gjson.GetBytes(body, "response_format").String())) req.Quality = strings.TrimSpace(gjson.GetBytes(body, "quality").String()) req.Background = strings.TrimSpace(gjson.GetBytes(body, "background").String()) req.OutputFormat = strings.TrimSpace(gjson.GetBytes(body, "output_format").String()) req.Moderation = strings.TrimSpace(gjson.GetBytes(body, "moderation").String()) req.InputFidelity = strings.TrimSpace(gjson.GetBytes(body, "input_fidelity").String()) req.Style = strings.TrimSpace(gjson.GetBytes(body, "style").String()) req.HasMask = gjson.GetBytes(body, "mask").Exists() if outputCompression := gjson.GetBytes(body, "output_compression"); outputCompression.Exists() { if outputCompression.Type != gjson.Number { return fmt.Errorf("invalid output_compression field type") } v := int(outputCompression.Int()) req.OutputCompression = &v } if partialImages := gjson.GetBytes(body, "partial_images"); partialImages.Exists() { if partialImages.Type != gjson.Number { return fmt.Errorf("invalid partial_images field type") } v := int(partialImages.Int()) req.PartialImages = &v } if req.IsEdits() { images := gjson.GetBytes(body, "images") if images.Exists() { if !images.IsArray() { return fmt.Errorf("invalid images field type") } for _, item := range images.Array() { if imageURL := strings.TrimSpace(item.Get("image_url").String()); imageURL != "" { req.InputImageURLs = append(req.InputImageURLs, imageURL) continue } if item.Get("file_id").Exists() { return fmt.Errorf("images[].file_id is not supported (use images[].image_url instead)") } } } if maskImageURL := strings.TrimSpace(gjson.GetBytes(body, "mask.image_url").String()); maskImageURL != "" { req.MaskImageURL = maskImageURL req.HasMask = true } if gjson.GetBytes(body, "mask.file_id").Exists() { return fmt.Errorf("mask.file_id is not supported (use mask.image_url instead)") } if len(req.InputImageURLs) == 0 { return fmt.Errorf("images[].image_url is required") } } req.HasNativeOptions = hasOpenAINativeImageOptions(func(path string) bool { return gjson.GetBytes(body, path).Exists() }) return nil } func parseOpenAIImagesMultipartRequest(body []byte, contentType string, req *OpenAIImagesRequest) error { _, params, err := mime.ParseMediaType(contentType) if err != nil { return fmt.Errorf("invalid multipart content-type: %w", err) } boundary := strings.TrimSpace(params["boundary"]) if boundary == "" { return fmt.Errorf("multipart boundary is required") } reader := multipart.NewReader(bytes.NewReader(body), boundary) for { part, err := reader.NextPart() if err == io.EOF { break } if err != nil { return fmt.Errorf("read multipart body: %w", err) } name := strings.TrimSpace(part.FormName()) if name == "" { _ = part.Close() continue } data, err := io.ReadAll(io.LimitReader(part, openAIImageMaxUploadPartSize)) _ = part.Close() if err != nil { return fmt.Errorf("read multipart field %s: %w", name, err) } fileName := strings.TrimSpace(part.FileName()) if fileName != "" { partContentType := strings.TrimSpace(part.Header.Get("Content-Type")) if name == "mask" && len(data) > 0 { req.HasMask = true width, height := parseOpenAIImageDimensions(part.Header) maskUpload := OpenAIImagesUpload{ FieldName: name, FileName: fileName, ContentType: partContentType, Data: data, Width: width, Height: height, } req.MaskUpload = &maskUpload } if name == "image" || strings.HasPrefix(name, "image[") { width, height := parseOpenAIImageDimensions(part.Header) req.Uploads = append(req.Uploads, OpenAIImagesUpload{ FieldName: name, FileName: fileName, ContentType: partContentType, Data: data, Width: width, Height: height, }) } continue } value := strings.TrimSpace(string(data)) switch name { case "model": req.Model = value req.ExplicitModel = value != "" case "prompt": req.Prompt = value case "size": req.Size = value req.ExplicitSize = value != "" case "response_format": req.ResponseFormat = strings.ToLower(value) case "stream": parsed, err := strconv.ParseBool(value) if err != nil { return fmt.Errorf("invalid stream field value") } req.Stream = parsed case "n": n, err := strconv.Atoi(value) if err != nil || n <= 0 { return fmt.Errorf("n must be a positive integer") } req.N = n case "quality": req.Quality = value req.HasNativeOptions = true case "background": req.Background = value req.HasNativeOptions = true case "output_format": req.OutputFormat = value req.HasNativeOptions = true case "moderation": req.Moderation = value req.HasNativeOptions = true case "input_fidelity": req.InputFidelity = value req.HasNativeOptions = true case "style": req.Style = value req.HasNativeOptions = true case "output_compression": n, err := strconv.Atoi(value) if err != nil { return fmt.Errorf("invalid output_compression field value") } req.OutputCompression = &n req.HasNativeOptions = true case "partial_images": n, err := strconv.Atoi(value) if err != nil { return fmt.Errorf("invalid partial_images field value") } req.PartialImages = &n req.HasNativeOptions = true default: if isOpenAINativeImageOption(name) && value != "" { req.HasNativeOptions = true } } } if len(req.Uploads) == 0 && req.IsEdits() { return fmt.Errorf("image file is required") } return nil } func parseOpenAIImageDimensions(_ textproto.MIMEHeader) (int, int) { return 0, 0 } func applyOpenAIImagesDefaults(req *OpenAIImagesRequest) { if req == nil { return } if req.N <= 0 { req.N = 1 } if strings.TrimSpace(req.Model) != "" { req.Model = strings.TrimSpace(req.Model) return } req.Model = "gpt-image-2" } func isOpenAIImageGenerationModel(model string) bool { return strings.HasPrefix(strings.ToLower(strings.TrimSpace(model)), "gpt-image-") } func validateOpenAIImagesModel(model string) error { model = strings.TrimSpace(model) if isOpenAIImageGenerationModel(model) { return nil } if model == "" { return fmt.Errorf("images endpoint requires an image model") } return fmt.Errorf("images endpoint requires an image model, got %q", model) } func normalizeOpenAIImagesEndpointPath(path string) string { trimmed := strings.TrimSpace(path) switch { case strings.Contains(trimmed, "/images/generations"): return openAIImagesGenerationsEndpoint case strings.Contains(trimmed, "/images/edits"): return openAIImagesEditsEndpoint default: return "" } } func classifyOpenAIImagesCapability(req *OpenAIImagesRequest) OpenAIImagesCapability { if req == nil { return OpenAIImagesCapabilityNative } if req.ExplicitModel || req.ExplicitSize { return OpenAIImagesCapabilityNative } model := strings.ToLower(strings.TrimSpace(req.Model)) if !strings.HasPrefix(model, "gpt-image-") { return OpenAIImagesCapabilityNative } if req.Stream || req.N != 1 || req.HasMask || req.HasNativeOptions { return OpenAIImagesCapabilityNative } if req.IsEdits() && !req.Multipart { return OpenAIImagesCapabilityNative } if req.ResponseFormat != "" && req.ResponseFormat != "b64_json" { return OpenAIImagesCapabilityNative } return OpenAIImagesCapabilityBasic } func hasOpenAINativeImageOptions(exists func(path string) bool) bool { for _, path := range []string{ "background", "quality", "style", "output_format", "output_compression", "moderation", "input_fidelity", "partial_images", } { if exists(path) { return true } } return false } func isOpenAINativeImageOption(name string) bool { switch strings.TrimSpace(strings.ToLower(name)) { case "background", "quality", "style", "output_format", "output_compression", "moderation", "input_fidelity", "partial_images": return true default: return false } } func normalizeOpenAIImageSizeTier(size string) string { switch strings.ToLower(strings.TrimSpace(size)) { case "1024x1024": return "1K" case "1536x1024", "1024x1536", "1792x1024", "1024x1792", "", "auto": return "2K" default: return "2K" } } func (s *OpenAIGatewayService) ForwardImages( ctx context.Context, c *gin.Context, account *Account, body []byte, parsed *OpenAIImagesRequest, channelMappedModel string, ) (*OpenAIForwardResult, error) { if parsed == nil { return nil, fmt.Errorf("parsed images request is required") } switch account.Type { case AccountTypeAPIKey: return s.forwardOpenAIImagesAPIKey(ctx, c, account, body, parsed, channelMappedModel) case AccountTypeOAuth: return s.forwardOpenAIImagesOAuth(ctx, c, account, parsed, channelMappedModel) default: return nil, fmt.Errorf("unsupported account type: %s", account.Type) } } func (s *OpenAIGatewayService) forwardOpenAIImagesAPIKey( ctx context.Context, c *gin.Context, account *Account, body []byte, parsed *OpenAIImagesRequest, channelMappedModel string, ) (*OpenAIForwardResult, error) { startTime := time.Now() requestModel := strings.TrimSpace(parsed.Model) if mapped := strings.TrimSpace(channelMappedModel); mapped != "" { requestModel = mapped } if err := validateOpenAIImagesModel(requestModel); err != nil { return nil, err } upstreamModel := account.GetMappedModel(requestModel) if err := validateOpenAIImagesModel(upstreamModel); err != nil { return nil, err } logger.LegacyPrintf( "service.openai_gateway", "[OpenAI] Images request routing request_model=%s upstream_model=%s endpoint=%s account_type=%s", strings.TrimSpace(parsed.Model), upstreamModel, parsed.Endpoint, account.Type, ) forwardBody, forwardContentType, err := rewriteOpenAIImagesModel(body, parsed.ContentType, upstreamModel) if err != nil { return nil, err } if !parsed.Multipart { setOpsUpstreamRequestBody(c, forwardBody) } token, _, err := s.GetAccessToken(ctx, account) if err != nil { return nil, err } upstreamReq, err := s.buildOpenAIImagesRequest(ctx, c, account, forwardBody, forwardContentType, token, parsed.Endpoint) if err != nil { return nil, err } proxyURL := "" if account.ProxyID != nil && account.Proxy != nil { proxyURL = account.Proxy.URL() } upstreamStart := time.Now() resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) SetOpsLatencyMs(c, OpsUpstreamLatencyMsKey, time.Since(upstreamStart).Milliseconds()) if err != nil { safeErr := sanitizeUpstreamErrorMessage(err.Error()) setOpsUpstreamError(c, 0, safeErr, "") appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ Platform: account.Platform, AccountID: account.ID, AccountName: account.Name, UpstreamStatusCode: 0, UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()), Kind: "request_error", Message: safeErr, }) return nil, fmt.Errorf("upstream request failed: %s", safeErr) } if resp.StatusCode >= 400 { respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) _ = resp.Body.Close() resp.Body = io.NopCloser(bytes.NewReader(respBody)) upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody)) upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) if s.shouldFailoverOpenAIUpstreamResponse(resp.StatusCode, upstreamMsg, respBody) { appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ Platform: account.Platform, AccountID: account.ID, AccountName: account.Name, UpstreamStatusCode: resp.StatusCode, UpstreamRequestID: resp.Header.Get("x-request-id"), UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()), Kind: "failover", Message: upstreamMsg, }) s.handleFailoverSideEffects(ctx, resp, account) return nil, &UpstreamFailoverError{ StatusCode: resp.StatusCode, ResponseBody: respBody, RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode), } } return s.handleErrorResponse(ctx, resp, c, account, forwardBody) } defer func() { _ = resp.Body.Close() }() var usage OpenAIUsage imageCount := parsed.N var firstTokenMs *int if parsed.Stream { streamUsage, streamCount, ttft, err := s.handleOpenAIImagesStreamingResponse(resp, c, startTime) if err != nil { return nil, err } usage = streamUsage imageCount = streamCount firstTokenMs = ttft } else { nonStreamUsage, nonStreamCount, err := s.handleOpenAIImagesNonStreamingResponse(resp, c) if err != nil { return nil, err } usage = nonStreamUsage if nonStreamCount > 0 { imageCount = nonStreamCount } } return &OpenAIForwardResult{ RequestID: resp.Header.Get("x-request-id"), Usage: usage, Model: requestModel, UpstreamModel: upstreamModel, Stream: parsed.Stream, ResponseHeaders: resp.Header.Clone(), Duration: time.Since(startTime), FirstTokenMs: firstTokenMs, ImageCount: imageCount, ImageSize: parsed.SizeTier, }, nil } func (s *OpenAIGatewayService) buildOpenAIImagesRequest( ctx context.Context, c *gin.Context, account *Account, body []byte, contentType string, token string, endpoint string, ) (*http.Request, error) { targetURL := openAIImagesGenerationsURL if endpoint == openAIImagesEditsEndpoint { targetURL = openAIImagesEditsURL } baseURL := account.GetOpenAIBaseURL() if baseURL != "" { validatedURL, err := s.validateUpstreamBaseURL(baseURL) if err != nil { return nil, err } targetURL = buildOpenAIImagesURL(validatedURL, endpoint) } req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, bytes.NewReader(body)) if err != nil { return nil, err } req.Header.Set("Authorization", "Bearer "+token) for key, values := range c.Request.Header { if !openaiPassthroughAllowedHeaders[strings.ToLower(key)] { continue } for _, value := range values { req.Header.Add(key, value) } } customUA := account.GetOpenAIUserAgent() if customUA != "" { req.Header.Set("User-Agent", customUA) } if strings.TrimSpace(contentType) != "" { req.Header.Set("Content-Type", contentType) } return req, nil } func buildOpenAIImagesURL(base string, endpoint string) string { normalized := strings.TrimRight(strings.TrimSpace(base), "/") relative := strings.TrimPrefix(strings.TrimSpace(endpoint), "/v1") if strings.HasSuffix(normalized, endpoint) || strings.HasSuffix(normalized, relative) { return normalized } if strings.HasSuffix(normalized, "/v1") { return normalized + relative } return normalized + endpoint } func rewriteOpenAIImagesModel(body []byte, contentType string, model string) ([]byte, string, error) { model = strings.TrimSpace(model) if model == "" { return body, contentType, nil } mediaType, _, err := mime.ParseMediaType(contentType) if err == nil && strings.EqualFold(mediaType, "multipart/form-data") { rewrittenBody, rewrittenType, rewriteErr := rewriteOpenAIImagesMultipartModel(body, contentType, model) return rewrittenBody, rewrittenType, rewriteErr } rewritten, err := sjson.SetBytes(body, "model", model) if err != nil { return nil, "", fmt.Errorf("rewrite image request model: %w", err) } return rewritten, contentType, nil } func rewriteOpenAIImagesMultipartModel(body []byte, contentType string, model string) ([]byte, string, error) { _, params, err := mime.ParseMediaType(contentType) if err != nil { return nil, "", fmt.Errorf("parse multipart content-type: %w", err) } boundary := strings.TrimSpace(params["boundary"]) if boundary == "" { return nil, "", fmt.Errorf("multipart boundary is required") } reader := multipart.NewReader(bytes.NewReader(body), boundary) var buffer bytes.Buffer writer := multipart.NewWriter(&buffer) modelWritten := false for { part, err := reader.NextPart() if err == io.EOF { break } if err != nil { return nil, "", fmt.Errorf("read multipart body: %w", err) } formName := strings.TrimSpace(part.FormName()) partHeader := cloneMultipartHeader(part.Header) target, err := writer.CreatePart(partHeader) if err != nil { _ = part.Close() return nil, "", fmt.Errorf("create multipart part: %w", err) } if formName == "model" && part.FileName() == "" { if _, err := target.Write([]byte(model)); err != nil { _ = part.Close() return nil, "", fmt.Errorf("rewrite multipart model: %w", err) } modelWritten = true _ = part.Close() continue } if _, err := io.Copy(target, part); err != nil { _ = part.Close() return nil, "", fmt.Errorf("copy multipart part: %w", err) } _ = part.Close() } if !modelWritten { if err := writer.WriteField("model", model); err != nil { return nil, "", fmt.Errorf("append multipart model field: %w", err) } } if err := writer.Close(); err != nil { return nil, "", fmt.Errorf("finalize multipart body: %w", err) } return buffer.Bytes(), writer.FormDataContentType(), nil } func cloneMultipartHeader(src textproto.MIMEHeader) textproto.MIMEHeader { dst := make(textproto.MIMEHeader, len(src)) for key, values := range src { copied := make([]string, len(values)) copy(copied, values) dst[key] = copied } return dst } func (s *OpenAIGatewayService) handleOpenAIImagesNonStreamingResponse(resp *http.Response, c *gin.Context) (OpenAIUsage, int, error) { body, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, openAITooLargeError) if err != nil { return OpenAIUsage{}, 0, err } responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) contentType := "application/json" if s.cfg != nil && !s.cfg.Security.ResponseHeaders.Enabled { if upstreamType := resp.Header.Get("Content-Type"); upstreamType != "" { contentType = upstreamType } } c.Data(resp.StatusCode, contentType, body) usage, _ := extractOpenAIUsageFromJSONBytes(body) return usage, extractOpenAIImageCountFromJSONBytes(body), nil } func (s *OpenAIGatewayService) handleOpenAIImagesStreamingResponse( resp *http.Response, c *gin.Context, startTime time.Time, ) (OpenAIUsage, int, *int, error) { responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) contentType := strings.TrimSpace(resp.Header.Get("Content-Type")) if contentType == "" { contentType = "text/event-stream" } c.Status(resp.StatusCode) c.Header("Content-Type", contentType) flusher, ok := c.Writer.(http.Flusher) if !ok { return OpenAIUsage{}, 0, nil, fmt.Errorf("streaming is not supported by response writer") } reader := bufio.NewReader(resp.Body) usage := OpenAIUsage{} imageCount := 0 var firstTokenMs *int for { line, err := reader.ReadBytes('\n') if len(line) > 0 { if firstTokenMs == nil { ms := int(time.Since(startTime).Milliseconds()) firstTokenMs = &ms } if _, writeErr := c.Writer.Write(line); writeErr != nil { return OpenAIUsage{}, 0, firstTokenMs, writeErr } flusher.Flush() if data, ok := extractOpenAISSEDataLine(strings.TrimRight(string(line), "\r\n")); ok && data != "" && data != "[DONE]" { dataBytes := []byte(data) mergeOpenAIUsage(&usage, dataBytes) if count := extractOpenAIImageCountFromJSONBytes(dataBytes); count > imageCount { imageCount = count } } } if err == io.EOF { break } if err != nil { return OpenAIUsage{}, 0, firstTokenMs, err } } return usage, imageCount, firstTokenMs, nil } func mergeOpenAIUsage(dst *OpenAIUsage, body []byte) { if dst == nil { return } if parsed, ok := extractOpenAIUsageFromJSONBytes(body); ok { if parsed.InputTokens > 0 { dst.InputTokens = parsed.InputTokens } if parsed.OutputTokens > 0 { dst.OutputTokens = parsed.OutputTokens } if parsed.CacheReadInputTokens > 0 { dst.CacheReadInputTokens = parsed.CacheReadInputTokens } if parsed.ImageOutputTokens > 0 { dst.ImageOutputTokens = parsed.ImageOutputTokens } } } func extractOpenAIImageCountFromJSONBytes(body []byte) int { if len(body) == 0 || !gjson.ValidBytes(body) { return 0 } data := gjson.GetBytes(body, "data") if data.Exists() && data.IsArray() { return len(data.Array()) } return 0 } type openAIImagePointerInfo struct { Pointer string DownloadURL string B64JSON string MimeType string Prompt string } func collectOpenAIImagePointers(body []byte) []openAIImagePointerInfo { if len(body) == 0 { return nil } prompt := "" for _, path := range []string{ "message.metadata.dalle.prompt", "metadata.dalle.prompt", "revised_prompt", } { if value := strings.TrimSpace(gjson.GetBytes(body, path).String()); value != "" { prompt = value break } } matches := openAIImagePointerMatches(body) out := make([]openAIImagePointerInfo, 0, len(matches)) for _, pointer := range matches { out = append(out, openAIImagePointerInfo{Pointer: pointer, Prompt: prompt}) } return mergeOpenAIImagePointerInfos(out, collectOpenAIImageInlineAssets(body, prompt)) } func openAIImagePointerMatches(body []byte) []string { raw := string(body) matches := make([]string, 0, 4) for _, prefix := range []string{"file-service://", "sediment://"} { start := 0 for { idx := strings.Index(raw[start:], prefix) if idx < 0 { break } idx += start end := idx + len(prefix) for end < len(raw) { ch := raw[end] if ch != '-' && ch != '_' && (ch < '0' || ch > '9') && (ch < 'a' || ch > 'z') && (ch < 'A' || ch > 'Z') { break } end++ } matches = append(matches, raw[idx:end]) start = end } } return dedupeStrings(matches) } func mergeOpenAIImagePointerInfos(existing []openAIImagePointerInfo, next []openAIImagePointerInfo) []openAIImagePointerInfo { if len(next) == 0 { return existing } seen := make(map[string]openAIImagePointerInfo, len(existing)+len(next)) out := make([]openAIImagePointerInfo, 0, len(existing)+len(next)) for _, item := range existing { if key := item.identityKey(); key != "" { seen[key] = item } out = append(out, item) } for _, item := range next { key := item.identityKey() if key == "" { continue } if existingItem, ok := seen[key]; ok { merged := mergeOpenAIImagePointerInfo(existingItem, item) if merged != existingItem { for i := range out { if out[i].identityKey() == key { out[i] = merged break } } seen[key] = merged } continue } seen[key] = item out = append(out, item) } return out } func (i openAIImagePointerInfo) identityKey() string { switch { case strings.TrimSpace(i.Pointer) != "": return "pointer:" + strings.TrimSpace(i.Pointer) case strings.TrimSpace(i.DownloadURL) != "": return "download:" + strings.TrimSpace(i.DownloadURL) case strings.TrimSpace(i.B64JSON) != "": b64 := strings.TrimSpace(i.B64JSON) if len(b64) > 64 { b64 = b64[:64] } return "b64:" + b64 default: return "" } } func mergeOpenAIImagePointerInfo(existing, next openAIImagePointerInfo) openAIImagePointerInfo { merged := existing if strings.TrimSpace(merged.Pointer) == "" { merged.Pointer = next.Pointer } if strings.TrimSpace(merged.DownloadURL) == "" { merged.DownloadURL = next.DownloadURL } if strings.TrimSpace(merged.B64JSON) == "" { merged.B64JSON = next.B64JSON } if strings.TrimSpace(merged.MimeType) == "" { merged.MimeType = next.MimeType } if strings.TrimSpace(merged.Prompt) == "" { merged.Prompt = next.Prompt } return merged } func resolveOpenAIImageBytes( ctx context.Context, client *req.Client, headers http.Header, conversationID string, pointer openAIImagePointerInfo, ) ([]byte, error) { if normalized := normalizeOpenAIImageBase64(pointer.B64JSON); normalized != "" { return base64.StdEncoding.DecodeString(normalized) } if downloadURL := strings.TrimSpace(pointer.DownloadURL); downloadURL != "" { return downloadOpenAIImageBytes(ctx, client, headers, downloadURL) } if strings.TrimSpace(pointer.Pointer) == "" { return nil, fmt.Errorf("image asset is missing pointer, url, and base64 data") } downloadURL, err := fetchOpenAIImageDownloadURL(ctx, client, headers, conversationID, pointer.Pointer) if err != nil { return nil, err } return downloadOpenAIImageBytes(ctx, client, headers, downloadURL) } func normalizeOpenAIImageBase64(raw string) string { raw = strings.TrimSpace(raw) if raw == "" { return "" } if strings.HasPrefix(strings.ToLower(raw), "data:") { if idx := strings.Index(raw, ","); idx >= 0 && idx+1 < len(raw) { raw = raw[idx+1:] } } raw = strings.TrimSpace(raw) raw = strings.TrimRight(raw, "=") + strings.Repeat("=", (4-len(raw)%4)%4) if raw == "" { return "" } if _, err := base64.StdEncoding.DecodeString(raw); err != nil { return "" } return raw } func collectOpenAIImageInlineAssets(body []byte, fallbackPrompt string) []openAIImagePointerInfo { if len(body) == 0 || !gjson.ValidBytes(body) { return nil } var decoded any if err := json.Unmarshal(body, &decoded); err != nil { return nil } var out []openAIImagePointerInfo walkOpenAIImageInlineAssets(decoded, strings.TrimSpace(fallbackPrompt), &out) return out } func walkOpenAIImageInlineAssets(node any, prompt string, out *[]openAIImagePointerInfo) { switch value := node.(type) { case map[string]any: localPrompt := prompt for _, key := range []string{"revised_prompt", "image_gen_title", "prompt"} { if v, ok := value[key].(string); ok && strings.TrimSpace(v) != "" { localPrompt = strings.TrimSpace(v) break } } item := openAIImagePointerInfo{ Prompt: localPrompt, Pointer: firstNonEmptyString(value["asset_pointer"], value["pointer"]), DownloadURL: firstNonEmptyString(value["download_url"], value["url"], value["image_url"]), B64JSON: firstNonEmptyString(value["b64_json"], value["base64"], value["image_base64"]), MimeType: firstNonEmptyString(value["mime_type"], value["mimeType"], value["content_type"]), } switch { case strings.HasPrefix(strings.TrimSpace(item.Pointer), "file-service://"), strings.HasPrefix(strings.TrimSpace(item.Pointer), "sediment://"), isLikelyOpenAIImageDownloadURL(item.DownloadURL), normalizeOpenAIImageBase64(item.B64JSON) != "": *out = append(*out, item) } for _, child := range value { walkOpenAIImageInlineAssets(child, localPrompt, out) } case []any: for _, child := range value { walkOpenAIImageInlineAssets(child, prompt, out) } } } func firstNonEmptyString(values ...any) string { for _, value := range values { if s, ok := value.(string); ok && strings.TrimSpace(s) != "" { return strings.TrimSpace(s) } } return "" } func isLikelyOpenAIImageDownloadURL(raw string) bool { raw = strings.TrimSpace(raw) if raw == "" { return false } if strings.HasPrefix(strings.ToLower(raw), "data:image/") { return true } if !strings.HasPrefix(strings.ToLower(raw), "http://") && !strings.HasPrefix(strings.ToLower(raw), "https://") { return false } lower := strings.ToLower(raw) return strings.Contains(lower, "/download") || strings.Contains(lower, ".png") || strings.Contains(lower, ".jpg") || strings.Contains(lower, ".jpeg") || strings.Contains(lower, ".webp") } func fetchOpenAIImageDownloadURL( ctx context.Context, client *req.Client, headers http.Header, conversationID string, pointer string, ) (string, error) { url := "" allowConversationRetry := false switch { case strings.HasPrefix(pointer, "file-service://"): fileID := strings.TrimPrefix(pointer, "file-service://") url = fmt.Sprintf("%s/%s/download", openAIChatGPTFilesURL, fileID) case strings.HasPrefix(pointer, "sediment://"): attachmentID := strings.TrimPrefix(pointer, "sediment://") url = fmt.Sprintf("https://chatgpt.com/backend-api/conversation/%s/attachment/%s/download", conversationID, attachmentID) allowConversationRetry = true default: return "", fmt.Errorf("unsupported image pointer: %s", pointer) } var lastErr error for attempt := 0; attempt < 8; attempt++ { var result struct { DownloadURL string `json:"download_url"` } resp, err := client.R(). SetContext(ctx). SetHeaders(headerToMap(headers)). SetSuccessResult(&result). Get(url) if err != nil { lastErr = err } else if resp.IsSuccessState() && strings.TrimSpace(result.DownloadURL) != "" { return strings.TrimSpace(result.DownloadURL), nil } else { statusErr := newOpenAIImageStatusError(resp, "fetch image download url failed") if !allowConversationRetry || !isOpenAIImageTransientConversationNotFoundError(statusErr) { return "", statusErr } lastErr = statusErr } if attempt == 7 { break } timer := time.NewTimer(750 * time.Millisecond) select { case <-ctx.Done(): if !timer.Stop() { <-timer.C } return "", ctx.Err() case <-timer.C: } } if lastErr == nil { lastErr = fmt.Errorf("fetch image download url failed") } return "", lastErr } func downloadOpenAIImageBytes(ctx context.Context, client *req.Client, headers http.Header, downloadURL string) ([]byte, error) { request := client.R(). SetContext(ctx). DisableAutoReadResponse() if strings.HasPrefix(downloadURL, openAIChatGPTStartURL) { downloadHeaders := cloneHTTPHeader(headers) downloadHeaders.Set("Accept", "image/*,*/*;q=0.8") downloadHeaders.Del("Content-Type") request.SetHeaders(headerToMap(downloadHeaders)) } else { userAgent := strings.TrimSpace(headers.Get("User-Agent")) if userAgent == "" { userAgent = openAIImageBackendUserAgent } request.SetHeader("User-Agent", userAgent) } resp, err := request.Get(downloadURL) if err != nil { return nil, err } defer func() { if resp != nil && resp.Body != nil { _ = resp.Body.Close() } }() if resp.StatusCode < 200 || resp.StatusCode >= 300 { return nil, newOpenAIImageStatusError(resp, "download image bytes failed") } return io.ReadAll(io.LimitReader(resp.Body, openAIImageMaxDownloadBytes)) } type openAIImageStatusError struct { StatusCode int Message string ResponseBody []byte ResponseHeaders http.Header RequestID string URL string } func (e *openAIImageStatusError) Error() string { if e == nil { return "openai image backend request failed" } if e.Message != "" { return e.Message } if e.StatusCode > 0 { return fmt.Sprintf("openai image backend request failed: status %d", e.StatusCode) } return "openai image backend request failed" } func newOpenAIImageStatusError(resp *req.Response, fallback string) error { if resp == nil { if strings.TrimSpace(fallback) == "" { fallback = "openai image backend request failed" } return fmt.Errorf("%s", fallback) } statusCode := resp.StatusCode headers := http.Header(nil) requestID := "" requestURL := "" body := []byte(nil) if resp.Response != nil { headers = resp.Header.Clone() requestID = strings.TrimSpace(resp.Header.Get("x-request-id")) if resp.Request != nil && resp.Request.URL != nil { requestURL = resp.Request.URL.String() } if resp.Body != nil { body, _ = io.ReadAll(io.LimitReader(resp.Body, 2<<20)) _ = resp.Body.Close() } } message := sanitizeUpstreamErrorMessage(extractUpstreamErrorMessage(body)) if message == "" { prefix := strings.TrimSpace(fallback) if prefix == "" { prefix = "openai image backend request failed" } message = fmt.Sprintf("%s: status %d", prefix, statusCode) } return &openAIImageStatusError{ StatusCode: statusCode, Message: message, ResponseBody: body, ResponseHeaders: headers, RequestID: requestID, URL: requestURL, } } func isOpenAIImageTransientConversationNotFoundError(err error) bool { statusErr, ok := err.(*openAIImageStatusError) if !ok || statusErr == nil || statusErr.StatusCode != http.StatusNotFound { return false } msg := strings.ToLower(strings.TrimSpace(statusErr.Message)) if strings.Contains(msg, "conversation_not_found") { return true } if strings.Contains(msg, "conversation") && strings.Contains(msg, "not found") { return true } bodyMsg := strings.ToLower(strings.TrimSpace(extractUpstreamErrorMessage(statusErr.ResponseBody))) if strings.Contains(bodyMsg, "conversation_not_found") { return true } return strings.Contains(bodyMsg, "conversation") && strings.Contains(bodyMsg, "not found") } func cloneHTTPHeader(src http.Header) http.Header { dst := make(http.Header, len(src)) for key, values := range src { copied := make([]string, len(values)) copy(copied, values) dst[key] = copied } return dst } func headerToMap(header http.Header) map[string]string { if len(header) == 0 { return nil } result := make(map[string]string, len(header)) for key, values := range header { if len(values) == 0 { continue } result[key] = values[0] } return result } func dedupeStrings(values []string) []string { if len(values) == 0 { return nil } seen := make(map[string]struct{}, len(values)) out := make([]string, 0, len(values)) for _, value := range values { if _, ok := seen[value]; ok { continue } seen[value] = struct{}{} out = append(out, value) } return out }