feat(sync): full code sync from release
This commit is contained in:
875
backend/internal/service/sora_generation_service_test.go
Normal file
875
backend/internal/service/sora_generation_service_test.go
Normal file
@@ -0,0 +1,875 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/aws/aws-sdk-go-v2/service/s3"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// ==================== Stub: SoraGenerationRepository ====================
|
||||
|
||||
var _ SoraGenerationRepository = (*stubGenRepo)(nil)
|
||||
|
||||
type stubGenRepo struct {
|
||||
gens map[int64]*SoraGeneration
|
||||
nextID int64
|
||||
createErr error
|
||||
getErr error
|
||||
updateErr error
|
||||
deleteErr error
|
||||
listErr error
|
||||
countErr error
|
||||
countValue int64
|
||||
}
|
||||
|
||||
func newStubGenRepo() *stubGenRepo {
|
||||
return &stubGenRepo{gens: make(map[int64]*SoraGeneration), nextID: 1}
|
||||
}
|
||||
|
||||
func (r *stubGenRepo) Create(_ context.Context, gen *SoraGeneration) error {
|
||||
if r.createErr != nil {
|
||||
return r.createErr
|
||||
}
|
||||
gen.ID = r.nextID
|
||||
gen.CreatedAt = time.Now()
|
||||
r.nextID++
|
||||
r.gens[gen.ID] = gen
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *stubGenRepo) GetByID(_ context.Context, id int64) (*SoraGeneration, error) {
|
||||
if r.getErr != nil {
|
||||
return nil, r.getErr
|
||||
}
|
||||
if gen, ok := r.gens[id]; ok {
|
||||
return gen, nil
|
||||
}
|
||||
return nil, fmt.Errorf("not found")
|
||||
}
|
||||
|
||||
func (r *stubGenRepo) Update(_ context.Context, gen *SoraGeneration) error {
|
||||
if r.updateErr != nil {
|
||||
return r.updateErr
|
||||
}
|
||||
r.gens[gen.ID] = gen
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *stubGenRepo) Delete(_ context.Context, id int64) error {
|
||||
if r.deleteErr != nil {
|
||||
return r.deleteErr
|
||||
}
|
||||
delete(r.gens, id)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *stubGenRepo) List(_ context.Context, params SoraGenerationListParams) ([]*SoraGeneration, int64, error) {
|
||||
if r.listErr != nil {
|
||||
return nil, 0, r.listErr
|
||||
}
|
||||
var result []*SoraGeneration
|
||||
for _, gen := range r.gens {
|
||||
if gen.UserID != params.UserID {
|
||||
continue
|
||||
}
|
||||
if params.Status != "" && gen.Status != params.Status {
|
||||
continue
|
||||
}
|
||||
if params.StorageType != "" && gen.StorageType != params.StorageType {
|
||||
continue
|
||||
}
|
||||
if params.MediaType != "" && gen.MediaType != params.MediaType {
|
||||
continue
|
||||
}
|
||||
result = append(result, gen)
|
||||
}
|
||||
return result, int64(len(result)), nil
|
||||
}
|
||||
|
||||
func (r *stubGenRepo) CountByUserAndStatus(_ context.Context, userID int64, statuses []string) (int64, error) {
|
||||
if r.countErr != nil {
|
||||
return 0, r.countErr
|
||||
}
|
||||
if r.countValue > 0 {
|
||||
return r.countValue, nil
|
||||
}
|
||||
var count int64
|
||||
statusSet := make(map[string]struct{})
|
||||
for _, s := range statuses {
|
||||
statusSet[s] = struct{}{}
|
||||
}
|
||||
for _, gen := range r.gens {
|
||||
if gen.UserID == userID {
|
||||
if _, ok := statusSet[gen.Status]; ok {
|
||||
count++
|
||||
}
|
||||
}
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// ==================== Stub: UserRepository (用于 SoraQuotaService) ====================
|
||||
|
||||
var _ UserRepository = (*stubUserRepoForQuota)(nil)
|
||||
|
||||
type stubUserRepoForQuota struct {
|
||||
users map[int64]*User
|
||||
updateErr error
|
||||
}
|
||||
|
||||
func newStubUserRepoForQuota() *stubUserRepoForQuota {
|
||||
return &stubUserRepoForQuota{users: make(map[int64]*User)}
|
||||
}
|
||||
|
||||
func (r *stubUserRepoForQuota) GetByID(_ context.Context, id int64) (*User, error) {
|
||||
if u, ok := r.users[id]; ok {
|
||||
return u, nil
|
||||
}
|
||||
return nil, fmt.Errorf("user not found")
|
||||
}
|
||||
func (r *stubUserRepoForQuota) Update(_ context.Context, user *User) error {
|
||||
if r.updateErr != nil {
|
||||
return r.updateErr
|
||||
}
|
||||
r.users[user.ID] = user
|
||||
return nil
|
||||
}
|
||||
func (r *stubUserRepoForQuota) Create(context.Context, *User) error { return nil }
|
||||
func (r *stubUserRepoForQuota) GetByEmail(context.Context, string) (*User, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (r *stubUserRepoForQuota) GetFirstAdmin(context.Context) (*User, error) { return nil, nil }
|
||||
func (r *stubUserRepoForQuota) Delete(context.Context, int64) error { return nil }
|
||||
func (r *stubUserRepoForQuota) List(context.Context, pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (r *stubUserRepoForQuota) ListWithFilters(context.Context, pagination.PaginationParams, UserListFilters) ([]User, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (r *stubUserRepoForQuota) UpdateBalance(context.Context, int64, float64) error { return nil }
|
||||
func (r *stubUserRepoForQuota) DeductBalance(context.Context, int64, float64) error { return nil }
|
||||
func (r *stubUserRepoForQuota) UpdateConcurrency(context.Context, int64, int) error { return nil }
|
||||
func (r *stubUserRepoForQuota) ExistsByEmail(context.Context, string) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
func (r *stubUserRepoForQuota) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
func (r *stubUserRepoForQuota) UpdateTotpSecret(context.Context, int64, *string) error { return nil }
|
||||
func (r *stubUserRepoForQuota) EnableTotp(context.Context, int64) error { return nil }
|
||||
func (r *stubUserRepoForQuota) DisableTotp(context.Context, int64) error { return nil }
|
||||
|
||||
// ==================== 辅助函数:构造带 CDN 缓存的 SoraS3Storage ====================
|
||||
|
||||
// newS3StorageWithCDN 创建一个预缓存了 CDN 配置的 SoraS3Storage,
|
||||
// 避免实际初始化 AWS 客户端。用于测试 GetAccessURL 的 CDN 路径。
|
||||
func newS3StorageWithCDN(cdnURL string) *SoraS3Storage {
|
||||
storage := &SoraS3Storage{}
|
||||
storage.cfg = &SoraS3Settings{
|
||||
Enabled: true,
|
||||
Bucket: "test-bucket",
|
||||
CDNURL: cdnURL,
|
||||
}
|
||||
// 需要 non-nil client 使 getClient 命中缓存
|
||||
storage.client = s3.New(s3.Options{})
|
||||
return storage
|
||||
}
|
||||
|
||||
// newS3StorageFailingDelete 创建一个 settingService=nil 的 SoraS3Storage,
|
||||
// 使 DeleteObjects 返回错误(无法获取配置)。用于测试 Delete 方法 S3 清理失败但仍继续的场景。
|
||||
func newS3StorageFailingDelete() *SoraS3Storage {
|
||||
return &SoraS3Storage{} // settingService 为 nil → getConfig 返回 error
|
||||
}
|
||||
|
||||
// ==================== CreatePending ====================
|
||||
|
||||
func TestCreatePending_Success(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
gen, err := svc.CreatePending(context.Background(), 1, nil, "sora2-landscape-10s", "一只猫跳舞", "video")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(1), gen.ID)
|
||||
require.Equal(t, int64(1), gen.UserID)
|
||||
require.Equal(t, "sora2-landscape-10s", gen.Model)
|
||||
require.Equal(t, "一只猫跳舞", gen.Prompt)
|
||||
require.Equal(t, "video", gen.MediaType)
|
||||
require.Equal(t, SoraGenStatusPending, gen.Status)
|
||||
require.Equal(t, SoraStorageTypeNone, gen.StorageType)
|
||||
require.Nil(t, gen.APIKeyID)
|
||||
}
|
||||
|
||||
func TestCreatePending_WithAPIKeyID(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
apiKeyID := int64(42)
|
||||
gen, err := svc.CreatePending(context.Background(), 1, &apiKeyID, "gpt-image", "画一朵花", "image")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, gen.APIKeyID)
|
||||
require.Equal(t, int64(42), *gen.APIKeyID)
|
||||
}
|
||||
|
||||
func TestCreatePending_RepoError(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
repo.createErr = fmt.Errorf("db write error")
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
gen, err := svc.CreatePending(context.Background(), 1, nil, "sora2-landscape-10s", "test", "video")
|
||||
require.Error(t, err)
|
||||
require.Nil(t, gen)
|
||||
require.Contains(t, err.Error(), "create generation")
|
||||
}
|
||||
|
||||
// ==================== MarkGenerating ====================
|
||||
|
||||
func TestMarkGenerating_Success(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusPending}
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
err := svc.MarkGenerating(context.Background(), 1, "upstream-task-123")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, SoraGenStatusGenerating, repo.gens[1].Status)
|
||||
require.Equal(t, "upstream-task-123", repo.gens[1].UpstreamTaskID)
|
||||
}
|
||||
|
||||
func TestMarkGenerating_NotFound(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
err := svc.MarkGenerating(context.Background(), 999, "")
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestMarkGenerating_UpdateError(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusPending}
|
||||
repo.updateErr = fmt.Errorf("update failed")
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
err := svc.MarkGenerating(context.Background(), 1, "")
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
// ==================== MarkCompleted ====================
|
||||
|
||||
func TestMarkCompleted_Success(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusGenerating}
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
err := svc.MarkCompleted(context.Background(), 1,
|
||||
"https://cdn.example.com/video.mp4",
|
||||
[]string{"https://cdn.example.com/video.mp4"},
|
||||
SoraStorageTypeS3,
|
||||
[]string{"sora/1/2024/01/01/uuid.mp4"},
|
||||
1048576,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
gen := repo.gens[1]
|
||||
require.Equal(t, SoraGenStatusCompleted, gen.Status)
|
||||
require.Equal(t, "https://cdn.example.com/video.mp4", gen.MediaURL)
|
||||
require.Equal(t, []string{"https://cdn.example.com/video.mp4"}, gen.MediaURLs)
|
||||
require.Equal(t, SoraStorageTypeS3, gen.StorageType)
|
||||
require.Equal(t, []string{"sora/1/2024/01/01/uuid.mp4"}, gen.S3ObjectKeys)
|
||||
require.Equal(t, int64(1048576), gen.FileSizeBytes)
|
||||
require.NotNil(t, gen.CompletedAt)
|
||||
}
|
||||
|
||||
func TestMarkCompleted_NotFound(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
err := svc.MarkCompleted(context.Background(), 999, "", nil, "", nil, 0)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestMarkCompleted_UpdateError(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusGenerating}
|
||||
repo.updateErr = fmt.Errorf("update failed")
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
err := svc.MarkCompleted(context.Background(), 1, "url", nil, SoraStorageTypeUpstream, nil, 0)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
// ==================== MarkFailed ====================
|
||||
|
||||
func TestMarkFailed_Success(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusGenerating}
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
err := svc.MarkFailed(context.Background(), 1, "上游返回 500 错误")
|
||||
require.NoError(t, err)
|
||||
gen := repo.gens[1]
|
||||
require.Equal(t, SoraGenStatusFailed, gen.Status)
|
||||
require.Equal(t, "上游返回 500 错误", gen.ErrorMessage)
|
||||
require.NotNil(t, gen.CompletedAt)
|
||||
}
|
||||
|
||||
func TestMarkFailed_NotFound(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
err := svc.MarkFailed(context.Background(), 999, "error")
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestMarkFailed_UpdateError(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusGenerating}
|
||||
repo.updateErr = fmt.Errorf("update failed")
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
err := svc.MarkFailed(context.Background(), 1, "err")
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
// ==================== MarkCancelled ====================
|
||||
|
||||
func TestMarkCancelled_Pending(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusPending}
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
err := svc.MarkCancelled(context.Background(), 1)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, SoraGenStatusCancelled, repo.gens[1].Status)
|
||||
require.NotNil(t, repo.gens[1].CompletedAt)
|
||||
}
|
||||
|
||||
func TestMarkCancelled_Generating(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusGenerating}
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
err := svc.MarkCancelled(context.Background(), 1)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, SoraGenStatusCancelled, repo.gens[1].Status)
|
||||
}
|
||||
|
||||
func TestMarkCancelled_Completed(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusCompleted}
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
err := svc.MarkCancelled(context.Background(), 1)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, ErrSoraGenerationNotActive)
|
||||
}
|
||||
|
||||
func TestMarkCancelled_Failed(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusFailed}
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
err := svc.MarkCancelled(context.Background(), 1)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestMarkCancelled_AlreadyCancelled(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusCancelled}
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
err := svc.MarkCancelled(context.Background(), 1)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestMarkCancelled_NotFound(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
err := svc.MarkCancelled(context.Background(), 999)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestMarkCancelled_UpdateError(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusPending}
|
||||
repo.updateErr = fmt.Errorf("update failed")
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
err := svc.MarkCancelled(context.Background(), 1)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
// ==================== GetByID ====================
|
||||
|
||||
func TestGetByID_Success(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusCompleted, Model: "sora2-landscape-10s"}
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
gen, err := svc.GetByID(context.Background(), 1, 1)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(1), gen.ID)
|
||||
require.Equal(t, "sora2-landscape-10s", gen.Model)
|
||||
}
|
||||
|
||||
func TestGetByID_WrongUser(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 2, Status: SoraGenStatusCompleted}
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
gen, err := svc.GetByID(context.Background(), 1, 1)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, gen)
|
||||
require.Contains(t, err.Error(), "无权访问")
|
||||
}
|
||||
|
||||
func TestGetByID_NotFound(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
gen, err := svc.GetByID(context.Background(), 999, 1)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, gen)
|
||||
}
|
||||
|
||||
// ==================== List ====================
|
||||
|
||||
func TestList_Success(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusCompleted, MediaType: "video"}
|
||||
repo.gens[2] = &SoraGeneration{ID: 2, UserID: 1, Status: SoraGenStatusPending, MediaType: "image"}
|
||||
repo.gens[3] = &SoraGeneration{ID: 3, UserID: 2, Status: SoraGenStatusCompleted, MediaType: "video"}
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
gens, total, err := svc.List(context.Background(), SoraGenerationListParams{UserID: 1, Page: 1, PageSize: 20})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, gens, 2) // 只有 userID=1 的
|
||||
require.Equal(t, int64(2), total)
|
||||
}
|
||||
|
||||
func TestList_DefaultPagination(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
// page=0, pageSize=0 → 应修正为 page=1, pageSize=20
|
||||
_, _, err := svc.List(context.Background(), SoraGenerationListParams{UserID: 1})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestList_MaxPageSize(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
// pageSize > 100 → 应限制为 100
|
||||
_, _, err := svc.List(context.Background(), SoraGenerationListParams{UserID: 1, Page: 1, PageSize: 200})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestList_Error(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
repo.listErr = fmt.Errorf("db error")
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
_, _, err := svc.List(context.Background(), SoraGenerationListParams{UserID: 1})
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
// ==================== Delete ====================
|
||||
|
||||
func TestDelete_Success(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusCompleted, StorageType: SoraStorageTypeUpstream}
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
err := svc.Delete(context.Background(), 1, 1)
|
||||
require.NoError(t, err)
|
||||
_, exists := repo.gens[1]
|
||||
require.False(t, exists)
|
||||
}
|
||||
|
||||
func TestDelete_WrongUser(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 2, Status: SoraGenStatusCompleted}
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
err := svc.Delete(context.Background(), 1, 1)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "无权删除")
|
||||
}
|
||||
|
||||
func TestDelete_NotFound(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
err := svc.Delete(context.Background(), 999, 1)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestDelete_S3Cleanup_NilS3(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, StorageType: SoraStorageTypeS3, S3ObjectKeys: []string{"key1"}}
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
err := svc.Delete(context.Background(), 1, 1)
|
||||
require.NoError(t, err) // s3Storage 为 nil,跳过清理
|
||||
}
|
||||
|
||||
func TestDelete_QuotaRelease_NilQuota(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, StorageType: SoraStorageTypeS3, FileSizeBytes: 1024}
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
err := svc.Delete(context.Background(), 1, 1)
|
||||
require.NoError(t, err) // quotaService 为 nil,跳过释放
|
||||
}
|
||||
|
||||
func TestDelete_NonS3NoCleanup(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, StorageType: SoraStorageTypeLocal, FileSizeBytes: 1024}
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
err := svc.Delete(context.Background(), 1, 1)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestDelete_DeleteRepoError(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, StorageType: SoraStorageTypeUpstream}
|
||||
repo.deleteErr = fmt.Errorf("delete failed")
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
err := svc.Delete(context.Background(), 1, 1)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
// ==================== CountActiveByUser ====================
|
||||
|
||||
func TestCountActiveByUser_Success(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusPending}
|
||||
repo.gens[2] = &SoraGeneration{ID: 2, UserID: 1, Status: SoraGenStatusGenerating}
|
||||
repo.gens[3] = &SoraGeneration{ID: 3, UserID: 1, Status: SoraGenStatusCompleted} // 不算
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
count, err := svc.CountActiveByUser(context.Background(), 1)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(2), count)
|
||||
}
|
||||
|
||||
func TestCountActiveByUser_NoActive(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusCompleted}
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
count, err := svc.CountActiveByUser(context.Background(), 1)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(0), count)
|
||||
}
|
||||
|
||||
func TestCountActiveByUser_Error(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
repo.countErr = fmt.Errorf("db error")
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
_, err := svc.CountActiveByUser(context.Background(), 1)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
// ==================== ResolveMediaURLs ====================
|
||||
|
||||
func TestResolveMediaURLs_NilGen(t *testing.T) {
|
||||
svc := NewSoraGenerationService(newStubGenRepo(), nil, nil)
|
||||
require.NoError(t, svc.ResolveMediaURLs(context.Background(), nil))
|
||||
}
|
||||
|
||||
func TestResolveMediaURLs_NonS3(t *testing.T) {
|
||||
svc := NewSoraGenerationService(newStubGenRepo(), nil, nil)
|
||||
gen := &SoraGeneration{StorageType: SoraStorageTypeUpstream, MediaURL: "https://original.com/v.mp4"}
|
||||
require.NoError(t, svc.ResolveMediaURLs(context.Background(), gen))
|
||||
require.Equal(t, "https://original.com/v.mp4", gen.MediaURL) // 不变
|
||||
}
|
||||
|
||||
func TestResolveMediaURLs_S3NilStorage(t *testing.T) {
|
||||
svc := NewSoraGenerationService(newStubGenRepo(), nil, nil)
|
||||
gen := &SoraGeneration{StorageType: SoraStorageTypeS3, S3ObjectKeys: []string{"key1"}}
|
||||
require.NoError(t, svc.ResolveMediaURLs(context.Background(), gen))
|
||||
}
|
||||
|
||||
func TestResolveMediaURLs_Local(t *testing.T) {
|
||||
svc := NewSoraGenerationService(newStubGenRepo(), nil, nil)
|
||||
gen := &SoraGeneration{StorageType: SoraStorageTypeLocal, MediaURL: "/video/2024/01/01/file.mp4"}
|
||||
require.NoError(t, svc.ResolveMediaURLs(context.Background(), gen))
|
||||
require.Equal(t, "/video/2024/01/01/file.mp4", gen.MediaURL) // 不变
|
||||
}
|
||||
|
||||
// ==================== 状态流转完整测试 ====================
|
||||
|
||||
func TestStatusTransition_PendingToCompletedFlow(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
// 1. 创建 pending
|
||||
gen, err := svc.CreatePending(context.Background(), 1, nil, "sora2-landscape-10s", "test", "video")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, SoraGenStatusPending, gen.Status)
|
||||
|
||||
// 2. 标记 generating
|
||||
err = svc.MarkGenerating(context.Background(), gen.ID, "task-123")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, SoraGenStatusGenerating, repo.gens[gen.ID].Status)
|
||||
|
||||
// 3. 标记 completed
|
||||
err = svc.MarkCompleted(context.Background(), gen.ID, "https://s3.com/video.mp4", nil, SoraStorageTypeS3, []string{"key"}, 1024)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, SoraGenStatusCompleted, repo.gens[gen.ID].Status)
|
||||
}
|
||||
|
||||
func TestStatusTransition_PendingToFailedFlow(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
gen, _ := svc.CreatePending(context.Background(), 1, nil, "sora2-landscape-10s", "test", "video")
|
||||
_ = svc.MarkGenerating(context.Background(), gen.ID, "")
|
||||
|
||||
err := svc.MarkFailed(context.Background(), gen.ID, "上游超时")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, SoraGenStatusFailed, repo.gens[gen.ID].Status)
|
||||
require.Equal(t, "上游超时", repo.gens[gen.ID].ErrorMessage)
|
||||
}
|
||||
|
||||
func TestStatusTransition_PendingToCancelledFlow(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
gen, _ := svc.CreatePending(context.Background(), 1, nil, "sora2-landscape-10s", "test", "video")
|
||||
err := svc.MarkCancelled(context.Background(), gen.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, SoraGenStatusCancelled, repo.gens[gen.ID].Status)
|
||||
}
|
||||
|
||||
func TestStatusTransition_GeneratingToCancelledFlow(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
gen, _ := svc.CreatePending(context.Background(), 1, nil, "sora2-landscape-10s", "test", "video")
|
||||
_ = svc.MarkGenerating(context.Background(), gen.ID, "")
|
||||
err := svc.MarkCancelled(context.Background(), gen.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, SoraGenStatusCancelled, repo.gens[gen.ID].Status)
|
||||
}
|
||||
|
||||
// ==================== 权限隔离测试 ====================
|
||||
|
||||
func TestUserIsolation_CannotAccessOthersRecord(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
gen, _ := svc.CreatePending(context.Background(), 1, nil, "sora2-landscape-10s", "test", "video")
|
||||
|
||||
// 用户 2 尝试访问用户 1 的记录
|
||||
_, err := svc.GetByID(context.Background(), gen.ID, 2)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "无权访问")
|
||||
}
|
||||
|
||||
func TestUserIsolation_CannotDeleteOthersRecord(t *testing.T) {
|
||||
repo := newStubGenRepo()
|
||||
svc := NewSoraGenerationService(repo, nil, nil)
|
||||
|
||||
gen, _ := svc.CreatePending(context.Background(), 1, nil, "sora2-landscape-10s", "test", "video")
|
||||
|
||||
err := svc.Delete(context.Background(), gen.ID, 2)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "无权删除")
|
||||
}
|
||||
|
||||
// ==================== Delete: S3 清理 + 配额释放路径 ====================
|
||||
|
||||
func TestDelete_S3Cleanup_WithS3Storage(t *testing.T) {
|
||||
// S3 存储存在但 deleteObjects 会失败(settingService=nil),
|
||||
// 验证 Delete 仍然成功(S3 错误只是记录日志)
|
||||
repo := newStubGenRepo()
|
||||
repo.gens[1] = &SoraGeneration{
|
||||
ID: 1, UserID: 1,
|
||||
StorageType: SoraStorageTypeS3,
|
||||
S3ObjectKeys: []string{"sora/1/2024/01/01/abc.mp4"},
|
||||
}
|
||||
s3Storage := newS3StorageFailingDelete()
|
||||
svc := NewSoraGenerationService(repo, s3Storage, nil)
|
||||
|
||||
err := svc.Delete(context.Background(), 1, 1)
|
||||
require.NoError(t, err) // S3 清理失败不影响删除
|
||||
_, exists := repo.gens[1]
|
||||
require.False(t, exists)
|
||||
}
|
||||
|
||||
func TestDelete_QuotaRelease_WithQuotaService(t *testing.T) {
|
||||
// 有配额服务时,删除 S3 类型记录会释放配额
|
||||
repo := newStubGenRepo()
|
||||
repo.gens[1] = &SoraGeneration{
|
||||
ID: 1, UserID: 1,
|
||||
StorageType: SoraStorageTypeS3,
|
||||
FileSizeBytes: 1048576, // 1MB
|
||||
}
|
||||
|
||||
userRepo := newStubUserRepoForQuota()
|
||||
userRepo.users[1] = &User{ID: 1, SoraStorageUsedBytes: 2097152} // 2MB
|
||||
quotaService := NewSoraQuotaService(userRepo, nil, nil)
|
||||
|
||||
svc := NewSoraGenerationService(repo, nil, quotaService)
|
||||
err := svc.Delete(context.Background(), 1, 1)
|
||||
require.NoError(t, err)
|
||||
// 配额应被释放: 2MB - 1MB = 1MB
|
||||
require.Equal(t, int64(1048576), userRepo.users[1].SoraStorageUsedBytes)
|
||||
}
|
||||
|
||||
func TestDelete_S3Cleanup_And_QuotaRelease(t *testing.T) {
|
||||
// S3 清理 + 配额释放同时触发
|
||||
repo := newStubGenRepo()
|
||||
repo.gens[1] = &SoraGeneration{
|
||||
ID: 1, UserID: 1,
|
||||
StorageType: SoraStorageTypeS3,
|
||||
S3ObjectKeys: []string{"key1"},
|
||||
FileSizeBytes: 512,
|
||||
}
|
||||
|
||||
userRepo := newStubUserRepoForQuota()
|
||||
userRepo.users[1] = &User{ID: 1, SoraStorageUsedBytes: 1024}
|
||||
quotaService := NewSoraQuotaService(userRepo, nil, nil)
|
||||
s3Storage := newS3StorageFailingDelete()
|
||||
|
||||
svc := NewSoraGenerationService(repo, s3Storage, quotaService)
|
||||
err := svc.Delete(context.Background(), 1, 1)
|
||||
require.NoError(t, err)
|
||||
_, exists := repo.gens[1]
|
||||
require.False(t, exists)
|
||||
require.Equal(t, int64(512), userRepo.users[1].SoraStorageUsedBytes)
|
||||
}
|
||||
|
||||
func TestDelete_QuotaRelease_LocalStorage(t *testing.T) {
|
||||
// 本地存储同样需要释放配额
|
||||
repo := newStubGenRepo()
|
||||
repo.gens[1] = &SoraGeneration{
|
||||
ID: 1, UserID: 1,
|
||||
StorageType: SoraStorageTypeLocal,
|
||||
FileSizeBytes: 1024,
|
||||
}
|
||||
|
||||
userRepo := newStubUserRepoForQuota()
|
||||
userRepo.users[1] = &User{ID: 1, SoraStorageUsedBytes: 2048}
|
||||
quotaService := NewSoraQuotaService(userRepo, nil, nil)
|
||||
|
||||
svc := NewSoraGenerationService(repo, nil, quotaService)
|
||||
err := svc.Delete(context.Background(), 1, 1)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(1024), userRepo.users[1].SoraStorageUsedBytes)
|
||||
}
|
||||
|
||||
func TestDelete_QuotaRelease_ZeroFileSize(t *testing.T) {
|
||||
// FileSizeBytes=0 跳过配额释放
|
||||
repo := newStubGenRepo()
|
||||
repo.gens[1] = &SoraGeneration{
|
||||
ID: 1, UserID: 1,
|
||||
StorageType: SoraStorageTypeS3,
|
||||
FileSizeBytes: 0,
|
||||
}
|
||||
|
||||
userRepo := newStubUserRepoForQuota()
|
||||
userRepo.users[1] = &User{ID: 1, SoraStorageUsedBytes: 1024}
|
||||
quotaService := NewSoraQuotaService(userRepo, nil, nil)
|
||||
|
||||
svc := NewSoraGenerationService(repo, nil, quotaService)
|
||||
err := svc.Delete(context.Background(), 1, 1)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(1024), userRepo.users[1].SoraStorageUsedBytes)
|
||||
}
|
||||
|
||||
// ==================== ResolveMediaURLs: S3 + CDN 路径 ====================
|
||||
|
||||
func TestResolveMediaURLs_S3_CDN_SingleKey(t *testing.T) {
|
||||
s3Storage := newS3StorageWithCDN("https://cdn.example.com")
|
||||
svc := NewSoraGenerationService(newStubGenRepo(), s3Storage, nil)
|
||||
|
||||
gen := &SoraGeneration{
|
||||
StorageType: SoraStorageTypeS3,
|
||||
S3ObjectKeys: []string{"sora/1/2024/01/01/video.mp4"},
|
||||
MediaURL: "original",
|
||||
}
|
||||
err := svc.ResolveMediaURLs(context.Background(), gen)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "https://cdn.example.com/sora/1/2024/01/01/video.mp4", gen.MediaURL)
|
||||
}
|
||||
|
||||
func TestResolveMediaURLs_S3_CDN_MultipleKeys(t *testing.T) {
|
||||
s3Storage := newS3StorageWithCDN("https://cdn.example.com/")
|
||||
svc := NewSoraGenerationService(newStubGenRepo(), s3Storage, nil)
|
||||
|
||||
gen := &SoraGeneration{
|
||||
StorageType: SoraStorageTypeS3,
|
||||
S3ObjectKeys: []string{
|
||||
"sora/1/2024/01/01/img1.png",
|
||||
"sora/1/2024/01/01/img2.png",
|
||||
"sora/1/2024/01/01/img3.png",
|
||||
},
|
||||
MediaURL: "original",
|
||||
}
|
||||
err := svc.ResolveMediaURLs(context.Background(), gen)
|
||||
require.NoError(t, err)
|
||||
// 主 URL 更新为第一个 key 的 CDN URL
|
||||
require.Equal(t, "https://cdn.example.com/sora/1/2024/01/01/img1.png", gen.MediaURL)
|
||||
// 多图 URLs 全部更新
|
||||
require.Len(t, gen.MediaURLs, 3)
|
||||
require.Equal(t, "https://cdn.example.com/sora/1/2024/01/01/img1.png", gen.MediaURLs[0])
|
||||
require.Equal(t, "https://cdn.example.com/sora/1/2024/01/01/img2.png", gen.MediaURLs[1])
|
||||
require.Equal(t, "https://cdn.example.com/sora/1/2024/01/01/img3.png", gen.MediaURLs[2])
|
||||
}
|
||||
|
||||
func TestResolveMediaURLs_S3_EmptyKeys(t *testing.T) {
|
||||
s3Storage := newS3StorageWithCDN("https://cdn.example.com")
|
||||
svc := NewSoraGenerationService(newStubGenRepo(), s3Storage, nil)
|
||||
|
||||
gen := &SoraGeneration{
|
||||
StorageType: SoraStorageTypeS3,
|
||||
S3ObjectKeys: []string{},
|
||||
MediaURL: "original",
|
||||
}
|
||||
err := svc.ResolveMediaURLs(context.Background(), gen)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "original", gen.MediaURL) // 不变
|
||||
}
|
||||
|
||||
func TestResolveMediaURLs_S3_GetAccessURL_Error(t *testing.T) {
|
||||
// 使用无 settingService 的 S3 Storage,getClient 会失败
|
||||
s3Storage := newS3StorageFailingDelete() // 同样 GetAccessURL 也会失败
|
||||
svc := NewSoraGenerationService(newStubGenRepo(), s3Storage, nil)
|
||||
|
||||
gen := &SoraGeneration{
|
||||
StorageType: SoraStorageTypeS3,
|
||||
S3ObjectKeys: []string{"sora/1/2024/01/01/video.mp4"},
|
||||
MediaURL: "original",
|
||||
}
|
||||
err := svc.ResolveMediaURLs(context.Background(), gen)
|
||||
require.Error(t, err) // GetAccessURL 失败应传播错误
|
||||
}
|
||||
|
||||
func TestResolveMediaURLs_S3_MultiKey_ErrorOnSecond(t *testing.T) {
|
||||
// 只有一个 key 时走主 URL 路径成功,但多 key 路径的错误也需覆盖
|
||||
s3Storage := newS3StorageFailingDelete()
|
||||
svc := NewSoraGenerationService(newStubGenRepo(), s3Storage, nil)
|
||||
|
||||
gen := &SoraGeneration{
|
||||
StorageType: SoraStorageTypeS3,
|
||||
S3ObjectKeys: []string{
|
||||
"sora/1/2024/01/01/img1.png",
|
||||
"sora/1/2024/01/01/img2.png",
|
||||
},
|
||||
MediaURL: "original",
|
||||
}
|
||||
err := svc.ResolveMediaURLs(context.Background(), gen)
|
||||
require.Error(t, err) // 第一个 key 的 GetAccessURL 就会失败
|
||||
}
|
||||
Reference in New Issue
Block a user