fix(Sora): 加固直连安全与下载限制

补充图片输入 SSRF 防护与重定向限制\n增加媒体下载超时/大小上限配置并更新示例\n完善 recent_tasks 轮询回退策略与相关测试\n\n测试: go test ./... -tags=unit
This commit is contained in:
yangjianbo
2026-02-01 22:10:15 +08:00
parent dcf5f60237
commit 99250ec527
8 changed files with 301 additions and 13 deletions

View File

@@ -218,6 +218,8 @@ type SoraClientConfig struct {
MaxRetries int `mapstructure:"max_retries"`
PollIntervalSeconds int `mapstructure:"poll_interval_seconds"`
MaxPollAttempts int `mapstructure:"max_poll_attempts"`
RecentTaskLimit int `mapstructure:"recent_task_limit"`
RecentTaskLimitMax int `mapstructure:"recent_task_limit_max"`
Debug bool `mapstructure:"debug"`
Headers map[string]string `mapstructure:"headers"`
UserAgent string `mapstructure:"user_agent"`
@@ -230,6 +232,8 @@ type SoraStorageConfig struct {
LocalPath string `mapstructure:"local_path"`
FallbackToUpstream bool `mapstructure:"fallback_to_upstream"`
MaxConcurrentDownloads int `mapstructure:"max_concurrent_downloads"`
DownloadTimeoutSeconds int `mapstructure:"download_timeout_seconds"`
MaxDownloadBytes int64 `mapstructure:"max_download_bytes"`
Debug bool `mapstructure:"debug"`
Cleanup SoraStorageCleanupConfig `mapstructure:"cleanup"`
}
@@ -929,6 +933,8 @@ func setDefaults() {
viper.SetDefault("sora.client.max_retries", 3)
viper.SetDefault("sora.client.poll_interval_seconds", 2)
viper.SetDefault("sora.client.max_poll_attempts", 600)
viper.SetDefault("sora.client.recent_task_limit", 50)
viper.SetDefault("sora.client.recent_task_limit_max", 200)
viper.SetDefault("sora.client.debug", false)
viper.SetDefault("sora.client.headers", map[string]string{})
viper.SetDefault("sora.client.user_agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)")
@@ -938,6 +944,8 @@ func setDefaults() {
viper.SetDefault("sora.storage.local_path", "")
viper.SetDefault("sora.storage.fallback_to_upstream", true)
viper.SetDefault("sora.storage.max_concurrent_downloads", 4)
viper.SetDefault("sora.storage.download_timeout_seconds", 120)
viper.SetDefault("sora.storage.max_download_bytes", int64(200<<20))
viper.SetDefault("sora.storage.debug", false)
viper.SetDefault("sora.storage.cleanup.enabled", true)
viper.SetDefault("sora.storage.cleanup.retention_days", 7)
@@ -1205,9 +1213,25 @@ func (c *Config) Validate() error {
if c.Sora.Client.MaxPollAttempts < 0 {
return fmt.Errorf("sora.client.max_poll_attempts must be non-negative")
}
if c.Sora.Client.RecentTaskLimit < 0 {
return fmt.Errorf("sora.client.recent_task_limit must be non-negative")
}
if c.Sora.Client.RecentTaskLimitMax < 0 {
return fmt.Errorf("sora.client.recent_task_limit_max must be non-negative")
}
if c.Sora.Client.RecentTaskLimitMax > 0 && c.Sora.Client.RecentTaskLimit > 0 &&
c.Sora.Client.RecentTaskLimitMax < c.Sora.Client.RecentTaskLimit {
c.Sora.Client.RecentTaskLimitMax = c.Sora.Client.RecentTaskLimit
}
if c.Sora.Storage.MaxConcurrentDownloads < 0 {
return fmt.Errorf("sora.storage.max_concurrent_downloads must be non-negative")
}
if c.Sora.Storage.DownloadTimeoutSeconds < 0 {
return fmt.Errorf("sora.storage.download_timeout_seconds must be non-negative")
}
if c.Sora.Storage.MaxDownloadBytes < 0 {
return fmt.Errorf("sora.storage.max_download_bytes must be non-negative")
}
if c.Sora.Storage.Cleanup.Enabled {
if c.Sora.Storage.Cleanup.RetentionDays <= 0 {
return fmt.Errorf("sora.storage.cleanup.retention_days must be positive")

View File

@@ -359,18 +359,43 @@ func (c *SoraDirectClient) CreateVideoTask(ctx context.Context, account *Account
}
func (c *SoraDirectClient) GetImageTask(ctx context.Context, account *Account, taskID string) (*SoraImageTaskStatus, error) {
token, err := c.getAccessToken(ctx, account)
status, found, err := c.fetchRecentImageTask(ctx, account, taskID, c.recentTaskLimit())
if err != nil {
return nil, err
}
headers := c.buildBaseHeaders(token, c.defaultUserAgent())
respBody, _, err := c.doRequest(ctx, account, http.MethodGet, c.buildURL("/v2/recent_tasks?limit=20"), headers, nil, false)
if found {
return status, nil
}
maxLimit := c.recentTaskLimitMax()
if maxLimit > 0 && maxLimit != c.recentTaskLimit() {
status, found, err = c.fetchRecentImageTask(ctx, account, taskID, maxLimit)
if err != nil {
return nil, err
}
if found {
return status, nil
}
}
return &SoraImageTaskStatus{ID: taskID, Status: "processing"}, nil
}
func (c *SoraDirectClient) fetchRecentImageTask(ctx context.Context, account *Account, taskID string, limit int) (*SoraImageTaskStatus, bool, error) {
token, err := c.getAccessToken(ctx, account)
if err != nil {
return nil, err
return nil, false, err
}
headers := c.buildBaseHeaders(token, c.defaultUserAgent())
if limit <= 0 {
limit = 20
}
endpoint := fmt.Sprintf("/v2/recent_tasks?limit=%d", limit)
respBody, _, err := c.doRequest(ctx, account, http.MethodGet, c.buildURL(endpoint), headers, nil, false)
if err != nil {
return nil, false, err
}
var resp map[string]any
if err := json.Unmarshal(respBody, &resp); err != nil {
return nil, err
return nil, false, err
}
taskResponses, _ := resp["task_responses"].([]any)
for _, item := range taskResponses {
@@ -401,10 +426,30 @@ func (c *SoraDirectClient) GetImageTask(ctx context.Context, account *Account, t
Status: status,
ProgressPct: progress,
URLs: urls,
}, nil
}, true, nil
}
}
return &SoraImageTaskStatus{ID: taskID, Status: "processing"}, nil
return &SoraImageTaskStatus{ID: taskID, Status: "processing"}, false, nil
}
func (c *SoraDirectClient) recentTaskLimit() int {
if c == nil || c.cfg == nil {
return 20
}
if c.cfg.Sora.Client.RecentTaskLimit > 0 {
return c.cfg.Sora.Client.RecentTaskLimit
}
return 20
}
func (c *SoraDirectClient) recentTaskLimitMax() int {
if c == nil || c.cfg == nil {
return 0
}
if c.cfg.Sora.Client.RecentTaskLimitMax > 0 {
return c.cfg.Sora.Client.RecentTaskLimitMax
}
return 0
}
func (c *SoraDirectClient) GetVideoTask(ctx context.Context, account *Account, taskID string) (*SoraVideoTaskStatus, error) {

View File

@@ -52,3 +52,36 @@ func TestSoraDirectClient_BuildBaseHeaders(t *testing.T) {
require.Equal(t, "yes", headers.Get("X-Test"))
require.Empty(t, headers.Get("openai-sentinel-token"))
}
func TestSoraDirectClient_GetImageTaskFallbackLimit(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
limit := r.URL.Query().Get("limit")
w.Header().Set("Content-Type", "application/json")
switch limit {
case "1":
_, _ = w.Write([]byte(`{"task_responses":[]}`))
case "2":
_, _ = w.Write([]byte(`{"task_responses":[{"id":"task-1","status":"completed","progress_pct":1,"generations":[{"url":"https://example.com/a.png"}]}]}`))
default:
_, _ = w.Write([]byte(`{"task_responses":[]}`))
}
}))
defer server.Close()
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
BaseURL: server.URL,
RecentTaskLimit: 1,
RecentTaskLimitMax: 2,
},
},
}
client := NewSoraDirectClient(cfg, nil, nil)
account := &Account{Credentials: map[string]any{"access_token": "token"}}
status, err := client.GetImageTask(context.Background(), account, "task-1")
require.NoError(t, err)
require.Equal(t, "completed", status.Status)
require.Equal(t, []string{"https://example.com/a.png"}, status.URLs)
}

