From 28c13e5a0fd8c1aa0d2b426918da3cf3a9283a75 Mon Sep 17 00:00:00 2001 From: "1808837298@qq.com" <1808837298@qq.com> Date: Tue, 11 Feb 2025 23:47:15 +0800 Subject: [PATCH] feat: Add support for VolcEngine (Doubao) channel #313 #734 --- common/constants.go | 8 +- model/cache.go | 100 -------- relay/channel/volcengine/adaptor.go | 76 ++++++ relay/channel/volcengine/constants.go | 13 + relay/constant/api_type.go | 3 + relay/relay_adaptor.go | 3 + web/src/components/ChannelsTable.js | 4 +- web/src/constants/channel.constants.js | 53 ++-- web/src/i18n/locales/en.json | 3 +- web/src/pages/Channel/EditChannel.js | 319 +++++++++++++------------ 10 files changed, 285 insertions(+), 297 deletions(-) create mode 100644 relay/channel/volcengine/adaptor.go create mode 100644 relay/channel/volcengine/constants.go diff --git a/common/constants.go b/common/constants.go index 3c8d262a..f823cd3d 100644 --- a/common/constants.go +++ b/common/constants.go @@ -231,8 +231,9 @@ const ( ChannelTypeVertexAi = 41 ChannelTypeMistral = 42 ChannelTypeDeepSeek = 43 - ChannelTypeMokaAI = 47 - ChannelTypeDummy // this one is only for count, do not add any channel after this + ChannelTypeMokaAI = 47 + ChannelTypeVolcEngine = 48 + ChannelTypeDummy // this one is only for count, do not add any channel after this ) @@ -281,5 +282,6 @@ var ChannelBaseURLs = []string{ "", //41 "https://api.mistral.ai", //42 "https://api.deepseek.com", //43 - "https://api.moka.ai", //43 + "https://api.moka.ai", //43 + "https://ark.cn-beijing.volces.com", //44 } diff --git a/model/cache.go b/model/cache.go index b6102200..bda1ed57 100644 --- a/model/cache.go +++ b/model/cache.go @@ -11,106 +11,6 @@ import ( "time" ) -//func CacheGetUserGroup(id int) (group string, err error) { -// if !common.RedisEnabled { -// return GetUserGroup(id) -// } -// group, err = common.RedisGet(fmt.Sprintf("user_group:%d", id)) -// if err != nil { -// group, err = GetUserGroup(id) -// if err != nil { -// return "", err -// } -// err = common.RedisSet(fmt.Sprintf("user_group:%d", id), group, time.Duration(constant.UserId2GroupCacheSeconds)*time.Second) -// if err != nil { -// common.SysError("Redis set user group error: " + err.Error()) -// } -// } -// return group, err -//} -// -//func CacheGetUsername(id int) (username string, err error) { -// if !common.RedisEnabled { -// return GetUsernameById(id) -// } -// username, err = common.RedisGet(fmt.Sprintf("user_name:%d", id)) -// if err != nil { -// username, err = GetUsernameById(id) -// if err != nil { -// return "", err -// } -// err = common.RedisSet(fmt.Sprintf("user_name:%d", id), username, time.Duration(constant.UserId2GroupCacheSeconds)*time.Second) -// if err != nil { -// common.SysError("Redis set user group error: " + err.Error()) -// } -// } -// return username, err -//} -// -//func CacheGetUserQuota(id int) (quota int, err error) { -// if !common.RedisEnabled { -// return GetUserQuota(id) -// } -// quotaString, err := common.RedisGet(fmt.Sprintf("user_quota:%d", id)) -// if err != nil { -// quota, err = GetUserQuota(id) -// if err != nil { -// return 0, err -// } -// return quota, nil -// } -// quota, err = strconv.Atoi(quotaString) -// return quota, nil -//} -// -//func CacheUpdateUserQuota(id int) error { -// if !common.RedisEnabled { -// return nil -// } -// quota, err := GetUserQuota(id) -// if err != nil { -// return err -// } -// return cacheSetUserQuota(id, quota) -//} -// -//func cacheSetUserQuota(id int, quota int) error { -// err := common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(constant.UserId2QuotaCacheSeconds)*time.Second) -// return err -//} -// -//func CacheDecreaseUserQuota(id int, quota int) error { -// if !common.RedisEnabled { -// return nil -// } -// err := common.RedisDecrease(fmt.Sprintf("user_quota:%d", id), int64(quota)) -// return err -//} -// -//func CacheIsUserEnabled(userId int) (bool, error) { -// if !common.RedisEnabled { -// return IsUserEnabled(userId) -// } -// enabled, err := common.RedisGet(fmt.Sprintf("user_enabled:%d", userId)) -// if err == nil { -// return enabled == "1", nil -// } -// -// userEnabled, err := IsUserEnabled(userId) -// if err != nil { -// return false, err -// } -// enabled = "0" -// if userEnabled { -// enabled = "1" -// } -// err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, time.Duration(constant.UserId2StatusCacheSeconds)*time.Second) -// if err != nil { -// common.SysError("Redis set user enabled error: " + err.Error()) -// } -// return userEnabled, err -//} - var group2model2channels map[string]map[string][]*Channel var channelsIDM map[int]*Channel var channelSyncLock sync.RWMutex diff --git a/relay/channel/volcengine/adaptor.go b/relay/channel/volcengine/adaptor.go new file mode 100644 index 00000000..0be421f3 --- /dev/null +++ b/relay/channel/volcengine/adaptor.go @@ -0,0 +1,76 @@ +package volcengine + +import ( + "errors" + "fmt" + "github.com/gin-gonic/gin" + "io" + "net/http" + "one-api/dto" + "one-api/relay/channel" + "one-api/relay/channel/openai" + relaycommon "one-api/relay/common" +) + +type Adaptor struct { +} + +func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + +func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + +func (a *Adaptor) Init(info *relaycommon.RelayInfo) { +} + +func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { + return fmt.Sprintf("%s/chat/completions", info.BaseUrl), nil +} + +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { + channel.SetupApiRequestHeader(info, c, req) + req.Set("Authorization", "Bearer "+info.ApiKey) + return nil +} + +func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return request, nil +} + +func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) { + return nil, nil +} + +func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + +func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { + return channel.DoApiRequest(a, c, info, requestBody) +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { + if info.IsStream { + err, usage = openai.OaiStreamHandler(c, resp, info) + } else { + err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) + } + return +} + +func (a *Adaptor) GetModelList() []string { + return ModelList +} + +func (a *Adaptor) GetChannelName() string { + return ChannelName +} diff --git a/relay/channel/volcengine/constants.go b/relay/channel/volcengine/constants.go new file mode 100644 index 00000000..30cc902e --- /dev/null +++ b/relay/channel/volcengine/constants.go @@ -0,0 +1,13 @@ +package volcengine + +var ModelList = []string{ + "Doubao-pro-128k", + "Doubao-pro-32k", + "Doubao-pro-4k", + "Doubao-lite-128k", + "Doubao-lite-32k", + "Doubao-lite-4k", + "Doubao-embedding", +} + +var ChannelName = "volcengine" diff --git a/relay/constant/api_type.go b/relay/constant/api_type.go index 1a40a6ee..3ff9e233 100644 --- a/relay/constant/api_type.go +++ b/relay/constant/api_type.go @@ -28,6 +28,7 @@ const ( APITypeMistral APITypeDeepSeek APITypeMokaAI + APITypeVolcEngine APITypeDummy // this one is only for count, do not add any channel after this ) @@ -80,6 +81,8 @@ func ChannelType2APIType(channelType int) (int, bool) { apiType = APITypeDeepSeek case common.ChannelTypeMokaAI: apiType = APITypeMokaAI + case common.ChannelTypeVolcEngine: + apiType = APITypeVolcEngine } if apiType == -1 { return APITypeOpenAI, false diff --git a/relay/relay_adaptor.go b/relay/relay_adaptor.go index 9304bd6d..60baa45b 100644 --- a/relay/relay_adaptor.go +++ b/relay/relay_adaptor.go @@ -23,6 +23,7 @@ import ( "one-api/relay/channel/task/suno" "one-api/relay/channel/tencent" "one-api/relay/channel/vertex" + "one-api/relay/channel/volcengine" "one-api/relay/channel/xunfei" "one-api/relay/channel/zhipu" "one-api/relay/channel/zhipu_4v" @@ -77,6 +78,8 @@ func GetAdaptor(apiType int) channel.Adaptor { return &deepseek.Adaptor{} case constant.APITypeMokaAI: return &mokaai.Adaptor{} + case constant.APITypeVolcEngine: + return &volcengine.Adaptor{} } return nil } diff --git a/web/src/components/ChannelsTable.js b/web/src/components/ChannelsTable.js index d62c2f13..605103ae 100644 --- a/web/src/components/ChannelsTable.js +++ b/web/src/components/ChannelsTable.js @@ -53,11 +53,11 @@ const ChannelsTable = () => { for (let i = 0; i < CHANNEL_OPTIONS.length; i++) { type2label[CHANNEL_OPTIONS[i].value] = CHANNEL_OPTIONS[i]; } - type2label[0] = { value: 0, text: t('未知类型'), color: 'grey' }; + type2label[0] = { value: 0, label: t('未知类型'), color: 'grey' }; } return ( - {type2label[type]?.text} + {type2label[type]?.label} ); }; diff --git a/web/src/constants/channel.constants.js b/web/src/constants/channel.constants.js index c1be95b7..03628c7c 100644 --- a/web/src/constants/channel.constants.js +++ b/web/src/constants/channel.constants.js @@ -1,134 +1,123 @@ export const CHANNEL_OPTIONS = [ - { key: 1, text: 'OpenAI', value: 1, color: 'green', label: 'OpenAI' }, + { key: 1, value: 1, color: 'green', label: 'OpenAI' }, { key: 2, - text: 'Midjourney Proxy', value: 2, color: 'light-blue', label: 'Midjourney Proxy' }, { key: 5, - text: 'Midjourney Proxy Plus', value: 5, color: 'blue', label: 'Midjourney Proxy Plus' }, { key: 36, - text: 'Suno API', value: 36, color: 'purple', label: 'Suno API' }, - { key: 4, text: 'Ollama', value: 4, color: 'grey', label: 'Ollama' }, + { key: 4, value: 4, color: 'grey', label: 'Ollama' }, { key: 14, - text: 'Anthropic Claude', value: 14, color: 'indigo', label: 'Anthropic Claude' }, { key: 33, - text: 'AWS Claude', value: 33, color: 'indigo', label: 'AWS Claude' }, - { key: 41, text: 'Vertex AI', value: 41, color: 'blue', label: 'Vertex AI' }, + { key: 41, value: 41, color: 'blue', label: 'Vertex AI' }, { key: 3, - text: 'Azure OpenAI', value: 3, color: 'teal', label: 'Azure OpenAI' }, { key: 34, - text: 'Cohere', value: 34, color: 'purple', label: 'Cohere' }, - { key: 39, text: 'Cloudflare', value: 39, color: 'grey', label: 'Cloudflare' }, - { key: 43, text: 'DeepSeek', value: 43, color: 'blue', label: 'DeepSeek' }, + { key: 39, value: 39, color: 'grey', label: 'Cloudflare' }, + { key: 43, value: 43, color: 'blue', label: 'DeepSeek' }, { key: 15, - text: '百度文心千帆', value: 15, color: 'blue', label: '百度文心千帆' }, { key: 17, - text: '阿里通义千问', value: 17, color: 'orange', label: '阿里通义千问' }, { key: 18, - text: '讯飞星火认知', value: 18, color: 'blue', label: '讯飞星火认知' }, { key: 16, - text: '智谱 ChatGLM', value: 16, color: 'violet', label: '智谱 ChatGLM' }, { key: 26, - text: '智谱 GLM-4V', value: 26, color: 'purple', label: '智谱 GLM-4V' }, { key: 24, - text: 'Google Gemini', value: 24, color: 'orange', label: 'Google Gemini' }, { key: 11, - text: 'Google PaLM2', value: 11, color: 'orange', label: 'Google PaLM2' }, - { key: 25, text: 'Moonshot', value: 25, color: 'green', label: 'Moonshot' }, - { key: 19, text: '360 智脑', value: 19, color: 'blue', label: '360 智脑' }, - { key: 23, text: '腾讯混元', value: 23, color: 'teal', label: '腾讯混元' }, - { key: 31, text: '零一万物', value: 31, color: 'green', label: '零一万物' }, - { key: 35, text: 'MiniMax', value: 35, color: 'green', label: 'MiniMax' }, - { key: 37, text: 'Dify', value: 37, color: 'teal', label: 'Dify' }, - { key: 38, text: 'Jina', value: 38, color: 'blue', label: 'Jina' }, - { key: 40, text: 'SiliconCloud', value: 40, color: 'purple', label: 'SiliconCloud' }, - { key: 42, text: 'Mistral AI', value: 42, color: 'blue', label: 'Mistral AI' }, - { key: 8, text: '自定义渠道', value: 8, color: 'pink', label: '自定义渠道' }, + { + key: 48, + value: 48, + color: 'blue', + label: '火山方舟(豆包)' + }, + { key: 25, value: 25, color: 'green', label: 'Moonshot' }, + { key: 19, value: 19, color: 'blue', label: '360 智脑' }, + { key: 23, value: 23, color: 'teal', label: '腾讯混元' }, + { key: 31, value: 31, color: 'green', label: '零一万物' }, + { key: 35, value: 35, color: 'green', label: 'MiniMax' }, + { key: 37, value: 37, color: 'teal', label: 'Dify' }, + { key: 38, value: 38, color: 'blue', label: 'Jina' }, + { key: 40, value: 40, color: 'purple', label: 'SiliconCloud' }, + { key: 42, value: 42, color: 'blue', label: 'Mistral AI' }, + { key: 8, value: 8, color: 'pink', label: '自定义渠道' }, { key: 22, - text: '知识库:FastGPT', value: 22, color: 'blue', label: '知识库:FastGPT' }, { key: 21, - text: '知识库:AI Proxy', value: 21, color: 'purple', label: '知识库:AI Proxy' }, { key: 47, - text: '嵌入模型:MokaAI M3E', value: 47, color: 'purple', label: '嵌入模型:MokaAI M3E' diff --git a/web/src/i18n/locales/en.json b/web/src/i18n/locales/en.json index ec2769c2..3943a19d 100644 --- a/web/src/i18n/locales/en.json +++ b/web/src/i18n/locales/en.json @@ -498,8 +498,7 @@ "请输入用户名": "Please enter username", "请输入显示名称": "Please enter display name", "请输入密码": "Please enter password", - "模型部署名称必须和模型名称保持一致": "The model deployment name must be consistent with the model name", - ",因为 One API 会把请求体中的 model": ", because One API will take the model in the request body", + "注意,模型部署名称必须和模型名称保持一致": "Note that the model deployment name must be consistent with the model name", "请输入 AZURE_OPENAI_ENDPOINT": "Please enter AZURE_OPENAI_ENDPOINT", "请输入自定义渠道的 Base URL": "Please enter the Base URL of the custom channel", "Homepage URL 填": "Fill in the Homepage URL", diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js index 3e99b7da..b80bf93b 100644 --- a/web/src/pages/Channel/EditChannel.js +++ b/web/src/pages/Channel/EditChannel.js @@ -438,13 +438,16 @@ const EditChannel = (props) => { value={inputs.type} onChange={(value) => handleInputChange('type', value)} style={{ width: '50%' }} + filter + searchPosition='dropdown' + placeholder={t('请选择渠道类型')} /> {inputs.type === 3 && ( <>
@@ -501,6 +504,19 @@ const EditChannel = (props) => { /> )} +
+ {t('名称')}: +
+ { + handleInputChange('name', value); + }} + value={inputs.name} + autoComplete="new-password" + /> {inputs.type !== 3 && inputs.type !== 8 && inputs.type !== 22 && inputs.type !== 36 && ( <>
@@ -518,6 +534,77 @@ const EditChannel = (props) => { /> )} +
+ {t('密钥')}: +
+ {batch ? ( +