fix(openai): 透传OAuth强制store/stream并修复Codex识别

This commit is contained in:
yangjianbo
2026-02-12 21:02:52 +08:00
parent d411cf4472
commit 2f190d812a
4 changed files with 135 additions and 5 deletions

View File

@@ -1,5 +1,7 @@
package openai package openai
import "strings"
// CodexCLIUserAgentPrefixes matches Codex CLI User-Agent patterns // CodexCLIUserAgentPrefixes matches Codex CLI User-Agent patterns
// Examples: "codex_vscode/1.0.0", "codex_cli_rs/0.1.2" // Examples: "codex_vscode/1.0.0", "codex_cli_rs/0.1.2"
var CodexCLIUserAgentPrefixes = []string{ var CodexCLIUserAgentPrefixes = []string{
@@ -9,8 +11,17 @@ var CodexCLIUserAgentPrefixes = []string{
// IsCodexCLIRequest checks if the User-Agent indicates a Codex CLI request // IsCodexCLIRequest checks if the User-Agent indicates a Codex CLI request
func IsCodexCLIRequest(userAgent string) bool { func IsCodexCLIRequest(userAgent string) bool {
ua := strings.ToLower(strings.TrimSpace(userAgent))
if ua == "" {
return false
}
for _, prefix := range CodexCLIUserAgentPrefixes { for _, prefix := range CodexCLIUserAgentPrefixes {
if len(userAgent) >= len(prefix) && userAgent[:len(prefix)] == prefix { normalizedPrefix := strings.ToLower(strings.TrimSpace(prefix))
if normalizedPrefix == "" {
continue
}
// 优先前缀匹配;若 UA 被网关/代理拼接为复合字符串时,退化为包含匹配。
if strings.HasPrefix(ua, normalizedPrefix) || strings.Contains(ua, normalizedPrefix) {
return true return true
} }
} }

View File

@@ -0,0 +1,28 @@
package openai
import "testing"
func TestIsCodexCLIRequest(t *testing.T) {
tests := []struct {
name string
ua string
want bool
}{
{name: "codex_cli_rs 前缀", ua: "codex_cli_rs/0.1.0", want: true},
{name: "codex_vscode 前缀", ua: "codex_vscode/1.2.3", want: true},
{name: "大小写混合", ua: "Codex_CLI_Rs/0.1.0", want: true},
{name: "复合 UA 包含 codex", ua: "Mozilla/5.0 codex_cli_rs/0.1.0", want: true},
{name: "空白包裹", ua: " codex_vscode/1.2.3 ", want: true},
{name: "非 codex", ua: "curl/8.0.1", want: false},
{name: "空字符串", ua: "", want: false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := IsCodexCLIRequest(tt.ua)
if got != tt.want {
t.Fatalf("IsCodexCLIRequest(%q) = %v, want %v", tt.ua, got, tt.want)
}
})
}
}

View File

@@ -1027,6 +1027,17 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough(
reqStream bool, reqStream bool,
startTime time.Time, startTime time.Time,
) (*OpenAIForwardResult, error) { ) (*OpenAIForwardResult, error) {
if account != nil && account.Type == AccountTypeOAuth {
normalizedBody, normalized, err := normalizeOpenAIPassthroughOAuthBody(body)
if err != nil {
return nil, err
}
if normalized {
body = normalizedBody
reqStream = true
}
}
logger.LegacyPrintf("service.openai_gateway", logger.LegacyPrintf("service.openai_gateway",
"[OpenAI 自动透传] 命中自动透传分支: account=%d name=%s type=%s model=%s stream=%v", "[OpenAI 自动透传] 命中自动透传分支: account=%d name=%s type=%s model=%s stream=%v",
account.ID, account.ID,
@@ -2572,6 +2583,37 @@ func extractOpenAIRequestMetaFromBody(body []byte) (model string, stream bool, p
return model, stream, promptCacheKey return model, stream, promptCacheKey
} }
// normalizeOpenAIPassthroughOAuthBody 将透传 OAuth 请求体收敛为旧链路关键行为:
// 1) store=false 2) stream=true
func normalizeOpenAIPassthroughOAuthBody(body []byte) ([]byte, bool, error) {
if len(body) == 0 {
return body, false, nil
}
normalized := body
changed := false
if store := gjson.GetBytes(normalized, "store"); !store.Exists() || store.Type != gjson.False {
next, err := sjson.SetBytes(normalized, "store", false)
if err != nil {
return body, false, fmt.Errorf("normalize passthrough body store=false: %w", err)
}
normalized = next
changed = true
}
if stream := gjson.GetBytes(normalized, "stream"); !stream.Exists() || stream.Type != gjson.True {
next, err := sjson.SetBytes(normalized, "stream", true)
if err != nil {
return body, false, fmt.Errorf("normalize passthrough body stream=true: %w", err)
}
normalized = next
changed = true
}
return normalized, changed, nil
}
func extractOpenAIReasoningEffortFromBody(body []byte, requestedModel string) *string { func extractOpenAIReasoningEffortFromBody(body []byte, requestedModel string) *string {
reasoningEffort := strings.TrimSpace(gjson.GetBytes(body, "reasoning.effort").String()) reasoningEffort := strings.TrimSpace(gjson.GetBytes(body, "reasoning.effort").String())
if reasoningEffort == "" { if reasoningEffort == "" {

View File

@@ -16,6 +16,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
) )
func f64p(v float64) *float64 { return &v } func f64p(v float64) *float64 { return &v }
@@ -119,7 +120,7 @@ func captureStructuredLog(t *testing.T) (*inMemoryLogSink, func()) {
} }
} }
func TestOpenAIGatewayService_OAuthPassthrough_StreamKeepsToolNameAndBodyUnchanged(t *testing.T) { func TestOpenAIGatewayService_OAuthPassthrough_StreamKeepsToolNameAndBodyNormalized(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
@@ -178,8 +179,12 @@ func TestOpenAIGatewayService_OAuthPassthrough_StreamKeepsToolNameAndBodyUnchang
require.NotNil(t, result) require.NotNil(t, result)
require.True(t, result.Stream) require.True(t, result.Stream)
// 1) upstream body is exactly unchanged // 1) 透传 OAuth 请求体与旧链路关键行为保持一致:store=false + stream=true。
require.Equal(t, originalBody, upstream.lastBody) require.Equal(t, false, gjson.GetBytes(upstream.lastBody, "store").Bool())
require.Equal(t, true, gjson.GetBytes(upstream.lastBody, "stream").Bool())
// 其余关键字段保持原值。
require.Equal(t, "gpt-5.2", gjson.GetBytes(upstream.lastBody, "model").String())
require.Equal(t, "hi", gjson.GetBytes(upstream.lastBody, "input.0.text").String())
// 2) only auth is replaced; inbound auth/cookie are not forwarded // 2) only auth is replaced; inbound auth/cookie are not forwarded
require.Equal(t, "Bearer oauth-token", upstream.lastReq.Header.Get("Authorization")) require.Equal(t, "Bearer oauth-token", upstream.lastReq.Header.Get("Authorization"))
@@ -246,6 +251,49 @@ func TestOpenAIGatewayService_OAuthPassthrough_DisabledUsesLegacyTransform(t *te
require.Contains(t, string(upstream.lastBody), `"stream":true`) require.Contains(t, string(upstream.lastBody), `"stream":true`)
} }
func TestOpenAIGatewayService_OAuthLegacy_CompositeCodexUAUsesCodexOriginator(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil))
// 复合 UA前缀不是 codex_cli_rs历史实现会误判为非 Codex 并走 opencode。
c.Request.Header.Set("User-Agent", "Mozilla/5.0 codex_cli_rs/0.1.0")
inputBody := []byte(`{"model":"gpt-5.2","stream":true,"store":false,"input":[{"type":"text","text":"hi"}]}`)
resp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid"}},
Body: io.NopCloser(strings.NewReader("data: [DONE]\n\n")),
}
upstream := &httpUpstreamRecorder{resp: resp}
svc := &OpenAIGatewayService{
cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}},
httpUpstream: upstream,
}
account := &Account{
ID: 123,
Name: "acc",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"},
Extra: map[string]any{"openai_passthrough": false},
Status: StatusActive,
Schedulable: true,
RateMultiplier: f64p(1),
}
_, err := svc.Forward(context.Background(), c, account, inputBody)
require.NoError(t, err)
require.NotNil(t, upstream.lastReq)
require.Equal(t, "codex_cli_rs", upstream.lastReq.Header.Get("originator"))
require.NotEqual(t, "opencode", upstream.lastReq.Header.Get("originator"))
}
func TestOpenAIGatewayService_OAuthPassthrough_ResponseHeadersAllowXCodex(t *testing.T) { func TestOpenAIGatewayService_OAuthPassthrough_ResponseHeadersAllowXCodex(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
@@ -382,7 +430,8 @@ func TestOpenAIGatewayService_OAuthPassthrough_NonCodexUAFallbackToCodexUA(t *te
_, err := svc.Forward(context.Background(), c, account, inputBody) _, err := svc.Forward(context.Background(), c, account, inputBody)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, inputBody, upstream.lastBody) require.Equal(t, false, gjson.GetBytes(upstream.lastBody, "store").Bool())
require.Equal(t, true, gjson.GetBytes(upstream.lastBody, "stream").Bool())
require.Equal(t, "codex_cli_rs/0.98.0", upstream.lastReq.Header.Get("User-Agent")) require.Equal(t, "codex_cli_rs/0.98.0", upstream.lastReq.Header.Get("User-Agent"))
} }