将 API 网关热路径中的 json.Unmarshal+json.Marshal 替换为 gjson 零拷贝查询和 sjson 精准写入: - unwrapV1InternalResponse 性能提升 22x(4009ns→182ns),内存分配减少 28.5x - unwrapGeminiResponse、extractGeminiUsage、estimateGeminiCountTokens、ParseGeminiRateLimitResetTime 改为接收 []byte 使用 gjson 提取 - ParseGatewayRequest 的 model/stream/metadata/thinking/max_tokens 改用 gjson 类型安全提取 - Handler 层(sora/openai)改用 gjson 提取字段、sjson 注入/修改字段,移除 map[string]any 中间变量 - Sora Client 响应解析改用 gjson ForEach 遍历,减少内存分配 - 新增约 100 个单元测试用例,所有改动函数覆盖率 >85% Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
516 lines
14 KiB
Go
516 lines
14 KiB
Go
//go:build unit
|
||
|
||
package service
|
||
|
||
import (
|
||
"strings"
|
||
"testing"
|
||
|
||
"github.com/stretchr/testify/assert"
|
||
"github.com/stretchr/testify/require"
|
||
"github.com/tidwall/gjson"
|
||
)
|
||
|
||
// ---------- 辅助解析函数(复制生产代码中的 gjson 解析逻辑,用于单元测试) ----------
|
||
|
||
// testParseUploadOrCreateTaskID 模拟 UploadImage / CreateImageTask / CreateVideoTask 中
|
||
// 用 gjson.GetBytes(respBody, "id") 提取 id 的逻辑。
|
||
func testParseUploadOrCreateTaskID(respBody []byte) (string, error) {
|
||
id := strings.TrimSpace(gjson.GetBytes(respBody, "id").String())
|
||
if id == "" {
|
||
return "", assert.AnError // 占位错误,表示 "missing id"
|
||
}
|
||
return id, nil
|
||
}
|
||
|
||
// testParseFetchRecentImageTask 模拟 fetchRecentImageTask 中的 gjson.ForEach 解析逻辑。
|
||
func testParseFetchRecentImageTask(respBody []byte, taskID string) (*SoraImageTaskStatus, bool) {
|
||
var found *SoraImageTaskStatus
|
||
gjson.GetBytes(respBody, "task_responses").ForEach(func(_, item gjson.Result) bool {
|
||
if item.Get("id").String() != taskID {
|
||
return true // continue
|
||
}
|
||
status := strings.TrimSpace(item.Get("status").String())
|
||
progress := item.Get("progress_pct").Float()
|
||
var urls []string
|
||
item.Get("generations").ForEach(func(_, gen gjson.Result) bool {
|
||
if u := strings.TrimSpace(gen.Get("url").String()); u != "" {
|
||
urls = append(urls, u)
|
||
}
|
||
return true
|
||
})
|
||
found = &SoraImageTaskStatus{
|
||
ID: taskID,
|
||
Status: status,
|
||
ProgressPct: progress,
|
||
URLs: urls,
|
||
}
|
||
return false // break
|
||
})
|
||
if found != nil {
|
||
return found, true
|
||
}
|
||
return &SoraImageTaskStatus{ID: taskID, Status: "processing"}, false
|
||
}
|
||
|
||
// testParseGetVideoTaskPending 模拟 GetVideoTask 中解析 pending 列表的逻辑。
|
||
func testParseGetVideoTaskPending(respBody []byte, taskID string) (*SoraVideoTaskStatus, bool) {
|
||
pendingResult := gjson.ParseBytes(respBody)
|
||
if !pendingResult.IsArray() {
|
||
return nil, false
|
||
}
|
||
var pendingFound *SoraVideoTaskStatus
|
||
pendingResult.ForEach(func(_, task gjson.Result) bool {
|
||
if task.Get("id").String() != taskID {
|
||
return true
|
||
}
|
||
progress := 0
|
||
if v := task.Get("progress_pct"); v.Exists() {
|
||
progress = int(v.Float() * 100)
|
||
}
|
||
status := strings.TrimSpace(task.Get("status").String())
|
||
pendingFound = &SoraVideoTaskStatus{
|
||
ID: taskID,
|
||
Status: status,
|
||
ProgressPct: progress,
|
||
}
|
||
return false
|
||
})
|
||
if pendingFound != nil {
|
||
return pendingFound, true
|
||
}
|
||
return nil, false
|
||
}
|
||
|
||
// testParseGetVideoTaskDrafts 模拟 GetVideoTask 中解析 drafts 列表的逻辑。
|
||
func testParseGetVideoTaskDrafts(respBody []byte, taskID string) (*SoraVideoTaskStatus, bool) {
|
||
var draftFound *SoraVideoTaskStatus
|
||
gjson.GetBytes(respBody, "items").ForEach(func(_, draft gjson.Result) bool {
|
||
if draft.Get("task_id").String() != taskID {
|
||
return true
|
||
}
|
||
kind := strings.TrimSpace(draft.Get("kind").String())
|
||
reason := strings.TrimSpace(draft.Get("reason_str").String())
|
||
if reason == "" {
|
||
reason = strings.TrimSpace(draft.Get("markdown_reason_str").String())
|
||
}
|
||
urlStr := strings.TrimSpace(draft.Get("downloadable_url").String())
|
||
if urlStr == "" {
|
||
urlStr = strings.TrimSpace(draft.Get("url").String())
|
||
}
|
||
|
||
if kind == "sora_content_violation" || reason != "" || urlStr == "" {
|
||
msg := reason
|
||
if msg == "" {
|
||
msg = "Content violates guardrails"
|
||
}
|
||
draftFound = &SoraVideoTaskStatus{
|
||
ID: taskID,
|
||
Status: "failed",
|
||
ErrorMsg: msg,
|
||
}
|
||
} else {
|
||
draftFound = &SoraVideoTaskStatus{
|
||
ID: taskID,
|
||
Status: "completed",
|
||
URLs: []string{urlStr},
|
||
}
|
||
}
|
||
return false
|
||
})
|
||
if draftFound != nil {
|
||
return draftFound, true
|
||
}
|
||
return nil, false
|
||
}
|
||
|
||
// ===================== Test 1: TestSoraParseUploadResponse =====================
|
||
|
||
func TestSoraParseUploadResponse(t *testing.T) {
|
||
tests := []struct {
|
||
name string
|
||
body string
|
||
wantID string
|
||
wantErr bool
|
||
}{
|
||
{
|
||
name: "正常 id",
|
||
body: `{"id":"file-abc123","status":"uploaded"}`,
|
||
wantID: "file-abc123",
|
||
},
|
||
{
|
||
name: "空 id",
|
||
body: `{"id":"","status":"uploaded"}`,
|
||
wantErr: true,
|
||
},
|
||
{
|
||
name: "无 id 字段",
|
||
body: `{"status":"uploaded"}`,
|
||
wantErr: true,
|
||
},
|
||
{
|
||
name: "id 全为空白",
|
||
body: `{"id":" ","status":"uploaded"}`,
|
||
wantErr: true,
|
||
},
|
||
{
|
||
name: "id 前后有空白",
|
||
body: `{"id":" file-trimmed ","status":"uploaded"}`,
|
||
wantID: "file-trimmed",
|
||
},
|
||
{
|
||
name: "空 JSON 对象",
|
||
body: `{}`,
|
||
wantErr: true,
|
||
},
|
||
}
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
id, err := testParseUploadOrCreateTaskID([]byte(tt.body))
|
||
if tt.wantErr {
|
||
require.Error(t, err, "应返回错误")
|
||
return
|
||
}
|
||
require.NoError(t, err)
|
||
require.Equal(t, tt.wantID, id)
|
||
})
|
||
}
|
||
}
|
||
|
||
// ===================== Test 2: TestSoraParseCreateTaskResponse =====================
|
||
|
||
func TestSoraParseCreateTaskResponse(t *testing.T) {
|
||
tests := []struct {
|
||
name string
|
||
body string
|
||
wantID string
|
||
wantErr bool
|
||
}{
|
||
{
|
||
name: "正常任务 id",
|
||
body: `{"id":"task-123"}`,
|
||
wantID: "task-123",
|
||
},
|
||
{
|
||
name: "缺失 id",
|
||
body: `{"status":"created"}`,
|
||
wantErr: true,
|
||
},
|
||
{
|
||
name: "空 id",
|
||
body: `{"id":" "}`,
|
||
wantErr: true,
|
||
},
|
||
{
|
||
name: "id 为数字(gjson 转字符串)",
|
||
body: `{"id":123}`,
|
||
wantID: "123",
|
||
},
|
||
{
|
||
name: "id 含特殊字符",
|
||
body: `{"id":"task-abc-def-456-ghi"}`,
|
||
wantID: "task-abc-def-456-ghi",
|
||
},
|
||
{
|
||
name: "额外字段不影响解析",
|
||
body: `{"id":"task-999","type":"image_gen","extra":"data"}`,
|
||
wantID: "task-999",
|
||
},
|
||
}
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
id, err := testParseUploadOrCreateTaskID([]byte(tt.body))
|
||
if tt.wantErr {
|
||
require.Error(t, err, "应返回错误")
|
||
return
|
||
}
|
||
require.NoError(t, err)
|
||
require.Equal(t, tt.wantID, id)
|
||
})
|
||
}
|
||
}
|
||
|
||
// ===================== Test 3: TestSoraParseFetchRecentImageTask =====================
|
||
|
||
func TestSoraParseFetchRecentImageTask(t *testing.T) {
|
||
tests := []struct {
|
||
name string
|
||
body string
|
||
taskID string
|
||
wantFound bool
|
||
wantStatus string
|
||
wantProgress float64
|
||
wantURLs []string
|
||
}{
|
||
{
|
||
name: "匹配已完成任务",
|
||
body: `{"task_responses":[{"id":"task-1","status":"completed","progress_pct":1.0,"generations":[{"url":"https://example.com/img.png"}]}]}`,
|
||
taskID: "task-1",
|
||
wantFound: true,
|
||
wantStatus: "completed",
|
||
wantProgress: 1.0,
|
||
wantURLs: []string{"https://example.com/img.png"},
|
||
},
|
||
{
|
||
name: "匹配处理中任务",
|
||
body: `{"task_responses":[{"id":"task-2","status":"processing","progress_pct":0.5,"generations":[]}]}`,
|
||
taskID: "task-2",
|
||
wantFound: true,
|
||
wantStatus: "processing",
|
||
wantProgress: 0.5,
|
||
wantURLs: nil,
|
||
},
|
||
{
|
||
name: "无匹配任务",
|
||
body: `{"task_responses":[{"id":"other","status":"completed"}]}`,
|
||
taskID: "task-1",
|
||
wantFound: false,
|
||
wantStatus: "processing",
|
||
},
|
||
{
|
||
name: "空 task_responses",
|
||
body: `{"task_responses":[]}`,
|
||
taskID: "task-1",
|
||
wantFound: false,
|
||
wantStatus: "processing",
|
||
},
|
||
{
|
||
name: "缺少 task_responses 字段",
|
||
body: `{"other":"data"}`,
|
||
taskID: "task-1",
|
||
wantFound: false,
|
||
wantStatus: "processing",
|
||
},
|
||
{
|
||
name: "多个任务中精准匹配",
|
||
body: `{"task_responses":[{"id":"task-a","status":"completed","progress_pct":1.0,"generations":[{"url":"https://a.com/1.png"}]},{"id":"task-b","status":"processing","progress_pct":0.3,"generations":[]},{"id":"task-c","status":"failed","progress_pct":0}]}`,
|
||
taskID: "task-b",
|
||
wantFound: true,
|
||
wantStatus: "processing",
|
||
wantProgress: 0.3,
|
||
wantURLs: nil,
|
||
},
|
||
{
|
||
name: "多个 generations",
|
||
body: `{"task_responses":[{"id":"task-m","status":"completed","progress_pct":1.0,"generations":[{"url":"https://a.com/1.png"},{"url":"https://a.com/2.png"},{"url":""}]}]}`,
|
||
taskID: "task-m",
|
||
wantFound: true,
|
||
wantStatus: "completed",
|
||
wantProgress: 1.0,
|
||
wantURLs: []string{"https://a.com/1.png", "https://a.com/2.png"},
|
||
},
|
||
}
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
status, found := testParseFetchRecentImageTask([]byte(tt.body), tt.taskID)
|
||
require.Equal(t, tt.wantFound, found, "found 不匹配")
|
||
require.NotNil(t, status)
|
||
require.Equal(t, tt.taskID, status.ID)
|
||
require.Equal(t, tt.wantStatus, status.Status)
|
||
if tt.wantFound {
|
||
require.InDelta(t, tt.wantProgress, status.ProgressPct, 0.001, "进度不匹配")
|
||
require.Equal(t, tt.wantURLs, status.URLs)
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
// ===================== Test 4: TestSoraParseGetVideoTaskPending =====================
|
||
|
||
func TestSoraParseGetVideoTaskPending(t *testing.T) {
|
||
tests := []struct {
|
||
name string
|
||
body string
|
||
taskID string
|
||
wantFound bool
|
||
wantStatus string
|
||
wantProgress int
|
||
}{
|
||
{
|
||
name: "匹配 pending 任务",
|
||
body: `[{"id":"task-1","status":"processing","progress_pct":0.5}]`,
|
||
taskID: "task-1",
|
||
wantFound: true,
|
||
wantStatus: "processing",
|
||
wantProgress: 50,
|
||
},
|
||
{
|
||
name: "进度为 0",
|
||
body: `[{"id":"task-2","status":"queued","progress_pct":0}]`,
|
||
taskID: "task-2",
|
||
wantFound: true,
|
||
wantStatus: "queued",
|
||
wantProgress: 0,
|
||
},
|
||
{
|
||
name: "进度为 1(100%)",
|
||
body: `[{"id":"task-3","status":"completing","progress_pct":1.0}]`,
|
||
taskID: "task-3",
|
||
wantFound: true,
|
||
wantStatus: "completing",
|
||
wantProgress: 100,
|
||
},
|
||
{
|
||
name: "空数组",
|
||
body: `[]`,
|
||
taskID: "task-1",
|
||
wantFound: false,
|
||
},
|
||
{
|
||
name: "无匹配 id",
|
||
body: `[{"id":"task-other","status":"processing","progress_pct":0.3}]`,
|
||
taskID: "task-1",
|
||
wantFound: false,
|
||
},
|
||
{
|
||
name: "多个任务精准匹配",
|
||
body: `[{"id":"task-a","status":"processing","progress_pct":0.2},{"id":"task-b","status":"queued","progress_pct":0},{"id":"task-c","status":"processing","progress_pct":0.8}]`,
|
||
taskID: "task-c",
|
||
wantFound: true,
|
||
wantStatus: "processing",
|
||
wantProgress: 80,
|
||
},
|
||
{
|
||
name: "非数组 JSON",
|
||
body: `{"id":"task-1","status":"processing"}`,
|
||
taskID: "task-1",
|
||
wantFound: false,
|
||
},
|
||
{
|
||
name: "无 progress_pct 字段",
|
||
body: `[{"id":"task-4","status":"pending"}]`,
|
||
taskID: "task-4",
|
||
wantFound: true,
|
||
wantStatus: "pending",
|
||
wantProgress: 0,
|
||
},
|
||
}
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
status, found := testParseGetVideoTaskPending([]byte(tt.body), tt.taskID)
|
||
require.Equal(t, tt.wantFound, found, "found 不匹配")
|
||
if tt.wantFound {
|
||
require.NotNil(t, status)
|
||
require.Equal(t, tt.taskID, status.ID)
|
||
require.Equal(t, tt.wantStatus, status.Status)
|
||
require.Equal(t, tt.wantProgress, status.ProgressPct)
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
// ===================== Test 5: TestSoraParseGetVideoTaskDrafts =====================
|
||
|
||
func TestSoraParseGetVideoTaskDrafts(t *testing.T) {
|
||
tests := []struct {
|
||
name string
|
||
body string
|
||
taskID string
|
||
wantFound bool
|
||
wantStatus string
|
||
wantURLs []string
|
||
wantErr string
|
||
}{
|
||
{
|
||
name: "正常完成的视频",
|
||
body: `{"items":[{"task_id":"task-1","kind":"video","downloadable_url":"https://example.com/video.mp4"}]}`,
|
||
taskID: "task-1",
|
||
wantFound: true,
|
||
wantStatus: "completed",
|
||
wantURLs: []string{"https://example.com/video.mp4"},
|
||
},
|
||
{
|
||
name: "使用 url 字段回退",
|
||
body: `{"items":[{"task_id":"task-2","kind":"video","url":"https://example.com/fallback.mp4"}]}`,
|
||
taskID: "task-2",
|
||
wantFound: true,
|
||
wantStatus: "completed",
|
||
wantURLs: []string{"https://example.com/fallback.mp4"},
|
||
},
|
||
{
|
||
name: "内容违规",
|
||
body: `{"items":[{"task_id":"task-3","kind":"sora_content_violation","reason_str":"Content policy violation"}]}`,
|
||
taskID: "task-3",
|
||
wantFound: true,
|
||
wantStatus: "failed",
|
||
wantErr: "Content policy violation",
|
||
},
|
||
{
|
||
name: "内容违规 - markdown_reason_str 回退",
|
||
body: `{"items":[{"task_id":"task-4","kind":"sora_content_violation","markdown_reason_str":"Markdown reason"}]}`,
|
||
taskID: "task-4",
|
||
wantFound: true,
|
||
wantStatus: "failed",
|
||
wantErr: "Markdown reason",
|
||
},
|
||
{
|
||
name: "内容违规 - 无 reason 使用默认消息",
|
||
body: `{"items":[{"task_id":"task-5","kind":"sora_content_violation"}]}`,
|
||
taskID: "task-5",
|
||
wantFound: true,
|
||
wantStatus: "failed",
|
||
wantErr: "Content violates guardrails",
|
||
},
|
||
{
|
||
name: "有 reason_str 但非 violation kind(仍判定失败)",
|
||
body: `{"items":[{"task_id":"task-6","kind":"video","reason_str":"Some error occurred"}]}`,
|
||
taskID: "task-6",
|
||
wantFound: true,
|
||
wantStatus: "failed",
|
||
wantErr: "Some error occurred",
|
||
},
|
||
{
|
||
name: "空 URL 判定为失败",
|
||
body: `{"items":[{"task_id":"task-7","kind":"video","downloadable_url":"","url":""}]}`,
|
||
taskID: "task-7",
|
||
wantFound: true,
|
||
wantStatus: "failed",
|
||
wantErr: "Content violates guardrails",
|
||
},
|
||
{
|
||
name: "无匹配 task_id",
|
||
body: `{"items":[{"task_id":"task-other","kind":"video","downloadable_url":"https://example.com/video.mp4"}]}`,
|
||
taskID: "task-1",
|
||
wantFound: false,
|
||
},
|
||
{
|
||
name: "空 items",
|
||
body: `{"items":[]}`,
|
||
taskID: "task-1",
|
||
wantFound: false,
|
||
},
|
||
{
|
||
name: "缺少 items 字段",
|
||
body: `{"other":"data"}`,
|
||
taskID: "task-1",
|
||
wantFound: false,
|
||
},
|
||
{
|
||
name: "多个 items 精准匹配",
|
||
body: `{"items":[{"task_id":"task-a","kind":"video","downloadable_url":"https://a.com/a.mp4"},{"task_id":"task-b","kind":"sora_content_violation","reason_str":"Bad content"},{"task_id":"task-c","kind":"video","downloadable_url":"https://c.com/c.mp4"}]}`,
|
||
taskID: "task-b",
|
||
wantFound: true,
|
||
wantStatus: "failed",
|
||
wantErr: "Bad content",
|
||
},
|
||
}
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
status, found := testParseGetVideoTaskDrafts([]byte(tt.body), tt.taskID)
|
||
require.Equal(t, tt.wantFound, found, "found 不匹配")
|
||
if !tt.wantFound {
|
||
return
|
||
}
|
||
require.NotNil(t, status)
|
||
require.Equal(t, tt.taskID, status.ID)
|
||
require.Equal(t, tt.wantStatus, status.Status)
|
||
if tt.wantErr != "" {
|
||
require.Equal(t, tt.wantErr, status.ErrorMsg)
|
||
}
|
||
if tt.wantURLs != nil {
|
||
require.Equal(t, tt.wantURLs, status.URLs)
|
||
}
|
||
})
|
||
}
|
||
}
|