feat(sora): 对齐sora2api分镜角色去水印与挑战错误治理
This commit is contained in:
@@ -8,10 +8,12 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"math"
|
||||
"mime"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -23,6 +25,9 @@ import (
|
||||
const soraImageInputMaxBytes = 20 << 20
|
||||
const soraImageInputMaxRedirects = 3
|
||||
const soraImageInputTimeout = 20 * time.Second
|
||||
const soraVideoInputMaxBytes = 200 << 20
|
||||
const soraVideoInputMaxRedirects = 3
|
||||
const soraVideoInputTimeout = 60 * time.Second
|
||||
|
||||
var soraImageSizeMap = map[string]string{
|
||||
"gpt-image": "360",
|
||||
@@ -61,6 +66,32 @@ type SoraGatewayService struct {
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
type soraWatermarkOptions struct {
|
||||
Enabled bool
|
||||
ParseMethod string
|
||||
ParseURL string
|
||||
ParseToken string
|
||||
FallbackOnFailure bool
|
||||
DeletePost bool
|
||||
}
|
||||
|
||||
type soraCharacterOptions struct {
|
||||
SetPublic bool
|
||||
DeleteAfterGenerate bool
|
||||
}
|
||||
|
||||
type soraCharacterFlowResult struct {
|
||||
CameoID string
|
||||
CharacterID string
|
||||
Username string
|
||||
DisplayName string
|
||||
}
|
||||
|
||||
var soraStoryboardPattern = regexp.MustCompile(`\[\d+(?:\.\d+)?s\]`)
|
||||
var soraStoryboardShotPattern = regexp.MustCompile(`\[(\d+(?:\.\d+)?)s\]\s*([^\[]+)`)
|
||||
var soraRemixTargetPattern = regexp.MustCompile(`s_[a-f0-9]{32}`)
|
||||
var soraRemixTargetInURLPattern = regexp.MustCompile(`https://sora\.chatgpt\.com/p/s_[a-f0-9]{32}`)
|
||||
|
||||
type soraPreflightChecker interface {
|
||||
PreflightCheck(ctx context.Context, account *Account, requestedModel string, modelCfg SoraModelConfig) error
|
||||
}
|
||||
@@ -117,20 +148,34 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun
|
||||
return nil, fmt.Errorf("unsupported model: %s", reqModel)
|
||||
}
|
||||
prompt, imageInput, videoInput, remixTargetID := extractSoraInput(reqBody)
|
||||
if strings.TrimSpace(prompt) == "" {
|
||||
prompt = strings.TrimSpace(prompt)
|
||||
imageInput = strings.TrimSpace(imageInput)
|
||||
videoInput = strings.TrimSpace(videoInput)
|
||||
remixTargetID = strings.TrimSpace(remixTargetID)
|
||||
|
||||
if videoInput != "" && modelCfg.Type != "video" {
|
||||
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "video input only supports video models", clientStream)
|
||||
return nil, errors.New("video input only supports video models")
|
||||
}
|
||||
if videoInput != "" && imageInput != "" {
|
||||
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "image input and video input cannot be used together", clientStream)
|
||||
return nil, errors.New("image input and video input cannot be used together")
|
||||
}
|
||||
characterOnly := videoInput != "" && prompt == ""
|
||||
if modelCfg.Type == "prompt_enhance" && prompt == "" {
|
||||
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "prompt is required", clientStream)
|
||||
return nil, errors.New("prompt is required")
|
||||
}
|
||||
if strings.TrimSpace(videoInput) != "" {
|
||||
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "Video input is not supported yet", clientStream)
|
||||
return nil, errors.New("video input not supported")
|
||||
if modelCfg.Type != "prompt_enhance" && prompt == "" && !characterOnly {
|
||||
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "prompt is required", clientStream)
|
||||
return nil, errors.New("prompt is required")
|
||||
}
|
||||
|
||||
reqCtx, cancel := s.withSoraTimeout(ctx, reqStream)
|
||||
if cancel != nil {
|
||||
defer cancel()
|
||||
}
|
||||
if checker, ok := s.soraClient.(soraPreflightChecker); ok {
|
||||
if checker, ok := s.soraClient.(soraPreflightChecker); ok && !characterOnly {
|
||||
if err := checker.PreflightCheck(reqCtx, account, reqModel, modelCfg); err != nil {
|
||||
return nil, s.handleSoraRequestError(ctx, account, err, reqModel, c, clientStream)
|
||||
}
|
||||
@@ -166,9 +211,69 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun
|
||||
}, nil
|
||||
}
|
||||
|
||||
characterOpts := parseSoraCharacterOptions(reqBody)
|
||||
watermarkOpts := parseSoraWatermarkOptions(reqBody)
|
||||
var characterResult *soraCharacterFlowResult
|
||||
if videoInput != "" {
|
||||
videoData, videoErr := decodeSoraVideoInput(reqCtx, videoInput)
|
||||
if videoErr != nil {
|
||||
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", videoErr.Error(), clientStream)
|
||||
return nil, videoErr
|
||||
}
|
||||
characterResult, videoErr = s.createCharacterFromVideo(reqCtx, account, videoData, characterOpts)
|
||||
if videoErr != nil {
|
||||
return nil, s.handleSoraRequestError(ctx, account, videoErr, reqModel, c, clientStream)
|
||||
}
|
||||
if characterResult != nil && characterOpts.DeleteAfterGenerate && strings.TrimSpace(characterResult.CharacterID) != "" && !characterOnly {
|
||||
characterID := strings.TrimSpace(characterResult.CharacterID)
|
||||
defer func() {
|
||||
cleanupCtx, cancelCleanup := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancelCleanup()
|
||||
if err := s.soraClient.DeleteCharacter(cleanupCtx, account, characterID); err != nil {
|
||||
log.Printf("[Sora] cleanup character failed, character_id=%s err=%v", characterID, err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
if characterOnly {
|
||||
content := "角色创建成功"
|
||||
if characterResult != nil && strings.TrimSpace(characterResult.Username) != "" {
|
||||
content = fmt.Sprintf("角色创建成功,角色名@%s", strings.TrimSpace(characterResult.Username))
|
||||
}
|
||||
var firstTokenMs *int
|
||||
if clientStream {
|
||||
ms, streamErr := s.writeSoraStream(c, reqModel, content, startTime)
|
||||
if streamErr != nil {
|
||||
return nil, streamErr
|
||||
}
|
||||
firstTokenMs = ms
|
||||
} else if c != nil {
|
||||
resp := buildSoraNonStreamResponse(content, reqModel)
|
||||
if characterResult != nil {
|
||||
resp["character_id"] = characterResult.CharacterID
|
||||
resp["cameo_id"] = characterResult.CameoID
|
||||
resp["character_username"] = characterResult.Username
|
||||
resp["character_display_name"] = characterResult.DisplayName
|
||||
}
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
return &ForwardResult{
|
||||
RequestID: "",
|
||||
Model: reqModel,
|
||||
Stream: clientStream,
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: firstTokenMs,
|
||||
Usage: ClaudeUsage{},
|
||||
MediaType: "prompt",
|
||||
}, nil
|
||||
}
|
||||
if characterResult != nil && strings.TrimSpace(characterResult.Username) != "" {
|
||||
prompt = fmt.Sprintf("@%s %s", characterResult.Username, prompt)
|
||||
}
|
||||
}
|
||||
|
||||
var imageData []byte
|
||||
imageFilename := ""
|
||||
if strings.TrimSpace(imageInput) != "" {
|
||||
if imageInput != "" {
|
||||
decoded, filename, err := decodeSoraImageInput(reqCtx, imageInput)
|
||||
if err != nil {
|
||||
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", err.Error(), clientStream)
|
||||
@@ -198,15 +303,27 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun
|
||||
MediaID: mediaID,
|
||||
})
|
||||
case "video":
|
||||
taskID, err = s.soraClient.CreateVideoTask(reqCtx, account, SoraVideoRequest{
|
||||
Prompt: prompt,
|
||||
Orientation: modelCfg.Orientation,
|
||||
Frames: modelCfg.Frames,
|
||||
Model: modelCfg.Model,
|
||||
Size: modelCfg.Size,
|
||||
MediaID: mediaID,
|
||||
RemixTargetID: remixTargetID,
|
||||
})
|
||||
if remixTargetID == "" && isSoraStoryboardPrompt(prompt) {
|
||||
taskID, err = s.soraClient.CreateStoryboardTask(reqCtx, account, SoraStoryboardRequest{
|
||||
Prompt: formatSoraStoryboardPrompt(prompt),
|
||||
Orientation: modelCfg.Orientation,
|
||||
Frames: modelCfg.Frames,
|
||||
Model: modelCfg.Model,
|
||||
Size: modelCfg.Size,
|
||||
MediaID: mediaID,
|
||||
})
|
||||
} else {
|
||||
taskID, err = s.soraClient.CreateVideoTask(reqCtx, account, SoraVideoRequest{
|
||||
Prompt: prompt,
|
||||
Orientation: modelCfg.Orientation,
|
||||
Frames: modelCfg.Frames,
|
||||
Model: modelCfg.Model,
|
||||
Size: modelCfg.Size,
|
||||
MediaID: mediaID,
|
||||
RemixTargetID: remixTargetID,
|
||||
CameoIDs: extractSoraCameoIDs(reqBody),
|
||||
})
|
||||
}
|
||||
default:
|
||||
err = fmt.Errorf("unsupported model type: %s", modelCfg.Type)
|
||||
}
|
||||
@@ -219,6 +336,7 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun
|
||||
}
|
||||
|
||||
var mediaURLs []string
|
||||
videoGenerationID := ""
|
||||
mediaType := modelCfg.Type
|
||||
imageCount := 0
|
||||
imageSize := ""
|
||||
@@ -232,15 +350,32 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun
|
||||
imageCount = len(urls)
|
||||
imageSize = soraImageSizeFromModel(reqModel)
|
||||
case "video":
|
||||
urls, pollErr := s.pollVideoTask(reqCtx, c, account, taskID, clientStream)
|
||||
videoStatus, pollErr := s.pollVideoTaskDetailed(reqCtx, c, account, taskID, clientStream)
|
||||
if pollErr != nil {
|
||||
return nil, s.handleSoraRequestError(ctx, account, pollErr, reqModel, c, clientStream)
|
||||
}
|
||||
mediaURLs = urls
|
||||
if videoStatus != nil {
|
||||
mediaURLs = videoStatus.URLs
|
||||
videoGenerationID = strings.TrimSpace(videoStatus.GenerationID)
|
||||
}
|
||||
default:
|
||||
mediaType = "prompt"
|
||||
}
|
||||
|
||||
watermarkPostID := ""
|
||||
if modelCfg.Type == "video" && watermarkOpts.Enabled {
|
||||
watermarkURL, postID, watermarkErr := s.resolveWatermarkFreeURL(reqCtx, account, videoGenerationID, watermarkOpts)
|
||||
if watermarkErr != nil {
|
||||
if !watermarkOpts.FallbackOnFailure {
|
||||
return nil, s.handleSoraRequestError(ctx, account, watermarkErr, reqModel, c, clientStream)
|
||||
}
|
||||
log.Printf("[Sora] watermark-free fallback to original URL, task_id=%s err=%v", taskID, watermarkErr)
|
||||
} else if strings.TrimSpace(watermarkURL) != "" {
|
||||
mediaURLs = []string{strings.TrimSpace(watermarkURL)}
|
||||
watermarkPostID = strings.TrimSpace(postID)
|
||||
}
|
||||
}
|
||||
|
||||
finalURLs := s.normalizeSoraMediaURLs(mediaURLs)
|
||||
if len(mediaURLs) > 0 && s.mediaStorage != nil && s.mediaStorage.Enabled() {
|
||||
stored, storeErr := s.mediaStorage.StoreFromURLs(reqCtx, mediaType, mediaURLs)
|
||||
@@ -251,6 +386,11 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun
|
||||
finalURLs = s.normalizeSoraMediaURLs(stored)
|
||||
}
|
||||
}
|
||||
if watermarkPostID != "" && watermarkOpts.DeletePost {
|
||||
if deleteErr := s.soraClient.DeletePost(reqCtx, account, watermarkPostID); deleteErr != nil {
|
||||
log.Printf("[Sora] delete post failed, post_id=%s err=%v", watermarkPostID, deleteErr)
|
||||
}
|
||||
}
|
||||
|
||||
content := buildSoraContent(mediaType, finalURLs)
|
||||
var firstTokenMs *int
|
||||
@@ -299,6 +439,267 @@ func (s *SoraGatewayService) withSoraTimeout(ctx context.Context, stream bool) (
|
||||
return context.WithTimeout(ctx, time.Duration(timeoutSeconds)*time.Second)
|
||||
}
|
||||
|
||||
func parseSoraWatermarkOptions(body map[string]any) soraWatermarkOptions {
|
||||
opts := soraWatermarkOptions{
|
||||
Enabled: parseBoolWithDefault(body, "watermark_free", false),
|
||||
ParseMethod: strings.ToLower(strings.TrimSpace(parseStringWithDefault(body, "watermark_parse_method", "third_party"))),
|
||||
ParseURL: strings.TrimSpace(parseStringWithDefault(body, "watermark_parse_url", "")),
|
||||
ParseToken: strings.TrimSpace(parseStringWithDefault(body, "watermark_parse_token", "")),
|
||||
FallbackOnFailure: parseBoolWithDefault(body, "watermark_fallback_on_failure", true),
|
||||
DeletePost: parseBoolWithDefault(body, "watermark_delete_post", false),
|
||||
}
|
||||
if opts.ParseMethod == "" {
|
||||
opts.ParseMethod = "third_party"
|
||||
}
|
||||
return opts
|
||||
}
|
||||
|
||||
func parseSoraCharacterOptions(body map[string]any) soraCharacterOptions {
|
||||
return soraCharacterOptions{
|
||||
SetPublic: parseBoolWithDefault(body, "character_set_public", true),
|
||||
DeleteAfterGenerate: parseBoolWithDefault(body, "character_delete_after_generate", true),
|
||||
}
|
||||
}
|
||||
|
||||
func parseBoolWithDefault(body map[string]any, key string, def bool) bool {
|
||||
if body == nil {
|
||||
return def
|
||||
}
|
||||
val, ok := body[key]
|
||||
if !ok {
|
||||
return def
|
||||
}
|
||||
switch typed := val.(type) {
|
||||
case bool:
|
||||
return typed
|
||||
case int:
|
||||
return typed != 0
|
||||
case int32:
|
||||
return typed != 0
|
||||
case int64:
|
||||
return typed != 0
|
||||
case float64:
|
||||
return typed != 0
|
||||
case string:
|
||||
typed = strings.ToLower(strings.TrimSpace(typed))
|
||||
if typed == "true" || typed == "1" || typed == "yes" {
|
||||
return true
|
||||
}
|
||||
if typed == "false" || typed == "0" || typed == "no" {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return def
|
||||
}
|
||||
|
||||
func parseStringWithDefault(body map[string]any, key, def string) string {
|
||||
if body == nil {
|
||||
return def
|
||||
}
|
||||
val, ok := body[key]
|
||||
if !ok {
|
||||
return def
|
||||
}
|
||||
if str, ok := val.(string); ok {
|
||||
return str
|
||||
}
|
||||
return def
|
||||
}
|
||||
|
||||
func extractSoraCameoIDs(body map[string]any) []string {
|
||||
if body == nil {
|
||||
return nil
|
||||
}
|
||||
raw, ok := body["cameo_ids"]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
switch typed := raw.(type) {
|
||||
case []string:
|
||||
out := make([]string, 0, len(typed))
|
||||
for _, item := range typed {
|
||||
item = strings.TrimSpace(item)
|
||||
if item != "" {
|
||||
out = append(out, item)
|
||||
}
|
||||
}
|
||||
return out
|
||||
case []any:
|
||||
out := make([]string, 0, len(typed))
|
||||
for _, item := range typed {
|
||||
str, ok := item.(string)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
str = strings.TrimSpace(str)
|
||||
if str != "" {
|
||||
out = append(out, str)
|
||||
}
|
||||
}
|
||||
return out
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SoraGatewayService) createCharacterFromVideo(ctx context.Context, account *Account, videoData []byte, opts soraCharacterOptions) (*soraCharacterFlowResult, error) {
|
||||
cameoID, err := s.soraClient.UploadCharacterVideo(ctx, account, videoData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cameoStatus, err := s.pollCameoStatus(ctx, account, cameoID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
username := processSoraCharacterUsername(cameoStatus.UsernameHint)
|
||||
displayName := strings.TrimSpace(cameoStatus.DisplayNameHint)
|
||||
if displayName == "" {
|
||||
displayName = "Character"
|
||||
}
|
||||
profileAssetURL := strings.TrimSpace(cameoStatus.ProfileAssetURL)
|
||||
if profileAssetURL == "" {
|
||||
return nil, errors.New("profile asset url not found in cameo status")
|
||||
}
|
||||
|
||||
avatarData, err := s.soraClient.DownloadCharacterImage(ctx, account, profileAssetURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
assetPointer, err := s.soraClient.UploadCharacterImage(ctx, account, avatarData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
instructionSet := cameoStatus.InstructionSetHint
|
||||
if instructionSet == nil {
|
||||
instructionSet = cameoStatus.InstructionSet
|
||||
}
|
||||
|
||||
characterID, err := s.soraClient.FinalizeCharacter(ctx, account, SoraCharacterFinalizeRequest{
|
||||
CameoID: strings.TrimSpace(cameoID),
|
||||
Username: username,
|
||||
DisplayName: displayName,
|
||||
ProfileAssetPointer: assetPointer,
|
||||
InstructionSet: instructionSet,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if opts.SetPublic {
|
||||
if err := s.soraClient.SetCharacterPublic(ctx, account, cameoID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &soraCharacterFlowResult{
|
||||
CameoID: strings.TrimSpace(cameoID),
|
||||
CharacterID: strings.TrimSpace(characterID),
|
||||
Username: strings.TrimSpace(username),
|
||||
DisplayName: displayName,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *SoraGatewayService) pollCameoStatus(ctx context.Context, account *Account, cameoID string) (*SoraCameoStatus, error) {
|
||||
timeout := 10 * time.Minute
|
||||
interval := 5 * time.Second
|
||||
maxAttempts := int(math.Ceil(timeout.Seconds() / interval.Seconds()))
|
||||
if maxAttempts < 1 {
|
||||
maxAttempts = 1
|
||||
}
|
||||
|
||||
var lastErr error
|
||||
consecutiveErrors := 0
|
||||
for attempt := 0; attempt < maxAttempts; attempt++ {
|
||||
status, err := s.soraClient.GetCameoStatus(ctx, account, cameoID)
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
consecutiveErrors++
|
||||
if consecutiveErrors >= 3 {
|
||||
break
|
||||
}
|
||||
if attempt < maxAttempts-1 {
|
||||
if sleepErr := sleepWithContext(ctx, interval); sleepErr != nil {
|
||||
return nil, sleepErr
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
consecutiveErrors = 0
|
||||
if status == nil {
|
||||
if attempt < maxAttempts-1 {
|
||||
if sleepErr := sleepWithContext(ctx, interval); sleepErr != nil {
|
||||
return nil, sleepErr
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
currentStatus := strings.ToLower(strings.TrimSpace(status.Status))
|
||||
statusMessage := strings.TrimSpace(status.StatusMessage)
|
||||
if currentStatus == "failed" {
|
||||
if statusMessage == "" {
|
||||
statusMessage = "character creation failed"
|
||||
}
|
||||
return nil, errors.New(statusMessage)
|
||||
}
|
||||
if strings.EqualFold(statusMessage, "Completed") || currentStatus == "finalized" {
|
||||
return status, nil
|
||||
}
|
||||
if attempt < maxAttempts-1 {
|
||||
if sleepErr := sleepWithContext(ctx, interval); sleepErr != nil {
|
||||
return nil, sleepErr
|
||||
}
|
||||
}
|
||||
}
|
||||
if lastErr != nil {
|
||||
return nil, fmt.Errorf("poll cameo status failed: %w", lastErr)
|
||||
}
|
||||
return nil, errors.New("cameo processing timeout")
|
||||
}
|
||||
|
||||
func processSoraCharacterUsername(usernameHint string) string {
|
||||
usernameHint = strings.TrimSpace(usernameHint)
|
||||
if usernameHint == "" {
|
||||
usernameHint = "character"
|
||||
}
|
||||
if strings.Contains(usernameHint, ".") {
|
||||
parts := strings.Split(usernameHint, ".")
|
||||
usernameHint = strings.TrimSpace(parts[len(parts)-1])
|
||||
}
|
||||
if usernameHint == "" {
|
||||
usernameHint = "character"
|
||||
}
|
||||
return fmt.Sprintf("%s%d", usernameHint, soraRandInt(900)+100)
|
||||
}
|
||||
|
||||
func (s *SoraGatewayService) resolveWatermarkFreeURL(ctx context.Context, account *Account, generationID string, opts soraWatermarkOptions) (string, string, error) {
|
||||
generationID = strings.TrimSpace(generationID)
|
||||
if generationID == "" {
|
||||
return "", "", errors.New("generation id is required for watermark-free mode")
|
||||
}
|
||||
postID, err := s.soraClient.PostVideoForWatermarkFree(ctx, account, generationID)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
postID = strings.TrimSpace(postID)
|
||||
if postID == "" {
|
||||
return "", "", errors.New("watermark-free publish returned empty post id")
|
||||
}
|
||||
|
||||
switch opts.ParseMethod {
|
||||
case "custom":
|
||||
urlVal, parseErr := s.soraClient.GetWatermarkFreeURLCustom(ctx, account, opts.ParseURL, opts.ParseToken, postID)
|
||||
if parseErr != nil {
|
||||
return "", postID, parseErr
|
||||
}
|
||||
return strings.TrimSpace(urlVal), postID, nil
|
||||
case "", "third_party":
|
||||
return fmt.Sprintf("https://oscdn2.dyysy.com/MP4/%s.mp4", postID), postID, nil
|
||||
default:
|
||||
return "", postID, fmt.Errorf("unsupported watermark parse method: %s", opts.ParseMethod)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SoraGatewayService) shouldFailoverUpstreamError(statusCode int) bool {
|
||||
switch statusCode {
|
||||
case 401, 402, 403, 404, 429, 529:
|
||||
@@ -554,7 +955,7 @@ func (s *SoraGatewayService) pollImageTask(ctx context.Context, c *gin.Context,
|
||||
return nil, errors.New("sora image generation timeout")
|
||||
}
|
||||
|
||||
func (s *SoraGatewayService) pollVideoTask(ctx context.Context, c *gin.Context, account *Account, taskID string, stream bool) ([]string, error) {
|
||||
func (s *SoraGatewayService) pollVideoTaskDetailed(ctx context.Context, c *gin.Context, account *Account, taskID string, stream bool) (*SoraVideoTaskStatus, error) {
|
||||
interval := s.pollInterval()
|
||||
maxAttempts := s.pollMaxAttempts()
|
||||
lastPing := time.Now()
|
||||
@@ -565,7 +966,7 @@ func (s *SoraGatewayService) pollVideoTask(ctx context.Context, c *gin.Context,
|
||||
}
|
||||
switch strings.ToLower(status.Status) {
|
||||
case "completed", "succeeded":
|
||||
return status.URLs, nil
|
||||
return status, nil
|
||||
case "failed":
|
||||
if status.ErrorMsg != "" {
|
||||
return nil, errors.New(status.ErrorMsg)
|
||||
@@ -669,7 +1070,7 @@ func extractSoraInput(body map[string]any) (prompt, imageInput, videoInput, remi
|
||||
return "", "", "", ""
|
||||
}
|
||||
if v, ok := body["remix_target_id"].(string); ok {
|
||||
remixTargetID = v
|
||||
remixTargetID = strings.TrimSpace(v)
|
||||
}
|
||||
if v, ok := body["image"].(string); ok {
|
||||
imageInput = v
|
||||
@@ -710,6 +1111,10 @@ func extractSoraInput(body map[string]any) (prompt, imageInput, videoInput, remi
|
||||
prompt = builder.String()
|
||||
}
|
||||
}
|
||||
if remixTargetID == "" {
|
||||
remixTargetID = extractRemixTargetIDFromPrompt(prompt)
|
||||
}
|
||||
prompt = cleanRemixLinkFromPrompt(prompt)
|
||||
return prompt, imageInput, videoInput, remixTargetID
|
||||
}
|
||||
|
||||
@@ -757,6 +1162,69 @@ func parseSoraMessageContent(content any) (text, imageInput, videoInput string)
|
||||
}
|
||||
}
|
||||
|
||||
func isSoraStoryboardPrompt(prompt string) bool {
|
||||
prompt = strings.TrimSpace(prompt)
|
||||
if prompt == "" {
|
||||
return false
|
||||
}
|
||||
return len(soraStoryboardPattern.FindAllString(prompt, -1)) >= 1
|
||||
}
|
||||
|
||||
func formatSoraStoryboardPrompt(prompt string) string {
|
||||
prompt = strings.TrimSpace(prompt)
|
||||
if prompt == "" {
|
||||
return ""
|
||||
}
|
||||
matches := soraStoryboardShotPattern.FindAllStringSubmatch(prompt, -1)
|
||||
if len(matches) == 0 {
|
||||
return prompt
|
||||
}
|
||||
firstBracketPos := strings.Index(prompt, "[")
|
||||
instructions := ""
|
||||
if firstBracketPos > 0 {
|
||||
instructions = strings.TrimSpace(prompt[:firstBracketPos])
|
||||
}
|
||||
shots := make([]string, 0, len(matches))
|
||||
for i, match := range matches {
|
||||
if len(match) < 3 {
|
||||
continue
|
||||
}
|
||||
duration := strings.TrimSpace(match[1])
|
||||
scene := strings.TrimSpace(match[2])
|
||||
if scene == "" {
|
||||
continue
|
||||
}
|
||||
shots = append(shots, fmt.Sprintf("Shot %d:\nduration: %ssec\nScene: %s", i+1, duration, scene))
|
||||
}
|
||||
if len(shots) == 0 {
|
||||
return prompt
|
||||
}
|
||||
timeline := strings.Join(shots, "\n\n")
|
||||
if instructions == "" {
|
||||
return timeline
|
||||
}
|
||||
return fmt.Sprintf("current timeline:\n%s\n\ninstructions:\n%s", timeline, instructions)
|
||||
}
|
||||
|
||||
func extractRemixTargetIDFromPrompt(prompt string) string {
|
||||
prompt = strings.TrimSpace(prompt)
|
||||
if prompt == "" {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(soraRemixTargetPattern.FindString(prompt))
|
||||
}
|
||||
|
||||
func cleanRemixLinkFromPrompt(prompt string) string {
|
||||
prompt = strings.TrimSpace(prompt)
|
||||
if prompt == "" {
|
||||
return prompt
|
||||
}
|
||||
cleaned := soraRemixTargetInURLPattern.ReplaceAllString(prompt, "")
|
||||
cleaned = soraRemixTargetPattern.ReplaceAllString(cleaned, "")
|
||||
cleaned = strings.Join(strings.Fields(cleaned), " ")
|
||||
return strings.TrimSpace(cleaned)
|
||||
}
|
||||
|
||||
func decodeSoraImageInput(ctx context.Context, input string) ([]byte, string, error) {
|
||||
raw := strings.TrimSpace(input)
|
||||
if raw == "" {
|
||||
@@ -769,7 +1237,7 @@ func decodeSoraImageInput(ctx context.Context, input string) ([]byte, string, er
|
||||
}
|
||||
meta := parts[0]
|
||||
payload := parts[1]
|
||||
decoded, err := base64.StdEncoding.DecodeString(payload)
|
||||
decoded, err := decodeBase64WithLimit(payload, soraImageInputMaxBytes)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
@@ -788,15 +1256,47 @@ func decodeSoraImageInput(ctx context.Context, input string) ([]byte, string, er
|
||||
if strings.HasPrefix(raw, "http://") || strings.HasPrefix(raw, "https://") {
|
||||
return downloadSoraImageInput(ctx, raw)
|
||||
}
|
||||
decoded, err := base64.StdEncoding.DecodeString(raw)
|
||||
decoded, err := decodeBase64WithLimit(raw, soraImageInputMaxBytes)
|
||||
if err != nil {
|
||||
return nil, "", errors.New("invalid base64 image")
|
||||
}
|
||||
return decoded, "image.png", nil
|
||||
}
|
||||
|
||||
func decodeSoraVideoInput(ctx context.Context, input string) ([]byte, error) {
|
||||
raw := strings.TrimSpace(input)
|
||||
if raw == "" {
|
||||
return nil, errors.New("empty video input")
|
||||
}
|
||||
if strings.HasPrefix(raw, "data:") {
|
||||
parts := strings.SplitN(raw, ",", 2)
|
||||
if len(parts) != 2 {
|
||||
return nil, errors.New("invalid video data url")
|
||||
}
|
||||
decoded, err := decodeBase64WithLimit(parts[1], soraVideoInputMaxBytes)
|
||||
if err != nil {
|
||||
return nil, errors.New("invalid base64 video")
|
||||
}
|
||||
if len(decoded) == 0 {
|
||||
return nil, errors.New("empty video data")
|
||||
}
|
||||
return decoded, nil
|
||||
}
|
||||
if strings.HasPrefix(raw, "http://") || strings.HasPrefix(raw, "https://") {
|
||||
return downloadSoraVideoInput(ctx, raw)
|
||||
}
|
||||
decoded, err := decodeBase64WithLimit(raw, soraVideoInputMaxBytes)
|
||||
if err != nil {
|
||||
return nil, errors.New("invalid base64 video")
|
||||
}
|
||||
if len(decoded) == 0 {
|
||||
return nil, errors.New("empty video data")
|
||||
}
|
||||
return decoded, nil
|
||||
}
|
||||
|
||||
func downloadSoraImageInput(ctx context.Context, rawURL string) ([]byte, string, error) {
|
||||
parsed, err := validateSoraImageURL(rawURL)
|
||||
parsed, err := validateSoraRemoteURL(rawURL)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
@@ -810,7 +1310,7 @@ func downloadSoraImageInput(ctx context.Context, rawURL string) ([]byte, string,
|
||||
if len(via) >= soraImageInputMaxRedirects {
|
||||
return errors.New("too many redirects")
|
||||
}
|
||||
return validateSoraImageURLValue(req.URL)
|
||||
return validateSoraRemoteURLValue(req.URL)
|
||||
},
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
@@ -833,51 +1333,103 @@ func downloadSoraImageInput(ctx context.Context, rawURL string) ([]byte, string,
|
||||
return data, filename, nil
|
||||
}
|
||||
|
||||
func validateSoraImageURL(raw string) (*url.URL, error) {
|
||||
func downloadSoraVideoInput(ctx context.Context, rawURL string) ([]byte, error) {
|
||||
parsed, err := validateSoraRemoteURL(rawURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, parsed.String(), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
client := &http.Client{
|
||||
Timeout: soraVideoInputTimeout,
|
||||
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||||
if len(via) >= soraVideoInputMaxRedirects {
|
||||
return errors.New("too many redirects")
|
||||
}
|
||||
return validateSoraRemoteURLValue(req.URL)
|
||||
},
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("download video failed: %d", resp.StatusCode)
|
||||
}
|
||||
data, err := io.ReadAll(io.LimitReader(resp.Body, soraVideoInputMaxBytes))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(data) == 0 {
|
||||
return nil, errors.New("empty video content")
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func decodeBase64WithLimit(encoded string, maxBytes int64) ([]byte, error) {
|
||||
if maxBytes <= 0 {
|
||||
return nil, errors.New("invalid max bytes limit")
|
||||
}
|
||||
decoder := base64.NewDecoder(base64.StdEncoding, strings.NewReader(encoded))
|
||||
limited := io.LimitReader(decoder, maxBytes+1)
|
||||
data, err := io.ReadAll(limited)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if int64(len(data)) > maxBytes {
|
||||
return nil, fmt.Errorf("input exceeds %d bytes limit", maxBytes)
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func validateSoraRemoteURL(raw string) (*url.URL, error) {
|
||||
if strings.TrimSpace(raw) == "" {
|
||||
return nil, errors.New("empty image url")
|
||||
return nil, errors.New("empty remote url")
|
||||
}
|
||||
parsed, err := url.Parse(raw)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid image url: %w", err)
|
||||
return nil, fmt.Errorf("invalid remote url: %w", err)
|
||||
}
|
||||
if err := validateSoraImageURLValue(parsed); err != nil {
|
||||
if err := validateSoraRemoteURLValue(parsed); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return parsed, nil
|
||||
}
|
||||
|
||||
func validateSoraImageURLValue(parsed *url.URL) error {
|
||||
func validateSoraRemoteURLValue(parsed *url.URL) error {
|
||||
if parsed == nil {
|
||||
return errors.New("invalid image url")
|
||||
return errors.New("invalid remote url")
|
||||
}
|
||||
scheme := strings.ToLower(strings.TrimSpace(parsed.Scheme))
|
||||
if scheme != "http" && scheme != "https" {
|
||||
return errors.New("only http/https image url is allowed")
|
||||
return errors.New("only http/https remote url is allowed")
|
||||
}
|
||||
if parsed.User != nil {
|
||||
return errors.New("image url cannot contain userinfo")
|
||||
return errors.New("remote url cannot contain userinfo")
|
||||
}
|
||||
host := strings.ToLower(strings.TrimSpace(parsed.Hostname()))
|
||||
if host == "" {
|
||||
return errors.New("image url missing host")
|
||||
return errors.New("remote url missing host")
|
||||
}
|
||||
if _, blocked := soraBlockedHostnames[host]; blocked {
|
||||
return errors.New("image url is not allowed")
|
||||
return errors.New("remote url is not allowed")
|
||||
}
|
||||
if ip := net.ParseIP(host); ip != nil {
|
||||
if isSoraBlockedIP(ip) {
|
||||
return errors.New("image url is not allowed")
|
||||
return errors.New("remote url is not allowed")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
ips, err := net.LookupIP(host)
|
||||
if err != nil {
|
||||
return fmt.Errorf("resolve image url failed: %w", err)
|
||||
return fmt.Errorf("resolve remote url failed: %w", err)
|
||||
}
|
||||
for _, ip := range ips {
|
||||
if isSoraBlockedIP(ip) {
|
||||
return errors.New("image url is not allowed")
|
||||
return errors.New("remote url is not allowed")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
||||
Reference in New Issue
Block a user