Merge pull request #2356 from jack-atlas/fix/openai-messages-multi-tool-continuation

Preserve multi-tool context in OpenAI messages continuation
This commit is contained in:
Wesley Liddick
2026-05-11 23:03:24 +08:00
committed by GitHub
2 changed files with 113 additions and 4 deletions

View File

@@ -480,6 +480,61 @@ func TestForwardAsAnthropic_AttachesPreviousResponseIDForCompatContinuation(t *t
require.Equal(t, "second", gjson.GetBytes(upstream.lastBody, "input.1.content.0.text").String())
}
func TestForwardAsAnthropic_PreviousResponseIDKeepsMultiToolCallContext(t *testing.T) {
t.Parallel()
gin.SetMode(gin.TestMode)
upstream := &httpUpstreamRecorder{}
svc := &OpenAIGatewayService{
httpUpstream: upstream,
cfg: &config.Config{Security: config.SecurityConfig{URLAllowlist: config.URLAllowlistConfig{Enabled: false}}},
}
account := &Account{
ID: 1,
Name: "openai-apikey",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "sk-test",
"base_url": "https://api.openai.com/v1",
},
}
firstBody := []byte(`{"model":"claude-sonnet-4-5","max_tokens":16,"messages":[{"role":"user","content":"inspect files"}],"stream":false}`)
upstream.resp = openAICompatSSECompletedResponse("resp_first_tools", "gpt-5.3-codex")
firstRec := httptest.NewRecorder()
firstCtx, _ := gin.CreateTestContext(firstRec)
firstCtx.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(firstBody))
firstCtx.Request.Header.Set("Content-Type", "application/json")
firstResult, err := svc.ForwardAsAnthropic(context.Background(), firstCtx, account, firstBody, "stable-cache-key", "gpt-5.3-codex")
require.NoError(t, err)
require.NotNil(t, firstResult)
secondBody := []byte(`{"model":"claude-sonnet-4-5","max_tokens":16,"messages":[{"role":"user","content":"inspect files"},{"role":"assistant","content":[{"type":"tool_use","id":"call_one","name":"Read","input":{"file_path":"a.go"}},{"type":"tool_use","id":"call_two","name":"Read","input":{"file_path":"b.go"}}]},{"role":"user","content":[{"type":"tool_result","tool_use_id":"call_one","content":"package a"},{"type":"tool_result","tool_use_id":"call_two","content":"package b"},{"type":"text","text":"continue"}]}],"tools":[{"name":"Read","description":"read a file","input_schema":{"type":"object","properties":{"file_path":{"type":"string"}}}}],"stream":false}`)
upstream.resp = openAICompatSSECompletedResponse("resp_second_tools", "gpt-5.3-codex")
secondRec := httptest.NewRecorder()
secondCtx, _ := gin.CreateTestContext(secondRec)
secondCtx.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(secondBody))
secondCtx.Request.Header.Set("Content-Type", "application/json")
secondResult, err := svc.ForwardAsAnthropic(context.Background(), secondCtx, account, secondBody, "stable-cache-key", "gpt-5.3-codex")
require.NoError(t, err)
require.NotNil(t, secondResult)
require.Equal(t, "resp_first_tools", gjson.GetBytes(upstream.lastBody, "previous_response_id").String())
require.Equal(t, "function_call", gjson.GetBytes(upstream.lastBody, "input.1.type").String())
require.Equal(t, "call_one", gjson.GetBytes(upstream.lastBody, "input.1.call_id").String())
require.Equal(t, "function_call", gjson.GetBytes(upstream.lastBody, "input.2.type").String())
require.Equal(t, "call_two", gjson.GetBytes(upstream.lastBody, "input.2.call_id").String())
require.Equal(t, "function_call_output", gjson.GetBytes(upstream.lastBody, "input.3.type").String())
require.Equal(t, "call_one", gjson.GetBytes(upstream.lastBody, "input.3.call_id").String())
require.Equal(t, "function_call_output", gjson.GetBytes(upstream.lastBody, "input.4.type").String())
require.Equal(t, "call_two", gjson.GetBytes(upstream.lastBody, "input.4.call_id").String())
require.Equal(t, "continue", gjson.GetBytes(upstream.lastBody, "input.5.content.0.text").String())
}
func TestForwardAsAnthropic_ReplaysWithoutContinuationWhenPreviousResponseMissing(t *testing.T) {
t.Parallel()
gin.SetMode(gin.TestMode)

View File

@@ -37,10 +37,7 @@ func trimAnthropicCompatResponsesInputToLatestTurn(req *apicompat.ResponsesReque
return
}
start := len(items) - 1
for start > 0 && items[start].Type == "function_call_output" {
start--
}
start := latestAnthropicCompatResponsesInputTurnStart(items)
trimmed := append([]apicompat.ResponsesInputItem(nil), items[start:]...)
if len(trimmed) == len(items) {
return
@@ -50,6 +47,63 @@ func trimAnthropicCompatResponsesInputToLatestTurn(req *apicompat.ResponsesReque
}
}
func latestAnthropicCompatResponsesInputTurnStart(items []apicompat.ResponsesInputItem) int {
if len(items) == 0 {
return 0
}
start := len(items) - 1
last := items[start]
switch {
case last.Type == "function_call_output":
for start > 0 && items[start-1].Type == "function_call_output" {
start--
}
case last.Type == "message" && last.Role == "user":
for start > 0 && items[start-1].Type == "function_call_output" {
start--
}
default:
return start
}
return expandAnthropicCompatResponsesInputToolCallStart(items, start)
}
func expandAnthropicCompatResponsesInputToolCallStart(items []apicompat.ResponsesInputItem, start int) int {
if start < 0 || start >= len(items) {
return start
}
needed := make(map[string]struct{})
for i := start; i < len(items); i++ {
if items[i].Type != "function_call_output" {
continue
}
callID := strings.TrimSpace(items[i].CallID)
if callID != "" {
needed[callID] = struct{}{}
}
}
if len(needed) == 0 {
return start
}
expandedStart := start
for i := start - 1; i >= 0 && len(needed) > 0; i-- {
if items[i].Type != "function_call" {
continue
}
callID := strings.TrimSpace(items[i].CallID)
if _, ok := needed[callID]; !ok {
continue
}
delete(needed, callID)
expandedStart = i
}
return expandedStart
}
func isOpenAICompatPreviousResponseNotFound(statusCode int, upstreamMsg string, upstreamBody []byte) bool {
if statusCode != http.StatusBadRequest && statusCode != http.StatusNotFound {
return false