feat(openai): 增加 OAuth 透传开关
- 仅对 Codex CLI 且账号开启时走原样透传(只替换认证) - 透传模式禁用工具修正/模型替换,并旁路解析 usage 用于计费 - 管理后台增加开关与文案,ops upstream error 记录 passthrough 标记 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -744,6 +744,8 @@ func (s *OpenAIGatewayService) handleFailoverSideEffects(ctx context.Context, re
|
||||
func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte) (*OpenAIForwardResult, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
originalBody := body
|
||||
|
||||
// Parse request body once (avoid multiple parse/serialize cycles)
|
||||
var reqBody map[string]any
|
||||
if err := json.Unmarshal(body, &reqBody); err != nil {
|
||||
@@ -764,6 +766,12 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
||||
|
||||
isCodexCLI := openai.IsCodexCLIRequest(c.GetHeader("User-Agent")) || (s.cfg != nil && s.cfg.Gateway.ForceCodexCLI)
|
||||
|
||||
passthroughEnabled := account.Type == AccountTypeOAuth && account.IsOpenAIOAuthPassthroughEnabled() && isCodexCLI
|
||||
if passthroughEnabled {
|
||||
reasoningEffort := extractOpenAIReasoningEffort(reqBody, reqModel)
|
||||
return s.forwardOAuthPassthrough(ctx, c, account, originalBody, reqModel, reasoningEffort, reqStream, startTime)
|
||||
}
|
||||
|
||||
// 对所有请求执行模型映射(包含 Codex CLI)。
|
||||
mappedModel := account.GetMappedModel(reqModel)
|
||||
if mappedModel != reqModel {
|
||||
@@ -983,6 +991,372 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) forwardOAuthPassthrough(
|
||||
ctx context.Context,
|
||||
c *gin.Context,
|
||||
account *Account,
|
||||
body []byte,
|
||||
reqModel string,
|
||||
reasoningEffort *string,
|
||||
reqStream bool,
|
||||
startTime time.Time,
|
||||
) (*OpenAIForwardResult, error) {
|
||||
log.Printf("[OpenAI 透传] 已启用:account=%d name=%s", account.ID, account.Name)
|
||||
|
||||
// Get access token
|
||||
token, _, err := s.GetAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
upstreamReq, err := s.buildUpstreamRequestOAuthPassthrough(ctx, c, account, body, token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
proxyURL := ""
|
||||
if account.ProxyID != nil && account.Proxy != nil {
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
|
||||
if c != nil {
|
||||
c.Set(OpsUpstreamRequestBodyKey, string(body))
|
||||
c.Set("openai_passthrough", true)
|
||||
}
|
||||
|
||||
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
|
||||
if err != nil {
|
||||
safeErr := sanitizeUpstreamErrorMessage(err.Error())
|
||||
setOpsUpstreamError(c, 0, safeErr, "")
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: 0,
|
||||
Passthrough: true,
|
||||
Kind: "request_error",
|
||||
Message: safeErr,
|
||||
})
|
||||
c.JSON(http.StatusBadGateway, gin.H{
|
||||
"error": gin.H{
|
||||
"type": "upstream_error",
|
||||
"message": "Upstream request failed",
|
||||
},
|
||||
})
|
||||
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
// 透传模式不做 failover(避免改变原始上游语义),按上游原样返回错误响应。
|
||||
return nil, s.handleErrorResponsePassthrough(ctx, resp, c, account)
|
||||
}
|
||||
|
||||
var usage *OpenAIUsage
|
||||
var firstTokenMs *int
|
||||
if reqStream {
|
||||
result, err := s.handleStreamingResponsePassthrough(ctx, resp, c, account, startTime)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
usage = result.usage
|
||||
firstTokenMs = result.firstTokenMs
|
||||
} else {
|
||||
usage, err = s.handleNonStreamingResponsePassthrough(ctx, resp, c)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil {
|
||||
s.updateCodexUsageSnapshot(ctx, account.ID, snapshot)
|
||||
}
|
||||
|
||||
if usage == nil {
|
||||
usage = &OpenAIUsage{}
|
||||
}
|
||||
|
||||
return &OpenAIForwardResult{
|
||||
RequestID: resp.Header.Get("x-request-id"),
|
||||
Usage: *usage,
|
||||
Model: reqModel,
|
||||
ReasoningEffort: reasoningEffort,
|
||||
Stream: reqStream,
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: firstTokenMs,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) buildUpstreamRequestOAuthPassthrough(
|
||||
ctx context.Context,
|
||||
c *gin.Context,
|
||||
account *Account,
|
||||
body []byte,
|
||||
token string,
|
||||
) (*http.Request, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, chatgptCodexURL, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 透传客户端请求头(尽可能原样),并做安全剔除。
|
||||
if c != nil && c.Request != nil {
|
||||
for key, values := range c.Request.Header {
|
||||
lower := strings.ToLower(key)
|
||||
if isOpenAIPassthroughBlockedRequestHeader(lower) {
|
||||
continue
|
||||
}
|
||||
for _, v := range values {
|
||||
req.Header.Add(key, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 覆盖入站鉴权残留,并注入上游认证
|
||||
req.Header.Del("authorization")
|
||||
req.Header.Del("x-api-key")
|
||||
req.Header.Del("x-goog-api-key")
|
||||
req.Header.Set("authorization", "Bearer "+token)
|
||||
|
||||
// ChatGPT internal Codex API 必要头
|
||||
req.Host = "chatgpt.com"
|
||||
if chatgptAccountID := account.GetChatGPTAccountID(); chatgptAccountID != "" {
|
||||
req.Header.Set("chatgpt-account-id", chatgptAccountID)
|
||||
}
|
||||
if req.Header.Get("OpenAI-Beta") == "" {
|
||||
req.Header.Set("OpenAI-Beta", "responses=experimental")
|
||||
}
|
||||
if req.Header.Get("originator") == "" {
|
||||
req.Header.Set("originator", "codex_cli_rs")
|
||||
}
|
||||
|
||||
if req.Header.Get("content-type") == "" {
|
||||
req.Header.Set("content-type", "application/json")
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) handleErrorResponsePassthrough(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) error {
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
|
||||
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(body))
|
||||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||
upstreamDetail := ""
|
||||
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||||
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
||||
if maxBytes <= 0 {
|
||||
maxBytes = 2048
|
||||
}
|
||||
upstreamDetail = truncateString(string(body), maxBytes)
|
||||
}
|
||||
setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail)
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: resp.StatusCode,
|
||||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||
Passthrough: true,
|
||||
Kind: "http_error",
|
||||
Message: upstreamMsg,
|
||||
Detail: upstreamDetail,
|
||||
UpstreamResponseBody: upstreamDetail,
|
||||
})
|
||||
|
||||
writeOpenAIPassthroughResponseHeaders(c.Writer.Header(), resp.Header, s.cfg)
|
||||
contentType := resp.Header.Get("Content-Type")
|
||||
if contentType == "" {
|
||||
contentType = "application/json"
|
||||
}
|
||||
c.Data(resp.StatusCode, contentType, body)
|
||||
|
||||
if upstreamMsg == "" {
|
||||
return fmt.Errorf("upstream error: %d", resp.StatusCode)
|
||||
}
|
||||
return fmt.Errorf("upstream error: %d message=%s", resp.StatusCode, upstreamMsg)
|
||||
}
|
||||
|
||||
func isOpenAIPassthroughBlockedRequestHeader(lowerKey string) bool {
|
||||
switch lowerKey {
|
||||
// hop-by-hop
|
||||
case "connection", "transfer-encoding", "keep-alive", "proxy-connection", "upgrade", "te", "trailer":
|
||||
return true
|
||||
// 入站鉴权与潜在泄露
|
||||
case "authorization", "x-api-key", "x-goog-api-key", "cookie":
|
||||
return true
|
||||
// 由 HTTP 库管理
|
||||
case "host", "content-length":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
type openaiStreamingResultPassthrough struct {
|
||||
usage *OpenAIUsage
|
||||
firstTokenMs *int
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
|
||||
ctx context.Context,
|
||||
resp *http.Response,
|
||||
c *gin.Context,
|
||||
account *Account,
|
||||
startTime time.Time,
|
||||
) (*openaiStreamingResultPassthrough, error) {
|
||||
writeOpenAIPassthroughResponseHeaders(c.Writer.Header(), resp.Header, s.cfg)
|
||||
|
||||
// SSE headers
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("X-Accel-Buffering", "no")
|
||||
if v := resp.Header.Get("x-request-id"); v != "" {
|
||||
c.Header("x-request-id", v)
|
||||
}
|
||||
|
||||
w := c.Writer
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
return nil, errors.New("streaming not supported")
|
||||
}
|
||||
|
||||
usage := &OpenAIUsage{}
|
||||
var firstTokenMs *int
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
maxLineSize := defaultMaxLineSize
|
||||
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
|
||||
maxLineSize = s.cfg.Gateway.MaxLineSize
|
||||
}
|
||||
scanBuf := getSSEScannerBuf64K()
|
||||
scanner.Buffer(scanBuf[:0], maxLineSize)
|
||||
defer putSSEScannerBuf64K(scanBuf)
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if openaiSSEDataRe.MatchString(line) {
|
||||
data := openaiSSEDataRe.ReplaceAllString(line, "")
|
||||
if firstTokenMs == nil && strings.TrimSpace(data) != "" {
|
||||
ms := int(time.Since(startTime).Milliseconds())
|
||||
firstTokenMs = &ms
|
||||
}
|
||||
s.parseSSEUsage(data, usage)
|
||||
}
|
||||
|
||||
if _, err := fmt.Fprintln(w, line); err != nil {
|
||||
// 客户端断开时停止写入
|
||||
break
|
||||
}
|
||||
flusher.Flush()
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, nil
|
||||
}
|
||||
if errors.Is(err, bufio.ErrTooLong) {
|
||||
log.Printf("[OpenAI passthrough] SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, err)
|
||||
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, err
|
||||
}
|
||||
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", err)
|
||||
}
|
||||
|
||||
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, nil
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) handleNonStreamingResponsePassthrough(
|
||||
ctx context.Context,
|
||||
resp *http.Response,
|
||||
c *gin.Context,
|
||||
) (*OpenAIUsage, error) {
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
usage := &OpenAIUsage{}
|
||||
usageParsed := false
|
||||
if len(body) > 0 {
|
||||
var response struct {
|
||||
Usage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
InputTokenDetails struct {
|
||||
CachedTokens int `json:"cached_tokens"`
|
||||
} `json:"input_tokens_details"`
|
||||
} `json:"usage"`
|
||||
}
|
||||
if json.Unmarshal(body, &response) == nil {
|
||||
usage.InputTokens = response.Usage.InputTokens
|
||||
usage.OutputTokens = response.Usage.OutputTokens
|
||||
usage.CacheReadInputTokens = response.Usage.InputTokenDetails.CachedTokens
|
||||
usageParsed = true
|
||||
}
|
||||
}
|
||||
if !usageParsed {
|
||||
// 兜底:尝试从 SSE 文本中解析 usage
|
||||
usage = s.parseSSEUsageFromBody(string(body))
|
||||
}
|
||||
|
||||
writeOpenAIPassthroughResponseHeaders(c.Writer.Header(), resp.Header, s.cfg)
|
||||
|
||||
contentType := resp.Header.Get("Content-Type")
|
||||
if contentType == "" {
|
||||
contentType = "application/json"
|
||||
}
|
||||
c.Data(resp.StatusCode, contentType, body)
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
func writeOpenAIPassthroughResponseHeaders(dst http.Header, src http.Header, cfg *config.Config) {
|
||||
if dst == nil || src == nil {
|
||||
return
|
||||
}
|
||||
if cfg != nil {
|
||||
responseheaders.WriteFilteredHeaders(dst, src, cfg.Security.ResponseHeaders)
|
||||
} else {
|
||||
// 兜底:尽量保留最基础的 content-type
|
||||
if v := strings.TrimSpace(src.Get("Content-Type")); v != "" {
|
||||
dst.Set("Content-Type", v)
|
||||
}
|
||||
}
|
||||
// 透传模式强制放行 x-codex-* 响应头(若上游返回)。
|
||||
// 注意:真实 http.Response.Header 的 key 一般会被 canonicalize;但为了兼容测试/自建响应,
|
||||
// 这里用 EqualFold 做一次大小写不敏感的查找。
|
||||
getCaseInsensitiveValues := func(h http.Header, want string) []string {
|
||||
if h == nil {
|
||||
return nil
|
||||
}
|
||||
for k, vals := range h {
|
||||
if strings.EqualFold(k, want) {
|
||||
return vals
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, rawKey := range []string{
|
||||
"x-codex-primary-used-percent",
|
||||
"x-codex-primary-reset-after-seconds",
|
||||
"x-codex-primary-window-minutes",
|
||||
"x-codex-secondary-used-percent",
|
||||
"x-codex-secondary-reset-after-seconds",
|
||||
"x-codex-secondary-window-minutes",
|
||||
"x-codex-primary-over-secondary-limit-percent",
|
||||
} {
|
||||
vals := getCaseInsensitiveValues(src, rawKey)
|
||||
if len(vals) == 0 {
|
||||
continue
|
||||
}
|
||||
key := http.CanonicalHeaderKey(rawKey)
|
||||
dst.Del(key)
|
||||
for _, v := range vals {
|
||||
dst.Add(key, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token string, isStream bool, promptCacheKey string, isCodexCLI bool) (*http.Request, error) {
|
||||
// Determine target URL based on account type
|
||||
var targetURL string
|
||||
@@ -1904,6 +2278,9 @@ func (s *OpenAIGatewayService) updateCodexUsageSnapshot(ctx context.Context, acc
|
||||
if snapshot == nil {
|
||||
return
|
||||
}
|
||||
if s == nil || s.accountRepo == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Convert snapshot to map for merging into Extra
|
||||
updates := make(map[string]any)
|
||||
|
||||
Reference in New Issue
Block a user