diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 53fc1278..eb763bbe 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -99,7 +99,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { antigravityOAuthService := service.NewAntigravityOAuthService(proxyRepository) antigravityTokenProvider := service.NewAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService) httpUpstream := repository.NewHTTPUpstream(configConfig) - antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, antigravityTokenProvider, rateLimitService, httpUpstream) + antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, antigravityTokenProvider, rateLimitService, httpUpstream, configConfig) accountTestService := service.NewAccountTestService(accountRepository, oAuthService, openAIOAuthService, geminiTokenProvider, antigravityGatewayService, httpUpstream, configConfig) concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig) concurrencyService := service.NewConcurrencyService(concurrencyCache) @@ -136,9 +136,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { deferredService := service.ProvideDeferredService(accountRepository, timingWheelService) gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService) geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig) - gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService) + gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, configConfig) openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService) - openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService) + openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, configConfig) handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo) handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler) jwtAuthMiddleware := middleware.NewJWTAuthMiddleware(authService, userService) diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index f073ee1a..7f3cecd0 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -45,6 +45,7 @@ type Config struct { RateLimit RateLimitConfig `mapstructure:"rate_limit"` Pricing PricingConfig `mapstructure:"pricing"` Gateway GatewayConfig `mapstructure:"gateway"` + Concurrency ConcurrencyConfig `mapstructure:"concurrency"` TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"` RunMode string `mapstructure:"run_mode" yaml:"run_mode"` Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC" @@ -155,6 +156,11 @@ type CircuitBreakerConfig struct { HalfOpenRequests int `mapstructure:"half_open_requests"` } +type ConcurrencyConfig struct { + // PingInterval: 并发等待期间的 SSE ping 间隔(秒) + PingInterval int `mapstructure:"ping_interval"` +} + // GatewayConfig API网关相关配置 type GatewayConfig struct { // 等待上游响应头的超时时间(秒),0表示无超时 @@ -187,6 +193,13 @@ type GatewayConfig struct { // 应大于最长 LLM 请求时间,防止请求完成前槽位过期 ConcurrencySlotTTLMinutes int `mapstructure:"concurrency_slot_ttl_minutes"` + // StreamDataIntervalTimeout: 流数据间隔超时(秒),0表示禁用 + StreamDataIntervalTimeout int `mapstructure:"stream_data_interval_timeout"` + // StreamKeepaliveInterval: 流式 keepalive 间隔(秒),0表示禁用 + StreamKeepaliveInterval int `mapstructure:"stream_keepalive_interval"` + // MaxLineSize: 上游 SSE 单行最大字节数(0使用默认值) + MaxLineSize int `mapstructure:"max_line_size"` + // 是否记录上游错误响应体摘要(避免输出请求内容) LogUpstreamErrorBody bool `mapstructure:"log_upstream_error_body"` // 上游错误响应体记录最大字节数(超过会截断) @@ -475,7 +488,7 @@ func setDefaults() { viper.SetDefault("timezone", "Asia/Shanghai") // Gateway - viper.SetDefault("gateway.response_header_timeout", 300) // 300秒(5分钟)等待上游响应头,LLM高负载时可能排队较久 + viper.SetDefault("gateway.response_header_timeout", 600) // 600秒(10分钟)等待上游响应头,LLM高负载时可能排队较久 viper.SetDefault("gateway.log_upstream_error_body", false) viper.SetDefault("gateway.log_upstream_error_body_max_bytes", 2048) viper.SetDefault("gateway.inject_beta_for_apikey", false) @@ -486,16 +499,20 @@ func setDefaults() { viper.SetDefault("gateway.max_idle_conns", 240) // 最大空闲连接总数(HTTP/2 场景默认) viper.SetDefault("gateway.max_idle_conns_per_host", 120) // 每主机最大空闲连接(HTTP/2 场景默认) viper.SetDefault("gateway.max_conns_per_host", 240) // 每主机最大连接数(含活跃,HTTP/2 场景默认) - viper.SetDefault("gateway.idle_conn_timeout_seconds", 300) // 空闲连接超时(秒) + viper.SetDefault("gateway.idle_conn_timeout_seconds", 90) // 空闲连接超时(秒) viper.SetDefault("gateway.max_upstream_clients", 5000) viper.SetDefault("gateway.client_idle_ttl_seconds", 900) - viper.SetDefault("gateway.concurrency_slot_ttl_minutes", 15) // 并发槽位过期时间(支持超长请求) + viper.SetDefault("gateway.concurrency_slot_ttl_minutes", 30) // 并发槽位过期时间(支持超长请求) + viper.SetDefault("gateway.stream_data_interval_timeout", 180) + viper.SetDefault("gateway.stream_keepalive_interval", 10) + viper.SetDefault("gateway.max_line_size", 10*1024*1024) viper.SetDefault("gateway.scheduling.sticky_session_max_waiting", 3) viper.SetDefault("gateway.scheduling.sticky_session_wait_timeout", 45*time.Second) viper.SetDefault("gateway.scheduling.fallback_wait_timeout", 30*time.Second) viper.SetDefault("gateway.scheduling.fallback_max_waiting", 100) viper.SetDefault("gateway.scheduling.load_batch_enabled", true) viper.SetDefault("gateway.scheduling.slot_cleanup_interval", 30*time.Second) + viper.SetDefault("concurrency.ping_interval", 10) // TokenRefresh viper.SetDefault("token_refresh.enabled", true) @@ -604,6 +621,9 @@ func (c *Config) Validate() error { if c.Gateway.IdleConnTimeoutSeconds <= 0 { return fmt.Errorf("gateway.idle_conn_timeout_seconds must be positive") } + if c.Gateway.IdleConnTimeoutSeconds > 180 { + log.Printf("Warning: gateway.idle_conn_timeout_seconds is %d (> 180). Consider 60-120 seconds for better connection reuse.", c.Gateway.IdleConnTimeoutSeconds) + } if c.Gateway.MaxUpstreamClients <= 0 { return fmt.Errorf("gateway.max_upstream_clients must be positive") } @@ -613,6 +633,26 @@ func (c *Config) Validate() error { if c.Gateway.ConcurrencySlotTTLMinutes <= 0 { return fmt.Errorf("gateway.concurrency_slot_ttl_minutes must be positive") } + if c.Gateway.StreamDataIntervalTimeout < 0 { + return fmt.Errorf("gateway.stream_data_interval_timeout must be non-negative") + } + if c.Gateway.StreamDataIntervalTimeout != 0 && + (c.Gateway.StreamDataIntervalTimeout < 30 || c.Gateway.StreamDataIntervalTimeout > 300) { + return fmt.Errorf("gateway.stream_data_interval_timeout must be 0 or between 30-300 seconds") + } + if c.Gateway.StreamKeepaliveInterval < 0 { + return fmt.Errorf("gateway.stream_keepalive_interval must be non-negative") + } + if c.Gateway.StreamKeepaliveInterval != 0 && + (c.Gateway.StreamKeepaliveInterval < 5 || c.Gateway.StreamKeepaliveInterval > 30) { + return fmt.Errorf("gateway.stream_keepalive_interval must be 0 or between 5-30 seconds") + } + if c.Gateway.MaxLineSize < 0 { + return fmt.Errorf("gateway.max_line_size must be non-negative") + } + if c.Gateway.MaxLineSize != 0 && c.Gateway.MaxLineSize < 1024*1024 { + return fmt.Errorf("gateway.max_line_size must be at least 1MB") + } if c.Gateway.Scheduling.StickySessionMaxWaiting <= 0 { return fmt.Errorf("gateway.scheduling.sticky_session_max_waiting must be positive") } @@ -628,6 +668,9 @@ func (c *Config) Validate() error { if c.Gateway.Scheduling.SlotCleanupInterval < 0 { return fmt.Errorf("gateway.scheduling.slot_cleanup_interval must be non-negative") } + if c.Concurrency.PingInterval < 5 || c.Concurrency.PingInterval > 30 { + return fmt.Errorf("concurrency.ping_interval must be between 5-30 seconds") + } return nil } diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index 21a3af56..0be81ae2 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -11,6 +11,7 @@ import ( "strings" "time" + "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" pkgerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/claude" @@ -39,14 +40,19 @@ func NewGatewayHandler( userService *service.UserService, concurrencyService *service.ConcurrencyService, billingCacheService *service.BillingCacheService, + cfg *config.Config, ) *GatewayHandler { + pingInterval := time.Duration(0) + if cfg != nil { + pingInterval = time.Duration(cfg.Concurrency.PingInterval) * time.Second + } return &GatewayHandler{ gatewayService: gatewayService, geminiCompatService: geminiCompatService, antigravityGatewayService: antigravityGatewayService, userService: userService, billingCacheService: billingCacheService, - concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude), + concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude, pingInterval), } } @@ -122,6 +128,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) { h.handleConcurrencyError(c, err, "user", streamStarted) return } + // 在请求结束或 Context 取消时确保释放槽位,避免客户端断开造成泄漏 + userReleaseFunc = wrapReleaseOnDone(c.Request.Context(), userReleaseFunc) if userReleaseFunc != nil { defer userReleaseFunc() } @@ -222,6 +230,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) { log.Printf("Bind sticky session failed: %v", err) } } + // 账号槽位/等待计数需要在超时或断开时安全回收 + accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc) + accountWaitRelease = wrapReleaseOnDone(c.Request.Context(), accountWaitRelease) // 转发请求 - 根据账号平台分流 var result *service.ForwardResult @@ -346,6 +357,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) { log.Printf("Bind sticky session failed: %v", err) } } + // 账号槽位/等待计数需要在超时或断开时安全回收 + accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc) + accountWaitRelease = wrapReleaseOnDone(c.Request.Context(), accountWaitRelease) // 转发请求 - 根据账号平台分流 var result *service.ForwardResult diff --git a/backend/internal/handler/gateway_helper.go b/backend/internal/handler/gateway_helper.go index 9d2e4a9d..2eb3ac72 100644 --- a/backend/internal/handler/gateway_helper.go +++ b/backend/internal/handler/gateway_helper.go @@ -5,6 +5,7 @@ import ( "fmt" "math/rand" "net/http" + "sync" "time" "github.com/Wei-Shaw/sub2api/internal/service" @@ -26,8 +27,8 @@ import ( const ( // maxConcurrencyWait 等待并发槽位的最大时间 maxConcurrencyWait = 30 * time.Second - // pingInterval 流式响应等待时发送 ping 的间隔 - pingInterval = 15 * time.Second + // defaultPingInterval 流式响应等待时发送 ping 的默认间隔 + defaultPingInterval = 10 * time.Second // initialBackoff 初始退避时间 initialBackoff = 100 * time.Millisecond // backoffMultiplier 退避时间乘数(指数退避) @@ -44,6 +45,8 @@ const ( SSEPingFormatClaude SSEPingFormat = "data: {\"type\": \"ping\"}\n\n" // SSEPingFormatNone indicates no ping should be sent (e.g., OpenAI has no ping spec) SSEPingFormatNone SSEPingFormat = "" + // SSEPingFormatComment is an SSE comment ping for OpenAI/Codex CLI clients + SSEPingFormatComment SSEPingFormat = ":\n\n" ) // ConcurrencyError represents a concurrency limit error with context @@ -63,16 +66,38 @@ func (e *ConcurrencyError) Error() string { type ConcurrencyHelper struct { concurrencyService *service.ConcurrencyService pingFormat SSEPingFormat + pingInterval time.Duration } // NewConcurrencyHelper creates a new ConcurrencyHelper -func NewConcurrencyHelper(concurrencyService *service.ConcurrencyService, pingFormat SSEPingFormat) *ConcurrencyHelper { +func NewConcurrencyHelper(concurrencyService *service.ConcurrencyService, pingFormat SSEPingFormat, pingInterval time.Duration) *ConcurrencyHelper { + if pingInterval <= 0 { + pingInterval = defaultPingInterval + } return &ConcurrencyHelper{ concurrencyService: concurrencyService, pingFormat: pingFormat, + pingInterval: pingInterval, } } +// wrapReleaseOnDone ensures release runs at most once and still triggers on context cancellation. +// 用于避免客户端断开或上游超时导致的并发槽位泄漏。 +func wrapReleaseOnDone(ctx context.Context, releaseFunc func()) func() { + if releaseFunc == nil { + return nil + } + var once sync.Once + wrapped := func() { + once.Do(releaseFunc) + } + go func() { + <-ctx.Done() + wrapped() + }() + return wrapped +} + // IncrementWaitCount increments the wait count for a user func (h *ConcurrencyHelper) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) { return h.concurrencyService.IncrementWaitCount(ctx, userID, maxWait) @@ -174,7 +199,7 @@ func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType // Only create ping ticker if ping is needed var pingCh <-chan time.Time if needPing { - pingTicker := time.NewTicker(pingInterval) + pingTicker := time.NewTicker(h.pingInterval) defer pingTicker.Stop() pingCh = pingTicker.C } diff --git a/backend/internal/handler/gemini_v1beta_handler.go b/backend/internal/handler/gemini_v1beta_handler.go index e6a59473..df6b98bd 100644 --- a/backend/internal/handler/gemini_v1beta_handler.go +++ b/backend/internal/handler/gemini_v1beta_handler.go @@ -165,7 +165,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { subscription, _ := middleware.GetSubscriptionFromContext(c) // For Gemini native API, do not send Claude-style ping frames. - geminiConcurrency := NewConcurrencyHelper(h.concurrencyHelper.concurrencyService, SSEPingFormatNone) + geminiConcurrency := NewConcurrencyHelper(h.concurrencyHelper.concurrencyService, SSEPingFormatNone, 0) // 0) wait queue check maxWait := service.CalculateMaxWait(authSubject.Concurrency) @@ -185,6 +185,8 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { googleError(c, http.StatusTooManyRequests, err.Error()) return } + // 确保请求取消时也会释放槽位,避免长连接被动中断造成泄漏 + userReleaseFunc = wrapReleaseOnDone(c.Request.Context(), userReleaseFunc) if userReleaseFunc != nil { defer userReleaseFunc() } @@ -261,6 +263,9 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { log.Printf("Bind sticky session failed: %v", err) } } + // 账号槽位/等待计数需要在超时或断开时安全回收 + accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc) + accountWaitRelease = wrapReleaseOnDone(c.Request.Context(), accountWaitRelease) // 5) forward (根据平台分流) var result *service.ForwardResult diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go index 3f110a3e..0a7602c6 100644 --- a/backend/internal/handler/openai_gateway_handler.go +++ b/backend/internal/handler/openai_gateway_handler.go @@ -10,6 +10,7 @@ import ( "net/http" "time" + "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/pkg/openai" middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/service" @@ -29,11 +30,16 @@ func NewOpenAIGatewayHandler( gatewayService *service.OpenAIGatewayService, concurrencyService *service.ConcurrencyService, billingCacheService *service.BillingCacheService, + cfg *config.Config, ) *OpenAIGatewayHandler { + pingInterval := time.Duration(0) + if cfg != nil { + pingInterval = time.Duration(cfg.Concurrency.PingInterval) * time.Second + } return &OpenAIGatewayHandler{ gatewayService: gatewayService, billingCacheService: billingCacheService, - concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatNone), + concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval), } } @@ -124,6 +130,8 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { h.handleConcurrencyError(c, err, "user", streamStarted) return } + // 确保请求取消时也会释放槽位,避免长连接被动中断造成泄漏 + userReleaseFunc = wrapReleaseOnDone(c.Request.Context(), userReleaseFunc) if userReleaseFunc != nil { defer userReleaseFunc() } @@ -202,6 +210,9 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { log.Printf("Bind sticky session failed: %v", err) } } + // 账号槽位/等待计数需要在超时或断开时安全回收 + accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc) + accountWaitRelease = wrapReleaseOnDone(c.Request.Context(), accountWaitRelease) // Forward request result, err := h.gatewayService.Forward(c.Request.Context(), c, account, body) diff --git a/backend/internal/pkg/httpclient/pool.go b/backend/internal/pkg/httpclient/pool.go index f76e375d..3cd5a592 100644 --- a/backend/internal/pkg/httpclient/pool.go +++ b/backend/internal/pkg/httpclient/pool.go @@ -34,7 +34,7 @@ import ( const ( defaultMaxIdleConns = 100 // 最大空闲连接数 defaultMaxIdleConnsPerHost = 10 // 每个主机最大空闲连接数 - defaultIdleConnTimeout = 90 * time.Second // 空闲连接超时时间 + defaultIdleConnTimeout = 90 * time.Second // 空闲连接超时时间(建议小于上游 LB 超时) ) // Options 定义共享 HTTP 客户端的构建参数 diff --git a/backend/internal/repository/http_upstream.go b/backend/internal/repository/http_upstream.go index 39d98839..3c84ab1d 100644 --- a/backend/internal/repository/http_upstream.go +++ b/backend/internal/repository/http_upstream.go @@ -30,9 +30,9 @@ const ( // defaultMaxConnsPerHost: 默认每主机最大连接数(含活跃连接) // 达到上限后新请求会等待,而非无限创建连接 defaultMaxConnsPerHost = 240 - // defaultIdleConnTimeout: 默认空闲连接超时时间(5分钟) - // 超时后连接会被关闭,释放系统资源 - defaultIdleConnTimeout = 300 * time.Second + // defaultIdleConnTimeout: 默认空闲连接超时时间(90秒) + // 超时后连接会被关闭,释放系统资源(建议小于上游 LB 超时) + defaultIdleConnTimeout = 90 * time.Second // defaultResponseHeaderTimeout: 默认等待响应头超时时间(5分钟) // LLM 请求可能排队较久,需要较长超时 defaultResponseHeaderTimeout = 300 * time.Second diff --git a/backend/internal/service/antigravity_gateway_service.go b/backend/internal/service/antigravity_gateway_service.go index e4843f1b..e2719cf6 100644 --- a/backend/internal/service/antigravity_gateway_service.go +++ b/backend/internal/service/antigravity_gateway_service.go @@ -13,6 +13,7 @@ import ( "strings" "time" + "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" "github.com/gin-gonic/gin" "github.com/google/uuid" @@ -64,6 +65,7 @@ type AntigravityGatewayService struct { tokenProvider *AntigravityTokenProvider rateLimitService *RateLimitService httpUpstream HTTPUpstream + cfg *config.Config } func NewAntigravityGatewayService( @@ -72,12 +74,14 @@ func NewAntigravityGatewayService( tokenProvider *AntigravityTokenProvider, rateLimitService *RateLimitService, httpUpstream HTTPUpstream, + cfg *config.Config, ) *AntigravityGatewayService { return &AntigravityGatewayService{ accountRepo: accountRepo, tokenProvider: tokenProvider, rateLimitService: rateLimitService, httpUpstream: httpUpstream, + cfg: cfg, } } @@ -674,57 +678,147 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context return nil, errors.New("streaming not supported") } - reader := bufio.NewReader(resp.Body) + // 使用 Scanner 并限制单行大小,避免 ReadString 无上限导致 OOM + scanner := bufio.NewScanner(resp.Body) + scanner.Buffer(make([]byte, 64*1024), defaultMaxLineSize) usage := &ClaudeUsage{} var firstTokenMs *int + type scanEvent struct { + line string + err error + } + // 独立 goroutine 读取上游,避免读取阻塞影响超时处理 + events := make(chan scanEvent, 1) + done := make(chan struct{}) + sendEvent := func(ev scanEvent) bool { + select { + case events <- ev: + return true + case <-done: + return false + } + } + go func() { + defer close(events) + for scanner.Scan() { + if !sendEvent(scanEvent{line: scanner.Text()}) { + return + } + } + if err := scanner.Err(); err != nil { + _ = sendEvent(scanEvent{err: err}) + } + }() + defer close(done) + + // 上游数据间隔超时保护(防止上游挂起长期占用连接) + streamInterval := time.Duration(0) + if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 { + streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second + } + var intervalTimer *time.Timer + if streamInterval > 0 { + intervalTimer = time.NewTimer(streamInterval) + defer intervalTimer.Stop() + } + var intervalCh <-chan time.Time + if intervalTimer != nil { + intervalCh = intervalTimer.C + } + resetInterval := func() { + if intervalTimer == nil { + return + } + if !intervalTimer.Stop() { + select { + case <-intervalTimer.C: + default: + } + } + intervalTimer.Reset(streamInterval) + } + + // 仅发送一次错误事件,避免多次写入导致协议混乱 + errorEventSent := false + sendErrorEvent := func(reason string) { + if errorEventSent { + return + } + errorEventSent = true + _, _ = fmt.Fprintf(c.Writer, "event: error\ndata: {\"error\":\"%s\"}\n\n", reason) + flusher.Flush() + } + for { - line, err := reader.ReadString('\n') - if len(line) > 0 { + select { + case ev, ok := <-events: + if !ok { + return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, nil + } + if ev.err != nil { + if errors.Is(ev.err, bufio.ErrTooLong) { + log.Printf("SSE line too long (antigravity): max_size=%d error=%v", defaultMaxLineSize, ev.err) + sendErrorEvent("response_too_large") + return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, ev.err + } + sendErrorEvent("stream_read_error") + return nil, ev.err + } + + resetInterval() + line := ev.line trimmed := strings.TrimRight(line, "\r\n") if strings.HasPrefix(trimmed, "data:") { payload := strings.TrimSpace(strings.TrimPrefix(trimmed, "data:")) if payload == "" || payload == "[DONE]" { - _, _ = io.WriteString(c.Writer, line) - flusher.Flush() - } else { - // 解包 v1internal 响应 - inner, parseErr := s.unwrapV1InternalResponse([]byte(payload)) - if parseErr == nil && inner != nil { - payload = string(inner) - } - - // 解析 usage - var parsed map[string]any - if json.Unmarshal(inner, &parsed) == nil { - if u := extractGeminiUsage(parsed); u != nil { - usage = u - } - } - - if firstTokenMs == nil { - ms := int(time.Since(startTime).Milliseconds()) - firstTokenMs = &ms - } - - _, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", payload) + if _, err := fmt.Fprintf(c.Writer, "%s\n", line); err != nil { + sendErrorEvent("write_failed") + return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, err + } flusher.Flush() + continue } - } else { - _, _ = io.WriteString(c.Writer, line) - flusher.Flush() - } - } - if errors.Is(err, io.EOF) { - break - } - if err != nil { - return nil, err + // 解包 v1internal 响应 + inner, parseErr := s.unwrapV1InternalResponse([]byte(payload)) + if parseErr == nil && inner != nil { + payload = string(inner) + } + + // 解析 usage + var parsed map[string]any + if json.Unmarshal(inner, &parsed) == nil { + if u := extractGeminiUsage(parsed); u != nil { + usage = u + } + } + + if firstTokenMs == nil { + ms := int(time.Since(startTime).Milliseconds()) + firstTokenMs = &ms + } + + if _, err := fmt.Fprintf(c.Writer, "data: %s\n\n", payload); err != nil { + sendErrorEvent("write_failed") + return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, err + } + flusher.Flush() + continue + } + + if _, err := fmt.Fprintf(c.Writer, "%s\n", line); err != nil { + sendErrorEvent("write_failed") + return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, err + } + flusher.Flush() + + case <-intervalCh: + log.Printf("Stream data interval timeout (antigravity)") + sendErrorEvent("stream_timeout") + return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout") } } - - return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, nil } func (s *AntigravityGatewayService) handleGeminiNonStreamingResponse(c *gin.Context, resp *http.Response) (*ClaudeUsage, error) { @@ -863,7 +957,9 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context processor := antigravity.NewStreamingProcessor(originalModel) var firstTokenMs *int - reader := bufio.NewReader(resp.Body) + // 使用 Scanner 并限制单行大小,避免 ReadString 无上限导致 OOM + scanner := bufio.NewScanner(resp.Body) + scanner.Buffer(make([]byte, 64*1024), defaultMaxLineSize) // 辅助函数:转换 antigravity.ClaudeUsage 到 service.ClaudeUsage convertUsage := func(agUsage *antigravity.ClaudeUsage) *ClaudeUsage { @@ -878,13 +974,95 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context } } - for { - line, err := reader.ReadString('\n') - if err != nil && !errors.Is(err, io.EOF) { - return nil, fmt.Errorf("stream read error: %w", err) + type scanEvent struct { + line string + err error + } + // 独立 goroutine 读取上游,避免读取阻塞影响超时处理 + events := make(chan scanEvent, 1) + done := make(chan struct{}) + sendEvent := func(ev scanEvent) bool { + select { + case events <- ev: + return true + case <-done: + return false } + } + go func() { + defer close(events) + for scanner.Scan() { + if !sendEvent(scanEvent{line: scanner.Text()}) { + return + } + } + if err := scanner.Err(); err != nil { + _ = sendEvent(scanEvent{err: err}) + } + }() + defer close(done) - if len(line) > 0 { + streamInterval := time.Duration(0) + if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 { + streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second + } + var intervalTimer *time.Timer + if streamInterval > 0 { + intervalTimer = time.NewTimer(streamInterval) + defer intervalTimer.Stop() + } + var intervalCh <-chan time.Time + if intervalTimer != nil { + intervalCh = intervalTimer.C + } + resetInterval := func() { + if intervalTimer == nil { + return + } + if !intervalTimer.Stop() { + select { + case <-intervalTimer.C: + default: + } + } + intervalTimer.Reset(streamInterval) + } + + // 仅发送一次错误事件,避免多次写入导致协议混乱 + errorEventSent := false + sendErrorEvent := func(reason string) { + if errorEventSent { + return + } + errorEventSent = true + _, _ = fmt.Fprintf(c.Writer, "event: error\ndata: {\"error\":\"%s\"}\n\n", reason) + flusher.Flush() + } + + for { + select { + case ev, ok := <-events: + if !ok { + // 发送结束事件 + finalEvents, agUsage := processor.Finish() + if len(finalEvents) > 0 { + _, _ = c.Writer.Write(finalEvents) + flusher.Flush() + } + return &antigravityStreamResult{usage: convertUsage(agUsage), firstTokenMs: firstTokenMs}, nil + } + if ev.err != nil { + if errors.Is(ev.err, bufio.ErrTooLong) { + log.Printf("SSE line too long (antigravity): max_size=%d error=%v", defaultMaxLineSize, ev.err) + sendErrorEvent("response_too_large") + return &antigravityStreamResult{usage: convertUsage(nil), firstTokenMs: firstTokenMs}, ev.err + } + sendErrorEvent("stream_read_error") + return nil, fmt.Errorf("stream read error: %w", ev.err) + } + + resetInterval() + line := ev.line // 处理 SSE 行,转换为 Claude 格式 claudeEvents := processor.ProcessLine(strings.TrimRight(line, "\r\n")) @@ -899,23 +1077,17 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context if len(finalEvents) > 0 { _, _ = c.Writer.Write(finalEvents) } + sendErrorEvent("write_failed") return &antigravityStreamResult{usage: convertUsage(agUsage), firstTokenMs: firstTokenMs}, writeErr } flusher.Flush() } - } - if errors.Is(err, io.EOF) { - break + case <-intervalCh: + log.Printf("Stream data interval timeout (antigravity)") + sendErrorEvent("stream_timeout") + return &antigravityStreamResult{usage: convertUsage(nil), firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout") } } - // 发送结束事件 - finalEvents, agUsage := processor.Finish() - if len(finalEvents) > 0 { - _, _ = c.Writer.Write(finalEvents) - flusher.Flush() - } - - return &antigravityStreamResult{usage: convertUsage(agUsage), firstTokenMs: firstTokenMs}, nil } diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 8f1bf756..e5282101 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -32,6 +32,7 @@ const ( claudeAPIURL = "https://api.anthropic.com/v1/messages?beta=true" claudeAPICountTokensURL = "https://api.anthropic.com/v1/messages/count_tokens?beta=true" stickySessionTTL = time.Hour // 粘性会话TTL + defaultMaxLineSize = 10 * 1024 * 1024 ) // sseDataRe matches SSE data lines with optional whitespace after colon. @@ -1448,53 +1449,143 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http var firstTokenMs *int scanner := bufio.NewScanner(resp.Body) // 设置更大的buffer以处理长行 - scanner.Buffer(make([]byte, 64*1024), 1024*1024) + maxLineSize := defaultMaxLineSize + if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 { + maxLineSize = s.cfg.Gateway.MaxLineSize + } + scanner.Buffer(make([]byte, 64*1024), maxLineSize) + + type scanEvent struct { + line string + err error + } + // 独立 goroutine 读取上游,避免读取阻塞导致超时/keepalive无法处理 + events := make(chan scanEvent, 1) + done := make(chan struct{}) + sendEvent := func(ev scanEvent) bool { + select { + case events <- ev: + return true + case <-done: + return false + } + } + go func() { + defer close(events) + for scanner.Scan() { + if !sendEvent(scanEvent{line: scanner.Text()}) { + return + } + } + if err := scanner.Err(); err != nil { + _ = sendEvent(scanEvent{err: err}) + } + }() + defer close(done) + + streamInterval := time.Duration(0) + if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 { + streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second + } + // 仅监控上游数据间隔超时,避免上游挂起占用资源 + var intervalTimer *time.Timer + if streamInterval > 0 { + intervalTimer = time.NewTimer(streamInterval) + defer intervalTimer.Stop() + } + + // 仅发送一次错误事件,避免多次写入导致协议混乱(写失败时尽力通知客户端) + errorEventSent := false + sendErrorEvent := func(reason string) { + if errorEventSent { + return + } + errorEventSent = true + _, _ = fmt.Fprintf(w, "event: error\ndata: {\"error\":\"%s\"}\n\n", reason) + flusher.Flush() + } needModelReplace := originalModel != mappedModel - for scanner.Scan() { - line := scanner.Text() - if line == "event: error" { - return nil, errors.New("have error in stream") + for { + select { + case ev, ok := <-events: + if !ok { + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil + } + if ev.err != nil { + if errors.Is(ev.err, bufio.ErrTooLong) { + log.Printf("SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err) + sendErrorEvent("response_too_large") + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, ev.err + } + sendErrorEvent("stream_read_error") + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", ev.err) + } + line := ev.line + if intervalTimer != nil { + resetTimer(intervalTimer, streamInterval) + } + if line == "event: error" { + return nil, errors.New("have error in stream") + } + + // Extract data from SSE line (supports both "data: " and "data:" formats) + if sseDataRe.MatchString(line) { + data := sseDataRe.ReplaceAllString(line, "") + + // 如果有模型映射,替换响应中的model字段 + if needModelReplace { + line = s.replaceModelInSSELine(line, mappedModel, originalModel) + } + + // 转发行 + if _, err := fmt.Fprintf(w, "%s\n", line); err != nil { + sendErrorEvent("write_failed") + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, err + } + flusher.Flush() + + // 记录首字时间:第一个有效的 content_block_delta 或 message_start + if firstTokenMs == nil && data != "" && data != "[DONE]" { + ms := int(time.Since(startTime).Milliseconds()) + firstTokenMs = &ms + } + s.parseSSEUsage(data, usage) + } else { + // 非 data 行直接转发 + if _, err := fmt.Fprintf(w, "%s\n", line); err != nil { + sendErrorEvent("write_failed") + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, err + } + flusher.Flush() + } + + case <-func() <-chan time.Time { + if intervalTimer != nil { + return intervalTimer.C + } + return nil + }(): + log.Printf("Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval) + sendErrorEvent("stream_timeout") + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout") } - - // Extract data from SSE line (supports both "data: " and "data:" formats) - if sseDataRe.MatchString(line) { - data := sseDataRe.ReplaceAllString(line, "") - - // 如果有模型映射,替换响应中的model字段 - if needModelReplace { - line = s.replaceModelInSSELine(line, mappedModel, originalModel) - } - - // 转发行 - if _, err := fmt.Fprintf(w, "%s\n", line); err != nil { - return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, err - } - flusher.Flush() - - // 记录首字时间:第一个有效的 content_block_delta 或 message_start - if firstTokenMs == nil && data != "" && data != "[DONE]" { - ms := int(time.Since(startTime).Milliseconds()) - firstTokenMs = &ms - } - s.parseSSEUsage(data, usage) - } else { - // 非 data 行直接转发 - if _, err := fmt.Fprintf(w, "%s\n", line); err != nil { - return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, err - } - flusher.Flush() - } - } - - if err := scanner.Err(); err != nil { - return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", err) } return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil } +func resetTimer(timer *time.Timer, interval time.Duration) { + if !timer.Stop() { + select { + case <-timer.C: + default: + } + } + timer.Reset(interval) +} + // replaceModelInSSELine 替换SSE数据行中的model字段 func (s *GatewayService) replaceModelInSSELine(line, fromModel, toModel string) string { if !sseDataRe.MatchString(line) { diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index 061aeff5..6589df2a 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -775,47 +775,154 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp usage := &OpenAIUsage{} var firstTokenMs *int scanner := bufio.NewScanner(resp.Body) - scanner.Buffer(make([]byte, 64*1024), 1024*1024) + maxLineSize := defaultMaxLineSize + if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 { + maxLineSize = s.cfg.Gateway.MaxLineSize + } + scanner.Buffer(make([]byte, 64*1024), maxLineSize) + + type scanEvent struct { + line string + err error + } + // 独立 goroutine 读取上游,避免读取阻塞影响 keepalive/超时处理 + events := make(chan scanEvent, 1) + done := make(chan struct{}) + sendEvent := func(ev scanEvent) bool { + select { + case events <- ev: + return true + case <-done: + return false + } + } + go func() { + defer close(events) + for scanner.Scan() { + if !sendEvent(scanEvent{line: scanner.Text()}) { + return + } + } + if err := scanner.Err(); err != nil { + _ = sendEvent(scanEvent{err: err}) + } + }() + defer close(done) + + streamInterval := time.Duration(0) + if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 { + streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second + } + // 仅监控上游数据间隔超时,不被下游 keepalive 影响 + var intervalTimer *time.Timer + if streamInterval > 0 { + intervalTimer = time.NewTimer(streamInterval) + defer intervalTimer.Stop() + } + var intervalCh <-chan time.Time + if intervalTimer != nil { + intervalCh = intervalTimer.C + } + + keepaliveInterval := time.Duration(0) + if s.cfg != nil && s.cfg.Gateway.StreamKeepaliveInterval > 0 { + keepaliveInterval = time.Duration(s.cfg.Gateway.StreamKeepaliveInterval) * time.Second + } + // 下游 keepalive 仅用于防止代理空闲断开 + var keepaliveTicker *time.Ticker + if keepaliveInterval > 0 { + keepaliveTicker = time.NewTicker(keepaliveInterval) + defer keepaliveTicker.Stop() + } + var keepaliveCh <-chan time.Time + if keepaliveTicker != nil { + keepaliveCh = keepaliveTicker.C + } + // 记录上次收到上游数据的时间,用于控制 keepalive 发送频率 + lastDataAt := time.Now() + + // 仅发送一次错误事件,避免多次写入导致协议混乱(写失败时尽力通知客户端) + errorEventSent := false + sendErrorEvent := func(reason string) { + if errorEventSent { + return + } + errorEventSent = true + _, _ = fmt.Fprintf(w, "event: error\ndata: {\"error\":\"%s\"}\n\n", reason) + flusher.Flush() + } needModelReplace := originalModel != mappedModel - for scanner.Scan() { - line := scanner.Text() - - // Extract data from SSE line (supports both "data: " and "data:" formats) - if openaiSSEDataRe.MatchString(line) { - data := openaiSSEDataRe.ReplaceAllString(line, "") - - // Replace model in response if needed - if needModelReplace { - line = s.replaceModelInSSELine(line, mappedModel, originalModel) + for { + select { + case ev, ok := <-events: + if !ok { + return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil + } + if ev.err != nil { + if errors.Is(ev.err, bufio.ErrTooLong) { + log.Printf("SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err) + sendErrorEvent("response_too_large") + return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, ev.err + } + sendErrorEvent("stream_read_error") + return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", ev.err) } - // Forward line - if _, err := fmt.Fprintf(w, "%s\n", line); err != nil { - return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err + line := ev.line + lastDataAt = time.Now() + if intervalTimer != nil { + resetTimer(intervalTimer, streamInterval) } - flusher.Flush() - // Record first token time - if firstTokenMs == nil && data != "" && data != "[DONE]" { - ms := int(time.Since(startTime).Milliseconds()) - firstTokenMs = &ms + // Extract data from SSE line (supports both "data: " and "data:" formats) + if openaiSSEDataRe.MatchString(line) { + data := openaiSSEDataRe.ReplaceAllString(line, "") + + // Replace model in response if needed + if needModelReplace { + line = s.replaceModelInSSELine(line, mappedModel, originalModel) + } + + // Forward line + if _, err := fmt.Fprintf(w, "%s\n", line); err != nil { + sendErrorEvent("write_failed") + return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err + } + flusher.Flush() + + // Record first token time + if firstTokenMs == nil && data != "" && data != "[DONE]" { + ms := int(time.Since(startTime).Milliseconds()) + firstTokenMs = &ms + } + s.parseSSEUsage(data, usage) + } else { + // Forward non-data lines as-is + if _, err := fmt.Fprintf(w, "%s\n", line); err != nil { + sendErrorEvent("write_failed") + return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err + } + flusher.Flush() } - s.parseSSEUsage(data, usage) - } else { - // Forward non-data lines as-is - if _, err := fmt.Fprintf(w, "%s\n", line); err != nil { + + case <-intervalCh: + log.Printf("Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval) + sendErrorEvent("stream_timeout") + return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout") + + case <-keepaliveCh: + if time.Since(lastDataAt) < keepaliveInterval { + continue + } + if _, err := fmt.Fprint(w, ":\n\n"); err != nil { return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err } flusher.Flush() } } - if err := scanner.Err(); err != nil { - return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", err) - } - return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil } diff --git a/backend/internal/service/openai_gateway_service_test.go b/backend/internal/service/openai_gateway_service_test.go new file mode 100644 index 00000000..dd8ca6b6 --- /dev/null +++ b/backend/internal/service/openai_gateway_service_test.go @@ -0,0 +1,90 @@ +package service + +import ( + "bufio" + "errors" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/gin-gonic/gin" +) + +func TestOpenAIStreamingTimeout(t *testing.T) { + gin.SetMode(gin.TestMode) + cfg := &config.Config{ + Gateway: config.GatewayConfig{ + StreamDataIntervalTimeout: 1, + StreamKeepaliveInterval: 0, + MaxLineSize: defaultMaxLineSize, + }, + } + svc := &OpenAIGatewayService{cfg: cfg} + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/", nil) + + pr, pw := io.Pipe() + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: pr, + Header: http.Header{}, + } + + start := time.Now() + _, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, start, "model", "model") + _ = pw.Close() + _ = pr.Close() + + if err == nil || !strings.Contains(err.Error(), "stream data interval timeout") { + t.Fatalf("expected stream timeout error, got %v", err) + } + if !strings.Contains(rec.Body.String(), "stream_timeout") { + t.Fatalf("expected stream_timeout SSE error, got %q", rec.Body.String()) + } +} + +func TestOpenAIStreamingTooLong(t *testing.T) { + gin.SetMode(gin.TestMode) + cfg := &config.Config{ + Gateway: config.GatewayConfig{ + StreamDataIntervalTimeout: 0, + StreamKeepaliveInterval: 0, + MaxLineSize: 64 * 1024, + }, + } + svc := &OpenAIGatewayService{cfg: cfg} + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/", nil) + + pr, pw := io.Pipe() + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: pr, + Header: http.Header{}, + } + + go func() { + defer pw.Close() + // 写入超过 MaxLineSize 的单行数据,触发 ErrTooLong + payload := "data: " + strings.Repeat("a", 128*1024) + "\n" + _, _ = pw.Write([]byte(payload)) + }() + + _, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 2}, time.Now(), "model", "model") + _ = pr.Close() + + if !errors.Is(err, bufio.ErrTooLong) { + t.Fatalf("expected ErrTooLong, got %v", err) + } + if !strings.Contains(rec.Body.String(), "response_too_large") { + t.Fatalf("expected response_too_large SSE error, got %q", rec.Body.String()) + } +} diff --git a/deploy/Caddyfile b/deploy/Caddyfile index eaba462b..3aeef51a 100644 --- a/deploy/Caddyfile +++ b/deploy/Caddyfile @@ -25,8 +25,8 @@ timeouts { read_body 30s read_header 10s - write 60s - idle 120s + write 300s + idle 300s } } } @@ -77,7 +77,10 @@ example.com { write_buffer 16KB compression off } - + + # SSE/流式传输优化:禁用响应缓冲,立即刷新数据给客户端 + flush_interval -1 + # 故障转移 fail_duration 30s max_fails 3 @@ -92,6 +95,10 @@ example.com { gzip 6 minimum_length 256 match { + # SSE 请求通常会带 Accept: text/event-stream,需排除压缩 + not header Accept text/event-stream* + # 排除已知 SSE 路径(即便 Accept 缺失) + not path /v1/messages /v1/responses /responses /antigravity/v1/messages /v1beta/models/* /antigravity/v1beta/models/* header Content-Type text/* header Content-Type application/json* header Content-Type application/javascript* @@ -179,6 +186,3 @@ example.com { # ============================================================================= # HTTP 重定向到 HTTPS (Caddy 默认自动处理,此处显式声明) # ============================================================================= -; http://example.com { -; redir https://{host}{uri} permanent -; } diff --git a/deploy/config.example.yaml b/deploy/config.example.yaml index 19210953..7b2c7d39 100644 --- a/deploy/config.example.yaml +++ b/deploy/config.example.yaml @@ -70,7 +70,7 @@ security: # ============================================================================= gateway: # 等待上游响应头超时时间(秒) - response_header_timeout: 300 + response_header_timeout: 600 # 请求体最大字节数(默认 100MB) max_body_size: 104857600 # 连接池隔离策略: @@ -82,14 +82,20 @@ gateway: max_idle_conns: 240 max_idle_conns_per_host: 120 max_conns_per_host: 240 - idle_conn_timeout_seconds: 300 + idle_conn_timeout_seconds: 90 # 上游连接池客户端缓存配置 # max_upstream_clients: 最大缓存客户端数量,超出后淘汰最久未使用的 # client_idle_ttl_seconds: 客户端空闲回收阈值(秒),超时且无活跃请求时回收 max_upstream_clients: 5000 client_idle_ttl_seconds: 900 # 并发槽位过期时间(分钟) - concurrency_slot_ttl_minutes: 15 + concurrency_slot_ttl_minutes: 30 + # 流数据间隔超时(秒),0=禁用 + stream_data_interval_timeout: 180 + # 流式 keepalive 间隔(秒),0=禁用 + stream_keepalive_interval: 10 + # SSE 单行最大字节数(默认 10MB) + max_line_size: 10485760 # Log upstream error response body summary (safe/truncated; does not log request content) log_upstream_error_body: false # Max bytes to log from upstream error body @@ -99,6 +105,13 @@ gateway: # Allow failover on selected 400 errors (default off) failover_on_400: false +# ============================================================================= +# 并发等待配置 +# ============================================================================= +concurrency: + # 并发等待期间的 SSE ping 间隔(秒) + ping_interval: 10 + # ============================================================================= # Database Configuration (PostgreSQL) # =============================================================================