diff --git a/controller/channel-test.go b/controller/channel-test.go index 82bb1d7f..8c010048 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -312,7 +312,7 @@ func TestChannel(c *gin.Context) { }) return } - channel, err := model.GetChannelById(channelId, true) + channel, err := model.CacheGetChannel(channelId) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, diff --git a/model/channel.go b/model/channel.go index 6079bf3c..90d0e9b9 100644 --- a/model/channel.go +++ b/model/channel.go @@ -117,15 +117,19 @@ func (channel *Channel) GetNextEnabledKey() (string, *types.NewAPIError) { // Randomly pick one enabled key return keys[enabledIdx[rand.Intn(len(enabledIdx))]], nil case constant.MultiKeyModePolling: + // Use channel-specific lock to ensure thread-safe polling + lock := getChannelPollingLock(channel.Id) + lock.Lock() + defer lock.Unlock() + defer func() { if !common.MemoryCacheEnabled { _ = channel.Save() } else { - CacheUpdateChannel(channel) + // CacheUpdateChannel(channel) } }() // Start from the saved polling index and look for the next enabled key - println(channel.ChannelInfo.MultiKeyPollingIndex) start := channel.ChannelInfo.MultiKeyPollingIndex if start < 0 || start >= len(keys) { start = 0 @@ -135,7 +139,6 @@ func (channel *Channel) GetNextEnabledKey() (string, *types.NewAPIError) { if getStatus(idx) == common.ChannelStatusEnabled { // update polling index for next call (point to the next position) channel.ChannelInfo.MultiKeyPollingIndex = (idx + 1) % len(keys) - println(channel.ChannelInfo.MultiKeyPollingIndex) return keys[idx], nil } } @@ -421,6 +424,40 @@ func (channel *Channel) Delete() error { var channelStatusLock sync.Mutex +// channelPollingLocks stores locks for each channel.id to ensure thread-safe polling +var channelPollingLocks sync.Map + +// getChannelPollingLock returns or creates a mutex for the given channel ID +func getChannelPollingLock(channelId int) *sync.Mutex { + if lock, exists := channelPollingLocks.Load(channelId); exists { + return lock.(*sync.Mutex) + } + // Create new lock for this channel + newLock := &sync.Mutex{} + actual, _ := channelPollingLocks.LoadOrStore(channelId, newLock) + return actual.(*sync.Mutex) +} + +// CleanupChannelPollingLocks removes locks for channels that no longer exist +// This is optional and can be called periodically to prevent memory leaks +func CleanupChannelPollingLocks() { + var activeChannelIds []int + DB.Model(&Channel{}).Pluck("id", &activeChannelIds) + + activeChannelSet := make(map[int]bool) + for _, id := range activeChannelIds { + activeChannelSet[id] = true + } + + channelPollingLocks.Range(func(key, value interface{}) bool { + channelId := key.(int) + if !activeChannelSet[channelId] { + channelPollingLocks.Delete(channelId) + } + return true + }) +} + func handlerMultiKeyUpdate(channel *Channel, usingKey string, status int) { keys := channel.getKeys() if len(keys) == 0 { diff --git a/model/channel_cache.go b/model/channel_cache.go index f4e38a39..c97ee78e 100644 --- a/model/channel_cache.go +++ b/model/channel_cache.go @@ -128,13 +128,20 @@ func getRandomSatisfiedChannel(group string, model string, retry int) (*Channel, } channelSyncLock.RLock() + defer channelSyncLock.RUnlock() channels := group2model2channels[group][model] - channelSyncLock.RUnlock() if len(channels) == 0 { return nil, errors.New("channel not found") } + if len(channels) == 1 { + if channel, ok := channelsIDM[channels[0]]; ok { + return channel, nil + } + return nil, fmt.Errorf("数据库一致性错误,渠道# %d 不存在,请联系管理员修复", channels[0]) + } + uniquePriorities := make(map[int]bool) for _, channelId := range channels { if channel, ok := channelsIDM[channelId]; ok { @@ -196,7 +203,7 @@ func CacheGetChannel(id int) (*Channel, error) { c, ok := channelsIDM[id] if !ok { - return nil, errors.New(fmt.Sprintf("当前渠道# %d,已不存在", id)) + return nil, fmt.Errorf("当前渠道# %d,已不存在", id) } return c, nil } @@ -224,5 +231,7 @@ func CacheUpdateChannel(channel *Channel) { println("CacheUpdateChannel:", channel.Id, channel.Name, channel.Status, channel.ChannelInfo.MultiKeyPollingIndex) + println("before:", channelsIDM[channel.Id].ChannelInfo.MultiKeyPollingIndex) channelsIDM[channel.Id] = channel + println("after :", channelsIDM[channel.Id].ChannelInfo.MultiKeyPollingIndex) }