From eceb6afcdd8caee5bb37671852ea5fba087c5599 Mon Sep 17 00:00:00 2001 From: "1808837298@qq.com" <1808837298@qq.com> Date: Wed, 12 Feb 2025 00:07:02 +0800 Subject: [PATCH] feat: Add Baidu Qianfan V2 channel support #725 - Update channel constants to include Baidu V2 channel - Create new Baidu V2 adaptor for relay - Add Baidu V2 models and channel configuration - Update relay adaptor to support Baidu V2 channel - Modify web channel constants to include Baidu V2 option --- common/constants.go | 10 ++-- relay/channel/baidu_v2/adaptor.go | 76 ++++++++++++++++++++++++++ relay/channel/baidu_v2/constants.go | 29 ++++++++++ relay/channel/deepseek/adaptor.go | 3 +- relay/channel/volcengine/adaptor.go | 24 ++++++-- relay/constant/api_type.go | 3 + relay/relay_adaptor.go | 3 + web/src/constants/channel.constants.js | 57 ++++++++----------- 8 files changed, 158 insertions(+), 47 deletions(-) create mode 100644 relay/channel/baidu_v2/adaptor.go create mode 100644 relay/channel/baidu_v2/constants.go diff --git a/common/constants.go b/common/constants.go index f823cd3d..f967d066 100644 --- a/common/constants.go +++ b/common/constants.go @@ -231,8 +231,9 @@ const ( ChannelTypeVertexAi = 41 ChannelTypeMistral = 42 ChannelTypeDeepSeek = 43 - ChannelTypeMokaAI = 47 - ChannelTypeVolcEngine = 48 + ChannelTypeMokaAI = 44 + ChannelTypeVolcEngine = 45 + ChannelTypeBaiduV2 = 46 ChannelTypeDummy // this one is only for count, do not add any channel after this ) @@ -282,6 +283,7 @@ var ChannelBaseURLs = []string{ "", //41 "https://api.mistral.ai", //42 "https://api.deepseek.com", //43 - "https://api.moka.ai", //43 - "https://ark.cn-beijing.volces.com", //44 + "https://api.moka.ai", //44 + "https://ark.cn-beijing.volces.com", //45 + "https://qianfan.baidubce.com", //46 } diff --git a/relay/channel/baidu_v2/adaptor.go b/relay/channel/baidu_v2/adaptor.go new file mode 100644 index 00000000..fd25ecc1 --- /dev/null +++ b/relay/channel/baidu_v2/adaptor.go @@ -0,0 +1,76 @@ +package baidu_v2 + +import ( + "errors" + "fmt" + "github.com/gin-gonic/gin" + "io" + "net/http" + "one-api/dto" + "one-api/relay/channel" + "one-api/relay/channel/openai" + relaycommon "one-api/relay/common" +) + +type Adaptor struct { +} + +func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + +func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + +func (a *Adaptor) Init(info *relaycommon.RelayInfo) { +} + +func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { + return fmt.Sprintf("%s/v2/chat/completions", info.BaseUrl), nil +} + +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { + channel.SetupApiRequestHeader(info, c, req) + req.Set("Authorization", "Bearer "+info.ApiKey) + return nil +} + +func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return request, nil +} + +func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) { + return nil, nil +} + +func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + +func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { + return channel.DoApiRequest(a, c, info, requestBody) +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { + if info.IsStream { + err, usage = openai.OaiStreamHandler(c, resp, info) + } else { + err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) + } + return +} + +func (a *Adaptor) GetModelList() []string { + return ModelList +} + +func (a *Adaptor) GetChannelName() string { + return ChannelName +} diff --git a/relay/channel/baidu_v2/constants.go b/relay/channel/baidu_v2/constants.go new file mode 100644 index 00000000..a7cee248 --- /dev/null +++ b/relay/channel/baidu_v2/constants.go @@ -0,0 +1,29 @@ +package baidu_v2 + +var ModelList = []string{ + "ernie-4.0-8k-latest", + "ernie-4.0-8k-preview", + "ernie-4.0-8k", + "ernie-4.0-turbo-8k-latest", + "ernie-4.0-turbo-8k-preview", + "ernie-4.0-turbo-8k", + "ernie-4.0-turbo-128k", + "ernie-3.5-8k-preview", + "ernie-3.5-8k", + "ernie-3.5-128k", + "ernie-speed-8k", + "ernie-speed-128k", + "ernie-speed-pro-128k", + "ernie-lite-8k", + "ernie-lite-pro-128k", + "ernie-tiny-8k", + "ernie-char-8k", + "ernie-char-fiction-8k", + "ernie-novel-8k", + "deepseek-v3", + "deepseek-r1", + "deepseek-r1-distill-qwen-32b", + "deepseek-r1-distill-qwen-14b", +} + +var ChannelName = "volcengine" diff --git a/relay/channel/deepseek/adaptor.go b/relay/channel/deepseek/adaptor.go index 1682dc3f..14dd74f0 100644 --- a/relay/channel/deepseek/adaptor.go +++ b/relay/channel/deepseek/adaptor.go @@ -29,7 +29,7 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { - return fmt.Sprintf("%s/chat/completions", info.BaseUrl), nil + return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil } func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { @@ -54,7 +54,6 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela return nil, errors.New("not implemented") } - func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } diff --git a/relay/channel/volcengine/adaptor.go b/relay/channel/volcengine/adaptor.go index 0be421f3..8cffebba 100644 --- a/relay/channel/volcengine/adaptor.go +++ b/relay/channel/volcengine/adaptor.go @@ -10,6 +10,7 @@ import ( "one-api/relay/channel" "one-api/relay/channel/openai" relaycommon "one-api/relay/common" + "one-api/relay/constant" ) type Adaptor struct { @@ -29,7 +30,14 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { - return fmt.Sprintf("%s/chat/completions", info.BaseUrl), nil + switch info.RelayMode { + case constant.RelayModeChatCompletions: + return fmt.Sprintf("%s/api/v3/chat/completions", info.BaseUrl), nil + case constant.RelayModeEmbeddings: + return fmt.Sprintf("%s/api/v3/embeddings", info.BaseUrl), nil + default: + } + return "", fmt.Errorf("unsupported relay mode: %d", info.RelayMode) } func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { @@ -50,8 +58,7 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt } func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) { - //TODO implement me - return nil, errors.New("not implemented") + return request, nil } func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { @@ -59,9 +66,14 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request } func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { - if info.IsStream { - err, usage = openai.OaiStreamHandler(c, resp, info) - } else { + switch info.RelayMode { + case constant.RelayModeChatCompletions: + if info.IsStream { + err, usage = openai.OaiStreamHandler(c, resp, info) + } else { + err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) + } + case constant.RelayModeEmbeddings: err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) } return diff --git a/relay/constant/api_type.go b/relay/constant/api_type.go index 3ff9e233..f7a87536 100644 --- a/relay/constant/api_type.go +++ b/relay/constant/api_type.go @@ -29,6 +29,7 @@ const ( APITypeDeepSeek APITypeMokaAI APITypeVolcEngine + APITypeBaiduV2 APITypeDummy // this one is only for count, do not add any channel after this ) @@ -83,6 +84,8 @@ func ChannelType2APIType(channelType int) (int, bool) { apiType = APITypeMokaAI case common.ChannelTypeVolcEngine: apiType = APITypeVolcEngine + case common.ChannelTypeBaiduV2: + apiType = APITypeBaiduV2 } if apiType == -1 { return APITypeOpenAI, false diff --git a/relay/relay_adaptor.go b/relay/relay_adaptor.go index 60baa45b..c9111106 100644 --- a/relay/relay_adaptor.go +++ b/relay/relay_adaptor.go @@ -6,6 +6,7 @@ import ( "one-api/relay/channel/ali" "one-api/relay/channel/aws" "one-api/relay/channel/baidu" + "one-api/relay/channel/baidu_v2" "one-api/relay/channel/claude" "one-api/relay/channel/cloudflare" "one-api/relay/channel/cohere" @@ -80,6 +81,8 @@ func GetAdaptor(apiType int) channel.Adaptor { return &mokaai.Adaptor{} case constant.APITypeVolcEngine: return &volcengine.Adaptor{} + case constant.APITypeBaiduV2: + return &baidu_v2.Adaptor{} } return nil } diff --git a/web/src/constants/channel.constants.js b/web/src/constants/channel.constants.js index 03628c7c..dec74b06 100644 --- a/web/src/constants/channel.constants.js +++ b/web/src/constants/channel.constants.js @@ -1,124 +1,111 @@ export const CHANNEL_OPTIONS = [ - { key: 1, value: 1, color: 'green', label: 'OpenAI' }, + { value: 1, color: 'green', label: 'OpenAI' }, { - key: 2, value: 2, color: 'light-blue', label: 'Midjourney Proxy' }, { - key: 5, value: 5, color: 'blue', label: 'Midjourney Proxy Plus' }, { - key: 36, value: 36, color: 'purple', label: 'Suno API' }, - { key: 4, value: 4, color: 'grey', label: 'Ollama' }, + { value: 4, color: 'grey', label: 'Ollama' }, { - key: 14, value: 14, color: 'indigo', label: 'Anthropic Claude' }, { - key: 33, value: 33, color: 'indigo', label: 'AWS Claude' }, - { key: 41, value: 41, color: 'blue', label: 'Vertex AI' }, + { value: 41, color: 'blue', label: 'Vertex AI' }, { - key: 3, value: 3, color: 'teal', label: 'Azure OpenAI' }, { - key: 34, value: 34, color: 'purple', label: 'Cohere' }, - { key: 39, value: 39, color: 'grey', label: 'Cloudflare' }, - { key: 43, value: 43, color: 'blue', label: 'DeepSeek' }, + { value: 39, color: 'grey', label: 'Cloudflare' }, + { value: 43, color: 'blue', label: 'DeepSeek' }, { - key: 15, value: 15, color: 'blue', label: '百度文心千帆' }, { - key: 17, + value: 46, + color: 'blue', + label: '百度文心千帆V2' + }, + { value: 17, color: 'orange', label: '阿里通义千问' }, { - key: 18, value: 18, color: 'blue', label: '讯飞星火认知' }, { - key: 16, value: 16, color: 'violet', label: '智谱 ChatGLM' }, { - key: 26, value: 26, color: 'purple', label: '智谱 GLM-4V' }, { - key: 24, value: 24, color: 'orange', label: 'Google Gemini' }, { - key: 11, value: 11, color: 'orange', label: 'Google PaLM2' }, { - key: 48, - value: 48, + value: 45, color: 'blue', label: '火山方舟(豆包)' }, - { key: 25, value: 25, color: 'green', label: 'Moonshot' }, - { key: 19, value: 19, color: 'blue', label: '360 智脑' }, - { key: 23, value: 23, color: 'teal', label: '腾讯混元' }, - { key: 31, value: 31, color: 'green', label: '零一万物' }, - { key: 35, value: 35, color: 'green', label: 'MiniMax' }, - { key: 37, value: 37, color: 'teal', label: 'Dify' }, - { key: 38, value: 38, color: 'blue', label: 'Jina' }, - { key: 40, value: 40, color: 'purple', label: 'SiliconCloud' }, - { key: 42, value: 42, color: 'blue', label: 'Mistral AI' }, - { key: 8, value: 8, color: 'pink', label: '自定义渠道' }, + { value: 25, color: 'green', label: 'Moonshot' }, + { value: 19, color: 'blue', label: '360 智脑' }, + { value: 23, color: 'teal', label: '腾讯混元' }, + { value: 31, color: 'green', label: '零一万物' }, + { value: 35, color: 'green', label: 'MiniMax' }, + { value: 37, color: 'teal', label: 'Dify' }, + { value: 38, color: 'blue', label: 'Jina' }, + { value: 40, color: 'purple', label: 'SiliconCloud' }, + { value: 42, color: 'blue', label: 'Mistral AI' }, + { value: 8, color: 'pink', label: '自定义渠道' }, { - key: 22, value: 22, color: 'blue', label: '知识库:FastGPT' }, { - key: 21, value: 21, color: 'purple', label: '知识库:AI Proxy' }, { - key: 47, - value: 47, + value: 44, color: 'purple', label: '嵌入模型:MokaAI M3E' }