From ed2eba90282d5bbd9aa788845b8a1337d92baed7 Mon Sep 17 00:00:00 2001 From: yangjianbo Date: Thu, 12 Feb 2026 14:16:18 +0800 Subject: [PATCH] =?UTF-8?q?fix(gateway):=20=E9=BB=98=E8=AE=A4=E8=BF=87?= =?UTF-8?q?=E6=BB=A4OpenAI=E9=80=8F=E4=BC=A0=E8=B6=85=E6=97=B6=E5=A4=B4?= =?UTF-8?q?=E5=B9=B6=E8=A1=A5=E5=85=85=E6=96=AD=E6=B5=81=E5=91=8A=E8=AD=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/internal/config/config.go | 4 + .../service/openai_gateway_service.go | 81 +++++++- .../service/openai_oauth_passthrough_test.go | 191 ++++++++++++++++++ deploy/config.example.yaml | 3 + 4 files changed, 278 insertions(+), 1 deletion(-) diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index f095f317..3f3deefc 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -280,6 +280,9 @@ type GatewayConfig struct { // ForceCodexCLI: 强制将 OpenAI `/v1/responses` 请求按 Codex CLI 处理。 // 用于网关未透传/改写 User-Agent 时的兼容兜底(默认关闭,避免影响其他客户端)。 ForceCodexCLI bool `mapstructure:"force_codex_cli"` + // OpenAIPassthroughAllowTimeoutHeaders: OpenAI 透传模式是否放行客户端超时头 + // 关闭(默认)可避免 x-stainless-timeout 等头导致上游提前断流。 + OpenAIPassthroughAllowTimeoutHeaders bool `mapstructure:"openai_passthrough_allow_timeout_headers"` // HTTP 上游连接池配置(性能优化:支持高并发场景调优) // MaxIdleConns: 所有主机的最大空闲连接总数 @@ -995,6 +998,7 @@ func setDefaults() { viper.SetDefault("gateway.max_account_switches", 10) viper.SetDefault("gateway.max_account_switches_gemini", 3) viper.SetDefault("gateway.force_codex_cli", false) + viper.SetDefault("gateway.openai_passthrough_allow_timeout_headers", false) viper.SetDefault("gateway.antigravity_fallback_cooldown_minutes", 1) viper.SetDefault("gateway.max_body_size", int64(100*1024*1024)) viper.SetDefault("gateway.sora_max_body_size", int64(256*1024*1024)) diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index 1de60665..4cd0f171 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -1020,6 +1020,23 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough( reqModel, reqStream, ) + if reqStream && c != nil && c.Request != nil { + if timeoutHeaders := collectOpenAIPassthroughTimeoutHeaders(c.Request.Header); len(timeoutHeaders) > 0 { + if s.isOpenAIPassthroughTimeoutHeadersAllowed() { + log.Printf( + "[WARN] [OpenAI passthrough] 透传请求包含超时相关请求头,且当前配置为放行,可能导致上游提前断流: account=%d headers=%s", + account.ID, + strings.Join(timeoutHeaders, ", "), + ) + } else { + log.Printf( + "[WARN] [OpenAI passthrough] 检测到超时相关请求头,将按配置过滤以降低断流风险: account=%d headers=%s", + account.ID, + strings.Join(timeoutHeaders, ", "), + ) + } + } + } // Get access token token, _, err := s.GetAccessToken(ctx, account) @@ -1135,12 +1152,16 @@ func (s *OpenAIGatewayService) buildUpstreamRequestOpenAIPassthrough( } // 透传客户端请求头(尽可能原样),并做安全剔除。 + allowTimeoutHeaders := s.isOpenAIPassthroughTimeoutHeadersAllowed() if c != nil && c.Request != nil { for key, values := range c.Request.Header { lower := strings.ToLower(key) if isOpenAIPassthroughBlockedRequestHeader(lower) { continue } + if !allowTimeoutHeaders && isOpenAIPassthroughTimeoutHeader(lower) { + continue + } for _, v := range values { req.Header.Add(key, v) } @@ -1233,6 +1254,38 @@ func isOpenAIPassthroughBlockedRequestHeader(lowerKey string) bool { } } +func isOpenAIPassthroughTimeoutHeader(lowerKey string) bool { + switch lowerKey { + case "x-stainless-timeout", "x-stainless-read-timeout", "x-stainless-connect-timeout", "x-request-timeout", "request-timeout", "grpc-timeout": + return true + default: + return false + } +} + +func (s *OpenAIGatewayService) isOpenAIPassthroughTimeoutHeadersAllowed() bool { + return s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIPassthroughAllowTimeoutHeaders +} + +func collectOpenAIPassthroughTimeoutHeaders(h http.Header) []string { + if h == nil { + return nil + } + var matched []string + for key, values := range h { + lowerKey := strings.ToLower(strings.TrimSpace(key)) + if isOpenAIPassthroughTimeoutHeader(lowerKey) { + entry := lowerKey + if len(values) > 0 { + entry = fmt.Sprintf("%s=%s", lowerKey, strings.Join(values, "|")) + } + matched = append(matched, entry) + } + } + sort.Strings(matched) + return matched +} + type openaiStreamingResultPassthrough struct { usage *OpenAIUsage firstTokenMs *int @@ -1265,6 +1318,8 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough( usage := &OpenAIUsage{} var firstTokenMs *int clientDisconnected := false + sawDone := false + upstreamRequestID := strings.TrimSpace(resp.Header.Get("x-request-id")) scanner := bufio.NewScanner(resp.Body) maxLineSize := defaultMaxLineSize @@ -1278,7 +1333,11 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough( for scanner.Scan() { line := scanner.Text() if data, ok := extractOpenAISSEDataLine(line); ok { - if firstTokenMs == nil && strings.TrimSpace(data) != "" { + trimmedData := strings.TrimSpace(data) + if trimmedData == "[DONE]" { + sawDone = true + } + if firstTokenMs == nil && trimmedData != "" && trimmedData != "[DONE]" { ms := int(time.Since(startTime).Milliseconds()) firstTokenMs = &ms } @@ -1300,14 +1359,34 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough( return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, nil } if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + log.Printf( + "[WARN] [OpenAI passthrough] 流读取被取消,可能发生断流: account=%d request_id=%s err=%v ctx_err=%v", + account.ID, + upstreamRequestID, + err, + ctx.Err(), + ) return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, nil } if errors.Is(err, bufio.ErrTooLong) { log.Printf("[OpenAI passthrough] SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, err) return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, err } + log.Printf( + "[WARN] [OpenAI passthrough] 流读取异常中断: account=%d request_id=%s err=%v", + account.ID, + upstreamRequestID, + err, + ) return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", err) } + if !clientDisconnected && !sawDone && ctx.Err() == nil { + log.Printf( + "[WARN] [OpenAI passthrough] 上游流在未收到 [DONE] 时结束,疑似断流: account=%d request_id=%s", + account.ID, + upstreamRequestID, + ) + } return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, nil } diff --git a/backend/internal/service/openai_oauth_passthrough_test.go b/backend/internal/service/openai_oauth_passthrough_test.go index f6932469..970c4b84 100644 --- a/backend/internal/service/openai_oauth_passthrough_test.go +++ b/backend/internal/service/openai_oauth_passthrough_test.go @@ -4,9 +4,12 @@ import ( "bytes" "context" "io" + "log" "net/http" "net/http/httptest" + "os" "strings" + "sync" "testing" "time" @@ -43,6 +46,27 @@ func (u *httpUpstreamRecorder) DoWithTLS(req *http.Request, proxyURL string, acc return u.Do(req, proxyURL, accountID, accountConcurrency) } +var stdLogCaptureMu sync.Mutex + +func captureStdLog(t *testing.T) (*bytes.Buffer, func()) { + t.Helper() + stdLogCaptureMu.Lock() + buf := &bytes.Buffer{} + prevWriter := log.Writer() + prevFlags := log.Flags() + log.SetFlags(0) + log.SetOutput(buf) + return buf, func() { + log.SetOutput(prevWriter) + log.SetFlags(prevFlags) + // 防御性恢复,避免其他测试改动了底层 writer。 + if prevWriter == nil { + log.SetOutput(os.Stderr) + } + stdLogCaptureMu.Unlock() + } +} + func TestOpenAIGatewayService_OAuthPassthrough_StreamKeepsToolNameAndBodyUnchanged(t *testing.T) { gin.SetMode(gin.TestMode) @@ -459,3 +483,170 @@ func TestOpenAIGatewayService_APIKeyPassthrough_PreservesBodyAndUsesResponsesEnd require.Equal(t, "curl/8.0", upstream.lastReq.Header.Get("User-Agent")) require.Equal(t, "keep", upstream.lastReq.Header.Get("X-Test")) } + +func TestOpenAIGatewayService_OAuthPassthrough_WarnOnTimeoutHeadersForStream(t *testing.T) { + gin.SetMode(gin.TestMode) + logBuf, restore := captureStdLog(t) + defer restore() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil)) + c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0") + c.Request.Header.Set("x-stainless-timeout", "10000") + + originalBody := []byte(`{"model":"gpt-5.2","stream":true,"input":[{"type":"text","text":"hi"}]}`) + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "X-Request-Id": []string{"rid-timeout"}}, + Body: io.NopCloser(strings.NewReader("data: [DONE]\n\n")), + } + upstream := &httpUpstreamRecorder{resp: resp} + svc := &OpenAIGatewayService{ + cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}}, + httpUpstream: upstream, + } + account := &Account{ + ID: 321, + Name: "acc", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"}, + Extra: map[string]any{"openai_passthrough": true}, + Status: StatusActive, + Schedulable: true, + RateMultiplier: f64p(1), + } + + _, err := svc.Forward(context.Background(), c, account, originalBody) + require.NoError(t, err) + require.Contains(t, logBuf.String(), "检测到超时相关请求头,将按配置过滤以降低断流风险") + require.Contains(t, logBuf.String(), "x-stainless-timeout=10000") +} + +func TestOpenAIGatewayService_OAuthPassthrough_WarnWhenStreamEndsWithoutDone(t *testing.T) { + gin.SetMode(gin.TestMode) + logBuf, restore := captureStdLog(t) + defer restore() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil)) + c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0") + + originalBody := []byte(`{"model":"gpt-5.2","stream":true,"input":[{"type":"text","text":"hi"}]}`) + // 注意:刻意不发送 [DONE],模拟上游中途断流。 + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "X-Request-Id": []string{"rid-truncate"}}, + Body: io.NopCloser(strings.NewReader("data: {\"type\":\"response.output_text.delta\",\"delta\":\"h\"}\n\n")), + } + upstream := &httpUpstreamRecorder{resp: resp} + svc := &OpenAIGatewayService{ + cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}}, + httpUpstream: upstream, + } + account := &Account{ + ID: 654, + Name: "acc", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"}, + Extra: map[string]any{"openai_passthrough": true}, + Status: StatusActive, + Schedulable: true, + RateMultiplier: f64p(1), + } + + _, err := svc.Forward(context.Background(), c, account, originalBody) + require.NoError(t, err) + require.Contains(t, logBuf.String(), "上游流在未收到 [DONE] 时结束,疑似断流") + require.Contains(t, logBuf.String(), "rid-truncate") +} + +func TestOpenAIGatewayService_OAuthPassthrough_DefaultFiltersTimeoutHeaders(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil)) + c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0") + c.Request.Header.Set("x-stainless-timeout", "120000") + c.Request.Header.Set("X-Test", "keep") + + originalBody := []byte(`{"model":"gpt-5.2","stream":false,"input":[{"type":"text","text":"hi"}]}`) + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}, "X-Request-Id": []string{"rid-filter-default"}}, + Body: io.NopCloser(strings.NewReader(`{"output":[],"usage":{"input_tokens":1,"output_tokens":1,"input_tokens_details":{"cached_tokens":0}}}`)), + } + upstream := &httpUpstreamRecorder{resp: resp} + svc := &OpenAIGatewayService{ + cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}}, + httpUpstream: upstream, + } + account := &Account{ + ID: 111, + Name: "acc", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"}, + Extra: map[string]any{"openai_passthrough": true}, + Status: StatusActive, + Schedulable: true, + RateMultiplier: f64p(1), + } + + _, err := svc.Forward(context.Background(), c, account, originalBody) + require.NoError(t, err) + require.NotNil(t, upstream.lastReq) + require.Empty(t, upstream.lastReq.Header.Get("x-stainless-timeout")) + require.Equal(t, "keep", upstream.lastReq.Header.Get("X-Test")) +} + +func TestOpenAIGatewayService_OAuthPassthrough_AllowTimeoutHeadersWhenConfigured(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil)) + c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0") + c.Request.Header.Set("x-stainless-timeout", "120000") + c.Request.Header.Set("X-Test", "keep") + + originalBody := []byte(`{"model":"gpt-5.2","stream":false,"input":[{"type":"text","text":"hi"}]}`) + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}, "X-Request-Id": []string{"rid-filter-allow"}}, + Body: io.NopCloser(strings.NewReader(`{"output":[],"usage":{"input_tokens":1,"output_tokens":1,"input_tokens_details":{"cached_tokens":0}}}`)), + } + upstream := &httpUpstreamRecorder{resp: resp} + svc := &OpenAIGatewayService{ + cfg: &config.Config{Gateway: config.GatewayConfig{ + ForceCodexCLI: false, + OpenAIPassthroughAllowTimeoutHeaders: true, + }}, + httpUpstream: upstream, + } + account := &Account{ + ID: 222, + Name: "acc", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"}, + Extra: map[string]any{"openai_passthrough": true}, + Status: StatusActive, + Schedulable: true, + RateMultiplier: f64p(1), + } + + _, err := svc.Forward(context.Background(), c, account, originalBody) + require.NoError(t, err) + require.NotNil(t, upstream.lastReq) + require.Equal(t, "120000", upstream.lastReq.Header.Get("x-stainless-timeout")) + require.Equal(t, "keep", upstream.lastReq.Header.Get("X-Test")) +} diff --git a/deploy/config.example.yaml b/deploy/config.example.yaml index b60082b9..9ab3bfd0 100644 --- a/deploy/config.example.yaml +++ b/deploy/config.example.yaml @@ -187,6 +187,9 @@ gateway: # # 注意:开启后会影响所有客户端的行为(不仅限于 VS Code / Codex CLI),请谨慎开启。 force_codex_cli: false + # OpenAI 透传模式是否放行客户端超时头(如 x-stainless-timeout) + # 默认 false:过滤超时头,降低上游提前断流风险。 + openai_passthrough_allow_timeout_headers: false # HTTP upstream connection pool settings (HTTP/2 + multi-proxy scenario defaults) # HTTP 上游连接池配置(HTTP/2 + 多代理场景默认值) # Max idle connections across all hosts