refactor: Centralize stream handling and helper functions in relay package

This commit is contained in:
1808837298@qq.com
2025-03-05 19:47:41 +08:00
parent 37bb34b4b0
commit 37a83ecc33
20 changed files with 228 additions and 195 deletions

View File

@@ -8,6 +8,7 @@ import (
"net/http"
"one-api/common"
"one-api/dto"
"one-api/relay/helper"
"one-api/service"
"strings"
)
@@ -153,7 +154,7 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWith
}
stopChan <- true
}()
service.SetEventStreamHeaders(c)
helper.SetEventStreamHeaders(c)
lastResponseText := ""
c.Stream(func(w io.Writer) bool {
select {

View File

@@ -12,6 +12,7 @@ import (
relaymodel "one-api/dto"
"one-api/relay/channel/claude"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
"strings"
"time"
@@ -203,13 +204,13 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
}
})
if info.ShouldIncludeUsage {
response := service.GenerateFinalUsageResponse(id, createdTime, info.UpstreamModelName, usage)
err := service.ObjectData(c, response)
response := helper.GenerateFinalUsageResponse(id, createdTime, info.UpstreamModelName, usage)
err := helper.ObjectData(c, response)
if err != nil {
common.SysError("send final response failed: " + err.Error())
}
}
service.Done(c)
helper.Done(c)
if resp != nil {
err = resp.Body.Close()
if err != nil {

View File

@@ -11,6 +11,7 @@ import (
"one-api/common"
"one-api/constant"
"one-api/dto"
"one-api/relay/helper"
"one-api/service"
"strings"
"sync"
@@ -138,7 +139,7 @@ func baiduStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWi
}
stopChan <- true
}()
service.SetEventStreamHeaders(c)
helper.SetEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:

View File

@@ -1,7 +1,6 @@
package claude
import (
"bufio"
"encoding/json"
"fmt"
"io"
@@ -9,6 +8,7 @@ import (
"one-api/common"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
"one-api/setting/model_setting"
"strings"
@@ -443,28 +443,18 @@ func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
usage = &dto.Usage{}
responseText := ""
createdTime := common.GetTimestamp()
scanner := bufio.NewScanner(resp.Body)
scanner.Split(bufio.ScanLines)
service.SetEventStreamHeaders(c)
for scanner.Scan() {
data := scanner.Text()
info.SetFirstResponseTime()
if len(data) < 6 || !strings.HasPrefix(data, "data:") {
continue
}
data = strings.TrimPrefix(data, "data:")
data = strings.TrimSpace(data)
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
var claudeResponse ClaudeResponse
err := json.Unmarshal([]byte(data), &claudeResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
continue
return true
}
response, claudeUsage := StreamResponseClaude2OpenAI(requestMode, &claudeResponse)
if response == nil {
continue
return true
}
if requestMode == RequestModeCompletion {
responseText += claudeResponse.Completion
@@ -481,9 +471,9 @@ func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
usage.CompletionTokens = claudeUsage.OutputTokens
usage.TotalTokens = claudeUsage.InputTokens + claudeUsage.OutputTokens
} else if claudeResponse.Type == "content_block_start" {
return true
} else {
continue
return true
}
}
//response.Id = responseId
@@ -491,11 +481,12 @@ func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
response.Created = createdTime
response.Model = info.UpstreamModelName
err = service.ObjectData(c, response)
err = helper.ObjectData(c, response)
if err != nil {
common.LogError(c, "send_stream_response_failed: "+err.Error())
}
}
return true
})
if requestMode == RequestModeCompletion {
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
@@ -508,13 +499,13 @@ func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
}
}
if info.ShouldIncludeUsage {
response := service.GenerateFinalUsageResponse(responseId, createdTime, info.UpstreamModelName, *usage)
err := service.ObjectData(c, response)
response := helper.GenerateFinalUsageResponse(responseId, createdTime, info.UpstreamModelName, *usage)
err := helper.ObjectData(c, response)
if err != nil {
common.SysError("send final response failed: " + err.Error())
}
}
service.Done(c)
helper.Done(c)
resp.Body.Close()
return nil, usage
}

