From 98b65e67f21189f441f92dec88ed40b3ba7e8561 Mon Sep 17 00:00:00 2001 From: cyhhao Date: Thu, 15 Jan 2026 21:42:13 +0800 Subject: [PATCH] fix(gateway): avoid injecting invalid SSE on client cancel --- .../service/openai_gateway_service.go | 6 +++ .../service/openai_gateway_service_test.go | 37 +++++++++++++++++++ 2 files changed, 43 insertions(+) diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index 04a90fdd..d49be282 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -1064,6 +1064,12 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil } if ev.err != nil { + // 客户端断开/取消请求时,上游读取往往会返回 context canceled。 + // /v1/responses 的 SSE 事件必须符合 OpenAI 协议;这里不注入自定义 error event,避免下游 SDK 解析失败。 + if errors.Is(ev.err, context.Canceled) || errors.Is(ev.err, context.DeadlineExceeded) { + log.Printf("Context canceled during streaming, returning collected usage") + return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil + } if errors.Is(ev.err, bufio.ErrTooLong) { log.Printf("SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err) sendErrorEvent("response_too_large") diff --git a/backend/internal/service/openai_gateway_service_test.go b/backend/internal/service/openai_gateway_service_test.go index 42b88b7d..ead6e143 100644 --- a/backend/internal/service/openai_gateway_service_test.go +++ b/backend/internal/service/openai_gateway_service_test.go @@ -33,6 +33,11 @@ type stubConcurrencyCache struct { ConcurrencyCache } +type cancelReadCloser struct{} + +func (c cancelReadCloser) Read(p []byte) (int, error) { return 0, context.Canceled } +func (c cancelReadCloser) Close() error { return nil } + func (c stubConcurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) { return true, nil } @@ -174,6 +179,38 @@ func TestOpenAIStreamingTimeout(t *testing.T) { } } +func TestOpenAIStreamingContextCanceledDoesNotInjectErrorEvent(t *testing.T) { + gin.SetMode(gin.TestMode) + cfg := &config.Config{ + Gateway: config.GatewayConfig{ + StreamDataIntervalTimeout: 0, + StreamKeepaliveInterval: 0, + MaxLineSize: defaultMaxLineSize, + }, + } + svc := &OpenAIGatewayService{cfg: cfg} + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + ctx, cancel := context.WithCancel(context.Background()) + cancel() + c.Request = httptest.NewRequest(http.MethodPost, "/", nil).WithContext(ctx) + + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: cancelReadCloser{}, + Header: http.Header{}, + } + + _, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "model", "model") + if err != nil { + t.Fatalf("expected nil error, got %v", err) + } + if strings.Contains(rec.Body.String(), "event: error") || strings.Contains(rec.Body.String(), "stream_read_error") { + t.Fatalf("expected no injected SSE error event, got %q", rec.Body.String()) + } +} + func TestOpenAIStreamingTooLong(t *testing.T) { gin.SetMode(gin.TestMode) cfg := &config.Config{