Files
sub2api/backend/internal/service/sora_s3_storage.go
2026-02-28 15:01:20 +08:00

393 lines
10 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package service
import (
"context"
"fmt"
"io"
"net/http"
"path"
"strings"
"sync"
"time"
"github.com/aws/aws-sdk-go-v2/aws"
v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4"
awsconfig "github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/google/uuid"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
)
// SoraS3Storage 负责 Sora 媒体文件的 S3 存储操作。
// 从 Settings 表读取 S3 配置,初始化并缓存 S3 客户端。
type SoraS3Storage struct {
settingService *SettingService
mu sync.RWMutex
client *s3.Client
cfg *SoraS3Settings // 上次加载的配置快照
healthCheckedAt time.Time
healthErr error
healthTTL time.Duration
}
const defaultSoraS3HealthTTL = 30 * time.Second
// UpstreamDownloadError 表示从上游下载媒体失败(包含 HTTP 状态码)。
type UpstreamDownloadError struct {
StatusCode int
}
func (e *UpstreamDownloadError) Error() string {
if e == nil {
return "upstream download failed"
}
return fmt.Sprintf("upstream returned %d", e.StatusCode)
}
// NewSoraS3Storage 创建 S3 存储服务实例。
func NewSoraS3Storage(settingService *SettingService) *SoraS3Storage {
return &SoraS3Storage{
settingService: settingService,
healthTTL: defaultSoraS3HealthTTL,
}
}
// Enabled 返回 S3 存储是否已启用且配置有效。
func (s *SoraS3Storage) Enabled(ctx context.Context) bool {
cfg, err := s.getConfig(ctx)
if err != nil || cfg == nil {
return false
}
return cfg.Enabled && cfg.Bucket != ""
}
// getConfig 获取当前 S3 配置(从 settings 表读取)。
func (s *SoraS3Storage) getConfig(ctx context.Context) (*SoraS3Settings, error) {
if s.settingService == nil {
return nil, fmt.Errorf("setting service not available")
}
return s.settingService.GetSoraS3Settings(ctx)
}
// getClient 获取或初始化 S3 客户端(带缓存)。
// 配置变更时调用 RefreshClient 清除缓存。
func (s *SoraS3Storage) getClient(ctx context.Context) (*s3.Client, *SoraS3Settings, error) {
s.mu.RLock()
if s.client != nil && s.cfg != nil {
client, cfg := s.client, s.cfg
s.mu.RUnlock()
return client, cfg, nil
}
s.mu.RUnlock()
return s.initClient(ctx)
}
func (s *SoraS3Storage) initClient(ctx context.Context) (*s3.Client, *SoraS3Settings, error) {
s.mu.Lock()
defer s.mu.Unlock()
// 双重检查
if s.client != nil && s.cfg != nil {
return s.client, s.cfg, nil
}
cfg, err := s.getConfig(ctx)
if err != nil {
return nil, nil, fmt.Errorf("load s3 config: %w", err)
}
if !cfg.Enabled {
return nil, nil, fmt.Errorf("sora s3 storage is disabled")
}
if cfg.Bucket == "" || cfg.AccessKeyID == "" || cfg.SecretAccessKey == "" {
return nil, nil, fmt.Errorf("sora s3 config incomplete: bucket, access_key_id, secret_access_key are required")
}
client, region, err := buildSoraS3Client(ctx, cfg)
if err != nil {
return nil, nil, err
}
s.client = client
s.cfg = cfg
logger.LegacyPrintf("service.sora_s3", "[SoraS3] 客户端已初始化 bucket=%s endpoint=%s region=%s", cfg.Bucket, cfg.Endpoint, region)
return client, cfg, nil
}
// RefreshClient 清除缓存的 S3 客户端,下次使用时重新初始化。
// 应在系统设置中 S3 配置变更时调用。
func (s *SoraS3Storage) RefreshClient() {
s.mu.Lock()
defer s.mu.Unlock()
s.client = nil
s.cfg = nil
s.healthCheckedAt = time.Time{}
s.healthErr = nil
logger.LegacyPrintf("service.sora_s3", "[SoraS3] 客户端缓存已清除,下次使用将重新初始化")
}
// TestConnection 测试 S3 连接HeadBucket
func (s *SoraS3Storage) TestConnection(ctx context.Context) error {
client, cfg, err := s.getClient(ctx)
if err != nil {
return err
}
_, err = client.HeadBucket(ctx, &s3.HeadBucketInput{
Bucket: &cfg.Bucket,
})
if err != nil {
return fmt.Errorf("s3 HeadBucket failed: %w", err)
}
return nil
}
// IsHealthy 返回 S3 健康状态(带短缓存,避免每次请求都触发 HeadBucket
func (s *SoraS3Storage) IsHealthy(ctx context.Context) bool {
if s == nil {
return false
}
now := time.Now()
s.mu.RLock()
lastCheck := s.healthCheckedAt
lastErr := s.healthErr
ttl := s.healthTTL
s.mu.RUnlock()
if ttl <= 0 {
ttl = defaultSoraS3HealthTTL
}
if !lastCheck.IsZero() && now.Sub(lastCheck) < ttl {
return lastErr == nil
}
err := s.TestConnection(ctx)
s.mu.Lock()
s.healthCheckedAt = time.Now()
s.healthErr = err
s.mu.Unlock()
return err == nil
}
// TestConnectionWithSettings 使用临时配置测试连接,不污染缓存的客户端。
func (s *SoraS3Storage) TestConnectionWithSettings(ctx context.Context, cfg *SoraS3Settings) error {
if cfg == nil {
return fmt.Errorf("s3 config is required")
}
if !cfg.Enabled {
return fmt.Errorf("sora s3 storage is disabled")
}
if cfg.Endpoint == "" || cfg.Bucket == "" || cfg.AccessKeyID == "" || cfg.SecretAccessKey == "" {
return fmt.Errorf("sora s3 config incomplete: endpoint, bucket, access_key_id, secret_access_key are required")
}
client, _, err := buildSoraS3Client(ctx, cfg)
if err != nil {
return err
}
_, err = client.HeadBucket(ctx, &s3.HeadBucketInput{
Bucket: &cfg.Bucket,
})
if err != nil {
return fmt.Errorf("s3 HeadBucket failed: %w", err)
}
return nil
}
// GenerateObjectKey 生成 S3 object key。
// 格式: {prefix}sora/{userID}/{YYYY/MM/DD}/{uuid}.{ext}
func (s *SoraS3Storage) GenerateObjectKey(prefix string, userID int64, ext string) string {
if !strings.HasPrefix(ext, ".") {
ext = "." + ext
}
datePath := time.Now().Format("2006/01/02")
key := fmt.Sprintf("sora/%d/%s/%s%s", userID, datePath, uuid.NewString(), ext)
if prefix != "" {
prefix = strings.TrimRight(prefix, "/") + "/"
key = prefix + key
}
return key
}
// UploadFromURL 从上游 URL 下载并流式上传到 S3。
// 返回 S3 object key。
func (s *SoraS3Storage) UploadFromURL(ctx context.Context, userID int64, sourceURL string) (string, int64, error) {
client, cfg, err := s.getClient(ctx)
if err != nil {
return "", 0, err
}
// 下载源文件
req, err := http.NewRequestWithContext(ctx, http.MethodGet, sourceURL, nil)
if err != nil {
return "", 0, fmt.Errorf("create download request: %w", err)
}
httpClient := &http.Client{Timeout: 5 * time.Minute}
resp, err := httpClient.Do(req)
if err != nil {
return "", 0, fmt.Errorf("download from upstream: %w", err)
}
defer func() {
_ = resp.Body.Close()
}()
if resp.StatusCode != http.StatusOK {
return "", 0, &UpstreamDownloadError{StatusCode: resp.StatusCode}
}
// 推断文件扩展名
ext := fileExtFromURL(sourceURL)
if ext == "" {
ext = fileExtFromContentType(resp.Header.Get("Content-Type"))
}
if ext == "" {
ext = ".bin"
}
objectKey := s.GenerateObjectKey(cfg.Prefix, userID, ext)
// 检测 Content-Type
contentType := resp.Header.Get("Content-Type")
if contentType == "" {
contentType = "application/octet-stream"
}
reader, writer := io.Pipe()
uploadErrCh := make(chan error, 1)
go func() {
defer close(uploadErrCh)
input := &s3.PutObjectInput{
Bucket: &cfg.Bucket,
Key: &objectKey,
Body: reader,
ContentType: &contentType,
}
if resp.ContentLength >= 0 {
input.ContentLength = &resp.ContentLength
}
_, uploadErr := client.PutObject(ctx, input)
uploadErrCh <- uploadErr
}()
written, copyErr := io.CopyBuffer(writer, resp.Body, make([]byte, 1024*1024))
_ = writer.CloseWithError(copyErr)
uploadErr := <-uploadErrCh
if copyErr != nil {
return "", 0, fmt.Errorf("stream upload copy failed: %w", copyErr)
}
if uploadErr != nil {
return "", 0, fmt.Errorf("s3 upload: %w", uploadErr)
}
logger.LegacyPrintf("service.sora_s3", "[SoraS3] 上传完成 key=%s size=%d", objectKey, written)
return objectKey, written, nil
}
func buildSoraS3Client(ctx context.Context, cfg *SoraS3Settings) (*s3.Client, string, error) {
if cfg == nil {
return nil, "", fmt.Errorf("s3 config is required")
}
region := cfg.Region
if region == "" {
region = "us-east-1"
}
awsCfg, err := awsconfig.LoadDefaultConfig(ctx,
awsconfig.WithRegion(region),
awsconfig.WithCredentialsProvider(
credentials.NewStaticCredentialsProvider(cfg.AccessKeyID, cfg.SecretAccessKey, ""),
),
)
if err != nil {
return nil, "", fmt.Errorf("load aws config: %w", err)
}
client := s3.NewFromConfig(awsCfg, func(o *s3.Options) {
if cfg.Endpoint != "" {
o.BaseEndpoint = &cfg.Endpoint
}
if cfg.ForcePathStyle {
o.UsePathStyle = true
}
o.APIOptions = append(o.APIOptions, v4.SwapComputePayloadSHA256ForUnsignedPayloadMiddleware)
// 兼容非 TLS 连接(如 MinIO的流式上传避免 io.Pipe checksum 校验失败
o.RequestChecksumCalculation = aws.RequestChecksumCalculationWhenRequired
})
return client, region, nil
}
// DeleteObjects 删除一组 S3 object遍历逐一删除
func (s *SoraS3Storage) DeleteObjects(ctx context.Context, objectKeys []string) error {
if len(objectKeys) == 0 {
return nil
}
client, cfg, err := s.getClient(ctx)
if err != nil {
return err
}
var lastErr error
for _, key := range objectKeys {
k := key
_, err := client.DeleteObject(ctx, &s3.DeleteObjectInput{
Bucket: &cfg.Bucket,
Key: &k,
})
if err != nil {
logger.LegacyPrintf("service.sora_s3", "[SoraS3] 删除失败 key=%s err=%v", key, err)
lastErr = err
}
}
return lastErr
}
// GetAccessURL 获取 S3 文件的访问 URL。
// CDN URL 优先,否则生成 24h 预签名 URL。
func (s *SoraS3Storage) GetAccessURL(ctx context.Context, objectKey string) (string, error) {
_, cfg, err := s.getClient(ctx)
if err != nil {
return "", err
}
// CDN URL 优先
if cfg.CDNURL != "" {
cdnBase := strings.TrimRight(cfg.CDNURL, "/")
return cdnBase + "/" + objectKey, nil
}
// 生成 24h 预签名 URL
return s.GeneratePresignedURL(ctx, objectKey, 24*time.Hour)
}
// GeneratePresignedURL 生成预签名 URL。
func (s *SoraS3Storage) GeneratePresignedURL(ctx context.Context, objectKey string, ttl time.Duration) (string, error) {
client, cfg, err := s.getClient(ctx)
if err != nil {
return "", err
}
presignClient := s3.NewPresignClient(client)
result, err := presignClient.PresignGetObject(ctx, &s3.GetObjectInput{
Bucket: &cfg.Bucket,
Key: &objectKey,
}, s3.WithPresignExpires(ttl))
if err != nil {
return "", fmt.Errorf("presign url: %w", err)
}
return result.URL, nil
}
// GetMediaType 从 object key 推断媒体类型image/video
func GetMediaTypeFromKey(objectKey string) string {
ext := strings.ToLower(path.Ext(objectKey))
switch ext {
case ".mp4", ".mov", ".webm", ".m4v", ".avi", ".mkv", ".3gp", ".flv":
return "video"
default:
return "image"
}
}