fix(gateway): 恢复 Anthropic 透传流数据间隔超时保护并补充回归测试

This commit is contained in:
yangjianbo
2026-02-21 16:54:44 +08:00
parent fdfc739b72
commit 1985be26b2
2 changed files with 215 additions and 46 deletions

View File

@@ -3679,7 +3679,7 @@ func (s *GatewayService) forwardAnthropicAPIKeyPassthrough(
var firstTokenMs *int
var clientDisconnect bool
if reqStream {
streamResult, err := s.handleStreamingResponseAnthropicAPIKeyPassthrough(ctx, resp, c, account, startTime)
streamResult, err := s.handleStreamingResponseAnthropicAPIKeyPassthrough(ctx, resp, c, account, startTime, reqModel)
if err != nil {
return nil, err
}
@@ -3764,6 +3764,7 @@ func (s *GatewayService) handleStreamingResponseAnthropicAPIKeyPassthrough(
c *gin.Context,
account *Account,
startTime time.Time,
model string,
) (*streamingResult, error) {
if s.rateLimitService != nil {
s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header)
@@ -3804,55 +3805,118 @@ func (s *GatewayService) handleStreamingResponseAnthropicAPIKeyPassthrough(
}
scanBuf := getSSEScannerBuf64K()
scanner.Buffer(scanBuf[:0], maxLineSize)
defer putSSEScannerBuf64K(scanBuf)
for scanner.Scan() {
line := scanner.Text()
if data, ok := extractAnthropicSSEDataLine(line); ok {
trimmed := strings.TrimSpace(data)
if firstTokenMs == nil && trimmed != "" && trimmed != "[DONE]" {
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
}
s.parseSSEUsagePassthrough(data, usage)
type scanEvent struct {
line string
err error
}
events := make(chan scanEvent, 16)
done := make(chan struct{})
sendEvent := func(ev scanEvent) bool {
select {
case events <- ev:
return true
case <-done:
return false
}
if !clientDisconnected {
if _, err := io.WriteString(w, line); err != nil {
clientDisconnected = true
logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] Client disconnected during streaming, continue draining upstream for usage: account=%d", account.ID)
} else if _, err := io.WriteString(w, "\n"); err != nil {
clientDisconnected = true
logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] Client disconnected during streaming, continue draining upstream for usage: account=%d", account.ID)
} else if line == "" {
// 按 SSE 事件边界刷出,减少每行 flush 带来的 syscall 开销。
flusher.Flush()
}
var lastReadAt int64
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
go func(scanBuf *sseScannerBuf64K) {
defer putSSEScannerBuf64K(scanBuf)
defer close(events)
for scanner.Scan() {
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
if !sendEvent(scanEvent{line: scanner.Text()}) {
return
}
}
if err := scanner.Err(); err != nil {
_ = sendEvent(scanEvent{err: err})
}
}(scanBuf)
defer close(done)
streamInterval := time.Duration(0)
if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 {
streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second
}
if !clientDisconnected {
// 兜底补刷,确保最后一个未以空行结尾的事件也能及时送达客户端。
flusher.Flush()
var intervalTicker *time.Ticker
if streamInterval > 0 {
intervalTicker = time.NewTicker(streamInterval)
defer intervalTicker.Stop()
}
var intervalCh <-chan time.Time
if intervalTicker != nil {
intervalCh = intervalTicker.C
}
if err := scanner.Err(); err != nil {
if clientDisconnected {
logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] Upstream read error after client disconnect: account=%d err=%v", account.ID, err)
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
}
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] 流读取被取消: account=%d request_id=%s err=%v ctx_err=%v",
account.ID, resp.Header.Get("x-request-id"), err, ctx.Err())
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
}
if errors.Is(err, bufio.ErrTooLong) {
logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, err)
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
}
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", err)
}
for {
select {
case ev, ok := <-events:
if !ok {
if !clientDisconnected {
// 兜底补刷,确保最后一个未以空行结尾的事件也能及时送达客户端。
flusher.Flush()
}
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, nil
}
if ev.err != nil {
if clientDisconnected {
logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] Upstream read error after client disconnect: account=%d err=%v", account.ID, ev.err)
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
}
if errors.Is(ev.err, context.Canceled) || errors.Is(ev.err, context.DeadlineExceeded) {
logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] 流读取被取消: account=%d request_id=%s err=%v ctx_err=%v",
account.ID, resp.Header.Get("x-request-id"), ev.err, ctx.Err())
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
}
if errors.Is(ev.err, bufio.ErrTooLong) {
logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err)
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, ev.err
}
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", ev.err)
}
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, nil
line := ev.line
if data, ok := extractAnthropicSSEDataLine(line); ok {
trimmed := strings.TrimSpace(data)
if firstTokenMs == nil && trimmed != "" && trimmed != "[DONE]" {
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
}
s.parseSSEUsagePassthrough(data, usage)
}
if !clientDisconnected {
if _, err := io.WriteString(w, line); err != nil {
clientDisconnected = true
logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] Client disconnected during streaming, continue draining upstream for usage: account=%d", account.ID)
} else if _, err := io.WriteString(w, "\n"); err != nil {
clientDisconnected = true
logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] Client disconnected during streaming, continue draining upstream for usage: account=%d", account.ID)
} else if line == "" {
// 按 SSE 事件边界刷出,减少每行 flush 带来的 syscall 开销。
flusher.Flush()
}
}
case <-intervalCh:
lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
if time.Since(lastRead) < streamInterval {
continue
}
if clientDisconnected {
logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] Upstream timeout after client disconnect: account=%d model=%s", account.ID, model)
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
}
logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] Stream data interval timeout: account=%d model=%s interval=%s", account.ID, model, streamInterval)
if s.rateLimitService != nil {
s.rateLimitService.HandleStreamTimeout(ctx, account, model)
}
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
}
}
}
func extractAnthropicSSEDataLine(line string) (string, bool) {