🔧 refactor(auth, channel, context): improve context setup and validation for multi-key channels

This commit is contained in:
CaIon
2025-07-06 12:37:56 +08:00
parent b695e67154
commit f0f277dc2a
15 changed files with 294 additions and 114 deletions

View File

@@ -31,16 +31,30 @@ func MapToJsonStr(m map[string]interface{}) string {
return string(bytes) return string(bytes)
} }
func StrToMap(str string) map[string]interface{} { func StrToMap(str string) (map[string]interface{}, error) {
m := make(map[string]interface{}) m := make(map[string]interface{})
err := json.Unmarshal([]byte(str), &m) err := UnmarshalJson([]byte(str), &m)
if err != nil { if err != nil {
return nil return nil, err
} }
return m return m, nil
} }
func IsJsonStr(str string) bool { func StrToJsonArray(str string) ([]interface{}, error) {
var js []interface{}
err := json.Unmarshal([]byte(str), &js)
if err != nil {
return nil, err
}
return js, nil
}
func IsJsonArray(str string) bool {
var js []interface{}
return json.Unmarshal([]byte(str), &js) == nil
}
func IsJsonObject(str string) bool {
var js map[string]interface{} var js map[string]interface{}
return json.Unmarshal([]byte(str), &js) == nil return json.Unmarshal([]byte(str), &js) == nil
} }

View File

