fix(gateway): 移除 PR #316 引入的工具名转换逻辑

移除响应阶段的工具名/schema/description 转换逻辑,修复第三方工具调用时
工具名被错误转换的问题(如 Task → task)。

移除内容:
- 工具名相关正则变量(toolPrefixRe, toolNameBoundaryRe 等)
- openCodeToolOverrides 和 claudeToolNameOverrides 映射表
- 工具名转换函数(normalizeToolNameForClaude, normalizeToolNameForOpenCode 等)
- 响应体工具名替换函数(replaceToolNamesInText, replaceToolNamesInResponseBody 等)
- 参数名转换函数(normalizeParamNameForOpenCode, rewriteParamKeysInValue)
- 工具描述清理函数(sanitizeToolDescription)
- 输入 schema 转换函数(normalizeToolInputSchema)
- 模型 ID 正则替换函数(replaceModelIDInText)

保留内容:
- 系统提示词清理(sanitizeSystemText)
- Claude Code 指纹 headers 处理
- 模型 ID 映射(通过 JSON 对象操作)
This commit is contained in:
shaw
2026-02-06 16:09:58 +08:00
parent 4809fa4f19
commit d182ef0391
2 changed files with 22 additions and 541 deletions

View File

@@ -12,10 +12,3 @@ func TestSanitizeOpenCodeText_RewritesCanonicalSentence(t *testing.T) {
got := sanitizeSystemText(in)
require.Equal(t, strings.TrimSpace(claudeCodeSystemPrompt), got)
}
func TestSanitizeToolDescription_DoesNotRewriteKeywords(t *testing.T) {
in := "OpenCode and opencode are mentioned."
got := sanitizeToolDescription(in)
// We no longer rewrite tool descriptions; only redact obvious path leaks.
require.Equal(t, in, got)
}

View File

