//go:build unit
package handler
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"os"
"strings"
"sync/atomic"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func init() {
gin.SetMode(gin.TestMode)
}
// ==================== Stub: SoraGenerationRepository ====================
var _ service.SoraGenerationRepository = (*stubSoraGenRepo)(nil)
type stubSoraGenRepo struct {
gens map[int64]*service.SoraGeneration
nextID int64
createErr error
getErr error
updateErr error
deleteErr error
listErr error
countErr error
countValue int64
// 条件性 Update 失败:前 updateFailAfterN 次成功,之后失败
updateCallCount *int32
updateFailAfterN int32
// 条件性 GetByID 状态覆盖:前 getByIDOverrideAfterN 次正常返回,之后返回 overrideStatus
getByIDCallCount int32
getByIDOverrideAfterN int32 // 0 = 不覆盖
getByIDOverrideStatus string
}
func newStubSoraGenRepo() *stubSoraGenRepo {
return &stubSoraGenRepo{gens: make(map[int64]*service.SoraGeneration), nextID: 1}
}
func (r *stubSoraGenRepo) Create(_ context.Context, gen *service.SoraGeneration) error {
if r.createErr != nil {
return r.createErr
}
gen.ID = r.nextID
r.nextID++
r.gens[gen.ID] = gen
return nil
}
func (r *stubSoraGenRepo) GetByID(_ context.Context, id int64) (*service.SoraGeneration, error) {
if r.getErr != nil {
return nil, r.getErr
}
gen, ok := r.gens[id]
if !ok {
return nil, fmt.Errorf("not found")
}
// 条件性状态覆盖:模拟外部取消等场景
if r.getByIDOverrideAfterN > 0 {
n := atomic.AddInt32(&r.getByIDCallCount, 1)
if n > r.getByIDOverrideAfterN {
cp := *gen
cp.Status = r.getByIDOverrideStatus
return &cp, nil
}
}
return gen, nil
}
func (r *stubSoraGenRepo) Update(_ context.Context, gen *service.SoraGeneration) error {
// 条件性失败:前 N 次成功,之后失败
if r.updateCallCount != nil {
n := atomic.AddInt32(r.updateCallCount, 1)
if n > r.updateFailAfterN {
return fmt.Errorf("conditional update error (call #%d)", n)
}
}
if r.updateErr != nil {
return r.updateErr
}
r.gens[gen.ID] = gen
return nil
}
func (r *stubSoraGenRepo) Delete(_ context.Context, id int64) error {
if r.deleteErr != nil {
return r.deleteErr
}
delete(r.gens, id)
return nil
}
func (r *stubSoraGenRepo) List(_ context.Context, params service.SoraGenerationListParams) ([]*service.SoraGeneration, int64, error) {
if r.listErr != nil {
return nil, 0, r.listErr
}
var result []*service.SoraGeneration
for _, gen := range r.gens {
if gen.UserID != params.UserID {
continue
}
result = append(result, gen)
}
return result, int64(len(result)), nil
}
func (r *stubSoraGenRepo) CountByUserAndStatus(_ context.Context, _ int64, _ []string) (int64, error) {
if r.countErr != nil {
return 0, r.countErr
}
return r.countValue, nil
}
// ==================== 辅助函数 ====================
func newTestSoraClientHandler(repo *stubSoraGenRepo) *SoraClientHandler {
genService := service.NewSoraGenerationService(repo, nil, nil)
return &SoraClientHandler{genService: genService}
}
func makeGinContext(method, path, body string, userID int64) (*gin.Context, *httptest.ResponseRecorder) {
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
if body != "" {
c.Request = httptest.NewRequest(method, path, strings.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
} else {
c.Request = httptest.NewRequest(method, path, nil)
}
if userID > 0 {
c.Set("user_id", userID)
}
return c, rec
}
func parseResponse(t *testing.T, rec *httptest.ResponseRecorder) map[string]any {
t.Helper()
var resp map[string]any
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
return resp
}
// ==================== 纯函数测试: buildAsyncRequestBody ====================
func TestBuildAsyncRequestBody(t *testing.T) {
body := buildAsyncRequestBody("sora2-landscape-10s", "一只猫在跳舞", "", 1)
var parsed map[string]any
require.NoError(t, json.Unmarshal(body, &parsed))
require.Equal(t, "sora2-landscape-10s", parsed["model"])
require.Equal(t, false, parsed["stream"])
msgs := parsed["messages"].([]any)
require.Len(t, msgs, 1)
msg := msgs[0].(map[string]any)
require.Equal(t, "user", msg["role"])
require.Equal(t, "一只猫在跳舞", msg["content"])
}
func TestBuildAsyncRequestBody_EmptyPrompt(t *testing.T) {
body := buildAsyncRequestBody("gpt-image", "", "", 1)
var parsed map[string]any
require.NoError(t, json.Unmarshal(body, &parsed))
require.Equal(t, "gpt-image", parsed["model"])
msgs := parsed["messages"].([]any)
msg := msgs[0].(map[string]any)
require.Equal(t, "", msg["content"])
}
func TestBuildAsyncRequestBody_WithImageInput(t *testing.T) {
body := buildAsyncRequestBody("gpt-image", "一只猫", "https://example.com/ref.png", 1)
var parsed map[string]any
require.NoError(t, json.Unmarshal(body, &parsed))
require.Equal(t, "https://example.com/ref.png", parsed["image_input"])
}
func TestBuildAsyncRequestBody_WithVideoCount(t *testing.T) {
body := buildAsyncRequestBody("sora2-landscape-10s", "一只猫在跳舞", "", 3)
var parsed map[string]any
require.NoError(t, json.Unmarshal(body, &parsed))
require.Equal(t, float64(3), parsed["video_count"])
}
func TestNormalizeVideoCount(t *testing.T) {
require.Equal(t, 1, normalizeVideoCount("video", 0))
require.Equal(t, 2, normalizeVideoCount("video", 2))
require.Equal(t, 3, normalizeVideoCount("video", 5))
require.Equal(t, 1, normalizeVideoCount("image", 3))
}
// ==================== 纯函数测试: parseMediaURLsFromBody ====================
func TestParseMediaURLsFromBody_MediaURLs(t *testing.T) {
urls := parseMediaURLsFromBody([]byte(`{"media_urls":["https://a.com/1.mp4","https://a.com/2.mp4"]}`))
require.Equal(t, []string{"https://a.com/1.mp4", "https://a.com/2.mp4"}, urls)
}
func TestParseMediaURLsFromBody_SingleMediaURL(t *testing.T) {
urls := parseMediaURLsFromBody([]byte(`{"media_url":"https://a.com/video.mp4"}`))
require.Equal(t, []string{"https://a.com/video.mp4"}, urls)
}
func TestParseMediaURLsFromBody_EmptyBody(t *testing.T) {
require.Nil(t, parseMediaURLsFromBody(nil))
require.Nil(t, parseMediaURLsFromBody([]byte{}))
}
func TestParseMediaURLsFromBody_InvalidJSON(t *testing.T) {
require.Nil(t, parseMediaURLsFromBody([]byte("not json")))
}
func TestParseMediaURLsFromBody_NoMediaFields(t *testing.T) {
require.Nil(t, parseMediaURLsFromBody([]byte(`{"data":"something"}`)))
}
func TestParseMediaURLsFromBody_EmptyMediaURL(t *testing.T) {
require.Nil(t, parseMediaURLsFromBody([]byte(`{"media_url":""}`)))
}
func TestParseMediaURLsFromBody_EmptyMediaURLs(t *testing.T) {
require.Nil(t, parseMediaURLsFromBody([]byte(`{"media_urls":[]}`)))
}
func TestParseMediaURLsFromBody_MediaURLsPriority(t *testing.T) {
body := `{"media_url":"https://single.com/1.mp4","media_urls":["https://multi.com/a.mp4","https://multi.com/b.mp4"]}`
urls := parseMediaURLsFromBody([]byte(body))
require.Len(t, urls, 2)
require.Equal(t, "https://multi.com/a.mp4", urls[0])
}
func TestParseMediaURLsFromBody_FilterEmpty(t *testing.T) {
urls := parseMediaURLsFromBody([]byte(`{"media_urls":["https://a.com/1.mp4","","https://a.com/2.mp4"]}`))
require.Equal(t, []string{"https://a.com/1.mp4", "https://a.com/2.mp4"}, urls)
}
func TestParseMediaURLsFromBody_AllEmpty(t *testing.T) {
require.Nil(t, parseMediaURLsFromBody([]byte(`{"media_urls":["",""]}`)))
}
func TestParseMediaURLsFromBody_NonStringArray(t *testing.T) {
// media_urls 不是 string 数组
require.Nil(t, parseMediaURLsFromBody([]byte(`{"media_urls":"not-array"}`)))
}
func TestParseMediaURLsFromBody_MediaURLNotString(t *testing.T) {
require.Nil(t, parseMediaURLsFromBody([]byte(`{"media_url":123}`)))
}
// ==================== 纯函数测试: extractMediaURLsFromResult ====================
func TestExtractMediaURLsFromResult_OAuthPath(t *testing.T) {
result := &service.ForwardResult{MediaURL: "https://oauth.com/video.mp4"}
recorder := httptest.NewRecorder()
url, urls := extractMediaURLsFromResult(result, recorder)
require.Equal(t, "https://oauth.com/video.mp4", url)
require.Equal(t, []string{"https://oauth.com/video.mp4"}, urls)
}
func TestExtractMediaURLsFromResult_OAuthWithBody(t *testing.T) {
result := &service.ForwardResult{MediaURL: "https://oauth.com/video.mp4"}
recorder := httptest.NewRecorder()
_, _ = recorder.Write([]byte(`{"media_urls":["https://body.com/1.mp4","https://body.com/2.mp4"]}`))
url, urls := extractMediaURLsFromResult(result, recorder)
require.Equal(t, "https://body.com/1.mp4", url)
require.Len(t, urls, 2)
}
func TestExtractMediaURLsFromResult_APIKeyPath(t *testing.T) {
recorder := httptest.NewRecorder()
_, _ = recorder.Write([]byte(`{"media_url":"https://upstream.com/video.mp4"}`))
url, urls := extractMediaURLsFromResult(nil, recorder)
require.Equal(t, "https://upstream.com/video.mp4", url)
require.Equal(t, []string{"https://upstream.com/video.mp4"}, urls)
}
func TestExtractMediaURLsFromResult_NilResultEmptyBody(t *testing.T) {
recorder := httptest.NewRecorder()
url, urls := extractMediaURLsFromResult(nil, recorder)
require.Empty(t, url)
require.Nil(t, urls)
}
func TestExtractMediaURLsFromResult_EmptyMediaURL(t *testing.T) {
result := &service.ForwardResult{MediaURL: ""}
recorder := httptest.NewRecorder()
url, urls := extractMediaURLsFromResult(result, recorder)
require.Empty(t, url)
require.Nil(t, urls)
}
// ==================== getUserIDFromContext ====================
func TestGetUserIDFromContext_Int64(t *testing.T) {
c, _ := gin.CreateTestContext(httptest.NewRecorder())
c.Request = httptest.NewRequest("GET", "/", nil)
c.Set("user_id", int64(42))
require.Equal(t, int64(42), getUserIDFromContext(c))
}
func TestGetUserIDFromContext_AuthSubject(t *testing.T) {
c, _ := gin.CreateTestContext(httptest.NewRecorder())
c.Request = httptest.NewRequest("GET", "/", nil)
c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 777})
require.Equal(t, int64(777), getUserIDFromContext(c))
}
func TestGetUserIDFromContext_Float64(t *testing.T) {
c, _ := gin.CreateTestContext(httptest.NewRecorder())
c.Request = httptest.NewRequest("GET", "/", nil)
c.Set("user_id", float64(99))
require.Equal(t, int64(99), getUserIDFromContext(c))
}
func TestGetUserIDFromContext_String(t *testing.T) {
c, _ := gin.CreateTestContext(httptest.NewRecorder())
c.Request = httptest.NewRequest("GET", "/", nil)
c.Set("user_id", "123")
require.Equal(t, int64(123), getUserIDFromContext(c))
}
func TestGetUserIDFromContext_UserIDFallback(t *testing.T) {
c, _ := gin.CreateTestContext(httptest.NewRecorder())
c.Request = httptest.NewRequest("GET", "/", nil)
c.Set("userID", int64(55))
require.Equal(t, int64(55), getUserIDFromContext(c))
}
func TestGetUserIDFromContext_NoID(t *testing.T) {
c, _ := gin.CreateTestContext(httptest.NewRecorder())
c.Request = httptest.NewRequest("GET", "/", nil)
require.Equal(t, int64(0), getUserIDFromContext(c))
}
func TestGetUserIDFromContext_InvalidString(t *testing.T) {
c, _ := gin.CreateTestContext(httptest.NewRecorder())
c.Request = httptest.NewRequest("GET", "/", nil)
c.Set("user_id", "not-a-number")
require.Equal(t, int64(0), getUserIDFromContext(c))
}
// ==================== Handler: Generate ====================
func TestGenerate_Unauthorized(t *testing.T) {
h := newTestSoraClientHandler(newStubSoraGenRepo())
c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 0)
h.Generate(c)
require.Equal(t, http.StatusUnauthorized, rec.Code)
}
func TestGenerate_BadRequest_MissingModel(t *testing.T) {
h := newTestSoraClientHandler(newStubSoraGenRepo())
c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"prompt":"test"}`, 1)
h.Generate(c)
require.Equal(t, http.StatusBadRequest, rec.Code)
}
func TestGenerate_BadRequest_MissingPrompt(t *testing.T) {
h := newTestSoraClientHandler(newStubSoraGenRepo())
c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s"}`, 1)
h.Generate(c)
require.Equal(t, http.StatusBadRequest, rec.Code)
}
func TestGenerate_BadRequest_InvalidJSON(t *testing.T) {
h := newTestSoraClientHandler(newStubSoraGenRepo())
c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{invalid`, 1)
h.Generate(c)
require.Equal(t, http.StatusBadRequest, rec.Code)
}
func TestGenerate_TooManyRequests(t *testing.T) {
repo := newStubSoraGenRepo()
repo.countValue = 3
h := newTestSoraClientHandler(repo)
c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1)
h.Generate(c)
require.Equal(t, http.StatusTooManyRequests, rec.Code)
}
func TestGenerate_CountError(t *testing.T) {
repo := newStubSoraGenRepo()
repo.countErr = fmt.Errorf("db error")
h := newTestSoraClientHandler(repo)
c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1)
h.Generate(c)
require.Equal(t, http.StatusInternalServerError, rec.Code)
}
func TestGenerate_Success(t *testing.T) {
repo := newStubSoraGenRepo()
h := newTestSoraClientHandler(repo)
c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"测试生成"}`, 1)
h.Generate(c)
require.Equal(t, http.StatusOK, rec.Code)
resp := parseResponse(t, rec)
data := resp["data"].(map[string]any)
require.NotZero(t, data["generation_id"])
require.Equal(t, "pending", data["status"])
}
func TestGenerate_DefaultMediaType(t *testing.T) {
repo := newStubSoraGenRepo()
h := newTestSoraClientHandler(repo)
c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1)
h.Generate(c)
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, "video", repo.gens[1].MediaType)
}
func TestGenerate_ImageMediaType(t *testing.T) {
repo := newStubSoraGenRepo()
h := newTestSoraClientHandler(repo)
c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"gpt-image","prompt":"test","media_type":"image"}`, 1)
h.Generate(c)
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, "image", repo.gens[1].MediaType)
}
func TestGenerate_CreatePendingError(t *testing.T) {
repo := newStubSoraGenRepo()
repo.createErr = fmt.Errorf("create failed")
h := newTestSoraClientHandler(repo)
c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1)
h.Generate(c)
require.Equal(t, http.StatusInternalServerError, rec.Code)
}
func TestGenerate_NilQuotaServiceSkipsCheck(t *testing.T) {
repo := newStubSoraGenRepo()
h := newTestSoraClientHandler(repo)
c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1)
h.Generate(c)
require.Equal(t, http.StatusOK, rec.Code)
}
func TestGenerate_APIKeyInContext(t *testing.T) {
repo := newStubSoraGenRepo()
h := newTestSoraClientHandler(repo)
c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1)
c.Set("api_key_id", int64(42))
h.Generate(c)
require.Equal(t, http.StatusOK, rec.Code)
require.NotNil(t, repo.gens[1].APIKeyID)
require.Equal(t, int64(42), *repo.gens[1].APIKeyID)
}
func TestGenerate_NoAPIKeyInContext(t *testing.T) {
repo := newStubSoraGenRepo()
h := newTestSoraClientHandler(repo)
c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1)
h.Generate(c)
require.Equal(t, http.StatusOK, rec.Code)
require.Nil(t, repo.gens[1].APIKeyID)
}
func TestGenerate_ConcurrencyBoundary(t *testing.T) {
// activeCount == 2 应该允许
repo := newStubSoraGenRepo()
repo.countValue = 2
h := newTestSoraClientHandler(repo)
c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1)
h.Generate(c)
require.Equal(t, http.StatusOK, rec.Code)
}
// ==================== Handler: ListGenerations ====================
func TestListGenerations_Unauthorized(t *testing.T) {
h := newTestSoraClientHandler(newStubSoraGenRepo())
c, rec := makeGinContext("GET", "/api/v1/sora/generations", "", 0)
h.ListGenerations(c)
require.Equal(t, http.StatusUnauthorized, rec.Code)
}
func TestListGenerations_Success(t *testing.T) {
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Model: "sora2-landscape-10s", Status: "completed", StorageType: "upstream"}
repo.gens[2] = &service.SoraGeneration{ID: 2, UserID: 1, Model: "gpt-image", Status: "pending", StorageType: "none"}
repo.nextID = 3
h := newTestSoraClientHandler(repo)
c, rec := makeGinContext("GET", "/api/v1/sora/generations?page=1&page_size=10", "", 1)
h.ListGenerations(c)
require.Equal(t, http.StatusOK, rec.Code)
resp := parseResponse(t, rec)
data := resp["data"].(map[string]any)
items := data["data"].([]any)
require.Len(t, items, 2)
require.Equal(t, float64(2), data["total"])
}
func TestListGenerations_ListError(t *testing.T) {
repo := newStubSoraGenRepo()
repo.listErr = fmt.Errorf("db error")
h := newTestSoraClientHandler(repo)
c, rec := makeGinContext("GET", "/api/v1/sora/generations", "", 1)
h.ListGenerations(c)
require.Equal(t, http.StatusInternalServerError, rec.Code)
}
func TestListGenerations_DefaultPagination(t *testing.T) {
repo := newStubSoraGenRepo()
h := newTestSoraClientHandler(repo)
// 不传分页参数,应默认 page=1 page_size=20
c, rec := makeGinContext("GET", "/api/v1/sora/generations", "", 1)
h.ListGenerations(c)
require.Equal(t, http.StatusOK, rec.Code)
resp := parseResponse(t, rec)
data := resp["data"].(map[string]any)
require.Equal(t, float64(1), data["page"])
}
// ==================== Handler: GetGeneration ====================
func TestGetGeneration_Unauthorized(t *testing.T) {
h := newTestSoraClientHandler(newStubSoraGenRepo())
c, rec := makeGinContext("GET", "/api/v1/sora/generations/1", "", 0)
c.Params = gin.Params{{Key: "id", Value: "1"}}
h.GetGeneration(c)
require.Equal(t, http.StatusUnauthorized, rec.Code)
}
func TestGetGeneration_InvalidID(t *testing.T) {
h := newTestSoraClientHandler(newStubSoraGenRepo())
c, rec := makeGinContext("GET", "/api/v1/sora/generations/abc", "", 1)
c.Params = gin.Params{{Key: "id", Value: "abc"}}
h.GetGeneration(c)
require.Equal(t, http.StatusBadRequest, rec.Code)
}
func TestGetGeneration_NotFound(t *testing.T) {
h := newTestSoraClientHandler(newStubSoraGenRepo())
c, rec := makeGinContext("GET", "/api/v1/sora/generations/999", "", 1)
c.Params = gin.Params{{Key: "id", Value: "999"}}
h.GetGeneration(c)
require.Equal(t, http.StatusNotFound, rec.Code)
}
func TestGetGeneration_WrongUser(t *testing.T) {
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 2, Status: "completed"}
h := newTestSoraClientHandler(repo)
c, rec := makeGinContext("GET", "/api/v1/sora/generations/1", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}}
h.GetGeneration(c)
require.Equal(t, http.StatusNotFound, rec.Code)
}
func TestGetGeneration_Success(t *testing.T) {
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Model: "sora2-landscape-10s", Status: "completed", StorageType: "upstream", MediaURL: "https://example.com/video.mp4"}
h := newTestSoraClientHandler(repo)
c, rec := makeGinContext("GET", "/api/v1/sora/generations/1", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}}
h.GetGeneration(c)
require.Equal(t, http.StatusOK, rec.Code)
resp := parseResponse(t, rec)
data := resp["data"].(map[string]any)
require.Equal(t, float64(1), data["id"])
}
// ==================== Handler: DeleteGeneration ====================
func TestDeleteGeneration_Unauthorized(t *testing.T) {
h := newTestSoraClientHandler(newStubSoraGenRepo())
c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 0)
c.Params = gin.Params{{Key: "id", Value: "1"}}
h.DeleteGeneration(c)
require.Equal(t, http.StatusUnauthorized, rec.Code)
}
func TestDeleteGeneration_InvalidID(t *testing.T) {
h := newTestSoraClientHandler(newStubSoraGenRepo())
c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/abc", "", 1)
c.Params = gin.Params{{Key: "id", Value: "abc"}}
h.DeleteGeneration(c)
require.Equal(t, http.StatusBadRequest, rec.Code)
}
func TestDeleteGeneration_NotFound(t *testing.T) {
h := newTestSoraClientHandler(newStubSoraGenRepo())
c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/999", "", 1)
c.Params = gin.Params{{Key: "id", Value: "999"}}
h.DeleteGeneration(c)
require.Equal(t, http.StatusNotFound, rec.Code)
}
func TestDeleteGeneration_WrongUser(t *testing.T) {
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 2, Status: "completed"}
h := newTestSoraClientHandler(repo)
c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}}
h.DeleteGeneration(c)
require.Equal(t, http.StatusNotFound, rec.Code)
}
func TestDeleteGeneration_Success(t *testing.T) {
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed"}
h := newTestSoraClientHandler(repo)
c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}}
h.DeleteGeneration(c)
require.Equal(t, http.StatusOK, rec.Code)
_, exists := repo.gens[1]
require.False(t, exists)
}
// ==================== Handler: CancelGeneration ====================
func TestCancelGeneration_Unauthorized(t *testing.T) {
h := newTestSoraClientHandler(newStubSoraGenRepo())
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 0)
c.Params = gin.Params{{Key: "id", Value: "1"}}
h.CancelGeneration(c)
require.Equal(t, http.StatusUnauthorized, rec.Code)
}
func TestCancelGeneration_InvalidID(t *testing.T) {
h := newTestSoraClientHandler(newStubSoraGenRepo())
c, rec := makeGinContext("POST", "/api/v1/sora/generations/abc/cancel", "", 1)
c.Params = gin.Params{{Key: "id", Value: "abc"}}
h.CancelGeneration(c)
require.Equal(t, http.StatusBadRequest, rec.Code)
}
func TestCancelGeneration_NotFound(t *testing.T) {
h := newTestSoraClientHandler(newStubSoraGenRepo())
c, rec := makeGinContext("POST", "/api/v1/sora/generations/999/cancel", "", 1)
c.Params = gin.Params{{Key: "id", Value: "999"}}
h.CancelGeneration(c)
require.Equal(t, http.StatusNotFound, rec.Code)
}
func TestCancelGeneration_WrongUser(t *testing.T) {
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 2, Status: "pending"}
h := newTestSoraClientHandler(repo)
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}}
h.CancelGeneration(c)
require.Equal(t, http.StatusNotFound, rec.Code)
}
func TestCancelGeneration_Pending(t *testing.T) {
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
h := newTestSoraClientHandler(repo)
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}}
h.CancelGeneration(c)
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, "cancelled", repo.gens[1].Status)
}
func TestCancelGeneration_Generating(t *testing.T) {
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "generating"}
h := newTestSoraClientHandler(repo)
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}}
h.CancelGeneration(c)
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, "cancelled", repo.gens[1].Status)
}
func TestCancelGeneration_Completed(t *testing.T) {
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed"}
h := newTestSoraClientHandler(repo)
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}}
h.CancelGeneration(c)
require.Equal(t, http.StatusConflict, rec.Code)
}
func TestCancelGeneration_Failed(t *testing.T) {
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "failed"}
h := newTestSoraClientHandler(repo)
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}}
h.CancelGeneration(c)
require.Equal(t, http.StatusConflict, rec.Code)
}
func TestCancelGeneration_Cancelled(t *testing.T) {
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "cancelled"}
h := newTestSoraClientHandler(repo)
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}}
h.CancelGeneration(c)
require.Equal(t, http.StatusConflict, rec.Code)
}
// ==================== Handler: GetQuota ====================
func TestGetQuota_Unauthorized(t *testing.T) {
h := newTestSoraClientHandler(newStubSoraGenRepo())
c, rec := makeGinContext("GET", "/api/v1/sora/quota", "", 0)
h.GetQuota(c)
require.Equal(t, http.StatusUnauthorized, rec.Code)
}
func TestGetQuota_NilQuotaService(t *testing.T) {
h := newTestSoraClientHandler(newStubSoraGenRepo())
c, rec := makeGinContext("GET", "/api/v1/sora/quota", "", 1)
h.GetQuota(c)
require.Equal(t, http.StatusOK, rec.Code)
resp := parseResponse(t, rec)
data := resp["data"].(map[string]any)
require.Equal(t, "unlimited", data["source"])
}
// ==================== Handler: GetModels ====================
func TestGetModels(t *testing.T) {
h := newTestSoraClientHandler(newStubSoraGenRepo())
c, rec := makeGinContext("GET", "/api/v1/sora/models", "", 0)
h.GetModels(c)
require.Equal(t, http.StatusOK, rec.Code)
resp := parseResponse(t, rec)
data := resp["data"].([]any)
require.Len(t, data, 4)
// 验证类型分布
videoCount, imageCount := 0, 0
for _, item := range data {
m := item.(map[string]any)
if m["type"] == "video" {
videoCount++
} else if m["type"] == "image" {
imageCount++
}
}
require.Equal(t, 3, videoCount)
require.Equal(t, 1, imageCount)
}
// ==================== Handler: GetStorageStatus ====================
func TestGetStorageStatus_NilS3(t *testing.T) {
h := newTestSoraClientHandler(newStubSoraGenRepo())
c, rec := makeGinContext("GET", "/api/v1/sora/storage-status", "", 0)
h.GetStorageStatus(c)
require.Equal(t, http.StatusOK, rec.Code)
resp := parseResponse(t, rec)
data := resp["data"].(map[string]any)
require.Equal(t, false, data["s3_enabled"])
require.Equal(t, false, data["s3_healthy"])
require.Equal(t, false, data["local_enabled"])
}
func TestGetStorageStatus_LocalEnabled(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "sora-storage-status-*")
require.NoError(t, err)
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Sora: config.SoraConfig{
Storage: config.SoraStorageConfig{
Type: "local",
LocalPath: tmpDir,
},
},
}
mediaStorage := service.NewSoraMediaStorage(cfg)
h := &SoraClientHandler{mediaStorage: mediaStorage}
c, rec := makeGinContext("GET", "/api/v1/sora/storage-status", "", 0)
h.GetStorageStatus(c)
require.Equal(t, http.StatusOK, rec.Code)
resp := parseResponse(t, rec)
data := resp["data"].(map[string]any)
require.Equal(t, false, data["s3_enabled"])
require.Equal(t, false, data["s3_healthy"])
require.Equal(t, true, data["local_enabled"])
}
// ==================== Handler: SaveToStorage ====================
func TestSaveToStorage_Unauthorized(t *testing.T) {
h := newTestSoraClientHandler(newStubSoraGenRepo())
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 0)
c.Params = gin.Params{{Key: "id", Value: "1"}}
h.SaveToStorage(c)
require.Equal(t, http.StatusUnauthorized, rec.Code)
}
func TestSaveToStorage_InvalidID(t *testing.T) {
h := newTestSoraClientHandler(newStubSoraGenRepo())
c, rec := makeGinContext("POST", "/api/v1/sora/generations/abc/save", "", 1)
c.Params = gin.Params{{Key: "id", Value: "abc"}}
h.SaveToStorage(c)
require.Equal(t, http.StatusBadRequest, rec.Code)
}
func TestSaveToStorage_NotFound(t *testing.T) {
h := newTestSoraClientHandler(newStubSoraGenRepo())
c, rec := makeGinContext("POST", "/api/v1/sora/generations/999/save", "", 1)
c.Params = gin.Params{{Key: "id", Value: "999"}}
h.SaveToStorage(c)
require.Equal(t, http.StatusNotFound, rec.Code)
}
func TestSaveToStorage_NotUpstream(t *testing.T) {
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed", StorageType: "s3", MediaURL: "https://example.com/v.mp4"}
h := newTestSoraClientHandler(repo)
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}}
h.SaveToStorage(c)
require.Equal(t, http.StatusBadRequest, rec.Code)
}
func TestSaveToStorage_EmptyMediaURL(t *testing.T) {
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed", StorageType: "upstream", MediaURL: ""}
h := newTestSoraClientHandler(repo)
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}}
h.SaveToStorage(c)
require.Equal(t, http.StatusBadRequest, rec.Code)
}
func TestSaveToStorage_S3Nil(t *testing.T) {
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed", StorageType: "upstream", MediaURL: "https://example.com/video.mp4"}
h := newTestSoraClientHandler(repo)
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}}
h.SaveToStorage(c)
require.Equal(t, http.StatusServiceUnavailable, rec.Code)
resp := parseResponse(t, rec)
require.Contains(t, fmt.Sprint(resp["message"]), "云存储")
}
func TestSaveToStorage_WrongUser(t *testing.T) {
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 2, Status: "completed", StorageType: "upstream", MediaURL: "https://example.com/video.mp4"}
h := newTestSoraClientHandler(repo)
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}}
h.SaveToStorage(c)
require.Equal(t, http.StatusNotFound, rec.Code)
}
// ==================== storeMediaWithDegradation — nil guard 路径 ====================
func TestStoreMediaWithDegradation_NilS3NilMedia(t *testing.T) {
h := &SoraClientHandler{}
url, urls, storageType, keys, size := h.storeMediaWithDegradation(
context.Background(), 1, "video", "https://upstream.com/v.mp4", nil,
)
require.Equal(t, service.SoraStorageTypeUpstream, storageType)
require.Equal(t, "https://upstream.com/v.mp4", url)
require.Equal(t, []string{"https://upstream.com/v.mp4"}, urls)
require.Nil(t, keys)
require.Equal(t, int64(0), size)
}
func TestStoreMediaWithDegradation_NilGuardsMultiURL(t *testing.T) {
h := &SoraClientHandler{}
url, urls, storageType, keys, size := h.storeMediaWithDegradation(
context.Background(), 1, "video", "https://upstream.com/v.mp4", []string{"https://a.com/1.mp4", "https://a.com/2.mp4"},
)
require.Equal(t, service.SoraStorageTypeUpstream, storageType)
require.Equal(t, "https://a.com/1.mp4", url)
require.Equal(t, []string{"https://a.com/1.mp4", "https://a.com/2.mp4"}, urls)
require.Nil(t, keys)
require.Equal(t, int64(0), size)
}
func TestStoreMediaWithDegradation_EmptyMediaURLsFallback(t *testing.T) {
h := &SoraClientHandler{}
url, _, storageType, _, _ := h.storeMediaWithDegradation(
context.Background(), 1, "video", "https://upstream.com/v.mp4", []string{},
)
require.Equal(t, service.SoraStorageTypeUpstream, storageType)
require.Equal(t, "https://upstream.com/v.mp4", url)
}
// ==================== Stub: UserRepository (用于 SoraQuotaService) ====================
var _ service.UserRepository = (*stubUserRepoForHandler)(nil)
type stubUserRepoForHandler struct {
users map[int64]*service.User
updateErr error
}
func newStubUserRepoForHandler() *stubUserRepoForHandler {
return &stubUserRepoForHandler{users: make(map[int64]*service.User)}
}
func (r *stubUserRepoForHandler) GetByID(_ context.Context, id int64) (*service.User, error) {
if u, ok := r.users[id]; ok {
return u, nil
}
return nil, fmt.Errorf("user not found")
}
func (r *stubUserRepoForHandler) Update(_ context.Context, user *service.User) error {
if r.updateErr != nil {
return r.updateErr
}
r.users[user.ID] = user
return nil
}
func (r *stubUserRepoForHandler) Create(context.Context, *service.User) error { return nil }
func (r *stubUserRepoForHandler) GetByEmail(context.Context, string) (*service.User, error) {
return nil, nil
}
func (r *stubUserRepoForHandler) GetFirstAdmin(context.Context) (*service.User, error) {
return nil, nil
}
func (r *stubUserRepoForHandler) Delete(context.Context, int64) error { return nil }
func (r *stubUserRepoForHandler) List(context.Context, pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (r *stubUserRepoForHandler) ListWithFilters(context.Context, pagination.PaginationParams, service.UserListFilters) ([]service.User, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (r *stubUserRepoForHandler) UpdateBalance(context.Context, int64, float64) error { return nil }
func (r *stubUserRepoForHandler) DeductBalance(context.Context, int64, float64) error { return nil }
func (r *stubUserRepoForHandler) UpdateConcurrency(context.Context, int64, int) error { return nil }
func (r *stubUserRepoForHandler) ExistsByEmail(context.Context, string) (bool, error) {
return false, nil
}
func (r *stubUserRepoForHandler) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) {
return 0, nil
}
func (r *stubUserRepoForHandler) UpdateTotpSecret(context.Context, int64, *string) error { return nil }
func (r *stubUserRepoForHandler) EnableTotp(context.Context, int64) error { return nil }
func (r *stubUserRepoForHandler) DisableTotp(context.Context, int64) error { return nil }
func (r *stubUserRepoForHandler) AddGroupToAllowedGroups(context.Context, int64, int64) error {
return nil
}
// ==================== NewSoraClientHandler ====================
func TestNewSoraClientHandler(t *testing.T) {
h := NewSoraClientHandler(nil, nil, nil, nil, nil, nil, nil)
require.NotNil(t, h)
}
func TestNewSoraClientHandler_WithAPIKeyService(t *testing.T) {
h := NewSoraClientHandler(nil, nil, nil, nil, nil, nil, nil)
require.NotNil(t, h)
require.Nil(t, h.apiKeyService)
}
// ==================== Stub: APIKeyRepository (用于 API Key 校验测试) ====================
var _ service.APIKeyRepository = (*stubAPIKeyRepoForHandler)(nil)
type stubAPIKeyRepoForHandler struct {
keys map[int64]*service.APIKey
getErr error
}
func newStubAPIKeyRepoForHandler() *stubAPIKeyRepoForHandler {
return &stubAPIKeyRepoForHandler{keys: make(map[int64]*service.APIKey)}
}
func (r *stubAPIKeyRepoForHandler) GetByID(_ context.Context, id int64) (*service.APIKey, error) {
if r.getErr != nil {
return nil, r.getErr
}
if k, ok := r.keys[id]; ok {
return k, nil
}
return nil, fmt.Errorf("api key not found: %d", id)
}
func (r *stubAPIKeyRepoForHandler) Create(context.Context, *service.APIKey) error { return nil }
func (r *stubAPIKeyRepoForHandler) GetKeyAndOwnerID(_ context.Context, _ int64) (string, int64, error) {
return "", 0, nil
}
func (r *stubAPIKeyRepoForHandler) GetByKey(context.Context, string) (*service.APIKey, error) {
return nil, nil
}
func (r *stubAPIKeyRepoForHandler) GetByKeyForAuth(context.Context, string) (*service.APIKey, error) {
return nil, nil
}
func (r *stubAPIKeyRepoForHandler) Update(context.Context, *service.APIKey) error { return nil }
func (r *stubAPIKeyRepoForHandler) Delete(context.Context, int64) error { return nil }
func (r *stubAPIKeyRepoForHandler) ListByUserID(_ context.Context, _ int64, _ pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (r *stubAPIKeyRepoForHandler) VerifyOwnership(context.Context, int64, []int64) ([]int64, error) {
return nil, nil
}
func (r *stubAPIKeyRepoForHandler) CountByUserID(context.Context, int64) (int64, error) {
return 0, nil
}
func (r *stubAPIKeyRepoForHandler) ExistsByKey(context.Context, string) (bool, error) {
return false, nil
}
func (r *stubAPIKeyRepoForHandler) ListByGroupID(_ context.Context, _ int64, _ pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (r *stubAPIKeyRepoForHandler) SearchAPIKeys(context.Context, int64, string, int) ([]service.APIKey, error) {
return nil, nil
}
func (r *stubAPIKeyRepoForHandler) ClearGroupIDByGroupID(context.Context, int64) (int64, error) {
return 0, nil
}
func (r *stubAPIKeyRepoForHandler) CountByGroupID(context.Context, int64) (int64, error) {
return 0, nil
}
func (r *stubAPIKeyRepoForHandler) ListKeysByUserID(context.Context, int64) ([]string, error) {
return nil, nil
}
func (r *stubAPIKeyRepoForHandler) ListKeysByGroupID(context.Context, int64) ([]string, error) {
return nil, nil
}
func (r *stubAPIKeyRepoForHandler) IncrementQuotaUsed(_ context.Context, _ int64, _ float64) (float64, error) {
return 0, nil
}
func (r *stubAPIKeyRepoForHandler) UpdateLastUsed(context.Context, int64, time.Time) error {
return nil
}
// newTestAPIKeyService 创建测试用的 APIKeyService
func newTestAPIKeyService(repo *stubAPIKeyRepoForHandler) *service.APIKeyService {
return service.NewAPIKeyService(repo, nil, nil, nil, nil, nil, &config.Config{})
}
// ==================== Generate: API Key 校验(前端传递 api_key_id)====================
func TestGenerate_WithAPIKeyID_Success(t *testing.T) {
// 前端传递 api_key_id,校验通过 → 成功生成,记录关联 api_key_id
repo := newStubSoraGenRepo()
genService := service.NewSoraGenerationService(repo, nil, nil)
groupID := int64(5)
apiKeyRepo := newStubAPIKeyRepoForHandler()
apiKeyRepo.keys[42] = &service.APIKey{
ID: 42,
UserID: 1,
Status: service.StatusAPIKeyActive,
GroupID: &groupID,
}
apiKeyService := newTestAPIKeyService(apiKeyRepo)
h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService}
c, rec := makeGinContext("POST", "/api/v1/sora/generate",
`{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1)
h.Generate(c)
require.Equal(t, http.StatusOK, rec.Code)
resp := parseResponse(t, rec)
data := resp["data"].(map[string]any)
require.NotZero(t, data["generation_id"])
// 验证 api_key_id 已关联到生成记录
gen := repo.gens[1]
require.NotNil(t, gen.APIKeyID)
require.Equal(t, int64(42), *gen.APIKeyID)
}
func TestGenerate_WithAPIKeyID_NotFound(t *testing.T) {
// 前端传递不存在的 api_key_id → 400
repo := newStubSoraGenRepo()
genService := service.NewSoraGenerationService(repo, nil, nil)
apiKeyRepo := newStubAPIKeyRepoForHandler()
apiKeyService := newTestAPIKeyService(apiKeyRepo)
h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService}
c, rec := makeGinContext("POST", "/api/v1/sora/generate",
`{"model":"sora2-landscape-10s","prompt":"test","api_key_id":999}`, 1)
h.Generate(c)
require.Equal(t, http.StatusBadRequest, rec.Code)
resp := parseResponse(t, rec)
require.Contains(t, fmt.Sprint(resp["message"]), "不存在")
}
func TestGenerate_WithAPIKeyID_WrongUser(t *testing.T) {
// 前端传递别人的 api_key_id → 403
repo := newStubSoraGenRepo()
genService := service.NewSoraGenerationService(repo, nil, nil)
apiKeyRepo := newStubAPIKeyRepoForHandler()
apiKeyRepo.keys[42] = &service.APIKey{
ID: 42,
UserID: 999, // 属于 user 999
Status: service.StatusAPIKeyActive,
}
apiKeyService := newTestAPIKeyService(apiKeyRepo)
h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService}
c, rec := makeGinContext("POST", "/api/v1/sora/generate",
`{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1)
h.Generate(c)
require.Equal(t, http.StatusForbidden, rec.Code)
resp := parseResponse(t, rec)
require.Contains(t, fmt.Sprint(resp["message"]), "不属于")
}
func TestGenerate_WithAPIKeyID_Disabled(t *testing.T) {
// 前端传递已禁用的 api_key_id → 403
repo := newStubSoraGenRepo()
genService := service.NewSoraGenerationService(repo, nil, nil)
apiKeyRepo := newStubAPIKeyRepoForHandler()
apiKeyRepo.keys[42] = &service.APIKey{
ID: 42,
UserID: 1,
Status: service.StatusAPIKeyDisabled,
}
apiKeyService := newTestAPIKeyService(apiKeyRepo)
h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService}
c, rec := makeGinContext("POST", "/api/v1/sora/generate",
`{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1)
h.Generate(c)
require.Equal(t, http.StatusForbidden, rec.Code)
resp := parseResponse(t, rec)
require.Contains(t, fmt.Sprint(resp["message"]), "不可用")
}
func TestGenerate_WithAPIKeyID_QuotaExhausted(t *testing.T) {
// 前端传递配额耗尽的 api_key_id → 403
repo := newStubSoraGenRepo()
genService := service.NewSoraGenerationService(repo, nil, nil)
apiKeyRepo := newStubAPIKeyRepoForHandler()
apiKeyRepo.keys[42] = &service.APIKey{
ID: 42,
UserID: 1,
Status: service.StatusAPIKeyQuotaExhausted,
}
apiKeyService := newTestAPIKeyService(apiKeyRepo)
h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService}
c, rec := makeGinContext("POST", "/api/v1/sora/generate",
`{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1)
h.Generate(c)
require.Equal(t, http.StatusForbidden, rec.Code)
}
func TestGenerate_WithAPIKeyID_Expired(t *testing.T) {
// 前端传递已过期的 api_key_id → 403
repo := newStubSoraGenRepo()
genService := service.NewSoraGenerationService(repo, nil, nil)
apiKeyRepo := newStubAPIKeyRepoForHandler()
apiKeyRepo.keys[42] = &service.APIKey{
ID: 42,
UserID: 1,
Status: service.StatusAPIKeyExpired,
}
apiKeyService := newTestAPIKeyService(apiKeyRepo)
h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService}
c, rec := makeGinContext("POST", "/api/v1/sora/generate",
`{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1)
h.Generate(c)
require.Equal(t, http.StatusForbidden, rec.Code)
}
func TestGenerate_WithAPIKeyID_NilAPIKeyService(t *testing.T) {
// apiKeyService 为 nil 时忽略 api_key_id → 正常生成但不记录 api_key_id
repo := newStubSoraGenRepo()
genService := service.NewSoraGenerationService(repo, nil, nil)
h := &SoraClientHandler{genService: genService} // apiKeyService = nil
c, rec := makeGinContext("POST", "/api/v1/sora/generate",
`{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1)
h.Generate(c)
require.Equal(t, http.StatusOK, rec.Code)
// apiKeyService 为 nil → 跳过校验 → api_key_id 不记录
require.Nil(t, repo.gens[1].APIKeyID)
}
func TestGenerate_WithAPIKeyID_NilGroupID(t *testing.T) {
// api_key 有效但 GroupID 为 nil → 成功,groupID 为 nil
repo := newStubSoraGenRepo()
genService := service.NewSoraGenerationService(repo, nil, nil)
apiKeyRepo := newStubAPIKeyRepoForHandler()
apiKeyRepo.keys[42] = &service.APIKey{
ID: 42,
UserID: 1,
Status: service.StatusAPIKeyActive,
GroupID: nil, // 无分组
}
apiKeyService := newTestAPIKeyService(apiKeyRepo)
h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService}
c, rec := makeGinContext("POST", "/api/v1/sora/generate",
`{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1)
h.Generate(c)
require.Equal(t, http.StatusOK, rec.Code)
require.NotNil(t, repo.gens[1].APIKeyID)
require.Equal(t, int64(42), *repo.gens[1].APIKeyID)
}
func TestGenerate_NoAPIKeyID_NoContext_NilResult(t *testing.T) {
// 既无 api_key_id 字段也无 context 中的 api_key_id → api_key_id 为 nil
repo := newStubSoraGenRepo()
genService := service.NewSoraGenerationService(repo, nil, nil)
apiKeyRepo := newStubAPIKeyRepoForHandler()
apiKeyService := newTestAPIKeyService(apiKeyRepo)
h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService}
c, rec := makeGinContext("POST", "/api/v1/sora/generate",
`{"model":"sora2-landscape-10s","prompt":"test"}`, 1)
h.Generate(c)
require.Equal(t, http.StatusOK, rec.Code)
require.Nil(t, repo.gens[1].APIKeyID)
}
func TestGenerate_WithAPIKeyIDInBody_OverridesContext(t *testing.T) {
// 同时有 body api_key_id 和 context api_key_id → 优先使用 body 的
repo := newStubSoraGenRepo()
genService := service.NewSoraGenerationService(repo, nil, nil)
groupID := int64(10)
apiKeyRepo := newStubAPIKeyRepoForHandler()
apiKeyRepo.keys[42] = &service.APIKey{
ID: 42,
UserID: 1,
Status: service.StatusAPIKeyActive,
GroupID: &groupID,
}
apiKeyService := newTestAPIKeyService(apiKeyRepo)
h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService}
c, rec := makeGinContext("POST", "/api/v1/sora/generate",
`{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1)
c.Set("api_key_id", int64(99)) // context 中有另一个 api_key_id
h.Generate(c)
require.Equal(t, http.StatusOK, rec.Code)
// 应使用 body 中的 api_key_id=42,而不是 context 中的 99
require.NotNil(t, repo.gens[1].APIKeyID)
require.Equal(t, int64(42), *repo.gens[1].APIKeyID)
}
func TestGenerate_WithContextAPIKeyID_FallbackPath(t *testing.T) {
// 无 body api_key_id,但 context 有 → 使用 context 中的(兼容网关路由)
repo := newStubSoraGenRepo()
genService := service.NewSoraGenerationService(repo, nil, nil)
apiKeyRepo := newStubAPIKeyRepoForHandler()
apiKeyService := newTestAPIKeyService(apiKeyRepo)
h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService}
c, rec := makeGinContext("POST", "/api/v1/sora/generate",
`{"model":"sora2-landscape-10s","prompt":"test"}`, 1)
c.Set("api_key_id", int64(99))
h.Generate(c)
require.Equal(t, http.StatusOK, rec.Code)
// 应使用 context 中的 api_key_id=99
require.NotNil(t, repo.gens[1].APIKeyID)
require.Equal(t, int64(99), *repo.gens[1].APIKeyID)
}
func TestGenerate_APIKeyID_Zero_IgnoredInJSON(t *testing.T) {
// JSON 中 api_key_id=0 被视为 omitempty → 仍然为指针值 0,需要传 nil 检查
repo := newStubSoraGenRepo()
genService := service.NewSoraGenerationService(repo, nil, nil)
apiKeyRepo := newStubAPIKeyRepoForHandler()
apiKeyService := newTestAPIKeyService(apiKeyRepo)
h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService}
// JSON 中传了 api_key_id: 0 → 解析后 *int64(0),会触发校验
// api_key_id=0 不存在 → 400
c, rec := makeGinContext("POST", "/api/v1/sora/generate",
`{"model":"sora2-landscape-10s","prompt":"test","api_key_id":0}`, 1)
h.Generate(c)
require.Equal(t, http.StatusBadRequest, rec.Code)
}
// ==================== processGeneration: groupID 传递与 ForcePlatform ====================
func TestProcessGeneration_WithGroupID_NoForcePlatform(t *testing.T) {
// groupID 不为 nil → 不设置 ForcePlatform
// gatewayService 为 nil → MarkFailed → 检查错误消息不包含 ForcePlatform 相关
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
genService := service.NewSoraGenerationService(repo, nil, nil)
h := &SoraClientHandler{genService: genService}
gid := int64(5)
h.processGeneration(1, 1, &gid, "sora2-landscape-10s", "test", "video", "", 1)
require.Equal(t, "failed", repo.gens[1].Status)
require.Contains(t, repo.gens[1].ErrorMessage, "gatewayService")
}
func TestProcessGeneration_NilGroupID_SetsForcePlatform(t *testing.T) {
// groupID 为 nil → 设置 ForcePlatform → gatewayService 为 nil → MarkFailed
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
genService := service.NewSoraGenerationService(repo, nil, nil)
h := &SoraClientHandler{genService: genService}
h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1)
require.Equal(t, "failed", repo.gens[1].Status)
require.Contains(t, repo.gens[1].ErrorMessage, "gatewayService")
}
func TestProcessGeneration_MarkGeneratingStateConflict(t *testing.T) {
// 任务状态已变化(如已取消)→ MarkGenerating 返回 ErrSoraGenerationStateConflict → 跳过
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "cancelled"}
genService := service.NewSoraGenerationService(repo, nil, nil)
h := &SoraClientHandler{genService: genService}
h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1)
// 状态为 cancelled 时 MarkGenerating 不符合状态转换规则 → 应保持 cancelled
require.Equal(t, "cancelled", repo.gens[1].Status)
}
// ==================== GenerateRequest JSON 解析 ====================
func TestGenerateRequest_WithAPIKeyID_JSONParsing(t *testing.T) {
// 验证 api_key_id 在 JSON 中正确解析为 *int64
var req GenerateRequest
err := json.Unmarshal([]byte(`{"model":"sora2","prompt":"test","api_key_id":42}`), &req)
require.NoError(t, err)
require.NotNil(t, req.APIKeyID)
require.Equal(t, int64(42), *req.APIKeyID)
}
func TestGenerateRequest_WithoutAPIKeyID_JSONParsing(t *testing.T) {
// 不传 api_key_id → 解析后为 nil
var req GenerateRequest
err := json.Unmarshal([]byte(`{"model":"sora2","prompt":"test"}`), &req)
require.NoError(t, err)
require.Nil(t, req.APIKeyID)
}
func TestGenerateRequest_NullAPIKeyID_JSONParsing(t *testing.T) {
// api_key_id: null → 解析后为 nil
var req GenerateRequest
err := json.Unmarshal([]byte(`{"model":"sora2","prompt":"test","api_key_id":null}`), &req)
require.NoError(t, err)
require.Nil(t, req.APIKeyID)
}
func TestGenerateRequest_FullFields_JSONParsing(t *testing.T) {
// 全字段解析
var req GenerateRequest
err := json.Unmarshal([]byte(`{
"model":"sora2-landscape-10s",
"prompt":"test prompt",
"media_type":"video",
"video_count":2,
"image_input":"data:image/png;base64,abc",
"api_key_id":100
}`), &req)
require.NoError(t, err)
require.Equal(t, "sora2-landscape-10s", req.Model)
require.Equal(t, "test prompt", req.Prompt)
require.Equal(t, "video", req.MediaType)
require.Equal(t, 2, req.VideoCount)
require.Equal(t, "data:image/png;base64,abc", req.ImageInput)
require.NotNil(t, req.APIKeyID)
require.Equal(t, int64(100), *req.APIKeyID)
}
func TestGenerateRequest_JSONSerialize_OmitsNilAPIKeyID(t *testing.T) {
// api_key_id 为 nil 时 JSON 序列化应省略
req := GenerateRequest{Model: "sora2", Prompt: "test"}
b, err := json.Marshal(req)
require.NoError(t, err)
var parsed map[string]any
require.NoError(t, json.Unmarshal(b, &parsed))
_, hasAPIKeyID := parsed["api_key_id"]
require.False(t, hasAPIKeyID, "api_key_id 为 nil 时应省略")
}
func TestGenerateRequest_JSONSerialize_IncludesAPIKeyID(t *testing.T) {
// api_key_id 不为 nil 时 JSON 序列化应包含
id := int64(42)
req := GenerateRequest{Model: "sora2", Prompt: "test", APIKeyID: &id}
b, err := json.Marshal(req)
require.NoError(t, err)
var parsed map[string]any
require.NoError(t, json.Unmarshal(b, &parsed))
require.Equal(t, float64(42), parsed["api_key_id"])
}
// ==================== GetQuota: 有配额服务 ====================
func TestGetQuota_WithQuotaService_Success(t *testing.T) {
userRepo := newStubUserRepoForHandler()
userRepo.users[1] = &service.User{
ID: 1,
SoraStorageQuotaBytes: 10 * 1024 * 1024,
SoraStorageUsedBytes: 3 * 1024 * 1024,
}
quotaService := service.NewSoraQuotaService(userRepo, nil, nil)
repo := newStubSoraGenRepo()
genService := service.NewSoraGenerationService(repo, nil, nil)
h := &SoraClientHandler{
genService: genService,
quotaService: quotaService,
}
c, rec := makeGinContext("GET", "/api/v1/sora/quota", "", 1)
h.GetQuota(c)
require.Equal(t, http.StatusOK, rec.Code)
resp := parseResponse(t, rec)
data := resp["data"].(map[string]any)
require.Equal(t, "user", data["source"])
require.Equal(t, float64(10*1024*1024), data["quota_bytes"])
require.Equal(t, float64(3*1024*1024), data["used_bytes"])
}
func TestGetQuota_WithQuotaService_Error(t *testing.T) {
// 用户不存在时 GetQuota 返回错误
userRepo := newStubUserRepoForHandler()
quotaService := service.NewSoraQuotaService(userRepo, nil, nil)
repo := newStubSoraGenRepo()
genService := service.NewSoraGenerationService(repo, nil, nil)
h := &SoraClientHandler{
genService: genService,
quotaService: quotaService,
}
c, rec := makeGinContext("GET", "/api/v1/sora/quota", "", 999)
h.GetQuota(c)
require.Equal(t, http.StatusInternalServerError, rec.Code)
}
// ==================== Generate: 配额检查 ====================
func TestGenerate_QuotaCheckFailed(t *testing.T) {
// 配额超限时返回 429
userRepo := newStubUserRepoForHandler()
userRepo.users[1] = &service.User{
ID: 1,
SoraStorageQuotaBytes: 1024,
SoraStorageUsedBytes: 1025, // 已超限
}
quotaService := service.NewSoraQuotaService(userRepo, nil, nil)
repo := newStubSoraGenRepo()
genService := service.NewSoraGenerationService(repo, nil, nil)
h := &SoraClientHandler{
genService: genService,
quotaService: quotaService,
}
c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1)
h.Generate(c)
require.Equal(t, http.StatusTooManyRequests, rec.Code)
}
func TestGenerate_QuotaCheckPassed(t *testing.T) {
// 配额充足时允许生成
userRepo := newStubUserRepoForHandler()
userRepo.users[1] = &service.User{
ID: 1,
SoraStorageQuotaBytes: 10 * 1024 * 1024,
SoraStorageUsedBytes: 0,
}
quotaService := service.NewSoraQuotaService(userRepo, nil, nil)
repo := newStubSoraGenRepo()
genService := service.NewSoraGenerationService(repo, nil, nil)
h := &SoraClientHandler{
genService: genService,
quotaService: quotaService,
}
c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1)
h.Generate(c)
require.Equal(t, http.StatusOK, rec.Code)
}
// ==================== Stub: SettingRepository (用于 S3 存储测试) ====================
var _ service.SettingRepository = (*stubSettingRepoForHandler)(nil)
type stubSettingRepoForHandler struct {
values map[string]string
}
func newStubSettingRepoForHandler(values map[string]string) *stubSettingRepoForHandler {
if values == nil {
values = make(map[string]string)
}
return &stubSettingRepoForHandler{values: values}
}
func (r *stubSettingRepoForHandler) Get(_ context.Context, key string) (*service.Setting, error) {
if v, ok := r.values[key]; ok {
return &service.Setting{Key: key, Value: v}, nil
}
return nil, service.ErrSettingNotFound
}
func (r *stubSettingRepoForHandler) GetValue(_ context.Context, key string) (string, error) {
if v, ok := r.values[key]; ok {
return v, nil
}
return "", service.ErrSettingNotFound
}
func (r *stubSettingRepoForHandler) Set(_ context.Context, key, value string) error {
r.values[key] = value
return nil
}
func (r *stubSettingRepoForHandler) GetMultiple(_ context.Context, keys []string) (map[string]string, error) {
result := make(map[string]string)
for _, k := range keys {
if v, ok := r.values[k]; ok {
result[k] = v
}
}
return result, nil
}
func (r *stubSettingRepoForHandler) SetMultiple(_ context.Context, settings map[string]string) error {
for k, v := range settings {
r.values[k] = v
}
return nil
}
func (r *stubSettingRepoForHandler) GetAll(_ context.Context) (map[string]string, error) {
return r.values, nil
}
func (r *stubSettingRepoForHandler) Delete(_ context.Context, key string) error {
delete(r.values, key)
return nil
}
// ==================== S3 / MediaStorage 辅助函数 ====================
// newS3StorageForHandler 创建指向指定 endpoint 的 S3Storage(用于测试)。
func newS3StorageForHandler(endpoint string) *service.SoraS3Storage {
settingRepo := newStubSettingRepoForHandler(map[string]string{
"sora_s3_enabled": "true",
"sora_s3_endpoint": endpoint,
"sora_s3_region": "us-east-1",
"sora_s3_bucket": "test-bucket",
"sora_s3_access_key_id": "AKIATEST",
"sora_s3_secret_access_key": "test-secret",
"sora_s3_prefix": "sora",
"sora_s3_force_path_style": "true",
})
settingService := service.NewSettingService(settingRepo, &config.Config{})
return service.NewSoraS3Storage(settingService)
}
// newFakeSourceServer 创建返回固定内容的 HTTP 服务器(模拟上游媒体文件)。
func newFakeSourceServer() *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "video/mp4")
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("fake video data for test"))
}))
}
// newFakeS3Server 创建模拟 S3 的 HTTP 服务器。
// mode: "ok" 接受所有请求,"fail" 返回 403,"fail-second" 第一次成功第二次失败。
func newFakeS3Server(mode string) *httptest.Server {
var counter atomic.Int32
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = io.Copy(io.Discard, r.Body)
_ = r.Body.Close()
switch mode {
case "ok":
w.Header().Set("ETag", `"test-etag"`)
w.WriteHeader(http.StatusOK)
case "fail":
w.WriteHeader(http.StatusForbidden)
_, _ = w.Write([]byte(`AccessDenied`))
case "fail-second":
n := counter.Add(1)
if n <= 1 {
w.Header().Set("ETag", `"test-etag"`)
w.WriteHeader(http.StatusOK)
} else {
w.WriteHeader(http.StatusForbidden)
_, _ = w.Write([]byte(`AccessDenied`))
}
}
}))
}
// ==================== processGeneration 直接调用测试 ====================
func TestProcessGeneration_MarkGeneratingFails(t *testing.T) {
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
repo.updateErr = fmt.Errorf("db error")
genService := service.NewSoraGenerationService(repo, nil, nil)
h := &SoraClientHandler{genService: genService}
// 直接调用(非 goroutine),MarkGenerating 失败 → 早退
h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1)
// MarkGenerating 在调用 repo.Update 前已修改内存对象为 "generating"
// repo.Update 返回错误 → processGeneration 早退,不会继续到 MarkFailed
// 因此 ErrorMessage 为空(证明未调用 MarkFailed)
require.Equal(t, "generating", repo.gens[1].Status)
require.Empty(t, repo.gens[1].ErrorMessage)
}
func TestProcessGeneration_GatewayServiceNil(t *testing.T) {
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
genService := service.NewSoraGenerationService(repo, nil, nil)
h := &SoraClientHandler{genService: genService}
// gatewayService 未设置 → MarkFailed
h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1)
require.Equal(t, "failed", repo.gens[1].Status)
require.Contains(t, repo.gens[1].ErrorMessage, "gatewayService")
}
// ==================== storeMediaWithDegradation: S3 路径 ====================
func TestStoreMediaWithDegradation_S3SuccessSingleURL(t *testing.T) {
sourceServer := newFakeSourceServer()
defer sourceServer.Close()
fakeS3 := newFakeS3Server("ok")
defer fakeS3.Close()
s3Storage := newS3StorageForHandler(fakeS3.URL)
h := &SoraClientHandler{s3Storage: s3Storage}
storedURL, storedURLs, storageType, s3Keys, fileSize := h.storeMediaWithDegradation(
context.Background(), 1, "video", sourceServer.URL+"/v.mp4", nil,
)
require.Equal(t, service.SoraStorageTypeS3, storageType)
require.Len(t, s3Keys, 1)
require.NotEmpty(t, s3Keys[0])
require.Len(t, storedURLs, 1)
require.Equal(t, storedURL, storedURLs[0])
require.Contains(t, storedURL, fakeS3.URL)
require.Contains(t, storedURL, "/test-bucket/")
require.Greater(t, fileSize, int64(0))
}
func TestStoreMediaWithDegradation_S3SuccessMultiURL(t *testing.T) {
sourceServer := newFakeSourceServer()
defer sourceServer.Close()
fakeS3 := newFakeS3Server("ok")
defer fakeS3.Close()
s3Storage := newS3StorageForHandler(fakeS3.URL)
h := &SoraClientHandler{s3Storage: s3Storage}
urls := []string{sourceServer.URL + "/a.mp4", sourceServer.URL + "/b.mp4"}
storedURL, storedURLs, storageType, s3Keys, fileSize := h.storeMediaWithDegradation(
context.Background(), 1, "video", sourceServer.URL+"/a.mp4", urls,
)
require.Equal(t, service.SoraStorageTypeS3, storageType)
require.Len(t, s3Keys, 2)
require.Len(t, storedURLs, 2)
require.Equal(t, storedURL, storedURLs[0])
require.Contains(t, storedURLs[0], fakeS3.URL)
require.Contains(t, storedURLs[1], fakeS3.URL)
require.Greater(t, fileSize, int64(0))
}
func TestStoreMediaWithDegradation_S3DownloadFails(t *testing.T) {
// 上游返回 404 → 下载失败 → S3 上传不会开始
fakeS3 := newFakeS3Server("ok")
defer fakeS3.Close()
badSource := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
}))
defer badSource.Close()
s3Storage := newS3StorageForHandler(fakeS3.URL)
h := &SoraClientHandler{s3Storage: s3Storage}
_, _, storageType, _, _ := h.storeMediaWithDegradation(
context.Background(), 1, "video", badSource.URL+"/missing.mp4", nil,
)
require.Equal(t, service.SoraStorageTypeUpstream, storageType)
}
func TestStoreMediaWithDegradation_S3FailsSingleURL(t *testing.T) {
sourceServer := newFakeSourceServer()
defer sourceServer.Close()
fakeS3 := newFakeS3Server("fail")
defer fakeS3.Close()
s3Storage := newS3StorageForHandler(fakeS3.URL)
h := &SoraClientHandler{s3Storage: s3Storage}
_, _, storageType, s3Keys, _ := h.storeMediaWithDegradation(
context.Background(), 1, "video", sourceServer.URL+"/v.mp4", nil,
)
// S3 失败,降级到 upstream
require.Equal(t, service.SoraStorageTypeUpstream, storageType)
require.Nil(t, s3Keys)
}
func TestStoreMediaWithDegradation_S3PartialFailureCleanup(t *testing.T) {
sourceServer := newFakeSourceServer()
defer sourceServer.Close()
fakeS3 := newFakeS3Server("fail-second")
defer fakeS3.Close()
s3Storage := newS3StorageForHandler(fakeS3.URL)
h := &SoraClientHandler{s3Storage: s3Storage}
urls := []string{sourceServer.URL + "/a.mp4", sourceServer.URL + "/b.mp4"}
_, _, storageType, s3Keys, _ := h.storeMediaWithDegradation(
context.Background(), 1, "video", sourceServer.URL+"/a.mp4", urls,
)
// 第二个 URL 上传失败 → 清理已上传 → 降级到 upstream
require.Equal(t, service.SoraStorageTypeUpstream, storageType)
require.Nil(t, s3Keys)
}
// ==================== storeMediaWithDegradation: 本地存储路径 ====================
func TestStoreMediaWithDegradation_LocalStorageFails(t *testing.T) {
// 使用无效路径,EnsureLocalDirs 失败 → StoreFromURLs 返回 error
cfg := &config.Config{
Sora: config.SoraConfig{
Storage: config.SoraStorageConfig{
Type: "local",
LocalPath: "/dev/null/invalid_dir",
},
},
}
mediaStorage := service.NewSoraMediaStorage(cfg)
h := &SoraClientHandler{mediaStorage: mediaStorage}
_, _, storageType, _, _ := h.storeMediaWithDegradation(
context.Background(), 1, "video", "https://upstream.com/v.mp4", nil,
)
// 本地存储失败,降级到 upstream
require.Equal(t, service.SoraStorageTypeUpstream, storageType)
}
func TestStoreMediaWithDegradation_LocalStorageSuccess(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "sora-handler-test-*")
require.NoError(t, err)
defer os.RemoveAll(tmpDir)
sourceServer := newFakeSourceServer()
defer sourceServer.Close()
cfg := &config.Config{
Sora: config.SoraConfig{
Storage: config.SoraStorageConfig{
Type: "local",
LocalPath: tmpDir,
DownloadTimeoutSeconds: 5,
MaxDownloadBytes: 10 * 1024 * 1024,
},
},
}
mediaStorage := service.NewSoraMediaStorage(cfg)
h := &SoraClientHandler{mediaStorage: mediaStorage}
_, _, storageType, s3Keys, _ := h.storeMediaWithDegradation(
context.Background(), 1, "video", sourceServer.URL+"/v.mp4", nil,
)
require.Equal(t, service.SoraStorageTypeLocal, storageType)
require.Nil(t, s3Keys) // 本地存储不返回 S3 keys
}
func TestStoreMediaWithDegradation_S3FailsFallbackToLocal(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "sora-handler-test-*")
require.NoError(t, err)
defer os.RemoveAll(tmpDir)
sourceServer := newFakeSourceServer()
defer sourceServer.Close()
fakeS3 := newFakeS3Server("fail")
defer fakeS3.Close()
s3Storage := newS3StorageForHandler(fakeS3.URL)
cfg := &config.Config{
Sora: config.SoraConfig{
Storage: config.SoraStorageConfig{
Type: "local",
LocalPath: tmpDir,
DownloadTimeoutSeconds: 5,
MaxDownloadBytes: 10 * 1024 * 1024,
},
},
}
mediaStorage := service.NewSoraMediaStorage(cfg)
h := &SoraClientHandler{
s3Storage: s3Storage,
mediaStorage: mediaStorage,
}
_, _, storageType, _, _ := h.storeMediaWithDegradation(
context.Background(), 1, "video", sourceServer.URL+"/v.mp4", nil,
)
// S3 失败 → 本地存储成功
require.Equal(t, service.SoraStorageTypeLocal, storageType)
}
// ==================== SaveToStorage: S3 路径 ====================
func TestSaveToStorage_S3EnabledButUploadFails(t *testing.T) {
sourceServer := newFakeSourceServer()
defer sourceServer.Close()
fakeS3 := newFakeS3Server("fail")
defer fakeS3.Close()
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{
ID: 1, UserID: 1, Status: "completed",
StorageType: "upstream",
MediaURL: sourceServer.URL + "/v.mp4",
}
s3Storage := newS3StorageForHandler(fakeS3.URL)
genService := service.NewSoraGenerationService(repo, nil, nil)
h := &SoraClientHandler{genService: genService, s3Storage: s3Storage}
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}}
h.SaveToStorage(c)
require.Equal(t, http.StatusInternalServerError, rec.Code)
resp := parseResponse(t, rec)
require.Contains(t, resp["message"], "S3")
}
func TestSaveToStorage_UpstreamURLExpired(t *testing.T) {
expiredServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusForbidden)
}))
defer expiredServer.Close()
fakeS3 := newFakeS3Server("ok")
defer fakeS3.Close()
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{
ID: 1, UserID: 1, Status: "completed",
StorageType: "upstream",
MediaURL: expiredServer.URL + "/v.mp4",
}
s3Storage := newS3StorageForHandler(fakeS3.URL)
genService := service.NewSoraGenerationService(repo, nil, nil)
h := &SoraClientHandler{genService: genService, s3Storage: s3Storage}
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}}
h.SaveToStorage(c)
require.Equal(t, http.StatusGone, rec.Code)
resp := parseResponse(t, rec)
require.Contains(t, fmt.Sprint(resp["message"]), "过期")
}
func TestSaveToStorage_S3EnabledUploadSuccess(t *testing.T) {
sourceServer := newFakeSourceServer()
defer sourceServer.Close()
fakeS3 := newFakeS3Server("ok")
defer fakeS3.Close()
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{
ID: 1, UserID: 1, Status: "completed",
StorageType: "upstream",
MediaURL: sourceServer.URL + "/v.mp4",
}
s3Storage := newS3StorageForHandler(fakeS3.URL)
genService := service.NewSoraGenerationService(repo, nil, nil)
h := &SoraClientHandler{genService: genService, s3Storage: s3Storage}
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}}
h.SaveToStorage(c)
require.Equal(t, http.StatusOK, rec.Code)
resp := parseResponse(t, rec)
data := resp["data"].(map[string]any)
require.Contains(t, data["message"], "S3")
require.NotEmpty(t, data["object_key"])
// 验证记录已更新为 S3 存储
require.Equal(t, service.SoraStorageTypeS3, repo.gens[1].StorageType)
}
func TestSaveToStorage_S3EnabledUploadSuccess_MultiMediaURLs(t *testing.T) {
sourceServer := newFakeSourceServer()
defer sourceServer.Close()
fakeS3 := newFakeS3Server("ok")
defer fakeS3.Close()
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{
ID: 1, UserID: 1, Status: "completed",
StorageType: "upstream",
MediaURL: sourceServer.URL + "/v1.mp4",
MediaURLs: []string{
sourceServer.URL + "/v1.mp4",
sourceServer.URL + "/v2.mp4",
},
}
s3Storage := newS3StorageForHandler(fakeS3.URL)
genService := service.NewSoraGenerationService(repo, nil, nil)
h := &SoraClientHandler{genService: genService, s3Storage: s3Storage}
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}}
h.SaveToStorage(c)
require.Equal(t, http.StatusOK, rec.Code)
resp := parseResponse(t, rec)
data := resp["data"].(map[string]any)
require.Len(t, data["object_keys"].([]any), 2)
require.Equal(t, service.SoraStorageTypeS3, repo.gens[1].StorageType)
require.Len(t, repo.gens[1].S3ObjectKeys, 2)
require.Len(t, repo.gens[1].MediaURLs, 2)
}
func TestSaveToStorage_S3EnabledUploadSuccessWithQuota(t *testing.T) {
sourceServer := newFakeSourceServer()
defer sourceServer.Close()
fakeS3 := newFakeS3Server("ok")
defer fakeS3.Close()
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{
ID: 1, UserID: 1, Status: "completed",
StorageType: "upstream",
MediaURL: sourceServer.URL + "/v.mp4",
}
s3Storage := newS3StorageForHandler(fakeS3.URL)
genService := service.NewSoraGenerationService(repo, nil, nil)
userRepo := newStubUserRepoForHandler()
userRepo.users[1] = &service.User{
ID: 1,
SoraStorageQuotaBytes: 100 * 1024 * 1024,
SoraStorageUsedBytes: 0,
}
quotaService := service.NewSoraQuotaService(userRepo, nil, nil)
h := &SoraClientHandler{genService: genService, s3Storage: s3Storage, quotaService: quotaService}
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}}
h.SaveToStorage(c)
require.Equal(t, http.StatusOK, rec.Code)
// 验证配额已累加
require.Greater(t, userRepo.users[1].SoraStorageUsedBytes, int64(0))
}
func TestSaveToStorage_S3UploadSuccessMarkCompletedFails(t *testing.T) {
sourceServer := newFakeSourceServer()
defer sourceServer.Close()
fakeS3 := newFakeS3Server("ok")
defer fakeS3.Close()
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{
ID: 1, UserID: 1, Status: "completed",
StorageType: "upstream",
MediaURL: sourceServer.URL + "/v.mp4",
}
// S3 上传成功后,MarkCompleted 会调用 repo.Update → 失败
repo.updateErr = fmt.Errorf("db error")
s3Storage := newS3StorageForHandler(fakeS3.URL)
genService := service.NewSoraGenerationService(repo, nil, nil)
h := &SoraClientHandler{genService: genService, s3Storage: s3Storage}
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}}
h.SaveToStorage(c)
require.Equal(t, http.StatusInternalServerError, rec.Code)
}
// ==================== GetStorageStatus: S3 路径 ====================
func TestGetStorageStatus_S3EnabledNotHealthy(t *testing.T) {
// S3 启用但 TestConnection 失败(fake 端点不响应 HeadBucket)
fakeS3 := newFakeS3Server("fail")
defer fakeS3.Close()
s3Storage := newS3StorageForHandler(fakeS3.URL)
h := &SoraClientHandler{s3Storage: s3Storage}
c, rec := makeGinContext("GET", "/api/v1/sora/storage-status", "", 0)
h.GetStorageStatus(c)
require.Equal(t, http.StatusOK, rec.Code)
resp := parseResponse(t, rec)
data := resp["data"].(map[string]any)
require.Equal(t, true, data["s3_enabled"])
require.Equal(t, false, data["s3_healthy"])
}
func TestGetStorageStatus_S3EnabledHealthy(t *testing.T) {
fakeS3 := newFakeS3Server("ok")
defer fakeS3.Close()
s3Storage := newS3StorageForHandler(fakeS3.URL)
h := &SoraClientHandler{s3Storage: s3Storage}
c, rec := makeGinContext("GET", "/api/v1/sora/storage-status", "", 0)
h.GetStorageStatus(c)
require.Equal(t, http.StatusOK, rec.Code)
resp := parseResponse(t, rec)
data := resp["data"].(map[string]any)
require.Equal(t, true, data["s3_enabled"])
require.Equal(t, true, data["s3_healthy"])
}
// ==================== Stub: AccountRepository (用于 GatewayService) ====================
var _ service.AccountRepository = (*stubAccountRepoForHandler)(nil)
type stubAccountRepoForHandler struct {
accounts []service.Account
}
func (r *stubAccountRepoForHandler) Create(context.Context, *service.Account) error { return nil }
func (r *stubAccountRepoForHandler) GetByID(_ context.Context, id int64) (*service.Account, error) {
for i := range r.accounts {
if r.accounts[i].ID == id {
return &r.accounts[i], nil
}
}
return nil, fmt.Errorf("account not found")
}
func (r *stubAccountRepoForHandler) GetByIDs(context.Context, []int64) ([]*service.Account, error) {
return nil, nil
}
func (r *stubAccountRepoForHandler) ExistsByID(context.Context, int64) (bool, error) {
return false, nil
}
func (r *stubAccountRepoForHandler) GetByCRSAccountID(context.Context, string) (*service.Account, error) {
return nil, nil
}
func (r *stubAccountRepoForHandler) FindByExtraField(context.Context, string, any) ([]service.Account, error) {
return nil, nil
}
func (r *stubAccountRepoForHandler) ListCRSAccountIDs(context.Context) (map[string]int64, error) {
return nil, nil
}
func (r *stubAccountRepoForHandler) Update(context.Context, *service.Account) error { return nil }
func (r *stubAccountRepoForHandler) Delete(context.Context, int64) error { return nil }
func (r *stubAccountRepoForHandler) List(context.Context, pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (r *stubAccountRepoForHandler) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string, string, int64) ([]service.Account, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (r *stubAccountRepoForHandler) ListByGroup(context.Context, int64) ([]service.Account, error) {
return nil, nil
}
func (r *stubAccountRepoForHandler) ListActive(context.Context) ([]service.Account, error) {
return nil, nil
}
func (r *stubAccountRepoForHandler) ListByPlatform(context.Context, string) ([]service.Account, error) {
return nil, nil
}
func (r *stubAccountRepoForHandler) UpdateLastUsed(context.Context, int64) error { return nil }
func (r *stubAccountRepoForHandler) BatchUpdateLastUsed(context.Context, map[int64]time.Time) error {
return nil
}
func (r *stubAccountRepoForHandler) SetError(context.Context, int64, string) error { return nil }
func (r *stubAccountRepoForHandler) ClearError(context.Context, int64) error { return nil }
func (r *stubAccountRepoForHandler) SetSchedulable(context.Context, int64, bool) error {
return nil
}
func (r *stubAccountRepoForHandler) AutoPauseExpiredAccounts(context.Context, time.Time) (int64, error) {
return 0, nil
}
func (r *stubAccountRepoForHandler) BindGroups(context.Context, int64, []int64) error { return nil }
func (r *stubAccountRepoForHandler) ListSchedulable(context.Context) ([]service.Account, error) {
return r.accounts, nil
}
func (r *stubAccountRepoForHandler) ListSchedulableByGroupID(context.Context, int64) ([]service.Account, error) {
return r.accounts, nil
}
func (r *stubAccountRepoForHandler) ListSchedulableByPlatform(_ context.Context, _ string) ([]service.Account, error) {
return r.accounts, nil
}
func (r *stubAccountRepoForHandler) ListSchedulableByGroupIDAndPlatform(context.Context, int64, string) ([]service.Account, error) {
return r.accounts, nil
}
func (r *stubAccountRepoForHandler) ListSchedulableByPlatforms(context.Context, []string) ([]service.Account, error) {
return r.accounts, nil
}
func (r *stubAccountRepoForHandler) ListSchedulableByGroupIDAndPlatforms(context.Context, int64, []string) ([]service.Account, error) {
return r.accounts, nil
}
func (r *stubAccountRepoForHandler) SetRateLimited(context.Context, int64, time.Time) error {
return nil
}
func (r *stubAccountRepoForHandler) SetModelRateLimit(context.Context, int64, string, time.Time) error {
return nil
}
func (r *stubAccountRepoForHandler) SetOverloaded(context.Context, int64, time.Time) error {
return nil
}
func (r *stubAccountRepoForHandler) SetTempUnschedulable(context.Context, int64, time.Time, string) error {
return nil
}
func (r *stubAccountRepoForHandler) ClearTempUnschedulable(context.Context, int64) error { return nil }
func (r *stubAccountRepoForHandler) ClearRateLimit(context.Context, int64) error { return nil }
func (r *stubAccountRepoForHandler) ClearAntigravityQuotaScopes(context.Context, int64) error {
return nil
}
func (r *stubAccountRepoForHandler) ClearModelRateLimits(context.Context, int64) error { return nil }
func (r *stubAccountRepoForHandler) UpdateSessionWindow(context.Context, int64, *time.Time, *time.Time, string) error {
return nil
}
func (r *stubAccountRepoForHandler) UpdateExtra(context.Context, int64, map[string]any) error {
return nil
}
func (r *stubAccountRepoForHandler) BulkUpdate(context.Context, []int64, service.AccountBulkUpdate) (int64, error) {
return 0, nil
}
// ==================== Stub: SoraClient (用于 SoraGatewayService) ====================
var _ service.SoraClient = (*stubSoraClientForHandler)(nil)
type stubSoraClientForHandler struct {
videoStatus *service.SoraVideoTaskStatus
}
func (s *stubSoraClientForHandler) Enabled() bool { return true }
func (s *stubSoraClientForHandler) UploadImage(context.Context, *service.Account, []byte, string) (string, error) {
return "", nil
}
func (s *stubSoraClientForHandler) CreateImageTask(context.Context, *service.Account, service.SoraImageRequest) (string, error) {
return "task-image", nil
}
func (s *stubSoraClientForHandler) CreateVideoTask(context.Context, *service.Account, service.SoraVideoRequest) (string, error) {
return "task-video", nil
}
func (s *stubSoraClientForHandler) CreateStoryboardTask(context.Context, *service.Account, service.SoraStoryboardRequest) (string, error) {
return "task-video", nil
}
func (s *stubSoraClientForHandler) UploadCharacterVideo(context.Context, *service.Account, []byte) (string, error) {
return "", nil
}
func (s *stubSoraClientForHandler) GetCameoStatus(context.Context, *service.Account, string) (*service.SoraCameoStatus, error) {
return nil, nil
}
func (s *stubSoraClientForHandler) DownloadCharacterImage(context.Context, *service.Account, string) ([]byte, error) {
return nil, nil
}
func (s *stubSoraClientForHandler) UploadCharacterImage(context.Context, *service.Account, []byte) (string, error) {
return "", nil
}
func (s *stubSoraClientForHandler) FinalizeCharacter(context.Context, *service.Account, service.SoraCharacterFinalizeRequest) (string, error) {
return "", nil
}
func (s *stubSoraClientForHandler) SetCharacterPublic(context.Context, *service.Account, string) error {
return nil
}
func (s *stubSoraClientForHandler) DeleteCharacter(context.Context, *service.Account, string) error {
return nil
}
func (s *stubSoraClientForHandler) PostVideoForWatermarkFree(context.Context, *service.Account, string) (string, error) {
return "", nil
}
func (s *stubSoraClientForHandler) DeletePost(context.Context, *service.Account, string) error {
return nil
}
func (s *stubSoraClientForHandler) GetWatermarkFreeURLCustom(context.Context, *service.Account, string, string, string) (string, error) {
return "", nil
}
func (s *stubSoraClientForHandler) EnhancePrompt(context.Context, *service.Account, string, string, int) (string, error) {
return "", nil
}
func (s *stubSoraClientForHandler) GetImageTask(context.Context, *service.Account, string) (*service.SoraImageTaskStatus, error) {
return nil, nil
}
func (s *stubSoraClientForHandler) GetVideoTask(_ context.Context, _ *service.Account, _ string) (*service.SoraVideoTaskStatus, error) {
return s.videoStatus, nil
}
// ==================== 辅助:创建最小 GatewayService 和 SoraGatewayService ====================
// newMinimalGatewayService 创建仅包含 accountRepo 的最小 GatewayService(用于测试 SelectAccountForModel)。
func newMinimalGatewayService(accountRepo service.AccountRepository) *service.GatewayService {
return service.NewGatewayService(
accountRepo, nil, nil, nil, nil, nil, nil, nil,
nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil,
)
}
// newMinimalSoraGatewayService 创建最小 SoraGatewayService(用于测试 Forward)。
func newMinimalSoraGatewayService(soraClient service.SoraClient) *service.SoraGatewayService {
cfg := &config.Config{
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
PollIntervalSeconds: 1,
MaxPollAttempts: 1,
},
},
}
return service.NewSoraGatewayService(soraClient, nil, nil, cfg)
}
// ==================== processGeneration: 更多路径测试 ====================
func TestProcessGeneration_SelectAccountError(t *testing.T) {
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
genService := service.NewSoraGenerationService(repo, nil, nil)
// accountRepo 返回空列表 → SelectAccountForModel 返回 "no available accounts"
accountRepo := &stubAccountRepoForHandler{accounts: nil}
gatewayService := newMinimalGatewayService(accountRepo)
h := &SoraClientHandler{genService: genService, gatewayService: gatewayService}
h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1)
require.Equal(t, "failed", repo.gens[1].Status)
require.Contains(t, repo.gens[1].ErrorMessage, "选择账号失败")
}
func TestProcessGeneration_SoraGatewayServiceNil(t *testing.T) {
t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复")
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
genService := service.NewSoraGenerationService(repo, nil, nil)
// 提供可用账号使 SelectAccountForModel 成功
accountRepo := &stubAccountRepoForHandler{
accounts: []service.Account{
{ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true},
},
}
gatewayService := newMinimalGatewayService(accountRepo)
// soraGatewayService 为 nil
h := &SoraClientHandler{genService: genService, gatewayService: gatewayService}
h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1)
require.Equal(t, "failed", repo.gens[1].Status)
require.Contains(t, repo.gens[1].ErrorMessage, "soraGatewayService")
}
func TestProcessGeneration_ForwardError(t *testing.T) {
t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复")
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
genService := service.NewSoraGenerationService(repo, nil, nil)
accountRepo := &stubAccountRepoForHandler{
accounts: []service.Account{
{ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true},
},
}
gatewayService := newMinimalGatewayService(accountRepo)
// SoraClient 返回视频任务失败
soraClient := &stubSoraClientForHandler{
videoStatus: &service.SoraVideoTaskStatus{
Status: "failed",
ErrorMsg: "content policy violation",
},
}
soraGatewayService := newMinimalSoraGatewayService(soraClient)
h := &SoraClientHandler{
genService: genService,
gatewayService: gatewayService,
soraGatewayService: soraGatewayService,
}
h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test prompt", "video", "", 1)
require.Equal(t, "failed", repo.gens[1].Status)
require.Contains(t, repo.gens[1].ErrorMessage, "生成失败")
}
func TestProcessGeneration_ForwardErrorCancelled(t *testing.T) {
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
// MarkGenerating 内部调用 GetByID(第 1 次),Forward 失败后 processGeneration
// 调用 GetByID(第 2 次)。模拟外部在 Forward 期间取消了任务。
repo.getByIDOverrideAfterN = 1
repo.getByIDOverrideStatus = "cancelled"
genService := service.NewSoraGenerationService(repo, nil, nil)
accountRepo := &stubAccountRepoForHandler{
accounts: []service.Account{
{ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true},
},
}
gatewayService := newMinimalGatewayService(accountRepo)
soraClient := &stubSoraClientForHandler{
videoStatus: &service.SoraVideoTaskStatus{Status: "failed", ErrorMsg: "reject"},
}
soraGatewayService := newMinimalSoraGatewayService(soraClient)
h := &SoraClientHandler{
genService: genService,
gatewayService: gatewayService,
soraGatewayService: soraGatewayService,
}
h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1)
// Forward 失败后检测到外部取消,不应调用 MarkFailed(状态保持 generating)
require.Equal(t, "generating", repo.gens[1].Status)
}
func TestProcessGeneration_ForwardSuccessNoMediaURL(t *testing.T) {
t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复")
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
genService := service.NewSoraGenerationService(repo, nil, nil)
accountRepo := &stubAccountRepoForHandler{
accounts: []service.Account{
{ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true},
},
}
gatewayService := newMinimalGatewayService(accountRepo)
// SoraClient 返回 completed 但无 URL
soraClient := &stubSoraClientForHandler{
videoStatus: &service.SoraVideoTaskStatus{
Status: "completed",
URLs: nil, // 无 URL
},
}
soraGatewayService := newMinimalSoraGatewayService(soraClient)
h := &SoraClientHandler{
genService: genService,
gatewayService: gatewayService,
soraGatewayService: soraGatewayService,
}
h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1)
require.Equal(t, "failed", repo.gens[1].Status)
require.Contains(t, repo.gens[1].ErrorMessage, "未获取到媒体 URL")
}
func TestProcessGeneration_ForwardSuccessCancelledBeforeStore(t *testing.T) {
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
// MarkGenerating 调用 GetByID(第 1 次),之后 processGeneration 行 176 调用 GetByID(第 2 次)
// 第 2 次返回 "cancelled" 状态,模拟外部取消
repo.getByIDOverrideAfterN = 1
repo.getByIDOverrideStatus = "cancelled"
genService := service.NewSoraGenerationService(repo, nil, nil)
accountRepo := &stubAccountRepoForHandler{
accounts: []service.Account{
{ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true},
},
}
gatewayService := newMinimalGatewayService(accountRepo)
soraClient := &stubSoraClientForHandler{
videoStatus: &service.SoraVideoTaskStatus{
Status: "completed",
URLs: []string{"https://example.com/video.mp4"},
},
}
soraGatewayService := newMinimalSoraGatewayService(soraClient)
h := &SoraClientHandler{
genService: genService,
gatewayService: gatewayService,
soraGatewayService: soraGatewayService,
}
h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1)
// Forward 成功后检测到外部取消,不应调用存储和 MarkCompleted(状态保持 generating)
require.Equal(t, "generating", repo.gens[1].Status)
}
func TestProcessGeneration_FullSuccessUpstream(t *testing.T) {
t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复")
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
genService := service.NewSoraGenerationService(repo, nil, nil)
accountRepo := &stubAccountRepoForHandler{
accounts: []service.Account{
{ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true},
},
}
gatewayService := newMinimalGatewayService(accountRepo)
soraClient := &stubSoraClientForHandler{
videoStatus: &service.SoraVideoTaskStatus{
Status: "completed",
URLs: []string{"https://example.com/video.mp4"},
},
}
soraGatewayService := newMinimalSoraGatewayService(soraClient)
// 无 S3 和本地存储,降级到 upstream
h := &SoraClientHandler{
genService: genService,
gatewayService: gatewayService,
soraGatewayService: soraGatewayService,
}
h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test prompt", "video", "", 1)
require.Equal(t, "completed", repo.gens[1].Status)
require.Equal(t, service.SoraStorageTypeUpstream, repo.gens[1].StorageType)
require.NotEmpty(t, repo.gens[1].MediaURL)
}
func TestProcessGeneration_FullSuccessWithS3(t *testing.T) {
t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复")
sourceServer := newFakeSourceServer()
defer sourceServer.Close()
fakeS3 := newFakeS3Server("ok")
defer fakeS3.Close()
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
genService := service.NewSoraGenerationService(repo, nil, nil)
accountRepo := &stubAccountRepoForHandler{
accounts: []service.Account{
{ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true},
},
}
gatewayService := newMinimalGatewayService(accountRepo)
soraClient := &stubSoraClientForHandler{
videoStatus: &service.SoraVideoTaskStatus{
Status: "completed",
URLs: []string{sourceServer.URL + "/video.mp4"},
},
}
soraGatewayService := newMinimalSoraGatewayService(soraClient)
s3Storage := newS3StorageForHandler(fakeS3.URL)
userRepo := newStubUserRepoForHandler()
userRepo.users[1] = &service.User{
ID: 1, SoraStorageQuotaBytes: 100 * 1024 * 1024,
}
quotaService := service.NewSoraQuotaService(userRepo, nil, nil)
h := &SoraClientHandler{
genService: genService,
gatewayService: gatewayService,
soraGatewayService: soraGatewayService,
s3Storage: s3Storage,
quotaService: quotaService,
}
h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test prompt", "video", "", 1)
require.Equal(t, "completed", repo.gens[1].Status)
require.Equal(t, service.SoraStorageTypeS3, repo.gens[1].StorageType)
require.NotEmpty(t, repo.gens[1].S3ObjectKeys)
require.Greater(t, repo.gens[1].FileSizeBytes, int64(0))
// 验证配额已累加
require.Greater(t, userRepo.users[1].SoraStorageUsedBytes, int64(0))
}
func TestProcessGeneration_MarkCompletedFails(t *testing.T) {
t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复")
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
// 第 1 次 Update(MarkGenerating)成功,第 2 次(MarkCompleted)失败
repo.updateCallCount = new(int32)
repo.updateFailAfterN = 1
genService := service.NewSoraGenerationService(repo, nil, nil)
accountRepo := &stubAccountRepoForHandler{
accounts: []service.Account{
{ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true},
},
}
gatewayService := newMinimalGatewayService(accountRepo)
soraClient := &stubSoraClientForHandler{
videoStatus: &service.SoraVideoTaskStatus{
Status: "completed",
URLs: []string{"https://example.com/video.mp4"},
},
}
soraGatewayService := newMinimalSoraGatewayService(soraClient)
h := &SoraClientHandler{
genService: genService,
gatewayService: gatewayService,
soraGatewayService: soraGatewayService,
}
h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test prompt", "video", "", 1)
// MarkCompleted 内部先修改内存对象状态为 completed,然后 Update 失败。
// 由于 stub 存储的是指针,内存中的状态已被修改为 completed。
// 此测试验证 processGeneration 在 MarkCompleted 失败后提前返回(不调用 AddUsage)。
require.Equal(t, "completed", repo.gens[1].Status)
}
// ==================== cleanupStoredMedia 直接测试 ====================
func TestCleanupStoredMedia_S3Path(t *testing.T) {
// S3 清理路径:s3Storage 为 nil 时不 panic
h := &SoraClientHandler{}
// 不应 panic
h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeS3, []string{"key1"}, nil)
}
func TestCleanupStoredMedia_LocalPath(t *testing.T) {
// 本地清理路径:mediaStorage 为 nil 时不 panic
h := &SoraClientHandler{}
h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeLocal, nil, []string{"/tmp/test.mp4"})
}
func TestCleanupStoredMedia_UpstreamPath(t *testing.T) {
// upstream 类型不清理
h := &SoraClientHandler{}
h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeUpstream, nil, nil)
}
func TestCleanupStoredMedia_EmptyKeys(t *testing.T) {
// 空 keys 不触发清理
h := &SoraClientHandler{}
h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeS3, nil, nil)
h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeLocal, nil, nil)
}
// ==================== DeleteGeneration: 本地存储清理路径 ====================
func TestDeleteGeneration_LocalStorageCleanup(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "sora-delete-test-*")
require.NoError(t, err)
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Sora: config.SoraConfig{
Storage: config.SoraStorageConfig{
Type: "local",
LocalPath: tmpDir,
},
},
}
mediaStorage := service.NewSoraMediaStorage(cfg)
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{
ID: 1,
UserID: 1,
Status: "completed",
StorageType: service.SoraStorageTypeLocal,
MediaURL: "video/test.mp4",
MediaURLs: []string{"video/test.mp4"},
}
genService := service.NewSoraGenerationService(repo, nil, nil)
h := &SoraClientHandler{genService: genService, mediaStorage: mediaStorage}
c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}}
h.DeleteGeneration(c)
require.Equal(t, http.StatusOK, rec.Code)
_, exists := repo.gens[1]
require.False(t, exists)
}
func TestDeleteGeneration_LocalStorageCleanup_MediaURLFallback(t *testing.T) {
// MediaURLs 为空,使用 MediaURL 作为清理路径
tmpDir, err := os.MkdirTemp("", "sora-delete-fallback-*")
require.NoError(t, err)
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Sora: config.SoraConfig{
Storage: config.SoraStorageConfig{
Type: "local",
LocalPath: tmpDir,
},
},
}
mediaStorage := service.NewSoraMediaStorage(cfg)
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{
ID: 1,
UserID: 1,
Status: "completed",
StorageType: service.SoraStorageTypeLocal,
MediaURL: "video/test.mp4",
MediaURLs: nil, // 空
}
genService := service.NewSoraGenerationService(repo, nil, nil)
h := &SoraClientHandler{genService: genService, mediaStorage: mediaStorage}
c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}}
h.DeleteGeneration(c)
require.Equal(t, http.StatusOK, rec.Code)
}
func TestDeleteGeneration_NonLocalStorage_SkipCleanup(t *testing.T) {
// 非本地存储类型 → 跳过清理
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{
ID: 1,
UserID: 1,
Status: "completed",
StorageType: service.SoraStorageTypeUpstream,
MediaURL: "https://upstream.com/v.mp4",
}
genService := service.NewSoraGenerationService(repo, nil, nil)
h := &SoraClientHandler{genService: genService}
c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}}
h.DeleteGeneration(c)
require.Equal(t, http.StatusOK, rec.Code)
}
func TestDeleteGeneration_DeleteError(t *testing.T) {
// repo.Delete 出错
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed", StorageType: "upstream"}
repo.deleteErr = fmt.Errorf("delete failed")
genService := service.NewSoraGenerationService(repo, nil, nil)
h := &SoraClientHandler{genService: genService}
c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}}
h.DeleteGeneration(c)
require.Equal(t, http.StatusNotFound, rec.Code)
}
// ==================== fetchUpstreamModels 测试 ====================
func TestFetchUpstreamModels_NilGateway(t *testing.T) {
t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复")
h := &SoraClientHandler{}
_, err := h.fetchUpstreamModels(context.Background())
require.Error(t, err)
require.Contains(t, err.Error(), "gatewayService 未初始化")
}
func TestFetchUpstreamModels_NoAccounts(t *testing.T) {
t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复")
accountRepo := &stubAccountRepoForHandler{accounts: nil}
gatewayService := newMinimalGatewayService(accountRepo)
h := &SoraClientHandler{gatewayService: gatewayService}
_, err := h.fetchUpstreamModels(context.Background())
require.Error(t, err)
require.Contains(t, err.Error(), "选择 Sora 账号失败")
}
func TestFetchUpstreamModels_NonAPIKeyAccount(t *testing.T) {
t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复")
accountRepo := &stubAccountRepoForHandler{
accounts: []service.Account{
{ID: 1, Type: "oauth", Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true},
},
}
gatewayService := newMinimalGatewayService(accountRepo)
h := &SoraClientHandler{gatewayService: gatewayService}
_, err := h.fetchUpstreamModels(context.Background())
require.Error(t, err)
require.Contains(t, err.Error(), "不支持模型同步")
}
func TestFetchUpstreamModels_MissingAPIKey(t *testing.T) {
t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复")
accountRepo := &stubAccountRepoForHandler{
accounts: []service.Account{
{ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true,
Credentials: map[string]any{"base_url": "https://sora.test"}},
},
}
gatewayService := newMinimalGatewayService(accountRepo)
h := &SoraClientHandler{gatewayService: gatewayService}
_, err := h.fetchUpstreamModels(context.Background())
require.Error(t, err)
require.Contains(t, err.Error(), "api_key")
}
func TestFetchUpstreamModels_MissingBaseURL_FallsBackToDefault(t *testing.T) {
t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复")
// GetBaseURL() 在缺少 base_url 时返回默认值 "https://api.anthropic.com"
// 因此不会触发 "账号缺少 base_url" 错误,而是会尝试请求默认 URL 并失败
accountRepo := &stubAccountRepoForHandler{
accounts: []service.Account{
{ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true,
Credentials: map[string]any{"api_key": "sk-test"}},
},
}
gatewayService := newMinimalGatewayService(accountRepo)
h := &SoraClientHandler{gatewayService: gatewayService}
_, err := h.fetchUpstreamModels(context.Background())
require.Error(t, err)
}
func TestFetchUpstreamModels_UpstreamReturns500(t *testing.T) {
t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复")
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
}))
defer ts.Close()
accountRepo := &stubAccountRepoForHandler{
accounts: []service.Account{
{ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true,
Credentials: map[string]any{"api_key": "sk-test", "base_url": ts.URL}},
},
}
gatewayService := newMinimalGatewayService(accountRepo)
h := &SoraClientHandler{gatewayService: gatewayService}
_, err := h.fetchUpstreamModels(context.Background())
require.Error(t, err)
require.Contains(t, err.Error(), "状态码 500")
}
func TestFetchUpstreamModels_UpstreamReturnsInvalidJSON(t *testing.T) {
t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复")
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("not json"))
}))
defer ts.Close()
accountRepo := &stubAccountRepoForHandler{
accounts: []service.Account{
{ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true,
Credentials: map[string]any{"api_key": "sk-test", "base_url": ts.URL}},
},
}
gatewayService := newMinimalGatewayService(accountRepo)
h := &SoraClientHandler{gatewayService: gatewayService}
_, err := h.fetchUpstreamModels(context.Background())
require.Error(t, err)
require.Contains(t, err.Error(), "解析响应失败")
}
func TestFetchUpstreamModels_UpstreamReturnsEmptyList(t *testing.T) {
t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复")
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`{"data":[]}`))
}))
defer ts.Close()
accountRepo := &stubAccountRepoForHandler{
accounts: []service.Account{
{ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true,
Credentials: map[string]any{"api_key": "sk-test", "base_url": ts.URL}},
},
}
gatewayService := newMinimalGatewayService(accountRepo)
h := &SoraClientHandler{gatewayService: gatewayService}
_, err := h.fetchUpstreamModels(context.Background())
require.Error(t, err)
require.Contains(t, err.Error(), "空模型列表")
}
func TestFetchUpstreamModels_Success(t *testing.T) {
t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复")
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// 验证请求头
require.Equal(t, "Bearer sk-test", r.Header.Get("Authorization"))
require.True(t, strings.HasSuffix(r.URL.Path, "/sora/v1/models"))
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`{"data":[{"id":"sora2-landscape-10s"},{"id":"sora2-portrait-10s"},{"id":"sora2-landscape-15s"},{"id":"gpt-image"}]}`))
}))
defer ts.Close()
accountRepo := &stubAccountRepoForHandler{
accounts: []service.Account{
{ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true,
Credentials: map[string]any{"api_key": "sk-test", "base_url": ts.URL}},
},
}
gatewayService := newMinimalGatewayService(accountRepo)
h := &SoraClientHandler{gatewayService: gatewayService}
families, err := h.fetchUpstreamModels(context.Background())
require.NoError(t, err)
require.NotEmpty(t, families)
}
func TestFetchUpstreamModels_UnrecognizedModels(t *testing.T) {
t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复")
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`{"data":[{"id":"unknown-model-1"},{"id":"unknown-model-2"}]}`))
}))
defer ts.Close()
accountRepo := &stubAccountRepoForHandler{
accounts: []service.Account{
{ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true,
Credentials: map[string]any{"api_key": "sk-test", "base_url": ts.URL}},
},
}
gatewayService := newMinimalGatewayService(accountRepo)
h := &SoraClientHandler{gatewayService: gatewayService}
_, err := h.fetchUpstreamModels(context.Background())
require.Error(t, err)
require.Contains(t, err.Error(), "未能从上游模型列表中识别")
}
// ==================== getModelFamilies 缓存测试 ====================
func TestGetModelFamilies_CachesLocalConfig(t *testing.T) {
// gatewayService 为 nil → fetchUpstreamModels 失败 → 降级到本地配置
h := &SoraClientHandler{}
families := h.getModelFamilies(context.Background())
require.NotEmpty(t, families)
// 第二次调用应命中缓存(modelCacheUpstream=false → 使用短 TTL)
families2 := h.getModelFamilies(context.Background())
require.Equal(t, families, families2)
require.False(t, h.modelCacheUpstream)
}
func TestGetModelFamilies_CachesUpstreamResult(t *testing.T) {
t.Skip("TODO: 临时屏蔽依赖 Sora 上游模型同步的缓存测试,待账号选择逻辑稳定后恢复")
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`{"data":[{"id":"sora2-landscape-10s"},{"id":"gpt-image"}]}`))
}))
defer ts.Close()
accountRepo := &stubAccountRepoForHandler{
accounts: []service.Account{
{ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true,
Credentials: map[string]any{"api_key": "sk-test", "base_url": ts.URL}},
},
}
gatewayService := newMinimalGatewayService(accountRepo)
h := &SoraClientHandler{gatewayService: gatewayService}
families := h.getModelFamilies(context.Background())
require.NotEmpty(t, families)
require.True(t, h.modelCacheUpstream)
// 第二次调用命中缓存
families2 := h.getModelFamilies(context.Background())
require.Equal(t, families, families2)
}
func TestGetModelFamilies_ExpiredCacheRefreshes(t *testing.T) {
// 预设过期的缓存(modelCacheUpstream=false → 短 TTL)
h := &SoraClientHandler{
cachedFamilies: []service.SoraModelFamily{{ID: "old"}},
modelCacheTime: time.Now().Add(-10 * time.Minute), // 已过期
modelCacheUpstream: false,
}
// gatewayService 为 nil → fetchUpstreamModels 失败 → 使用本地配置刷新缓存
families := h.getModelFamilies(context.Background())
require.NotEmpty(t, families)
// 缓存已刷新,不再是 "old"
found := false
for _, f := range families {
if f.ID == "old" {
found = true
}
}
require.False(t, found, "过期缓存应被刷新")
}
// ==================== processGeneration: groupID 与 ForcePlatform ====================
func TestProcessGeneration_NilGroupID_WithGateway_SelectAccountFails(t *testing.T) {
// groupID 为 nil → 设置 ForcePlatform=sora → 无可用 sora 账号 → MarkFailed
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
genService := service.NewSoraGenerationService(repo, nil, nil)
// 空账号列表 → SelectAccountForModel 失败
accountRepo := &stubAccountRepoForHandler{accounts: nil}
gatewayService := newMinimalGatewayService(accountRepo)
h := &SoraClientHandler{
genService: genService,
gatewayService: gatewayService,
}
h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1)
require.Equal(t, "failed", repo.gens[1].Status)
require.Contains(t, repo.gens[1].ErrorMessage, "选择账号失败")
}
// ==================== Generate: 配额检查非 QuotaExceeded 错误 ====================
func TestGenerate_CheckQuotaNonQuotaError(t *testing.T) {
// quotaService.CheckQuota 返回非 QuotaExceededError → 返回 403
repo := newStubSoraGenRepo()
genService := service.NewSoraGenerationService(repo, nil, nil)
// 用户不存在 → GetByID 失败 → CheckQuota 返回普通 error
userRepo := newStubUserRepoForHandler()
quotaService := service.NewSoraQuotaService(userRepo, nil, nil)
h := NewSoraClientHandler(genService, quotaService, nil, nil, nil, nil, nil)
body := `{"model":"sora2-landscape-10s","prompt":"test"}`
c, rec := makeGinContext("POST", "/api/v1/sora/generate", body, 1)
h.Generate(c)
require.Equal(t, http.StatusForbidden, rec.Code)
}
// ==================== Generate: CreatePending 并发限制错误 ====================
// stubSoraGenRepoWithAtomicCreate 实现 soraGenerationRepoAtomicCreator 接口
type stubSoraGenRepoWithAtomicCreate struct {
stubSoraGenRepo
limitErr error
}
func (r *stubSoraGenRepoWithAtomicCreate) CreatePendingWithLimit(_ context.Context, gen *service.SoraGeneration, _ []string, _ int64) error {
if r.limitErr != nil {
return r.limitErr
}
return r.stubSoraGenRepo.Create(context.Background(), gen)
}
func TestGenerate_CreatePendingConcurrencyLimit(t *testing.T) {
repo := &stubSoraGenRepoWithAtomicCreate{
stubSoraGenRepo: *newStubSoraGenRepo(),
limitErr: service.ErrSoraGenerationConcurrencyLimit,
}
genService := service.NewSoraGenerationService(repo, nil, nil)
h := NewSoraClientHandler(genService, nil, nil, nil, nil, nil, nil)
body := `{"model":"sora2-landscape-10s","prompt":"test"}`
c, rec := makeGinContext("POST", "/api/v1/sora/generate", body, 1)
h.Generate(c)
require.Equal(t, http.StatusTooManyRequests, rec.Code)
resp := parseResponse(t, rec)
require.Contains(t, resp["message"], "3")
}
// ==================== SaveToStorage: 配额超限 ====================
func TestSaveToStorage_QuotaExceeded(t *testing.T) {
sourceServer := newFakeSourceServer()
defer sourceServer.Close()
fakeS3 := newFakeS3Server("ok")
defer fakeS3.Close()
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{
ID: 1, UserID: 1, Status: "completed",
StorageType: "upstream",
MediaURL: sourceServer.URL + "/v.mp4",
}
s3Storage := newS3StorageForHandler(fakeS3.URL)
genService := service.NewSoraGenerationService(repo, nil, nil)
// 用户配额已满
userRepo := newStubUserRepoForHandler()
userRepo.users[1] = &service.User{
ID: 1,
SoraStorageQuotaBytes: 10,
SoraStorageUsedBytes: 10,
}
quotaService := service.NewSoraQuotaService(userRepo, nil, nil)
h := &SoraClientHandler{genService: genService, s3Storage: s3Storage, quotaService: quotaService}
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}}
h.SaveToStorage(c)
require.Equal(t, http.StatusTooManyRequests, rec.Code)
}
// ==================== SaveToStorage: 配额非 QuotaExceeded 错误 ====================
func TestSaveToStorage_QuotaNonQuotaError(t *testing.T) {
sourceServer := newFakeSourceServer()
defer sourceServer.Close()
fakeS3 := newFakeS3Server("ok")
defer fakeS3.Close()
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{
ID: 1, UserID: 1, Status: "completed",
StorageType: "upstream",
MediaURL: sourceServer.URL + "/v.mp4",
}
s3Storage := newS3StorageForHandler(fakeS3.URL)
genService := service.NewSoraGenerationService(repo, nil, nil)
// 用户不存在 → GetByID 失败 → AddUsage 返回普通 error
userRepo := newStubUserRepoForHandler()
quotaService := service.NewSoraQuotaService(userRepo, nil, nil)
h := &SoraClientHandler{genService: genService, s3Storage: s3Storage, quotaService: quotaService}
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}}
h.SaveToStorage(c)
require.Equal(t, http.StatusInternalServerError, rec.Code)
}
// ==================== SaveToStorage: MediaURLs 全为空 ====================
func TestSaveToStorage_EmptyMediaURLs(t *testing.T) {
fakeS3 := newFakeS3Server("ok")
defer fakeS3.Close()
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{
ID: 1, UserID: 1, Status: "completed",
StorageType: "upstream",
MediaURL: "",
MediaURLs: []string{},
}
s3Storage := newS3StorageForHandler(fakeS3.URL)
genService := service.NewSoraGenerationService(repo, nil, nil)
h := &SoraClientHandler{genService: genService, s3Storage: s3Storage}
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}}
h.SaveToStorage(c)
require.Equal(t, http.StatusBadRequest, rec.Code)
resp := parseResponse(t, rec)
require.Contains(t, resp["message"], "已过期")
}
// ==================== SaveToStorage: S3 上传失败时已有已上传文件需清理 ====================
func TestSaveToStorage_MultiURL_SecondUploadFails(t *testing.T) {
sourceServer := newFakeSourceServer()
defer sourceServer.Close()
fakeS3 := newFakeS3Server("fail-second")
defer fakeS3.Close()
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{
ID: 1, UserID: 1, Status: "completed",
StorageType: "upstream",
MediaURL: sourceServer.URL + "/v1.mp4",
MediaURLs: []string{sourceServer.URL + "/v1.mp4", sourceServer.URL + "/v2.mp4"},
}
s3Storage := newS3StorageForHandler(fakeS3.URL)
genService := service.NewSoraGenerationService(repo, nil, nil)
h := &SoraClientHandler{genService: genService, s3Storage: s3Storage}
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}}
h.SaveToStorage(c)
require.Equal(t, http.StatusInternalServerError, rec.Code)
}
// ==================== SaveToStorage: UpdateStorageForCompleted 失败(含配额回滚) ====================
func TestSaveToStorage_MarkCompletedFailsWithQuotaRollback(t *testing.T) {
sourceServer := newFakeSourceServer()
defer sourceServer.Close()
fakeS3 := newFakeS3Server("ok")
defer fakeS3.Close()
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{
ID: 1, UserID: 1, Status: "completed",
StorageType: "upstream",
MediaURL: sourceServer.URL + "/v.mp4",
}
repo.updateErr = fmt.Errorf("db error")
s3Storage := newS3StorageForHandler(fakeS3.URL)
genService := service.NewSoraGenerationService(repo, nil, nil)
userRepo := newStubUserRepoForHandler()
userRepo.users[1] = &service.User{
ID: 1,
SoraStorageQuotaBytes: 100 * 1024 * 1024,
SoraStorageUsedBytes: 0,
}
quotaService := service.NewSoraQuotaService(userRepo, nil, nil)
h := &SoraClientHandler{genService: genService, s3Storage: s3Storage, quotaService: quotaService}
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}}
h.SaveToStorage(c)
require.Equal(t, http.StatusInternalServerError, rec.Code)
}
// ==================== cleanupStoredMedia: 实际 S3 删除路径 ====================
func TestCleanupStoredMedia_WithS3Storage_ActualDelete(t *testing.T) {
fakeS3 := newFakeS3Server("ok")
defer fakeS3.Close()
s3Storage := newS3StorageForHandler(fakeS3.URL)
h := &SoraClientHandler{s3Storage: s3Storage}
h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeS3, []string{"key1", "key2"}, nil)
}
func TestCleanupStoredMedia_S3DeleteFails_LogOnly(t *testing.T) {
fakeS3 := newFakeS3Server("fail")
defer fakeS3.Close()
s3Storage := newS3StorageForHandler(fakeS3.URL)
h := &SoraClientHandler{s3Storage: s3Storage}
h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeS3, []string{"key1"}, nil)
}
func TestCleanupStoredMedia_LocalDeleteFails_LogOnly(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "sora-cleanup-fail-*")
require.NoError(t, err)
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Sora: config.SoraConfig{
Storage: config.SoraStorageConfig{
Type: "local",
LocalPath: tmpDir,
},
},
}
mediaStorage := service.NewSoraMediaStorage(cfg)
h := &SoraClientHandler{mediaStorage: mediaStorage}
h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeLocal, nil, []string{"nonexistent/file.mp4"})
}
// ==================== DeleteGeneration: 本地文件删除失败(仅日志) ====================
func TestDeleteGeneration_LocalStorageDeleteFails_LogOnly(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "sora-del-test-*")
require.NoError(t, err)
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Sora: config.SoraConfig{
Storage: config.SoraStorageConfig{
Type: "local",
LocalPath: tmpDir,
},
},
}
mediaStorage := service.NewSoraMediaStorage(cfg)
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{
ID: 1, UserID: 1, Status: "completed",
StorageType: service.SoraStorageTypeLocal,
MediaURL: "nonexistent/video.mp4",
MediaURLs: []string{"nonexistent/video.mp4"},
}
genService := service.NewSoraGenerationService(repo, nil, nil)
h := &SoraClientHandler{genService: genService, mediaStorage: mediaStorage}
c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}}
h.DeleteGeneration(c)
require.Equal(t, http.StatusOK, rec.Code)
}
// ==================== CancelGeneration: 任务已结束冲突 ====================
func TestCancelGeneration_AlreadyCompleted(t *testing.T) {
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed"}
genService := service.NewSoraGenerationService(repo, nil, nil)
h := &SoraClientHandler{genService: genService}
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}}
h.CancelGeneration(c)
require.Equal(t, http.StatusConflict, rec.Code)
}