Merge pull request #1 from cyhhao/fix/responses-stream-cancel
fix(gateway): avoid invalid SSE error on canceled stream
This commit is contained in:
@@ -1046,8 +1046,9 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
|||||||
|
|
||||||
// 仅发送一次错误事件,避免多次写入导致协议混乱(写失败时尽力通知客户端)
|
// 仅发送一次错误事件,避免多次写入导致协议混乱(写失败时尽力通知客户端)
|
||||||
errorEventSent := false
|
errorEventSent := false
|
||||||
|
clientDisconnected := false // 客户端断开后继续 drain 上游以收集 usage
|
||||||
sendErrorEvent := func(reason string) {
|
sendErrorEvent := func(reason string) {
|
||||||
if errorEventSent {
|
if errorEventSent || clientDisconnected {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
errorEventSent = true
|
errorEventSent = true
|
||||||
@@ -1064,6 +1065,17 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
|||||||
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
|
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
|
||||||
}
|
}
|
||||||
if ev.err != nil {
|
if ev.err != nil {
|
||||||
|
// 客户端断开/取消请求时,上游读取往往会返回 context canceled。
|
||||||
|
// /v1/responses 的 SSE 事件必须符合 OpenAI 协议;这里不注入自定义 error event,避免下游 SDK 解析失败。
|
||||||
|
if errors.Is(ev.err, context.Canceled) || errors.Is(ev.err, context.DeadlineExceeded) {
|
||||||
|
log.Printf("Context canceled during streaming, returning collected usage")
|
||||||
|
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
|
||||||
|
}
|
||||||
|
// 客户端已断开时,上游出错仅影响体验,不影响计费;返回已收集 usage
|
||||||
|
if clientDisconnected {
|
||||||
|
log.Printf("Upstream read error after client disconnect: %v, returning collected usage", ev.err)
|
||||||
|
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
|
||||||
|
}
|
||||||
if errors.Is(ev.err, bufio.ErrTooLong) {
|
if errors.Is(ev.err, bufio.ErrTooLong) {
|
||||||
log.Printf("SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err)
|
log.Printf("SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err)
|
||||||
sendErrorEvent("response_too_large")
|
sendErrorEvent("response_too_large")
|
||||||
@@ -1085,12 +1097,15 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
|||||||
line = s.replaceModelInSSELine(line, mappedModel, originalModel)
|
line = s.replaceModelInSSELine(line, mappedModel, originalModel)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Forward line
|
// 写入客户端(客户端断开后继续 drain 上游)
|
||||||
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
|
if !clientDisconnected {
|
||||||
sendErrorEvent("write_failed")
|
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
|
||||||
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
|
clientDisconnected = true
|
||||||
|
log.Printf("Client disconnected during streaming, continuing to drain upstream for billing")
|
||||||
|
} else {
|
||||||
|
flusher.Flush()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
flusher.Flush()
|
|
||||||
|
|
||||||
// Record first token time
|
// Record first token time
|
||||||
if firstTokenMs == nil && data != "" && data != "[DONE]" {
|
if firstTokenMs == nil && data != "" && data != "[DONE]" {
|
||||||
@@ -1100,11 +1115,14 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
|||||||
s.parseSSEUsage(data, usage)
|
s.parseSSEUsage(data, usage)
|
||||||
} else {
|
} else {
|
||||||
// Forward non-data lines as-is
|
// Forward non-data lines as-is
|
||||||
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
|
if !clientDisconnected {
|
||||||
sendErrorEvent("write_failed")
|
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
|
||||||
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
|
clientDisconnected = true
|
||||||
|
log.Printf("Client disconnected during streaming, continuing to drain upstream for billing")
|
||||||
|
} else {
|
||||||
|
flusher.Flush()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
flusher.Flush()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
case <-intervalCh:
|
case <-intervalCh:
|
||||||
@@ -1112,6 +1130,10 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
|||||||
if time.Since(lastRead) < streamInterval {
|
if time.Since(lastRead) < streamInterval {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
if clientDisconnected {
|
||||||
|
log.Printf("Upstream timeout after client disconnect, returning collected usage")
|
||||||
|
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
|
||||||
|
}
|
||||||
log.Printf("Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval)
|
log.Printf("Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval)
|
||||||
// 处理流超时,可能标记账户为临时不可调度或错误状态
|
// 处理流超时,可能标记账户为临时不可调度或错误状态
|
||||||
if s.rateLimitService != nil {
|
if s.rateLimitService != nil {
|
||||||
@@ -1121,11 +1143,16 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
|||||||
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
|
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
|
||||||
|
|
||||||
case <-keepaliveCh:
|
case <-keepaliveCh:
|
||||||
|
if clientDisconnected {
|
||||||
|
continue
|
||||||
|
}
|
||||||
if time.Since(lastDataAt) < keepaliveInterval {
|
if time.Since(lastDataAt) < keepaliveInterval {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if _, err := fmt.Fprint(w, ":\n\n"); err != nil {
|
if _, err := fmt.Fprint(w, ":\n\n"); err != nil {
|
||||||
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
|
clientDisconnected = true
|
||||||
|
log.Printf("Client disconnected during streaming, continuing to drain upstream for billing")
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
flusher.Flush()
|
flusher.Flush()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -33,6 +33,25 @@ type stubConcurrencyCache struct {
|
|||||||
ConcurrencyCache
|
ConcurrencyCache
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type cancelReadCloser struct{}
|
||||||
|
|
||||||
|
func (c cancelReadCloser) Read(p []byte) (int, error) { return 0, context.Canceled }
|
||||||
|
func (c cancelReadCloser) Close() error { return nil }
|
||||||
|
|
||||||
|
type failingGinWriter struct {
|
||||||
|
gin.ResponseWriter
|
||||||
|
failAfter int
|
||||||
|
writes int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *failingGinWriter) Write(p []byte) (int, error) {
|
||||||
|
if w.writes >= w.failAfter {
|
||||||
|
return 0, errors.New("write failed")
|
||||||
|
}
|
||||||
|
w.writes++
|
||||||
|
return w.ResponseWriter.Write(p)
|
||||||
|
}
|
||||||
|
|
||||||
func (c stubConcurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
|
func (c stubConcurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
|
||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
@@ -174,6 +193,83 @@ func TestOpenAIStreamingTimeout(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOpenAIStreamingContextCanceledDoesNotInjectErrorEvent(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
cfg := &config.Config{
|
||||||
|
Gateway: config.GatewayConfig{
|
||||||
|
StreamDataIntervalTimeout: 0,
|
||||||
|
StreamKeepaliveInterval: 0,
|
||||||
|
MaxLineSize: defaultMaxLineSize,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := &OpenAIGatewayService{cfg: cfg}
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancel()
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/", nil).WithContext(ctx)
|
||||||
|
|
||||||
|
resp := &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Body: cancelReadCloser{},
|
||||||
|
Header: http.Header{},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "model", "model")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected nil error, got %v", err)
|
||||||
|
}
|
||||||
|
if strings.Contains(rec.Body.String(), "event: error") || strings.Contains(rec.Body.String(), "stream_read_error") {
|
||||||
|
t.Fatalf("expected no injected SSE error event, got %q", rec.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAIStreamingClientDisconnectDrainsUpstreamUsage(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
cfg := &config.Config{
|
||||||
|
Gateway: config.GatewayConfig{
|
||||||
|
StreamDataIntervalTimeout: 0,
|
||||||
|
StreamKeepaliveInterval: 0,
|
||||||
|
MaxLineSize: defaultMaxLineSize,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := &OpenAIGatewayService{cfg: cfg}
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||||
|
c.Writer = &failingGinWriter{ResponseWriter: c.Writer, failAfter: 0}
|
||||||
|
|
||||||
|
pr, pw := io.Pipe()
|
||||||
|
resp := &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Body: pr,
|
||||||
|
Header: http.Header{},
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer func() { _ = pw.Close() }()
|
||||||
|
_, _ = pw.Write([]byte("data: {\"type\":\"response.in_progress\",\"response\":{}}\n\n"))
|
||||||
|
_, _ = pw.Write([]byte("data: {\"type\":\"response.completed\",\"response\":{\"usage\":{\"input_tokens\":3,\"output_tokens\":5,\"input_tokens_details\":{\"cached_tokens\":1}}}}\n\n"))
|
||||||
|
}()
|
||||||
|
|
||||||
|
result, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "model", "model")
|
||||||
|
_ = pr.Close()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected nil error, got %v", err)
|
||||||
|
}
|
||||||
|
if result == nil || result.usage == nil {
|
||||||
|
t.Fatalf("expected usage result")
|
||||||
|
}
|
||||||
|
if result.usage.InputTokens != 3 || result.usage.OutputTokens != 5 || result.usage.CacheReadInputTokens != 1 {
|
||||||
|
t.Fatalf("unexpected usage: %+v", *result.usage)
|
||||||
|
}
|
||||||
|
if strings.Contains(rec.Body.String(), "event: error") || strings.Contains(rec.Body.String(), "write_failed") {
|
||||||
|
t.Fatalf("expected no injected SSE error event, got %q", rec.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestOpenAIStreamingTooLong(t *testing.T) {
|
func TestOpenAIStreamingTooLong(t *testing.T) {
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
cfg := &config.Config{
|
cfg := &config.Config{
|
||||||
|
|||||||
Reference in New Issue
Block a user