🔧 refactor(auth, channel, context): improve context setup and validation for multi-key channels
This commit is contained in:
@@ -3,7 +3,10 @@ package model
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
@@ -43,20 +46,93 @@ type Channel struct {
|
||||
}
|
||||
|
||||
type ChannelInfo struct {
|
||||
MultiKeyMode bool `json:"multi_key_mode"` // 是否多Key模式
|
||||
MultiKeyStatusList map[int]int `json:"multi_key_status_list"` // key状态列表,key index -> status
|
||||
MultiKeySize int `json:"multi_key_size"` // 多Key模式下的key数量
|
||||
IsMultiKey bool `json:"is_multi_key"` // 是否多Key模式
|
||||
MultiKeyStatusList map[int]int `json:"multi_key_status_list"` // key状态列表,key index -> status
|
||||
MultiKeyPollingIndex int `json:"multi_key_polling_index"` // 多Key模式下轮询的key索引
|
||||
MultiKeyMode constant.MultiKeyMode `json:"multi_key_mode"`
|
||||
}
|
||||
|
||||
// Value implements driver.Valuer interface
|
||||
func (c ChannelInfo) Value() (driver.Value, error) {
|
||||
return json.Marshal(c)
|
||||
func (c *ChannelInfo) Value() (driver.Value, error) {
|
||||
return common.EncodeJson(c)
|
||||
}
|
||||
|
||||
// Scan implements sql.Scanner interface
|
||||
func (c *ChannelInfo) Scan(value interface{}) error {
|
||||
bytesValue, _ := value.([]byte)
|
||||
return json.Unmarshal(bytesValue, c)
|
||||
return common.UnmarshalJson(bytesValue, c)
|
||||
}
|
||||
|
||||
func (channel *Channel) getKeys() []string {
|
||||
if channel.Key == "" {
|
||||
return []string{}
|
||||
}
|
||||
// use \n to split keys
|
||||
keys := strings.Split(strings.Trim(channel.Key, "\n"), "\n")
|
||||
return keys
|
||||
}
|
||||
|
||||
func (channel *Channel) GetNextEnabledKey() (string, error) {
|
||||
// If not in multi-key mode, return the original key string directly.
|
||||
if !channel.ChannelInfo.IsMultiKey {
|
||||
return channel.Key, nil
|
||||
}
|
||||
|
||||
// Obtain all keys (split by \n)
|
||||
keys := channel.getKeys()
|
||||
if len(keys) == 0 {
|
||||
// No keys available, return error, should disable the channel
|
||||
return "", fmt.Errorf("no valid keys in channel")
|
||||
}
|
||||
|
||||
statusList := channel.ChannelInfo.MultiKeyStatusList
|
||||
// helper to get key status, default to enabled when missing
|
||||
getStatus := func(idx int) int {
|
||||
if statusList == nil {
|
||||
return common.ChannelStatusEnabled
|
||||
}
|
||||
if status, ok := statusList[idx]; ok {
|
||||
return status
|
||||
}
|
||||
return common.ChannelStatusEnabled
|
||||
}
|
||||
|
||||
// Collect indexes of enabled keys
|
||||
enabledIdx := make([]int, 0, len(keys))
|
||||
for i := range keys {
|
||||
if getStatus(i) == common.ChannelStatusEnabled {
|
||||
enabledIdx = append(enabledIdx, i)
|
||||
}
|
||||
}
|
||||
// If no specific status list or none enabled, fall back to first key
|
||||
if len(enabledIdx) == 0 {
|
||||
return keys[0], nil
|
||||
}
|
||||
|
||||
switch channel.ChannelInfo.MultiKeyMode {
|
||||
case constant.MultiKeyModeRandom:
|
||||
// Randomly pick one enabled key
|
||||
return keys[enabledIdx[rand.Intn(len(enabledIdx))]], nil
|
||||
case constant.MultiKeyModePolling:
|
||||
// Start from the saved polling index and look for the next enabled key
|
||||
start := channel.ChannelInfo.MultiKeyPollingIndex
|
||||
if start < 0 || start >= len(keys) {
|
||||
start = 0
|
||||
}
|
||||
for i := 0; i < len(keys); i++ {
|
||||
idx := (start + i) % len(keys)
|
||||
if getStatus(idx) == common.ChannelStatusEnabled {
|
||||
// update polling index for next call (point to the next position)
|
||||
channel.ChannelInfo.MultiKeyPollingIndex = (idx + 1) % len(keys)
|
||||
return keys[idx], nil
|
||||
}
|
||||
}
|
||||
// Fallback – should not happen, but return first enabled key
|
||||
return keys[enabledIdx[0]], nil
|
||||
default:
|
||||
// Unknown mode, default to first enabled key (or original key string)
|
||||
return keys[enabledIdx[0]], nil
|
||||
}
|
||||
}
|
||||
|
||||
func (channel *Channel) GetModels() []string {
|
||||
|
||||
@@ -50,7 +50,7 @@ func formatUserLogs(logs []*Log) {
|
||||
for i := range logs {
|
||||
logs[i].ChannelName = ""
|
||||
var otherMap map[string]interface{}
|
||||
otherMap = common.StrToMap(logs[i].Other)
|
||||
otherMap, _ = common.StrToMap(logs[i].Other)
|
||||
if otherMap != nil {
|
||||
// delete admin
|
||||
delete(otherMap, "admin_info")
|
||||
|
||||
@@ -68,11 +68,16 @@ func (user *User) SetAccessToken(token string) {
|
||||
user.AccessToken = &token
|
||||
}
|
||||
|
||||
func (user *User) GetSetting() map[string]interface{} {
|
||||
func (user *User) GetSetting() (map[string]interface{}, error) {
|
||||
if user.Setting == "" {
|
||||
return nil
|
||||
return map[string]interface{}{}, nil
|
||||
}
|
||||
return common.StrToMap(user.Setting)
|
||||
toMap, err := common.StrToMap(user.Setting)
|
||||
if err != nil {
|
||||
common.SysError("failed to convert setting to map: " + err.Error())
|
||||
return nil, fmt.Errorf("failed to convert setting to map")
|
||||
}
|
||||
return toMap, nil
|
||||
}
|
||||
|
||||
func (user *User) SetSetting(setting map[string]interface{}) {
|
||||
@@ -651,7 +656,12 @@ func GetUserSetting(id int, fromDB bool) (settingMap map[string]interface{}, err
|
||||
return map[string]interface{}{}, err
|
||||
}
|
||||
|
||||
return common.StrToMap(setting), nil
|
||||
toMap, err := common.StrToMap(setting)
|
||||
if err != nil {
|
||||
common.SysError("failed to convert setting to map: " + err.Error())
|
||||
return nil, fmt.Errorf("failed to convert setting to map")
|
||||
}
|
||||
return toMap, nil
|
||||
}
|
||||
|
||||
func IncreaseUserQuota(id int, quota int, db bool) (err error) {
|
||||
|
||||
@@ -36,7 +36,12 @@ func (user *UserBase) GetSetting() map[string]interface{} {
|
||||
if user.Setting == "" {
|
||||
return nil
|
||||
}
|
||||
return common.StrToMap(user.Setting)
|
||||
toMap, err := common.StrToMap(user.Setting)
|
||||
if err != nil {
|
||||
common.SysError("failed to convert user setting to map: " + err.Error())
|
||||
return nil
|
||||
}
|
||||
return toMap
|
||||
}
|
||||
|
||||
func (user *UserBase) SetSetting(setting map[string]interface{}) {
|
||||
|
||||
Reference in New Issue
Block a user