refactor: extract FlushWriter function for improved stream flushing

This commit is contained in:
CaIon
2025-08-17 15:30:31 +08:00
parent 998305fd00
commit c18414cbe4
2 changed files with 17 additions and 32 deletions

View File

@@ -2,9 +2,6 @@ package openai
import (
"encoding/json"
"errors"
"github.com/samber/lo"
"net/http"
"one-api/common"
"one-api/dto"
"one-api/logger"
@@ -15,6 +12,8 @@ import (
"one-api/types"
"strings"
"github.com/samber/lo"
"github.com/gin-gonic/gin"
)
@@ -71,11 +70,7 @@ func handleGeminiFormat(c *gin.Context, data string, info *relaycommon.RelayInfo
// send gemini format response
c.Render(-1, common.CustomEvent{Data: "data: " + string(geminiResponseStr)})
if flusher, ok := c.Writer.(http.Flusher); ok {
flusher.Flush()
} else {
return errors.New("streaming error: flusher not found")
}
_ = helper.FlushWriter(c)
return nil
}
@@ -253,9 +248,7 @@ func HandleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStream
// 发送最终的 Gemini 响应
c.Render(-1, common.CustomEvent{Data: "data: " + string(geminiResponseStr)})
if flusher, ok := c.Writer.(http.Flusher); ok {
flusher.Flush()
}
_ = helper.FlushWriter(c)
}
}

View File

@@ -14,6 +14,14 @@ import (
"github.com/gorilla/websocket"
)
func FlushWriter(c *gin.Context) error {
if flusher, ok := c.Writer.(http.Flusher); ok {
flusher.Flush()
return nil
}
return errors.New("streaming error: flusher not found")
}
func SetEventStreamHeaders(c *gin.Context) {
// 检查是否已经设置过头部
if _, exists := c.Get("event_stream_headers_set"); exists {
@@ -38,49 +46,33 @@ func ClaudeData(c *gin.Context, resp dto.ClaudeResponse) error {
c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("event: %s\n", resp.Type)})
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonData)})
}
if flusher, ok := c.Writer.(http.Flusher); ok {
flusher.Flush()
} else {
return errors.New("streaming error: flusher not found")
}
_ = FlushWriter(c)
return nil
}
func ClaudeChunkData(c *gin.Context, resp dto.ClaudeResponse, data string) {
c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("event: %s\n", resp.Type)})
c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("data: %s\n", data)})
if flusher, ok := c.Writer.(http.Flusher); ok {
flusher.Flush()
}
_ = FlushWriter(c)
}
func ResponseChunkData(c *gin.Context, resp dto.ResponsesStreamResponse, data string) {
c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("event: %s\n", resp.Type)})
c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("data: %s", data)})
if flusher, ok := c.Writer.(http.Flusher); ok {
flusher.Flush()
}
_ = FlushWriter(c)
}
func StringData(c *gin.Context, str string) error {
//str = strings.TrimPrefix(str, "data: ")
//str = strings.TrimSuffix(str, "\r")
c.Render(-1, common.CustomEvent{Data: "data: " + str})
if flusher, ok := c.Writer.(http.Flusher); ok {
flusher.Flush()
} else {
return errors.New("streaming error: flusher not found")
}
_ = FlushWriter(c)
return nil
}
func PingData(c *gin.Context) error {
c.Writer.Write([]byte(": PING\n\n"))
if flusher, ok := c.Writer.(http.Flusher); ok {
flusher.Flush()
} else {
return errors.New("streaming error: flusher not found")
}
_ = FlushWriter(c)
return nil
}