diff --git a/relay/helper/stream_scanner.go b/relay/helper/stream_scanner.go index abb98f42..5724e17c 100644 --- a/relay/helper/stream_scanner.go +++ b/relay/helper/stream_scanner.go @@ -23,6 +23,76 @@ const ( DefaultPingInterval = 10 * time.Second ) +type DoRequestFunc func(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) + +// Optional SSE Ping keep-alive mechanism +// +// Used to solve the problem of the connection with the client timing out due to no data being sent when the upstream +// channel response time is long (e.g., thinking model). +// When enabled, it will send ping data packets to the client via SSE at the specified interval to maintain the connection. +func DoStreamRequestWithPinger(doRequest DoRequestFunc, c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { + SetEventStreamHeaders(c) + + generalSettings := operation_setting.GetGeneralSetting() + pingEnabled := generalSettings.PingIntervalEnabled + pingInterval := time.Duration(generalSettings.PingIntervalSeconds) * time.Second + + pingerCtx, stopPinger := context.WithCancel(c.Request.Context()) + var pingerWg sync.WaitGroup + var doRequestErr error + var resp any + + if pingEnabled { + pingerWg.Add(1) + + gopool.Go(func() { + defer pingerWg.Done() + + if pingInterval <= 0 { + pingInterval = DefaultPingInterval + } + + ticker := time.NewTicker(pingInterval) + defer ticker.Stop() + var pingMutex sync.Mutex + + if common.DebugEnabled { + println("SSE ping goroutine started.") + } + + for { + select { + case <-ticker.C: + pingMutex.Lock() + err := PingData(c) + pingMutex.Unlock() + if err != nil { + common.LogError(c, "SSE ping error: "+err.Error()) + return + } + if common.DebugEnabled { + println("SSE ping data sent.") + } + case <-pingerCtx.Done(): + if common.DebugEnabled { + println("SSE ping goroutine stopped.") + } + return + } + } + }) + } + + resp, doRequestErr = doRequest(c, info, requestBody) + + stopPinger() + if pingEnabled { + pingerWg.Wait() + } + + return resp, doRequestErr +} + func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, dataHandler func(data string) bool) { if resp == nil || dataHandler == nil { @@ -41,26 +111,11 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon stopChan = make(chan bool, 2) scanner = bufio.NewScanner(resp.Body) ticker = time.NewTicker(streamingTimeout) - pingTicker *time.Ticker writeMutex sync.Mutex // Mutex to protect concurrent writes ) - generalSettings := operation_setting.GetGeneralSetting() - pingEnabled := generalSettings.PingIntervalEnabled - pingInterval := time.Duration(generalSettings.PingIntervalSeconds) * time.Second - if pingInterval <= 0 { - pingInterval = DefaultPingInterval - } - - if pingEnabled { - pingTicker = time.NewTicker(pingInterval) - } - defer func() { ticker.Stop() - if pingTicker != nil { - pingTicker.Stop() - } close(stopChan) }() scanner.Buffer(make([]byte, InitialScannerBufferSize), MaxScannerBufferSize) @@ -72,33 +127,6 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon ctx = context.WithValue(ctx, "stop_chan", stopChan) - // Handle ping data sending - if pingEnabled && pingTicker != nil { - gopool.Go(func() { - for { - select { - case <-pingTicker.C: - writeMutex.Lock() // Lock before writing - err := PingData(c) - writeMutex.Unlock() // Unlock after writing - if err != nil { - common.LogError(c, "ping data error: "+err.Error()) - common.SafeSendBool(stopChan, true) - return - } - if common.DebugEnabled { - println("ping data sent") - } - case <-ctx.Done(): - if common.DebugEnabled { - println("ping data goroutine stopped") - } - return - } - } - }) - } - common.RelayCtxGo(ctx, func() { for scanner.Scan() { ticker.Reset(streamingTimeout) diff --git a/relay/relay-text.go b/relay/relay-text.go index 4fdd435d..0147de8d 100644 --- a/relay/relay-text.go +++ b/relay/relay-text.go @@ -192,7 +192,16 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { } var httpResp *http.Response - resp, err := adaptor.DoRequest(c, relayInfo, requestBody) + var resp any + + if relayInfo.IsStream { + // Streaming requests can use SSE ping to keep alive and avoid connection timeout + // The judgment of whether ping is enabled will be made within the function + resp, err = helper.DoStreamRequestWithPinger(adaptor.DoRequest, c, relayInfo, requestBody) + } else { + resp, err = adaptor.DoRequest(c, relayInfo, requestBody) + } + if err != nil { return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) }