feat(sora): 对齐sora2api分镜角色去水印与挑战错误治理

This commit is contained in:
yangjianbo
2026-02-19 20:04:10 +08:00
parent 440b87094a
commit 40498aac9d
12 changed files with 1994 additions and 202 deletions

View File

@@ -1 +1 @@
0.1.83.1
0.1.83.2

View File

@@ -184,7 +184,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler)
gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, usageService, apiKeyService, errorPassthroughService, configConfig)
openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, apiKeyService, errorPassthroughService, configConfig)
soraDirectClient := service.NewSoraDirectClient(configConfig, httpUpstream, openAITokenProvider)
soraDirectClient := service.ProvideSoraDirectClient(configConfig, httpUpstream, openAITokenProvider, accountRepository, soraAccountRepository)
soraMediaStorage := service.ProvideSoraMediaStorage(configConfig)
soraGatewayService := service.NewSoraGatewayService(soraDirectClient, soraMediaStorage, rateLimitService, configConfig)
soraGatewayHandler := handler.NewSoraGatewayHandler(gatewayService, soraGatewayService, concurrencyService, billingCacheService, configConfig)

View File

@@ -12,7 +12,6 @@ import (
"os"
"path"
"path/filepath"
"regexp"
"strconv"
"strings"
"time"
@@ -22,6 +21,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/util/soraerror"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
@@ -29,8 +29,6 @@ import (
"go.uber.org/zap"
)
var soraCloudflareRayPattern = regexp.MustCompile(`(?i)cf-ray[:\s=]+([a-z0-9-]+)`)
// SoraGatewayHandler handles Sora chat completions requests
type SoraGatewayHandler struct {
gatewayService *service.GatewayService
@@ -385,6 +383,10 @@ func (h *SoraGatewayHandler) mapUpstreamError(statusCode int, responseHeaders ht
}
upstreamCode, upstreamMessage := extractUpstreamErrorCodeAndMessage(responseBody)
if strings.EqualFold(upstreamCode, "cf_shield_429") {
baseMsg := "Sora request blocked by Cloudflare shield (429). Please switch to a clean proxy/network and retry."
return http.StatusTooManyRequests, "rate_limit_error", formatSoraCloudflareChallengeMessage(baseMsg, responseHeaders, responseBody)
}
if shouldPassthroughSoraUpstreamMessage(statusCode, upstreamMessage) {
switch statusCode {
case 401, 403, 404, 500, 502, 503, 504:
@@ -416,27 +418,7 @@ func (h *SoraGatewayHandler) mapUpstreamError(statusCode int, responseHeaders ht
}
func isSoraCloudflareChallengeResponse(statusCode int, headers http.Header, body []byte) bool {
if statusCode != http.StatusForbidden && statusCode != http.StatusTooManyRequests {
return false
}
if strings.EqualFold(strings.TrimSpace(headers.Get("cf-mitigated")), "challenge") {
return true
}
preview := strings.ToLower(truncateSoraErrorBody(body, 4096))
if strings.Contains(preview, "window._cf_chl_opt") ||
strings.Contains(preview, "just a moment") ||
strings.Contains(preview, "enable javascript and cookies to continue") ||
strings.Contains(preview, "__cf_chl_") ||
strings.Contains(preview, "challenge-platform") {
return true
}
contentType := strings.ToLower(strings.TrimSpace(headers.Get("content-type")))
if strings.Contains(contentType, "text/html") &&
(strings.Contains(preview, "<html") || strings.Contains(preview, "<!doctype html")) &&
(strings.Contains(preview, "cloudflare") || strings.Contains(preview, "challenge")) {
return true
}
return false
return soraerror.IsCloudflareChallengeResponse(statusCode, headers, body)
}
func shouldPassthroughSoraUpstreamMessage(statusCode int, message string) bool {
@@ -454,76 +436,11 @@ func shouldPassthroughSoraUpstreamMessage(statusCode int, message string) bool {
}
func formatSoraCloudflareChallengeMessage(base string, headers http.Header, body []byte) string {
rayID := extractSoraCloudflareRayID(headers, body)
if rayID == "" {
return base
}
return fmt.Sprintf("%s (cf-ray: %s)", base, rayID)
}
func extractSoraCloudflareRayID(headers http.Header, body []byte) string {
if headers != nil {
rayID := strings.TrimSpace(headers.Get("cf-ray"))
if rayID != "" {
return rayID
}
rayID = strings.TrimSpace(headers.Get("Cf-Ray"))
if rayID != "" {
return rayID
}
}
preview := truncateSoraErrorBody(body, 8192)
matches := soraCloudflareRayPattern.FindStringSubmatch(preview)
if len(matches) >= 2 {
return strings.TrimSpace(matches[1])
}
return ""
return soraerror.FormatCloudflareChallengeMessage(base, headers, body)
}
func extractUpstreamErrorCodeAndMessage(body []byte) (string, string) {
trimmed := strings.TrimSpace(string(body))
if trimmed == "" {
return "", ""
}
if !gjson.Valid(trimmed) {
return "", truncateSoraErrorMessage(trimmed, 256)
}
code := strings.TrimSpace(gjson.Get(trimmed, "error.code").String())
if code == "" {
code = strings.TrimSpace(gjson.Get(trimmed, "code").String())
}
message := strings.TrimSpace(gjson.Get(trimmed, "error.message").String())
if message == "" {
message = strings.TrimSpace(gjson.Get(trimmed, "message").String())
}
if message == "" {
message = strings.TrimSpace(gjson.Get(trimmed, "error.detail").String())
}
if message == "" {
message = strings.TrimSpace(gjson.Get(trimmed, "detail").String())
}
return code, truncateSoraErrorMessage(message, 512)
}
func truncateSoraErrorMessage(s string, maxLen int) string {
if maxLen <= 0 {
return ""
}
if len(s) <= maxLen {
return s
}
return s[:maxLen] + "...(truncated)"
}
func truncateSoraErrorBody(body []byte, maxLen int) string {
if maxLen <= 0 {
maxLen = 512
}
raw := strings.TrimSpace(string(body))
if len(raw) <= maxLen {
return raw
}
return raw[:maxLen] + "...(truncated)"
return soraerror.ExtractUpstreamErrorCodeAndMessage(body)
}
func (h *SoraGatewayHandler) handleStreamingAwareError(c *gin.Context, status int, errType, message string, streamStarted bool) {

View File

@@ -43,6 +43,45 @@ func (s *stubSoraClient) CreateImageTask(ctx context.Context, account *service.A
func (s *stubSoraClient) CreateVideoTask(ctx context.Context, account *service.Account, req service.SoraVideoRequest) (string, error) {
return "task-video", nil
}
func (s *stubSoraClient) CreateStoryboardTask(ctx context.Context, account *service.Account, req service.SoraStoryboardRequest) (string, error) {
return "task-video", nil
}
func (s *stubSoraClient) UploadCharacterVideo(ctx context.Context, account *service.Account, data []byte) (string, error) {
return "cameo-1", nil
}
func (s *stubSoraClient) GetCameoStatus(ctx context.Context, account *service.Account, cameoID string) (*service.SoraCameoStatus, error) {
return &service.SoraCameoStatus{
Status: "finalized",
StatusMessage: "Completed",
DisplayNameHint: "Character",
UsernameHint: "user.character",
ProfileAssetURL: "https://example.com/avatar.webp",
}, nil
}
func (s *stubSoraClient) DownloadCharacterImage(ctx context.Context, account *service.Account, imageURL string) ([]byte, error) {
return []byte("avatar"), nil
}
func (s *stubSoraClient) UploadCharacterImage(ctx context.Context, account *service.Account, data []byte) (string, error) {
return "asset-pointer", nil
}
func (s *stubSoraClient) FinalizeCharacter(ctx context.Context, account *service.Account, req service.SoraCharacterFinalizeRequest) (string, error) {
return "character-1", nil
}
func (s *stubSoraClient) SetCharacterPublic(ctx context.Context, account *service.Account, cameoID string) error {
return nil
}
func (s *stubSoraClient) DeleteCharacter(ctx context.Context, account *service.Account, characterID string) error {
return nil
}
func (s *stubSoraClient) PostVideoForWatermarkFree(ctx context.Context, account *service.Account, generationID string) (string, error) {
return "s_post", nil
}
func (s *stubSoraClient) DeletePost(ctx context.Context, account *service.Account, postID string) error {
return nil
}
func (s *stubSoraClient) GetWatermarkFreeURLCustom(ctx context.Context, account *service.Account, parseURL, parseToken, postID string) (string, error) {
return "https://example.com/no-watermark.mp4", nil
}
func (s *stubSoraClient) EnhancePrompt(ctx context.Context, account *service.Account, prompt, expansionLevel string, durationS int) (string, error) {
return "enhanced prompt", nil
}
@@ -607,3 +646,31 @@ func TestSoraHandleFailoverExhausted_CloudflareChallengeIncludesRay(t *testing.T
require.Contains(t, msg, "Cloudflare challenge")
require.Contains(t, msg, "cf-ray: 9d01b0e9ecc35829-SEA")
}
func TestSoraHandleFailoverExhausted_CfShield429MappedToRateLimitError(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
headers := http.Header{}
headers.Set("cf-ray", "9d03b68c086027a1-SEA")
body := []byte(`{"error":{"code":"cf_shield_429","message":"shield blocked"}}`)
h := &SoraGatewayHandler{}
h.handleFailoverExhausted(c, http.StatusTooManyRequests, headers, body, true)
lines := strings.Split(strings.TrimSuffix(w.Body.String(), "\n\n"), "\n")
require.Len(t, lines, 2)
jsonStr := strings.TrimPrefix(lines[1], "data: ")
var parsed map[string]any
require.NoError(t, json.Unmarshal([]byte(jsonStr), &parsed))
errorObj, ok := parsed["error"].(map[string]any)
require.True(t, ok)
require.Equal(t, "rate_limit_error", errorObj["type"])
msg, _ := errorObj["message"].(string)
require.Contains(t, msg, "Cloudflare shield")
require.Contains(t, msg, "cf-ray: 9d03b68c086027a1-SEA")
}

View File

@@ -15,11 +15,14 @@ import (
"net/url"
"regexp"
"strings"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/Wei-Shaw/sub2api/internal/util/soraerror"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
@@ -28,7 +31,6 @@ import (
// sseDataPrefix matches SSE data lines with optional whitespace after colon.
// Some upstream APIs return non-standard "data:" without space (should be "data: ").
var sseDataPrefix = regexp.MustCompile(`^data:\s*`)
var cloudflareRayPattern = regexp.MustCompile(`(?i)cRay:\s*'([a-z0-9-]+)'`)
const (
testClaudeAPIURL = "https://api.anthropic.com/v1/messages"
@@ -45,6 +47,9 @@ type TestEvent struct {
Type string `json:"type"`
Text string `json:"text,omitempty"`
Model string `json:"model,omitempty"`
Status string `json:"status,omitempty"`
Code string `json:"code,omitempty"`
Data any `json:"data,omitempty"`
Success bool `json:"success,omitempty"`
Error string `json:"error,omitempty"`
}
@@ -56,8 +61,13 @@ type AccountTestService struct {
antigravityGatewayService *AntigravityGatewayService
httpUpstream HTTPUpstream
cfg *config.Config
soraTestGuardMu sync.Mutex
soraTestLastRun map[int64]time.Time
soraTestCooldown time.Duration
}
const defaultSoraTestCooldown = 10 * time.Second
// NewAccountTestService creates a new AccountTestService
func NewAccountTestService(
accountRepo AccountRepository,
@@ -72,6 +82,8 @@ func NewAccountTestService(
antigravityGatewayService: antigravityGatewayService,
httpUpstream: httpUpstream,
cfg: cfg,
soraTestLastRun: make(map[int64]time.Time),
soraTestCooldown: defaultSoraTestCooldown,
}
}
@@ -473,13 +485,129 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account
return s.processGeminiStream(c, resp.Body)
}
type soraProbeStep struct {
Name string `json:"name"`
Status string `json:"status"`
HTTPStatus int `json:"http_status,omitempty"`
ErrorCode string `json:"error_code,omitempty"`
Message string `json:"message,omitempty"`
}
type soraProbeSummary struct {
Status string `json:"status"`
Steps []soraProbeStep `json:"steps"`
}
type soraProbeRecorder struct {
steps []soraProbeStep
}
func (r *soraProbeRecorder) addStep(name, status string, httpStatus int, errorCode, message string) {
r.steps = append(r.steps, soraProbeStep{
Name: name,
Status: status,
HTTPStatus: httpStatus,
ErrorCode: strings.TrimSpace(errorCode),
Message: strings.TrimSpace(message),
})
}
func (r *soraProbeRecorder) finalize() soraProbeSummary {
meSuccess := false
partial := false
for _, step := range r.steps {
if step.Name == "me" {
meSuccess = strings.EqualFold(step.Status, "success")
continue
}
if strings.EqualFold(step.Status, "failed") {
partial = true
}
}
status := "success"
if !meSuccess {
status = "failed"
} else if partial {
status = "partial_success"
}
return soraProbeSummary{
Status: status,
Steps: append([]soraProbeStep(nil), r.steps...),
}
}
func (s *AccountTestService) emitSoraProbeSummary(c *gin.Context, rec *soraProbeRecorder) {
if rec == nil {
return
}
summary := rec.finalize()
code := ""
for _, step := range summary.Steps {
if strings.EqualFold(step.Status, "failed") && strings.TrimSpace(step.ErrorCode) != "" {
code = step.ErrorCode
break
}
}
s.sendEvent(c, TestEvent{
Type: "sora_test_result",
Status: summary.Status,
Code: code,
Data: summary,
})
}
func (s *AccountTestService) acquireSoraTestPermit(accountID int64) (time.Duration, bool) {
if accountID <= 0 {
return 0, true
}
s.soraTestGuardMu.Lock()
defer s.soraTestGuardMu.Unlock()
if s.soraTestLastRun == nil {
s.soraTestLastRun = make(map[int64]time.Time)
}
cooldown := s.soraTestCooldown
if cooldown <= 0 {
cooldown = defaultSoraTestCooldown
}
now := time.Now()
if lastRun, ok := s.soraTestLastRun[accountID]; ok {
elapsed := now.Sub(lastRun)
if elapsed < cooldown {
return cooldown - elapsed, false
}
}
s.soraTestLastRun[accountID] = now
return 0, true
}
func ceilSeconds(d time.Duration) int {
if d <= 0 {
return 1
}
sec := int(d / time.Second)
if d%time.Second != 0 {
sec++
}
if sec < 1 {
sec = 1
}
return sec
}
// testSoraAccountConnection 测试 Sora 账号的连接
// 调用 /backend/me 接口验证 access_token 有效性(不需要 Sentinel Token
func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account *Account) error {
ctx := c.Request.Context()
recorder := &soraProbeRecorder{}
authToken := account.GetCredential("access_token")
if authToken == "" {
recorder.addStep("me", "failed", http.StatusUnauthorized, "missing_access_token", "No access token available")
s.emitSoraProbeSummary(c, recorder)
return s.sendErrorAndEnd(c, "No access token available")
}
@@ -490,11 +618,20 @@ func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account *
c.Writer.Header().Set("X-Accel-Buffering", "no")
c.Writer.Flush()
if wait, ok := s.acquireSoraTestPermit(account.ID); !ok {
msg := fmt.Sprintf("Sora 账号测试过于频繁,请 %d 秒后重试", ceilSeconds(wait))
recorder.addStep("rate_limit", "failed", http.StatusTooManyRequests, "test_rate_limited", msg)
s.emitSoraProbeSummary(c, recorder)
return s.sendErrorAndEnd(c, msg)
}
// Send test_start event
s.sendEvent(c, TestEvent{Type: "test_start", Model: "sora"})
req, err := http.NewRequestWithContext(ctx, "GET", soraMeAPIURL, nil)
if err != nil {
recorder.addStep("me", "failed", 0, "request_build_failed", err.Error())
s.emitSoraProbeSummary(c, recorder)
return s.sendErrorAndEnd(c, "Failed to create request")
}
@@ -515,6 +652,8 @@ func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account *
resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, enableSoraTLSFingerprint)
if err != nil {
recorder.addStep("me", "failed", 0, "network_error", err.Error())
s.emitSoraProbeSummary(c, recorder)
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
}
defer func() { _ = resp.Body.Close() }()
@@ -522,12 +661,33 @@ func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account *
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
if isCloudflareChallengeResponse(resp.StatusCode, body) {
if isCloudflareChallengeResponse(resp.StatusCode, resp.Header, body) {
recorder.addStep("me", "failed", resp.StatusCode, "cf_challenge", "Cloudflare challenge detected")
s.emitSoraProbeSummary(c, recorder)
s.logSoraCloudflareChallenge(account, proxyURL, soraMeAPIURL, resp.Header, body)
return s.sendErrorAndEnd(c, formatCloudflareChallengeMessage("Sora request blocked by Cloudflare challenge (HTTP 403). Please switch to a clean proxy/network and retry.", resp.Header, body))
return s.sendErrorAndEnd(c, formatCloudflareChallengeMessage(fmt.Sprintf("Sora request blocked by Cloudflare challenge (HTTP %d). Please switch to a clean proxy/network and retry.", resp.StatusCode), resp.Header, body))
}
upstreamCode, upstreamMessage := soraerror.ExtractUpstreamErrorCodeAndMessage(body)
switch {
case resp.StatusCode == http.StatusUnauthorized && strings.EqualFold(upstreamCode, "token_invalidated"):
recorder.addStep("me", "failed", resp.StatusCode, "token_invalidated", "Sora token invalidated")
s.emitSoraProbeSummary(c, recorder)
return s.sendErrorAndEnd(c, "Sora token 已失效token_invalidated请重新授权账号")
case strings.EqualFold(upstreamCode, "unsupported_country_code"):
recorder.addStep("me", "failed", resp.StatusCode, "unsupported_country_code", "Sora is unavailable in current egress region")
s.emitSoraProbeSummary(c, recorder)
return s.sendErrorAndEnd(c, "Sora 在当前网络出口地区不可用unsupported_country_code请切换到支持地区后重试")
case strings.TrimSpace(upstreamMessage) != "":
recorder.addStep("me", "failed", resp.StatusCode, upstreamCode, upstreamMessage)
s.emitSoraProbeSummary(c, recorder)
return s.sendErrorAndEnd(c, fmt.Sprintf("Sora API returned %d: %s", resp.StatusCode, upstreamMessage))
default:
recorder.addStep("me", "failed", resp.StatusCode, upstreamCode, "Sora me endpoint failed")
s.emitSoraProbeSummary(c, recorder)
return s.sendErrorAndEnd(c, fmt.Sprintf("Sora API returned %d: %s", resp.StatusCode, truncateSoraErrorBody(body, 512)))
}
return s.sendErrorAndEnd(c, fmt.Sprintf("Sora API returned %d: %s", resp.StatusCode, truncateSoraErrorBody(body, 512)))
}
recorder.addStep("me", "success", resp.StatusCode, "", "me endpoint ok")
// 解析 /me 响应,提取用户信息
var meResp map[string]any
@@ -557,21 +717,26 @@ func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account *
subResp, subErr := s.httpUpstream.DoWithTLS(subReq, proxyURL, account.ID, account.Concurrency, enableSoraTLSFingerprint)
if subErr != nil {
recorder.addStep("subscription", "failed", 0, "network_error", subErr.Error())
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Subscription check skipped: %s", subErr.Error())})
} else {
subBody, _ := io.ReadAll(subResp.Body)
_ = subResp.Body.Close()
if subResp.StatusCode == http.StatusOK {
recorder.addStep("subscription", "success", subResp.StatusCode, "", "subscription endpoint ok")
if summary := parseSoraSubscriptionSummary(subBody); summary != "" {
s.sendEvent(c, TestEvent{Type: "content", Text: summary})
} else {
s.sendEvent(c, TestEvent{Type: "content", Text: "Subscription check OK"})
}
} else {
if isCloudflareChallengeResponse(subResp.StatusCode, subBody) {
if isCloudflareChallengeResponse(subResp.StatusCode, subResp.Header, subBody) {
recorder.addStep("subscription", "failed", subResp.StatusCode, "cf_challenge", "Cloudflare challenge detected")
s.logSoraCloudflareChallenge(account, proxyURL, soraBillingAPIURL, subResp.Header, subBody)
s.sendEvent(c, TestEvent{Type: "content", Text: formatCloudflareChallengeMessage("Subscription check blocked by Cloudflare challenge (HTTP 403)", subResp.Header, subBody)})
s.sendEvent(c, TestEvent{Type: "content", Text: formatCloudflareChallengeMessage(fmt.Sprintf("Subscription check blocked by Cloudflare challenge (HTTP %d)", subResp.StatusCode), subResp.Header, subBody)})
} else {
upstreamCode, upstreamMessage := soraerror.ExtractUpstreamErrorCodeAndMessage(subBody)
recorder.addStep("subscription", "failed", subResp.StatusCode, upstreamCode, upstreamMessage)
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Subscription check returned %d", subResp.StatusCode)})
}
}
@@ -579,8 +744,9 @@ func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account *
}
// 追加 Sora2 能力探测(对齐 sora2api 的测试思路):邀请码 + 剩余额度。
s.testSora2Capabilities(c, ctx, account, authToken, proxyURL, enableSoraTLSFingerprint)
s.testSora2Capabilities(c, ctx, account, authToken, proxyURL, enableSoraTLSFingerprint, recorder)
s.emitSoraProbeSummary(c, recorder)
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
return nil
}
@@ -592,6 +758,7 @@ func (s *AccountTestService) testSora2Capabilities(
authToken string,
proxyURL string,
enableTLSFingerprint bool,
recorder *soraProbeRecorder,
) {
inviteStatus, inviteHeader, inviteBody, err := s.fetchSoraTestEndpoint(
ctx,
@@ -602,6 +769,9 @@ func (s *AccountTestService) testSora2Capabilities(
enableTLSFingerprint,
)
if err != nil {
if recorder != nil {
recorder.addStep("sora2_invite", "failed", 0, "network_error", err.Error())
}
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 invite check skipped: %s", err.Error())})
return
}
@@ -616,6 +786,9 @@ func (s *AccountTestService) testSora2Capabilities(
enableTLSFingerprint,
)
if bootstrapErr == nil && bootstrapStatus == http.StatusOK {
if recorder != nil {
recorder.addStep("sora2_bootstrap", "success", bootstrapStatus, "", "bootstrap endpoint ok")
}
s.sendEvent(c, TestEvent{Type: "content", Text: "Sora2 bootstrap OK, retry invite check"})
inviteStatus, inviteHeader, inviteBody, err = s.fetchSoraTestEndpoint(
ctx,
@@ -626,20 +799,42 @@ func (s *AccountTestService) testSora2Capabilities(
enableTLSFingerprint,
)
if err != nil {
if recorder != nil {
recorder.addStep("sora2_invite", "failed", 0, "network_error", err.Error())
}
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 invite retry failed: %s", err.Error())})
return
}
} else if recorder != nil {
code := ""
msg := ""
if bootstrapErr != nil {
code = "network_error"
msg = bootstrapErr.Error()
}
recorder.addStep("sora2_bootstrap", "failed", bootstrapStatus, code, msg)
}
}
if inviteStatus != http.StatusOK {
if isCloudflareChallengeResponse(inviteStatus, inviteBody) {
s.sendEvent(c, TestEvent{Type: "content", Text: formatCloudflareChallengeMessage("Sora2 invite check blocked by Cloudflare challenge (HTTP 403)", inviteHeader, inviteBody)})
if isCloudflareChallengeResponse(inviteStatus, inviteHeader, inviteBody) {
if recorder != nil {
recorder.addStep("sora2_invite", "failed", inviteStatus, "cf_challenge", "Cloudflare challenge detected")
}
s.logSoraCloudflareChallenge(account, proxyURL, soraInviteMineURL, inviteHeader, inviteBody)
s.sendEvent(c, TestEvent{Type: "content", Text: formatCloudflareChallengeMessage(fmt.Sprintf("Sora2 invite check blocked by Cloudflare challenge (HTTP %d)", inviteStatus), inviteHeader, inviteBody)})
return
}
upstreamCode, upstreamMessage := soraerror.ExtractUpstreamErrorCodeAndMessage(inviteBody)
if recorder != nil {
recorder.addStep("sora2_invite", "failed", inviteStatus, upstreamCode, upstreamMessage)
}
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 invite check returned %d", inviteStatus)})
return
}
if recorder != nil {
recorder.addStep("sora2_invite", "success", inviteStatus, "", "invite endpoint ok")
}
if summary := parseSoraInviteSummary(inviteBody); summary != "" {
s.sendEvent(c, TestEvent{Type: "content", Text: summary})
@@ -656,17 +851,31 @@ func (s *AccountTestService) testSora2Capabilities(
enableTLSFingerprint,
)
if remainingErr != nil {
if recorder != nil {
recorder.addStep("sora2_remaining", "failed", 0, "network_error", remainingErr.Error())
}
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 remaining check skipped: %s", remainingErr.Error())})
return
}
if remainingStatus != http.StatusOK {
if isCloudflareChallengeResponse(remainingStatus, remainingBody) {
s.sendEvent(c, TestEvent{Type: "content", Text: formatCloudflareChallengeMessage("Sora2 remaining check blocked by Cloudflare challenge (HTTP 403)", remainingHeader, remainingBody)})
if isCloudflareChallengeResponse(remainingStatus, remainingHeader, remainingBody) {
if recorder != nil {
recorder.addStep("sora2_remaining", "failed", remainingStatus, "cf_challenge", "Cloudflare challenge detected")
}
s.logSoraCloudflareChallenge(account, proxyURL, soraRemainingURL, remainingHeader, remainingBody)
s.sendEvent(c, TestEvent{Type: "content", Text: formatCloudflareChallengeMessage(fmt.Sprintf("Sora2 remaining check blocked by Cloudflare challenge (HTTP %d)", remainingStatus), remainingHeader, remainingBody)})
return
}
upstreamCode, upstreamMessage := soraerror.ExtractUpstreamErrorCodeAndMessage(remainingBody)
if recorder != nil {
recorder.addStep("sora2_remaining", "failed", remainingStatus, upstreamCode, upstreamMessage)
}
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 remaining check returned %d", remainingStatus)})
return
}
if recorder != nil {
recorder.addStep("sora2_remaining", "success", remainingStatus, "", "remaining endpoint ok")
}
if summary := parseSoraRemainingSummary(remainingBody); summary != "" {
s.sendEvent(c, TestEvent{Type: "content", Text: summary})
} else {
@@ -789,42 +998,16 @@ func (s *AccountTestService) shouldEnableSoraTLSFingerprint() bool {
return !s.cfg.Sora.Client.DisableTLSFingerprint
}
func isCloudflareChallengeResponse(statusCode int, body []byte) bool {
if statusCode != http.StatusForbidden {
return false
}
preview := strings.ToLower(truncateSoraErrorBody(body, 4096))
return strings.Contains(preview, "window._cf_chl_opt") ||
strings.Contains(preview, "just a moment") ||
strings.Contains(preview, "enable javascript and cookies to continue")
func isCloudflareChallengeResponse(statusCode int, headers http.Header, body []byte) bool {
return soraerror.IsCloudflareChallengeResponse(statusCode, headers, body)
}
func formatCloudflareChallengeMessage(base string, headers http.Header, body []byte) string {
rayID := extractCloudflareRayID(headers, body)
if rayID == "" {
return base
}
return fmt.Sprintf("%s (cf-ray: %s)", base, rayID)
return soraerror.FormatCloudflareChallengeMessage(base, headers, body)
}
func extractCloudflareRayID(headers http.Header, body []byte) string {
if headers != nil {
rayID := strings.TrimSpace(headers.Get("cf-ray"))
if rayID != "" {
return rayID
}
rayID = strings.TrimSpace(headers.Get("Cf-Ray"))
if rayID != "" {
return rayID
}
}
preview := truncateSoraErrorBody(body, 8192)
matches := cloudflareRayPattern.FindStringSubmatch(preview)
if len(matches) >= 2 {
return strings.TrimSpace(matches[1])
}
return ""
return soraerror.ExtractCloudflareRayID(headers, body)
}
func extractSoraEgressIPHint(headers http.Header) string {
@@ -897,14 +1080,7 @@ func (s *AccountTestService) logSoraCloudflareChallenge(account *Account, proxyU
}
func truncateSoraErrorBody(body []byte, max int) string {
if max <= 0 {
max = 512
}
raw := strings.TrimSpace(string(body))
if len(raw) <= max {
return raw
}
return raw[:max] + "...(truncated)"
return soraerror.TruncateBody(body, max)
}
// testAntigravityAccountConnection tests an Antigravity account's connection

View File

@@ -7,6 +7,7 @@ import (
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/gin-gonic/gin"
@@ -109,6 +110,8 @@ func TestAccountTestService_testSoraAccountConnection_WithSubscription(t *testin
require.Contains(t, body, "Subscription: ChatGPT Plus | chatgpt_plus | end=2026-12-31T00:00:00Z")
require.Contains(t, body, "Sora2: supported | invite=inv_abc | used=3/50")
require.Contains(t, body, "Sora2 remaining: 27 | reset_in=46833s")
require.Contains(t, body, `"type":"sora_test_result"`)
require.Contains(t, body, `"status":"success"`)
require.Contains(t, body, `"type":"test_complete","success":true`)
}
@@ -141,6 +144,8 @@ func TestAccountTestService_testSoraAccountConnection_SubscriptionFailedStillSuc
require.Contains(t, body, "Sora connection OK - User: demo-user")
require.Contains(t, body, "Subscription check returned 403")
require.Contains(t, body, "Sora2 invite check returned 401")
require.Contains(t, body, `"type":"sora_test_result"`)
require.Contains(t, body, `"status":"partial_success"`)
require.Contains(t, body, `"type":"test_complete","success":true`)
}
@@ -173,6 +178,97 @@ func TestAccountTestService_testSoraAccountConnection_CloudflareChallenge(t *tes
require.Contains(t, body, "cf-ray: 9cff2d62d83bb98d")
}
func TestAccountTestService_testSoraAccountConnection_CloudflareChallenge429WithHeader(t *testing.T) {
upstream := &queuedHTTPUpstream{
responses: []*http.Response{
newJSONResponseWithHeader(http.StatusTooManyRequests, `<!DOCTYPE html><html><head><title>Just a moment...</title></head><body></body></html>`, "cf-mitigated", "challenge"),
},
}
svc := &AccountTestService{httpUpstream: upstream}
account := &Account{
ID: 1,
Platform: PlatformSora,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "test_token",
},
}
c, rec := newSoraTestContext()
err := svc.testSoraAccountConnection(c, account)
require.Error(t, err)
require.Contains(t, err.Error(), "Cloudflare challenge")
require.Contains(t, err.Error(), "HTTP 429")
body := rec.Body.String()
require.Contains(t, body, "Cloudflare challenge")
}
func TestAccountTestService_testSoraAccountConnection_TokenInvalidated(t *testing.T) {
upstream := &queuedHTTPUpstream{
responses: []*http.Response{
newJSONResponse(http.StatusUnauthorized, `{"error":{"code":"token_invalidated","message":"Token invalid"}}`),
},
}
svc := &AccountTestService{httpUpstream: upstream}
account := &Account{
ID: 1,
Platform: PlatformSora,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "test_token",
},
}
c, rec := newSoraTestContext()
err := svc.testSoraAccountConnection(c, account)
require.Error(t, err)
require.Contains(t, err.Error(), "token_invalidated")
body := rec.Body.String()
require.Contains(t, body, `"type":"sora_test_result"`)
require.Contains(t, body, `"status":"failed"`)
require.Contains(t, body, "token_invalidated")
require.NotContains(t, body, `"type":"test_complete","success":true`)
}
func TestAccountTestService_testSoraAccountConnection_RateLimited(t *testing.T) {
upstream := &queuedHTTPUpstream{
responses: []*http.Response{
newJSONResponse(http.StatusOK, `{"email":"demo@example.com"}`),
},
}
svc := &AccountTestService{
httpUpstream: upstream,
soraTestCooldown: time.Hour,
}
account := &Account{
ID: 1,
Platform: PlatformSora,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "test_token",
},
}
c1, _ := newSoraTestContext()
err := svc.testSoraAccountConnection(c1, account)
require.NoError(t, err)
c2, rec2 := newSoraTestContext()
err = svc.testSoraAccountConnection(c2, account)
require.Error(t, err)
require.Contains(t, err.Error(), "测试过于频繁")
body := rec2.Body.String()
require.Contains(t, body, `"type":"sora_test_result"`)
require.Contains(t, body, `"code":"test_rate_limited"`)
require.Contains(t, body, `"status":"failed"`)
require.NotContains(t, body, `"type":"test_complete","success":true`)
}
func TestAccountTestService_testSoraAccountConnection_SubscriptionCloudflareChallengeWithRay(t *testing.T) {
upstream := &queuedHTTPUpstream{
responses: []*http.Response{

View File

@@ -95,6 +95,16 @@ var soraDesktopUserAgents = []string{
"Mozilla/5.0 (Windows NT 11.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36",
}
var soraMobileUserAgents = []string{
"Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)",
"Sora/1.2026.007 (Android 14; SM-G998B; build 2600700)",
"Sora/1.2026.007 (Android 15; Pixel 8 Pro; build 2600700)",
"Sora/1.2026.007 (Android 14; Pixel 7; build 2600700)",
"Sora/1.2026.007 (Android 15; 2211133C; build 2600700)",
"Sora/1.2026.007 (Android 14; SM-S918B; build 2600700)",
"Sora/1.2026.007 (Android 15; OnePlus 12; build 2600700)",
}
var soraRand = rand.New(rand.NewSource(time.Now().UnixNano()))
var soraRandMu sync.Mutex
var soraPerfStart = time.Now()
@@ -106,6 +116,17 @@ type SoraClient interface {
UploadImage(ctx context.Context, account *Account, data []byte, filename string) (string, error)
CreateImageTask(ctx context.Context, account *Account, req SoraImageRequest) (string, error)
CreateVideoTask(ctx context.Context, account *Account, req SoraVideoRequest) (string, error)
CreateStoryboardTask(ctx context.Context, account *Account, req SoraStoryboardRequest) (string, error)
UploadCharacterVideo(ctx context.Context, account *Account, data []byte) (string, error)
GetCameoStatus(ctx context.Context, account *Account, cameoID string) (*SoraCameoStatus, error)
DownloadCharacterImage(ctx context.Context, account *Account, imageURL string) ([]byte, error)
UploadCharacterImage(ctx context.Context, account *Account, data []byte) (string, error)
FinalizeCharacter(ctx context.Context, account *Account, req SoraCharacterFinalizeRequest) (string, error)
SetCharacterPublic(ctx context.Context, account *Account, cameoID string) error
DeleteCharacter(ctx context.Context, account *Account, characterID string) error
PostVideoForWatermarkFree(ctx context.Context, account *Account, generationID string) (string, error)
DeletePost(ctx context.Context, account *Account, postID string) error
GetWatermarkFreeURLCustom(ctx context.Context, account *Account, parseURL, parseToken, postID string) (string, error)
EnhancePrompt(ctx context.Context, account *Account, prompt, expansionLevel string, durationS int) (string, error)
GetImageTask(ctx context.Context, account *Account, taskID string) (*SoraImageTaskStatus, error)
GetVideoTask(ctx context.Context, account *Account, taskID string) (*SoraVideoTaskStatus, error)
@@ -128,6 +149,17 @@ type SoraVideoRequest struct {
Size string
MediaID string
RemixTargetID string
CameoIDs []string
}
// SoraStoryboardRequest 分镜视频生成请求参数
type SoraStoryboardRequest struct {
Prompt string
Orientation string
Frames int
Model string
Size string
MediaID string
}
// SoraImageTaskStatus 图片任务状态
@@ -141,11 +173,32 @@ type SoraImageTaskStatus struct {
// SoraVideoTaskStatus 视频任务状态
type SoraVideoTaskStatus struct {
ID string
Status string
ProgressPct int
URLs []string
ErrorMsg string
ID string
Status string
ProgressPct int
URLs []string
GenerationID string
ErrorMsg string
}
// SoraCameoStatus 角色处理中间态
type SoraCameoStatus struct {
Status string
StatusMessage string
DisplayNameHint string
UsernameHint string
ProfileAssetURL string
InstructionSetHint any
InstructionSet any
}
// SoraCharacterFinalizeRequest 角色定稿请求参数
type SoraCharacterFinalizeRequest struct {
CameoID string
Username string
DisplayName string
ProfileAssetPointer string
InstructionSet any
}
// SoraUpstreamError 上游错误
@@ -407,6 +460,9 @@ func (c *SoraDirectClient) CreateVideoTask(ctx context.Context, account *Account
payload["remix_target_id"] = req.RemixTargetID
payload["cameo_ids"] = []string{}
payload["cameo_replacements"] = map[string]any{}
} else if len(req.CameoIDs) > 0 {
payload["cameo_ids"] = req.CameoIDs
payload["cameo_replacements"] = map[string]any{}
}
headers := c.buildBaseHeaders(token, userAgent)
@@ -434,6 +490,425 @@ func (c *SoraDirectClient) CreateVideoTask(ctx context.Context, account *Account
return taskID, nil
}
func (c *SoraDirectClient) CreateStoryboardTask(ctx context.Context, account *Account, req SoraStoryboardRequest) (string, error) {
token, err := c.getAccessToken(ctx, account)
if err != nil {
return "", err
}
userAgent := c.taskUserAgent()
proxyURL := c.resolveProxyURL(account)
orientation := req.Orientation
if orientation == "" {
orientation = "landscape"
}
nFrames := req.Frames
if nFrames <= 0 {
nFrames = 450
}
model := req.Model
if model == "" {
model = "sy_8"
}
size := req.Size
if size == "" {
size = "small"
}
inpaintItems := []map[string]any{}
if strings.TrimSpace(req.MediaID) != "" {
inpaintItems = append(inpaintItems, map[string]any{
"kind": "upload",
"upload_id": req.MediaID,
})
}
payload := map[string]any{
"kind": "video",
"prompt": req.Prompt,
"title": "Draft your video",
"orientation": orientation,
"size": size,
"n_frames": nFrames,
"storyboard_id": nil,
"inpaint_items": inpaintItems,
"remix_target_id": nil,
"model": model,
"metadata": nil,
"style_id": nil,
"cameo_ids": nil,
"cameo_replacements": nil,
"audio_caption": nil,
"audio_transcript": nil,
"video_caption": nil,
}
headers := c.buildBaseHeaders(token, userAgent)
headers.Set("Content-Type", "application/json")
headers.Set("Origin", "https://sora.chatgpt.com")
headers.Set("Referer", "https://sora.chatgpt.com/")
body, err := json.Marshal(payload)
if err != nil {
return "", err
}
sentinel, err := c.generateSentinelToken(ctx, account, token, userAgent, proxyURL)
if err != nil {
return "", err
}
headers.Set("openai-sentinel-token", sentinel)
respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodPost, c.buildURL("/nf/create/storyboard"), headers, bytes.NewReader(body), true)
if err != nil {
return "", err
}
taskID := strings.TrimSpace(gjson.GetBytes(respBody, "id").String())
if taskID == "" {
return "", errors.New("storyboard task response missing id")
}
return taskID, nil
}
func (c *SoraDirectClient) UploadCharacterVideo(ctx context.Context, account *Account, data []byte) (string, error) {
if len(data) == 0 {
return "", errors.New("empty video data")
}
token, err := c.getAccessToken(ctx, account)
if err != nil {
return "", err
}
userAgent := c.taskUserAgent()
proxyURL := c.resolveProxyURL(account)
var body bytes.Buffer
writer := multipart.NewWriter(&body)
partHeader := make(textproto.MIMEHeader)
partHeader.Set("Content-Disposition", `form-data; name="file"; filename="video.mp4"`)
partHeader.Set("Content-Type", "video/mp4")
part, err := writer.CreatePart(partHeader)
if err != nil {
return "", err
}
if _, err := part.Write(data); err != nil {
return "", err
}
if err := writer.WriteField("timestamps", "0,3"); err != nil {
return "", err
}
if err := writer.Close(); err != nil {
return "", err
}
headers := c.buildBaseHeaders(token, userAgent)
headers.Set("Content-Type", writer.FormDataContentType())
respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodPost, c.buildURL("/characters/upload"), headers, &body, false)
if err != nil {
return "", err
}
cameoID := strings.TrimSpace(gjson.GetBytes(respBody, "id").String())
if cameoID == "" {
return "", errors.New("character upload response missing id")
}
return cameoID, nil
}
func (c *SoraDirectClient) GetCameoStatus(ctx context.Context, account *Account, cameoID string) (*SoraCameoStatus, error) {
token, err := c.getAccessToken(ctx, account)
if err != nil {
return nil, err
}
userAgent := c.taskUserAgent()
proxyURL := c.resolveProxyURL(account)
headers := c.buildBaseHeaders(token, userAgent)
respBody, _, err := c.doRequestWithProxy(
ctx,
account,
proxyURL,
http.MethodGet,
c.buildURL("/project_y/cameos/in_progress/"+strings.TrimSpace(cameoID)),
headers,
nil,
false,
)
if err != nil {
return nil, err
}
return &SoraCameoStatus{
Status: strings.TrimSpace(gjson.GetBytes(respBody, "status").String()),
StatusMessage: strings.TrimSpace(gjson.GetBytes(respBody, "status_message").String()),
DisplayNameHint: strings.TrimSpace(gjson.GetBytes(respBody, "display_name_hint").String()),
UsernameHint: strings.TrimSpace(gjson.GetBytes(respBody, "username_hint").String()),
ProfileAssetURL: strings.TrimSpace(gjson.GetBytes(respBody, "profile_asset_url").String()),
InstructionSetHint: gjson.GetBytes(respBody, "instruction_set_hint").Value(),
InstructionSet: gjson.GetBytes(respBody, "instruction_set").Value(),
}, nil
}
func (c *SoraDirectClient) DownloadCharacterImage(ctx context.Context, account *Account, imageURL string) ([]byte, error) {
token, err := c.getAccessToken(ctx, account)
if err != nil {
return nil, err
}
userAgent := c.taskUserAgent()
proxyURL := c.resolveProxyURL(account)
headers := c.buildBaseHeaders(token, userAgent)
headers.Set("Accept", "image/*,*/*;q=0.8")
respBody, _, err := c.doRequestWithProxy(
ctx,
account,
proxyURL,
http.MethodGet,
strings.TrimSpace(imageURL),
headers,
nil,
false,
)
if err != nil {
return nil, err
}
return respBody, nil
}
func (c *SoraDirectClient) UploadCharacterImage(ctx context.Context, account *Account, data []byte) (string, error) {
if len(data) == 0 {
return "", errors.New("empty character image")
}
token, err := c.getAccessToken(ctx, account)
if err != nil {
return "", err
}
userAgent := c.taskUserAgent()
proxyURL := c.resolveProxyURL(account)
var body bytes.Buffer
writer := multipart.NewWriter(&body)
partHeader := make(textproto.MIMEHeader)
partHeader.Set("Content-Disposition", `form-data; name="file"; filename="profile.webp"`)
partHeader.Set("Content-Type", "image/webp")
part, err := writer.CreatePart(partHeader)
if err != nil {
return "", err
}
if _, err := part.Write(data); err != nil {
return "", err
}
if err := writer.WriteField("use_case", "profile"); err != nil {
return "", err
}
if err := writer.Close(); err != nil {
return "", err
}
headers := c.buildBaseHeaders(token, userAgent)
headers.Set("Content-Type", writer.FormDataContentType())
respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodPost, c.buildURL("/project_y/file/upload"), headers, &body, false)
if err != nil {
return "", err
}
assetPointer := strings.TrimSpace(gjson.GetBytes(respBody, "asset_pointer").String())
if assetPointer == "" {
return "", errors.New("character image upload response missing asset_pointer")
}
return assetPointer, nil
}
func (c *SoraDirectClient) FinalizeCharacter(ctx context.Context, account *Account, req SoraCharacterFinalizeRequest) (string, error) {
token, err := c.getAccessToken(ctx, account)
if err != nil {
return "", err
}
userAgent := c.taskUserAgent()
proxyURL := c.resolveProxyURL(account)
payload := map[string]any{
"cameo_id": req.CameoID,
"username": req.Username,
"display_name": req.DisplayName,
"profile_asset_pointer": req.ProfileAssetPointer,
"instruction_set": nil,
"safety_instruction_set": nil,
}
body, err := json.Marshal(payload)
if err != nil {
return "", err
}
headers := c.buildBaseHeaders(token, userAgent)
headers.Set("Content-Type", "application/json")
headers.Set("Origin", "https://sora.chatgpt.com")
headers.Set("Referer", "https://sora.chatgpt.com/")
respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodPost, c.buildURL("/characters/finalize"), headers, bytes.NewReader(body), false)
if err != nil {
return "", err
}
characterID := strings.TrimSpace(gjson.GetBytes(respBody, "character.character_id").String())
if characterID == "" {
return "", errors.New("character finalize response missing character_id")
}
return characterID, nil
}
func (c *SoraDirectClient) SetCharacterPublic(ctx context.Context, account *Account, cameoID string) error {
token, err := c.getAccessToken(ctx, account)
if err != nil {
return err
}
userAgent := c.taskUserAgent()
proxyURL := c.resolveProxyURL(account)
payload := map[string]any{"visibility": "public"}
body, err := json.Marshal(payload)
if err != nil {
return err
}
headers := c.buildBaseHeaders(token, userAgent)
headers.Set("Content-Type", "application/json")
headers.Set("Origin", "https://sora.chatgpt.com")
headers.Set("Referer", "https://sora.chatgpt.com/")
_, _, err = c.doRequestWithProxy(
ctx,
account,
proxyURL,
http.MethodPost,
c.buildURL("/project_y/cameos/by_id/"+strings.TrimSpace(cameoID)+"/update_v2"),
headers,
bytes.NewReader(body),
false,
)
return err
}
func (c *SoraDirectClient) DeleteCharacter(ctx context.Context, account *Account, characterID string) error {
token, err := c.getAccessToken(ctx, account)
if err != nil {
return err
}
userAgent := c.taskUserAgent()
proxyURL := c.resolveProxyURL(account)
headers := c.buildBaseHeaders(token, userAgent)
_, _, err = c.doRequestWithProxy(
ctx,
account,
proxyURL,
http.MethodDelete,
c.buildURL("/project_y/characters/"+strings.TrimSpace(characterID)),
headers,
nil,
false,
)
return err
}
func (c *SoraDirectClient) PostVideoForWatermarkFree(ctx context.Context, account *Account, generationID string) (string, error) {
token, err := c.getAccessToken(ctx, account)
if err != nil {
return "", err
}
userAgent := c.taskUserAgent()
proxyURL := c.resolveProxyURL(account)
payload := map[string]any{
"attachments_to_create": []map[string]any{
{
"generation_id": generationID,
"kind": "sora",
},
},
"post_text": "",
}
body, err := json.Marshal(payload)
if err != nil {
return "", err
}
headers := c.buildBaseHeaders(token, userAgent)
headers.Set("Content-Type", "application/json")
headers.Set("Origin", "https://sora.chatgpt.com")
headers.Set("Referer", "https://sora.chatgpt.com/")
sentinel, err := c.generateSentinelToken(ctx, account, token, userAgent, proxyURL)
if err != nil {
return "", err
}
headers.Set("openai-sentinel-token", sentinel)
respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodPost, c.buildURL("/project_y/post"), headers, bytes.NewReader(body), true)
if err != nil {
return "", err
}
postID := strings.TrimSpace(gjson.GetBytes(respBody, "post.id").String())
if postID == "" {
return "", errors.New("watermark-free publish response missing post.id")
}
return postID, nil
}
func (c *SoraDirectClient) DeletePost(ctx context.Context, account *Account, postID string) error {
token, err := c.getAccessToken(ctx, account)
if err != nil {
return err
}
userAgent := c.taskUserAgent()
proxyURL := c.resolveProxyURL(account)
headers := c.buildBaseHeaders(token, userAgent)
_, _, err = c.doRequestWithProxy(
ctx,
account,
proxyURL,
http.MethodDelete,
c.buildURL("/project_y/post/"+strings.TrimSpace(postID)),
headers,
nil,
false,
)
return err
}
func (c *SoraDirectClient) GetWatermarkFreeURLCustom(ctx context.Context, account *Account, parseURL, parseToken, postID string) (string, error) {
parseURL = strings.TrimRight(strings.TrimSpace(parseURL), "/")
if parseURL == "" {
return "", errors.New("custom parse url is required")
}
if strings.TrimSpace(parseToken) == "" {
return "", errors.New("custom parse token is required")
}
shareURL := "https://sora.chatgpt.com/p/" + strings.TrimSpace(postID)
payload := map[string]any{
"url": shareURL,
"token": strings.TrimSpace(parseToken),
}
body, err := json.Marshal(payload)
if err != nil {
return "", err
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, parseURL+"/get-sora-link", bytes.NewReader(body))
if err != nil {
return "", err
}
req.Header.Set("Content-Type", "application/json")
proxyURL := c.resolveProxyURL(account)
accountID := int64(0)
accountConcurrency := 0
if account != nil {
accountID = account.ID
accountConcurrency = account.Concurrency
}
var resp *http.Response
if c.httpUpstream != nil {
resp, err = c.httpUpstream.Do(req, proxyURL, accountID, accountConcurrency)
} else {
resp, err = http.DefaultClient.Do(req)
}
if err != nil {
return "", err
}
defer func() { _ = resp.Body.Close() }()
raw, err := io.ReadAll(io.LimitReader(resp.Body, 4<<20))
if err != nil {
return "", err
}
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("custom parse failed: %d %s", resp.StatusCode, truncateForLog(raw, 256))
}
downloadLink := strings.TrimSpace(gjson.GetBytes(raw, "download_link").String())
if downloadLink == "" {
return "", errors.New("custom parse response missing download_link")
}
return downloadLink, nil
}
func (c *SoraDirectClient) EnhancePrompt(ctx context.Context, account *Account, prompt, expansionLevel string, durationS int) (string, error) {
token, err := c.getAccessToken(ctx, account)
if err != nil {
@@ -607,6 +1082,7 @@ func (c *SoraDirectClient) GetVideoTask(ctx context.Context, account *Account, t
if draft.Get("task_id").String() != taskID {
return true
}
generationID := strings.TrimSpace(draft.Get("id").String())
kind := strings.TrimSpace(draft.Get("kind").String())
reason := strings.TrimSpace(draft.Get("reason_str").String())
if reason == "" {
@@ -623,15 +1099,17 @@ func (c *SoraDirectClient) GetVideoTask(ctx context.Context, account *Account, t
msg = "Content violates guardrails"
}
draftFound = &SoraVideoTaskStatus{
ID: taskID,
Status: "failed",
ErrorMsg: msg,
ID: taskID,
Status: "failed",
GenerationID: generationID,
ErrorMsg: msg,
}
} else {
draftFound = &SoraVideoTaskStatus{
ID: taskID,
Status: "completed",
URLs: []string{urlStr},
ID: taskID,
Status: "completed",
GenerationID: generationID,
URLs: []string{urlStr},
}
}
return false
@@ -675,8 +1153,11 @@ func (c *SoraDirectClient) taskUserAgent() string {
return ua
}
}
if len(soraMobileUserAgents) > 0 {
return soraMobileUserAgents[soraRandInt(len(soraMobileUserAgents))]
}
if len(soraDesktopUserAgents) > 0 {
return soraDesktopUserAgents[0]
return soraDesktopUserAgents[soraRandInt(len(soraDesktopUserAgents))]
}
return soraDefaultUserAgent
}
@@ -1149,10 +1630,7 @@ func shouldAttemptSoraTokenRecover(statusCode int, rawURL string) bool {
}
func (c *SoraDirectClient) doHTTP(req *http.Request, proxyURL string, account *Account) (*http.Response, error) {
enableTLS := true
if c != nil && c.cfg != nil && c.cfg.Sora.Client.DisableTLSFingerprint {
enableTLS = false
}
enableTLS := c == nil || c.cfg == nil || !c.cfg.Sora.Client.DisableTLSFingerprint
if c.httpUpstream != nil {
accountID := int64(0)
accountConcurrency := 0
@@ -1288,6 +1766,15 @@ func soraRandFloat() float64 {
return soraRand.Float64()
}
func soraRandInt(max int) int {
if max <= 1 {
return 0
}
soraRandMu.Lock()
defer soraRandMu.Unlock()
return soraRand.Intn(max)
}
func soraBuildPowConfig(userAgent string) []any {
userAgent = strings.TrimSpace(userAgent)
if userAgent == "" && len(soraDesktopUserAgents) > 0 {

View File

@@ -393,10 +393,22 @@ func (u *soraClientRecordingUpstream) DoWithTLS(req *http.Request, proxyURL stri
return newSoraClientMockResponse(http.StatusOK, `{"token":"sentinel-token","turnstile":{"dx":"ok"}}`), nil
case "/backend/nf/create":
return newSoraClientMockResponse(http.StatusOK, `{"id":"task-123"}`), nil
case "/backend/nf/create/storyboard":
return newSoraClientMockResponse(http.StatusOK, `{"id":"storyboard-123"}`), nil
case "/backend/uploads":
return newSoraClientMockResponse(http.StatusOK, `{"id":"upload-123"}`), nil
case "/backend/nf/check":
return newSoraClientMockResponse(http.StatusOK, `{"rate_limit_and_credit_balance":{"estimated_num_videos_remaining":1,"rate_limit_reached":false}}`), nil
case "/backend/characters/upload":
return newSoraClientMockResponse(http.StatusOK, `{"id":"cameo-123"}`), nil
case "/backend/project_y/cameos/in_progress/cameo-123":
return newSoraClientMockResponse(http.StatusOK, `{"status":"finalized","status_message":"Completed","username_hint":"foo.bar","display_name_hint":"Bar","profile_asset_url":"https://example.com/avatar.webp"}`), nil
case "/backend/project_y/file/upload":
return newSoraClientMockResponse(http.StatusOK, `{"asset_pointer":"asset-123"}`), nil
case "/backend/characters/finalize":
return newSoraClientMockResponse(http.StatusOK, `{"character":{"character_id":"character-123"}}`), nil
case "/backend/project_y/post":
return newSoraClientMockResponse(http.StatusOK, `{"post":{"id":"s_post"}}`), nil
default:
return newSoraClientMockResponse(http.StatusOK, `{"ok":true}`), nil
}
@@ -410,9 +422,13 @@ func newSoraClientMockResponse(statusCode int, body string) *http.Response {
}
}
func TestSoraDirectClient_TaskUserAgent_DefaultDesktopFallback(t *testing.T) {
func TestSoraDirectClient_TaskUserAgent_DefaultMobileFallback(t *testing.T) {
client := NewSoraDirectClient(&config.Config{}, nil, nil)
require.Equal(t, soraDesktopUserAgents[0], client.taskUserAgent())
ua := client.taskUserAgent()
require.NotEmpty(t, ua)
allowed := append([]string{}, soraMobileUserAgents...)
allowed = append(allowed, soraDesktopUserAgents...)
require.Contains(t, allowed, ua)
}
func TestSoraDirectClient_CreateVideoTask_UsesSameUserAgentAndProxyForSentinelAndCreate(t *testing.T) {
@@ -460,7 +476,7 @@ func TestSoraDirectClient_CreateVideoTask_UsesSameUserAgentAndProxyForSentinelAn
require.Equal(t, "/backend/nf/create", createCall.Path)
require.Equal(t, "http://127.0.0.1:8080", sentinelCall.ProxyURL)
require.Equal(t, sentinelCall.ProxyURL, createCall.ProxyURL)
require.Equal(t, soraDesktopUserAgents[0], sentinelCall.UserAgent)
require.NotEmpty(t, sentinelCall.UserAgent)
require.Equal(t, sentinelCall.UserAgent, createCall.UserAgent)
}
@@ -495,7 +511,7 @@ func TestSoraDirectClient_UploadImage_UsesTaskUserAgentAndProxy(t *testing.T) {
require.Len(t, upstream.calls, 1)
require.Equal(t, "/backend/uploads", upstream.calls[0].Path)
require.Equal(t, "http://127.0.0.1:8080", upstream.calls[0].ProxyURL)
require.Equal(t, soraDesktopUserAgents[0], upstream.calls[0].UserAgent)
require.NotEmpty(t, upstream.calls[0].UserAgent)
}
func TestSoraDirectClient_PreflightCheck_UsesTaskUserAgentAndProxy(t *testing.T) {
@@ -528,5 +544,98 @@ func TestSoraDirectClient_PreflightCheck_UsesTaskUserAgentAndProxy(t *testing.T)
require.Len(t, upstream.calls, 1)
require.Equal(t, "/backend/nf/check", upstream.calls[0].Path)
require.Equal(t, "http://127.0.0.1:8080", upstream.calls[0].ProxyURL)
require.Equal(t, soraDesktopUserAgents[0], upstream.calls[0].UserAgent)
require.NotEmpty(t, upstream.calls[0].UserAgent)
}
func TestSoraDirectClient_CreateStoryboardTask(t *testing.T) {
originPowTokenGenerator := soraPowTokenGenerator
soraPowTokenGenerator = func(_ string) string { return "gAAAAACmock" }
defer func() { soraPowTokenGenerator = originPowTokenGenerator }()
upstream := &soraClientRecordingUpstream{}
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
BaseURL: "https://sora.chatgpt.com/backend",
},
},
}
client := NewSoraDirectClient(cfg, upstream, nil)
account := &Account{
ID: 51,
Credentials: map[string]any{
"access_token": "access-token",
"expires_at": time.Now().Add(30 * time.Minute).Format(time.RFC3339),
},
}
taskID, err := client.CreateStoryboardTask(context.Background(), account, SoraStoryboardRequest{
Prompt: "Shot 1:\nduration: 5sec\nScene: cat",
})
require.NoError(t, err)
require.Equal(t, "storyboard-123", taskID)
require.Len(t, upstream.calls, 2)
require.Equal(t, "/backend-api/sentinel/req", upstream.calls[0].Path)
require.Equal(t, "/backend/nf/create/storyboard", upstream.calls[1].Path)
}
func TestSoraDirectClient_GetVideoTask_ReturnsGenerationID(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
switch r.URL.Path {
case "/nf/pending/v2":
_, _ = w.Write([]byte(`[]`))
case "/project_y/profile/drafts":
_, _ = w.Write([]byte(`{"items":[{"id":"gen_1","task_id":"task-1","kind":"video","downloadable_url":"https://example.com/v.mp4"}]}`))
default:
http.NotFound(w, r)
}
}))
defer server.Close()
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
BaseURL: server.URL,
},
},
}
client := NewSoraDirectClient(cfg, nil, nil)
account := &Account{Credentials: map[string]any{"access_token": "token"}}
status, err := client.GetVideoTask(context.Background(), account, "task-1")
require.NoError(t, err)
require.Equal(t, "completed", status.Status)
require.Equal(t, "gen_1", status.GenerationID)
require.Equal(t, []string{"https://example.com/v.mp4"}, status.URLs)
}
func TestSoraDirectClient_PostVideoForWatermarkFree(t *testing.T) {
originPowTokenGenerator := soraPowTokenGenerator
soraPowTokenGenerator = func(_ string) string { return "gAAAAACmock" }
defer func() { soraPowTokenGenerator = originPowTokenGenerator }()
upstream := &soraClientRecordingUpstream{}
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
BaseURL: "https://sora.chatgpt.com/backend",
},
},
}
client := NewSoraDirectClient(cfg, upstream, nil)
account := &Account{
ID: 52,
Credentials: map[string]any{
"access_token": "access-token",
"expires_at": time.Now().Add(30 * time.Minute).Format(time.RFC3339),
},
}
postID, err := client.PostVideoForWatermarkFree(context.Background(), account, "gen_1")
require.NoError(t, err)
require.Equal(t, "s_post", postID)
require.Len(t, upstream.calls, 2)
require.Equal(t, "/backend-api/sentinel/req", upstream.calls[0].Path)
require.Equal(t, "/backend/project_y/post", upstream.calls[1].Path)
}

View File

@@ -8,10 +8,12 @@ import (
"fmt"
"io"
"log"
"math"
"mime"
"net"
"net/http"
"net/url"
"regexp"
"strconv"
"strings"
"time"
@@ -23,6 +25,9 @@ import (
const soraImageInputMaxBytes = 20 << 20
const soraImageInputMaxRedirects = 3
const soraImageInputTimeout = 20 * time.Second
const soraVideoInputMaxBytes = 200 << 20
const soraVideoInputMaxRedirects = 3
const soraVideoInputTimeout = 60 * time.Second
var soraImageSizeMap = map[string]string{
"gpt-image": "360",
@@ -61,6 +66,32 @@ type SoraGatewayService struct {
cfg *config.Config
}
type soraWatermarkOptions struct {
Enabled bool
ParseMethod string
ParseURL string
ParseToken string
FallbackOnFailure bool
DeletePost bool
}
type soraCharacterOptions struct {
SetPublic bool
DeleteAfterGenerate bool
}
type soraCharacterFlowResult struct {
CameoID string
CharacterID string
Username string
DisplayName string
}
var soraStoryboardPattern = regexp.MustCompile(`\[\d+(?:\.\d+)?s\]`)
var soraStoryboardShotPattern = regexp.MustCompile(`\[(\d+(?:\.\d+)?)s\]\s*([^\[]+)`)
var soraRemixTargetPattern = regexp.MustCompile(`s_[a-f0-9]{32}`)
var soraRemixTargetInURLPattern = regexp.MustCompile(`https://sora\.chatgpt\.com/p/s_[a-f0-9]{32}`)
type soraPreflightChecker interface {
PreflightCheck(ctx context.Context, account *Account, requestedModel string, modelCfg SoraModelConfig) error
}
@@ -117,20 +148,34 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun
return nil, fmt.Errorf("unsupported model: %s", reqModel)
}
prompt, imageInput, videoInput, remixTargetID := extractSoraInput(reqBody)
if strings.TrimSpace(prompt) == "" {
prompt = strings.TrimSpace(prompt)
imageInput = strings.TrimSpace(imageInput)
videoInput = strings.TrimSpace(videoInput)
remixTargetID = strings.TrimSpace(remixTargetID)
if videoInput != "" && modelCfg.Type != "video" {
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "video input only supports video models", clientStream)
return nil, errors.New("video input only supports video models")
}
if videoInput != "" && imageInput != "" {
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "image input and video input cannot be used together", clientStream)
return nil, errors.New("image input and video input cannot be used together")
}
characterOnly := videoInput != "" && prompt == ""
if modelCfg.Type == "prompt_enhance" && prompt == "" {
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "prompt is required", clientStream)
return nil, errors.New("prompt is required")
}
if strings.TrimSpace(videoInput) != "" {
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "Video input is not supported yet", clientStream)
return nil, errors.New("video input not supported")
if modelCfg.Type != "prompt_enhance" && prompt == "" && !characterOnly {
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "prompt is required", clientStream)
return nil, errors.New("prompt is required")
}
reqCtx, cancel := s.withSoraTimeout(ctx, reqStream)
if cancel != nil {
defer cancel()
}
if checker, ok := s.soraClient.(soraPreflightChecker); ok {
if checker, ok := s.soraClient.(soraPreflightChecker); ok && !characterOnly {
if err := checker.PreflightCheck(reqCtx, account, reqModel, modelCfg); err != nil {
return nil, s.handleSoraRequestError(ctx, account, err, reqModel, c, clientStream)
}
@@ -166,9 +211,69 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun
}, nil
}
characterOpts := parseSoraCharacterOptions(reqBody)
watermarkOpts := parseSoraWatermarkOptions(reqBody)
var characterResult *soraCharacterFlowResult
if videoInput != "" {
videoData, videoErr := decodeSoraVideoInput(reqCtx, videoInput)
if videoErr != nil {
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", videoErr.Error(), clientStream)
return nil, videoErr
}
characterResult, videoErr = s.createCharacterFromVideo(reqCtx, account, videoData, characterOpts)
if videoErr != nil {
return nil, s.handleSoraRequestError(ctx, account, videoErr, reqModel, c, clientStream)
}
if characterResult != nil && characterOpts.DeleteAfterGenerate && strings.TrimSpace(characterResult.CharacterID) != "" && !characterOnly {
characterID := strings.TrimSpace(characterResult.CharacterID)
defer func() {
cleanupCtx, cancelCleanup := context.WithTimeout(context.Background(), 15*time.Second)
defer cancelCleanup()
if err := s.soraClient.DeleteCharacter(cleanupCtx, account, characterID); err != nil {
log.Printf("[Sora] cleanup character failed, character_id=%s err=%v", characterID, err)
}
}()
}
if characterOnly {
content := "角色创建成功"
if characterResult != nil && strings.TrimSpace(characterResult.Username) != "" {
content = fmt.Sprintf("角色创建成功,角色名@%s", strings.TrimSpace(characterResult.Username))
}
var firstTokenMs *int
if clientStream {
ms, streamErr := s.writeSoraStream(c, reqModel, content, startTime)
if streamErr != nil {
return nil, streamErr
}
firstTokenMs = ms
} else if c != nil {
resp := buildSoraNonStreamResponse(content, reqModel)
if characterResult != nil {
resp["character_id"] = characterResult.CharacterID
resp["cameo_id"] = characterResult.CameoID
resp["character_username"] = characterResult.Username
resp["character_display_name"] = characterResult.DisplayName
}
c.JSON(http.StatusOK, resp)
}
return &ForwardResult{
RequestID: "",
Model: reqModel,
Stream: clientStream,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
Usage: ClaudeUsage{},
MediaType: "prompt",
}, nil
}
if characterResult != nil && strings.TrimSpace(characterResult.Username) != "" {
prompt = fmt.Sprintf("@%s %s", characterResult.Username, prompt)
}
}
var imageData []byte
imageFilename := ""
if strings.TrimSpace(imageInput) != "" {
if imageInput != "" {
decoded, filename, err := decodeSoraImageInput(reqCtx, imageInput)
if err != nil {
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", err.Error(), clientStream)
@@ -198,15 +303,27 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun
MediaID: mediaID,
})
case "video":
taskID, err = s.soraClient.CreateVideoTask(reqCtx, account, SoraVideoRequest{
Prompt: prompt,
Orientation: modelCfg.Orientation,
Frames: modelCfg.Frames,
Model: modelCfg.Model,
Size: modelCfg.Size,
MediaID: mediaID,
RemixTargetID: remixTargetID,
})
if remixTargetID == "" && isSoraStoryboardPrompt(prompt) {
taskID, err = s.soraClient.CreateStoryboardTask(reqCtx, account, SoraStoryboardRequest{
Prompt: formatSoraStoryboardPrompt(prompt),
Orientation: modelCfg.Orientation,
Frames: modelCfg.Frames,
Model: modelCfg.Model,
Size: modelCfg.Size,
MediaID: mediaID,
})
} else {
taskID, err = s.soraClient.CreateVideoTask(reqCtx, account, SoraVideoRequest{
Prompt: prompt,
Orientation: modelCfg.Orientation,
Frames: modelCfg.Frames,
Model: modelCfg.Model,
Size: modelCfg.Size,
MediaID: mediaID,
RemixTargetID: remixTargetID,
CameoIDs: extractSoraCameoIDs(reqBody),
})
}
default:
err = fmt.Errorf("unsupported model type: %s", modelCfg.Type)
}
@@ -219,6 +336,7 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun
}
var mediaURLs []string
videoGenerationID := ""
mediaType := modelCfg.Type
imageCount := 0
imageSize := ""
@@ -232,15 +350,32 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun
imageCount = len(urls)
imageSize = soraImageSizeFromModel(reqModel)
case "video":
urls, pollErr := s.pollVideoTask(reqCtx, c, account, taskID, clientStream)
videoStatus, pollErr := s.pollVideoTaskDetailed(reqCtx, c, account, taskID, clientStream)
if pollErr != nil {
return nil, s.handleSoraRequestError(ctx, account, pollErr, reqModel, c, clientStream)
}
mediaURLs = urls
if videoStatus != nil {
mediaURLs = videoStatus.URLs
videoGenerationID = strings.TrimSpace(videoStatus.GenerationID)
}
default:
mediaType = "prompt"
}
watermarkPostID := ""
if modelCfg.Type == "video" && watermarkOpts.Enabled {
watermarkURL, postID, watermarkErr := s.resolveWatermarkFreeURL(reqCtx, account, videoGenerationID, watermarkOpts)
if watermarkErr != nil {
if !watermarkOpts.FallbackOnFailure {
return nil, s.handleSoraRequestError(ctx, account, watermarkErr, reqModel, c, clientStream)
}
log.Printf("[Sora] watermark-free fallback to original URL, task_id=%s err=%v", taskID, watermarkErr)
} else if strings.TrimSpace(watermarkURL) != "" {
mediaURLs = []string{strings.TrimSpace(watermarkURL)}
watermarkPostID = strings.TrimSpace(postID)
}
}
finalURLs := s.normalizeSoraMediaURLs(mediaURLs)
if len(mediaURLs) > 0 && s.mediaStorage != nil && s.mediaStorage.Enabled() {
stored, storeErr := s.mediaStorage.StoreFromURLs(reqCtx, mediaType, mediaURLs)
@@ -251,6 +386,11 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun
finalURLs = s.normalizeSoraMediaURLs(stored)
}
}
if watermarkPostID != "" && watermarkOpts.DeletePost {
if deleteErr := s.soraClient.DeletePost(reqCtx, account, watermarkPostID); deleteErr != nil {
log.Printf("[Sora] delete post failed, post_id=%s err=%v", watermarkPostID, deleteErr)
}
}
content := buildSoraContent(mediaType, finalURLs)
var firstTokenMs *int
@@ -299,6 +439,267 @@ func (s *SoraGatewayService) withSoraTimeout(ctx context.Context, stream bool) (
return context.WithTimeout(ctx, time.Duration(timeoutSeconds)*time.Second)
}
func parseSoraWatermarkOptions(body map[string]any) soraWatermarkOptions {
opts := soraWatermarkOptions{
Enabled: parseBoolWithDefault(body, "watermark_free", false),
ParseMethod: strings.ToLower(strings.TrimSpace(parseStringWithDefault(body, "watermark_parse_method", "third_party"))),
ParseURL: strings.TrimSpace(parseStringWithDefault(body, "watermark_parse_url", "")),
ParseToken: strings.TrimSpace(parseStringWithDefault(body, "watermark_parse_token", "")),
FallbackOnFailure: parseBoolWithDefault(body, "watermark_fallback_on_failure", true),
DeletePost: parseBoolWithDefault(body, "watermark_delete_post", false),
}
if opts.ParseMethod == "" {
opts.ParseMethod = "third_party"
}
return opts
}
func parseSoraCharacterOptions(body map[string]any) soraCharacterOptions {
return soraCharacterOptions{
SetPublic: parseBoolWithDefault(body, "character_set_public", true),
DeleteAfterGenerate: parseBoolWithDefault(body, "character_delete_after_generate", true),
}
}
func parseBoolWithDefault(body map[string]any, key string, def bool) bool {
if body == nil {
return def
}
val, ok := body[key]
if !ok {
return def
}
switch typed := val.(type) {
case bool:
return typed
case int:
return typed != 0
case int32:
return typed != 0
case int64:
return typed != 0
case float64:
return typed != 0
case string:
typed = strings.ToLower(strings.TrimSpace(typed))
if typed == "true" || typed == "1" || typed == "yes" {
return true
}
if typed == "false" || typed == "0" || typed == "no" {
return false
}
}
return def
}
func parseStringWithDefault(body map[string]any, key, def string) string {
if body == nil {
return def
}
val, ok := body[key]
if !ok {
return def
}
if str, ok := val.(string); ok {
return str
}
return def
}
func extractSoraCameoIDs(body map[string]any) []string {
if body == nil {
return nil
}
raw, ok := body["cameo_ids"]
if !ok {
return nil
}
switch typed := raw.(type) {
case []string:
out := make([]string, 0, len(typed))
for _, item := range typed {
item = strings.TrimSpace(item)
if item != "" {
out = append(out, item)
}
}
return out
case []any:
out := make([]string, 0, len(typed))
for _, item := range typed {
str, ok := item.(string)
if !ok {
continue
}
str = strings.TrimSpace(str)
if str != "" {
out = append(out, str)
}
}
return out
default:
return nil
}
}
func (s *SoraGatewayService) createCharacterFromVideo(ctx context.Context, account *Account, videoData []byte, opts soraCharacterOptions) (*soraCharacterFlowResult, error) {
cameoID, err := s.soraClient.UploadCharacterVideo(ctx, account, videoData)
if err != nil {
return nil, err
}
cameoStatus, err := s.pollCameoStatus(ctx, account, cameoID)
if err != nil {
return nil, err
}
username := processSoraCharacterUsername(cameoStatus.UsernameHint)
displayName := strings.TrimSpace(cameoStatus.DisplayNameHint)
if displayName == "" {
displayName = "Character"
}
profileAssetURL := strings.TrimSpace(cameoStatus.ProfileAssetURL)
if profileAssetURL == "" {
return nil, errors.New("profile asset url not found in cameo status")
}
avatarData, err := s.soraClient.DownloadCharacterImage(ctx, account, profileAssetURL)
if err != nil {
return nil, err
}
assetPointer, err := s.soraClient.UploadCharacterImage(ctx, account, avatarData)
if err != nil {
return nil, err
}
instructionSet := cameoStatus.InstructionSetHint
if instructionSet == nil {
instructionSet = cameoStatus.InstructionSet
}
characterID, err := s.soraClient.FinalizeCharacter(ctx, account, SoraCharacterFinalizeRequest{
CameoID: strings.TrimSpace(cameoID),
Username: username,
DisplayName: displayName,
ProfileAssetPointer: assetPointer,
InstructionSet: instructionSet,
})
if err != nil {
return nil, err
}
if opts.SetPublic {
if err := s.soraClient.SetCharacterPublic(ctx, account, cameoID); err != nil {
return nil, err
}
}
return &soraCharacterFlowResult{
CameoID: strings.TrimSpace(cameoID),
CharacterID: strings.TrimSpace(characterID),
Username: strings.TrimSpace(username),
DisplayName: displayName,
}, nil
}
func (s *SoraGatewayService) pollCameoStatus(ctx context.Context, account *Account, cameoID string) (*SoraCameoStatus, error) {
timeout := 10 * time.Minute
interval := 5 * time.Second
maxAttempts := int(math.Ceil(timeout.Seconds() / interval.Seconds()))
if maxAttempts < 1 {
maxAttempts = 1
}
var lastErr error
consecutiveErrors := 0
for attempt := 0; attempt < maxAttempts; attempt++ {
status, err := s.soraClient.GetCameoStatus(ctx, account, cameoID)
if err != nil {
lastErr = err
consecutiveErrors++
if consecutiveErrors >= 3 {
break
}
if attempt < maxAttempts-1 {
if sleepErr := sleepWithContext(ctx, interval); sleepErr != nil {
return nil, sleepErr
}
}
continue
}
consecutiveErrors = 0
if status == nil {
if attempt < maxAttempts-1 {
if sleepErr := sleepWithContext(ctx, interval); sleepErr != nil {
return nil, sleepErr
}
}
continue
}
currentStatus := strings.ToLower(strings.TrimSpace(status.Status))
statusMessage := strings.TrimSpace(status.StatusMessage)
if currentStatus == "failed" {
if statusMessage == "" {
statusMessage = "character creation failed"
}
return nil, errors.New(statusMessage)
}
if strings.EqualFold(statusMessage, "Completed") || currentStatus == "finalized" {
return status, nil
}
if attempt < maxAttempts-1 {
if sleepErr := sleepWithContext(ctx, interval); sleepErr != nil {
return nil, sleepErr
}
}
}
if lastErr != nil {
return nil, fmt.Errorf("poll cameo status failed: %w", lastErr)
}
return nil, errors.New("cameo processing timeout")
}
func processSoraCharacterUsername(usernameHint string) string {
usernameHint = strings.TrimSpace(usernameHint)
if usernameHint == "" {
usernameHint = "character"
}
if strings.Contains(usernameHint, ".") {
parts := strings.Split(usernameHint, ".")
usernameHint = strings.TrimSpace(parts[len(parts)-1])
}
if usernameHint == "" {
usernameHint = "character"
}
return fmt.Sprintf("%s%d", usernameHint, soraRandInt(900)+100)
}
func (s *SoraGatewayService) resolveWatermarkFreeURL(ctx context.Context, account *Account, generationID string, opts soraWatermarkOptions) (string, string, error) {
generationID = strings.TrimSpace(generationID)
if generationID == "" {
return "", "", errors.New("generation id is required for watermark-free mode")
}
postID, err := s.soraClient.PostVideoForWatermarkFree(ctx, account, generationID)
if err != nil {
return "", "", err
}
postID = strings.TrimSpace(postID)
if postID == "" {
return "", "", errors.New("watermark-free publish returned empty post id")
}
switch opts.ParseMethod {
case "custom":
urlVal, parseErr := s.soraClient.GetWatermarkFreeURLCustom(ctx, account, opts.ParseURL, opts.ParseToken, postID)
if parseErr != nil {
return "", postID, parseErr
}
return strings.TrimSpace(urlVal), postID, nil
case "", "third_party":
return fmt.Sprintf("https://oscdn2.dyysy.com/MP4/%s.mp4", postID), postID, nil
default:
return "", postID, fmt.Errorf("unsupported watermark parse method: %s", opts.ParseMethod)
}
}
func (s *SoraGatewayService) shouldFailoverUpstreamError(statusCode int) bool {
switch statusCode {
case 401, 402, 403, 404, 429, 529:
@@ -554,7 +955,7 @@ func (s *SoraGatewayService) pollImageTask(ctx context.Context, c *gin.Context,
return nil, errors.New("sora image generation timeout")
}
func (s *SoraGatewayService) pollVideoTask(ctx context.Context, c *gin.Context, account *Account, taskID string, stream bool) ([]string, error) {
func (s *SoraGatewayService) pollVideoTaskDetailed(ctx context.Context, c *gin.Context, account *Account, taskID string, stream bool) (*SoraVideoTaskStatus, error) {
interval := s.pollInterval()
maxAttempts := s.pollMaxAttempts()
lastPing := time.Now()
@@ -565,7 +966,7 @@ func (s *SoraGatewayService) pollVideoTask(ctx context.Context, c *gin.Context,
}
switch strings.ToLower(status.Status) {
case "completed", "succeeded":
return status.URLs, nil
return status, nil
case "failed":
if status.ErrorMsg != "" {
return nil, errors.New(status.ErrorMsg)
@@ -669,7 +1070,7 @@ func extractSoraInput(body map[string]any) (prompt, imageInput, videoInput, remi
return "", "", "", ""
}
if v, ok := body["remix_target_id"].(string); ok {
remixTargetID = v
remixTargetID = strings.TrimSpace(v)
}
if v, ok := body["image"].(string); ok {
imageInput = v
@@ -710,6 +1111,10 @@ func extractSoraInput(body map[string]any) (prompt, imageInput, videoInput, remi
prompt = builder.String()
}
}
if remixTargetID == "" {
remixTargetID = extractRemixTargetIDFromPrompt(prompt)
}
prompt = cleanRemixLinkFromPrompt(prompt)
return prompt, imageInput, videoInput, remixTargetID
}
@@ -757,6 +1162,69 @@ func parseSoraMessageContent(content any) (text, imageInput, videoInput string)
}
}
func isSoraStoryboardPrompt(prompt string) bool {
prompt = strings.TrimSpace(prompt)
if prompt == "" {
return false
}
return len(soraStoryboardPattern.FindAllString(prompt, -1)) >= 1
}
func formatSoraStoryboardPrompt(prompt string) string {
prompt = strings.TrimSpace(prompt)
if prompt == "" {
return ""
}
matches := soraStoryboardShotPattern.FindAllStringSubmatch(prompt, -1)
if len(matches) == 0 {
return prompt
}
firstBracketPos := strings.Index(prompt, "[")
instructions := ""
if firstBracketPos > 0 {
instructions = strings.TrimSpace(prompt[:firstBracketPos])
}
shots := make([]string, 0, len(matches))
for i, match := range matches {
if len(match) < 3 {
continue
}
duration := strings.TrimSpace(match[1])
scene := strings.TrimSpace(match[2])
if scene == "" {
continue
}
shots = append(shots, fmt.Sprintf("Shot %d:\nduration: %ssec\nScene: %s", i+1, duration, scene))
}
if len(shots) == 0 {
return prompt
}
timeline := strings.Join(shots, "\n\n")
if instructions == "" {
return timeline
}
return fmt.Sprintf("current timeline:\n%s\n\ninstructions:\n%s", timeline, instructions)
}
func extractRemixTargetIDFromPrompt(prompt string) string {
prompt = strings.TrimSpace(prompt)
if prompt == "" {
return ""
}
return strings.TrimSpace(soraRemixTargetPattern.FindString(prompt))
}
func cleanRemixLinkFromPrompt(prompt string) string {
prompt = strings.TrimSpace(prompt)
if prompt == "" {
return prompt
}
cleaned := soraRemixTargetInURLPattern.ReplaceAllString(prompt, "")
cleaned = soraRemixTargetPattern.ReplaceAllString(cleaned, "")
cleaned = strings.Join(strings.Fields(cleaned), " ")
return strings.TrimSpace(cleaned)
}
func decodeSoraImageInput(ctx context.Context, input string) ([]byte, string, error) {
raw := strings.TrimSpace(input)
if raw == "" {
@@ -769,7 +1237,7 @@ func decodeSoraImageInput(ctx context.Context, input string) ([]byte, string, er
}
meta := parts[0]
payload := parts[1]
decoded, err := base64.StdEncoding.DecodeString(payload)
decoded, err := decodeBase64WithLimit(payload, soraImageInputMaxBytes)
if err != nil {
return nil, "", err
}
@@ -788,15 +1256,47 @@ func decodeSoraImageInput(ctx context.Context, input string) ([]byte, string, er
if strings.HasPrefix(raw, "http://") || strings.HasPrefix(raw, "https://") {
return downloadSoraImageInput(ctx, raw)
}
decoded, err := base64.StdEncoding.DecodeString(raw)
decoded, err := decodeBase64WithLimit(raw, soraImageInputMaxBytes)
if err != nil {
return nil, "", errors.New("invalid base64 image")
}
return decoded, "image.png", nil
}
func decodeSoraVideoInput(ctx context.Context, input string) ([]byte, error) {
raw := strings.TrimSpace(input)
if raw == "" {
return nil, errors.New("empty video input")
}
if strings.HasPrefix(raw, "data:") {
parts := strings.SplitN(raw, ",", 2)
if len(parts) != 2 {
return nil, errors.New("invalid video data url")
}
decoded, err := decodeBase64WithLimit(parts[1], soraVideoInputMaxBytes)
if err != nil {
return nil, errors.New("invalid base64 video")
}
if len(decoded) == 0 {
return nil, errors.New("empty video data")
}
return decoded, nil
}
if strings.HasPrefix(raw, "http://") || strings.HasPrefix(raw, "https://") {
return downloadSoraVideoInput(ctx, raw)
}
decoded, err := decodeBase64WithLimit(raw, soraVideoInputMaxBytes)
if err != nil {
return nil, errors.New("invalid base64 video")
}
if len(decoded) == 0 {
return nil, errors.New("empty video data")
}
return decoded, nil
}
func downloadSoraImageInput(ctx context.Context, rawURL string) ([]byte, string, error) {
parsed, err := validateSoraImageURL(rawURL)
parsed, err := validateSoraRemoteURL(rawURL)
if err != nil {
return nil, "", err
}
@@ -810,7 +1310,7 @@ func downloadSoraImageInput(ctx context.Context, rawURL string) ([]byte, string,
if len(via) >= soraImageInputMaxRedirects {
return errors.New("too many redirects")
}
return validateSoraImageURLValue(req.URL)
return validateSoraRemoteURLValue(req.URL)
},
}
resp, err := client.Do(req)
@@ -833,51 +1333,103 @@ func downloadSoraImageInput(ctx context.Context, rawURL string) ([]byte, string,
return data, filename, nil
}
func validateSoraImageURL(raw string) (*url.URL, error) {
func downloadSoraVideoInput(ctx context.Context, rawURL string) ([]byte, error) {
parsed, err := validateSoraRemoteURL(rawURL)
if err != nil {
return nil, err
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, parsed.String(), nil)
if err != nil {
return nil, err
}
client := &http.Client{
Timeout: soraVideoInputTimeout,
CheckRedirect: func(req *http.Request, via []*http.Request) error {
if len(via) >= soraVideoInputMaxRedirects {
return errors.New("too many redirects")
}
return validateSoraRemoteURLValue(req.URL)
},
}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("download video failed: %d", resp.StatusCode)
}
data, err := io.ReadAll(io.LimitReader(resp.Body, soraVideoInputMaxBytes))
if err != nil {
return nil, err
}
if len(data) == 0 {
return nil, errors.New("empty video content")
}
return data, nil
}
func decodeBase64WithLimit(encoded string, maxBytes int64) ([]byte, error) {
if maxBytes <= 0 {
return nil, errors.New("invalid max bytes limit")
}
decoder := base64.NewDecoder(base64.StdEncoding, strings.NewReader(encoded))
limited := io.LimitReader(decoder, maxBytes+1)
data, err := io.ReadAll(limited)
if err != nil {
return nil, err
}
if int64(len(data)) > maxBytes {
return nil, fmt.Errorf("input exceeds %d bytes limit", maxBytes)
}
return data, nil
}
func validateSoraRemoteURL(raw string) (*url.URL, error) {
if strings.TrimSpace(raw) == "" {
return nil, errors.New("empty image url")
return nil, errors.New("empty remote url")
}
parsed, err := url.Parse(raw)
if err != nil {
return nil, fmt.Errorf("invalid image url: %w", err)
return nil, fmt.Errorf("invalid remote url: %w", err)
}
if err := validateSoraImageURLValue(parsed); err != nil {
if err := validateSoraRemoteURLValue(parsed); err != nil {
return nil, err
}
return parsed, nil
}
func validateSoraImageURLValue(parsed *url.URL) error {
func validateSoraRemoteURLValue(parsed *url.URL) error {
if parsed == nil {
return errors.New("invalid image url")
return errors.New("invalid remote url")
}
scheme := strings.ToLower(strings.TrimSpace(parsed.Scheme))
if scheme != "http" && scheme != "https" {
return errors.New("only http/https image url is allowed")
return errors.New("only http/https remote url is allowed")
}
if parsed.User != nil {
return errors.New("image url cannot contain userinfo")
return errors.New("remote url cannot contain userinfo")
}
host := strings.ToLower(strings.TrimSpace(parsed.Hostname()))
if host == "" {
return errors.New("image url missing host")
return errors.New("remote url missing host")
}
if _, blocked := soraBlockedHostnames[host]; blocked {
return errors.New("image url is not allowed")
return errors.New("remote url is not allowed")
}
if ip := net.ParseIP(host); ip != nil {
if isSoraBlockedIP(ip) {
return errors.New("image url is not allowed")
return errors.New("remote url is not allowed")
}
return nil
}
ips, err := net.LookupIP(host)
if err != nil {
return fmt.Errorf("resolve image url failed: %w", err)
return fmt.Errorf("resolve remote url failed: %w", err)
}
for _, ip := range ips {
if isSoraBlockedIP(ip) {
return errors.New("image url is not allowed")
return errors.New("remote url is not allowed")
}
}
return nil

View File

@@ -5,6 +5,7 @@ package service
import (
"context"
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"strings"
@@ -25,6 +26,11 @@ type stubSoraClientForPoll struct {
videoCalls int
enhanced string
enhanceErr error
storyboard bool
videoReq SoraVideoRequest
parseErr error
postCalls int
deleteCalls int
}
func (s *stubSoraClientForPoll) Enabled() bool { return true }
@@ -35,8 +41,54 @@ func (s *stubSoraClientForPoll) CreateImageTask(ctx context.Context, account *Ac
return "task-image", nil
}
func (s *stubSoraClientForPoll) CreateVideoTask(ctx context.Context, account *Account, req SoraVideoRequest) (string, error) {
s.videoReq = req
return "task-video", nil
}
func (s *stubSoraClientForPoll) CreateStoryboardTask(ctx context.Context, account *Account, req SoraStoryboardRequest) (string, error) {
s.storyboard = true
return "task-video", nil
}
func (s *stubSoraClientForPoll) UploadCharacterVideo(ctx context.Context, account *Account, data []byte) (string, error) {
return "cameo-1", nil
}
func (s *stubSoraClientForPoll) GetCameoStatus(ctx context.Context, account *Account, cameoID string) (*SoraCameoStatus, error) {
return &SoraCameoStatus{
Status: "finalized",
StatusMessage: "Completed",
DisplayNameHint: "Character",
UsernameHint: "user.character",
ProfileAssetURL: "https://example.com/avatar.webp",
}, nil
}
func (s *stubSoraClientForPoll) DownloadCharacterImage(ctx context.Context, account *Account, imageURL string) ([]byte, error) {
return []byte("avatar"), nil
}
func (s *stubSoraClientForPoll) UploadCharacterImage(ctx context.Context, account *Account, data []byte) (string, error) {
return "asset-pointer", nil
}
func (s *stubSoraClientForPoll) FinalizeCharacter(ctx context.Context, account *Account, req SoraCharacterFinalizeRequest) (string, error) {
return "character-1", nil
}
func (s *stubSoraClientForPoll) SetCharacterPublic(ctx context.Context, account *Account, cameoID string) error {
return nil
}
func (s *stubSoraClientForPoll) DeleteCharacter(ctx context.Context, account *Account, characterID string) error {
return nil
}
func (s *stubSoraClientForPoll) PostVideoForWatermarkFree(ctx context.Context, account *Account, generationID string) (string, error) {
s.postCalls++
return "s_post", nil
}
func (s *stubSoraClientForPoll) DeletePost(ctx context.Context, account *Account, postID string) error {
s.deleteCalls++
return nil
}
func (s *stubSoraClientForPoll) GetWatermarkFreeURLCustom(ctx context.Context, account *Account, parseURL, parseToken, postID string) (string, error) {
if s.parseErr != nil {
return "", s.parseErr
}
return "https://example.com/no-watermark.mp4", nil
}
func (s *stubSoraClientForPoll) EnhancePrompt(ctx context.Context, account *Account, prompt, expansionLevel string, durationS int) (string, error) {
if s.enhanced != "" {
return s.enhanced, s.enhanceErr
@@ -102,6 +154,109 @@ func TestSoraGatewayService_ForwardPromptEnhance(t *testing.T) {
require.Equal(t, "prompt-enhance-short-10s", result.Model)
}
func TestSoraGatewayService_ForwardStoryboardPrompt(t *testing.T) {
client := &stubSoraClientForPoll{
videoStatus: &SoraVideoTaskStatus{
Status: "completed",
URLs: []string{"https://example.com/v.mp4"},
},
}
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
PollIntervalSeconds: 1,
MaxPollAttempts: 1,
},
},
}
svc := NewSoraGatewayService(client, nil, nil, cfg)
account := &Account{ID: 1, Platform: PlatformSora, Status: StatusActive}
body := []byte(`{"model":"sora2-landscape-10s","messages":[{"role":"user","content":"[5.0s]猫猫跳伞 [5.0s]猫猫落地"}],"stream":false}`)
result, err := svc.Forward(context.Background(), nil, account, body, false)
require.NoError(t, err)
require.NotNil(t, result)
require.True(t, client.storyboard)
}
func TestSoraGatewayService_ForwardCharacterOnly(t *testing.T) {
client := &stubSoraClientForPoll{}
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
PollIntervalSeconds: 1,
MaxPollAttempts: 1,
},
},
}
svc := NewSoraGatewayService(client, nil, nil, cfg)
account := &Account{ID: 1, Platform: PlatformSora, Status: StatusActive}
body := []byte(`{"model":"sora2-landscape-10s","video":"aGVsbG8=","stream":false}`)
result, err := svc.Forward(context.Background(), nil, account, body, false)
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, "prompt", result.MediaType)
require.Equal(t, 0, client.videoCalls)
}
func TestSoraGatewayService_ForwardWatermarkFallback(t *testing.T) {
client := &stubSoraClientForPoll{
videoStatus: &SoraVideoTaskStatus{
Status: "completed",
URLs: []string{"https://example.com/original.mp4"},
GenerationID: "gen_1",
},
parseErr: errors.New("parse failed"),
}
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
PollIntervalSeconds: 1,
MaxPollAttempts: 1,
},
},
}
svc := NewSoraGatewayService(client, nil, nil, cfg)
account := &Account{ID: 1, Platform: PlatformSora, Status: StatusActive}
body := []byte(`{"model":"sora2-landscape-10s","messages":[{"role":"user","content":"cat running"}],"stream":false,"watermark_free":true,"watermark_parse_method":"custom","watermark_parse_url":"https://parser.example.com","watermark_parse_token":"token","watermark_fallback_on_failure":true}`)
result, err := svc.Forward(context.Background(), nil, account, body, false)
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, "https://example.com/original.mp4", result.MediaURL)
require.Equal(t, 1, client.postCalls)
require.Equal(t, 0, client.deleteCalls)
}
func TestSoraGatewayService_ForwardWatermarkCustomSuccessAndDelete(t *testing.T) {
client := &stubSoraClientForPoll{
videoStatus: &SoraVideoTaskStatus{
Status: "completed",
URLs: []string{"https://example.com/original.mp4"},
GenerationID: "gen_1",
},
}
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
PollIntervalSeconds: 1,
MaxPollAttempts: 1,
},
},
}
svc := NewSoraGatewayService(client, nil, nil, cfg)
account := &Account{ID: 1, Platform: PlatformSora, Status: StatusActive}
body := []byte(`{"model":"sora2-landscape-10s","messages":[{"role":"user","content":"cat running"}],"stream":false,"watermark_free":true,"watermark_parse_method":"custom","watermark_parse_url":"https://parser.example.com","watermark_parse_token":"token","watermark_delete_post":true}`)
result, err := svc.Forward(context.Background(), nil, account, body, false)
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, "https://example.com/no-watermark.mp4", result.MediaURL)
require.Equal(t, 1, client.postCalls)
require.Equal(t, 1, client.deleteCalls)
}
func TestSoraGatewayService_PollVideoTaskFailed(t *testing.T) {
client := &stubSoraClientForPoll{
videoStatus: &SoraVideoTaskStatus{
@@ -119,9 +274,9 @@ func TestSoraGatewayService_PollVideoTaskFailed(t *testing.T) {
}
service := NewSoraGatewayService(client, nil, nil, cfg)
urls, err := service.pollVideoTask(context.Background(), nil, &Account{ID: 1}, "task", false)
status, err := service.pollVideoTaskDetailed(context.Background(), nil, &Account{ID: 1}, "task", false)
require.Error(t, err)
require.Empty(t, urls)
require.Nil(t, status)
require.Contains(t, err.Error(), "reject")
require.Equal(t, 1, client.videoCalls)
}
@@ -325,3 +480,19 @@ func TestDecodeSoraImageInput_DataURL(t *testing.T) {
require.NotEmpty(t, data)
require.Contains(t, filename, ".png")
}
func TestDecodeBase64WithLimit_ExceedLimit(t *testing.T) {
data, err := decodeBase64WithLimit("aGVsbG8=", 3)
require.Error(t, err)
require.Nil(t, data)
}
func TestParseSoraWatermarkOptions_NumericBool(t *testing.T) {
body := map[string]any{
"watermark_free": float64(1),
"watermark_fallback_on_failure": float64(0),
}
opts := parseSoraWatermarkOptions(body)
require.True(t, opts.Enabled)
require.False(t, opts.FallbackOnFailure)
}

View File

@@ -0,0 +1,170 @@
package soraerror
import (
"encoding/json"
"fmt"
"net/http"
"regexp"
"strings"
)
var (
cfRayPattern = regexp.MustCompile(`(?i)cf-ray[:\s=]+([a-z0-9-]+)`)
cRayPattern = regexp.MustCompile(`(?i)cRay:\s*'([a-z0-9-]+)'`)
htmlChallenge = []string{
"window._cf_chl_opt",
"just a moment",
"enable javascript and cookies to continue",
"__cf_chl_",
"challenge-platform",
}
)
// IsCloudflareChallengeResponse reports whether the upstream response matches Cloudflare challenge behavior.
func IsCloudflareChallengeResponse(statusCode int, headers http.Header, body []byte) bool {
if statusCode != http.StatusForbidden && statusCode != http.StatusTooManyRequests {
return false
}
if headers != nil && strings.EqualFold(strings.TrimSpace(headers.Get("cf-mitigated")), "challenge") {
return true
}
preview := strings.ToLower(TruncateBody(body, 4096))
for _, marker := range htmlChallenge {
if strings.Contains(preview, marker) {
return true
}
}
contentType := ""
if headers != nil {
contentType = strings.ToLower(strings.TrimSpace(headers.Get("content-type")))
}
if strings.Contains(contentType, "text/html") &&
(strings.Contains(preview, "<html") || strings.Contains(preview, "<!doctype html")) &&
(strings.Contains(preview, "cloudflare") || strings.Contains(preview, "challenge")) {
return true
}
return false
}
// ExtractCloudflareRayID extracts cf-ray from headers or response body.
func ExtractCloudflareRayID(headers http.Header, body []byte) string {
if headers != nil {
rayID := strings.TrimSpace(headers.Get("cf-ray"))
if rayID != "" {
return rayID
}
rayID = strings.TrimSpace(headers.Get("Cf-Ray"))
if rayID != "" {
return rayID
}
}
preview := TruncateBody(body, 8192)
if matches := cfRayPattern.FindStringSubmatch(preview); len(matches) >= 2 {
return strings.TrimSpace(matches[1])
}
if matches := cRayPattern.FindStringSubmatch(preview); len(matches) >= 2 {
return strings.TrimSpace(matches[1])
}
return ""
}
// FormatCloudflareChallengeMessage appends cf-ray info when available.
func FormatCloudflareChallengeMessage(base string, headers http.Header, body []byte) string {
rayID := ExtractCloudflareRayID(headers, body)
if rayID == "" {
return base
}
return fmt.Sprintf("%s (cf-ray: %s)", base, rayID)
}
// ExtractUpstreamErrorCodeAndMessage extracts structured error code/message from common JSON layouts.
func ExtractUpstreamErrorCodeAndMessage(body []byte) (string, string) {
trimmed := strings.TrimSpace(string(body))
if trimmed == "" {
return "", ""
}
if !json.Valid([]byte(trimmed)) {
return "", truncateMessage(trimmed, 256)
}
var payload map[string]any
if err := json.Unmarshal([]byte(trimmed), &payload); err != nil {
return "", truncateMessage(trimmed, 256)
}
code := firstNonEmpty(
extractNestedString(payload, "error", "code"),
extractRootString(payload, "code"),
)
message := firstNonEmpty(
extractNestedString(payload, "error", "message"),
extractRootString(payload, "message"),
extractNestedString(payload, "error", "detail"),
extractRootString(payload, "detail"),
)
return strings.TrimSpace(code), truncateMessage(strings.TrimSpace(message), 512)
}
// TruncateBody truncates body text for logging/inspection.
func TruncateBody(body []byte, max int) string {
if max <= 0 {
max = 512
}
raw := strings.TrimSpace(string(body))
if len(raw) <= max {
return raw
}
return raw[:max] + "...(truncated)"
}
func truncateMessage(s string, max int) string {
if max <= 0 {
return ""
}
if len(s) <= max {
return s
}
return s[:max] + "...(truncated)"
}
func firstNonEmpty(values ...string) string {
for _, v := range values {
if strings.TrimSpace(v) != "" {
return v
}
}
return ""
}
func extractRootString(m map[string]any, key string) string {
if m == nil {
return ""
}
v, ok := m[key]
if !ok {
return ""
}
s, _ := v.(string)
return s
}
func extractNestedString(m map[string]any, parent, key string) string {
if m == nil {
return ""
}
node, ok := m[parent]
if !ok {
return ""
}
child, ok := node.(map[string]any)
if !ok {
return ""
}
s, _ := child[key].(string)
return s
}

View File

@@ -0,0 +1,47 @@
package soraerror
import (
"net/http"
"testing"
"github.com/stretchr/testify/require"
)
func TestIsCloudflareChallengeResponse(t *testing.T) {
headers := make(http.Header)
headers.Set("cf-mitigated", "challenge")
require.True(t, IsCloudflareChallengeResponse(http.StatusForbidden, headers, []byte(`{"ok":false}`)))
require.True(t, IsCloudflareChallengeResponse(http.StatusTooManyRequests, nil, []byte(`<!DOCTYPE html><title>Just a moment...</title><script>window._cf_chl_opt={};</script>`)))
require.False(t, IsCloudflareChallengeResponse(http.StatusBadGateway, nil, []byte(`<!DOCTYPE html><title>Just a moment...</title>`)))
}
func TestExtractCloudflareRayID(t *testing.T) {
headers := make(http.Header)
headers.Set("cf-ray", "9d01b0e9ecc35829-SEA")
require.Equal(t, "9d01b0e9ecc35829-SEA", ExtractCloudflareRayID(headers, nil))
body := []byte(`<script>window._cf_chl_opt={cRay: '9cff2d62d83bb98d'};</script>`)
require.Equal(t, "9cff2d62d83bb98d", ExtractCloudflareRayID(nil, body))
}
func TestExtractUpstreamErrorCodeAndMessage(t *testing.T) {
code, msg := ExtractUpstreamErrorCodeAndMessage([]byte(`{"error":{"code":"cf_shield_429","message":"rate limited"}}`))
require.Equal(t, "cf_shield_429", code)
require.Equal(t, "rate limited", msg)
code, msg = ExtractUpstreamErrorCodeAndMessage([]byte(`{"code":"unsupported_country_code","message":"not available"}`))
require.Equal(t, "unsupported_country_code", code)
require.Equal(t, "not available", msg)
code, msg = ExtractUpstreamErrorCodeAndMessage([]byte(`plain text`))
require.Equal(t, "", code)
require.Equal(t, "plain text", msg)
}
func TestFormatCloudflareChallengeMessage(t *testing.T) {
headers := make(http.Header)
headers.Set("cf-ray", "9d03b68c086027a1-SEA")
msg := FormatCloudflareChallengeMessage("blocked", headers, nil)
require.Equal(t, "blocked (cf-ray: 9d03b68c086027a1-SEA)", msg)
}