🔧 refactor(auth, channel, context): improve context setup and validation for multi-key channels
This commit is contained in:
@@ -31,16 +31,30 @@ func MapToJsonStr(m map[string]interface{}) string {
|
|||||||
return string(bytes)
|
return string(bytes)
|
||||||
}
|
}
|
||||||
|
|
||||||
func StrToMap(str string) map[string]interface{} {
|
func StrToMap(str string) (map[string]interface{}, error) {
|
||||||
m := make(map[string]interface{})
|
m := make(map[string]interface{})
|
||||||
err := json.Unmarshal([]byte(str), &m)
|
err := UnmarshalJson([]byte(str), &m)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil
|
return nil, err
|
||||||
}
|
}
|
||||||
return m
|
return m, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func IsJsonStr(str string) bool {
|
func StrToJsonArray(str string) ([]interface{}, error) {
|
||||||
|
var js []interface{}
|
||||||
|
err := json.Unmarshal([]byte(str), &js)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return js, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func IsJsonArray(str string) bool {
|
||||||
|
var js []interface{}
|
||||||
|
return json.Unmarshal([]byte(str), &js) == nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func IsJsonObject(str string) bool {
|
||||||
var js map[string]interface{}
|
var js map[string]interface{}
|
||||||
return json.Unmarshal([]byte(str), &js) == nil
|
return json.Unmarshal([]byte(str), &js) == nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -17,11 +17,18 @@ const (
|
|||||||
ContextKeyTokenModelLimit ContextKey = "token_model_limit"
|
ContextKeyTokenModelLimit ContextKey = "token_model_limit"
|
||||||
|
|
||||||
/* channel related keys */
|
/* channel related keys */
|
||||||
ContextKeyBaseUrl ContextKey = "base_url"
|
|
||||||
ContextKeyChannelType ContextKey = "channel_type"
|
|
||||||
ContextKeyChannelId ContextKey = "channel_id"
|
ContextKeyChannelId ContextKey = "channel_id"
|
||||||
|
ContextKeyChannelName ContextKey = "channel_name"
|
||||||
|
ContextKeyChannelCreateTime ContextKey = "channel_create_name"
|
||||||
|
ContextKeyChannelBaseUrl ContextKey = "base_url"
|
||||||
|
ContextKeyChannelType ContextKey = "channel_type"
|
||||||
ContextKeyChannelSetting ContextKey = "channel_setting"
|
ContextKeyChannelSetting ContextKey = "channel_setting"
|
||||||
ContextKeyParamOverride ContextKey = "param_override"
|
ContextKeyChannelParamOverride ContextKey = "param_override"
|
||||||
|
ContextKeyChannelOrganization ContextKey = "channel_organization"
|
||||||
|
ContextKeyChannelAutoBan ContextKey = "auto_ban"
|
||||||
|
ContextKeyChannelModelMapping ContextKey = "model_mapping"
|
||||||
|
ContextKeyChannelStatusCodeMapping ContextKey = "status_code_mapping"
|
||||||
|
ContextKeyChannelIsMultiKey ContextKey = "channel_is_multi_key"
|
||||||
|
|
||||||
/* user related keys */
|
/* user related keys */
|
||||||
ContextKeyUserId ContextKey = "id"
|
ContextKeyUserId ContextKey = "id"
|
||||||
|
|||||||
8
constant/multi_key_mode.go
Normal file
8
constant/multi_key_mode.go
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
package constant
|
||||||
|
|
||||||
|
type MultiKeyMode string
|
||||||
|
|
||||||
|
const (
|
||||||
|
MultiKeyModeRandom MultiKeyMode = "random" // 随机
|
||||||
|
MultiKeyModePolling MultiKeyMode = "polling" // 轮询
|
||||||
|
)
|
||||||
@@ -379,9 +379,32 @@ func GetChannel(c *gin.Context) {
|
|||||||
|
|
||||||
type AddChannelRequest struct {
|
type AddChannelRequest struct {
|
||||||
Mode string `json:"mode"`
|
Mode string `json:"mode"`
|
||||||
|
MultiKeyMode constant.MultiKeyMode `json:"multi_key_mode"`
|
||||||
Channel *model.Channel `json:"channel"`
|
Channel *model.Channel `json:"channel"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getVertexArrayKeys(keys string) ([]string, error) {
|
||||||
|
if keys == "" {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
var keyArray []interface{}
|
||||||
|
err := common.UnmarshalJson([]byte(keys), &keyArray)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("批量添加 Vertex AI 必须使用标准的JsonArray格式,例如[{key1}, {key2}...],请检查输入: %w", err)
|
||||||
|
}
|
||||||
|
cleanKeys := make([]string, 0, len(keyArray))
|
||||||
|
for _, key := range keyArray {
|
||||||
|
keyStr := fmt.Sprintf("%v", key)
|
||||||
|
if keyStr != "" {
|
||||||
|
cleanKeys = append(cleanKeys, strings.TrimSpace(keyStr))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(cleanKeys) == 0 {
|
||||||
|
return nil, fmt.Errorf("批量添加 Vertex AI 的 keys 不能为空")
|
||||||
|
}
|
||||||
|
return cleanKeys, nil
|
||||||
|
}
|
||||||
|
|
||||||
func AddChannel(c *gin.Context) {
|
func AddChannel(c *gin.Context) {
|
||||||
addChannelRequest := AddChannelRequest{}
|
addChannelRequest := AddChannelRequest{}
|
||||||
err := c.ShouldBindJSON(&addChannelRequest)
|
err := c.ShouldBindJSON(&addChannelRequest)
|
||||||
@@ -418,9 +441,14 @@ func AddChannel(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
return
|
return
|
||||||
} else {
|
} else {
|
||||||
if common.IsJsonStr(addChannelRequest.Channel.Other) {
|
regionMap, err := common.StrToMap(addChannelRequest.Channel.Other)
|
||||||
// must have default
|
if err != nil {
|
||||||
regionMap := common.StrToMap(addChannelRequest.Channel.Other)
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "部署地区必须是标准的Json格式,例如{\"default\": \"us-central1\", \"region2\": \"us-east1\"}",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
if regionMap["default"] == nil {
|
if regionMap["default"] == nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
@@ -430,57 +458,46 @@ func AddChannel(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
addChannelRequest.Channel.CreatedTime = common.GetTimestamp()
|
addChannelRequest.Channel.CreatedTime = common.GetTimestamp()
|
||||||
keys := make([]string, 0)
|
keys := make([]string, 0)
|
||||||
switch addChannelRequest.Mode {
|
switch addChannelRequest.Mode {
|
||||||
case "multi_to_single":
|
case "multi_to_single":
|
||||||
addChannelRequest.Channel.ChannelInfo.MultiKeyMode = true
|
addChannelRequest.Channel.ChannelInfo.IsMultiKey = true
|
||||||
|
addChannelRequest.Channel.ChannelInfo.MultiKeyMode = addChannelRequest.MultiKeyMode
|
||||||
if addChannelRequest.Channel.Type == constant.ChannelTypeVertexAi {
|
if addChannelRequest.Channel.Type == constant.ChannelTypeVertexAi {
|
||||||
if !common.IsJsonStr(addChannelRequest.Channel.Key) {
|
array, err := getVertexArrayKeys(addChannelRequest.Channel.Key)
|
||||||
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
"message": "Vertex AI 批量添加模式必须使用标准的JsonArray格式,例如[{key1}, {key2}...],请检查输入",
|
"message": err.Error(),
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
toMap := common.StrToMap(addChannelRequest.Channel.Key)
|
addChannelRequest.Channel.Key = strings.Join(array, "\n")
|
||||||
if toMap != nil {
|
|
||||||
addChannelRequest.Channel.ChannelInfo.MultiKeySize = len(toMap)
|
|
||||||
} else {
|
|
||||||
addChannelRequest.Channel.ChannelInfo.MultiKeySize = 0
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
cleanKeys := make([]string, 0)
|
cleanKeys := make([]string, 0)
|
||||||
for _, key := range strings.Split(addChannelRequest.Channel.Key, "\n") {
|
for _, key := range strings.Split(addChannelRequest.Channel.Key, "\n") {
|
||||||
if key == "" {
|
if key == "" {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
key = strings.TrimSpace(key)
|
||||||
cleanKeys = append(cleanKeys, key)
|
cleanKeys = append(cleanKeys, key)
|
||||||
}
|
}
|
||||||
addChannelRequest.Channel.ChannelInfo.MultiKeySize = len(cleanKeys)
|
|
||||||
addChannelRequest.Channel.Key = strings.Join(cleanKeys, "\n")
|
addChannelRequest.Channel.Key = strings.Join(cleanKeys, "\n")
|
||||||
}
|
}
|
||||||
keys = []string{addChannelRequest.Channel.Key}
|
keys = []string{addChannelRequest.Channel.Key}
|
||||||
case "batch":
|
case "batch":
|
||||||
if addChannelRequest.Channel.Type == constant.ChannelTypeVertexAi {
|
if addChannelRequest.Channel.Type == constant.ChannelTypeVertexAi {
|
||||||
// multi json
|
// multi json
|
||||||
toMap := common.StrToMap(addChannelRequest.Channel.Key)
|
keys, err = getVertexArrayKeys(addChannelRequest.Channel.Key)
|
||||||
if toMap == nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
"message": "Vertex AI 批量添加模式必须使用标准的JsonArray格式,例如[{key1}, {key2}...],请检查输入",
|
"message": err.Error(),
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
keys = make([]string, 0, len(toMap))
|
|
||||||
for k := range toMap {
|
|
||||||
if k == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
keys = append(keys, k)
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
keys = strings.Split(addChannelRequest.Channel.Key, "\n")
|
keys = strings.Split(addChannelRequest.Channel.Key, "\n")
|
||||||
}
|
}
|
||||||
@@ -694,9 +711,14 @@ func UpdateChannel(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
return
|
return
|
||||||
} else {
|
} else {
|
||||||
if common.IsJsonStr(channel.Other) {
|
regionMap, err := common.StrToMap(channel.Other)
|
||||||
// must have default
|
if err != nil {
|
||||||
regionMap := common.StrToMap(channel.Other)
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "部署地区必须是标准的Json格式,例如{\"default\": \"us-central1\", \"region2\": \"us-east1\"}",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
if regionMap["default"] == nil {
|
if regionMap["default"] == nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
@@ -706,7 +728,6 @@ func UpdateChannel(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
err = channel.Update()
|
err = channel.Update()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
|||||||
@@ -57,18 +57,24 @@ func Playground(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
c.Set("group", group)
|
c.Set("group", group)
|
||||||
}
|
}
|
||||||
c.Set("token_name", "playground-"+group)
|
|
||||||
channel, finalGroup, err := model.CacheGetRandomSatisfiedChannel(c, group, playgroundRequest.Model, 0)
|
userId := c.GetInt("id")
|
||||||
|
//c.Set("token_name", "playground-"+group)
|
||||||
|
tempToken := &model.Token{
|
||||||
|
UserId: userId,
|
||||||
|
Name: fmt.Sprintf("playground-%s", group),
|
||||||
|
Group: group,
|
||||||
|
}
|
||||||
|
_ = middleware.SetupContextForToken(c, tempToken)
|
||||||
|
_, err = getChannel(c, group, playgroundRequest.Model, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", finalGroup, playgroundRequest.Model)
|
openaiErr = service.OpenAIErrorWrapperLocal(err, "get_playground_channel_failed", http.StatusInternalServerError)
|
||||||
openaiErr = service.OpenAIErrorWrapperLocal(errors.New(message), "get_playground_channel_failed", http.StatusInternalServerError)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
middleware.SetupContextForSelectedChannel(c, channel, playgroundRequest.Model)
|
//middleware.SetupContextForSelectedChannel(c, channel, playgroundRequest.Model)
|
||||||
common.SetContextKey(c, constant.ContextKeyRequestStartTime, time.Now())
|
common.SetContextKey(c, constant.ContextKeyRequestStartTime, time.Now())
|
||||||
|
|
||||||
// Write user context to ensure acceptUnsetRatio is available
|
// Write user context to ensure acceptUnsetRatio is available
|
||||||
userId := c.GetInt("id")
|
|
||||||
userCache, err := model.GetUserCache(userId)
|
userCache, err := model.GetUserCache(userId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
openaiErr = service.OpenAIErrorWrapperLocal(err, "get_user_cache_failed", http.StatusInternalServerError)
|
openaiErr = service.OpenAIErrorWrapperLocal(err, "get_user_cache_failed", http.StatusInternalServerError)
|
||||||
|
|||||||
@@ -259,9 +259,12 @@ func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*m
|
|||||||
AutoBan: &autoBanInt,
|
AutoBan: &autoBanInt,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
channel, _, err := model.CacheGetRandomSatisfiedChannel(c, group, originalModel, retryCount)
|
channel, selectGroup, err := model.CacheGetRandomSatisfiedChannel(c, group, originalModel, retryCount)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.New(fmt.Sprintf("获取重试渠道失败: %s", err.Error()))
|
if group == "auto" {
|
||||||
|
return nil, errors.New(fmt.Sprintf("获取自动分组下模型 %s 的可用渠道失败: %s", originalModel, err.Error()))
|
||||||
|
}
|
||||||
|
return nil, errors.New(fmt.Sprintf("获取分组 %s 下模型 %s 的可用渠道失败: %s", selectGroup, originalModel, err.Error()))
|
||||||
}
|
}
|
||||||
middleware.SetupContextForSelectedChannel(c, channel, originalModel)
|
middleware.SetupContextForSelectedChannel(c, channel, originalModel)
|
||||||
return channel, nil
|
return channel, nil
|
||||||
@@ -388,9 +391,10 @@ func RelayTask(c *gin.Context) {
|
|||||||
retryTimes = 0
|
retryTimes = 0
|
||||||
}
|
}
|
||||||
for i := 0; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && i < retryTimes; i++ {
|
for i := 0; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && i < retryTimes; i++ {
|
||||||
channel, _, err := model.CacheGetRandomSatisfiedChannel(c, group, originalModel, i)
|
channel, err := getChannel(c, group, originalModel, i)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", err.Error()))
|
common.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", err.Error()))
|
||||||
|
taskErr = service.TaskErrorWrapperLocal(err, "get_channel_failed", http.StatusInternalServerError)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
channelId = channel.Id
|
channelId = channel.Id
|
||||||
@@ -398,7 +402,7 @@ func RelayTask(c *gin.Context) {
|
|||||||
useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
|
useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
|
||||||
c.Set("use_channel", useChannel)
|
c.Set("use_channel", useChannel)
|
||||||
common.LogInfo(c, fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i))
|
common.LogInfo(c, fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i))
|
||||||
middleware.SetupContextForSelectedChannel(c, channel, originalModel)
|
//middleware.SetupContextForSelectedChannel(c, channel, originalModel)
|
||||||
|
|
||||||
requestBody, err := common.GetRequestBody(c)
|
requestBody, err := common.GetRequestBody(c)
|
||||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package middleware
|
package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
@@ -233,6 +234,18 @@ func TokenAuth() func(c *gin.Context) {
|
|||||||
|
|
||||||
userCache.WriteContext(c)
|
userCache.WriteContext(c)
|
||||||
|
|
||||||
|
err = SetupContextForToken(c, token, parts...)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.Next()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func SetupContextForToken(c *gin.Context, token *model.Token, parts ...string) error {
|
||||||
|
if token == nil {
|
||||||
|
return fmt.Errorf("token is nil")
|
||||||
|
}
|
||||||
c.Set("id", token.UserId)
|
c.Set("id", token.UserId)
|
||||||
c.Set("token_id", token.Id)
|
c.Set("token_id", token.Id)
|
||||||
c.Set("token_key", token.Key)
|
c.Set("token_key", token.Key)
|
||||||
@@ -254,9 +267,8 @@ func TokenAuth() func(c *gin.Context) {
|
|||||||
c.Set("specific_channel_id", parts[1])
|
c.Set("specific_channel_id", parts[1])
|
||||||
} else {
|
} else {
|
||||||
abortWithOpenAiMessage(c, http.StatusForbidden, "普通用户不支持指定渠道")
|
abortWithOpenAiMessage(c, http.StatusForbidden, "普通用户不支持指定渠道")
|
||||||
return
|
return fmt.Errorf("普通用户不支持指定渠道")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
c.Next()
|
return nil
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ import (
|
|||||||
|
|
||||||
type ModelRequest struct {
|
type ModelRequest struct {
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
|
Group string `json:"group,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func Distribute() func(c *gin.Context) {
|
func Distribute() func(c *gin.Context) {
|
||||||
@@ -237,6 +238,14 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
|
|||||||
}
|
}
|
||||||
c.Set("relay_mode", relayMode)
|
c.Set("relay_mode", relayMode)
|
||||||
}
|
}
|
||||||
|
if strings.HasPrefix(c.Request.URL.Path, "/pg/chat/completions") {
|
||||||
|
// playground chat completions
|
||||||
|
err = common.UnmarshalBodyReusable(c, &modelRequest)
|
||||||
|
if err != nil {
|
||||||
|
return nil, false, errors.New("无效的请求, " + err.Error())
|
||||||
|
}
|
||||||
|
common.SetContextKey(c, constant.ContextKeyTokenGroup, modelRequest.Group)
|
||||||
|
}
|
||||||
return &modelRequest, shouldSelectChannel, nil
|
return &modelRequest, shouldSelectChannel, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -245,20 +254,25 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
|
|||||||
if channel == nil {
|
if channel == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.Set("channel_id", channel.Id)
|
common.SetContextKey(c, constant.ContextKeyChannelId, channel.Id)
|
||||||
c.Set("channel_name", channel.Name)
|
common.SetContextKey(c, constant.ContextKeyChannelName, channel.Name)
|
||||||
c.Set("channel_type", channel.Type)
|
common.SetContextKey(c, constant.ContextKeyChannelType, channel.Type)
|
||||||
c.Set("channel_create_time", channel.CreatedTime)
|
common.SetContextKey(c, constant.ContextKeyChannelCreateTime, channel.CreatedTime)
|
||||||
c.Set("channel_setting", channel.GetSetting())
|
common.SetContextKey(c, constant.ContextKeyChannelSetting, channel.GetSetting())
|
||||||
c.Set("param_override", channel.GetParamOverride())
|
common.SetContextKey(c, constant.ContextKeyChannelParamOverride, channel.GetParamOverride())
|
||||||
if nil != channel.OpenAIOrganization && "" != *channel.OpenAIOrganization {
|
if nil != channel.OpenAIOrganization && *channel.OpenAIOrganization != "" {
|
||||||
c.Set("channel_organization", *channel.OpenAIOrganization)
|
common.SetContextKey(c, constant.ContextKeyChannelOrganization, *channel.OpenAIOrganization)
|
||||||
|
}
|
||||||
|
common.SetContextKey(c, constant.ContextKeyChannelAutoBan, channel.GetAutoBan())
|
||||||
|
common.SetContextKey(c, constant.ContextKeyChannelModelMapping, channel.GetModelMapping())
|
||||||
|
common.SetContextKey(c, constant.ContextKeyChannelStatusCodeMapping, channel.GetStatusCodeMapping())
|
||||||
|
if channel.ChannelInfo.IsMultiKey {
|
||||||
|
common.SetContextKey(c, constant.ContextKeyChannelIsMultiKey, true)
|
||||||
|
|
||||||
}
|
}
|
||||||
c.Set("auto_ban", channel.GetAutoBan())
|
|
||||||
c.Set("model_mapping", channel.GetModelMapping())
|
|
||||||
c.Set("status_code_mapping", channel.GetStatusCodeMapping())
|
|
||||||
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
|
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
|
||||||
c.Set("base_url", channel.GetBaseURL())
|
common.SetContextKey(c, constant.ContextKeyChannelBaseUrl, channel.GetBaseURL())
|
||||||
|
|
||||||
// TODO: api_version统一
|
// TODO: api_version统一
|
||||||
switch channel.Type {
|
switch channel.Type {
|
||||||
case constant.ChannelTypeAzure:
|
case constant.ChannelTypeAzure:
|
||||||
|
|||||||
@@ -3,7 +3,10 @@ package model
|
|||||||
import (
|
import (
|
||||||
"database/sql/driver"
|
"database/sql/driver"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"math/rand"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/constant"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
@@ -43,20 +46,93 @@ type Channel struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type ChannelInfo struct {
|
type ChannelInfo struct {
|
||||||
MultiKeyMode bool `json:"multi_key_mode"` // 是否多Key模式
|
IsMultiKey bool `json:"is_multi_key"` // 是否多Key模式
|
||||||
MultiKeyStatusList map[int]int `json:"multi_key_status_list"` // key状态列表,key index -> status
|
MultiKeyStatusList map[int]int `json:"multi_key_status_list"` // key状态列表,key index -> status
|
||||||
MultiKeySize int `json:"multi_key_size"` // 多Key模式下的key数量
|
MultiKeyPollingIndex int `json:"multi_key_polling_index"` // 多Key模式下轮询的key索引
|
||||||
|
MultiKeyMode constant.MultiKeyMode `json:"multi_key_mode"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Value implements driver.Valuer interface
|
// Value implements driver.Valuer interface
|
||||||
func (c ChannelInfo) Value() (driver.Value, error) {
|
func (c *ChannelInfo) Value() (driver.Value, error) {
|
||||||
return json.Marshal(c)
|
return common.EncodeJson(c)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Scan implements sql.Scanner interface
|
// Scan implements sql.Scanner interface
|
||||||
func (c *ChannelInfo) Scan(value interface{}) error {
|
func (c *ChannelInfo) Scan(value interface{}) error {
|
||||||
bytesValue, _ := value.([]byte)
|
bytesValue, _ := value.([]byte)
|
||||||
return json.Unmarshal(bytesValue, c)
|
return common.UnmarshalJson(bytesValue, c)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (channel *Channel) getKeys() []string {
|
||||||
|
if channel.Key == "" {
|
||||||
|
return []string{}
|
||||||
|
}
|
||||||
|
// use \n to split keys
|
||||||
|
keys := strings.Split(strings.Trim(channel.Key, "\n"), "\n")
|
||||||
|
return keys
|
||||||
|
}
|
||||||
|
|
||||||
|
func (channel *Channel) GetNextEnabledKey() (string, error) {
|
||||||
|
// If not in multi-key mode, return the original key string directly.
|
||||||
|
if !channel.ChannelInfo.IsMultiKey {
|
||||||
|
return channel.Key, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Obtain all keys (split by \n)
|
||||||
|
keys := channel.getKeys()
|
||||||
|
if len(keys) == 0 {
|
||||||
|
// No keys available, return error, should disable the channel
|
||||||
|
return "", fmt.Errorf("no valid keys in channel")
|
||||||
|
}
|
||||||
|
|
||||||
|
statusList := channel.ChannelInfo.MultiKeyStatusList
|
||||||
|
// helper to get key status, default to enabled when missing
|
||||||
|
getStatus := func(idx int) int {
|
||||||
|
if statusList == nil {
|
||||||
|
return common.ChannelStatusEnabled
|
||||||
|
}
|
||||||
|
if status, ok := statusList[idx]; ok {
|
||||||
|
return status
|
||||||
|
}
|
||||||
|
return common.ChannelStatusEnabled
|
||||||
|
}
|
||||||
|
|
||||||
|
// Collect indexes of enabled keys
|
||||||
|
enabledIdx := make([]int, 0, len(keys))
|
||||||
|
for i := range keys {
|
||||||
|
if getStatus(i) == common.ChannelStatusEnabled {
|
||||||
|
enabledIdx = append(enabledIdx, i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// If no specific status list or none enabled, fall back to first key
|
||||||
|
if len(enabledIdx) == 0 {
|
||||||
|
return keys[0], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
switch channel.ChannelInfo.MultiKeyMode {
|
||||||
|
case constant.MultiKeyModeRandom:
|
||||||
|
// Randomly pick one enabled key
|
||||||
|
return keys[enabledIdx[rand.Intn(len(enabledIdx))]], nil
|
||||||
|
case constant.MultiKeyModePolling:
|
||||||
|
// Start from the saved polling index and look for the next enabled key
|
||||||
|
start := channel.ChannelInfo.MultiKeyPollingIndex
|
||||||
|
if start < 0 || start >= len(keys) {
|
||||||
|
start = 0
|
||||||
|
}
|
||||||
|
for i := 0; i < len(keys); i++ {
|
||||||
|
idx := (start + i) % len(keys)
|
||||||
|
if getStatus(idx) == common.ChannelStatusEnabled {
|
||||||
|
// update polling index for next call (point to the next position)
|
||||||
|
channel.ChannelInfo.MultiKeyPollingIndex = (idx + 1) % len(keys)
|
||||||
|
return keys[idx], nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Fallback – should not happen, but return first enabled key
|
||||||
|
return keys[enabledIdx[0]], nil
|
||||||
|
default:
|
||||||
|
// Unknown mode, default to first enabled key (or original key string)
|
||||||
|
return keys[enabledIdx[0]], nil
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (channel *Channel) GetModels() []string {
|
func (channel *Channel) GetModels() []string {
|
||||||
|
|||||||
@@ -50,7 +50,7 @@ func formatUserLogs(logs []*Log) {
|
|||||||
for i := range logs {
|
for i := range logs {
|
||||||
logs[i].ChannelName = ""
|
logs[i].ChannelName = ""
|
||||||
var otherMap map[string]interface{}
|
var otherMap map[string]interface{}
|
||||||
otherMap = common.StrToMap(logs[i].Other)
|
otherMap, _ = common.StrToMap(logs[i].Other)
|
||||||
if otherMap != nil {
|
if otherMap != nil {
|
||||||
// delete admin
|
// delete admin
|
||||||
delete(otherMap, "admin_info")
|
delete(otherMap, "admin_info")
|
||||||
|
|||||||
@@ -68,11 +68,16 @@ func (user *User) SetAccessToken(token string) {
|
|||||||
user.AccessToken = &token
|
user.AccessToken = &token
|
||||||
}
|
}
|
||||||
|
|
||||||
func (user *User) GetSetting() map[string]interface{} {
|
func (user *User) GetSetting() (map[string]interface{}, error) {
|
||||||
if user.Setting == "" {
|
if user.Setting == "" {
|
||||||
return nil
|
return map[string]interface{}{}, nil
|
||||||
}
|
}
|
||||||
return common.StrToMap(user.Setting)
|
toMap, err := common.StrToMap(user.Setting)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("failed to convert setting to map: " + err.Error())
|
||||||
|
return nil, fmt.Errorf("failed to convert setting to map")
|
||||||
|
}
|
||||||
|
return toMap, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (user *User) SetSetting(setting map[string]interface{}) {
|
func (user *User) SetSetting(setting map[string]interface{}) {
|
||||||
@@ -651,7 +656,12 @@ func GetUserSetting(id int, fromDB bool) (settingMap map[string]interface{}, err
|
|||||||
return map[string]interface{}{}, err
|
return map[string]interface{}{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return common.StrToMap(setting), nil
|
toMap, err := common.StrToMap(setting)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("failed to convert setting to map: " + err.Error())
|
||||||
|
return nil, fmt.Errorf("failed to convert setting to map")
|
||||||
|
}
|
||||||
|
return toMap, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func IncreaseUserQuota(id int, quota int, db bool) (err error) {
|
func IncreaseUserQuota(id int, quota int, db bool) (err error) {
|
||||||
|
|||||||
@@ -36,7 +36,12 @@ func (user *UserBase) GetSetting() map[string]interface{} {
|
|||||||
if user.Setting == "" {
|
if user.Setting == "" {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return common.StrToMap(user.Setting)
|
toMap, err := common.StrToMap(user.Setting)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("failed to convert user setting to map: " + err.Error())
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return toMap
|
||||||
}
|
}
|
||||||
|
|
||||||
func (user *UserBase) SetSetting(setting map[string]interface{}) {
|
func (user *UserBase) SetSetting(setting map[string]interface{}) {
|
||||||
|
|||||||
@@ -4,8 +4,11 @@ import "one-api/common"
|
|||||||
|
|
||||||
func GetModelRegion(other string, localModelName string) string {
|
func GetModelRegion(other string, localModelName string) string {
|
||||||
// if other is json string
|
// if other is json string
|
||||||
if common.IsJsonStr(other) {
|
if common.IsJsonObject(other) {
|
||||||
m := common.StrToMap(other)
|
m, err := common.StrToMap(other)
|
||||||
|
if err != nil {
|
||||||
|
return other // return original if parsing fails
|
||||||
|
}
|
||||||
if m[localModelName] != nil {
|
if m[localModelName] != nil {
|
||||||
return m[localModelName].(string)
|
return m[localModelName].(string)
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@@ -214,7 +214,7 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
|
|||||||
channelType := common.GetContextKeyInt(c, constant.ContextKeyChannelType)
|
channelType := common.GetContextKeyInt(c, constant.ContextKeyChannelType)
|
||||||
channelId := common.GetContextKeyInt(c, constant.ContextKeyChannelId)
|
channelId := common.GetContextKeyInt(c, constant.ContextKeyChannelId)
|
||||||
channelSetting := common.GetContextKeyStringMap(c, constant.ContextKeyChannelSetting)
|
channelSetting := common.GetContextKeyStringMap(c, constant.ContextKeyChannelSetting)
|
||||||
paramOverride := common.GetContextKeyStringMap(c, constant.ContextKeyParamOverride)
|
paramOverride := common.GetContextKeyStringMap(c, constant.ContextKeyChannelParamOverride)
|
||||||
|
|
||||||
tokenId := common.GetContextKeyInt(c, constant.ContextKeyTokenId)
|
tokenId := common.GetContextKeyInt(c, constant.ContextKeyTokenId)
|
||||||
tokenKey := common.GetContextKeyString(c, constant.ContextKeyTokenKey)
|
tokenKey := common.GetContextKeyString(c, constant.ContextKeyTokenKey)
|
||||||
@@ -231,7 +231,7 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
|
|||||||
UserEmail: common.GetContextKeyString(c, constant.ContextKeyUserEmail),
|
UserEmail: common.GetContextKeyString(c, constant.ContextKeyUserEmail),
|
||||||
isFirstResponse: true,
|
isFirstResponse: true,
|
||||||
RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path),
|
RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path),
|
||||||
BaseUrl: common.GetContextKeyString(c, constant.ContextKeyBaseUrl),
|
BaseUrl: common.GetContextKeyString(c, constant.ContextKeyChannelBaseUrl),
|
||||||
RequestURLPath: c.Request.URL.String(),
|
RequestURLPath: c.Request.URL.String(),
|
||||||
ChannelType: channelType,
|
ChannelType: channelType,
|
||||||
ChannelId: channelId,
|
ChannelId: channelId,
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ func SetRelayRouter(router *gin.Engine) {
|
|||||||
modelsRouter.GET("/:model", controller.RetrieveModel)
|
modelsRouter.GET("/:model", controller.RetrieveModel)
|
||||||
}
|
}
|
||||||
playgroundRouter := router.Group("/pg")
|
playgroundRouter := router.Group("/pg")
|
||||||
playgroundRouter.Use(middleware.UserAuth())
|
playgroundRouter.Use(middleware.UserAuth(), middleware.Distribute())
|
||||||
{
|
{
|
||||||
playgroundRouter.POST("/chat/completions", controller.Playground)
|
playgroundRouter.POST("/chat/completions", controller.Playground)
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user