diff --git a/relay/channel/api_request.go b/relay/channel/api_request.go index 8b2ca889..db5d4f44 100644 --- a/relay/channel/api_request.go +++ b/relay/channel/api_request.go @@ -1,16 +1,23 @@ package channel import ( + "context" "errors" "fmt" - "github.com/gin-gonic/gin" - "github.com/gorilla/websocket" "io" "net/http" common2 "one-api/common" "one-api/relay/common" "one-api/relay/constant" + "one-api/relay/helper" "one-api/service" + "one-api/setting/operation_setting" + "sync" + "time" + + "github.com/bytedance/gopkg/util/gopool" + "github.com/gin-gonic/gin" + "github.com/gorilla/websocket" ) func SetupApiRequestHeader(info *common.RelayInfo, c *gin.Context, req *http.Header) { @@ -105,7 +112,62 @@ func doRequest(c *gin.Context, req *http.Request, info *common.RelayInfo) (*http } else { client = service.GetHttpClient() } + // 流式请求 ping 保活 + var stopPinger func() + generalSettings := operation_setting.GetGeneralSetting() + pingEnabled := generalSettings.PingIntervalEnabled + var pingerWg sync.WaitGroup + if info.IsStream { + helper.SetEventStreamHeaders(c) + pingInterval := time.Duration(generalSettings.PingIntervalSeconds) * time.Second + var pingerCtx context.Context + pingerCtx, stopPinger = context.WithCancel(c.Request.Context()) + + if pingEnabled { + pingerWg.Add(1) + gopool.Go(func() { + defer pingerWg.Done() + if pingInterval <= 0 { + pingInterval = helper.DefaultPingInterval + } + + ticker := time.NewTicker(pingInterval) + defer ticker.Stop() + var pingMutex sync.Mutex + if common2.DebugEnabled { + println("SSE ping goroutine started") + } + + for { + select { + case <-ticker.C: + pingMutex.Lock() + err2 := helper.PingData(c) + pingMutex.Unlock() + if err2 != nil { + common2.LogError(c, "SSE ping error: "+err.Error()) + return + } + if common2.DebugEnabled { + println("SSE ping data sent.") + } + case <-pingerCtx.Done(): + if common2.DebugEnabled { + println("SSE ping goroutine stopped.") + } + return + } + } + }) + } + } + resp, err := client.Do(req) + // request结束后停止ping + if info.IsStream && pingEnabled { + stopPinger() + pingerWg.Wait() + } if err != nil { return nil, err } diff --git a/relay/helper/common.go b/relay/helper/common.go index 0a3aba1e..35d983f7 100644 --- a/relay/helper/common.go +++ b/relay/helper/common.go @@ -12,11 +12,19 @@ import ( ) func SetEventStreamHeaders(c *gin.Context) { - c.Writer.Header().Set("Content-Type", "text/event-stream") - c.Writer.Header().Set("Cache-Control", "no-cache") - c.Writer.Header().Set("Connection", "keep-alive") - c.Writer.Header().Set("Transfer-Encoding", "chunked") - c.Writer.Header().Set("X-Accel-Buffering", "no") + // 检查是否已经设置过头部 + if _, exists := c.Get("event_stream_headers_set"); exists { + return + } + + c.Writer.Header().Set("Content-Type", "text/event-stream") + c.Writer.Header().Set("Cache-Control", "no-cache") + c.Writer.Header().Set("Connection", "keep-alive") + c.Writer.Header().Set("Transfer-Encoding", "chunked") + c.Writer.Header().Set("X-Accel-Buffering", "no") + + // 设置标志,表示头部已经设置过 + c.Set("event_stream_headers_set", true) } func ClaudeData(c *gin.Context, resp dto.ClaudeResponse) error { diff --git a/relay/helper/stream_scanner.go b/relay/helper/stream_scanner.go index ce4d3a6d..c1bc0d6e 100644 --- a/relay/helper/stream_scanner.go +++ b/relay/helper/stream_scanner.go @@ -3,7 +3,6 @@ package helper import ( "bufio" "context" - "github.com/bytedance/gopkg/util/gopool" "io" "net/http" "one-api/common" @@ -14,6 +13,8 @@ import ( "sync" "time" + "github.com/bytedance/gopkg/util/gopool" + "github.com/gin-gonic/gin" ) @@ -23,76 +24,6 @@ const ( DefaultPingInterval = 10 * time.Second ) -type DoRequestFunc func(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) - -// Optional SSE Ping keep-alive mechanism -// -// Used to solve the problem of the connection with the client timing out due to no data being sent when the upstream -// channel response time is long (e.g., thinking model). -// When enabled, it will send ping data packets to the client via SSE at the specified interval to maintain the connection. -func DoStreamRequestWithPinger(doRequest DoRequestFunc, c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { - SetEventStreamHeaders(c) - - generalSettings := operation_setting.GetGeneralSetting() - pingEnabled := generalSettings.PingIntervalEnabled - pingInterval := time.Duration(generalSettings.PingIntervalSeconds) * time.Second - - pingerCtx, stopPinger := context.WithCancel(c.Request.Context()) - var pingerWg sync.WaitGroup - var doRequestErr error - var resp any - - if pingEnabled { - pingerWg.Add(1) - - gopool.Go(func() { - defer pingerWg.Done() - - if pingInterval <= 0 { - pingInterval = DefaultPingInterval - } - - ticker := time.NewTicker(pingInterval) - defer ticker.Stop() - var pingMutex sync.Mutex - - if common.DebugEnabled { - println("SSE ping goroutine started.") - } - - for { - select { - case <-ticker.C: - pingMutex.Lock() - err := PingData(c) - pingMutex.Unlock() - if err != nil { - common.LogError(c, "SSE ping error: "+err.Error()) - return - } - if common.DebugEnabled { - println("SSE ping data sent.") - } - case <-pingerCtx.Done(): - if common.DebugEnabled { - println("SSE ping goroutine stopped.") - } - return - } - } - }) - } - - resp, doRequestErr = doRequest(c, info, requestBody) - - stopPinger() - if pingEnabled { - pingerWg.Wait() - } - - return resp, doRequestErr -} - func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, dataHandler func(data string) bool) { if resp == nil || dataHandler == nil { @@ -111,11 +42,26 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon stopChan = make(chan bool, 2) scanner = bufio.NewScanner(resp.Body) ticker = time.NewTicker(streamingTimeout) + pingTicker *time.Ticker writeMutex sync.Mutex // Mutex to protect concurrent writes ) + generalSettings := operation_setting.GetGeneralSetting() + pingEnabled := generalSettings.PingIntervalEnabled + pingInterval := time.Duration(generalSettings.PingIntervalSeconds) * time.Second + if pingInterval <= 0 { + pingInterval = DefaultPingInterval + } + + if pingEnabled { + pingTicker = time.NewTicker(pingInterval) + } + defer func() { ticker.Stop() + if pingTicker != nil { + pingTicker.Stop() + } close(stopChan) }() scanner.Buffer(make([]byte, InitialScannerBufferSize), MaxScannerBufferSize) @@ -127,6 +73,33 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon ctx = context.WithValue(ctx, "stop_chan", stopChan) + // Handle ping data sending + if pingEnabled && pingTicker != nil { + gopool.Go(func() { + for { + select { + case <-pingTicker.C: + writeMutex.Lock() // Lock before writing + err := PingData(c) + writeMutex.Unlock() // Unlock after writing + if err != nil { + common.LogError(c, "ping data error: "+err.Error()) + common.SafeSendBool(stopChan, true) + return + } + if common.DebugEnabled { + println("ping data sent") + } + case <-ctx.Done(): + if common.DebugEnabled { + println("ping data goroutine stopped") + } + return + } + } + }) + } + common.RelayCtxGo(ctx, func() { for scanner.Scan() { ticker.Reset(streamingTimeout) diff --git a/relay/relay-text.go b/relay/relay-text.go index 69a48637..8d5cd384 100644 --- a/relay/relay-text.go +++ b/relay/relay-text.go @@ -193,15 +193,7 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { } var httpResp *http.Response - var resp any - - if relayInfo.IsStream { - // Streaming requests can use SSE ping to keep alive and avoid connection timeout - // The judgment of whether ping is enabled will be made within the function - resp, err = helper.DoStreamRequestWithPinger(adaptor.DoRequest, c, relayInfo, requestBody) - } else { - resp, err = adaptor.DoRequest(c, relayInfo, requestBody) - } + resp, err := adaptor.DoRequest(c, relayInfo, requestBody) if err != nil { return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)