fix(openai): 透传OAuth强制store/stream并修复Codex识别
This commit is contained in:
@@ -1,5 +1,7 @@
|
|||||||
package openai
|
package openai
|
||||||
|
|
||||||
|
import "strings"
|
||||||
|
|
||||||
// CodexCLIUserAgentPrefixes matches Codex CLI User-Agent patterns
|
// CodexCLIUserAgentPrefixes matches Codex CLI User-Agent patterns
|
||||||
// Examples: "codex_vscode/1.0.0", "codex_cli_rs/0.1.2"
|
// Examples: "codex_vscode/1.0.0", "codex_cli_rs/0.1.2"
|
||||||
var CodexCLIUserAgentPrefixes = []string{
|
var CodexCLIUserAgentPrefixes = []string{
|
||||||
@@ -9,8 +11,17 @@ var CodexCLIUserAgentPrefixes = []string{
|
|||||||
|
|
||||||
// IsCodexCLIRequest checks if the User-Agent indicates a Codex CLI request
|
// IsCodexCLIRequest checks if the User-Agent indicates a Codex CLI request
|
||||||
func IsCodexCLIRequest(userAgent string) bool {
|
func IsCodexCLIRequest(userAgent string) bool {
|
||||||
|
ua := strings.ToLower(strings.TrimSpace(userAgent))
|
||||||
|
if ua == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
for _, prefix := range CodexCLIUserAgentPrefixes {
|
for _, prefix := range CodexCLIUserAgentPrefixes {
|
||||||
if len(userAgent) >= len(prefix) && userAgent[:len(prefix)] == prefix {
|
normalizedPrefix := strings.ToLower(strings.TrimSpace(prefix))
|
||||||
|
if normalizedPrefix == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// 优先前缀匹配;若 UA 被网关/代理拼接为复合字符串时,退化为包含匹配。
|
||||||
|
if strings.HasPrefix(ua, normalizedPrefix) || strings.Contains(ua, normalizedPrefix) {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
28
backend/internal/pkg/openai/request_test.go
Normal file
28
backend/internal/pkg/openai/request_test.go
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestIsCodexCLIRequest(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
ua string
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{name: "codex_cli_rs 前缀", ua: "codex_cli_rs/0.1.0", want: true},
|
||||||
|
{name: "codex_vscode 前缀", ua: "codex_vscode/1.2.3", want: true},
|
||||||
|
{name: "大小写混合", ua: "Codex_CLI_Rs/0.1.0", want: true},
|
||||||
|
{name: "复合 UA 包含 codex", ua: "Mozilla/5.0 codex_cli_rs/0.1.0", want: true},
|
||||||
|
{name: "空白包裹", ua: " codex_vscode/1.2.3 ", want: true},
|
||||||
|
{name: "非 codex", ua: "curl/8.0.1", want: false},
|
||||||
|
{name: "空字符串", ua: "", want: false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := IsCodexCLIRequest(tt.ua)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Fatalf("IsCodexCLIRequest(%q) = %v, want %v", tt.ua, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1027,6 +1027,17 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough(
|
|||||||
reqStream bool,
|
reqStream bool,
|
||||||
startTime time.Time,
|
startTime time.Time,
|
||||||
) (*OpenAIForwardResult, error) {
|
) (*OpenAIForwardResult, error) {
|
||||||
|
if account != nil && account.Type == AccountTypeOAuth {
|
||||||
|
normalizedBody, normalized, err := normalizeOpenAIPassthroughOAuthBody(body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if normalized {
|
||||||
|
body = normalizedBody
|
||||||
|
reqStream = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
logger.LegacyPrintf("service.openai_gateway",
|
logger.LegacyPrintf("service.openai_gateway",
|
||||||
"[OpenAI 自动透传] 命中自动透传分支: account=%d name=%s type=%s model=%s stream=%v",
|
"[OpenAI 自动透传] 命中自动透传分支: account=%d name=%s type=%s model=%s stream=%v",
|
||||||
account.ID,
|
account.ID,
|
||||||
@@ -2572,6 +2583,37 @@ func extractOpenAIRequestMetaFromBody(body []byte) (model string, stream bool, p
|
|||||||
return model, stream, promptCacheKey
|
return model, stream, promptCacheKey
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// normalizeOpenAIPassthroughOAuthBody 将透传 OAuth 请求体收敛为旧链路关键行为:
|
||||||
|
// 1) store=false 2) stream=true
|
||||||
|
func normalizeOpenAIPassthroughOAuthBody(body []byte) ([]byte, bool, error) {
|
||||||
|
if len(body) == 0 {
|
||||||
|
return body, false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
normalized := body
|
||||||
|
changed := false
|
||||||
|
|
||||||
|
if store := gjson.GetBytes(normalized, "store"); !store.Exists() || store.Type != gjson.False {
|
||||||
|
next, err := sjson.SetBytes(normalized, "store", false)
|
||||||
|
if err != nil {
|
||||||
|
return body, false, fmt.Errorf("normalize passthrough body store=false: %w", err)
|
||||||
|
}
|
||||||
|
normalized = next
|
||||||
|
changed = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if stream := gjson.GetBytes(normalized, "stream"); !stream.Exists() || stream.Type != gjson.True {
|
||||||
|
next, err := sjson.SetBytes(normalized, "stream", true)
|
||||||
|
if err != nil {
|
||||||
|
return body, false, fmt.Errorf("normalize passthrough body stream=true: %w", err)
|
||||||
|
}
|
||||||
|
normalized = next
|
||||||
|
changed = true
|
||||||
|
}
|
||||||
|
|
||||||
|
return normalized, changed, nil
|
||||||
|
}
|
||||||
|
|
||||||
func extractOpenAIReasoningEffortFromBody(body []byte, requestedModel string) *string {
|
func extractOpenAIReasoningEffortFromBody(body []byte, requestedModel string) *string {
|
||||||
reasoningEffort := strings.TrimSpace(gjson.GetBytes(body, "reasoning.effort").String())
|
reasoningEffort := strings.TrimSpace(gjson.GetBytes(body, "reasoning.effort").String())
|
||||||
if reasoningEffort == "" {
|
if reasoningEffort == "" {
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ import (
|
|||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
)
|
)
|
||||||
|
|
||||||
func f64p(v float64) *float64 { return &v }
|
func f64p(v float64) *float64 { return &v }
|
||||||
@@ -119,7 +120,7 @@ func captureStructuredLog(t *testing.T) (*inMemoryLogSink, func()) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestOpenAIGatewayService_OAuthPassthrough_StreamKeepsToolNameAndBodyUnchanged(t *testing.T) {
|
func TestOpenAIGatewayService_OAuthPassthrough_StreamKeepsToolNameAndBodyNormalized(t *testing.T) {
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
rec := httptest.NewRecorder()
|
rec := httptest.NewRecorder()
|
||||||
@@ -178,8 +179,12 @@ func TestOpenAIGatewayService_OAuthPassthrough_StreamKeepsToolNameAndBodyUnchang
|
|||||||
require.NotNil(t, result)
|
require.NotNil(t, result)
|
||||||
require.True(t, result.Stream)
|
require.True(t, result.Stream)
|
||||||
|
|
||||||
// 1) upstream body is exactly unchanged
|
// 1) 透传 OAuth 请求体与旧链路关键行为保持一致:store=false + stream=true。
|
||||||
require.Equal(t, originalBody, upstream.lastBody)
|
require.Equal(t, false, gjson.GetBytes(upstream.lastBody, "store").Bool())
|
||||||
|
require.Equal(t, true, gjson.GetBytes(upstream.lastBody, "stream").Bool())
|
||||||
|
// 其余关键字段保持原值。
|
||||||
|
require.Equal(t, "gpt-5.2", gjson.GetBytes(upstream.lastBody, "model").String())
|
||||||
|
require.Equal(t, "hi", gjson.GetBytes(upstream.lastBody, "input.0.text").String())
|
||||||
|
|
||||||
// 2) only auth is replaced; inbound auth/cookie are not forwarded
|
// 2) only auth is replaced; inbound auth/cookie are not forwarded
|
||||||
require.Equal(t, "Bearer oauth-token", upstream.lastReq.Header.Get("Authorization"))
|
require.Equal(t, "Bearer oauth-token", upstream.lastReq.Header.Get("Authorization"))
|
||||||
@@ -246,6 +251,49 @@ func TestOpenAIGatewayService_OAuthPassthrough_DisabledUsesLegacyTransform(t *te
|
|||||||
require.Contains(t, string(upstream.lastBody), `"stream":true`)
|
require.Contains(t, string(upstream.lastBody), `"stream":true`)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOpenAIGatewayService_OAuthLegacy_CompositeCodexUAUsesCodexOriginator(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil))
|
||||||
|
// 复合 UA(前缀不是 codex_cli_rs),历史实现会误判为非 Codex 并走 opencode。
|
||||||
|
c.Request.Header.Set("User-Agent", "Mozilla/5.0 codex_cli_rs/0.1.0")
|
||||||
|
|
||||||
|
inputBody := []byte(`{"model":"gpt-5.2","stream":true,"store":false,"input":[{"type":"text","text":"hi"}]}`)
|
||||||
|
|
||||||
|
resp := &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid"}},
|
||||||
|
Body: io.NopCloser(strings.NewReader("data: [DONE]\n\n")),
|
||||||
|
}
|
||||||
|
upstream := &httpUpstreamRecorder{resp: resp}
|
||||||
|
|
||||||
|
svc := &OpenAIGatewayService{
|
||||||
|
cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}},
|
||||||
|
httpUpstream: upstream,
|
||||||
|
}
|
||||||
|
|
||||||
|
account := &Account{
|
||||||
|
ID: 123,
|
||||||
|
Name: "acc",
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Concurrency: 1,
|
||||||
|
Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"},
|
||||||
|
Extra: map[string]any{"openai_passthrough": false},
|
||||||
|
Status: StatusActive,
|
||||||
|
Schedulable: true,
|
||||||
|
RateMultiplier: f64p(1),
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := svc.Forward(context.Background(), c, account, inputBody)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, upstream.lastReq)
|
||||||
|
require.Equal(t, "codex_cli_rs", upstream.lastReq.Header.Get("originator"))
|
||||||
|
require.NotEqual(t, "opencode", upstream.lastReq.Header.Get("originator"))
|
||||||
|
}
|
||||||
|
|
||||||
func TestOpenAIGatewayService_OAuthPassthrough_ResponseHeadersAllowXCodex(t *testing.T) {
|
func TestOpenAIGatewayService_OAuthPassthrough_ResponseHeadersAllowXCodex(t *testing.T) {
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
@@ -382,7 +430,8 @@ func TestOpenAIGatewayService_OAuthPassthrough_NonCodexUAFallbackToCodexUA(t *te
|
|||||||
|
|
||||||
_, err := svc.Forward(context.Background(), c, account, inputBody)
|
_, err := svc.Forward(context.Background(), c, account, inputBody)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, inputBody, upstream.lastBody)
|
require.Equal(t, false, gjson.GetBytes(upstream.lastBody, "store").Bool())
|
||||||
|
require.Equal(t, true, gjson.GetBytes(upstream.lastBody, "stream").Bool())
|
||||||
require.Equal(t, "codex_cli_rs/0.98.0", upstream.lastReq.Header.Get("User-Agent"))
|
require.Equal(t, "codex_cli_rs/0.98.0", upstream.lastReq.Header.Get("User-Agent"))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user