@@ -207,40 +207,6 @@ var (
sseDataRe = regexp.MustCompile(`^data:\s*`)
sessionIDRegex = regexp.MustCompile(`session_([a-f0-9-]{36})`)
claudeCliUserAgentRe = regexp.MustCompile(`^claude-cli/\d+\.\d+\.\d+`)
toolPrefixRe = regexp.MustCompile(`(?i)^(?:oc_|mcp_)`)
toolNameBoundaryRe = regexp.MustCompile(`[^a-zA-Z0-9]+`)
toolNameCamelRe = regexp.MustCompile(`([a-z0-9])([A-Z])`)
toolNameFieldRe = regexp.MustCompile(`"name"\s*:\s*"([^"]+)"`)
modelFieldRe = regexp.MustCompile(`"model"\s*:\s*"([^"]+)"`)
toolDescAbsPathRe = regexp.MustCompile(`/\/?(?:home|Users|tmp|var|opt|usr|etc)\/[^\s,\)"'\]]+`)
toolDescWinPathRe = regexp.MustCompile(`(?i)[A-Z]:\\[^\s,\)"'\]]+`)
claudeToolNameOverrides = map[string]string{
"bash": "Bash",
"read": "Read",
"edit": "Edit",
"write": "Write",
"task": "Task",
"glob": "Glob",
"grep": "Grep",
"webfetch": "WebFetch",
"websearch": "WebSearch",
"todowrite": "TodoWrite",
"question": "AskUserQuestion",
}
openCodeToolOverrides = map[string]string{
"Bash": "bash",
"Read": "read",
"Edit": "edit",
"Write": "write",
"Task": "task",
"Glob": "glob",
"Grep": "grep",
"WebFetch": "webfetch",
"WebSearch": "websearch",
"TodoWrite": "todowrite",
"AskUserQuestion": "question",
}
// claudeCodePromptPrefixes 用于检测 Claude Code 系统提示词的前缀列表
// 支持多种变体标准版、Agent SDK 版、Explore Agent 版、Compact 版等
@@ -616,71 +582,6 @@ type claudeOAuthNormalizeOptions struct {
stripSystemCacheControl bool
}
func stripToolPrefix(value string) string {
if value == "" {
return value
}
return toolPrefixRe.ReplaceAllString(value, "")
}
func toSnakeCase(value string) string {
if value == "" {
return value
}
output := toolNameCamelRe.ReplaceAllString(value, "$1_$2")
output = toolNameBoundaryRe.ReplaceAllString(output, "_")
output = strings.Trim(output, "_")
return strings.ToLower(output)
}
func normalizeToolNameForClaude(name string, cache map[string]string) string {
if name == "" {
return name
}
stripped := stripToolPrefix(name)
// 只对已知的工具名进行映射,未知工具名保持原样
// 避免破坏 Anthropic 特殊工具(如 text_editor_20250728
mapped, ok := claudeToolNameOverrides[strings.ToLower(stripped)]
if !ok {
return stripped
}
if cache != nil && mapped != stripped {
cache[mapped] = stripped
}
return mapped
}
func normalizeToolNameForOpenCode(name string, cache map[string]string) string {
if name == "" {
return name
}
stripped := stripToolPrefix(name)
// 优先从请求时建立的映射中查找
if cache != nil {
if mapped, ok := cache[stripped]; ok {
return mapped
}
}
// 已知工具名的硬编码映射
if mapped, ok := openCodeToolOverrides[stripped]; ok {
return mapped
}
// 未知工具名保持原样,避免破坏 Anthropic 特殊工具
return stripped
}
func normalizeParamNameForOpenCode(name string, cache map[string]string) string {
if name == "" {
return name
}
if cache != nil {
if mapped, ok := cache[name]; ok {
return mapped
}
}
return name
}
// sanitizeSystemText rewrites only the fixed OpenCode identity sentence (if present).
// We intentionally avoid broad keyword replacement in system prompts to prevent
// accidentally changing user-provided instructions.
@@ -699,55 +600,6 @@ func sanitizeSystemText(text string) string {
return text
}
func sanitizeToolDescription(description string) string {
if description == "" {
return description
}
description = toolDescAbsPathRe.ReplaceAllString(description, "[path]")
description = toolDescWinPathRe.ReplaceAllString(description, "[path]")
// Intentionally do NOT rewrite tool descriptions (OpenCode/Claude strings).
// Tool names/skill names may rely on exact wording, and rewriting can be misleading.
return description
}
func normalizeToolInputSchema(inputSchema any, cache map[string]string) {
schema, ok := inputSchema.(map[string]any)
if !ok {
return
}
properties, ok := schema["properties"].(map[string]any)
if !ok {
return
}
newProperties := make(map[string]any, len(properties))
for key, value := range properties {
snakeKey := toSnakeCase(key)
newProperties[snakeKey] = value
if snakeKey != key && cache != nil {
cache[snakeKey] = key
}
}
schema["properties"] = newProperties
if required, ok := schema["required"].([]any); ok {
newRequired := make([]any, 0, len(required))
for _, item := range required {
name, ok := item.(string)
if !ok {
newRequired = append(newRequired, item)
continue
}
snakeName := toSnakeCase(name)
newRequired = append(newRequired, snakeName)
if snakeName != name && cache != nil {
cache[snakeName] = name
}
}
schema["required"] = newRequired
}
}
func stripCacheControlFromSystemBlocks(system any) bool {
blocks, ok := system.([]any)
if !ok {
@@ -768,24 +620,17 @@ func stripCacheControlFromSystemBlocks(system any) bool {
return changed
}
func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAuthNormalizeOptions) ([]byte, string, map[string]string) {
func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAuthNormalizeOptions) ([]byte, string) {
if len(body) == 0 {
return body, modelID, nil
return body, modelID
}
// 使用 json.RawMessage 保留 messages 的原始字节,避免 thinking 块被修改
var reqRaw map[string]json.RawMessage
if err := json.Unmarshal(body, &reqRaw); err != nil {
return body, modelID, nil
}
// 同时解析为 map[string]any 用于修改非 messages 字段
// 解析为 map[string]any 用于修改字段
var req map[string]any
if err := json.Unmarshal(body, &req); err != nil {
return body, modelID, nil
return body, modelID
}
toolNameMap := make(map[string]string)
modified := false
if system, ok := req["system"]; ok {
@@ -827,115 +672,12 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
}
}
if rawTools, exists := req["tools"]; exists {
switch tools := rawTools.(type) {
case []any:
for idx, tool := range tools {
toolMap, ok := tool.(map[string]any)
if !ok {
continue
}
if name, ok := toolMap["name"].(string); ok {
normalized := normalizeToolNameForClaude(name, toolNameMap)
if normalized != "" && normalized != name {
toolMap["name"] = normalized
modified = true
}
}
if desc, ok := toolMap["description"].(string); ok {
sanitized := sanitizeToolDescription(desc)
if sanitized != desc {
toolMap["description"] = sanitized
modified = true
}
}
if schema, ok := toolMap["input_schema"]; ok {
normalizeToolInputSchema(schema, toolNameMap)
modified = true
}
tools[idx] = toolMap
}
req["tools"] = tools
case map[string]any:
normalizedTools := make(map[string]any, len(tools))
for name, value := range tools {
normalized := normalizeToolNameForClaude(name, toolNameMap)
if normalized == "" {
normalized = name
}
if toolMap, ok := value.(map[string]any); ok {
toolMap["name"] = normalized
if desc, ok := toolMap["description"].(string); ok {
sanitized := sanitizeToolDescription(desc)
if sanitized != desc {
toolMap["description"] = sanitized
}
}
if schema, ok := toolMap["input_schema"]; ok {
normalizeToolInputSchema(schema, toolNameMap)
}
normalizedTools[normalized] = toolMap
continue
}
normalizedTools[normalized] = value
}
req["tools"] = normalizedTools
modified = true
}
} else {
// 确保 tools 字段存在(即使为空数组)
if _, exists := req["tools"]; !exists {
req["tools"] = []any{}
modified = true
}
// 处理 messages 中的 tool_use 块,但保留包含 thinking 块的消息的原始字节
messagesModified := false
if messages, ok := req["messages"].([]any); ok {
for _, msg := range messages {
msgMap, ok := msg.(map[string]any)
if !ok {
continue
}
content, ok := msgMap["content"].([]any)
if !ok {
continue
}
// 检查此消息是否包含 thinking 块
hasThinking := false
for _, block := range content {
blockMap, ok := block.(map[string]any)
if !ok {
continue
}
blockType, _ := blockMap["type"].(string)
if blockType == "thinking" || blockType == "redacted_thinking" {
hasThinking = true
break
}
}
// 如果包含 thinking 块,跳过此消息的修改
if hasThinking {
continue
}
// 只修改不包含 thinking 块的消息中的 tool_use
for _, block := range content {
blockMap, ok := block.(map[string]any)
if !ok {
continue
}
if blockType, _ := blockMap["type"].(string); blockType != "tool_use" {
continue
}
if name, ok := blockMap["name"].(string); ok {
normalized := normalizeToolNameForClaude(name, toolNameMap)
if normalized != "" && normalized != name {
blockMap["name"] = normalized
messagesModified = true
}
}
}
}
}
if opts.stripSystemCacheControl {
if system, ok := req["system"]; ok {
_ = stripCacheControlFromSystemBlocks(system)
@@ -964,38 +706,15 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
modified = true
}
if !modified && !messagesModified {
return body, modelID, toolNameMap
if !modified {
return body, modelID
}
// 如果 messages 没有被修改,保留原始 messages 字节
if !messagesModified {
// 序列化非 messages 字段
newBody, err := json.Marshal(req)
if err != nil {
return body, modelID, toolNameMap
}
// 替换回原始的 messages
var newReq map[string]json.RawMessage
if err := json.Unmarshal(newBody, &newReq); err != nil {
return newBody, modelID, toolNameMap
}
if origMessages, ok := reqRaw["messages"]; ok {
newReq["messages"] = origMessages
}
finalBody, err := json.Marshal(newReq)
if err != nil {
return newBody, modelID, toolNameMap
}
return finalBody, modelID, toolNameMap
}
// messages 被修改了,需要完整序列化
newBody, err := json.Marshal(req)
if err != nil {
return body, modelID, toolNameMap
return body, modelID
}
return newBody, modelID, toolNameMap
return newBody, modelID
}
func (s *GatewayService) buildOAuthMetadataUserID(parsed *ParsedRequest, account *Account, fp *Fingerprint) string {
@@ -2960,7 +2679,6 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
reqModel := parsed.Model
reqStream := parsed.Stream
originalModel := reqModel
var toolNameMap map[string]string
isClaudeCode := isClaudeCodeRequest(ctx, c, parsed)
shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCode
@@ -2984,7 +2702,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
}
}
body, reqModel, toolNameMap = normalizeClaudeOAuthRequestBody(body, reqModel, normalizeOpts)
body, reqModel = normalizeClaudeOAuthRequestBody(body, reqModel, normalizeOpts)
}
// 强制执行 cache_control 块数量限制(最多 4 个)
@@ -3371,7 +3089,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
var firstTokenMs *int
var clientDisconnect bool
if reqStream {
streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, reqModel, toolNameMap, shouldMimicClaudeCode)
streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, reqModel, shouldMimicClaudeCode)
if err != nil {
if err.Error() == "have error in stream" {
return nil, &UpstreamFailoverError{
@@ -3384,7 +3102,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
firstTokenMs = streamResult.firstTokenMs
clientDisconnect = streamResult.clientDisconnect
} else {
usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, reqModel, toolNameMap, shouldMimicClaudeCode)
usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, reqModel)
if err != nil {
return nil, err
}
@@ -3998,7 +3716,7 @@ type streamingResult struct {
clientDisconnect bool // 客户端是否在流式传输过程中断开
}
func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string, toolNameMap map[string]string, mimicClaudeCode bool) (*streamingResult, error) {
func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string, mimicClaudeCode bool) (*streamingResult, error) {
// 更新5h窗口状态
s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header)
@@ -4094,33 +3812,6 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
clientDisconnected := false // 客户端断开标志断开后继续读取上游以获取完整usage
pendingEventLines := make([]string, 0, 4)
var toolInputBuffers map[int]string
if mimicClaudeCode {
toolInputBuffers = make(map[int]string)
}
transformToolInputJSON := func(raw string) string {
if !mimicClaudeCode {
return raw
}
raw = strings.TrimSpace(raw)
if raw == "" {
return raw
}
var parsed any
if err := json.Unmarshal([]byte(raw), &parsed); err != nil {
return replaceToolNamesInText(raw, toolNameMap)
}
rewritten, changed := rewriteParamKeysInValue(parsed, toolNameMap)
if changed {
if bytes, err := json.Marshal(rewritten); err == nil {
return string(bytes)
}
}
return raw
}
processSSEEvent := func(lines []string) ([]string, string, error) {
if len(lines) == 0 {
@@ -4159,16 +3850,13 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
var event map[string]any
if err := json.Unmarshal([]byte(dataLine), &event); err != nil {
replaced := dataLine
if mimicClaudeCode {
replaced = replaceToolNamesInText(dataLine, toolNameMap)
}
// JSON 解析失败,直接透传原始数据
block := ""
if eventName != "" {
block = "event: " + eventName + "\n"
}
block += "data: " + replaced + "\n\n"
return []string{block}, replaced, nil
block += "data: " + dataLine + "\n\n"
return []string{block}, dataLine, nil
}
eventType, _ := event["type"].(string)
@@ -4198,70 +3886,15 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
}
}
if mimicClaudeCode && eventType == "content_block_delta" {
if delta, ok := event["delta"].(map[string]any); ok {
if deltaType, _ := delta["type"].(string); deltaType == "input_json_delta" {
if indexVal, ok := event["index"].(float64); ok {
index := int(indexVal)
if partial, ok := delta["partial_json"].(string); ok {
toolInputBuffers[index] += partial
}
}
return nil, dataLine, nil
}
}
}
if mimicClaudeCode && eventType == "content_block_stop" {
if indexVal, ok := event["index"].(float64); ok {
index := int(indexVal)
if buffered := toolInputBuffers[index]; buffered != "" {
delete(toolInputBuffers, index)
transformed := transformToolInputJSON(buffered)
synthetic := map[string]any{
"type": "content_block_delta",
"index": index,
"delta": map[string]any{
"type": "input_json_delta",
"partial_json": transformed,
},
}
synthBytes, synthErr := json.Marshal(synthetic)
if synthErr == nil {
synthBlock := "event: content_block_delta\n" + "data: " + string(synthBytes) + "\n\n"
rewriteToolNamesInValue(event, toolNameMap)
stopBytes, stopErr := json.Marshal(event)
if stopErr == nil {
stopBlock := ""
if eventName != "" {
stopBlock = "event: " + eventName + "\n"
}
stopBlock += "data: " + string(stopBytes) + "\n\n"
return []string{synthBlock, stopBlock}, string(stopBytes), nil
}
}
}
}
}
if mimicClaudeCode {
rewriteToolNamesInValue(event, toolNameMap)
}
newData, err := json.Marshal(event)
if err != nil {
replaced := dataLine
if mimicClaudeCode {
replaced = replaceToolNamesInText(dataLine, toolNameMap)
}
// 序列化失败,直接透传原始数据
block := ""
if eventName != "" {
block = "event: " + eventName + "\n"
}
block += "data: " + replaced + "\n\n"
return []string{block}, replaced, nil
block += "data: " + dataLine + "\n\n"
return []string{block}, dataLine, nil
}
block := ""
@@ -4360,126 +3993,6 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
}
func rewriteParamKeysInValue(value any, cache map[string]string) (any, bool) {
switch v := value.(type) {
case map[string]any:
changed := false
rewritten := make(map[string]any, len(v))
for key, item := range v {
newKey := normalizeParamNameForOpenCode(key, cache)
newItem, childChanged := rewriteParamKeysInValue(item, cache)
if childChanged {
changed = true
}
if newKey != key {
changed = true
}
rewritten[newKey] = newItem
}
if !changed {
return value, false
}
return rewritten, true
case []any:
changed := false
rewritten := make([]any, len(v))
for idx, item := range v {
newItem, childChanged := rewriteParamKeysInValue(item, cache)
if childChanged {
changed = true
}
rewritten[idx] = newItem
}
if !changed {
return value, false
}
return rewritten, true
default:
return value, false
}
}
func rewriteToolNamesInValue(value any, toolNameMap map[string]string) bool {
switch v := value.(type) {
case map[string]any:
changed := false
if blockType, _ := v["type"].(string); blockType == "tool_use" {
if name, ok := v["name"].(string); ok {
mapped := normalizeToolNameForOpenCode(name, toolNameMap)
if mapped != name {
v["name"] = mapped
changed = true
}
}
if input, ok := v["input"].(map[string]any); ok {
rewrittenInput, inputChanged := rewriteParamKeysInValue(input, toolNameMap)
if inputChanged {
if m, ok := rewrittenInput.(map[string]any); ok {
v["input"] = m
changed = true
}
}
}
}
for _, item := range v {
if rewriteToolNamesInValue(item, toolNameMap) {
changed = true
}
}
return changed
case []any:
changed := false
for _, item := range v {
if rewriteToolNamesInValue(item, toolNameMap) {
changed = true
}
}
return changed
default:
return false
}
}
func replaceToolNamesInText(text string, toolNameMap map[string]string) string {
if text == "" {
return text
}
output := toolNameFieldRe.ReplaceAllStringFunc(text, func(match string) string {
submatches := toolNameFieldRe.FindStringSubmatch(match)
if len(submatches) < 2 {
return match
}
name := submatches[1]
mapped := normalizeToolNameForOpenCode(name, toolNameMap)
if mapped == name {
return match
}
return strings.Replace(match, name, mapped, 1)
})
output = modelFieldRe.ReplaceAllStringFunc(output, func(match string) string {
submatches := modelFieldRe.FindStringSubmatch(match)
if len(submatches) < 2 {
return match
}
model := submatches[1]
mapped := claude.DenormalizeModelID(model)
if mapped == model {
return match
}
return strings.Replace(match, model, mapped, 1)
})
for mapped, original := range toolNameMap {
if mapped == "" || original == "" || mapped == original {
continue
}
output = strings.ReplaceAll(output, "\""+mapped+"\":", "\""+original+"\":")
output = strings.ReplaceAll(output, "\\\""+mapped+"\\\":", "\\\""+original+"\\\":")
}
return output
}
func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) {
// 解析message_start获取input tokens标准Claude API格式
var msgStart struct {
@@ -4523,7 +4036,7 @@ func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) {
}
}
func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string, toolNameMap map[string]string, mimicClaudeCode bool) (*ClaudeUsage, error) {
func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string) (*ClaudeUsage, error) {
// 更新5h窗口状态
s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header)
@@ -4555,9 +4068,6 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h
if originalModel != mappedModel {
body = s.replaceModelInResponseBody(body, mappedModel, originalModel)
}
if mimicClaudeCode {
body = s.replaceToolNamesInResponseBody(body, toolNameMap)
}
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
@@ -4595,28 +4105,6 @@ func (s *GatewayService) replaceModelInResponseBody(body []byte, fromModel, toMo
return newBody
}
func (s *GatewayService) replaceToolNamesInResponseBody(body []byte, toolNameMap map[string]string) []byte {
if len(body) == 0 {
return body
}
var resp map[string]any
if err := json.Unmarshal(body, &resp); err != nil {
replaced := replaceToolNamesInText(string(body), toolNameMap)
if replaced == string(body) {
return body
}
return []byte(replaced)
}
if !rewriteToolNamesInValue(resp, toolNameMap) {
return body
}
newBody, err := json.Marshal(resp)
if err != nil {
return body
}
return newBody
}
// RecordUsageInput 记录使用量的输入参数
type RecordUsageInput struct {
Result *ForwardResult
@@ -4977,7 +4465,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
if shouldMimicClaudeCode {
normalizeOpts := claudeOAuthNormalizeOptions{stripSystemCacheControl: true}
body, reqModel, _ = normalizeClaudeOAuthRequestBody(body, reqModel, normalizeOpts)
body, reqModel = normalizeClaudeOAuthRequestBody(body, reqModel, normalizeOpts)
}
// Antigravity 账户不支持 count_tokens 转发,直接返回空值