fix(gateway): 恢复 Anthropic 透传流数据间隔超时保护并补充回归测试
This commit is contained in:
@@ -352,7 +352,7 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingStillCollectsUsageAf
|
|||||||
}, "\n"))),
|
}, "\n"))),
|
||||||
}
|
}
|
||||||
|
|
||||||
result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 1}, time.Now())
|
result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "claude-3-7-sonnet-20250219")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, result)
|
require.NotNil(t, result)
|
||||||
require.NotNil(t, result.usage)
|
require.NotNil(t, result.usage)
|
||||||
@@ -602,12 +602,117 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingErrTooLong(t *testin
|
|||||||
Body: io.NopCloser(strings.NewReader(longLine)),
|
Body: io.NopCloser(strings.NewReader(longLine)),
|
||||||
}
|
}
|
||||||
|
|
||||||
result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 2}, time.Now())
|
result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 2}, time.Now(), "claude-3-7-sonnet-20250219")
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
require.ErrorIs(t, err, bufio.ErrTooLong)
|
require.ErrorIs(t, err, bufio.ErrTooLong)
|
||||||
require.NotNil(t, result)
|
require.NotNil(t, result)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingDataIntervalTimeout(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||||
|
|
||||||
|
svc := &GatewayService{
|
||||||
|
cfg: &config.Config{
|
||||||
|
Gateway: config.GatewayConfig{
|
||||||
|
StreamDataIntervalTimeout: 1,
|
||||||
|
MaxLineSize: defaultMaxLineSize,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
rateLimitService: &RateLimitService{},
|
||||||
|
}
|
||||||
|
|
||||||
|
pr, pw := io.Pipe()
|
||||||
|
resp := &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: http.Header{"Content-Type": []string{"text/event-stream"}},
|
||||||
|
Body: pr,
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 5}, time.Now(), "claude-3-7-sonnet-20250219")
|
||||||
|
_ = pw.Close()
|
||||||
|
_ = pr.Close()
|
||||||
|
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "stream data interval timeout")
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.False(t, result.clientDisconnect)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingReadError(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||||
|
|
||||||
|
svc := &GatewayService{
|
||||||
|
cfg: &config.Config{
|
||||||
|
Gateway: config.GatewayConfig{
|
||||||
|
MaxLineSize: defaultMaxLineSize,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: http.Header{"Content-Type": []string{"text/event-stream"}},
|
||||||
|
Body: &streamReadCloser{
|
||||||
|
err: io.ErrUnexpectedEOF,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 6}, time.Now(), "claude-3-7-sonnet-20250219")
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "stream read error")
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.False(t, result.clientDisconnect)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingTimeoutAfterClientDisconnect(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||||
|
c.Writer = &failWriteResponseWriter{ResponseWriter: c.Writer}
|
||||||
|
|
||||||
|
svc := &GatewayService{
|
||||||
|
cfg: &config.Config{
|
||||||
|
Gateway: config.GatewayConfig{
|
||||||
|
StreamDataIntervalTimeout: 1,
|
||||||
|
MaxLineSize: defaultMaxLineSize,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
rateLimitService: &RateLimitService{},
|
||||||
|
}
|
||||||
|
|
||||||
|
pr, pw := io.Pipe()
|
||||||
|
resp := &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: http.Header{"Content-Type": []string{"text/event-stream"}},
|
||||||
|
Body: pr,
|
||||||
|
}
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
defer close(done)
|
||||||
|
_, _ = pw.Write([]byte(`data: {"type":"message_start","message":{"usage":{"input_tokens":9}}}` + "\n"))
|
||||||
|
// 保持上游连接静默,触发数据间隔超时分支。
|
||||||
|
time.Sleep(1500 * time.Millisecond)
|
||||||
|
_ = pw.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 7}, time.Now(), "claude-3-7-sonnet-20250219")
|
||||||
|
_ = pr.Close()
|
||||||
|
<-done
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.True(t, result.clientDisconnect)
|
||||||
|
require.Equal(t, 9, result.usage.InputTokens)
|
||||||
|
}
|
||||||
|
|
||||||
func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingContextCanceled(t *testing.T) {
|
func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingContextCanceled(t *testing.T) {
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
rec := httptest.NewRecorder()
|
rec := httptest.NewRecorder()
|
||||||
@@ -630,7 +735,7 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingContextCanceled(t *t
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 3}, time.Now())
|
result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 3}, time.Now(), "claude-3-7-sonnet-20250219")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, result)
|
require.NotNil(t, result)
|
||||||
require.True(t, result.clientDisconnect)
|
require.True(t, result.clientDisconnect)
|
||||||
@@ -660,7 +765,7 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingUpstreamReadErrorAft
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 4}, time.Now())
|
result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 4}, time.Now(), "claude-3-7-sonnet-20250219")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, result)
|
require.NotNil(t, result)
|
||||||
require.True(t, result.clientDisconnect)
|
require.True(t, result.clientDisconnect)
|
||||||
|
|||||||
@@ -3679,7 +3679,7 @@ func (s *GatewayService) forwardAnthropicAPIKeyPassthrough(
|
|||||||
var firstTokenMs *int
|
var firstTokenMs *int
|
||||||
var clientDisconnect bool
|
var clientDisconnect bool
|
||||||
if reqStream {
|
if reqStream {
|
||||||
streamResult, err := s.handleStreamingResponseAnthropicAPIKeyPassthrough(ctx, resp, c, account, startTime)
|
streamResult, err := s.handleStreamingResponseAnthropicAPIKeyPassthrough(ctx, resp, c, account, startTime, reqModel)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -3764,6 +3764,7 @@ func (s *GatewayService) handleStreamingResponseAnthropicAPIKeyPassthrough(
|
|||||||
c *gin.Context,
|
c *gin.Context,
|
||||||
account *Account,
|
account *Account,
|
||||||
startTime time.Time,
|
startTime time.Time,
|
||||||
|
model string,
|
||||||
) (*streamingResult, error) {
|
) (*streamingResult, error) {
|
||||||
if s.rateLimitService != nil {
|
if s.rateLimitService != nil {
|
||||||
s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header)
|
s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header)
|
||||||
@@ -3804,10 +3805,80 @@ func (s *GatewayService) handleStreamingResponseAnthropicAPIKeyPassthrough(
|
|||||||
}
|
}
|
||||||
scanBuf := getSSEScannerBuf64K()
|
scanBuf := getSSEScannerBuf64K()
|
||||||
scanner.Buffer(scanBuf[:0], maxLineSize)
|
scanner.Buffer(scanBuf[:0], maxLineSize)
|
||||||
defer putSSEScannerBuf64K(scanBuf)
|
|
||||||
|
|
||||||
|
type scanEvent struct {
|
||||||
|
line string
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
events := make(chan scanEvent, 16)
|
||||||
|
done := make(chan struct{})
|
||||||
|
sendEvent := func(ev scanEvent) bool {
|
||||||
|
select {
|
||||||
|
case events <- ev:
|
||||||
|
return true
|
||||||
|
case <-done:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
var lastReadAt int64
|
||||||
|
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
|
||||||
|
go func(scanBuf *sseScannerBuf64K) {
|
||||||
|
defer putSSEScannerBuf64K(scanBuf)
|
||||||
|
defer close(events)
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
line := scanner.Text()
|
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
|
||||||
|
if !sendEvent(scanEvent{line: scanner.Text()}) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := scanner.Err(); err != nil {
|
||||||
|
_ = sendEvent(scanEvent{err: err})
|
||||||
|
}
|
||||||
|
}(scanBuf)
|
||||||
|
defer close(done)
|
||||||
|
|
||||||
|
streamInterval := time.Duration(0)
|
||||||
|
if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 {
|
||||||
|
streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second
|
||||||
|
}
|
||||||
|
var intervalTicker *time.Ticker
|
||||||
|
if streamInterval > 0 {
|
||||||
|
intervalTicker = time.NewTicker(streamInterval)
|
||||||
|
defer intervalTicker.Stop()
|
||||||
|
}
|
||||||
|
var intervalCh <-chan time.Time
|
||||||
|
if intervalTicker != nil {
|
||||||
|
intervalCh = intervalTicker.C
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case ev, ok := <-events:
|
||||||
|
if !ok {
|
||||||
|
if !clientDisconnected {
|
||||||
|
// 兜底补刷,确保最后一个未以空行结尾的事件也能及时送达客户端。
|
||||||
|
flusher.Flush()
|
||||||
|
}
|
||||||
|
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, nil
|
||||||
|
}
|
||||||
|
if ev.err != nil {
|
||||||
|
if clientDisconnected {
|
||||||
|
logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] Upstream read error after client disconnect: account=%d err=%v", account.ID, ev.err)
|
||||||
|
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
|
||||||
|
}
|
||||||
|
if errors.Is(ev.err, context.Canceled) || errors.Is(ev.err, context.DeadlineExceeded) {
|
||||||
|
logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] 流读取被取消: account=%d request_id=%s err=%v ctx_err=%v",
|
||||||
|
account.ID, resp.Header.Get("x-request-id"), ev.err, ctx.Err())
|
||||||
|
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
|
||||||
|
}
|
||||||
|
if errors.Is(ev.err, bufio.ErrTooLong) {
|
||||||
|
logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err)
|
||||||
|
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, ev.err
|
||||||
|
}
|
||||||
|
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", ev.err)
|
||||||
|
}
|
||||||
|
|
||||||
|
line := ev.line
|
||||||
if data, ok := extractAnthropicSSEDataLine(line); ok {
|
if data, ok := extractAnthropicSSEDataLine(line); ok {
|
||||||
trimmed := strings.TrimSpace(data)
|
trimmed := strings.TrimSpace(data)
|
||||||
if firstTokenMs == nil && trimmed != "" && trimmed != "[DONE]" {
|
if firstTokenMs == nil && trimmed != "" && trimmed != "[DONE]" {
|
||||||
@@ -3829,30 +3900,23 @@ func (s *GatewayService) handleStreamingResponseAnthropicAPIKeyPassthrough(
|
|||||||
flusher.Flush()
|
flusher.Flush()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
if !clientDisconnected {
|
|
||||||
// 兜底补刷,确保最后一个未以空行结尾的事件也能及时送达客户端。
|
|
||||||
flusher.Flush()
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := scanner.Err(); err != nil {
|
case <-intervalCh:
|
||||||
|
lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
|
||||||
|
if time.Since(lastRead) < streamInterval {
|
||||||
|
continue
|
||||||
|
}
|
||||||
if clientDisconnected {
|
if clientDisconnected {
|
||||||
logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] Upstream read error after client disconnect: account=%d err=%v", account.ID, err)
|
logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] Upstream timeout after client disconnect: account=%d model=%s", account.ID, model)
|
||||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
|
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
|
||||||
}
|
}
|
||||||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] Stream data interval timeout: account=%d model=%s interval=%s", account.ID, model, streamInterval)
|
||||||
logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] 流读取被取消: account=%d request_id=%s err=%v ctx_err=%v",
|
if s.rateLimitService != nil {
|
||||||
account.ID, resp.Header.Get("x-request-id"), err, ctx.Err())
|
s.rateLimitService.HandleStreamTimeout(ctx, account, model)
|
||||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
|
|
||||||
}
|
}
|
||||||
if errors.Is(err, bufio.ErrTooLong) {
|
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
|
||||||
logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, err)
|
|
||||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
|
|
||||||
}
|
}
|
||||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func extractAnthropicSSEDataLine(line string) (string, bool) {
|
func extractAnthropicSSEDataLine(line string) (string, bool) {
|
||||||
|
|||||||
Reference in New Issue
Block a user