feat(openai): 增加 OAuth 透传开关

- 仅对 Codex CLI 且账号开启时走原样透传(只替换认证)

- 透传模式禁用工具修正/模型替换,并旁路解析 usage 用于计费

- 管理后台增加开关与文案,ops upstream error 记录 passthrough 标记

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
yangjianbo
2026-02-11 00:59:39 +08:00
parent 86f3124720
commit f1e884ce2b
7 changed files with 821 additions and 0 deletions

View File

@@ -744,6 +744,8 @@ func (s *OpenAIGatewayService) handleFailoverSideEffects(ctx context.Context, re
func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte) (*OpenAIForwardResult, error) {
startTime := time.Now()
originalBody := body
// Parse request body once (avoid multiple parse/serialize cycles)
var reqBody map[string]any
if err := json.Unmarshal(body, &reqBody); err != nil {
@@ -764,6 +766,12 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
isCodexCLI := openai.IsCodexCLIRequest(c.GetHeader("User-Agent")) || (s.cfg != nil && s.cfg.Gateway.ForceCodexCLI)
passthroughEnabled := account.Type == AccountTypeOAuth && account.IsOpenAIOAuthPassthroughEnabled() && isCodexCLI
if passthroughEnabled {
reasoningEffort := extractOpenAIReasoningEffort(reqBody, reqModel)
return s.forwardOAuthPassthrough(ctx, c, account, originalBody, reqModel, reasoningEffort, reqStream, startTime)
}
// 对所有请求执行模型映射(包含 Codex CLI
mappedModel := account.GetMappedModel(reqModel)
if mappedModel != reqModel {
@@ -983,6 +991,372 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
}, nil
}
func (s *OpenAIGatewayService) forwardOAuthPassthrough(
ctx context.Context,
c *gin.Context,
account *Account,
body []byte,
reqModel string,
reasoningEffort *string,
reqStream bool,
startTime time.Time,
) (*OpenAIForwardResult, error) {
log.Printf("[OpenAI 透传] 已启用account=%d name=%s", account.ID, account.Name)
// Get access token
token, _, err := s.GetAccessToken(ctx, account)
if err != nil {
return nil, err
}
upstreamReq, err := s.buildUpstreamRequestOAuthPassthrough(ctx, c, account, body, token)
if err != nil {
return nil, err
}
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
if c != nil {
c.Set(OpsUpstreamRequestBodyKey, string(body))
c.Set("openai_passthrough", true)
}
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
if err != nil {
safeErr := sanitizeUpstreamErrorMessage(err.Error())
setOpsUpstreamError(c, 0, safeErr, "")
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: 0,
Passthrough: true,
Kind: "request_error",
Message: safeErr,
})
c.JSON(http.StatusBadGateway, gin.H{
"error": gin.H{
"type": "upstream_error",
"message": "Upstream request failed",
},
})
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode >= 400 {
// 透传模式不做 failover避免改变原始上游语义按上游原样返回错误响应。
return nil, s.handleErrorResponsePassthrough(ctx, resp, c, account)
}
var usage *OpenAIUsage
var firstTokenMs *int
if reqStream {
result, err := s.handleStreamingResponsePassthrough(ctx, resp, c, account, startTime)
if err != nil {
return nil, err
}
usage = result.usage
firstTokenMs = result.firstTokenMs
} else {
usage, err = s.handleNonStreamingResponsePassthrough(ctx, resp, c)
if err != nil {
return nil, err
}
}
if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil {
s.updateCodexUsageSnapshot(ctx, account.ID, snapshot)
}
if usage == nil {
usage = &OpenAIUsage{}
}
return &OpenAIForwardResult{
RequestID: resp.Header.Get("x-request-id"),
Usage: *usage,
Model: reqModel,
ReasoningEffort: reasoningEffort,
Stream: reqStream,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
}, nil
}
func (s *OpenAIGatewayService) buildUpstreamRequestOAuthPassthrough(
ctx context.Context,
c *gin.Context,
account *Account,
body []byte,
token string,
) (*http.Request, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodPost, chatgptCodexURL, bytes.NewReader(body))
if err != nil {
return nil, err
}
// 透传客户端请求头(尽可能原样),并做安全剔除。
if c != nil && c.Request != nil {
for key, values := range c.Request.Header {
lower := strings.ToLower(key)
if isOpenAIPassthroughBlockedRequestHeader(lower) {
continue
}
for _, v := range values {
req.Header.Add(key, v)
}
}
}
// 覆盖入站鉴权残留,并注入上游认证
req.Header.Del("authorization")
req.Header.Del("x-api-key")
req.Header.Del("x-goog-api-key")
req.Header.Set("authorization", "Bearer "+token)
// ChatGPT internal Codex API 必要头
req.Host = "chatgpt.com"
if chatgptAccountID := account.GetChatGPTAccountID(); chatgptAccountID != "" {
req.Header.Set("chatgpt-account-id", chatgptAccountID)
}
if req.Header.Get("OpenAI-Beta") == "" {
req.Header.Set("OpenAI-Beta", "responses=experimental")
}
if req.Header.Get("originator") == "" {
req.Header.Set("originator", "codex_cli_rs")
}
if req.Header.Get("content-type") == "" {
req.Header.Set("content-type", "application/json")
}
return req, nil
}
func (s *OpenAIGatewayService) handleErrorResponsePassthrough(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) error {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(body))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
upstreamDetail := ""
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
if maxBytes <= 0 {
maxBytes = 2048
}
upstreamDetail = truncateString(string(body), maxBytes)
}
setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail)
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Passthrough: true,
Kind: "http_error",
Message: upstreamMsg,
Detail: upstreamDetail,
UpstreamResponseBody: upstreamDetail,
})
writeOpenAIPassthroughResponseHeaders(c.Writer.Header(), resp.Header, s.cfg)
contentType := resp.Header.Get("Content-Type")
if contentType == "" {
contentType = "application/json"
}
c.Data(resp.StatusCode, contentType, body)
if upstreamMsg == "" {
return fmt.Errorf("upstream error: %d", resp.StatusCode)
}
return fmt.Errorf("upstream error: %d message=%s", resp.StatusCode, upstreamMsg)
}
func isOpenAIPassthroughBlockedRequestHeader(lowerKey string) bool {
switch lowerKey {
// hop-by-hop
case "connection", "transfer-encoding", "keep-alive", "proxy-connection", "upgrade", "te", "trailer":
return true
// 入站鉴权与潜在泄露
case "authorization", "x-api-key", "x-goog-api-key", "cookie":
return true
// 由 HTTP 库管理
case "host", "content-length":
return true
default:
return false
}
}
type openaiStreamingResultPassthrough struct {
usage *OpenAIUsage
firstTokenMs *int
}
func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
ctx context.Context,
resp *http.Response,
c *gin.Context,
account *Account,
startTime time.Time,
) (*openaiStreamingResultPassthrough, error) {
writeOpenAIPassthroughResponseHeaders(c.Writer.Header(), resp.Header, s.cfg)
// SSE headers
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("X-Accel-Buffering", "no")
if v := resp.Header.Get("x-request-id"); v != "" {
c.Header("x-request-id", v)
}
w := c.Writer
flusher, ok := w.(http.Flusher)
if !ok {
return nil, errors.New("streaming not supported")
}
usage := &OpenAIUsage{}
var firstTokenMs *int
scanner := bufio.NewScanner(resp.Body)
maxLineSize := defaultMaxLineSize
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
maxLineSize = s.cfg.Gateway.MaxLineSize
}
scanBuf := getSSEScannerBuf64K()
scanner.Buffer(scanBuf[:0], maxLineSize)
defer putSSEScannerBuf64K(scanBuf)
for scanner.Scan() {
line := scanner.Text()
if openaiSSEDataRe.MatchString(line) {
data := openaiSSEDataRe.ReplaceAllString(line, "")
if firstTokenMs == nil && strings.TrimSpace(data) != "" {
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
}
s.parseSSEUsage(data, usage)
}
if _, err := fmt.Fprintln(w, line); err != nil {
// 客户端断开时停止写入
break
}
flusher.Flush()
}
if err := scanner.Err(); err != nil {
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, nil
}
if errors.Is(err, bufio.ErrTooLong) {
log.Printf("[OpenAI passthrough] SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, err)
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, err
}
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", err)
}
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, nil
}
func (s *OpenAIGatewayService) handleNonStreamingResponsePassthrough(
ctx context.Context,
resp *http.Response,
c *gin.Context,
) (*OpenAIUsage, error) {
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
usage := &OpenAIUsage{}
usageParsed := false
if len(body) > 0 {
var response struct {
Usage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
InputTokenDetails struct {
CachedTokens int `json:"cached_tokens"`
} `json:"input_tokens_details"`
} `json:"usage"`
}
if json.Unmarshal(body, &response) == nil {
usage.InputTokens = response.Usage.InputTokens
usage.OutputTokens = response.Usage.OutputTokens
usage.CacheReadInputTokens = response.Usage.InputTokenDetails.CachedTokens
usageParsed = true
}
}
if !usageParsed {
// 兜底:尝试从 SSE 文本中解析 usage
usage = s.parseSSEUsageFromBody(string(body))
}
writeOpenAIPassthroughResponseHeaders(c.Writer.Header(), resp.Header, s.cfg)
contentType := resp.Header.Get("Content-Type")
if contentType == "" {
contentType = "application/json"
}
c.Data(resp.StatusCode, contentType, body)
return usage, nil
}
func writeOpenAIPassthroughResponseHeaders(dst http.Header, src http.Header, cfg *config.Config) {
if dst == nil || src == nil {
return
}
if cfg != nil {
responseheaders.WriteFilteredHeaders(dst, src, cfg.Security.ResponseHeaders)
} else {
// 兜底:尽量保留最基础的 content-type
if v := strings.TrimSpace(src.Get("Content-Type")); v != "" {
dst.Set("Content-Type", v)
}
}
// 透传模式强制放行 x-codex-* 响应头(若上游返回)。
// 注意:真实 http.Response.Header 的 key 一般会被 canonicalize但为了兼容测试/自建响应,
// 这里用 EqualFold 做一次大小写不敏感的查找。
getCaseInsensitiveValues := func(h http.Header, want string) []string {
if h == nil {
return nil
}
for k, vals := range h {
if strings.EqualFold(k, want) {
return vals
}
}
return nil
}
for _, rawKey := range []string{
"x-codex-primary-used-percent",
"x-codex-primary-reset-after-seconds",
"x-codex-primary-window-minutes",
"x-codex-secondary-used-percent",
"x-codex-secondary-reset-after-seconds",
"x-codex-secondary-window-minutes",
"x-codex-primary-over-secondary-limit-percent",
} {
vals := getCaseInsensitiveValues(src, rawKey)
if len(vals) == 0 {
continue
}
key := http.CanonicalHeaderKey(rawKey)
dst.Del(key)
for _, v := range vals {
dst.Add(key, v)
}
}
}
func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token string, isStream bool, promptCacheKey string, isCodexCLI bool) (*http.Request, error) {
// Determine target URL based on account type
var targetURL string
@@ -1904,6 +2278,9 @@ func (s *OpenAIGatewayService) updateCodexUsageSnapshot(ctx context.Context, acc
if snapshot == nil {
return
}
if s == nil || s.accountRepo == nil {
return
}
// Convert snapshot to map for merging into Extra
updates := make(map[string]any)