diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index c394ed1f..76746d2b 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -1183,7 +1183,10 @@ func isOpenAIPassthroughBlockedRequestHeader(lowerKey string) bool { case "connection", "transfer-encoding", "keep-alive", "proxy-connection", "upgrade", "te", "trailer": return true // 入站鉴权与潜在泄露 - case "authorization", "x-api-key", "x-goog-api-key", "cookie": + case "authorization", "x-api-key", "x-goog-api-key", "cookie", "proxy-authorization": + return true + // 由 Go http client 自动协商压缩;透传模式需避免上游返回压缩体影响 SSE/usage 解析 + case "accept-encoding": return true // 由 HTTP 库管理 case "host", "content-length": @@ -1224,6 +1227,7 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough( usage := &OpenAIUsage{} var firstTokenMs *int + clientDisconnected := false scanner := bufio.NewScanner(resp.Body) maxLineSize := defaultMaxLineSize @@ -1245,13 +1249,20 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough( s.parseSSEUsage(data, usage) } - if _, err := fmt.Fprintln(w, line); err != nil { - // 客户端断开时停止写入 - break + if !clientDisconnected { + if _, err := fmt.Fprintln(w, line); err != nil { + clientDisconnected = true + log.Printf("[OpenAI passthrough] Client disconnected during streaming, continue draining upstream for usage: account=%d", account.ID) + } else { + flusher.Flush() + } } - flusher.Flush() } if err := scanner.Err(); err != nil { + if clientDisconnected { + log.Printf("[OpenAI passthrough] Upstream read error after client disconnect: account=%d err=%v", account.ID, err) + return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, nil + } if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, nil } diff --git a/backend/internal/service/openai_oauth_passthrough_test.go b/backend/internal/service/openai_oauth_passthrough_test.go index 56eb8bd8..96805123 100644 --- a/backend/internal/service/openai_oauth_passthrough_test.go +++ b/backend/internal/service/openai_oauth_passthrough_test.go @@ -54,6 +54,8 @@ func TestOpenAIGatewayService_OAuthPassthrough_StreamKeepsToolNameAndBodyUnchang c.Request.Header.Set("Cookie", "secret=1") c.Request.Header.Set("X-Api-Key", "sk-inbound") c.Request.Header.Set("X-Goog-Api-Key", "goog-inbound") + c.Request.Header.Set("Accept-Encoding", "gzip") + c.Request.Header.Set("Proxy-Authorization", "Basic abc") c.Request.Header.Set("X-Test", "keep") originalBody := []byte(`{"model":"gpt-5.2","stream":true,"store":true,"input":[{"type":"text","text":"hi"}]}`) @@ -108,6 +110,8 @@ func TestOpenAIGatewayService_OAuthPassthrough_StreamKeepsToolNameAndBodyUnchang require.Empty(t, upstream.lastReq.Header.Get("Cookie")) require.Empty(t, upstream.lastReq.Header.Get("X-Api-Key")) require.Empty(t, upstream.lastReq.Header.Get("X-Goog-Api-Key")) + require.Empty(t, upstream.lastReq.Header.Get("Accept-Encoding")) + require.Empty(t, upstream.lastReq.Header.Get("Proxy-Authorization")) require.Equal(t, "keep", upstream.lastReq.Header.Get("X-Test")) // 3) required OAuth headers are present @@ -362,3 +366,58 @@ func TestOpenAIGatewayService_OAuthPassthrough_StreamingSetsFirstTokenMs(t *test require.NotNil(t, result.FirstTokenMs) require.GreaterOrEqual(t, *result.FirstTokenMs, 0) } + +func TestOpenAIGatewayService_OAuthPassthrough_StreamClientDisconnectStillCollectsUsage(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil)) + c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0") + // 首次写入成功,后续写入失败,模拟客户端中途断开。 + c.Writer = &failingGinWriter{ResponseWriter: c.Writer, failAfter: 1} + + originalBody := []byte(`{"model":"gpt-5.2","stream":true,"input":[{"type":"text","text":"hi"}]}`) + + upstreamSSE := strings.Join([]string{ + `data: {"type":"response.output_text.delta","delta":"h"}`, + "", + `data: {"type":"response.completed","response":{"usage":{"input_tokens":11,"output_tokens":7,"input_tokens_details":{"cached_tokens":3}}}}`, + "", + "data: [DONE]", + "", + }, "\n") + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid"}}, + Body: io.NopCloser(strings.NewReader(upstreamSSE)), + } + upstream := &httpUpstreamRecorder{resp: resp} + + svc := &OpenAIGatewayService{ + cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}}, + httpUpstream: upstream, + } + + account := &Account{ + ID: 123, + Name: "acc", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"}, + Extra: map[string]any{"openai_oauth_passthrough": true}, + Status: StatusActive, + Schedulable: true, + RateMultiplier: f64p(1), + } + + result, err := svc.Forward(context.Background(), c, account, originalBody) + require.NoError(t, err) + require.NotNil(t, result) + require.True(t, result.Stream) + require.NotNil(t, result.FirstTokenMs) + require.Equal(t, 11, result.Usage.InputTokens) + require.Equal(t, 7, result.Usage.OutputTokens) + require.Equal(t, 3, result.Usage.CacheReadInputTokens) +}