diff --git a/controller/channel-test.go b/controller/channel-test.go index 8c010048..203c91a2 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -110,7 +110,7 @@ func testChannel(channel *model.Channel, testModel string) testResult { } cache.WriteContext(c) - c.Request.Header.Set("Authorization", "Bearer "+channel.Key) + //c.Request.Header.Set("Authorization", "Bearer "+channel.Key) c.Request.Header.Set("Content-Type", "application/json") c.Set("channel", channel.Type) c.Set("base_url", channel.GetBaseURL()) @@ -320,6 +320,11 @@ func TestChannel(c *gin.Context) { }) return } + //defer func() { + // if channel.ChannelInfo.IsMultiKey { + // go func() { _ = channel.SaveChannelInfo() }() + // } + //}() testModel := c.Query("model") tik := time.Now() result := testChannel(channel, testModel) diff --git a/model/channel.go b/model/channel.go index 90d0e9b9..c9f11953 100644 --- a/model/channel.go +++ b/model/channel.go @@ -4,6 +4,7 @@ import ( "database/sql/driver" "encoding/json" "errors" + "fmt" "math/rand" "one-api/common" "one-api/constant" @@ -122,15 +123,23 @@ func (channel *Channel) GetNextEnabledKey() (string, *types.NewAPIError) { lock.Lock() defer lock.Unlock() + channelInfo, err := CacheGetChannelInfo(channel.Id) + if err != nil { + return "", types.NewError(err, types.ErrorCodeGetChannelFailed) + } + //println("before polling index:", channel.ChannelInfo.MultiKeyPollingIndex) defer func() { + if common.DebugEnabled { + println(fmt.Sprintf("channel %d polling index: %d", channel.Id, channel.ChannelInfo.MultiKeyPollingIndex)) + } if !common.MemoryCacheEnabled { - _ = channel.Save() + _ = channel.SaveChannelInfo() } else { // CacheUpdateChannel(channel) } }() // Start from the saved polling index and look for the next enabled key - start := channel.ChannelInfo.MultiKeyPollingIndex + start := channelInfo.MultiKeyPollingIndex if start < 0 || start >= len(keys) { start = 0 } @@ -150,6 +159,10 @@ func (channel *Channel) GetNextEnabledKey() (string, *types.NewAPIError) { } } +func (channel *Channel) SaveChannelInfo() error { + return DB.Model(channel).Update("channel_info", channel.ChannelInfo).Error +} + func (channel *Channel) GetModels() []string { if channel.Models == "" { return []string{} @@ -500,7 +513,7 @@ func UpdateChannelStatus(channelId int, usingKey string, status int, reason stri if channelCache.ChannelInfo.IsMultiKey { // 如果是多Key模式,更新缓存中的状态 handlerMultiKeyUpdate(channelCache, usingKey, status) - CacheUpdateChannel(channelCache) + //CacheUpdateChannel(channelCache) //return true } else { // 如果缓存渠道存在,且状态已是目标状态,直接返回 diff --git a/model/channel_cache.go b/model/channel_cache.go index c97ee78e..b2451248 100644 --- a/model/channel_cache.go +++ b/model/channel_cache.go @@ -14,8 +14,8 @@ import ( "github.com/gin-gonic/gin" ) -var group2model2channels map[string]map[string][]int -var channelsIDM map[int]*Channel +var group2model2channels map[string]map[string][]int // enabled channel +var channelsIDM map[int]*Channel // all channels include disabled var channelSyncLock sync.RWMutex func InitChannelCache() { @@ -24,7 +24,7 @@ func InitChannelCache() { } newChannelId2channel := make(map[int]*Channel) var channels []*Channel - DB.Where("status = ?", common.ChannelStatusEnabled).Find(&channels) + DB.Find(&channels) for _, channel := range channels { newChannelId2channel[channel.Id] = channel } @@ -35,12 +35,13 @@ func InitChannelCache() { groups[ability.Group] = true } newGroup2model2channels := make(map[string]map[string][]int) - newChannelsIDM := make(map[int]*Channel) for group := range groups { newGroup2model2channels[group] = make(map[string][]int) } for _, channel := range channels { - newChannelsIDM[channel.Id] = channel + if channel.Status != common.ChannelStatusEnabled { + continue // skip disabled channels + } groups := strings.Split(channel.Group, ",") for _, group := range groups { models := strings.Split(channel.Models, ",") @@ -57,7 +58,7 @@ func InitChannelCache() { for group, model2channels := range newGroup2model2channels { for model, channels := range model2channels { sort.Slice(channels, func(i, j int) bool { - return newChannelsIDM[channels[i]].GetPriority() > newChannelsIDM[channels[j]].GetPriority() + return newChannelId2channel[channels[i]].GetPriority() > newChannelId2channel[channels[j]].GetPriority() }) newGroup2model2channels[group][model] = channels } @@ -65,7 +66,7 @@ func InitChannelCache() { channelSyncLock.Lock() group2model2channels = newGroup2model2channels - channelsIDM = newChannelsIDM + channelsIDM = newChannelId2channel channelSyncLock.Unlock() common.SysLog("channels synced from database") } @@ -203,11 +204,35 @@ func CacheGetChannel(id int) (*Channel, error) { c, ok := channelsIDM[id] if !ok { - return nil, fmt.Errorf("当前渠道# %d,已不存在", id) + return nil, fmt.Errorf("渠道# %d,已不存在", id) + } + if c.Status != common.ChannelStatusEnabled { + return nil, fmt.Errorf("渠道# %d,已被禁用", id) } return c, nil } +func CacheGetChannelInfo(id int) (*ChannelInfo, error) { + if !common.MemoryCacheEnabled { + channel, err := GetChannelById(id, true) + if err != nil { + return nil, err + } + return &channel.ChannelInfo, nil + } + channelSyncLock.RLock() + defer channelSyncLock.RUnlock() + + c, ok := channelsIDM[id] + if !ok { + return nil, fmt.Errorf("渠道# %d,已不存在", id) + } + if c.Status != common.ChannelStatusEnabled { + return nil, fmt.Errorf("渠道# %d,已被禁用", id) + } + return &c.ChannelInfo, nil +} + func CacheUpdateChannelStatus(id int, status int) { if !common.MemoryCacheEnabled { return