🔧 refactor(auth, channel, context): improve context setup and validation for multi-key channels

This commit is contained in:
CaIon
2025-07-06 12:37:56 +08:00
parent b695e67154
commit f0f277dc2a
15 changed files with 294 additions and 114 deletions

View File

@@ -3,7 +3,10 @@ package model
import (
"database/sql/driver"
"encoding/json"
"fmt"
"math/rand"
"one-api/common"
"one-api/constant"
"strings"
"sync"
@@ -43,20 +46,93 @@ type Channel struct {
}
type ChannelInfo struct {
MultiKeyMode bool `json:"multi_key_mode"` // 是否多Key模式
MultiKeyStatusList map[int]int `json:"multi_key_status_list"` // key状态列表key index -> status
MultiKeySize int `json:"multi_key_size"` // 多Key模式下的key数量
IsMultiKey bool `json:"is_multi_key"` // 是否多Key模式
MultiKeyStatusList map[int]int `json:"multi_key_status_list"` // key状态列表key index -> status
MultiKeyPollingIndex int `json:"multi_key_polling_index"` // 多Key模式下轮询的key索引
MultiKeyMode constant.MultiKeyMode `json:"multi_key_mode"`
}
// Value implements driver.Valuer interface
func (c ChannelInfo) Value() (driver.Value, error) {
return json.Marshal(c)
func (c *ChannelInfo) Value() (driver.Value, error) {
return common.EncodeJson(c)
}
// Scan implements sql.Scanner interface
func (c *ChannelInfo) Scan(value interface{}) error {
bytesValue, _ := value.([]byte)
return json.Unmarshal(bytesValue, c)
return common.UnmarshalJson(bytesValue, c)
}
func (channel *Channel) getKeys() []string {
if channel.Key == "" {
return []string{}
}
// use \n to split keys
keys := strings.Split(strings.Trim(channel.Key, "\n"), "\n")
return keys
}
func (channel *Channel) GetNextEnabledKey() (string, error) {
// If not in multi-key mode, return the original key string directly.
if !channel.ChannelInfo.IsMultiKey {
return channel.Key, nil
}
// Obtain all keys (split by \n)
keys := channel.getKeys()
if len(keys) == 0 {
// No keys available, return error, should disable the channel
return "", fmt.Errorf("no valid keys in channel")
}
statusList := channel.ChannelInfo.MultiKeyStatusList
// helper to get key status, default to enabled when missing
getStatus := func(idx int) int {
if statusList == nil {
return common.ChannelStatusEnabled
}
if status, ok := statusList[idx]; ok {
return status
}
return common.ChannelStatusEnabled
}
// Collect indexes of enabled keys
enabledIdx := make([]int, 0, len(keys))
for i := range keys {
if getStatus(i) == common.ChannelStatusEnabled {
enabledIdx = append(enabledIdx, i)
}
}
// If no specific status list or none enabled, fall back to first key
if len(enabledIdx) == 0 {
return keys[0], nil
}
switch channel.ChannelInfo.MultiKeyMode {
case constant.MultiKeyModeRandom:
// Randomly pick one enabled key
return keys[enabledIdx[rand.Intn(len(enabledIdx))]], nil
case constant.MultiKeyModePolling:
// Start from the saved polling index and look for the next enabled key
start := channel.ChannelInfo.MultiKeyPollingIndex
if start < 0 || start >= len(keys) {
start = 0
}
for i := 0; i < len(keys); i++ {
idx := (start + i) % len(keys)
if getStatus(idx) == common.ChannelStatusEnabled {
// update polling index for next call (point to the next position)
channel.ChannelInfo.MultiKeyPollingIndex = (idx + 1) % len(keys)
return keys[idx], nil
}
}
// Fallback should not happen, but return first enabled key
return keys[enabledIdx[0]], nil
default:
// Unknown mode, default to first enabled key (or original key string)
return keys[enabledIdx[0]], nil
}
}
func (channel *Channel) GetModels() []string {

View File

@@ -50,7 +50,7 @@ func formatUserLogs(logs []*Log) {
for i := range logs {
logs[i].ChannelName = ""
var otherMap map[string]interface{}
otherMap = common.StrToMap(logs[i].Other)
otherMap, _ = common.StrToMap(logs[i].Other)
if otherMap != nil {
// delete admin
delete(otherMap, "admin_info")

View File

@@ -68,11 +68,16 @@ func (user *User) SetAccessToken(token string) {
user.AccessToken = &token
}
func (user *User) GetSetting() map[string]interface{} {
func (user *User) GetSetting() (map[string]interface{}, error) {
if user.Setting == "" {
return nil
return map[string]interface{}{}, nil
}
return common.StrToMap(user.Setting)
toMap, err := common.StrToMap(user.Setting)
if err != nil {
common.SysError("failed to convert setting to map: " + err.Error())
return nil, fmt.Errorf("failed to convert setting to map")
}
return toMap, nil
}
func (user *User) SetSetting(setting map[string]interface{}) {
@@ -651,7 +656,12 @@ func GetUserSetting(id int, fromDB bool) (settingMap map[string]interface{}, err
return map[string]interface{}{}, err
}
return common.StrToMap(setting), nil
toMap, err := common.StrToMap(setting)
if err != nil {
common.SysError("failed to convert setting to map: " + err.Error())
return nil, fmt.Errorf("failed to convert setting to map")
}
return toMap, nil
}
func IncreaseUserQuota(id int, quota int, db bool) (err error) {

View File

@@ -36,7 +36,12 @@ func (user *UserBase) GetSetting() map[string]interface{} {
if user.Setting == "" {
return nil
}
return common.StrToMap(user.Setting)
toMap, err := common.StrToMap(user.Setting)
if err != nil {
common.SysError("failed to convert user setting to map: " + err.Error())
return nil
}
return toMap
}
func (user *UserBase) SetSetting(setting map[string]interface{}) {