fix(Sora): 加固直连安全与下载限制
补充图片输入 SSRF 防护与重定向限制\n增加媒体下载超时/大小上限配置并更新示例\n完善 recent_tasks 轮询回退策略与相关测试\n\n测试: go test ./... -tags=unit
This commit is contained in:
@@ -218,6 +218,8 @@ type SoraClientConfig struct {
|
|||||||
MaxRetries int `mapstructure:"max_retries"`
|
MaxRetries int `mapstructure:"max_retries"`
|
||||||
PollIntervalSeconds int `mapstructure:"poll_interval_seconds"`
|
PollIntervalSeconds int `mapstructure:"poll_interval_seconds"`
|
||||||
MaxPollAttempts int `mapstructure:"max_poll_attempts"`
|
MaxPollAttempts int `mapstructure:"max_poll_attempts"`
|
||||||
|
RecentTaskLimit int `mapstructure:"recent_task_limit"`
|
||||||
|
RecentTaskLimitMax int `mapstructure:"recent_task_limit_max"`
|
||||||
Debug bool `mapstructure:"debug"`
|
Debug bool `mapstructure:"debug"`
|
||||||
Headers map[string]string `mapstructure:"headers"`
|
Headers map[string]string `mapstructure:"headers"`
|
||||||
UserAgent string `mapstructure:"user_agent"`
|
UserAgent string `mapstructure:"user_agent"`
|
||||||
@@ -230,6 +232,8 @@ type SoraStorageConfig struct {
|
|||||||
LocalPath string `mapstructure:"local_path"`
|
LocalPath string `mapstructure:"local_path"`
|
||||||
FallbackToUpstream bool `mapstructure:"fallback_to_upstream"`
|
FallbackToUpstream bool `mapstructure:"fallback_to_upstream"`
|
||||||
MaxConcurrentDownloads int `mapstructure:"max_concurrent_downloads"`
|
MaxConcurrentDownloads int `mapstructure:"max_concurrent_downloads"`
|
||||||
|
DownloadTimeoutSeconds int `mapstructure:"download_timeout_seconds"`
|
||||||
|
MaxDownloadBytes int64 `mapstructure:"max_download_bytes"`
|
||||||
Debug bool `mapstructure:"debug"`
|
Debug bool `mapstructure:"debug"`
|
||||||
Cleanup SoraStorageCleanupConfig `mapstructure:"cleanup"`
|
Cleanup SoraStorageCleanupConfig `mapstructure:"cleanup"`
|
||||||
}
|
}
|
||||||
@@ -929,6 +933,8 @@ func setDefaults() {
|
|||||||
viper.SetDefault("sora.client.max_retries", 3)
|
viper.SetDefault("sora.client.max_retries", 3)
|
||||||
viper.SetDefault("sora.client.poll_interval_seconds", 2)
|
viper.SetDefault("sora.client.poll_interval_seconds", 2)
|
||||||
viper.SetDefault("sora.client.max_poll_attempts", 600)
|
viper.SetDefault("sora.client.max_poll_attempts", 600)
|
||||||
|
viper.SetDefault("sora.client.recent_task_limit", 50)
|
||||||
|
viper.SetDefault("sora.client.recent_task_limit_max", 200)
|
||||||
viper.SetDefault("sora.client.debug", false)
|
viper.SetDefault("sora.client.debug", false)
|
||||||
viper.SetDefault("sora.client.headers", map[string]string{})
|
viper.SetDefault("sora.client.headers", map[string]string{})
|
||||||
viper.SetDefault("sora.client.user_agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)")
|
viper.SetDefault("sora.client.user_agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)")
|
||||||
@@ -938,6 +944,8 @@ func setDefaults() {
|
|||||||
viper.SetDefault("sora.storage.local_path", "")
|
viper.SetDefault("sora.storage.local_path", "")
|
||||||
viper.SetDefault("sora.storage.fallback_to_upstream", true)
|
viper.SetDefault("sora.storage.fallback_to_upstream", true)
|
||||||
viper.SetDefault("sora.storage.max_concurrent_downloads", 4)
|
viper.SetDefault("sora.storage.max_concurrent_downloads", 4)
|
||||||
|
viper.SetDefault("sora.storage.download_timeout_seconds", 120)
|
||||||
|
viper.SetDefault("sora.storage.max_download_bytes", int64(200<<20))
|
||||||
viper.SetDefault("sora.storage.debug", false)
|
viper.SetDefault("sora.storage.debug", false)
|
||||||
viper.SetDefault("sora.storage.cleanup.enabled", true)
|
viper.SetDefault("sora.storage.cleanup.enabled", true)
|
||||||
viper.SetDefault("sora.storage.cleanup.retention_days", 7)
|
viper.SetDefault("sora.storage.cleanup.retention_days", 7)
|
||||||
@@ -1205,9 +1213,25 @@ func (c *Config) Validate() error {
|
|||||||
if c.Sora.Client.MaxPollAttempts < 0 {
|
if c.Sora.Client.MaxPollAttempts < 0 {
|
||||||
return fmt.Errorf("sora.client.max_poll_attempts must be non-negative")
|
return fmt.Errorf("sora.client.max_poll_attempts must be non-negative")
|
||||||
}
|
}
|
||||||
|
if c.Sora.Client.RecentTaskLimit < 0 {
|
||||||
|
return fmt.Errorf("sora.client.recent_task_limit must be non-negative")
|
||||||
|
}
|
||||||
|
if c.Sora.Client.RecentTaskLimitMax < 0 {
|
||||||
|
return fmt.Errorf("sora.client.recent_task_limit_max must be non-negative")
|
||||||
|
}
|
||||||
|
if c.Sora.Client.RecentTaskLimitMax > 0 && c.Sora.Client.RecentTaskLimit > 0 &&
|
||||||
|
c.Sora.Client.RecentTaskLimitMax < c.Sora.Client.RecentTaskLimit {
|
||||||
|
c.Sora.Client.RecentTaskLimitMax = c.Sora.Client.RecentTaskLimit
|
||||||
|
}
|
||||||
if c.Sora.Storage.MaxConcurrentDownloads < 0 {
|
if c.Sora.Storage.MaxConcurrentDownloads < 0 {
|
||||||
return fmt.Errorf("sora.storage.max_concurrent_downloads must be non-negative")
|
return fmt.Errorf("sora.storage.max_concurrent_downloads must be non-negative")
|
||||||
}
|
}
|
||||||
|
if c.Sora.Storage.DownloadTimeoutSeconds < 0 {
|
||||||
|
return fmt.Errorf("sora.storage.download_timeout_seconds must be non-negative")
|
||||||
|
}
|
||||||
|
if c.Sora.Storage.MaxDownloadBytes < 0 {
|
||||||
|
return fmt.Errorf("sora.storage.max_download_bytes must be non-negative")
|
||||||
|
}
|
||||||
if c.Sora.Storage.Cleanup.Enabled {
|
if c.Sora.Storage.Cleanup.Enabled {
|
||||||
if c.Sora.Storage.Cleanup.RetentionDays <= 0 {
|
if c.Sora.Storage.Cleanup.RetentionDays <= 0 {
|
||||||
return fmt.Errorf("sora.storage.cleanup.retention_days must be positive")
|
return fmt.Errorf("sora.storage.cleanup.retention_days must be positive")
|
||||||
|
|||||||
@@ -359,18 +359,43 @@ func (c *SoraDirectClient) CreateVideoTask(ctx context.Context, account *Account
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *SoraDirectClient) GetImageTask(ctx context.Context, account *Account, taskID string) (*SoraImageTaskStatus, error) {
|
func (c *SoraDirectClient) GetImageTask(ctx context.Context, account *Account, taskID string) (*SoraImageTaskStatus, error) {
|
||||||
token, err := c.getAccessToken(ctx, account)
|
status, found, err := c.fetchRecentImageTask(ctx, account, taskID, c.recentTaskLimit())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
headers := c.buildBaseHeaders(token, c.defaultUserAgent())
|
if found {
|
||||||
respBody, _, err := c.doRequest(ctx, account, http.MethodGet, c.buildURL("/v2/recent_tasks?limit=20"), headers, nil, false)
|
return status, nil
|
||||||
|
}
|
||||||
|
maxLimit := c.recentTaskLimitMax()
|
||||||
|
if maxLimit > 0 && maxLimit != c.recentTaskLimit() {
|
||||||
|
status, found, err = c.fetchRecentImageTask(ctx, account, taskID, maxLimit)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if found {
|
||||||
|
return status, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return &SoraImageTaskStatus{ID: taskID, Status: "processing"}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *SoraDirectClient) fetchRecentImageTask(ctx context.Context, account *Account, taskID string, limit int) (*SoraImageTaskStatus, bool, error) {
|
||||||
|
token, err := c.getAccessToken(ctx, account)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, false, err
|
||||||
|
}
|
||||||
|
headers := c.buildBaseHeaders(token, c.defaultUserAgent())
|
||||||
|
if limit <= 0 {
|
||||||
|
limit = 20
|
||||||
|
}
|
||||||
|
endpoint := fmt.Sprintf("/v2/recent_tasks?limit=%d", limit)
|
||||||
|
respBody, _, err := c.doRequest(ctx, account, http.MethodGet, c.buildURL(endpoint), headers, nil, false)
|
||||||
|
if err != nil {
|
||||||
|
return nil, false, err
|
||||||
}
|
}
|
||||||
var resp map[string]any
|
var resp map[string]any
|
||||||
if err := json.Unmarshal(respBody, &resp); err != nil {
|
if err := json.Unmarshal(respBody, &resp); err != nil {
|
||||||
return nil, err
|
return nil, false, err
|
||||||
}
|
}
|
||||||
taskResponses, _ := resp["task_responses"].([]any)
|
taskResponses, _ := resp["task_responses"].([]any)
|
||||||
for _, item := range taskResponses {
|
for _, item := range taskResponses {
|
||||||
@@ -401,10 +426,30 @@ func (c *SoraDirectClient) GetImageTask(ctx context.Context, account *Account, t
|
|||||||
Status: status,
|
Status: status,
|
||||||
ProgressPct: progress,
|
ProgressPct: progress,
|
||||||
URLs: urls,
|
URLs: urls,
|
||||||
}, nil
|
}, true, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return &SoraImageTaskStatus{ID: taskID, Status: "processing"}, nil
|
return &SoraImageTaskStatus{ID: taskID, Status: "processing"}, false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *SoraDirectClient) recentTaskLimit() int {
|
||||||
|
if c == nil || c.cfg == nil {
|
||||||
|
return 20
|
||||||
|
}
|
||||||
|
if c.cfg.Sora.Client.RecentTaskLimit > 0 {
|
||||||
|
return c.cfg.Sora.Client.RecentTaskLimit
|
||||||
|
}
|
||||||
|
return 20
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *SoraDirectClient) recentTaskLimitMax() int {
|
||||||
|
if c == nil || c.cfg == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
if c.cfg.Sora.Client.RecentTaskLimitMax > 0 {
|
||||||
|
return c.cfg.Sora.Client.RecentTaskLimitMax
|
||||||
|
}
|
||||||
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *SoraDirectClient) GetVideoTask(ctx context.Context, account *Account, taskID string) (*SoraVideoTaskStatus, error) {
|
func (c *SoraDirectClient) GetVideoTask(ctx context.Context, account *Account, taskID string) (*SoraVideoTaskStatus, error) {
|
||||||
|
|||||||
@@ -52,3 +52,36 @@ func TestSoraDirectClient_BuildBaseHeaders(t *testing.T) {
|
|||||||
require.Equal(t, "yes", headers.Get("X-Test"))
|
require.Equal(t, "yes", headers.Get("X-Test"))
|
||||||
require.Empty(t, headers.Get("openai-sentinel-token"))
|
require.Empty(t, headers.Get("openai-sentinel-token"))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSoraDirectClient_GetImageTaskFallbackLimit(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
limit := r.URL.Query().Get("limit")
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
switch limit {
|
||||||
|
case "1":
|
||||||
|
_, _ = w.Write([]byte(`{"task_responses":[]}`))
|
||||||
|
case "2":
|
||||||
|
_, _ = w.Write([]byte(`{"task_responses":[{"id":"task-1","status":"completed","progress_pct":1,"generations":[{"url":"https://example.com/a.png"}]}]}`))
|
||||||
|
default:
|
||||||
|
_, _ = w.Write([]byte(`{"task_responses":[]}`))
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
cfg := &config.Config{
|
||||||
|
Sora: config.SoraConfig{
|
||||||
|
Client: config.SoraClientConfig{
|
||||||
|
BaseURL: server.URL,
|
||||||
|
RecentTaskLimit: 1,
|
||||||
|
RecentTaskLimitMax: 2,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
client := NewSoraDirectClient(cfg, nil, nil)
|
||||||
|
account := &Account{Credentials: map[string]any{"access_token": "token"}}
|
||||||
|
|
||||||
|
status, err := client.GetImageTask(context.Background(), account, "task-1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "completed", status.Status)
|
||||||
|
require.Equal(t, []string{"https://example.com/a.png"}, status.URLs)
|
||||||
|
}
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"mime"
|
"mime"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"regexp"
|
"regexp"
|
||||||
@@ -26,6 +27,9 @@ var soraImageMarkdownRe = regexp.MustCompile(`!\[[^\]]*\]\(([^)]+)\)`)
|
|||||||
var soraVideoHTMLRe = regexp.MustCompile(`(?i)<video[^>]+src=['"]([^'"]+)['"]`)
|
var soraVideoHTMLRe = regexp.MustCompile(`(?i)<video[^>]+src=['"]([^'"]+)['"]`)
|
||||||
|
|
||||||
const soraRewriteBufferLimit = 2048
|
const soraRewriteBufferLimit = 2048
|
||||||
|
const soraImageInputMaxBytes = 20 << 20
|
||||||
|
const soraImageInputMaxRedirects = 3
|
||||||
|
const soraImageInputTimeout = 20 * time.Second
|
||||||
|
|
||||||
var soraImageSizeMap = map[string]string{
|
var soraImageSizeMap = map[string]string{
|
||||||
"gpt-image": "360",
|
"gpt-image": "360",
|
||||||
@@ -33,6 +37,29 @@ var soraImageSizeMap = map[string]string{
|
|||||||
"gpt-image-portrait": "540",
|
"gpt-image-portrait": "540",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var soraBlockedHostnames = map[string]struct{}{
|
||||||
|
"localhost": {},
|
||||||
|
"localhost.localdomain": {},
|
||||||
|
"metadata.google.internal": {},
|
||||||
|
"metadata.google.internal.": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
var soraBlockedCIDRs = mustParseCIDRs([]string{
|
||||||
|
"0.0.0.0/8",
|
||||||
|
"10.0.0.0/8",
|
||||||
|
"100.64.0.0/10",
|
||||||
|
"127.0.0.0/8",
|
||||||
|
"169.254.0.0/16",
|
||||||
|
"172.16.0.0/12",
|
||||||
|
"192.168.0.0/16",
|
||||||
|
"224.0.0.0/4",
|
||||||
|
"240.0.0.0/4",
|
||||||
|
"::/128",
|
||||||
|
"::1/128",
|
||||||
|
"fc00::/7",
|
||||||
|
"fe80::/10",
|
||||||
|
})
|
||||||
|
|
||||||
type soraStreamingResult struct {
|
type soraStreamingResult struct {
|
||||||
mediaType string
|
mediaType string
|
||||||
mediaURLs []string
|
mediaURLs []string
|
||||||
@@ -1233,11 +1260,24 @@ func decodeSoraImageInput(ctx context.Context, input string) ([]byte, string, er
|
|||||||
}
|
}
|
||||||
|
|
||||||
func downloadSoraImageInput(ctx context.Context, rawURL string) ([]byte, string, error) {
|
func downloadSoraImageInput(ctx context.Context, rawURL string) ([]byte, string, error) {
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, rawURL, nil)
|
parsed, err := validateSoraImageURL(rawURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
resp, err := http.DefaultClient.Do(req)
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, parsed.String(), nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
client := &http.Client{
|
||||||
|
Timeout: soraImageInputTimeout,
|
||||||
|
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||||||
|
if len(via) >= soraImageInputMaxRedirects {
|
||||||
|
return errors.New("too many redirects")
|
||||||
|
}
|
||||||
|
return validateSoraImageURLValue(req.URL)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
resp, err := client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
@@ -1245,14 +1285,88 @@ func downloadSoraImageInput(ctx context.Context, rawURL string) ([]byte, string,
|
|||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
return nil, "", fmt.Errorf("download image failed: %d", resp.StatusCode)
|
return nil, "", fmt.Errorf("download image failed: %d", resp.StatusCode)
|
||||||
}
|
}
|
||||||
data, err := io.ReadAll(io.LimitReader(resp.Body, 20<<20))
|
data, err := io.ReadAll(io.LimitReader(resp.Body, soraImageInputMaxBytes))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
ext := fileExtFromURL(rawURL)
|
ext := fileExtFromURL(parsed.String())
|
||||||
if ext == "" {
|
if ext == "" {
|
||||||
ext = fileExtFromContentType(resp.Header.Get("Content-Type"))
|
ext = fileExtFromContentType(resp.Header.Get("Content-Type"))
|
||||||
}
|
}
|
||||||
filename := "image" + ext
|
filename := "image" + ext
|
||||||
return data, filename, nil
|
return data, filename, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func validateSoraImageURL(raw string) (*url.URL, error) {
|
||||||
|
if strings.TrimSpace(raw) == "" {
|
||||||
|
return nil, errors.New("empty image url")
|
||||||
|
}
|
||||||
|
parsed, err := url.Parse(raw)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid image url: %w", err)
|
||||||
|
}
|
||||||
|
if err := validateSoraImageURLValue(parsed); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return parsed, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateSoraImageURLValue(parsed *url.URL) error {
|
||||||
|
if parsed == nil {
|
||||||
|
return errors.New("invalid image url")
|
||||||
|
}
|
||||||
|
scheme := strings.ToLower(strings.TrimSpace(parsed.Scheme))
|
||||||
|
if scheme != "http" && scheme != "https" {
|
||||||
|
return errors.New("only http/https image url is allowed")
|
||||||
|
}
|
||||||
|
if parsed.User != nil {
|
||||||
|
return errors.New("image url cannot contain userinfo")
|
||||||
|
}
|
||||||
|
host := strings.ToLower(strings.TrimSpace(parsed.Hostname()))
|
||||||
|
if host == "" {
|
||||||
|
return errors.New("image url missing host")
|
||||||
|
}
|
||||||
|
if _, blocked := soraBlockedHostnames[host]; blocked {
|
||||||
|
return errors.New("image url is not allowed")
|
||||||
|
}
|
||||||
|
if ip := net.ParseIP(host); ip != nil {
|
||||||
|
if isSoraBlockedIP(ip) {
|
||||||
|
return errors.New("image url is not allowed")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
ips, err := net.LookupIP(host)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("resolve image url failed: %w", err)
|
||||||
|
}
|
||||||
|
for _, ip := range ips {
|
||||||
|
if isSoraBlockedIP(ip) {
|
||||||
|
return errors.New("image url is not allowed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func isSoraBlockedIP(ip net.IP) bool {
|
||||||
|
if ip == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
for _, cidr := range soraBlockedCIDRs {
|
||||||
|
if cidr.Contains(ip) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func mustParseCIDRs(values []string) []*net.IPNet {
|
||||||
|
out := make([]*net.IPNet, 0, len(values))
|
||||||
|
for _, val := range values {
|
||||||
|
_, cidr, err := net.ParseCIDR(val)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
out = append(out, cidr)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|||||||
@@ -97,3 +97,16 @@ func TestSoraGatewayService_BuildSoraMediaURLSigned(t *testing.T) {
|
|||||||
require.Contains(t, url, "expires=")
|
require.Contains(t, url, "expires=")
|
||||||
require.Contains(t, url, "sig=")
|
require.Contains(t, url, "sig=")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDecodeSoraImageInput_BlockPrivateURL(t *testing.T) {
|
||||||
|
_, _, err := decodeSoraImageInput(context.Background(), "http://127.0.0.1/internal.png")
|
||||||
|
require.Error(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDecodeSoraImageInput_DataURL(t *testing.T) {
|
||||||
|
encoded := "data:image/png;base64,aGVsbG8="
|
||||||
|
data, filename, err := decodeSoraImageInput(context.Background(), encoded)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotEmpty(t, data)
|
||||||
|
require.Contains(t, filename, ".png")
|
||||||
|
}
|
||||||
|
|||||||
@@ -30,6 +30,8 @@ type SoraMediaStorage struct {
|
|||||||
imageRoot string
|
imageRoot string
|
||||||
videoRoot string
|
videoRoot string
|
||||||
maxConcurrent int
|
maxConcurrent int
|
||||||
|
downloadTimeout time.Duration
|
||||||
|
maxDownloadBytes int64
|
||||||
fallbackToUpstream bool
|
fallbackToUpstream bool
|
||||||
debug bool
|
debug bool
|
||||||
sem chan struct{}
|
sem chan struct{}
|
||||||
@@ -92,6 +94,17 @@ func (s *SoraMediaStorage) refreshConfig() {
|
|||||||
maxConcurrent = 4
|
maxConcurrent = 4
|
||||||
}
|
}
|
||||||
s.maxConcurrent = maxConcurrent
|
s.maxConcurrent = maxConcurrent
|
||||||
|
timeoutSeconds := s.cfg.Sora.Storage.DownloadTimeoutSeconds
|
||||||
|
if timeoutSeconds <= 0 {
|
||||||
|
timeoutSeconds = 120
|
||||||
|
}
|
||||||
|
s.downloadTimeout = time.Duration(timeoutSeconds) * time.Second
|
||||||
|
|
||||||
|
maxBytes := s.cfg.Sora.Storage.MaxDownloadBytes
|
||||||
|
if maxBytes <= 0 {
|
||||||
|
maxBytes = 200 << 20
|
||||||
|
}
|
||||||
|
s.maxDownloadBytes = maxBytes
|
||||||
s.fallbackToUpstream = s.cfg.Sora.Storage.FallbackToUpstream
|
s.fallbackToUpstream = s.cfg.Sora.Storage.FallbackToUpstream
|
||||||
s.debug = s.cfg.Sora.Storage.Debug
|
s.debug = s.cfg.Sora.Storage.Debug
|
||||||
s.sem = make(chan struct{}, maxConcurrent)
|
s.sem = make(chan struct{}, maxConcurrent)
|
||||||
@@ -180,7 +193,8 @@ func (s *SoraMediaStorage) downloadOnce(ctx context.Context, root, mediaType, ra
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
resp, err := http.DefaultClient.Do(req)
|
client := &http.Client{Timeout: s.downloadTimeout}
|
||||||
|
resp, err := client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
@@ -198,6 +212,9 @@ func (s *SoraMediaStorage) downloadOnce(ctx context.Context, root, mediaType, ra
|
|||||||
if ext == "" {
|
if ext == "" {
|
||||||
ext = ".bin"
|
ext = ".bin"
|
||||||
}
|
}
|
||||||
|
if s.maxDownloadBytes > 0 && resp.ContentLength > s.maxDownloadBytes {
|
||||||
|
return "", fmt.Errorf("download size exceeds limit: %d", resp.ContentLength)
|
||||||
|
}
|
||||||
|
|
||||||
datePath := time.Now().Format("2006/01/02")
|
datePath := time.Now().Format("2006/01/02")
|
||||||
destDir := filepath.Join(root, filepath.FromSlash(datePath))
|
destDir := filepath.Join(root, filepath.FromSlash(datePath))
|
||||||
@@ -212,10 +229,16 @@ func (s *SoraMediaStorage) downloadOnce(ctx context.Context, root, mediaType, ra
|
|||||||
}
|
}
|
||||||
defer func() { _ = out.Close() }()
|
defer func() { _ = out.Close() }()
|
||||||
|
|
||||||
if _, err := io.Copy(out, resp.Body); err != nil {
|
limited := io.LimitReader(resp.Body, s.maxDownloadBytes+1)
|
||||||
|
written, err := io.Copy(out, limited)
|
||||||
|
if err != nil {
|
||||||
_ = os.Remove(destPath)
|
_ = os.Remove(destPath)
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
if s.maxDownloadBytes > 0 && written > s.maxDownloadBytes {
|
||||||
|
_ = os.Remove(destPath)
|
||||||
|
return "", fmt.Errorf("download size exceeds limit: %d", written)
|
||||||
|
}
|
||||||
|
|
||||||
relative := path.Join("/", mediaType, datePath, filename)
|
relative := path.Join("/", mediaType, datePath, filename)
|
||||||
if s.debug {
|
if s.debug {
|
||||||
|
|||||||
@@ -67,3 +67,27 @@ func TestSoraMediaStorage_FallbackToUpstream(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, []string{url}, urls)
|
require.Equal(t, []string{url}, urls)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSoraMediaStorage_MaxDownloadBytes(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "image/png")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = w.Write([]byte("too-large"))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
cfg := &config.Config{
|
||||||
|
Sora: config.SoraConfig{
|
||||||
|
Storage: config.SoraStorageConfig{
|
||||||
|
Type: "local",
|
||||||
|
LocalPath: tmpDir,
|
||||||
|
MaxDownloadBytes: 1,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
storage := NewSoraMediaStorage(cfg)
|
||||||
|
_, err := storage.StoreFromURLs(context.Background(), "image", []string{server.URL + "/img.png"})
|
||||||
|
require.Error(t, err)
|
||||||
|
}
|
||||||
|
|||||||
@@ -269,6 +269,12 @@ sora:
|
|||||||
# Max poll attempts
|
# Max poll attempts
|
||||||
# 最大轮询次数
|
# 最大轮询次数
|
||||||
max_poll_attempts: 600
|
max_poll_attempts: 600
|
||||||
|
# Recent task query limit (image)
|
||||||
|
# 最近任务查询数量(图片轮询)
|
||||||
|
recent_task_limit: 50
|
||||||
|
# Recent task query max limit (fallback)
|
||||||
|
# 最近任务查询最大数量(回退)
|
||||||
|
recent_task_limit_max: 200
|
||||||
# Enable debug logs for Sora upstream requests
|
# Enable debug logs for Sora upstream requests
|
||||||
# 启用 Sora 直连调试日志
|
# 启用 Sora 直连调试日志
|
||||||
debug: false
|
debug: false
|
||||||
@@ -294,6 +300,12 @@ sora:
|
|||||||
# Max concurrent downloads
|
# Max concurrent downloads
|
||||||
# 并发下载上限
|
# 并发下载上限
|
||||||
max_concurrent_downloads: 4
|
max_concurrent_downloads: 4
|
||||||
|
# Download timeout (seconds)
|
||||||
|
# 下载超时(秒)
|
||||||
|
download_timeout_seconds: 120
|
||||||
|
# Max download bytes
|
||||||
|
# 最大下载字节数
|
||||||
|
max_download_bytes: 209715200
|
||||||
# Enable debug logs for media storage
|
# Enable debug logs for media storage
|
||||||
# 启用媒体存储调试日志
|
# 启用媒体存储调试日志
|
||||||
debug: false
|
debug: false
|
||||||
|
|||||||
Reference in New Issue
Block a user