@@ -17,11 +17,18 @@ const (
ContextKeyTokenModelLimit ContextKey = "token_model_limit" ContextKeyTokenModelLimit ContextKey = "token_model_limit"
/* channel related keys */ /* channel related keys */
ContextKeyBaseUrl ContextKey = "base_url"
ContextKeyChannelType ContextKey = "channel_type"
ContextKeyChannelId ContextKey = "channel_id" ContextKeyChannelId ContextKey = "channel_id"
ContextKeyChannelName ContextKey = "channel_name"
ContextKeyChannelCreateTime ContextKey = "channel_create_name"
ContextKeyChannelBaseUrl ContextKey = "base_url"
ContextKeyChannelType ContextKey = "channel_type"
ContextKeyChannelSetting ContextKey = "channel_setting" ContextKeyChannelSetting ContextKey = "channel_setting"
ContextKeyParamOverride ContextKey = "param_override" ContextKeyChannelParamOverride ContextKey = "param_override"
ContextKeyChannelOrganization ContextKey = "channel_organization"
ContextKeyChannelAutoBan ContextKey = "auto_ban"
ContextKeyChannelModelMapping ContextKey = "model_mapping"
ContextKeyChannelStatusCodeMapping ContextKey = "status_code_mapping"
ContextKeyChannelIsMultiKey ContextKey = "channel_is_multi_key"
/* user related keys */ /* user related keys */
ContextKeyUserId ContextKey = "id" ContextKeyUserId ContextKey = "id"

View File

@@ -0,0 +1,8 @@
package constant
type MultiKeyMode string
const (
MultiKeyModeRandom MultiKeyMode = "random" // 随机
MultiKeyModePolling MultiKeyMode = "polling" // 轮询
)

View File

@@ -379,9 +379,32 @@ func GetChannel(c *gin.Context) {
type AddChannelRequest struct { type AddChannelRequest struct {
Mode string `json:"mode"` Mode string `json:"mode"`
MultiKeyMode constant.MultiKeyMode `json:"multi_key_mode"`
Channel *model.Channel `json:"channel"` Channel *model.Channel `json:"channel"`
} }
func getVertexArrayKeys(keys string) ([]string, error) {
if keys == "" {
return nil, nil
}
var keyArray []interface{}
err := common.UnmarshalJson([]byte(keys), &keyArray)
if err != nil {
return nil, fmt.Errorf("批量添加 Vertex AI 必须使用标准的JsonArray格式例如[{key1}, {key2}...],请检查输入: %w", err)
}
cleanKeys := make([]string, 0, len(keyArray))
for _, key := range keyArray {
keyStr := fmt.Sprintf("%v", key)
if keyStr != "" {
cleanKeys = append(cleanKeys, strings.TrimSpace(keyStr))
}
}
if len(cleanKeys) == 0 {
return nil, fmt.Errorf("批量添加 Vertex AI 的 keys 不能为空")
}
return cleanKeys, nil
}
func AddChannel(c *gin.Context) { func AddChannel(c *gin.Context) {
addChannelRequest := AddChannelRequest{} addChannelRequest := AddChannelRequest{}
err := c.ShouldBindJSON(&addChannelRequest) err := c.ShouldBindJSON(&addChannelRequest)
@@ -418,9 +441,14 @@ func AddChannel(c *gin.Context) {
}) })
return return
} else { } else {
if common.IsJsonStr(addChannelRequest.Channel.Other) { regionMap, err := common.StrToMap(addChannelRequest.Channel.Other)
// must have default if err != nil {
regionMap := common.StrToMap(addChannelRequest.Channel.Other) c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "部署地区必须是标准的Json格式例如{\"default\": \"us-central1\", \"region2\": \"us-east1\"}",
})
return
}
if regionMap["default"] == nil { if regionMap["default"] == nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
@@ -430,57 +458,46 @@ func AddChannel(c *gin.Context) {
} }
} }
} }
}
addChannelRequest.Channel.CreatedTime = common.GetTimestamp() addChannelRequest.Channel.CreatedTime = common.GetTimestamp()
keys := make([]string, 0) keys := make([]string, 0)
switch addChannelRequest.Mode { switch addChannelRequest.Mode {
case "multi_to_single": case "multi_to_single":
addChannelRequest.Channel.ChannelInfo.MultiKeyMode = true addChannelRequest.Channel.ChannelInfo.IsMultiKey = true
addChannelRequest.Channel.ChannelInfo.MultiKeyMode = addChannelRequest.MultiKeyMode
if addChannelRequest.Channel.Type == constant.ChannelTypeVertexAi { if addChannelRequest.Channel.Type == constant.ChannelTypeVertexAi {
if !common.IsJsonStr(addChannelRequest.Channel.Key) { array, err := getVertexArrayKeys(addChannelRequest.Channel.Key)
if err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": "Vertex AI 批量添加模式必须使用标准的JsonArray格式例如[{key1}, {key2}...],请检查输入", "message": err.Error(),
}) })
return return
} }
toMap := common.StrToMap(addChannelRequest.Channel.Key) addChannelRequest.Channel.Key = strings.Join(array, "\n")
if toMap != nil {
addChannelRequest.Channel.ChannelInfo.MultiKeySize = len(toMap)
} else {
addChannelRequest.Channel.ChannelInfo.MultiKeySize = 0
}
} else { } else {
cleanKeys := make([]string, 0) cleanKeys := make([]string, 0)
for _, key := range strings.Split(addChannelRequest.Channel.Key, "\n") { for _, key := range strings.Split(addChannelRequest.Channel.Key, "\n") {
if key == "" { if key == "" {
continue continue
} }
key = strings.TrimSpace(key)
cleanKeys = append(cleanKeys, key) cleanKeys = append(cleanKeys, key)
} }
addChannelRequest.Channel.ChannelInfo.MultiKeySize = len(cleanKeys)
addChannelRequest.Channel.Key = strings.Join(cleanKeys, "\n") addChannelRequest.Channel.Key = strings.Join(cleanKeys, "\n")
} }
keys = []string{addChannelRequest.Channel.Key} keys = []string{addChannelRequest.Channel.Key}
case "batch": case "batch":
if addChannelRequest.Channel.Type == constant.ChannelTypeVertexAi { if addChannelRequest.Channel.Type == constant.ChannelTypeVertexAi {
// multi json // multi json
toMap := common.StrToMap(addChannelRequest.Channel.Key) keys, err = getVertexArrayKeys(addChannelRequest.Channel.Key)
if toMap == nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": "Vertex AI 批量添加模式必须使用标准的JsonArray格式例如[{key1}, {key2}...],请检查输入", "message": err.Error(),
}) })
return return
} }
keys = make([]string, 0, len(toMap))
for k := range toMap {
if k == "" {
continue
}
keys = append(keys, k)
}
} else { } else {
keys = strings.Split(addChannelRequest.Channel.Key, "\n") keys = strings.Split(addChannelRequest.Channel.Key, "\n")
} }
@@ -694,9 +711,14 @@ func UpdateChannel(c *gin.Context) {
}) })
return return
} else { } else {
if common.IsJsonStr(channel.Other) { regionMap, err := common.StrToMap(channel.Other)
// must have default if err != nil {
regionMap := common.StrToMap(channel.Other) c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "部署地区必须是标准的Json格式例如{\"default\": \"us-central1\", \"region2\": \"us-east1\"}",
})
return
}
if regionMap["default"] == nil { if regionMap["default"] == nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
@@ -706,7 +728,6 @@ func UpdateChannel(c *gin.Context) {
} }
} }
} }
}
err = channel.Update() err = channel.Update()
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{

View File

@@ -57,18 +57,24 @@ func Playground(c *gin.Context) {
} }
c.Set("group", group) c.Set("group", group)
} }
c.Set("token_name", "playground-"+group)
channel, finalGroup, err := model.CacheGetRandomSatisfiedChannel(c, group, playgroundRequest.Model, 0) userId := c.GetInt("id")
//c.Set("token_name", "playground-"+group)
tempToken := &model.Token{
UserId: userId,
Name: fmt.Sprintf("playground-%s", group),
Group: group,
}
_ = middleware.SetupContextForToken(c, tempToken)
_, err = getChannel(c, group, playgroundRequest.Model, 0)
if err != nil { if err != nil {
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", finalGroup, playgroundRequest.Model) openaiErr = service.OpenAIErrorWrapperLocal(err, "get_playground_channel_failed", http.StatusInternalServerError)
openaiErr = service.OpenAIErrorWrapperLocal(errors.New(message), "get_playground_channel_failed", http.StatusInternalServerError)
return return
} }
middleware.SetupContextForSelectedChannel(c, channel, playgroundRequest.Model) //middleware.SetupContextForSelectedChannel(c, channel, playgroundRequest.Model)
common.SetContextKey(c, constant.ContextKeyRequestStartTime, time.Now()) common.SetContextKey(c, constant.ContextKeyRequestStartTime, time.Now())
// Write user context to ensure acceptUnsetRatio is available // Write user context to ensure acceptUnsetRatio is available
userId := c.GetInt("id")
userCache, err := model.GetUserCache(userId) userCache, err := model.GetUserCache(userId)
if err != nil { if err != nil {
openaiErr = service.OpenAIErrorWrapperLocal(err, "get_user_cache_failed", http.StatusInternalServerError) openaiErr = service.OpenAIErrorWrapperLocal(err, "get_user_cache_failed", http.StatusInternalServerError)

View File

@@ -259,9 +259,12 @@ func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*m
AutoBan: &autoBanInt, AutoBan: &autoBanInt,
}, nil }, nil
} }
channel, _, err := model.CacheGetRandomSatisfiedChannel(c, group, originalModel, retryCount) channel, selectGroup, err := model.CacheGetRandomSatisfiedChannel(c, group, originalModel, retryCount)
if err != nil { if err != nil {
return nil, errors.New(fmt.Sprintf("获取重试渠道失败: %s", err.Error())) if group == "auto" {
return nil, errors.New(fmt.Sprintf("获取自动分组下模型 %s 的可用渠道失败: %s", originalModel, err.Error()))
}
return nil, errors.New(fmt.Sprintf("获取分组 %s 下模型 %s 的可用渠道失败: %s", selectGroup, originalModel, err.Error()))
} }
middleware.SetupContextForSelectedChannel(c, channel, originalModel) middleware.SetupContextForSelectedChannel(c, channel, originalModel)
return channel, nil return channel, nil
@@ -388,9 +391,10 @@ func RelayTask(c *gin.Context) {
retryTimes = 0 retryTimes = 0
} }
for i := 0; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && i < retryTimes; i++ { for i := 0; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && i < retryTimes; i++ {
channel, _, err := model.CacheGetRandomSatisfiedChannel(c, group, originalModel, i) channel, err := getChannel(c, group, originalModel, i)
if err != nil { if err != nil {
common.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", err.Error())) common.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", err.Error()))
taskErr = service.TaskErrorWrapperLocal(err, "get_channel_failed", http.StatusInternalServerError)
break break
} }
channelId = channel.Id channelId = channel.Id
@@ -398,7 +402,7 @@ func RelayTask(c *gin.Context) {
useChannel = append(useChannel, fmt.Sprintf("%d", channelId)) useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
c.Set("use_channel", useChannel) c.Set("use_channel", useChannel)
common.LogInfo(c, fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i)) common.LogInfo(c, fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i))
middleware.SetupContextForSelectedChannel(c, channel, originalModel) //middleware.SetupContextForSelectedChannel(c, channel, originalModel)
requestBody, err := common.GetRequestBody(c) requestBody, err := common.GetRequestBody(c)
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))

