diff --git a/controller/channel.go b/controller/channel.go index 4c7d28f2..85b14b43 100644 --- a/controller/channel.go +++ b/controller/channel.go @@ -228,7 +228,7 @@ func FetchUpstreamModels(c *gin.Context) { } func FixChannelsAbilities(c *gin.Context) { - count, err := model.FixAbility() + success, fails, err := model.FixAbility() if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, @@ -239,7 +239,10 @@ func FixChannelsAbilities(c *gin.Context) { c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", - "data": count, + "data": gin.H{ + "success": success, + "fails": fails, + }, }) } diff --git a/main.go b/main.go index b89350b0..ca3da601 100644 --- a/main.go +++ b/main.go @@ -68,9 +68,9 @@ func main() { if r := recover(); r != nil { common.SysError(fmt.Sprintf("InitChannelCache panic: %v, retrying once", r)) // Retry once - _, fixErr := model.FixAbility() + _, _, fixErr := model.FixAbility() if fixErr != nil { - common.SysError(fmt.Sprintf("InitChannelCache failed: %s", fixErr.Error())) + common.FatalLog(fmt.Sprintf("InitChannelCache failed: %s", fixErr.Error())) } } }() diff --git a/model/ability.go b/model/ability.go index fb5301fe..ed124676 100644 --- a/model/ability.go +++ b/model/ability.go @@ -5,6 +5,7 @@ import ( "fmt" "one-api/common" "strings" + "sync" "github.com/samber/lo" "gorm.io/gorm" @@ -272,74 +273,45 @@ func UpdateAbilityByTag(tag string, newTag *string, priority *int64, weight *uin return DB.Model(&Ability{}).Where("tag = ?", tag).Updates(ability).Error } -func FixAbility() (int, error) { - var channelIds []int - count := 0 - // Find all channel ids from channel table - err := DB.Model(&Channel{}).Pluck("id", &channelIds).Error +var fixLock = sync.Mutex{} + +func FixAbility() (int, int, error) { + lock := fixLock.TryLock() + if !lock { + return 0, 0, errors.New("已经有一个修复任务在运行中,请稍后再试") + } + defer fixLock.Unlock() + var channels []*Channel + // Find all channels + err := DB.Model(&Channel{}).Find(&channels).Error if err != nil { - common.SysError(fmt.Sprintf("Get channel ids from channel table failed: %s", err.Error())) - return 0, err + return 0, 0, err } - - // Delete abilities of channels that are not in channel table - in batches to avoid too many placeholders - if len(channelIds) > 0 { - // Process deletion in chunks to avoid "too many placeholders" error - for _, chunk := range lo.Chunk(channelIds, 100) { - err = DB.Where("channel_id NOT IN (?)", chunk).Delete(&Ability{}).Error - if err != nil { - common.SysError(fmt.Sprintf("Delete abilities of channels (batch) that are not in channel table failed: %s", err.Error())) - return 0, err - } - } - } else { - // If no channels exist, delete all abilities - err = DB.Delete(&Ability{}).Error + if len(channels) == 0 { + return 0, 0, nil + } + successCount := 0 + failCount := 0 + for _, chunk := range lo.Chunk(channels, 50) { + ids := lo.Map(chunk, func(c *Channel, _ int) int { return c.Id }) + // Delete all abilities of this channel + err = DB.Where("channel_id IN ?", ids).Delete(&Ability{}).Error if err != nil { - common.SysError(fmt.Sprintf("Delete all abilities failed: %s", err.Error())) - return 0, err + common.SysError(fmt.Sprintf("Delete abilities failed: %s", err.Error())) + failCount += len(chunk) + continue } - common.SysLog("Delete all abilities successfully") - return 0, nil - } - - common.SysLog(fmt.Sprintf("Delete abilities of channels that are not in channel table successfully, ids: %v", channelIds)) - count += len(channelIds) - - // Use channelIds to find channel not in abilities table - var abilityChannelIds []int - err = DB.Table("abilities").Distinct("channel_id").Pluck("channel_id", &abilityChannelIds).Error - if err != nil { - common.SysError(fmt.Sprintf("Get channel ids from abilities table failed: %s", err.Error())) - return count, err - } - - var channels []Channel - if len(abilityChannelIds) == 0 { - err = DB.Find(&channels).Error - } else { - // Process query in chunks to avoid "too many placeholders" error - err = nil - for _, chunk := range lo.Chunk(abilityChannelIds, 100) { - var channelsChunk []Channel - err = DB.Where("id NOT IN (?)", chunk).Find(&channelsChunk).Error + // Then add new abilities + for _, channel := range chunk { + err = channel.AddAbilities() if err != nil { - common.SysError(fmt.Sprintf("Find channels not in abilities table failed: %s", err.Error())) - return count, err + common.SysError(fmt.Sprintf("Add abilities for channel %d failed: %s", channel.Id, err.Error())) + failCount++ + } else { + successCount++ } - channels = append(channels, channelsChunk...) - } - } - - for _, channel := range channels { - err := channel.UpdateAbilities(nil) - if err != nil { - common.SysError(fmt.Sprintf("Update abilities of channel %d failed: %s", channel.Id, err.Error())) - } else { - common.SysLog(fmt.Sprintf("Update abilities of channel %d successfully", channel.Id)) - count++ } } InitChannelCache() - return count, nil + return successCount, failCount, nil } diff --git a/web/src/components/table/ChannelsTable.js b/web/src/components/table/ChannelsTable.js index 0e84437d..810993c4 100644 --- a/web/src/components/table/ChannelsTable.js +++ b/web/src/components/table/ChannelsTable.js @@ -1461,9 +1461,9 @@ const ChannelsTable = () => { const fixChannelsAbilities = async () => { const res = await API.post(`/api/channel/fix`); - const { success, message, data } = res.data; + const { success, message, data } = res.data; if (success) { - showSuccess(t('已修复 ${data} 个通道!').replace('${data}', data)); + showSuccess(t('已修复 ${success} 个通道,失败 ${fails} 个通道。').replace('${success}', data.success).replace('${fails}', data.fails)); await refresh(); } else { showError(message); diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js index cfed54e4..f53b4abb 100644 --- a/web/src/pages/Channel/EditChannel.js +++ b/web/src/pages/Channel/EditChannel.js @@ -240,7 +240,7 @@ const EditChannel = (props) => { if (isEdit) { // 如果是编辑模式,使用已有的channel id获取模型列表 const res = await API.get('/api/channel/fetch_models/' + channelId); - if (res.data && res.data?.success) { + if (res.data && res.data.success) { models.push(...res.data.data); } else { err = true;