Merge pull request #1813 from meteor041/meteor041/fix-openai-image-handling
fix: openai image request handling
This commit is contained in:
@@ -187,9 +187,13 @@ func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact
|
|||||||
}
|
}
|
||||||
|
|
||||||
func normalizeCodexModel(model string) string {
|
func normalizeCodexModel(model string) string {
|
||||||
|
model = strings.TrimSpace(model)
|
||||||
if model == "" {
|
if model == "" {
|
||||||
return "gpt-5.4"
|
return "gpt-5.4"
|
||||||
}
|
}
|
||||||
|
if isOpenAIImageGenerationModel(model) {
|
||||||
|
return model
|
||||||
|
}
|
||||||
|
|
||||||
modelID := model
|
modelID := model
|
||||||
if strings.Contains(modelID, "/") {
|
if strings.Contains(modelID, "/") {
|
||||||
@@ -231,6 +235,78 @@ func normalizeCodexModel(model string) string {
|
|||||||
return "gpt-5.4"
|
return "gpt-5.4"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func hasOpenAIImageGenerationTool(reqBody map[string]any) bool {
|
||||||
|
rawTools, ok := reqBody["tools"]
|
||||||
|
if !ok || rawTools == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
tools, ok := rawTools.([]any)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for _, rawTool := range tools {
|
||||||
|
toolMap, ok := rawTool.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(firstNonEmptyString(toolMap["type"])) == "image_generation" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeOpenAIResponsesImageGenerationTools(reqBody map[string]any) bool {
|
||||||
|
rawTools, ok := reqBody["tools"]
|
||||||
|
if !ok || rawTools == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
tools, ok := rawTools.([]any)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
modified := false
|
||||||
|
for _, rawTool := range tools {
|
||||||
|
toolMap, ok := rawTool.(map[string]any)
|
||||||
|
if !ok || strings.TrimSpace(firstNonEmptyString(toolMap["type"])) != "image_generation" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, ok := toolMap["output_format"]; !ok {
|
||||||
|
if value := strings.TrimSpace(firstNonEmptyString(toolMap["format"])); value != "" {
|
||||||
|
toolMap["output_format"] = value
|
||||||
|
modified = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if _, ok := toolMap["output_compression"]; !ok {
|
||||||
|
if value, exists := toolMap["compression"]; exists && value != nil {
|
||||||
|
toolMap["output_compression"] = value
|
||||||
|
modified = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if _, ok := toolMap["format"]; ok {
|
||||||
|
delete(toolMap, "format")
|
||||||
|
modified = true
|
||||||
|
}
|
||||||
|
if _, ok := toolMap["compression"]; ok {
|
||||||
|
delete(toolMap, "compression")
|
||||||
|
modified = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return modified
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateOpenAIResponsesImageModel(reqBody map[string]any, model string) error {
|
||||||
|
if !hasOpenAIImageGenerationTool(reqBody) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
model = strings.TrimSpace(model)
|
||||||
|
if !isOpenAIImageGenerationModel(model) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return fmt.Errorf("/v1/responses image_generation requests require a Responses-capable text model; image-only model %q is not allowed", model)
|
||||||
|
}
|
||||||
|
|
||||||
func normalizeOpenAIModelForUpstream(account *Account, model string) string {
|
func normalizeOpenAIModelForUpstream(account *Account, model string) string {
|
||||||
if account == nil || account.Type == AccountTypeOAuth {
|
if account == nil || account.Type == AccountTypeOAuth {
|
||||||
return normalizeCodexModel(model)
|
return normalizeCodexModel(model)
|
||||||
|
|||||||
@@ -217,6 +217,42 @@ func TestApplyCodexOAuthTransform_NormalizeCodexTools_PreservesResponsesFunction
|
|||||||
require.Equal(t, "bash", first["name"])
|
require.Equal(t, "bash", first["name"])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestNormalizeOpenAIResponsesImageGenerationTools_RewritesLegacyFields(t *testing.T) {
|
||||||
|
reqBody := map[string]any{
|
||||||
|
"tools": []any{
|
||||||
|
map[string]any{
|
||||||
|
"type": "image_generation",
|
||||||
|
"format": "png",
|
||||||
|
"compression": 60,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
modified := normalizeOpenAIResponsesImageGenerationTools(reqBody)
|
||||||
|
require.True(t, modified)
|
||||||
|
|
||||||
|
tools, ok := reqBody["tools"].([]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
first, ok := tools[0].(map[string]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, "png", first["output_format"])
|
||||||
|
require.Equal(t, 60, first["output_compression"])
|
||||||
|
_, hasFormat := first["format"]
|
||||||
|
require.False(t, hasFormat)
|
||||||
|
_, hasCompression := first["compression"]
|
||||||
|
require.False(t, hasCompression)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateOpenAIResponsesImageModel_RejectsImageOnlyModel(t *testing.T) {
|
||||||
|
err := validateOpenAIResponsesImageModel(map[string]any{
|
||||||
|
"tools": []any{
|
||||||
|
map[string]any{"type": "image_generation"},
|
||||||
|
},
|
||||||
|
}, "gpt-image-2")
|
||||||
|
|
||||||
|
require.ErrorContains(t, err, `/v1/responses image_generation requests require a Responses-capable text model`)
|
||||||
|
}
|
||||||
|
|
||||||
func TestApplyCodexOAuthTransform_EmptyInput(t *testing.T) {
|
func TestApplyCodexOAuthTransform_EmptyInput(t *testing.T) {
|
||||||
// 空 input 应保持为空且不触发异常。
|
// 空 input 应保持为空且不触发异常。
|
||||||
|
|
||||||
|
|||||||
@@ -1935,6 +1935,12 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
|||||||
markPatchSet("instructions", "You are a helpful coding assistant.")
|
markPatchSet("instructions", "You are a helpful coding assistant.")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if normalizeOpenAIResponsesImageGenerationTools(reqBody) {
|
||||||
|
bodyModified = true
|
||||||
|
disablePatch()
|
||||||
|
logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Normalized /responses image_generation tool payload")
|
||||||
|
}
|
||||||
|
|
||||||
// 对所有请求执行模型映射(包含 Codex CLI)。
|
// 对所有请求执行模型映射(包含 Codex CLI)。
|
||||||
billingModel := account.GetMappedModel(reqModel)
|
billingModel := account.GetMappedModel(reqModel)
|
||||||
if billingModel != reqModel {
|
if billingModel != reqModel {
|
||||||
@@ -1944,6 +1950,26 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
|||||||
markPatchSet("model", billingModel)
|
markPatchSet("model", billingModel)
|
||||||
}
|
}
|
||||||
upstreamModel := billingModel
|
upstreamModel := billingModel
|
||||||
|
if err := validateOpenAIResponsesImageModel(reqBody, upstreamModel); err != nil {
|
||||||
|
setOpsUpstreamError(c, http.StatusBadRequest, err.Error(), "")
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{
|
||||||
|
"error": gin.H{
|
||||||
|
"type": "invalid_request_error",
|
||||||
|
"message": err.Error(),
|
||||||
|
"param": "model",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if hasOpenAIImageGenerationTool(reqBody) {
|
||||||
|
logger.LegacyPrintf(
|
||||||
|
"service.openai_gateway",
|
||||||
|
"[OpenAI] /responses image_generation request inbound_model=%s mapped_model=%s account_type=%s",
|
||||||
|
reqModel,
|
||||||
|
upstreamModel,
|
||||||
|
account.Type,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
// OpenAI OAuth 账号走 ChatGPT internal Codex endpoint,需要将模型名规范化为
|
// OpenAI OAuth 账号走 ChatGPT internal Codex endpoint,需要将模型名规范化为
|
||||||
// 上游可识别的 Codex/GPT 系列。API Key 账号则应保留原始/映射后的模型名,
|
// 上游可识别的 Codex/GPT 系列。API Key 账号则应保留原始/映射后的模型名,
|
||||||
|
|||||||
@@ -47,6 +47,7 @@ const (
|
|||||||
|
|
||||||
openAIImageBackendUserAgent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36"
|
openAIImageBackendUserAgent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36"
|
||||||
openAIImageRequirementsDiff = "0fffff"
|
openAIImageRequirementsDiff = "0fffff"
|
||||||
|
openAIImageLifecycleTimeout = 2 * time.Minute
|
||||||
)
|
)
|
||||||
|
|
||||||
type OpenAIImagesCapability string
|
type OpenAIImagesCapability string
|
||||||
@@ -148,6 +149,9 @@ func (s *OpenAIGatewayService) ParseOpenAIImagesRequest(c *gin.Context, body []b
|
|||||||
}
|
}
|
||||||
|
|
||||||
applyOpenAIImagesDefaults(req)
|
applyOpenAIImagesDefaults(req)
|
||||||
|
if err := validateOpenAIImagesModel(req.Model); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
req.SizeTier = normalizeOpenAIImageSizeTier(req.Size)
|
req.SizeTier = normalizeOpenAIImageSizeTier(req.Size)
|
||||||
req.RequiredCapability = classifyOpenAIImagesCapability(req)
|
req.RequiredCapability = classifyOpenAIImagesCapability(req)
|
||||||
return req, nil
|
return req, nil
|
||||||
@@ -295,6 +299,21 @@ func applyOpenAIImagesDefaults(req *OpenAIImagesRequest) {
|
|||||||
req.Model = "gpt-image-2"
|
req.Model = "gpt-image-2"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isOpenAIImageGenerationModel(model string) bool {
|
||||||
|
return strings.HasPrefix(strings.ToLower(strings.TrimSpace(model)), "gpt-image-")
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateOpenAIImagesModel(model string) error {
|
||||||
|
model = strings.TrimSpace(model)
|
||||||
|
if isOpenAIImageGenerationModel(model) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if model == "" {
|
||||||
|
return fmt.Errorf("images endpoint requires an image model")
|
||||||
|
}
|
||||||
|
return fmt.Errorf("images endpoint requires an image model, got %q", model)
|
||||||
|
}
|
||||||
|
|
||||||
func normalizeOpenAIImagesEndpointPath(path string) string {
|
func normalizeOpenAIImagesEndpointPath(path string) string {
|
||||||
trimmed := strings.TrimSpace(path)
|
trimmed := strings.TrimSpace(path)
|
||||||
switch {
|
switch {
|
||||||
@@ -400,7 +419,21 @@ func (s *OpenAIGatewayService) forwardOpenAIImagesAPIKey(
|
|||||||
if mapped := strings.TrimSpace(channelMappedModel); mapped != "" {
|
if mapped := strings.TrimSpace(channelMappedModel); mapped != "" {
|
||||||
requestModel = mapped
|
requestModel = mapped
|
||||||
}
|
}
|
||||||
|
if err := validateOpenAIImagesModel(requestModel); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
upstreamModel := account.GetMappedModel(requestModel)
|
upstreamModel := account.GetMappedModel(requestModel)
|
||||||
|
if err := validateOpenAIImagesModel(upstreamModel); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
logger.LegacyPrintf(
|
||||||
|
"service.openai_gateway",
|
||||||
|
"[OpenAI] Images request routing request_model=%s upstream_model=%s endpoint=%s account_type=%s",
|
||||||
|
strings.TrimSpace(parsed.Model),
|
||||||
|
upstreamModel,
|
||||||
|
parsed.Endpoint,
|
||||||
|
account.Type,
|
||||||
|
)
|
||||||
forwardBody, forwardContentType, err := rewriteOpenAIImagesModel(body, parsed.ContentType, upstreamModel)
|
forwardBody, forwardContentType, err := rewriteOpenAIImagesModel(body, parsed.ContentType, upstreamModel)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -759,6 +792,17 @@ func (s *OpenAIGatewayService) forwardOpenAIImagesOAuth(
|
|||||||
if mapped := strings.TrimSpace(channelMappedModel); mapped != "" {
|
if mapped := strings.TrimSpace(channelMappedModel); mapped != "" {
|
||||||
requestModel = mapped
|
requestModel = mapped
|
||||||
}
|
}
|
||||||
|
if err := validateOpenAIImagesModel(requestModel); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
logger.LegacyPrintf(
|
||||||
|
"service.openai_gateway",
|
||||||
|
"[OpenAI] Images request routing request_model=%s endpoint=%s account_type=%s uploads=%d",
|
||||||
|
requestModel,
|
||||||
|
parsed.Endpoint,
|
||||||
|
account.Type,
|
||||||
|
len(parsed.Uploads),
|
||||||
|
)
|
||||||
|
|
||||||
token, _, err := s.GetAccessToken(ctx, account)
|
token, _, err := s.GetAccessToken(ctx, account)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -844,8 +888,18 @@ func (s *OpenAIGatewayService) forwardOpenAIImagesOAuth(
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
pointerInfos = mergeOpenAIImagePointerInfos(pointerInfos, nil)
|
pointerInfos = mergeOpenAIImagePointerInfos(pointerInfos, nil)
|
||||||
|
logger.LegacyPrintf(
|
||||||
|
"service.openai_gateway",
|
||||||
|
"[OpenAI] Image extraction stream conversation_id=%s total_assets=%d file_service_assets=%d direct_assets=%d",
|
||||||
|
conversationID,
|
||||||
|
len(pointerInfos),
|
||||||
|
countOpenAIFileServicePointerInfos(pointerInfos),
|
||||||
|
countOpenAIDirectImageAssets(pointerInfos),
|
||||||
|
)
|
||||||
|
lifecycleCtx, releaseLifecycleCtx := detachOpenAIImageLifecycleContext(ctx, openAIImageLifecycleTimeout)
|
||||||
|
defer releaseLifecycleCtx()
|
||||||
if conversationID != "" && !hasOpenAIFileServicePointerInfos(pointerInfos) {
|
if conversationID != "" && !hasOpenAIFileServicePointerInfos(pointerInfos) {
|
||||||
polledPointers, pollErr := pollOpenAIImageConversation(ctx, client, headers, conversationID)
|
polledPointers, pollErr := pollOpenAIImageConversation(lifecycleCtx, client, headers, conversationID)
|
||||||
if pollErr != nil {
|
if pollErr != nil {
|
||||||
return nil, s.wrapOpenAIImageBackendError(ctx, c, account, pollErr)
|
return nil, s.wrapOpenAIImageBackendError(ctx, c, account, pollErr)
|
||||||
}
|
}
|
||||||
@@ -853,10 +907,11 @@ func (s *OpenAIGatewayService) forwardOpenAIImagesOAuth(
|
|||||||
}
|
}
|
||||||
pointerInfos = preferOpenAIFileServicePointerInfos(pointerInfos)
|
pointerInfos = preferOpenAIFileServicePointerInfos(pointerInfos)
|
||||||
if len(pointerInfos) == 0 {
|
if len(pointerInfos) == 0 {
|
||||||
|
logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Image extraction yielded no assets conversation_id=%s", conversationID)
|
||||||
return nil, fmt.Errorf("openai image conversation returned no downloadable images")
|
return nil, fmt.Errorf("openai image conversation returned no downloadable images")
|
||||||
}
|
}
|
||||||
|
|
||||||
responseBody, imageCount, err := buildOpenAIImageResponse(ctx, client, headers, conversationID, pointerInfos)
|
responseBody, imageCount, err := buildOpenAIImageResponse(lifecycleCtx, client, headers, conversationID, pointerInfos)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, s.wrapOpenAIImageBackendError(ctx, c, account, err)
|
return nil, s.wrapOpenAIImageBackendError(ctx, c, account, err)
|
||||||
}
|
}
|
||||||
@@ -1283,8 +1338,11 @@ func buildOpenAIImageConversationRequest(parsed *OpenAIImagesRequest, parentMess
|
|||||||
}
|
}
|
||||||
|
|
||||||
type openAIImagePointerInfo struct {
|
type openAIImagePointerInfo struct {
|
||||||
Pointer string
|
Pointer string
|
||||||
Prompt string
|
DownloadURL string
|
||||||
|
B64JSON string
|
||||||
|
MimeType string
|
||||||
|
Prompt string
|
||||||
}
|
}
|
||||||
|
|
||||||
type openAIImageToolMessage struct {
|
type openAIImageToolMessage struct {
|
||||||
@@ -1336,10 +1394,6 @@ func collectOpenAIImagePointers(body []byte) []openAIImagePointerInfo {
|
|||||||
if len(body) == 0 {
|
if len(body) == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
matches := openAIImagePointerMatches(body)
|
|
||||||
if len(matches) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
prompt := ""
|
prompt := ""
|
||||||
for _, path := range []string{
|
for _, path := range []string{
|
||||||
"message.metadata.dalle.prompt",
|
"message.metadata.dalle.prompt",
|
||||||
@@ -1351,11 +1405,12 @@ func collectOpenAIImagePointers(body []byte) []openAIImagePointerInfo {
|
|||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
matches := openAIImagePointerMatches(body)
|
||||||
out := make([]openAIImagePointerInfo, 0, len(matches))
|
out := make([]openAIImagePointerInfo, 0, len(matches))
|
||||||
for _, pointer := range matches {
|
for _, pointer := range matches {
|
||||||
out = append(out, openAIImagePointerInfo{Pointer: pointer, Prompt: prompt})
|
out = append(out, openAIImagePointerInfo{Pointer: pointer, Prompt: prompt})
|
||||||
}
|
}
|
||||||
return out
|
return mergeOpenAIImagePointerInfos(out, collectOpenAIImageInlineAssets(body, prompt))
|
||||||
}
|
}
|
||||||
|
|
||||||
func openAIImagePointerMatches(body []byte) []string {
|
func openAIImagePointerMatches(body []byte) []string {
|
||||||
@@ -1394,27 +1449,72 @@ func mergeOpenAIImagePointerInfos(existing []openAIImagePointerInfo, next []open
|
|||||||
seen := make(map[string]openAIImagePointerInfo, len(existing)+len(next))
|
seen := make(map[string]openAIImagePointerInfo, len(existing)+len(next))
|
||||||
out := make([]openAIImagePointerInfo, 0, len(existing)+len(next))
|
out := make([]openAIImagePointerInfo, 0, len(existing)+len(next))
|
||||||
for _, item := range existing {
|
for _, item := range existing {
|
||||||
seen[item.Pointer] = item
|
if key := item.identityKey(); key != "" {
|
||||||
|
seen[key] = item
|
||||||
|
}
|
||||||
out = append(out, item)
|
out = append(out, item)
|
||||||
}
|
}
|
||||||
for _, item := range next {
|
for _, item := range next {
|
||||||
if existingItem, ok := seen[item.Pointer]; ok {
|
key := item.identityKey()
|
||||||
if existingItem.Prompt == "" && item.Prompt != "" {
|
if key == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if existingItem, ok := seen[key]; ok {
|
||||||
|
merged := mergeOpenAIImagePointerInfo(existingItem, item)
|
||||||
|
if merged != existingItem {
|
||||||
for i := range out {
|
for i := range out {
|
||||||
if out[i].Pointer == item.Pointer {
|
if out[i].identityKey() == key {
|
||||||
out[i].Prompt = item.Prompt
|
out[i] = merged
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
seen[key] = merged
|
||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
seen[item.Pointer] = item
|
seen[key] = item
|
||||||
out = append(out, item)
|
out = append(out, item)
|
||||||
}
|
}
|
||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (i openAIImagePointerInfo) identityKey() string {
|
||||||
|
switch {
|
||||||
|
case strings.TrimSpace(i.Pointer) != "":
|
||||||
|
return "pointer:" + strings.TrimSpace(i.Pointer)
|
||||||
|
case strings.TrimSpace(i.DownloadURL) != "":
|
||||||
|
return "download:" + strings.TrimSpace(i.DownloadURL)
|
||||||
|
case strings.TrimSpace(i.B64JSON) != "":
|
||||||
|
b64 := strings.TrimSpace(i.B64JSON)
|
||||||
|
if len(b64) > 64 {
|
||||||
|
b64 = b64[:64]
|
||||||
|
}
|
||||||
|
return "b64:" + b64
|
||||||
|
default:
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mergeOpenAIImagePointerInfo(existing, next openAIImagePointerInfo) openAIImagePointerInfo {
|
||||||
|
merged := existing
|
||||||
|
if strings.TrimSpace(merged.Pointer) == "" {
|
||||||
|
merged.Pointer = next.Pointer
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(merged.DownloadURL) == "" {
|
||||||
|
merged.DownloadURL = next.DownloadURL
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(merged.B64JSON) == "" {
|
||||||
|
merged.B64JSON = next.B64JSON
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(merged.MimeType) == "" {
|
||||||
|
merged.MimeType = next.MimeType
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(merged.Prompt) == "" {
|
||||||
|
merged.Prompt = next.Prompt
|
||||||
|
}
|
||||||
|
return merged
|
||||||
|
}
|
||||||
|
|
||||||
func hasOpenAIFileServicePointerInfos(items []openAIImagePointerInfo) bool {
|
func hasOpenAIFileServicePointerInfos(items []openAIImagePointerInfo) bool {
|
||||||
for _, item := range items {
|
for _, item := range items {
|
||||||
if strings.HasPrefix(item.Pointer, "file-service://") {
|
if strings.HasPrefix(item.Pointer, "file-service://") {
|
||||||
@@ -1424,6 +1524,26 @@ func hasOpenAIFileServicePointerInfos(items []openAIImagePointerInfo) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func countOpenAIFileServicePointerInfos(items []openAIImagePointerInfo) int {
|
||||||
|
count := 0
|
||||||
|
for _, item := range items {
|
||||||
|
if strings.HasPrefix(item.Pointer, "file-service://") {
|
||||||
|
count++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return count
|
||||||
|
}
|
||||||
|
|
||||||
|
func countOpenAIDirectImageAssets(items []openAIImagePointerInfo) int {
|
||||||
|
count := 0
|
||||||
|
for _, item := range items {
|
||||||
|
if strings.TrimSpace(item.DownloadURL) != "" || strings.TrimSpace(item.B64JSON) != "" {
|
||||||
|
count++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return count
|
||||||
|
}
|
||||||
|
|
||||||
func preferOpenAIFileServicePointerInfos(items []openAIImagePointerInfo) []openAIImagePointerInfo {
|
func preferOpenAIFileServicePointerInfos(items []openAIImagePointerInfo) []openAIImagePointerInfo {
|
||||||
if !hasOpenAIFileServicePointerInfos(items) {
|
if !hasOpenAIFileServicePointerInfos(items) {
|
||||||
return items
|
return items
|
||||||
@@ -1591,11 +1711,7 @@ func buildOpenAIImageResponse(
|
|||||||
}
|
}
|
||||||
items := make([]responseItem, 0, len(pointers))
|
items := make([]responseItem, 0, len(pointers))
|
||||||
for _, pointer := range pointers {
|
for _, pointer := range pointers {
|
||||||
downloadURL, err := fetchOpenAIImageDownloadURL(ctx, client, headers, conversationID, pointer.Pointer)
|
data, err := resolveOpenAIImageBytes(ctx, client, headers, conversationID, pointer)
|
||||||
if err != nil {
|
|
||||||
return nil, 0, err
|
|
||||||
}
|
|
||||||
data, err := downloadOpenAIImageBytes(ctx, client, headers, downloadURL)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, err
|
return nil, 0, err
|
||||||
}
|
}
|
||||||
@@ -1615,6 +1731,136 @@ func buildOpenAIImageResponse(
|
|||||||
return body, len(items), nil
|
return body, len(items), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func resolveOpenAIImageBytes(
|
||||||
|
ctx context.Context,
|
||||||
|
client *req.Client,
|
||||||
|
headers http.Header,
|
||||||
|
conversationID string,
|
||||||
|
pointer openAIImagePointerInfo,
|
||||||
|
) ([]byte, error) {
|
||||||
|
if normalized := normalizeOpenAIImageBase64(pointer.B64JSON); normalized != "" {
|
||||||
|
return base64.StdEncoding.DecodeString(normalized)
|
||||||
|
}
|
||||||
|
if downloadURL := strings.TrimSpace(pointer.DownloadURL); downloadURL != "" {
|
||||||
|
return downloadOpenAIImageBytes(ctx, client, headers, downloadURL)
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(pointer.Pointer) == "" {
|
||||||
|
return nil, fmt.Errorf("image asset is missing pointer, url, and base64 data")
|
||||||
|
}
|
||||||
|
downloadURL, err := fetchOpenAIImageDownloadURL(ctx, client, headers, conversationID, pointer.Pointer)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return downloadOpenAIImageBytes(ctx, client, headers, downloadURL)
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeOpenAIImageBase64(raw string) string {
|
||||||
|
raw = strings.TrimSpace(raw)
|
||||||
|
if raw == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(strings.ToLower(raw), "data:") {
|
||||||
|
if idx := strings.Index(raw, ","); idx >= 0 && idx+1 < len(raw) {
|
||||||
|
raw = raw[idx+1:]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
raw = strings.TrimSpace(raw)
|
||||||
|
raw = strings.TrimRight(raw, "=") + strings.Repeat("=", (4-len(raw)%4)%4)
|
||||||
|
if raw == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if _, err := base64.StdEncoding.DecodeString(raw); err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return raw
|
||||||
|
}
|
||||||
|
|
||||||
|
func collectOpenAIImageInlineAssets(body []byte, fallbackPrompt string) []openAIImagePointerInfo {
|
||||||
|
if len(body) == 0 || !gjson.ValidBytes(body) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
var decoded any
|
||||||
|
if err := json.Unmarshal(body, &decoded); err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
var out []openAIImagePointerInfo
|
||||||
|
walkOpenAIImageInlineAssets(decoded, strings.TrimSpace(fallbackPrompt), &out)
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func walkOpenAIImageInlineAssets(node any, prompt string, out *[]openAIImagePointerInfo) {
|
||||||
|
switch value := node.(type) {
|
||||||
|
case map[string]any:
|
||||||
|
localPrompt := prompt
|
||||||
|
for _, key := range []string{"revised_prompt", "image_gen_title", "prompt"} {
|
||||||
|
if v, ok := value[key].(string); ok && strings.TrimSpace(v) != "" {
|
||||||
|
localPrompt = strings.TrimSpace(v)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
item := openAIImagePointerInfo{
|
||||||
|
Prompt: localPrompt,
|
||||||
|
Pointer: firstNonEmptyString(value["asset_pointer"], value["pointer"]),
|
||||||
|
DownloadURL: firstNonEmptyString(value["download_url"], value["url"], value["image_url"]),
|
||||||
|
B64JSON: firstNonEmptyString(value["b64_json"], value["base64"], value["image_base64"]),
|
||||||
|
MimeType: firstNonEmptyString(value["mime_type"], value["mimeType"], value["content_type"]),
|
||||||
|
}
|
||||||
|
switch {
|
||||||
|
case strings.HasPrefix(strings.TrimSpace(item.Pointer), "file-service://"),
|
||||||
|
strings.HasPrefix(strings.TrimSpace(item.Pointer), "sediment://"),
|
||||||
|
isLikelyOpenAIImageDownloadURL(item.DownloadURL),
|
||||||
|
normalizeOpenAIImageBase64(item.B64JSON) != "":
|
||||||
|
*out = append(*out, item)
|
||||||
|
}
|
||||||
|
for _, child := range value {
|
||||||
|
walkOpenAIImageInlineAssets(child, localPrompt, out)
|
||||||
|
}
|
||||||
|
case []any:
|
||||||
|
for _, child := range value {
|
||||||
|
walkOpenAIImageInlineAssets(child, prompt, out)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func firstNonEmptyString(values ...any) string {
|
||||||
|
for _, value := range values {
|
||||||
|
if s, ok := value.(string); ok && strings.TrimSpace(s) != "" {
|
||||||
|
return strings.TrimSpace(s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func isLikelyOpenAIImageDownloadURL(raw string) bool {
|
||||||
|
raw = strings.TrimSpace(raw)
|
||||||
|
if raw == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(strings.ToLower(raw), "data:image/") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if !strings.HasPrefix(strings.ToLower(raw), "http://") && !strings.HasPrefix(strings.ToLower(raw), "https://") {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
lower := strings.ToLower(raw)
|
||||||
|
return strings.Contains(lower, "/download") ||
|
||||||
|
strings.Contains(lower, ".png") ||
|
||||||
|
strings.Contains(lower, ".jpg") ||
|
||||||
|
strings.Contains(lower, ".jpeg") ||
|
||||||
|
strings.Contains(lower, ".webp")
|
||||||
|
}
|
||||||
|
|
||||||
|
func detachOpenAIImageLifecycleContext(ctx context.Context, timeout time.Duration) (context.Context, context.CancelFunc) {
|
||||||
|
base := context.Background()
|
||||||
|
if ctx != nil {
|
||||||
|
base = context.WithoutCancel(ctx)
|
||||||
|
}
|
||||||
|
if timeout <= 0 {
|
||||||
|
return base, func() {}
|
||||||
|
}
|
||||||
|
return context.WithTimeout(base, timeout)
|
||||||
|
}
|
||||||
|
|
||||||
func fetchOpenAIImageDownloadURL(
|
func fetchOpenAIImageDownloadURL(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
client *req.Client,
|
client *req.Client,
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package service
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"mime/multipart"
|
"mime/multipart"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
@@ -103,3 +104,56 @@ func TestOpenAIGatewayServiceParseOpenAIImagesRequest_ExplicitSizeRequiresNative
|
|||||||
require.NotNil(t, parsed)
|
require.NotNil(t, parsed)
|
||||||
require.Equal(t, OpenAIImagesCapabilityNative, parsed.RequiredCapability)
|
require.Equal(t, OpenAIImagesCapabilityNative, parsed.RequiredCapability)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOpenAIGatewayServiceParseOpenAIImagesRequest_RejectsNonImageModel(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
body := []byte(`{"model":"gpt-5.4","prompt":"draw a cat"}`)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = req
|
||||||
|
|
||||||
|
svc := &OpenAIGatewayService{}
|
||||||
|
parsed, err := svc.ParseOpenAIImagesRequest(c, body)
|
||||||
|
require.Nil(t, parsed)
|
||||||
|
require.ErrorContains(t, err, `images endpoint requires an image model, got "gpt-5.4"`)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCollectOpenAIImagePointers_RecognizesDirectAssets(t *testing.T) {
|
||||||
|
items := collectOpenAIImagePointers([]byte(`{
|
||||||
|
"revised_prompt": "cat astronaut",
|
||||||
|
"parts": [
|
||||||
|
{"b64_json":"QUJD"},
|
||||||
|
{"download_url":"https://files.example.com/image.png?sig=1"},
|
||||||
|
{"asset_pointer":"file-service://file_123"}
|
||||||
|
]
|
||||||
|
}`))
|
||||||
|
|
||||||
|
require.Len(t, items, 3)
|
||||||
|
var sawBase64, sawURL, sawPointer bool
|
||||||
|
for _, item := range items {
|
||||||
|
if item.B64JSON == "QUJD" {
|
||||||
|
sawBase64 = true
|
||||||
|
require.Equal(t, "cat astronaut", item.Prompt)
|
||||||
|
}
|
||||||
|
if item.DownloadURL == "https://files.example.com/image.png?sig=1" {
|
||||||
|
sawURL = true
|
||||||
|
}
|
||||||
|
if item.Pointer == "file-service://file_123" {
|
||||||
|
sawPointer = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
require.True(t, sawBase64)
|
||||||
|
require.True(t, sawURL)
|
||||||
|
require.True(t, sawPointer)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveOpenAIImageBytes_PrefersInlineBase64(t *testing.T) {
|
||||||
|
data, err := resolveOpenAIImageBytes(context.Background(), nil, nil, "", openAIImagePointerInfo{
|
||||||
|
B64JSON: "data:image/png;base64,QUJD",
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, []byte("ABC"), data)
|
||||||
|
}
|
||||||
|
|||||||
@@ -91,6 +91,7 @@ func TestNormalizeCodexModel(t *testing.T) {
|
|||||||
"gpt-5.3-codex-spark-high": "gpt-5.3-codex-spark",
|
"gpt-5.3-codex-spark-high": "gpt-5.3-codex-spark",
|
||||||
"gpt-5.3-codex-spark-xhigh": "gpt-5.3-codex-spark",
|
"gpt-5.3-codex-spark-xhigh": "gpt-5.3-codex-spark",
|
||||||
"gpt-5.3": "gpt-5.3-codex",
|
"gpt-5.3": "gpt-5.3-codex",
|
||||||
|
"gpt-image-2": "gpt-image-2",
|
||||||
}
|
}
|
||||||
|
|
||||||
for input, expected := range cases {
|
for input, expected := range cases {
|
||||||
|
|||||||
@@ -812,6 +812,16 @@ func (s *PricingService) matchOpenAIModel(model string) *LiteLLMModelPricing {
|
|||||||
return openAIGPT54FallbackPricing
|
return openAIGPT54FallbackPricing
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if isOpenAIImageGenerationModel(model) {
|
||||||
|
for _, candidate := range []string{"gpt-image-2", "gpt-image-1.5", "gpt-image-1"} {
|
||||||
|
if pricing, ok := s.pricingData[candidate]; ok {
|
||||||
|
logger.LegacyPrintf("service.pricing", "[Pricing] OpenAI image fallback matched %s -> %s", model, candidate)
|
||||||
|
return pricing
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// 最终回退到 DefaultTestModel
|
// 最终回退到 DefaultTestModel
|
||||||
defaultModel := strings.ToLower(openai.DefaultTestModel)
|
defaultModel := strings.ToLower(openai.DefaultTestModel)
|
||||||
if pricing, ok := s.pricingData[defaultModel]; ok {
|
if pricing, ok := s.pricingData[defaultModel]; ok {
|
||||||
|
|||||||
@@ -128,6 +128,21 @@ func TestGetModelPricing_Gpt54NanoUsesDedicatedStaticFallbackWhenRemoteMissing(t
|
|||||||
require.Zero(t, got.LongContextInputTokenThreshold)
|
require.Zero(t, got.LongContextInputTokenThreshold)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGetModelPricing_ImageModelDoesNotFallbackToTextModel(t *testing.T) {
|
||||||
|
imagePricing := &LiteLLMModelPricing{InputCostPerToken: 3}
|
||||||
|
textPricing := &LiteLLMModelPricing{InputCostPerToken: 9}
|
||||||
|
|
||||||
|
svc := &PricingService{
|
||||||
|
pricingData: map[string]*LiteLLMModelPricing{
|
||||||
|
"gpt-image-2": imagePricing,
|
||||||
|
"gpt-5.4": textPricing,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
got := svc.GetModelPricing("gpt-image-3")
|
||||||
|
require.Same(t, imagePricing, got)
|
||||||
|
}
|
||||||
|
|
||||||
func TestParsePricingData_PreservesPriorityAndServiceTierFields(t *testing.T) {
|
func TestParsePricingData_PreservesPriorityAndServiceTierFields(t *testing.T) {
|
||||||
raw := map[string]any{
|
raw := map[string]any{
|
||||||
"gpt-5.4": map[string]any{
|
"gpt-5.4": map[string]any{
|
||||||
|
|||||||
Reference in New Issue
Block a user