diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 3b474c4a..724f01f2 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -124,9 +124,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { geminiQuotaService := service.NewGeminiQuotaService(configConfig, settingRepository) tempUnschedCache := repository.NewTempUnschedCache(redisClient) timeoutCounterCache := repository.NewTimeoutCounterCache(redisClient) + openAI403CounterCache := repository.NewOpenAI403CounterCache(redisClient) geminiTokenCache := repository.NewGeminiTokenCache(redisClient) compositeTokenCacheInvalidator := service.NewCompositeTokenCacheInvalidator(geminiTokenCache) - rateLimitService := service.ProvideRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService, tempUnschedCache, timeoutCounterCache, settingService, compositeTokenCacheInvalidator) + rateLimitService := service.ProvideRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService, tempUnschedCache, timeoutCounterCache, openAI403CounterCache, settingService, compositeTokenCacheInvalidator) httpUpstream := repository.NewHTTPUpstream(configConfig) claudeUsageFetcher := repository.NewClaudeUsageFetcher(httpUpstream) antigravityQuotaFetcher := service.NewAntigravityQuotaFetcher(proxyRepository) diff --git a/backend/internal/repository/openai_403_counter_cache.go b/backend/internal/repository/openai_403_counter_cache.go new file mode 100644 index 00000000..a68d2518 --- /dev/null +++ b/backend/internal/repository/openai_403_counter_cache.go @@ -0,0 +1,51 @@ +package repository + +import ( + "context" + "fmt" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/redis/go-redis/v9" +) + +const openAI403CounterPrefix = "openai_403_count:account:" + +var openAI403CounterIncrScript = redis.NewScript(` + local key = KEYS[1] + local ttl = tonumber(ARGV[1]) + + local count = redis.call('INCR', key) + if count == 1 then + redis.call('EXPIRE', key, ttl) + end + + return count +`) + +type openAI403CounterCache struct { + rdb *redis.Client +} + +func NewOpenAI403CounterCache(rdb *redis.Client) service.OpenAI403CounterCache { + return &openAI403CounterCache{rdb: rdb} +} + +func (c *openAI403CounterCache) IncrementOpenAI403Count(ctx context.Context, accountID int64, windowMinutes int) (int64, error) { + key := fmt.Sprintf("%s%d", openAI403CounterPrefix, accountID) + + ttlSeconds := windowMinutes * 60 + if ttlSeconds < 60 { + ttlSeconds = 60 + } + + result, err := openAI403CounterIncrScript.Run(ctx, c.rdb, []string{key}, ttlSeconds).Int64() + if err != nil { + return 0, fmt.Errorf("increment openai 403 count: %w", err) + } + return result, nil +} + +func (c *openAI403CounterCache) ResetOpenAI403Count(ctx context.Context, accountID int64) error { + key := fmt.Sprintf("%s%d", openAI403CounterPrefix, accountID) + return c.rdb.Del(ctx, key).Err() +} diff --git a/backend/internal/repository/wire.go b/backend/internal/repository/wire.go index d3adb4a0..b10175c3 100644 --- a/backend/internal/repository/wire.go +++ b/backend/internal/repository/wire.go @@ -96,6 +96,7 @@ var ProviderSet = wire.NewSet( NewAPIKeyCache, NewTempUnschedCache, NewTimeoutCounterCache, + NewOpenAI403CounterCache, NewInternal500CounterCache, ProvideConcurrencyCache, ProvideSessionLimitCache, diff --git a/backend/internal/service/account.go b/backend/internal/service/account.go index 801eac1b..0fb6e18f 100644 --- a/backend/internal/service/account.go +++ b/backend/internal/service/account.go @@ -930,10 +930,8 @@ func (a *Account) SupportsOpenAIImageCapability(capability OpenAIImagesCapabilit return false } switch capability { - case OpenAIImagesCapabilityBasic: + case OpenAIImagesCapabilityBasic, OpenAIImagesCapabilityNative: return a.Type == AccountTypeOAuth || a.Type == AccountTypeAPIKey - case OpenAIImagesCapabilityNative: - return a.Type == AccountTypeAPIKey default: return true } diff --git a/backend/internal/service/account_test_service.go b/backend/internal/service/account_test_service.go index 52d53013..e5bc93ca 100644 --- a/backend/internal/service/account_test_service.go +++ b/backend/internal/service/account_test_service.go @@ -5,7 +5,6 @@ import ( "bytes" "context" "crypto/rand" - "encoding/base64" "encoding/hex" "encoding/json" "errors" @@ -1138,7 +1137,7 @@ func (s *AccountTestService) testOpenAIImageAPIKey(c *gin.Context, ctx context.C return nil } -// testOpenAIImageOAuth tests OpenAI image generation using an OAuth account via ChatGPT backend API. +// testOpenAIImageOAuth tests OpenAI image generation using an OAuth account via Codex /responses API. func (s *AccountTestService) testOpenAIImageOAuth(c *gin.Context, ctx context.Context, account *Account, modelID, prompt string) error { authToken := account.GetOpenAIAccessToken() if authToken == "" { @@ -1153,69 +1152,46 @@ func (s *AccountTestService) testOpenAIImageOAuth(c *gin.Context, ctx context.Co c.Writer.Flush() s.sendEvent(c, TestEvent{Type: "test_start", Model: modelID}) - s.sendEvent(c, TestEvent{Type: "content", Text: "Initializing ChatGPT backend...\n"}) + s.sendEvent(c, TestEvent{Type: "content", Text: "Calling Codex /responses image tool...\n"}) - // Build headers (replicating buildOpenAIBackendAPIHeaders logic) - headers := buildOpenAIBackendAPIHeadersForTest(ctx, account, authToken, s.accountRepo) + parsed := &OpenAIImagesRequest{ + Endpoint: openAIImagesGenerationsEndpoint, + Model: strings.TrimSpace(modelID), + Prompt: prompt, + } + applyOpenAIImagesDefaults(parsed) + + responsesBody, err := buildOpenAIImagesResponsesRequest(parsed, parsed.Model) + if err != nil { + return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to build image request: %s", err.Error())) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, chatgptCodexAPIURL, bytes.NewReader(responsesBody)) + if err != nil { + return s.sendErrorAndEnd(c, "Failed to create request") + } + req.Host = "chatgpt.com" + req.Header.Set("Authorization", "Bearer "+authToken) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "text/event-stream") + req.Header.Set("OpenAI-Beta", "responses=experimental") + req.Header.Set("originator", "opencode") + if customUA := strings.TrimSpace(account.GetOpenAIUserAgent()); customUA != "" { + req.Header.Set("User-Agent", customUA) + } else { + req.Header.Set("User-Agent", codexCLIUserAgent) + } + if chatgptAccountID := strings.TrimSpace(account.GetChatGPTAccountID()); chatgptAccountID != "" { + req.Header.Set("chatgpt-account-id", chatgptAccountID) + } proxyURL := "" if account.ProxyID != nil && account.Proxy != nil { proxyURL = account.Proxy.URL() } - - client, err := newOpenAIBackendAPIClient(proxyURL) + resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency) if err != nil { - return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to create client: %s", err.Error())) - } - - // Bootstrap - if bootstrapErr := bootstrapOpenAIBackendAPI(ctx, client, headers); bootstrapErr != nil { - log.Printf("OpenAI image test bootstrap warning: %v", bootstrapErr) - } - - // Fetch chat requirements - s.sendEvent(c, TestEvent{Type: "content", Text: "Fetching chat requirements...\n"}) - chatReqs, err := fetchOpenAIChatRequirements(ctx, client, headers) - if err != nil { - return s.sendErrorAndEnd(c, fmt.Sprintf("Chat requirements failed: %s", err.Error())) - } - if chatReqs.Arkose.Required { - return s.sendErrorAndEnd(c, "Unsupported challenge: arkose required") - } - - // Initialize and prepare conversation - s.sendEvent(c, TestEvent{Type: "content", Text: "Preparing image conversation...\n"}) - parentMessageID := uuid.NewString() - proofToken := generateOpenAIProofToken(chatReqs.ProofOfWork.Required, chatReqs.ProofOfWork.Seed, chatReqs.ProofOfWork.Difficulty, headers.Get("User-Agent")) - _ = initializeOpenAIImageConversation(ctx, client, headers) - conduitToken, err := prepareOpenAIImageConversation(ctx, client, headers, prompt, parentMessageID, chatReqs.Token, proofToken) - if err != nil { - return s.sendErrorAndEnd(c, fmt.Sprintf("Conversation prepare failed: %s", err.Error())) - } - - // Build simplified conversation request (no file uploads) - convReq := buildOpenAIImageTestConversationRequest(prompt, parentMessageID) - convHeaders := cloneHTTPHeader(headers) - convHeaders.Set("Accept", "text/event-stream") - convHeaders.Set("Content-Type", "application/json") - convHeaders.Set("openai-sentinel-chat-requirements-token", chatReqs.Token) - if conduitToken != "" { - convHeaders.Set("x-conduit-token", conduitToken) - } - if proofToken != "" { - convHeaders.Set("openai-sentinel-proof-token", proofToken) - } - - s.sendEvent(c, TestEvent{Type: "content", Text: "Generating image...\n"}) - - resp, err := client.R(). - SetContext(ctx). - DisableAutoReadResponse(). - SetHeaders(headerToMap(convHeaders)). - SetBodyJsonMarshal(convReq). - Post(openAIChatGPTConversationURL) - if err != nil { - return s.sendErrorAndEnd(c, fmt.Sprintf("Conversation request failed: %s", err.Error())) + return s.sendErrorAndEnd(c, fmt.Sprintf("Responses API request failed: %s", err.Error())) } defer func() { if resp != nil && resp.Body != nil { @@ -1223,49 +1199,35 @@ func (s *AccountTestService) testOpenAIImageOAuth(c *gin.Context, ctx context.Co } }() if resp.StatusCode >= 400 { - return s.sendErrorAndEnd(c, fmt.Sprintf("Conversation API returned %d", resp.StatusCode)) + body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + message := strings.TrimSpace(extractUpstreamErrorMessage(body)) + if message == "" { + message = fmt.Sprintf("Responses API returned %d", resp.StatusCode) + } + return s.sendErrorAndEnd(c, message) } - startTime := time.Now() - conversationID, pointerInfos, _, _, err := readOpenAIImageConversationStream(resp, startTime) + body, err := io.ReadAll(resp.Body) if err != nil { - return s.sendErrorAndEnd(c, fmt.Sprintf("Stream read failed: %s", err.Error())) + return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to read image response: %s", err.Error())) } - pointerInfos = mergeOpenAIImagePointerInfos(pointerInfos, nil) - if conversationID != "" && !hasOpenAIFileServicePointerInfos(pointerInfos) { - s.sendEvent(c, TestEvent{Type: "content", Text: "Waiting for image generation to complete...\n"}) - polledPointers, pollErr := pollOpenAIImageConversation(ctx, client, headers, conversationID) - if pollErr != nil { - return s.sendErrorAndEnd(c, fmt.Sprintf("Poll failed: %s", pollErr.Error())) - } - pointerInfos = mergeOpenAIImagePointerInfos(pointerInfos, polledPointers) + results, _, _, _, _, err := collectOpenAIImagesFromResponsesBody(body) + if err != nil { + return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to parse image response: %s", err.Error())) } - pointerInfos = preferOpenAIFileServicePointerInfos(pointerInfos) - if len(pointerInfos) == 0 { - return s.sendErrorAndEnd(c, "No images returned from conversation") + if len(results) == 0 { + return s.sendErrorAndEnd(c, "No images returned from responses API") } - s.sendEvent(c, TestEvent{Type: "content", Text: "Downloading generated image...\n"}) - - // Download and encode each image - for _, pointer := range pointerInfos { - downloadURL, err := fetchOpenAIImageDownloadURL(ctx, client, headers, conversationID, pointer.Pointer) - if err != nil { - return s.sendErrorAndEnd(c, fmt.Sprintf("Download URL fetch failed: %s", err.Error())) - } - data, err := downloadOpenAIImageBytes(ctx, client, headers, downloadURL) - if err != nil { - return s.sendErrorAndEnd(c, fmt.Sprintf("Image download failed: %s", err.Error())) - } - b64 := base64.StdEncoding.EncodeToString(data) - mimeType := http.DetectContentType(data) - if pointer.Prompt != "" { - s.sendEvent(c, TestEvent{Type: "content", Text: pointer.Prompt}) + for _, item := range results { + if item.RevisedPrompt != "" { + s.sendEvent(c, TestEvent{Type: "content", Text: item.RevisedPrompt}) } + mimeType := openAIImageOutputMIMEType(item.OutputFormat) s.sendEvent(c, TestEvent{ Type: "image", - ImageURL: "data:" + mimeType + ";base64," + b64, + ImageURL: "data:" + mimeType + ";base64," + item.Result, MimeType: mimeType, }) } @@ -1274,107 +1236,6 @@ func (s *AccountTestService) testOpenAIImageOAuth(c *gin.Context, ctx context.Co return nil } -// buildOpenAIBackendAPIHeadersForTest builds ChatGPT backend API headers for test purposes. -// Replicates the logic from OpenAIGatewayService.buildOpenAIBackendAPIHeaders without -// requiring the full gateway service dependency. -func buildOpenAIBackendAPIHeadersForTest(ctx context.Context, account *Account, token string, repo AccountRepository) http.Header { - // Ensure device and session IDs exist - deviceID := account.GetOpenAIDeviceID() - sessionID := account.GetOpenAISessionID() - if deviceID == "" || sessionID == "" { - updates := map[string]any{} - if deviceID == "" { - deviceID = uuid.NewString() - updates["openai_device_id"] = deviceID - } - if sessionID == "" { - sessionID = uuid.NewString() - updates["openai_session_id"] = sessionID - } - if account.Extra == nil { - account.Extra = map[string]any{} - } - for key, value := range updates { - account.Extra[key] = value - } - if repo != nil { - updateCtx, cancel := context.WithTimeout(ctx, 5*time.Second) - defer cancel() - _ = repo.UpdateExtra(updateCtx, account.ID, updates) - } - } - - headers := make(http.Header) - headers.Set("Authorization", "Bearer "+token) - headers.Set("Accept", "application/json") - headers.Set("Origin", "https://chatgpt.com") - headers.Set("Referer", "https://chatgpt.com/") - headers.Set("Sec-Fetch-Dest", "empty") - headers.Set("Sec-Fetch-Mode", "cors") - headers.Set("Sec-Fetch-Site", "same-origin") - headers.Set("User-Agent", openAIImageBackendUserAgent) - if customUA := strings.TrimSpace(account.GetOpenAIUserAgent()); customUA != "" { - headers.Set("User-Agent", customUA) - } - if chatgptAccountID := strings.TrimSpace(account.GetChatGPTAccountID()); chatgptAccountID != "" { - headers.Set("chatgpt-account-id", chatgptAccountID) - } - if deviceID != "" { - headers.Set("oai-device-id", deviceID) - headers.Set("Cookie", "oai-did="+deviceID) - } - if sessionID != "" { - headers.Set("oai-session-id", sessionID) - } - return headers -} - -// buildOpenAIImageTestConversationRequest creates a simplified image generation conversation request. -func buildOpenAIImageTestConversationRequest(prompt, parentMessageID string) map[string]any { - promptText := strings.TrimSpace(prompt) - if promptText == "" { - promptText = "Generate an image." - } - metadata := map[string]any{ - "developer_mode_connector_ids": []any{}, - "selected_github_repos": []any{}, - "selected_all_github_repos": false, - "system_hints": []string{"picture_v2"}, - "serialization_metadata": map[string]any{ - "custom_symbol_offsets": []any{}, - }, - } - message := map[string]any{ - "id": uuid.NewString(), - "author": map[string]any{"role": "user"}, - "content": map[string]any{ - "content_type": "text", - "parts": []any{promptText}, - }, - "metadata": metadata, - "create_time": float64(time.Now().UnixMilli()) / 1000, - } - return map[string]any{ - "action": "next", - "client_prepare_state": "sent", - "parent_message_id": parentMessageID, - "messages": []any{message}, - "model": "auto", - "timezone_offset_min": openAITimezoneOffsetMinutes(), - "timezone": openAITimezoneName(), - "conversation_mode": map[string]any{"kind": "primary_assistant"}, - "system_hints": []string{"picture_v2"}, - "supports_buffering": true, - "supported_encodings": []string{"v1"}, - "client_contextual_info": map[string]any{"app_name": "chatgpt.com"}, - "force_nulligen": false, - "force_paragen": false, - "force_paragen_model_slug": "", - "force_rate_limit": false, - "websocket_request_id": uuid.NewString(), - } -} - func (s *AccountTestService) sendEvent(c *gin.Context, event TestEvent) { eventJSON, _ := json.Marshal(event) if _, err := fmt.Fprintf(c.Writer, "data: %s\n\n", eventJSON); err != nil { diff --git a/backend/internal/service/account_test_service_openai_image_test.go b/backend/internal/service/account_test_service_openai_image_test.go new file mode 100644 index 00000000..80a2fc31 --- /dev/null +++ b/backend/internal/service/account_test_service_openai_image_test.go @@ -0,0 +1,50 @@ +package service + +import ( + "context" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func TestAccountTestService_OpenAIImageOAuthHandlesOutputItemDoneFallback(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/1/test", nil) + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{ + "Content-Type": []string{"text/event-stream"}, + }, + Body: io.NopCloser(strings.NewReader( + "data: {\"type\":\"response.output_item.done\",\"item\":{\"id\":\"ig_123\",\"type\":\"image_generation_call\",\"result\":\"aGVsbG8=\",\"revised_prompt\":\"draw a cat\",\"output_format\":\"png\"}}\n\n" + + "data: {\"type\":\"response.completed\",\"response\":{\"created_at\":1710000006,\"tool_usage\":{\"image_gen\":{\"images\":1}},\"output\":[]}}\n\n" + + "data: [DONE]\n\n", + )), + }, + } + svc := &AccountTestService{httpUpstream: upstream} + account := &Account{ + ID: 53, + Name: "openai-oauth", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "token-123", + }, + } + + err := svc.testOpenAIImageOAuth(c, context.Background(), account, "gpt-image-2", "draw a cat") + require.NoError(t, err) + require.Contains(t, rec.Body.String(), "Calling Codex /responses image tool") + require.Contains(t, rec.Body.String(), "data:image/png;base64,aGVsbG8=") + require.Contains(t, rec.Body.String(), "\"success\":true") +} diff --git a/backend/internal/service/openai_403_counter.go b/backend/internal/service/openai_403_counter.go new file mode 100644 index 00000000..5ba3e195 --- /dev/null +++ b/backend/internal/service/openai_403_counter.go @@ -0,0 +1,11 @@ +package service + +import "context" + +// OpenAI403CounterCache 追踪 OpenAI 账号连续 403 失败次数。 +type OpenAI403CounterCache interface { + // IncrementOpenAI403Count 原子递增 403 计数并返回当前值。 + IncrementOpenAI403Count(ctx context.Context, accountID int64, windowMinutes int) (int64, error) + // ResetOpenAI403Count 成功后清零计数器。 + ResetOpenAI403Count(ctx context.Context, accountID int64) error +} diff --git a/backend/internal/service/openai_gateway_403_reset_test.go b/backend/internal/service/openai_gateway_403_reset_test.go new file mode 100644 index 00000000..c6805464 --- /dev/null +++ b/backend/internal/service/openai_gateway_403_reset_test.go @@ -0,0 +1,39 @@ +package service + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" +) + +type openAI403CounterResetStub struct { + resetCalls []int64 +} + +func (s *openAI403CounterResetStub) IncrementOpenAI403Count(context.Context, int64, int) (int64, error) { + return 0, nil +} + +func (s *openAI403CounterResetStub) ResetOpenAI403Count(_ context.Context, accountID int64) error { + s.resetCalls = append(s.resetCalls, accountID) + return nil +} + +func TestOpenAIGatewayServiceRecordUsage_ResetsOpenAI403CounterBeforeZeroUsageReturn(t *testing.T) { + counter := &openAI403CounterResetStub{} + rateLimitSvc := NewRateLimitService(nil, nil, nil, nil, nil) + rateLimitSvc.SetOpenAI403CounterCache(counter) + + svc := &OpenAIGatewayService{ + rateLimitService: rateLimitSvc, + } + + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{}, + Account: &Account{ID: 777, Platform: PlatformOpenAI}, + }) + + require.NoError(t, err) + require.Equal(t, []int64{777}, counter.resetCalls) +} diff --git a/backend/internal/service/openai_gateway_record_usage_test.go b/backend/internal/service/openai_gateway_record_usage_test.go index 95e1bffa..9665c4c8 100644 --- a/backend/internal/service/openai_gateway_record_usage_test.go +++ b/backend/internal/service/openai_gateway_record_usage_test.go @@ -1098,3 +1098,50 @@ func TestOpenAIGatewayServiceRecordUsage_ImageOnlyUsageStillPersists(t *testing. require.NotNil(t, usageRepo.lastLog.BillingMode) require.Equal(t, string(BillingModeImage), *usageRepo.lastLog.BillingMode) } + +func TestOpenAIGatewayServiceRecordUsage_ImageUsesPerImageBillingEvenWithUsageTokens(t *testing.T) { + imagePrice := 0.02 + groupID := int64(12) + + usageRepo := &openAIRecordUsageLogRepoStub{inserted: true} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil) + + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_image_per_request", + Model: "gpt-image-2", + Usage: OpenAIUsage{ + InputTokens: 1110, + OutputTokens: 1756, + ImageOutputTokens: 1756, + }, + ImageCount: 2, + ImageSize: "1K", + Duration: time.Second, + }, + APIKey: &APIKey{ + ID: 1008, + GroupID: i64p(groupID), + Group: &Group{ + ID: groupID, + RateMultiplier: 1.0, + ImagePrice1K: &imagePrice, + }, + }, + User: &User{ID: 2008}, + Account: &Account{ID: 3008}, + }) + + require.NoError(t, err) + require.NotNil(t, usageRepo.lastLog) + require.NotNil(t, usageRepo.lastLog.BillingMode) + require.Equal(t, string(BillingModeImage), *usageRepo.lastLog.BillingMode) + require.Equal(t, 2, usageRepo.lastLog.ImageCount) + require.InDelta(t, 0.04, usageRepo.lastLog.TotalCost, 1e-12) + require.InDelta(t, 0.04, usageRepo.lastLog.ActualCost, 1e-12) + require.InDelta(t, 0.0, usageRepo.lastLog.InputCost, 1e-12) + require.InDelta(t, 0.0, usageRepo.lastLog.OutputCost, 1e-12) + require.InDelta(t, 0.0, usageRepo.lastLog.ImageOutputCost, 1e-12) +} diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index 534ffeee..1a462a3b 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -4425,6 +4425,9 @@ type OpenAIRecordUsageInput struct { // RecordUsage records usage and deducts balance func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRecordUsageInput) error { result := input.Result + if s.rateLimitService != nil && input != nil && input.Account != nil && input.Account.Platform == PlatformOpenAI { + s.rateLimitService.ResetOpenAI403Counter(ctx, input.Account.ID) + } // 跳过所有 token 均为零的用量记录——上游未返回 usage 时不应写入数据库 if result.Usage.InputTokens == 0 && result.Usage.OutputTokens == 0 && @@ -4622,12 +4625,6 @@ func (s *OpenAIGatewayService) calculateOpenAIRecordUsageCost( serviceTier string, ) (*CostBreakdown, error) { if result != nil && result.ImageCount > 0 { - if hasOpenAIImageUsageTokens(result) { - cost, err := s.calculateOpenAIImageTokenCost(ctx, apiKey, billingModel, multiplier, tokens, serviceTier, result.ImageSize) - if err == nil { - return cost, nil - } - } return s.calculateOpenAIImageCost(ctx, billingModel, apiKey, result, multiplier), nil } if s.resolver != nil && apiKey.Group != nil { @@ -4679,7 +4676,8 @@ func (s *OpenAIGatewayService) calculateOpenAIImageCost( result *OpenAIForwardResult, multiplier float64, ) *CostBreakdown { - if resolved := s.resolveOpenAIChannelPricing(ctx, billingModel, apiKey); resolved != nil { + if resolved := s.resolveOpenAIChannelPricing(ctx, billingModel, apiKey); resolved != nil && + (resolved.Mode == BillingModePerRequest || resolved.Mode == BillingModeImage) { gid := apiKey.Group.ID cost, err := s.billingService.CalculateCostUnified(CostInput{ Ctx: ctx, diff --git a/backend/internal/service/openai_images.go b/backend/internal/service/openai_images.go index 7935376b..3922b730 100644 --- a/backend/internal/service/openai_images.go +++ b/backend/internal/service/openai_images.go @@ -50,6 +50,7 @@ const ( openAIImageLifecycleTimeout = 2 * time.Minute openAIImageMaxDownloadBytes = 20 << 20 // 20MB per image download openAIImageMaxUploadPartSize = 20 << 20 // 20MB per multipart upload part + openAIImagesResponsesMainModel = "gpt-5.4-mini" ) type OpenAIImagesCapability string @@ -81,10 +82,21 @@ type OpenAIImagesRequest struct { ExplicitSize bool SizeTier string ResponseFormat string + Quality string + Background string + OutputFormat string + Moderation string + InputFidelity string + Style string + OutputCompression *int + PartialImages *int HasMask bool HasNativeOptions bool RequiredCapability OpenAIImagesCapability + InputImageURLs []string + MaskImageURL string Uploads []OpenAIImagesUpload + MaskUpload *OpenAIImagesUpload Body []byte bodyHash string } @@ -188,7 +200,54 @@ func parseOpenAIImagesJSONRequest(body []byte, req *OpenAIImagesRequest) error { req.ExplicitSize = req.Size != "" } req.ResponseFormat = strings.ToLower(strings.TrimSpace(gjson.GetBytes(body, "response_format").String())) + req.Quality = strings.TrimSpace(gjson.GetBytes(body, "quality").String()) + req.Background = strings.TrimSpace(gjson.GetBytes(body, "background").String()) + req.OutputFormat = strings.TrimSpace(gjson.GetBytes(body, "output_format").String()) + req.Moderation = strings.TrimSpace(gjson.GetBytes(body, "moderation").String()) + req.InputFidelity = strings.TrimSpace(gjson.GetBytes(body, "input_fidelity").String()) + req.Style = strings.TrimSpace(gjson.GetBytes(body, "style").String()) req.HasMask = gjson.GetBytes(body, "mask").Exists() + if outputCompression := gjson.GetBytes(body, "output_compression"); outputCompression.Exists() { + if outputCompression.Type != gjson.Number { + return fmt.Errorf("invalid output_compression field type") + } + v := int(outputCompression.Int()) + req.OutputCompression = &v + } + if partialImages := gjson.GetBytes(body, "partial_images"); partialImages.Exists() { + if partialImages.Type != gjson.Number { + return fmt.Errorf("invalid partial_images field type") + } + v := int(partialImages.Int()) + req.PartialImages = &v + } + if req.IsEdits() { + images := gjson.GetBytes(body, "images") + if images.Exists() { + if !images.IsArray() { + return fmt.Errorf("invalid images field type") + } + for _, item := range images.Array() { + if imageURL := strings.TrimSpace(item.Get("image_url").String()); imageURL != "" { + req.InputImageURLs = append(req.InputImageURLs, imageURL) + continue + } + if item.Get("file_id").Exists() { + return fmt.Errorf("images[].file_id is not supported (use images[].image_url instead)") + } + } + } + if maskImageURL := strings.TrimSpace(gjson.GetBytes(body, "mask.image_url").String()); maskImageURL != "" { + req.MaskImageURL = maskImageURL + req.HasMask = true + } + if gjson.GetBytes(body, "mask.file_id").Exists() { + return fmt.Errorf("mask.file_id is not supported (use mask.image_url instead)") + } + if len(req.InputImageURLs) == 0 { + return fmt.Errorf("images[].image_url is required") + } + } req.HasNativeOptions = hasOpenAINativeImageOptions(func(path string) bool { return gjson.GetBytes(body, path).Exists() }) @@ -231,6 +290,16 @@ func parseOpenAIImagesMultipartRequest(body []byte, contentType string, req *Ope partContentType := strings.TrimSpace(part.Header.Get("Content-Type")) if name == "mask" && len(data) > 0 { req.HasMask = true + width, height := parseOpenAIImageDimensions(part.Header) + maskUpload := OpenAIImagesUpload{ + FieldName: name, + FileName: fileName, + ContentType: partContentType, + Data: data, + Width: width, + Height: height, + } + req.MaskUpload = &maskUpload } if name == "image" || strings.HasPrefix(name, "image[") { width, height := parseOpenAIImageDimensions(part.Header) @@ -270,6 +339,38 @@ func parseOpenAIImagesMultipartRequest(body []byte, contentType string, req *Ope return fmt.Errorf("n must be a positive integer") } req.N = n + case "quality": + req.Quality = value + req.HasNativeOptions = true + case "background": + req.Background = value + req.HasNativeOptions = true + case "output_format": + req.OutputFormat = value + req.HasNativeOptions = true + case "moderation": + req.Moderation = value + req.HasNativeOptions = true + case "input_fidelity": + req.InputFidelity = value + req.HasNativeOptions = true + case "style": + req.Style = value + req.HasNativeOptions = true + case "output_compression": + n, err := strconv.Atoi(value) + if err != nil { + return fmt.Errorf("invalid output_compression field value") + } + req.OutputCompression = &n + req.HasNativeOptions = true + case "partial_images": + n, err := strconv.Atoi(value) + if err != nil { + return fmt.Errorf("invalid partial_images field value") + } + req.PartialImages = &n + req.HasNativeOptions = true default: if isOpenAINativeImageOption(name) && value != "" { req.HasNativeOptions = true @@ -359,6 +460,8 @@ func hasOpenAINativeImageOptions(exists func(path string) bool) bool { "output_format", "output_compression", "moderation", + "input_fidelity", + "partial_images", } { if exists(path) { return true @@ -369,7 +472,7 @@ func hasOpenAINativeImageOptions(exists func(path string) bool) bool { func isOpenAINativeImageOption(name string) bool { switch strings.TrimSpace(strings.ToLower(name)) { - case "background", "quality", "style", "output_format", "output_compression", "moderation": + case "background", "quality", "style", "output_format", "output_compression", "moderation", "input_fidelity", "partial_images": return true default: return false @@ -782,156 +885,6 @@ func extractOpenAIImageCountFromJSONBytes(body []byte) int { return 0 } -func (s *OpenAIGatewayService) forwardOpenAIImagesOAuth( - ctx context.Context, - c *gin.Context, - account *Account, - parsed *OpenAIImagesRequest, - channelMappedModel string, -) (*OpenAIForwardResult, error) { - startTime := time.Now() - requestModel := strings.TrimSpace(parsed.Model) - if mapped := strings.TrimSpace(channelMappedModel); mapped != "" { - requestModel = mapped - } - if err := validateOpenAIImagesModel(requestModel); err != nil { - return nil, err - } - logger.LegacyPrintf( - "service.openai_gateway", - "[OpenAI] Images request routing request_model=%s endpoint=%s account_type=%s uploads=%d", - requestModel, - parsed.Endpoint, - account.Type, - len(parsed.Uploads), - ) - - token, _, err := s.GetAccessToken(ctx, account) - if err != nil { - return nil, err - } - client, err := newOpenAIBackendAPIClient(resolveOpenAIProxyURL(account)) - if err != nil { - return nil, err - } - headers, err := s.buildOpenAIBackendAPIHeaders(account, token) - if err != nil { - return nil, err - } - if bootstrapErr := bootstrapOpenAIBackendAPI(ctx, client, headers); bootstrapErr != nil { - logger.LegacyPrintf("service.openai_gateway", "OpenAI image bootstrap failed: %v", bootstrapErr) - } - - chatReqs, err := fetchOpenAIChatRequirements(ctx, client, headers) - if err != nil { - return nil, s.wrapOpenAIImageBackendError(ctx, c, account, err) - } - if chatReqs.Arkose.Required { - return nil, s.wrapOpenAIImageBackendError( - ctx, - c, - account, - newOpenAIImageSyntheticStatusError( - http.StatusForbidden, - "chat-requirements requires unsupported challenge (arkose)", - openAIChatGPTChatRequirementsURL, - ), - ) - } - - parentMessageID := uuid.NewString() - proofToken := generateOpenAIProofToken(chatReqs.ProofOfWork.Required, chatReqs.ProofOfWork.Seed, chatReqs.ProofOfWork.Difficulty, headers.Get("User-Agent")) - _ = initializeOpenAIImageConversation(ctx, client, headers) - conduitToken, err := prepareOpenAIImageConversation(ctx, client, headers, parsed.Prompt, parentMessageID, chatReqs.Token, proofToken) - if err != nil { - return nil, s.wrapOpenAIImageBackendError(ctx, c, account, err) - } - - uploads, err := uploadOpenAIImageFiles(ctx, client, headers, parsed.Uploads) - if err != nil { - return nil, s.wrapOpenAIImageBackendError(ctx, c, account, err) - } - - convReq := buildOpenAIImageConversationRequest(parsed, parentMessageID, uploads) - if parsedContent, err := json.Marshal(convReq); err == nil { - setOpsUpstreamRequestBody(c, parsedContent) - } - convHeaders := cloneHTTPHeader(headers) - convHeaders.Set("Accept", "text/event-stream") - convHeaders.Set("Content-Type", "application/json") - convHeaders.Set("openai-sentinel-chat-requirements-token", chatReqs.Token) - if conduitToken != "" { - convHeaders.Set("x-conduit-token", conduitToken) - } - if proofToken != "" { - convHeaders.Set("openai-sentinel-proof-token", proofToken) - } - - resp, err := client.R(). - SetContext(ctx). - DisableAutoReadResponse(). - SetHeaders(headerToMap(convHeaders)). - SetBodyJsonMarshal(convReq). - Post(openAIChatGPTConversationURL) - if err != nil { - return nil, fmt.Errorf("openai image conversation request failed: %w", err) - } - defer func() { - if resp != nil && resp.Body != nil { - _ = resp.Body.Close() - } - }() - if resp.StatusCode >= 400 { - return nil, s.wrapOpenAIImageBackendError(ctx, c, account, handleOpenAIImageBackendError(resp)) - } - - conversationID, pointerInfos, usage, firstTokenMs, err := readOpenAIImageConversationStream(resp, startTime) - if err != nil { - return nil, err - } - pointerInfos = mergeOpenAIImagePointerInfos(pointerInfos, nil) - logger.LegacyPrintf( - "service.openai_gateway", - "[OpenAI] Image extraction stream conversation_id=%s total_assets=%d file_service_assets=%d direct_assets=%d", - conversationID, - len(pointerInfos), - countOpenAIFileServicePointerInfos(pointerInfos), - countOpenAIDirectImageAssets(pointerInfos), - ) - lifecycleCtx, releaseLifecycleCtx := detachOpenAIImageLifecycleContext(ctx, openAIImageLifecycleTimeout) - defer releaseLifecycleCtx() - if conversationID != "" && !hasOpenAIFileServicePointerInfos(pointerInfos) { - polledPointers, pollErr := pollOpenAIImageConversation(lifecycleCtx, client, headers, conversationID) - if pollErr != nil { - return nil, s.wrapOpenAIImageBackendError(ctx, c, account, pollErr) - } - pointerInfos = mergeOpenAIImagePointerInfos(pointerInfos, polledPointers) - } - pointerInfos = preferOpenAIFileServicePointerInfos(pointerInfos) - if len(pointerInfos) == 0 { - logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Image extraction yielded no assets conversation_id=%s", conversationID) - return nil, fmt.Errorf("openai image conversation returned no downloadable images") - } - - responseBody, imageCount, err := buildOpenAIImageResponse(lifecycleCtx, client, headers, conversationID, pointerInfos) - if err != nil { - return nil, s.wrapOpenAIImageBackendError(ctx, c, account, err) - } - - c.Data(http.StatusOK, "application/json; charset=utf-8", responseBody) - return &OpenAIForwardResult{ - RequestID: resp.Header.Get("x-request-id"), - Usage: usage, - Model: requestModel, - UpstreamModel: requestModel, - Stream: false, - Duration: time.Since(startTime), - FirstTokenMs: firstTokenMs, - ImageCount: imageCount, - ImageSize: parsed.SizeTier, - }, nil -} - func resolveOpenAIProxyURL(account *Account) string { if account != nil && account.ProxyID != nil && account.Proxy != nil { return account.Proxy.URL() diff --git a/backend/internal/service/openai_images_responses.go b/backend/internal/service/openai_images_responses.go new file mode 100644 index 00000000..99b5ca6e --- /dev/null +++ b/backend/internal/service/openai_images_responses.go @@ -0,0 +1,853 @@ +package service + +import ( + "bufio" + "bytes" + "context" + "encoding/base64" + "fmt" + "io" + "net/http" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/Wei-Shaw/sub2api/internal/util/responseheaders" + "github.com/gin-gonic/gin" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" +) + +type openAIResponsesImageResult struct { + Result string + RevisedPrompt string + OutputFormat string + Size string + Background string + Quality string + Model string +} + +func openAIResponsesImageResultKey(itemID string, result openAIResponsesImageResult) string { + if strings.TrimSpace(result.Result) != "" { + return strings.TrimSpace(result.OutputFormat) + "|" + strings.TrimSpace(result.Result) + } + return "item:" + strings.TrimSpace(itemID) +} + +func appendOpenAIResponsesImageResultDedup(results *[]openAIResponsesImageResult, seen map[string]struct{}, itemID string, result openAIResponsesImageResult) bool { + if results == nil { + return false + } + key := openAIResponsesImageResultKey(itemID, result) + if key != "" { + if _, exists := seen[key]; exists { + return false + } + seen[key] = struct{}{} + } + *results = append(*results, result) + return true +} + +func mergeOpenAIResponsesImageMeta(dst *openAIResponsesImageResult, src openAIResponsesImageResult) { + if dst == nil { + return + } + if trimmed := strings.TrimSpace(src.OutputFormat); trimmed != "" { + dst.OutputFormat = trimmed + } + if trimmed := strings.TrimSpace(src.Size); trimmed != "" { + dst.Size = trimmed + } + if trimmed := strings.TrimSpace(src.Background); trimmed != "" { + dst.Background = trimmed + } + if trimmed := strings.TrimSpace(src.Quality); trimmed != "" { + dst.Quality = trimmed + } + if trimmed := strings.TrimSpace(src.Model); trimmed != "" { + dst.Model = trimmed + } +} + +func extractOpenAIResponsesImageMetaFromLifecycleEvent(payload []byte) (openAIResponsesImageResult, int64, bool) { + switch gjson.GetBytes(payload, "type").String() { + case "response.created", "response.in_progress", "response.completed": + default: + return openAIResponsesImageResult{}, 0, false + } + + response := gjson.GetBytes(payload, "response") + if !response.Exists() { + return openAIResponsesImageResult{}, 0, false + } + + meta := openAIResponsesImageResult{ + OutputFormat: strings.TrimSpace(response.Get("tools.0.output_format").String()), + Size: strings.TrimSpace(response.Get("tools.0.size").String()), + Background: strings.TrimSpace(response.Get("tools.0.background").String()), + Quality: strings.TrimSpace(response.Get("tools.0.quality").String()), + Model: strings.TrimSpace(response.Get("tools.0.model").String()), + } + return meta, response.Get("created_at").Int(), true +} + +func buildOpenAIImagesStreamPartialPayload( + eventType string, + b64 string, + partialImageIndex int64, + responseFormat string, + createdAt int64, + meta openAIResponsesImageResult, +) []byte { + if createdAt <= 0 { + createdAt = time.Now().Unix() + } + + payload := []byte(`{"type":"","created_at":0,"partial_image_index":0,"b64_json":""}`) + payload, _ = sjson.SetBytes(payload, "type", eventType) + payload, _ = sjson.SetBytes(payload, "created_at", createdAt) + payload, _ = sjson.SetBytes(payload, "partial_image_index", partialImageIndex) + payload, _ = sjson.SetBytes(payload, "b64_json", b64) + if strings.EqualFold(strings.TrimSpace(responseFormat), "url") { + payload, _ = sjson.SetBytes(payload, "url", "data:"+openAIImageOutputMIMEType(meta.OutputFormat)+";base64,"+b64) + } + if meta.Background != "" { + payload, _ = sjson.SetBytes(payload, "background", meta.Background) + } + if meta.OutputFormat != "" { + payload, _ = sjson.SetBytes(payload, "output_format", meta.OutputFormat) + } + if meta.Quality != "" { + payload, _ = sjson.SetBytes(payload, "quality", meta.Quality) + } + if meta.Size != "" { + payload, _ = sjson.SetBytes(payload, "size", meta.Size) + } + if meta.Model != "" { + payload, _ = sjson.SetBytes(payload, "model", meta.Model) + } + return payload +} + +func buildOpenAIImagesStreamCompletedPayload( + eventType string, + img openAIResponsesImageResult, + responseFormat string, + createdAt int64, + usageRaw []byte, +) []byte { + if createdAt <= 0 { + createdAt = time.Now().Unix() + } + + payload := []byte(`{"type":"","created_at":0,"b64_json":""}`) + payload, _ = sjson.SetBytes(payload, "type", eventType) + payload, _ = sjson.SetBytes(payload, "created_at", createdAt) + payload, _ = sjson.SetBytes(payload, "b64_json", img.Result) + if strings.EqualFold(strings.TrimSpace(responseFormat), "url") { + payload, _ = sjson.SetBytes(payload, "url", "data:"+openAIImageOutputMIMEType(img.OutputFormat)+";base64,"+img.Result) + } + if img.Background != "" { + payload, _ = sjson.SetBytes(payload, "background", img.Background) + } + if img.OutputFormat != "" { + payload, _ = sjson.SetBytes(payload, "output_format", img.OutputFormat) + } + if img.Quality != "" { + payload, _ = sjson.SetBytes(payload, "quality", img.Quality) + } + if img.Size != "" { + payload, _ = sjson.SetBytes(payload, "size", img.Size) + } + if img.Model != "" { + payload, _ = sjson.SetBytes(payload, "model", img.Model) + } + if len(usageRaw) > 0 && gjson.ValidBytes(usageRaw) { + payload, _ = sjson.SetRawBytes(payload, "usage", usageRaw) + } + return payload +} + +func openAIImageOutputMIMEType(outputFormat string) string { + if outputFormat == "" { + return "image/png" + } + if strings.Contains(outputFormat, "/") { + return outputFormat + } + switch strings.ToLower(strings.TrimSpace(outputFormat)) { + case "png": + return "image/png" + case "jpg", "jpeg": + return "image/jpeg" + case "webp": + return "image/webp" + default: + return "image/png" + } +} + +func openAIImageUploadToDataURL(upload OpenAIImagesUpload) (string, error) { + if len(upload.Data) == 0 { + return "", fmt.Errorf("upload %q is empty", strings.TrimSpace(upload.FileName)) + } + contentType := strings.TrimSpace(upload.ContentType) + if contentType == "" { + contentType = http.DetectContentType(upload.Data) + } + return "data:" + contentType + ";base64," + base64.StdEncoding.EncodeToString(upload.Data), nil +} + +func buildOpenAIImagesResponsesRequest(parsed *OpenAIImagesRequest, toolModel string) ([]byte, error) { + if parsed == nil { + return nil, fmt.Errorf("parsed images request is required") + } + prompt := strings.TrimSpace(parsed.Prompt) + if prompt == "" { + return nil, fmt.Errorf("prompt is required") + } + + inputImages := make([]string, 0, len(parsed.InputImageURLs)+len(parsed.Uploads)) + for _, imageURL := range parsed.InputImageURLs { + if trimmed := strings.TrimSpace(imageURL); trimmed != "" { + inputImages = append(inputImages, trimmed) + } + } + for _, upload := range parsed.Uploads { + dataURL, err := openAIImageUploadToDataURL(upload) + if err != nil { + return nil, err + } + inputImages = append(inputImages, dataURL) + } + if parsed.IsEdits() && len(inputImages) == 0 { + return nil, fmt.Errorf("image input is required") + } + + req := []byte(`{"instructions":"","stream":true,"reasoning":{"effort":"medium","summary":"auto"},"parallel_tool_calls":true,"include":["reasoning.encrypted_content"],"model":"","store":false,"tool_choice":{"type":"image_generation"}}`) + req, _ = sjson.SetBytes(req, "model", openAIImagesResponsesMainModel) + + input := []byte(`[{"type":"message","role":"user","content":[{"type":"input_text","text":""}]}]`) + input, _ = sjson.SetBytes(input, "0.content.0.text", prompt) + for index, imageURL := range inputImages { + part := []byte(`{"type":"input_image","image_url":""}`) + part, _ = sjson.SetBytes(part, "image_url", imageURL) + input, _ = sjson.SetRawBytes(input, fmt.Sprintf("0.content.%d", index+1), part) + } + req, _ = sjson.SetRawBytes(req, "input", input) + + action := "generate" + if parsed.IsEdits() { + action = "edit" + } + tool := []byte(`{"type":"image_generation","action":"","model":""}`) + tool, _ = sjson.SetBytes(tool, "action", action) + tool, _ = sjson.SetBytes(tool, "model", strings.TrimSpace(toolModel)) + + for _, field := range []struct { + path string + value string + }{ + {path: "size", value: parsed.Size}, + {path: "quality", value: parsed.Quality}, + {path: "background", value: parsed.Background}, + {path: "output_format", value: parsed.OutputFormat}, + {path: "moderation", value: parsed.Moderation}, + {path: "style", value: parsed.Style}, + } { + if trimmed := strings.TrimSpace(field.value); trimmed != "" { + tool, _ = sjson.SetBytes(tool, field.path, trimmed) + } + } + if parsed.OutputCompression != nil { + tool, _ = sjson.SetBytes(tool, "output_compression", *parsed.OutputCompression) + } + if parsed.PartialImages != nil { + tool, _ = sjson.SetBytes(tool, "partial_images", *parsed.PartialImages) + } + + maskImageURL := strings.TrimSpace(parsed.MaskImageURL) + if parsed.MaskUpload != nil { + dataURL, err := openAIImageUploadToDataURL(*parsed.MaskUpload) + if err != nil { + return nil, err + } + maskImageURL = dataURL + } + if maskImageURL != "" { + tool, _ = sjson.SetBytes(tool, "input_image_mask.image_url", maskImageURL) + } + + req, _ = sjson.SetRawBytes(req, "tools", []byte(`[]`)) + req, _ = sjson.SetRawBytes(req, "tools.-1", tool) + return req, nil +} + +func extractOpenAIImagesFromResponsesCompleted(payload []byte) ([]openAIResponsesImageResult, int64, []byte, openAIResponsesImageResult, error) { + if gjson.GetBytes(payload, "type").String() != "response.completed" { + return nil, 0, nil, openAIResponsesImageResult{}, fmt.Errorf("unexpected event type") + } + + createdAt := gjson.GetBytes(payload, "response.created_at").Int() + if createdAt <= 0 { + createdAt = time.Now().Unix() + } + + var ( + results []openAIResponsesImageResult + firstMeta openAIResponsesImageResult + ) + output := gjson.GetBytes(payload, "response.output") + if output.IsArray() { + for _, item := range output.Array() { + if item.Get("type").String() != "image_generation_call" { + continue + } + result := strings.TrimSpace(item.Get("result").String()) + if result == "" { + continue + } + entry := openAIResponsesImageResult{ + Result: result, + RevisedPrompt: strings.TrimSpace(item.Get("revised_prompt").String()), + OutputFormat: strings.TrimSpace(item.Get("output_format").String()), + Size: strings.TrimSpace(item.Get("size").String()), + Background: strings.TrimSpace(item.Get("background").String()), + Quality: strings.TrimSpace(item.Get("quality").String()), + } + if len(results) == 0 { + firstMeta = entry + } + results = append(results, entry) + } + } + + var usageRaw []byte + if usage := gjson.GetBytes(payload, "response.tool_usage.image_gen"); usage.Exists() && usage.IsObject() { + usageRaw = []byte(usage.Raw) + } + return results, createdAt, usageRaw, firstMeta, nil +} + +func extractOpenAIImageFromResponsesOutputItemDone(payload []byte) (openAIResponsesImageResult, string, bool, error) { + if gjson.GetBytes(payload, "type").String() != "response.output_item.done" { + return openAIResponsesImageResult{}, "", false, fmt.Errorf("unexpected event type") + } + + item := gjson.GetBytes(payload, "item") + if !item.Exists() || item.Get("type").String() != "image_generation_call" { + return openAIResponsesImageResult{}, "", false, nil + } + + result := strings.TrimSpace(item.Get("result").String()) + if result == "" { + return openAIResponsesImageResult{}, "", false, nil + } + + entry := openAIResponsesImageResult{ + Result: result, + RevisedPrompt: strings.TrimSpace(item.Get("revised_prompt").String()), + OutputFormat: strings.TrimSpace(item.Get("output_format").String()), + Size: strings.TrimSpace(item.Get("size").String()), + Background: strings.TrimSpace(item.Get("background").String()), + Quality: strings.TrimSpace(item.Get("quality").String()), + } + return entry, strings.TrimSpace(item.Get("id").String()), true, nil +} + +func collectOpenAIImagesFromResponsesBody(body []byte) ([]openAIResponsesImageResult, int64, []byte, openAIResponsesImageResult, bool, error) { + var ( + fallbackResults []openAIResponsesImageResult + fallbackSeen = make(map[string]struct{}) + createdAt int64 + usageRaw []byte + foundFinal bool + responseMeta openAIResponsesImageResult + ) + + for _, line := range bytes.Split(body, []byte("\n")) { + line = bytes.TrimRight(line, "\r") + data, ok := extractOpenAISSEDataLine(string(line)) + if !ok || data == "" || data == "[DONE]" { + continue + } + payload := []byte(data) + if !gjson.ValidBytes(payload) { + continue + } + if meta, eventCreatedAt, ok := extractOpenAIResponsesImageMetaFromLifecycleEvent(payload); ok { + mergeOpenAIResponsesImageMeta(&responseMeta, meta) + if eventCreatedAt > 0 { + createdAt = eventCreatedAt + } + } + + switch gjson.GetBytes(payload, "type").String() { + case "response.output_item.done": + result, itemID, ok, err := extractOpenAIImageFromResponsesOutputItemDone(payload) + if err != nil { + return nil, 0, nil, openAIResponsesImageResult{}, false, err + } + if ok { + mergeOpenAIResponsesImageMeta(&result, responseMeta) + appendOpenAIResponsesImageResultDedup(&fallbackResults, fallbackSeen, itemID, result) + } + case "response.completed": + results, completedAt, completedUsageRaw, firstMeta, err := extractOpenAIImagesFromResponsesCompleted(payload) + if err != nil { + return nil, 0, nil, openAIResponsesImageResult{}, false, err + } + foundFinal = true + if completedAt > 0 { + createdAt = completedAt + } + if len(completedUsageRaw) > 0 { + usageRaw = completedUsageRaw + } + if len(results) > 0 { + mergeOpenAIResponsesImageMeta(&firstMeta, responseMeta) + return results, createdAt, usageRaw, firstMeta, true, nil + } + if len(fallbackResults) > 0 { + firstMeta = fallbackResults[0] + mergeOpenAIResponsesImageMeta(&firstMeta, responseMeta) + return fallbackResults, createdAt, usageRaw, firstMeta, true, nil + } + } + } + + if len(fallbackResults) > 0 { + firstMeta := fallbackResults[0] + mergeOpenAIResponsesImageMeta(&firstMeta, responseMeta) + return fallbackResults, createdAt, usageRaw, firstMeta, foundFinal, nil + } + return nil, createdAt, usageRaw, openAIResponsesImageResult{}, foundFinal, nil +} + +func buildOpenAIImagesAPIResponse( + results []openAIResponsesImageResult, + createdAt int64, + usageRaw []byte, + firstMeta openAIResponsesImageResult, + responseFormat string, +) ([]byte, error) { + if createdAt <= 0 { + createdAt = time.Now().Unix() + } + out := []byte(`{"created":0,"data":[]}`) + out, _ = sjson.SetBytes(out, "created", createdAt) + + format := strings.ToLower(strings.TrimSpace(responseFormat)) + if format == "" { + format = "b64_json" + } + for _, img := range results { + item := []byte(`{}`) + if format == "url" { + item, _ = sjson.SetBytes(item, "url", "data:"+openAIImageOutputMIMEType(img.OutputFormat)+";base64,"+img.Result) + } else { + item, _ = sjson.SetBytes(item, "b64_json", img.Result) + } + if img.RevisedPrompt != "" { + item, _ = sjson.SetBytes(item, "revised_prompt", img.RevisedPrompt) + } + out, _ = sjson.SetRawBytes(out, "data.-1", item) + } + if firstMeta.Background != "" { + out, _ = sjson.SetBytes(out, "background", firstMeta.Background) + } + if firstMeta.OutputFormat != "" { + out, _ = sjson.SetBytes(out, "output_format", firstMeta.OutputFormat) + } + if firstMeta.Quality != "" { + out, _ = sjson.SetBytes(out, "quality", firstMeta.Quality) + } + if firstMeta.Size != "" { + out, _ = sjson.SetBytes(out, "size", firstMeta.Size) + } + if firstMeta.Model != "" { + out, _ = sjson.SetBytes(out, "model", firstMeta.Model) + } + if len(usageRaw) > 0 && gjson.ValidBytes(usageRaw) { + out, _ = sjson.SetRawBytes(out, "usage", usageRaw) + } + return out, nil +} + +func openAIImagesStreamPrefix(parsed *OpenAIImagesRequest) string { + if parsed != nil && parsed.IsEdits() { + return "image_edit" + } + return "image_generation" +} + +func buildOpenAIImagesStreamErrorBody(message string) []byte { + body := []byte(`{"type":"error","error":{"type":"upstream_error","message":""}}`) + if strings.TrimSpace(message) == "" { + message = "upstream request failed" + } + body, _ = sjson.SetBytes(body, "error.message", message) + return body +} + +func (s *OpenAIGatewayService) writeOpenAIImagesStreamEvent(c *gin.Context, flusher http.Flusher, eventName string, payload []byte) error { + if strings.TrimSpace(eventName) != "" { + if _, err := fmt.Fprintf(c.Writer, "event: %s\n", eventName); err != nil { + return err + } + } + if _, err := fmt.Fprintf(c.Writer, "data: %s\n\n", payload); err != nil { + return err + } + flusher.Flush() + return nil +} + +func (s *OpenAIGatewayService) handleOpenAIImagesOAuthNonStreamingResponse( + resp *http.Response, + c *gin.Context, + responseFormat string, + fallbackModel string, +) (OpenAIUsage, int, error) { + body, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, openAITooLargeError) + if err != nil { + return OpenAIUsage{}, 0, err + } + + var usage OpenAIUsage + for _, line := range bytes.Split(body, []byte("\n")) { + line = bytes.TrimRight(line, "\r") + data, ok := extractOpenAISSEDataLine(string(line)) + if !ok || data == "" || data == "[DONE]" { + continue + } + dataBytes := []byte(data) + s.parseSSEUsageBytes(dataBytes, &usage) + } + results, createdAt, usageRaw, firstMeta, _, err := collectOpenAIImagesFromResponsesBody(body) + if err != nil { + return OpenAIUsage{}, 0, err + } + if len(results) == 0 { + return OpenAIUsage{}, 0, fmt.Errorf("upstream did not return image output") + } + if strings.TrimSpace(firstMeta.Model) == "" { + firstMeta.Model = strings.TrimSpace(fallbackModel) + } + + responseBody, err := buildOpenAIImagesAPIResponse(results, createdAt, usageRaw, firstMeta, responseFormat) + if err != nil { + return OpenAIUsage{}, 0, err + } + responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) + c.Data(resp.StatusCode, "application/json; charset=utf-8", responseBody) + return usage, len(results), nil +} + +func (s *OpenAIGatewayService) handleOpenAIImagesOAuthStreamingResponse( + resp *http.Response, + c *gin.Context, + startTime time.Time, + responseFormat string, + streamPrefix string, + fallbackModel string, +) (OpenAIUsage, int, *int, error) { + responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Status(resp.StatusCode) + + flusher, ok := c.Writer.(http.Flusher) + if !ok { + return OpenAIUsage{}, 0, nil, fmt.Errorf("streaming is not supported by response writer") + } + + format := strings.ToLower(strings.TrimSpace(responseFormat)) + if format == "" { + format = "b64_json" + } + + reader := bufio.NewReader(resp.Body) + usage := OpenAIUsage{} + imageCount := 0 + var firstTokenMs *int + emitted := make(map[string]struct{}) + pendingResults := make([]openAIResponsesImageResult, 0, 1) + pendingSeen := make(map[string]struct{}) + streamMeta := openAIResponsesImageResult{Model: strings.TrimSpace(fallbackModel)} + var createdAt int64 + + for { + line, err := reader.ReadBytes('\n') + if len(line) > 0 { + trimmedLine := strings.TrimRight(string(line), "\r\n") + data, ok := extractOpenAISSEDataLine(trimmedLine) + if ok && data != "" && data != "[DONE]" { + if firstTokenMs == nil { + ms := int(time.Since(startTime).Milliseconds()) + firstTokenMs = &ms + } + dataBytes := []byte(data) + s.parseSSEUsageBytes(dataBytes, &usage) + if gjson.ValidBytes(dataBytes) { + if meta, eventCreatedAt, ok := extractOpenAIResponsesImageMetaFromLifecycleEvent(dataBytes); ok { + mergeOpenAIResponsesImageMeta(&streamMeta, meta) + if eventCreatedAt > 0 { + createdAt = eventCreatedAt + } + } + switch gjson.GetBytes(dataBytes, "type").String() { + case "response.image_generation_call.partial_image": + b64 := strings.TrimSpace(gjson.GetBytes(dataBytes, "partial_image_b64").String()) + if b64 != "" { + eventName := streamPrefix + ".partial_image" + partialMeta := streamMeta + mergeOpenAIResponsesImageMeta(&partialMeta, openAIResponsesImageResult{ + OutputFormat: strings.TrimSpace(gjson.GetBytes(dataBytes, "output_format").String()), + Background: strings.TrimSpace(gjson.GetBytes(dataBytes, "background").String()), + }) + payload := buildOpenAIImagesStreamPartialPayload( + eventName, + b64, + gjson.GetBytes(dataBytes, "partial_image_index").Int(), + format, + createdAt, + partialMeta, + ) + if writeErr := s.writeOpenAIImagesStreamEvent(c, flusher, eventName, payload); writeErr != nil { + return OpenAIUsage{}, imageCount, firstTokenMs, writeErr + } + } + case "response.output_item.done": + img, itemID, ok, extractErr := extractOpenAIImageFromResponsesOutputItemDone(dataBytes) + if extractErr != nil { + _ = s.writeOpenAIImagesStreamEvent(c, flusher, "error", buildOpenAIImagesStreamErrorBody(extractErr.Error())) + return OpenAIUsage{}, imageCount, firstTokenMs, extractErr + } + if !ok { + break + } + mergeOpenAIResponsesImageMeta(&streamMeta, img) + mergeOpenAIResponsesImageMeta(&img, streamMeta) + key := openAIResponsesImageResultKey(itemID, img) + if _, exists := emitted[key]; exists { + break + } + if _, exists := pendingSeen[key]; exists { + break + } + pendingSeen[key] = struct{}{} + pendingResults = append(pendingResults, img) + case "response.completed": + results, _, usageRaw, firstMeta, extractErr := extractOpenAIImagesFromResponsesCompleted(dataBytes) + if extractErr != nil { + _ = s.writeOpenAIImagesStreamEvent(c, flusher, "error", buildOpenAIImagesStreamErrorBody(extractErr.Error())) + return OpenAIUsage{}, imageCount, firstTokenMs, extractErr + } + mergeOpenAIResponsesImageMeta(&streamMeta, firstMeta) + finalResults := make([]openAIResponsesImageResult, 0, len(results)+len(pendingResults)) + finalSeen := make(map[string]struct{}) + for _, img := range results { + mergeOpenAIResponsesImageMeta(&img, streamMeta) + appendOpenAIResponsesImageResultDedup(&finalResults, finalSeen, "", img) + } + for _, img := range pendingResults { + mergeOpenAIResponsesImageMeta(&img, streamMeta) + appendOpenAIResponsesImageResultDedup(&finalResults, finalSeen, "", img) + } + if len(finalResults) == 0 { + err = fmt.Errorf("upstream did not return image output") + _ = s.writeOpenAIImagesStreamEvent(c, flusher, "error", buildOpenAIImagesStreamErrorBody(err.Error())) + return OpenAIUsage{}, imageCount, firstTokenMs, err + } + eventName := streamPrefix + ".completed" + for _, img := range finalResults { + key := openAIResponsesImageResultKey("", img) + if _, exists := emitted[key]; exists { + continue + } + payload := buildOpenAIImagesStreamCompletedPayload(eventName, img, format, createdAt, usageRaw) + if writeErr := s.writeOpenAIImagesStreamEvent(c, flusher, eventName, payload); writeErr != nil { + return OpenAIUsage{}, imageCount, firstTokenMs, writeErr + } + emitted[key] = struct{}{} + } + imageCount = len(emitted) + return usage, imageCount, firstTokenMs, nil + } + } + } + } + if err == io.EOF { + break + } + if err != nil { + _ = s.writeOpenAIImagesStreamEvent(c, flusher, "error", buildOpenAIImagesStreamErrorBody(err.Error())) + return OpenAIUsage{}, imageCount, firstTokenMs, err + } + } + + if imageCount > 0 { + return usage, imageCount, firstTokenMs, nil + } + if len(pendingResults) > 0 { + eventName := streamPrefix + ".completed" + for _, img := range pendingResults { + mergeOpenAIResponsesImageMeta(&img, streamMeta) + key := openAIResponsesImageResultKey("", img) + if _, exists := emitted[key]; exists { + continue + } + payload := buildOpenAIImagesStreamCompletedPayload(eventName, img, format, createdAt, nil) + if writeErr := s.writeOpenAIImagesStreamEvent(c, flusher, eventName, payload); writeErr != nil { + return OpenAIUsage{}, imageCount, firstTokenMs, writeErr + } + emitted[key] = struct{}{} + } + imageCount = len(emitted) + return usage, imageCount, firstTokenMs, nil + } + + streamErr := fmt.Errorf("stream disconnected before image generation completed") + _ = s.writeOpenAIImagesStreamEvent(c, flusher, "error", buildOpenAIImagesStreamErrorBody(streamErr.Error())) + return OpenAIUsage{}, imageCount, firstTokenMs, streamErr +} + +func (s *OpenAIGatewayService) forwardOpenAIImagesOAuth( + ctx context.Context, + c *gin.Context, + account *Account, + parsed *OpenAIImagesRequest, + channelMappedModel string, +) (*OpenAIForwardResult, error) { + startTime := time.Now() + requestModel := strings.TrimSpace(parsed.Model) + if mapped := strings.TrimSpace(channelMappedModel); mapped != "" { + requestModel = mapped + } + if requestModel == "" { + requestModel = "gpt-image-2" + } + if err := validateOpenAIImagesModel(requestModel); err != nil { + return nil, err + } + logger.LegacyPrintf( + "service.openai_gateway", + "[OpenAI] Images request routing request_model=%s endpoint=%s account_type=%s uploads=%d", + requestModel, + parsed.Endpoint, + account.Type, + len(parsed.Uploads), + ) + if parsed.N > 1 { + logger.LegacyPrintf( + "service.openai_gateway", + "[Warning] Codex /responses image tool requested n=%d; falling back to n=1 request_model=%s endpoint=%s", + parsed.N, + requestModel, + parsed.Endpoint, + ) + } + + token, _, err := s.GetAccessToken(ctx, account) + if err != nil { + return nil, err + } + + responsesBody, err := buildOpenAIImagesResponsesRequest(parsed, requestModel) + if err != nil { + return nil, err + } + setOpsUpstreamRequestBody(c, responsesBody) + + upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, responsesBody, token, true, parsed.StickySessionSeed(), false) + if err != nil { + return nil, err + } + upstreamReq.Header.Set("Content-Type", "application/json") + upstreamReq.Header.Set("Accept", "text/event-stream") + + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + upstreamStart := time.Now() + resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) + SetOpsLatencyMs(c, OpsUpstreamLatencyMsKey, time.Since(upstreamStart).Milliseconds()) + if err != nil { + safeErr := sanitizeUpstreamErrorMessage(err.Error()) + setOpsUpstreamError(c, 0, safeErr, "") + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: 0, + UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()), + Kind: "request_error", + Message: safeErr, + }) + return nil, fmt.Errorf("upstream request failed: %s", safeErr) + } + if resp.StatusCode >= 400 { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + resp.Body = io.NopCloser(bytes.NewReader(respBody)) + upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody)) + upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + if s.shouldFailoverOpenAIUpstreamResponse(resp.StatusCode, upstreamMsg, respBody) { + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()), + Kind: "failover", + Message: upstreamMsg, + }) + s.handleFailoverSideEffects(ctx, resp, account) + return nil, &UpstreamFailoverError{ + StatusCode: resp.StatusCode, + ResponseBody: respBody, + RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode), + } + } + return s.handleErrorResponse(ctx, resp, c, account, responsesBody) + } + defer func() { _ = resp.Body.Close() }() + + var ( + usage OpenAIUsage + imageCount = parsed.N + firstTokenMs *int + ) + if parsed.Stream { + usage, imageCount, firstTokenMs, err = s.handleOpenAIImagesOAuthStreamingResponse(resp, c, startTime, parsed.ResponseFormat, openAIImagesStreamPrefix(parsed), requestModel) + if err != nil { + return nil, err + } + } else { + usage, imageCount, err = s.handleOpenAIImagesOAuthNonStreamingResponse(resp, c, parsed.ResponseFormat, requestModel) + if err != nil { + return nil, err + } + } + if imageCount <= 0 { + imageCount = parsed.N + } + return &OpenAIForwardResult{ + RequestID: resp.Header.Get("x-request-id"), + Usage: usage, + Model: requestModel, + UpstreamModel: requestModel, + Stream: parsed.Stream, + ResponseHeaders: resp.Header.Clone(), + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + ImageCount: imageCount, + ImageSize: parsed.SizeTier, + }, nil +} diff --git a/backend/internal/service/openai_images_test.go b/backend/internal/service/openai_images_test.go index 6aa1d5e5..200547d4 100644 --- a/backend/internal/service/openai_images_test.go +++ b/backend/internal/service/openai_images_test.go @@ -3,13 +3,17 @@ package service import ( "bytes" "context" + "io" "mime/multipart" "net/http" "net/http/httptest" + "net/textproto" + "strings" "testing" "github.com/gin-gonic/gin" "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" ) func TestOpenAIGatewayServiceParseOpenAIImagesRequest_JSON(t *testing.T) { @@ -70,6 +74,58 @@ func TestOpenAIGatewayServiceParseOpenAIImagesRequest_MultipartEdit(t *testing.T require.Equal(t, OpenAIImagesCapabilityNative, parsed.RequiredCapability) } +func TestOpenAIGatewayServiceParseOpenAIImagesRequest_MultipartEditWithMaskAndNativeOptions(t *testing.T) { + gin.SetMode(gin.TestMode) + + var body bytes.Buffer + writer := multipart.NewWriter(&body) + require.NoError(t, writer.WriteField("model", "gpt-image-2")) + require.NoError(t, writer.WriteField("prompt", "replace foreground")) + require.NoError(t, writer.WriteField("output_format", "png")) + require.NoError(t, writer.WriteField("input_fidelity", "high")) + require.NoError(t, writer.WriteField("output_compression", "80")) + require.NoError(t, writer.WriteField("partial_images", "2")) + + imageHeader := make(textproto.MIMEHeader) + imageHeader.Set("Content-Disposition", `form-data; name="image"; filename="source.png"`) + imageHeader.Set("Content-Type", "image/png") + imagePart, err := writer.CreatePart(imageHeader) + require.NoError(t, err) + _, err = imagePart.Write([]byte("source-image-bytes")) + require.NoError(t, err) + + maskHeader := make(textproto.MIMEHeader) + maskHeader.Set("Content-Disposition", `form-data; name="mask"; filename="mask.png"`) + maskHeader.Set("Content-Type", "image/png") + maskPart, err := writer.CreatePart(maskHeader) + require.NoError(t, err) + _, err = maskPart.Write([]byte("mask-image-bytes")) + require.NoError(t, err) + + require.NoError(t, writer.Close()) + + req := httptest.NewRequest(http.MethodPost, "/v1/images/edits", bytes.NewReader(body.Bytes())) + req.Header.Set("Content-Type", writer.FormDataContentType()) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = req + + svc := &OpenAIGatewayService{} + parsed, err := svc.ParseOpenAIImagesRequest(c, body.Bytes()) + require.NoError(t, err) + require.NotNil(t, parsed) + require.Len(t, parsed.Uploads, 1) + require.NotNil(t, parsed.MaskUpload) + require.True(t, parsed.HasMask) + require.Equal(t, "png", parsed.OutputFormat) + require.Equal(t, "high", parsed.InputFidelity) + require.NotNil(t, parsed.OutputCompression) + require.Equal(t, 80, *parsed.OutputCompression) + require.NotNil(t, parsed.PartialImages) + require.Equal(t, 2, *parsed.PartialImages) + require.Equal(t, OpenAIImagesCapabilityNative, parsed.RequiredCapability) +} + func TestOpenAIGatewayServiceParseOpenAIImagesRequest_PromptOnlyDefaultsRemainBasic(t *testing.T) { gin.SetMode(gin.TestMode) body := []byte(`{"prompt":"draw a cat"}`) @@ -121,6 +177,40 @@ func TestOpenAIGatewayServiceParseOpenAIImagesRequest_RejectsNonImageModel(t *te require.ErrorContains(t, err, `images endpoint requires an image model, got "gpt-5.4"`) } +func TestOpenAIGatewayServiceParseOpenAIImagesRequest_JSONEditURLs(t *testing.T) { + gin.SetMode(gin.TestMode) + body := []byte(`{ + "model":"gpt-image-2", + "prompt":"replace the background", + "images":[{"image_url":"https://example.com/source.png"}], + "mask":{"image_url":"https://example.com/mask.png"}, + "input_fidelity":"high", + "output_compression":90, + "partial_images":2, + "response_format":"url" + }`) + + req := httptest.NewRequest(http.MethodPost, "/v1/images/edits", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = req + + svc := &OpenAIGatewayService{} + parsed, err := svc.ParseOpenAIImagesRequest(c, body) + require.NoError(t, err) + require.NotNil(t, parsed) + require.Equal(t, []string{"https://example.com/source.png"}, parsed.InputImageURLs) + require.Equal(t, "https://example.com/mask.png", parsed.MaskImageURL) + require.Equal(t, "high", parsed.InputFidelity) + require.NotNil(t, parsed.OutputCompression) + require.Equal(t, 90, *parsed.OutputCompression) + require.NotNil(t, parsed.PartialImages) + require.Equal(t, 2, *parsed.PartialImages) + require.True(t, parsed.HasMask) + require.Equal(t, OpenAIImagesCapabilityNative, parsed.RequiredCapability) +} + func TestCollectOpenAIImagePointers_RecognizesDirectAssets(t *testing.T) { items := collectOpenAIImagePointers([]byte(`{ "revised_prompt": "cat astronaut", @@ -157,3 +247,472 @@ func TestResolveOpenAIImageBytes_PrefersInlineBase64(t *testing.T) { require.NoError(t, err) require.Equal(t, []byte("ABC"), data) } + +func TestAccountSupportsOpenAIImageCapability_OAuthSupportsNative(t *testing.T) { + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + } + + require.True(t, account.SupportsOpenAIImageCapability(OpenAIImagesCapabilityBasic)) + require.True(t, account.SupportsOpenAIImageCapability(OpenAIImagesCapabilityNative)) +} + +type openAIImageTestSSEEvent struct { + Name string + Data string +} + +func parseOpenAIImageTestSSEEvents(body string) []openAIImageTestSSEEvent { + chunks := strings.Split(body, "\n\n") + events := make([]openAIImageTestSSEEvent, 0, len(chunks)) + for _, chunk := range chunks { + chunk = strings.TrimSpace(chunk) + if chunk == "" { + continue + } + var event openAIImageTestSSEEvent + for _, line := range strings.Split(chunk, "\n") { + switch { + case strings.HasPrefix(line, "event: "): + event.Name = strings.TrimSpace(strings.TrimPrefix(line, "event: ")) + case strings.HasPrefix(line, "data: "): + event.Data = strings.TrimSpace(strings.TrimPrefix(line, "data: ")) + } + } + if event.Name != "" || event.Data != "" { + events = append(events, event) + } + } + return events +} + +func findOpenAIImageTestSSEEvent(events []openAIImageTestSSEEvent, name string) (openAIImageTestSSEEvent, bool) { + for _, event := range events { + if event.Name == name { + return event, true + } + } + return openAIImageTestSSEEvent{}, false +} + +func TestOpenAIGatewayServiceForwardImages_OAuthUsesResponsesAPI(t *testing.T) { + gin.SetMode(gin.TestMode) + body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","size":"1024x1024","quality":"high","n":2}`) + + req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = req + c.Set("api_key", &APIKey{ID: 42}) + + svc := &OpenAIGatewayService{} + parsed, err := svc.ParseOpenAIImagesRequest(c, body) + require.NoError(t, err) + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{ + "Content-Type": []string{"text/event-stream"}, + "X-Request-Id": []string{"req_img_123"}, + }, + Body: io.NopCloser(strings.NewReader( + "data: {\"type\":\"response.completed\",\"response\":{\"created_at\":1710000000,\"usage\":{\"input_tokens\":11,\"output_tokens\":22,\"input_tokens_details\":{\"cached_tokens\":3},\"output_tokens_details\":{\"image_tokens\":7}},\"tool_usage\":{\"image_gen\":{\"images\":1}},\"output\":[{\"type\":\"image_generation_call\",\"result\":\"aGVsbG8=\",\"revised_prompt\":\"draw a cat\",\"output_format\":\"png\",\"quality\":\"high\",\"size\":\"1024x1024\"}]}}\n\n" + + "data: [DONE]\n\n", + )), + }, + } + svc.httpUpstream = upstream + + account := &Account{ + ID: 1, + Name: "openai-oauth", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "token-123", + "chatgpt_account_id": "acct-123", + }, + } + + result, err := svc.ForwardImages(context.Background(), c, account, body, parsed, "") + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, "gpt-image-2", result.Model) + require.Equal(t, "gpt-image-2", result.UpstreamModel) + require.Equal(t, 1, result.ImageCount) + require.Equal(t, 11, result.Usage.InputTokens) + require.Equal(t, 22, result.Usage.OutputTokens) + require.Equal(t, 7, result.Usage.ImageOutputTokens) + + require.NotNil(t, upstream.lastReq) + require.Equal(t, chatgptCodexURL, upstream.lastReq.URL.String()) + require.Equal(t, "chatgpt.com", upstream.lastReq.Host) + require.Equal(t, "application/json", upstream.lastReq.Header.Get("Content-Type")) + require.Equal(t, "text/event-stream", upstream.lastReq.Header.Get("Accept")) + require.Equal(t, "acct-123", upstream.lastReq.Header.Get("chatgpt-account-id")) + require.Equal(t, "responses=experimental", upstream.lastReq.Header.Get("OpenAI-Beta")) + + require.Equal(t, openAIImagesResponsesMainModel, gjson.GetBytes(upstream.lastBody, "model").String()) + require.True(t, gjson.GetBytes(upstream.lastBody, "stream").Bool()) + require.Equal(t, "image_generation", gjson.GetBytes(upstream.lastBody, "tools.0.type").String()) + require.Equal(t, "generate", gjson.GetBytes(upstream.lastBody, "tools.0.action").String()) + require.Equal(t, "gpt-image-2", gjson.GetBytes(upstream.lastBody, "tools.0.model").String()) + require.Equal(t, "1024x1024", gjson.GetBytes(upstream.lastBody, "tools.0.size").String()) + require.Equal(t, "high", gjson.GetBytes(upstream.lastBody, "tools.0.quality").String()) + require.False(t, gjson.GetBytes(upstream.lastBody, "tools.0.n").Exists()) + require.Equal(t, "draw a cat", gjson.GetBytes(upstream.lastBody, "input.0.content.0.text").String()) + + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, "gpt-image-2", gjson.Get(rec.Body.String(), "model").String()) + require.Equal(t, "aGVsbG8=", gjson.Get(rec.Body.String(), "data.0.b64_json").String()) + require.Equal(t, "draw a cat", gjson.Get(rec.Body.String(), "data.0.revised_prompt").String()) +} + +func TestOpenAIGatewayServiceForwardImages_OAuthStreamingTransformsEvents(t *testing.T) { + gin.SetMode(gin.TestMode) + body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","stream":true,"response_format":"url"}`) + + req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = req + + svc := &OpenAIGatewayService{} + parsed, err := svc.ParseOpenAIImagesRequest(c, body) + require.NoError(t, err) + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{ + "Content-Type": []string{"text/event-stream"}, + "X-Request-Id": []string{"req_img_stream"}, + }, + Body: io.NopCloser(strings.NewReader( + "data: {\"type\":\"response.created\",\"response\":{\"created_at\":1710000001,\"tools\":[{\"type\":\"image_generation\",\"model\":\"gpt-image-2\",\"background\":\"auto\",\"output_format\":\"png\",\"quality\":\"high\",\"size\":\"1024x1024\"}]}}\n\n" + + "data: {\"type\":\"response.image_generation_call.partial_image\",\"partial_image_b64\":\"cGFydGlhbA==\",\"partial_image_index\":0,\"output_format\":\"png\",\"background\":\"auto\"}\n\n" + + "data: {\"type\":\"response.completed\",\"response\":{\"created_at\":1710000001,\"usage\":{\"input_tokens\":5,\"output_tokens\":9,\"output_tokens_details\":{\"image_tokens\":4}},\"tool_usage\":{\"image_gen\":{\"images\":1}},\"tools\":[{\"type\":\"image_generation\",\"model\":\"gpt-image-2\",\"background\":\"auto\",\"output_format\":\"png\",\"quality\":\"high\",\"size\":\"1024x1024\"}],\"output\":[{\"type\":\"image_generation_call\",\"result\":\"ZmluYWw=\",\"output_format\":\"png\"}]}}\n\n" + + "data: [DONE]\n\n", + )), + }, + } + svc.httpUpstream = upstream + + account := &Account{ + ID: 2, + Name: "openai-oauth", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "token-123", + }, + } + + result, err := svc.ForwardImages(context.Background(), c, account, body, parsed, "") + require.NoError(t, err) + require.NotNil(t, result) + require.True(t, result.Stream) + require.Equal(t, 1, result.ImageCount) + events := parseOpenAIImageTestSSEEvents(rec.Body.String()) + partial, ok := findOpenAIImageTestSSEEvent(events, "image_generation.partial_image") + require.True(t, ok) + require.Equal(t, "image_generation.partial_image", gjson.Get(partial.Data, "type").String()) + require.Equal(t, int64(1710000001), gjson.Get(partial.Data, "created_at").Int()) + require.Equal(t, "cGFydGlhbA==", gjson.Get(partial.Data, "b64_json").String()) + require.Equal(t, "data:image/png;base64,cGFydGlhbA==", gjson.Get(partial.Data, "url").String()) + require.Equal(t, "gpt-image-2", gjson.Get(partial.Data, "model").String()) + require.Equal(t, "png", gjson.Get(partial.Data, "output_format").String()) + require.Equal(t, "high", gjson.Get(partial.Data, "quality").String()) + require.Equal(t, "1024x1024", gjson.Get(partial.Data, "size").String()) + require.Equal(t, "auto", gjson.Get(partial.Data, "background").String()) + + completed, ok := findOpenAIImageTestSSEEvent(events, "image_generation.completed") + require.True(t, ok) + require.Equal(t, "image_generation.completed", gjson.Get(completed.Data, "type").String()) + require.Equal(t, int64(1710000001), gjson.Get(completed.Data, "created_at").Int()) + require.Equal(t, "ZmluYWw=", gjson.Get(completed.Data, "b64_json").String()) + require.Equal(t, "data:image/png;base64,ZmluYWw=", gjson.Get(completed.Data, "url").String()) + require.Equal(t, "gpt-image-2", gjson.Get(completed.Data, "model").String()) + require.Equal(t, "png", gjson.Get(completed.Data, "output_format").String()) + require.Equal(t, "high", gjson.Get(completed.Data, "quality").String()) + require.Equal(t, "1024x1024", gjson.Get(completed.Data, "size").String()) + require.Equal(t, "auto", gjson.Get(completed.Data, "background").String()) + require.JSONEq(t, `{"images":1}`, gjson.Get(completed.Data, "usage").Raw) + require.False(t, gjson.Get(completed.Data, "revised_prompt").Exists()) +} + +func TestOpenAIGatewayServiceForwardImages_OAuthEditsMultipartUsesResponsesAPI(t *testing.T) { + gin.SetMode(gin.TestMode) + + var body bytes.Buffer + writer := multipart.NewWriter(&body) + require.NoError(t, writer.WriteField("model", "gpt-image-2")) + require.NoError(t, writer.WriteField("prompt", "replace background with aurora")) + require.NoError(t, writer.WriteField("input_fidelity", "high")) + require.NoError(t, writer.WriteField("output_format", "webp")) + require.NoError(t, writer.WriteField("quality", "high")) + + imageHeader := make(textproto.MIMEHeader) + imageHeader.Set("Content-Disposition", `form-data; name="image"; filename="source.png"`) + imageHeader.Set("Content-Type", "image/png") + imagePart, err := writer.CreatePart(imageHeader) + require.NoError(t, err) + _, err = imagePart.Write([]byte("png-image-content")) + require.NoError(t, err) + + maskHeader := make(textproto.MIMEHeader) + maskHeader.Set("Content-Disposition", `form-data; name="mask"; filename="mask.png"`) + maskHeader.Set("Content-Type", "image/png") + maskPart, err := writer.CreatePart(maskHeader) + require.NoError(t, err) + _, err = maskPart.Write([]byte("png-mask-content")) + require.NoError(t, err) + + require.NoError(t, writer.Close()) + + req := httptest.NewRequest(http.MethodPost, "/v1/images/edits", bytes.NewReader(body.Bytes())) + req.Header.Set("Content-Type", writer.FormDataContentType()) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = req + c.Set("api_key", &APIKey{ID: 100}) + + svc := &OpenAIGatewayService{} + parsed, err := svc.ParseOpenAIImagesRequest(c, body.Bytes()) + require.NoError(t, err) + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{ + "Content-Type": []string{"text/event-stream"}, + "X-Request-Id": []string{"req_img_edit_123"}, + }, + Body: io.NopCloser(strings.NewReader( + "data: {\"type\":\"response.completed\",\"response\":{\"created_at\":1710000002,\"usage\":{\"input_tokens\":13,\"output_tokens\":21,\"output_tokens_details\":{\"image_tokens\":8}},\"tool_usage\":{\"image_gen\":{\"images\":1}},\"output\":[{\"type\":\"image_generation_call\",\"result\":\"ZWRpdGVk\",\"revised_prompt\":\"replace background with aurora\",\"output_format\":\"webp\",\"quality\":\"high\"}]}}\n\n" + + "data: [DONE]\n\n", + )), + }, + } + svc.httpUpstream = upstream + + account := &Account{ + ID: 3, + Name: "openai-oauth", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "token-123", + }, + } + + result, err := svc.ForwardImages(context.Background(), c, account, body.Bytes(), parsed, "") + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, 1, result.ImageCount) + require.Equal(t, "gpt-image-2", gjson.GetBytes(upstream.lastBody, "tools.0.model").String()) + require.Equal(t, "edit", gjson.GetBytes(upstream.lastBody, "tools.0.action").String()) + require.False(t, gjson.GetBytes(upstream.lastBody, "tools.0.input_fidelity").Exists()) + require.Equal(t, "webp", gjson.GetBytes(upstream.lastBody, "tools.0.output_format").String()) + require.True(t, strings.HasPrefix(gjson.GetBytes(upstream.lastBody, "input.0.content.1.image_url").String(), "data:image/png;base64,")) + require.True(t, strings.HasPrefix(gjson.GetBytes(upstream.lastBody, "tools.0.input_image_mask.image_url").String(), "data:image/png;base64,")) + require.Equal(t, "replace background with aurora", gjson.GetBytes(upstream.lastBody, "input.0.content.0.text").String()) + require.Equal(t, "ZWRpdGVk", gjson.Get(rec.Body.String(), "data.0.b64_json").String()) + require.Equal(t, "replace background with aurora", gjson.Get(rec.Body.String(), "data.0.revised_prompt").String()) +} + +func TestOpenAIGatewayServiceForwardImages_OAuthEditsStreamingTransformsEvents(t *testing.T) { + gin.SetMode(gin.TestMode) + body := []byte(`{ + "model":"gpt-image-2", + "prompt":"replace background with aurora", + "images":[{"image_url":"https://example.com/source.png"}], + "mask":{"image_url":"https://example.com/mask.png"}, + "stream":true, + "response_format":"url" + }`) + + req := httptest.NewRequest(http.MethodPost, "/v1/images/edits", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = req + + svc := &OpenAIGatewayService{} + parsed, err := svc.ParseOpenAIImagesRequest(c, body) + require.NoError(t, err) + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{ + "Content-Type": []string{"text/event-stream"}, + }, + Body: io.NopCloser(strings.NewReader( + "data: {\"type\":\"response.created\",\"response\":{\"created_at\":1710000003,\"tools\":[{\"type\":\"image_generation\",\"model\":\"gpt-image-2\",\"background\":\"transparent\",\"output_format\":\"webp\",\"quality\":\"high\",\"size\":\"1024x1024\"}]}}\n\n" + + "data: {\"type\":\"response.image_generation_call.partial_image\",\"partial_image_b64\":\"cGFydGlhbA==\",\"partial_image_index\":0,\"output_format\":\"webp\",\"background\":\"transparent\"}\n\n" + + "data: {\"type\":\"response.completed\",\"response\":{\"created_at\":1710000003,\"usage\":{\"input_tokens\":7,\"output_tokens\":10,\"output_tokens_details\":{\"image_tokens\":5}},\"tool_usage\":{\"image_gen\":{\"images\":1}},\"tools\":[{\"type\":\"image_generation\",\"model\":\"gpt-image-2\",\"background\":\"transparent\",\"output_format\":\"webp\",\"quality\":\"high\",\"size\":\"1024x1024\"}],\"output\":[{\"type\":\"image_generation_call\",\"result\":\"ZWRpdGVk\",\"revised_prompt\":\"replace background with aurora\",\"output_format\":\"webp\"}]}}\n\n" + + "data: [DONE]\n\n", + )), + }, + } + svc.httpUpstream = upstream + + account := &Account{ + ID: 4, + Name: "openai-oauth", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "token-123", + }, + } + + result, err := svc.ForwardImages(context.Background(), c, account, body, parsed, "") + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, 1, result.ImageCount) + require.Equal(t, "edit", gjson.GetBytes(upstream.lastBody, "tools.0.action").String()) + require.Equal(t, "https://example.com/source.png", gjson.GetBytes(upstream.lastBody, "input.0.content.1.image_url").String()) + require.Equal(t, "https://example.com/mask.png", gjson.GetBytes(upstream.lastBody, "tools.0.input_image_mask.image_url").String()) + events := parseOpenAIImageTestSSEEvents(rec.Body.String()) + partial, ok := findOpenAIImageTestSSEEvent(events, "image_edit.partial_image") + require.True(t, ok) + require.Equal(t, "image_edit.partial_image", gjson.Get(partial.Data, "type").String()) + require.Equal(t, int64(1710000003), gjson.Get(partial.Data, "created_at").Int()) + require.Equal(t, "cGFydGlhbA==", gjson.Get(partial.Data, "b64_json").String()) + require.Equal(t, "data:image/webp;base64,cGFydGlhbA==", gjson.Get(partial.Data, "url").String()) + require.Equal(t, "gpt-image-2", gjson.Get(partial.Data, "model").String()) + require.Equal(t, "webp", gjson.Get(partial.Data, "output_format").String()) + require.Equal(t, "high", gjson.Get(partial.Data, "quality").String()) + require.Equal(t, "1024x1024", gjson.Get(partial.Data, "size").String()) + require.Equal(t, "transparent", gjson.Get(partial.Data, "background").String()) + + completed, ok := findOpenAIImageTestSSEEvent(events, "image_edit.completed") + require.True(t, ok) + require.Equal(t, "image_edit.completed", gjson.Get(completed.Data, "type").String()) + require.Equal(t, int64(1710000003), gjson.Get(completed.Data, "created_at").Int()) + require.Equal(t, "ZWRpdGVk", gjson.Get(completed.Data, "b64_json").String()) + require.Equal(t, "data:image/webp;base64,ZWRpdGVk", gjson.Get(completed.Data, "url").String()) + require.Equal(t, "gpt-image-2", gjson.Get(completed.Data, "model").String()) + require.Equal(t, "webp", gjson.Get(completed.Data, "output_format").String()) + require.Equal(t, "high", gjson.Get(completed.Data, "quality").String()) + require.Equal(t, "1024x1024", gjson.Get(completed.Data, "size").String()) + require.Equal(t, "transparent", gjson.Get(completed.Data, "background").String()) + require.JSONEq(t, `{"images":1}`, gjson.Get(completed.Data, "usage").Raw) + require.False(t, gjson.Get(completed.Data, "revised_prompt").Exists()) +} + +func TestBuildOpenAIImagesResponsesRequest_DowngradesMultipleImagesToSingle(t *testing.T) { + parsed := &OpenAIImagesRequest{ + Endpoint: openAIImagesGenerationsEndpoint, + Model: "gpt-image-2", + Prompt: "draw a cat", + N: 2, + } + + body, err := buildOpenAIImagesResponsesRequest(parsed, "gpt-image-2") + require.NoError(t, err) + require.NotNil(t, body) + require.False(t, gjson.GetBytes(body, "tools.0.n").Exists()) + require.Equal(t, "gpt-image-2", gjson.GetBytes(body, "tools.0.model").String()) + require.Equal(t, "draw a cat", gjson.GetBytes(body, "input.0.content.0.text").String()) +} + +func TestBuildOpenAIImagesResponsesRequest_StripsInputFidelity(t *testing.T) { + parsed := &OpenAIImagesRequest{ + Endpoint: openAIImagesEditsEndpoint, + Model: "gpt-image-2", + Prompt: "replace background", + InputFidelity: "high", + InputImageURLs: []string{ + "https://example.com/source.png", + }, + } + + body, err := buildOpenAIImagesResponsesRequest(parsed, "gpt-image-2") + require.NoError(t, err) + require.NotNil(t, body) + require.False(t, gjson.GetBytes(body, "tools.0.input_fidelity").Exists()) + require.Equal(t, "edit", gjson.GetBytes(body, "tools.0.action").String()) +} + +func TestCollectOpenAIImagesFromResponsesBody_FallsBackToOutputItemDone(t *testing.T) { + body := []byte( + "data: {\"type\":\"response.created\",\"response\":{\"created_at\":1710000004}}\n\n" + + "data: {\"type\":\"response.output_item.done\",\"item\":{\"id\":\"ig_123\",\"type\":\"image_generation_call\",\"result\":\"aGVsbG8=\",\"revised_prompt\":\"draw a cat\",\"output_format\":\"png\",\"quality\":\"high\"}}\n\n" + + "data: {\"type\":\"response.completed\",\"response\":{\"created_at\":1710000004,\"tool_usage\":{\"image_gen\":{\"images\":1}},\"output\":[]}}\n\n" + + "data: [DONE]\n\n", + ) + + results, createdAt, usageRaw, firstMeta, foundFinal, err := collectOpenAIImagesFromResponsesBody(body) + require.NoError(t, err) + require.True(t, foundFinal) + require.Equal(t, int64(1710000004), createdAt) + require.Len(t, results, 1) + require.Equal(t, "aGVsbG8=", results[0].Result) + require.Equal(t, "draw a cat", results[0].RevisedPrompt) + require.Equal(t, "png", firstMeta.OutputFormat) + require.JSONEq(t, `{"images":1}`, string(usageRaw)) +} + +func TestOpenAIGatewayServiceForwardImages_OAuthStreamingHandlesOutputItemDoneFallback(t *testing.T) { + gin.SetMode(gin.TestMode) + body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","stream":true,"response_format":"url"}`) + + req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = req + + svc := &OpenAIGatewayService{} + parsed, err := svc.ParseOpenAIImagesRequest(c, body) + require.NoError(t, err) + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{ + "Content-Type": []string{"text/event-stream"}, + "X-Request-Id": []string{"req_img_stream_output_item_done"}, + }, + Body: io.NopCloser(strings.NewReader( + "data: {\"type\":\"response.output_item.done\",\"item\":{\"id\":\"ig_123\",\"type\":\"image_generation_call\",\"result\":\"ZmluYWw=\",\"revised_prompt\":\"draw a cat\",\"output_format\":\"png\"}}\n\n" + + "data: {\"type\":\"response.completed\",\"response\":{\"created_at\":1710000005,\"usage\":{\"input_tokens\":5,\"output_tokens\":9,\"output_tokens_details\":{\"image_tokens\":4}},\"tool_usage\":{\"image_gen\":{\"images\":1}},\"output\":[]}}\n\n" + + "data: [DONE]\n\n", + )), + }, + } + svc.httpUpstream = upstream + + account := &Account{ + ID: 5, + Name: "openai-oauth", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "token-123", + }, + } + + result, err := svc.ForwardImages(context.Background(), c, account, body, parsed, "") + require.NoError(t, err) + require.NotNil(t, result) + require.True(t, result.Stream) + require.Equal(t, 1, result.ImageCount) + events := parseOpenAIImageTestSSEEvents(rec.Body.String()) + completed, ok := findOpenAIImageTestSSEEvent(events, "image_generation.completed") + require.True(t, ok) + require.Equal(t, "image_generation.completed", gjson.Get(completed.Data, "type").String()) + require.Equal(t, int64(1710000005), gjson.Get(completed.Data, "created_at").Int()) + require.Equal(t, "ZmluYWw=", gjson.Get(completed.Data, "b64_json").String()) + require.Equal(t, "data:image/png;base64,ZmluYWw=", gjson.Get(completed.Data, "url").String()) + require.Equal(t, "gpt-image-2", gjson.Get(completed.Data, "model").String()) + require.JSONEq(t, `{"images":1}`, gjson.Get(completed.Data, "usage").Raw) + require.NotContains(t, rec.Body.String(), "event: error") +} diff --git a/backend/internal/service/ratelimit_service.go b/backend/internal/service/ratelimit_service.go index 53581574..4730303f 100644 --- a/backend/internal/service/ratelimit_service.go +++ b/backend/internal/service/ratelimit_service.go @@ -1,8 +1,10 @@ package service import ( + "bytes" "context" "encoding/json" + "fmt" "log/slog" "net/http" "strconv" @@ -23,6 +25,7 @@ type RateLimitService struct { geminiQuotaService *GeminiQuotaService tempUnschedCache TempUnschedCache timeoutCounterCache TimeoutCounterCache + openAI403CounterCache OpenAI403CounterCache settingService *SettingService tokenCacheInvalidator TokenCacheInvalidator usageCacheMu sync.RWMutex @@ -52,6 +55,12 @@ type geminiUsageTotalsBatchProvider interface { const geminiPrecheckCacheTTL = time.Minute +const ( + openAI403CooldownMinutesDefault = 10 + openAI403DisableThreshold = 3 + openAI403CounterWindowMinutes = 180 +) + // NewRateLimitService 创建RateLimitService实例 func NewRateLimitService(accountRepo AccountRepository, usageRepo UsageLogRepository, cfg *config.Config, geminiQuotaService *GeminiQuotaService, tempUnschedCache TempUnschedCache) *RateLimitService { return &RateLimitService{ @@ -69,6 +78,11 @@ func (s *RateLimitService) SetTimeoutCounterCache(cache TimeoutCounterCache) { s.timeoutCounterCache = cache } +// SetOpenAI403CounterCache 设置 OpenAI 403 连续失败计数器(可选依赖) +func (s *RateLimitService) SetOpenAI403CounterCache(cache OpenAI403CounterCache) { + s.openAI403CounterCache = cache +} + // SetSettingService 设置系统设置服务(可选依赖) func (s *RateLimitService) SetSettingService(settingService *SettingService) { s.settingService = settingService @@ -655,6 +669,30 @@ func (s *RateLimitService) handleAuthError(ctx context.Context, account *Account slog.Warn("account_disabled_auth_error", "account_id", account.ID, "error", errorMsg) } +func buildForbiddenErrorMessage(prefix string, upstreamMsg string, responseBody []byte, fallback string) string { + prefix = strings.TrimSpace(prefix) + if prefix != "" && !strings.HasSuffix(prefix, " ") { + prefix += " " + } + + if msg := strings.TrimSpace(upstreamMsg); msg != "" { + return prefix + msg + } + + rawBody := bytes.TrimSpace(responseBody) + if len(rawBody) > 0 { + if json.Valid(rawBody) { + var compact bytes.Buffer + if err := json.Compact(&compact, rawBody); err == nil { + return prefix + truncateForLog(compact.Bytes(), 512) + } + } + return prefix + truncateForLog(rawBody, 512) + } + + return prefix + fallback +} + // handle403 处理 403 Forbidden 错误 // Antigravity 平台区分 validation/violation/generic 三种类型,均 SetError 永久禁用; // 其他平台保持原有 SetError 行为。 @@ -662,15 +700,64 @@ func (s *RateLimitService) handle403(ctx context.Context, account *Account, upst if account.Platform == PlatformAntigravity { return s.handleAntigravity403(ctx, account, upstreamMsg, responseBody) } - // 非 Antigravity 平台:保持原有行为 - msg := "Access forbidden (403): account may be suspended or lack permissions" - if upstreamMsg != "" { - msg = "Access forbidden (403): " + upstreamMsg + if account.Platform == PlatformOpenAI { + return s.handleOpenAI403(ctx, account, upstreamMsg, responseBody) } + // 非 Antigravity 平台:保持原有行为 + msg := buildForbiddenErrorMessage( + "Access forbidden (403):", + upstreamMsg, + responseBody, + "account may be suspended or lack permissions", + ) s.handleAuthError(ctx, account, msg) return true } +func (s *RateLimitService) handleOpenAI403(ctx context.Context, account *Account, upstreamMsg string, responseBody []byte) (shouldDisable bool) { + msg := buildForbiddenErrorMessage( + "Access forbidden (403):", + upstreamMsg, + responseBody, + "account may be suspended or lack permissions", + ) + + if s.openAI403CounterCache == nil { + s.handleAuthError(ctx, account, msg) + return true + } + + count, err := s.openAI403CounterCache.IncrementOpenAI403Count(ctx, account.ID, openAI403CounterWindowMinutes) + if err != nil { + slog.Warn("openai_403_increment_failed", "account_id", account.ID, "error", err) + s.handleAuthError(ctx, account, msg) + return true + } + + if count >= openAI403DisableThreshold { + msg = fmt.Sprintf("%s | consecutive_403=%d/%d", msg, count, openAI403DisableThreshold) + s.handleAuthError(ctx, account, msg) + return true + } + + until := time.Now().Add(time.Duration(openAI403CooldownMinutesDefault) * time.Minute) + reason := fmt.Sprintf("OpenAI 403 temporary cooldown (%d/%d): %s", count, openAI403DisableThreshold, msg) + if err := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, reason); err != nil { + slog.Warn("openai_403_set_temp_unschedulable_failed", "account_id", account.ID, "error", err) + s.handleAuthError(ctx, account, msg) + return true + } + + slog.Warn( + "openai_403_temp_unschedulable", + "account_id", account.ID, + "until", until, + "count", count, + "threshold", openAI403DisableThreshold, + ) + return true +} + // handleAntigravity403 处理 Antigravity 平台的 403 错误 // validation(需要验证)→ 永久 SetError(需人工去 Google 验证后恢复) // violation(违规封号)→ 永久 SetError(需人工处理) @@ -681,10 +768,12 @@ func (s *RateLimitService) handleAntigravity403(ctx context.Context, account *Ac switch fbType { case forbiddenTypeValidation: // VALIDATION_REQUIRED: 永久禁用,需人工去 Google 验证后手动恢复 - msg := "Validation required (403): account needs Google verification" - if upstreamMsg != "" { - msg = "Validation required (403): " + upstreamMsg - } + msg := buildForbiddenErrorMessage( + "Validation required (403):", + upstreamMsg, + responseBody, + "account needs Google verification", + ) if validationURL := extractValidationURL(string(responseBody)); validationURL != "" { msg += " | validation_url: " + validationURL } @@ -693,19 +782,23 @@ func (s *RateLimitService) handleAntigravity403(ctx context.Context, account *Ac case forbiddenTypeViolation: // 违规封号: 永久禁用,需人工处理 - msg := "Account violation (403): terms of service violation" - if upstreamMsg != "" { - msg = "Account violation (403): " + upstreamMsg - } + msg := buildForbiddenErrorMessage( + "Account violation (403):", + upstreamMsg, + responseBody, + "terms of service violation", + ) s.handleAuthError(ctx, account, msg) return true default: // 通用 403: 保持原有行为 - msg := "Access forbidden (403): account may be suspended or lack permissions" - if upstreamMsg != "" { - msg = "Access forbidden (403): " + upstreamMsg - } + msg := buildForbiddenErrorMessage( + "Access forbidden (403):", + upstreamMsg, + responseBody, + "account may be suspended or lack permissions", + ) s.handleAuthError(ctx, account, msg) return true } @@ -1221,9 +1314,19 @@ func (s *RateLimitService) ClearRateLimit(ctx context.Context, accountID int64) slog.Warn("temp_unsched_cache_delete_failed", "account_id", accountID, "error", err) } } + s.ResetOpenAI403Counter(ctx, accountID) return nil } +func (s *RateLimitService) ResetOpenAI403Counter(ctx context.Context, accountID int64) { + if s == nil || s.openAI403CounterCache == nil || accountID <= 0 { + return + } + if err := s.openAI403CounterCache.ResetOpenAI403Count(ctx, accountID); err != nil { + slog.Warn("openai_403_reset_failed", "account_id", accountID, "error", err) + } +} + // RecoverAccountState 按需恢复账号的可恢复运行时状态。 func (s *RateLimitService) RecoverAccountState(ctx context.Context, accountID int64, options AccountRecoveryOptions) (*SuccessfulTestRecoveryResult, error) { account, err := s.accountRepo.GetByID(ctx, accountID) @@ -1250,6 +1353,9 @@ func (s *RateLimitService) RecoverAccountState(ctx context.Context, accountID in } result.ClearedRateLimit = true } + if result.ClearedError || result.ClearedRateLimit { + s.ResetOpenAI403Counter(ctx, accountID) + } return result, nil } diff --git a/backend/internal/service/ratelimit_service_401_test.go b/backend/internal/service/ratelimit_service_401_test.go index 9e5e2b0e..73b7849f 100644 --- a/backend/internal/service/ratelimit_service_401_test.go +++ b/backend/internal/service/ratelimit_service_401_test.go @@ -20,6 +20,7 @@ type rateLimitAccountRepoStub struct { updateCredentialsCalls int lastCredentials map[string]any lastErrorMsg string + lastTempReason string } func (r *rateLimitAccountRepoStub) SetError(ctx context.Context, id int64, errorMsg string) error { @@ -30,6 +31,7 @@ func (r *rateLimitAccountRepoStub) SetError(ctx context.Context, id int64, error func (r *rateLimitAccountRepoStub) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error { r.tempCalls++ + r.lastTempReason = reason return nil } @@ -44,6 +46,29 @@ type tokenCacheInvalidatorRecorder struct { err error } +type openAI403CounterCacheStub struct { + counts []int64 + resetCalls []int64 + err error +} + +func (s *openAI403CounterCacheStub) IncrementOpenAI403Count(_ context.Context, _ int64, _ int) (int64, error) { + if s.err != nil { + return 0, s.err + } + if len(s.counts) == 0 { + return 1, nil + } + count := s.counts[0] + s.counts = s.counts[1:] + return count, nil +} + +func (s *openAI403CounterCacheStub) ResetOpenAI403Count(_ context.Context, accountID int64) error { + s.resetCalls = append(s.resetCalls, accountID) + return nil +} + func (r *tokenCacheInvalidatorRecorder) InvalidateToken(ctx context.Context, account *Account) error { r.accounts = append(r.accounts, account) return r.err diff --git a/backend/internal/service/ratelimit_service_403_test.go b/backend/internal/service/ratelimit_service_403_test.go new file mode 100644 index 00000000..2fd11b71 --- /dev/null +++ b/backend/internal/service/ratelimit_service_403_test.go @@ -0,0 +1,64 @@ +//go:build unit + +package service + +import ( + "context" + "net/http" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +func TestRateLimitService_HandleUpstreamError_OpenAI403FirstHitTempUnschedulable(t *testing.T) { + repo := &rateLimitAccountRepoStub{} + counter := &openAI403CounterCacheStub{counts: []int64{1}} + service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) + service.SetOpenAI403CounterCache(counter) + account := &Account{ + ID: 301, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + } + + shouldDisable := service.HandleUpstreamError( + context.Background(), + account, + http.StatusForbidden, + http.Header{}, + []byte(`{"error":{"message":"temporary edge rejection"}}`), + ) + + require.True(t, shouldDisable) + require.Equal(t, 0, repo.setErrorCalls) + require.Equal(t, 1, repo.tempCalls) + require.Contains(t, repo.lastTempReason, "temporary edge rejection") + require.Contains(t, repo.lastTempReason, "(1/3)") +} + +func TestRateLimitService_HandleUpstreamError_OpenAI403ThresholdDisables(t *testing.T) { + repo := &rateLimitAccountRepoStub{} + counter := &openAI403CounterCacheStub{counts: []int64{3}} + service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) + service.SetOpenAI403CounterCache(counter) + account := &Account{ + ID: 302, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + } + + shouldDisable := service.HandleUpstreamError( + context.Background(), + account, + http.StatusForbidden, + http.Header{}, + []byte(`{"error":{"message":"workspace forbidden by policy"}}`), + ) + + require.True(t, shouldDisable) + require.Equal(t, 1, repo.setErrorCalls) + require.Equal(t, 0, repo.tempCalls) + require.Contains(t, repo.lastErrorMsg, "workspace forbidden by policy") + require.Contains(t, repo.lastErrorMsg, "consecutive_403=3/3") +} diff --git a/backend/internal/service/ratelimit_service_openai_test.go b/backend/internal/service/ratelimit_service_openai_test.go index 89c754c8..619bb773 100644 --- a/backend/internal/service/ratelimit_service_openai_test.go +++ b/backend/internal/service/ratelimit_service_openai_test.go @@ -7,6 +7,9 @@ import ( "net/http" "testing" "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" ) func TestCalculateOpenAI429ResetTime_7dExhausted(t *testing.T) { @@ -259,6 +262,53 @@ func TestNormalizedCodexLimits_OnlyPrimaryData(t *testing.T) { } } +func TestRateLimitService_HandleUpstreamError_403PreservesOriginalUpstreamMessage(t *testing.T) { + repo := &rateLimitAccountRepoStub{} + service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) + account := &Account{ + ID: 201, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + } + + shouldDisable := service.HandleUpstreamError( + context.Background(), + account, + 403, + http.Header{}, + []byte(`{"error":{"message":"workspace forbidden by policy","type":"invalid_request_error"}}`), + ) + + require.True(t, shouldDisable) + require.Equal(t, 1, repo.setErrorCalls) + require.Contains(t, repo.lastErrorMsg, "workspace forbidden by policy") + require.NotContains(t, repo.lastErrorMsg, "account may be suspended or lack permissions") +} + +func TestRateLimitService_HandleUpstreamError_403FallsBackToRawBody(t *testing.T) { + repo := &rateLimitAccountRepoStub{} + service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) + account := &Account{ + ID: 202, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + } + + shouldDisable := service.HandleUpstreamError( + context.Background(), + account, + 403, + http.Header{}, + []byte(`{"error":{"type":"access_denied","details":{"reason":"ip_blocked"}}}`), + ) + + require.True(t, shouldDisable) + require.Equal(t, 1, repo.setErrorCalls) + require.Contains(t, repo.lastErrorMsg, `"access_denied"`) + require.Contains(t, repo.lastErrorMsg, `"ip_blocked"`) + require.NotContains(t, repo.lastErrorMsg, "account may be suspended or lack permissions") +} + func TestNormalizedCodexLimits_OnlySecondaryData(t *testing.T) { // Test when only secondary has data, no window_minutes sUsed := 60.0 diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go index 9f33c46a..e6c4c074 100644 --- a/backend/internal/service/wire.go +++ b/backend/internal/service/wire.go @@ -210,11 +210,13 @@ func ProvideRateLimitService( geminiQuotaService *GeminiQuotaService, tempUnschedCache TempUnschedCache, timeoutCounterCache TimeoutCounterCache, + openAI403CounterCache OpenAI403CounterCache, settingService *SettingService, tokenCacheInvalidator TokenCacheInvalidator, ) *RateLimitService { svc := NewRateLimitService(accountRepo, usageRepo, cfg, geminiQuotaService, tempUnschedCache) svc.SetTimeoutCounterCache(timeoutCounterCache) + svc.SetOpenAI403CounterCache(openAI403CounterCache) svc.SetSettingService(settingService) svc.SetTokenCacheInvalidator(tokenCacheInvalidator) return svc