feat(openai): 增加 OAuth 透传开关
- 仅对 Codex CLI 且账号开启时走原样透传(只替换认证) - 透传模式禁用工具修正/模型替换,并旁路解析 usage 用于计费 - 管理后台增加开关与文案,ops upstream error 记录 passthrough 标记 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -696,6 +696,25 @@ func (a *Account) IsMixedSchedulingEnabled() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// IsOpenAIOAuthPassthroughEnabled 返回 OpenAI OAuth 账号是否启用“原样透传(仅替换认证)”。
|
||||
//
|
||||
// 存储位置:accounts.extra.openai_oauth_passthrough。
|
||||
// 字段缺失或类型不正确时,按 false(关闭)处理。
|
||||
func (a *Account) IsOpenAIOAuthPassthroughEnabled() bool {
|
||||
if a == nil || a.Extra == nil {
|
||||
return false
|
||||
}
|
||||
v, ok := a.Extra["openai_oauth_passthrough"]
|
||||
if !ok || v == nil {
|
||||
return false
|
||||
}
|
||||
enabled, ok := v.(bool)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return enabled
|
||||
}
|
||||
|
||||
// WindowCostSchedulability 窗口费用调度状态
|
||||
type WindowCostSchedulability int
|
||||
|
||||
|
||||
@@ -744,6 +744,8 @@ func (s *OpenAIGatewayService) handleFailoverSideEffects(ctx context.Context, re
|
||||
func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte) (*OpenAIForwardResult, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
originalBody := body
|
||||
|
||||
// Parse request body once (avoid multiple parse/serialize cycles)
|
||||
var reqBody map[string]any
|
||||
if err := json.Unmarshal(body, &reqBody); err != nil {
|
||||
@@ -764,6 +766,12 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
||||
|
||||
isCodexCLI := openai.IsCodexCLIRequest(c.GetHeader("User-Agent")) || (s.cfg != nil && s.cfg.Gateway.ForceCodexCLI)
|
||||
|
||||
passthroughEnabled := account.Type == AccountTypeOAuth && account.IsOpenAIOAuthPassthroughEnabled() && isCodexCLI
|
||||
if passthroughEnabled {
|
||||
reasoningEffort := extractOpenAIReasoningEffort(reqBody, reqModel)
|
||||
return s.forwardOAuthPassthrough(ctx, c, account, originalBody, reqModel, reasoningEffort, reqStream, startTime)
|
||||
}
|
||||
|
||||
// 对所有请求执行模型映射(包含 Codex CLI)。
|
||||
mappedModel := account.GetMappedModel(reqModel)
|
||||
if mappedModel != reqModel {
|
||||
@@ -983,6 +991,372 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) forwardOAuthPassthrough(
|
||||
ctx context.Context,
|
||||
c *gin.Context,
|
||||
account *Account,
|
||||
body []byte,
|
||||
reqModel string,
|
||||
reasoningEffort *string,
|
||||
reqStream bool,
|
||||
startTime time.Time,
|
||||
) (*OpenAIForwardResult, error) {
|
||||
log.Printf("[OpenAI 透传] 已启用:account=%d name=%s", account.ID, account.Name)
|
||||
|
||||
// Get access token
|
||||
token, _, err := s.GetAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
upstreamReq, err := s.buildUpstreamRequestOAuthPassthrough(ctx, c, account, body, token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
proxyURL := ""
|
||||
if account.ProxyID != nil && account.Proxy != nil {
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
|
||||
if c != nil {
|
||||
c.Set(OpsUpstreamRequestBodyKey, string(body))
|
||||
c.Set("openai_passthrough", true)
|
||||
}
|
||||
|
||||
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
|
||||
if err != nil {
|
||||
safeErr := sanitizeUpstreamErrorMessage(err.Error())
|
||||
setOpsUpstreamError(c, 0, safeErr, "")
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: 0,
|
||||
Passthrough: true,
|
||||
Kind: "request_error",
|
||||
Message: safeErr,
|
||||
})
|
||||
c.JSON(http.StatusBadGateway, gin.H{
|
||||
"error": gin.H{
|
||||
"type": "upstream_error",
|
||||
"message": "Upstream request failed",
|
||||
},
|
||||
})
|
||||
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
// 透传模式不做 failover(避免改变原始上游语义),按上游原样返回错误响应。
|
||||
return nil, s.handleErrorResponsePassthrough(ctx, resp, c, account)
|
||||
}
|
||||
|
||||
var usage *OpenAIUsage
|
||||
var firstTokenMs *int
|
||||
if reqStream {
|
||||
result, err := s.handleStreamingResponsePassthrough(ctx, resp, c, account, startTime)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
usage = result.usage
|
||||
firstTokenMs = result.firstTokenMs
|
||||
} else {
|
||||
usage, err = s.handleNonStreamingResponsePassthrough(ctx, resp, c)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil {
|
||||
s.updateCodexUsageSnapshot(ctx, account.ID, snapshot)
|
||||
}
|
||||
|
||||
if usage == nil {
|
||||
usage = &OpenAIUsage{}
|
||||
}
|
||||
|
||||
return &OpenAIForwardResult{
|
||||
RequestID: resp.Header.Get("x-request-id"),
|
||||
Usage: *usage,
|
||||
Model: reqModel,
|
||||
ReasoningEffort: reasoningEffort,
|
||||
Stream: reqStream,
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: firstTokenMs,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) buildUpstreamRequestOAuthPassthrough(
|
||||
ctx context.Context,
|
||||
c *gin.Context,
|
||||
account *Account,
|
||||
body []byte,
|
||||
token string,
|
||||
) (*http.Request, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, chatgptCodexURL, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 透传客户端请求头(尽可能原样),并做安全剔除。
|
||||
if c != nil && c.Request != nil {
|
||||
for key, values := range c.Request.Header {
|
||||
lower := strings.ToLower(key)
|
||||
if isOpenAIPassthroughBlockedRequestHeader(lower) {
|
||||
continue
|
||||
}
|
||||
for _, v := range values {
|
||||
req.Header.Add(key, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 覆盖入站鉴权残留,并注入上游认证
|
||||
req.Header.Del("authorization")
|
||||
req.Header.Del("x-api-key")
|
||||
req.Header.Del("x-goog-api-key")
|
||||
req.Header.Set("authorization", "Bearer "+token)
|
||||
|
||||
// ChatGPT internal Codex API 必要头
|
||||
req.Host = "chatgpt.com"
|
||||
if chatgptAccountID := account.GetChatGPTAccountID(); chatgptAccountID != "" {
|
||||
req.Header.Set("chatgpt-account-id", chatgptAccountID)
|
||||
}
|
||||
if req.Header.Get("OpenAI-Beta") == "" {
|
||||
req.Header.Set("OpenAI-Beta", "responses=experimental")
|
||||
}
|
||||
if req.Header.Get("originator") == "" {
|
||||
req.Header.Set("originator", "codex_cli_rs")
|
||||
}
|
||||
|
||||
if req.Header.Get("content-type") == "" {
|
||||
req.Header.Set("content-type", "application/json")
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) handleErrorResponsePassthrough(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) error {
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
|
||||
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(body))
|
||||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||
upstreamDetail := ""
|
||||
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||||
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
||||
if maxBytes <= 0 {
|
||||
maxBytes = 2048
|
||||
}
|
||||
upstreamDetail = truncateString(string(body), maxBytes)
|
||||
}
|
||||
setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail)
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: resp.StatusCode,
|
||||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||
Passthrough: true,
|
||||
Kind: "http_error",
|
||||
Message: upstreamMsg,
|
||||
Detail: upstreamDetail,
|
||||
UpstreamResponseBody: upstreamDetail,
|
||||
})
|
||||
|
||||
writeOpenAIPassthroughResponseHeaders(c.Writer.Header(), resp.Header, s.cfg)
|
||||
contentType := resp.Header.Get("Content-Type")
|
||||
if contentType == "" {
|
||||
contentType = "application/json"
|
||||
}
|
||||
c.Data(resp.StatusCode, contentType, body)
|
||||
|
||||
if upstreamMsg == "" {
|
||||
return fmt.Errorf("upstream error: %d", resp.StatusCode)
|
||||
}
|
||||
return fmt.Errorf("upstream error: %d message=%s", resp.StatusCode, upstreamMsg)
|
||||
}
|
||||
|
||||
func isOpenAIPassthroughBlockedRequestHeader(lowerKey string) bool {
|
||||
switch lowerKey {
|
||||
// hop-by-hop
|
||||
case "connection", "transfer-encoding", "keep-alive", "proxy-connection", "upgrade", "te", "trailer":
|
||||
return true
|
||||
// 入站鉴权与潜在泄露
|
||||
case "authorization", "x-api-key", "x-goog-api-key", "cookie":
|
||||
return true
|
||||
// 由 HTTP 库管理
|
||||
case "host", "content-length":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
type openaiStreamingResultPassthrough struct {
|
||||
usage *OpenAIUsage
|
||||
firstTokenMs *int
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
|
||||
ctx context.Context,
|
||||
resp *http.Response,
|
||||
c *gin.Context,
|
||||
account *Account,
|
||||
startTime time.Time,
|
||||
) (*openaiStreamingResultPassthrough, error) {
|
||||
writeOpenAIPassthroughResponseHeaders(c.Writer.Header(), resp.Header, s.cfg)
|
||||
|
||||
// SSE headers
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("X-Accel-Buffering", "no")
|
||||
if v := resp.Header.Get("x-request-id"); v != "" {
|
||||
c.Header("x-request-id", v)
|
||||
}
|
||||
|
||||
w := c.Writer
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
return nil, errors.New("streaming not supported")
|
||||
}
|
||||
|
||||
usage := &OpenAIUsage{}
|
||||
var firstTokenMs *int
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
maxLineSize := defaultMaxLineSize
|
||||
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
|
||||
maxLineSize = s.cfg.Gateway.MaxLineSize
|
||||
}
|
||||
scanBuf := getSSEScannerBuf64K()
|
||||
scanner.Buffer(scanBuf[:0], maxLineSize)
|
||||
defer putSSEScannerBuf64K(scanBuf)
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if openaiSSEDataRe.MatchString(line) {
|
||||
data := openaiSSEDataRe.ReplaceAllString(line, "")
|
||||
if firstTokenMs == nil && strings.TrimSpace(data) != "" {
|
||||
ms := int(time.Since(startTime).Milliseconds())
|
||||
firstTokenMs = &ms
|
||||
}
|
||||
s.parseSSEUsage(data, usage)
|
||||
}
|
||||
|
||||
if _, err := fmt.Fprintln(w, line); err != nil {
|
||||
// 客户端断开时停止写入
|
||||
break
|
||||
}
|
||||
flusher.Flush()
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, nil
|
||||
}
|
||||
if errors.Is(err, bufio.ErrTooLong) {
|
||||
log.Printf("[OpenAI passthrough] SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, err)
|
||||
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, err
|
||||
}
|
||||
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", err)
|
||||
}
|
||||
|
||||
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, nil
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) handleNonStreamingResponsePassthrough(
|
||||
ctx context.Context,
|
||||
resp *http.Response,
|
||||
c *gin.Context,
|
||||
) (*OpenAIUsage, error) {
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
usage := &OpenAIUsage{}
|
||||
usageParsed := false
|
||||
if len(body) > 0 {
|
||||
var response struct {
|
||||
Usage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
InputTokenDetails struct {
|
||||
CachedTokens int `json:"cached_tokens"`
|
||||
} `json:"input_tokens_details"`
|
||||
} `json:"usage"`
|
||||
}
|
||||
if json.Unmarshal(body, &response) == nil {
|
||||
usage.InputTokens = response.Usage.InputTokens
|
||||
usage.OutputTokens = response.Usage.OutputTokens
|
||||
usage.CacheReadInputTokens = response.Usage.InputTokenDetails.CachedTokens
|
||||
usageParsed = true
|
||||
}
|
||||
}
|
||||
if !usageParsed {
|
||||
// 兜底:尝试从 SSE 文本中解析 usage
|
||||
usage = s.parseSSEUsageFromBody(string(body))
|
||||
}
|
||||
|
||||
writeOpenAIPassthroughResponseHeaders(c.Writer.Header(), resp.Header, s.cfg)
|
||||
|
||||
contentType := resp.Header.Get("Content-Type")
|
||||
if contentType == "" {
|
||||
contentType = "application/json"
|
||||
}
|
||||
c.Data(resp.StatusCode, contentType, body)
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
func writeOpenAIPassthroughResponseHeaders(dst http.Header, src http.Header, cfg *config.Config) {
|
||||
if dst == nil || src == nil {
|
||||
return
|
||||
}
|
||||
if cfg != nil {
|
||||
responseheaders.WriteFilteredHeaders(dst, src, cfg.Security.ResponseHeaders)
|
||||
} else {
|
||||
// 兜底:尽量保留最基础的 content-type
|
||||
if v := strings.TrimSpace(src.Get("Content-Type")); v != "" {
|
||||
dst.Set("Content-Type", v)
|
||||
}
|
||||
}
|
||||
// 透传模式强制放行 x-codex-* 响应头(若上游返回)。
|
||||
// 注意:真实 http.Response.Header 的 key 一般会被 canonicalize;但为了兼容测试/自建响应,
|
||||
// 这里用 EqualFold 做一次大小写不敏感的查找。
|
||||
getCaseInsensitiveValues := func(h http.Header, want string) []string {
|
||||
if h == nil {
|
||||
return nil
|
||||
}
|
||||
for k, vals := range h {
|
||||
if strings.EqualFold(k, want) {
|
||||
return vals
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, rawKey := range []string{
|
||||
"x-codex-primary-used-percent",
|
||||
"x-codex-primary-reset-after-seconds",
|
||||
"x-codex-primary-window-minutes",
|
||||
"x-codex-secondary-used-percent",
|
||||
"x-codex-secondary-reset-after-seconds",
|
||||
"x-codex-secondary-window-minutes",
|
||||
"x-codex-primary-over-secondary-limit-percent",
|
||||
} {
|
||||
vals := getCaseInsensitiveValues(src, rawKey)
|
||||
if len(vals) == 0 {
|
||||
continue
|
||||
}
|
||||
key := http.CanonicalHeaderKey(rawKey)
|
||||
dst.Del(key)
|
||||
for _, v := range vals {
|
||||
dst.Add(key, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token string, isStream bool, promptCacheKey string, isCodexCLI bool) (*http.Request, error) {
|
||||
// Determine target URL based on account type
|
||||
var targetURL string
|
||||
@@ -1904,6 +2278,9 @@ func (s *OpenAIGatewayService) updateCodexUsageSnapshot(ctx context.Context, acc
|
||||
if snapshot == nil {
|
||||
return
|
||||
}
|
||||
if s == nil || s.accountRepo == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Convert snapshot to map for merging into Extra
|
||||
updates := make(map[string]any)
|
||||
|
||||
364
backend/internal/service/openai_oauth_passthrough_test.go
Normal file
364
backend/internal/service/openai_oauth_passthrough_test.go
Normal file
@@ -0,0 +1,364 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func f64p(v float64) *float64 { return &v }
|
||||
|
||||
type httpUpstreamRecorder struct {
|
||||
lastReq *http.Request
|
||||
lastBody []byte
|
||||
|
||||
resp *http.Response
|
||||
err error
|
||||
}
|
||||
|
||||
func (u *httpUpstreamRecorder) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) {
|
||||
u.lastReq = req
|
||||
if req != nil && req.Body != nil {
|
||||
b, _ := io.ReadAll(req.Body)
|
||||
u.lastBody = b
|
||||
_ = req.Body.Close()
|
||||
req.Body = io.NopCloser(bytes.NewReader(b))
|
||||
}
|
||||
if u.err != nil {
|
||||
return nil, u.err
|
||||
}
|
||||
return u.resp, nil
|
||||
}
|
||||
|
||||
func (u *httpUpstreamRecorder) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, enableTLSFingerprint bool) (*http.Response, error) {
|
||||
return u.Do(req, proxyURL, accountID, accountConcurrency)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_OAuthPassthrough_StreamKeepsToolNameAndBodyUnchanged(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil))
|
||||
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0")
|
||||
c.Request.Header.Set("Authorization", "Bearer inbound-should-not-forward")
|
||||
c.Request.Header.Set("Cookie", "secret=1")
|
||||
c.Request.Header.Set("X-Api-Key", "sk-inbound")
|
||||
c.Request.Header.Set("X-Goog-Api-Key", "goog-inbound")
|
||||
c.Request.Header.Set("X-Test", "keep")
|
||||
|
||||
originalBody := []byte(`{"model":"gpt-5.2","stream":true,"store":true,"input":[{"type":"text","text":"hi"}]}`)
|
||||
|
||||
upstreamSSE := strings.Join([]string{
|
||||
`data: {"type":"response.output_item.added","item":{"type":"tool_call","tool_calls":[{"function":{"name":"apply_patch"}}]}}`,
|
||||
"",
|
||||
"data: [DONE]",
|
||||
"",
|
||||
}, "\n")
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid"}},
|
||||
Body: io.NopCloser(strings.NewReader(upstreamSSE)),
|
||||
}
|
||||
upstream := &httpUpstreamRecorder{resp: resp}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}},
|
||||
httpUpstream: upstream,
|
||||
openAITokenProvider: &OpenAITokenProvider{ // minimal: will be bypassed by nil cache/service, but GetAccessToken uses provider only if non-nil
|
||||
accountRepo: nil,
|
||||
},
|
||||
}
|
||||
|
||||
account := &Account{
|
||||
ID: 123,
|
||||
Name: "acc",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"},
|
||||
Extra: map[string]any{"openai_oauth_passthrough": true},
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
RateMultiplier: f64p(1),
|
||||
}
|
||||
|
||||
// Use the gateway method that reads token from credentials when provider is nil.
|
||||
svc.openAITokenProvider = nil
|
||||
|
||||
result, err := svc.Forward(context.Background(), c, account, originalBody)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.True(t, result.Stream)
|
||||
|
||||
// 1) upstream body is exactly unchanged
|
||||
require.Equal(t, originalBody, upstream.lastBody)
|
||||
|
||||
// 2) only auth is replaced; inbound auth/cookie are not forwarded
|
||||
require.Equal(t, "Bearer oauth-token", upstream.lastReq.Header.Get("Authorization"))
|
||||
require.Empty(t, upstream.lastReq.Header.Get("Cookie"))
|
||||
require.Empty(t, upstream.lastReq.Header.Get("X-Api-Key"))
|
||||
require.Empty(t, upstream.lastReq.Header.Get("X-Goog-Api-Key"))
|
||||
require.Equal(t, "keep", upstream.lastReq.Header.Get("X-Test"))
|
||||
|
||||
// 3) required OAuth headers are present
|
||||
require.Equal(t, "chatgpt.com", upstream.lastReq.Host)
|
||||
require.Equal(t, "chatgpt-acc", upstream.lastReq.Header.Get("chatgpt-account-id"))
|
||||
|
||||
// 4) downstream SSE keeps tool name (no toolCorrector)
|
||||
body := rec.Body.String()
|
||||
require.Contains(t, body, "apply_patch")
|
||||
require.NotContains(t, body, "\"name\":\"edit\"")
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_OAuthPassthrough_DisabledUsesLegacyTransform(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil))
|
||||
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0")
|
||||
|
||||
// store=true + stream=false should be forced to store=false + stream=true by applyCodexOAuthTransform (OAuth legacy path)
|
||||
inputBody := []byte(`{"model":"gpt-5.2","stream":false,"store":true,"input":[{"type":"text","text":"hi"}]}`)
|
||||
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid"}},
|
||||
Body: io.NopCloser(strings.NewReader("data: [DONE]\n\n")),
|
||||
}
|
||||
upstream := &httpUpstreamRecorder{resp: resp}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}},
|
||||
httpUpstream: upstream,
|
||||
}
|
||||
|
||||
account := &Account{
|
||||
ID: 123,
|
||||
Name: "acc",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"},
|
||||
Extra: map[string]any{"openai_oauth_passthrough": false},
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
RateMultiplier: f64p(1),
|
||||
}
|
||||
|
||||
_, err := svc.Forward(context.Background(), c, account, inputBody)
|
||||
require.NoError(t, err)
|
||||
|
||||
// legacy path rewrites request body (not byte-equal)
|
||||
require.NotEqual(t, inputBody, upstream.lastBody)
|
||||
require.Contains(t, string(upstream.lastBody), `"store":false`)
|
||||
require.Contains(t, string(upstream.lastBody), `"stream":true`)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_OAuthPassthrough_ResponseHeadersAllowXCodex(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil))
|
||||
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0")
|
||||
|
||||
originalBody := []byte(`{"model":"gpt-5.2","stream":false,"input":[{"type":"text","text":"hi"}]}`)
|
||||
|
||||
headers := make(http.Header)
|
||||
headers.Set("Content-Type", "application/json")
|
||||
headers.Set("x-request-id", "rid")
|
||||
headers.Set("x-codex-primary-used-percent", "12")
|
||||
headers.Set("x-codex-secondary-used-percent", "34")
|
||||
headers.Set("x-codex-primary-window-minutes", "300")
|
||||
headers.Set("x-codex-secondary-window-minutes", "10080")
|
||||
headers.Set("x-codex-primary-reset-after-seconds", "1")
|
||||
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: headers,
|
||||
Body: io.NopCloser(strings.NewReader(`{"output":[],"usage":{"input_tokens":1,"output_tokens":1,"input_tokens_details":{"cached_tokens":0}}}`)),
|
||||
}
|
||||
upstream := &httpUpstreamRecorder{resp: resp}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}},
|
||||
httpUpstream: upstream,
|
||||
}
|
||||
|
||||
account := &Account{
|
||||
ID: 123,
|
||||
Name: "acc",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"},
|
||||
Extra: map[string]any{"openai_oauth_passthrough": true},
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
RateMultiplier: f64p(1),
|
||||
}
|
||||
|
||||
_, err := svc.Forward(context.Background(), c, account, originalBody)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, "12", rec.Header().Get("x-codex-primary-used-percent"))
|
||||
require.Equal(t, "34", rec.Header().Get("x-codex-secondary-used-percent"))
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_OAuthPassthrough_UpstreamErrorIncludesPassthroughFlag(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil))
|
||||
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0")
|
||||
|
||||
originalBody := []byte(`{"model":"gpt-5.2","stream":false,"input":[{"type":"text","text":"hi"}]}`)
|
||||
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusBadRequest,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}, "x-request-id": []string{"rid"}},
|
||||
Body: io.NopCloser(strings.NewReader(`{"error":{"message":"bad"}}`)),
|
||||
}
|
||||
upstream := &httpUpstreamRecorder{resp: resp}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}},
|
||||
httpUpstream: upstream,
|
||||
}
|
||||
|
||||
account := &Account{
|
||||
ID: 123,
|
||||
Name: "acc",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"},
|
||||
Extra: map[string]any{"openai_oauth_passthrough": true},
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
RateMultiplier: f64p(1),
|
||||
}
|
||||
|
||||
_, err := svc.Forward(context.Background(), c, account, originalBody)
|
||||
require.Error(t, err)
|
||||
|
||||
// should append an upstream error event with passthrough=true
|
||||
v, ok := c.Get(OpsUpstreamErrorsKey)
|
||||
require.True(t, ok)
|
||||
arr, ok := v.([]*OpsUpstreamErrorEvent)
|
||||
require.True(t, ok)
|
||||
require.NotEmpty(t, arr)
|
||||
require.True(t, arr[len(arr)-1].Passthrough)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_OAuthPassthrough_RequiresCodexUAOrForceFlag(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil))
|
||||
// Non-Codex UA
|
||||
c.Request.Header.Set("User-Agent", "curl/8.0")
|
||||
|
||||
inputBody := []byte(`{"model":"gpt-5.2","stream":false,"store":true,"input":[{"type":"text","text":"hi"}]}`)
|
||||
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid"}},
|
||||
Body: io.NopCloser(strings.NewReader("data: [DONE]\n\n")),
|
||||
}
|
||||
upstream := &httpUpstreamRecorder{resp: resp}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}},
|
||||
httpUpstream: upstream,
|
||||
}
|
||||
|
||||
account := &Account{
|
||||
ID: 123,
|
||||
Name: "acc",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"},
|
||||
Extra: map[string]any{"openai_oauth_passthrough": true},
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
RateMultiplier: f64p(1),
|
||||
}
|
||||
|
||||
_, err := svc.Forward(context.Background(), c, account, inputBody)
|
||||
require.NoError(t, err)
|
||||
// not codex, not forced => legacy transform should run
|
||||
require.Contains(t, string(upstream.lastBody), `"store":false`)
|
||||
require.Contains(t, string(upstream.lastBody), `"stream":true`)
|
||||
|
||||
// now enable force flag => should passthrough and keep bytes
|
||||
upstream2 := &httpUpstreamRecorder{resp: resp}
|
||||
svc2 := &OpenAIGatewayService{cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: true}}, httpUpstream: upstream2}
|
||||
_, err = svc2.Forward(context.Background(), c, account, inputBody)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, inputBody, upstream2.lastBody)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_OAuthPassthrough_StreamingSetsFirstTokenMs(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil))
|
||||
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0")
|
||||
|
||||
originalBody := []byte(`{"model":"gpt-5.2","stream":true,"input":[{"type":"text","text":"hi"}]}`)
|
||||
|
||||
upstreamSSE := strings.Join([]string{
|
||||
`data: {"type":"response.output_text.delta","delta":"h"}`,
|
||||
"",
|
||||
"data: [DONE]",
|
||||
"",
|
||||
}, "\n")
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid"}},
|
||||
Body: io.NopCloser(strings.NewReader(upstreamSSE)),
|
||||
}
|
||||
upstream := &httpUpstreamRecorder{resp: resp}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}},
|
||||
httpUpstream: upstream,
|
||||
}
|
||||
|
||||
account := &Account{
|
||||
ID: 123,
|
||||
Name: "acc",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"},
|
||||
Extra: map[string]any{"openai_oauth_passthrough": true},
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
RateMultiplier: f64p(1),
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
result, err := svc.Forward(context.Background(), c, account, originalBody)
|
||||
require.NoError(t, err)
|
||||
// sanity: duration after start
|
||||
require.GreaterOrEqual(t, time.Since(start), time.Duration(0))
|
||||
require.NotNil(t, result.FirstTokenMs)
|
||||
require.GreaterOrEqual(t, *result.FirstTokenMs, 0)
|
||||
}
|
||||
@@ -42,6 +42,10 @@ func setOpsUpstreamError(c *gin.Context, upstreamStatusCode int, upstreamMessage
|
||||
type OpsUpstreamErrorEvent struct {
|
||||
AtUnixMs int64 `json:"at_unix_ms,omitempty"`
|
||||
|
||||
// Passthrough 表示本次请求是否命中“原样透传(仅替换认证)”分支。
|
||||
// 该字段用于排障与灰度评估;存入 JSON,不涉及 DB schema 变更。
|
||||
Passthrough bool `json:"passthrough,omitempty"`
|
||||
|
||||
// Context
|
||||
Platform string `json:"platform,omitempty"`
|
||||
AccountID int64 `json:"account_id,omitempty"`
|
||||
|
||||
Reference in New Issue
Block a user