feat(channel): improve channel cache handling and add error checks for disabled channels

This commit is contained in:
CaIon
2025-07-12 14:20:59 +08:00
parent 23e4e25e9a
commit 50b76f4466
3 changed files with 55 additions and 12 deletions

View File

@@ -110,7 +110,7 @@ func testChannel(channel *model.Channel, testModel string) testResult {
}
cache.WriteContext(c)
c.Request.Header.Set("Authorization", "Bearer "+channel.Key)
//c.Request.Header.Set("Authorization", "Bearer "+channel.Key)
c.Request.Header.Set("Content-Type", "application/json")
c.Set("channel", channel.Type)
c.Set("base_url", channel.GetBaseURL())
@@ -320,6 +320,11 @@ func TestChannel(c *gin.Context) {
})
return
}
//defer func() {
// if channel.ChannelInfo.IsMultiKey {
// go func() { _ = channel.SaveChannelInfo() }()
// }
//}()
testModel := c.Query("model")
tik := time.Now()
result := testChannel(channel, testModel)

View File

@@ -4,6 +4,7 @@ import (
"database/sql/driver"
"encoding/json"
"errors"
"fmt"
"math/rand"
"one-api/common"
"one-api/constant"
@@ -122,15 +123,23 @@ func (channel *Channel) GetNextEnabledKey() (string, *types.NewAPIError) {
lock.Lock()
defer lock.Unlock()
channelInfo, err := CacheGetChannelInfo(channel.Id)
if err != nil {
return "", types.NewError(err, types.ErrorCodeGetChannelFailed)
}
//println("before polling index:", channel.ChannelInfo.MultiKeyPollingIndex)
defer func() {
if common.DebugEnabled {
println(fmt.Sprintf("channel %d polling index: %d", channel.Id, channel.ChannelInfo.MultiKeyPollingIndex))
}
if !common.MemoryCacheEnabled {
_ = channel.Save()
_ = channel.SaveChannelInfo()
} else {
// CacheUpdateChannel(channel)
}
}()
// Start from the saved polling index and look for the next enabled key
start := channel.ChannelInfo.MultiKeyPollingIndex
start := channelInfo.MultiKeyPollingIndex
if start < 0 || start >= len(keys) {
start = 0
}
@@ -150,6 +159,10 @@ func (channel *Channel) GetNextEnabledKey() (string, *types.NewAPIError) {
}
}
func (channel *Channel) SaveChannelInfo() error {
return DB.Model(channel).Update("channel_info", channel.ChannelInfo).Error
}
func (channel *Channel) GetModels() []string {
if channel.Models == "" {
return []string{}
@@ -500,7 +513,7 @@ func UpdateChannelStatus(channelId int, usingKey string, status int, reason stri
if channelCache.ChannelInfo.IsMultiKey {
// 如果是多Key模式更新缓存中的状态
handlerMultiKeyUpdate(channelCache, usingKey, status)
CacheUpdateChannel(channelCache)
//CacheUpdateChannel(channelCache)
//return true
} else {
// 如果缓存渠道存在,且状态已是目标状态,直接返回

View File

@@ -14,8 +14,8 @@ import (
"github.com/gin-gonic/gin"
)
var group2model2channels map[string]map[string][]int
var channelsIDM map[int]*Channel
var group2model2channels map[string]map[string][]int // enabled channel
var channelsIDM map[int]*Channel // all channels include disabled
var channelSyncLock sync.RWMutex
func InitChannelCache() {
@@ -24,7 +24,7 @@ func InitChannelCache() {
}
newChannelId2channel := make(map[int]*Channel)
var channels []*Channel
DB.Where("status = ?", common.ChannelStatusEnabled).Find(&channels)
DB.Find(&channels)
for _, channel := range channels {
newChannelId2channel[channel.Id] = channel
}
@@ -35,12 +35,13 @@ func InitChannelCache() {
groups[ability.Group] = true
}
newGroup2model2channels := make(map[string]map[string][]int)
newChannelsIDM := make(map[int]*Channel)
for group := range groups {
newGroup2model2channels[group] = make(map[string][]int)
}
for _, channel := range channels {
newChannelsIDM[channel.Id] = channel
if channel.Status != common.ChannelStatusEnabled {
continue // skip disabled channels
}
groups := strings.Split(channel.Group, ",")
for _, group := range groups {
models := strings.Split(channel.Models, ",")
@@ -57,7 +58,7 @@ func InitChannelCache() {
for group, model2channels := range newGroup2model2channels {
for model, channels := range model2channels {
sort.Slice(channels, func(i, j int) bool {
return newChannelsIDM[channels[i]].GetPriority() > newChannelsIDM[channels[j]].GetPriority()
return newChannelId2channel[channels[i]].GetPriority() > newChannelId2channel[channels[j]].GetPriority()
})
newGroup2model2channels[group][model] = channels
}
@@ -65,7 +66,7 @@ func InitChannelCache() {
channelSyncLock.Lock()
group2model2channels = newGroup2model2channels
channelsIDM = newChannelsIDM
channelsIDM = newChannelId2channel
channelSyncLock.Unlock()
common.SysLog("channels synced from database")
}
@@ -203,11 +204,35 @@ func CacheGetChannel(id int) (*Channel, error) {
c, ok := channelsIDM[id]
if !ok {
return nil, fmt.Errorf("当前渠道# %d已不存在", id)
return nil, fmt.Errorf("渠道# %d已不存在", id)
}
if c.Status != common.ChannelStatusEnabled {
return nil, fmt.Errorf("渠道# %d已被禁用", id)
}
return c, nil
}
func CacheGetChannelInfo(id int) (*ChannelInfo, error) {
if !common.MemoryCacheEnabled {
channel, err := GetChannelById(id, true)
if err != nil {
return nil, err
}
return &channel.ChannelInfo, nil
}
channelSyncLock.RLock()
defer channelSyncLock.RUnlock()
c, ok := channelsIDM[id]
if !ok {
return nil, fmt.Errorf("渠道# %d已不存在", id)
}
if c.Status != common.ChannelStatusEnabled {
return nil, fmt.Errorf("渠道# %d已被禁用", id)
}
return &c.ChannelInfo, nil
}
func CacheUpdateChannelStatus(id int, status int) {
if !common.MemoryCacheEnabled {
return