将 API 网关热路径中的 json.Unmarshal+json.Marshal 替换为 gjson 零拷贝查询和 sjson 精准写入: - unwrapV1InternalResponse 性能提升 22x(4009ns→182ns),内存分配减少 28.5x - unwrapGeminiResponse、extractGeminiUsage、estimateGeminiCountTokens、ParseGeminiRateLimitResetTime 改为接收 []byte 使用 gjson 提取 - ParseGatewayRequest 的 model/stream/metadata/thinking/max_tokens 改用 gjson 类型安全提取 - Handler 层(sora/openai)改用 gjson 提取字段、sjson 注入/修改字段,移除 map[string]any 中间变量 - Sora Client 响应解析改用 gjson ForEach 遍历,减少内存分配 - 新增约 100 个单元测试用例,所有改动函数覆盖率 >85% Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
511 lines
13 KiB
Go
511 lines
13 KiB
Go
package service
|
||
|
||
import (
|
||
"encoding/json"
|
||
"fmt"
|
||
"strings"
|
||
"testing"
|
||
"time"
|
||
|
||
"github.com/stretchr/testify/require"
|
||
)
|
||
|
||
// TestConvertClaudeToolsToGeminiTools_CustomType 测试custom类型工具转换
|
||
func TestConvertClaudeToolsToGeminiTools_CustomType(t *testing.T) {
|
||
tests := []struct {
|
||
name string
|
||
tools any
|
||
expectedLen int
|
||
description string
|
||
}{
|
||
{
|
||
name: "Standard tools",
|
||
tools: []any{
|
||
map[string]any{
|
||
"name": "get_weather",
|
||
"description": "Get weather info",
|
||
"input_schema": map[string]any{"type": "object"},
|
||
},
|
||
},
|
||
expectedLen: 1,
|
||
description: "标准工具格式应该正常转换",
|
||
},
|
||
{
|
||
name: "Custom type tool (MCP format)",
|
||
tools: []any{
|
||
map[string]any{
|
||
"type": "custom",
|
||
"name": "mcp_tool",
|
||
"custom": map[string]any{
|
||
"description": "MCP tool description",
|
||
"input_schema": map[string]any{"type": "object"},
|
||
},
|
||
},
|
||
},
|
||
expectedLen: 1,
|
||
description: "Custom类型工具应该从custom字段读取",
|
||
},
|
||
{
|
||
name: "Mixed standard and custom tools",
|
||
tools: []any{
|
||
map[string]any{
|
||
"name": "standard_tool",
|
||
"description": "Standard",
|
||
"input_schema": map[string]any{"type": "object"},
|
||
},
|
||
map[string]any{
|
||
"type": "custom",
|
||
"name": "custom_tool",
|
||
"custom": map[string]any{
|
||
"description": "Custom",
|
||
"input_schema": map[string]any{"type": "object"},
|
||
},
|
||
},
|
||
},
|
||
expectedLen: 1,
|
||
description: "混合工具应该都能正确转换",
|
||
},
|
||
{
|
||
name: "Custom tool without custom field",
|
||
tools: []any{
|
||
map[string]any{
|
||
"type": "custom",
|
||
"name": "invalid_custom",
|
||
// 缺少 custom 字段
|
||
},
|
||
},
|
||
expectedLen: 0, // 应该被跳过
|
||
description: "缺少custom字段的custom工具应该被跳过",
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
result := convertClaudeToolsToGeminiTools(tt.tools)
|
||
|
||
if tt.expectedLen == 0 {
|
||
if result != nil {
|
||
t.Errorf("%s: expected nil result, got %v", tt.description, result)
|
||
}
|
||
return
|
||
}
|
||
|
||
if result == nil {
|
||
t.Fatalf("%s: expected non-nil result", tt.description)
|
||
}
|
||
|
||
if len(result) != 1 {
|
||
t.Errorf("%s: expected 1 tool declaration, got %d", tt.description, len(result))
|
||
return
|
||
}
|
||
|
||
toolDecl, ok := result[0].(map[string]any)
|
||
if !ok {
|
||
t.Fatalf("%s: result[0] is not map[string]any", tt.description)
|
||
}
|
||
|
||
funcDecls, ok := toolDecl["functionDeclarations"].([]any)
|
||
if !ok {
|
||
t.Fatalf("%s: functionDeclarations is not []any", tt.description)
|
||
}
|
||
|
||
toolsArr, _ := tt.tools.([]any)
|
||
expectedFuncCount := 0
|
||
for _, tool := range toolsArr {
|
||
toolMap, _ := tool.(map[string]any)
|
||
if toolMap["name"] != "" {
|
||
// 检查是否为有效的custom工具
|
||
if toolMap["type"] == "custom" {
|
||
if toolMap["custom"] != nil {
|
||
expectedFuncCount++
|
||
}
|
||
} else {
|
||
expectedFuncCount++
|
||
}
|
||
}
|
||
}
|
||
|
||
if len(funcDecls) != expectedFuncCount {
|
||
t.Errorf("%s: expected %d function declarations, got %d",
|
||
tt.description, expectedFuncCount, len(funcDecls))
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
func TestConvertClaudeMessagesToGeminiGenerateContent_AddsThoughtSignatureForToolUse(t *testing.T) {
|
||
claudeReq := map[string]any{
|
||
"model": "claude-haiku-4-5-20251001",
|
||
"max_tokens": 10,
|
||
"messages": []any{
|
||
map[string]any{
|
||
"role": "user",
|
||
"content": []any{
|
||
map[string]any{"type": "text", "text": "hi"},
|
||
},
|
||
},
|
||
map[string]any{
|
||
"role": "assistant",
|
||
"content": []any{
|
||
map[string]any{"type": "text", "text": "ok"},
|
||
map[string]any{
|
||
"type": "tool_use",
|
||
"id": "toolu_123",
|
||
"name": "default_api:write_file",
|
||
"input": map[string]any{"path": "a.txt", "content": "x"},
|
||
// no signature on purpose
|
||
},
|
||
},
|
||
},
|
||
},
|
||
"tools": []any{
|
||
map[string]any{
|
||
"name": "default_api:write_file",
|
||
"description": "write file",
|
||
"input_schema": map[string]any{
|
||
"type": "object",
|
||
"properties": map[string]any{"path": map[string]any{"type": "string"}},
|
||
},
|
||
},
|
||
},
|
||
}
|
||
b, _ := json.Marshal(claudeReq)
|
||
|
||
out, err := convertClaudeMessagesToGeminiGenerateContent(b)
|
||
if err != nil {
|
||
t.Fatalf("convert failed: %v", err)
|
||
}
|
||
s := string(out)
|
||
if !strings.Contains(s, "\"functionCall\"") {
|
||
t.Fatalf("expected functionCall in output, got: %s", s)
|
||
}
|
||
if !strings.Contains(s, "\"thoughtSignature\":\""+geminiDummyThoughtSignature+"\"") {
|
||
t.Fatalf("expected injected thoughtSignature %q, got: %s", geminiDummyThoughtSignature, s)
|
||
}
|
||
}
|
||
|
||
func TestEnsureGeminiFunctionCallThoughtSignatures_InsertsWhenMissing(t *testing.T) {
|
||
geminiReq := map[string]any{
|
||
"contents": []any{
|
||
map[string]any{
|
||
"role": "user",
|
||
"parts": []any{
|
||
map[string]any{
|
||
"functionCall": map[string]any{
|
||
"name": "default_api:write_file",
|
||
"args": map[string]any{"path": "a.txt"},
|
||
},
|
||
},
|
||
},
|
||
},
|
||
},
|
||
}
|
||
b, _ := json.Marshal(geminiReq)
|
||
out := ensureGeminiFunctionCallThoughtSignatures(b)
|
||
s := string(out)
|
||
if !strings.Contains(s, "\"thoughtSignature\":\""+geminiDummyThoughtSignature+"\"") {
|
||
t.Fatalf("expected injected thoughtSignature %q, got: %s", geminiDummyThoughtSignature, s)
|
||
}
|
||
}
|
||
|
||
// TestUnwrapGeminiResponse 测试 unwrapGeminiResponse 的各种输入场景
|
||
// 关键区别:只有 response 为 JSON 对象/数组时才解包
|
||
func TestUnwrapGeminiResponse(t *testing.T) {
|
||
// 构造 >50KB 的大型 JSON 对象
|
||
largePadding := strings.Repeat("x", 50*1024)
|
||
largeInput := []byte(fmt.Sprintf(`{"response":{"id":"big","pad":"%s"}}`, largePadding))
|
||
largeExpected := fmt.Sprintf(`{"id":"big","pad":"%s"}`, largePadding)
|
||
|
||
tests := []struct {
|
||
name string
|
||
input []byte
|
||
expected string
|
||
wantErr bool
|
||
}{
|
||
{
|
||
name: "正常 response 包装(JSON 对象)",
|
||
input: []byte(`{"response":{"key":"val"}}`),
|
||
expected: `{"key":"val"}`,
|
||
},
|
||
{
|
||
name: "无包装直接返回",
|
||
input: []byte(`{"key":"val"}`),
|
||
expected: `{"key":"val"}`,
|
||
},
|
||
{
|
||
name: "空 JSON",
|
||
input: []byte(`{}`),
|
||
expected: `{}`,
|
||
},
|
||
{
|
||
name: "null response 返回原始 body",
|
||
input: []byte(`{"response":null}`),
|
||
expected: `{"response":null}`,
|
||
},
|
||
{
|
||
name: "非法 JSON 返回原始 body",
|
||
input: []byte(`not json`),
|
||
expected: `not json`,
|
||
},
|
||
{
|
||
name: "response 为基础类型 string 返回原始 body",
|
||
input: []byte(`{"response":"hello"}`),
|
||
expected: `{"response":"hello"}`,
|
||
},
|
||
{
|
||
name: "嵌套 response 只解一层",
|
||
input: []byte(`{"response":{"response":{"inner":true}}}`),
|
||
expected: `{"response":{"inner":true}}`,
|
||
},
|
||
{
|
||
name: "大型 JSON >50KB",
|
||
input: largeInput,
|
||
expected: largeExpected,
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
got, err := unwrapGeminiResponse(tt.input)
|
||
if tt.wantErr {
|
||
require.Error(t, err)
|
||
return
|
||
}
|
||
require.NoError(t, err)
|
||
require.Equal(t, tt.expected, strings.TrimSpace(string(got)))
|
||
})
|
||
}
|
||
}
|
||
|
||
// ---------------------------------------------------------------------------
|
||
// Task 8.1 — extractGeminiUsage 测试
|
||
// ---------------------------------------------------------------------------
|
||
|
||
func TestExtractGeminiUsage(t *testing.T) {
|
||
tests := []struct {
|
||
name string
|
||
input string
|
||
wantNil bool
|
||
wantUsage *ClaudeUsage
|
||
}{
|
||
{
|
||
name: "完整 usageMetadata",
|
||
input: `{"usageMetadata":{"promptTokenCount":100,"candidatesTokenCount":50,"cachedContentTokenCount":20}}`,
|
||
wantNil: false,
|
||
wantUsage: &ClaudeUsage{
|
||
InputTokens: 80,
|
||
OutputTokens: 50,
|
||
CacheReadInputTokens: 20,
|
||
},
|
||
},
|
||
{
|
||
name: "缺失 cachedContentTokenCount",
|
||
input: `{"usageMetadata":{"promptTokenCount":100,"candidatesTokenCount":50}}`,
|
||
wantNil: false,
|
||
wantUsage: &ClaudeUsage{
|
||
InputTokens: 100,
|
||
OutputTokens: 50,
|
||
CacheReadInputTokens: 0,
|
||
},
|
||
},
|
||
{
|
||
name: "无 usageMetadata",
|
||
input: `{"candidates":[]}`,
|
||
wantNil: true,
|
||
},
|
||
{
|
||
// gjson 对 null 返回 Exists()=true,因此函数不会返回 nil,
|
||
// 而是返回全零的 ClaudeUsage。
|
||
name: "null usageMetadata — gjson Exists 为 true",
|
||
input: `{"usageMetadata":null}`,
|
||
wantNil: false,
|
||
wantUsage: &ClaudeUsage{
|
||
InputTokens: 0,
|
||
OutputTokens: 0,
|
||
CacheReadInputTokens: 0,
|
||
},
|
||
},
|
||
{
|
||
name: "零值字段",
|
||
input: `{"usageMetadata":{"promptTokenCount":0,"candidatesTokenCount":0,"cachedContentTokenCount":0}}`,
|
||
wantNil: false,
|
||
wantUsage: &ClaudeUsage{
|
||
InputTokens: 0,
|
||
OutputTokens: 0,
|
||
CacheReadInputTokens: 0,
|
||
},
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
got := extractGeminiUsage([]byte(tt.input))
|
||
if tt.wantNil {
|
||
if got != nil {
|
||
t.Fatalf("期望返回 nil,实际返回 %+v", got)
|
||
}
|
||
return
|
||
}
|
||
if got == nil {
|
||
t.Fatalf("期望返回非 nil,实际返回 nil")
|
||
}
|
||
if got.InputTokens != tt.wantUsage.InputTokens {
|
||
t.Errorf("InputTokens: 期望 %d,实际 %d", tt.wantUsage.InputTokens, got.InputTokens)
|
||
}
|
||
if got.OutputTokens != tt.wantUsage.OutputTokens {
|
||
t.Errorf("OutputTokens: 期望 %d,实际 %d", tt.wantUsage.OutputTokens, got.OutputTokens)
|
||
}
|
||
if got.CacheReadInputTokens != tt.wantUsage.CacheReadInputTokens {
|
||
t.Errorf("CacheReadInputTokens: 期望 %d,实际 %d", tt.wantUsage.CacheReadInputTokens, got.CacheReadInputTokens)
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
// ---------------------------------------------------------------------------
|
||
// Task 8.2 — estimateGeminiCountTokens 测试
|
||
// ---------------------------------------------------------------------------
|
||
|
||
func TestEstimateGeminiCountTokens(t *testing.T) {
|
||
tests := []struct {
|
||
name string
|
||
input string
|
||
wantGt0 bool // 期望结果 > 0
|
||
wantExact *int // 如果非 nil,期望精确匹配
|
||
}{
|
||
{
|
||
name: "含 systemInstruction 和 contents",
|
||
input: `{
|
||
"systemInstruction":{"parts":[{"text":"You are a helpful assistant."}]},
|
||
"contents":[{"parts":[{"text":"Hello, how are you?"}]}]
|
||
}`,
|
||
wantGt0: true,
|
||
},
|
||
{
|
||
name: "仅 contents,无 systemInstruction",
|
||
input: `{
|
||
"contents":[{"parts":[{"text":"Hello, how are you?"}]}]
|
||
}`,
|
||
wantGt0: true,
|
||
},
|
||
{
|
||
name: "空 parts",
|
||
input: `{"contents":[{"parts":[]}]}`,
|
||
wantGt0: false,
|
||
wantExact: intPtr(0),
|
||
},
|
||
{
|
||
name: "非文本 parts(inlineData)",
|
||
input: `{"contents":[{"parts":[{"inlineData":{"mimeType":"image/png"}}]}]}`,
|
||
wantGt0: false,
|
||
wantExact: intPtr(0),
|
||
},
|
||
{
|
||
name: "空白文本",
|
||
input: `{"contents":[{"parts":[{"text":" "}]}]}`,
|
||
wantGt0: false,
|
||
wantExact: intPtr(0),
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
got := estimateGeminiCountTokens([]byte(tt.input))
|
||
if tt.wantExact != nil {
|
||
if got != *tt.wantExact {
|
||
t.Errorf("期望精确值 %d,实际 %d", *tt.wantExact, got)
|
||
}
|
||
return
|
||
}
|
||
if tt.wantGt0 && got <= 0 {
|
||
t.Errorf("期望返回 > 0,实际 %d", got)
|
||
}
|
||
if !tt.wantGt0 && got != 0 {
|
||
t.Errorf("期望返回 0,实际 %d", got)
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
// ---------------------------------------------------------------------------
|
||
// Task 8.3 — ParseGeminiRateLimitResetTime 测试
|
||
// ---------------------------------------------------------------------------
|
||
|
||
func TestParseGeminiRateLimitResetTime(t *testing.T) {
|
||
tests := []struct {
|
||
name string
|
||
input string
|
||
wantNil bool
|
||
approxDelta int64 // 预期的 (返回值 - now) 大约是多少秒
|
||
}{
|
||
{
|
||
name: "正常 quotaResetDelay",
|
||
input: `{"error":{"details":[{"metadata":{"quotaResetDelay":"12.345s"}}]}}`,
|
||
wantNil: false,
|
||
approxDelta: 13, // 向上取整 12.345 -> 13
|
||
},
|
||
{
|
||
name: "daily quota",
|
||
input: `{"error":{"message":"quota per day exceeded"}}`,
|
||
wantNil: false,
|
||
approxDelta: -1, // 不检查精确 delta,仅检查非 nil
|
||
},
|
||
{
|
||
name: "无 details 且无 regex 匹配",
|
||
input: `{"error":{"message":"rate limit"}}`,
|
||
wantNil: true,
|
||
},
|
||
{
|
||
name: "regex 回退匹配",
|
||
input: `Please retry in 30s`,
|
||
wantNil: false,
|
||
approxDelta: 30,
|
||
},
|
||
{
|
||
name: "完全无匹配",
|
||
input: `{"error":{"code":429}}`,
|
||
wantNil: true,
|
||
},
|
||
{
|
||
name: "非法 JSON 但 regex 回退仍工作",
|
||
input: `not json but Please retry in 10s`,
|
||
wantNil: false,
|
||
approxDelta: 10,
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
now := time.Now().Unix()
|
||
got := ParseGeminiRateLimitResetTime([]byte(tt.input))
|
||
|
||
if tt.wantNil {
|
||
if got != nil {
|
||
t.Fatalf("期望返回 nil,实际返回 %d", *got)
|
||
}
|
||
return
|
||
}
|
||
|
||
if got == nil {
|
||
t.Fatalf("期望返回非 nil,实际返回 nil")
|
||
}
|
||
|
||
// approxDelta == -1 表示只检查非 nil,不检查具体值(如 daily quota 场景)
|
||
if tt.approxDelta == -1 {
|
||
// 仅验证返回的时间戳在合理范围内(未来的某个时间)
|
||
if *got < now {
|
||
t.Errorf("期望返回的时间戳 >= now(%d),实际 %d", now, *got)
|
||
}
|
||
return
|
||
}
|
||
|
||
// 使用 +/-2 秒容差进行范围检查
|
||
delta := *got - now
|
||
if delta < tt.approxDelta-2 || delta > tt.approxDelta+2 {
|
||
t.Errorf("期望 delta 约为 %d 秒(+/-2),实际 delta = %d 秒(返回值=%d, now=%d)",
|
||
tt.approxDelta, delta, *got, now)
|
||
}
|
||
})
|
||
}
|
||
}
|