diff --git a/backend/internal/service/openai_tool_corrector.go b/backend/internal/service/openai_tool_corrector.go index 6dad1a75..9c9eab84 100644 --- a/backend/internal/service/openai_tool_corrector.go +++ b/backend/internal/service/openai_tool_corrector.go @@ -29,16 +29,16 @@ var codexToolNameMapping = map[string]string{ "execBash": "bash", } -// ToolCorrectionStats 记录工具修正的统计信息 +// ToolCorrectionStats 记录工具修正的统计信息(导出用于 JSON 序列化) type ToolCorrectionStats struct { TotalCorrected int `json:"total_corrected"` CorrectionsByTool map[string]int `json:"corrections_by_tool"` - mu sync.RWMutex } // CodexToolCorrector 处理 Codex 工具调用的自动修正 type CodexToolCorrector struct { stats ToolCorrectionStats + mu sync.RWMutex } // NewCodexToolCorrector 创建新的工具修正器 @@ -251,8 +251,8 @@ func (c *CodexToolCorrector) correctToolParameters(toolName string, functionCall // recordCorrection 记录一次工具名称修正 func (c *CodexToolCorrector) recordCorrection(from, to string) { - c.stats.mu.Lock() - defer c.stats.mu.Unlock() + c.mu.Lock() + defer c.mu.Unlock() c.stats.TotalCorrected++ key := fmt.Sprintf("%s->%s", from, to) @@ -264,8 +264,8 @@ func (c *CodexToolCorrector) recordCorrection(from, to string) { // GetStats 获取工具修正统计信息 func (c *CodexToolCorrector) GetStats() ToolCorrectionStats { - c.stats.mu.RLock() - defer c.stats.mu.RUnlock() + c.mu.RLock() + defer c.mu.RUnlock() // 返回副本以避免并发问题 statsCopy := ToolCorrectionStats{ @@ -281,8 +281,8 @@ func (c *CodexToolCorrector) GetStats() ToolCorrectionStats { // ResetStats 重置统计信息 func (c *CodexToolCorrector) ResetStats() { - c.stats.mu.Lock() - defer c.stats.mu.Unlock() + c.mu.Lock() + defer c.mu.Unlock() c.stats.TotalCorrected = 0 c.stats.CorrectionsByTool = make(map[string]int) diff --git a/backend/internal/service/openai_tool_corrector_test.go b/backend/internal/service/openai_tool_corrector_test.go index 7219c2af..a1c4530a 100644 --- a/backend/internal/service/openai_tool_corrector_test.go +++ b/backend/internal/service/openai_tool_corrector_test.go @@ -38,8 +38,18 @@ func TestCorrectToolCallsInSSEData(t *testing.T) { if err := json.Unmarshal([]byte(result), &payload); err != nil { t.Fatalf("Failed to parse result: %v", err) } - toolCalls := payload["tool_calls"].([]any) - functionCall := toolCalls[0].(map[string]any)["function"].(map[string]any) + toolCalls, ok := payload["tool_calls"].([]any) + if !ok || len(toolCalls) == 0 { + t.Fatal("No tool_calls found in result") + } + toolCall, ok := toolCalls[0].(map[string]any) + if !ok { + t.Fatal("Invalid tool_call format") + } + functionCall, ok := toolCall["function"].(map[string]any) + if !ok { + t.Fatal("Invalid function format") + } if functionCall["name"] != "edit" { t.Errorf("Expected tool name 'edit', got '%v'", functionCall["name"]) } @@ -54,7 +64,10 @@ func TestCorrectToolCallsInSSEData(t *testing.T) { if err := json.Unmarshal([]byte(result), &payload); err != nil { t.Fatalf("Failed to parse result: %v", err) } - functionCall := payload["function_call"].(map[string]any) + functionCall, ok := payload["function_call"].(map[string]any) + if !ok { + t.Fatal("Invalid function_call format") + } if functionCall["name"] != "todowrite" { t.Errorf("Expected tool name 'todowrite', got '%v'", functionCall["name"]) } @@ -131,8 +144,18 @@ func TestCorrectToolCallsInSSEData(t *testing.T) { if err := json.Unmarshal([]byte(result), &payload); err != nil { t.Fatalf("Failed to parse result: %v", err) } - toolCalls := payload["tool_calls"].([]any) - functionCall := toolCalls[0].(map[string]any)["function"].(map[string]any) + toolCalls, ok := payload["tool_calls"].([]any) + if !ok || len(toolCalls) == 0 { + t.Fatal("No tool_calls found in result") + } + toolCall, ok := toolCalls[0].(map[string]any) + if !ok { + t.Fatal("Invalid tool_call format") + } + functionCall, ok := toolCall["function"].(map[string]any) + if !ok { + t.Fatal("Invalid function format") + } if functionCall["name"] != "edit" { t.Errorf("Expected tool name 'edit', got '%v'", functionCall["name"]) }