fix(openai): preserve replay tool output continuation

This commit is contained in:
anzhen-tech
2026-05-07 14:59:42 +08:00
parent f3577bc69c
commit 16a315574d
2 changed files with 321 additions and 2 deletions

View File

@@ -1552,6 +1552,15 @@ func openAIWSRawItemsHasPrefix(items []json.RawMessage, prefix []json.RawMessage
return true
}
func openAIWSRawItemsHasFunctionCallOutput(items []json.RawMessage) bool {
for _, item := range items {
if gjson.GetBytes(item, "type").String() == "function_call_output" {
return true
}
}
return false
}
func buildOpenAIWSReplayInputSequence(
previousFullInput []json.RawMessage,
previousFullInputExists bool,
@@ -3117,6 +3126,12 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
currentTurnReplayInput := []json.RawMessage(nil)
currentTurnReplayInputExists := false
skipBeforeTurn := false
hasCurrentOrReplayFunctionCallOutput := func(payload []byte) bool {
if gjson.GetBytes(payload, `input.#(type=="function_call_output")`).Exists() {
return true
}
return currentTurnReplayInputExists && openAIWSRawItemsHasFunctionCallOutput(currentTurnReplayInput)
}
resetSessionLease := func(markBroken bool) {
if sessionLease == nil {
return
@@ -3139,7 +3154,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
// 携带 function_call_output 的请求不能丢弃 previous_response_id
// 上游 API 需要 response chain 来匹配 tool_result 与之前的 tool_use
// 丢弃后会导致 "No tool call found for function call output" 400 错误。
if gjson.GetBytes(currentPayload, `input.#(type=="function_call_output")`).Exists() {
if hasCurrentOrReplayFunctionCallOutput(currentPayload) {
return false
}
if isStrictAffinityTurn(currentPayload) {
@@ -3298,6 +3313,9 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
currentTurnReplayInput = nextReplayInput
currentTurnReplayInputExists = nextReplayInputExists
}
replayHasFunctionCallOutput := currentTurnReplayInputExists &&
openAIWSRawItemsHasFunctionCallOutput(currentTurnReplayInput)
hasFunctionCallOutput = hasFunctionCallOutput || replayHasFunctionCallOutput
if storeDisabled && turn > 1 && currentPreviousResponseID != "" {
shouldKeepPreviousResponseID := false
strictReason := ""
@@ -3416,7 +3434,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
// 携带 function_call_output 的请求不能丢弃 previous_response_id
// 上游 API 需要 response chain 来匹配 tool_result 与之前的 tool_use
// 丢弃后会导致 "No tool call found for function call output" 400 错误。
hasFCOutput := gjson.GetBytes(currentPayload, `input.#(type=="function_call_output")`).Exists()
hasFCOutput := hasFunctionCallOutput
if !turnPrevRecoveryTried && currentPreviousResponseID != "" && !hasFCOutput {
updatedPayload, removed, dropErr := dropPreviousResponseIDFromRawPayload(currentPayload)
if dropErr != nil || !removed {
@@ -3464,6 +3482,15 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
}
}
}
if hasFCOutput && currentPreviousResponseID != "" {
logOpenAIWSModeInfo(
"ingress_ws_preflight_ping_recovery_skip account_id=%d turn=%d conn_id=%s reason=function_call_output action=fail_close previous_response_id=%s",
account.ID,
turn,
truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen),
truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen),
)
}
resetSessionLease(true)
return NewOpenAIWSClientCloseError(
coderws.StatusPolicyViolation,

View File

@@ -1918,6 +1918,298 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledStr
require.Equal(t, "world", gjson.Get(secondWrite, "input.1.text").String())
}
func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledPreflightPingFailClosesWhenFunctionCallOutputNeedsPreviousResponseID(t *testing.T) {
gin.SetMode(gin.TestMode)
prevPreflightPingIdle := openAIWSIngressPreflightPingIdle
openAIWSIngressPreflightPingIdle = 0
defer func() {
openAIWSIngressPreflightPingIdle = prevPreflightPingIdle
}()
cfg := &config.Config{}
cfg.Security.URLAllowlist.Enabled = false
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 2
cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
firstConn := &openAIWSPreflightFailConn{
events: [][]byte{
[]byte(`{"type":"response.completed","response":{"id":"resp_turn_ping_replay_fc_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
},
}
secondConn := &openAIWSCaptureConn{
events: [][]byte{
[]byte(`{"type":"error","error":{"type":"invalid_request_error","code":"previous_response_not_found","message":"Previous response not found."}}`),
},
}
dialer := &openAIWSQueueDialer{
conns: []openAIWSClientConn{firstConn, secondConn},
}
pool := newOpenAIWSConnPool(cfg)
pool.setClientDialerForTest(dialer)
svc := &OpenAIGatewayService{
cfg: cfg,
httpUpstream: &httpUpstreamRecorder{},
cache: &stubGatewayCache{},
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
toolCorrector: NewCodexToolCorrector(),
openaiWSPool: pool,
}
account := &Account{
ID: 129,
Name: "openai-ingress-preflight-replay-function-output",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "sk-test",
},
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}
serverErrCh := make(chan error, 1)
wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
CompressionMode: coderws.CompressionContextTakeover,
})
if err != nil {
serverErrCh <- err
return
}
defer func() {
_ = conn.CloseNow()
}()
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
req := r.Clone(r.Context())
req.Header = req.Header.Clone()
req.Header.Set("User-Agent", "unit-test-agent/1.0")
ginCtx.Request = req
readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
msgType, firstMessage, readErr := conn.Read(readCtx)
cancel()
if readErr != nil {
serverErrCh <- readErr
return
}
if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
serverErrCh <- errors.New("unsupported websocket client message type")
return
}
serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil)
}))
defer wsServer.Close()
dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil)
cancelDial()
require.NoError(t, err)
defer func() {
_ = clientConn.CloseNow()
}()
writeMessage := func(payload string) {
writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload)))
}
readMessage := func() []byte {
readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
msgType, message, readErr := clientConn.Read(readCtx)
require.NoError(t, readErr)
require.Equal(t, coderws.MessageText, msgType)
return message
}
writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"function_call","call_id":"call_other","name":"shell","arguments":"{}"},{"type":"function_call_output","call_id":"call_replay_1","output":"ok"}]}`)
firstTurn := readMessage()
require.Equal(t, "resp_turn_ping_replay_fc_1", gjson.GetBytes(firstTurn, "response.id").String())
writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"previous_response_id":"resp_turn_ping_replay_fc_1","input":[{"type":"function_call_output","call_id":"call_replay_1","output":"ok"}]}`)
select {
case serverErr := <-serverErrCh:
require.Error(t, serverErr)
var closeErr *OpenAIWSClientCloseError
require.ErrorAs(t, serverErr, &closeErr)
require.Equal(t, coderws.StatusPolicyViolation, closeErr.StatusCode())
require.Contains(t, closeErr.Reason(), "upstream continuation connection is unavailable")
case <-time.After(5 * time.Second):
t.Fatal("等待 ingress websocket 结束超时")
}
require.Equal(t, 1, dialer.DialCount(), "需要 previous_response_id 的 function_call_output 在原连接不可用时不应换新连接重试")
secondConn.mu.Lock()
secondWrites := append([]map[string]any(nil), secondConn.writes...)
secondConn.mu.Unlock()
require.Empty(t, secondWrites, "不能把旧连接的 previous_response_id 发送到新上游,否则会触发 previous_response_not_found")
}
func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledPreflightPingFailClosesWhenReplayHasFunctionCallOutput(t *testing.T) {
gin.SetMode(gin.TestMode)
prevPreflightPingIdle := openAIWSIngressPreflightPingIdle
openAIWSIngressPreflightPingIdle = 0
defer func() {
openAIWSIngressPreflightPingIdle = prevPreflightPingIdle
}()
cfg := &config.Config{}
cfg.Security.URLAllowlist.Enabled = false
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 2
cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
firstConn := &openAIWSPreflightFailConn{
events: [][]byte{
[]byte(`{"type":"response.completed","response":{"id":"resp_turn_ping_replay_only_fc_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
},
}
secondConn := &openAIWSCaptureConn{
events: [][]byte{
[]byte(`{"type":"error","error":{"type":"invalid_request_error","message":"No tool call found for function call output with call_id call_replay_1.","param":"input"}}`),
},
}
dialer := &openAIWSQueueDialer{
conns: []openAIWSClientConn{firstConn, secondConn},
}
pool := newOpenAIWSConnPool(cfg)
pool.setClientDialerForTest(dialer)
svc := &OpenAIGatewayService{
cfg: cfg,
httpUpstream: &httpUpstreamRecorder{},
cache: &stubGatewayCache{},
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
toolCorrector: NewCodexToolCorrector(),
openaiWSPool: pool,
}
account := &Account{
ID: 130,
Name: "openai-ingress-preflight-replay-only-function-output",
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "sk-test",
},
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}
serverErrCh := make(chan error, 1)
wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
CompressionMode: coderws.CompressionContextTakeover,
})
if err != nil {
serverErrCh <- err
return
}
defer func() {
_ = conn.CloseNow()
}()
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
req := r.Clone(r.Context())
req.Header = req.Header.Clone()
req.Header.Set("User-Agent", "unit-test-agent/1.0")
ginCtx.Request = req
readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
msgType, firstMessage, readErr := conn.Read(readCtx)
cancel()
if readErr != nil {
serverErrCh <- readErr
return
}
if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
serverErrCh <- errors.New("unsupported websocket client message type")
return
}
serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil)
}))
defer wsServer.Close()
dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil)
cancelDial()
require.NoError(t, err)
defer func() {
_ = clientConn.CloseNow()
}()
writeMessage := func(payload string) {
writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload)))
}
readMessage := func() []byte {
readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
msgType, message, readErr := clientConn.Read(readCtx)
require.NoError(t, readErr)
require.Equal(t, coderws.MessageText, msgType)
return message
}
writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"function_call","call_id":"call_other","name":"shell","arguments":"{}"},{"type":"function_call_output","call_id":"call_replay_1","output":"ok"}]}`)
firstTurn := readMessage()
require.Equal(t, "resp_turn_ping_replay_only_fc_1", gjson.GetBytes(firstTurn, "response.id").String())
writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"previous_response_id":"resp_turn_ping_replay_only_fc_1","input":[{"type":"input_text","text":"after tool output"}]}`)
select {
case serverErr := <-serverErrCh:
require.Error(t, serverErr)
var closeErr *OpenAIWSClientCloseError
require.ErrorAs(t, serverErr, &closeErr)
require.Equal(t, coderws.StatusPolicyViolation, closeErr.StatusCode())
require.Contains(t, closeErr.Reason(), "upstream continuation connection is unavailable")
case <-time.After(5 * time.Second):
t.Fatal("等待 ingress websocket 结束超时")
}
require.Equal(t, 1, dialer.DialCount(), "replay input 带 function_call_output 时不应换新连接重试")
secondConn.mu.Lock()
secondWrites := append([]map[string]any(nil), secondConn.writes...)
secondConn.mu.Unlock()
require.Empty(t, secondWrites, "不能把会触发 No tool call found 的重放请求发到新上游")
}
func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_WriteFailBeforeDownstreamRetriesOnce(t *testing.T) {
gin.SetMode(gin.TestMode)