feat(sora): 对齐sora2api分镜角色去水印与挑战错误治理

This commit is contained in:
yangjianbo
2026-02-19 20:04:10 +08:00
parent 440b87094a
commit 40498aac9d
12 changed files with 1994 additions and 202 deletions

View File

@@ -8,10 +8,12 @@ import (
"fmt"
"io"
"log"
"math"
"mime"
"net"
"net/http"
"net/url"
"regexp"
"strconv"
"strings"
"time"
@@ -23,6 +25,9 @@ import (
const soraImageInputMaxBytes = 20 << 20
const soraImageInputMaxRedirects = 3
const soraImageInputTimeout = 20 * time.Second
const soraVideoInputMaxBytes = 200 << 20
const soraVideoInputMaxRedirects = 3
const soraVideoInputTimeout = 60 * time.Second
var soraImageSizeMap = map[string]string{
"gpt-image": "360",
@@ -61,6 +66,32 @@ type SoraGatewayService struct {
cfg *config.Config
}
type soraWatermarkOptions struct {
Enabled bool
ParseMethod string
ParseURL string
ParseToken string
FallbackOnFailure bool
DeletePost bool
}
type soraCharacterOptions struct {
SetPublic bool
DeleteAfterGenerate bool
}
type soraCharacterFlowResult struct {
CameoID string
CharacterID string
Username string
DisplayName string
}
var soraStoryboardPattern = regexp.MustCompile(`\[\d+(?:\.\d+)?s\]`)
var soraStoryboardShotPattern = regexp.MustCompile(`\[(\d+(?:\.\d+)?)s\]\s*([^\[]+)`)
var soraRemixTargetPattern = regexp.MustCompile(`s_[a-f0-9]{32}`)
var soraRemixTargetInURLPattern = regexp.MustCompile(`https://sora\.chatgpt\.com/p/s_[a-f0-9]{32}`)
type soraPreflightChecker interface {
PreflightCheck(ctx context.Context, account *Account, requestedModel string, modelCfg SoraModelConfig) error
}
@@ -117,20 +148,34 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun
return nil, fmt.Errorf("unsupported model: %s", reqModel)
}
prompt, imageInput, videoInput, remixTargetID := extractSoraInput(reqBody)
if strings.TrimSpace(prompt) == "" {
prompt = strings.TrimSpace(prompt)
imageInput = strings.TrimSpace(imageInput)
videoInput = strings.TrimSpace(videoInput)
remixTargetID = strings.TrimSpace(remixTargetID)
if videoInput != "" && modelCfg.Type != "video" {
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "video input only supports video models", clientStream)
return nil, errors.New("video input only supports video models")
}
if videoInput != "" && imageInput != "" {
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "image input and video input cannot be used together", clientStream)
return nil, errors.New("image input and video input cannot be used together")
}
characterOnly := videoInput != "" && prompt == ""
if modelCfg.Type == "prompt_enhance" && prompt == "" {
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "prompt is required", clientStream)
return nil, errors.New("prompt is required")
}
if strings.TrimSpace(videoInput) != "" {
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "Video input is not supported yet", clientStream)
return nil, errors.New("video input not supported")
if modelCfg.Type != "prompt_enhance" && prompt == "" && !characterOnly {
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "prompt is required", clientStream)
return nil, errors.New("prompt is required")
}
reqCtx, cancel := s.withSoraTimeout(ctx, reqStream)
if cancel != nil {
defer cancel()
}
if checker, ok := s.soraClient.(soraPreflightChecker); ok {
if checker, ok := s.soraClient.(soraPreflightChecker); ok && !characterOnly {
if err := checker.PreflightCheck(reqCtx, account, reqModel, modelCfg); err != nil {
return nil, s.handleSoraRequestError(ctx, account, err, reqModel, c, clientStream)
}
@@ -166,9 +211,69 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun
}, nil
}
characterOpts := parseSoraCharacterOptions(reqBody)
watermarkOpts := parseSoraWatermarkOptions(reqBody)
var characterResult *soraCharacterFlowResult
if videoInput != "" {
videoData, videoErr := decodeSoraVideoInput(reqCtx, videoInput)
if videoErr != nil {
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", videoErr.Error(), clientStream)
return nil, videoErr
}
characterResult, videoErr = s.createCharacterFromVideo(reqCtx, account, videoData, characterOpts)
if videoErr != nil {
return nil, s.handleSoraRequestError(ctx, account, videoErr, reqModel, c, clientStream)
}
if characterResult != nil && characterOpts.DeleteAfterGenerate && strings.TrimSpace(characterResult.CharacterID) != "" && !characterOnly {
characterID := strings.TrimSpace(characterResult.CharacterID)
defer func() {
cleanupCtx, cancelCleanup := context.WithTimeout(context.Background(), 15*time.Second)
defer cancelCleanup()
if err := s.soraClient.DeleteCharacter(cleanupCtx, account, characterID); err != nil {
log.Printf("[Sora] cleanup character failed, character_id=%s err=%v", characterID, err)
}
}()
}
if characterOnly {
content := "角色创建成功"
if characterResult != nil && strings.TrimSpace(characterResult.Username) != "" {
content = fmt.Sprintf("角色创建成功,角色名@%s", strings.TrimSpace(characterResult.Username))
}
var firstTokenMs *int
if clientStream {
ms, streamErr := s.writeSoraStream(c, reqModel, content, startTime)
if streamErr != nil {
return nil, streamErr
}
firstTokenMs = ms
} else if c != nil {
resp := buildSoraNonStreamResponse(content, reqModel)
if characterResult != nil {
resp["character_id"] = characterResult.CharacterID
resp["cameo_id"] = characterResult.CameoID
resp["character_username"] = characterResult.Username
resp["character_display_name"] = characterResult.DisplayName
}
c.JSON(http.StatusOK, resp)
}
return &ForwardResult{
RequestID: "",
Model: reqModel,
Stream: clientStream,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
Usage: ClaudeUsage{},
MediaType: "prompt",
}, nil
}
if characterResult != nil && strings.TrimSpace(characterResult.Username) != "" {
prompt = fmt.Sprintf("@%s %s", characterResult.Username, prompt)
}
}
var imageData []byte
imageFilename := ""
if strings.TrimSpace(imageInput) != "" {
if imageInput != "" {
decoded, filename, err := decodeSoraImageInput(reqCtx, imageInput)
if err != nil {
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", err.Error(), clientStream)
@@ -198,15 +303,27 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun
MediaID: mediaID,
})
case "video":
taskID, err = s.soraClient.CreateVideoTask(reqCtx, account, SoraVideoRequest{
Prompt: prompt,
Orientation: modelCfg.Orientation,
Frames: modelCfg.Frames,
Model: modelCfg.Model,
Size: modelCfg.Size,
MediaID: mediaID,
RemixTargetID: remixTargetID,
})
if remixTargetID == "" && isSoraStoryboardPrompt(prompt) {
taskID, err = s.soraClient.CreateStoryboardTask(reqCtx, account, SoraStoryboardRequest{
Prompt: formatSoraStoryboardPrompt(prompt),
Orientation: modelCfg.Orientation,
Frames: modelCfg.Frames,
Model: modelCfg.Model,
Size: modelCfg.Size,
MediaID: mediaID,
})
} else {
taskID, err = s.soraClient.CreateVideoTask(reqCtx, account, SoraVideoRequest{
Prompt: prompt,
Orientation: modelCfg.Orientation,
Frames: modelCfg.Frames,
Model: modelCfg.Model,
Size: modelCfg.Size,
MediaID: mediaID,
RemixTargetID: remixTargetID,
CameoIDs: extractSoraCameoIDs(reqBody),
})
}
default:
err = fmt.Errorf("unsupported model type: %s", modelCfg.Type)
}
@@ -219,6 +336,7 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun
}
var mediaURLs []string
videoGenerationID := ""
mediaType := modelCfg.Type
imageCount := 0
imageSize := ""
@@ -232,15 +350,32 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun
imageCount = len(urls)
imageSize = soraImageSizeFromModel(reqModel)
case "video":
urls, pollErr := s.pollVideoTask(reqCtx, c, account, taskID, clientStream)
videoStatus, pollErr := s.pollVideoTaskDetailed(reqCtx, c, account, taskID, clientStream)
if pollErr != nil {
return nil, s.handleSoraRequestError(ctx, account, pollErr, reqModel, c, clientStream)
}
mediaURLs = urls
if videoStatus != nil {
mediaURLs = videoStatus.URLs
videoGenerationID = strings.TrimSpace(videoStatus.GenerationID)
}
default:
mediaType = "prompt"
}
watermarkPostID := ""
if modelCfg.Type == "video" && watermarkOpts.Enabled {
watermarkURL, postID, watermarkErr := s.resolveWatermarkFreeURL(reqCtx, account, videoGenerationID, watermarkOpts)
if watermarkErr != nil {
if !watermarkOpts.FallbackOnFailure {
return nil, s.handleSoraRequestError(ctx, account, watermarkErr, reqModel, c, clientStream)
}
log.Printf("[Sora] watermark-free fallback to original URL, task_id=%s err=%v", taskID, watermarkErr)
} else if strings.TrimSpace(watermarkURL) != "" {
mediaURLs = []string{strings.TrimSpace(watermarkURL)}
watermarkPostID = strings.TrimSpace(postID)
}
}
finalURLs := s.normalizeSoraMediaURLs(mediaURLs)
if len(mediaURLs) > 0 && s.mediaStorage != nil && s.mediaStorage.Enabled() {
stored, storeErr := s.mediaStorage.StoreFromURLs(reqCtx, mediaType, mediaURLs)
@@ -251,6 +386,11 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun
finalURLs = s.normalizeSoraMediaURLs(stored)
}
}
if watermarkPostID != "" && watermarkOpts.DeletePost {
if deleteErr := s.soraClient.DeletePost(reqCtx, account, watermarkPostID); deleteErr != nil {
log.Printf("[Sora] delete post failed, post_id=%s err=%v", watermarkPostID, deleteErr)
}
}
content := buildSoraContent(mediaType, finalURLs)
var firstTokenMs *int
@@ -299,6 +439,267 @@ func (s *SoraGatewayService) withSoraTimeout(ctx context.Context, stream bool) (
return context.WithTimeout(ctx, time.Duration(timeoutSeconds)*time.Second)
}
func parseSoraWatermarkOptions(body map[string]any) soraWatermarkOptions {
opts := soraWatermarkOptions{
Enabled: parseBoolWithDefault(body, "watermark_free", false),
ParseMethod: strings.ToLower(strings.TrimSpace(parseStringWithDefault(body, "watermark_parse_method", "third_party"))),
ParseURL: strings.TrimSpace(parseStringWithDefault(body, "watermark_parse_url", "")),
ParseToken: strings.TrimSpace(parseStringWithDefault(body, "watermark_parse_token", "")),
FallbackOnFailure: parseBoolWithDefault(body, "watermark_fallback_on_failure", true),
DeletePost: parseBoolWithDefault(body, "watermark_delete_post", false),
}
if opts.ParseMethod == "" {
opts.ParseMethod = "third_party"
}
return opts
}
func parseSoraCharacterOptions(body map[string]any) soraCharacterOptions {
return soraCharacterOptions{
SetPublic: parseBoolWithDefault(body, "character_set_public", true),
DeleteAfterGenerate: parseBoolWithDefault(body, "character_delete_after_generate", true),
}
}
func parseBoolWithDefault(body map[string]any, key string, def bool) bool {
if body == nil {
return def
}
val, ok := body[key]
if !ok {
return def
}
switch typed := val.(type) {
case bool:
return typed
case int:
return typed != 0
case int32:
return typed != 0
case int64:
return typed != 0
case float64:
return typed != 0
case string:
typed = strings.ToLower(strings.TrimSpace(typed))
if typed == "true" || typed == "1" || typed == "yes" {
return true
}
if typed == "false" || typed == "0" || typed == "no" {
return false
}
}
return def
}
func parseStringWithDefault(body map[string]any, key, def string) string {
if body == nil {
return def
}
val, ok := body[key]
if !ok {
return def
}
if str, ok := val.(string); ok {
return str
}
return def
}
func extractSoraCameoIDs(body map[string]any) []string {
if body == nil {
return nil
}
raw, ok := body["cameo_ids"]
if !ok {
return nil
}
switch typed := raw.(type) {
case []string:
out := make([]string, 0, len(typed))
for _, item := range typed {
item = strings.TrimSpace(item)
if item != "" {
out = append(out, item)
}
}
return out
case []any:
out := make([]string, 0, len(typed))
for _, item := range typed {
str, ok := item.(string)
if !ok {
continue
}
str = strings.TrimSpace(str)
if str != "" {
out = append(out, str)
}
}
return out
default:
return nil
}
}
func (s *SoraGatewayService) createCharacterFromVideo(ctx context.Context, account *Account, videoData []byte, opts soraCharacterOptions) (*soraCharacterFlowResult, error) {
cameoID, err := s.soraClient.UploadCharacterVideo(ctx, account, videoData)
if err != nil {
return nil, err
}
cameoStatus, err := s.pollCameoStatus(ctx, account, cameoID)
if err != nil {
return nil, err
}
username := processSoraCharacterUsername(cameoStatus.UsernameHint)
displayName := strings.TrimSpace(cameoStatus.DisplayNameHint)
if displayName == "" {
displayName = "Character"
}
profileAssetURL := strings.TrimSpace(cameoStatus.ProfileAssetURL)
if profileAssetURL == "" {
return nil, errors.New("profile asset url not found in cameo status")
}
avatarData, err := s.soraClient.DownloadCharacterImage(ctx, account, profileAssetURL)
if err != nil {
return nil, err
}
assetPointer, err := s.soraClient.UploadCharacterImage(ctx, account, avatarData)
if err != nil {
return nil, err
}
instructionSet := cameoStatus.InstructionSetHint
if instructionSet == nil {
instructionSet = cameoStatus.InstructionSet
}
characterID, err := s.soraClient.FinalizeCharacter(ctx, account, SoraCharacterFinalizeRequest{
CameoID: strings.TrimSpace(cameoID),
Username: username,
DisplayName: displayName,
ProfileAssetPointer: assetPointer,
InstructionSet: instructionSet,
})
if err != nil {
return nil, err
}
if opts.SetPublic {
if err := s.soraClient.SetCharacterPublic(ctx, account, cameoID); err != nil {
return nil, err
}
}
return &soraCharacterFlowResult{
CameoID: strings.TrimSpace(cameoID),
CharacterID: strings.TrimSpace(characterID),
Username: strings.TrimSpace(username),
DisplayName: displayName,
}, nil
}
func (s *SoraGatewayService) pollCameoStatus(ctx context.Context, account *Account, cameoID string) (*SoraCameoStatus, error) {
timeout := 10 * time.Minute
interval := 5 * time.Second
maxAttempts := int(math.Ceil(timeout.Seconds() / interval.Seconds()))
if maxAttempts < 1 {
maxAttempts = 1
}
var lastErr error
consecutiveErrors := 0
for attempt := 0; attempt < maxAttempts; attempt++ {
status, err := s.soraClient.GetCameoStatus(ctx, account, cameoID)
if err != nil {
lastErr = err
consecutiveErrors++
if consecutiveErrors >= 3 {
break
}
if attempt < maxAttempts-1 {
if sleepErr := sleepWithContext(ctx, interval); sleepErr != nil {
return nil, sleepErr
}
}
continue
}
consecutiveErrors = 0
if status == nil {
if attempt < maxAttempts-1 {
if sleepErr := sleepWithContext(ctx, interval); sleepErr != nil {
return nil, sleepErr
}
}
continue
}
currentStatus := strings.ToLower(strings.TrimSpace(status.Status))
statusMessage := strings.TrimSpace(status.StatusMessage)
if currentStatus == "failed" {
if statusMessage == "" {
statusMessage = "character creation failed"
}
return nil, errors.New(statusMessage)
}
if strings.EqualFold(statusMessage, "Completed") || currentStatus == "finalized" {
return status, nil
}
if attempt < maxAttempts-1 {
if sleepErr := sleepWithContext(ctx, interval); sleepErr != nil {
return nil, sleepErr
}
}
}
if lastErr != nil {
return nil, fmt.Errorf("poll cameo status failed: %w", lastErr)
}
return nil, errors.New("cameo processing timeout")
}
func processSoraCharacterUsername(usernameHint string) string {
usernameHint = strings.TrimSpace(usernameHint)
if usernameHint == "" {
usernameHint = "character"
}
if strings.Contains(usernameHint, ".") {
parts := strings.Split(usernameHint, ".")
usernameHint = strings.TrimSpace(parts[len(parts)-1])
}
if usernameHint == "" {
usernameHint = "character"
}
return fmt.Sprintf("%s%d", usernameHint, soraRandInt(900)+100)
}
func (s *SoraGatewayService) resolveWatermarkFreeURL(ctx context.Context, account *Account, generationID string, opts soraWatermarkOptions) (string, string, error) {
generationID = strings.TrimSpace(generationID)
if generationID == "" {
return "", "", errors.New("generation id is required for watermark-free mode")
}
postID, err := s.soraClient.PostVideoForWatermarkFree(ctx, account, generationID)
if err != nil {
return "", "", err
}
postID = strings.TrimSpace(postID)
if postID == "" {
return "", "", errors.New("watermark-free publish returned empty post id")
}
switch opts.ParseMethod {
case "custom":
urlVal, parseErr := s.soraClient.GetWatermarkFreeURLCustom(ctx, account, opts.ParseURL, opts.ParseToken, postID)
if parseErr != nil {
return "", postID, parseErr
}
return strings.TrimSpace(urlVal), postID, nil
case "", "third_party":
return fmt.Sprintf("https://oscdn2.dyysy.com/MP4/%s.mp4", postID), postID, nil
default:
return "", postID, fmt.Errorf("unsupported watermark parse method: %s", opts.ParseMethod)
}
}
func (s *SoraGatewayService) shouldFailoverUpstreamError(statusCode int) bool {
switch statusCode {
case 401, 402, 403, 404, 429, 529:
@@ -554,7 +955,7 @@ func (s *SoraGatewayService) pollImageTask(ctx context.Context, c *gin.Context,
return nil, errors.New("sora image generation timeout")
}
func (s *SoraGatewayService) pollVideoTask(ctx context.Context, c *gin.Context, account *Account, taskID string, stream bool) ([]string, error) {
func (s *SoraGatewayService) pollVideoTaskDetailed(ctx context.Context, c *gin.Context, account *Account, taskID string, stream bool) (*SoraVideoTaskStatus, error) {
interval := s.pollInterval()
maxAttempts := s.pollMaxAttempts()
lastPing := time.Now()
@@ -565,7 +966,7 @@ func (s *SoraGatewayService) pollVideoTask(ctx context.Context, c *gin.Context,
}
switch strings.ToLower(status.Status) {
case "completed", "succeeded":
return status.URLs, nil
return status, nil
case "failed":
if status.ErrorMsg != "" {
return nil, errors.New(status.ErrorMsg)
@@ -669,7 +1070,7 @@ func extractSoraInput(body map[string]any) (prompt, imageInput, videoInput, remi
return "", "", "", ""
}
if v, ok := body["remix_target_id"].(string); ok {
remixTargetID = v
remixTargetID = strings.TrimSpace(v)
}
if v, ok := body["image"].(string); ok {
imageInput = v
@@ -710,6 +1111,10 @@ func extractSoraInput(body map[string]any) (prompt, imageInput, videoInput, remi
prompt = builder.String()
}
}
if remixTargetID == "" {
remixTargetID = extractRemixTargetIDFromPrompt(prompt)
}
prompt = cleanRemixLinkFromPrompt(prompt)
return prompt, imageInput, videoInput, remixTargetID
}
@@ -757,6 +1162,69 @@ func parseSoraMessageContent(content any) (text, imageInput, videoInput string)
}
}
func isSoraStoryboardPrompt(prompt string) bool {
prompt = strings.TrimSpace(prompt)
if prompt == "" {
return false
}
return len(soraStoryboardPattern.FindAllString(prompt, -1)) >= 1
}
func formatSoraStoryboardPrompt(prompt string) string {
prompt = strings.TrimSpace(prompt)
if prompt == "" {
return ""
}
matches := soraStoryboardShotPattern.FindAllStringSubmatch(prompt, -1)
if len(matches) == 0 {
return prompt
}
firstBracketPos := strings.Index(prompt, "[")
instructions := ""
if firstBracketPos > 0 {
instructions = strings.TrimSpace(prompt[:firstBracketPos])
}
shots := make([]string, 0, len(matches))
for i, match := range matches {
if len(match) < 3 {
continue
}
duration := strings.TrimSpace(match[1])
scene := strings.TrimSpace(match[2])
if scene == "" {
continue
}
shots = append(shots, fmt.Sprintf("Shot %d:\nduration: %ssec\nScene: %s", i+1, duration, scene))
}
if len(shots) == 0 {
return prompt
}
timeline := strings.Join(shots, "\n\n")
if instructions == "" {
return timeline
}
return fmt.Sprintf("current timeline:\n%s\n\ninstructions:\n%s", timeline, instructions)
}
func extractRemixTargetIDFromPrompt(prompt string) string {
prompt = strings.TrimSpace(prompt)
if prompt == "" {
return ""
}
return strings.TrimSpace(soraRemixTargetPattern.FindString(prompt))
}
func cleanRemixLinkFromPrompt(prompt string) string {
prompt = strings.TrimSpace(prompt)
if prompt == "" {
return prompt
}
cleaned := soraRemixTargetInURLPattern.ReplaceAllString(prompt, "")
cleaned = soraRemixTargetPattern.ReplaceAllString(cleaned, "")
cleaned = strings.Join(strings.Fields(cleaned), " ")
return strings.TrimSpace(cleaned)
}
func decodeSoraImageInput(ctx context.Context, input string) ([]byte, string, error) {
raw := strings.TrimSpace(input)
if raw == "" {
@@ -769,7 +1237,7 @@ func decodeSoraImageInput(ctx context.Context, input string) ([]byte, string, er
}
meta := parts[0]
payload := parts[1]
decoded, err := base64.StdEncoding.DecodeString(payload)
decoded, err := decodeBase64WithLimit(payload, soraImageInputMaxBytes)
if err != nil {
return nil, "", err
}
@@ -788,15 +1256,47 @@ func decodeSoraImageInput(ctx context.Context, input string) ([]byte, string, er
if strings.HasPrefix(raw, "http://") || strings.HasPrefix(raw, "https://") {
return downloadSoraImageInput(ctx, raw)
}
decoded, err := base64.StdEncoding.DecodeString(raw)
decoded, err := decodeBase64WithLimit(raw, soraImageInputMaxBytes)
if err != nil {
return nil, "", errors.New("invalid base64 image")
}
return decoded, "image.png", nil
}
func decodeSoraVideoInput(ctx context.Context, input string) ([]byte, error) {
raw := strings.TrimSpace(input)
if raw == "" {
return nil, errors.New("empty video input")
}
if strings.HasPrefix(raw, "data:") {
parts := strings.SplitN(raw, ",", 2)
if len(parts) != 2 {
return nil, errors.New("invalid video data url")
}
decoded, err := decodeBase64WithLimit(parts[1], soraVideoInputMaxBytes)
if err != nil {
return nil, errors.New("invalid base64 video")
}
if len(decoded) == 0 {
return nil, errors.New("empty video data")
}
return decoded, nil
}
if strings.HasPrefix(raw, "http://") || strings.HasPrefix(raw, "https://") {
return downloadSoraVideoInput(ctx, raw)
}
decoded, err := decodeBase64WithLimit(raw, soraVideoInputMaxBytes)
if err != nil {
return nil, errors.New("invalid base64 video")
}
if len(decoded) == 0 {
return nil, errors.New("empty video data")
}
return decoded, nil
}
func downloadSoraImageInput(ctx context.Context, rawURL string) ([]byte, string, error) {
parsed, err := validateSoraImageURL(rawURL)
parsed, err := validateSoraRemoteURL(rawURL)
if err != nil {
return nil, "", err
}
@@ -810,7 +1310,7 @@ func downloadSoraImageInput(ctx context.Context, rawURL string) ([]byte, string,
if len(via) >= soraImageInputMaxRedirects {
return errors.New("too many redirects")
}
return validateSoraImageURLValue(req.URL)
return validateSoraRemoteURLValue(req.URL)
},
}
resp, err := client.Do(req)
@@ -833,51 +1333,103 @@ func downloadSoraImageInput(ctx context.Context, rawURL string) ([]byte, string,
return data, filename, nil
}
func validateSoraImageURL(raw string) (*url.URL, error) {
func downloadSoraVideoInput(ctx context.Context, rawURL string) ([]byte, error) {
parsed, err := validateSoraRemoteURL(rawURL)
if err != nil {
return nil, err
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, parsed.String(), nil)
if err != nil {
return nil, err
}
client := &http.Client{
Timeout: soraVideoInputTimeout,
CheckRedirect: func(req *http.Request, via []*http.Request) error {
if len(via) >= soraVideoInputMaxRedirects {
return errors.New("too many redirects")
}
return validateSoraRemoteURLValue(req.URL)
},
}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("download video failed: %d", resp.StatusCode)
}
data, err := io.ReadAll(io.LimitReader(resp.Body, soraVideoInputMaxBytes))
if err != nil {
return nil, err
}
if len(data) == 0 {
return nil, errors.New("empty video content")
}
return data, nil
}
func decodeBase64WithLimit(encoded string, maxBytes int64) ([]byte, error) {
if maxBytes <= 0 {
return nil, errors.New("invalid max bytes limit")
}
decoder := base64.NewDecoder(base64.StdEncoding, strings.NewReader(encoded))
limited := io.LimitReader(decoder, maxBytes+1)
data, err := io.ReadAll(limited)
if err != nil {
return nil, err
}
if int64(len(data)) > maxBytes {
return nil, fmt.Errorf("input exceeds %d bytes limit", maxBytes)
}
return data, nil
}
func validateSoraRemoteURL(raw string) (*url.URL, error) {
if strings.TrimSpace(raw) == "" {
return nil, errors.New("empty image url")
return nil, errors.New("empty remote url")
}
parsed, err := url.Parse(raw)
if err != nil {
return nil, fmt.Errorf("invalid image url: %w", err)
return nil, fmt.Errorf("invalid remote url: %w", err)
}
if err := validateSoraImageURLValue(parsed); err != nil {
if err := validateSoraRemoteURLValue(parsed); err != nil {
return nil, err
}
return parsed, nil
}
func validateSoraImageURLValue(parsed *url.URL) error {
func validateSoraRemoteURLValue(parsed *url.URL) error {
if parsed == nil {
return errors.New("invalid image url")
return errors.New("invalid remote url")
}
scheme := strings.ToLower(strings.TrimSpace(parsed.Scheme))
if scheme != "http" && scheme != "https" {
return errors.New("only http/https image url is allowed")
return errors.New("only http/https remote url is allowed")
}
if parsed.User != nil {
return errors.New("image url cannot contain userinfo")
return errors.New("remote url cannot contain userinfo")
}
host := strings.ToLower(strings.TrimSpace(parsed.Hostname()))
if host == "" {
return errors.New("image url missing host")
return errors.New("remote url missing host")
}
if _, blocked := soraBlockedHostnames[host]; blocked {
return errors.New("image url is not allowed")
return errors.New("remote url is not allowed")
}
if ip := net.ParseIP(host); ip != nil {
if isSoraBlockedIP(ip) {
return errors.New("image url is not allowed")
return errors.New("remote url is not allowed")
}
return nil
}
ips, err := net.LookupIP(host)
if err != nil {
return fmt.Errorf("resolve image url failed: %w", err)
return fmt.Errorf("resolve remote url failed: %w", err)
}
for _, ip := range ips {
if isSoraBlockedIP(ip) {
return errors.New("image url is not allowed")
return errors.New("remote url is not allowed")
}
}
return nil