coze stream
This commit is contained in:
@@ -57,6 +57,9 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
|||||||
|
|
||||||
// DoRequest implements channel.Adaptor.
|
// DoRequest implements channel.Adaptor.
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, info *common.RelayInfo, requestBody io.Reader) (any, error) {
|
func (a *Adaptor) DoRequest(c *gin.Context, info *common.RelayInfo, requestBody io.Reader) (any, error) {
|
||||||
|
if info.IsStream {
|
||||||
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
|
}
|
||||||
// 首先发送创建消息请求,成功后再发送获取消息请求
|
// 首先发送创建消息请求,成功后再发送获取消息请求
|
||||||
// 发送创建消息请求
|
// 发送创建消息请求
|
||||||
resp, err := channel.DoApiRequest(a, c, info, requestBody)
|
resp, err := channel.DoApiRequest(a, c, info, requestBody)
|
||||||
@@ -93,7 +96,11 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *common.RelayInfo, requestBody
|
|||||||
|
|
||||||
// DoResponse implements channel.Adaptor.
|
// DoResponse implements channel.Adaptor.
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *common.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *common.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||||
err, usage = cozeChatHandler(c, resp, info)
|
if info.IsStream {
|
||||||
|
err, usage = cozeChatStreamHandler(c, resp, info)
|
||||||
|
} else {
|
||||||
|
err, usage = cozeChatHandler(c, resp, info)
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,16 +1,18 @@
|
|||||||
package coze
|
package coze
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bufio"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
"one-api/relay/common"
|
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
"one-api/relay/helper"
|
"one-api/relay/helper"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
@@ -95,6 +97,124 @@ func cozeChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
|
|||||||
return nil, &usage
|
return nil, &usage
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func cozeChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||||
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
|
scanner.Split(bufio.ScanLines)
|
||||||
|
helper.SetEventStreamHeaders(c)
|
||||||
|
id := helper.GetResponseID(c)
|
||||||
|
var responseText string
|
||||||
|
|
||||||
|
var currentEvent string
|
||||||
|
var currentData string
|
||||||
|
var usage dto.Usage
|
||||||
|
|
||||||
|
for scanner.Scan() {
|
||||||
|
line := scanner.Text()
|
||||||
|
|
||||||
|
if line == "" {
|
||||||
|
if currentEvent != "" && currentData != "" {
|
||||||
|
// handle last event
|
||||||
|
handleCozeEvent(c, currentEvent, currentData, &responseText, &usage, id, info)
|
||||||
|
currentEvent = ""
|
||||||
|
currentData = ""
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.HasPrefix(line, "event:") {
|
||||||
|
currentEvent = strings.TrimSpace(line[6:])
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.HasPrefix(line, "data:") {
|
||||||
|
currentData = strings.TrimSpace(line[5:])
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Last event
|
||||||
|
if currentEvent != "" && currentData != "" {
|
||||||
|
handleCozeEvent(c, currentEvent, currentData, &responseText, &usage, id, info)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := scanner.Err(); err != nil {
|
||||||
|
return service.OpenAIErrorWrapper(err, "stream_scanner_error", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
helper.Done(c)
|
||||||
|
|
||||||
|
if usage.TotalTokens == 0 {
|
||||||
|
usage.PromptTokens = info.PromptTokens
|
||||||
|
usage.CompletionTokens, _ = service.CountTextToken("gpt-3.5-turbo", responseText)
|
||||||
|
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, &usage
|
||||||
|
}
|
||||||
|
|
||||||
|
func handleCozeEvent(c *gin.Context, event string, data string, responseText *string, usage *dto.Usage, id string, info *relaycommon.RelayInfo) {
|
||||||
|
switch event {
|
||||||
|
case "conversation.chat.completed":
|
||||||
|
// 将 data 解析为 CozeChatResponseData
|
||||||
|
var chatData CozeChatResponseData
|
||||||
|
err := json.Unmarshal([]byte(data), &chatData)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("error_unmarshalling_stream_response: " + err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
usage.PromptTokens = chatData.Usage.InputCount
|
||||||
|
usage.CompletionTokens = chatData.Usage.OutputCount
|
||||||
|
usage.TotalTokens = chatData.Usage.TokenCount
|
||||||
|
|
||||||
|
finishReason := "stop"
|
||||||
|
stopResponse := helper.GenerateStopResponse(id, common.GetTimestamp(), info.UpstreamModelName, finishReason)
|
||||||
|
helper.ObjectData(c, stopResponse)
|
||||||
|
|
||||||
|
case "conversation.message.delta":
|
||||||
|
// 将 data 解析为 CozeChatV3MessageDetail
|
||||||
|
var messageData CozeChatV3MessageDetail
|
||||||
|
err := json.Unmarshal([]byte(data), &messageData)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("error_unmarshalling_stream_response: " + err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var content string
|
||||||
|
err = json.Unmarshal(messageData.Content, &content)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("error_unmarshalling_stream_response: " + err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
*responseText += content
|
||||||
|
|
||||||
|
openaiResponse := dto.ChatCompletionsStreamResponse{
|
||||||
|
Id: id,
|
||||||
|
Object: "chat.completion.chunk",
|
||||||
|
Created: common.GetTimestamp(),
|
||||||
|
Model: info.UpstreamModelName,
|
||||||
|
}
|
||||||
|
|
||||||
|
choice := dto.ChatCompletionsStreamResponseChoice{
|
||||||
|
Index: 0,
|
||||||
|
}
|
||||||
|
choice.Delta.SetContentString(content)
|
||||||
|
openaiResponse.Choices = append(openaiResponse.Choices, choice)
|
||||||
|
|
||||||
|
helper.ObjectData(c, openaiResponse)
|
||||||
|
|
||||||
|
case "error":
|
||||||
|
var errorData CozeError
|
||||||
|
err := json.Unmarshal([]byte(data), &errorData)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("error_unmarshalling_stream_response: " + err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
common.SysError(fmt.Sprintf("stream event error: ", errorData.Code, errorData.Message))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func checkIfChatComplete(a *Adaptor, c *gin.Context, info *relaycommon.RelayInfo) (error, bool) {
|
func checkIfChatComplete(a *Adaptor, c *gin.Context, info *relaycommon.RelayInfo) (error, bool) {
|
||||||
requestURL := fmt.Sprintf("%s/v3/chat/retrieve", info.BaseUrl)
|
requestURL := fmt.Sprintf("%s/v3/chat/retrieve", info.BaseUrl)
|
||||||
|
|
||||||
@@ -160,7 +280,7 @@ func getChatDetail(a *Adaptor, c *gin.Context, info *relaycommon.RelayInfo) (*ht
|
|||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func doRequest(req *http.Request, info *common.RelayInfo) (*http.Response, error) {
|
func doRequest(req *http.Request, info *relaycommon.RelayInfo) (*http.Response, error) {
|
||||||
var client *http.Client
|
var client *http.Client
|
||||||
var err error // 声明 err 变量
|
var err error // 声明 err 变量
|
||||||
if proxyURL, ok := info.ChannelSetting["proxy"]; ok {
|
if proxyURL, ok := info.ChannelSetting["proxy"]; ok {
|
||||||
|
|||||||
Reference in New Issue
Block a user