From dac6e520911c3d353132a75f802f1de8e5caeb34 Mon Sep 17 00:00:00 2001 From: gaoren002 Date: Sat, 25 Apr 2026 12:11:27 +0000 Subject: [PATCH] fix(openai): keep responses stream alive during pre-output failover --- .../service/openai_gateway_service.go | 16 +++++--- .../service/openai_gateway_service_test.go | 40 +++++++++++++++++++ 2 files changed, 51 insertions(+), 5 deletions(-) diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index 75a92f6e..5034a407 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -4008,8 +4008,10 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp if keepaliveTicker != nil { keepaliveCh = keepaliveTicker.C } - // 记录上次收到上游数据的时间,用于控制 keepalive 发送频率 - lastDataAt := time.Now() + // Track downstream writes separately from upstream reads: pre-output failover + // can buffer response.created / response.in_progress, so keepalive must be + // based on downstream idle time. + lastDownstreamWriteAt := time.Now() // 仅发送一次错误事件,避免多次写入导致协议混乱。 // 注意:OpenAI `/v1/responses` streaming 事件必须符合 OpenAI Responses schema; @@ -4041,6 +4043,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp return } clientOutputStarted = true + lastDownstreamWriteAt = time.Now() } needModelReplace := originalModel != mappedModel @@ -4071,6 +4074,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp logger.LegacyPrintf("service.openai_gateway", "Client disconnected during final flush, returning collected usage") } else if hadBufferedData { clientOutputStarted = true + lastDownstreamWriteAt = time.Now() } } return resultWithUsage(), nil @@ -4114,8 +4118,6 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp if streamFailoverErr != nil { return } - lastDataAt = time.Now() - // Extract data from SSE line (supports both "data: " and "data:" formats) if data, ok := extractOpenAISSEDataLine(line); ok { @@ -4170,6 +4172,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming flush, continuing to drain upstream for billing") } else { clientOutputStarted = true + lastDownstreamWriteAt = time.Now() } } } @@ -4197,6 +4200,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming flush, continuing to drain upstream for billing") } else { clientOutputStarted = true + lastDownstreamWriteAt = time.Now() } } } @@ -4283,7 +4287,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp if clientDisconnected { continue } - if time.Since(lastDataAt) < keepaliveInterval { + if time.Since(lastDownstreamWriteAt) < keepaliveInterval { continue } if _, err := bufferedWriter.WriteString(":\n\n"); err != nil { @@ -4294,6 +4298,8 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp if err := flushBuffered(); err != nil { clientDisconnected = true logger.LegacyPrintf("service.openai_gateway", "Client disconnected during keepalive flush, continuing to drain upstream for billing") + } else { + lastDownstreamWriteAt = time.Now() } } } diff --git a/backend/internal/service/openai_gateway_service_test.go b/backend/internal/service/openai_gateway_service_test.go index 0cf2392d..d54b00ab 100644 --- a/backend/internal/service/openai_gateway_service_test.go +++ b/backend/internal/service/openai_gateway_service_test.go @@ -1117,6 +1117,46 @@ func TestOpenAIStreamingPreambleOnlyMissingTerminalReturnsFailover(t *testing.T) require.Empty(t, rec.Body.String()) } +func TestOpenAIStreamingPreambleKeepaliveUsesDownstreamIdle(t *testing.T) { + gin.SetMode(gin.TestMode) + cfg := &config.Config{ + Gateway: config.GatewayConfig{ + StreamDataIntervalTimeout: 0, + StreamKeepaliveInterval: 1, + MaxLineSize: defaultMaxLineSize, + }, + } + svc := &OpenAIGatewayService{cfg: cfg} + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/", nil) + + pr, pw := io.Pipe() + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: pr, + Header: http.Header{}, + } + + go func() { + defer func() { _ = pw.Close() }() + _, _ = pw.Write([]byte("data: {\"type\":\"response.created\",\"response\":{\"id\":\"resp_1\"}}\n\n")) + for i := 0; i < 6; i++ { + time.Sleep(250 * time.Millisecond) + _, _ = pw.Write([]byte("data: {\"type\":\"response.in_progress\",\"response\":{\"id\":\"resp_1\"}}\n\n")) + } + _, _ = pw.Write([]byte("data: {\"type\":\"response.completed\",\"response\":{\"usage\":{\"input_tokens\":1,\"output_tokens\":2}}}\n\n")) + }() + + result, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1, Platform: PlatformOpenAI, Name: "acc"}, time.Now(), "model", "model") + _ = pr.Close() + require.NoError(t, err) + require.NotNil(t, result) + require.Contains(t, rec.Body.String(), ":\n\n") + require.Contains(t, rec.Body.String(), "response.completed") +} + func TestOpenAIStreamingPolicyResponseFailedBeforeOutputPassesThrough(t *testing.T) { gin.SetMode(gin.TestMode) cfg := &config.Config{