From 40498aac9dc4044d680e969ac5066ba3e3deb281 Mon Sep 17 00:00:00 2001 From: yangjianbo Date: Thu, 19 Feb 2026 20:04:10 +0800 Subject: [PATCH] =?UTF-8?q?feat(sora):=20=E5=AF=B9=E9=BD=90sora2api?= =?UTF-8?q?=E5=88=86=E9=95=9C=E8=A7=92=E8=89=B2=E5=8E=BB=E6=B0=B4=E5=8D=B0?= =?UTF-8?q?=E4=B8=8E=E6=8C=91=E6=88=98=E9=94=99=E8=AF=AF=E6=B2=BB=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/cmd/server/VERSION | 2 +- backend/cmd/server/wire_gen.go | 2 +- .../internal/handler/sora_gateway_handler.go | 99 +-- .../handler/sora_gateway_handler_test.go | 67 ++ .../internal/service/account_test_service.go | 274 ++++++-- .../service/account_test_service_sora_test.go | 96 +++ backend/internal/service/sora_client.go | 519 ++++++++++++++- backend/internal/service/sora_client_test.go | 119 +++- .../internal/service/sora_gateway_service.go | 626 ++++++++++++++++-- .../service/sora_gateway_service_test.go | 175 ++++- backend/internal/util/soraerror/soraerror.go | 170 +++++ .../internal/util/soraerror/soraerror_test.go | 47 ++ 12 files changed, 1994 insertions(+), 202 deletions(-) create mode 100644 backend/internal/util/soraerror/soraerror.go create mode 100644 backend/internal/util/soraerror/soraerror_test.go diff --git a/backend/cmd/server/VERSION b/backend/cmd/server/VERSION index d778f67c..f788a87d 100644 --- a/backend/cmd/server/VERSION +++ b/backend/cmd/server/VERSION @@ -1 +1 @@ -0.1.83.1 +0.1.83.2 diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index be17fb01..a0f8807a 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -184,7 +184,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler) gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, usageService, apiKeyService, errorPassthroughService, configConfig) openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, apiKeyService, errorPassthroughService, configConfig) - soraDirectClient := service.NewSoraDirectClient(configConfig, httpUpstream, openAITokenProvider) + soraDirectClient := service.ProvideSoraDirectClient(configConfig, httpUpstream, openAITokenProvider, accountRepository, soraAccountRepository) soraMediaStorage := service.ProvideSoraMediaStorage(configConfig) soraGatewayService := service.NewSoraGatewayService(soraDirectClient, soraMediaStorage, rateLimitService, configConfig) soraGatewayHandler := handler.NewSoraGatewayHandler(gatewayService, soraGatewayService, concurrencyService, billingCacheService, configConfig) diff --git a/backend/internal/handler/sora_gateway_handler.go b/backend/internal/handler/sora_gateway_handler.go index 791100f6..219922aa 100644 --- a/backend/internal/handler/sora_gateway_handler.go +++ b/backend/internal/handler/sora_gateway_handler.go @@ -12,7 +12,6 @@ import ( "os" "path" "path/filepath" - "regexp" "strconv" "strings" "time" @@ -22,6 +21,7 @@ import ( "github.com/Wei-Shaw/sub2api/internal/pkg/logger" middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/Wei-Shaw/sub2api/internal/util/soraerror" "github.com/gin-gonic/gin" "github.com/tidwall/gjson" @@ -29,8 +29,6 @@ import ( "go.uber.org/zap" ) -var soraCloudflareRayPattern = regexp.MustCompile(`(?i)cf-ray[:\s=]+([a-z0-9-]+)`) - // SoraGatewayHandler handles Sora chat completions requests type SoraGatewayHandler struct { gatewayService *service.GatewayService @@ -385,6 +383,10 @@ func (h *SoraGatewayHandler) mapUpstreamError(statusCode int, responseHeaders ht } upstreamCode, upstreamMessage := extractUpstreamErrorCodeAndMessage(responseBody) + if strings.EqualFold(upstreamCode, "cf_shield_429") { + baseMsg := "Sora request blocked by Cloudflare shield (429). Please switch to a clean proxy/network and retry." + return http.StatusTooManyRequests, "rate_limit_error", formatSoraCloudflareChallengeMessage(baseMsg, responseHeaders, responseBody) + } if shouldPassthroughSoraUpstreamMessage(statusCode, upstreamMessage) { switch statusCode { case 401, 403, 404, 500, 502, 503, 504: @@ -416,27 +418,7 @@ func (h *SoraGatewayHandler) mapUpstreamError(statusCode int, responseHeaders ht } func isSoraCloudflareChallengeResponse(statusCode int, headers http.Header, body []byte) bool { - if statusCode != http.StatusForbidden && statusCode != http.StatusTooManyRequests { - return false - } - if strings.EqualFold(strings.TrimSpace(headers.Get("cf-mitigated")), "challenge") { - return true - } - preview := strings.ToLower(truncateSoraErrorBody(body, 4096)) - if strings.Contains(preview, "window._cf_chl_opt") || - strings.Contains(preview, "just a moment") || - strings.Contains(preview, "enable javascript and cookies to continue") || - strings.Contains(preview, "__cf_chl_") || - strings.Contains(preview, "challenge-platform") { - return true - } - contentType := strings.ToLower(strings.TrimSpace(headers.Get("content-type"))) - if strings.Contains(contentType, "text/html") && - (strings.Contains(preview, "= 2 { - return strings.TrimSpace(matches[1]) - } - return "" + return soraerror.FormatCloudflareChallengeMessage(base, headers, body) } func extractUpstreamErrorCodeAndMessage(body []byte) (string, string) { - trimmed := strings.TrimSpace(string(body)) - if trimmed == "" { - return "", "" - } - if !gjson.Valid(trimmed) { - return "", truncateSoraErrorMessage(trimmed, 256) - } - code := strings.TrimSpace(gjson.Get(trimmed, "error.code").String()) - if code == "" { - code = strings.TrimSpace(gjson.Get(trimmed, "code").String()) - } - message := strings.TrimSpace(gjson.Get(trimmed, "error.message").String()) - if message == "" { - message = strings.TrimSpace(gjson.Get(trimmed, "message").String()) - } - if message == "" { - message = strings.TrimSpace(gjson.Get(trimmed, "error.detail").String()) - } - if message == "" { - message = strings.TrimSpace(gjson.Get(trimmed, "detail").String()) - } - return code, truncateSoraErrorMessage(message, 512) -} - -func truncateSoraErrorMessage(s string, maxLen int) string { - if maxLen <= 0 { - return "" - } - if len(s) <= maxLen { - return s - } - return s[:maxLen] + "...(truncated)" -} - -func truncateSoraErrorBody(body []byte, maxLen int) string { - if maxLen <= 0 { - maxLen = 512 - } - raw := strings.TrimSpace(string(body)) - if len(raw) <= maxLen { - return raw - } - return raw[:maxLen] + "...(truncated)" + return soraerror.ExtractUpstreamErrorCodeAndMessage(body) } func (h *SoraGatewayHandler) handleStreamingAwareError(c *gin.Context, status int, errType, message string, streamStarted bool) { diff --git a/backend/internal/handler/sora_gateway_handler_test.go b/backend/internal/handler/sora_gateway_handler_test.go index edf3ca5e..52ff0a96 100644 --- a/backend/internal/handler/sora_gateway_handler_test.go +++ b/backend/internal/handler/sora_gateway_handler_test.go @@ -43,6 +43,45 @@ func (s *stubSoraClient) CreateImageTask(ctx context.Context, account *service.A func (s *stubSoraClient) CreateVideoTask(ctx context.Context, account *service.Account, req service.SoraVideoRequest) (string, error) { return "task-video", nil } +func (s *stubSoraClient) CreateStoryboardTask(ctx context.Context, account *service.Account, req service.SoraStoryboardRequest) (string, error) { + return "task-video", nil +} +func (s *stubSoraClient) UploadCharacterVideo(ctx context.Context, account *service.Account, data []byte) (string, error) { + return "cameo-1", nil +} +func (s *stubSoraClient) GetCameoStatus(ctx context.Context, account *service.Account, cameoID string) (*service.SoraCameoStatus, error) { + return &service.SoraCameoStatus{ + Status: "finalized", + StatusMessage: "Completed", + DisplayNameHint: "Character", + UsernameHint: "user.character", + ProfileAssetURL: "https://example.com/avatar.webp", + }, nil +} +func (s *stubSoraClient) DownloadCharacterImage(ctx context.Context, account *service.Account, imageURL string) ([]byte, error) { + return []byte("avatar"), nil +} +func (s *stubSoraClient) UploadCharacterImage(ctx context.Context, account *service.Account, data []byte) (string, error) { + return "asset-pointer", nil +} +func (s *stubSoraClient) FinalizeCharacter(ctx context.Context, account *service.Account, req service.SoraCharacterFinalizeRequest) (string, error) { + return "character-1", nil +} +func (s *stubSoraClient) SetCharacterPublic(ctx context.Context, account *service.Account, cameoID string) error { + return nil +} +func (s *stubSoraClient) DeleteCharacter(ctx context.Context, account *service.Account, characterID string) error { + return nil +} +func (s *stubSoraClient) PostVideoForWatermarkFree(ctx context.Context, account *service.Account, generationID string) (string, error) { + return "s_post", nil +} +func (s *stubSoraClient) DeletePost(ctx context.Context, account *service.Account, postID string) error { + return nil +} +func (s *stubSoraClient) GetWatermarkFreeURLCustom(ctx context.Context, account *service.Account, parseURL, parseToken, postID string) (string, error) { + return "https://example.com/no-watermark.mp4", nil +} func (s *stubSoraClient) EnhancePrompt(ctx context.Context, account *service.Account, prompt, expansionLevel string, durationS int) (string, error) { return "enhanced prompt", nil } @@ -607,3 +646,31 @@ func TestSoraHandleFailoverExhausted_CloudflareChallengeIncludesRay(t *testing.T require.Contains(t, msg, "Cloudflare challenge") require.Contains(t, msg, "cf-ray: 9d01b0e9ecc35829-SEA") } + +func TestSoraHandleFailoverExhausted_CfShield429MappedToRateLimitError(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) + + headers := http.Header{} + headers.Set("cf-ray", "9d03b68c086027a1-SEA") + body := []byte(`{"error":{"code":"cf_shield_429","message":"shield blocked"}}`) + + h := &SoraGatewayHandler{} + h.handleFailoverExhausted(c, http.StatusTooManyRequests, headers, body, true) + + lines := strings.Split(strings.TrimSuffix(w.Body.String(), "\n\n"), "\n") + require.Len(t, lines, 2) + jsonStr := strings.TrimPrefix(lines[1], "data: ") + + var parsed map[string]any + require.NoError(t, json.Unmarshal([]byte(jsonStr), &parsed)) + + errorObj, ok := parsed["error"].(map[string]any) + require.True(t, ok) + require.Equal(t, "rate_limit_error", errorObj["type"]) + msg, _ := errorObj["message"].(string) + require.Contains(t, msg, "Cloudflare shield") + require.Contains(t, msg, "cf-ray: 9d03b68c086027a1-SEA") +} diff --git a/backend/internal/service/account_test_service.go b/backend/internal/service/account_test_service.go index c3f2359a..a507efb4 100644 --- a/backend/internal/service/account_test_service.go +++ b/backend/internal/service/account_test_service.go @@ -15,11 +15,14 @@ import ( "net/url" "regexp" "strings" + "sync" + "time" "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/pkg/claude" "github.com/Wei-Shaw/sub2api/internal/pkg/geminicli" "github.com/Wei-Shaw/sub2api/internal/pkg/openai" + "github.com/Wei-Shaw/sub2api/internal/util/soraerror" "github.com/Wei-Shaw/sub2api/internal/util/urlvalidator" "github.com/gin-gonic/gin" "github.com/google/uuid" @@ -28,7 +31,6 @@ import ( // sseDataPrefix matches SSE data lines with optional whitespace after colon. // Some upstream APIs return non-standard "data:" without space (should be "data: "). var sseDataPrefix = regexp.MustCompile(`^data:\s*`) -var cloudflareRayPattern = regexp.MustCompile(`(?i)cRay:\s*'([a-z0-9-]+)'`) const ( testClaudeAPIURL = "https://api.anthropic.com/v1/messages" @@ -45,6 +47,9 @@ type TestEvent struct { Type string `json:"type"` Text string `json:"text,omitempty"` Model string `json:"model,omitempty"` + Status string `json:"status,omitempty"` + Code string `json:"code,omitempty"` + Data any `json:"data,omitempty"` Success bool `json:"success,omitempty"` Error string `json:"error,omitempty"` } @@ -56,8 +61,13 @@ type AccountTestService struct { antigravityGatewayService *AntigravityGatewayService httpUpstream HTTPUpstream cfg *config.Config + soraTestGuardMu sync.Mutex + soraTestLastRun map[int64]time.Time + soraTestCooldown time.Duration } +const defaultSoraTestCooldown = 10 * time.Second + // NewAccountTestService creates a new AccountTestService func NewAccountTestService( accountRepo AccountRepository, @@ -72,6 +82,8 @@ func NewAccountTestService( antigravityGatewayService: antigravityGatewayService, httpUpstream: httpUpstream, cfg: cfg, + soraTestLastRun: make(map[int64]time.Time), + soraTestCooldown: defaultSoraTestCooldown, } } @@ -473,13 +485,129 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account return s.processGeminiStream(c, resp.Body) } +type soraProbeStep struct { + Name string `json:"name"` + Status string `json:"status"` + HTTPStatus int `json:"http_status,omitempty"` + ErrorCode string `json:"error_code,omitempty"` + Message string `json:"message,omitempty"` +} + +type soraProbeSummary struct { + Status string `json:"status"` + Steps []soraProbeStep `json:"steps"` +} + +type soraProbeRecorder struct { + steps []soraProbeStep +} + +func (r *soraProbeRecorder) addStep(name, status string, httpStatus int, errorCode, message string) { + r.steps = append(r.steps, soraProbeStep{ + Name: name, + Status: status, + HTTPStatus: httpStatus, + ErrorCode: strings.TrimSpace(errorCode), + Message: strings.TrimSpace(message), + }) +} + +func (r *soraProbeRecorder) finalize() soraProbeSummary { + meSuccess := false + partial := false + for _, step := range r.steps { + if step.Name == "me" { + meSuccess = strings.EqualFold(step.Status, "success") + continue + } + if strings.EqualFold(step.Status, "failed") { + partial = true + } + } + + status := "success" + if !meSuccess { + status = "failed" + } else if partial { + status = "partial_success" + } + + return soraProbeSummary{ + Status: status, + Steps: append([]soraProbeStep(nil), r.steps...), + } +} + +func (s *AccountTestService) emitSoraProbeSummary(c *gin.Context, rec *soraProbeRecorder) { + if rec == nil { + return + } + summary := rec.finalize() + code := "" + for _, step := range summary.Steps { + if strings.EqualFold(step.Status, "failed") && strings.TrimSpace(step.ErrorCode) != "" { + code = step.ErrorCode + break + } + } + s.sendEvent(c, TestEvent{ + Type: "sora_test_result", + Status: summary.Status, + Code: code, + Data: summary, + }) +} + +func (s *AccountTestService) acquireSoraTestPermit(accountID int64) (time.Duration, bool) { + if accountID <= 0 { + return 0, true + } + s.soraTestGuardMu.Lock() + defer s.soraTestGuardMu.Unlock() + + if s.soraTestLastRun == nil { + s.soraTestLastRun = make(map[int64]time.Time) + } + cooldown := s.soraTestCooldown + if cooldown <= 0 { + cooldown = defaultSoraTestCooldown + } + + now := time.Now() + if lastRun, ok := s.soraTestLastRun[accountID]; ok { + elapsed := now.Sub(lastRun) + if elapsed < cooldown { + return cooldown - elapsed, false + } + } + s.soraTestLastRun[accountID] = now + return 0, true +} + +func ceilSeconds(d time.Duration) int { + if d <= 0 { + return 1 + } + sec := int(d / time.Second) + if d%time.Second != 0 { + sec++ + } + if sec < 1 { + sec = 1 + } + return sec +} + // testSoraAccountConnection 测试 Sora 账号的连接 // 调用 /backend/me 接口验证 access_token 有效性(不需要 Sentinel Token) func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account *Account) error { ctx := c.Request.Context() + recorder := &soraProbeRecorder{} authToken := account.GetCredential("access_token") if authToken == "" { + recorder.addStep("me", "failed", http.StatusUnauthorized, "missing_access_token", "No access token available") + s.emitSoraProbeSummary(c, recorder) return s.sendErrorAndEnd(c, "No access token available") } @@ -490,11 +618,20 @@ func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account * c.Writer.Header().Set("X-Accel-Buffering", "no") c.Writer.Flush() + if wait, ok := s.acquireSoraTestPermit(account.ID); !ok { + msg := fmt.Sprintf("Sora 账号测试过于频繁,请 %d 秒后重试", ceilSeconds(wait)) + recorder.addStep("rate_limit", "failed", http.StatusTooManyRequests, "test_rate_limited", msg) + s.emitSoraProbeSummary(c, recorder) + return s.sendErrorAndEnd(c, msg) + } + // Send test_start event s.sendEvent(c, TestEvent{Type: "test_start", Model: "sora"}) req, err := http.NewRequestWithContext(ctx, "GET", soraMeAPIURL, nil) if err != nil { + recorder.addStep("me", "failed", 0, "request_build_failed", err.Error()) + s.emitSoraProbeSummary(c, recorder) return s.sendErrorAndEnd(c, "Failed to create request") } @@ -515,6 +652,8 @@ func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account * resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, enableSoraTLSFingerprint) if err != nil { + recorder.addStep("me", "failed", 0, "network_error", err.Error()) + s.emitSoraProbeSummary(c, recorder) return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error())) } defer func() { _ = resp.Body.Close() }() @@ -522,12 +661,33 @@ func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account * body, _ := io.ReadAll(resp.Body) if resp.StatusCode != http.StatusOK { - if isCloudflareChallengeResponse(resp.StatusCode, body) { + if isCloudflareChallengeResponse(resp.StatusCode, resp.Header, body) { + recorder.addStep("me", "failed", resp.StatusCode, "cf_challenge", "Cloudflare challenge detected") + s.emitSoraProbeSummary(c, recorder) s.logSoraCloudflareChallenge(account, proxyURL, soraMeAPIURL, resp.Header, body) - return s.sendErrorAndEnd(c, formatCloudflareChallengeMessage("Sora request blocked by Cloudflare challenge (HTTP 403). Please switch to a clean proxy/network and retry.", resp.Header, body)) + return s.sendErrorAndEnd(c, formatCloudflareChallengeMessage(fmt.Sprintf("Sora request blocked by Cloudflare challenge (HTTP %d). Please switch to a clean proxy/network and retry.", resp.StatusCode), resp.Header, body)) + } + upstreamCode, upstreamMessage := soraerror.ExtractUpstreamErrorCodeAndMessage(body) + switch { + case resp.StatusCode == http.StatusUnauthorized && strings.EqualFold(upstreamCode, "token_invalidated"): + recorder.addStep("me", "failed", resp.StatusCode, "token_invalidated", "Sora token invalidated") + s.emitSoraProbeSummary(c, recorder) + return s.sendErrorAndEnd(c, "Sora token 已失效(token_invalidated),请重新授权账号") + case strings.EqualFold(upstreamCode, "unsupported_country_code"): + recorder.addStep("me", "failed", resp.StatusCode, "unsupported_country_code", "Sora is unavailable in current egress region") + s.emitSoraProbeSummary(c, recorder) + return s.sendErrorAndEnd(c, "Sora 在当前网络出口地区不可用(unsupported_country_code),请切换到支持地区后重试") + case strings.TrimSpace(upstreamMessage) != "": + recorder.addStep("me", "failed", resp.StatusCode, upstreamCode, upstreamMessage) + s.emitSoraProbeSummary(c, recorder) + return s.sendErrorAndEnd(c, fmt.Sprintf("Sora API returned %d: %s", resp.StatusCode, upstreamMessage)) + default: + recorder.addStep("me", "failed", resp.StatusCode, upstreamCode, "Sora me endpoint failed") + s.emitSoraProbeSummary(c, recorder) + return s.sendErrorAndEnd(c, fmt.Sprintf("Sora API returned %d: %s", resp.StatusCode, truncateSoraErrorBody(body, 512))) } - return s.sendErrorAndEnd(c, fmt.Sprintf("Sora API returned %d: %s", resp.StatusCode, truncateSoraErrorBody(body, 512))) } + recorder.addStep("me", "success", resp.StatusCode, "", "me endpoint ok") // 解析 /me 响应,提取用户信息 var meResp map[string]any @@ -557,21 +717,26 @@ func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account * subResp, subErr := s.httpUpstream.DoWithTLS(subReq, proxyURL, account.ID, account.Concurrency, enableSoraTLSFingerprint) if subErr != nil { + recorder.addStep("subscription", "failed", 0, "network_error", subErr.Error()) s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Subscription check skipped: %s", subErr.Error())}) } else { subBody, _ := io.ReadAll(subResp.Body) _ = subResp.Body.Close() if subResp.StatusCode == http.StatusOK { + recorder.addStep("subscription", "success", subResp.StatusCode, "", "subscription endpoint ok") if summary := parseSoraSubscriptionSummary(subBody); summary != "" { s.sendEvent(c, TestEvent{Type: "content", Text: summary}) } else { s.sendEvent(c, TestEvent{Type: "content", Text: "Subscription check OK"}) } } else { - if isCloudflareChallengeResponse(subResp.StatusCode, subBody) { + if isCloudflareChallengeResponse(subResp.StatusCode, subResp.Header, subBody) { + recorder.addStep("subscription", "failed", subResp.StatusCode, "cf_challenge", "Cloudflare challenge detected") s.logSoraCloudflareChallenge(account, proxyURL, soraBillingAPIURL, subResp.Header, subBody) - s.sendEvent(c, TestEvent{Type: "content", Text: formatCloudflareChallengeMessage("Subscription check blocked by Cloudflare challenge (HTTP 403)", subResp.Header, subBody)}) + s.sendEvent(c, TestEvent{Type: "content", Text: formatCloudflareChallengeMessage(fmt.Sprintf("Subscription check blocked by Cloudflare challenge (HTTP %d)", subResp.StatusCode), subResp.Header, subBody)}) } else { + upstreamCode, upstreamMessage := soraerror.ExtractUpstreamErrorCodeAndMessage(subBody) + recorder.addStep("subscription", "failed", subResp.StatusCode, upstreamCode, upstreamMessage) s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Subscription check returned %d", subResp.StatusCode)}) } } @@ -579,8 +744,9 @@ func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account * } // 追加 Sora2 能力探测(对齐 sora2api 的测试思路):邀请码 + 剩余额度。 - s.testSora2Capabilities(c, ctx, account, authToken, proxyURL, enableSoraTLSFingerprint) + s.testSora2Capabilities(c, ctx, account, authToken, proxyURL, enableSoraTLSFingerprint, recorder) + s.emitSoraProbeSummary(c, recorder) s.sendEvent(c, TestEvent{Type: "test_complete", Success: true}) return nil } @@ -592,6 +758,7 @@ func (s *AccountTestService) testSora2Capabilities( authToken string, proxyURL string, enableTLSFingerprint bool, + recorder *soraProbeRecorder, ) { inviteStatus, inviteHeader, inviteBody, err := s.fetchSoraTestEndpoint( ctx, @@ -602,6 +769,9 @@ func (s *AccountTestService) testSora2Capabilities( enableTLSFingerprint, ) if err != nil { + if recorder != nil { + recorder.addStep("sora2_invite", "failed", 0, "network_error", err.Error()) + } s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 invite check skipped: %s", err.Error())}) return } @@ -616,6 +786,9 @@ func (s *AccountTestService) testSora2Capabilities( enableTLSFingerprint, ) if bootstrapErr == nil && bootstrapStatus == http.StatusOK { + if recorder != nil { + recorder.addStep("sora2_bootstrap", "success", bootstrapStatus, "", "bootstrap endpoint ok") + } s.sendEvent(c, TestEvent{Type: "content", Text: "Sora2 bootstrap OK, retry invite check"}) inviteStatus, inviteHeader, inviteBody, err = s.fetchSoraTestEndpoint( ctx, @@ -626,20 +799,42 @@ func (s *AccountTestService) testSora2Capabilities( enableTLSFingerprint, ) if err != nil { + if recorder != nil { + recorder.addStep("sora2_invite", "failed", 0, "network_error", err.Error()) + } s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 invite retry failed: %s", err.Error())}) return } + } else if recorder != nil { + code := "" + msg := "" + if bootstrapErr != nil { + code = "network_error" + msg = bootstrapErr.Error() + } + recorder.addStep("sora2_bootstrap", "failed", bootstrapStatus, code, msg) } } if inviteStatus != http.StatusOK { - if isCloudflareChallengeResponse(inviteStatus, inviteBody) { - s.sendEvent(c, TestEvent{Type: "content", Text: formatCloudflareChallengeMessage("Sora2 invite check blocked by Cloudflare challenge (HTTP 403)", inviteHeader, inviteBody)}) + if isCloudflareChallengeResponse(inviteStatus, inviteHeader, inviteBody) { + if recorder != nil { + recorder.addStep("sora2_invite", "failed", inviteStatus, "cf_challenge", "Cloudflare challenge detected") + } + s.logSoraCloudflareChallenge(account, proxyURL, soraInviteMineURL, inviteHeader, inviteBody) + s.sendEvent(c, TestEvent{Type: "content", Text: formatCloudflareChallengeMessage(fmt.Sprintf("Sora2 invite check blocked by Cloudflare challenge (HTTP %d)", inviteStatus), inviteHeader, inviteBody)}) return } + upstreamCode, upstreamMessage := soraerror.ExtractUpstreamErrorCodeAndMessage(inviteBody) + if recorder != nil { + recorder.addStep("sora2_invite", "failed", inviteStatus, upstreamCode, upstreamMessage) + } s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 invite check returned %d", inviteStatus)}) return } + if recorder != nil { + recorder.addStep("sora2_invite", "success", inviteStatus, "", "invite endpoint ok") + } if summary := parseSoraInviteSummary(inviteBody); summary != "" { s.sendEvent(c, TestEvent{Type: "content", Text: summary}) @@ -656,17 +851,31 @@ func (s *AccountTestService) testSora2Capabilities( enableTLSFingerprint, ) if remainingErr != nil { + if recorder != nil { + recorder.addStep("sora2_remaining", "failed", 0, "network_error", remainingErr.Error()) + } s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 remaining check skipped: %s", remainingErr.Error())}) return } if remainingStatus != http.StatusOK { - if isCloudflareChallengeResponse(remainingStatus, remainingBody) { - s.sendEvent(c, TestEvent{Type: "content", Text: formatCloudflareChallengeMessage("Sora2 remaining check blocked by Cloudflare challenge (HTTP 403)", remainingHeader, remainingBody)}) + if isCloudflareChallengeResponse(remainingStatus, remainingHeader, remainingBody) { + if recorder != nil { + recorder.addStep("sora2_remaining", "failed", remainingStatus, "cf_challenge", "Cloudflare challenge detected") + } + s.logSoraCloudflareChallenge(account, proxyURL, soraRemainingURL, remainingHeader, remainingBody) + s.sendEvent(c, TestEvent{Type: "content", Text: formatCloudflareChallengeMessage(fmt.Sprintf("Sora2 remaining check blocked by Cloudflare challenge (HTTP %d)", remainingStatus), remainingHeader, remainingBody)}) return } + upstreamCode, upstreamMessage := soraerror.ExtractUpstreamErrorCodeAndMessage(remainingBody) + if recorder != nil { + recorder.addStep("sora2_remaining", "failed", remainingStatus, upstreamCode, upstreamMessage) + } s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 remaining check returned %d", remainingStatus)}) return } + if recorder != nil { + recorder.addStep("sora2_remaining", "success", remainingStatus, "", "remaining endpoint ok") + } if summary := parseSoraRemainingSummary(remainingBody); summary != "" { s.sendEvent(c, TestEvent{Type: "content", Text: summary}) } else { @@ -789,42 +998,16 @@ func (s *AccountTestService) shouldEnableSoraTLSFingerprint() bool { return !s.cfg.Sora.Client.DisableTLSFingerprint } -func isCloudflareChallengeResponse(statusCode int, body []byte) bool { - if statusCode != http.StatusForbidden { - return false - } - preview := strings.ToLower(truncateSoraErrorBody(body, 4096)) - return strings.Contains(preview, "window._cf_chl_opt") || - strings.Contains(preview, "just a moment") || - strings.Contains(preview, "enable javascript and cookies to continue") +func isCloudflareChallengeResponse(statusCode int, headers http.Header, body []byte) bool { + return soraerror.IsCloudflareChallengeResponse(statusCode, headers, body) } func formatCloudflareChallengeMessage(base string, headers http.Header, body []byte) string { - rayID := extractCloudflareRayID(headers, body) - if rayID == "" { - return base - } - return fmt.Sprintf("%s (cf-ray: %s)", base, rayID) + return soraerror.FormatCloudflareChallengeMessage(base, headers, body) } func extractCloudflareRayID(headers http.Header, body []byte) string { - if headers != nil { - rayID := strings.TrimSpace(headers.Get("cf-ray")) - if rayID != "" { - return rayID - } - rayID = strings.TrimSpace(headers.Get("Cf-Ray")) - if rayID != "" { - return rayID - } - } - - preview := truncateSoraErrorBody(body, 8192) - matches := cloudflareRayPattern.FindStringSubmatch(preview) - if len(matches) >= 2 { - return strings.TrimSpace(matches[1]) - } - return "" + return soraerror.ExtractCloudflareRayID(headers, body) } func extractSoraEgressIPHint(headers http.Header) string { @@ -897,14 +1080,7 @@ func (s *AccountTestService) logSoraCloudflareChallenge(account *Account, proxyU } func truncateSoraErrorBody(body []byte, max int) string { - if max <= 0 { - max = 512 - } - raw := strings.TrimSpace(string(body)) - if len(raw) <= max { - return raw - } - return raw[:max] + "...(truncated)" + return soraerror.TruncateBody(body, max) } // testAntigravityAccountConnection tests an Antigravity account's connection diff --git a/backend/internal/service/account_test_service_sora_test.go b/backend/internal/service/account_test_service_sora_test.go index b5389ea2..3dfac786 100644 --- a/backend/internal/service/account_test_service_sora_test.go +++ b/backend/internal/service/account_test_service_sora_test.go @@ -7,6 +7,7 @@ import ( "net/http/httptest" "strings" "testing" + "time" "github.com/Wei-Shaw/sub2api/internal/config" "github.com/gin-gonic/gin" @@ -109,6 +110,8 @@ func TestAccountTestService_testSoraAccountConnection_WithSubscription(t *testin require.Contains(t, body, "Subscription: ChatGPT Plus | chatgpt_plus | end=2026-12-31T00:00:00Z") require.Contains(t, body, "Sora2: supported | invite=inv_abc | used=3/50") require.Contains(t, body, "Sora2 remaining: 27 | reset_in=46833s") + require.Contains(t, body, `"type":"sora_test_result"`) + require.Contains(t, body, `"status":"success"`) require.Contains(t, body, `"type":"test_complete","success":true`) } @@ -141,6 +144,8 @@ func TestAccountTestService_testSoraAccountConnection_SubscriptionFailedStillSuc require.Contains(t, body, "Sora connection OK - User: demo-user") require.Contains(t, body, "Subscription check returned 403") require.Contains(t, body, "Sora2 invite check returned 401") + require.Contains(t, body, `"type":"sora_test_result"`) + require.Contains(t, body, `"status":"partial_success"`) require.Contains(t, body, `"type":"test_complete","success":true`) } @@ -173,6 +178,97 @@ func TestAccountTestService_testSoraAccountConnection_CloudflareChallenge(t *tes require.Contains(t, body, "cf-ray: 9cff2d62d83bb98d") } +func TestAccountTestService_testSoraAccountConnection_CloudflareChallenge429WithHeader(t *testing.T) { + upstream := &queuedHTTPUpstream{ + responses: []*http.Response{ + newJSONResponseWithHeader(http.StatusTooManyRequests, `Just a moment...`, "cf-mitigated", "challenge"), + }, + } + svc := &AccountTestService{httpUpstream: upstream} + account := &Account{ + ID: 1, + Platform: PlatformSora, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "test_token", + }, + } + + c, rec := newSoraTestContext() + err := svc.testSoraAccountConnection(c, account) + + require.Error(t, err) + require.Contains(t, err.Error(), "Cloudflare challenge") + require.Contains(t, err.Error(), "HTTP 429") + body := rec.Body.String() + require.Contains(t, body, "Cloudflare challenge") +} + +func TestAccountTestService_testSoraAccountConnection_TokenInvalidated(t *testing.T) { + upstream := &queuedHTTPUpstream{ + responses: []*http.Response{ + newJSONResponse(http.StatusUnauthorized, `{"error":{"code":"token_invalidated","message":"Token invalid"}}`), + }, + } + svc := &AccountTestService{httpUpstream: upstream} + account := &Account{ + ID: 1, + Platform: PlatformSora, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "test_token", + }, + } + + c, rec := newSoraTestContext() + err := svc.testSoraAccountConnection(c, account) + + require.Error(t, err) + require.Contains(t, err.Error(), "token_invalidated") + body := rec.Body.String() + require.Contains(t, body, `"type":"sora_test_result"`) + require.Contains(t, body, `"status":"failed"`) + require.Contains(t, body, "token_invalidated") + require.NotContains(t, body, `"type":"test_complete","success":true`) +} + +func TestAccountTestService_testSoraAccountConnection_RateLimited(t *testing.T) { + upstream := &queuedHTTPUpstream{ + responses: []*http.Response{ + newJSONResponse(http.StatusOK, `{"email":"demo@example.com"}`), + }, + } + svc := &AccountTestService{ + httpUpstream: upstream, + soraTestCooldown: time.Hour, + } + account := &Account{ + ID: 1, + Platform: PlatformSora, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "test_token", + }, + } + + c1, _ := newSoraTestContext() + err := svc.testSoraAccountConnection(c1, account) + require.NoError(t, err) + + c2, rec2 := newSoraTestContext() + err = svc.testSoraAccountConnection(c2, account) + require.Error(t, err) + require.Contains(t, err.Error(), "测试过于频繁") + body := rec2.Body.String() + require.Contains(t, body, `"type":"sora_test_result"`) + require.Contains(t, body, `"code":"test_rate_limited"`) + require.Contains(t, body, `"status":"failed"`) + require.NotContains(t, body, `"type":"test_complete","success":true`) +} + func TestAccountTestService_testSoraAccountConnection_SubscriptionCloudflareChallengeWithRay(t *testing.T) { upstream := &queuedHTTPUpstream{ responses: []*http.Response{ diff --git a/backend/internal/service/sora_client.go b/backend/internal/service/sora_client.go index 7ca99ad2..e1af5ead 100644 --- a/backend/internal/service/sora_client.go +++ b/backend/internal/service/sora_client.go @@ -95,6 +95,16 @@ var soraDesktopUserAgents = []string{ "Mozilla/5.0 (Windows NT 11.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36", } +var soraMobileUserAgents = []string{ + "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)", + "Sora/1.2026.007 (Android 14; SM-G998B; build 2600700)", + "Sora/1.2026.007 (Android 15; Pixel 8 Pro; build 2600700)", + "Sora/1.2026.007 (Android 14; Pixel 7; build 2600700)", + "Sora/1.2026.007 (Android 15; 2211133C; build 2600700)", + "Sora/1.2026.007 (Android 14; SM-S918B; build 2600700)", + "Sora/1.2026.007 (Android 15; OnePlus 12; build 2600700)", +} + var soraRand = rand.New(rand.NewSource(time.Now().UnixNano())) var soraRandMu sync.Mutex var soraPerfStart = time.Now() @@ -106,6 +116,17 @@ type SoraClient interface { UploadImage(ctx context.Context, account *Account, data []byte, filename string) (string, error) CreateImageTask(ctx context.Context, account *Account, req SoraImageRequest) (string, error) CreateVideoTask(ctx context.Context, account *Account, req SoraVideoRequest) (string, error) + CreateStoryboardTask(ctx context.Context, account *Account, req SoraStoryboardRequest) (string, error) + UploadCharacterVideo(ctx context.Context, account *Account, data []byte) (string, error) + GetCameoStatus(ctx context.Context, account *Account, cameoID string) (*SoraCameoStatus, error) + DownloadCharacterImage(ctx context.Context, account *Account, imageURL string) ([]byte, error) + UploadCharacterImage(ctx context.Context, account *Account, data []byte) (string, error) + FinalizeCharacter(ctx context.Context, account *Account, req SoraCharacterFinalizeRequest) (string, error) + SetCharacterPublic(ctx context.Context, account *Account, cameoID string) error + DeleteCharacter(ctx context.Context, account *Account, characterID string) error + PostVideoForWatermarkFree(ctx context.Context, account *Account, generationID string) (string, error) + DeletePost(ctx context.Context, account *Account, postID string) error + GetWatermarkFreeURLCustom(ctx context.Context, account *Account, parseURL, parseToken, postID string) (string, error) EnhancePrompt(ctx context.Context, account *Account, prompt, expansionLevel string, durationS int) (string, error) GetImageTask(ctx context.Context, account *Account, taskID string) (*SoraImageTaskStatus, error) GetVideoTask(ctx context.Context, account *Account, taskID string) (*SoraVideoTaskStatus, error) @@ -128,6 +149,17 @@ type SoraVideoRequest struct { Size string MediaID string RemixTargetID string + CameoIDs []string +} + +// SoraStoryboardRequest 分镜视频生成请求参数 +type SoraStoryboardRequest struct { + Prompt string + Orientation string + Frames int + Model string + Size string + MediaID string } // SoraImageTaskStatus 图片任务状态 @@ -141,11 +173,32 @@ type SoraImageTaskStatus struct { // SoraVideoTaskStatus 视频任务状态 type SoraVideoTaskStatus struct { - ID string - Status string - ProgressPct int - URLs []string - ErrorMsg string + ID string + Status string + ProgressPct int + URLs []string + GenerationID string + ErrorMsg string +} + +// SoraCameoStatus 角色处理中间态 +type SoraCameoStatus struct { + Status string + StatusMessage string + DisplayNameHint string + UsernameHint string + ProfileAssetURL string + InstructionSetHint any + InstructionSet any +} + +// SoraCharacterFinalizeRequest 角色定稿请求参数 +type SoraCharacterFinalizeRequest struct { + CameoID string + Username string + DisplayName string + ProfileAssetPointer string + InstructionSet any } // SoraUpstreamError 上游错误 @@ -407,6 +460,9 @@ func (c *SoraDirectClient) CreateVideoTask(ctx context.Context, account *Account payload["remix_target_id"] = req.RemixTargetID payload["cameo_ids"] = []string{} payload["cameo_replacements"] = map[string]any{} + } else if len(req.CameoIDs) > 0 { + payload["cameo_ids"] = req.CameoIDs + payload["cameo_replacements"] = map[string]any{} } headers := c.buildBaseHeaders(token, userAgent) @@ -434,6 +490,425 @@ func (c *SoraDirectClient) CreateVideoTask(ctx context.Context, account *Account return taskID, nil } +func (c *SoraDirectClient) CreateStoryboardTask(ctx context.Context, account *Account, req SoraStoryboardRequest) (string, error) { + token, err := c.getAccessToken(ctx, account) + if err != nil { + return "", err + } + userAgent := c.taskUserAgent() + proxyURL := c.resolveProxyURL(account) + orientation := req.Orientation + if orientation == "" { + orientation = "landscape" + } + nFrames := req.Frames + if nFrames <= 0 { + nFrames = 450 + } + model := req.Model + if model == "" { + model = "sy_8" + } + size := req.Size + if size == "" { + size = "small" + } + + inpaintItems := []map[string]any{} + if strings.TrimSpace(req.MediaID) != "" { + inpaintItems = append(inpaintItems, map[string]any{ + "kind": "upload", + "upload_id": req.MediaID, + }) + } + payload := map[string]any{ + "kind": "video", + "prompt": req.Prompt, + "title": "Draft your video", + "orientation": orientation, + "size": size, + "n_frames": nFrames, + "storyboard_id": nil, + "inpaint_items": inpaintItems, + "remix_target_id": nil, + "model": model, + "metadata": nil, + "style_id": nil, + "cameo_ids": nil, + "cameo_replacements": nil, + "audio_caption": nil, + "audio_transcript": nil, + "video_caption": nil, + } + + headers := c.buildBaseHeaders(token, userAgent) + headers.Set("Content-Type", "application/json") + headers.Set("Origin", "https://sora.chatgpt.com") + headers.Set("Referer", "https://sora.chatgpt.com/") + body, err := json.Marshal(payload) + if err != nil { + return "", err + } + sentinel, err := c.generateSentinelToken(ctx, account, token, userAgent, proxyURL) + if err != nil { + return "", err + } + headers.Set("openai-sentinel-token", sentinel) + + respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodPost, c.buildURL("/nf/create/storyboard"), headers, bytes.NewReader(body), true) + if err != nil { + return "", err + } + taskID := strings.TrimSpace(gjson.GetBytes(respBody, "id").String()) + if taskID == "" { + return "", errors.New("storyboard task response missing id") + } + return taskID, nil +} + +func (c *SoraDirectClient) UploadCharacterVideo(ctx context.Context, account *Account, data []byte) (string, error) { + if len(data) == 0 { + return "", errors.New("empty video data") + } + token, err := c.getAccessToken(ctx, account) + if err != nil { + return "", err + } + userAgent := c.taskUserAgent() + proxyURL := c.resolveProxyURL(account) + + var body bytes.Buffer + writer := multipart.NewWriter(&body) + partHeader := make(textproto.MIMEHeader) + partHeader.Set("Content-Disposition", `form-data; name="file"; filename="video.mp4"`) + partHeader.Set("Content-Type", "video/mp4") + part, err := writer.CreatePart(partHeader) + if err != nil { + return "", err + } + if _, err := part.Write(data); err != nil { + return "", err + } + if err := writer.WriteField("timestamps", "0,3"); err != nil { + return "", err + } + if err := writer.Close(); err != nil { + return "", err + } + + headers := c.buildBaseHeaders(token, userAgent) + headers.Set("Content-Type", writer.FormDataContentType()) + respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodPost, c.buildURL("/characters/upload"), headers, &body, false) + if err != nil { + return "", err + } + cameoID := strings.TrimSpace(gjson.GetBytes(respBody, "id").String()) + if cameoID == "" { + return "", errors.New("character upload response missing id") + } + return cameoID, nil +} + +func (c *SoraDirectClient) GetCameoStatus(ctx context.Context, account *Account, cameoID string) (*SoraCameoStatus, error) { + token, err := c.getAccessToken(ctx, account) + if err != nil { + return nil, err + } + userAgent := c.taskUserAgent() + proxyURL := c.resolveProxyURL(account) + headers := c.buildBaseHeaders(token, userAgent) + respBody, _, err := c.doRequestWithProxy( + ctx, + account, + proxyURL, + http.MethodGet, + c.buildURL("/project_y/cameos/in_progress/"+strings.TrimSpace(cameoID)), + headers, + nil, + false, + ) + if err != nil { + return nil, err + } + return &SoraCameoStatus{ + Status: strings.TrimSpace(gjson.GetBytes(respBody, "status").String()), + StatusMessage: strings.TrimSpace(gjson.GetBytes(respBody, "status_message").String()), + DisplayNameHint: strings.TrimSpace(gjson.GetBytes(respBody, "display_name_hint").String()), + UsernameHint: strings.TrimSpace(gjson.GetBytes(respBody, "username_hint").String()), + ProfileAssetURL: strings.TrimSpace(gjson.GetBytes(respBody, "profile_asset_url").String()), + InstructionSetHint: gjson.GetBytes(respBody, "instruction_set_hint").Value(), + InstructionSet: gjson.GetBytes(respBody, "instruction_set").Value(), + }, nil +} + +func (c *SoraDirectClient) DownloadCharacterImage(ctx context.Context, account *Account, imageURL string) ([]byte, error) { + token, err := c.getAccessToken(ctx, account) + if err != nil { + return nil, err + } + userAgent := c.taskUserAgent() + proxyURL := c.resolveProxyURL(account) + headers := c.buildBaseHeaders(token, userAgent) + headers.Set("Accept", "image/*,*/*;q=0.8") + + respBody, _, err := c.doRequestWithProxy( + ctx, + account, + proxyURL, + http.MethodGet, + strings.TrimSpace(imageURL), + headers, + nil, + false, + ) + if err != nil { + return nil, err + } + return respBody, nil +} + +func (c *SoraDirectClient) UploadCharacterImage(ctx context.Context, account *Account, data []byte) (string, error) { + if len(data) == 0 { + return "", errors.New("empty character image") + } + token, err := c.getAccessToken(ctx, account) + if err != nil { + return "", err + } + userAgent := c.taskUserAgent() + proxyURL := c.resolveProxyURL(account) + + var body bytes.Buffer + writer := multipart.NewWriter(&body) + partHeader := make(textproto.MIMEHeader) + partHeader.Set("Content-Disposition", `form-data; name="file"; filename="profile.webp"`) + partHeader.Set("Content-Type", "image/webp") + part, err := writer.CreatePart(partHeader) + if err != nil { + return "", err + } + if _, err := part.Write(data); err != nil { + return "", err + } + if err := writer.WriteField("use_case", "profile"); err != nil { + return "", err + } + if err := writer.Close(); err != nil { + return "", err + } + + headers := c.buildBaseHeaders(token, userAgent) + headers.Set("Content-Type", writer.FormDataContentType()) + respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodPost, c.buildURL("/project_y/file/upload"), headers, &body, false) + if err != nil { + return "", err + } + assetPointer := strings.TrimSpace(gjson.GetBytes(respBody, "asset_pointer").String()) + if assetPointer == "" { + return "", errors.New("character image upload response missing asset_pointer") + } + return assetPointer, nil +} + +func (c *SoraDirectClient) FinalizeCharacter(ctx context.Context, account *Account, req SoraCharacterFinalizeRequest) (string, error) { + token, err := c.getAccessToken(ctx, account) + if err != nil { + return "", err + } + userAgent := c.taskUserAgent() + proxyURL := c.resolveProxyURL(account) + payload := map[string]any{ + "cameo_id": req.CameoID, + "username": req.Username, + "display_name": req.DisplayName, + "profile_asset_pointer": req.ProfileAssetPointer, + "instruction_set": nil, + "safety_instruction_set": nil, + } + body, err := json.Marshal(payload) + if err != nil { + return "", err + } + headers := c.buildBaseHeaders(token, userAgent) + headers.Set("Content-Type", "application/json") + headers.Set("Origin", "https://sora.chatgpt.com") + headers.Set("Referer", "https://sora.chatgpt.com/") + respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodPost, c.buildURL("/characters/finalize"), headers, bytes.NewReader(body), false) + if err != nil { + return "", err + } + characterID := strings.TrimSpace(gjson.GetBytes(respBody, "character.character_id").String()) + if characterID == "" { + return "", errors.New("character finalize response missing character_id") + } + return characterID, nil +} + +func (c *SoraDirectClient) SetCharacterPublic(ctx context.Context, account *Account, cameoID string) error { + token, err := c.getAccessToken(ctx, account) + if err != nil { + return err + } + userAgent := c.taskUserAgent() + proxyURL := c.resolveProxyURL(account) + payload := map[string]any{"visibility": "public"} + body, err := json.Marshal(payload) + if err != nil { + return err + } + headers := c.buildBaseHeaders(token, userAgent) + headers.Set("Content-Type", "application/json") + headers.Set("Origin", "https://sora.chatgpt.com") + headers.Set("Referer", "https://sora.chatgpt.com/") + _, _, err = c.doRequestWithProxy( + ctx, + account, + proxyURL, + http.MethodPost, + c.buildURL("/project_y/cameos/by_id/"+strings.TrimSpace(cameoID)+"/update_v2"), + headers, + bytes.NewReader(body), + false, + ) + return err +} + +func (c *SoraDirectClient) DeleteCharacter(ctx context.Context, account *Account, characterID string) error { + token, err := c.getAccessToken(ctx, account) + if err != nil { + return err + } + userAgent := c.taskUserAgent() + proxyURL := c.resolveProxyURL(account) + headers := c.buildBaseHeaders(token, userAgent) + _, _, err = c.doRequestWithProxy( + ctx, + account, + proxyURL, + http.MethodDelete, + c.buildURL("/project_y/characters/"+strings.TrimSpace(characterID)), + headers, + nil, + false, + ) + return err +} + +func (c *SoraDirectClient) PostVideoForWatermarkFree(ctx context.Context, account *Account, generationID string) (string, error) { + token, err := c.getAccessToken(ctx, account) + if err != nil { + return "", err + } + userAgent := c.taskUserAgent() + proxyURL := c.resolveProxyURL(account) + payload := map[string]any{ + "attachments_to_create": []map[string]any{ + { + "generation_id": generationID, + "kind": "sora", + }, + }, + "post_text": "", + } + body, err := json.Marshal(payload) + if err != nil { + return "", err + } + headers := c.buildBaseHeaders(token, userAgent) + headers.Set("Content-Type", "application/json") + headers.Set("Origin", "https://sora.chatgpt.com") + headers.Set("Referer", "https://sora.chatgpt.com/") + sentinel, err := c.generateSentinelToken(ctx, account, token, userAgent, proxyURL) + if err != nil { + return "", err + } + headers.Set("openai-sentinel-token", sentinel) + respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodPost, c.buildURL("/project_y/post"), headers, bytes.NewReader(body), true) + if err != nil { + return "", err + } + postID := strings.TrimSpace(gjson.GetBytes(respBody, "post.id").String()) + if postID == "" { + return "", errors.New("watermark-free publish response missing post.id") + } + return postID, nil +} + +func (c *SoraDirectClient) DeletePost(ctx context.Context, account *Account, postID string) error { + token, err := c.getAccessToken(ctx, account) + if err != nil { + return err + } + userAgent := c.taskUserAgent() + proxyURL := c.resolveProxyURL(account) + headers := c.buildBaseHeaders(token, userAgent) + _, _, err = c.doRequestWithProxy( + ctx, + account, + proxyURL, + http.MethodDelete, + c.buildURL("/project_y/post/"+strings.TrimSpace(postID)), + headers, + nil, + false, + ) + return err +} + +func (c *SoraDirectClient) GetWatermarkFreeURLCustom(ctx context.Context, account *Account, parseURL, parseToken, postID string) (string, error) { + parseURL = strings.TrimRight(strings.TrimSpace(parseURL), "/") + if parseURL == "" { + return "", errors.New("custom parse url is required") + } + if strings.TrimSpace(parseToken) == "" { + return "", errors.New("custom parse token is required") + } + shareURL := "https://sora.chatgpt.com/p/" + strings.TrimSpace(postID) + payload := map[string]any{ + "url": shareURL, + "token": strings.TrimSpace(parseToken), + } + body, err := json.Marshal(payload) + if err != nil { + return "", err + } + req, err := http.NewRequestWithContext(ctx, http.MethodPost, parseURL+"/get-sora-link", bytes.NewReader(body)) + if err != nil { + return "", err + } + req.Header.Set("Content-Type", "application/json") + + proxyURL := c.resolveProxyURL(account) + accountID := int64(0) + accountConcurrency := 0 + if account != nil { + accountID = account.ID + accountConcurrency = account.Concurrency + } + var resp *http.Response + if c.httpUpstream != nil { + resp, err = c.httpUpstream.Do(req, proxyURL, accountID, accountConcurrency) + } else { + resp, err = http.DefaultClient.Do(req) + } + if err != nil { + return "", err + } + defer func() { _ = resp.Body.Close() }() + raw, err := io.ReadAll(io.LimitReader(resp.Body, 4<<20)) + if err != nil { + return "", err + } + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("custom parse failed: %d %s", resp.StatusCode, truncateForLog(raw, 256)) + } + downloadLink := strings.TrimSpace(gjson.GetBytes(raw, "download_link").String()) + if downloadLink == "" { + return "", errors.New("custom parse response missing download_link") + } + return downloadLink, nil +} + func (c *SoraDirectClient) EnhancePrompt(ctx context.Context, account *Account, prompt, expansionLevel string, durationS int) (string, error) { token, err := c.getAccessToken(ctx, account) if err != nil { @@ -607,6 +1082,7 @@ func (c *SoraDirectClient) GetVideoTask(ctx context.Context, account *Account, t if draft.Get("task_id").String() != taskID { return true } + generationID := strings.TrimSpace(draft.Get("id").String()) kind := strings.TrimSpace(draft.Get("kind").String()) reason := strings.TrimSpace(draft.Get("reason_str").String()) if reason == "" { @@ -623,15 +1099,17 @@ func (c *SoraDirectClient) GetVideoTask(ctx context.Context, account *Account, t msg = "Content violates guardrails" } draftFound = &SoraVideoTaskStatus{ - ID: taskID, - Status: "failed", - ErrorMsg: msg, + ID: taskID, + Status: "failed", + GenerationID: generationID, + ErrorMsg: msg, } } else { draftFound = &SoraVideoTaskStatus{ - ID: taskID, - Status: "completed", - URLs: []string{urlStr}, + ID: taskID, + Status: "completed", + GenerationID: generationID, + URLs: []string{urlStr}, } } return false @@ -675,8 +1153,11 @@ func (c *SoraDirectClient) taskUserAgent() string { return ua } } + if len(soraMobileUserAgents) > 0 { + return soraMobileUserAgents[soraRandInt(len(soraMobileUserAgents))] + } if len(soraDesktopUserAgents) > 0 { - return soraDesktopUserAgents[0] + return soraDesktopUserAgents[soraRandInt(len(soraDesktopUserAgents))] } return soraDefaultUserAgent } @@ -1149,10 +1630,7 @@ func shouldAttemptSoraTokenRecover(statusCode int, rawURL string) bool { } func (c *SoraDirectClient) doHTTP(req *http.Request, proxyURL string, account *Account) (*http.Response, error) { - enableTLS := true - if c != nil && c.cfg != nil && c.cfg.Sora.Client.DisableTLSFingerprint { - enableTLS = false - } + enableTLS := c == nil || c.cfg == nil || !c.cfg.Sora.Client.DisableTLSFingerprint if c.httpUpstream != nil { accountID := int64(0) accountConcurrency := 0 @@ -1288,6 +1766,15 @@ func soraRandFloat() float64 { return soraRand.Float64() } +func soraRandInt(max int) int { + if max <= 1 { + return 0 + } + soraRandMu.Lock() + defer soraRandMu.Unlock() + return soraRand.Intn(max) +} + func soraBuildPowConfig(userAgent string) []any { userAgent = strings.TrimSpace(userAgent) if userAgent == "" && len(soraDesktopUserAgents) > 0 { diff --git a/backend/internal/service/sora_client_test.go b/backend/internal/service/sora_client_test.go index d50b2d85..9e528f97 100644 --- a/backend/internal/service/sora_client_test.go +++ b/backend/internal/service/sora_client_test.go @@ -393,10 +393,22 @@ func (u *soraClientRecordingUpstream) DoWithTLS(req *http.Request, proxyURL stri return newSoraClientMockResponse(http.StatusOK, `{"token":"sentinel-token","turnstile":{"dx":"ok"}}`), nil case "/backend/nf/create": return newSoraClientMockResponse(http.StatusOK, `{"id":"task-123"}`), nil + case "/backend/nf/create/storyboard": + return newSoraClientMockResponse(http.StatusOK, `{"id":"storyboard-123"}`), nil case "/backend/uploads": return newSoraClientMockResponse(http.StatusOK, `{"id":"upload-123"}`), nil case "/backend/nf/check": return newSoraClientMockResponse(http.StatusOK, `{"rate_limit_and_credit_balance":{"estimated_num_videos_remaining":1,"rate_limit_reached":false}}`), nil + case "/backend/characters/upload": + return newSoraClientMockResponse(http.StatusOK, `{"id":"cameo-123"}`), nil + case "/backend/project_y/cameos/in_progress/cameo-123": + return newSoraClientMockResponse(http.StatusOK, `{"status":"finalized","status_message":"Completed","username_hint":"foo.bar","display_name_hint":"Bar","profile_asset_url":"https://example.com/avatar.webp"}`), nil + case "/backend/project_y/file/upload": + return newSoraClientMockResponse(http.StatusOK, `{"asset_pointer":"asset-123"}`), nil + case "/backend/characters/finalize": + return newSoraClientMockResponse(http.StatusOK, `{"character":{"character_id":"character-123"}}`), nil + case "/backend/project_y/post": + return newSoraClientMockResponse(http.StatusOK, `{"post":{"id":"s_post"}}`), nil default: return newSoraClientMockResponse(http.StatusOK, `{"ok":true}`), nil } @@ -410,9 +422,13 @@ func newSoraClientMockResponse(statusCode int, body string) *http.Response { } } -func TestSoraDirectClient_TaskUserAgent_DefaultDesktopFallback(t *testing.T) { +func TestSoraDirectClient_TaskUserAgent_DefaultMobileFallback(t *testing.T) { client := NewSoraDirectClient(&config.Config{}, nil, nil) - require.Equal(t, soraDesktopUserAgents[0], client.taskUserAgent()) + ua := client.taskUserAgent() + require.NotEmpty(t, ua) + allowed := append([]string{}, soraMobileUserAgents...) + allowed = append(allowed, soraDesktopUserAgents...) + require.Contains(t, allowed, ua) } func TestSoraDirectClient_CreateVideoTask_UsesSameUserAgentAndProxyForSentinelAndCreate(t *testing.T) { @@ -460,7 +476,7 @@ func TestSoraDirectClient_CreateVideoTask_UsesSameUserAgentAndProxyForSentinelAn require.Equal(t, "/backend/nf/create", createCall.Path) require.Equal(t, "http://127.0.0.1:8080", sentinelCall.ProxyURL) require.Equal(t, sentinelCall.ProxyURL, createCall.ProxyURL) - require.Equal(t, soraDesktopUserAgents[0], sentinelCall.UserAgent) + require.NotEmpty(t, sentinelCall.UserAgent) require.Equal(t, sentinelCall.UserAgent, createCall.UserAgent) } @@ -495,7 +511,7 @@ func TestSoraDirectClient_UploadImage_UsesTaskUserAgentAndProxy(t *testing.T) { require.Len(t, upstream.calls, 1) require.Equal(t, "/backend/uploads", upstream.calls[0].Path) require.Equal(t, "http://127.0.0.1:8080", upstream.calls[0].ProxyURL) - require.Equal(t, soraDesktopUserAgents[0], upstream.calls[0].UserAgent) + require.NotEmpty(t, upstream.calls[0].UserAgent) } func TestSoraDirectClient_PreflightCheck_UsesTaskUserAgentAndProxy(t *testing.T) { @@ -528,5 +544,98 @@ func TestSoraDirectClient_PreflightCheck_UsesTaskUserAgentAndProxy(t *testing.T) require.Len(t, upstream.calls, 1) require.Equal(t, "/backend/nf/check", upstream.calls[0].Path) require.Equal(t, "http://127.0.0.1:8080", upstream.calls[0].ProxyURL) - require.Equal(t, soraDesktopUserAgents[0], upstream.calls[0].UserAgent) + require.NotEmpty(t, upstream.calls[0].UserAgent) +} + +func TestSoraDirectClient_CreateStoryboardTask(t *testing.T) { + originPowTokenGenerator := soraPowTokenGenerator + soraPowTokenGenerator = func(_ string) string { return "gAAAAACmock" } + defer func() { soraPowTokenGenerator = originPowTokenGenerator }() + + upstream := &soraClientRecordingUpstream{} + cfg := &config.Config{ + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + BaseURL: "https://sora.chatgpt.com/backend", + }, + }, + } + client := NewSoraDirectClient(cfg, upstream, nil) + account := &Account{ + ID: 51, + Credentials: map[string]any{ + "access_token": "access-token", + "expires_at": time.Now().Add(30 * time.Minute).Format(time.RFC3339), + }, + } + + taskID, err := client.CreateStoryboardTask(context.Background(), account, SoraStoryboardRequest{ + Prompt: "Shot 1:\nduration: 5sec\nScene: cat", + }) + require.NoError(t, err) + require.Equal(t, "storyboard-123", taskID) + require.Len(t, upstream.calls, 2) + require.Equal(t, "/backend-api/sentinel/req", upstream.calls[0].Path) + require.Equal(t, "/backend/nf/create/storyboard", upstream.calls[1].Path) +} + +func TestSoraDirectClient_GetVideoTask_ReturnsGenerationID(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + switch r.URL.Path { + case "/nf/pending/v2": + _, _ = w.Write([]byte(`[]`)) + case "/project_y/profile/drafts": + _, _ = w.Write([]byte(`{"items":[{"id":"gen_1","task_id":"task-1","kind":"video","downloadable_url":"https://example.com/v.mp4"}]}`)) + default: + http.NotFound(w, r) + } + })) + defer server.Close() + + cfg := &config.Config{ + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + BaseURL: server.URL, + }, + }, + } + client := NewSoraDirectClient(cfg, nil, nil) + account := &Account{Credentials: map[string]any{"access_token": "token"}} + + status, err := client.GetVideoTask(context.Background(), account, "task-1") + require.NoError(t, err) + require.Equal(t, "completed", status.Status) + require.Equal(t, "gen_1", status.GenerationID) + require.Equal(t, []string{"https://example.com/v.mp4"}, status.URLs) +} + +func TestSoraDirectClient_PostVideoForWatermarkFree(t *testing.T) { + originPowTokenGenerator := soraPowTokenGenerator + soraPowTokenGenerator = func(_ string) string { return "gAAAAACmock" } + defer func() { soraPowTokenGenerator = originPowTokenGenerator }() + + upstream := &soraClientRecordingUpstream{} + cfg := &config.Config{ + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + BaseURL: "https://sora.chatgpt.com/backend", + }, + }, + } + client := NewSoraDirectClient(cfg, upstream, nil) + account := &Account{ + ID: 52, + Credentials: map[string]any{ + "access_token": "access-token", + "expires_at": time.Now().Add(30 * time.Minute).Format(time.RFC3339), + }, + } + + postID, err := client.PostVideoForWatermarkFree(context.Background(), account, "gen_1") + require.NoError(t, err) + require.Equal(t, "s_post", postID) + require.Len(t, upstream.calls, 2) + require.Equal(t, "/backend-api/sentinel/req", upstream.calls[0].Path) + require.Equal(t, "/backend/project_y/post", upstream.calls[1].Path) } diff --git a/backend/internal/service/sora_gateway_service.go b/backend/internal/service/sora_gateway_service.go index ef47f6d4..054d38e7 100644 --- a/backend/internal/service/sora_gateway_service.go +++ b/backend/internal/service/sora_gateway_service.go @@ -8,10 +8,12 @@ import ( "fmt" "io" "log" + "math" "mime" "net" "net/http" "net/url" + "regexp" "strconv" "strings" "time" @@ -23,6 +25,9 @@ import ( const soraImageInputMaxBytes = 20 << 20 const soraImageInputMaxRedirects = 3 const soraImageInputTimeout = 20 * time.Second +const soraVideoInputMaxBytes = 200 << 20 +const soraVideoInputMaxRedirects = 3 +const soraVideoInputTimeout = 60 * time.Second var soraImageSizeMap = map[string]string{ "gpt-image": "360", @@ -61,6 +66,32 @@ type SoraGatewayService struct { cfg *config.Config } +type soraWatermarkOptions struct { + Enabled bool + ParseMethod string + ParseURL string + ParseToken string + FallbackOnFailure bool + DeletePost bool +} + +type soraCharacterOptions struct { + SetPublic bool + DeleteAfterGenerate bool +} + +type soraCharacterFlowResult struct { + CameoID string + CharacterID string + Username string + DisplayName string +} + +var soraStoryboardPattern = regexp.MustCompile(`\[\d+(?:\.\d+)?s\]`) +var soraStoryboardShotPattern = regexp.MustCompile(`\[(\d+(?:\.\d+)?)s\]\s*([^\[]+)`) +var soraRemixTargetPattern = regexp.MustCompile(`s_[a-f0-9]{32}`) +var soraRemixTargetInURLPattern = regexp.MustCompile(`https://sora\.chatgpt\.com/p/s_[a-f0-9]{32}`) + type soraPreflightChecker interface { PreflightCheck(ctx context.Context, account *Account, requestedModel string, modelCfg SoraModelConfig) error } @@ -117,20 +148,34 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun return nil, fmt.Errorf("unsupported model: %s", reqModel) } prompt, imageInput, videoInput, remixTargetID := extractSoraInput(reqBody) - if strings.TrimSpace(prompt) == "" { + prompt = strings.TrimSpace(prompt) + imageInput = strings.TrimSpace(imageInput) + videoInput = strings.TrimSpace(videoInput) + remixTargetID = strings.TrimSpace(remixTargetID) + + if videoInput != "" && modelCfg.Type != "video" { + s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "video input only supports video models", clientStream) + return nil, errors.New("video input only supports video models") + } + if videoInput != "" && imageInput != "" { + s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "image input and video input cannot be used together", clientStream) + return nil, errors.New("image input and video input cannot be used together") + } + characterOnly := videoInput != "" && prompt == "" + if modelCfg.Type == "prompt_enhance" && prompt == "" { s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "prompt is required", clientStream) return nil, errors.New("prompt is required") } - if strings.TrimSpace(videoInput) != "" { - s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "Video input is not supported yet", clientStream) - return nil, errors.New("video input not supported") + if modelCfg.Type != "prompt_enhance" && prompt == "" && !characterOnly { + s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "prompt is required", clientStream) + return nil, errors.New("prompt is required") } reqCtx, cancel := s.withSoraTimeout(ctx, reqStream) if cancel != nil { defer cancel() } - if checker, ok := s.soraClient.(soraPreflightChecker); ok { + if checker, ok := s.soraClient.(soraPreflightChecker); ok && !characterOnly { if err := checker.PreflightCheck(reqCtx, account, reqModel, modelCfg); err != nil { return nil, s.handleSoraRequestError(ctx, account, err, reqModel, c, clientStream) } @@ -166,9 +211,69 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun }, nil } + characterOpts := parseSoraCharacterOptions(reqBody) + watermarkOpts := parseSoraWatermarkOptions(reqBody) + var characterResult *soraCharacterFlowResult + if videoInput != "" { + videoData, videoErr := decodeSoraVideoInput(reqCtx, videoInput) + if videoErr != nil { + s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", videoErr.Error(), clientStream) + return nil, videoErr + } + characterResult, videoErr = s.createCharacterFromVideo(reqCtx, account, videoData, characterOpts) + if videoErr != nil { + return nil, s.handleSoraRequestError(ctx, account, videoErr, reqModel, c, clientStream) + } + if characterResult != nil && characterOpts.DeleteAfterGenerate && strings.TrimSpace(characterResult.CharacterID) != "" && !characterOnly { + characterID := strings.TrimSpace(characterResult.CharacterID) + defer func() { + cleanupCtx, cancelCleanup := context.WithTimeout(context.Background(), 15*time.Second) + defer cancelCleanup() + if err := s.soraClient.DeleteCharacter(cleanupCtx, account, characterID); err != nil { + log.Printf("[Sora] cleanup character failed, character_id=%s err=%v", characterID, err) + } + }() + } + if characterOnly { + content := "角色创建成功" + if characterResult != nil && strings.TrimSpace(characterResult.Username) != "" { + content = fmt.Sprintf("角色创建成功,角色名@%s", strings.TrimSpace(characterResult.Username)) + } + var firstTokenMs *int + if clientStream { + ms, streamErr := s.writeSoraStream(c, reqModel, content, startTime) + if streamErr != nil { + return nil, streamErr + } + firstTokenMs = ms + } else if c != nil { + resp := buildSoraNonStreamResponse(content, reqModel) + if characterResult != nil { + resp["character_id"] = characterResult.CharacterID + resp["cameo_id"] = characterResult.CameoID + resp["character_username"] = characterResult.Username + resp["character_display_name"] = characterResult.DisplayName + } + c.JSON(http.StatusOK, resp) + } + return &ForwardResult{ + RequestID: "", + Model: reqModel, + Stream: clientStream, + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + Usage: ClaudeUsage{}, + MediaType: "prompt", + }, nil + } + if characterResult != nil && strings.TrimSpace(characterResult.Username) != "" { + prompt = fmt.Sprintf("@%s %s", characterResult.Username, prompt) + } + } + var imageData []byte imageFilename := "" - if strings.TrimSpace(imageInput) != "" { + if imageInput != "" { decoded, filename, err := decodeSoraImageInput(reqCtx, imageInput) if err != nil { s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", err.Error(), clientStream) @@ -198,15 +303,27 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun MediaID: mediaID, }) case "video": - taskID, err = s.soraClient.CreateVideoTask(reqCtx, account, SoraVideoRequest{ - Prompt: prompt, - Orientation: modelCfg.Orientation, - Frames: modelCfg.Frames, - Model: modelCfg.Model, - Size: modelCfg.Size, - MediaID: mediaID, - RemixTargetID: remixTargetID, - }) + if remixTargetID == "" && isSoraStoryboardPrompt(prompt) { + taskID, err = s.soraClient.CreateStoryboardTask(reqCtx, account, SoraStoryboardRequest{ + Prompt: formatSoraStoryboardPrompt(prompt), + Orientation: modelCfg.Orientation, + Frames: modelCfg.Frames, + Model: modelCfg.Model, + Size: modelCfg.Size, + MediaID: mediaID, + }) + } else { + taskID, err = s.soraClient.CreateVideoTask(reqCtx, account, SoraVideoRequest{ + Prompt: prompt, + Orientation: modelCfg.Orientation, + Frames: modelCfg.Frames, + Model: modelCfg.Model, + Size: modelCfg.Size, + MediaID: mediaID, + RemixTargetID: remixTargetID, + CameoIDs: extractSoraCameoIDs(reqBody), + }) + } default: err = fmt.Errorf("unsupported model type: %s", modelCfg.Type) } @@ -219,6 +336,7 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun } var mediaURLs []string + videoGenerationID := "" mediaType := modelCfg.Type imageCount := 0 imageSize := "" @@ -232,15 +350,32 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun imageCount = len(urls) imageSize = soraImageSizeFromModel(reqModel) case "video": - urls, pollErr := s.pollVideoTask(reqCtx, c, account, taskID, clientStream) + videoStatus, pollErr := s.pollVideoTaskDetailed(reqCtx, c, account, taskID, clientStream) if pollErr != nil { return nil, s.handleSoraRequestError(ctx, account, pollErr, reqModel, c, clientStream) } - mediaURLs = urls + if videoStatus != nil { + mediaURLs = videoStatus.URLs + videoGenerationID = strings.TrimSpace(videoStatus.GenerationID) + } default: mediaType = "prompt" } + watermarkPostID := "" + if modelCfg.Type == "video" && watermarkOpts.Enabled { + watermarkURL, postID, watermarkErr := s.resolveWatermarkFreeURL(reqCtx, account, videoGenerationID, watermarkOpts) + if watermarkErr != nil { + if !watermarkOpts.FallbackOnFailure { + return nil, s.handleSoraRequestError(ctx, account, watermarkErr, reqModel, c, clientStream) + } + log.Printf("[Sora] watermark-free fallback to original URL, task_id=%s err=%v", taskID, watermarkErr) + } else if strings.TrimSpace(watermarkURL) != "" { + mediaURLs = []string{strings.TrimSpace(watermarkURL)} + watermarkPostID = strings.TrimSpace(postID) + } + } + finalURLs := s.normalizeSoraMediaURLs(mediaURLs) if len(mediaURLs) > 0 && s.mediaStorage != nil && s.mediaStorage.Enabled() { stored, storeErr := s.mediaStorage.StoreFromURLs(reqCtx, mediaType, mediaURLs) @@ -251,6 +386,11 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun finalURLs = s.normalizeSoraMediaURLs(stored) } } + if watermarkPostID != "" && watermarkOpts.DeletePost { + if deleteErr := s.soraClient.DeletePost(reqCtx, account, watermarkPostID); deleteErr != nil { + log.Printf("[Sora] delete post failed, post_id=%s err=%v", watermarkPostID, deleteErr) + } + } content := buildSoraContent(mediaType, finalURLs) var firstTokenMs *int @@ -299,6 +439,267 @@ func (s *SoraGatewayService) withSoraTimeout(ctx context.Context, stream bool) ( return context.WithTimeout(ctx, time.Duration(timeoutSeconds)*time.Second) } +func parseSoraWatermarkOptions(body map[string]any) soraWatermarkOptions { + opts := soraWatermarkOptions{ + Enabled: parseBoolWithDefault(body, "watermark_free", false), + ParseMethod: strings.ToLower(strings.TrimSpace(parseStringWithDefault(body, "watermark_parse_method", "third_party"))), + ParseURL: strings.TrimSpace(parseStringWithDefault(body, "watermark_parse_url", "")), + ParseToken: strings.TrimSpace(parseStringWithDefault(body, "watermark_parse_token", "")), + FallbackOnFailure: parseBoolWithDefault(body, "watermark_fallback_on_failure", true), + DeletePost: parseBoolWithDefault(body, "watermark_delete_post", false), + } + if opts.ParseMethod == "" { + opts.ParseMethod = "third_party" + } + return opts +} + +func parseSoraCharacterOptions(body map[string]any) soraCharacterOptions { + return soraCharacterOptions{ + SetPublic: parseBoolWithDefault(body, "character_set_public", true), + DeleteAfterGenerate: parseBoolWithDefault(body, "character_delete_after_generate", true), + } +} + +func parseBoolWithDefault(body map[string]any, key string, def bool) bool { + if body == nil { + return def + } + val, ok := body[key] + if !ok { + return def + } + switch typed := val.(type) { + case bool: + return typed + case int: + return typed != 0 + case int32: + return typed != 0 + case int64: + return typed != 0 + case float64: + return typed != 0 + case string: + typed = strings.ToLower(strings.TrimSpace(typed)) + if typed == "true" || typed == "1" || typed == "yes" { + return true + } + if typed == "false" || typed == "0" || typed == "no" { + return false + } + } + return def +} + +func parseStringWithDefault(body map[string]any, key, def string) string { + if body == nil { + return def + } + val, ok := body[key] + if !ok { + return def + } + if str, ok := val.(string); ok { + return str + } + return def +} + +func extractSoraCameoIDs(body map[string]any) []string { + if body == nil { + return nil + } + raw, ok := body["cameo_ids"] + if !ok { + return nil + } + switch typed := raw.(type) { + case []string: + out := make([]string, 0, len(typed)) + for _, item := range typed { + item = strings.TrimSpace(item) + if item != "" { + out = append(out, item) + } + } + return out + case []any: + out := make([]string, 0, len(typed)) + for _, item := range typed { + str, ok := item.(string) + if !ok { + continue + } + str = strings.TrimSpace(str) + if str != "" { + out = append(out, str) + } + } + return out + default: + return nil + } +} + +func (s *SoraGatewayService) createCharacterFromVideo(ctx context.Context, account *Account, videoData []byte, opts soraCharacterOptions) (*soraCharacterFlowResult, error) { + cameoID, err := s.soraClient.UploadCharacterVideo(ctx, account, videoData) + if err != nil { + return nil, err + } + + cameoStatus, err := s.pollCameoStatus(ctx, account, cameoID) + if err != nil { + return nil, err + } + username := processSoraCharacterUsername(cameoStatus.UsernameHint) + displayName := strings.TrimSpace(cameoStatus.DisplayNameHint) + if displayName == "" { + displayName = "Character" + } + profileAssetURL := strings.TrimSpace(cameoStatus.ProfileAssetURL) + if profileAssetURL == "" { + return nil, errors.New("profile asset url not found in cameo status") + } + + avatarData, err := s.soraClient.DownloadCharacterImage(ctx, account, profileAssetURL) + if err != nil { + return nil, err + } + assetPointer, err := s.soraClient.UploadCharacterImage(ctx, account, avatarData) + if err != nil { + return nil, err + } + instructionSet := cameoStatus.InstructionSetHint + if instructionSet == nil { + instructionSet = cameoStatus.InstructionSet + } + + characterID, err := s.soraClient.FinalizeCharacter(ctx, account, SoraCharacterFinalizeRequest{ + CameoID: strings.TrimSpace(cameoID), + Username: username, + DisplayName: displayName, + ProfileAssetPointer: assetPointer, + InstructionSet: instructionSet, + }) + if err != nil { + return nil, err + } + + if opts.SetPublic { + if err := s.soraClient.SetCharacterPublic(ctx, account, cameoID); err != nil { + return nil, err + } + } + + return &soraCharacterFlowResult{ + CameoID: strings.TrimSpace(cameoID), + CharacterID: strings.TrimSpace(characterID), + Username: strings.TrimSpace(username), + DisplayName: displayName, + }, nil +} + +func (s *SoraGatewayService) pollCameoStatus(ctx context.Context, account *Account, cameoID string) (*SoraCameoStatus, error) { + timeout := 10 * time.Minute + interval := 5 * time.Second + maxAttempts := int(math.Ceil(timeout.Seconds() / interval.Seconds())) + if maxAttempts < 1 { + maxAttempts = 1 + } + + var lastErr error + consecutiveErrors := 0 + for attempt := 0; attempt < maxAttempts; attempt++ { + status, err := s.soraClient.GetCameoStatus(ctx, account, cameoID) + if err != nil { + lastErr = err + consecutiveErrors++ + if consecutiveErrors >= 3 { + break + } + if attempt < maxAttempts-1 { + if sleepErr := sleepWithContext(ctx, interval); sleepErr != nil { + return nil, sleepErr + } + } + continue + } + consecutiveErrors = 0 + if status == nil { + if attempt < maxAttempts-1 { + if sleepErr := sleepWithContext(ctx, interval); sleepErr != nil { + return nil, sleepErr + } + } + continue + } + currentStatus := strings.ToLower(strings.TrimSpace(status.Status)) + statusMessage := strings.TrimSpace(status.StatusMessage) + if currentStatus == "failed" { + if statusMessage == "" { + statusMessage = "character creation failed" + } + return nil, errors.New(statusMessage) + } + if strings.EqualFold(statusMessage, "Completed") || currentStatus == "finalized" { + return status, nil + } + if attempt < maxAttempts-1 { + if sleepErr := sleepWithContext(ctx, interval); sleepErr != nil { + return nil, sleepErr + } + } + } + if lastErr != nil { + return nil, fmt.Errorf("poll cameo status failed: %w", lastErr) + } + return nil, errors.New("cameo processing timeout") +} + +func processSoraCharacterUsername(usernameHint string) string { + usernameHint = strings.TrimSpace(usernameHint) + if usernameHint == "" { + usernameHint = "character" + } + if strings.Contains(usernameHint, ".") { + parts := strings.Split(usernameHint, ".") + usernameHint = strings.TrimSpace(parts[len(parts)-1]) + } + if usernameHint == "" { + usernameHint = "character" + } + return fmt.Sprintf("%s%d", usernameHint, soraRandInt(900)+100) +} + +func (s *SoraGatewayService) resolveWatermarkFreeURL(ctx context.Context, account *Account, generationID string, opts soraWatermarkOptions) (string, string, error) { + generationID = strings.TrimSpace(generationID) + if generationID == "" { + return "", "", errors.New("generation id is required for watermark-free mode") + } + postID, err := s.soraClient.PostVideoForWatermarkFree(ctx, account, generationID) + if err != nil { + return "", "", err + } + postID = strings.TrimSpace(postID) + if postID == "" { + return "", "", errors.New("watermark-free publish returned empty post id") + } + + switch opts.ParseMethod { + case "custom": + urlVal, parseErr := s.soraClient.GetWatermarkFreeURLCustom(ctx, account, opts.ParseURL, opts.ParseToken, postID) + if parseErr != nil { + return "", postID, parseErr + } + return strings.TrimSpace(urlVal), postID, nil + case "", "third_party": + return fmt.Sprintf("https://oscdn2.dyysy.com/MP4/%s.mp4", postID), postID, nil + default: + return "", postID, fmt.Errorf("unsupported watermark parse method: %s", opts.ParseMethod) + } +} + func (s *SoraGatewayService) shouldFailoverUpstreamError(statusCode int) bool { switch statusCode { case 401, 402, 403, 404, 429, 529: @@ -554,7 +955,7 @@ func (s *SoraGatewayService) pollImageTask(ctx context.Context, c *gin.Context, return nil, errors.New("sora image generation timeout") } -func (s *SoraGatewayService) pollVideoTask(ctx context.Context, c *gin.Context, account *Account, taskID string, stream bool) ([]string, error) { +func (s *SoraGatewayService) pollVideoTaskDetailed(ctx context.Context, c *gin.Context, account *Account, taskID string, stream bool) (*SoraVideoTaskStatus, error) { interval := s.pollInterval() maxAttempts := s.pollMaxAttempts() lastPing := time.Now() @@ -565,7 +966,7 @@ func (s *SoraGatewayService) pollVideoTask(ctx context.Context, c *gin.Context, } switch strings.ToLower(status.Status) { case "completed", "succeeded": - return status.URLs, nil + return status, nil case "failed": if status.ErrorMsg != "" { return nil, errors.New(status.ErrorMsg) @@ -669,7 +1070,7 @@ func extractSoraInput(body map[string]any) (prompt, imageInput, videoInput, remi return "", "", "", "" } if v, ok := body["remix_target_id"].(string); ok { - remixTargetID = v + remixTargetID = strings.TrimSpace(v) } if v, ok := body["image"].(string); ok { imageInput = v @@ -710,6 +1111,10 @@ func extractSoraInput(body map[string]any) (prompt, imageInput, videoInput, remi prompt = builder.String() } } + if remixTargetID == "" { + remixTargetID = extractRemixTargetIDFromPrompt(prompt) + } + prompt = cleanRemixLinkFromPrompt(prompt) return prompt, imageInput, videoInput, remixTargetID } @@ -757,6 +1162,69 @@ func parseSoraMessageContent(content any) (text, imageInput, videoInput string) } } +func isSoraStoryboardPrompt(prompt string) bool { + prompt = strings.TrimSpace(prompt) + if prompt == "" { + return false + } + return len(soraStoryboardPattern.FindAllString(prompt, -1)) >= 1 +} + +func formatSoraStoryboardPrompt(prompt string) string { + prompt = strings.TrimSpace(prompt) + if prompt == "" { + return "" + } + matches := soraStoryboardShotPattern.FindAllStringSubmatch(prompt, -1) + if len(matches) == 0 { + return prompt + } + firstBracketPos := strings.Index(prompt, "[") + instructions := "" + if firstBracketPos > 0 { + instructions = strings.TrimSpace(prompt[:firstBracketPos]) + } + shots := make([]string, 0, len(matches)) + for i, match := range matches { + if len(match) < 3 { + continue + } + duration := strings.TrimSpace(match[1]) + scene := strings.TrimSpace(match[2]) + if scene == "" { + continue + } + shots = append(shots, fmt.Sprintf("Shot %d:\nduration: %ssec\nScene: %s", i+1, duration, scene)) + } + if len(shots) == 0 { + return prompt + } + timeline := strings.Join(shots, "\n\n") + if instructions == "" { + return timeline + } + return fmt.Sprintf("current timeline:\n%s\n\ninstructions:\n%s", timeline, instructions) +} + +func extractRemixTargetIDFromPrompt(prompt string) string { + prompt = strings.TrimSpace(prompt) + if prompt == "" { + return "" + } + return strings.TrimSpace(soraRemixTargetPattern.FindString(prompt)) +} + +func cleanRemixLinkFromPrompt(prompt string) string { + prompt = strings.TrimSpace(prompt) + if prompt == "" { + return prompt + } + cleaned := soraRemixTargetInURLPattern.ReplaceAllString(prompt, "") + cleaned = soraRemixTargetPattern.ReplaceAllString(cleaned, "") + cleaned = strings.Join(strings.Fields(cleaned), " ") + return strings.TrimSpace(cleaned) +} + func decodeSoraImageInput(ctx context.Context, input string) ([]byte, string, error) { raw := strings.TrimSpace(input) if raw == "" { @@ -769,7 +1237,7 @@ func decodeSoraImageInput(ctx context.Context, input string) ([]byte, string, er } meta := parts[0] payload := parts[1] - decoded, err := base64.StdEncoding.DecodeString(payload) + decoded, err := decodeBase64WithLimit(payload, soraImageInputMaxBytes) if err != nil { return nil, "", err } @@ -788,15 +1256,47 @@ func decodeSoraImageInput(ctx context.Context, input string) ([]byte, string, er if strings.HasPrefix(raw, "http://") || strings.HasPrefix(raw, "https://") { return downloadSoraImageInput(ctx, raw) } - decoded, err := base64.StdEncoding.DecodeString(raw) + decoded, err := decodeBase64WithLimit(raw, soraImageInputMaxBytes) if err != nil { return nil, "", errors.New("invalid base64 image") } return decoded, "image.png", nil } +func decodeSoraVideoInput(ctx context.Context, input string) ([]byte, error) { + raw := strings.TrimSpace(input) + if raw == "" { + return nil, errors.New("empty video input") + } + if strings.HasPrefix(raw, "data:") { + parts := strings.SplitN(raw, ",", 2) + if len(parts) != 2 { + return nil, errors.New("invalid video data url") + } + decoded, err := decodeBase64WithLimit(parts[1], soraVideoInputMaxBytes) + if err != nil { + return nil, errors.New("invalid base64 video") + } + if len(decoded) == 0 { + return nil, errors.New("empty video data") + } + return decoded, nil + } + if strings.HasPrefix(raw, "http://") || strings.HasPrefix(raw, "https://") { + return downloadSoraVideoInput(ctx, raw) + } + decoded, err := decodeBase64WithLimit(raw, soraVideoInputMaxBytes) + if err != nil { + return nil, errors.New("invalid base64 video") + } + if len(decoded) == 0 { + return nil, errors.New("empty video data") + } + return decoded, nil +} + func downloadSoraImageInput(ctx context.Context, rawURL string) ([]byte, string, error) { - parsed, err := validateSoraImageURL(rawURL) + parsed, err := validateSoraRemoteURL(rawURL) if err != nil { return nil, "", err } @@ -810,7 +1310,7 @@ func downloadSoraImageInput(ctx context.Context, rawURL string) ([]byte, string, if len(via) >= soraImageInputMaxRedirects { return errors.New("too many redirects") } - return validateSoraImageURLValue(req.URL) + return validateSoraRemoteURLValue(req.URL) }, } resp, err := client.Do(req) @@ -833,51 +1333,103 @@ func downloadSoraImageInput(ctx context.Context, rawURL string) ([]byte, string, return data, filename, nil } -func validateSoraImageURL(raw string) (*url.URL, error) { +func downloadSoraVideoInput(ctx context.Context, rawURL string) ([]byte, error) { + parsed, err := validateSoraRemoteURL(rawURL) + if err != nil { + return nil, err + } + req, err := http.NewRequestWithContext(ctx, http.MethodGet, parsed.String(), nil) + if err != nil { + return nil, err + } + client := &http.Client{ + Timeout: soraVideoInputTimeout, + CheckRedirect: func(req *http.Request, via []*http.Request) error { + if len(via) >= soraVideoInputMaxRedirects { + return errors.New("too many redirects") + } + return validateSoraRemoteURLValue(req.URL) + }, + } + resp, err := client.Do(req) + if err != nil { + return nil, err + } + defer func() { _ = resp.Body.Close() }() + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("download video failed: %d", resp.StatusCode) + } + data, err := io.ReadAll(io.LimitReader(resp.Body, soraVideoInputMaxBytes)) + if err != nil { + return nil, err + } + if len(data) == 0 { + return nil, errors.New("empty video content") + } + return data, nil +} + +func decodeBase64WithLimit(encoded string, maxBytes int64) ([]byte, error) { + if maxBytes <= 0 { + return nil, errors.New("invalid max bytes limit") + } + decoder := base64.NewDecoder(base64.StdEncoding, strings.NewReader(encoded)) + limited := io.LimitReader(decoder, maxBytes+1) + data, err := io.ReadAll(limited) + if err != nil { + return nil, err + } + if int64(len(data)) > maxBytes { + return nil, fmt.Errorf("input exceeds %d bytes limit", maxBytes) + } + return data, nil +} + +func validateSoraRemoteURL(raw string) (*url.URL, error) { if strings.TrimSpace(raw) == "" { - return nil, errors.New("empty image url") + return nil, errors.New("empty remote url") } parsed, err := url.Parse(raw) if err != nil { - return nil, fmt.Errorf("invalid image url: %w", err) + return nil, fmt.Errorf("invalid remote url: %w", err) } - if err := validateSoraImageURLValue(parsed); err != nil { + if err := validateSoraRemoteURLValue(parsed); err != nil { return nil, err } return parsed, nil } -func validateSoraImageURLValue(parsed *url.URL) error { +func validateSoraRemoteURLValue(parsed *url.URL) error { if parsed == nil { - return errors.New("invalid image url") + return errors.New("invalid remote url") } scheme := strings.ToLower(strings.TrimSpace(parsed.Scheme)) if scheme != "http" && scheme != "https" { - return errors.New("only http/https image url is allowed") + return errors.New("only http/https remote url is allowed") } if parsed.User != nil { - return errors.New("image url cannot contain userinfo") + return errors.New("remote url cannot contain userinfo") } host := strings.ToLower(strings.TrimSpace(parsed.Hostname())) if host == "" { - return errors.New("image url missing host") + return errors.New("remote url missing host") } if _, blocked := soraBlockedHostnames[host]; blocked { - return errors.New("image url is not allowed") + return errors.New("remote url is not allowed") } if ip := net.ParseIP(host); ip != nil { if isSoraBlockedIP(ip) { - return errors.New("image url is not allowed") + return errors.New("remote url is not allowed") } return nil } ips, err := net.LookupIP(host) if err != nil { - return fmt.Errorf("resolve image url failed: %w", err) + return fmt.Errorf("resolve remote url failed: %w", err) } for _, ip := range ips { if isSoraBlockedIP(ip) { - return errors.New("image url is not allowed") + return errors.New("remote url is not allowed") } } return nil diff --git a/backend/internal/service/sora_gateway_service_test.go b/backend/internal/service/sora_gateway_service_test.go index 469a131e..c965901c 100644 --- a/backend/internal/service/sora_gateway_service_test.go +++ b/backend/internal/service/sora_gateway_service_test.go @@ -5,6 +5,7 @@ package service import ( "context" "encoding/json" + "errors" "net/http" "net/http/httptest" "strings" @@ -25,6 +26,11 @@ type stubSoraClientForPoll struct { videoCalls int enhanced string enhanceErr error + storyboard bool + videoReq SoraVideoRequest + parseErr error + postCalls int + deleteCalls int } func (s *stubSoraClientForPoll) Enabled() bool { return true } @@ -35,8 +41,54 @@ func (s *stubSoraClientForPoll) CreateImageTask(ctx context.Context, account *Ac return "task-image", nil } func (s *stubSoraClientForPoll) CreateVideoTask(ctx context.Context, account *Account, req SoraVideoRequest) (string, error) { + s.videoReq = req return "task-video", nil } +func (s *stubSoraClientForPoll) CreateStoryboardTask(ctx context.Context, account *Account, req SoraStoryboardRequest) (string, error) { + s.storyboard = true + return "task-video", nil +} +func (s *stubSoraClientForPoll) UploadCharacterVideo(ctx context.Context, account *Account, data []byte) (string, error) { + return "cameo-1", nil +} +func (s *stubSoraClientForPoll) GetCameoStatus(ctx context.Context, account *Account, cameoID string) (*SoraCameoStatus, error) { + return &SoraCameoStatus{ + Status: "finalized", + StatusMessage: "Completed", + DisplayNameHint: "Character", + UsernameHint: "user.character", + ProfileAssetURL: "https://example.com/avatar.webp", + }, nil +} +func (s *stubSoraClientForPoll) DownloadCharacterImage(ctx context.Context, account *Account, imageURL string) ([]byte, error) { + return []byte("avatar"), nil +} +func (s *stubSoraClientForPoll) UploadCharacterImage(ctx context.Context, account *Account, data []byte) (string, error) { + return "asset-pointer", nil +} +func (s *stubSoraClientForPoll) FinalizeCharacter(ctx context.Context, account *Account, req SoraCharacterFinalizeRequest) (string, error) { + return "character-1", nil +} +func (s *stubSoraClientForPoll) SetCharacterPublic(ctx context.Context, account *Account, cameoID string) error { + return nil +} +func (s *stubSoraClientForPoll) DeleteCharacter(ctx context.Context, account *Account, characterID string) error { + return nil +} +func (s *stubSoraClientForPoll) PostVideoForWatermarkFree(ctx context.Context, account *Account, generationID string) (string, error) { + s.postCalls++ + return "s_post", nil +} +func (s *stubSoraClientForPoll) DeletePost(ctx context.Context, account *Account, postID string) error { + s.deleteCalls++ + return nil +} +func (s *stubSoraClientForPoll) GetWatermarkFreeURLCustom(ctx context.Context, account *Account, parseURL, parseToken, postID string) (string, error) { + if s.parseErr != nil { + return "", s.parseErr + } + return "https://example.com/no-watermark.mp4", nil +} func (s *stubSoraClientForPoll) EnhancePrompt(ctx context.Context, account *Account, prompt, expansionLevel string, durationS int) (string, error) { if s.enhanced != "" { return s.enhanced, s.enhanceErr @@ -102,6 +154,109 @@ func TestSoraGatewayService_ForwardPromptEnhance(t *testing.T) { require.Equal(t, "prompt-enhance-short-10s", result.Model) } +func TestSoraGatewayService_ForwardStoryboardPrompt(t *testing.T) { + client := &stubSoraClientForPoll{ + videoStatus: &SoraVideoTaskStatus{ + Status: "completed", + URLs: []string{"https://example.com/v.mp4"}, + }, + } + cfg := &config.Config{ + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + PollIntervalSeconds: 1, + MaxPollAttempts: 1, + }, + }, + } + svc := NewSoraGatewayService(client, nil, nil, cfg) + account := &Account{ID: 1, Platform: PlatformSora, Status: StatusActive} + body := []byte(`{"model":"sora2-landscape-10s","messages":[{"role":"user","content":"[5.0s]猫猫跳伞 [5.0s]猫猫落地"}],"stream":false}`) + + result, err := svc.Forward(context.Background(), nil, account, body, false) + require.NoError(t, err) + require.NotNil(t, result) + require.True(t, client.storyboard) +} + +func TestSoraGatewayService_ForwardCharacterOnly(t *testing.T) { + client := &stubSoraClientForPoll{} + cfg := &config.Config{ + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + PollIntervalSeconds: 1, + MaxPollAttempts: 1, + }, + }, + } + svc := NewSoraGatewayService(client, nil, nil, cfg) + account := &Account{ID: 1, Platform: PlatformSora, Status: StatusActive} + body := []byte(`{"model":"sora2-landscape-10s","video":"aGVsbG8=","stream":false}`) + + result, err := svc.Forward(context.Background(), nil, account, body, false) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, "prompt", result.MediaType) + require.Equal(t, 0, client.videoCalls) +} + +func TestSoraGatewayService_ForwardWatermarkFallback(t *testing.T) { + client := &stubSoraClientForPoll{ + videoStatus: &SoraVideoTaskStatus{ + Status: "completed", + URLs: []string{"https://example.com/original.mp4"}, + GenerationID: "gen_1", + }, + parseErr: errors.New("parse failed"), + } + cfg := &config.Config{ + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + PollIntervalSeconds: 1, + MaxPollAttempts: 1, + }, + }, + } + svc := NewSoraGatewayService(client, nil, nil, cfg) + account := &Account{ID: 1, Platform: PlatformSora, Status: StatusActive} + body := []byte(`{"model":"sora2-landscape-10s","messages":[{"role":"user","content":"cat running"}],"stream":false,"watermark_free":true,"watermark_parse_method":"custom","watermark_parse_url":"https://parser.example.com","watermark_parse_token":"token","watermark_fallback_on_failure":true}`) + + result, err := svc.Forward(context.Background(), nil, account, body, false) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, "https://example.com/original.mp4", result.MediaURL) + require.Equal(t, 1, client.postCalls) + require.Equal(t, 0, client.deleteCalls) +} + +func TestSoraGatewayService_ForwardWatermarkCustomSuccessAndDelete(t *testing.T) { + client := &stubSoraClientForPoll{ + videoStatus: &SoraVideoTaskStatus{ + Status: "completed", + URLs: []string{"https://example.com/original.mp4"}, + GenerationID: "gen_1", + }, + } + cfg := &config.Config{ + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + PollIntervalSeconds: 1, + MaxPollAttempts: 1, + }, + }, + } + svc := NewSoraGatewayService(client, nil, nil, cfg) + account := &Account{ID: 1, Platform: PlatformSora, Status: StatusActive} + body := []byte(`{"model":"sora2-landscape-10s","messages":[{"role":"user","content":"cat running"}],"stream":false,"watermark_free":true,"watermark_parse_method":"custom","watermark_parse_url":"https://parser.example.com","watermark_parse_token":"token","watermark_delete_post":true}`) + + result, err := svc.Forward(context.Background(), nil, account, body, false) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, "https://example.com/no-watermark.mp4", result.MediaURL) + require.Equal(t, 1, client.postCalls) + require.Equal(t, 1, client.deleteCalls) +} + func TestSoraGatewayService_PollVideoTaskFailed(t *testing.T) { client := &stubSoraClientForPoll{ videoStatus: &SoraVideoTaskStatus{ @@ -119,9 +274,9 @@ func TestSoraGatewayService_PollVideoTaskFailed(t *testing.T) { } service := NewSoraGatewayService(client, nil, nil, cfg) - urls, err := service.pollVideoTask(context.Background(), nil, &Account{ID: 1}, "task", false) + status, err := service.pollVideoTaskDetailed(context.Background(), nil, &Account{ID: 1}, "task", false) require.Error(t, err) - require.Empty(t, urls) + require.Nil(t, status) require.Contains(t, err.Error(), "reject") require.Equal(t, 1, client.videoCalls) } @@ -325,3 +480,19 @@ func TestDecodeSoraImageInput_DataURL(t *testing.T) { require.NotEmpty(t, data) require.Contains(t, filename, ".png") } + +func TestDecodeBase64WithLimit_ExceedLimit(t *testing.T) { + data, err := decodeBase64WithLimit("aGVsbG8=", 3) + require.Error(t, err) + require.Nil(t, data) +} + +func TestParseSoraWatermarkOptions_NumericBool(t *testing.T) { + body := map[string]any{ + "watermark_free": float64(1), + "watermark_fallback_on_failure": float64(0), + } + opts := parseSoraWatermarkOptions(body) + require.True(t, opts.Enabled) + require.False(t, opts.FallbackOnFailure) +} diff --git a/backend/internal/util/soraerror/soraerror.go b/backend/internal/util/soraerror/soraerror.go new file mode 100644 index 00000000..17712c10 --- /dev/null +++ b/backend/internal/util/soraerror/soraerror.go @@ -0,0 +1,170 @@ +package soraerror + +import ( + "encoding/json" + "fmt" + "net/http" + "regexp" + "strings" +) + +var ( + cfRayPattern = regexp.MustCompile(`(?i)cf-ray[:\s=]+([a-z0-9-]+)`) + cRayPattern = regexp.MustCompile(`(?i)cRay:\s*'([a-z0-9-]+)'`) + htmlChallenge = []string{ + "window._cf_chl_opt", + "just a moment", + "enable javascript and cookies to continue", + "__cf_chl_", + "challenge-platform", + } +) + +// IsCloudflareChallengeResponse reports whether the upstream response matches Cloudflare challenge behavior. +func IsCloudflareChallengeResponse(statusCode int, headers http.Header, body []byte) bool { + if statusCode != http.StatusForbidden && statusCode != http.StatusTooManyRequests { + return false + } + + if headers != nil && strings.EqualFold(strings.TrimSpace(headers.Get("cf-mitigated")), "challenge") { + return true + } + + preview := strings.ToLower(TruncateBody(body, 4096)) + for _, marker := range htmlChallenge { + if strings.Contains(preview, marker) { + return true + } + } + + contentType := "" + if headers != nil { + contentType = strings.ToLower(strings.TrimSpace(headers.Get("content-type"))) + } + if strings.Contains(contentType, "text/html") && + (strings.Contains(preview, "= 2 { + return strings.TrimSpace(matches[1]) + } + if matches := cRayPattern.FindStringSubmatch(preview); len(matches) >= 2 { + return strings.TrimSpace(matches[1]) + } + return "" +} + +// FormatCloudflareChallengeMessage appends cf-ray info when available. +func FormatCloudflareChallengeMessage(base string, headers http.Header, body []byte) string { + rayID := ExtractCloudflareRayID(headers, body) + if rayID == "" { + return base + } + return fmt.Sprintf("%s (cf-ray: %s)", base, rayID) +} + +// ExtractUpstreamErrorCodeAndMessage extracts structured error code/message from common JSON layouts. +func ExtractUpstreamErrorCodeAndMessage(body []byte) (string, string) { + trimmed := strings.TrimSpace(string(body)) + if trimmed == "" { + return "", "" + } + if !json.Valid([]byte(trimmed)) { + return "", truncateMessage(trimmed, 256) + } + + var payload map[string]any + if err := json.Unmarshal([]byte(trimmed), &payload); err != nil { + return "", truncateMessage(trimmed, 256) + } + + code := firstNonEmpty( + extractNestedString(payload, "error", "code"), + extractRootString(payload, "code"), + ) + message := firstNonEmpty( + extractNestedString(payload, "error", "message"), + extractRootString(payload, "message"), + extractNestedString(payload, "error", "detail"), + extractRootString(payload, "detail"), + ) + return strings.TrimSpace(code), truncateMessage(strings.TrimSpace(message), 512) +} + +// TruncateBody truncates body text for logging/inspection. +func TruncateBody(body []byte, max int) string { + if max <= 0 { + max = 512 + } + raw := strings.TrimSpace(string(body)) + if len(raw) <= max { + return raw + } + return raw[:max] + "...(truncated)" +} + +func truncateMessage(s string, max int) string { + if max <= 0 { + return "" + } + if len(s) <= max { + return s + } + return s[:max] + "...(truncated)" +} + +func firstNonEmpty(values ...string) string { + for _, v := range values { + if strings.TrimSpace(v) != "" { + return v + } + } + return "" +} + +func extractRootString(m map[string]any, key string) string { + if m == nil { + return "" + } + v, ok := m[key] + if !ok { + return "" + } + s, _ := v.(string) + return s +} + +func extractNestedString(m map[string]any, parent, key string) string { + if m == nil { + return "" + } + node, ok := m[parent] + if !ok { + return "" + } + child, ok := node.(map[string]any) + if !ok { + return "" + } + s, _ := child[key].(string) + return s +} diff --git a/backend/internal/util/soraerror/soraerror_test.go b/backend/internal/util/soraerror/soraerror_test.go new file mode 100644 index 00000000..4cf11169 --- /dev/null +++ b/backend/internal/util/soraerror/soraerror_test.go @@ -0,0 +1,47 @@ +package soraerror + +import ( + "net/http" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestIsCloudflareChallengeResponse(t *testing.T) { + headers := make(http.Header) + headers.Set("cf-mitigated", "challenge") + require.True(t, IsCloudflareChallengeResponse(http.StatusForbidden, headers, []byte(`{"ok":false}`))) + + require.True(t, IsCloudflareChallengeResponse(http.StatusTooManyRequests, nil, []byte(`Just a moment...`))) + require.False(t, IsCloudflareChallengeResponse(http.StatusBadGateway, nil, []byte(`Just a moment...`))) +} + +func TestExtractCloudflareRayID(t *testing.T) { + headers := make(http.Header) + headers.Set("cf-ray", "9d01b0e9ecc35829-SEA") + require.Equal(t, "9d01b0e9ecc35829-SEA", ExtractCloudflareRayID(headers, nil)) + + body := []byte(``) + require.Equal(t, "9cff2d62d83bb98d", ExtractCloudflareRayID(nil, body)) +} + +func TestExtractUpstreamErrorCodeAndMessage(t *testing.T) { + code, msg := ExtractUpstreamErrorCodeAndMessage([]byte(`{"error":{"code":"cf_shield_429","message":"rate limited"}}`)) + require.Equal(t, "cf_shield_429", code) + require.Equal(t, "rate limited", msg) + + code, msg = ExtractUpstreamErrorCodeAndMessage([]byte(`{"code":"unsupported_country_code","message":"not available"}`)) + require.Equal(t, "unsupported_country_code", code) + require.Equal(t, "not available", msg) + + code, msg = ExtractUpstreamErrorCodeAndMessage([]byte(`plain text`)) + require.Equal(t, "", code) + require.Equal(t, "plain text", msg) +} + +func TestFormatCloudflareChallengeMessage(t *testing.T) { + headers := make(http.Header) + headers.Set("cf-ray", "9d03b68c086027a1-SEA") + msg := FormatCloudflareChallengeMessage("blocked", headers, nil) + require.Equal(t, "blocked (cf-ray: 9d03b68c086027a1-SEA)", msg) +}