fix(openai): detect official codex client by headers

This commit is contained in:
admin
2026-03-07 14:12:38 +08:00
parent 6411645ffc
commit da89583ccc
7 changed files with 170 additions and 14 deletions

View File

@@ -58,6 +58,12 @@ func IsCodexOfficialClientOriginator(originator string) bool {
return matchCodexClientHeaderPrefixes(v, CodexOfficialClientOriginatorPrefixes) return matchCodexClientHeaderPrefixes(v, CodexOfficialClientOriginatorPrefixes)
} }
// IsCodexOfficialClientByHeaders checks whether the request headers indicate an
// official Codex client family request.
func IsCodexOfficialClientByHeaders(userAgent, originator string) bool {
return IsCodexOfficialClientRequest(userAgent) || IsCodexOfficialClientOriginator(originator)
}
func normalizeCodexClientHeader(value string) string { func normalizeCodexClientHeader(value string) string {
return strings.ToLower(strings.TrimSpace(value)) return strings.ToLower(strings.TrimSpace(value))
} }

View File

@@ -85,3 +85,26 @@ func TestIsCodexOfficialClientOriginator(t *testing.T) {
}) })
} }
} }
func TestIsCodexOfficialClientByHeaders(t *testing.T) {
tests := []struct {
name string
ua string
originator string
want bool
}{
{name: "仅 originator 命中 desktop", originator: "Codex Desktop", want: true},
{name: "仅 originator 命中 vscode", originator: "codex_vscode", want: true},
{name: "仅 ua 命中 desktop", ua: "Codex Desktop/1.2.3", want: true},
{name: "ua 与 originator 都未命中", ua: "curl/8.0.1", originator: "my_client", want: false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := IsCodexOfficialClientByHeaders(tt.ua, tt.originator)
if got != tt.want {
t.Fatalf("IsCodexOfficialClientByHeaders(%q, %q) = %v, want %v", tt.ua, tt.originator, got, tt.want)
}
})
}
}

View File

