Files
sub2api-ht/backend/internal/service/image_generation_intent_test.go

185 lines
6.4 KiB
Go

package service
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestIsImageGenerationIntent(t *testing.T) {
tests := []struct {
name string
endpoint string
model string
body []byte
want bool
}{
{
name: "images endpoint",
endpoint: "/v1/images/generations",
body: []byte(`{"model":"gpt-image-2"}`),
want: true,
},
{
name: "image model",
endpoint: "/v1/responses",
model: "gpt-image-2",
body: []byte(`{"model":"gpt-image-2"}`),
want: true,
},
{
name: "image tool",
endpoint: "/v1/responses",
model: "gpt-5.4",
body: []byte(`{"model":"gpt-5.4","tools":[{"type":"image_generation"}]}`),
want: true,
},
{
name: "image tool choice",
endpoint: "/v1/responses",
model: "gpt-5.4",
body: []byte(`{"model":"gpt-5.4","tool_choice":{"type":"image_generation"}}`),
want: true,
},
{
name: "required tool choice alone is text",
endpoint: "/v1/responses",
model: "gpt-5.4",
body: []byte(`{"model":"gpt-5.4","tool_choice":"required"}`),
want: false,
},
{
name: "text only gpt 5.4",
endpoint: "/v1/responses",
model: "gpt-5.4",
body: []byte(`{"model":"gpt-5.4","input":"write code"}`),
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
require.Equal(t, tt.want, IsImageGenerationIntent(tt.endpoint, tt.model, tt.body))
})
}
}
func TestResolveOpenAIResponsesImageBillingConfigUsesCurrentBodyModel(t *testing.T) {
imageModel, imageSize, err := resolveOpenAIResponsesImageBillingConfigFromBody(
[]byte(`{"model":"mapped-image-model","tools":[{"type":"image_generation","size":"1024x1024"}]}`),
"requested-model",
)
require.NoError(t, err)
require.Equal(t, "mapped-image-model", imageModel)
require.Equal(t, "1K", imageSize)
}
func TestResolveOpenAIResponsesImageBillingConfigToolModelWins(t *testing.T) {
imageModel, imageSize, err := resolveOpenAIResponsesImageBillingConfigFromBody(
[]byte(`{"model":"mapped-text-model","tools":[{"type":"image_generation","model":"gpt-image-2","size":"1536x1024"}]}`),
"requested-model",
)
require.NoError(t, err)
require.Equal(t, "gpt-image-2", imageModel)
require.Equal(t, "2K", imageSize)
}
func TestResolveOpenAIResponsesImageBillingConfigSupportsOfficialAndCustomSizes(t *testing.T) {
tests := []struct {
name string
body []byte
wantTier string
}{
{
name: "official 2k landscape",
body: []byte(`{"model":"gpt-5.4","tools":[{"type":"image_generation","model":"gpt-image-2","size":"2048x1152"}]}`),
wantTier: "2K",
},
{
name: "official 4k landscape",
body: []byte(`{"model":"gpt-5.4","tools":[{"type":"image_generation","model":"gpt-image-2","size":"3840x2160"}]}`),
wantTier: "4K",
},
{
name: "custom valid 2k",
body: []byte(`{"model":"gpt-5.5","tools":[{"type":"image_generation","model":"gpt-image-2","size":"1280x768"}]}`),
wantTier: "2K",
},
{
name: "default image tool model supports flexible size",
body: []byte(`{"model":"gpt-5.4","tools":[{"type":"image_generation","size":"2048x1152"}]}`),
wantTier: "2K",
},
{
name: "top level image size is moved into billing",
body: []byte(`{"model":"gpt-image-2","size":"2048x2048","tools":[{"type":"image_generation","model":"gpt-image-2"}]}`),
wantTier: "2K",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
imageModel, imageSize, err := resolveOpenAIResponsesImageBillingConfigFromBody(tt.body, "requested-model")
require.NoError(t, err)
require.NotEmpty(t, imageModel)
require.Equal(t, tt.wantTier, imageSize)
})
}
}
func TestResolveOpenAIResponsesImageBillingConfigDoesNotRejectUnknownSizes(t *testing.T) {
imageModel, imageSize, err := resolveOpenAIResponsesImageBillingConfigFromBody(
[]byte(`{"model":"gpt-5.4","tools":[{"type":"image_generation","model":"gpt-image-1.5","size":"2048x1152"}]}`),
"requested-model",
)
require.NoError(t, err)
require.Equal(t, "gpt-image-1.5", imageModel)
require.Equal(t, "2K", imageSize)
}
func TestOpenAIImageOutputCounterDeduplicatesFinalImages(t *testing.T) {
counter := newOpenAIImageOutputCounter()
counter.AddSSEData([]byte(`{"type":"response.image_generation_call.partial_image","partial_image_b64":"abc"}`))
counter.AddSSEData([]byte(`{"type":"response.output_item.done","item":{"id":"ig_1","type":"image_generation_call","result":"final-a"}}`))
counter.AddSSEData([]byte(`{"type":"response.completed","response":{"output":[{"id":"ig_1","type":"image_generation_call","result":"final-a"},{"id":"ig_2","type":"image_generation_call","result":"final-b"}]}}`))
require.Equal(t, 2, counter.Count())
}
func TestOpenAIImageOutputCounterCountsImagesAPIStreamShapes(t *testing.T) {
counter := newOpenAIImageOutputCounter()
counter.AddSSEData([]byte(`{"type":"image_generation.completed","id":"ig_complete","b64_json":"final-a"}`))
counter.AddSSEData([]byte(`{"type":"response.output_item.done","item":{"id":"ig_item","type":"image_generation_call","result":"final-b"}}`))
counter.AddSSEData([]byte(`{"type":"response.completed","response":{"output":[{"id":"ig_done","type":"image_generation_call","result":"final-c"}]}}`))
require.Equal(t, 3, counter.Count())
dataCounter := newOpenAIImageOutputCounter()
dataCounter.AddSSEData([]byte(`{"data":[{"b64_json":"a"},{"b64_json":"b"}]}`))
dataCounter.AddSSEData([]byte(`{"data":[{"b64_json":"a"},{"b64_json":"b"},{"b64_json":"c"}]}`))
require.Equal(t, 3, dataCounter.Count())
}
func TestOpenAIImageOutputCounterCountsMultilineSSEDataPayload(t *testing.T) {
counter := newOpenAIImageOutputCounter()
counter.AddSSEData([]byte("{\"type\":\"image_generation.completed\",\n\"b64_json\":\"final-a\"}"))
require.Equal(t, 1, counter.Count())
}
func TestOpenAIImageOutputCounterCountsMultilineSSEBodyPayload(t *testing.T) {
counter := newOpenAIImageOutputCounter()
counter.AddSSEBody(
"data: {\"type\":\"image_generation.completed\",\n" +
"data: \"b64_json\":\"final-a\"}\n\n" +
"data: [DONE]\n\n",
)
require.Equal(t, 1, counter.Count())
}
func TestOpenAIImageOutputCounterFallsBackForInvalidMultilineSSEBody(t *testing.T) {
counter := newOpenAIImageOutputCounter()
counter.AddSSEBody(
"data: {\"type\":\"image_generation.completed\",\"b64_json\":\"final-a\"}\n" +
"data: {\"type\":\"image_generation.completed\",\"b64_json\":\"final-b\"}\n\n",
)
require.Equal(t, 2, counter.Count())
}