From e379ee8f66c1d3f85c89a26994b88227564ffa10 Mon Sep 17 00:00:00 2001 From: creamlike1024 Date: Fri, 16 May 2025 10:27:07 +0800 Subject: [PATCH] coze stream --- relay/channel/coze/adaptor.go | 9 ++- relay/channel/coze/relay-coze.go | 124 ++++++++++++++++++++++++++++++- 2 files changed, 130 insertions(+), 3 deletions(-) diff --git a/relay/channel/coze/adaptor.go b/relay/channel/coze/adaptor.go index 34931cc6..80441a51 100644 --- a/relay/channel/coze/adaptor.go +++ b/relay/channel/coze/adaptor.go @@ -57,6 +57,9 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt // DoRequest implements channel.Adaptor. func (a *Adaptor) DoRequest(c *gin.Context, info *common.RelayInfo, requestBody io.Reader) (any, error) { + if info.IsStream { + return channel.DoApiRequest(a, c, info, requestBody) + } // 首先发送创建消息请求,成功后再发送获取消息请求 // 发送创建消息请求 resp, err := channel.DoApiRequest(a, c, info, requestBody) @@ -93,7 +96,11 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *common.RelayInfo, requestBody // DoResponse implements channel.Adaptor. func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *common.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { - err, usage = cozeChatHandler(c, resp, info) + if info.IsStream { + err, usage = cozeChatStreamHandler(c, resp, info) + } else { + err, usage = cozeChatHandler(c, resp, info) + } return } diff --git a/relay/channel/coze/relay-coze.go b/relay/channel/coze/relay-coze.go index 1ebdb7c1..6db40213 100644 --- a/relay/channel/coze/relay-coze.go +++ b/relay/channel/coze/relay-coze.go @@ -1,16 +1,18 @@ package coze import ( + "bufio" "encoding/json" "errors" "fmt" "io" "net/http" + "one-api/common" "one-api/dto" - "one-api/relay/common" relaycommon "one-api/relay/common" "one-api/relay/helper" "one-api/service" + "strings" "github.com/gin-gonic/gin" ) @@ -95,6 +97,124 @@ func cozeChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela return nil, &usage } +func cozeChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { + scanner := bufio.NewScanner(resp.Body) + scanner.Split(bufio.ScanLines) + helper.SetEventStreamHeaders(c) + id := helper.GetResponseID(c) + var responseText string + + var currentEvent string + var currentData string + var usage dto.Usage + + for scanner.Scan() { + line := scanner.Text() + + if line == "" { + if currentEvent != "" && currentData != "" { + // handle last event + handleCozeEvent(c, currentEvent, currentData, &responseText, &usage, id, info) + currentEvent = "" + currentData = "" + } + continue + } + + if strings.HasPrefix(line, "event:") { + currentEvent = strings.TrimSpace(line[6:]) + continue + } + + if strings.HasPrefix(line, "data:") { + currentData = strings.TrimSpace(line[5:]) + continue + } + } + + // Last event + if currentEvent != "" && currentData != "" { + handleCozeEvent(c, currentEvent, currentData, &responseText, &usage, id, info) + } + + if err := scanner.Err(); err != nil { + return service.OpenAIErrorWrapper(err, "stream_scanner_error", http.StatusInternalServerError), nil + } + helper.Done(c) + + if usage.TotalTokens == 0 { + usage.PromptTokens = info.PromptTokens + usage.CompletionTokens, _ = service.CountTextToken("gpt-3.5-turbo", responseText) + usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens + } + + return nil, &usage +} + +func handleCozeEvent(c *gin.Context, event string, data string, responseText *string, usage *dto.Usage, id string, info *relaycommon.RelayInfo) { + switch event { + case "conversation.chat.completed": + // 将 data 解析为 CozeChatResponseData + var chatData CozeChatResponseData + err := json.Unmarshal([]byte(data), &chatData) + if err != nil { + common.SysError("error_unmarshalling_stream_response: " + err.Error()) + return + } + + usage.PromptTokens = chatData.Usage.InputCount + usage.CompletionTokens = chatData.Usage.OutputCount + usage.TotalTokens = chatData.Usage.TokenCount + + finishReason := "stop" + stopResponse := helper.GenerateStopResponse(id, common.GetTimestamp(), info.UpstreamModelName, finishReason) + helper.ObjectData(c, stopResponse) + + case "conversation.message.delta": + // 将 data 解析为 CozeChatV3MessageDetail + var messageData CozeChatV3MessageDetail + err := json.Unmarshal([]byte(data), &messageData) + if err != nil { + common.SysError("error_unmarshalling_stream_response: " + err.Error()) + return + } + + var content string + err = json.Unmarshal(messageData.Content, &content) + if err != nil { + common.SysError("error_unmarshalling_stream_response: " + err.Error()) + return + } + + *responseText += content + + openaiResponse := dto.ChatCompletionsStreamResponse{ + Id: id, + Object: "chat.completion.chunk", + Created: common.GetTimestamp(), + Model: info.UpstreamModelName, + } + + choice := dto.ChatCompletionsStreamResponseChoice{ + Index: 0, + } + choice.Delta.SetContentString(content) + openaiResponse.Choices = append(openaiResponse.Choices, choice) + + helper.ObjectData(c, openaiResponse) + + case "error": + var errorData CozeError + err := json.Unmarshal([]byte(data), &errorData) + if err != nil { + common.SysError("error_unmarshalling_stream_response: " + err.Error()) + return + } + + common.SysError(fmt.Sprintf("stream event error: ", errorData.Code, errorData.Message)) + } +} + func checkIfChatComplete(a *Adaptor, c *gin.Context, info *relaycommon.RelayInfo) (error, bool) { requestURL := fmt.Sprintf("%s/v3/chat/retrieve", info.BaseUrl) @@ -160,7 +280,7 @@ func getChatDetail(a *Adaptor, c *gin.Context, info *relaycommon.RelayInfo) (*ht return resp, nil } -func doRequest(req *http.Request, info *common.RelayInfo) (*http.Response, error) { +func doRequest(req *http.Request, info *relaycommon.RelayInfo) (*http.Response, error) { var client *http.Client var err error // 声明 err 变量 if proxyURL, ok := info.ChannelSetting["proxy"]; ok {