fix(service): 使用 os.Root 修复 Sora 存储路径告警

- 将媒体写入和删除切换为 os.Root 沙箱 API
- 移除旧的路径拼接校验辅助函数并收敛删除逻辑
- 调整并新增相关单元测试覆盖删除行为

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
yangjianbo
2026-02-23 16:06:04 +08:00
parent e8671fd7c2
commit c2567831d9
2 changed files with 35 additions and 41 deletions

View File

@@ -220,17 +220,20 @@ func (s *SoraMediaStorage) downloadOnce(ctx context.Context, root, mediaType, ra
return "", fmt.Errorf("download size exceeds limit: %d", resp.ContentLength) return "", fmt.Errorf("download size exceeds limit: %d", resp.ContentLength)
} }
datePath := time.Now().Format("2006/01/02") storageRoot, err := os.OpenRoot(root)
destDir := filepath.Join(root, filepath.FromSlash(datePath))
if err := os.MkdirAll(destDir, 0o755); err != nil {
return "", err
}
filename := uuid.NewString() + ext
destPath, err := joinPathWithinDir(destDir, filename)
if err != nil { if err != nil {
return "", err return "", err
} }
out, err := os.OpenFile(destPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o644) defer func() { _ = storageRoot.Close() }()
datePath := time.Now().Format("2006/01/02")
datePathFS := filepath.FromSlash(datePath)
if err := storageRoot.MkdirAll(datePathFS, 0o755); err != nil {
return "", err
}
filename := uuid.NewString() + ext
filePath := filepath.Join(datePathFS, filename)
out, err := storageRoot.OpenFile(filePath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o644)
if err != nil { if err != nil {
return "", err return "", err
} }
@@ -239,11 +242,11 @@ func (s *SoraMediaStorage) downloadOnce(ctx context.Context, root, mediaType, ra
limited := io.LimitReader(resp.Body, s.maxDownloadBytes+1) limited := io.LimitReader(resp.Body, s.maxDownloadBytes+1)
written, err := io.Copy(out, limited) written, err := io.Copy(out, limited)
if err != nil { if err != nil {
removePartialDownload(destDir, filename) removePartialDownload(storageRoot, filePath)
return "", err return "", err
} }
if s.maxDownloadBytes > 0 && written > s.maxDownloadBytes { if s.maxDownloadBytes > 0 && written > s.maxDownloadBytes {
removePartialDownload(destDir, filename) removePartialDownload(storageRoot, filePath)
return "", fmt.Errorf("download size exceeds limit: %d", written) return "", fmt.Errorf("download size exceeds limit: %d", written)
} }
@@ -296,26 +299,9 @@ func normalizeSoraFileExt(ext string) string {
} }
} }
func joinPathWithinDir(baseDir, filename string) (string, error) { func removePartialDownload(root *os.Root, filePath string) {
baseDir = filepath.Clean(baseDir) if root == nil || strings.TrimSpace(filePath) == "" {
if strings.TrimSpace(filename) == "" {
return "", errors.New("empty filename")
}
joined := filepath.Clean(filepath.Join(baseDir, filename))
rel, err := filepath.Rel(baseDir, joined)
if err != nil {
return "", fmt.Errorf("resolve path rel: %w", err)
}
if rel == ".." || strings.HasPrefix(rel, ".."+string(os.PathSeparator)) {
return "", fmt.Errorf("path traversal detected: %s", filename)
}
return joined, nil
}
func removePartialDownload(baseDir, filename string) {
target, err := joinPathWithinDir(baseDir, filename)
if err != nil {
return return
} }
_ = os.Remove(target) _ = root.Remove(filePath)
} }

View File

@@ -6,6 +6,7 @@ import (
"context" "context"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"os"
"path/filepath" "path/filepath"
"strings" "strings"
"testing" "testing"
@@ -92,20 +93,27 @@ func TestSoraMediaStorage_MaxDownloadBytes(t *testing.T) {
require.Error(t, err) require.Error(t, err)
} }
func TestJoinPathWithinDir(t *testing.T) {
baseDir := t.TempDir()
path1, err := joinPathWithinDir(baseDir, "ok.png")
require.NoError(t, err)
require.Equal(t, filepath.Join(baseDir, "ok.png"), path1)
_, err = joinPathWithinDir(baseDir, "../escape.png")
require.Error(t, err)
}
func TestNormalizeSoraFileExt(t *testing.T) { func TestNormalizeSoraFileExt(t *testing.T) {
require.Equal(t, ".png", normalizeSoraFileExt(".PNG")) require.Equal(t, ".png", normalizeSoraFileExt(".PNG"))
require.Equal(t, ".mp4", normalizeSoraFileExt(".mp4")) require.Equal(t, ".mp4", normalizeSoraFileExt(".mp4"))
require.Equal(t, "", normalizeSoraFileExt("../../etc/passwd")) require.Equal(t, "", normalizeSoraFileExt("../../etc/passwd"))
require.Equal(t, "", normalizeSoraFileExt(".php")) require.Equal(t, "", normalizeSoraFileExt(".php"))
} }
func TestRemovePartialDownload(t *testing.T) {
tmpDir := t.TempDir()
root, err := os.OpenRoot(tmpDir)
require.NoError(t, err)
defer func() { _ = root.Close() }()
filePath := "partial.bin"
f, err := root.OpenFile(filePath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o600)
require.NoError(t, err)
_, _ = f.WriteString("partial")
_ = f.Close()
removePartialDownload(root, filePath)
_, err = root.Stat(filePath)
require.Error(t, err)
require.True(t, os.IsNotExist(err))
}