feat(sora): 对齐sora2api分镜角色去水印与挑战错误治理
This commit is contained in:
@@ -1 +1 @@
|
||||
0.1.83.1
|
||||
0.1.83.2
|
||||
|
||||
@@ -184,7 +184,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler)
|
||||
gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, usageService, apiKeyService, errorPassthroughService, configConfig)
|
||||
openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, apiKeyService, errorPassthroughService, configConfig)
|
||||
soraDirectClient := service.NewSoraDirectClient(configConfig, httpUpstream, openAITokenProvider)
|
||||
soraDirectClient := service.ProvideSoraDirectClient(configConfig, httpUpstream, openAITokenProvider, accountRepository, soraAccountRepository)
|
||||
soraMediaStorage := service.ProvideSoraMediaStorage(configConfig)
|
||||
soraGatewayService := service.NewSoraGatewayService(soraDirectClient, soraMediaStorage, rateLimitService, configConfig)
|
||||
soraGatewayHandler := handler.NewSoraGatewayHandler(gatewayService, soraGatewayService, concurrencyService, billingCacheService, configConfig)
|
||||
|
||||
@@ -12,7 +12,6 @@ import (
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -22,6 +21,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/soraerror"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/tidwall/gjson"
|
||||
@@ -29,8 +29,6 @@ import (
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
var soraCloudflareRayPattern = regexp.MustCompile(`(?i)cf-ray[:\s=]+([a-z0-9-]+)`)
|
||||
|
||||
// SoraGatewayHandler handles Sora chat completions requests
|
||||
type SoraGatewayHandler struct {
|
||||
gatewayService *service.GatewayService
|
||||
@@ -385,6 +383,10 @@ func (h *SoraGatewayHandler) mapUpstreamError(statusCode int, responseHeaders ht
|
||||
}
|
||||
|
||||
upstreamCode, upstreamMessage := extractUpstreamErrorCodeAndMessage(responseBody)
|
||||
if strings.EqualFold(upstreamCode, "cf_shield_429") {
|
||||
baseMsg := "Sora request blocked by Cloudflare shield (429). Please switch to a clean proxy/network and retry."
|
||||
return http.StatusTooManyRequests, "rate_limit_error", formatSoraCloudflareChallengeMessage(baseMsg, responseHeaders, responseBody)
|
||||
}
|
||||
if shouldPassthroughSoraUpstreamMessage(statusCode, upstreamMessage) {
|
||||
switch statusCode {
|
||||
case 401, 403, 404, 500, 502, 503, 504:
|
||||
@@ -416,27 +418,7 @@ func (h *SoraGatewayHandler) mapUpstreamError(statusCode int, responseHeaders ht
|
||||
}
|
||||
|
||||
func isSoraCloudflareChallengeResponse(statusCode int, headers http.Header, body []byte) bool {
|
||||
if statusCode != http.StatusForbidden && statusCode != http.StatusTooManyRequests {
|
||||
return false
|
||||
}
|
||||
if strings.EqualFold(strings.TrimSpace(headers.Get("cf-mitigated")), "challenge") {
|
||||
return true
|
||||
}
|
||||
preview := strings.ToLower(truncateSoraErrorBody(body, 4096))
|
||||
if strings.Contains(preview, "window._cf_chl_opt") ||
|
||||
strings.Contains(preview, "just a moment") ||
|
||||
strings.Contains(preview, "enable javascript and cookies to continue") ||
|
||||
strings.Contains(preview, "__cf_chl_") ||
|
||||
strings.Contains(preview, "challenge-platform") {
|
||||
return true
|
||||
}
|
||||
contentType := strings.ToLower(strings.TrimSpace(headers.Get("content-type")))
|
||||
if strings.Contains(contentType, "text/html") &&
|
||||
(strings.Contains(preview, "<html") || strings.Contains(preview, "<!doctype html")) &&
|
||||
(strings.Contains(preview, "cloudflare") || strings.Contains(preview, "challenge")) {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
return soraerror.IsCloudflareChallengeResponse(statusCode, headers, body)
|
||||
}
|
||||
|
||||
func shouldPassthroughSoraUpstreamMessage(statusCode int, message string) bool {
|
||||
@@ -454,76 +436,11 @@ func shouldPassthroughSoraUpstreamMessage(statusCode int, message string) bool {
|
||||
}
|
||||
|
||||
func formatSoraCloudflareChallengeMessage(base string, headers http.Header, body []byte) string {
|
||||
rayID := extractSoraCloudflareRayID(headers, body)
|
||||
if rayID == "" {
|
||||
return base
|
||||
}
|
||||
return fmt.Sprintf("%s (cf-ray: %s)", base, rayID)
|
||||
}
|
||||
|
||||
func extractSoraCloudflareRayID(headers http.Header, body []byte) string {
|
||||
if headers != nil {
|
||||
rayID := strings.TrimSpace(headers.Get("cf-ray"))
|
||||
if rayID != "" {
|
||||
return rayID
|
||||
}
|
||||
rayID = strings.TrimSpace(headers.Get("Cf-Ray"))
|
||||
if rayID != "" {
|
||||
return rayID
|
||||
}
|
||||
}
|
||||
preview := truncateSoraErrorBody(body, 8192)
|
||||
matches := soraCloudflareRayPattern.FindStringSubmatch(preview)
|
||||
if len(matches) >= 2 {
|
||||
return strings.TrimSpace(matches[1])
|
||||
}
|
||||
return ""
|
||||
return soraerror.FormatCloudflareChallengeMessage(base, headers, body)
|
||||
}
|
||||
|
||||
func extractUpstreamErrorCodeAndMessage(body []byte) (string, string) {
|
||||
trimmed := strings.TrimSpace(string(body))
|
||||
if trimmed == "" {
|
||||
return "", ""
|
||||
}
|
||||
if !gjson.Valid(trimmed) {
|
||||
return "", truncateSoraErrorMessage(trimmed, 256)
|
||||
}
|
||||
code := strings.TrimSpace(gjson.Get(trimmed, "error.code").String())
|
||||
if code == "" {
|
||||
code = strings.TrimSpace(gjson.Get(trimmed, "code").String())
|
||||
}
|
||||
message := strings.TrimSpace(gjson.Get(trimmed, "error.message").String())
|
||||
if message == "" {
|
||||
message = strings.TrimSpace(gjson.Get(trimmed, "message").String())
|
||||
}
|
||||
if message == "" {
|
||||
message = strings.TrimSpace(gjson.Get(trimmed, "error.detail").String())
|
||||
}
|
||||
if message == "" {
|
||||
message = strings.TrimSpace(gjson.Get(trimmed, "detail").String())
|
||||
}
|
||||
return code, truncateSoraErrorMessage(message, 512)
|
||||
}
|
||||
|
||||
func truncateSoraErrorMessage(s string, maxLen int) string {
|
||||
if maxLen <= 0 {
|
||||
return ""
|
||||
}
|
||||
if len(s) <= maxLen {
|
||||
return s
|
||||
}
|
||||
return s[:maxLen] + "...(truncated)"
|
||||
}
|
||||
|
||||
func truncateSoraErrorBody(body []byte, maxLen int) string {
|
||||
if maxLen <= 0 {
|
||||
maxLen = 512
|
||||
}
|
||||
raw := strings.TrimSpace(string(body))
|
||||
if len(raw) <= maxLen {
|
||||
return raw
|
||||
}
|
||||
return raw[:maxLen] + "...(truncated)"
|
||||
return soraerror.ExtractUpstreamErrorCodeAndMessage(body)
|
||||
}
|
||||
|
||||
func (h *SoraGatewayHandler) handleStreamingAwareError(c *gin.Context, status int, errType, message string, streamStarted bool) {
|
||||
|
||||
@@ -43,6 +43,45 @@ func (s *stubSoraClient) CreateImageTask(ctx context.Context, account *service.A
|
||||
func (s *stubSoraClient) CreateVideoTask(ctx context.Context, account *service.Account, req service.SoraVideoRequest) (string, error) {
|
||||
return "task-video", nil
|
||||
}
|
||||
func (s *stubSoraClient) CreateStoryboardTask(ctx context.Context, account *service.Account, req service.SoraStoryboardRequest) (string, error) {
|
||||
return "task-video", nil
|
||||
}
|
||||
func (s *stubSoraClient) UploadCharacterVideo(ctx context.Context, account *service.Account, data []byte) (string, error) {
|
||||
return "cameo-1", nil
|
||||
}
|
||||
func (s *stubSoraClient) GetCameoStatus(ctx context.Context, account *service.Account, cameoID string) (*service.SoraCameoStatus, error) {
|
||||
return &service.SoraCameoStatus{
|
||||
Status: "finalized",
|
||||
StatusMessage: "Completed",
|
||||
DisplayNameHint: "Character",
|
||||
UsernameHint: "user.character",
|
||||
ProfileAssetURL: "https://example.com/avatar.webp",
|
||||
}, nil
|
||||
}
|
||||
func (s *stubSoraClient) DownloadCharacterImage(ctx context.Context, account *service.Account, imageURL string) ([]byte, error) {
|
||||
return []byte("avatar"), nil
|
||||
}
|
||||
func (s *stubSoraClient) UploadCharacterImage(ctx context.Context, account *service.Account, data []byte) (string, error) {
|
||||
return "asset-pointer", nil
|
||||
}
|
||||
func (s *stubSoraClient) FinalizeCharacter(ctx context.Context, account *service.Account, req service.SoraCharacterFinalizeRequest) (string, error) {
|
||||
return "character-1", nil
|
||||
}
|
||||
func (s *stubSoraClient) SetCharacterPublic(ctx context.Context, account *service.Account, cameoID string) error {
|
||||
return nil
|
||||
}
|
||||
func (s *stubSoraClient) DeleteCharacter(ctx context.Context, account *service.Account, characterID string) error {
|
||||
return nil
|
||||
}
|
||||
func (s *stubSoraClient) PostVideoForWatermarkFree(ctx context.Context, account *service.Account, generationID string) (string, error) {
|
||||
return "s_post", nil
|
||||
}
|
||||
func (s *stubSoraClient) DeletePost(ctx context.Context, account *service.Account, postID string) error {
|
||||
return nil
|
||||
}
|
||||
func (s *stubSoraClient) GetWatermarkFreeURLCustom(ctx context.Context, account *service.Account, parseURL, parseToken, postID string) (string, error) {
|
||||
return "https://example.com/no-watermark.mp4", nil
|
||||
}
|
||||
func (s *stubSoraClient) EnhancePrompt(ctx context.Context, account *service.Account, prompt, expansionLevel string, durationS int) (string, error) {
|
||||
return "enhanced prompt", nil
|
||||
}
|
||||
@@ -607,3 +646,31 @@ func TestSoraHandleFailoverExhausted_CloudflareChallengeIncludesRay(t *testing.T
|
||||
require.Contains(t, msg, "Cloudflare challenge")
|
||||
require.Contains(t, msg, "cf-ray: 9d01b0e9ecc35829-SEA")
|
||||
}
|
||||
|
||||
func TestSoraHandleFailoverExhausted_CfShield429MappedToRateLimitError(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
|
||||
headers := http.Header{}
|
||||
headers.Set("cf-ray", "9d03b68c086027a1-SEA")
|
||||
body := []byte(`{"error":{"code":"cf_shield_429","message":"shield blocked"}}`)
|
||||
|
||||
h := &SoraGatewayHandler{}
|
||||
h.handleFailoverExhausted(c, http.StatusTooManyRequests, headers, body, true)
|
||||
|
||||
lines := strings.Split(strings.TrimSuffix(w.Body.String(), "\n\n"), "\n")
|
||||
require.Len(t, lines, 2)
|
||||
jsonStr := strings.TrimPrefix(lines[1], "data: ")
|
||||
|
||||
var parsed map[string]any
|
||||
require.NoError(t, json.Unmarshal([]byte(jsonStr), &parsed))
|
||||
|
||||
errorObj, ok := parsed["error"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "rate_limit_error", errorObj["type"])
|
||||
msg, _ := errorObj["message"].(string)
|
||||
require.Contains(t, msg, "Cloudflare shield")
|
||||
require.Contains(t, msg, "cf-ray: 9d03b68c086027a1-SEA")
|
||||
}
|
||||
|
||||
@@ -15,11 +15,14 @@ import (
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/soraerror"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
@@ -28,7 +31,6 @@ import (
|
||||
// sseDataPrefix matches SSE data lines with optional whitespace after colon.
|
||||
// Some upstream APIs return non-standard "data:" without space (should be "data: ").
|
||||
var sseDataPrefix = regexp.MustCompile(`^data:\s*`)
|
||||
var cloudflareRayPattern = regexp.MustCompile(`(?i)cRay:\s*'([a-z0-9-]+)'`)
|
||||
|
||||
const (
|
||||
testClaudeAPIURL = "https://api.anthropic.com/v1/messages"
|
||||
@@ -45,6 +47,9 @@ type TestEvent struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
Status string `json:"status,omitempty"`
|
||||
Code string `json:"code,omitempty"`
|
||||
Data any `json:"data,omitempty"`
|
||||
Success bool `json:"success,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
@@ -56,8 +61,13 @@ type AccountTestService struct {
|
||||
antigravityGatewayService *AntigravityGatewayService
|
||||
httpUpstream HTTPUpstream
|
||||
cfg *config.Config
|
||||
soraTestGuardMu sync.Mutex
|
||||
soraTestLastRun map[int64]time.Time
|
||||
soraTestCooldown time.Duration
|
||||
}
|
||||
|
||||
const defaultSoraTestCooldown = 10 * time.Second
|
||||
|
||||
// NewAccountTestService creates a new AccountTestService
|
||||
func NewAccountTestService(
|
||||
accountRepo AccountRepository,
|
||||
@@ -72,6 +82,8 @@ func NewAccountTestService(
|
||||
antigravityGatewayService: antigravityGatewayService,
|
||||
httpUpstream: httpUpstream,
|
||||
cfg: cfg,
|
||||
soraTestLastRun: make(map[int64]time.Time),
|
||||
soraTestCooldown: defaultSoraTestCooldown,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -473,13 +485,129 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account
|
||||
return s.processGeminiStream(c, resp.Body)
|
||||
}
|
||||
|
||||
type soraProbeStep struct {
|
||||
Name string `json:"name"`
|
||||
Status string `json:"status"`
|
||||
HTTPStatus int `json:"http_status,omitempty"`
|
||||
ErrorCode string `json:"error_code,omitempty"`
|
||||
Message string `json:"message,omitempty"`
|
||||
}
|
||||
|
||||
type soraProbeSummary struct {
|
||||
Status string `json:"status"`
|
||||
Steps []soraProbeStep `json:"steps"`
|
||||
}
|
||||
|
||||
type soraProbeRecorder struct {
|
||||
steps []soraProbeStep
|
||||
}
|
||||
|
||||
func (r *soraProbeRecorder) addStep(name, status string, httpStatus int, errorCode, message string) {
|
||||
r.steps = append(r.steps, soraProbeStep{
|
||||
Name: name,
|
||||
Status: status,
|
||||
HTTPStatus: httpStatus,
|
||||
ErrorCode: strings.TrimSpace(errorCode),
|
||||
Message: strings.TrimSpace(message),
|
||||
})
|
||||
}
|
||||
|
||||
func (r *soraProbeRecorder) finalize() soraProbeSummary {
|
||||
meSuccess := false
|
||||
partial := false
|
||||
for _, step := range r.steps {
|
||||
if step.Name == "me" {
|
||||
meSuccess = strings.EqualFold(step.Status, "success")
|
||||
continue
|
||||
}
|
||||
if strings.EqualFold(step.Status, "failed") {
|
||||
partial = true
|
||||
}
|
||||
}
|
||||
|
||||
status := "success"
|
||||
if !meSuccess {
|
||||
status = "failed"
|
||||
} else if partial {
|
||||
status = "partial_success"
|
||||
}
|
||||
|
||||
return soraProbeSummary{
|
||||
Status: status,
|
||||
Steps: append([]soraProbeStep(nil), r.steps...),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AccountTestService) emitSoraProbeSummary(c *gin.Context, rec *soraProbeRecorder) {
|
||||
if rec == nil {
|
||||
return
|
||||
}
|
||||
summary := rec.finalize()
|
||||
code := ""
|
||||
for _, step := range summary.Steps {
|
||||
if strings.EqualFold(step.Status, "failed") && strings.TrimSpace(step.ErrorCode) != "" {
|
||||
code = step.ErrorCode
|
||||
break
|
||||
}
|
||||
}
|
||||
s.sendEvent(c, TestEvent{
|
||||
Type: "sora_test_result",
|
||||
Status: summary.Status,
|
||||
Code: code,
|
||||
Data: summary,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *AccountTestService) acquireSoraTestPermit(accountID int64) (time.Duration, bool) {
|
||||
if accountID <= 0 {
|
||||
return 0, true
|
||||
}
|
||||
s.soraTestGuardMu.Lock()
|
||||
defer s.soraTestGuardMu.Unlock()
|
||||
|
||||
if s.soraTestLastRun == nil {
|
||||
s.soraTestLastRun = make(map[int64]time.Time)
|
||||
}
|
||||
cooldown := s.soraTestCooldown
|
||||
if cooldown <= 0 {
|
||||
cooldown = defaultSoraTestCooldown
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
if lastRun, ok := s.soraTestLastRun[accountID]; ok {
|
||||
elapsed := now.Sub(lastRun)
|
||||
if elapsed < cooldown {
|
||||
return cooldown - elapsed, false
|
||||
}
|
||||
}
|
||||
s.soraTestLastRun[accountID] = now
|
||||
return 0, true
|
||||
}
|
||||
|
||||
func ceilSeconds(d time.Duration) int {
|
||||
if d <= 0 {
|
||||
return 1
|
||||
}
|
||||
sec := int(d / time.Second)
|
||||
if d%time.Second != 0 {
|
||||
sec++
|
||||
}
|
||||
if sec < 1 {
|
||||
sec = 1
|
||||
}
|
||||
return sec
|
||||
}
|
||||
|
||||
// testSoraAccountConnection 测试 Sora 账号的连接
|
||||
// 调用 /backend/me 接口验证 access_token 有效性(不需要 Sentinel Token)
|
||||
func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account *Account) error {
|
||||
ctx := c.Request.Context()
|
||||
recorder := &soraProbeRecorder{}
|
||||
|
||||
authToken := account.GetCredential("access_token")
|
||||
if authToken == "" {
|
||||
recorder.addStep("me", "failed", http.StatusUnauthorized, "missing_access_token", "No access token available")
|
||||
s.emitSoraProbeSummary(c, recorder)
|
||||
return s.sendErrorAndEnd(c, "No access token available")
|
||||
}
|
||||
|
||||
@@ -490,11 +618,20 @@ func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account *
|
||||
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
||||
c.Writer.Flush()
|
||||
|
||||
if wait, ok := s.acquireSoraTestPermit(account.ID); !ok {
|
||||
msg := fmt.Sprintf("Sora 账号测试过于频繁,请 %d 秒后重试", ceilSeconds(wait))
|
||||
recorder.addStep("rate_limit", "failed", http.StatusTooManyRequests, "test_rate_limited", msg)
|
||||
s.emitSoraProbeSummary(c, recorder)
|
||||
return s.sendErrorAndEnd(c, msg)
|
||||
}
|
||||
|
||||
// Send test_start event
|
||||
s.sendEvent(c, TestEvent{Type: "test_start", Model: "sora"})
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", soraMeAPIURL, nil)
|
||||
if err != nil {
|
||||
recorder.addStep("me", "failed", 0, "request_build_failed", err.Error())
|
||||
s.emitSoraProbeSummary(c, recorder)
|
||||
return s.sendErrorAndEnd(c, "Failed to create request")
|
||||
}
|
||||
|
||||
@@ -515,6 +652,8 @@ func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account *
|
||||
|
||||
resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, enableSoraTLSFingerprint)
|
||||
if err != nil {
|
||||
recorder.addStep("me", "failed", 0, "network_error", err.Error())
|
||||
s.emitSoraProbeSummary(c, recorder)
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
@@ -522,12 +661,33 @@ func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account *
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if isCloudflareChallengeResponse(resp.StatusCode, body) {
|
||||
if isCloudflareChallengeResponse(resp.StatusCode, resp.Header, body) {
|
||||
recorder.addStep("me", "failed", resp.StatusCode, "cf_challenge", "Cloudflare challenge detected")
|
||||
s.emitSoraProbeSummary(c, recorder)
|
||||
s.logSoraCloudflareChallenge(account, proxyURL, soraMeAPIURL, resp.Header, body)
|
||||
return s.sendErrorAndEnd(c, formatCloudflareChallengeMessage("Sora request blocked by Cloudflare challenge (HTTP 403). Please switch to a clean proxy/network and retry.", resp.Header, body))
|
||||
return s.sendErrorAndEnd(c, formatCloudflareChallengeMessage(fmt.Sprintf("Sora request blocked by Cloudflare challenge (HTTP %d). Please switch to a clean proxy/network and retry.", resp.StatusCode), resp.Header, body))
|
||||
}
|
||||
upstreamCode, upstreamMessage := soraerror.ExtractUpstreamErrorCodeAndMessage(body)
|
||||
switch {
|
||||
case resp.StatusCode == http.StatusUnauthorized && strings.EqualFold(upstreamCode, "token_invalidated"):
|
||||
recorder.addStep("me", "failed", resp.StatusCode, "token_invalidated", "Sora token invalidated")
|
||||
s.emitSoraProbeSummary(c, recorder)
|
||||
return s.sendErrorAndEnd(c, "Sora token 已失效(token_invalidated),请重新授权账号")
|
||||
case strings.EqualFold(upstreamCode, "unsupported_country_code"):
|
||||
recorder.addStep("me", "failed", resp.StatusCode, "unsupported_country_code", "Sora is unavailable in current egress region")
|
||||
s.emitSoraProbeSummary(c, recorder)
|
||||
return s.sendErrorAndEnd(c, "Sora 在当前网络出口地区不可用(unsupported_country_code),请切换到支持地区后重试")
|
||||
case strings.TrimSpace(upstreamMessage) != "":
|
||||
recorder.addStep("me", "failed", resp.StatusCode, upstreamCode, upstreamMessage)
|
||||
s.emitSoraProbeSummary(c, recorder)
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Sora API returned %d: %s", resp.StatusCode, upstreamMessage))
|
||||
default:
|
||||
recorder.addStep("me", "failed", resp.StatusCode, upstreamCode, "Sora me endpoint failed")
|
||||
s.emitSoraProbeSummary(c, recorder)
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Sora API returned %d: %s", resp.StatusCode, truncateSoraErrorBody(body, 512)))
|
||||
}
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Sora API returned %d: %s", resp.StatusCode, truncateSoraErrorBody(body, 512)))
|
||||
}
|
||||
recorder.addStep("me", "success", resp.StatusCode, "", "me endpoint ok")
|
||||
|
||||
// 解析 /me 响应,提取用户信息
|
||||
var meResp map[string]any
|
||||
@@ -557,21 +717,26 @@ func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account *
|
||||
|
||||
subResp, subErr := s.httpUpstream.DoWithTLS(subReq, proxyURL, account.ID, account.Concurrency, enableSoraTLSFingerprint)
|
||||
if subErr != nil {
|
||||
recorder.addStep("subscription", "failed", 0, "network_error", subErr.Error())
|
||||
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Subscription check skipped: %s", subErr.Error())})
|
||||
} else {
|
||||
subBody, _ := io.ReadAll(subResp.Body)
|
||||
_ = subResp.Body.Close()
|
||||
if subResp.StatusCode == http.StatusOK {
|
||||
recorder.addStep("subscription", "success", subResp.StatusCode, "", "subscription endpoint ok")
|
||||
if summary := parseSoraSubscriptionSummary(subBody); summary != "" {
|
||||
s.sendEvent(c, TestEvent{Type: "content", Text: summary})
|
||||
} else {
|
||||
s.sendEvent(c, TestEvent{Type: "content", Text: "Subscription check OK"})
|
||||
}
|
||||
} else {
|
||||
if isCloudflareChallengeResponse(subResp.StatusCode, subBody) {
|
||||
if isCloudflareChallengeResponse(subResp.StatusCode, subResp.Header, subBody) {
|
||||
recorder.addStep("subscription", "failed", subResp.StatusCode, "cf_challenge", "Cloudflare challenge detected")
|
||||
s.logSoraCloudflareChallenge(account, proxyURL, soraBillingAPIURL, subResp.Header, subBody)
|
||||
s.sendEvent(c, TestEvent{Type: "content", Text: formatCloudflareChallengeMessage("Subscription check blocked by Cloudflare challenge (HTTP 403)", subResp.Header, subBody)})
|
||||
s.sendEvent(c, TestEvent{Type: "content", Text: formatCloudflareChallengeMessage(fmt.Sprintf("Subscription check blocked by Cloudflare challenge (HTTP %d)", subResp.StatusCode), subResp.Header, subBody)})
|
||||
} else {
|
||||
upstreamCode, upstreamMessage := soraerror.ExtractUpstreamErrorCodeAndMessage(subBody)
|
||||
recorder.addStep("subscription", "failed", subResp.StatusCode, upstreamCode, upstreamMessage)
|
||||
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Subscription check returned %d", subResp.StatusCode)})
|
||||
}
|
||||
}
|
||||
@@ -579,8 +744,9 @@ func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account *
|
||||
}
|
||||
|
||||
// 追加 Sora2 能力探测(对齐 sora2api 的测试思路):邀请码 + 剩余额度。
|
||||
s.testSora2Capabilities(c, ctx, account, authToken, proxyURL, enableSoraTLSFingerprint)
|
||||
s.testSora2Capabilities(c, ctx, account, authToken, proxyURL, enableSoraTLSFingerprint, recorder)
|
||||
|
||||
s.emitSoraProbeSummary(c, recorder)
|
||||
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
|
||||
return nil
|
||||
}
|
||||
@@ -592,6 +758,7 @@ func (s *AccountTestService) testSora2Capabilities(
|
||||
authToken string,
|
||||
proxyURL string,
|
||||
enableTLSFingerprint bool,
|
||||
recorder *soraProbeRecorder,
|
||||
) {
|
||||
inviteStatus, inviteHeader, inviteBody, err := s.fetchSoraTestEndpoint(
|
||||
ctx,
|
||||
@@ -602,6 +769,9 @@ func (s *AccountTestService) testSora2Capabilities(
|
||||
enableTLSFingerprint,
|
||||
)
|
||||
if err != nil {
|
||||
if recorder != nil {
|
||||
recorder.addStep("sora2_invite", "failed", 0, "network_error", err.Error())
|
||||
}
|
||||
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 invite check skipped: %s", err.Error())})
|
||||
return
|
||||
}
|
||||
@@ -616,6 +786,9 @@ func (s *AccountTestService) testSora2Capabilities(
|
||||
enableTLSFingerprint,
|
||||
)
|
||||
if bootstrapErr == nil && bootstrapStatus == http.StatusOK {
|
||||
if recorder != nil {
|
||||
recorder.addStep("sora2_bootstrap", "success", bootstrapStatus, "", "bootstrap endpoint ok")
|
||||
}
|
||||
s.sendEvent(c, TestEvent{Type: "content", Text: "Sora2 bootstrap OK, retry invite check"})
|
||||
inviteStatus, inviteHeader, inviteBody, err = s.fetchSoraTestEndpoint(
|
||||
ctx,
|
||||
@@ -626,20 +799,42 @@ func (s *AccountTestService) testSora2Capabilities(
|
||||
enableTLSFingerprint,
|
||||
)
|
||||
if err != nil {
|
||||
if recorder != nil {
|
||||
recorder.addStep("sora2_invite", "failed", 0, "network_error", err.Error())
|
||||
}
|
||||
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 invite retry failed: %s", err.Error())})
|
||||
return
|
||||
}
|
||||
} else if recorder != nil {
|
||||
code := ""
|
||||
msg := ""
|
||||
if bootstrapErr != nil {
|
||||
code = "network_error"
|
||||
msg = bootstrapErr.Error()
|
||||
}
|
||||
recorder.addStep("sora2_bootstrap", "failed", bootstrapStatus, code, msg)
|
||||
}
|
||||
}
|
||||
|
||||
if inviteStatus != http.StatusOK {
|
||||
if isCloudflareChallengeResponse(inviteStatus, inviteBody) {
|
||||
s.sendEvent(c, TestEvent{Type: "content", Text: formatCloudflareChallengeMessage("Sora2 invite check blocked by Cloudflare challenge (HTTP 403)", inviteHeader, inviteBody)})
|
||||
if isCloudflareChallengeResponse(inviteStatus, inviteHeader, inviteBody) {
|
||||
if recorder != nil {
|
||||
recorder.addStep("sora2_invite", "failed", inviteStatus, "cf_challenge", "Cloudflare challenge detected")
|
||||
}
|
||||
s.logSoraCloudflareChallenge(account, proxyURL, soraInviteMineURL, inviteHeader, inviteBody)
|
||||
s.sendEvent(c, TestEvent{Type: "content", Text: formatCloudflareChallengeMessage(fmt.Sprintf("Sora2 invite check blocked by Cloudflare challenge (HTTP %d)", inviteStatus), inviteHeader, inviteBody)})
|
||||
return
|
||||
}
|
||||
upstreamCode, upstreamMessage := soraerror.ExtractUpstreamErrorCodeAndMessage(inviteBody)
|
||||
if recorder != nil {
|
||||
recorder.addStep("sora2_invite", "failed", inviteStatus, upstreamCode, upstreamMessage)
|
||||
}
|
||||
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 invite check returned %d", inviteStatus)})
|
||||
return
|
||||
}
|
||||
if recorder != nil {
|
||||
recorder.addStep("sora2_invite", "success", inviteStatus, "", "invite endpoint ok")
|
||||
}
|
||||
|
||||
if summary := parseSoraInviteSummary(inviteBody); summary != "" {
|
||||
s.sendEvent(c, TestEvent{Type: "content", Text: summary})
|
||||
@@ -656,17 +851,31 @@ func (s *AccountTestService) testSora2Capabilities(
|
||||
enableTLSFingerprint,
|
||||
)
|
||||
if remainingErr != nil {
|
||||
if recorder != nil {
|
||||
recorder.addStep("sora2_remaining", "failed", 0, "network_error", remainingErr.Error())
|
||||
}
|
||||
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 remaining check skipped: %s", remainingErr.Error())})
|
||||
return
|
||||
}
|
||||
if remainingStatus != http.StatusOK {
|
||||
if isCloudflareChallengeResponse(remainingStatus, remainingBody) {
|
||||
s.sendEvent(c, TestEvent{Type: "content", Text: formatCloudflareChallengeMessage("Sora2 remaining check blocked by Cloudflare challenge (HTTP 403)", remainingHeader, remainingBody)})
|
||||
if isCloudflareChallengeResponse(remainingStatus, remainingHeader, remainingBody) {
|
||||
if recorder != nil {
|
||||
recorder.addStep("sora2_remaining", "failed", remainingStatus, "cf_challenge", "Cloudflare challenge detected")
|
||||
}
|
||||
s.logSoraCloudflareChallenge(account, proxyURL, soraRemainingURL, remainingHeader, remainingBody)
|
||||
s.sendEvent(c, TestEvent{Type: "content", Text: formatCloudflareChallengeMessage(fmt.Sprintf("Sora2 remaining check blocked by Cloudflare challenge (HTTP %d)", remainingStatus), remainingHeader, remainingBody)})
|
||||
return
|
||||
}
|
||||
upstreamCode, upstreamMessage := soraerror.ExtractUpstreamErrorCodeAndMessage(remainingBody)
|
||||
if recorder != nil {
|
||||
recorder.addStep("sora2_remaining", "failed", remainingStatus, upstreamCode, upstreamMessage)
|
||||
}
|
||||
s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 remaining check returned %d", remainingStatus)})
|
||||
return
|
||||
}
|
||||
if recorder != nil {
|
||||
recorder.addStep("sora2_remaining", "success", remainingStatus, "", "remaining endpoint ok")
|
||||
}
|
||||
if summary := parseSoraRemainingSummary(remainingBody); summary != "" {
|
||||
s.sendEvent(c, TestEvent{Type: "content", Text: summary})
|
||||
} else {
|
||||
@@ -789,42 +998,16 @@ func (s *AccountTestService) shouldEnableSoraTLSFingerprint() bool {
|
||||
return !s.cfg.Sora.Client.DisableTLSFingerprint
|
||||
}
|
||||
|
||||
func isCloudflareChallengeResponse(statusCode int, body []byte) bool {
|
||||
if statusCode != http.StatusForbidden {
|
||||
return false
|
||||
}
|
||||
preview := strings.ToLower(truncateSoraErrorBody(body, 4096))
|
||||
return strings.Contains(preview, "window._cf_chl_opt") ||
|
||||
strings.Contains(preview, "just a moment") ||
|
||||
strings.Contains(preview, "enable javascript and cookies to continue")
|
||||
func isCloudflareChallengeResponse(statusCode int, headers http.Header, body []byte) bool {
|
||||
return soraerror.IsCloudflareChallengeResponse(statusCode, headers, body)
|
||||
}
|
||||
|
||||
func formatCloudflareChallengeMessage(base string, headers http.Header, body []byte) string {
|
||||
rayID := extractCloudflareRayID(headers, body)
|
||||
if rayID == "" {
|
||||
return base
|
||||
}
|
||||
return fmt.Sprintf("%s (cf-ray: %s)", base, rayID)
|
||||
return soraerror.FormatCloudflareChallengeMessage(base, headers, body)
|
||||
}
|
||||
|
||||
func extractCloudflareRayID(headers http.Header, body []byte) string {
|
||||
if headers != nil {
|
||||
rayID := strings.TrimSpace(headers.Get("cf-ray"))
|
||||
if rayID != "" {
|
||||
return rayID
|
||||
}
|
||||
rayID = strings.TrimSpace(headers.Get("Cf-Ray"))
|
||||
if rayID != "" {
|
||||
return rayID
|
||||
}
|
||||
}
|
||||
|
||||
preview := truncateSoraErrorBody(body, 8192)
|
||||
matches := cloudflareRayPattern.FindStringSubmatch(preview)
|
||||
if len(matches) >= 2 {
|
||||
return strings.TrimSpace(matches[1])
|
||||
}
|
||||
return ""
|
||||
return soraerror.ExtractCloudflareRayID(headers, body)
|
||||
}
|
||||
|
||||
func extractSoraEgressIPHint(headers http.Header) string {
|
||||
@@ -897,14 +1080,7 @@ func (s *AccountTestService) logSoraCloudflareChallenge(account *Account, proxyU
|
||||
}
|
||||
|
||||
func truncateSoraErrorBody(body []byte, max int) string {
|
||||
if max <= 0 {
|
||||
max = 512
|
||||
}
|
||||
raw := strings.TrimSpace(string(body))
|
||||
if len(raw) <= max {
|
||||
return raw
|
||||
}
|
||||
return raw[:max] + "...(truncated)"
|
||||
return soraerror.TruncateBody(body, max)
|
||||
}
|
||||
|
||||
// testAntigravityAccountConnection tests an Antigravity account's connection
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -109,6 +110,8 @@ func TestAccountTestService_testSoraAccountConnection_WithSubscription(t *testin
|
||||
require.Contains(t, body, "Subscription: ChatGPT Plus | chatgpt_plus | end=2026-12-31T00:00:00Z")
|
||||
require.Contains(t, body, "Sora2: supported | invite=inv_abc | used=3/50")
|
||||
require.Contains(t, body, "Sora2 remaining: 27 | reset_in=46833s")
|
||||
require.Contains(t, body, `"type":"sora_test_result"`)
|
||||
require.Contains(t, body, `"status":"success"`)
|
||||
require.Contains(t, body, `"type":"test_complete","success":true`)
|
||||
}
|
||||
|
||||
@@ -141,6 +144,8 @@ func TestAccountTestService_testSoraAccountConnection_SubscriptionFailedStillSuc
|
||||
require.Contains(t, body, "Sora connection OK - User: demo-user")
|
||||
require.Contains(t, body, "Subscription check returned 403")
|
||||
require.Contains(t, body, "Sora2 invite check returned 401")
|
||||
require.Contains(t, body, `"type":"sora_test_result"`)
|
||||
require.Contains(t, body, `"status":"partial_success"`)
|
||||
require.Contains(t, body, `"type":"test_complete","success":true`)
|
||||
}
|
||||
|
||||
@@ -173,6 +178,97 @@ func TestAccountTestService_testSoraAccountConnection_CloudflareChallenge(t *tes
|
||||
require.Contains(t, body, "cf-ray: 9cff2d62d83bb98d")
|
||||
}
|
||||
|
||||
func TestAccountTestService_testSoraAccountConnection_CloudflareChallenge429WithHeader(t *testing.T) {
|
||||
upstream := &queuedHTTPUpstream{
|
||||
responses: []*http.Response{
|
||||
newJSONResponseWithHeader(http.StatusTooManyRequests, `<!DOCTYPE html><html><head><title>Just a moment...</title></head><body></body></html>`, "cf-mitigated", "challenge"),
|
||||
},
|
||||
}
|
||||
svc := &AccountTestService{httpUpstream: upstream}
|
||||
account := &Account{
|
||||
ID: 1,
|
||||
Platform: PlatformSora,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "test_token",
|
||||
},
|
||||
}
|
||||
|
||||
c, rec := newSoraTestContext()
|
||||
err := svc.testSoraAccountConnection(c, account)
|
||||
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "Cloudflare challenge")
|
||||
require.Contains(t, err.Error(), "HTTP 429")
|
||||
body := rec.Body.String()
|
||||
require.Contains(t, body, "Cloudflare challenge")
|
||||
}
|
||||
|
||||
func TestAccountTestService_testSoraAccountConnection_TokenInvalidated(t *testing.T) {
|
||||
upstream := &queuedHTTPUpstream{
|
||||
responses: []*http.Response{
|
||||
newJSONResponse(http.StatusUnauthorized, `{"error":{"code":"token_invalidated","message":"Token invalid"}}`),
|
||||
},
|
||||
}
|
||||
svc := &AccountTestService{httpUpstream: upstream}
|
||||
account := &Account{
|
||||
ID: 1,
|
||||
Platform: PlatformSora,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "test_token",
|
||||
},
|
||||
}
|
||||
|
||||
c, rec := newSoraTestContext()
|
||||
err := svc.testSoraAccountConnection(c, account)
|
||||
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "token_invalidated")
|
||||
body := rec.Body.String()
|
||||
require.Contains(t, body, `"type":"sora_test_result"`)
|
||||
require.Contains(t, body, `"status":"failed"`)
|
||||
require.Contains(t, body, "token_invalidated")
|
||||
require.NotContains(t, body, `"type":"test_complete","success":true`)
|
||||
}
|
||||
|
||||
func TestAccountTestService_testSoraAccountConnection_RateLimited(t *testing.T) {
|
||||
upstream := &queuedHTTPUpstream{
|
||||
responses: []*http.Response{
|
||||
newJSONResponse(http.StatusOK, `{"email":"demo@example.com"}`),
|
||||
},
|
||||
}
|
||||
svc := &AccountTestService{
|
||||
httpUpstream: upstream,
|
||||
soraTestCooldown: time.Hour,
|
||||
}
|
||||
account := &Account{
|
||||
ID: 1,
|
||||
Platform: PlatformSora,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "test_token",
|
||||
},
|
||||
}
|
||||
|
||||
c1, _ := newSoraTestContext()
|
||||
err := svc.testSoraAccountConnection(c1, account)
|
||||
require.NoError(t, err)
|
||||
|
||||
c2, rec2 := newSoraTestContext()
|
||||
err = svc.testSoraAccountConnection(c2, account)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "测试过于频繁")
|
||||
body := rec2.Body.String()
|
||||
require.Contains(t, body, `"type":"sora_test_result"`)
|
||||
require.Contains(t, body, `"code":"test_rate_limited"`)
|
||||
require.Contains(t, body, `"status":"failed"`)
|
||||
require.NotContains(t, body, `"type":"test_complete","success":true`)
|
||||
}
|
||||
|
||||
func TestAccountTestService_testSoraAccountConnection_SubscriptionCloudflareChallengeWithRay(t *testing.T) {
|
||||
upstream := &queuedHTTPUpstream{
|
||||
responses: []*http.Response{
|
||||
|
||||
@@ -95,6 +95,16 @@ var soraDesktopUserAgents = []string{
|
||||
"Mozilla/5.0 (Windows NT 11.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36",
|
||||
}
|
||||
|
||||
var soraMobileUserAgents = []string{
|
||||
"Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)",
|
||||
"Sora/1.2026.007 (Android 14; SM-G998B; build 2600700)",
|
||||
"Sora/1.2026.007 (Android 15; Pixel 8 Pro; build 2600700)",
|
||||
"Sora/1.2026.007 (Android 14; Pixel 7; build 2600700)",
|
||||
"Sora/1.2026.007 (Android 15; 2211133C; build 2600700)",
|
||||
"Sora/1.2026.007 (Android 14; SM-S918B; build 2600700)",
|
||||
"Sora/1.2026.007 (Android 15; OnePlus 12; build 2600700)",
|
||||
}
|
||||
|
||||
var soraRand = rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||
var soraRandMu sync.Mutex
|
||||
var soraPerfStart = time.Now()
|
||||
@@ -106,6 +116,17 @@ type SoraClient interface {
|
||||
UploadImage(ctx context.Context, account *Account, data []byte, filename string) (string, error)
|
||||
CreateImageTask(ctx context.Context, account *Account, req SoraImageRequest) (string, error)
|
||||
CreateVideoTask(ctx context.Context, account *Account, req SoraVideoRequest) (string, error)
|
||||
CreateStoryboardTask(ctx context.Context, account *Account, req SoraStoryboardRequest) (string, error)
|
||||
UploadCharacterVideo(ctx context.Context, account *Account, data []byte) (string, error)
|
||||
GetCameoStatus(ctx context.Context, account *Account, cameoID string) (*SoraCameoStatus, error)
|
||||
DownloadCharacterImage(ctx context.Context, account *Account, imageURL string) ([]byte, error)
|
||||
UploadCharacterImage(ctx context.Context, account *Account, data []byte) (string, error)
|
||||
FinalizeCharacter(ctx context.Context, account *Account, req SoraCharacterFinalizeRequest) (string, error)
|
||||
SetCharacterPublic(ctx context.Context, account *Account, cameoID string) error
|
||||
DeleteCharacter(ctx context.Context, account *Account, characterID string) error
|
||||
PostVideoForWatermarkFree(ctx context.Context, account *Account, generationID string) (string, error)
|
||||
DeletePost(ctx context.Context, account *Account, postID string) error
|
||||
GetWatermarkFreeURLCustom(ctx context.Context, account *Account, parseURL, parseToken, postID string) (string, error)
|
||||
EnhancePrompt(ctx context.Context, account *Account, prompt, expansionLevel string, durationS int) (string, error)
|
||||
GetImageTask(ctx context.Context, account *Account, taskID string) (*SoraImageTaskStatus, error)
|
||||
GetVideoTask(ctx context.Context, account *Account, taskID string) (*SoraVideoTaskStatus, error)
|
||||
@@ -128,6 +149,17 @@ type SoraVideoRequest struct {
|
||||
Size string
|
||||
MediaID string
|
||||
RemixTargetID string
|
||||
CameoIDs []string
|
||||
}
|
||||
|
||||
// SoraStoryboardRequest 分镜视频生成请求参数
|
||||
type SoraStoryboardRequest struct {
|
||||
Prompt string
|
||||
Orientation string
|
||||
Frames int
|
||||
Model string
|
||||
Size string
|
||||
MediaID string
|
||||
}
|
||||
|
||||
// SoraImageTaskStatus 图片任务状态
|
||||
@@ -141,11 +173,32 @@ type SoraImageTaskStatus struct {
|
||||
|
||||
// SoraVideoTaskStatus 视频任务状态
|
||||
type SoraVideoTaskStatus struct {
|
||||
ID string
|
||||
Status string
|
||||
ProgressPct int
|
||||
URLs []string
|
||||
ErrorMsg string
|
||||
ID string
|
||||
Status string
|
||||
ProgressPct int
|
||||
URLs []string
|
||||
GenerationID string
|
||||
ErrorMsg string
|
||||
}
|
||||
|
||||
// SoraCameoStatus 角色处理中间态
|
||||
type SoraCameoStatus struct {
|
||||
Status string
|
||||
StatusMessage string
|
||||
DisplayNameHint string
|
||||
UsernameHint string
|
||||
ProfileAssetURL string
|
||||
InstructionSetHint any
|
||||
InstructionSet any
|
||||
}
|
||||
|
||||
// SoraCharacterFinalizeRequest 角色定稿请求参数
|
||||
type SoraCharacterFinalizeRequest struct {
|
||||
CameoID string
|
||||
Username string
|
||||
DisplayName string
|
||||
ProfileAssetPointer string
|
||||
InstructionSet any
|
||||
}
|
||||
|
||||
// SoraUpstreamError 上游错误
|
||||
@@ -407,6 +460,9 @@ func (c *SoraDirectClient) CreateVideoTask(ctx context.Context, account *Account
|
||||
payload["remix_target_id"] = req.RemixTargetID
|
||||
payload["cameo_ids"] = []string{}
|
||||
payload["cameo_replacements"] = map[string]any{}
|
||||
} else if len(req.CameoIDs) > 0 {
|
||||
payload["cameo_ids"] = req.CameoIDs
|
||||
payload["cameo_replacements"] = map[string]any{}
|
||||
}
|
||||
|
||||
headers := c.buildBaseHeaders(token, userAgent)
|
||||
@@ -434,6 +490,425 @@ func (c *SoraDirectClient) CreateVideoTask(ctx context.Context, account *Account
|
||||
return taskID, nil
|
||||
}
|
||||
|
||||
func (c *SoraDirectClient) CreateStoryboardTask(ctx context.Context, account *Account, req SoraStoryboardRequest) (string, error) {
|
||||
token, err := c.getAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
userAgent := c.taskUserAgent()
|
||||
proxyURL := c.resolveProxyURL(account)
|
||||
orientation := req.Orientation
|
||||
if orientation == "" {
|
||||
orientation = "landscape"
|
||||
}
|
||||
nFrames := req.Frames
|
||||
if nFrames <= 0 {
|
||||
nFrames = 450
|
||||
}
|
||||
model := req.Model
|
||||
if model == "" {
|
||||
model = "sy_8"
|
||||
}
|
||||
size := req.Size
|
||||
if size == "" {
|
||||
size = "small"
|
||||
}
|
||||
|
||||
inpaintItems := []map[string]any{}
|
||||
if strings.TrimSpace(req.MediaID) != "" {
|
||||
inpaintItems = append(inpaintItems, map[string]any{
|
||||
"kind": "upload",
|
||||
"upload_id": req.MediaID,
|
||||
})
|
||||
}
|
||||
payload := map[string]any{
|
||||
"kind": "video",
|
||||
"prompt": req.Prompt,
|
||||
"title": "Draft your video",
|
||||
"orientation": orientation,
|
||||
"size": size,
|
||||
"n_frames": nFrames,
|
||||
"storyboard_id": nil,
|
||||
"inpaint_items": inpaintItems,
|
||||
"remix_target_id": nil,
|
||||
"model": model,
|
||||
"metadata": nil,
|
||||
"style_id": nil,
|
||||
"cameo_ids": nil,
|
||||
"cameo_replacements": nil,
|
||||
"audio_caption": nil,
|
||||
"audio_transcript": nil,
|
||||
"video_caption": nil,
|
||||
}
|
||||
|
||||
headers := c.buildBaseHeaders(token, userAgent)
|
||||
headers.Set("Content-Type", "application/json")
|
||||
headers.Set("Origin", "https://sora.chatgpt.com")
|
||||
headers.Set("Referer", "https://sora.chatgpt.com/")
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
sentinel, err := c.generateSentinelToken(ctx, account, token, userAgent, proxyURL)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
headers.Set("openai-sentinel-token", sentinel)
|
||||
|
||||
respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodPost, c.buildURL("/nf/create/storyboard"), headers, bytes.NewReader(body), true)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
taskID := strings.TrimSpace(gjson.GetBytes(respBody, "id").String())
|
||||
if taskID == "" {
|
||||
return "", errors.New("storyboard task response missing id")
|
||||
}
|
||||
return taskID, nil
|
||||
}
|
||||
|
||||
func (c *SoraDirectClient) UploadCharacterVideo(ctx context.Context, account *Account, data []byte) (string, error) {
|
||||
if len(data) == 0 {
|
||||
return "", errors.New("empty video data")
|
||||
}
|
||||
token, err := c.getAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
userAgent := c.taskUserAgent()
|
||||
proxyURL := c.resolveProxyURL(account)
|
||||
|
||||
var body bytes.Buffer
|
||||
writer := multipart.NewWriter(&body)
|
||||
partHeader := make(textproto.MIMEHeader)
|
||||
partHeader.Set("Content-Disposition", `form-data; name="file"; filename="video.mp4"`)
|
||||
partHeader.Set("Content-Type", "video/mp4")
|
||||
part, err := writer.CreatePart(partHeader)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if _, err := part.Write(data); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := writer.WriteField("timestamps", "0,3"); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := writer.Close(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
headers := c.buildBaseHeaders(token, userAgent)
|
||||
headers.Set("Content-Type", writer.FormDataContentType())
|
||||
respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodPost, c.buildURL("/characters/upload"), headers, &body, false)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
cameoID := strings.TrimSpace(gjson.GetBytes(respBody, "id").String())
|
||||
if cameoID == "" {
|
||||
return "", errors.New("character upload response missing id")
|
||||
}
|
||||
return cameoID, nil
|
||||
}
|
||||
|
||||
func (c *SoraDirectClient) GetCameoStatus(ctx context.Context, account *Account, cameoID string) (*SoraCameoStatus, error) {
|
||||
token, err := c.getAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
userAgent := c.taskUserAgent()
|
||||
proxyURL := c.resolveProxyURL(account)
|
||||
headers := c.buildBaseHeaders(token, userAgent)
|
||||
respBody, _, err := c.doRequestWithProxy(
|
||||
ctx,
|
||||
account,
|
||||
proxyURL,
|
||||
http.MethodGet,
|
||||
c.buildURL("/project_y/cameos/in_progress/"+strings.TrimSpace(cameoID)),
|
||||
headers,
|
||||
nil,
|
||||
false,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &SoraCameoStatus{
|
||||
Status: strings.TrimSpace(gjson.GetBytes(respBody, "status").String()),
|
||||
StatusMessage: strings.TrimSpace(gjson.GetBytes(respBody, "status_message").String()),
|
||||
DisplayNameHint: strings.TrimSpace(gjson.GetBytes(respBody, "display_name_hint").String()),
|
||||
UsernameHint: strings.TrimSpace(gjson.GetBytes(respBody, "username_hint").String()),
|
||||
ProfileAssetURL: strings.TrimSpace(gjson.GetBytes(respBody, "profile_asset_url").String()),
|
||||
InstructionSetHint: gjson.GetBytes(respBody, "instruction_set_hint").Value(),
|
||||
InstructionSet: gjson.GetBytes(respBody, "instruction_set").Value(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *SoraDirectClient) DownloadCharacterImage(ctx context.Context, account *Account, imageURL string) ([]byte, error) {
|
||||
token, err := c.getAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
userAgent := c.taskUserAgent()
|
||||
proxyURL := c.resolveProxyURL(account)
|
||||
headers := c.buildBaseHeaders(token, userAgent)
|
||||
headers.Set("Accept", "image/*,*/*;q=0.8")
|
||||
|
||||
respBody, _, err := c.doRequestWithProxy(
|
||||
ctx,
|
||||
account,
|
||||
proxyURL,
|
||||
http.MethodGet,
|
||||
strings.TrimSpace(imageURL),
|
||||
headers,
|
||||
nil,
|
||||
false,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return respBody, nil
|
||||
}
|
||||
|
||||
func (c *SoraDirectClient) UploadCharacterImage(ctx context.Context, account *Account, data []byte) (string, error) {
|
||||
if len(data) == 0 {
|
||||
return "", errors.New("empty character image")
|
||||
}
|
||||
token, err := c.getAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
userAgent := c.taskUserAgent()
|
||||
proxyURL := c.resolveProxyURL(account)
|
||||
|
||||
var body bytes.Buffer
|
||||
writer := multipart.NewWriter(&body)
|
||||
partHeader := make(textproto.MIMEHeader)
|
||||
partHeader.Set("Content-Disposition", `form-data; name="file"; filename="profile.webp"`)
|
||||
partHeader.Set("Content-Type", "image/webp")
|
||||
part, err := writer.CreatePart(partHeader)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if _, err := part.Write(data); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := writer.WriteField("use_case", "profile"); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := writer.Close(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
headers := c.buildBaseHeaders(token, userAgent)
|
||||
headers.Set("Content-Type", writer.FormDataContentType())
|
||||
respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodPost, c.buildURL("/project_y/file/upload"), headers, &body, false)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
assetPointer := strings.TrimSpace(gjson.GetBytes(respBody, "asset_pointer").String())
|
||||
if assetPointer == "" {
|
||||
return "", errors.New("character image upload response missing asset_pointer")
|
||||
}
|
||||
return assetPointer, nil
|
||||
}
|
||||
|
||||
func (c *SoraDirectClient) FinalizeCharacter(ctx context.Context, account *Account, req SoraCharacterFinalizeRequest) (string, error) {
|
||||
token, err := c.getAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
userAgent := c.taskUserAgent()
|
||||
proxyURL := c.resolveProxyURL(account)
|
||||
payload := map[string]any{
|
||||
"cameo_id": req.CameoID,
|
||||
"username": req.Username,
|
||||
"display_name": req.DisplayName,
|
||||
"profile_asset_pointer": req.ProfileAssetPointer,
|
||||
"instruction_set": nil,
|
||||
"safety_instruction_set": nil,
|
||||
}
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
headers := c.buildBaseHeaders(token, userAgent)
|
||||
headers.Set("Content-Type", "application/json")
|
||||
headers.Set("Origin", "https://sora.chatgpt.com")
|
||||
headers.Set("Referer", "https://sora.chatgpt.com/")
|
||||
respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodPost, c.buildURL("/characters/finalize"), headers, bytes.NewReader(body), false)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
characterID := strings.TrimSpace(gjson.GetBytes(respBody, "character.character_id").String())
|
||||
if characterID == "" {
|
||||
return "", errors.New("character finalize response missing character_id")
|
||||
}
|
||||
return characterID, nil
|
||||
}
|
||||
|
||||
func (c *SoraDirectClient) SetCharacterPublic(ctx context.Context, account *Account, cameoID string) error {
|
||||
token, err := c.getAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
userAgent := c.taskUserAgent()
|
||||
proxyURL := c.resolveProxyURL(account)
|
||||
payload := map[string]any{"visibility": "public"}
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
headers := c.buildBaseHeaders(token, userAgent)
|
||||
headers.Set("Content-Type", "application/json")
|
||||
headers.Set("Origin", "https://sora.chatgpt.com")
|
||||
headers.Set("Referer", "https://sora.chatgpt.com/")
|
||||
_, _, err = c.doRequestWithProxy(
|
||||
ctx,
|
||||
account,
|
||||
proxyURL,
|
||||
http.MethodPost,
|
||||
c.buildURL("/project_y/cameos/by_id/"+strings.TrimSpace(cameoID)+"/update_v2"),
|
||||
headers,
|
||||
bytes.NewReader(body),
|
||||
false,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *SoraDirectClient) DeleteCharacter(ctx context.Context, account *Account, characterID string) error {
|
||||
token, err := c.getAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
userAgent := c.taskUserAgent()
|
||||
proxyURL := c.resolveProxyURL(account)
|
||||
headers := c.buildBaseHeaders(token, userAgent)
|
||||
_, _, err = c.doRequestWithProxy(
|
||||
ctx,
|
||||
account,
|
||||
proxyURL,
|
||||
http.MethodDelete,
|
||||
c.buildURL("/project_y/characters/"+strings.TrimSpace(characterID)),
|
||||
headers,
|
||||
nil,
|
||||
false,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *SoraDirectClient) PostVideoForWatermarkFree(ctx context.Context, account *Account, generationID string) (string, error) {
|
||||
token, err := c.getAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
userAgent := c.taskUserAgent()
|
||||
proxyURL := c.resolveProxyURL(account)
|
||||
payload := map[string]any{
|
||||
"attachments_to_create": []map[string]any{
|
||||
{
|
||||
"generation_id": generationID,
|
||||
"kind": "sora",
|
||||
},
|
||||
},
|
||||
"post_text": "",
|
||||
}
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
headers := c.buildBaseHeaders(token, userAgent)
|
||||
headers.Set("Content-Type", "application/json")
|
||||
headers.Set("Origin", "https://sora.chatgpt.com")
|
||||
headers.Set("Referer", "https://sora.chatgpt.com/")
|
||||
sentinel, err := c.generateSentinelToken(ctx, account, token, userAgent, proxyURL)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
headers.Set("openai-sentinel-token", sentinel)
|
||||
respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodPost, c.buildURL("/project_y/post"), headers, bytes.NewReader(body), true)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
postID := strings.TrimSpace(gjson.GetBytes(respBody, "post.id").String())
|
||||
if postID == "" {
|
||||
return "", errors.New("watermark-free publish response missing post.id")
|
||||
}
|
||||
return postID, nil
|
||||
}
|
||||
|
||||
func (c *SoraDirectClient) DeletePost(ctx context.Context, account *Account, postID string) error {
|
||||
token, err := c.getAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
userAgent := c.taskUserAgent()
|
||||
proxyURL := c.resolveProxyURL(account)
|
||||
headers := c.buildBaseHeaders(token, userAgent)
|
||||
_, _, err = c.doRequestWithProxy(
|
||||
ctx,
|
||||
account,
|
||||
proxyURL,
|
||||
http.MethodDelete,
|
||||
c.buildURL("/project_y/post/"+strings.TrimSpace(postID)),
|
||||
headers,
|
||||
nil,
|
||||
false,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *SoraDirectClient) GetWatermarkFreeURLCustom(ctx context.Context, account *Account, parseURL, parseToken, postID string) (string, error) {
|
||||
parseURL = strings.TrimRight(strings.TrimSpace(parseURL), "/")
|
||||
if parseURL == "" {
|
||||
return "", errors.New("custom parse url is required")
|
||||
}
|
||||
if strings.TrimSpace(parseToken) == "" {
|
||||
return "", errors.New("custom parse token is required")
|
||||
}
|
||||
shareURL := "https://sora.chatgpt.com/p/" + strings.TrimSpace(postID)
|
||||
payload := map[string]any{
|
||||
"url": shareURL,
|
||||
"token": strings.TrimSpace(parseToken),
|
||||
}
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, parseURL+"/get-sora-link", bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
proxyURL := c.resolveProxyURL(account)
|
||||
accountID := int64(0)
|
||||
accountConcurrency := 0
|
||||
if account != nil {
|
||||
accountID = account.ID
|
||||
accountConcurrency = account.Concurrency
|
||||
}
|
||||
var resp *http.Response
|
||||
if c.httpUpstream != nil {
|
||||
resp, err = c.httpUpstream.Do(req, proxyURL, accountID, accountConcurrency)
|
||||
} else {
|
||||
resp, err = http.DefaultClient.Do(req)
|
||||
}
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
raw, err := io.ReadAll(io.LimitReader(resp.Body, 4<<20))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("custom parse failed: %d %s", resp.StatusCode, truncateForLog(raw, 256))
|
||||
}
|
||||
downloadLink := strings.TrimSpace(gjson.GetBytes(raw, "download_link").String())
|
||||
if downloadLink == "" {
|
||||
return "", errors.New("custom parse response missing download_link")
|
||||
}
|
||||
return downloadLink, nil
|
||||
}
|
||||
|
||||
func (c *SoraDirectClient) EnhancePrompt(ctx context.Context, account *Account, prompt, expansionLevel string, durationS int) (string, error) {
|
||||
token, err := c.getAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
@@ -607,6 +1082,7 @@ func (c *SoraDirectClient) GetVideoTask(ctx context.Context, account *Account, t
|
||||
if draft.Get("task_id").String() != taskID {
|
||||
return true
|
||||
}
|
||||
generationID := strings.TrimSpace(draft.Get("id").String())
|
||||
kind := strings.TrimSpace(draft.Get("kind").String())
|
||||
reason := strings.TrimSpace(draft.Get("reason_str").String())
|
||||
if reason == "" {
|
||||
@@ -623,15 +1099,17 @@ func (c *SoraDirectClient) GetVideoTask(ctx context.Context, account *Account, t
|
||||
msg = "Content violates guardrails"
|
||||
}
|
||||
draftFound = &SoraVideoTaskStatus{
|
||||
ID: taskID,
|
||||
Status: "failed",
|
||||
ErrorMsg: msg,
|
||||
ID: taskID,
|
||||
Status: "failed",
|
||||
GenerationID: generationID,
|
||||
ErrorMsg: msg,
|
||||
}
|
||||
} else {
|
||||
draftFound = &SoraVideoTaskStatus{
|
||||
ID: taskID,
|
||||
Status: "completed",
|
||||
URLs: []string{urlStr},
|
||||
ID: taskID,
|
||||
Status: "completed",
|
||||
GenerationID: generationID,
|
||||
URLs: []string{urlStr},
|
||||
}
|
||||
}
|
||||
return false
|
||||
@@ -675,8 +1153,11 @@ func (c *SoraDirectClient) taskUserAgent() string {
|
||||
return ua
|
||||
}
|
||||
}
|
||||
if len(soraMobileUserAgents) > 0 {
|
||||
return soraMobileUserAgents[soraRandInt(len(soraMobileUserAgents))]
|
||||
}
|
||||
if len(soraDesktopUserAgents) > 0 {
|
||||
return soraDesktopUserAgents[0]
|
||||
return soraDesktopUserAgents[soraRandInt(len(soraDesktopUserAgents))]
|
||||
}
|
||||
return soraDefaultUserAgent
|
||||
}
|
||||
@@ -1149,10 +1630,7 @@ func shouldAttemptSoraTokenRecover(statusCode int, rawURL string) bool {
|
||||
}
|
||||
|
||||
func (c *SoraDirectClient) doHTTP(req *http.Request, proxyURL string, account *Account) (*http.Response, error) {
|
||||
enableTLS := true
|
||||
if c != nil && c.cfg != nil && c.cfg.Sora.Client.DisableTLSFingerprint {
|
||||
enableTLS = false
|
||||
}
|
||||
enableTLS := c == nil || c.cfg == nil || !c.cfg.Sora.Client.DisableTLSFingerprint
|
||||
if c.httpUpstream != nil {
|
||||
accountID := int64(0)
|
||||
accountConcurrency := 0
|
||||
@@ -1288,6 +1766,15 @@ func soraRandFloat() float64 {
|
||||
return soraRand.Float64()
|
||||
}
|
||||
|
||||
func soraRandInt(max int) int {
|
||||
if max <= 1 {
|
||||
return 0
|
||||
}
|
||||
soraRandMu.Lock()
|
||||
defer soraRandMu.Unlock()
|
||||
return soraRand.Intn(max)
|
||||
}
|
||||
|
||||
func soraBuildPowConfig(userAgent string) []any {
|
||||
userAgent = strings.TrimSpace(userAgent)
|
||||
if userAgent == "" && len(soraDesktopUserAgents) > 0 {
|
||||
|
||||
@@ -393,10 +393,22 @@ func (u *soraClientRecordingUpstream) DoWithTLS(req *http.Request, proxyURL stri
|
||||
return newSoraClientMockResponse(http.StatusOK, `{"token":"sentinel-token","turnstile":{"dx":"ok"}}`), nil
|
||||
case "/backend/nf/create":
|
||||
return newSoraClientMockResponse(http.StatusOK, `{"id":"task-123"}`), nil
|
||||
case "/backend/nf/create/storyboard":
|
||||
return newSoraClientMockResponse(http.StatusOK, `{"id":"storyboard-123"}`), nil
|
||||
case "/backend/uploads":
|
||||
return newSoraClientMockResponse(http.StatusOK, `{"id":"upload-123"}`), nil
|
||||
case "/backend/nf/check":
|
||||
return newSoraClientMockResponse(http.StatusOK, `{"rate_limit_and_credit_balance":{"estimated_num_videos_remaining":1,"rate_limit_reached":false}}`), nil
|
||||
case "/backend/characters/upload":
|
||||
return newSoraClientMockResponse(http.StatusOK, `{"id":"cameo-123"}`), nil
|
||||
case "/backend/project_y/cameos/in_progress/cameo-123":
|
||||
return newSoraClientMockResponse(http.StatusOK, `{"status":"finalized","status_message":"Completed","username_hint":"foo.bar","display_name_hint":"Bar","profile_asset_url":"https://example.com/avatar.webp"}`), nil
|
||||
case "/backend/project_y/file/upload":
|
||||
return newSoraClientMockResponse(http.StatusOK, `{"asset_pointer":"asset-123"}`), nil
|
||||
case "/backend/characters/finalize":
|
||||
return newSoraClientMockResponse(http.StatusOK, `{"character":{"character_id":"character-123"}}`), nil
|
||||
case "/backend/project_y/post":
|
||||
return newSoraClientMockResponse(http.StatusOK, `{"post":{"id":"s_post"}}`), nil
|
||||
default:
|
||||
return newSoraClientMockResponse(http.StatusOK, `{"ok":true}`), nil
|
||||
}
|
||||
@@ -410,9 +422,13 @@ func newSoraClientMockResponse(statusCode int, body string) *http.Response {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSoraDirectClient_TaskUserAgent_DefaultDesktopFallback(t *testing.T) {
|
||||
func TestSoraDirectClient_TaskUserAgent_DefaultMobileFallback(t *testing.T) {
|
||||
client := NewSoraDirectClient(&config.Config{}, nil, nil)
|
||||
require.Equal(t, soraDesktopUserAgents[0], client.taskUserAgent())
|
||||
ua := client.taskUserAgent()
|
||||
require.NotEmpty(t, ua)
|
||||
allowed := append([]string{}, soraMobileUserAgents...)
|
||||
allowed = append(allowed, soraDesktopUserAgents...)
|
||||
require.Contains(t, allowed, ua)
|
||||
}
|
||||
|
||||
func TestSoraDirectClient_CreateVideoTask_UsesSameUserAgentAndProxyForSentinelAndCreate(t *testing.T) {
|
||||
@@ -460,7 +476,7 @@ func TestSoraDirectClient_CreateVideoTask_UsesSameUserAgentAndProxyForSentinelAn
|
||||
require.Equal(t, "/backend/nf/create", createCall.Path)
|
||||
require.Equal(t, "http://127.0.0.1:8080", sentinelCall.ProxyURL)
|
||||
require.Equal(t, sentinelCall.ProxyURL, createCall.ProxyURL)
|
||||
require.Equal(t, soraDesktopUserAgents[0], sentinelCall.UserAgent)
|
||||
require.NotEmpty(t, sentinelCall.UserAgent)
|
||||
require.Equal(t, sentinelCall.UserAgent, createCall.UserAgent)
|
||||
}
|
||||
|
||||
@@ -495,7 +511,7 @@ func TestSoraDirectClient_UploadImage_UsesTaskUserAgentAndProxy(t *testing.T) {
|
||||
require.Len(t, upstream.calls, 1)
|
||||
require.Equal(t, "/backend/uploads", upstream.calls[0].Path)
|
||||
require.Equal(t, "http://127.0.0.1:8080", upstream.calls[0].ProxyURL)
|
||||
require.Equal(t, soraDesktopUserAgents[0], upstream.calls[0].UserAgent)
|
||||
require.NotEmpty(t, upstream.calls[0].UserAgent)
|
||||
}
|
||||
|
||||
func TestSoraDirectClient_PreflightCheck_UsesTaskUserAgentAndProxy(t *testing.T) {
|
||||
@@ -528,5 +544,98 @@ func TestSoraDirectClient_PreflightCheck_UsesTaskUserAgentAndProxy(t *testing.T)
|
||||
require.Len(t, upstream.calls, 1)
|
||||
require.Equal(t, "/backend/nf/check", upstream.calls[0].Path)
|
||||
require.Equal(t, "http://127.0.0.1:8080", upstream.calls[0].ProxyURL)
|
||||
require.Equal(t, soraDesktopUserAgents[0], upstream.calls[0].UserAgent)
|
||||
require.NotEmpty(t, upstream.calls[0].UserAgent)
|
||||
}
|
||||
|
||||
func TestSoraDirectClient_CreateStoryboardTask(t *testing.T) {
|
||||
originPowTokenGenerator := soraPowTokenGenerator
|
||||
soraPowTokenGenerator = func(_ string) string { return "gAAAAACmock" }
|
||||
defer func() { soraPowTokenGenerator = originPowTokenGenerator }()
|
||||
|
||||
upstream := &soraClientRecordingUpstream{}
|
||||
cfg := &config.Config{
|
||||
Sora: config.SoraConfig{
|
||||
Client: config.SoraClientConfig{
|
||||
BaseURL: "https://sora.chatgpt.com/backend",
|
||||
},
|
||||
},
|
||||
}
|
||||
client := NewSoraDirectClient(cfg, upstream, nil)
|
||||
account := &Account{
|
||||
ID: 51,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "access-token",
|
||||
"expires_at": time.Now().Add(30 * time.Minute).Format(time.RFC3339),
|
||||
},
|
||||
}
|
||||
|
||||
taskID, err := client.CreateStoryboardTask(context.Background(), account, SoraStoryboardRequest{
|
||||
Prompt: "Shot 1:\nduration: 5sec\nScene: cat",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "storyboard-123", taskID)
|
||||
require.Len(t, upstream.calls, 2)
|
||||
require.Equal(t, "/backend-api/sentinel/req", upstream.calls[0].Path)
|
||||
require.Equal(t, "/backend/nf/create/storyboard", upstream.calls[1].Path)
|
||||
}
|
||||
|
||||
func TestSoraDirectClient_GetVideoTask_ReturnsGenerationID(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
switch r.URL.Path {
|
||||
case "/nf/pending/v2":
|
||||
_, _ = w.Write([]byte(`[]`))
|
||||
case "/project_y/profile/drafts":
|
||||
_, _ = w.Write([]byte(`{"items":[{"id":"gen_1","task_id":"task-1","kind":"video","downloadable_url":"https://example.com/v.mp4"}]}`))
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
cfg := &config.Config{
|
||||
Sora: config.SoraConfig{
|
||||
Client: config.SoraClientConfig{
|
||||
BaseURL: server.URL,
|
||||
},
|
||||
},
|
||||
}
|
||||
client := NewSoraDirectClient(cfg, nil, nil)
|
||||
account := &Account{Credentials: map[string]any{"access_token": "token"}}
|
||||
|
||||
status, err := client.GetVideoTask(context.Background(), account, "task-1")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "completed", status.Status)
|
||||
require.Equal(t, "gen_1", status.GenerationID)
|
||||
require.Equal(t, []string{"https://example.com/v.mp4"}, status.URLs)
|
||||
}
|
||||
|
||||
func TestSoraDirectClient_PostVideoForWatermarkFree(t *testing.T) {
|
||||
originPowTokenGenerator := soraPowTokenGenerator
|
||||
soraPowTokenGenerator = func(_ string) string { return "gAAAAACmock" }
|
||||
defer func() { soraPowTokenGenerator = originPowTokenGenerator }()
|
||||
|
||||
upstream := &soraClientRecordingUpstream{}
|
||||
cfg := &config.Config{
|
||||
Sora: config.SoraConfig{
|
||||
Client: config.SoraClientConfig{
|
||||
BaseURL: "https://sora.chatgpt.com/backend",
|
||||
},
|
||||
},
|
||||
}
|
||||
client := NewSoraDirectClient(cfg, upstream, nil)
|
||||
account := &Account{
|
||||
ID: 52,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "access-token",
|
||||
"expires_at": time.Now().Add(30 * time.Minute).Format(time.RFC3339),
|
||||
},
|
||||
}
|
||||
|
||||
postID, err := client.PostVideoForWatermarkFree(context.Background(), account, "gen_1")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "s_post", postID)
|
||||
require.Len(t, upstream.calls, 2)
|
||||
require.Equal(t, "/backend-api/sentinel/req", upstream.calls[0].Path)
|
||||
require.Equal(t, "/backend/project_y/post", upstream.calls[1].Path)
|
||||
}
|
||||
|
||||
@@ -8,10 +8,12 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"math"
|
||||
"mime"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -23,6 +25,9 @@ import (
|
||||
const soraImageInputMaxBytes = 20 << 20
|
||||
const soraImageInputMaxRedirects = 3
|
||||
const soraImageInputTimeout = 20 * time.Second
|
||||
const soraVideoInputMaxBytes = 200 << 20
|
||||
const soraVideoInputMaxRedirects = 3
|
||||
const soraVideoInputTimeout = 60 * time.Second
|
||||
|
||||
var soraImageSizeMap = map[string]string{
|
||||
"gpt-image": "360",
|
||||
@@ -61,6 +66,32 @@ type SoraGatewayService struct {
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
type soraWatermarkOptions struct {
|
||||
Enabled bool
|
||||
ParseMethod string
|
||||
ParseURL string
|
||||
ParseToken string
|
||||
FallbackOnFailure bool
|
||||
DeletePost bool
|
||||
}
|
||||
|
||||
type soraCharacterOptions struct {
|
||||
SetPublic bool
|
||||
DeleteAfterGenerate bool
|
||||
}
|
||||
|
||||
type soraCharacterFlowResult struct {
|
||||
CameoID string
|
||||
CharacterID string
|
||||
Username string
|
||||
DisplayName string
|
||||
}
|
||||
|
||||
var soraStoryboardPattern = regexp.MustCompile(`\[\d+(?:\.\d+)?s\]`)
|
||||
var soraStoryboardShotPattern = regexp.MustCompile(`\[(\d+(?:\.\d+)?)s\]\s*([^\[]+)`)
|
||||
var soraRemixTargetPattern = regexp.MustCompile(`s_[a-f0-9]{32}`)
|
||||
var soraRemixTargetInURLPattern = regexp.MustCompile(`https://sora\.chatgpt\.com/p/s_[a-f0-9]{32}`)
|
||||
|
||||
type soraPreflightChecker interface {
|
||||
PreflightCheck(ctx context.Context, account *Account, requestedModel string, modelCfg SoraModelConfig) error
|
||||
}
|
||||
@@ -117,20 +148,34 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun
|
||||
return nil, fmt.Errorf("unsupported model: %s", reqModel)
|
||||
}
|
||||
prompt, imageInput, videoInput, remixTargetID := extractSoraInput(reqBody)
|
||||
if strings.TrimSpace(prompt) == "" {
|
||||
prompt = strings.TrimSpace(prompt)
|
||||
imageInput = strings.TrimSpace(imageInput)
|
||||
videoInput = strings.TrimSpace(videoInput)
|
||||
remixTargetID = strings.TrimSpace(remixTargetID)
|
||||
|
||||
if videoInput != "" && modelCfg.Type != "video" {
|
||||
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "video input only supports video models", clientStream)
|
||||
return nil, errors.New("video input only supports video models")
|
||||
}
|
||||
if videoInput != "" && imageInput != "" {
|
||||
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "image input and video input cannot be used together", clientStream)
|
||||
return nil, errors.New("image input and video input cannot be used together")
|
||||
}
|
||||
characterOnly := videoInput != "" && prompt == ""
|
||||
if modelCfg.Type == "prompt_enhance" && prompt == "" {
|
||||
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "prompt is required", clientStream)
|
||||
return nil, errors.New("prompt is required")
|
||||
}
|
||||
if strings.TrimSpace(videoInput) != "" {
|
||||
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "Video input is not supported yet", clientStream)
|
||||
return nil, errors.New("video input not supported")
|
||||
if modelCfg.Type != "prompt_enhance" && prompt == "" && !characterOnly {
|
||||
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "prompt is required", clientStream)
|
||||
return nil, errors.New("prompt is required")
|
||||
}
|
||||
|
||||
reqCtx, cancel := s.withSoraTimeout(ctx, reqStream)
|
||||
if cancel != nil {
|
||||
defer cancel()
|
||||
}
|
||||
if checker, ok := s.soraClient.(soraPreflightChecker); ok {
|
||||
if checker, ok := s.soraClient.(soraPreflightChecker); ok && !characterOnly {
|
||||
if err := checker.PreflightCheck(reqCtx, account, reqModel, modelCfg); err != nil {
|
||||
return nil, s.handleSoraRequestError(ctx, account, err, reqModel, c, clientStream)
|
||||
}
|
||||
@@ -166,9 +211,69 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun
|
||||
}, nil
|
||||
}
|
||||
|
||||
characterOpts := parseSoraCharacterOptions(reqBody)
|
||||
watermarkOpts := parseSoraWatermarkOptions(reqBody)
|
||||
var characterResult *soraCharacterFlowResult
|
||||
if videoInput != "" {
|
||||
videoData, videoErr := decodeSoraVideoInput(reqCtx, videoInput)
|
||||
if videoErr != nil {
|
||||
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", videoErr.Error(), clientStream)
|
||||
return nil, videoErr
|
||||
}
|
||||
characterResult, videoErr = s.createCharacterFromVideo(reqCtx, account, videoData, characterOpts)
|
||||
if videoErr != nil {
|
||||
return nil, s.handleSoraRequestError(ctx, account, videoErr, reqModel, c, clientStream)
|
||||
}
|
||||
if characterResult != nil && characterOpts.DeleteAfterGenerate && strings.TrimSpace(characterResult.CharacterID) != "" && !characterOnly {
|
||||
characterID := strings.TrimSpace(characterResult.CharacterID)
|
||||
defer func() {
|
||||
cleanupCtx, cancelCleanup := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancelCleanup()
|
||||
if err := s.soraClient.DeleteCharacter(cleanupCtx, account, characterID); err != nil {
|
||||
log.Printf("[Sora] cleanup character failed, character_id=%s err=%v", characterID, err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
if characterOnly {
|
||||
content := "角色创建成功"
|
||||
if characterResult != nil && strings.TrimSpace(characterResult.Username) != "" {
|
||||
content = fmt.Sprintf("角色创建成功,角色名@%s", strings.TrimSpace(characterResult.Username))
|
||||
}
|
||||
var firstTokenMs *int
|
||||
if clientStream {
|
||||
ms, streamErr := s.writeSoraStream(c, reqModel, content, startTime)
|
||||
if streamErr != nil {
|
||||
return nil, streamErr
|
||||
}
|
||||
firstTokenMs = ms
|
||||
} else if c != nil {
|
||||
resp := buildSoraNonStreamResponse(content, reqModel)
|
||||
if characterResult != nil {
|
||||
resp["character_id"] = characterResult.CharacterID
|
||||
resp["cameo_id"] = characterResult.CameoID
|
||||
resp["character_username"] = characterResult.Username
|
||||
resp["character_display_name"] = characterResult.DisplayName
|
||||
}
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
return &ForwardResult{
|
||||
RequestID: "",
|
||||
Model: reqModel,
|
||||
Stream: clientStream,
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: firstTokenMs,
|
||||
Usage: ClaudeUsage{},
|
||||
MediaType: "prompt",
|
||||
}, nil
|
||||
}
|
||||
if characterResult != nil && strings.TrimSpace(characterResult.Username) != "" {
|
||||
prompt = fmt.Sprintf("@%s %s", characterResult.Username, prompt)
|
||||
}
|
||||
}
|
||||
|
||||
var imageData []byte
|
||||
imageFilename := ""
|
||||
if strings.TrimSpace(imageInput) != "" {
|
||||
if imageInput != "" {
|
||||
decoded, filename, err := decodeSoraImageInput(reqCtx, imageInput)
|
||||
if err != nil {
|
||||
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", err.Error(), clientStream)
|
||||
@@ -198,15 +303,27 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun
|
||||
MediaID: mediaID,
|
||||
})
|
||||
case "video":
|
||||
taskID, err = s.soraClient.CreateVideoTask(reqCtx, account, SoraVideoRequest{
|
||||
Prompt: prompt,
|
||||
Orientation: modelCfg.Orientation,
|
||||
Frames: modelCfg.Frames,
|
||||
Model: modelCfg.Model,
|
||||
Size: modelCfg.Size,
|
||||
MediaID: mediaID,
|
||||
RemixTargetID: remixTargetID,
|
||||
})
|
||||
if remixTargetID == "" && isSoraStoryboardPrompt(prompt) {
|
||||
taskID, err = s.soraClient.CreateStoryboardTask(reqCtx, account, SoraStoryboardRequest{
|
||||
Prompt: formatSoraStoryboardPrompt(prompt),
|
||||
Orientation: modelCfg.Orientation,
|
||||
Frames: modelCfg.Frames,
|
||||
Model: modelCfg.Model,
|
||||
Size: modelCfg.Size,
|
||||
MediaID: mediaID,
|
||||
})
|
||||
} else {
|
||||
taskID, err = s.soraClient.CreateVideoTask(reqCtx, account, SoraVideoRequest{
|
||||
Prompt: prompt,
|
||||
Orientation: modelCfg.Orientation,
|
||||
Frames: modelCfg.Frames,
|
||||
Model: modelCfg.Model,
|
||||
Size: modelCfg.Size,
|
||||
MediaID: mediaID,
|
||||
RemixTargetID: remixTargetID,
|
||||
CameoIDs: extractSoraCameoIDs(reqBody),
|
||||
})
|
||||
}
|
||||
default:
|
||||
err = fmt.Errorf("unsupported model type: %s", modelCfg.Type)
|
||||
}
|
||||
@@ -219,6 +336,7 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun
|
||||
}
|
||||
|
||||
var mediaURLs []string
|
||||
videoGenerationID := ""
|
||||
mediaType := modelCfg.Type
|
||||
imageCount := 0
|
||||
imageSize := ""
|
||||
@@ -232,15 +350,32 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun
|
||||
imageCount = len(urls)
|
||||
imageSize = soraImageSizeFromModel(reqModel)
|
||||
case "video":
|
||||
urls, pollErr := s.pollVideoTask(reqCtx, c, account, taskID, clientStream)
|
||||
videoStatus, pollErr := s.pollVideoTaskDetailed(reqCtx, c, account, taskID, clientStream)
|
||||
if pollErr != nil {
|
||||
return nil, s.handleSoraRequestError(ctx, account, pollErr, reqModel, c, clientStream)
|
||||
}
|
||||
mediaURLs = urls
|
||||
if videoStatus != nil {
|
||||
mediaURLs = videoStatus.URLs
|
||||
videoGenerationID = strings.TrimSpace(videoStatus.GenerationID)
|
||||
}
|
||||
default:
|
||||
mediaType = "prompt"
|
||||
}
|
||||
|
||||
watermarkPostID := ""
|
||||
if modelCfg.Type == "video" && watermarkOpts.Enabled {
|
||||
watermarkURL, postID, watermarkErr := s.resolveWatermarkFreeURL(reqCtx, account, videoGenerationID, watermarkOpts)
|
||||
if watermarkErr != nil {
|
||||
if !watermarkOpts.FallbackOnFailure {
|
||||
return nil, s.handleSoraRequestError(ctx, account, watermarkErr, reqModel, c, clientStream)
|
||||
}
|
||||
log.Printf("[Sora] watermark-free fallback to original URL, task_id=%s err=%v", taskID, watermarkErr)
|
||||
} else if strings.TrimSpace(watermarkURL) != "" {
|
||||
mediaURLs = []string{strings.TrimSpace(watermarkURL)}
|
||||
watermarkPostID = strings.TrimSpace(postID)
|
||||
}
|
||||
}
|
||||
|
||||
finalURLs := s.normalizeSoraMediaURLs(mediaURLs)
|
||||
if len(mediaURLs) > 0 && s.mediaStorage != nil && s.mediaStorage.Enabled() {
|
||||
stored, storeErr := s.mediaStorage.StoreFromURLs(reqCtx, mediaType, mediaURLs)
|
||||
@@ -251,6 +386,11 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun
|
||||
finalURLs = s.normalizeSoraMediaURLs(stored)
|
||||
}
|
||||
}
|
||||
if watermarkPostID != "" && watermarkOpts.DeletePost {
|
||||
if deleteErr := s.soraClient.DeletePost(reqCtx, account, watermarkPostID); deleteErr != nil {
|
||||
log.Printf("[Sora] delete post failed, post_id=%s err=%v", watermarkPostID, deleteErr)
|
||||
}
|
||||
}
|
||||
|
||||
content := buildSoraContent(mediaType, finalURLs)
|
||||
var firstTokenMs *int
|
||||
@@ -299,6 +439,267 @@ func (s *SoraGatewayService) withSoraTimeout(ctx context.Context, stream bool) (
|
||||
return context.WithTimeout(ctx, time.Duration(timeoutSeconds)*time.Second)
|
||||
}
|
||||
|
||||
func parseSoraWatermarkOptions(body map[string]any) soraWatermarkOptions {
|
||||
opts := soraWatermarkOptions{
|
||||
Enabled: parseBoolWithDefault(body, "watermark_free", false),
|
||||
ParseMethod: strings.ToLower(strings.TrimSpace(parseStringWithDefault(body, "watermark_parse_method", "third_party"))),
|
||||
ParseURL: strings.TrimSpace(parseStringWithDefault(body, "watermark_parse_url", "")),
|
||||
ParseToken: strings.TrimSpace(parseStringWithDefault(body, "watermark_parse_token", "")),
|
||||
FallbackOnFailure: parseBoolWithDefault(body, "watermark_fallback_on_failure", true),
|
||||
DeletePost: parseBoolWithDefault(body, "watermark_delete_post", false),
|
||||
}
|
||||
if opts.ParseMethod == "" {
|
||||
opts.ParseMethod = "third_party"
|
||||
}
|
||||
return opts
|
||||
}
|
||||
|
||||
func parseSoraCharacterOptions(body map[string]any) soraCharacterOptions {
|
||||
return soraCharacterOptions{
|
||||
SetPublic: parseBoolWithDefault(body, "character_set_public", true),
|
||||
DeleteAfterGenerate: parseBoolWithDefault(body, "character_delete_after_generate", true),
|
||||
}
|
||||
}
|
||||
|
||||
func parseBoolWithDefault(body map[string]any, key string, def bool) bool {
|
||||
if body == nil {
|
||||
return def
|
||||
}
|
||||
val, ok := body[key]
|
||||
if !ok {
|
||||
return def
|
||||
}
|
||||
switch typed := val.(type) {
|
||||
case bool:
|
||||
return typed
|
||||
case int:
|
||||
return typed != 0
|
||||
case int32:
|
||||
return typed != 0
|
||||
case int64:
|
||||
return typed != 0
|
||||
case float64:
|
||||
return typed != 0
|
||||
case string:
|
||||
typed = strings.ToLower(strings.TrimSpace(typed))
|
||||
if typed == "true" || typed == "1" || typed == "yes" {
|
||||
return true
|
||||
}
|
||||
if typed == "false" || typed == "0" || typed == "no" {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return def
|
||||
}
|
||||
|
||||
func parseStringWithDefault(body map[string]any, key, def string) string {
|
||||
if body == nil {
|
||||
return def
|
||||
}
|
||||
val, ok := body[key]
|
||||
if !ok {
|
||||
return def
|
||||
}
|
||||
if str, ok := val.(string); ok {
|
||||
return str
|
||||
}
|
||||
return def
|
||||
}
|
||||
|
||||
func extractSoraCameoIDs(body map[string]any) []string {
|
||||
if body == nil {
|
||||
return nil
|
||||
}
|
||||
raw, ok := body["cameo_ids"]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
switch typed := raw.(type) {
|
||||
case []string:
|
||||
out := make([]string, 0, len(typed))
|
||||
for _, item := range typed {
|
||||
item = strings.TrimSpace(item)
|
||||
if item != "" {
|
||||
out = append(out, item)
|
||||
}
|
||||
}
|
||||
return out
|
||||
case []any:
|
||||
out := make([]string, 0, len(typed))
|
||||
for _, item := range typed {
|
||||
str, ok := item.(string)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
str = strings.TrimSpace(str)
|
||||
if str != "" {
|
||||
out = append(out, str)
|
||||
}
|
||||
}
|
||||
return out
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SoraGatewayService) createCharacterFromVideo(ctx context.Context, account *Account, videoData []byte, opts soraCharacterOptions) (*soraCharacterFlowResult, error) {
|
||||
cameoID, err := s.soraClient.UploadCharacterVideo(ctx, account, videoData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cameoStatus, err := s.pollCameoStatus(ctx, account, cameoID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
username := processSoraCharacterUsername(cameoStatus.UsernameHint)
|
||||
displayName := strings.TrimSpace(cameoStatus.DisplayNameHint)
|
||||
if displayName == "" {
|
||||
displayName = "Character"
|
||||
}
|
||||
profileAssetURL := strings.TrimSpace(cameoStatus.ProfileAssetURL)
|
||||
if profileAssetURL == "" {
|
||||
return nil, errors.New("profile asset url not found in cameo status")
|
||||
}
|
||||
|
||||
avatarData, err := s.soraClient.DownloadCharacterImage(ctx, account, profileAssetURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
assetPointer, err := s.soraClient.UploadCharacterImage(ctx, account, avatarData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
instructionSet := cameoStatus.InstructionSetHint
|
||||
if instructionSet == nil {
|
||||
instructionSet = cameoStatus.InstructionSet
|
||||
}
|
||||
|
||||
characterID, err := s.soraClient.FinalizeCharacter(ctx, account, SoraCharacterFinalizeRequest{
|
||||
CameoID: strings.TrimSpace(cameoID),
|
||||
Username: username,
|
||||
DisplayName: displayName,
|
||||
ProfileAssetPointer: assetPointer,
|
||||
InstructionSet: instructionSet,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if opts.SetPublic {
|
||||
if err := s.soraClient.SetCharacterPublic(ctx, account, cameoID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &soraCharacterFlowResult{
|
||||
CameoID: strings.TrimSpace(cameoID),
|
||||
CharacterID: strings.TrimSpace(characterID),
|
||||
Username: strings.TrimSpace(username),
|
||||
DisplayName: displayName,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *SoraGatewayService) pollCameoStatus(ctx context.Context, account *Account, cameoID string) (*SoraCameoStatus, error) {
|
||||
timeout := 10 * time.Minute
|
||||
interval := 5 * time.Second
|
||||
maxAttempts := int(math.Ceil(timeout.Seconds() / interval.Seconds()))
|
||||
if maxAttempts < 1 {
|
||||
maxAttempts = 1
|
||||
}
|
||||
|
||||
var lastErr error
|
||||
consecutiveErrors := 0
|
||||
for attempt := 0; attempt < maxAttempts; attempt++ {
|
||||
status, err := s.soraClient.GetCameoStatus(ctx, account, cameoID)
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
consecutiveErrors++
|
||||
if consecutiveErrors >= 3 {
|
||||
break
|
||||
}
|
||||
if attempt < maxAttempts-1 {
|
||||
if sleepErr := sleepWithContext(ctx, interval); sleepErr != nil {
|
||||
return nil, sleepErr
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
consecutiveErrors = 0
|
||||
if status == nil {
|
||||
if attempt < maxAttempts-1 {
|
||||
if sleepErr := sleepWithContext(ctx, interval); sleepErr != nil {
|
||||
return nil, sleepErr
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
currentStatus := strings.ToLower(strings.TrimSpace(status.Status))
|
||||
statusMessage := strings.TrimSpace(status.StatusMessage)
|
||||
if currentStatus == "failed" {
|
||||
if statusMessage == "" {
|
||||
statusMessage = "character creation failed"
|
||||
}
|
||||
return nil, errors.New(statusMessage)
|
||||
}
|
||||
if strings.EqualFold(statusMessage, "Completed") || currentStatus == "finalized" {
|
||||
return status, nil
|
||||
}
|
||||
if attempt < maxAttempts-1 {
|
||||
if sleepErr := sleepWithContext(ctx, interval); sleepErr != nil {
|
||||
return nil, sleepErr
|
||||
}
|
||||
}
|
||||
}
|
||||
if lastErr != nil {
|
||||
return nil, fmt.Errorf("poll cameo status failed: %w", lastErr)
|
||||
}
|
||||
return nil, errors.New("cameo processing timeout")
|
||||
}
|
||||
|
||||
func processSoraCharacterUsername(usernameHint string) string {
|
||||
usernameHint = strings.TrimSpace(usernameHint)
|
||||
if usernameHint == "" {
|
||||
usernameHint = "character"
|
||||
}
|
||||
if strings.Contains(usernameHint, ".") {
|
||||
parts := strings.Split(usernameHint, ".")
|
||||
usernameHint = strings.TrimSpace(parts[len(parts)-1])
|
||||
}
|
||||
if usernameHint == "" {
|
||||
usernameHint = "character"
|
||||
}
|
||||
return fmt.Sprintf("%s%d", usernameHint, soraRandInt(900)+100)
|
||||
}
|
||||
|
||||
func (s *SoraGatewayService) resolveWatermarkFreeURL(ctx context.Context, account *Account, generationID string, opts soraWatermarkOptions) (string, string, error) {
|
||||
generationID = strings.TrimSpace(generationID)
|
||||
if generationID == "" {
|
||||
return "", "", errors.New("generation id is required for watermark-free mode")
|
||||
}
|
||||
postID, err := s.soraClient.PostVideoForWatermarkFree(ctx, account, generationID)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
postID = strings.TrimSpace(postID)
|
||||
if postID == "" {
|
||||
return "", "", errors.New("watermark-free publish returned empty post id")
|
||||
}
|
||||
|
||||
switch opts.ParseMethod {
|
||||
case "custom":
|
||||
urlVal, parseErr := s.soraClient.GetWatermarkFreeURLCustom(ctx, account, opts.ParseURL, opts.ParseToken, postID)
|
||||
if parseErr != nil {
|
||||
return "", postID, parseErr
|
||||
}
|
||||
return strings.TrimSpace(urlVal), postID, nil
|
||||
case "", "third_party":
|
||||
return fmt.Sprintf("https://oscdn2.dyysy.com/MP4/%s.mp4", postID), postID, nil
|
||||
default:
|
||||
return "", postID, fmt.Errorf("unsupported watermark parse method: %s", opts.ParseMethod)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SoraGatewayService) shouldFailoverUpstreamError(statusCode int) bool {
|
||||
switch statusCode {
|
||||
case 401, 402, 403, 404, 429, 529:
|
||||
@@ -554,7 +955,7 @@ func (s *SoraGatewayService) pollImageTask(ctx context.Context, c *gin.Context,
|
||||
return nil, errors.New("sora image generation timeout")
|
||||
}
|
||||
|
||||
func (s *SoraGatewayService) pollVideoTask(ctx context.Context, c *gin.Context, account *Account, taskID string, stream bool) ([]string, error) {
|
||||
func (s *SoraGatewayService) pollVideoTaskDetailed(ctx context.Context, c *gin.Context, account *Account, taskID string, stream bool) (*SoraVideoTaskStatus, error) {
|
||||
interval := s.pollInterval()
|
||||
maxAttempts := s.pollMaxAttempts()
|
||||
lastPing := time.Now()
|
||||
@@ -565,7 +966,7 @@ func (s *SoraGatewayService) pollVideoTask(ctx context.Context, c *gin.Context,
|
||||
}
|
||||
switch strings.ToLower(status.Status) {
|
||||
case "completed", "succeeded":
|
||||
return status.URLs, nil
|
||||
return status, nil
|
||||
case "failed":
|
||||
if status.ErrorMsg != "" {
|
||||
return nil, errors.New(status.ErrorMsg)
|
||||
@@ -669,7 +1070,7 @@ func extractSoraInput(body map[string]any) (prompt, imageInput, videoInput, remi
|
||||
return "", "", "", ""
|
||||
}
|
||||
if v, ok := body["remix_target_id"].(string); ok {
|
||||
remixTargetID = v
|
||||
remixTargetID = strings.TrimSpace(v)
|
||||
}
|
||||
if v, ok := body["image"].(string); ok {
|
||||
imageInput = v
|
||||
@@ -710,6 +1111,10 @@ func extractSoraInput(body map[string]any) (prompt, imageInput, videoInput, remi
|
||||
prompt = builder.String()
|
||||
}
|
||||
}
|
||||
if remixTargetID == "" {
|
||||
remixTargetID = extractRemixTargetIDFromPrompt(prompt)
|
||||
}
|
||||
prompt = cleanRemixLinkFromPrompt(prompt)
|
||||
return prompt, imageInput, videoInput, remixTargetID
|
||||
}
|
||||
|
||||
@@ -757,6 +1162,69 @@ func parseSoraMessageContent(content any) (text, imageInput, videoInput string)
|
||||
}
|
||||
}
|
||||
|
||||
func isSoraStoryboardPrompt(prompt string) bool {
|
||||
prompt = strings.TrimSpace(prompt)
|
||||
if prompt == "" {
|
||||
return false
|
||||
}
|
||||
return len(soraStoryboardPattern.FindAllString(prompt, -1)) >= 1
|
||||
}
|
||||
|
||||
func formatSoraStoryboardPrompt(prompt string) string {
|
||||
prompt = strings.TrimSpace(prompt)
|
||||
if prompt == "" {
|
||||
return ""
|
||||
}
|
||||
matches := soraStoryboardShotPattern.FindAllStringSubmatch(prompt, -1)
|
||||
if len(matches) == 0 {
|
||||
return prompt
|
||||
}
|
||||
firstBracketPos := strings.Index(prompt, "[")
|
||||
instructions := ""
|
||||
if firstBracketPos > 0 {
|
||||
instructions = strings.TrimSpace(prompt[:firstBracketPos])
|
||||
}
|
||||
shots := make([]string, 0, len(matches))
|
||||
for i, match := range matches {
|
||||
if len(match) < 3 {
|
||||
continue
|
||||
}
|
||||
duration := strings.TrimSpace(match[1])
|
||||
scene := strings.TrimSpace(match[2])
|
||||
if scene == "" {
|
||||
continue
|
||||
}
|
||||
shots = append(shots, fmt.Sprintf("Shot %d:\nduration: %ssec\nScene: %s", i+1, duration, scene))
|
||||
}
|
||||
if len(shots) == 0 {
|
||||
return prompt
|
||||
}
|
||||
timeline := strings.Join(shots, "\n\n")
|
||||
if instructions == "" {
|
||||
return timeline
|
||||
}
|
||||
return fmt.Sprintf("current timeline:\n%s\n\ninstructions:\n%s", timeline, instructions)
|
||||
}
|
||||
|
||||
func extractRemixTargetIDFromPrompt(prompt string) string {
|
||||
prompt = strings.TrimSpace(prompt)
|
||||
if prompt == "" {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(soraRemixTargetPattern.FindString(prompt))
|
||||
}
|
||||
|
||||
func cleanRemixLinkFromPrompt(prompt string) string {
|
||||
prompt = strings.TrimSpace(prompt)
|
||||
if prompt == "" {
|
||||
return prompt
|
||||
}
|
||||
cleaned := soraRemixTargetInURLPattern.ReplaceAllString(prompt, "")
|
||||
cleaned = soraRemixTargetPattern.ReplaceAllString(cleaned, "")
|
||||
cleaned = strings.Join(strings.Fields(cleaned), " ")
|
||||
return strings.TrimSpace(cleaned)
|
||||
}
|
||||
|
||||
func decodeSoraImageInput(ctx context.Context, input string) ([]byte, string, error) {
|
||||
raw := strings.TrimSpace(input)
|
||||
if raw == "" {
|
||||
@@ -769,7 +1237,7 @@ func decodeSoraImageInput(ctx context.Context, input string) ([]byte, string, er
|
||||
}
|
||||
meta := parts[0]
|
||||
payload := parts[1]
|
||||
decoded, err := base64.StdEncoding.DecodeString(payload)
|
||||
decoded, err := decodeBase64WithLimit(payload, soraImageInputMaxBytes)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
@@ -788,15 +1256,47 @@ func decodeSoraImageInput(ctx context.Context, input string) ([]byte, string, er
|
||||
if strings.HasPrefix(raw, "http://") || strings.HasPrefix(raw, "https://") {
|
||||
return downloadSoraImageInput(ctx, raw)
|
||||
}
|
||||
decoded, err := base64.StdEncoding.DecodeString(raw)
|
||||
decoded, err := decodeBase64WithLimit(raw, soraImageInputMaxBytes)
|
||||
if err != nil {
|
||||
return nil, "", errors.New("invalid base64 image")
|
||||
}
|
||||
return decoded, "image.png", nil
|
||||
}
|
||||
|
||||
func decodeSoraVideoInput(ctx context.Context, input string) ([]byte, error) {
|
||||
raw := strings.TrimSpace(input)
|
||||
if raw == "" {
|
||||
return nil, errors.New("empty video input")
|
||||
}
|
||||
if strings.HasPrefix(raw, "data:") {
|
||||
parts := strings.SplitN(raw, ",", 2)
|
||||
if len(parts) != 2 {
|
||||
return nil, errors.New("invalid video data url")
|
||||
}
|
||||
decoded, err := decodeBase64WithLimit(parts[1], soraVideoInputMaxBytes)
|
||||
if err != nil {
|
||||
return nil, errors.New("invalid base64 video")
|
||||
}
|
||||
if len(decoded) == 0 {
|
||||
return nil, errors.New("empty video data")
|
||||
}
|
||||
return decoded, nil
|
||||
}
|
||||
if strings.HasPrefix(raw, "http://") || strings.HasPrefix(raw, "https://") {
|
||||
return downloadSoraVideoInput(ctx, raw)
|
||||
}
|
||||
decoded, err := decodeBase64WithLimit(raw, soraVideoInputMaxBytes)
|
||||
if err != nil {
|
||||
return nil, errors.New("invalid base64 video")
|
||||
}
|
||||
if len(decoded) == 0 {
|
||||
return nil, errors.New("empty video data")
|
||||
}
|
||||
return decoded, nil
|
||||
}
|
||||
|
||||
func downloadSoraImageInput(ctx context.Context, rawURL string) ([]byte, string, error) {
|
||||
parsed, err := validateSoraImageURL(rawURL)
|
||||
parsed, err := validateSoraRemoteURL(rawURL)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
@@ -810,7 +1310,7 @@ func downloadSoraImageInput(ctx context.Context, rawURL string) ([]byte, string,
|
||||
if len(via) >= soraImageInputMaxRedirects {
|
||||
return errors.New("too many redirects")
|
||||
}
|
||||
return validateSoraImageURLValue(req.URL)
|
||||
return validateSoraRemoteURLValue(req.URL)
|
||||
},
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
@@ -833,51 +1333,103 @@ func downloadSoraImageInput(ctx context.Context, rawURL string) ([]byte, string,
|
||||
return data, filename, nil
|
||||
}
|
||||
|
||||
func validateSoraImageURL(raw string) (*url.URL, error) {
|
||||
func downloadSoraVideoInput(ctx context.Context, rawURL string) ([]byte, error) {
|
||||
parsed, err := validateSoraRemoteURL(rawURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, parsed.String(), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
client := &http.Client{
|
||||
Timeout: soraVideoInputTimeout,
|
||||
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||||
if len(via) >= soraVideoInputMaxRedirects {
|
||||
return errors.New("too many redirects")
|
||||
}
|
||||
return validateSoraRemoteURLValue(req.URL)
|
||||
},
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("download video failed: %d", resp.StatusCode)
|
||||
}
|
||||
data, err := io.ReadAll(io.LimitReader(resp.Body, soraVideoInputMaxBytes))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(data) == 0 {
|
||||
return nil, errors.New("empty video content")
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func decodeBase64WithLimit(encoded string, maxBytes int64) ([]byte, error) {
|
||||
if maxBytes <= 0 {
|
||||
return nil, errors.New("invalid max bytes limit")
|
||||
}
|
||||
decoder := base64.NewDecoder(base64.StdEncoding, strings.NewReader(encoded))
|
||||
limited := io.LimitReader(decoder, maxBytes+1)
|
||||
data, err := io.ReadAll(limited)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if int64(len(data)) > maxBytes {
|
||||
return nil, fmt.Errorf("input exceeds %d bytes limit", maxBytes)
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func validateSoraRemoteURL(raw string) (*url.URL, error) {
|
||||
if strings.TrimSpace(raw) == "" {
|
||||
return nil, errors.New("empty image url")
|
||||
return nil, errors.New("empty remote url")
|
||||
}
|
||||
parsed, err := url.Parse(raw)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid image url: %w", err)
|
||||
return nil, fmt.Errorf("invalid remote url: %w", err)
|
||||
}
|
||||
if err := validateSoraImageURLValue(parsed); err != nil {
|
||||
if err := validateSoraRemoteURLValue(parsed); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return parsed, nil
|
||||
}
|
||||
|
||||
func validateSoraImageURLValue(parsed *url.URL) error {
|
||||
func validateSoraRemoteURLValue(parsed *url.URL) error {
|
||||
if parsed == nil {
|
||||
return errors.New("invalid image url")
|
||||
return errors.New("invalid remote url")
|
||||
}
|
||||
scheme := strings.ToLower(strings.TrimSpace(parsed.Scheme))
|
||||
if scheme != "http" && scheme != "https" {
|
||||
return errors.New("only http/https image url is allowed")
|
||||
return errors.New("only http/https remote url is allowed")
|
||||
}
|
||||
if parsed.User != nil {
|
||||
return errors.New("image url cannot contain userinfo")
|
||||
return errors.New("remote url cannot contain userinfo")
|
||||
}
|
||||
host := strings.ToLower(strings.TrimSpace(parsed.Hostname()))
|
||||
if host == "" {
|
||||
return errors.New("image url missing host")
|
||||
return errors.New("remote url missing host")
|
||||
}
|
||||
if _, blocked := soraBlockedHostnames[host]; blocked {
|
||||
return errors.New("image url is not allowed")
|
||||
return errors.New("remote url is not allowed")
|
||||
}
|
||||
if ip := net.ParseIP(host); ip != nil {
|
||||
if isSoraBlockedIP(ip) {
|
||||
return errors.New("image url is not allowed")
|
||||
return errors.New("remote url is not allowed")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
ips, err := net.LookupIP(host)
|
||||
if err != nil {
|
||||
return fmt.Errorf("resolve image url failed: %w", err)
|
||||
return fmt.Errorf("resolve remote url failed: %w", err)
|
||||
}
|
||||
for _, ip := range ips {
|
||||
if isSoraBlockedIP(ip) {
|
||||
return errors.New("image url is not allowed")
|
||||
return errors.New("remote url is not allowed")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
||||
@@ -5,6 +5,7 @@ package service
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
@@ -25,6 +26,11 @@ type stubSoraClientForPoll struct {
|
||||
videoCalls int
|
||||
enhanced string
|
||||
enhanceErr error
|
||||
storyboard bool
|
||||
videoReq SoraVideoRequest
|
||||
parseErr error
|
||||
postCalls int
|
||||
deleteCalls int
|
||||
}
|
||||
|
||||
func (s *stubSoraClientForPoll) Enabled() bool { return true }
|
||||
@@ -35,8 +41,54 @@ func (s *stubSoraClientForPoll) CreateImageTask(ctx context.Context, account *Ac
|
||||
return "task-image", nil
|
||||
}
|
||||
func (s *stubSoraClientForPoll) CreateVideoTask(ctx context.Context, account *Account, req SoraVideoRequest) (string, error) {
|
||||
s.videoReq = req
|
||||
return "task-video", nil
|
||||
}
|
||||
func (s *stubSoraClientForPoll) CreateStoryboardTask(ctx context.Context, account *Account, req SoraStoryboardRequest) (string, error) {
|
||||
s.storyboard = true
|
||||
return "task-video", nil
|
||||
}
|
||||
func (s *stubSoraClientForPoll) UploadCharacterVideo(ctx context.Context, account *Account, data []byte) (string, error) {
|
||||
return "cameo-1", nil
|
||||
}
|
||||
func (s *stubSoraClientForPoll) GetCameoStatus(ctx context.Context, account *Account, cameoID string) (*SoraCameoStatus, error) {
|
||||
return &SoraCameoStatus{
|
||||
Status: "finalized",
|
||||
StatusMessage: "Completed",
|
||||
DisplayNameHint: "Character",
|
||||
UsernameHint: "user.character",
|
||||
ProfileAssetURL: "https://example.com/avatar.webp",
|
||||
}, nil
|
||||
}
|
||||
func (s *stubSoraClientForPoll) DownloadCharacterImage(ctx context.Context, account *Account, imageURL string) ([]byte, error) {
|
||||
return []byte("avatar"), nil
|
||||
}
|
||||
func (s *stubSoraClientForPoll) UploadCharacterImage(ctx context.Context, account *Account, data []byte) (string, error) {
|
||||
return "asset-pointer", nil
|
||||
}
|
||||
func (s *stubSoraClientForPoll) FinalizeCharacter(ctx context.Context, account *Account, req SoraCharacterFinalizeRequest) (string, error) {
|
||||
return "character-1", nil
|
||||
}
|
||||
func (s *stubSoraClientForPoll) SetCharacterPublic(ctx context.Context, account *Account, cameoID string) error {
|
||||
return nil
|
||||
}
|
||||
func (s *stubSoraClientForPoll) DeleteCharacter(ctx context.Context, account *Account, characterID string) error {
|
||||
return nil
|
||||
}
|
||||
func (s *stubSoraClientForPoll) PostVideoForWatermarkFree(ctx context.Context, account *Account, generationID string) (string, error) {
|
||||
s.postCalls++
|
||||
return "s_post", nil
|
||||
}
|
||||
func (s *stubSoraClientForPoll) DeletePost(ctx context.Context, account *Account, postID string) error {
|
||||
s.deleteCalls++
|
||||
return nil
|
||||
}
|
||||
func (s *stubSoraClientForPoll) GetWatermarkFreeURLCustom(ctx context.Context, account *Account, parseURL, parseToken, postID string) (string, error) {
|
||||
if s.parseErr != nil {
|
||||
return "", s.parseErr
|
||||
}
|
||||
return "https://example.com/no-watermark.mp4", nil
|
||||
}
|
||||
func (s *stubSoraClientForPoll) EnhancePrompt(ctx context.Context, account *Account, prompt, expansionLevel string, durationS int) (string, error) {
|
||||
if s.enhanced != "" {
|
||||
return s.enhanced, s.enhanceErr
|
||||
@@ -102,6 +154,109 @@ func TestSoraGatewayService_ForwardPromptEnhance(t *testing.T) {
|
||||
require.Equal(t, "prompt-enhance-short-10s", result.Model)
|
||||
}
|
||||
|
||||
func TestSoraGatewayService_ForwardStoryboardPrompt(t *testing.T) {
|
||||
client := &stubSoraClientForPoll{
|
||||
videoStatus: &SoraVideoTaskStatus{
|
||||
Status: "completed",
|
||||
URLs: []string{"https://example.com/v.mp4"},
|
||||
},
|
||||
}
|
||||
cfg := &config.Config{
|
||||
Sora: config.SoraConfig{
|
||||
Client: config.SoraClientConfig{
|
||||
PollIntervalSeconds: 1,
|
||||
MaxPollAttempts: 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
svc := NewSoraGatewayService(client, nil, nil, cfg)
|
||||
account := &Account{ID: 1, Platform: PlatformSora, Status: StatusActive}
|
||||
body := []byte(`{"model":"sora2-landscape-10s","messages":[{"role":"user","content":"[5.0s]猫猫跳伞 [5.0s]猫猫落地"}],"stream":false}`)
|
||||
|
||||
result, err := svc.Forward(context.Background(), nil, account, body, false)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.True(t, client.storyboard)
|
||||
}
|
||||
|
||||
func TestSoraGatewayService_ForwardCharacterOnly(t *testing.T) {
|
||||
client := &stubSoraClientForPoll{}
|
||||
cfg := &config.Config{
|
||||
Sora: config.SoraConfig{
|
||||
Client: config.SoraClientConfig{
|
||||
PollIntervalSeconds: 1,
|
||||
MaxPollAttempts: 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
svc := NewSoraGatewayService(client, nil, nil, cfg)
|
||||
account := &Account{ID: 1, Platform: PlatformSora, Status: StatusActive}
|
||||
body := []byte(`{"model":"sora2-landscape-10s","video":"aGVsbG8=","stream":false}`)
|
||||
|
||||
result, err := svc.Forward(context.Background(), nil, account, body, false)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, "prompt", result.MediaType)
|
||||
require.Equal(t, 0, client.videoCalls)
|
||||
}
|
||||
|
||||
func TestSoraGatewayService_ForwardWatermarkFallback(t *testing.T) {
|
||||
client := &stubSoraClientForPoll{
|
||||
videoStatus: &SoraVideoTaskStatus{
|
||||
Status: "completed",
|
||||
URLs: []string{"https://example.com/original.mp4"},
|
||||
GenerationID: "gen_1",
|
||||
},
|
||||
parseErr: errors.New("parse failed"),
|
||||
}
|
||||
cfg := &config.Config{
|
||||
Sora: config.SoraConfig{
|
||||
Client: config.SoraClientConfig{
|
||||
PollIntervalSeconds: 1,
|
||||
MaxPollAttempts: 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
svc := NewSoraGatewayService(client, nil, nil, cfg)
|
||||
account := &Account{ID: 1, Platform: PlatformSora, Status: StatusActive}
|
||||
body := []byte(`{"model":"sora2-landscape-10s","messages":[{"role":"user","content":"cat running"}],"stream":false,"watermark_free":true,"watermark_parse_method":"custom","watermark_parse_url":"https://parser.example.com","watermark_parse_token":"token","watermark_fallback_on_failure":true}`)
|
||||
|
||||
result, err := svc.Forward(context.Background(), nil, account, body, false)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, "https://example.com/original.mp4", result.MediaURL)
|
||||
require.Equal(t, 1, client.postCalls)
|
||||
require.Equal(t, 0, client.deleteCalls)
|
||||
}
|
||||
|
||||
func TestSoraGatewayService_ForwardWatermarkCustomSuccessAndDelete(t *testing.T) {
|
||||
client := &stubSoraClientForPoll{
|
||||
videoStatus: &SoraVideoTaskStatus{
|
||||
Status: "completed",
|
||||
URLs: []string{"https://example.com/original.mp4"},
|
||||
GenerationID: "gen_1",
|
||||
},
|
||||
}
|
||||
cfg := &config.Config{
|
||||
Sora: config.SoraConfig{
|
||||
Client: config.SoraClientConfig{
|
||||
PollIntervalSeconds: 1,
|
||||
MaxPollAttempts: 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
svc := NewSoraGatewayService(client, nil, nil, cfg)
|
||||
account := &Account{ID: 1, Platform: PlatformSora, Status: StatusActive}
|
||||
body := []byte(`{"model":"sora2-landscape-10s","messages":[{"role":"user","content":"cat running"}],"stream":false,"watermark_free":true,"watermark_parse_method":"custom","watermark_parse_url":"https://parser.example.com","watermark_parse_token":"token","watermark_delete_post":true}`)
|
||||
|
||||
result, err := svc.Forward(context.Background(), nil, account, body, false)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, "https://example.com/no-watermark.mp4", result.MediaURL)
|
||||
require.Equal(t, 1, client.postCalls)
|
||||
require.Equal(t, 1, client.deleteCalls)
|
||||
}
|
||||
|
||||
func TestSoraGatewayService_PollVideoTaskFailed(t *testing.T) {
|
||||
client := &stubSoraClientForPoll{
|
||||
videoStatus: &SoraVideoTaskStatus{
|
||||
@@ -119,9 +274,9 @@ func TestSoraGatewayService_PollVideoTaskFailed(t *testing.T) {
|
||||
}
|
||||
service := NewSoraGatewayService(client, nil, nil, cfg)
|
||||
|
||||
urls, err := service.pollVideoTask(context.Background(), nil, &Account{ID: 1}, "task", false)
|
||||
status, err := service.pollVideoTaskDetailed(context.Background(), nil, &Account{ID: 1}, "task", false)
|
||||
require.Error(t, err)
|
||||
require.Empty(t, urls)
|
||||
require.Nil(t, status)
|
||||
require.Contains(t, err.Error(), "reject")
|
||||
require.Equal(t, 1, client.videoCalls)
|
||||
}
|
||||
@@ -325,3 +480,19 @@ func TestDecodeSoraImageInput_DataURL(t *testing.T) {
|
||||
require.NotEmpty(t, data)
|
||||
require.Contains(t, filename, ".png")
|
||||
}
|
||||
|
||||
func TestDecodeBase64WithLimit_ExceedLimit(t *testing.T) {
|
||||
data, err := decodeBase64WithLimit("aGVsbG8=", 3)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, data)
|
||||
}
|
||||
|
||||
func TestParseSoraWatermarkOptions_NumericBool(t *testing.T) {
|
||||
body := map[string]any{
|
||||
"watermark_free": float64(1),
|
||||
"watermark_fallback_on_failure": float64(0),
|
||||
}
|
||||
opts := parseSoraWatermarkOptions(body)
|
||||
require.True(t, opts.Enabled)
|
||||
require.False(t, opts.FallbackOnFailure)
|
||||
}
|
||||
|
||||
170
backend/internal/util/soraerror/soraerror.go
Normal file
170
backend/internal/util/soraerror/soraerror.go
Normal file
@@ -0,0 +1,170 @@
|
||||
package soraerror
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var (
|
||||
cfRayPattern = regexp.MustCompile(`(?i)cf-ray[:\s=]+([a-z0-9-]+)`)
|
||||
cRayPattern = regexp.MustCompile(`(?i)cRay:\s*'([a-z0-9-]+)'`)
|
||||
htmlChallenge = []string{
|
||||
"window._cf_chl_opt",
|
||||
"just a moment",
|
||||
"enable javascript and cookies to continue",
|
||||
"__cf_chl_",
|
||||
"challenge-platform",
|
||||
}
|
||||
)
|
||||
|
||||
// IsCloudflareChallengeResponse reports whether the upstream response matches Cloudflare challenge behavior.
|
||||
func IsCloudflareChallengeResponse(statusCode int, headers http.Header, body []byte) bool {
|
||||
if statusCode != http.StatusForbidden && statusCode != http.StatusTooManyRequests {
|
||||
return false
|
||||
}
|
||||
|
||||
if headers != nil && strings.EqualFold(strings.TrimSpace(headers.Get("cf-mitigated")), "challenge") {
|
||||
return true
|
||||
}
|
||||
|
||||
preview := strings.ToLower(TruncateBody(body, 4096))
|
||||
for _, marker := range htmlChallenge {
|
||||
if strings.Contains(preview, marker) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
contentType := ""
|
||||
if headers != nil {
|
||||
contentType = strings.ToLower(strings.TrimSpace(headers.Get("content-type")))
|
||||
}
|
||||
if strings.Contains(contentType, "text/html") &&
|
||||
(strings.Contains(preview, "<html") || strings.Contains(preview, "<!doctype html")) &&
|
||||
(strings.Contains(preview, "cloudflare") || strings.Contains(preview, "challenge")) {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// ExtractCloudflareRayID extracts cf-ray from headers or response body.
|
||||
func ExtractCloudflareRayID(headers http.Header, body []byte) string {
|
||||
if headers != nil {
|
||||
rayID := strings.TrimSpace(headers.Get("cf-ray"))
|
||||
if rayID != "" {
|
||||
return rayID
|
||||
}
|
||||
rayID = strings.TrimSpace(headers.Get("Cf-Ray"))
|
||||
if rayID != "" {
|
||||
return rayID
|
||||
}
|
||||
}
|
||||
|
||||
preview := TruncateBody(body, 8192)
|
||||
if matches := cfRayPattern.FindStringSubmatch(preview); len(matches) >= 2 {
|
||||
return strings.TrimSpace(matches[1])
|
||||
}
|
||||
if matches := cRayPattern.FindStringSubmatch(preview); len(matches) >= 2 {
|
||||
return strings.TrimSpace(matches[1])
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// FormatCloudflareChallengeMessage appends cf-ray info when available.
|
||||
func FormatCloudflareChallengeMessage(base string, headers http.Header, body []byte) string {
|
||||
rayID := ExtractCloudflareRayID(headers, body)
|
||||
if rayID == "" {
|
||||
return base
|
||||
}
|
||||
return fmt.Sprintf("%s (cf-ray: %s)", base, rayID)
|
||||
}
|
||||
|
||||
// ExtractUpstreamErrorCodeAndMessage extracts structured error code/message from common JSON layouts.
|
||||
func ExtractUpstreamErrorCodeAndMessage(body []byte) (string, string) {
|
||||
trimmed := strings.TrimSpace(string(body))
|
||||
if trimmed == "" {
|
||||
return "", ""
|
||||
}
|
||||
if !json.Valid([]byte(trimmed)) {
|
||||
return "", truncateMessage(trimmed, 256)
|
||||
}
|
||||
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal([]byte(trimmed), &payload); err != nil {
|
||||
return "", truncateMessage(trimmed, 256)
|
||||
}
|
||||
|
||||
code := firstNonEmpty(
|
||||
extractNestedString(payload, "error", "code"),
|
||||
extractRootString(payload, "code"),
|
||||
)
|
||||
message := firstNonEmpty(
|
||||
extractNestedString(payload, "error", "message"),
|
||||
extractRootString(payload, "message"),
|
||||
extractNestedString(payload, "error", "detail"),
|
||||
extractRootString(payload, "detail"),
|
||||
)
|
||||
return strings.TrimSpace(code), truncateMessage(strings.TrimSpace(message), 512)
|
||||
}
|
||||
|
||||
// TruncateBody truncates body text for logging/inspection.
|
||||
func TruncateBody(body []byte, max int) string {
|
||||
if max <= 0 {
|
||||
max = 512
|
||||
}
|
||||
raw := strings.TrimSpace(string(body))
|
||||
if len(raw) <= max {
|
||||
return raw
|
||||
}
|
||||
return raw[:max] + "...(truncated)"
|
||||
}
|
||||
|
||||
func truncateMessage(s string, max int) string {
|
||||
if max <= 0 {
|
||||
return ""
|
||||
}
|
||||
if len(s) <= max {
|
||||
return s
|
||||
}
|
||||
return s[:max] + "...(truncated)"
|
||||
}
|
||||
|
||||
func firstNonEmpty(values ...string) string {
|
||||
for _, v := range values {
|
||||
if strings.TrimSpace(v) != "" {
|
||||
return v
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func extractRootString(m map[string]any, key string) string {
|
||||
if m == nil {
|
||||
return ""
|
||||
}
|
||||
v, ok := m[key]
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
s, _ := v.(string)
|
||||
return s
|
||||
}
|
||||
|
||||
func extractNestedString(m map[string]any, parent, key string) string {
|
||||
if m == nil {
|
||||
return ""
|
||||
}
|
||||
node, ok := m[parent]
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
child, ok := node.(map[string]any)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
s, _ := child[key].(string)
|
||||
return s
|
||||
}
|
||||
47
backend/internal/util/soraerror/soraerror_test.go
Normal file
47
backend/internal/util/soraerror/soraerror_test.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package soraerror
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestIsCloudflareChallengeResponse(t *testing.T) {
|
||||
headers := make(http.Header)
|
||||
headers.Set("cf-mitigated", "challenge")
|
||||
require.True(t, IsCloudflareChallengeResponse(http.StatusForbidden, headers, []byte(`{"ok":false}`)))
|
||||
|
||||
require.True(t, IsCloudflareChallengeResponse(http.StatusTooManyRequests, nil, []byte(`<!DOCTYPE html><title>Just a moment...</title><script>window._cf_chl_opt={};</script>`)))
|
||||
require.False(t, IsCloudflareChallengeResponse(http.StatusBadGateway, nil, []byte(`<!DOCTYPE html><title>Just a moment...</title>`)))
|
||||
}
|
||||
|
||||
func TestExtractCloudflareRayID(t *testing.T) {
|
||||
headers := make(http.Header)
|
||||
headers.Set("cf-ray", "9d01b0e9ecc35829-SEA")
|
||||
require.Equal(t, "9d01b0e9ecc35829-SEA", ExtractCloudflareRayID(headers, nil))
|
||||
|
||||
body := []byte(`<script>window._cf_chl_opt={cRay: '9cff2d62d83bb98d'};</script>`)
|
||||
require.Equal(t, "9cff2d62d83bb98d", ExtractCloudflareRayID(nil, body))
|
||||
}
|
||||
|
||||
func TestExtractUpstreamErrorCodeAndMessage(t *testing.T) {
|
||||
code, msg := ExtractUpstreamErrorCodeAndMessage([]byte(`{"error":{"code":"cf_shield_429","message":"rate limited"}}`))
|
||||
require.Equal(t, "cf_shield_429", code)
|
||||
require.Equal(t, "rate limited", msg)
|
||||
|
||||
code, msg = ExtractUpstreamErrorCodeAndMessage([]byte(`{"code":"unsupported_country_code","message":"not available"}`))
|
||||
require.Equal(t, "unsupported_country_code", code)
|
||||
require.Equal(t, "not available", msg)
|
||||
|
||||
code, msg = ExtractUpstreamErrorCodeAndMessage([]byte(`plain text`))
|
||||
require.Equal(t, "", code)
|
||||
require.Equal(t, "plain text", msg)
|
||||
}
|
||||
|
||||
func TestFormatCloudflareChallengeMessage(t *testing.T) {
|
||||
headers := make(http.Header)
|
||||
headers.Set("cf-ray", "9d03b68c086027a1-SEA")
|
||||
msg := FormatCloudflareChallengeMessage("blocked", headers, nil)
|
||||
require.Equal(t, "blocked (cf-ray: 9d03b68c086027a1-SEA)", msg)
|
||||
}
|
||||
Reference in New Issue
Block a user