From b2cad229520ab533f1981daefe9a478502ddb31f Mon Sep 17 00:00:00 2001 From: creamlike1024 Date: Tue, 13 May 2025 12:52:22 +0800 Subject: [PATCH 1/6] add coze request --- common/constants.go | 2 + relay/channel/coze/adaptor.go | 125 +++++++++++++++++++++++++++++++ relay/channel/coze/constants.go | 8 ++ relay/channel/coze/dto.go | 81 ++++++++++++++++++++ relay/channel/coze/relay-coze.go | 121 ++++++++++++++++++++++++++++++ relay/constant/api_type.go | 3 + relay/relay_adaptor.go | 3 + 7 files changed, 343 insertions(+) create mode 100644 relay/channel/coze/adaptor.go create mode 100644 relay/channel/coze/constants.go create mode 100644 relay/channel/coze/dto.go create mode 100644 relay/channel/coze/relay-coze.go diff --git a/common/constants.go b/common/constants.go index dd4f3b04..bee00506 100644 --- a/common/constants.go +++ b/common/constants.go @@ -240,6 +240,7 @@ const ( ChannelTypeBaiduV2 = 46 ChannelTypeXinference = 47 ChannelTypeXai = 48 + ChannelTypeCoze = 49 ChannelTypeDummy // this one is only for count, do not add any channel after this ) @@ -294,4 +295,5 @@ var ChannelBaseURLs = []string{ "https://qianfan.baidubce.com", //46 "", //47 "https://api.x.ai", //48 + "https://api.coze.cn", //49 } diff --git a/relay/channel/coze/adaptor.go b/relay/channel/coze/adaptor.go new file mode 100644 index 00000000..b14239a6 --- /dev/null +++ b/relay/channel/coze/adaptor.go @@ -0,0 +1,125 @@ +package coze + +import ( + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "one-api/dto" + "one-api/relay/channel" + "one-api/relay/common" + "time" + + "github.com/gin-gonic/gin" +) + +type Adaptor struct { +} + +// ConvertAudioRequest implements channel.Adaptor. +func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *common.RelayInfo, request dto.AudioRequest) (io.Reader, error) { + return nil, errors.New("not implemented") +} + +// ConvertClaudeRequest implements channel.Adaptor. +func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *common.RelayInfo, request *dto.ClaudeRequest) (any, error) { + return nil, errors.New("not implemented") +} + +// ConvertEmbeddingRequest implements channel.Adaptor. +func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *common.RelayInfo, request dto.EmbeddingRequest) (any, error) { + return nil, errors.New("not implemented") +} + +// ConvertImageRequest implements channel.Adaptor. +func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *common.RelayInfo, request dto.ImageRequest) (any, error) { + return nil, errors.New("not implemented") +} + +// ConvertOpenAIRequest implements channel.Adaptor. +func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *common.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return convertCozeChatRequest(*request), nil +} + +// ConvertOpenAIResponsesRequest implements channel.Adaptor. +func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *common.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) { + return nil, errors.New("not implemented") +} + +// ConvertRerankRequest implements channel.Adaptor. +func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) { + return nil, errors.New("not implemented") +} + +// DoRequest implements channel.Adaptor. +func (a *Adaptor) DoRequest(c *gin.Context, info *common.RelayInfo, requestBody io.Reader) (any, error) { + // 首先发送创建消息请求,成功后再发送获取消息请求 + // 发送创建消息请求 + resp, err := channel.DoApiRequest(a, c, info, requestBody) + if err != nil { + return nil, err + } + // 解析 resp + var cozeResponse CozeChatResponse + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + err = json.Unmarshal(respBody, &cozeResponse) + if cozeResponse.Code != 0 { + return nil, errors.New(cozeResponse.Msg) + } + c.Set("coze_conversation_id", cozeResponse.Data.ConversationId) + c.Set("coze_chat_id", cozeResponse.Data.Id) + // 轮询检查消息是否完成 + for { + err, isComplete := checkIfChatComplete(a, c, info) + if err != nil { + return nil, err + } else { + if isComplete { + break + } + } + time.Sleep(time.Second * 1) + } + // 发送获取消息请求 + return channel.DoApiRequest(a, c, info, requestBody) +} + +// DoResponse implements channel.Adaptor. +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *common.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { + err, usage = cozeChatHandler(c, resp, info) + return +} + +// GetChannelName implements channel.Adaptor. +func (a *Adaptor) GetChannelName() string { + return ChannelName +} + +// GetModelList implements channel.Adaptor. +func (a *Adaptor) GetModelList() []string { + return ModelList +} + +// GetRequestURL implements channel.Adaptor. +func (a *Adaptor) GetRequestURL(info *common.RelayInfo) (string, error) { + return fmt.Sprintf("%s/v3/chat/message/list", info.BaseUrl), nil +} + +// Init implements channel.Adaptor. +func (a *Adaptor) Init(info *common.RelayInfo) { + +} + +// SetupRequestHeader implements channel.Adaptor. +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *common.RelayInfo) error { + channel.SetupApiRequestHeader(info, c, req) + req.Set("Authorization", "Bearer "+info.ApiKey) + return nil +} diff --git a/relay/channel/coze/constants.go b/relay/channel/coze/constants.go new file mode 100644 index 00000000..da28cb83 --- /dev/null +++ b/relay/channel/coze/constants.go @@ -0,0 +1,8 @@ +package coze + +var ModelList = []string{ + // TODO: 完整列表 + "deepseek-v3", +} + +var ChannelName = "coze" diff --git a/relay/channel/coze/dto.go b/relay/channel/coze/dto.go new file mode 100644 index 00000000..fb92289a --- /dev/null +++ b/relay/channel/coze/dto.go @@ -0,0 +1,81 @@ +package coze + +import "encoding/json" + +// type CozeResponse struct { +// Code int `json:"code"` +// Message string `json:"message"` +// Data CozeConversationData `json:"data"` +// Detail CozeConversationData `json:"detail"` +// } + +// type CozeConversationData struct { +// Id string `json:"id"` +// CreatedAt int64 `json:"created_at"` +// MetaData json.RawMessage `json:"meta_data"` +// LastSectionId string `json:"last_section_id"` +// } + +// type CozeResponseDetail struct { +// Logid string `json:"logid"` +// } + +type CozeError struct { + Code int `json:"code"` + Message string `json:"message"` +} + +// type CozeErrorWithStatusCode struct { +// Error CozeError `json:"error"` +// StatusCode int +// LocalError bool +// } + +type CozeRequest struct { + BotId string `json:"bot_id,omitempty"` + MetaData json.RawMessage `json:"meta_data,omitempty"` + Messages []CozeEnterMessage `json:"messages,omitempty"` +} + +type CozeEnterMessage struct { + Role string `json:"role"` + Type string `json:"type,omitempty"` + Content json.RawMessage `json:"content,omitempty"` + MetaData json.RawMessage `json:"meta_data,omitempty"` + ContentType string `json:"content_type,omitempty"` +} + +type CozeChatRequest struct { + BotId string `json:"bot_id"` + UserId string `json:"user_id"` + AdditionalMessages []CozeEnterMessage `json:"additional_messages,omitempty"` + Stream bool `json:"stream,omitempty"` + CustomVariables json.RawMessage `json:"custom_variables,omitempty"` + AutoSaveHistory bool `json:"auto_save_history,omitempty"` + MetaData json.RawMessage `json:"meta_data,omitempty"` + ExtraParams json.RawMessage `json:"extra_params,omitempty"` + ShortcutCommand json.RawMessage `json:"shortcut_command,omitempty"` + Parameters json.RawMessage `json:"parameters,omitempty"` +} + +type CozeChatResponse struct { + Code int `json:"code"` + Msg string `json:"msg"` + Data CozeChatResponseData `json:"data"` +} + +type CozeChatResponseData struct { + Id string `json:"id"` + ConversationId string `json:"conversation_id"` + BotId string `json:"bot_id"` + CreatedAt int64 `json:"created_at"` + LastError CozeError `json:"last_error"` + Status string `json:"status"` + Usage CozeChatUsage `json:"usage"` +} + +type CozeChatUsage struct { + TokenCount int `json:"token_count"` + OutputCount int `json:"output_count"` + InputCount int `json:"input_count"` +} diff --git a/relay/channel/coze/relay-coze.go b/relay/channel/coze/relay-coze.go new file mode 100644 index 00000000..49a3ac15 --- /dev/null +++ b/relay/channel/coze/relay-coze.go @@ -0,0 +1,121 @@ +package coze + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + "one-api/dto" + "one-api/relay/common" + relaycommon "one-api/relay/common" + "one-api/service" + + "github.com/gin-gonic/gin" +) + +func convertCozeChatRequest(request dto.GeneralOpenAIRequest) *CozeRequest { + var messages []CozeEnterMessage + // 将 request的messages的role为user的content转换为CozeMessage + for _, message := range request.Messages { + if message.Role == "user" { + messages = append(messages, CozeEnterMessage{ + Role: "user", + Content: message.Content, + // TODO: support more content type + ContentType: "text", + }) + } + } + cozeRequest := &CozeRequest{ + // TODO: model to botid + BotId: "1", + Messages: messages, + } + return cozeRequest +} + +func cozeChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + // convert coze response to openai response + var response dto.TextResponse + var cozeResponse CozeChatResponse + err = json.Unmarshal(responseBody, &cozeResponse) + if err != nil { + return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + response.Model = info.UpstreamModelName + // TODO: 处理 cozeResponse + return nil, nil +} + +func checkIfChatComplete(a *Adaptor, c *gin.Context, info *relaycommon.RelayInfo) (error, bool) { + requestURL := fmt.Sprintf("%s/v3/chat/retrieve", info.BaseUrl) + + requestURL = requestURL + "?conversation_id=" + c.GetString("coze_conversation_id") + "&chat_id=" + c.GetString("coze_chat_id") + // 将 conversationId和chatId作为参数发送get请求 + req, err := http.NewRequest("GET", requestURL, nil) + if err != nil { + return err, false + } + err = a.SetupRequestHeader(c, &req.Header, info) + if err != nil { + return err, false + } + + resp, err := doRequest(req, info) // 调用 doRequest + if err != nil { + return err, false + } + if resp == nil { // 确保在 doRequest 失败时 resp 不为 nil 导致 panic + return fmt.Errorf("resp is nil"), false + } + defer resp.Body.Close() // 确保响应体被关闭 + + // 解析 resp 到 CozeChatResponse + var cozeResponse CozeChatResponse + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("read response body failed: %w", err), false + } + err = json.Unmarshal(responseBody, &cozeResponse) + if err != nil { + return fmt.Errorf("unmarshal response body failed: %w", err), false + } + if cozeResponse.Data.Status == "completed" { + // 在上下文设置 usage + c.Set("coze_token_count", cozeResponse.Data.Usage.TokenCount) + c.Set("coze_output_count", cozeResponse.Data.Usage.OutputCount) + c.Set("coze_input_count", cozeResponse.Data.Usage.InputCount) + return nil, true + } else if cozeResponse.Data.Status == "failed" || cozeResponse.Data.Status == "canceled" || cozeResponse.Data.Status == "requires_action" { + return fmt.Errorf("chat status: %s", cozeResponse.Data.Status), false + } else { + return nil, false + } +} + +func doRequest(req *http.Request, info *common.RelayInfo) (*http.Response, error) { + var client *http.Client + var err error // 声明 err 变量 + if proxyURL, ok := info.ChannelSetting["proxy"]; ok { + client, err = service.NewProxyHttpClient(proxyURL.(string)) + if err != nil { + return nil, fmt.Errorf("new proxy http client failed: %w", err) + } + } else { + client = service.GetHttpClient() + } + resp, err := client.Do(req) + if err != nil { // 增加对 client.Do(req) 返回错误的检查 + return nil, fmt.Errorf("client.Do failed: %w", err) + } + _ = resp.Body.Close() + return resp, nil +} diff --git a/relay/constant/api_type.go b/relay/constant/api_type.go index fef38f23..3f1ecd78 100644 --- a/relay/constant/api_type.go +++ b/relay/constant/api_type.go @@ -33,6 +33,7 @@ const ( APITypeOpenRouter APITypeXinference APITypeXai + APITypeCoze APITypeDummy // this one is only for count, do not add any channel after this ) @@ -95,6 +96,8 @@ func ChannelType2APIType(channelType int) (int, bool) { apiType = APITypeXinference case common.ChannelTypeXai: apiType = APITypeXai + case common.ChannelTypeCoze: + apiType = APITypeCoze } if apiType == -1 { return APITypeOpenAI, false diff --git a/relay/relay_adaptor.go b/relay/relay_adaptor.go index 8b4afcb3..7bf0da9f 100644 --- a/relay/relay_adaptor.go +++ b/relay/relay_adaptor.go @@ -10,6 +10,7 @@ import ( "one-api/relay/channel/claude" "one-api/relay/channel/cloudflare" "one-api/relay/channel/cohere" + "one-api/relay/channel/coze" "one-api/relay/channel/deepseek" "one-api/relay/channel/dify" "one-api/relay/channel/gemini" @@ -88,6 +89,8 @@ func GetAdaptor(apiType int) channel.Adaptor { return &openai.Adaptor{} case constant.APITypeXai: return &xai.Adaptor{} + case constant.APITypeCoze: + return &coze.Adaptor{} } return nil } From b2499b0a7ed0d902ad7ae4653dd0d0ab7e81055a Mon Sep 17 00:00:00 2001 From: creamlike1024 Date: Tue, 13 May 2025 21:13:34 +0800 Subject: [PATCH 2/6] DoRequest --- relay/channel/coze/adaptor.go | 6 +++--- relay/channel/coze/dto.go | 30 ------------------------------ relay/channel/coze/relay-coze.go | 29 +++++++++++++++++++++++++---- 3 files changed, 28 insertions(+), 37 deletions(-) diff --git a/relay/channel/coze/adaptor.go b/relay/channel/coze/adaptor.go index b14239a6..34931cc6 100644 --- a/relay/channel/coze/adaptor.go +++ b/relay/channel/coze/adaptor.go @@ -42,7 +42,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *common.RelayInfo, r if request == nil { return nil, errors.New("request is nil") } - return convertCozeChatRequest(*request), nil + return convertCozeChatRequest(c, *request), nil } // ConvertOpenAIResponsesRequest implements channel.Adaptor. @@ -88,7 +88,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *common.RelayInfo, requestBody time.Sleep(time.Second * 1) } // 发送获取消息请求 - return channel.DoApiRequest(a, c, info, requestBody) + return getChatDetail(a, c, info) } // DoResponse implements channel.Adaptor. @@ -109,7 +109,7 @@ func (a *Adaptor) GetModelList() []string { // GetRequestURL implements channel.Adaptor. func (a *Adaptor) GetRequestURL(info *common.RelayInfo) (string, error) { - return fmt.Sprintf("%s/v3/chat/message/list", info.BaseUrl), nil + return fmt.Sprintf("%s/v3/chat", info.BaseUrl), nil } // Init implements channel.Adaptor. diff --git a/relay/channel/coze/dto.go b/relay/channel/coze/dto.go index fb92289a..38fc2f16 100644 --- a/relay/channel/coze/dto.go +++ b/relay/channel/coze/dto.go @@ -2,41 +2,11 @@ package coze import "encoding/json" -// type CozeResponse struct { -// Code int `json:"code"` -// Message string `json:"message"` -// Data CozeConversationData `json:"data"` -// Detail CozeConversationData `json:"detail"` -// } - -// type CozeConversationData struct { -// Id string `json:"id"` -// CreatedAt int64 `json:"created_at"` -// MetaData json.RawMessage `json:"meta_data"` -// LastSectionId string `json:"last_section_id"` -// } - -// type CozeResponseDetail struct { -// Logid string `json:"logid"` -// } - type CozeError struct { Code int `json:"code"` Message string `json:"message"` } -// type CozeErrorWithStatusCode struct { -// Error CozeError `json:"error"` -// StatusCode int -// LocalError bool -// } - -type CozeRequest struct { - BotId string `json:"bot_id,omitempty"` - MetaData json.RawMessage `json:"meta_data,omitempty"` - Messages []CozeEnterMessage `json:"messages,omitempty"` -} - type CozeEnterMessage struct { Role string `json:"role"` Type string `json:"type,omitempty"` diff --git a/relay/channel/coze/relay-coze.go b/relay/channel/coze/relay-coze.go index 49a3ac15..7c16763e 100644 --- a/relay/channel/coze/relay-coze.go +++ b/relay/channel/coze/relay-coze.go @@ -13,7 +13,7 @@ import ( "github.com/gin-gonic/gin" ) -func convertCozeChatRequest(request dto.GeneralOpenAIRequest) *CozeRequest { +func convertCozeChatRequest(c *gin.Context, request dto.GeneralOpenAIRequest) *CozeChatRequest { var messages []CozeEnterMessage // 将 request的messages的role为user的content转换为CozeMessage for _, message := range request.Messages { @@ -26,10 +26,12 @@ func convertCozeChatRequest(request dto.GeneralOpenAIRequest) *CozeRequest { }) } } - cozeRequest := &CozeRequest{ + cozeRequest := &CozeChatRequest{ // TODO: model to botid - BotId: "1", - Messages: messages, + BotId: "1", + UserId: c.GetString("id"), + AdditionalMessages: messages, + Stream: request.Stream, } return cozeRequest } @@ -101,6 +103,25 @@ func checkIfChatComplete(a *Adaptor, c *gin.Context, info *relaycommon.RelayInfo } } +func getChatDetail(a *Adaptor, c *gin.Context, info *relaycommon.RelayInfo) (*http.Response, error) { + requestURL := fmt.Sprintf("%s/v3/chat/message/list", info.BaseUrl) + + requestURL = requestURL + "?conversation_id=" + c.GetString("coze_conversation_id") + "&chat_id=" + c.GetString("coze_chat_id") + req, err := http.NewRequest("GET", requestURL, nil) + if err != nil { + return nil, fmt.Errorf("new request failed: %w", err) + } + err = a.SetupRequestHeader(c, &req.Header, info) + if err != nil { + return nil, fmt.Errorf("setup request header failed: %w", err) + } + resp, err := doRequest(req, info) + if err != nil { + return nil, fmt.Errorf("do request failed: %w", err) + } + return resp, nil +} + func doRequest(req *http.Request, info *common.RelayInfo) (*http.Response, error) { var client *http.Client var err error // 声明 err 变量 From 29c95c598e380dbe5ff80cd0690a1c4c3770f93d Mon Sep 17 00:00:00 2001 From: creamlike1024 Date: Tue, 13 May 2025 22:01:12 +0800 Subject: [PATCH 3/6] cozeChatHelper --- relay/channel/coze/dto.go | 27 ++++++++++++++++++++ relay/channel/coze/relay-coze.go | 43 +++++++++++++++++++++++++++++--- 2 files changed, 66 insertions(+), 4 deletions(-) diff --git a/relay/channel/coze/dto.go b/relay/channel/coze/dto.go index 38fc2f16..4e9afa23 100644 --- a/relay/channel/coze/dto.go +++ b/relay/channel/coze/dto.go @@ -49,3 +49,30 @@ type CozeChatUsage struct { OutputCount int `json:"output_count"` InputCount int `json:"input_count"` } + +type CozeChatDetailResponse struct { + Data []CozeChatV3MessageDetail `json:"data"` + Code int `json:"code"` + Msg string `json:"msg"` + Detail CozeResponseDetail `json:"detail"` +} + +type CozeChatV3MessageDetail struct { + Id string `json:"id"` + Role string `json:"role"` + Type string `json:"type"` + BotId string `json:"bot_id"` + ChatId string `json:"chat_id"` + Content json.RawMessage `json:"content"` + MetaData json.RawMessage `json:"meta_data"` + CreatedAt int64 `json:"created_at"` + SectionId string `json:"section_id"` + UpdatedAt int64 `json:"updated_at"` + ContentType string `json:"content_type"` + ConversationId string `json:"conversation_id"` + ReasoningContent string `json:"reasoning_content"` +} + +type CozeResponseDetail struct { + Logid string `json:"logid"` +} diff --git a/relay/channel/coze/relay-coze.go b/relay/channel/coze/relay-coze.go index 7c16763e..fe630ef6 100644 --- a/relay/channel/coze/relay-coze.go +++ b/relay/channel/coze/relay-coze.go @@ -2,12 +2,14 @@ package coze import ( "encoding/json" + "errors" "fmt" "io" "net/http" "one-api/dto" "one-api/relay/common" relaycommon "one-api/relay/common" + "one-api/relay/helper" "one-api/service" "github.com/gin-gonic/gin" @@ -47,14 +49,47 @@ func cozeChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela } // convert coze response to openai response var response dto.TextResponse - var cozeResponse CozeChatResponse + var cozeResponse CozeChatDetailResponse + response.Model = info.UpstreamModelName err = json.Unmarshal(responseBody, &cozeResponse) if err != nil { return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil } - response.Model = info.UpstreamModelName - // TODO: 处理 cozeResponse - return nil, nil + if cozeResponse.Code != 0 { + return service.OpenAIErrorWrapper(errors.New(cozeResponse.Msg), fmt.Sprintf("%d", cozeResponse.Code), http.StatusInternalServerError), nil + } + // 从上下文获取 usage + var usage dto.Usage + usage.PromptTokens = c.GetInt("coze_input_count") + usage.CompletionTokens = c.GetInt("coze_output_count") + usage.TotalTokens = c.GetInt("coze_token_count") + response.Usage = usage + response.Id = helper.GetResponseID(c) + + var responseContent json.RawMessage + for _, data := range cozeResponse.Data { + if data.Type == "answer" { + responseContent = data.Content + response.Created = data.CreatedAt + } + } + // 添加 response.Choices + response.Choices = []dto.OpenAITextResponseChoice{ + { + Index: 0, + Message: dto.Message{Role: "assistant", Content: responseContent}, + FinishReason: "stop", + }, + } + jsonResponse, err := json.Marshal(response) + if err != nil { + return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, _ = c.Writer.Write(jsonResponse) + + return nil, &usage } func checkIfChatComplete(a *Adaptor, c *gin.Context, info *relaycommon.RelayInfo) (error, bool) { From 108b67be6cc269778c17e24d38b5bc1971d11919 Mon Sep 17 00:00:00 2001 From: creamlike1024 Date: Tue, 13 May 2025 22:23:38 +0800 Subject: [PATCH 4/6] use channel bot id --- middleware/distributor.go | 2 ++ relay/channel/coze/relay-coze.go | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/middleware/distributor.go b/middleware/distributor.go index 34882381..fdda8dda 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -240,5 +240,7 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode c.Set("api_version", channel.Other) case common.ChannelTypeMokaAI: c.Set("api_version", channel.Other) + case common.ChannelTypeCoze: + c.Set("bot_id", channel.Other) } } diff --git a/relay/channel/coze/relay-coze.go b/relay/channel/coze/relay-coze.go index fe630ef6..8e9b8e3e 100644 --- a/relay/channel/coze/relay-coze.go +++ b/relay/channel/coze/relay-coze.go @@ -30,7 +30,7 @@ func convertCozeChatRequest(c *gin.Context, request dto.GeneralOpenAIRequest) *C } cozeRequest := &CozeChatRequest{ // TODO: model to botid - BotId: "1", + BotId: c.GetString("bot_id"), UserId: c.GetString("id"), AdditionalMessages: messages, Stream: request.Stream, From 59aabb43119059bca2e26fd2059904294b6e0ce3 Mon Sep 17 00:00:00 2001 From: creamlike1024 Date: Thu, 15 May 2025 20:00:59 +0800 Subject: [PATCH 5/6] add frontend display, more model --- relay/channel/coze/constants.go | 24 +++++++++++++++++++++++- relay/channel/coze/relay-coze.go | 9 ++++++--- web/src/constants/channel.constants.js | 9 +++++++-- web/src/pages/Channel/EditChannel.js | 16 ++++++++++++++++ 4 files changed, 52 insertions(+), 6 deletions(-) diff --git a/relay/channel/coze/constants.go b/relay/channel/coze/constants.go index da28cb83..873ffe24 100644 --- a/relay/channel/coze/constants.go +++ b/relay/channel/coze/constants.go @@ -1,8 +1,30 @@ package coze var ModelList = []string{ - // TODO: 完整列表 + "moonshot-v1-8k", + "moonshot-v1-32k", + "moonshot-v1-128k", + "Baichuan4", + "abab6.5s-chat-pro", + "glm-4-0520", + "qwen-max", + "deepseek-r1", "deepseek-v3", + "deepseek-r1-distill-qwen-32b", + "deepseek-r1-distill-qwen-7b", + "step-1v-8k", + "step-1.5v-mini", + "Doubao-pro-32k", + "Doubao-pro-256k", + "Doubao-lite-128k", + "Doubao-lite-32k", + "Doubao-vision-lite-32k", + "Doubao-vision-pro-32k", + "Doubao-1.5-pro-vision-32k", + "Doubao-1.5-lite-32k", + "Doubao-1.5-pro-32k", + "Doubao-1.5-thinking-pro", + "Doubao-1.5-pro-256k", } var ChannelName = "coze" diff --git a/relay/channel/coze/relay-coze.go b/relay/channel/coze/relay-coze.go index 8e9b8e3e..1ebdb7c1 100644 --- a/relay/channel/coze/relay-coze.go +++ b/relay/channel/coze/relay-coze.go @@ -28,10 +28,13 @@ func convertCozeChatRequest(c *gin.Context, request dto.GeneralOpenAIRequest) *C }) } } + user := request.User + if user == "" { + user = helper.GetResponseID(c) + } cozeRequest := &CozeChatRequest{ - // TODO: model to botid BotId: c.GetString("bot_id"), - UserId: c.GetString("id"), + UserId: user, AdditionalMessages: messages, Stream: request.Stream, } @@ -172,6 +175,6 @@ func doRequest(req *http.Request, info *common.RelayInfo) (*http.Response, error if err != nil { // 增加对 client.Do(req) 返回错误的检查 return nil, fmt.Errorf("client.Do failed: %w", err) } - _ = resp.Body.Close() + // _ = resp.Body.Close() return resp, nil } diff --git a/web/src/constants/channel.constants.js b/web/src/constants/channel.constants.js index fa59bcce..054da535 100644 --- a/web/src/constants/channel.constants.js +++ b/web/src/constants/channel.constants.js @@ -118,6 +118,11 @@ export const CHANNEL_OPTIONS = [ { value: 48, color: 'blue', - label: 'xAI' - } + label: 'xAI', + }, + { + value: 49, + color: 'blue', + label: 'Coze', + }, ]; diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js index fd96ffb6..f7fab057 100644 --- a/web/src/pages/Channel/EditChannel.js +++ b/web/src/pages/Channel/EditChannel.js @@ -838,6 +838,22 @@ const EditChannel = (props) => { /> )} + {inputs.type === 49 && ( + <> +
+ 智能体ID: +
+ { + handleInputChange('other', value); + }} + value={inputs.other} + autoComplete='new-password' + /> + + )}
{t('模型')}:
From e379ee8f66c1d3f85c89a26994b88227564ffa10 Mon Sep 17 00:00:00 2001 From: creamlike1024 Date: Fri, 16 May 2025 10:27:07 +0800 Subject: [PATCH 6/6] coze stream --- relay/channel/coze/adaptor.go | 9 ++- relay/channel/coze/relay-coze.go | 124 ++++++++++++++++++++++++++++++- 2 files changed, 130 insertions(+), 3 deletions(-) diff --git a/relay/channel/coze/adaptor.go b/relay/channel/coze/adaptor.go index 34931cc6..80441a51 100644 --- a/relay/channel/coze/adaptor.go +++ b/relay/channel/coze/adaptor.go @@ -57,6 +57,9 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt // DoRequest implements channel.Adaptor. func (a *Adaptor) DoRequest(c *gin.Context, info *common.RelayInfo, requestBody io.Reader) (any, error) { + if info.IsStream { + return channel.DoApiRequest(a, c, info, requestBody) + } // 首先发送创建消息请求,成功后再发送获取消息请求 // 发送创建消息请求 resp, err := channel.DoApiRequest(a, c, info, requestBody) @@ -93,7 +96,11 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *common.RelayInfo, requestBody // DoResponse implements channel.Adaptor. func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *common.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { - err, usage = cozeChatHandler(c, resp, info) + if info.IsStream { + err, usage = cozeChatStreamHandler(c, resp, info) + } else { + err, usage = cozeChatHandler(c, resp, info) + } return } diff --git a/relay/channel/coze/relay-coze.go b/relay/channel/coze/relay-coze.go index 1ebdb7c1..6db40213 100644 --- a/relay/channel/coze/relay-coze.go +++ b/relay/channel/coze/relay-coze.go @@ -1,16 +1,18 @@ package coze import ( + "bufio" "encoding/json" "errors" "fmt" "io" "net/http" + "one-api/common" "one-api/dto" - "one-api/relay/common" relaycommon "one-api/relay/common" "one-api/relay/helper" "one-api/service" + "strings" "github.com/gin-gonic/gin" ) @@ -95,6 +97,124 @@ func cozeChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela return nil, &usage } +func cozeChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { + scanner := bufio.NewScanner(resp.Body) + scanner.Split(bufio.ScanLines) + helper.SetEventStreamHeaders(c) + id := helper.GetResponseID(c) + var responseText string + + var currentEvent string + var currentData string + var usage dto.Usage + + for scanner.Scan() { + line := scanner.Text() + + if line == "" { + if currentEvent != "" && currentData != "" { + // handle last event + handleCozeEvent(c, currentEvent, currentData, &responseText, &usage, id, info) + currentEvent = "" + currentData = "" + } + continue + } + + if strings.HasPrefix(line, "event:") { + currentEvent = strings.TrimSpace(line[6:]) + continue + } + + if strings.HasPrefix(line, "data:") { + currentData = strings.TrimSpace(line[5:]) + continue + } + } + + // Last event + if currentEvent != "" && currentData != "" { + handleCozeEvent(c, currentEvent, currentData, &responseText, &usage, id, info) + } + + if err := scanner.Err(); err != nil { + return service.OpenAIErrorWrapper(err, "stream_scanner_error", http.StatusInternalServerError), nil + } + helper.Done(c) + + if usage.TotalTokens == 0 { + usage.PromptTokens = info.PromptTokens + usage.CompletionTokens, _ = service.CountTextToken("gpt-3.5-turbo", responseText) + usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens + } + + return nil, &usage +} + +func handleCozeEvent(c *gin.Context, event string, data string, responseText *string, usage *dto.Usage, id string, info *relaycommon.RelayInfo) { + switch event { + case "conversation.chat.completed": + // 将 data 解析为 CozeChatResponseData + var chatData CozeChatResponseData + err := json.Unmarshal([]byte(data), &chatData) + if err != nil { + common.SysError("error_unmarshalling_stream_response: " + err.Error()) + return + } + + usage.PromptTokens = chatData.Usage.InputCount + usage.CompletionTokens = chatData.Usage.OutputCount + usage.TotalTokens = chatData.Usage.TokenCount + + finishReason := "stop" + stopResponse := helper.GenerateStopResponse(id, common.GetTimestamp(), info.UpstreamModelName, finishReason) + helper.ObjectData(c, stopResponse) + + case "conversation.message.delta": + // 将 data 解析为 CozeChatV3MessageDetail + var messageData CozeChatV3MessageDetail + err := json.Unmarshal([]byte(data), &messageData) + if err != nil { + common.SysError("error_unmarshalling_stream_response: " + err.Error()) + return + } + + var content string + err = json.Unmarshal(messageData.Content, &content) + if err != nil { + common.SysError("error_unmarshalling_stream_response: " + err.Error()) + return + } + + *responseText += content + + openaiResponse := dto.ChatCompletionsStreamResponse{ + Id: id, + Object: "chat.completion.chunk", + Created: common.GetTimestamp(), + Model: info.UpstreamModelName, + } + + choice := dto.ChatCompletionsStreamResponseChoice{ + Index: 0, + } + choice.Delta.SetContentString(content) + openaiResponse.Choices = append(openaiResponse.Choices, choice) + + helper.ObjectData(c, openaiResponse) + + case "error": + var errorData CozeError + err := json.Unmarshal([]byte(data), &errorData) + if err != nil { + common.SysError("error_unmarshalling_stream_response: " + err.Error()) + return + } + + common.SysError(fmt.Sprintf("stream event error: ", errorData.Code, errorData.Message)) + } +} + func checkIfChatComplete(a *Adaptor, c *gin.Context, info *relaycommon.RelayInfo) (error, bool) { requestURL := fmt.Sprintf("%s/v3/chat/retrieve", info.BaseUrl) @@ -160,7 +280,7 @@ func getChatDetail(a *Adaptor, c *gin.Context, info *relaycommon.RelayInfo) (*ht return resp, nil } -func doRequest(req *http.Request, info *common.RelayInfo) (*http.Response, error) { +func doRequest(req *http.Request, info *relaycommon.RelayInfo) (*http.Response, error) { var client *http.Client var err error // 声明 err 变量 if proxyURL, ok := info.ChannelSetting["proxy"]; ok {