fix(Sora): 加固直连安全与下载限制
补充图片输入 SSRF 防护与重定向限制\n增加媒体下载超时/大小上限配置并更新示例\n完善 recent_tasks 轮询回退策略与相关测试\n\n测试: go test ./... -tags=unit
This commit is contained in:
@@ -218,6 +218,8 @@ type SoraClientConfig struct {
|
||||
MaxRetries int `mapstructure:"max_retries"`
|
||||
PollIntervalSeconds int `mapstructure:"poll_interval_seconds"`
|
||||
MaxPollAttempts int `mapstructure:"max_poll_attempts"`
|
||||
RecentTaskLimit int `mapstructure:"recent_task_limit"`
|
||||
RecentTaskLimitMax int `mapstructure:"recent_task_limit_max"`
|
||||
Debug bool `mapstructure:"debug"`
|
||||
Headers map[string]string `mapstructure:"headers"`
|
||||
UserAgent string `mapstructure:"user_agent"`
|
||||
@@ -230,6 +232,8 @@ type SoraStorageConfig struct {
|
||||
LocalPath string `mapstructure:"local_path"`
|
||||
FallbackToUpstream bool `mapstructure:"fallback_to_upstream"`
|
||||
MaxConcurrentDownloads int `mapstructure:"max_concurrent_downloads"`
|
||||
DownloadTimeoutSeconds int `mapstructure:"download_timeout_seconds"`
|
||||
MaxDownloadBytes int64 `mapstructure:"max_download_bytes"`
|
||||
Debug bool `mapstructure:"debug"`
|
||||
Cleanup SoraStorageCleanupConfig `mapstructure:"cleanup"`
|
||||
}
|
||||
@@ -929,6 +933,8 @@ func setDefaults() {
|
||||
viper.SetDefault("sora.client.max_retries", 3)
|
||||
viper.SetDefault("sora.client.poll_interval_seconds", 2)
|
||||
viper.SetDefault("sora.client.max_poll_attempts", 600)
|
||||
viper.SetDefault("sora.client.recent_task_limit", 50)
|
||||
viper.SetDefault("sora.client.recent_task_limit_max", 200)
|
||||
viper.SetDefault("sora.client.debug", false)
|
||||
viper.SetDefault("sora.client.headers", map[string]string{})
|
||||
viper.SetDefault("sora.client.user_agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)")
|
||||
@@ -938,6 +944,8 @@ func setDefaults() {
|
||||
viper.SetDefault("sora.storage.local_path", "")
|
||||
viper.SetDefault("sora.storage.fallback_to_upstream", true)
|
||||
viper.SetDefault("sora.storage.max_concurrent_downloads", 4)
|
||||
viper.SetDefault("sora.storage.download_timeout_seconds", 120)
|
||||
viper.SetDefault("sora.storage.max_download_bytes", int64(200<<20))
|
||||
viper.SetDefault("sora.storage.debug", false)
|
||||
viper.SetDefault("sora.storage.cleanup.enabled", true)
|
||||
viper.SetDefault("sora.storage.cleanup.retention_days", 7)
|
||||
@@ -1205,9 +1213,25 @@ func (c *Config) Validate() error {
|
||||
if c.Sora.Client.MaxPollAttempts < 0 {
|
||||
return fmt.Errorf("sora.client.max_poll_attempts must be non-negative")
|
||||
}
|
||||
if c.Sora.Client.RecentTaskLimit < 0 {
|
||||
return fmt.Errorf("sora.client.recent_task_limit must be non-negative")
|
||||
}
|
||||
if c.Sora.Client.RecentTaskLimitMax < 0 {
|
||||
return fmt.Errorf("sora.client.recent_task_limit_max must be non-negative")
|
||||
}
|
||||
if c.Sora.Client.RecentTaskLimitMax > 0 && c.Sora.Client.RecentTaskLimit > 0 &&
|
||||
c.Sora.Client.RecentTaskLimitMax < c.Sora.Client.RecentTaskLimit {
|
||||
c.Sora.Client.RecentTaskLimitMax = c.Sora.Client.RecentTaskLimit
|
||||
}
|
||||
if c.Sora.Storage.MaxConcurrentDownloads < 0 {
|
||||
return fmt.Errorf("sora.storage.max_concurrent_downloads must be non-negative")
|
||||
}
|
||||
if c.Sora.Storage.DownloadTimeoutSeconds < 0 {
|
||||
return fmt.Errorf("sora.storage.download_timeout_seconds must be non-negative")
|
||||
}
|
||||
if c.Sora.Storage.MaxDownloadBytes < 0 {
|
||||
return fmt.Errorf("sora.storage.max_download_bytes must be non-negative")
|
||||
}
|
||||
if c.Sora.Storage.Cleanup.Enabled {
|
||||
if c.Sora.Storage.Cleanup.RetentionDays <= 0 {
|
||||
return fmt.Errorf("sora.storage.cleanup.retention_days must be positive")
|
||||
|
||||
@@ -359,18 +359,43 @@ func (c *SoraDirectClient) CreateVideoTask(ctx context.Context, account *Account
|
||||
}
|
||||
|
||||
func (c *SoraDirectClient) GetImageTask(ctx context.Context, account *Account, taskID string) (*SoraImageTaskStatus, error) {
|
||||
token, err := c.getAccessToken(ctx, account)
|
||||
status, found, err := c.fetchRecentImageTask(ctx, account, taskID, c.recentTaskLimit())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
headers := c.buildBaseHeaders(token, c.defaultUserAgent())
|
||||
respBody, _, err := c.doRequest(ctx, account, http.MethodGet, c.buildURL("/v2/recent_tasks?limit=20"), headers, nil, false)
|
||||
if found {
|
||||
return status, nil
|
||||
}
|
||||
maxLimit := c.recentTaskLimitMax()
|
||||
if maxLimit > 0 && maxLimit != c.recentTaskLimit() {
|
||||
status, found, err = c.fetchRecentImageTask(ctx, account, taskID, maxLimit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if found {
|
||||
return status, nil
|
||||
}
|
||||
}
|
||||
return &SoraImageTaskStatus{ID: taskID, Status: "processing"}, nil
|
||||
}
|
||||
|
||||
func (c *SoraDirectClient) fetchRecentImageTask(ctx context.Context, account *Account, taskID string, limit int) (*SoraImageTaskStatus, bool, error) {
|
||||
token, err := c.getAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, false, err
|
||||
}
|
||||
headers := c.buildBaseHeaders(token, c.defaultUserAgent())
|
||||
if limit <= 0 {
|
||||
limit = 20
|
||||
}
|
||||
endpoint := fmt.Sprintf("/v2/recent_tasks?limit=%d", limit)
|
||||
respBody, _, err := c.doRequest(ctx, account, http.MethodGet, c.buildURL(endpoint), headers, nil, false)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
var resp map[string]any
|
||||
if err := json.Unmarshal(respBody, &resp); err != nil {
|
||||
return nil, err
|
||||
return nil, false, err
|
||||
}
|
||||
taskResponses, _ := resp["task_responses"].([]any)
|
||||
for _, item := range taskResponses {
|
||||
@@ -401,10 +426,30 @@ func (c *SoraDirectClient) GetImageTask(ctx context.Context, account *Account, t
|
||||
Status: status,
|
||||
ProgressPct: progress,
|
||||
URLs: urls,
|
||||
}, nil
|
||||
}, true, nil
|
||||
}
|
||||
}
|
||||
return &SoraImageTaskStatus{ID: taskID, Status: "processing"}, nil
|
||||
return &SoraImageTaskStatus{ID: taskID, Status: "processing"}, false, nil
|
||||
}
|
||||
|
||||
func (c *SoraDirectClient) recentTaskLimit() int {
|
||||
if c == nil || c.cfg == nil {
|
||||
return 20
|
||||
}
|
||||
if c.cfg.Sora.Client.RecentTaskLimit > 0 {
|
||||
return c.cfg.Sora.Client.RecentTaskLimit
|
||||
}
|
||||
return 20
|
||||
}
|
||||
|
||||
func (c *SoraDirectClient) recentTaskLimitMax() int {
|
||||
if c == nil || c.cfg == nil {
|
||||
return 0
|
||||
}
|
||||
if c.cfg.Sora.Client.RecentTaskLimitMax > 0 {
|
||||
return c.cfg.Sora.Client.RecentTaskLimitMax
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (c *SoraDirectClient) GetVideoTask(ctx context.Context, account *Account, taskID string) (*SoraVideoTaskStatus, error) {
|
||||
|
||||
@@ -52,3 +52,36 @@ func TestSoraDirectClient_BuildBaseHeaders(t *testing.T) {
|
||||
require.Equal(t, "yes", headers.Get("X-Test"))
|
||||
require.Empty(t, headers.Get("openai-sentinel-token"))
|
||||
}
|
||||
|
||||
func TestSoraDirectClient_GetImageTaskFallbackLimit(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
limit := r.URL.Query().Get("limit")
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
switch limit {
|
||||
case "1":
|
||||
_, _ = w.Write([]byte(`{"task_responses":[]}`))
|
||||
case "2":
|
||||
_, _ = w.Write([]byte(`{"task_responses":[{"id":"task-1","status":"completed","progress_pct":1,"generations":[{"url":"https://example.com/a.png"}]}]}`))
|
||||
default:
|
||||
_, _ = w.Write([]byte(`{"task_responses":[]}`))
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
cfg := &config.Config{
|
||||
Sora: config.SoraConfig{
|
||||
Client: config.SoraClientConfig{
|
||||
BaseURL: server.URL,
|
||||
RecentTaskLimit: 1,
|
||||
RecentTaskLimitMax: 2,
|
||||
},
|
||||
},
|
||||
}
|
||||
client := NewSoraDirectClient(cfg, nil, nil)
|
||||
account := &Account{Credentials: map[string]any{"access_token": "token"}}
|
||||
|
||||
status, err := client.GetImageTask(context.Background(), account, "task-1")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "completed", status.Status)
|
||||
require.Equal(t, []string{"https://example.com/a.png"}, status.URLs)
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"mime"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"regexp"
|
||||
@@ -26,6 +27,9 @@ var soraImageMarkdownRe = regexp.MustCompile(`!\[[^\]]*\]\(([^)]+)\)`)
|
||||
var soraVideoHTMLRe = regexp.MustCompile(`(?i)<video[^>]+src=['"]([^'"]+)['"]`)
|
||||
|
||||
const soraRewriteBufferLimit = 2048
|
||||
const soraImageInputMaxBytes = 20 << 20
|
||||
const soraImageInputMaxRedirects = 3
|
||||
const soraImageInputTimeout = 20 * time.Second
|
||||
|
||||
var soraImageSizeMap = map[string]string{
|
||||
"gpt-image": "360",
|
||||
@@ -33,6 +37,29 @@ var soraImageSizeMap = map[string]string{
|
||||
"gpt-image-portrait": "540",
|
||||
}
|
||||
|
||||
var soraBlockedHostnames = map[string]struct{}{
|
||||
"localhost": {},
|
||||
"localhost.localdomain": {},
|
||||
"metadata.google.internal": {},
|
||||
"metadata.google.internal.": {},
|
||||
}
|
||||
|
||||
var soraBlockedCIDRs = mustParseCIDRs([]string{
|
||||
"0.0.0.0/8",
|
||||
"10.0.0.0/8",
|
||||
"100.64.0.0/10",
|
||||
"127.0.0.0/8",
|
||||
"169.254.0.0/16",
|
||||
"172.16.0.0/12",
|
||||
"192.168.0.0/16",
|
||||
"224.0.0.0/4",
|
||||
"240.0.0.0/4",
|
||||
"::/128",
|
||||
"::1/128",
|
||||
"fc00::/7",
|
||||
"fe80::/10",
|
||||
})
|
||||
|
||||
type soraStreamingResult struct {
|
||||
mediaType string
|
||||
mediaURLs []string
|
||||
@@ -1233,11 +1260,24 @@ func decodeSoraImageInput(ctx context.Context, input string) ([]byte, string, er
|
||||
}
|
||||
|
||||
func downloadSoraImageInput(ctx context.Context, rawURL string) ([]byte, string, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, rawURL, nil)
|
||||
parsed, err := validateSoraImageURL(rawURL)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, parsed.String(), nil)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
client := &http.Client{
|
||||
Timeout: soraImageInputTimeout,
|
||||
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||||
if len(via) >= soraImageInputMaxRedirects {
|
||||
return errors.New("too many redirects")
|
||||
}
|
||||
return validateSoraImageURLValue(req.URL)
|
||||
},
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
@@ -1245,14 +1285,88 @@ func downloadSoraImageInput(ctx context.Context, rawURL string) ([]byte, string,
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, "", fmt.Errorf("download image failed: %d", resp.StatusCode)
|
||||
}
|
||||
data, err := io.ReadAll(io.LimitReader(resp.Body, 20<<20))
|
||||
data, err := io.ReadAll(io.LimitReader(resp.Body, soraImageInputMaxBytes))
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
ext := fileExtFromURL(rawURL)
|
||||
ext := fileExtFromURL(parsed.String())
|
||||
if ext == "" {
|
||||
ext = fileExtFromContentType(resp.Header.Get("Content-Type"))
|
||||
}
|
||||
filename := "image" + ext
|
||||
return data, filename, nil
|
||||
}
|
||||
|
||||
func validateSoraImageURL(raw string) (*url.URL, error) {
|
||||
if strings.TrimSpace(raw) == "" {
|
||||
return nil, errors.New("empty image url")
|
||||
}
|
||||
parsed, err := url.Parse(raw)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid image url: %w", err)
|
||||
}
|
||||
if err := validateSoraImageURLValue(parsed); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return parsed, nil
|
||||
}
|
||||
|
||||
func validateSoraImageURLValue(parsed *url.URL) error {
|
||||
if parsed == nil {
|
||||
return errors.New("invalid image url")
|
||||
}
|
||||
scheme := strings.ToLower(strings.TrimSpace(parsed.Scheme))
|
||||
if scheme != "http" && scheme != "https" {
|
||||
return errors.New("only http/https image url is allowed")
|
||||
}
|
||||
if parsed.User != nil {
|
||||
return errors.New("image url cannot contain userinfo")
|
||||
}
|
||||
host := strings.ToLower(strings.TrimSpace(parsed.Hostname()))
|
||||
if host == "" {
|
||||
return errors.New("image url missing host")
|
||||
}
|
||||
if _, blocked := soraBlockedHostnames[host]; blocked {
|
||||
return errors.New("image url is not allowed")
|
||||
}
|
||||
if ip := net.ParseIP(host); ip != nil {
|
||||
if isSoraBlockedIP(ip) {
|
||||
return errors.New("image url is not allowed")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
ips, err := net.LookupIP(host)
|
||||
if err != nil {
|
||||
return fmt.Errorf("resolve image url failed: %w", err)
|
||||
}
|
||||
for _, ip := range ips {
|
||||
if isSoraBlockedIP(ip) {
|
||||
return errors.New("image url is not allowed")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func isSoraBlockedIP(ip net.IP) bool {
|
||||
if ip == nil {
|
||||
return true
|
||||
}
|
||||
for _, cidr := range soraBlockedCIDRs {
|
||||
if cidr.Contains(ip) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func mustParseCIDRs(values []string) []*net.IPNet {
|
||||
out := make([]*net.IPNet, 0, len(values))
|
||||
for _, val := range values {
|
||||
_, cidr, err := net.ParseCIDR(val)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
out = append(out, cidr)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
@@ -97,3 +97,16 @@ func TestSoraGatewayService_BuildSoraMediaURLSigned(t *testing.T) {
|
||||
require.Contains(t, url, "expires=")
|
||||
require.Contains(t, url, "sig=")
|
||||
}
|
||||
|
||||
func TestDecodeSoraImageInput_BlockPrivateURL(t *testing.T) {
|
||||
_, _, err := decodeSoraImageInput(context.Background(), "http://127.0.0.1/internal.png")
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestDecodeSoraImageInput_DataURL(t *testing.T) {
|
||||
encoded := "data:image/png;base64,aGVsbG8="
|
||||
data, filename, err := decodeSoraImageInput(context.Background(), encoded)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, data)
|
||||
require.Contains(t, filename, ".png")
|
||||
}
|
||||
|
||||
@@ -30,6 +30,8 @@ type SoraMediaStorage struct {
|
||||
imageRoot string
|
||||
videoRoot string
|
||||
maxConcurrent int
|
||||
downloadTimeout time.Duration
|
||||
maxDownloadBytes int64
|
||||
fallbackToUpstream bool
|
||||
debug bool
|
||||
sem chan struct{}
|
||||
@@ -92,6 +94,17 @@ func (s *SoraMediaStorage) refreshConfig() {
|
||||
maxConcurrent = 4
|
||||
}
|
||||
s.maxConcurrent = maxConcurrent
|
||||
timeoutSeconds := s.cfg.Sora.Storage.DownloadTimeoutSeconds
|
||||
if timeoutSeconds <= 0 {
|
||||
timeoutSeconds = 120
|
||||
}
|
||||
s.downloadTimeout = time.Duration(timeoutSeconds) * time.Second
|
||||
|
||||
maxBytes := s.cfg.Sora.Storage.MaxDownloadBytes
|
||||
if maxBytes <= 0 {
|
||||
maxBytes = 200 << 20
|
||||
}
|
||||
s.maxDownloadBytes = maxBytes
|
||||
s.fallbackToUpstream = s.cfg.Sora.Storage.FallbackToUpstream
|
||||
s.debug = s.cfg.Sora.Storage.Debug
|
||||
s.sem = make(chan struct{}, maxConcurrent)
|
||||
@@ -180,7 +193,8 @@ func (s *SoraMediaStorage) downloadOnce(ctx context.Context, root, mediaType, ra
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
client := &http.Client{Timeout: s.downloadTimeout}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -198,6 +212,9 @@ func (s *SoraMediaStorage) downloadOnce(ctx context.Context, root, mediaType, ra
|
||||
if ext == "" {
|
||||
ext = ".bin"
|
||||
}
|
||||
if s.maxDownloadBytes > 0 && resp.ContentLength > s.maxDownloadBytes {
|
||||
return "", fmt.Errorf("download size exceeds limit: %d", resp.ContentLength)
|
||||
}
|
||||
|
||||
datePath := time.Now().Format("2006/01/02")
|
||||
destDir := filepath.Join(root, filepath.FromSlash(datePath))
|
||||
@@ -212,10 +229,16 @@ func (s *SoraMediaStorage) downloadOnce(ctx context.Context, root, mediaType, ra
|
||||
}
|
||||
defer func() { _ = out.Close() }()
|
||||
|
||||
if _, err := io.Copy(out, resp.Body); err != nil {
|
||||
limited := io.LimitReader(resp.Body, s.maxDownloadBytes+1)
|
||||
written, err := io.Copy(out, limited)
|
||||
if err != nil {
|
||||
_ = os.Remove(destPath)
|
||||
return "", err
|
||||
}
|
||||
if s.maxDownloadBytes > 0 && written > s.maxDownloadBytes {
|
||||
_ = os.Remove(destPath)
|
||||
return "", fmt.Errorf("download size exceeds limit: %d", written)
|
||||
}
|
||||
|
||||
relative := path.Join("/", mediaType, datePath, filename)
|
||||
if s.debug {
|
||||
|
||||
@@ -67,3 +67,27 @@ func TestSoraMediaStorage_FallbackToUpstream(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []string{url}, urls)
|
||||
}
|
||||
|
||||
func TestSoraMediaStorage_MaxDownloadBytes(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "image/png")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("too-large"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
cfg := &config.Config{
|
||||
Sora: config.SoraConfig{
|
||||
Storage: config.SoraStorageConfig{
|
||||
Type: "local",
|
||||
LocalPath: tmpDir,
|
||||
MaxDownloadBytes: 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
storage := NewSoraMediaStorage(cfg)
|
||||
_, err := storage.StoreFromURLs(context.Background(), "image", []string{server.URL + "/img.png"})
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
@@ -269,6 +269,12 @@ sora:
|
||||
# Max poll attempts
|
||||
# 最大轮询次数
|
||||
max_poll_attempts: 600
|
||||
# Recent task query limit (image)
|
||||
# 最近任务查询数量(图片轮询)
|
||||
recent_task_limit: 50
|
||||
# Recent task query max limit (fallback)
|
||||
# 最近任务查询最大数量(回退)
|
||||
recent_task_limit_max: 200
|
||||
# Enable debug logs for Sora upstream requests
|
||||
# 启用 Sora 直连调试日志
|
||||
debug: false
|
||||
@@ -294,6 +300,12 @@ sora:
|
||||
# Max concurrent downloads
|
||||
# 并发下载上限
|
||||
max_concurrent_downloads: 4
|
||||
# Download timeout (seconds)
|
||||
# 下载超时(秒)
|
||||
download_timeout_seconds: 120
|
||||
# Max download bytes
|
||||
# 最大下载字节数
|
||||
max_download_bytes: 209715200
|
||||
# Enable debug logs for media storage
|
||||
# 启用媒体存储调试日志
|
||||
debug: false
|
||||
|
||||
Reference in New Issue
Block a user