diff --git a/backend/internal/pkg/openai/constants.go b/backend/internal/pkg/openai/constants.go index f023e32b..60ffefb3 100644 --- a/backend/internal/pkg/openai/constants.go +++ b/backend/internal/pkg/openai/constants.go @@ -20,6 +20,8 @@ var DefaultModels = []Model{ {ID: "gpt-5.3-codex", Object: "model", Created: 1735689600, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.3 Codex"}, {ID: "gpt-5.3-codex-spark", Object: "model", Created: 1735689600, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.3 Codex Spark"}, {ID: "gpt-5.2", Object: "model", Created: 1733875200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.2"}, + {ID: "gpt-image-1", Object: "model", Created: 1733875200, OwnedBy: "openai", Type: "model", DisplayName: "GPT Image 1"}, + {ID: "gpt-image-1.5", Object: "model", Created: 1735689600, OwnedBy: "openai", Type: "model", DisplayName: "GPT Image 1.5"}, {ID: "gpt-image-2", Object: "model", Created: 1738368000, OwnedBy: "openai", Type: "model", DisplayName: "GPT Image 2"}, } diff --git a/backend/internal/service/account_test_service.go b/backend/internal/service/account_test_service.go index a5559b7d..396a3973 100644 --- a/backend/internal/service/account_test_service.go +++ b/backend/internal/service/account_test_service.go @@ -5,6 +5,7 @@ import ( "bytes" "context" "crypto/rand" + "encoding/base64" "encoding/hex" "encoding/json" "errors" @@ -52,8 +53,14 @@ type TestEvent struct { const ( defaultGeminiTextTestPrompt = "hi" defaultGeminiImageTestPrompt = "Generate a cute orange cat astronaut sticker on a clean pastel background." + defaultOpenAIImageTestPrompt = "Generate a cute orange cat astronaut sticker on a clean pastel background." ) +// isOpenAIImageModel checks if the model is an OpenAI image generation model (e.g. gpt-image-2). +func isOpenAIImageModel(model string) bool { + return strings.HasPrefix(strings.ToLower(model), "gpt-image-") +} + // AccountTestService handles account testing operations type AccountTestService struct { accountRepo AccountRepository @@ -170,7 +177,7 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int // Route to platform-specific test method if account.IsOpenAI() { - return s.testOpenAIAccountConnection(c, account, modelID) + return s.testOpenAIAccountConnection(c, account, modelID, prompt) } if account.IsGemini() { @@ -410,7 +417,7 @@ func (s *AccountTestService) testBedrockAccountConnection(c *gin.Context, ctx co } // testOpenAIAccountConnection tests an OpenAI account's connection -func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account *Account, modelID string) error { +func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account *Account, modelID string, prompt string) error { ctx := c.Request.Context() // Default to openai.DefaultTestModel for OpenAI testing @@ -429,6 +436,18 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account } } + // Route to image generation test if an image model is selected + if isOpenAIImageModel(testModelID) { + imagePrompt := strings.TrimSpace(prompt) + if imagePrompt == "" { + imagePrompt = defaultOpenAIImageTestPrompt + } + if account.Type == "apikey" { + return s.testOpenAIImageAPIKey(c, ctx, account, testModelID, imagePrompt) + } + return s.testOpenAIImageOAuth(c, ctx, account, testModelID, imagePrompt) + } + // Determine authentication method and API URL var authToken string var apiURL string @@ -1025,7 +1044,336 @@ func (s *AccountTestService) processOpenAIStream(c *gin.Context, body io.Reader) } } -// sendEvent sends a SSE event to the client +// testOpenAIImageAPIKey tests OpenAI image generation using an API Key account. +func (s *AccountTestService) testOpenAIImageAPIKey(c *gin.Context, ctx context.Context, account *Account, modelID, prompt string) error { + authToken := account.GetOpenAIApiKey() + if authToken == "" { + return s.sendErrorAndEnd(c, "No API key available") + } + + baseURL := account.GetOpenAIBaseURL() + if baseURL == "" { + baseURL = "https://api.openai.com" + } + normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL) + if err != nil { + return s.sendErrorAndEnd(c, fmt.Sprintf("Invalid base URL: %s", err.Error())) + } + apiURL := strings.TrimSuffix(normalizedBaseURL, "/") + "/v1/images/generations" + + // Set SSE headers + c.Writer.Header().Set("Content-Type", "text/event-stream") + c.Writer.Header().Set("Cache-Control", "no-cache") + c.Writer.Header().Set("Connection", "keep-alive") + c.Writer.Header().Set("X-Accel-Buffering", "no") + c.Writer.Flush() + + s.sendEvent(c, TestEvent{Type: "test_start", Model: modelID}) + + payload := map[string]any{ + "model": modelID, + "prompt": prompt, + "n": 1, + "response_format": "b64_json", + } + payloadBytes, _ := json.Marshal(payload) + + req, err := http.NewRequestWithContext(ctx, "POST", apiURL, bytes.NewReader(payloadBytes)) + if err != nil { + return s.sendErrorAndEnd(c, "Failed to create request") + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+authToken) + + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + + resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, s.tlsFPProfileService.ResolveTLSProfile(account)) + if err != nil { + return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error())) + } + defer func() { _ = resp.Body.Close() }() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to read response: %s", err.Error())) + } + + if resp.StatusCode != http.StatusOK { + return s.sendErrorAndEnd(c, fmt.Sprintf("API returned %d: %s", resp.StatusCode, string(body))) + } + + // Parse {"data": [{"b64_json": "...", "revised_prompt": "..."}]} + var result struct { + Data []struct { + B64JSON string `json:"b64_json"` + RevisedPrompt string `json:"revised_prompt"` + } `json:"data"` + } + if err := json.Unmarshal(body, &result); err != nil { + return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to parse response: %s", err.Error())) + } + + if len(result.Data) == 0 { + return s.sendErrorAndEnd(c, "No images returned from API") + } + + for _, item := range result.Data { + if item.RevisedPrompt != "" { + s.sendEvent(c, TestEvent{Type: "content", Text: item.RevisedPrompt}) + } + if item.B64JSON != "" { + s.sendEvent(c, TestEvent{ + Type: "image", + ImageURL: "data:image/png;base64," + item.B64JSON, + MimeType: "image/png", + }) + } + } + + s.sendEvent(c, TestEvent{Type: "test_complete", Success: true}) + return nil +} + +// testOpenAIImageOAuth tests OpenAI image generation using an OAuth account via ChatGPT backend API. +func (s *AccountTestService) testOpenAIImageOAuth(c *gin.Context, ctx context.Context, account *Account, modelID, prompt string) error { + authToken := account.GetOpenAIAccessToken() + if authToken == "" { + return s.sendErrorAndEnd(c, "No access token available") + } + + // Set SSE headers + c.Writer.Header().Set("Content-Type", "text/event-stream") + c.Writer.Header().Set("Cache-Control", "no-cache") + c.Writer.Header().Set("Connection", "keep-alive") + c.Writer.Header().Set("X-Accel-Buffering", "no") + c.Writer.Flush() + + s.sendEvent(c, TestEvent{Type: "test_start", Model: modelID}) + s.sendEvent(c, TestEvent{Type: "content", Text: "Initializing ChatGPT backend...\n"}) + + // Build headers (replicating buildOpenAIBackendAPIHeaders logic) + headers := buildOpenAIBackendAPIHeadersForTest(ctx, account, authToken, s.accountRepo) + + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + + client, err := newOpenAIBackendAPIClient(proxyURL) + if err != nil { + return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to create client: %s", err.Error())) + } + + // Bootstrap + if bootstrapErr := bootstrapOpenAIBackendAPI(ctx, client, headers); bootstrapErr != nil { + log.Printf("OpenAI image test bootstrap warning: %v", bootstrapErr) + } + + // Fetch chat requirements + s.sendEvent(c, TestEvent{Type: "content", Text: "Fetching chat requirements...\n"}) + chatReqs, err := fetchOpenAIChatRequirements(ctx, client, headers) + if err != nil { + return s.sendErrorAndEnd(c, fmt.Sprintf("Chat requirements failed: %s", err.Error())) + } + if chatReqs.Arkose.Required { + return s.sendErrorAndEnd(c, "Unsupported challenge: arkose required") + } + + // Initialize and prepare conversation + s.sendEvent(c, TestEvent{Type: "content", Text: "Preparing image conversation...\n"}) + parentMessageID := uuid.NewString() + proofToken := generateOpenAIProofToken(chatReqs.ProofOfWork.Required, chatReqs.ProofOfWork.Seed, chatReqs.ProofOfWork.Difficulty, headers.Get("User-Agent")) + _ = initializeOpenAIImageConversation(ctx, client, headers) + conduitToken, err := prepareOpenAIImageConversation(ctx, client, headers, prompt, parentMessageID, chatReqs.Token, proofToken) + if err != nil { + return s.sendErrorAndEnd(c, fmt.Sprintf("Conversation prepare failed: %s", err.Error())) + } + + // Build simplified conversation request (no file uploads) + convReq := buildOpenAIImageTestConversationRequest(prompt, parentMessageID) + convHeaders := cloneHTTPHeader(headers) + convHeaders.Set("Accept", "text/event-stream") + convHeaders.Set("Content-Type", "application/json") + convHeaders.Set("openai-sentinel-chat-requirements-token", chatReqs.Token) + if conduitToken != "" { + convHeaders.Set("x-conduit-token", conduitToken) + } + if proofToken != "" { + convHeaders.Set("openai-sentinel-proof-token", proofToken) + } + + s.sendEvent(c, TestEvent{Type: "content", Text: "Generating image...\n"}) + + resp, err := client.R(). + SetContext(ctx). + DisableAutoReadResponse(). + SetHeaders(headerToMap(convHeaders)). + SetBodyJsonMarshal(convReq). + Post(openAIChatGPTConversationURL) + if err != nil { + return s.sendErrorAndEnd(c, fmt.Sprintf("Conversation request failed: %s", err.Error())) + } + defer func() { + if resp != nil && resp.Body != nil { + _ = resp.Body.Close() + } + }() + if resp.StatusCode >= 400 { + return s.sendErrorAndEnd(c, fmt.Sprintf("Conversation API returned %d", resp.StatusCode)) + } + + startTime := time.Now() + conversationID, pointerInfos, _, _, err := readOpenAIImageConversationStream(resp, startTime) + if err != nil { + return s.sendErrorAndEnd(c, fmt.Sprintf("Stream read failed: %s", err.Error())) + } + + pointerInfos = mergeOpenAIImagePointerInfos(pointerInfos, nil) + if conversationID != "" && !hasOpenAIFileServicePointerInfos(pointerInfos) { + s.sendEvent(c, TestEvent{Type: "content", Text: "Waiting for image generation to complete...\n"}) + polledPointers, pollErr := pollOpenAIImageConversation(ctx, client, headers, conversationID) + if pollErr != nil { + return s.sendErrorAndEnd(c, fmt.Sprintf("Poll failed: %s", pollErr.Error())) + } + pointerInfos = mergeOpenAIImagePointerInfos(pointerInfos, polledPointers) + } + pointerInfos = preferOpenAIFileServicePointerInfos(pointerInfos) + if len(pointerInfos) == 0 { + return s.sendErrorAndEnd(c, "No images returned from conversation") + } + + s.sendEvent(c, TestEvent{Type: "content", Text: "Downloading generated image...\n"}) + + // Download and encode each image + for _, pointer := range pointerInfos { + downloadURL, err := fetchOpenAIImageDownloadURL(ctx, client, headers, conversationID, pointer.Pointer) + if err != nil { + return s.sendErrorAndEnd(c, fmt.Sprintf("Download URL fetch failed: %s", err.Error())) + } + data, err := downloadOpenAIImageBytes(ctx, client, headers, downloadURL) + if err != nil { + return s.sendErrorAndEnd(c, fmt.Sprintf("Image download failed: %s", err.Error())) + } + b64 := base64.StdEncoding.EncodeToString(data) + mimeType := http.DetectContentType(data) + if pointer.Prompt != "" { + s.sendEvent(c, TestEvent{Type: "content", Text: pointer.Prompt}) + } + s.sendEvent(c, TestEvent{ + Type: "image", + ImageURL: "data:" + mimeType + ";base64," + b64, + MimeType: mimeType, + }) + } + + s.sendEvent(c, TestEvent{Type: "test_complete", Success: true}) + return nil +} + +// buildOpenAIBackendAPIHeadersForTest builds ChatGPT backend API headers for test purposes. +// Replicates the logic from OpenAIGatewayService.buildOpenAIBackendAPIHeaders without +// requiring the full gateway service dependency. +func buildOpenAIBackendAPIHeadersForTest(ctx context.Context, account *Account, token string, repo AccountRepository) http.Header { + // Ensure device and session IDs exist + deviceID := account.GetOpenAIDeviceID() + sessionID := account.GetOpenAISessionID() + if deviceID == "" || sessionID == "" { + updates := map[string]any{} + if deviceID == "" { + deviceID = uuid.NewString() + updates["openai_device_id"] = deviceID + } + if sessionID == "" { + sessionID = uuid.NewString() + updates["openai_session_id"] = sessionID + } + if account.Extra == nil { + account.Extra = map[string]any{} + } + for key, value := range updates { + account.Extra[key] = value + } + if repo != nil { + updateCtx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + _ = repo.UpdateExtra(updateCtx, account.ID, updates) + } + } + + headers := make(http.Header) + headers.Set("Authorization", "Bearer "+token) + headers.Set("Accept", "application/json") + headers.Set("Origin", "https://chatgpt.com") + headers.Set("Referer", "https://chatgpt.com/") + headers.Set("Sec-Fetch-Dest", "empty") + headers.Set("Sec-Fetch-Mode", "cors") + headers.Set("Sec-Fetch-Site", "same-origin") + headers.Set("User-Agent", openAIImageBackendUserAgent) + if customUA := strings.TrimSpace(account.GetOpenAIUserAgent()); customUA != "" { + headers.Set("User-Agent", customUA) + } + if chatgptAccountID := strings.TrimSpace(account.GetChatGPTAccountID()); chatgptAccountID != "" { + headers.Set("chatgpt-account-id", chatgptAccountID) + } + if deviceID != "" { + headers.Set("oai-device-id", deviceID) + headers.Set("Cookie", "oai-did="+deviceID) + } + if sessionID != "" { + headers.Set("oai-session-id", sessionID) + } + return headers +} + +// buildOpenAIImageTestConversationRequest creates a simplified image generation conversation request. +func buildOpenAIImageTestConversationRequest(prompt, parentMessageID string) map[string]any { + promptText := strings.TrimSpace(prompt) + if promptText == "" { + promptText = "Generate an image." + } + metadata := map[string]any{ + "developer_mode_connector_ids": []any{}, + "selected_github_repos": []any{}, + "selected_all_github_repos": false, + "system_hints": []string{"picture_v2"}, + "serialization_metadata": map[string]any{ + "custom_symbol_offsets": []any{}, + }, + } + message := map[string]any{ + "id": uuid.NewString(), + "author": map[string]any{"role": "user"}, + "content": map[string]any{ + "content_type": "text", + "parts": []any{promptText}, + }, + "metadata": metadata, + "create_time": float64(time.Now().UnixMilli()) / 1000, + } + return map[string]any{ + "action": "next", + "client_prepare_state": "sent", + "parent_message_id": parentMessageID, + "messages": []any{message}, + "model": "auto", + "timezone_offset_min": openAITimezoneOffsetMinutes(), + "timezone": openAITimezoneName(), + "conversation_mode": map[string]any{"kind": "primary_assistant"}, + "system_hints": []string{"picture_v2"}, + "supports_buffering": true, + "supported_encodings": []string{"v1"}, + "client_contextual_info": map[string]any{"app_name": "chatgpt.com"}, + "force_nulligen": false, + "force_paragen": false, + "force_paragen_model_slug": "", + "force_rate_limit": false, + "websocket_request_id": uuid.NewString(), + } +} + func (s *AccountTestService) sendEvent(c *gin.Context, event TestEvent) { eventJSON, _ := json.Marshal(event) if _, err := fmt.Fprintf(c.Writer, "data: %s\n\n", eventJSON); err != nil { diff --git a/backend/internal/service/account_test_service_openai_test.go b/backend/internal/service/account_test_service_openai_test.go index 82606979..82ff0a8b 100644 --- a/backend/internal/service/account_test_service_openai_test.go +++ b/backend/internal/service/account_test_service_openai_test.go @@ -103,7 +103,7 @@ func TestAccountTestService_OpenAISuccessPersistsSnapshotFromHeaders(t *testing. Credentials: map[string]any{"access_token": "test-token"}, } - err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4") + err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "") require.NoError(t, err) require.NotEmpty(t, repo.updatedExtra) require.Equal(t, 42.0, repo.updatedExtra["codex_5h_used_percent"]) @@ -134,7 +134,7 @@ func TestAccountTestService_OpenAI429PersistsSnapshotWithoutRateLimit(t *testing Credentials: map[string]any{"access_token": "test-token"}, } - err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4") + err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "") require.Error(t, err) require.NotEmpty(t, repo.updatedExtra) require.Equal(t, 100.0, repo.updatedExtra["codex_5h_used_percent"]) diff --git a/backend/internal/service/openai_account_scheduler.go b/backend/internal/service/openai_account_scheduler.go index f3533ec4..808f1229 100644 --- a/backend/internal/service/openai_account_scheduler.go +++ b/backend/internal/service/openai_account_scheduler.go @@ -917,7 +917,15 @@ func (s *OpenAIGatewayService) SelectAccountWithSchedulerForImages( excludedIDs map[int64]struct{}, requiredCapability OpenAIImagesCapability, ) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error) { - return s.selectAccountWithScheduler(ctx, groupID, "", sessionHash, requestedModel, excludedIDs, OpenAIUpstreamTransportHTTPSSE, requiredCapability) + selection, decision, err := s.selectAccountWithScheduler(ctx, groupID, "", sessionHash, requestedModel, excludedIDs, OpenAIUpstreamTransportHTTPSSE, requiredCapability) + if err == nil && selection != nil && selection.Account != nil { + return selection, decision, nil + } + // 如果要求 native 能力(如指定了模型)但没有可用的 APIKey 账号,回退到 basic(OAuth 账号) + if requiredCapability == OpenAIImagesCapabilityNative { + return s.selectAccountWithScheduler(ctx, groupID, "", sessionHash, requestedModel, excludedIDs, OpenAIUpstreamTransportHTTPSSE, OpenAIImagesCapabilityBasic) + } + return selection, decision, err } func (s *OpenAIGatewayService) selectAccountWithScheduler( diff --git a/backend/internal/service/openai_images.go b/backend/internal/service/openai_images.go index 396c0381..48bce22b 100644 --- a/backend/internal/service/openai_images.go +++ b/backend/internal/service/openai_images.go @@ -1157,9 +1157,9 @@ func uploadOpenAIImageFiles(ctx context.Context, client *req.Client, headers htt if err != nil { return nil, err } - if putResp.Response != nil && putResp.Response.Body != nil { - _, _ = io.Copy(io.Discard, putResp.Response.Body) - _ = putResp.Response.Body.Close() + if putResp.Response != nil && putResp.Body != nil { + _, _ = io.Copy(io.Discard, putResp.Body) + _ = putResp.Body.Close() } if putResp.StatusCode < 200 || putResp.StatusCode >= 300 { return nil, newOpenAIImageStatusError(putResp, "upload image bytes failed") @@ -1294,10 +1294,10 @@ type openAIImageToolMessage struct { } func readOpenAIImageConversationStream(resp *req.Response, startTime time.Time) (string, []openAIImagePointerInfo, OpenAIUsage, *int, error) { - if resp == nil || resp.Response == nil || resp.Response.Body == nil { + if resp == nil || resp.Body == nil { return "", nil, OpenAIUsage{}, nil, fmt.Errorf("empty conversation response") } - reader := bufio.NewReader(resp.Response.Body) + reader := bufio.NewReader(resp.Body) var ( conversationID string firstTokenMs *int @@ -1529,8 +1529,8 @@ func pollOpenAIImageConversation(ctx context.Context, client *req.Client, header lastErr = err } else { if resp.StatusCode >= 200 && resp.StatusCode < 300 { - body, readErr := io.ReadAll(resp.Response.Body) - _ = resp.Response.Body.Close() + body, readErr := io.ReadAll(resp.Body) + _ = resp.Body.Close() if readErr != nil { lastErr = readErr goto waitNextPoll diff --git a/frontend/src/components/account/AccountTestModal.vue b/frontend/src/components/account/AccountTestModal.vue index 67409a7c..2e3db61b 100644 --- a/frontend/src/components/account/AccountTestModal.vue +++ b/frontend/src/components/account/AccountTestModal.vue @@ -55,12 +55,12 @@ /> -