feat: 添加流模式下的SSE保活机制 #945

This commit is contained in:
CaIon
2025-04-14 19:40:23 +08:00
parent dcf7878772
commit 2f3acd9d22
8 changed files with 136 additions and 31 deletions

View File

@@ -141,7 +141,6 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
if err != nil {
common.SysError("error handling stream format: " + err.Error())
}
info.SetFirstResponseTime()
}
lastStreamData = data
streamItems = append(streamItems, data)

View File

@@ -6,6 +6,7 @@ import (
"one-api/dto"
relayconstant "one-api/relay/constant"
"strings"
"sync"
"time"
"github.com/gin-gonic/gin"
@@ -54,6 +55,7 @@ type RelayInfo struct {
StartTime time.Time
FirstResponseTime time.Time
isFirstResponse bool
responseMutex sync.Mutex // Add mutex for protecting concurrent access
//SendLastReasoningResponse bool
ApiType int
IsStream bool
@@ -212,12 +214,19 @@ func (info *RelayInfo) SetIsStream(isStream bool) {
}
func (info *RelayInfo) SetFirstResponseTime() {
info.responseMutex.Lock()
defer info.responseMutex.Unlock()
if info.isFirstResponse {
info.FirstResponseTime = time.Now()
info.isFirstResponse = false
}
}
func (info *RelayInfo) HasSendResponse() bool {
return info.FirstResponseTime.After(info.StartTime)
}
type TaskRelayInfo struct {
*RelayInfo
Action string

View File

@@ -55,6 +55,16 @@ func StringData(c *gin.Context, str string) error {
return nil
}
func PingData(c *gin.Context) error {
c.Writer.Write([]byte(": PING\n\n"))
if flusher, ok := c.Writer.(http.Flusher); ok {
flusher.Flush()
} else {
return errors.New("streaming error: flusher not found")
}
return nil
}
func ObjectData(c *gin.Context, object interface{}) error {
if object == nil {
return errors.New("object is nil")

View File

@@ -3,12 +3,15 @@ package helper
import (
"bufio"
"context"
"github.com/bytedance/gopkg/util/gopool"
"io"
"net/http"
"one-api/common"
"one-api/constant"
relaycommon "one-api/relay/common"
"one-api/setting/operation_setting"
"strings"
"sync"
"time"
"github.com/gin-gonic/gin"
@@ -17,11 +20,12 @@ import (
const (
InitialScannerBufferSize = 1 << 20 // 1MB (1*1024*1024)
MaxScannerBufferSize = 10 << 20 // 10MB (10*1024*1024)
DefaultPingInterval = 10 * time.Second
)
func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, dataHandler func(data string) bool) {
if resp == nil {
if resp == nil || dataHandler == nil {
return
}
@@ -34,13 +38,29 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
}
var (
stopChan = make(chan bool, 2)
scanner = bufio.NewScanner(resp.Body)
ticker = time.NewTicker(streamingTimeout)
stopChan = make(chan bool, 2)
scanner = bufio.NewScanner(resp.Body)
ticker = time.NewTicker(streamingTimeout)
pingTicker *time.Ticker
writeMutex sync.Mutex // Mutex to protect concurrent writes
)
generalSettings := operation_setting.GetGeneralSetting()
pingEnabled := generalSettings.PingIntervalEnabled
pingInterval := time.Duration(generalSettings.PingIntervalSeconds) * time.Second
if pingInterval <= 0 {
pingInterval = DefaultPingInterval
}
if pingEnabled {
pingTicker = time.NewTicker(pingInterval)
}
defer func() {
ticker.Stop()
if pingTicker != nil {
pingTicker.Stop()
}
close(stopChan)
}()
scanner.Buffer(make([]byte, InitialScannerBufferSize), MaxScannerBufferSize)
@@ -51,6 +71,34 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
defer cancel()
ctx = context.WithValue(ctx, "stop_chan", stopChan)
// Handle ping data sending
if pingEnabled && pingTicker != nil {
gopool.Go(func() {
for {
select {
case <-pingTicker.C:
writeMutex.Lock() // Lock before writing
err := PingData(c)
writeMutex.Unlock() // Unlock after writing
if err != nil {
common.LogError(c, "ping data error: "+err.Error())
common.SafeSendBool(stopChan, true)
return
}
if common.DebugEnabled {
println("ping data sent")
}
case <-ctx.Done():
if common.DebugEnabled {
println("ping data goroutine stopped")
}
return
}
}
})
}
common.RelayCtxGo(ctx, func() {
for scanner.Scan() {
ticker.Reset(streamingTimeout)
@@ -70,7 +118,9 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
data = strings.TrimSuffix(data, "\"")
if !strings.HasPrefix(data, "[DONE]") {
info.SetFirstResponseTime()
writeMutex.Lock() // Lock before writing
success := dataHandler(data)
writeMutex.Unlock() // Unlock after writing
if !success {
break
}
@@ -90,7 +140,9 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
case <-ticker.C:
// 超时处理逻辑
common.LogError(c, "streaming timeout")
common.SafeSendBool(stopChan, true)
case <-stopChan:
// 正常结束
common.LogInfo(c, "streaming finished")
}
}