diff --git a/relay/channel/api_request.go b/relay/channel/api_request.go index da8d4e14..1d733bd4 100644 --- a/relay/channel/api_request.go +++ b/relay/channel/api_request.go @@ -104,6 +104,65 @@ func DoWssRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody return targetConn, nil } +func startPingKeepAlive(c *gin.Context, pingInterval time.Duration) context.CancelFunc { + pingerCtx, stopPinger := context.WithCancel(context.Background()) + + gopool.Go(func() { + defer func() { + if common2.DebugEnabled { + println("SSE ping goroutine stopped.") + } + }() + + if pingInterval <= 0 { + pingInterval = helper.DefaultPingInterval + } + + ticker := time.NewTicker(pingInterval) + // 退出时清理 ticker + defer ticker.Stop() + + var pingMutex sync.Mutex + if common2.DebugEnabled { + println("SSE ping goroutine started") + } + + for { + select { + // 发送 ping 数据 + case <-ticker.C: + if err := sendPingData(c, &pingMutex); err != nil { + return + } + // 收到退出信号 + case <-pingerCtx.Done(): + return + // request 结束 + case <-c.Request.Context().Done(): + return + } + } + }) + + return stopPinger +} + +func sendPingData(c *gin.Context, mutex *sync.Mutex) error { + mutex.Lock() + defer mutex.Unlock() + + err := helper.PingData(c) + if err != nil { + common2.LogError(c, "SSE ping error: "+err.Error()) + return err + } + + if common2.DebugEnabled { + println("SSE ping data sent.") + } + return nil +} + func doRequest(c *gin.Context, req *http.Request, info *common.RelayInfo) (*http.Response, error) { var client *http.Client var err error @@ -115,69 +174,28 @@ func doRequest(c *gin.Context, req *http.Request, info *common.RelayInfo) (*http } else { client = service.GetHttpClient() } - // 流式请求 ping 保活 - var stopPinger func() - generalSettings := operation_setting.GetGeneralSetting() - pingEnabled := generalSettings.PingIntervalEnabled - var pingerWg sync.WaitGroup + if info.IsStream { helper.SetEventStreamHeaders(c) - if pingEnabled { + // 处理流式请求的 ping 保活 + generalSettings := operation_setting.GetGeneralSetting() + if generalSettings.PingIntervalEnabled { pingInterval := time.Duration(generalSettings.PingIntervalSeconds) * time.Second - var pingerCtx context.Context - pingerCtx, stopPinger = context.WithCancel(c.Request.Context()) - // 退出时清理 pingerCtx 防止泄露 + stopPinger := startPingKeepAlive(c, pingInterval) defer stopPinger() - pingerWg.Add(1) - gopool.Go(func() { - defer pingerWg.Done() - if pingInterval <= 0 { - pingInterval = helper.DefaultPingInterval - } - - ticker := time.NewTicker(pingInterval) - defer ticker.Stop() - var pingMutex sync.Mutex - if common2.DebugEnabled { - println("SSE ping goroutine started") - } - - for { - select { - case <-ticker.C: - pingMutex.Lock() - err2 := helper.PingData(c) - pingMutex.Unlock() - if err2 != nil { - common2.LogError(c, "SSE ping error: "+err.Error()) - return - } - if common2.DebugEnabled { - println("SSE ping data sent.") - } - case <-pingerCtx.Done(): - if common2.DebugEnabled { - println("SSE ping goroutine stopped.") - } - return - } - } - }) } } resp, err := client.Do(req) - // request结束后等待 ping goroutine 完成 - if info.IsStream && pingEnabled { - pingerWg.Wait() - } + if err != nil { return nil, err } if resp == nil { return nil, errors.New("resp is nil") } + _ = req.Body.Close() _ = c.Request.Body.Close() return resp, nil