Files
sub2api/backend/internal/service/gateway_tool_rewrite_test.go
keh4l 6e12578bc5 feat(gateway): port Parrot tool-name obfuscation + message cache breakpoints
Implements the remaining three parity items with Parrot cc_mimicry:

  D) Tool-name obfuscation
     - Dynamic mapping when tools.length > 5 (matches Parrot threshold).
       Fake names follow {prefix}{name[:3]}{i:02d} (e.g. 'manage_bas00').
       Go port of random.Random(hash(tuple(names))) uses fnv64a seed +
       math/rand; byte-exact reproduction is impossible (Python hash vs
       Go hash), but the two invariants that matter are preserved:
         * same input tool_names yield identical mapping (cache hit)
         * prefix pool is shuffled (names look distributed)
     - Static prefix map (sessions_ -> cc_sess_, session_ -> cc_ses_)
       applied as fallback, matching Parrot TOOL_NAME_REWRITES verbatim.
     - Server tools (web_search_20250305, computer_*, etc.) are NOT
       renamed; only type=='function' and type=='custom' tools are.
     - tool_choice.name is rewritten in sync (only when type=='tool').
     - Response side: bytes-level replace on every SSE chunk / JSON
       body at 6 injection points (standard stream/non-stream,
       passthrough stream/non-stream, chat_completions stream +
       non-stream, responses stream + non-stream). Reverse mapping
       applied longest-fake-name-first to prevent substring conflicts
       (parity with Parrot _restore_tool_names_in_chunk).
     - tool_choice is no longer unconditionally deleted in
       normalizeClaudeOAuthRequestBody — Parrot passes it through.

  E) tools[-1] cache_control breakpoint
     - Injected as {type:ephemeral, ttl:<DefaultCacheControlTTL>} when
       the last tool has no cache_control. Client-provided ttl is
       passed through unchanged (repo-wide policy).

  F) messages cache_control strategy
     - stripMessageCacheControl removes every client-provided
       messages[*].content[*].cache_control (multi-turn stability).
     - addMessageCacheBreakpoints then injects two stable breakpoints:
       (1) last message, and (2) second-to-last user turn when
       messages.length >= 4.
     - Combined with the system block breakpoint and tools[-1]
       breakpoint, this gives exactly the 4 breakpoints Anthropic
       allows per request.

Non-trivial implementation details to be aware of when rebasing:

  * Two new files, no upstream collision:
      gateway_tool_rewrite.go       (D + E algorithms)
      gateway_messages_cache.go     (F strip + breakpoints)
  * Two new feature calls bolted onto the tail of
    applyClaudeCodeOAuthMimicryToBody in gateway_service.go — rebase
    conflicts will be ~10 lines maximum.
  * Response-side injection points all wrap their existing write with
    reverseToolNamesIfPresent(c, ...), preserving original behavior
    when no mapping is stored (static prefix rollback still runs).
  * Non-stream chat/responses switched from c.JSON to
    json.Marshal + c.Data so bytes-level replace is possible.
  * Retry bodies (FilterThinkingBlocksForRetry,
    FilterSignatureSensitiveBlocksForRetry, RectifyThinkingBudget)
    only prune blocks — they preserve the already-obfuscated tool
    names, so no extra mapping re-application is needed.

Manual QA: end-to-end scenario verified with 6 tools (above threshold)
and tool_choice.type=='tool'. Obfuscation + restore roundtrip shown
in test logs; then removed the temp test file.

Tests (16 new):
  - buildDynamicToolMap stability + below-threshold guard
  - sanitizeToolName precedence (dynamic > static)
  - restoreToolNamesInBytes longest-first + static rollback
  - applyToolNameRewriteToBody skips server tools + syncs tool_choice
  - applyToolsLastCacheBreakpoint defaults to 5m + passes client ttl
  - stripMessageCacheControl + addMessageCacheBreakpoints in the
    1/4/string-content cases + second-to-last user turn selection
  - buildToolNameRewriteFromBody ReverseOrdered is desc-by-fake-length
  - fake name shape follows Parrot {prefix}{head3}{i:02d}
2026-04-24 23:16:32 +08:00

186 lines
8.1 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package service
import (
"strings"
"testing"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
)
func TestBuildDynamicToolMap_BelowThreshold(t *testing.T) {
// Parrot 行为tools 数量 ≤ 5 时不做动态映射。
names := []string{"bash", "edit", "read", "write", "search"}
require.Nil(t, buildDynamicToolMap(names))
}
func TestBuildDynamicToolMap_AboveThresholdIsStable(t *testing.T) {
// Parrot 不变量:同一组 tool_names 在同进程内映射稳定(保证 cache 命中)。
names := []string{"alpha", "beta", "gamma", "delta", "epsilon", "zeta"}
a := buildDynamicToolMap(names)
b := buildDynamicToolMap(names)
require.NotNil(t, a)
require.Equal(t, a, b, "same input tool_names must yield identical mapping")
require.Len(t, a, 6)
for _, name := range names {
require.Contains(t, a, name)
require.NotEqual(t, name, a[name])
}
}
func TestSanitizeToolName_StaticPrefix(t *testing.T) {
require.Equal(t, "cc_sess_list", sanitizeToolName("sessions_list", nil))
require.Equal(t, "cc_ses_get", sanitizeToolName("session_get", nil))
require.Equal(t, "bash", sanitizeToolName("bash", nil))
}
func TestSanitizeToolName_DynamicTakesPrecedence(t *testing.T) {
dyn := map[string]string{"sessions_list": "analyze_ses00"}
got := sanitizeToolName("sessions_list", dyn)
require.Equal(t, "analyze_ses00", got, "dynamic mapping wins over static prefix")
}
func TestRestoreToolNamesInBytes_LongestFirst(t *testing.T) {
// 当假名 "abc_12" 是另一个更长假名的子串(真实场景极少但算法必须防御)时,
// 长的必须先替换。本测试用显式构造的映射来验证排序不变量。
rw := &ToolNameRewrite{
Forward: map[string]string{"foo": "abc_12", "bar": "abc_12_ext"},
Reverse: map[string]string{"abc_12": "foo", "abc_12_ext": "bar"},
}
// 手工构造 ReverseOrdered长的在前
rw.ReverseOrdered = [][2]string{
{"abc_12_ext", "bar"},
{"abc_12", "foo"},
}
data := []byte(`{"tool":"abc_12_ext","other":"abc_12"}`)
restored := string(restoreToolNamesInBytes(data, rw))
require.Equal(t, `{"tool":"bar","other":"foo"}`, restored)
}
func TestRestoreToolNamesInBytes_StaticPrefixRollback(t *testing.T) {
data := []byte(`{"name":"sessions_list","id":"cc_ses_xyz"}`)
got := string(restoreToolNamesInBytes(data, nil))
require.Equal(t, `{"name":"sessions_list","id":"session_xyz"}`, got)
}
func TestApplyToolNameRewriteToBody_RenamesToolsAndToolChoice(t *testing.T) {
body := []byte(`{"tools":[{"name":"sessions_list","input_schema":{}},{"name":"session_get","input_schema":{}},{"name":"web_search","type":"web_search_20250305"}],"tool_choice":{"type":"tool","name":"sessions_list"}}`)
rw := buildToolNameRewriteFromBody(body)
require.NotNil(t, rw)
require.Contains(t, rw.Forward, "sessions_list")
require.Contains(t, rw.Forward, "session_get")
// web_search is a server tool, not rewritten
require.NotContains(t, rw.Forward, "web_search")
out := applyToolNameRewriteToBody(body, rw)
// tools[0].name and tools[1].name rewritten; tools[2].name untouched
require.Equal(t, "cc_sess_list", gjson.GetBytes(out, "tools.0.name").String())
require.Equal(t, "cc_ses_get", gjson.GetBytes(out, "tools.1.name").String())
require.Equal(t, "web_search", gjson.GetBytes(out, "tools.2.name").String())
// tool_choice.name rewritten
require.Equal(t, "cc_sess_list", gjson.GetBytes(out, "tool_choice.name").String())
require.Equal(t, "tool", gjson.GetBytes(out, "tool_choice.type").String())
}
func TestApplyToolsLastCacheBreakpoint_InjectsDefault(t *testing.T) {
body := []byte(`{"tools":[{"name":"a","input_schema":{}},{"name":"b","input_schema":{}}]}`)
out := applyToolsLastCacheBreakpoint(body)
require.Equal(t, "ephemeral", gjson.GetBytes(out, "tools.1.cache_control.type").String())
require.Equal(t, "5m", gjson.GetBytes(out, "tools.1.cache_control.ttl").String())
// First tool untouched
require.False(t, gjson.GetBytes(out, "tools.0.cache_control").Exists())
}
func TestApplyToolsLastCacheBreakpoint_PassesThroughClientTTL(t *testing.T) {
body := []byte(`{"tools":[{"name":"a","input_schema":{},"cache_control":{"type":"ephemeral","ttl":"1h"}}]}`)
out := applyToolsLastCacheBreakpoint(body)
// User-provided ttl must be preserved.
require.Equal(t, "1h", gjson.GetBytes(out, "tools.0.cache_control.ttl").String())
}
func TestStripMessageCacheControl(t *testing.T) {
body := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi","cache_control":{"type":"ephemeral"}}]}]}`)
out := stripMessageCacheControl(body)
require.False(t, gjson.GetBytes(out, "messages.0.content.0.cache_control").Exists())
}
func TestAddMessageCacheBreakpoints_LastMessageOnly(t *testing.T) {
body := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}]}`)
out := addMessageCacheBreakpoints(body)
require.Equal(t, "ephemeral", gjson.GetBytes(out, "messages.0.content.0.cache_control.type").String())
require.Equal(t, "5m", gjson.GetBytes(out, "messages.0.content.0.cache_control.ttl").String())
}
func TestAddMessageCacheBreakpoints_SecondToLastUserTurn(t *testing.T) {
// Parrot 不变量messages ≥ 4 时才打第二个断点,且位置是"倒数第二个 user turn"。
body := []byte(`{"messages":[
{"role":"user","content":[{"type":"text","text":"q1"}]},
{"role":"assistant","content":[{"type":"text","text":"a1"}]},
{"role":"user","content":[{"type":"text","text":"q2"}]},
{"role":"assistant","content":[{"type":"text","text":"a2"}]}
]}`)
out := addMessageCacheBreakpoints(body)
// 最后一条 assistant 被打断点
require.Equal(t, "ephemeral", gjson.GetBytes(out, "messages.3.content.0.cache_control.type").String())
// 倒数第二个 user turn = index 0唯一另一个 user
require.Equal(t, "ephemeral", gjson.GetBytes(out, "messages.0.content.0.cache_control.type").String())
// 其他不打断点
require.False(t, gjson.GetBytes(out, "messages.1.content.0.cache_control").Exists())
require.False(t, gjson.GetBytes(out, "messages.2.content.0.cache_control").Exists())
}
func TestAddMessageCacheBreakpoints_StringContentPromoted(t *testing.T) {
body := []byte(`{"messages":[{"role":"user","content":"hi"}]}`)
out := addMessageCacheBreakpoints(body)
// content 升级成数组
require.True(t, gjson.GetBytes(out, "messages.0.content").IsArray())
require.Equal(t, "text", gjson.GetBytes(out, "messages.0.content.0.type").String())
require.Equal(t, "hi", gjson.GetBytes(out, "messages.0.content.0.text").String())
require.Equal(t, "5m", gjson.GetBytes(out, "messages.0.content.0.cache_control.ttl").String())
}
func TestBuildToolNameRewriteFromBody_ReverseOrderedByLengthDesc(t *testing.T) {
// 超过阈值触发动态映射,验证 ReverseOrdered 按假名长度倒序排列
body := []byte(`{"tools":[
{"name":"t1","input_schema":{}},
{"name":"t2","input_schema":{}},
{"name":"t3","input_schema":{}},
{"name":"t4","input_schema":{}},
{"name":"t5","input_schema":{}},
{"name":"t6","input_schema":{}}
]}`)
rw := buildToolNameRewriteFromBody(body)
require.NotNil(t, rw)
require.NotEmpty(t, rw.ReverseOrdered)
for i := 1; i < len(rw.ReverseOrdered); i++ {
require.GreaterOrEqual(t, len(rw.ReverseOrdered[i-1][0]), len(rw.ReverseOrdered[i][0]),
"ReverseOrdered must be sorted by fake-name length descending")
}
}
func TestRestoreToolNamesInBytes_NoMapping_NoStaticMatch_IsNoop(t *testing.T) {
data := []byte("plain text without any tool names")
require.Equal(t, string(data), string(restoreToolNamesInBytes(data, nil)))
}
// Ensure the fake name format follows Parrot's "{prefix}{name[:3]}{i:02d}".
func TestBuildDynamicToolMap_FakeNameShape(t *testing.T) {
names := []string{"alphabet", "bravo", "charlie", "delta", "echo", "foxtrot"}
m := buildDynamicToolMap(names)
require.NotNil(t, m)
for _, name := range names {
fake, ok := m[name]
require.True(t, ok)
// fake = prefix + head3 + "%02d"
// ends with two decimal digits
require.Regexp(t, `^[a-z]+_[a-z0-9]{1,3}\d{2}$`, fake)
head := name
if len(head) > 3 {
head = head[:3]
}
require.True(t, strings.Contains(fake, head), "fake %q should contain head3 %q of %q", fake, head, name)
}
}