✨ feat(channel): improve channel cache handling and add error checks for disabled channels
This commit is contained in:
@@ -110,7 +110,7 @@ func testChannel(channel *model.Channel, testModel string) testResult {
|
|||||||
}
|
}
|
||||||
cache.WriteContext(c)
|
cache.WriteContext(c)
|
||||||
|
|
||||||
c.Request.Header.Set("Authorization", "Bearer "+channel.Key)
|
//c.Request.Header.Set("Authorization", "Bearer "+channel.Key)
|
||||||
c.Request.Header.Set("Content-Type", "application/json")
|
c.Request.Header.Set("Content-Type", "application/json")
|
||||||
c.Set("channel", channel.Type)
|
c.Set("channel", channel.Type)
|
||||||
c.Set("base_url", channel.GetBaseURL())
|
c.Set("base_url", channel.GetBaseURL())
|
||||||
@@ -320,6 +320,11 @@ func TestChannel(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
//defer func() {
|
||||||
|
// if channel.ChannelInfo.IsMultiKey {
|
||||||
|
// go func() { _ = channel.SaveChannelInfo() }()
|
||||||
|
// }
|
||||||
|
//}()
|
||||||
testModel := c.Query("model")
|
testModel := c.Query("model")
|
||||||
tik := time.Now()
|
tik := time.Now()
|
||||||
result := testChannel(channel, testModel)
|
result := testChannel(channel, testModel)
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"database/sql/driver"
|
"database/sql/driver"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/constant"
|
"one-api/constant"
|
||||||
@@ -122,15 +123,23 @@ func (channel *Channel) GetNextEnabledKey() (string, *types.NewAPIError) {
|
|||||||
lock.Lock()
|
lock.Lock()
|
||||||
defer lock.Unlock()
|
defer lock.Unlock()
|
||||||
|
|
||||||
|
channelInfo, err := CacheGetChannelInfo(channel.Id)
|
||||||
|
if err != nil {
|
||||||
|
return "", types.NewError(err, types.ErrorCodeGetChannelFailed)
|
||||||
|
}
|
||||||
|
//println("before polling index:", channel.ChannelInfo.MultiKeyPollingIndex)
|
||||||
defer func() {
|
defer func() {
|
||||||
|
if common.DebugEnabled {
|
||||||
|
println(fmt.Sprintf("channel %d polling index: %d", channel.Id, channel.ChannelInfo.MultiKeyPollingIndex))
|
||||||
|
}
|
||||||
if !common.MemoryCacheEnabled {
|
if !common.MemoryCacheEnabled {
|
||||||
_ = channel.Save()
|
_ = channel.SaveChannelInfo()
|
||||||
} else {
|
} else {
|
||||||
// CacheUpdateChannel(channel)
|
// CacheUpdateChannel(channel)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
// Start from the saved polling index and look for the next enabled key
|
// Start from the saved polling index and look for the next enabled key
|
||||||
start := channel.ChannelInfo.MultiKeyPollingIndex
|
start := channelInfo.MultiKeyPollingIndex
|
||||||
if start < 0 || start >= len(keys) {
|
if start < 0 || start >= len(keys) {
|
||||||
start = 0
|
start = 0
|
||||||
}
|
}
|
||||||
@@ -150,6 +159,10 @@ func (channel *Channel) GetNextEnabledKey() (string, *types.NewAPIError) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (channel *Channel) SaveChannelInfo() error {
|
||||||
|
return DB.Model(channel).Update("channel_info", channel.ChannelInfo).Error
|
||||||
|
}
|
||||||
|
|
||||||
func (channel *Channel) GetModels() []string {
|
func (channel *Channel) GetModels() []string {
|
||||||
if channel.Models == "" {
|
if channel.Models == "" {
|
||||||
return []string{}
|
return []string{}
|
||||||
@@ -500,7 +513,7 @@ func UpdateChannelStatus(channelId int, usingKey string, status int, reason stri
|
|||||||
if channelCache.ChannelInfo.IsMultiKey {
|
if channelCache.ChannelInfo.IsMultiKey {
|
||||||
// 如果是多Key模式,更新缓存中的状态
|
// 如果是多Key模式,更新缓存中的状态
|
||||||
handlerMultiKeyUpdate(channelCache, usingKey, status)
|
handlerMultiKeyUpdate(channelCache, usingKey, status)
|
||||||
CacheUpdateChannel(channelCache)
|
//CacheUpdateChannel(channelCache)
|
||||||
//return true
|
//return true
|
||||||
} else {
|
} else {
|
||||||
// 如果缓存渠道存在,且状态已是目标状态,直接返回
|
// 如果缓存渠道存在,且状态已是目标状态,直接返回
|
||||||
|
|||||||
@@ -14,8 +14,8 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
var group2model2channels map[string]map[string][]int
|
var group2model2channels map[string]map[string][]int // enabled channel
|
||||||
var channelsIDM map[int]*Channel
|
var channelsIDM map[int]*Channel // all channels include disabled
|
||||||
var channelSyncLock sync.RWMutex
|
var channelSyncLock sync.RWMutex
|
||||||
|
|
||||||
func InitChannelCache() {
|
func InitChannelCache() {
|
||||||
@@ -24,7 +24,7 @@ func InitChannelCache() {
|
|||||||
}
|
}
|
||||||
newChannelId2channel := make(map[int]*Channel)
|
newChannelId2channel := make(map[int]*Channel)
|
||||||
var channels []*Channel
|
var channels []*Channel
|
||||||
DB.Where("status = ?", common.ChannelStatusEnabled).Find(&channels)
|
DB.Find(&channels)
|
||||||
for _, channel := range channels {
|
for _, channel := range channels {
|
||||||
newChannelId2channel[channel.Id] = channel
|
newChannelId2channel[channel.Id] = channel
|
||||||
}
|
}
|
||||||
@@ -35,12 +35,13 @@ func InitChannelCache() {
|
|||||||
groups[ability.Group] = true
|
groups[ability.Group] = true
|
||||||
}
|
}
|
||||||
newGroup2model2channels := make(map[string]map[string][]int)
|
newGroup2model2channels := make(map[string]map[string][]int)
|
||||||
newChannelsIDM := make(map[int]*Channel)
|
|
||||||
for group := range groups {
|
for group := range groups {
|
||||||
newGroup2model2channels[group] = make(map[string][]int)
|
newGroup2model2channels[group] = make(map[string][]int)
|
||||||
}
|
}
|
||||||
for _, channel := range channels {
|
for _, channel := range channels {
|
||||||
newChannelsIDM[channel.Id] = channel
|
if channel.Status != common.ChannelStatusEnabled {
|
||||||
|
continue // skip disabled channels
|
||||||
|
}
|
||||||
groups := strings.Split(channel.Group, ",")
|
groups := strings.Split(channel.Group, ",")
|
||||||
for _, group := range groups {
|
for _, group := range groups {
|
||||||
models := strings.Split(channel.Models, ",")
|
models := strings.Split(channel.Models, ",")
|
||||||
@@ -57,7 +58,7 @@ func InitChannelCache() {
|
|||||||
for group, model2channels := range newGroup2model2channels {
|
for group, model2channels := range newGroup2model2channels {
|
||||||
for model, channels := range model2channels {
|
for model, channels := range model2channels {
|
||||||
sort.Slice(channels, func(i, j int) bool {
|
sort.Slice(channels, func(i, j int) bool {
|
||||||
return newChannelsIDM[channels[i]].GetPriority() > newChannelsIDM[channels[j]].GetPriority()
|
return newChannelId2channel[channels[i]].GetPriority() > newChannelId2channel[channels[j]].GetPriority()
|
||||||
})
|
})
|
||||||
newGroup2model2channels[group][model] = channels
|
newGroup2model2channels[group][model] = channels
|
||||||
}
|
}
|
||||||
@@ -65,7 +66,7 @@ func InitChannelCache() {
|
|||||||
|
|
||||||
channelSyncLock.Lock()
|
channelSyncLock.Lock()
|
||||||
group2model2channels = newGroup2model2channels
|
group2model2channels = newGroup2model2channels
|
||||||
channelsIDM = newChannelsIDM
|
channelsIDM = newChannelId2channel
|
||||||
channelSyncLock.Unlock()
|
channelSyncLock.Unlock()
|
||||||
common.SysLog("channels synced from database")
|
common.SysLog("channels synced from database")
|
||||||
}
|
}
|
||||||
@@ -203,11 +204,35 @@ func CacheGetChannel(id int) (*Channel, error) {
|
|||||||
|
|
||||||
c, ok := channelsIDM[id]
|
c, ok := channelsIDM[id]
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("当前渠道# %d,已不存在", id)
|
return nil, fmt.Errorf("渠道# %d,已不存在", id)
|
||||||
|
}
|
||||||
|
if c.Status != common.ChannelStatusEnabled {
|
||||||
|
return nil, fmt.Errorf("渠道# %d,已被禁用", id)
|
||||||
}
|
}
|
||||||
return c, nil
|
return c, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func CacheGetChannelInfo(id int) (*ChannelInfo, error) {
|
||||||
|
if !common.MemoryCacheEnabled {
|
||||||
|
channel, err := GetChannelById(id, true)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &channel.ChannelInfo, nil
|
||||||
|
}
|
||||||
|
channelSyncLock.RLock()
|
||||||
|
defer channelSyncLock.RUnlock()
|
||||||
|
|
||||||
|
c, ok := channelsIDM[id]
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("渠道# %d,已不存在", id)
|
||||||
|
}
|
||||||
|
if c.Status != common.ChannelStatusEnabled {
|
||||||
|
return nil, fmt.Errorf("渠道# %d,已被禁用", id)
|
||||||
|
}
|
||||||
|
return &c.ChannelInfo, nil
|
||||||
|
}
|
||||||
|
|
||||||
func CacheUpdateChannelStatus(id int, status int) {
|
func CacheUpdateChannelStatus(id int, status int) {
|
||||||
if !common.MemoryCacheEnabled {
|
if !common.MemoryCacheEnabled {
|
||||||
return
|
return
|
||||||
|
|||||||
Reference in New Issue
Block a user