Revert payment/wechat, sora/claude-max cleanup, fork-only migrations, and cosmetic changes that were brought in by the release sync commit. Keep only channel-monitor related improvements: - PublicSettingsInjectionPayload named struct with drift test - ChannelMonitorRunner graceful shutdown in wire - image_output_price in SupportedModelChip - Simplified buildSelfNavItems in AppSidebar - Gateway WARN logs for 503 branches
1347 lines
38 KiB
Go
1347 lines
38 KiB
Go
package service
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"context"
|
|
"crypto/sha256"
|
|
"encoding/base64"
|
|
"encoding/hex"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"mime"
|
|
"mime/multipart"
|
|
"net/http"
|
|
"net/textproto"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
|
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/imroc/req/v3"
|
|
"github.com/tidwall/gjson"
|
|
"github.com/tidwall/sjson"
|
|
)
|
|
|
|
const (
|
|
openAIImagesGenerationsEndpoint = "/v1/images/generations"
|
|
openAIImagesEditsEndpoint = "/v1/images/edits"
|
|
|
|
openAIImagesGenerationsURL = "https://api.openai.com/v1/images/generations"
|
|
openAIImagesEditsURL = "https://api.openai.com/v1/images/edits"
|
|
|
|
openAIChatGPTStartURL = "https://chatgpt.com/"
|
|
openAIChatGPTFilesURL = "https://chatgpt.com/backend-api/files"
|
|
openAIImageBackendUserAgent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36"
|
|
openAIImageMaxDownloadBytes = 20 << 20 // 20MB per image download
|
|
openAIImageMaxUploadPartSize = 20 << 20 // 20MB per multipart upload part
|
|
openAIImagesResponsesMainModel = "gpt-5.4-mini"
|
|
)
|
|
|
|
type OpenAIImagesCapability string
|
|
|
|
const (
|
|
OpenAIImagesCapabilityBasic OpenAIImagesCapability = "images-basic"
|
|
OpenAIImagesCapabilityNative OpenAIImagesCapability = "images-native"
|
|
)
|
|
|
|
type OpenAIImagesUpload struct {
|
|
FieldName string
|
|
FileName string
|
|
ContentType string
|
|
Data []byte
|
|
Width int
|
|
Height int
|
|
}
|
|
|
|
type OpenAIImagesRequest struct {
|
|
Endpoint string
|
|
ContentType string
|
|
Multipart bool
|
|
Model string
|
|
ExplicitModel bool
|
|
Prompt string
|
|
Stream bool
|
|
N int
|
|
Size string
|
|
ExplicitSize bool
|
|
SizeTier string
|
|
ResponseFormat string
|
|
Quality string
|
|
Background string
|
|
OutputFormat string
|
|
Moderation string
|
|
InputFidelity string
|
|
Style string
|
|
OutputCompression *int
|
|
PartialImages *int
|
|
HasMask bool
|
|
HasNativeOptions bool
|
|
RequiredCapability OpenAIImagesCapability
|
|
InputImageURLs []string
|
|
MaskImageURL string
|
|
Uploads []OpenAIImagesUpload
|
|
MaskUpload *OpenAIImagesUpload
|
|
Body []byte
|
|
bodyHash string
|
|
}
|
|
|
|
func (r *OpenAIImagesRequest) IsEdits() bool {
|
|
return r != nil && r.Endpoint == openAIImagesEditsEndpoint
|
|
}
|
|
|
|
func (r *OpenAIImagesRequest) StickySessionSeed() string {
|
|
if r == nil {
|
|
return ""
|
|
}
|
|
parts := []string{
|
|
"openai-images",
|
|
strings.TrimSpace(r.Endpoint),
|
|
strings.TrimSpace(r.Model),
|
|
strings.TrimSpace(r.Size),
|
|
strings.TrimSpace(r.Prompt),
|
|
}
|
|
seed := strings.Join(parts, "|")
|
|
if strings.TrimSpace(r.Prompt) == "" && r.bodyHash != "" {
|
|
seed += "|body=" + r.bodyHash
|
|
}
|
|
return seed
|
|
}
|
|
|
|
func (s *OpenAIGatewayService) ParseOpenAIImagesRequest(c *gin.Context, body []byte) (*OpenAIImagesRequest, error) {
|
|
if c == nil || c.Request == nil {
|
|
return nil, fmt.Errorf("missing request context")
|
|
}
|
|
endpoint := normalizeOpenAIImagesEndpointPath(c.Request.URL.Path)
|
|
if endpoint == "" {
|
|
return nil, fmt.Errorf("unsupported images endpoint")
|
|
}
|
|
|
|
contentType := strings.TrimSpace(c.GetHeader("Content-Type"))
|
|
req := &OpenAIImagesRequest{
|
|
Endpoint: endpoint,
|
|
ContentType: contentType,
|
|
N: 1,
|
|
Body: body,
|
|
}
|
|
if len(body) > 0 {
|
|
sum := sha256.Sum256(body)
|
|
req.bodyHash = hex.EncodeToString(sum[:8])
|
|
}
|
|
|
|
mediaType, _, err := mime.ParseMediaType(contentType)
|
|
if err == nil && strings.EqualFold(mediaType, "multipart/form-data") {
|
|
req.Multipart = true
|
|
if parseErr := parseOpenAIImagesMultipartRequest(body, contentType, req); parseErr != nil {
|
|
return nil, parseErr
|
|
}
|
|
} else {
|
|
if len(body) == 0 {
|
|
return nil, fmt.Errorf("request body is empty")
|
|
}
|
|
if !gjson.ValidBytes(body) {
|
|
return nil, fmt.Errorf("failed to parse request body")
|
|
}
|
|
if parseErr := parseOpenAIImagesJSONRequest(body, req); parseErr != nil {
|
|
return nil, parseErr
|
|
}
|
|
}
|
|
|
|
applyOpenAIImagesDefaults(req)
|
|
if err := validateOpenAIImagesModel(req.Model); err != nil {
|
|
return nil, err
|
|
}
|
|
req.SizeTier = normalizeOpenAIImageSizeTier(req.Size)
|
|
req.RequiredCapability = classifyOpenAIImagesCapability(req)
|
|
return req, nil
|
|
}
|
|
|
|
func parseOpenAIImagesJSONRequest(body []byte, req *OpenAIImagesRequest) error {
|
|
if modelResult := gjson.GetBytes(body, "model"); modelResult.Exists() {
|
|
req.Model = strings.TrimSpace(modelResult.String())
|
|
req.ExplicitModel = req.Model != ""
|
|
}
|
|
req.Prompt = strings.TrimSpace(gjson.GetBytes(body, "prompt").String())
|
|
|
|
if streamResult := gjson.GetBytes(body, "stream"); streamResult.Exists() {
|
|
if streamResult.Type != gjson.True && streamResult.Type != gjson.False {
|
|
return fmt.Errorf("invalid stream field type")
|
|
}
|
|
req.Stream = streamResult.Bool()
|
|
}
|
|
|
|
if nResult := gjson.GetBytes(body, "n"); nResult.Exists() {
|
|
if nResult.Type != gjson.Number {
|
|
return fmt.Errorf("invalid n field type")
|
|
}
|
|
req.N = int(nResult.Int())
|
|
if req.N <= 0 {
|
|
return fmt.Errorf("n must be greater than 0")
|
|
}
|
|
}
|
|
|
|
if sizeResult := gjson.GetBytes(body, "size"); sizeResult.Exists() {
|
|
req.Size = strings.TrimSpace(sizeResult.String())
|
|
req.ExplicitSize = req.Size != ""
|
|
}
|
|
req.ResponseFormat = strings.ToLower(strings.TrimSpace(gjson.GetBytes(body, "response_format").String()))
|
|
req.Quality = strings.TrimSpace(gjson.GetBytes(body, "quality").String())
|
|
req.Background = strings.TrimSpace(gjson.GetBytes(body, "background").String())
|
|
req.OutputFormat = strings.TrimSpace(gjson.GetBytes(body, "output_format").String())
|
|
req.Moderation = strings.TrimSpace(gjson.GetBytes(body, "moderation").String())
|
|
req.InputFidelity = strings.TrimSpace(gjson.GetBytes(body, "input_fidelity").String())
|
|
req.Style = strings.TrimSpace(gjson.GetBytes(body, "style").String())
|
|
req.HasMask = gjson.GetBytes(body, "mask").Exists()
|
|
if outputCompression := gjson.GetBytes(body, "output_compression"); outputCompression.Exists() {
|
|
if outputCompression.Type != gjson.Number {
|
|
return fmt.Errorf("invalid output_compression field type")
|
|
}
|
|
v := int(outputCompression.Int())
|
|
req.OutputCompression = &v
|
|
}
|
|
if partialImages := gjson.GetBytes(body, "partial_images"); partialImages.Exists() {
|
|
if partialImages.Type != gjson.Number {
|
|
return fmt.Errorf("invalid partial_images field type")
|
|
}
|
|
v := int(partialImages.Int())
|
|
req.PartialImages = &v
|
|
}
|
|
if req.IsEdits() {
|
|
images := gjson.GetBytes(body, "images")
|
|
if images.Exists() {
|
|
if !images.IsArray() {
|
|
return fmt.Errorf("invalid images field type")
|
|
}
|
|
for _, item := range images.Array() {
|
|
if imageURL := strings.TrimSpace(item.Get("image_url").String()); imageURL != "" {
|
|
req.InputImageURLs = append(req.InputImageURLs, imageURL)
|
|
continue
|
|
}
|
|
if item.Get("file_id").Exists() {
|
|
return fmt.Errorf("images[].file_id is not supported (use images[].image_url instead)")
|
|
}
|
|
}
|
|
}
|
|
if maskImageURL := strings.TrimSpace(gjson.GetBytes(body, "mask.image_url").String()); maskImageURL != "" {
|
|
req.MaskImageURL = maskImageURL
|
|
req.HasMask = true
|
|
}
|
|
if gjson.GetBytes(body, "mask.file_id").Exists() {
|
|
return fmt.Errorf("mask.file_id is not supported (use mask.image_url instead)")
|
|
}
|
|
if len(req.InputImageURLs) == 0 {
|
|
return fmt.Errorf("images[].image_url is required")
|
|
}
|
|
}
|
|
req.HasNativeOptions = hasOpenAINativeImageOptions(func(path string) bool {
|
|
return gjson.GetBytes(body, path).Exists()
|
|
})
|
|
return nil
|
|
}
|
|
|
|
func parseOpenAIImagesMultipartRequest(body []byte, contentType string, req *OpenAIImagesRequest) error {
|
|
_, params, err := mime.ParseMediaType(contentType)
|
|
if err != nil {
|
|
return fmt.Errorf("invalid multipart content-type: %w", err)
|
|
}
|
|
boundary := strings.TrimSpace(params["boundary"])
|
|
if boundary == "" {
|
|
return fmt.Errorf("multipart boundary is required")
|
|
}
|
|
|
|
reader := multipart.NewReader(bytes.NewReader(body), boundary)
|
|
for {
|
|
part, err := reader.NextPart()
|
|
if err == io.EOF {
|
|
break
|
|
}
|
|
if err != nil {
|
|
return fmt.Errorf("read multipart body: %w", err)
|
|
}
|
|
name := strings.TrimSpace(part.FormName())
|
|
if name == "" {
|
|
_ = part.Close()
|
|
continue
|
|
}
|
|
|
|
data, err := io.ReadAll(io.LimitReader(part, openAIImageMaxUploadPartSize))
|
|
_ = part.Close()
|
|
if err != nil {
|
|
return fmt.Errorf("read multipart field %s: %w", name, err)
|
|
}
|
|
|
|
fileName := strings.TrimSpace(part.FileName())
|
|
if fileName != "" {
|
|
partContentType := strings.TrimSpace(part.Header.Get("Content-Type"))
|
|
if name == "mask" && len(data) > 0 {
|
|
req.HasMask = true
|
|
width, height := parseOpenAIImageDimensions(part.Header)
|
|
maskUpload := OpenAIImagesUpload{
|
|
FieldName: name,
|
|
FileName: fileName,
|
|
ContentType: partContentType,
|
|
Data: data,
|
|
Width: width,
|
|
Height: height,
|
|
}
|
|
req.MaskUpload = &maskUpload
|
|
}
|
|
if name == "image" || strings.HasPrefix(name, "image[") {
|
|
width, height := parseOpenAIImageDimensions(part.Header)
|
|
req.Uploads = append(req.Uploads, OpenAIImagesUpload{
|
|
FieldName: name,
|
|
FileName: fileName,
|
|
ContentType: partContentType,
|
|
Data: data,
|
|
Width: width,
|
|
Height: height,
|
|
})
|
|
}
|
|
continue
|
|
}
|
|
|
|
value := strings.TrimSpace(string(data))
|
|
switch name {
|
|
case "model":
|
|
req.Model = value
|
|
req.ExplicitModel = value != ""
|
|
case "prompt":
|
|
req.Prompt = value
|
|
case "size":
|
|
req.Size = value
|
|
req.ExplicitSize = value != ""
|
|
case "response_format":
|
|
req.ResponseFormat = strings.ToLower(value)
|
|
case "stream":
|
|
parsed, err := strconv.ParseBool(value)
|
|
if err != nil {
|
|
return fmt.Errorf("invalid stream field value")
|
|
}
|
|
req.Stream = parsed
|
|
case "n":
|
|
n, err := strconv.Atoi(value)
|
|
if err != nil || n <= 0 {
|
|
return fmt.Errorf("n must be a positive integer")
|
|
}
|
|
req.N = n
|
|
case "quality":
|
|
req.Quality = value
|
|
req.HasNativeOptions = true
|
|
case "background":
|
|
req.Background = value
|
|
req.HasNativeOptions = true
|
|
case "output_format":
|
|
req.OutputFormat = value
|
|
req.HasNativeOptions = true
|
|
case "moderation":
|
|
req.Moderation = value
|
|
req.HasNativeOptions = true
|
|
case "input_fidelity":
|
|
req.InputFidelity = value
|
|
req.HasNativeOptions = true
|
|
case "style":
|
|
req.Style = value
|
|
req.HasNativeOptions = true
|
|
case "output_compression":
|
|
n, err := strconv.Atoi(value)
|
|
if err != nil {
|
|
return fmt.Errorf("invalid output_compression field value")
|
|
}
|
|
req.OutputCompression = &n
|
|
req.HasNativeOptions = true
|
|
case "partial_images":
|
|
n, err := strconv.Atoi(value)
|
|
if err != nil {
|
|
return fmt.Errorf("invalid partial_images field value")
|
|
}
|
|
req.PartialImages = &n
|
|
req.HasNativeOptions = true
|
|
default:
|
|
if isOpenAINativeImageOption(name) && value != "" {
|
|
req.HasNativeOptions = true
|
|
}
|
|
}
|
|
}
|
|
|
|
if len(req.Uploads) == 0 && req.IsEdits() {
|
|
return fmt.Errorf("image file is required")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func parseOpenAIImageDimensions(_ textproto.MIMEHeader) (int, int) {
|
|
return 0, 0
|
|
}
|
|
|
|
func applyOpenAIImagesDefaults(req *OpenAIImagesRequest) {
|
|
if req == nil {
|
|
return
|
|
}
|
|
if req.N <= 0 {
|
|
req.N = 1
|
|
}
|
|
if strings.TrimSpace(req.Model) != "" {
|
|
req.Model = strings.TrimSpace(req.Model)
|
|
return
|
|
}
|
|
req.Model = "gpt-image-2"
|
|
}
|
|
|
|
func isOpenAIImageGenerationModel(model string) bool {
|
|
return strings.HasPrefix(strings.ToLower(strings.TrimSpace(model)), "gpt-image-")
|
|
}
|
|
|
|
func validateOpenAIImagesModel(model string) error {
|
|
model = strings.TrimSpace(model)
|
|
if isOpenAIImageGenerationModel(model) {
|
|
return nil
|
|
}
|
|
if model == "" {
|
|
return fmt.Errorf("images endpoint requires an image model")
|
|
}
|
|
return fmt.Errorf("images endpoint requires an image model, got %q", model)
|
|
}
|
|
|
|
func normalizeOpenAIImagesEndpointPath(path string) string {
|
|
trimmed := strings.TrimSpace(path)
|
|
switch {
|
|
case strings.Contains(trimmed, "/images/generations"):
|
|
return openAIImagesGenerationsEndpoint
|
|
case strings.Contains(trimmed, "/images/edits"):
|
|
return openAIImagesEditsEndpoint
|
|
default:
|
|
return ""
|
|
}
|
|
}
|
|
|
|
func classifyOpenAIImagesCapability(req *OpenAIImagesRequest) OpenAIImagesCapability {
|
|
if req == nil {
|
|
return OpenAIImagesCapabilityNative
|
|
}
|
|
if req.ExplicitModel || req.ExplicitSize {
|
|
return OpenAIImagesCapabilityNative
|
|
}
|
|
model := strings.ToLower(strings.TrimSpace(req.Model))
|
|
if !strings.HasPrefix(model, "gpt-image-") {
|
|
return OpenAIImagesCapabilityNative
|
|
}
|
|
if req.Stream || req.N != 1 || req.HasMask || req.HasNativeOptions {
|
|
return OpenAIImagesCapabilityNative
|
|
}
|
|
if req.IsEdits() && !req.Multipart {
|
|
return OpenAIImagesCapabilityNative
|
|
}
|
|
if req.ResponseFormat != "" && req.ResponseFormat != "b64_json" {
|
|
return OpenAIImagesCapabilityNative
|
|
}
|
|
return OpenAIImagesCapabilityBasic
|
|
}
|
|
|
|
func hasOpenAINativeImageOptions(exists func(path string) bool) bool {
|
|
for _, path := range []string{
|
|
"background",
|
|
"quality",
|
|
"style",
|
|
"output_format",
|
|
"output_compression",
|
|
"moderation",
|
|
"input_fidelity",
|
|
"partial_images",
|
|
} {
|
|
if exists(path) {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func isOpenAINativeImageOption(name string) bool {
|
|
switch strings.TrimSpace(strings.ToLower(name)) {
|
|
case "background", "quality", "style", "output_format", "output_compression", "moderation", "input_fidelity", "partial_images":
|
|
return true
|
|
default:
|
|
return false
|
|
}
|
|
}
|
|
|
|
func normalizeOpenAIImageSizeTier(size string) string {
|
|
switch strings.ToLower(strings.TrimSpace(size)) {
|
|
case "1024x1024":
|
|
return "1K"
|
|
case "1536x1024", "1024x1536", "1792x1024", "1024x1792", "", "auto":
|
|
return "2K"
|
|
default:
|
|
return "2K"
|
|
}
|
|
}
|
|
|
|
func (s *OpenAIGatewayService) ForwardImages(
|
|
ctx context.Context,
|
|
c *gin.Context,
|
|
account *Account,
|
|
body []byte,
|
|
parsed *OpenAIImagesRequest,
|
|
channelMappedModel string,
|
|
) (*OpenAIForwardResult, error) {
|
|
if parsed == nil {
|
|
return nil, fmt.Errorf("parsed images request is required")
|
|
}
|
|
switch account.Type {
|
|
case AccountTypeAPIKey:
|
|
return s.forwardOpenAIImagesAPIKey(ctx, c, account, body, parsed, channelMappedModel)
|
|
case AccountTypeOAuth:
|
|
return s.forwardOpenAIImagesOAuth(ctx, c, account, parsed, channelMappedModel)
|
|
default:
|
|
return nil, fmt.Errorf("unsupported account type: %s", account.Type)
|
|
}
|
|
}
|
|
|
|
func (s *OpenAIGatewayService) forwardOpenAIImagesAPIKey(
|
|
ctx context.Context,
|
|
c *gin.Context,
|
|
account *Account,
|
|
body []byte,
|
|
parsed *OpenAIImagesRequest,
|
|
channelMappedModel string,
|
|
) (*OpenAIForwardResult, error) {
|
|
startTime := time.Now()
|
|
requestModel := strings.TrimSpace(parsed.Model)
|
|
if mapped := strings.TrimSpace(channelMappedModel); mapped != "" {
|
|
requestModel = mapped
|
|
}
|
|
if err := validateOpenAIImagesModel(requestModel); err != nil {
|
|
return nil, err
|
|
}
|
|
upstreamModel := account.GetMappedModel(requestModel)
|
|
if err := validateOpenAIImagesModel(upstreamModel); err != nil {
|
|
return nil, err
|
|
}
|
|
logger.LegacyPrintf(
|
|
"service.openai_gateway",
|
|
"[OpenAI] Images request routing request_model=%s upstream_model=%s endpoint=%s account_type=%s",
|
|
strings.TrimSpace(parsed.Model),
|
|
upstreamModel,
|
|
parsed.Endpoint,
|
|
account.Type,
|
|
)
|
|
forwardBody, forwardContentType, err := rewriteOpenAIImagesModel(body, parsed.ContentType, upstreamModel)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if !parsed.Multipart {
|
|
setOpsUpstreamRequestBody(c, forwardBody)
|
|
}
|
|
|
|
token, _, err := s.GetAccessToken(ctx, account)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
upstreamReq, err := s.buildOpenAIImagesRequest(ctx, c, account, forwardBody, forwardContentType, token, parsed.Endpoint)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
proxyURL := ""
|
|
if account.ProxyID != nil && account.Proxy != nil {
|
|
proxyURL = account.Proxy.URL()
|
|
}
|
|
upstreamStart := time.Now()
|
|
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
|
|
SetOpsLatencyMs(c, OpsUpstreamLatencyMsKey, time.Since(upstreamStart).Milliseconds())
|
|
if err != nil {
|
|
safeErr := sanitizeUpstreamErrorMessage(err.Error())
|
|
setOpsUpstreamError(c, 0, safeErr, "")
|
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
|
Platform: account.Platform,
|
|
AccountID: account.ID,
|
|
AccountName: account.Name,
|
|
UpstreamStatusCode: 0,
|
|
UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()),
|
|
Kind: "request_error",
|
|
Message: safeErr,
|
|
})
|
|
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
|
|
}
|
|
if resp.StatusCode >= 400 {
|
|
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
|
_ = resp.Body.Close()
|
|
resp.Body = io.NopCloser(bytes.NewReader(respBody))
|
|
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
|
|
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
|
if s.shouldFailoverOpenAIUpstreamResponse(resp.StatusCode, upstreamMsg, respBody) {
|
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
|
Platform: account.Platform,
|
|
AccountID: account.ID,
|
|
AccountName: account.Name,
|
|
UpstreamStatusCode: resp.StatusCode,
|
|
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
|
UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()),
|
|
Kind: "failover",
|
|
Message: upstreamMsg,
|
|
})
|
|
s.handleFailoverSideEffects(ctx, resp, account)
|
|
return nil, &UpstreamFailoverError{
|
|
StatusCode: resp.StatusCode,
|
|
ResponseBody: respBody,
|
|
RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode),
|
|
}
|
|
}
|
|
return s.handleErrorResponse(ctx, resp, c, account, forwardBody)
|
|
}
|
|
defer func() { _ = resp.Body.Close() }()
|
|
|
|
var usage OpenAIUsage
|
|
imageCount := parsed.N
|
|
var firstTokenMs *int
|
|
if parsed.Stream {
|
|
streamUsage, streamCount, ttft, err := s.handleOpenAIImagesStreamingResponse(resp, c, startTime)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
usage = streamUsage
|
|
imageCount = streamCount
|
|
firstTokenMs = ttft
|
|
} else {
|
|
nonStreamUsage, nonStreamCount, err := s.handleOpenAIImagesNonStreamingResponse(resp, c)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
usage = nonStreamUsage
|
|
if nonStreamCount > 0 {
|
|
imageCount = nonStreamCount
|
|
}
|
|
}
|
|
return &OpenAIForwardResult{
|
|
RequestID: resp.Header.Get("x-request-id"),
|
|
Usage: usage,
|
|
Model: requestModel,
|
|
UpstreamModel: upstreamModel,
|
|
Stream: parsed.Stream,
|
|
ResponseHeaders: resp.Header.Clone(),
|
|
Duration: time.Since(startTime),
|
|
FirstTokenMs: firstTokenMs,
|
|
ImageCount: imageCount,
|
|
ImageSize: parsed.SizeTier,
|
|
}, nil
|
|
}
|
|
|
|
func (s *OpenAIGatewayService) buildOpenAIImagesRequest(
|
|
ctx context.Context,
|
|
c *gin.Context,
|
|
account *Account,
|
|
body []byte,
|
|
contentType string,
|
|
token string,
|
|
endpoint string,
|
|
) (*http.Request, error) {
|
|
targetURL := openAIImagesGenerationsURL
|
|
if endpoint == openAIImagesEditsEndpoint {
|
|
targetURL = openAIImagesEditsURL
|
|
}
|
|
baseURL := account.GetOpenAIBaseURL()
|
|
if baseURL != "" {
|
|
validatedURL, err := s.validateUpstreamBaseURL(baseURL)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
targetURL = buildOpenAIImagesURL(validatedURL, endpoint)
|
|
}
|
|
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, bytes.NewReader(body))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
req.Header.Set("Authorization", "Bearer "+token)
|
|
for key, values := range c.Request.Header {
|
|
if !openaiPassthroughAllowedHeaders[strings.ToLower(key)] {
|
|
continue
|
|
}
|
|
for _, value := range values {
|
|
req.Header.Add(key, value)
|
|
}
|
|
}
|
|
customUA := account.GetOpenAIUserAgent()
|
|
if customUA != "" {
|
|
req.Header.Set("User-Agent", customUA)
|
|
}
|
|
if strings.TrimSpace(contentType) != "" {
|
|
req.Header.Set("Content-Type", contentType)
|
|
}
|
|
return req, nil
|
|
}
|
|
|
|
func buildOpenAIImagesURL(base string, endpoint string) string {
|
|
normalized := strings.TrimRight(strings.TrimSpace(base), "/")
|
|
relative := strings.TrimPrefix(strings.TrimSpace(endpoint), "/v1")
|
|
if strings.HasSuffix(normalized, endpoint) || strings.HasSuffix(normalized, relative) {
|
|
return normalized
|
|
}
|
|
if strings.HasSuffix(normalized, "/v1") {
|
|
return normalized + relative
|
|
}
|
|
return normalized + endpoint
|
|
}
|
|
|
|
func rewriteOpenAIImagesModel(body []byte, contentType string, model string) ([]byte, string, error) {
|
|
model = strings.TrimSpace(model)
|
|
if model == "" {
|
|
return body, contentType, nil
|
|
}
|
|
mediaType, _, err := mime.ParseMediaType(contentType)
|
|
if err == nil && strings.EqualFold(mediaType, "multipart/form-data") {
|
|
rewrittenBody, rewrittenType, rewriteErr := rewriteOpenAIImagesMultipartModel(body, contentType, model)
|
|
return rewrittenBody, rewrittenType, rewriteErr
|
|
}
|
|
rewritten, err := sjson.SetBytes(body, "model", model)
|
|
if err != nil {
|
|
return nil, "", fmt.Errorf("rewrite image request model: %w", err)
|
|
}
|
|
return rewritten, contentType, nil
|
|
}
|
|
|
|
func rewriteOpenAIImagesMultipartModel(body []byte, contentType string, model string) ([]byte, string, error) {
|
|
_, params, err := mime.ParseMediaType(contentType)
|
|
if err != nil {
|
|
return nil, "", fmt.Errorf("parse multipart content-type: %w", err)
|
|
}
|
|
boundary := strings.TrimSpace(params["boundary"])
|
|
if boundary == "" {
|
|
return nil, "", fmt.Errorf("multipart boundary is required")
|
|
}
|
|
|
|
reader := multipart.NewReader(bytes.NewReader(body), boundary)
|
|
var buffer bytes.Buffer
|
|
writer := multipart.NewWriter(&buffer)
|
|
modelWritten := false
|
|
|
|
for {
|
|
part, err := reader.NextPart()
|
|
if err == io.EOF {
|
|
break
|
|
}
|
|
if err != nil {
|
|
return nil, "", fmt.Errorf("read multipart body: %w", err)
|
|
}
|
|
|
|
formName := strings.TrimSpace(part.FormName())
|
|
partHeader := cloneMultipartHeader(part.Header)
|
|
target, err := writer.CreatePart(partHeader)
|
|
if err != nil {
|
|
_ = part.Close()
|
|
return nil, "", fmt.Errorf("create multipart part: %w", err)
|
|
}
|
|
|
|
if formName == "model" && part.FileName() == "" {
|
|
if _, err := target.Write([]byte(model)); err != nil {
|
|
_ = part.Close()
|
|
return nil, "", fmt.Errorf("rewrite multipart model: %w", err)
|
|
}
|
|
modelWritten = true
|
|
_ = part.Close()
|
|
continue
|
|
}
|
|
if _, err := io.Copy(target, part); err != nil {
|
|
_ = part.Close()
|
|
return nil, "", fmt.Errorf("copy multipart part: %w", err)
|
|
}
|
|
_ = part.Close()
|
|
}
|
|
|
|
if !modelWritten {
|
|
if err := writer.WriteField("model", model); err != nil {
|
|
return nil, "", fmt.Errorf("append multipart model field: %w", err)
|
|
}
|
|
}
|
|
if err := writer.Close(); err != nil {
|
|
return nil, "", fmt.Errorf("finalize multipart body: %w", err)
|
|
}
|
|
return buffer.Bytes(), writer.FormDataContentType(), nil
|
|
}
|
|
|
|
func cloneMultipartHeader(src textproto.MIMEHeader) textproto.MIMEHeader {
|
|
dst := make(textproto.MIMEHeader, len(src))
|
|
for key, values := range src {
|
|
copied := make([]string, len(values))
|
|
copy(copied, values)
|
|
dst[key] = copied
|
|
}
|
|
return dst
|
|
}
|
|
|
|
func (s *OpenAIGatewayService) handleOpenAIImagesNonStreamingResponse(resp *http.Response, c *gin.Context) (OpenAIUsage, int, error) {
|
|
body, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, openAITooLargeError)
|
|
if err != nil {
|
|
return OpenAIUsage{}, 0, err
|
|
}
|
|
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
|
|
contentType := "application/json"
|
|
if s.cfg != nil && !s.cfg.Security.ResponseHeaders.Enabled {
|
|
if upstreamType := resp.Header.Get("Content-Type"); upstreamType != "" {
|
|
contentType = upstreamType
|
|
}
|
|
}
|
|
c.Data(resp.StatusCode, contentType, body)
|
|
|
|
usage, _ := extractOpenAIUsageFromJSONBytes(body)
|
|
return usage, extractOpenAIImageCountFromJSONBytes(body), nil
|
|
}
|
|
|
|
func (s *OpenAIGatewayService) handleOpenAIImagesStreamingResponse(
|
|
resp *http.Response,
|
|
c *gin.Context,
|
|
startTime time.Time,
|
|
) (OpenAIUsage, int, *int, error) {
|
|
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
|
|
contentType := strings.TrimSpace(resp.Header.Get("Content-Type"))
|
|
if contentType == "" {
|
|
contentType = "text/event-stream"
|
|
}
|
|
c.Status(resp.StatusCode)
|
|
c.Header("Content-Type", contentType)
|
|
|
|
flusher, ok := c.Writer.(http.Flusher)
|
|
if !ok {
|
|
return OpenAIUsage{}, 0, nil, fmt.Errorf("streaming is not supported by response writer")
|
|
}
|
|
|
|
reader := bufio.NewReader(resp.Body)
|
|
usage := OpenAIUsage{}
|
|
imageCount := 0
|
|
var firstTokenMs *int
|
|
|
|
for {
|
|
line, err := reader.ReadBytes('\n')
|
|
if len(line) > 0 {
|
|
if firstTokenMs == nil {
|
|
ms := int(time.Since(startTime).Milliseconds())
|
|
firstTokenMs = &ms
|
|
}
|
|
if _, writeErr := c.Writer.Write(line); writeErr != nil {
|
|
return OpenAIUsage{}, 0, firstTokenMs, writeErr
|
|
}
|
|
flusher.Flush()
|
|
|
|
if data, ok := extractOpenAISSEDataLine(strings.TrimRight(string(line), "\r\n")); ok && data != "" && data != "[DONE]" {
|
|
dataBytes := []byte(data)
|
|
mergeOpenAIUsage(&usage, dataBytes)
|
|
if count := extractOpenAIImageCountFromJSONBytes(dataBytes); count > imageCount {
|
|
imageCount = count
|
|
}
|
|
}
|
|
}
|
|
if err == io.EOF {
|
|
break
|
|
}
|
|
if err != nil {
|
|
return OpenAIUsage{}, 0, firstTokenMs, err
|
|
}
|
|
}
|
|
return usage, imageCount, firstTokenMs, nil
|
|
}
|
|
|
|
func mergeOpenAIUsage(dst *OpenAIUsage, body []byte) {
|
|
if dst == nil {
|
|
return
|
|
}
|
|
if parsed, ok := extractOpenAIUsageFromJSONBytes(body); ok {
|
|
if parsed.InputTokens > 0 {
|
|
dst.InputTokens = parsed.InputTokens
|
|
}
|
|
if parsed.OutputTokens > 0 {
|
|
dst.OutputTokens = parsed.OutputTokens
|
|
}
|
|
if parsed.CacheReadInputTokens > 0 {
|
|
dst.CacheReadInputTokens = parsed.CacheReadInputTokens
|
|
}
|
|
if parsed.ImageOutputTokens > 0 {
|
|
dst.ImageOutputTokens = parsed.ImageOutputTokens
|
|
}
|
|
}
|
|
}
|
|
|
|
func extractOpenAIImageCountFromJSONBytes(body []byte) int {
|
|
if len(body) == 0 || !gjson.ValidBytes(body) {
|
|
return 0
|
|
}
|
|
data := gjson.GetBytes(body, "data")
|
|
if data.Exists() && data.IsArray() {
|
|
return len(data.Array())
|
|
}
|
|
return 0
|
|
}
|
|
|
|
type openAIImagePointerInfo struct {
|
|
Pointer string
|
|
DownloadURL string
|
|
B64JSON string
|
|
MimeType string
|
|
Prompt string
|
|
}
|
|
|
|
func collectOpenAIImagePointers(body []byte) []openAIImagePointerInfo {
|
|
if len(body) == 0 {
|
|
return nil
|
|
}
|
|
prompt := ""
|
|
for _, path := range []string{
|
|
"message.metadata.dalle.prompt",
|
|
"metadata.dalle.prompt",
|
|
"revised_prompt",
|
|
} {
|
|
if value := strings.TrimSpace(gjson.GetBytes(body, path).String()); value != "" {
|
|
prompt = value
|
|
break
|
|
}
|
|
}
|
|
matches := openAIImagePointerMatches(body)
|
|
out := make([]openAIImagePointerInfo, 0, len(matches))
|
|
for _, pointer := range matches {
|
|
out = append(out, openAIImagePointerInfo{Pointer: pointer, Prompt: prompt})
|
|
}
|
|
return mergeOpenAIImagePointerInfos(out, collectOpenAIImageInlineAssets(body, prompt))
|
|
}
|
|
|
|
func openAIImagePointerMatches(body []byte) []string {
|
|
raw := string(body)
|
|
matches := make([]string, 0, 4)
|
|
for _, prefix := range []string{"file-service://", "sediment://"} {
|
|
start := 0
|
|
for {
|
|
idx := strings.Index(raw[start:], prefix)
|
|
if idx < 0 {
|
|
break
|
|
}
|
|
idx += start
|
|
end := idx + len(prefix)
|
|
for end < len(raw) {
|
|
ch := raw[end]
|
|
if ch != '-' && ch != '_' &&
|
|
(ch < '0' || ch > '9') &&
|
|
(ch < 'a' || ch > 'z') &&
|
|
(ch < 'A' || ch > 'Z') {
|
|
break
|
|
}
|
|
end++
|
|
}
|
|
matches = append(matches, raw[idx:end])
|
|
start = end
|
|
}
|
|
}
|
|
return dedupeStrings(matches)
|
|
}
|
|
|
|
func mergeOpenAIImagePointerInfos(existing []openAIImagePointerInfo, next []openAIImagePointerInfo) []openAIImagePointerInfo {
|
|
if len(next) == 0 {
|
|
return existing
|
|
}
|
|
seen := make(map[string]openAIImagePointerInfo, len(existing)+len(next))
|
|
out := make([]openAIImagePointerInfo, 0, len(existing)+len(next))
|
|
for _, item := range existing {
|
|
if key := item.identityKey(); key != "" {
|
|
seen[key] = item
|
|
}
|
|
out = append(out, item)
|
|
}
|
|
for _, item := range next {
|
|
key := item.identityKey()
|
|
if key == "" {
|
|
continue
|
|
}
|
|
if existingItem, ok := seen[key]; ok {
|
|
merged := mergeOpenAIImagePointerInfo(existingItem, item)
|
|
if merged != existingItem {
|
|
for i := range out {
|
|
if out[i].identityKey() == key {
|
|
out[i] = merged
|
|
break
|
|
}
|
|
}
|
|
seen[key] = merged
|
|
}
|
|
continue
|
|
}
|
|
seen[key] = item
|
|
out = append(out, item)
|
|
}
|
|
return out
|
|
}
|
|
|
|
func (i openAIImagePointerInfo) identityKey() string {
|
|
switch {
|
|
case strings.TrimSpace(i.Pointer) != "":
|
|
return "pointer:" + strings.TrimSpace(i.Pointer)
|
|
case strings.TrimSpace(i.DownloadURL) != "":
|
|
return "download:" + strings.TrimSpace(i.DownloadURL)
|
|
case strings.TrimSpace(i.B64JSON) != "":
|
|
b64 := strings.TrimSpace(i.B64JSON)
|
|
if len(b64) > 64 {
|
|
b64 = b64[:64]
|
|
}
|
|
return "b64:" + b64
|
|
default:
|
|
return ""
|
|
}
|
|
}
|
|
|
|
func mergeOpenAIImagePointerInfo(existing, next openAIImagePointerInfo) openAIImagePointerInfo {
|
|
merged := existing
|
|
if strings.TrimSpace(merged.Pointer) == "" {
|
|
merged.Pointer = next.Pointer
|
|
}
|
|
if strings.TrimSpace(merged.DownloadURL) == "" {
|
|
merged.DownloadURL = next.DownloadURL
|
|
}
|
|
if strings.TrimSpace(merged.B64JSON) == "" {
|
|
merged.B64JSON = next.B64JSON
|
|
}
|
|
if strings.TrimSpace(merged.MimeType) == "" {
|
|
merged.MimeType = next.MimeType
|
|
}
|
|
if strings.TrimSpace(merged.Prompt) == "" {
|
|
merged.Prompt = next.Prompt
|
|
}
|
|
return merged
|
|
}
|
|
|
|
func resolveOpenAIImageBytes(
|
|
ctx context.Context,
|
|
client *req.Client,
|
|
headers http.Header,
|
|
conversationID string,
|
|
pointer openAIImagePointerInfo,
|
|
) ([]byte, error) {
|
|
if normalized := normalizeOpenAIImageBase64(pointer.B64JSON); normalized != "" {
|
|
return base64.StdEncoding.DecodeString(normalized)
|
|
}
|
|
if downloadURL := strings.TrimSpace(pointer.DownloadURL); downloadURL != "" {
|
|
return downloadOpenAIImageBytes(ctx, client, headers, downloadURL)
|
|
}
|
|
if strings.TrimSpace(pointer.Pointer) == "" {
|
|
return nil, fmt.Errorf("image asset is missing pointer, url, and base64 data")
|
|
}
|
|
downloadURL, err := fetchOpenAIImageDownloadURL(ctx, client, headers, conversationID, pointer.Pointer)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return downloadOpenAIImageBytes(ctx, client, headers, downloadURL)
|
|
}
|
|
|
|
func normalizeOpenAIImageBase64(raw string) string {
|
|
raw = strings.TrimSpace(raw)
|
|
if raw == "" {
|
|
return ""
|
|
}
|
|
if strings.HasPrefix(strings.ToLower(raw), "data:") {
|
|
if idx := strings.Index(raw, ","); idx >= 0 && idx+1 < len(raw) {
|
|
raw = raw[idx+1:]
|
|
}
|
|
}
|
|
raw = strings.TrimSpace(raw)
|
|
raw = strings.TrimRight(raw, "=") + strings.Repeat("=", (4-len(raw)%4)%4)
|
|
if raw == "" {
|
|
return ""
|
|
}
|
|
if _, err := base64.StdEncoding.DecodeString(raw); err != nil {
|
|
return ""
|
|
}
|
|
return raw
|
|
}
|
|
|
|
func collectOpenAIImageInlineAssets(body []byte, fallbackPrompt string) []openAIImagePointerInfo {
|
|
if len(body) == 0 || !gjson.ValidBytes(body) {
|
|
return nil
|
|
}
|
|
var decoded any
|
|
if err := json.Unmarshal(body, &decoded); err != nil {
|
|
return nil
|
|
}
|
|
var out []openAIImagePointerInfo
|
|
walkOpenAIImageInlineAssets(decoded, strings.TrimSpace(fallbackPrompt), &out)
|
|
return out
|
|
}
|
|
|
|
func walkOpenAIImageInlineAssets(node any, prompt string, out *[]openAIImagePointerInfo) {
|
|
switch value := node.(type) {
|
|
case map[string]any:
|
|
localPrompt := prompt
|
|
for _, key := range []string{"revised_prompt", "image_gen_title", "prompt"} {
|
|
if v, ok := value[key].(string); ok && strings.TrimSpace(v) != "" {
|
|
localPrompt = strings.TrimSpace(v)
|
|
break
|
|
}
|
|
}
|
|
item := openAIImagePointerInfo{
|
|
Prompt: localPrompt,
|
|
Pointer: firstNonEmptyString(value["asset_pointer"], value["pointer"]),
|
|
DownloadURL: firstNonEmptyString(value["download_url"], value["url"], value["image_url"]),
|
|
B64JSON: firstNonEmptyString(value["b64_json"], value["base64"], value["image_base64"]),
|
|
MimeType: firstNonEmptyString(value["mime_type"], value["mimeType"], value["content_type"]),
|
|
}
|
|
switch {
|
|
case strings.HasPrefix(strings.TrimSpace(item.Pointer), "file-service://"),
|
|
strings.HasPrefix(strings.TrimSpace(item.Pointer), "sediment://"),
|
|
isLikelyOpenAIImageDownloadURL(item.DownloadURL),
|
|
normalizeOpenAIImageBase64(item.B64JSON) != "":
|
|
*out = append(*out, item)
|
|
}
|
|
for _, child := range value {
|
|
walkOpenAIImageInlineAssets(child, localPrompt, out)
|
|
}
|
|
case []any:
|
|
for _, child := range value {
|
|
walkOpenAIImageInlineAssets(child, prompt, out)
|
|
}
|
|
}
|
|
}
|
|
|
|
func firstNonEmptyString(values ...any) string {
|
|
for _, value := range values {
|
|
if s, ok := value.(string); ok && strings.TrimSpace(s) != "" {
|
|
return strings.TrimSpace(s)
|
|
}
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func isLikelyOpenAIImageDownloadURL(raw string) bool {
|
|
raw = strings.TrimSpace(raw)
|
|
if raw == "" {
|
|
return false
|
|
}
|
|
if strings.HasPrefix(strings.ToLower(raw), "data:image/") {
|
|
return true
|
|
}
|
|
if !strings.HasPrefix(strings.ToLower(raw), "http://") && !strings.HasPrefix(strings.ToLower(raw), "https://") {
|
|
return false
|
|
}
|
|
lower := strings.ToLower(raw)
|
|
return strings.Contains(lower, "/download") ||
|
|
strings.Contains(lower, ".png") ||
|
|
strings.Contains(lower, ".jpg") ||
|
|
strings.Contains(lower, ".jpeg") ||
|
|
strings.Contains(lower, ".webp")
|
|
}
|
|
|
|
func fetchOpenAIImageDownloadURL(
|
|
ctx context.Context,
|
|
client *req.Client,
|
|
headers http.Header,
|
|
conversationID string,
|
|
pointer string,
|
|
) (string, error) {
|
|
url := ""
|
|
allowConversationRetry := false
|
|
switch {
|
|
case strings.HasPrefix(pointer, "file-service://"):
|
|
fileID := strings.TrimPrefix(pointer, "file-service://")
|
|
url = fmt.Sprintf("%s/%s/download", openAIChatGPTFilesURL, fileID)
|
|
case strings.HasPrefix(pointer, "sediment://"):
|
|
attachmentID := strings.TrimPrefix(pointer, "sediment://")
|
|
url = fmt.Sprintf("https://chatgpt.com/backend-api/conversation/%s/attachment/%s/download", conversationID, attachmentID)
|
|
allowConversationRetry = true
|
|
default:
|
|
return "", fmt.Errorf("unsupported image pointer: %s", pointer)
|
|
}
|
|
|
|
var lastErr error
|
|
for attempt := 0; attempt < 8; attempt++ {
|
|
var result struct {
|
|
DownloadURL string `json:"download_url"`
|
|
}
|
|
resp, err := client.R().
|
|
SetContext(ctx).
|
|
SetHeaders(headerToMap(headers)).
|
|
SetSuccessResult(&result).
|
|
Get(url)
|
|
if err != nil {
|
|
lastErr = err
|
|
} else if resp.IsSuccessState() && strings.TrimSpace(result.DownloadURL) != "" {
|
|
return strings.TrimSpace(result.DownloadURL), nil
|
|
} else {
|
|
statusErr := newOpenAIImageStatusError(resp, "fetch image download url failed")
|
|
if !allowConversationRetry || !isOpenAIImageTransientConversationNotFoundError(statusErr) {
|
|
return "", statusErr
|
|
}
|
|
lastErr = statusErr
|
|
}
|
|
if attempt == 7 {
|
|
break
|
|
}
|
|
timer := time.NewTimer(750 * time.Millisecond)
|
|
select {
|
|
case <-ctx.Done():
|
|
if !timer.Stop() {
|
|
<-timer.C
|
|
}
|
|
return "", ctx.Err()
|
|
case <-timer.C:
|
|
}
|
|
}
|
|
if lastErr == nil {
|
|
lastErr = fmt.Errorf("fetch image download url failed")
|
|
}
|
|
return "", lastErr
|
|
}
|
|
|
|
func downloadOpenAIImageBytes(ctx context.Context, client *req.Client, headers http.Header, downloadURL string) ([]byte, error) {
|
|
request := client.R().
|
|
SetContext(ctx).
|
|
DisableAutoReadResponse()
|
|
|
|
if strings.HasPrefix(downloadURL, openAIChatGPTStartURL) {
|
|
downloadHeaders := cloneHTTPHeader(headers)
|
|
downloadHeaders.Set("Accept", "image/*,*/*;q=0.8")
|
|
downloadHeaders.Del("Content-Type")
|
|
request.SetHeaders(headerToMap(downloadHeaders))
|
|
} else {
|
|
userAgent := strings.TrimSpace(headers.Get("User-Agent"))
|
|
if userAgent == "" {
|
|
userAgent = openAIImageBackendUserAgent
|
|
}
|
|
request.SetHeader("User-Agent", userAgent)
|
|
}
|
|
|
|
resp, err := request.Get(downloadURL)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer func() {
|
|
if resp != nil && resp.Body != nil {
|
|
_ = resp.Body.Close()
|
|
}
|
|
}()
|
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
|
return nil, newOpenAIImageStatusError(resp, "download image bytes failed")
|
|
}
|
|
return io.ReadAll(io.LimitReader(resp.Body, openAIImageMaxDownloadBytes))
|
|
}
|
|
|
|
type openAIImageStatusError struct {
|
|
StatusCode int
|
|
Message string
|
|
ResponseBody []byte
|
|
ResponseHeaders http.Header
|
|
RequestID string
|
|
URL string
|
|
}
|
|
|
|
func (e *openAIImageStatusError) Error() string {
|
|
if e == nil {
|
|
return "openai image backend request failed"
|
|
}
|
|
if e.Message != "" {
|
|
return e.Message
|
|
}
|
|
if e.StatusCode > 0 {
|
|
return fmt.Sprintf("openai image backend request failed: status %d", e.StatusCode)
|
|
}
|
|
return "openai image backend request failed"
|
|
}
|
|
|
|
func newOpenAIImageStatusError(resp *req.Response, fallback string) error {
|
|
if resp == nil {
|
|
if strings.TrimSpace(fallback) == "" {
|
|
fallback = "openai image backend request failed"
|
|
}
|
|
return fmt.Errorf("%s", fallback)
|
|
}
|
|
|
|
statusCode := resp.StatusCode
|
|
headers := http.Header(nil)
|
|
requestID := ""
|
|
requestURL := ""
|
|
body := []byte(nil)
|
|
|
|
if resp.Response != nil {
|
|
headers = resp.Header.Clone()
|
|
requestID = strings.TrimSpace(resp.Header.Get("x-request-id"))
|
|
if resp.Request != nil && resp.Request.URL != nil {
|
|
requestURL = resp.Request.URL.String()
|
|
}
|
|
if resp.Body != nil {
|
|
body, _ = io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
|
_ = resp.Body.Close()
|
|
}
|
|
}
|
|
|
|
message := sanitizeUpstreamErrorMessage(extractUpstreamErrorMessage(body))
|
|
if message == "" {
|
|
prefix := strings.TrimSpace(fallback)
|
|
if prefix == "" {
|
|
prefix = "openai image backend request failed"
|
|
}
|
|
message = fmt.Sprintf("%s: status %d", prefix, statusCode)
|
|
}
|
|
|
|
return &openAIImageStatusError{
|
|
StatusCode: statusCode,
|
|
Message: message,
|
|
ResponseBody: body,
|
|
ResponseHeaders: headers,
|
|
RequestID: requestID,
|
|
URL: requestURL,
|
|
}
|
|
}
|
|
|
|
func isOpenAIImageTransientConversationNotFoundError(err error) bool {
|
|
statusErr, ok := err.(*openAIImageStatusError)
|
|
if !ok || statusErr == nil || statusErr.StatusCode != http.StatusNotFound {
|
|
return false
|
|
}
|
|
msg := strings.ToLower(strings.TrimSpace(statusErr.Message))
|
|
if strings.Contains(msg, "conversation_not_found") {
|
|
return true
|
|
}
|
|
if strings.Contains(msg, "conversation") && strings.Contains(msg, "not found") {
|
|
return true
|
|
}
|
|
bodyMsg := strings.ToLower(strings.TrimSpace(extractUpstreamErrorMessage(statusErr.ResponseBody)))
|
|
if strings.Contains(bodyMsg, "conversation_not_found") {
|
|
return true
|
|
}
|
|
return strings.Contains(bodyMsg, "conversation") && strings.Contains(bodyMsg, "not found")
|
|
}
|
|
|
|
func cloneHTTPHeader(src http.Header) http.Header {
|
|
dst := make(http.Header, len(src))
|
|
for key, values := range src {
|
|
copied := make([]string, len(values))
|
|
copy(copied, values)
|
|
dst[key] = copied
|
|
}
|
|
return dst
|
|
}
|
|
|
|
func headerToMap(header http.Header) map[string]string {
|
|
if len(header) == 0 {
|
|
return nil
|
|
}
|
|
result := make(map[string]string, len(header))
|
|
for key, values := range header {
|
|
if len(values) == 0 {
|
|
continue
|
|
}
|
|
result[key] = values[0]
|
|
}
|
|
return result
|
|
}
|
|
|
|
func dedupeStrings(values []string) []string {
|
|
if len(values) == 0 {
|
|
return nil
|
|
}
|
|
seen := make(map[string]struct{}, len(values))
|
|
out := make([]string, 0, len(values))
|
|
for _, value := range values {
|
|
if _, ok := seen[value]; ok {
|
|
continue
|
|
}
|
|
seen[value] = struct{}{}
|
|
out = append(out, value)
|
|
}
|
|
return out
|
|
}
|