//go:build unit package handler import ( "context" "encoding/json" "fmt" "io" "net/http" "net/http/httptest" "os" "strings" "sync/atomic" "testing" "time" "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/gin-gonic/gin" "github.com/stretchr/testify/require" ) func init() { gin.SetMode(gin.TestMode) } // ==================== Stub: SoraGenerationRepository ==================== var _ service.SoraGenerationRepository = (*stubSoraGenRepo)(nil) type stubSoraGenRepo struct { gens map[int64]*service.SoraGeneration nextID int64 createErr error getErr error updateErr error deleteErr error listErr error countErr error countValue int64 // 条件性 Update 失败:前 updateFailAfterN 次成功,之后失败 updateCallCount *int32 updateFailAfterN int32 // 条件性 GetByID 状态覆盖:前 getByIDOverrideAfterN 次正常返回,之后返回 overrideStatus getByIDCallCount int32 getByIDOverrideAfterN int32 // 0 = 不覆盖 getByIDOverrideStatus string } func newStubSoraGenRepo() *stubSoraGenRepo { return &stubSoraGenRepo{gens: make(map[int64]*service.SoraGeneration), nextID: 1} } func (r *stubSoraGenRepo) Create(_ context.Context, gen *service.SoraGeneration) error { if r.createErr != nil { return r.createErr } gen.ID = r.nextID r.nextID++ r.gens[gen.ID] = gen return nil } func (r *stubSoraGenRepo) GetByID(_ context.Context, id int64) (*service.SoraGeneration, error) { if r.getErr != nil { return nil, r.getErr } gen, ok := r.gens[id] if !ok { return nil, fmt.Errorf("not found") } // 条件性状态覆盖:模拟外部取消等场景 if r.getByIDOverrideAfterN > 0 { n := atomic.AddInt32(&r.getByIDCallCount, 1) if n > r.getByIDOverrideAfterN { cp := *gen cp.Status = r.getByIDOverrideStatus return &cp, nil } } return gen, nil } func (r *stubSoraGenRepo) Update(_ context.Context, gen *service.SoraGeneration) error { // 条件性失败:前 N 次成功,之后失败 if r.updateCallCount != nil { n := atomic.AddInt32(r.updateCallCount, 1) if n > r.updateFailAfterN { return fmt.Errorf("conditional update error (call #%d)", n) } } if r.updateErr != nil { return r.updateErr } r.gens[gen.ID] = gen return nil } func (r *stubSoraGenRepo) Delete(_ context.Context, id int64) error { if r.deleteErr != nil { return r.deleteErr } delete(r.gens, id) return nil } func (r *stubSoraGenRepo) List(_ context.Context, params service.SoraGenerationListParams) ([]*service.SoraGeneration, int64, error) { if r.listErr != nil { return nil, 0, r.listErr } var result []*service.SoraGeneration for _, gen := range r.gens { if gen.UserID != params.UserID { continue } result = append(result, gen) } return result, int64(len(result)), nil } func (r *stubSoraGenRepo) CountByUserAndStatus(_ context.Context, _ int64, _ []string) (int64, error) { if r.countErr != nil { return 0, r.countErr } return r.countValue, nil } // ==================== 辅助函数 ==================== func newTestSoraClientHandler(repo *stubSoraGenRepo) *SoraClientHandler { genService := service.NewSoraGenerationService(repo, nil, nil) return &SoraClientHandler{genService: genService} } func makeGinContext(method, path, body string, userID int64) (*gin.Context, *httptest.ResponseRecorder) { rec := httptest.NewRecorder() c, _ := gin.CreateTestContext(rec) if body != "" { c.Request = httptest.NewRequest(method, path, strings.NewReader(body)) c.Request.Header.Set("Content-Type", "application/json") } else { c.Request = httptest.NewRequest(method, path, nil) } if userID > 0 { c.Set("user_id", userID) } return c, rec } func parseResponse(t *testing.T, rec *httptest.ResponseRecorder) map[string]any { t.Helper() var resp map[string]any require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) return resp } // ==================== 纯函数测试: buildAsyncRequestBody ==================== func TestBuildAsyncRequestBody(t *testing.T) { body := buildAsyncRequestBody("sora2-landscape-10s", "一只猫在跳舞", "", 1) var parsed map[string]any require.NoError(t, json.Unmarshal(body, &parsed)) require.Equal(t, "sora2-landscape-10s", parsed["model"]) require.Equal(t, false, parsed["stream"]) msgs := parsed["messages"].([]any) require.Len(t, msgs, 1) msg := msgs[0].(map[string]any) require.Equal(t, "user", msg["role"]) require.Equal(t, "一只猫在跳舞", msg["content"]) } func TestBuildAsyncRequestBody_EmptyPrompt(t *testing.T) { body := buildAsyncRequestBody("gpt-image", "", "", 1) var parsed map[string]any require.NoError(t, json.Unmarshal(body, &parsed)) require.Equal(t, "gpt-image", parsed["model"]) msgs := parsed["messages"].([]any) msg := msgs[0].(map[string]any) require.Equal(t, "", msg["content"]) } func TestBuildAsyncRequestBody_WithImageInput(t *testing.T) { body := buildAsyncRequestBody("gpt-image", "一只猫", "https://example.com/ref.png", 1) var parsed map[string]any require.NoError(t, json.Unmarshal(body, &parsed)) require.Equal(t, "https://example.com/ref.png", parsed["image_input"]) } func TestBuildAsyncRequestBody_WithVideoCount(t *testing.T) { body := buildAsyncRequestBody("sora2-landscape-10s", "一只猫在跳舞", "", 3) var parsed map[string]any require.NoError(t, json.Unmarshal(body, &parsed)) require.Equal(t, float64(3), parsed["video_count"]) } func TestNormalizeVideoCount(t *testing.T) { require.Equal(t, 1, normalizeVideoCount("video", 0)) require.Equal(t, 2, normalizeVideoCount("video", 2)) require.Equal(t, 3, normalizeVideoCount("video", 5)) require.Equal(t, 1, normalizeVideoCount("image", 3)) } // ==================== 纯函数测试: parseMediaURLsFromBody ==================== func TestParseMediaURLsFromBody_MediaURLs(t *testing.T) { urls := parseMediaURLsFromBody([]byte(`{"media_urls":["https://a.com/1.mp4","https://a.com/2.mp4"]}`)) require.Equal(t, []string{"https://a.com/1.mp4", "https://a.com/2.mp4"}, urls) } func TestParseMediaURLsFromBody_SingleMediaURL(t *testing.T) { urls := parseMediaURLsFromBody([]byte(`{"media_url":"https://a.com/video.mp4"}`)) require.Equal(t, []string{"https://a.com/video.mp4"}, urls) } func TestParseMediaURLsFromBody_EmptyBody(t *testing.T) { require.Nil(t, parseMediaURLsFromBody(nil)) require.Nil(t, parseMediaURLsFromBody([]byte{})) } func TestParseMediaURLsFromBody_InvalidJSON(t *testing.T) { require.Nil(t, parseMediaURLsFromBody([]byte("not json"))) } func TestParseMediaURLsFromBody_NoMediaFields(t *testing.T) { require.Nil(t, parseMediaURLsFromBody([]byte(`{"data":"something"}`))) } func TestParseMediaURLsFromBody_EmptyMediaURL(t *testing.T) { require.Nil(t, parseMediaURLsFromBody([]byte(`{"media_url":""}`))) } func TestParseMediaURLsFromBody_EmptyMediaURLs(t *testing.T) { require.Nil(t, parseMediaURLsFromBody([]byte(`{"media_urls":[]}`))) } func TestParseMediaURLsFromBody_MediaURLsPriority(t *testing.T) { body := `{"media_url":"https://single.com/1.mp4","media_urls":["https://multi.com/a.mp4","https://multi.com/b.mp4"]}` urls := parseMediaURLsFromBody([]byte(body)) require.Len(t, urls, 2) require.Equal(t, "https://multi.com/a.mp4", urls[0]) } func TestParseMediaURLsFromBody_FilterEmpty(t *testing.T) { urls := parseMediaURLsFromBody([]byte(`{"media_urls":["https://a.com/1.mp4","","https://a.com/2.mp4"]}`)) require.Equal(t, []string{"https://a.com/1.mp4", "https://a.com/2.mp4"}, urls) } func TestParseMediaURLsFromBody_AllEmpty(t *testing.T) { require.Nil(t, parseMediaURLsFromBody([]byte(`{"media_urls":["",""]}`))) } func TestParseMediaURLsFromBody_NonStringArray(t *testing.T) { // media_urls 不是 string 数组 require.Nil(t, parseMediaURLsFromBody([]byte(`{"media_urls":"not-array"}`))) } func TestParseMediaURLsFromBody_MediaURLNotString(t *testing.T) { require.Nil(t, parseMediaURLsFromBody([]byte(`{"media_url":123}`))) } // ==================== 纯函数测试: extractMediaURLsFromResult ==================== func TestExtractMediaURLsFromResult_OAuthPath(t *testing.T) { result := &service.ForwardResult{MediaURL: "https://oauth.com/video.mp4"} recorder := httptest.NewRecorder() url, urls := extractMediaURLsFromResult(result, recorder) require.Equal(t, "https://oauth.com/video.mp4", url) require.Equal(t, []string{"https://oauth.com/video.mp4"}, urls) } func TestExtractMediaURLsFromResult_OAuthWithBody(t *testing.T) { result := &service.ForwardResult{MediaURL: "https://oauth.com/video.mp4"} recorder := httptest.NewRecorder() _, _ = recorder.Write([]byte(`{"media_urls":["https://body.com/1.mp4","https://body.com/2.mp4"]}`)) url, urls := extractMediaURLsFromResult(result, recorder) require.Equal(t, "https://body.com/1.mp4", url) require.Len(t, urls, 2) } func TestExtractMediaURLsFromResult_APIKeyPath(t *testing.T) { recorder := httptest.NewRecorder() _, _ = recorder.Write([]byte(`{"media_url":"https://upstream.com/video.mp4"}`)) url, urls := extractMediaURLsFromResult(nil, recorder) require.Equal(t, "https://upstream.com/video.mp4", url) require.Equal(t, []string{"https://upstream.com/video.mp4"}, urls) } func TestExtractMediaURLsFromResult_NilResultEmptyBody(t *testing.T) { recorder := httptest.NewRecorder() url, urls := extractMediaURLsFromResult(nil, recorder) require.Empty(t, url) require.Nil(t, urls) } func TestExtractMediaURLsFromResult_EmptyMediaURL(t *testing.T) { result := &service.ForwardResult{MediaURL: ""} recorder := httptest.NewRecorder() url, urls := extractMediaURLsFromResult(result, recorder) require.Empty(t, url) require.Nil(t, urls) } // ==================== getUserIDFromContext ==================== func TestGetUserIDFromContext_Int64(t *testing.T) { c, _ := gin.CreateTestContext(httptest.NewRecorder()) c.Request = httptest.NewRequest("GET", "/", nil) c.Set("user_id", int64(42)) require.Equal(t, int64(42), getUserIDFromContext(c)) } func TestGetUserIDFromContext_AuthSubject(t *testing.T) { c, _ := gin.CreateTestContext(httptest.NewRecorder()) c.Request = httptest.NewRequest("GET", "/", nil) c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 777}) require.Equal(t, int64(777), getUserIDFromContext(c)) } func TestGetUserIDFromContext_Float64(t *testing.T) { c, _ := gin.CreateTestContext(httptest.NewRecorder()) c.Request = httptest.NewRequest("GET", "/", nil) c.Set("user_id", float64(99)) require.Equal(t, int64(99), getUserIDFromContext(c)) } func TestGetUserIDFromContext_String(t *testing.T) { c, _ := gin.CreateTestContext(httptest.NewRecorder()) c.Request = httptest.NewRequest("GET", "/", nil) c.Set("user_id", "123") require.Equal(t, int64(123), getUserIDFromContext(c)) } func TestGetUserIDFromContext_UserIDFallback(t *testing.T) { c, _ := gin.CreateTestContext(httptest.NewRecorder()) c.Request = httptest.NewRequest("GET", "/", nil) c.Set("userID", int64(55)) require.Equal(t, int64(55), getUserIDFromContext(c)) } func TestGetUserIDFromContext_NoID(t *testing.T) { c, _ := gin.CreateTestContext(httptest.NewRecorder()) c.Request = httptest.NewRequest("GET", "/", nil) require.Equal(t, int64(0), getUserIDFromContext(c)) } func TestGetUserIDFromContext_InvalidString(t *testing.T) { c, _ := gin.CreateTestContext(httptest.NewRecorder()) c.Request = httptest.NewRequest("GET", "/", nil) c.Set("user_id", "not-a-number") require.Equal(t, int64(0), getUserIDFromContext(c)) } // ==================== Handler: Generate ==================== func TestGenerate_Unauthorized(t *testing.T) { h := newTestSoraClientHandler(newStubSoraGenRepo()) c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 0) h.Generate(c) require.Equal(t, http.StatusUnauthorized, rec.Code) } func TestGenerate_BadRequest_MissingModel(t *testing.T) { h := newTestSoraClientHandler(newStubSoraGenRepo()) c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"prompt":"test"}`, 1) h.Generate(c) require.Equal(t, http.StatusBadRequest, rec.Code) } func TestGenerate_BadRequest_MissingPrompt(t *testing.T) { h := newTestSoraClientHandler(newStubSoraGenRepo()) c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s"}`, 1) h.Generate(c) require.Equal(t, http.StatusBadRequest, rec.Code) } func TestGenerate_BadRequest_InvalidJSON(t *testing.T) { h := newTestSoraClientHandler(newStubSoraGenRepo()) c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{invalid`, 1) h.Generate(c) require.Equal(t, http.StatusBadRequest, rec.Code) } func TestGenerate_TooManyRequests(t *testing.T) { repo := newStubSoraGenRepo() repo.countValue = 3 h := newTestSoraClientHandler(repo) c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1) h.Generate(c) require.Equal(t, http.StatusTooManyRequests, rec.Code) } func TestGenerate_CountError(t *testing.T) { repo := newStubSoraGenRepo() repo.countErr = fmt.Errorf("db error") h := newTestSoraClientHandler(repo) c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1) h.Generate(c) require.Equal(t, http.StatusInternalServerError, rec.Code) } func TestGenerate_Success(t *testing.T) { repo := newStubSoraGenRepo() h := newTestSoraClientHandler(repo) c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"测试生成"}`, 1) h.Generate(c) require.Equal(t, http.StatusOK, rec.Code) resp := parseResponse(t, rec) data := resp["data"].(map[string]any) require.NotZero(t, data["generation_id"]) require.Equal(t, "pending", data["status"]) } func TestGenerate_DefaultMediaType(t *testing.T) { repo := newStubSoraGenRepo() h := newTestSoraClientHandler(repo) c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1) h.Generate(c) require.Equal(t, http.StatusOK, rec.Code) require.Equal(t, "video", repo.gens[1].MediaType) } func TestGenerate_ImageMediaType(t *testing.T) { repo := newStubSoraGenRepo() h := newTestSoraClientHandler(repo) c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"gpt-image","prompt":"test","media_type":"image"}`, 1) h.Generate(c) require.Equal(t, http.StatusOK, rec.Code) require.Equal(t, "image", repo.gens[1].MediaType) } func TestGenerate_CreatePendingError(t *testing.T) { repo := newStubSoraGenRepo() repo.createErr = fmt.Errorf("create failed") h := newTestSoraClientHandler(repo) c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1) h.Generate(c) require.Equal(t, http.StatusInternalServerError, rec.Code) } func TestGenerate_NilQuotaServiceSkipsCheck(t *testing.T) { repo := newStubSoraGenRepo() h := newTestSoraClientHandler(repo) c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1) h.Generate(c) require.Equal(t, http.StatusOK, rec.Code) } func TestGenerate_APIKeyInContext(t *testing.T) { repo := newStubSoraGenRepo() h := newTestSoraClientHandler(repo) c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1) c.Set("api_key_id", int64(42)) h.Generate(c) require.Equal(t, http.StatusOK, rec.Code) require.NotNil(t, repo.gens[1].APIKeyID) require.Equal(t, int64(42), *repo.gens[1].APIKeyID) } func TestGenerate_NoAPIKeyInContext(t *testing.T) { repo := newStubSoraGenRepo() h := newTestSoraClientHandler(repo) c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1) h.Generate(c) require.Equal(t, http.StatusOK, rec.Code) require.Nil(t, repo.gens[1].APIKeyID) } func TestGenerate_ConcurrencyBoundary(t *testing.T) { // activeCount == 2 应该允许 repo := newStubSoraGenRepo() repo.countValue = 2 h := newTestSoraClientHandler(repo) c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1) h.Generate(c) require.Equal(t, http.StatusOK, rec.Code) } // ==================== Handler: ListGenerations ==================== func TestListGenerations_Unauthorized(t *testing.T) { h := newTestSoraClientHandler(newStubSoraGenRepo()) c, rec := makeGinContext("GET", "/api/v1/sora/generations", "", 0) h.ListGenerations(c) require.Equal(t, http.StatusUnauthorized, rec.Code) } func TestListGenerations_Success(t *testing.T) { repo := newStubSoraGenRepo() repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Model: "sora2-landscape-10s", Status: "completed", StorageType: "upstream"} repo.gens[2] = &service.SoraGeneration{ID: 2, UserID: 1, Model: "gpt-image", Status: "pending", StorageType: "none"} repo.nextID = 3 h := newTestSoraClientHandler(repo) c, rec := makeGinContext("GET", "/api/v1/sora/generations?page=1&page_size=10", "", 1) h.ListGenerations(c) require.Equal(t, http.StatusOK, rec.Code) resp := parseResponse(t, rec) data := resp["data"].(map[string]any) items := data["data"].([]any) require.Len(t, items, 2) require.Equal(t, float64(2), data["total"]) } func TestListGenerations_ListError(t *testing.T) { repo := newStubSoraGenRepo() repo.listErr = fmt.Errorf("db error") h := newTestSoraClientHandler(repo) c, rec := makeGinContext("GET", "/api/v1/sora/generations", "", 1) h.ListGenerations(c) require.Equal(t, http.StatusInternalServerError, rec.Code) } func TestListGenerations_DefaultPagination(t *testing.T) { repo := newStubSoraGenRepo() h := newTestSoraClientHandler(repo) // 不传分页参数,应默认 page=1 page_size=20 c, rec := makeGinContext("GET", "/api/v1/sora/generations", "", 1) h.ListGenerations(c) require.Equal(t, http.StatusOK, rec.Code) resp := parseResponse(t, rec) data := resp["data"].(map[string]any) require.Equal(t, float64(1), data["page"]) } // ==================== Handler: GetGeneration ==================== func TestGetGeneration_Unauthorized(t *testing.T) { h := newTestSoraClientHandler(newStubSoraGenRepo()) c, rec := makeGinContext("GET", "/api/v1/sora/generations/1", "", 0) c.Params = gin.Params{{Key: "id", Value: "1"}} h.GetGeneration(c) require.Equal(t, http.StatusUnauthorized, rec.Code) } func TestGetGeneration_InvalidID(t *testing.T) { h := newTestSoraClientHandler(newStubSoraGenRepo()) c, rec := makeGinContext("GET", "/api/v1/sora/generations/abc", "", 1) c.Params = gin.Params{{Key: "id", Value: "abc"}} h.GetGeneration(c) require.Equal(t, http.StatusBadRequest, rec.Code) } func TestGetGeneration_NotFound(t *testing.T) { h := newTestSoraClientHandler(newStubSoraGenRepo()) c, rec := makeGinContext("GET", "/api/v1/sora/generations/999", "", 1) c.Params = gin.Params{{Key: "id", Value: "999"}} h.GetGeneration(c) require.Equal(t, http.StatusNotFound, rec.Code) } func TestGetGeneration_WrongUser(t *testing.T) { repo := newStubSoraGenRepo() repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 2, Status: "completed"} h := newTestSoraClientHandler(repo) c, rec := makeGinContext("GET", "/api/v1/sora/generations/1", "", 1) c.Params = gin.Params{{Key: "id", Value: "1"}} h.GetGeneration(c) require.Equal(t, http.StatusNotFound, rec.Code) } func TestGetGeneration_Success(t *testing.T) { repo := newStubSoraGenRepo() repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Model: "sora2-landscape-10s", Status: "completed", StorageType: "upstream", MediaURL: "https://example.com/video.mp4"} h := newTestSoraClientHandler(repo) c, rec := makeGinContext("GET", "/api/v1/sora/generations/1", "", 1) c.Params = gin.Params{{Key: "id", Value: "1"}} h.GetGeneration(c) require.Equal(t, http.StatusOK, rec.Code) resp := parseResponse(t, rec) data := resp["data"].(map[string]any) require.Equal(t, float64(1), data["id"]) } // ==================== Handler: DeleteGeneration ==================== func TestDeleteGeneration_Unauthorized(t *testing.T) { h := newTestSoraClientHandler(newStubSoraGenRepo()) c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 0) c.Params = gin.Params{{Key: "id", Value: "1"}} h.DeleteGeneration(c) require.Equal(t, http.StatusUnauthorized, rec.Code) } func TestDeleteGeneration_InvalidID(t *testing.T) { h := newTestSoraClientHandler(newStubSoraGenRepo()) c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/abc", "", 1) c.Params = gin.Params{{Key: "id", Value: "abc"}} h.DeleteGeneration(c) require.Equal(t, http.StatusBadRequest, rec.Code) } func TestDeleteGeneration_NotFound(t *testing.T) { h := newTestSoraClientHandler(newStubSoraGenRepo()) c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/999", "", 1) c.Params = gin.Params{{Key: "id", Value: "999"}} h.DeleteGeneration(c) require.Equal(t, http.StatusNotFound, rec.Code) } func TestDeleteGeneration_WrongUser(t *testing.T) { repo := newStubSoraGenRepo() repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 2, Status: "completed"} h := newTestSoraClientHandler(repo) c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1) c.Params = gin.Params{{Key: "id", Value: "1"}} h.DeleteGeneration(c) require.Equal(t, http.StatusNotFound, rec.Code) } func TestDeleteGeneration_Success(t *testing.T) { repo := newStubSoraGenRepo() repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed"} h := newTestSoraClientHandler(repo) c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1) c.Params = gin.Params{{Key: "id", Value: "1"}} h.DeleteGeneration(c) require.Equal(t, http.StatusOK, rec.Code) _, exists := repo.gens[1] require.False(t, exists) } // ==================== Handler: CancelGeneration ==================== func TestCancelGeneration_Unauthorized(t *testing.T) { h := newTestSoraClientHandler(newStubSoraGenRepo()) c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 0) c.Params = gin.Params{{Key: "id", Value: "1"}} h.CancelGeneration(c) require.Equal(t, http.StatusUnauthorized, rec.Code) } func TestCancelGeneration_InvalidID(t *testing.T) { h := newTestSoraClientHandler(newStubSoraGenRepo()) c, rec := makeGinContext("POST", "/api/v1/sora/generations/abc/cancel", "", 1) c.Params = gin.Params{{Key: "id", Value: "abc"}} h.CancelGeneration(c) require.Equal(t, http.StatusBadRequest, rec.Code) } func TestCancelGeneration_NotFound(t *testing.T) { h := newTestSoraClientHandler(newStubSoraGenRepo()) c, rec := makeGinContext("POST", "/api/v1/sora/generations/999/cancel", "", 1) c.Params = gin.Params{{Key: "id", Value: "999"}} h.CancelGeneration(c) require.Equal(t, http.StatusNotFound, rec.Code) } func TestCancelGeneration_WrongUser(t *testing.T) { repo := newStubSoraGenRepo() repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 2, Status: "pending"} h := newTestSoraClientHandler(repo) c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1) c.Params = gin.Params{{Key: "id", Value: "1"}} h.CancelGeneration(c) require.Equal(t, http.StatusNotFound, rec.Code) } func TestCancelGeneration_Pending(t *testing.T) { repo := newStubSoraGenRepo() repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} h := newTestSoraClientHandler(repo) c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1) c.Params = gin.Params{{Key: "id", Value: "1"}} h.CancelGeneration(c) require.Equal(t, http.StatusOK, rec.Code) require.Equal(t, "cancelled", repo.gens[1].Status) } func TestCancelGeneration_Generating(t *testing.T) { repo := newStubSoraGenRepo() repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "generating"} h := newTestSoraClientHandler(repo) c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1) c.Params = gin.Params{{Key: "id", Value: "1"}} h.CancelGeneration(c) require.Equal(t, http.StatusOK, rec.Code) require.Equal(t, "cancelled", repo.gens[1].Status) } func TestCancelGeneration_Completed(t *testing.T) { repo := newStubSoraGenRepo() repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed"} h := newTestSoraClientHandler(repo) c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1) c.Params = gin.Params{{Key: "id", Value: "1"}} h.CancelGeneration(c) require.Equal(t, http.StatusConflict, rec.Code) } func TestCancelGeneration_Failed(t *testing.T) { repo := newStubSoraGenRepo() repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "failed"} h := newTestSoraClientHandler(repo) c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1) c.Params = gin.Params{{Key: "id", Value: "1"}} h.CancelGeneration(c) require.Equal(t, http.StatusConflict, rec.Code) } func TestCancelGeneration_Cancelled(t *testing.T) { repo := newStubSoraGenRepo() repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "cancelled"} h := newTestSoraClientHandler(repo) c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1) c.Params = gin.Params{{Key: "id", Value: "1"}} h.CancelGeneration(c) require.Equal(t, http.StatusConflict, rec.Code) } // ==================== Handler: GetQuota ==================== func TestGetQuota_Unauthorized(t *testing.T) { h := newTestSoraClientHandler(newStubSoraGenRepo()) c, rec := makeGinContext("GET", "/api/v1/sora/quota", "", 0) h.GetQuota(c) require.Equal(t, http.StatusUnauthorized, rec.Code) } func TestGetQuota_NilQuotaService(t *testing.T) { h := newTestSoraClientHandler(newStubSoraGenRepo()) c, rec := makeGinContext("GET", "/api/v1/sora/quota", "", 1) h.GetQuota(c) require.Equal(t, http.StatusOK, rec.Code) resp := parseResponse(t, rec) data := resp["data"].(map[string]any) require.Equal(t, "unlimited", data["source"]) } // ==================== Handler: GetModels ==================== func TestGetModels(t *testing.T) { h := newTestSoraClientHandler(newStubSoraGenRepo()) c, rec := makeGinContext("GET", "/api/v1/sora/models", "", 0) h.GetModels(c) require.Equal(t, http.StatusOK, rec.Code) resp := parseResponse(t, rec) data := resp["data"].([]any) require.Len(t, data, 4) // 验证类型分布 videoCount, imageCount := 0, 0 for _, item := range data { m := item.(map[string]any) if m["type"] == "video" { videoCount++ } else if m["type"] == "image" { imageCount++ } } require.Equal(t, 3, videoCount) require.Equal(t, 1, imageCount) } // ==================== Handler: GetStorageStatus ==================== func TestGetStorageStatus_NilS3(t *testing.T) { h := newTestSoraClientHandler(newStubSoraGenRepo()) c, rec := makeGinContext("GET", "/api/v1/sora/storage-status", "", 0) h.GetStorageStatus(c) require.Equal(t, http.StatusOK, rec.Code) resp := parseResponse(t, rec) data := resp["data"].(map[string]any) require.Equal(t, false, data["s3_enabled"]) require.Equal(t, false, data["s3_healthy"]) require.Equal(t, false, data["local_enabled"]) } func TestGetStorageStatus_LocalEnabled(t *testing.T) { tmpDir, err := os.MkdirTemp("", "sora-storage-status-*") require.NoError(t, err) defer os.RemoveAll(tmpDir) cfg := &config.Config{ Sora: config.SoraConfig{ Storage: config.SoraStorageConfig{ Type: "local", LocalPath: tmpDir, }, }, } mediaStorage := service.NewSoraMediaStorage(cfg) h := &SoraClientHandler{mediaStorage: mediaStorage} c, rec := makeGinContext("GET", "/api/v1/sora/storage-status", "", 0) h.GetStorageStatus(c) require.Equal(t, http.StatusOK, rec.Code) resp := parseResponse(t, rec) data := resp["data"].(map[string]any) require.Equal(t, false, data["s3_enabled"]) require.Equal(t, false, data["s3_healthy"]) require.Equal(t, true, data["local_enabled"]) } // ==================== Handler: SaveToStorage ==================== func TestSaveToStorage_Unauthorized(t *testing.T) { h := newTestSoraClientHandler(newStubSoraGenRepo()) c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 0) c.Params = gin.Params{{Key: "id", Value: "1"}} h.SaveToStorage(c) require.Equal(t, http.StatusUnauthorized, rec.Code) } func TestSaveToStorage_InvalidID(t *testing.T) { h := newTestSoraClientHandler(newStubSoraGenRepo()) c, rec := makeGinContext("POST", "/api/v1/sora/generations/abc/save", "", 1) c.Params = gin.Params{{Key: "id", Value: "abc"}} h.SaveToStorage(c) require.Equal(t, http.StatusBadRequest, rec.Code) } func TestSaveToStorage_NotFound(t *testing.T) { h := newTestSoraClientHandler(newStubSoraGenRepo()) c, rec := makeGinContext("POST", "/api/v1/sora/generations/999/save", "", 1) c.Params = gin.Params{{Key: "id", Value: "999"}} h.SaveToStorage(c) require.Equal(t, http.StatusNotFound, rec.Code) } func TestSaveToStorage_NotUpstream(t *testing.T) { repo := newStubSoraGenRepo() repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed", StorageType: "s3", MediaURL: "https://example.com/v.mp4"} h := newTestSoraClientHandler(repo) c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) c.Params = gin.Params{{Key: "id", Value: "1"}} h.SaveToStorage(c) require.Equal(t, http.StatusBadRequest, rec.Code) } func TestSaveToStorage_EmptyMediaURL(t *testing.T) { repo := newStubSoraGenRepo() repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed", StorageType: "upstream", MediaURL: ""} h := newTestSoraClientHandler(repo) c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) c.Params = gin.Params{{Key: "id", Value: "1"}} h.SaveToStorage(c) require.Equal(t, http.StatusBadRequest, rec.Code) } func TestSaveToStorage_S3Nil(t *testing.T) { repo := newStubSoraGenRepo() repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed", StorageType: "upstream", MediaURL: "https://example.com/video.mp4"} h := newTestSoraClientHandler(repo) c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) c.Params = gin.Params{{Key: "id", Value: "1"}} h.SaveToStorage(c) require.Equal(t, http.StatusServiceUnavailable, rec.Code) resp := parseResponse(t, rec) require.Contains(t, fmt.Sprint(resp["message"]), "云存储") } func TestSaveToStorage_WrongUser(t *testing.T) { repo := newStubSoraGenRepo() repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 2, Status: "completed", StorageType: "upstream", MediaURL: "https://example.com/video.mp4"} h := newTestSoraClientHandler(repo) c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) c.Params = gin.Params{{Key: "id", Value: "1"}} h.SaveToStorage(c) require.Equal(t, http.StatusNotFound, rec.Code) } // ==================== storeMediaWithDegradation — nil guard 路径 ==================== func TestStoreMediaWithDegradation_NilS3NilMedia(t *testing.T) { h := &SoraClientHandler{} url, urls, storageType, keys, size := h.storeMediaWithDegradation( context.Background(), 1, "video", "https://upstream.com/v.mp4", nil, ) require.Equal(t, service.SoraStorageTypeUpstream, storageType) require.Equal(t, "https://upstream.com/v.mp4", url) require.Equal(t, []string{"https://upstream.com/v.mp4"}, urls) require.Nil(t, keys) require.Equal(t, int64(0), size) } func TestStoreMediaWithDegradation_NilGuardsMultiURL(t *testing.T) { h := &SoraClientHandler{} url, urls, storageType, keys, size := h.storeMediaWithDegradation( context.Background(), 1, "video", "https://upstream.com/v.mp4", []string{"https://a.com/1.mp4", "https://a.com/2.mp4"}, ) require.Equal(t, service.SoraStorageTypeUpstream, storageType) require.Equal(t, "https://a.com/1.mp4", url) require.Equal(t, []string{"https://a.com/1.mp4", "https://a.com/2.mp4"}, urls) require.Nil(t, keys) require.Equal(t, int64(0), size) } func TestStoreMediaWithDegradation_EmptyMediaURLsFallback(t *testing.T) { h := &SoraClientHandler{} url, _, storageType, _, _ := h.storeMediaWithDegradation( context.Background(), 1, "video", "https://upstream.com/v.mp4", []string{}, ) require.Equal(t, service.SoraStorageTypeUpstream, storageType) require.Equal(t, "https://upstream.com/v.mp4", url) } // ==================== Stub: UserRepository (用于 SoraQuotaService) ==================== var _ service.UserRepository = (*stubUserRepoForHandler)(nil) type stubUserRepoForHandler struct { users map[int64]*service.User updateErr error } func newStubUserRepoForHandler() *stubUserRepoForHandler { return &stubUserRepoForHandler{users: make(map[int64]*service.User)} } func (r *stubUserRepoForHandler) GetByID(_ context.Context, id int64) (*service.User, error) { if u, ok := r.users[id]; ok { return u, nil } return nil, fmt.Errorf("user not found") } func (r *stubUserRepoForHandler) Update(_ context.Context, user *service.User) error { if r.updateErr != nil { return r.updateErr } r.users[user.ID] = user return nil } func (r *stubUserRepoForHandler) Create(context.Context, *service.User) error { return nil } func (r *stubUserRepoForHandler) GetByEmail(context.Context, string) (*service.User, error) { return nil, nil } func (r *stubUserRepoForHandler) GetFirstAdmin(context.Context) (*service.User, error) { return nil, nil } func (r *stubUserRepoForHandler) Delete(context.Context, int64) error { return nil } func (r *stubUserRepoForHandler) List(context.Context, pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) { return nil, nil, nil } func (r *stubUserRepoForHandler) ListWithFilters(context.Context, pagination.PaginationParams, service.UserListFilters) ([]service.User, *pagination.PaginationResult, error) { return nil, nil, nil } func (r *stubUserRepoForHandler) UpdateBalance(context.Context, int64, float64) error { return nil } func (r *stubUserRepoForHandler) DeductBalance(context.Context, int64, float64) error { return nil } func (r *stubUserRepoForHandler) UpdateConcurrency(context.Context, int64, int) error { return nil } func (r *stubUserRepoForHandler) ExistsByEmail(context.Context, string) (bool, error) { return false, nil } func (r *stubUserRepoForHandler) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) { return 0, nil } func (r *stubUserRepoForHandler) UpdateTotpSecret(context.Context, int64, *string) error { return nil } func (r *stubUserRepoForHandler) EnableTotp(context.Context, int64) error { return nil } func (r *stubUserRepoForHandler) DisableTotp(context.Context, int64) error { return nil } // ==================== NewSoraClientHandler ==================== func TestNewSoraClientHandler(t *testing.T) { h := NewSoraClientHandler(nil, nil, nil, nil, nil, nil, nil) require.NotNil(t, h) } func TestNewSoraClientHandler_WithAPIKeyService(t *testing.T) { h := NewSoraClientHandler(nil, nil, nil, nil, nil, nil, nil) require.NotNil(t, h) require.Nil(t, h.apiKeyService) } // ==================== Stub: APIKeyRepository (用于 API Key 校验测试) ==================== var _ service.APIKeyRepository = (*stubAPIKeyRepoForHandler)(nil) type stubAPIKeyRepoForHandler struct { keys map[int64]*service.APIKey getErr error } func newStubAPIKeyRepoForHandler() *stubAPIKeyRepoForHandler { return &stubAPIKeyRepoForHandler{keys: make(map[int64]*service.APIKey)} } func (r *stubAPIKeyRepoForHandler) GetByID(_ context.Context, id int64) (*service.APIKey, error) { if r.getErr != nil { return nil, r.getErr } if k, ok := r.keys[id]; ok { return k, nil } return nil, fmt.Errorf("api key not found: %d", id) } func (r *stubAPIKeyRepoForHandler) Create(context.Context, *service.APIKey) error { return nil } func (r *stubAPIKeyRepoForHandler) GetKeyAndOwnerID(_ context.Context, _ int64) (string, int64, error) { return "", 0, nil } func (r *stubAPIKeyRepoForHandler) GetByKey(context.Context, string) (*service.APIKey, error) { return nil, nil } func (r *stubAPIKeyRepoForHandler) GetByKeyForAuth(context.Context, string) (*service.APIKey, error) { return nil, nil } func (r *stubAPIKeyRepoForHandler) Update(context.Context, *service.APIKey) error { return nil } func (r *stubAPIKeyRepoForHandler) Delete(context.Context, int64) error { return nil } func (r *stubAPIKeyRepoForHandler) ListByUserID(_ context.Context, _ int64, _ pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) { return nil, nil, nil } func (r *stubAPIKeyRepoForHandler) VerifyOwnership(context.Context, int64, []int64) ([]int64, error) { return nil, nil } func (r *stubAPIKeyRepoForHandler) CountByUserID(context.Context, int64) (int64, error) { return 0, nil } func (r *stubAPIKeyRepoForHandler) ExistsByKey(context.Context, string) (bool, error) { return false, nil } func (r *stubAPIKeyRepoForHandler) ListByGroupID(_ context.Context, _ int64, _ pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) { return nil, nil, nil } func (r *stubAPIKeyRepoForHandler) SearchAPIKeys(context.Context, int64, string, int) ([]service.APIKey, error) { return nil, nil } func (r *stubAPIKeyRepoForHandler) ClearGroupIDByGroupID(context.Context, int64) (int64, error) { return 0, nil } func (r *stubAPIKeyRepoForHandler) CountByGroupID(context.Context, int64) (int64, error) { return 0, nil } func (r *stubAPIKeyRepoForHandler) ListKeysByUserID(context.Context, int64) ([]string, error) { return nil, nil } func (r *stubAPIKeyRepoForHandler) ListKeysByGroupID(context.Context, int64) ([]string, error) { return nil, nil } func (r *stubAPIKeyRepoForHandler) IncrementQuotaUsed(_ context.Context, _ int64, _ float64) (float64, error) { return 0, nil } func (r *stubAPIKeyRepoForHandler) UpdateLastUsed(context.Context, int64, time.Time) error { return nil } // newTestAPIKeyService 创建测试用的 APIKeyService func newTestAPIKeyService(repo *stubAPIKeyRepoForHandler) *service.APIKeyService { return service.NewAPIKeyService(repo, nil, nil, nil, nil, nil, &config.Config{}) } // ==================== Generate: API Key 校验(前端传递 api_key_id)==================== func TestGenerate_WithAPIKeyID_Success(t *testing.T) { // 前端传递 api_key_id,校验通过 → 成功生成,记录关联 api_key_id repo := newStubSoraGenRepo() genService := service.NewSoraGenerationService(repo, nil, nil) groupID := int64(5) apiKeyRepo := newStubAPIKeyRepoForHandler() apiKeyRepo.keys[42] = &service.APIKey{ ID: 42, UserID: 1, Status: service.StatusAPIKeyActive, GroupID: &groupID, } apiKeyService := newTestAPIKeyService(apiKeyRepo) h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService} c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1) h.Generate(c) require.Equal(t, http.StatusOK, rec.Code) resp := parseResponse(t, rec) data := resp["data"].(map[string]any) require.NotZero(t, data["generation_id"]) // 验证 api_key_id 已关联到生成记录 gen := repo.gens[1] require.NotNil(t, gen.APIKeyID) require.Equal(t, int64(42), *gen.APIKeyID) } func TestGenerate_WithAPIKeyID_NotFound(t *testing.T) { // 前端传递不存在的 api_key_id → 400 repo := newStubSoraGenRepo() genService := service.NewSoraGenerationService(repo, nil, nil) apiKeyRepo := newStubAPIKeyRepoForHandler() apiKeyService := newTestAPIKeyService(apiKeyRepo) h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService} c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":999}`, 1) h.Generate(c) require.Equal(t, http.StatusBadRequest, rec.Code) resp := parseResponse(t, rec) require.Contains(t, fmt.Sprint(resp["message"]), "不存在") } func TestGenerate_WithAPIKeyID_WrongUser(t *testing.T) { // 前端传递别人的 api_key_id → 403 repo := newStubSoraGenRepo() genService := service.NewSoraGenerationService(repo, nil, nil) apiKeyRepo := newStubAPIKeyRepoForHandler() apiKeyRepo.keys[42] = &service.APIKey{ ID: 42, UserID: 999, // 属于 user 999 Status: service.StatusAPIKeyActive, } apiKeyService := newTestAPIKeyService(apiKeyRepo) h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService} c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1) h.Generate(c) require.Equal(t, http.StatusForbidden, rec.Code) resp := parseResponse(t, rec) require.Contains(t, fmt.Sprint(resp["message"]), "不属于") } func TestGenerate_WithAPIKeyID_Disabled(t *testing.T) { // 前端传递已禁用的 api_key_id → 403 repo := newStubSoraGenRepo() genService := service.NewSoraGenerationService(repo, nil, nil) apiKeyRepo := newStubAPIKeyRepoForHandler() apiKeyRepo.keys[42] = &service.APIKey{ ID: 42, UserID: 1, Status: service.StatusAPIKeyDisabled, } apiKeyService := newTestAPIKeyService(apiKeyRepo) h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService} c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1) h.Generate(c) require.Equal(t, http.StatusForbidden, rec.Code) resp := parseResponse(t, rec) require.Contains(t, fmt.Sprint(resp["message"]), "不可用") } func TestGenerate_WithAPIKeyID_QuotaExhausted(t *testing.T) { // 前端传递配额耗尽的 api_key_id → 403 repo := newStubSoraGenRepo() genService := service.NewSoraGenerationService(repo, nil, nil) apiKeyRepo := newStubAPIKeyRepoForHandler() apiKeyRepo.keys[42] = &service.APIKey{ ID: 42, UserID: 1, Status: service.StatusAPIKeyQuotaExhausted, } apiKeyService := newTestAPIKeyService(apiKeyRepo) h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService} c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1) h.Generate(c) require.Equal(t, http.StatusForbidden, rec.Code) } func TestGenerate_WithAPIKeyID_Expired(t *testing.T) { // 前端传递已过期的 api_key_id → 403 repo := newStubSoraGenRepo() genService := service.NewSoraGenerationService(repo, nil, nil) apiKeyRepo := newStubAPIKeyRepoForHandler() apiKeyRepo.keys[42] = &service.APIKey{ ID: 42, UserID: 1, Status: service.StatusAPIKeyExpired, } apiKeyService := newTestAPIKeyService(apiKeyRepo) h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService} c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1) h.Generate(c) require.Equal(t, http.StatusForbidden, rec.Code) } func TestGenerate_WithAPIKeyID_NilAPIKeyService(t *testing.T) { // apiKeyService 为 nil 时忽略 api_key_id → 正常生成但不记录 api_key_id repo := newStubSoraGenRepo() genService := service.NewSoraGenerationService(repo, nil, nil) h := &SoraClientHandler{genService: genService} // apiKeyService = nil c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1) h.Generate(c) require.Equal(t, http.StatusOK, rec.Code) // apiKeyService 为 nil → 跳过校验 → api_key_id 不记录 require.Nil(t, repo.gens[1].APIKeyID) } func TestGenerate_WithAPIKeyID_NilGroupID(t *testing.T) { // api_key 有效但 GroupID 为 nil → 成功,groupID 为 nil repo := newStubSoraGenRepo() genService := service.NewSoraGenerationService(repo, nil, nil) apiKeyRepo := newStubAPIKeyRepoForHandler() apiKeyRepo.keys[42] = &service.APIKey{ ID: 42, UserID: 1, Status: service.StatusAPIKeyActive, GroupID: nil, // 无分组 } apiKeyService := newTestAPIKeyService(apiKeyRepo) h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService} c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1) h.Generate(c) require.Equal(t, http.StatusOK, rec.Code) require.NotNil(t, repo.gens[1].APIKeyID) require.Equal(t, int64(42), *repo.gens[1].APIKeyID) } func TestGenerate_NoAPIKeyID_NoContext_NilResult(t *testing.T) { // 既无 api_key_id 字段也无 context 中的 api_key_id → api_key_id 为 nil repo := newStubSoraGenRepo() genService := service.NewSoraGenerationService(repo, nil, nil) apiKeyRepo := newStubAPIKeyRepoForHandler() apiKeyService := newTestAPIKeyService(apiKeyRepo) h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService} c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1) h.Generate(c) require.Equal(t, http.StatusOK, rec.Code) require.Nil(t, repo.gens[1].APIKeyID) } func TestGenerate_WithAPIKeyIDInBody_OverridesContext(t *testing.T) { // 同时有 body api_key_id 和 context api_key_id → 优先使用 body 的 repo := newStubSoraGenRepo() genService := service.NewSoraGenerationService(repo, nil, nil) groupID := int64(10) apiKeyRepo := newStubAPIKeyRepoForHandler() apiKeyRepo.keys[42] = &service.APIKey{ ID: 42, UserID: 1, Status: service.StatusAPIKeyActive, GroupID: &groupID, } apiKeyService := newTestAPIKeyService(apiKeyRepo) h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService} c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1) c.Set("api_key_id", int64(99)) // context 中有另一个 api_key_id h.Generate(c) require.Equal(t, http.StatusOK, rec.Code) // 应使用 body 中的 api_key_id=42,而不是 context 中的 99 require.NotNil(t, repo.gens[1].APIKeyID) require.Equal(t, int64(42), *repo.gens[1].APIKeyID) } func TestGenerate_WithContextAPIKeyID_FallbackPath(t *testing.T) { // 无 body api_key_id,但 context 有 → 使用 context 中的(兼容网关路由) repo := newStubSoraGenRepo() genService := service.NewSoraGenerationService(repo, nil, nil) apiKeyRepo := newStubAPIKeyRepoForHandler() apiKeyService := newTestAPIKeyService(apiKeyRepo) h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService} c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1) c.Set("api_key_id", int64(99)) h.Generate(c) require.Equal(t, http.StatusOK, rec.Code) // 应使用 context 中的 api_key_id=99 require.NotNil(t, repo.gens[1].APIKeyID) require.Equal(t, int64(99), *repo.gens[1].APIKeyID) } func TestGenerate_APIKeyID_Zero_IgnoredInJSON(t *testing.T) { // JSON 中 api_key_id=0 被视为 omitempty → 仍然为指针值 0,需要传 nil 检查 repo := newStubSoraGenRepo() genService := service.NewSoraGenerationService(repo, nil, nil) apiKeyRepo := newStubAPIKeyRepoForHandler() apiKeyService := newTestAPIKeyService(apiKeyRepo) h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService} // JSON 中传了 api_key_id: 0 → 解析后 *int64(0),会触发校验 // api_key_id=0 不存在 → 400 c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":0}`, 1) h.Generate(c) require.Equal(t, http.StatusBadRequest, rec.Code) } // ==================== processGeneration: groupID 传递与 ForcePlatform ==================== func TestProcessGeneration_WithGroupID_NoForcePlatform(t *testing.T) { // groupID 不为 nil → 不设置 ForcePlatform // gatewayService 为 nil → MarkFailed → 检查错误消息不包含 ForcePlatform 相关 repo := newStubSoraGenRepo() repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} genService := service.NewSoraGenerationService(repo, nil, nil) h := &SoraClientHandler{genService: genService} gid := int64(5) h.processGeneration(1, 1, &gid, "sora2-landscape-10s", "test", "video", "", 1) require.Equal(t, "failed", repo.gens[1].Status) require.Contains(t, repo.gens[1].ErrorMessage, "gatewayService") } func TestProcessGeneration_NilGroupID_SetsForcePlatform(t *testing.T) { // groupID 为 nil → 设置 ForcePlatform → gatewayService 为 nil → MarkFailed repo := newStubSoraGenRepo() repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} genService := service.NewSoraGenerationService(repo, nil, nil) h := &SoraClientHandler{genService: genService} h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1) require.Equal(t, "failed", repo.gens[1].Status) require.Contains(t, repo.gens[1].ErrorMessage, "gatewayService") } func TestProcessGeneration_MarkGeneratingStateConflict(t *testing.T) { // 任务状态已变化(如已取消)→ MarkGenerating 返回 ErrSoraGenerationStateConflict → 跳过 repo := newStubSoraGenRepo() repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "cancelled"} genService := service.NewSoraGenerationService(repo, nil, nil) h := &SoraClientHandler{genService: genService} h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1) // 状态为 cancelled 时 MarkGenerating 不符合状态转换规则 → 应保持 cancelled require.Equal(t, "cancelled", repo.gens[1].Status) } // ==================== GenerateRequest JSON 解析 ==================== func TestGenerateRequest_WithAPIKeyID_JSONParsing(t *testing.T) { // 验证 api_key_id 在 JSON 中正确解析为 *int64 var req GenerateRequest err := json.Unmarshal([]byte(`{"model":"sora2","prompt":"test","api_key_id":42}`), &req) require.NoError(t, err) require.NotNil(t, req.APIKeyID) require.Equal(t, int64(42), *req.APIKeyID) } func TestGenerateRequest_WithoutAPIKeyID_JSONParsing(t *testing.T) { // 不传 api_key_id → 解析后为 nil var req GenerateRequest err := json.Unmarshal([]byte(`{"model":"sora2","prompt":"test"}`), &req) require.NoError(t, err) require.Nil(t, req.APIKeyID) } func TestGenerateRequest_NullAPIKeyID_JSONParsing(t *testing.T) { // api_key_id: null → 解析后为 nil var req GenerateRequest err := json.Unmarshal([]byte(`{"model":"sora2","prompt":"test","api_key_id":null}`), &req) require.NoError(t, err) require.Nil(t, req.APIKeyID) } func TestGenerateRequest_FullFields_JSONParsing(t *testing.T) { // 全字段解析 var req GenerateRequest err := json.Unmarshal([]byte(`{ "model":"sora2-landscape-10s", "prompt":"test prompt", "media_type":"video", "video_count":2, "image_input":"data:image/png;base64,abc", "api_key_id":100 }`), &req) require.NoError(t, err) require.Equal(t, "sora2-landscape-10s", req.Model) require.Equal(t, "test prompt", req.Prompt) require.Equal(t, "video", req.MediaType) require.Equal(t, 2, req.VideoCount) require.Equal(t, "data:image/png;base64,abc", req.ImageInput) require.NotNil(t, req.APIKeyID) require.Equal(t, int64(100), *req.APIKeyID) } func TestGenerateRequest_JSONSerialize_OmitsNilAPIKeyID(t *testing.T) { // api_key_id 为 nil 时 JSON 序列化应省略 req := GenerateRequest{Model: "sora2", Prompt: "test"} b, err := json.Marshal(req) require.NoError(t, err) var parsed map[string]any require.NoError(t, json.Unmarshal(b, &parsed)) _, hasAPIKeyID := parsed["api_key_id"] require.False(t, hasAPIKeyID, "api_key_id 为 nil 时应省略") } func TestGenerateRequest_JSONSerialize_IncludesAPIKeyID(t *testing.T) { // api_key_id 不为 nil 时 JSON 序列化应包含 id := int64(42) req := GenerateRequest{Model: "sora2", Prompt: "test", APIKeyID: &id} b, err := json.Marshal(req) require.NoError(t, err) var parsed map[string]any require.NoError(t, json.Unmarshal(b, &parsed)) require.Equal(t, float64(42), parsed["api_key_id"]) } // ==================== GetQuota: 有配额服务 ==================== func TestGetQuota_WithQuotaService_Success(t *testing.T) { userRepo := newStubUserRepoForHandler() userRepo.users[1] = &service.User{ ID: 1, SoraStorageQuotaBytes: 10 * 1024 * 1024, SoraStorageUsedBytes: 3 * 1024 * 1024, } quotaService := service.NewSoraQuotaService(userRepo, nil, nil) repo := newStubSoraGenRepo() genService := service.NewSoraGenerationService(repo, nil, nil) h := &SoraClientHandler{ genService: genService, quotaService: quotaService, } c, rec := makeGinContext("GET", "/api/v1/sora/quota", "", 1) h.GetQuota(c) require.Equal(t, http.StatusOK, rec.Code) resp := parseResponse(t, rec) data := resp["data"].(map[string]any) require.Equal(t, "user", data["source"]) require.Equal(t, float64(10*1024*1024), data["quota_bytes"]) require.Equal(t, float64(3*1024*1024), data["used_bytes"]) } func TestGetQuota_WithQuotaService_Error(t *testing.T) { // 用户不存在时 GetQuota 返回错误 userRepo := newStubUserRepoForHandler() quotaService := service.NewSoraQuotaService(userRepo, nil, nil) repo := newStubSoraGenRepo() genService := service.NewSoraGenerationService(repo, nil, nil) h := &SoraClientHandler{ genService: genService, quotaService: quotaService, } c, rec := makeGinContext("GET", "/api/v1/sora/quota", "", 999) h.GetQuota(c) require.Equal(t, http.StatusInternalServerError, rec.Code) } // ==================== Generate: 配额检查 ==================== func TestGenerate_QuotaCheckFailed(t *testing.T) { // 配额超限时返回 429 userRepo := newStubUserRepoForHandler() userRepo.users[1] = &service.User{ ID: 1, SoraStorageQuotaBytes: 1024, SoraStorageUsedBytes: 1025, // 已超限 } quotaService := service.NewSoraQuotaService(userRepo, nil, nil) repo := newStubSoraGenRepo() genService := service.NewSoraGenerationService(repo, nil, nil) h := &SoraClientHandler{ genService: genService, quotaService: quotaService, } c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1) h.Generate(c) require.Equal(t, http.StatusTooManyRequests, rec.Code) } func TestGenerate_QuotaCheckPassed(t *testing.T) { // 配额充足时允许生成 userRepo := newStubUserRepoForHandler() userRepo.users[1] = &service.User{ ID: 1, SoraStorageQuotaBytes: 10 * 1024 * 1024, SoraStorageUsedBytes: 0, } quotaService := service.NewSoraQuotaService(userRepo, nil, nil) repo := newStubSoraGenRepo() genService := service.NewSoraGenerationService(repo, nil, nil) h := &SoraClientHandler{ genService: genService, quotaService: quotaService, } c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1) h.Generate(c) require.Equal(t, http.StatusOK, rec.Code) } // ==================== Stub: SettingRepository (用于 S3 存储测试) ==================== var _ service.SettingRepository = (*stubSettingRepoForHandler)(nil) type stubSettingRepoForHandler struct { values map[string]string } func newStubSettingRepoForHandler(values map[string]string) *stubSettingRepoForHandler { if values == nil { values = make(map[string]string) } return &stubSettingRepoForHandler{values: values} } func (r *stubSettingRepoForHandler) Get(_ context.Context, key string) (*service.Setting, error) { if v, ok := r.values[key]; ok { return &service.Setting{Key: key, Value: v}, nil } return nil, service.ErrSettingNotFound } func (r *stubSettingRepoForHandler) GetValue(_ context.Context, key string) (string, error) { if v, ok := r.values[key]; ok { return v, nil } return "", service.ErrSettingNotFound } func (r *stubSettingRepoForHandler) Set(_ context.Context, key, value string) error { r.values[key] = value return nil } func (r *stubSettingRepoForHandler) GetMultiple(_ context.Context, keys []string) (map[string]string, error) { result := make(map[string]string) for _, k := range keys { if v, ok := r.values[k]; ok { result[k] = v } } return result, nil } func (r *stubSettingRepoForHandler) SetMultiple(_ context.Context, settings map[string]string) error { for k, v := range settings { r.values[k] = v } return nil } func (r *stubSettingRepoForHandler) GetAll(_ context.Context) (map[string]string, error) { return r.values, nil } func (r *stubSettingRepoForHandler) Delete(_ context.Context, key string) error { delete(r.values, key) return nil } // ==================== S3 / MediaStorage 辅助函数 ==================== // newS3StorageForHandler 创建指向指定 endpoint 的 S3Storage(用于测试)。 func newS3StorageForHandler(endpoint string) *service.SoraS3Storage { settingRepo := newStubSettingRepoForHandler(map[string]string{ "sora_s3_enabled": "true", "sora_s3_endpoint": endpoint, "sora_s3_region": "us-east-1", "sora_s3_bucket": "test-bucket", "sora_s3_access_key_id": "AKIATEST", "sora_s3_secret_access_key": "test-secret", "sora_s3_prefix": "sora", "sora_s3_force_path_style": "true", }) settingService := service.NewSettingService(settingRepo, &config.Config{}) return service.NewSoraS3Storage(settingService) } // newFakeSourceServer 创建返回固定内容的 HTTP 服务器(模拟上游媒体文件)。 func newFakeSourceServer() *httptest.Server { return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "video/mp4") w.WriteHeader(http.StatusOK) _, _ = w.Write([]byte("fake video data for test")) })) } // newFakeS3Server 创建模拟 S3 的 HTTP 服务器。 // mode: "ok" 接受所有请求,"fail" 返回 403,"fail-second" 第一次成功第二次失败。 func newFakeS3Server(mode string) *httptest.Server { var counter atomic.Int32 return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { _, _ = io.Copy(io.Discard, r.Body) _ = r.Body.Close() switch mode { case "ok": w.Header().Set("ETag", `"test-etag"`) w.WriteHeader(http.StatusOK) case "fail": w.WriteHeader(http.StatusForbidden) _, _ = w.Write([]byte(`AccessDenied`)) case "fail-second": n := counter.Add(1) if n <= 1 { w.Header().Set("ETag", `"test-etag"`) w.WriteHeader(http.StatusOK) } else { w.WriteHeader(http.StatusForbidden) _, _ = w.Write([]byte(`AccessDenied`)) } } })) } // ==================== processGeneration 直接调用测试 ==================== func TestProcessGeneration_MarkGeneratingFails(t *testing.T) { repo := newStubSoraGenRepo() repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} repo.updateErr = fmt.Errorf("db error") genService := service.NewSoraGenerationService(repo, nil, nil) h := &SoraClientHandler{genService: genService} // 直接调用(非 goroutine),MarkGenerating 失败 → 早退 h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1) // MarkGenerating 在调用 repo.Update 前已修改内存对象为 "generating" // repo.Update 返回错误 → processGeneration 早退,不会继续到 MarkFailed // 因此 ErrorMessage 为空(证明未调用 MarkFailed) require.Equal(t, "generating", repo.gens[1].Status) require.Empty(t, repo.gens[1].ErrorMessage) } func TestProcessGeneration_GatewayServiceNil(t *testing.T) { repo := newStubSoraGenRepo() repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} genService := service.NewSoraGenerationService(repo, nil, nil) h := &SoraClientHandler{genService: genService} // gatewayService 未设置 → MarkFailed h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1) require.Equal(t, "failed", repo.gens[1].Status) require.Contains(t, repo.gens[1].ErrorMessage, "gatewayService") } // ==================== storeMediaWithDegradation: S3 路径 ==================== func TestStoreMediaWithDegradation_S3SuccessSingleURL(t *testing.T) { sourceServer := newFakeSourceServer() defer sourceServer.Close() fakeS3 := newFakeS3Server("ok") defer fakeS3.Close() s3Storage := newS3StorageForHandler(fakeS3.URL) h := &SoraClientHandler{s3Storage: s3Storage} storedURL, storedURLs, storageType, s3Keys, fileSize := h.storeMediaWithDegradation( context.Background(), 1, "video", sourceServer.URL+"/v.mp4", nil, ) require.Equal(t, service.SoraStorageTypeS3, storageType) require.Len(t, s3Keys, 1) require.NotEmpty(t, s3Keys[0]) require.Len(t, storedURLs, 1) require.Equal(t, storedURL, storedURLs[0]) require.Contains(t, storedURL, fakeS3.URL) require.Contains(t, storedURL, "/test-bucket/") require.Greater(t, fileSize, int64(0)) } func TestStoreMediaWithDegradation_S3SuccessMultiURL(t *testing.T) { sourceServer := newFakeSourceServer() defer sourceServer.Close() fakeS3 := newFakeS3Server("ok") defer fakeS3.Close() s3Storage := newS3StorageForHandler(fakeS3.URL) h := &SoraClientHandler{s3Storage: s3Storage} urls := []string{sourceServer.URL + "/a.mp4", sourceServer.URL + "/b.mp4"} storedURL, storedURLs, storageType, s3Keys, fileSize := h.storeMediaWithDegradation( context.Background(), 1, "video", sourceServer.URL+"/a.mp4", urls, ) require.Equal(t, service.SoraStorageTypeS3, storageType) require.Len(t, s3Keys, 2) require.Len(t, storedURLs, 2) require.Equal(t, storedURL, storedURLs[0]) require.Contains(t, storedURLs[0], fakeS3.URL) require.Contains(t, storedURLs[1], fakeS3.URL) require.Greater(t, fileSize, int64(0)) } func TestStoreMediaWithDegradation_S3DownloadFails(t *testing.T) { // 上游返回 404 → 下载失败 → S3 上传不会开始 fakeS3 := newFakeS3Server("ok") defer fakeS3.Close() badSource := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusNotFound) })) defer badSource.Close() s3Storage := newS3StorageForHandler(fakeS3.URL) h := &SoraClientHandler{s3Storage: s3Storage} _, _, storageType, _, _ := h.storeMediaWithDegradation( context.Background(), 1, "video", badSource.URL+"/missing.mp4", nil, ) require.Equal(t, service.SoraStorageTypeUpstream, storageType) } func TestStoreMediaWithDegradation_S3FailsSingleURL(t *testing.T) { sourceServer := newFakeSourceServer() defer sourceServer.Close() fakeS3 := newFakeS3Server("fail") defer fakeS3.Close() s3Storage := newS3StorageForHandler(fakeS3.URL) h := &SoraClientHandler{s3Storage: s3Storage} _, _, storageType, s3Keys, _ := h.storeMediaWithDegradation( context.Background(), 1, "video", sourceServer.URL+"/v.mp4", nil, ) // S3 失败,降级到 upstream require.Equal(t, service.SoraStorageTypeUpstream, storageType) require.Nil(t, s3Keys) } func TestStoreMediaWithDegradation_S3PartialFailureCleanup(t *testing.T) { sourceServer := newFakeSourceServer() defer sourceServer.Close() fakeS3 := newFakeS3Server("fail-second") defer fakeS3.Close() s3Storage := newS3StorageForHandler(fakeS3.URL) h := &SoraClientHandler{s3Storage: s3Storage} urls := []string{sourceServer.URL + "/a.mp4", sourceServer.URL + "/b.mp4"} _, _, storageType, s3Keys, _ := h.storeMediaWithDegradation( context.Background(), 1, "video", sourceServer.URL+"/a.mp4", urls, ) // 第二个 URL 上传失败 → 清理已上传 → 降级到 upstream require.Equal(t, service.SoraStorageTypeUpstream, storageType) require.Nil(t, s3Keys) } // ==================== storeMediaWithDegradation: 本地存储路径 ==================== func TestStoreMediaWithDegradation_LocalStorageFails(t *testing.T) { // 使用无效路径,EnsureLocalDirs 失败 → StoreFromURLs 返回 error cfg := &config.Config{ Sora: config.SoraConfig{ Storage: config.SoraStorageConfig{ Type: "local", LocalPath: "/dev/null/invalid_dir", }, }, } mediaStorage := service.NewSoraMediaStorage(cfg) h := &SoraClientHandler{mediaStorage: mediaStorage} _, _, storageType, _, _ := h.storeMediaWithDegradation( context.Background(), 1, "video", "https://upstream.com/v.mp4", nil, ) // 本地存储失败,降级到 upstream require.Equal(t, service.SoraStorageTypeUpstream, storageType) } func TestStoreMediaWithDegradation_LocalStorageSuccess(t *testing.T) { tmpDir, err := os.MkdirTemp("", "sora-handler-test-*") require.NoError(t, err) defer os.RemoveAll(tmpDir) sourceServer := newFakeSourceServer() defer sourceServer.Close() cfg := &config.Config{ Sora: config.SoraConfig{ Storage: config.SoraStorageConfig{ Type: "local", LocalPath: tmpDir, DownloadTimeoutSeconds: 5, MaxDownloadBytes: 10 * 1024 * 1024, }, }, } mediaStorage := service.NewSoraMediaStorage(cfg) h := &SoraClientHandler{mediaStorage: mediaStorage} _, _, storageType, s3Keys, _ := h.storeMediaWithDegradation( context.Background(), 1, "video", sourceServer.URL+"/v.mp4", nil, ) require.Equal(t, service.SoraStorageTypeLocal, storageType) require.Nil(t, s3Keys) // 本地存储不返回 S3 keys } func TestStoreMediaWithDegradation_S3FailsFallbackToLocal(t *testing.T) { tmpDir, err := os.MkdirTemp("", "sora-handler-test-*") require.NoError(t, err) defer os.RemoveAll(tmpDir) sourceServer := newFakeSourceServer() defer sourceServer.Close() fakeS3 := newFakeS3Server("fail") defer fakeS3.Close() s3Storage := newS3StorageForHandler(fakeS3.URL) cfg := &config.Config{ Sora: config.SoraConfig{ Storage: config.SoraStorageConfig{ Type: "local", LocalPath: tmpDir, DownloadTimeoutSeconds: 5, MaxDownloadBytes: 10 * 1024 * 1024, }, }, } mediaStorage := service.NewSoraMediaStorage(cfg) h := &SoraClientHandler{ s3Storage: s3Storage, mediaStorage: mediaStorage, } _, _, storageType, _, _ := h.storeMediaWithDegradation( context.Background(), 1, "video", sourceServer.URL+"/v.mp4", nil, ) // S3 失败 → 本地存储成功 require.Equal(t, service.SoraStorageTypeLocal, storageType) } // ==================== SaveToStorage: S3 路径 ==================== func TestSaveToStorage_S3EnabledButUploadFails(t *testing.T) { sourceServer := newFakeSourceServer() defer sourceServer.Close() fakeS3 := newFakeS3Server("fail") defer fakeS3.Close() repo := newStubSoraGenRepo() repo.gens[1] = &service.SoraGeneration{ ID: 1, UserID: 1, Status: "completed", StorageType: "upstream", MediaURL: sourceServer.URL + "/v.mp4", } s3Storage := newS3StorageForHandler(fakeS3.URL) genService := service.NewSoraGenerationService(repo, nil, nil) h := &SoraClientHandler{genService: genService, s3Storage: s3Storage} c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) c.Params = gin.Params{{Key: "id", Value: "1"}} h.SaveToStorage(c) require.Equal(t, http.StatusInternalServerError, rec.Code) resp := parseResponse(t, rec) require.Contains(t, resp["message"], "S3") } func TestSaveToStorage_UpstreamURLExpired(t *testing.T) { expiredServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusForbidden) })) defer expiredServer.Close() fakeS3 := newFakeS3Server("ok") defer fakeS3.Close() repo := newStubSoraGenRepo() repo.gens[1] = &service.SoraGeneration{ ID: 1, UserID: 1, Status: "completed", StorageType: "upstream", MediaURL: expiredServer.URL + "/v.mp4", } s3Storage := newS3StorageForHandler(fakeS3.URL) genService := service.NewSoraGenerationService(repo, nil, nil) h := &SoraClientHandler{genService: genService, s3Storage: s3Storage} c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) c.Params = gin.Params{{Key: "id", Value: "1"}} h.SaveToStorage(c) require.Equal(t, http.StatusGone, rec.Code) resp := parseResponse(t, rec) require.Contains(t, fmt.Sprint(resp["message"]), "过期") } func TestSaveToStorage_S3EnabledUploadSuccess(t *testing.T) { sourceServer := newFakeSourceServer() defer sourceServer.Close() fakeS3 := newFakeS3Server("ok") defer fakeS3.Close() repo := newStubSoraGenRepo() repo.gens[1] = &service.SoraGeneration{ ID: 1, UserID: 1, Status: "completed", StorageType: "upstream", MediaURL: sourceServer.URL + "/v.mp4", } s3Storage := newS3StorageForHandler(fakeS3.URL) genService := service.NewSoraGenerationService(repo, nil, nil) h := &SoraClientHandler{genService: genService, s3Storage: s3Storage} c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) c.Params = gin.Params{{Key: "id", Value: "1"}} h.SaveToStorage(c) require.Equal(t, http.StatusOK, rec.Code) resp := parseResponse(t, rec) data := resp["data"].(map[string]any) require.Contains(t, data["message"], "S3") require.NotEmpty(t, data["object_key"]) // 验证记录已更新为 S3 存储 require.Equal(t, service.SoraStorageTypeS3, repo.gens[1].StorageType) } func TestSaveToStorage_S3EnabledUploadSuccess_MultiMediaURLs(t *testing.T) { sourceServer := newFakeSourceServer() defer sourceServer.Close() fakeS3 := newFakeS3Server("ok") defer fakeS3.Close() repo := newStubSoraGenRepo() repo.gens[1] = &service.SoraGeneration{ ID: 1, UserID: 1, Status: "completed", StorageType: "upstream", MediaURL: sourceServer.URL + "/v1.mp4", MediaURLs: []string{ sourceServer.URL + "/v1.mp4", sourceServer.URL + "/v2.mp4", }, } s3Storage := newS3StorageForHandler(fakeS3.URL) genService := service.NewSoraGenerationService(repo, nil, nil) h := &SoraClientHandler{genService: genService, s3Storage: s3Storage} c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) c.Params = gin.Params{{Key: "id", Value: "1"}} h.SaveToStorage(c) require.Equal(t, http.StatusOK, rec.Code) resp := parseResponse(t, rec) data := resp["data"].(map[string]any) require.Len(t, data["object_keys"].([]any), 2) require.Equal(t, service.SoraStorageTypeS3, repo.gens[1].StorageType) require.Len(t, repo.gens[1].S3ObjectKeys, 2) require.Len(t, repo.gens[1].MediaURLs, 2) } func TestSaveToStorage_S3EnabledUploadSuccessWithQuota(t *testing.T) { sourceServer := newFakeSourceServer() defer sourceServer.Close() fakeS3 := newFakeS3Server("ok") defer fakeS3.Close() repo := newStubSoraGenRepo() repo.gens[1] = &service.SoraGeneration{ ID: 1, UserID: 1, Status: "completed", StorageType: "upstream", MediaURL: sourceServer.URL + "/v.mp4", } s3Storage := newS3StorageForHandler(fakeS3.URL) genService := service.NewSoraGenerationService(repo, nil, nil) userRepo := newStubUserRepoForHandler() userRepo.users[1] = &service.User{ ID: 1, SoraStorageQuotaBytes: 100 * 1024 * 1024, SoraStorageUsedBytes: 0, } quotaService := service.NewSoraQuotaService(userRepo, nil, nil) h := &SoraClientHandler{genService: genService, s3Storage: s3Storage, quotaService: quotaService} c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) c.Params = gin.Params{{Key: "id", Value: "1"}} h.SaveToStorage(c) require.Equal(t, http.StatusOK, rec.Code) // 验证配额已累加 require.Greater(t, userRepo.users[1].SoraStorageUsedBytes, int64(0)) } func TestSaveToStorage_S3UploadSuccessMarkCompletedFails(t *testing.T) { sourceServer := newFakeSourceServer() defer sourceServer.Close() fakeS3 := newFakeS3Server("ok") defer fakeS3.Close() repo := newStubSoraGenRepo() repo.gens[1] = &service.SoraGeneration{ ID: 1, UserID: 1, Status: "completed", StorageType: "upstream", MediaURL: sourceServer.URL + "/v.mp4", } // S3 上传成功后,MarkCompleted 会调用 repo.Update → 失败 repo.updateErr = fmt.Errorf("db error") s3Storage := newS3StorageForHandler(fakeS3.URL) genService := service.NewSoraGenerationService(repo, nil, nil) h := &SoraClientHandler{genService: genService, s3Storage: s3Storage} c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) c.Params = gin.Params{{Key: "id", Value: "1"}} h.SaveToStorage(c) require.Equal(t, http.StatusInternalServerError, rec.Code) } // ==================== GetStorageStatus: S3 路径 ==================== func TestGetStorageStatus_S3EnabledNotHealthy(t *testing.T) { // S3 启用但 TestConnection 失败(fake 端点不响应 HeadBucket) fakeS3 := newFakeS3Server("fail") defer fakeS3.Close() s3Storage := newS3StorageForHandler(fakeS3.URL) h := &SoraClientHandler{s3Storage: s3Storage} c, rec := makeGinContext("GET", "/api/v1/sora/storage-status", "", 0) h.GetStorageStatus(c) require.Equal(t, http.StatusOK, rec.Code) resp := parseResponse(t, rec) data := resp["data"].(map[string]any) require.Equal(t, true, data["s3_enabled"]) require.Equal(t, false, data["s3_healthy"]) } func TestGetStorageStatus_S3EnabledHealthy(t *testing.T) { fakeS3 := newFakeS3Server("ok") defer fakeS3.Close() s3Storage := newS3StorageForHandler(fakeS3.URL) h := &SoraClientHandler{s3Storage: s3Storage} c, rec := makeGinContext("GET", "/api/v1/sora/storage-status", "", 0) h.GetStorageStatus(c) require.Equal(t, http.StatusOK, rec.Code) resp := parseResponse(t, rec) data := resp["data"].(map[string]any) require.Equal(t, true, data["s3_enabled"]) require.Equal(t, true, data["s3_healthy"]) } // ==================== Stub: AccountRepository (用于 GatewayService) ==================== var _ service.AccountRepository = (*stubAccountRepoForHandler)(nil) type stubAccountRepoForHandler struct { accounts []service.Account } func (r *stubAccountRepoForHandler) Create(context.Context, *service.Account) error { return nil } func (r *stubAccountRepoForHandler) GetByID(_ context.Context, id int64) (*service.Account, error) { for i := range r.accounts { if r.accounts[i].ID == id { return &r.accounts[i], nil } } return nil, fmt.Errorf("account not found") } func (r *stubAccountRepoForHandler) GetByIDs(context.Context, []int64) ([]*service.Account, error) { return nil, nil } func (r *stubAccountRepoForHandler) ExistsByID(context.Context, int64) (bool, error) { return false, nil } func (r *stubAccountRepoForHandler) GetByCRSAccountID(context.Context, string) (*service.Account, error) { return nil, nil } func (r *stubAccountRepoForHandler) FindByExtraField(context.Context, string, any) ([]service.Account, error) { return nil, nil } func (r *stubAccountRepoForHandler) ListCRSAccountIDs(context.Context) (map[string]int64, error) { return nil, nil } func (r *stubAccountRepoForHandler) Update(context.Context, *service.Account) error { return nil } func (r *stubAccountRepoForHandler) Delete(context.Context, int64) error { return nil } func (r *stubAccountRepoForHandler) List(context.Context, pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) { return nil, nil, nil } func (r *stubAccountRepoForHandler) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string, string, int64) ([]service.Account, *pagination.PaginationResult, error) { return nil, nil, nil } func (r *stubAccountRepoForHandler) ListByGroup(context.Context, int64) ([]service.Account, error) { return nil, nil } func (r *stubAccountRepoForHandler) ListActive(context.Context) ([]service.Account, error) { return nil, nil } func (r *stubAccountRepoForHandler) ListByPlatform(context.Context, string) ([]service.Account, error) { return nil, nil } func (r *stubAccountRepoForHandler) UpdateLastUsed(context.Context, int64) error { return nil } func (r *stubAccountRepoForHandler) BatchUpdateLastUsed(context.Context, map[int64]time.Time) error { return nil } func (r *stubAccountRepoForHandler) SetError(context.Context, int64, string) error { return nil } func (r *stubAccountRepoForHandler) ClearError(context.Context, int64) error { return nil } func (r *stubAccountRepoForHandler) SetSchedulable(context.Context, int64, bool) error { return nil } func (r *stubAccountRepoForHandler) AutoPauseExpiredAccounts(context.Context, time.Time) (int64, error) { return 0, nil } func (r *stubAccountRepoForHandler) BindGroups(context.Context, int64, []int64) error { return nil } func (r *stubAccountRepoForHandler) ListSchedulable(context.Context) ([]service.Account, error) { return r.accounts, nil } func (r *stubAccountRepoForHandler) ListSchedulableByGroupID(context.Context, int64) ([]service.Account, error) { return r.accounts, nil } func (r *stubAccountRepoForHandler) ListSchedulableByPlatform(_ context.Context, _ string) ([]service.Account, error) { return r.accounts, nil } func (r *stubAccountRepoForHandler) ListSchedulableByGroupIDAndPlatform(context.Context, int64, string) ([]service.Account, error) { return r.accounts, nil } func (r *stubAccountRepoForHandler) ListSchedulableByPlatforms(context.Context, []string) ([]service.Account, error) { return r.accounts, nil } func (r *stubAccountRepoForHandler) ListSchedulableByGroupIDAndPlatforms(context.Context, int64, []string) ([]service.Account, error) { return r.accounts, nil } func (r *stubAccountRepoForHandler) SetRateLimited(context.Context, int64, time.Time) error { return nil } func (r *stubAccountRepoForHandler) SetModelRateLimit(context.Context, int64, string, time.Time) error { return nil } func (r *stubAccountRepoForHandler) SetOverloaded(context.Context, int64, time.Time) error { return nil } func (r *stubAccountRepoForHandler) SetTempUnschedulable(context.Context, int64, time.Time, string) error { return nil } func (r *stubAccountRepoForHandler) ClearTempUnschedulable(context.Context, int64) error { return nil } func (r *stubAccountRepoForHandler) ClearRateLimit(context.Context, int64) error { return nil } func (r *stubAccountRepoForHandler) ClearAntigravityQuotaScopes(context.Context, int64) error { return nil } func (r *stubAccountRepoForHandler) ClearModelRateLimits(context.Context, int64) error { return nil } func (r *stubAccountRepoForHandler) UpdateSessionWindow(context.Context, int64, *time.Time, *time.Time, string) error { return nil } func (r *stubAccountRepoForHandler) UpdateExtra(context.Context, int64, map[string]any) error { return nil } func (r *stubAccountRepoForHandler) BulkUpdate(context.Context, []int64, service.AccountBulkUpdate) (int64, error) { return 0, nil } // ==================== Stub: SoraClient (用于 SoraGatewayService) ==================== var _ service.SoraClient = (*stubSoraClientForHandler)(nil) type stubSoraClientForHandler struct { videoStatus *service.SoraVideoTaskStatus } func (s *stubSoraClientForHandler) Enabled() bool { return true } func (s *stubSoraClientForHandler) UploadImage(context.Context, *service.Account, []byte, string) (string, error) { return "", nil } func (s *stubSoraClientForHandler) CreateImageTask(context.Context, *service.Account, service.SoraImageRequest) (string, error) { return "task-image", nil } func (s *stubSoraClientForHandler) CreateVideoTask(context.Context, *service.Account, service.SoraVideoRequest) (string, error) { return "task-video", nil } func (s *stubSoraClientForHandler) CreateStoryboardTask(context.Context, *service.Account, service.SoraStoryboardRequest) (string, error) { return "task-video", nil } func (s *stubSoraClientForHandler) UploadCharacterVideo(context.Context, *service.Account, []byte) (string, error) { return "", nil } func (s *stubSoraClientForHandler) GetCameoStatus(context.Context, *service.Account, string) (*service.SoraCameoStatus, error) { return nil, nil } func (s *stubSoraClientForHandler) DownloadCharacterImage(context.Context, *service.Account, string) ([]byte, error) { return nil, nil } func (s *stubSoraClientForHandler) UploadCharacterImage(context.Context, *service.Account, []byte) (string, error) { return "", nil } func (s *stubSoraClientForHandler) FinalizeCharacter(context.Context, *service.Account, service.SoraCharacterFinalizeRequest) (string, error) { return "", nil } func (s *stubSoraClientForHandler) SetCharacterPublic(context.Context, *service.Account, string) error { return nil } func (s *stubSoraClientForHandler) DeleteCharacter(context.Context, *service.Account, string) error { return nil } func (s *stubSoraClientForHandler) PostVideoForWatermarkFree(context.Context, *service.Account, string) (string, error) { return "", nil } func (s *stubSoraClientForHandler) DeletePost(context.Context, *service.Account, string) error { return nil } func (s *stubSoraClientForHandler) GetWatermarkFreeURLCustom(context.Context, *service.Account, string, string, string) (string, error) { return "", nil } func (s *stubSoraClientForHandler) EnhancePrompt(context.Context, *service.Account, string, string, int) (string, error) { return "", nil } func (s *stubSoraClientForHandler) GetImageTask(context.Context, *service.Account, string) (*service.SoraImageTaskStatus, error) { return nil, nil } func (s *stubSoraClientForHandler) GetVideoTask(_ context.Context, _ *service.Account, _ string) (*service.SoraVideoTaskStatus, error) { return s.videoStatus, nil } // ==================== 辅助:创建最小 GatewayService 和 SoraGatewayService ==================== // newMinimalGatewayService 创建仅包含 accountRepo 的最小 GatewayService(用于测试 SelectAccountForModel)。 func newMinimalGatewayService(accountRepo service.AccountRepository) *service.GatewayService { return service.NewGatewayService( accountRepo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, ) } // newMinimalSoraGatewayService 创建最小 SoraGatewayService(用于测试 Forward)。 func newMinimalSoraGatewayService(soraClient service.SoraClient) *service.SoraGatewayService { cfg := &config.Config{ Sora: config.SoraConfig{ Client: config.SoraClientConfig{ PollIntervalSeconds: 1, MaxPollAttempts: 1, }, }, } return service.NewSoraGatewayService(soraClient, nil, nil, cfg) } // ==================== processGeneration: 更多路径测试 ==================== func TestProcessGeneration_SelectAccountError(t *testing.T) { repo := newStubSoraGenRepo() repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} genService := service.NewSoraGenerationService(repo, nil, nil) // accountRepo 返回空列表 → SelectAccountForModel 返回 "no available accounts" accountRepo := &stubAccountRepoForHandler{accounts: nil} gatewayService := newMinimalGatewayService(accountRepo) h := &SoraClientHandler{genService: genService, gatewayService: gatewayService} h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1) require.Equal(t, "failed", repo.gens[1].Status) require.Contains(t, repo.gens[1].ErrorMessage, "选择账号失败") } func TestProcessGeneration_SoraGatewayServiceNil(t *testing.T) { t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复") repo := newStubSoraGenRepo() repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} genService := service.NewSoraGenerationService(repo, nil, nil) // 提供可用账号使 SelectAccountForModel 成功 accountRepo := &stubAccountRepoForHandler{ accounts: []service.Account{ {ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true}, }, } gatewayService := newMinimalGatewayService(accountRepo) // soraGatewayService 为 nil h := &SoraClientHandler{genService: genService, gatewayService: gatewayService} h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1) require.Equal(t, "failed", repo.gens[1].Status) require.Contains(t, repo.gens[1].ErrorMessage, "soraGatewayService") } func TestProcessGeneration_ForwardError(t *testing.T) { t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复") repo := newStubSoraGenRepo() repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} genService := service.NewSoraGenerationService(repo, nil, nil) accountRepo := &stubAccountRepoForHandler{ accounts: []service.Account{ {ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true}, }, } gatewayService := newMinimalGatewayService(accountRepo) // SoraClient 返回视频任务失败 soraClient := &stubSoraClientForHandler{ videoStatus: &service.SoraVideoTaskStatus{ Status: "failed", ErrorMsg: "content policy violation", }, } soraGatewayService := newMinimalSoraGatewayService(soraClient) h := &SoraClientHandler{ genService: genService, gatewayService: gatewayService, soraGatewayService: soraGatewayService, } h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test prompt", "video", "", 1) require.Equal(t, "failed", repo.gens[1].Status) require.Contains(t, repo.gens[1].ErrorMessage, "生成失败") } func TestProcessGeneration_ForwardErrorCancelled(t *testing.T) { repo := newStubSoraGenRepo() repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} // MarkGenerating 内部调用 GetByID(第 1 次),Forward 失败后 processGeneration // 调用 GetByID(第 2 次)。模拟外部在 Forward 期间取消了任务。 repo.getByIDOverrideAfterN = 1 repo.getByIDOverrideStatus = "cancelled" genService := service.NewSoraGenerationService(repo, nil, nil) accountRepo := &stubAccountRepoForHandler{ accounts: []service.Account{ {ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true}, }, } gatewayService := newMinimalGatewayService(accountRepo) soraClient := &stubSoraClientForHandler{ videoStatus: &service.SoraVideoTaskStatus{Status: "failed", ErrorMsg: "reject"}, } soraGatewayService := newMinimalSoraGatewayService(soraClient) h := &SoraClientHandler{ genService: genService, gatewayService: gatewayService, soraGatewayService: soraGatewayService, } h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1) // Forward 失败后检测到外部取消,不应调用 MarkFailed(状态保持 generating) require.Equal(t, "generating", repo.gens[1].Status) } func TestProcessGeneration_ForwardSuccessNoMediaURL(t *testing.T) { t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复") repo := newStubSoraGenRepo() repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} genService := service.NewSoraGenerationService(repo, nil, nil) accountRepo := &stubAccountRepoForHandler{ accounts: []service.Account{ {ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true}, }, } gatewayService := newMinimalGatewayService(accountRepo) // SoraClient 返回 completed 但无 URL soraClient := &stubSoraClientForHandler{ videoStatus: &service.SoraVideoTaskStatus{ Status: "completed", URLs: nil, // 无 URL }, } soraGatewayService := newMinimalSoraGatewayService(soraClient) h := &SoraClientHandler{ genService: genService, gatewayService: gatewayService, soraGatewayService: soraGatewayService, } h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1) require.Equal(t, "failed", repo.gens[1].Status) require.Contains(t, repo.gens[1].ErrorMessage, "未获取到媒体 URL") } func TestProcessGeneration_ForwardSuccessCancelledBeforeStore(t *testing.T) { repo := newStubSoraGenRepo() repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} // MarkGenerating 调用 GetByID(第 1 次),之后 processGeneration 行 176 调用 GetByID(第 2 次) // 第 2 次返回 "cancelled" 状态,模拟外部取消 repo.getByIDOverrideAfterN = 1 repo.getByIDOverrideStatus = "cancelled" genService := service.NewSoraGenerationService(repo, nil, nil) accountRepo := &stubAccountRepoForHandler{ accounts: []service.Account{ {ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true}, }, } gatewayService := newMinimalGatewayService(accountRepo) soraClient := &stubSoraClientForHandler{ videoStatus: &service.SoraVideoTaskStatus{ Status: "completed", URLs: []string{"https://example.com/video.mp4"}, }, } soraGatewayService := newMinimalSoraGatewayService(soraClient) h := &SoraClientHandler{ genService: genService, gatewayService: gatewayService, soraGatewayService: soraGatewayService, } h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1) // Forward 成功后检测到外部取消,不应调用存储和 MarkCompleted(状态保持 generating) require.Equal(t, "generating", repo.gens[1].Status) } func TestProcessGeneration_FullSuccessUpstream(t *testing.T) { t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复") repo := newStubSoraGenRepo() repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} genService := service.NewSoraGenerationService(repo, nil, nil) accountRepo := &stubAccountRepoForHandler{ accounts: []service.Account{ {ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true}, }, } gatewayService := newMinimalGatewayService(accountRepo) soraClient := &stubSoraClientForHandler{ videoStatus: &service.SoraVideoTaskStatus{ Status: "completed", URLs: []string{"https://example.com/video.mp4"}, }, } soraGatewayService := newMinimalSoraGatewayService(soraClient) // 无 S3 和本地存储,降级到 upstream h := &SoraClientHandler{ genService: genService, gatewayService: gatewayService, soraGatewayService: soraGatewayService, } h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test prompt", "video", "", 1) require.Equal(t, "completed", repo.gens[1].Status) require.Equal(t, service.SoraStorageTypeUpstream, repo.gens[1].StorageType) require.NotEmpty(t, repo.gens[1].MediaURL) } func TestProcessGeneration_FullSuccessWithS3(t *testing.T) { t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复") sourceServer := newFakeSourceServer() defer sourceServer.Close() fakeS3 := newFakeS3Server("ok") defer fakeS3.Close() repo := newStubSoraGenRepo() repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} genService := service.NewSoraGenerationService(repo, nil, nil) accountRepo := &stubAccountRepoForHandler{ accounts: []service.Account{ {ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true}, }, } gatewayService := newMinimalGatewayService(accountRepo) soraClient := &stubSoraClientForHandler{ videoStatus: &service.SoraVideoTaskStatus{ Status: "completed", URLs: []string{sourceServer.URL + "/video.mp4"}, }, } soraGatewayService := newMinimalSoraGatewayService(soraClient) s3Storage := newS3StorageForHandler(fakeS3.URL) userRepo := newStubUserRepoForHandler() userRepo.users[1] = &service.User{ ID: 1, SoraStorageQuotaBytes: 100 * 1024 * 1024, } quotaService := service.NewSoraQuotaService(userRepo, nil, nil) h := &SoraClientHandler{ genService: genService, gatewayService: gatewayService, soraGatewayService: soraGatewayService, s3Storage: s3Storage, quotaService: quotaService, } h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test prompt", "video", "", 1) require.Equal(t, "completed", repo.gens[1].Status) require.Equal(t, service.SoraStorageTypeS3, repo.gens[1].StorageType) require.NotEmpty(t, repo.gens[1].S3ObjectKeys) require.Greater(t, repo.gens[1].FileSizeBytes, int64(0)) // 验证配额已累加 require.Greater(t, userRepo.users[1].SoraStorageUsedBytes, int64(0)) } func TestProcessGeneration_MarkCompletedFails(t *testing.T) { t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复") repo := newStubSoraGenRepo() repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} // 第 1 次 Update(MarkGenerating)成功,第 2 次(MarkCompleted)失败 repo.updateCallCount = new(int32) repo.updateFailAfterN = 1 genService := service.NewSoraGenerationService(repo, nil, nil) accountRepo := &stubAccountRepoForHandler{ accounts: []service.Account{ {ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true}, }, } gatewayService := newMinimalGatewayService(accountRepo) soraClient := &stubSoraClientForHandler{ videoStatus: &service.SoraVideoTaskStatus{ Status: "completed", URLs: []string{"https://example.com/video.mp4"}, }, } soraGatewayService := newMinimalSoraGatewayService(soraClient) h := &SoraClientHandler{ genService: genService, gatewayService: gatewayService, soraGatewayService: soraGatewayService, } h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test prompt", "video", "", 1) // MarkCompleted 内部先修改内存对象状态为 completed,然后 Update 失败。 // 由于 stub 存储的是指针,内存中的状态已被修改为 completed。 // 此测试验证 processGeneration 在 MarkCompleted 失败后提前返回(不调用 AddUsage)。 require.Equal(t, "completed", repo.gens[1].Status) } // ==================== cleanupStoredMedia 直接测试 ==================== func TestCleanupStoredMedia_S3Path(t *testing.T) { // S3 清理路径:s3Storage 为 nil 时不 panic h := &SoraClientHandler{} // 不应 panic h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeS3, []string{"key1"}, nil) } func TestCleanupStoredMedia_LocalPath(t *testing.T) { // 本地清理路径:mediaStorage 为 nil 时不 panic h := &SoraClientHandler{} h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeLocal, nil, []string{"/tmp/test.mp4"}) } func TestCleanupStoredMedia_UpstreamPath(t *testing.T) { // upstream 类型不清理 h := &SoraClientHandler{} h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeUpstream, nil, nil) } func TestCleanupStoredMedia_EmptyKeys(t *testing.T) { // 空 keys 不触发清理 h := &SoraClientHandler{} h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeS3, nil, nil) h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeLocal, nil, nil) } // ==================== DeleteGeneration: 本地存储清理路径 ==================== func TestDeleteGeneration_LocalStorageCleanup(t *testing.T) { tmpDir, err := os.MkdirTemp("", "sora-delete-test-*") require.NoError(t, err) defer os.RemoveAll(tmpDir) cfg := &config.Config{ Sora: config.SoraConfig{ Storage: config.SoraStorageConfig{ Type: "local", LocalPath: tmpDir, }, }, } mediaStorage := service.NewSoraMediaStorage(cfg) repo := newStubSoraGenRepo() repo.gens[1] = &service.SoraGeneration{ ID: 1, UserID: 1, Status: "completed", StorageType: service.SoraStorageTypeLocal, MediaURL: "video/test.mp4", MediaURLs: []string{"video/test.mp4"}, } genService := service.NewSoraGenerationService(repo, nil, nil) h := &SoraClientHandler{genService: genService, mediaStorage: mediaStorage} c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1) c.Params = gin.Params{{Key: "id", Value: "1"}} h.DeleteGeneration(c) require.Equal(t, http.StatusOK, rec.Code) _, exists := repo.gens[1] require.False(t, exists) } func TestDeleteGeneration_LocalStorageCleanup_MediaURLFallback(t *testing.T) { // MediaURLs 为空,使用 MediaURL 作为清理路径 tmpDir, err := os.MkdirTemp("", "sora-delete-fallback-*") require.NoError(t, err) defer os.RemoveAll(tmpDir) cfg := &config.Config{ Sora: config.SoraConfig{ Storage: config.SoraStorageConfig{ Type: "local", LocalPath: tmpDir, }, }, } mediaStorage := service.NewSoraMediaStorage(cfg) repo := newStubSoraGenRepo() repo.gens[1] = &service.SoraGeneration{ ID: 1, UserID: 1, Status: "completed", StorageType: service.SoraStorageTypeLocal, MediaURL: "video/test.mp4", MediaURLs: nil, // 空 } genService := service.NewSoraGenerationService(repo, nil, nil) h := &SoraClientHandler{genService: genService, mediaStorage: mediaStorage} c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1) c.Params = gin.Params{{Key: "id", Value: "1"}} h.DeleteGeneration(c) require.Equal(t, http.StatusOK, rec.Code) } func TestDeleteGeneration_NonLocalStorage_SkipCleanup(t *testing.T) { // 非本地存储类型 → 跳过清理 repo := newStubSoraGenRepo() repo.gens[1] = &service.SoraGeneration{ ID: 1, UserID: 1, Status: "completed", StorageType: service.SoraStorageTypeUpstream, MediaURL: "https://upstream.com/v.mp4", } genService := service.NewSoraGenerationService(repo, nil, nil) h := &SoraClientHandler{genService: genService} c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1) c.Params = gin.Params{{Key: "id", Value: "1"}} h.DeleteGeneration(c) require.Equal(t, http.StatusOK, rec.Code) } func TestDeleteGeneration_DeleteError(t *testing.T) { // repo.Delete 出错 repo := newStubSoraGenRepo() repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed", StorageType: "upstream"} repo.deleteErr = fmt.Errorf("delete failed") genService := service.NewSoraGenerationService(repo, nil, nil) h := &SoraClientHandler{genService: genService} c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1) c.Params = gin.Params{{Key: "id", Value: "1"}} h.DeleteGeneration(c) require.Equal(t, http.StatusNotFound, rec.Code) } // ==================== fetchUpstreamModels 测试 ==================== func TestFetchUpstreamModels_NilGateway(t *testing.T) { t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复") h := &SoraClientHandler{} _, err := h.fetchUpstreamModels(context.Background()) require.Error(t, err) require.Contains(t, err.Error(), "gatewayService 未初始化") } func TestFetchUpstreamModels_NoAccounts(t *testing.T) { t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复") accountRepo := &stubAccountRepoForHandler{accounts: nil} gatewayService := newMinimalGatewayService(accountRepo) h := &SoraClientHandler{gatewayService: gatewayService} _, err := h.fetchUpstreamModels(context.Background()) require.Error(t, err) require.Contains(t, err.Error(), "选择 Sora 账号失败") } func TestFetchUpstreamModels_NonAPIKeyAccount(t *testing.T) { t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复") accountRepo := &stubAccountRepoForHandler{ accounts: []service.Account{ {ID: 1, Type: "oauth", Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true}, }, } gatewayService := newMinimalGatewayService(accountRepo) h := &SoraClientHandler{gatewayService: gatewayService} _, err := h.fetchUpstreamModels(context.Background()) require.Error(t, err) require.Contains(t, err.Error(), "不支持模型同步") } func TestFetchUpstreamModels_MissingAPIKey(t *testing.T) { t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复") accountRepo := &stubAccountRepoForHandler{ accounts: []service.Account{ {ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true, Credentials: map[string]any{"base_url": "https://sora.test"}}, }, } gatewayService := newMinimalGatewayService(accountRepo) h := &SoraClientHandler{gatewayService: gatewayService} _, err := h.fetchUpstreamModels(context.Background()) require.Error(t, err) require.Contains(t, err.Error(), "api_key") } func TestFetchUpstreamModels_MissingBaseURL_FallsBackToDefault(t *testing.T) { t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复") // GetBaseURL() 在缺少 base_url 时返回默认值 "https://api.anthropic.com" // 因此不会触发 "账号缺少 base_url" 错误,而是会尝试请求默认 URL 并失败 accountRepo := &stubAccountRepoForHandler{ accounts: []service.Account{ {ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true, Credentials: map[string]any{"api_key": "sk-test"}}, }, } gatewayService := newMinimalGatewayService(accountRepo) h := &SoraClientHandler{gatewayService: gatewayService} _, err := h.fetchUpstreamModels(context.Background()) require.Error(t, err) } func TestFetchUpstreamModels_UpstreamReturns500(t *testing.T) { t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复") ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusInternalServerError) })) defer ts.Close() accountRepo := &stubAccountRepoForHandler{ accounts: []service.Account{ {ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true, Credentials: map[string]any{"api_key": "sk-test", "base_url": ts.URL}}, }, } gatewayService := newMinimalGatewayService(accountRepo) h := &SoraClientHandler{gatewayService: gatewayService} _, err := h.fetchUpstreamModels(context.Background()) require.Error(t, err) require.Contains(t, err.Error(), "状态码 500") } func TestFetchUpstreamModels_UpstreamReturnsInvalidJSON(t *testing.T) { t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复") ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) _, _ = w.Write([]byte("not json")) })) defer ts.Close() accountRepo := &stubAccountRepoForHandler{ accounts: []service.Account{ {ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true, Credentials: map[string]any{"api_key": "sk-test", "base_url": ts.URL}}, }, } gatewayService := newMinimalGatewayService(accountRepo) h := &SoraClientHandler{gatewayService: gatewayService} _, err := h.fetchUpstreamModels(context.Background()) require.Error(t, err) require.Contains(t, err.Error(), "解析响应失败") } func TestFetchUpstreamModels_UpstreamReturnsEmptyList(t *testing.T) { t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复") ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) _, _ = w.Write([]byte(`{"data":[]}`)) })) defer ts.Close() accountRepo := &stubAccountRepoForHandler{ accounts: []service.Account{ {ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true, Credentials: map[string]any{"api_key": "sk-test", "base_url": ts.URL}}, }, } gatewayService := newMinimalGatewayService(accountRepo) h := &SoraClientHandler{gatewayService: gatewayService} _, err := h.fetchUpstreamModels(context.Background()) require.Error(t, err) require.Contains(t, err.Error(), "空模型列表") } func TestFetchUpstreamModels_Success(t *testing.T) { t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复") ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // 验证请求头 require.Equal(t, "Bearer sk-test", r.Header.Get("Authorization")) require.True(t, strings.HasSuffix(r.URL.Path, "/sora/v1/models")) w.WriteHeader(http.StatusOK) _, _ = w.Write([]byte(`{"data":[{"id":"sora2-landscape-10s"},{"id":"sora2-portrait-10s"},{"id":"sora2-landscape-15s"},{"id":"gpt-image"}]}`)) })) defer ts.Close() accountRepo := &stubAccountRepoForHandler{ accounts: []service.Account{ {ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true, Credentials: map[string]any{"api_key": "sk-test", "base_url": ts.URL}}, }, } gatewayService := newMinimalGatewayService(accountRepo) h := &SoraClientHandler{gatewayService: gatewayService} families, err := h.fetchUpstreamModels(context.Background()) require.NoError(t, err) require.NotEmpty(t, families) } func TestFetchUpstreamModels_UnrecognizedModels(t *testing.T) { t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复") ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) _, _ = w.Write([]byte(`{"data":[{"id":"unknown-model-1"},{"id":"unknown-model-2"}]}`)) })) defer ts.Close() accountRepo := &stubAccountRepoForHandler{ accounts: []service.Account{ {ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true, Credentials: map[string]any{"api_key": "sk-test", "base_url": ts.URL}}, }, } gatewayService := newMinimalGatewayService(accountRepo) h := &SoraClientHandler{gatewayService: gatewayService} _, err := h.fetchUpstreamModels(context.Background()) require.Error(t, err) require.Contains(t, err.Error(), "未能从上游模型列表中识别") } // ==================== getModelFamilies 缓存测试 ==================== func TestGetModelFamilies_CachesLocalConfig(t *testing.T) { // gatewayService 为 nil → fetchUpstreamModels 失败 → 降级到本地配置 h := &SoraClientHandler{} families := h.getModelFamilies(context.Background()) require.NotEmpty(t, families) // 第二次调用应命中缓存(modelCacheUpstream=false → 使用短 TTL) families2 := h.getModelFamilies(context.Background()) require.Equal(t, families, families2) require.False(t, h.modelCacheUpstream) } func TestGetModelFamilies_CachesUpstreamResult(t *testing.T) { t.Skip("TODO: 临时屏蔽依赖 Sora 上游模型同步的缓存测试,待账号选择逻辑稳定后恢复") ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) _, _ = w.Write([]byte(`{"data":[{"id":"sora2-landscape-10s"},{"id":"gpt-image"}]}`)) })) defer ts.Close() accountRepo := &stubAccountRepoForHandler{ accounts: []service.Account{ {ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true, Credentials: map[string]any{"api_key": "sk-test", "base_url": ts.URL}}, }, } gatewayService := newMinimalGatewayService(accountRepo) h := &SoraClientHandler{gatewayService: gatewayService} families := h.getModelFamilies(context.Background()) require.NotEmpty(t, families) require.True(t, h.modelCacheUpstream) // 第二次调用命中缓存 families2 := h.getModelFamilies(context.Background()) require.Equal(t, families, families2) } func TestGetModelFamilies_ExpiredCacheRefreshes(t *testing.T) { // 预设过期的缓存(modelCacheUpstream=false → 短 TTL) h := &SoraClientHandler{ cachedFamilies: []service.SoraModelFamily{{ID: "old"}}, modelCacheTime: time.Now().Add(-10 * time.Minute), // 已过期 modelCacheUpstream: false, } // gatewayService 为 nil → fetchUpstreamModels 失败 → 使用本地配置刷新缓存 families := h.getModelFamilies(context.Background()) require.NotEmpty(t, families) // 缓存已刷新,不再是 "old" found := false for _, f := range families { if f.ID == "old" { found = true } } require.False(t, found, "过期缓存应被刷新") } // ==================== processGeneration: groupID 与 ForcePlatform ==================== func TestProcessGeneration_NilGroupID_WithGateway_SelectAccountFails(t *testing.T) { // groupID 为 nil → 设置 ForcePlatform=sora → 无可用 sora 账号 → MarkFailed repo := newStubSoraGenRepo() repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} genService := service.NewSoraGenerationService(repo, nil, nil) // 空账号列表 → SelectAccountForModel 失败 accountRepo := &stubAccountRepoForHandler{accounts: nil} gatewayService := newMinimalGatewayService(accountRepo) h := &SoraClientHandler{ genService: genService, gatewayService: gatewayService, } h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1) require.Equal(t, "failed", repo.gens[1].Status) require.Contains(t, repo.gens[1].ErrorMessage, "选择账号失败") } // ==================== Generate: 配额检查非 QuotaExceeded 错误 ==================== func TestGenerate_CheckQuotaNonQuotaError(t *testing.T) { // quotaService.CheckQuota 返回非 QuotaExceededError → 返回 403 repo := newStubSoraGenRepo() genService := service.NewSoraGenerationService(repo, nil, nil) // 用户不存在 → GetByID 失败 → CheckQuota 返回普通 error userRepo := newStubUserRepoForHandler() quotaService := service.NewSoraQuotaService(userRepo, nil, nil) h := NewSoraClientHandler(genService, quotaService, nil, nil, nil, nil, nil) body := `{"model":"sora2-landscape-10s","prompt":"test"}` c, rec := makeGinContext("POST", "/api/v1/sora/generate", body, 1) h.Generate(c) require.Equal(t, http.StatusForbidden, rec.Code) } // ==================== Generate: CreatePending 并发限制错误 ==================== // stubSoraGenRepoWithAtomicCreate 实现 soraGenerationRepoAtomicCreator 接口 type stubSoraGenRepoWithAtomicCreate struct { stubSoraGenRepo limitErr error } func (r *stubSoraGenRepoWithAtomicCreate) CreatePendingWithLimit(_ context.Context, gen *service.SoraGeneration, _ []string, _ int64) error { if r.limitErr != nil { return r.limitErr } return r.stubSoraGenRepo.Create(context.Background(), gen) } func TestGenerate_CreatePendingConcurrencyLimit(t *testing.T) { repo := &stubSoraGenRepoWithAtomicCreate{ stubSoraGenRepo: *newStubSoraGenRepo(), limitErr: service.ErrSoraGenerationConcurrencyLimit, } genService := service.NewSoraGenerationService(repo, nil, nil) h := NewSoraClientHandler(genService, nil, nil, nil, nil, nil, nil) body := `{"model":"sora2-landscape-10s","prompt":"test"}` c, rec := makeGinContext("POST", "/api/v1/sora/generate", body, 1) h.Generate(c) require.Equal(t, http.StatusTooManyRequests, rec.Code) resp := parseResponse(t, rec) require.Contains(t, resp["message"], "3") } // ==================== SaveToStorage: 配额超限 ==================== func TestSaveToStorage_QuotaExceeded(t *testing.T) { sourceServer := newFakeSourceServer() defer sourceServer.Close() fakeS3 := newFakeS3Server("ok") defer fakeS3.Close() repo := newStubSoraGenRepo() repo.gens[1] = &service.SoraGeneration{ ID: 1, UserID: 1, Status: "completed", StorageType: "upstream", MediaURL: sourceServer.URL + "/v.mp4", } s3Storage := newS3StorageForHandler(fakeS3.URL) genService := service.NewSoraGenerationService(repo, nil, nil) // 用户配额已满 userRepo := newStubUserRepoForHandler() userRepo.users[1] = &service.User{ ID: 1, SoraStorageQuotaBytes: 10, SoraStorageUsedBytes: 10, } quotaService := service.NewSoraQuotaService(userRepo, nil, nil) h := &SoraClientHandler{genService: genService, s3Storage: s3Storage, quotaService: quotaService} c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) c.Params = gin.Params{{Key: "id", Value: "1"}} h.SaveToStorage(c) require.Equal(t, http.StatusTooManyRequests, rec.Code) } // ==================== SaveToStorage: 配额非 QuotaExceeded 错误 ==================== func TestSaveToStorage_QuotaNonQuotaError(t *testing.T) { sourceServer := newFakeSourceServer() defer sourceServer.Close() fakeS3 := newFakeS3Server("ok") defer fakeS3.Close() repo := newStubSoraGenRepo() repo.gens[1] = &service.SoraGeneration{ ID: 1, UserID: 1, Status: "completed", StorageType: "upstream", MediaURL: sourceServer.URL + "/v.mp4", } s3Storage := newS3StorageForHandler(fakeS3.URL) genService := service.NewSoraGenerationService(repo, nil, nil) // 用户不存在 → GetByID 失败 → AddUsage 返回普通 error userRepo := newStubUserRepoForHandler() quotaService := service.NewSoraQuotaService(userRepo, nil, nil) h := &SoraClientHandler{genService: genService, s3Storage: s3Storage, quotaService: quotaService} c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) c.Params = gin.Params{{Key: "id", Value: "1"}} h.SaveToStorage(c) require.Equal(t, http.StatusInternalServerError, rec.Code) } // ==================== SaveToStorage: MediaURLs 全为空 ==================== func TestSaveToStorage_EmptyMediaURLs(t *testing.T) { fakeS3 := newFakeS3Server("ok") defer fakeS3.Close() repo := newStubSoraGenRepo() repo.gens[1] = &service.SoraGeneration{ ID: 1, UserID: 1, Status: "completed", StorageType: "upstream", MediaURL: "", MediaURLs: []string{}, } s3Storage := newS3StorageForHandler(fakeS3.URL) genService := service.NewSoraGenerationService(repo, nil, nil) h := &SoraClientHandler{genService: genService, s3Storage: s3Storage} c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) c.Params = gin.Params{{Key: "id", Value: "1"}} h.SaveToStorage(c) require.Equal(t, http.StatusBadRequest, rec.Code) resp := parseResponse(t, rec) require.Contains(t, resp["message"], "已过期") } // ==================== SaveToStorage: S3 上传失败时已有已上传文件需清理 ==================== func TestSaveToStorage_MultiURL_SecondUploadFails(t *testing.T) { sourceServer := newFakeSourceServer() defer sourceServer.Close() fakeS3 := newFakeS3Server("fail-second") defer fakeS3.Close() repo := newStubSoraGenRepo() repo.gens[1] = &service.SoraGeneration{ ID: 1, UserID: 1, Status: "completed", StorageType: "upstream", MediaURL: sourceServer.URL + "/v1.mp4", MediaURLs: []string{sourceServer.URL + "/v1.mp4", sourceServer.URL + "/v2.mp4"}, } s3Storage := newS3StorageForHandler(fakeS3.URL) genService := service.NewSoraGenerationService(repo, nil, nil) h := &SoraClientHandler{genService: genService, s3Storage: s3Storage} c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) c.Params = gin.Params{{Key: "id", Value: "1"}} h.SaveToStorage(c) require.Equal(t, http.StatusInternalServerError, rec.Code) } // ==================== SaveToStorage: UpdateStorageForCompleted 失败(含配额回滚) ==================== func TestSaveToStorage_MarkCompletedFailsWithQuotaRollback(t *testing.T) { sourceServer := newFakeSourceServer() defer sourceServer.Close() fakeS3 := newFakeS3Server("ok") defer fakeS3.Close() repo := newStubSoraGenRepo() repo.gens[1] = &service.SoraGeneration{ ID: 1, UserID: 1, Status: "completed", StorageType: "upstream", MediaURL: sourceServer.URL + "/v.mp4", } repo.updateErr = fmt.Errorf("db error") s3Storage := newS3StorageForHandler(fakeS3.URL) genService := service.NewSoraGenerationService(repo, nil, nil) userRepo := newStubUserRepoForHandler() userRepo.users[1] = &service.User{ ID: 1, SoraStorageQuotaBytes: 100 * 1024 * 1024, SoraStorageUsedBytes: 0, } quotaService := service.NewSoraQuotaService(userRepo, nil, nil) h := &SoraClientHandler{genService: genService, s3Storage: s3Storage, quotaService: quotaService} c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) c.Params = gin.Params{{Key: "id", Value: "1"}} h.SaveToStorage(c) require.Equal(t, http.StatusInternalServerError, rec.Code) } // ==================== cleanupStoredMedia: 实际 S3 删除路径 ==================== func TestCleanupStoredMedia_WithS3Storage_ActualDelete(t *testing.T) { fakeS3 := newFakeS3Server("ok") defer fakeS3.Close() s3Storage := newS3StorageForHandler(fakeS3.URL) h := &SoraClientHandler{s3Storage: s3Storage} h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeS3, []string{"key1", "key2"}, nil) } func TestCleanupStoredMedia_S3DeleteFails_LogOnly(t *testing.T) { fakeS3 := newFakeS3Server("fail") defer fakeS3.Close() s3Storage := newS3StorageForHandler(fakeS3.URL) h := &SoraClientHandler{s3Storage: s3Storage} h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeS3, []string{"key1"}, nil) } func TestCleanupStoredMedia_LocalDeleteFails_LogOnly(t *testing.T) { tmpDir, err := os.MkdirTemp("", "sora-cleanup-fail-*") require.NoError(t, err) defer os.RemoveAll(tmpDir) cfg := &config.Config{ Sora: config.SoraConfig{ Storage: config.SoraStorageConfig{ Type: "local", LocalPath: tmpDir, }, }, } mediaStorage := service.NewSoraMediaStorage(cfg) h := &SoraClientHandler{mediaStorage: mediaStorage} h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeLocal, nil, []string{"nonexistent/file.mp4"}) } // ==================== DeleteGeneration: 本地文件删除失败(仅日志) ==================== func TestDeleteGeneration_LocalStorageDeleteFails_LogOnly(t *testing.T) { tmpDir, err := os.MkdirTemp("", "sora-del-test-*") require.NoError(t, err) defer os.RemoveAll(tmpDir) cfg := &config.Config{ Sora: config.SoraConfig{ Storage: config.SoraStorageConfig{ Type: "local", LocalPath: tmpDir, }, }, } mediaStorage := service.NewSoraMediaStorage(cfg) repo := newStubSoraGenRepo() repo.gens[1] = &service.SoraGeneration{ ID: 1, UserID: 1, Status: "completed", StorageType: service.SoraStorageTypeLocal, MediaURL: "nonexistent/video.mp4", MediaURLs: []string{"nonexistent/video.mp4"}, } genService := service.NewSoraGenerationService(repo, nil, nil) h := &SoraClientHandler{genService: genService, mediaStorage: mediaStorage} c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1) c.Params = gin.Params{{Key: "id", Value: "1"}} h.DeleteGeneration(c) require.Equal(t, http.StatusOK, rec.Code) } // ==================== CancelGeneration: 任务已结束冲突 ==================== func TestCancelGeneration_AlreadyCompleted(t *testing.T) { repo := newStubSoraGenRepo() repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed"} genService := service.NewSoraGenerationService(repo, nil, nil) h := &SoraClientHandler{genService: genService} c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1) c.Params = gin.Params{{Key: "id", Value: "1"}} h.CancelGeneration(c) require.Equal(t, http.StatusConflict, rec.Code) }