diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go index 75c758da..67d607fa 100644 --- a/backend/internal/handler/openai_gateway_handler.go +++ b/backend/internal/handler/openai_gateway_handler.go @@ -13,13 +13,11 @@ import ( "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/pkg/ip" - "github.com/Wei-Shaw/sub2api/internal/pkg/openai" middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/gin-gonic/gin" "github.com/tidwall/gjson" - "github.com/tidwall/sjson" ) // OpenAIGatewayHandler handles OpenAI API gateway requests @@ -118,22 +116,6 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { } reqStream := streamResult.Bool() - userAgent := c.GetHeader("User-Agent") - isCodexCLI := openai.IsCodexCLIRequest(userAgent) || (h.cfg != nil && h.cfg.Gateway.ForceCodexCLI) - if !isCodexCLI { - existingInstructions := gjson.GetBytes(body, "instructions").String() - if strings.TrimSpace(existingInstructions) == "" { - if instructions := strings.TrimSpace(service.GetOpenCodeInstructions()); instructions != "" { - newBody, err := sjson.SetBytes(body, "instructions", instructions) - if err != nil { - h.errorResponse(c, http.StatusInternalServerError, "api_error", "Failed to process request") - return - } - body = newBody - } - } - } - setOpsRequestContext(c, reqModel, reqStream, body) // 提前校验 function_call_output 是否具备可关联上下文,避免上游 400。 diff --git a/backend/internal/service/account.go b/backend/internal/service/account.go index ba3a3081..96ff5ca3 100644 --- a/backend/internal/service/account.go +++ b/backend/internal/service/account.go @@ -696,23 +696,27 @@ func (a *Account) IsMixedSchedulingEnabled() bool { return false } -// IsOpenAIOAuthPassthroughEnabled 返回 OpenAI OAuth 账号是否启用“原样透传(仅替换认证)”。 +// IsOpenAIPassthroughEnabled 返回 OpenAI 账号是否启用“自动透传(仅替换认证)”。 // -// 存储位置:accounts.extra.openai_oauth_passthrough。 +// 新字段:accounts.extra.openai_passthrough。 +// 兼容字段:accounts.extra.openai_oauth_passthrough(历史 OAuth 开关)。 // 字段缺失或类型不正确时,按 false(关闭)处理。 +func (a *Account) IsOpenAIPassthroughEnabled() bool { + if a == nil || !a.IsOpenAI() || a.Extra == nil { + return false + } + if enabled, ok := a.Extra["openai_passthrough"].(bool); ok { + return enabled + } + if enabled, ok := a.Extra["openai_oauth_passthrough"].(bool); ok { + return enabled + } + return false +} + +// IsOpenAIOAuthPassthroughEnabled 兼容旧接口,等价于 OAuth 账号的 IsOpenAIPassthroughEnabled。 func (a *Account) IsOpenAIOAuthPassthroughEnabled() bool { - if a == nil || a.Extra == nil { - return false - } - v, ok := a.Extra["openai_oauth_passthrough"] - if !ok || v == nil { - return false - } - enabled, ok := v.(bool) - if !ok { - return false - } - return enabled + return a != nil && a.IsOpenAIOAuth() && a.IsOpenAIPassthroughEnabled() } // WindowCostSchedulability 窗口费用调度状态 diff --git a/backend/internal/service/account_openai_passthrough_test.go b/backend/internal/service/account_openai_passthrough_test.go new file mode 100644 index 00000000..f7a0d1cf --- /dev/null +++ b/backend/internal/service/account_openai_passthrough_test.go @@ -0,0 +1,72 @@ +package service + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestAccount_IsOpenAIPassthroughEnabled(t *testing.T) { + t.Run("新字段开启", func(t *testing.T) { + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Extra: map[string]any{ + "openai_passthrough": true, + }, + } + require.True(t, account.IsOpenAIPassthroughEnabled()) + }) + + t.Run("兼容旧字段", func(t *testing.T) { + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Extra: map[string]any{ + "openai_oauth_passthrough": true, + }, + } + require.True(t, account.IsOpenAIPassthroughEnabled()) + }) + + t.Run("非OpenAI账号始终关闭", func(t *testing.T) { + account := &Account{ + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Extra: map[string]any{ + "openai_passthrough": true, + }, + } + require.False(t, account.IsOpenAIPassthroughEnabled()) + }) + + t.Run("空额外配置默认关闭", func(t *testing.T) { + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + } + require.False(t, account.IsOpenAIPassthroughEnabled()) + }) +} + +func TestAccount_IsOpenAIOAuthPassthroughEnabled(t *testing.T) { + t.Run("仅OAuth类型允许返回开启", func(t *testing.T) { + oauthAccount := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Extra: map[string]any{ + "openai_passthrough": true, + }, + } + require.True(t, oauthAccount.IsOpenAIOAuthPassthroughEnabled()) + + apiKeyAccount := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Extra: map[string]any{ + "openai_passthrough": true, + }, + } + require.False(t, apiKeyAccount.IsOpenAIOAuthPassthroughEnabled()) + }) +} diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index 3ff20978..0e9a2580 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -747,11 +747,11 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco originalModel := reqModel isCodexCLI := openai.IsCodexCLIRequest(c.GetHeader("User-Agent")) || (s.cfg != nil && s.cfg.Gateway.ForceCodexCLI) - passthroughEnabled := account.Type == AccountTypeOAuth && account.IsOpenAIOAuthPassthroughEnabled() && isCodexCLI + passthroughEnabled := account.IsOpenAIPassthroughEnabled() if passthroughEnabled { // 透传分支只需要轻量提取字段,避免热路径全量 Unmarshal。 reasoningEffort := extractOpenAIReasoningEffortFromBody(body, reqModel) - return s.forwardOAuthPassthrough(ctx, c, account, originalBody, reqModel, reasoningEffort, reqStream, startTime) + return s.forwardOpenAIPassthrough(ctx, c, account, originalBody, reqModel, reasoningEffort, reqStream, startTime) } reqBody, err := getOpenAIRequestBodyMap(c, body) @@ -775,6 +775,14 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco // Track if body needs re-serialization bodyModified := false + // 非透传模式下,保持历史行为:非 Codex CLI 请求在 instructions 为空时注入默认指令。 + if !isCodexCLI && isInstructionsEmpty(reqBody) { + if instructions := strings.TrimSpace(GetOpenCodeInstructions()); instructions != "" { + reqBody["instructions"] = instructions + bodyModified = true + } + } + // 对所有请求执行模型映射(包含 Codex CLI)。 mappedModel := account.GetMappedModel(reqModel) if mappedModel != reqModel { @@ -994,7 +1002,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco }, nil } -func (s *OpenAIGatewayService) forwardOAuthPassthrough( +func (s *OpenAIGatewayService) forwardOpenAIPassthrough( ctx context.Context, c *gin.Context, account *Account, @@ -1012,7 +1020,7 @@ func (s *OpenAIGatewayService) forwardOAuthPassthrough( return nil, err } - upstreamReq, err := s.buildUpstreamRequestOAuthPassthrough(ctx, c, account, body, token) + upstreamReq, err := s.buildUpstreamRequestOpenAIPassthrough(ctx, c, account, body, token) if err != nil { return nil, err } @@ -1092,14 +1100,29 @@ func (s *OpenAIGatewayService) forwardOAuthPassthrough( }, nil } -func (s *OpenAIGatewayService) buildUpstreamRequestOAuthPassthrough( +func (s *OpenAIGatewayService) buildUpstreamRequestOpenAIPassthrough( ctx context.Context, c *gin.Context, account *Account, body []byte, token string, ) (*http.Request, error) { - req, err := http.NewRequestWithContext(ctx, http.MethodPost, chatgptCodexURL, bytes.NewReader(body)) + targetURL := openaiPlatformAPIURL + switch account.Type { + case AccountTypeOAuth: + targetURL = chatgptCodexURL + case AccountTypeAPIKey: + baseURL := account.GetOpenAIBaseURL() + if baseURL != "" { + validatedURL, err := s.validateUpstreamBaseURL(baseURL) + if err != nil { + return nil, err + } + targetURL = buildOpenAIResponsesURL(validatedURL) + } + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, bytes.NewReader(body)) if err != nil { return nil, err } @@ -1123,16 +1146,18 @@ func (s *OpenAIGatewayService) buildUpstreamRequestOAuthPassthrough( req.Header.Del("x-goog-api-key") req.Header.Set("authorization", "Bearer "+token) - // ChatGPT internal Codex API 必要头 - req.Host = "chatgpt.com" - if chatgptAccountID := account.GetChatGPTAccountID(); chatgptAccountID != "" { - req.Header.Set("chatgpt-account-id", chatgptAccountID) - } - if req.Header.Get("OpenAI-Beta") == "" { - req.Header.Set("OpenAI-Beta", "responses=experimental") - } - if req.Header.Get("originator") == "" { - req.Header.Set("originator", "codex_cli_rs") + // OAuth 透传到 ChatGPT internal API 时补齐必要头。 + if account.Type == AccountTypeOAuth { + req.Host = "chatgpt.com" + if chatgptAccountID := account.GetChatGPTAccountID(); chatgptAccountID != "" { + req.Header.Set("chatgpt-account-id", chatgptAccountID) + } + if req.Header.Get("OpenAI-Beta") == "" { + req.Header.Set("OpenAI-Beta", "responses=experimental") + } + if req.Header.Get("originator") == "" { + req.Header.Set("originator", "codex_cli_rs") + } } if req.Header.Get("content-type") == "" { @@ -1389,7 +1414,7 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin. if err != nil { return nil, err } - targetURL = validatedURL + "/responses" + targetURL = buildOpenAIResponsesURL(validatedURL) } default: targetURL = openaiPlatformAPIURL @@ -2084,6 +2109,21 @@ func (s *OpenAIGatewayService) validateUpstreamBaseURL(raw string) (string, erro return normalized, nil } +// buildOpenAIResponsesURL 组装 OpenAI Responses 端点。 +// - base 以 /v1 结尾:追加 /responses +// - base 已是 /responses:原样返回 +// - 其他情况:追加 /v1/responses +func buildOpenAIResponsesURL(base string) string { + normalized := strings.TrimRight(strings.TrimSpace(base), "/") + if strings.HasSuffix(normalized, "/responses") { + return normalized + } + if strings.HasSuffix(normalized, "/v1") { + return normalized + "/responses" + } + return normalized + "/v1/responses" +} + func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel, toModel string) []byte { // 使用 gjson/sjson 精确替换 model 字段,避免全量 JSON 反序列化 if m := gjson.GetBytes(body, "model"); m.Exists() && m.Str == fromModel { diff --git a/backend/internal/service/openai_oauth_passthrough_test.go b/backend/internal/service/openai_oauth_passthrough_test.go index 96805123..f6932469 100644 --- a/backend/internal/service/openai_oauth_passthrough_test.go +++ b/backend/internal/service/openai_oauth_passthrough_test.go @@ -88,7 +88,7 @@ func TestOpenAIGatewayService_OAuthPassthrough_StreamKeepsToolNameAndBodyUnchang Type: AccountTypeOAuth, Concurrency: 1, Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"}, - Extra: map[string]any{"openai_oauth_passthrough": true}, + Extra: map[string]any{"openai_passthrough": true}, Status: StatusActive, Schedulable: true, RateMultiplier: f64p(1), @@ -107,6 +107,7 @@ func TestOpenAIGatewayService_OAuthPassthrough_StreamKeepsToolNameAndBodyUnchang // 2) only auth is replaced; inbound auth/cookie are not forwarded require.Equal(t, "Bearer oauth-token", upstream.lastReq.Header.Get("Authorization")) + require.Equal(t, "codex_cli_rs/0.1.0", upstream.lastReq.Header.Get("User-Agent")) require.Empty(t, upstream.lastReq.Header.Get("Cookie")) require.Empty(t, upstream.lastReq.Header.Get("X-Api-Key")) require.Empty(t, upstream.lastReq.Header.Get("X-Goog-Api-Key")) @@ -154,7 +155,7 @@ func TestOpenAIGatewayService_OAuthPassthrough_DisabledUsesLegacyTransform(t *te Type: AccountTypeOAuth, Concurrency: 1, Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"}, - Extra: map[string]any{"openai_oauth_passthrough": false}, + Extra: map[string]any{"openai_passthrough": false}, Status: StatusActive, Schedulable: true, RateMultiplier: f64p(1), @@ -207,7 +208,7 @@ func TestOpenAIGatewayService_OAuthPassthrough_ResponseHeadersAllowXCodex(t *tes Type: AccountTypeOAuth, Concurrency: 1, Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"}, - Extra: map[string]any{"openai_oauth_passthrough": true}, + Extra: map[string]any{"openai_passthrough": true}, Status: StatusActive, Schedulable: true, RateMultiplier: f64p(1), @@ -249,7 +250,7 @@ func TestOpenAIGatewayService_OAuthPassthrough_UpstreamErrorIncludesPassthroughF Type: AccountTypeOAuth, Concurrency: 1, Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"}, - Extra: map[string]any{"openai_oauth_passthrough": true}, + Extra: map[string]any{"openai_passthrough": true}, Status: StatusActive, Schedulable: true, RateMultiplier: f64p(1), @@ -267,7 +268,7 @@ func TestOpenAIGatewayService_OAuthPassthrough_UpstreamErrorIncludesPassthroughF require.True(t, arr[len(arr)-1].Passthrough) } -func TestOpenAIGatewayService_OAuthPassthrough_RequiresCodexUAOrForceFlag(t *testing.T) { +func TestOpenAIGatewayService_OAuthPassthrough_NonCodexUAStillPassthroughWhenEnabled(t *testing.T) { gin.SetMode(gin.TestMode) rec := httptest.NewRecorder() @@ -297,7 +298,7 @@ func TestOpenAIGatewayService_OAuthPassthrough_RequiresCodexUAOrForceFlag(t *tes Type: AccountTypeOAuth, Concurrency: 1, Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"}, - Extra: map[string]any{"openai_oauth_passthrough": true}, + Extra: map[string]any{"openai_passthrough": true}, Status: StatusActive, Schedulable: true, RateMultiplier: f64p(1), @@ -305,16 +306,8 @@ func TestOpenAIGatewayService_OAuthPassthrough_RequiresCodexUAOrForceFlag(t *tes _, err := svc.Forward(context.Background(), c, account, inputBody) require.NoError(t, err) - // not codex, not forced => legacy transform should run - require.Contains(t, string(upstream.lastBody), `"store":false`) - require.Contains(t, string(upstream.lastBody), `"stream":true`) - - // now enable force flag => should passthrough and keep bytes - upstream2 := &httpUpstreamRecorder{resp: resp} - svc2 := &OpenAIGatewayService{cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: true}}, httpUpstream: upstream2} - _, err = svc2.Forward(context.Background(), c, account, inputBody) - require.NoError(t, err) - require.Equal(t, inputBody, upstream2.lastBody) + require.Equal(t, inputBody, upstream.lastBody) + require.Equal(t, "curl/8.0", upstream.lastReq.Header.Get("User-Agent")) } func TestOpenAIGatewayService_OAuthPassthrough_StreamingSetsFirstTokenMs(t *testing.T) { @@ -352,7 +345,7 @@ func TestOpenAIGatewayService_OAuthPassthrough_StreamingSetsFirstTokenMs(t *test Type: AccountTypeOAuth, Concurrency: 1, Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"}, - Extra: map[string]any{"openai_oauth_passthrough": true}, + Extra: map[string]any{"openai_passthrough": true}, Status: StatusActive, Schedulable: true, RateMultiplier: f64p(1), @@ -406,7 +399,7 @@ func TestOpenAIGatewayService_OAuthPassthrough_StreamClientDisconnectStillCollec Type: AccountTypeOAuth, Concurrency: 1, Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"}, - Extra: map[string]any{"openai_oauth_passthrough": true}, + Extra: map[string]any{"openai_passthrough": true}, Status: StatusActive, Schedulable: true, RateMultiplier: f64p(1), @@ -421,3 +414,48 @@ func TestOpenAIGatewayService_OAuthPassthrough_StreamClientDisconnectStillCollec require.Equal(t, 7, result.Usage.OutputTokens) require.Equal(t, 3, result.Usage.CacheReadInputTokens) } + +func TestOpenAIGatewayService_APIKeyPassthrough_PreservesBodyAndUsesResponsesEndpoint(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil)) + c.Request.Header.Set("User-Agent", "curl/8.0") + c.Request.Header.Set("X-Test", "keep") + + originalBody := []byte(`{"model":"gpt-5.2","stream":false,"max_output_tokens":128,"input":[{"type":"text","text":"hi"}]}`) + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}, "x-request-id": []string{"rid"}}, + Body: io.NopCloser(strings.NewReader(`{"output":[],"usage":{"input_tokens":1,"output_tokens":1,"input_tokens_details":{"cached_tokens":0}}}`)), + } + upstream := &httpUpstreamRecorder{resp: resp} + + svc := &OpenAIGatewayService{ + cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}}, + httpUpstream: upstream, + } + + account := &Account{ + ID: 456, + Name: "apikey-acc", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{"api_key": "sk-api-key", "base_url": "https://api.openai.com"}, + Extra: map[string]any{"openai_passthrough": true}, + Status: StatusActive, + Schedulable: true, + RateMultiplier: f64p(1), + } + + _, err := svc.Forward(context.Background(), c, account, originalBody) + require.NoError(t, err) + require.NotNil(t, upstream.lastReq) + require.Equal(t, originalBody, upstream.lastBody) + require.Equal(t, "https://api.openai.com/v1/responses", upstream.lastReq.URL.String()) + require.Equal(t, "Bearer sk-api-key", upstream.lastReq.Header.Get("Authorization")) + require.Equal(t, "curl/8.0", upstream.lastReq.Header.Get("User-Agent")) + require.Equal(t, "keep", upstream.lastReq.Header.Get("X-Test")) +} diff --git a/frontend/src/components/account/CreateAccountModal.vue b/frontend/src/components/account/CreateAccountModal.vue index a7290cbf..339044e5 100644 --- a/frontend/src/components/account/CreateAccountModal.vue +++ b/frontend/src/components/account/CreateAccountModal.vue @@ -866,77 +866,30 @@
- {{ t('admin.accounts.selectedModels', { count: allowedModels.length }) }} - {{ - t('admin.accounts.supportsAllModels') - }} +
+ {{ t('admin.accounts.openai.modelRestrictionDisabledByPassthrough') }}
+
+
+
+ {{ t('admin.accounts.selectedModels', { count: allowedModels.length }) }}
+ {{
+ t('admin.accounts.supportsAllModels')
+ }}
+
+ {{ t('admin.accounts.mapRequestModels') }}
+ {{ t('admin.accounts.expiresAtHint') }}
+ {{ t('admin.accounts.openai.oauthPassthroughDesc') }}
+
- {{ t('admin.accounts.selectedModels', { count: allowedModels.length }) }}
- {{
- t('admin.accounts.supportsAllModels')
- }}
+
+ {{ t('admin.accounts.openai.modelRestrictionDisabledByPassthrough') }}
+
+
+
+ {{ t('admin.accounts.selectedModels', { count: allowedModels.length }) }}
+ {{
+ t('admin.accounts.supportsAllModels')
+ }}
+
+ {{ t('admin.accounts.mapRequestModels') }}
+ {{ t('admin.accounts.expiresAtHint') }}