diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index 724c6dc4..63fb233e 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -34,12 +34,13 @@ const ( // OpenAI Platform API for API Key accounts (fallback) openaiPlatformAPIURL = "https://api.openai.com/v1/responses" openaiStickySessionTTL = time.Hour // 粘性会话TTL + codexCLIUserAgent = "codex_cli_rs/0.98.0" // OpenAIParsedRequestBodyKey 缓存 handler 侧已解析的请求体,避免重复解析。 OpenAIParsedRequestBodyKey = "openai_parsed_request_body" ) -// OpenAI allowed headers whitelist (for non-OAuth accounts) +// OpenAI allowed headers whitelist (for non-passthrough). var openaiAllowedHeaders = map[string]bool{ "accept-language": true, "content-type": true, @@ -49,6 +50,19 @@ var openaiAllowedHeaders = map[string]bool{ "session_id": true, } +// OpenAI passthrough allowed headers whitelist. +// 透传模式下仅放行这些低风险请求头,避免将非标准/环境噪声头传给上游触发风控。 +var openaiPassthroughAllowedHeaders = map[string]bool{ + "accept": true, + "accept-language": true, + "content-type": true, + "conversation_id": true, + "openai-beta": true, + "user-agent": true, + "originator": true, + "session_id": true, +} + // OpenAICodexUsageSnapshot represents Codex API usage limits from response headers type OpenAICodexUsageSnapshot struct { PrimaryUsedPercent *float64 `json:"primary_used_percent,omitempty"` @@ -1149,15 +1163,12 @@ func (s *OpenAIGatewayService) buildUpstreamRequestOpenAIPassthrough( return nil, err } - // 透传客户端请求头(尽可能原样),并做安全剔除。 + // 透传客户端请求头(安全白名单)。 allowTimeoutHeaders := s.isOpenAIPassthroughTimeoutHeadersAllowed() if c != nil && c.Request != nil { for key, values := range c.Request.Header { - lower := strings.ToLower(key) - if isOpenAIPassthroughBlockedRequestHeader(lower) { - continue - } - if !allowTimeoutHeaders && isOpenAIPassthroughTimeoutHeader(lower) { + lower := strings.ToLower(strings.TrimSpace(key)) + if !isOpenAIPassthroughAllowedRequestHeader(lower, allowTimeoutHeaders) { continue } for _, v := range values { @@ -1174,16 +1185,41 @@ func (s *OpenAIGatewayService) buildUpstreamRequestOpenAIPassthrough( // OAuth 透传到 ChatGPT internal API 时补齐必要头。 if account.Type == AccountTypeOAuth { + promptCacheKey := strings.TrimSpace(gjson.GetBytes(body, "prompt_cache_key").String()) req.Host = "chatgpt.com" if chatgptAccountID := account.GetChatGPTAccountID(); chatgptAccountID != "" { req.Header.Set("chatgpt-account-id", chatgptAccountID) } + if req.Header.Get("accept") == "" { + req.Header.Set("accept", "text/event-stream") + } if req.Header.Get("OpenAI-Beta") == "" { req.Header.Set("OpenAI-Beta", "responses=experimental") } if req.Header.Get("originator") == "" { req.Header.Set("originator", "codex_cli_rs") } + if promptCacheKey != "" { + if req.Header.Get("conversation_id") == "" { + req.Header.Set("conversation_id", promptCacheKey) + } + if req.Header.Get("session_id") == "" { + req.Header.Set("session_id", promptCacheKey) + } + } + } + + // 透传模式也支持账户自定义 User-Agent 与 ForceCodexCLI 兜底。 + customUA := account.GetOpenAIUserAgent() + if customUA != "" { + req.Header.Set("user-agent", customUA) + } + if s.cfg != nil && s.cfg.Gateway.ForceCodexCLI { + req.Header.Set("user-agent", codexCLIUserAgent) + } + // OAuth 安全透传:对非 Codex UA 统一兜底,降低被上游风控拦截概率。 + if account.Type == AccountTypeOAuth && !openai.IsCodexCLIRequest(req.Header.Get("user-agent")) { + req.Header.Set("user-agent", codexCLIUserAgent) } if req.Header.Get("content-type") == "" { @@ -1233,23 +1269,14 @@ func (s *OpenAIGatewayService) handleErrorResponsePassthrough(ctx context.Contex return fmt.Errorf("upstream error: %d message=%s", resp.StatusCode, upstreamMsg) } -func isOpenAIPassthroughBlockedRequestHeader(lowerKey string) bool { - switch lowerKey { - // hop-by-hop - case "connection", "transfer-encoding", "keep-alive", "proxy-connection", "upgrade", "te", "trailer": - return true - // 入站鉴权与潜在泄露 - case "authorization", "x-api-key", "x-goog-api-key", "cookie", "proxy-authorization": - return true - // 由 Go http client 自动协商压缩;透传模式需避免上游返回压缩体影响 SSE/usage 解析 - case "accept-encoding": - return true - // 由 HTTP 库管理 - case "host", "content-length": - return true - default: +func isOpenAIPassthroughAllowedRequestHeader(lowerKey string, allowTimeoutHeaders bool) bool { + if lowerKey == "" { return false } + if isOpenAIPassthroughTimeoutHeader(lowerKey) { + return allowTimeoutHeaders + } + return openaiPassthroughAllowedHeaders[lowerKey] } func isOpenAIPassthroughTimeoutHeader(lowerKey string) bool { @@ -1555,7 +1582,7 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin. // 若开启 ForceCodexCLI,则强制将上游 User-Agent 伪装为 Codex CLI。 // 用于网关未透传/改写 User-Agent 时,仍能命中 Codex 侧识别逻辑。 if s.cfg != nil && s.cfg.Gateway.ForceCodexCLI { - req.Header.Set("user-agent", "codex_cli_rs/0.98.0") + req.Header.Set("user-agent", codexCLIUserAgent) } // Ensure required headers exist diff --git a/backend/internal/service/openai_oauth_passthrough_test.go b/backend/internal/service/openai_oauth_passthrough_test.go index f0bda0ef..1722952c 100644 --- a/backend/internal/service/openai_oauth_passthrough_test.go +++ b/backend/internal/service/openai_oauth_passthrough_test.go @@ -189,7 +189,7 @@ func TestOpenAIGatewayService_OAuthPassthrough_StreamKeepsToolNameAndBodyUnchang require.Empty(t, upstream.lastReq.Header.Get("X-Goog-Api-Key")) require.Empty(t, upstream.lastReq.Header.Get("Accept-Encoding")) require.Empty(t, upstream.lastReq.Header.Get("Proxy-Authorization")) - require.Equal(t, "keep", upstream.lastReq.Header.Get("X-Test")) + require.Empty(t, upstream.lastReq.Header.Get("X-Test")) // 3) required OAuth headers are present require.Equal(t, "chatgpt.com", upstream.lastReq.Host) @@ -344,7 +344,7 @@ func TestOpenAIGatewayService_OAuthPassthrough_UpstreamErrorIncludesPassthroughF require.True(t, arr[len(arr)-1].Passthrough) } -func TestOpenAIGatewayService_OAuthPassthrough_NonCodexUAStillPassthroughWhenEnabled(t *testing.T) { +func TestOpenAIGatewayService_OAuthPassthrough_NonCodexUAFallbackToCodexUA(t *testing.T) { gin.SetMode(gin.TestMode) rec := httptest.NewRecorder() @@ -383,7 +383,7 @@ func TestOpenAIGatewayService_OAuthPassthrough_NonCodexUAStillPassthroughWhenEna _, err := svc.Forward(context.Background(), c, account, inputBody) require.NoError(t, err) require.Equal(t, inputBody, upstream.lastBody) - require.Equal(t, "curl/8.0", upstream.lastReq.Header.Get("User-Agent")) + require.Equal(t, "codex_cli_rs/0.98.0", upstream.lastReq.Header.Get("User-Agent")) } func TestOpenAIGatewayService_OAuthPassthrough_StreamingSetsFirstTokenMs(t *testing.T) { @@ -533,7 +533,7 @@ func TestOpenAIGatewayService_APIKeyPassthrough_PreservesBodyAndUsesResponsesEnd require.Equal(t, "https://api.openai.com/v1/responses", upstream.lastReq.URL.String()) require.Equal(t, "Bearer sk-api-key", upstream.lastReq.Header.Get("Authorization")) require.Equal(t, "curl/8.0", upstream.lastReq.Header.Get("User-Agent")) - require.Equal(t, "keep", upstream.lastReq.Header.Get("X-Test")) + require.Empty(t, upstream.lastReq.Header.Get("X-Test")) } func TestOpenAIGatewayService_OAuthPassthrough_WarnOnTimeoutHeadersForStream(t *testing.T) { @@ -656,7 +656,7 @@ func TestOpenAIGatewayService_OAuthPassthrough_DefaultFiltersTimeoutHeaders(t *t require.NoError(t, err) require.NotNil(t, upstream.lastReq) require.Empty(t, upstream.lastReq.Header.Get("x-stainless-timeout")) - require.Equal(t, "keep", upstream.lastReq.Header.Get("X-Test")) + require.Empty(t, upstream.lastReq.Header.Get("X-Test")) } func TestOpenAIGatewayService_OAuthPassthrough_AllowTimeoutHeadersWhenConfigured(t *testing.T) { @@ -700,5 +700,5 @@ func TestOpenAIGatewayService_OAuthPassthrough_AllowTimeoutHeadersWhenConfigured require.NoError(t, err) require.NotNil(t, upstream.lastReq) require.Equal(t, "120000", upstream.lastReq.Header.Get("x-stainless-timeout")) - require.Equal(t, "keep", upstream.lastReq.Header.Get("X-Test")) + require.Empty(t, upstream.lastReq.Header.Get("X-Test")) }