From 9e5a6351fce07d82f82766cf1d58a488c1a959b1 Mon Sep 17 00:00:00 2001 From: wx-11 <168356742+wx-11@users.noreply.github.com> Date: Thu, 23 Apr 2026 15:09:47 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E8=AE=A1=E8=B4=B9=E9=97=AE?= =?UTF-8?q?=E9=A2=98=E4=BB=A5=E5=8F=8A=E6=A8=A1=E5=9E=8B=E5=9B=9E=E6=98=BE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../openai_gateway_record_usage_test.go | 47 +++ .../service/openai_gateway_service.go | 9 +- .../service/openai_images_responses.go | 271 ++++++++++++++---- .../internal/service/openai_images_test.go | 157 ++++++++-- 4 files changed, 397 insertions(+), 87 deletions(-) diff --git a/backend/internal/service/openai_gateway_record_usage_test.go b/backend/internal/service/openai_gateway_record_usage_test.go index 95e1bffa..9665c4c8 100644 --- a/backend/internal/service/openai_gateway_record_usage_test.go +++ b/backend/internal/service/openai_gateway_record_usage_test.go @@ -1098,3 +1098,50 @@ func TestOpenAIGatewayServiceRecordUsage_ImageOnlyUsageStillPersists(t *testing. require.NotNil(t, usageRepo.lastLog.BillingMode) require.Equal(t, string(BillingModeImage), *usageRepo.lastLog.BillingMode) } + +func TestOpenAIGatewayServiceRecordUsage_ImageUsesPerImageBillingEvenWithUsageTokens(t *testing.T) { + imagePrice := 0.02 + groupID := int64(12) + + usageRepo := &openAIRecordUsageLogRepoStub{inserted: true} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil) + + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_image_per_request", + Model: "gpt-image-2", + Usage: OpenAIUsage{ + InputTokens: 1110, + OutputTokens: 1756, + ImageOutputTokens: 1756, + }, + ImageCount: 2, + ImageSize: "1K", + Duration: time.Second, + }, + APIKey: &APIKey{ + ID: 1008, + GroupID: i64p(groupID), + Group: &Group{ + ID: groupID, + RateMultiplier: 1.0, + ImagePrice1K: &imagePrice, + }, + }, + User: &User{ID: 2008}, + Account: &Account{ID: 3008}, + }) + + require.NoError(t, err) + require.NotNil(t, usageRepo.lastLog) + require.NotNil(t, usageRepo.lastLog.BillingMode) + require.Equal(t, string(BillingModeImage), *usageRepo.lastLog.BillingMode) + require.Equal(t, 2, usageRepo.lastLog.ImageCount) + require.InDelta(t, 0.04, usageRepo.lastLog.TotalCost, 1e-12) + require.InDelta(t, 0.04, usageRepo.lastLog.ActualCost, 1e-12) + require.InDelta(t, 0.0, usageRepo.lastLog.InputCost, 1e-12) + require.InDelta(t, 0.0, usageRepo.lastLog.OutputCost, 1e-12) + require.InDelta(t, 0.0, usageRepo.lastLog.ImageOutputCost, 1e-12) +} diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index 6c661c67..1a462a3b 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -4625,12 +4625,6 @@ func (s *OpenAIGatewayService) calculateOpenAIRecordUsageCost( serviceTier string, ) (*CostBreakdown, error) { if result != nil && result.ImageCount > 0 { - if hasOpenAIImageUsageTokens(result) { - cost, err := s.calculateOpenAIImageTokenCost(ctx, apiKey, billingModel, multiplier, tokens, serviceTier, result.ImageSize) - if err == nil { - return cost, nil - } - } return s.calculateOpenAIImageCost(ctx, billingModel, apiKey, result, multiplier), nil } if s.resolver != nil && apiKey.Group != nil { @@ -4682,7 +4676,8 @@ func (s *OpenAIGatewayService) calculateOpenAIImageCost( result *OpenAIForwardResult, multiplier float64, ) *CostBreakdown { - if resolved := s.resolveOpenAIChannelPricing(ctx, billingModel, apiKey); resolved != nil { + if resolved := s.resolveOpenAIChannelPricing(ctx, billingModel, apiKey); resolved != nil && + (resolved.Mode == BillingModePerRequest || resolved.Mode == BillingModeImage) { gid := apiKey.Group.ID cost, err := s.billingService.CalculateCostUnified(CostInput{ Ctx: ctx, diff --git a/backend/internal/service/openai_images_responses.go b/backend/internal/service/openai_images_responses.go index d3aa2a31..99b5ca6e 100644 --- a/backend/internal/service/openai_images_responses.go +++ b/backend/internal/service/openai_images_responses.go @@ -25,6 +25,7 @@ type openAIResponsesImageResult struct { Size string Background string Quality string + Model string } func openAIResponsesImageResultKey(itemID string, result openAIResponsesImageResult) string { @@ -49,6 +50,126 @@ func appendOpenAIResponsesImageResultDedup(results *[]openAIResponsesImageResult return true } +func mergeOpenAIResponsesImageMeta(dst *openAIResponsesImageResult, src openAIResponsesImageResult) { + if dst == nil { + return + } + if trimmed := strings.TrimSpace(src.OutputFormat); trimmed != "" { + dst.OutputFormat = trimmed + } + if trimmed := strings.TrimSpace(src.Size); trimmed != "" { + dst.Size = trimmed + } + if trimmed := strings.TrimSpace(src.Background); trimmed != "" { + dst.Background = trimmed + } + if trimmed := strings.TrimSpace(src.Quality); trimmed != "" { + dst.Quality = trimmed + } + if trimmed := strings.TrimSpace(src.Model); trimmed != "" { + dst.Model = trimmed + } +} + +func extractOpenAIResponsesImageMetaFromLifecycleEvent(payload []byte) (openAIResponsesImageResult, int64, bool) { + switch gjson.GetBytes(payload, "type").String() { + case "response.created", "response.in_progress", "response.completed": + default: + return openAIResponsesImageResult{}, 0, false + } + + response := gjson.GetBytes(payload, "response") + if !response.Exists() { + return openAIResponsesImageResult{}, 0, false + } + + meta := openAIResponsesImageResult{ + OutputFormat: strings.TrimSpace(response.Get("tools.0.output_format").String()), + Size: strings.TrimSpace(response.Get("tools.0.size").String()), + Background: strings.TrimSpace(response.Get("tools.0.background").String()), + Quality: strings.TrimSpace(response.Get("tools.0.quality").String()), + Model: strings.TrimSpace(response.Get("tools.0.model").String()), + } + return meta, response.Get("created_at").Int(), true +} + +func buildOpenAIImagesStreamPartialPayload( + eventType string, + b64 string, + partialImageIndex int64, + responseFormat string, + createdAt int64, + meta openAIResponsesImageResult, +) []byte { + if createdAt <= 0 { + createdAt = time.Now().Unix() + } + + payload := []byte(`{"type":"","created_at":0,"partial_image_index":0,"b64_json":""}`) + payload, _ = sjson.SetBytes(payload, "type", eventType) + payload, _ = sjson.SetBytes(payload, "created_at", createdAt) + payload, _ = sjson.SetBytes(payload, "partial_image_index", partialImageIndex) + payload, _ = sjson.SetBytes(payload, "b64_json", b64) + if strings.EqualFold(strings.TrimSpace(responseFormat), "url") { + payload, _ = sjson.SetBytes(payload, "url", "data:"+openAIImageOutputMIMEType(meta.OutputFormat)+";base64,"+b64) + } + if meta.Background != "" { + payload, _ = sjson.SetBytes(payload, "background", meta.Background) + } + if meta.OutputFormat != "" { + payload, _ = sjson.SetBytes(payload, "output_format", meta.OutputFormat) + } + if meta.Quality != "" { + payload, _ = sjson.SetBytes(payload, "quality", meta.Quality) + } + if meta.Size != "" { + payload, _ = sjson.SetBytes(payload, "size", meta.Size) + } + if meta.Model != "" { + payload, _ = sjson.SetBytes(payload, "model", meta.Model) + } + return payload +} + +func buildOpenAIImagesStreamCompletedPayload( + eventType string, + img openAIResponsesImageResult, + responseFormat string, + createdAt int64, + usageRaw []byte, +) []byte { + if createdAt <= 0 { + createdAt = time.Now().Unix() + } + + payload := []byte(`{"type":"","created_at":0,"b64_json":""}`) + payload, _ = sjson.SetBytes(payload, "type", eventType) + payload, _ = sjson.SetBytes(payload, "created_at", createdAt) + payload, _ = sjson.SetBytes(payload, "b64_json", img.Result) + if strings.EqualFold(strings.TrimSpace(responseFormat), "url") { + payload, _ = sjson.SetBytes(payload, "url", "data:"+openAIImageOutputMIMEType(img.OutputFormat)+";base64,"+img.Result) + } + if img.Background != "" { + payload, _ = sjson.SetBytes(payload, "background", img.Background) + } + if img.OutputFormat != "" { + payload, _ = sjson.SetBytes(payload, "output_format", img.OutputFormat) + } + if img.Quality != "" { + payload, _ = sjson.SetBytes(payload, "quality", img.Quality) + } + if img.Size != "" { + payload, _ = sjson.SetBytes(payload, "size", img.Size) + } + if img.Model != "" { + payload, _ = sjson.SetBytes(payload, "model", img.Model) + } + if len(usageRaw) > 0 && gjson.ValidBytes(usageRaw) { + payload, _ = sjson.SetRawBytes(payload, "usage", usageRaw) + } + return payload +} + func openAIImageOutputMIMEType(outputFormat string) string { if outputFormat == "" { return "image/png" @@ -134,16 +255,12 @@ func buildOpenAIImagesResponsesRequest(parsed *OpenAIImagesRequest, toolModel st {path: "background", value: parsed.Background}, {path: "output_format", value: parsed.OutputFormat}, {path: "moderation", value: parsed.Moderation}, - {path: "input_fidelity", value: parsed.InputFidelity}, {path: "style", value: parsed.Style}, } { if trimmed := strings.TrimSpace(field.value); trimmed != "" { tool, _ = sjson.SetBytes(tool, field.path, trimmed) } } - if parsed.N > 1 { - return nil, fmt.Errorf("codex /responses image tool currently supports only n=1") - } if parsed.OutputCompression != nil { tool, _ = sjson.SetBytes(tool, "output_compression", *parsed.OutputCompression) } @@ -247,6 +364,7 @@ func collectOpenAIImagesFromResponsesBody(body []byte) ([]openAIResponsesImageRe createdAt int64 usageRaw []byte foundFinal bool + responseMeta openAIResponsesImageResult ) for _, line := range bytes.Split(body, []byte("\n")) { @@ -259,18 +377,21 @@ func collectOpenAIImagesFromResponsesBody(body []byte) ([]openAIResponsesImageRe if !gjson.ValidBytes(payload) { continue } + if meta, eventCreatedAt, ok := extractOpenAIResponsesImageMetaFromLifecycleEvent(payload); ok { + mergeOpenAIResponsesImageMeta(&responseMeta, meta) + if eventCreatedAt > 0 { + createdAt = eventCreatedAt + } + } switch gjson.GetBytes(payload, "type").String() { - case "response.created": - if createdAt <= 0 { - createdAt = gjson.GetBytes(payload, "response.created_at").Int() - } case "response.output_item.done": result, itemID, ok, err := extractOpenAIImageFromResponsesOutputItemDone(payload) if err != nil { return nil, 0, nil, openAIResponsesImageResult{}, false, err } if ok { + mergeOpenAIResponsesImageMeta(&result, responseMeta) appendOpenAIResponsesImageResultDedup(&fallbackResults, fallbackSeen, itemID, result) } case "response.completed": @@ -286,16 +407,21 @@ func collectOpenAIImagesFromResponsesBody(body []byte) ([]openAIResponsesImageRe usageRaw = completedUsageRaw } if len(results) > 0 { + mergeOpenAIResponsesImageMeta(&firstMeta, responseMeta) return results, createdAt, usageRaw, firstMeta, true, nil } if len(fallbackResults) > 0 { - return fallbackResults, createdAt, usageRaw, fallbackResults[0], true, nil + firstMeta = fallbackResults[0] + mergeOpenAIResponsesImageMeta(&firstMeta, responseMeta) + return fallbackResults, createdAt, usageRaw, firstMeta, true, nil } } } if len(fallbackResults) > 0 { - return fallbackResults, createdAt, usageRaw, fallbackResults[0], foundFinal, nil + firstMeta := fallbackResults[0] + mergeOpenAIResponsesImageMeta(&firstMeta, responseMeta) + return fallbackResults, createdAt, usageRaw, firstMeta, foundFinal, nil } return nil, createdAt, usageRaw, openAIResponsesImageResult{}, foundFinal, nil } @@ -341,6 +467,9 @@ func buildOpenAIImagesAPIResponse( if firstMeta.Size != "" { out, _ = sjson.SetBytes(out, "size", firstMeta.Size) } + if firstMeta.Model != "" { + out, _ = sjson.SetBytes(out, "model", firstMeta.Model) + } if len(usageRaw) > 0 && gjson.ValidBytes(usageRaw) { out, _ = sjson.SetRawBytes(out, "usage", usageRaw) } @@ -380,6 +509,7 @@ func (s *OpenAIGatewayService) handleOpenAIImagesOAuthNonStreamingResponse( resp *http.Response, c *gin.Context, responseFormat string, + fallbackModel string, ) (OpenAIUsage, int, error) { body, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, openAITooLargeError) if err != nil { @@ -403,6 +533,9 @@ func (s *OpenAIGatewayService) handleOpenAIImagesOAuthNonStreamingResponse( if len(results) == 0 { return OpenAIUsage{}, 0, fmt.Errorf("upstream did not return image output") } + if strings.TrimSpace(firstMeta.Model) == "" { + firstMeta.Model = strings.TrimSpace(fallbackModel) + } responseBody, err := buildOpenAIImagesAPIResponse(results, createdAt, usageRaw, firstMeta, responseFormat) if err != nil { @@ -419,6 +552,7 @@ func (s *OpenAIGatewayService) handleOpenAIImagesOAuthStreamingResponse( startTime time.Time, responseFormat string, streamPrefix string, + fallbackModel string, ) (OpenAIUsage, int, *int, error) { responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) c.Header("Content-Type", "text/event-stream") @@ -441,6 +575,10 @@ func (s *OpenAIGatewayService) handleOpenAIImagesOAuthStreamingResponse( imageCount := 0 var firstTokenMs *int emitted := make(map[string]struct{}) + pendingResults := make([]openAIResponsesImageResult, 0, 1) + pendingSeen := make(map[string]struct{}) + streamMeta := openAIResponsesImageResult{Model: strings.TrimSpace(fallbackModel)} + var createdAt int64 for { line, err := reader.ReadBytes('\n') @@ -455,20 +593,30 @@ func (s *OpenAIGatewayService) handleOpenAIImagesOAuthStreamingResponse( dataBytes := []byte(data) s.parseSSEUsageBytes(dataBytes, &usage) if gjson.ValidBytes(dataBytes) { + if meta, eventCreatedAt, ok := extractOpenAIResponsesImageMetaFromLifecycleEvent(dataBytes); ok { + mergeOpenAIResponsesImageMeta(&streamMeta, meta) + if eventCreatedAt > 0 { + createdAt = eventCreatedAt + } + } switch gjson.GetBytes(dataBytes, "type").String() { case "response.image_generation_call.partial_image": b64 := strings.TrimSpace(gjson.GetBytes(dataBytes, "partial_image_b64").String()) if b64 != "" { eventName := streamPrefix + ".partial_image" - payload := []byte(`{"type":"","partial_image_index":0}`) - payload, _ = sjson.SetBytes(payload, "type", eventName) - payload, _ = sjson.SetBytes(payload, "partial_image_index", gjson.GetBytes(dataBytes, "partial_image_index").Int()) - if format == "url" { - outputFormat := strings.TrimSpace(gjson.GetBytes(dataBytes, "output_format").String()) - payload, _ = sjson.SetBytes(payload, "url", "data:"+openAIImageOutputMIMEType(outputFormat)+";base64,"+b64) - } else { - payload, _ = sjson.SetBytes(payload, "b64_json", b64) - } + partialMeta := streamMeta + mergeOpenAIResponsesImageMeta(&partialMeta, openAIResponsesImageResult{ + OutputFormat: strings.TrimSpace(gjson.GetBytes(dataBytes, "output_format").String()), + Background: strings.TrimSpace(gjson.GetBytes(dataBytes, "background").String()), + }) + payload := buildOpenAIImagesStreamPartialPayload( + eventName, + b64, + gjson.GetBytes(dataBytes, "partial_image_index").Int(), + format, + createdAt, + partialMeta, + ) if writeErr := s.writeOpenAIImagesStreamEvent(c, flusher, eventName, payload); writeErr != nil { return OpenAIUsage{}, imageCount, firstTokenMs, writeErr } @@ -482,59 +630,46 @@ func (s *OpenAIGatewayService) handleOpenAIImagesOAuthStreamingResponse( if !ok { break } + mergeOpenAIResponsesImageMeta(&streamMeta, img) + mergeOpenAIResponsesImageMeta(&img, streamMeta) key := openAIResponsesImageResultKey(itemID, img) if _, exists := emitted[key]; exists { break } - eventName := streamPrefix + ".completed" - payload := []byte(`{"type":""}`) - payload, _ = sjson.SetBytes(payload, "type", eventName) - if format == "url" { - payload, _ = sjson.SetBytes(payload, "url", "data:"+openAIImageOutputMIMEType(img.OutputFormat)+";base64,"+img.Result) - } else { - payload, _ = sjson.SetBytes(payload, "b64_json", img.Result) + if _, exists := pendingSeen[key]; exists { + break } - if img.RevisedPrompt != "" { - payload, _ = sjson.SetBytes(payload, "revised_prompt", img.RevisedPrompt) - } - if writeErr := s.writeOpenAIImagesStreamEvent(c, flusher, eventName, payload); writeErr != nil { - return OpenAIUsage{}, imageCount, firstTokenMs, writeErr - } - emitted[key] = struct{}{} - imageCount = len(emitted) + pendingSeen[key] = struct{}{} + pendingResults = append(pendingResults, img) case "response.completed": - results, _, usageRaw, _, extractErr := extractOpenAIImagesFromResponsesCompleted(dataBytes) + results, _, usageRaw, firstMeta, extractErr := extractOpenAIImagesFromResponsesCompleted(dataBytes) if extractErr != nil { _ = s.writeOpenAIImagesStreamEvent(c, flusher, "error", buildOpenAIImagesStreamErrorBody(extractErr.Error())) return OpenAIUsage{}, imageCount, firstTokenMs, extractErr } - if len(results) == 0 { - if imageCount > 0 { - return usage, imageCount, firstTokenMs, nil - } + mergeOpenAIResponsesImageMeta(&streamMeta, firstMeta) + finalResults := make([]openAIResponsesImageResult, 0, len(results)+len(pendingResults)) + finalSeen := make(map[string]struct{}) + for _, img := range results { + mergeOpenAIResponsesImageMeta(&img, streamMeta) + appendOpenAIResponsesImageResultDedup(&finalResults, finalSeen, "", img) + } + for _, img := range pendingResults { + mergeOpenAIResponsesImageMeta(&img, streamMeta) + appendOpenAIResponsesImageResultDedup(&finalResults, finalSeen, "", img) + } + if len(finalResults) == 0 { err = fmt.Errorf("upstream did not return image output") _ = s.writeOpenAIImagesStreamEvent(c, flusher, "error", buildOpenAIImagesStreamErrorBody(err.Error())) return OpenAIUsage{}, imageCount, firstTokenMs, err } eventName := streamPrefix + ".completed" - for _, img := range results { + for _, img := range finalResults { key := openAIResponsesImageResultKey("", img) if _, exists := emitted[key]; exists { continue } - payload := []byte(`{"type":""}`) - payload, _ = sjson.SetBytes(payload, "type", eventName) - if format == "url" { - payload, _ = sjson.SetBytes(payload, "url", "data:"+openAIImageOutputMIMEType(img.OutputFormat)+";base64,"+img.Result) - } else { - payload, _ = sjson.SetBytes(payload, "b64_json", img.Result) - } - if img.RevisedPrompt != "" { - payload, _ = sjson.SetBytes(payload, "revised_prompt", img.RevisedPrompt) - } - if len(usageRaw) > 0 && gjson.ValidBytes(usageRaw) { - payload, _ = sjson.SetRawBytes(payload, "usage", usageRaw) - } + payload := buildOpenAIImagesStreamCompletedPayload(eventName, img, format, createdAt, usageRaw) if writeErr := s.writeOpenAIImagesStreamEvent(c, flusher, eventName, payload); writeErr != nil { return OpenAIUsage{}, imageCount, firstTokenMs, writeErr } @@ -558,6 +693,23 @@ func (s *OpenAIGatewayService) handleOpenAIImagesOAuthStreamingResponse( if imageCount > 0 { return usage, imageCount, firstTokenMs, nil } + if len(pendingResults) > 0 { + eventName := streamPrefix + ".completed" + for _, img := range pendingResults { + mergeOpenAIResponsesImageMeta(&img, streamMeta) + key := openAIResponsesImageResultKey("", img) + if _, exists := emitted[key]; exists { + continue + } + payload := buildOpenAIImagesStreamCompletedPayload(eventName, img, format, createdAt, nil) + if writeErr := s.writeOpenAIImagesStreamEvent(c, flusher, eventName, payload); writeErr != nil { + return OpenAIUsage{}, imageCount, firstTokenMs, writeErr + } + emitted[key] = struct{}{} + } + imageCount = len(emitted) + return usage, imageCount, firstTokenMs, nil + } streamErr := fmt.Errorf("stream disconnected before image generation completed") _ = s.writeOpenAIImagesStreamEvent(c, flusher, "error", buildOpenAIImagesStreamErrorBody(streamErr.Error())) @@ -590,6 +742,15 @@ func (s *OpenAIGatewayService) forwardOpenAIImagesOAuth( account.Type, len(parsed.Uploads), ) + if parsed.N > 1 { + logger.LegacyPrintf( + "service.openai_gateway", + "[Warning] Codex /responses image tool requested n=%d; falling back to n=1 request_model=%s endpoint=%s", + parsed.N, + requestModel, + parsed.Endpoint, + ) + } token, _, err := s.GetAccessToken(ctx, account) if err != nil { @@ -664,12 +825,12 @@ func (s *OpenAIGatewayService) forwardOpenAIImagesOAuth( firstTokenMs *int ) if parsed.Stream { - usage, imageCount, firstTokenMs, err = s.handleOpenAIImagesOAuthStreamingResponse(resp, c, startTime, parsed.ResponseFormat, openAIImagesStreamPrefix(parsed)) + usage, imageCount, firstTokenMs, err = s.handleOpenAIImagesOAuthStreamingResponse(resp, c, startTime, parsed.ResponseFormat, openAIImagesStreamPrefix(parsed), requestModel) if err != nil { return nil, err } } else { - usage, imageCount, err = s.handleOpenAIImagesOAuthNonStreamingResponse(resp, c, parsed.ResponseFormat) + usage, imageCount, err = s.handleOpenAIImagesOAuthNonStreamingResponse(resp, c, parsed.ResponseFormat, requestModel) if err != nil { return nil, err } diff --git a/backend/internal/service/openai_images_test.go b/backend/internal/service/openai_images_test.go index 4f0ab1f3..200547d4 100644 --- a/backend/internal/service/openai_images_test.go +++ b/backend/internal/service/openai_images_test.go @@ -258,9 +258,47 @@ func TestAccountSupportsOpenAIImageCapability_OAuthSupportsNative(t *testing.T) require.True(t, account.SupportsOpenAIImageCapability(OpenAIImagesCapabilityNative)) } +type openAIImageTestSSEEvent struct { + Name string + Data string +} + +func parseOpenAIImageTestSSEEvents(body string) []openAIImageTestSSEEvent { + chunks := strings.Split(body, "\n\n") + events := make([]openAIImageTestSSEEvent, 0, len(chunks)) + for _, chunk := range chunks { + chunk = strings.TrimSpace(chunk) + if chunk == "" { + continue + } + var event openAIImageTestSSEEvent + for _, line := range strings.Split(chunk, "\n") { + switch { + case strings.HasPrefix(line, "event: "): + event.Name = strings.TrimSpace(strings.TrimPrefix(line, "event: ")) + case strings.HasPrefix(line, "data: "): + event.Data = strings.TrimSpace(strings.TrimPrefix(line, "data: ")) + } + } + if event.Name != "" || event.Data != "" { + events = append(events, event) + } + } + return events +} + +func findOpenAIImageTestSSEEvent(events []openAIImageTestSSEEvent, name string) (openAIImageTestSSEEvent, bool) { + for _, event := range events { + if event.Name == name { + return event, true + } + } + return openAIImageTestSSEEvent{}, false +} + func TestOpenAIGatewayServiceForwardImages_OAuthUsesResponsesAPI(t *testing.T) { gin.SetMode(gin.TestMode) - body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","size":"1024x1024","quality":"high"}`) + body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","size":"1024x1024","quality":"high","n":2}`) req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body)) req.Header.Set("Content-Type", "application/json") @@ -328,6 +366,7 @@ func TestOpenAIGatewayServiceForwardImages_OAuthUsesResponsesAPI(t *testing.T) { require.Equal(t, "draw a cat", gjson.GetBytes(upstream.lastBody, "input.0.content.0.text").String()) require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, "gpt-image-2", gjson.Get(rec.Body.String(), "model").String()) require.Equal(t, "aGVsbG8=", gjson.Get(rec.Body.String(), "data.0.b64_json").String()) require.Equal(t, "draw a cat", gjson.Get(rec.Body.String(), "data.0.revised_prompt").String()) } @@ -354,8 +393,9 @@ func TestOpenAIGatewayServiceForwardImages_OAuthStreamingTransformsEvents(t *tes "X-Request-Id": []string{"req_img_stream"}, }, Body: io.NopCloser(strings.NewReader( - "data: {\"type\":\"response.image_generation_call.partial_image\",\"partial_image_b64\":\"cGFydGlhbA==\",\"partial_image_index\":0,\"output_format\":\"png\"}\n\n" + - "data: {\"type\":\"response.completed\",\"response\":{\"created_at\":1710000001,\"usage\":{\"input_tokens\":5,\"output_tokens\":9,\"output_tokens_details\":{\"image_tokens\":4}},\"tool_usage\":{\"image_gen\":{\"images\":1}},\"output\":[{\"type\":\"image_generation_call\",\"result\":\"ZmluYWw=\",\"output_format\":\"png\"}]}}\n\n" + + "data: {\"type\":\"response.created\",\"response\":{\"created_at\":1710000001,\"tools\":[{\"type\":\"image_generation\",\"model\":\"gpt-image-2\",\"background\":\"auto\",\"output_format\":\"png\",\"quality\":\"high\",\"size\":\"1024x1024\"}]}}\n\n" + + "data: {\"type\":\"response.image_generation_call.partial_image\",\"partial_image_b64\":\"cGFydGlhbA==\",\"partial_image_index\":0,\"output_format\":\"png\",\"background\":\"auto\"}\n\n" + + "data: {\"type\":\"response.completed\",\"response\":{\"created_at\":1710000001,\"usage\":{\"input_tokens\":5,\"output_tokens\":9,\"output_tokens_details\":{\"image_tokens\":4}},\"tool_usage\":{\"image_gen\":{\"images\":1}},\"tools\":[{\"type\":\"image_generation\",\"model\":\"gpt-image-2\",\"background\":\"auto\",\"output_format\":\"png\",\"quality\":\"high\",\"size\":\"1024x1024\"}],\"output\":[{\"type\":\"image_generation_call\",\"result\":\"ZmluYWw=\",\"output_format\":\"png\"}]}}\n\n" + "data: [DONE]\n\n", )), }, @@ -377,12 +417,32 @@ func TestOpenAIGatewayServiceForwardImages_OAuthStreamingTransformsEvents(t *tes require.NotNil(t, result) require.True(t, result.Stream) require.Equal(t, 1, result.ImageCount) - require.Contains(t, rec.Body.String(), "event: image_generation.partial_image") - require.Contains(t, rec.Body.String(), "event: image_generation.completed") - require.Contains(t, rec.Body.String(), "\"type\":\"image_generation.partial_image\"") - require.Contains(t, rec.Body.String(), "\"type\":\"image_generation.completed\"") - require.Contains(t, rec.Body.String(), "\"url\":\"data:image/png;base64,cGFydGlhbA==\"") - require.Contains(t, rec.Body.String(), "\"url\":\"data:image/png;base64,ZmluYWw=\"") + events := parseOpenAIImageTestSSEEvents(rec.Body.String()) + partial, ok := findOpenAIImageTestSSEEvent(events, "image_generation.partial_image") + require.True(t, ok) + require.Equal(t, "image_generation.partial_image", gjson.Get(partial.Data, "type").String()) + require.Equal(t, int64(1710000001), gjson.Get(partial.Data, "created_at").Int()) + require.Equal(t, "cGFydGlhbA==", gjson.Get(partial.Data, "b64_json").String()) + require.Equal(t, "data:image/png;base64,cGFydGlhbA==", gjson.Get(partial.Data, "url").String()) + require.Equal(t, "gpt-image-2", gjson.Get(partial.Data, "model").String()) + require.Equal(t, "png", gjson.Get(partial.Data, "output_format").String()) + require.Equal(t, "high", gjson.Get(partial.Data, "quality").String()) + require.Equal(t, "1024x1024", gjson.Get(partial.Data, "size").String()) + require.Equal(t, "auto", gjson.Get(partial.Data, "background").String()) + + completed, ok := findOpenAIImageTestSSEEvent(events, "image_generation.completed") + require.True(t, ok) + require.Equal(t, "image_generation.completed", gjson.Get(completed.Data, "type").String()) + require.Equal(t, int64(1710000001), gjson.Get(completed.Data, "created_at").Int()) + require.Equal(t, "ZmluYWw=", gjson.Get(completed.Data, "b64_json").String()) + require.Equal(t, "data:image/png;base64,ZmluYWw=", gjson.Get(completed.Data, "url").String()) + require.Equal(t, "gpt-image-2", gjson.Get(completed.Data, "model").String()) + require.Equal(t, "png", gjson.Get(completed.Data, "output_format").String()) + require.Equal(t, "high", gjson.Get(completed.Data, "quality").String()) + require.Equal(t, "1024x1024", gjson.Get(completed.Data, "size").String()) + require.Equal(t, "auto", gjson.Get(completed.Data, "background").String()) + require.JSONEq(t, `{"images":1}`, gjson.Get(completed.Data, "usage").Raw) + require.False(t, gjson.Get(completed.Data, "revised_prompt").Exists()) } func TestOpenAIGatewayServiceForwardImages_OAuthEditsMultipartUsesResponsesAPI(t *testing.T) { @@ -456,7 +516,7 @@ func TestOpenAIGatewayServiceForwardImages_OAuthEditsMultipartUsesResponsesAPI(t require.Equal(t, 1, result.ImageCount) require.Equal(t, "gpt-image-2", gjson.GetBytes(upstream.lastBody, "tools.0.model").String()) require.Equal(t, "edit", gjson.GetBytes(upstream.lastBody, "tools.0.action").String()) - require.Equal(t, "high", gjson.GetBytes(upstream.lastBody, "tools.0.input_fidelity").String()) + require.False(t, gjson.GetBytes(upstream.lastBody, "tools.0.input_fidelity").Exists()) require.Equal(t, "webp", gjson.GetBytes(upstream.lastBody, "tools.0.output_format").String()) require.True(t, strings.HasPrefix(gjson.GetBytes(upstream.lastBody, "input.0.content.1.image_url").String(), "data:image/png;base64,")) require.True(t, strings.HasPrefix(gjson.GetBytes(upstream.lastBody, "tools.0.input_image_mask.image_url").String(), "data:image/png;base64,")) @@ -493,8 +553,9 @@ func TestOpenAIGatewayServiceForwardImages_OAuthEditsStreamingTransformsEvents(t "Content-Type": []string{"text/event-stream"}, }, Body: io.NopCloser(strings.NewReader( - "data: {\"type\":\"response.image_generation_call.partial_image\",\"partial_image_b64\":\"cGFydGlhbA==\",\"partial_image_index\":0,\"output_format\":\"webp\"}\n\n" + - "data: {\"type\":\"response.completed\",\"response\":{\"created_at\":1710000003,\"usage\":{\"input_tokens\":7,\"output_tokens\":10,\"output_tokens_details\":{\"image_tokens\":5}},\"tool_usage\":{\"image_gen\":{\"images\":1}},\"output\":[{\"type\":\"image_generation_call\",\"result\":\"ZWRpdGVk\",\"revised_prompt\":\"replace background with aurora\",\"output_format\":\"webp\"}]}}\n\n" + + "data: {\"type\":\"response.created\",\"response\":{\"created_at\":1710000003,\"tools\":[{\"type\":\"image_generation\",\"model\":\"gpt-image-2\",\"background\":\"transparent\",\"output_format\":\"webp\",\"quality\":\"high\",\"size\":\"1024x1024\"}]}}\n\n" + + "data: {\"type\":\"response.image_generation_call.partial_image\",\"partial_image_b64\":\"cGFydGlhbA==\",\"partial_image_index\":0,\"output_format\":\"webp\",\"background\":\"transparent\"}\n\n" + + "data: {\"type\":\"response.completed\",\"response\":{\"created_at\":1710000003,\"usage\":{\"input_tokens\":7,\"output_tokens\":10,\"output_tokens_details\":{\"image_tokens\":5}},\"tool_usage\":{\"image_gen\":{\"images\":1}},\"tools\":[{\"type\":\"image_generation\",\"model\":\"gpt-image-2\",\"background\":\"transparent\",\"output_format\":\"webp\",\"quality\":\"high\",\"size\":\"1024x1024\"}],\"output\":[{\"type\":\"image_generation_call\",\"result\":\"ZWRpdGVk\",\"revised_prompt\":\"replace background with aurora\",\"output_format\":\"webp\"}]}}\n\n" + "data: [DONE]\n\n", )), }, @@ -518,15 +579,35 @@ func TestOpenAIGatewayServiceForwardImages_OAuthEditsStreamingTransformsEvents(t require.Equal(t, "edit", gjson.GetBytes(upstream.lastBody, "tools.0.action").String()) require.Equal(t, "https://example.com/source.png", gjson.GetBytes(upstream.lastBody, "input.0.content.1.image_url").String()) require.Equal(t, "https://example.com/mask.png", gjson.GetBytes(upstream.lastBody, "tools.0.input_image_mask.image_url").String()) - require.Contains(t, rec.Body.String(), "event: image_edit.partial_image") - require.Contains(t, rec.Body.String(), "event: image_edit.completed") - require.Contains(t, rec.Body.String(), "\"type\":\"image_edit.partial_image\"") - require.Contains(t, rec.Body.String(), "\"type\":\"image_edit.completed\"") - require.Contains(t, rec.Body.String(), "\"url\":\"data:image/webp;base64,cGFydGlhbA==\"") - require.Contains(t, rec.Body.String(), "\"url\":\"data:image/webp;base64,ZWRpdGVk\"") + events := parseOpenAIImageTestSSEEvents(rec.Body.String()) + partial, ok := findOpenAIImageTestSSEEvent(events, "image_edit.partial_image") + require.True(t, ok) + require.Equal(t, "image_edit.partial_image", gjson.Get(partial.Data, "type").String()) + require.Equal(t, int64(1710000003), gjson.Get(partial.Data, "created_at").Int()) + require.Equal(t, "cGFydGlhbA==", gjson.Get(partial.Data, "b64_json").String()) + require.Equal(t, "data:image/webp;base64,cGFydGlhbA==", gjson.Get(partial.Data, "url").String()) + require.Equal(t, "gpt-image-2", gjson.Get(partial.Data, "model").String()) + require.Equal(t, "webp", gjson.Get(partial.Data, "output_format").String()) + require.Equal(t, "high", gjson.Get(partial.Data, "quality").String()) + require.Equal(t, "1024x1024", gjson.Get(partial.Data, "size").String()) + require.Equal(t, "transparent", gjson.Get(partial.Data, "background").String()) + + completed, ok := findOpenAIImageTestSSEEvent(events, "image_edit.completed") + require.True(t, ok) + require.Equal(t, "image_edit.completed", gjson.Get(completed.Data, "type").String()) + require.Equal(t, int64(1710000003), gjson.Get(completed.Data, "created_at").Int()) + require.Equal(t, "ZWRpdGVk", gjson.Get(completed.Data, "b64_json").String()) + require.Equal(t, "data:image/webp;base64,ZWRpdGVk", gjson.Get(completed.Data, "url").String()) + require.Equal(t, "gpt-image-2", gjson.Get(completed.Data, "model").String()) + require.Equal(t, "webp", gjson.Get(completed.Data, "output_format").String()) + require.Equal(t, "high", gjson.Get(completed.Data, "quality").String()) + require.Equal(t, "1024x1024", gjson.Get(completed.Data, "size").String()) + require.Equal(t, "transparent", gjson.Get(completed.Data, "background").String()) + require.JSONEq(t, `{"images":1}`, gjson.Get(completed.Data, "usage").Raw) + require.False(t, gjson.Get(completed.Data, "revised_prompt").Exists()) } -func TestBuildOpenAIImagesResponsesRequest_RejectsMultipleImages(t *testing.T) { +func TestBuildOpenAIImagesResponsesRequest_DowngradesMultipleImagesToSingle(t *testing.T) { parsed := &OpenAIImagesRequest{ Endpoint: openAIImagesGenerationsEndpoint, Model: "gpt-image-2", @@ -535,9 +616,29 @@ func TestBuildOpenAIImagesResponsesRequest_RejectsMultipleImages(t *testing.T) { } body, err := buildOpenAIImagesResponsesRequest(parsed, "gpt-image-2") - require.Error(t, err) - require.Nil(t, body) - require.Contains(t, err.Error(), "only n=1") + require.NoError(t, err) + require.NotNil(t, body) + require.False(t, gjson.GetBytes(body, "tools.0.n").Exists()) + require.Equal(t, "gpt-image-2", gjson.GetBytes(body, "tools.0.model").String()) + require.Equal(t, "draw a cat", gjson.GetBytes(body, "input.0.content.0.text").String()) +} + +func TestBuildOpenAIImagesResponsesRequest_StripsInputFidelity(t *testing.T) { + parsed := &OpenAIImagesRequest{ + Endpoint: openAIImagesEditsEndpoint, + Model: "gpt-image-2", + Prompt: "replace background", + InputFidelity: "high", + InputImageURLs: []string{ + "https://example.com/source.png", + }, + } + + body, err := buildOpenAIImagesResponsesRequest(parsed, "gpt-image-2") + require.NoError(t, err) + require.NotNil(t, body) + require.False(t, gjson.GetBytes(body, "tools.0.input_fidelity").Exists()) + require.Equal(t, "edit", gjson.GetBytes(body, "tools.0.action").String()) } func TestCollectOpenAIImagesFromResponsesBody_FallsBackToOutputItemDone(t *testing.T) { @@ -604,8 +705,14 @@ func TestOpenAIGatewayServiceForwardImages_OAuthStreamingHandlesOutputItemDoneFa require.NotNil(t, result) require.True(t, result.Stream) require.Equal(t, 1, result.ImageCount) - require.Contains(t, rec.Body.String(), "event: image_generation.completed") - require.Contains(t, rec.Body.String(), "\"type\":\"image_generation.completed\"") - require.Contains(t, rec.Body.String(), "\"url\":\"data:image/png;base64,ZmluYWw=\"") + events := parseOpenAIImageTestSSEEvents(rec.Body.String()) + completed, ok := findOpenAIImageTestSSEEvent(events, "image_generation.completed") + require.True(t, ok) + require.Equal(t, "image_generation.completed", gjson.Get(completed.Data, "type").String()) + require.Equal(t, int64(1710000005), gjson.Get(completed.Data, "created_at").Int()) + require.Equal(t, "ZmluYWw=", gjson.Get(completed.Data, "b64_json").String()) + require.Equal(t, "data:image/png;base64,ZmluYWw=", gjson.Get(completed.Data, "url").String()) + require.Equal(t, "gpt-image-2", gjson.Get(completed.Data, "model").String()) + require.JSONEq(t, `{"images":1}`, gjson.Get(completed.Data, "usage").Raw) require.NotContains(t, rec.Body.String(), "event: error") }