@@ -834,8 +834,10 @@ func logOpenAIInstructionsRequiredDebug(
} }
userAgent := "" userAgent := ""
originator := ""
if c != nil { if c != nil {
userAgent = strings.TrimSpace(c.GetHeader("User-Agent")) userAgent = strings.TrimSpace(c.GetHeader("User-Agent"))
originator = strings.TrimSpace(c.GetHeader("originator"))
} }
fields := []zap.Field{ fields := []zap.Field{
@@ -845,7 +847,7 @@ func logOpenAIInstructionsRequiredDebug(
zap.Int("upstream_status_code", upstreamStatusCode), zap.Int("upstream_status_code", upstreamStatusCode),
zap.String("upstream_error_message", msg), zap.String("upstream_error_message", msg),
zap.String("request_user_agent", userAgent), zap.String("request_user_agent", userAgent),
zap.Bool("codex_official_client_match", openai.IsCodexCLIRequest(userAgent)), zap.Bool("codex_official_client_match", openai.IsCodexOfficialClientByHeaders(userAgent, originator)),
} }
fields = appendCodexCLIOnlyRejectedRequestFields(fields, c, requestBody) fields = appendCodexCLIOnlyRejectedRequestFields(fields, c, requestBody)
@@ -968,6 +970,18 @@ func (s *OpenAIGatewayService) GenerateSessionHashWithFallback(c *gin.Context, b
return currentHash return currentHash
} }
func resolveOpenAIUpstreamOriginator(c *gin.Context, isOfficialClient bool) string {
if c != nil {
if originator := strings.TrimSpace(c.GetHeader("originator")); originator != "" {
return originator
}
}
if isOfficialClient {
return "codex_cli_rs"
}
return "opencode"
}
// BindStickySession sets session -> account binding with standard TTL. // BindStickySession sets session -> account binding with standard TTL.
func (s *OpenAIGatewayService) BindStickySession(ctx context.Context, groupID *int64, sessionHash string, accountID int64) error { func (s *OpenAIGatewayService) BindStickySession(ctx context.Context, groupID *int64, sessionHash string, accountID int64) error {
if sessionHash == "" || accountID <= 0 { if sessionHash == "" || accountID <= 0 {
@@ -1489,7 +1503,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
reqModel, reqStream, promptCacheKey := extractOpenAIRequestMetaFromBody(body) reqModel, reqStream, promptCacheKey := extractOpenAIRequestMetaFromBody(body)
originalModel := reqModel originalModel := reqModel
isCodexCLI := openai.IsCodexCLIRequest(c.GetHeader("User-Agent")) || (s.cfg != nil && s.cfg.Gateway.ForceCodexCLI) isCodexCLI := openai.IsCodexOfficialClientByHeaders(c.GetHeader("User-Agent"), c.GetHeader("originator")) || (s.cfg != nil && s.cfg.Gateway.ForceCodexCLI)
wsDecision := s.getOpenAIWSProtocolResolver().Resolve(account) wsDecision := s.getOpenAIWSProtocolResolver().Resolve(account)
clientTransport := GetOpenAIClientTransport(c) clientTransport := GetOpenAIClientTransport(c)
// 仅允许 WS 入站请求走 WS 上游,避免出现 HTTP -> WS 协议混用。 // 仅允许 WS 入站请求走 WS 上游,避免出现 HTTP -> WS 协议混用。
@@ -2664,11 +2678,7 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.
} }
if account.Type == AccountTypeOAuth { if account.Type == AccountTypeOAuth {
req.Header.Set("OpenAI-Beta", "responses=experimental") req.Header.Set("OpenAI-Beta", "responses=experimental")
if isCodexCLI { req.Header.Set("originator", resolveOpenAIUpstreamOriginator(c, isCodexCLI))
req.Header.Set("originator", "codex_cli_rs")
} else {
req.Header.Set("originator", "opencode")
}
if isOpenAIResponsesCompactPath(c) { if isOpenAIResponsesCompactPath(c) {
req.Header.Set("accept", "application/json") req.Header.Set("accept", "application/json")
if req.Header.Get("version") == "" { if req.Header.Get("version") == "" {

View File

@@ -14,6 +14,7 @@ import (
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/cespare/xxhash/v2" "github.com/cespare/xxhash/v2"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@@ -1372,6 +1373,46 @@ func TestOpenAIBuildUpstreamRequestPreservesCompactPathForAPIKeyBaseURL(t *testi
require.Equal(t, "https://example.com/v1/responses/compact", req.URL.String()) require.Equal(t, "https://example.com/v1/responses/compact", req.URL.String())
} }
func TestOpenAIBuildUpstreamRequestOAuthOfficialClientOriginatorCompatibility(t *testing.T) {
gin.SetMode(gin.TestMode)
tests := []struct {
name string
userAgent string
originator string
wantOriginator string
}{
{name: "desktop originator preserved", originator: "Codex Desktop", wantOriginator: "Codex Desktop"},
{name: "vscode originator preserved", originator: "codex_vscode", wantOriginator: "codex_vscode"},
{name: "official ua fallback to codex_cli_rs", userAgent: "Codex Desktop/1.2.3", wantOriginator: "codex_cli_rs"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader([]byte(`{"model":"gpt-5"}`)))
if tt.userAgent != "" {
c.Request.Header.Set("User-Agent", tt.userAgent)
}
if tt.originator != "" {
c.Request.Header.Set("originator", tt.originator)
}
svc := &OpenAIGatewayService{}
account := &Account{
Type: AccountTypeOAuth,
Credentials: map[string]any{"chatgpt_account_id": "chatgpt-acc"},
}
isCodexCLI := openai.IsCodexOfficialClientByHeaders(c.GetHeader("User-Agent"), c.GetHeader("originator"))
req, err := svc.buildUpstreamRequest(c.Request.Context(), c, account, []byte(`{"model":"gpt-5"}`), "token", false, "", isCodexCLI)
require.NoError(t, err)
require.Equal(t, tt.wantOriginator, req.Header.Get("originator"))
})
}
}
// ==================== P1-08 修复model 替换性能优化测试 ==================== // ==================== P1-08 修复model 替换性能优化测试 ====================
// ==================== P1-08 修复model 替换性能优化测试 ============= // ==================== P1-08 修复model 替换性能优化测试 =============

View File

@@ -1141,11 +1141,7 @@ func (s *OpenAIGatewayService) buildOpenAIWSHeaders(
if chatgptAccountID := account.GetChatGPTAccountID(); chatgptAccountID != "" { if chatgptAccountID := account.GetChatGPTAccountID(); chatgptAccountID != "" {
headers.Set("chatgpt-account-id", chatgptAccountID) headers.Set("chatgpt-account-id", chatgptAccountID)
} }
if isCodexCLI { headers.Set("originator", resolveOpenAIUpstreamOriginator(c, isCodexCLI))
headers.Set("originator", "codex_cli_rs")
} else {
headers.Set("originator", "opencode")
}
} }
betaValue := openAIWSBetaV2Value betaValue := openAIWSBetaV2Value
@@ -2543,7 +2539,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
} }
} }
isCodexCLI := openai.IsCodexCLIRequest(c.GetHeader("User-Agent")) || (s.cfg != nil && s.cfg.Gateway.ForceCodexCLI) isCodexCLI := openai.IsCodexOfficialClientByHeaders(c.GetHeader("User-Agent"), c.GetHeader("originator")) || (s.cfg != nil && s.cfg.Gateway.ForceCodexCLI)
wsHeaders, _ := s.buildOpenAIWSHeaders(c, account, token, wsDecision, isCodexCLI, turnState, strings.TrimSpace(c.GetHeader(openAIWSTurnMetadataHeader)), firstPayload.promptCacheKey) wsHeaders, _ := s.buildOpenAIWSHeaders(c, account, token, wsDecision, isCodexCLI, turnState, strings.TrimSpace(c.GetHeader(openAIWSTurnMetadataHeader)), firstPayload.promptCacheKey)
baseAcquireReq := openAIWSAcquireRequest{ baseAcquireReq := openAIWSAcquireRequest{
Account: account, Account: account,

View File

@@ -458,6 +458,86 @@ func TestOpenAIGatewayService_Forward_WSv2_OAuthStoreFalseByDefault(t *testing.T
require.Equal(t, "conv-oauth-1", captureDialer.lastHeaders.Get("conversation_id")) require.Equal(t, "conv-oauth-1", captureDialer.lastHeaders.Get("conversation_id"))
} }
func TestOpenAIGatewayService_Forward_WSv2_OAuthOriginatorCompatibility(t *testing.T) {
gin.SetMode(gin.TestMode)
tests := []struct {
name string
userAgent string
originator string
wantOriginator string
}{
{name: "desktop originator preserved", originator: "Codex Desktop", wantOriginator: "Codex Desktop"},
{name: "vscode originator preserved", originator: "codex_vscode", wantOriginator: "codex_vscode"},
{name: "official ua fallback to codex_cli_rs", userAgent: "Codex Desktop/1.2.3", wantOriginator: "codex_cli_rs"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
if tt.userAgent != "" {
c.Request.Header.Set("User-Agent", tt.userAgent)
}
if tt.originator != "" {
c.Request.Header.Set("originator", tt.originator)
}
cfg := &config.Config{}
cfg.Security.URLAllowlist.Enabled = false
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
cfg.Gateway.OpenAIWS.AllowStoreRecovery = false
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
captureConn := &openAIWSCaptureConn{
events: [][]byte{
[]byte(`{"type":"response.completed","response":{"id":"resp_oauth_originator","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
},
}
captureDialer := &openAIWSCaptureDialer{conn: captureConn}
pool := newOpenAIWSConnPool(cfg)
pool.setClientDialerForTest(captureDialer)
svc := &OpenAIGatewayService{
cfg: cfg,
httpUpstream: &httpUpstreamRecorder{},
cache: &stubGatewayCache{},
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
toolCorrector: NewCodexToolCorrector(),
openaiWSPool: pool,
}
account := &Account{
ID: 129,
Name: "openai-oauth",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "oauth-token-1",
},
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}
body := []byte(`{"model":"gpt-5.1","stream":false,"input":[{"type":"input_text","text":"hello"}]}`)
result, err := svc.Forward(context.Background(), c, account, body)
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, tt.wantOriginator, captureDialer.lastHeaders.Get("originator"))
})
}
}
func TestOpenAIGatewayService_Forward_WSv2_HeaderSessionFallbackFromPromptCacheKey(t *testing.T) { func TestOpenAIGatewayService_Forward_WSv2_HeaderSessionFallbackFromPromptCacheKey(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)

View File

@@ -107,7 +107,7 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
isCodexCLI := false isCodexCLI := false
if c != nil { if c != nil {
isCodexCLI = openai.IsCodexCLIRequest(c.GetHeader("User-Agent")) isCodexCLI = openai.IsCodexOfficialClientByHeaders(c.GetHeader("User-Agent"), c.GetHeader("originator"))
} }
if s.cfg != nil && s.cfg.Gateway.ForceCodexCLI { if s.cfg != nil && s.cfg.Gateway.ForceCodexCLI {
isCodexCLI = true isCodexCLI = true