diff --git a/model/channel.go b/model/channel.go index 12186712..6ff0901d 100644 --- a/model/channel.go +++ b/model/channel.go @@ -290,35 +290,42 @@ func (channel *Channel) Delete() error { var channelStatusLock sync.Mutex -func UpdateChannelStatusById(id int, status int, reason string) { +func UpdateChannelStatusById(id int, status int, reason string) bool { if common.MemoryCacheEnabled { channelStatusLock.Lock() + defer channelStatusLock.Unlock() + channelCache, _ := CacheGetChannel(id) // 如果缓存渠道存在,且状态已是目标状态,直接返回 if channelCache != nil && channelCache.Status == status { - channelStatusLock.Unlock() - return + return false } // 如果缓存渠道不存在(说明已经被禁用),且要设置的状态不为启用,直接返回 if channelCache == nil && status != common.ChannelStatusEnabled { - channelStatusLock.Unlock() - return + return false } CacheUpdateChannelStatus(id, status) - channelStatusLock.Unlock() } err := UpdateAbilityStatus(id, status == common.ChannelStatusEnabled) if err != nil { common.SysError("failed to update ability status: " + err.Error()) + return false } channel, err := GetChannelById(id, true) if err != nil { // find channel by id error, directly update status - err = DB.Model(&Channel{}).Where("id = ?", id).Update("status", status).Error - if err != nil { - common.SysError("failed to update channel status: " + err.Error()) + result := DB.Model(&Channel{}).Where("id = ?", id).Update("status", status) + if result.Error != nil { + common.SysError("failed to update channel status: " + result.Error.Error()) + return false + } + if result.RowsAffected == 0 { + return false } } else { + if channel.Status == status { + return false + } // find channel by id success, update status and other info info := channel.GetOtherInfo() info["status_reason"] = reason @@ -328,9 +335,10 @@ func UpdateChannelStatusById(id int, status int, reason string) { err = channel.Save() if err != nil { common.SysError("failed to update channel status: " + err.Error()) + return false } } - + return true } func EnableChannelByTag(tag string) error { diff --git a/service/channel.go b/service/channel.go index 0f4270a4..e3a76af4 100644 --- a/service/channel.go +++ b/service/channel.go @@ -10,19 +10,27 @@ import ( "strings" ) +func formatNotifyType(channelId int, status int) string { + return fmt.Sprintf("%s_%d_%d", dto.NotifyTypeChannelUpdate, channelId, status) +} + // disable & notify func DisableChannel(channelId int, channelName string, reason string) { - model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled, reason) - subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelName, channelId) - content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason) - NotifyRootUser(dto.NotifyTypeChannelUpdate, subject, content) + success := model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled, reason) + if success { + subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelName, channelId) + content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason) + NotifyRootUser(formatNotifyType(channelId, common.ChannelStatusAutoDisabled), subject, content) + } } func EnableChannel(channelId int, channelName string) { - model.UpdateChannelStatusById(channelId, common.ChannelStatusEnabled, "") - subject := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId) - content := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId) - NotifyRootUser(dto.NotifyTypeChannelUpdate, subject, content) + success := model.UpdateChannelStatusById(channelId, common.ChannelStatusEnabled, "") + if success { + subject := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId) + content := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId) + NotifyRootUser(formatNotifyType(channelId, common.ChannelStatusEnabled), subject, content) + } } func ShouldDisableChannel(channelType int, err *dto.OpenAIErrorWithStatusCode) bool {