From 98b65e67f21189f441f92dec88ed40b3ba7e8561 Mon Sep 17 00:00:00 2001 From: cyhhao Date: Thu, 15 Jan 2026 21:42:13 +0800 Subject: [PATCH 1/2] fix(gateway): avoid injecting invalid SSE on client cancel --- .../service/openai_gateway_service.go | 6 +++ .../service/openai_gateway_service_test.go | 37 +++++++++++++++++++ 2 files changed, 43 insertions(+) diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index 04a90fdd..d49be282 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -1064,6 +1064,12 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil } if ev.err != nil { + // 客户端断开/取消请求时,上游读取往往会返回 context canceled。 + // /v1/responses 的 SSE 事件必须符合 OpenAI 协议;这里不注入自定义 error event,避免下游 SDK 解析失败。 + if errors.Is(ev.err, context.Canceled) || errors.Is(ev.err, context.DeadlineExceeded) { + log.Printf("Context canceled during streaming, returning collected usage") + return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil + } if errors.Is(ev.err, bufio.ErrTooLong) { log.Printf("SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err) sendErrorEvent("response_too_large") diff --git a/backend/internal/service/openai_gateway_service_test.go b/backend/internal/service/openai_gateway_service_test.go index 42b88b7d..ead6e143 100644 --- a/backend/internal/service/openai_gateway_service_test.go +++ b/backend/internal/service/openai_gateway_service_test.go @@ -33,6 +33,11 @@ type stubConcurrencyCache struct { ConcurrencyCache } +type cancelReadCloser struct{} + +func (c cancelReadCloser) Read(p []byte) (int, error) { return 0, context.Canceled } +func (c cancelReadCloser) Close() error { return nil } + func (c stubConcurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) { return true, nil } @@ -174,6 +179,38 @@ func TestOpenAIStreamingTimeout(t *testing.T) { } } +func TestOpenAIStreamingContextCanceledDoesNotInjectErrorEvent(t *testing.T) { + gin.SetMode(gin.TestMode) + cfg := &config.Config{ + Gateway: config.GatewayConfig{ + StreamDataIntervalTimeout: 0, + StreamKeepaliveInterval: 0, + MaxLineSize: defaultMaxLineSize, + }, + } + svc := &OpenAIGatewayService{cfg: cfg} + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + ctx, cancel := context.WithCancel(context.Background()) + cancel() + c.Request = httptest.NewRequest(http.MethodPost, "/", nil).WithContext(ctx) + + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: cancelReadCloser{}, + Header: http.Header{}, + } + + _, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "model", "model") + if err != nil { + t.Fatalf("expected nil error, got %v", err) + } + if strings.Contains(rec.Body.String(), "event: error") || strings.Contains(rec.Body.String(), "stream_read_error") { + t.Fatalf("expected no injected SSE error event, got %q", rec.Body.String()) + } +} + func TestOpenAIStreamingTooLong(t *testing.T) { gin.SetMode(gin.TestMode) cfg := &config.Config{ From c11f14f3a030c30846183704ccd6193785899bd4 Mon Sep 17 00:00:00 2001 From: cyhhao Date: Thu, 15 Jan 2026 21:51:14 +0800 Subject: [PATCH 2/2] fix(gateway): drain upstream after client disconnect --- .../service/openai_gateway_service.go | 43 ++++++++++---- .../service/openai_gateway_service_test.go | 59 +++++++++++++++++++ 2 files changed, 91 insertions(+), 11 deletions(-) diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index d49be282..fb811e9e 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -1046,8 +1046,9 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp // 仅发送一次错误事件,避免多次写入导致协议混乱(写失败时尽力通知客户端) errorEventSent := false + clientDisconnected := false // 客户端断开后继续 drain 上游以收集 usage sendErrorEvent := func(reason string) { - if errorEventSent { + if errorEventSent || clientDisconnected { return } errorEventSent = true @@ -1070,6 +1071,11 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp log.Printf("Context canceled during streaming, returning collected usage") return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil } + // 客户端已断开时,上游出错仅影响体验,不影响计费;返回已收集 usage + if clientDisconnected { + log.Printf("Upstream read error after client disconnect: %v, returning collected usage", ev.err) + return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil + } if errors.Is(ev.err, bufio.ErrTooLong) { log.Printf("SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err) sendErrorEvent("response_too_large") @@ -1091,12 +1097,15 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp line = s.replaceModelInSSELine(line, mappedModel, originalModel) } - // Forward line - if _, err := fmt.Fprintf(w, "%s\n", line); err != nil { - sendErrorEvent("write_failed") - return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err + // 写入客户端(客户端断开后继续 drain 上游) + if !clientDisconnected { + if _, err := fmt.Fprintf(w, "%s\n", line); err != nil { + clientDisconnected = true + log.Printf("Client disconnected during streaming, continuing to drain upstream for billing") + } else { + flusher.Flush() + } } - flusher.Flush() // Record first token time if firstTokenMs == nil && data != "" && data != "[DONE]" { @@ -1106,11 +1115,14 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp s.parseSSEUsage(data, usage) } else { // Forward non-data lines as-is - if _, err := fmt.Fprintf(w, "%s\n", line); err != nil { - sendErrorEvent("write_failed") - return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err + if !clientDisconnected { + if _, err := fmt.Fprintf(w, "%s\n", line); err != nil { + clientDisconnected = true + log.Printf("Client disconnected during streaming, continuing to drain upstream for billing") + } else { + flusher.Flush() + } } - flusher.Flush() } case <-intervalCh: @@ -1118,6 +1130,10 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp if time.Since(lastRead) < streamInterval { continue } + if clientDisconnected { + log.Printf("Upstream timeout after client disconnect, returning collected usage") + return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil + } log.Printf("Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval) // 处理流超时,可能标记账户为临时不可调度或错误状态 if s.rateLimitService != nil { @@ -1127,11 +1143,16 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout") case <-keepaliveCh: + if clientDisconnected { + continue + } if time.Since(lastDataAt) < keepaliveInterval { continue } if _, err := fmt.Fprint(w, ":\n\n"); err != nil { - return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err + clientDisconnected = true + log.Printf("Client disconnected during streaming, continuing to drain upstream for billing") + continue } flusher.Flush() } diff --git a/backend/internal/service/openai_gateway_service_test.go b/backend/internal/service/openai_gateway_service_test.go index ead6e143..3ec37544 100644 --- a/backend/internal/service/openai_gateway_service_test.go +++ b/backend/internal/service/openai_gateway_service_test.go @@ -38,6 +38,20 @@ type cancelReadCloser struct{} func (c cancelReadCloser) Read(p []byte) (int, error) { return 0, context.Canceled } func (c cancelReadCloser) Close() error { return nil } +type failingGinWriter struct { + gin.ResponseWriter + failAfter int + writes int +} + +func (w *failingGinWriter) Write(p []byte) (int, error) { + if w.writes >= w.failAfter { + return 0, errors.New("write failed") + } + w.writes++ + return w.ResponseWriter.Write(p) +} + func (c stubConcurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) { return true, nil } @@ -211,6 +225,51 @@ func TestOpenAIStreamingContextCanceledDoesNotInjectErrorEvent(t *testing.T) { } } +func TestOpenAIStreamingClientDisconnectDrainsUpstreamUsage(t *testing.T) { + gin.SetMode(gin.TestMode) + cfg := &config.Config{ + Gateway: config.GatewayConfig{ + StreamDataIntervalTimeout: 0, + StreamKeepaliveInterval: 0, + MaxLineSize: defaultMaxLineSize, + }, + } + svc := &OpenAIGatewayService{cfg: cfg} + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/", nil) + c.Writer = &failingGinWriter{ResponseWriter: c.Writer, failAfter: 0} + + pr, pw := io.Pipe() + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: pr, + Header: http.Header{}, + } + + go func() { + defer func() { _ = pw.Close() }() + _, _ = pw.Write([]byte("data: {\"type\":\"response.in_progress\",\"response\":{}}\n\n")) + _, _ = pw.Write([]byte("data: {\"type\":\"response.completed\",\"response\":{\"usage\":{\"input_tokens\":3,\"output_tokens\":5,\"input_tokens_details\":{\"cached_tokens\":1}}}}\n\n")) + }() + + result, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "model", "model") + _ = pr.Close() + if err != nil { + t.Fatalf("expected nil error, got %v", err) + } + if result == nil || result.usage == nil { + t.Fatalf("expected usage result") + } + if result.usage.InputTokens != 3 || result.usage.OutputTokens != 5 || result.usage.CacheReadInputTokens != 1 { + t.Fatalf("unexpected usage: %+v", *result.usage) + } + if strings.Contains(rec.Body.String(), "event: error") || strings.Contains(rec.Body.String(), "write_failed") { + t.Fatalf("expected no injected SSE error event, got %q", rec.Body.String()) + } +} + func TestOpenAIStreamingTooLong(t *testing.T) { gin.SetMode(gin.TestMode) cfg := &config.Config{