perf(backend): 使用 gjson/sjson 优化热路径 JSON 处理

将 API 网关热路径中的 json.Unmarshal+json.Marshal 替换为 gjson 零拷贝查询和 sjson 精准写入:
- unwrapV1InternalResponse 性能提升 22x(4009ns→182ns),内存分配减少 28.5x
- unwrapGeminiResponse、extractGeminiUsage、estimateGeminiCountTokens、ParseGeminiRateLimitResetTime 改为接收 []byte 使用 gjson 提取
- ParseGatewayRequest 的 model/stream/metadata/thinking/max_tokens 改用 gjson 类型安全提取
- Handler 层(sora/openai)改用 gjson 提取字段、sjson 注入/修改字段,移除 map[string]any 中间变量
- Sora Client 响应解析改用 gjson ForEach 遍历,减少内存分配
- 新增约 100 个单元测试用例,所有改动函数覆盖率 >85%

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
yangjianbo
2026-02-10 08:59:30 +08:00
parent 29ca1290b3
commit 58912d4ac5
14 changed files with 1686 additions and 324 deletions

View File

@@ -22,6 +22,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/tidwall/gjson"
)
const (
@@ -981,16 +982,12 @@ func (s *AntigravityGatewayService) wrapV1InternalRequest(projectID, model strin
}
// unwrapV1InternalResponse 解包 v1internal 响应
// 使用 gjson 零拷贝提取 response 字段,避免 Unmarshal+Marshal 双重开销
func (s *AntigravityGatewayService) unwrapV1InternalResponse(body []byte) ([]byte, error) {
var outer map[string]any
if err := json.Unmarshal(body, &outer); err != nil {
return nil, err
result := gjson.GetBytes(body, "response")
if result.Exists() {
return []byte(result.Raw), nil
}
if resp, ok := outer["response"]; ok {
return json.Marshal(resp)
}
return body, nil
}
@@ -2516,11 +2513,11 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
}
// 解析 usage
if u := extractGeminiUsage(inner); u != nil {
usage = u
}
var parsed map[string]any
if json.Unmarshal(inner, &parsed) == nil {
if u := extractGeminiUsage(parsed); u != nil {
usage = u
}
// Check for MALFORMED_FUNCTION_CALL
if candidates, ok := parsed["candidates"].([]any); ok && len(candidates) > 0 {
if cand, ok := candidates[0].(map[string]any); ok {
@@ -2676,7 +2673,7 @@ func (s *AntigravityGatewayService) handleGeminiStreamToNonStreaming(c *gin.Cont
last = parsed
// 提取 usage
if u := extractGeminiUsage(parsed); u != nil {
if u := extractGeminiUsage(inner); u != nil {
usage = u
}

View File

@@ -9,6 +9,7 @@ import (
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
@@ -889,3 +890,144 @@ func TestAntigravityClientWriter(t *testing.T) {
require.True(t, cw.Disconnected())
})
}
// TestUnwrapV1InternalResponse 测试 unwrapV1InternalResponse 的各种输入场景
func TestUnwrapV1InternalResponse(t *testing.T) {
svc := &AntigravityGatewayService{}
// 构造 >50KB 的大型 JSON
largePadding := strings.Repeat("x", 50*1024)
largeInput := []byte(fmt.Sprintf(`{"response":{"id":"big","pad":"%s"}}`, largePadding))
largeExpected := fmt.Sprintf(`{"id":"big","pad":"%s"}`, largePadding)
tests := []struct {
name string
input []byte
expected string
wantErr bool
}{
{
name: "正常 response 包装",
input: []byte(`{"response":{"id":"123","content":"hello"}}`),
expected: `{"id":"123","content":"hello"}`,
},
{
name: "无 response 透传",
input: []byte(`{"id":"456"}`),
expected: `{"id":"456"}`,
},
{
name: "空 JSON",
input: []byte(`{}`),
expected: `{}`,
},
{
name: "response 为 null",
input: []byte(`{"response":null}`),
expected: `null`,
},
{
name: "response 为基础类型 string",
input: []byte(`{"response":"hello"}`),
expected: `"hello"`,
},
{
name: "非法 JSON",
input: []byte(`not json`),
expected: `not json`,
},
{
name: "嵌套 response 只解一层",
input: []byte(`{"response":{"response":{"inner":true}}}`),
expected: `{"response":{"inner":true}}`,
},
{
name: "大型 JSON >50KB",
input: largeInput,
expected: largeExpected,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := svc.unwrapV1InternalResponse(tt.input)
if tt.wantErr {
require.Error(t, err)
return
}
require.NoError(t, err)
require.Equal(t, tt.expected, strings.TrimSpace(string(got)))
})
}
}
// --- unwrapV1InternalResponse benchmark 对照组 ---
// unwrapV1InternalResponseOld 旧实现Unmarshal+Marshal 双重开销(仅用于 benchmark 对照)
func unwrapV1InternalResponseOld(body []byte) ([]byte, error) {
var outer map[string]any
if err := json.Unmarshal(body, &outer); err != nil {
return nil, err
}
if resp, ok := outer["response"]; ok {
return json.Marshal(resp)
}
return body, nil
}
func BenchmarkUnwrapV1Internal_Old_Small(b *testing.B) {
body := []byte(`{"response":{"candidates":[{"content":{"parts":[{"text":"hello world"}]}}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":5}}}`)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = unwrapV1InternalResponseOld(body)
}
}
func BenchmarkUnwrapV1Internal_New_Small(b *testing.B) {
body := []byte(`{"response":{"candidates":[{"content":{"parts":[{"text":"hello world"}]}}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":5}}}`)
svc := &AntigravityGatewayService{}
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = svc.unwrapV1InternalResponse(body)
}
}
func BenchmarkUnwrapV1Internal_Old_Large(b *testing.B) {
body := generateLargeUnwrapJSON(10 * 1024) // ~10KB
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = unwrapV1InternalResponseOld(body)
}
}
func BenchmarkUnwrapV1Internal_New_Large(b *testing.B) {
body := generateLargeUnwrapJSON(10 * 1024) // ~10KB
svc := &AntigravityGatewayService{}
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = svc.unwrapV1InternalResponse(body)
}
}
// generateLargeUnwrapJSON 生成指定最小大小的包含 response 包装的 JSON
func generateLargeUnwrapJSON(minSize int) []byte {
parts := make([]map[string]string, 0)
current := 0
for current < minSize {
text := fmt.Sprintf("这是第 %d 段内容,用于填充 JSON 到目标大小。", len(parts)+1)
parts = append(parts, map[string]string{"text": text})
current += len(text) + 20 // 估算 JSON 编码开销
}
inner := map[string]any{
"candidates": []map[string]any{
{"content": map[string]any{"parts": parts}},
},
"usageMetadata": map[string]any{
"promptTokenCount": 100,
"candidatesTokenCount": 50,
},
}
outer := map[string]any{"response": inner}
b, _ := json.Marshal(outer)
return b
}

View File

