From da89583cccd6387d4fdc40b13bda365edbc84829 Mon Sep 17 00:00:00 2001 From: admin Date: Sat, 7 Mar 2026 14:12:38 +0800 Subject: [PATCH] fix(openai): detect official codex client by headers --- backend/internal/pkg/openai/request.go | 6 ++ backend/internal/pkg/openai/request_test.go | 23 ++++++ .../service/openai_gateway_service.go | 24 ++++-- .../service/openai_gateway_service_test.go | 41 ++++++++++ .../internal/service/openai_ws_forwarder.go | 8 +- .../openai_ws_forwarder_success_test.go | 80 +++++++++++++++++++ .../openai_ws_v2_passthrough_adapter.go | 2 +- 7 files changed, 170 insertions(+), 14 deletions(-) diff --git a/backend/internal/pkg/openai/request.go b/backend/internal/pkg/openai/request.go index c24d1273..dd8fe566 100644 --- a/backend/internal/pkg/openai/request.go +++ b/backend/internal/pkg/openai/request.go @@ -58,6 +58,12 @@ func IsCodexOfficialClientOriginator(originator string) bool { return matchCodexClientHeaderPrefixes(v, CodexOfficialClientOriginatorPrefixes) } +// IsCodexOfficialClientByHeaders checks whether the request headers indicate an +// official Codex client family request. +func IsCodexOfficialClientByHeaders(userAgent, originator string) bool { + return IsCodexOfficialClientRequest(userAgent) || IsCodexOfficialClientOriginator(originator) +} + func normalizeCodexClientHeader(value string) string { return strings.ToLower(strings.TrimSpace(value)) } diff --git a/backend/internal/pkg/openai/request_test.go b/backend/internal/pkg/openai/request_test.go index 508bf561..b4562a07 100644 --- a/backend/internal/pkg/openai/request_test.go +++ b/backend/internal/pkg/openai/request_test.go @@ -85,3 +85,26 @@ func TestIsCodexOfficialClientOriginator(t *testing.T) { }) } } + +func TestIsCodexOfficialClientByHeaders(t *testing.T) { + tests := []struct { + name string + ua string + originator string + want bool + }{ + {name: "仅 originator 命中 desktop", originator: "Codex Desktop", want: true}, + {name: "仅 originator 命中 vscode", originator: "codex_vscode", want: true}, + {name: "仅 ua 命中 desktop", ua: "Codex Desktop/1.2.3", want: true}, + {name: "ua 与 originator 都未命中", ua: "curl/8.0.1", originator: "my_client", want: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := IsCodexOfficialClientByHeaders(tt.ua, tt.originator) + if got != tt.want { + t.Fatalf("IsCodexOfficialClientByHeaders(%q, %q) = %v, want %v", tt.ua, tt.originator, got, tt.want) + } + }) + } +} diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index de9cad51..fe04b0c4 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -834,8 +834,10 @@ func logOpenAIInstructionsRequiredDebug( } userAgent := "" + originator := "" if c != nil { userAgent = strings.TrimSpace(c.GetHeader("User-Agent")) + originator = strings.TrimSpace(c.GetHeader("originator")) } fields := []zap.Field{ @@ -845,7 +847,7 @@ func logOpenAIInstructionsRequiredDebug( zap.Int("upstream_status_code", upstreamStatusCode), zap.String("upstream_error_message", msg), zap.String("request_user_agent", userAgent), - zap.Bool("codex_official_client_match", openai.IsCodexCLIRequest(userAgent)), + zap.Bool("codex_official_client_match", openai.IsCodexOfficialClientByHeaders(userAgent, originator)), } fields = appendCodexCLIOnlyRejectedRequestFields(fields, c, requestBody) @@ -968,6 +970,18 @@ func (s *OpenAIGatewayService) GenerateSessionHashWithFallback(c *gin.Context, b return currentHash } +func resolveOpenAIUpstreamOriginator(c *gin.Context, isOfficialClient bool) string { + if c != nil { + if originator := strings.TrimSpace(c.GetHeader("originator")); originator != "" { + return originator + } + } + if isOfficialClient { + return "codex_cli_rs" + } + return "opencode" +} + // BindStickySession sets session -> account binding with standard TTL. func (s *OpenAIGatewayService) BindStickySession(ctx context.Context, groupID *int64, sessionHash string, accountID int64) error { if sessionHash == "" || accountID <= 0 { @@ -1489,7 +1503,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco reqModel, reqStream, promptCacheKey := extractOpenAIRequestMetaFromBody(body) originalModel := reqModel - isCodexCLI := openai.IsCodexCLIRequest(c.GetHeader("User-Agent")) || (s.cfg != nil && s.cfg.Gateway.ForceCodexCLI) + isCodexCLI := openai.IsCodexOfficialClientByHeaders(c.GetHeader("User-Agent"), c.GetHeader("originator")) || (s.cfg != nil && s.cfg.Gateway.ForceCodexCLI) wsDecision := s.getOpenAIWSProtocolResolver().Resolve(account) clientTransport := GetOpenAIClientTransport(c) // 仅允许 WS 入站请求走 WS 上游,避免出现 HTTP -> WS 协议混用。 @@ -2664,11 +2678,7 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin. } if account.Type == AccountTypeOAuth { req.Header.Set("OpenAI-Beta", "responses=experimental") - if isCodexCLI { - req.Header.Set("originator", "codex_cli_rs") - } else { - req.Header.Set("originator", "opencode") - } + req.Header.Set("originator", resolveOpenAIUpstreamOriginator(c, isCodexCLI)) if isOpenAIResponsesCompactPath(c) { req.Header.Set("accept", "application/json") if req.Header.Get("version") == "" { diff --git a/backend/internal/service/openai_gateway_service_test.go b/backend/internal/service/openai_gateway_service_test.go index 1be1eb50..43e2f39d 100644 --- a/backend/internal/service/openai_gateway_service_test.go +++ b/backend/internal/service/openai_gateway_service_test.go @@ -14,6 +14,7 @@ import ( "time" "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/openai" "github.com/cespare/xxhash/v2" "github.com/gin-gonic/gin" "github.com/stretchr/testify/require" @@ -1372,6 +1373,46 @@ func TestOpenAIBuildUpstreamRequestPreservesCompactPathForAPIKeyBaseURL(t *testi require.Equal(t, "https://example.com/v1/responses/compact", req.URL.String()) } +func TestOpenAIBuildUpstreamRequestOAuthOfficialClientOriginatorCompatibility(t *testing.T) { + gin.SetMode(gin.TestMode) + + tests := []struct { + name string + userAgent string + originator string + wantOriginator string + }{ + {name: "desktop originator preserved", originator: "Codex Desktop", wantOriginator: "Codex Desktop"}, + {name: "vscode originator preserved", originator: "codex_vscode", wantOriginator: "codex_vscode"}, + {name: "official ua fallback to codex_cli_rs", userAgent: "Codex Desktop/1.2.3", wantOriginator: "codex_cli_rs"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader([]byte(`{"model":"gpt-5"}`))) + if tt.userAgent != "" { + c.Request.Header.Set("User-Agent", tt.userAgent) + } + if tt.originator != "" { + c.Request.Header.Set("originator", tt.originator) + } + + svc := &OpenAIGatewayService{} + account := &Account{ + Type: AccountTypeOAuth, + Credentials: map[string]any{"chatgpt_account_id": "chatgpt-acc"}, + } + + isCodexCLI := openai.IsCodexOfficialClientByHeaders(c.GetHeader("User-Agent"), c.GetHeader("originator")) + req, err := svc.buildUpstreamRequest(c.Request.Context(), c, account, []byte(`{"model":"gpt-5"}`), "token", false, "", isCodexCLI) + require.NoError(t, err) + require.Equal(t, tt.wantOriginator, req.Header.Get("originator")) + }) + } +} + // ==================== P1-08 修复:model 替换性能优化测试 ==================== // ==================== P1-08 修复:model 替换性能优化测试 ============= diff --git a/backend/internal/service/openai_ws_forwarder.go b/backend/internal/service/openai_ws_forwarder.go index 91218cd2..f9e93f85 100644 --- a/backend/internal/service/openai_ws_forwarder.go +++ b/backend/internal/service/openai_ws_forwarder.go @@ -1141,11 +1141,7 @@ func (s *OpenAIGatewayService) buildOpenAIWSHeaders( if chatgptAccountID := account.GetChatGPTAccountID(); chatgptAccountID != "" { headers.Set("chatgpt-account-id", chatgptAccountID) } - if isCodexCLI { - headers.Set("originator", "codex_cli_rs") - } else { - headers.Set("originator", "opencode") - } + headers.Set("originator", resolveOpenAIUpstreamOriginator(c, isCodexCLI)) } betaValue := openAIWSBetaV2Value @@ -2543,7 +2539,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( } } - isCodexCLI := openai.IsCodexCLIRequest(c.GetHeader("User-Agent")) || (s.cfg != nil && s.cfg.Gateway.ForceCodexCLI) + isCodexCLI := openai.IsCodexOfficialClientByHeaders(c.GetHeader("User-Agent"), c.GetHeader("originator")) || (s.cfg != nil && s.cfg.Gateway.ForceCodexCLI) wsHeaders, _ := s.buildOpenAIWSHeaders(c, account, token, wsDecision, isCodexCLI, turnState, strings.TrimSpace(c.GetHeader(openAIWSTurnMetadataHeader)), firstPayload.promptCacheKey) baseAcquireReq := openAIWSAcquireRequest{ Account: account, diff --git a/backend/internal/service/openai_ws_forwarder_success_test.go b/backend/internal/service/openai_ws_forwarder_success_test.go index 1beb9ae9..912fade9 100644 --- a/backend/internal/service/openai_ws_forwarder_success_test.go +++ b/backend/internal/service/openai_ws_forwarder_success_test.go @@ -458,6 +458,86 @@ func TestOpenAIGatewayService_Forward_WSv2_OAuthStoreFalseByDefault(t *testing.T require.Equal(t, "conv-oauth-1", captureDialer.lastHeaders.Get("conversation_id")) } +func TestOpenAIGatewayService_Forward_WSv2_OAuthOriginatorCompatibility(t *testing.T) { + gin.SetMode(gin.TestMode) + + tests := []struct { + name string + userAgent string + originator string + wantOriginator string + }{ + {name: "desktop originator preserved", originator: "Codex Desktop", wantOriginator: "Codex Desktop"}, + {name: "vscode originator preserved", originator: "codex_vscode", wantOriginator: "codex_vscode"}, + {name: "official ua fallback to codex_cli_rs", userAgent: "Codex Desktop/1.2.3", wantOriginator: "codex_cli_rs"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + if tt.userAgent != "" { + c.Request.Header.Set("User-Agent", tt.userAgent) + } + if tt.originator != "" { + c.Request.Header.Set("originator", tt.originator) + } + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.AllowStoreRecovery = false + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 + + captureConn := &openAIWSCaptureConn{ + events: [][]byte{ + []byte(`{"type":"response.completed","response":{"id":"resp_oauth_originator","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + }, + } + captureDialer := &openAIWSCaptureDialer{conn: captureConn} + pool := newOpenAIWSConnPool(cfg) + pool.setClientDialerForTest(captureDialer) + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + openaiWSPool: pool, + } + account := &Account{ + ID: 129, + Name: "openai-oauth", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "oauth-token-1", + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + body := []byte(`{"model":"gpt-5.1","stream":false,"input":[{"type":"input_text","text":"hello"}]}`) + result, err := svc.Forward(context.Background(), c, account, body) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, tt.wantOriginator, captureDialer.lastHeaders.Get("originator")) + }) + } +} + func TestOpenAIGatewayService_Forward_WSv2_HeaderSessionFallbackFromPromptCacheKey(t *testing.T) { gin.SetMode(gin.TestMode) diff --git a/backend/internal/service/openai_ws_v2_passthrough_adapter.go b/backend/internal/service/openai_ws_v2_passthrough_adapter.go index 29a2640d..c18c921f 100644 --- a/backend/internal/service/openai_ws_v2_passthrough_adapter.go +++ b/backend/internal/service/openai_ws_v2_passthrough_adapter.go @@ -107,7 +107,7 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough( isCodexCLI := false if c != nil { - isCodexCLI = openai.IsCodexCLIRequest(c.GetHeader("User-Agent")) + isCodexCLI = openai.IsCodexOfficialClientByHeaders(c.GetHeader("User-Agent"), c.GetHeader("originator")) } if s.cfg != nil && s.cfg.Gateway.ForceCodexCLI { isCodexCLI = true