From 10699eeb34e0e9b602a50d43a1be434d4751d5af Mon Sep 17 00:00:00 2001 From: erio Date: Thu, 16 Apr 2026 01:53:22 +0800 Subject: [PATCH] refactor: extract ReadUpstreamResponseBody to deduplicate upstream response read + too-large error handling Consolidates 9 call sites of resolveUpstreamResponseReadLimit + readUpstreamResponseBodyLimited + ErrUpstreamResponseBodyTooLarge error handling into a single ReadUpstreamResponseBody function with TooLargeWriter callback for API-format-specific error responses (Anthropic, OpenAI, countTokens). --- backend/internal/service/gateway_service.go | 74 +++++-------------- .../service/gemini_messages_compat_service.go | 12 +-- .../service/openai_gateway_service.go | 24 +----- .../service/upstream_response_limit.go | 43 +++++++++++ .../service/upstream_response_limit_test.go | 43 +++++++++++ 5 files changed, 107 insertions(+), 89 deletions(-) diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index c65e828a..4b4fc0bf 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -5120,19 +5120,8 @@ func (s *GatewayService) handleNonStreamingResponseAnthropicAPIKeyPassthrough( s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header) } - maxBytes := resolveUpstreamResponseReadLimit(s.cfg) - body, err := readUpstreamResponseBodyLimited(resp.Body, maxBytes) + body, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, anthropicTooLargeError) if err != nil { - if errors.Is(err, ErrUpstreamResponseBodyTooLarge) { - setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "") - c.JSON(http.StatusBadGateway, gin.H{ - "type": "error", - "error": gin.H{ - "type": "upstream_error", - "message": "Upstream response too large", - }, - }) - } return nil, err } @@ -5498,19 +5487,8 @@ func (s *GatewayService) handleBedrockNonStreamingResponse( c *gin.Context, account *Account, ) (*ClaudeUsage, error) { - maxBytes := resolveUpstreamResponseReadLimit(s.cfg) - body, err := readUpstreamResponseBodyLimited(resp.Body, maxBytes) + body, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, anthropicTooLargeError) if err != nil { - if errors.Is(err, ErrUpstreamResponseBodyTooLarge) { - setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "") - c.JSON(http.StatusBadGateway, gin.H{ - "type": "error", - "error": gin.H{ - "type": "upstream_error", - "message": "Upstream response too large", - }, - }) - } return nil, err } @@ -7175,19 +7153,8 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h // 更新5h窗口状态 s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header) - maxBytes := resolveUpstreamResponseReadLimit(s.cfg) - body, err := readUpstreamResponseBodyLimited(resp.Body, maxBytes) + body, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, anthropicTooLargeError) if err != nil { - if errors.Is(err, ErrUpstreamResponseBodyTooLarge) { - setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "") - c.JSON(http.StatusBadGateway, gin.H{ - "type": "error", - "error": gin.H{ - "type": "upstream_error", - "message": "Upstream response too large", - }, - }) - } return nil, err } @@ -8300,16 +8267,15 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, } // 读取响应体 - maxReadBytes := resolveUpstreamResponseReadLimit(s.cfg) - respBody, err := readUpstreamResponseBodyLimited(resp.Body, maxReadBytes) + countTokensTooLarge := func(c *gin.Context) { + s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Upstream response too large") + } + respBody, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, countTokensTooLarge) _ = resp.Body.Close() if err != nil { - if errors.Is(err, ErrUpstreamResponseBodyTooLarge) { - setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "") - s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Upstream response too large") - return err + if !errors.Is(err, ErrUpstreamResponseBodyTooLarge) { + s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to read response") } - s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to read response") return err } @@ -8323,15 +8289,12 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, retryResp, retryErr := s.httpUpstream.DoWithTLS(retryReq, proxyURL, account.ID, account.Concurrency, s.tlsFPProfileService.ResolveTLSProfile(account)) if retryErr == nil { resp = retryResp - respBody, err = readUpstreamResponseBodyLimited(resp.Body, maxReadBytes) + respBody, err = ReadUpstreamResponseBody(resp.Body, s.cfg, c, countTokensTooLarge) _ = resp.Body.Close() if err != nil { - if errors.Is(err, ErrUpstreamResponseBodyTooLarge) { - setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "") - s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Upstream response too large") - return err + if !errors.Is(err, ErrUpstreamResponseBodyTooLarge) { + s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to read response") } - s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to read response") return err } } @@ -8426,16 +8389,15 @@ func (s *GatewayService) forwardCountTokensAnthropicAPIKeyPassthrough(ctx contex return fmt.Errorf("upstream request failed: %w", err) } - maxReadBytes := resolveUpstreamResponseReadLimit(s.cfg) - respBody, err := readUpstreamResponseBodyLimited(resp.Body, maxReadBytes) + countTokensTooLarge := func(c *gin.Context) { + s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Upstream response too large") + } + respBody, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, countTokensTooLarge) _ = resp.Body.Close() if err != nil { - if errors.Is(err, ErrUpstreamResponseBodyTooLarge) { - setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "") - s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Upstream response too large") - return err + if !errors.Is(err, ErrUpstreamResponseBodyTooLarge) { + s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to read response") } - s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to read response") return err } diff --git a/backend/internal/service/gemini_messages_compat_service.go b/backend/internal/service/gemini_messages_compat_service.go index 5a9490f3..7a24071b 100644 --- a/backend/internal/service/gemini_messages_compat_service.go +++ b/backend/internal/service/gemini_messages_compat_service.go @@ -2424,18 +2424,8 @@ func (s *GeminiMessagesCompatService) handleNativeNonStreamingResponse(c *gin.Co logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ========================================") } - maxBytes := resolveUpstreamResponseReadLimit(s.cfg) - respBody, err := readUpstreamResponseBodyLimited(resp.Body, maxBytes) + respBody, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, openAITooLargeError) if err != nil { - if errors.Is(err, ErrUpstreamResponseBodyTooLarge) { - setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "") - c.JSON(http.StatusBadGateway, gin.H{ - "error": gin.H{ - "type": "upstream_error", - "message": "Upstream response too large", - }, - }) - } return nil, err } diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index ef97daad..064191bd 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -3010,18 +3010,8 @@ func (s *OpenAIGatewayService) handleNonStreamingResponsePassthrough( resp *http.Response, c *gin.Context, ) (*OpenAIUsage, error) { - maxBytes := resolveUpstreamResponseReadLimit(s.cfg) - body, err := readUpstreamResponseBodyLimited(resp.Body, maxBytes) + body, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, openAITooLargeError) if err != nil { - if errors.Is(err, ErrUpstreamResponseBodyTooLarge) { - setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "") - c.JSON(http.StatusBadGateway, gin.H{ - "error": gin.H{ - "type": "upstream_error", - "message": "Upstream response too large", - }, - }) - } return nil, err } @@ -3919,18 +3909,8 @@ func extractOpenAIUsageFromJSONBytes(body []byte) (OpenAIUsage, bool) { } func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string) (*OpenAIUsage, error) { - maxBytes := resolveUpstreamResponseReadLimit(s.cfg) - body, err := readUpstreamResponseBodyLimited(resp.Body, maxBytes) + body, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, openAITooLargeError) if err != nil { - if errors.Is(err, ErrUpstreamResponseBodyTooLarge) { - setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "") - c.JSON(http.StatusBadGateway, gin.H{ - "error": gin.H{ - "type": "upstream_error", - "message": "Upstream response too large", - }, - }) - } return nil, err } diff --git a/backend/internal/service/upstream_response_limit.go b/backend/internal/service/upstream_response_limit.go index aecf69a3..a0444d52 100644 --- a/backend/internal/service/upstream_response_limit.go +++ b/backend/internal/service/upstream_response_limit.go @@ -4,8 +4,10 @@ import ( "errors" "fmt" "io" + "net/http" "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/gin-gonic/gin" ) var ErrUpstreamResponseBodyTooLarge = errors.New("upstream response body too large") @@ -36,3 +38,44 @@ func readUpstreamResponseBodyLimited(reader io.Reader, maxBytes int64) ([]byte, } return body, nil } + +// TooLargeWriter 在响应超限时向客户端写格式化的错误响应。 +type TooLargeWriter func(c *gin.Context) + +// ReadUpstreamResponseBody 读取上游非流式响应体。 +// 超限时自动记录 ops error 并调用 onTooLarge 向客户端写错误。 +func ReadUpstreamResponseBody(reader io.Reader, cfg *config.Config, c *gin.Context, onTooLarge TooLargeWriter) ([]byte, error) { + maxBytes := resolveUpstreamResponseReadLimit(cfg) + body, err := readUpstreamResponseBodyLimited(reader, maxBytes) + if err != nil { + if errors.Is(err, ErrUpstreamResponseBodyTooLarge) { + setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "") + if onTooLarge != nil { + onTooLarge(c) + } + } + return nil, err + } + return body, nil +} + +// anthropicTooLargeError 以 Anthropic Messages API 格式写入超限错误。 +func anthropicTooLargeError(c *gin.Context) { + c.JSON(http.StatusBadGateway, gin.H{ + "type": "error", + "error": gin.H{ + "type": "upstream_error", + "message": "Upstream response too large", + }, + }) +} + +// openAITooLargeError 以 OpenAI / Gemini 格式写入超限错误。 +func openAITooLargeError(c *gin.Context) { + c.JSON(http.StatusBadGateway, gin.H{ + "error": gin.H{ + "type": "upstream_error", + "message": "Upstream response too large", + }, + }) +} diff --git a/backend/internal/service/upstream_response_limit_test.go b/backend/internal/service/upstream_response_limit_test.go index b9e5cc6d..09283189 100644 --- a/backend/internal/service/upstream_response_limit_test.go +++ b/backend/internal/service/upstream_response_limit_test.go @@ -4,8 +4,10 @@ import ( "bytes" "errors" "testing" + "testing/iotest" "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/gin-gonic/gin" "github.com/stretchr/testify/require" ) @@ -35,3 +37,44 @@ func TestReadUpstreamResponseBodyLimited(t *testing.T) { require.True(t, errors.Is(err, ErrUpstreamResponseBodyTooLarge)) }) } + +func TestReadUpstreamResponseBody(t *testing.T) { + t.Run("within limit", func(t *testing.T) { + body, err := ReadUpstreamResponseBody(bytes.NewReader([]byte("ok")), nil, nil, nil) + require.NoError(t, err) + require.Equal(t, []byte("ok"), body) + }) + + t.Run("exceeds limit calls onTooLarge", func(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.UpstreamResponseReadMaxBytes = 3 + + called := false + onTooLarge := func(_ *gin.Context) { called = true } + + body, err := ReadUpstreamResponseBody(bytes.NewReader([]byte("toolong")), cfg, nil, onTooLarge) + require.Nil(t, body) + require.True(t, errors.Is(err, ErrUpstreamResponseBodyTooLarge)) + require.True(t, called) + }) + + t.Run("nil onTooLarge does not panic", func(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.UpstreamResponseReadMaxBytes = 3 + + body, err := ReadUpstreamResponseBody(bytes.NewReader([]byte("toolong")), cfg, nil, nil) + require.Nil(t, body) + require.True(t, errors.Is(err, ErrUpstreamResponseBodyTooLarge)) + }) + + t.Run("io error does not call onTooLarge", func(t *testing.T) { + called := false + onTooLarge := func(_ *gin.Context) { called = true } + + body, err := ReadUpstreamResponseBody(iotest.ErrReader(errors.New("disk failure")), nil, nil, onTooLarge) + require.Nil(t, body) + require.Error(t, err) + require.False(t, errors.Is(err, ErrUpstreamResponseBodyTooLarge)) + require.False(t, called) + }) +}