fix(openai): detect official codex client by headers
This commit is contained in:
@@ -58,6 +58,12 @@ func IsCodexOfficialClientOriginator(originator string) bool {
|
||||
return matchCodexClientHeaderPrefixes(v, CodexOfficialClientOriginatorPrefixes)
|
||||
}
|
||||
|
||||
// IsCodexOfficialClientByHeaders checks whether the request headers indicate an
|
||||
// official Codex client family request.
|
||||
func IsCodexOfficialClientByHeaders(userAgent, originator string) bool {
|
||||
return IsCodexOfficialClientRequest(userAgent) || IsCodexOfficialClientOriginator(originator)
|
||||
}
|
||||
|
||||
func normalizeCodexClientHeader(value string) string {
|
||||
return strings.ToLower(strings.TrimSpace(value))
|
||||
}
|
||||
|
||||
@@ -85,3 +85,26 @@ func TestIsCodexOfficialClientOriginator(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsCodexOfficialClientByHeaders(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
ua string
|
||||
originator string
|
||||
want bool
|
||||
}{
|
||||
{name: "仅 originator 命中 desktop", originator: "Codex Desktop", want: true},
|
||||
{name: "仅 originator 命中 vscode", originator: "codex_vscode", want: true},
|
||||
{name: "仅 ua 命中 desktop", ua: "Codex Desktop/1.2.3", want: true},
|
||||
{name: "ua 与 originator 都未命中", ua: "curl/8.0.1", originator: "my_client", want: false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := IsCodexOfficialClientByHeaders(tt.ua, tt.originator)
|
||||
if got != tt.want {
|
||||
t.Fatalf("IsCodexOfficialClientByHeaders(%q, %q) = %v, want %v", tt.ua, tt.originator, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -834,8 +834,10 @@ func logOpenAIInstructionsRequiredDebug(
|
||||
}
|
||||
|
||||
userAgent := ""
|
||||
originator := ""
|
||||
if c != nil {
|
||||
userAgent = strings.TrimSpace(c.GetHeader("User-Agent"))
|
||||
originator = strings.TrimSpace(c.GetHeader("originator"))
|
||||
}
|
||||
|
||||
fields := []zap.Field{
|
||||
@@ -845,7 +847,7 @@ func logOpenAIInstructionsRequiredDebug(
|
||||
zap.Int("upstream_status_code", upstreamStatusCode),
|
||||
zap.String("upstream_error_message", msg),
|
||||
zap.String("request_user_agent", userAgent),
|
||||
zap.Bool("codex_official_client_match", openai.IsCodexCLIRequest(userAgent)),
|
||||
zap.Bool("codex_official_client_match", openai.IsCodexOfficialClientByHeaders(userAgent, originator)),
|
||||
}
|
||||
fields = appendCodexCLIOnlyRejectedRequestFields(fields, c, requestBody)
|
||||
|
||||
@@ -968,6 +970,18 @@ func (s *OpenAIGatewayService) GenerateSessionHashWithFallback(c *gin.Context, b
|
||||
return currentHash
|
||||
}
|
||||
|
||||
func resolveOpenAIUpstreamOriginator(c *gin.Context, isOfficialClient bool) string {
|
||||
if c != nil {
|
||||
if originator := strings.TrimSpace(c.GetHeader("originator")); originator != "" {
|
||||
return originator
|
||||
}
|
||||
}
|
||||
if isOfficialClient {
|
||||
return "codex_cli_rs"
|
||||
}
|
||||
return "opencode"
|
||||
}
|
||||
|
||||
// BindStickySession sets session -> account binding with standard TTL.
|
||||
func (s *OpenAIGatewayService) BindStickySession(ctx context.Context, groupID *int64, sessionHash string, accountID int64) error {
|
||||
if sessionHash == "" || accountID <= 0 {
|
||||
@@ -1489,7 +1503,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
||||
reqModel, reqStream, promptCacheKey := extractOpenAIRequestMetaFromBody(body)
|
||||
originalModel := reqModel
|
||||
|
||||
isCodexCLI := openai.IsCodexCLIRequest(c.GetHeader("User-Agent")) || (s.cfg != nil && s.cfg.Gateway.ForceCodexCLI)
|
||||
isCodexCLI := openai.IsCodexOfficialClientByHeaders(c.GetHeader("User-Agent"), c.GetHeader("originator")) || (s.cfg != nil && s.cfg.Gateway.ForceCodexCLI)
|
||||
wsDecision := s.getOpenAIWSProtocolResolver().Resolve(account)
|
||||
clientTransport := GetOpenAIClientTransport(c)
|
||||
// 仅允许 WS 入站请求走 WS 上游,避免出现 HTTP -> WS 协议混用。
|
||||
@@ -2664,11 +2678,7 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.
|
||||
}
|
||||
if account.Type == AccountTypeOAuth {
|
||||
req.Header.Set("OpenAI-Beta", "responses=experimental")
|
||||
if isCodexCLI {
|
||||
req.Header.Set("originator", "codex_cli_rs")
|
||||
} else {
|
||||
req.Header.Set("originator", "opencode")
|
||||
}
|
||||
req.Header.Set("originator", resolveOpenAIUpstreamOriginator(c, isCodexCLI))
|
||||
if isOpenAIResponsesCompactPath(c) {
|
||||
req.Header.Set("accept", "application/json")
|
||||
if req.Header.Get("version") == "" {
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
"github.com/cespare/xxhash/v2"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -1372,6 +1373,46 @@ func TestOpenAIBuildUpstreamRequestPreservesCompactPathForAPIKeyBaseURL(t *testi
|
||||
require.Equal(t, "https://example.com/v1/responses/compact", req.URL.String())
|
||||
}
|
||||
|
||||
func TestOpenAIBuildUpstreamRequestOAuthOfficialClientOriginatorCompatibility(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
userAgent string
|
||||
originator string
|
||||
wantOriginator string
|
||||
}{
|
||||
{name: "desktop originator preserved", originator: "Codex Desktop", wantOriginator: "Codex Desktop"},
|
||||
{name: "vscode originator preserved", originator: "codex_vscode", wantOriginator: "codex_vscode"},
|
||||
{name: "official ua fallback to codex_cli_rs", userAgent: "Codex Desktop/1.2.3", wantOriginator: "codex_cli_rs"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader([]byte(`{"model":"gpt-5"}`)))
|
||||
if tt.userAgent != "" {
|
||||
c.Request.Header.Set("User-Agent", tt.userAgent)
|
||||
}
|
||||
if tt.originator != "" {
|
||||
c.Request.Header.Set("originator", tt.originator)
|
||||
}
|
||||
|
||||
svc := &OpenAIGatewayService{}
|
||||
account := &Account{
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{"chatgpt_account_id": "chatgpt-acc"},
|
||||
}
|
||||
|
||||
isCodexCLI := openai.IsCodexOfficialClientByHeaders(c.GetHeader("User-Agent"), c.GetHeader("originator"))
|
||||
req, err := svc.buildUpstreamRequest(c.Request.Context(), c, account, []byte(`{"model":"gpt-5"}`), "token", false, "", isCodexCLI)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tt.wantOriginator, req.Header.Get("originator"))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== P1-08 修复:model 替换性能优化测试 ====================
|
||||
|
||||
// ==================== P1-08 修复:model 替换性能优化测试 =============
|
||||
|
||||
@@ -1141,11 +1141,7 @@ func (s *OpenAIGatewayService) buildOpenAIWSHeaders(
|
||||
if chatgptAccountID := account.GetChatGPTAccountID(); chatgptAccountID != "" {
|
||||
headers.Set("chatgpt-account-id", chatgptAccountID)
|
||||
}
|
||||
if isCodexCLI {
|
||||
headers.Set("originator", "codex_cli_rs")
|
||||
} else {
|
||||
headers.Set("originator", "opencode")
|
||||
}
|
||||
headers.Set("originator", resolveOpenAIUpstreamOriginator(c, isCodexCLI))
|
||||
}
|
||||
|
||||
betaValue := openAIWSBetaV2Value
|
||||
@@ -2543,7 +2539,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
|
||||
}
|
||||
}
|
||||
|
||||
isCodexCLI := openai.IsCodexCLIRequest(c.GetHeader("User-Agent")) || (s.cfg != nil && s.cfg.Gateway.ForceCodexCLI)
|
||||
isCodexCLI := openai.IsCodexOfficialClientByHeaders(c.GetHeader("User-Agent"), c.GetHeader("originator")) || (s.cfg != nil && s.cfg.Gateway.ForceCodexCLI)
|
||||
wsHeaders, _ := s.buildOpenAIWSHeaders(c, account, token, wsDecision, isCodexCLI, turnState, strings.TrimSpace(c.GetHeader(openAIWSTurnMetadataHeader)), firstPayload.promptCacheKey)
|
||||
baseAcquireReq := openAIWSAcquireRequest{
|
||||
Account: account,
|
||||
|
||||
@@ -458,6 +458,86 @@ func TestOpenAIGatewayService_Forward_WSv2_OAuthStoreFalseByDefault(t *testing.T
|
||||
require.Equal(t, "conv-oauth-1", captureDialer.lastHeaders.Get("conversation_id"))
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_Forward_WSv2_OAuthOriginatorCompatibility(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
userAgent string
|
||||
originator string
|
||||
wantOriginator string
|
||||
}{
|
||||
{name: "desktop originator preserved", originator: "Codex Desktop", wantOriginator: "Codex Desktop"},
|
||||
{name: "vscode originator preserved", originator: "codex_vscode", wantOriginator: "codex_vscode"},
|
||||
{name: "official ua fallback to codex_cli_rs", userAgent: "Codex Desktop/1.2.3", wantOriginator: "codex_cli_rs"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
|
||||
if tt.userAgent != "" {
|
||||
c.Request.Header.Set("User-Agent", tt.userAgent)
|
||||
}
|
||||
if tt.originator != "" {
|
||||
c.Request.Header.Set("originator", tt.originator)
|
||||
}
|
||||
|
||||
cfg := &config.Config{}
|
||||
cfg.Security.URLAllowlist.Enabled = false
|
||||
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
|
||||
cfg.Gateway.OpenAIWS.Enabled = true
|
||||
cfg.Gateway.OpenAIWS.OAuthEnabled = true
|
||||
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
|
||||
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
|
||||
cfg.Gateway.OpenAIWS.AllowStoreRecovery = false
|
||||
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
|
||||
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
|
||||
cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
|
||||
|
||||
captureConn := &openAIWSCaptureConn{
|
||||
events: [][]byte{
|
||||
[]byte(`{"type":"response.completed","response":{"id":"resp_oauth_originator","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
|
||||
},
|
||||
}
|
||||
captureDialer := &openAIWSCaptureDialer{conn: captureConn}
|
||||
pool := newOpenAIWSConnPool(cfg)
|
||||
pool.setClientDialerForTest(captureDialer)
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
cfg: cfg,
|
||||
httpUpstream: &httpUpstreamRecorder{},
|
||||
cache: &stubGatewayCache{},
|
||||
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
|
||||
toolCorrector: NewCodexToolCorrector(),
|
||||
openaiWSPool: pool,
|
||||
}
|
||||
account := &Account{
|
||||
ID: 129,
|
||||
Name: "openai-oauth",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "oauth-token-1",
|
||||
},
|
||||
Extra: map[string]any{
|
||||
"responses_websockets_v2_enabled": true,
|
||||
},
|
||||
}
|
||||
|
||||
body := []byte(`{"model":"gpt-5.1","stream":false,"input":[{"type":"input_text","text":"hello"}]}`)
|
||||
result, err := svc.Forward(context.Background(), c, account, body)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, tt.wantOriginator, captureDialer.lastHeaders.Get("originator"))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_Forward_WSv2_HeaderSessionFallbackFromPromptCacheKey(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
|
||||
@@ -107,7 +107,7 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
|
||||
|
||||
isCodexCLI := false
|
||||
if c != nil {
|
||||
isCodexCLI = openai.IsCodexCLIRequest(c.GetHeader("User-Agent"))
|
||||
isCodexCLI = openai.IsCodexOfficialClientByHeaders(c.GetHeader("User-Agent"), c.GetHeader("originator"))
|
||||
}
|
||||
if s.cfg != nil && s.cfg.Gateway.ForceCodexCLI {
|
||||
isCodexCLI = true
|
||||
|
||||
Reference in New Issue
Block a user