diff --git a/backend/internal/service/gateway_sanitize_test.go b/backend/internal/service/gateway_sanitize_test.go index 8fa971ca..a62bc8c7 100644 --- a/backend/internal/service/gateway_sanitize_test.go +++ b/backend/internal/service/gateway_sanitize_test.go @@ -12,10 +12,3 @@ func TestSanitizeOpenCodeText_RewritesCanonicalSentence(t *testing.T) { got := sanitizeSystemText(in) require.Equal(t, strings.TrimSpace(claudeCodeSystemPrompt), got) } - -func TestSanitizeToolDescription_DoesNotRewriteKeywords(t *testing.T) { - in := "OpenCode and opencode are mentioned." - got := sanitizeToolDescription(in) - // We no longer rewrite tool descriptions; only redact obvious path leaks. - require.Equal(t, in, got) -} diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index bbfb1723..308f0f18 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -207,40 +207,6 @@ var ( sseDataRe = regexp.MustCompile(`^data:\s*`) sessionIDRegex = regexp.MustCompile(`session_([a-f0-9-]{36})`) claudeCliUserAgentRe = regexp.MustCompile(`^claude-cli/\d+\.\d+\.\d+`) - toolPrefixRe = regexp.MustCompile(`(?i)^(?:oc_|mcp_)`) - toolNameBoundaryRe = regexp.MustCompile(`[^a-zA-Z0-9]+`) - toolNameCamelRe = regexp.MustCompile(`([a-z0-9])([A-Z])`) - toolNameFieldRe = regexp.MustCompile(`"name"\s*:\s*"([^"]+)"`) - modelFieldRe = regexp.MustCompile(`"model"\s*:\s*"([^"]+)"`) - toolDescAbsPathRe = regexp.MustCompile(`/\/?(?:home|Users|tmp|var|opt|usr|etc)\/[^\s,\)"'\]]+`) - toolDescWinPathRe = regexp.MustCompile(`(?i)[A-Z]:\\[^\s,\)"'\]]+`) - - claudeToolNameOverrides = map[string]string{ - "bash": "Bash", - "read": "Read", - "edit": "Edit", - "write": "Write", - "task": "Task", - "glob": "Glob", - "grep": "Grep", - "webfetch": "WebFetch", - "websearch": "WebSearch", - "todowrite": "TodoWrite", - "question": "AskUserQuestion", - } - openCodeToolOverrides = map[string]string{ - "Bash": "bash", - "Read": "read", - "Edit": "edit", - "Write": "write", - "Task": "task", - "Glob": "glob", - "Grep": "grep", - "WebFetch": "webfetch", - "WebSearch": "websearch", - "TodoWrite": "todowrite", - "AskUserQuestion": "question", - } // claudeCodePromptPrefixes 用于检测 Claude Code 系统提示词的前缀列表 // 支持多种变体:标准版、Agent SDK 版、Explore Agent 版、Compact 版等 @@ -616,71 +582,6 @@ type claudeOAuthNormalizeOptions struct { stripSystemCacheControl bool } -func stripToolPrefix(value string) string { - if value == "" { - return value - } - return toolPrefixRe.ReplaceAllString(value, "") -} - -func toSnakeCase(value string) string { - if value == "" { - return value - } - output := toolNameCamelRe.ReplaceAllString(value, "$1_$2") - output = toolNameBoundaryRe.ReplaceAllString(output, "_") - output = strings.Trim(output, "_") - return strings.ToLower(output) -} - -func normalizeToolNameForClaude(name string, cache map[string]string) string { - if name == "" { - return name - } - stripped := stripToolPrefix(name) - // 只对已知的工具名进行映射,未知工具名保持原样 - // 避免破坏 Anthropic 特殊工具(如 text_editor_20250728) - mapped, ok := claudeToolNameOverrides[strings.ToLower(stripped)] - if !ok { - return stripped - } - if cache != nil && mapped != stripped { - cache[mapped] = stripped - } - return mapped -} - -func normalizeToolNameForOpenCode(name string, cache map[string]string) string { - if name == "" { - return name - } - stripped := stripToolPrefix(name) - // 优先从请求时建立的映射中查找 - if cache != nil { - if mapped, ok := cache[stripped]; ok { - return mapped - } - } - // 已知工具名的硬编码映射 - if mapped, ok := openCodeToolOverrides[stripped]; ok { - return mapped - } - // 未知工具名保持原样,避免破坏 Anthropic 特殊工具 - return stripped -} - -func normalizeParamNameForOpenCode(name string, cache map[string]string) string { - if name == "" { - return name - } - if cache != nil { - if mapped, ok := cache[name]; ok { - return mapped - } - } - return name -} - // sanitizeSystemText rewrites only the fixed OpenCode identity sentence (if present). // We intentionally avoid broad keyword replacement in system prompts to prevent // accidentally changing user-provided instructions. @@ -699,55 +600,6 @@ func sanitizeSystemText(text string) string { return text } -func sanitizeToolDescription(description string) string { - if description == "" { - return description - } - description = toolDescAbsPathRe.ReplaceAllString(description, "[path]") - description = toolDescWinPathRe.ReplaceAllString(description, "[path]") - // Intentionally do NOT rewrite tool descriptions (OpenCode/Claude strings). - // Tool names/skill names may rely on exact wording, and rewriting can be misleading. - return description -} - -func normalizeToolInputSchema(inputSchema any, cache map[string]string) { - schema, ok := inputSchema.(map[string]any) - if !ok { - return - } - properties, ok := schema["properties"].(map[string]any) - if !ok { - return - } - - newProperties := make(map[string]any, len(properties)) - for key, value := range properties { - snakeKey := toSnakeCase(key) - newProperties[snakeKey] = value - if snakeKey != key && cache != nil { - cache[snakeKey] = key - } - } - schema["properties"] = newProperties - - if required, ok := schema["required"].([]any); ok { - newRequired := make([]any, 0, len(required)) - for _, item := range required { - name, ok := item.(string) - if !ok { - newRequired = append(newRequired, item) - continue - } - snakeName := toSnakeCase(name) - newRequired = append(newRequired, snakeName) - if snakeName != name && cache != nil { - cache[snakeName] = name - } - } - schema["required"] = newRequired - } -} - func stripCacheControlFromSystemBlocks(system any) bool { blocks, ok := system.([]any) if !ok { @@ -768,24 +620,17 @@ func stripCacheControlFromSystemBlocks(system any) bool { return changed } -func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAuthNormalizeOptions) ([]byte, string, map[string]string) { +func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAuthNormalizeOptions) ([]byte, string) { if len(body) == 0 { - return body, modelID, nil + return body, modelID } - // 使用 json.RawMessage 保留 messages 的原始字节,避免 thinking 块被修改 - var reqRaw map[string]json.RawMessage - if err := json.Unmarshal(body, &reqRaw); err != nil { - return body, modelID, nil - } - - // 同时解析为 map[string]any 用于修改非 messages 字段 + // 解析为 map[string]any 用于修改字段 var req map[string]any if err := json.Unmarshal(body, &req); err != nil { - return body, modelID, nil + return body, modelID } - toolNameMap := make(map[string]string) modified := false if system, ok := req["system"]; ok { @@ -827,115 +672,12 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu } } - if rawTools, exists := req["tools"]; exists { - switch tools := rawTools.(type) { - case []any: - for idx, tool := range tools { - toolMap, ok := tool.(map[string]any) - if !ok { - continue - } - if name, ok := toolMap["name"].(string); ok { - normalized := normalizeToolNameForClaude(name, toolNameMap) - if normalized != "" && normalized != name { - toolMap["name"] = normalized - modified = true - } - } - if desc, ok := toolMap["description"].(string); ok { - sanitized := sanitizeToolDescription(desc) - if sanitized != desc { - toolMap["description"] = sanitized - modified = true - } - } - if schema, ok := toolMap["input_schema"]; ok { - normalizeToolInputSchema(schema, toolNameMap) - modified = true - } - tools[idx] = toolMap - } - req["tools"] = tools - case map[string]any: - normalizedTools := make(map[string]any, len(tools)) - for name, value := range tools { - normalized := normalizeToolNameForClaude(name, toolNameMap) - if normalized == "" { - normalized = name - } - if toolMap, ok := value.(map[string]any); ok { - toolMap["name"] = normalized - if desc, ok := toolMap["description"].(string); ok { - sanitized := sanitizeToolDescription(desc) - if sanitized != desc { - toolMap["description"] = sanitized - } - } - if schema, ok := toolMap["input_schema"]; ok { - normalizeToolInputSchema(schema, toolNameMap) - } - normalizedTools[normalized] = toolMap - continue - } - normalizedTools[normalized] = value - } - req["tools"] = normalizedTools - modified = true - } - } else { + // 确保 tools 字段存在(即使为空数组) + if _, exists := req["tools"]; !exists { req["tools"] = []any{} modified = true } - // 处理 messages 中的 tool_use 块,但保留包含 thinking 块的消息的原始字节 - messagesModified := false - if messages, ok := req["messages"].([]any); ok { - for _, msg := range messages { - msgMap, ok := msg.(map[string]any) - if !ok { - continue - } - content, ok := msgMap["content"].([]any) - if !ok { - continue - } - // 检查此消息是否包含 thinking 块 - hasThinking := false - for _, block := range content { - blockMap, ok := block.(map[string]any) - if !ok { - continue - } - blockType, _ := blockMap["type"].(string) - if blockType == "thinking" || blockType == "redacted_thinking" { - hasThinking = true - break - } - } - // 如果包含 thinking 块,跳过此消息的修改 - if hasThinking { - continue - } - // 只修改不包含 thinking 块的消息中的 tool_use - for _, block := range content { - blockMap, ok := block.(map[string]any) - if !ok { - continue - } - if blockType, _ := blockMap["type"].(string); blockType != "tool_use" { - continue - } - if name, ok := blockMap["name"].(string); ok { - normalized := normalizeToolNameForClaude(name, toolNameMap) - if normalized != "" && normalized != name { - blockMap["name"] = normalized - messagesModified = true - } - } - } - } - } - if opts.stripSystemCacheControl { if system, ok := req["system"]; ok { _ = stripCacheControlFromSystemBlocks(system) @@ -964,38 +706,15 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu modified = true } - if !modified && !messagesModified { - return body, modelID, toolNameMap + if !modified { + return body, modelID } - // 如果 messages 没有被修改,保留原始 messages 字节 - if !messagesModified { - // 序列化非 messages 字段 - newBody, err := json.Marshal(req) - if err != nil { - return body, modelID, toolNameMap - } - // 替换回原始的 messages - var newReq map[string]json.RawMessage - if err := json.Unmarshal(newBody, &newReq); err != nil { - return newBody, modelID, toolNameMap - } - if origMessages, ok := reqRaw["messages"]; ok { - newReq["messages"] = origMessages - } - finalBody, err := json.Marshal(newReq) - if err != nil { - return newBody, modelID, toolNameMap - } - return finalBody, modelID, toolNameMap - } - - // messages 被修改了,需要完整序列化 newBody, err := json.Marshal(req) if err != nil { - return body, modelID, toolNameMap + return body, modelID } - return newBody, modelID, toolNameMap + return newBody, modelID } func (s *GatewayService) buildOAuthMetadataUserID(parsed *ParsedRequest, account *Account, fp *Fingerprint) string { @@ -2960,7 +2679,6 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A reqModel := parsed.Model reqStream := parsed.Stream originalModel := reqModel - var toolNameMap map[string]string isClaudeCode := isClaudeCodeRequest(ctx, c, parsed) shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCode @@ -2984,7 +2702,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A } } - body, reqModel, toolNameMap = normalizeClaudeOAuthRequestBody(body, reqModel, normalizeOpts) + body, reqModel = normalizeClaudeOAuthRequestBody(body, reqModel, normalizeOpts) } // 强制执行 cache_control 块数量限制(最多 4 个) @@ -3371,7 +3089,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A var firstTokenMs *int var clientDisconnect bool if reqStream { - streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, reqModel, toolNameMap, shouldMimicClaudeCode) + streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, reqModel, shouldMimicClaudeCode) if err != nil { if err.Error() == "have error in stream" { return nil, &UpstreamFailoverError{ @@ -3384,7 +3102,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A firstTokenMs = streamResult.firstTokenMs clientDisconnect = streamResult.clientDisconnect } else { - usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, reqModel, toolNameMap, shouldMimicClaudeCode) + usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, reqModel) if err != nil { return nil, err } @@ -3998,7 +3716,7 @@ type streamingResult struct { clientDisconnect bool // 客户端是否在流式传输过程中断开 } -func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string, toolNameMap map[string]string, mimicClaudeCode bool) (*streamingResult, error) { +func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string, mimicClaudeCode bool) (*streamingResult, error) { // 更新5h窗口状态 s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header) @@ -4094,33 +3812,6 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http clientDisconnected := false // 客户端断开标志,断开后继续读取上游以获取完整usage pendingEventLines := make([]string, 0, 4) - var toolInputBuffers map[int]string - if mimicClaudeCode { - toolInputBuffers = make(map[int]string) - } - - transformToolInputJSON := func(raw string) string { - if !mimicClaudeCode { - return raw - } - raw = strings.TrimSpace(raw) - if raw == "" { - return raw - } - - var parsed any - if err := json.Unmarshal([]byte(raw), &parsed); err != nil { - return replaceToolNamesInText(raw, toolNameMap) - } - - rewritten, changed := rewriteParamKeysInValue(parsed, toolNameMap) - if changed { - if bytes, err := json.Marshal(rewritten); err == nil { - return string(bytes) - } - } - return raw - } processSSEEvent := func(lines []string) ([]string, string, error) { if len(lines) == 0 { @@ -4159,16 +3850,13 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http var event map[string]any if err := json.Unmarshal([]byte(dataLine), &event); err != nil { - replaced := dataLine - if mimicClaudeCode { - replaced = replaceToolNamesInText(dataLine, toolNameMap) - } + // JSON 解析失败,直接透传原始数据 block := "" if eventName != "" { block = "event: " + eventName + "\n" } - block += "data: " + replaced + "\n\n" - return []string{block}, replaced, nil + block += "data: " + dataLine + "\n\n" + return []string{block}, dataLine, nil } eventType, _ := event["type"].(string) @@ -4198,70 +3886,15 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http } } - if mimicClaudeCode && eventType == "content_block_delta" { - if delta, ok := event["delta"].(map[string]any); ok { - if deltaType, _ := delta["type"].(string); deltaType == "input_json_delta" { - if indexVal, ok := event["index"].(float64); ok { - index := int(indexVal) - if partial, ok := delta["partial_json"].(string); ok { - toolInputBuffers[index] += partial - } - } - return nil, dataLine, nil - } - } - } - - if mimicClaudeCode && eventType == "content_block_stop" { - if indexVal, ok := event["index"].(float64); ok { - index := int(indexVal) - if buffered := toolInputBuffers[index]; buffered != "" { - delete(toolInputBuffers, index) - - transformed := transformToolInputJSON(buffered) - synthetic := map[string]any{ - "type": "content_block_delta", - "index": index, - "delta": map[string]any{ - "type": "input_json_delta", - "partial_json": transformed, - }, - } - - synthBytes, synthErr := json.Marshal(synthetic) - if synthErr == nil { - synthBlock := "event: content_block_delta\n" + "data: " + string(synthBytes) + "\n\n" - - rewriteToolNamesInValue(event, toolNameMap) - stopBytes, stopErr := json.Marshal(event) - if stopErr == nil { - stopBlock := "" - if eventName != "" { - stopBlock = "event: " + eventName + "\n" - } - stopBlock += "data: " + string(stopBytes) + "\n\n" - return []string{synthBlock, stopBlock}, string(stopBytes), nil - } - } - } - } - } - - if mimicClaudeCode { - rewriteToolNamesInValue(event, toolNameMap) - } newData, err := json.Marshal(event) if err != nil { - replaced := dataLine - if mimicClaudeCode { - replaced = replaceToolNamesInText(dataLine, toolNameMap) - } + // 序列化失败,直接透传原始数据 block := "" if eventName != "" { block = "event: " + eventName + "\n" } - block += "data: " + replaced + "\n\n" - return []string{block}, replaced, nil + block += "data: " + dataLine + "\n\n" + return []string{block}, dataLine, nil } block := "" @@ -4360,126 +3993,6 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http } -func rewriteParamKeysInValue(value any, cache map[string]string) (any, bool) { - switch v := value.(type) { - case map[string]any: - changed := false - rewritten := make(map[string]any, len(v)) - for key, item := range v { - newKey := normalizeParamNameForOpenCode(key, cache) - newItem, childChanged := rewriteParamKeysInValue(item, cache) - if childChanged { - changed = true - } - if newKey != key { - changed = true - } - rewritten[newKey] = newItem - } - if !changed { - return value, false - } - return rewritten, true - case []any: - changed := false - rewritten := make([]any, len(v)) - for idx, item := range v { - newItem, childChanged := rewriteParamKeysInValue(item, cache) - if childChanged { - changed = true - } - rewritten[idx] = newItem - } - if !changed { - return value, false - } - return rewritten, true - default: - return value, false - } -} - -func rewriteToolNamesInValue(value any, toolNameMap map[string]string) bool { - switch v := value.(type) { - case map[string]any: - changed := false - if blockType, _ := v["type"].(string); blockType == "tool_use" { - if name, ok := v["name"].(string); ok { - mapped := normalizeToolNameForOpenCode(name, toolNameMap) - if mapped != name { - v["name"] = mapped - changed = true - } - } - if input, ok := v["input"].(map[string]any); ok { - rewrittenInput, inputChanged := rewriteParamKeysInValue(input, toolNameMap) - if inputChanged { - if m, ok := rewrittenInput.(map[string]any); ok { - v["input"] = m - changed = true - } - } - } - } - for _, item := range v { - if rewriteToolNamesInValue(item, toolNameMap) { - changed = true - } - } - return changed - case []any: - changed := false - for _, item := range v { - if rewriteToolNamesInValue(item, toolNameMap) { - changed = true - } - } - return changed - default: - return false - } -} - -func replaceToolNamesInText(text string, toolNameMap map[string]string) string { - if text == "" { - return text - } - output := toolNameFieldRe.ReplaceAllStringFunc(text, func(match string) string { - submatches := toolNameFieldRe.FindStringSubmatch(match) - if len(submatches) < 2 { - return match - } - name := submatches[1] - mapped := normalizeToolNameForOpenCode(name, toolNameMap) - if mapped == name { - return match - } - return strings.Replace(match, name, mapped, 1) - }) - output = modelFieldRe.ReplaceAllStringFunc(output, func(match string) string { - submatches := modelFieldRe.FindStringSubmatch(match) - if len(submatches) < 2 { - return match - } - model := submatches[1] - mapped := claude.DenormalizeModelID(model) - if mapped == model { - return match - } - return strings.Replace(match, model, mapped, 1) - }) - - for mapped, original := range toolNameMap { - if mapped == "" || original == "" || mapped == original { - continue - } - output = strings.ReplaceAll(output, "\""+mapped+"\":", "\""+original+"\":") - output = strings.ReplaceAll(output, "\\\""+mapped+"\\\":", "\\\""+original+"\\\":") - } - - return output -} - func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) { // 解析message_start获取input tokens(标准Claude API格式) var msgStart struct { @@ -4523,7 +4036,7 @@ func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) { } } -func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string, toolNameMap map[string]string, mimicClaudeCode bool) (*ClaudeUsage, error) { +func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string) (*ClaudeUsage, error) { // 更新5h窗口状态 s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header) @@ -4555,9 +4068,6 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h if originalModel != mappedModel { body = s.replaceModelInResponseBody(body, mappedModel, originalModel) } - if mimicClaudeCode { - body = s.replaceToolNamesInResponseBody(body, toolNameMap) - } responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders) @@ -4595,28 +4105,6 @@ func (s *GatewayService) replaceModelInResponseBody(body []byte, fromModel, toMo return newBody } -func (s *GatewayService) replaceToolNamesInResponseBody(body []byte, toolNameMap map[string]string) []byte { - if len(body) == 0 { - return body - } - var resp map[string]any - if err := json.Unmarshal(body, &resp); err != nil { - replaced := replaceToolNamesInText(string(body), toolNameMap) - if replaced == string(body) { - return body - } - return []byte(replaced) - } - if !rewriteToolNamesInValue(resp, toolNameMap) { - return body - } - newBody, err := json.Marshal(resp) - if err != nil { - return body - } - return newBody -} - // RecordUsageInput 记录使用量的输入参数 type RecordUsageInput struct { Result *ForwardResult @@ -4977,7 +4465,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, if shouldMimicClaudeCode { normalizeOpts := claudeOAuthNormalizeOptions{stripSystemCacheControl: true} - body, reqModel, _ = normalizeClaudeOAuthRequestBody(body, reqModel, normalizeOpts) + body, reqModel = normalizeClaudeOAuthRequestBody(body, reqModel, normalizeOpts) } // Antigravity 账户不支持 count_tokens 转发,直接返回空值