perf(backend): 使用 gjson/sjson 优化热路径 JSON 处理
将 API 网关热路径中的 json.Unmarshal+json.Marshal 替换为 gjson 零拷贝查询和 sjson 精准写入: - unwrapV1InternalResponse 性能提升 22x(4009ns→182ns),内存分配减少 28.5x - unwrapGeminiResponse、extractGeminiUsage、estimateGeminiCountTokens、ParseGeminiRateLimitResetTime 改为接收 []byte 使用 gjson 提取 - ParseGatewayRequest 的 model/stream/metadata/thinking/max_tokens 改用 gjson 类型安全提取 - Handler 层(sora/openai)改用 gjson 提取字段、sjson 注入/修改字段,移除 map[string]any 中间变量 - Sora Client 响应解析改用 gjson ForEach 遍历,减少内存分配 - 新增约 100 个单元测试用例,所有改动函数覆盖率 >85% Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -1,7 +1,11 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/domain"
|
||||
@@ -416,3 +420,341 @@ func TestFilterSignatureSensitiveBlocksForRetry_DowngradesTools(t *testing.T) {
|
||||
require.Contains(t, content0["text"], "tool_use")
|
||||
require.Contains(t, content1["text"], "tool_result")
|
||||
}
|
||||
|
||||
// ============ Group 7: ParseGatewayRequest 补充单元测试 ============
|
||||
|
||||
// Task 7.1 — 类型校验边界测试
|
||||
func TestParseGatewayRequest_TypeValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
body string
|
||||
wantErr bool
|
||||
errSubstr string // 期望的错误信息子串(为空则不检查)
|
||||
}{
|
||||
{
|
||||
name: "model 为 int",
|
||||
body: `{"model":123}`,
|
||||
wantErr: true,
|
||||
errSubstr: "invalid model field type",
|
||||
},
|
||||
{
|
||||
name: "model 为 array",
|
||||
body: `{"model":[]}`,
|
||||
wantErr: true,
|
||||
errSubstr: "invalid model field type",
|
||||
},
|
||||
{
|
||||
name: "model 为 bool",
|
||||
body: `{"model":true}`,
|
||||
wantErr: true,
|
||||
errSubstr: "invalid model field type",
|
||||
},
|
||||
{
|
||||
name: "model 为 null — gjson Null 类型触发类型校验错误",
|
||||
body: `{"model":null}`,
|
||||
wantErr: true, // gjson: Exists()=true, Type=Null != String → 返回错误
|
||||
errSubstr: "invalid model field type",
|
||||
},
|
||||
{
|
||||
name: "stream 为 string",
|
||||
body: `{"stream":"true"}`,
|
||||
wantErr: true,
|
||||
errSubstr: "invalid stream field type",
|
||||
},
|
||||
{
|
||||
name: "stream 为 int",
|
||||
body: `{"stream":1}`,
|
||||
wantErr: true,
|
||||
errSubstr: "invalid stream field type",
|
||||
},
|
||||
{
|
||||
name: "stream 为 null — gjson Null 类型触发类型校验错误",
|
||||
body: `{"stream":null}`,
|
||||
wantErr: true, // gjson: Exists()=true, Type=Null != True && != False → 返回错误
|
||||
errSubstr: "invalid stream field type",
|
||||
},
|
||||
{
|
||||
name: "model 为 object",
|
||||
body: `{"model":{}}`,
|
||||
wantErr: true,
|
||||
errSubstr: "invalid model field type",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := ParseGatewayRequest([]byte(tt.body), "")
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
if tt.errSubstr != "" {
|
||||
require.Contains(t, err.Error(), tt.errSubstr)
|
||||
}
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Task 7.2 — 可选字段缺失测试
|
||||
func TestParseGatewayRequest_OptionalFieldsMissing(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
body string
|
||||
wantModel string
|
||||
wantStream bool
|
||||
wantMetadataUID string
|
||||
wantHasSystem bool
|
||||
wantThinking bool
|
||||
wantMaxTokens int
|
||||
wantMessagesNil bool
|
||||
wantMessagesLen int
|
||||
}{
|
||||
{
|
||||
name: "完全空 JSON — 所有字段零值",
|
||||
body: `{}`,
|
||||
wantModel: "",
|
||||
wantStream: false,
|
||||
wantMetadataUID: "",
|
||||
wantHasSystem: false,
|
||||
wantThinking: false,
|
||||
wantMaxTokens: 0,
|
||||
wantMessagesNil: true,
|
||||
},
|
||||
{
|
||||
name: "metadata 无 user_id",
|
||||
body: `{"model":"test"}`,
|
||||
wantModel: "test",
|
||||
wantMetadataUID: "",
|
||||
wantHasSystem: false,
|
||||
wantThinking: false,
|
||||
},
|
||||
{
|
||||
name: "thinking 非 enabled(type=disabled)",
|
||||
body: `{"model":"test","thinking":{"type":"disabled"}}`,
|
||||
wantModel: "test",
|
||||
wantThinking: false,
|
||||
},
|
||||
{
|
||||
name: "thinking 字段缺失",
|
||||
body: `{"model":"test"}`,
|
||||
wantModel: "test",
|
||||
wantThinking: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
parsed, err := ParseGatewayRequest([]byte(tt.body), "")
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, tt.wantModel, parsed.Model)
|
||||
require.Equal(t, tt.wantStream, parsed.Stream)
|
||||
require.Equal(t, tt.wantMetadataUID, parsed.MetadataUserID)
|
||||
require.Equal(t, tt.wantHasSystem, parsed.HasSystem)
|
||||
require.Equal(t, tt.wantThinking, parsed.ThinkingEnabled)
|
||||
require.Equal(t, tt.wantMaxTokens, parsed.MaxTokens)
|
||||
|
||||
if tt.wantMessagesNil {
|
||||
require.Nil(t, parsed.Messages)
|
||||
}
|
||||
if tt.wantMessagesLen > 0 {
|
||||
require.Len(t, parsed.Messages, tt.wantMessagesLen)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Task 7.3 — Gemini 协议分支测试
|
||||
// 已有测试覆盖:
|
||||
// - TestParseGatewayRequest_GeminiSystemInstruction: 正常 systemInstruction+contents
|
||||
// - TestParseGatewayRequest_GeminiNoContents: 缺失 contents
|
||||
// - TestParseGatewayRequest_GeminiContents: 正常 contents(无 systemInstruction)
|
||||
// 因此跳过。
|
||||
|
||||
// Task 7.4 — max_tokens 边界测试
|
||||
func TestParseGatewayRequest_MaxTokensBoundary(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
body string
|
||||
wantMaxTokens int
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "正常整数",
|
||||
body: `{"max_tokens":1024}`,
|
||||
wantMaxTokens: 1024,
|
||||
},
|
||||
{
|
||||
name: "浮点数(非整数)被忽略",
|
||||
body: `{"max_tokens":10.5}`,
|
||||
wantMaxTokens: 0,
|
||||
},
|
||||
{
|
||||
name: "负整数可以通过",
|
||||
body: `{"max_tokens":-1}`,
|
||||
wantMaxTokens: -1,
|
||||
},
|
||||
{
|
||||
name: "超大值不 panic",
|
||||
body: `{"max_tokens":9999999999999999}`,
|
||||
wantMaxTokens: 10000000000000000, // float64 精度导致 9999999999999999 → 1e16
|
||||
},
|
||||
{
|
||||
name: "null 值被忽略",
|
||||
body: `{"max_tokens":null}`,
|
||||
wantMaxTokens: 0, // gjson Type=Null != Number → 条件不满足,跳过
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
parsed, err := ParseGatewayRequest([]byte(tt.body), "")
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tt.wantMaxTokens, parsed.MaxTokens)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ============ Task 7.5: Benchmark 测试 ============
|
||||
|
||||
// parseGatewayRequestOld 是基于完整 json.Unmarshal 的旧实现,用于 benchmark 对比基线。
|
||||
// 核心路径:先 Unmarshal 到 map[string]any,再逐字段提取。
|
||||
func parseGatewayRequestOld(body []byte, protocol string) (*ParsedRequest, error) {
|
||||
parsed := &ParsedRequest{
|
||||
Body: body,
|
||||
}
|
||||
|
||||
var req map[string]any
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// model
|
||||
if raw, ok := req["model"]; ok {
|
||||
s, ok := raw.(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid model field type")
|
||||
}
|
||||
parsed.Model = s
|
||||
}
|
||||
|
||||
// stream
|
||||
if raw, ok := req["stream"]; ok {
|
||||
b, ok := raw.(bool)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid stream field type")
|
||||
}
|
||||
parsed.Stream = b
|
||||
}
|
||||
|
||||
// metadata.user_id
|
||||
if meta, ok := req["metadata"].(map[string]any); ok {
|
||||
if uid, ok := meta["user_id"].(string); ok {
|
||||
parsed.MetadataUserID = uid
|
||||
}
|
||||
}
|
||||
|
||||
// thinking.type
|
||||
if thinking, ok := req["thinking"].(map[string]any); ok {
|
||||
if thinkType, ok := thinking["type"].(string); ok && thinkType == "enabled" {
|
||||
parsed.ThinkingEnabled = true
|
||||
}
|
||||
}
|
||||
|
||||
// max_tokens
|
||||
if raw, ok := req["max_tokens"]; ok {
|
||||
if n, ok := parseIntegralNumber(raw); ok {
|
||||
parsed.MaxTokens = n
|
||||
}
|
||||
}
|
||||
|
||||
// system / messages(按协议分支)
|
||||
switch protocol {
|
||||
case domain.PlatformGemini:
|
||||
if sysInst, ok := req["systemInstruction"].(map[string]any); ok {
|
||||
if parts, ok := sysInst["parts"].([]any); ok {
|
||||
parsed.System = parts
|
||||
}
|
||||
}
|
||||
if contents, ok := req["contents"].([]any); ok {
|
||||
parsed.Messages = contents
|
||||
}
|
||||
default:
|
||||
if system, ok := req["system"]; ok {
|
||||
parsed.HasSystem = true
|
||||
parsed.System = system
|
||||
}
|
||||
if messages, ok := req["messages"].([]any); ok {
|
||||
parsed.Messages = messages
|
||||
}
|
||||
}
|
||||
|
||||
return parsed, nil
|
||||
}
|
||||
|
||||
// buildSmallJSON 构建 ~500B 的小型测试 JSON
|
||||
func buildSmallJSON() []byte {
|
||||
return []byte(`{"model":"claude-sonnet-4-5","stream":true,"max_tokens":4096,"metadata":{"user_id":"user-abc123"},"thinking":{"type":"enabled","budget_tokens":2048},"system":"You are a helpful assistant.","messages":[{"role":"user","content":"What is the meaning of life?"},{"role":"assistant","content":"The meaning of life is a philosophical question."},{"role":"user","content":"Can you elaborate?"}]}`)
|
||||
}
|
||||
|
||||
// buildLargeJSON 构建 ~50KB 的大型测试 JSON(大量 messages)
|
||||
func buildLargeJSON() []byte {
|
||||
var b strings.Builder
|
||||
b.WriteString(`{"model":"claude-sonnet-4-5","stream":true,"max_tokens":8192,"metadata":{"user_id":"user-xyz789"},"system":[{"type":"text","text":"You are a detailed assistant.","cache_control":{"type":"ephemeral"}}],"messages":[`)
|
||||
|
||||
msgCount := 200
|
||||
for i := 0; i < msgCount; i++ {
|
||||
if i > 0 {
|
||||
b.WriteByte(',')
|
||||
}
|
||||
if i%2 == 0 {
|
||||
b.WriteString(fmt.Sprintf(`{"role":"user","content":"This is user message number %d with some extra padding text to make the message reasonably long for benchmarking purposes. Lorem ipsum dolor sit amet."}`, i))
|
||||
} else {
|
||||
b.WriteString(fmt.Sprintf(`{"role":"assistant","content":[{"type":"text","text":"This is assistant response number %d. I will provide a detailed answer with multiple sentences to simulate real conversation content for benchmark testing."}]}`, i))
|
||||
}
|
||||
}
|
||||
|
||||
b.WriteString(`]}`)
|
||||
return []byte(b.String())
|
||||
}
|
||||
|
||||
func BenchmarkParseGatewayRequest_Old_Small(b *testing.B) {
|
||||
data := buildSmallJSON()
|
||||
b.SetBytes(int64(len(data)))
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = parseGatewayRequestOld(data, "")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkParseGatewayRequest_New_Small(b *testing.B) {
|
||||
data := buildSmallJSON()
|
||||
b.SetBytes(int64(len(data)))
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = ParseGatewayRequest(data, "")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkParseGatewayRequest_Old_Large(b *testing.B) {
|
||||
data := buildLargeJSON()
|
||||
b.SetBytes(int64(len(data)))
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = parseGatewayRequestOld(data, "")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkParseGatewayRequest_New_Large(b *testing.B) {
|
||||
data := buildLargeJSON()
|
||||
b.SetBytes(int64(len(data)))
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = ParseGatewayRequest(data, "")
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user