feat: enhance AddAbilities and BatchInsertChannels to support transaction handling

This commit is contained in:
CaIon
2025-08-08 18:36:09 +08:00
parent f6c7828160
commit 962c40c1a7
2 changed files with 48 additions and 26 deletions

View File

@@ -142,7 +142,7 @@ func GetRandomSatisfiedChannel(group string, model string, retry int) (*Channel,
return &channel, err return &channel, err
} }
func (channel *Channel) AddAbilities() error { func (channel *Channel) AddAbilities(tx *gorm.DB) error {
models_ := strings.Split(channel.Models, ",") models_ := strings.Split(channel.Models, ",")
groups_ := strings.Split(channel.Group, ",") groups_ := strings.Split(channel.Group, ",")
abilitySet := make(map[string]struct{}) abilitySet := make(map[string]struct{})
@@ -169,8 +169,13 @@ func (channel *Channel) AddAbilities() error {
if len(abilities) == 0 { if len(abilities) == 0 {
return nil return nil
} }
// choose DB or provided tx
useDB := DB
if tx != nil {
useDB = tx
}
for _, chunk := range lo.Chunk(abilities, 50) { for _, chunk := range lo.Chunk(abilities, 50) {
err := DB.Clauses(clause.OnConflict{DoNothing: true}).Create(&chunk).Error err := useDB.Clauses(clause.OnConflict{DoNothing: true}).Create(&chunk).Error
if err != nil { if err != nil {
return err return err
} }
@@ -321,7 +326,7 @@ func FixAbility() (int, int, error) {
} }
// Then add new abilities // Then add new abilities
for _, channel := range chunk { for _, channel := range chunk {
err = channel.AddAbilities() err = channel.AddAbilities(nil)
if err != nil { if err != nil {
common.SysError(fmt.Sprintf("Add abilities for channel %d failed: %s", channel.Id, err.Error())) common.SysError(fmt.Sprintf("Add abilities for channel %d failed: %s", channel.Id, err.Error()))
failCount++ failCount++

View File

@@ -13,6 +13,7 @@ import (
"strings" "strings"
"sync" "sync"
"github.com/samber/lo"
"gorm.io/gorm" "gorm.io/gorm"
) )
@@ -337,38 +338,54 @@ func GetChannelById(id int, selectAll bool) (*Channel, error) {
} }
func BatchInsertChannels(channels []Channel) error { func BatchInsertChannels(channels []Channel) error {
var err error if len(channels) == 0 {
err = DB.Create(&channels).Error return nil
if err != nil {
return err
} }
for _, channel_ := range channels { tx := DB.Begin()
err = channel_.AddAbilities() if tx.Error != nil {
if err != nil { return tx.Error
}
defer func() {
if r := recover(); r != nil {
tx.Rollback()
}
}()
for _, chunk := range lo.Chunk(channels, 50) {
if err := tx.Create(&chunk).Error; err != nil {
tx.Rollback()
return err return err
} }
for _, channel_ := range chunk {
if err := channel_.AddAbilities(tx); err != nil {
tx.Rollback()
return err
}
}
} }
return nil return tx.Commit().Error
} }
func BatchDeleteChannels(ids []int) error { func BatchDeleteChannels(ids []int) error {
//使用事务 删除channel表和channel_ability表 if len(ids) == 0 {
return nil
}
// 使用事务 分批删除channel表和abilities表
tx := DB.Begin() tx := DB.Begin()
err := tx.Where("id in (?)", ids).Delete(&Channel{}).Error if tx.Error != nil {
if err != nil { return tx.Error
// 回滚事务
tx.Rollback()
return err
} }
err = tx.Where("channel_id in (?)", ids).Delete(&Ability{}).Error for _, chunk := range lo.Chunk(ids, 200) {
if err != nil { if err := tx.Where("id in (?)", chunk).Delete(&Channel{}).Error; err != nil {
// 回滚事务 tx.Rollback()
tx.Rollback() return err
return err }
if err := tx.Where("channel_id in (?)", chunk).Delete(&Ability{}).Error; err != nil {
tx.Rollback()
return err
}
} }
// 提交事务 return tx.Commit().Error
tx.Commit()
return err
} }
func (channel *Channel) GetPriority() int64 { func (channel *Channel) GetPriority() int64 {
@@ -412,7 +429,7 @@ func (channel *Channel) Insert() error {
if err != nil { if err != nil {
return err return err
} }
err = channel.AddAbilities() err = channel.AddAbilities(nil)
return err return err
} }