Files
sub2api-ht/backend/internal/service/gateway_tool_rewrite_test.go
2026-05-11 17:27:04 +08:00

233 lines
11 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package service
import (
"strings"
"testing"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
)
func TestBuildDynamicToolMap_BelowThreshold(t *testing.T) {
// Parrot 行为tools 数量 ≤ 5 时不做动态映射。
names := []string{"bash", "edit", "read", "write", "search"}
require.Nil(t, buildDynamicToolMap(names))
}
func TestBuildDynamicToolMap_AboveThresholdIsStable(t *testing.T) {
// Parrot 不变量:同一组 tool_names 在同进程内映射稳定(保证 cache 命中)。
names := []string{"alpha", "beta", "gamma", "delta", "epsilon", "zeta"}
a := buildDynamicToolMap(names)
b := buildDynamicToolMap(names)
require.NotNil(t, a)
require.Equal(t, a, b, "same input tool_names must yield identical mapping")
require.Len(t, a, 6)
for _, name := range names {
require.Contains(t, a, name)
require.NotEqual(t, name, a[name])
}
}
func TestSanitizeToolName_StaticPrefix(t *testing.T) {
require.Equal(t, "cc_sess_list", sanitizeToolName("sessions_list", nil))
require.Equal(t, "cc_ses_get", sanitizeToolName("session_get", nil))
require.Equal(t, "bash", sanitizeToolName("bash", nil))
}
func TestSanitizeToolName_DynamicTakesPrecedence(t *testing.T) {
dyn := map[string]string{"sessions_list": "analyze_ses00"}
got := sanitizeToolName("sessions_list", dyn)
require.Equal(t, "analyze_ses00", got, "dynamic mapping wins over static prefix")
}
func TestRestoreToolNamesInBytes_LongestFirst(t *testing.T) {
// 当假名 "abc_12" 是另一个更长假名的子串(真实场景极少但算法必须防御)时,
// 长的必须先替换。本测试用显式构造的映射来验证排序不变量。
rw := &ToolNameRewrite{
Forward: map[string]string{"foo": "abc_12", "bar": "abc_12_ext"},
Reverse: map[string]string{"abc_12": "foo", "abc_12_ext": "bar"},
}
// 手工构造 ReverseOrdered长的在前
rw.ReverseOrdered = [][2]string{
{"abc_12_ext", "bar"},
{"abc_12", "foo"},
}
data := []byte(`{"tool":"abc_12_ext","other":"abc_12"}`)
restored := string(restoreToolNamesInBytes(data, rw))
require.Equal(t, `{"tool":"bar","other":"foo"}`, restored)
}
func TestRestoreToolNamesInBytes_StaticPrefixRollback(t *testing.T) {
data := []byte(`{"name":"sessions_list","id":"cc_ses_xyz"}`)
got := string(restoreToolNamesInBytes(data, nil))
require.Equal(t, `{"name":"sessions_list","id":"session_xyz"}`, got)
}
func TestApplyToolNameRewriteToBody_RenamesToolsAndToolChoice(t *testing.T) {
body := []byte(`{"tools":[{"name":"sessions_list","input_schema":{}},{"name":"session_get","input_schema":{}},{"name":"web_search","type":"web_search_20250305"}],"tool_choice":{"type":"tool","name":"sessions_list"}}`)
rw := buildToolNameRewriteFromBody(body)
require.NotNil(t, rw)
require.Contains(t, rw.Forward, "sessions_list")
require.Contains(t, rw.Forward, "session_get")
// web_search 是 server tool不参与工具名改写
require.NotContains(t, rw.Forward, "web_search")
out := applyToolNameRewriteToBody(body, rw)
// tools[0].name 和 tools[1].name 被改写tools[2].name 保持不变
require.Equal(t, "cc_sess_list", gjson.GetBytes(out, "tools.0.name").String())
require.Equal(t, "cc_ses_get", gjson.GetBytes(out, "tools.1.name").String())
require.Equal(t, "web_search", gjson.GetBytes(out, "tools.2.name").String())
// tool_choice.name 被同步改写
require.Equal(t, "cc_sess_list", gjson.GetBytes(out, "tool_choice.name").String())
require.Equal(t, "tool", gjson.GetBytes(out, "tool_choice.type").String())
}
func TestApplyToolNameRewriteToBody_RenamesToolUseInMessages(t *testing.T) {
// sessions_list 通过静态前缀规则改写为 cc_sess_list
// web_search 是 server tooltype != ""),不参与工具名改写
// messages 中的 tool_use.name 必须同步改写,才能和 tools[] 保持一致
body := []byte(`{"tools":[{"name":"sessions_list","input_schema":{}},{"name":"web_search","type":"web_search_20250305"}],"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]},{"role":"assistant","content":[{"type":"tool_use","id":"tu_01","name":"sessions_list","input":{}},{"type":"text","text":"thinking"}]},{"role":"user","content":[{"type":"tool_result","tool_use_id":"tu_01","content":"ok"}]}]}`)
rw := buildToolNameRewriteFromBody(body)
require.NotNil(t, rw)
require.Equal(t, "cc_sess_list", rw.Forward["sessions_list"])
out := applyToolNameRewriteToBody(body, rw)
// tools[0].name 被改写
require.Equal(t, "cc_sess_list", gjson.GetBytes(out, "tools.0.name").String())
// tools[1].name 是 server tool保持不变
require.Equal(t, "web_search", gjson.GetBytes(out, "tools.1.name").String())
// messages[1].content[0].name 是 tool_use必须同步改写以匹配 tools[]
require.Equal(t, "cc_sess_list", gjson.GetBytes(out, "messages.1.content.0.name").String())
// messages[1].content[1] 是 text保持不变
require.Equal(t, "thinking", gjson.GetBytes(out, "messages.1.content.1.text").String())
// messages[2].content[0] 是 tool_result不包含 name 字段,保持不变
require.Equal(t, "ok", gjson.GetBytes(out, "messages.2.content.0.content").String())
}
func TestApplyToolNameRewriteToBody_RenamesToolUseWithDynamicMapping(t *testing.T) {
body := []byte(`{"tools":[{"name":"alpha_search","input_schema":{}},{"name":"beta_lookup","input_schema":{}},{"name":"gamma_fetch","input_schema":{}},{"name":"delta_update","input_schema":{}},{"name":"epsilon_parse","input_schema":{}},{"name":"zeta_render","input_schema":{}},{"name":"web_search","type":"web_search_20250305"}],"tool_choice":{"type":"tool","name":"gamma_fetch"},"messages":[{"role":"assistant","content":[{"type":"tool_use","id":"tu_dyn","name":"gamma_fetch","input":{}},{"type":"tool_use","id":"tu_srv","name":"web_search","input":{}},{"type":"text","text":"done"}]},{"role":"user","content":[{"type":"tool_result","tool_use_id":"tu_dyn","content":"ok"}]}]}`)
rw := buildToolNameRewriteFromBody(body)
require.NotNil(t, rw)
require.Len(t, rw.Forward, 6)
fakeGamma := rw.Forward["gamma_fetch"]
require.NotEmpty(t, fakeGamma)
require.NotEqual(t, "gamma_fetch", fakeGamma)
require.NotContains(t, rw.Forward, "web_search")
out := applyToolNameRewriteToBody(body, rw)
// 动态映射会改写 tools[]、tool_choice 和历史 tool_use 中的同一个工具名
require.Equal(t, fakeGamma, gjson.GetBytes(out, "tools.2.name").String())
require.Equal(t, fakeGamma, gjson.GetBytes(out, "tool_choice.name").String())
require.Equal(t, fakeGamma, gjson.GetBytes(out, "messages.0.content.0.name").String())
// server tool 不参与动态映射,历史 tool_use 中同名引用也保持不变
require.Equal(t, "web_search", gjson.GetBytes(out, "tools.6.name").String())
require.Equal(t, "web_search", gjson.GetBytes(out, "messages.0.content.1.name").String())
// tool_result 依靠 tool_use_id 关联,不需要 name 字段
require.Equal(t, "ok", gjson.GetBytes(out, "messages.1.content.0.content").String())
}
func TestApplyToolsLastCacheBreakpoint_InjectsDefault(t *testing.T) {
body := []byte(`{"tools":[{"name":"a","input_schema":{}},{"name":"b","input_schema":{}}]}`)
out := applyToolsLastCacheBreakpoint(body)
require.Equal(t, "ephemeral", gjson.GetBytes(out, "tools.1.cache_control.type").String())
require.Equal(t, "5m", gjson.GetBytes(out, "tools.1.cache_control.ttl").String())
// First tool untouched
require.False(t, gjson.GetBytes(out, "tools.0.cache_control").Exists())
}
func TestApplyToolsLastCacheBreakpoint_PassesThroughClientTTL(t *testing.T) {
body := []byte(`{"tools":[{"name":"a","input_schema":{},"cache_control":{"type":"ephemeral","ttl":"1h"}}]}`)
out := applyToolsLastCacheBreakpoint(body)
// User-provided ttl must be preserved.
require.Equal(t, "1h", gjson.GetBytes(out, "tools.0.cache_control.ttl").String())
}
func TestStripMessageCacheControl(t *testing.T) {
body := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi","cache_control":{"type":"ephemeral"}}]}]}`)
out := stripMessageCacheControl(body)
require.False(t, gjson.GetBytes(out, "messages.0.content.0.cache_control").Exists())
}
func TestAddMessageCacheBreakpoints_LastMessageOnly(t *testing.T) {
body := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}]}`)
out := addMessageCacheBreakpoints(body)
require.Equal(t, "ephemeral", gjson.GetBytes(out, "messages.0.content.0.cache_control.type").String())
require.Equal(t, "5m", gjson.GetBytes(out, "messages.0.content.0.cache_control.ttl").String())
}
func TestAddMessageCacheBreakpoints_SecondToLastUserTurn(t *testing.T) {
// Parrot 不变量messages ≥ 4 时才打第二个断点,且位置是"倒数第二个 user turn"。
body := []byte(`{"messages":[
{"role":"user","content":[{"type":"text","text":"q1"}]},
{"role":"assistant","content":[{"type":"text","text":"a1"}]},
{"role":"user","content":[{"type":"text","text":"q2"}]},
{"role":"assistant","content":[{"type":"text","text":"a2"}]}
]}`)
out := addMessageCacheBreakpoints(body)
// 最后一条 assistant 被打断点
require.Equal(t, "ephemeral", gjson.GetBytes(out, "messages.3.content.0.cache_control.type").String())
// 倒数第二个 user turn = index 0唯一另一个 user
require.Equal(t, "ephemeral", gjson.GetBytes(out, "messages.0.content.0.cache_control.type").String())
// 其他不打断点
require.False(t, gjson.GetBytes(out, "messages.1.content.0.cache_control").Exists())
require.False(t, gjson.GetBytes(out, "messages.2.content.0.cache_control").Exists())
}
func TestAddMessageCacheBreakpoints_StringContentPromoted(t *testing.T) {
body := []byte(`{"messages":[{"role":"user","content":"hi"}]}`)
out := addMessageCacheBreakpoints(body)
// content 升级成数组
require.True(t, gjson.GetBytes(out, "messages.0.content").IsArray())
require.Equal(t, "text", gjson.GetBytes(out, "messages.0.content.0.type").String())
require.Equal(t, "hi", gjson.GetBytes(out, "messages.0.content.0.text").String())
require.Equal(t, "5m", gjson.GetBytes(out, "messages.0.content.0.cache_control.ttl").String())
}
func TestBuildToolNameRewriteFromBody_ReverseOrderedByLengthDesc(t *testing.T) {
// 超过阈值触发动态映射,验证 ReverseOrdered 按假名长度倒序排列
body := []byte(`{"tools":[
{"name":"t1","input_schema":{}},
{"name":"t2","input_schema":{}},
{"name":"t3","input_schema":{}},
{"name":"t4","input_schema":{}},
{"name":"t5","input_schema":{}},
{"name":"t6","input_schema":{}}
]}`)
rw := buildToolNameRewriteFromBody(body)
require.NotNil(t, rw)
require.NotEmpty(t, rw.ReverseOrdered)
for i := 1; i < len(rw.ReverseOrdered); i++ {
require.GreaterOrEqual(t, len(rw.ReverseOrdered[i-1][0]), len(rw.ReverseOrdered[i][0]),
"ReverseOrdered must be sorted by fake-name length descending")
}
}
func TestRestoreToolNamesInBytes_NoMapping_NoStaticMatch_IsNoop(t *testing.T) {
data := []byte("plain text without any tool names")
require.Equal(t, string(data), string(restoreToolNamesInBytes(data, nil)))
}
// Ensure the fake name format follows Parrot's "{prefix}{name[:3]}{i:02d}".
func TestBuildDynamicToolMap_FakeNameShape(t *testing.T) {
names := []string{"alphabet", "bravo", "charlie", "delta", "echo", "foxtrot"}
m := buildDynamicToolMap(names)
require.NotNil(t, m)
for _, name := range names {
fake, ok := m[name]
require.True(t, ok)
// fake = prefix + head3 + "%02d"
// ends with two decimal digits
require.Regexp(t, `^[a-z]+_[a-z0-9]{1,3}\d{2}$`, fake)
head := name
if len(head) > 3 {
head = head[:3]
}
require.True(t, strings.Contains(fake, head), "fake %q should contain head3 %q of %q", fake, head, name)
}
}