add coze request
This commit is contained in:
@@ -240,6 +240,7 @@ const (
|
|||||||
ChannelTypeBaiduV2 = 46
|
ChannelTypeBaiduV2 = 46
|
||||||
ChannelTypeXinference = 47
|
ChannelTypeXinference = 47
|
||||||
ChannelTypeXai = 48
|
ChannelTypeXai = 48
|
||||||
|
ChannelTypeCoze = 49
|
||||||
ChannelTypeDummy // this one is only for count, do not add any channel after this
|
ChannelTypeDummy // this one is only for count, do not add any channel after this
|
||||||
|
|
||||||
)
|
)
|
||||||
@@ -294,4 +295,5 @@ var ChannelBaseURLs = []string{
|
|||||||
"https://qianfan.baidubce.com", //46
|
"https://qianfan.baidubce.com", //46
|
||||||
"", //47
|
"", //47
|
||||||
"https://api.x.ai", //48
|
"https://api.x.ai", //48
|
||||||
|
"https://api.coze.cn", //49
|
||||||
}
|
}
|
||||||
|
|||||||
125
relay/channel/coze/adaptor.go
Normal file
125
relay/channel/coze/adaptor.go
Normal file
@@ -0,0 +1,125 @@
|
|||||||
|
package coze
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"one-api/dto"
|
||||||
|
"one-api/relay/channel"
|
||||||
|
"one-api/relay/common"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Adaptor struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConvertAudioRequest implements channel.Adaptor.
|
||||||
|
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *common.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConvertClaudeRequest implements channel.Adaptor.
|
||||||
|
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *common.RelayInfo, request *dto.ClaudeRequest) (any, error) {
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConvertEmbeddingRequest implements channel.Adaptor.
|
||||||
|
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *common.RelayInfo, request dto.EmbeddingRequest) (any, error) {
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConvertImageRequest implements channel.Adaptor.
|
||||||
|
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *common.RelayInfo, request dto.ImageRequest) (any, error) {
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConvertOpenAIRequest implements channel.Adaptor.
|
||||||
|
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *common.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||||
|
if request == nil {
|
||||||
|
return nil, errors.New("request is nil")
|
||||||
|
}
|
||||||
|
return convertCozeChatRequest(*request), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConvertOpenAIResponsesRequest implements channel.Adaptor.
|
||||||
|
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *common.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConvertRerankRequest implements channel.Adaptor.
|
||||||
|
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
// DoRequest implements channel.Adaptor.
|
||||||
|
func (a *Adaptor) DoRequest(c *gin.Context, info *common.RelayInfo, requestBody io.Reader) (any, error) {
|
||||||
|
// 首先发送创建消息请求,成功后再发送获取消息请求
|
||||||
|
// 发送创建消息请求
|
||||||
|
resp, err := channel.DoApiRequest(a, c, info, requestBody)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
// 解析 resp
|
||||||
|
var cozeResponse CozeChatResponse
|
||||||
|
respBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
err = json.Unmarshal(respBody, &cozeResponse)
|
||||||
|
if cozeResponse.Code != 0 {
|
||||||
|
return nil, errors.New(cozeResponse.Msg)
|
||||||
|
}
|
||||||
|
c.Set("coze_conversation_id", cozeResponse.Data.ConversationId)
|
||||||
|
c.Set("coze_chat_id", cozeResponse.Data.Id)
|
||||||
|
// 轮询检查消息是否完成
|
||||||
|
for {
|
||||||
|
err, isComplete := checkIfChatComplete(a, c, info)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
} else {
|
||||||
|
if isComplete {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
time.Sleep(time.Second * 1)
|
||||||
|
}
|
||||||
|
// 发送获取消息请求
|
||||||
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DoResponse implements channel.Adaptor.
|
||||||
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *common.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||||
|
err, usage = cozeChatHandler(c, resp, info)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetChannelName implements channel.Adaptor.
|
||||||
|
func (a *Adaptor) GetChannelName() string {
|
||||||
|
return ChannelName
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetModelList implements channel.Adaptor.
|
||||||
|
func (a *Adaptor) GetModelList() []string {
|
||||||
|
return ModelList
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetRequestURL implements channel.Adaptor.
|
||||||
|
func (a *Adaptor) GetRequestURL(info *common.RelayInfo) (string, error) {
|
||||||
|
return fmt.Sprintf("%s/v3/chat/message/list", info.BaseUrl), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Init implements channel.Adaptor.
|
||||||
|
func (a *Adaptor) Init(info *common.RelayInfo) {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetupRequestHeader implements channel.Adaptor.
|
||||||
|
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *common.RelayInfo) error {
|
||||||
|
channel.SetupApiRequestHeader(info, c, req)
|
||||||
|
req.Set("Authorization", "Bearer "+info.ApiKey)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
8
relay/channel/coze/constants.go
Normal file
8
relay/channel/coze/constants.go
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
package coze
|
||||||
|
|
||||||
|
var ModelList = []string{
|
||||||
|
// TODO: 完整列表
|
||||||
|
"deepseek-v3",
|
||||||
|
}
|
||||||
|
|
||||||
|
var ChannelName = "coze"
|
||||||
81
relay/channel/coze/dto.go
Normal file
81
relay/channel/coze/dto.go
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
package coze
|
||||||
|
|
||||||
|
import "encoding/json"
|
||||||
|
|
||||||
|
// type CozeResponse struct {
|
||||||
|
// Code int `json:"code"`
|
||||||
|
// Message string `json:"message"`
|
||||||
|
// Data CozeConversationData `json:"data"`
|
||||||
|
// Detail CozeConversationData `json:"detail"`
|
||||||
|
// }
|
||||||
|
|
||||||
|
// type CozeConversationData struct {
|
||||||
|
// Id string `json:"id"`
|
||||||
|
// CreatedAt int64 `json:"created_at"`
|
||||||
|
// MetaData json.RawMessage `json:"meta_data"`
|
||||||
|
// LastSectionId string `json:"last_section_id"`
|
||||||
|
// }
|
||||||
|
|
||||||
|
// type CozeResponseDetail struct {
|
||||||
|
// Logid string `json:"logid"`
|
||||||
|
// }
|
||||||
|
|
||||||
|
type CozeError struct {
|
||||||
|
Code int `json:"code"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// type CozeErrorWithStatusCode struct {
|
||||||
|
// Error CozeError `json:"error"`
|
||||||
|
// StatusCode int
|
||||||
|
// LocalError bool
|
||||||
|
// }
|
||||||
|
|
||||||
|
type CozeRequest struct {
|
||||||
|
BotId string `json:"bot_id,omitempty"`
|
||||||
|
MetaData json.RawMessage `json:"meta_data,omitempty"`
|
||||||
|
Messages []CozeEnterMessage `json:"messages,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type CozeEnterMessage struct {
|
||||||
|
Role string `json:"role"`
|
||||||
|
Type string `json:"type,omitempty"`
|
||||||
|
Content json.RawMessage `json:"content,omitempty"`
|
||||||
|
MetaData json.RawMessage `json:"meta_data,omitempty"`
|
||||||
|
ContentType string `json:"content_type,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type CozeChatRequest struct {
|
||||||
|
BotId string `json:"bot_id"`
|
||||||
|
UserId string `json:"user_id"`
|
||||||
|
AdditionalMessages []CozeEnterMessage `json:"additional_messages,omitempty"`
|
||||||
|
Stream bool `json:"stream,omitempty"`
|
||||||
|
CustomVariables json.RawMessage `json:"custom_variables,omitempty"`
|
||||||
|
AutoSaveHistory bool `json:"auto_save_history,omitempty"`
|
||||||
|
MetaData json.RawMessage `json:"meta_data,omitempty"`
|
||||||
|
ExtraParams json.RawMessage `json:"extra_params,omitempty"`
|
||||||
|
ShortcutCommand json.RawMessage `json:"shortcut_command,omitempty"`
|
||||||
|
Parameters json.RawMessage `json:"parameters,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type CozeChatResponse struct {
|
||||||
|
Code int `json:"code"`
|
||||||
|
Msg string `json:"msg"`
|
||||||
|
Data CozeChatResponseData `json:"data"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type CozeChatResponseData struct {
|
||||||
|
Id string `json:"id"`
|
||||||
|
ConversationId string `json:"conversation_id"`
|
||||||
|
BotId string `json:"bot_id"`
|
||||||
|
CreatedAt int64 `json:"created_at"`
|
||||||
|
LastError CozeError `json:"last_error"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
Usage CozeChatUsage `json:"usage"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type CozeChatUsage struct {
|
||||||
|
TokenCount int `json:"token_count"`
|
||||||
|
OutputCount int `json:"output_count"`
|
||||||
|
InputCount int `json:"input_count"`
|
||||||
|
}
|
||||||
121
relay/channel/coze/relay-coze.go
Normal file
121
relay/channel/coze/relay-coze.go
Normal file
@@ -0,0 +1,121 @@
|
|||||||
|
package coze
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"one-api/dto"
|
||||||
|
"one-api/relay/common"
|
||||||
|
relaycommon "one-api/relay/common"
|
||||||
|
"one-api/service"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
func convertCozeChatRequest(request dto.GeneralOpenAIRequest) *CozeRequest {
|
||||||
|
var messages []CozeEnterMessage
|
||||||
|
// 将 request的messages的role为user的content转换为CozeMessage
|
||||||
|
for _, message := range request.Messages {
|
||||||
|
if message.Role == "user" {
|
||||||
|
messages = append(messages, CozeEnterMessage{
|
||||||
|
Role: "user",
|
||||||
|
Content: message.Content,
|
||||||
|
// TODO: support more content type
|
||||||
|
ContentType: "text",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cozeRequest := &CozeRequest{
|
||||||
|
// TODO: model to botid
|
||||||
|
BotId: "1",
|
||||||
|
Messages: messages,
|
||||||
|
}
|
||||||
|
return cozeRequest
|
||||||
|
}
|
||||||
|
|
||||||
|
func cozeChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||||
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
err = resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
// convert coze response to openai response
|
||||||
|
var response dto.TextResponse
|
||||||
|
var cozeResponse CozeChatResponse
|
||||||
|
err = json.Unmarshal(responseBody, &cozeResponse)
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
response.Model = info.UpstreamModelName
|
||||||
|
// TODO: 处理 cozeResponse
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func checkIfChatComplete(a *Adaptor, c *gin.Context, info *relaycommon.RelayInfo) (error, bool) {
|
||||||
|
requestURL := fmt.Sprintf("%s/v3/chat/retrieve", info.BaseUrl)
|
||||||
|
|
||||||
|
requestURL = requestURL + "?conversation_id=" + c.GetString("coze_conversation_id") + "&chat_id=" + c.GetString("coze_chat_id")
|
||||||
|
// 将 conversationId和chatId作为参数发送get请求
|
||||||
|
req, err := http.NewRequest("GET", requestURL, nil)
|
||||||
|
if err != nil {
|
||||||
|
return err, false
|
||||||
|
}
|
||||||
|
err = a.SetupRequestHeader(c, &req.Header, info)
|
||||||
|
if err != nil {
|
||||||
|
return err, false
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := doRequest(req, info) // 调用 doRequest
|
||||||
|
if err != nil {
|
||||||
|
return err, false
|
||||||
|
}
|
||||||
|
if resp == nil { // 确保在 doRequest 失败时 resp 不为 nil 导致 panic
|
||||||
|
return fmt.Errorf("resp is nil"), false
|
||||||
|
}
|
||||||
|
defer resp.Body.Close() // 确保响应体被关闭
|
||||||
|
|
||||||
|
// 解析 resp 到 CozeChatResponse
|
||||||
|
var cozeResponse CozeChatResponse
|
||||||
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("read response body failed: %w", err), false
|
||||||
|
}
|
||||||
|
err = json.Unmarshal(responseBody, &cozeResponse)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("unmarshal response body failed: %w", err), false
|
||||||
|
}
|
||||||
|
if cozeResponse.Data.Status == "completed" {
|
||||||
|
// 在上下文设置 usage
|
||||||
|
c.Set("coze_token_count", cozeResponse.Data.Usage.TokenCount)
|
||||||
|
c.Set("coze_output_count", cozeResponse.Data.Usage.OutputCount)
|
||||||
|
c.Set("coze_input_count", cozeResponse.Data.Usage.InputCount)
|
||||||
|
return nil, true
|
||||||
|
} else if cozeResponse.Data.Status == "failed" || cozeResponse.Data.Status == "canceled" || cozeResponse.Data.Status == "requires_action" {
|
||||||
|
return fmt.Errorf("chat status: %s", cozeResponse.Data.Status), false
|
||||||
|
} else {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func doRequest(req *http.Request, info *common.RelayInfo) (*http.Response, error) {
|
||||||
|
var client *http.Client
|
||||||
|
var err error // 声明 err 变量
|
||||||
|
if proxyURL, ok := info.ChannelSetting["proxy"]; ok {
|
||||||
|
client, err = service.NewProxyHttpClient(proxyURL.(string))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("new proxy http client failed: %w", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
client = service.GetHttpClient()
|
||||||
|
}
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil { // 增加对 client.Do(req) 返回错误的检查
|
||||||
|
return nil, fmt.Errorf("client.Do failed: %w", err)
|
||||||
|
}
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
@@ -33,6 +33,7 @@ const (
|
|||||||
APITypeOpenRouter
|
APITypeOpenRouter
|
||||||
APITypeXinference
|
APITypeXinference
|
||||||
APITypeXai
|
APITypeXai
|
||||||
|
APITypeCoze
|
||||||
APITypeDummy // this one is only for count, do not add any channel after this
|
APITypeDummy // this one is only for count, do not add any channel after this
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -95,6 +96,8 @@ func ChannelType2APIType(channelType int) (int, bool) {
|
|||||||
apiType = APITypeXinference
|
apiType = APITypeXinference
|
||||||
case common.ChannelTypeXai:
|
case common.ChannelTypeXai:
|
||||||
apiType = APITypeXai
|
apiType = APITypeXai
|
||||||
|
case common.ChannelTypeCoze:
|
||||||
|
apiType = APITypeCoze
|
||||||
}
|
}
|
||||||
if apiType == -1 {
|
if apiType == -1 {
|
||||||
return APITypeOpenAI, false
|
return APITypeOpenAI, false
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
"one-api/relay/channel/claude"
|
"one-api/relay/channel/claude"
|
||||||
"one-api/relay/channel/cloudflare"
|
"one-api/relay/channel/cloudflare"
|
||||||
"one-api/relay/channel/cohere"
|
"one-api/relay/channel/cohere"
|
||||||
|
"one-api/relay/channel/coze"
|
||||||
"one-api/relay/channel/deepseek"
|
"one-api/relay/channel/deepseek"
|
||||||
"one-api/relay/channel/dify"
|
"one-api/relay/channel/dify"
|
||||||
"one-api/relay/channel/gemini"
|
"one-api/relay/channel/gemini"
|
||||||
@@ -88,6 +89,8 @@ func GetAdaptor(apiType int) channel.Adaptor {
|
|||||||
return &openai.Adaptor{}
|
return &openai.Adaptor{}
|
||||||
case constant.APITypeXai:
|
case constant.APITypeXai:
|
||||||
return &xai.Adaptor{}
|
return &xai.Adaptor{}
|
||||||
|
case constant.APITypeCoze:
|
||||||
|
return &coze.Adaptor{}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user