fix(openai): keep responses stream alive during pre-output failover
This commit is contained in:
@@ -4008,8 +4008,10 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
||||
if keepaliveTicker != nil {
|
||||
keepaliveCh = keepaliveTicker.C
|
||||
}
|
||||
// 记录上次收到上游数据的时间,用于控制 keepalive 发送频率
|
||||
lastDataAt := time.Now()
|
||||
// Track downstream writes separately from upstream reads: pre-output failover
|
||||
// can buffer response.created / response.in_progress, so keepalive must be
|
||||
// based on downstream idle time.
|
||||
lastDownstreamWriteAt := time.Now()
|
||||
|
||||
// 仅发送一次错误事件,避免多次写入导致协议混乱。
|
||||
// 注意:OpenAI `/v1/responses` streaming 事件必须符合 OpenAI Responses schema;
|
||||
@@ -4041,6 +4043,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
||||
return
|
||||
}
|
||||
clientOutputStarted = true
|
||||
lastDownstreamWriteAt = time.Now()
|
||||
}
|
||||
|
||||
needModelReplace := originalModel != mappedModel
|
||||
@@ -4071,6 +4074,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
||||
logger.LegacyPrintf("service.openai_gateway", "Client disconnected during final flush, returning collected usage")
|
||||
} else if hadBufferedData {
|
||||
clientOutputStarted = true
|
||||
lastDownstreamWriteAt = time.Now()
|
||||
}
|
||||
}
|
||||
return resultWithUsage(), nil
|
||||
@@ -4114,8 +4118,6 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
||||
if streamFailoverErr != nil {
|
||||
return
|
||||
}
|
||||
lastDataAt = time.Now()
|
||||
|
||||
// Extract data from SSE line (supports both "data: " and "data:" formats)
|
||||
if data, ok := extractOpenAISSEDataLine(line); ok {
|
||||
|
||||
@@ -4170,6 +4172,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
||||
logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming flush, continuing to drain upstream for billing")
|
||||
} else {
|
||||
clientOutputStarted = true
|
||||
lastDownstreamWriteAt = time.Now()
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -4197,6 +4200,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
||||
logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming flush, continuing to drain upstream for billing")
|
||||
} else {
|
||||
clientOutputStarted = true
|
||||
lastDownstreamWriteAt = time.Now()
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -4283,7 +4287,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
||||
if clientDisconnected {
|
||||
continue
|
||||
}
|
||||
if time.Since(lastDataAt) < keepaliveInterval {
|
||||
if time.Since(lastDownstreamWriteAt) < keepaliveInterval {
|
||||
continue
|
||||
}
|
||||
if _, err := bufferedWriter.WriteString(":\n\n"); err != nil {
|
||||
@@ -4294,6 +4298,8 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
||||
if err := flushBuffered(); err != nil {
|
||||
clientDisconnected = true
|
||||
logger.LegacyPrintf("service.openai_gateway", "Client disconnected during keepalive flush, continuing to drain upstream for billing")
|
||||
} else {
|
||||
lastDownstreamWriteAt = time.Now()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1117,6 +1117,46 @@ func TestOpenAIStreamingPreambleOnlyMissingTerminalReturnsFailover(t *testing.T)
|
||||
require.Empty(t, rec.Body.String())
|
||||
}
|
||||
|
||||
func TestOpenAIStreamingPreambleKeepaliveUsesDownstreamIdle(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
cfg := &config.Config{
|
||||
Gateway: config.GatewayConfig{
|
||||
StreamDataIntervalTimeout: 0,
|
||||
StreamKeepaliveInterval: 1,
|
||||
MaxLineSize: defaultMaxLineSize,
|
||||
},
|
||||
}
|
||||
svc := &OpenAIGatewayService{cfg: cfg}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: pr,
|
||||
Header: http.Header{},
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer func() { _ = pw.Close() }()
|
||||
_, _ = pw.Write([]byte("data: {\"type\":\"response.created\",\"response\":{\"id\":\"resp_1\"}}\n\n"))
|
||||
for i := 0; i < 6; i++ {
|
||||
time.Sleep(250 * time.Millisecond)
|
||||
_, _ = pw.Write([]byte("data: {\"type\":\"response.in_progress\",\"response\":{\"id\":\"resp_1\"}}\n\n"))
|
||||
}
|
||||
_, _ = pw.Write([]byte("data: {\"type\":\"response.completed\",\"response\":{\"usage\":{\"input_tokens\":1,\"output_tokens\":2}}}\n\n"))
|
||||
}()
|
||||
|
||||
result, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1, Platform: PlatformOpenAI, Name: "acc"}, time.Now(), "model", "model")
|
||||
_ = pr.Close()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Contains(t, rec.Body.String(), ":\n\n")
|
||||
require.Contains(t, rec.Body.String(), "response.completed")
|
||||
}
|
||||
|
||||
func TestOpenAIStreamingPolicyResponseFailedBeforeOutputPassesThrough(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
cfg := &config.Config{
|
||||
|
||||
Reference in New Issue
Block a user