@@ -8,6 +8,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/domain"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/tidwall/gjson"
)
// SessionContext 粘性会话上下文,用于区分不同来源的请求。
@@ -48,38 +49,58 @@ type ParsedRequest struct {
// protocol 指定请求协议格式domain.PlatformAnthropic / domain.PlatformGemini
// 不同协议使用不同的 system/messages 字段名。
func ParseGatewayRequest(body []byte, protocol string) (*ParsedRequest, error) {
var req map[string]any
if err := json.Unmarshal(body, &req); err != nil {
return nil, err
}
parsed := &ParsedRequest{
Body: body,
}
if rawModel, exists := req["model"]; exists {
model, ok := rawModel.(string)
if !ok {
// --- gjson 提取简单字段(避免完整 Unmarshal ---
// model: 需要严格类型校验,非 string 返回错误
modelResult := gjson.GetBytes(body, "model")
if modelResult.Exists() {
if modelResult.Type != gjson.String {
return nil, fmt.Errorf("invalid model field type")
}
parsed.Model = model
parsed.Model = modelResult.String()
}
if rawStream, exists := req["stream"]; exists {
stream, ok := rawStream.(bool)
if !ok {
// stream: 需要严格类型校验,非 bool 返回错误
streamResult := gjson.GetBytes(body, "stream")
if streamResult.Exists() {
if streamResult.Type != gjson.True && streamResult.Type != gjson.False {
return nil, fmt.Errorf("invalid stream field type")
}
parsed.Stream = stream
parsed.Stream = streamResult.Bool()
}
if metadata, ok := req["metadata"].(map[string]any); ok {
if userID, ok := metadata["user_id"].(string); ok {
parsed.MetadataUserID = userID
// metadata.user_id: 直接路径提取,不需要严格类型校验
parsed.MetadataUserID = gjson.GetBytes(body, "metadata.user_id").String()
// thinking.type: 直接路径提取
if gjson.GetBytes(body, "thinking.type").String() == "enabled" {
parsed.ThinkingEnabled = true
}
// max_tokens: 仅接受整数值
maxTokensResult := gjson.GetBytes(body, "max_tokens")
if maxTokensResult.Exists() && maxTokensResult.Type == gjson.Number {
f := maxTokensResult.Float()
if !math.IsNaN(f) && !math.IsInf(f, 0) && f == math.Trunc(f) &&
f <= float64(math.MaxInt) && f >= float64(math.MinInt) {
parsed.MaxTokens = int(f)
}
}
// --- 保留 Unmarshal 用于 system/messages 提取 ---
// 这些字段需要作为 any/[]any 传递给下游消费者,无法用 gjson 替代
switch protocol {
case domain.PlatformGemini:
// Gemini 原生格式: systemInstruction.parts / contents
var req map[string]any
if err := json.Unmarshal(body, &req); err != nil {
return nil, err
}
if sysInst, ok := req["systemInstruction"].(map[string]any); ok {
if parts, ok := sysInst["parts"].([]any); ok {
parsed.System = parts
@@ -92,6 +113,10 @@ func ParseGatewayRequest(body []byte, protocol string) (*ParsedRequest, error) {
// Anthropic / OpenAI 格式: system / messages
// system 字段只要存在就视为显式提供(即使为 null
// 以避免客户端传 null 时被默认 system 误注入。
var req map[string]any
if err := json.Unmarshal(body, &req); err != nil {
return nil, err
}
if system, ok := req["system"]; ok {
parsed.HasSystem = true
parsed.System = system
@@ -101,20 +126,6 @@ func ParseGatewayRequest(body []byte, protocol string) (*ParsedRequest, error) {
}
}
// thinking: {type: "enabled"}
if rawThinking, ok := req["thinking"].(map[string]any); ok {
if t, ok := rawThinking["type"].(string); ok && t == "enabled" {
parsed.ThinkingEnabled = true
}
}
// max_tokens
if rawMaxTokens, exists := req["max_tokens"]; exists {
if maxTokens, ok := parseIntegralNumber(rawMaxTokens); ok {
parsed.MaxTokens = maxTokens
}
}
return parsed, nil
}

View File

@@ -1,7 +1,11 @@
//go:build unit
package service
import (
"encoding/json"
"fmt"
"strings"
"testing"
"github.com/Wei-Shaw/sub2api/internal/domain"
@@ -416,3 +420,341 @@ func TestFilterSignatureSensitiveBlocksForRetry_DowngradesTools(t *testing.T) {
require.Contains(t, content0["text"], "tool_use")
require.Contains(t, content1["text"], "tool_result")
}
// ============ Group 7: ParseGatewayRequest 补充单元测试 ============
// Task 7.1 — 类型校验边界测试
func TestParseGatewayRequest_TypeValidation(t *testing.T) {
tests := []struct {
name string
body string
wantErr bool
errSubstr string // 期望的错误信息子串(为空则不检查)
}{
{
name: "model 为 int",
body: `{"model":123}`,
wantErr: true,
errSubstr: "invalid model field type",
},
{
name: "model 为 array",
body: `{"model":[]}`,
wantErr: true,
errSubstr: "invalid model field type",
},
{
name: "model 为 bool",
body: `{"model":true}`,
wantErr: true,
errSubstr: "invalid model field type",
},
{
name: "model 为 null — gjson Null 类型触发类型校验错误",
body: `{"model":null}`,
wantErr: true, // gjson: Exists()=true, Type=Null != String → 返回错误
errSubstr: "invalid model field type",
},
{
name: "stream 为 string",
body: `{"stream":"true"}`,
wantErr: true,
errSubstr: "invalid stream field type",
},
{
name: "stream 为 int",
body: `{"stream":1}`,
wantErr: true,
errSubstr: "invalid stream field type",
},
{
name: "stream 为 null — gjson Null 类型触发类型校验错误",
body: `{"stream":null}`,
wantErr: true, // gjson: Exists()=true, Type=Null != True && != False → 返回错误
errSubstr: "invalid stream field type",
},
{
name: "model 为 object",
body: `{"model":{}}`,
wantErr: true,
errSubstr: "invalid model field type",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := ParseGatewayRequest([]byte(tt.body), "")
if tt.wantErr {
require.Error(t, err)
if tt.errSubstr != "" {
require.Contains(t, err.Error(), tt.errSubstr)
}
} else {
require.NoError(t, err)
}
})
}
}
// Task 7.2 — 可选字段缺失测试
func TestParseGatewayRequest_OptionalFieldsMissing(t *testing.T) {
tests := []struct {
name string
body string
wantModel string
wantStream bool
wantMetadataUID string
wantHasSystem bool
wantThinking bool
wantMaxTokens int
wantMessagesNil bool
wantMessagesLen int
}{
{
name: "完全空 JSON — 所有字段零值",
body: `{}`,
wantModel: "",
wantStream: false,
wantMetadataUID: "",
wantHasSystem: false,
wantThinking: false,
wantMaxTokens: 0,
wantMessagesNil: true,
},
{
name: "metadata 无 user_id",
body: `{"model":"test"}`,
wantModel: "test",
wantMetadataUID: "",
wantHasSystem: false,
wantThinking: false,
},
{
name: "thinking 非 enabledtype=disabled",
body: `{"model":"test","thinking":{"type":"disabled"}}`,
wantModel: "test",
wantThinking: false,
},
{
name: "thinking 字段缺失",
body: `{"model":"test"}`,
wantModel: "test",
wantThinking: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
parsed, err := ParseGatewayRequest([]byte(tt.body), "")
require.NoError(t, err)
require.Equal(t, tt.wantModel, parsed.Model)
require.Equal(t, tt.wantStream, parsed.Stream)
require.Equal(t, tt.wantMetadataUID, parsed.MetadataUserID)
require.Equal(t, tt.wantHasSystem, parsed.HasSystem)
require.Equal(t, tt.wantThinking, parsed.ThinkingEnabled)
require.Equal(t, tt.wantMaxTokens, parsed.MaxTokens)
if tt.wantMessagesNil {
require.Nil(t, parsed.Messages)
}
if tt.wantMessagesLen > 0 {
require.Len(t, parsed.Messages, tt.wantMessagesLen)
}
})
}
}
// Task 7.3 — Gemini 协议分支测试
// 已有测试覆盖:
// - TestParseGatewayRequest_GeminiSystemInstruction: 正常 systemInstruction+contents
// - TestParseGatewayRequest_GeminiNoContents: 缺失 contents
// - TestParseGatewayRequest_GeminiContents: 正常 contents无 systemInstruction
// 因此跳过。
// Task 7.4 — max_tokens 边界测试
func TestParseGatewayRequest_MaxTokensBoundary(t *testing.T) {
tests := []struct {
name string
body string
wantMaxTokens int
wantErr bool
}{
{
name: "正常整数",
body: `{"max_tokens":1024}`,
wantMaxTokens: 1024,
},
{
name: "浮点数(非整数)被忽略",
body: `{"max_tokens":10.5}`,
wantMaxTokens: 0,
},
{
name: "负整数可以通过",
body: `{"max_tokens":-1}`,
wantMaxTokens: -1,
},
{
name: "超大值不 panic",
body: `{"max_tokens":9999999999999999}`,
wantMaxTokens: 10000000000000000, // float64 精度导致 9999999999999999 → 1e16
},
{
name: "null 值被忽略",
body: `{"max_tokens":null}`,
wantMaxTokens: 0, // gjson Type=Null != Number → 条件不满足,跳过
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
parsed, err := ParseGatewayRequest([]byte(tt.body), "")
if tt.wantErr {
require.Error(t, err)
return
}
require.NoError(t, err)
require.Equal(t, tt.wantMaxTokens, parsed.MaxTokens)
})
}
}
// ============ Task 7.5: Benchmark 测试 ============
// parseGatewayRequestOld 是基于完整 json.Unmarshal 的旧实现,用于 benchmark 对比基线。
// 核心路径:先 Unmarshal 到 map[string]any再逐字段提取。
func parseGatewayRequestOld(body []byte, protocol string) (*ParsedRequest, error) {
parsed := &ParsedRequest{
Body: body,
}
var req map[string]any
if err := json.Unmarshal(body, &req); err != nil {
return nil, err
}
// model
if raw, ok := req["model"]; ok {
s, ok := raw.(string)
if !ok {
return nil, fmt.Errorf("invalid model field type")
}
parsed.Model = s
}
// stream
if raw, ok := req["stream"]; ok {
b, ok := raw.(bool)
if !ok {
return nil, fmt.Errorf("invalid stream field type")
}
parsed.Stream = b
}
// metadata.user_id
if meta, ok := req["metadata"].(map[string]any); ok {
if uid, ok := meta["user_id"].(string); ok {
parsed.MetadataUserID = uid
}
}
// thinking.type
if thinking, ok := req["thinking"].(map[string]any); ok {
if thinkType, ok := thinking["type"].(string); ok && thinkType == "enabled" {
parsed.ThinkingEnabled = true
}
}
// max_tokens
if raw, ok := req["max_tokens"]; ok {
if n, ok := parseIntegralNumber(raw); ok {
parsed.MaxTokens = n
}
}
// system / messages按协议分支
switch protocol {
case domain.PlatformGemini:
if sysInst, ok := req["systemInstruction"].(map[string]any); ok {
if parts, ok := sysInst["parts"].([]any); ok {
parsed.System = parts
}
}
if contents, ok := req["contents"].([]any); ok {
parsed.Messages = contents
}
default:
if system, ok := req["system"]; ok {
parsed.HasSystem = true
parsed.System = system
}
if messages, ok := req["messages"].([]any); ok {
parsed.Messages = messages
}
}
return parsed, nil
}
// buildSmallJSON 构建 ~500B 的小型测试 JSON
func buildSmallJSON() []byte {
return []byte(`{"model":"claude-sonnet-4-5","stream":true,"max_tokens":4096,"metadata":{"user_id":"user-abc123"},"thinking":{"type":"enabled","budget_tokens":2048},"system":"You are a helpful assistant.","messages":[{"role":"user","content":"What is the meaning of life?"},{"role":"assistant","content":"The meaning of life is a philosophical question."},{"role":"user","content":"Can you elaborate?"}]}`)
}
// buildLargeJSON 构建 ~50KB 的大型测试 JSON大量 messages
func buildLargeJSON() []byte {
var b strings.Builder
b.WriteString(`{"model":"claude-sonnet-4-5","stream":true,"max_tokens":8192,"metadata":{"user_id":"user-xyz789"},"system":[{"type":"text","text":"You are a detailed assistant.","cache_control":{"type":"ephemeral"}}],"messages":[`)
msgCount := 200
for i := 0; i < msgCount; i++ {
if i > 0 {
b.WriteByte(',')
}
if i%2 == 0 {
b.WriteString(fmt.Sprintf(`{"role":"user","content":"This is user message number %d with some extra padding text to make the message reasonably long for benchmarking purposes. Lorem ipsum dolor sit amet."}`, i))
} else {
b.WriteString(fmt.Sprintf(`{"role":"assistant","content":[{"type":"text","text":"This is assistant response number %d. I will provide a detailed answer with multiple sentences to simulate real conversation content for benchmark testing."}]}`, i))
}
}
b.WriteString(`]}`)
return []byte(b.String())
}
func BenchmarkParseGatewayRequest_Old_Small(b *testing.B) {
data := buildSmallJSON()
b.SetBytes(int64(len(data)))
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = parseGatewayRequestOld(data, "")
}
}
func BenchmarkParseGatewayRequest_New_Small(b *testing.B) {
data := buildSmallJSON()
b.SetBytes(int64(len(data)))
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = ParseGatewayRequest(data, "")
}
}
func BenchmarkParseGatewayRequest_Old_Large(b *testing.B) {
data := buildLargeJSON()
b.SetBytes(int64(len(data)))
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = parseGatewayRequestOld(data, "")
}
}
func BenchmarkParseGatewayRequest_New_Large(b *testing.B) {
data := buildLargeJSON()
b.SetBytes(int64(len(data)))
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = ParseGatewayRequest(data, "")
}
}

View File

@@ -26,6 +26,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
)
const geminiStickySessionTTL = time.Hour
@@ -929,7 +930,8 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
if err != nil {
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Failed to read upstream stream")
}
claudeResp, usageObj2 := convertGeminiToClaudeMessage(collected, originalModel)
collectedBytes, _ := json.Marshal(collected)
claudeResp, usageObj2 := convertGeminiToClaudeMessage(collected, originalModel, collectedBytes)
c.JSON(http.StatusOK, claudeResp)
usage = usageObj2
if usageObj != nil && (usageObj.InputTokens > 0 || usageObj.OutputTokens > 0) {
@@ -1726,12 +1728,17 @@ func (s *GeminiMessagesCompatService) handleNonStreamingResponse(c *gin.Context,
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Failed to read upstream response")
}
geminiResp, err := unwrapGeminiResponse(body)
unwrappedBody, err := unwrapGeminiResponse(body)
if err != nil {
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Failed to parse upstream response")
}
claudeResp, usage := convertGeminiToClaudeMessage(geminiResp, originalModel)
var geminiResp map[string]any
if err := json.Unmarshal(unwrappedBody, &geminiResp); err != nil {
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Failed to parse upstream response")
}
claudeResp, usage := convertGeminiToClaudeMessage(geminiResp, originalModel, unwrappedBody)
c.JSON(http.StatusOK, claudeResp)
return usage, nil
@@ -1804,11 +1811,16 @@ func (s *GeminiMessagesCompatService) handleStreamingResponse(c *gin.Context, re
continue
}
geminiResp, err := unwrapGeminiResponse([]byte(payload))
unwrappedBytes, err := unwrapGeminiResponse([]byte(payload))
if err != nil {
continue
}
var geminiResp map[string]any
if err := json.Unmarshal(unwrappedBytes, &geminiResp); err != nil {
continue
}
if fr := extractGeminiFinishReason(geminiResp); fr != "" {
finishReason = fr
}
@@ -1935,7 +1947,7 @@ func (s *GeminiMessagesCompatService) handleStreamingResponse(c *gin.Context, re
}
}
if u := extractGeminiUsage(geminiResp); u != nil {
if u := extractGeminiUsage(unwrappedBytes); u != nil {
usage = *u
}
@@ -2026,11 +2038,7 @@ func unwrapIfNeeded(isOAuth bool, raw []byte) []byte {
if err != nil {
return raw
}
b, err := json.Marshal(inner)
if err != nil {
return raw
}
return b
return inner
}
func collectGeminiSSE(body io.Reader, isOAuth bool) (map[string]any, *ClaudeUsage, error) {
@@ -2054,17 +2062,20 @@ func collectGeminiSSE(body io.Reader, isOAuth bool) (map[string]any, *ClaudeUsag
}
default:
var parsed map[string]any
var rawBytes []byte
if isOAuth {
inner, err := unwrapGeminiResponse([]byte(payload))
if err == nil && inner != nil {
parsed = inner
innerBytes, err := unwrapGeminiResponse([]byte(payload))
if err == nil {
rawBytes = innerBytes
_ = json.Unmarshal(innerBytes, &parsed)
}
} else {
_ = json.Unmarshal([]byte(payload), &parsed)
rawBytes = []byte(payload)
_ = json.Unmarshal(rawBytes, &parsed)
}
if parsed != nil {
last = parsed
if u := extractGeminiUsage(parsed); u != nil {
if u := extractGeminiUsage(rawBytes); u != nil {
usage = u
}
if parts := extractGeminiParts(parsed); len(parts) > 0 {
@@ -2193,53 +2204,27 @@ func isGeminiInsufficientScope(headers http.Header, body []byte) bool {
}
func estimateGeminiCountTokens(reqBody []byte) int {
var obj map[string]any
if err := json.Unmarshal(reqBody, &obj); err != nil {
return 0
}
var texts []string
total := 0
// systemInstruction.parts[].text
if si, ok := obj["systemInstruction"].(map[string]any); ok {
if parts, ok := si["parts"].([]any); ok {
for _, p := range parts {
if pm, ok := p.(map[string]any); ok {
if t, ok := pm["text"].(string); ok && strings.TrimSpace(t) != "" {
texts = append(texts, t)
}
}
}
gjson.GetBytes(reqBody, "systemInstruction.parts").ForEach(func(_, part gjson.Result) bool {
if t := strings.TrimSpace(part.Get("text").String()); t != "" {
total += estimateTokensForText(t)
}
}
return true
})
// contents[].parts[].text
if contents, ok := obj["contents"].([]any); ok {
for _, c := range contents {
cm, ok := c.(map[string]any)
if !ok {
continue
gjson.GetBytes(reqBody, "contents").ForEach(func(_, content gjson.Result) bool {
content.Get("parts").ForEach(func(_, part gjson.Result) bool {
if t := strings.TrimSpace(part.Get("text").String()); t != "" {
total += estimateTokensForText(t)
}
parts, ok := cm["parts"].([]any)
if !ok {
continue
}
for _, p := range parts {
pm, ok := p.(map[string]any)
if !ok {
continue
}
if t, ok := pm["text"].(string); ok && strings.TrimSpace(t) != "" {
texts = append(texts, t)
}
}
}
}
return true
})
return true
})
total := 0
for _, t := range texts {
total += estimateTokensForText(t)
}
if total < 0 {
return 0
}
@@ -2293,10 +2278,11 @@ func (s *GeminiMessagesCompatService) handleNativeNonStreamingResponse(c *gin.Co
var parsed map[string]any
if isOAuth {
parsed, err = unwrapGeminiResponse(respBody)
if err == nil && parsed != nil {
respBody, _ = json.Marshal(parsed)
unwrappedBody, uwErr := unwrapGeminiResponse(respBody)
if uwErr == nil {
respBody = unwrappedBody
}
_ = json.Unmarshal(respBody, &parsed)
} else {
_ = json.Unmarshal(respBody, &parsed)
}
@@ -2309,10 +2295,8 @@ func (s *GeminiMessagesCompatService) handleNativeNonStreamingResponse(c *gin.Co
}
c.Data(resp.StatusCode, contentType, respBody)
if parsed != nil {
if u := extractGeminiUsage(parsed); u != nil {
return u, nil
}
if u := extractGeminiUsage(respBody); u != nil {
return u, nil
}
return &ClaudeUsage{}, nil
}
@@ -2365,23 +2349,19 @@ func (s *GeminiMessagesCompatService) handleNativeStreamingResponse(c *gin.Conte
var rawToWrite string
rawToWrite = payload
var parsed map[string]any
var rawBytes []byte
if isOAuth {
inner, err := unwrapGeminiResponse([]byte(payload))
if err == nil && inner != nil {
parsed = inner
if b, err := json.Marshal(inner); err == nil {
rawToWrite = string(b)
}
innerBytes, err := unwrapGeminiResponse([]byte(payload))
if err == nil {
rawToWrite = string(innerBytes)
rawBytes = innerBytes
}
} else {
_ = json.Unmarshal([]byte(payload), &parsed)
rawBytes = []byte(payload)
}
if parsed != nil {
if u := extractGeminiUsage(parsed); u != nil {
usage = u
}
if u := extractGeminiUsage(rawBytes); u != nil {
usage = u
}
if firstTokenMs == nil {
@@ -2484,19 +2464,18 @@ func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, ac
}, nil
}
func unwrapGeminiResponse(raw []byte) (map[string]any, error) {
var outer map[string]any
if err := json.Unmarshal(raw, &outer); err != nil {
return nil, err
// unwrapGeminiResponse 解包 Gemini OAuth 响应中的 response 字段
// 使用 gjson 零拷贝提取,避免完整 Unmarshal+Marshal
func unwrapGeminiResponse(raw []byte) ([]byte, error) {
result := gjson.GetBytes(raw, "response")
if result.Exists() && result.Type == gjson.JSON {
return []byte(result.Raw), nil
}
if resp, ok := outer["response"].(map[string]any); ok && resp != nil {
return resp, nil
}
return outer, nil
return raw, nil
}
func convertGeminiToClaudeMessage(geminiResp map[string]any, originalModel string) (map[string]any, *ClaudeUsage) {
usage := extractGeminiUsage(geminiResp)
func convertGeminiToClaudeMessage(geminiResp map[string]any, originalModel string, rawData []byte) (map[string]any, *ClaudeUsage) {
usage := extractGeminiUsage(rawData)
if usage == nil {
usage = &ClaudeUsage{}
}
@@ -2560,14 +2539,14 @@ func convertGeminiToClaudeMessage(geminiResp map[string]any, originalModel strin
return resp, usage
}
func extractGeminiUsage(geminiResp map[string]any) *ClaudeUsage {
usageMeta, ok := geminiResp["usageMetadata"].(map[string]any)
if !ok || usageMeta == nil {
func extractGeminiUsage(data []byte) *ClaudeUsage {
usage := gjson.GetBytes(data, "usageMetadata")
if !usage.Exists() {
return nil
}
prompt, _ := asInt(usageMeta["promptTokenCount"])
cand, _ := asInt(usageMeta["candidatesTokenCount"])
cached, _ := asInt(usageMeta["cachedContentTokenCount"])
prompt := int(usage.Get("promptTokenCount").Int())
cand := int(usage.Get("candidatesTokenCount").Int())
cached := int(usage.Get("cachedContentTokenCount").Int())
// 注意Gemini 的 promptTokenCount 包含 cachedContentTokenCount
// 但 Claude 的 input_tokens 不包含 cache_read_input_tokens需要减去
return &ClaudeUsage{
@@ -2646,39 +2625,35 @@ func (s *GeminiMessagesCompatService) handleGeminiUpstreamError(ctx context.Cont
// ParseGeminiRateLimitResetTime 解析 Gemini 格式的 429 响应,返回重置时间的 Unix 时间戳
func ParseGeminiRateLimitResetTime(body []byte) *int64 {
// Try to parse metadata.quotaResetDelay like "12.345s"
var parsed map[string]any
if err := json.Unmarshal(body, &parsed); err == nil {
if errObj, ok := parsed["error"].(map[string]any); ok {
if msg, ok := errObj["message"].(string); ok {
if looksLikeGeminiDailyQuota(msg) {
if ts := nextGeminiDailyResetUnix(); ts != nil {
return ts
}
}
}
if details, ok := errObj["details"].([]any); ok {
for _, d := range details {
dm, ok := d.(map[string]any)
if !ok {
continue
}
if meta, ok := dm["metadata"].(map[string]any); ok {
if v, ok := meta["quotaResetDelay"].(string); ok {
if dur, err := time.ParseDuration(v); err == nil {
// Use ceil to avoid undercounting fractional seconds (e.g. 10.1s should not become 10s),
// which can affect scheduling decisions around thresholds (like 10s).
ts := time.Now().Unix() + int64(math.Ceil(dur.Seconds()))
return &ts
}
}
}
}
}
// 第一阶段gjson 结构化提取
errMsg := gjson.GetBytes(body, "error.message").String()
if looksLikeGeminiDailyQuota(errMsg) {
if ts := nextGeminiDailyResetUnix(); ts != nil {
return ts
}
}
// Match "Please retry in Xs"
// 遍历 error.details 查找 quotaResetDelay
var found *int64
gjson.GetBytes(body, "error.details").ForEach(func(_, detail gjson.Result) bool {
v := detail.Get("metadata.quotaResetDelay").String()
if v == "" {
return true
}
if dur, err := time.ParseDuration(v); err == nil {
// Use ceil to avoid undercounting fractional seconds (e.g. 10.1s should not become 10s),
// which can affect scheduling decisions around thresholds (like 10s).
ts := time.Now().Unix() + int64(math.Ceil(dur.Seconds()))
found = &ts
return false
}
return true
})
if found != nil {
return found
}
// 第二阶段regex 回退匹配 "Please retry in Xs"
matches := retryInRegex.FindStringSubmatch(string(body))
if len(matches) == 2 {
if dur, err := time.ParseDuration(matches[1] + "s"); err == nil {

View File

@@ -2,8 +2,12 @@ package service
import (
"encoding/json"
"fmt"
"strings"
"testing"
"time"
"github.com/stretchr/testify/require"
)
// TestConvertClaudeToolsToGeminiTools_CustomType 测试custom类型工具转换
@@ -203,3 +207,304 @@ func TestEnsureGeminiFunctionCallThoughtSignatures_InsertsWhenMissing(t *testing
t.Fatalf("expected injected thoughtSignature %q, got: %s", geminiDummyThoughtSignature, s)
}
}
// TestUnwrapGeminiResponse 测试 unwrapGeminiResponse 的各种输入场景
// 关键区别:只有 response 为 JSON 对象/数组时才解包
func TestUnwrapGeminiResponse(t *testing.T) {
// 构造 >50KB 的大型 JSON 对象
largePadding := strings.Repeat("x", 50*1024)
largeInput := []byte(fmt.Sprintf(`{"response":{"id":"big","pad":"%s"}}`, largePadding))
largeExpected := fmt.Sprintf(`{"id":"big","pad":"%s"}`, largePadding)
tests := []struct {
name string
input []byte
expected string
wantErr bool
}{
{
name: "正常 response 包装JSON 对象)",
input: []byte(`{"response":{"key":"val"}}`),
expected: `{"key":"val"}`,
},
{
name: "无包装直接返回",
input: []byte(`{"key":"val"}`),
expected: `{"key":"val"}`,
},
{
name: "空 JSON",
input: []byte(`{}`),
expected: `{}`,
},
{
name: "null response 返回原始 body",
input: []byte(`{"response":null}`),
expected: `{"response":null}`,
},
{
name: "非法 JSON 返回原始 body",
input: []byte(`not json`),
expected: `not json`,
},
{
name: "response 为基础类型 string 返回原始 body",
input: []byte(`{"response":"hello"}`),
expected: `{"response":"hello"}`,
},
{
name: "嵌套 response 只解一层",
input: []byte(`{"response":{"response":{"inner":true}}}`),
expected: `{"response":{"inner":true}}`,
},
{
name: "大型 JSON >50KB",
input: largeInput,
expected: largeExpected,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := unwrapGeminiResponse(tt.input)
if tt.wantErr {
require.Error(t, err)
return
}
require.NoError(t, err)
require.Equal(t, tt.expected, strings.TrimSpace(string(got)))
})
}
}
// ---------------------------------------------------------------------------
// Task 8.1 — extractGeminiUsage 测试
// ---------------------------------------------------------------------------
func TestExtractGeminiUsage(t *testing.T) {
tests := []struct {
name string
input string
wantNil bool
wantUsage *ClaudeUsage
}{
{
name: "完整 usageMetadata",
input: `{"usageMetadata":{"promptTokenCount":100,"candidatesTokenCount":50,"cachedContentTokenCount":20}}`,
wantNil: false,
wantUsage: &ClaudeUsage{
InputTokens: 80,
OutputTokens: 50,
CacheReadInputTokens: 20,
},
},
{
name: "缺失 cachedContentTokenCount",
input: `{"usageMetadata":{"promptTokenCount":100,"candidatesTokenCount":50}}`,
wantNil: false,
wantUsage: &ClaudeUsage{
InputTokens: 100,
OutputTokens: 50,
CacheReadInputTokens: 0,
},
},
{
name: "无 usageMetadata",
input: `{"candidates":[]}`,
wantNil: true,
},
{
// gjson 对 null 返回 Exists()=true因此函数不会返回 nil
// 而是返回全零的 ClaudeUsage。
name: "null usageMetadata — gjson Exists 为 true",
input: `{"usageMetadata":null}`,
wantNil: false,
wantUsage: &ClaudeUsage{
InputTokens: 0,
OutputTokens: 0,
CacheReadInputTokens: 0,
},
},
{
name: "零值字段",
input: `{"usageMetadata":{"promptTokenCount":0,"candidatesTokenCount":0,"cachedContentTokenCount":0}}`,
wantNil: false,
wantUsage: &ClaudeUsage{
InputTokens: 0,
OutputTokens: 0,
CacheReadInputTokens: 0,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := extractGeminiUsage([]byte(tt.input))
if tt.wantNil {
if got != nil {
t.Fatalf("期望返回 nil实际返回 %+v", got)
}
return
}
if got == nil {
t.Fatalf("期望返回非 nil实际返回 nil")
}
if got.InputTokens != tt.wantUsage.InputTokens {
t.Errorf("InputTokens: 期望 %d实际 %d", tt.wantUsage.InputTokens, got.InputTokens)
}
if got.OutputTokens != tt.wantUsage.OutputTokens {
t.Errorf("OutputTokens: 期望 %d实际 %d", tt.wantUsage.OutputTokens, got.OutputTokens)
}
if got.CacheReadInputTokens != tt.wantUsage.CacheReadInputTokens {
t.Errorf("CacheReadInputTokens: 期望 %d实际 %d", tt.wantUsage.CacheReadInputTokens, got.CacheReadInputTokens)
}
})
}
}
// ---------------------------------------------------------------------------
// Task 8.2 — estimateGeminiCountTokens 测试
// ---------------------------------------------------------------------------
func TestEstimateGeminiCountTokens(t *testing.T) {
tests := []struct {
name string
input string
wantGt0 bool // 期望结果 > 0
wantExact *int // 如果非 nil期望精确匹配
}{
{
name: "含 systemInstruction 和 contents",
input: `{
"systemInstruction":{"parts":[{"text":"You are a helpful assistant."}]},
"contents":[{"parts":[{"text":"Hello, how are you?"}]}]
}`,
wantGt0: true,
},
{
name: "仅 contents无 systemInstruction",
input: `{
"contents":[{"parts":[{"text":"Hello, how are you?"}]}]
}`,
wantGt0: true,
},
{
name: "空 parts",
input: `{"contents":[{"parts":[]}]}`,
wantGt0: false,
wantExact: intPtr(0),
},
{
name: "非文本 partsinlineData",
input: `{"contents":[{"parts":[{"inlineData":{"mimeType":"image/png"}}]}]}`,
wantGt0: false,
wantExact: intPtr(0),
},
{
name: "空白文本",
input: `{"contents":[{"parts":[{"text":" "}]}]}`,
wantGt0: false,
wantExact: intPtr(0),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := estimateGeminiCountTokens([]byte(tt.input))
if tt.wantExact != nil {
if got != *tt.wantExact {
t.Errorf("期望精确值 %d实际 %d", *tt.wantExact, got)
}
return
}
if tt.wantGt0 && got <= 0 {
t.Errorf("期望返回 > 0实际 %d", got)
}
if !tt.wantGt0 && got != 0 {
t.Errorf("期望返回 0实际 %d", got)
}
})
}
}
// ---------------------------------------------------------------------------
// Task 8.3 — ParseGeminiRateLimitResetTime 测试
// ---------------------------------------------------------------------------
func TestParseGeminiRateLimitResetTime(t *testing.T) {
tests := []struct {
name string
input string
wantNil bool
approxDelta int64 // 预期的 (返回值 - now) 大约是多少秒
}{
{
name: "正常 quotaResetDelay",
input: `{"error":{"details":[{"metadata":{"quotaResetDelay":"12.345s"}}]}}`,
wantNil: false,
approxDelta: 13, // 向上取整 12.345 -> 13
},
{
name: "daily quota",
input: `{"error":{"message":"quota per day exceeded"}}`,
wantNil: false,
approxDelta: -1, // 不检查精确 delta仅检查非 nil
},
{
name: "无 details 且无 regex 匹配",
input: `{"error":{"message":"rate limit"}}`,
wantNil: true,
},
{
name: "regex 回退匹配",
input: `Please retry in 30s`,
wantNil: false,
approxDelta: 30,
},
{
name: "完全无匹配",
input: `{"error":{"code":429}}`,
wantNil: true,
},
{
name: "非法 JSON 但 regex 回退仍工作",
input: `not json but Please retry in 10s`,
wantNil: false,
approxDelta: 10,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
now := time.Now().Unix()
got := ParseGeminiRateLimitResetTime([]byte(tt.input))
if tt.wantNil {
if got != nil {
t.Fatalf("期望返回 nil实际返回 %d", *got)
}
return
}
if got == nil {
t.Fatalf("期望返回非 nil实际返回 nil")
}
// approxDelta == -1 表示只检查非 nil不检查具体值如 daily quota 场景)
if tt.approxDelta == -1 {
// 仅验证返回的时间戳在合理范围内(未来的某个时间)
if *got < now {
t.Errorf("期望返回的时间戳 >= now(%d),实际 %d", now, *got)
}
return
}
// 使用 +/-2 秒容差进行范围检查
delta := *got - now
if delta < tt.approxDelta-2 || delta > tt.approxDelta+2 {
t.Errorf("期望 delta 约为 %d 秒(+/-2实际 delta = %d 秒(返回值=%d, now=%d",
tt.approxDelta, delta, *got, now)
}
})
}
}

View File

@@ -230,7 +230,7 @@ func NewOpenAIGatewayService(
// 1. Header: session_id
// 2. Header: conversation_id
// 3. Body: prompt_cache_key (opencode)
func (s *OpenAIGatewayService) GenerateSessionHash(c *gin.Context, reqBody map[string]any) string {
func (s *OpenAIGatewayService) GenerateSessionHash(c *gin.Context, body []byte) string {
if c == nil {
return ""
}
@@ -239,10 +239,8 @@ func (s *OpenAIGatewayService) GenerateSessionHash(c *gin.Context, reqBody map[s
if sessionID == "" {
sessionID = strings.TrimSpace(c.GetHeader("conversation_id"))
}
if sessionID == "" && reqBody != nil {
if v, ok := reqBody["prompt_cache_key"].(string); ok {
sessionID = strings.TrimSpace(v)
}
if sessionID == "" && len(body) > 0 {
sessionID = strings.TrimSpace(gjson.GetBytes(body, "prompt_cache_key").String())
}
if sessionID == "" {
return ""

View File

@@ -129,17 +129,19 @@ func TestOpenAIGatewayService_GenerateSessionHash_Priority(t *testing.T) {
svc := &OpenAIGatewayService{}
bodyWithKey := []byte(`{"prompt_cache_key":"ses_aaa"}`)
// 1) session_id header wins
c.Request.Header.Set("session_id", "sess-123")
c.Request.Header.Set("conversation_id", "conv-456")
h1 := svc.GenerateSessionHash(c, map[string]any{"prompt_cache_key": "ses_aaa"})
h1 := svc.GenerateSessionHash(c, bodyWithKey)
if h1 == "" {
t.Fatalf("expected non-empty hash")
}
// 2) conversation_id used when session_id absent
c.Request.Header.Del("session_id")
h2 := svc.GenerateSessionHash(c, map[string]any{"prompt_cache_key": "ses_aaa"})
h2 := svc.GenerateSessionHash(c, bodyWithKey)
if h2 == "" {
t.Fatalf("expected non-empty hash")
}
@@ -149,7 +151,7 @@ func TestOpenAIGatewayService_GenerateSessionHash_Priority(t *testing.T) {
// 3) prompt_cache_key used when both headers absent
c.Request.Header.Del("conversation_id")
h3 := svc.GenerateSessionHash(c, map[string]any{"prompt_cache_key": "ses_aaa"})
h3 := svc.GenerateSessionHash(c, bodyWithKey)
if h3 == "" {
t.Fatalf("expected non-empty hash")
}
@@ -158,7 +160,7 @@ func TestOpenAIGatewayService_GenerateSessionHash_Priority(t *testing.T) {
}
// 4) empty when no signals
h4 := svc.GenerateSessionHash(c, map[string]any{})
h4 := svc.GenerateSessionHash(c, []byte(`{}`))
if h4 != "" {
t.Fatalf("expected empty hash when no signals")
}

View File

@@ -24,6 +24,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/google/uuid"
"github.com/tidwall/gjson"
"golang.org/x/crypto/sha3"
)
@@ -219,12 +220,8 @@ func (c *SoraDirectClient) UploadImage(ctx context.Context, account *Account, da
if err != nil {
return "", err
}
var payload map[string]any
if err := json.Unmarshal(respBody, &payload); err != nil {
return "", fmt.Errorf("parse upload response: %w", err)
}
id, _ := payload["id"].(string)
if strings.TrimSpace(id) == "" {
id := strings.TrimSpace(gjson.GetBytes(respBody, "id").String())
if id == "" {
return "", errors.New("upload response missing id")
}
return id, nil
@@ -274,12 +271,8 @@ func (c *SoraDirectClient) CreateImageTask(ctx context.Context, account *Account
if err != nil {
return "", err
}
var resp map[string]any
if err := json.Unmarshal(respBody, &resp); err != nil {
return "", err
}
taskID, _ := resp["id"].(string)
if strings.TrimSpace(taskID) == "" {
taskID := strings.TrimSpace(gjson.GetBytes(respBody, "id").String())
if taskID == "" {
return "", errors.New("image task response missing id")
}
return taskID, nil
@@ -347,12 +340,8 @@ func (c *SoraDirectClient) CreateVideoTask(ctx context.Context, account *Account
if err != nil {
return "", err
}
var resp map[string]any
if err := json.Unmarshal(respBody, &resp); err != nil {
return "", err
}
taskID, _ := resp["id"].(string)
if strings.TrimSpace(taskID) == "" {
taskID := strings.TrimSpace(gjson.GetBytes(respBody, "id").String())
if taskID == "" {
return "", errors.New("video task response missing id")
}
return taskID, nil
@@ -393,41 +382,30 @@ func (c *SoraDirectClient) fetchRecentImageTask(ctx context.Context, account *Ac
if err != nil {
return nil, false, err
}
var resp map[string]any
if err := json.Unmarshal(respBody, &resp); err != nil {
return nil, false, err
}
taskResponses, _ := resp["task_responses"].([]any)
for _, item := range taskResponses {
taskResp, ok := item.(map[string]any)
if !ok {
continue
var found *SoraImageTaskStatus
gjson.GetBytes(respBody, "task_responses").ForEach(func(_, item gjson.Result) bool {
if item.Get("id").String() != taskID {
return true // continue
}
if id, _ := taskResp["id"].(string); id == taskID {
status := strings.TrimSpace(fmt.Sprintf("%v", taskResp["status"]))
progress := 0.0
if v, ok := taskResp["progress_pct"].(float64); ok {
progress = v
status := strings.TrimSpace(item.Get("status").String())
progress := item.Get("progress_pct").Float()
var urls []string
item.Get("generations").ForEach(func(_, gen gjson.Result) bool {
if u := strings.TrimSpace(gen.Get("url").String()); u != "" {
urls = append(urls, u)
}
urls := []string{}
if generations, ok := taskResp["generations"].([]any); ok {
for _, genItem := range generations {
gen, ok := genItem.(map[string]any)
if !ok {
continue
}
if urlStr, ok := gen["url"].(string); ok && strings.TrimSpace(urlStr) != "" {
urls = append(urls, urlStr)
}
}
}
return &SoraImageTaskStatus{
ID: taskID,
Status: status,
ProgressPct: progress,
URLs: urls,
}, true, nil
return true
})
found = &SoraImageTaskStatus{
ID: taskID,
Status: status,
ProgressPct: progress,
URLs: urls,
}
return false // break
})
if found != nil {
return found, true, nil
}
return &SoraImageTaskStatus{ID: taskID, Status: "processing"}, false, nil
}
@@ -463,27 +441,28 @@ func (c *SoraDirectClient) GetVideoTask(ctx context.Context, account *Account, t
if err != nil {
return nil, err
}
var pending any
if err := json.Unmarshal(respBody, &pending); err == nil {
if list, ok := pending.([]any); ok {
for _, item := range list {
task, ok := item.(map[string]any)
if !ok {
continue
}
if id, _ := task["id"].(string); id == taskID {
progress := 0
if v, ok := task["progress_pct"].(float64); ok {
progress = int(v * 100)
}
status := strings.TrimSpace(fmt.Sprintf("%v", task["status"]))
return &SoraVideoTaskStatus{
ID: taskID,
Status: status,
ProgressPct: progress,
}, nil
}
// 搜索 pending 列表JSON 数组)
pendingResult := gjson.ParseBytes(respBody)
if pendingResult.IsArray() {
var pendingFound *SoraVideoTaskStatus
pendingResult.ForEach(func(_, task gjson.Result) bool {
if task.Get("id").String() != taskID {
return true
}
progress := 0
if v := task.Get("progress_pct"); v.Exists() {
progress = int(v.Float() * 100)
}
status := strings.TrimSpace(task.Get("status").String())
pendingFound = &SoraVideoTaskStatus{
ID: taskID,
Status: status,
ProgressPct: progress,
}
return false
})
if pendingFound != nil {
return pendingFound, nil
}
}
@@ -491,44 +470,42 @@ func (c *SoraDirectClient) GetVideoTask(ctx context.Context, account *Account, t
if err != nil {
return nil, err
}
var draftsResp map[string]any
if err := json.Unmarshal(respBody, &draftsResp); err != nil {
return nil, err
}
items, _ := draftsResp["items"].([]any)
for _, item := range items {
draft, ok := item.(map[string]any)
if !ok {
continue
var draftFound *SoraVideoTaskStatus
gjson.GetBytes(respBody, "items").ForEach(func(_, draft gjson.Result) bool {
if draft.Get("task_id").String() != taskID {
return true
}
kind := strings.TrimSpace(draft.Get("kind").String())
reason := strings.TrimSpace(draft.Get("reason_str").String())
if reason == "" {
reason = strings.TrimSpace(draft.Get("markdown_reason_str").String())
}
urlStr := strings.TrimSpace(draft.Get("downloadable_url").String())
if urlStr == "" {
urlStr = strings.TrimSpace(draft.Get("url").String())
}
if id, _ := draft["task_id"].(string); id == taskID {
kind := strings.TrimSpace(fmt.Sprintf("%v", draft["kind"]))
reason := strings.TrimSpace(fmt.Sprintf("%v", draft["reason_str"]))
if reason == "" {
reason = strings.TrimSpace(fmt.Sprintf("%v", draft["markdown_reason_str"]))
}
urlStr := strings.TrimSpace(fmt.Sprintf("%v", draft["downloadable_url"]))
if urlStr == "" {
urlStr = strings.TrimSpace(fmt.Sprintf("%v", draft["url"]))
}
if kind == "sora_content_violation" || reason != "" || urlStr == "" {
msg := reason
if msg == "" {
msg = "Content violates guardrails"
}
return &SoraVideoTaskStatus{
ID: taskID,
Status: "failed",
ErrorMsg: msg,
}, nil
if kind == "sora_content_violation" || reason != "" || urlStr == "" {
msg := reason
if msg == "" {
msg = "Content violates guardrails"
}
return &SoraVideoTaskStatus{
draftFound = &SoraVideoTaskStatus{
ID: taskID,
Status: "failed",
ErrorMsg: msg,
}
} else {
draftFound = &SoraVideoTaskStatus{
ID: taskID,
Status: "completed",
URLs: []string{urlStr},
}, nil
}
}
return false
})
if draftFound != nil {
return draftFound, nil
}
return &SoraVideoTaskStatus{ID: taskID, Status: "processing"}, nil

View File

@@ -0,0 +1,515 @@
//go:build unit
package service
import (
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
)
// ---------- 辅助解析函数(复制生产代码中的 gjson 解析逻辑,用于单元测试) ----------
// testParseUploadOrCreateTaskID 模拟 UploadImage / CreateImageTask / CreateVideoTask 中
// 用 gjson.GetBytes(respBody, "id") 提取 id 的逻辑。
func testParseUploadOrCreateTaskID(respBody []byte) (string, error) {
id := strings.TrimSpace(gjson.GetBytes(respBody, "id").String())
if id == "" {
return "", assert.AnError // 占位错误,表示 "missing id"
}
return id, nil
}
// testParseFetchRecentImageTask 模拟 fetchRecentImageTask 中的 gjson.ForEach 解析逻辑。
func testParseFetchRecentImageTask(respBody []byte, taskID string) (*SoraImageTaskStatus, bool) {
var found *SoraImageTaskStatus
gjson.GetBytes(respBody, "task_responses").ForEach(func(_, item gjson.Result) bool {
if item.Get("id").String() != taskID {
return true // continue
}
status := strings.TrimSpace(item.Get("status").String())
progress := item.Get("progress_pct").Float()
var urls []string
item.Get("generations").ForEach(func(_, gen gjson.Result) bool {
if u := strings.TrimSpace(gen.Get("url").String()); u != "" {
urls = append(urls, u)
}
return true
})
found = &SoraImageTaskStatus{
ID: taskID,
Status: status,
ProgressPct: progress,
URLs: urls,
}
return false // break
})
if found != nil {
return found, true
}
return &SoraImageTaskStatus{ID: taskID, Status: "processing"}, false
}
// testParseGetVideoTaskPending 模拟 GetVideoTask 中解析 pending 列表的逻辑。
func testParseGetVideoTaskPending(respBody []byte, taskID string) (*SoraVideoTaskStatus, bool) {
pendingResult := gjson.ParseBytes(respBody)
if !pendingResult.IsArray() {
return nil, false
}
var pendingFound *SoraVideoTaskStatus
pendingResult.ForEach(func(_, task gjson.Result) bool {
if task.Get("id").String() != taskID {
return true
}
progress := 0
if v := task.Get("progress_pct"); v.Exists() {
progress = int(v.Float() * 100)
}
status := strings.TrimSpace(task.Get("status").String())
pendingFound = &SoraVideoTaskStatus{
ID: taskID,
Status: status,
ProgressPct: progress,
}
return false
})
if pendingFound != nil {
return pendingFound, true
}
return nil, false
}
// testParseGetVideoTaskDrafts 模拟 GetVideoTask 中解析 drafts 列表的逻辑。
func testParseGetVideoTaskDrafts(respBody []byte, taskID string) (*SoraVideoTaskStatus, bool) {
var draftFound *SoraVideoTaskStatus
gjson.GetBytes(respBody, "items").ForEach(func(_, draft gjson.Result) bool {
if draft.Get("task_id").String() != taskID {
return true
}
kind := strings.TrimSpace(draft.Get("kind").String())
reason := strings.TrimSpace(draft.Get("reason_str").String())
if reason == "" {
reason = strings.TrimSpace(draft.Get("markdown_reason_str").String())
}
urlStr := strings.TrimSpace(draft.Get("downloadable_url").String())
if urlStr == "" {
urlStr = strings.TrimSpace(draft.Get("url").String())
}
if kind == "sora_content_violation" || reason != "" || urlStr == "" {
msg := reason
if msg == "" {
msg = "Content violates guardrails"
}
draftFound = &SoraVideoTaskStatus{
ID: taskID,
Status: "failed",
ErrorMsg: msg,
}
} else {
draftFound = &SoraVideoTaskStatus{
ID: taskID,
Status: "completed",
URLs: []string{urlStr},
}
}
return false
})
if draftFound != nil {
return draftFound, true
}
return nil, false
}
// ===================== Test 1: TestSoraParseUploadResponse =====================
func TestSoraParseUploadResponse(t *testing.T) {
tests := []struct {
name string
body string
wantID string
wantErr bool
}{
{
name: "正常 id",
body: `{"id":"file-abc123","status":"uploaded"}`,
wantID: "file-abc123",
},
{
name: "空 id",
body: `{"id":"","status":"uploaded"}`,
wantErr: true,
},
{
name: "无 id 字段",
body: `{"status":"uploaded"}`,
wantErr: true,
},
{
name: "id 全为空白",
body: `{"id":" ","status":"uploaded"}`,
wantErr: true,
},
{
name: "id 前后有空白",
body: `{"id":" file-trimmed ","status":"uploaded"}`,
wantID: "file-trimmed",
},
{
name: "空 JSON 对象",
body: `{}`,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
id, err := testParseUploadOrCreateTaskID([]byte(tt.body))
if tt.wantErr {
require.Error(t, err, "应返回错误")
return
}
require.NoError(t, err)
require.Equal(t, tt.wantID, id)
})
}
}
// ===================== Test 2: TestSoraParseCreateTaskResponse =====================
func TestSoraParseCreateTaskResponse(t *testing.T) {
tests := []struct {
name string
body string
wantID string
wantErr bool
}{
{
name: "正常任务 id",
body: `{"id":"task-123"}`,
wantID: "task-123",
},
{
name: "缺失 id",
body: `{"status":"created"}`,
wantErr: true,
},
{
name: "空 id",
body: `{"id":" "}`,
wantErr: true,
},
{
name: "id 为数字gjson 转字符串)",
body: `{"id":123}`,
wantID: "123",
},
{
name: "id 含特殊字符",
body: `{"id":"task-abc-def-456-ghi"}`,
wantID: "task-abc-def-456-ghi",
},
{
name: "额外字段不影响解析",
body: `{"id":"task-999","type":"image_gen","extra":"data"}`,
wantID: "task-999",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
id, err := testParseUploadOrCreateTaskID([]byte(tt.body))
if tt.wantErr {
require.Error(t, err, "应返回错误")
return
}
require.NoError(t, err)
require.Equal(t, tt.wantID, id)
})
}
}
// ===================== Test 3: TestSoraParseFetchRecentImageTask =====================
func TestSoraParseFetchRecentImageTask(t *testing.T) {
tests := []struct {
name string
body string
taskID string
wantFound bool
wantStatus string
wantProgress float64
wantURLs []string
}{
{
name: "匹配已完成任务",
body: `{"task_responses":[{"id":"task-1","status":"completed","progress_pct":1.0,"generations":[{"url":"https://example.com/img.png"}]}]}`,
taskID: "task-1",
wantFound: true,
wantStatus: "completed",
wantProgress: 1.0,
wantURLs: []string{"https://example.com/img.png"},
},
{
name: "匹配处理中任务",
body: `{"task_responses":[{"id":"task-2","status":"processing","progress_pct":0.5,"generations":[]}]}`,
taskID: "task-2",
wantFound: true,
wantStatus: "processing",
wantProgress: 0.5,
wantURLs: nil,
},
{
name: "无匹配任务",
body: `{"task_responses":[{"id":"other","status":"completed"}]}`,
taskID: "task-1",
wantFound: false,
wantStatus: "processing",
},
{
name: "空 task_responses",
body: `{"task_responses":[]}`,
taskID: "task-1",
wantFound: false,
wantStatus: "processing",
},
{
name: "缺少 task_responses 字段",
body: `{"other":"data"}`,
taskID: "task-1",
wantFound: false,
wantStatus: "processing",
},
{
name: "多个任务中精准匹配",
body: `{"task_responses":[{"id":"task-a","status":"completed","progress_pct":1.0,"generations":[{"url":"https://a.com/1.png"}]},{"id":"task-b","status":"processing","progress_pct":0.3,"generations":[]},{"id":"task-c","status":"failed","progress_pct":0}]}`,
taskID: "task-b",
wantFound: true,
wantStatus: "processing",
wantProgress: 0.3,
wantURLs: nil,
},
{
name: "多个 generations",
body: `{"task_responses":[{"id":"task-m","status":"completed","progress_pct":1.0,"generations":[{"url":"https://a.com/1.png"},{"url":"https://a.com/2.png"},{"url":""}]}]}`,
taskID: "task-m",
wantFound: true,
wantStatus: "completed",
wantProgress: 1.0,
wantURLs: []string{"https://a.com/1.png", "https://a.com/2.png"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
status, found := testParseFetchRecentImageTask([]byte(tt.body), tt.taskID)
require.Equal(t, tt.wantFound, found, "found 不匹配")
require.NotNil(t, status)
require.Equal(t, tt.taskID, status.ID)
require.Equal(t, tt.wantStatus, status.Status)
if tt.wantFound {
require.InDelta(t, tt.wantProgress, status.ProgressPct, 0.001, "进度不匹配")
require.Equal(t, tt.wantURLs, status.URLs)
}
})
}
}
// ===================== Test 4: TestSoraParseGetVideoTaskPending =====================
func TestSoraParseGetVideoTaskPending(t *testing.T) {
tests := []struct {
name string
body string
taskID string
wantFound bool
wantStatus string
wantProgress int
}{
{
name: "匹配 pending 任务",
body: `[{"id":"task-1","status":"processing","progress_pct":0.5}]`,
taskID: "task-1",
wantFound: true,
wantStatus: "processing",
wantProgress: 50,
},
{
name: "进度为 0",
body: `[{"id":"task-2","status":"queued","progress_pct":0}]`,
taskID: "task-2",
wantFound: true,
wantStatus: "queued",
wantProgress: 0,
},
{
name: "进度为 1100%",
body: `[{"id":"task-3","status":"completing","progress_pct":1.0}]`,
taskID: "task-3",
wantFound: true,
wantStatus: "completing",
wantProgress: 100,
},
{
name: "空数组",
body: `[]`,
taskID: "task-1",
wantFound: false,
},
{
name: "无匹配 id",
body: `[{"id":"task-other","status":"processing","progress_pct":0.3}]`,
taskID: "task-1",
wantFound: false,
},
{
name: "多个任务精准匹配",
body: `[{"id":"task-a","status":"processing","progress_pct":0.2},{"id":"task-b","status":"queued","progress_pct":0},{"id":"task-c","status":"processing","progress_pct":0.8}]`,
taskID: "task-c",
wantFound: true,
wantStatus: "processing",
wantProgress: 80,
},
{
name: "非数组 JSON",
body: `{"id":"task-1","status":"processing"}`,
taskID: "task-1",
wantFound: false,
},
{
name: "无 progress_pct 字段",
body: `[{"id":"task-4","status":"pending"}]`,
taskID: "task-4",
wantFound: true,
wantStatus: "pending",
wantProgress: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
status, found := testParseGetVideoTaskPending([]byte(tt.body), tt.taskID)
require.Equal(t, tt.wantFound, found, "found 不匹配")
if tt.wantFound {
require.NotNil(t, status)
require.Equal(t, tt.taskID, status.ID)
require.Equal(t, tt.wantStatus, status.Status)
require.Equal(t, tt.wantProgress, status.ProgressPct)
}
})
}
}
// ===================== Test 5: TestSoraParseGetVideoTaskDrafts =====================
func TestSoraParseGetVideoTaskDrafts(t *testing.T) {
tests := []struct {
name string
body string
taskID string
wantFound bool
wantStatus string
wantURLs []string
wantErr string
}{
{
name: "正常完成的视频",
body: `{"items":[{"task_id":"task-1","kind":"video","downloadable_url":"https://example.com/video.mp4"}]}`,
taskID: "task-1",
wantFound: true,
wantStatus: "completed",
wantURLs: []string{"https://example.com/video.mp4"},
},
{
name: "使用 url 字段回退",
body: `{"items":[{"task_id":"task-2","kind":"video","url":"https://example.com/fallback.mp4"}]}`,
taskID: "task-2",
wantFound: true,
wantStatus: "completed",
wantURLs: []string{"https://example.com/fallback.mp4"},
},
{
name: "内容违规",
body: `{"items":[{"task_id":"task-3","kind":"sora_content_violation","reason_str":"Content policy violation"}]}`,
taskID: "task-3",
wantFound: true,
wantStatus: "failed",
wantErr: "Content policy violation",
},
{
name: "内容违规 - markdown_reason_str 回退",
body: `{"items":[{"task_id":"task-4","kind":"sora_content_violation","markdown_reason_str":"Markdown reason"}]}`,
taskID: "task-4",
wantFound: true,
wantStatus: "failed",
wantErr: "Markdown reason",
},
{
name: "内容违规 - 无 reason 使用默认消息",
body: `{"items":[{"task_id":"task-5","kind":"sora_content_violation"}]}`,
taskID: "task-5",
wantFound: true,
wantStatus: "failed",
wantErr: "Content violates guardrails",
},
{
name: "有 reason_str 但非 violation kind仍判定失败",
body: `{"items":[{"task_id":"task-6","kind":"video","reason_str":"Some error occurred"}]}`,
taskID: "task-6",
wantFound: true,
wantStatus: "failed",
wantErr: "Some error occurred",
},
{
name: "空 URL 判定为失败",
body: `{"items":[{"task_id":"task-7","kind":"video","downloadable_url":"","url":""}]}`,
taskID: "task-7",
wantFound: true,
wantStatus: "failed",
wantErr: "Content violates guardrails",
},
{
name: "无匹配 task_id",
body: `{"items":[{"task_id":"task-other","kind":"video","downloadable_url":"https://example.com/video.mp4"}]}`,
taskID: "task-1",
wantFound: false,
},
{
name: "空 items",
body: `{"items":[]}`,
taskID: "task-1",
wantFound: false,
},
{
name: "缺少 items 字段",
body: `{"other":"data"}`,
taskID: "task-1",
wantFound: false,
},
{
name: "多个 items 精准匹配",
body: `{"items":[{"task_id":"task-a","kind":"video","downloadable_url":"https://a.com/a.mp4"},{"task_id":"task-b","kind":"sora_content_violation","reason_str":"Bad content"},{"task_id":"task-c","kind":"video","downloadable_url":"https://c.com/c.mp4"}]}`,
taskID: "task-b",
wantFound: true,
wantStatus: "failed",
wantErr: "Bad content",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
status, found := testParseGetVideoTaskDrafts([]byte(tt.body), tt.taskID)
require.Equal(t, tt.wantFound, found, "found 不匹配")
if !tt.wantFound {
return
}
require.NotNil(t, status)
require.Equal(t, tt.taskID, status.ID)
require.Equal(t, tt.wantStatus, status.Status)
if tt.wantErr != "" {
require.Equal(t, tt.wantErr, status.ErrorMsg)
}
if tt.wantURLs != nil {
require.Equal(t, tt.wantURLs, status.URLs)
}
})
}
}