diff --git a/backend/internal/service/gateway_forward_as_chat_completions.go b/backend/internal/service/gateway_forward_as_chat_completions.go index 8248d26c..c531667e 100644 --- a/backend/internal/service/gateway_forward_as_chat_completions.go +++ b/backend/internal/service/gateway_forward_as_chat_completions.go @@ -313,7 +313,14 @@ func (s *GatewayService) handleCCBufferedFromAnthropic( if s.responseHeaderFilter != nil { responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) } - c.JSON(http.StatusOK, ccResp) + // Marshal then bytes-replace so tool name mapping is reversed at byte level + // (parity with Parrot non-stream flow that marshals → restore → emit). + if respBytes, err := json.Marshal(ccResp); err == nil { + respBytes = reverseToolNamesIfPresent(c, respBytes) + c.Data(http.StatusOK, "application/json; charset=utf-8", respBytes) + } else { + c.JSON(http.StatusOK, ccResp) + } return &ForwardResult{ RequestID: requestID, @@ -384,7 +391,10 @@ func (s *GatewayService) handleCCStreamingFromAnthropic( if err != nil { return false } - if _, err := fmt.Fprint(c.Writer, sse); err != nil { + // Reverse tool name mapping: fake → real, per-chunk bytes.Replace. + // c 可能持有请求侧注入的 ToolNameRewrite;无则仅做静态前缀还原。 + out := string(reverseToolNamesIfPresent(c, []byte(sse))) + if _, err := fmt.Fprint(c.Writer, out); err != nil { return true // client disconnected } return false diff --git a/backend/internal/service/gateway_forward_as_responses.go b/backend/internal/service/gateway_forward_as_responses.go index 1ecad7d3..647193d6 100644 --- a/backend/internal/service/gateway_forward_as_responses.go +++ b/backend/internal/service/gateway_forward_as_responses.go @@ -332,7 +332,12 @@ func (s *GatewayService) handleResponsesBufferedStreamingResponse( if s.responseHeaderFilter != nil { responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) } - c.JSON(http.StatusOK, responsesResp) + if respBytes, err := json.Marshal(responsesResp); err == nil { + respBytes = reverseToolNamesIfPresent(c, respBytes) + c.Data(http.StatusOK, "application/json; charset=utf-8", respBytes) + } else { + c.JSON(http.StatusOK, responsesResp) + } return &ForwardResult{ RequestID: requestID, @@ -420,7 +425,8 @@ func (s *GatewayService) handleResponsesStreamingResponse( ) continue } - if _, err := fmt.Fprint(c.Writer, sse); err != nil { + out := string(reverseToolNamesIfPresent(c, []byte(sse))) + if _, err := fmt.Fprint(c.Writer, out); err != nil { logger.L().Info("forward_as_responses stream: client disconnected", zap.String("request_id", requestID), ) @@ -440,7 +446,8 @@ func (s *GatewayService) handleResponsesStreamingResponse( if err != nil { continue } - fmt.Fprint(c.Writer, sse) //nolint:errcheck + out := string(reverseToolNamesIfPresent(c, []byte(sse))) + fmt.Fprint(c.Writer, out) //nolint:errcheck } c.Writer.Flush() } diff --git a/backend/internal/service/gateway_messages_cache.go b/backend/internal/service/gateway_messages_cache.go new file mode 100644 index 00000000..cb5384ba --- /dev/null +++ b/backend/internal/service/gateway_messages_cache.go @@ -0,0 +1,141 @@ +package service + +import ( + "fmt" + + "github.com/Wei-Shaw/sub2api/internal/pkg/claude" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" +) + +// stripMessageCacheControl 移除 $.messages[*].content[*].cache_control。 +// 与 Parrot _strip_message_cache_control 语义一致。 +// +// 为什么必须整体清空:客户端(特别是 Claude Code)经常把 cache_control 打在 +// "当前最后一条 user message" 上;下一轮对话 messages 追加后,原本的最后一条 +// 变成中间某条,cache_control 还挂着就导致"前缀签名变化",破坏缓存命中。 +// 统一由代理重新打断点(addMessageCacheBreakpoints)才能在多轮间稳定。 +func stripMessageCacheControl(body []byte) []byte { + messages := gjson.GetBytes(body, "messages") + if !messages.IsArray() { + return body + } + msgIdx := -1 + messages.ForEach(func(_, msg gjson.Result) bool { + msgIdx++ + content := msg.Get("content") + if !content.IsArray() { + return true + } + blockIdx := -1 + content.ForEach(func(_, block gjson.Result) bool { + blockIdx++ + if !block.Get("cache_control").Exists() { + return true + } + path := fmt.Sprintf("messages.%d.content.%d.cache_control", msgIdx, blockIdx) + if next, err := sjson.DeleteBytes(body, path); err == nil { + body = next + } + return true + }) + return true + }) + return body +} + +// addMessageCacheBreakpoints 在 messages 上注入两个稳定的 cache 断点: +// 1. 最后一条 message +// 2. 当 messages 数量 ≥ 4 时,倒数第二个 role=user 的 message +// +// 与 Parrot add_cache_breakpoints 一致。两个断点 + system prompt block 的断点 +// + tools[-1] 的断点共同构成最多 4 个断点(Anthropic 上限)。 +// +// cache_control ttl 策略: +// - 若目标 block 已有 cache_control.ttl → 不覆盖 +// - 否则写入 {"type":"ephemeral","ttl": claude.DefaultCacheControlTTL} +// +// 调用前应先 stripMessageCacheControl 以保证幂等和稳定。 +func addMessageCacheBreakpoints(body []byte) []byte { + messages := gjson.GetBytes(body, "messages") + if !messages.IsArray() { + return body + } + arr := messages.Array() + if len(arr) == 0 { + return body + } + + body = injectCacheControlOnLastContentBlock(body, len(arr)-1, &arr[len(arr)-1]) + + if len(arr) >= 4 { + userCount := 0 + for i := len(arr) - 1; i >= 0; i-- { + if arr[i].Get("role").String() != "user" { + continue + } + userCount++ + if userCount == 2 { + body = injectCacheControlOnLastContentBlock(body, i, &arr[i]) + break + } + } + } + + return body +} + +// injectCacheControlOnLastContentBlock 把 cache_control 断点打在 messages[idx] +// 的最后一个 content block 上。若 content 是 string,先升级成单块 text 数组 +// (对齐 Parrot _inject_cache_on_msg 的行为)。 +// +// msg 是调用方已持有的 gjson.Result 快照,用于省一次 GetBytes。 +func injectCacheControlOnLastContentBlock(body []byte, idx int, msg *gjson.Result) []byte { + content := msg.Get("content") + + if content.Type == gjson.String { + text := content.String() + blockRaw := fmt.Sprintf( + `[{"type":"text","text":%s,"cache_control":{"type":"ephemeral","ttl":%q}}]`, + mustJSONString(text), claude.DefaultCacheControlTTL, + ) + if next, err := sjson.SetRawBytes(body, fmt.Sprintf("messages.%d.content", idx), []byte(blockRaw)); err == nil { + body = next + } + return body + } + + if !content.IsArray() { + return body + } + contentArr := content.Array() + if len(contentArr) == 0 { + return body + } + lastBlockIdx := len(contentArr) - 1 + lastBlock := contentArr[lastBlockIdx] + + if cc := lastBlock.Get("cache_control"); cc.Exists() && cc.Get("ttl").String() != "" { + return body + } + + pathPrefix := fmt.Sprintf("messages.%d.content.%d.cache_control", idx, lastBlockIdx) + existingCC := lastBlock.Get("cache_control") + if existingCC.Exists() { + if next, err := sjson.SetBytes(body, pathPrefix+".ttl", claude.DefaultCacheControlTTL); err == nil { + body = next + } + return body + } + raw := fmt.Sprintf(`{"type":"ephemeral","ttl":%q}`, claude.DefaultCacheControlTTL) + if next, err := sjson.SetRawBytes(body, pathPrefix, []byte(raw)); err == nil { + body = next + } + return body +} + +// mustJSONString 把一个 Go string 序列化为合法 JSON string(含引号), +// 用于 sjson.SetRawBytes 场景下手工拼 JSON。 +func mustJSONString(s string) string { + return fmt.Sprintf("%q", s) +} diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index c5c196a0..598de146 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -1110,10 +1110,17 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu } } - if gjson.GetBytes(out, "tool_choice").Exists() { - if next, ok := deleteJSONPathBytes(out, "tool_choice"); ok { - out = next - modified = true + // tool_choice:与 Parrot 对齐,不再无条件删除。 + // - 客户端传了 {"type":"tool","name":"X"} → 保留结构,name 由 + // applyToolNameRewriteToBody 同步映射为假名 + // - 其他形态(auto/any/none)原样透传 + // 如果 body 里完全没有 tools(空数组),tool_choice 没意义时才删除 + if !gjson.GetBytes(out, "tools").IsArray() || len(gjson.GetBytes(out, "tools").Array()) == 0 { + if gjson.GetBytes(out, "tool_choice").Exists() { + if next, ok := deleteJSONPathBytes(out, "tool_choice"); ok { + out = next + modified = true + } } } @@ -1214,6 +1221,25 @@ func (s *GatewayService) applyClaudeCodeOAuthMimicryToBody( } body, _ = normalizeClaudeOAuthRequestBody(body, model, normalizeOpts) + + // Phase D+E+F: messages cache 策略 + 工具名混淆 + tools[-1] 断点 + // 对齐 Parrot transform_request 里剩余的字段级改写。三步顺序有语义约束: + // 1) strip:先清除客户端的 messages[*].cache_control(多轮稳定性) + // 2) breakpoints:再注入 2 个断点(最后一条 + 倒数第二个 user turn) + // 3) tool rewrite:最后改 tools[*].name / tool_choice.name 并在 tools[-1] + // 上打断点;mapping 存入 gin.Context 供响应侧 bytes.Replace 还原。 + body = stripMessageCacheControl(body) + body = addMessageCacheBreakpoints(body) + + if rw := buildToolNameRewriteFromBody(body); rw != nil { + body = applyToolNameRewriteToBody(body, rw) + if c != nil { + c.Set(toolNameRewriteKey, rw) + } + } else { + body = applyToolsLastCacheBreakpoint(body) + } + return body } @@ -5099,7 +5125,8 @@ func (s *GatewayService) handleStreamingResponseAnthropicAPIKeyPassthrough( } if !clientDisconnected { - if _, err := io.WriteString(w, line); err != nil { + restored := string(reverseToolNamesIfPresent(c, []byte(line))) + if _, err := io.WriteString(w, restored); err != nil { clientDisconnected = true logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] Client disconnected during streaming, continue draining upstream for usage: account=%d", account.ID) } else if _, err := io.WriteString(w, "\n"); err != nil { @@ -5269,6 +5296,7 @@ func (s *GatewayService) handleNonStreamingResponseAnthropicAPIKeyPassthrough( if contentType == "" { contentType = "application/json" } + body = reverseToolNamesIfPresent(c, body) c.Data(resp.StatusCode, contentType, body) return usage, nil } @@ -7013,7 +7041,8 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http for _, block := range outputBlocks { if !clientDisconnected { - if _, werr := fmt.Fprint(w, block); werr != nil { + restored := reverseToolNamesIfPresent(c, []byte(block)) + if _, werr := fmt.Fprint(w, string(restored)); werr != nil { clientDisconnected = true logger.LegacyPrintf("service.gateway", "Client disconnected during streaming, continuing to drain upstream for billing") break @@ -7355,6 +7384,8 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h } } + body = reverseToolNamesIfPresent(c, body) + // 写入响应 c.Data(resp.StatusCode, contentType, body) diff --git a/backend/internal/service/gateway_tool_rewrite.go b/backend/internal/service/gateway_tool_rewrite.go new file mode 100644 index 00000000..c76cab62 --- /dev/null +++ b/backend/internal/service/gateway_tool_rewrite.go @@ -0,0 +1,313 @@ +package service + +import ( + "fmt" + "hash/fnv" + "math/rand" + "sort" + "strings" + + "github.com/Wei-Shaw/sub2api/internal/pkg/claude" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" +) + +// toolNameRewriteKey 是 gin.Context 上存 ToolNameRewrite 映射的 key。 +// 请求阶段写入,响应阶段读取,用于 bytes 级逆向还原假名 → 真名。 +const toolNameRewriteKey = "claude_tool_name_rewrite" + +// staticToolNameRewrites 是"静态前缀映射",与 Parrot src/transform/cc_mimicry.py +// TOOL_NAME_REWRITES 完全一致。只有以这些前缀开头的工具会被重写。 +var staticToolNameRewrites = map[string]string{ + "sessions_": "cc_sess_", + "session_": "cc_ses_", +} + +// fakeToolNamePrefixes 是"动态映射"的前缀池,与 Parrot _FAKE_PREFIXES 一致。 +// 当 tools 数量 > dynamicToolMapThreshold 时随机选用其中前缀生成可读假名。 +var fakeToolNamePrefixes = []string{ + "analyze_", "compute_", "fetch_", "generate_", "lookup_", "modify_", + "process_", "query_", "render_", "resolve_", "sync_", "update_", + "validate_", "convert_", "extract_", "manage_", "monitor_", "parse_", + "review_", "search_", "transform_", "handle_", "invoke_", "notify_", +} + +// dynamicToolMapThreshold 与 Parrot 一致:tools 数量超过 5 才启用动态映射。 +// 少量工具不需要混淆(一般是 Claude Code 自己的核心工具 bash/edit/read 等)。 +const dynamicToolMapThreshold = 5 + +// ToolNameRewrite 是单次请求内的工具名混淆映射。 +// - Forward: real → fake,请求阶段在 body 上应用。 +// - Reverse: fake → real,响应阶段对每个 chunk 做 bytes.Replace 还原。 +// +// ReverseOrdered 是按假名长度倒序的 (fake, real) 列表,用于防止短假名是长假名的 +// 子串时 bytes.Replace 先被吃掉(对齐 Parrot _restore_tool_names_in_chunk 的 +// `sorted(..., key=lambda x: len(x[1]), reverse=True)`)。 +type ToolNameRewrite struct { + Forward map[string]string + Reverse map[string]string + ReverseOrdered [][2]string +} + +// buildDynamicToolMap 构造 tools 的动态假名映射。 +// +// 与 Parrot _build_dynamic_tool_map 语义等价: +// - tools 数量 ≤ dynamicToolMapThreshold 时返回 nil(不做动态映射,走静态 fallback) +// - 同一组 tool_names 在同进程内映射稳定(保证 cache 命中) +// +// Parrot 用 `random.Random(hash(tuple(tool_names)))` 作 seed + shuffle 前缀池; +// Go 无法字节级复刻 Python hash,但"稳定性"和"前缀池打散"两个不变量都保留: +// 用 fnv64a(strings.Join(names, "\x00")) 作 seed 喂 math/rand.New。 +// 字节级不同不影响上游判定(Anthropic 不会验证我们的随机种子算法)。 +func buildDynamicToolMap(toolNames []string) map[string]string { + if len(toolNames) <= dynamicToolMapThreshold { + return nil + } + h := fnv.New64a() + for i, n := range toolNames { + if i > 0 { + _, _ = h.Write([]byte{0}) + } + _, _ = h.Write([]byte(n)) + } + rng := rand.New(rand.NewSource(int64(h.Sum64()))) + + available := make([]string, len(fakeToolNamePrefixes)) + copy(available, fakeToolNamePrefixes) + rng.Shuffle(len(available), func(i, j int) { available[i], available[j] = available[j], available[i] }) + + mapping := make(map[string]string, len(toolNames)) + for i, name := range toolNames { + prefix := available[i%len(available)] + headLen := 3 + if len(name) < 3 { + headLen = len(name) + } + fake := fmt.Sprintf("%s%s%02d", prefix, name[:headLen], i) + mapping[name] = fake + } + return mapping +} + +// sanitizeToolName 把真名转成假名。 +// 与 Parrot _sanitize_tool_name 语义一致:动态映射优先,再走静态前缀映射。 +func sanitizeToolName(name string, dynamic map[string]string) string { + if dynamic != nil { + if fake, ok := dynamic[name]; ok { + return fake + } + } + for prefix, replacement := range staticToolNameRewrites { + if strings.HasPrefix(name, prefix) { + return replacement + name[len(prefix):] + } + } + return name +} + +// shouldMimicToolName 指示某个 tool 是否需要重命名。 +// server tool(type != "" 且不是 "function" / "custom")是 Anthropic 协议语义的一部分, +// 比如 "web_search_20250305" / "computer_20250124";误改会导致上游拒绝。 +func shouldMimicToolName(toolType string) bool { + if toolType == "" || toolType == "function" || toolType == "custom" { + return true + } + return false +} + +// buildToolNameRewriteFromBody 扫描 body 的 tools[*].name,构造 ToolNameRewrite +// 并返回它。若不需要混淆(tools 数量不足 + 没有匹配静态前缀的工具)返回 nil。 +// +// 注意:只扫描,不改 body。真正的 body 改写在 applyToolNameRewriteToBody。 +func buildToolNameRewriteFromBody(body []byte) *ToolNameRewrite { + tools := gjson.GetBytes(body, "tools") + if !tools.IsArray() { + return nil + } + + mimicableNames := make([]string, 0) + toolsArr := tools.Array() + for _, t := range toolsArr { + if !shouldMimicToolName(t.Get("type").String()) { + continue + } + name := t.Get("name").String() + if name == "" { + continue + } + mimicableNames = append(mimicableNames, name) + } + + dynamic := buildDynamicToolMap(mimicableNames) + + rw := &ToolNameRewrite{ + Forward: make(map[string]string), + Reverse: make(map[string]string), + } + for _, name := range mimicableNames { + fake := sanitizeToolName(name, dynamic) + if fake == name { + continue + } + rw.Forward[name] = fake + rw.Reverse[fake] = name + } + if len(rw.Forward) == 0 { + return nil + } + + rw.ReverseOrdered = make([][2]string, 0, len(rw.Reverse)) + for fake, real := range rw.Reverse { + rw.ReverseOrdered = append(rw.ReverseOrdered, [2]string{fake, real}) + } + sort.SliceStable(rw.ReverseOrdered, func(i, j int) bool { + return len(rw.ReverseOrdered[i][0]) > len(rw.ReverseOrdered[j][0]) + }) + + return rw +} + +// applyToolNameRewriteToBody 把已构造的 ToolNameRewrite 应用到 body 上: +// - 改写 $.tools[*].name(仅对 shouldMimicToolName 通过的 tool) +// - 在 $.tools[last].cache_control 上打 ephemeral 缓存断点(Parrot 行为对齐, +// ttl 客户端已有则透传,否则默认 claude.DefaultCacheControlTTL) +// - 改写 $.tool_choice.name(仅当 $.tool_choice.type == "tool") +// +// 历史 $.messages[*].content[*].name(tool_use)不在请求侧改写——这与 Parrot 一致; +// 响应侧 bytes.Replace 会连带还原它们。 +func applyToolNameRewriteToBody(body []byte, rw *ToolNameRewrite) []byte { + if rw == nil || len(rw.Forward) == 0 { + body = applyToolsLastCacheBreakpoint(body) + return body + } + + tools := gjson.GetBytes(body, "tools") + if tools.IsArray() { + idx := -1 + tools.ForEach(func(_, t gjson.Result) bool { + idx++ + if !shouldMimicToolName(t.Get("type").String()) { + return true + } + name := t.Get("name").String() + if name == "" { + return true + } + fake, ok := rw.Forward[name] + if !ok { + return true + } + if next, err := sjson.SetBytes(body, fmt.Sprintf("tools.%d.name", idx), fake); err == nil { + body = next + } + return true + }) + } + + if tc := gjson.GetBytes(body, "tool_choice"); tc.Exists() && tc.Get("type").String() == "tool" { + name := tc.Get("name").String() + if fake, ok := rw.Forward[name]; ok { + if next, err := sjson.SetBytes(body, "tool_choice.name", fake); err == nil { + body = next + } + } + } + + body = applyToolsLastCacheBreakpoint(body) + return body +} + +// applyToolsLastCacheBreakpoint 在 tools 数组最后一个工具上注入 cache_control +// 断点,对齐 Parrot `tools[-1]["cache_control"] = {"type":"ephemeral","ttl":"1h"}` +// 行为,但 ttl 按本仓规则: +// - 客户端已为该 tool 显式设置 cache_control.ttl → 完全透传不覆盖 +// - 否则注入 {"type":"ephemeral","ttl": claude.DefaultCacheControlTTL} +// +// 纯副作用函数,tools 不存在或为空数组时 no-op。 +func applyToolsLastCacheBreakpoint(body []byte) []byte { + tools := gjson.GetBytes(body, "tools") + if !tools.IsArray() { + return body + } + arr := tools.Array() + if len(arr) == 0 { + return body + } + lastIdx := len(arr) - 1 + existingCC := arr[lastIdx].Get("cache_control") + + if existingCC.Exists() && existingCC.Get("ttl").String() != "" { + return body + } + + if existingCC.Exists() { + if next, err := sjson.SetBytes(body, fmt.Sprintf("tools.%d.cache_control.ttl", lastIdx), claude.DefaultCacheControlTTL); err == nil { + body = next + } + return body + } + + raw := fmt.Sprintf(`{"type":"ephemeral","ttl":%q}`, claude.DefaultCacheControlTTL) + if next, err := sjson.SetRawBytes(body, fmt.Sprintf("tools.%d.cache_control", lastIdx), []byte(raw)); err == nil { + body = next + } + return body +} + +// restoreToolNamesInBytes 对 bytes chunk 做逆向还原:假名 → 真名。 +// 按 ReverseOrdered 的假名长度倒序逐个 bytes.Replace,防止子串冲突 +// (与 Parrot _restore_tool_names_in_chunk 的 sorted(..., reverse=True) 等价)。 +// 再做静态前缀还原(cc_sess_ → sessions_ / cc_ses_ → session_)。 +// +// rw 可为 nil;nil 时仍会做静态前缀还原。 +func restoreToolNamesInBytes(data []byte, rw *ToolNameRewrite) []byte { + if rw != nil { + for _, pair := range rw.ReverseOrdered { + fake, real := pair[0], pair[1] + if fake == "" || fake == real { + continue + } + data = replaceAllBytes(data, fake, real) + } + } + for prefix, replacement := range staticToolNameRewrites { + data = replaceAllBytes(data, replacement, prefix) + } + return data +} + +// replaceAllBytes 是 bytes.ReplaceAll 的便捷封装,避免每个调用点各自做 []byte 转换。 +func replaceAllBytes(data []byte, from, to string) []byte { + if len(data) == 0 || from == to || !strings.Contains(string(data), from) { + return data + } + return []byte(strings.ReplaceAll(string(data), from, to)) +} + +// toolNameRewriteFromContext 从 gin.Context 取出请求阶段保存的工具名映射。 +// 找不到(c==nil 或 key 不存在或类型不对)时返回 nil;调用方必须能处理 nil。 +func toolNameRewriteFromContext(c interface { + Get(string) (any, bool) +}) *ToolNameRewrite { + if c == nil { + return nil + } + raw, ok := c.Get(toolNameRewriteKey) + if !ok || raw == nil { + return nil + } + rw, _ := raw.(*ToolNameRewrite) + return rw +} + +// reverseToolNamesIfPresent 是响应侧 5 处注入点的统一封装:从 c 取出 mapping +// 并对 chunk 做 bytes 级假名→真名替换。c 没有 mapping 时仍会做静态前缀还原。 +func reverseToolNamesIfPresent(c interface { + Get(string) (any, bool) +}, chunk []byte) []byte { + rw := toolNameRewriteFromContext(c) + if rw == nil && len(staticToolNameRewrites) == 0 { + return chunk + } + return restoreToolNamesInBytes(chunk, rw) +} diff --git a/backend/internal/service/gateway_tool_rewrite_test.go b/backend/internal/service/gateway_tool_rewrite_test.go new file mode 100644 index 00000000..8f0e3939 --- /dev/null +++ b/backend/internal/service/gateway_tool_rewrite_test.go @@ -0,0 +1,185 @@ +package service + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" +) + +func TestBuildDynamicToolMap_BelowThreshold(t *testing.T) { + // Parrot 行为:tools 数量 ≤ 5 时不做动态映射。 + names := []string{"bash", "edit", "read", "write", "search"} + require.Nil(t, buildDynamicToolMap(names)) +} + +func TestBuildDynamicToolMap_AboveThresholdIsStable(t *testing.T) { + // Parrot 不变量:同一组 tool_names 在同进程内映射稳定(保证 cache 命中)。 + names := []string{"alpha", "beta", "gamma", "delta", "epsilon", "zeta"} + a := buildDynamicToolMap(names) + b := buildDynamicToolMap(names) + require.NotNil(t, a) + require.Equal(t, a, b, "same input tool_names must yield identical mapping") + require.Len(t, a, 6) + for _, name := range names { + require.Contains(t, a, name) + require.NotEqual(t, name, a[name]) + } +} + +func TestSanitizeToolName_StaticPrefix(t *testing.T) { + require.Equal(t, "cc_sess_list", sanitizeToolName("sessions_list", nil)) + require.Equal(t, "cc_ses_get", sanitizeToolName("session_get", nil)) + require.Equal(t, "bash", sanitizeToolName("bash", nil)) +} + +func TestSanitizeToolName_DynamicTakesPrecedence(t *testing.T) { + dyn := map[string]string{"sessions_list": "analyze_ses00"} + got := sanitizeToolName("sessions_list", dyn) + require.Equal(t, "analyze_ses00", got, "dynamic mapping wins over static prefix") +} + +func TestRestoreToolNamesInBytes_LongestFirst(t *testing.T) { + // 当假名 "abc_12" 是另一个更长假名的子串(真实场景极少但算法必须防御)时, + // 长的必须先替换。本测试用显式构造的映射来验证排序不变量。 + rw := &ToolNameRewrite{ + Forward: map[string]string{"foo": "abc_12", "bar": "abc_12_ext"}, + Reverse: map[string]string{"abc_12": "foo", "abc_12_ext": "bar"}, + } + // 手工构造 ReverseOrdered:长的在前 + rw.ReverseOrdered = [][2]string{ + {"abc_12_ext", "bar"}, + {"abc_12", "foo"}, + } + data := []byte(`{"tool":"abc_12_ext","other":"abc_12"}`) + restored := string(restoreToolNamesInBytes(data, rw)) + require.Equal(t, `{"tool":"bar","other":"foo"}`, restored) +} + +func TestRestoreToolNamesInBytes_StaticPrefixRollback(t *testing.T) { + data := []byte(`{"name":"sessions_list","id":"cc_ses_xyz"}`) + got := string(restoreToolNamesInBytes(data, nil)) + require.Equal(t, `{"name":"sessions_list","id":"session_xyz"}`, got) +} + +func TestApplyToolNameRewriteToBody_RenamesToolsAndToolChoice(t *testing.T) { + body := []byte(`{"tools":[{"name":"sessions_list","input_schema":{}},{"name":"session_get","input_schema":{}},{"name":"web_search","type":"web_search_20250305"}],"tool_choice":{"type":"tool","name":"sessions_list"}}`) + rw := buildToolNameRewriteFromBody(body) + require.NotNil(t, rw) + require.Contains(t, rw.Forward, "sessions_list") + require.Contains(t, rw.Forward, "session_get") + // web_search is a server tool, not rewritten + require.NotContains(t, rw.Forward, "web_search") + + out := applyToolNameRewriteToBody(body, rw) + + // tools[0].name and tools[1].name rewritten; tools[2].name untouched + require.Equal(t, "cc_sess_list", gjson.GetBytes(out, "tools.0.name").String()) + require.Equal(t, "cc_ses_get", gjson.GetBytes(out, "tools.1.name").String()) + require.Equal(t, "web_search", gjson.GetBytes(out, "tools.2.name").String()) + + // tool_choice.name rewritten + require.Equal(t, "cc_sess_list", gjson.GetBytes(out, "tool_choice.name").String()) + require.Equal(t, "tool", gjson.GetBytes(out, "tool_choice.type").String()) +} + +func TestApplyToolsLastCacheBreakpoint_InjectsDefault(t *testing.T) { + body := []byte(`{"tools":[{"name":"a","input_schema":{}},{"name":"b","input_schema":{}}]}`) + out := applyToolsLastCacheBreakpoint(body) + require.Equal(t, "ephemeral", gjson.GetBytes(out, "tools.1.cache_control.type").String()) + require.Equal(t, "5m", gjson.GetBytes(out, "tools.1.cache_control.ttl").String()) + // First tool untouched + require.False(t, gjson.GetBytes(out, "tools.0.cache_control").Exists()) +} + +func TestApplyToolsLastCacheBreakpoint_PassesThroughClientTTL(t *testing.T) { + body := []byte(`{"tools":[{"name":"a","input_schema":{},"cache_control":{"type":"ephemeral","ttl":"1h"}}]}`) + out := applyToolsLastCacheBreakpoint(body) + // User-provided ttl must be preserved. + require.Equal(t, "1h", gjson.GetBytes(out, "tools.0.cache_control.ttl").String()) +} + +func TestStripMessageCacheControl(t *testing.T) { + body := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi","cache_control":{"type":"ephemeral"}}]}]}`) + out := stripMessageCacheControl(body) + require.False(t, gjson.GetBytes(out, "messages.0.content.0.cache_control").Exists()) +} + +func TestAddMessageCacheBreakpoints_LastMessageOnly(t *testing.T) { + body := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}]}`) + out := addMessageCacheBreakpoints(body) + require.Equal(t, "ephemeral", gjson.GetBytes(out, "messages.0.content.0.cache_control.type").String()) + require.Equal(t, "5m", gjson.GetBytes(out, "messages.0.content.0.cache_control.ttl").String()) +} + +func TestAddMessageCacheBreakpoints_SecondToLastUserTurn(t *testing.T) { + // Parrot 不变量:messages ≥ 4 时才打第二个断点,且位置是"倒数第二个 user turn"。 + body := []byte(`{"messages":[ + {"role":"user","content":[{"type":"text","text":"q1"}]}, + {"role":"assistant","content":[{"type":"text","text":"a1"}]}, + {"role":"user","content":[{"type":"text","text":"q2"}]}, + {"role":"assistant","content":[{"type":"text","text":"a2"}]} + ]}`) + out := addMessageCacheBreakpoints(body) + // 最后一条 assistant 被打断点 + require.Equal(t, "ephemeral", gjson.GetBytes(out, "messages.3.content.0.cache_control.type").String()) + // 倒数第二个 user turn = index 0(唯一另一个 user) + require.Equal(t, "ephemeral", gjson.GetBytes(out, "messages.0.content.0.cache_control.type").String()) + // 其他不打断点 + require.False(t, gjson.GetBytes(out, "messages.1.content.0.cache_control").Exists()) + require.False(t, gjson.GetBytes(out, "messages.2.content.0.cache_control").Exists()) +} + +func TestAddMessageCacheBreakpoints_StringContentPromoted(t *testing.T) { + body := []byte(`{"messages":[{"role":"user","content":"hi"}]}`) + out := addMessageCacheBreakpoints(body) + // content 升级成数组 + require.True(t, gjson.GetBytes(out, "messages.0.content").IsArray()) + require.Equal(t, "text", gjson.GetBytes(out, "messages.0.content.0.type").String()) + require.Equal(t, "hi", gjson.GetBytes(out, "messages.0.content.0.text").String()) + require.Equal(t, "5m", gjson.GetBytes(out, "messages.0.content.0.cache_control.ttl").String()) +} + +func TestBuildToolNameRewriteFromBody_ReverseOrderedByLengthDesc(t *testing.T) { + // 超过阈值触发动态映射,验证 ReverseOrdered 按假名长度倒序排列 + body := []byte(`{"tools":[ + {"name":"t1","input_schema":{}}, + {"name":"t2","input_schema":{}}, + {"name":"t3","input_schema":{}}, + {"name":"t4","input_schema":{}}, + {"name":"t5","input_schema":{}}, + {"name":"t6","input_schema":{}} + ]}`) + rw := buildToolNameRewriteFromBody(body) + require.NotNil(t, rw) + require.NotEmpty(t, rw.ReverseOrdered) + for i := 1; i < len(rw.ReverseOrdered); i++ { + require.GreaterOrEqual(t, len(rw.ReverseOrdered[i-1][0]), len(rw.ReverseOrdered[i][0]), + "ReverseOrdered must be sorted by fake-name length descending") + } +} + +func TestRestoreToolNamesInBytes_NoMapping_NoStaticMatch_IsNoop(t *testing.T) { + data := []byte("plain text without any tool names") + require.Equal(t, string(data), string(restoreToolNamesInBytes(data, nil))) +} + +// Ensure the fake name format follows Parrot's "{prefix}{name[:3]}{i:02d}". +func TestBuildDynamicToolMap_FakeNameShape(t *testing.T) { + names := []string{"alphabet", "bravo", "charlie", "delta", "echo", "foxtrot"} + m := buildDynamicToolMap(names) + require.NotNil(t, m) + for _, name := range names { + fake, ok := m[name] + require.True(t, ok) + // fake = prefix + head3 + "%02d" + // ends with two decimal digits + require.Regexp(t, `^[a-z]+_[a-z0-9]{1,3}\d{2}$`, fake) + head := name + if len(head) > 3 { + head = head[:3] + } + require.True(t, strings.Contains(fake, head), "fake %q should contain head3 %q of %q", fake, head, name) + } +}