fix(openai): 修复 OAuth 透传流式断开与压缩头问题
- 透传流式在客户端断开后继续 drain 上游并解析 usage,避免计费信息丢失 - 阻断透传 accept-encoding,避免压缩响应影响 SSE/usage 解析 - 阻断 proxy-authorization,避免透传代理鉴权信息 - 补充回归测试:请求头阻断与断流后 usage 采集
This commit is contained in:
@@ -1183,7 +1183,10 @@ func isOpenAIPassthroughBlockedRequestHeader(lowerKey string) bool {
|
||||
case "connection", "transfer-encoding", "keep-alive", "proxy-connection", "upgrade", "te", "trailer":
|
||||
return true
|
||||
// 入站鉴权与潜在泄露
|
||||
case "authorization", "x-api-key", "x-goog-api-key", "cookie":
|
||||
case "authorization", "x-api-key", "x-goog-api-key", "cookie", "proxy-authorization":
|
||||
return true
|
||||
// 由 Go http client 自动协商压缩;透传模式需避免上游返回压缩体影响 SSE/usage 解析
|
||||
case "accept-encoding":
|
||||
return true
|
||||
// 由 HTTP 库管理
|
||||
case "host", "content-length":
|
||||
@@ -1224,6 +1227,7 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
|
||||
|
||||
usage := &OpenAIUsage{}
|
||||
var firstTokenMs *int
|
||||
clientDisconnected := false
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
maxLineSize := defaultMaxLineSize
|
||||
@@ -1245,13 +1249,20 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
|
||||
s.parseSSEUsage(data, usage)
|
||||
}
|
||||
|
||||
if _, err := fmt.Fprintln(w, line); err != nil {
|
||||
// 客户端断开时停止写入
|
||||
break
|
||||
if !clientDisconnected {
|
||||
if _, err := fmt.Fprintln(w, line); err != nil {
|
||||
clientDisconnected = true
|
||||
log.Printf("[OpenAI passthrough] Client disconnected during streaming, continue draining upstream for usage: account=%d", account.ID)
|
||||
} else {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
flusher.Flush()
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
if clientDisconnected {
|
||||
log.Printf("[OpenAI passthrough] Upstream read error after client disconnect: account=%d err=%v", account.ID, err)
|
||||
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, nil
|
||||
}
|
||||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, nil
|
||||
}
|
||||
|
||||
@@ -54,6 +54,8 @@ func TestOpenAIGatewayService_OAuthPassthrough_StreamKeepsToolNameAndBodyUnchang
|
||||
c.Request.Header.Set("Cookie", "secret=1")
|
||||
c.Request.Header.Set("X-Api-Key", "sk-inbound")
|
||||
c.Request.Header.Set("X-Goog-Api-Key", "goog-inbound")
|
||||
c.Request.Header.Set("Accept-Encoding", "gzip")
|
||||
c.Request.Header.Set("Proxy-Authorization", "Basic abc")
|
||||
c.Request.Header.Set("X-Test", "keep")
|
||||
|
||||
originalBody := []byte(`{"model":"gpt-5.2","stream":true,"store":true,"input":[{"type":"text","text":"hi"}]}`)
|
||||
@@ -108,6 +110,8 @@ func TestOpenAIGatewayService_OAuthPassthrough_StreamKeepsToolNameAndBodyUnchang
|
||||
require.Empty(t, upstream.lastReq.Header.Get("Cookie"))
|
||||
require.Empty(t, upstream.lastReq.Header.Get("X-Api-Key"))
|
||||
require.Empty(t, upstream.lastReq.Header.Get("X-Goog-Api-Key"))
|
||||
require.Empty(t, upstream.lastReq.Header.Get("Accept-Encoding"))
|
||||
require.Empty(t, upstream.lastReq.Header.Get("Proxy-Authorization"))
|
||||
require.Equal(t, "keep", upstream.lastReq.Header.Get("X-Test"))
|
||||
|
||||
// 3) required OAuth headers are present
|
||||
@@ -362,3 +366,58 @@ func TestOpenAIGatewayService_OAuthPassthrough_StreamingSetsFirstTokenMs(t *test
|
||||
require.NotNil(t, result.FirstTokenMs)
|
||||
require.GreaterOrEqual(t, *result.FirstTokenMs, 0)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_OAuthPassthrough_StreamClientDisconnectStillCollectsUsage(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil))
|
||||
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0")
|
||||
// 首次写入成功,后续写入失败,模拟客户端中途断开。
|
||||
c.Writer = &failingGinWriter{ResponseWriter: c.Writer, failAfter: 1}
|
||||
|
||||
originalBody := []byte(`{"model":"gpt-5.2","stream":true,"input":[{"type":"text","text":"hi"}]}`)
|
||||
|
||||
upstreamSSE := strings.Join([]string{
|
||||
`data: {"type":"response.output_text.delta","delta":"h"}`,
|
||||
"",
|
||||
`data: {"type":"response.completed","response":{"usage":{"input_tokens":11,"output_tokens":7,"input_tokens_details":{"cached_tokens":3}}}}`,
|
||||
"",
|
||||
"data: [DONE]",
|
||||
"",
|
||||
}, "\n")
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid"}},
|
||||
Body: io.NopCloser(strings.NewReader(upstreamSSE)),
|
||||
}
|
||||
upstream := &httpUpstreamRecorder{resp: resp}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}},
|
||||
httpUpstream: upstream,
|
||||
}
|
||||
|
||||
account := &Account{
|
||||
ID: 123,
|
||||
Name: "acc",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"},
|
||||
Extra: map[string]any{"openai_oauth_passthrough": true},
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
RateMultiplier: f64p(1),
|
||||
}
|
||||
|
||||
result, err := svc.Forward(context.Background(), c, account, originalBody)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.True(t, result.Stream)
|
||||
require.NotNil(t, result.FirstTokenMs)
|
||||
require.Equal(t, 11, result.Usage.InputTokens)
|
||||
require.Equal(t, 7, result.Usage.OutputTokens)
|
||||
require.Equal(t, 3, result.Usage.CacheReadInputTokens)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user