refactor: Centralize stream handling and helper functions in relay package

This commit is contained in:
1808837298@qq.com
2025-03-05 19:47:41 +08:00
parent 37bb34b4b0
commit 37a83ecc33
20 changed files with 228 additions and 195 deletions

112
relay/helper/common.go Normal file
View File

@@ -0,0 +1,112 @@
package helper
import (
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"net/http"
"one-api/common"
"one-api/dto"
)
func SetEventStreamHeaders(c *gin.Context) {
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("Transfer-Encoding", "chunked")
c.Writer.Header().Set("X-Accel-Buffering", "no")
}
func StringData(c *gin.Context, str string) error {
//str = strings.TrimPrefix(str, "data: ")
//str = strings.TrimSuffix(str, "\r")
c.Render(-1, common.CustomEvent{Data: "data: " + str})
if flusher, ok := c.Writer.(http.Flusher); ok {
flusher.Flush()
} else {
return errors.New("streaming error: flusher not found")
}
return nil
}
func ObjectData(c *gin.Context, object interface{}) error {
jsonData, err := json.Marshal(object)
if err != nil {
return fmt.Errorf("error marshalling object: %w", err)
}
return StringData(c, string(jsonData))
}
func Done(c *gin.Context) {
_ = StringData(c, "[DONE]")
}
func WssString(c *gin.Context, ws *websocket.Conn, str string) error {
if ws == nil {
common.LogError(c, "websocket connection is nil")
return errors.New("websocket connection is nil")
}
//common.LogInfo(c, fmt.Sprintf("sending message: %s", str))
return ws.WriteMessage(1, []byte(str))
}
func WssObject(c *gin.Context, ws *websocket.Conn, object interface{}) error {
jsonData, err := json.Marshal(object)
if err != nil {
return fmt.Errorf("error marshalling object: %w", err)
}
if ws == nil {
common.LogError(c, "websocket connection is nil")
return errors.New("websocket connection is nil")
}
//common.LogInfo(c, fmt.Sprintf("sending message: %s", jsonData))
return ws.WriteMessage(1, jsonData)
}
func WssError(c *gin.Context, ws *websocket.Conn, openaiError dto.OpenAIError) {
errorObj := &dto.RealtimeEvent{
Type: "error",
EventId: GetLocalRealtimeID(c),
Error: &openaiError,
}
_ = WssObject(c, ws, errorObj)
}
func GetResponseID(c *gin.Context) string {
logID := c.GetString(common.RequestIdKey)
return fmt.Sprintf("chatcmpl-%s", logID)
}
func GetLocalRealtimeID(c *gin.Context) string {
logID := c.GetString(common.RequestIdKey)
return fmt.Sprintf("evt_%s", logID)
}
func GenerateStopResponse(id string, createAt int64, model string, finishReason string) *dto.ChatCompletionsStreamResponse {
return &dto.ChatCompletionsStreamResponse{
Id: id,
Object: "chat.completion.chunk",
Created: createAt,
Model: model,
SystemFingerprint: nil,
Choices: []dto.ChatCompletionsStreamResponseChoice{
{
FinishReason: &finishReason,
},
},
}
}
func GenerateFinalUsageResponse(id string, createAt int64, model string, usage dto.Usage) *dto.ChatCompletionsStreamResponse {
return &dto.ChatCompletionsStreamResponse{
Id: id,
Object: "chat.completion.chunk",
Created: createAt,
Model: model,
SystemFingerprint: nil,
Choices: make([]dto.ChatCompletionsStreamResponseChoice, 0),
Usage: &usage,
}
}

View File

@@ -0,0 +1,85 @@
package helper
import (
"bufio"
"context"
"io"
"net/http"
"one-api/common"
"one-api/constant"
relaycommon "one-api/relay/common"
"strings"
"time"
"github.com/gin-gonic/gin"
)
func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, dataHandler func(data string) bool) {
streamingTimeout := time.Duration(constant.StreamingTimeout) * time.Second
if strings.HasPrefix(info.UpstreamModelName, "o1") || strings.HasPrefix(info.UpstreamModelName, "o3") {
// twice timeout for thinking model
streamingTimeout *= 2
}
var (
stopChan = make(chan bool, 2)
scanner = bufio.NewScanner(resp.Body)
ticker = time.NewTicker(streamingTimeout)
)
defer func() {
ticker.Stop()
close(stopChan)
}()
scanner.Split(bufio.ScanLines)
SetEventStreamHeaders(c)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
ctx = context.WithValue(ctx, "stop_chan", stopChan)
common.RelayCtxGo(ctx, func() {
for scanner.Scan() {
ticker.Reset(streamingTimeout)
data := scanner.Text()
if common.DebugEnabled {
println(data)
}
if len(data) < 6 {
continue
}
if data[:5] != "data:" && data[:6] != "[DONE]" {
continue
}
data = data[5:]
data = strings.TrimLeft(data, " ")
data = strings.TrimSuffix(data, "\"")
if !strings.HasPrefix(data, "[DONE]") {
info.SetFirstResponseTime()
success := dataHandler(data)
if !success {
break
}
}
}
if err := scanner.Err(); err != nil {
if err != io.EOF {
common.LogError(c, "scanner error: "+err.Error())
}
}
common.SafeSendBool(stopChan, true)
})
select {
case <-ticker.C:
// 超时处理逻辑
common.LogError(c, "streaming timeout")
case <-stopChan:
// 正常结束
}
}