3136 lines
116 KiB
Go
3136 lines
116 KiB
Go
//go:build unit
|
||
|
||
package handler
|
||
|
||
import (
|
||
"context"
|
||
"encoding/json"
|
||
"fmt"
|
||
"io"
|
||
"net/http"
|
||
"net/http/httptest"
|
||
"os"
|
||
"strings"
|
||
"sync/atomic"
|
||
"testing"
|
||
"time"
|
||
|
||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||
"github.com/gin-gonic/gin"
|
||
"github.com/stretchr/testify/require"
|
||
)
|
||
|
||
func init() {
|
||
gin.SetMode(gin.TestMode)
|
||
}
|
||
|
||
// ==================== Stub: SoraGenerationRepository ====================
|
||
|
||
var _ service.SoraGenerationRepository = (*stubSoraGenRepo)(nil)
|
||
|
||
type stubSoraGenRepo struct {
|
||
gens map[int64]*service.SoraGeneration
|
||
nextID int64
|
||
createErr error
|
||
getErr error
|
||
updateErr error
|
||
deleteErr error
|
||
listErr error
|
||
countErr error
|
||
countValue int64
|
||
|
||
// 条件性 Update 失败:前 updateFailAfterN 次成功,之后失败
|
||
updateCallCount *int32
|
||
updateFailAfterN int32
|
||
|
||
// 条件性 GetByID 状态覆盖:前 getByIDOverrideAfterN 次正常返回,之后返回 overrideStatus
|
||
getByIDCallCount int32
|
||
getByIDOverrideAfterN int32 // 0 = 不覆盖
|
||
getByIDOverrideStatus string
|
||
}
|
||
|
||
func newStubSoraGenRepo() *stubSoraGenRepo {
|
||
return &stubSoraGenRepo{gens: make(map[int64]*service.SoraGeneration), nextID: 1}
|
||
}
|
||
|
||
func (r *stubSoraGenRepo) Create(_ context.Context, gen *service.SoraGeneration) error {
|
||
if r.createErr != nil {
|
||
return r.createErr
|
||
}
|
||
gen.ID = r.nextID
|
||
r.nextID++
|
||
r.gens[gen.ID] = gen
|
||
return nil
|
||
}
|
||
func (r *stubSoraGenRepo) GetByID(_ context.Context, id int64) (*service.SoraGeneration, error) {
|
||
if r.getErr != nil {
|
||
return nil, r.getErr
|
||
}
|
||
gen, ok := r.gens[id]
|
||
if !ok {
|
||
return nil, fmt.Errorf("not found")
|
||
}
|
||
// 条件性状态覆盖:模拟外部取消等场景
|
||
if r.getByIDOverrideAfterN > 0 {
|
||
n := atomic.AddInt32(&r.getByIDCallCount, 1)
|
||
if n > r.getByIDOverrideAfterN {
|
||
cp := *gen
|
||
cp.Status = r.getByIDOverrideStatus
|
||
return &cp, nil
|
||
}
|
||
}
|
||
return gen, nil
|
||
}
|
||
func (r *stubSoraGenRepo) Update(_ context.Context, gen *service.SoraGeneration) error {
|
||
// 条件性失败:前 N 次成功,之后失败
|
||
if r.updateCallCount != nil {
|
||
n := atomic.AddInt32(r.updateCallCount, 1)
|
||
if n > r.updateFailAfterN {
|
||
return fmt.Errorf("conditional update error (call #%d)", n)
|
||
}
|
||
}
|
||
if r.updateErr != nil {
|
||
return r.updateErr
|
||
}
|
||
r.gens[gen.ID] = gen
|
||
return nil
|
||
}
|
||
func (r *stubSoraGenRepo) Delete(_ context.Context, id int64) error {
|
||
if r.deleteErr != nil {
|
||
return r.deleteErr
|
||
}
|
||
delete(r.gens, id)
|
||
return nil
|
||
}
|
||
func (r *stubSoraGenRepo) List(_ context.Context, params service.SoraGenerationListParams) ([]*service.SoraGeneration, int64, error) {
|
||
if r.listErr != nil {
|
||
return nil, 0, r.listErr
|
||
}
|
||
var result []*service.SoraGeneration
|
||
for _, gen := range r.gens {
|
||
if gen.UserID != params.UserID {
|
||
continue
|
||
}
|
||
result = append(result, gen)
|
||
}
|
||
return result, int64(len(result)), nil
|
||
}
|
||
func (r *stubSoraGenRepo) CountByUserAndStatus(_ context.Context, _ int64, _ []string) (int64, error) {
|
||
if r.countErr != nil {
|
||
return 0, r.countErr
|
||
}
|
||
return r.countValue, nil
|
||
}
|
||
|
||
// ==================== 辅助函数 ====================
|
||
|
||
func newTestSoraClientHandler(repo *stubSoraGenRepo) *SoraClientHandler {
|
||
genService := service.NewSoraGenerationService(repo, nil, nil)
|
||
return &SoraClientHandler{genService: genService}
|
||
}
|
||
|
||
func makeGinContext(method, path, body string, userID int64) (*gin.Context, *httptest.ResponseRecorder) {
|
||
rec := httptest.NewRecorder()
|
||
c, _ := gin.CreateTestContext(rec)
|
||
if body != "" {
|
||
c.Request = httptest.NewRequest(method, path, strings.NewReader(body))
|
||
c.Request.Header.Set("Content-Type", "application/json")
|
||
} else {
|
||
c.Request = httptest.NewRequest(method, path, nil)
|
||
}
|
||
if userID > 0 {
|
||
c.Set("user_id", userID)
|
||
}
|
||
return c, rec
|
||
}
|
||
|
||
func parseResponse(t *testing.T, rec *httptest.ResponseRecorder) map[string]any {
|
||
t.Helper()
|
||
var resp map[string]any
|
||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||
return resp
|
||
}
|
||
|
||
// ==================== 纯函数测试: buildAsyncRequestBody ====================
|
||
|
||
func TestBuildAsyncRequestBody(t *testing.T) {
|
||
body := buildAsyncRequestBody("sora2-landscape-10s", "一只猫在跳舞", "", 1)
|
||
var parsed map[string]any
|
||
require.NoError(t, json.Unmarshal(body, &parsed))
|
||
require.Equal(t, "sora2-landscape-10s", parsed["model"])
|
||
require.Equal(t, false, parsed["stream"])
|
||
|
||
msgs := parsed["messages"].([]any)
|
||
require.Len(t, msgs, 1)
|
||
msg := msgs[0].(map[string]any)
|
||
require.Equal(t, "user", msg["role"])
|
||
require.Equal(t, "一只猫在跳舞", msg["content"])
|
||
}
|
||
|
||
func TestBuildAsyncRequestBody_EmptyPrompt(t *testing.T) {
|
||
body := buildAsyncRequestBody("gpt-image", "", "", 1)
|
||
var parsed map[string]any
|
||
require.NoError(t, json.Unmarshal(body, &parsed))
|
||
require.Equal(t, "gpt-image", parsed["model"])
|
||
msgs := parsed["messages"].([]any)
|
||
msg := msgs[0].(map[string]any)
|
||
require.Equal(t, "", msg["content"])
|
||
}
|
||
|
||
func TestBuildAsyncRequestBody_WithImageInput(t *testing.T) {
|
||
body := buildAsyncRequestBody("gpt-image", "一只猫", "https://example.com/ref.png", 1)
|
||
var parsed map[string]any
|
||
require.NoError(t, json.Unmarshal(body, &parsed))
|
||
require.Equal(t, "https://example.com/ref.png", parsed["image_input"])
|
||
}
|
||
|
||
func TestBuildAsyncRequestBody_WithVideoCount(t *testing.T) {
|
||
body := buildAsyncRequestBody("sora2-landscape-10s", "一只猫在跳舞", "", 3)
|
||
var parsed map[string]any
|
||
require.NoError(t, json.Unmarshal(body, &parsed))
|
||
require.Equal(t, float64(3), parsed["video_count"])
|
||
}
|
||
|
||
func TestNormalizeVideoCount(t *testing.T) {
|
||
require.Equal(t, 1, normalizeVideoCount("video", 0))
|
||
require.Equal(t, 2, normalizeVideoCount("video", 2))
|
||
require.Equal(t, 3, normalizeVideoCount("video", 5))
|
||
require.Equal(t, 1, normalizeVideoCount("image", 3))
|
||
}
|
||
|
||
// ==================== 纯函数测试: parseMediaURLsFromBody ====================
|
||
|
||
func TestParseMediaURLsFromBody_MediaURLs(t *testing.T) {
|
||
urls := parseMediaURLsFromBody([]byte(`{"media_urls":["https://a.com/1.mp4","https://a.com/2.mp4"]}`))
|
||
require.Equal(t, []string{"https://a.com/1.mp4", "https://a.com/2.mp4"}, urls)
|
||
}
|
||
|
||
func TestParseMediaURLsFromBody_SingleMediaURL(t *testing.T) {
|
||
urls := parseMediaURLsFromBody([]byte(`{"media_url":"https://a.com/video.mp4"}`))
|
||
require.Equal(t, []string{"https://a.com/video.mp4"}, urls)
|
||
}
|
||
|
||
func TestParseMediaURLsFromBody_EmptyBody(t *testing.T) {
|
||
require.Nil(t, parseMediaURLsFromBody(nil))
|
||
require.Nil(t, parseMediaURLsFromBody([]byte{}))
|
||
}
|
||
|
||
func TestParseMediaURLsFromBody_InvalidJSON(t *testing.T) {
|
||
require.Nil(t, parseMediaURLsFromBody([]byte("not json")))
|
||
}
|
||
|
||
func TestParseMediaURLsFromBody_NoMediaFields(t *testing.T) {
|
||
require.Nil(t, parseMediaURLsFromBody([]byte(`{"data":"something"}`)))
|
||
}
|
||
|
||
func TestParseMediaURLsFromBody_EmptyMediaURL(t *testing.T) {
|
||
require.Nil(t, parseMediaURLsFromBody([]byte(`{"media_url":""}`)))
|
||
}
|
||
|
||
func TestParseMediaURLsFromBody_EmptyMediaURLs(t *testing.T) {
|
||
require.Nil(t, parseMediaURLsFromBody([]byte(`{"media_urls":[]}`)))
|
||
}
|
||
|
||
func TestParseMediaURLsFromBody_MediaURLsPriority(t *testing.T) {
|
||
body := `{"media_url":"https://single.com/1.mp4","media_urls":["https://multi.com/a.mp4","https://multi.com/b.mp4"]}`
|
||
urls := parseMediaURLsFromBody([]byte(body))
|
||
require.Len(t, urls, 2)
|
||
require.Equal(t, "https://multi.com/a.mp4", urls[0])
|
||
}
|
||
|
||
func TestParseMediaURLsFromBody_FilterEmpty(t *testing.T) {
|
||
urls := parseMediaURLsFromBody([]byte(`{"media_urls":["https://a.com/1.mp4","","https://a.com/2.mp4"]}`))
|
||
require.Equal(t, []string{"https://a.com/1.mp4", "https://a.com/2.mp4"}, urls)
|
||
}
|
||
|
||
func TestParseMediaURLsFromBody_AllEmpty(t *testing.T) {
|
||
require.Nil(t, parseMediaURLsFromBody([]byte(`{"media_urls":["",""]}`)))
|
||
}
|
||
|
||
func TestParseMediaURLsFromBody_NonStringArray(t *testing.T) {
|
||
// media_urls 不是 string 数组
|
||
require.Nil(t, parseMediaURLsFromBody([]byte(`{"media_urls":"not-array"}`)))
|
||
}
|
||
|
||
func TestParseMediaURLsFromBody_MediaURLNotString(t *testing.T) {
|
||
require.Nil(t, parseMediaURLsFromBody([]byte(`{"media_url":123}`)))
|
||
}
|
||
|
||
// ==================== 纯函数测试: extractMediaURLsFromResult ====================
|
||
|
||
func TestExtractMediaURLsFromResult_OAuthPath(t *testing.T) {
|
||
result := &service.ForwardResult{MediaURL: "https://oauth.com/video.mp4"}
|
||
recorder := httptest.NewRecorder()
|
||
url, urls := extractMediaURLsFromResult(result, recorder)
|
||
require.Equal(t, "https://oauth.com/video.mp4", url)
|
||
require.Equal(t, []string{"https://oauth.com/video.mp4"}, urls)
|
||
}
|
||
|
||
func TestExtractMediaURLsFromResult_OAuthWithBody(t *testing.T) {
|
||
result := &service.ForwardResult{MediaURL: "https://oauth.com/video.mp4"}
|
||
recorder := httptest.NewRecorder()
|
||
_, _ = recorder.Write([]byte(`{"media_urls":["https://body.com/1.mp4","https://body.com/2.mp4"]}`))
|
||
url, urls := extractMediaURLsFromResult(result, recorder)
|
||
require.Equal(t, "https://body.com/1.mp4", url)
|
||
require.Len(t, urls, 2)
|
||
}
|
||
|
||
func TestExtractMediaURLsFromResult_APIKeyPath(t *testing.T) {
|
||
recorder := httptest.NewRecorder()
|
||
_, _ = recorder.Write([]byte(`{"media_url":"https://upstream.com/video.mp4"}`))
|
||
url, urls := extractMediaURLsFromResult(nil, recorder)
|
||
require.Equal(t, "https://upstream.com/video.mp4", url)
|
||
require.Equal(t, []string{"https://upstream.com/video.mp4"}, urls)
|
||
}
|
||
|
||
func TestExtractMediaURLsFromResult_NilResultEmptyBody(t *testing.T) {
|
||
recorder := httptest.NewRecorder()
|
||
url, urls := extractMediaURLsFromResult(nil, recorder)
|
||
require.Empty(t, url)
|
||
require.Nil(t, urls)
|
||
}
|
||
|
||
func TestExtractMediaURLsFromResult_EmptyMediaURL(t *testing.T) {
|
||
result := &service.ForwardResult{MediaURL: ""}
|
||
recorder := httptest.NewRecorder()
|
||
url, urls := extractMediaURLsFromResult(result, recorder)
|
||
require.Empty(t, url)
|
||
require.Nil(t, urls)
|
||
}
|
||
|
||
// ==================== getUserIDFromContext ====================
|
||
|
||
func TestGetUserIDFromContext_Int64(t *testing.T) {
|
||
c, _ := gin.CreateTestContext(httptest.NewRecorder())
|
||
c.Request = httptest.NewRequest("GET", "/", nil)
|
||
c.Set("user_id", int64(42))
|
||
require.Equal(t, int64(42), getUserIDFromContext(c))
|
||
}
|
||
|
||
func TestGetUserIDFromContext_AuthSubject(t *testing.T) {
|
||
c, _ := gin.CreateTestContext(httptest.NewRecorder())
|
||
c.Request = httptest.NewRequest("GET", "/", nil)
|
||
c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 777})
|
||
require.Equal(t, int64(777), getUserIDFromContext(c))
|
||
}
|
||
|
||
func TestGetUserIDFromContext_Float64(t *testing.T) {
|
||
c, _ := gin.CreateTestContext(httptest.NewRecorder())
|
||
c.Request = httptest.NewRequest("GET", "/", nil)
|
||
c.Set("user_id", float64(99))
|
||
require.Equal(t, int64(99), getUserIDFromContext(c))
|
||
}
|
||
|
||
func TestGetUserIDFromContext_String(t *testing.T) {
|
||
c, _ := gin.CreateTestContext(httptest.NewRecorder())
|
||
c.Request = httptest.NewRequest("GET", "/", nil)
|
||
c.Set("user_id", "123")
|
||
require.Equal(t, int64(123), getUserIDFromContext(c))
|
||
}
|
||
|
||
func TestGetUserIDFromContext_UserIDFallback(t *testing.T) {
|
||
c, _ := gin.CreateTestContext(httptest.NewRecorder())
|
||
c.Request = httptest.NewRequest("GET", "/", nil)
|
||
c.Set("userID", int64(55))
|
||
require.Equal(t, int64(55), getUserIDFromContext(c))
|
||
}
|
||
|
||
func TestGetUserIDFromContext_NoID(t *testing.T) {
|
||
c, _ := gin.CreateTestContext(httptest.NewRecorder())
|
||
c.Request = httptest.NewRequest("GET", "/", nil)
|
||
require.Equal(t, int64(0), getUserIDFromContext(c))
|
||
}
|
||
|
||
func TestGetUserIDFromContext_InvalidString(t *testing.T) {
|
||
c, _ := gin.CreateTestContext(httptest.NewRecorder())
|
||
c.Request = httptest.NewRequest("GET", "/", nil)
|
||
c.Set("user_id", "not-a-number")
|
||
require.Equal(t, int64(0), getUserIDFromContext(c))
|
||
}
|
||
|
||
// ==================== Handler: Generate ====================
|
||
|
||
func TestGenerate_Unauthorized(t *testing.T) {
|
||
h := newTestSoraClientHandler(newStubSoraGenRepo())
|
||
c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 0)
|
||
h.Generate(c)
|
||
require.Equal(t, http.StatusUnauthorized, rec.Code)
|
||
}
|
||
|
||
func TestGenerate_BadRequest_MissingModel(t *testing.T) {
|
||
h := newTestSoraClientHandler(newStubSoraGenRepo())
|
||
c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"prompt":"test"}`, 1)
|
||
h.Generate(c)
|
||
require.Equal(t, http.StatusBadRequest, rec.Code)
|
||
}
|
||
|
||
func TestGenerate_BadRequest_MissingPrompt(t *testing.T) {
|
||
h := newTestSoraClientHandler(newStubSoraGenRepo())
|
||
c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s"}`, 1)
|
||
h.Generate(c)
|
||
require.Equal(t, http.StatusBadRequest, rec.Code)
|
||
}
|
||
|
||
func TestGenerate_BadRequest_InvalidJSON(t *testing.T) {
|
||
h := newTestSoraClientHandler(newStubSoraGenRepo())
|
||
c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{invalid`, 1)
|
||
h.Generate(c)
|
||
require.Equal(t, http.StatusBadRequest, rec.Code)
|
||
}
|
||
|
||
func TestGenerate_TooManyRequests(t *testing.T) {
|
||
repo := newStubSoraGenRepo()
|
||
repo.countValue = 3
|
||
h := newTestSoraClientHandler(repo)
|
||
c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1)
|
||
h.Generate(c)
|
||
require.Equal(t, http.StatusTooManyRequests, rec.Code)
|
||
}
|
||
|
||
func TestGenerate_CountError(t *testing.T) {
|
||
repo := newStubSoraGenRepo()
|
||
repo.countErr = fmt.Errorf("db error")
|
||
h := newTestSoraClientHandler(repo)
|
||
c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1)
|
||
h.Generate(c)
|
||
require.Equal(t, http.StatusInternalServerError, rec.Code)
|
||
}
|
||
|
||
func TestGenerate_Success(t *testing.T) {
|
||
repo := newStubSoraGenRepo()
|
||
h := newTestSoraClientHandler(repo)
|
||
c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"测试生成"}`, 1)
|
||
h.Generate(c)
|
||
require.Equal(t, http.StatusOK, rec.Code)
|
||
resp := parseResponse(t, rec)
|
||
data := resp["data"].(map[string]any)
|
||
require.NotZero(t, data["generation_id"])
|
||
require.Equal(t, "pending", data["status"])
|
||
}
|
||
|
||
func TestGenerate_DefaultMediaType(t *testing.T) {
|
||
repo := newStubSoraGenRepo()
|
||
h := newTestSoraClientHandler(repo)
|
||
c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1)
|
||
h.Generate(c)
|
||
require.Equal(t, http.StatusOK, rec.Code)
|
||
require.Equal(t, "video", repo.gens[1].MediaType)
|
||
}
|
||
|
||
func TestGenerate_ImageMediaType(t *testing.T) {
|
||
repo := newStubSoraGenRepo()
|
||
h := newTestSoraClientHandler(repo)
|
||
c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"gpt-image","prompt":"test","media_type":"image"}`, 1)
|
||
h.Generate(c)
|
||
require.Equal(t, http.StatusOK, rec.Code)
|
||
require.Equal(t, "image", repo.gens[1].MediaType)
|
||
}
|
||
|
||
func TestGenerate_CreatePendingError(t *testing.T) {
|
||
repo := newStubSoraGenRepo()
|
||
repo.createErr = fmt.Errorf("create failed")
|
||
h := newTestSoraClientHandler(repo)
|
||
c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1)
|
||
h.Generate(c)
|
||
require.Equal(t, http.StatusInternalServerError, rec.Code)
|
||
}
|
||
|
||
func TestGenerate_NilQuotaServiceSkipsCheck(t *testing.T) {
|
||
repo := newStubSoraGenRepo()
|
||
h := newTestSoraClientHandler(repo)
|
||
c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1)
|
||
h.Generate(c)
|
||
require.Equal(t, http.StatusOK, rec.Code)
|
||
}
|
||
|
||
func TestGenerate_APIKeyInContext(t *testing.T) {
|
||
repo := newStubSoraGenRepo()
|
||
h := newTestSoraClientHandler(repo)
|
||
c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1)
|
||
c.Set("api_key_id", int64(42))
|
||
h.Generate(c)
|
||
require.Equal(t, http.StatusOK, rec.Code)
|
||
require.NotNil(t, repo.gens[1].APIKeyID)
|
||
require.Equal(t, int64(42), *repo.gens[1].APIKeyID)
|
||
}
|
||
|
||
func TestGenerate_NoAPIKeyInContext(t *testing.T) {
|
||
repo := newStubSoraGenRepo()
|
||
h := newTestSoraClientHandler(repo)
|
||
c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1)
|
||
h.Generate(c)
|
||
require.Equal(t, http.StatusOK, rec.Code)
|
||
require.Nil(t, repo.gens[1].APIKeyID)
|
||
}
|
||
|
||
func TestGenerate_ConcurrencyBoundary(t *testing.T) {
|
||
// activeCount == 2 应该允许
|
||
repo := newStubSoraGenRepo()
|
||
repo.countValue = 2
|
||
h := newTestSoraClientHandler(repo)
|
||
c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1)
|
||
h.Generate(c)
|
||
require.Equal(t, http.StatusOK, rec.Code)
|
||
}
|
||
|
||
// ==================== Handler: ListGenerations ====================
|
||
|
||
func TestListGenerations_Unauthorized(t *testing.T) {
|
||
h := newTestSoraClientHandler(newStubSoraGenRepo())
|
||
c, rec := makeGinContext("GET", "/api/v1/sora/generations", "", 0)
|
||
h.ListGenerations(c)
|
||
require.Equal(t, http.StatusUnauthorized, rec.Code)
|
||
}
|
||
|
||
func TestListGenerations_Success(t *testing.T) {
|
||
repo := newStubSoraGenRepo()
|
||
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Model: "sora2-landscape-10s", Status: "completed", StorageType: "upstream"}
|
||
repo.gens[2] = &service.SoraGeneration{ID: 2, UserID: 1, Model: "gpt-image", Status: "pending", StorageType: "none"}
|
||
repo.nextID = 3
|
||
h := newTestSoraClientHandler(repo)
|
||
c, rec := makeGinContext("GET", "/api/v1/sora/generations?page=1&page_size=10", "", 1)
|
||
h.ListGenerations(c)
|
||
require.Equal(t, http.StatusOK, rec.Code)
|
||
resp := parseResponse(t, rec)
|
||
data := resp["data"].(map[string]any)
|
||
items := data["data"].([]any)
|
||
require.Len(t, items, 2)
|
||
require.Equal(t, float64(2), data["total"])
|
||
}
|
||
|
||
func TestListGenerations_ListError(t *testing.T) {
|
||
repo := newStubSoraGenRepo()
|
||
repo.listErr = fmt.Errorf("db error")
|
||
h := newTestSoraClientHandler(repo)
|
||
c, rec := makeGinContext("GET", "/api/v1/sora/generations", "", 1)
|
||
h.ListGenerations(c)
|
||
require.Equal(t, http.StatusInternalServerError, rec.Code)
|
||
}
|
||
|
||
func TestListGenerations_DefaultPagination(t *testing.T) {
|
||
repo := newStubSoraGenRepo()
|
||
h := newTestSoraClientHandler(repo)
|
||
// 不传分页参数,应默认 page=1 page_size=20
|
||
c, rec := makeGinContext("GET", "/api/v1/sora/generations", "", 1)
|
||
h.ListGenerations(c)
|
||
require.Equal(t, http.StatusOK, rec.Code)
|
||
resp := parseResponse(t, rec)
|
||
data := resp["data"].(map[string]any)
|
||
require.Equal(t, float64(1), data["page"])
|
||
}
|
||
|
||
// ==================== Handler: GetGeneration ====================
|
||
|
||
func TestGetGeneration_Unauthorized(t *testing.T) {
|
||
h := newTestSoraClientHandler(newStubSoraGenRepo())
|
||
c, rec := makeGinContext("GET", "/api/v1/sora/generations/1", "", 0)
|
||
c.Params = gin.Params{{Key: "id", Value: "1"}}
|
||
h.GetGeneration(c)
|
||
require.Equal(t, http.StatusUnauthorized, rec.Code)
|
||
}
|
||
|
||
func TestGetGeneration_InvalidID(t *testing.T) {
|
||
h := newTestSoraClientHandler(newStubSoraGenRepo())
|
||
c, rec := makeGinContext("GET", "/api/v1/sora/generations/abc", "", 1)
|
||
c.Params = gin.Params{{Key: "id", Value: "abc"}}
|
||
h.GetGeneration(c)
|
||
require.Equal(t, http.StatusBadRequest, rec.Code)
|
||
}
|
||
|
||
func TestGetGeneration_NotFound(t *testing.T) {
|
||
h := newTestSoraClientHandler(newStubSoraGenRepo())
|
||
c, rec := makeGinContext("GET", "/api/v1/sora/generations/999", "", 1)
|
||
c.Params = gin.Params{{Key: "id", Value: "999"}}
|
||
h.GetGeneration(c)
|
||
require.Equal(t, http.StatusNotFound, rec.Code)
|
||
}
|
||
|
||
func TestGetGeneration_WrongUser(t *testing.T) {
|
||
repo := newStubSoraGenRepo()
|
||
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 2, Status: "completed"}
|
||
h := newTestSoraClientHandler(repo)
|
||
c, rec := makeGinContext("GET", "/api/v1/sora/generations/1", "", 1)
|
||
c.Params = gin.Params{{Key: "id", Value: "1"}}
|
||
h.GetGeneration(c)
|
||
require.Equal(t, http.StatusNotFound, rec.Code)
|
||
}
|
||
|
||
func TestGetGeneration_Success(t *testing.T) {
|
||
repo := newStubSoraGenRepo()
|
||
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Model: "sora2-landscape-10s", Status: "completed", StorageType: "upstream", MediaURL: "https://example.com/video.mp4"}
|
||
h := newTestSoraClientHandler(repo)
|
||
c, rec := makeGinContext("GET", "/api/v1/sora/generations/1", "", 1)
|
||
c.Params = gin.Params{{Key: "id", Value: "1"}}
|
||
h.GetGeneration(c)
|
||
require.Equal(t, http.StatusOK, rec.Code)
|
||
resp := parseResponse(t, rec)
|
||
data := resp["data"].(map[string]any)
|
||
require.Equal(t, float64(1), data["id"])
|
||
}
|
||
|
||
// ==================== Handler: DeleteGeneration ====================
|
||
|
||
func TestDeleteGeneration_Unauthorized(t *testing.T) {
|
||
h := newTestSoraClientHandler(newStubSoraGenRepo())
|
||
c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 0)
|
||
c.Params = gin.Params{{Key: "id", Value: "1"}}
|
||
h.DeleteGeneration(c)
|
||
require.Equal(t, http.StatusUnauthorized, rec.Code)
|
||
}
|
||
|
||
func TestDeleteGeneration_InvalidID(t *testing.T) {
|
||
h := newTestSoraClientHandler(newStubSoraGenRepo())
|
||
c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/abc", "", 1)
|
||
c.Params = gin.Params{{Key: "id", Value: "abc"}}
|
||
h.DeleteGeneration(c)
|
||
require.Equal(t, http.StatusBadRequest, rec.Code)
|
||
}
|
||
|
||
func TestDeleteGeneration_NotFound(t *testing.T) {
|
||
h := newTestSoraClientHandler(newStubSoraGenRepo())
|
||
c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/999", "", 1)
|
||
c.Params = gin.Params{{Key: "id", Value: "999"}}
|
||
h.DeleteGeneration(c)
|
||
require.Equal(t, http.StatusNotFound, rec.Code)
|
||
}
|
||
|
||
func TestDeleteGeneration_WrongUser(t *testing.T) {
|
||
repo := newStubSoraGenRepo()
|
||
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 2, Status: "completed"}
|
||
h := newTestSoraClientHandler(repo)
|
||
c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1)
|
||
c.Params = gin.Params{{Key: "id", Value: "1"}}
|
||
h.DeleteGeneration(c)
|
||
require.Equal(t, http.StatusNotFound, rec.Code)
|
||
}
|
||
|
||
func TestDeleteGeneration_Success(t *testing.T) {
|
||
repo := newStubSoraGenRepo()
|
||
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed"}
|
||
h := newTestSoraClientHandler(repo)
|
||
c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1)
|
||
c.Params = gin.Params{{Key: "id", Value: "1"}}
|
||
h.DeleteGeneration(c)
|
||
require.Equal(t, http.StatusOK, rec.Code)
|
||
_, exists := repo.gens[1]
|
||
require.False(t, exists)
|
||
}
|
||
|
||
// ==================== Handler: CancelGeneration ====================
|
||
|
||
func TestCancelGeneration_Unauthorized(t *testing.T) {
|
||
h := newTestSoraClientHandler(newStubSoraGenRepo())
|
||
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 0)
|
||
c.Params = gin.Params{{Key: "id", Value: "1"}}
|
||
h.CancelGeneration(c)
|
||
require.Equal(t, http.StatusUnauthorized, rec.Code)
|
||
}
|
||
|
||
func TestCancelGeneration_InvalidID(t *testing.T) {
|
||
h := newTestSoraClientHandler(newStubSoraGenRepo())
|
||
c, rec := makeGinContext("POST", "/api/v1/sora/generations/abc/cancel", "", 1)
|
||
c.Params = gin.Params{{Key: "id", Value: "abc"}}
|
||
h.CancelGeneration(c)
|
||
require.Equal(t, http.StatusBadRequest, rec.Code)
|
||
}
|
||
|
||
func TestCancelGeneration_NotFound(t *testing.T) {
|
||
h := newTestSoraClientHandler(newStubSoraGenRepo())
|
||
c, rec := makeGinContext("POST", "/api/v1/sora/generations/999/cancel", "", 1)
|
||
c.Params = gin.Params{{Key: "id", Value: "999"}}
|
||
h.CancelGeneration(c)
|
||
require.Equal(t, http.StatusNotFound, rec.Code)
|
||
}
|
||
|
||
func TestCancelGeneration_WrongUser(t *testing.T) {
|
||
repo := newStubSoraGenRepo()
|
||
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 2, Status: "pending"}
|
||
h := newTestSoraClientHandler(repo)
|
||
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1)
|
||
c.Params = gin.Params{{Key: "id", Value: "1"}}
|
||
h.CancelGeneration(c)
|
||
require.Equal(t, http.StatusNotFound, rec.Code)
|
||
}
|
||
|
||
func TestCancelGeneration_Pending(t *testing.T) {
|
||
repo := newStubSoraGenRepo()
|
||
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
|
||
h := newTestSoraClientHandler(repo)
|
||
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1)
|
||
c.Params = gin.Params{{Key: "id", Value: "1"}}
|
||
h.CancelGeneration(c)
|
||
require.Equal(t, http.StatusOK, rec.Code)
|
||
require.Equal(t, "cancelled", repo.gens[1].Status)
|
||
}
|
||
|
||
func TestCancelGeneration_Generating(t *testing.T) {
|
||
repo := newStubSoraGenRepo()
|
||
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "generating"}
|
||
h := newTestSoraClientHandler(repo)
|
||
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1)
|
||
c.Params = gin.Params{{Key: "id", Value: "1"}}
|
||
h.CancelGeneration(c)
|
||
require.Equal(t, http.StatusOK, rec.Code)
|
||
require.Equal(t, "cancelled", repo.gens[1].Status)
|
||
}
|
||
|
||
func TestCancelGeneration_Completed(t *testing.T) {
|
||
repo := newStubSoraGenRepo()
|
||
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed"}
|
||
h := newTestSoraClientHandler(repo)
|
||
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1)
|
||
c.Params = gin.Params{{Key: "id", Value: "1"}}
|
||
h.CancelGeneration(c)
|
||
require.Equal(t, http.StatusConflict, rec.Code)
|
||
}
|
||
|
||
func TestCancelGeneration_Failed(t *testing.T) {
|
||
repo := newStubSoraGenRepo()
|
||
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "failed"}
|
||
h := newTestSoraClientHandler(repo)
|
||
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1)
|
||
c.Params = gin.Params{{Key: "id", Value: "1"}}
|
||
h.CancelGeneration(c)
|
||
require.Equal(t, http.StatusConflict, rec.Code)
|
||
}
|
||
|
||
func TestCancelGeneration_Cancelled(t *testing.T) {
|
||
repo := newStubSoraGenRepo()
|
||
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "cancelled"}
|
||
h := newTestSoraClientHandler(repo)
|
||
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1)
|
||
c.Params = gin.Params{{Key: "id", Value: "1"}}
|
||
h.CancelGeneration(c)
|
||
require.Equal(t, http.StatusConflict, rec.Code)
|
||
}
|
||
|
||
// ==================== Handler: GetQuota ====================
|
||
|
||
func TestGetQuota_Unauthorized(t *testing.T) {
|
||
h := newTestSoraClientHandler(newStubSoraGenRepo())
|
||
c, rec := makeGinContext("GET", "/api/v1/sora/quota", "", 0)
|
||
h.GetQuota(c)
|
||
require.Equal(t, http.StatusUnauthorized, rec.Code)
|
||
}
|
||
|
||
func TestGetQuota_NilQuotaService(t *testing.T) {
|
||
h := newTestSoraClientHandler(newStubSoraGenRepo())
|
||
c, rec := makeGinContext("GET", "/api/v1/sora/quota", "", 1)
|
||
h.GetQuota(c)
|
||
require.Equal(t, http.StatusOK, rec.Code)
|
||
resp := parseResponse(t, rec)
|
||
data := resp["data"].(map[string]any)
|
||
require.Equal(t, "unlimited", data["source"])
|
||
}
|
||
|
||
// ==================== Handler: GetModels ====================
|
||
|
||
func TestGetModels(t *testing.T) {
|
||
h := newTestSoraClientHandler(newStubSoraGenRepo())
|
||
c, rec := makeGinContext("GET", "/api/v1/sora/models", "", 0)
|
||
h.GetModels(c)
|
||
require.Equal(t, http.StatusOK, rec.Code)
|
||
resp := parseResponse(t, rec)
|
||
data := resp["data"].([]any)
|
||
require.Len(t, data, 4)
|
||
// 验证类型分布
|
||
videoCount, imageCount := 0, 0
|
||
for _, item := range data {
|
||
m := item.(map[string]any)
|
||
if m["type"] == "video" {
|
||
videoCount++
|
||
} else if m["type"] == "image" {
|
||
imageCount++
|
||
}
|
||
}
|
||
require.Equal(t, 3, videoCount)
|
||
require.Equal(t, 1, imageCount)
|
||
}
|
||
|
||
// ==================== Handler: GetStorageStatus ====================
|
||
|
||
func TestGetStorageStatus_NilS3(t *testing.T) {
|
||
h := newTestSoraClientHandler(newStubSoraGenRepo())
|
||
c, rec := makeGinContext("GET", "/api/v1/sora/storage-status", "", 0)
|
||
h.GetStorageStatus(c)
|
||
require.Equal(t, http.StatusOK, rec.Code)
|
||
resp := parseResponse(t, rec)
|
||
data := resp["data"].(map[string]any)
|
||
require.Equal(t, false, data["s3_enabled"])
|
||
require.Equal(t, false, data["s3_healthy"])
|
||
require.Equal(t, false, data["local_enabled"])
|
||
}
|
||
|
||
func TestGetStorageStatus_LocalEnabled(t *testing.T) {
|
||
tmpDir, err := os.MkdirTemp("", "sora-storage-status-*")
|
||
require.NoError(t, err)
|
||
defer os.RemoveAll(tmpDir)
|
||
|
||
cfg := &config.Config{
|
||
Sora: config.SoraConfig{
|
||
Storage: config.SoraStorageConfig{
|
||
Type: "local",
|
||
LocalPath: tmpDir,
|
||
},
|
||
},
|
||
}
|
||
mediaStorage := service.NewSoraMediaStorage(cfg)
|
||
h := &SoraClientHandler{mediaStorage: mediaStorage}
|
||
|
||
c, rec := makeGinContext("GET", "/api/v1/sora/storage-status", "", 0)
|
||
h.GetStorageStatus(c)
|
||
require.Equal(t, http.StatusOK, rec.Code)
|
||
resp := parseResponse(t, rec)
|
||
data := resp["data"].(map[string]any)
|
||
require.Equal(t, false, data["s3_enabled"])
|
||
require.Equal(t, false, data["s3_healthy"])
|
||
require.Equal(t, true, data["local_enabled"])
|
||
}
|
||
|
||
// ==================== Handler: SaveToStorage ====================
|
||
|
||
func TestSaveToStorage_Unauthorized(t *testing.T) {
|
||
h := newTestSoraClientHandler(newStubSoraGenRepo())
|
||
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 0)
|
||
c.Params = gin.Params{{Key: "id", Value: "1"}}
|
||
h.SaveToStorage(c)
|
||
require.Equal(t, http.StatusUnauthorized, rec.Code)
|
||
}
|
||
|
||
func TestSaveToStorage_InvalidID(t *testing.T) {
|
||
h := newTestSoraClientHandler(newStubSoraGenRepo())
|
||
c, rec := makeGinContext("POST", "/api/v1/sora/generations/abc/save", "", 1)
|
||
c.Params = gin.Params{{Key: "id", Value: "abc"}}
|
||
h.SaveToStorage(c)
|
||
require.Equal(t, http.StatusBadRequest, rec.Code)
|
||
}
|
||
|
||
func TestSaveToStorage_NotFound(t *testing.T) {
|
||
h := newTestSoraClientHandler(newStubSoraGenRepo())
|
||
c, rec := makeGinContext("POST", "/api/v1/sora/generations/999/save", "", 1)
|
||
c.Params = gin.Params{{Key: "id", Value: "999"}}
|
||
h.SaveToStorage(c)
|
||
require.Equal(t, http.StatusNotFound, rec.Code)
|
||
}
|
||
|
||
func TestSaveToStorage_NotUpstream(t *testing.T) {
|
||
repo := newStubSoraGenRepo()
|
||
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed", StorageType: "s3", MediaURL: "https://example.com/v.mp4"}
|
||
h := newTestSoraClientHandler(repo)
|
||
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
|
||
c.Params = gin.Params{{Key: "id", Value: "1"}}
|
||
h.SaveToStorage(c)
|
||
require.Equal(t, http.StatusBadRequest, rec.Code)
|
||
}
|
||
|
||
func TestSaveToStorage_EmptyMediaURL(t *testing.T) {
|
||
repo := newStubSoraGenRepo()
|
||
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed", StorageType: "upstream", MediaURL: ""}
|
||
h := newTestSoraClientHandler(repo)
|
||
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
|
||
c.Params = gin.Params{{Key: "id", Value: "1"}}
|
||
h.SaveToStorage(c)
|
||
require.Equal(t, http.StatusBadRequest, rec.Code)
|
||
}
|
||
|
||
func TestSaveToStorage_S3Nil(t *testing.T) {
|
||
repo := newStubSoraGenRepo()
|
||
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed", StorageType: "upstream", MediaURL: "https://example.com/video.mp4"}
|
||
h := newTestSoraClientHandler(repo)
|
||
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
|
||
c.Params = gin.Params{{Key: "id", Value: "1"}}
|
||
h.SaveToStorage(c)
|
||
require.Equal(t, http.StatusServiceUnavailable, rec.Code)
|
||
resp := parseResponse(t, rec)
|
||
require.Contains(t, fmt.Sprint(resp["message"]), "云存储")
|
||
}
|
||
|
||
func TestSaveToStorage_WrongUser(t *testing.T) {
|
||
repo := newStubSoraGenRepo()
|
||
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 2, Status: "completed", StorageType: "upstream", MediaURL: "https://example.com/video.mp4"}
|
||
h := newTestSoraClientHandler(repo)
|
||
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
|
||
c.Params = gin.Params{{Key: "id", Value: "1"}}
|
||
h.SaveToStorage(c)
|
||
require.Equal(t, http.StatusNotFound, rec.Code)
|
||
}
|
||
|
||
// ==================== storeMediaWithDegradation — nil guard 路径 ====================
|
||
|
||
func TestStoreMediaWithDegradation_NilS3NilMedia(t *testing.T) {
|
||
h := &SoraClientHandler{}
|
||
url, urls, storageType, keys, size := h.storeMediaWithDegradation(
|
||
context.Background(), 1, "video", "https://upstream.com/v.mp4", nil,
|
||
)
|
||
require.Equal(t, service.SoraStorageTypeUpstream, storageType)
|
||
require.Equal(t, "https://upstream.com/v.mp4", url)
|
||
require.Equal(t, []string{"https://upstream.com/v.mp4"}, urls)
|
||
require.Nil(t, keys)
|
||
require.Equal(t, int64(0), size)
|
||
}
|
||
|
||
func TestStoreMediaWithDegradation_NilGuardsMultiURL(t *testing.T) {
|
||
h := &SoraClientHandler{}
|
||
url, urls, storageType, keys, size := h.storeMediaWithDegradation(
|
||
context.Background(), 1, "video", "https://upstream.com/v.mp4", []string{"https://a.com/1.mp4", "https://a.com/2.mp4"},
|
||
)
|
||
require.Equal(t, service.SoraStorageTypeUpstream, storageType)
|
||
require.Equal(t, "https://a.com/1.mp4", url)
|
||
require.Equal(t, []string{"https://a.com/1.mp4", "https://a.com/2.mp4"}, urls)
|
||
require.Nil(t, keys)
|
||
require.Equal(t, int64(0), size)
|
||
}
|
||
|
||
func TestStoreMediaWithDegradation_EmptyMediaURLsFallback(t *testing.T) {
|
||
h := &SoraClientHandler{}
|
||
url, _, storageType, _, _ := h.storeMediaWithDegradation(
|
||
context.Background(), 1, "video", "https://upstream.com/v.mp4", []string{},
|
||
)
|
||
require.Equal(t, service.SoraStorageTypeUpstream, storageType)
|
||
require.Equal(t, "https://upstream.com/v.mp4", url)
|
||
}
|
||
|
||
// ==================== Stub: UserRepository (用于 SoraQuotaService) ====================
|
||
|
||
var _ service.UserRepository = (*stubUserRepoForHandler)(nil)
|
||
|
||
type stubUserRepoForHandler struct {
|
||
users map[int64]*service.User
|
||
updateErr error
|
||
}
|
||
|
||
func newStubUserRepoForHandler() *stubUserRepoForHandler {
|
||
return &stubUserRepoForHandler{users: make(map[int64]*service.User)}
|
||
}
|
||
|
||
func (r *stubUserRepoForHandler) GetByID(_ context.Context, id int64) (*service.User, error) {
|
||
if u, ok := r.users[id]; ok {
|
||
return u, nil
|
||
}
|
||
return nil, fmt.Errorf("user not found")
|
||
}
|
||
func (r *stubUserRepoForHandler) Update(_ context.Context, user *service.User) error {
|
||
if r.updateErr != nil {
|
||
return r.updateErr
|
||
}
|
||
r.users[user.ID] = user
|
||
return nil
|
||
}
|
||
func (r *stubUserRepoForHandler) Create(context.Context, *service.User) error { return nil }
|
||
func (r *stubUserRepoForHandler) GetByEmail(context.Context, string) (*service.User, error) {
|
||
return nil, nil
|
||
}
|
||
func (r *stubUserRepoForHandler) GetFirstAdmin(context.Context) (*service.User, error) {
|
||
return nil, nil
|
||
}
|
||
func (r *stubUserRepoForHandler) Delete(context.Context, int64) error { return nil }
|
||
func (r *stubUserRepoForHandler) List(context.Context, pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) {
|
||
return nil, nil, nil
|
||
}
|
||
func (r *stubUserRepoForHandler) ListWithFilters(context.Context, pagination.PaginationParams, service.UserListFilters) ([]service.User, *pagination.PaginationResult, error) {
|
||
return nil, nil, nil
|
||
}
|
||
func (r *stubUserRepoForHandler) UpdateBalance(context.Context, int64, float64) error { return nil }
|
||
func (r *stubUserRepoForHandler) DeductBalance(context.Context, int64, float64) error { return nil }
|
||
func (r *stubUserRepoForHandler) UpdateConcurrency(context.Context, int64, int) error { return nil }
|
||
func (r *stubUserRepoForHandler) ExistsByEmail(context.Context, string) (bool, error) {
|
||
return false, nil
|
||
}
|
||
func (r *stubUserRepoForHandler) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) {
|
||
return 0, nil
|
||
}
|
||
func (r *stubUserRepoForHandler) UpdateTotpSecret(context.Context, int64, *string) error { return nil }
|
||
func (r *stubUserRepoForHandler) EnableTotp(context.Context, int64) error { return nil }
|
||
func (r *stubUserRepoForHandler) DisableTotp(context.Context, int64) error { return nil }
|
||
|
||
// ==================== NewSoraClientHandler ====================
|
||
|
||
func TestNewSoraClientHandler(t *testing.T) {
|
||
h := NewSoraClientHandler(nil, nil, nil, nil, nil, nil, nil)
|
||
require.NotNil(t, h)
|
||
}
|
||
|
||
func TestNewSoraClientHandler_WithAPIKeyService(t *testing.T) {
|
||
h := NewSoraClientHandler(nil, nil, nil, nil, nil, nil, nil)
|
||
require.NotNil(t, h)
|
||
require.Nil(t, h.apiKeyService)
|
||
}
|
||
|
||
// ==================== Stub: APIKeyRepository (用于 API Key 校验测试) ====================
|
||
|
||
var _ service.APIKeyRepository = (*stubAPIKeyRepoForHandler)(nil)
|
||
|
||
type stubAPIKeyRepoForHandler struct {
|
||
keys map[int64]*service.APIKey
|
||
getErr error
|
||
}
|
||
|
||
func newStubAPIKeyRepoForHandler() *stubAPIKeyRepoForHandler {
|
||
return &stubAPIKeyRepoForHandler{keys: make(map[int64]*service.APIKey)}
|
||
}
|
||
|
||
func (r *stubAPIKeyRepoForHandler) GetByID(_ context.Context, id int64) (*service.APIKey, error) {
|
||
if r.getErr != nil {
|
||
return nil, r.getErr
|
||
}
|
||
if k, ok := r.keys[id]; ok {
|
||
return k, nil
|
||
}
|
||
return nil, fmt.Errorf("api key not found: %d", id)
|
||
}
|
||
func (r *stubAPIKeyRepoForHandler) Create(context.Context, *service.APIKey) error { return nil }
|
||
func (r *stubAPIKeyRepoForHandler) GetKeyAndOwnerID(_ context.Context, _ int64) (string, int64, error) {
|
||
return "", 0, nil
|
||
}
|
||
func (r *stubAPIKeyRepoForHandler) GetByKey(context.Context, string) (*service.APIKey, error) {
|
||
return nil, nil
|
||
}
|
||
func (r *stubAPIKeyRepoForHandler) GetByKeyForAuth(context.Context, string) (*service.APIKey, error) {
|
||
return nil, nil
|
||
}
|
||
func (r *stubAPIKeyRepoForHandler) Update(context.Context, *service.APIKey) error { return nil }
|
||
func (r *stubAPIKeyRepoForHandler) Delete(context.Context, int64) error { return nil }
|
||
func (r *stubAPIKeyRepoForHandler) ListByUserID(_ context.Context, _ int64, _ pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
|
||
return nil, nil, nil
|
||
}
|
||
func (r *stubAPIKeyRepoForHandler) VerifyOwnership(context.Context, int64, []int64) ([]int64, error) {
|
||
return nil, nil
|
||
}
|
||
func (r *stubAPIKeyRepoForHandler) CountByUserID(context.Context, int64) (int64, error) {
|
||
return 0, nil
|
||
}
|
||
func (r *stubAPIKeyRepoForHandler) ExistsByKey(context.Context, string) (bool, error) {
|
||
return false, nil
|
||
}
|
||
func (r *stubAPIKeyRepoForHandler) ListByGroupID(_ context.Context, _ int64, _ pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
|
||
return nil, nil, nil
|
||
}
|
||
func (r *stubAPIKeyRepoForHandler) SearchAPIKeys(context.Context, int64, string, int) ([]service.APIKey, error) {
|
||
return nil, nil
|
||
}
|
||
func (r *stubAPIKeyRepoForHandler) ClearGroupIDByGroupID(context.Context, int64) (int64, error) {
|
||
return 0, nil
|
||
}
|
||
func (r *stubAPIKeyRepoForHandler) CountByGroupID(context.Context, int64) (int64, error) {
|
||
return 0, nil
|
||
}
|
||
func (r *stubAPIKeyRepoForHandler) ListKeysByUserID(context.Context, int64) ([]string, error) {
|
||
return nil, nil
|
||
}
|
||
func (r *stubAPIKeyRepoForHandler) ListKeysByGroupID(context.Context, int64) ([]string, error) {
|
||
return nil, nil
|
||
}
|
||
func (r *stubAPIKeyRepoForHandler) IncrementQuotaUsed(_ context.Context, _ int64, _ float64) (float64, error) {
|
||
return 0, nil
|
||
}
|
||
func (r *stubAPIKeyRepoForHandler) UpdateLastUsed(context.Context, int64, time.Time) error {
|
||
return nil
|
||
}
|
||
|
||
// newTestAPIKeyService 创建测试用的 APIKeyService
|
||
func newTestAPIKeyService(repo *stubAPIKeyRepoForHandler) *service.APIKeyService {
|
||
return service.NewAPIKeyService(repo, nil, nil, nil, nil, nil, &config.Config{})
|
||
}
|
||
|
||
// ==================== Generate: API Key 校验(前端传递 api_key_id)====================
|
||
|
||
func TestGenerate_WithAPIKeyID_Success(t *testing.T) {
|
||
// 前端传递 api_key_id,校验通过 → 成功生成,记录关联 api_key_id
|
||
repo := newStubSoraGenRepo()
|
||
genService := service.NewSoraGenerationService(repo, nil, nil)
|
||
|
||
groupID := int64(5)
|
||
apiKeyRepo := newStubAPIKeyRepoForHandler()
|
||
apiKeyRepo.keys[42] = &service.APIKey{
|
||
ID: 42,
|
||
UserID: 1,
|
||
Status: service.StatusAPIKeyActive,
|
||
GroupID: &groupID,
|
||
}
|
||
apiKeyService := newTestAPIKeyService(apiKeyRepo)
|
||
|
||
h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService}
|
||
c, rec := makeGinContext("POST", "/api/v1/sora/generate",
|
||
`{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1)
|
||
h.Generate(c)
|
||
require.Equal(t, http.StatusOK, rec.Code)
|
||
resp := parseResponse(t, rec)
|
||
data := resp["data"].(map[string]any)
|
||
require.NotZero(t, data["generation_id"])
|
||
|
||
// 验证 api_key_id 已关联到生成记录
|
||
gen := repo.gens[1]
|
||
require.NotNil(t, gen.APIKeyID)
|
||
require.Equal(t, int64(42), *gen.APIKeyID)
|
||
}
|
||
|
||
func TestGenerate_WithAPIKeyID_NotFound(t *testing.T) {
|
||
// 前端传递不存在的 api_key_id → 400
|
||
repo := newStubSoraGenRepo()
|
||
genService := service.NewSoraGenerationService(repo, nil, nil)
|
||
|
||
apiKeyRepo := newStubAPIKeyRepoForHandler()
|
||
apiKeyService := newTestAPIKeyService(apiKeyRepo)
|
||
|
||
h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService}
|
||
c, rec := makeGinContext("POST", "/api/v1/sora/generate",
|
||
`{"model":"sora2-landscape-10s","prompt":"test","api_key_id":999}`, 1)
|
||
h.Generate(c)
|
||
require.Equal(t, http.StatusBadRequest, rec.Code)
|
||
resp := parseResponse(t, rec)
|
||
require.Contains(t, fmt.Sprint(resp["message"]), "不存在")
|
||
}
|
||
|
||
func TestGenerate_WithAPIKeyID_WrongUser(t *testing.T) {
|
||
// 前端传递别人的 api_key_id → 403
|
||
repo := newStubSoraGenRepo()
|
||
genService := service.NewSoraGenerationService(repo, nil, nil)
|
||
|
||
apiKeyRepo := newStubAPIKeyRepoForHandler()
|
||
apiKeyRepo.keys[42] = &service.APIKey{
|
||
ID: 42,
|
||
UserID: 999, // 属于 user 999
|
||
Status: service.StatusAPIKeyActive,
|
||
}
|
||
apiKeyService := newTestAPIKeyService(apiKeyRepo)
|
||
|
||
h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService}
|
||
c, rec := makeGinContext("POST", "/api/v1/sora/generate",
|
||
`{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1)
|
||
h.Generate(c)
|
||
require.Equal(t, http.StatusForbidden, rec.Code)
|
||
resp := parseResponse(t, rec)
|
||
require.Contains(t, fmt.Sprint(resp["message"]), "不属于")
|
||
}
|
||
|
||
func TestGenerate_WithAPIKeyID_Disabled(t *testing.T) {
|
||
// 前端传递已禁用的 api_key_id → 403
|
||
repo := newStubSoraGenRepo()
|
||
genService := service.NewSoraGenerationService(repo, nil, nil)
|
||
|
||
apiKeyRepo := newStubAPIKeyRepoForHandler()
|
||
apiKeyRepo.keys[42] = &service.APIKey{
|
||
ID: 42,
|
||
UserID: 1,
|
||
Status: service.StatusAPIKeyDisabled,
|
||
}
|
||
apiKeyService := newTestAPIKeyService(apiKeyRepo)
|
||
|
||
h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService}
|
||
c, rec := makeGinContext("POST", "/api/v1/sora/generate",
|
||
`{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1)
|
||
h.Generate(c)
|
||
require.Equal(t, http.StatusForbidden, rec.Code)
|
||
resp := parseResponse(t, rec)
|
||
require.Contains(t, fmt.Sprint(resp["message"]), "不可用")
|
||
}
|
||
|
||
func TestGenerate_WithAPIKeyID_QuotaExhausted(t *testing.T) {
|
||
// 前端传递配额耗尽的 api_key_id → 403
|
||
repo := newStubSoraGenRepo()
|
||
genService := service.NewSoraGenerationService(repo, nil, nil)
|
||
|
||
apiKeyRepo := newStubAPIKeyRepoForHandler()
|
||
apiKeyRepo.keys[42] = &service.APIKey{
|
||
ID: 42,
|
||
UserID: 1,
|
||
Status: service.StatusAPIKeyQuotaExhausted,
|
||
}
|
||
apiKeyService := newTestAPIKeyService(apiKeyRepo)
|
||
|
||
h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService}
|
||
c, rec := makeGinContext("POST", "/api/v1/sora/generate",
|
||
`{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1)
|
||
h.Generate(c)
|
||
require.Equal(t, http.StatusForbidden, rec.Code)
|
||
}
|
||
|
||
func TestGenerate_WithAPIKeyID_Expired(t *testing.T) {
|
||
// 前端传递已过期的 api_key_id → 403
|
||
repo := newStubSoraGenRepo()
|
||
genService := service.NewSoraGenerationService(repo, nil, nil)
|
||
|
||
apiKeyRepo := newStubAPIKeyRepoForHandler()
|
||
apiKeyRepo.keys[42] = &service.APIKey{
|
||
ID: 42,
|
||
UserID: 1,
|
||
Status: service.StatusAPIKeyExpired,
|
||
}
|
||
apiKeyService := newTestAPIKeyService(apiKeyRepo)
|
||
|
||
h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService}
|
||
c, rec := makeGinContext("POST", "/api/v1/sora/generate",
|
||
`{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1)
|
||
h.Generate(c)
|
||
require.Equal(t, http.StatusForbidden, rec.Code)
|
||
}
|
||
|
||
func TestGenerate_WithAPIKeyID_NilAPIKeyService(t *testing.T) {
|
||
// apiKeyService 为 nil 时忽略 api_key_id → 正常生成但不记录 api_key_id
|
||
repo := newStubSoraGenRepo()
|
||
genService := service.NewSoraGenerationService(repo, nil, nil)
|
||
|
||
h := &SoraClientHandler{genService: genService} // apiKeyService = nil
|
||
c, rec := makeGinContext("POST", "/api/v1/sora/generate",
|
||
`{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1)
|
||
h.Generate(c)
|
||
require.Equal(t, http.StatusOK, rec.Code)
|
||
// apiKeyService 为 nil → 跳过校验 → api_key_id 不记录
|
||
require.Nil(t, repo.gens[1].APIKeyID)
|
||
}
|
||
|
||
func TestGenerate_WithAPIKeyID_NilGroupID(t *testing.T) {
|
||
// api_key 有效但 GroupID 为 nil → 成功,groupID 为 nil
|
||
repo := newStubSoraGenRepo()
|
||
genService := service.NewSoraGenerationService(repo, nil, nil)
|
||
|
||
apiKeyRepo := newStubAPIKeyRepoForHandler()
|
||
apiKeyRepo.keys[42] = &service.APIKey{
|
||
ID: 42,
|
||
UserID: 1,
|
||
Status: service.StatusAPIKeyActive,
|
||
GroupID: nil, // 无分组
|
||
}
|
||
apiKeyService := newTestAPIKeyService(apiKeyRepo)
|
||
|
||
h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService}
|
||
c, rec := makeGinContext("POST", "/api/v1/sora/generate",
|
||
`{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1)
|
||
h.Generate(c)
|
||
require.Equal(t, http.StatusOK, rec.Code)
|
||
require.NotNil(t, repo.gens[1].APIKeyID)
|
||
require.Equal(t, int64(42), *repo.gens[1].APIKeyID)
|
||
}
|
||
|
||
func TestGenerate_NoAPIKeyID_NoContext_NilResult(t *testing.T) {
|
||
// 既无 api_key_id 字段也无 context 中的 api_key_id → api_key_id 为 nil
|
||
repo := newStubSoraGenRepo()
|
||
genService := service.NewSoraGenerationService(repo, nil, nil)
|
||
apiKeyRepo := newStubAPIKeyRepoForHandler()
|
||
apiKeyService := newTestAPIKeyService(apiKeyRepo)
|
||
|
||
h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService}
|
||
c, rec := makeGinContext("POST", "/api/v1/sora/generate",
|
||
`{"model":"sora2-landscape-10s","prompt":"test"}`, 1)
|
||
h.Generate(c)
|
||
require.Equal(t, http.StatusOK, rec.Code)
|
||
require.Nil(t, repo.gens[1].APIKeyID)
|
||
}
|
||
|
||
func TestGenerate_WithAPIKeyIDInBody_OverridesContext(t *testing.T) {
|
||
// 同时有 body api_key_id 和 context api_key_id → 优先使用 body 的
|
||
repo := newStubSoraGenRepo()
|
||
genService := service.NewSoraGenerationService(repo, nil, nil)
|
||
|
||
groupID := int64(10)
|
||
apiKeyRepo := newStubAPIKeyRepoForHandler()
|
||
apiKeyRepo.keys[42] = &service.APIKey{
|
||
ID: 42,
|
||
UserID: 1,
|
||
Status: service.StatusAPIKeyActive,
|
||
GroupID: &groupID,
|
||
}
|
||
apiKeyService := newTestAPIKeyService(apiKeyRepo)
|
||
|
||
h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService}
|
||
c, rec := makeGinContext("POST", "/api/v1/sora/generate",
|
||
`{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1)
|
||
c.Set("api_key_id", int64(99)) // context 中有另一个 api_key_id
|
||
h.Generate(c)
|
||
require.Equal(t, http.StatusOK, rec.Code)
|
||
// 应使用 body 中的 api_key_id=42,而不是 context 中的 99
|
||
require.NotNil(t, repo.gens[1].APIKeyID)
|
||
require.Equal(t, int64(42), *repo.gens[1].APIKeyID)
|
||
}
|
||
|
||
func TestGenerate_WithContextAPIKeyID_FallbackPath(t *testing.T) {
|
||
// 无 body api_key_id,但 context 有 → 使用 context 中的(兼容网关路由)
|
||
repo := newStubSoraGenRepo()
|
||
genService := service.NewSoraGenerationService(repo, nil, nil)
|
||
apiKeyRepo := newStubAPIKeyRepoForHandler()
|
||
apiKeyService := newTestAPIKeyService(apiKeyRepo)
|
||
|
||
h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService}
|
||
c, rec := makeGinContext("POST", "/api/v1/sora/generate",
|
||
`{"model":"sora2-landscape-10s","prompt":"test"}`, 1)
|
||
c.Set("api_key_id", int64(99))
|
||
h.Generate(c)
|
||
require.Equal(t, http.StatusOK, rec.Code)
|
||
// 应使用 context 中的 api_key_id=99
|
||
require.NotNil(t, repo.gens[1].APIKeyID)
|
||
require.Equal(t, int64(99), *repo.gens[1].APIKeyID)
|
||
}
|
||
|
||
func TestGenerate_APIKeyID_Zero_IgnoredInJSON(t *testing.T) {
|
||
// JSON 中 api_key_id=0 被视为 omitempty → 仍然为指针值 0,需要传 nil 检查
|
||
repo := newStubSoraGenRepo()
|
||
genService := service.NewSoraGenerationService(repo, nil, nil)
|
||
apiKeyRepo := newStubAPIKeyRepoForHandler()
|
||
apiKeyService := newTestAPIKeyService(apiKeyRepo)
|
||
|
||
h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService}
|
||
// JSON 中传了 api_key_id: 0 → 解析后 *int64(0),会触发校验
|
||
// api_key_id=0 不存在 → 400
|
||
c, rec := makeGinContext("POST", "/api/v1/sora/generate",
|
||
`{"model":"sora2-landscape-10s","prompt":"test","api_key_id":0}`, 1)
|
||
h.Generate(c)
|
||
require.Equal(t, http.StatusBadRequest, rec.Code)
|
||
}
|
||
|
||
// ==================== processGeneration: groupID 传递与 ForcePlatform ====================
|
||
|
||
func TestProcessGeneration_WithGroupID_NoForcePlatform(t *testing.T) {
|
||
// groupID 不为 nil → 不设置 ForcePlatform
|
||
// gatewayService 为 nil → MarkFailed → 检查错误消息不包含 ForcePlatform 相关
|
||
repo := newStubSoraGenRepo()
|
||
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
|
||
genService := service.NewSoraGenerationService(repo, nil, nil)
|
||
h := &SoraClientHandler{genService: genService}
|
||
|
||
gid := int64(5)
|
||
h.processGeneration(1, 1, &gid, "sora2-landscape-10s", "test", "video", "", 1)
|
||
require.Equal(t, "failed", repo.gens[1].Status)
|
||
require.Contains(t, repo.gens[1].ErrorMessage, "gatewayService")
|
||
}
|
||
|
||
func TestProcessGeneration_NilGroupID_SetsForcePlatform(t *testing.T) {
|
||
// groupID 为 nil → 设置 ForcePlatform → gatewayService 为 nil → MarkFailed
|
||
repo := newStubSoraGenRepo()
|
||
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
|
||
genService := service.NewSoraGenerationService(repo, nil, nil)
|
||
h := &SoraClientHandler{genService: genService}
|
||
|
||
h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1)
|
||
require.Equal(t, "failed", repo.gens[1].Status)
|
||
require.Contains(t, repo.gens[1].ErrorMessage, "gatewayService")
|
||
}
|
||
|
||
func TestProcessGeneration_MarkGeneratingStateConflict(t *testing.T) {
|
||
// 任务状态已变化(如已取消)→ MarkGenerating 返回 ErrSoraGenerationStateConflict → 跳过
|
||
repo := newStubSoraGenRepo()
|
||
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "cancelled"}
|
||
genService := service.NewSoraGenerationService(repo, nil, nil)
|
||
h := &SoraClientHandler{genService: genService}
|
||
|
||
h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1)
|
||
// 状态为 cancelled 时 MarkGenerating 不符合状态转换规则 → 应保持 cancelled
|
||
require.Equal(t, "cancelled", repo.gens[1].Status)
|
||
}
|
||
|
||
// ==================== GenerateRequest JSON 解析 ====================
|
||
|
||
func TestGenerateRequest_WithAPIKeyID_JSONParsing(t *testing.T) {
|
||
// 验证 api_key_id 在 JSON 中正确解析为 *int64
|
||
var req GenerateRequest
|
||
err := json.Unmarshal([]byte(`{"model":"sora2","prompt":"test","api_key_id":42}`), &req)
|
||
require.NoError(t, err)
|
||
require.NotNil(t, req.APIKeyID)
|
||
require.Equal(t, int64(42), *req.APIKeyID)
|
||
}
|
||
|
||
func TestGenerateRequest_WithoutAPIKeyID_JSONParsing(t *testing.T) {
|
||
// 不传 api_key_id → 解析后为 nil
|
||
var req GenerateRequest
|
||
err := json.Unmarshal([]byte(`{"model":"sora2","prompt":"test"}`), &req)
|
||
require.NoError(t, err)
|
||
require.Nil(t, req.APIKeyID)
|
||
}
|
||
|
||
func TestGenerateRequest_NullAPIKeyID_JSONParsing(t *testing.T) {
|
||
// api_key_id: null → 解析后为 nil
|
||
var req GenerateRequest
|
||
err := json.Unmarshal([]byte(`{"model":"sora2","prompt":"test","api_key_id":null}`), &req)
|
||
require.NoError(t, err)
|
||
require.Nil(t, req.APIKeyID)
|
||
}
|
||
|
||
func TestGenerateRequest_FullFields_JSONParsing(t *testing.T) {
|
||
// 全字段解析
|
||
var req GenerateRequest
|
||
err := json.Unmarshal([]byte(`{
|
||
"model":"sora2-landscape-10s",
|
||
"prompt":"test prompt",
|
||
"media_type":"video",
|
||
"video_count":2,
|
||
"image_input":"data:image/png;base64,abc",
|
||
"api_key_id":100
|
||
}`), &req)
|
||
require.NoError(t, err)
|
||
require.Equal(t, "sora2-landscape-10s", req.Model)
|
||
require.Equal(t, "test prompt", req.Prompt)
|
||
require.Equal(t, "video", req.MediaType)
|
||
require.Equal(t, 2, req.VideoCount)
|
||
require.Equal(t, "data:image/png;base64,abc", req.ImageInput)
|
||
require.NotNil(t, req.APIKeyID)
|
||
require.Equal(t, int64(100), *req.APIKeyID)
|
||
}
|
||
|
||
func TestGenerateRequest_JSONSerialize_OmitsNilAPIKeyID(t *testing.T) {
|
||
// api_key_id 为 nil 时 JSON 序列化应省略
|
||
req := GenerateRequest{Model: "sora2", Prompt: "test"}
|
||
b, err := json.Marshal(req)
|
||
require.NoError(t, err)
|
||
var parsed map[string]any
|
||
require.NoError(t, json.Unmarshal(b, &parsed))
|
||
_, hasAPIKeyID := parsed["api_key_id"]
|
||
require.False(t, hasAPIKeyID, "api_key_id 为 nil 时应省略")
|
||
}
|
||
|
||
func TestGenerateRequest_JSONSerialize_IncludesAPIKeyID(t *testing.T) {
|
||
// api_key_id 不为 nil 时 JSON 序列化应包含
|
||
id := int64(42)
|
||
req := GenerateRequest{Model: "sora2", Prompt: "test", APIKeyID: &id}
|
||
b, err := json.Marshal(req)
|
||
require.NoError(t, err)
|
||
var parsed map[string]any
|
||
require.NoError(t, json.Unmarshal(b, &parsed))
|
||
require.Equal(t, float64(42), parsed["api_key_id"])
|
||
}
|
||
|
||
// ==================== GetQuota: 有配额服务 ====================
|
||
|
||
func TestGetQuota_WithQuotaService_Success(t *testing.T) {
|
||
userRepo := newStubUserRepoForHandler()
|
||
userRepo.users[1] = &service.User{
|
||
ID: 1,
|
||
SoraStorageQuotaBytes: 10 * 1024 * 1024,
|
||
SoraStorageUsedBytes: 3 * 1024 * 1024,
|
||
}
|
||
quotaService := service.NewSoraQuotaService(userRepo, nil, nil)
|
||
|
||
repo := newStubSoraGenRepo()
|
||
genService := service.NewSoraGenerationService(repo, nil, nil)
|
||
h := &SoraClientHandler{
|
||
genService: genService,
|
||
quotaService: quotaService,
|
||
}
|
||
|
||
c, rec := makeGinContext("GET", "/api/v1/sora/quota", "", 1)
|
||
h.GetQuota(c)
|
||
require.Equal(t, http.StatusOK, rec.Code)
|
||
resp := parseResponse(t, rec)
|
||
data := resp["data"].(map[string]any)
|
||
require.Equal(t, "user", data["source"])
|
||
require.Equal(t, float64(10*1024*1024), data["quota_bytes"])
|
||
require.Equal(t, float64(3*1024*1024), data["used_bytes"])
|
||
}
|
||
|
||
func TestGetQuota_WithQuotaService_Error(t *testing.T) {
|
||
// 用户不存在时 GetQuota 返回错误
|
||
userRepo := newStubUserRepoForHandler()
|
||
quotaService := service.NewSoraQuotaService(userRepo, nil, nil)
|
||
|
||
repo := newStubSoraGenRepo()
|
||
genService := service.NewSoraGenerationService(repo, nil, nil)
|
||
h := &SoraClientHandler{
|
||
genService: genService,
|
||
quotaService: quotaService,
|
||
}
|
||
|
||
c, rec := makeGinContext("GET", "/api/v1/sora/quota", "", 999)
|
||
h.GetQuota(c)
|
||
require.Equal(t, http.StatusInternalServerError, rec.Code)
|
||
}
|
||
|
||
// ==================== Generate: 配额检查 ====================
|
||
|
||
func TestGenerate_QuotaCheckFailed(t *testing.T) {
|
||
// 配额超限时返回 429
|
||
userRepo := newStubUserRepoForHandler()
|
||
userRepo.users[1] = &service.User{
|
||
ID: 1,
|
||
SoraStorageQuotaBytes: 1024,
|
||
SoraStorageUsedBytes: 1025, // 已超限
|
||
}
|
||
quotaService := service.NewSoraQuotaService(userRepo, nil, nil)
|
||
|
||
repo := newStubSoraGenRepo()
|
||
genService := service.NewSoraGenerationService(repo, nil, nil)
|
||
h := &SoraClientHandler{
|
||
genService: genService,
|
||
quotaService: quotaService,
|
||
}
|
||
|
||
c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1)
|
||
h.Generate(c)
|
||
require.Equal(t, http.StatusTooManyRequests, rec.Code)
|
||
}
|
||
|
||
func TestGenerate_QuotaCheckPassed(t *testing.T) {
|
||
// 配额充足时允许生成
|
||
userRepo := newStubUserRepoForHandler()
|
||
userRepo.users[1] = &service.User{
|
||
ID: 1,
|
||
SoraStorageQuotaBytes: 10 * 1024 * 1024,
|
||
SoraStorageUsedBytes: 0,
|
||
}
|
||
quotaService := service.NewSoraQuotaService(userRepo, nil, nil)
|
||
|
||
repo := newStubSoraGenRepo()
|
||
genService := service.NewSoraGenerationService(repo, nil, nil)
|
||
h := &SoraClientHandler{
|
||
genService: genService,
|
||
quotaService: quotaService,
|
||
}
|
||
|
||
c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1)
|
||
h.Generate(c)
|
||
require.Equal(t, http.StatusOK, rec.Code)
|
||
}
|
||
|
||
// ==================== Stub: SettingRepository (用于 S3 存储测试) ====================
|
||
|
||
var _ service.SettingRepository = (*stubSettingRepoForHandler)(nil)
|
||
|
||
type stubSettingRepoForHandler struct {
|
||
values map[string]string
|
||
}
|
||
|
||
func newStubSettingRepoForHandler(values map[string]string) *stubSettingRepoForHandler {
|
||
if values == nil {
|
||
values = make(map[string]string)
|
||
}
|
||
return &stubSettingRepoForHandler{values: values}
|
||
}
|
||
|
||
func (r *stubSettingRepoForHandler) Get(_ context.Context, key string) (*service.Setting, error) {
|
||
if v, ok := r.values[key]; ok {
|
||
return &service.Setting{Key: key, Value: v}, nil
|
||
}
|
||
return nil, service.ErrSettingNotFound
|
||
}
|
||
func (r *stubSettingRepoForHandler) GetValue(_ context.Context, key string) (string, error) {
|
||
if v, ok := r.values[key]; ok {
|
||
return v, nil
|
||
}
|
||
return "", service.ErrSettingNotFound
|
||
}
|
||
func (r *stubSettingRepoForHandler) Set(_ context.Context, key, value string) error {
|
||
r.values[key] = value
|
||
return nil
|
||
}
|
||
func (r *stubSettingRepoForHandler) GetMultiple(_ context.Context, keys []string) (map[string]string, error) {
|
||
result := make(map[string]string)
|
||
for _, k := range keys {
|
||
if v, ok := r.values[k]; ok {
|
||
result[k] = v
|
||
}
|
||
}
|
||
return result, nil
|
||
}
|
||
func (r *stubSettingRepoForHandler) SetMultiple(_ context.Context, settings map[string]string) error {
|
||
for k, v := range settings {
|
||
r.values[k] = v
|
||
}
|
||
return nil
|
||
}
|
||
func (r *stubSettingRepoForHandler) GetAll(_ context.Context) (map[string]string, error) {
|
||
return r.values, nil
|
||
}
|
||
func (r *stubSettingRepoForHandler) Delete(_ context.Context, key string) error {
|
||
delete(r.values, key)
|
||
return nil
|
||
}
|
||
|
||
// ==================== S3 / MediaStorage 辅助函数 ====================
|
||
|
||
// newS3StorageForHandler 创建指向指定 endpoint 的 S3Storage(用于测试)。
|
||
func newS3StorageForHandler(endpoint string) *service.SoraS3Storage {
|
||
settingRepo := newStubSettingRepoForHandler(map[string]string{
|
||
"sora_s3_enabled": "true",
|
||
"sora_s3_endpoint": endpoint,
|
||
"sora_s3_region": "us-east-1",
|
||
"sora_s3_bucket": "test-bucket",
|
||
"sora_s3_access_key_id": "AKIATEST",
|
||
"sora_s3_secret_access_key": "test-secret",
|
||
"sora_s3_prefix": "sora",
|
||
"sora_s3_force_path_style": "true",
|
||
})
|
||
settingService := service.NewSettingService(settingRepo, &config.Config{})
|
||
return service.NewSoraS3Storage(settingService)
|
||
}
|
||
|
||
// newFakeSourceServer 创建返回固定内容的 HTTP 服务器(模拟上游媒体文件)。
|
||
func newFakeSourceServer() *httptest.Server {
|
||
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
w.Header().Set("Content-Type", "video/mp4")
|
||
w.WriteHeader(http.StatusOK)
|
||
_, _ = w.Write([]byte("fake video data for test"))
|
||
}))
|
||
}
|
||
|
||
// newFakeS3Server 创建模拟 S3 的 HTTP 服务器。
|
||
// mode: "ok" 接受所有请求,"fail" 返回 403,"fail-second" 第一次成功第二次失败。
|
||
func newFakeS3Server(mode string) *httptest.Server {
|
||
var counter atomic.Int32
|
||
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
_, _ = io.Copy(io.Discard, r.Body)
|
||
_ = r.Body.Close()
|
||
|
||
switch mode {
|
||
case "ok":
|
||
w.Header().Set("ETag", `"test-etag"`)
|
||
w.WriteHeader(http.StatusOK)
|
||
case "fail":
|
||
w.WriteHeader(http.StatusForbidden)
|
||
_, _ = w.Write([]byte(`<?xml version="1.0"?><Error><Code>AccessDenied</Code></Error>`))
|
||
case "fail-second":
|
||
n := counter.Add(1)
|
||
if n <= 1 {
|
||
w.Header().Set("ETag", `"test-etag"`)
|
||
w.WriteHeader(http.StatusOK)
|
||
} else {
|
||
w.WriteHeader(http.StatusForbidden)
|
||
_, _ = w.Write([]byte(`<?xml version="1.0"?><Error><Code>AccessDenied</Code></Error>`))
|
||
}
|
||
}
|
||
}))
|
||
}
|
||
|
||
// ==================== processGeneration 直接调用测试 ====================
|
||
|
||
func TestProcessGeneration_MarkGeneratingFails(t *testing.T) {
|
||
repo := newStubSoraGenRepo()
|
||
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
|
||
repo.updateErr = fmt.Errorf("db error")
|
||
genService := service.NewSoraGenerationService(repo, nil, nil)
|
||
h := &SoraClientHandler{genService: genService}
|
||
|
||
// 直接调用(非 goroutine),MarkGenerating 失败 → 早退
|
||
h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1)
|
||
// MarkGenerating 在调用 repo.Update 前已修改内存对象为 "generating"
|
||
// repo.Update 返回错误 → processGeneration 早退,不会继续到 MarkFailed
|
||
// 因此 ErrorMessage 为空(证明未调用 MarkFailed)
|
||
require.Equal(t, "generating", repo.gens[1].Status)
|
||
require.Empty(t, repo.gens[1].ErrorMessage)
|
||
}
|
||
|
||
func TestProcessGeneration_GatewayServiceNil(t *testing.T) {
|
||
repo := newStubSoraGenRepo()
|
||
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
|
||
genService := service.NewSoraGenerationService(repo, nil, nil)
|
||
h := &SoraClientHandler{genService: genService}
|
||
// gatewayService 未设置 → MarkFailed
|
||
|
||
h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1)
|
||
require.Equal(t, "failed", repo.gens[1].Status)
|
||
require.Contains(t, repo.gens[1].ErrorMessage, "gatewayService")
|
||
}
|
||
|
||
// ==================== storeMediaWithDegradation: S3 路径 ====================
|
||
|
||
func TestStoreMediaWithDegradation_S3SuccessSingleURL(t *testing.T) {
|
||
sourceServer := newFakeSourceServer()
|
||
defer sourceServer.Close()
|
||
fakeS3 := newFakeS3Server("ok")
|
||
defer fakeS3.Close()
|
||
|
||
s3Storage := newS3StorageForHandler(fakeS3.URL)
|
||
h := &SoraClientHandler{s3Storage: s3Storage}
|
||
|
||
storedURL, storedURLs, storageType, s3Keys, fileSize := h.storeMediaWithDegradation(
|
||
context.Background(), 1, "video", sourceServer.URL+"/v.mp4", nil,
|
||
)
|
||
require.Equal(t, service.SoraStorageTypeS3, storageType)
|
||
require.Len(t, s3Keys, 1)
|
||
require.NotEmpty(t, s3Keys[0])
|
||
require.Len(t, storedURLs, 1)
|
||
require.Equal(t, storedURL, storedURLs[0])
|
||
require.Contains(t, storedURL, fakeS3.URL)
|
||
require.Contains(t, storedURL, "/test-bucket/")
|
||
require.Greater(t, fileSize, int64(0))
|
||
}
|
||
|
||
func TestStoreMediaWithDegradation_S3SuccessMultiURL(t *testing.T) {
|
||
sourceServer := newFakeSourceServer()
|
||
defer sourceServer.Close()
|
||
fakeS3 := newFakeS3Server("ok")
|
||
defer fakeS3.Close()
|
||
|
||
s3Storage := newS3StorageForHandler(fakeS3.URL)
|
||
h := &SoraClientHandler{s3Storage: s3Storage}
|
||
|
||
urls := []string{sourceServer.URL + "/a.mp4", sourceServer.URL + "/b.mp4"}
|
||
storedURL, storedURLs, storageType, s3Keys, fileSize := h.storeMediaWithDegradation(
|
||
context.Background(), 1, "video", sourceServer.URL+"/a.mp4", urls,
|
||
)
|
||
require.Equal(t, service.SoraStorageTypeS3, storageType)
|
||
require.Len(t, s3Keys, 2)
|
||
require.Len(t, storedURLs, 2)
|
||
require.Equal(t, storedURL, storedURLs[0])
|
||
require.Contains(t, storedURLs[0], fakeS3.URL)
|
||
require.Contains(t, storedURLs[1], fakeS3.URL)
|
||
require.Greater(t, fileSize, int64(0))
|
||
}
|
||
|
||
func TestStoreMediaWithDegradation_S3DownloadFails(t *testing.T) {
|
||
// 上游返回 404 → 下载失败 → S3 上传不会开始
|
||
fakeS3 := newFakeS3Server("ok")
|
||
defer fakeS3.Close()
|
||
badSource := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
w.WriteHeader(http.StatusNotFound)
|
||
}))
|
||
defer badSource.Close()
|
||
|
||
s3Storage := newS3StorageForHandler(fakeS3.URL)
|
||
h := &SoraClientHandler{s3Storage: s3Storage}
|
||
|
||
_, _, storageType, _, _ := h.storeMediaWithDegradation(
|
||
context.Background(), 1, "video", badSource.URL+"/missing.mp4", nil,
|
||
)
|
||
require.Equal(t, service.SoraStorageTypeUpstream, storageType)
|
||
}
|
||
|
||
func TestStoreMediaWithDegradation_S3FailsSingleURL(t *testing.T) {
|
||
sourceServer := newFakeSourceServer()
|
||
defer sourceServer.Close()
|
||
fakeS3 := newFakeS3Server("fail")
|
||
defer fakeS3.Close()
|
||
|
||
s3Storage := newS3StorageForHandler(fakeS3.URL)
|
||
h := &SoraClientHandler{s3Storage: s3Storage}
|
||
|
||
_, _, storageType, s3Keys, _ := h.storeMediaWithDegradation(
|
||
context.Background(), 1, "video", sourceServer.URL+"/v.mp4", nil,
|
||
)
|
||
// S3 失败,降级到 upstream
|
||
require.Equal(t, service.SoraStorageTypeUpstream, storageType)
|
||
require.Nil(t, s3Keys)
|
||
}
|
||
|
||
func TestStoreMediaWithDegradation_S3PartialFailureCleanup(t *testing.T) {
|
||
sourceServer := newFakeSourceServer()
|
||
defer sourceServer.Close()
|
||
fakeS3 := newFakeS3Server("fail-second")
|
||
defer fakeS3.Close()
|
||
|
||
s3Storage := newS3StorageForHandler(fakeS3.URL)
|
||
h := &SoraClientHandler{s3Storage: s3Storage}
|
||
|
||
urls := []string{sourceServer.URL + "/a.mp4", sourceServer.URL + "/b.mp4"}
|
||
_, _, storageType, s3Keys, _ := h.storeMediaWithDegradation(
|
||
context.Background(), 1, "video", sourceServer.URL+"/a.mp4", urls,
|
||
)
|
||
// 第二个 URL 上传失败 → 清理已上传 → 降级到 upstream
|
||
require.Equal(t, service.SoraStorageTypeUpstream, storageType)
|
||
require.Nil(t, s3Keys)
|
||
}
|
||
|
||
// ==================== storeMediaWithDegradation: 本地存储路径 ====================
|
||
|
||
func TestStoreMediaWithDegradation_LocalStorageFails(t *testing.T) {
|
||
// 使用无效路径,EnsureLocalDirs 失败 → StoreFromURLs 返回 error
|
||
cfg := &config.Config{
|
||
Sora: config.SoraConfig{
|
||
Storage: config.SoraStorageConfig{
|
||
Type: "local",
|
||
LocalPath: "/dev/null/invalid_dir",
|
||
},
|
||
},
|
||
}
|
||
mediaStorage := service.NewSoraMediaStorage(cfg)
|
||
h := &SoraClientHandler{mediaStorage: mediaStorage}
|
||
|
||
_, _, storageType, _, _ := h.storeMediaWithDegradation(
|
||
context.Background(), 1, "video", "https://upstream.com/v.mp4", nil,
|
||
)
|
||
// 本地存储失败,降级到 upstream
|
||
require.Equal(t, service.SoraStorageTypeUpstream, storageType)
|
||
}
|
||
|
||
func TestStoreMediaWithDegradation_LocalStorageSuccess(t *testing.T) {
|
||
tmpDir, err := os.MkdirTemp("", "sora-handler-test-*")
|
||
require.NoError(t, err)
|
||
defer os.RemoveAll(tmpDir)
|
||
|
||
sourceServer := newFakeSourceServer()
|
||
defer sourceServer.Close()
|
||
|
||
cfg := &config.Config{
|
||
Sora: config.SoraConfig{
|
||
Storage: config.SoraStorageConfig{
|
||
Type: "local",
|
||
LocalPath: tmpDir,
|
||
DownloadTimeoutSeconds: 5,
|
||
MaxDownloadBytes: 10 * 1024 * 1024,
|
||
},
|
||
},
|
||
}
|
||
mediaStorage := service.NewSoraMediaStorage(cfg)
|
||
h := &SoraClientHandler{mediaStorage: mediaStorage}
|
||
|
||
_, _, storageType, s3Keys, _ := h.storeMediaWithDegradation(
|
||
context.Background(), 1, "video", sourceServer.URL+"/v.mp4", nil,
|
||
)
|
||
require.Equal(t, service.SoraStorageTypeLocal, storageType)
|
||
require.Nil(t, s3Keys) // 本地存储不返回 S3 keys
|
||
}
|
||
|
||
func TestStoreMediaWithDegradation_S3FailsFallbackToLocal(t *testing.T) {
|
||
tmpDir, err := os.MkdirTemp("", "sora-handler-test-*")
|
||
require.NoError(t, err)
|
||
defer os.RemoveAll(tmpDir)
|
||
|
||
sourceServer := newFakeSourceServer()
|
||
defer sourceServer.Close()
|
||
fakeS3 := newFakeS3Server("fail")
|
||
defer fakeS3.Close()
|
||
|
||
s3Storage := newS3StorageForHandler(fakeS3.URL)
|
||
cfg := &config.Config{
|
||
Sora: config.SoraConfig{
|
||
Storage: config.SoraStorageConfig{
|
||
Type: "local",
|
||
LocalPath: tmpDir,
|
||
DownloadTimeoutSeconds: 5,
|
||
MaxDownloadBytes: 10 * 1024 * 1024,
|
||
},
|
||
},
|
||
}
|
||
mediaStorage := service.NewSoraMediaStorage(cfg)
|
||
h := &SoraClientHandler{
|
||
s3Storage: s3Storage,
|
||
mediaStorage: mediaStorage,
|
||
}
|
||
|
||
_, _, storageType, _, _ := h.storeMediaWithDegradation(
|
||
context.Background(), 1, "video", sourceServer.URL+"/v.mp4", nil,
|
||
)
|
||
// S3 失败 → 本地存储成功
|
||
require.Equal(t, service.SoraStorageTypeLocal, storageType)
|
||
}
|
||
|
||
// ==================== SaveToStorage: S3 路径 ====================
|
||
|
||
func TestSaveToStorage_S3EnabledButUploadFails(t *testing.T) {
|
||
sourceServer := newFakeSourceServer()
|
||
defer sourceServer.Close()
|
||
fakeS3 := newFakeS3Server("fail")
|
||
defer fakeS3.Close()
|
||
|
||
repo := newStubSoraGenRepo()
|
||
repo.gens[1] = &service.SoraGeneration{
|
||
ID: 1, UserID: 1, Status: "completed",
|
||
StorageType: "upstream",
|
||
MediaURL: sourceServer.URL + "/v.mp4",
|
||
}
|
||
s3Storage := newS3StorageForHandler(fakeS3.URL)
|
||
genService := service.NewSoraGenerationService(repo, nil, nil)
|
||
h := &SoraClientHandler{genService: genService, s3Storage: s3Storage}
|
||
|
||
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
|
||
c.Params = gin.Params{{Key: "id", Value: "1"}}
|
||
h.SaveToStorage(c)
|
||
require.Equal(t, http.StatusInternalServerError, rec.Code)
|
||
resp := parseResponse(t, rec)
|
||
require.Contains(t, resp["message"], "S3")
|
||
}
|
||
|
||
func TestSaveToStorage_UpstreamURLExpired(t *testing.T) {
|
||
expiredServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||
w.WriteHeader(http.StatusForbidden)
|
||
}))
|
||
defer expiredServer.Close()
|
||
fakeS3 := newFakeS3Server("ok")
|
||
defer fakeS3.Close()
|
||
|
||
repo := newStubSoraGenRepo()
|
||
repo.gens[1] = &service.SoraGeneration{
|
||
ID: 1, UserID: 1, Status: "completed",
|
||
StorageType: "upstream",
|
||
MediaURL: expiredServer.URL + "/v.mp4",
|
||
}
|
||
s3Storage := newS3StorageForHandler(fakeS3.URL)
|
||
genService := service.NewSoraGenerationService(repo, nil, nil)
|
||
h := &SoraClientHandler{genService: genService, s3Storage: s3Storage}
|
||
|
||
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
|
||
c.Params = gin.Params{{Key: "id", Value: "1"}}
|
||
h.SaveToStorage(c)
|
||
require.Equal(t, http.StatusGone, rec.Code)
|
||
resp := parseResponse(t, rec)
|
||
require.Contains(t, fmt.Sprint(resp["message"]), "过期")
|
||
}
|
||
|
||
func TestSaveToStorage_S3EnabledUploadSuccess(t *testing.T) {
|
||
sourceServer := newFakeSourceServer()
|
||
defer sourceServer.Close()
|
||
fakeS3 := newFakeS3Server("ok")
|
||
defer fakeS3.Close()
|
||
|
||
repo := newStubSoraGenRepo()
|
||
repo.gens[1] = &service.SoraGeneration{
|
||
ID: 1, UserID: 1, Status: "completed",
|
||
StorageType: "upstream",
|
||
MediaURL: sourceServer.URL + "/v.mp4",
|
||
}
|
||
s3Storage := newS3StorageForHandler(fakeS3.URL)
|
||
genService := service.NewSoraGenerationService(repo, nil, nil)
|
||
h := &SoraClientHandler{genService: genService, s3Storage: s3Storage}
|
||
|
||
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
|
||
c.Params = gin.Params{{Key: "id", Value: "1"}}
|
||
h.SaveToStorage(c)
|
||
require.Equal(t, http.StatusOK, rec.Code)
|
||
resp := parseResponse(t, rec)
|
||
data := resp["data"].(map[string]any)
|
||
require.Contains(t, data["message"], "S3")
|
||
require.NotEmpty(t, data["object_key"])
|
||
// 验证记录已更新为 S3 存储
|
||
require.Equal(t, service.SoraStorageTypeS3, repo.gens[1].StorageType)
|
||
}
|
||
|
||
func TestSaveToStorage_S3EnabledUploadSuccess_MultiMediaURLs(t *testing.T) {
|
||
sourceServer := newFakeSourceServer()
|
||
defer sourceServer.Close()
|
||
fakeS3 := newFakeS3Server("ok")
|
||
defer fakeS3.Close()
|
||
|
||
repo := newStubSoraGenRepo()
|
||
repo.gens[1] = &service.SoraGeneration{
|
||
ID: 1, UserID: 1, Status: "completed",
|
||
StorageType: "upstream",
|
||
MediaURL: sourceServer.URL + "/v1.mp4",
|
||
MediaURLs: []string{
|
||
sourceServer.URL + "/v1.mp4",
|
||
sourceServer.URL + "/v2.mp4",
|
||
},
|
||
}
|
||
s3Storage := newS3StorageForHandler(fakeS3.URL)
|
||
genService := service.NewSoraGenerationService(repo, nil, nil)
|
||
h := &SoraClientHandler{genService: genService, s3Storage: s3Storage}
|
||
|
||
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
|
||
c.Params = gin.Params{{Key: "id", Value: "1"}}
|
||
h.SaveToStorage(c)
|
||
require.Equal(t, http.StatusOK, rec.Code)
|
||
resp := parseResponse(t, rec)
|
||
data := resp["data"].(map[string]any)
|
||
require.Len(t, data["object_keys"].([]any), 2)
|
||
require.Equal(t, service.SoraStorageTypeS3, repo.gens[1].StorageType)
|
||
require.Len(t, repo.gens[1].S3ObjectKeys, 2)
|
||
require.Len(t, repo.gens[1].MediaURLs, 2)
|
||
}
|
||
|
||
func TestSaveToStorage_S3EnabledUploadSuccessWithQuota(t *testing.T) {
|
||
sourceServer := newFakeSourceServer()
|
||
defer sourceServer.Close()
|
||
fakeS3 := newFakeS3Server("ok")
|
||
defer fakeS3.Close()
|
||
|
||
repo := newStubSoraGenRepo()
|
||
repo.gens[1] = &service.SoraGeneration{
|
||
ID: 1, UserID: 1, Status: "completed",
|
||
StorageType: "upstream",
|
||
MediaURL: sourceServer.URL + "/v.mp4",
|
||
}
|
||
s3Storage := newS3StorageForHandler(fakeS3.URL)
|
||
genService := service.NewSoraGenerationService(repo, nil, nil)
|
||
|
||
userRepo := newStubUserRepoForHandler()
|
||
userRepo.users[1] = &service.User{
|
||
ID: 1,
|
||
SoraStorageQuotaBytes: 100 * 1024 * 1024,
|
||
SoraStorageUsedBytes: 0,
|
||
}
|
||
quotaService := service.NewSoraQuotaService(userRepo, nil, nil)
|
||
h := &SoraClientHandler{genService: genService, s3Storage: s3Storage, quotaService: quotaService}
|
||
|
||
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
|
||
c.Params = gin.Params{{Key: "id", Value: "1"}}
|
||
h.SaveToStorage(c)
|
||
require.Equal(t, http.StatusOK, rec.Code)
|
||
// 验证配额已累加
|
||
require.Greater(t, userRepo.users[1].SoraStorageUsedBytes, int64(0))
|
||
}
|
||
|
||
func TestSaveToStorage_S3UploadSuccessMarkCompletedFails(t *testing.T) {
|
||
sourceServer := newFakeSourceServer()
|
||
defer sourceServer.Close()
|
||
fakeS3 := newFakeS3Server("ok")
|
||
defer fakeS3.Close()
|
||
|
||
repo := newStubSoraGenRepo()
|
||
repo.gens[1] = &service.SoraGeneration{
|
||
ID: 1, UserID: 1, Status: "completed",
|
||
StorageType: "upstream",
|
||
MediaURL: sourceServer.URL + "/v.mp4",
|
||
}
|
||
// S3 上传成功后,MarkCompleted 会调用 repo.Update → 失败
|
||
repo.updateErr = fmt.Errorf("db error")
|
||
s3Storage := newS3StorageForHandler(fakeS3.URL)
|
||
genService := service.NewSoraGenerationService(repo, nil, nil)
|
||
h := &SoraClientHandler{genService: genService, s3Storage: s3Storage}
|
||
|
||
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
|
||
c.Params = gin.Params{{Key: "id", Value: "1"}}
|
||
h.SaveToStorage(c)
|
||
require.Equal(t, http.StatusInternalServerError, rec.Code)
|
||
}
|
||
|
||
// ==================== GetStorageStatus: S3 路径 ====================
|
||
|
||
func TestGetStorageStatus_S3EnabledNotHealthy(t *testing.T) {
|
||
// S3 启用但 TestConnection 失败(fake 端点不响应 HeadBucket)
|
||
fakeS3 := newFakeS3Server("fail")
|
||
defer fakeS3.Close()
|
||
|
||
s3Storage := newS3StorageForHandler(fakeS3.URL)
|
||
h := &SoraClientHandler{s3Storage: s3Storage}
|
||
|
||
c, rec := makeGinContext("GET", "/api/v1/sora/storage-status", "", 0)
|
||
h.GetStorageStatus(c)
|
||
require.Equal(t, http.StatusOK, rec.Code)
|
||
resp := parseResponse(t, rec)
|
||
data := resp["data"].(map[string]any)
|
||
require.Equal(t, true, data["s3_enabled"])
|
||
require.Equal(t, false, data["s3_healthy"])
|
||
}
|
||
|
||
func TestGetStorageStatus_S3EnabledHealthy(t *testing.T) {
|
||
fakeS3 := newFakeS3Server("ok")
|
||
defer fakeS3.Close()
|
||
|
||
s3Storage := newS3StorageForHandler(fakeS3.URL)
|
||
h := &SoraClientHandler{s3Storage: s3Storage}
|
||
|
||
c, rec := makeGinContext("GET", "/api/v1/sora/storage-status", "", 0)
|
||
h.GetStorageStatus(c)
|
||
require.Equal(t, http.StatusOK, rec.Code)
|
||
resp := parseResponse(t, rec)
|
||
data := resp["data"].(map[string]any)
|
||
require.Equal(t, true, data["s3_enabled"])
|
||
require.Equal(t, true, data["s3_healthy"])
|
||
}
|
||
|
||
// ==================== Stub: AccountRepository (用于 GatewayService) ====================
|
||
|
||
var _ service.AccountRepository = (*stubAccountRepoForHandler)(nil)
|
||
|
||
type stubAccountRepoForHandler struct {
|
||
accounts []service.Account
|
||
}
|
||
|
||
func (r *stubAccountRepoForHandler) Create(context.Context, *service.Account) error { return nil }
|
||
func (r *stubAccountRepoForHandler) GetByID(_ context.Context, id int64) (*service.Account, error) {
|
||
for i := range r.accounts {
|
||
if r.accounts[i].ID == id {
|
||
return &r.accounts[i], nil
|
||
}
|
||
}
|
||
return nil, fmt.Errorf("account not found")
|
||
}
|
||
func (r *stubAccountRepoForHandler) GetByIDs(context.Context, []int64) ([]*service.Account, error) {
|
||
return nil, nil
|
||
}
|
||
func (r *stubAccountRepoForHandler) ExistsByID(context.Context, int64) (bool, error) {
|
||
return false, nil
|
||
}
|
||
func (r *stubAccountRepoForHandler) GetByCRSAccountID(context.Context, string) (*service.Account, error) {
|
||
return nil, nil
|
||
}
|
||
func (r *stubAccountRepoForHandler) FindByExtraField(context.Context, string, any) ([]service.Account, error) {
|
||
return nil, nil
|
||
}
|
||
func (r *stubAccountRepoForHandler) ListCRSAccountIDs(context.Context) (map[string]int64, error) {
|
||
return nil, nil
|
||
}
|
||
func (r *stubAccountRepoForHandler) Update(context.Context, *service.Account) error { return nil }
|
||
func (r *stubAccountRepoForHandler) Delete(context.Context, int64) error { return nil }
|
||
func (r *stubAccountRepoForHandler) List(context.Context, pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) {
|
||
return nil, nil, nil
|
||
}
|
||
func (r *stubAccountRepoForHandler) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string, string, int64) ([]service.Account, *pagination.PaginationResult, error) {
|
||
return nil, nil, nil
|
||
}
|
||
func (r *stubAccountRepoForHandler) ListByGroup(context.Context, int64) ([]service.Account, error) {
|
||
return nil, nil
|
||
}
|
||
func (r *stubAccountRepoForHandler) ListActive(context.Context) ([]service.Account, error) {
|
||
return nil, nil
|
||
}
|
||
func (r *stubAccountRepoForHandler) ListByPlatform(context.Context, string) ([]service.Account, error) {
|
||
return nil, nil
|
||
}
|
||
func (r *stubAccountRepoForHandler) UpdateLastUsed(context.Context, int64) error { return nil }
|
||
func (r *stubAccountRepoForHandler) BatchUpdateLastUsed(context.Context, map[int64]time.Time) error {
|
||
return nil
|
||
}
|
||
func (r *stubAccountRepoForHandler) SetError(context.Context, int64, string) error { return nil }
|
||
func (r *stubAccountRepoForHandler) ClearError(context.Context, int64) error { return nil }
|
||
func (r *stubAccountRepoForHandler) SetSchedulable(context.Context, int64, bool) error {
|
||
return nil
|
||
}
|
||
func (r *stubAccountRepoForHandler) AutoPauseExpiredAccounts(context.Context, time.Time) (int64, error) {
|
||
return 0, nil
|
||
}
|
||
func (r *stubAccountRepoForHandler) BindGroups(context.Context, int64, []int64) error { return nil }
|
||
func (r *stubAccountRepoForHandler) ListSchedulable(context.Context) ([]service.Account, error) {
|
||
return r.accounts, nil
|
||
}
|
||
func (r *stubAccountRepoForHandler) ListSchedulableByGroupID(context.Context, int64) ([]service.Account, error) {
|
||
return r.accounts, nil
|
||
}
|
||
func (r *stubAccountRepoForHandler) ListSchedulableByPlatform(_ context.Context, _ string) ([]service.Account, error) {
|
||
return r.accounts, nil
|
||
}
|
||
func (r *stubAccountRepoForHandler) ListSchedulableByGroupIDAndPlatform(context.Context, int64, string) ([]service.Account, error) {
|
||
return r.accounts, nil
|
||
}
|
||
func (r *stubAccountRepoForHandler) ListSchedulableByPlatforms(context.Context, []string) ([]service.Account, error) {
|
||
return r.accounts, nil
|
||
}
|
||
func (r *stubAccountRepoForHandler) ListSchedulableByGroupIDAndPlatforms(context.Context, int64, []string) ([]service.Account, error) {
|
||
return r.accounts, nil
|
||
}
|
||
func (r *stubAccountRepoForHandler) SetRateLimited(context.Context, int64, time.Time) error {
|
||
return nil
|
||
}
|
||
func (r *stubAccountRepoForHandler) SetModelRateLimit(context.Context, int64, string, time.Time) error {
|
||
return nil
|
||
}
|
||
func (r *stubAccountRepoForHandler) SetOverloaded(context.Context, int64, time.Time) error {
|
||
return nil
|
||
}
|
||
func (r *stubAccountRepoForHandler) SetTempUnschedulable(context.Context, int64, time.Time, string) error {
|
||
return nil
|
||
}
|
||
func (r *stubAccountRepoForHandler) ClearTempUnschedulable(context.Context, int64) error { return nil }
|
||
func (r *stubAccountRepoForHandler) ClearRateLimit(context.Context, int64) error { return nil }
|
||
func (r *stubAccountRepoForHandler) ClearAntigravityQuotaScopes(context.Context, int64) error {
|
||
return nil
|
||
}
|
||
func (r *stubAccountRepoForHandler) ClearModelRateLimits(context.Context, int64) error { return nil }
|
||
func (r *stubAccountRepoForHandler) UpdateSessionWindow(context.Context, int64, *time.Time, *time.Time, string) error {
|
||
return nil
|
||
}
|
||
func (r *stubAccountRepoForHandler) UpdateExtra(context.Context, int64, map[string]any) error {
|
||
return nil
|
||
}
|
||
func (r *stubAccountRepoForHandler) BulkUpdate(context.Context, []int64, service.AccountBulkUpdate) (int64, error) {
|
||
return 0, nil
|
||
}
|
||
|
||
// ==================== Stub: SoraClient (用于 SoraGatewayService) ====================
|
||
|
||
var _ service.SoraClient = (*stubSoraClientForHandler)(nil)
|
||
|
||
type stubSoraClientForHandler struct {
|
||
videoStatus *service.SoraVideoTaskStatus
|
||
}
|
||
|
||
func (s *stubSoraClientForHandler) Enabled() bool { return true }
|
||
func (s *stubSoraClientForHandler) UploadImage(context.Context, *service.Account, []byte, string) (string, error) {
|
||
return "", nil
|
||
}
|
||
func (s *stubSoraClientForHandler) CreateImageTask(context.Context, *service.Account, service.SoraImageRequest) (string, error) {
|
||
return "task-image", nil
|
||
}
|
||
func (s *stubSoraClientForHandler) CreateVideoTask(context.Context, *service.Account, service.SoraVideoRequest) (string, error) {
|
||
return "task-video", nil
|
||
}
|
||
func (s *stubSoraClientForHandler) CreateStoryboardTask(context.Context, *service.Account, service.SoraStoryboardRequest) (string, error) {
|
||
return "task-video", nil
|
||
}
|
||
func (s *stubSoraClientForHandler) UploadCharacterVideo(context.Context, *service.Account, []byte) (string, error) {
|
||
return "", nil
|
||
}
|
||
func (s *stubSoraClientForHandler) GetCameoStatus(context.Context, *service.Account, string) (*service.SoraCameoStatus, error) {
|
||
return nil, nil
|
||
}
|
||
func (s *stubSoraClientForHandler) DownloadCharacterImage(context.Context, *service.Account, string) ([]byte, error) {
|
||
return nil, nil
|
||
}
|
||
func (s *stubSoraClientForHandler) UploadCharacterImage(context.Context, *service.Account, []byte) (string, error) {
|
||
return "", nil
|
||
}
|
||
func (s *stubSoraClientForHandler) FinalizeCharacter(context.Context, *service.Account, service.SoraCharacterFinalizeRequest) (string, error) {
|
||
return "", nil
|
||
}
|
||
func (s *stubSoraClientForHandler) SetCharacterPublic(context.Context, *service.Account, string) error {
|
||
return nil
|
||
}
|
||
func (s *stubSoraClientForHandler) DeleteCharacter(context.Context, *service.Account, string) error {
|
||
return nil
|
||
}
|
||
func (s *stubSoraClientForHandler) PostVideoForWatermarkFree(context.Context, *service.Account, string) (string, error) {
|
||
return "", nil
|
||
}
|
||
func (s *stubSoraClientForHandler) DeletePost(context.Context, *service.Account, string) error {
|
||
return nil
|
||
}
|
||
func (s *stubSoraClientForHandler) GetWatermarkFreeURLCustom(context.Context, *service.Account, string, string, string) (string, error) {
|
||
return "", nil
|
||
}
|
||
func (s *stubSoraClientForHandler) EnhancePrompt(context.Context, *service.Account, string, string, int) (string, error) {
|
||
return "", nil
|
||
}
|
||
func (s *stubSoraClientForHandler) GetImageTask(context.Context, *service.Account, string) (*service.SoraImageTaskStatus, error) {
|
||
return nil, nil
|
||
}
|
||
func (s *stubSoraClientForHandler) GetVideoTask(_ context.Context, _ *service.Account, _ string) (*service.SoraVideoTaskStatus, error) {
|
||
return s.videoStatus, nil
|
||
}
|
||
|
||
// ==================== 辅助:创建最小 GatewayService 和 SoraGatewayService ====================
|
||
|
||
// newMinimalGatewayService 创建仅包含 accountRepo 的最小 GatewayService(用于测试 SelectAccountForModel)。
|
||
func newMinimalGatewayService(accountRepo service.AccountRepository) *service.GatewayService {
|
||
return service.NewGatewayService(
|
||
accountRepo, nil, nil, nil, nil, nil, nil, nil,
|
||
nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil,
|
||
)
|
||
}
|
||
|
||
// newMinimalSoraGatewayService 创建最小 SoraGatewayService(用于测试 Forward)。
|
||
func newMinimalSoraGatewayService(soraClient service.SoraClient) *service.SoraGatewayService {
|
||
cfg := &config.Config{
|
||
Sora: config.SoraConfig{
|
||
Client: config.SoraClientConfig{
|
||
PollIntervalSeconds: 1,
|
||
MaxPollAttempts: 1,
|
||
},
|
||
},
|
||
}
|
||
return service.NewSoraGatewayService(soraClient, nil, nil, cfg)
|
||
}
|
||
|
||
// ==================== processGeneration: 更多路径测试 ====================
|
||
|
||
func TestProcessGeneration_SelectAccountError(t *testing.T) {
|
||
repo := newStubSoraGenRepo()
|
||
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
|
||
genService := service.NewSoraGenerationService(repo, nil, nil)
|
||
// accountRepo 返回空列表 → SelectAccountForModel 返回 "no available accounts"
|
||
accountRepo := &stubAccountRepoForHandler{accounts: nil}
|
||
gatewayService := newMinimalGatewayService(accountRepo)
|
||
h := &SoraClientHandler{genService: genService, gatewayService: gatewayService}
|
||
|
||
h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1)
|
||
require.Equal(t, "failed", repo.gens[1].Status)
|
||
require.Contains(t, repo.gens[1].ErrorMessage, "选择账号失败")
|
||
}
|
||
|
||
func TestProcessGeneration_SoraGatewayServiceNil(t *testing.T) {
|
||
t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复")
|
||
repo := newStubSoraGenRepo()
|
||
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
|
||
genService := service.NewSoraGenerationService(repo, nil, nil)
|
||
// 提供可用账号使 SelectAccountForModel 成功
|
||
accountRepo := &stubAccountRepoForHandler{
|
||
accounts: []service.Account{
|
||
{ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true},
|
||
},
|
||
}
|
||
gatewayService := newMinimalGatewayService(accountRepo)
|
||
// soraGatewayService 为 nil
|
||
h := &SoraClientHandler{genService: genService, gatewayService: gatewayService}
|
||
|
||
h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1)
|
||
require.Equal(t, "failed", repo.gens[1].Status)
|
||
require.Contains(t, repo.gens[1].ErrorMessage, "soraGatewayService")
|
||
}
|
||
|
||
func TestProcessGeneration_ForwardError(t *testing.T) {
|
||
t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复")
|
||
repo := newStubSoraGenRepo()
|
||
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
|
||
genService := service.NewSoraGenerationService(repo, nil, nil)
|
||
accountRepo := &stubAccountRepoForHandler{
|
||
accounts: []service.Account{
|
||
{ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true},
|
||
},
|
||
}
|
||
gatewayService := newMinimalGatewayService(accountRepo)
|
||
// SoraClient 返回视频任务失败
|
||
soraClient := &stubSoraClientForHandler{
|
||
videoStatus: &service.SoraVideoTaskStatus{
|
||
Status: "failed",
|
||
ErrorMsg: "content policy violation",
|
||
},
|
||
}
|
||
soraGatewayService := newMinimalSoraGatewayService(soraClient)
|
||
h := &SoraClientHandler{
|
||
genService: genService,
|
||
gatewayService: gatewayService,
|
||
soraGatewayService: soraGatewayService,
|
||
}
|
||
|
||
h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test prompt", "video", "", 1)
|
||
require.Equal(t, "failed", repo.gens[1].Status)
|
||
require.Contains(t, repo.gens[1].ErrorMessage, "生成失败")
|
||
}
|
||
|
||
func TestProcessGeneration_ForwardErrorCancelled(t *testing.T) {
|
||
repo := newStubSoraGenRepo()
|
||
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
|
||
// MarkGenerating 内部调用 GetByID(第 1 次),Forward 失败后 processGeneration
|
||
// 调用 GetByID(第 2 次)。模拟外部在 Forward 期间取消了任务。
|
||
repo.getByIDOverrideAfterN = 1
|
||
repo.getByIDOverrideStatus = "cancelled"
|
||
genService := service.NewSoraGenerationService(repo, nil, nil)
|
||
accountRepo := &stubAccountRepoForHandler{
|
||
accounts: []service.Account{
|
||
{ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true},
|
||
},
|
||
}
|
||
gatewayService := newMinimalGatewayService(accountRepo)
|
||
soraClient := &stubSoraClientForHandler{
|
||
videoStatus: &service.SoraVideoTaskStatus{Status: "failed", ErrorMsg: "reject"},
|
||
}
|
||
soraGatewayService := newMinimalSoraGatewayService(soraClient)
|
||
h := &SoraClientHandler{
|
||
genService: genService,
|
||
gatewayService: gatewayService,
|
||
soraGatewayService: soraGatewayService,
|
||
}
|
||
|
||
h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1)
|
||
// Forward 失败后检测到外部取消,不应调用 MarkFailed(状态保持 generating)
|
||
require.Equal(t, "generating", repo.gens[1].Status)
|
||
}
|
||
|
||
func TestProcessGeneration_ForwardSuccessNoMediaURL(t *testing.T) {
|
||
t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复")
|
||
repo := newStubSoraGenRepo()
|
||
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
|
||
genService := service.NewSoraGenerationService(repo, nil, nil)
|
||
accountRepo := &stubAccountRepoForHandler{
|
||
accounts: []service.Account{
|
||
{ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true},
|
||
},
|
||
}
|
||
gatewayService := newMinimalGatewayService(accountRepo)
|
||
// SoraClient 返回 completed 但无 URL
|
||
soraClient := &stubSoraClientForHandler{
|
||
videoStatus: &service.SoraVideoTaskStatus{
|
||
Status: "completed",
|
||
URLs: nil, // 无 URL
|
||
},
|
||
}
|
||
soraGatewayService := newMinimalSoraGatewayService(soraClient)
|
||
h := &SoraClientHandler{
|
||
genService: genService,
|
||
gatewayService: gatewayService,
|
||
soraGatewayService: soraGatewayService,
|
||
}
|
||
|
||
h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1)
|
||
require.Equal(t, "failed", repo.gens[1].Status)
|
||
require.Contains(t, repo.gens[1].ErrorMessage, "未获取到媒体 URL")
|
||
}
|
||
|
||
func TestProcessGeneration_ForwardSuccessCancelledBeforeStore(t *testing.T) {
|
||
repo := newStubSoraGenRepo()
|
||
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
|
||
// MarkGenerating 调用 GetByID(第 1 次),之后 processGeneration 行 176 调用 GetByID(第 2 次)
|
||
// 第 2 次返回 "cancelled" 状态,模拟外部取消
|
||
repo.getByIDOverrideAfterN = 1
|
||
repo.getByIDOverrideStatus = "cancelled"
|
||
genService := service.NewSoraGenerationService(repo, nil, nil)
|
||
accountRepo := &stubAccountRepoForHandler{
|
||
accounts: []service.Account{
|
||
{ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true},
|
||
},
|
||
}
|
||
gatewayService := newMinimalGatewayService(accountRepo)
|
||
soraClient := &stubSoraClientForHandler{
|
||
videoStatus: &service.SoraVideoTaskStatus{
|
||
Status: "completed",
|
||
URLs: []string{"https://example.com/video.mp4"},
|
||
},
|
||
}
|
||
soraGatewayService := newMinimalSoraGatewayService(soraClient)
|
||
h := &SoraClientHandler{
|
||
genService: genService,
|
||
gatewayService: gatewayService,
|
||
soraGatewayService: soraGatewayService,
|
||
}
|
||
|
||
h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1)
|
||
// Forward 成功后检测到外部取消,不应调用存储和 MarkCompleted(状态保持 generating)
|
||
require.Equal(t, "generating", repo.gens[1].Status)
|
||
}
|
||
|
||
func TestProcessGeneration_FullSuccessUpstream(t *testing.T) {
|
||
t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复")
|
||
repo := newStubSoraGenRepo()
|
||
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
|
||
genService := service.NewSoraGenerationService(repo, nil, nil)
|
||
accountRepo := &stubAccountRepoForHandler{
|
||
accounts: []service.Account{
|
||
{ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true},
|
||
},
|
||
}
|
||
gatewayService := newMinimalGatewayService(accountRepo)
|
||
soraClient := &stubSoraClientForHandler{
|
||
videoStatus: &service.SoraVideoTaskStatus{
|
||
Status: "completed",
|
||
URLs: []string{"https://example.com/video.mp4"},
|
||
},
|
||
}
|
||
soraGatewayService := newMinimalSoraGatewayService(soraClient)
|
||
// 无 S3 和本地存储,降级到 upstream
|
||
h := &SoraClientHandler{
|
||
genService: genService,
|
||
gatewayService: gatewayService,
|
||
soraGatewayService: soraGatewayService,
|
||
}
|
||
|
||
h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test prompt", "video", "", 1)
|
||
require.Equal(t, "completed", repo.gens[1].Status)
|
||
require.Equal(t, service.SoraStorageTypeUpstream, repo.gens[1].StorageType)
|
||
require.NotEmpty(t, repo.gens[1].MediaURL)
|
||
}
|
||
|
||
func TestProcessGeneration_FullSuccessWithS3(t *testing.T) {
|
||
t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复")
|
||
sourceServer := newFakeSourceServer()
|
||
defer sourceServer.Close()
|
||
fakeS3 := newFakeS3Server("ok")
|
||
defer fakeS3.Close()
|
||
|
||
repo := newStubSoraGenRepo()
|
||
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
|
||
genService := service.NewSoraGenerationService(repo, nil, nil)
|
||
accountRepo := &stubAccountRepoForHandler{
|
||
accounts: []service.Account{
|
||
{ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true},
|
||
},
|
||
}
|
||
gatewayService := newMinimalGatewayService(accountRepo)
|
||
soraClient := &stubSoraClientForHandler{
|
||
videoStatus: &service.SoraVideoTaskStatus{
|
||
Status: "completed",
|
||
URLs: []string{sourceServer.URL + "/video.mp4"},
|
||
},
|
||
}
|
||
soraGatewayService := newMinimalSoraGatewayService(soraClient)
|
||
s3Storage := newS3StorageForHandler(fakeS3.URL)
|
||
|
||
userRepo := newStubUserRepoForHandler()
|
||
userRepo.users[1] = &service.User{
|
||
ID: 1, SoraStorageQuotaBytes: 100 * 1024 * 1024,
|
||
}
|
||
quotaService := service.NewSoraQuotaService(userRepo, nil, nil)
|
||
|
||
h := &SoraClientHandler{
|
||
genService: genService,
|
||
gatewayService: gatewayService,
|
||
soraGatewayService: soraGatewayService,
|
||
s3Storage: s3Storage,
|
||
quotaService: quotaService,
|
||
}
|
||
|
||
h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test prompt", "video", "", 1)
|
||
require.Equal(t, "completed", repo.gens[1].Status)
|
||
require.Equal(t, service.SoraStorageTypeS3, repo.gens[1].StorageType)
|
||
require.NotEmpty(t, repo.gens[1].S3ObjectKeys)
|
||
require.Greater(t, repo.gens[1].FileSizeBytes, int64(0))
|
||
// 验证配额已累加
|
||
require.Greater(t, userRepo.users[1].SoraStorageUsedBytes, int64(0))
|
||
}
|
||
|
||
func TestProcessGeneration_MarkCompletedFails(t *testing.T) {
|
||
t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复")
|
||
repo := newStubSoraGenRepo()
|
||
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
|
||
// 第 1 次 Update(MarkGenerating)成功,第 2 次(MarkCompleted)失败
|
||
repo.updateCallCount = new(int32)
|
||
repo.updateFailAfterN = 1
|
||
genService := service.NewSoraGenerationService(repo, nil, nil)
|
||
accountRepo := &stubAccountRepoForHandler{
|
||
accounts: []service.Account{
|
||
{ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true},
|
||
},
|
||
}
|
||
gatewayService := newMinimalGatewayService(accountRepo)
|
||
soraClient := &stubSoraClientForHandler{
|
||
videoStatus: &service.SoraVideoTaskStatus{
|
||
Status: "completed",
|
||
URLs: []string{"https://example.com/video.mp4"},
|
||
},
|
||
}
|
||
soraGatewayService := newMinimalSoraGatewayService(soraClient)
|
||
h := &SoraClientHandler{
|
||
genService: genService,
|
||
gatewayService: gatewayService,
|
||
soraGatewayService: soraGatewayService,
|
||
}
|
||
|
||
h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test prompt", "video", "", 1)
|
||
// MarkCompleted 内部先修改内存对象状态为 completed,然后 Update 失败。
|
||
// 由于 stub 存储的是指针,内存中的状态已被修改为 completed。
|
||
// 此测试验证 processGeneration 在 MarkCompleted 失败后提前返回(不调用 AddUsage)。
|
||
require.Equal(t, "completed", repo.gens[1].Status)
|
||
}
|
||
|
||
// ==================== cleanupStoredMedia 直接测试 ====================
|
||
|
||
func TestCleanupStoredMedia_S3Path(t *testing.T) {
|
||
// S3 清理路径:s3Storage 为 nil 时不 panic
|
||
h := &SoraClientHandler{}
|
||
// 不应 panic
|
||
h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeS3, []string{"key1"}, nil)
|
||
}
|
||
|
||
func TestCleanupStoredMedia_LocalPath(t *testing.T) {
|
||
// 本地清理路径:mediaStorage 为 nil 时不 panic
|
||
h := &SoraClientHandler{}
|
||
h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeLocal, nil, []string{"/tmp/test.mp4"})
|
||
}
|
||
|
||
func TestCleanupStoredMedia_UpstreamPath(t *testing.T) {
|
||
// upstream 类型不清理
|
||
h := &SoraClientHandler{}
|
||
h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeUpstream, nil, nil)
|
||
}
|
||
|
||
func TestCleanupStoredMedia_EmptyKeys(t *testing.T) {
|
||
// 空 keys 不触发清理
|
||
h := &SoraClientHandler{}
|
||
h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeS3, nil, nil)
|
||
h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeLocal, nil, nil)
|
||
}
|
||
|
||
// ==================== DeleteGeneration: 本地存储清理路径 ====================
|
||
|
||
func TestDeleteGeneration_LocalStorageCleanup(t *testing.T) {
|
||
tmpDir, err := os.MkdirTemp("", "sora-delete-test-*")
|
||
require.NoError(t, err)
|
||
defer os.RemoveAll(tmpDir)
|
||
|
||
cfg := &config.Config{
|
||
Sora: config.SoraConfig{
|
||
Storage: config.SoraStorageConfig{
|
||
Type: "local",
|
||
LocalPath: tmpDir,
|
||
},
|
||
},
|
||
}
|
||
mediaStorage := service.NewSoraMediaStorage(cfg)
|
||
|
||
repo := newStubSoraGenRepo()
|
||
repo.gens[1] = &service.SoraGeneration{
|
||
ID: 1,
|
||
UserID: 1,
|
||
Status: "completed",
|
||
StorageType: service.SoraStorageTypeLocal,
|
||
MediaURL: "video/test.mp4",
|
||
MediaURLs: []string{"video/test.mp4"},
|
||
}
|
||
genService := service.NewSoraGenerationService(repo, nil, nil)
|
||
h := &SoraClientHandler{genService: genService, mediaStorage: mediaStorage}
|
||
|
||
c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1)
|
||
c.Params = gin.Params{{Key: "id", Value: "1"}}
|
||
h.DeleteGeneration(c)
|
||
require.Equal(t, http.StatusOK, rec.Code)
|
||
_, exists := repo.gens[1]
|
||
require.False(t, exists)
|
||
}
|
||
|
||
func TestDeleteGeneration_LocalStorageCleanup_MediaURLFallback(t *testing.T) {
|
||
// MediaURLs 为空,使用 MediaURL 作为清理路径
|
||
tmpDir, err := os.MkdirTemp("", "sora-delete-fallback-*")
|
||
require.NoError(t, err)
|
||
defer os.RemoveAll(tmpDir)
|
||
|
||
cfg := &config.Config{
|
||
Sora: config.SoraConfig{
|
||
Storage: config.SoraStorageConfig{
|
||
Type: "local",
|
||
LocalPath: tmpDir,
|
||
},
|
||
},
|
||
}
|
||
mediaStorage := service.NewSoraMediaStorage(cfg)
|
||
|
||
repo := newStubSoraGenRepo()
|
||
repo.gens[1] = &service.SoraGeneration{
|
||
ID: 1,
|
||
UserID: 1,
|
||
Status: "completed",
|
||
StorageType: service.SoraStorageTypeLocal,
|
||
MediaURL: "video/test.mp4",
|
||
MediaURLs: nil, // 空
|
||
}
|
||
genService := service.NewSoraGenerationService(repo, nil, nil)
|
||
h := &SoraClientHandler{genService: genService, mediaStorage: mediaStorage}
|
||
|
||
c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1)
|
||
c.Params = gin.Params{{Key: "id", Value: "1"}}
|
||
h.DeleteGeneration(c)
|
||
require.Equal(t, http.StatusOK, rec.Code)
|
||
}
|
||
|
||
func TestDeleteGeneration_NonLocalStorage_SkipCleanup(t *testing.T) {
|
||
// 非本地存储类型 → 跳过清理
|
||
repo := newStubSoraGenRepo()
|
||
repo.gens[1] = &service.SoraGeneration{
|
||
ID: 1,
|
||
UserID: 1,
|
||
Status: "completed",
|
||
StorageType: service.SoraStorageTypeUpstream,
|
||
MediaURL: "https://upstream.com/v.mp4",
|
||
}
|
||
genService := service.NewSoraGenerationService(repo, nil, nil)
|
||
h := &SoraClientHandler{genService: genService}
|
||
|
||
c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1)
|
||
c.Params = gin.Params{{Key: "id", Value: "1"}}
|
||
h.DeleteGeneration(c)
|
||
require.Equal(t, http.StatusOK, rec.Code)
|
||
}
|
||
|
||
func TestDeleteGeneration_DeleteError(t *testing.T) {
|
||
// repo.Delete 出错
|
||
repo := newStubSoraGenRepo()
|
||
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed", StorageType: "upstream"}
|
||
repo.deleteErr = fmt.Errorf("delete failed")
|
||
genService := service.NewSoraGenerationService(repo, nil, nil)
|
||
h := &SoraClientHandler{genService: genService}
|
||
|
||
c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1)
|
||
c.Params = gin.Params{{Key: "id", Value: "1"}}
|
||
h.DeleteGeneration(c)
|
||
require.Equal(t, http.StatusNotFound, rec.Code)
|
||
}
|
||
|
||
// ==================== fetchUpstreamModels 测试 ====================
|
||
|
||
func TestFetchUpstreamModels_NilGateway(t *testing.T) {
|
||
t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复")
|
||
h := &SoraClientHandler{}
|
||
_, err := h.fetchUpstreamModels(context.Background())
|
||
require.Error(t, err)
|
||
require.Contains(t, err.Error(), "gatewayService 未初始化")
|
||
}
|
||
|
||
func TestFetchUpstreamModels_NoAccounts(t *testing.T) {
|
||
t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复")
|
||
accountRepo := &stubAccountRepoForHandler{accounts: nil}
|
||
gatewayService := newMinimalGatewayService(accountRepo)
|
||
h := &SoraClientHandler{gatewayService: gatewayService}
|
||
_, err := h.fetchUpstreamModels(context.Background())
|
||
require.Error(t, err)
|
||
require.Contains(t, err.Error(), "选择 Sora 账号失败")
|
||
}
|
||
|
||
func TestFetchUpstreamModels_NonAPIKeyAccount(t *testing.T) {
|
||
t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复")
|
||
accountRepo := &stubAccountRepoForHandler{
|
||
accounts: []service.Account{
|
||
{ID: 1, Type: "oauth", Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true},
|
||
},
|
||
}
|
||
gatewayService := newMinimalGatewayService(accountRepo)
|
||
h := &SoraClientHandler{gatewayService: gatewayService}
|
||
_, err := h.fetchUpstreamModels(context.Background())
|
||
require.Error(t, err)
|
||
require.Contains(t, err.Error(), "不支持模型同步")
|
||
}
|
||
|
||
func TestFetchUpstreamModels_MissingAPIKey(t *testing.T) {
|
||
t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复")
|
||
accountRepo := &stubAccountRepoForHandler{
|
||
accounts: []service.Account{
|
||
{ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true,
|
||
Credentials: map[string]any{"base_url": "https://sora.test"}},
|
||
},
|
||
}
|
||
gatewayService := newMinimalGatewayService(accountRepo)
|
||
h := &SoraClientHandler{gatewayService: gatewayService}
|
||
_, err := h.fetchUpstreamModels(context.Background())
|
||
require.Error(t, err)
|
||
require.Contains(t, err.Error(), "api_key")
|
||
}
|
||
|
||
func TestFetchUpstreamModels_MissingBaseURL_FallsBackToDefault(t *testing.T) {
|
||
t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复")
|
||
// GetBaseURL() 在缺少 base_url 时返回默认值 "https://api.anthropic.com"
|
||
// 因此不会触发 "账号缺少 base_url" 错误,而是会尝试请求默认 URL 并失败
|
||
accountRepo := &stubAccountRepoForHandler{
|
||
accounts: []service.Account{
|
||
{ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true,
|
||
Credentials: map[string]any{"api_key": "sk-test"}},
|
||
},
|
||
}
|
||
gatewayService := newMinimalGatewayService(accountRepo)
|
||
h := &SoraClientHandler{gatewayService: gatewayService}
|
||
_, err := h.fetchUpstreamModels(context.Background())
|
||
require.Error(t, err)
|
||
}
|
||
|
||
func TestFetchUpstreamModels_UpstreamReturns500(t *testing.T) {
|
||
t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复")
|
||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
w.WriteHeader(http.StatusInternalServerError)
|
||
}))
|
||
defer ts.Close()
|
||
|
||
accountRepo := &stubAccountRepoForHandler{
|
||
accounts: []service.Account{
|
||
{ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true,
|
||
Credentials: map[string]any{"api_key": "sk-test", "base_url": ts.URL}},
|
||
},
|
||
}
|
||
gatewayService := newMinimalGatewayService(accountRepo)
|
||
h := &SoraClientHandler{gatewayService: gatewayService}
|
||
_, err := h.fetchUpstreamModels(context.Background())
|
||
require.Error(t, err)
|
||
require.Contains(t, err.Error(), "状态码 500")
|
||
}
|
||
|
||
func TestFetchUpstreamModels_UpstreamReturnsInvalidJSON(t *testing.T) {
|
||
t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复")
|
||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
w.WriteHeader(http.StatusOK)
|
||
_, _ = w.Write([]byte("not json"))
|
||
}))
|
||
defer ts.Close()
|
||
|
||
accountRepo := &stubAccountRepoForHandler{
|
||
accounts: []service.Account{
|
||
{ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true,
|
||
Credentials: map[string]any{"api_key": "sk-test", "base_url": ts.URL}},
|
||
},
|
||
}
|
||
gatewayService := newMinimalGatewayService(accountRepo)
|
||
h := &SoraClientHandler{gatewayService: gatewayService}
|
||
_, err := h.fetchUpstreamModels(context.Background())
|
||
require.Error(t, err)
|
||
require.Contains(t, err.Error(), "解析响应失败")
|
||
}
|
||
|
||
func TestFetchUpstreamModels_UpstreamReturnsEmptyList(t *testing.T) {
|
||
t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复")
|
||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
w.WriteHeader(http.StatusOK)
|
||
_, _ = w.Write([]byte(`{"data":[]}`))
|
||
}))
|
||
defer ts.Close()
|
||
|
||
accountRepo := &stubAccountRepoForHandler{
|
||
accounts: []service.Account{
|
||
{ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true,
|
||
Credentials: map[string]any{"api_key": "sk-test", "base_url": ts.URL}},
|
||
},
|
||
}
|
||
gatewayService := newMinimalGatewayService(accountRepo)
|
||
h := &SoraClientHandler{gatewayService: gatewayService}
|
||
_, err := h.fetchUpstreamModels(context.Background())
|
||
require.Error(t, err)
|
||
require.Contains(t, err.Error(), "空模型列表")
|
||
}
|
||
|
||
func TestFetchUpstreamModels_Success(t *testing.T) {
|
||
t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复")
|
||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
// 验证请求头
|
||
require.Equal(t, "Bearer sk-test", r.Header.Get("Authorization"))
|
||
require.True(t, strings.HasSuffix(r.URL.Path, "/sora/v1/models"))
|
||
w.WriteHeader(http.StatusOK)
|
||
_, _ = w.Write([]byte(`{"data":[{"id":"sora2-landscape-10s"},{"id":"sora2-portrait-10s"},{"id":"sora2-landscape-15s"},{"id":"gpt-image"}]}`))
|
||
}))
|
||
defer ts.Close()
|
||
|
||
accountRepo := &stubAccountRepoForHandler{
|
||
accounts: []service.Account{
|
||
{ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true,
|
||
Credentials: map[string]any{"api_key": "sk-test", "base_url": ts.URL}},
|
||
},
|
||
}
|
||
gatewayService := newMinimalGatewayService(accountRepo)
|
||
h := &SoraClientHandler{gatewayService: gatewayService}
|
||
families, err := h.fetchUpstreamModels(context.Background())
|
||
require.NoError(t, err)
|
||
require.NotEmpty(t, families)
|
||
}
|
||
|
||
func TestFetchUpstreamModels_UnrecognizedModels(t *testing.T) {
|
||
t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复")
|
||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
w.WriteHeader(http.StatusOK)
|
||
_, _ = w.Write([]byte(`{"data":[{"id":"unknown-model-1"},{"id":"unknown-model-2"}]}`))
|
||
}))
|
||
defer ts.Close()
|
||
|
||
accountRepo := &stubAccountRepoForHandler{
|
||
accounts: []service.Account{
|
||
{ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true,
|
||
Credentials: map[string]any{"api_key": "sk-test", "base_url": ts.URL}},
|
||
},
|
||
}
|
||
gatewayService := newMinimalGatewayService(accountRepo)
|
||
h := &SoraClientHandler{gatewayService: gatewayService}
|
||
_, err := h.fetchUpstreamModels(context.Background())
|
||
require.Error(t, err)
|
||
require.Contains(t, err.Error(), "未能从上游模型列表中识别")
|
||
}
|
||
|
||
// ==================== getModelFamilies 缓存测试 ====================
|
||
|
||
func TestGetModelFamilies_CachesLocalConfig(t *testing.T) {
|
||
// gatewayService 为 nil → fetchUpstreamModels 失败 → 降级到本地配置
|
||
h := &SoraClientHandler{}
|
||
families := h.getModelFamilies(context.Background())
|
||
require.NotEmpty(t, families)
|
||
|
||
// 第二次调用应命中缓存(modelCacheUpstream=false → 使用短 TTL)
|
||
families2 := h.getModelFamilies(context.Background())
|
||
require.Equal(t, families, families2)
|
||
require.False(t, h.modelCacheUpstream)
|
||
}
|
||
|
||
func TestGetModelFamilies_CachesUpstreamResult(t *testing.T) {
|
||
t.Skip("TODO: 临时屏蔽依赖 Sora 上游模型同步的缓存测试,待账号选择逻辑稳定后恢复")
|
||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
w.WriteHeader(http.StatusOK)
|
||
_, _ = w.Write([]byte(`{"data":[{"id":"sora2-landscape-10s"},{"id":"gpt-image"}]}`))
|
||
}))
|
||
defer ts.Close()
|
||
|
||
accountRepo := &stubAccountRepoForHandler{
|
||
accounts: []service.Account{
|
||
{ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true,
|
||
Credentials: map[string]any{"api_key": "sk-test", "base_url": ts.URL}},
|
||
},
|
||
}
|
||
gatewayService := newMinimalGatewayService(accountRepo)
|
||
h := &SoraClientHandler{gatewayService: gatewayService}
|
||
|
||
families := h.getModelFamilies(context.Background())
|
||
require.NotEmpty(t, families)
|
||
require.True(t, h.modelCacheUpstream)
|
||
|
||
// 第二次调用命中缓存
|
||
families2 := h.getModelFamilies(context.Background())
|
||
require.Equal(t, families, families2)
|
||
}
|
||
|
||
func TestGetModelFamilies_ExpiredCacheRefreshes(t *testing.T) {
|
||
// 预设过期的缓存(modelCacheUpstream=false → 短 TTL)
|
||
h := &SoraClientHandler{
|
||
cachedFamilies: []service.SoraModelFamily{{ID: "old"}},
|
||
modelCacheTime: time.Now().Add(-10 * time.Minute), // 已过期
|
||
modelCacheUpstream: false,
|
||
}
|
||
// gatewayService 为 nil → fetchUpstreamModels 失败 → 使用本地配置刷新缓存
|
||
families := h.getModelFamilies(context.Background())
|
||
require.NotEmpty(t, families)
|
||
// 缓存已刷新,不再是 "old"
|
||
found := false
|
||
for _, f := range families {
|
||
if f.ID == "old" {
|
||
found = true
|
||
}
|
||
}
|
||
require.False(t, found, "过期缓存应被刷新")
|
||
}
|
||
|
||
// ==================== processGeneration: groupID 与 ForcePlatform ====================
|
||
|
||
func TestProcessGeneration_NilGroupID_WithGateway_SelectAccountFails(t *testing.T) {
|
||
// groupID 为 nil → 设置 ForcePlatform=sora → 无可用 sora 账号 → MarkFailed
|
||
repo := newStubSoraGenRepo()
|
||
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
|
||
genService := service.NewSoraGenerationService(repo, nil, nil)
|
||
|
||
// 空账号列表 → SelectAccountForModel 失败
|
||
accountRepo := &stubAccountRepoForHandler{accounts: nil}
|
||
gatewayService := newMinimalGatewayService(accountRepo)
|
||
|
||
h := &SoraClientHandler{
|
||
genService: genService,
|
||
gatewayService: gatewayService,
|
||
}
|
||
|
||
h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1)
|
||
require.Equal(t, "failed", repo.gens[1].Status)
|
||
require.Contains(t, repo.gens[1].ErrorMessage, "选择账号失败")
|
||
}
|
||
|
||
// ==================== Generate: 配额检查非 QuotaExceeded 错误 ====================
|
||
|
||
func TestGenerate_CheckQuotaNonQuotaError(t *testing.T) {
|
||
// quotaService.CheckQuota 返回非 QuotaExceededError → 返回 403
|
||
repo := newStubSoraGenRepo()
|
||
genService := service.NewSoraGenerationService(repo, nil, nil)
|
||
|
||
// 用户不存在 → GetByID 失败 → CheckQuota 返回普通 error
|
||
userRepo := newStubUserRepoForHandler()
|
||
quotaService := service.NewSoraQuotaService(userRepo, nil, nil)
|
||
|
||
h := NewSoraClientHandler(genService, quotaService, nil, nil, nil, nil, nil)
|
||
|
||
body := `{"model":"sora2-landscape-10s","prompt":"test"}`
|
||
c, rec := makeGinContext("POST", "/api/v1/sora/generate", body, 1)
|
||
h.Generate(c)
|
||
require.Equal(t, http.StatusForbidden, rec.Code)
|
||
}
|
||
|
||
// ==================== Generate: CreatePending 并发限制错误 ====================
|
||
|
||
// stubSoraGenRepoWithAtomicCreate 实现 soraGenerationRepoAtomicCreator 接口
|
||
type stubSoraGenRepoWithAtomicCreate struct {
|
||
stubSoraGenRepo
|
||
limitErr error
|
||
}
|
||
|
||
func (r *stubSoraGenRepoWithAtomicCreate) CreatePendingWithLimit(_ context.Context, gen *service.SoraGeneration, _ []string, _ int64) error {
|
||
if r.limitErr != nil {
|
||
return r.limitErr
|
||
}
|
||
return r.stubSoraGenRepo.Create(context.Background(), gen)
|
||
}
|
||
|
||
func TestGenerate_CreatePendingConcurrencyLimit(t *testing.T) {
|
||
repo := &stubSoraGenRepoWithAtomicCreate{
|
||
stubSoraGenRepo: *newStubSoraGenRepo(),
|
||
limitErr: service.ErrSoraGenerationConcurrencyLimit,
|
||
}
|
||
genService := service.NewSoraGenerationService(repo, nil, nil)
|
||
h := NewSoraClientHandler(genService, nil, nil, nil, nil, nil, nil)
|
||
|
||
body := `{"model":"sora2-landscape-10s","prompt":"test"}`
|
||
c, rec := makeGinContext("POST", "/api/v1/sora/generate", body, 1)
|
||
h.Generate(c)
|
||
require.Equal(t, http.StatusTooManyRequests, rec.Code)
|
||
resp := parseResponse(t, rec)
|
||
require.Contains(t, resp["message"], "3")
|
||
}
|
||
|
||
// ==================== SaveToStorage: 配额超限 ====================
|
||
|
||
func TestSaveToStorage_QuotaExceeded(t *testing.T) {
|
||
sourceServer := newFakeSourceServer()
|
||
defer sourceServer.Close()
|
||
fakeS3 := newFakeS3Server("ok")
|
||
defer fakeS3.Close()
|
||
|
||
repo := newStubSoraGenRepo()
|
||
repo.gens[1] = &service.SoraGeneration{
|
||
ID: 1, UserID: 1, Status: "completed",
|
||
StorageType: "upstream",
|
||
MediaURL: sourceServer.URL + "/v.mp4",
|
||
}
|
||
s3Storage := newS3StorageForHandler(fakeS3.URL)
|
||
genService := service.NewSoraGenerationService(repo, nil, nil)
|
||
|
||
// 用户配额已满
|
||
userRepo := newStubUserRepoForHandler()
|
||
userRepo.users[1] = &service.User{
|
||
ID: 1,
|
||
SoraStorageQuotaBytes: 10,
|
||
SoraStorageUsedBytes: 10,
|
||
}
|
||
quotaService := service.NewSoraQuotaService(userRepo, nil, nil)
|
||
h := &SoraClientHandler{genService: genService, s3Storage: s3Storage, quotaService: quotaService}
|
||
|
||
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
|
||
c.Params = gin.Params{{Key: "id", Value: "1"}}
|
||
h.SaveToStorage(c)
|
||
require.Equal(t, http.StatusTooManyRequests, rec.Code)
|
||
}
|
||
|
||
// ==================== SaveToStorage: 配额非 QuotaExceeded 错误 ====================
|
||
|
||
func TestSaveToStorage_QuotaNonQuotaError(t *testing.T) {
|
||
sourceServer := newFakeSourceServer()
|
||
defer sourceServer.Close()
|
||
fakeS3 := newFakeS3Server("ok")
|
||
defer fakeS3.Close()
|
||
|
||
repo := newStubSoraGenRepo()
|
||
repo.gens[1] = &service.SoraGeneration{
|
||
ID: 1, UserID: 1, Status: "completed",
|
||
StorageType: "upstream",
|
||
MediaURL: sourceServer.URL + "/v.mp4",
|
||
}
|
||
s3Storage := newS3StorageForHandler(fakeS3.URL)
|
||
genService := service.NewSoraGenerationService(repo, nil, nil)
|
||
|
||
// 用户不存在 → GetByID 失败 → AddUsage 返回普通 error
|
||
userRepo := newStubUserRepoForHandler()
|
||
quotaService := service.NewSoraQuotaService(userRepo, nil, nil)
|
||
h := &SoraClientHandler{genService: genService, s3Storage: s3Storage, quotaService: quotaService}
|
||
|
||
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
|
||
c.Params = gin.Params{{Key: "id", Value: "1"}}
|
||
h.SaveToStorage(c)
|
||
require.Equal(t, http.StatusInternalServerError, rec.Code)
|
||
}
|
||
|
||
// ==================== SaveToStorage: MediaURLs 全为空 ====================
|
||
|
||
func TestSaveToStorage_EmptyMediaURLs(t *testing.T) {
|
||
fakeS3 := newFakeS3Server("ok")
|
||
defer fakeS3.Close()
|
||
|
||
repo := newStubSoraGenRepo()
|
||
repo.gens[1] = &service.SoraGeneration{
|
||
ID: 1, UserID: 1, Status: "completed",
|
||
StorageType: "upstream",
|
||
MediaURL: "",
|
||
MediaURLs: []string{},
|
||
}
|
||
s3Storage := newS3StorageForHandler(fakeS3.URL)
|
||
genService := service.NewSoraGenerationService(repo, nil, nil)
|
||
h := &SoraClientHandler{genService: genService, s3Storage: s3Storage}
|
||
|
||
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
|
||
c.Params = gin.Params{{Key: "id", Value: "1"}}
|
||
h.SaveToStorage(c)
|
||
require.Equal(t, http.StatusBadRequest, rec.Code)
|
||
resp := parseResponse(t, rec)
|
||
require.Contains(t, resp["message"], "已过期")
|
||
}
|
||
|
||
// ==================== SaveToStorage: S3 上传失败时已有已上传文件需清理 ====================
|
||
|
||
func TestSaveToStorage_MultiURL_SecondUploadFails(t *testing.T) {
|
||
sourceServer := newFakeSourceServer()
|
||
defer sourceServer.Close()
|
||
fakeS3 := newFakeS3Server("fail-second")
|
||
defer fakeS3.Close()
|
||
|
||
repo := newStubSoraGenRepo()
|
||
repo.gens[1] = &service.SoraGeneration{
|
||
ID: 1, UserID: 1, Status: "completed",
|
||
StorageType: "upstream",
|
||
MediaURL: sourceServer.URL + "/v1.mp4",
|
||
MediaURLs: []string{sourceServer.URL + "/v1.mp4", sourceServer.URL + "/v2.mp4"},
|
||
}
|
||
s3Storage := newS3StorageForHandler(fakeS3.URL)
|
||
genService := service.NewSoraGenerationService(repo, nil, nil)
|
||
h := &SoraClientHandler{genService: genService, s3Storage: s3Storage}
|
||
|
||
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
|
||
c.Params = gin.Params{{Key: "id", Value: "1"}}
|
||
h.SaveToStorage(c)
|
||
require.Equal(t, http.StatusInternalServerError, rec.Code)
|
||
}
|
||
|
||
// ==================== SaveToStorage: UpdateStorageForCompleted 失败(含配额回滚) ====================
|
||
|
||
func TestSaveToStorage_MarkCompletedFailsWithQuotaRollback(t *testing.T) {
|
||
sourceServer := newFakeSourceServer()
|
||
defer sourceServer.Close()
|
||
fakeS3 := newFakeS3Server("ok")
|
||
defer fakeS3.Close()
|
||
|
||
repo := newStubSoraGenRepo()
|
||
repo.gens[1] = &service.SoraGeneration{
|
||
ID: 1, UserID: 1, Status: "completed",
|
||
StorageType: "upstream",
|
||
MediaURL: sourceServer.URL + "/v.mp4",
|
||
}
|
||
repo.updateErr = fmt.Errorf("db error")
|
||
s3Storage := newS3StorageForHandler(fakeS3.URL)
|
||
genService := service.NewSoraGenerationService(repo, nil, nil)
|
||
|
||
userRepo := newStubUserRepoForHandler()
|
||
userRepo.users[1] = &service.User{
|
||
ID: 1,
|
||
SoraStorageQuotaBytes: 100 * 1024 * 1024,
|
||
SoraStorageUsedBytes: 0,
|
||
}
|
||
quotaService := service.NewSoraQuotaService(userRepo, nil, nil)
|
||
h := &SoraClientHandler{genService: genService, s3Storage: s3Storage, quotaService: quotaService}
|
||
|
||
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
|
||
c.Params = gin.Params{{Key: "id", Value: "1"}}
|
||
h.SaveToStorage(c)
|
||
require.Equal(t, http.StatusInternalServerError, rec.Code)
|
||
}
|
||
|
||
// ==================== cleanupStoredMedia: 实际 S3 删除路径 ====================
|
||
|
||
func TestCleanupStoredMedia_WithS3Storage_ActualDelete(t *testing.T) {
|
||
fakeS3 := newFakeS3Server("ok")
|
||
defer fakeS3.Close()
|
||
s3Storage := newS3StorageForHandler(fakeS3.URL)
|
||
h := &SoraClientHandler{s3Storage: s3Storage}
|
||
|
||
h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeS3, []string{"key1", "key2"}, nil)
|
||
}
|
||
|
||
func TestCleanupStoredMedia_S3DeleteFails_LogOnly(t *testing.T) {
|
||
fakeS3 := newFakeS3Server("fail")
|
||
defer fakeS3.Close()
|
||
s3Storage := newS3StorageForHandler(fakeS3.URL)
|
||
h := &SoraClientHandler{s3Storage: s3Storage}
|
||
|
||
h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeS3, []string{"key1"}, nil)
|
||
}
|
||
|
||
func TestCleanupStoredMedia_LocalDeleteFails_LogOnly(t *testing.T) {
|
||
tmpDir, err := os.MkdirTemp("", "sora-cleanup-fail-*")
|
||
require.NoError(t, err)
|
||
defer os.RemoveAll(tmpDir)
|
||
|
||
cfg := &config.Config{
|
||
Sora: config.SoraConfig{
|
||
Storage: config.SoraStorageConfig{
|
||
Type: "local",
|
||
LocalPath: tmpDir,
|
||
},
|
||
},
|
||
}
|
||
mediaStorage := service.NewSoraMediaStorage(cfg)
|
||
h := &SoraClientHandler{mediaStorage: mediaStorage}
|
||
|
||
h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeLocal, nil, []string{"nonexistent/file.mp4"})
|
||
}
|
||
|
||
// ==================== DeleteGeneration: 本地文件删除失败(仅日志) ====================
|
||
|
||
func TestDeleteGeneration_LocalStorageDeleteFails_LogOnly(t *testing.T) {
|
||
tmpDir, err := os.MkdirTemp("", "sora-del-test-*")
|
||
require.NoError(t, err)
|
||
defer os.RemoveAll(tmpDir)
|
||
|
||
cfg := &config.Config{
|
||
Sora: config.SoraConfig{
|
||
Storage: config.SoraStorageConfig{
|
||
Type: "local",
|
||
LocalPath: tmpDir,
|
||
},
|
||
},
|
||
}
|
||
mediaStorage := service.NewSoraMediaStorage(cfg)
|
||
|
||
repo := newStubSoraGenRepo()
|
||
repo.gens[1] = &service.SoraGeneration{
|
||
ID: 1, UserID: 1, Status: "completed",
|
||
StorageType: service.SoraStorageTypeLocal,
|
||
MediaURL: "nonexistent/video.mp4",
|
||
MediaURLs: []string{"nonexistent/video.mp4"},
|
||
}
|
||
genService := service.NewSoraGenerationService(repo, nil, nil)
|
||
h := &SoraClientHandler{genService: genService, mediaStorage: mediaStorage}
|
||
|
||
c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1)
|
||
c.Params = gin.Params{{Key: "id", Value: "1"}}
|
||
h.DeleteGeneration(c)
|
||
require.Equal(t, http.StatusOK, rec.Code)
|
||
}
|
||
|
||
// ==================== CancelGeneration: 任务已结束冲突 ====================
|
||
|
||
func TestCancelGeneration_AlreadyCompleted(t *testing.T) {
|
||
repo := newStubSoraGenRepo()
|
||
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed"}
|
||
genService := service.NewSoraGenerationService(repo, nil, nil)
|
||
h := &SoraClientHandler{genService: genService}
|
||
|
||
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1)
|
||
c.Params = gin.Params{{Key: "id", Value: "1"}}
|
||
h.CancelGeneration(c)
|
||
require.Equal(t, http.StatusConflict, rec.Code)
|
||
}
|