diff --git a/backend/internal/handler/sora_gateway_handler.go b/backend/internal/handler/sora_gateway_handler.go index 3a5ddcb0..791100f6 100644 --- a/backend/internal/handler/sora_gateway_handler.go +++ b/backend/internal/handler/sora_gateway_handler.go @@ -12,6 +12,7 @@ import ( "os" "path" "path/filepath" + "regexp" "strconv" "strings" "time" @@ -28,6 +29,8 @@ import ( "go.uber.org/zap" ) +var soraCloudflareRayPattern = regexp.MustCompile(`(?i)cf-ray[:\s=]+([a-z0-9-]+)`) + // SoraGatewayHandler handles Sora chat completions requests type SoraGatewayHandler struct { gatewayService *service.GatewayService @@ -214,6 +217,7 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) { failedAccountIDs := make(map[int64]struct{}) lastFailoverStatus := 0 var lastFailoverBody []byte + var lastFailoverHeaders http.Header for { selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs, "") @@ -226,7 +230,7 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) { h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) return } - h.handleFailoverExhausted(c, lastFailoverStatus, lastFailoverBody, streamStarted) + h.handleFailoverExhausted(c, lastFailoverStatus, lastFailoverHeaders, lastFailoverBody, streamStarted) return } account := selection.Account @@ -289,11 +293,13 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) { failedAccountIDs[account.ID] = struct{}{} if switchCount >= maxAccountSwitches { lastFailoverStatus = failoverErr.StatusCode + lastFailoverHeaders = failoverErr.ResponseHeaders lastFailoverBody = failoverErr.ResponseBody - h.handleFailoverExhausted(c, lastFailoverStatus, lastFailoverBody, streamStarted) + h.handleFailoverExhausted(c, lastFailoverStatus, lastFailoverHeaders, lastFailoverBody, streamStarted) return } lastFailoverStatus = failoverErr.StatusCode + lastFailoverHeaders = failoverErr.ResponseHeaders lastFailoverBody = failoverErr.ResponseBody switchCount++ upstreamErrCode, upstreamErrMsg := extractUpstreamErrorCodeAndMessage(lastFailoverBody) @@ -367,14 +373,19 @@ func (h *SoraGatewayHandler) handleConcurrencyError(c *gin.Context, err error, s fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted) } -func (h *SoraGatewayHandler) handleFailoverExhausted(c *gin.Context, statusCode int, responseBody []byte, streamStarted bool) { - status, errType, errMsg := h.mapUpstreamError(statusCode, responseBody) +func (h *SoraGatewayHandler) handleFailoverExhausted(c *gin.Context, statusCode int, responseHeaders http.Header, responseBody []byte, streamStarted bool) { + status, errType, errMsg := h.mapUpstreamError(statusCode, responseHeaders, responseBody) h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted) } -func (h *SoraGatewayHandler) mapUpstreamError(statusCode int, responseBody []byte) (int, string, string) { +func (h *SoraGatewayHandler) mapUpstreamError(statusCode int, responseHeaders http.Header, responseBody []byte) (int, string, string) { + if isSoraCloudflareChallengeResponse(statusCode, responseHeaders, responseBody) { + baseMsg := fmt.Sprintf("Sora request blocked by Cloudflare challenge (HTTP %d). Please switch to a clean proxy/network and retry.", statusCode) + return http.StatusBadGateway, "upstream_error", formatSoraCloudflareChallengeMessage(baseMsg, responseHeaders, responseBody) + } + upstreamCode, upstreamMessage := extractUpstreamErrorCodeAndMessage(responseBody) - if upstreamMessage != "" { + if shouldPassthroughSoraUpstreamMessage(statusCode, upstreamMessage) { switch statusCode { case 401, 403, 404, 500, 502, 503, 504: return http.StatusBadGateway, "upstream_error", upstreamMessage @@ -404,6 +415,71 @@ func (h *SoraGatewayHandler) mapUpstreamError(statusCode int, responseBody []byt } } +func isSoraCloudflareChallengeResponse(statusCode int, headers http.Header, body []byte) bool { + if statusCode != http.StatusForbidden && statusCode != http.StatusTooManyRequests { + return false + } + if strings.EqualFold(strings.TrimSpace(headers.Get("cf-mitigated")), "challenge") { + return true + } + preview := strings.ToLower(truncateSoraErrorBody(body, 4096)) + if strings.Contains(preview, "window._cf_chl_opt") || + strings.Contains(preview, "just a moment") || + strings.Contains(preview, "enable javascript and cookies to continue") || + strings.Contains(preview, "__cf_chl_") || + strings.Contains(preview, "challenge-platform") { + return true + } + contentType := strings.ToLower(strings.TrimSpace(headers.Get("content-type"))) + if strings.Contains(contentType, "text/html") && + (strings.Contains(preview, "= 2 { + return strings.TrimSpace(matches[1]) + } + return "" +} + func extractUpstreamErrorCodeAndMessage(body []byte) (string, string) { trimmed := strings.TrimSpace(string(body)) if trimmed == "" { @@ -439,6 +515,17 @@ func truncateSoraErrorMessage(s string, maxLen int) string { return s[:maxLen] + "...(truncated)" } +func truncateSoraErrorBody(body []byte, maxLen int) string { + if maxLen <= 0 { + maxLen = 512 + } + raw := strings.TrimSpace(string(body)) + if len(raw) <= maxLen { + return raw + } + return raw[:maxLen] + "...(truncated)" +} + func (h *SoraGatewayHandler) handleStreamingAwareError(c *gin.Context, status int, errType, message string, streamStarted bool) { if streamStarted { flusher, ok := c.Writer.(http.Flusher) diff --git a/backend/internal/handler/sora_gateway_handler_test.go b/backend/internal/handler/sora_gateway_handler_test.go index d80b959c..edf3ca5e 100644 --- a/backend/internal/handler/sora_gateway_handler_test.go +++ b/backend/internal/handler/sora_gateway_handler_test.go @@ -561,7 +561,7 @@ func TestSoraHandleFailoverExhausted_StreamPassesUpstreamMessage(t *testing.T) { h := &SoraGatewayHandler{} resp := []byte(`{"error":{"message":"invalid \"prompt\"\nline2","code":"bad_request"}}`) - h.handleFailoverExhausted(c, http.StatusBadGateway, resp, true) + h.handleFailoverExhausted(c, http.StatusBadGateway, nil, resp, true) body := w.Body.String() require.True(t, strings.HasPrefix(body, "event: error\n")) @@ -579,3 +579,31 @@ func TestSoraHandleFailoverExhausted_StreamPassesUpstreamMessage(t *testing.T) { require.Equal(t, "upstream_error", errorObj["type"]) require.Equal(t, "invalid \"prompt\"\nline2", errorObj["message"]) } + +func TestSoraHandleFailoverExhausted_CloudflareChallengeIncludesRay(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) + + headers := http.Header{} + headers.Set("cf-ray", "9d01b0e9ecc35829-SEA") + body := []byte(`Just a moment...`) + + h := &SoraGatewayHandler{} + h.handleFailoverExhausted(c, http.StatusForbidden, headers, body, true) + + lines := strings.Split(strings.TrimSuffix(w.Body.String(), "\n\n"), "\n") + require.Len(t, lines, 2) + jsonStr := strings.TrimPrefix(lines[1], "data: ") + + var parsed map[string]any + require.NoError(t, json.Unmarshal([]byte(jsonStr), &parsed)) + + errorObj, ok := parsed["error"].(map[string]any) + require.True(t, ok) + require.Equal(t, "upstream_error", errorObj["type"]) + msg, _ := errorObj["message"].(string) + require.Contains(t, msg, "Cloudflare challenge") + require.Contains(t, msg, "cf-ray: 9d01b0e9ecc35829-SEA") +} diff --git a/backend/internal/service/account_test_service.go b/backend/internal/service/account_test_service.go index 95fad9a6..c3f2359a 100644 --- a/backend/internal/service/account_test_service.go +++ b/backend/internal/service/account_test_service.go @@ -12,6 +12,7 @@ import ( "io" "log" "net/http" + "net/url" "regexp" "strings" @@ -522,6 +523,7 @@ func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account * if resp.StatusCode != http.StatusOK { if isCloudflareChallengeResponse(resp.StatusCode, body) { + s.logSoraCloudflareChallenge(account, proxyURL, soraMeAPIURL, resp.Header, body) return s.sendErrorAndEnd(c, formatCloudflareChallengeMessage("Sora request blocked by Cloudflare challenge (HTTP 403). Please switch to a clean proxy/network and retry.", resp.Header, body)) } return s.sendErrorAndEnd(c, fmt.Sprintf("Sora API returned %d: %s", resp.StatusCode, truncateSoraErrorBody(body, 512))) @@ -567,6 +569,7 @@ func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account * } } else { if isCloudflareChallengeResponse(subResp.StatusCode, subBody) { + s.logSoraCloudflareChallenge(account, proxyURL, soraBillingAPIURL, subResp.Header, subBody) s.sendEvent(c, TestEvent{Type: "content", Text: formatCloudflareChallengeMessage("Subscription check blocked by Cloudflare challenge (HTTP 403)", subResp.Header, subBody)}) } else { s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Subscription check returned %d", subResp.StatusCode)}) @@ -824,6 +827,75 @@ func extractCloudflareRayID(headers http.Header, body []byte) string { return "" } +func extractSoraEgressIPHint(headers http.Header) string { + if headers == nil { + return "unknown" + } + candidates := []string{ + "x-openai-public-ip", + "x-envoy-external-address", + "cf-connecting-ip", + "x-forwarded-for", + } + for _, key := range candidates { + if value := strings.TrimSpace(headers.Get(key)); value != "" { + return value + } + } + return "unknown" +} + +func sanitizeProxyURLForLog(raw string) string { + raw = strings.TrimSpace(raw) + if raw == "" { + return "" + } + u, err := url.Parse(raw) + if err != nil { + return "" + } + if u.User != nil { + u.User = nil + } + return u.String() +} + +func endpointPathForLog(endpoint string) string { + parsed, err := url.Parse(strings.TrimSpace(endpoint)) + if err != nil || parsed.Path == "" { + return endpoint + } + return parsed.Path +} + +func (s *AccountTestService) logSoraCloudflareChallenge(account *Account, proxyURL, endpoint string, headers http.Header, body []byte) { + accountID := int64(0) + platform := "" + proxyID := "none" + if account != nil { + accountID = account.ID + platform = account.Platform + if account.ProxyID != nil { + proxyID = fmt.Sprintf("%d", *account.ProxyID) + } + } + cfRay := extractCloudflareRayID(headers, body) + if cfRay == "" { + cfRay = "unknown" + } + log.Printf( + "[SoraCFChallenge] account_id=%d platform=%s endpoint=%s path=%s proxy_id=%s proxy_url=%s cf_ray=%s egress_ip_hint=%s", + accountID, + platform, + endpoint, + endpointPathForLog(endpoint), + proxyID, + sanitizeProxyURLForLog(proxyURL), + cfRay, + extractSoraEgressIPHint(headers), + ) +} + func truncateSoraErrorBody(body []byte, max int) string { if max <= 0 { max = 512 diff --git a/backend/internal/service/account_test_service_sora_test.go b/backend/internal/service/account_test_service_sora_test.go index 0c09bf18..b5389ea2 100644 --- a/backend/internal/service/account_test_service_sora_test.go +++ b/backend/internal/service/account_test_service_sora_test.go @@ -202,3 +202,22 @@ func TestAccountTestService_testSoraAccountConnection_SubscriptionCloudflareChal require.Contains(t, body, "cf-ray: 9cff2d62d83bb98d") require.Contains(t, body, `"type":"test_complete","success":true`) } + +func TestSanitizeProxyURLForLog(t *testing.T) { + require.Equal(t, "http://proxy.example.com:8080", sanitizeProxyURLForLog("http://user:pass@proxy.example.com:8080")) + require.Equal(t, "", sanitizeProxyURLForLog("")) + require.Equal(t, "", sanitizeProxyURLForLog("://invalid")) +} + +func TestExtractSoraEgressIPHint(t *testing.T) { + h := make(http.Header) + h.Set("x-openai-public-ip", "203.0.113.10") + require.Equal(t, "203.0.113.10", extractSoraEgressIPHint(h)) + + h2 := make(http.Header) + h2.Set("x-envoy-external-address", "198.51.100.9") + require.Equal(t, "198.51.100.9", extractSoraEgressIPHint(h2)) + + require.Equal(t, "unknown", extractSoraEgressIPHint(nil)) + require.Equal(t, "unknown", extractSoraEgressIPHint(http.Header{})) +} diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index cf0f298d..abdb1120 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -376,8 +376,9 @@ type ForwardResult struct { type UpstreamFailoverError struct { StatusCode int ResponseBody []byte // 上游响应体,用于错误透传规则匹配 - ForceCacheBilling bool // Antigravity 粘性会话切换时设为 true - RetryableOnSameAccount bool // 临时性错误(如 Google 间歇性 400、空响应),应在同一账号上重试 N 次再切换 + ResponseHeaders http.Header + ForceCacheBilling bool // Antigravity 粘性会话切换时设为 true + RetryableOnSameAccount bool // 临时性错误(如 Google 间歇性 400、空响应),应在同一账号上重试 N 次再切换 } func (e *UpstreamFailoverError) Error() string { diff --git a/backend/internal/service/sora_client.go b/backend/internal/service/sora_client.go index 23d93386..7ca99ad2 100644 --- a/backend/internal/service/sora_client.go +++ b/backend/internal/service/sora_client.go @@ -8,6 +8,7 @@ import ( "encoding/json" "errors" "fmt" + "hash/fnv" "io" "log" "math/rand" @@ -97,6 +98,7 @@ var soraDesktopUserAgents = []string{ var soraRand = rand.New(rand.NewSource(time.Now().UnixNano())) var soraRandMu sync.Mutex var soraPerfStart = time.Now() +var soraPowTokenGenerator = soraGetPowToken // SoraClient 定义直连 Sora 的任务操作接口。 type SoraClient interface { @@ -224,9 +226,11 @@ func (c *SoraDirectClient) PreflightCheck(ctx context.Context, account *Account, if err != nil { return err } - headers := c.buildBaseHeaders(token, c.defaultUserAgent()) + userAgent := c.taskUserAgent() + proxyURL := c.resolveProxyURL(account) + headers := c.buildBaseHeaders(token, userAgent) headers.Set("Accept", "application/json") - body, _, err := c.doRequest(ctx, account, http.MethodGet, c.buildURL("/nf/check"), headers, nil, false) + body, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodGet, c.buildURL("/nf/check"), headers, nil, false) if err != nil { var upstreamErr *SoraUpstreamError if errors.As(err, &upstreamErr) && upstreamErr.StatusCode == http.StatusNotFound { @@ -264,6 +268,8 @@ func (c *SoraDirectClient) UploadImage(ctx context.Context, account *Account, da if err != nil { return "", err } + userAgent := c.taskUserAgent() + proxyURL := c.resolveProxyURL(account) if filename == "" { filename = "image.png" } @@ -290,10 +296,10 @@ func (c *SoraDirectClient) UploadImage(ctx context.Context, account *Account, da return "", err } - headers := c.buildBaseHeaders(token, c.defaultUserAgent()) + headers := c.buildBaseHeaders(token, userAgent) headers.Set("Content-Type", writer.FormDataContentType()) - respBody, _, err := c.doRequest(ctx, account, http.MethodPost, c.buildURL("/uploads"), headers, &body, false) + respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodPost, c.buildURL("/uploads"), headers, &body, false) if err != nil { return "", err } @@ -309,6 +315,8 @@ func (c *SoraDirectClient) CreateImageTask(ctx context.Context, account *Account if err != nil { return "", err } + userAgent := c.taskUserAgent() + proxyURL := c.resolveProxyURL(account) operation := "simple_compose" inpaintItems := []map[string]any{} if strings.TrimSpace(req.MediaID) != "" { @@ -329,7 +337,7 @@ func (c *SoraDirectClient) CreateImageTask(ctx context.Context, account *Account "n_frames": 1, "inpaint_items": inpaintItems, } - headers := c.buildBaseHeaders(token, c.defaultUserAgent()) + headers := c.buildBaseHeaders(token, userAgent) headers.Set("Content-Type", "application/json") headers.Set("Origin", "https://sora.chatgpt.com") headers.Set("Referer", "https://sora.chatgpt.com/") @@ -338,13 +346,13 @@ func (c *SoraDirectClient) CreateImageTask(ctx context.Context, account *Account if err != nil { return "", err } - sentinel, err := c.generateSentinelToken(ctx, account, token) + sentinel, err := c.generateSentinelToken(ctx, account, token, userAgent, proxyURL) if err != nil { return "", err } headers.Set("openai-sentinel-token", sentinel) - respBody, _, err := c.doRequest(ctx, account, http.MethodPost, c.buildURL("/video_gen"), headers, bytes.NewReader(body), true) + respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodPost, c.buildURL("/video_gen"), headers, bytes.NewReader(body), true) if err != nil { return "", err } @@ -360,6 +368,8 @@ func (c *SoraDirectClient) CreateVideoTask(ctx context.Context, account *Account if err != nil { return "", err } + userAgent := c.taskUserAgent() + proxyURL := c.resolveProxyURL(account) orientation := req.Orientation if orientation == "" { orientation = "landscape" @@ -399,7 +409,7 @@ func (c *SoraDirectClient) CreateVideoTask(ctx context.Context, account *Account payload["cameo_replacements"] = map[string]any{} } - headers := c.buildBaseHeaders(token, c.defaultUserAgent()) + headers := c.buildBaseHeaders(token, userAgent) headers.Set("Content-Type", "application/json") headers.Set("Origin", "https://sora.chatgpt.com") headers.Set("Referer", "https://sora.chatgpt.com/") @@ -407,13 +417,13 @@ func (c *SoraDirectClient) CreateVideoTask(ctx context.Context, account *Account if err != nil { return "", err } - sentinel, err := c.generateSentinelToken(ctx, account, token) + sentinel, err := c.generateSentinelToken(ctx, account, token, userAgent, proxyURL) if err != nil { return "", err } headers.Set("openai-sentinel-token", sentinel) - respBody, _, err := c.doRequest(ctx, account, http.MethodPost, c.buildURL("/nf/create"), headers, bytes.NewReader(body), true) + respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodPost, c.buildURL("/nf/create"), headers, bytes.NewReader(body), true) if err != nil { return "", err } @@ -429,6 +439,8 @@ func (c *SoraDirectClient) EnhancePrompt(ctx context.Context, account *Account, if err != nil { return "", err } + userAgent := c.taskUserAgent() + proxyURL := c.resolveProxyURL(account) if strings.TrimSpace(expansionLevel) == "" { expansionLevel = "medium" } @@ -446,13 +458,13 @@ func (c *SoraDirectClient) EnhancePrompt(ctx context.Context, account *Account, return "", err } - headers := c.buildBaseHeaders(token, c.defaultUserAgent()) + headers := c.buildBaseHeaders(token, userAgent) headers.Set("Content-Type", "application/json") headers.Set("Accept", "application/json") headers.Set("Origin", "https://sora.chatgpt.com") headers.Set("Referer", "https://sora.chatgpt.com/") - respBody, _, err := c.doRequest(ctx, account, http.MethodPost, c.buildURL("/editor/enhance_prompt"), headers, bytes.NewReader(body), false) + respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodPost, c.buildURL("/editor/enhance_prompt"), headers, bytes.NewReader(body), false) if err != nil { return "", err } @@ -489,12 +501,14 @@ func (c *SoraDirectClient) fetchRecentImageTask(ctx context.Context, account *Ac if err != nil { return nil, false, err } - headers := c.buildBaseHeaders(token, c.defaultUserAgent()) + userAgent := c.taskUserAgent() + proxyURL := c.resolveProxyURL(account) + headers := c.buildBaseHeaders(token, userAgent) if limit <= 0 { limit = 20 } endpoint := fmt.Sprintf("/v2/recent_tasks?limit=%d", limit) - respBody, _, err := c.doRequest(ctx, account, http.MethodGet, c.buildURL(endpoint), headers, nil, false) + respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodGet, c.buildURL(endpoint), headers, nil, false) if err != nil { return nil, false, err } @@ -551,9 +565,11 @@ func (c *SoraDirectClient) GetVideoTask(ctx context.Context, account *Account, t if err != nil { return nil, err } - headers := c.buildBaseHeaders(token, c.defaultUserAgent()) + userAgent := c.taskUserAgent() + proxyURL := c.resolveProxyURL(account) + headers := c.buildBaseHeaders(token, userAgent) - respBody, _, err := c.doRequest(ctx, account, http.MethodGet, c.buildURL("/nf/pending/v2"), headers, nil, false) + respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodGet, c.buildURL("/nf/pending/v2"), headers, nil, false) if err != nil { return nil, err } @@ -582,7 +598,7 @@ func (c *SoraDirectClient) GetVideoTask(ctx context.Context, account *Account, t } } - respBody, _, err = c.doRequest(ctx, account, http.MethodGet, c.buildURL("/project_y/profile/drafts?limit=15"), headers, nil, false) + respBody, _, err = c.doRequestWithProxy(ctx, account, proxyURL, http.MethodGet, c.buildURL("/project_y/profile/drafts?limit=15"), headers, nil, false) if err != nil { return nil, err } @@ -653,6 +669,25 @@ func (c *SoraDirectClient) defaultUserAgent() string { return ua } +func (c *SoraDirectClient) taskUserAgent() string { + if c != nil && c.cfg != nil { + if ua := strings.TrimSpace(c.cfg.Sora.Client.UserAgent); ua != "" { + return ua + } + } + if len(soraDesktopUserAgents) > 0 { + return soraDesktopUserAgents[0] + } + return soraDefaultUserAgent +} + +func (c *SoraDirectClient) resolveProxyURL(account *Account) string { + if account == nil || account.ProxyID == nil || account.Proxy == nil { + return "" + } + return strings.TrimSpace(account.Proxy.URL()) +} + func (c *SoraDirectClient) getAccessToken(ctx context.Context, account *Account) (string, error) { if account == nil { return "", errors.New("account is nil") @@ -925,9 +960,26 @@ func (c *SoraDirectClient) buildBaseHeaders(token, userAgent string) http.Header } func (c *SoraDirectClient) doRequest(ctx context.Context, account *Account, method, urlStr string, headers http.Header, body io.Reader, allowRetry bool) ([]byte, http.Header, error) { + return c.doRequestWithProxy(ctx, account, c.resolveProxyURL(account), method, urlStr, headers, body, allowRetry) +} + +func (c *SoraDirectClient) doRequestWithProxy( + ctx context.Context, + account *Account, + proxyURL string, + method, + urlStr string, + headers http.Header, + body io.Reader, + allowRetry bool, +) ([]byte, http.Header, error) { if strings.TrimSpace(urlStr) == "" { return nil, nil, errors.New("empty upstream url") } + proxyURL = strings.TrimSpace(proxyURL) + if proxyURL == "" { + proxyURL = c.resolveProxyURL(account) + } timeout := 0 if c != nil && c.cfg != nil { timeout = c.cfg.Sora.Client.TimeoutSeconds @@ -968,7 +1020,7 @@ func (c *SoraDirectClient) doRequest(ctx context.Context, account *Account, meth attempts, timeout, len(bodyBytes), - account != nil && account.ProxyID != nil && account.Proxy != nil, + proxyURL != "", formatSoraHeaders(headers), ) } @@ -984,10 +1036,6 @@ func (c *SoraDirectClient) doRequest(ctx context.Context, account *Account, meth req.Header = headers.Clone() start := time.Now() - proxyURL := "" - if account != nil && account.ProxyID != nil && account.Proxy != nil { - proxyURL = account.Proxy.URL() - } resp, err := c.doHTTP(req, proxyURL, account) if err != nil { lastErr = err @@ -1183,10 +1231,13 @@ func soraBaseURLNotFoundHint(requestURL string) string { return "(请检查 sora.client.base_url,建议配置为 https://sora.chatgpt.com/backend)" } -func (c *SoraDirectClient) generateSentinelToken(ctx context.Context, account *Account, accessToken string) (string, error) { +func (c *SoraDirectClient) generateSentinelToken(ctx context.Context, account *Account, accessToken, userAgent, proxyURL string) (string, error) { reqID := uuid.NewString() - userAgent := soraRandChoice(soraDesktopUserAgents) - powToken := soraGetPowToken(userAgent) + userAgent = strings.TrimSpace(userAgent) + if userAgent == "" { + userAgent = c.taskUserAgent() + } + powToken := soraPowTokenGenerator(userAgent) payload := map[string]any{ "p": powToken, "flow": soraSentinelFlow, @@ -1207,7 +1258,7 @@ func (c *SoraDirectClient) generateSentinelToken(ctx context.Context, account *A } urlStr := soraChatGPTBaseURL + "/backend-api/sentinel/req" - respBody, _, err := c.doRequest(ctx, account, http.MethodPost, urlStr, headers, bytes.NewReader(body), true) + respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodPost, urlStr, headers, bytes.NewReader(body), true) if err != nil { return "", err } @@ -1223,16 +1274,6 @@ func (c *SoraDirectClient) generateSentinelToken(ctx context.Context, account *A return sentinel, nil } -func soraRandChoice(items []string) string { - if len(items) == 0 { - return "" - } - soraRandMu.Lock() - idx := soraRand.Intn(len(items)) - soraRandMu.Unlock() - return items[idx] -} - func soraGetPowToken(userAgent string) string { configList := soraBuildPowConfig(userAgent) seed := strconv.FormatFloat(soraRandFloat(), 'f', -1, 64) @@ -1248,13 +1289,16 @@ func soraRandFloat() float64 { } func soraBuildPowConfig(userAgent string) []any { - screen := soraRandChoice([]string{ - strconv.Itoa(1920 + 1080), - strconv.Itoa(2560 + 1440), - strconv.Itoa(1920 + 1200), - strconv.Itoa(2560 + 1600), - }) - screenVal, _ := strconv.Atoi(screen) + userAgent = strings.TrimSpace(userAgent) + if userAgent == "" && len(soraDesktopUserAgents) > 0 { + userAgent = soraDesktopUserAgents[0] + } + screenVal := soraStableChoiceInt([]int{ + 1920 + 1080, + 2560 + 1440, + 1920 + 1200, + 2560 + 1600, + }, userAgent+"|screen") perfMs := float64(time.Since(soraPerfStart).Milliseconds()) wallMs := float64(time.Now().UnixNano()) / 1e6 diff := wallMs - perfMs @@ -1264,32 +1308,47 @@ func soraBuildPowConfig(userAgent string) []any { 4294705152, 0, userAgent, - soraRandChoice(soraPowScripts), - soraRandChoice(soraPowDPL), + soraStableChoice(soraPowScripts, userAgent+"|script"), + soraStableChoice(soraPowDPL, userAgent+"|dpl"), "en-US", "en-US,es-US,en,es", 0, - soraRandChoice(soraPowNavigatorKeys), - soraRandChoice(soraPowDocumentKeys), - soraRandChoice(soraPowWindowKeys), + soraStableChoice(soraPowNavigatorKeys, userAgent+"|navigator"), + soraStableChoice(soraPowDocumentKeys, userAgent+"|document"), + soraStableChoice(soraPowWindowKeys, userAgent+"|window"), perfMs, uuid.NewString(), "", - soraRandChoiceInt(soraPowCores), + soraStableChoiceInt(soraPowCores, userAgent+"|cores"), diff, } } -func soraRandChoiceInt(items []int) int { +func soraStableChoice(items []string, seed string) string { + if len(items) == 0 { + return "" + } + idx := soraStableIndex(seed, len(items)) + return items[idx] +} + +func soraStableChoiceInt(items []int, seed string) int { if len(items) == 0 { return 0 } - soraRandMu.Lock() - idx := soraRand.Intn(len(items)) - soraRandMu.Unlock() + idx := soraStableIndex(seed, len(items)) return items[idx] } +func soraStableIndex(seed string, size int) int { + if size <= 0 { + return 0 + } + h := fnv.New32a() + _, _ = h.Write([]byte(seed)) + return int(h.Sum32() % uint32(size)) +} + func soraPowParseTime() string { loc := time.FixedZone("EST", -5*3600) return time.Now().In(loc).Format("Mon Jan 02 2006 15:04:05 GMT-0700 (Eastern Standard Time)") diff --git a/backend/internal/service/sora_client_test.go b/backend/internal/service/sora_client_test.go index e566f06b..d50b2d85 100644 --- a/backend/internal/service/sora_client_test.go +++ b/backend/internal/service/sora_client_test.go @@ -5,6 +5,8 @@ package service import ( "context" "encoding/json" + "errors" + "io" "net/http" "net/http/httptest" "strings" @@ -365,3 +367,166 @@ func TestShouldAttemptSoraTokenRecover(t *testing.T) { require.False(t, shouldAttemptSoraTokenRecover(http.StatusUnauthorized, "https://auth.openai.com/oauth/token")) require.False(t, shouldAttemptSoraTokenRecover(http.StatusTooManyRequests, "https://sora.chatgpt.com/backend/video_gen")) } + +type soraClientRequestCall struct { + Path string + UserAgent string + ProxyURL string +} + +type soraClientRecordingUpstream struct { + calls []soraClientRequestCall +} + +func (u *soraClientRecordingUpstream) Do(_ *http.Request, _ string, _ int64, _ int) (*http.Response, error) { + return nil, errors.New("unexpected Do call") +} + +func (u *soraClientRecordingUpstream) DoWithTLS(req *http.Request, proxyURL string, _ int64, _ int, _ bool) (*http.Response, error) { + u.calls = append(u.calls, soraClientRequestCall{ + Path: req.URL.Path, + UserAgent: req.Header.Get("User-Agent"), + ProxyURL: proxyURL, + }) + switch req.URL.Path { + case "/backend-api/sentinel/req": + return newSoraClientMockResponse(http.StatusOK, `{"token":"sentinel-token","turnstile":{"dx":"ok"}}`), nil + case "/backend/nf/create": + return newSoraClientMockResponse(http.StatusOK, `{"id":"task-123"}`), nil + case "/backend/uploads": + return newSoraClientMockResponse(http.StatusOK, `{"id":"upload-123"}`), nil + case "/backend/nf/check": + return newSoraClientMockResponse(http.StatusOK, `{"rate_limit_and_credit_balance":{"estimated_num_videos_remaining":1,"rate_limit_reached":false}}`), nil + default: + return newSoraClientMockResponse(http.StatusOK, `{"ok":true}`), nil + } +} + +func newSoraClientMockResponse(statusCode int, body string) *http.Response { + return &http.Response{ + StatusCode: statusCode, + Header: make(http.Header), + Body: io.NopCloser(strings.NewReader(body)), + } +} + +func TestSoraDirectClient_TaskUserAgent_DefaultDesktopFallback(t *testing.T) { + client := NewSoraDirectClient(&config.Config{}, nil, nil) + require.Equal(t, soraDesktopUserAgents[0], client.taskUserAgent()) +} + +func TestSoraDirectClient_CreateVideoTask_UsesSameUserAgentAndProxyForSentinelAndCreate(t *testing.T) { + originPowTokenGenerator := soraPowTokenGenerator + soraPowTokenGenerator = func(_ string) string { return "gAAAAACmock" } + defer func() { + soraPowTokenGenerator = originPowTokenGenerator + }() + + upstream := &soraClientRecordingUpstream{} + cfg := &config.Config{ + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + BaseURL: "https://sora.chatgpt.com/backend", + }, + }, + } + client := NewSoraDirectClient(cfg, upstream, nil) + proxyID := int64(9) + account := &Account{ + ID: 21, + Platform: PlatformSora, + Type: AccountTypeOAuth, + Concurrency: 1, + ProxyID: &proxyID, + Proxy: &Proxy{ + Protocol: "http", + Host: "127.0.0.1", + Port: 8080, + }, + Credentials: map[string]any{ + "access_token": "access-token", + "expires_at": time.Now().Add(30 * time.Minute).Format(time.RFC3339), + }, + } + + taskID, err := client.CreateVideoTask(context.Background(), account, SoraVideoRequest{Prompt: "test"}) + require.NoError(t, err) + require.Equal(t, "task-123", taskID) + require.Len(t, upstream.calls, 2) + + sentinelCall := upstream.calls[0] + createCall := upstream.calls[1] + require.Equal(t, "/backend-api/sentinel/req", sentinelCall.Path) + require.Equal(t, "/backend/nf/create", createCall.Path) + require.Equal(t, "http://127.0.0.1:8080", sentinelCall.ProxyURL) + require.Equal(t, sentinelCall.ProxyURL, createCall.ProxyURL) + require.Equal(t, soraDesktopUserAgents[0], sentinelCall.UserAgent) + require.Equal(t, sentinelCall.UserAgent, createCall.UserAgent) +} + +func TestSoraDirectClient_UploadImage_UsesTaskUserAgentAndProxy(t *testing.T) { + upstream := &soraClientRecordingUpstream{} + cfg := &config.Config{ + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + BaseURL: "https://sora.chatgpt.com/backend", + }, + }, + } + client := NewSoraDirectClient(cfg, upstream, nil) + proxyID := int64(3) + account := &Account{ + ID: 31, + ProxyID: &proxyID, + Proxy: &Proxy{ + Protocol: "http", + Host: "127.0.0.1", + Port: 8080, + }, + Credentials: map[string]any{ + "access_token": "access-token", + "expires_at": time.Now().Add(30 * time.Minute).Format(time.RFC3339), + }, + } + + uploadID, err := client.UploadImage(context.Background(), account, []byte("mock-image"), "a.png") + require.NoError(t, err) + require.Equal(t, "upload-123", uploadID) + require.Len(t, upstream.calls, 1) + require.Equal(t, "/backend/uploads", upstream.calls[0].Path) + require.Equal(t, "http://127.0.0.1:8080", upstream.calls[0].ProxyURL) + require.Equal(t, soraDesktopUserAgents[0], upstream.calls[0].UserAgent) +} + +func TestSoraDirectClient_PreflightCheck_UsesTaskUserAgentAndProxy(t *testing.T) { + upstream := &soraClientRecordingUpstream{} + cfg := &config.Config{ + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + BaseURL: "https://sora.chatgpt.com/backend", + }, + }, + } + client := NewSoraDirectClient(cfg, upstream, nil) + proxyID := int64(7) + account := &Account{ + ID: 41, + ProxyID: &proxyID, + Proxy: &Proxy{ + Protocol: "http", + Host: "127.0.0.1", + Port: 8080, + }, + Credentials: map[string]any{ + "access_token": "access-token", + "expires_at": time.Now().Add(30 * time.Minute).Format(time.RFC3339), + }, + } + + err := client.PreflightCheck(context.Background(), account, "sora2", SoraModelConfig{Type: "video"}) + require.NoError(t, err) + require.Len(t, upstream.calls, 1) + require.Equal(t, "/backend/nf/check", upstream.calls[0].Path) + require.Equal(t, "http://127.0.0.1:8080", upstream.calls[0].ProxyURL) + require.Equal(t, soraDesktopUserAgents[0], upstream.calls[0].UserAgent) +} diff --git a/backend/internal/service/sora_gateway_service.go b/backend/internal/service/sora_gateway_service.go index 8ae89f92..ef47f6d4 100644 --- a/backend/internal/service/sora_gateway_service.go +++ b/backend/internal/service/sora_gateway_service.go @@ -468,7 +468,18 @@ func (s *SoraGatewayService) writeSoraError(c *gin.Context, status int, errType, } if stream { flusher, _ := c.Writer.(http.Flusher) - errorEvent := fmt.Sprintf(`event: error`+"\n"+`data: {"error": {"type": "%s", "message": "%s"}}`+"\n\n", errType, message) + errorData := map[string]any{ + "error": map[string]string{ + "type": errType, + "message": message, + }, + } + jsonBytes, err := json.Marshal(errorData) + if err != nil { + _ = c.Error(err) + return + } + errorEvent := fmt.Sprintf("event: error\ndata: %s\n\n", string(jsonBytes)) _, _ = fmt.Fprint(c.Writer, errorEvent) _, _ = fmt.Fprint(c.Writer, "data: [DONE]\n\n") if flusher != nil { @@ -494,7 +505,11 @@ func (s *SoraGatewayService) handleSoraRequestError(ctx context.Context, account s.rateLimitService.HandleUpstreamError(ctx, account, upstreamErr.StatusCode, upstreamErr.Headers, upstreamErr.Body) } if s.shouldFailoverUpstreamError(upstreamErr.StatusCode) { - return &UpstreamFailoverError{StatusCode: upstreamErr.StatusCode, ResponseBody: upstreamErr.Body} + return &UpstreamFailoverError{ + StatusCode: upstreamErr.StatusCode, + ResponseBody: upstreamErr.Body, + ResponseHeaders: upstreamErr.Headers, + } } msg := upstreamErr.Message if override := soraProErrorMessage(model, msg); override != "" { diff --git a/backend/internal/service/sora_gateway_service_test.go b/backend/internal/service/sora_gateway_service_test.go index f706d052..469a131e 100644 --- a/backend/internal/service/sora_gateway_service_test.go +++ b/backend/internal/service/sora_gateway_service_test.go @@ -4,10 +4,15 @@ package service import ( "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" "testing" "time" "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/gin-gonic/gin" "github.com/stretchr/testify/require" ) @@ -210,6 +215,33 @@ func TestSoraProErrorMessage(t *testing.T) { require.Empty(t, soraProErrorMessage("sora-basic", "")) } +func TestSoraGatewayService_WriteSoraError_StreamEscapesJSON(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) + + svc := NewSoraGatewayService(nil, nil, nil, &config.Config{}) + svc.writeSoraError(c, http.StatusBadGateway, "upstream_error", "invalid \"prompt\"\nline2", true) + + body := rec.Body.String() + require.Contains(t, body, "event: error\n") + require.Contains(t, body, "data: [DONE]\n\n") + + lines := strings.Split(body, "\n") + require.GreaterOrEqual(t, len(lines), 2) + require.Equal(t, "event: error", lines[0]) + require.True(t, strings.HasPrefix(lines[1], "data: ")) + + data := strings.TrimPrefix(lines[1], "data: ") + var parsed map[string]any + require.NoError(t, json.Unmarshal([]byte(data), &parsed)) + errObj, ok := parsed["error"].(map[string]any) + require.True(t, ok) + require.Equal(t, "upstream_error", errObj["type"]) + require.Equal(t, "invalid \"prompt\"\nline2", errObj["message"]) +} + func TestShouldFailoverUpstreamError(t *testing.T) { svc := NewSoraGatewayService(nil, nil, nil, &config.Config{}) require.True(t, svc.shouldFailoverUpstreamError(401))