From 7489da49cbf20b0bda80d598b5e913a596c6ef87 Mon Sep 17 00:00:00 2001 From: yangjianbo Date: Sun, 4 Jan 2026 20:19:07 +0800 Subject: [PATCH] =?UTF-8?q?fix(=E6=B5=81=E5=BC=8F):=20=E4=BB=A5=E4=B8=8A?= =?UTF-8?q?=E6=B8=B8=E8=AF=BB=E5=8F=96=E5=88=A4=E5=AE=9A=E8=B6=85=E6=97=B6?= =?UTF-8?q?=E5=B9=B6=E8=B0=83=E5=A4=A7=E4=BA=8B=E4=BB=B6=E7=BC=93=E5=86=B2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 以读取时间戳判定流式间隔超时,避免下游阻塞误判 - antigravity 流式读取使用 MaxLineSize 配置 - 事件通道缓冲提升到 16 测试: go test ./... --- .../service/antigravity_gateway_service.go | 81 +++++++++---------- backend/internal/service/gateway_service.go | 40 ++++----- .../service/openai_gateway_service.go | 25 +++--- 3 files changed, 71 insertions(+), 75 deletions(-) diff --git a/backend/internal/service/antigravity_gateway_service.go b/backend/internal/service/antigravity_gateway_service.go index e2719cf6..8f97598f 100644 --- a/backend/internal/service/antigravity_gateway_service.go +++ b/backend/internal/service/antigravity_gateway_service.go @@ -11,6 +11,7 @@ import ( "log" "net/http" "strings" + "sync/atomic" "time" "github.com/Wei-Shaw/sub2api/internal/config" @@ -680,7 +681,11 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context // 使用 Scanner 并限制单行大小,避免 ReadString 无上限导致 OOM scanner := bufio.NewScanner(resp.Body) - scanner.Buffer(make([]byte, 64*1024), defaultMaxLineSize) + maxLineSize := defaultMaxLineSize + if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 { + maxLineSize = s.cfg.Gateway.MaxLineSize + } + scanner.Buffer(make([]byte, 64*1024), maxLineSize) usage := &ClaudeUsage{} var firstTokenMs *int @@ -689,7 +694,7 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context err error } // 独立 goroutine 读取上游,避免读取阻塞影响超时处理 - events := make(chan scanEvent, 1) + events := make(chan scanEvent, 16) done := make(chan struct{}) sendEvent := func(ev scanEvent) bool { select { @@ -699,9 +704,12 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context return false } } + var lastReadAt int64 + atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) go func() { defer close(events) for scanner.Scan() { + atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) if !sendEvent(scanEvent{line: scanner.Text()}) { return } @@ -717,26 +725,14 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 { streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second } - var intervalTimer *time.Timer + var intervalTicker *time.Ticker if streamInterval > 0 { - intervalTimer = time.NewTimer(streamInterval) - defer intervalTimer.Stop() + intervalTicker = time.NewTicker(streamInterval) + defer intervalTicker.Stop() } var intervalCh <-chan time.Time - if intervalTimer != nil { - intervalCh = intervalTimer.C - } - resetInterval := func() { - if intervalTimer == nil { - return - } - if !intervalTimer.Stop() { - select { - case <-intervalTimer.C: - default: - } - } - intervalTimer.Reset(streamInterval) + if intervalTicker != nil { + intervalCh = intervalTicker.C } // 仅发送一次错误事件,避免多次写入导致协议混乱 @@ -758,7 +754,7 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context } if ev.err != nil { if errors.Is(ev.err, bufio.ErrTooLong) { - log.Printf("SSE line too long (antigravity): max_size=%d error=%v", defaultMaxLineSize, ev.err) + log.Printf("SSE line too long (antigravity): max_size=%d error=%v", maxLineSize, ev.err) sendErrorEvent("response_too_large") return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, ev.err } @@ -766,7 +762,6 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context return nil, ev.err } - resetInterval() line := ev.line trimmed := strings.TrimRight(line, "\r\n") if strings.HasPrefix(trimmed, "data:") { @@ -814,6 +809,10 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context flusher.Flush() case <-intervalCh: + lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt)) + if time.Since(lastRead) < streamInterval { + continue + } log.Printf("Stream data interval timeout (antigravity)") sendErrorEvent("stream_timeout") return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout") @@ -959,7 +958,11 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context var firstTokenMs *int // 使用 Scanner 并限制单行大小,避免 ReadString 无上限导致 OOM scanner := bufio.NewScanner(resp.Body) - scanner.Buffer(make([]byte, 64*1024), defaultMaxLineSize) + maxLineSize := defaultMaxLineSize + if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 { + maxLineSize = s.cfg.Gateway.MaxLineSize + } + scanner.Buffer(make([]byte, 64*1024), maxLineSize) // 辅助函数:转换 antigravity.ClaudeUsage 到 service.ClaudeUsage convertUsage := func(agUsage *antigravity.ClaudeUsage) *ClaudeUsage { @@ -979,7 +982,7 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context err error } // 独立 goroutine 读取上游,避免读取阻塞影响超时处理 - events := make(chan scanEvent, 1) + events := make(chan scanEvent, 16) done := make(chan struct{}) sendEvent := func(ev scanEvent) bool { select { @@ -989,9 +992,12 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context return false } } + var lastReadAt int64 + atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) go func() { defer close(events) for scanner.Scan() { + atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) if !sendEvent(scanEvent{line: scanner.Text()}) { return } @@ -1006,26 +1012,14 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 { streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second } - var intervalTimer *time.Timer + var intervalTicker *time.Ticker if streamInterval > 0 { - intervalTimer = time.NewTimer(streamInterval) - defer intervalTimer.Stop() + intervalTicker = time.NewTicker(streamInterval) + defer intervalTicker.Stop() } var intervalCh <-chan time.Time - if intervalTimer != nil { - intervalCh = intervalTimer.C - } - resetInterval := func() { - if intervalTimer == nil { - return - } - if !intervalTimer.Stop() { - select { - case <-intervalTimer.C: - default: - } - } - intervalTimer.Reset(streamInterval) + if intervalTicker != nil { + intervalCh = intervalTicker.C } // 仅发送一次错误事件,避免多次写入导致协议混乱 @@ -1053,7 +1047,7 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context } if ev.err != nil { if errors.Is(ev.err, bufio.ErrTooLong) { - log.Printf("SSE line too long (antigravity): max_size=%d error=%v", defaultMaxLineSize, ev.err) + log.Printf("SSE line too long (antigravity): max_size=%d error=%v", maxLineSize, ev.err) sendErrorEvent("response_too_large") return &antigravityStreamResult{usage: convertUsage(nil), firstTokenMs: firstTokenMs}, ev.err } @@ -1061,7 +1055,6 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context return nil, fmt.Errorf("stream read error: %w", ev.err) } - resetInterval() line := ev.line // 处理 SSE 行,转换为 Claude 格式 claudeEvents := processor.ProcessLine(strings.TrimRight(line, "\r\n")) @@ -1084,6 +1077,10 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context } case <-intervalCh: + lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt)) + if time.Since(lastRead) < streamInterval { + continue + } log.Printf("Stream data interval timeout (antigravity)") sendErrorEvent("stream_timeout") return &antigravityStreamResult{usage: convertUsage(nil), firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout") diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index e5282101..47c136df 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -15,6 +15,7 @@ import ( "regexp" "sort" "strings" + "sync/atomic" "time" "github.com/Wei-Shaw/sub2api/internal/config" @@ -1460,7 +1461,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http err error } // 独立 goroutine 读取上游,避免读取阻塞导致超时/keepalive无法处理 - events := make(chan scanEvent, 1) + events := make(chan scanEvent, 16) done := make(chan struct{}) sendEvent := func(ev scanEvent) bool { select { @@ -1470,9 +1471,12 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http return false } } + var lastReadAt int64 + atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) go func() { defer close(events) for scanner.Scan() { + atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) if !sendEvent(scanEvent{line: scanner.Text()}) { return } @@ -1487,11 +1491,15 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 { streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second } - // 仅监控上游数据间隔超时,避免上游挂起占用资源 - var intervalTimer *time.Timer + // 仅监控上游数据间隔超时,避免下游写入阻塞导致误判 + var intervalTicker *time.Ticker if streamInterval > 0 { - intervalTimer = time.NewTimer(streamInterval) - defer intervalTimer.Stop() + intervalTicker = time.NewTicker(streamInterval) + defer intervalTicker.Stop() + } + var intervalCh <-chan time.Time + if intervalTicker != nil { + intervalCh = intervalTicker.C } // 仅发送一次错误事件,避免多次写入导致协议混乱(写失败时尽力通知客户端) @@ -1523,9 +1531,6 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", ev.err) } line := ev.line - if intervalTimer != nil { - resetTimer(intervalTimer, streamInterval) - } if line == "event: error" { return nil, errors.New("have error in stream") } @@ -1561,12 +1566,11 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http flusher.Flush() } - case <-func() <-chan time.Time { - if intervalTimer != nil { - return intervalTimer.C + case <-intervalCh: + lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt)) + if time.Since(lastRead) < streamInterval { + continue } - return nil - }(): log.Printf("Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval) sendErrorEvent("stream_timeout") return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout") @@ -1576,16 +1580,6 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil } -func resetTimer(timer *time.Timer, interval time.Duration) { - if !timer.Stop() { - select { - case <-timer.C: - default: - } - } - timer.Reset(interval) -} - // replaceModelInSSELine 替换SSE数据行中的model字段 func (s *GatewayService) replaceModelInSSELine(line, fromModel, toModel string) string { if !sseDataRe.MatchString(line) { diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index 6589df2a..ab5c3d89 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -16,6 +16,7 @@ import ( "sort" "strconv" "strings" + "sync/atomic" "time" "github.com/Wei-Shaw/sub2api/internal/config" @@ -786,7 +787,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp err error } // 独立 goroutine 读取上游,避免读取阻塞影响 keepalive/超时处理 - events := make(chan scanEvent, 1) + events := make(chan scanEvent, 16) done := make(chan struct{}) sendEvent := func(ev scanEvent) bool { select { @@ -796,9 +797,12 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp return false } } + var lastReadAt int64 + atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) go func() { defer close(events) for scanner.Scan() { + atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) if !sendEvent(scanEvent{line: scanner.Text()}) { return } @@ -813,15 +817,15 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 { streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second } - // 仅监控上游数据间隔超时,不被下游 keepalive 影响 - var intervalTimer *time.Timer + // 仅监控上游数据间隔超时,不被下游写入阻塞影响 + var intervalTicker *time.Ticker if streamInterval > 0 { - intervalTimer = time.NewTimer(streamInterval) - defer intervalTimer.Stop() + intervalTicker = time.NewTicker(streamInterval) + defer intervalTicker.Stop() } var intervalCh <-chan time.Time - if intervalTimer != nil { - intervalCh = intervalTimer.C + if intervalTicker != nil { + intervalCh = intervalTicker.C } keepaliveInterval := time.Duration(0) @@ -872,9 +876,6 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp line := ev.line lastDataAt = time.Now() - if intervalTimer != nil { - resetTimer(intervalTimer, streamInterval) - } // Extract data from SSE line (supports both "data: " and "data:" formats) if openaiSSEDataRe.MatchString(line) { @@ -908,6 +909,10 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp } case <-intervalCh: + lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt)) + if time.Since(lastRead) < streamInterval { + continue + } log.Printf("Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval) sendErrorEvent("stream_timeout") return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")