fix(openai): detect official codex client by headers
This commit is contained in:
@@ -58,6 +58,12 @@ func IsCodexOfficialClientOriginator(originator string) bool {
|
|||||||
return matchCodexClientHeaderPrefixes(v, CodexOfficialClientOriginatorPrefixes)
|
return matchCodexClientHeaderPrefixes(v, CodexOfficialClientOriginatorPrefixes)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsCodexOfficialClientByHeaders checks whether the request headers indicate an
|
||||||
|
// official Codex client family request.
|
||||||
|
func IsCodexOfficialClientByHeaders(userAgent, originator string) bool {
|
||||||
|
return IsCodexOfficialClientRequest(userAgent) || IsCodexOfficialClientOriginator(originator)
|
||||||
|
}
|
||||||
|
|
||||||
func normalizeCodexClientHeader(value string) string {
|
func normalizeCodexClientHeader(value string) string {
|
||||||
return strings.ToLower(strings.TrimSpace(value))
|
return strings.ToLower(strings.TrimSpace(value))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -85,3 +85,26 @@ func TestIsCodexOfficialClientOriginator(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestIsCodexOfficialClientByHeaders(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
ua string
|
||||||
|
originator string
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{name: "仅 originator 命中 desktop", originator: "Codex Desktop", want: true},
|
||||||
|
{name: "仅 originator 命中 vscode", originator: "codex_vscode", want: true},
|
||||||
|
{name: "仅 ua 命中 desktop", ua: "Codex Desktop/1.2.3", want: true},
|
||||||
|
{name: "ua 与 originator 都未命中", ua: "curl/8.0.1", originator: "my_client", want: false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := IsCodexOfficialClientByHeaders(tt.ua, tt.originator)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Fatalf("IsCodexOfficialClientByHeaders(%q, %q) = %v, want %v", tt.ua, tt.originator, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -834,8 +834,10 @@ func logOpenAIInstructionsRequiredDebug(
|
|||||||
}
|
}
|
||||||
|
|
||||||
userAgent := ""
|
userAgent := ""
|
||||||
|
originator := ""
|
||||||
if c != nil {
|
if c != nil {
|
||||||
userAgent = strings.TrimSpace(c.GetHeader("User-Agent"))
|
userAgent = strings.TrimSpace(c.GetHeader("User-Agent"))
|
||||||
|
originator = strings.TrimSpace(c.GetHeader("originator"))
|
||||||
}
|
}
|
||||||
|
|
||||||
fields := []zap.Field{
|
fields := []zap.Field{
|
||||||
@@ -845,7 +847,7 @@ func logOpenAIInstructionsRequiredDebug(
|
|||||||
zap.Int("upstream_status_code", upstreamStatusCode),
|
zap.Int("upstream_status_code", upstreamStatusCode),
|
||||||
zap.String("upstream_error_message", msg),
|
zap.String("upstream_error_message", msg),
|
||||||
zap.String("request_user_agent", userAgent),
|
zap.String("request_user_agent", userAgent),
|
||||||
zap.Bool("codex_official_client_match", openai.IsCodexCLIRequest(userAgent)),
|
zap.Bool("codex_official_client_match", openai.IsCodexOfficialClientByHeaders(userAgent, originator)),
|
||||||
}
|
}
|
||||||
fields = appendCodexCLIOnlyRejectedRequestFields(fields, c, requestBody)
|
fields = appendCodexCLIOnlyRejectedRequestFields(fields, c, requestBody)
|
||||||
|
|
||||||
@@ -968,6 +970,18 @@ func (s *OpenAIGatewayService) GenerateSessionHashWithFallback(c *gin.Context, b
|
|||||||
return currentHash
|
return currentHash
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func resolveOpenAIUpstreamOriginator(c *gin.Context, isOfficialClient bool) string {
|
||||||
|
if c != nil {
|
||||||
|
if originator := strings.TrimSpace(c.GetHeader("originator")); originator != "" {
|
||||||
|
return originator
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if isOfficialClient {
|
||||||
|
return "codex_cli_rs"
|
||||||
|
}
|
||||||
|
return "opencode"
|
||||||
|
}
|
||||||
|
|
||||||
// BindStickySession sets session -> account binding with standard TTL.
|
// BindStickySession sets session -> account binding with standard TTL.
|
||||||
func (s *OpenAIGatewayService) BindStickySession(ctx context.Context, groupID *int64, sessionHash string, accountID int64) error {
|
func (s *OpenAIGatewayService) BindStickySession(ctx context.Context, groupID *int64, sessionHash string, accountID int64) error {
|
||||||
if sessionHash == "" || accountID <= 0 {
|
if sessionHash == "" || accountID <= 0 {
|
||||||
@@ -1489,7 +1503,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
|||||||
reqModel, reqStream, promptCacheKey := extractOpenAIRequestMetaFromBody(body)
|
reqModel, reqStream, promptCacheKey := extractOpenAIRequestMetaFromBody(body)
|
||||||
originalModel := reqModel
|
originalModel := reqModel
|
||||||
|
|
||||||
isCodexCLI := openai.IsCodexCLIRequest(c.GetHeader("User-Agent")) || (s.cfg != nil && s.cfg.Gateway.ForceCodexCLI)
|
isCodexCLI := openai.IsCodexOfficialClientByHeaders(c.GetHeader("User-Agent"), c.GetHeader("originator")) || (s.cfg != nil && s.cfg.Gateway.ForceCodexCLI)
|
||||||
wsDecision := s.getOpenAIWSProtocolResolver().Resolve(account)
|
wsDecision := s.getOpenAIWSProtocolResolver().Resolve(account)
|
||||||
clientTransport := GetOpenAIClientTransport(c)
|
clientTransport := GetOpenAIClientTransport(c)
|
||||||
// 仅允许 WS 入站请求走 WS 上游,避免出现 HTTP -> WS 协议混用。
|
// 仅允许 WS 入站请求走 WS 上游,避免出现 HTTP -> WS 协议混用。
|
||||||
@@ -2664,11 +2678,7 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.
|
|||||||
}
|
}
|
||||||
if account.Type == AccountTypeOAuth {
|
if account.Type == AccountTypeOAuth {
|
||||||
req.Header.Set("OpenAI-Beta", "responses=experimental")
|
req.Header.Set("OpenAI-Beta", "responses=experimental")
|
||||||
if isCodexCLI {
|
req.Header.Set("originator", resolveOpenAIUpstreamOriginator(c, isCodexCLI))
|
||||||
req.Header.Set("originator", "codex_cli_rs")
|
|
||||||
} else {
|
|
||||||
req.Header.Set("originator", "opencode")
|
|
||||||
}
|
|
||||||
if isOpenAIResponsesCompactPath(c) {
|
if isOpenAIResponsesCompactPath(c) {
|
||||||
req.Header.Set("accept", "application/json")
|
req.Header.Set("accept", "application/json")
|
||||||
if req.Header.Get("version") == "" {
|
if req.Header.Get("version") == "" {
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||||
"github.com/cespare/xxhash/v2"
|
"github.com/cespare/xxhash/v2"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
@@ -1372,6 +1373,46 @@ func TestOpenAIBuildUpstreamRequestPreservesCompactPathForAPIKeyBaseURL(t *testi
|
|||||||
require.Equal(t, "https://example.com/v1/responses/compact", req.URL.String())
|
require.Equal(t, "https://example.com/v1/responses/compact", req.URL.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOpenAIBuildUpstreamRequestOAuthOfficialClientOriginatorCompatibility(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
userAgent string
|
||||||
|
originator string
|
||||||
|
wantOriginator string
|
||||||
|
}{
|
||||||
|
{name: "desktop originator preserved", originator: "Codex Desktop", wantOriginator: "Codex Desktop"},
|
||||||
|
{name: "vscode originator preserved", originator: "codex_vscode", wantOriginator: "codex_vscode"},
|
||||||
|
{name: "official ua fallback to codex_cli_rs", userAgent: "Codex Desktop/1.2.3", wantOriginator: "codex_cli_rs"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader([]byte(`{"model":"gpt-5"}`)))
|
||||||
|
if tt.userAgent != "" {
|
||||||
|
c.Request.Header.Set("User-Agent", tt.userAgent)
|
||||||
|
}
|
||||||
|
if tt.originator != "" {
|
||||||
|
c.Request.Header.Set("originator", tt.originator)
|
||||||
|
}
|
||||||
|
|
||||||
|
svc := &OpenAIGatewayService{}
|
||||||
|
account := &Account{
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Credentials: map[string]any{"chatgpt_account_id": "chatgpt-acc"},
|
||||||
|
}
|
||||||
|
|
||||||
|
isCodexCLI := openai.IsCodexOfficialClientByHeaders(c.GetHeader("User-Agent"), c.GetHeader("originator"))
|
||||||
|
req, err := svc.buildUpstreamRequest(c.Request.Context(), c, account, []byte(`{"model":"gpt-5"}`), "token", false, "", isCodexCLI)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, tt.wantOriginator, req.Header.Get("originator"))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// ==================== P1-08 修复:model 替换性能优化测试 ====================
|
// ==================== P1-08 修复:model 替换性能优化测试 ====================
|
||||||
|
|
||||||
// ==================== P1-08 修复:model 替换性能优化测试 =============
|
// ==================== P1-08 修复:model 替换性能优化测试 =============
|
||||||
|
|||||||
@@ -1141,11 +1141,7 @@ func (s *OpenAIGatewayService) buildOpenAIWSHeaders(
|
|||||||
if chatgptAccountID := account.GetChatGPTAccountID(); chatgptAccountID != "" {
|
if chatgptAccountID := account.GetChatGPTAccountID(); chatgptAccountID != "" {
|
||||||
headers.Set("chatgpt-account-id", chatgptAccountID)
|
headers.Set("chatgpt-account-id", chatgptAccountID)
|
||||||
}
|
}
|
||||||
if isCodexCLI {
|
headers.Set("originator", resolveOpenAIUpstreamOriginator(c, isCodexCLI))
|
||||||
headers.Set("originator", "codex_cli_rs")
|
|
||||||
} else {
|
|
||||||
headers.Set("originator", "opencode")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
betaValue := openAIWSBetaV2Value
|
betaValue := openAIWSBetaV2Value
|
||||||
@@ -2543,7 +2539,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
isCodexCLI := openai.IsCodexCLIRequest(c.GetHeader("User-Agent")) || (s.cfg != nil && s.cfg.Gateway.ForceCodexCLI)
|
isCodexCLI := openai.IsCodexOfficialClientByHeaders(c.GetHeader("User-Agent"), c.GetHeader("originator")) || (s.cfg != nil && s.cfg.Gateway.ForceCodexCLI)
|
||||||
wsHeaders, _ := s.buildOpenAIWSHeaders(c, account, token, wsDecision, isCodexCLI, turnState, strings.TrimSpace(c.GetHeader(openAIWSTurnMetadataHeader)), firstPayload.promptCacheKey)
|
wsHeaders, _ := s.buildOpenAIWSHeaders(c, account, token, wsDecision, isCodexCLI, turnState, strings.TrimSpace(c.GetHeader(openAIWSTurnMetadataHeader)), firstPayload.promptCacheKey)
|
||||||
baseAcquireReq := openAIWSAcquireRequest{
|
baseAcquireReq := openAIWSAcquireRequest{
|
||||||
Account: account,
|
Account: account,
|
||||||
|
|||||||
@@ -458,6 +458,86 @@ func TestOpenAIGatewayService_Forward_WSv2_OAuthStoreFalseByDefault(t *testing.T
|
|||||||
require.Equal(t, "conv-oauth-1", captureDialer.lastHeaders.Get("conversation_id"))
|
require.Equal(t, "conv-oauth-1", captureDialer.lastHeaders.Get("conversation_id"))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOpenAIGatewayService_Forward_WSv2_OAuthOriginatorCompatibility(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
userAgent string
|
||||||
|
originator string
|
||||||
|
wantOriginator string
|
||||||
|
}{
|
||||||
|
{name: "desktop originator preserved", originator: "Codex Desktop", wantOriginator: "Codex Desktop"},
|
||||||
|
{name: "vscode originator preserved", originator: "codex_vscode", wantOriginator: "codex_vscode"},
|
||||||
|
{name: "official ua fallback to codex_cli_rs", userAgent: "Codex Desktop/1.2.3", wantOriginator: "codex_cli_rs"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
|
||||||
|
if tt.userAgent != "" {
|
||||||
|
c.Request.Header.Set("User-Agent", tt.userAgent)
|
||||||
|
}
|
||||||
|
if tt.originator != "" {
|
||||||
|
c.Request.Header.Set("originator", tt.originator)
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg := &config.Config{}
|
||||||
|
cfg.Security.URLAllowlist.Enabled = false
|
||||||
|
cfg.Security.URLAllowlist.AllowInsecureHTTP = true
|
||||||
|
cfg.Gateway.OpenAIWS.Enabled = true
|
||||||
|
cfg.Gateway.OpenAIWS.OAuthEnabled = true
|
||||||
|
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
|
||||||
|
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
|
||||||
|
cfg.Gateway.OpenAIWS.AllowStoreRecovery = false
|
||||||
|
cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
|
||||||
|
cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
|
||||||
|
cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
|
||||||
|
|
||||||
|
captureConn := &openAIWSCaptureConn{
|
||||||
|
events: [][]byte{
|
||||||
|
[]byte(`{"type":"response.completed","response":{"id":"resp_oauth_originator","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
captureDialer := &openAIWSCaptureDialer{conn: captureConn}
|
||||||
|
pool := newOpenAIWSConnPool(cfg)
|
||||||
|
pool.setClientDialerForTest(captureDialer)
|
||||||
|
|
||||||
|
svc := &OpenAIGatewayService{
|
||||||
|
cfg: cfg,
|
||||||
|
httpUpstream: &httpUpstreamRecorder{},
|
||||||
|
cache: &stubGatewayCache{},
|
||||||
|
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
|
||||||
|
toolCorrector: NewCodexToolCorrector(),
|
||||||
|
openaiWSPool: pool,
|
||||||
|
}
|
||||||
|
account := &Account{
|
||||||
|
ID: 129,
|
||||||
|
Name: "openai-oauth",
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Status: StatusActive,
|
||||||
|
Schedulable: true,
|
||||||
|
Concurrency: 1,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "oauth-token-1",
|
||||||
|
},
|
||||||
|
Extra: map[string]any{
|
||||||
|
"responses_websockets_v2_enabled": true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
body := []byte(`{"model":"gpt-5.1","stream":false,"input":[{"type":"input_text","text":"hello"}]}`)
|
||||||
|
result, err := svc.Forward(context.Background(), c, account, body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.Equal(t, tt.wantOriginator, captureDialer.lastHeaders.Get("originator"))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestOpenAIGatewayService_Forward_WSv2_HeaderSessionFallbackFromPromptCacheKey(t *testing.T) {
|
func TestOpenAIGatewayService_Forward_WSv2_HeaderSessionFallbackFromPromptCacheKey(t *testing.T) {
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
|||||||
@@ -107,7 +107,7 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
|
|||||||
|
|
||||||
isCodexCLI := false
|
isCodexCLI := false
|
||||||
if c != nil {
|
if c != nil {
|
||||||
isCodexCLI = openai.IsCodexCLIRequest(c.GetHeader("User-Agent"))
|
isCodexCLI = openai.IsCodexOfficialClientByHeaders(c.GetHeader("User-Agent"), c.GetHeader("originator"))
|
||||||
}
|
}
|
||||||
if s.cfg != nil && s.cfg.Gateway.ForceCodexCLI {
|
if s.cfg != nil && s.cfg.Gateway.ForceCodexCLI {
|
||||||
isCodexCLI = true
|
isCodexCLI = true
|
||||||
|
|||||||
Reference in New Issue
Block a user