refactor: extract ReadUpstreamResponseBody to deduplicate upstream response read + too-large error handling
Consolidates 9 call sites of resolveUpstreamResponseReadLimit + readUpstreamResponseBodyLimited + ErrUpstreamResponseBodyTooLarge error handling into a single ReadUpstreamResponseBody function with TooLargeWriter callback for API-format-specific error responses (Anthropic, OpenAI, countTokens).
This commit is contained in:
@@ -5120,19 +5120,8 @@ func (s *GatewayService) handleNonStreamingResponseAnthropicAPIKeyPassthrough(
|
|||||||
s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header)
|
s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header)
|
||||||
}
|
}
|
||||||
|
|
||||||
maxBytes := resolveUpstreamResponseReadLimit(s.cfg)
|
body, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, anthropicTooLargeError)
|
||||||
body, err := readUpstreamResponseBodyLimited(resp.Body, maxBytes)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, ErrUpstreamResponseBodyTooLarge) {
|
|
||||||
setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "")
|
|
||||||
c.JSON(http.StatusBadGateway, gin.H{
|
|
||||||
"type": "error",
|
|
||||||
"error": gin.H{
|
|
||||||
"type": "upstream_error",
|
|
||||||
"message": "Upstream response too large",
|
|
||||||
},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -5498,19 +5487,8 @@ func (s *GatewayService) handleBedrockNonStreamingResponse(
|
|||||||
c *gin.Context,
|
c *gin.Context,
|
||||||
account *Account,
|
account *Account,
|
||||||
) (*ClaudeUsage, error) {
|
) (*ClaudeUsage, error) {
|
||||||
maxBytes := resolveUpstreamResponseReadLimit(s.cfg)
|
body, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, anthropicTooLargeError)
|
||||||
body, err := readUpstreamResponseBodyLimited(resp.Body, maxBytes)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, ErrUpstreamResponseBodyTooLarge) {
|
|
||||||
setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "")
|
|
||||||
c.JSON(http.StatusBadGateway, gin.H{
|
|
||||||
"type": "error",
|
|
||||||
"error": gin.H{
|
|
||||||
"type": "upstream_error",
|
|
||||||
"message": "Upstream response too large",
|
|
||||||
},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -7175,19 +7153,8 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h
|
|||||||
// 更新5h窗口状态
|
// 更新5h窗口状态
|
||||||
s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header)
|
s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header)
|
||||||
|
|
||||||
maxBytes := resolveUpstreamResponseReadLimit(s.cfg)
|
body, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, anthropicTooLargeError)
|
||||||
body, err := readUpstreamResponseBodyLimited(resp.Body, maxBytes)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, ErrUpstreamResponseBodyTooLarge) {
|
|
||||||
setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "")
|
|
||||||
c.JSON(http.StatusBadGateway, gin.H{
|
|
||||||
"type": "error",
|
|
||||||
"error": gin.H{
|
|
||||||
"type": "upstream_error",
|
|
||||||
"message": "Upstream response too large",
|
|
||||||
},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -8300,16 +8267,15 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 读取响应体
|
// 读取响应体
|
||||||
maxReadBytes := resolveUpstreamResponseReadLimit(s.cfg)
|
countTokensTooLarge := func(c *gin.Context) {
|
||||||
respBody, err := readUpstreamResponseBodyLimited(resp.Body, maxReadBytes)
|
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Upstream response too large")
|
||||||
|
}
|
||||||
|
respBody, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, countTokensTooLarge)
|
||||||
_ = resp.Body.Close()
|
_ = resp.Body.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, ErrUpstreamResponseBodyTooLarge) {
|
if !errors.Is(err, ErrUpstreamResponseBodyTooLarge) {
|
||||||
setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "")
|
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to read response")
|
||||||
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Upstream response too large")
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to read response")
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -8323,15 +8289,12 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
|||||||
retryResp, retryErr := s.httpUpstream.DoWithTLS(retryReq, proxyURL, account.ID, account.Concurrency, s.tlsFPProfileService.ResolveTLSProfile(account))
|
retryResp, retryErr := s.httpUpstream.DoWithTLS(retryReq, proxyURL, account.ID, account.Concurrency, s.tlsFPProfileService.ResolveTLSProfile(account))
|
||||||
if retryErr == nil {
|
if retryErr == nil {
|
||||||
resp = retryResp
|
resp = retryResp
|
||||||
respBody, err = readUpstreamResponseBodyLimited(resp.Body, maxReadBytes)
|
respBody, err = ReadUpstreamResponseBody(resp.Body, s.cfg, c, countTokensTooLarge)
|
||||||
_ = resp.Body.Close()
|
_ = resp.Body.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, ErrUpstreamResponseBodyTooLarge) {
|
if !errors.Is(err, ErrUpstreamResponseBodyTooLarge) {
|
||||||
setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "")
|
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to read response")
|
||||||
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Upstream response too large")
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to read response")
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -8426,16 +8389,15 @@ func (s *GatewayService) forwardCountTokensAnthropicAPIKeyPassthrough(ctx contex
|
|||||||
return fmt.Errorf("upstream request failed: %w", err)
|
return fmt.Errorf("upstream request failed: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
maxReadBytes := resolveUpstreamResponseReadLimit(s.cfg)
|
countTokensTooLarge := func(c *gin.Context) {
|
||||||
respBody, err := readUpstreamResponseBodyLimited(resp.Body, maxReadBytes)
|
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Upstream response too large")
|
||||||
|
}
|
||||||
|
respBody, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, countTokensTooLarge)
|
||||||
_ = resp.Body.Close()
|
_ = resp.Body.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, ErrUpstreamResponseBodyTooLarge) {
|
if !errors.Is(err, ErrUpstreamResponseBodyTooLarge) {
|
||||||
setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "")
|
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to read response")
|
||||||
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Upstream response too large")
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to read response")
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -2424,18 +2424,8 @@ func (s *GeminiMessagesCompatService) handleNativeNonStreamingResponse(c *gin.Co
|
|||||||
logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ========================================")
|
logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ========================================")
|
||||||
}
|
}
|
||||||
|
|
||||||
maxBytes := resolveUpstreamResponseReadLimit(s.cfg)
|
respBody, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, openAITooLargeError)
|
||||||
respBody, err := readUpstreamResponseBodyLimited(resp.Body, maxBytes)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, ErrUpstreamResponseBodyTooLarge) {
|
|
||||||
setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "")
|
|
||||||
c.JSON(http.StatusBadGateway, gin.H{
|
|
||||||
"error": gin.H{
|
|
||||||
"type": "upstream_error",
|
|
||||||
"message": "Upstream response too large",
|
|
||||||
},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -3010,18 +3010,8 @@ func (s *OpenAIGatewayService) handleNonStreamingResponsePassthrough(
|
|||||||
resp *http.Response,
|
resp *http.Response,
|
||||||
c *gin.Context,
|
c *gin.Context,
|
||||||
) (*OpenAIUsage, error) {
|
) (*OpenAIUsage, error) {
|
||||||
maxBytes := resolveUpstreamResponseReadLimit(s.cfg)
|
body, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, openAITooLargeError)
|
||||||
body, err := readUpstreamResponseBodyLimited(resp.Body, maxBytes)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, ErrUpstreamResponseBodyTooLarge) {
|
|
||||||
setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "")
|
|
||||||
c.JSON(http.StatusBadGateway, gin.H{
|
|
||||||
"error": gin.H{
|
|
||||||
"type": "upstream_error",
|
|
||||||
"message": "Upstream response too large",
|
|
||||||
},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -3919,18 +3909,8 @@ func extractOpenAIUsageFromJSONBytes(body []byte) (OpenAIUsage, bool) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string) (*OpenAIUsage, error) {
|
func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string) (*OpenAIUsage, error) {
|
||||||
maxBytes := resolveUpstreamResponseReadLimit(s.cfg)
|
body, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, openAITooLargeError)
|
||||||
body, err := readUpstreamResponseBodyLimited(resp.Body, maxBytes)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, ErrUpstreamResponseBodyTooLarge) {
|
|
||||||
setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "")
|
|
||||||
c.JSON(http.StatusBadGateway, gin.H{
|
|
||||||
"error": gin.H{
|
|
||||||
"type": "upstream_error",
|
|
||||||
"message": "Upstream response too large",
|
|
||||||
},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -4,8 +4,10 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
var ErrUpstreamResponseBodyTooLarge = errors.New("upstream response body too large")
|
var ErrUpstreamResponseBodyTooLarge = errors.New("upstream response body too large")
|
||||||
@@ -36,3 +38,44 @@ func readUpstreamResponseBodyLimited(reader io.Reader, maxBytes int64) ([]byte,
|
|||||||
}
|
}
|
||||||
return body, nil
|
return body, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TooLargeWriter 在响应超限时向客户端写格式化的错误响应。
|
||||||
|
type TooLargeWriter func(c *gin.Context)
|
||||||
|
|
||||||
|
// ReadUpstreamResponseBody 读取上游非流式响应体。
|
||||||
|
// 超限时自动记录 ops error 并调用 onTooLarge 向客户端写错误。
|
||||||
|
func ReadUpstreamResponseBody(reader io.Reader, cfg *config.Config, c *gin.Context, onTooLarge TooLargeWriter) ([]byte, error) {
|
||||||
|
maxBytes := resolveUpstreamResponseReadLimit(cfg)
|
||||||
|
body, err := readUpstreamResponseBodyLimited(reader, maxBytes)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, ErrUpstreamResponseBodyTooLarge) {
|
||||||
|
setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "")
|
||||||
|
if onTooLarge != nil {
|
||||||
|
onTooLarge(c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return body, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// anthropicTooLargeError 以 Anthropic Messages API 格式写入超限错误。
|
||||||
|
func anthropicTooLargeError(c *gin.Context) {
|
||||||
|
c.JSON(http.StatusBadGateway, gin.H{
|
||||||
|
"type": "error",
|
||||||
|
"error": gin.H{
|
||||||
|
"type": "upstream_error",
|
||||||
|
"message": "Upstream response too large",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// openAITooLargeError 以 OpenAI / Gemini 格式写入超限错误。
|
||||||
|
func openAITooLargeError(c *gin.Context) {
|
||||||
|
c.JSON(http.StatusBadGateway, gin.H{
|
||||||
|
"error": gin.H{
|
||||||
|
"type": "upstream_error",
|
||||||
|
"message": "Upstream response too large",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
@@ -4,8 +4,10 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"errors"
|
"errors"
|
||||||
"testing"
|
"testing"
|
||||||
|
"testing/iotest"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -35,3 +37,44 @@ func TestReadUpstreamResponseBodyLimited(t *testing.T) {
|
|||||||
require.True(t, errors.Is(err, ErrUpstreamResponseBodyTooLarge))
|
require.True(t, errors.Is(err, ErrUpstreamResponseBodyTooLarge))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestReadUpstreamResponseBody(t *testing.T) {
|
||||||
|
t.Run("within limit", func(t *testing.T) {
|
||||||
|
body, err := ReadUpstreamResponseBody(bytes.NewReader([]byte("ok")), nil, nil, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, []byte("ok"), body)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("exceeds limit calls onTooLarge", func(t *testing.T) {
|
||||||
|
cfg := &config.Config{}
|
||||||
|
cfg.Gateway.UpstreamResponseReadMaxBytes = 3
|
||||||
|
|
||||||
|
called := false
|
||||||
|
onTooLarge := func(_ *gin.Context) { called = true }
|
||||||
|
|
||||||
|
body, err := ReadUpstreamResponseBody(bytes.NewReader([]byte("toolong")), cfg, nil, onTooLarge)
|
||||||
|
require.Nil(t, body)
|
||||||
|
require.True(t, errors.Is(err, ErrUpstreamResponseBodyTooLarge))
|
||||||
|
require.True(t, called)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("nil onTooLarge does not panic", func(t *testing.T) {
|
||||||
|
cfg := &config.Config{}
|
||||||
|
cfg.Gateway.UpstreamResponseReadMaxBytes = 3
|
||||||
|
|
||||||
|
body, err := ReadUpstreamResponseBody(bytes.NewReader([]byte("toolong")), cfg, nil, nil)
|
||||||
|
require.Nil(t, body)
|
||||||
|
require.True(t, errors.Is(err, ErrUpstreamResponseBodyTooLarge))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("io error does not call onTooLarge", func(t *testing.T) {
|
||||||
|
called := false
|
||||||
|
onTooLarge := func(_ *gin.Context) { called = true }
|
||||||
|
|
||||||
|
body, err := ReadUpstreamResponseBody(iotest.ErrReader(errors.New("disk failure")), nil, nil, onTooLarge)
|
||||||
|
require.Nil(t, body)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.False(t, errors.Is(err, ErrUpstreamResponseBodyTooLarge))
|
||||||
|
require.False(t, called)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user