diff --git a/backend/cmd/server/VERSION b/backend/cmd/server/VERSION
index d778f67c..f788a87d 100644
--- a/backend/cmd/server/VERSION
+++ b/backend/cmd/server/VERSION
@@ -1 +1 @@
-0.1.83.1
+0.1.83.2
diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go
index be17fb01..a0f8807a 100644
--- a/backend/cmd/server/wire_gen.go
+++ b/backend/cmd/server/wire_gen.go
@@ -184,7 +184,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler)
gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, usageService, apiKeyService, errorPassthroughService, configConfig)
openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, apiKeyService, errorPassthroughService, configConfig)
- soraDirectClient := service.NewSoraDirectClient(configConfig, httpUpstream, openAITokenProvider)
+ soraDirectClient := service.ProvideSoraDirectClient(configConfig, httpUpstream, openAITokenProvider, accountRepository, soraAccountRepository)
soraMediaStorage := service.ProvideSoraMediaStorage(configConfig)
soraGatewayService := service.NewSoraGatewayService(soraDirectClient, soraMediaStorage, rateLimitService, configConfig)
soraGatewayHandler := handler.NewSoraGatewayHandler(gatewayService, soraGatewayService, concurrencyService, billingCacheService, configConfig)
diff --git a/backend/internal/handler/sora_gateway_handler.go b/backend/internal/handler/sora_gateway_handler.go
index 791100f6..219922aa 100644
--- a/backend/internal/handler/sora_gateway_handler.go
+++ b/backend/internal/handler/sora_gateway_handler.go
@@ -12,7 +12,6 @@ import (
"os"
"path"
"path/filepath"
- "regexp"
"strconv"
"strings"
"time"
@@ -22,6 +21,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/Wei-Shaw/sub2api/internal/util/soraerror"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
@@ -29,8 +29,6 @@ import (
"go.uber.org/zap"
)
-var soraCloudflareRayPattern = regexp.MustCompile(`(?i)cf-ray[:\s=]+([a-z0-9-]+)`)
-
// SoraGatewayHandler handles Sora chat completions requests
type SoraGatewayHandler struct {
gatewayService *service.GatewayService
@@ -385,6 +383,10 @@ func (h *SoraGatewayHandler) mapUpstreamError(statusCode int, responseHeaders ht
}
upstreamCode, upstreamMessage := extractUpstreamErrorCodeAndMessage(responseBody)
+ if strings.EqualFold(upstreamCode, "cf_shield_429") {
+ baseMsg := "Sora request blocked by Cloudflare shield (429). Please switch to a clean proxy/network and retry."
+ return http.StatusTooManyRequests, "rate_limit_error", formatSoraCloudflareChallengeMessage(baseMsg, responseHeaders, responseBody)
+ }
if shouldPassthroughSoraUpstreamMessage(statusCode, upstreamMessage) {
switch statusCode {
case 401, 403, 404, 500, 502, 503, 504:
@@ -416,27 +418,7 @@ func (h *SoraGatewayHandler) mapUpstreamError(statusCode int, responseHeaders ht
}
func isSoraCloudflareChallengeResponse(statusCode int, headers http.Header, body []byte) bool {
- if statusCode != http.StatusForbidden && statusCode != http.StatusTooManyRequests {
- return false
- }
- if strings.EqualFold(strings.TrimSpace(headers.Get("cf-mitigated")), "challenge") {
- return true
- }
- preview := strings.ToLower(truncateSoraErrorBody(body, 4096))
- if strings.Contains(preview, "window._cf_chl_opt") ||
- strings.Contains(preview, "just a moment") ||
- strings.Contains(preview, "enable javascript and cookies to continue") ||
- strings.Contains(preview, "__cf_chl_") ||
- strings.Contains(preview, "challenge-platform") {
- return true
- }
- contentType := strings.ToLower(strings.TrimSpace(headers.Get("content-type")))
- if strings.Contains(contentType, "text/html") &&
- (strings.Contains(preview, "= 2 {
- return strings.TrimSpace(matches[1])
- }
- return ""
+ return soraerror.FormatCloudflareChallengeMessage(base, headers, body)
}
func extractUpstreamErrorCodeAndMessage(body []byte) (string, string) {
- trimmed := strings.TrimSpace(string(body))
- if trimmed == "" {
- return "", ""
- }
- if !gjson.Valid(trimmed) {
- return "", truncateSoraErrorMessage(trimmed, 256)
- }
- code := strings.TrimSpace(gjson.Get(trimmed, "error.code").String())
- if code == "" {
- code = strings.TrimSpace(gjson.Get(trimmed, "code").String())
- }
- message := strings.TrimSpace(gjson.Get(trimmed, "error.message").String())
- if message == "" {
- message = strings.TrimSpace(gjson.Get(trimmed, "message").String())
- }
- if message == "" {
- message = strings.TrimSpace(gjson.Get(trimmed, "error.detail").String())
- }
- if message == "" {
- message = strings.TrimSpace(gjson.Get(trimmed, "detail").String())
- }
- return code, truncateSoraErrorMessage(message, 512)
-}
-
-func truncateSoraErrorMessage(s string, maxLen int) string {
- if maxLen <= 0 {
- return ""
- }
- if len(s) <= maxLen {
- return s
- }
- return s[:maxLen] + "...(truncated)"
-}
-
-func truncateSoraErrorBody(body []byte, maxLen int) string {
- if maxLen <= 0 {
- maxLen = 512
- }
- raw := strings.TrimSpace(string(body))
- if len(raw) <= maxLen {
- return raw
- }
- return raw[:maxLen] + "...(truncated)"
+ return soraerror.ExtractUpstreamErrorCodeAndMessage(body)
}
func (h *SoraGatewayHandler) handleStreamingAwareError(c *gin.Context, status int, errType, message string, streamStarted bool) {
diff --git a/backend/internal/handler/sora_gateway_handler_test.go b/backend/internal/handler/sora_gateway_handler_test.go
index edf3ca5e..52ff0a96 100644
--- a/backend/internal/handler/sora_gateway_handler_test.go
+++ b/backend/internal/handler/sora_gateway_handler_test.go
@@ -43,6 +43,45 @@ func (s *stubSoraClient) CreateImageTask(ctx context.Context, account *service.A
func (s *stubSoraClient) CreateVideoTask(ctx context.Context, account *service.Account, req service.SoraVideoRequest) (string, error) {
return "task-video", nil
}
+func (s *stubSoraClient) CreateStoryboardTask(ctx context.Context, account *service.Account, req service.SoraStoryboardRequest) (string, error) {
+ return "task-video", nil
+}
+func (s *stubSoraClient) UploadCharacterVideo(ctx context.Context, account *service.Account, data []byte) (string, error) {
+ return "cameo-1", nil
+}
+func (s *stubSoraClient) GetCameoStatus(ctx context.Context, account *service.Account, cameoID string) (*service.SoraCameoStatus, error) {
+ return &service.SoraCameoStatus{
+ Status: "finalized",
+ StatusMessage: "Completed",
+ DisplayNameHint: "Character",
+ UsernameHint: "user.character",
+ ProfileAssetURL: "https://example.com/avatar.webp",
+ }, nil
+}
+func (s *stubSoraClient) DownloadCharacterImage(ctx context.Context, account *service.Account, imageURL string) ([]byte, error) {
+ return []byte("avatar"), nil
+}
+func (s *stubSoraClient) UploadCharacterImage(ctx context.Context, account *service.Account, data []byte) (string, error) {
+ return "asset-pointer", nil
+}
+func (s *stubSoraClient) FinalizeCharacter(ctx context.Context, account *service.Account, req service.SoraCharacterFinalizeRequest) (string, error) {
+ return "character-1", nil
+}
+func (s *stubSoraClient) SetCharacterPublic(ctx context.Context, account *service.Account, cameoID string) error {
+ return nil
+}
+func (s *stubSoraClient) DeleteCharacter(ctx context.Context, account *service.Account, characterID string) error {
+ return nil
+}
+func (s *stubSoraClient) PostVideoForWatermarkFree(ctx context.Context, account *service.Account, generationID string) (string, error) {
+ return "s_post", nil
+}
+func (s *stubSoraClient) DeletePost(ctx context.Context, account *service.Account, postID string) error {
+ return nil
+}
+func (s *stubSoraClient) GetWatermarkFreeURLCustom(ctx context.Context, account *service.Account, parseURL, parseToken, postID string) (string, error) {
+ return "https://example.com/no-watermark.mp4", nil
+}
func (s *stubSoraClient) EnhancePrompt(ctx context.Context, account *service.Account, prompt, expansionLevel string, durationS int) (string, error) {
return "enhanced prompt", nil
}
@@ -607,3 +646,31 @@ func TestSoraHandleFailoverExhausted_CloudflareChallengeIncludesRay(t *testing.T
require.Contains(t, msg, "Cloudflare challenge")
require.Contains(t, msg, "cf-ray: 9d01b0e9ecc35829-SEA")
}
+
+func TestSoraHandleFailoverExhausted_CfShield429MappedToRateLimitError(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ w := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(w)
+ c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
+
+ headers := http.Header{}
+ headers.Set("cf-ray", "9d03b68c086027a1-SEA")
+ body := []byte(`{"error":{"code":"cf_shield_429","message":"shield blocked"}}`)
+
+ h := &SoraGatewayHandler{}
+ h.handleFailoverExhausted(c, http.StatusTooManyRequests, headers, body, true)
+
+ lines := strings.Split(strings.TrimSuffix(w.Body.String(), "\n\n"), "\n")
+ require.Len(t, lines, 2)
+ jsonStr := strings.TrimPrefix(lines[1], "data: ")
+
+ var parsed map[string]any
+ require.NoError(t, json.Unmarshal([]byte(jsonStr), &parsed))
+
+ errorObj, ok := parsed["error"].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, "rate_limit_error", errorObj["type"])
+ msg, _ := errorObj["message"].(string)
+ require.Contains(t, msg, "Cloudflare shield")
+ require.Contains(t, msg, "cf-ray: 9d03b68c086027a1-SEA")
+}
diff --git a/backend/internal/service/account_test_service.go b/backend/internal/service/account_test_service.go
index c3f2359a..a507efb4 100644
--- a/backend/internal/service/account_test_service.go
+++ b/backend/internal/service/account_test_service.go
@@ -15,11 +15,14 @@ import (
"net/url"
"regexp"
"strings"
+ "sync"
+ "time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
+ "github.com/Wei-Shaw/sub2api/internal/util/soraerror"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
@@ -28,7 +31,6 @@ import (
// sseDataPrefix matches SSE data lines with optional whitespace after colon.
// Some upstream APIs return non-standard "data:" without space (should be "data: ").
var sseDataPrefix = regexp.MustCompile(`^data:\s*`)
-var cloudflareRayPattern = regexp.MustCompile(`(?i)cRay:\s*'([a-z0-9-]+)'`)
const (
testClaudeAPIURL = "https://api.anthropic.com/v1/messages"
@@ -45,6 +47,9 @@ type TestEvent struct {
Type string `json:"type"`
Text string `json:"text,omitempty"`
Model string `json:"model,omitempty"`
+ Status string `json:"status,omitempty"`
+ Code string `json:"code,omitempty"`
+ Data any `json:"data,omitempty"`
Success bool `json:"success,omitempty"`
Error string `json:"error,omitempty"`
}
@@ -56,8 +61,13 @@ type AccountTestService struct {
antigravityGatewayService *AntigravityGatewayService
httpUpstream HTTPUpstream
cfg *config.Config
+ soraTestGuardMu sync.Mutex
+ soraTestLastRun map[int64]time.Time
+ soraTestCooldown time.Duration
}
+const defaultSoraTestCooldown = 10 * time.Second
+
// NewAccountTestService creates a new AccountTestService
func NewAccountTestService(
accountRepo AccountRepository,
@@ -72,6 +82,8 @@ func NewAccountTestService(
antigravityGatewayService: antigravityGatewayService,
httpUpstream: httpUpstream,
cfg: cfg,
+ soraTestLastRun: make(map[int64]time.Time),
+ soraTestCooldown: defaultSoraTestCooldown,
}
}
@@ -473,13 +485,129 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account
return s.processGeminiStream(c, resp.Body)
}
+type soraProbeStep struct {
+ Name string `json:"name"`
+ Status string `json:"status"`
+ HTTPStatus int `json:"http_status,omitempty"`
+ ErrorCode string `json:"error_code,omitempty"`
+ Message string `json:"message,omitempty"`
+}
+
+type soraProbeSummary struct {
+ Status string `json:"status"`
+ Steps []soraProbeStep `json:"steps"`
+}
+
+type soraProbeRecorder struct {
+ steps []soraProbeStep
+}
+
+func (r *soraProbeRecorder) addStep(name, status string, httpStatus int, errorCode, message string) {
+ r.steps = append(r.steps, soraProbeStep{
+ Name: name,
+ Status: status,
+ HTTPStatus: httpStatus,
+ ErrorCode: strings.TrimSpace(errorCode),
+ Message: strings.TrimSpace(message),
+ })
+}
+
+func (r *soraProbeRecorder) finalize() soraProbeSummary {
+ meSuccess := false
+ partial := false
+ for _, step := range r.steps {
+ if step.Name == "me" {
+ meSuccess = strings.EqualFold(step.Status, "success")
+ continue
+ }
+ if strings.EqualFold(step.Status, "failed") {
+ partial = true
+ }
+ }
+
+ status := "success"
+ if !meSuccess {
+ status = "failed"
+ } else if partial {
+ status = "partial_success"
+ }
+
+ return soraProbeSummary{
+ Status: status,
+ Steps: append([]soraProbeStep(nil), r.steps...),
+ }
+}
+
+func (s *AccountTestService) emitSoraProbeSummary(c *gin.Context, rec *soraProbeRecorder) {
+ if rec == nil {
+ return
+ }
+ summary := rec.finalize()
+ code := ""
+ for _, step := range summary.Steps {
+ if strings.EqualFold(step.Status, "failed") && strings.TrimSpace(step.ErrorCode) != "" {
+ code = step.ErrorCode
+ break
+ }
+ }
+ s.sendEvent(c, TestEvent{
+ Type: "sora_test_result",
+ Status: summary.Status,
+ Code: code,
+ Data: summary,
+ })
+}
+
+func (s *AccountTestService) acquireSoraTestPermit(accountID int64) (time.Duration, bool) {
+ if accountID <= 0 {
+ return 0, true
+ }
+ s.soraTestGuardMu.Lock()
+ defer s.soraTestGuardMu.Unlock()
+
+ if s.soraTestLastRun == nil {
+ s.soraTestLastRun = make(map[int64]time.Time)
+ }
+ cooldown := s.soraTestCooldown
+ if cooldown <= 0 {
+ cooldown = defaultSoraTestCooldown
+ }
+
+ now := time.Now()
+ if lastRun, ok := s.soraTestLastRun[accountID]; ok {
+ elapsed := now.Sub(lastRun)
+ if elapsed < cooldown {
+ return cooldown - elapsed, false
+ }
+ }
+ s.soraTestLastRun[accountID] = now
+ return 0, true
+}
+
+func ceilSeconds(d time.Duration) int {
+ if d <= 0 {
+ return 1
+ }
+ sec := int(d / time.Second)
+ if d%time.Second != 0 {
+ sec++
+ }
+ if sec < 1 {
+ sec = 1
+ }
+ return sec
+}
+
// testSoraAccountConnection 测试 Sora 账号的连接
// 调用 /backend/me 接口验证 access_token 有效性(不需要 Sentinel Token)
func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account *Account) error {
ctx := c.Request.Context()
+ recorder := &soraProbeRecorder{}
authToken := account.GetCredential("access_token")
if authToken == "" {
+ recorder.addStep("me", "failed", http.StatusUnauthorized, "missing_access_token", "No access token available")
+ s.emitSoraProbeSummary(c, recorder)
return s.sendErrorAndEnd(c, "No access token available")
}
@@ -490,11 +618,20 @@ func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account *
c.Writer.Header().Set("X-Accel-Buffering", "no")
c.Writer.Flush()
+ if wait, ok := s.acquireSoraTestPermit(account.ID); !ok {
+ msg := fmt.Sprintf("Sora 账号测试过于频繁,请 %d 秒后重试", ceilSeconds(wait))
+ recorder.addStep("rate_limit", "failed", http.StatusTooManyRequests, "test_rate_limited", msg)
+ s.emitSoraProbeSummary(c, recorder)
+ return s.sendErrorAndEnd(c, msg)
+ }
+
// Send test_start event
s.sendEvent(c, TestEvent{Type: "test_start", Model: "sora"})
req, err := http.NewRequestWithContext(ctx, "GET", soraMeAPIURL, nil)
if err != nil {
+ recorder.addStep("me", "failed", 0, "request_build_failed", err.Error())
+ s.emitSoraProbeSummary(c, recorder)
return s.sendErrorAndEnd(c, "Failed to create request")
}
@@ -515,6 +652,8 @@ func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account *
resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, enableSoraTLSFingerprint)
if err != nil {
+ recorder.addStep("me", "failed", 0, "network_error", err.Error())
+ s.emitSoraProbeSummary(c, recorder)
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
}
defer func() { _ = resp.Body.Close() }()
@@ -522,12 +661,33 @@ func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account *
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
- if isCloudflareChallengeResponse(resp.StatusCode, body) {
+ if isCloudflareChallengeResponse(resp.StatusCode, resp.Header, body) {
+ recorder.addStep("me", "failed", resp.StatusCode, "cf_challenge", "Cloudflare challenge detected")
+ s.emitSoraProbeSummary(c, recorder)
s.logSoraCloudflareChallenge(account, proxyURL, soraMeAPIURL, resp.Header, body)
- return s.sendErrorAndEnd(c, formatCloudflareChallengeMessage("Sora request blocked by Cloudflare challenge (HTTP 403). Please switch to a clean proxy/network and retry.", resp.Header, body))
+ return s.sendErrorAndEnd(c, formatCloudflareChallengeMessage(fmt.Sprintf("Sora request blocked by Cloudflare challenge (HTTP %d). Please switch to a clean proxy/network and retry.", resp.StatusCode), resp.Header, body))
+ }
+ upstreamCode, upstreamMessage := soraerror.ExtractUpstreamErrorCodeAndMessage(body)
+ switch {
+ case resp.StatusCode == http.StatusUnauthorized && strings.EqualFold(upstreamCode, "token_invalidated"):
+ recorder.addStep("me", "failed", resp.StatusCode, "token_invalidated", "Sora token invalidated")
+ s.emitSoraProbeSummary(c, recorder)
+ return s.sendErrorAndEnd(c, "Sora token 已失效(token_invalidated),请重新授权账号")
+ case strings.EqualFold(upstreamCode, "unsupported_country_code"):
+ recorder.addStep("me", "failed", resp.StatusCode, "unsupported_country_code", "Sora is unavailable in current egress region")
+ s.emitSoraProbeSummary(c, recorder)
+ return s.sendErrorAndEnd(c, "Sora 在当前网络出口地区不可用(unsupported_country_code),请切换到支持地区后重试")
+ case strings.TrimSpace(upstreamMessage) != "":
+ recorder.addStep("me", "failed", resp.StatusCode, upstreamCode, upstreamMessage)
+ s.emitSoraProbeSummary(c, recorder)
+ return s.sendErrorAndEnd(c, fmt.Sprintf("Sora API returned %d: %s", resp.StatusCode, upstreamMessage))
+ default:
+ recorder.addStep("me", "failed", resp.StatusCode, upstreamCode, "Sora me endpoint failed")
+ s.emitSoraProbeSummary(c, recorder)
+ return s.sendErrorAndEnd(c, fmt.Sprintf("Sora API returned %d: %s", resp.StatusCode, truncateSoraErrorBody(body, 512)))
}
- return s.sendErrorAndEnd(c, fmt.Sprintf("Sora API returned %d: %s", resp.StatusCode, truncateSoraErrorBody(body, 512)))
}
+ recorder.addStep("me", "success", resp.StatusCode, "", "me endpoint ok")
// 解析 /me 响应,提取用户信息
var meResp map[string]any
@@ -557,21 +717,26 @@ func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account *
subResp, subErr := s.httpUpstream.DoWithTLS(subReq, proxyURL, account.ID, account.Concurrency, enableSoraTLSFingerprint)
if subErr != nil {
+ recorder.addStep("subscription", "failed", 0, "network_error", subErr.Error())
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Subscription check skipped: %s", subErr.Error())})
} else {
subBody, _ := io.ReadAll(subResp.Body)
_ = subResp.Body.Close()
if subResp.StatusCode == http.StatusOK {
+ recorder.addStep("subscription", "success", subResp.StatusCode, "", "subscription endpoint ok")
if summary := parseSoraSubscriptionSummary(subBody); summary != "" {
s.sendEvent(c, TestEvent{Type: "content", Text: summary})
} else {
s.sendEvent(c, TestEvent{Type: "content", Text: "Subscription check OK"})
}
} else {
- if isCloudflareChallengeResponse(subResp.StatusCode, subBody) {
+ if isCloudflareChallengeResponse(subResp.StatusCode, subResp.Header, subBody) {
+ recorder.addStep("subscription", "failed", subResp.StatusCode, "cf_challenge", "Cloudflare challenge detected")
s.logSoraCloudflareChallenge(account, proxyURL, soraBillingAPIURL, subResp.Header, subBody)
- s.sendEvent(c, TestEvent{Type: "content", Text: formatCloudflareChallengeMessage("Subscription check blocked by Cloudflare challenge (HTTP 403)", subResp.Header, subBody)})
+ s.sendEvent(c, TestEvent{Type: "content", Text: formatCloudflareChallengeMessage(fmt.Sprintf("Subscription check blocked by Cloudflare challenge (HTTP %d)", subResp.StatusCode), subResp.Header, subBody)})
} else {
+ upstreamCode, upstreamMessage := soraerror.ExtractUpstreamErrorCodeAndMessage(subBody)
+ recorder.addStep("subscription", "failed", subResp.StatusCode, upstreamCode, upstreamMessage)
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Subscription check returned %d", subResp.StatusCode)})
}
}
@@ -579,8 +744,9 @@ func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account *
}
// 追加 Sora2 能力探测(对齐 sora2api 的测试思路):邀请码 + 剩余额度。
- s.testSora2Capabilities(c, ctx, account, authToken, proxyURL, enableSoraTLSFingerprint)
+ s.testSora2Capabilities(c, ctx, account, authToken, proxyURL, enableSoraTLSFingerprint, recorder)
+ s.emitSoraProbeSummary(c, recorder)
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
return nil
}
@@ -592,6 +758,7 @@ func (s *AccountTestService) testSora2Capabilities(
authToken string,
proxyURL string,
enableTLSFingerprint bool,
+ recorder *soraProbeRecorder,
) {
inviteStatus, inviteHeader, inviteBody, err := s.fetchSoraTestEndpoint(
ctx,
@@ -602,6 +769,9 @@ func (s *AccountTestService) testSora2Capabilities(
enableTLSFingerprint,
)
if err != nil {
+ if recorder != nil {
+ recorder.addStep("sora2_invite", "failed", 0, "network_error", err.Error())
+ }
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 invite check skipped: %s", err.Error())})
return
}
@@ -616,6 +786,9 @@ func (s *AccountTestService) testSora2Capabilities(
enableTLSFingerprint,
)
if bootstrapErr == nil && bootstrapStatus == http.StatusOK {
+ if recorder != nil {
+ recorder.addStep("sora2_bootstrap", "success", bootstrapStatus, "", "bootstrap endpoint ok")
+ }
s.sendEvent(c, TestEvent{Type: "content", Text: "Sora2 bootstrap OK, retry invite check"})
inviteStatus, inviteHeader, inviteBody, err = s.fetchSoraTestEndpoint(
ctx,
@@ -626,20 +799,42 @@ func (s *AccountTestService) testSora2Capabilities(
enableTLSFingerprint,
)
if err != nil {
+ if recorder != nil {
+ recorder.addStep("sora2_invite", "failed", 0, "network_error", err.Error())
+ }
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 invite retry failed: %s", err.Error())})
return
}
+ } else if recorder != nil {
+ code := ""
+ msg := ""
+ if bootstrapErr != nil {
+ code = "network_error"
+ msg = bootstrapErr.Error()
+ }
+ recorder.addStep("sora2_bootstrap", "failed", bootstrapStatus, code, msg)
}
}
if inviteStatus != http.StatusOK {
- if isCloudflareChallengeResponse(inviteStatus, inviteBody) {
- s.sendEvent(c, TestEvent{Type: "content", Text: formatCloudflareChallengeMessage("Sora2 invite check blocked by Cloudflare challenge (HTTP 403)", inviteHeader, inviteBody)})
+ if isCloudflareChallengeResponse(inviteStatus, inviteHeader, inviteBody) {
+ if recorder != nil {
+ recorder.addStep("sora2_invite", "failed", inviteStatus, "cf_challenge", "Cloudflare challenge detected")
+ }
+ s.logSoraCloudflareChallenge(account, proxyURL, soraInviteMineURL, inviteHeader, inviteBody)
+ s.sendEvent(c, TestEvent{Type: "content", Text: formatCloudflareChallengeMessage(fmt.Sprintf("Sora2 invite check blocked by Cloudflare challenge (HTTP %d)", inviteStatus), inviteHeader, inviteBody)})
return
}
+ upstreamCode, upstreamMessage := soraerror.ExtractUpstreamErrorCodeAndMessage(inviteBody)
+ if recorder != nil {
+ recorder.addStep("sora2_invite", "failed", inviteStatus, upstreamCode, upstreamMessage)
+ }
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 invite check returned %d", inviteStatus)})
return
}
+ if recorder != nil {
+ recorder.addStep("sora2_invite", "success", inviteStatus, "", "invite endpoint ok")
+ }
if summary := parseSoraInviteSummary(inviteBody); summary != "" {
s.sendEvent(c, TestEvent{Type: "content", Text: summary})
@@ -656,17 +851,31 @@ func (s *AccountTestService) testSora2Capabilities(
enableTLSFingerprint,
)
if remainingErr != nil {
+ if recorder != nil {
+ recorder.addStep("sora2_remaining", "failed", 0, "network_error", remainingErr.Error())
+ }
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 remaining check skipped: %s", remainingErr.Error())})
return
}
if remainingStatus != http.StatusOK {
- if isCloudflareChallengeResponse(remainingStatus, remainingBody) {
- s.sendEvent(c, TestEvent{Type: "content", Text: formatCloudflareChallengeMessage("Sora2 remaining check blocked by Cloudflare challenge (HTTP 403)", remainingHeader, remainingBody)})
+ if isCloudflareChallengeResponse(remainingStatus, remainingHeader, remainingBody) {
+ if recorder != nil {
+ recorder.addStep("sora2_remaining", "failed", remainingStatus, "cf_challenge", "Cloudflare challenge detected")
+ }
+ s.logSoraCloudflareChallenge(account, proxyURL, soraRemainingURL, remainingHeader, remainingBody)
+ s.sendEvent(c, TestEvent{Type: "content", Text: formatCloudflareChallengeMessage(fmt.Sprintf("Sora2 remaining check blocked by Cloudflare challenge (HTTP %d)", remainingStatus), remainingHeader, remainingBody)})
return
}
+ upstreamCode, upstreamMessage := soraerror.ExtractUpstreamErrorCodeAndMessage(remainingBody)
+ if recorder != nil {
+ recorder.addStep("sora2_remaining", "failed", remainingStatus, upstreamCode, upstreamMessage)
+ }
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 remaining check returned %d", remainingStatus)})
return
}
+ if recorder != nil {
+ recorder.addStep("sora2_remaining", "success", remainingStatus, "", "remaining endpoint ok")
+ }
if summary := parseSoraRemainingSummary(remainingBody); summary != "" {
s.sendEvent(c, TestEvent{Type: "content", Text: summary})
} else {
@@ -789,42 +998,16 @@ func (s *AccountTestService) shouldEnableSoraTLSFingerprint() bool {
return !s.cfg.Sora.Client.DisableTLSFingerprint
}
-func isCloudflareChallengeResponse(statusCode int, body []byte) bool {
- if statusCode != http.StatusForbidden {
- return false
- }
- preview := strings.ToLower(truncateSoraErrorBody(body, 4096))
- return strings.Contains(preview, "window._cf_chl_opt") ||
- strings.Contains(preview, "just a moment") ||
- strings.Contains(preview, "enable javascript and cookies to continue")
+func isCloudflareChallengeResponse(statusCode int, headers http.Header, body []byte) bool {
+ return soraerror.IsCloudflareChallengeResponse(statusCode, headers, body)
}
func formatCloudflareChallengeMessage(base string, headers http.Header, body []byte) string {
- rayID := extractCloudflareRayID(headers, body)
- if rayID == "" {
- return base
- }
- return fmt.Sprintf("%s (cf-ray: %s)", base, rayID)
+ return soraerror.FormatCloudflareChallengeMessage(base, headers, body)
}
func extractCloudflareRayID(headers http.Header, body []byte) string {
- if headers != nil {
- rayID := strings.TrimSpace(headers.Get("cf-ray"))
- if rayID != "" {
- return rayID
- }
- rayID = strings.TrimSpace(headers.Get("Cf-Ray"))
- if rayID != "" {
- return rayID
- }
- }
-
- preview := truncateSoraErrorBody(body, 8192)
- matches := cloudflareRayPattern.FindStringSubmatch(preview)
- if len(matches) >= 2 {
- return strings.TrimSpace(matches[1])
- }
- return ""
+ return soraerror.ExtractCloudflareRayID(headers, body)
}
func extractSoraEgressIPHint(headers http.Header) string {
@@ -897,14 +1080,7 @@ func (s *AccountTestService) logSoraCloudflareChallenge(account *Account, proxyU
}
func truncateSoraErrorBody(body []byte, max int) string {
- if max <= 0 {
- max = 512
- }
- raw := strings.TrimSpace(string(body))
- if len(raw) <= max {
- return raw
- }
- return raw[:max] + "...(truncated)"
+ return soraerror.TruncateBody(body, max)
}
// testAntigravityAccountConnection tests an Antigravity account's connection
diff --git a/backend/internal/service/account_test_service_sora_test.go b/backend/internal/service/account_test_service_sora_test.go
index b5389ea2..3dfac786 100644
--- a/backend/internal/service/account_test_service_sora_test.go
+++ b/backend/internal/service/account_test_service_sora_test.go
@@ -7,6 +7,7 @@ import (
"net/http/httptest"
"strings"
"testing"
+ "time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/gin-gonic/gin"
@@ -109,6 +110,8 @@ func TestAccountTestService_testSoraAccountConnection_WithSubscription(t *testin
require.Contains(t, body, "Subscription: ChatGPT Plus | chatgpt_plus | end=2026-12-31T00:00:00Z")
require.Contains(t, body, "Sora2: supported | invite=inv_abc | used=3/50")
require.Contains(t, body, "Sora2 remaining: 27 | reset_in=46833s")
+ require.Contains(t, body, `"type":"sora_test_result"`)
+ require.Contains(t, body, `"status":"success"`)
require.Contains(t, body, `"type":"test_complete","success":true`)
}
@@ -141,6 +144,8 @@ func TestAccountTestService_testSoraAccountConnection_SubscriptionFailedStillSuc
require.Contains(t, body, "Sora connection OK - User: demo-user")
require.Contains(t, body, "Subscription check returned 403")
require.Contains(t, body, "Sora2 invite check returned 401")
+ require.Contains(t, body, `"type":"sora_test_result"`)
+ require.Contains(t, body, `"status":"partial_success"`)
require.Contains(t, body, `"type":"test_complete","success":true`)
}
@@ -173,6 +178,97 @@ func TestAccountTestService_testSoraAccountConnection_CloudflareChallenge(t *tes
require.Contains(t, body, "cf-ray: 9cff2d62d83bb98d")
}
+func TestAccountTestService_testSoraAccountConnection_CloudflareChallenge429WithHeader(t *testing.T) {
+ upstream := &queuedHTTPUpstream{
+ responses: []*http.Response{
+ newJSONResponseWithHeader(http.StatusTooManyRequests, `
Just a moment...`, "cf-mitigated", "challenge"),
+ },
+ }
+ svc := &AccountTestService{httpUpstream: upstream}
+ account := &Account{
+ ID: 1,
+ Platform: PlatformSora,
+ Type: AccountTypeOAuth,
+ Concurrency: 1,
+ Credentials: map[string]any{
+ "access_token": "test_token",
+ },
+ }
+
+ c, rec := newSoraTestContext()
+ err := svc.testSoraAccountConnection(c, account)
+
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "Cloudflare challenge")
+ require.Contains(t, err.Error(), "HTTP 429")
+ body := rec.Body.String()
+ require.Contains(t, body, "Cloudflare challenge")
+}
+
+func TestAccountTestService_testSoraAccountConnection_TokenInvalidated(t *testing.T) {
+ upstream := &queuedHTTPUpstream{
+ responses: []*http.Response{
+ newJSONResponse(http.StatusUnauthorized, `{"error":{"code":"token_invalidated","message":"Token invalid"}}`),
+ },
+ }
+ svc := &AccountTestService{httpUpstream: upstream}
+ account := &Account{
+ ID: 1,
+ Platform: PlatformSora,
+ Type: AccountTypeOAuth,
+ Concurrency: 1,
+ Credentials: map[string]any{
+ "access_token": "test_token",
+ },
+ }
+
+ c, rec := newSoraTestContext()
+ err := svc.testSoraAccountConnection(c, account)
+
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "token_invalidated")
+ body := rec.Body.String()
+ require.Contains(t, body, `"type":"sora_test_result"`)
+ require.Contains(t, body, `"status":"failed"`)
+ require.Contains(t, body, "token_invalidated")
+ require.NotContains(t, body, `"type":"test_complete","success":true`)
+}
+
+func TestAccountTestService_testSoraAccountConnection_RateLimited(t *testing.T) {
+ upstream := &queuedHTTPUpstream{
+ responses: []*http.Response{
+ newJSONResponse(http.StatusOK, `{"email":"demo@example.com"}`),
+ },
+ }
+ svc := &AccountTestService{
+ httpUpstream: upstream,
+ soraTestCooldown: time.Hour,
+ }
+ account := &Account{
+ ID: 1,
+ Platform: PlatformSora,
+ Type: AccountTypeOAuth,
+ Concurrency: 1,
+ Credentials: map[string]any{
+ "access_token": "test_token",
+ },
+ }
+
+ c1, _ := newSoraTestContext()
+ err := svc.testSoraAccountConnection(c1, account)
+ require.NoError(t, err)
+
+ c2, rec2 := newSoraTestContext()
+ err = svc.testSoraAccountConnection(c2, account)
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "测试过于频繁")
+ body := rec2.Body.String()
+ require.Contains(t, body, `"type":"sora_test_result"`)
+ require.Contains(t, body, `"code":"test_rate_limited"`)
+ require.Contains(t, body, `"status":"failed"`)
+ require.NotContains(t, body, `"type":"test_complete","success":true`)
+}
+
func TestAccountTestService_testSoraAccountConnection_SubscriptionCloudflareChallengeWithRay(t *testing.T) {
upstream := &queuedHTTPUpstream{
responses: []*http.Response{
diff --git a/backend/internal/service/sora_client.go b/backend/internal/service/sora_client.go
index 7ca99ad2..e1af5ead 100644
--- a/backend/internal/service/sora_client.go
+++ b/backend/internal/service/sora_client.go
@@ -95,6 +95,16 @@ var soraDesktopUserAgents = []string{
"Mozilla/5.0 (Windows NT 11.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36",
}
+var soraMobileUserAgents = []string{
+ "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)",
+ "Sora/1.2026.007 (Android 14; SM-G998B; build 2600700)",
+ "Sora/1.2026.007 (Android 15; Pixel 8 Pro; build 2600700)",
+ "Sora/1.2026.007 (Android 14; Pixel 7; build 2600700)",
+ "Sora/1.2026.007 (Android 15; 2211133C; build 2600700)",
+ "Sora/1.2026.007 (Android 14; SM-S918B; build 2600700)",
+ "Sora/1.2026.007 (Android 15; OnePlus 12; build 2600700)",
+}
+
var soraRand = rand.New(rand.NewSource(time.Now().UnixNano()))
var soraRandMu sync.Mutex
var soraPerfStart = time.Now()
@@ -106,6 +116,17 @@ type SoraClient interface {
UploadImage(ctx context.Context, account *Account, data []byte, filename string) (string, error)
CreateImageTask(ctx context.Context, account *Account, req SoraImageRequest) (string, error)
CreateVideoTask(ctx context.Context, account *Account, req SoraVideoRequest) (string, error)
+ CreateStoryboardTask(ctx context.Context, account *Account, req SoraStoryboardRequest) (string, error)
+ UploadCharacterVideo(ctx context.Context, account *Account, data []byte) (string, error)
+ GetCameoStatus(ctx context.Context, account *Account, cameoID string) (*SoraCameoStatus, error)
+ DownloadCharacterImage(ctx context.Context, account *Account, imageURL string) ([]byte, error)
+ UploadCharacterImage(ctx context.Context, account *Account, data []byte) (string, error)
+ FinalizeCharacter(ctx context.Context, account *Account, req SoraCharacterFinalizeRequest) (string, error)
+ SetCharacterPublic(ctx context.Context, account *Account, cameoID string) error
+ DeleteCharacter(ctx context.Context, account *Account, characterID string) error
+ PostVideoForWatermarkFree(ctx context.Context, account *Account, generationID string) (string, error)
+ DeletePost(ctx context.Context, account *Account, postID string) error
+ GetWatermarkFreeURLCustom(ctx context.Context, account *Account, parseURL, parseToken, postID string) (string, error)
EnhancePrompt(ctx context.Context, account *Account, prompt, expansionLevel string, durationS int) (string, error)
GetImageTask(ctx context.Context, account *Account, taskID string) (*SoraImageTaskStatus, error)
GetVideoTask(ctx context.Context, account *Account, taskID string) (*SoraVideoTaskStatus, error)
@@ -128,6 +149,17 @@ type SoraVideoRequest struct {
Size string
MediaID string
RemixTargetID string
+ CameoIDs []string
+}
+
+// SoraStoryboardRequest 分镜视频生成请求参数
+type SoraStoryboardRequest struct {
+ Prompt string
+ Orientation string
+ Frames int
+ Model string
+ Size string
+ MediaID string
}
// SoraImageTaskStatus 图片任务状态
@@ -141,11 +173,32 @@ type SoraImageTaskStatus struct {
// SoraVideoTaskStatus 视频任务状态
type SoraVideoTaskStatus struct {
- ID string
- Status string
- ProgressPct int
- URLs []string
- ErrorMsg string
+ ID string
+ Status string
+ ProgressPct int
+ URLs []string
+ GenerationID string
+ ErrorMsg string
+}
+
+// SoraCameoStatus 角色处理中间态
+type SoraCameoStatus struct {
+ Status string
+ StatusMessage string
+ DisplayNameHint string
+ UsernameHint string
+ ProfileAssetURL string
+ InstructionSetHint any
+ InstructionSet any
+}
+
+// SoraCharacterFinalizeRequest 角色定稿请求参数
+type SoraCharacterFinalizeRequest struct {
+ CameoID string
+ Username string
+ DisplayName string
+ ProfileAssetPointer string
+ InstructionSet any
}
// SoraUpstreamError 上游错误
@@ -407,6 +460,9 @@ func (c *SoraDirectClient) CreateVideoTask(ctx context.Context, account *Account
payload["remix_target_id"] = req.RemixTargetID
payload["cameo_ids"] = []string{}
payload["cameo_replacements"] = map[string]any{}
+ } else if len(req.CameoIDs) > 0 {
+ payload["cameo_ids"] = req.CameoIDs
+ payload["cameo_replacements"] = map[string]any{}
}
headers := c.buildBaseHeaders(token, userAgent)
@@ -434,6 +490,425 @@ func (c *SoraDirectClient) CreateVideoTask(ctx context.Context, account *Account
return taskID, nil
}
+func (c *SoraDirectClient) CreateStoryboardTask(ctx context.Context, account *Account, req SoraStoryboardRequest) (string, error) {
+ token, err := c.getAccessToken(ctx, account)
+ if err != nil {
+ return "", err
+ }
+ userAgent := c.taskUserAgent()
+ proxyURL := c.resolveProxyURL(account)
+ orientation := req.Orientation
+ if orientation == "" {
+ orientation = "landscape"
+ }
+ nFrames := req.Frames
+ if nFrames <= 0 {
+ nFrames = 450
+ }
+ model := req.Model
+ if model == "" {
+ model = "sy_8"
+ }
+ size := req.Size
+ if size == "" {
+ size = "small"
+ }
+
+ inpaintItems := []map[string]any{}
+ if strings.TrimSpace(req.MediaID) != "" {
+ inpaintItems = append(inpaintItems, map[string]any{
+ "kind": "upload",
+ "upload_id": req.MediaID,
+ })
+ }
+ payload := map[string]any{
+ "kind": "video",
+ "prompt": req.Prompt,
+ "title": "Draft your video",
+ "orientation": orientation,
+ "size": size,
+ "n_frames": nFrames,
+ "storyboard_id": nil,
+ "inpaint_items": inpaintItems,
+ "remix_target_id": nil,
+ "model": model,
+ "metadata": nil,
+ "style_id": nil,
+ "cameo_ids": nil,
+ "cameo_replacements": nil,
+ "audio_caption": nil,
+ "audio_transcript": nil,
+ "video_caption": nil,
+ }
+
+ headers := c.buildBaseHeaders(token, userAgent)
+ headers.Set("Content-Type", "application/json")
+ headers.Set("Origin", "https://sora.chatgpt.com")
+ headers.Set("Referer", "https://sora.chatgpt.com/")
+ body, err := json.Marshal(payload)
+ if err != nil {
+ return "", err
+ }
+ sentinel, err := c.generateSentinelToken(ctx, account, token, userAgent, proxyURL)
+ if err != nil {
+ return "", err
+ }
+ headers.Set("openai-sentinel-token", sentinel)
+
+ respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodPost, c.buildURL("/nf/create/storyboard"), headers, bytes.NewReader(body), true)
+ if err != nil {
+ return "", err
+ }
+ taskID := strings.TrimSpace(gjson.GetBytes(respBody, "id").String())
+ if taskID == "" {
+ return "", errors.New("storyboard task response missing id")
+ }
+ return taskID, nil
+}
+
+func (c *SoraDirectClient) UploadCharacterVideo(ctx context.Context, account *Account, data []byte) (string, error) {
+ if len(data) == 0 {
+ return "", errors.New("empty video data")
+ }
+ token, err := c.getAccessToken(ctx, account)
+ if err != nil {
+ return "", err
+ }
+ userAgent := c.taskUserAgent()
+ proxyURL := c.resolveProxyURL(account)
+
+ var body bytes.Buffer
+ writer := multipart.NewWriter(&body)
+ partHeader := make(textproto.MIMEHeader)
+ partHeader.Set("Content-Disposition", `form-data; name="file"; filename="video.mp4"`)
+ partHeader.Set("Content-Type", "video/mp4")
+ part, err := writer.CreatePart(partHeader)
+ if err != nil {
+ return "", err
+ }
+ if _, err := part.Write(data); err != nil {
+ return "", err
+ }
+ if err := writer.WriteField("timestamps", "0,3"); err != nil {
+ return "", err
+ }
+ if err := writer.Close(); err != nil {
+ return "", err
+ }
+
+ headers := c.buildBaseHeaders(token, userAgent)
+ headers.Set("Content-Type", writer.FormDataContentType())
+ respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodPost, c.buildURL("/characters/upload"), headers, &body, false)
+ if err != nil {
+ return "", err
+ }
+ cameoID := strings.TrimSpace(gjson.GetBytes(respBody, "id").String())
+ if cameoID == "" {
+ return "", errors.New("character upload response missing id")
+ }
+ return cameoID, nil
+}
+
+func (c *SoraDirectClient) GetCameoStatus(ctx context.Context, account *Account, cameoID string) (*SoraCameoStatus, error) {
+ token, err := c.getAccessToken(ctx, account)
+ if err != nil {
+ return nil, err
+ }
+ userAgent := c.taskUserAgent()
+ proxyURL := c.resolveProxyURL(account)
+ headers := c.buildBaseHeaders(token, userAgent)
+ respBody, _, err := c.doRequestWithProxy(
+ ctx,
+ account,
+ proxyURL,
+ http.MethodGet,
+ c.buildURL("/project_y/cameos/in_progress/"+strings.TrimSpace(cameoID)),
+ headers,
+ nil,
+ false,
+ )
+ if err != nil {
+ return nil, err
+ }
+ return &SoraCameoStatus{
+ Status: strings.TrimSpace(gjson.GetBytes(respBody, "status").String()),
+ StatusMessage: strings.TrimSpace(gjson.GetBytes(respBody, "status_message").String()),
+ DisplayNameHint: strings.TrimSpace(gjson.GetBytes(respBody, "display_name_hint").String()),
+ UsernameHint: strings.TrimSpace(gjson.GetBytes(respBody, "username_hint").String()),
+ ProfileAssetURL: strings.TrimSpace(gjson.GetBytes(respBody, "profile_asset_url").String()),
+ InstructionSetHint: gjson.GetBytes(respBody, "instruction_set_hint").Value(),
+ InstructionSet: gjson.GetBytes(respBody, "instruction_set").Value(),
+ }, nil
+}
+
+func (c *SoraDirectClient) DownloadCharacterImage(ctx context.Context, account *Account, imageURL string) ([]byte, error) {
+ token, err := c.getAccessToken(ctx, account)
+ if err != nil {
+ return nil, err
+ }
+ userAgent := c.taskUserAgent()
+ proxyURL := c.resolveProxyURL(account)
+ headers := c.buildBaseHeaders(token, userAgent)
+ headers.Set("Accept", "image/*,*/*;q=0.8")
+
+ respBody, _, err := c.doRequestWithProxy(
+ ctx,
+ account,
+ proxyURL,
+ http.MethodGet,
+ strings.TrimSpace(imageURL),
+ headers,
+ nil,
+ false,
+ )
+ if err != nil {
+ return nil, err
+ }
+ return respBody, nil
+}
+
+func (c *SoraDirectClient) UploadCharacterImage(ctx context.Context, account *Account, data []byte) (string, error) {
+ if len(data) == 0 {
+ return "", errors.New("empty character image")
+ }
+ token, err := c.getAccessToken(ctx, account)
+ if err != nil {
+ return "", err
+ }
+ userAgent := c.taskUserAgent()
+ proxyURL := c.resolveProxyURL(account)
+
+ var body bytes.Buffer
+ writer := multipart.NewWriter(&body)
+ partHeader := make(textproto.MIMEHeader)
+ partHeader.Set("Content-Disposition", `form-data; name="file"; filename="profile.webp"`)
+ partHeader.Set("Content-Type", "image/webp")
+ part, err := writer.CreatePart(partHeader)
+ if err != nil {
+ return "", err
+ }
+ if _, err := part.Write(data); err != nil {
+ return "", err
+ }
+ if err := writer.WriteField("use_case", "profile"); err != nil {
+ return "", err
+ }
+ if err := writer.Close(); err != nil {
+ return "", err
+ }
+
+ headers := c.buildBaseHeaders(token, userAgent)
+ headers.Set("Content-Type", writer.FormDataContentType())
+ respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodPost, c.buildURL("/project_y/file/upload"), headers, &body, false)
+ if err != nil {
+ return "", err
+ }
+ assetPointer := strings.TrimSpace(gjson.GetBytes(respBody, "asset_pointer").String())
+ if assetPointer == "" {
+ return "", errors.New("character image upload response missing asset_pointer")
+ }
+ return assetPointer, nil
+}
+
+func (c *SoraDirectClient) FinalizeCharacter(ctx context.Context, account *Account, req SoraCharacterFinalizeRequest) (string, error) {
+ token, err := c.getAccessToken(ctx, account)
+ if err != nil {
+ return "", err
+ }
+ userAgent := c.taskUserAgent()
+ proxyURL := c.resolveProxyURL(account)
+ payload := map[string]any{
+ "cameo_id": req.CameoID,
+ "username": req.Username,
+ "display_name": req.DisplayName,
+ "profile_asset_pointer": req.ProfileAssetPointer,
+ "instruction_set": nil,
+ "safety_instruction_set": nil,
+ }
+ body, err := json.Marshal(payload)
+ if err != nil {
+ return "", err
+ }
+ headers := c.buildBaseHeaders(token, userAgent)
+ headers.Set("Content-Type", "application/json")
+ headers.Set("Origin", "https://sora.chatgpt.com")
+ headers.Set("Referer", "https://sora.chatgpt.com/")
+ respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodPost, c.buildURL("/characters/finalize"), headers, bytes.NewReader(body), false)
+ if err != nil {
+ return "", err
+ }
+ characterID := strings.TrimSpace(gjson.GetBytes(respBody, "character.character_id").String())
+ if characterID == "" {
+ return "", errors.New("character finalize response missing character_id")
+ }
+ return characterID, nil
+}
+
+func (c *SoraDirectClient) SetCharacterPublic(ctx context.Context, account *Account, cameoID string) error {
+ token, err := c.getAccessToken(ctx, account)
+ if err != nil {
+ return err
+ }
+ userAgent := c.taskUserAgent()
+ proxyURL := c.resolveProxyURL(account)
+ payload := map[string]any{"visibility": "public"}
+ body, err := json.Marshal(payload)
+ if err != nil {
+ return err
+ }
+ headers := c.buildBaseHeaders(token, userAgent)
+ headers.Set("Content-Type", "application/json")
+ headers.Set("Origin", "https://sora.chatgpt.com")
+ headers.Set("Referer", "https://sora.chatgpt.com/")
+ _, _, err = c.doRequestWithProxy(
+ ctx,
+ account,
+ proxyURL,
+ http.MethodPost,
+ c.buildURL("/project_y/cameos/by_id/"+strings.TrimSpace(cameoID)+"/update_v2"),
+ headers,
+ bytes.NewReader(body),
+ false,
+ )
+ return err
+}
+
+func (c *SoraDirectClient) DeleteCharacter(ctx context.Context, account *Account, characterID string) error {
+ token, err := c.getAccessToken(ctx, account)
+ if err != nil {
+ return err
+ }
+ userAgent := c.taskUserAgent()
+ proxyURL := c.resolveProxyURL(account)
+ headers := c.buildBaseHeaders(token, userAgent)
+ _, _, err = c.doRequestWithProxy(
+ ctx,
+ account,
+ proxyURL,
+ http.MethodDelete,
+ c.buildURL("/project_y/characters/"+strings.TrimSpace(characterID)),
+ headers,
+ nil,
+ false,
+ )
+ return err
+}
+
+func (c *SoraDirectClient) PostVideoForWatermarkFree(ctx context.Context, account *Account, generationID string) (string, error) {
+ token, err := c.getAccessToken(ctx, account)
+ if err != nil {
+ return "", err
+ }
+ userAgent := c.taskUserAgent()
+ proxyURL := c.resolveProxyURL(account)
+ payload := map[string]any{
+ "attachments_to_create": []map[string]any{
+ {
+ "generation_id": generationID,
+ "kind": "sora",
+ },
+ },
+ "post_text": "",
+ }
+ body, err := json.Marshal(payload)
+ if err != nil {
+ return "", err
+ }
+ headers := c.buildBaseHeaders(token, userAgent)
+ headers.Set("Content-Type", "application/json")
+ headers.Set("Origin", "https://sora.chatgpt.com")
+ headers.Set("Referer", "https://sora.chatgpt.com/")
+ sentinel, err := c.generateSentinelToken(ctx, account, token, userAgent, proxyURL)
+ if err != nil {
+ return "", err
+ }
+ headers.Set("openai-sentinel-token", sentinel)
+ respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodPost, c.buildURL("/project_y/post"), headers, bytes.NewReader(body), true)
+ if err != nil {
+ return "", err
+ }
+ postID := strings.TrimSpace(gjson.GetBytes(respBody, "post.id").String())
+ if postID == "" {
+ return "", errors.New("watermark-free publish response missing post.id")
+ }
+ return postID, nil
+}
+
+func (c *SoraDirectClient) DeletePost(ctx context.Context, account *Account, postID string) error {
+ token, err := c.getAccessToken(ctx, account)
+ if err != nil {
+ return err
+ }
+ userAgent := c.taskUserAgent()
+ proxyURL := c.resolveProxyURL(account)
+ headers := c.buildBaseHeaders(token, userAgent)
+ _, _, err = c.doRequestWithProxy(
+ ctx,
+ account,
+ proxyURL,
+ http.MethodDelete,
+ c.buildURL("/project_y/post/"+strings.TrimSpace(postID)),
+ headers,
+ nil,
+ false,
+ )
+ return err
+}
+
+func (c *SoraDirectClient) GetWatermarkFreeURLCustom(ctx context.Context, account *Account, parseURL, parseToken, postID string) (string, error) {
+ parseURL = strings.TrimRight(strings.TrimSpace(parseURL), "/")
+ if parseURL == "" {
+ return "", errors.New("custom parse url is required")
+ }
+ if strings.TrimSpace(parseToken) == "" {
+ return "", errors.New("custom parse token is required")
+ }
+ shareURL := "https://sora.chatgpt.com/p/" + strings.TrimSpace(postID)
+ payload := map[string]any{
+ "url": shareURL,
+ "token": strings.TrimSpace(parseToken),
+ }
+ body, err := json.Marshal(payload)
+ if err != nil {
+ return "", err
+ }
+ req, err := http.NewRequestWithContext(ctx, http.MethodPost, parseURL+"/get-sora-link", bytes.NewReader(body))
+ if err != nil {
+ return "", err
+ }
+ req.Header.Set("Content-Type", "application/json")
+
+ proxyURL := c.resolveProxyURL(account)
+ accountID := int64(0)
+ accountConcurrency := 0
+ if account != nil {
+ accountID = account.ID
+ accountConcurrency = account.Concurrency
+ }
+ var resp *http.Response
+ if c.httpUpstream != nil {
+ resp, err = c.httpUpstream.Do(req, proxyURL, accountID, accountConcurrency)
+ } else {
+ resp, err = http.DefaultClient.Do(req)
+ }
+ if err != nil {
+ return "", err
+ }
+ defer func() { _ = resp.Body.Close() }()
+ raw, err := io.ReadAll(io.LimitReader(resp.Body, 4<<20))
+ if err != nil {
+ return "", err
+ }
+ if resp.StatusCode != http.StatusOK {
+ return "", fmt.Errorf("custom parse failed: %d %s", resp.StatusCode, truncateForLog(raw, 256))
+ }
+ downloadLink := strings.TrimSpace(gjson.GetBytes(raw, "download_link").String())
+ if downloadLink == "" {
+ return "", errors.New("custom parse response missing download_link")
+ }
+ return downloadLink, nil
+}
+
func (c *SoraDirectClient) EnhancePrompt(ctx context.Context, account *Account, prompt, expansionLevel string, durationS int) (string, error) {
token, err := c.getAccessToken(ctx, account)
if err != nil {
@@ -607,6 +1082,7 @@ func (c *SoraDirectClient) GetVideoTask(ctx context.Context, account *Account, t
if draft.Get("task_id").String() != taskID {
return true
}
+ generationID := strings.TrimSpace(draft.Get("id").String())
kind := strings.TrimSpace(draft.Get("kind").String())
reason := strings.TrimSpace(draft.Get("reason_str").String())
if reason == "" {
@@ -623,15 +1099,17 @@ func (c *SoraDirectClient) GetVideoTask(ctx context.Context, account *Account, t
msg = "Content violates guardrails"
}
draftFound = &SoraVideoTaskStatus{
- ID: taskID,
- Status: "failed",
- ErrorMsg: msg,
+ ID: taskID,
+ Status: "failed",
+ GenerationID: generationID,
+ ErrorMsg: msg,
}
} else {
draftFound = &SoraVideoTaskStatus{
- ID: taskID,
- Status: "completed",
- URLs: []string{urlStr},
+ ID: taskID,
+ Status: "completed",
+ GenerationID: generationID,
+ URLs: []string{urlStr},
}
}
return false
@@ -675,8 +1153,11 @@ func (c *SoraDirectClient) taskUserAgent() string {
return ua
}
}
+ if len(soraMobileUserAgents) > 0 {
+ return soraMobileUserAgents[soraRandInt(len(soraMobileUserAgents))]
+ }
if len(soraDesktopUserAgents) > 0 {
- return soraDesktopUserAgents[0]
+ return soraDesktopUserAgents[soraRandInt(len(soraDesktopUserAgents))]
}
return soraDefaultUserAgent
}
@@ -1149,10 +1630,7 @@ func shouldAttemptSoraTokenRecover(statusCode int, rawURL string) bool {
}
func (c *SoraDirectClient) doHTTP(req *http.Request, proxyURL string, account *Account) (*http.Response, error) {
- enableTLS := true
- if c != nil && c.cfg != nil && c.cfg.Sora.Client.DisableTLSFingerprint {
- enableTLS = false
- }
+ enableTLS := c == nil || c.cfg == nil || !c.cfg.Sora.Client.DisableTLSFingerprint
if c.httpUpstream != nil {
accountID := int64(0)
accountConcurrency := 0
@@ -1288,6 +1766,15 @@ func soraRandFloat() float64 {
return soraRand.Float64()
}
+func soraRandInt(max int) int {
+ if max <= 1 {
+ return 0
+ }
+ soraRandMu.Lock()
+ defer soraRandMu.Unlock()
+ return soraRand.Intn(max)
+}
+
func soraBuildPowConfig(userAgent string) []any {
userAgent = strings.TrimSpace(userAgent)
if userAgent == "" && len(soraDesktopUserAgents) > 0 {
diff --git a/backend/internal/service/sora_client_test.go b/backend/internal/service/sora_client_test.go
index d50b2d85..9e528f97 100644
--- a/backend/internal/service/sora_client_test.go
+++ b/backend/internal/service/sora_client_test.go
@@ -393,10 +393,22 @@ func (u *soraClientRecordingUpstream) DoWithTLS(req *http.Request, proxyURL stri
return newSoraClientMockResponse(http.StatusOK, `{"token":"sentinel-token","turnstile":{"dx":"ok"}}`), nil
case "/backend/nf/create":
return newSoraClientMockResponse(http.StatusOK, `{"id":"task-123"}`), nil
+ case "/backend/nf/create/storyboard":
+ return newSoraClientMockResponse(http.StatusOK, `{"id":"storyboard-123"}`), nil
case "/backend/uploads":
return newSoraClientMockResponse(http.StatusOK, `{"id":"upload-123"}`), nil
case "/backend/nf/check":
return newSoraClientMockResponse(http.StatusOK, `{"rate_limit_and_credit_balance":{"estimated_num_videos_remaining":1,"rate_limit_reached":false}}`), nil
+ case "/backend/characters/upload":
+ return newSoraClientMockResponse(http.StatusOK, `{"id":"cameo-123"}`), nil
+ case "/backend/project_y/cameos/in_progress/cameo-123":
+ return newSoraClientMockResponse(http.StatusOK, `{"status":"finalized","status_message":"Completed","username_hint":"foo.bar","display_name_hint":"Bar","profile_asset_url":"https://example.com/avatar.webp"}`), nil
+ case "/backend/project_y/file/upload":
+ return newSoraClientMockResponse(http.StatusOK, `{"asset_pointer":"asset-123"}`), nil
+ case "/backend/characters/finalize":
+ return newSoraClientMockResponse(http.StatusOK, `{"character":{"character_id":"character-123"}}`), nil
+ case "/backend/project_y/post":
+ return newSoraClientMockResponse(http.StatusOK, `{"post":{"id":"s_post"}}`), nil
default:
return newSoraClientMockResponse(http.StatusOK, `{"ok":true}`), nil
}
@@ -410,9 +422,13 @@ func newSoraClientMockResponse(statusCode int, body string) *http.Response {
}
}
-func TestSoraDirectClient_TaskUserAgent_DefaultDesktopFallback(t *testing.T) {
+func TestSoraDirectClient_TaskUserAgent_DefaultMobileFallback(t *testing.T) {
client := NewSoraDirectClient(&config.Config{}, nil, nil)
- require.Equal(t, soraDesktopUserAgents[0], client.taskUserAgent())
+ ua := client.taskUserAgent()
+ require.NotEmpty(t, ua)
+ allowed := append([]string{}, soraMobileUserAgents...)
+ allowed = append(allowed, soraDesktopUserAgents...)
+ require.Contains(t, allowed, ua)
}
func TestSoraDirectClient_CreateVideoTask_UsesSameUserAgentAndProxyForSentinelAndCreate(t *testing.T) {
@@ -460,7 +476,7 @@ func TestSoraDirectClient_CreateVideoTask_UsesSameUserAgentAndProxyForSentinelAn
require.Equal(t, "/backend/nf/create", createCall.Path)
require.Equal(t, "http://127.0.0.1:8080", sentinelCall.ProxyURL)
require.Equal(t, sentinelCall.ProxyURL, createCall.ProxyURL)
- require.Equal(t, soraDesktopUserAgents[0], sentinelCall.UserAgent)
+ require.NotEmpty(t, sentinelCall.UserAgent)
require.Equal(t, sentinelCall.UserAgent, createCall.UserAgent)
}
@@ -495,7 +511,7 @@ func TestSoraDirectClient_UploadImage_UsesTaskUserAgentAndProxy(t *testing.T) {
require.Len(t, upstream.calls, 1)
require.Equal(t, "/backend/uploads", upstream.calls[0].Path)
require.Equal(t, "http://127.0.0.1:8080", upstream.calls[0].ProxyURL)
- require.Equal(t, soraDesktopUserAgents[0], upstream.calls[0].UserAgent)
+ require.NotEmpty(t, upstream.calls[0].UserAgent)
}
func TestSoraDirectClient_PreflightCheck_UsesTaskUserAgentAndProxy(t *testing.T) {
@@ -528,5 +544,98 @@ func TestSoraDirectClient_PreflightCheck_UsesTaskUserAgentAndProxy(t *testing.T)
require.Len(t, upstream.calls, 1)
require.Equal(t, "/backend/nf/check", upstream.calls[0].Path)
require.Equal(t, "http://127.0.0.1:8080", upstream.calls[0].ProxyURL)
- require.Equal(t, soraDesktopUserAgents[0], upstream.calls[0].UserAgent)
+ require.NotEmpty(t, upstream.calls[0].UserAgent)
+}
+
+func TestSoraDirectClient_CreateStoryboardTask(t *testing.T) {
+ originPowTokenGenerator := soraPowTokenGenerator
+ soraPowTokenGenerator = func(_ string) string { return "gAAAAACmock" }
+ defer func() { soraPowTokenGenerator = originPowTokenGenerator }()
+
+ upstream := &soraClientRecordingUpstream{}
+ cfg := &config.Config{
+ Sora: config.SoraConfig{
+ Client: config.SoraClientConfig{
+ BaseURL: "https://sora.chatgpt.com/backend",
+ },
+ },
+ }
+ client := NewSoraDirectClient(cfg, upstream, nil)
+ account := &Account{
+ ID: 51,
+ Credentials: map[string]any{
+ "access_token": "access-token",
+ "expires_at": time.Now().Add(30 * time.Minute).Format(time.RFC3339),
+ },
+ }
+
+ taskID, err := client.CreateStoryboardTask(context.Background(), account, SoraStoryboardRequest{
+ Prompt: "Shot 1:\nduration: 5sec\nScene: cat",
+ })
+ require.NoError(t, err)
+ require.Equal(t, "storyboard-123", taskID)
+ require.Len(t, upstream.calls, 2)
+ require.Equal(t, "/backend-api/sentinel/req", upstream.calls[0].Path)
+ require.Equal(t, "/backend/nf/create/storyboard", upstream.calls[1].Path)
+}
+
+func TestSoraDirectClient_GetVideoTask_ReturnsGenerationID(t *testing.T) {
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", "application/json")
+ switch r.URL.Path {
+ case "/nf/pending/v2":
+ _, _ = w.Write([]byte(`[]`))
+ case "/project_y/profile/drafts":
+ _, _ = w.Write([]byte(`{"items":[{"id":"gen_1","task_id":"task-1","kind":"video","downloadable_url":"https://example.com/v.mp4"}]}`))
+ default:
+ http.NotFound(w, r)
+ }
+ }))
+ defer server.Close()
+
+ cfg := &config.Config{
+ Sora: config.SoraConfig{
+ Client: config.SoraClientConfig{
+ BaseURL: server.URL,
+ },
+ },
+ }
+ client := NewSoraDirectClient(cfg, nil, nil)
+ account := &Account{Credentials: map[string]any{"access_token": "token"}}
+
+ status, err := client.GetVideoTask(context.Background(), account, "task-1")
+ require.NoError(t, err)
+ require.Equal(t, "completed", status.Status)
+ require.Equal(t, "gen_1", status.GenerationID)
+ require.Equal(t, []string{"https://example.com/v.mp4"}, status.URLs)
+}
+
+func TestSoraDirectClient_PostVideoForWatermarkFree(t *testing.T) {
+ originPowTokenGenerator := soraPowTokenGenerator
+ soraPowTokenGenerator = func(_ string) string { return "gAAAAACmock" }
+ defer func() { soraPowTokenGenerator = originPowTokenGenerator }()
+
+ upstream := &soraClientRecordingUpstream{}
+ cfg := &config.Config{
+ Sora: config.SoraConfig{
+ Client: config.SoraClientConfig{
+ BaseURL: "https://sora.chatgpt.com/backend",
+ },
+ },
+ }
+ client := NewSoraDirectClient(cfg, upstream, nil)
+ account := &Account{
+ ID: 52,
+ Credentials: map[string]any{
+ "access_token": "access-token",
+ "expires_at": time.Now().Add(30 * time.Minute).Format(time.RFC3339),
+ },
+ }
+
+ postID, err := client.PostVideoForWatermarkFree(context.Background(), account, "gen_1")
+ require.NoError(t, err)
+ require.Equal(t, "s_post", postID)
+ require.Len(t, upstream.calls, 2)
+ require.Equal(t, "/backend-api/sentinel/req", upstream.calls[0].Path)
+ require.Equal(t, "/backend/project_y/post", upstream.calls[1].Path)
}
diff --git a/backend/internal/service/sora_gateway_service.go b/backend/internal/service/sora_gateway_service.go
index ef47f6d4..054d38e7 100644
--- a/backend/internal/service/sora_gateway_service.go
+++ b/backend/internal/service/sora_gateway_service.go
@@ -8,10 +8,12 @@ import (
"fmt"
"io"
"log"
+ "math"
"mime"
"net"
"net/http"
"net/url"
+ "regexp"
"strconv"
"strings"
"time"
@@ -23,6 +25,9 @@ import (
const soraImageInputMaxBytes = 20 << 20
const soraImageInputMaxRedirects = 3
const soraImageInputTimeout = 20 * time.Second
+const soraVideoInputMaxBytes = 200 << 20
+const soraVideoInputMaxRedirects = 3
+const soraVideoInputTimeout = 60 * time.Second
var soraImageSizeMap = map[string]string{
"gpt-image": "360",
@@ -61,6 +66,32 @@ type SoraGatewayService struct {
cfg *config.Config
}
+type soraWatermarkOptions struct {
+ Enabled bool
+ ParseMethod string
+ ParseURL string
+ ParseToken string
+ FallbackOnFailure bool
+ DeletePost bool
+}
+
+type soraCharacterOptions struct {
+ SetPublic bool
+ DeleteAfterGenerate bool
+}
+
+type soraCharacterFlowResult struct {
+ CameoID string
+ CharacterID string
+ Username string
+ DisplayName string
+}
+
+var soraStoryboardPattern = regexp.MustCompile(`\[\d+(?:\.\d+)?s\]`)
+var soraStoryboardShotPattern = regexp.MustCompile(`\[(\d+(?:\.\d+)?)s\]\s*([^\[]+)`)
+var soraRemixTargetPattern = regexp.MustCompile(`s_[a-f0-9]{32}`)
+var soraRemixTargetInURLPattern = regexp.MustCompile(`https://sora\.chatgpt\.com/p/s_[a-f0-9]{32}`)
+
type soraPreflightChecker interface {
PreflightCheck(ctx context.Context, account *Account, requestedModel string, modelCfg SoraModelConfig) error
}
@@ -117,20 +148,34 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun
return nil, fmt.Errorf("unsupported model: %s", reqModel)
}
prompt, imageInput, videoInput, remixTargetID := extractSoraInput(reqBody)
- if strings.TrimSpace(prompt) == "" {
+ prompt = strings.TrimSpace(prompt)
+ imageInput = strings.TrimSpace(imageInput)
+ videoInput = strings.TrimSpace(videoInput)
+ remixTargetID = strings.TrimSpace(remixTargetID)
+
+ if videoInput != "" && modelCfg.Type != "video" {
+ s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "video input only supports video models", clientStream)
+ return nil, errors.New("video input only supports video models")
+ }
+ if videoInput != "" && imageInput != "" {
+ s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "image input and video input cannot be used together", clientStream)
+ return nil, errors.New("image input and video input cannot be used together")
+ }
+ characterOnly := videoInput != "" && prompt == ""
+ if modelCfg.Type == "prompt_enhance" && prompt == "" {
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "prompt is required", clientStream)
return nil, errors.New("prompt is required")
}
- if strings.TrimSpace(videoInput) != "" {
- s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "Video input is not supported yet", clientStream)
- return nil, errors.New("video input not supported")
+ if modelCfg.Type != "prompt_enhance" && prompt == "" && !characterOnly {
+ s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "prompt is required", clientStream)
+ return nil, errors.New("prompt is required")
}
reqCtx, cancel := s.withSoraTimeout(ctx, reqStream)
if cancel != nil {
defer cancel()
}
- if checker, ok := s.soraClient.(soraPreflightChecker); ok {
+ if checker, ok := s.soraClient.(soraPreflightChecker); ok && !characterOnly {
if err := checker.PreflightCheck(reqCtx, account, reqModel, modelCfg); err != nil {
return nil, s.handleSoraRequestError(ctx, account, err, reqModel, c, clientStream)
}
@@ -166,9 +211,69 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun
}, nil
}
+ characterOpts := parseSoraCharacterOptions(reqBody)
+ watermarkOpts := parseSoraWatermarkOptions(reqBody)
+ var characterResult *soraCharacterFlowResult
+ if videoInput != "" {
+ videoData, videoErr := decodeSoraVideoInput(reqCtx, videoInput)
+ if videoErr != nil {
+ s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", videoErr.Error(), clientStream)
+ return nil, videoErr
+ }
+ characterResult, videoErr = s.createCharacterFromVideo(reqCtx, account, videoData, characterOpts)
+ if videoErr != nil {
+ return nil, s.handleSoraRequestError(ctx, account, videoErr, reqModel, c, clientStream)
+ }
+ if characterResult != nil && characterOpts.DeleteAfterGenerate && strings.TrimSpace(characterResult.CharacterID) != "" && !characterOnly {
+ characterID := strings.TrimSpace(characterResult.CharacterID)
+ defer func() {
+ cleanupCtx, cancelCleanup := context.WithTimeout(context.Background(), 15*time.Second)
+ defer cancelCleanup()
+ if err := s.soraClient.DeleteCharacter(cleanupCtx, account, characterID); err != nil {
+ log.Printf("[Sora] cleanup character failed, character_id=%s err=%v", characterID, err)
+ }
+ }()
+ }
+ if characterOnly {
+ content := "角色创建成功"
+ if characterResult != nil && strings.TrimSpace(characterResult.Username) != "" {
+ content = fmt.Sprintf("角色创建成功,角色名@%s", strings.TrimSpace(characterResult.Username))
+ }
+ var firstTokenMs *int
+ if clientStream {
+ ms, streamErr := s.writeSoraStream(c, reqModel, content, startTime)
+ if streamErr != nil {
+ return nil, streamErr
+ }
+ firstTokenMs = ms
+ } else if c != nil {
+ resp := buildSoraNonStreamResponse(content, reqModel)
+ if characterResult != nil {
+ resp["character_id"] = characterResult.CharacterID
+ resp["cameo_id"] = characterResult.CameoID
+ resp["character_username"] = characterResult.Username
+ resp["character_display_name"] = characterResult.DisplayName
+ }
+ c.JSON(http.StatusOK, resp)
+ }
+ return &ForwardResult{
+ RequestID: "",
+ Model: reqModel,
+ Stream: clientStream,
+ Duration: time.Since(startTime),
+ FirstTokenMs: firstTokenMs,
+ Usage: ClaudeUsage{},
+ MediaType: "prompt",
+ }, nil
+ }
+ if characterResult != nil && strings.TrimSpace(characterResult.Username) != "" {
+ prompt = fmt.Sprintf("@%s %s", characterResult.Username, prompt)
+ }
+ }
+
var imageData []byte
imageFilename := ""
- if strings.TrimSpace(imageInput) != "" {
+ if imageInput != "" {
decoded, filename, err := decodeSoraImageInput(reqCtx, imageInput)
if err != nil {
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", err.Error(), clientStream)
@@ -198,15 +303,27 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun
MediaID: mediaID,
})
case "video":
- taskID, err = s.soraClient.CreateVideoTask(reqCtx, account, SoraVideoRequest{
- Prompt: prompt,
- Orientation: modelCfg.Orientation,
- Frames: modelCfg.Frames,
- Model: modelCfg.Model,
- Size: modelCfg.Size,
- MediaID: mediaID,
- RemixTargetID: remixTargetID,
- })
+ if remixTargetID == "" && isSoraStoryboardPrompt(prompt) {
+ taskID, err = s.soraClient.CreateStoryboardTask(reqCtx, account, SoraStoryboardRequest{
+ Prompt: formatSoraStoryboardPrompt(prompt),
+ Orientation: modelCfg.Orientation,
+ Frames: modelCfg.Frames,
+ Model: modelCfg.Model,
+ Size: modelCfg.Size,
+ MediaID: mediaID,
+ })
+ } else {
+ taskID, err = s.soraClient.CreateVideoTask(reqCtx, account, SoraVideoRequest{
+ Prompt: prompt,
+ Orientation: modelCfg.Orientation,
+ Frames: modelCfg.Frames,
+ Model: modelCfg.Model,
+ Size: modelCfg.Size,
+ MediaID: mediaID,
+ RemixTargetID: remixTargetID,
+ CameoIDs: extractSoraCameoIDs(reqBody),
+ })
+ }
default:
err = fmt.Errorf("unsupported model type: %s", modelCfg.Type)
}
@@ -219,6 +336,7 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun
}
var mediaURLs []string
+ videoGenerationID := ""
mediaType := modelCfg.Type
imageCount := 0
imageSize := ""
@@ -232,15 +350,32 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun
imageCount = len(urls)
imageSize = soraImageSizeFromModel(reqModel)
case "video":
- urls, pollErr := s.pollVideoTask(reqCtx, c, account, taskID, clientStream)
+ videoStatus, pollErr := s.pollVideoTaskDetailed(reqCtx, c, account, taskID, clientStream)
if pollErr != nil {
return nil, s.handleSoraRequestError(ctx, account, pollErr, reqModel, c, clientStream)
}
- mediaURLs = urls
+ if videoStatus != nil {
+ mediaURLs = videoStatus.URLs
+ videoGenerationID = strings.TrimSpace(videoStatus.GenerationID)
+ }
default:
mediaType = "prompt"
}
+ watermarkPostID := ""
+ if modelCfg.Type == "video" && watermarkOpts.Enabled {
+ watermarkURL, postID, watermarkErr := s.resolveWatermarkFreeURL(reqCtx, account, videoGenerationID, watermarkOpts)
+ if watermarkErr != nil {
+ if !watermarkOpts.FallbackOnFailure {
+ return nil, s.handleSoraRequestError(ctx, account, watermarkErr, reqModel, c, clientStream)
+ }
+ log.Printf("[Sora] watermark-free fallback to original URL, task_id=%s err=%v", taskID, watermarkErr)
+ } else if strings.TrimSpace(watermarkURL) != "" {
+ mediaURLs = []string{strings.TrimSpace(watermarkURL)}
+ watermarkPostID = strings.TrimSpace(postID)
+ }
+ }
+
finalURLs := s.normalizeSoraMediaURLs(mediaURLs)
if len(mediaURLs) > 0 && s.mediaStorage != nil && s.mediaStorage.Enabled() {
stored, storeErr := s.mediaStorage.StoreFromURLs(reqCtx, mediaType, mediaURLs)
@@ -251,6 +386,11 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun
finalURLs = s.normalizeSoraMediaURLs(stored)
}
}
+ if watermarkPostID != "" && watermarkOpts.DeletePost {
+ if deleteErr := s.soraClient.DeletePost(reqCtx, account, watermarkPostID); deleteErr != nil {
+ log.Printf("[Sora] delete post failed, post_id=%s err=%v", watermarkPostID, deleteErr)
+ }
+ }
content := buildSoraContent(mediaType, finalURLs)
var firstTokenMs *int
@@ -299,6 +439,267 @@ func (s *SoraGatewayService) withSoraTimeout(ctx context.Context, stream bool) (
return context.WithTimeout(ctx, time.Duration(timeoutSeconds)*time.Second)
}
+func parseSoraWatermarkOptions(body map[string]any) soraWatermarkOptions {
+ opts := soraWatermarkOptions{
+ Enabled: parseBoolWithDefault(body, "watermark_free", false),
+ ParseMethod: strings.ToLower(strings.TrimSpace(parseStringWithDefault(body, "watermark_parse_method", "third_party"))),
+ ParseURL: strings.TrimSpace(parseStringWithDefault(body, "watermark_parse_url", "")),
+ ParseToken: strings.TrimSpace(parseStringWithDefault(body, "watermark_parse_token", "")),
+ FallbackOnFailure: parseBoolWithDefault(body, "watermark_fallback_on_failure", true),
+ DeletePost: parseBoolWithDefault(body, "watermark_delete_post", false),
+ }
+ if opts.ParseMethod == "" {
+ opts.ParseMethod = "third_party"
+ }
+ return opts
+}
+
+func parseSoraCharacterOptions(body map[string]any) soraCharacterOptions {
+ return soraCharacterOptions{
+ SetPublic: parseBoolWithDefault(body, "character_set_public", true),
+ DeleteAfterGenerate: parseBoolWithDefault(body, "character_delete_after_generate", true),
+ }
+}
+
+func parseBoolWithDefault(body map[string]any, key string, def bool) bool {
+ if body == nil {
+ return def
+ }
+ val, ok := body[key]
+ if !ok {
+ return def
+ }
+ switch typed := val.(type) {
+ case bool:
+ return typed
+ case int:
+ return typed != 0
+ case int32:
+ return typed != 0
+ case int64:
+ return typed != 0
+ case float64:
+ return typed != 0
+ case string:
+ typed = strings.ToLower(strings.TrimSpace(typed))
+ if typed == "true" || typed == "1" || typed == "yes" {
+ return true
+ }
+ if typed == "false" || typed == "0" || typed == "no" {
+ return false
+ }
+ }
+ return def
+}
+
+func parseStringWithDefault(body map[string]any, key, def string) string {
+ if body == nil {
+ return def
+ }
+ val, ok := body[key]
+ if !ok {
+ return def
+ }
+ if str, ok := val.(string); ok {
+ return str
+ }
+ return def
+}
+
+func extractSoraCameoIDs(body map[string]any) []string {
+ if body == nil {
+ return nil
+ }
+ raw, ok := body["cameo_ids"]
+ if !ok {
+ return nil
+ }
+ switch typed := raw.(type) {
+ case []string:
+ out := make([]string, 0, len(typed))
+ for _, item := range typed {
+ item = strings.TrimSpace(item)
+ if item != "" {
+ out = append(out, item)
+ }
+ }
+ return out
+ case []any:
+ out := make([]string, 0, len(typed))
+ for _, item := range typed {
+ str, ok := item.(string)
+ if !ok {
+ continue
+ }
+ str = strings.TrimSpace(str)
+ if str != "" {
+ out = append(out, str)
+ }
+ }
+ return out
+ default:
+ return nil
+ }
+}
+
+func (s *SoraGatewayService) createCharacterFromVideo(ctx context.Context, account *Account, videoData []byte, opts soraCharacterOptions) (*soraCharacterFlowResult, error) {
+ cameoID, err := s.soraClient.UploadCharacterVideo(ctx, account, videoData)
+ if err != nil {
+ return nil, err
+ }
+
+ cameoStatus, err := s.pollCameoStatus(ctx, account, cameoID)
+ if err != nil {
+ return nil, err
+ }
+ username := processSoraCharacterUsername(cameoStatus.UsernameHint)
+ displayName := strings.TrimSpace(cameoStatus.DisplayNameHint)
+ if displayName == "" {
+ displayName = "Character"
+ }
+ profileAssetURL := strings.TrimSpace(cameoStatus.ProfileAssetURL)
+ if profileAssetURL == "" {
+ return nil, errors.New("profile asset url not found in cameo status")
+ }
+
+ avatarData, err := s.soraClient.DownloadCharacterImage(ctx, account, profileAssetURL)
+ if err != nil {
+ return nil, err
+ }
+ assetPointer, err := s.soraClient.UploadCharacterImage(ctx, account, avatarData)
+ if err != nil {
+ return nil, err
+ }
+ instructionSet := cameoStatus.InstructionSetHint
+ if instructionSet == nil {
+ instructionSet = cameoStatus.InstructionSet
+ }
+
+ characterID, err := s.soraClient.FinalizeCharacter(ctx, account, SoraCharacterFinalizeRequest{
+ CameoID: strings.TrimSpace(cameoID),
+ Username: username,
+ DisplayName: displayName,
+ ProfileAssetPointer: assetPointer,
+ InstructionSet: instructionSet,
+ })
+ if err != nil {
+ return nil, err
+ }
+
+ if opts.SetPublic {
+ if err := s.soraClient.SetCharacterPublic(ctx, account, cameoID); err != nil {
+ return nil, err
+ }
+ }
+
+ return &soraCharacterFlowResult{
+ CameoID: strings.TrimSpace(cameoID),
+ CharacterID: strings.TrimSpace(characterID),
+ Username: strings.TrimSpace(username),
+ DisplayName: displayName,
+ }, nil
+}
+
+func (s *SoraGatewayService) pollCameoStatus(ctx context.Context, account *Account, cameoID string) (*SoraCameoStatus, error) {
+ timeout := 10 * time.Minute
+ interval := 5 * time.Second
+ maxAttempts := int(math.Ceil(timeout.Seconds() / interval.Seconds()))
+ if maxAttempts < 1 {
+ maxAttempts = 1
+ }
+
+ var lastErr error
+ consecutiveErrors := 0
+ for attempt := 0; attempt < maxAttempts; attempt++ {
+ status, err := s.soraClient.GetCameoStatus(ctx, account, cameoID)
+ if err != nil {
+ lastErr = err
+ consecutiveErrors++
+ if consecutiveErrors >= 3 {
+ break
+ }
+ if attempt < maxAttempts-1 {
+ if sleepErr := sleepWithContext(ctx, interval); sleepErr != nil {
+ return nil, sleepErr
+ }
+ }
+ continue
+ }
+ consecutiveErrors = 0
+ if status == nil {
+ if attempt < maxAttempts-1 {
+ if sleepErr := sleepWithContext(ctx, interval); sleepErr != nil {
+ return nil, sleepErr
+ }
+ }
+ continue
+ }
+ currentStatus := strings.ToLower(strings.TrimSpace(status.Status))
+ statusMessage := strings.TrimSpace(status.StatusMessage)
+ if currentStatus == "failed" {
+ if statusMessage == "" {
+ statusMessage = "character creation failed"
+ }
+ return nil, errors.New(statusMessage)
+ }
+ if strings.EqualFold(statusMessage, "Completed") || currentStatus == "finalized" {
+ return status, nil
+ }
+ if attempt < maxAttempts-1 {
+ if sleepErr := sleepWithContext(ctx, interval); sleepErr != nil {
+ return nil, sleepErr
+ }
+ }
+ }
+ if lastErr != nil {
+ return nil, fmt.Errorf("poll cameo status failed: %w", lastErr)
+ }
+ return nil, errors.New("cameo processing timeout")
+}
+
+func processSoraCharacterUsername(usernameHint string) string {
+ usernameHint = strings.TrimSpace(usernameHint)
+ if usernameHint == "" {
+ usernameHint = "character"
+ }
+ if strings.Contains(usernameHint, ".") {
+ parts := strings.Split(usernameHint, ".")
+ usernameHint = strings.TrimSpace(parts[len(parts)-1])
+ }
+ if usernameHint == "" {
+ usernameHint = "character"
+ }
+ return fmt.Sprintf("%s%d", usernameHint, soraRandInt(900)+100)
+}
+
+func (s *SoraGatewayService) resolveWatermarkFreeURL(ctx context.Context, account *Account, generationID string, opts soraWatermarkOptions) (string, string, error) {
+ generationID = strings.TrimSpace(generationID)
+ if generationID == "" {
+ return "", "", errors.New("generation id is required for watermark-free mode")
+ }
+ postID, err := s.soraClient.PostVideoForWatermarkFree(ctx, account, generationID)
+ if err != nil {
+ return "", "", err
+ }
+ postID = strings.TrimSpace(postID)
+ if postID == "" {
+ return "", "", errors.New("watermark-free publish returned empty post id")
+ }
+
+ switch opts.ParseMethod {
+ case "custom":
+ urlVal, parseErr := s.soraClient.GetWatermarkFreeURLCustom(ctx, account, opts.ParseURL, opts.ParseToken, postID)
+ if parseErr != nil {
+ return "", postID, parseErr
+ }
+ return strings.TrimSpace(urlVal), postID, nil
+ case "", "third_party":
+ return fmt.Sprintf("https://oscdn2.dyysy.com/MP4/%s.mp4", postID), postID, nil
+ default:
+ return "", postID, fmt.Errorf("unsupported watermark parse method: %s", opts.ParseMethod)
+ }
+}
+
func (s *SoraGatewayService) shouldFailoverUpstreamError(statusCode int) bool {
switch statusCode {
case 401, 402, 403, 404, 429, 529:
@@ -554,7 +955,7 @@ func (s *SoraGatewayService) pollImageTask(ctx context.Context, c *gin.Context,
return nil, errors.New("sora image generation timeout")
}
-func (s *SoraGatewayService) pollVideoTask(ctx context.Context, c *gin.Context, account *Account, taskID string, stream bool) ([]string, error) {
+func (s *SoraGatewayService) pollVideoTaskDetailed(ctx context.Context, c *gin.Context, account *Account, taskID string, stream bool) (*SoraVideoTaskStatus, error) {
interval := s.pollInterval()
maxAttempts := s.pollMaxAttempts()
lastPing := time.Now()
@@ -565,7 +966,7 @@ func (s *SoraGatewayService) pollVideoTask(ctx context.Context, c *gin.Context,
}
switch strings.ToLower(status.Status) {
case "completed", "succeeded":
- return status.URLs, nil
+ return status, nil
case "failed":
if status.ErrorMsg != "" {
return nil, errors.New(status.ErrorMsg)
@@ -669,7 +1070,7 @@ func extractSoraInput(body map[string]any) (prompt, imageInput, videoInput, remi
return "", "", "", ""
}
if v, ok := body["remix_target_id"].(string); ok {
- remixTargetID = v
+ remixTargetID = strings.TrimSpace(v)
}
if v, ok := body["image"].(string); ok {
imageInput = v
@@ -710,6 +1111,10 @@ func extractSoraInput(body map[string]any) (prompt, imageInput, videoInput, remi
prompt = builder.String()
}
}
+ if remixTargetID == "" {
+ remixTargetID = extractRemixTargetIDFromPrompt(prompt)
+ }
+ prompt = cleanRemixLinkFromPrompt(prompt)
return prompt, imageInput, videoInput, remixTargetID
}
@@ -757,6 +1162,69 @@ func parseSoraMessageContent(content any) (text, imageInput, videoInput string)
}
}
+func isSoraStoryboardPrompt(prompt string) bool {
+ prompt = strings.TrimSpace(prompt)
+ if prompt == "" {
+ return false
+ }
+ return len(soraStoryboardPattern.FindAllString(prompt, -1)) >= 1
+}
+
+func formatSoraStoryboardPrompt(prompt string) string {
+ prompt = strings.TrimSpace(prompt)
+ if prompt == "" {
+ return ""
+ }
+ matches := soraStoryboardShotPattern.FindAllStringSubmatch(prompt, -1)
+ if len(matches) == 0 {
+ return prompt
+ }
+ firstBracketPos := strings.Index(prompt, "[")
+ instructions := ""
+ if firstBracketPos > 0 {
+ instructions = strings.TrimSpace(prompt[:firstBracketPos])
+ }
+ shots := make([]string, 0, len(matches))
+ for i, match := range matches {
+ if len(match) < 3 {
+ continue
+ }
+ duration := strings.TrimSpace(match[1])
+ scene := strings.TrimSpace(match[2])
+ if scene == "" {
+ continue
+ }
+ shots = append(shots, fmt.Sprintf("Shot %d:\nduration: %ssec\nScene: %s", i+1, duration, scene))
+ }
+ if len(shots) == 0 {
+ return prompt
+ }
+ timeline := strings.Join(shots, "\n\n")
+ if instructions == "" {
+ return timeline
+ }
+ return fmt.Sprintf("current timeline:\n%s\n\ninstructions:\n%s", timeline, instructions)
+}
+
+func extractRemixTargetIDFromPrompt(prompt string) string {
+ prompt = strings.TrimSpace(prompt)
+ if prompt == "" {
+ return ""
+ }
+ return strings.TrimSpace(soraRemixTargetPattern.FindString(prompt))
+}
+
+func cleanRemixLinkFromPrompt(prompt string) string {
+ prompt = strings.TrimSpace(prompt)
+ if prompt == "" {
+ return prompt
+ }
+ cleaned := soraRemixTargetInURLPattern.ReplaceAllString(prompt, "")
+ cleaned = soraRemixTargetPattern.ReplaceAllString(cleaned, "")
+ cleaned = strings.Join(strings.Fields(cleaned), " ")
+ return strings.TrimSpace(cleaned)
+}
+
func decodeSoraImageInput(ctx context.Context, input string) ([]byte, string, error) {
raw := strings.TrimSpace(input)
if raw == "" {
@@ -769,7 +1237,7 @@ func decodeSoraImageInput(ctx context.Context, input string) ([]byte, string, er
}
meta := parts[0]
payload := parts[1]
- decoded, err := base64.StdEncoding.DecodeString(payload)
+ decoded, err := decodeBase64WithLimit(payload, soraImageInputMaxBytes)
if err != nil {
return nil, "", err
}
@@ -788,15 +1256,47 @@ func decodeSoraImageInput(ctx context.Context, input string) ([]byte, string, er
if strings.HasPrefix(raw, "http://") || strings.HasPrefix(raw, "https://") {
return downloadSoraImageInput(ctx, raw)
}
- decoded, err := base64.StdEncoding.DecodeString(raw)
+ decoded, err := decodeBase64WithLimit(raw, soraImageInputMaxBytes)
if err != nil {
return nil, "", errors.New("invalid base64 image")
}
return decoded, "image.png", nil
}
+func decodeSoraVideoInput(ctx context.Context, input string) ([]byte, error) {
+ raw := strings.TrimSpace(input)
+ if raw == "" {
+ return nil, errors.New("empty video input")
+ }
+ if strings.HasPrefix(raw, "data:") {
+ parts := strings.SplitN(raw, ",", 2)
+ if len(parts) != 2 {
+ return nil, errors.New("invalid video data url")
+ }
+ decoded, err := decodeBase64WithLimit(parts[1], soraVideoInputMaxBytes)
+ if err != nil {
+ return nil, errors.New("invalid base64 video")
+ }
+ if len(decoded) == 0 {
+ return nil, errors.New("empty video data")
+ }
+ return decoded, nil
+ }
+ if strings.HasPrefix(raw, "http://") || strings.HasPrefix(raw, "https://") {
+ return downloadSoraVideoInput(ctx, raw)
+ }
+ decoded, err := decodeBase64WithLimit(raw, soraVideoInputMaxBytes)
+ if err != nil {
+ return nil, errors.New("invalid base64 video")
+ }
+ if len(decoded) == 0 {
+ return nil, errors.New("empty video data")
+ }
+ return decoded, nil
+}
+
func downloadSoraImageInput(ctx context.Context, rawURL string) ([]byte, string, error) {
- parsed, err := validateSoraImageURL(rawURL)
+ parsed, err := validateSoraRemoteURL(rawURL)
if err != nil {
return nil, "", err
}
@@ -810,7 +1310,7 @@ func downloadSoraImageInput(ctx context.Context, rawURL string) ([]byte, string,
if len(via) >= soraImageInputMaxRedirects {
return errors.New("too many redirects")
}
- return validateSoraImageURLValue(req.URL)
+ return validateSoraRemoteURLValue(req.URL)
},
}
resp, err := client.Do(req)
@@ -833,51 +1333,103 @@ func downloadSoraImageInput(ctx context.Context, rawURL string) ([]byte, string,
return data, filename, nil
}
-func validateSoraImageURL(raw string) (*url.URL, error) {
+func downloadSoraVideoInput(ctx context.Context, rawURL string) ([]byte, error) {
+ parsed, err := validateSoraRemoteURL(rawURL)
+ if err != nil {
+ return nil, err
+ }
+ req, err := http.NewRequestWithContext(ctx, http.MethodGet, parsed.String(), nil)
+ if err != nil {
+ return nil, err
+ }
+ client := &http.Client{
+ Timeout: soraVideoInputTimeout,
+ CheckRedirect: func(req *http.Request, via []*http.Request) error {
+ if len(via) >= soraVideoInputMaxRedirects {
+ return errors.New("too many redirects")
+ }
+ return validateSoraRemoteURLValue(req.URL)
+ },
+ }
+ resp, err := client.Do(req)
+ if err != nil {
+ return nil, err
+ }
+ defer func() { _ = resp.Body.Close() }()
+ if resp.StatusCode != http.StatusOK {
+ return nil, fmt.Errorf("download video failed: %d", resp.StatusCode)
+ }
+ data, err := io.ReadAll(io.LimitReader(resp.Body, soraVideoInputMaxBytes))
+ if err != nil {
+ return nil, err
+ }
+ if len(data) == 0 {
+ return nil, errors.New("empty video content")
+ }
+ return data, nil
+}
+
+func decodeBase64WithLimit(encoded string, maxBytes int64) ([]byte, error) {
+ if maxBytes <= 0 {
+ return nil, errors.New("invalid max bytes limit")
+ }
+ decoder := base64.NewDecoder(base64.StdEncoding, strings.NewReader(encoded))
+ limited := io.LimitReader(decoder, maxBytes+1)
+ data, err := io.ReadAll(limited)
+ if err != nil {
+ return nil, err
+ }
+ if int64(len(data)) > maxBytes {
+ return nil, fmt.Errorf("input exceeds %d bytes limit", maxBytes)
+ }
+ return data, nil
+}
+
+func validateSoraRemoteURL(raw string) (*url.URL, error) {
if strings.TrimSpace(raw) == "" {
- return nil, errors.New("empty image url")
+ return nil, errors.New("empty remote url")
}
parsed, err := url.Parse(raw)
if err != nil {
- return nil, fmt.Errorf("invalid image url: %w", err)
+ return nil, fmt.Errorf("invalid remote url: %w", err)
}
- if err := validateSoraImageURLValue(parsed); err != nil {
+ if err := validateSoraRemoteURLValue(parsed); err != nil {
return nil, err
}
return parsed, nil
}
-func validateSoraImageURLValue(parsed *url.URL) error {
+func validateSoraRemoteURLValue(parsed *url.URL) error {
if parsed == nil {
- return errors.New("invalid image url")
+ return errors.New("invalid remote url")
}
scheme := strings.ToLower(strings.TrimSpace(parsed.Scheme))
if scheme != "http" && scheme != "https" {
- return errors.New("only http/https image url is allowed")
+ return errors.New("only http/https remote url is allowed")
}
if parsed.User != nil {
- return errors.New("image url cannot contain userinfo")
+ return errors.New("remote url cannot contain userinfo")
}
host := strings.ToLower(strings.TrimSpace(parsed.Hostname()))
if host == "" {
- return errors.New("image url missing host")
+ return errors.New("remote url missing host")
}
if _, blocked := soraBlockedHostnames[host]; blocked {
- return errors.New("image url is not allowed")
+ return errors.New("remote url is not allowed")
}
if ip := net.ParseIP(host); ip != nil {
if isSoraBlockedIP(ip) {
- return errors.New("image url is not allowed")
+ return errors.New("remote url is not allowed")
}
return nil
}
ips, err := net.LookupIP(host)
if err != nil {
- return fmt.Errorf("resolve image url failed: %w", err)
+ return fmt.Errorf("resolve remote url failed: %w", err)
}
for _, ip := range ips {
if isSoraBlockedIP(ip) {
- return errors.New("image url is not allowed")
+ return errors.New("remote url is not allowed")
}
}
return nil
diff --git a/backend/internal/service/sora_gateway_service_test.go b/backend/internal/service/sora_gateway_service_test.go
index 469a131e..c965901c 100644
--- a/backend/internal/service/sora_gateway_service_test.go
+++ b/backend/internal/service/sora_gateway_service_test.go
@@ -5,6 +5,7 @@ package service
import (
"context"
"encoding/json"
+ "errors"
"net/http"
"net/http/httptest"
"strings"
@@ -25,6 +26,11 @@ type stubSoraClientForPoll struct {
videoCalls int
enhanced string
enhanceErr error
+ storyboard bool
+ videoReq SoraVideoRequest
+ parseErr error
+ postCalls int
+ deleteCalls int
}
func (s *stubSoraClientForPoll) Enabled() bool { return true }
@@ -35,8 +41,54 @@ func (s *stubSoraClientForPoll) CreateImageTask(ctx context.Context, account *Ac
return "task-image", nil
}
func (s *stubSoraClientForPoll) CreateVideoTask(ctx context.Context, account *Account, req SoraVideoRequest) (string, error) {
+ s.videoReq = req
return "task-video", nil
}
+func (s *stubSoraClientForPoll) CreateStoryboardTask(ctx context.Context, account *Account, req SoraStoryboardRequest) (string, error) {
+ s.storyboard = true
+ return "task-video", nil
+}
+func (s *stubSoraClientForPoll) UploadCharacterVideo(ctx context.Context, account *Account, data []byte) (string, error) {
+ return "cameo-1", nil
+}
+func (s *stubSoraClientForPoll) GetCameoStatus(ctx context.Context, account *Account, cameoID string) (*SoraCameoStatus, error) {
+ return &SoraCameoStatus{
+ Status: "finalized",
+ StatusMessage: "Completed",
+ DisplayNameHint: "Character",
+ UsernameHint: "user.character",
+ ProfileAssetURL: "https://example.com/avatar.webp",
+ }, nil
+}
+func (s *stubSoraClientForPoll) DownloadCharacterImage(ctx context.Context, account *Account, imageURL string) ([]byte, error) {
+ return []byte("avatar"), nil
+}
+func (s *stubSoraClientForPoll) UploadCharacterImage(ctx context.Context, account *Account, data []byte) (string, error) {
+ return "asset-pointer", nil
+}
+func (s *stubSoraClientForPoll) FinalizeCharacter(ctx context.Context, account *Account, req SoraCharacterFinalizeRequest) (string, error) {
+ return "character-1", nil
+}
+func (s *stubSoraClientForPoll) SetCharacterPublic(ctx context.Context, account *Account, cameoID string) error {
+ return nil
+}
+func (s *stubSoraClientForPoll) DeleteCharacter(ctx context.Context, account *Account, characterID string) error {
+ return nil
+}
+func (s *stubSoraClientForPoll) PostVideoForWatermarkFree(ctx context.Context, account *Account, generationID string) (string, error) {
+ s.postCalls++
+ return "s_post", nil
+}
+func (s *stubSoraClientForPoll) DeletePost(ctx context.Context, account *Account, postID string) error {
+ s.deleteCalls++
+ return nil
+}
+func (s *stubSoraClientForPoll) GetWatermarkFreeURLCustom(ctx context.Context, account *Account, parseURL, parseToken, postID string) (string, error) {
+ if s.parseErr != nil {
+ return "", s.parseErr
+ }
+ return "https://example.com/no-watermark.mp4", nil
+}
func (s *stubSoraClientForPoll) EnhancePrompt(ctx context.Context, account *Account, prompt, expansionLevel string, durationS int) (string, error) {
if s.enhanced != "" {
return s.enhanced, s.enhanceErr
@@ -102,6 +154,109 @@ func TestSoraGatewayService_ForwardPromptEnhance(t *testing.T) {
require.Equal(t, "prompt-enhance-short-10s", result.Model)
}
+func TestSoraGatewayService_ForwardStoryboardPrompt(t *testing.T) {
+ client := &stubSoraClientForPoll{
+ videoStatus: &SoraVideoTaskStatus{
+ Status: "completed",
+ URLs: []string{"https://example.com/v.mp4"},
+ },
+ }
+ cfg := &config.Config{
+ Sora: config.SoraConfig{
+ Client: config.SoraClientConfig{
+ PollIntervalSeconds: 1,
+ MaxPollAttempts: 1,
+ },
+ },
+ }
+ svc := NewSoraGatewayService(client, nil, nil, cfg)
+ account := &Account{ID: 1, Platform: PlatformSora, Status: StatusActive}
+ body := []byte(`{"model":"sora2-landscape-10s","messages":[{"role":"user","content":"[5.0s]猫猫跳伞 [5.0s]猫猫落地"}],"stream":false}`)
+
+ result, err := svc.Forward(context.Background(), nil, account, body, false)
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.True(t, client.storyboard)
+}
+
+func TestSoraGatewayService_ForwardCharacterOnly(t *testing.T) {
+ client := &stubSoraClientForPoll{}
+ cfg := &config.Config{
+ Sora: config.SoraConfig{
+ Client: config.SoraClientConfig{
+ PollIntervalSeconds: 1,
+ MaxPollAttempts: 1,
+ },
+ },
+ }
+ svc := NewSoraGatewayService(client, nil, nil, cfg)
+ account := &Account{ID: 1, Platform: PlatformSora, Status: StatusActive}
+ body := []byte(`{"model":"sora2-landscape-10s","video":"aGVsbG8=","stream":false}`)
+
+ result, err := svc.Forward(context.Background(), nil, account, body, false)
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.Equal(t, "prompt", result.MediaType)
+ require.Equal(t, 0, client.videoCalls)
+}
+
+func TestSoraGatewayService_ForwardWatermarkFallback(t *testing.T) {
+ client := &stubSoraClientForPoll{
+ videoStatus: &SoraVideoTaskStatus{
+ Status: "completed",
+ URLs: []string{"https://example.com/original.mp4"},
+ GenerationID: "gen_1",
+ },
+ parseErr: errors.New("parse failed"),
+ }
+ cfg := &config.Config{
+ Sora: config.SoraConfig{
+ Client: config.SoraClientConfig{
+ PollIntervalSeconds: 1,
+ MaxPollAttempts: 1,
+ },
+ },
+ }
+ svc := NewSoraGatewayService(client, nil, nil, cfg)
+ account := &Account{ID: 1, Platform: PlatformSora, Status: StatusActive}
+ body := []byte(`{"model":"sora2-landscape-10s","messages":[{"role":"user","content":"cat running"}],"stream":false,"watermark_free":true,"watermark_parse_method":"custom","watermark_parse_url":"https://parser.example.com","watermark_parse_token":"token","watermark_fallback_on_failure":true}`)
+
+ result, err := svc.Forward(context.Background(), nil, account, body, false)
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.Equal(t, "https://example.com/original.mp4", result.MediaURL)
+ require.Equal(t, 1, client.postCalls)
+ require.Equal(t, 0, client.deleteCalls)
+}
+
+func TestSoraGatewayService_ForwardWatermarkCustomSuccessAndDelete(t *testing.T) {
+ client := &stubSoraClientForPoll{
+ videoStatus: &SoraVideoTaskStatus{
+ Status: "completed",
+ URLs: []string{"https://example.com/original.mp4"},
+ GenerationID: "gen_1",
+ },
+ }
+ cfg := &config.Config{
+ Sora: config.SoraConfig{
+ Client: config.SoraClientConfig{
+ PollIntervalSeconds: 1,
+ MaxPollAttempts: 1,
+ },
+ },
+ }
+ svc := NewSoraGatewayService(client, nil, nil, cfg)
+ account := &Account{ID: 1, Platform: PlatformSora, Status: StatusActive}
+ body := []byte(`{"model":"sora2-landscape-10s","messages":[{"role":"user","content":"cat running"}],"stream":false,"watermark_free":true,"watermark_parse_method":"custom","watermark_parse_url":"https://parser.example.com","watermark_parse_token":"token","watermark_delete_post":true}`)
+
+ result, err := svc.Forward(context.Background(), nil, account, body, false)
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.Equal(t, "https://example.com/no-watermark.mp4", result.MediaURL)
+ require.Equal(t, 1, client.postCalls)
+ require.Equal(t, 1, client.deleteCalls)
+}
+
func TestSoraGatewayService_PollVideoTaskFailed(t *testing.T) {
client := &stubSoraClientForPoll{
videoStatus: &SoraVideoTaskStatus{
@@ -119,9 +274,9 @@ func TestSoraGatewayService_PollVideoTaskFailed(t *testing.T) {
}
service := NewSoraGatewayService(client, nil, nil, cfg)
- urls, err := service.pollVideoTask(context.Background(), nil, &Account{ID: 1}, "task", false)
+ status, err := service.pollVideoTaskDetailed(context.Background(), nil, &Account{ID: 1}, "task", false)
require.Error(t, err)
- require.Empty(t, urls)
+ require.Nil(t, status)
require.Contains(t, err.Error(), "reject")
require.Equal(t, 1, client.videoCalls)
}
@@ -325,3 +480,19 @@ func TestDecodeSoraImageInput_DataURL(t *testing.T) {
require.NotEmpty(t, data)
require.Contains(t, filename, ".png")
}
+
+func TestDecodeBase64WithLimit_ExceedLimit(t *testing.T) {
+ data, err := decodeBase64WithLimit("aGVsbG8=", 3)
+ require.Error(t, err)
+ require.Nil(t, data)
+}
+
+func TestParseSoraWatermarkOptions_NumericBool(t *testing.T) {
+ body := map[string]any{
+ "watermark_free": float64(1),
+ "watermark_fallback_on_failure": float64(0),
+ }
+ opts := parseSoraWatermarkOptions(body)
+ require.True(t, opts.Enabled)
+ require.False(t, opts.FallbackOnFailure)
+}
diff --git a/backend/internal/util/soraerror/soraerror.go b/backend/internal/util/soraerror/soraerror.go
new file mode 100644
index 00000000..17712c10
--- /dev/null
+++ b/backend/internal/util/soraerror/soraerror.go
@@ -0,0 +1,170 @@
+package soraerror
+
+import (
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "regexp"
+ "strings"
+)
+
+var (
+ cfRayPattern = regexp.MustCompile(`(?i)cf-ray[:\s=]+([a-z0-9-]+)`)
+ cRayPattern = regexp.MustCompile(`(?i)cRay:\s*'([a-z0-9-]+)'`)
+ htmlChallenge = []string{
+ "window._cf_chl_opt",
+ "just a moment",
+ "enable javascript and cookies to continue",
+ "__cf_chl_",
+ "challenge-platform",
+ }
+)
+
+// IsCloudflareChallengeResponse reports whether the upstream response matches Cloudflare challenge behavior.
+func IsCloudflareChallengeResponse(statusCode int, headers http.Header, body []byte) bool {
+ if statusCode != http.StatusForbidden && statusCode != http.StatusTooManyRequests {
+ return false
+ }
+
+ if headers != nil && strings.EqualFold(strings.TrimSpace(headers.Get("cf-mitigated")), "challenge") {
+ return true
+ }
+
+ preview := strings.ToLower(TruncateBody(body, 4096))
+ for _, marker := range htmlChallenge {
+ if strings.Contains(preview, marker) {
+ return true
+ }
+ }
+
+ contentType := ""
+ if headers != nil {
+ contentType = strings.ToLower(strings.TrimSpace(headers.Get("content-type")))
+ }
+ if strings.Contains(contentType, "text/html") &&
+ (strings.Contains(preview, "= 2 {
+ return strings.TrimSpace(matches[1])
+ }
+ if matches := cRayPattern.FindStringSubmatch(preview); len(matches) >= 2 {
+ return strings.TrimSpace(matches[1])
+ }
+ return ""
+}
+
+// FormatCloudflareChallengeMessage appends cf-ray info when available.
+func FormatCloudflareChallengeMessage(base string, headers http.Header, body []byte) string {
+ rayID := ExtractCloudflareRayID(headers, body)
+ if rayID == "" {
+ return base
+ }
+ return fmt.Sprintf("%s (cf-ray: %s)", base, rayID)
+}
+
+// ExtractUpstreamErrorCodeAndMessage extracts structured error code/message from common JSON layouts.
+func ExtractUpstreamErrorCodeAndMessage(body []byte) (string, string) {
+ trimmed := strings.TrimSpace(string(body))
+ if trimmed == "" {
+ return "", ""
+ }
+ if !json.Valid([]byte(trimmed)) {
+ return "", truncateMessage(trimmed, 256)
+ }
+
+ var payload map[string]any
+ if err := json.Unmarshal([]byte(trimmed), &payload); err != nil {
+ return "", truncateMessage(trimmed, 256)
+ }
+
+ code := firstNonEmpty(
+ extractNestedString(payload, "error", "code"),
+ extractRootString(payload, "code"),
+ )
+ message := firstNonEmpty(
+ extractNestedString(payload, "error", "message"),
+ extractRootString(payload, "message"),
+ extractNestedString(payload, "error", "detail"),
+ extractRootString(payload, "detail"),
+ )
+ return strings.TrimSpace(code), truncateMessage(strings.TrimSpace(message), 512)
+}
+
+// TruncateBody truncates body text for logging/inspection.
+func TruncateBody(body []byte, max int) string {
+ if max <= 0 {
+ max = 512
+ }
+ raw := strings.TrimSpace(string(body))
+ if len(raw) <= max {
+ return raw
+ }
+ return raw[:max] + "...(truncated)"
+}
+
+func truncateMessage(s string, max int) string {
+ if max <= 0 {
+ return ""
+ }
+ if len(s) <= max {
+ return s
+ }
+ return s[:max] + "...(truncated)"
+}
+
+func firstNonEmpty(values ...string) string {
+ for _, v := range values {
+ if strings.TrimSpace(v) != "" {
+ return v
+ }
+ }
+ return ""
+}
+
+func extractRootString(m map[string]any, key string) string {
+ if m == nil {
+ return ""
+ }
+ v, ok := m[key]
+ if !ok {
+ return ""
+ }
+ s, _ := v.(string)
+ return s
+}
+
+func extractNestedString(m map[string]any, parent, key string) string {
+ if m == nil {
+ return ""
+ }
+ node, ok := m[parent]
+ if !ok {
+ return ""
+ }
+ child, ok := node.(map[string]any)
+ if !ok {
+ return ""
+ }
+ s, _ := child[key].(string)
+ return s
+}
diff --git a/backend/internal/util/soraerror/soraerror_test.go b/backend/internal/util/soraerror/soraerror_test.go
new file mode 100644
index 00000000..4cf11169
--- /dev/null
+++ b/backend/internal/util/soraerror/soraerror_test.go
@@ -0,0 +1,47 @@
+package soraerror
+
+import (
+ "net/http"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestIsCloudflareChallengeResponse(t *testing.T) {
+ headers := make(http.Header)
+ headers.Set("cf-mitigated", "challenge")
+ require.True(t, IsCloudflareChallengeResponse(http.StatusForbidden, headers, []byte(`{"ok":false}`)))
+
+ require.True(t, IsCloudflareChallengeResponse(http.StatusTooManyRequests, nil, []byte(`Just a moment...`)))
+ require.False(t, IsCloudflareChallengeResponse(http.StatusBadGateway, nil, []byte(`Just a moment...`)))
+}
+
+func TestExtractCloudflareRayID(t *testing.T) {
+ headers := make(http.Header)
+ headers.Set("cf-ray", "9d01b0e9ecc35829-SEA")
+ require.Equal(t, "9d01b0e9ecc35829-SEA", ExtractCloudflareRayID(headers, nil))
+
+ body := []byte(``)
+ require.Equal(t, "9cff2d62d83bb98d", ExtractCloudflareRayID(nil, body))
+}
+
+func TestExtractUpstreamErrorCodeAndMessage(t *testing.T) {
+ code, msg := ExtractUpstreamErrorCodeAndMessage([]byte(`{"error":{"code":"cf_shield_429","message":"rate limited"}}`))
+ require.Equal(t, "cf_shield_429", code)
+ require.Equal(t, "rate limited", msg)
+
+ code, msg = ExtractUpstreamErrorCodeAndMessage([]byte(`{"code":"unsupported_country_code","message":"not available"}`))
+ require.Equal(t, "unsupported_country_code", code)
+ require.Equal(t, "not available", msg)
+
+ code, msg = ExtractUpstreamErrorCodeAndMessage([]byte(`plain text`))
+ require.Equal(t, "", code)
+ require.Equal(t, "plain text", msg)
+}
+
+func TestFormatCloudflareChallengeMessage(t *testing.T) {
+ headers := make(http.Header)
+ headers.Set("cf-ray", "9d03b68c086027a1-SEA")
+ msg := FormatCloudflareChallengeMessage("blocked", headers, nil)
+ require.Equal(t, "blocked (cf-ray: 9d03b68c086027a1-SEA)", msg)
+}