feat(sora): 新增 Sora 平台支持并修复高危安全和性能问题
新增功能: - 新增 Sora 账号管理和 OAuth 认证 - 新增 Sora 视频/图片生成 API 网关 - 新增 Sora 任务调度和缓存机制 - 新增 Sora 使用统计和计费支持 - 前端增加 Sora 平台配置界面 安全修复(代码审核): - [SEC-001] 限制媒体下载响应体大小(图片 20MB、视频 200MB),防止 DoS 攻击 - [SEC-002] 限制 SDK API 响应大小(1MB),防止内存耗尽 - [SEC-003] 修复 SSRF 风险,添加 URL 验证并强制使用代理配置 BUG 修复(代码审核): - [BUG-001] 修复 for 循环内 defer 累积导致的资源泄漏 - [BUG-002] 修复图片并发槽位获取失败时已持有锁未释放的永久泄漏 性能优化(代码审核): - [PERF-001] 添加 Sentinel Token 缓存(3 分钟有效期),减少 PoW 计算开销 技术细节: - 使用 io.LimitReader 限制所有外部输入的大小 - 添加 urlvalidator 验证防止 SSRF 攻击 - 使用 sync.Map 实现线程安全的包级缓存 - 优化并发槽位管理,添加 releaseAll 模式防止泄漏 影响范围: - 后端:新增 Sora 相关数据模型、服务、网关和管理接口 - 前端:新增 Sora 平台配置、账号管理和监控界面 - 配置:新增 Sora 相关配置项和环境变量 Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
This commit is contained in:
246
backend/internal/service/sora_cache_service.go
Normal file
246
backend/internal/service/sora_cache_service.go
Normal file
@@ -0,0 +1,246 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/uuidv7"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
|
||||
)
|
||||
|
||||
// SoraCacheService 提供 Sora 视频缓存能力。
|
||||
type SoraCacheService struct {
|
||||
cfg *config.Config
|
||||
cacheRepo SoraCacheFileRepository
|
||||
settingService *SettingService
|
||||
accountRepo AccountRepository
|
||||
httpUpstream HTTPUpstream
|
||||
}
|
||||
|
||||
// NewSoraCacheService 创建 SoraCacheService。
|
||||
func NewSoraCacheService(cfg *config.Config, cacheRepo SoraCacheFileRepository, settingService *SettingService, accountRepo AccountRepository, httpUpstream HTTPUpstream) *SoraCacheService {
|
||||
return &SoraCacheService{
|
||||
cfg: cfg,
|
||||
cacheRepo: cacheRepo,
|
||||
settingService: settingService,
|
||||
accountRepo: accountRepo,
|
||||
httpUpstream: httpUpstream,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SoraCacheService) CacheVideo(ctx context.Context, accountID, userID int64, taskID, mediaURL string) (*SoraCacheFile, error) {
|
||||
cfg := s.getSoraConfig(ctx)
|
||||
if !cfg.Cache.Enabled {
|
||||
return nil, nil
|
||||
}
|
||||
trimmed := strings.TrimSpace(mediaURL)
|
||||
if trimmed == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
allowedHosts := cfg.Cache.AllowedHosts
|
||||
useAllowlist := true
|
||||
if len(allowedHosts) == 0 {
|
||||
if s.cfg != nil {
|
||||
allowedHosts = s.cfg.Security.URLAllowlist.UpstreamHosts
|
||||
useAllowlist = s.cfg.Security.URLAllowlist.Enabled
|
||||
} else {
|
||||
useAllowlist = false
|
||||
}
|
||||
}
|
||||
|
||||
if useAllowlist {
|
||||
if _, err := urlvalidator.ValidateHTTPSURL(trimmed, urlvalidator.ValidationOptions{
|
||||
AllowedHosts: allowedHosts,
|
||||
RequireAllowlist: true,
|
||||
AllowPrivate: s.cfg != nil && s.cfg.Security.URLAllowlist.AllowPrivateHosts,
|
||||
}); err != nil {
|
||||
return nil, fmt.Errorf("缓存下载地址不合法: %w", err)
|
||||
}
|
||||
} else {
|
||||
allowInsecure := false
|
||||
if s.cfg != nil {
|
||||
allowInsecure = s.cfg.Security.URLAllowlist.AllowInsecureHTTP
|
||||
}
|
||||
if _, err := urlvalidator.ValidateURLFormat(trimmed, allowInsecure); err != nil {
|
||||
return nil, fmt.Errorf("缓存下载地址不合法: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
videoDir := strings.TrimSpace(cfg.Cache.VideoDir)
|
||||
if videoDir == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if cfg.Cache.MaxBytes > 0 {
|
||||
size, err := dirSize(videoDir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if size >= cfg.Cache.MaxBytes {
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
relativeDir := ""
|
||||
if cfg.Cache.UserDirEnabled && userID > 0 {
|
||||
relativeDir = fmt.Sprintf("u_%d", userID)
|
||||
}
|
||||
|
||||
targetDir := filepath.Join(videoDir, relativeDir)
|
||||
if err := os.MkdirAll(targetDir, 0o755); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
uuid, err := uuidv7.New()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
name := deriveFileName(trimmed)
|
||||
if name == "" {
|
||||
name = "video.mp4"
|
||||
}
|
||||
name = sanitizeFileName(name)
|
||||
filename := uuid + "_" + name
|
||||
cachePath := filepath.Join(targetDir, filename)
|
||||
|
||||
resp, err := s.downloadMedia(ctx, accountID, trimmed, time.Duration(cfg.Timeout)*time.Second)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return nil, fmt.Errorf("缓存下载失败: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
out, err := os.Create(cachePath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer out.Close()
|
||||
|
||||
written, err := io.Copy(out, resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cacheURL := buildCacheURL(relativeDir, filename)
|
||||
|
||||
record := &SoraCacheFile{
|
||||
TaskID: taskID,
|
||||
AccountID: accountID,
|
||||
UserID: userID,
|
||||
MediaType: "video",
|
||||
OriginalURL: trimmed,
|
||||
CachePath: cachePath,
|
||||
CacheURL: cacheURL,
|
||||
SizeBytes: written,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
if s.cacheRepo != nil {
|
||||
if err := s.cacheRepo.Create(ctx, record); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return record, nil
|
||||
}
|
||||
|
||||
func buildCacheURL(relativeDir, filename string) string {
|
||||
base := "/data/video"
|
||||
if relativeDir != "" {
|
||||
return path.Join(base, relativeDir, filename)
|
||||
}
|
||||
return path.Join(base, filename)
|
||||
}
|
||||
|
||||
func (s *SoraCacheService) getSoraConfig(ctx context.Context) config.SoraConfig {
|
||||
if s.settingService != nil {
|
||||
return s.settingService.GetSoraConfig(ctx)
|
||||
}
|
||||
if s.cfg != nil {
|
||||
return s.cfg.Sora
|
||||
}
|
||||
return config.SoraConfig{}
|
||||
}
|
||||
|
||||
func (s *SoraCacheService) downloadMedia(ctx context.Context, accountID int64, mediaURL string, timeout time.Duration) (*http.Response, error) {
|
||||
if timeout <= 0 {
|
||||
timeout = 120 * time.Second
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", mediaURL, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("User-Agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36")
|
||||
|
||||
if s.httpUpstream == nil {
|
||||
client := &http.Client{Timeout: timeout}
|
||||
return client.Do(req)
|
||||
}
|
||||
|
||||
var accountConcurrency int
|
||||
proxyURL := ""
|
||||
if s.accountRepo != nil && accountID > 0 {
|
||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
||||
if err == nil && account != nil {
|
||||
accountConcurrency = account.Concurrency
|
||||
if account.Proxy != nil {
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
}
|
||||
}
|
||||
enableTLS := false
|
||||
if s.cfg != nil {
|
||||
enableTLS = s.cfg.Gateway.TLSFingerprint.Enabled
|
||||
}
|
||||
return s.httpUpstream.DoWithTLS(req, proxyURL, accountID, accountConcurrency, enableTLS)
|
||||
}
|
||||
|
||||
func deriveFileName(rawURL string) string {
|
||||
parsed, err := url.Parse(rawURL)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
name := path.Base(parsed.Path)
|
||||
if name == "/" || name == "." {
|
||||
return ""
|
||||
}
|
||||
return name
|
||||
}
|
||||
|
||||
func sanitizeFileName(name string) string {
|
||||
name = strings.TrimSpace(name)
|
||||
if name == "" {
|
||||
return ""
|
||||
}
|
||||
sanitized := strings.Map(func(r rune) rune {
|
||||
switch {
|
||||
case r >= 'a' && r <= 'z':
|
||||
return r
|
||||
case r >= 'A' && r <= 'Z':
|
||||
return r
|
||||
case r >= '0' && r <= '9':
|
||||
return r
|
||||
case r == '-' || r == '_' || r == '.':
|
||||
return r
|
||||
case r == ' ': // 空格替换为下划线
|
||||
return '_'
|
||||
default:
|
||||
return -1
|
||||
}
|
||||
}, name)
|
||||
return strings.TrimLeft(sanitized, ".")
|
||||
}
|
||||
Reference in New Issue
Block a user