diff --git a/README_CN.md b/README_CN.md index bec7fe86..9dd69226 100644 --- a/README_CN.md +++ b/README_CN.md @@ -404,6 +404,14 @@ gateway: - `server.trusted_proxies` 启用可信代理解析 X-Forwarded-For - `turnstile.required` 在 release 模式强制启用 Turnstile +**网关防御纵深建议(重点)** + +- `gateway.upstream_response_read_max_bytes`:限制非流式上游响应读取大小(默认 `8MB`),用于防止异常响应导致内存放大。 +- `gateway.proxy_probe_response_read_max_bytes`:限制代理探测响应读取大小(默认 `1MB`)。 +- `gateway.gemini_debug_response_headers`:默认 `false`,仅在排障时短时开启,避免高频请求日志开销。 +- `/auth/register`、`/auth/login`、`/auth/login/2fa`、`/auth/send-verify-code` 已提供服务端兜底限流(Redis 故障时 fail-close)。 +- 推荐将 WAF/CDN 作为第一层防护,服务端限流与响应读取上限作为第二层兜底;两层同时保留,避免旁路流量与误配置风险。 + **⚠️ 安全警告:HTTP URL 配置** 当 `security.url_allowlist.enabled=false` 时,系统默认执行最小 URL 校验,**拒绝 HTTP URL**,仅允许 HTTPS。要允许 HTTP URL(例如用于开发或内网测试),必须显式设置: diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index 861351de..b9f31ba9 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -308,6 +308,12 @@ type GatewayConfig struct { ResponseHeaderTimeout int `mapstructure:"response_header_timeout"` // 请求体最大字节数,用于网关请求体大小限制 MaxBodySize int64 `mapstructure:"max_body_size"` + // 非流式上游响应体读取上限(字节),用于防止无界读取导致内存放大 + UpstreamResponseReadMaxBytes int64 `mapstructure:"upstream_response_read_max_bytes"` + // 代理探测响应体读取上限(字节) + ProxyProbeResponseReadMaxBytes int64 `mapstructure:"proxy_probe_response_read_max_bytes"` + // Gemini 上游响应头调试日志开关(默认关闭,避免高频日志开销) + GeminiDebugResponseHeaders bool `mapstructure:"gemini_debug_response_headers"` // ConnectionPoolIsolation: 上游连接池隔离策略(proxy/account/account_proxy) ConnectionPoolIsolation string `mapstructure:"connection_pool_isolation"` // ForceCodexCLI: 强制将 OpenAI `/v1/responses` 请求按 Codex CLI 处理。 @@ -1059,6 +1065,9 @@ func setDefaults() { viper.SetDefault("gateway.openai_passthrough_allow_timeout_headers", false) viper.SetDefault("gateway.antigravity_fallback_cooldown_minutes", 1) viper.SetDefault("gateway.max_body_size", int64(100*1024*1024)) + viper.SetDefault("gateway.upstream_response_read_max_bytes", int64(8*1024*1024)) + viper.SetDefault("gateway.proxy_probe_response_read_max_bytes", int64(1024*1024)) + viper.SetDefault("gateway.gemini_debug_response_headers", false) viper.SetDefault("gateway.sora_max_body_size", int64(256*1024*1024)) viper.SetDefault("gateway.sora_stream_timeout_seconds", 900) viper.SetDefault("gateway.sora_request_timeout_seconds", 180) @@ -1465,6 +1474,12 @@ func (c *Config) Validate() error { if c.Gateway.MaxBodySize <= 0 { return fmt.Errorf("gateway.max_body_size must be positive") } + if c.Gateway.UpstreamResponseReadMaxBytes <= 0 { + return fmt.Errorf("gateway.upstream_response_read_max_bytes must be positive") + } + if c.Gateway.ProxyProbeResponseReadMaxBytes <= 0 { + return fmt.Errorf("gateway.proxy_probe_response_read_max_bytes must be positive") + } if c.Gateway.SoraMaxBodySize < 0 { return fmt.Errorf("gateway.sora_max_body_size must be non-negative") } diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index 598eb4b3..ca5ee9d7 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -418,8 +418,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) { } continue } - // 错误响应已在Forward中处理,这里只记录日志 - reqLog.Error("gateway.forward_failed", zap.Int64("account_id", account.ID), zap.Error(err)) + wroteFallback := h.ensureForwardErrorResponse(c, streamStarted) + reqLog.Error("gateway.forward_failed", + zap.Int64("account_id", account.ID), + zap.Bool("fallback_error_response_written", wroteFallback), + zap.Error(err), + ) return } @@ -683,8 +687,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) { } continue } - // 错误响应已在Forward中处理,这里只记录日志 - reqLog.Error("gateway.forward_failed", zap.Int64("account_id", account.ID), zap.Error(err)) + wroteFallback := h.ensureForwardErrorResponse(c, streamStarted) + reqLog.Error("gateway.forward_failed", + zap.Int64("account_id", account.ID), + zap.Bool("fallback_error_response_written", wroteFallback), + zap.Error(err), + ) return } @@ -1117,6 +1125,15 @@ func (h *GatewayHandler) handleStreamingAwareError(c *gin.Context, status int, e h.errorResponse(c, status, errType, message) } +// ensureForwardErrorResponse 在 Forward 返回错误但尚未写响应时补写统一错误响应。 +func (h *GatewayHandler) ensureForwardErrorResponse(c *gin.Context, streamStarted bool) bool { + if c == nil || c.Writer == nil || c.Writer.Written() { + return false + } + h.handleStreamingAwareError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed", streamStarted) + return true +} + // errorResponse 返回Claude API格式的错误响应 func (h *GatewayHandler) errorResponse(c *gin.Context, status int, errType, message string) { c.JSON(status, gin.H{ diff --git a/backend/internal/handler/gateway_handler_error_fallback_test.go b/backend/internal/handler/gateway_handler_error_fallback_test.go new file mode 100644 index 00000000..4fce5ec1 --- /dev/null +++ b/backend/internal/handler/gateway_handler_error_fallback_test.go @@ -0,0 +1,49 @@ +package handler + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestGatewayEnsureForwardErrorResponse_WritesFallbackWhenNotWritten(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) + + h := &GatewayHandler{} + wrote := h.ensureForwardErrorResponse(c, false) + + require.True(t, wrote) + require.Equal(t, http.StatusBadGateway, w.Code) + + var parsed map[string]any + err := json.Unmarshal(w.Body.Bytes(), &parsed) + require.NoError(t, err) + assert.Equal(t, "error", parsed["type"]) + errorObj, ok := parsed["error"].(map[string]any) + require.True(t, ok) + assert.Equal(t, "upstream_error", errorObj["type"]) + assert.Equal(t, "Upstream request failed", errorObj["message"]) +} + +func TestGatewayEnsureForwardErrorResponse_DoesNotOverrideWrittenResponse(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) + c.String(http.StatusTeapot, "already written") + + h := &GatewayHandler{} + wrote := h.ensureForwardErrorResponse(c, false) + + require.False(t, wrote) + require.Equal(t, http.StatusTeapot, w.Code) + assert.Equal(t, "already written", w.Body.String()) +} diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go index 470eab45..f5db385b 100644 --- a/backend/internal/handler/openai_gateway_handler.go +++ b/backend/internal/handler/openai_gateway_handler.go @@ -365,8 +365,12 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { ) continue } - // Error response already handled in Forward, just log - reqLog.Error("openai.forward_failed", zap.Int64("account_id", account.ID), zap.Error(err)) + wroteFallback := h.ensureForwardErrorResponse(c, streamStarted) + reqLog.Error("openai.forward_failed", + zap.Int64("account_id", account.ID), + zap.Bool("fallback_error_response_written", wroteFallback), + zap.Error(err), + ) return } @@ -521,6 +525,15 @@ func (h *OpenAIGatewayHandler) handleStreamingAwareError(c *gin.Context, status h.errorResponse(c, status, errType, message) } +// ensureForwardErrorResponse 在 Forward 返回错误但尚未写响应时补写统一错误响应。 +func (h *OpenAIGatewayHandler) ensureForwardErrorResponse(c *gin.Context, streamStarted bool) bool { + if c == nil || c.Writer == nil || c.Writer.Written() { + return false + } + h.handleStreamingAwareError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed", streamStarted) + return true +} + // errorResponse returns OpenAI API format error response func (h *OpenAIGatewayHandler) errorResponse(c *gin.Context, status int, errType, message string) { c.JSON(status, gin.H{ diff --git a/backend/internal/handler/openai_gateway_handler_test.go b/backend/internal/handler/openai_gateway_handler_test.go index 65296da4..1ca52c2d 100644 --- a/backend/internal/handler/openai_gateway_handler_test.go +++ b/backend/internal/handler/openai_gateway_handler_test.go @@ -105,6 +105,42 @@ func TestOpenAIHandleStreamingAwareError_NonStreaming(t *testing.T) { assert.Equal(t, "test error", errorObj["message"]) } +func TestOpenAIEnsureForwardErrorResponse_WritesFallbackWhenNotWritten(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) + + h := &OpenAIGatewayHandler{} + wrote := h.ensureForwardErrorResponse(c, false) + + require.True(t, wrote) + require.Equal(t, http.StatusBadGateway, w.Code) + + var parsed map[string]any + err := json.Unmarshal(w.Body.Bytes(), &parsed) + require.NoError(t, err) + errorObj, ok := parsed["error"].(map[string]any) + require.True(t, ok) + assert.Equal(t, "upstream_error", errorObj["type"]) + assert.Equal(t, "Upstream request failed", errorObj["message"]) +} + +func TestOpenAIEnsureForwardErrorResponse_DoesNotOverrideWrittenResponse(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) + c.String(http.StatusTeapot, "already written") + + h := &OpenAIGatewayHandler{} + wrote := h.ensureForwardErrorResponse(c, false) + + require.False(t, wrote) + require.Equal(t, http.StatusTeapot, w.Code) + assert.Equal(t, "already written", w.Body.String()) +} + // TestOpenAIHandler_GjsonExtraction 验证 gjson 从请求体中提取 model/stream 的正确性 func TestOpenAIHandler_GjsonExtraction(t *testing.T) { tests := []struct { diff --git a/backend/internal/pkg/ip/ip.go b/backend/internal/pkg/ip/ip.go index 6ab2ff72..3f05ac41 100644 --- a/backend/internal/pkg/ip/ip.go +++ b/backend/internal/pkg/ip/ip.go @@ -44,6 +44,16 @@ func GetClientIP(c *gin.Context) string { return normalizeIP(c.ClientIP()) } +// GetTrustedClientIP 从 Gin 的可信代理解析链提取客户端 IP。 +// 该方法依赖 gin.Engine.SetTrustedProxies 配置,不会优先直接信任原始转发头值。 +// 适用于 ACL / 风控等安全敏感场景。 +func GetTrustedClientIP(c *gin.Context) string { + if c == nil { + return "" + } + return normalizeIP(c.ClientIP()) +} + // normalizeIP 规范化 IP 地址,去除端口号和空格。 func normalizeIP(ip string) string { ip = strings.TrimSpace(ip) diff --git a/backend/internal/pkg/ip/ip_test.go b/backend/internal/pkg/ip/ip_test.go index c3c90c74..3839403c 100644 --- a/backend/internal/pkg/ip/ip_test.go +++ b/backend/internal/pkg/ip/ip_test.go @@ -3,8 +3,10 @@ package ip import ( + "net/http/httptest" "testing" + "github.com/gin-gonic/gin" "github.com/stretchr/testify/require" ) @@ -49,3 +51,25 @@ func TestIsPrivateIP(t *testing.T) { }) } } + +func TestGetTrustedClientIPUsesGinClientIP(t *testing.T) { + gin.SetMode(gin.TestMode) + + r := gin.New() + require.NoError(t, r.SetTrustedProxies(nil)) + + r.GET("/t", func(c *gin.Context) { + c.String(200, GetTrustedClientIP(c)) + }) + + w := httptest.NewRecorder() + req := httptest.NewRequest("GET", "/t", nil) + req.RemoteAddr = "9.9.9.9:12345" + req.Header.Set("X-Forwarded-For", "1.2.3.4") + req.Header.Set("X-Real-IP", "1.2.3.4") + req.Header.Set("CF-Connecting-IP", "1.2.3.4") + r.ServeHTTP(w, req) + + require.Equal(t, 200, w.Code) + require.Equal(t, "9.9.9.9", w.Body.String()) +} diff --git a/backend/internal/repository/proxy_probe_service.go b/backend/internal/repository/proxy_probe_service.go index 513e929c..54de2897 100644 --- a/backend/internal/repository/proxy_probe_service.go +++ b/backend/internal/repository/proxy_probe_service.go @@ -19,10 +19,14 @@ func NewProxyExitInfoProber(cfg *config.Config) service.ProxyExitInfoProber { insecure := false allowPrivate := false validateResolvedIP := true + maxResponseBytes := defaultProxyProbeResponseMaxBytes if cfg != nil { insecure = cfg.Security.ProxyProbe.InsecureSkipVerify allowPrivate = cfg.Security.URLAllowlist.AllowPrivateHosts validateResolvedIP = cfg.Security.URLAllowlist.Enabled + if cfg.Gateway.ProxyProbeResponseReadMaxBytes > 0 { + maxResponseBytes = cfg.Gateway.ProxyProbeResponseReadMaxBytes + } } if insecure { log.Printf("[ProxyProbe] Warning: insecure_skip_verify is not allowed and will cause probe failure.") @@ -31,11 +35,13 @@ func NewProxyExitInfoProber(cfg *config.Config) service.ProxyExitInfoProber { insecureSkipVerify: insecure, allowPrivateHosts: allowPrivate, validateResolvedIP: validateResolvedIP, + maxResponseBytes: maxResponseBytes, } } const ( - defaultProxyProbeTimeout = 30 * time.Second + defaultProxyProbeTimeout = 30 * time.Second + defaultProxyProbeResponseMaxBytes = int64(1024 * 1024) ) // probeURLs 按优先级排列的探测 URL 列表 @@ -52,6 +58,7 @@ type proxyProbeService struct { insecureSkipVerify bool allowPrivateHosts bool validateResolvedIP bool + maxResponseBytes int64 } func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*service.ProxyExitInfo, int64, error) { @@ -98,10 +105,17 @@ func (s *proxyProbeService) probeWithURL(ctx context.Context, client *http.Clien return nil, latencyMs, fmt.Errorf("request failed with status: %d", resp.StatusCode) } - body, err := io.ReadAll(resp.Body) + maxResponseBytes := s.maxResponseBytes + if maxResponseBytes <= 0 { + maxResponseBytes = defaultProxyProbeResponseMaxBytes + } + body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBytes+1)) if err != nil { return nil, latencyMs, fmt.Errorf("failed to read response: %w", err) } + if int64(len(body)) > maxResponseBytes { + return nil, latencyMs, fmt.Errorf("proxy probe response exceeds limit: %d", maxResponseBytes) + } switch parser { case "ip-api": diff --git a/backend/internal/server/http.go b/backend/internal/server/http.go index d2d8ed40..a8034e98 100644 --- a/backend/internal/server/http.go +++ b/backend/internal/server/http.go @@ -51,6 +51,9 @@ func ProvideRouter( if err := r.SetTrustedProxies(nil); err != nil { log.Printf("Failed to disable trusted proxies: %v", err) } + if cfg.Server.Mode == "release" { + log.Printf("Warning: server.trusted_proxies is empty in release mode; client IP trust chain is disabled") + } } return SetupRouter(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg, redisClient) diff --git a/backend/internal/server/middleware/api_key_auth.go b/backend/internal/server/middleware/api_key_auth.go index 8e03f785..7aad1699 100644 --- a/backend/internal/server/middleware/api_key_auth.go +++ b/backend/internal/server/middleware/api_key_auth.go @@ -96,7 +96,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti // 检查 IP 限制(白名单/黑名单) // 注意:错误信息故意模糊,避免暴露具体的 IP 限制机制 if len(apiKey.IPWhitelist) > 0 || len(apiKey.IPBlacklist) > 0 { - clientIP := ip.GetClientIP(c) + clientIP := ip.GetTrustedClientIP(c) allowed, _ := ip.CheckIPRestriction(clientIP, apiKey.IPWhitelist, apiKey.IPBlacklist) if !allowed { AbortWithError(c, 403, "ACCESS_DENIED", "Access denied") diff --git a/backend/internal/server/middleware/api_key_auth_test.go b/backend/internal/server/middleware/api_key_auth_test.go index 3e33c7e3..f3a6f076 100644 --- a/backend/internal/server/middleware/api_key_auth_test.go +++ b/backend/internal/server/middleware/api_key_auth_test.go @@ -300,6 +300,57 @@ func TestAPIKeyAuthOverwritesInvalidContextGroup(t *testing.T) { require.Equal(t, http.StatusOK, w.Code) } +func TestAPIKeyAuthIPRestrictionDoesNotTrustSpoofedForwardHeaders(t *testing.T) { + gin.SetMode(gin.TestMode) + + user := &service.User{ + ID: 7, + Role: service.RoleUser, + Status: service.StatusActive, + Balance: 10, + Concurrency: 3, + } + apiKey := &service.APIKey{ + ID: 100, + UserID: user.ID, + Key: "test-key", + Status: service.StatusActive, + User: user, + IPWhitelist: []string{"1.2.3.4"}, + } + + apiKeyRepo := &stubApiKeyRepo{ + getByKey: func(ctx context.Context, key string) (*service.APIKey, error) { + if key != apiKey.Key { + return nil, service.ErrAPIKeyNotFound + } + clone := *apiKey + return &clone, nil + }, + } + + cfg := &config.Config{RunMode: config.RunModeSimple} + apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg) + router := gin.New() + require.NoError(t, router.SetTrustedProxies(nil)) + router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, nil, cfg))) + router.GET("/t", func(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{"ok": true}) + }) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/t", nil) + req.RemoteAddr = "9.9.9.9:12345" + req.Header.Set("x-api-key", apiKey.Key) + req.Header.Set("X-Forwarded-For", "1.2.3.4") + req.Header.Set("X-Real-IP", "1.2.3.4") + req.Header.Set("CF-Connecting-IP", "1.2.3.4") + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusForbidden, w.Code) + require.Contains(t, w.Body.String(), "ACCESS_DENIED") +} + func newAuthTestRouter(apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) *gin.Engine { router := gin.New() router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, cfg))) diff --git a/backend/internal/server/routes/auth.go b/backend/internal/server/routes/auth.go index 26d79605..c168820c 100644 --- a/backend/internal/server/routes/auth.go +++ b/backend/internal/server/routes/auth.go @@ -24,10 +24,19 @@ func RegisterAuthRoutes( // 公开接口 auth := v1.Group("/auth") { - auth.POST("/register", h.Auth.Register) - auth.POST("/login", h.Auth.Login) - auth.POST("/login/2fa", h.Auth.Login2FA) - auth.POST("/send-verify-code", h.Auth.SendVerifyCode) + // 注册/登录/2FA/验证码发送均属于高风险入口,增加服务端兜底限流(Redis 故障时 fail-close) + auth.POST("/register", rateLimiter.LimitWithOptions("auth-register", 5, time.Minute, middleware.RateLimitOptions{ + FailureMode: middleware.RateLimitFailClose, + }), h.Auth.Register) + auth.POST("/login", rateLimiter.LimitWithOptions("auth-login", 20, time.Minute, middleware.RateLimitOptions{ + FailureMode: middleware.RateLimitFailClose, + }), h.Auth.Login) + auth.POST("/login/2fa", rateLimiter.LimitWithOptions("auth-login-2fa", 20, time.Minute, middleware.RateLimitOptions{ + FailureMode: middleware.RateLimitFailClose, + }), h.Auth.Login2FA) + auth.POST("/send-verify-code", rateLimiter.LimitWithOptions("auth-send-verify-code", 5, time.Minute, middleware.RateLimitOptions{ + FailureMode: middleware.RateLimitFailClose, + }), h.Auth.SendVerifyCode) // Token刷新接口添加速率限制:每分钟最多 30 次(Redis 故障时 fail-close) auth.POST("/refresh", rateLimiter.LimitWithOptions("refresh-token", 30, time.Minute, middleware.RateLimitOptions{ FailureMode: middleware.RateLimitFailClose, diff --git a/backend/internal/server/routes/auth_rate_limit_integration_test.go b/backend/internal/server/routes/auth_rate_limit_integration_test.go new file mode 100644 index 00000000..8a0ef860 --- /dev/null +++ b/backend/internal/server/routes/auth_rate_limit_integration_test.go @@ -0,0 +1,111 @@ +//go:build integration + +package routes + +import ( + "context" + "fmt" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "strconv" + "strings" + "testing" + + "github.com/redis/go-redis/v9" + "github.com/stretchr/testify/require" + tcredis "github.com/testcontainers/testcontainers-go/modules/redis" +) + +const authRouteRedisImageTag = "redis:8.4-alpine" + +func TestAuthRegisterRateLimitThresholdHitReturns429(t *testing.T) { + ctx := context.Background() + rdb := startAuthRouteRedis(t, ctx) + + router := newAuthRoutesTestRouter(rdb) + const path = "/api/v1/auth/register" + + for i := 1; i <= 6; i++ { + req := httptest.NewRequest(http.MethodPost, path, strings.NewReader(`{}`)) + req.Header.Set("Content-Type", "application/json") + req.RemoteAddr = "198.51.100.10:23456" + + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if i <= 5 { + require.Equal(t, http.StatusBadRequest, w.Code, "第 %d 次请求应先进入业务校验", i) + continue + } + require.Equal(t, http.StatusTooManyRequests, w.Code, "第 6 次请求应命中限流") + require.Contains(t, w.Body.String(), "rate limit exceeded") + } +} + +func startAuthRouteRedis(t *testing.T, ctx context.Context) *redis.Client { + t.Helper() + ensureAuthRouteDockerAvailable(t) + + redisContainer, err := tcredis.Run(ctx, authRouteRedisImageTag) + require.NoError(t, err) + t.Cleanup(func() { + _ = redisContainer.Terminate(ctx) + }) + + redisHost, err := redisContainer.Host(ctx) + require.NoError(t, err) + redisPort, err := redisContainer.MappedPort(ctx, "6379/tcp") + require.NoError(t, err) + + rdb := redis.NewClient(&redis.Options{ + Addr: fmt.Sprintf("%s:%d", redisHost, redisPort.Int()), + DB: 0, + }) + require.NoError(t, rdb.Ping(ctx).Err()) + t.Cleanup(func() { + _ = rdb.Close() + }) + return rdb +} + +func ensureAuthRouteDockerAvailable(t *testing.T) { + t.Helper() + if authRouteDockerAvailable() { + return + } + t.Skip("Docker 未启用,跳过认证限流集成测试") +} + +func authRouteDockerAvailable() bool { + if os.Getenv("DOCKER_HOST") != "" { + return true + } + + socketCandidates := []string{ + "/var/run/docker.sock", + filepath.Join(os.Getenv("XDG_RUNTIME_DIR"), "docker.sock"), + filepath.Join(authRouteUserHomeDir(), ".docker", "run", "docker.sock"), + filepath.Join(authRouteUserHomeDir(), ".docker", "desktop", "docker.sock"), + filepath.Join("/run/user", strconv.Itoa(os.Getuid()), "docker.sock"), + } + + for _, socket := range socketCandidates { + if socket == "" { + continue + } + if _, err := os.Stat(socket); err == nil { + return true + } + } + return false +} + +func authRouteUserHomeDir() string { + home, err := os.UserHomeDir() + if err != nil { + return "" + } + return home +} diff --git a/backend/internal/server/routes/auth_rate_limit_test.go b/backend/internal/server/routes/auth_rate_limit_test.go new file mode 100644 index 00000000..5ce8497c --- /dev/null +++ b/backend/internal/server/routes/auth_rate_limit_test.go @@ -0,0 +1,67 @@ +package routes + +import ( + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/handler" + servermiddleware "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/gin-gonic/gin" + "github.com/redis/go-redis/v9" + "github.com/stretchr/testify/require" +) + +func newAuthRoutesTestRouter(redisClient *redis.Client) *gin.Engine { + gin.SetMode(gin.TestMode) + router := gin.New() + v1 := router.Group("/api/v1") + + RegisterAuthRoutes( + v1, + &handler.Handlers{ + Auth: &handler.AuthHandler{}, + Setting: &handler.SettingHandler{}, + }, + servermiddleware.JWTAuthMiddleware(func(c *gin.Context) { + c.Next() + }), + redisClient, + ) + + return router +} + +func TestAuthRoutesRateLimitFailCloseWhenRedisUnavailable(t *testing.T) { + rdb := redis.NewClient(&redis.Options{ + Addr: "127.0.0.1:1", + DialTimeout: 50 * time.Millisecond, + ReadTimeout: 50 * time.Millisecond, + WriteTimeout: 50 * time.Millisecond, + }) + t.Cleanup(func() { + _ = rdb.Close() + }) + + router := newAuthRoutesTestRouter(rdb) + paths := []string{ + "/api/v1/auth/register", + "/api/v1/auth/login", + "/api/v1/auth/login/2fa", + "/api/v1/auth/send-verify-code", + } + + for _, path := range paths { + req := httptest.NewRequest(http.MethodPost, path, strings.NewReader(`{}`)) + req.Header.Set("Content-Type", "application/json") + req.RemoteAddr = "203.0.113.10:12345" + + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusTooManyRequests, w.Code, "path=%s", path) + require.Contains(t, w.Body.String(), "rate limit exceeded", "path=%s", path) + } +} diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 83cde19e..0502d352 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -3332,7 +3332,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A // 不需要重试(成功或不可重试的错误),跳出循环 // DEBUG: 输出响应 headers(用于检测 rate limit 信息) - if account.Platform == PlatformGemini && resp.StatusCode < 400 { + if account.Platform == PlatformGemini && resp.StatusCode < 400 && s.cfg != nil && s.cfg.Gateway.GeminiDebugResponseHeaders { logger.LegacyPrintf("service.gateway", "[DEBUG] Gemini API Response Headers for account %d:", account.ID) for k, v := range resp.Header { logger.LegacyPrintf("service.gateway", "[DEBUG] %s: %v", k, v) @@ -4467,8 +4467,19 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h // 更新5h窗口状态 s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header) - body, err := io.ReadAll(resp.Body) + maxBytes := resolveUpstreamResponseReadLimit(s.cfg) + body, err := readUpstreamResponseBodyLimited(resp.Body, maxBytes) if err != nil { + if errors.Is(err, ErrUpstreamResponseBodyTooLarge) { + setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "") + c.JSON(http.StatusBadGateway, gin.H{ + "type": "error", + "error": gin.H{ + "type": "upstream_error", + "message": "Upstream response too large", + }, + }) + } return nil, err } @@ -4990,9 +5001,15 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, } // 读取响应体 - respBody, err := io.ReadAll(resp.Body) + maxReadBytes := resolveUpstreamResponseReadLimit(s.cfg) + respBody, err := readUpstreamResponseBodyLimited(resp.Body, maxReadBytes) _ = resp.Body.Close() if err != nil { + if errors.Is(err, ErrUpstreamResponseBodyTooLarge) { + setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "") + s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Upstream response too large") + return err + } s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to read response") return err } @@ -5007,9 +5024,14 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, retryResp, retryErr := s.httpUpstream.DoWithTLS(retryReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled()) if retryErr == nil { resp = retryResp - respBody, err = io.ReadAll(resp.Body) + respBody, err = readUpstreamResponseBodyLimited(resp.Body, maxReadBytes) _ = resp.Body.Close() if err != nil { + if errors.Is(err, ErrUpstreamResponseBodyTooLarge) { + setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "") + s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Upstream response too large") + return err + } s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to read response") return err } diff --git a/backend/internal/service/gemini_messages_compat_service.go b/backend/internal/service/gemini_messages_compat_service.go index 2fe55137..8670f99a 100644 --- a/backend/internal/service/gemini_messages_compat_service.go +++ b/backend/internal/service/gemini_messages_compat_service.go @@ -2358,29 +2358,36 @@ type UpstreamHTTPResult struct { } func (s *GeminiMessagesCompatService) handleNativeNonStreamingResponse(c *gin.Context, resp *http.Response, isOAuth bool) (*ClaudeUsage, error) { - // Log response headers for debugging - logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ========== Response Headers ==========") - for key, values := range resp.Header { - if strings.HasPrefix(strings.ToLower(key), "x-ratelimit") { - logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] %s: %v", key, values) + if s.cfg != nil && s.cfg.Gateway.GeminiDebugResponseHeaders { + logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ========== Response Headers ==========") + for key, values := range resp.Header { + if strings.HasPrefix(strings.ToLower(key), "x-ratelimit") { + logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] %s: %v", key, values) + } } + logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ========================================") } - logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ========================================") - respBody, err := io.ReadAll(resp.Body) + maxBytes := resolveUpstreamResponseReadLimit(s.cfg) + respBody, err := readUpstreamResponseBodyLimited(resp.Body, maxBytes) if err != nil { + if errors.Is(err, ErrUpstreamResponseBodyTooLarge) { + setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "") + c.JSON(http.StatusBadGateway, gin.H{ + "error": gin.H{ + "type": "upstream_error", + "message": "Upstream response too large", + }, + }) + } return nil, err } - var parsed map[string]any if isOAuth { unwrappedBody, uwErr := unwrapGeminiResponse(respBody) if uwErr == nil { respBody = unwrappedBody } - _ = json.Unmarshal(respBody, &parsed) - } else { - _ = json.Unmarshal(respBody, &parsed) } responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders) @@ -2398,14 +2405,15 @@ func (s *GeminiMessagesCompatService) handleNativeNonStreamingResponse(c *gin.Co } func (s *GeminiMessagesCompatService) handleNativeStreamingResponse(c *gin.Context, resp *http.Response, startTime time.Time, isOAuth bool) (*geminiNativeStreamResult, error) { - // Log response headers for debugging - logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ========== Streaming Response Headers ==========") - for key, values := range resp.Header { - if strings.HasPrefix(strings.ToLower(key), "x-ratelimit") { - logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] %s: %v", key, values) + if s.cfg != nil && s.cfg.Gateway.GeminiDebugResponseHeaders { + logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ========== Streaming Response Headers ==========") + for key, values := range resp.Header { + if strings.HasPrefix(strings.ToLower(key), "x-ratelimit") { + logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] %s: %v", key, values) + } } + logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ====================================================") } - logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ====================================================") if s.cfg != nil { responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders) diff --git a/backend/internal/service/gemini_messages_compat_service_test.go b/backend/internal/service/gemini_messages_compat_service_test.go index c5888d88..7560f480 100644 --- a/backend/internal/service/gemini_messages_compat_service_test.go +++ b/backend/internal/service/gemini_messages_compat_service_test.go @@ -3,10 +3,15 @@ package service import ( "encoding/json" "fmt" + "io" + "net/http" + "net/http/httptest" "strings" "testing" "time" + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/gin-gonic/gin" "github.com/stretchr/testify/require" ) @@ -133,6 +138,38 @@ func TestConvertClaudeToolsToGeminiTools_CustomType(t *testing.T) { } } +func TestGeminiHandleNativeNonStreamingResponse_DebugDisabledDoesNotEmitHeaderLogs(t *testing.T) { + gin.SetMode(gin.TestMode) + logSink, restore := captureStructuredLog(t) + defer restore() + + svc := &GeminiMessagesCompatService{ + cfg: &config.Config{ + Gateway: config.GatewayConfig{ + GeminiDebugResponseHeaders: false, + }, + }, + } + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{ + "Content-Type": []string{"application/json"}, + "X-RateLimit-Limit": []string{"60"}, + }, + Body: io.NopCloser(strings.NewReader(`{"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":2}}`)), + } + + usage, err := svc.handleNativeNonStreamingResponse(c, resp, false) + require.NoError(t, err) + require.NotNil(t, usage) + require.False(t, logSink.ContainsMessage("[GeminiAPI]"), "debug 关闭时不应输出 Gemini 响应头日志") +} + func TestConvertClaudeMessagesToGeminiGenerateContent_AddsThoughtSignatureForToolUse(t *testing.T) { claudeReq := map[string]any{ "model": "claude-haiku-4-5-20251001", diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index 157506a6..ac93aa0c 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -1741,8 +1741,18 @@ func (s *OpenAIGatewayService) handleNonStreamingResponsePassthrough( resp *http.Response, c *gin.Context, ) (*OpenAIUsage, error) { - body, err := io.ReadAll(resp.Body) + maxBytes := resolveUpstreamResponseReadLimit(s.cfg) + body, err := readUpstreamResponseBodyLimited(resp.Body, maxBytes) if err != nil { + if errors.Is(err, ErrUpstreamResponseBodyTooLarge) { + setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "") + c.JSON(http.StatusBadGateway, gin.H{ + "error": gin.H{ + "type": "upstream_error", + "message": "Upstream response too large", + }, + }) + } return nil, err } @@ -2371,8 +2381,18 @@ func (s *OpenAIGatewayService) parseSSEUsage(data string, usage *OpenAIUsage) { } func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string) (*OpenAIUsage, error) { - body, err := io.ReadAll(resp.Body) + maxBytes := resolveUpstreamResponseReadLimit(s.cfg) + body, err := readUpstreamResponseBodyLimited(resp.Body, maxBytes) if err != nil { + if errors.Is(err, ErrUpstreamResponseBodyTooLarge) { + setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "") + c.JSON(http.StatusBadGateway, gin.H{ + "error": gin.H{ + "type": "upstream_error", + "message": "Upstream response too large", + }, + }) + } return nil, err } @@ -2930,6 +2950,25 @@ func normalizeOpenAIPassthroughOAuthBody(body []byte) ([]byte, bool, error) { return normalized, changed, nil } +func detectOpenAIPassthroughInstructionsRejectReason(reqModel string, body []byte) string { + model := strings.ToLower(strings.TrimSpace(reqModel)) + if !strings.Contains(model, "codex") { + return "" + } + + instructions := gjson.GetBytes(body, "instructions") + if !instructions.Exists() { + return "instructions_missing" + } + if instructions.Type != gjson.String { + return "instructions_not_string" + } + if strings.TrimSpace(instructions.String()) == "" { + return "instructions_empty" + } + return "" +} + func extractOpenAIReasoningEffortFromBody(body []byte, requestedModel string) *string { reasoningEffort := strings.TrimSpace(gjson.GetBytes(body, "reasoning.effort").String()) if reasoningEffort == "" { @@ -3002,22 +3041,3 @@ func normalizeOpenAIReasoningEffort(raw string) string { return "" } } -func detectOpenAIPassthroughInstructionsRejectReason(reqModel string, body []byte) string { - model := strings.ToLower(strings.TrimSpace(reqModel)) - if !strings.Contains(model, "codex") { - return "" - } - - instructions := gjson.GetBytes(body, "instructions") - if !instructions.Exists() { - return "instructions_missing" - } - if instructions.Type != gjson.String { - return "instructions_not_string" - } - if strings.TrimSpace(instructions.String()) == "" { - return "instructions_empty" - } - return "" -} - diff --git a/backend/internal/service/upstream_response_limit.go b/backend/internal/service/upstream_response_limit.go new file mode 100644 index 00000000..aecf69a3 --- /dev/null +++ b/backend/internal/service/upstream_response_limit.go @@ -0,0 +1,38 @@ +package service + +import ( + "errors" + "fmt" + "io" + + "github.com/Wei-Shaw/sub2api/internal/config" +) + +var ErrUpstreamResponseBodyTooLarge = errors.New("upstream response body too large") + +const defaultUpstreamResponseReadMaxBytes int64 = 8 * 1024 * 1024 + +func resolveUpstreamResponseReadLimit(cfg *config.Config) int64 { + if cfg != nil && cfg.Gateway.UpstreamResponseReadMaxBytes > 0 { + return cfg.Gateway.UpstreamResponseReadMaxBytes + } + return defaultUpstreamResponseReadMaxBytes +} + +func readUpstreamResponseBodyLimited(reader io.Reader, maxBytes int64) ([]byte, error) { + if reader == nil { + return nil, errors.New("response body is nil") + } + if maxBytes <= 0 { + maxBytes = defaultUpstreamResponseReadMaxBytes + } + + body, err := io.ReadAll(io.LimitReader(reader, maxBytes+1)) + if err != nil { + return nil, err + } + if int64(len(body)) > maxBytes { + return nil, fmt.Errorf("%w: limit=%d", ErrUpstreamResponseBodyTooLarge, maxBytes) + } + return body, nil +} diff --git a/backend/internal/service/upstream_response_limit_test.go b/backend/internal/service/upstream_response_limit_test.go new file mode 100644 index 00000000..b9e5cc6d --- /dev/null +++ b/backend/internal/service/upstream_response_limit_test.go @@ -0,0 +1,37 @@ +package service + +import ( + "bytes" + "errors" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +func TestResolveUpstreamResponseReadLimit(t *testing.T) { + t.Run("use default when config missing", func(t *testing.T) { + require.Equal(t, defaultUpstreamResponseReadMaxBytes, resolveUpstreamResponseReadLimit(nil)) + }) + + t.Run("use configured value", func(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.UpstreamResponseReadMaxBytes = 1234 + require.Equal(t, int64(1234), resolveUpstreamResponseReadLimit(cfg)) + }) +} + +func TestReadUpstreamResponseBodyLimited(t *testing.T) { + t.Run("within limit", func(t *testing.T) { + body, err := readUpstreamResponseBodyLimited(bytes.NewReader([]byte("ok")), 2) + require.NoError(t, err) + require.Equal(t, []byte("ok"), body) + }) + + t.Run("exceeds limit", func(t *testing.T) { + body, err := readUpstreamResponseBodyLimited(bytes.NewReader([]byte("toolong")), 3) + require.Nil(t, body) + require.Error(t, err) + require.True(t, errors.Is(err, ErrUpstreamResponseBodyTooLarge)) + }) +} diff --git a/deploy/config.example.yaml b/deploy/config.example.yaml index f016fd49..9fd2d391 100644 --- a/deploy/config.example.yaml +++ b/deploy/config.example.yaml @@ -146,6 +146,15 @@ gateway: # Max request body size in bytes (default: 100MB) # 请求体最大字节数(默认 100MB) max_body_size: 104857600 + # Max bytes to read for non-stream upstream responses (default: 8MB) + # 非流式上游响应体读取上限(默认 8MB) + upstream_response_read_max_bytes: 8388608 + # Max bytes to read for proxy probe responses (default: 1MB) + # 代理探测响应体读取上限(默认 1MB) + proxy_probe_response_read_max_bytes: 1048576 + # Enable Gemini upstream response header debug logs (default: false) + # 是否开启 Gemini 上游响应头调试日志(默认 false) + gemini_debug_response_headers: false # Sora max request body size in bytes (0=use max_body_size) # Sora 请求体最大字节数(0=使用 max_body_size) sora_max_body_size: 268435456