修复计费问题以及模型回显
This commit is contained in:
@@ -1098,3 +1098,50 @@ func TestOpenAIGatewayServiceRecordUsage_ImageOnlyUsageStillPersists(t *testing.
|
|||||||
require.NotNil(t, usageRepo.lastLog.BillingMode)
|
require.NotNil(t, usageRepo.lastLog.BillingMode)
|
||||||
require.Equal(t, string(BillingModeImage), *usageRepo.lastLog.BillingMode)
|
require.Equal(t, string(BillingModeImage), *usageRepo.lastLog.BillingMode)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOpenAIGatewayServiceRecordUsage_ImageUsesPerImageBillingEvenWithUsageTokens(t *testing.T) {
|
||||||
|
imagePrice := 0.02
|
||||||
|
groupID := int64(12)
|
||||||
|
|
||||||
|
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
|
||||||
|
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||||
|
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||||
|
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
|
||||||
|
|
||||||
|
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
|
||||||
|
Result: &OpenAIForwardResult{
|
||||||
|
RequestID: "resp_image_per_request",
|
||||||
|
Model: "gpt-image-2",
|
||||||
|
Usage: OpenAIUsage{
|
||||||
|
InputTokens: 1110,
|
||||||
|
OutputTokens: 1756,
|
||||||
|
ImageOutputTokens: 1756,
|
||||||
|
},
|
||||||
|
ImageCount: 2,
|
||||||
|
ImageSize: "1K",
|
||||||
|
Duration: time.Second,
|
||||||
|
},
|
||||||
|
APIKey: &APIKey{
|
||||||
|
ID: 1008,
|
||||||
|
GroupID: i64p(groupID),
|
||||||
|
Group: &Group{
|
||||||
|
ID: groupID,
|
||||||
|
RateMultiplier: 1.0,
|
||||||
|
ImagePrice1K: &imagePrice,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
User: &User{ID: 2008},
|
||||||
|
Account: &Account{ID: 3008},
|
||||||
|
})
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, usageRepo.lastLog)
|
||||||
|
require.NotNil(t, usageRepo.lastLog.BillingMode)
|
||||||
|
require.Equal(t, string(BillingModeImage), *usageRepo.lastLog.BillingMode)
|
||||||
|
require.Equal(t, 2, usageRepo.lastLog.ImageCount)
|
||||||
|
require.InDelta(t, 0.04, usageRepo.lastLog.TotalCost, 1e-12)
|
||||||
|
require.InDelta(t, 0.04, usageRepo.lastLog.ActualCost, 1e-12)
|
||||||
|
require.InDelta(t, 0.0, usageRepo.lastLog.InputCost, 1e-12)
|
||||||
|
require.InDelta(t, 0.0, usageRepo.lastLog.OutputCost, 1e-12)
|
||||||
|
require.InDelta(t, 0.0, usageRepo.lastLog.ImageOutputCost, 1e-12)
|
||||||
|
}
|
||||||
|
|||||||
@@ -4625,12 +4625,6 @@ func (s *OpenAIGatewayService) calculateOpenAIRecordUsageCost(
|
|||||||
serviceTier string,
|
serviceTier string,
|
||||||
) (*CostBreakdown, error) {
|
) (*CostBreakdown, error) {
|
||||||
if result != nil && result.ImageCount > 0 {
|
if result != nil && result.ImageCount > 0 {
|
||||||
if hasOpenAIImageUsageTokens(result) {
|
|
||||||
cost, err := s.calculateOpenAIImageTokenCost(ctx, apiKey, billingModel, multiplier, tokens, serviceTier, result.ImageSize)
|
|
||||||
if err == nil {
|
|
||||||
return cost, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return s.calculateOpenAIImageCost(ctx, billingModel, apiKey, result, multiplier), nil
|
return s.calculateOpenAIImageCost(ctx, billingModel, apiKey, result, multiplier), nil
|
||||||
}
|
}
|
||||||
if s.resolver != nil && apiKey.Group != nil {
|
if s.resolver != nil && apiKey.Group != nil {
|
||||||
@@ -4682,7 +4676,8 @@ func (s *OpenAIGatewayService) calculateOpenAIImageCost(
|
|||||||
result *OpenAIForwardResult,
|
result *OpenAIForwardResult,
|
||||||
multiplier float64,
|
multiplier float64,
|
||||||
) *CostBreakdown {
|
) *CostBreakdown {
|
||||||
if resolved := s.resolveOpenAIChannelPricing(ctx, billingModel, apiKey); resolved != nil {
|
if resolved := s.resolveOpenAIChannelPricing(ctx, billingModel, apiKey); resolved != nil &&
|
||||||
|
(resolved.Mode == BillingModePerRequest || resolved.Mode == BillingModeImage) {
|
||||||
gid := apiKey.Group.ID
|
gid := apiKey.Group.ID
|
||||||
cost, err := s.billingService.CalculateCostUnified(CostInput{
|
cost, err := s.billingService.CalculateCostUnified(CostInput{
|
||||||
Ctx: ctx,
|
Ctx: ctx,
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ type openAIResponsesImageResult struct {
|
|||||||
Size string
|
Size string
|
||||||
Background string
|
Background string
|
||||||
Quality string
|
Quality string
|
||||||
|
Model string
|
||||||
}
|
}
|
||||||
|
|
||||||
func openAIResponsesImageResultKey(itemID string, result openAIResponsesImageResult) string {
|
func openAIResponsesImageResultKey(itemID string, result openAIResponsesImageResult) string {
|
||||||
@@ -49,6 +50,126 @@ func appendOpenAIResponsesImageResultDedup(results *[]openAIResponsesImageResult
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func mergeOpenAIResponsesImageMeta(dst *openAIResponsesImageResult, src openAIResponsesImageResult) {
|
||||||
|
if dst == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if trimmed := strings.TrimSpace(src.OutputFormat); trimmed != "" {
|
||||||
|
dst.OutputFormat = trimmed
|
||||||
|
}
|
||||||
|
if trimmed := strings.TrimSpace(src.Size); trimmed != "" {
|
||||||
|
dst.Size = trimmed
|
||||||
|
}
|
||||||
|
if trimmed := strings.TrimSpace(src.Background); trimmed != "" {
|
||||||
|
dst.Background = trimmed
|
||||||
|
}
|
||||||
|
if trimmed := strings.TrimSpace(src.Quality); trimmed != "" {
|
||||||
|
dst.Quality = trimmed
|
||||||
|
}
|
||||||
|
if trimmed := strings.TrimSpace(src.Model); trimmed != "" {
|
||||||
|
dst.Model = trimmed
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractOpenAIResponsesImageMetaFromLifecycleEvent(payload []byte) (openAIResponsesImageResult, int64, bool) {
|
||||||
|
switch gjson.GetBytes(payload, "type").String() {
|
||||||
|
case "response.created", "response.in_progress", "response.completed":
|
||||||
|
default:
|
||||||
|
return openAIResponsesImageResult{}, 0, false
|
||||||
|
}
|
||||||
|
|
||||||
|
response := gjson.GetBytes(payload, "response")
|
||||||
|
if !response.Exists() {
|
||||||
|
return openAIResponsesImageResult{}, 0, false
|
||||||
|
}
|
||||||
|
|
||||||
|
meta := openAIResponsesImageResult{
|
||||||
|
OutputFormat: strings.TrimSpace(response.Get("tools.0.output_format").String()),
|
||||||
|
Size: strings.TrimSpace(response.Get("tools.0.size").String()),
|
||||||
|
Background: strings.TrimSpace(response.Get("tools.0.background").String()),
|
||||||
|
Quality: strings.TrimSpace(response.Get("tools.0.quality").String()),
|
||||||
|
Model: strings.TrimSpace(response.Get("tools.0.model").String()),
|
||||||
|
}
|
||||||
|
return meta, response.Get("created_at").Int(), true
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildOpenAIImagesStreamPartialPayload(
|
||||||
|
eventType string,
|
||||||
|
b64 string,
|
||||||
|
partialImageIndex int64,
|
||||||
|
responseFormat string,
|
||||||
|
createdAt int64,
|
||||||
|
meta openAIResponsesImageResult,
|
||||||
|
) []byte {
|
||||||
|
if createdAt <= 0 {
|
||||||
|
createdAt = time.Now().Unix()
|
||||||
|
}
|
||||||
|
|
||||||
|
payload := []byte(`{"type":"","created_at":0,"partial_image_index":0,"b64_json":""}`)
|
||||||
|
payload, _ = sjson.SetBytes(payload, "type", eventType)
|
||||||
|
payload, _ = sjson.SetBytes(payload, "created_at", createdAt)
|
||||||
|
payload, _ = sjson.SetBytes(payload, "partial_image_index", partialImageIndex)
|
||||||
|
payload, _ = sjson.SetBytes(payload, "b64_json", b64)
|
||||||
|
if strings.EqualFold(strings.TrimSpace(responseFormat), "url") {
|
||||||
|
payload, _ = sjson.SetBytes(payload, "url", "data:"+openAIImageOutputMIMEType(meta.OutputFormat)+";base64,"+b64)
|
||||||
|
}
|
||||||
|
if meta.Background != "" {
|
||||||
|
payload, _ = sjson.SetBytes(payload, "background", meta.Background)
|
||||||
|
}
|
||||||
|
if meta.OutputFormat != "" {
|
||||||
|
payload, _ = sjson.SetBytes(payload, "output_format", meta.OutputFormat)
|
||||||
|
}
|
||||||
|
if meta.Quality != "" {
|
||||||
|
payload, _ = sjson.SetBytes(payload, "quality", meta.Quality)
|
||||||
|
}
|
||||||
|
if meta.Size != "" {
|
||||||
|
payload, _ = sjson.SetBytes(payload, "size", meta.Size)
|
||||||
|
}
|
||||||
|
if meta.Model != "" {
|
||||||
|
payload, _ = sjson.SetBytes(payload, "model", meta.Model)
|
||||||
|
}
|
||||||
|
return payload
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildOpenAIImagesStreamCompletedPayload(
|
||||||
|
eventType string,
|
||||||
|
img openAIResponsesImageResult,
|
||||||
|
responseFormat string,
|
||||||
|
createdAt int64,
|
||||||
|
usageRaw []byte,
|
||||||
|
) []byte {
|
||||||
|
if createdAt <= 0 {
|
||||||
|
createdAt = time.Now().Unix()
|
||||||
|
}
|
||||||
|
|
||||||
|
payload := []byte(`{"type":"","created_at":0,"b64_json":""}`)
|
||||||
|
payload, _ = sjson.SetBytes(payload, "type", eventType)
|
||||||
|
payload, _ = sjson.SetBytes(payload, "created_at", createdAt)
|
||||||
|
payload, _ = sjson.SetBytes(payload, "b64_json", img.Result)
|
||||||
|
if strings.EqualFold(strings.TrimSpace(responseFormat), "url") {
|
||||||
|
payload, _ = sjson.SetBytes(payload, "url", "data:"+openAIImageOutputMIMEType(img.OutputFormat)+";base64,"+img.Result)
|
||||||
|
}
|
||||||
|
if img.Background != "" {
|
||||||
|
payload, _ = sjson.SetBytes(payload, "background", img.Background)
|
||||||
|
}
|
||||||
|
if img.OutputFormat != "" {
|
||||||
|
payload, _ = sjson.SetBytes(payload, "output_format", img.OutputFormat)
|
||||||
|
}
|
||||||
|
if img.Quality != "" {
|
||||||
|
payload, _ = sjson.SetBytes(payload, "quality", img.Quality)
|
||||||
|
}
|
||||||
|
if img.Size != "" {
|
||||||
|
payload, _ = sjson.SetBytes(payload, "size", img.Size)
|
||||||
|
}
|
||||||
|
if img.Model != "" {
|
||||||
|
payload, _ = sjson.SetBytes(payload, "model", img.Model)
|
||||||
|
}
|
||||||
|
if len(usageRaw) > 0 && gjson.ValidBytes(usageRaw) {
|
||||||
|
payload, _ = sjson.SetRawBytes(payload, "usage", usageRaw)
|
||||||
|
}
|
||||||
|
return payload
|
||||||
|
}
|
||||||
|
|
||||||
func openAIImageOutputMIMEType(outputFormat string) string {
|
func openAIImageOutputMIMEType(outputFormat string) string {
|
||||||
if outputFormat == "" {
|
if outputFormat == "" {
|
||||||
return "image/png"
|
return "image/png"
|
||||||
@@ -134,16 +255,12 @@ func buildOpenAIImagesResponsesRequest(parsed *OpenAIImagesRequest, toolModel st
|
|||||||
{path: "background", value: parsed.Background},
|
{path: "background", value: parsed.Background},
|
||||||
{path: "output_format", value: parsed.OutputFormat},
|
{path: "output_format", value: parsed.OutputFormat},
|
||||||
{path: "moderation", value: parsed.Moderation},
|
{path: "moderation", value: parsed.Moderation},
|
||||||
{path: "input_fidelity", value: parsed.InputFidelity},
|
|
||||||
{path: "style", value: parsed.Style},
|
{path: "style", value: parsed.Style},
|
||||||
} {
|
} {
|
||||||
if trimmed := strings.TrimSpace(field.value); trimmed != "" {
|
if trimmed := strings.TrimSpace(field.value); trimmed != "" {
|
||||||
tool, _ = sjson.SetBytes(tool, field.path, trimmed)
|
tool, _ = sjson.SetBytes(tool, field.path, trimmed)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if parsed.N > 1 {
|
|
||||||
return nil, fmt.Errorf("codex /responses image tool currently supports only n=1")
|
|
||||||
}
|
|
||||||
if parsed.OutputCompression != nil {
|
if parsed.OutputCompression != nil {
|
||||||
tool, _ = sjson.SetBytes(tool, "output_compression", *parsed.OutputCompression)
|
tool, _ = sjson.SetBytes(tool, "output_compression", *parsed.OutputCompression)
|
||||||
}
|
}
|
||||||
@@ -247,6 +364,7 @@ func collectOpenAIImagesFromResponsesBody(body []byte) ([]openAIResponsesImageRe
|
|||||||
createdAt int64
|
createdAt int64
|
||||||
usageRaw []byte
|
usageRaw []byte
|
||||||
foundFinal bool
|
foundFinal bool
|
||||||
|
responseMeta openAIResponsesImageResult
|
||||||
)
|
)
|
||||||
|
|
||||||
for _, line := range bytes.Split(body, []byte("\n")) {
|
for _, line := range bytes.Split(body, []byte("\n")) {
|
||||||
@@ -259,18 +377,21 @@ func collectOpenAIImagesFromResponsesBody(body []byte) ([]openAIResponsesImageRe
|
|||||||
if !gjson.ValidBytes(payload) {
|
if !gjson.ValidBytes(payload) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
if meta, eventCreatedAt, ok := extractOpenAIResponsesImageMetaFromLifecycleEvent(payload); ok {
|
||||||
|
mergeOpenAIResponsesImageMeta(&responseMeta, meta)
|
||||||
|
if eventCreatedAt > 0 {
|
||||||
|
createdAt = eventCreatedAt
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
switch gjson.GetBytes(payload, "type").String() {
|
switch gjson.GetBytes(payload, "type").String() {
|
||||||
case "response.created":
|
|
||||||
if createdAt <= 0 {
|
|
||||||
createdAt = gjson.GetBytes(payload, "response.created_at").Int()
|
|
||||||
}
|
|
||||||
case "response.output_item.done":
|
case "response.output_item.done":
|
||||||
result, itemID, ok, err := extractOpenAIImageFromResponsesOutputItemDone(payload)
|
result, itemID, ok, err := extractOpenAIImageFromResponsesOutputItemDone(payload)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, nil, openAIResponsesImageResult{}, false, err
|
return nil, 0, nil, openAIResponsesImageResult{}, false, err
|
||||||
}
|
}
|
||||||
if ok {
|
if ok {
|
||||||
|
mergeOpenAIResponsesImageMeta(&result, responseMeta)
|
||||||
appendOpenAIResponsesImageResultDedup(&fallbackResults, fallbackSeen, itemID, result)
|
appendOpenAIResponsesImageResultDedup(&fallbackResults, fallbackSeen, itemID, result)
|
||||||
}
|
}
|
||||||
case "response.completed":
|
case "response.completed":
|
||||||
@@ -286,16 +407,21 @@ func collectOpenAIImagesFromResponsesBody(body []byte) ([]openAIResponsesImageRe
|
|||||||
usageRaw = completedUsageRaw
|
usageRaw = completedUsageRaw
|
||||||
}
|
}
|
||||||
if len(results) > 0 {
|
if len(results) > 0 {
|
||||||
|
mergeOpenAIResponsesImageMeta(&firstMeta, responseMeta)
|
||||||
return results, createdAt, usageRaw, firstMeta, true, nil
|
return results, createdAt, usageRaw, firstMeta, true, nil
|
||||||
}
|
}
|
||||||
if len(fallbackResults) > 0 {
|
if len(fallbackResults) > 0 {
|
||||||
return fallbackResults, createdAt, usageRaw, fallbackResults[0], true, nil
|
firstMeta = fallbackResults[0]
|
||||||
|
mergeOpenAIResponsesImageMeta(&firstMeta, responseMeta)
|
||||||
|
return fallbackResults, createdAt, usageRaw, firstMeta, true, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(fallbackResults) > 0 {
|
if len(fallbackResults) > 0 {
|
||||||
return fallbackResults, createdAt, usageRaw, fallbackResults[0], foundFinal, nil
|
firstMeta := fallbackResults[0]
|
||||||
|
mergeOpenAIResponsesImageMeta(&firstMeta, responseMeta)
|
||||||
|
return fallbackResults, createdAt, usageRaw, firstMeta, foundFinal, nil
|
||||||
}
|
}
|
||||||
return nil, createdAt, usageRaw, openAIResponsesImageResult{}, foundFinal, nil
|
return nil, createdAt, usageRaw, openAIResponsesImageResult{}, foundFinal, nil
|
||||||
}
|
}
|
||||||
@@ -341,6 +467,9 @@ func buildOpenAIImagesAPIResponse(
|
|||||||
if firstMeta.Size != "" {
|
if firstMeta.Size != "" {
|
||||||
out, _ = sjson.SetBytes(out, "size", firstMeta.Size)
|
out, _ = sjson.SetBytes(out, "size", firstMeta.Size)
|
||||||
}
|
}
|
||||||
|
if firstMeta.Model != "" {
|
||||||
|
out, _ = sjson.SetBytes(out, "model", firstMeta.Model)
|
||||||
|
}
|
||||||
if len(usageRaw) > 0 && gjson.ValidBytes(usageRaw) {
|
if len(usageRaw) > 0 && gjson.ValidBytes(usageRaw) {
|
||||||
out, _ = sjson.SetRawBytes(out, "usage", usageRaw)
|
out, _ = sjson.SetRawBytes(out, "usage", usageRaw)
|
||||||
}
|
}
|
||||||
@@ -380,6 +509,7 @@ func (s *OpenAIGatewayService) handleOpenAIImagesOAuthNonStreamingResponse(
|
|||||||
resp *http.Response,
|
resp *http.Response,
|
||||||
c *gin.Context,
|
c *gin.Context,
|
||||||
responseFormat string,
|
responseFormat string,
|
||||||
|
fallbackModel string,
|
||||||
) (OpenAIUsage, int, error) {
|
) (OpenAIUsage, int, error) {
|
||||||
body, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, openAITooLargeError)
|
body, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, openAITooLargeError)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -403,6 +533,9 @@ func (s *OpenAIGatewayService) handleOpenAIImagesOAuthNonStreamingResponse(
|
|||||||
if len(results) == 0 {
|
if len(results) == 0 {
|
||||||
return OpenAIUsage{}, 0, fmt.Errorf("upstream did not return image output")
|
return OpenAIUsage{}, 0, fmt.Errorf("upstream did not return image output")
|
||||||
}
|
}
|
||||||
|
if strings.TrimSpace(firstMeta.Model) == "" {
|
||||||
|
firstMeta.Model = strings.TrimSpace(fallbackModel)
|
||||||
|
}
|
||||||
|
|
||||||
responseBody, err := buildOpenAIImagesAPIResponse(results, createdAt, usageRaw, firstMeta, responseFormat)
|
responseBody, err := buildOpenAIImagesAPIResponse(results, createdAt, usageRaw, firstMeta, responseFormat)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -419,6 +552,7 @@ func (s *OpenAIGatewayService) handleOpenAIImagesOAuthStreamingResponse(
|
|||||||
startTime time.Time,
|
startTime time.Time,
|
||||||
responseFormat string,
|
responseFormat string,
|
||||||
streamPrefix string,
|
streamPrefix string,
|
||||||
|
fallbackModel string,
|
||||||
) (OpenAIUsage, int, *int, error) {
|
) (OpenAIUsage, int, *int, error) {
|
||||||
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
|
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
|
||||||
c.Header("Content-Type", "text/event-stream")
|
c.Header("Content-Type", "text/event-stream")
|
||||||
@@ -441,6 +575,10 @@ func (s *OpenAIGatewayService) handleOpenAIImagesOAuthStreamingResponse(
|
|||||||
imageCount := 0
|
imageCount := 0
|
||||||
var firstTokenMs *int
|
var firstTokenMs *int
|
||||||
emitted := make(map[string]struct{})
|
emitted := make(map[string]struct{})
|
||||||
|
pendingResults := make([]openAIResponsesImageResult, 0, 1)
|
||||||
|
pendingSeen := make(map[string]struct{})
|
||||||
|
streamMeta := openAIResponsesImageResult{Model: strings.TrimSpace(fallbackModel)}
|
||||||
|
var createdAt int64
|
||||||
|
|
||||||
for {
|
for {
|
||||||
line, err := reader.ReadBytes('\n')
|
line, err := reader.ReadBytes('\n')
|
||||||
@@ -455,20 +593,30 @@ func (s *OpenAIGatewayService) handleOpenAIImagesOAuthStreamingResponse(
|
|||||||
dataBytes := []byte(data)
|
dataBytes := []byte(data)
|
||||||
s.parseSSEUsageBytes(dataBytes, &usage)
|
s.parseSSEUsageBytes(dataBytes, &usage)
|
||||||
if gjson.ValidBytes(dataBytes) {
|
if gjson.ValidBytes(dataBytes) {
|
||||||
|
if meta, eventCreatedAt, ok := extractOpenAIResponsesImageMetaFromLifecycleEvent(dataBytes); ok {
|
||||||
|
mergeOpenAIResponsesImageMeta(&streamMeta, meta)
|
||||||
|
if eventCreatedAt > 0 {
|
||||||
|
createdAt = eventCreatedAt
|
||||||
|
}
|
||||||
|
}
|
||||||
switch gjson.GetBytes(dataBytes, "type").String() {
|
switch gjson.GetBytes(dataBytes, "type").String() {
|
||||||
case "response.image_generation_call.partial_image":
|
case "response.image_generation_call.partial_image":
|
||||||
b64 := strings.TrimSpace(gjson.GetBytes(dataBytes, "partial_image_b64").String())
|
b64 := strings.TrimSpace(gjson.GetBytes(dataBytes, "partial_image_b64").String())
|
||||||
if b64 != "" {
|
if b64 != "" {
|
||||||
eventName := streamPrefix + ".partial_image"
|
eventName := streamPrefix + ".partial_image"
|
||||||
payload := []byte(`{"type":"","partial_image_index":0}`)
|
partialMeta := streamMeta
|
||||||
payload, _ = sjson.SetBytes(payload, "type", eventName)
|
mergeOpenAIResponsesImageMeta(&partialMeta, openAIResponsesImageResult{
|
||||||
payload, _ = sjson.SetBytes(payload, "partial_image_index", gjson.GetBytes(dataBytes, "partial_image_index").Int())
|
OutputFormat: strings.TrimSpace(gjson.GetBytes(dataBytes, "output_format").String()),
|
||||||
if format == "url" {
|
Background: strings.TrimSpace(gjson.GetBytes(dataBytes, "background").String()),
|
||||||
outputFormat := strings.TrimSpace(gjson.GetBytes(dataBytes, "output_format").String())
|
})
|
||||||
payload, _ = sjson.SetBytes(payload, "url", "data:"+openAIImageOutputMIMEType(outputFormat)+";base64,"+b64)
|
payload := buildOpenAIImagesStreamPartialPayload(
|
||||||
} else {
|
eventName,
|
||||||
payload, _ = sjson.SetBytes(payload, "b64_json", b64)
|
b64,
|
||||||
}
|
gjson.GetBytes(dataBytes, "partial_image_index").Int(),
|
||||||
|
format,
|
||||||
|
createdAt,
|
||||||
|
partialMeta,
|
||||||
|
)
|
||||||
if writeErr := s.writeOpenAIImagesStreamEvent(c, flusher, eventName, payload); writeErr != nil {
|
if writeErr := s.writeOpenAIImagesStreamEvent(c, flusher, eventName, payload); writeErr != nil {
|
||||||
return OpenAIUsage{}, imageCount, firstTokenMs, writeErr
|
return OpenAIUsage{}, imageCount, firstTokenMs, writeErr
|
||||||
}
|
}
|
||||||
@@ -482,59 +630,46 @@ func (s *OpenAIGatewayService) handleOpenAIImagesOAuthStreamingResponse(
|
|||||||
if !ok {
|
if !ok {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
mergeOpenAIResponsesImageMeta(&streamMeta, img)
|
||||||
|
mergeOpenAIResponsesImageMeta(&img, streamMeta)
|
||||||
key := openAIResponsesImageResultKey(itemID, img)
|
key := openAIResponsesImageResultKey(itemID, img)
|
||||||
if _, exists := emitted[key]; exists {
|
if _, exists := emitted[key]; exists {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
eventName := streamPrefix + ".completed"
|
if _, exists := pendingSeen[key]; exists {
|
||||||
payload := []byte(`{"type":""}`)
|
break
|
||||||
payload, _ = sjson.SetBytes(payload, "type", eventName)
|
|
||||||
if format == "url" {
|
|
||||||
payload, _ = sjson.SetBytes(payload, "url", "data:"+openAIImageOutputMIMEType(img.OutputFormat)+";base64,"+img.Result)
|
|
||||||
} else {
|
|
||||||
payload, _ = sjson.SetBytes(payload, "b64_json", img.Result)
|
|
||||||
}
|
}
|
||||||
if img.RevisedPrompt != "" {
|
pendingSeen[key] = struct{}{}
|
||||||
payload, _ = sjson.SetBytes(payload, "revised_prompt", img.RevisedPrompt)
|
pendingResults = append(pendingResults, img)
|
||||||
}
|
|
||||||
if writeErr := s.writeOpenAIImagesStreamEvent(c, flusher, eventName, payload); writeErr != nil {
|
|
||||||
return OpenAIUsage{}, imageCount, firstTokenMs, writeErr
|
|
||||||
}
|
|
||||||
emitted[key] = struct{}{}
|
|
||||||
imageCount = len(emitted)
|
|
||||||
case "response.completed":
|
case "response.completed":
|
||||||
results, _, usageRaw, _, extractErr := extractOpenAIImagesFromResponsesCompleted(dataBytes)
|
results, _, usageRaw, firstMeta, extractErr := extractOpenAIImagesFromResponsesCompleted(dataBytes)
|
||||||
if extractErr != nil {
|
if extractErr != nil {
|
||||||
_ = s.writeOpenAIImagesStreamEvent(c, flusher, "error", buildOpenAIImagesStreamErrorBody(extractErr.Error()))
|
_ = s.writeOpenAIImagesStreamEvent(c, flusher, "error", buildOpenAIImagesStreamErrorBody(extractErr.Error()))
|
||||||
return OpenAIUsage{}, imageCount, firstTokenMs, extractErr
|
return OpenAIUsage{}, imageCount, firstTokenMs, extractErr
|
||||||
}
|
}
|
||||||
if len(results) == 0 {
|
mergeOpenAIResponsesImageMeta(&streamMeta, firstMeta)
|
||||||
if imageCount > 0 {
|
finalResults := make([]openAIResponsesImageResult, 0, len(results)+len(pendingResults))
|
||||||
return usage, imageCount, firstTokenMs, nil
|
finalSeen := make(map[string]struct{})
|
||||||
}
|
for _, img := range results {
|
||||||
|
mergeOpenAIResponsesImageMeta(&img, streamMeta)
|
||||||
|
appendOpenAIResponsesImageResultDedup(&finalResults, finalSeen, "", img)
|
||||||
|
}
|
||||||
|
for _, img := range pendingResults {
|
||||||
|
mergeOpenAIResponsesImageMeta(&img, streamMeta)
|
||||||
|
appendOpenAIResponsesImageResultDedup(&finalResults, finalSeen, "", img)
|
||||||
|
}
|
||||||
|
if len(finalResults) == 0 {
|
||||||
err = fmt.Errorf("upstream did not return image output")
|
err = fmt.Errorf("upstream did not return image output")
|
||||||
_ = s.writeOpenAIImagesStreamEvent(c, flusher, "error", buildOpenAIImagesStreamErrorBody(err.Error()))
|
_ = s.writeOpenAIImagesStreamEvent(c, flusher, "error", buildOpenAIImagesStreamErrorBody(err.Error()))
|
||||||
return OpenAIUsage{}, imageCount, firstTokenMs, err
|
return OpenAIUsage{}, imageCount, firstTokenMs, err
|
||||||
}
|
}
|
||||||
eventName := streamPrefix + ".completed"
|
eventName := streamPrefix + ".completed"
|
||||||
for _, img := range results {
|
for _, img := range finalResults {
|
||||||
key := openAIResponsesImageResultKey("", img)
|
key := openAIResponsesImageResultKey("", img)
|
||||||
if _, exists := emitted[key]; exists {
|
if _, exists := emitted[key]; exists {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
payload := []byte(`{"type":""}`)
|
payload := buildOpenAIImagesStreamCompletedPayload(eventName, img, format, createdAt, usageRaw)
|
||||||
payload, _ = sjson.SetBytes(payload, "type", eventName)
|
|
||||||
if format == "url" {
|
|
||||||
payload, _ = sjson.SetBytes(payload, "url", "data:"+openAIImageOutputMIMEType(img.OutputFormat)+";base64,"+img.Result)
|
|
||||||
} else {
|
|
||||||
payload, _ = sjson.SetBytes(payload, "b64_json", img.Result)
|
|
||||||
}
|
|
||||||
if img.RevisedPrompt != "" {
|
|
||||||
payload, _ = sjson.SetBytes(payload, "revised_prompt", img.RevisedPrompt)
|
|
||||||
}
|
|
||||||
if len(usageRaw) > 0 && gjson.ValidBytes(usageRaw) {
|
|
||||||
payload, _ = sjson.SetRawBytes(payload, "usage", usageRaw)
|
|
||||||
}
|
|
||||||
if writeErr := s.writeOpenAIImagesStreamEvent(c, flusher, eventName, payload); writeErr != nil {
|
if writeErr := s.writeOpenAIImagesStreamEvent(c, flusher, eventName, payload); writeErr != nil {
|
||||||
return OpenAIUsage{}, imageCount, firstTokenMs, writeErr
|
return OpenAIUsage{}, imageCount, firstTokenMs, writeErr
|
||||||
}
|
}
|
||||||
@@ -558,6 +693,23 @@ func (s *OpenAIGatewayService) handleOpenAIImagesOAuthStreamingResponse(
|
|||||||
if imageCount > 0 {
|
if imageCount > 0 {
|
||||||
return usage, imageCount, firstTokenMs, nil
|
return usage, imageCount, firstTokenMs, nil
|
||||||
}
|
}
|
||||||
|
if len(pendingResults) > 0 {
|
||||||
|
eventName := streamPrefix + ".completed"
|
||||||
|
for _, img := range pendingResults {
|
||||||
|
mergeOpenAIResponsesImageMeta(&img, streamMeta)
|
||||||
|
key := openAIResponsesImageResultKey("", img)
|
||||||
|
if _, exists := emitted[key]; exists {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
payload := buildOpenAIImagesStreamCompletedPayload(eventName, img, format, createdAt, nil)
|
||||||
|
if writeErr := s.writeOpenAIImagesStreamEvent(c, flusher, eventName, payload); writeErr != nil {
|
||||||
|
return OpenAIUsage{}, imageCount, firstTokenMs, writeErr
|
||||||
|
}
|
||||||
|
emitted[key] = struct{}{}
|
||||||
|
}
|
||||||
|
imageCount = len(emitted)
|
||||||
|
return usage, imageCount, firstTokenMs, nil
|
||||||
|
}
|
||||||
|
|
||||||
streamErr := fmt.Errorf("stream disconnected before image generation completed")
|
streamErr := fmt.Errorf("stream disconnected before image generation completed")
|
||||||
_ = s.writeOpenAIImagesStreamEvent(c, flusher, "error", buildOpenAIImagesStreamErrorBody(streamErr.Error()))
|
_ = s.writeOpenAIImagesStreamEvent(c, flusher, "error", buildOpenAIImagesStreamErrorBody(streamErr.Error()))
|
||||||
@@ -590,6 +742,15 @@ func (s *OpenAIGatewayService) forwardOpenAIImagesOAuth(
|
|||||||
account.Type,
|
account.Type,
|
||||||
len(parsed.Uploads),
|
len(parsed.Uploads),
|
||||||
)
|
)
|
||||||
|
if parsed.N > 1 {
|
||||||
|
logger.LegacyPrintf(
|
||||||
|
"service.openai_gateway",
|
||||||
|
"[Warning] Codex /responses image tool requested n=%d; falling back to n=1 request_model=%s endpoint=%s",
|
||||||
|
parsed.N,
|
||||||
|
requestModel,
|
||||||
|
parsed.Endpoint,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
token, _, err := s.GetAccessToken(ctx, account)
|
token, _, err := s.GetAccessToken(ctx, account)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -664,12 +825,12 @@ func (s *OpenAIGatewayService) forwardOpenAIImagesOAuth(
|
|||||||
firstTokenMs *int
|
firstTokenMs *int
|
||||||
)
|
)
|
||||||
if parsed.Stream {
|
if parsed.Stream {
|
||||||
usage, imageCount, firstTokenMs, err = s.handleOpenAIImagesOAuthStreamingResponse(resp, c, startTime, parsed.ResponseFormat, openAIImagesStreamPrefix(parsed))
|
usage, imageCount, firstTokenMs, err = s.handleOpenAIImagesOAuthStreamingResponse(resp, c, startTime, parsed.ResponseFormat, openAIImagesStreamPrefix(parsed), requestModel)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
usage, imageCount, err = s.handleOpenAIImagesOAuthNonStreamingResponse(resp, c, parsed.ResponseFormat)
|
usage, imageCount, err = s.handleOpenAIImagesOAuthNonStreamingResponse(resp, c, parsed.ResponseFormat, requestModel)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -258,9 +258,47 @@ func TestAccountSupportsOpenAIImageCapability_OAuthSupportsNative(t *testing.T)
|
|||||||
require.True(t, account.SupportsOpenAIImageCapability(OpenAIImagesCapabilityNative))
|
require.True(t, account.SupportsOpenAIImageCapability(OpenAIImagesCapabilityNative))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type openAIImageTestSSEEvent struct {
|
||||||
|
Name string
|
||||||
|
Data string
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseOpenAIImageTestSSEEvents(body string) []openAIImageTestSSEEvent {
|
||||||
|
chunks := strings.Split(body, "\n\n")
|
||||||
|
events := make([]openAIImageTestSSEEvent, 0, len(chunks))
|
||||||
|
for _, chunk := range chunks {
|
||||||
|
chunk = strings.TrimSpace(chunk)
|
||||||
|
if chunk == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
var event openAIImageTestSSEEvent
|
||||||
|
for _, line := range strings.Split(chunk, "\n") {
|
||||||
|
switch {
|
||||||
|
case strings.HasPrefix(line, "event: "):
|
||||||
|
event.Name = strings.TrimSpace(strings.TrimPrefix(line, "event: "))
|
||||||
|
case strings.HasPrefix(line, "data: "):
|
||||||
|
event.Data = strings.TrimSpace(strings.TrimPrefix(line, "data: "))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if event.Name != "" || event.Data != "" {
|
||||||
|
events = append(events, event)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return events
|
||||||
|
}
|
||||||
|
|
||||||
|
func findOpenAIImageTestSSEEvent(events []openAIImageTestSSEEvent, name string) (openAIImageTestSSEEvent, bool) {
|
||||||
|
for _, event := range events {
|
||||||
|
if event.Name == name {
|
||||||
|
return event, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return openAIImageTestSSEEvent{}, false
|
||||||
|
}
|
||||||
|
|
||||||
func TestOpenAIGatewayServiceForwardImages_OAuthUsesResponsesAPI(t *testing.T) {
|
func TestOpenAIGatewayServiceForwardImages_OAuthUsesResponsesAPI(t *testing.T) {
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","size":"1024x1024","quality":"high"}`)
|
body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","size":"1024x1024","quality":"high","n":2}`)
|
||||||
|
|
||||||
req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body))
|
req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body))
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
@@ -328,6 +366,7 @@ func TestOpenAIGatewayServiceForwardImages_OAuthUsesResponsesAPI(t *testing.T) {
|
|||||||
require.Equal(t, "draw a cat", gjson.GetBytes(upstream.lastBody, "input.0.content.0.text").String())
|
require.Equal(t, "draw a cat", gjson.GetBytes(upstream.lastBody, "input.0.content.0.text").String())
|
||||||
|
|
||||||
require.Equal(t, http.StatusOK, rec.Code)
|
require.Equal(t, http.StatusOK, rec.Code)
|
||||||
|
require.Equal(t, "gpt-image-2", gjson.Get(rec.Body.String(), "model").String())
|
||||||
require.Equal(t, "aGVsbG8=", gjson.Get(rec.Body.String(), "data.0.b64_json").String())
|
require.Equal(t, "aGVsbG8=", gjson.Get(rec.Body.String(), "data.0.b64_json").String())
|
||||||
require.Equal(t, "draw a cat", gjson.Get(rec.Body.String(), "data.0.revised_prompt").String())
|
require.Equal(t, "draw a cat", gjson.Get(rec.Body.String(), "data.0.revised_prompt").String())
|
||||||
}
|
}
|
||||||
@@ -354,8 +393,9 @@ func TestOpenAIGatewayServiceForwardImages_OAuthStreamingTransformsEvents(t *tes
|
|||||||
"X-Request-Id": []string{"req_img_stream"},
|
"X-Request-Id": []string{"req_img_stream"},
|
||||||
},
|
},
|
||||||
Body: io.NopCloser(strings.NewReader(
|
Body: io.NopCloser(strings.NewReader(
|
||||||
"data: {\"type\":\"response.image_generation_call.partial_image\",\"partial_image_b64\":\"cGFydGlhbA==\",\"partial_image_index\":0,\"output_format\":\"png\"}\n\n" +
|
"data: {\"type\":\"response.created\",\"response\":{\"created_at\":1710000001,\"tools\":[{\"type\":\"image_generation\",\"model\":\"gpt-image-2\",\"background\":\"auto\",\"output_format\":\"png\",\"quality\":\"high\",\"size\":\"1024x1024\"}]}}\n\n" +
|
||||||
"data: {\"type\":\"response.completed\",\"response\":{\"created_at\":1710000001,\"usage\":{\"input_tokens\":5,\"output_tokens\":9,\"output_tokens_details\":{\"image_tokens\":4}},\"tool_usage\":{\"image_gen\":{\"images\":1}},\"output\":[{\"type\":\"image_generation_call\",\"result\":\"ZmluYWw=\",\"output_format\":\"png\"}]}}\n\n" +
|
"data: {\"type\":\"response.image_generation_call.partial_image\",\"partial_image_b64\":\"cGFydGlhbA==\",\"partial_image_index\":0,\"output_format\":\"png\",\"background\":\"auto\"}\n\n" +
|
||||||
|
"data: {\"type\":\"response.completed\",\"response\":{\"created_at\":1710000001,\"usage\":{\"input_tokens\":5,\"output_tokens\":9,\"output_tokens_details\":{\"image_tokens\":4}},\"tool_usage\":{\"image_gen\":{\"images\":1}},\"tools\":[{\"type\":\"image_generation\",\"model\":\"gpt-image-2\",\"background\":\"auto\",\"output_format\":\"png\",\"quality\":\"high\",\"size\":\"1024x1024\"}],\"output\":[{\"type\":\"image_generation_call\",\"result\":\"ZmluYWw=\",\"output_format\":\"png\"}]}}\n\n" +
|
||||||
"data: [DONE]\n\n",
|
"data: [DONE]\n\n",
|
||||||
)),
|
)),
|
||||||
},
|
},
|
||||||
@@ -377,12 +417,32 @@ func TestOpenAIGatewayServiceForwardImages_OAuthStreamingTransformsEvents(t *tes
|
|||||||
require.NotNil(t, result)
|
require.NotNil(t, result)
|
||||||
require.True(t, result.Stream)
|
require.True(t, result.Stream)
|
||||||
require.Equal(t, 1, result.ImageCount)
|
require.Equal(t, 1, result.ImageCount)
|
||||||
require.Contains(t, rec.Body.String(), "event: image_generation.partial_image")
|
events := parseOpenAIImageTestSSEEvents(rec.Body.String())
|
||||||
require.Contains(t, rec.Body.String(), "event: image_generation.completed")
|
partial, ok := findOpenAIImageTestSSEEvent(events, "image_generation.partial_image")
|
||||||
require.Contains(t, rec.Body.String(), "\"type\":\"image_generation.partial_image\"")
|
require.True(t, ok)
|
||||||
require.Contains(t, rec.Body.String(), "\"type\":\"image_generation.completed\"")
|
require.Equal(t, "image_generation.partial_image", gjson.Get(partial.Data, "type").String())
|
||||||
require.Contains(t, rec.Body.String(), "\"url\":\"data:image/png;base64,cGFydGlhbA==\"")
|
require.Equal(t, int64(1710000001), gjson.Get(partial.Data, "created_at").Int())
|
||||||
require.Contains(t, rec.Body.String(), "\"url\":\"data:image/png;base64,ZmluYWw=\"")
|
require.Equal(t, "cGFydGlhbA==", gjson.Get(partial.Data, "b64_json").String())
|
||||||
|
require.Equal(t, "data:image/png;base64,cGFydGlhbA==", gjson.Get(partial.Data, "url").String())
|
||||||
|
require.Equal(t, "gpt-image-2", gjson.Get(partial.Data, "model").String())
|
||||||
|
require.Equal(t, "png", gjson.Get(partial.Data, "output_format").String())
|
||||||
|
require.Equal(t, "high", gjson.Get(partial.Data, "quality").String())
|
||||||
|
require.Equal(t, "1024x1024", gjson.Get(partial.Data, "size").String())
|
||||||
|
require.Equal(t, "auto", gjson.Get(partial.Data, "background").String())
|
||||||
|
|
||||||
|
completed, ok := findOpenAIImageTestSSEEvent(events, "image_generation.completed")
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, "image_generation.completed", gjson.Get(completed.Data, "type").String())
|
||||||
|
require.Equal(t, int64(1710000001), gjson.Get(completed.Data, "created_at").Int())
|
||||||
|
require.Equal(t, "ZmluYWw=", gjson.Get(completed.Data, "b64_json").String())
|
||||||
|
require.Equal(t, "data:image/png;base64,ZmluYWw=", gjson.Get(completed.Data, "url").String())
|
||||||
|
require.Equal(t, "gpt-image-2", gjson.Get(completed.Data, "model").String())
|
||||||
|
require.Equal(t, "png", gjson.Get(completed.Data, "output_format").String())
|
||||||
|
require.Equal(t, "high", gjson.Get(completed.Data, "quality").String())
|
||||||
|
require.Equal(t, "1024x1024", gjson.Get(completed.Data, "size").String())
|
||||||
|
require.Equal(t, "auto", gjson.Get(completed.Data, "background").String())
|
||||||
|
require.JSONEq(t, `{"images":1}`, gjson.Get(completed.Data, "usage").Raw)
|
||||||
|
require.False(t, gjson.Get(completed.Data, "revised_prompt").Exists())
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestOpenAIGatewayServiceForwardImages_OAuthEditsMultipartUsesResponsesAPI(t *testing.T) {
|
func TestOpenAIGatewayServiceForwardImages_OAuthEditsMultipartUsesResponsesAPI(t *testing.T) {
|
||||||
@@ -456,7 +516,7 @@ func TestOpenAIGatewayServiceForwardImages_OAuthEditsMultipartUsesResponsesAPI(t
|
|||||||
require.Equal(t, 1, result.ImageCount)
|
require.Equal(t, 1, result.ImageCount)
|
||||||
require.Equal(t, "gpt-image-2", gjson.GetBytes(upstream.lastBody, "tools.0.model").String())
|
require.Equal(t, "gpt-image-2", gjson.GetBytes(upstream.lastBody, "tools.0.model").String())
|
||||||
require.Equal(t, "edit", gjson.GetBytes(upstream.lastBody, "tools.0.action").String())
|
require.Equal(t, "edit", gjson.GetBytes(upstream.lastBody, "tools.0.action").String())
|
||||||
require.Equal(t, "high", gjson.GetBytes(upstream.lastBody, "tools.0.input_fidelity").String())
|
require.False(t, gjson.GetBytes(upstream.lastBody, "tools.0.input_fidelity").Exists())
|
||||||
require.Equal(t, "webp", gjson.GetBytes(upstream.lastBody, "tools.0.output_format").String())
|
require.Equal(t, "webp", gjson.GetBytes(upstream.lastBody, "tools.0.output_format").String())
|
||||||
require.True(t, strings.HasPrefix(gjson.GetBytes(upstream.lastBody, "input.0.content.1.image_url").String(), "data:image/png;base64,"))
|
require.True(t, strings.HasPrefix(gjson.GetBytes(upstream.lastBody, "input.0.content.1.image_url").String(), "data:image/png;base64,"))
|
||||||
require.True(t, strings.HasPrefix(gjson.GetBytes(upstream.lastBody, "tools.0.input_image_mask.image_url").String(), "data:image/png;base64,"))
|
require.True(t, strings.HasPrefix(gjson.GetBytes(upstream.lastBody, "tools.0.input_image_mask.image_url").String(), "data:image/png;base64,"))
|
||||||
@@ -493,8 +553,9 @@ func TestOpenAIGatewayServiceForwardImages_OAuthEditsStreamingTransformsEvents(t
|
|||||||
"Content-Type": []string{"text/event-stream"},
|
"Content-Type": []string{"text/event-stream"},
|
||||||
},
|
},
|
||||||
Body: io.NopCloser(strings.NewReader(
|
Body: io.NopCloser(strings.NewReader(
|
||||||
"data: {\"type\":\"response.image_generation_call.partial_image\",\"partial_image_b64\":\"cGFydGlhbA==\",\"partial_image_index\":0,\"output_format\":\"webp\"}\n\n" +
|
"data: {\"type\":\"response.created\",\"response\":{\"created_at\":1710000003,\"tools\":[{\"type\":\"image_generation\",\"model\":\"gpt-image-2\",\"background\":\"transparent\",\"output_format\":\"webp\",\"quality\":\"high\",\"size\":\"1024x1024\"}]}}\n\n" +
|
||||||
"data: {\"type\":\"response.completed\",\"response\":{\"created_at\":1710000003,\"usage\":{\"input_tokens\":7,\"output_tokens\":10,\"output_tokens_details\":{\"image_tokens\":5}},\"tool_usage\":{\"image_gen\":{\"images\":1}},\"output\":[{\"type\":\"image_generation_call\",\"result\":\"ZWRpdGVk\",\"revised_prompt\":\"replace background with aurora\",\"output_format\":\"webp\"}]}}\n\n" +
|
"data: {\"type\":\"response.image_generation_call.partial_image\",\"partial_image_b64\":\"cGFydGlhbA==\",\"partial_image_index\":0,\"output_format\":\"webp\",\"background\":\"transparent\"}\n\n" +
|
||||||
|
"data: {\"type\":\"response.completed\",\"response\":{\"created_at\":1710000003,\"usage\":{\"input_tokens\":7,\"output_tokens\":10,\"output_tokens_details\":{\"image_tokens\":5}},\"tool_usage\":{\"image_gen\":{\"images\":1}},\"tools\":[{\"type\":\"image_generation\",\"model\":\"gpt-image-2\",\"background\":\"transparent\",\"output_format\":\"webp\",\"quality\":\"high\",\"size\":\"1024x1024\"}],\"output\":[{\"type\":\"image_generation_call\",\"result\":\"ZWRpdGVk\",\"revised_prompt\":\"replace background with aurora\",\"output_format\":\"webp\"}]}}\n\n" +
|
||||||
"data: [DONE]\n\n",
|
"data: [DONE]\n\n",
|
||||||
)),
|
)),
|
||||||
},
|
},
|
||||||
@@ -518,15 +579,35 @@ func TestOpenAIGatewayServiceForwardImages_OAuthEditsStreamingTransformsEvents(t
|
|||||||
require.Equal(t, "edit", gjson.GetBytes(upstream.lastBody, "tools.0.action").String())
|
require.Equal(t, "edit", gjson.GetBytes(upstream.lastBody, "tools.0.action").String())
|
||||||
require.Equal(t, "https://example.com/source.png", gjson.GetBytes(upstream.lastBody, "input.0.content.1.image_url").String())
|
require.Equal(t, "https://example.com/source.png", gjson.GetBytes(upstream.lastBody, "input.0.content.1.image_url").String())
|
||||||
require.Equal(t, "https://example.com/mask.png", gjson.GetBytes(upstream.lastBody, "tools.0.input_image_mask.image_url").String())
|
require.Equal(t, "https://example.com/mask.png", gjson.GetBytes(upstream.lastBody, "tools.0.input_image_mask.image_url").String())
|
||||||
require.Contains(t, rec.Body.String(), "event: image_edit.partial_image")
|
events := parseOpenAIImageTestSSEEvents(rec.Body.String())
|
||||||
require.Contains(t, rec.Body.String(), "event: image_edit.completed")
|
partial, ok := findOpenAIImageTestSSEEvent(events, "image_edit.partial_image")
|
||||||
require.Contains(t, rec.Body.String(), "\"type\":\"image_edit.partial_image\"")
|
require.True(t, ok)
|
||||||
require.Contains(t, rec.Body.String(), "\"type\":\"image_edit.completed\"")
|
require.Equal(t, "image_edit.partial_image", gjson.Get(partial.Data, "type").String())
|
||||||
require.Contains(t, rec.Body.String(), "\"url\":\"data:image/webp;base64,cGFydGlhbA==\"")
|
require.Equal(t, int64(1710000003), gjson.Get(partial.Data, "created_at").Int())
|
||||||
require.Contains(t, rec.Body.String(), "\"url\":\"data:image/webp;base64,ZWRpdGVk\"")
|
require.Equal(t, "cGFydGlhbA==", gjson.Get(partial.Data, "b64_json").String())
|
||||||
|
require.Equal(t, "data:image/webp;base64,cGFydGlhbA==", gjson.Get(partial.Data, "url").String())
|
||||||
|
require.Equal(t, "gpt-image-2", gjson.Get(partial.Data, "model").String())
|
||||||
|
require.Equal(t, "webp", gjson.Get(partial.Data, "output_format").String())
|
||||||
|
require.Equal(t, "high", gjson.Get(partial.Data, "quality").String())
|
||||||
|
require.Equal(t, "1024x1024", gjson.Get(partial.Data, "size").String())
|
||||||
|
require.Equal(t, "transparent", gjson.Get(partial.Data, "background").String())
|
||||||
|
|
||||||
|
completed, ok := findOpenAIImageTestSSEEvent(events, "image_edit.completed")
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, "image_edit.completed", gjson.Get(completed.Data, "type").String())
|
||||||
|
require.Equal(t, int64(1710000003), gjson.Get(completed.Data, "created_at").Int())
|
||||||
|
require.Equal(t, "ZWRpdGVk", gjson.Get(completed.Data, "b64_json").String())
|
||||||
|
require.Equal(t, "data:image/webp;base64,ZWRpdGVk", gjson.Get(completed.Data, "url").String())
|
||||||
|
require.Equal(t, "gpt-image-2", gjson.Get(completed.Data, "model").String())
|
||||||
|
require.Equal(t, "webp", gjson.Get(completed.Data, "output_format").String())
|
||||||
|
require.Equal(t, "high", gjson.Get(completed.Data, "quality").String())
|
||||||
|
require.Equal(t, "1024x1024", gjson.Get(completed.Data, "size").String())
|
||||||
|
require.Equal(t, "transparent", gjson.Get(completed.Data, "background").String())
|
||||||
|
require.JSONEq(t, `{"images":1}`, gjson.Get(completed.Data, "usage").Raw)
|
||||||
|
require.False(t, gjson.Get(completed.Data, "revised_prompt").Exists())
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestBuildOpenAIImagesResponsesRequest_RejectsMultipleImages(t *testing.T) {
|
func TestBuildOpenAIImagesResponsesRequest_DowngradesMultipleImagesToSingle(t *testing.T) {
|
||||||
parsed := &OpenAIImagesRequest{
|
parsed := &OpenAIImagesRequest{
|
||||||
Endpoint: openAIImagesGenerationsEndpoint,
|
Endpoint: openAIImagesGenerationsEndpoint,
|
||||||
Model: "gpt-image-2",
|
Model: "gpt-image-2",
|
||||||
@@ -535,9 +616,29 @@ func TestBuildOpenAIImagesResponsesRequest_RejectsMultipleImages(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
body, err := buildOpenAIImagesResponsesRequest(parsed, "gpt-image-2")
|
body, err := buildOpenAIImagesResponsesRequest(parsed, "gpt-image-2")
|
||||||
require.Error(t, err)
|
require.NoError(t, err)
|
||||||
require.Nil(t, body)
|
require.NotNil(t, body)
|
||||||
require.Contains(t, err.Error(), "only n=1")
|
require.False(t, gjson.GetBytes(body, "tools.0.n").Exists())
|
||||||
|
require.Equal(t, "gpt-image-2", gjson.GetBytes(body, "tools.0.model").String())
|
||||||
|
require.Equal(t, "draw a cat", gjson.GetBytes(body, "input.0.content.0.text").String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildOpenAIImagesResponsesRequest_StripsInputFidelity(t *testing.T) {
|
||||||
|
parsed := &OpenAIImagesRequest{
|
||||||
|
Endpoint: openAIImagesEditsEndpoint,
|
||||||
|
Model: "gpt-image-2",
|
||||||
|
Prompt: "replace background",
|
||||||
|
InputFidelity: "high",
|
||||||
|
InputImageURLs: []string{
|
||||||
|
"https://example.com/source.png",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := buildOpenAIImagesResponsesRequest(parsed, "gpt-image-2")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, body)
|
||||||
|
require.False(t, gjson.GetBytes(body, "tools.0.input_fidelity").Exists())
|
||||||
|
require.Equal(t, "edit", gjson.GetBytes(body, "tools.0.action").String())
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCollectOpenAIImagesFromResponsesBody_FallsBackToOutputItemDone(t *testing.T) {
|
func TestCollectOpenAIImagesFromResponsesBody_FallsBackToOutputItemDone(t *testing.T) {
|
||||||
@@ -604,8 +705,14 @@ func TestOpenAIGatewayServiceForwardImages_OAuthStreamingHandlesOutputItemDoneFa
|
|||||||
require.NotNil(t, result)
|
require.NotNil(t, result)
|
||||||
require.True(t, result.Stream)
|
require.True(t, result.Stream)
|
||||||
require.Equal(t, 1, result.ImageCount)
|
require.Equal(t, 1, result.ImageCount)
|
||||||
require.Contains(t, rec.Body.String(), "event: image_generation.completed")
|
events := parseOpenAIImageTestSSEEvents(rec.Body.String())
|
||||||
require.Contains(t, rec.Body.String(), "\"type\":\"image_generation.completed\"")
|
completed, ok := findOpenAIImageTestSSEEvent(events, "image_generation.completed")
|
||||||
require.Contains(t, rec.Body.String(), "\"url\":\"data:image/png;base64,ZmluYWw=\"")
|
require.True(t, ok)
|
||||||
|
require.Equal(t, "image_generation.completed", gjson.Get(completed.Data, "type").String())
|
||||||
|
require.Equal(t, int64(1710000005), gjson.Get(completed.Data, "created_at").Int())
|
||||||
|
require.Equal(t, "ZmluYWw=", gjson.Get(completed.Data, "b64_json").String())
|
||||||
|
require.Equal(t, "data:image/png;base64,ZmluYWw=", gjson.Get(completed.Data, "url").String())
|
||||||
|
require.Equal(t, "gpt-image-2", gjson.Get(completed.Data, "model").String())
|
||||||
|
require.JSONEq(t, `{"images":1}`, gjson.Get(completed.Data, "usage").Raw)
|
||||||
require.NotContains(t, rec.Body.String(), "event: error")
|
require.NotContains(t, rec.Body.String(), "event: error")
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user