From 5f41899705cca0a88e6823f18522a974b22870a1 Mon Sep 17 00:00:00 2001 From: gaoren002 Date: Thu, 23 Apr 2026 15:13:57 +0000 Subject: [PATCH] fix: bridge codex image generation over responses --- backend/internal/server/routes/gateway.go | 7 + .../internal/server/routes/gateway_test.go | 7 +- .../service/openai_codex_transform.go | 136 ++++++++++++++++ .../service/openai_codex_transform_test.go | 153 ++++++++++++++++++ .../service/openai_gateway_service.go | 73 ++++++++- .../service/openai_gateway_service_test.go | 28 ++++ backend/internal/web/embed_on.go | 1 + backend/internal/web/embed_test.go | 4 + 8 files changed, 406 insertions(+), 3 deletions(-) diff --git a/backend/internal/server/routes/gateway.go b/backend/internal/server/routes/gateway.go index 5982e1cc..9541cda1 100644 --- a/backend/internal/server/routes/gateway.go +++ b/backend/internal/server/routes/gateway.go @@ -140,6 +140,13 @@ func RegisterGatewayRoutes( r.POST("/responses", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, responsesHandler) r.POST("/responses/*subpath", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, responsesHandler) r.GET("/responses", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.ResponsesWebSocket) + codexDirect := r.Group("/backend-api/codex") + codexDirect.Use(bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic) + { + codexDirect.POST("/responses", responsesHandler) + codexDirect.POST("/responses/*subpath", responsesHandler) + codexDirect.GET("/responses", h.OpenAIGateway.ResponsesWebSocket) + } // OpenAI Chat Completions API(不带v1前缀的别名)— auto-route based on group platform r.POST("/chat/completions", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, func(c *gin.Context) { if getGroupPlatform(c) == service.PlatformOpenAI { diff --git a/backend/internal/server/routes/gateway_test.go b/backend/internal/server/routes/gateway_test.go index 87a77cbc..19ef5686 100644 --- a/backend/internal/server/routes/gateway_test.go +++ b/backend/internal/server/routes/gateway_test.go @@ -45,7 +45,12 @@ func newGatewayRoutesTestRouter() *gin.Engine { func TestGatewayRoutesOpenAIResponsesCompactPathIsRegistered(t *testing.T) { router := newGatewayRoutesTestRouter() - for _, path := range []string{"/v1/responses/compact", "/responses/compact"} { + for _, path := range []string{ + "/v1/responses/compact", + "/responses/compact", + "/backend-api/codex/responses", + "/backend-api/codex/responses/compact", + } { req := httptest.NewRequest(http.MethodPost, path, strings.NewReader(`{"model":"gpt-5"}`)) req.Header.Set("Content-Type", "application/json") w := httptest.NewRecorder() diff --git a/backend/internal/service/openai_codex_transform.go b/backend/internal/service/openai_codex_transform.go index 560db436..14abde9b 100644 --- a/backend/internal/service/openai_codex_transform.go +++ b/backend/internal/service/openai_codex_transform.go @@ -45,6 +45,11 @@ type codexTransformResult struct { PromptCacheKey string } +const ( + codexImageGenerationBridgeMarker = "" + codexImageGenerationBridgeText = codexImageGenerationBridgeMarker + "\nWhen the user asks for raster image generation or editing, use the OpenAI Responses native `image_generation` tool attached to this request. The local Codex client may not expose an `image_gen` namespace, but that does not mean image generation is unavailable. Do not ask the user to switch to CLI fallback solely because `image_gen` is absent.\n" +) + func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact bool) codexTransformResult { result := codexTransformResult{} // 工具续链需求会影响存储策略与 input 过滤逻辑。 @@ -300,6 +305,61 @@ func normalizeOpenAIResponsesImageGenerationTools(reqBody map[string]any) bool { return modified } +func ensureOpenAIResponsesImageGenerationTool(reqBody map[string]any) bool { + if len(reqBody) == 0 { + return false + } + + tool := map[string]any{ + "type": "image_generation", + "output_format": "png", + } + + rawTools, ok := reqBody["tools"] + if !ok || rawTools == nil { + reqBody["tools"] = []any{tool} + return true + } + + tools, ok := rawTools.([]any) + if !ok { + reqBody["tools"] = []any{tool} + return true + } + for _, rawTool := range tools { + toolMap, ok := rawTool.(map[string]any) + if !ok { + continue + } + if strings.TrimSpace(firstNonEmptyString(toolMap["type"])) == "image_generation" { + return false + } + } + + reqBody["tools"] = append(tools, tool) + return true +} + +func applyCodexImageGenerationBridgeInstructions(reqBody map[string]any) bool { + if len(reqBody) == 0 || !hasOpenAIImageGenerationTool(reqBody) { + return false + } + + existing, _ := reqBody["instructions"].(string) + if strings.Contains(existing, codexImageGenerationBridgeMarker) { + return false + } + + existing = strings.TrimRight(existing, " \t\r\n") + if strings.TrimSpace(existing) == "" { + reqBody["instructions"] = codexImageGenerationBridgeText + return true + } + + reqBody["instructions"] = existing + "\n\n" + codexImageGenerationBridgeText + return true +} + func validateOpenAIResponsesImageModel(reqBody map[string]any, model string) error { if !hasOpenAIImageGenerationTool(reqBody) { return nil @@ -311,6 +371,82 @@ func validateOpenAIResponsesImageModel(reqBody map[string]any, model string) err return fmt.Errorf("/v1/responses image_generation requests require a Responses-capable text model; image-only model %q is not allowed", model) } +func normalizeOpenAIResponsesImageOnlyModel(reqBody map[string]any) bool { + if len(reqBody) == 0 { + return false + } + imageModel := strings.TrimSpace(firstNonEmptyString(reqBody["model"])) + if !isOpenAIImageGenerationModel(imageModel) { + return false + } + + modified := false + tools, _ := reqBody["tools"].([]any) + imageToolIndex := -1 + for i, rawTool := range tools { + toolMap, ok := rawTool.(map[string]any) + if !ok { + continue + } + if strings.TrimSpace(firstNonEmptyString(toolMap["type"])) == "image_generation" { + imageToolIndex = i + break + } + } + if imageToolIndex < 0 { + tools = append(tools, map[string]any{ + "type": "image_generation", + "model": imageModel, + }) + imageToolIndex = len(tools) - 1 + reqBody["tools"] = tools + modified = true + } + + if toolMap, ok := tools[imageToolIndex].(map[string]any); ok { + if strings.TrimSpace(firstNonEmptyString(toolMap["model"])) == "" { + toolMap["model"] = imageModel + modified = true + } + for _, key := range []string{ + "size", + "quality", + "background", + "output_format", + "output_compression", + "moderation", + "style", + "partial_images", + } { + if value, exists := reqBody[key]; exists && value != nil { + if _, toolHas := toolMap[key]; !toolHas { + toolMap[key] = value + } + delete(reqBody, key) + modified = true + } + } + } + + if prompt := strings.TrimSpace(firstNonEmptyString(reqBody["prompt"])); prompt != "" { + if _, hasInput := reqBody["input"]; !hasInput { + reqBody["input"] = prompt + } + delete(reqBody, "prompt") + modified = true + } + + if _, ok := reqBody["tool_choice"]; !ok { + reqBody["tool_choice"] = map[string]any{"type": "image_generation"} + modified = true + } + if imageModel != openAIImagesResponsesMainModel { + modified = true + } + reqBody["model"] = openAIImagesResponsesMainModel + return modified +} + func normalizeOpenAIModelForUpstream(account *Account, model string) string { if account == nil || account.Type == AccountTypeOAuth { return normalizeCodexModel(model) diff --git a/backend/internal/service/openai_codex_transform_test.go b/backend/internal/service/openai_codex_transform_test.go index f08e4b15..4fd16fdb 100644 --- a/backend/internal/service/openai_codex_transform_test.go +++ b/backend/internal/service/openai_codex_transform_test.go @@ -243,6 +243,159 @@ func TestNormalizeOpenAIResponsesImageGenerationTools_RewritesLegacyFields(t *te require.False(t, hasCompression) } +func TestEnsureOpenAIResponsesImageGenerationTool_NoTools(t *testing.T) { + reqBody := map[string]any{ + "model": "gpt-5.4", + "input": "draw a cat", + } + + modified := ensureOpenAIResponsesImageGenerationTool(reqBody) + require.True(t, modified) + + tools, ok := reqBody["tools"].([]any) + require.True(t, ok) + require.Len(t, tools, 1) + tool, ok := tools[0].(map[string]any) + require.True(t, ok) + require.Equal(t, "image_generation", tool["type"]) + require.Equal(t, "png", tool["output_format"]) +} + +func TestEnsureOpenAIResponsesImageGenerationTool_AppendsToExistingTools(t *testing.T) { + reqBody := map[string]any{ + "model": "gpt-5.4", + "tools": []any{ + map[string]any{"type": "web_search"}, + }, + } + + modified := ensureOpenAIResponsesImageGenerationTool(reqBody) + require.True(t, modified) + + tools, ok := reqBody["tools"].([]any) + require.True(t, ok) + require.Len(t, tools, 2) + first, ok := tools[0].(map[string]any) + require.True(t, ok) + require.Equal(t, "web_search", first["type"]) + second, ok := tools[1].(map[string]any) + require.True(t, ok) + require.Equal(t, "image_generation", second["type"]) + require.Equal(t, "png", second["output_format"]) +} + +func TestEnsureOpenAIResponsesImageGenerationTool_PreservesExistingImageTool(t *testing.T) { + reqBody := map[string]any{ + "model": "gpt-5.4", + "tools": []any{ + map[string]any{"type": "image_generation", "output_format": "webp"}, + map[string]any{"type": "web_search"}, + }, + } + + modified := ensureOpenAIResponsesImageGenerationTool(reqBody) + require.False(t, modified) + + tools, ok := reqBody["tools"].([]any) + require.True(t, ok) + require.Len(t, tools, 2) + tool, ok := tools[0].(map[string]any) + require.True(t, ok) + require.Equal(t, "webp", tool["output_format"]) +} + +func TestApplyCodexImageGenerationBridgeInstructions_AppendsBridgeOnce(t *testing.T) { + reqBody := map[string]any{ + "instructions": "existing instructions", + "tools": []any{ + map[string]any{"type": "image_generation", "output_format": "png"}, + }, + } + + modified := applyCodexImageGenerationBridgeInstructions(reqBody) + require.True(t, modified) + + instructions, ok := reqBody["instructions"].(string) + require.True(t, ok) + require.Contains(t, instructions, "existing instructions") + require.Contains(t, instructions, codexImageGenerationBridgeMarker) + require.Contains(t, instructions, "Responses native `image_generation` tool") + + modified = applyCodexImageGenerationBridgeInstructions(reqBody) + require.False(t, modified) +} + +func TestApplyCodexImageGenerationBridgeInstructions_SkipsWithoutImageTool(t *testing.T) { + reqBody := map[string]any{ + "instructions": "existing instructions", + "tools": []any{ + map[string]any{"type": "web_search"}, + }, + } + + modified := applyCodexImageGenerationBridgeInstructions(reqBody) + require.False(t, modified) + require.Equal(t, "existing instructions", reqBody["instructions"]) +} + +func TestNormalizeOpenAIResponsesImageOnlyModel_BuildsImageToolRequest(t *testing.T) { + reqBody := map[string]any{ + "model": "gpt-image-2", + "prompt": "draw a cat", + "size": "1024x1024", + "output_format": "png", + } + + modified := normalizeOpenAIResponsesImageOnlyModel(reqBody) + require.True(t, modified) + require.Equal(t, openAIImagesResponsesMainModel, reqBody["model"]) + require.Equal(t, "draw a cat", reqBody["input"]) + _, hasPrompt := reqBody["prompt"] + require.False(t, hasPrompt) + _, hasTopLevelSize := reqBody["size"] + require.False(t, hasTopLevelSize) + + tools, ok := reqBody["tools"].([]any) + require.True(t, ok) + require.Len(t, tools, 1) + tool, ok := tools[0].(map[string]any) + require.True(t, ok) + require.Equal(t, "image_generation", tool["type"]) + require.Equal(t, "gpt-image-2", tool["model"]) + require.Equal(t, "1024x1024", tool["size"]) + require.Equal(t, "png", tool["output_format"]) + + choice, ok := reqBody["tool_choice"].(map[string]any) + require.True(t, ok) + require.Equal(t, "image_generation", choice["type"]) +} + +func TestNormalizeOpenAIResponsesImageOnlyModel_PreservesExistingImageTool(t *testing.T) { + reqBody := map[string]any{ + "model": "gpt-image-2", + "input": "draw a cat", + "tools": []any{ + map[string]any{ + "type": "image_generation", + "model": "gpt-image-1.5", + }, + }, + "tool_choice": "auto", + } + + modified := normalizeOpenAIResponsesImageOnlyModel(reqBody) + require.True(t, modified) + require.Equal(t, openAIImagesResponsesMainModel, reqBody["model"]) + require.Equal(t, "auto", reqBody["tool_choice"]) + + tools, ok := reqBody["tools"].([]any) + require.True(t, ok) + require.Len(t, tools, 1) + tool, ok := tools[0].(map[string]any) + require.True(t, ok) + require.Equal(t, "gpt-image-1.5", tool["model"]) +} + func TestValidateOpenAIResponsesImageModel_RejectsImageOnlyModel(t *testing.T) { err := validateOpenAIResponsesImageModel(map[string]any{ "tools": []any{ diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index 06fd14af..b4b285c5 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -1935,11 +1935,22 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco markPatchSet("instructions", "You are a helpful coding assistant.") } + if isCodexCLI && ensureOpenAIResponsesImageGenerationTool(reqBody) { + bodyModified = true + disablePatch() + logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Injected /responses image_generation tool for Codex client") + } + if normalizeOpenAIResponsesImageGenerationTools(reqBody) { bodyModified = true disablePatch() logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Normalized /responses image_generation tool payload") } + if isCodexCLI && applyCodexImageGenerationBridgeInstructions(reqBody) { + bodyModified = true + disablePatch() + logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Added Codex image_generation bridge instructions") + } // 对所有请求执行模型映射(包含 Codex CLI)。 billingModel := account.GetMappedModel(reqModel) @@ -1950,6 +1961,20 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco markPatchSet("model", billingModel) } upstreamModel := billingModel + if normalizeOpenAIResponsesImageOnlyModel(reqBody) { + bodyModified = true + disablePatch() + if model, ok := reqBody["model"].(string); ok { + upstreamModel = strings.TrimSpace(model) + } + logger.LegacyPrintf( + "service.openai_gateway", + "[OpenAI] Normalized /responses image-only model request inbound_model=%s image_model=%s upstream_model=%s", + reqModel, + billingModel, + upstreamModel, + ) + } if err := validateOpenAIResponsesImageModel(reqBody, upstreamModel); err != nil { setOpsUpstreamError(c, http.StatusBadRequest, err.Error(), "") c.JSON(http.StatusBadRequest, gin.H{ @@ -4118,22 +4143,39 @@ func extractCodexFinalResponse(body string) ([]byte, bool) { // Returns (nil, false) if no content was found in deltas. func reconstructResponseOutputFromSSE(bodyText string) ([]byte, bool) { acc := apicompat.NewBufferedResponseAccumulator() + imageOutputs := make([]json.RawMessage, 0, 1) + seenImages := make(map[string]struct{}) lines := strings.Split(bodyText, "\n") for _, line := range lines { data, ok := extractOpenAISSEDataLine(line) if !ok || data == "" || data == "[DONE]" { continue } + if imageOutput, ok := extractImageGenerationOutputFromSSEData([]byte(data), seenImages); ok { + imageOutputs = append(imageOutputs, imageOutput) + } var event apicompat.ResponsesStreamEvent if err := json.Unmarshal([]byte(data), &event); err != nil { continue } acc.ProcessEvent(&event) } - if !acc.HasContent() { + if !acc.HasContent() && len(imageOutputs) == 0 { return nil, false } - output := acc.BuildOutput() + + var output []json.RawMessage + if acc.HasContent() { + outputJSON, err := json.Marshal(acc.BuildOutput()) + if err != nil { + return nil, false + } + if err := json.Unmarshal(outputJSON, &output); err != nil { + return nil, false + } + } + output = append(output, imageOutputs...) + outputJSON, err := json.Marshal(output) if err != nil { return nil, false @@ -4141,6 +4183,33 @@ func reconstructResponseOutputFromSSE(bodyText string) ([]byte, bool) { return outputJSON, true } +func extractImageGenerationOutputFromSSEData(data []byte, seen map[string]struct{}) (json.RawMessage, bool) { + if len(data) == 0 || !gjson.ValidBytes(data) { + return nil, false + } + if gjson.GetBytes(data, "type").String() != "response.output_item.done" { + return nil, false + } + item := gjson.GetBytes(data, "item") + if !item.Exists() || !item.IsObject() || item.Get("type").String() != "image_generation_call" { + return nil, false + } + if strings.TrimSpace(item.Get("result").String()) == "" { + return nil, false + } + key := strings.TrimSpace(item.Get("id").String()) + if key == "" { + key = strings.TrimSpace(item.Get("output_format").String()) + "|" + strings.TrimSpace(item.Get("result").String()) + } + if key != "" && seen != nil { + if _, exists := seen[key]; exists { + return nil, false + } + seen[key] = struct{}{} + } + return json.RawMessage(item.Raw), true +} + func (s *OpenAIGatewayService) parseSSEUsageFromBody(body string) *OpenAIUsage { usage := &OpenAIUsage{} lines := strings.Split(body, "\n") diff --git a/backend/internal/service/openai_gateway_service_test.go b/backend/internal/service/openai_gateway_service_test.go index cf2d875f..ed7c78a3 100644 --- a/backend/internal/service/openai_gateway_service_test.go +++ b/backend/internal/service/openai_gateway_service_test.go @@ -18,6 +18,7 @@ import ( "github.com/cespare/xxhash/v2" "github.com/gin-gonic/gin" "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" ) // 编译期接口断言 @@ -1880,6 +1881,33 @@ func TestHandleSSEToJSON_CompletedEventReturnsJSON(t *testing.T) { require.NotContains(t, rec.Body.String(), "data:") } +func TestHandleSSEToJSON_ReconstructsImageGenerationOutputItemDone(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/", nil) + + svc := &OpenAIGatewayService{cfg: &config.Config{}} + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}}, + } + body := []byte(strings.Join([]string{ + `data: {"type":"response.output_item.done","item":{"id":"ig_123","type":"image_generation_call","result":"aGVsbG8=","revised_prompt":"draw a cat","output_format":"png"}}`, + `data: {"type":"response.completed","response":{"id":"resp_img","model":"gpt-5.4","output":[],"usage":{"input_tokens":7,"output_tokens":9,"output_tokens_details":{"image_tokens":4}}}}`, + `data: [DONE]`, + }, "\n")) + + usage, err := svc.handleSSEToJSON(resp, c, body, "gpt-5.4", "gpt-5.4") + require.NoError(t, err) + require.NotNil(t, usage) + require.Equal(t, 4, usage.ImageOutputTokens) + require.NotContains(t, rec.Body.String(), "data:") + require.Equal(t, "image_generation_call", gjson.Get(rec.Body.String(), "output.0.type").String()) + require.Equal(t, "aGVsbG8=", gjson.Get(rec.Body.String(), "output.0.result").String()) + require.Equal(t, "draw a cat", gjson.Get(rec.Body.String(), "output.0.revised_prompt").String()) +} + func TestHandleSSEToJSON_NoFinalResponseKeepsSSEBody(t *testing.T) { gin.SetMode(gin.TestMode) rec := httptest.NewRecorder() diff --git a/backend/internal/web/embed_on.go b/backend/internal/web/embed_on.go index 5f3719be..2279d913 100644 --- a/backend/internal/web/embed_on.go +++ b/backend/internal/web/embed_on.go @@ -301,6 +301,7 @@ func shouldBypassEmbeddedFrontend(path string) bool { return strings.HasPrefix(trimmed, "/api/") || strings.HasPrefix(trimmed, "/v1/") || strings.HasPrefix(trimmed, "/v1beta/") || + strings.HasPrefix(trimmed, "/backend-api/") || strings.HasPrefix(trimmed, "/antigravity/") || strings.HasPrefix(trimmed, "/setup/") || trimmed == "/health" || diff --git a/backend/internal/web/embed_test.go b/backend/internal/web/embed_test.go index 4127a7a6..583d98a0 100644 --- a/backend/internal/web/embed_test.go +++ b/backend/internal/web/embed_test.go @@ -434,6 +434,8 @@ func TestFrontendServer_Middleware(t *testing.T) { "/api/v1/users", "/v1/models", "/v1beta/chat", + "/backend-api/codex/responses", + "/backend-api/codex/responses/compact", "/antigravity/test", "/setup/init", "/health", @@ -636,6 +638,8 @@ func TestServeEmbeddedFrontend(t *testing.T) { "/api/users", "/v1/models", "/v1beta/chat", + "/backend-api/codex/responses", + "/backend-api/codex/responses/compact", "/antigravity/test", "/setup/init", "/health",