package service import ( "context" "fmt" "io" "net/http" "path" "strings" "sync" "time" "github.com/aws/aws-sdk-go-v2/aws" v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4" awsconfig "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/credentials" "github.com/aws/aws-sdk-go-v2/service/s3" "github.com/google/uuid" "github.com/Wei-Shaw/sub2api/internal/pkg/logger" ) // SoraS3Storage 负责 Sora 媒体文件的 S3 存储操作。 // 从 Settings 表读取 S3 配置,初始化并缓存 S3 客户端。 type SoraS3Storage struct { settingService *SettingService mu sync.RWMutex client *s3.Client cfg *SoraS3Settings // 上次加载的配置快照 healthCheckedAt time.Time healthErr error healthTTL time.Duration } const defaultSoraS3HealthTTL = 30 * time.Second // UpstreamDownloadError 表示从上游下载媒体失败(包含 HTTP 状态码)。 type UpstreamDownloadError struct { StatusCode int } func (e *UpstreamDownloadError) Error() string { if e == nil { return "upstream download failed" } return fmt.Sprintf("upstream returned %d", e.StatusCode) } // NewSoraS3Storage 创建 S3 存储服务实例。 func NewSoraS3Storage(settingService *SettingService) *SoraS3Storage { return &SoraS3Storage{ settingService: settingService, healthTTL: defaultSoraS3HealthTTL, } } // Enabled 返回 S3 存储是否已启用且配置有效。 func (s *SoraS3Storage) Enabled(ctx context.Context) bool { cfg, err := s.getConfig(ctx) if err != nil || cfg == nil { return false } return cfg.Enabled && cfg.Bucket != "" } // getConfig 获取当前 S3 配置(从 settings 表读取)。 func (s *SoraS3Storage) getConfig(ctx context.Context) (*SoraS3Settings, error) { if s.settingService == nil { return nil, fmt.Errorf("setting service not available") } return s.settingService.GetSoraS3Settings(ctx) } // getClient 获取或初始化 S3 客户端(带缓存)。 // 配置变更时调用 RefreshClient 清除缓存。 func (s *SoraS3Storage) getClient(ctx context.Context) (*s3.Client, *SoraS3Settings, error) { s.mu.RLock() if s.client != nil && s.cfg != nil { client, cfg := s.client, s.cfg s.mu.RUnlock() return client, cfg, nil } s.mu.RUnlock() return s.initClient(ctx) } func (s *SoraS3Storage) initClient(ctx context.Context) (*s3.Client, *SoraS3Settings, error) { s.mu.Lock() defer s.mu.Unlock() // 双重检查 if s.client != nil && s.cfg != nil { return s.client, s.cfg, nil } cfg, err := s.getConfig(ctx) if err != nil { return nil, nil, fmt.Errorf("load s3 config: %w", err) } if !cfg.Enabled { return nil, nil, fmt.Errorf("sora s3 storage is disabled") } if cfg.Bucket == "" || cfg.AccessKeyID == "" || cfg.SecretAccessKey == "" { return nil, nil, fmt.Errorf("sora s3 config incomplete: bucket, access_key_id, secret_access_key are required") } client, region, err := buildSoraS3Client(ctx, cfg) if err != nil { return nil, nil, err } s.client = client s.cfg = cfg logger.LegacyPrintf("service.sora_s3", "[SoraS3] 客户端已初始化 bucket=%s endpoint=%s region=%s", cfg.Bucket, cfg.Endpoint, region) return client, cfg, nil } // RefreshClient 清除缓存的 S3 客户端,下次使用时重新初始化。 // 应在系统设置中 S3 配置变更时调用。 func (s *SoraS3Storage) RefreshClient() { s.mu.Lock() defer s.mu.Unlock() s.client = nil s.cfg = nil s.healthCheckedAt = time.Time{} s.healthErr = nil logger.LegacyPrintf("service.sora_s3", "[SoraS3] 客户端缓存已清除,下次使用将重新初始化") } // TestConnection 测试 S3 连接(HeadBucket)。 func (s *SoraS3Storage) TestConnection(ctx context.Context) error { client, cfg, err := s.getClient(ctx) if err != nil { return err } _, err = client.HeadBucket(ctx, &s3.HeadBucketInput{ Bucket: &cfg.Bucket, }) if err != nil { return fmt.Errorf("s3 HeadBucket failed: %w", err) } return nil } // IsHealthy 返回 S3 健康状态(带短缓存,避免每次请求都触发 HeadBucket)。 func (s *SoraS3Storage) IsHealthy(ctx context.Context) bool { if s == nil { return false } now := time.Now() s.mu.RLock() lastCheck := s.healthCheckedAt lastErr := s.healthErr ttl := s.healthTTL s.mu.RUnlock() if ttl <= 0 { ttl = defaultSoraS3HealthTTL } if !lastCheck.IsZero() && now.Sub(lastCheck) < ttl { return lastErr == nil } err := s.TestConnection(ctx) s.mu.Lock() s.healthCheckedAt = time.Now() s.healthErr = err s.mu.Unlock() return err == nil } // TestConnectionWithSettings 使用临时配置测试连接,不污染缓存的客户端。 func (s *SoraS3Storage) TestConnectionWithSettings(ctx context.Context, cfg *SoraS3Settings) error { if cfg == nil { return fmt.Errorf("s3 config is required") } if !cfg.Enabled { return fmt.Errorf("sora s3 storage is disabled") } if cfg.Endpoint == "" || cfg.Bucket == "" || cfg.AccessKeyID == "" || cfg.SecretAccessKey == "" { return fmt.Errorf("sora s3 config incomplete: endpoint, bucket, access_key_id, secret_access_key are required") } client, _, err := buildSoraS3Client(ctx, cfg) if err != nil { return err } _, err = client.HeadBucket(ctx, &s3.HeadBucketInput{ Bucket: &cfg.Bucket, }) if err != nil { return fmt.Errorf("s3 HeadBucket failed: %w", err) } return nil } // GenerateObjectKey 生成 S3 object key。 // 格式: {prefix}sora/{userID}/{YYYY/MM/DD}/{uuid}.{ext} func (s *SoraS3Storage) GenerateObjectKey(prefix string, userID int64, ext string) string { if !strings.HasPrefix(ext, ".") { ext = "." + ext } datePath := time.Now().Format("2006/01/02") key := fmt.Sprintf("sora/%d/%s/%s%s", userID, datePath, uuid.NewString(), ext) if prefix != "" { prefix = strings.TrimRight(prefix, "/") + "/" key = prefix + key } return key } // UploadFromURL 从上游 URL 下载并流式上传到 S3。 // 返回 S3 object key。 func (s *SoraS3Storage) UploadFromURL(ctx context.Context, userID int64, sourceURL string) (string, int64, error) { client, cfg, err := s.getClient(ctx) if err != nil { return "", 0, err } // 下载源文件 req, err := http.NewRequestWithContext(ctx, http.MethodGet, sourceURL, nil) if err != nil { return "", 0, fmt.Errorf("create download request: %w", err) } httpClient := &http.Client{Timeout: 5 * time.Minute} resp, err := httpClient.Do(req) if err != nil { return "", 0, fmt.Errorf("download from upstream: %w", err) } defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusOK { return "", 0, &UpstreamDownloadError{StatusCode: resp.StatusCode} } // 推断文件扩展名 ext := fileExtFromURL(sourceURL) if ext == "" { ext = fileExtFromContentType(resp.Header.Get("Content-Type")) } if ext == "" { ext = ".bin" } objectKey := s.GenerateObjectKey(cfg.Prefix, userID, ext) // 检测 Content-Type contentType := resp.Header.Get("Content-Type") if contentType == "" { contentType = "application/octet-stream" } reader, writer := io.Pipe() uploadErrCh := make(chan error, 1) go func() { defer close(uploadErrCh) input := &s3.PutObjectInput{ Bucket: &cfg.Bucket, Key: &objectKey, Body: reader, ContentType: &contentType, } if resp.ContentLength >= 0 { input.ContentLength = &resp.ContentLength } _, uploadErr := client.PutObject(ctx, input) uploadErrCh <- uploadErr }() written, copyErr := io.CopyBuffer(writer, resp.Body, make([]byte, 1024*1024)) _ = writer.CloseWithError(copyErr) uploadErr := <-uploadErrCh if copyErr != nil { return "", 0, fmt.Errorf("stream upload copy failed: %w", copyErr) } if uploadErr != nil { return "", 0, fmt.Errorf("s3 upload: %w", uploadErr) } logger.LegacyPrintf("service.sora_s3", "[SoraS3] 上传完成 key=%s size=%d", objectKey, written) return objectKey, written, nil } func buildSoraS3Client(ctx context.Context, cfg *SoraS3Settings) (*s3.Client, string, error) { if cfg == nil { return nil, "", fmt.Errorf("s3 config is required") } region := cfg.Region if region == "" { region = "us-east-1" } awsCfg, err := awsconfig.LoadDefaultConfig(ctx, awsconfig.WithRegion(region), awsconfig.WithCredentialsProvider( credentials.NewStaticCredentialsProvider(cfg.AccessKeyID, cfg.SecretAccessKey, ""), ), ) if err != nil { return nil, "", fmt.Errorf("load aws config: %w", err) } client := s3.NewFromConfig(awsCfg, func(o *s3.Options) { if cfg.Endpoint != "" { o.BaseEndpoint = &cfg.Endpoint } if cfg.ForcePathStyle { o.UsePathStyle = true } o.APIOptions = append(o.APIOptions, v4.SwapComputePayloadSHA256ForUnsignedPayloadMiddleware) // 兼容非 TLS 连接(如 MinIO)的流式上传,避免 io.Pipe checksum 校验失败 o.RequestChecksumCalculation = aws.RequestChecksumCalculationWhenRequired }) return client, region, nil } // DeleteObjects 删除一组 S3 object(遍历逐一删除)。 func (s *SoraS3Storage) DeleteObjects(ctx context.Context, objectKeys []string) error { if len(objectKeys) == 0 { return nil } client, cfg, err := s.getClient(ctx) if err != nil { return err } var lastErr error for _, key := range objectKeys { k := key _, err := client.DeleteObject(ctx, &s3.DeleteObjectInput{ Bucket: &cfg.Bucket, Key: &k, }) if err != nil { logger.LegacyPrintf("service.sora_s3", "[SoraS3] 删除失败 key=%s err=%v", key, err) lastErr = err } } return lastErr } // GetAccessURL 获取 S3 文件的访问 URL。 // CDN URL 优先,否则生成 24h 预签名 URL。 func (s *SoraS3Storage) GetAccessURL(ctx context.Context, objectKey string) (string, error) { _, cfg, err := s.getClient(ctx) if err != nil { return "", err } // CDN URL 优先 if cfg.CDNURL != "" { cdnBase := strings.TrimRight(cfg.CDNURL, "/") return cdnBase + "/" + objectKey, nil } // 生成 24h 预签名 URL return s.GeneratePresignedURL(ctx, objectKey, 24*time.Hour) } // GeneratePresignedURL 生成预签名 URL。 func (s *SoraS3Storage) GeneratePresignedURL(ctx context.Context, objectKey string, ttl time.Duration) (string, error) { client, cfg, err := s.getClient(ctx) if err != nil { return "", err } presignClient := s3.NewPresignClient(client) result, err := presignClient.PresignGetObject(ctx, &s3.GetObjectInput{ Bucket: &cfg.Bucket, Key: &objectKey, }, s3.WithPresignExpires(ttl)) if err != nil { return "", fmt.Errorf("presign url: %w", err) } return result.URL, nil } // GetMediaType 从 object key 推断媒体类型(image/video)。 func GetMediaTypeFromKey(objectKey string) string { ext := strings.ToLower(path.Ext(objectKey)) switch ext { case ".mp4", ".mov", ".webm", ".m4v", ".avi", ".mkv", ".3gp", ".flv": return "video" default: return "image" } }