feat(backend): 提交后端审计修复与配套测试改动
This commit is contained in:
@@ -308,6 +308,12 @@ type GatewayConfig struct {
|
||||
ResponseHeaderTimeout int `mapstructure:"response_header_timeout"`
|
||||
// 请求体最大字节数,用于网关请求体大小限制
|
||||
MaxBodySize int64 `mapstructure:"max_body_size"`
|
||||
// 非流式上游响应体读取上限(字节),用于防止无界读取导致内存放大
|
||||
UpstreamResponseReadMaxBytes int64 `mapstructure:"upstream_response_read_max_bytes"`
|
||||
// 代理探测响应体读取上限(字节)
|
||||
ProxyProbeResponseReadMaxBytes int64 `mapstructure:"proxy_probe_response_read_max_bytes"`
|
||||
// Gemini 上游响应头调试日志开关(默认关闭,避免高频日志开销)
|
||||
GeminiDebugResponseHeaders bool `mapstructure:"gemini_debug_response_headers"`
|
||||
// ConnectionPoolIsolation: 上游连接池隔离策略(proxy/account/account_proxy)
|
||||
ConnectionPoolIsolation string `mapstructure:"connection_pool_isolation"`
|
||||
// ForceCodexCLI: 强制将 OpenAI `/v1/responses` 请求按 Codex CLI 处理。
|
||||
@@ -1059,6 +1065,9 @@ func setDefaults() {
|
||||
viper.SetDefault("gateway.openai_passthrough_allow_timeout_headers", false)
|
||||
viper.SetDefault("gateway.antigravity_fallback_cooldown_minutes", 1)
|
||||
viper.SetDefault("gateway.max_body_size", int64(100*1024*1024))
|
||||
viper.SetDefault("gateway.upstream_response_read_max_bytes", int64(8*1024*1024))
|
||||
viper.SetDefault("gateway.proxy_probe_response_read_max_bytes", int64(1024*1024))
|
||||
viper.SetDefault("gateway.gemini_debug_response_headers", false)
|
||||
viper.SetDefault("gateway.sora_max_body_size", int64(256*1024*1024))
|
||||
viper.SetDefault("gateway.sora_stream_timeout_seconds", 900)
|
||||
viper.SetDefault("gateway.sora_request_timeout_seconds", 180)
|
||||
@@ -1465,6 +1474,12 @@ func (c *Config) Validate() error {
|
||||
if c.Gateway.MaxBodySize <= 0 {
|
||||
return fmt.Errorf("gateway.max_body_size must be positive")
|
||||
}
|
||||
if c.Gateway.UpstreamResponseReadMaxBytes <= 0 {
|
||||
return fmt.Errorf("gateway.upstream_response_read_max_bytes must be positive")
|
||||
}
|
||||
if c.Gateway.ProxyProbeResponseReadMaxBytes <= 0 {
|
||||
return fmt.Errorf("gateway.proxy_probe_response_read_max_bytes must be positive")
|
||||
}
|
||||
if c.Gateway.SoraMaxBodySize < 0 {
|
||||
return fmt.Errorf("gateway.sora_max_body_size must be non-negative")
|
||||
}
|
||||
|
||||
@@ -418,8 +418,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
}
|
||||
continue
|
||||
}
|
||||
// 错误响应已在Forward中处理,这里只记录日志
|
||||
reqLog.Error("gateway.forward_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
|
||||
reqLog.Error("gateway.forward_failed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Bool("fallback_error_response_written", wroteFallback),
|
||||
zap.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -683,8 +687,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
}
|
||||
continue
|
||||
}
|
||||
// 错误响应已在Forward中处理,这里只记录日志
|
||||
reqLog.Error("gateway.forward_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
|
||||
reqLog.Error("gateway.forward_failed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Bool("fallback_error_response_written", wroteFallback),
|
||||
zap.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1117,6 +1125,15 @@ func (h *GatewayHandler) handleStreamingAwareError(c *gin.Context, status int, e
|
||||
h.errorResponse(c, status, errType, message)
|
||||
}
|
||||
|
||||
// ensureForwardErrorResponse 在 Forward 返回错误但尚未写响应时补写统一错误响应。
|
||||
func (h *GatewayHandler) ensureForwardErrorResponse(c *gin.Context, streamStarted bool) bool {
|
||||
if c == nil || c.Writer == nil || c.Writer.Written() {
|
||||
return false
|
||||
}
|
||||
h.handleStreamingAwareError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed", streamStarted)
|
||||
return true
|
||||
}
|
||||
|
||||
// errorResponse 返回Claude API格式的错误响应
|
||||
func (h *GatewayHandler) errorResponse(c *gin.Context, status int, errType, message string) {
|
||||
c.JSON(status, gin.H{
|
||||
|
||||
@@ -0,0 +1,49 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGatewayEnsureForwardErrorResponse_WritesFallbackWhenNotWritten(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
|
||||
h := &GatewayHandler{}
|
||||
wrote := h.ensureForwardErrorResponse(c, false)
|
||||
|
||||
require.True(t, wrote)
|
||||
require.Equal(t, http.StatusBadGateway, w.Code)
|
||||
|
||||
var parsed map[string]any
|
||||
err := json.Unmarshal(w.Body.Bytes(), &parsed)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "error", parsed["type"])
|
||||
errorObj, ok := parsed["error"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "upstream_error", errorObj["type"])
|
||||
assert.Equal(t, "Upstream request failed", errorObj["message"])
|
||||
}
|
||||
|
||||
func TestGatewayEnsureForwardErrorResponse_DoesNotOverrideWrittenResponse(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
c.String(http.StatusTeapot, "already written")
|
||||
|
||||
h := &GatewayHandler{}
|
||||
wrote := h.ensureForwardErrorResponse(c, false)
|
||||
|
||||
require.False(t, wrote)
|
||||
require.Equal(t, http.StatusTeapot, w.Code)
|
||||
assert.Equal(t, "already written", w.Body.String())
|
||||
}
|
||||
@@ -365,8 +365,12 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
)
|
||||
continue
|
||||
}
|
||||
// Error response already handled in Forward, just log
|
||||
reqLog.Error("openai.forward_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
|
||||
reqLog.Error("openai.forward_failed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Bool("fallback_error_response_written", wroteFallback),
|
||||
zap.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -521,6 +525,15 @@ func (h *OpenAIGatewayHandler) handleStreamingAwareError(c *gin.Context, status
|
||||
h.errorResponse(c, status, errType, message)
|
||||
}
|
||||
|
||||
// ensureForwardErrorResponse 在 Forward 返回错误但尚未写响应时补写统一错误响应。
|
||||
func (h *OpenAIGatewayHandler) ensureForwardErrorResponse(c *gin.Context, streamStarted bool) bool {
|
||||
if c == nil || c.Writer == nil || c.Writer.Written() {
|
||||
return false
|
||||
}
|
||||
h.handleStreamingAwareError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed", streamStarted)
|
||||
return true
|
||||
}
|
||||
|
||||
// errorResponse returns OpenAI API format error response
|
||||
func (h *OpenAIGatewayHandler) errorResponse(c *gin.Context, status int, errType, message string) {
|
||||
c.JSON(status, gin.H{
|
||||
|
||||
@@ -105,6 +105,42 @@ func TestOpenAIHandleStreamingAwareError_NonStreaming(t *testing.T) {
|
||||
assert.Equal(t, "test error", errorObj["message"])
|
||||
}
|
||||
|
||||
func TestOpenAIEnsureForwardErrorResponse_WritesFallbackWhenNotWritten(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
|
||||
h := &OpenAIGatewayHandler{}
|
||||
wrote := h.ensureForwardErrorResponse(c, false)
|
||||
|
||||
require.True(t, wrote)
|
||||
require.Equal(t, http.StatusBadGateway, w.Code)
|
||||
|
||||
var parsed map[string]any
|
||||
err := json.Unmarshal(w.Body.Bytes(), &parsed)
|
||||
require.NoError(t, err)
|
||||
errorObj, ok := parsed["error"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "upstream_error", errorObj["type"])
|
||||
assert.Equal(t, "Upstream request failed", errorObj["message"])
|
||||
}
|
||||
|
||||
func TestOpenAIEnsureForwardErrorResponse_DoesNotOverrideWrittenResponse(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
c.String(http.StatusTeapot, "already written")
|
||||
|
||||
h := &OpenAIGatewayHandler{}
|
||||
wrote := h.ensureForwardErrorResponse(c, false)
|
||||
|
||||
require.False(t, wrote)
|
||||
require.Equal(t, http.StatusTeapot, w.Code)
|
||||
assert.Equal(t, "already written", w.Body.String())
|
||||
}
|
||||
|
||||
// TestOpenAIHandler_GjsonExtraction 验证 gjson 从请求体中提取 model/stream 的正确性
|
||||
func TestOpenAIHandler_GjsonExtraction(t *testing.T) {
|
||||
tests := []struct {
|
||||
|
||||
@@ -44,6 +44,16 @@ func GetClientIP(c *gin.Context) string {
|
||||
return normalizeIP(c.ClientIP())
|
||||
}
|
||||
|
||||
// GetTrustedClientIP 从 Gin 的可信代理解析链提取客户端 IP。
|
||||
// 该方法依赖 gin.Engine.SetTrustedProxies 配置,不会优先直接信任原始转发头值。
|
||||
// 适用于 ACL / 风控等安全敏感场景。
|
||||
func GetTrustedClientIP(c *gin.Context) string {
|
||||
if c == nil {
|
||||
return ""
|
||||
}
|
||||
return normalizeIP(c.ClientIP())
|
||||
}
|
||||
|
||||
// normalizeIP 规范化 IP 地址,去除端口号和空格。
|
||||
func normalizeIP(ip string) string {
|
||||
ip = strings.TrimSpace(ip)
|
||||
|
||||
@@ -3,8 +3,10 @@
|
||||
package ip
|
||||
|
||||
import (
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
@@ -49,3 +51,25 @@ func TestIsPrivateIP(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetTrustedClientIPUsesGinClientIP(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
r := gin.New()
|
||||
require.NoError(t, r.SetTrustedProxies(nil))
|
||||
|
||||
r.GET("/t", func(c *gin.Context) {
|
||||
c.String(200, GetTrustedClientIP(c))
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest("GET", "/t", nil)
|
||||
req.RemoteAddr = "9.9.9.9:12345"
|
||||
req.Header.Set("X-Forwarded-For", "1.2.3.4")
|
||||
req.Header.Set("X-Real-IP", "1.2.3.4")
|
||||
req.Header.Set("CF-Connecting-IP", "1.2.3.4")
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, 200, w.Code)
|
||||
require.Equal(t, "9.9.9.9", w.Body.String())
|
||||
}
|
||||
|
||||
@@ -19,10 +19,14 @@ func NewProxyExitInfoProber(cfg *config.Config) service.ProxyExitInfoProber {
|
||||
insecure := false
|
||||
allowPrivate := false
|
||||
validateResolvedIP := true
|
||||
maxResponseBytes := defaultProxyProbeResponseMaxBytes
|
||||
if cfg != nil {
|
||||
insecure = cfg.Security.ProxyProbe.InsecureSkipVerify
|
||||
allowPrivate = cfg.Security.URLAllowlist.AllowPrivateHosts
|
||||
validateResolvedIP = cfg.Security.URLAllowlist.Enabled
|
||||
if cfg.Gateway.ProxyProbeResponseReadMaxBytes > 0 {
|
||||
maxResponseBytes = cfg.Gateway.ProxyProbeResponseReadMaxBytes
|
||||
}
|
||||
}
|
||||
if insecure {
|
||||
log.Printf("[ProxyProbe] Warning: insecure_skip_verify is not allowed and will cause probe failure.")
|
||||
@@ -31,11 +35,13 @@ func NewProxyExitInfoProber(cfg *config.Config) service.ProxyExitInfoProber {
|
||||
insecureSkipVerify: insecure,
|
||||
allowPrivateHosts: allowPrivate,
|
||||
validateResolvedIP: validateResolvedIP,
|
||||
maxResponseBytes: maxResponseBytes,
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
defaultProxyProbeTimeout = 30 * time.Second
|
||||
defaultProxyProbeTimeout = 30 * time.Second
|
||||
defaultProxyProbeResponseMaxBytes = int64(1024 * 1024)
|
||||
)
|
||||
|
||||
// probeURLs 按优先级排列的探测 URL 列表
|
||||
@@ -52,6 +58,7 @@ type proxyProbeService struct {
|
||||
insecureSkipVerify bool
|
||||
allowPrivateHosts bool
|
||||
validateResolvedIP bool
|
||||
maxResponseBytes int64
|
||||
}
|
||||
|
||||
func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*service.ProxyExitInfo, int64, error) {
|
||||
@@ -98,10 +105,17 @@ func (s *proxyProbeService) probeWithURL(ctx context.Context, client *http.Clien
|
||||
return nil, latencyMs, fmt.Errorf("request failed with status: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
maxResponseBytes := s.maxResponseBytes
|
||||
if maxResponseBytes <= 0 {
|
||||
maxResponseBytes = defaultProxyProbeResponseMaxBytes
|
||||
}
|
||||
body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBytes+1))
|
||||
if err != nil {
|
||||
return nil, latencyMs, fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
if int64(len(body)) > maxResponseBytes {
|
||||
return nil, latencyMs, fmt.Errorf("proxy probe response exceeds limit: %d", maxResponseBytes)
|
||||
}
|
||||
|
||||
switch parser {
|
||||
case "ip-api":
|
||||
|
||||
@@ -51,6 +51,9 @@ func ProvideRouter(
|
||||
if err := r.SetTrustedProxies(nil); err != nil {
|
||||
log.Printf("Failed to disable trusted proxies: %v", err)
|
||||
}
|
||||
if cfg.Server.Mode == "release" {
|
||||
log.Printf("Warning: server.trusted_proxies is empty in release mode; client IP trust chain is disabled")
|
||||
}
|
||||
}
|
||||
|
||||
return SetupRouter(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg, redisClient)
|
||||
|
||||
@@ -96,7 +96,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
|
||||
// 检查 IP 限制(白名单/黑名单)
|
||||
// 注意:错误信息故意模糊,避免暴露具体的 IP 限制机制
|
||||
if len(apiKey.IPWhitelist) > 0 || len(apiKey.IPBlacklist) > 0 {
|
||||
clientIP := ip.GetClientIP(c)
|
||||
clientIP := ip.GetTrustedClientIP(c)
|
||||
allowed, _ := ip.CheckIPRestriction(clientIP, apiKey.IPWhitelist, apiKey.IPBlacklist)
|
||||
if !allowed {
|
||||
AbortWithError(c, 403, "ACCESS_DENIED", "Access denied")
|
||||
|
||||
@@ -300,6 +300,57 @@ func TestAPIKeyAuthOverwritesInvalidContextGroup(t *testing.T) {
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
func TestAPIKeyAuthIPRestrictionDoesNotTrustSpoofedForwardHeaders(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
user := &service.User{
|
||||
ID: 7,
|
||||
Role: service.RoleUser,
|
||||
Status: service.StatusActive,
|
||||
Balance: 10,
|
||||
Concurrency: 3,
|
||||
}
|
||||
apiKey := &service.APIKey{
|
||||
ID: 100,
|
||||
UserID: user.ID,
|
||||
Key: "test-key",
|
||||
Status: service.StatusActive,
|
||||
User: user,
|
||||
IPWhitelist: []string{"1.2.3.4"},
|
||||
}
|
||||
|
||||
apiKeyRepo := &stubApiKeyRepo{
|
||||
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
|
||||
if key != apiKey.Key {
|
||||
return nil, service.ErrAPIKeyNotFound
|
||||
}
|
||||
clone := *apiKey
|
||||
return &clone, nil
|
||||
},
|
||||
}
|
||||
|
||||
cfg := &config.Config{RunMode: config.RunModeSimple}
|
||||
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg)
|
||||
router := gin.New()
|
||||
require.NoError(t, router.SetTrustedProxies(nil))
|
||||
router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, nil, cfg)))
|
||||
router.GET("/t", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/t", nil)
|
||||
req.RemoteAddr = "9.9.9.9:12345"
|
||||
req.Header.Set("x-api-key", apiKey.Key)
|
||||
req.Header.Set("X-Forwarded-For", "1.2.3.4")
|
||||
req.Header.Set("X-Real-IP", "1.2.3.4")
|
||||
req.Header.Set("CF-Connecting-IP", "1.2.3.4")
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusForbidden, w.Code)
|
||||
require.Contains(t, w.Body.String(), "ACCESS_DENIED")
|
||||
}
|
||||
|
||||
func newAuthTestRouter(apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) *gin.Engine {
|
||||
router := gin.New()
|
||||
router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, cfg)))
|
||||
|
||||
@@ -24,10 +24,19 @@ func RegisterAuthRoutes(
|
||||
// 公开接口
|
||||
auth := v1.Group("/auth")
|
||||
{
|
||||
auth.POST("/register", h.Auth.Register)
|
||||
auth.POST("/login", h.Auth.Login)
|
||||
auth.POST("/login/2fa", h.Auth.Login2FA)
|
||||
auth.POST("/send-verify-code", h.Auth.SendVerifyCode)
|
||||
// 注册/登录/2FA/验证码发送均属于高风险入口,增加服务端兜底限流(Redis 故障时 fail-close)
|
||||
auth.POST("/register", rateLimiter.LimitWithOptions("auth-register", 5, time.Minute, middleware.RateLimitOptions{
|
||||
FailureMode: middleware.RateLimitFailClose,
|
||||
}), h.Auth.Register)
|
||||
auth.POST("/login", rateLimiter.LimitWithOptions("auth-login", 20, time.Minute, middleware.RateLimitOptions{
|
||||
FailureMode: middleware.RateLimitFailClose,
|
||||
}), h.Auth.Login)
|
||||
auth.POST("/login/2fa", rateLimiter.LimitWithOptions("auth-login-2fa", 20, time.Minute, middleware.RateLimitOptions{
|
||||
FailureMode: middleware.RateLimitFailClose,
|
||||
}), h.Auth.Login2FA)
|
||||
auth.POST("/send-verify-code", rateLimiter.LimitWithOptions("auth-send-verify-code", 5, time.Minute, middleware.RateLimitOptions{
|
||||
FailureMode: middleware.RateLimitFailClose,
|
||||
}), h.Auth.SendVerifyCode)
|
||||
// Token刷新接口添加速率限制:每分钟最多 30 次(Redis 故障时 fail-close)
|
||||
auth.POST("/refresh", rateLimiter.LimitWithOptions("refresh-token", 30, time.Minute, middleware.RateLimitOptions{
|
||||
FailureMode: middleware.RateLimitFailClose,
|
||||
|
||||
@@ -0,0 +1,111 @@
|
||||
//go:build integration
|
||||
|
||||
package routes
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/stretchr/testify/require"
|
||||
tcredis "github.com/testcontainers/testcontainers-go/modules/redis"
|
||||
)
|
||||
|
||||
const authRouteRedisImageTag = "redis:8.4-alpine"
|
||||
|
||||
func TestAuthRegisterRateLimitThresholdHitReturns429(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
rdb := startAuthRouteRedis(t, ctx)
|
||||
|
||||
router := newAuthRoutesTestRouter(rdb)
|
||||
const path = "/api/v1/auth/register"
|
||||
|
||||
for i := 1; i <= 6; i++ {
|
||||
req := httptest.NewRequest(http.MethodPost, path, strings.NewReader(`{}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.RemoteAddr = "198.51.100.10:23456"
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if i <= 5 {
|
||||
require.Equal(t, http.StatusBadRequest, w.Code, "第 %d 次请求应先进入业务校验", i)
|
||||
continue
|
||||
}
|
||||
require.Equal(t, http.StatusTooManyRequests, w.Code, "第 6 次请求应命中限流")
|
||||
require.Contains(t, w.Body.String(), "rate limit exceeded")
|
||||
}
|
||||
}
|
||||
|
||||
func startAuthRouteRedis(t *testing.T, ctx context.Context) *redis.Client {
|
||||
t.Helper()
|
||||
ensureAuthRouteDockerAvailable(t)
|
||||
|
||||
redisContainer, err := tcredis.Run(ctx, authRouteRedisImageTag)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
_ = redisContainer.Terminate(ctx)
|
||||
})
|
||||
|
||||
redisHost, err := redisContainer.Host(ctx)
|
||||
require.NoError(t, err)
|
||||
redisPort, err := redisContainer.MappedPort(ctx, "6379/tcp")
|
||||
require.NoError(t, err)
|
||||
|
||||
rdb := redis.NewClient(&redis.Options{
|
||||
Addr: fmt.Sprintf("%s:%d", redisHost, redisPort.Int()),
|
||||
DB: 0,
|
||||
})
|
||||
require.NoError(t, rdb.Ping(ctx).Err())
|
||||
t.Cleanup(func() {
|
||||
_ = rdb.Close()
|
||||
})
|
||||
return rdb
|
||||
}
|
||||
|
||||
func ensureAuthRouteDockerAvailable(t *testing.T) {
|
||||
t.Helper()
|
||||
if authRouteDockerAvailable() {
|
||||
return
|
||||
}
|
||||
t.Skip("Docker 未启用,跳过认证限流集成测试")
|
||||
}
|
||||
|
||||
func authRouteDockerAvailable() bool {
|
||||
if os.Getenv("DOCKER_HOST") != "" {
|
||||
return true
|
||||
}
|
||||
|
||||
socketCandidates := []string{
|
||||
"/var/run/docker.sock",
|
||||
filepath.Join(os.Getenv("XDG_RUNTIME_DIR"), "docker.sock"),
|
||||
filepath.Join(authRouteUserHomeDir(), ".docker", "run", "docker.sock"),
|
||||
filepath.Join(authRouteUserHomeDir(), ".docker", "desktop", "docker.sock"),
|
||||
filepath.Join("/run/user", strconv.Itoa(os.Getuid()), "docker.sock"),
|
||||
}
|
||||
|
||||
for _, socket := range socketCandidates {
|
||||
if socket == "" {
|
||||
continue
|
||||
}
|
||||
if _, err := os.Stat(socket); err == nil {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func authRouteUserHomeDir() string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return home
|
||||
}
|
||||
67
backend/internal/server/routes/auth_rate_limit_test.go
Normal file
67
backend/internal/server/routes/auth_rate_limit_test.go
Normal file
@@ -0,0 +1,67 @@
|
||||
package routes
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler"
|
||||
servermiddleware "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func newAuthRoutesTestRouter(redisClient *redis.Client) *gin.Engine {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
v1 := router.Group("/api/v1")
|
||||
|
||||
RegisterAuthRoutes(
|
||||
v1,
|
||||
&handler.Handlers{
|
||||
Auth: &handler.AuthHandler{},
|
||||
Setting: &handler.SettingHandler{},
|
||||
},
|
||||
servermiddleware.JWTAuthMiddleware(func(c *gin.Context) {
|
||||
c.Next()
|
||||
}),
|
||||
redisClient,
|
||||
)
|
||||
|
||||
return router
|
||||
}
|
||||
|
||||
func TestAuthRoutesRateLimitFailCloseWhenRedisUnavailable(t *testing.T) {
|
||||
rdb := redis.NewClient(&redis.Options{
|
||||
Addr: "127.0.0.1:1",
|
||||
DialTimeout: 50 * time.Millisecond,
|
||||
ReadTimeout: 50 * time.Millisecond,
|
||||
WriteTimeout: 50 * time.Millisecond,
|
||||
})
|
||||
t.Cleanup(func() {
|
||||
_ = rdb.Close()
|
||||
})
|
||||
|
||||
router := newAuthRoutesTestRouter(rdb)
|
||||
paths := []string{
|
||||
"/api/v1/auth/register",
|
||||
"/api/v1/auth/login",
|
||||
"/api/v1/auth/login/2fa",
|
||||
"/api/v1/auth/send-verify-code",
|
||||
}
|
||||
|
||||
for _, path := range paths {
|
||||
req := httptest.NewRequest(http.MethodPost, path, strings.NewReader(`{}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.RemoteAddr = "203.0.113.10:12345"
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusTooManyRequests, w.Code, "path=%s", path)
|
||||
require.Contains(t, w.Body.String(), "rate limit exceeded", "path=%s", path)
|
||||
}
|
||||
}
|
||||
@@ -3332,7 +3332,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
|
||||
// 不需要重试(成功或不可重试的错误),跳出循环
|
||||
// DEBUG: 输出响应 headers(用于检测 rate limit 信息)
|
||||
if account.Platform == PlatformGemini && resp.StatusCode < 400 {
|
||||
if account.Platform == PlatformGemini && resp.StatusCode < 400 && s.cfg != nil && s.cfg.Gateway.GeminiDebugResponseHeaders {
|
||||
logger.LegacyPrintf("service.gateway", "[DEBUG] Gemini API Response Headers for account %d:", account.ID)
|
||||
for k, v := range resp.Header {
|
||||
logger.LegacyPrintf("service.gateway", "[DEBUG] %s: %v", k, v)
|
||||
@@ -4467,8 +4467,19 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h
|
||||
// 更新5h窗口状态
|
||||
s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header)
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
maxBytes := resolveUpstreamResponseReadLimit(s.cfg)
|
||||
body, err := readUpstreamResponseBodyLimited(resp.Body, maxBytes)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrUpstreamResponseBodyTooLarge) {
|
||||
setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "")
|
||||
c.JSON(http.StatusBadGateway, gin.H{
|
||||
"type": "error",
|
||||
"error": gin.H{
|
||||
"type": "upstream_error",
|
||||
"message": "Upstream response too large",
|
||||
},
|
||||
})
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -4990,9 +5001,15 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
||||
}
|
||||
|
||||
// 读取响应体
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
maxReadBytes := resolveUpstreamResponseReadLimit(s.cfg)
|
||||
respBody, err := readUpstreamResponseBodyLimited(resp.Body, maxReadBytes)
|
||||
_ = resp.Body.Close()
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrUpstreamResponseBodyTooLarge) {
|
||||
setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "")
|
||||
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Upstream response too large")
|
||||
return err
|
||||
}
|
||||
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to read response")
|
||||
return err
|
||||
}
|
||||
@@ -5007,9 +5024,14 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
||||
retryResp, retryErr := s.httpUpstream.DoWithTLS(retryReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
|
||||
if retryErr == nil {
|
||||
resp = retryResp
|
||||
respBody, err = io.ReadAll(resp.Body)
|
||||
respBody, err = readUpstreamResponseBodyLimited(resp.Body, maxReadBytes)
|
||||
_ = resp.Body.Close()
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrUpstreamResponseBodyTooLarge) {
|
||||
setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "")
|
||||
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Upstream response too large")
|
||||
return err
|
||||
}
|
||||
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to read response")
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -2358,29 +2358,36 @@ type UpstreamHTTPResult struct {
|
||||
}
|
||||
|
||||
func (s *GeminiMessagesCompatService) handleNativeNonStreamingResponse(c *gin.Context, resp *http.Response, isOAuth bool) (*ClaudeUsage, error) {
|
||||
// Log response headers for debugging
|
||||
logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ========== Response Headers ==========")
|
||||
for key, values := range resp.Header {
|
||||
if strings.HasPrefix(strings.ToLower(key), "x-ratelimit") {
|
||||
logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] %s: %v", key, values)
|
||||
if s.cfg != nil && s.cfg.Gateway.GeminiDebugResponseHeaders {
|
||||
logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ========== Response Headers ==========")
|
||||
for key, values := range resp.Header {
|
||||
if strings.HasPrefix(strings.ToLower(key), "x-ratelimit") {
|
||||
logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] %s: %v", key, values)
|
||||
}
|
||||
}
|
||||
logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ========================================")
|
||||
}
|
||||
logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ========================================")
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
maxBytes := resolveUpstreamResponseReadLimit(s.cfg)
|
||||
respBody, err := readUpstreamResponseBodyLimited(resp.Body, maxBytes)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrUpstreamResponseBodyTooLarge) {
|
||||
setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "")
|
||||
c.JSON(http.StatusBadGateway, gin.H{
|
||||
"error": gin.H{
|
||||
"type": "upstream_error",
|
||||
"message": "Upstream response too large",
|
||||
},
|
||||
})
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var parsed map[string]any
|
||||
if isOAuth {
|
||||
unwrappedBody, uwErr := unwrapGeminiResponse(respBody)
|
||||
if uwErr == nil {
|
||||
respBody = unwrappedBody
|
||||
}
|
||||
_ = json.Unmarshal(respBody, &parsed)
|
||||
} else {
|
||||
_ = json.Unmarshal(respBody, &parsed)
|
||||
}
|
||||
|
||||
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
|
||||
@@ -2398,14 +2405,15 @@ func (s *GeminiMessagesCompatService) handleNativeNonStreamingResponse(c *gin.Co
|
||||
}
|
||||
|
||||
func (s *GeminiMessagesCompatService) handleNativeStreamingResponse(c *gin.Context, resp *http.Response, startTime time.Time, isOAuth bool) (*geminiNativeStreamResult, error) {
|
||||
// Log response headers for debugging
|
||||
logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ========== Streaming Response Headers ==========")
|
||||
for key, values := range resp.Header {
|
||||
if strings.HasPrefix(strings.ToLower(key), "x-ratelimit") {
|
||||
logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] %s: %v", key, values)
|
||||
if s.cfg != nil && s.cfg.Gateway.GeminiDebugResponseHeaders {
|
||||
logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ========== Streaming Response Headers ==========")
|
||||
for key, values := range resp.Header {
|
||||
if strings.HasPrefix(strings.ToLower(key), "x-ratelimit") {
|
||||
logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] %s: %v", key, values)
|
||||
}
|
||||
}
|
||||
logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ====================================================")
|
||||
}
|
||||
logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ====================================================")
|
||||
|
||||
if s.cfg != nil {
|
||||
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
|
||||
|
||||
@@ -3,10 +3,15 @@ package service
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
@@ -133,6 +138,38 @@ func TestConvertClaudeToolsToGeminiTools_CustomType(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGeminiHandleNativeNonStreamingResponse_DebugDisabledDoesNotEmitHeaderLogs(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
logSink, restore := captureStructuredLog(t)
|
||||
defer restore()
|
||||
|
||||
svc := &GeminiMessagesCompatService{
|
||||
cfg: &config.Config{
|
||||
Gateway: config.GatewayConfig{
|
||||
GeminiDebugResponseHeaders: false,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{
|
||||
"Content-Type": []string{"application/json"},
|
||||
"X-RateLimit-Limit": []string{"60"},
|
||||
},
|
||||
Body: io.NopCloser(strings.NewReader(`{"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":2}}`)),
|
||||
}
|
||||
|
||||
usage, err := svc.handleNativeNonStreamingResponse(c, resp, false)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, usage)
|
||||
require.False(t, logSink.ContainsMessage("[GeminiAPI]"), "debug 关闭时不应输出 Gemini 响应头日志")
|
||||
}
|
||||
|
||||
func TestConvertClaudeMessagesToGeminiGenerateContent_AddsThoughtSignatureForToolUse(t *testing.T) {
|
||||
claudeReq := map[string]any{
|
||||
"model": "claude-haiku-4-5-20251001",
|
||||
|
||||
@@ -1741,8 +1741,18 @@ func (s *OpenAIGatewayService) handleNonStreamingResponsePassthrough(
|
||||
resp *http.Response,
|
||||
c *gin.Context,
|
||||
) (*OpenAIUsage, error) {
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
maxBytes := resolveUpstreamResponseReadLimit(s.cfg)
|
||||
body, err := readUpstreamResponseBodyLimited(resp.Body, maxBytes)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrUpstreamResponseBodyTooLarge) {
|
||||
setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "")
|
||||
c.JSON(http.StatusBadGateway, gin.H{
|
||||
"error": gin.H{
|
||||
"type": "upstream_error",
|
||||
"message": "Upstream response too large",
|
||||
},
|
||||
})
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -2371,8 +2381,18 @@ func (s *OpenAIGatewayService) parseSSEUsage(data string, usage *OpenAIUsage) {
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string) (*OpenAIUsage, error) {
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
maxBytes := resolveUpstreamResponseReadLimit(s.cfg)
|
||||
body, err := readUpstreamResponseBodyLimited(resp.Body, maxBytes)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrUpstreamResponseBodyTooLarge) {
|
||||
setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "")
|
||||
c.JSON(http.StatusBadGateway, gin.H{
|
||||
"error": gin.H{
|
||||
"type": "upstream_error",
|
||||
"message": "Upstream response too large",
|
||||
},
|
||||
})
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -2930,6 +2950,25 @@ func normalizeOpenAIPassthroughOAuthBody(body []byte) ([]byte, bool, error) {
|
||||
return normalized, changed, nil
|
||||
}
|
||||
|
||||
func detectOpenAIPassthroughInstructionsRejectReason(reqModel string, body []byte) string {
|
||||
model := strings.ToLower(strings.TrimSpace(reqModel))
|
||||
if !strings.Contains(model, "codex") {
|
||||
return ""
|
||||
}
|
||||
|
||||
instructions := gjson.GetBytes(body, "instructions")
|
||||
if !instructions.Exists() {
|
||||
return "instructions_missing"
|
||||
}
|
||||
if instructions.Type != gjson.String {
|
||||
return "instructions_not_string"
|
||||
}
|
||||
if strings.TrimSpace(instructions.String()) == "" {
|
||||
return "instructions_empty"
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func extractOpenAIReasoningEffortFromBody(body []byte, requestedModel string) *string {
|
||||
reasoningEffort := strings.TrimSpace(gjson.GetBytes(body, "reasoning.effort").String())
|
||||
if reasoningEffort == "" {
|
||||
@@ -3002,22 +3041,3 @@ func normalizeOpenAIReasoningEffort(raw string) string {
|
||||
return ""
|
||||
}
|
||||
}
|
||||
func detectOpenAIPassthroughInstructionsRejectReason(reqModel string, body []byte) string {
|
||||
model := strings.ToLower(strings.TrimSpace(reqModel))
|
||||
if !strings.Contains(model, "codex") {
|
||||
return ""
|
||||
}
|
||||
|
||||
instructions := gjson.GetBytes(body, "instructions")
|
||||
if !instructions.Exists() {
|
||||
return "instructions_missing"
|
||||
}
|
||||
if instructions.Type != gjson.String {
|
||||
return "instructions_not_string"
|
||||
}
|
||||
if strings.TrimSpace(instructions.String()) == "" {
|
||||
return "instructions_empty"
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
|
||||
38
backend/internal/service/upstream_response_limit.go
Normal file
38
backend/internal/service/upstream_response_limit.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
)
|
||||
|
||||
var ErrUpstreamResponseBodyTooLarge = errors.New("upstream response body too large")
|
||||
|
||||
const defaultUpstreamResponseReadMaxBytes int64 = 8 * 1024 * 1024
|
||||
|
||||
func resolveUpstreamResponseReadLimit(cfg *config.Config) int64 {
|
||||
if cfg != nil && cfg.Gateway.UpstreamResponseReadMaxBytes > 0 {
|
||||
return cfg.Gateway.UpstreamResponseReadMaxBytes
|
||||
}
|
||||
return defaultUpstreamResponseReadMaxBytes
|
||||
}
|
||||
|
||||
func readUpstreamResponseBodyLimited(reader io.Reader, maxBytes int64) ([]byte, error) {
|
||||
if reader == nil {
|
||||
return nil, errors.New("response body is nil")
|
||||
}
|
||||
if maxBytes <= 0 {
|
||||
maxBytes = defaultUpstreamResponseReadMaxBytes
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(io.LimitReader(reader, maxBytes+1))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if int64(len(body)) > maxBytes {
|
||||
return nil, fmt.Errorf("%w: limit=%d", ErrUpstreamResponseBodyTooLarge, maxBytes)
|
||||
}
|
||||
return body, nil
|
||||
}
|
||||
37
backend/internal/service/upstream_response_limit_test.go
Normal file
37
backend/internal/service/upstream_response_limit_test.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestResolveUpstreamResponseReadLimit(t *testing.T) {
|
||||
t.Run("use default when config missing", func(t *testing.T) {
|
||||
require.Equal(t, defaultUpstreamResponseReadMaxBytes, resolveUpstreamResponseReadLimit(nil))
|
||||
})
|
||||
|
||||
t.Run("use configured value", func(t *testing.T) {
|
||||
cfg := &config.Config{}
|
||||
cfg.Gateway.UpstreamResponseReadMaxBytes = 1234
|
||||
require.Equal(t, int64(1234), resolveUpstreamResponseReadLimit(cfg))
|
||||
})
|
||||
}
|
||||
|
||||
func TestReadUpstreamResponseBodyLimited(t *testing.T) {
|
||||
t.Run("within limit", func(t *testing.T) {
|
||||
body, err := readUpstreamResponseBodyLimited(bytes.NewReader([]byte("ok")), 2)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []byte("ok"), body)
|
||||
})
|
||||
|
||||
t.Run("exceeds limit", func(t *testing.T) {
|
||||
body, err := readUpstreamResponseBodyLimited(bytes.NewReader([]byte("toolong")), 3)
|
||||
require.Nil(t, body)
|
||||
require.Error(t, err)
|
||||
require.True(t, errors.Is(err, ErrUpstreamResponseBodyTooLarge))
|
||||
})
|
||||
}
|
||||
Reference in New Issue
Block a user