🔧 refactor(auth, channel, context): improve context setup and validation for multi-key channels

This commit is contained in:
CaIon
2025-07-06 12:37:56 +08:00
parent b695e67154
commit f0f277dc2a
15 changed files with 294 additions and 114 deletions

View File

@@ -31,16 +31,30 @@ func MapToJsonStr(m map[string]interface{}) string {
return string(bytes)
}
func StrToMap(str string) map[string]interface{} {
func StrToMap(str string) (map[string]interface{}, error) {
m := make(map[string]interface{})
err := json.Unmarshal([]byte(str), &m)
err := UnmarshalJson([]byte(str), &m)
if err != nil {
return nil
return nil, err
}
return m
return m, nil
}
func IsJsonStr(str string) bool {
func StrToJsonArray(str string) ([]interface{}, error) {
var js []interface{}
err := json.Unmarshal([]byte(str), &js)
if err != nil {
return nil, err
}
return js, nil
}
func IsJsonArray(str string) bool {
var js []interface{}
return json.Unmarshal([]byte(str), &js) == nil
}
func IsJsonObject(str string) bool {
var js map[string]interface{}
return json.Unmarshal([]byte(str), &js) == nil
}

View File

@@ -17,11 +17,18 @@ const (
ContextKeyTokenModelLimit ContextKey = "token_model_limit"
/* channel related keys */
ContextKeyBaseUrl ContextKey = "base_url"
ContextKeyChannelType ContextKey = "channel_type"
ContextKeyChannelId ContextKey = "channel_id"
ContextKeyChannelName ContextKey = "channel_name"
ContextKeyChannelCreateTime ContextKey = "channel_create_name"
ContextKeyChannelBaseUrl ContextKey = "base_url"
ContextKeyChannelType ContextKey = "channel_type"
ContextKeyChannelSetting ContextKey = "channel_setting"
ContextKeyParamOverride ContextKey = "param_override"
ContextKeyChannelParamOverride ContextKey = "param_override"
ContextKeyChannelOrganization ContextKey = "channel_organization"
ContextKeyChannelAutoBan ContextKey = "auto_ban"
ContextKeyChannelModelMapping ContextKey = "model_mapping"
ContextKeyChannelStatusCodeMapping ContextKey = "status_code_mapping"
ContextKeyChannelIsMultiKey ContextKey = "channel_is_multi_key"
/* user related keys */
ContextKeyUserId ContextKey = "id"

View File

@@ -0,0 +1,8 @@
package constant
type MultiKeyMode string
const (
MultiKeyModeRandom MultiKeyMode = "random" // 随机
MultiKeyModePolling MultiKeyMode = "polling" // 轮询
)

View File

@@ -379,9 +379,32 @@ func GetChannel(c *gin.Context) {
type AddChannelRequest struct {
Mode string `json:"mode"`
MultiKeyMode constant.MultiKeyMode `json:"multi_key_mode"`
Channel *model.Channel `json:"channel"`
}
func getVertexArrayKeys(keys string) ([]string, error) {
if keys == "" {
return nil, nil
}
var keyArray []interface{}
err := common.UnmarshalJson([]byte(keys), &keyArray)
if err != nil {
return nil, fmt.Errorf("批量添加 Vertex AI 必须使用标准的JsonArray格式例如[{key1}, {key2}...],请检查输入: %w", err)
}
cleanKeys := make([]string, 0, len(keyArray))
for _, key := range keyArray {
keyStr := fmt.Sprintf("%v", key)
if keyStr != "" {
cleanKeys = append(cleanKeys, strings.TrimSpace(keyStr))
}
}
if len(cleanKeys) == 0 {
return nil, fmt.Errorf("批量添加 Vertex AI 的 keys 不能为空")
}
return cleanKeys, nil
}
func AddChannel(c *gin.Context) {
addChannelRequest := AddChannelRequest{}
err := c.ShouldBindJSON(&addChannelRequest)
@@ -418,9 +441,14 @@ func AddChannel(c *gin.Context) {
})
return
} else {
if common.IsJsonStr(addChannelRequest.Channel.Other) {
// must have default
regionMap := common.StrToMap(addChannelRequest.Channel.Other)
regionMap, err := common.StrToMap(addChannelRequest.Channel.Other)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "部署地区必须是标准的Json格式例如{\"default\": \"us-central1\", \"region2\": \"us-east1\"}",
})
return
}
if regionMap["default"] == nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -430,57 +458,46 @@ func AddChannel(c *gin.Context) {
}
}
}
}
addChannelRequest.Channel.CreatedTime = common.GetTimestamp()
keys := make([]string, 0)
switch addChannelRequest.Mode {
case "multi_to_single":
addChannelRequest.Channel.ChannelInfo.MultiKeyMode = true
addChannelRequest.Channel.ChannelInfo.IsMultiKey = true
addChannelRequest.Channel.ChannelInfo.MultiKeyMode = addChannelRequest.MultiKeyMode
if addChannelRequest.Channel.Type == constant.ChannelTypeVertexAi {
if !common.IsJsonStr(addChannelRequest.Channel.Key) {
array, err := getVertexArrayKeys(addChannelRequest.Channel.Key)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "Vertex AI 批量添加模式必须使用标准的JsonArray格式例如[{key1}, {key2}...],请检查输入",
"message": err.Error(),
})
return
}
toMap := common.StrToMap(addChannelRequest.Channel.Key)
if toMap != nil {
addChannelRequest.Channel.ChannelInfo.MultiKeySize = len(toMap)
} else {
addChannelRequest.Channel.ChannelInfo.MultiKeySize = 0
}
addChannelRequest.Channel.Key = strings.Join(array, "\n")
} else {
cleanKeys := make([]string, 0)
for _, key := range strings.Split(addChannelRequest.Channel.Key, "\n") {
if key == "" {
continue
}
key = strings.TrimSpace(key)
cleanKeys = append(cleanKeys, key)
}
addChannelRequest.Channel.ChannelInfo.MultiKeySize = len(cleanKeys)
addChannelRequest.Channel.Key = strings.Join(cleanKeys, "\n")
}
keys = []string{addChannelRequest.Channel.Key}
case "batch":
if addChannelRequest.Channel.Type == constant.ChannelTypeVertexAi {
// multi json
toMap := common.StrToMap(addChannelRequest.Channel.Key)
if toMap == nil {
keys, err = getVertexArrayKeys(addChannelRequest.Channel.Key)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "Vertex AI 批量添加模式必须使用标准的JsonArray格式例如[{key1}, {key2}...],请检查输入",
"message": err.Error(),
})
return
}
keys = make([]string, 0, len(toMap))
for k := range toMap {
if k == "" {
continue
}
keys = append(keys, k)
}
} else {
keys = strings.Split(addChannelRequest.Channel.Key, "\n")
}
@@ -694,9 +711,14 @@ func UpdateChannel(c *gin.Context) {
})
return
} else {
if common.IsJsonStr(channel.Other) {
// must have default
regionMap := common.StrToMap(channel.Other)
regionMap, err := common.StrToMap(channel.Other)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "部署地区必须是标准的Json格式例如{\"default\": \"us-central1\", \"region2\": \"us-east1\"}",
})
return
}
if regionMap["default"] == nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -706,7 +728,6 @@ func UpdateChannel(c *gin.Context) {
}
}
}
}
err = channel.Update()
if err != nil {
c.JSON(http.StatusOK, gin.H{

View File

@@ -57,18 +57,24 @@ func Playground(c *gin.Context) {
}
c.Set("group", group)
}
c.Set("token_name", "playground-"+group)
channel, finalGroup, err := model.CacheGetRandomSatisfiedChannel(c, group, playgroundRequest.Model, 0)
userId := c.GetInt("id")
//c.Set("token_name", "playground-"+group)
tempToken := &model.Token{
UserId: userId,
Name: fmt.Sprintf("playground-%s", group),
Group: group,
}
_ = middleware.SetupContextForToken(c, tempToken)
_, err = getChannel(c, group, playgroundRequest.Model, 0)
if err != nil {
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", finalGroup, playgroundRequest.Model)
openaiErr = service.OpenAIErrorWrapperLocal(errors.New(message), "get_playground_channel_failed", http.StatusInternalServerError)
openaiErr = service.OpenAIErrorWrapperLocal(err, "get_playground_channel_failed", http.StatusInternalServerError)
return
}
middleware.SetupContextForSelectedChannel(c, channel, playgroundRequest.Model)
//middleware.SetupContextForSelectedChannel(c, channel, playgroundRequest.Model)
common.SetContextKey(c, constant.ContextKeyRequestStartTime, time.Now())
// Write user context to ensure acceptUnsetRatio is available
userId := c.GetInt("id")
userCache, err := model.GetUserCache(userId)
if err != nil {
openaiErr = service.OpenAIErrorWrapperLocal(err, "get_user_cache_failed", http.StatusInternalServerError)

View File

@@ -259,9 +259,12 @@ func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*m
AutoBan: &autoBanInt,
}, nil
}
channel, _, err := model.CacheGetRandomSatisfiedChannel(c, group, originalModel, retryCount)
channel, selectGroup, err := model.CacheGetRandomSatisfiedChannel(c, group, originalModel, retryCount)
if err != nil {
return nil, errors.New(fmt.Sprintf("获取重试渠道失败: %s", err.Error()))
if group == "auto" {
return nil, errors.New(fmt.Sprintf("获取自动分组下模型 %s 的可用渠道失败: %s", originalModel, err.Error()))
}
return nil, errors.New(fmt.Sprintf("获取分组 %s 下模型 %s 的可用渠道失败: %s", selectGroup, originalModel, err.Error()))
}
middleware.SetupContextForSelectedChannel(c, channel, originalModel)
return channel, nil
@@ -388,9 +391,10 @@ func RelayTask(c *gin.Context) {
retryTimes = 0
}
for i := 0; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && i < retryTimes; i++ {
channel, _, err := model.CacheGetRandomSatisfiedChannel(c, group, originalModel, i)
channel, err := getChannel(c, group, originalModel, i)
if err != nil {
common.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", err.Error()))
taskErr = service.TaskErrorWrapperLocal(err, "get_channel_failed", http.StatusInternalServerError)
break
}
channelId = channel.Id
@@ -398,7 +402,7 @@ func RelayTask(c *gin.Context) {
useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
c.Set("use_channel", useChannel)
common.LogInfo(c, fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i))
middleware.SetupContextForSelectedChannel(c, channel, originalModel)
//middleware.SetupContextForSelectedChannel(c, channel, originalModel)
requestBody, err := common.GetRequestBody(c)
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))

