Merge pull request #1960 from gaoren002/fix/openai-stream-keepalive-downstream-idle

fix(openai): keep responses stream alive during pre-output failover
This commit is contained in:
Wesley Liddick
2026-04-25 20:24:25 +08:00
committed by GitHub
2 changed files with 51 additions and 5 deletions

View File

@@ -4008,8 +4008,10 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
if keepaliveTicker != nil {
keepaliveCh = keepaliveTicker.C
}
// 记录上次收到上游数据的时间,用于控制 keepalive 发送频率
lastDataAt := time.Now()
// Track downstream writes separately from upstream reads: pre-output failover
// can buffer response.created / response.in_progress, so keepalive must be
// based on downstream idle time.
lastDownstreamWriteAt := time.Now()
// 仅发送一次错误事件,避免多次写入导致协议混乱。
// 注意OpenAI `/v1/responses` streaming 事件必须符合 OpenAI Responses schema
@@ -4041,6 +4043,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
return
}
clientOutputStarted = true
lastDownstreamWriteAt = time.Now()
}
needModelReplace := originalModel != mappedModel
@@ -4071,6 +4074,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
logger.LegacyPrintf("service.openai_gateway", "Client disconnected during final flush, returning collected usage")
} else if hadBufferedData {
clientOutputStarted = true
lastDownstreamWriteAt = time.Now()
}
}
return resultWithUsage(), nil
@@ -4114,8 +4118,6 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
if streamFailoverErr != nil {
return
}
lastDataAt = time.Now()
// Extract data from SSE line (supports both "data: " and "data:" formats)
if data, ok := extractOpenAISSEDataLine(line); ok {
@@ -4170,6 +4172,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming flush, continuing to drain upstream for billing")
} else {
clientOutputStarted = true
lastDownstreamWriteAt = time.Now()
}
}
}
@@ -4197,6 +4200,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming flush, continuing to drain upstream for billing")
} else {
clientOutputStarted = true
lastDownstreamWriteAt = time.Now()
}
}
}
@@ -4283,7 +4287,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
if clientDisconnected {
continue
}
if time.Since(lastDataAt) < keepaliveInterval {
if time.Since(lastDownstreamWriteAt) < keepaliveInterval {
continue
}
if _, err := bufferedWriter.WriteString(":\n\n"); err != nil {
@@ -4294,6 +4298,8 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
if err := flushBuffered(); err != nil {
clientDisconnected = true
logger.LegacyPrintf("service.openai_gateway", "Client disconnected during keepalive flush, continuing to drain upstream for billing")
} else {
lastDownstreamWriteAt = time.Now()
}
}
}

View File

@@ -1117,6 +1117,46 @@ func TestOpenAIStreamingPreambleOnlyMissingTerminalReturnsFailover(t *testing.T)
require.Empty(t, rec.Body.String())
}
func TestOpenAIStreamingPreambleKeepaliveUsesDownstreamIdle(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
Gateway: config.GatewayConfig{
StreamDataIntervalTimeout: 0,
StreamKeepaliveInterval: 1,
MaxLineSize: defaultMaxLineSize,
},
}
svc := &OpenAIGatewayService{cfg: cfg}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
pr, pw := io.Pipe()
resp := &http.Response{
StatusCode: http.StatusOK,
Body: pr,
Header: http.Header{},
}
go func() {
defer func() { _ = pw.Close() }()
_, _ = pw.Write([]byte("data: {\"type\":\"response.created\",\"response\":{\"id\":\"resp_1\"}}\n\n"))
for i := 0; i < 6; i++ {
time.Sleep(250 * time.Millisecond)
_, _ = pw.Write([]byte("data: {\"type\":\"response.in_progress\",\"response\":{\"id\":\"resp_1\"}}\n\n"))
}
_, _ = pw.Write([]byte("data: {\"type\":\"response.completed\",\"response\":{\"usage\":{\"input_tokens\":1,\"output_tokens\":2}}}\n\n"))
}()
result, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1, Platform: PlatformOpenAI, Name: "acc"}, time.Now(), "model", "model")
_ = pr.Close()
require.NoError(t, err)
require.NotNil(t, result)
require.Contains(t, rec.Body.String(), ":\n\n")
require.Contains(t, rec.Body.String(), "response.completed")
}
func TestOpenAIStreamingPolicyResponseFailedBeforeOutputPassesThrough(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{