refactor: extract ReadUpstreamResponseBody to deduplicate upstream response read + too-large error handling
Consolidates 9 call sites of resolveUpstreamResponseReadLimit + readUpstreamResponseBodyLimited + ErrUpstreamResponseBodyTooLarge error handling into a single ReadUpstreamResponseBody function with TooLargeWriter callback for API-format-specific error responses (Anthropic, OpenAI, countTokens).
This commit is contained in:
@@ -5120,19 +5120,8 @@ func (s *GatewayService) handleNonStreamingResponseAnthropicAPIKeyPassthrough(
|
||||
s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header)
|
||||
}
|
||||
|
||||
maxBytes := resolveUpstreamResponseReadLimit(s.cfg)
|
||||
body, err := readUpstreamResponseBodyLimited(resp.Body, maxBytes)
|
||||
body, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, anthropicTooLargeError)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrUpstreamResponseBodyTooLarge) {
|
||||
setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "")
|
||||
c.JSON(http.StatusBadGateway, gin.H{
|
||||
"type": "error",
|
||||
"error": gin.H{
|
||||
"type": "upstream_error",
|
||||
"message": "Upstream response too large",
|
||||
},
|
||||
})
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -5498,19 +5487,8 @@ func (s *GatewayService) handleBedrockNonStreamingResponse(
|
||||
c *gin.Context,
|
||||
account *Account,
|
||||
) (*ClaudeUsage, error) {
|
||||
maxBytes := resolveUpstreamResponseReadLimit(s.cfg)
|
||||
body, err := readUpstreamResponseBodyLimited(resp.Body, maxBytes)
|
||||
body, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, anthropicTooLargeError)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrUpstreamResponseBodyTooLarge) {
|
||||
setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "")
|
||||
c.JSON(http.StatusBadGateway, gin.H{
|
||||
"type": "error",
|
||||
"error": gin.H{
|
||||
"type": "upstream_error",
|
||||
"message": "Upstream response too large",
|
||||
},
|
||||
})
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -7175,19 +7153,8 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h
|
||||
// 更新5h窗口状态
|
||||
s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header)
|
||||
|
||||
maxBytes := resolveUpstreamResponseReadLimit(s.cfg)
|
||||
body, err := readUpstreamResponseBodyLimited(resp.Body, maxBytes)
|
||||
body, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, anthropicTooLargeError)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrUpstreamResponseBodyTooLarge) {
|
||||
setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "")
|
||||
c.JSON(http.StatusBadGateway, gin.H{
|
||||
"type": "error",
|
||||
"error": gin.H{
|
||||
"type": "upstream_error",
|
||||
"message": "Upstream response too large",
|
||||
},
|
||||
})
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -8300,16 +8267,15 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
||||
}
|
||||
|
||||
// 读取响应体
|
||||
maxReadBytes := resolveUpstreamResponseReadLimit(s.cfg)
|
||||
respBody, err := readUpstreamResponseBodyLimited(resp.Body, maxReadBytes)
|
||||
countTokensTooLarge := func(c *gin.Context) {
|
||||
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Upstream response too large")
|
||||
}
|
||||
respBody, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, countTokensTooLarge)
|
||||
_ = resp.Body.Close()
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrUpstreamResponseBodyTooLarge) {
|
||||
setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "")
|
||||
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Upstream response too large")
|
||||
return err
|
||||
if !errors.Is(err, ErrUpstreamResponseBodyTooLarge) {
|
||||
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to read response")
|
||||
}
|
||||
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to read response")
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -8323,15 +8289,12 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
||||
retryResp, retryErr := s.httpUpstream.DoWithTLS(retryReq, proxyURL, account.ID, account.Concurrency, s.tlsFPProfileService.ResolveTLSProfile(account))
|
||||
if retryErr == nil {
|
||||
resp = retryResp
|
||||
respBody, err = readUpstreamResponseBodyLimited(resp.Body, maxReadBytes)
|
||||
respBody, err = ReadUpstreamResponseBody(resp.Body, s.cfg, c, countTokensTooLarge)
|
||||
_ = resp.Body.Close()
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrUpstreamResponseBodyTooLarge) {
|
||||
setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "")
|
||||
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Upstream response too large")
|
||||
return err
|
||||
if !errors.Is(err, ErrUpstreamResponseBodyTooLarge) {
|
||||
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to read response")
|
||||
}
|
||||
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to read response")
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -8426,16 +8389,15 @@ func (s *GatewayService) forwardCountTokensAnthropicAPIKeyPassthrough(ctx contex
|
||||
return fmt.Errorf("upstream request failed: %w", err)
|
||||
}
|
||||
|
||||
maxReadBytes := resolveUpstreamResponseReadLimit(s.cfg)
|
||||
respBody, err := readUpstreamResponseBodyLimited(resp.Body, maxReadBytes)
|
||||
countTokensTooLarge := func(c *gin.Context) {
|
||||
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Upstream response too large")
|
||||
}
|
||||
respBody, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, countTokensTooLarge)
|
||||
_ = resp.Body.Close()
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrUpstreamResponseBodyTooLarge) {
|
||||
setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "")
|
||||
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Upstream response too large")
|
||||
return err
|
||||
if !errors.Is(err, ErrUpstreamResponseBodyTooLarge) {
|
||||
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to read response")
|
||||
}
|
||||
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to read response")
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user