feat: enhance AddAbilities and BatchInsertChannels to support transaction handling

This commit is contained in:
CaIon
2025-08-08 18:36:09 +08:00
parent f6c7828160
commit 962c40c1a7
2 changed files with 48 additions and 26 deletions

View File

@@ -142,7 +142,7 @@ func GetRandomSatisfiedChannel(group string, model string, retry int) (*Channel,
return &channel, err
}
func (channel *Channel) AddAbilities() error {
func (channel *Channel) AddAbilities(tx *gorm.DB) error {
models_ := strings.Split(channel.Models, ",")
groups_ := strings.Split(channel.Group, ",")
abilitySet := make(map[string]struct{})
@@ -169,8 +169,13 @@ func (channel *Channel) AddAbilities() error {
if len(abilities) == 0 {
return nil
}
// choose DB or provided tx
useDB := DB
if tx != nil {
useDB = tx
}
for _, chunk := range lo.Chunk(abilities, 50) {
err := DB.Clauses(clause.OnConflict{DoNothing: true}).Create(&chunk).Error
err := useDB.Clauses(clause.OnConflict{DoNothing: true}).Create(&chunk).Error
if err != nil {
return err
}
@@ -321,7 +326,7 @@ func FixAbility() (int, int, error) {
}
// Then add new abilities
for _, channel := range chunk {
err = channel.AddAbilities()
err = channel.AddAbilities(nil)
if err != nil {
common.SysError(fmt.Sprintf("Add abilities for channel %d failed: %s", channel.Id, err.Error()))
failCount++

View File

@@ -13,6 +13,7 @@ import (
"strings"
"sync"
"github.com/samber/lo"
"gorm.io/gorm"
)
@@ -337,38 +338,54 @@ func GetChannelById(id int, selectAll bool) (*Channel, error) {
}
func BatchInsertChannels(channels []Channel) error {
var err error
err = DB.Create(&channels).Error
if err != nil {
return err
}
for _, channel_ := range channels {
err = channel_.AddAbilities()
if err != nil {
return err
}
}
if len(channels) == 0 {
return nil
}
tx := DB.Begin()
if tx.Error != nil {
return tx.Error
}
defer func() {
if r := recover(); r != nil {
tx.Rollback()
}
}()
for _, chunk := range lo.Chunk(channels, 50) {
if err := tx.Create(&chunk).Error; err != nil {
tx.Rollback()
return err
}
for _, channel_ := range chunk {
if err := channel_.AddAbilities(tx); err != nil {
tx.Rollback()
return err
}
}
}
return tx.Commit().Error
}
func BatchDeleteChannels(ids []int) error {
//使用事务 删除channel表和channel_ability表
if len(ids) == 0 {
return nil
}
// 使用事务 分批删除channel表和abilities表
tx := DB.Begin()
err := tx.Where("id in (?)", ids).Delete(&Channel{}).Error
if err != nil {
// 回滚事务
if tx.Error != nil {
return tx.Error
}
for _, chunk := range lo.Chunk(ids, 200) {
if err := tx.Where("id in (?)", chunk).Delete(&Channel{}).Error; err != nil {
tx.Rollback()
return err
}
err = tx.Where("channel_id in (?)", ids).Delete(&Ability{}).Error
if err != nil {
// 回滚事务
if err := tx.Where("channel_id in (?)", chunk).Delete(&Ability{}).Error; err != nil {
tx.Rollback()
return err
}
// 提交事务
tx.Commit()
return err
}
return tx.Commit().Error
}
func (channel *Channel) GetPriority() int64 {
@@ -412,7 +429,7 @@ func (channel *Channel) Insert() error {
if err != nil {
return err
}
err = channel.AddAbilities()
err = channel.AddAbilities(nil)
return err
}