diff --git a/backend/internal/service/antigravity_gateway_service.go b/backend/internal/service/antigravity_gateway_service.go index 4ca32829..aa24c60a 100644 --- a/backend/internal/service/antigravity_gateway_service.go +++ b/backend/internal/service/antigravity_gateway_service.go @@ -1587,8 +1587,9 @@ func (s *AntigravityGatewayService) streamUpstreamResponse(c *gin.Context, resp var firstTokenRecorded bool scanner := bufio.NewScanner(resp.Body) - buf := make([]byte, 0, 64*1024) - scanner.Buffer(buf, 1024*1024) + scanBuf := getSSEScannerBuf64K() + defer putSSEScannerBuf64K(scanBuf) + scanner.Buffer(scanBuf[:0], 1024*1024) for scanner.Scan() { line := scanner.Bytes() @@ -2120,7 +2121,8 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context if s.settingService.cfg != nil && s.settingService.cfg.Gateway.MaxLineSize > 0 { maxLineSize = s.settingService.cfg.Gateway.MaxLineSize } - scanner.Buffer(make([]byte, 64*1024), maxLineSize) + scanBuf := getSSEScannerBuf64K() + scanner.Buffer(scanBuf[:0], maxLineSize) usage := &ClaudeUsage{} var firstTokenMs *int @@ -2141,7 +2143,8 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context } var lastReadAt int64 atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) - go func() { + go func(scanBuf *sseScannerBuf64K) { + defer putSSEScannerBuf64K(scanBuf) defer close(events) for scanner.Scan() { atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) @@ -2152,7 +2155,7 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context if err := scanner.Err(); err != nil { _ = sendEvent(scanEvent{err: err}) } - }() + }(scanBuf) defer close(done) // 上游数据间隔超时保护(防止上游挂起长期占用连接) @@ -2277,7 +2280,8 @@ func (s *AntigravityGatewayService) handleGeminiStreamToNonStreaming(c *gin.Cont if s.settingService.cfg != nil && s.settingService.cfg.Gateway.MaxLineSize > 0 { maxLineSize = s.settingService.cfg.Gateway.MaxLineSize } - scanner.Buffer(make([]byte, 64*1024), maxLineSize) + scanBuf := getSSEScannerBuf64K() + scanner.Buffer(scanBuf[:0], maxLineSize) usage := &ClaudeUsage{} var firstTokenMs *int @@ -2305,7 +2309,8 @@ func (s *AntigravityGatewayService) handleGeminiStreamToNonStreaming(c *gin.Cont var lastReadAt int64 atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) - go func() { + go func(scanBuf *sseScannerBuf64K) { + defer putSSEScannerBuf64K(scanBuf) defer close(events) for scanner.Scan() { atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) @@ -2316,7 +2321,7 @@ func (s *AntigravityGatewayService) handleGeminiStreamToNonStreaming(c *gin.Cont if err := scanner.Err(); err != nil { _ = sendEvent(scanEvent{err: err}) } - }() + }(scanBuf) defer close(done) // 上游数据间隔超时保护(防止上游挂起长期占用连接) @@ -2728,7 +2733,8 @@ func (s *AntigravityGatewayService) handleClaudeStreamToNonStreaming(c *gin.Cont if s.settingService.cfg != nil && s.settingService.cfg.Gateway.MaxLineSize > 0 { maxLineSize = s.settingService.cfg.Gateway.MaxLineSize } - scanner.Buffer(make([]byte, 64*1024), maxLineSize) + scanBuf := getSSEScannerBuf64K() + scanner.Buffer(scanBuf[:0], maxLineSize) var firstTokenMs *int var last map[string]any @@ -2754,7 +2760,8 @@ func (s *AntigravityGatewayService) handleClaudeStreamToNonStreaming(c *gin.Cont var lastReadAt int64 atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) - go func() { + go func(scanBuf *sseScannerBuf64K) { + defer putSSEScannerBuf64K(scanBuf) defer close(events) for scanner.Scan() { atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) @@ -2765,7 +2772,7 @@ func (s *AntigravityGatewayService) handleClaudeStreamToNonStreaming(c *gin.Cont if err := scanner.Err(); err != nil { _ = sendEvent(scanEvent{err: err}) } - }() + }(scanBuf) defer close(done) // 上游数据间隔超时保护(防止上游挂起长期占用连接) @@ -2908,7 +2915,8 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context if s.settingService.cfg != nil && s.settingService.cfg.Gateway.MaxLineSize > 0 { maxLineSize = s.settingService.cfg.Gateway.MaxLineSize } - scanner.Buffer(make([]byte, 64*1024), maxLineSize) + scanBuf := getSSEScannerBuf64K() + scanner.Buffer(scanBuf[:0], maxLineSize) // 辅助函数:转换 antigravity.ClaudeUsage 到 service.ClaudeUsage convertUsage := func(agUsage *antigravity.ClaudeUsage) *ClaudeUsage { @@ -2940,7 +2948,8 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context } var lastReadAt int64 atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) - go func() { + go func(scanBuf *sseScannerBuf64K) { + defer putSSEScannerBuf64K(scanBuf) defer close(events) for scanner.Scan() { atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) @@ -2951,7 +2960,7 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context if err := scanner.Err(); err != nil { _ = sendEvent(scanEvent{err: err}) } - }() + }(scanBuf) defer close(done) streamInterval := time.Duration(0) diff --git a/backend/internal/service/antigravity_gateway_service_test.go b/backend/internal/service/antigravity_gateway_service_test.go index 32a591ef..ab8448eb 100644 --- a/backend/internal/service/antigravity_gateway_service_test.go +++ b/backend/internal/service/antigravity_gateway_service_test.go @@ -7,7 +7,9 @@ import ( "io" "net/http" "net/http/httptest" + "strings" "testing" + "time" "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" "github.com/gin-gonic/gin" @@ -190,3 +192,37 @@ func TestAntigravityMaxRetriesForModel_AfterSwitchFallback(t *testing.T) { got := antigravityMaxRetriesForModel("gemini-2.5-flash", true) require.Equal(t, 5, got) } + +func TestAntigravityStreamUpstreamResponse_UsageAndFirstToken(t *testing.T) { + gin.SetMode(gin.TestMode) + writer := httptest.NewRecorder() + c, _ := gin.CreateTestContext(writer) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + pr, pw := io.Pipe() + resp := &http.Response{StatusCode: http.StatusOK, Header: http.Header{}, Body: pr} + + go func() { + defer func() { _ = pw.Close() }() + _, _ = pw.Write([]byte("data: {\"usage\":{\"input_tokens\":1,\"output_tokens\":2,\"cache_read_input_tokens\":3,\"cache_creation_input_tokens\":4}}\n")) + _, _ = pw.Write([]byte("data: {\"usage\":{\"output_tokens\":5}}\n")) + }() + + svc := &AntigravityGatewayService{} + start := time.Now().Add(-10 * time.Millisecond) + usage, firstTokenMs := svc.streamUpstreamResponse(c, resp, start) + _ = pr.Close() + + require.NotNil(t, usage) + require.Equal(t, 1, usage.InputTokens) + // 第二次事件覆盖 output_tokens + require.Equal(t, 5, usage.OutputTokens) + require.Equal(t, 3, usage.CacheReadInputTokens) + require.Equal(t, 4, usage.CacheCreationInputTokens) + + if firstTokenMs == nil { + t.Fatalf("expected firstTokenMs to be set") + } + // 确保有透传输出 + require.True(t, strings.Contains(writer.Body.String(), "data:")) +} diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index bbfb1723..66325404 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -4031,7 +4031,8 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 { maxLineSize = s.cfg.Gateway.MaxLineSize } - scanner.Buffer(make([]byte, 64*1024), maxLineSize) + scanBuf := getSSEScannerBuf64K() + scanner.Buffer(scanBuf[:0], maxLineSize) type scanEvent struct { line string @@ -4050,7 +4051,8 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http } var lastReadAt int64 atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) - go func() { + go func(scanBuf *sseScannerBuf64K) { + defer putSSEScannerBuf64K(scanBuf) defer close(events) for scanner.Scan() { atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) @@ -4061,7 +4063,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http if err := scanner.Err(); err != nil { _ = sendEvent(scanEvent{err: err}) } - }() + }(scanBuf) defer close(done) streamInterval := time.Duration(0) diff --git a/backend/internal/service/gateway_service_streaming_test.go b/backend/internal/service/gateway_service_streaming_test.go new file mode 100644 index 00000000..48667f58 --- /dev/null +++ b/backend/internal/service/gateway_service_streaming_test.go @@ -0,0 +1,52 @@ +package service + +import ( + "context" + "io" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func TestGatewayService_StreamingReusesScannerBufferAndStillParsesUsage(t *testing.T) { + gin.SetMode(gin.TestMode) + cfg := &config.Config{ + Gateway: config.GatewayConfig{ + StreamDataIntervalTimeout: 0, + MaxLineSize: defaultMaxLineSize, + }, + } + + svc := &GatewayService{ + cfg: cfg, + rateLimitService: &RateLimitService{}, + } + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + pr, pw := io.Pipe() + resp := &http.Response{StatusCode: http.StatusOK, Header: http.Header{}, Body: pr} + + go func() { + defer func() { _ = pw.Close() }() + // Minimal SSE event to trigger parseSSEUsage + _, _ = pw.Write([]byte("data: {\"type\":\"message_start\",\"message\":{\"usage\":{\"input_tokens\":3}}}\n\n")) + _, _ = pw.Write([]byte("data: {\"type\":\"message_delta\",\"usage\":{\"output_tokens\":7}}\n\n")) + _, _ = pw.Write([]byte("data: [DONE]\n\n")) + }() + + result, err := svc.handleStreamingResponse(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "model", "model", nil, false) + _ = pr.Close() + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, result.usage) + require.Equal(t, 3, result.usage.InputTokens) + require.Equal(t, 7, result.usage.OutputTokens) +} diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index 564ffa4d..69c8aa9f 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -1209,7 +1209,8 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 { maxLineSize = s.cfg.Gateway.MaxLineSize } - scanner.Buffer(make([]byte, 64*1024), maxLineSize) + scanBuf := getSSEScannerBuf64K() + scanner.Buffer(scanBuf[:0], maxLineSize) type scanEvent struct { line string @@ -1228,7 +1229,8 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp } var lastReadAt int64 atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) - go func() { + go func(scanBuf *sseScannerBuf64K) { + defer putSSEScannerBuf64K(scanBuf) defer close(events) for scanner.Scan() { atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) @@ -1239,7 +1241,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp if err := scanner.Err(); err != nil { _ = sendEvent(scanEvent{err: err}) } - }() + }(scanBuf) defer close(done) streamInterval := time.Duration(0) diff --git a/backend/internal/service/openai_gateway_service_test.go b/backend/internal/service/openai_gateway_service_test.go index ae69a986..5d4355fd 100644 --- a/backend/internal/service/openai_gateway_service_test.go +++ b/backend/internal/service/openai_gateway_service_test.go @@ -14,6 +14,7 @@ import ( "github.com/Wei-Shaw/sub2api/internal/config" "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" ) type stubOpenAIAccountRepo struct { @@ -1066,6 +1067,43 @@ func TestOpenAIStreamingHeadersOverride(t *testing.T) { } } +func TestOpenAIStreamingReuseScannerBufferAndStillWorks(t *testing.T) { + gin.SetMode(gin.TestMode) + cfg := &config.Config{ + Gateway: config.GatewayConfig{ + StreamDataIntervalTimeout: 0, + StreamKeepaliveInterval: 0, + MaxLineSize: defaultMaxLineSize, + }, + } + svc := &OpenAIGatewayService{cfg: cfg} + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/", nil) + + pr, pw := io.Pipe() + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: pr, + Header: http.Header{}, + } + + go func() { + defer func() { _ = pw.Close() }() + _, _ = pw.Write([]byte("data: {\"type\":\"response.completed\",\"response\":{\"usage\":{\"input_tokens\":1,\"output_tokens\":2,\"input_tokens_details\":{\"cached_tokens\":3}}}}\n\n")) + }() + + result, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "model", "model") + _ = pr.Close() + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, result.usage) + require.Equal(t, 1, result.usage.InputTokens) + require.Equal(t, 2, result.usage.OutputTokens) + require.Equal(t, 3, result.usage.CacheReadInputTokens) +} + func TestOpenAIInvalidBaseURLWhenAllowlistDisabled(t *testing.T) { gin.SetMode(gin.TestMode) cfg := &config.Config{ diff --git a/backend/internal/service/sse_scanner_buffer_pool.go b/backend/internal/service/sse_scanner_buffer_pool.go new file mode 100644 index 00000000..7475547f --- /dev/null +++ b/backend/internal/service/sse_scanner_buffer_pool.go @@ -0,0 +1,24 @@ +package service + +import "sync" + +const sseScannerBuf64KSize = 64 * 1024 + +type sseScannerBuf64K [sseScannerBuf64KSize]byte + +var sseScannerBuf64KPool = sync.Pool{ + New: func() any { + return new(sseScannerBuf64K) + }, +} + +func getSSEScannerBuf64K() *sseScannerBuf64K { + return sseScannerBuf64KPool.Get().(*sseScannerBuf64K) +} + +func putSSEScannerBuf64K(buf *sseScannerBuf64K) { + if buf == nil { + return + } + sseScannerBuf64KPool.Put(buf) +} diff --git a/backend/internal/service/sse_scanner_buffer_pool_test.go b/backend/internal/service/sse_scanner_buffer_pool_test.go new file mode 100644 index 00000000..09b8ad21 --- /dev/null +++ b/backend/internal/service/sse_scanner_buffer_pool_test.go @@ -0,0 +1,19 @@ +package service + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestSSEScannerBuf64KPool_GetPutDoesNotPanic(t *testing.T) { + buf := getSSEScannerBuf64K() + require.NotNil(t, buf) + require.Equal(t, sseScannerBuf64KSize, len(buf[:])) + + buf[0] = 1 + putSSEScannerBuf64K(buf) + + // 允许传入 nil,确保不会 panic + putSSEScannerBuf64K(nil) +}