fix(gateway): 默认过滤OpenAI透传超时头并补充断流告警

This commit is contained in:
yangjianbo
2026-02-12 14:16:18 +08:00
parent 114e172603
commit ed2eba9028
4 changed files with 278 additions and 1 deletions

View File

@@ -1020,6 +1020,23 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough(
reqModel,
reqStream,
)
if reqStream && c != nil && c.Request != nil {
if timeoutHeaders := collectOpenAIPassthroughTimeoutHeaders(c.Request.Header); len(timeoutHeaders) > 0 {
if s.isOpenAIPassthroughTimeoutHeadersAllowed() {
log.Printf(
"[WARN] [OpenAI passthrough] 透传请求包含超时相关请求头,且当前配置为放行,可能导致上游提前断流: account=%d headers=%s",
account.ID,
strings.Join(timeoutHeaders, ", "),
)
} else {
log.Printf(
"[WARN] [OpenAI passthrough] 检测到超时相关请求头,将按配置过滤以降低断流风险: account=%d headers=%s",
account.ID,
strings.Join(timeoutHeaders, ", "),
)
}
}
}
// Get access token
token, _, err := s.GetAccessToken(ctx, account)
@@ -1135,12 +1152,16 @@ func (s *OpenAIGatewayService) buildUpstreamRequestOpenAIPassthrough(
}
// 透传客户端请求头(尽可能原样),并做安全剔除。
allowTimeoutHeaders := s.isOpenAIPassthroughTimeoutHeadersAllowed()
if c != nil && c.Request != nil {
for key, values := range c.Request.Header {
lower := strings.ToLower(key)
if isOpenAIPassthroughBlockedRequestHeader(lower) {
continue
}
if !allowTimeoutHeaders && isOpenAIPassthroughTimeoutHeader(lower) {
continue
}
for _, v := range values {
req.Header.Add(key, v)
}
@@ -1233,6 +1254,38 @@ func isOpenAIPassthroughBlockedRequestHeader(lowerKey string) bool {
}
}
func isOpenAIPassthroughTimeoutHeader(lowerKey string) bool {
switch lowerKey {
case "x-stainless-timeout", "x-stainless-read-timeout", "x-stainless-connect-timeout", "x-request-timeout", "request-timeout", "grpc-timeout":
return true
default:
return false
}
}
func (s *OpenAIGatewayService) isOpenAIPassthroughTimeoutHeadersAllowed() bool {
return s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIPassthroughAllowTimeoutHeaders
}
func collectOpenAIPassthroughTimeoutHeaders(h http.Header) []string {
if h == nil {
return nil
}
var matched []string
for key, values := range h {
lowerKey := strings.ToLower(strings.TrimSpace(key))
if isOpenAIPassthroughTimeoutHeader(lowerKey) {
entry := lowerKey
if len(values) > 0 {
entry = fmt.Sprintf("%s=%s", lowerKey, strings.Join(values, "|"))
}
matched = append(matched, entry)
}
}
sort.Strings(matched)
return matched
}
type openaiStreamingResultPassthrough struct {
usage *OpenAIUsage
firstTokenMs *int
@@ -1265,6 +1318,8 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
usage := &OpenAIUsage{}
var firstTokenMs *int
clientDisconnected := false
sawDone := false
upstreamRequestID := strings.TrimSpace(resp.Header.Get("x-request-id"))
scanner := bufio.NewScanner(resp.Body)
maxLineSize := defaultMaxLineSize
@@ -1278,7 +1333,11 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
for scanner.Scan() {
line := scanner.Text()
if data, ok := extractOpenAISSEDataLine(line); ok {
if firstTokenMs == nil && strings.TrimSpace(data) != "" {
trimmedData := strings.TrimSpace(data)
if trimmedData == "[DONE]" {
sawDone = true
}
if firstTokenMs == nil && trimmedData != "" && trimmedData != "[DONE]" {
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
}
@@ -1300,14 +1359,34 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, nil
}
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
log.Printf(
"[WARN] [OpenAI passthrough] 流读取被取消,可能发生断流: account=%d request_id=%s err=%v ctx_err=%v",
account.ID,
upstreamRequestID,
err,
ctx.Err(),
)
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, nil
}
if errors.Is(err, bufio.ErrTooLong) {
log.Printf("[OpenAI passthrough] SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, err)
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, err
}
log.Printf(
"[WARN] [OpenAI passthrough] 流读取异常中断: account=%d request_id=%s err=%v",
account.ID,
upstreamRequestID,
err,
)
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", err)
}
if !clientDisconnected && !sawDone && ctx.Err() == nil {
log.Printf(
"[WARN] [OpenAI passthrough] 上游流在未收到 [DONE] 时结束,疑似断流: account=%d request_id=%s",
account.ID,
upstreamRequestID,
)
}
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, nil
}