fix(gateway): 默认过滤OpenAI透传超时头并补充断流告警
This commit is contained in:
@@ -1020,6 +1020,23 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough(
|
||||
reqModel,
|
||||
reqStream,
|
||||
)
|
||||
if reqStream && c != nil && c.Request != nil {
|
||||
if timeoutHeaders := collectOpenAIPassthroughTimeoutHeaders(c.Request.Header); len(timeoutHeaders) > 0 {
|
||||
if s.isOpenAIPassthroughTimeoutHeadersAllowed() {
|
||||
log.Printf(
|
||||
"[WARN] [OpenAI passthrough] 透传请求包含超时相关请求头,且当前配置为放行,可能导致上游提前断流: account=%d headers=%s",
|
||||
account.ID,
|
||||
strings.Join(timeoutHeaders, ", "),
|
||||
)
|
||||
} else {
|
||||
log.Printf(
|
||||
"[WARN] [OpenAI passthrough] 检测到超时相关请求头,将按配置过滤以降低断流风险: account=%d headers=%s",
|
||||
account.ID,
|
||||
strings.Join(timeoutHeaders, ", "),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Get access token
|
||||
token, _, err := s.GetAccessToken(ctx, account)
|
||||
@@ -1135,12 +1152,16 @@ func (s *OpenAIGatewayService) buildUpstreamRequestOpenAIPassthrough(
|
||||
}
|
||||
|
||||
// 透传客户端请求头(尽可能原样),并做安全剔除。
|
||||
allowTimeoutHeaders := s.isOpenAIPassthroughTimeoutHeadersAllowed()
|
||||
if c != nil && c.Request != nil {
|
||||
for key, values := range c.Request.Header {
|
||||
lower := strings.ToLower(key)
|
||||
if isOpenAIPassthroughBlockedRequestHeader(lower) {
|
||||
continue
|
||||
}
|
||||
if !allowTimeoutHeaders && isOpenAIPassthroughTimeoutHeader(lower) {
|
||||
continue
|
||||
}
|
||||
for _, v := range values {
|
||||
req.Header.Add(key, v)
|
||||
}
|
||||
@@ -1233,6 +1254,38 @@ func isOpenAIPassthroughBlockedRequestHeader(lowerKey string) bool {
|
||||
}
|
||||
}
|
||||
|
||||
func isOpenAIPassthroughTimeoutHeader(lowerKey string) bool {
|
||||
switch lowerKey {
|
||||
case "x-stainless-timeout", "x-stainless-read-timeout", "x-stainless-connect-timeout", "x-request-timeout", "request-timeout", "grpc-timeout":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) isOpenAIPassthroughTimeoutHeadersAllowed() bool {
|
||||
return s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIPassthroughAllowTimeoutHeaders
|
||||
}
|
||||
|
||||
func collectOpenAIPassthroughTimeoutHeaders(h http.Header) []string {
|
||||
if h == nil {
|
||||
return nil
|
||||
}
|
||||
var matched []string
|
||||
for key, values := range h {
|
||||
lowerKey := strings.ToLower(strings.TrimSpace(key))
|
||||
if isOpenAIPassthroughTimeoutHeader(lowerKey) {
|
||||
entry := lowerKey
|
||||
if len(values) > 0 {
|
||||
entry = fmt.Sprintf("%s=%s", lowerKey, strings.Join(values, "|"))
|
||||
}
|
||||
matched = append(matched, entry)
|
||||
}
|
||||
}
|
||||
sort.Strings(matched)
|
||||
return matched
|
||||
}
|
||||
|
||||
type openaiStreamingResultPassthrough struct {
|
||||
usage *OpenAIUsage
|
||||
firstTokenMs *int
|
||||
@@ -1265,6 +1318,8 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
|
||||
usage := &OpenAIUsage{}
|
||||
var firstTokenMs *int
|
||||
clientDisconnected := false
|
||||
sawDone := false
|
||||
upstreamRequestID := strings.TrimSpace(resp.Header.Get("x-request-id"))
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
maxLineSize := defaultMaxLineSize
|
||||
@@ -1278,7 +1333,11 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if data, ok := extractOpenAISSEDataLine(line); ok {
|
||||
if firstTokenMs == nil && strings.TrimSpace(data) != "" {
|
||||
trimmedData := strings.TrimSpace(data)
|
||||
if trimmedData == "[DONE]" {
|
||||
sawDone = true
|
||||
}
|
||||
if firstTokenMs == nil && trimmedData != "" && trimmedData != "[DONE]" {
|
||||
ms := int(time.Since(startTime).Milliseconds())
|
||||
firstTokenMs = &ms
|
||||
}
|
||||
@@ -1300,14 +1359,34 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
|
||||
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, nil
|
||||
}
|
||||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||
log.Printf(
|
||||
"[WARN] [OpenAI passthrough] 流读取被取消,可能发生断流: account=%d request_id=%s err=%v ctx_err=%v",
|
||||
account.ID,
|
||||
upstreamRequestID,
|
||||
err,
|
||||
ctx.Err(),
|
||||
)
|
||||
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, nil
|
||||
}
|
||||
if errors.Is(err, bufio.ErrTooLong) {
|
||||
log.Printf("[OpenAI passthrough] SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, err)
|
||||
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, err
|
||||
}
|
||||
log.Printf(
|
||||
"[WARN] [OpenAI passthrough] 流读取异常中断: account=%d request_id=%s err=%v",
|
||||
account.ID,
|
||||
upstreamRequestID,
|
||||
err,
|
||||
)
|
||||
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", err)
|
||||
}
|
||||
if !clientDisconnected && !sawDone && ctx.Err() == nil {
|
||||
log.Printf(
|
||||
"[WARN] [OpenAI passthrough] 上游流在未收到 [DONE] 时结束,疑似断流: account=%d request_id=%s",
|
||||
account.ID,
|
||||
upstreamRequestID,
|
||||
)
|
||||
}
|
||||
|
||||
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user