feat(sync): full code sync from release
This commit is contained in:
392
backend/internal/service/sora_s3_storage.go
Normal file
392
backend/internal/service/sora_s3_storage.go
Normal file
@@ -0,0 +1,392 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"path"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4"
|
||||
awsconfig "github.com/aws/aws-sdk-go-v2/config"
|
||||
"github.com/aws/aws-sdk-go-v2/credentials"
|
||||
"github.com/aws/aws-sdk-go-v2/service/s3"
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
)
|
||||
|
||||
// SoraS3Storage 负责 Sora 媒体文件的 S3 存储操作。
|
||||
// 从 Settings 表读取 S3 配置,初始化并缓存 S3 客户端。
|
||||
type SoraS3Storage struct {
|
||||
settingService *SettingService
|
||||
|
||||
mu sync.RWMutex
|
||||
client *s3.Client
|
||||
cfg *SoraS3Settings // 上次加载的配置快照
|
||||
|
||||
healthCheckedAt time.Time
|
||||
healthErr error
|
||||
healthTTL time.Duration
|
||||
}
|
||||
|
||||
const defaultSoraS3HealthTTL = 30 * time.Second
|
||||
|
||||
// UpstreamDownloadError 表示从上游下载媒体失败(包含 HTTP 状态码)。
|
||||
type UpstreamDownloadError struct {
|
||||
StatusCode int
|
||||
}
|
||||
|
||||
func (e *UpstreamDownloadError) Error() string {
|
||||
if e == nil {
|
||||
return "upstream download failed"
|
||||
}
|
||||
return fmt.Sprintf("upstream returned %d", e.StatusCode)
|
||||
}
|
||||
|
||||
// NewSoraS3Storage 创建 S3 存储服务实例。
|
||||
func NewSoraS3Storage(settingService *SettingService) *SoraS3Storage {
|
||||
return &SoraS3Storage{
|
||||
settingService: settingService,
|
||||
healthTTL: defaultSoraS3HealthTTL,
|
||||
}
|
||||
}
|
||||
|
||||
// Enabled 返回 S3 存储是否已启用且配置有效。
|
||||
func (s *SoraS3Storage) Enabled(ctx context.Context) bool {
|
||||
cfg, err := s.getConfig(ctx)
|
||||
if err != nil || cfg == nil {
|
||||
return false
|
||||
}
|
||||
return cfg.Enabled && cfg.Bucket != ""
|
||||
}
|
||||
|
||||
// getConfig 获取当前 S3 配置(从 settings 表读取)。
|
||||
func (s *SoraS3Storage) getConfig(ctx context.Context) (*SoraS3Settings, error) {
|
||||
if s.settingService == nil {
|
||||
return nil, fmt.Errorf("setting service not available")
|
||||
}
|
||||
return s.settingService.GetSoraS3Settings(ctx)
|
||||
}
|
||||
|
||||
// getClient 获取或初始化 S3 客户端(带缓存)。
|
||||
// 配置变更时调用 RefreshClient 清除缓存。
|
||||
func (s *SoraS3Storage) getClient(ctx context.Context) (*s3.Client, *SoraS3Settings, error) {
|
||||
s.mu.RLock()
|
||||
if s.client != nil && s.cfg != nil {
|
||||
client, cfg := s.client, s.cfg
|
||||
s.mu.RUnlock()
|
||||
return client, cfg, nil
|
||||
}
|
||||
s.mu.RUnlock()
|
||||
|
||||
return s.initClient(ctx)
|
||||
}
|
||||
|
||||
func (s *SoraS3Storage) initClient(ctx context.Context) (*s3.Client, *SoraS3Settings, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// 双重检查
|
||||
if s.client != nil && s.cfg != nil {
|
||||
return s.client, s.cfg, nil
|
||||
}
|
||||
|
||||
cfg, err := s.getConfig(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("load s3 config: %w", err)
|
||||
}
|
||||
if !cfg.Enabled {
|
||||
return nil, nil, fmt.Errorf("sora s3 storage is disabled")
|
||||
}
|
||||
if cfg.Bucket == "" || cfg.AccessKeyID == "" || cfg.SecretAccessKey == "" {
|
||||
return nil, nil, fmt.Errorf("sora s3 config incomplete: bucket, access_key_id, secret_access_key are required")
|
||||
}
|
||||
|
||||
client, region, err := buildSoraS3Client(ctx, cfg)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
s.client = client
|
||||
s.cfg = cfg
|
||||
logger.LegacyPrintf("service.sora_s3", "[SoraS3] 客户端已初始化 bucket=%s endpoint=%s region=%s", cfg.Bucket, cfg.Endpoint, region)
|
||||
return client, cfg, nil
|
||||
}
|
||||
|
||||
// RefreshClient 清除缓存的 S3 客户端,下次使用时重新初始化。
|
||||
// 应在系统设置中 S3 配置变更时调用。
|
||||
func (s *SoraS3Storage) RefreshClient() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.client = nil
|
||||
s.cfg = nil
|
||||
s.healthCheckedAt = time.Time{}
|
||||
s.healthErr = nil
|
||||
logger.LegacyPrintf("service.sora_s3", "[SoraS3] 客户端缓存已清除,下次使用将重新初始化")
|
||||
}
|
||||
|
||||
// TestConnection 测试 S3 连接(HeadBucket)。
|
||||
func (s *SoraS3Storage) TestConnection(ctx context.Context) error {
|
||||
client, cfg, err := s.getClient(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = client.HeadBucket(ctx, &s3.HeadBucketInput{
|
||||
Bucket: &cfg.Bucket,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("s3 HeadBucket failed: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsHealthy 返回 S3 健康状态(带短缓存,避免每次请求都触发 HeadBucket)。
|
||||
func (s *SoraS3Storage) IsHealthy(ctx context.Context) bool {
|
||||
if s == nil {
|
||||
return false
|
||||
}
|
||||
now := time.Now()
|
||||
s.mu.RLock()
|
||||
lastCheck := s.healthCheckedAt
|
||||
lastErr := s.healthErr
|
||||
ttl := s.healthTTL
|
||||
s.mu.RUnlock()
|
||||
|
||||
if ttl <= 0 {
|
||||
ttl = defaultSoraS3HealthTTL
|
||||
}
|
||||
if !lastCheck.IsZero() && now.Sub(lastCheck) < ttl {
|
||||
return lastErr == nil
|
||||
}
|
||||
|
||||
err := s.TestConnection(ctx)
|
||||
s.mu.Lock()
|
||||
s.healthCheckedAt = time.Now()
|
||||
s.healthErr = err
|
||||
s.mu.Unlock()
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// TestConnectionWithSettings 使用临时配置测试连接,不污染缓存的客户端。
|
||||
func (s *SoraS3Storage) TestConnectionWithSettings(ctx context.Context, cfg *SoraS3Settings) error {
|
||||
if cfg == nil {
|
||||
return fmt.Errorf("s3 config is required")
|
||||
}
|
||||
if !cfg.Enabled {
|
||||
return fmt.Errorf("sora s3 storage is disabled")
|
||||
}
|
||||
if cfg.Endpoint == "" || cfg.Bucket == "" || cfg.AccessKeyID == "" || cfg.SecretAccessKey == "" {
|
||||
return fmt.Errorf("sora s3 config incomplete: endpoint, bucket, access_key_id, secret_access_key are required")
|
||||
}
|
||||
client, _, err := buildSoraS3Client(ctx, cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = client.HeadBucket(ctx, &s3.HeadBucketInput{
|
||||
Bucket: &cfg.Bucket,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("s3 HeadBucket failed: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GenerateObjectKey 生成 S3 object key。
|
||||
// 格式: {prefix}sora/{userID}/{YYYY/MM/DD}/{uuid}.{ext}
|
||||
func (s *SoraS3Storage) GenerateObjectKey(prefix string, userID int64, ext string) string {
|
||||
if !strings.HasPrefix(ext, ".") {
|
||||
ext = "." + ext
|
||||
}
|
||||
datePath := time.Now().Format("2006/01/02")
|
||||
key := fmt.Sprintf("sora/%d/%s/%s%s", userID, datePath, uuid.NewString(), ext)
|
||||
if prefix != "" {
|
||||
prefix = strings.TrimRight(prefix, "/") + "/"
|
||||
key = prefix + key
|
||||
}
|
||||
return key
|
||||
}
|
||||
|
||||
// UploadFromURL 从上游 URL 下载并流式上传到 S3。
|
||||
// 返回 S3 object key。
|
||||
func (s *SoraS3Storage) UploadFromURL(ctx context.Context, userID int64, sourceURL string) (string, int64, error) {
|
||||
client, cfg, err := s.getClient(ctx)
|
||||
if err != nil {
|
||||
return "", 0, err
|
||||
}
|
||||
|
||||
// 下载源文件
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, sourceURL, nil)
|
||||
if err != nil {
|
||||
return "", 0, fmt.Errorf("create download request: %w", err)
|
||||
}
|
||||
httpClient := &http.Client{Timeout: 5 * time.Minute}
|
||||
resp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
return "", 0, fmt.Errorf("download from upstream: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = resp.Body.Close()
|
||||
}()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", 0, &UpstreamDownloadError{StatusCode: resp.StatusCode}
|
||||
}
|
||||
|
||||
// 推断文件扩展名
|
||||
ext := fileExtFromURL(sourceURL)
|
||||
if ext == "" {
|
||||
ext = fileExtFromContentType(resp.Header.Get("Content-Type"))
|
||||
}
|
||||
if ext == "" {
|
||||
ext = ".bin"
|
||||
}
|
||||
|
||||
objectKey := s.GenerateObjectKey(cfg.Prefix, userID, ext)
|
||||
|
||||
// 检测 Content-Type
|
||||
contentType := resp.Header.Get("Content-Type")
|
||||
if contentType == "" {
|
||||
contentType = "application/octet-stream"
|
||||
}
|
||||
|
||||
reader, writer := io.Pipe()
|
||||
uploadErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
defer close(uploadErrCh)
|
||||
input := &s3.PutObjectInput{
|
||||
Bucket: &cfg.Bucket,
|
||||
Key: &objectKey,
|
||||
Body: reader,
|
||||
ContentType: &contentType,
|
||||
}
|
||||
if resp.ContentLength >= 0 {
|
||||
input.ContentLength = &resp.ContentLength
|
||||
}
|
||||
_, uploadErr := client.PutObject(ctx, input)
|
||||
uploadErrCh <- uploadErr
|
||||
}()
|
||||
|
||||
written, copyErr := io.CopyBuffer(writer, resp.Body, make([]byte, 1024*1024))
|
||||
_ = writer.CloseWithError(copyErr)
|
||||
uploadErr := <-uploadErrCh
|
||||
if copyErr != nil {
|
||||
return "", 0, fmt.Errorf("stream upload copy failed: %w", copyErr)
|
||||
}
|
||||
if uploadErr != nil {
|
||||
return "", 0, fmt.Errorf("s3 upload: %w", uploadErr)
|
||||
}
|
||||
|
||||
logger.LegacyPrintf("service.sora_s3", "[SoraS3] 上传完成 key=%s size=%d", objectKey, written)
|
||||
return objectKey, written, nil
|
||||
}
|
||||
|
||||
func buildSoraS3Client(ctx context.Context, cfg *SoraS3Settings) (*s3.Client, string, error) {
|
||||
if cfg == nil {
|
||||
return nil, "", fmt.Errorf("s3 config is required")
|
||||
}
|
||||
region := cfg.Region
|
||||
if region == "" {
|
||||
region = "us-east-1"
|
||||
}
|
||||
|
||||
awsCfg, err := awsconfig.LoadDefaultConfig(ctx,
|
||||
awsconfig.WithRegion(region),
|
||||
awsconfig.WithCredentialsProvider(
|
||||
credentials.NewStaticCredentialsProvider(cfg.AccessKeyID, cfg.SecretAccessKey, ""),
|
||||
),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("load aws config: %w", err)
|
||||
}
|
||||
|
||||
client := s3.NewFromConfig(awsCfg, func(o *s3.Options) {
|
||||
if cfg.Endpoint != "" {
|
||||
o.BaseEndpoint = &cfg.Endpoint
|
||||
}
|
||||
if cfg.ForcePathStyle {
|
||||
o.UsePathStyle = true
|
||||
}
|
||||
o.APIOptions = append(o.APIOptions, v4.SwapComputePayloadSHA256ForUnsignedPayloadMiddleware)
|
||||
// 兼容非 TLS 连接(如 MinIO)的流式上传,避免 io.Pipe checksum 校验失败
|
||||
o.RequestChecksumCalculation = aws.RequestChecksumCalculationWhenRequired
|
||||
})
|
||||
return client, region, nil
|
||||
}
|
||||
|
||||
// DeleteObjects 删除一组 S3 object(遍历逐一删除)。
|
||||
func (s *SoraS3Storage) DeleteObjects(ctx context.Context, objectKeys []string) error {
|
||||
if len(objectKeys) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
client, cfg, err := s.getClient(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var lastErr error
|
||||
for _, key := range objectKeys {
|
||||
k := key
|
||||
_, err := client.DeleteObject(ctx, &s3.DeleteObjectInput{
|
||||
Bucket: &cfg.Bucket,
|
||||
Key: &k,
|
||||
})
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.sora_s3", "[SoraS3] 删除失败 key=%s err=%v", key, err)
|
||||
lastErr = err
|
||||
}
|
||||
}
|
||||
return lastErr
|
||||
}
|
||||
|
||||
// GetAccessURL 获取 S3 文件的访问 URL。
|
||||
// CDN URL 优先,否则生成 24h 预签名 URL。
|
||||
func (s *SoraS3Storage) GetAccessURL(ctx context.Context, objectKey string) (string, error) {
|
||||
_, cfg, err := s.getClient(ctx)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// CDN URL 优先
|
||||
if cfg.CDNURL != "" {
|
||||
cdnBase := strings.TrimRight(cfg.CDNURL, "/")
|
||||
return cdnBase + "/" + objectKey, nil
|
||||
}
|
||||
|
||||
// 生成 24h 预签名 URL
|
||||
return s.GeneratePresignedURL(ctx, objectKey, 24*time.Hour)
|
||||
}
|
||||
|
||||
// GeneratePresignedURL 生成预签名 URL。
|
||||
func (s *SoraS3Storage) GeneratePresignedURL(ctx context.Context, objectKey string, ttl time.Duration) (string, error) {
|
||||
client, cfg, err := s.getClient(ctx)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
presignClient := s3.NewPresignClient(client)
|
||||
result, err := presignClient.PresignGetObject(ctx, &s3.GetObjectInput{
|
||||
Bucket: &cfg.Bucket,
|
||||
Key: &objectKey,
|
||||
}, s3.WithPresignExpires(ttl))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("presign url: %w", err)
|
||||
}
|
||||
return result.URL, nil
|
||||
}
|
||||
|
||||
// GetMediaType 从 object key 推断媒体类型(image/video)。
|
||||
func GetMediaTypeFromKey(objectKey string) string {
|
||||
ext := strings.ToLower(path.Ext(objectKey))
|
||||
switch ext {
|
||||
case ".mp4", ".mov", ".webm", ".m4v", ".avi", ".mkv", ".3gp", ".flv":
|
||||
return "video"
|
||||
default:
|
||||
return "image"
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user