fix(openai): keep responses stream alive during pre-output failover
This commit is contained in:
@@ -4008,8 +4008,10 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
|||||||
if keepaliveTicker != nil {
|
if keepaliveTicker != nil {
|
||||||
keepaliveCh = keepaliveTicker.C
|
keepaliveCh = keepaliveTicker.C
|
||||||
}
|
}
|
||||||
// 记录上次收到上游数据的时间,用于控制 keepalive 发送频率
|
// Track downstream writes separately from upstream reads: pre-output failover
|
||||||
lastDataAt := time.Now()
|
// can buffer response.created / response.in_progress, so keepalive must be
|
||||||
|
// based on downstream idle time.
|
||||||
|
lastDownstreamWriteAt := time.Now()
|
||||||
|
|
||||||
// 仅发送一次错误事件,避免多次写入导致协议混乱。
|
// 仅发送一次错误事件,避免多次写入导致协议混乱。
|
||||||
// 注意:OpenAI `/v1/responses` streaming 事件必须符合 OpenAI Responses schema;
|
// 注意:OpenAI `/v1/responses` streaming 事件必须符合 OpenAI Responses schema;
|
||||||
@@ -4041,6 +4043,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
clientOutputStarted = true
|
clientOutputStarted = true
|
||||||
|
lastDownstreamWriteAt = time.Now()
|
||||||
}
|
}
|
||||||
|
|
||||||
needModelReplace := originalModel != mappedModel
|
needModelReplace := originalModel != mappedModel
|
||||||
@@ -4071,6 +4074,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
|||||||
logger.LegacyPrintf("service.openai_gateway", "Client disconnected during final flush, returning collected usage")
|
logger.LegacyPrintf("service.openai_gateway", "Client disconnected during final flush, returning collected usage")
|
||||||
} else if hadBufferedData {
|
} else if hadBufferedData {
|
||||||
clientOutputStarted = true
|
clientOutputStarted = true
|
||||||
|
lastDownstreamWriteAt = time.Now()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return resultWithUsage(), nil
|
return resultWithUsage(), nil
|
||||||
@@ -4114,8 +4118,6 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
|||||||
if streamFailoverErr != nil {
|
if streamFailoverErr != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
lastDataAt = time.Now()
|
|
||||||
|
|
||||||
// Extract data from SSE line (supports both "data: " and "data:" formats)
|
// Extract data from SSE line (supports both "data: " and "data:" formats)
|
||||||
if data, ok := extractOpenAISSEDataLine(line); ok {
|
if data, ok := extractOpenAISSEDataLine(line); ok {
|
||||||
|
|
||||||
@@ -4170,6 +4172,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
|||||||
logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming flush, continuing to drain upstream for billing")
|
logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming flush, continuing to drain upstream for billing")
|
||||||
} else {
|
} else {
|
||||||
clientOutputStarted = true
|
clientOutputStarted = true
|
||||||
|
lastDownstreamWriteAt = time.Now()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -4197,6 +4200,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
|||||||
logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming flush, continuing to drain upstream for billing")
|
logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming flush, continuing to drain upstream for billing")
|
||||||
} else {
|
} else {
|
||||||
clientOutputStarted = true
|
clientOutputStarted = true
|
||||||
|
lastDownstreamWriteAt = time.Now()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -4283,7 +4287,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
|||||||
if clientDisconnected {
|
if clientDisconnected {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if time.Since(lastDataAt) < keepaliveInterval {
|
if time.Since(lastDownstreamWriteAt) < keepaliveInterval {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if _, err := bufferedWriter.WriteString(":\n\n"); err != nil {
|
if _, err := bufferedWriter.WriteString(":\n\n"); err != nil {
|
||||||
@@ -4294,6 +4298,8 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
|||||||
if err := flushBuffered(); err != nil {
|
if err := flushBuffered(); err != nil {
|
||||||
clientDisconnected = true
|
clientDisconnected = true
|
||||||
logger.LegacyPrintf("service.openai_gateway", "Client disconnected during keepalive flush, continuing to drain upstream for billing")
|
logger.LegacyPrintf("service.openai_gateway", "Client disconnected during keepalive flush, continuing to drain upstream for billing")
|
||||||
|
} else {
|
||||||
|
lastDownstreamWriteAt = time.Now()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1117,6 +1117,46 @@ func TestOpenAIStreamingPreambleOnlyMissingTerminalReturnsFailover(t *testing.T)
|
|||||||
require.Empty(t, rec.Body.String())
|
require.Empty(t, rec.Body.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOpenAIStreamingPreambleKeepaliveUsesDownstreamIdle(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
cfg := &config.Config{
|
||||||
|
Gateway: config.GatewayConfig{
|
||||||
|
StreamDataIntervalTimeout: 0,
|
||||||
|
StreamKeepaliveInterval: 1,
|
||||||
|
MaxLineSize: defaultMaxLineSize,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := &OpenAIGatewayService{cfg: cfg}
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||||
|
|
||||||
|
pr, pw := io.Pipe()
|
||||||
|
resp := &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Body: pr,
|
||||||
|
Header: http.Header{},
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer func() { _ = pw.Close() }()
|
||||||
|
_, _ = pw.Write([]byte("data: {\"type\":\"response.created\",\"response\":{\"id\":\"resp_1\"}}\n\n"))
|
||||||
|
for i := 0; i < 6; i++ {
|
||||||
|
time.Sleep(250 * time.Millisecond)
|
||||||
|
_, _ = pw.Write([]byte("data: {\"type\":\"response.in_progress\",\"response\":{\"id\":\"resp_1\"}}\n\n"))
|
||||||
|
}
|
||||||
|
_, _ = pw.Write([]byte("data: {\"type\":\"response.completed\",\"response\":{\"usage\":{\"input_tokens\":1,\"output_tokens\":2}}}\n\n"))
|
||||||
|
}()
|
||||||
|
|
||||||
|
result, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1, Platform: PlatformOpenAI, Name: "acc"}, time.Now(), "model", "model")
|
||||||
|
_ = pr.Close()
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.Contains(t, rec.Body.String(), ":\n\n")
|
||||||
|
require.Contains(t, rec.Body.String(), "response.completed")
|
||||||
|
}
|
||||||
|
|
||||||
func TestOpenAIStreamingPolicyResponseFailedBeforeOutputPassesThrough(t *testing.T) {
|
func TestOpenAIStreamingPolicyResponseFailedBeforeOutputPassesThrough(t *testing.T) {
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
cfg := &config.Config{
|
cfg := &config.Config{
|
||||||
|
|||||||
Reference in New Issue
Block a user