View File

@@ -9,6 +9,7 @@ import (
"one-api/common"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
"strings"
"time"
@@ -28,8 +29,8 @@ func cfStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
scanner := bufio.NewScanner(resp.Body)
scanner.Split(bufio.ScanLines)
service.SetEventStreamHeaders(c)
id := service.GetResponseID(c)
helper.SetEventStreamHeaders(c)
id := helper.GetResponseID(c)
var responseText string
isFirst := true
@@ -57,7 +58,7 @@ func cfStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
}
response.Id = id
response.Model = info.UpstreamModelName
err = service.ObjectData(c, response)
err = helper.ObjectData(c, response)
if isFirst {
isFirst = false
info.FirstResponseTime = time.Now()
@@ -72,13 +73,13 @@ func cfStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
}
usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
if info.ShouldIncludeUsage {
response := service.GenerateFinalUsageResponse(id, info.StartTime.Unix(), info.UpstreamModelName, *usage)
err := service.ObjectData(c, response)
response := helper.GenerateFinalUsageResponse(id, info.StartTime.Unix(), info.UpstreamModelName, *usage)
err := helper.ObjectData(c, response)
if err != nil {
common.LogError(c, "error_rendering_final_usage_response: "+err.Error())
}
}
service.Done(c)
helper.Done(c)
err := resp.Body.Close()
if err != nil {
@@ -109,7 +110,7 @@ func cfHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo)
}
usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
response.Usage = *usage
response.Id = service.GetResponseID(c)
response.Id = helper.GetResponseID(c)
jsonResponse, err := json.Marshal(response)
if err != nil {
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil

View File

@@ -10,6 +10,7 @@ import (
"one-api/common"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
"strings"
"time"
@@ -103,7 +104,7 @@ func cohereStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
}
stopChan <- true
}()
service.SetEventStreamHeaders(c)
helper.SetEventStreamHeaders(c)
isFirst := true
c.Stream(func(w io.Writer) bool {
select {

View File

@@ -10,6 +10,7 @@ import (
"one-api/constant"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
"strings"
)
@@ -66,7 +67,7 @@ func difyStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
scanner := bufio.NewScanner(resp.Body)
scanner.Split(bufio.ScanLines)
service.SetEventStreamHeaders(c)
helper.SetEventStreamHeaders(c)
for scanner.Scan() {
data := scanner.Text()
@@ -92,7 +93,7 @@ func difyStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
responseText += openaiResponse.Choices[0].Delta.GetContentString()
}
}
err = service.ObjectData(c, openaiResponse)
err = helper.ObjectData(c, openaiResponse)
if err != nil {
common.SysError(err.Error())
}
@@ -100,7 +101,7 @@ func difyStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
if err := scanner.Err(); err != nil {
common.SysError("error reading stream: " + err.Error())
}
service.Done(c)
helper.Done(c)
err := resp.Body.Close()
if err != nil {
//return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil

View File

@@ -1,7 +1,6 @@
package gemini
import (
"bufio"
"encoding/json"
"fmt"
"io"
@@ -10,6 +9,7 @@ import (
"one-api/constant"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
"one-api/setting/model_setting"
"strings"
@@ -429,10 +429,10 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp
func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.ChatCompletionsStreamResponse, bool) {
choices := make([]dto.ChatCompletionsStreamResponseChoice, 0, len(geminiResponse.Candidates))
is_stop := false
isStop := false
for _, candidate := range geminiResponse.Candidates {
if candidate.FinishReason != nil && *candidate.FinishReason == "STOP" {
is_stop = true
isStop = true
candidate.FinishReason = nil
}
choice := dto.ChatCompletionsStreamResponseChoice{
@@ -482,9 +482,8 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.C
var response dto.ChatCompletionsStreamResponse
response.Object = "chat.completion.chunk"
response.Model = "gemini"
response.Choices = choices
return &response, is_stop
return &response, isStop
}
func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
@@ -492,27 +491,16 @@ func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom
id := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
createAt := common.GetTimestamp()
var usage = &dto.Usage{}
scanner := bufio.NewScanner(resp.Body)
scanner.Split(bufio.ScanLines)
service.SetEventStreamHeaders(c)
for scanner.Scan() {
data := scanner.Text()
info.SetFirstResponseTime()
data = strings.TrimSpace(data)
if !strings.HasPrefix(data, "data: ") {
continue
}
data = strings.TrimPrefix(data, "data: ")
data = strings.TrimSuffix(data, "\"")
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
var geminiResponse GeminiChatResponse
err := json.Unmarshal([]byte(data), &geminiResponse)
if err != nil {
common.LogError(c, "error unmarshalling stream response: "+err.Error())
continue
return false
}
response, is_stop := streamResponseGeminiChat2OpenAI(&geminiResponse)
response, isStop := streamResponseGeminiChat2OpenAI(&geminiResponse)
response.Id = id
response.Created = createAt
response.Model = info.UpstreamModelName
@@ -521,15 +509,16 @@ func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom
usage.PromptTokens = geminiResponse.UsageMetadata.PromptTokenCount
usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount
}
err = service.ObjectData(c, response)
err = helper.ObjectData(c, response)
if err != nil {
common.LogError(c, err.Error())
}
if is_stop {
response := service.GenerateStopResponse(id, createAt, info.UpstreamModelName, constant.FinishReasonStop)
service.ObjectData(c, response)
if isStop {
response := helper.GenerateStopResponse(id, createAt, info.UpstreamModelName, constant.FinishReasonStop)
helper.ObjectData(c, response)
}
}
return true
})
var response *dto.ChatCompletionsStreamResponse
@@ -538,13 +527,13 @@ func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom
usage.CompletionTokenDetails.TextTokens = usage.CompletionTokens
if info.ShouldIncludeUsage {
response = service.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage)
err := service.ObjectData(c, response)
response = helper.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage)
err := helper.ObjectData(c, response)
if err != nil {
common.SysError("send final response failed: " + err.Error())
}
}
service.Done(c)
helper.Done(c)
resp.Body.Close()
return nil, usage
}

View File

@@ -1,11 +1,13 @@
package openai
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"github.com/bytedance/gopkg/util/gopool"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"github.com/pkg/errors"
"io"
"math"
"mime/multipart"
@@ -15,16 +17,10 @@ import (
"one-api/dto"
relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant"
"one-api/relay/helper"
"one-api/service"
"os"
"strings"
"sync"
"time"
"github.com/bytedance/gopkg/util/gopool"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"github.com/pkg/errors"
)
func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, forceFormat bool, thinkToContent bool) error {
@@ -33,7 +29,7 @@ func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, fo
}
if !forceFormat && !thinkToContent {
return service.StringData(c, data)
return helper.StringData(c, data)
}
var lastStreamResponse dto.ChatCompletionsStreamResponse
@@ -42,34 +38,47 @@ func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, fo
}
if !thinkToContent {
return service.ObjectData(c, lastStreamResponse)
return helper.ObjectData(c, lastStreamResponse)
}
hasThinkingContent := false
for _, choice := range lastStreamResponse.Choices {
if len(choice.Delta.GetReasoningContent()) > 0 {
hasThinkingContent = true
break
}
}
// Handle think to content conversion
if info.IsFirstResponse {
response := lastStreamResponse.Copy()
for i := range response.Choices {
response.Choices[i].Delta.SetContentString("<think>\n")
response.Choices[i].Delta.SetReasoningContent("")
if info.ThinkingContentInfo.IsFirstThinkingContent {
if hasThinkingContent {
response := lastStreamResponse.Copy()
for i := range response.Choices {
response.Choices[i].Delta.SetContentString("<think>\n")
response.Choices[i].Delta.SetReasoningContent("")
}
info.ThinkingContentInfo.IsFirstThinkingContent = false
return helper.ObjectData(c, response)
} else {
return helper.ObjectData(c, lastStreamResponse)
}
service.ObjectData(c, response)
}
if lastStreamResponse.Choices == nil || len(lastStreamResponse.Choices) == 0 {
return service.ObjectData(c, lastStreamResponse)
return helper.ObjectData(c, lastStreamResponse)
}
// Process each choice
for i, choice := range lastStreamResponse.Choices {
// Handle transition from thinking to content
if len(choice.Delta.GetContentString()) > 0 && !info.SendLastReasoningResponse {
if len(choice.Delta.GetContentString()) > 0 && !info.ThinkingContentInfo.SendLastThinkingContent {
response := lastStreamResponse.Copy()
for j := range response.Choices {
response.Choices[j].Delta.SetContentString("\n</think>")
response.Choices[j].Delta.SetContentString("\n</think>\n\n")
response.Choices[j].Delta.SetReasoningContent("")
}
info.SendLastReasoningResponse = true
service.ObjectData(c, response)
info.ThinkingContentInfo.SendLastThinkingContent = true
helper.ObjectData(c, response)
}
// Convert reasoning content to regular content
@@ -79,7 +88,7 @@ func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, fo
}
}
return service.ObjectData(c, lastStreamResponse)
return helper.ObjectData(c, lastStreamResponse)
}
func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
@@ -109,75 +118,23 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
}
toolCount := 0
scanner := bufio.NewScanner(resp.Body)
scanner.Split(bufio.ScanLines)
service.SetEventStreamHeaders(c)
streamingTimeout := time.Duration(constant.StreamingTimeout) * time.Second
if strings.HasPrefix(info.UpstreamModelName, "o1") || strings.HasPrefix(info.UpstreamModelName, "o3") {
// twice timeout for o1 model
streamingTimeout *= 2
}
ticker := time.NewTicker(streamingTimeout)
defer ticker.Stop()
stopChan := make(chan bool, 2)
defer close(stopChan)
var (
lastStreamData string
mu sync.Mutex
)
ctx := context.WithValue(context.Background(), "stop_chan", stopChan)
common.CtxGo(ctx, func() {
for scanner.Scan() {
//info.SetFirstResponseTime()
ticker.Reset(time.Duration(constant.StreamingTimeout) * time.Second)
data := scanner.Text()
if common.DebugEnabled {
println(data)
}
if len(data) < 6 { // ignore blank line or wrong format
continue
}
if data[:5] != "data:" && data[:6] != "[DONE]" {
continue
}
mu.Lock()
data = data[5:]
data = strings.TrimSpace(data)
if !strings.HasPrefix(data, "[DONE]") {
if lastStreamData != "" {
err := sendStreamData(c, info, lastStreamData, forceFormat, thinkToContent)
if err != nil {
common.LogError(c, "streaming error: "+err.Error())
}
info.SetFirstResponseTime()
}
lastStreamData = data
streamItems = append(streamItems, data)
}
mu.Unlock()
}
if err := scanner.Err(); err != nil {
if err != io.EOF {
common.LogError(c, "scanner error: "+err.Error())
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
if lastStreamData != "" {
err := sendStreamData(c, info, lastStreamData, forceFormat, thinkToContent)
if err != nil {
common.LogError(c, "streaming error: "+err.Error())
}
}
common.SafeSendBool(stopChan, true)
lastStreamData = data
streamItems = append(streamItems, data)
return true
})
select {
case <-ticker.C:
// 超时处理逻辑
common.LogError(c, "streaming timeout")
case <-stopChan:
// 正常结束
}
shouldSendLastResp := true
var lastStreamResponse dto.ChatCompletionsStreamResponse
err := json.Unmarshal(common.StringToByteSlice(lastStreamData), &lastStreamResponse)
@@ -285,12 +242,12 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
}
if info.ShouldIncludeUsage && !containStreamUsage {
response := service.GenerateFinalUsageResponse(responseId, createAt, model, *usage)
response := helper.GenerateFinalUsageResponse(responseId, createAt, model, *usage)
response.SetSystemFingerprint(systemFingerprint)
service.ObjectData(c, response)
helper.ObjectData(c, response)
}
service.Done(c)
helper.Done(c)
resp.Body.Close()
return nil, usage
@@ -523,7 +480,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op
localUsage.InputTokenDetails.TextTokens += textToken
localUsage.InputTokenDetails.AudioTokens += audioToken
err = service.WssString(c, targetConn, string(message))
err = helper.WssString(c, targetConn, string(message))
if err != nil {
errChan <- fmt.Errorf("error writing to target: %v", err)
return
@@ -629,7 +586,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op
localUsage.OutputTokenDetails.AudioTokens += audioToken
}
err = service.WssString(c, clientConn, string(message))
err = helper.WssString(c, clientConn, string(message))
if err != nil {
errChan <- fmt.Errorf("error writing to client: %v", err)
return

View File

@@ -9,6 +9,7 @@ import (
"one-api/common"
"one-api/constant"
"one-api/dto"
"one-api/relay/helper"
"one-api/service"
)
@@ -112,7 +113,7 @@ func palmStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWit
dataChan <- string(jsonResponse)
stopChan <- true
}()
service.SetEventStreamHeaders(c)
helper.SetEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:

View File

@@ -14,6 +14,7 @@ import (
"one-api/common"
"one-api/constant"
"one-api/dto"
"one-api/relay/helper"
"one-api/service"
"strconv"
"strings"
@@ -91,7 +92,7 @@ func tencentStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIError
scanner := bufio.NewScanner(resp.Body)
scanner.Split(bufio.ScanLines)
service.SetEventStreamHeaders(c)
helper.SetEventStreamHeaders(c)
for scanner.Scan() {
data := scanner.Text()
@@ -112,7 +113,7 @@ func tencentStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIError
responseText += response.Choices[0].Delta.GetContentString()
}
err = service.ObjectData(c, response)
err = helper.ObjectData(c, response)
if err != nil {
common.SysError(err.Error())
}
@@ -122,7 +123,7 @@ func tencentStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIError
common.SysError("error reading stream: " + err.Error())
}
service.Done(c)
helper.Done(c)
err := resp.Body.Close()
if err != nil {

View File

@@ -14,6 +14,7 @@ import (
"one-api/common"
"one-api/constant"
"one-api/dto"
"one-api/relay/helper"
"one-api/service"
"strings"
"time"
@@ -132,7 +133,7 @@ func xunfeiStreamHandler(c *gin.Context, textRequest dto.GeneralOpenAIRequest, a
if err != nil {
return service.OpenAIErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
}
service.SetEventStreamHeaders(c)
helper.SetEventStreamHeaders(c)
var usage dto.Usage
c.Stream(func(w io.Writer) bool {
select {

View File

@@ -10,6 +10,7 @@ import (
"one-api/common"
"one-api/constant"
"one-api/dto"
"one-api/relay/helper"
"one-api/service"
"strings"
"sync"
@@ -177,7 +178,7 @@ func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWi
}
stopChan <- true
}()
service.SetEventStreamHeaders(c)
helper.SetEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:

View File

@@ -10,6 +10,7 @@ import (
"net/http"
"one-api/common"
"one-api/dto"
"one-api/relay/helper"
"one-api/service"
"strings"
"sync"
@@ -197,7 +198,7 @@ func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWi
}
stopChan <- true
}()
service.SetEventStreamHeaders(c)
helper.SetEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan: