From 1985be26b2397efc7b177ea1fae8a510b429d679 Mon Sep 17 00:00:00 2001 From: yangjianbo Date: Sat, 21 Feb 2026 16:54:44 +0800 Subject: [PATCH] =?UTF-8?q?fix(gateway):=20=E6=81=A2=E5=A4=8D=20Anthropic?= =?UTF-8?q?=20=E9=80=8F=E4=BC=A0=E6=B5=81=E6=95=B0=E6=8D=AE=E9=97=B4?= =?UTF-8?q?=E9=9A=94=E8=B6=85=E6=97=B6=E4=BF=9D=E6=8A=A4=E5=B9=B6=E8=A1=A5?= =?UTF-8?q?=E5=85=85=E5=9B=9E=E5=BD=92=E6=B5=8B=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ...teway_anthropic_apikey_passthrough_test.go | 113 ++++++++++++- backend/internal/service/gateway_service.go | 148 +++++++++++++----- 2 files changed, 215 insertions(+), 46 deletions(-) diff --git a/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go b/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go index 28641ca3..5183891b 100644 --- a/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go +++ b/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go @@ -352,7 +352,7 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingStillCollectsUsageAf }, "\n"))), } - result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 1}, time.Now()) + result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "claude-3-7-sonnet-20250219") require.NoError(t, err) require.NotNil(t, result) require.NotNil(t, result.usage) @@ -602,12 +602,117 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingErrTooLong(t *testin Body: io.NopCloser(strings.NewReader(longLine)), } - result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 2}, time.Now()) + result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 2}, time.Now(), "claude-3-7-sonnet-20250219") require.Error(t, err) require.ErrorIs(t, err, bufio.ErrTooLong) require.NotNil(t, result) } +func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingDataIntervalTimeout(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + svc := &GatewayService{ + cfg: &config.Config{ + Gateway: config.GatewayConfig{ + StreamDataIntervalTimeout: 1, + MaxLineSize: defaultMaxLineSize, + }, + }, + rateLimitService: &RateLimitService{}, + } + + pr, pw := io.Pipe() + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}}, + Body: pr, + } + + result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 5}, time.Now(), "claude-3-7-sonnet-20250219") + _ = pw.Close() + _ = pr.Close() + + require.Error(t, err) + require.Contains(t, err.Error(), "stream data interval timeout") + require.NotNil(t, result) + require.False(t, result.clientDisconnect) +} + +func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingReadError(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + svc := &GatewayService{ + cfg: &config.Config{ + Gateway: config.GatewayConfig{ + MaxLineSize: defaultMaxLineSize, + }, + }, + } + + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}}, + Body: &streamReadCloser{ + err: io.ErrUnexpectedEOF, + }, + } + + result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 6}, time.Now(), "claude-3-7-sonnet-20250219") + require.Error(t, err) + require.Contains(t, err.Error(), "stream read error") + require.NotNil(t, result) + require.False(t, result.clientDisconnect) +} + +func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingTimeoutAfterClientDisconnect(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + c.Writer = &failWriteResponseWriter{ResponseWriter: c.Writer} + + svc := &GatewayService{ + cfg: &config.Config{ + Gateway: config.GatewayConfig{ + StreamDataIntervalTimeout: 1, + MaxLineSize: defaultMaxLineSize, + }, + }, + rateLimitService: &RateLimitService{}, + } + + pr, pw := io.Pipe() + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}}, + Body: pr, + } + + done := make(chan struct{}) + go func() { + defer close(done) + _, _ = pw.Write([]byte(`data: {"type":"message_start","message":{"usage":{"input_tokens":9}}}` + "\n")) + // 保持上游连接静默,触发数据间隔超时分支。 + time.Sleep(1500 * time.Millisecond) + _ = pw.Close() + }() + + result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 7}, time.Now(), "claude-3-7-sonnet-20250219") + _ = pr.Close() + <-done + + require.NoError(t, err) + require.NotNil(t, result) + require.True(t, result.clientDisconnect) + require.Equal(t, 9, result.usage.InputTokens) +} + func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingContextCanceled(t *testing.T) { gin.SetMode(gin.TestMode) rec := httptest.NewRecorder() @@ -630,7 +735,7 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingContextCanceled(t *t }, } - result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 3}, time.Now()) + result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 3}, time.Now(), "claude-3-7-sonnet-20250219") require.NoError(t, err) require.NotNil(t, result) require.True(t, result.clientDisconnect) @@ -660,7 +765,7 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingUpstreamReadErrorAft }, } - result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 4}, time.Now()) + result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 4}, time.Now(), "claude-3-7-sonnet-20250219") require.NoError(t, err) require.NotNil(t, result) require.True(t, result.clientDisconnect) diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index e146ded0..f16f685f 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -3679,7 +3679,7 @@ func (s *GatewayService) forwardAnthropicAPIKeyPassthrough( var firstTokenMs *int var clientDisconnect bool if reqStream { - streamResult, err := s.handleStreamingResponseAnthropicAPIKeyPassthrough(ctx, resp, c, account, startTime) + streamResult, err := s.handleStreamingResponseAnthropicAPIKeyPassthrough(ctx, resp, c, account, startTime, reqModel) if err != nil { return nil, err } @@ -3764,6 +3764,7 @@ func (s *GatewayService) handleStreamingResponseAnthropicAPIKeyPassthrough( c *gin.Context, account *Account, startTime time.Time, + model string, ) (*streamingResult, error) { if s.rateLimitService != nil { s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header) @@ -3804,55 +3805,118 @@ func (s *GatewayService) handleStreamingResponseAnthropicAPIKeyPassthrough( } scanBuf := getSSEScannerBuf64K() scanner.Buffer(scanBuf[:0], maxLineSize) - defer putSSEScannerBuf64K(scanBuf) - for scanner.Scan() { - line := scanner.Text() - if data, ok := extractAnthropicSSEDataLine(line); ok { - trimmed := strings.TrimSpace(data) - if firstTokenMs == nil && trimmed != "" && trimmed != "[DONE]" { - ms := int(time.Since(startTime).Milliseconds()) - firstTokenMs = &ms - } - s.parseSSEUsagePassthrough(data, usage) + type scanEvent struct { + line string + err error + } + events := make(chan scanEvent, 16) + done := make(chan struct{}) + sendEvent := func(ev scanEvent) bool { + select { + case events <- ev: + return true + case <-done: + return false } - - if !clientDisconnected { - if _, err := io.WriteString(w, line); err != nil { - clientDisconnected = true - logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] Client disconnected during streaming, continue draining upstream for usage: account=%d", account.ID) - } else if _, err := io.WriteString(w, "\n"); err != nil { - clientDisconnected = true - logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] Client disconnected during streaming, continue draining upstream for usage: account=%d", account.ID) - } else if line == "" { - // 按 SSE 事件边界刷出,减少每行 flush 带来的 syscall 开销。 - flusher.Flush() + } + var lastReadAt int64 + atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) + go func(scanBuf *sseScannerBuf64K) { + defer putSSEScannerBuf64K(scanBuf) + defer close(events) + for scanner.Scan() { + atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) + if !sendEvent(scanEvent{line: scanner.Text()}) { + return } } + if err := scanner.Err(); err != nil { + _ = sendEvent(scanEvent{err: err}) + } + }(scanBuf) + defer close(done) + + streamInterval := time.Duration(0) + if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 { + streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second } - if !clientDisconnected { - // 兜底补刷,确保最后一个未以空行结尾的事件也能及时送达客户端。 - flusher.Flush() + var intervalTicker *time.Ticker + if streamInterval > 0 { + intervalTicker = time.NewTicker(streamInterval) + defer intervalTicker.Stop() + } + var intervalCh <-chan time.Time + if intervalTicker != nil { + intervalCh = intervalTicker.C } - if err := scanner.Err(); err != nil { - if clientDisconnected { - logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] Upstream read error after client disconnect: account=%d err=%v", account.ID, err) - return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil - } - if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { - logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] 流读取被取消: account=%d request_id=%s err=%v ctx_err=%v", - account.ID, resp.Header.Get("x-request-id"), err, ctx.Err()) - return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil - } - if errors.Is(err, bufio.ErrTooLong) { - logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, err) - return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, err - } - return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", err) - } + for { + select { + case ev, ok := <-events: + if !ok { + if !clientDisconnected { + // 兜底补刷,确保最后一个未以空行结尾的事件也能及时送达客户端。 + flusher.Flush() + } + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, nil + } + if ev.err != nil { + if clientDisconnected { + logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] Upstream read error after client disconnect: account=%d err=%v", account.ID, ev.err) + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil + } + if errors.Is(ev.err, context.Canceled) || errors.Is(ev.err, context.DeadlineExceeded) { + logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] 流读取被取消: account=%d request_id=%s err=%v ctx_err=%v", + account.ID, resp.Header.Get("x-request-id"), ev.err, ctx.Err()) + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil + } + if errors.Is(ev.err, bufio.ErrTooLong) { + logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err) + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, ev.err + } + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", ev.err) + } - return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, nil + line := ev.line + if data, ok := extractAnthropicSSEDataLine(line); ok { + trimmed := strings.TrimSpace(data) + if firstTokenMs == nil && trimmed != "" && trimmed != "[DONE]" { + ms := int(time.Since(startTime).Milliseconds()) + firstTokenMs = &ms + } + s.parseSSEUsagePassthrough(data, usage) + } + + if !clientDisconnected { + if _, err := io.WriteString(w, line); err != nil { + clientDisconnected = true + logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] Client disconnected during streaming, continue draining upstream for usage: account=%d", account.ID) + } else if _, err := io.WriteString(w, "\n"); err != nil { + clientDisconnected = true + logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] Client disconnected during streaming, continue draining upstream for usage: account=%d", account.ID) + } else if line == "" { + // 按 SSE 事件边界刷出,减少每行 flush 带来的 syscall 开销。 + flusher.Flush() + } + } + + case <-intervalCh: + lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt)) + if time.Since(lastRead) < streamInterval { + continue + } + if clientDisconnected { + logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] Upstream timeout after client disconnect: account=%d model=%s", account.ID, model) + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil + } + logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] Stream data interval timeout: account=%d model=%s interval=%s", account.ID, model, streamInterval) + if s.rateLimitService != nil { + s.rateLimitService.HandleStreamTimeout(ctx, account, model) + } + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout") + } + } } func extractAnthropicSSEDataLine(line string) (string, bool) {