diff --git a/common/str.go b/common/str.go index d42fd837..da44a987 100644 --- a/common/str.go +++ b/common/str.go @@ -31,16 +31,30 @@ func MapToJsonStr(m map[string]interface{}) string { return string(bytes) } -func StrToMap(str string) map[string]interface{} { +func StrToMap(str string) (map[string]interface{}, error) { m := make(map[string]interface{}) - err := json.Unmarshal([]byte(str), &m) + err := UnmarshalJson([]byte(str), &m) if err != nil { - return nil + return nil, err } - return m + return m, nil } -func IsJsonStr(str string) bool { +func StrToJsonArray(str string) ([]interface{}, error) { + var js []interface{} + err := json.Unmarshal([]byte(str), &js) + if err != nil { + return nil, err + } + return js, nil +} + +func IsJsonArray(str string) bool { + var js []interface{} + return json.Unmarshal([]byte(str), &js) == nil +} + +func IsJsonObject(str string) bool { var js map[string]interface{} return json.Unmarshal([]byte(str), &js) == nil } diff --git a/constant/context_key.go b/constant/context_key.go index 71e02f01..d58f1205 100644 --- a/constant/context_key.go +++ b/constant/context_key.go @@ -17,11 +17,18 @@ const ( ContextKeyTokenModelLimit ContextKey = "token_model_limit" /* channel related keys */ - ContextKeyBaseUrl ContextKey = "base_url" - ContextKeyChannelType ContextKey = "channel_type" - ContextKeyChannelId ContextKey = "channel_id" - ContextKeyChannelSetting ContextKey = "channel_setting" - ContextKeyParamOverride ContextKey = "param_override" + ContextKeyChannelId ContextKey = "channel_id" + ContextKeyChannelName ContextKey = "channel_name" + ContextKeyChannelCreateTime ContextKey = "channel_create_name" + ContextKeyChannelBaseUrl ContextKey = "base_url" + ContextKeyChannelType ContextKey = "channel_type" + ContextKeyChannelSetting ContextKey = "channel_setting" + ContextKeyChannelParamOverride ContextKey = "param_override" + ContextKeyChannelOrganization ContextKey = "channel_organization" + ContextKeyChannelAutoBan ContextKey = "auto_ban" + ContextKeyChannelModelMapping ContextKey = "model_mapping" + ContextKeyChannelStatusCodeMapping ContextKey = "status_code_mapping" + ContextKeyChannelIsMultiKey ContextKey = "channel_is_multi_key" /* user related keys */ ContextKeyUserId ContextKey = "id" diff --git a/constant/multi_key_mode.go b/constant/multi_key_mode.go new file mode 100644 index 00000000..cd0cdbff --- /dev/null +++ b/constant/multi_key_mode.go @@ -0,0 +1,8 @@ +package constant + +type MultiKeyMode string + +const ( + MultiKeyModeRandom MultiKeyMode = "random" // 随机 + MultiKeyModePolling MultiKeyMode = "polling" // 轮询 +) diff --git a/controller/channel.go b/controller/channel.go index d65a53df..bf735779 100644 --- a/controller/channel.go +++ b/controller/channel.go @@ -378,8 +378,31 @@ func GetChannel(c *gin.Context) { } type AddChannelRequest struct { - Mode string `json:"mode"` - Channel *model.Channel `json:"channel"` + Mode string `json:"mode"` + MultiKeyMode constant.MultiKeyMode `json:"multi_key_mode"` + Channel *model.Channel `json:"channel"` +} + +func getVertexArrayKeys(keys string) ([]string, error) { + if keys == "" { + return nil, nil + } + var keyArray []interface{} + err := common.UnmarshalJson([]byte(keys), &keyArray) + if err != nil { + return nil, fmt.Errorf("批量添加 Vertex AI 必须使用标准的JsonArray格式,例如[{key1}, {key2}...],请检查输入: %w", err) + } + cleanKeys := make([]string, 0, len(keyArray)) + for _, key := range keyArray { + keyStr := fmt.Sprintf("%v", key) + if keyStr != "" { + cleanKeys = append(cleanKeys, strings.TrimSpace(keyStr)) + } + } + if len(cleanKeys) == 0 { + return nil, fmt.Errorf("批量添加 Vertex AI 的 keys 不能为空") + } + return cleanKeys, nil } func AddChannel(c *gin.Context) { @@ -418,16 +441,20 @@ func AddChannel(c *gin.Context) { }) return } else { - if common.IsJsonStr(addChannelRequest.Channel.Other) { - // must have default - regionMap := common.StrToMap(addChannelRequest.Channel.Other) - if regionMap["default"] == nil { - c.JSON(http.StatusOK, gin.H{ - "success": false, - "message": "部署地区必须包含default字段", - }) - return - } + regionMap, err := common.StrToMap(addChannelRequest.Channel.Other) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "部署地区必须是标准的Json格式,例如{\"default\": \"us-central1\", \"region2\": \"us-east1\"}", + }) + return + } + if regionMap["default"] == nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "部署地区必须包含default字段", + }) + return } } } @@ -436,51 +463,41 @@ func AddChannel(c *gin.Context) { keys := make([]string, 0) switch addChannelRequest.Mode { case "multi_to_single": - addChannelRequest.Channel.ChannelInfo.MultiKeyMode = true + addChannelRequest.Channel.ChannelInfo.IsMultiKey = true + addChannelRequest.Channel.ChannelInfo.MultiKeyMode = addChannelRequest.MultiKeyMode if addChannelRequest.Channel.Type == constant.ChannelTypeVertexAi { - if !common.IsJsonStr(addChannelRequest.Channel.Key) { + array, err := getVertexArrayKeys(addChannelRequest.Channel.Key) + if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, - "message": "Vertex AI 批量添加模式必须使用标准的JsonArray格式,例如[{key1}, {key2}...],请检查输入", + "message": err.Error(), }) return } - toMap := common.StrToMap(addChannelRequest.Channel.Key) - if toMap != nil { - addChannelRequest.Channel.ChannelInfo.MultiKeySize = len(toMap) - } else { - addChannelRequest.Channel.ChannelInfo.MultiKeySize = 0 - } + addChannelRequest.Channel.Key = strings.Join(array, "\n") } else { cleanKeys := make([]string, 0) for _, key := range strings.Split(addChannelRequest.Channel.Key, "\n") { if key == "" { continue } + key = strings.TrimSpace(key) cleanKeys = append(cleanKeys, key) } - addChannelRequest.Channel.ChannelInfo.MultiKeySize = len(cleanKeys) addChannelRequest.Channel.Key = strings.Join(cleanKeys, "\n") } keys = []string{addChannelRequest.Channel.Key} case "batch": if addChannelRequest.Channel.Type == constant.ChannelTypeVertexAi { // multi json - toMap := common.StrToMap(addChannelRequest.Channel.Key) - if toMap == nil { + keys, err = getVertexArrayKeys(addChannelRequest.Channel.Key) + if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, - "message": "Vertex AI 批量添加模式必须使用标准的JsonArray格式,例如[{key1}, {key2}...],请检查输入", + "message": err.Error(), }) return } - keys = make([]string, 0, len(toMap)) - for k := range toMap { - if k == "" { - continue - } - keys = append(keys, k) - } } else { keys = strings.Split(addChannelRequest.Channel.Key, "\n") } @@ -694,16 +711,20 @@ func UpdateChannel(c *gin.Context) { }) return } else { - if common.IsJsonStr(channel.Other) { - // must have default - regionMap := common.StrToMap(channel.Other) - if regionMap["default"] == nil { - c.JSON(http.StatusOK, gin.H{ - "success": false, - "message": "部署地区必须包含default字段", - }) - return - } + regionMap, err := common.StrToMap(channel.Other) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "部署地区必须是标准的Json格式,例如{\"default\": \"us-central1\", \"region2\": \"us-east1\"}", + }) + return + } + if regionMap["default"] == nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "部署地区必须包含default字段", + }) + return } } } diff --git a/controller/playground.go b/controller/playground.go index 33471455..98b93031 100644 --- a/controller/playground.go +++ b/controller/playground.go @@ -57,18 +57,24 @@ func Playground(c *gin.Context) { } c.Set("group", group) } - c.Set("token_name", "playground-"+group) - channel, finalGroup, err := model.CacheGetRandomSatisfiedChannel(c, group, playgroundRequest.Model, 0) + + userId := c.GetInt("id") + //c.Set("token_name", "playground-"+group) + tempToken := &model.Token{ + UserId: userId, + Name: fmt.Sprintf("playground-%s", group), + Group: group, + } + _ = middleware.SetupContextForToken(c, tempToken) + _, err = getChannel(c, group, playgroundRequest.Model, 0) if err != nil { - message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", finalGroup, playgroundRequest.Model) - openaiErr = service.OpenAIErrorWrapperLocal(errors.New(message), "get_playground_channel_failed", http.StatusInternalServerError) + openaiErr = service.OpenAIErrorWrapperLocal(err, "get_playground_channel_failed", http.StatusInternalServerError) return } - middleware.SetupContextForSelectedChannel(c, channel, playgroundRequest.Model) + //middleware.SetupContextForSelectedChannel(c, channel, playgroundRequest.Model) common.SetContextKey(c, constant.ContextKeyRequestStartTime, time.Now()) // Write user context to ensure acceptUnsetRatio is available - userId := c.GetInt("id") userCache, err := model.GetUserCache(userId) if err != nil { openaiErr = service.OpenAIErrorWrapperLocal(err, "get_user_cache_failed", http.StatusInternalServerError) diff --git a/controller/relay.go b/controller/relay.go index e375120b..7b113221 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -259,9 +259,12 @@ func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*m AutoBan: &autoBanInt, }, nil } - channel, _, err := model.CacheGetRandomSatisfiedChannel(c, group, originalModel, retryCount) + channel, selectGroup, err := model.CacheGetRandomSatisfiedChannel(c, group, originalModel, retryCount) if err != nil { - return nil, errors.New(fmt.Sprintf("获取重试渠道失败: %s", err.Error())) + if group == "auto" { + return nil, errors.New(fmt.Sprintf("获取自动分组下模型 %s 的可用渠道失败: %s", originalModel, err.Error())) + } + return nil, errors.New(fmt.Sprintf("获取分组 %s 下模型 %s 的可用渠道失败: %s", selectGroup, originalModel, err.Error())) } middleware.SetupContextForSelectedChannel(c, channel, originalModel) return channel, nil @@ -388,9 +391,10 @@ func RelayTask(c *gin.Context) { retryTimes = 0 } for i := 0; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && i < retryTimes; i++ { - channel, _, err := model.CacheGetRandomSatisfiedChannel(c, group, originalModel, i) + channel, err := getChannel(c, group, originalModel, i) if err != nil { common.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", err.Error())) + taskErr = service.TaskErrorWrapperLocal(err, "get_channel_failed", http.StatusInternalServerError) break } channelId = channel.Id @@ -398,7 +402,7 @@ func RelayTask(c *gin.Context) { useChannel = append(useChannel, fmt.Sprintf("%d", channelId)) c.Set("use_channel", useChannel) common.LogInfo(c, fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i)) - middleware.SetupContextForSelectedChannel(c, channel, originalModel) + //middleware.SetupContextForSelectedChannel(c, channel, originalModel) requestBody, err := common.GetRequestBody(c) c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) diff --git a/middleware/auth.go b/middleware/auth.go index ecf4844b..47d033a9 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -1,6 +1,7 @@ package middleware import ( + "fmt" "net/http" "one-api/common" "one-api/model" @@ -233,30 +234,41 @@ func TokenAuth() func(c *gin.Context) { userCache.WriteContext(c) - c.Set("id", token.UserId) - c.Set("token_id", token.Id) - c.Set("token_key", token.Key) - c.Set("token_name", token.Name) - c.Set("token_unlimited_quota", token.UnlimitedQuota) - if !token.UnlimitedQuota { - c.Set("token_quota", token.RemainQuota) - } - if token.ModelLimitsEnabled { - c.Set("token_model_limit_enabled", true) - c.Set("token_model_limit", token.GetModelLimitsMap()) - } else { - c.Set("token_model_limit_enabled", false) - } - c.Set("allow_ips", token.GetIpLimitsMap()) - c.Set("token_group", token.Group) - if len(parts) > 1 { - if model.IsAdmin(token.UserId) { - c.Set("specific_channel_id", parts[1]) - } else { - abortWithOpenAiMessage(c, http.StatusForbidden, "普通用户不支持指定渠道") - return - } + err = SetupContextForToken(c, token, parts...) + if err != nil { + return } c.Next() } } + +func SetupContextForToken(c *gin.Context, token *model.Token, parts ...string) error { + if token == nil { + return fmt.Errorf("token is nil") + } + c.Set("id", token.UserId) + c.Set("token_id", token.Id) + c.Set("token_key", token.Key) + c.Set("token_name", token.Name) + c.Set("token_unlimited_quota", token.UnlimitedQuota) + if !token.UnlimitedQuota { + c.Set("token_quota", token.RemainQuota) + } + if token.ModelLimitsEnabled { + c.Set("token_model_limit_enabled", true) + c.Set("token_model_limit", token.GetModelLimitsMap()) + } else { + c.Set("token_model_limit_enabled", false) + } + c.Set("allow_ips", token.GetIpLimitsMap()) + c.Set("token_group", token.Group) + if len(parts) > 1 { + if model.IsAdmin(token.UserId) { + c.Set("specific_channel_id", parts[1]) + } else { + abortWithOpenAiMessage(c, http.StatusForbidden, "普通用户不支持指定渠道") + return fmt.Errorf("普通用户不支持指定渠道") + } + } + return nil +} diff --git a/middleware/distributor.go b/middleware/distributor.go index 17916e7a..18959e61 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -21,6 +21,7 @@ import ( type ModelRequest struct { Model string `json:"model"` + Group string `json:"group,omitempty"` } func Distribute() func(c *gin.Context) { @@ -237,6 +238,14 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) { } c.Set("relay_mode", relayMode) } + if strings.HasPrefix(c.Request.URL.Path, "/pg/chat/completions") { + // playground chat completions + err = common.UnmarshalBodyReusable(c, &modelRequest) + if err != nil { + return nil, false, errors.New("无效的请求, " + err.Error()) + } + common.SetContextKey(c, constant.ContextKeyTokenGroup, modelRequest.Group) + } return &modelRequest, shouldSelectChannel, nil } @@ -245,20 +254,25 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode if channel == nil { return } - c.Set("channel_id", channel.Id) - c.Set("channel_name", channel.Name) - c.Set("channel_type", channel.Type) - c.Set("channel_create_time", channel.CreatedTime) - c.Set("channel_setting", channel.GetSetting()) - c.Set("param_override", channel.GetParamOverride()) - if nil != channel.OpenAIOrganization && "" != *channel.OpenAIOrganization { - c.Set("channel_organization", *channel.OpenAIOrganization) + common.SetContextKey(c, constant.ContextKeyChannelId, channel.Id) + common.SetContextKey(c, constant.ContextKeyChannelName, channel.Name) + common.SetContextKey(c, constant.ContextKeyChannelType, channel.Type) + common.SetContextKey(c, constant.ContextKeyChannelCreateTime, channel.CreatedTime) + common.SetContextKey(c, constant.ContextKeyChannelSetting, channel.GetSetting()) + common.SetContextKey(c, constant.ContextKeyChannelParamOverride, channel.GetParamOverride()) + if nil != channel.OpenAIOrganization && *channel.OpenAIOrganization != "" { + common.SetContextKey(c, constant.ContextKeyChannelOrganization, *channel.OpenAIOrganization) + } + common.SetContextKey(c, constant.ContextKeyChannelAutoBan, channel.GetAutoBan()) + common.SetContextKey(c, constant.ContextKeyChannelModelMapping, channel.GetModelMapping()) + common.SetContextKey(c, constant.ContextKeyChannelStatusCodeMapping, channel.GetStatusCodeMapping()) + if channel.ChannelInfo.IsMultiKey { + common.SetContextKey(c, constant.ContextKeyChannelIsMultiKey, true) + } - c.Set("auto_ban", channel.GetAutoBan()) - c.Set("model_mapping", channel.GetModelMapping()) - c.Set("status_code_mapping", channel.GetStatusCodeMapping()) c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) - c.Set("base_url", channel.GetBaseURL()) + common.SetContextKey(c, constant.ContextKeyChannelBaseUrl, channel.GetBaseURL()) + // TODO: api_version统一 switch channel.Type { case constant.ChannelTypeAzure: diff --git a/model/channel.go b/model/channel.go index ed9a478a..0b8d9ba2 100644 --- a/model/channel.go +++ b/model/channel.go @@ -3,7 +3,10 @@ package model import ( "database/sql/driver" "encoding/json" + "fmt" + "math/rand" "one-api/common" + "one-api/constant" "strings" "sync" @@ -43,20 +46,93 @@ type Channel struct { } type ChannelInfo struct { - MultiKeyMode bool `json:"multi_key_mode"` // 是否多Key模式 - MultiKeyStatusList map[int]int `json:"multi_key_status_list"` // key状态列表,key index -> status - MultiKeySize int `json:"multi_key_size"` // 多Key模式下的key数量 + IsMultiKey bool `json:"is_multi_key"` // 是否多Key模式 + MultiKeyStatusList map[int]int `json:"multi_key_status_list"` // key状态列表,key index -> status + MultiKeyPollingIndex int `json:"multi_key_polling_index"` // 多Key模式下轮询的key索引 + MultiKeyMode constant.MultiKeyMode `json:"multi_key_mode"` } // Value implements driver.Valuer interface -func (c ChannelInfo) Value() (driver.Value, error) { - return json.Marshal(c) +func (c *ChannelInfo) Value() (driver.Value, error) { + return common.EncodeJson(c) } // Scan implements sql.Scanner interface func (c *ChannelInfo) Scan(value interface{}) error { bytesValue, _ := value.([]byte) - return json.Unmarshal(bytesValue, c) + return common.UnmarshalJson(bytesValue, c) +} + +func (channel *Channel) getKeys() []string { + if channel.Key == "" { + return []string{} + } + // use \n to split keys + keys := strings.Split(strings.Trim(channel.Key, "\n"), "\n") + return keys +} + +func (channel *Channel) GetNextEnabledKey() (string, error) { + // If not in multi-key mode, return the original key string directly. + if !channel.ChannelInfo.IsMultiKey { + return channel.Key, nil + } + + // Obtain all keys (split by \n) + keys := channel.getKeys() + if len(keys) == 0 { + // No keys available, return error, should disable the channel + return "", fmt.Errorf("no valid keys in channel") + } + + statusList := channel.ChannelInfo.MultiKeyStatusList + // helper to get key status, default to enabled when missing + getStatus := func(idx int) int { + if statusList == nil { + return common.ChannelStatusEnabled + } + if status, ok := statusList[idx]; ok { + return status + } + return common.ChannelStatusEnabled + } + + // Collect indexes of enabled keys + enabledIdx := make([]int, 0, len(keys)) + for i := range keys { + if getStatus(i) == common.ChannelStatusEnabled { + enabledIdx = append(enabledIdx, i) + } + } + // If no specific status list or none enabled, fall back to first key + if len(enabledIdx) == 0 { + return keys[0], nil + } + + switch channel.ChannelInfo.MultiKeyMode { + case constant.MultiKeyModeRandom: + // Randomly pick one enabled key + return keys[enabledIdx[rand.Intn(len(enabledIdx))]], nil + case constant.MultiKeyModePolling: + // Start from the saved polling index and look for the next enabled key + start := channel.ChannelInfo.MultiKeyPollingIndex + if start < 0 || start >= len(keys) { + start = 0 + } + for i := 0; i < len(keys); i++ { + idx := (start + i) % len(keys) + if getStatus(idx) == common.ChannelStatusEnabled { + // update polling index for next call (point to the next position) + channel.ChannelInfo.MultiKeyPollingIndex = (idx + 1) % len(keys) + return keys[idx], nil + } + } + // Fallback – should not happen, but return first enabled key + return keys[enabledIdx[0]], nil + default: + // Unknown mode, default to first enabled key (or original key string) + return keys[enabledIdx[0]], nil + } } func (channel *Channel) GetModels() []string { diff --git a/model/log.go b/model/log.go index b3fd1ad2..1550aa91 100644 --- a/model/log.go +++ b/model/log.go @@ -50,7 +50,7 @@ func formatUserLogs(logs []*Log) { for i := range logs { logs[i].ChannelName = "" var otherMap map[string]interface{} - otherMap = common.StrToMap(logs[i].Other) + otherMap, _ = common.StrToMap(logs[i].Other) if otherMap != nil { // delete admin delete(otherMap, "admin_info") diff --git a/model/user.go b/model/user.go index bd685e54..634d6754 100644 --- a/model/user.go +++ b/model/user.go @@ -68,11 +68,16 @@ func (user *User) SetAccessToken(token string) { user.AccessToken = &token } -func (user *User) GetSetting() map[string]interface{} { +func (user *User) GetSetting() (map[string]interface{}, error) { if user.Setting == "" { - return nil + return map[string]interface{}{}, nil } - return common.StrToMap(user.Setting) + toMap, err := common.StrToMap(user.Setting) + if err != nil { + common.SysError("failed to convert setting to map: " + err.Error()) + return nil, fmt.Errorf("failed to convert setting to map") + } + return toMap, nil } func (user *User) SetSetting(setting map[string]interface{}) { @@ -651,7 +656,12 @@ func GetUserSetting(id int, fromDB bool) (settingMap map[string]interface{}, err return map[string]interface{}{}, err } - return common.StrToMap(setting), nil + toMap, err := common.StrToMap(setting) + if err != nil { + common.SysError("failed to convert setting to map: " + err.Error()) + return nil, fmt.Errorf("failed to convert setting to map") + } + return toMap, nil } func IncreaseUserQuota(id int, quota int, db bool) (err error) { diff --git a/model/user_cache.go b/model/user_cache.go index b4bc2f1e..44eaa842 100644 --- a/model/user_cache.go +++ b/model/user_cache.go @@ -36,7 +36,12 @@ func (user *UserBase) GetSetting() map[string]interface{} { if user.Setting == "" { return nil } - return common.StrToMap(user.Setting) + toMap, err := common.StrToMap(user.Setting) + if err != nil { + common.SysError("failed to convert user setting to map: " + err.Error()) + return nil + } + return toMap } func (user *UserBase) SetSetting(setting map[string]interface{}) { diff --git a/relay/channel/vertex/relay-vertex.go b/relay/channel/vertex/relay-vertex.go index d2596320..5ed87665 100644 --- a/relay/channel/vertex/relay-vertex.go +++ b/relay/channel/vertex/relay-vertex.go @@ -4,8 +4,11 @@ import "one-api/common" func GetModelRegion(other string, localModelName string) string { // if other is json string - if common.IsJsonStr(other) { - m := common.StrToMap(other) + if common.IsJsonObject(other) { + m, err := common.StrToMap(other) + if err != nil { + return other // return original if parsing fails + } if m[localModelName] != nil { return m[localModelName].(string) } else { diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index 37161c16..ce170df4 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -214,7 +214,7 @@ func GenRelayInfo(c *gin.Context) *RelayInfo { channelType := common.GetContextKeyInt(c, constant.ContextKeyChannelType) channelId := common.GetContextKeyInt(c, constant.ContextKeyChannelId) channelSetting := common.GetContextKeyStringMap(c, constant.ContextKeyChannelSetting) - paramOverride := common.GetContextKeyStringMap(c, constant.ContextKeyParamOverride) + paramOverride := common.GetContextKeyStringMap(c, constant.ContextKeyChannelParamOverride) tokenId := common.GetContextKeyInt(c, constant.ContextKeyTokenId) tokenKey := common.GetContextKeyString(c, constant.ContextKeyTokenKey) @@ -231,7 +231,7 @@ func GenRelayInfo(c *gin.Context) *RelayInfo { UserEmail: common.GetContextKeyString(c, constant.ContextKeyUserEmail), isFirstResponse: true, RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path), - BaseUrl: common.GetContextKeyString(c, constant.ContextKeyBaseUrl), + BaseUrl: common.GetContextKeyString(c, constant.ContextKeyChannelBaseUrl), RequestURLPath: c.Request.URL.String(), ChannelType: channelType, ChannelId: channelId, diff --git a/router/relay-router.go b/router/relay-router.go index b48c9dc7..5b293dbd 100644 --- a/router/relay-router.go +++ b/router/relay-router.go @@ -20,7 +20,7 @@ func SetRelayRouter(router *gin.Engine) { modelsRouter.GET("/:model", controller.RetrieveModel) } playgroundRouter := router.Group("/pg") - playgroundRouter.Use(middleware.UserAuth()) + playgroundRouter.Use(middleware.UserAuth(), middleware.Distribute()) { playgroundRouter.POST("/chat/completions", controller.Playground) }