fix(openai): preserve replay tool output continuation
This commit is contained in:
@@ -1552,6 +1552,15 @@ func openAIWSRawItemsHasPrefix(items []json.RawMessage, prefix []json.RawMessage
|
||||
return true
|
||||
}
|
||||
|
||||
func openAIWSRawItemsHasFunctionCallOutput(items []json.RawMessage) bool {
|
||||
for _, item := range items {
|
||||
if gjson.GetBytes(item, "type").String() == "function_call_output" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func buildOpenAIWSReplayInputSequence(
|
||||
previousFullInput []json.RawMessage,
|
||||
previousFullInputExists bool,
|
||||
@@ -3117,6 +3126,12 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
|
||||
currentTurnReplayInput := []json.RawMessage(nil)
|
||||
currentTurnReplayInputExists := false
|
||||
skipBeforeTurn := false
|
||||
hasCurrentOrReplayFunctionCallOutput := func(payload []byte) bool {
|
||||
if gjson.GetBytes(payload, `input.#(type=="function_call_output")`).Exists() {
|
||||
return true
|
||||
}
|
||||
return currentTurnReplayInputExists && openAIWSRawItemsHasFunctionCallOutput(currentTurnReplayInput)
|
||||
}
|
||||
resetSessionLease := func(markBroken bool) {
|
||||
if sessionLease == nil {
|
||||
return
|
||||
@@ -3139,7 +3154,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
|
||||
// 携带 function_call_output 的请求不能丢弃 previous_response_id:
|
||||
// 上游 API 需要 response chain 来匹配 tool_result 与之前的 tool_use,
|
||||
// 丢弃后会导致 "No tool call found for function call output" 400 错误。
|
||||
if gjson.GetBytes(currentPayload, `input.#(type=="function_call_output")`).Exists() {
|
||||
if hasCurrentOrReplayFunctionCallOutput(currentPayload) {
|
||||
return false
|
||||
}
|
||||
if isStrictAffinityTurn(currentPayload) {
|
||||
@@ -3298,6 +3313,9 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
|
||||
currentTurnReplayInput = nextReplayInput
|
||||
currentTurnReplayInputExists = nextReplayInputExists
|
||||
}
|
||||
replayHasFunctionCallOutput := currentTurnReplayInputExists &&
|
||||
openAIWSRawItemsHasFunctionCallOutput(currentTurnReplayInput)
|
||||
hasFunctionCallOutput = hasFunctionCallOutput || replayHasFunctionCallOutput
|
||||
if storeDisabled && turn > 1 && currentPreviousResponseID != "" {
|
||||
shouldKeepPreviousResponseID := false
|
||||
strictReason := ""
|
||||
@@ -3416,7 +3434,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
|
||||
// 携带 function_call_output 的请求不能丢弃 previous_response_id:
|
||||
// 上游 API 需要 response chain 来匹配 tool_result 与之前的 tool_use,
|
||||
// 丢弃后会导致 "No tool call found for function call output" 400 错误。
|
||||
hasFCOutput := gjson.GetBytes(currentPayload, `input.#(type=="function_call_output")`).Exists()
|
||||
hasFCOutput := hasFunctionCallOutput
|
||||
if !turnPrevRecoveryTried && currentPreviousResponseID != "" && !hasFCOutput {
|
||||
updatedPayload, removed, dropErr := dropPreviousResponseIDFromRawPayload(currentPayload)
|
||||
if dropErr != nil || !removed {
|
||||
@@ -3464,6 +3482,15 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
|
||||
}
|
||||
}
|
||||
}
|
||||
if hasFCOutput && currentPreviousResponseID != "" {
|
||||
logOpenAIWSModeInfo(
|
||||
"ingress_ws_preflight_ping_recovery_skip account_id=%d turn=%d conn_id=%s reason=function_call_output action=fail_close previous_response_id=%s",
|
||||
account.ID,
|
||||
turn,
|
||||
truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen),
|
||||
truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen),
|
||||
)
|
||||
}
|
||||
resetSessionLease(true)
|
||||
return NewOpenAIWSClientCloseError(
|
||||
coderws.StatusPolicyViolation,
|
||||
|
||||
@@ -1918,6 +1918,298 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledStr
|
||||
require.Equal(t, "world", gjson.Get(secondWrite, "input.1.text").String())
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledPreflightPingFailClosesWhenFunctionCallOutputNeedsPreviousResponseID(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
prevPreflightPingIdle := openAIWSIngressPreflightPingIdle
|
||||
openAIWSIngressPreflightPingIdle = 0
|
||||
defer func() {
|
||||
openAIWSIngressPreflightPingIdle = prevPreflightPingIdle
|
||||
}()
|
||||
|
||||
cfg := &config.Config{}
|
||||
cfg.Security.URLAllowlist.Enabled = false
|
||||
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
|
||||
cfg.Gateway.OpenAIWS.Enabled = true
|
||||
cfg.Gateway.OpenAIWS.OAuthEnabled = true
|
||||
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
|
||||
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
|
||||
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2
|
||||
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
|
||||
cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 2
|
||||
cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
|
||||
cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
|
||||
cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
|
||||
cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
|
||||
|
||||
firstConn := &openAIWSPreflightFailConn{
|
||||
events: [][]byte{
|
||||
[]byte(`{"type":"response.completed","response":{"id":"resp_turn_ping_replay_fc_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
|
||||
},
|
||||
}
|
||||
secondConn := &openAIWSCaptureConn{
|
||||
events: [][]byte{
|
||||
[]byte(`{"type":"error","error":{"type":"invalid_request_error","code":"previous_response_not_found","message":"Previous response not found."}}`),
|
||||
},
|
||||
}
|
||||
dialer := &openAIWSQueueDialer{
|
||||
conns: []openAIWSClientConn{firstConn, secondConn},
|
||||
}
|
||||
pool := newOpenAIWSConnPool(cfg)
|
||||
pool.setClientDialerForTest(dialer)
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
cfg: cfg,
|
||||
httpUpstream: &httpUpstreamRecorder{},
|
||||
cache: &stubGatewayCache{},
|
||||
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
|
||||
toolCorrector: NewCodexToolCorrector(),
|
||||
openaiWSPool: pool,
|
||||
}
|
||||
|
||||
account := &Account{
|
||||
ID: 129,
|
||||
Name: "openai-ingress-preflight-replay-function-output",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"api_key": "sk-test",
|
||||
},
|
||||
Extra: map[string]any{
|
||||
"responses_websockets_v2_enabled": true,
|
||||
},
|
||||
}
|
||||
|
||||
serverErrCh := make(chan error, 1)
|
||||
wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
|
||||
CompressionMode: coderws.CompressionContextTakeover,
|
||||
})
|
||||
if err != nil {
|
||||
serverErrCh <- err
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
_ = conn.CloseNow()
|
||||
}()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(rec)
|
||||
req := r.Clone(r.Context())
|
||||
req.Header = req.Header.Clone()
|
||||
req.Header.Set("User-Agent", "unit-test-agent/1.0")
|
||||
ginCtx.Request = req
|
||||
|
||||
readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
|
||||
msgType, firstMessage, readErr := conn.Read(readCtx)
|
||||
cancel()
|
||||
if readErr != nil {
|
||||
serverErrCh <- readErr
|
||||
return
|
||||
}
|
||||
if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
|
||||
serverErrCh <- errors.New("unsupported websocket client message type")
|
||||
return
|
||||
}
|
||||
|
||||
serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil)
|
||||
}))
|
||||
defer wsServer.Close()
|
||||
|
||||
dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil)
|
||||
cancelDial()
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
_ = clientConn.CloseNow()
|
||||
}()
|
||||
|
||||
writeMessage := func(payload string) {
|
||||
writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload)))
|
||||
}
|
||||
readMessage := func() []byte {
|
||||
readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
msgType, message, readErr := clientConn.Read(readCtx)
|
||||
require.NoError(t, readErr)
|
||||
require.Equal(t, coderws.MessageText, msgType)
|
||||
return message
|
||||
}
|
||||
|
||||
writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"function_call","call_id":"call_other","name":"shell","arguments":"{}"},{"type":"function_call_output","call_id":"call_replay_1","output":"ok"}]}`)
|
||||
firstTurn := readMessage()
|
||||
require.Equal(t, "resp_turn_ping_replay_fc_1", gjson.GetBytes(firstTurn, "response.id").String())
|
||||
|
||||
writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"previous_response_id":"resp_turn_ping_replay_fc_1","input":[{"type":"function_call_output","call_id":"call_replay_1","output":"ok"}]}`)
|
||||
|
||||
select {
|
||||
case serverErr := <-serverErrCh:
|
||||
require.Error(t, serverErr)
|
||||
var closeErr *OpenAIWSClientCloseError
|
||||
require.ErrorAs(t, serverErr, &closeErr)
|
||||
require.Equal(t, coderws.StatusPolicyViolation, closeErr.StatusCode())
|
||||
require.Contains(t, closeErr.Reason(), "upstream continuation connection is unavailable")
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("等待 ingress websocket 结束超时")
|
||||
}
|
||||
|
||||
require.Equal(t, 1, dialer.DialCount(), "需要 previous_response_id 的 function_call_output 在原连接不可用时不应换新连接重试")
|
||||
secondConn.mu.Lock()
|
||||
secondWrites := append([]map[string]any(nil), secondConn.writes...)
|
||||
secondConn.mu.Unlock()
|
||||
require.Empty(t, secondWrites, "不能把旧连接的 previous_response_id 发送到新上游,否则会触发 previous_response_not_found")
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledPreflightPingFailClosesWhenReplayHasFunctionCallOutput(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
prevPreflightPingIdle := openAIWSIngressPreflightPingIdle
|
||||
openAIWSIngressPreflightPingIdle = 0
|
||||
defer func() {
|
||||
openAIWSIngressPreflightPingIdle = prevPreflightPingIdle
|
||||
}()
|
||||
|
||||
cfg := &config.Config{}
|
||||
cfg.Security.URLAllowlist.Enabled = false
|
||||
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
|
||||
cfg.Gateway.OpenAIWS.Enabled = true
|
||||
cfg.Gateway.OpenAIWS.OAuthEnabled = true
|
||||
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
|
||||
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
|
||||
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2
|
||||
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
|
||||
cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 2
|
||||
cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
|
||||
cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
|
||||
cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
|
||||
cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
|
||||
|
||||
firstConn := &openAIWSPreflightFailConn{
|
||||
events: [][]byte{
|
||||
[]byte(`{"type":"response.completed","response":{"id":"resp_turn_ping_replay_only_fc_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
|
||||
},
|
||||
}
|
||||
secondConn := &openAIWSCaptureConn{
|
||||
events: [][]byte{
|
||||
[]byte(`{"type":"error","error":{"type":"invalid_request_error","message":"No tool call found for function call output with call_id call_replay_1.","param":"input"}}`),
|
||||
},
|
||||
}
|
||||
dialer := &openAIWSQueueDialer{
|
||||
conns: []openAIWSClientConn{firstConn, secondConn},
|
||||
}
|
||||
pool := newOpenAIWSConnPool(cfg)
|
||||
pool.setClientDialerForTest(dialer)
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
cfg: cfg,
|
||||
httpUpstream: &httpUpstreamRecorder{},
|
||||
cache: &stubGatewayCache{},
|
||||
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
|
||||
toolCorrector: NewCodexToolCorrector(),
|
||||
openaiWSPool: pool,
|
||||
}
|
||||
|
||||
account := &Account{
|
||||
ID: 130,
|
||||
Name: "openai-ingress-preflight-replay-only-function-output",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"api_key": "sk-test",
|
||||
},
|
||||
Extra: map[string]any{
|
||||
"responses_websockets_v2_enabled": true,
|
||||
},
|
||||
}
|
||||
|
||||
serverErrCh := make(chan error, 1)
|
||||
wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
|
||||
CompressionMode: coderws.CompressionContextTakeover,
|
||||
})
|
||||
if err != nil {
|
||||
serverErrCh <- err
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
_ = conn.CloseNow()
|
||||
}()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(rec)
|
||||
req := r.Clone(r.Context())
|
||||
req.Header = req.Header.Clone()
|
||||
req.Header.Set("User-Agent", "unit-test-agent/1.0")
|
||||
ginCtx.Request = req
|
||||
|
||||
readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
|
||||
msgType, firstMessage, readErr := conn.Read(readCtx)
|
||||
cancel()
|
||||
if readErr != nil {
|
||||
serverErrCh <- readErr
|
||||
return
|
||||
}
|
||||
if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
|
||||
serverErrCh <- errors.New("unsupported websocket client message type")
|
||||
return
|
||||
}
|
||||
|
||||
serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil)
|
||||
}))
|
||||
defer wsServer.Close()
|
||||
|
||||
dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil)
|
||||
cancelDial()
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
_ = clientConn.CloseNow()
|
||||
}()
|
||||
|
||||
writeMessage := func(payload string) {
|
||||
writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload)))
|
||||
}
|
||||
readMessage := func() []byte {
|
||||
readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
msgType, message, readErr := clientConn.Read(readCtx)
|
||||
require.NoError(t, readErr)
|
||||
require.Equal(t, coderws.MessageText, msgType)
|
||||
return message
|
||||
}
|
||||
|
||||
writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"function_call","call_id":"call_other","name":"shell","arguments":"{}"},{"type":"function_call_output","call_id":"call_replay_1","output":"ok"}]}`)
|
||||
firstTurn := readMessage()
|
||||
require.Equal(t, "resp_turn_ping_replay_only_fc_1", gjson.GetBytes(firstTurn, "response.id").String())
|
||||
|
||||
writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"previous_response_id":"resp_turn_ping_replay_only_fc_1","input":[{"type":"input_text","text":"after tool output"}]}`)
|
||||
|
||||
select {
|
||||
case serverErr := <-serverErrCh:
|
||||
require.Error(t, serverErr)
|
||||
var closeErr *OpenAIWSClientCloseError
|
||||
require.ErrorAs(t, serverErr, &closeErr)
|
||||
require.Equal(t, coderws.StatusPolicyViolation, closeErr.StatusCode())
|
||||
require.Contains(t, closeErr.Reason(), "upstream continuation connection is unavailable")
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("等待 ingress websocket 结束超时")
|
||||
}
|
||||
|
||||
require.Equal(t, 1, dialer.DialCount(), "replay input 带 function_call_output 时不应换新连接重试")
|
||||
secondConn.mu.Lock()
|
||||
secondWrites := append([]map[string]any(nil), secondConn.writes...)
|
||||
secondConn.mu.Unlock()
|
||||
require.Empty(t, secondWrites, "不能把会触发 No tool call found 的重放请求发到新上游")
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_WriteFailBeforeDownstreamRetriesOnce(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user