fix(openai): 收敛自动透传请求头并增强 OAuth 安全兜底

This commit is contained in:
yangjianbo
2026-02-12 20:12:15 +08:00
parent 1ae49b9ead
commit d411cf4472
2 changed files with 56 additions and 29 deletions

View File

@@ -34,12 +34,13 @@ const (
// OpenAI Platform API for API Key accounts (fallback)
openaiPlatformAPIURL = "https://api.openai.com/v1/responses"
openaiStickySessionTTL = time.Hour // 粘性会话TTL
codexCLIUserAgent = "codex_cli_rs/0.98.0"
// OpenAIParsedRequestBodyKey 缓存 handler 侧已解析的请求体,避免重复解析。
OpenAIParsedRequestBodyKey = "openai_parsed_request_body"
)
// OpenAI allowed headers whitelist (for non-OAuth accounts)
// OpenAI allowed headers whitelist (for non-passthrough).
var openaiAllowedHeaders = map[string]bool{
"accept-language": true,
"content-type": true,
@@ -49,6 +50,19 @@ var openaiAllowedHeaders = map[string]bool{
"session_id": true,
}
// OpenAI passthrough allowed headers whitelist.
// 透传模式下仅放行这些低风险请求头,避免将非标准/环境噪声头传给上游触发风控。
var openaiPassthroughAllowedHeaders = map[string]bool{
"accept": true,
"accept-language": true,
"content-type": true,
"conversation_id": true,
"openai-beta": true,
"user-agent": true,
"originator": true,
"session_id": true,
}
// OpenAICodexUsageSnapshot represents Codex API usage limits from response headers
type OpenAICodexUsageSnapshot struct {
PrimaryUsedPercent *float64 `json:"primary_used_percent,omitempty"`
@@ -1149,15 +1163,12 @@ func (s *OpenAIGatewayService) buildUpstreamRequestOpenAIPassthrough(
return nil, err
}
// 透传客户端请求头(尽可能原样),并做安全剔除
// 透传客户端请求头(安全白名单)
allowTimeoutHeaders := s.isOpenAIPassthroughTimeoutHeadersAllowed()
if c != nil && c.Request != nil {
for key, values := range c.Request.Header {
lower := strings.ToLower(key)
if isOpenAIPassthroughBlockedRequestHeader(lower) {
continue
}
if !allowTimeoutHeaders && isOpenAIPassthroughTimeoutHeader(lower) {
lower := strings.ToLower(strings.TrimSpace(key))
if !isOpenAIPassthroughAllowedRequestHeader(lower, allowTimeoutHeaders) {
continue
}
for _, v := range values {
@@ -1174,16 +1185,41 @@ func (s *OpenAIGatewayService) buildUpstreamRequestOpenAIPassthrough(
// OAuth 透传到 ChatGPT internal API 时补齐必要头。
if account.Type == AccountTypeOAuth {
promptCacheKey := strings.TrimSpace(gjson.GetBytes(body, "prompt_cache_key").String())
req.Host = "chatgpt.com"
if chatgptAccountID := account.GetChatGPTAccountID(); chatgptAccountID != "" {
req.Header.Set("chatgpt-account-id", chatgptAccountID)
}
if req.Header.Get("accept") == "" {
req.Header.Set("accept", "text/event-stream")
}
if req.Header.Get("OpenAI-Beta") == "" {
req.Header.Set("OpenAI-Beta", "responses=experimental")
}
if req.Header.Get("originator") == "" {
req.Header.Set("originator", "codex_cli_rs")
}
if promptCacheKey != "" {
if req.Header.Get("conversation_id") == "" {
req.Header.Set("conversation_id", promptCacheKey)
}
if req.Header.Get("session_id") == "" {
req.Header.Set("session_id", promptCacheKey)
}
}
}
// 透传模式也支持账户自定义 User-Agent 与 ForceCodexCLI 兜底。
customUA := account.GetOpenAIUserAgent()
if customUA != "" {
req.Header.Set("user-agent", customUA)
}
if s.cfg != nil && s.cfg.Gateway.ForceCodexCLI {
req.Header.Set("user-agent", codexCLIUserAgent)
}
// OAuth 安全透传:对非 Codex UA 统一兜底,降低被上游风控拦截概率。
if account.Type == AccountTypeOAuth && !openai.IsCodexCLIRequest(req.Header.Get("user-agent")) {
req.Header.Set("user-agent", codexCLIUserAgent)
}
if req.Header.Get("content-type") == "" {
@@ -1233,23 +1269,14 @@ func (s *OpenAIGatewayService) handleErrorResponsePassthrough(ctx context.Contex
return fmt.Errorf("upstream error: %d message=%s", resp.StatusCode, upstreamMsg)
}
func isOpenAIPassthroughBlockedRequestHeader(lowerKey string) bool {
switch lowerKey {
// hop-by-hop
case "connection", "transfer-encoding", "keep-alive", "proxy-connection", "upgrade", "te", "trailer":
return true
// 入站鉴权与潜在泄露
case "authorization", "x-api-key", "x-goog-api-key", "cookie", "proxy-authorization":
return true
// 由 Go http client 自动协商压缩;透传模式需避免上游返回压缩体影响 SSE/usage 解析
case "accept-encoding":
return true
// 由 HTTP 库管理
case "host", "content-length":
return true
default:
func isOpenAIPassthroughAllowedRequestHeader(lowerKey string, allowTimeoutHeaders bool) bool {
if lowerKey == "" {
return false
}
if isOpenAIPassthroughTimeoutHeader(lowerKey) {
return allowTimeoutHeaders
}
return openaiPassthroughAllowedHeaders[lowerKey]
}
func isOpenAIPassthroughTimeoutHeader(lowerKey string) bool {
@@ -1555,7 +1582,7 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.
// 若开启 ForceCodexCLI则强制将上游 User-Agent 伪装为 Codex CLI。
// 用于网关未透传/改写 User-Agent 时,仍能命中 Codex 侧识别逻辑。
if s.cfg != nil && s.cfg.Gateway.ForceCodexCLI {
req.Header.Set("user-agent", "codex_cli_rs/0.98.0")
req.Header.Set("user-agent", codexCLIUserAgent)
}
// Ensure required headers exist

View File

@@ -189,7 +189,7 @@ func TestOpenAIGatewayService_OAuthPassthrough_StreamKeepsToolNameAndBodyUnchang
require.Empty(t, upstream.lastReq.Header.Get("X-Goog-Api-Key"))
require.Empty(t, upstream.lastReq.Header.Get("Accept-Encoding"))
require.Empty(t, upstream.lastReq.Header.Get("Proxy-Authorization"))
require.Equal(t, "keep", upstream.lastReq.Header.Get("X-Test"))
require.Empty(t, upstream.lastReq.Header.Get("X-Test"))
// 3) required OAuth headers are present
require.Equal(t, "chatgpt.com", upstream.lastReq.Host)
@@ -344,7 +344,7 @@ func TestOpenAIGatewayService_OAuthPassthrough_UpstreamErrorIncludesPassthroughF
require.True(t, arr[len(arr)-1].Passthrough)
}
func TestOpenAIGatewayService_OAuthPassthrough_NonCodexUAStillPassthroughWhenEnabled(t *testing.T) {
func TestOpenAIGatewayService_OAuthPassthrough_NonCodexUAFallbackToCodexUA(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
@@ -383,7 +383,7 @@ func TestOpenAIGatewayService_OAuthPassthrough_NonCodexUAStillPassthroughWhenEna
_, err := svc.Forward(context.Background(), c, account, inputBody)
require.NoError(t, err)
require.Equal(t, inputBody, upstream.lastBody)
require.Equal(t, "curl/8.0", upstream.lastReq.Header.Get("User-Agent"))
require.Equal(t, "codex_cli_rs/0.98.0", upstream.lastReq.Header.Get("User-Agent"))
}
func TestOpenAIGatewayService_OAuthPassthrough_StreamingSetsFirstTokenMs(t *testing.T) {
@@ -533,7 +533,7 @@ func TestOpenAIGatewayService_APIKeyPassthrough_PreservesBodyAndUsesResponsesEnd
require.Equal(t, "https://api.openai.com/v1/responses", upstream.lastReq.URL.String())
require.Equal(t, "Bearer sk-api-key", upstream.lastReq.Header.Get("Authorization"))
require.Equal(t, "curl/8.0", upstream.lastReq.Header.Get("User-Agent"))
require.Equal(t, "keep", upstream.lastReq.Header.Get("X-Test"))
require.Empty(t, upstream.lastReq.Header.Get("X-Test"))
}
func TestOpenAIGatewayService_OAuthPassthrough_WarnOnTimeoutHeadersForStream(t *testing.T) {
@@ -656,7 +656,7 @@ func TestOpenAIGatewayService_OAuthPassthrough_DefaultFiltersTimeoutHeaders(t *t
require.NoError(t, err)
require.NotNil(t, upstream.lastReq)
require.Empty(t, upstream.lastReq.Header.Get("x-stainless-timeout"))
require.Equal(t, "keep", upstream.lastReq.Header.Get("X-Test"))
require.Empty(t, upstream.lastReq.Header.Get("X-Test"))
}
func TestOpenAIGatewayService_OAuthPassthrough_AllowTimeoutHeadersWhenConfigured(t *testing.T) {
@@ -700,5 +700,5 @@ func TestOpenAIGatewayService_OAuthPassthrough_AllowTimeoutHeadersWhenConfigured
require.NoError(t, err)
require.NotNil(t, upstream.lastReq)
require.Equal(t, "120000", upstream.lastReq.Header.Get("x-stainless-timeout"))
require.Equal(t, "keep", upstream.lastReq.Header.Get("X-Test"))
require.Empty(t, upstream.lastReq.Header.Get("X-Test"))
}