diff --git a/relay/channel/api_request.go b/relay/channel/api_request.go index 8b2ca889..db5d4f44 100644 --- a/relay/channel/api_request.go +++ b/relay/channel/api_request.go @@ -1,16 +1,23 @@ package channel import ( + "context" "errors" "fmt" - "github.com/gin-gonic/gin" - "github.com/gorilla/websocket" "io" "net/http" common2 "one-api/common" "one-api/relay/common" "one-api/relay/constant" + "one-api/relay/helper" "one-api/service" + "one-api/setting/operation_setting" + "sync" + "time" + + "github.com/bytedance/gopkg/util/gopool" + "github.com/gin-gonic/gin" + "github.com/gorilla/websocket" ) func SetupApiRequestHeader(info *common.RelayInfo, c *gin.Context, req *http.Header) { @@ -105,7 +112,62 @@ func doRequest(c *gin.Context, req *http.Request, info *common.RelayInfo) (*http } else { client = service.GetHttpClient() } + // 流式请求 ping 保活 + var stopPinger func() + generalSettings := operation_setting.GetGeneralSetting() + pingEnabled := generalSettings.PingIntervalEnabled + var pingerWg sync.WaitGroup + if info.IsStream { + helper.SetEventStreamHeaders(c) + pingInterval := time.Duration(generalSettings.PingIntervalSeconds) * time.Second + var pingerCtx context.Context + pingerCtx, stopPinger = context.WithCancel(c.Request.Context()) + + if pingEnabled { + pingerWg.Add(1) + gopool.Go(func() { + defer pingerWg.Done() + if pingInterval <= 0 { + pingInterval = helper.DefaultPingInterval + } + + ticker := time.NewTicker(pingInterval) + defer ticker.Stop() + var pingMutex sync.Mutex + if common2.DebugEnabled { + println("SSE ping goroutine started") + } + + for { + select { + case <-ticker.C: + pingMutex.Lock() + err2 := helper.PingData(c) + pingMutex.Unlock() + if err2 != nil { + common2.LogError(c, "SSE ping error: "+err.Error()) + return + } + if common2.DebugEnabled { + println("SSE ping data sent.") + } + case <-pingerCtx.Done(): + if common2.DebugEnabled { + println("SSE ping goroutine stopped.") + } + return + } + } + }) + } + } + resp, err := client.Do(req) + // request结束后停止ping + if info.IsStream && pingEnabled { + stopPinger() + pingerWg.Wait() + } if err != nil { return nil, err } diff --git a/relay/helper/common.go b/relay/helper/common.go index 0a3aba1e..35d983f7 100644 --- a/relay/helper/common.go +++ b/relay/helper/common.go @@ -12,11 +12,19 @@ import ( ) func SetEventStreamHeaders(c *gin.Context) { - c.Writer.Header().Set("Content-Type", "text/event-stream") - c.Writer.Header().Set("Cache-Control", "no-cache") - c.Writer.Header().Set("Connection", "keep-alive") - c.Writer.Header().Set("Transfer-Encoding", "chunked") - c.Writer.Header().Set("X-Accel-Buffering", "no") + // 检查是否已经设置过头部 + if _, exists := c.Get("event_stream_headers_set"); exists { + return + } + + c.Writer.Header().Set("Content-Type", "text/event-stream") + c.Writer.Header().Set("Cache-Control", "no-cache") + c.Writer.Header().Set("Connection", "keep-alive") + c.Writer.Header().Set("Transfer-Encoding", "chunked") + c.Writer.Header().Set("X-Accel-Buffering", "no") + + // 设置标志,表示头部已经设置过 + c.Set("event_stream_headers_set", true) } func ClaudeData(c *gin.Context, resp dto.ClaudeResponse) error { diff --git a/relay/helper/stream_scanner.go b/relay/helper/stream_scanner.go index 2738ce2a..c1bc0d6e 100644 --- a/relay/helper/stream_scanner.go +++ b/relay/helper/stream_scanner.go @@ -3,7 +3,6 @@ package helper import ( "bufio" "context" - "github.com/bytedance/gopkg/util/gopool" "io" "net/http" "one-api/common" @@ -14,6 +13,8 @@ import ( "sync" "time" + "github.com/bytedance/gopkg/util/gopool" + "github.com/gin-gonic/gin" ) diff --git a/relay/relay-text.go b/relay/relay-text.go index e0b6ad0e..8d5cd384 100644 --- a/relay/relay-text.go +++ b/relay/relay-text.go @@ -194,6 +194,7 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { var httpResp *http.Response resp, err := adaptor.DoRequest(c, relayInfo, requestBody) + if err != nil { return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) }