diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index 147cc3e9..7d1b10e8 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -218,6 +218,8 @@ type SoraClientConfig struct { MaxRetries int `mapstructure:"max_retries"` PollIntervalSeconds int `mapstructure:"poll_interval_seconds"` MaxPollAttempts int `mapstructure:"max_poll_attempts"` + RecentTaskLimit int `mapstructure:"recent_task_limit"` + RecentTaskLimitMax int `mapstructure:"recent_task_limit_max"` Debug bool `mapstructure:"debug"` Headers map[string]string `mapstructure:"headers"` UserAgent string `mapstructure:"user_agent"` @@ -230,6 +232,8 @@ type SoraStorageConfig struct { LocalPath string `mapstructure:"local_path"` FallbackToUpstream bool `mapstructure:"fallback_to_upstream"` MaxConcurrentDownloads int `mapstructure:"max_concurrent_downloads"` + DownloadTimeoutSeconds int `mapstructure:"download_timeout_seconds"` + MaxDownloadBytes int64 `mapstructure:"max_download_bytes"` Debug bool `mapstructure:"debug"` Cleanup SoraStorageCleanupConfig `mapstructure:"cleanup"` } @@ -929,6 +933,8 @@ func setDefaults() { viper.SetDefault("sora.client.max_retries", 3) viper.SetDefault("sora.client.poll_interval_seconds", 2) viper.SetDefault("sora.client.max_poll_attempts", 600) + viper.SetDefault("sora.client.recent_task_limit", 50) + viper.SetDefault("sora.client.recent_task_limit_max", 200) viper.SetDefault("sora.client.debug", false) viper.SetDefault("sora.client.headers", map[string]string{}) viper.SetDefault("sora.client.user_agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)") @@ -938,6 +944,8 @@ func setDefaults() { viper.SetDefault("sora.storage.local_path", "") viper.SetDefault("sora.storage.fallback_to_upstream", true) viper.SetDefault("sora.storage.max_concurrent_downloads", 4) + viper.SetDefault("sora.storage.download_timeout_seconds", 120) + viper.SetDefault("sora.storage.max_download_bytes", int64(200<<20)) viper.SetDefault("sora.storage.debug", false) viper.SetDefault("sora.storage.cleanup.enabled", true) viper.SetDefault("sora.storage.cleanup.retention_days", 7) @@ -1205,9 +1213,25 @@ func (c *Config) Validate() error { if c.Sora.Client.MaxPollAttempts < 0 { return fmt.Errorf("sora.client.max_poll_attempts must be non-negative") } + if c.Sora.Client.RecentTaskLimit < 0 { + return fmt.Errorf("sora.client.recent_task_limit must be non-negative") + } + if c.Sora.Client.RecentTaskLimitMax < 0 { + return fmt.Errorf("sora.client.recent_task_limit_max must be non-negative") + } + if c.Sora.Client.RecentTaskLimitMax > 0 && c.Sora.Client.RecentTaskLimit > 0 && + c.Sora.Client.RecentTaskLimitMax < c.Sora.Client.RecentTaskLimit { + c.Sora.Client.RecentTaskLimitMax = c.Sora.Client.RecentTaskLimit + } if c.Sora.Storage.MaxConcurrentDownloads < 0 { return fmt.Errorf("sora.storage.max_concurrent_downloads must be non-negative") } + if c.Sora.Storage.DownloadTimeoutSeconds < 0 { + return fmt.Errorf("sora.storage.download_timeout_seconds must be non-negative") + } + if c.Sora.Storage.MaxDownloadBytes < 0 { + return fmt.Errorf("sora.storage.max_download_bytes must be non-negative") + } if c.Sora.Storage.Cleanup.Enabled { if c.Sora.Storage.Cleanup.RetentionDays <= 0 { return fmt.Errorf("sora.storage.cleanup.retention_days must be positive") diff --git a/backend/internal/service/sora_client.go b/backend/internal/service/sora_client.go index 9ecb4688..f3a71a79 100644 --- a/backend/internal/service/sora_client.go +++ b/backend/internal/service/sora_client.go @@ -359,18 +359,43 @@ func (c *SoraDirectClient) CreateVideoTask(ctx context.Context, account *Account } func (c *SoraDirectClient) GetImageTask(ctx context.Context, account *Account, taskID string) (*SoraImageTaskStatus, error) { - token, err := c.getAccessToken(ctx, account) + status, found, err := c.fetchRecentImageTask(ctx, account, taskID, c.recentTaskLimit()) if err != nil { return nil, err } - headers := c.buildBaseHeaders(token, c.defaultUserAgent()) - respBody, _, err := c.doRequest(ctx, account, http.MethodGet, c.buildURL("/v2/recent_tasks?limit=20"), headers, nil, false) + if found { + return status, nil + } + maxLimit := c.recentTaskLimitMax() + if maxLimit > 0 && maxLimit != c.recentTaskLimit() { + status, found, err = c.fetchRecentImageTask(ctx, account, taskID, maxLimit) + if err != nil { + return nil, err + } + if found { + return status, nil + } + } + return &SoraImageTaskStatus{ID: taskID, Status: "processing"}, nil +} + +func (c *SoraDirectClient) fetchRecentImageTask(ctx context.Context, account *Account, taskID string, limit int) (*SoraImageTaskStatus, bool, error) { + token, err := c.getAccessToken(ctx, account) if err != nil { - return nil, err + return nil, false, err + } + headers := c.buildBaseHeaders(token, c.defaultUserAgent()) + if limit <= 0 { + limit = 20 + } + endpoint := fmt.Sprintf("/v2/recent_tasks?limit=%d", limit) + respBody, _, err := c.doRequest(ctx, account, http.MethodGet, c.buildURL(endpoint), headers, nil, false) + if err != nil { + return nil, false, err } var resp map[string]any if err := json.Unmarshal(respBody, &resp); err != nil { - return nil, err + return nil, false, err } taskResponses, _ := resp["task_responses"].([]any) for _, item := range taskResponses { @@ -401,10 +426,30 @@ func (c *SoraDirectClient) GetImageTask(ctx context.Context, account *Account, t Status: status, ProgressPct: progress, URLs: urls, - }, nil + }, true, nil } } - return &SoraImageTaskStatus{ID: taskID, Status: "processing"}, nil + return &SoraImageTaskStatus{ID: taskID, Status: "processing"}, false, nil +} + +func (c *SoraDirectClient) recentTaskLimit() int { + if c == nil || c.cfg == nil { + return 20 + } + if c.cfg.Sora.Client.RecentTaskLimit > 0 { + return c.cfg.Sora.Client.RecentTaskLimit + } + return 20 +} + +func (c *SoraDirectClient) recentTaskLimitMax() int { + if c == nil || c.cfg == nil { + return 0 + } + if c.cfg.Sora.Client.RecentTaskLimitMax > 0 { + return c.cfg.Sora.Client.RecentTaskLimitMax + } + return 0 } func (c *SoraDirectClient) GetVideoTask(ctx context.Context, account *Account, taskID string) (*SoraVideoTaskStatus, error) { diff --git a/backend/internal/service/sora_client_test.go b/backend/internal/service/sora_client_test.go index abbe47a1..a6bf71cd 100644 --- a/backend/internal/service/sora_client_test.go +++ b/backend/internal/service/sora_client_test.go @@ -52,3 +52,36 @@ func TestSoraDirectClient_BuildBaseHeaders(t *testing.T) { require.Equal(t, "yes", headers.Get("X-Test")) require.Empty(t, headers.Get("openai-sentinel-token")) } + +func TestSoraDirectClient_GetImageTaskFallbackLimit(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + limit := r.URL.Query().Get("limit") + w.Header().Set("Content-Type", "application/json") + switch limit { + case "1": + _, _ = w.Write([]byte(`{"task_responses":[]}`)) + case "2": + _, _ = w.Write([]byte(`{"task_responses":[{"id":"task-1","status":"completed","progress_pct":1,"generations":[{"url":"https://example.com/a.png"}]}]}`)) + default: + _, _ = w.Write([]byte(`{"task_responses":[]}`)) + } + })) + defer server.Close() + + cfg := &config.Config{ + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + BaseURL: server.URL, + RecentTaskLimit: 1, + RecentTaskLimitMax: 2, + }, + }, + } + client := NewSoraDirectClient(cfg, nil, nil) + account := &Account{Credentials: map[string]any{"access_token": "token"}} + + status, err := client.GetImageTask(context.Background(), account, "task-1") + require.NoError(t, err) + require.Equal(t, "completed", status.Status) + require.Equal(t, []string{"https://example.com/a.png"}, status.URLs) +} diff --git a/backend/internal/service/sora_gateway_service.go b/backend/internal/service/sora_gateway_service.go index 49cd7bba..ea696d63 100644 --- a/backend/internal/service/sora_gateway_service.go +++ b/backend/internal/service/sora_gateway_service.go @@ -10,6 +10,7 @@ import ( "fmt" "io" "mime" + "net" "net/http" "net/url" "regexp" @@ -26,6 +27,9 @@ var soraImageMarkdownRe = regexp.MustCompile(`!\[[^\]]*\]\(([^)]+)\)`) var soraVideoHTMLRe = regexp.MustCompile(`(?i)]+src=['"]([^'"]+)['"]`) const soraRewriteBufferLimit = 2048 +const soraImageInputMaxBytes = 20 << 20 +const soraImageInputMaxRedirects = 3 +const soraImageInputTimeout = 20 * time.Second var soraImageSizeMap = map[string]string{ "gpt-image": "360", @@ -33,6 +37,29 @@ var soraImageSizeMap = map[string]string{ "gpt-image-portrait": "540", } +var soraBlockedHostnames = map[string]struct{}{ + "localhost": {}, + "localhost.localdomain": {}, + "metadata.google.internal": {}, + "metadata.google.internal.": {}, +} + +var soraBlockedCIDRs = mustParseCIDRs([]string{ + "0.0.0.0/8", + "10.0.0.0/8", + "100.64.0.0/10", + "127.0.0.0/8", + "169.254.0.0/16", + "172.16.0.0/12", + "192.168.0.0/16", + "224.0.0.0/4", + "240.0.0.0/4", + "::/128", + "::1/128", + "fc00::/7", + "fe80::/10", +}) + type soraStreamingResult struct { mediaType string mediaURLs []string @@ -1233,11 +1260,24 @@ func decodeSoraImageInput(ctx context.Context, input string) ([]byte, string, er } func downloadSoraImageInput(ctx context.Context, rawURL string) ([]byte, string, error) { - req, err := http.NewRequestWithContext(ctx, http.MethodGet, rawURL, nil) + parsed, err := validateSoraImageURL(rawURL) if err != nil { return nil, "", err } - resp, err := http.DefaultClient.Do(req) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, parsed.String(), nil) + if err != nil { + return nil, "", err + } + client := &http.Client{ + Timeout: soraImageInputTimeout, + CheckRedirect: func(req *http.Request, via []*http.Request) error { + if len(via) >= soraImageInputMaxRedirects { + return errors.New("too many redirects") + } + return validateSoraImageURLValue(req.URL) + }, + } + resp, err := client.Do(req) if err != nil { return nil, "", err } @@ -1245,14 +1285,88 @@ func downloadSoraImageInput(ctx context.Context, rawURL string) ([]byte, string, if resp.StatusCode != http.StatusOK { return nil, "", fmt.Errorf("download image failed: %d", resp.StatusCode) } - data, err := io.ReadAll(io.LimitReader(resp.Body, 20<<20)) + data, err := io.ReadAll(io.LimitReader(resp.Body, soraImageInputMaxBytes)) if err != nil { return nil, "", err } - ext := fileExtFromURL(rawURL) + ext := fileExtFromURL(parsed.String()) if ext == "" { ext = fileExtFromContentType(resp.Header.Get("Content-Type")) } filename := "image" + ext return data, filename, nil } + +func validateSoraImageURL(raw string) (*url.URL, error) { + if strings.TrimSpace(raw) == "" { + return nil, errors.New("empty image url") + } + parsed, err := url.Parse(raw) + if err != nil { + return nil, fmt.Errorf("invalid image url: %w", err) + } + if err := validateSoraImageURLValue(parsed); err != nil { + return nil, err + } + return parsed, nil +} + +func validateSoraImageURLValue(parsed *url.URL) error { + if parsed == nil { + return errors.New("invalid image url") + } + scheme := strings.ToLower(strings.TrimSpace(parsed.Scheme)) + if scheme != "http" && scheme != "https" { + return errors.New("only http/https image url is allowed") + } + if parsed.User != nil { + return errors.New("image url cannot contain userinfo") + } + host := strings.ToLower(strings.TrimSpace(parsed.Hostname())) + if host == "" { + return errors.New("image url missing host") + } + if _, blocked := soraBlockedHostnames[host]; blocked { + return errors.New("image url is not allowed") + } + if ip := net.ParseIP(host); ip != nil { + if isSoraBlockedIP(ip) { + return errors.New("image url is not allowed") + } + return nil + } + ips, err := net.LookupIP(host) + if err != nil { + return fmt.Errorf("resolve image url failed: %w", err) + } + for _, ip := range ips { + if isSoraBlockedIP(ip) { + return errors.New("image url is not allowed") + } + } + return nil +} + +func isSoraBlockedIP(ip net.IP) bool { + if ip == nil { + return true + } + for _, cidr := range soraBlockedCIDRs { + if cidr.Contains(ip) { + return true + } + } + return false +} + +func mustParseCIDRs(values []string) []*net.IPNet { + out := make([]*net.IPNet, 0, len(values)) + for _, val := range values { + _, cidr, err := net.ParseCIDR(val) + if err != nil { + continue + } + out = append(out, cidr) + } + return out +} diff --git a/backend/internal/service/sora_gateway_service_test.go b/backend/internal/service/sora_gateway_service_test.go index e4de8256..caa10427 100644 --- a/backend/internal/service/sora_gateway_service_test.go +++ b/backend/internal/service/sora_gateway_service_test.go @@ -97,3 +97,16 @@ func TestSoraGatewayService_BuildSoraMediaURLSigned(t *testing.T) { require.Contains(t, url, "expires=") require.Contains(t, url, "sig=") } + +func TestDecodeSoraImageInput_BlockPrivateURL(t *testing.T) { + _, _, err := decodeSoraImageInput(context.Background(), "http://127.0.0.1/internal.png") + require.Error(t, err) +} + +func TestDecodeSoraImageInput_DataURL(t *testing.T) { + encoded := "data:image/png;base64,aGVsbG8=" + data, filename, err := decodeSoraImageInput(context.Background(), encoded) + require.NoError(t, err) + require.NotEmpty(t, data) + require.Contains(t, filename, ".png") +} diff --git a/backend/internal/service/sora_media_storage.go b/backend/internal/service/sora_media_storage.go index 53214bb7..4359f78d 100644 --- a/backend/internal/service/sora_media_storage.go +++ b/backend/internal/service/sora_media_storage.go @@ -30,6 +30,8 @@ type SoraMediaStorage struct { imageRoot string videoRoot string maxConcurrent int + downloadTimeout time.Duration + maxDownloadBytes int64 fallbackToUpstream bool debug bool sem chan struct{} @@ -92,6 +94,17 @@ func (s *SoraMediaStorage) refreshConfig() { maxConcurrent = 4 } s.maxConcurrent = maxConcurrent + timeoutSeconds := s.cfg.Sora.Storage.DownloadTimeoutSeconds + if timeoutSeconds <= 0 { + timeoutSeconds = 120 + } + s.downloadTimeout = time.Duration(timeoutSeconds) * time.Second + + maxBytes := s.cfg.Sora.Storage.MaxDownloadBytes + if maxBytes <= 0 { + maxBytes = 200 << 20 + } + s.maxDownloadBytes = maxBytes s.fallbackToUpstream = s.cfg.Sora.Storage.FallbackToUpstream s.debug = s.cfg.Sora.Storage.Debug s.sem = make(chan struct{}, maxConcurrent) @@ -180,7 +193,8 @@ func (s *SoraMediaStorage) downloadOnce(ctx context.Context, root, mediaType, ra if err != nil { return "", err } - resp, err := http.DefaultClient.Do(req) + client := &http.Client{Timeout: s.downloadTimeout} + resp, err := client.Do(req) if err != nil { return "", err } @@ -198,6 +212,9 @@ func (s *SoraMediaStorage) downloadOnce(ctx context.Context, root, mediaType, ra if ext == "" { ext = ".bin" } + if s.maxDownloadBytes > 0 && resp.ContentLength > s.maxDownloadBytes { + return "", fmt.Errorf("download size exceeds limit: %d", resp.ContentLength) + } datePath := time.Now().Format("2006/01/02") destDir := filepath.Join(root, filepath.FromSlash(datePath)) @@ -212,10 +229,16 @@ func (s *SoraMediaStorage) downloadOnce(ctx context.Context, root, mediaType, ra } defer func() { _ = out.Close() }() - if _, err := io.Copy(out, resp.Body); err != nil { + limited := io.LimitReader(resp.Body, s.maxDownloadBytes+1) + written, err := io.Copy(out, limited) + if err != nil { _ = os.Remove(destPath) return "", err } + if s.maxDownloadBytes > 0 && written > s.maxDownloadBytes { + _ = os.Remove(destPath) + return "", fmt.Errorf("download size exceeds limit: %d", written) + } relative := path.Join("/", mediaType, datePath, filename) if s.debug { diff --git a/backend/internal/service/sora_media_storage_test.go b/backend/internal/service/sora_media_storage_test.go index f86234d2..0050afed 100644 --- a/backend/internal/service/sora_media_storage_test.go +++ b/backend/internal/service/sora_media_storage_test.go @@ -67,3 +67,27 @@ func TestSoraMediaStorage_FallbackToUpstream(t *testing.T) { require.NoError(t, err) require.Equal(t, []string{url}, urls) } + +func TestSoraMediaStorage_MaxDownloadBytes(t *testing.T) { + tmpDir := t.TempDir() + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "image/png") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("too-large")) + })) + defer server.Close() + + cfg := &config.Config{ + Sora: config.SoraConfig{ + Storage: config.SoraStorageConfig{ + Type: "local", + LocalPath: tmpDir, + MaxDownloadBytes: 1, + }, + }, + } + + storage := NewSoraMediaStorage(cfg) + _, err := storage.StoreFromURLs(context.Background(), "image", []string{server.URL + "/img.png"}) + require.Error(t, err) +} diff --git a/deploy/config.example.yaml b/deploy/config.example.yaml index 2c7a1778..9ef3ee5c 100644 --- a/deploy/config.example.yaml +++ b/deploy/config.example.yaml @@ -269,6 +269,12 @@ sora: # Max poll attempts # 最大轮询次数 max_poll_attempts: 600 + # Recent task query limit (image) + # 最近任务查询数量(图片轮询) + recent_task_limit: 50 + # Recent task query max limit (fallback) + # 最近任务查询最大数量(回退) + recent_task_limit_max: 200 # Enable debug logs for Sora upstream requests # 启用 Sora 直连调试日志 debug: false @@ -294,6 +300,12 @@ sora: # Max concurrent downloads # 并发下载上限 max_concurrent_downloads: 4 + # Download timeout (seconds) + # 下载超时(秒) + download_timeout_seconds: 120 + # Max download bytes + # 最大下载字节数 + max_download_bytes: 209715200 # Enable debug logs for media storage # 启用媒体存储调试日志 debug: false