diff --git a/backend/internal/service/openai_compat_model_test.go b/backend/internal/service/openai_compat_model_test.go index a897e219..e222b093 100644 --- a/backend/internal/service/openai_compat_model_test.go +++ b/backend/internal/service/openai_compat_model_test.go @@ -480,6 +480,61 @@ func TestForwardAsAnthropic_AttachesPreviousResponseIDForCompatContinuation(t *t require.Equal(t, "second", gjson.GetBytes(upstream.lastBody, "input.1.content.0.text").String()) } +func TestForwardAsAnthropic_PreviousResponseIDKeepsMultiToolCallContext(t *testing.T) { + t.Parallel() + gin.SetMode(gin.TestMode) + + upstream := &httpUpstreamRecorder{} + svc := &OpenAIGatewayService{ + httpUpstream: upstream, + cfg: &config.Config{Security: config.SecurityConfig{URLAllowlist: config.URLAllowlistConfig{Enabled: false}}}, + } + account := &Account{ + ID: 1, + Name: "openai-apikey", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": "https://api.openai.com/v1", + }, + } + + firstBody := []byte(`{"model":"claude-sonnet-4-5","max_tokens":16,"messages":[{"role":"user","content":"inspect files"}],"stream":false}`) + upstream.resp = openAICompatSSECompletedResponse("resp_first_tools", "gpt-5.3-codex") + firstRec := httptest.NewRecorder() + firstCtx, _ := gin.CreateTestContext(firstRec) + firstCtx.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(firstBody)) + firstCtx.Request.Header.Set("Content-Type", "application/json") + + firstResult, err := svc.ForwardAsAnthropic(context.Background(), firstCtx, account, firstBody, "stable-cache-key", "gpt-5.3-codex") + require.NoError(t, err) + require.NotNil(t, firstResult) + + secondBody := []byte(`{"model":"claude-sonnet-4-5","max_tokens":16,"messages":[{"role":"user","content":"inspect files"},{"role":"assistant","content":[{"type":"tool_use","id":"call_one","name":"Read","input":{"file_path":"a.go"}},{"type":"tool_use","id":"call_two","name":"Read","input":{"file_path":"b.go"}}]},{"role":"user","content":[{"type":"tool_result","tool_use_id":"call_one","content":"package a"},{"type":"tool_result","tool_use_id":"call_two","content":"package b"},{"type":"text","text":"continue"}]}],"tools":[{"name":"Read","description":"read a file","input_schema":{"type":"object","properties":{"file_path":{"type":"string"}}}}],"stream":false}`) + upstream.resp = openAICompatSSECompletedResponse("resp_second_tools", "gpt-5.3-codex") + secondRec := httptest.NewRecorder() + secondCtx, _ := gin.CreateTestContext(secondRec) + secondCtx.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(secondBody)) + secondCtx.Request.Header.Set("Content-Type", "application/json") + + secondResult, err := svc.ForwardAsAnthropic(context.Background(), secondCtx, account, secondBody, "stable-cache-key", "gpt-5.3-codex") + require.NoError(t, err) + require.NotNil(t, secondResult) + require.Equal(t, "resp_first_tools", gjson.GetBytes(upstream.lastBody, "previous_response_id").String()) + + require.Equal(t, "function_call", gjson.GetBytes(upstream.lastBody, "input.1.type").String()) + require.Equal(t, "call_one", gjson.GetBytes(upstream.lastBody, "input.1.call_id").String()) + require.Equal(t, "function_call", gjson.GetBytes(upstream.lastBody, "input.2.type").String()) + require.Equal(t, "call_two", gjson.GetBytes(upstream.lastBody, "input.2.call_id").String()) + require.Equal(t, "function_call_output", gjson.GetBytes(upstream.lastBody, "input.3.type").String()) + require.Equal(t, "call_one", gjson.GetBytes(upstream.lastBody, "input.3.call_id").String()) + require.Equal(t, "function_call_output", gjson.GetBytes(upstream.lastBody, "input.4.type").String()) + require.Equal(t, "call_two", gjson.GetBytes(upstream.lastBody, "input.4.call_id").String()) + require.Equal(t, "continue", gjson.GetBytes(upstream.lastBody, "input.5.content.0.text").String()) +} + func TestForwardAsAnthropic_ReplaysWithoutContinuationWhenPreviousResponseMissing(t *testing.T) { t.Parallel() gin.SetMode(gin.TestMode) diff --git a/backend/internal/service/openai_messages_continuation.go b/backend/internal/service/openai_messages_continuation.go index 57d04784..b3c900fb 100644 --- a/backend/internal/service/openai_messages_continuation.go +++ b/backend/internal/service/openai_messages_continuation.go @@ -37,10 +37,7 @@ func trimAnthropicCompatResponsesInputToLatestTurn(req *apicompat.ResponsesReque return } - start := len(items) - 1 - for start > 0 && items[start].Type == "function_call_output" { - start-- - } + start := latestAnthropicCompatResponsesInputTurnStart(items) trimmed := append([]apicompat.ResponsesInputItem(nil), items[start:]...) if len(trimmed) == len(items) { return @@ -50,6 +47,63 @@ func trimAnthropicCompatResponsesInputToLatestTurn(req *apicompat.ResponsesReque } } +func latestAnthropicCompatResponsesInputTurnStart(items []apicompat.ResponsesInputItem) int { + if len(items) == 0 { + return 0 + } + + start := len(items) - 1 + last := items[start] + switch { + case last.Type == "function_call_output": + for start > 0 && items[start-1].Type == "function_call_output" { + start-- + } + case last.Type == "message" && last.Role == "user": + for start > 0 && items[start-1].Type == "function_call_output" { + start-- + } + default: + return start + } + + return expandAnthropicCompatResponsesInputToolCallStart(items, start) +} + +func expandAnthropicCompatResponsesInputToolCallStart(items []apicompat.ResponsesInputItem, start int) int { + if start < 0 || start >= len(items) { + return start + } + + needed := make(map[string]struct{}) + for i := start; i < len(items); i++ { + if items[i].Type != "function_call_output" { + continue + } + callID := strings.TrimSpace(items[i].CallID) + if callID != "" { + needed[callID] = struct{}{} + } + } + if len(needed) == 0 { + return start + } + + expandedStart := start + for i := start - 1; i >= 0 && len(needed) > 0; i-- { + if items[i].Type != "function_call" { + continue + } + callID := strings.TrimSpace(items[i].CallID) + if _, ok := needed[callID]; !ok { + continue + } + delete(needed, callID) + expandedStart = i + } + return expandedStart +} + func isOpenAICompatPreviousResponseNotFound(statusCode int, upstreamMsg string, upstreamBody []byte) bool { if statusCode != http.StatusBadRequest && statusCode != http.StatusNotFound { return false