fix(openai): 收敛自动透传请求头并增强 OAuth 安全兜底

This commit is contained in:
yangjianbo
2026-02-12 20:12:15 +08:00
parent 1ae49b9ead
commit d411cf4472
2 changed files with 56 additions and 29 deletions

View File

@@ -34,12 +34,13 @@ const (
// OpenAI Platform API for API Key accounts (fallback) // OpenAI Platform API for API Key accounts (fallback)
openaiPlatformAPIURL = "https://api.openai.com/v1/responses" openaiPlatformAPIURL = "https://api.openai.com/v1/responses"
openaiStickySessionTTL = time.Hour // 粘性会话TTL openaiStickySessionTTL = time.Hour // 粘性会话TTL
codexCLIUserAgent = "codex_cli_rs/0.98.0"
// OpenAIParsedRequestBodyKey 缓存 handler 侧已解析的请求体,避免重复解析。 // OpenAIParsedRequestBodyKey 缓存 handler 侧已解析的请求体,避免重复解析。
OpenAIParsedRequestBodyKey = "openai_parsed_request_body" OpenAIParsedRequestBodyKey = "openai_parsed_request_body"
) )
// OpenAI allowed headers whitelist (for non-OAuth accounts) // OpenAI allowed headers whitelist (for non-passthrough).
var openaiAllowedHeaders = map[string]bool{ var openaiAllowedHeaders = map[string]bool{
"accept-language": true, "accept-language": true,
"content-type": true, "content-type": true,
@@ -49,6 +50,19 @@ var openaiAllowedHeaders = map[string]bool{
"session_id": true, "session_id": true,
} }
// OpenAI passthrough allowed headers whitelist.
// 透传模式下仅放行这些低风险请求头,避免将非标准/环境噪声头传给上游触发风控。
var openaiPassthroughAllowedHeaders = map[string]bool{
"accept": true,
"accept-language": true,
"content-type": true,
"conversation_id": true,
"openai-beta": true,
"user-agent": true,
"originator": true,
"session_id": true,
}
// OpenAICodexUsageSnapshot represents Codex API usage limits from response headers // OpenAICodexUsageSnapshot represents Codex API usage limits from response headers
type OpenAICodexUsageSnapshot struct { type OpenAICodexUsageSnapshot struct {
PrimaryUsedPercent *float64 `json:"primary_used_percent,omitempty"` PrimaryUsedPercent *float64 `json:"primary_used_percent,omitempty"`
@@ -1149,15 +1163,12 @@ func (s *OpenAIGatewayService) buildUpstreamRequestOpenAIPassthrough(
return nil, err return nil, err
} }
// 透传客户端请求头(尽可能原样),并做安全剔除 // 透传客户端请求头(安全白名单)
allowTimeoutHeaders := s.isOpenAIPassthroughTimeoutHeadersAllowed() allowTimeoutHeaders := s.isOpenAIPassthroughTimeoutHeadersAllowed()
if c != nil && c.Request != nil { if c != nil && c.Request != nil {
for key, values := range c.Request.Header { for key, values := range c.Request.Header {
lower := strings.ToLower(key) lower := strings.ToLower(strings.TrimSpace(key))
if isOpenAIPassthroughBlockedRequestHeader(lower) { if !isOpenAIPassthroughAllowedRequestHeader(lower, allowTimeoutHeaders) {
continue
}
if !allowTimeoutHeaders && isOpenAIPassthroughTimeoutHeader(lower) {
continue continue
} }
for _, v := range values { for _, v := range values {
@@ -1174,16 +1185,41 @@ func (s *OpenAIGatewayService) buildUpstreamRequestOpenAIPassthrough(
// OAuth 透传到 ChatGPT internal API 时补齐必要头。 // OAuth 透传到 ChatGPT internal API 时补齐必要头。
if account.Type == AccountTypeOAuth { if account.Type == AccountTypeOAuth {
promptCacheKey := strings.TrimSpace(gjson.GetBytes(body, "prompt_cache_key").String())
req.Host = "chatgpt.com" req.Host = "chatgpt.com"
if chatgptAccountID := account.GetChatGPTAccountID(); chatgptAccountID != "" { if chatgptAccountID := account.GetChatGPTAccountID(); chatgptAccountID != "" {
req.Header.Set("chatgpt-account-id", chatgptAccountID) req.Header.Set("chatgpt-account-id", chatgptAccountID)
} }
if req.Header.Get("accept") == "" {
req.Header.Set("accept", "text/event-stream")
}
if req.Header.Get("OpenAI-Beta") == "" { if req.Header.Get("OpenAI-Beta") == "" {
req.Header.Set("OpenAI-Beta", "responses=experimental") req.Header.Set("OpenAI-Beta", "responses=experimental")
} }
if req.Header.Get("originator") == "" { if req.Header.Get("originator") == "" {
req.Header.Set("originator", "codex_cli_rs") req.Header.Set("originator", "codex_cli_rs")
} }
if promptCacheKey != "" {
if req.Header.Get("conversation_id") == "" {
req.Header.Set("conversation_id", promptCacheKey)
}
if req.Header.Get("session_id") == "" {
req.Header.Set("session_id", promptCacheKey)
}
}
}
// 透传模式也支持账户自定义 User-Agent 与 ForceCodexCLI 兜底。
customUA := account.GetOpenAIUserAgent()
if customUA != "" {
req.Header.Set("user-agent", customUA)
}
if s.cfg != nil && s.cfg.Gateway.ForceCodexCLI {
req.Header.Set("user-agent", codexCLIUserAgent)
}
// OAuth 安全透传:对非 Codex UA 统一兜底,降低被上游风控拦截概率。
if account.Type == AccountTypeOAuth && !openai.IsCodexCLIRequest(req.Header.Get("user-agent")) {
req.Header.Set("user-agent", codexCLIUserAgent)
} }
if req.Header.Get("content-type") == "" { if req.Header.Get("content-type") == "" {
@@ -1233,23 +1269,14 @@ func (s *OpenAIGatewayService) handleErrorResponsePassthrough(ctx context.Contex
return fmt.Errorf("upstream error: %d message=%s", resp.StatusCode, upstreamMsg) return fmt.Errorf("upstream error: %d message=%s", resp.StatusCode, upstreamMsg)
} }
func isOpenAIPassthroughBlockedRequestHeader(lowerKey string) bool { func isOpenAIPassthroughAllowedRequestHeader(lowerKey string, allowTimeoutHeaders bool) bool {
switch lowerKey { if lowerKey == "" {
// hop-by-hop
case "connection", "transfer-encoding", "keep-alive", "proxy-connection", "upgrade", "te", "trailer":
return true
// 入站鉴权与潜在泄露
case "authorization", "x-api-key", "x-goog-api-key", "cookie", "proxy-authorization":
return true
// 由 Go http client 自动协商压缩;透传模式需避免上游返回压缩体影响 SSE/usage 解析
case "accept-encoding":
return true
// 由 HTTP 库管理
case "host", "content-length":
return true
default:
return false return false
} }
if isOpenAIPassthroughTimeoutHeader(lowerKey) {
return allowTimeoutHeaders
}
return openaiPassthroughAllowedHeaders[lowerKey]
} }
func isOpenAIPassthroughTimeoutHeader(lowerKey string) bool { func isOpenAIPassthroughTimeoutHeader(lowerKey string) bool {
@@ -1555,7 +1582,7 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.
// 若开启 ForceCodexCLI则强制将上游 User-Agent 伪装为 Codex CLI。 // 若开启 ForceCodexCLI则强制将上游 User-Agent 伪装为 Codex CLI。
// 用于网关未透传/改写 User-Agent 时,仍能命中 Codex 侧识别逻辑。 // 用于网关未透传/改写 User-Agent 时,仍能命中 Codex 侧识别逻辑。
if s.cfg != nil && s.cfg.Gateway.ForceCodexCLI { if s.cfg != nil && s.cfg.Gateway.ForceCodexCLI {
req.Header.Set("user-agent", "codex_cli_rs/0.98.0") req.Header.Set("user-agent", codexCLIUserAgent)
} }
// Ensure required headers exist // Ensure required headers exist

View File

@@ -189,7 +189,7 @@ func TestOpenAIGatewayService_OAuthPassthrough_StreamKeepsToolNameAndBodyUnchang
require.Empty(t, upstream.lastReq.Header.Get("X-Goog-Api-Key")) require.Empty(t, upstream.lastReq.Header.Get("X-Goog-Api-Key"))
require.Empty(t, upstream.lastReq.Header.Get("Accept-Encoding")) require.Empty(t, upstream.lastReq.Header.Get("Accept-Encoding"))
require.Empty(t, upstream.lastReq.Header.Get("Proxy-Authorization")) require.Empty(t, upstream.lastReq.Header.Get("Proxy-Authorization"))
require.Equal(t, "keep", upstream.lastReq.Header.Get("X-Test")) require.Empty(t, upstream.lastReq.Header.Get("X-Test"))
// 3) required OAuth headers are present // 3) required OAuth headers are present
require.Equal(t, "chatgpt.com", upstream.lastReq.Host) require.Equal(t, "chatgpt.com", upstream.lastReq.Host)
@@ -344,7 +344,7 @@ func TestOpenAIGatewayService_OAuthPassthrough_UpstreamErrorIncludesPassthroughF
require.True(t, arr[len(arr)-1].Passthrough) require.True(t, arr[len(arr)-1].Passthrough)
} }
func TestOpenAIGatewayService_OAuthPassthrough_NonCodexUAStillPassthroughWhenEnabled(t *testing.T) { func TestOpenAIGatewayService_OAuthPassthrough_NonCodexUAFallbackToCodexUA(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
@@ -383,7 +383,7 @@ func TestOpenAIGatewayService_OAuthPassthrough_NonCodexUAStillPassthroughWhenEna
_, err := svc.Forward(context.Background(), c, account, inputBody) _, err := svc.Forward(context.Background(), c, account, inputBody)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, inputBody, upstream.lastBody) require.Equal(t, inputBody, upstream.lastBody)
require.Equal(t, "curl/8.0", upstream.lastReq.Header.Get("User-Agent")) require.Equal(t, "codex_cli_rs/0.98.0", upstream.lastReq.Header.Get("User-Agent"))
} }
func TestOpenAIGatewayService_OAuthPassthrough_StreamingSetsFirstTokenMs(t *testing.T) { func TestOpenAIGatewayService_OAuthPassthrough_StreamingSetsFirstTokenMs(t *testing.T) {
@@ -533,7 +533,7 @@ func TestOpenAIGatewayService_APIKeyPassthrough_PreservesBodyAndUsesResponsesEnd
require.Equal(t, "https://api.openai.com/v1/responses", upstream.lastReq.URL.String()) require.Equal(t, "https://api.openai.com/v1/responses", upstream.lastReq.URL.String())
require.Equal(t, "Bearer sk-api-key", upstream.lastReq.Header.Get("Authorization")) require.Equal(t, "Bearer sk-api-key", upstream.lastReq.Header.Get("Authorization"))
require.Equal(t, "curl/8.0", upstream.lastReq.Header.Get("User-Agent")) require.Equal(t, "curl/8.0", upstream.lastReq.Header.Get("User-Agent"))
require.Equal(t, "keep", upstream.lastReq.Header.Get("X-Test")) require.Empty(t, upstream.lastReq.Header.Get("X-Test"))
} }
func TestOpenAIGatewayService_OAuthPassthrough_WarnOnTimeoutHeadersForStream(t *testing.T) { func TestOpenAIGatewayService_OAuthPassthrough_WarnOnTimeoutHeadersForStream(t *testing.T) {
@@ -656,7 +656,7 @@ func TestOpenAIGatewayService_OAuthPassthrough_DefaultFiltersTimeoutHeaders(t *t
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, upstream.lastReq) require.NotNil(t, upstream.lastReq)
require.Empty(t, upstream.lastReq.Header.Get("x-stainless-timeout")) require.Empty(t, upstream.lastReq.Header.Get("x-stainless-timeout"))
require.Equal(t, "keep", upstream.lastReq.Header.Get("X-Test")) require.Empty(t, upstream.lastReq.Header.Get("X-Test"))
} }
func TestOpenAIGatewayService_OAuthPassthrough_AllowTimeoutHeadersWhenConfigured(t *testing.T) { func TestOpenAIGatewayService_OAuthPassthrough_AllowTimeoutHeadersWhenConfigured(t *testing.T) {
@@ -700,5 +700,5 @@ func TestOpenAIGatewayService_OAuthPassthrough_AllowTimeoutHeadersWhenConfigured
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, upstream.lastReq) require.NotNil(t, upstream.lastReq)
require.Equal(t, "120000", upstream.lastReq.Header.Get("x-stainless-timeout")) require.Equal(t, "120000", upstream.lastReq.Header.Get("x-stainless-timeout"))
require.Equal(t, "keep", upstream.lastReq.Header.Get("X-Test")) require.Empty(t, upstream.lastReq.Header.Get("X-Test"))
} }