feat(openai): 添加Codex工具调用自动修正功能
实现了完整的Codex工具调用拦截和自动修正系统,解决OpenCode使用Codex模型时的工具调用兼容性问题。 **核心功能:** 1. **工具名称自动映射** - apply_patch/applyPatch → edit - update_plan/updatePlan → todowrite - read_plan/readPlan → todoread - search_files/searchFiles → grep - list_files/listFiles → glob - read_file/readFile → read - write_file/writeFile → write - execute_bash/executeBash/exec_bash/execBash → bash 2. **工具参数自动修正** - bash: 自动移除不支持的 workdir/work_dir 参数 - edit: 自动将 path 参数重命名为 file_path - 支持 JSON 字符串和对象两种参数格式 3. **流式响应集成** - 在 SSE 数据流中实时修正工具调用 - 支持多种 JSON 结构(tool_calls, function_call, delta, choices等) - 不影响响应性能和用户体验 4. **统计和监控** - 记录每次工具修正的详细信息 - 提供修正统计数据查询 - 便于问题排查和性能优化 **实现文件:** - `openai_tool_corrector.go`: 工具修正核心逻辑(250行) - `openai_tool_corrector_test.go`: 完整的单元测试(380+行) - `openai_gateway_service.go`: 流式响应集成 - `openai_gateway_service_tool_correction_test.go`: 集成测试 **测试覆盖:** - 工具名称映射测试(18个映射规则) - 参数修正测试(bash workdir、edit path等) - SSE数据修正测试(多种JSON结构) - 统计功能测试 - 所有测试通过 ✅ **解决的问题:** 修复了 OpenCode 使用 sub2api 中转 Codex 时,因工具名称和参数不兼容导致的工具调用失败问题。 Codex 模型有时会忽略指令文件中的工具映射说明,导致调用不存在的工具(如 apply_patch)。 现在通过流式响应拦截,自动将错误的工具调用修正为 OpenCode 兼容的格式。 **参考文档:** - OpenCode 工具规范: https://opencode.ai/docs/ - Codex Bridge 指令: backend/internal/service/prompts/codex_opencode_bridge.txt
This commit is contained in:
@@ -94,6 +94,7 @@ type OpenAIGatewayService struct {
|
||||
httpUpstream HTTPUpstream
|
||||
deferredService *DeferredService
|
||||
openAITokenProvider *OpenAITokenProvider
|
||||
toolCorrector *CodexToolCorrector
|
||||
}
|
||||
|
||||
// NewOpenAIGatewayService creates a new OpenAIGatewayService
|
||||
@@ -128,6 +129,7 @@ func NewOpenAIGatewayService(
|
||||
httpUpstream: httpUpstream,
|
||||
deferredService: deferredService,
|
||||
openAITokenProvider: openAITokenProvider,
|
||||
toolCorrector: NewCodexToolCorrector(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1106,6 +1108,11 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
||||
line = s.replaceModelInSSELine(line, mappedModel, originalModel)
|
||||
}
|
||||
|
||||
// Correct Codex tool calls if needed (apply_patch -> edit, etc.)
|
||||
if correctedData, corrected := s.toolCorrector.CorrectToolCallsInSSEData(data); corrected {
|
||||
line = "data: " + correctedData
|
||||
}
|
||||
|
||||
// Forward line
|
||||
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
|
||||
sendErrorEvent("write_failed")
|
||||
@@ -1193,6 +1200,20 @@ func (s *OpenAIGatewayService) replaceModelInSSELine(line, fromModel, toModel st
|
||||
return line
|
||||
}
|
||||
|
||||
// correctToolCallsInResponseBody 修正响应体中的工具调用
|
||||
func (s *OpenAIGatewayService) correctToolCallsInResponseBody(body []byte) []byte {
|
||||
if len(body) == 0 {
|
||||
return body
|
||||
}
|
||||
|
||||
bodyStr := string(body)
|
||||
corrected, changed := s.toolCorrector.CorrectToolCallsInSSEData(bodyStr)
|
||||
if changed {
|
||||
return []byte(corrected)
|
||||
}
|
||||
return body
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) parseSSEUsage(data string, usage *OpenAIUsage) {
|
||||
// Parse response.completed event for usage (OpenAI Responses format)
|
||||
var event struct {
|
||||
@@ -1296,6 +1317,8 @@ func (s *OpenAIGatewayService) handleOAuthSSEToJSON(resp *http.Response, c *gin.
|
||||
if originalModel != mappedModel {
|
||||
body = s.replaceModelInResponseBody(body, mappedModel, originalModel)
|
||||
}
|
||||
// Correct tool calls in final response
|
||||
body = s.correctToolCallsInResponseBody(body)
|
||||
} else {
|
||||
usage = s.parseSSEUsageFromBody(bodyText)
|
||||
if originalModel != mappedModel {
|
||||
|
||||
@@ -0,0 +1,133 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestOpenAIGatewayService_ToolCorrection 测试 OpenAIGatewayService 中的工具修正集成
|
||||
func TestOpenAIGatewayService_ToolCorrection(t *testing.T) {
|
||||
// 创建一个简单的 service 实例来测试工具修正
|
||||
service := &OpenAIGatewayService{
|
||||
toolCorrector: NewCodexToolCorrector(),
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input []byte
|
||||
expected string
|
||||
changed bool
|
||||
}{
|
||||
{
|
||||
name: "correct apply_patch in response body",
|
||||
input: []byte(`{
|
||||
"choices": [{
|
||||
"message": {
|
||||
"tool_calls": [{
|
||||
"function": {"name": "apply_patch"}
|
||||
}]
|
||||
}
|
||||
}]
|
||||
}`),
|
||||
expected: "edit",
|
||||
changed: true,
|
||||
},
|
||||
{
|
||||
name: "correct update_plan in response body",
|
||||
input: []byte(`{
|
||||
"tool_calls": [{
|
||||
"function": {"name": "update_plan"}
|
||||
}]
|
||||
}`),
|
||||
expected: "todowrite",
|
||||
changed: true,
|
||||
},
|
||||
{
|
||||
name: "no change for correct tool name",
|
||||
input: []byte(`{
|
||||
"tool_calls": [{
|
||||
"function": {"name": "edit"}
|
||||
}]
|
||||
}`),
|
||||
expected: "edit",
|
||||
changed: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := service.correctToolCallsInResponseBody(tt.input)
|
||||
resultStr := string(result)
|
||||
|
||||
// 检查是否包含期望的工具名称
|
||||
if !strings.Contains(resultStr, tt.expected) {
|
||||
t.Errorf("expected result to contain %q, got %q", tt.expected, resultStr)
|
||||
}
|
||||
|
||||
// 对于预期有变化的情况,验证结果与输入不同
|
||||
if tt.changed && string(result) == string(tt.input) {
|
||||
t.Error("expected result to be different from input, but they are the same")
|
||||
}
|
||||
|
||||
// 对于预期无变化的情况,验证结果与输入相同
|
||||
if !tt.changed && string(result) != string(tt.input) {
|
||||
t.Error("expected result to be same as input, but they are different")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestOpenAIGatewayService_ToolCorrectorInitialization 测试工具修正器是否正确初始化
|
||||
func TestOpenAIGatewayService_ToolCorrectorInitialization(t *testing.T) {
|
||||
service := &OpenAIGatewayService{
|
||||
toolCorrector: NewCodexToolCorrector(),
|
||||
}
|
||||
|
||||
if service.toolCorrector == nil {
|
||||
t.Fatal("toolCorrector should not be nil")
|
||||
}
|
||||
|
||||
// 测试修正器可以正常工作
|
||||
data := `{"tool_calls":[{"function":{"name":"apply_patch"}}]}`
|
||||
corrected, changed := service.toolCorrector.CorrectToolCallsInSSEData(data)
|
||||
|
||||
if !changed {
|
||||
t.Error("expected tool call to be corrected")
|
||||
}
|
||||
|
||||
if !strings.Contains(corrected, "edit") {
|
||||
t.Errorf("expected corrected data to contain 'edit', got %q", corrected)
|
||||
}
|
||||
}
|
||||
|
||||
// TestToolCorrectionStats 测试工具修正统计功能
|
||||
func TestToolCorrectionStats(t *testing.T) {
|
||||
service := &OpenAIGatewayService{
|
||||
toolCorrector: NewCodexToolCorrector(),
|
||||
}
|
||||
|
||||
// 执行几次修正
|
||||
testData := []string{
|
||||
`{"tool_calls":[{"function":{"name":"apply_patch"}}]}`,
|
||||
`{"tool_calls":[{"function":{"name":"update_plan"}}]}`,
|
||||
`{"tool_calls":[{"function":{"name":"apply_patch"}}]}`,
|
||||
}
|
||||
|
||||
for _, data := range testData {
|
||||
service.toolCorrector.CorrectToolCallsInSSEData(data)
|
||||
}
|
||||
|
||||
stats := service.toolCorrector.GetStats()
|
||||
|
||||
if stats.TotalCorrected != 3 {
|
||||
t.Errorf("expected 3 corrections, got %d", stats.TotalCorrected)
|
||||
}
|
||||
|
||||
if stats.CorrectionsByTool["apply_patch->edit"] != 2 {
|
||||
t.Errorf("expected 2 apply_patch->edit corrections, got %d", stats.CorrectionsByTool["apply_patch->edit"])
|
||||
}
|
||||
|
||||
if stats.CorrectionsByTool["update_plan->todowrite"] != 1 {
|
||||
t.Errorf("expected 1 update_plan->todowrite correction, got %d", stats.CorrectionsByTool["update_plan->todowrite"])
|
||||
}
|
||||
}
|
||||
307
backend/internal/service/openai_tool_corrector.go
Normal file
307
backend/internal/service/openai_tool_corrector.go
Normal file
@@ -0,0 +1,307 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// codexToolNameMapping 定义 Codex 原生工具名称到 OpenCode 工具名称的映射
|
||||
var codexToolNameMapping = map[string]string{
|
||||
"apply_patch": "edit",
|
||||
"applyPatch": "edit",
|
||||
"update_plan": "todowrite",
|
||||
"updatePlan": "todowrite",
|
||||
"read_plan": "todoread",
|
||||
"readPlan": "todoread",
|
||||
"search_files": "grep",
|
||||
"searchFiles": "grep",
|
||||
"list_files": "glob",
|
||||
"listFiles": "glob",
|
||||
"read_file": "read",
|
||||
"readFile": "read",
|
||||
"write_file": "write",
|
||||
"writeFile": "write",
|
||||
"execute_bash": "bash",
|
||||
"executeBash": "bash",
|
||||
"exec_bash": "bash",
|
||||
"execBash": "bash",
|
||||
}
|
||||
|
||||
// ToolCorrectionStats 记录工具修正的统计信息
|
||||
type ToolCorrectionStats struct {
|
||||
TotalCorrected int `json:"total_corrected"`
|
||||
CorrectionsByTool map[string]int `json:"corrections_by_tool"`
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// CodexToolCorrector 处理 Codex 工具调用的自动修正
|
||||
type CodexToolCorrector struct {
|
||||
stats ToolCorrectionStats
|
||||
}
|
||||
|
||||
// NewCodexToolCorrector 创建新的工具修正器
|
||||
func NewCodexToolCorrector() *CodexToolCorrector {
|
||||
return &CodexToolCorrector{
|
||||
stats: ToolCorrectionStats{
|
||||
CorrectionsByTool: make(map[string]int),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// CorrectToolCallsInSSEData 修正 SSE 数据中的工具调用
|
||||
// 返回修正后的数据和是否进行了修正
|
||||
func (c *CodexToolCorrector) CorrectToolCallsInSSEData(data string) (string, bool) {
|
||||
if data == "" || data == "\n" {
|
||||
return data, false
|
||||
}
|
||||
|
||||
// 尝试解析 JSON
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal([]byte(data), &payload); err != nil {
|
||||
// 不是有效的 JSON,直接返回原数据
|
||||
return data, false
|
||||
}
|
||||
|
||||
corrected := false
|
||||
|
||||
// 处理 tool_calls 数组
|
||||
if toolCalls, ok := payload["tool_calls"].([]any); ok {
|
||||
if c.correctToolCallsArray(toolCalls) {
|
||||
corrected = true
|
||||
}
|
||||
}
|
||||
|
||||
// 处理 function_call 对象
|
||||
if functionCall, ok := payload["function_call"].(map[string]any); ok {
|
||||
if c.correctFunctionCall(functionCall) {
|
||||
corrected = true
|
||||
}
|
||||
}
|
||||
|
||||
// 处理 delta.tool_calls
|
||||
if delta, ok := payload["delta"].(map[string]any); ok {
|
||||
if toolCalls, ok := delta["tool_calls"].([]any); ok {
|
||||
if c.correctToolCallsArray(toolCalls) {
|
||||
corrected = true
|
||||
}
|
||||
}
|
||||
if functionCall, ok := delta["function_call"].(map[string]any); ok {
|
||||
if c.correctFunctionCall(functionCall) {
|
||||
corrected = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 处理 choices[].message.tool_calls 和 choices[].delta.tool_calls
|
||||
if choices, ok := payload["choices"].([]any); ok {
|
||||
for _, choice := range choices {
|
||||
if choiceMap, ok := choice.(map[string]any); ok {
|
||||
// 处理 message 中的工具调用
|
||||
if message, ok := choiceMap["message"].(map[string]any); ok {
|
||||
if toolCalls, ok := message["tool_calls"].([]any); ok {
|
||||
if c.correctToolCallsArray(toolCalls) {
|
||||
corrected = true
|
||||
}
|
||||
}
|
||||
if functionCall, ok := message["function_call"].(map[string]any); ok {
|
||||
if c.correctFunctionCall(functionCall) {
|
||||
corrected = true
|
||||
}
|
||||
}
|
||||
}
|
||||
// 处理 delta 中的工具调用
|
||||
if delta, ok := choiceMap["delta"].(map[string]any); ok {
|
||||
if toolCalls, ok := delta["tool_calls"].([]any); ok {
|
||||
if c.correctToolCallsArray(toolCalls) {
|
||||
corrected = true
|
||||
}
|
||||
}
|
||||
if functionCall, ok := delta["function_call"].(map[string]any); ok {
|
||||
if c.correctFunctionCall(functionCall) {
|
||||
corrected = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !corrected {
|
||||
return data, false
|
||||
}
|
||||
|
||||
// 序列化回 JSON
|
||||
correctedBytes, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
log.Printf("[CodexToolCorrector] Failed to marshal corrected data: %v", err)
|
||||
return data, false
|
||||
}
|
||||
|
||||
return string(correctedBytes), true
|
||||
}
|
||||
|
||||
// correctToolCallsArray 修正工具调用数组中的工具名称
|
||||
func (c *CodexToolCorrector) correctToolCallsArray(toolCalls []any) bool {
|
||||
corrected := false
|
||||
for _, toolCall := range toolCalls {
|
||||
if toolCallMap, ok := toolCall.(map[string]any); ok {
|
||||
if function, ok := toolCallMap["function"].(map[string]any); ok {
|
||||
if c.correctFunctionCall(function) {
|
||||
corrected = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return corrected
|
||||
}
|
||||
|
||||
// correctFunctionCall 修正单个函数调用的工具名称和参数
|
||||
func (c *CodexToolCorrector) correctFunctionCall(functionCall map[string]any) bool {
|
||||
name, ok := functionCall["name"].(string)
|
||||
if !ok || name == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
corrected := false
|
||||
|
||||
// 查找并修正工具名称
|
||||
if correctName, found := codexToolNameMapping[name]; found {
|
||||
functionCall["name"] = correctName
|
||||
c.recordCorrection(name, correctName)
|
||||
corrected = true
|
||||
name = correctName // 使用修正后的名称进行参数修正
|
||||
}
|
||||
|
||||
// 修正工具参数(基于工具名称)
|
||||
if c.correctToolParameters(name, functionCall) {
|
||||
corrected = true
|
||||
}
|
||||
|
||||
return corrected
|
||||
}
|
||||
|
||||
// correctToolParameters 修正工具参数以符合 OpenCode 规范
|
||||
func (c *CodexToolCorrector) correctToolParameters(toolName string, functionCall map[string]any) bool {
|
||||
arguments, ok := functionCall["arguments"]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
// arguments 可能是字符串(JSON)或已解析的 map
|
||||
var argsMap map[string]any
|
||||
switch v := arguments.(type) {
|
||||
case string:
|
||||
// 解析 JSON 字符串
|
||||
if err := json.Unmarshal([]byte(v), &argsMap); err != nil {
|
||||
return false
|
||||
}
|
||||
case map[string]any:
|
||||
argsMap = v
|
||||
default:
|
||||
return false
|
||||
}
|
||||
|
||||
corrected := false
|
||||
|
||||
// 根据工具名称应用特定的参数修正规则
|
||||
switch toolName {
|
||||
case "bash":
|
||||
// 移除 workdir 参数(OpenCode 不支持)
|
||||
if _, exists := argsMap["workdir"]; exists {
|
||||
delete(argsMap, "workdir")
|
||||
corrected = true
|
||||
log.Printf("[CodexToolCorrector] Removed 'workdir' parameter from bash tool")
|
||||
}
|
||||
if _, exists := argsMap["work_dir"]; exists {
|
||||
delete(argsMap, "work_dir")
|
||||
corrected = true
|
||||
log.Printf("[CodexToolCorrector] Removed 'work_dir' parameter from bash tool")
|
||||
}
|
||||
|
||||
case "edit":
|
||||
// OpenCode edit 使用 old_string/new_string,Codex 可能使用其他名称
|
||||
// 这里可以添加参数名称的映射逻辑
|
||||
if _, exists := argsMap["file_path"]; !exists {
|
||||
if path, exists := argsMap["path"]; exists {
|
||||
argsMap["file_path"] = path
|
||||
delete(argsMap, "path")
|
||||
corrected = true
|
||||
log.Printf("[CodexToolCorrector] Renamed 'path' to 'file_path' in edit tool")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 如果修正了参数,需要重新序列化
|
||||
if corrected {
|
||||
if _, wasString := arguments.(string); wasString {
|
||||
// 原本是字符串,序列化回字符串
|
||||
if newArgsJSON, err := json.Marshal(argsMap); err == nil {
|
||||
functionCall["arguments"] = string(newArgsJSON)
|
||||
}
|
||||
} else {
|
||||
// 原本是 map,直接赋值
|
||||
functionCall["arguments"] = argsMap
|
||||
}
|
||||
}
|
||||
|
||||
return corrected
|
||||
}
|
||||
|
||||
// recordCorrection 记录一次工具名称修正
|
||||
func (c *CodexToolCorrector) recordCorrection(from, to string) {
|
||||
c.stats.mu.Lock()
|
||||
defer c.stats.mu.Unlock()
|
||||
|
||||
c.stats.TotalCorrected++
|
||||
key := fmt.Sprintf("%s->%s", from, to)
|
||||
c.stats.CorrectionsByTool[key]++
|
||||
|
||||
log.Printf("[CodexToolCorrector] Corrected tool call: %s -> %s (total: %d)",
|
||||
from, to, c.stats.TotalCorrected)
|
||||
}
|
||||
|
||||
// GetStats 获取工具修正统计信息
|
||||
func (c *CodexToolCorrector) GetStats() ToolCorrectionStats {
|
||||
c.stats.mu.RLock()
|
||||
defer c.stats.mu.RUnlock()
|
||||
|
||||
// 返回副本以避免并发问题
|
||||
statsCopy := ToolCorrectionStats{
|
||||
TotalCorrected: c.stats.TotalCorrected,
|
||||
CorrectionsByTool: make(map[string]int, len(c.stats.CorrectionsByTool)),
|
||||
}
|
||||
for k, v := range c.stats.CorrectionsByTool {
|
||||
statsCopy.CorrectionsByTool[k] = v
|
||||
}
|
||||
|
||||
return statsCopy
|
||||
}
|
||||
|
||||
// ResetStats 重置统计信息
|
||||
func (c *CodexToolCorrector) ResetStats() {
|
||||
c.stats.mu.Lock()
|
||||
defer c.stats.mu.Unlock()
|
||||
|
||||
c.stats.TotalCorrected = 0
|
||||
c.stats.CorrectionsByTool = make(map[string]int)
|
||||
}
|
||||
|
||||
// CorrectToolName 直接修正工具名称(用于非 SSE 场景)
|
||||
func CorrectToolName(name string) (string, bool) {
|
||||
if correctName, found := codexToolNameMapping[name]; found {
|
||||
return correctName, true
|
||||
}
|
||||
return name, false
|
||||
}
|
||||
|
||||
// GetToolNameMapping 获取工具名称映射表
|
||||
func GetToolNameMapping() map[string]string {
|
||||
// 返回副本以避免外部修改
|
||||
mapping := make(map[string]string, len(codexToolNameMapping))
|
||||
for k, v := range codexToolNameMapping {
|
||||
mapping[k] = v
|
||||
}
|
||||
return mapping
|
||||
}
|
||||
410
backend/internal/service/openai_tool_corrector_test.go
Normal file
410
backend/internal/service/openai_tool_corrector_test.go
Normal file
@@ -0,0 +1,410 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCorrectToolCallsInSSEData(t *testing.T) {
|
||||
corrector := NewCodexToolCorrector()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expectCorrected bool
|
||||
checkFunc func(t *testing.T, result string)
|
||||
}{
|
||||
{
|
||||
name: "empty string",
|
||||
input: "",
|
||||
expectCorrected: false,
|
||||
},
|
||||
{
|
||||
name: "newline only",
|
||||
input: "\n",
|
||||
expectCorrected: false,
|
||||
},
|
||||
{
|
||||
name: "invalid json",
|
||||
input: "not a json",
|
||||
expectCorrected: false,
|
||||
},
|
||||
{
|
||||
name: "correct apply_patch in tool_calls",
|
||||
input: `{"tool_calls":[{"function":{"name":"apply_patch","arguments":"{}"}}]}`,
|
||||
expectCorrected: true,
|
||||
checkFunc: func(t *testing.T, result string) {
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal([]byte(result), &payload); err != nil {
|
||||
t.Fatalf("Failed to parse result: %v", err)
|
||||
}
|
||||
toolCalls := payload["tool_calls"].([]any)
|
||||
functionCall := toolCalls[0].(map[string]any)["function"].(map[string]any)
|
||||
if functionCall["name"] != "edit" {
|
||||
t.Errorf("Expected tool name 'edit', got '%v'", functionCall["name"])
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "correct update_plan in function_call",
|
||||
input: `{"function_call":{"name":"update_plan","arguments":"{}"}}`,
|
||||
expectCorrected: true,
|
||||
checkFunc: func(t *testing.T, result string) {
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal([]byte(result), &payload); err != nil {
|
||||
t.Fatalf("Failed to parse result: %v", err)
|
||||
}
|
||||
functionCall := payload["function_call"].(map[string]any)
|
||||
if functionCall["name"] != "todowrite" {
|
||||
t.Errorf("Expected tool name 'todowrite', got '%v'", functionCall["name"])
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "correct search_files in delta.tool_calls",
|
||||
input: `{"delta":{"tool_calls":[{"function":{"name":"search_files"}}]}}`,
|
||||
expectCorrected: true,
|
||||
checkFunc: func(t *testing.T, result string) {
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal([]byte(result), &payload); err != nil {
|
||||
t.Fatalf("Failed to parse result: %v", err)
|
||||
}
|
||||
delta := payload["delta"].(map[string]any)
|
||||
toolCalls := delta["tool_calls"].([]any)
|
||||
functionCall := toolCalls[0].(map[string]any)["function"].(map[string]any)
|
||||
if functionCall["name"] != "grep" {
|
||||
t.Errorf("Expected tool name 'grep', got '%v'", functionCall["name"])
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "correct list_files in choices.message.tool_calls",
|
||||
input: `{"choices":[{"message":{"tool_calls":[{"function":{"name":"list_files"}}]}}]}`,
|
||||
expectCorrected: true,
|
||||
checkFunc: func(t *testing.T, result string) {
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal([]byte(result), &payload); err != nil {
|
||||
t.Fatalf("Failed to parse result: %v", err)
|
||||
}
|
||||
choices := payload["choices"].([]any)
|
||||
message := choices[0].(map[string]any)["message"].(map[string]any)
|
||||
toolCalls := message["tool_calls"].([]any)
|
||||
functionCall := toolCalls[0].(map[string]any)["function"].(map[string]any)
|
||||
if functionCall["name"] != "glob" {
|
||||
t.Errorf("Expected tool name 'glob', got '%v'", functionCall["name"])
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "no correction needed",
|
||||
input: `{"tool_calls":[{"function":{"name":"read","arguments":"{}"}}]}`,
|
||||
expectCorrected: false,
|
||||
},
|
||||
{
|
||||
name: "correct multiple tool calls",
|
||||
input: `{"tool_calls":[{"function":{"name":"apply_patch"}},{"function":{"name":"read_file"}}]}`,
|
||||
expectCorrected: true,
|
||||
checkFunc: func(t *testing.T, result string) {
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal([]byte(result), &payload); err != nil {
|
||||
t.Fatalf("Failed to parse result: %v", err)
|
||||
}
|
||||
toolCalls := payload["tool_calls"].([]any)
|
||||
|
||||
func1 := toolCalls[0].(map[string]any)["function"].(map[string]any)
|
||||
if func1["name"] != "edit" {
|
||||
t.Errorf("Expected first tool name 'edit', got '%v'", func1["name"])
|
||||
}
|
||||
|
||||
func2 := toolCalls[1].(map[string]any)["function"].(map[string]any)
|
||||
if func2["name"] != "read" {
|
||||
t.Errorf("Expected second tool name 'read', got '%v'", func2["name"])
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "camelCase format - applyPatch",
|
||||
input: `{"tool_calls":[{"function":{"name":"applyPatch"}}]}`,
|
||||
expectCorrected: true,
|
||||
checkFunc: func(t *testing.T, result string) {
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal([]byte(result), &payload); err != nil {
|
||||
t.Fatalf("Failed to parse result: %v", err)
|
||||
}
|
||||
toolCalls := payload["tool_calls"].([]any)
|
||||
functionCall := toolCalls[0].(map[string]any)["function"].(map[string]any)
|
||||
if functionCall["name"] != "edit" {
|
||||
t.Errorf("Expected tool name 'edit', got '%v'", functionCall["name"])
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, corrected := corrector.CorrectToolCallsInSSEData(tt.input)
|
||||
|
||||
if corrected != tt.expectCorrected {
|
||||
t.Errorf("Expected corrected=%v, got %v", tt.expectCorrected, corrected)
|
||||
}
|
||||
|
||||
if !corrected && result != tt.input {
|
||||
t.Errorf("Expected unchanged result when not corrected")
|
||||
}
|
||||
|
||||
if tt.checkFunc != nil {
|
||||
tt.checkFunc(t, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCorrectToolName(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
expected string
|
||||
corrected bool
|
||||
}{
|
||||
{"apply_patch", "edit", true},
|
||||
{"applyPatch", "edit", true},
|
||||
{"update_plan", "todowrite", true},
|
||||
{"updatePlan", "todowrite", true},
|
||||
{"read_plan", "todoread", true},
|
||||
{"readPlan", "todoread", true},
|
||||
{"search_files", "grep", true},
|
||||
{"searchFiles", "grep", true},
|
||||
{"list_files", "glob", true},
|
||||
{"listFiles", "glob", true},
|
||||
{"read_file", "read", true},
|
||||
{"readFile", "read", true},
|
||||
{"write_file", "write", true},
|
||||
{"writeFile", "write", true},
|
||||
{"execute_bash", "bash", true},
|
||||
{"executeBash", "bash", true},
|
||||
{"exec_bash", "bash", true},
|
||||
{"execBash", "bash", true},
|
||||
{"unknown_tool", "unknown_tool", false},
|
||||
{"read", "read", false},
|
||||
{"edit", "edit", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.input, func(t *testing.T) {
|
||||
result, corrected := CorrectToolName(tt.input)
|
||||
|
||||
if corrected != tt.corrected {
|
||||
t.Errorf("Expected corrected=%v, got %v", tt.corrected, corrected)
|
||||
}
|
||||
|
||||
if result != tt.expected {
|
||||
t.Errorf("Expected '%s', got '%s'", tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetToolNameMapping(t *testing.T) {
|
||||
mapping := GetToolNameMapping()
|
||||
|
||||
expectedMappings := map[string]string{
|
||||
"apply_patch": "edit",
|
||||
"update_plan": "todowrite",
|
||||
"read_plan": "todoread",
|
||||
"search_files": "grep",
|
||||
"list_files": "glob",
|
||||
}
|
||||
|
||||
for from, to := range expectedMappings {
|
||||
if mapping[from] != to {
|
||||
t.Errorf("Expected mapping[%s] = %s, got %s", from, to, mapping[from])
|
||||
}
|
||||
}
|
||||
|
||||
mapping["test_tool"] = "test_value"
|
||||
newMapping := GetToolNameMapping()
|
||||
if _, exists := newMapping["test_tool"]; exists {
|
||||
t.Error("Modifications to returned mapping should not affect original")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCorrectorStats(t *testing.T) {
|
||||
corrector := NewCodexToolCorrector()
|
||||
|
||||
stats := corrector.GetStats()
|
||||
if stats.TotalCorrected != 0 {
|
||||
t.Errorf("Expected TotalCorrected=0, got %d", stats.TotalCorrected)
|
||||
}
|
||||
if len(stats.CorrectionsByTool) != 0 {
|
||||
t.Errorf("Expected empty CorrectionsByTool, got length %d", len(stats.CorrectionsByTool))
|
||||
}
|
||||
|
||||
corrector.CorrectToolCallsInSSEData(`{"tool_calls":[{"function":{"name":"apply_patch"}}]}`)
|
||||
corrector.CorrectToolCallsInSSEData(`{"tool_calls":[{"function":{"name":"apply_patch"}}]}`)
|
||||
corrector.CorrectToolCallsInSSEData(`{"tool_calls":[{"function":{"name":"update_plan"}}]}`)
|
||||
|
||||
stats = corrector.GetStats()
|
||||
if stats.TotalCorrected != 3 {
|
||||
t.Errorf("Expected TotalCorrected=3, got %d", stats.TotalCorrected)
|
||||
}
|
||||
|
||||
if stats.CorrectionsByTool["apply_patch->edit"] != 2 {
|
||||
t.Errorf("Expected apply_patch->edit count=2, got %d", stats.CorrectionsByTool["apply_patch->edit"])
|
||||
}
|
||||
|
||||
if stats.CorrectionsByTool["update_plan->todowrite"] != 1 {
|
||||
t.Errorf("Expected update_plan->todowrite count=1, got %d", stats.CorrectionsByTool["update_plan->todowrite"])
|
||||
}
|
||||
|
||||
corrector.ResetStats()
|
||||
stats = corrector.GetStats()
|
||||
if stats.TotalCorrected != 0 {
|
||||
t.Errorf("Expected TotalCorrected=0 after reset, got %d", stats.TotalCorrected)
|
||||
}
|
||||
if len(stats.CorrectionsByTool) != 0 {
|
||||
t.Errorf("Expected empty CorrectionsByTool after reset, got length %d", len(stats.CorrectionsByTool))
|
||||
}
|
||||
}
|
||||
|
||||
func TestComplexSSEData(t *testing.T) {
|
||||
corrector := NewCodexToolCorrector()
|
||||
|
||||
input := `{
|
||||
"id": "chatcmpl-123",
|
||||
"object": "chat.completion.chunk",
|
||||
"created": 1234567890,
|
||||
"model": "gpt-5.1-codex",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {
|
||||
"tool_calls": [
|
||||
{
|
||||
"index": 0,
|
||||
"function": {
|
||||
"name": "apply_patch",
|
||||
"arguments": "{\"file\":\"test.go\"}"
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
"finish_reason": null
|
||||
}
|
||||
]
|
||||
}`
|
||||
|
||||
result, corrected := corrector.CorrectToolCallsInSSEData(input)
|
||||
|
||||
if !corrected {
|
||||
t.Error("Expected data to be corrected")
|
||||
}
|
||||
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal([]byte(result), &payload); err != nil {
|
||||
t.Fatalf("Failed to parse result: %v", err)
|
||||
}
|
||||
|
||||
choices := payload["choices"].([]any)
|
||||
delta := choices[0].(map[string]any)["delta"].(map[string]any)
|
||||
toolCalls := delta["tool_calls"].([]any)
|
||||
function := toolCalls[0].(map[string]any)["function"].(map[string]any)
|
||||
|
||||
if function["name"] != "edit" {
|
||||
t.Errorf("Expected tool name 'edit', got '%v'", function["name"])
|
||||
}
|
||||
}
|
||||
|
||||
// TestCorrectToolParameters 测试工具参数修正
|
||||
func TestCorrectToolParameters(t *testing.T) {
|
||||
corrector := NewCodexToolCorrector()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected map[string]bool // key: 期待存在的参数, value: true表示应该存在
|
||||
}{
|
||||
{
|
||||
name: "remove workdir from bash tool",
|
||||
input: `{
|
||||
"tool_calls": [{
|
||||
"function": {
|
||||
"name": "bash",
|
||||
"arguments": "{\"command\":\"ls\",\"workdir\":\"/tmp\"}"
|
||||
}
|
||||
}]
|
||||
}`,
|
||||
expected: map[string]bool{
|
||||
"command": true,
|
||||
"workdir": false,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "rename path to file_path in edit tool",
|
||||
input: `{
|
||||
"tool_calls": [{
|
||||
"function": {
|
||||
"name": "apply_patch",
|
||||
"arguments": "{\"path\":\"/foo/bar.go\",\"old_string\":\"old\",\"new_string\":\"new\"}"
|
||||
}
|
||||
}]
|
||||
}`,
|
||||
expected: map[string]bool{
|
||||
"file_path": true,
|
||||
"path": false,
|
||||
"old_string": true,
|
||||
"new_string": true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
corrected, changed := corrector.CorrectToolCallsInSSEData(tt.input)
|
||||
if !changed {
|
||||
t.Error("expected data to be corrected")
|
||||
}
|
||||
|
||||
// 解析修正后的数据
|
||||
var result map[string]any
|
||||
if err := json.Unmarshal([]byte(corrected), &result); err != nil {
|
||||
t.Fatalf("failed to parse corrected data: %v", err)
|
||||
}
|
||||
|
||||
// 检查工具调用
|
||||
toolCalls, ok := result["tool_calls"].([]any)
|
||||
if !ok || len(toolCalls) == 0 {
|
||||
t.Fatal("no tool_calls found in corrected data")
|
||||
}
|
||||
|
||||
toolCall, ok := toolCalls[0].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("invalid tool_call structure")
|
||||
}
|
||||
|
||||
function, ok := toolCall["function"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("no function found in tool_call")
|
||||
}
|
||||
|
||||
argumentsStr, ok := function["arguments"].(string)
|
||||
if !ok {
|
||||
t.Fatal("arguments is not a string")
|
||||
}
|
||||
|
||||
var args map[string]any
|
||||
if err := json.Unmarshal([]byte(argumentsStr), &args); err != nil {
|
||||
t.Fatalf("failed to parse arguments: %v", err)
|
||||
}
|
||||
|
||||
// 验证期望的参数
|
||||
for param, shouldExist := range tt.expected {
|
||||
_, exists := args[param]
|
||||
if shouldExist && !exists {
|
||||
t.Errorf("expected parameter %q to exist, but it doesn't", param)
|
||||
}
|
||||
if !shouldExist && exists {
|
||||
t.Errorf("expected parameter %q to not exist, but it does", param)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user