//go:build unit package service import ( "encoding/json" "fmt" "strings" "testing" "github.com/Wei-Shaw/sub2api/internal/domain" "github.com/stretchr/testify/require" ) func TestParseGatewayRequest(t *testing.T) { body := []byte(`{"model":"claude-3-7-sonnet","stream":true,"metadata":{"user_id":"session_123e4567-e89b-12d3-a456-426614174000"},"system":[{"type":"text","text":"hello","cache_control":{"type":"ephemeral"}}],"messages":[{"content":"hi"}]}`) parsed, err := ParseGatewayRequest(body, "") require.NoError(t, err) require.Equal(t, "claude-3-7-sonnet", parsed.Model) require.True(t, parsed.Stream) require.Equal(t, "session_123e4567-e89b-12d3-a456-426614174000", parsed.MetadataUserID) require.True(t, parsed.HasSystem) require.NotNil(t, parsed.System) require.Len(t, parsed.Messages, 1) require.False(t, parsed.ThinkingEnabled) } func TestParseGatewayRequest_ThinkingEnabled(t *testing.T) { body := []byte(`{"model":"claude-sonnet-4-5","thinking":{"type":"enabled"},"messages":[{"content":"hi"}]}`) parsed, err := ParseGatewayRequest(body, "") require.NoError(t, err) require.Equal(t, "claude-sonnet-4-5", parsed.Model) require.True(t, parsed.ThinkingEnabled) } func TestParseGatewayRequest_MaxTokens(t *testing.T) { body := []byte(`{"model":"claude-haiku-4-5","max_tokens":1}`) parsed, err := ParseGatewayRequest(body, "") require.NoError(t, err) require.Equal(t, 1, parsed.MaxTokens) } func TestParseGatewayRequest_MaxTokensNonIntegralIgnored(t *testing.T) { body := []byte(`{"model":"claude-haiku-4-5","max_tokens":1.5}`) parsed, err := ParseGatewayRequest(body, "") require.NoError(t, err) require.Equal(t, 0, parsed.MaxTokens) } func TestParseGatewayRequest_SystemNull(t *testing.T) { body := []byte(`{"model":"claude-3","system":null}`) parsed, err := ParseGatewayRequest(body, "") require.NoError(t, err) // 显式传入 system:null 也应视为“字段已存在”,避免默认 system 被注入。 require.True(t, parsed.HasSystem) require.Nil(t, parsed.System) } func TestParseGatewayRequest_InvalidModelType(t *testing.T) { body := []byte(`{"model":123}`) _, err := ParseGatewayRequest(body, "") require.Error(t, err) } func TestParseGatewayRequest_InvalidStreamType(t *testing.T) { body := []byte(`{"stream":"true"}`) _, err := ParseGatewayRequest(body, "") require.Error(t, err) } // ============ Gemini 原生格式解析测试 ============ func TestParseGatewayRequest_GeminiContents(t *testing.T) { body := []byte(`{ "contents": [ {"role": "user", "parts": [{"text": "Hello"}]}, {"role": "model", "parts": [{"text": "Hi there"}]}, {"role": "user", "parts": [{"text": "How are you?"}]} ] }`) parsed, err := ParseGatewayRequest(body, domain.PlatformGemini) require.NoError(t, err) require.Len(t, parsed.Messages, 3, "should parse contents as Messages") require.False(t, parsed.HasSystem, "Gemini format should not set HasSystem") require.Nil(t, parsed.System, "no systemInstruction means nil System") } func TestParseGatewayRequest_GeminiSystemInstruction(t *testing.T) { body := []byte(`{ "systemInstruction": { "parts": [{"text": "You are a helpful assistant."}] }, "contents": [ {"role": "user", "parts": [{"text": "Hello"}]} ] }`) parsed, err := ParseGatewayRequest(body, domain.PlatformGemini) require.NoError(t, err) require.NotNil(t, parsed.System, "should parse systemInstruction.parts as System") parts, ok := parsed.System.([]any) require.True(t, ok) require.Len(t, parts, 1) partMap, ok := parts[0].(map[string]any) require.True(t, ok) require.Equal(t, "You are a helpful assistant.", partMap["text"]) require.Len(t, parsed.Messages, 1) } func TestParseGatewayRequest_GeminiWithModel(t *testing.T) { body := []byte(`{ "model": "gemini-2.5-pro", "contents": [{"role": "user", "parts": [{"text": "test"}]}] }`) parsed, err := ParseGatewayRequest(body, domain.PlatformGemini) require.NoError(t, err) require.Equal(t, "gemini-2.5-pro", parsed.Model) require.Len(t, parsed.Messages, 1) } func TestParseGatewayRequest_GeminiIgnoresAnthropicFields(t *testing.T) { // Gemini 格式下 system/messages 字段应被忽略 body := []byte(`{ "system": "should be ignored", "messages": [{"role": "user", "content": "ignored"}], "contents": [{"role": "user", "parts": [{"text": "real content"}]}] }`) parsed, err := ParseGatewayRequest(body, domain.PlatformGemini) require.NoError(t, err) require.False(t, parsed.HasSystem, "Gemini protocol should not parse Anthropic system field") require.Nil(t, parsed.System, "no systemInstruction = nil System") require.Len(t, parsed.Messages, 1, "should use contents, not messages") } func TestParseGatewayRequest_GeminiEmptyContents(t *testing.T) { body := []byte(`{"contents": []}`) parsed, err := ParseGatewayRequest(body, domain.PlatformGemini) require.NoError(t, err) require.Empty(t, parsed.Messages) } func TestParseGatewayRequest_GeminiNoContents(t *testing.T) { body := []byte(`{"model": "gemini-2.5-flash"}`) parsed, err := ParseGatewayRequest(body, domain.PlatformGemini) require.NoError(t, err) require.Nil(t, parsed.Messages) require.Equal(t, "gemini-2.5-flash", parsed.Model) } func TestParseGatewayRequest_AnthropicIgnoresGeminiFields(t *testing.T) { // Anthropic 格式下 contents/systemInstruction 字段应被忽略 body := []byte(`{ "system": "real system", "messages": [{"role": "user", "content": "real content"}], "contents": [{"role": "user", "parts": [{"text": "ignored"}]}], "systemInstruction": {"parts": [{"text": "ignored"}]} }`) parsed, err := ParseGatewayRequest(body, domain.PlatformAnthropic) require.NoError(t, err) require.True(t, parsed.HasSystem) require.Equal(t, "real system", parsed.System) require.Len(t, parsed.Messages, 1) msg, ok := parsed.Messages[0].(map[string]any) require.True(t, ok) require.Equal(t, "real content", msg["content"]) } func TestFilterThinkingBlocks(t *testing.T) { containsThinkingBlock := func(body []byte) bool { var req map[string]any if err := json.Unmarshal(body, &req); err != nil { return false } messages, ok := req["messages"].([]any) if !ok { return false } for _, msg := range messages { msgMap, ok := msg.(map[string]any) if !ok { continue } content, ok := msgMap["content"].([]any) if !ok { continue } for _, block := range content { blockMap, ok := block.(map[string]any) if !ok { continue } blockType, _ := blockMap["type"].(string) if blockType == "thinking" { return true } if blockType == "" { if _, hasThinking := blockMap["thinking"]; hasThinking { return true } } } } return false } tests := []struct { name string input string shouldFilter bool expectError bool }{ { name: "filters thinking blocks", input: `{"model":"claude-3-5-sonnet-20241022","messages":[{"role":"user","content":[{"type":"text","text":"Hello"},{"type":"thinking","thinking":"internal","signature":"invalid"},{"type":"text","text":"World"}]}]}`, shouldFilter: true, }, { name: "handles no thinking blocks", input: `{"model":"claude-3-5-sonnet-20241022","messages":[{"role":"user","content":[{"type":"text","text":"Hello"}]}]}`, shouldFilter: false, }, { name: "handles invalid JSON gracefully", input: `{invalid json`, shouldFilter: false, expectError: true, }, { name: "handles multiple messages with thinking blocks", input: `{"messages":[{"role":"user","content":[{"type":"text","text":"A"}]},{"role":"assistant","content":[{"type":"thinking","thinking":"think"},{"type":"text","text":"B"}]}]}`, shouldFilter: true, }, { name: "filters thinking blocks without type discriminator", input: `{"messages":[{"role":"assistant","content":[{"thinking":{"text":"internal"}},{"type":"text","text":"B"}]}]}`, shouldFilter: true, }, { name: "does not filter tool_use input fields named thinking", input: `{"messages":[{"role":"user","content":[{"type":"tool_use","id":"t1","name":"foo","input":{"thinking":"keepme","x":1}},{"type":"text","text":"Hello"}]}]}`, shouldFilter: false, }, { name: "handles empty messages array", input: `{"messages":[]}`, shouldFilter: false, }, { name: "handles missing messages field", input: `{"model":"claude-3"}`, shouldFilter: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := FilterThinkingBlocks([]byte(tt.input)) if tt.expectError { // For invalid JSON, should return original require.Equal(t, tt.input, string(result)) return } if tt.shouldFilter { require.False(t, containsThinkingBlock(result)) } else { // Ensure we don't rewrite JSON when no filtering is needed. require.Equal(t, tt.input, string(result)) } // Verify valid JSON returned (unless input was invalid) var parsed map[string]any err := json.Unmarshal(result, &parsed) require.NoError(t, err) }) } } func TestFilterThinkingBlocksForRetry_DisablesThinkingAndPreservesAsText(t *testing.T) { input := []byte(`{ "model":"claude-3-5-sonnet-20241022", "thinking":{"type":"enabled","budget_tokens":1024}, "messages":[ {"role":"user","content":[{"type":"text","text":"Hi"}]}, {"role":"assistant","content":[ {"type":"thinking","thinking":"Let me think...","signature":"bad_sig"}, {"type":"text","text":"Answer"} ]} ] }`) out := FilterThinkingBlocksForRetry(input) var req map[string]any require.NoError(t, json.Unmarshal(out, &req)) _, hasThinking := req["thinking"] require.False(t, hasThinking) msgs, ok := req["messages"].([]any) require.True(t, ok) require.Len(t, msgs, 2) assistant, ok := msgs[1].(map[string]any) require.True(t, ok) content, ok := assistant["content"].([]any) require.True(t, ok) require.Len(t, content, 2) first, ok := content[0].(map[string]any) require.True(t, ok) require.Equal(t, "text", first["type"]) require.Equal(t, "Let me think...", first["text"]) } func TestFilterThinkingBlocksForRetry_DisablesThinkingEvenWithoutThinkingBlocks(t *testing.T) { input := []byte(`{ "model":"claude-3-5-sonnet-20241022", "thinking":{"type":"enabled","budget_tokens":1024}, "messages":[ {"role":"user","content":[{"type":"text","text":"Hi"}]}, {"role":"assistant","content":[{"type":"text","text":"Prefill"}]} ] }`) out := FilterThinkingBlocksForRetry(input) var req map[string]any require.NoError(t, json.Unmarshal(out, &req)) _, hasThinking := req["thinking"] require.False(t, hasThinking) } func TestFilterThinkingBlocksForRetry_RemovesRedactedThinkingAndKeepsValidContent(t *testing.T) { input := []byte(`{ "thinking":{"type":"enabled","budget_tokens":1024}, "messages":[ {"role":"assistant","content":[ {"type":"redacted_thinking","data":"..."}, {"type":"text","text":"Visible"} ]} ] }`) out := FilterThinkingBlocksForRetry(input) var req map[string]any require.NoError(t, json.Unmarshal(out, &req)) _, hasThinking := req["thinking"] require.False(t, hasThinking) msgs, ok := req["messages"].([]any) require.True(t, ok) msg0, ok := msgs[0].(map[string]any) require.True(t, ok) content, ok := msg0["content"].([]any) require.True(t, ok) require.Len(t, content, 1) content0, ok := content[0].(map[string]any) require.True(t, ok) require.Equal(t, "text", content0["type"]) require.Equal(t, "Visible", content0["text"]) } func TestFilterThinkingBlocksForRetry_EmptyContentGetsPlaceholder(t *testing.T) { input := []byte(`{ "thinking":{"type":"enabled"}, "messages":[ {"role":"assistant","content":[{"type":"redacted_thinking","data":"..."}]} ] }`) out := FilterThinkingBlocksForRetry(input) var req map[string]any require.NoError(t, json.Unmarshal(out, &req)) msgs, ok := req["messages"].([]any) require.True(t, ok) msg0, ok := msgs[0].(map[string]any) require.True(t, ok) content, ok := msg0["content"].([]any) require.True(t, ok) require.Len(t, content, 1) content0, ok := content[0].(map[string]any) require.True(t, ok) require.Equal(t, "text", content0["type"]) require.NotEmpty(t, content0["text"]) } func TestFilterSignatureSensitiveBlocksForRetry_DowngradesTools(t *testing.T) { input := []byte(`{ "thinking":{"type":"enabled","budget_tokens":1024}, "messages":[ {"role":"assistant","content":[ {"type":"tool_use","id":"t1","name":"Bash","input":{"command":"ls"}}, {"type":"tool_result","tool_use_id":"t1","content":"ok","is_error":false} ]} ] }`) out := FilterSignatureSensitiveBlocksForRetry(input) var req map[string]any require.NoError(t, json.Unmarshal(out, &req)) _, hasThinking := req["thinking"] require.False(t, hasThinking) msgs, ok := req["messages"].([]any) require.True(t, ok) msg0, ok := msgs[0].(map[string]any) require.True(t, ok) content, ok := msg0["content"].([]any) require.True(t, ok) require.Len(t, content, 2) content0, ok := content[0].(map[string]any) require.True(t, ok) content1, ok := content[1].(map[string]any) require.True(t, ok) require.Equal(t, "text", content0["type"]) require.Equal(t, "text", content1["type"]) require.Contains(t, content0["text"], "tool_use") require.Contains(t, content1["text"], "tool_result") } // ============ Group 7: ParseGatewayRequest 补充单元测试 ============ // Task 7.1 — 类型校验边界测试 func TestParseGatewayRequest_TypeValidation(t *testing.T) { tests := []struct { name string body string wantErr bool errSubstr string // 期望的错误信息子串(为空则不检查) }{ { name: "model 为 int", body: `{"model":123}`, wantErr: true, errSubstr: "invalid model field type", }, { name: "model 为 array", body: `{"model":[]}`, wantErr: true, errSubstr: "invalid model field type", }, { name: "model 为 bool", body: `{"model":true}`, wantErr: true, errSubstr: "invalid model field type", }, { name: "model 为 null — gjson Null 类型触发类型校验错误", body: `{"model":null}`, wantErr: true, // gjson: Exists()=true, Type=Null != String → 返回错误 errSubstr: "invalid model field type", }, { name: "stream 为 string", body: `{"stream":"true"}`, wantErr: true, errSubstr: "invalid stream field type", }, { name: "stream 为 int", body: `{"stream":1}`, wantErr: true, errSubstr: "invalid stream field type", }, { name: "stream 为 null — gjson Null 类型触发类型校验错误", body: `{"stream":null}`, wantErr: true, // gjson: Exists()=true, Type=Null != True && != False → 返回错误 errSubstr: "invalid stream field type", }, { name: "model 为 object", body: `{"model":{}}`, wantErr: true, errSubstr: "invalid model field type", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { _, err := ParseGatewayRequest([]byte(tt.body), "") if tt.wantErr { require.Error(t, err) if tt.errSubstr != "" { require.Contains(t, err.Error(), tt.errSubstr) } } else { require.NoError(t, err) } }) } } // Task 7.2 — 可选字段缺失测试 func TestParseGatewayRequest_OptionalFieldsMissing(t *testing.T) { tests := []struct { name string body string wantModel string wantStream bool wantMetadataUID string wantHasSystem bool wantThinking bool wantMaxTokens int wantMessagesNil bool wantMessagesLen int }{ { name: "完全空 JSON — 所有字段零值", body: `{}`, wantModel: "", wantStream: false, wantMetadataUID: "", wantHasSystem: false, wantThinking: false, wantMaxTokens: 0, wantMessagesNil: true, }, { name: "metadata 无 user_id", body: `{"model":"test"}`, wantModel: "test", wantMetadataUID: "", wantHasSystem: false, wantThinking: false, }, { name: "thinking 非 enabled(type=disabled)", body: `{"model":"test","thinking":{"type":"disabled"}}`, wantModel: "test", wantThinking: false, }, { name: "thinking 字段缺失", body: `{"model":"test"}`, wantModel: "test", wantThinking: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { parsed, err := ParseGatewayRequest([]byte(tt.body), "") require.NoError(t, err) require.Equal(t, tt.wantModel, parsed.Model) require.Equal(t, tt.wantStream, parsed.Stream) require.Equal(t, tt.wantMetadataUID, parsed.MetadataUserID) require.Equal(t, tt.wantHasSystem, parsed.HasSystem) require.Equal(t, tt.wantThinking, parsed.ThinkingEnabled) require.Equal(t, tt.wantMaxTokens, parsed.MaxTokens) if tt.wantMessagesNil { require.Nil(t, parsed.Messages) } if tt.wantMessagesLen > 0 { require.Len(t, parsed.Messages, tt.wantMessagesLen) } }) } } // Task 7.3 — Gemini 协议分支测试 // 已有测试覆盖: // - TestParseGatewayRequest_GeminiSystemInstruction: 正常 systemInstruction+contents // - TestParseGatewayRequest_GeminiNoContents: 缺失 contents // - TestParseGatewayRequest_GeminiContents: 正常 contents(无 systemInstruction) // 因此跳过。 // Task 7.4 — max_tokens 边界测试 func TestParseGatewayRequest_MaxTokensBoundary(t *testing.T) { tests := []struct { name string body string wantMaxTokens int wantErr bool }{ { name: "正常整数", body: `{"max_tokens":1024}`, wantMaxTokens: 1024, }, { name: "浮点数(非整数)被忽略", body: `{"max_tokens":10.5}`, wantMaxTokens: 0, }, { name: "负整数可以通过", body: `{"max_tokens":-1}`, wantMaxTokens: -1, }, { name: "超大值不 panic", body: `{"max_tokens":9999999999999999}`, wantMaxTokens: 10000000000000000, // float64 精度导致 9999999999999999 → 1e16 }, { name: "null 值被忽略", body: `{"max_tokens":null}`, wantMaxTokens: 0, // gjson Type=Null != Number → 条件不满足,跳过 }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { parsed, err := ParseGatewayRequest([]byte(tt.body), "") if tt.wantErr { require.Error(t, err) return } require.NoError(t, err) require.Equal(t, tt.wantMaxTokens, parsed.MaxTokens) }) } } // ============ Task 7.5: Benchmark 测试 ============ // parseGatewayRequestOld 是基于完整 json.Unmarshal 的旧实现,用于 benchmark 对比基线。 // 核心路径:先 Unmarshal 到 map[string]any,再逐字段提取。 func parseGatewayRequestOld(body []byte, protocol string) (*ParsedRequest, error) { parsed := &ParsedRequest{ Body: body, } var req map[string]any if err := json.Unmarshal(body, &req); err != nil { return nil, err } // model if raw, ok := req["model"]; ok { s, ok := raw.(string) if !ok { return nil, fmt.Errorf("invalid model field type") } parsed.Model = s } // stream if raw, ok := req["stream"]; ok { b, ok := raw.(bool) if !ok { return nil, fmt.Errorf("invalid stream field type") } parsed.Stream = b } // metadata.user_id if meta, ok := req["metadata"].(map[string]any); ok { if uid, ok := meta["user_id"].(string); ok { parsed.MetadataUserID = uid } } // thinking.type if thinking, ok := req["thinking"].(map[string]any); ok { if thinkType, ok := thinking["type"].(string); ok && thinkType == "enabled" { parsed.ThinkingEnabled = true } } // max_tokens if raw, ok := req["max_tokens"]; ok { if n, ok := parseIntegralNumber(raw); ok { parsed.MaxTokens = n } } // system / messages(按协议分支) switch protocol { case domain.PlatformGemini: if sysInst, ok := req["systemInstruction"].(map[string]any); ok { if parts, ok := sysInst["parts"].([]any); ok { parsed.System = parts } } if contents, ok := req["contents"].([]any); ok { parsed.Messages = contents } default: if system, ok := req["system"]; ok { parsed.HasSystem = true parsed.System = system } if messages, ok := req["messages"].([]any); ok { parsed.Messages = messages } } return parsed, nil } // buildSmallJSON 构建 ~500B 的小型测试 JSON func buildSmallJSON() []byte { return []byte(`{"model":"claude-sonnet-4-5","stream":true,"max_tokens":4096,"metadata":{"user_id":"user-abc123"},"thinking":{"type":"enabled","budget_tokens":2048},"system":"You are a helpful assistant.","messages":[{"role":"user","content":"What is the meaning of life?"},{"role":"assistant","content":"The meaning of life is a philosophical question."},{"role":"user","content":"Can you elaborate?"}]}`) } // buildLargeJSON 构建 ~50KB 的大型测试 JSON(大量 messages) func buildLargeJSON() []byte { var b strings.Builder b.WriteString(`{"model":"claude-sonnet-4-5","stream":true,"max_tokens":8192,"metadata":{"user_id":"user-xyz789"},"system":[{"type":"text","text":"You are a detailed assistant.","cache_control":{"type":"ephemeral"}}],"messages":[`) msgCount := 200 for i := 0; i < msgCount; i++ { if i > 0 { b.WriteByte(',') } if i%2 == 0 { b.WriteString(fmt.Sprintf(`{"role":"user","content":"This is user message number %d with some extra padding text to make the message reasonably long for benchmarking purposes. Lorem ipsum dolor sit amet."}`, i)) } else { b.WriteString(fmt.Sprintf(`{"role":"assistant","content":[{"type":"text","text":"This is assistant response number %d. I will provide a detailed answer with multiple sentences to simulate real conversation content for benchmark testing."}]}`, i)) } } b.WriteString(`]}`) return []byte(b.String()) } func BenchmarkParseGatewayRequest_Old_Small(b *testing.B) { data := buildSmallJSON() b.SetBytes(int64(len(data))) b.ResetTimer() for i := 0; i < b.N; i++ { _, _ = parseGatewayRequestOld(data, "") } } func BenchmarkParseGatewayRequest_New_Small(b *testing.B) { data := buildSmallJSON() b.SetBytes(int64(len(data))) b.ResetTimer() for i := 0; i < b.N; i++ { _, _ = ParseGatewayRequest(data, "") } } func BenchmarkParseGatewayRequest_Old_Large(b *testing.B) { data := buildLargeJSON() b.SetBytes(int64(len(data))) b.ResetTimer() for i := 0; i < b.N; i++ { _, _ = parseGatewayRequestOld(data, "") } } func BenchmarkParseGatewayRequest_New_Large(b *testing.B) { data := buildLargeJSON() b.SetBytes(int64(len(data))) b.ResetTimer() for i := 0; i < b.N; i++ { _, _ = ParseGatewayRequest(data, "") } }