feat(channel): implement thread-safe polling

This commit is contained in:
CaIon
2025-07-12 11:17:08 +08:00
parent 85efea3fb8
commit 23e4e25e9a
3 changed files with 52 additions and 6 deletions

View File

@@ -117,15 +117,19 @@ func (channel *Channel) GetNextEnabledKey() (string, *types.NewAPIError) {
// Randomly pick one enabled key
return keys[enabledIdx[rand.Intn(len(enabledIdx))]], nil
case constant.MultiKeyModePolling:
// Use channel-specific lock to ensure thread-safe polling
lock := getChannelPollingLock(channel.Id)
lock.Lock()
defer lock.Unlock()
defer func() {
if !common.MemoryCacheEnabled {
_ = channel.Save()
} else {
CacheUpdateChannel(channel)
// CacheUpdateChannel(channel)
}
}()
// Start from the saved polling index and look for the next enabled key
println(channel.ChannelInfo.MultiKeyPollingIndex)
start := channel.ChannelInfo.MultiKeyPollingIndex
if start < 0 || start >= len(keys) {
start = 0
@@ -135,7 +139,6 @@ func (channel *Channel) GetNextEnabledKey() (string, *types.NewAPIError) {
if getStatus(idx) == common.ChannelStatusEnabled {
// update polling index for next call (point to the next position)
channel.ChannelInfo.MultiKeyPollingIndex = (idx + 1) % len(keys)
println(channel.ChannelInfo.MultiKeyPollingIndex)
return keys[idx], nil
}
}
@@ -421,6 +424,40 @@ func (channel *Channel) Delete() error {
var channelStatusLock sync.Mutex
// channelPollingLocks stores locks for each channel.id to ensure thread-safe polling
var channelPollingLocks sync.Map
// getChannelPollingLock returns or creates a mutex for the given channel ID
func getChannelPollingLock(channelId int) *sync.Mutex {
if lock, exists := channelPollingLocks.Load(channelId); exists {
return lock.(*sync.Mutex)
}
// Create new lock for this channel
newLock := &sync.Mutex{}
actual, _ := channelPollingLocks.LoadOrStore(channelId, newLock)
return actual.(*sync.Mutex)
}
// CleanupChannelPollingLocks removes locks for channels that no longer exist
// This is optional and can be called periodically to prevent memory leaks
func CleanupChannelPollingLocks() {
var activeChannelIds []int
DB.Model(&Channel{}).Pluck("id", &activeChannelIds)
activeChannelSet := make(map[int]bool)
for _, id := range activeChannelIds {
activeChannelSet[id] = true
}
channelPollingLocks.Range(func(key, value interface{}) bool {
channelId := key.(int)
if !activeChannelSet[channelId] {
channelPollingLocks.Delete(channelId)
}
return true
})
}
func handlerMultiKeyUpdate(channel *Channel, usingKey string, status int) {
keys := channel.getKeys()
if len(keys) == 0 {

View File

@@ -128,13 +128,20 @@ func getRandomSatisfiedChannel(group string, model string, retry int) (*Channel,
}
channelSyncLock.RLock()
defer channelSyncLock.RUnlock()
channels := group2model2channels[group][model]
channelSyncLock.RUnlock()
if len(channels) == 0 {
return nil, errors.New("channel not found")
}
if len(channels) == 1 {
if channel, ok := channelsIDM[channels[0]]; ok {
return channel, nil
}
return nil, fmt.Errorf("数据库一致性错误,渠道# %d 不存在,请联系管理员修复", channels[0])
}
uniquePriorities := make(map[int]bool)
for _, channelId := range channels {
if channel, ok := channelsIDM[channelId]; ok {
@@ -196,7 +203,7 @@ func CacheGetChannel(id int) (*Channel, error) {
c, ok := channelsIDM[id]
if !ok {
return nil, errors.New(fmt.Sprintf("当前渠道# %d已不存在", id))
return nil, fmt.Errorf("当前渠道# %d已不存在", id)
}
return c, nil
}
@@ -224,5 +231,7 @@ func CacheUpdateChannel(channel *Channel) {
println("CacheUpdateChannel:", channel.Id, channel.Name, channel.Status, channel.ChannelInfo.MultiKeyPollingIndex)
println("before:", channelsIDM[channel.Id].ChannelInfo.MultiKeyPollingIndex)
channelsIDM[channel.Id] = channel
println("after :", channelsIDM[channel.Id].ChannelInfo.MultiKeyPollingIndex)
}