View File

@@ -1,6 +1,7 @@
package middleware package middleware
import ( import (
"fmt"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/model" "one-api/model"
@@ -233,6 +234,18 @@ func TokenAuth() func(c *gin.Context) {
userCache.WriteContext(c) userCache.WriteContext(c)
err = SetupContextForToken(c, token, parts...)
if err != nil {
return
}
c.Next()
}
}
func SetupContextForToken(c *gin.Context, token *model.Token, parts ...string) error {
if token == nil {
return fmt.Errorf("token is nil")
}
c.Set("id", token.UserId) c.Set("id", token.UserId)
c.Set("token_id", token.Id) c.Set("token_id", token.Id)
c.Set("token_key", token.Key) c.Set("token_key", token.Key)
@@ -254,9 +267,8 @@ func TokenAuth() func(c *gin.Context) {
c.Set("specific_channel_id", parts[1]) c.Set("specific_channel_id", parts[1])
} else { } else {
abortWithOpenAiMessage(c, http.StatusForbidden, "普通用户不支持指定渠道") abortWithOpenAiMessage(c, http.StatusForbidden, "普通用户不支持指定渠道")
return return fmt.Errorf("普通用户不支持指定渠道")
} }
} }
c.Next() return nil
}
} }

View File

@@ -21,6 +21,7 @@ import (
type ModelRequest struct { type ModelRequest struct {
Model string `json:"model"` Model string `json:"model"`
Group string `json:"group,omitempty"`
} }
func Distribute() func(c *gin.Context) { func Distribute() func(c *gin.Context) {
@@ -237,6 +238,14 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
} }
c.Set("relay_mode", relayMode) c.Set("relay_mode", relayMode)
} }
if strings.HasPrefix(c.Request.URL.Path, "/pg/chat/completions") {
// playground chat completions
err = common.UnmarshalBodyReusable(c, &modelRequest)
if err != nil {
return nil, false, errors.New("无效的请求, " + err.Error())
}
common.SetContextKey(c, constant.ContextKeyTokenGroup, modelRequest.Group)
}
return &modelRequest, shouldSelectChannel, nil return &modelRequest, shouldSelectChannel, nil
} }
@@ -245,20 +254,25 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
if channel == nil { if channel == nil {
return return
} }
c.Set("channel_id", channel.Id) common.SetContextKey(c, constant.ContextKeyChannelId, channel.Id)
c.Set("channel_name", channel.Name) common.SetContextKey(c, constant.ContextKeyChannelName, channel.Name)
c.Set("channel_type", channel.Type) common.SetContextKey(c, constant.ContextKeyChannelType, channel.Type)
c.Set("channel_create_time", channel.CreatedTime) common.SetContextKey(c, constant.ContextKeyChannelCreateTime, channel.CreatedTime)
c.Set("channel_setting", channel.GetSetting()) common.SetContextKey(c, constant.ContextKeyChannelSetting, channel.GetSetting())
c.Set("param_override", channel.GetParamOverride()) common.SetContextKey(c, constant.ContextKeyChannelParamOverride, channel.GetParamOverride())
if nil != channel.OpenAIOrganization && "" != *channel.OpenAIOrganization { if nil != channel.OpenAIOrganization && *channel.OpenAIOrganization != "" {
c.Set("channel_organization", *channel.OpenAIOrganization) common.SetContextKey(c, constant.ContextKeyChannelOrganization, *channel.OpenAIOrganization)
}
common.SetContextKey(c, constant.ContextKeyChannelAutoBan, channel.GetAutoBan())
common.SetContextKey(c, constant.ContextKeyChannelModelMapping, channel.GetModelMapping())
common.SetContextKey(c, constant.ContextKeyChannelStatusCodeMapping, channel.GetStatusCodeMapping())
if channel.ChannelInfo.IsMultiKey {
common.SetContextKey(c, constant.ContextKeyChannelIsMultiKey, true)
} }
c.Set("auto_ban", channel.GetAutoBan())
c.Set("model_mapping", channel.GetModelMapping())
c.Set("status_code_mapping", channel.GetStatusCodeMapping())
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
c.Set("base_url", channel.GetBaseURL()) common.SetContextKey(c, constant.ContextKeyChannelBaseUrl, channel.GetBaseURL())
// TODO: api_version统一 // TODO: api_version统一
switch channel.Type { switch channel.Type {
case constant.ChannelTypeAzure: case constant.ChannelTypeAzure:

View File

@@ -3,7 +3,10 @@ package model
import ( import (
"database/sql/driver" "database/sql/driver"
"encoding/json" "encoding/json"
"fmt"
"math/rand"
"one-api/common" "one-api/common"
"one-api/constant"
"strings" "strings"
"sync" "sync"
@@ -43,20 +46,93 @@ type Channel struct {
} }
type ChannelInfo struct { type ChannelInfo struct {
MultiKeyMode bool `json:"multi_key_mode"` // 是否多Key模式 IsMultiKey bool `json:"is_multi_key"` // 是否多Key模式
MultiKeyStatusList map[int]int `json:"multi_key_status_list"` // key状态列表key index -> status MultiKeyStatusList map[int]int `json:"multi_key_status_list"` // key状态列表key index -> status
MultiKeySize int `json:"multi_key_size"` // 多Key模式下的key数量 MultiKeyPollingIndex int `json:"multi_key_polling_index"` // 多Key模式下轮询的key索引
MultiKeyMode constant.MultiKeyMode `json:"multi_key_mode"`
} }
// Value implements driver.Valuer interface // Value implements driver.Valuer interface
func (c ChannelInfo) Value() (driver.Value, error) { func (c *ChannelInfo) Value() (driver.Value, error) {
return json.Marshal(c) return common.EncodeJson(c)
} }
// Scan implements sql.Scanner interface // Scan implements sql.Scanner interface
func (c *ChannelInfo) Scan(value interface{}) error { func (c *ChannelInfo) Scan(value interface{}) error {
bytesValue, _ := value.([]byte) bytesValue, _ := value.([]byte)
return json.Unmarshal(bytesValue, c) return common.UnmarshalJson(bytesValue, c)
}
func (channel *Channel) getKeys() []string {
if channel.Key == "" {
return []string{}
}
// use \n to split keys
keys := strings.Split(strings.Trim(channel.Key, "\n"), "\n")
return keys
}
func (channel *Channel) GetNextEnabledKey() (string, error) {
// If not in multi-key mode, return the original key string directly.
if !channel.ChannelInfo.IsMultiKey {
return channel.Key, nil
}
// Obtain all keys (split by \n)
keys := channel.getKeys()
if len(keys) == 0 {
// No keys available, return error, should disable the channel
return "", fmt.Errorf("no valid keys in channel")
}
statusList := channel.ChannelInfo.MultiKeyStatusList
// helper to get key status, default to enabled when missing
getStatus := func(idx int) int {
if statusList == nil {
return common.ChannelStatusEnabled
}
if status, ok := statusList[idx]; ok {
return status
}
return common.ChannelStatusEnabled
}
// Collect indexes of enabled keys
enabledIdx := make([]int, 0, len(keys))
for i := range keys {
if getStatus(i) == common.ChannelStatusEnabled {
enabledIdx = append(enabledIdx, i)
}
}
// If no specific status list or none enabled, fall back to first key
if len(enabledIdx) == 0 {
return keys[0], nil
}
switch channel.ChannelInfo.MultiKeyMode {
case constant.MultiKeyModeRandom:
// Randomly pick one enabled key
return keys[enabledIdx[rand.Intn(len(enabledIdx))]], nil
case constant.MultiKeyModePolling:
// Start from the saved polling index and look for the next enabled key
start := channel.ChannelInfo.MultiKeyPollingIndex
if start < 0 || start >= len(keys) {
start = 0
}
for i := 0; i < len(keys); i++ {
idx := (start + i) % len(keys)
if getStatus(idx) == common.ChannelStatusEnabled {
// update polling index for next call (point to the next position)
channel.ChannelInfo.MultiKeyPollingIndex = (idx + 1) % len(keys)
return keys[idx], nil
}
}
// Fallback should not happen, but return first enabled key
return keys[enabledIdx[0]], nil
default:
// Unknown mode, default to first enabled key (or original key string)
return keys[enabledIdx[0]], nil
}
} }
func (channel *Channel) GetModels() []string { func (channel *Channel) GetModels() []string {

View File

@@ -50,7 +50,7 @@ func formatUserLogs(logs []*Log) {
for i := range logs { for i := range logs {
logs[i].ChannelName = "" logs[i].ChannelName = ""
var otherMap map[string]interface{} var otherMap map[string]interface{}
otherMap = common.StrToMap(logs[i].Other) otherMap, _ = common.StrToMap(logs[i].Other)
if otherMap != nil { if otherMap != nil {
// delete admin // delete admin
delete(otherMap, "admin_info") delete(otherMap, "admin_info")

View File

@@ -68,11 +68,16 @@ func (user *User) SetAccessToken(token string) {
user.AccessToken = &token user.AccessToken = &token
} }
func (user *User) GetSetting() map[string]interface{} { func (user *User) GetSetting() (map[string]interface{}, error) {
if user.Setting == "" { if user.Setting == "" {
return nil return map[string]interface{}{}, nil
} }
return common.StrToMap(user.Setting) toMap, err := common.StrToMap(user.Setting)
if err != nil {
common.SysError("failed to convert setting to map: " + err.Error())
return nil, fmt.Errorf("failed to convert setting to map")
}
return toMap, nil
} }
func (user *User) SetSetting(setting map[string]interface{}) { func (user *User) SetSetting(setting map[string]interface{}) {
@@ -651,7 +656,12 @@ func GetUserSetting(id int, fromDB bool) (settingMap map[string]interface{}, err
return map[string]interface{}{}, err return map[string]interface{}{}, err
} }
return common.StrToMap(setting), nil toMap, err := common.StrToMap(setting)
if err != nil {
common.SysError("failed to convert setting to map: " + err.Error())
return nil, fmt.Errorf("failed to convert setting to map")
}
return toMap, nil
} }
func IncreaseUserQuota(id int, quota int, db bool) (err error) { func IncreaseUserQuota(id int, quota int, db bool) (err error) {

View File

@@ -36,7 +36,12 @@ func (user *UserBase) GetSetting() map[string]interface{} {
if user.Setting == "" { if user.Setting == "" {
return nil return nil
} }
return common.StrToMap(user.Setting) toMap, err := common.StrToMap(user.Setting)
if err != nil {
common.SysError("failed to convert user setting to map: " + err.Error())
return nil
}
return toMap
} }
func (user *UserBase) SetSetting(setting map[string]interface{}) { func (user *UserBase) SetSetting(setting map[string]interface{}) {

View File

@@ -4,8 +4,11 @@ import "one-api/common"
func GetModelRegion(other string, localModelName string) string { func GetModelRegion(other string, localModelName string) string {
// if other is json string // if other is json string
if common.IsJsonStr(other) { if common.IsJsonObject(other) {
m := common.StrToMap(other) m, err := common.StrToMap(other)
if err != nil {
return other // return original if parsing fails
}
if m[localModelName] != nil { if m[localModelName] != nil {
return m[localModelName].(string) return m[localModelName].(string)
} else { } else {

View File

@@ -214,7 +214,7 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
channelType := common.GetContextKeyInt(c, constant.ContextKeyChannelType) channelType := common.GetContextKeyInt(c, constant.ContextKeyChannelType)
channelId := common.GetContextKeyInt(c, constant.ContextKeyChannelId) channelId := common.GetContextKeyInt(c, constant.ContextKeyChannelId)
channelSetting := common.GetContextKeyStringMap(c, constant.ContextKeyChannelSetting) channelSetting := common.GetContextKeyStringMap(c, constant.ContextKeyChannelSetting)
paramOverride := common.GetContextKeyStringMap(c, constant.ContextKeyParamOverride) paramOverride := common.GetContextKeyStringMap(c, constant.ContextKeyChannelParamOverride)
tokenId := common.GetContextKeyInt(c, constant.ContextKeyTokenId) tokenId := common.GetContextKeyInt(c, constant.ContextKeyTokenId)
tokenKey := common.GetContextKeyString(c, constant.ContextKeyTokenKey) tokenKey := common.GetContextKeyString(c, constant.ContextKeyTokenKey)
@@ -231,7 +231,7 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
UserEmail: common.GetContextKeyString(c, constant.ContextKeyUserEmail), UserEmail: common.GetContextKeyString(c, constant.ContextKeyUserEmail),
isFirstResponse: true, isFirstResponse: true,
RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path), RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path),
BaseUrl: common.GetContextKeyString(c, constant.ContextKeyBaseUrl), BaseUrl: common.GetContextKeyString(c, constant.ContextKeyChannelBaseUrl),
RequestURLPath: c.Request.URL.String(), RequestURLPath: c.Request.URL.String(),
ChannelType: channelType, ChannelType: channelType,
ChannelId: channelId, ChannelId: channelId,

View File

@@ -20,7 +20,7 @@ func SetRelayRouter(router *gin.Engine) {
modelsRouter.GET("/:model", controller.RetrieveModel) modelsRouter.GET("/:model", controller.RetrieveModel)
} }
playgroundRouter := router.Group("/pg") playgroundRouter := router.Group("/pg")
playgroundRouter.Use(middleware.UserAuth()) playgroundRouter.Use(middleware.UserAuth(), middleware.Distribute())
{ {
playgroundRouter.POST("/chat/completions", controller.Playground) playgroundRouter.POST("/chat/completions", controller.Playground)
} }