Files
sub2api-ht/backend/internal/service/openai_image_generation_controls_test.go

216 lines
7.6 KiB
Go

package service
import (
"context"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
)
func TestOpenAIGatewayServiceForward_RejectsDisabledImageGenerationIntents(t *testing.T) {
gin.SetMode(gin.TestMode)
tests := []struct {
name string
body []byte
}{
{
name: "image model",
body: []byte(`{"model":"gpt-image-2","input":"draw"}`),
},
{
name: "image tool",
body: []byte(`{"model":"gpt-5.4","input":"draw","tools":[{"type":"image_generation"}]}`),
},
{
name: "image tool choice",
body: []byte(`{"model":"gpt-5.4","input":"draw","tool_choice":{"type":"image_generation"}}`),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
upstream := &httpUpstreamRecorder{}
svc := newOpenAIImageGenerationControlTestService(upstream)
c, recorder := newOpenAIImageGenerationControlTestContext(false, "unit-test-agent/1.0")
account := newOpenAIImageGenerationControlTestAccount()
result, err := svc.Forward(context.Background(), c, account, tt.body)
require.Error(t, err)
require.Nil(t, result)
require.Equal(t, http.StatusForbidden, recorder.Code)
require.Equal(t, "permission_error", gjson.GetBytes(recorder.Body.Bytes(), "error.type").String())
require.Nil(t, upstream.lastReq, "disabled image request must not reach upstream")
})
}
}
func TestOpenAIGatewayServiceForward_DisabledGroupAllowsTextOnlyResponses(t *testing.T) {
gin.SetMode(gin.TestMode)
upstream := &httpUpstreamRecorder{
resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(strings.NewReader(`{"id":"resp_text","model":"gpt-5.4","usage":{"input_tokens":3,"output_tokens":2}}`)),
},
}
svc := newOpenAIImageGenerationControlTestService(upstream)
c, recorder := newOpenAIImageGenerationControlTestContext(false, "unit-test-agent/1.0")
account := newOpenAIImageGenerationControlTestAccount()
result, err := svc.Forward(context.Background(), c, account, []byte(`{"model":"gpt-5.4","input":"write code","stream":false}`))
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, http.StatusOK, recorder.Code)
require.Equal(t, 3, result.Usage.InputTokens)
require.Equal(t, 2, result.Usage.OutputTokens)
require.Equal(t, 0, result.ImageCount)
require.NotNil(t, upstream.lastReq)
}
func TestOpenAIGatewayServiceForward_CodexImageInjectionRespectsGroupCapability(t *testing.T) {
gin.SetMode(gin.TestMode)
tests := []struct {
name string
allowImages bool
wantInjected bool
}{
{name: "disabled group skips injection", allowImages: false, wantInjected: false},
{name: "enabled group injects image tool", allowImages: true, wantInjected: true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
upstream := &httpUpstreamRecorder{
resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(strings.NewReader(`{"id":"resp_codex","model":"gpt-5.4","usage":{"input_tokens":1,"output_tokens":1}}`)),
},
}
svc := newOpenAIImageGenerationControlTestService(upstream)
c, _ := newOpenAIImageGenerationControlTestContext(tt.allowImages, "codex_cli_rs/0.98.0")
account := newOpenAIImageGenerationControlTestAccount()
result, err := svc.Forward(context.Background(), c, account, []byte(`{"model":"gpt-5.4","input":"write code","stream":false}`))
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, upstream.lastReq)
hasImageTool := gjson.GetBytes(upstream.lastBody, `tools.#(type=="image_generation")`).Exists()
require.Equal(t, tt.wantInjected, hasImageTool)
instructions := gjson.GetBytes(upstream.lastBody, "instructions").String()
require.Equal(t, tt.wantInjected, strings.Contains(instructions, "image_generation"))
})
}
}
func TestOpenAIGatewayServiceHandleResponsesImageOutputs_NonStreaming(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := newOpenAIImageGenerationControlTestService(&httpUpstreamRecorder{})
c, _ := newOpenAIImageGenerationControlTestContext(true, "unit-test-agent/1.0")
resp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(strings.NewReader(`{
"id":"resp_image_json",
"model":"gpt-5.4",
"output":[{"id":"ig_json_1","type":"image_generation_call","result":"final-image"}],
"usage":{"input_tokens":7,"output_tokens":3,"output_tokens_details":{"image_tokens":2}}
}`)),
}
result, err := svc.handleNonStreamingResponse(context.Background(), resp, c, &Account{ID: 1, Type: AccountTypeAPIKey}, "gpt-5.4", "gpt-5.4")
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, 1, result.imageCount)
require.NotNil(t, result.usage)
require.Equal(t, 7, result.usage.InputTokens)
require.Equal(t, 3, result.usage.OutputTokens)
require.Equal(t, 2, result.usage.ImageOutputTokens)
}
func TestOpenAIGatewayServiceHandleResponsesImageOutputs_Streaming(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := newOpenAIImageGenerationControlTestService(&httpUpstreamRecorder{})
c, _ := newOpenAIImageGenerationControlTestContext(true, "unit-test-agent/1.0")
resp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}},
Body: io.NopCloser(strings.NewReader(
"data: {\"type\":\"response.output_item.done\",\"item\":{\"id\":\"ig_stream_1\",\"type\":\"image_generation_call\",\"result\":\"final-image\"}}\n\n" +
"data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_image_stream\",\"model\":\"gpt-5.5\",\"output\":[{\"id\":\"ig_stream_1\",\"type\":\"image_generation_call\",\"result\":\"final-image\"}],\"usage\":{\"input_tokens\":11,\"output_tokens\":5,\"output_tokens_details\":{\"image_tokens\":4}}}}\n\n",
)),
}
result, err := svc.handleStreamingResponse(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "gpt-5.5", "gpt-5.5")
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, 1, result.imageCount)
require.NotNil(t, result.usage)
require.Equal(t, 11, result.usage.InputTokens)
require.Equal(t, 5, result.usage.OutputTokens)
require.Equal(t, 4, result.usage.ImageOutputTokens)
}
func newOpenAIImageGenerationControlTestService(upstream *httpUpstreamRecorder) *OpenAIGatewayService {
cfg := &config.Config{}
return &OpenAIGatewayService{
cfg: cfg,
httpUpstream: upstream,
cache: &stubGatewayCache{},
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
toolCorrector: NewCodexToolCorrector(),
}
}
func newOpenAIImageGenerationControlTestContext(allowImages bool, userAgent string) (*gin.Context, *httptest.ResponseRecorder) {
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c.Request.Header.Set("User-Agent", userAgent)
groupID := int64(4242)
c.Set("api_key", &APIKey{
ID: 2424,
GroupID: &groupID,
Group: &Group{
ID: groupID,
AllowImageGeneration: allowImages,
RateMultiplier: 1,
ImageRateMultiplier: 1,
},
})
return c, recorder
}
func newOpenAIImageGenerationControlTestAccount() *Account {
return &Account{
ID: 5151,
Name: "openai-image-controls",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "sk-test",
},
}
}