diff --git a/backend/internal/pkg/openai/request.go b/backend/internal/pkg/openai/request.go index 5b049ddc..f6c3489a 100644 --- a/backend/internal/pkg/openai/request.go +++ b/backend/internal/pkg/openai/request.go @@ -1,5 +1,7 @@ package openai +import "strings" + // CodexCLIUserAgentPrefixes matches Codex CLI User-Agent patterns // Examples: "codex_vscode/1.0.0", "codex_cli_rs/0.1.2" var CodexCLIUserAgentPrefixes = []string{ @@ -9,8 +11,17 @@ var CodexCLIUserAgentPrefixes = []string{ // IsCodexCLIRequest checks if the User-Agent indicates a Codex CLI request func IsCodexCLIRequest(userAgent string) bool { + ua := strings.ToLower(strings.TrimSpace(userAgent)) + if ua == "" { + return false + } for _, prefix := range CodexCLIUserAgentPrefixes { - if len(userAgent) >= len(prefix) && userAgent[:len(prefix)] == prefix { + normalizedPrefix := strings.ToLower(strings.TrimSpace(prefix)) + if normalizedPrefix == "" { + continue + } + // 优先前缀匹配;若 UA 被网关/代理拼接为复合字符串时,退化为包含匹配。 + if strings.HasPrefix(ua, normalizedPrefix) || strings.Contains(ua, normalizedPrefix) { return true } } diff --git a/backend/internal/pkg/openai/request_test.go b/backend/internal/pkg/openai/request_test.go new file mode 100644 index 00000000..729321ff --- /dev/null +++ b/backend/internal/pkg/openai/request_test.go @@ -0,0 +1,28 @@ +package openai + +import "testing" + +func TestIsCodexCLIRequest(t *testing.T) { + tests := []struct { + name string + ua string + want bool + }{ + {name: "codex_cli_rs 前缀", ua: "codex_cli_rs/0.1.0", want: true}, + {name: "codex_vscode 前缀", ua: "codex_vscode/1.2.3", want: true}, + {name: "大小写混合", ua: "Codex_CLI_Rs/0.1.0", want: true}, + {name: "复合 UA 包含 codex", ua: "Mozilla/5.0 codex_cli_rs/0.1.0", want: true}, + {name: "空白包裹", ua: " codex_vscode/1.2.3 ", want: true}, + {name: "非 codex", ua: "curl/8.0.1", want: false}, + {name: "空字符串", ua: "", want: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := IsCodexCLIRequest(tt.ua) + if got != tt.want { + t.Fatalf("IsCodexCLIRequest(%q) = %v, want %v", tt.ua, got, tt.want) + } + }) + } +} diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index 63fb233e..06b41996 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -1027,6 +1027,17 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough( reqStream bool, startTime time.Time, ) (*OpenAIForwardResult, error) { + if account != nil && account.Type == AccountTypeOAuth { + normalizedBody, normalized, err := normalizeOpenAIPassthroughOAuthBody(body) + if err != nil { + return nil, err + } + if normalized { + body = normalizedBody + reqStream = true + } + } + logger.LegacyPrintf("service.openai_gateway", "[OpenAI 自动透传] 命中自动透传分支: account=%d name=%s type=%s model=%s stream=%v", account.ID, @@ -2572,6 +2583,37 @@ func extractOpenAIRequestMetaFromBody(body []byte) (model string, stream bool, p return model, stream, promptCacheKey } +// normalizeOpenAIPassthroughOAuthBody 将透传 OAuth 请求体收敛为旧链路关键行为: +// 1) store=false 2) stream=true +func normalizeOpenAIPassthroughOAuthBody(body []byte) ([]byte, bool, error) { + if len(body) == 0 { + return body, false, nil + } + + normalized := body + changed := false + + if store := gjson.GetBytes(normalized, "store"); !store.Exists() || store.Type != gjson.False { + next, err := sjson.SetBytes(normalized, "store", false) + if err != nil { + return body, false, fmt.Errorf("normalize passthrough body store=false: %w", err) + } + normalized = next + changed = true + } + + if stream := gjson.GetBytes(normalized, "stream"); !stream.Exists() || stream.Type != gjson.True { + next, err := sjson.SetBytes(normalized, "stream", true) + if err != nil { + return body, false, fmt.Errorf("normalize passthrough body stream=true: %w", err) + } + normalized = next + changed = true + } + + return normalized, changed, nil +} + func extractOpenAIReasoningEffortFromBody(body []byte, requestedModel string) *string { reasoningEffort := strings.TrimSpace(gjson.GetBytes(body, "reasoning.effort").String()) if reasoningEffort == "" { diff --git a/backend/internal/service/openai_oauth_passthrough_test.go b/backend/internal/service/openai_oauth_passthrough_test.go index 1722952c..e7bcc0bb 100644 --- a/backend/internal/service/openai_oauth_passthrough_test.go +++ b/backend/internal/service/openai_oauth_passthrough_test.go @@ -16,6 +16,7 @@ import ( "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/gin-gonic/gin" "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" ) func f64p(v float64) *float64 { return &v } @@ -119,7 +120,7 @@ func captureStructuredLog(t *testing.T) (*inMemoryLogSink, func()) { } } -func TestOpenAIGatewayService_OAuthPassthrough_StreamKeepsToolNameAndBodyUnchanged(t *testing.T) { +func TestOpenAIGatewayService_OAuthPassthrough_StreamKeepsToolNameAndBodyNormalized(t *testing.T) { gin.SetMode(gin.TestMode) rec := httptest.NewRecorder() @@ -178,8 +179,12 @@ func TestOpenAIGatewayService_OAuthPassthrough_StreamKeepsToolNameAndBodyUnchang require.NotNil(t, result) require.True(t, result.Stream) - // 1) upstream body is exactly unchanged - require.Equal(t, originalBody, upstream.lastBody) + // 1) 透传 OAuth 请求体与旧链路关键行为保持一致:store=false + stream=true。 + require.Equal(t, false, gjson.GetBytes(upstream.lastBody, "store").Bool()) + require.Equal(t, true, gjson.GetBytes(upstream.lastBody, "stream").Bool()) + // 其余关键字段保持原值。 + require.Equal(t, "gpt-5.2", gjson.GetBytes(upstream.lastBody, "model").String()) + require.Equal(t, "hi", gjson.GetBytes(upstream.lastBody, "input.0.text").String()) // 2) only auth is replaced; inbound auth/cookie are not forwarded require.Equal(t, "Bearer oauth-token", upstream.lastReq.Header.Get("Authorization")) @@ -246,6 +251,49 @@ func TestOpenAIGatewayService_OAuthPassthrough_DisabledUsesLegacyTransform(t *te require.Contains(t, string(upstream.lastBody), `"stream":true`) } +func TestOpenAIGatewayService_OAuthLegacy_CompositeCodexUAUsesCodexOriginator(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil)) + // 复合 UA(前缀不是 codex_cli_rs),历史实现会误判为非 Codex 并走 opencode。 + c.Request.Header.Set("User-Agent", "Mozilla/5.0 codex_cli_rs/0.1.0") + + inputBody := []byte(`{"model":"gpt-5.2","stream":true,"store":false,"input":[{"type":"text","text":"hi"}]}`) + + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid"}}, + Body: io.NopCloser(strings.NewReader("data: [DONE]\n\n")), + } + upstream := &httpUpstreamRecorder{resp: resp} + + svc := &OpenAIGatewayService{ + cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}}, + httpUpstream: upstream, + } + + account := &Account{ + ID: 123, + Name: "acc", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"}, + Extra: map[string]any{"openai_passthrough": false}, + Status: StatusActive, + Schedulable: true, + RateMultiplier: f64p(1), + } + + _, err := svc.Forward(context.Background(), c, account, inputBody) + require.NoError(t, err) + require.NotNil(t, upstream.lastReq) + require.Equal(t, "codex_cli_rs", upstream.lastReq.Header.Get("originator")) + require.NotEqual(t, "opencode", upstream.lastReq.Header.Get("originator")) +} + func TestOpenAIGatewayService_OAuthPassthrough_ResponseHeadersAllowXCodex(t *testing.T) { gin.SetMode(gin.TestMode) @@ -382,7 +430,8 @@ func TestOpenAIGatewayService_OAuthPassthrough_NonCodexUAFallbackToCodexUA(t *te _, err := svc.Forward(context.Background(), c, account, inputBody) require.NoError(t, err) - require.Equal(t, inputBody, upstream.lastBody) + require.Equal(t, false, gjson.GetBytes(upstream.lastBody, "store").Bool()) + require.Equal(t, true, gjson.GetBytes(upstream.lastBody, "stream").Bool()) require.Equal(t, "codex_cli_rs/0.98.0", upstream.lastReq.Header.Get("User-Agent")) }