diff --git a/backend/internal/service/sora_media_storage.go b/backend/internal/service/sora_media_storage.go index 562ded46..8294af62 100644 --- a/backend/internal/service/sora_media_storage.go +++ b/backend/internal/service/sora_media_storage.go @@ -84,6 +84,12 @@ func (s *SoraMediaStorage) refreshConfig() { if root == "" { root = soraStorageDefaultRoot } + root = filepath.Clean(root) + if !filepath.IsAbs(root) { + if absRoot, err := filepath.Abs(root); err == nil { + root = absRoot + } + } s.root = root s.imageRoot = filepath.Join(root, "image") s.videoRoot = filepath.Join(root, "video") @@ -203,9 +209,9 @@ func (s *SoraMediaStorage) downloadOnce(ctx context.Context, root, mediaType, ra return "", fmt.Errorf("download failed: %d %s", resp.StatusCode, string(body)) } - ext := fileExtFromURL(rawURL) + ext := normalizeSoraFileExt(fileExtFromURL(rawURL)) if ext == "" { - ext = fileExtFromContentType(resp.Header.Get("Content-Type")) + ext = normalizeSoraFileExt(fileExtFromContentType(resp.Header.Get("Content-Type"))) } if ext == "" { ext = ".bin" @@ -220,8 +226,11 @@ func (s *SoraMediaStorage) downloadOnce(ctx context.Context, root, mediaType, ra return "", err } filename := uuid.NewString() + ext - destPath := filepath.Join(destDir, filename) - out, err := os.Create(destPath) + destPath, err := joinPathWithinDir(destDir, filename) + if err != nil { + return "", err + } + out, err := os.OpenFile(destPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o644) if err != nil { return "", err } @@ -230,11 +239,11 @@ func (s *SoraMediaStorage) downloadOnce(ctx context.Context, root, mediaType, ra limited := io.LimitReader(resp.Body, s.maxDownloadBytes+1) written, err := io.Copy(out, limited) if err != nil { - _ = os.Remove(destPath) + removePartialDownload(destDir, filename) return "", err } if s.maxDownloadBytes > 0 && written > s.maxDownloadBytes { - _ = os.Remove(destPath) + removePartialDownload(destDir, filename) return "", fmt.Errorf("download size exceeds limit: %d", written) } @@ -275,3 +284,38 @@ func fileExtFromContentType(ct string) string { } return "" } + +func normalizeSoraFileExt(ext string) string { + ext = strings.ToLower(strings.TrimSpace(ext)) + switch ext { + case ".png", ".jpg", ".jpeg", ".gif", ".webp", ".bmp", ".svg", ".tif", ".tiff", ".heic", + ".mp4", ".mov", ".webm", ".m4v", ".avi", ".mkv", ".3gp", ".flv": + return ext + default: + return "" + } +} + +func joinPathWithinDir(baseDir, filename string) (string, error) { + baseDir = filepath.Clean(baseDir) + if strings.TrimSpace(filename) == "" { + return "", errors.New("empty filename") + } + joined := filepath.Clean(filepath.Join(baseDir, filename)) + rel, err := filepath.Rel(baseDir, joined) + if err != nil { + return "", fmt.Errorf("resolve path rel: %w", err) + } + if rel == ".." || strings.HasPrefix(rel, ".."+string(os.PathSeparator)) { + return "", fmt.Errorf("path traversal detected: %s", filename) + } + return joined, nil +} + +func removePartialDownload(baseDir, filename string) { + target, err := joinPathWithinDir(baseDir, filename) + if err != nil { + return + } + _ = os.Remove(target) +} diff --git a/backend/internal/service/sora_media_storage_test.go b/backend/internal/service/sora_media_storage_test.go index 0050afed..630a971b 100644 --- a/backend/internal/service/sora_media_storage_test.go +++ b/backend/internal/service/sora_media_storage_test.go @@ -91,3 +91,21 @@ func TestSoraMediaStorage_MaxDownloadBytes(t *testing.T) { _, err := storage.StoreFromURLs(context.Background(), "image", []string{server.URL + "/img.png"}) require.Error(t, err) } + +func TestJoinPathWithinDir(t *testing.T) { + baseDir := t.TempDir() + + path1, err := joinPathWithinDir(baseDir, "ok.png") + require.NoError(t, err) + require.Equal(t, filepath.Join(baseDir, "ok.png"), path1) + + _, err = joinPathWithinDir(baseDir, "../escape.png") + require.Error(t, err) +} + +func TestNormalizeSoraFileExt(t *testing.T) { + require.Equal(t, ".png", normalizeSoraFileExt(".PNG")) + require.Equal(t, ".mp4", normalizeSoraFileExt(".mp4")) + require.Equal(t, "", normalizeSoraFileExt("../../etc/passwd")) + require.Equal(t, "", normalizeSoraFileExt(".php")) +}