Merge pull request #304 from IanShaw027/feature/codex-tool-correction
feat(openai): 添加Codex工具调用自动修正功能
This commit is contained in:
@@ -94,6 +94,7 @@ type OpenAIGatewayService struct {
|
|||||||
httpUpstream HTTPUpstream
|
httpUpstream HTTPUpstream
|
||||||
deferredService *DeferredService
|
deferredService *DeferredService
|
||||||
openAITokenProvider *OpenAITokenProvider
|
openAITokenProvider *OpenAITokenProvider
|
||||||
|
toolCorrector *CodexToolCorrector
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewOpenAIGatewayService creates a new OpenAIGatewayService
|
// NewOpenAIGatewayService creates a new OpenAIGatewayService
|
||||||
@@ -128,6 +129,7 @@ func NewOpenAIGatewayService(
|
|||||||
httpUpstream: httpUpstream,
|
httpUpstream: httpUpstream,
|
||||||
deferredService: deferredService,
|
deferredService: deferredService,
|
||||||
openAITokenProvider: openAITokenProvider,
|
openAITokenProvider: openAITokenProvider,
|
||||||
|
toolCorrector: NewCodexToolCorrector(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1106,6 +1108,11 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
|||||||
line = s.replaceModelInSSELine(line, mappedModel, originalModel)
|
line = s.replaceModelInSSELine(line, mappedModel, originalModel)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Correct Codex tool calls if needed (apply_patch -> edit, etc.)
|
||||||
|
if correctedData, corrected := s.toolCorrector.CorrectToolCallsInSSEData(data); corrected {
|
||||||
|
line = "data: " + correctedData
|
||||||
|
}
|
||||||
|
|
||||||
// Forward line
|
// Forward line
|
||||||
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
|
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
|
||||||
sendErrorEvent("write_failed")
|
sendErrorEvent("write_failed")
|
||||||
@@ -1193,6 +1200,20 @@ func (s *OpenAIGatewayService) replaceModelInSSELine(line, fromModel, toModel st
|
|||||||
return line
|
return line
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// correctToolCallsInResponseBody 修正响应体中的工具调用
|
||||||
|
func (s *OpenAIGatewayService) correctToolCallsInResponseBody(body []byte) []byte {
|
||||||
|
if len(body) == 0 {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
|
bodyStr := string(body)
|
||||||
|
corrected, changed := s.toolCorrector.CorrectToolCallsInSSEData(bodyStr)
|
||||||
|
if changed {
|
||||||
|
return []byte(corrected)
|
||||||
|
}
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
func (s *OpenAIGatewayService) parseSSEUsage(data string, usage *OpenAIUsage) {
|
func (s *OpenAIGatewayService) parseSSEUsage(data string, usage *OpenAIUsage) {
|
||||||
// Parse response.completed event for usage (OpenAI Responses format)
|
// Parse response.completed event for usage (OpenAI Responses format)
|
||||||
var event struct {
|
var event struct {
|
||||||
@@ -1296,6 +1317,8 @@ func (s *OpenAIGatewayService) handleOAuthSSEToJSON(resp *http.Response, c *gin.
|
|||||||
if originalModel != mappedModel {
|
if originalModel != mappedModel {
|
||||||
body = s.replaceModelInResponseBody(body, mappedModel, originalModel)
|
body = s.replaceModelInResponseBody(body, mappedModel, originalModel)
|
||||||
}
|
}
|
||||||
|
// Correct tool calls in final response
|
||||||
|
body = s.correctToolCallsInResponseBody(body)
|
||||||
} else {
|
} else {
|
||||||
usage = s.parseSSEUsageFromBody(bodyText)
|
usage = s.parseSSEUsageFromBody(bodyText)
|
||||||
if originalModel != mappedModel {
|
if originalModel != mappedModel {
|
||||||
|
|||||||
@@ -0,0 +1,133 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestOpenAIGatewayService_ToolCorrection 测试 OpenAIGatewayService 中的工具修正集成
|
||||||
|
func TestOpenAIGatewayService_ToolCorrection(t *testing.T) {
|
||||||
|
// 创建一个简单的 service 实例来测试工具修正
|
||||||
|
service := &OpenAIGatewayService{
|
||||||
|
toolCorrector: NewCodexToolCorrector(),
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input []byte
|
||||||
|
expected string
|
||||||
|
changed bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "correct apply_patch in response body",
|
||||||
|
input: []byte(`{
|
||||||
|
"choices": [{
|
||||||
|
"message": {
|
||||||
|
"tool_calls": [{
|
||||||
|
"function": {"name": "apply_patch"}
|
||||||
|
}]
|
||||||
|
}
|
||||||
|
}]
|
||||||
|
}`),
|
||||||
|
expected: "edit",
|
||||||
|
changed: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "correct update_plan in response body",
|
||||||
|
input: []byte(`{
|
||||||
|
"tool_calls": [{
|
||||||
|
"function": {"name": "update_plan"}
|
||||||
|
}]
|
||||||
|
}`),
|
||||||
|
expected: "todowrite",
|
||||||
|
changed: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no change for correct tool name",
|
||||||
|
input: []byte(`{
|
||||||
|
"tool_calls": [{
|
||||||
|
"function": {"name": "edit"}
|
||||||
|
}]
|
||||||
|
}`),
|
||||||
|
expected: "edit",
|
||||||
|
changed: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := service.correctToolCallsInResponseBody(tt.input)
|
||||||
|
resultStr := string(result)
|
||||||
|
|
||||||
|
// 检查是否包含期望的工具名称
|
||||||
|
if !strings.Contains(resultStr, tt.expected) {
|
||||||
|
t.Errorf("expected result to contain %q, got %q", tt.expected, resultStr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 对于预期有变化的情况,验证结果与输入不同
|
||||||
|
if tt.changed && string(result) == string(tt.input) {
|
||||||
|
t.Error("expected result to be different from input, but they are the same")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 对于预期无变化的情况,验证结果与输入相同
|
||||||
|
if !tt.changed && string(result) != string(tt.input) {
|
||||||
|
t.Error("expected result to be same as input, but they are different")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestOpenAIGatewayService_ToolCorrectorInitialization 测试工具修正器是否正确初始化
|
||||||
|
func TestOpenAIGatewayService_ToolCorrectorInitialization(t *testing.T) {
|
||||||
|
service := &OpenAIGatewayService{
|
||||||
|
toolCorrector: NewCodexToolCorrector(),
|
||||||
|
}
|
||||||
|
|
||||||
|
if service.toolCorrector == nil {
|
||||||
|
t.Fatal("toolCorrector should not be nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 测试修正器可以正常工作
|
||||||
|
data := `{"tool_calls":[{"function":{"name":"apply_patch"}}]}`
|
||||||
|
corrected, changed := service.toolCorrector.CorrectToolCallsInSSEData(data)
|
||||||
|
|
||||||
|
if !changed {
|
||||||
|
t.Error("expected tool call to be corrected")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(corrected, "edit") {
|
||||||
|
t.Errorf("expected corrected data to contain 'edit', got %q", corrected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestToolCorrectionStats 测试工具修正统计功能
|
||||||
|
func TestToolCorrectionStats(t *testing.T) {
|
||||||
|
service := &OpenAIGatewayService{
|
||||||
|
toolCorrector: NewCodexToolCorrector(),
|
||||||
|
}
|
||||||
|
|
||||||
|
// 执行几次修正
|
||||||
|
testData := []string{
|
||||||
|
`{"tool_calls":[{"function":{"name":"apply_patch"}}]}`,
|
||||||
|
`{"tool_calls":[{"function":{"name":"update_plan"}}]}`,
|
||||||
|
`{"tool_calls":[{"function":{"name":"apply_patch"}}]}`,
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, data := range testData {
|
||||||
|
service.toolCorrector.CorrectToolCallsInSSEData(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
stats := service.toolCorrector.GetStats()
|
||||||
|
|
||||||
|
if stats.TotalCorrected != 3 {
|
||||||
|
t.Errorf("expected 3 corrections, got %d", stats.TotalCorrected)
|
||||||
|
}
|
||||||
|
|
||||||
|
if stats.CorrectionsByTool["apply_patch->edit"] != 2 {
|
||||||
|
t.Errorf("expected 2 apply_patch->edit corrections, got %d", stats.CorrectionsByTool["apply_patch->edit"])
|
||||||
|
}
|
||||||
|
|
||||||
|
if stats.CorrectionsByTool["update_plan->todowrite"] != 1 {
|
||||||
|
t.Errorf("expected 1 update_plan->todowrite correction, got %d", stats.CorrectionsByTool["update_plan->todowrite"])
|
||||||
|
}
|
||||||
|
}
|
||||||
307
backend/internal/service/openai_tool_corrector.go
Normal file
307
backend/internal/service/openai_tool_corrector.go
Normal file
@@ -0,0 +1,307 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
// codexToolNameMapping 定义 Codex 原生工具名称到 OpenCode 工具名称的映射
|
||||||
|
var codexToolNameMapping = map[string]string{
|
||||||
|
"apply_patch": "edit",
|
||||||
|
"applyPatch": "edit",
|
||||||
|
"update_plan": "todowrite",
|
||||||
|
"updatePlan": "todowrite",
|
||||||
|
"read_plan": "todoread",
|
||||||
|
"readPlan": "todoread",
|
||||||
|
"search_files": "grep",
|
||||||
|
"searchFiles": "grep",
|
||||||
|
"list_files": "glob",
|
||||||
|
"listFiles": "glob",
|
||||||
|
"read_file": "read",
|
||||||
|
"readFile": "read",
|
||||||
|
"write_file": "write",
|
||||||
|
"writeFile": "write",
|
||||||
|
"execute_bash": "bash",
|
||||||
|
"executeBash": "bash",
|
||||||
|
"exec_bash": "bash",
|
||||||
|
"execBash": "bash",
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToolCorrectionStats 记录工具修正的统计信息(导出用于 JSON 序列化)
|
||||||
|
type ToolCorrectionStats struct {
|
||||||
|
TotalCorrected int `json:"total_corrected"`
|
||||||
|
CorrectionsByTool map[string]int `json:"corrections_by_tool"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// CodexToolCorrector 处理 Codex 工具调用的自动修正
|
||||||
|
type CodexToolCorrector struct {
|
||||||
|
stats ToolCorrectionStats
|
||||||
|
mu sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewCodexToolCorrector 创建新的工具修正器
|
||||||
|
func NewCodexToolCorrector() *CodexToolCorrector {
|
||||||
|
return &CodexToolCorrector{
|
||||||
|
stats: ToolCorrectionStats{
|
||||||
|
CorrectionsByTool: make(map[string]int),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CorrectToolCallsInSSEData 修正 SSE 数据中的工具调用
|
||||||
|
// 返回修正后的数据和是否进行了修正
|
||||||
|
func (c *CodexToolCorrector) CorrectToolCallsInSSEData(data string) (string, bool) {
|
||||||
|
if data == "" || data == "\n" {
|
||||||
|
return data, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// 尝试解析 JSON
|
||||||
|
var payload map[string]any
|
||||||
|
if err := json.Unmarshal([]byte(data), &payload); err != nil {
|
||||||
|
// 不是有效的 JSON,直接返回原数据
|
||||||
|
return data, false
|
||||||
|
}
|
||||||
|
|
||||||
|
corrected := false
|
||||||
|
|
||||||
|
// 处理 tool_calls 数组
|
||||||
|
if toolCalls, ok := payload["tool_calls"].([]any); ok {
|
||||||
|
if c.correctToolCallsArray(toolCalls) {
|
||||||
|
corrected = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 处理 function_call 对象
|
||||||
|
if functionCall, ok := payload["function_call"].(map[string]any); ok {
|
||||||
|
if c.correctFunctionCall(functionCall) {
|
||||||
|
corrected = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 处理 delta.tool_calls
|
||||||
|
if delta, ok := payload["delta"].(map[string]any); ok {
|
||||||
|
if toolCalls, ok := delta["tool_calls"].([]any); ok {
|
||||||
|
if c.correctToolCallsArray(toolCalls) {
|
||||||
|
corrected = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if functionCall, ok := delta["function_call"].(map[string]any); ok {
|
||||||
|
if c.correctFunctionCall(functionCall) {
|
||||||
|
corrected = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 处理 choices[].message.tool_calls 和 choices[].delta.tool_calls
|
||||||
|
if choices, ok := payload["choices"].([]any); ok {
|
||||||
|
for _, choice := range choices {
|
||||||
|
if choiceMap, ok := choice.(map[string]any); ok {
|
||||||
|
// 处理 message 中的工具调用
|
||||||
|
if message, ok := choiceMap["message"].(map[string]any); ok {
|
||||||
|
if toolCalls, ok := message["tool_calls"].([]any); ok {
|
||||||
|
if c.correctToolCallsArray(toolCalls) {
|
||||||
|
corrected = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if functionCall, ok := message["function_call"].(map[string]any); ok {
|
||||||
|
if c.correctFunctionCall(functionCall) {
|
||||||
|
corrected = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// 处理 delta 中的工具调用
|
||||||
|
if delta, ok := choiceMap["delta"].(map[string]any); ok {
|
||||||
|
if toolCalls, ok := delta["tool_calls"].([]any); ok {
|
||||||
|
if c.correctToolCallsArray(toolCalls) {
|
||||||
|
corrected = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if functionCall, ok := delta["function_call"].(map[string]any); ok {
|
||||||
|
if c.correctFunctionCall(functionCall) {
|
||||||
|
corrected = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !corrected {
|
||||||
|
return data, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// 序列化回 JSON
|
||||||
|
correctedBytes, err := json.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("[CodexToolCorrector] Failed to marshal corrected data: %v", err)
|
||||||
|
return data, false
|
||||||
|
}
|
||||||
|
|
||||||
|
return string(correctedBytes), true
|
||||||
|
}
|
||||||
|
|
||||||
|
// correctToolCallsArray 修正工具调用数组中的工具名称
|
||||||
|
func (c *CodexToolCorrector) correctToolCallsArray(toolCalls []any) bool {
|
||||||
|
corrected := false
|
||||||
|
for _, toolCall := range toolCalls {
|
||||||
|
if toolCallMap, ok := toolCall.(map[string]any); ok {
|
||||||
|
if function, ok := toolCallMap["function"].(map[string]any); ok {
|
||||||
|
if c.correctFunctionCall(function) {
|
||||||
|
corrected = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return corrected
|
||||||
|
}
|
||||||
|
|
||||||
|
// correctFunctionCall 修正单个函数调用的工具名称和参数
|
||||||
|
func (c *CodexToolCorrector) correctFunctionCall(functionCall map[string]any) bool {
|
||||||
|
name, ok := functionCall["name"].(string)
|
||||||
|
if !ok || name == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
corrected := false
|
||||||
|
|
||||||
|
// 查找并修正工具名称
|
||||||
|
if correctName, found := codexToolNameMapping[name]; found {
|
||||||
|
functionCall["name"] = correctName
|
||||||
|
c.recordCorrection(name, correctName)
|
||||||
|
corrected = true
|
||||||
|
name = correctName // 使用修正后的名称进行参数修正
|
||||||
|
}
|
||||||
|
|
||||||
|
// 修正工具参数(基于工具名称)
|
||||||
|
if c.correctToolParameters(name, functionCall) {
|
||||||
|
corrected = true
|
||||||
|
}
|
||||||
|
|
||||||
|
return corrected
|
||||||
|
}
|
||||||
|
|
||||||
|
// correctToolParameters 修正工具参数以符合 OpenCode 规范
|
||||||
|
func (c *CodexToolCorrector) correctToolParameters(toolName string, functionCall map[string]any) bool {
|
||||||
|
arguments, ok := functionCall["arguments"]
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// arguments 可能是字符串(JSON)或已解析的 map
|
||||||
|
var argsMap map[string]any
|
||||||
|
switch v := arguments.(type) {
|
||||||
|
case string:
|
||||||
|
// 解析 JSON 字符串
|
||||||
|
if err := json.Unmarshal([]byte(v), &argsMap); err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
case map[string]any:
|
||||||
|
argsMap = v
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
corrected := false
|
||||||
|
|
||||||
|
// 根据工具名称应用特定的参数修正规则
|
||||||
|
switch toolName {
|
||||||
|
case "bash":
|
||||||
|
// 移除 workdir 参数(OpenCode 不支持)
|
||||||
|
if _, exists := argsMap["workdir"]; exists {
|
||||||
|
delete(argsMap, "workdir")
|
||||||
|
corrected = true
|
||||||
|
log.Printf("[CodexToolCorrector] Removed 'workdir' parameter from bash tool")
|
||||||
|
}
|
||||||
|
if _, exists := argsMap["work_dir"]; exists {
|
||||||
|
delete(argsMap, "work_dir")
|
||||||
|
corrected = true
|
||||||
|
log.Printf("[CodexToolCorrector] Removed 'work_dir' parameter from bash tool")
|
||||||
|
}
|
||||||
|
|
||||||
|
case "edit":
|
||||||
|
// OpenCode edit 使用 old_string/new_string,Codex 可能使用其他名称
|
||||||
|
// 这里可以添加参数名称的映射逻辑
|
||||||
|
if _, exists := argsMap["file_path"]; !exists {
|
||||||
|
if path, exists := argsMap["path"]; exists {
|
||||||
|
argsMap["file_path"] = path
|
||||||
|
delete(argsMap, "path")
|
||||||
|
corrected = true
|
||||||
|
log.Printf("[CodexToolCorrector] Renamed 'path' to 'file_path' in edit tool")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果修正了参数,需要重新序列化
|
||||||
|
if corrected {
|
||||||
|
if _, wasString := arguments.(string); wasString {
|
||||||
|
// 原本是字符串,序列化回字符串
|
||||||
|
if newArgsJSON, err := json.Marshal(argsMap); err == nil {
|
||||||
|
functionCall["arguments"] = string(newArgsJSON)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// 原本是 map,直接赋值
|
||||||
|
functionCall["arguments"] = argsMap
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return corrected
|
||||||
|
}
|
||||||
|
|
||||||
|
// recordCorrection 记录一次工具名称修正
|
||||||
|
func (c *CodexToolCorrector) recordCorrection(from, to string) {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
|
c.stats.TotalCorrected++
|
||||||
|
key := fmt.Sprintf("%s->%s", from, to)
|
||||||
|
c.stats.CorrectionsByTool[key]++
|
||||||
|
|
||||||
|
log.Printf("[CodexToolCorrector] Corrected tool call: %s -> %s (total: %d)",
|
||||||
|
from, to, c.stats.TotalCorrected)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetStats 获取工具修正统计信息
|
||||||
|
func (c *CodexToolCorrector) GetStats() ToolCorrectionStats {
|
||||||
|
c.mu.RLock()
|
||||||
|
defer c.mu.RUnlock()
|
||||||
|
|
||||||
|
// 返回副本以避免并发问题
|
||||||
|
statsCopy := ToolCorrectionStats{
|
||||||
|
TotalCorrected: c.stats.TotalCorrected,
|
||||||
|
CorrectionsByTool: make(map[string]int, len(c.stats.CorrectionsByTool)),
|
||||||
|
}
|
||||||
|
for k, v := range c.stats.CorrectionsByTool {
|
||||||
|
statsCopy.CorrectionsByTool[k] = v
|
||||||
|
}
|
||||||
|
|
||||||
|
return statsCopy
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResetStats 重置统计信息
|
||||||
|
func (c *CodexToolCorrector) ResetStats() {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
|
c.stats.TotalCorrected = 0
|
||||||
|
c.stats.CorrectionsByTool = make(map[string]int)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CorrectToolName 直接修正工具名称(用于非 SSE 场景)
|
||||||
|
func CorrectToolName(name string) (string, bool) {
|
||||||
|
if correctName, found := codexToolNameMapping[name]; found {
|
||||||
|
return correctName, true
|
||||||
|
}
|
||||||
|
return name, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetToolNameMapping 获取工具名称映射表
|
||||||
|
func GetToolNameMapping() map[string]string {
|
||||||
|
// 返回副本以避免外部修改
|
||||||
|
mapping := make(map[string]string, len(codexToolNameMapping))
|
||||||
|
for k, v := range codexToolNameMapping {
|
||||||
|
mapping[k] = v
|
||||||
|
}
|
||||||
|
return mapping
|
||||||
|
}
|
||||||
503
backend/internal/service/openai_tool_corrector_test.go
Normal file
503
backend/internal/service/openai_tool_corrector_test.go
Normal file
@@ -0,0 +1,503 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCorrectToolCallsInSSEData(t *testing.T) {
|
||||||
|
corrector := NewCodexToolCorrector()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
expectCorrected bool
|
||||||
|
checkFunc func(t *testing.T, result string)
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "empty string",
|
||||||
|
input: "",
|
||||||
|
expectCorrected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "newline only",
|
||||||
|
input: "\n",
|
||||||
|
expectCorrected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid json",
|
||||||
|
input: "not a json",
|
||||||
|
expectCorrected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "correct apply_patch in tool_calls",
|
||||||
|
input: `{"tool_calls":[{"function":{"name":"apply_patch","arguments":"{}"}}]}`,
|
||||||
|
expectCorrected: true,
|
||||||
|
checkFunc: func(t *testing.T, result string) {
|
||||||
|
var payload map[string]any
|
||||||
|
if err := json.Unmarshal([]byte(result), &payload); err != nil {
|
||||||
|
t.Fatalf("Failed to parse result: %v", err)
|
||||||
|
}
|
||||||
|
toolCalls, ok := payload["tool_calls"].([]any)
|
||||||
|
if !ok || len(toolCalls) == 0 {
|
||||||
|
t.Fatal("No tool_calls found in result")
|
||||||
|
}
|
||||||
|
toolCall, ok := toolCalls[0].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("Invalid tool_call format")
|
||||||
|
}
|
||||||
|
functionCall, ok := toolCall["function"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("Invalid function format")
|
||||||
|
}
|
||||||
|
if functionCall["name"] != "edit" {
|
||||||
|
t.Errorf("Expected tool name 'edit', got '%v'", functionCall["name"])
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "correct update_plan in function_call",
|
||||||
|
input: `{"function_call":{"name":"update_plan","arguments":"{}"}}`,
|
||||||
|
expectCorrected: true,
|
||||||
|
checkFunc: func(t *testing.T, result string) {
|
||||||
|
var payload map[string]any
|
||||||
|
if err := json.Unmarshal([]byte(result), &payload); err != nil {
|
||||||
|
t.Fatalf("Failed to parse result: %v", err)
|
||||||
|
}
|
||||||
|
functionCall, ok := payload["function_call"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("Invalid function_call format")
|
||||||
|
}
|
||||||
|
if functionCall["name"] != "todowrite" {
|
||||||
|
t.Errorf("Expected tool name 'todowrite', got '%v'", functionCall["name"])
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "correct search_files in delta.tool_calls",
|
||||||
|
input: `{"delta":{"tool_calls":[{"function":{"name":"search_files"}}]}}`,
|
||||||
|
expectCorrected: true,
|
||||||
|
checkFunc: func(t *testing.T, result string) {
|
||||||
|
var payload map[string]any
|
||||||
|
if err := json.Unmarshal([]byte(result), &payload); err != nil {
|
||||||
|
t.Fatalf("Failed to parse result: %v", err)
|
||||||
|
}
|
||||||
|
delta, ok := payload["delta"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("Invalid delta format")
|
||||||
|
}
|
||||||
|
toolCalls, ok := delta["tool_calls"].([]any)
|
||||||
|
if !ok || len(toolCalls) == 0 {
|
||||||
|
t.Fatal("No tool_calls found in delta")
|
||||||
|
}
|
||||||
|
toolCall, ok := toolCalls[0].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("Invalid tool_call format")
|
||||||
|
}
|
||||||
|
functionCall, ok := toolCall["function"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("Invalid function format")
|
||||||
|
}
|
||||||
|
if functionCall["name"] != "grep" {
|
||||||
|
t.Errorf("Expected tool name 'grep', got '%v'", functionCall["name"])
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "correct list_files in choices.message.tool_calls",
|
||||||
|
input: `{"choices":[{"message":{"tool_calls":[{"function":{"name":"list_files"}}]}}]}`,
|
||||||
|
expectCorrected: true,
|
||||||
|
checkFunc: func(t *testing.T, result string) {
|
||||||
|
var payload map[string]any
|
||||||
|
if err := json.Unmarshal([]byte(result), &payload); err != nil {
|
||||||
|
t.Fatalf("Failed to parse result: %v", err)
|
||||||
|
}
|
||||||
|
choices, ok := payload["choices"].([]any)
|
||||||
|
if !ok || len(choices) == 0 {
|
||||||
|
t.Fatal("No choices found in result")
|
||||||
|
}
|
||||||
|
choice, ok := choices[0].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("Invalid choice format")
|
||||||
|
}
|
||||||
|
message, ok := choice["message"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("Invalid message format")
|
||||||
|
}
|
||||||
|
toolCalls, ok := message["tool_calls"].([]any)
|
||||||
|
if !ok || len(toolCalls) == 0 {
|
||||||
|
t.Fatal("No tool_calls found in message")
|
||||||
|
}
|
||||||
|
toolCall, ok := toolCalls[0].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("Invalid tool_call format")
|
||||||
|
}
|
||||||
|
functionCall, ok := toolCall["function"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("Invalid function format")
|
||||||
|
}
|
||||||
|
if functionCall["name"] != "glob" {
|
||||||
|
t.Errorf("Expected tool name 'glob', got '%v'", functionCall["name"])
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no correction needed",
|
||||||
|
input: `{"tool_calls":[{"function":{"name":"read","arguments":"{}"}}]}`,
|
||||||
|
expectCorrected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "correct multiple tool calls",
|
||||||
|
input: `{"tool_calls":[{"function":{"name":"apply_patch"}},{"function":{"name":"read_file"}}]}`,
|
||||||
|
expectCorrected: true,
|
||||||
|
checkFunc: func(t *testing.T, result string) {
|
||||||
|
var payload map[string]any
|
||||||
|
if err := json.Unmarshal([]byte(result), &payload); err != nil {
|
||||||
|
t.Fatalf("Failed to parse result: %v", err)
|
||||||
|
}
|
||||||
|
toolCalls, ok := payload["tool_calls"].([]any)
|
||||||
|
if !ok || len(toolCalls) < 2 {
|
||||||
|
t.Fatal("Expected at least 2 tool_calls")
|
||||||
|
}
|
||||||
|
|
||||||
|
toolCall1, ok := toolCalls[0].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("Invalid first tool_call format")
|
||||||
|
}
|
||||||
|
func1, ok := toolCall1["function"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("Invalid first function format")
|
||||||
|
}
|
||||||
|
if func1["name"] != "edit" {
|
||||||
|
t.Errorf("Expected first tool name 'edit', got '%v'", func1["name"])
|
||||||
|
}
|
||||||
|
|
||||||
|
toolCall2, ok := toolCalls[1].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("Invalid second tool_call format")
|
||||||
|
}
|
||||||
|
func2, ok := toolCall2["function"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("Invalid second function format")
|
||||||
|
}
|
||||||
|
if func2["name"] != "read" {
|
||||||
|
t.Errorf("Expected second tool name 'read', got '%v'", func2["name"])
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "camelCase format - applyPatch",
|
||||||
|
input: `{"tool_calls":[{"function":{"name":"applyPatch"}}]}`,
|
||||||
|
expectCorrected: true,
|
||||||
|
checkFunc: func(t *testing.T, result string) {
|
||||||
|
var payload map[string]any
|
||||||
|
if err := json.Unmarshal([]byte(result), &payload); err != nil {
|
||||||
|
t.Fatalf("Failed to parse result: %v", err)
|
||||||
|
}
|
||||||
|
toolCalls, ok := payload["tool_calls"].([]any)
|
||||||
|
if !ok || len(toolCalls) == 0 {
|
||||||
|
t.Fatal("No tool_calls found in result")
|
||||||
|
}
|
||||||
|
toolCall, ok := toolCalls[0].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("Invalid tool_call format")
|
||||||
|
}
|
||||||
|
functionCall, ok := toolCall["function"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("Invalid function format")
|
||||||
|
}
|
||||||
|
if functionCall["name"] != "edit" {
|
||||||
|
t.Errorf("Expected tool name 'edit', got '%v'", functionCall["name"])
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result, corrected := corrector.CorrectToolCallsInSSEData(tt.input)
|
||||||
|
|
||||||
|
if corrected != tt.expectCorrected {
|
||||||
|
t.Errorf("Expected corrected=%v, got %v", tt.expectCorrected, corrected)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !corrected && result != tt.input {
|
||||||
|
t.Errorf("Expected unchanged result when not corrected")
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.checkFunc != nil {
|
||||||
|
tt.checkFunc(t, result)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCorrectToolName(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
input string
|
||||||
|
expected string
|
||||||
|
corrected bool
|
||||||
|
}{
|
||||||
|
{"apply_patch", "edit", true},
|
||||||
|
{"applyPatch", "edit", true},
|
||||||
|
{"update_plan", "todowrite", true},
|
||||||
|
{"updatePlan", "todowrite", true},
|
||||||
|
{"read_plan", "todoread", true},
|
||||||
|
{"readPlan", "todoread", true},
|
||||||
|
{"search_files", "grep", true},
|
||||||
|
{"searchFiles", "grep", true},
|
||||||
|
{"list_files", "glob", true},
|
||||||
|
{"listFiles", "glob", true},
|
||||||
|
{"read_file", "read", true},
|
||||||
|
{"readFile", "read", true},
|
||||||
|
{"write_file", "write", true},
|
||||||
|
{"writeFile", "write", true},
|
||||||
|
{"execute_bash", "bash", true},
|
||||||
|
{"executeBash", "bash", true},
|
||||||
|
{"exec_bash", "bash", true},
|
||||||
|
{"execBash", "bash", true},
|
||||||
|
{"unknown_tool", "unknown_tool", false},
|
||||||
|
{"read", "read", false},
|
||||||
|
{"edit", "edit", false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.input, func(t *testing.T) {
|
||||||
|
result, corrected := CorrectToolName(tt.input)
|
||||||
|
|
||||||
|
if corrected != tt.corrected {
|
||||||
|
t.Errorf("Expected corrected=%v, got %v", tt.corrected, corrected)
|
||||||
|
}
|
||||||
|
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("Expected '%s', got '%s'", tt.expected, result)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetToolNameMapping(t *testing.T) {
|
||||||
|
mapping := GetToolNameMapping()
|
||||||
|
|
||||||
|
expectedMappings := map[string]string{
|
||||||
|
"apply_patch": "edit",
|
||||||
|
"update_plan": "todowrite",
|
||||||
|
"read_plan": "todoread",
|
||||||
|
"search_files": "grep",
|
||||||
|
"list_files": "glob",
|
||||||
|
}
|
||||||
|
|
||||||
|
for from, to := range expectedMappings {
|
||||||
|
if mapping[from] != to {
|
||||||
|
t.Errorf("Expected mapping[%s] = %s, got %s", from, to, mapping[from])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
mapping["test_tool"] = "test_value"
|
||||||
|
newMapping := GetToolNameMapping()
|
||||||
|
if _, exists := newMapping["test_tool"]; exists {
|
||||||
|
t.Error("Modifications to returned mapping should not affect original")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCorrectorStats(t *testing.T) {
|
||||||
|
corrector := NewCodexToolCorrector()
|
||||||
|
|
||||||
|
stats := corrector.GetStats()
|
||||||
|
if stats.TotalCorrected != 0 {
|
||||||
|
t.Errorf("Expected TotalCorrected=0, got %d", stats.TotalCorrected)
|
||||||
|
}
|
||||||
|
if len(stats.CorrectionsByTool) != 0 {
|
||||||
|
t.Errorf("Expected empty CorrectionsByTool, got length %d", len(stats.CorrectionsByTool))
|
||||||
|
}
|
||||||
|
|
||||||
|
corrector.CorrectToolCallsInSSEData(`{"tool_calls":[{"function":{"name":"apply_patch"}}]}`)
|
||||||
|
corrector.CorrectToolCallsInSSEData(`{"tool_calls":[{"function":{"name":"apply_patch"}}]}`)
|
||||||
|
corrector.CorrectToolCallsInSSEData(`{"tool_calls":[{"function":{"name":"update_plan"}}]}`)
|
||||||
|
|
||||||
|
stats = corrector.GetStats()
|
||||||
|
if stats.TotalCorrected != 3 {
|
||||||
|
t.Errorf("Expected TotalCorrected=3, got %d", stats.TotalCorrected)
|
||||||
|
}
|
||||||
|
|
||||||
|
if stats.CorrectionsByTool["apply_patch->edit"] != 2 {
|
||||||
|
t.Errorf("Expected apply_patch->edit count=2, got %d", stats.CorrectionsByTool["apply_patch->edit"])
|
||||||
|
}
|
||||||
|
|
||||||
|
if stats.CorrectionsByTool["update_plan->todowrite"] != 1 {
|
||||||
|
t.Errorf("Expected update_plan->todowrite count=1, got %d", stats.CorrectionsByTool["update_plan->todowrite"])
|
||||||
|
}
|
||||||
|
|
||||||
|
corrector.ResetStats()
|
||||||
|
stats = corrector.GetStats()
|
||||||
|
if stats.TotalCorrected != 0 {
|
||||||
|
t.Errorf("Expected TotalCorrected=0 after reset, got %d", stats.TotalCorrected)
|
||||||
|
}
|
||||||
|
if len(stats.CorrectionsByTool) != 0 {
|
||||||
|
t.Errorf("Expected empty CorrectionsByTool after reset, got length %d", len(stats.CorrectionsByTool))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestComplexSSEData(t *testing.T) {
|
||||||
|
corrector := NewCodexToolCorrector()
|
||||||
|
|
||||||
|
input := `{
|
||||||
|
"id": "chatcmpl-123",
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"created": 1234567890,
|
||||||
|
"model": "gpt-5.1-codex",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"delta": {
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"function": {
|
||||||
|
"name": "apply_patch",
|
||||||
|
"arguments": "{\"file\":\"test.go\"}"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"finish_reason": null
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}`
|
||||||
|
|
||||||
|
result, corrected := corrector.CorrectToolCallsInSSEData(input)
|
||||||
|
|
||||||
|
if !corrected {
|
||||||
|
t.Error("Expected data to be corrected")
|
||||||
|
}
|
||||||
|
|
||||||
|
var payload map[string]any
|
||||||
|
if err := json.Unmarshal([]byte(result), &payload); err != nil {
|
||||||
|
t.Fatalf("Failed to parse result: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
choices, ok := payload["choices"].([]any)
|
||||||
|
if !ok || len(choices) == 0 {
|
||||||
|
t.Fatal("No choices found in result")
|
||||||
|
}
|
||||||
|
choice, ok := choices[0].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("Invalid choice format")
|
||||||
|
}
|
||||||
|
delta, ok := choice["delta"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("Invalid delta format")
|
||||||
|
}
|
||||||
|
toolCalls, ok := delta["tool_calls"].([]any)
|
||||||
|
if !ok || len(toolCalls) == 0 {
|
||||||
|
t.Fatal("No tool_calls found in delta")
|
||||||
|
}
|
||||||
|
toolCall, ok := toolCalls[0].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("Invalid tool_call format")
|
||||||
|
}
|
||||||
|
function, ok := toolCall["function"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("Invalid function format")
|
||||||
|
}
|
||||||
|
|
||||||
|
if function["name"] != "edit" {
|
||||||
|
t.Errorf("Expected tool name 'edit', got '%v'", function["name"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestCorrectToolParameters 测试工具参数修正
|
||||||
|
func TestCorrectToolParameters(t *testing.T) {
|
||||||
|
corrector := NewCodexToolCorrector()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
expected map[string]bool // key: 期待存在的参数, value: true表示应该存在
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "remove workdir from bash tool",
|
||||||
|
input: `{
|
||||||
|
"tool_calls": [{
|
||||||
|
"function": {
|
||||||
|
"name": "bash",
|
||||||
|
"arguments": "{\"command\":\"ls\",\"workdir\":\"/tmp\"}"
|
||||||
|
}
|
||||||
|
}]
|
||||||
|
}`,
|
||||||
|
expected: map[string]bool{
|
||||||
|
"command": true,
|
||||||
|
"workdir": false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "rename path to file_path in edit tool",
|
||||||
|
input: `{
|
||||||
|
"tool_calls": [{
|
||||||
|
"function": {
|
||||||
|
"name": "apply_patch",
|
||||||
|
"arguments": "{\"path\":\"/foo/bar.go\",\"old_string\":\"old\",\"new_string\":\"new\"}"
|
||||||
|
}
|
||||||
|
}]
|
||||||
|
}`,
|
||||||
|
expected: map[string]bool{
|
||||||
|
"file_path": true,
|
||||||
|
"path": false,
|
||||||
|
"old_string": true,
|
||||||
|
"new_string": true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
corrected, changed := corrector.CorrectToolCallsInSSEData(tt.input)
|
||||||
|
if !changed {
|
||||||
|
t.Error("expected data to be corrected")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 解析修正后的数据
|
||||||
|
var result map[string]any
|
||||||
|
if err := json.Unmarshal([]byte(corrected), &result); err != nil {
|
||||||
|
t.Fatalf("failed to parse corrected data: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查工具调用
|
||||||
|
toolCalls, ok := result["tool_calls"].([]any)
|
||||||
|
if !ok || len(toolCalls) == 0 {
|
||||||
|
t.Fatal("no tool_calls found in corrected data")
|
||||||
|
}
|
||||||
|
|
||||||
|
toolCall, ok := toolCalls[0].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("invalid tool_call structure")
|
||||||
|
}
|
||||||
|
|
||||||
|
function, ok := toolCall["function"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("no function found in tool_call")
|
||||||
|
}
|
||||||
|
|
||||||
|
argumentsStr, ok := function["arguments"].(string)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("arguments is not a string")
|
||||||
|
}
|
||||||
|
|
||||||
|
var args map[string]any
|
||||||
|
if err := json.Unmarshal([]byte(argumentsStr), &args); err != nil {
|
||||||
|
t.Fatalf("failed to parse arguments: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证期望的参数
|
||||||
|
for param, shouldExist := range tt.expected {
|
||||||
|
_, exists := args[param]
|
||||||
|
if shouldExist && !exists {
|
||||||
|
t.Errorf("expected parameter %q to exist, but it doesn't", param)
|
||||||
|
}
|
||||||
|
if !shouldExist && exists {
|
||||||
|
t.Errorf("expected parameter %q to not exist, but it does", param)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user