fix(gateway): 移除 PR #316 引入的工具名转换逻辑
移除响应阶段的工具名/schema/description 转换逻辑,修复第三方工具调用时 工具名被错误转换的问题(如 Task → task)。 移除内容: - 工具名相关正则变量(toolPrefixRe, toolNameBoundaryRe 等) - openCodeToolOverrides 和 claudeToolNameOverrides 映射表 - 工具名转换函数(normalizeToolNameForClaude, normalizeToolNameForOpenCode 等) - 响应体工具名替换函数(replaceToolNamesInText, replaceToolNamesInResponseBody 等) - 参数名转换函数(normalizeParamNameForOpenCode, rewriteParamKeysInValue) - 工具描述清理函数(sanitizeToolDescription) - 输入 schema 转换函数(normalizeToolInputSchema) - 模型 ID 正则替换函数(replaceModelIDInText) 保留内容: - 系统提示词清理(sanitizeSystemText) - Claude Code 指纹 headers 处理 - 模型 ID 映射(通过 JSON 对象操作)
This commit is contained in:
@@ -12,10 +12,3 @@ func TestSanitizeOpenCodeText_RewritesCanonicalSentence(t *testing.T) {
|
||||
got := sanitizeSystemText(in)
|
||||
require.Equal(t, strings.TrimSpace(claudeCodeSystemPrompt), got)
|
||||
}
|
||||
|
||||
func TestSanitizeToolDescription_DoesNotRewriteKeywords(t *testing.T) {
|
||||
in := "OpenCode and opencode are mentioned."
|
||||
got := sanitizeToolDescription(in)
|
||||
// We no longer rewrite tool descriptions; only redact obvious path leaks.
|
||||
require.Equal(t, in, got)
|
||||
}
|
||||
|
||||
@@ -207,40 +207,6 @@ var (
|
||||
sseDataRe = regexp.MustCompile(`^data:\s*`)
|
||||
sessionIDRegex = regexp.MustCompile(`session_([a-f0-9-]{36})`)
|
||||
claudeCliUserAgentRe = regexp.MustCompile(`^claude-cli/\d+\.\d+\.\d+`)
|
||||
toolPrefixRe = regexp.MustCompile(`(?i)^(?:oc_|mcp_)`)
|
||||
toolNameBoundaryRe = regexp.MustCompile(`[^a-zA-Z0-9]+`)
|
||||
toolNameCamelRe = regexp.MustCompile(`([a-z0-9])([A-Z])`)
|
||||
toolNameFieldRe = regexp.MustCompile(`"name"\s*:\s*"([^"]+)"`)
|
||||
modelFieldRe = regexp.MustCompile(`"model"\s*:\s*"([^"]+)"`)
|
||||
toolDescAbsPathRe = regexp.MustCompile(`/\/?(?:home|Users|tmp|var|opt|usr|etc)\/[^\s,\)"'\]]+`)
|
||||
toolDescWinPathRe = regexp.MustCompile(`(?i)[A-Z]:\\[^\s,\)"'\]]+`)
|
||||
|
||||
claudeToolNameOverrides = map[string]string{
|
||||
"bash": "Bash",
|
||||
"read": "Read",
|
||||
"edit": "Edit",
|
||||
"write": "Write",
|
||||
"task": "Task",
|
||||
"glob": "Glob",
|
||||
"grep": "Grep",
|
||||
"webfetch": "WebFetch",
|
||||
"websearch": "WebSearch",
|
||||
"todowrite": "TodoWrite",
|
||||
"question": "AskUserQuestion",
|
||||
}
|
||||
openCodeToolOverrides = map[string]string{
|
||||
"Bash": "bash",
|
||||
"Read": "read",
|
||||
"Edit": "edit",
|
||||
"Write": "write",
|
||||
"Task": "task",
|
||||
"Glob": "glob",
|
||||
"Grep": "grep",
|
||||
"WebFetch": "webfetch",
|
||||
"WebSearch": "websearch",
|
||||
"TodoWrite": "todowrite",
|
||||
"AskUserQuestion": "question",
|
||||
}
|
||||
|
||||
// claudeCodePromptPrefixes 用于检测 Claude Code 系统提示词的前缀列表
|
||||
// 支持多种变体:标准版、Agent SDK 版、Explore Agent 版、Compact 版等
|
||||
@@ -616,71 +582,6 @@ type claudeOAuthNormalizeOptions struct {
|
||||
stripSystemCacheControl bool
|
||||
}
|
||||
|
||||
func stripToolPrefix(value string) string {
|
||||
if value == "" {
|
||||
return value
|
||||
}
|
||||
return toolPrefixRe.ReplaceAllString(value, "")
|
||||
}
|
||||
|
||||
func toSnakeCase(value string) string {
|
||||
if value == "" {
|
||||
return value
|
||||
}
|
||||
output := toolNameCamelRe.ReplaceAllString(value, "$1_$2")
|
||||
output = toolNameBoundaryRe.ReplaceAllString(output, "_")
|
||||
output = strings.Trim(output, "_")
|
||||
return strings.ToLower(output)
|
||||
}
|
||||
|
||||
func normalizeToolNameForClaude(name string, cache map[string]string) string {
|
||||
if name == "" {
|
||||
return name
|
||||
}
|
||||
stripped := stripToolPrefix(name)
|
||||
// 只对已知的工具名进行映射,未知工具名保持原样
|
||||
// 避免破坏 Anthropic 特殊工具(如 text_editor_20250728)
|
||||
mapped, ok := claudeToolNameOverrides[strings.ToLower(stripped)]
|
||||
if !ok {
|
||||
return stripped
|
||||
}
|
||||
if cache != nil && mapped != stripped {
|
||||
cache[mapped] = stripped
|
||||
}
|
||||
return mapped
|
||||
}
|
||||
|
||||
func normalizeToolNameForOpenCode(name string, cache map[string]string) string {
|
||||
if name == "" {
|
||||
return name
|
||||
}
|
||||
stripped := stripToolPrefix(name)
|
||||
// 优先从请求时建立的映射中查找
|
||||
if cache != nil {
|
||||
if mapped, ok := cache[stripped]; ok {
|
||||
return mapped
|
||||
}
|
||||
}
|
||||
// 已知工具名的硬编码映射
|
||||
if mapped, ok := openCodeToolOverrides[stripped]; ok {
|
||||
return mapped
|
||||
}
|
||||
// 未知工具名保持原样,避免破坏 Anthropic 特殊工具
|
||||
return stripped
|
||||
}
|
||||
|
||||
func normalizeParamNameForOpenCode(name string, cache map[string]string) string {
|
||||
if name == "" {
|
||||
return name
|
||||
}
|
||||
if cache != nil {
|
||||
if mapped, ok := cache[name]; ok {
|
||||
return mapped
|
||||
}
|
||||
}
|
||||
return name
|
||||
}
|
||||
|
||||
// sanitizeSystemText rewrites only the fixed OpenCode identity sentence (if present).
|
||||
// We intentionally avoid broad keyword replacement in system prompts to prevent
|
||||
// accidentally changing user-provided instructions.
|
||||
@@ -699,55 +600,6 @@ func sanitizeSystemText(text string) string {
|
||||
return text
|
||||
}
|
||||
|
||||
func sanitizeToolDescription(description string) string {
|
||||
if description == "" {
|
||||
return description
|
||||
}
|
||||
description = toolDescAbsPathRe.ReplaceAllString(description, "[path]")
|
||||
description = toolDescWinPathRe.ReplaceAllString(description, "[path]")
|
||||
// Intentionally do NOT rewrite tool descriptions (OpenCode/Claude strings).
|
||||
// Tool names/skill names may rely on exact wording, and rewriting can be misleading.
|
||||
return description
|
||||
}
|
||||
|
||||
func normalizeToolInputSchema(inputSchema any, cache map[string]string) {
|
||||
schema, ok := inputSchema.(map[string]any)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
properties, ok := schema["properties"].(map[string]any)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
newProperties := make(map[string]any, len(properties))
|
||||
for key, value := range properties {
|
||||
snakeKey := toSnakeCase(key)
|
||||
newProperties[snakeKey] = value
|
||||
if snakeKey != key && cache != nil {
|
||||
cache[snakeKey] = key
|
||||
}
|
||||
}
|
||||
schema["properties"] = newProperties
|
||||
|
||||
if required, ok := schema["required"].([]any); ok {
|
||||
newRequired := make([]any, 0, len(required))
|
||||
for _, item := range required {
|
||||
name, ok := item.(string)
|
||||
if !ok {
|
||||
newRequired = append(newRequired, item)
|
||||
continue
|
||||
}
|
||||
snakeName := toSnakeCase(name)
|
||||
newRequired = append(newRequired, snakeName)
|
||||
if snakeName != name && cache != nil {
|
||||
cache[snakeName] = name
|
||||
}
|
||||
}
|
||||
schema["required"] = newRequired
|
||||
}
|
||||
}
|
||||
|
||||
func stripCacheControlFromSystemBlocks(system any) bool {
|
||||
blocks, ok := system.([]any)
|
||||
if !ok {
|
||||
@@ -768,24 +620,17 @@ func stripCacheControlFromSystemBlocks(system any) bool {
|
||||
return changed
|
||||
}
|
||||
|
||||
func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAuthNormalizeOptions) ([]byte, string, map[string]string) {
|
||||
func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAuthNormalizeOptions) ([]byte, string) {
|
||||
if len(body) == 0 {
|
||||
return body, modelID, nil
|
||||
return body, modelID
|
||||
}
|
||||
|
||||
// 使用 json.RawMessage 保留 messages 的原始字节,避免 thinking 块被修改
|
||||
var reqRaw map[string]json.RawMessage
|
||||
if err := json.Unmarshal(body, &reqRaw); err != nil {
|
||||
return body, modelID, nil
|
||||
}
|
||||
|
||||
// 同时解析为 map[string]any 用于修改非 messages 字段
|
||||
// 解析为 map[string]any 用于修改字段
|
||||
var req map[string]any
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
return body, modelID, nil
|
||||
return body, modelID
|
||||
}
|
||||
|
||||
toolNameMap := make(map[string]string)
|
||||
modified := false
|
||||
|
||||
if system, ok := req["system"]; ok {
|
||||
@@ -827,115 +672,12 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
|
||||
}
|
||||
}
|
||||
|
||||
if rawTools, exists := req["tools"]; exists {
|
||||
switch tools := rawTools.(type) {
|
||||
case []any:
|
||||
for idx, tool := range tools {
|
||||
toolMap, ok := tool.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if name, ok := toolMap["name"].(string); ok {
|
||||
normalized := normalizeToolNameForClaude(name, toolNameMap)
|
||||
if normalized != "" && normalized != name {
|
||||
toolMap["name"] = normalized
|
||||
modified = true
|
||||
}
|
||||
}
|
||||
if desc, ok := toolMap["description"].(string); ok {
|
||||
sanitized := sanitizeToolDescription(desc)
|
||||
if sanitized != desc {
|
||||
toolMap["description"] = sanitized
|
||||
modified = true
|
||||
}
|
||||
}
|
||||
if schema, ok := toolMap["input_schema"]; ok {
|
||||
normalizeToolInputSchema(schema, toolNameMap)
|
||||
modified = true
|
||||
}
|
||||
tools[idx] = toolMap
|
||||
}
|
||||
req["tools"] = tools
|
||||
case map[string]any:
|
||||
normalizedTools := make(map[string]any, len(tools))
|
||||
for name, value := range tools {
|
||||
normalized := normalizeToolNameForClaude(name, toolNameMap)
|
||||
if normalized == "" {
|
||||
normalized = name
|
||||
}
|
||||
if toolMap, ok := value.(map[string]any); ok {
|
||||
toolMap["name"] = normalized
|
||||
if desc, ok := toolMap["description"].(string); ok {
|
||||
sanitized := sanitizeToolDescription(desc)
|
||||
if sanitized != desc {
|
||||
toolMap["description"] = sanitized
|
||||
}
|
||||
}
|
||||
if schema, ok := toolMap["input_schema"]; ok {
|
||||
normalizeToolInputSchema(schema, toolNameMap)
|
||||
}
|
||||
normalizedTools[normalized] = toolMap
|
||||
continue
|
||||
}
|
||||
normalizedTools[normalized] = value
|
||||
}
|
||||
req["tools"] = normalizedTools
|
||||
modified = true
|
||||
}
|
||||
} else {
|
||||
// 确保 tools 字段存在(即使为空数组)
|
||||
if _, exists := req["tools"]; !exists {
|
||||
req["tools"] = []any{}
|
||||
modified = true
|
||||
}
|
||||
|
||||
// 处理 messages 中的 tool_use 块,但保留包含 thinking 块的消息的原始字节
|
||||
messagesModified := false
|
||||
if messages, ok := req["messages"].([]any); ok {
|
||||
for _, msg := range messages {
|
||||
msgMap, ok := msg.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
content, ok := msgMap["content"].([]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
// 检查此消息是否包含 thinking 块
|
||||
hasThinking := false
|
||||
for _, block := range content {
|
||||
blockMap, ok := block.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
blockType, _ := blockMap["type"].(string)
|
||||
if blockType == "thinking" || blockType == "redacted_thinking" {
|
||||
hasThinking = true
|
||||
break
|
||||
}
|
||||
}
|
||||
// 如果包含 thinking 块,跳过此消息的修改
|
||||
if hasThinking {
|
||||
continue
|
||||
}
|
||||
// 只修改不包含 thinking 块的消息中的 tool_use
|
||||
for _, block := range content {
|
||||
blockMap, ok := block.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if blockType, _ := blockMap["type"].(string); blockType != "tool_use" {
|
||||
continue
|
||||
}
|
||||
if name, ok := blockMap["name"].(string); ok {
|
||||
normalized := normalizeToolNameForClaude(name, toolNameMap)
|
||||
if normalized != "" && normalized != name {
|
||||
blockMap["name"] = normalized
|
||||
messagesModified = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if opts.stripSystemCacheControl {
|
||||
if system, ok := req["system"]; ok {
|
||||
_ = stripCacheControlFromSystemBlocks(system)
|
||||
@@ -964,38 +706,15 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
|
||||
modified = true
|
||||
}
|
||||
|
||||
if !modified && !messagesModified {
|
||||
return body, modelID, toolNameMap
|
||||
if !modified {
|
||||
return body, modelID
|
||||
}
|
||||
|
||||
// 如果 messages 没有被修改,保留原始 messages 字节
|
||||
if !messagesModified {
|
||||
// 序列化非 messages 字段
|
||||
newBody, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return body, modelID, toolNameMap
|
||||
return body, modelID
|
||||
}
|
||||
// 替换回原始的 messages
|
||||
var newReq map[string]json.RawMessage
|
||||
if err := json.Unmarshal(newBody, &newReq); err != nil {
|
||||
return newBody, modelID, toolNameMap
|
||||
}
|
||||
if origMessages, ok := reqRaw["messages"]; ok {
|
||||
newReq["messages"] = origMessages
|
||||
}
|
||||
finalBody, err := json.Marshal(newReq)
|
||||
if err != nil {
|
||||
return newBody, modelID, toolNameMap
|
||||
}
|
||||
return finalBody, modelID, toolNameMap
|
||||
}
|
||||
|
||||
// messages 被修改了,需要完整序列化
|
||||
newBody, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return body, modelID, toolNameMap
|
||||
}
|
||||
return newBody, modelID, toolNameMap
|
||||
return newBody, modelID
|
||||
}
|
||||
|
||||
func (s *GatewayService) buildOAuthMetadataUserID(parsed *ParsedRequest, account *Account, fp *Fingerprint) string {
|
||||
@@ -2960,7 +2679,6 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
reqModel := parsed.Model
|
||||
reqStream := parsed.Stream
|
||||
originalModel := reqModel
|
||||
var toolNameMap map[string]string
|
||||
|
||||
isClaudeCode := isClaudeCodeRequest(ctx, c, parsed)
|
||||
shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCode
|
||||
@@ -2984,7 +2702,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
}
|
||||
}
|
||||
|
||||
body, reqModel, toolNameMap = normalizeClaudeOAuthRequestBody(body, reqModel, normalizeOpts)
|
||||
body, reqModel = normalizeClaudeOAuthRequestBody(body, reqModel, normalizeOpts)
|
||||
}
|
||||
|
||||
// 强制执行 cache_control 块数量限制(最多 4 个)
|
||||
@@ -3371,7 +3089,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
var firstTokenMs *int
|
||||
var clientDisconnect bool
|
||||
if reqStream {
|
||||
streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, reqModel, toolNameMap, shouldMimicClaudeCode)
|
||||
streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, reqModel, shouldMimicClaudeCode)
|
||||
if err != nil {
|
||||
if err.Error() == "have error in stream" {
|
||||
return nil, &UpstreamFailoverError{
|
||||
@@ -3384,7 +3102,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
firstTokenMs = streamResult.firstTokenMs
|
||||
clientDisconnect = streamResult.clientDisconnect
|
||||
} else {
|
||||
usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, reqModel, toolNameMap, shouldMimicClaudeCode)
|
||||
usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, reqModel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -3998,7 +3716,7 @@ type streamingResult struct {
|
||||
clientDisconnect bool // 客户端是否在流式传输过程中断开
|
||||
}
|
||||
|
||||
func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string, toolNameMap map[string]string, mimicClaudeCode bool) (*streamingResult, error) {
|
||||
func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string, mimicClaudeCode bool) (*streamingResult, error) {
|
||||
// 更新5h窗口状态
|
||||
s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header)
|
||||
|
||||
@@ -4094,33 +3812,6 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
||||
clientDisconnected := false // 客户端断开标志,断开后继续读取上游以获取完整usage
|
||||
|
||||
pendingEventLines := make([]string, 0, 4)
|
||||
var toolInputBuffers map[int]string
|
||||
if mimicClaudeCode {
|
||||
toolInputBuffers = make(map[int]string)
|
||||
}
|
||||
|
||||
transformToolInputJSON := func(raw string) string {
|
||||
if !mimicClaudeCode {
|
||||
return raw
|
||||
}
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return raw
|
||||
}
|
||||
|
||||
var parsed any
|
||||
if err := json.Unmarshal([]byte(raw), &parsed); err != nil {
|
||||
return replaceToolNamesInText(raw, toolNameMap)
|
||||
}
|
||||
|
||||
rewritten, changed := rewriteParamKeysInValue(parsed, toolNameMap)
|
||||
if changed {
|
||||
if bytes, err := json.Marshal(rewritten); err == nil {
|
||||
return string(bytes)
|
||||
}
|
||||
}
|
||||
return raw
|
||||
}
|
||||
|
||||
processSSEEvent := func(lines []string) ([]string, string, error) {
|
||||
if len(lines) == 0 {
|
||||
@@ -4159,16 +3850,13 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
||||
|
||||
var event map[string]any
|
||||
if err := json.Unmarshal([]byte(dataLine), &event); err != nil {
|
||||
replaced := dataLine
|
||||
if mimicClaudeCode {
|
||||
replaced = replaceToolNamesInText(dataLine, toolNameMap)
|
||||
}
|
||||
// JSON 解析失败,直接透传原始数据
|
||||
block := ""
|
||||
if eventName != "" {
|
||||
block = "event: " + eventName + "\n"
|
||||
}
|
||||
block += "data: " + replaced + "\n\n"
|
||||
return []string{block}, replaced, nil
|
||||
block += "data: " + dataLine + "\n\n"
|
||||
return []string{block}, dataLine, nil
|
||||
}
|
||||
|
||||
eventType, _ := event["type"].(string)
|
||||
@@ -4198,70 +3886,15 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
||||
}
|
||||
}
|
||||
|
||||
if mimicClaudeCode && eventType == "content_block_delta" {
|
||||
if delta, ok := event["delta"].(map[string]any); ok {
|
||||
if deltaType, _ := delta["type"].(string); deltaType == "input_json_delta" {
|
||||
if indexVal, ok := event["index"].(float64); ok {
|
||||
index := int(indexVal)
|
||||
if partial, ok := delta["partial_json"].(string); ok {
|
||||
toolInputBuffers[index] += partial
|
||||
}
|
||||
}
|
||||
return nil, dataLine, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if mimicClaudeCode && eventType == "content_block_stop" {
|
||||
if indexVal, ok := event["index"].(float64); ok {
|
||||
index := int(indexVal)
|
||||
if buffered := toolInputBuffers[index]; buffered != "" {
|
||||
delete(toolInputBuffers, index)
|
||||
|
||||
transformed := transformToolInputJSON(buffered)
|
||||
synthetic := map[string]any{
|
||||
"type": "content_block_delta",
|
||||
"index": index,
|
||||
"delta": map[string]any{
|
||||
"type": "input_json_delta",
|
||||
"partial_json": transformed,
|
||||
},
|
||||
}
|
||||
|
||||
synthBytes, synthErr := json.Marshal(synthetic)
|
||||
if synthErr == nil {
|
||||
synthBlock := "event: content_block_delta\n" + "data: " + string(synthBytes) + "\n\n"
|
||||
|
||||
rewriteToolNamesInValue(event, toolNameMap)
|
||||
stopBytes, stopErr := json.Marshal(event)
|
||||
if stopErr == nil {
|
||||
stopBlock := ""
|
||||
if eventName != "" {
|
||||
stopBlock = "event: " + eventName + "\n"
|
||||
}
|
||||
stopBlock += "data: " + string(stopBytes) + "\n\n"
|
||||
return []string{synthBlock, stopBlock}, string(stopBytes), nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if mimicClaudeCode {
|
||||
rewriteToolNamesInValue(event, toolNameMap)
|
||||
}
|
||||
newData, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
replaced := dataLine
|
||||
if mimicClaudeCode {
|
||||
replaced = replaceToolNamesInText(dataLine, toolNameMap)
|
||||
}
|
||||
// 序列化失败,直接透传原始数据
|
||||
block := ""
|
||||
if eventName != "" {
|
||||
block = "event: " + eventName + "\n"
|
||||
}
|
||||
block += "data: " + replaced + "\n\n"
|
||||
return []string{block}, replaced, nil
|
||||
block += "data: " + dataLine + "\n\n"
|
||||
return []string{block}, dataLine, nil
|
||||
}
|
||||
|
||||
block := ""
|
||||
@@ -4360,126 +3993,6 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
||||
|
||||
}
|
||||
|
||||
func rewriteParamKeysInValue(value any, cache map[string]string) (any, bool) {
|
||||
switch v := value.(type) {
|
||||
case map[string]any:
|
||||
changed := false
|
||||
rewritten := make(map[string]any, len(v))
|
||||
for key, item := range v {
|
||||
newKey := normalizeParamNameForOpenCode(key, cache)
|
||||
newItem, childChanged := rewriteParamKeysInValue(item, cache)
|
||||
if childChanged {
|
||||
changed = true
|
||||
}
|
||||
if newKey != key {
|
||||
changed = true
|
||||
}
|
||||
rewritten[newKey] = newItem
|
||||
}
|
||||
if !changed {
|
||||
return value, false
|
||||
}
|
||||
return rewritten, true
|
||||
case []any:
|
||||
changed := false
|
||||
rewritten := make([]any, len(v))
|
||||
for idx, item := range v {
|
||||
newItem, childChanged := rewriteParamKeysInValue(item, cache)
|
||||
if childChanged {
|
||||
changed = true
|
||||
}
|
||||
rewritten[idx] = newItem
|
||||
}
|
||||
if !changed {
|
||||
return value, false
|
||||
}
|
||||
return rewritten, true
|
||||
default:
|
||||
return value, false
|
||||
}
|
||||
}
|
||||
|
||||
func rewriteToolNamesInValue(value any, toolNameMap map[string]string) bool {
|
||||
switch v := value.(type) {
|
||||
case map[string]any:
|
||||
changed := false
|
||||
if blockType, _ := v["type"].(string); blockType == "tool_use" {
|
||||
if name, ok := v["name"].(string); ok {
|
||||
mapped := normalizeToolNameForOpenCode(name, toolNameMap)
|
||||
if mapped != name {
|
||||
v["name"] = mapped
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
if input, ok := v["input"].(map[string]any); ok {
|
||||
rewrittenInput, inputChanged := rewriteParamKeysInValue(input, toolNameMap)
|
||||
if inputChanged {
|
||||
if m, ok := rewrittenInput.(map[string]any); ok {
|
||||
v["input"] = m
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, item := range v {
|
||||
if rewriteToolNamesInValue(item, toolNameMap) {
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
return changed
|
||||
case []any:
|
||||
changed := false
|
||||
for _, item := range v {
|
||||
if rewriteToolNamesInValue(item, toolNameMap) {
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
return changed
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func replaceToolNamesInText(text string, toolNameMap map[string]string) string {
|
||||
if text == "" {
|
||||
return text
|
||||
}
|
||||
output := toolNameFieldRe.ReplaceAllStringFunc(text, func(match string) string {
|
||||
submatches := toolNameFieldRe.FindStringSubmatch(match)
|
||||
if len(submatches) < 2 {
|
||||
return match
|
||||
}
|
||||
name := submatches[1]
|
||||
mapped := normalizeToolNameForOpenCode(name, toolNameMap)
|
||||
if mapped == name {
|
||||
return match
|
||||
}
|
||||
return strings.Replace(match, name, mapped, 1)
|
||||
})
|
||||
output = modelFieldRe.ReplaceAllStringFunc(output, func(match string) string {
|
||||
submatches := modelFieldRe.FindStringSubmatch(match)
|
||||
if len(submatches) < 2 {
|
||||
return match
|
||||
}
|
||||
model := submatches[1]
|
||||
mapped := claude.DenormalizeModelID(model)
|
||||
if mapped == model {
|
||||
return match
|
||||
}
|
||||
return strings.Replace(match, model, mapped, 1)
|
||||
})
|
||||
|
||||
for mapped, original := range toolNameMap {
|
||||
if mapped == "" || original == "" || mapped == original {
|
||||
continue
|
||||
}
|
||||
output = strings.ReplaceAll(output, "\""+mapped+"\":", "\""+original+"\":")
|
||||
output = strings.ReplaceAll(output, "\\\""+mapped+"\\\":", "\\\""+original+"\\\":")
|
||||
}
|
||||
|
||||
return output
|
||||
}
|
||||
|
||||
func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) {
|
||||
// 解析message_start获取input tokens(标准Claude API格式)
|
||||
var msgStart struct {
|
||||
@@ -4523,7 +4036,7 @@ func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string, toolNameMap map[string]string, mimicClaudeCode bool) (*ClaudeUsage, error) {
|
||||
func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string) (*ClaudeUsage, error) {
|
||||
// 更新5h窗口状态
|
||||
s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header)
|
||||
|
||||
@@ -4555,9 +4068,6 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h
|
||||
if originalModel != mappedModel {
|
||||
body = s.replaceModelInResponseBody(body, mappedModel, originalModel)
|
||||
}
|
||||
if mimicClaudeCode {
|
||||
body = s.replaceToolNamesInResponseBody(body, toolNameMap)
|
||||
}
|
||||
|
||||
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
|
||||
|
||||
@@ -4595,28 +4105,6 @@ func (s *GatewayService) replaceModelInResponseBody(body []byte, fromModel, toMo
|
||||
return newBody
|
||||
}
|
||||
|
||||
func (s *GatewayService) replaceToolNamesInResponseBody(body []byte, toolNameMap map[string]string) []byte {
|
||||
if len(body) == 0 {
|
||||
return body
|
||||
}
|
||||
var resp map[string]any
|
||||
if err := json.Unmarshal(body, &resp); err != nil {
|
||||
replaced := replaceToolNamesInText(string(body), toolNameMap)
|
||||
if replaced == string(body) {
|
||||
return body
|
||||
}
|
||||
return []byte(replaced)
|
||||
}
|
||||
if !rewriteToolNamesInValue(resp, toolNameMap) {
|
||||
return body
|
||||
}
|
||||
newBody, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
return body
|
||||
}
|
||||
return newBody
|
||||
}
|
||||
|
||||
// RecordUsageInput 记录使用量的输入参数
|
||||
type RecordUsageInput struct {
|
||||
Result *ForwardResult
|
||||
@@ -4977,7 +4465,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
||||
|
||||
if shouldMimicClaudeCode {
|
||||
normalizeOpts := claudeOAuthNormalizeOptions{stripSystemCacheControl: true}
|
||||
body, reqModel, _ = normalizeClaudeOAuthRequestBody(body, reqModel, normalizeOpts)
|
||||
body, reqModel = normalizeClaudeOAuthRequestBody(body, reqModel, normalizeOpts)
|
||||
}
|
||||
|
||||
// Antigravity 账户不支持 count_tokens 转发,直接返回空值
|
||||
|
||||
Reference in New Issue
Block a user