fix(service): 修复 Sora 媒体落地路径穿越风险

- 新增安全路径拼接校验,确保目标文件仍在下载目录内
- 清理失败下载文件时复用安全校验,避免不安全删除路径
- 增加扩展名白名单归一化与相关单元测试

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
yangjianbo
2026-02-23 14:42:07 +08:00
parent 4950ee48a0
commit e8671fd7c2
2 changed files with 68 additions and 6 deletions

View File

@@ -84,6 +84,12 @@ func (s *SoraMediaStorage) refreshConfig() {
if root == "" { if root == "" {
root = soraStorageDefaultRoot root = soraStorageDefaultRoot
} }
root = filepath.Clean(root)
if !filepath.IsAbs(root) {
if absRoot, err := filepath.Abs(root); err == nil {
root = absRoot
}
}
s.root = root s.root = root
s.imageRoot = filepath.Join(root, "image") s.imageRoot = filepath.Join(root, "image")
s.videoRoot = filepath.Join(root, "video") s.videoRoot = filepath.Join(root, "video")
@@ -203,9 +209,9 @@ func (s *SoraMediaStorage) downloadOnce(ctx context.Context, root, mediaType, ra
return "", fmt.Errorf("download failed: %d %s", resp.StatusCode, string(body)) return "", fmt.Errorf("download failed: %d %s", resp.StatusCode, string(body))
} }
ext := fileExtFromURL(rawURL) ext := normalizeSoraFileExt(fileExtFromURL(rawURL))
if ext == "" { if ext == "" {
ext = fileExtFromContentType(resp.Header.Get("Content-Type")) ext = normalizeSoraFileExt(fileExtFromContentType(resp.Header.Get("Content-Type")))
} }
if ext == "" { if ext == "" {
ext = ".bin" ext = ".bin"
@@ -220,8 +226,11 @@ func (s *SoraMediaStorage) downloadOnce(ctx context.Context, root, mediaType, ra
return "", err return "", err
} }
filename := uuid.NewString() + ext filename := uuid.NewString() + ext
destPath := filepath.Join(destDir, filename) destPath, err := joinPathWithinDir(destDir, filename)
out, err := os.Create(destPath) if err != nil {
return "", err
}
out, err := os.OpenFile(destPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o644)
if err != nil { if err != nil {
return "", err return "", err
} }
@@ -230,11 +239,11 @@ func (s *SoraMediaStorage) downloadOnce(ctx context.Context, root, mediaType, ra
limited := io.LimitReader(resp.Body, s.maxDownloadBytes+1) limited := io.LimitReader(resp.Body, s.maxDownloadBytes+1)
written, err := io.Copy(out, limited) written, err := io.Copy(out, limited)
if err != nil { if err != nil {
_ = os.Remove(destPath) removePartialDownload(destDir, filename)
return "", err return "", err
} }
if s.maxDownloadBytes > 0 && written > s.maxDownloadBytes { if s.maxDownloadBytes > 0 && written > s.maxDownloadBytes {
_ = os.Remove(destPath) removePartialDownload(destDir, filename)
return "", fmt.Errorf("download size exceeds limit: %d", written) return "", fmt.Errorf("download size exceeds limit: %d", written)
} }
@@ -275,3 +284,38 @@ func fileExtFromContentType(ct string) string {
} }
return "" return ""
} }
func normalizeSoraFileExt(ext string) string {
ext = strings.ToLower(strings.TrimSpace(ext))
switch ext {
case ".png", ".jpg", ".jpeg", ".gif", ".webp", ".bmp", ".svg", ".tif", ".tiff", ".heic",
".mp4", ".mov", ".webm", ".m4v", ".avi", ".mkv", ".3gp", ".flv":
return ext
default:
return ""
}
}
func joinPathWithinDir(baseDir, filename string) (string, error) {
baseDir = filepath.Clean(baseDir)
if strings.TrimSpace(filename) == "" {
return "", errors.New("empty filename")
}
joined := filepath.Clean(filepath.Join(baseDir, filename))
rel, err := filepath.Rel(baseDir, joined)
if err != nil {
return "", fmt.Errorf("resolve path rel: %w", err)
}
if rel == ".." || strings.HasPrefix(rel, ".."+string(os.PathSeparator)) {
return "", fmt.Errorf("path traversal detected: %s", filename)
}
return joined, nil
}
func removePartialDownload(baseDir, filename string) {
target, err := joinPathWithinDir(baseDir, filename)
if err != nil {
return
}
_ = os.Remove(target)
}

View File

@@ -91,3 +91,21 @@ func TestSoraMediaStorage_MaxDownloadBytes(t *testing.T) {
_, err := storage.StoreFromURLs(context.Background(), "image", []string{server.URL + "/img.png"}) _, err := storage.StoreFromURLs(context.Background(), "image", []string{server.URL + "/img.png"})
require.Error(t, err) require.Error(t, err)
} }
func TestJoinPathWithinDir(t *testing.T) {
baseDir := t.TempDir()
path1, err := joinPathWithinDir(baseDir, "ok.png")
require.NoError(t, err)
require.Equal(t, filepath.Join(baseDir, "ok.png"), path1)
_, err = joinPathWithinDir(baseDir, "../escape.png")
require.Error(t, err)
}
func TestNormalizeSoraFileExt(t *testing.T) {
require.Equal(t, ".png", normalizeSoraFileExt(".PNG"))
require.Equal(t, ".mp4", normalizeSoraFileExt(".mp4"))
require.Equal(t, "", normalizeSoraFileExt("../../etc/passwd"))
require.Equal(t, "", normalizeSoraFileExt(".php"))
}