From 962c40c1a7da4579162d7abc2aab1d43598842b7 Mon Sep 17 00:00:00 2001 From: CaIon Date: Fri, 8 Aug 2025 18:36:09 +0800 Subject: [PATCH] feat: enhance AddAbilities and BatchInsertChannels to support transaction handling --- model/ability.go | 11 ++++++--- model/channel.go | 63 ++++++++++++++++++++++++++++++------------------ 2 files changed, 48 insertions(+), 26 deletions(-) diff --git a/model/ability.go b/model/ability.go index 2df45917..ce2f299c 100644 --- a/model/ability.go +++ b/model/ability.go @@ -142,7 +142,7 @@ func GetRandomSatisfiedChannel(group string, model string, retry int) (*Channel, return &channel, err } -func (channel *Channel) AddAbilities() error { +func (channel *Channel) AddAbilities(tx *gorm.DB) error { models_ := strings.Split(channel.Models, ",") groups_ := strings.Split(channel.Group, ",") abilitySet := make(map[string]struct{}) @@ -169,8 +169,13 @@ func (channel *Channel) AddAbilities() error { if len(abilities) == 0 { return nil } + // choose DB or provided tx + useDB := DB + if tx != nil { + useDB = tx + } for _, chunk := range lo.Chunk(abilities, 50) { - err := DB.Clauses(clause.OnConflict{DoNothing: true}).Create(&chunk).Error + err := useDB.Clauses(clause.OnConflict{DoNothing: true}).Create(&chunk).Error if err != nil { return err } @@ -321,7 +326,7 @@ func FixAbility() (int, int, error) { } // Then add new abilities for _, channel := range chunk { - err = channel.AddAbilities() + err = channel.AddAbilities(nil) if err != nil { common.SysError(fmt.Sprintf("Add abilities for channel %d failed: %s", channel.Id, err.Error())) failCount++ diff --git a/model/channel.go b/model/channel.go index a5fb463e..b670b9dc 100644 --- a/model/channel.go +++ b/model/channel.go @@ -13,6 +13,7 @@ import ( "strings" "sync" + "github.com/samber/lo" "gorm.io/gorm" ) @@ -337,38 +338,54 @@ func GetChannelById(id int, selectAll bool) (*Channel, error) { } func BatchInsertChannels(channels []Channel) error { - var err error - err = DB.Create(&channels).Error - if err != nil { - return err + if len(channels) == 0 { + return nil } - for _, channel_ := range channels { - err = channel_.AddAbilities() - if err != nil { + tx := DB.Begin() + if tx.Error != nil { + return tx.Error + } + defer func() { + if r := recover(); r != nil { + tx.Rollback() + } + }() + + for _, chunk := range lo.Chunk(channels, 50) { + if err := tx.Create(&chunk).Error; err != nil { + tx.Rollback() return err } + for _, channel_ := range chunk { + if err := channel_.AddAbilities(tx); err != nil { + tx.Rollback() + return err + } + } } - return nil + return tx.Commit().Error } func BatchDeleteChannels(ids []int) error { - //使用事务 删除channel表和channel_ability表 + if len(ids) == 0 { + return nil + } + // 使用事务 分批删除channel表和abilities表 tx := DB.Begin() - err := tx.Where("id in (?)", ids).Delete(&Channel{}).Error - if err != nil { - // 回滚事务 - tx.Rollback() - return err + if tx.Error != nil { + return tx.Error } - err = tx.Where("channel_id in (?)", ids).Delete(&Ability{}).Error - if err != nil { - // 回滚事务 - tx.Rollback() - return err + for _, chunk := range lo.Chunk(ids, 200) { + if err := tx.Where("id in (?)", chunk).Delete(&Channel{}).Error; err != nil { + tx.Rollback() + return err + } + if err := tx.Where("channel_id in (?)", chunk).Delete(&Ability{}).Error; err != nil { + tx.Rollback() + return err + } } - // 提交事务 - tx.Commit() - return err + return tx.Commit().Error } func (channel *Channel) GetPriority() int64 { @@ -412,7 +429,7 @@ func (channel *Channel) Insert() error { if err != nil { return err } - err = channel.AddAbilities() + err = channel.AddAbilities(nil) return err }