View File

@@ -10,6 +10,7 @@ import (
"fmt"
"io"
"mime"
"net"
"net/http"
"net/url"
"regexp"
@@ -26,6 +27,9 @@ var soraImageMarkdownRe = regexp.MustCompile(`!\[[^\]]*\]\(([^)]+)\)`)
var soraVideoHTMLRe = regexp.MustCompile(`(?i)<video[^>]+src=['"]([^'"]+)['"]`)
const soraRewriteBufferLimit = 2048
const soraImageInputMaxBytes = 20 << 20
const soraImageInputMaxRedirects = 3
const soraImageInputTimeout = 20 * time.Second
var soraImageSizeMap = map[string]string{
"gpt-image": "360",
@@ -33,6 +37,29 @@ var soraImageSizeMap = map[string]string{
"gpt-image-portrait": "540",
}
var soraBlockedHostnames = map[string]struct{}{
"localhost": {},
"localhost.localdomain": {},
"metadata.google.internal": {},
"metadata.google.internal.": {},
}
var soraBlockedCIDRs = mustParseCIDRs([]string{
"0.0.0.0/8",
"10.0.0.0/8",
"100.64.0.0/10",
"127.0.0.0/8",
"169.254.0.0/16",
"172.16.0.0/12",
"192.168.0.0/16",
"224.0.0.0/4",
"240.0.0.0/4",
"::/128",
"::1/128",
"fc00::/7",
"fe80::/10",
})
type soraStreamingResult struct {
mediaType string
mediaURLs []string
@@ -1233,11 +1260,24 @@ func decodeSoraImageInput(ctx context.Context, input string) ([]byte, string, er
}
func downloadSoraImageInput(ctx context.Context, rawURL string) ([]byte, string, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, rawURL, nil)
parsed, err := validateSoraImageURL(rawURL)
if err != nil {
return nil, "", err
}
resp, err := http.DefaultClient.Do(req)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, parsed.String(), nil)
if err != nil {
return nil, "", err
}
client := &http.Client{
Timeout: soraImageInputTimeout,
CheckRedirect: func(req *http.Request, via []*http.Request) error {
if len(via) >= soraImageInputMaxRedirects {
return errors.New("too many redirects")
}
return validateSoraImageURLValue(req.URL)
},
}
resp, err := client.Do(req)
if err != nil {
return nil, "", err
}
@@ -1245,14 +1285,88 @@ func downloadSoraImageInput(ctx context.Context, rawURL string) ([]byte, string,
if resp.StatusCode != http.StatusOK {
return nil, "", fmt.Errorf("download image failed: %d", resp.StatusCode)
}
data, err := io.ReadAll(io.LimitReader(resp.Body, 20<<20))
data, err := io.ReadAll(io.LimitReader(resp.Body, soraImageInputMaxBytes))
if err != nil {
return nil, "", err
}
ext := fileExtFromURL(rawURL)
ext := fileExtFromURL(parsed.String())
if ext == "" {
ext = fileExtFromContentType(resp.Header.Get("Content-Type"))
}
filename := "image" + ext
return data, filename, nil
}
func validateSoraImageURL(raw string) (*url.URL, error) {
if strings.TrimSpace(raw) == "" {
return nil, errors.New("empty image url")
}
parsed, err := url.Parse(raw)
if err != nil {
return nil, fmt.Errorf("invalid image url: %w", err)
}
if err := validateSoraImageURLValue(parsed); err != nil {
return nil, err
}
return parsed, nil
}
func validateSoraImageURLValue(parsed *url.URL) error {
if parsed == nil {
return errors.New("invalid image url")
}
scheme := strings.ToLower(strings.TrimSpace(parsed.Scheme))
if scheme != "http" && scheme != "https" {
return errors.New("only http/https image url is allowed")
}
if parsed.User != nil {
return errors.New("image url cannot contain userinfo")
}
host := strings.ToLower(strings.TrimSpace(parsed.Hostname()))
if host == "" {
return errors.New("image url missing host")
}
if _, blocked := soraBlockedHostnames[host]; blocked {
return errors.New("image url is not allowed")
}
if ip := net.ParseIP(host); ip != nil {
if isSoraBlockedIP(ip) {
return errors.New("image url is not allowed")
}
return nil
}
ips, err := net.LookupIP(host)
if err != nil {
return fmt.Errorf("resolve image url failed: %w", err)
}
for _, ip := range ips {
if isSoraBlockedIP(ip) {
return errors.New("image url is not allowed")
}
}
return nil
}
func isSoraBlockedIP(ip net.IP) bool {
if ip == nil {
return true
}
for _, cidr := range soraBlockedCIDRs {
if cidr.Contains(ip) {
return true
}
}
return false
}
func mustParseCIDRs(values []string) []*net.IPNet {
out := make([]*net.IPNet, 0, len(values))
for _, val := range values {
_, cidr, err := net.ParseCIDR(val)
if err != nil {
continue
}
out = append(out, cidr)
}
return out
}

View File

@@ -97,3 +97,16 @@ func TestSoraGatewayService_BuildSoraMediaURLSigned(t *testing.T) {
require.Contains(t, url, "expires=")
require.Contains(t, url, "sig=")
}
func TestDecodeSoraImageInput_BlockPrivateURL(t *testing.T) {
_, _, err := decodeSoraImageInput(context.Background(), "http://127.0.0.1/internal.png")
require.Error(t, err)
}
func TestDecodeSoraImageInput_DataURL(t *testing.T) {
encoded := "data:image/png;base64,aGVsbG8="
data, filename, err := decodeSoraImageInput(context.Background(), encoded)
require.NoError(t, err)
require.NotEmpty(t, data)
require.Contains(t, filename, ".png")
}

View File

@@ -30,6 +30,8 @@ type SoraMediaStorage struct {
imageRoot string
videoRoot string
maxConcurrent int
downloadTimeout time.Duration
maxDownloadBytes int64
fallbackToUpstream bool
debug bool
sem chan struct{}
@@ -92,6 +94,17 @@ func (s *SoraMediaStorage) refreshConfig() {
maxConcurrent = 4
}
s.maxConcurrent = maxConcurrent
timeoutSeconds := s.cfg.Sora.Storage.DownloadTimeoutSeconds
if timeoutSeconds <= 0 {
timeoutSeconds = 120
}
s.downloadTimeout = time.Duration(timeoutSeconds) * time.Second
maxBytes := s.cfg.Sora.Storage.MaxDownloadBytes
if maxBytes <= 0 {
maxBytes = 200 << 20
}
s.maxDownloadBytes = maxBytes
s.fallbackToUpstream = s.cfg.Sora.Storage.FallbackToUpstream
s.debug = s.cfg.Sora.Storage.Debug
s.sem = make(chan struct{}, maxConcurrent)
@@ -180,7 +193,8 @@ func (s *SoraMediaStorage) downloadOnce(ctx context.Context, root, mediaType, ra
if err != nil {
return "", err
}
resp, err := http.DefaultClient.Do(req)
client := &http.Client{Timeout: s.downloadTimeout}
resp, err := client.Do(req)
if err != nil {
return "", err
}
@@ -198,6 +212,9 @@ func (s *SoraMediaStorage) downloadOnce(ctx context.Context, root, mediaType, ra
if ext == "" {
ext = ".bin"
}
if s.maxDownloadBytes > 0 && resp.ContentLength > s.maxDownloadBytes {
return "", fmt.Errorf("download size exceeds limit: %d", resp.ContentLength)
}
datePath := time.Now().Format("2006/01/02")
destDir := filepath.Join(root, filepath.FromSlash(datePath))
@@ -212,10 +229,16 @@ func (s *SoraMediaStorage) downloadOnce(ctx context.Context, root, mediaType, ra
}
defer func() { _ = out.Close() }()
if _, err := io.Copy(out, resp.Body); err != nil {
limited := io.LimitReader(resp.Body, s.maxDownloadBytes+1)
written, err := io.Copy(out, limited)
if err != nil {
_ = os.Remove(destPath)
return "", err
}
if s.maxDownloadBytes > 0 && written > s.maxDownloadBytes {
_ = os.Remove(destPath)
return "", fmt.Errorf("download size exceeds limit: %d", written)
}
relative := path.Join("/", mediaType, datePath, filename)
if s.debug {

View File

@@ -67,3 +67,27 @@ func TestSoraMediaStorage_FallbackToUpstream(t *testing.T) {
require.NoError(t, err)
require.Equal(t, []string{url}, urls)
}
func TestSoraMediaStorage_MaxDownloadBytes(t *testing.T) {
tmpDir := t.TempDir()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "image/png")
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("too-large"))
}))
defer server.Close()
cfg := &config.Config{
Sora: config.SoraConfig{
Storage: config.SoraStorageConfig{
Type: "local",
LocalPath: tmpDir,
MaxDownloadBytes: 1,
},
},
}
storage := NewSoraMediaStorage(cfg)
_, err := storage.StoreFromURLs(context.Background(), "image", []string{server.URL + "/img.png"})
require.Error(t, err)
}

View File

@@ -269,6 +269,12 @@ sora:
# Max poll attempts
# 最大轮询次数
max_poll_attempts: 600
# Recent task query limit (image)
# 最近任务查询数量(图片轮询)
recent_task_limit: 50
# Recent task query max limit (fallback)
# 最近任务查询最大数量(回退)
recent_task_limit_max: 200
# Enable debug logs for Sora upstream requests
# 启用 Sora 直连调试日志
debug: false
@@ -294,6 +300,12 @@ sora:
# Max concurrent downloads
# 并发下载上限
max_concurrent_downloads: 4
# Download timeout (seconds)
# 下载超时(秒)
download_timeout_seconds: 120
# Max download bytes
# 最大下载字节数
max_download_bytes: 209715200
# Enable debug logs for media storage
# 启用媒体存储调试日志
debug: false