diff --git a/common/constants.go b/common/constants.go
index 3c8d262a..f823cd3d 100644
--- a/common/constants.go
+++ b/common/constants.go
@@ -231,8 +231,9 @@ const (
ChannelTypeVertexAi = 41
ChannelTypeMistral = 42
ChannelTypeDeepSeek = 43
- ChannelTypeMokaAI = 47
- ChannelTypeDummy // this one is only for count, do not add any channel after this
+ ChannelTypeMokaAI = 47
+ ChannelTypeVolcEngine = 48
+ ChannelTypeDummy // this one is only for count, do not add any channel after this
)
@@ -281,5 +282,6 @@ var ChannelBaseURLs = []string{
"", //41
"https://api.mistral.ai", //42
"https://api.deepseek.com", //43
- "https://api.moka.ai", //43
+ "https://api.moka.ai", //43
+ "https://ark.cn-beijing.volces.com", //44
}
diff --git a/model/cache.go b/model/cache.go
index b6102200..bda1ed57 100644
--- a/model/cache.go
+++ b/model/cache.go
@@ -11,106 +11,6 @@ import (
"time"
)
-//func CacheGetUserGroup(id int) (group string, err error) {
-// if !common.RedisEnabled {
-// return GetUserGroup(id)
-// }
-// group, err = common.RedisGet(fmt.Sprintf("user_group:%d", id))
-// if err != nil {
-// group, err = GetUserGroup(id)
-// if err != nil {
-// return "", err
-// }
-// err = common.RedisSet(fmt.Sprintf("user_group:%d", id), group, time.Duration(constant.UserId2GroupCacheSeconds)*time.Second)
-// if err != nil {
-// common.SysError("Redis set user group error: " + err.Error())
-// }
-// }
-// return group, err
-//}
-//
-//func CacheGetUsername(id int) (username string, err error) {
-// if !common.RedisEnabled {
-// return GetUsernameById(id)
-// }
-// username, err = common.RedisGet(fmt.Sprintf("user_name:%d", id))
-// if err != nil {
-// username, err = GetUsernameById(id)
-// if err != nil {
-// return "", err
-// }
-// err = common.RedisSet(fmt.Sprintf("user_name:%d", id), username, time.Duration(constant.UserId2GroupCacheSeconds)*time.Second)
-// if err != nil {
-// common.SysError("Redis set user group error: " + err.Error())
-// }
-// }
-// return username, err
-//}
-//
-//func CacheGetUserQuota(id int) (quota int, err error) {
-// if !common.RedisEnabled {
-// return GetUserQuota(id)
-// }
-// quotaString, err := common.RedisGet(fmt.Sprintf("user_quota:%d", id))
-// if err != nil {
-// quota, err = GetUserQuota(id)
-// if err != nil {
-// return 0, err
-// }
-// return quota, nil
-// }
-// quota, err = strconv.Atoi(quotaString)
-// return quota, nil
-//}
-//
-//func CacheUpdateUserQuota(id int) error {
-// if !common.RedisEnabled {
-// return nil
-// }
-// quota, err := GetUserQuota(id)
-// if err != nil {
-// return err
-// }
-// return cacheSetUserQuota(id, quota)
-//}
-//
-//func cacheSetUserQuota(id int, quota int) error {
-// err := common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(constant.UserId2QuotaCacheSeconds)*time.Second)
-// return err
-//}
-//
-//func CacheDecreaseUserQuota(id int, quota int) error {
-// if !common.RedisEnabled {
-// return nil
-// }
-// err := common.RedisDecrease(fmt.Sprintf("user_quota:%d", id), int64(quota))
-// return err
-//}
-//
-//func CacheIsUserEnabled(userId int) (bool, error) {
-// if !common.RedisEnabled {
-// return IsUserEnabled(userId)
-// }
-// enabled, err := common.RedisGet(fmt.Sprintf("user_enabled:%d", userId))
-// if err == nil {
-// return enabled == "1", nil
-// }
-//
-// userEnabled, err := IsUserEnabled(userId)
-// if err != nil {
-// return false, err
-// }
-// enabled = "0"
-// if userEnabled {
-// enabled = "1"
-// }
-// err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, time.Duration(constant.UserId2StatusCacheSeconds)*time.Second)
-// if err != nil {
-// common.SysError("Redis set user enabled error: " + err.Error())
-// }
-// return userEnabled, err
-//}
-
var group2model2channels map[string]map[string][]*Channel
var channelsIDM map[int]*Channel
var channelSyncLock sync.RWMutex
diff --git a/relay/channel/volcengine/adaptor.go b/relay/channel/volcengine/adaptor.go
new file mode 100644
index 00000000..0be421f3
--- /dev/null
+++ b/relay/channel/volcengine/adaptor.go
@@ -0,0 +1,76 @@
+package volcengine
+
+import (
+ "errors"
+ "fmt"
+ "github.com/gin-gonic/gin"
+ "io"
+ "net/http"
+ "one-api/dto"
+ "one-api/relay/channel"
+ "one-api/relay/channel/openai"
+ relaycommon "one-api/relay/common"
+)
+
+type Adaptor struct {
+}
+
+func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
+ //TODO implement me
+ return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
+ //TODO implement me
+ return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
+}
+
+func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
+ return fmt.Sprintf("%s/chat/completions", info.BaseUrl), nil
+}
+
+func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
+ channel.SetupApiRequestHeader(info, c, req)
+ req.Set("Authorization", "Bearer "+info.ApiKey)
+ return nil
+}
+
+func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
+ if request == nil {
+ return nil, errors.New("request is nil")
+ }
+ return request, nil
+}
+
+func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
+ return nil, nil
+}
+
+func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
+ //TODO implement me
+ return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
+ return channel.DoApiRequest(a, c, info, requestBody)
+}
+
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
+ if info.IsStream {
+ err, usage = openai.OaiStreamHandler(c, resp, info)
+ } else {
+ err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
+ }
+ return
+}
+
+func (a *Adaptor) GetModelList() []string {
+ return ModelList
+}
+
+func (a *Adaptor) GetChannelName() string {
+ return ChannelName
+}
diff --git a/relay/channel/volcengine/constants.go b/relay/channel/volcengine/constants.go
new file mode 100644
index 00000000..30cc902e
--- /dev/null
+++ b/relay/channel/volcengine/constants.go
@@ -0,0 +1,13 @@
+package volcengine
+
+var ModelList = []string{
+ "Doubao-pro-128k",
+ "Doubao-pro-32k",
+ "Doubao-pro-4k",
+ "Doubao-lite-128k",
+ "Doubao-lite-32k",
+ "Doubao-lite-4k",
+ "Doubao-embedding",
+}
+
+var ChannelName = "volcengine"
diff --git a/relay/constant/api_type.go b/relay/constant/api_type.go
index 1a40a6ee..3ff9e233 100644
--- a/relay/constant/api_type.go
+++ b/relay/constant/api_type.go
@@ -28,6 +28,7 @@ const (
APITypeMistral
APITypeDeepSeek
APITypeMokaAI
+ APITypeVolcEngine
APITypeDummy // this one is only for count, do not add any channel after this
)
@@ -80,6 +81,8 @@ func ChannelType2APIType(channelType int) (int, bool) {
apiType = APITypeDeepSeek
case common.ChannelTypeMokaAI:
apiType = APITypeMokaAI
+ case common.ChannelTypeVolcEngine:
+ apiType = APITypeVolcEngine
}
if apiType == -1 {
return APITypeOpenAI, false
diff --git a/relay/relay_adaptor.go b/relay/relay_adaptor.go
index 9304bd6d..60baa45b 100644
--- a/relay/relay_adaptor.go
+++ b/relay/relay_adaptor.go
@@ -23,6 +23,7 @@ import (
"one-api/relay/channel/task/suno"
"one-api/relay/channel/tencent"
"one-api/relay/channel/vertex"
+ "one-api/relay/channel/volcengine"
"one-api/relay/channel/xunfei"
"one-api/relay/channel/zhipu"
"one-api/relay/channel/zhipu_4v"
@@ -77,6 +78,8 @@ func GetAdaptor(apiType int) channel.Adaptor {
return &deepseek.Adaptor{}
case constant.APITypeMokaAI:
return &mokaai.Adaptor{}
+ case constant.APITypeVolcEngine:
+ return &volcengine.Adaptor{}
}
return nil
}
diff --git a/web/src/components/ChannelsTable.js b/web/src/components/ChannelsTable.js
index d62c2f13..605103ae 100644
--- a/web/src/components/ChannelsTable.js
+++ b/web/src/components/ChannelsTable.js
@@ -53,11 +53,11 @@ const ChannelsTable = () => {
for (let i = 0; i < CHANNEL_OPTIONS.length; i++) {
type2label[CHANNEL_OPTIONS[i].value] = CHANNEL_OPTIONS[i];
}
- type2label[0] = { value: 0, text: t('未知类型'), color: 'grey' };
+ type2label[0] = { value: 0, label: t('未知类型'), color: 'grey' };
}
return (