Merge pull request #2356 from jack-atlas/fix/openai-messages-multi-tool-continuation
Preserve multi-tool context in OpenAI messages continuation
This commit is contained in:
@@ -480,6 +480,61 @@ func TestForwardAsAnthropic_AttachesPreviousResponseIDForCompatContinuation(t *t
|
||||
require.Equal(t, "second", gjson.GetBytes(upstream.lastBody, "input.1.content.0.text").String())
|
||||
}
|
||||
|
||||
func TestForwardAsAnthropic_PreviousResponseIDKeepsMultiToolCallContext(t *testing.T) {
|
||||
t.Parallel()
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
upstream := &httpUpstreamRecorder{}
|
||||
svc := &OpenAIGatewayService{
|
||||
httpUpstream: upstream,
|
||||
cfg: &config.Config{Security: config.SecurityConfig{URLAllowlist: config.URLAllowlistConfig{Enabled: false}}},
|
||||
}
|
||||
account := &Account{
|
||||
ID: 1,
|
||||
Name: "openai-apikey",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"api_key": "sk-test",
|
||||
"base_url": "https://api.openai.com/v1",
|
||||
},
|
||||
}
|
||||
|
||||
firstBody := []byte(`{"model":"claude-sonnet-4-5","max_tokens":16,"messages":[{"role":"user","content":"inspect files"}],"stream":false}`)
|
||||
upstream.resp = openAICompatSSECompletedResponse("resp_first_tools", "gpt-5.3-codex")
|
||||
firstRec := httptest.NewRecorder()
|
||||
firstCtx, _ := gin.CreateTestContext(firstRec)
|
||||
firstCtx.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(firstBody))
|
||||
firstCtx.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
firstResult, err := svc.ForwardAsAnthropic(context.Background(), firstCtx, account, firstBody, "stable-cache-key", "gpt-5.3-codex")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, firstResult)
|
||||
|
||||
secondBody := []byte(`{"model":"claude-sonnet-4-5","max_tokens":16,"messages":[{"role":"user","content":"inspect files"},{"role":"assistant","content":[{"type":"tool_use","id":"call_one","name":"Read","input":{"file_path":"a.go"}},{"type":"tool_use","id":"call_two","name":"Read","input":{"file_path":"b.go"}}]},{"role":"user","content":[{"type":"tool_result","tool_use_id":"call_one","content":"package a"},{"type":"tool_result","tool_use_id":"call_two","content":"package b"},{"type":"text","text":"continue"}]}],"tools":[{"name":"Read","description":"read a file","input_schema":{"type":"object","properties":{"file_path":{"type":"string"}}}}],"stream":false}`)
|
||||
upstream.resp = openAICompatSSECompletedResponse("resp_second_tools", "gpt-5.3-codex")
|
||||
secondRec := httptest.NewRecorder()
|
||||
secondCtx, _ := gin.CreateTestContext(secondRec)
|
||||
secondCtx.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(secondBody))
|
||||
secondCtx.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
secondResult, err := svc.ForwardAsAnthropic(context.Background(), secondCtx, account, secondBody, "stable-cache-key", "gpt-5.3-codex")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, secondResult)
|
||||
require.Equal(t, "resp_first_tools", gjson.GetBytes(upstream.lastBody, "previous_response_id").String())
|
||||
|
||||
require.Equal(t, "function_call", gjson.GetBytes(upstream.lastBody, "input.1.type").String())
|
||||
require.Equal(t, "call_one", gjson.GetBytes(upstream.lastBody, "input.1.call_id").String())
|
||||
require.Equal(t, "function_call", gjson.GetBytes(upstream.lastBody, "input.2.type").String())
|
||||
require.Equal(t, "call_two", gjson.GetBytes(upstream.lastBody, "input.2.call_id").String())
|
||||
require.Equal(t, "function_call_output", gjson.GetBytes(upstream.lastBody, "input.3.type").String())
|
||||
require.Equal(t, "call_one", gjson.GetBytes(upstream.lastBody, "input.3.call_id").String())
|
||||
require.Equal(t, "function_call_output", gjson.GetBytes(upstream.lastBody, "input.4.type").String())
|
||||
require.Equal(t, "call_two", gjson.GetBytes(upstream.lastBody, "input.4.call_id").String())
|
||||
require.Equal(t, "continue", gjson.GetBytes(upstream.lastBody, "input.5.content.0.text").String())
|
||||
}
|
||||
|
||||
func TestForwardAsAnthropic_ReplaysWithoutContinuationWhenPreviousResponseMissing(t *testing.T) {
|
||||
t.Parallel()
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
@@ -37,10 +37,7 @@ func trimAnthropicCompatResponsesInputToLatestTurn(req *apicompat.ResponsesReque
|
||||
return
|
||||
}
|
||||
|
||||
start := len(items) - 1
|
||||
for start > 0 && items[start].Type == "function_call_output" {
|
||||
start--
|
||||
}
|
||||
start := latestAnthropicCompatResponsesInputTurnStart(items)
|
||||
trimmed := append([]apicompat.ResponsesInputItem(nil), items[start:]...)
|
||||
if len(trimmed) == len(items) {
|
||||
return
|
||||
@@ -50,6 +47,63 @@ func trimAnthropicCompatResponsesInputToLatestTurn(req *apicompat.ResponsesReque
|
||||
}
|
||||
}
|
||||
|
||||
func latestAnthropicCompatResponsesInputTurnStart(items []apicompat.ResponsesInputItem) int {
|
||||
if len(items) == 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
start := len(items) - 1
|
||||
last := items[start]
|
||||
switch {
|
||||
case last.Type == "function_call_output":
|
||||
for start > 0 && items[start-1].Type == "function_call_output" {
|
||||
start--
|
||||
}
|
||||
case last.Type == "message" && last.Role == "user":
|
||||
for start > 0 && items[start-1].Type == "function_call_output" {
|
||||
start--
|
||||
}
|
||||
default:
|
||||
return start
|
||||
}
|
||||
|
||||
return expandAnthropicCompatResponsesInputToolCallStart(items, start)
|
||||
}
|
||||
|
||||
func expandAnthropicCompatResponsesInputToolCallStart(items []apicompat.ResponsesInputItem, start int) int {
|
||||
if start < 0 || start >= len(items) {
|
||||
return start
|
||||
}
|
||||
|
||||
needed := make(map[string]struct{})
|
||||
for i := start; i < len(items); i++ {
|
||||
if items[i].Type != "function_call_output" {
|
||||
continue
|
||||
}
|
||||
callID := strings.TrimSpace(items[i].CallID)
|
||||
if callID != "" {
|
||||
needed[callID] = struct{}{}
|
||||
}
|
||||
}
|
||||
if len(needed) == 0 {
|
||||
return start
|
||||
}
|
||||
|
||||
expandedStart := start
|
||||
for i := start - 1; i >= 0 && len(needed) > 0; i-- {
|
||||
if items[i].Type != "function_call" {
|
||||
continue
|
||||
}
|
||||
callID := strings.TrimSpace(items[i].CallID)
|
||||
if _, ok := needed[callID]; !ok {
|
||||
continue
|
||||
}
|
||||
delete(needed, callID)
|
||||
expandedStart = i
|
||||
}
|
||||
return expandedStart
|
||||
}
|
||||
|
||||
func isOpenAICompatPreviousResponseNotFound(statusCode int, upstreamMsg string, upstreamBody []byte) bool {
|
||||
if statusCode != http.StatusBadRequest && statusCode != http.StatusNotFound {
|
||||
return false
|
||||
|
||||
Reference in New Issue
Block a user