View File

@@ -1,6 +1,7 @@
package middleware
import (
"fmt"
"net/http"
"one-api/common"
"one-api/model"
@@ -233,6 +234,18 @@ func TokenAuth() func(c *gin.Context) {
userCache.WriteContext(c)
err = SetupContextForToken(c, token, parts...)
if err != nil {
return
}
c.Next()
}
}
func SetupContextForToken(c *gin.Context, token *model.Token, parts ...string) error {
if token == nil {
return fmt.Errorf("token is nil")
}
c.Set("id", token.UserId)
c.Set("token_id", token.Id)
c.Set("token_key", token.Key)
@@ -254,9 +267,8 @@ func TokenAuth() func(c *gin.Context) {
c.Set("specific_channel_id", parts[1])
} else {
abortWithOpenAiMessage(c, http.StatusForbidden, "普通用户不支持指定渠道")
return
return fmt.Errorf("普通用户不支持指定渠道")
}
}
c.Next()
}
return nil
}

View File

@@ -21,6 +21,7 @@ import (
type ModelRequest struct {
Model string `json:"model"`
Group string `json:"group,omitempty"`
}
func Distribute() func(c *gin.Context) {
@@ -237,6 +238,14 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
}
c.Set("relay_mode", relayMode)
}
if strings.HasPrefix(c.Request.URL.Path, "/pg/chat/completions") {
// playground chat completions
err = common.UnmarshalBodyReusable(c, &modelRequest)
if err != nil {
return nil, false, errors.New("无效的请求, " + err.Error())
}
common.SetContextKey(c, constant.ContextKeyTokenGroup, modelRequest.Group)
}
return &modelRequest, shouldSelectChannel, nil
}
@@ -245,20 +254,25 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
if channel == nil {
return
}
c.Set("channel_id", channel.Id)
c.Set("channel_name", channel.Name)
c.Set("channel_type", channel.Type)
c.Set("channel_create_time", channel.CreatedTime)
c.Set("channel_setting", channel.GetSetting())
c.Set("param_override", channel.GetParamOverride())
if nil != channel.OpenAIOrganization && "" != *channel.OpenAIOrganization {
c.Set("channel_organization", *channel.OpenAIOrganization)
common.SetContextKey(c, constant.ContextKeyChannelId, channel.Id)
common.SetContextKey(c, constant.ContextKeyChannelName, channel.Name)
common.SetContextKey(c, constant.ContextKeyChannelType, channel.Type)
common.SetContextKey(c, constant.ContextKeyChannelCreateTime, channel.CreatedTime)
common.SetContextKey(c, constant.ContextKeyChannelSetting, channel.GetSetting())
common.SetContextKey(c, constant.ContextKeyChannelParamOverride, channel.GetParamOverride())
if nil != channel.OpenAIOrganization && *channel.OpenAIOrganization != "" {
common.SetContextKey(c, constant.ContextKeyChannelOrganization, *channel.OpenAIOrganization)
}
common.SetContextKey(c, constant.ContextKeyChannelAutoBan, channel.GetAutoBan())
common.SetContextKey(c, constant.ContextKeyChannelModelMapping, channel.GetModelMapping())
common.SetContextKey(c, constant.ContextKeyChannelStatusCodeMapping, channel.GetStatusCodeMapping())
if channel.ChannelInfo.IsMultiKey {
common.SetContextKey(c, constant.ContextKeyChannelIsMultiKey, true)
}
c.Set("auto_ban", channel.GetAutoBan())
c.Set("model_mapping", channel.GetModelMapping())
c.Set("status_code_mapping", channel.GetStatusCodeMapping())
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
c.Set("base_url", channel.GetBaseURL())
common.SetContextKey(c, constant.ContextKeyChannelBaseUrl, channel.GetBaseURL())
// TODO: api_version统一
switch channel.Type {
case constant.ChannelTypeAzure:

View File

@@ -3,7 +3,10 @@ package model
import (
"database/sql/driver"
"encoding/json"
"fmt"
"math/rand"
"one-api/common"
"one-api/constant"
"strings"
"sync"
@@ -43,20 +46,93 @@ type Channel struct {
}
type ChannelInfo struct {
MultiKeyMode bool `json:"multi_key_mode"` // 是否多Key模式
IsMultiKey bool `json:"is_multi_key"` // 是否多Key模式
MultiKeyStatusList map[int]int `json:"multi_key_status_list"` // key状态列表key index -> status
MultiKeySize int `json:"multi_key_size"` // 多Key模式下的key数量
MultiKeyPollingIndex int `json:"multi_key_polling_index"` // 多Key模式下轮询的key索引
MultiKeyMode constant.MultiKeyMode `json:"multi_key_mode"`
}
// Value implements driver.Valuer interface
func (c ChannelInfo) Value() (driver.Value, error) {
return json.Marshal(c)
func (c *ChannelInfo) Value() (driver.Value, error) {
return common.EncodeJson(c)
}
// Scan implements sql.Scanner interface
func (c *ChannelInfo) Scan(value interface{}) error {
bytesValue, _ := value.([]byte)
return json.Unmarshal(bytesValue, c)
return common.UnmarshalJson(bytesValue, c)
}
func (channel *Channel) getKeys() []string {
if channel.Key == "" {
return []string{}
}
// use \n to split keys
keys := strings.Split(strings.Trim(channel.Key, "\n"), "\n")
return keys
}
func (channel *Channel) GetNextEnabledKey() (string, error) {
// If not in multi-key mode, return the original key string directly.
if !channel.ChannelInfo.IsMultiKey {
return channel.Key, nil
}
// Obtain all keys (split by \n)
keys := channel.getKeys()
if len(keys) == 0 {
// No keys available, return error, should disable the channel
return "", fmt.Errorf("no valid keys in channel")
}
statusList := channel.ChannelInfo.MultiKeyStatusList
// helper to get key status, default to enabled when missing
getStatus := func(idx int) int {
if statusList == nil {
return common.ChannelStatusEnabled
}
if status, ok := statusList[idx]; ok {
return status
}
return common.ChannelStatusEnabled
}
// Collect indexes of enabled keys
enabledIdx := make([]int, 0, len(keys))
for i := range keys {
if getStatus(i) == common.ChannelStatusEnabled {
enabledIdx = append(enabledIdx, i)
}
}
// If no specific status list or none enabled, fall back to first key
if len(enabledIdx) == 0 {
return keys[0], nil
}
switch channel.ChannelInfo.MultiKeyMode {
case constant.MultiKeyModeRandom:
// Randomly pick one enabled key
return keys[enabledIdx[rand.Intn(len(enabledIdx))]], nil
case constant.MultiKeyModePolling:
// Start from the saved polling index and look for the next enabled key
start := channel.ChannelInfo.MultiKeyPollingIndex
if start < 0 || start >= len(keys) {
start = 0
}
for i := 0; i < len(keys); i++ {
idx := (start + i) % len(keys)
if getStatus(idx) == common.ChannelStatusEnabled {
// update polling index for next call (point to the next position)
channel.ChannelInfo.MultiKeyPollingIndex = (idx + 1) % len(keys)
return keys[idx], nil
}
}
// Fallback should not happen, but return first enabled key
return keys[enabledIdx[0]], nil
default:
// Unknown mode, default to first enabled key (or original key string)
return keys[enabledIdx[0]], nil
}
}
func (channel *Channel) GetModels() []string {

View File

@@ -50,7 +50,7 @@ func formatUserLogs(logs []*Log) {
for i := range logs {
logs[i].ChannelName = ""
var otherMap map[string]interface{}
otherMap = common.StrToMap(logs[i].Other)
otherMap, _ = common.StrToMap(logs[i].Other)
if otherMap != nil {
// delete admin
delete(otherMap, "admin_info")

View File

@@ -68,11 +68,16 @@ func (user *User) SetAccessToken(token string) {
user.AccessToken = &token
}
func (user *User) GetSetting() map[string]interface{} {
func (user *User) GetSetting() (map[string]interface{}, error) {
if user.Setting == "" {
return nil
return map[string]interface{}{}, nil
}
return common.StrToMap(user.Setting)
toMap, err := common.StrToMap(user.Setting)
if err != nil {
common.SysError("failed to convert setting to map: " + err.Error())
return nil, fmt.Errorf("failed to convert setting to map")
}
return toMap, nil
}
func (user *User) SetSetting(setting map[string]interface{}) {
@@ -651,7 +656,12 @@ func GetUserSetting(id int, fromDB bool) (settingMap map[string]interface{}, err
return map[string]interface{}{}, err
}
return common.StrToMap(setting), nil
toMap, err := common.StrToMap(setting)
if err != nil {
common.SysError("failed to convert setting to map: " + err.Error())
return nil, fmt.Errorf("failed to convert setting to map")
}
return toMap, nil
}
func IncreaseUserQuota(id int, quota int, db bool) (err error) {

View File

@@ -36,7 +36,12 @@ func (user *UserBase) GetSetting() map[string]interface{} {
if user.Setting == "" {
return nil
}
return common.StrToMap(user.Setting)
toMap, err := common.StrToMap(user.Setting)
if err != nil {
common.SysError("failed to convert user setting to map: " + err.Error())
return nil
}
return toMap
}
func (user *UserBase) SetSetting(setting map[string]interface{}) {

View File

@@ -4,8 +4,11 @@ import "one-api/common"
func GetModelRegion(other string, localModelName string) string {
// if other is json string
if common.IsJsonStr(other) {
m := common.StrToMap(other)
if common.IsJsonObject(other) {
m, err := common.StrToMap(other)
if err != nil {
return other // return original if parsing fails
}
if m[localModelName] != nil {
return m[localModelName].(string)
} else {

View File

@@ -214,7 +214,7 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
channelType := common.GetContextKeyInt(c, constant.ContextKeyChannelType)
channelId := common.GetContextKeyInt(c, constant.ContextKeyChannelId)
channelSetting := common.GetContextKeyStringMap(c, constant.ContextKeyChannelSetting)
paramOverride := common.GetContextKeyStringMap(c, constant.ContextKeyParamOverride)
paramOverride := common.GetContextKeyStringMap(c, constant.ContextKeyChannelParamOverride)
tokenId := common.GetContextKeyInt(c, constant.ContextKeyTokenId)
tokenKey := common.GetContextKeyString(c, constant.ContextKeyTokenKey)
@@ -231,7 +231,7 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
UserEmail: common.GetContextKeyString(c, constant.ContextKeyUserEmail),
isFirstResponse: true,
RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path),
BaseUrl: common.GetContextKeyString(c, constant.ContextKeyBaseUrl),
BaseUrl: common.GetContextKeyString(c, constant.ContextKeyChannelBaseUrl),
RequestURLPath: c.Request.URL.String(),
ChannelType: channelType,
ChannelId: channelId,

View File

@@ -20,7 +20,7 @@ func SetRelayRouter(router *gin.Engine) {
modelsRouter.GET("/:model", controller.RetrieveModel)
}
playgroundRouter := router.Group("/pg")
playgroundRouter.Use(middleware.UserAuth())
playgroundRouter.Use(middleware.UserAuth(), middleware.Distribute())
{
playgroundRouter.POST("/chat/completions", controller.Playground)
}