From c2567831d949e1fc355ae835b5c5f82b08e699f4 Mon Sep 17 00:00:00 2001 From: yangjianbo Date: Mon, 23 Feb 2026 16:06:04 +0800 Subject: [PATCH] =?UTF-8?q?fix(service):=20=E4=BD=BF=E7=94=A8=20os.Root=20?= =?UTF-8?q?=E4=BF=AE=E5=A4=8D=20Sora=20=E5=AD=98=E5=82=A8=E8=B7=AF?= =?UTF-8?q?=E5=BE=84=E5=91=8A=E8=AD=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 将媒体写入和删除切换为 os.Root 沙箱 API - 移除旧的路径拼接校验辅助函数并收敛删除逻辑 - 调整并新增相关单元测试覆盖删除行为 Co-Authored-By: Claude Opus 4.6 --- .../internal/service/sora_media_storage.go | 46 +++++++------------ .../service/sora_media_storage_test.go | 30 +++++++----- 2 files changed, 35 insertions(+), 41 deletions(-) diff --git a/backend/internal/service/sora_media_storage.go b/backend/internal/service/sora_media_storage.go index 8294af62..8b83cb76 100644 --- a/backend/internal/service/sora_media_storage.go +++ b/backend/internal/service/sora_media_storage.go @@ -220,17 +220,20 @@ func (s *SoraMediaStorage) downloadOnce(ctx context.Context, root, mediaType, ra return "", fmt.Errorf("download size exceeds limit: %d", resp.ContentLength) } - datePath := time.Now().Format("2006/01/02") - destDir := filepath.Join(root, filepath.FromSlash(datePath)) - if err := os.MkdirAll(destDir, 0o755); err != nil { - return "", err - } - filename := uuid.NewString() + ext - destPath, err := joinPathWithinDir(destDir, filename) + storageRoot, err := os.OpenRoot(root) if err != nil { return "", err } - out, err := os.OpenFile(destPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o644) + defer func() { _ = storageRoot.Close() }() + + datePath := time.Now().Format("2006/01/02") + datePathFS := filepath.FromSlash(datePath) + if err := storageRoot.MkdirAll(datePathFS, 0o755); err != nil { + return "", err + } + filename := uuid.NewString() + ext + filePath := filepath.Join(datePathFS, filename) + out, err := storageRoot.OpenFile(filePath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o644) if err != nil { return "", err } @@ -239,11 +242,11 @@ func (s *SoraMediaStorage) downloadOnce(ctx context.Context, root, mediaType, ra limited := io.LimitReader(resp.Body, s.maxDownloadBytes+1) written, err := io.Copy(out, limited) if err != nil { - removePartialDownload(destDir, filename) + removePartialDownload(storageRoot, filePath) return "", err } if s.maxDownloadBytes > 0 && written > s.maxDownloadBytes { - removePartialDownload(destDir, filename) + removePartialDownload(storageRoot, filePath) return "", fmt.Errorf("download size exceeds limit: %d", written) } @@ -296,26 +299,9 @@ func normalizeSoraFileExt(ext string) string { } } -func joinPathWithinDir(baseDir, filename string) (string, error) { - baseDir = filepath.Clean(baseDir) - if strings.TrimSpace(filename) == "" { - return "", errors.New("empty filename") - } - joined := filepath.Clean(filepath.Join(baseDir, filename)) - rel, err := filepath.Rel(baseDir, joined) - if err != nil { - return "", fmt.Errorf("resolve path rel: %w", err) - } - if rel == ".." || strings.HasPrefix(rel, ".."+string(os.PathSeparator)) { - return "", fmt.Errorf("path traversal detected: %s", filename) - } - return joined, nil -} - -func removePartialDownload(baseDir, filename string) { - target, err := joinPathWithinDir(baseDir, filename) - if err != nil { +func removePartialDownload(root *os.Root, filePath string) { + if root == nil || strings.TrimSpace(filePath) == "" { return } - _ = os.Remove(target) + _ = root.Remove(filePath) } diff --git a/backend/internal/service/sora_media_storage_test.go b/backend/internal/service/sora_media_storage_test.go index 630a971b..5359f4e6 100644 --- a/backend/internal/service/sora_media_storage_test.go +++ b/backend/internal/service/sora_media_storage_test.go @@ -6,6 +6,7 @@ import ( "context" "net/http" "net/http/httptest" + "os" "path/filepath" "strings" "testing" @@ -92,20 +93,27 @@ func TestSoraMediaStorage_MaxDownloadBytes(t *testing.T) { require.Error(t, err) } -func TestJoinPathWithinDir(t *testing.T) { - baseDir := t.TempDir() - - path1, err := joinPathWithinDir(baseDir, "ok.png") - require.NoError(t, err) - require.Equal(t, filepath.Join(baseDir, "ok.png"), path1) - - _, err = joinPathWithinDir(baseDir, "../escape.png") - require.Error(t, err) -} - func TestNormalizeSoraFileExt(t *testing.T) { require.Equal(t, ".png", normalizeSoraFileExt(".PNG")) require.Equal(t, ".mp4", normalizeSoraFileExt(".mp4")) require.Equal(t, "", normalizeSoraFileExt("../../etc/passwd")) require.Equal(t, "", normalizeSoraFileExt(".php")) } + +func TestRemovePartialDownload(t *testing.T) { + tmpDir := t.TempDir() + root, err := os.OpenRoot(tmpDir) + require.NoError(t, err) + defer func() { _ = root.Close() }() + + filePath := "partial.bin" + f, err := root.OpenFile(filePath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o600) + require.NoError(t, err) + _, _ = f.WriteString("partial") + _ = f.Close() + + removePartialDownload(root, filePath) + _, err = root.Stat(filePath) + require.Error(t, err) + require.True(t, os.IsNotExist(err)) +}