refactor: Centralize stream handling and helper functions in relay package
This commit is contained in:
@@ -1,11 +1,13 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/bytedance/gopkg/util/gopool"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/pkg/errors"
|
||||
"io"
|
||||
"math"
|
||||
"mime/multipart"
|
||||
@@ -15,16 +17,10 @@ import (
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
relayconstant "one-api/relay/constant"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/bytedance/gopkg/util/gopool"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, forceFormat bool, thinkToContent bool) error {
|
||||
@@ -33,7 +29,7 @@ func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, fo
|
||||
}
|
||||
|
||||
if !forceFormat && !thinkToContent {
|
||||
return service.StringData(c, data)
|
||||
return helper.StringData(c, data)
|
||||
}
|
||||
|
||||
var lastStreamResponse dto.ChatCompletionsStreamResponse
|
||||
@@ -42,34 +38,47 @@ func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, fo
|
||||
}
|
||||
|
||||
if !thinkToContent {
|
||||
return service.ObjectData(c, lastStreamResponse)
|
||||
return helper.ObjectData(c, lastStreamResponse)
|
||||
}
|
||||
|
||||
hasThinkingContent := false
|
||||
for _, choice := range lastStreamResponse.Choices {
|
||||
if len(choice.Delta.GetReasoningContent()) > 0 {
|
||||
hasThinkingContent = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Handle think to content conversion
|
||||
if info.IsFirstResponse {
|
||||
response := lastStreamResponse.Copy()
|
||||
for i := range response.Choices {
|
||||
response.Choices[i].Delta.SetContentString("<think>\n")
|
||||
response.Choices[i].Delta.SetReasoningContent("")
|
||||
if info.ThinkingContentInfo.IsFirstThinkingContent {
|
||||
if hasThinkingContent {
|
||||
response := lastStreamResponse.Copy()
|
||||
for i := range response.Choices {
|
||||
response.Choices[i].Delta.SetContentString("<think>\n")
|
||||
response.Choices[i].Delta.SetReasoningContent("")
|
||||
}
|
||||
info.ThinkingContentInfo.IsFirstThinkingContent = false
|
||||
return helper.ObjectData(c, response)
|
||||
} else {
|
||||
return helper.ObjectData(c, lastStreamResponse)
|
||||
}
|
||||
service.ObjectData(c, response)
|
||||
}
|
||||
|
||||
if lastStreamResponse.Choices == nil || len(lastStreamResponse.Choices) == 0 {
|
||||
return service.ObjectData(c, lastStreamResponse)
|
||||
return helper.ObjectData(c, lastStreamResponse)
|
||||
}
|
||||
|
||||
// Process each choice
|
||||
for i, choice := range lastStreamResponse.Choices {
|
||||
// Handle transition from thinking to content
|
||||
if len(choice.Delta.GetContentString()) > 0 && !info.SendLastReasoningResponse {
|
||||
if len(choice.Delta.GetContentString()) > 0 && !info.ThinkingContentInfo.SendLastThinkingContent {
|
||||
response := lastStreamResponse.Copy()
|
||||
for j := range response.Choices {
|
||||
response.Choices[j].Delta.SetContentString("\n</think>")
|
||||
response.Choices[j].Delta.SetContentString("\n</think>\n\n")
|
||||
response.Choices[j].Delta.SetReasoningContent("")
|
||||
}
|
||||
info.SendLastReasoningResponse = true
|
||||
service.ObjectData(c, response)
|
||||
info.ThinkingContentInfo.SendLastThinkingContent = true
|
||||
helper.ObjectData(c, response)
|
||||
}
|
||||
|
||||
// Convert reasoning content to regular content
|
||||
@@ -79,7 +88,7 @@ func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, fo
|
||||
}
|
||||
}
|
||||
|
||||
return service.ObjectData(c, lastStreamResponse)
|
||||
return helper.ObjectData(c, lastStreamResponse)
|
||||
}
|
||||
|
||||
func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
@@ -109,75 +118,23 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
||||
}
|
||||
|
||||
toolCount := 0
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(bufio.ScanLines)
|
||||
|
||||
service.SetEventStreamHeaders(c)
|
||||
streamingTimeout := time.Duration(constant.StreamingTimeout) * time.Second
|
||||
if strings.HasPrefix(info.UpstreamModelName, "o1") || strings.HasPrefix(info.UpstreamModelName, "o3") {
|
||||
// twice timeout for o1 model
|
||||
streamingTimeout *= 2
|
||||
}
|
||||
ticker := time.NewTicker(streamingTimeout)
|
||||
defer ticker.Stop()
|
||||
|
||||
stopChan := make(chan bool, 2)
|
||||
defer close(stopChan)
|
||||
var (
|
||||
lastStreamData string
|
||||
mu sync.Mutex
|
||||
)
|
||||
|
||||
ctx := context.WithValue(context.Background(), "stop_chan", stopChan)
|
||||
|
||||
common.CtxGo(ctx, func() {
|
||||
for scanner.Scan() {
|
||||
//info.SetFirstResponseTime()
|
||||
ticker.Reset(time.Duration(constant.StreamingTimeout) * time.Second)
|
||||
data := scanner.Text()
|
||||
if common.DebugEnabled {
|
||||
println(data)
|
||||
}
|
||||
if len(data) < 6 { // ignore blank line or wrong format
|
||||
continue
|
||||
}
|
||||
if data[:5] != "data:" && data[:6] != "[DONE]" {
|
||||
continue
|
||||
}
|
||||
mu.Lock()
|
||||
data = data[5:]
|
||||
data = strings.TrimSpace(data)
|
||||
if !strings.HasPrefix(data, "[DONE]") {
|
||||
if lastStreamData != "" {
|
||||
err := sendStreamData(c, info, lastStreamData, forceFormat, thinkToContent)
|
||||
if err != nil {
|
||||
common.LogError(c, "streaming error: "+err.Error())
|
||||
}
|
||||
info.SetFirstResponseTime()
|
||||
}
|
||||
lastStreamData = data
|
||||
streamItems = append(streamItems, data)
|
||||
}
|
||||
mu.Unlock()
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
if err != io.EOF {
|
||||
common.LogError(c, "scanner error: "+err.Error())
|
||||
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
if lastStreamData != "" {
|
||||
err := sendStreamData(c, info, lastStreamData, forceFormat, thinkToContent)
|
||||
if err != nil {
|
||||
common.LogError(c, "streaming error: "+err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
common.SafeSendBool(stopChan, true)
|
||||
lastStreamData = data
|
||||
streamItems = append(streamItems, data)
|
||||
return true
|
||||
})
|
||||
|
||||
select {
|
||||
case <-ticker.C:
|
||||
// 超时处理逻辑
|
||||
common.LogError(c, "streaming timeout")
|
||||
case <-stopChan:
|
||||
// 正常结束
|
||||
}
|
||||
|
||||
shouldSendLastResp := true
|
||||
var lastStreamResponse dto.ChatCompletionsStreamResponse
|
||||
err := json.Unmarshal(common.StringToByteSlice(lastStreamData), &lastStreamResponse)
|
||||
@@ -285,12 +242,12 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
||||
}
|
||||
|
||||
if info.ShouldIncludeUsage && !containStreamUsage {
|
||||
response := service.GenerateFinalUsageResponse(responseId, createAt, model, *usage)
|
||||
response := helper.GenerateFinalUsageResponse(responseId, createAt, model, *usage)
|
||||
response.SetSystemFingerprint(systemFingerprint)
|
||||
service.ObjectData(c, response)
|
||||
helper.ObjectData(c, response)
|
||||
}
|
||||
|
||||
service.Done(c)
|
||||
helper.Done(c)
|
||||
|
||||
resp.Body.Close()
|
||||
return nil, usage
|
||||
@@ -523,7 +480,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op
|
||||
localUsage.InputTokenDetails.TextTokens += textToken
|
||||
localUsage.InputTokenDetails.AudioTokens += audioToken
|
||||
|
||||
err = service.WssString(c, targetConn, string(message))
|
||||
err = helper.WssString(c, targetConn, string(message))
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("error writing to target: %v", err)
|
||||
return
|
||||
@@ -629,7 +586,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op
|
||||
localUsage.OutputTokenDetails.AudioTokens += audioToken
|
||||
}
|
||||
|
||||
err = service.WssString(c, clientConn, string(message))
|
||||
err = helper.WssString(c, clientConn, string(message))
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("error writing to client: %v", err)
|
||||
return
|
||||
|
||||
Reference in New Issue
Block a user