From 539b41f42101486568b2594557c2cc95a2122478 Mon Sep 17 00:00:00 2001 From: IanShaw027 <131567472+IanShaw027@users.noreply.github.com> Date: Thu, 15 Jan 2026 23:52:50 +0800 Subject: [PATCH 1/3] =?UTF-8?q?feat(openai):=20=E6=B7=BB=E5=8A=A0Codex?= =?UTF-8?q?=E5=B7=A5=E5=85=B7=E8=B0=83=E7=94=A8=E8=87=AA=E5=8A=A8=E4=BF=AE?= =?UTF-8?q?=E6=AD=A3=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 实现了完整的Codex工具调用拦截和自动修正系统,解决OpenCode使用Codex模型时的工具调用兼容性问题。 **核心功能:** 1. **工具名称自动映射** - apply_patch/applyPatch → edit - update_plan/updatePlan → todowrite - read_plan/readPlan → todoread - search_files/searchFiles → grep - list_files/listFiles → glob - read_file/readFile → read - write_file/writeFile → write - execute_bash/executeBash/exec_bash/execBash → bash 2. **工具参数自动修正** - bash: 自动移除不支持的 workdir/work_dir 参数 - edit: 自动将 path 参数重命名为 file_path - 支持 JSON 字符串和对象两种参数格式 3. **流式响应集成** - 在 SSE 数据流中实时修正工具调用 - 支持多种 JSON 结构(tool_calls, function_call, delta, choices等) - 不影响响应性能和用户体验 4. **统计和监控** - 记录每次工具修正的详细信息 - 提供修正统计数据查询 - 便于问题排查和性能优化 **实现文件:** - `openai_tool_corrector.go`: 工具修正核心逻辑(250行) - `openai_tool_corrector_test.go`: 完整的单元测试(380+行) - `openai_gateway_service.go`: 流式响应集成 - `openai_gateway_service_tool_correction_test.go`: 集成测试 **测试覆盖:** - 工具名称映射测试(18个映射规则) - 参数修正测试(bash workdir、edit path等) - SSE数据修正测试(多种JSON结构) - 统计功能测试 - 所有测试通过 ✅ **解决的问题:** 修复了 OpenCode 使用 sub2api 中转 Codex 时,因工具名称和参数不兼容导致的工具调用失败问题。 Codex 模型有时会忽略指令文件中的工具映射说明,导致调用不存在的工具(如 apply_patch)。 现在通过流式响应拦截,自动将错误的工具调用修正为 OpenCode 兼容的格式。 **参考文档:** - OpenCode 工具规范: https://opencode.ai/docs/ - Codex Bridge 指令: backend/internal/service/prompts/codex_opencode_bridge.txt --- .../service/openai_gateway_service.go | 23 + ...ai_gateway_service_tool_correction_test.go | 133 ++++++ .../internal/service/openai_tool_corrector.go | 307 +++++++++++++ .../service/openai_tool_corrector_test.go | 410 ++++++++++++++++++ 4 files changed, 873 insertions(+) create mode 100644 backend/internal/service/openai_gateway_service_tool_correction_test.go create mode 100644 backend/internal/service/openai_tool_corrector.go create mode 100644 backend/internal/service/openai_tool_corrector_test.go diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index 859480de..c7d94882 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -94,6 +94,7 @@ type OpenAIGatewayService struct { httpUpstream HTTPUpstream deferredService *DeferredService openAITokenProvider *OpenAITokenProvider + toolCorrector *CodexToolCorrector } // NewOpenAIGatewayService creates a new OpenAIGatewayService @@ -128,6 +129,7 @@ func NewOpenAIGatewayService( httpUpstream: httpUpstream, deferredService: deferredService, openAITokenProvider: openAITokenProvider, + toolCorrector: NewCodexToolCorrector(), } } @@ -1106,6 +1108,11 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp line = s.replaceModelInSSELine(line, mappedModel, originalModel) } + // Correct Codex tool calls if needed (apply_patch -> edit, etc.) + if correctedData, corrected := s.toolCorrector.CorrectToolCallsInSSEData(data); corrected { + line = "data: " + correctedData + } + // Forward line if _, err := fmt.Fprintf(w, "%s\n", line); err != nil { sendErrorEvent("write_failed") @@ -1193,6 +1200,20 @@ func (s *OpenAIGatewayService) replaceModelInSSELine(line, fromModel, toModel st return line } +// correctToolCallsInResponseBody 修正响应体中的工具调用 +func (s *OpenAIGatewayService) correctToolCallsInResponseBody(body []byte) []byte { + if len(body) == 0 { + return body + } + + bodyStr := string(body) + corrected, changed := s.toolCorrector.CorrectToolCallsInSSEData(bodyStr) + if changed { + return []byte(corrected) + } + return body +} + func (s *OpenAIGatewayService) parseSSEUsage(data string, usage *OpenAIUsage) { // Parse response.completed event for usage (OpenAI Responses format) var event struct { @@ -1296,6 +1317,8 @@ func (s *OpenAIGatewayService) handleOAuthSSEToJSON(resp *http.Response, c *gin. if originalModel != mappedModel { body = s.replaceModelInResponseBody(body, mappedModel, originalModel) } + // Correct tool calls in final response + body = s.correctToolCallsInResponseBody(body) } else { usage = s.parseSSEUsageFromBody(bodyText) if originalModel != mappedModel { diff --git a/backend/internal/service/openai_gateway_service_tool_correction_test.go b/backend/internal/service/openai_gateway_service_tool_correction_test.go new file mode 100644 index 00000000..d4491cfe --- /dev/null +++ b/backend/internal/service/openai_gateway_service_tool_correction_test.go @@ -0,0 +1,133 @@ +package service + +import ( + "strings" + "testing" +) + +// TestOpenAIGatewayService_ToolCorrection 测试 OpenAIGatewayService 中的工具修正集成 +func TestOpenAIGatewayService_ToolCorrection(t *testing.T) { + // 创建一个简单的 service 实例来测试工具修正 + service := &OpenAIGatewayService{ + toolCorrector: NewCodexToolCorrector(), + } + + tests := []struct { + name string + input []byte + expected string + changed bool + }{ + { + name: "correct apply_patch in response body", + input: []byte(`{ + "choices": [{ + "message": { + "tool_calls": [{ + "function": {"name": "apply_patch"} + }] + } + }] + }`), + expected: "edit", + changed: true, + }, + { + name: "correct update_plan in response body", + input: []byte(`{ + "tool_calls": [{ + "function": {"name": "update_plan"} + }] + }`), + expected: "todowrite", + changed: true, + }, + { + name: "no change for correct tool name", + input: []byte(`{ + "tool_calls": [{ + "function": {"name": "edit"} + }] + }`), + expected: "edit", + changed: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := service.correctToolCallsInResponseBody(tt.input) + resultStr := string(result) + + // 检查是否包含期望的工具名称 + if !strings.Contains(resultStr, tt.expected) { + t.Errorf("expected result to contain %q, got %q", tt.expected, resultStr) + } + + // 对于预期有变化的情况,验证结果与输入不同 + if tt.changed && string(result) == string(tt.input) { + t.Error("expected result to be different from input, but they are the same") + } + + // 对于预期无变化的情况,验证结果与输入相同 + if !tt.changed && string(result) != string(tt.input) { + t.Error("expected result to be same as input, but they are different") + } + }) + } +} + +// TestOpenAIGatewayService_ToolCorrectorInitialization 测试工具修正器是否正确初始化 +func TestOpenAIGatewayService_ToolCorrectorInitialization(t *testing.T) { + service := &OpenAIGatewayService{ + toolCorrector: NewCodexToolCorrector(), + } + + if service.toolCorrector == nil { + t.Fatal("toolCorrector should not be nil") + } + + // 测试修正器可以正常工作 + data := `{"tool_calls":[{"function":{"name":"apply_patch"}}]}` + corrected, changed := service.toolCorrector.CorrectToolCallsInSSEData(data) + + if !changed { + t.Error("expected tool call to be corrected") + } + + if !strings.Contains(corrected, "edit") { + t.Errorf("expected corrected data to contain 'edit', got %q", corrected) + } +} + +// TestToolCorrectionStats 测试工具修正统计功能 +func TestToolCorrectionStats(t *testing.T) { + service := &OpenAIGatewayService{ + toolCorrector: NewCodexToolCorrector(), + } + + // 执行几次修正 + testData := []string{ + `{"tool_calls":[{"function":{"name":"apply_patch"}}]}`, + `{"tool_calls":[{"function":{"name":"update_plan"}}]}`, + `{"tool_calls":[{"function":{"name":"apply_patch"}}]}`, + } + + for _, data := range testData { + service.toolCorrector.CorrectToolCallsInSSEData(data) + } + + stats := service.toolCorrector.GetStats() + + if stats.TotalCorrected != 3 { + t.Errorf("expected 3 corrections, got %d", stats.TotalCorrected) + } + + if stats.CorrectionsByTool["apply_patch->edit"] != 2 { + t.Errorf("expected 2 apply_patch->edit corrections, got %d", stats.CorrectionsByTool["apply_patch->edit"]) + } + + if stats.CorrectionsByTool["update_plan->todowrite"] != 1 { + t.Errorf("expected 1 update_plan->todowrite correction, got %d", stats.CorrectionsByTool["update_plan->todowrite"]) + } +} diff --git a/backend/internal/service/openai_tool_corrector.go b/backend/internal/service/openai_tool_corrector.go new file mode 100644 index 00000000..6dad1a75 --- /dev/null +++ b/backend/internal/service/openai_tool_corrector.go @@ -0,0 +1,307 @@ +package service + +import ( + "encoding/json" + "fmt" + "log" + "sync" +) + +// codexToolNameMapping 定义 Codex 原生工具名称到 OpenCode 工具名称的映射 +var codexToolNameMapping = map[string]string{ + "apply_patch": "edit", + "applyPatch": "edit", + "update_plan": "todowrite", + "updatePlan": "todowrite", + "read_plan": "todoread", + "readPlan": "todoread", + "search_files": "grep", + "searchFiles": "grep", + "list_files": "glob", + "listFiles": "glob", + "read_file": "read", + "readFile": "read", + "write_file": "write", + "writeFile": "write", + "execute_bash": "bash", + "executeBash": "bash", + "exec_bash": "bash", + "execBash": "bash", +} + +// ToolCorrectionStats 记录工具修正的统计信息 +type ToolCorrectionStats struct { + TotalCorrected int `json:"total_corrected"` + CorrectionsByTool map[string]int `json:"corrections_by_tool"` + mu sync.RWMutex +} + +// CodexToolCorrector 处理 Codex 工具调用的自动修正 +type CodexToolCorrector struct { + stats ToolCorrectionStats +} + +// NewCodexToolCorrector 创建新的工具修正器 +func NewCodexToolCorrector() *CodexToolCorrector { + return &CodexToolCorrector{ + stats: ToolCorrectionStats{ + CorrectionsByTool: make(map[string]int), + }, + } +} + +// CorrectToolCallsInSSEData 修正 SSE 数据中的工具调用 +// 返回修正后的数据和是否进行了修正 +func (c *CodexToolCorrector) CorrectToolCallsInSSEData(data string) (string, bool) { + if data == "" || data == "\n" { + return data, false + } + + // 尝试解析 JSON + var payload map[string]any + if err := json.Unmarshal([]byte(data), &payload); err != nil { + // 不是有效的 JSON,直接返回原数据 + return data, false + } + + corrected := false + + // 处理 tool_calls 数组 + if toolCalls, ok := payload["tool_calls"].([]any); ok { + if c.correctToolCallsArray(toolCalls) { + corrected = true + } + } + + // 处理 function_call 对象 + if functionCall, ok := payload["function_call"].(map[string]any); ok { + if c.correctFunctionCall(functionCall) { + corrected = true + } + } + + // 处理 delta.tool_calls + if delta, ok := payload["delta"].(map[string]any); ok { + if toolCalls, ok := delta["tool_calls"].([]any); ok { + if c.correctToolCallsArray(toolCalls) { + corrected = true + } + } + if functionCall, ok := delta["function_call"].(map[string]any); ok { + if c.correctFunctionCall(functionCall) { + corrected = true + } + } + } + + // 处理 choices[].message.tool_calls 和 choices[].delta.tool_calls + if choices, ok := payload["choices"].([]any); ok { + for _, choice := range choices { + if choiceMap, ok := choice.(map[string]any); ok { + // 处理 message 中的工具调用 + if message, ok := choiceMap["message"].(map[string]any); ok { + if toolCalls, ok := message["tool_calls"].([]any); ok { + if c.correctToolCallsArray(toolCalls) { + corrected = true + } + } + if functionCall, ok := message["function_call"].(map[string]any); ok { + if c.correctFunctionCall(functionCall) { + corrected = true + } + } + } + // 处理 delta 中的工具调用 + if delta, ok := choiceMap["delta"].(map[string]any); ok { + if toolCalls, ok := delta["tool_calls"].([]any); ok { + if c.correctToolCallsArray(toolCalls) { + corrected = true + } + } + if functionCall, ok := delta["function_call"].(map[string]any); ok { + if c.correctFunctionCall(functionCall) { + corrected = true + } + } + } + } + } + } + + if !corrected { + return data, false + } + + // 序列化回 JSON + correctedBytes, err := json.Marshal(payload) + if err != nil { + log.Printf("[CodexToolCorrector] Failed to marshal corrected data: %v", err) + return data, false + } + + return string(correctedBytes), true +} + +// correctToolCallsArray 修正工具调用数组中的工具名称 +func (c *CodexToolCorrector) correctToolCallsArray(toolCalls []any) bool { + corrected := false + for _, toolCall := range toolCalls { + if toolCallMap, ok := toolCall.(map[string]any); ok { + if function, ok := toolCallMap["function"].(map[string]any); ok { + if c.correctFunctionCall(function) { + corrected = true + } + } + } + } + return corrected +} + +// correctFunctionCall 修正单个函数调用的工具名称和参数 +func (c *CodexToolCorrector) correctFunctionCall(functionCall map[string]any) bool { + name, ok := functionCall["name"].(string) + if !ok || name == "" { + return false + } + + corrected := false + + // 查找并修正工具名称 + if correctName, found := codexToolNameMapping[name]; found { + functionCall["name"] = correctName + c.recordCorrection(name, correctName) + corrected = true + name = correctName // 使用修正后的名称进行参数修正 + } + + // 修正工具参数(基于工具名称) + if c.correctToolParameters(name, functionCall) { + corrected = true + } + + return corrected +} + +// correctToolParameters 修正工具参数以符合 OpenCode 规范 +func (c *CodexToolCorrector) correctToolParameters(toolName string, functionCall map[string]any) bool { + arguments, ok := functionCall["arguments"] + if !ok { + return false + } + + // arguments 可能是字符串(JSON)或已解析的 map + var argsMap map[string]any + switch v := arguments.(type) { + case string: + // 解析 JSON 字符串 + if err := json.Unmarshal([]byte(v), &argsMap); err != nil { + return false + } + case map[string]any: + argsMap = v + default: + return false + } + + corrected := false + + // 根据工具名称应用特定的参数修正规则 + switch toolName { + case "bash": + // 移除 workdir 参数(OpenCode 不支持) + if _, exists := argsMap["workdir"]; exists { + delete(argsMap, "workdir") + corrected = true + log.Printf("[CodexToolCorrector] Removed 'workdir' parameter from bash tool") + } + if _, exists := argsMap["work_dir"]; exists { + delete(argsMap, "work_dir") + corrected = true + log.Printf("[CodexToolCorrector] Removed 'work_dir' parameter from bash tool") + } + + case "edit": + // OpenCode edit 使用 old_string/new_string,Codex 可能使用其他名称 + // 这里可以添加参数名称的映射逻辑 + if _, exists := argsMap["file_path"]; !exists { + if path, exists := argsMap["path"]; exists { + argsMap["file_path"] = path + delete(argsMap, "path") + corrected = true + log.Printf("[CodexToolCorrector] Renamed 'path' to 'file_path' in edit tool") + } + } + } + + // 如果修正了参数,需要重新序列化 + if corrected { + if _, wasString := arguments.(string); wasString { + // 原本是字符串,序列化回字符串 + if newArgsJSON, err := json.Marshal(argsMap); err == nil { + functionCall["arguments"] = string(newArgsJSON) + } + } else { + // 原本是 map,直接赋值 + functionCall["arguments"] = argsMap + } + } + + return corrected +} + +// recordCorrection 记录一次工具名称修正 +func (c *CodexToolCorrector) recordCorrection(from, to string) { + c.stats.mu.Lock() + defer c.stats.mu.Unlock() + + c.stats.TotalCorrected++ + key := fmt.Sprintf("%s->%s", from, to) + c.stats.CorrectionsByTool[key]++ + + log.Printf("[CodexToolCorrector] Corrected tool call: %s -> %s (total: %d)", + from, to, c.stats.TotalCorrected) +} + +// GetStats 获取工具修正统计信息 +func (c *CodexToolCorrector) GetStats() ToolCorrectionStats { + c.stats.mu.RLock() + defer c.stats.mu.RUnlock() + + // 返回副本以避免并发问题 + statsCopy := ToolCorrectionStats{ + TotalCorrected: c.stats.TotalCorrected, + CorrectionsByTool: make(map[string]int, len(c.stats.CorrectionsByTool)), + } + for k, v := range c.stats.CorrectionsByTool { + statsCopy.CorrectionsByTool[k] = v + } + + return statsCopy +} + +// ResetStats 重置统计信息 +func (c *CodexToolCorrector) ResetStats() { + c.stats.mu.Lock() + defer c.stats.mu.Unlock() + + c.stats.TotalCorrected = 0 + c.stats.CorrectionsByTool = make(map[string]int) +} + +// CorrectToolName 直接修正工具名称(用于非 SSE 场景) +func CorrectToolName(name string) (string, bool) { + if correctName, found := codexToolNameMapping[name]; found { + return correctName, true + } + return name, false +} + +// GetToolNameMapping 获取工具名称映射表 +func GetToolNameMapping() map[string]string { + // 返回副本以避免外部修改 + mapping := make(map[string]string, len(codexToolNameMapping)) + for k, v := range codexToolNameMapping { + mapping[k] = v + } + return mapping +} diff --git a/backend/internal/service/openai_tool_corrector_test.go b/backend/internal/service/openai_tool_corrector_test.go new file mode 100644 index 00000000..7219c2af --- /dev/null +++ b/backend/internal/service/openai_tool_corrector_test.go @@ -0,0 +1,410 @@ +package service + +import ( + "encoding/json" + "testing" +) + +func TestCorrectToolCallsInSSEData(t *testing.T) { + corrector := NewCodexToolCorrector() + + tests := []struct { + name string + input string + expectCorrected bool + checkFunc func(t *testing.T, result string) + }{ + { + name: "empty string", + input: "", + expectCorrected: false, + }, + { + name: "newline only", + input: "\n", + expectCorrected: false, + }, + { + name: "invalid json", + input: "not a json", + expectCorrected: false, + }, + { + name: "correct apply_patch in tool_calls", + input: `{"tool_calls":[{"function":{"name":"apply_patch","arguments":"{}"}}]}`, + expectCorrected: true, + checkFunc: func(t *testing.T, result string) { + var payload map[string]any + if err := json.Unmarshal([]byte(result), &payload); err != nil { + t.Fatalf("Failed to parse result: %v", err) + } + toolCalls := payload["tool_calls"].([]any) + functionCall := toolCalls[0].(map[string]any)["function"].(map[string]any) + if functionCall["name"] != "edit" { + t.Errorf("Expected tool name 'edit', got '%v'", functionCall["name"]) + } + }, + }, + { + name: "correct update_plan in function_call", + input: `{"function_call":{"name":"update_plan","arguments":"{}"}}`, + expectCorrected: true, + checkFunc: func(t *testing.T, result string) { + var payload map[string]any + if err := json.Unmarshal([]byte(result), &payload); err != nil { + t.Fatalf("Failed to parse result: %v", err) + } + functionCall := payload["function_call"].(map[string]any) + if functionCall["name"] != "todowrite" { + t.Errorf("Expected tool name 'todowrite', got '%v'", functionCall["name"]) + } + }, + }, + { + name: "correct search_files in delta.tool_calls", + input: `{"delta":{"tool_calls":[{"function":{"name":"search_files"}}]}}`, + expectCorrected: true, + checkFunc: func(t *testing.T, result string) { + var payload map[string]any + if err := json.Unmarshal([]byte(result), &payload); err != nil { + t.Fatalf("Failed to parse result: %v", err) + } + delta := payload["delta"].(map[string]any) + toolCalls := delta["tool_calls"].([]any) + functionCall := toolCalls[0].(map[string]any)["function"].(map[string]any) + if functionCall["name"] != "grep" { + t.Errorf("Expected tool name 'grep', got '%v'", functionCall["name"]) + } + }, + }, + { + name: "correct list_files in choices.message.tool_calls", + input: `{"choices":[{"message":{"tool_calls":[{"function":{"name":"list_files"}}]}}]}`, + expectCorrected: true, + checkFunc: func(t *testing.T, result string) { + var payload map[string]any + if err := json.Unmarshal([]byte(result), &payload); err != nil { + t.Fatalf("Failed to parse result: %v", err) + } + choices := payload["choices"].([]any) + message := choices[0].(map[string]any)["message"].(map[string]any) + toolCalls := message["tool_calls"].([]any) + functionCall := toolCalls[0].(map[string]any)["function"].(map[string]any) + if functionCall["name"] != "glob" { + t.Errorf("Expected tool name 'glob', got '%v'", functionCall["name"]) + } + }, + }, + { + name: "no correction needed", + input: `{"tool_calls":[{"function":{"name":"read","arguments":"{}"}}]}`, + expectCorrected: false, + }, + { + name: "correct multiple tool calls", + input: `{"tool_calls":[{"function":{"name":"apply_patch"}},{"function":{"name":"read_file"}}]}`, + expectCorrected: true, + checkFunc: func(t *testing.T, result string) { + var payload map[string]any + if err := json.Unmarshal([]byte(result), &payload); err != nil { + t.Fatalf("Failed to parse result: %v", err) + } + toolCalls := payload["tool_calls"].([]any) + + func1 := toolCalls[0].(map[string]any)["function"].(map[string]any) + if func1["name"] != "edit" { + t.Errorf("Expected first tool name 'edit', got '%v'", func1["name"]) + } + + func2 := toolCalls[1].(map[string]any)["function"].(map[string]any) + if func2["name"] != "read" { + t.Errorf("Expected second tool name 'read', got '%v'", func2["name"]) + } + }, + }, + { + name: "camelCase format - applyPatch", + input: `{"tool_calls":[{"function":{"name":"applyPatch"}}]}`, + expectCorrected: true, + checkFunc: func(t *testing.T, result string) { + var payload map[string]any + if err := json.Unmarshal([]byte(result), &payload); err != nil { + t.Fatalf("Failed to parse result: %v", err) + } + toolCalls := payload["tool_calls"].([]any) + functionCall := toolCalls[0].(map[string]any)["function"].(map[string]any) + if functionCall["name"] != "edit" { + t.Errorf("Expected tool name 'edit', got '%v'", functionCall["name"]) + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, corrected := corrector.CorrectToolCallsInSSEData(tt.input) + + if corrected != tt.expectCorrected { + t.Errorf("Expected corrected=%v, got %v", tt.expectCorrected, corrected) + } + + if !corrected && result != tt.input { + t.Errorf("Expected unchanged result when not corrected") + } + + if tt.checkFunc != nil { + tt.checkFunc(t, result) + } + }) + } +} + +func TestCorrectToolName(t *testing.T) { + tests := []struct { + input string + expected string + corrected bool + }{ + {"apply_patch", "edit", true}, + {"applyPatch", "edit", true}, + {"update_plan", "todowrite", true}, + {"updatePlan", "todowrite", true}, + {"read_plan", "todoread", true}, + {"readPlan", "todoread", true}, + {"search_files", "grep", true}, + {"searchFiles", "grep", true}, + {"list_files", "glob", true}, + {"listFiles", "glob", true}, + {"read_file", "read", true}, + {"readFile", "read", true}, + {"write_file", "write", true}, + {"writeFile", "write", true}, + {"execute_bash", "bash", true}, + {"executeBash", "bash", true}, + {"exec_bash", "bash", true}, + {"execBash", "bash", true}, + {"unknown_tool", "unknown_tool", false}, + {"read", "read", false}, + {"edit", "edit", false}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + result, corrected := CorrectToolName(tt.input) + + if corrected != tt.corrected { + t.Errorf("Expected corrected=%v, got %v", tt.corrected, corrected) + } + + if result != tt.expected { + t.Errorf("Expected '%s', got '%s'", tt.expected, result) + } + }) + } +} + +func TestGetToolNameMapping(t *testing.T) { + mapping := GetToolNameMapping() + + expectedMappings := map[string]string{ + "apply_patch": "edit", + "update_plan": "todowrite", + "read_plan": "todoread", + "search_files": "grep", + "list_files": "glob", + } + + for from, to := range expectedMappings { + if mapping[from] != to { + t.Errorf("Expected mapping[%s] = %s, got %s", from, to, mapping[from]) + } + } + + mapping["test_tool"] = "test_value" + newMapping := GetToolNameMapping() + if _, exists := newMapping["test_tool"]; exists { + t.Error("Modifications to returned mapping should not affect original") + } +} + +func TestCorrectorStats(t *testing.T) { + corrector := NewCodexToolCorrector() + + stats := corrector.GetStats() + if stats.TotalCorrected != 0 { + t.Errorf("Expected TotalCorrected=0, got %d", stats.TotalCorrected) + } + if len(stats.CorrectionsByTool) != 0 { + t.Errorf("Expected empty CorrectionsByTool, got length %d", len(stats.CorrectionsByTool)) + } + + corrector.CorrectToolCallsInSSEData(`{"tool_calls":[{"function":{"name":"apply_patch"}}]}`) + corrector.CorrectToolCallsInSSEData(`{"tool_calls":[{"function":{"name":"apply_patch"}}]}`) + corrector.CorrectToolCallsInSSEData(`{"tool_calls":[{"function":{"name":"update_plan"}}]}`) + + stats = corrector.GetStats() + if stats.TotalCorrected != 3 { + t.Errorf("Expected TotalCorrected=3, got %d", stats.TotalCorrected) + } + + if stats.CorrectionsByTool["apply_patch->edit"] != 2 { + t.Errorf("Expected apply_patch->edit count=2, got %d", stats.CorrectionsByTool["apply_patch->edit"]) + } + + if stats.CorrectionsByTool["update_plan->todowrite"] != 1 { + t.Errorf("Expected update_plan->todowrite count=1, got %d", stats.CorrectionsByTool["update_plan->todowrite"]) + } + + corrector.ResetStats() + stats = corrector.GetStats() + if stats.TotalCorrected != 0 { + t.Errorf("Expected TotalCorrected=0 after reset, got %d", stats.TotalCorrected) + } + if len(stats.CorrectionsByTool) != 0 { + t.Errorf("Expected empty CorrectionsByTool after reset, got length %d", len(stats.CorrectionsByTool)) + } +} + +func TestComplexSSEData(t *testing.T) { + corrector := NewCodexToolCorrector() + + input := `{ + "id": "chatcmpl-123", + "object": "chat.completion.chunk", + "created": 1234567890, + "model": "gpt-5.1-codex", + "choices": [ + { + "index": 0, + "delta": { + "tool_calls": [ + { + "index": 0, + "function": { + "name": "apply_patch", + "arguments": "{\"file\":\"test.go\"}" + } + } + ] + }, + "finish_reason": null + } + ] + }` + + result, corrected := corrector.CorrectToolCallsInSSEData(input) + + if !corrected { + t.Error("Expected data to be corrected") + } + + var payload map[string]any + if err := json.Unmarshal([]byte(result), &payload); err != nil { + t.Fatalf("Failed to parse result: %v", err) + } + + choices := payload["choices"].([]any) + delta := choices[0].(map[string]any)["delta"].(map[string]any) + toolCalls := delta["tool_calls"].([]any) + function := toolCalls[0].(map[string]any)["function"].(map[string]any) + + if function["name"] != "edit" { + t.Errorf("Expected tool name 'edit', got '%v'", function["name"]) + } +} + +// TestCorrectToolParameters 测试工具参数修正 +func TestCorrectToolParameters(t *testing.T) { + corrector := NewCodexToolCorrector() + + tests := []struct { + name string + input string + expected map[string]bool // key: 期待存在的参数, value: true表示应该存在 + }{ + { + name: "remove workdir from bash tool", + input: `{ + "tool_calls": [{ + "function": { + "name": "bash", + "arguments": "{\"command\":\"ls\",\"workdir\":\"/tmp\"}" + } + }] + }`, + expected: map[string]bool{ + "command": true, + "workdir": false, + }, + }, + { + name: "rename path to file_path in edit tool", + input: `{ + "tool_calls": [{ + "function": { + "name": "apply_patch", + "arguments": "{\"path\":\"/foo/bar.go\",\"old_string\":\"old\",\"new_string\":\"new\"}" + } + }] + }`, + expected: map[string]bool{ + "file_path": true, + "path": false, + "old_string": true, + "new_string": true, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + corrected, changed := corrector.CorrectToolCallsInSSEData(tt.input) + if !changed { + t.Error("expected data to be corrected") + } + + // 解析修正后的数据 + var result map[string]any + if err := json.Unmarshal([]byte(corrected), &result); err != nil { + t.Fatalf("failed to parse corrected data: %v", err) + } + + // 检查工具调用 + toolCalls, ok := result["tool_calls"].([]any) + if !ok || len(toolCalls) == 0 { + t.Fatal("no tool_calls found in corrected data") + } + + toolCall, ok := toolCalls[0].(map[string]any) + if !ok { + t.Fatal("invalid tool_call structure") + } + + function, ok := toolCall["function"].(map[string]any) + if !ok { + t.Fatal("no function found in tool_call") + } + + argumentsStr, ok := function["arguments"].(string) + if !ok { + t.Fatal("arguments is not a string") + } + + var args map[string]any + if err := json.Unmarshal([]byte(argumentsStr), &args); err != nil { + t.Fatalf("failed to parse arguments: %v", err) + } + + // 验证期望的参数 + for param, shouldExist := range tt.expected { + _, exists := args[param] + if shouldExist && !exists { + t.Errorf("expected parameter %q to exist, but it doesn't", param) + } + if !shouldExist && exists { + t.Errorf("expected parameter %q to not exist, but it does", param) + } + } + }) + } +} From c4f6c89b65f8b1bc5531e7cc944892b953ec5e9c Mon Sep 17 00:00:00 2001 From: IanShaw027 <131567472+IanShaw027@users.noreply.github.com> Date: Fri, 16 Jan 2026 00:02:22 +0800 Subject: [PATCH 2/3] =?UTF-8?q?fix(lint):=20=E4=BF=AE=E5=A4=8Dgolangci-lin?= =?UTF-8?q?t=E6=A3=80=E6=9F=A5=E5=8F=91=E7=8E=B0=E7=9A=84=E9=97=AE?= =?UTF-8?q?=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 修复了4个lint问题: 1. errcheck (3处): 在测试中添加类型断言的ok检查 2. govet copylocks (1处): 将mutex从ToolCorrectionStats移到CodexToolCorrector **详细修改:** 1. **openai_tool_corrector_test.go** - 添加了类型断言的ok检查,避免panic - 在解析JSON后检查payload结构的有效性 - 改进错误处理和测试可靠性 2. **openai_tool_corrector.go** - 将sync.RWMutex从ToolCorrectionStats移到CodexToolCorrector - 避免在GetStats()返回时复制mutex - 保持线程安全的同时符合Go最佳实践 **测试验证:** - 所有单元测试通过 ✅ - go vet 检查通过 ✅ - 代码编译正常 ✅ --- .../internal/service/openai_tool_corrector.go | 16 ++++----- .../service/openai_tool_corrector_test.go | 33 ++++++++++++++++--- 2 files changed, 36 insertions(+), 13 deletions(-) diff --git a/backend/internal/service/openai_tool_corrector.go b/backend/internal/service/openai_tool_corrector.go index 6dad1a75..9c9eab84 100644 --- a/backend/internal/service/openai_tool_corrector.go +++ b/backend/internal/service/openai_tool_corrector.go @@ -29,16 +29,16 @@ var codexToolNameMapping = map[string]string{ "execBash": "bash", } -// ToolCorrectionStats 记录工具修正的统计信息 +// ToolCorrectionStats 记录工具修正的统计信息(导出用于 JSON 序列化) type ToolCorrectionStats struct { TotalCorrected int `json:"total_corrected"` CorrectionsByTool map[string]int `json:"corrections_by_tool"` - mu sync.RWMutex } // CodexToolCorrector 处理 Codex 工具调用的自动修正 type CodexToolCorrector struct { stats ToolCorrectionStats + mu sync.RWMutex } // NewCodexToolCorrector 创建新的工具修正器 @@ -251,8 +251,8 @@ func (c *CodexToolCorrector) correctToolParameters(toolName string, functionCall // recordCorrection 记录一次工具名称修正 func (c *CodexToolCorrector) recordCorrection(from, to string) { - c.stats.mu.Lock() - defer c.stats.mu.Unlock() + c.mu.Lock() + defer c.mu.Unlock() c.stats.TotalCorrected++ key := fmt.Sprintf("%s->%s", from, to) @@ -264,8 +264,8 @@ func (c *CodexToolCorrector) recordCorrection(from, to string) { // GetStats 获取工具修正统计信息 func (c *CodexToolCorrector) GetStats() ToolCorrectionStats { - c.stats.mu.RLock() - defer c.stats.mu.RUnlock() + c.mu.RLock() + defer c.mu.RUnlock() // 返回副本以避免并发问题 statsCopy := ToolCorrectionStats{ @@ -281,8 +281,8 @@ func (c *CodexToolCorrector) GetStats() ToolCorrectionStats { // ResetStats 重置统计信息 func (c *CodexToolCorrector) ResetStats() { - c.stats.mu.Lock() - defer c.stats.mu.Unlock() + c.mu.Lock() + defer c.mu.Unlock() c.stats.TotalCorrected = 0 c.stats.CorrectionsByTool = make(map[string]int) diff --git a/backend/internal/service/openai_tool_corrector_test.go b/backend/internal/service/openai_tool_corrector_test.go index 7219c2af..a1c4530a 100644 --- a/backend/internal/service/openai_tool_corrector_test.go +++ b/backend/internal/service/openai_tool_corrector_test.go @@ -38,8 +38,18 @@ func TestCorrectToolCallsInSSEData(t *testing.T) { if err := json.Unmarshal([]byte(result), &payload); err != nil { t.Fatalf("Failed to parse result: %v", err) } - toolCalls := payload["tool_calls"].([]any) - functionCall := toolCalls[0].(map[string]any)["function"].(map[string]any) + toolCalls, ok := payload["tool_calls"].([]any) + if !ok || len(toolCalls) == 0 { + t.Fatal("No tool_calls found in result") + } + toolCall, ok := toolCalls[0].(map[string]any) + if !ok { + t.Fatal("Invalid tool_call format") + } + functionCall, ok := toolCall["function"].(map[string]any) + if !ok { + t.Fatal("Invalid function format") + } if functionCall["name"] != "edit" { t.Errorf("Expected tool name 'edit', got '%v'", functionCall["name"]) } @@ -54,7 +64,10 @@ func TestCorrectToolCallsInSSEData(t *testing.T) { if err := json.Unmarshal([]byte(result), &payload); err != nil { t.Fatalf("Failed to parse result: %v", err) } - functionCall := payload["function_call"].(map[string]any) + functionCall, ok := payload["function_call"].(map[string]any) + if !ok { + t.Fatal("Invalid function_call format") + } if functionCall["name"] != "todowrite" { t.Errorf("Expected tool name 'todowrite', got '%v'", functionCall["name"]) } @@ -131,8 +144,18 @@ func TestCorrectToolCallsInSSEData(t *testing.T) { if err := json.Unmarshal([]byte(result), &payload); err != nil { t.Fatalf("Failed to parse result: %v", err) } - toolCalls := payload["tool_calls"].([]any) - functionCall := toolCalls[0].(map[string]any)["function"].(map[string]any) + toolCalls, ok := payload["tool_calls"].([]any) + if !ok || len(toolCalls) == 0 { + t.Fatal("No tool_calls found in result") + } + toolCall, ok := toolCalls[0].(map[string]any) + if !ok { + t.Fatal("Invalid tool_call format") + } + functionCall, ok := toolCall["function"].(map[string]any) + if !ok { + t.Fatal("Invalid function format") + } if functionCall["name"] != "edit" { t.Errorf("Expected tool name 'edit', got '%v'", functionCall["name"]) } From 415840088e80ffed0a3592553e9e645a8fed7c74 Mon Sep 17 00:00:00 2001 From: IanShaw027 <131567472+IanShaw027@users.noreply.github.com> Date: Fri, 16 Jan 2026 00:14:19 +0800 Subject: [PATCH 3/3] =?UTF-8?q?fix(lint):=20=E4=BF=AE=E5=A4=8D=E5=89=A9?= =?UTF-8?q?=E4=BD=99=E7=9A=84errcheck=E9=94=99=E8=AF=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 修复了测试文件中剩余的6处类型断言未检查错误: - 第115-118行:choices.message.tool_calls 的类型断言链 - 第140和145行:multiple tool calls 测试的类型断言 - 第343和345行:ComplexSSEData 测试的类型断言 **修复模式:** 所有类型断言都改为使用 ok 检查: ```go // 修复前 choices := payload["choices"].([]any) // 修复后 choices, ok := payload["choices"].([]any) if !ok || len(choices) == 0 { t.Fatal("No choices found in result") } ``` **测试验证:** - ✅ TestCorrectToolCallsInSSEData - 所有子测试通过 - ✅ TestComplexSSEData - 通过 - ✅ TestCorrectToolParameters - 通过 - ✅ 所有类型断言都有 ok 检查 - ✅ 添加了数组长度验证 现在所有 errcheck 错误都已修复。 --- .../service/openai_tool_corrector_test.go | 98 ++++++++++++++++--- 1 file changed, 84 insertions(+), 14 deletions(-) diff --git a/backend/internal/service/openai_tool_corrector_test.go b/backend/internal/service/openai_tool_corrector_test.go index a1c4530a..3e885b4b 100644 --- a/backend/internal/service/openai_tool_corrector_test.go +++ b/backend/internal/service/openai_tool_corrector_test.go @@ -82,9 +82,22 @@ func TestCorrectToolCallsInSSEData(t *testing.T) { if err := json.Unmarshal([]byte(result), &payload); err != nil { t.Fatalf("Failed to parse result: %v", err) } - delta := payload["delta"].(map[string]any) - toolCalls := delta["tool_calls"].([]any) - functionCall := toolCalls[0].(map[string]any)["function"].(map[string]any) + delta, ok := payload["delta"].(map[string]any) + if !ok { + t.Fatal("Invalid delta format") + } + toolCalls, ok := delta["tool_calls"].([]any) + if !ok || len(toolCalls) == 0 { + t.Fatal("No tool_calls found in delta") + } + toolCall, ok := toolCalls[0].(map[string]any) + if !ok { + t.Fatal("Invalid tool_call format") + } + functionCall, ok := toolCall["function"].(map[string]any) + if !ok { + t.Fatal("Invalid function format") + } if functionCall["name"] != "grep" { t.Errorf("Expected tool name 'grep', got '%v'", functionCall["name"]) } @@ -99,10 +112,30 @@ func TestCorrectToolCallsInSSEData(t *testing.T) { if err := json.Unmarshal([]byte(result), &payload); err != nil { t.Fatalf("Failed to parse result: %v", err) } - choices := payload["choices"].([]any) - message := choices[0].(map[string]any)["message"].(map[string]any) - toolCalls := message["tool_calls"].([]any) - functionCall := toolCalls[0].(map[string]any)["function"].(map[string]any) + choices, ok := payload["choices"].([]any) + if !ok || len(choices) == 0 { + t.Fatal("No choices found in result") + } + choice, ok := choices[0].(map[string]any) + if !ok { + t.Fatal("Invalid choice format") + } + message, ok := choice["message"].(map[string]any) + if !ok { + t.Fatal("Invalid message format") + } + toolCalls, ok := message["tool_calls"].([]any) + if !ok || len(toolCalls) == 0 { + t.Fatal("No tool_calls found in message") + } + toolCall, ok := toolCalls[0].(map[string]any) + if !ok { + t.Fatal("Invalid tool_call format") + } + functionCall, ok := toolCall["function"].(map[string]any) + if !ok { + t.Fatal("Invalid function format") + } if functionCall["name"] != "glob" { t.Errorf("Expected tool name 'glob', got '%v'", functionCall["name"]) } @@ -122,14 +155,31 @@ func TestCorrectToolCallsInSSEData(t *testing.T) { if err := json.Unmarshal([]byte(result), &payload); err != nil { t.Fatalf("Failed to parse result: %v", err) } - toolCalls := payload["tool_calls"].([]any) + toolCalls, ok := payload["tool_calls"].([]any) + if !ok || len(toolCalls) < 2 { + t.Fatal("Expected at least 2 tool_calls") + } - func1 := toolCalls[0].(map[string]any)["function"].(map[string]any) + toolCall1, ok := toolCalls[0].(map[string]any) + if !ok { + t.Fatal("Invalid first tool_call format") + } + func1, ok := toolCall1["function"].(map[string]any) + if !ok { + t.Fatal("Invalid first function format") + } if func1["name"] != "edit" { t.Errorf("Expected first tool name 'edit', got '%v'", func1["name"]) } - func2 := toolCalls[1].(map[string]any)["function"].(map[string]any) + toolCall2, ok := toolCalls[1].(map[string]any) + if !ok { + t.Fatal("Invalid second tool_call format") + } + func2, ok := toolCall2["function"].(map[string]any) + if !ok { + t.Fatal("Invalid second function format") + } if func2["name"] != "read" { t.Errorf("Expected second tool name 'read', got '%v'", func2["name"]) } @@ -326,10 +376,30 @@ func TestComplexSSEData(t *testing.T) { t.Fatalf("Failed to parse result: %v", err) } - choices := payload["choices"].([]any) - delta := choices[0].(map[string]any)["delta"].(map[string]any) - toolCalls := delta["tool_calls"].([]any) - function := toolCalls[0].(map[string]any)["function"].(map[string]any) + choices, ok := payload["choices"].([]any) + if !ok || len(choices) == 0 { + t.Fatal("No choices found in result") + } + choice, ok := choices[0].(map[string]any) + if !ok { + t.Fatal("Invalid choice format") + } + delta, ok := choice["delta"].(map[string]any) + if !ok { + t.Fatal("Invalid delta format") + } + toolCalls, ok := delta["tool_calls"].([]any) + if !ok || len(toolCalls) == 0 { + t.Fatal("No tool_calls found in delta") + } + toolCall, ok := toolCalls[0].(map[string]any) + if !ok { + t.Fatal("Invalid tool_call format") + } + function, ok := toolCall["function"].(map[string]any) + if !ok { + t.Fatal("Invalid function format") + } if function["name"] != "edit" { t.Errorf("Expected tool name 'edit', got '%v'", function["name"])