diff --git a/controller/channel.go b/controller/channel.go index 3cbef144..1551369e 100644 --- a/controller/channel.go +++ b/controller/channel.go @@ -419,7 +419,8 @@ func EditTagChannels(c *gin.Context) { } type ChannelBatch struct { - Ids []int `json:"ids"` + Ids []int `json:"ids"` + Tag *string `json:"tag"` } func DeleteChannelBatch(c *gin.Context) { @@ -570,3 +571,29 @@ func FetchModels(c *gin.Context) { "data": models, }) } + +func BatchSetChannelTag(c *gin.Context) { + channelBatch := ChannelBatch{} + err := c.ShouldBindJSON(&channelBatch) + if err != nil || len(channelBatch.Ids) == 0 { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "参数错误", + }) + return + } + err = model.BatchSetChannelTag(channelBatch.Ids, channelBatch.Tag) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": len(channelBatch.Ids), + }) + return +} diff --git a/model/ability.go b/model/ability.go index 8e084cf9..e21f928f 100644 --- a/model/ability.go +++ b/model/ability.go @@ -3,10 +3,11 @@ package model import ( "errors" "fmt" - "github.com/samber/lo" - "gorm.io/gorm" "one-api/common" "strings" + + "github.com/samber/lo" + "gorm.io/gorm" ) type Ability struct { @@ -173,18 +174,67 @@ func (channel *Channel) DeleteAbilities() error { // UpdateAbilities updates abilities of this channel. // Make sure the channel is completed before calling this function. -func (channel *Channel) UpdateAbilities() error { - // A quick and dirty way to update abilities +func (channel *Channel) UpdateAbilities(tx *gorm.DB) error { + isNewTx := false + // 如果没有传入事务,创建新的事务 + if tx == nil { + tx = DB.Begin() + if tx.Error != nil { + return tx.Error + } + isNewTx = true + defer func() { + if r := recover(); r != nil { + tx.Rollback() + } + }() + } + // First delete all abilities of this channel - err := channel.DeleteAbilities() + err := tx.Where("channel_id = ?", channel.Id).Delete(&Ability{}).Error if err != nil { + if isNewTx { + tx.Rollback() + } return err } + // Then add new abilities - err = channel.AddAbilities() - if err != nil { - return err + models_ := strings.Split(channel.Models, ",") + groups_ := strings.Split(channel.Group, ",") + abilities := make([]Ability, 0, len(models_)) + for _, model := range models_ { + for _, group := range groups_ { + ability := Ability{ + Group: group, + Model: model, + ChannelId: channel.Id, + Enabled: channel.Status == common.ChannelStatusEnabled, + Priority: channel.Priority, + Weight: uint(channel.GetWeight()), + Tag: channel.Tag, + } + abilities = append(abilities, ability) + } } + + if len(abilities) > 0 { + for _, chunk := range lo.Chunk(abilities, 50) { + err = tx.Create(&chunk).Error + if err != nil { + if isNewTx { + tx.Rollback() + } + return err + } + } + } + + // 如果是新创建的事务,需要提交 + if isNewTx { + return tx.Commit().Error + } + return nil } @@ -246,7 +296,7 @@ func FixAbility() (int, error) { return 0, err } for _, channel := range channels { - err := channel.UpdateAbilities() + err := channel.UpdateAbilities(nil) if err != nil { common.SysError(fmt.Sprintf("Update abilities of channel %d failed: %s", channel.Id, err.Error())) } else { diff --git a/model/channel.go b/model/channel.go index 6be16fd2..2024bafd 100644 --- a/model/channel.go +++ b/model/channel.go @@ -257,7 +257,7 @@ func (channel *Channel) Update() error { return err } DB.Model(channel).First(channel, "id = ?", channel.Id) - err = channel.UpdateAbilities() + err = channel.UpdateAbilities(nil) return err } @@ -389,7 +389,7 @@ func EditChannelByTag(tag string, newTag *string, modelMapping *string, models * channels, err := GetChannelsByTag(updatedTag, false) if err == nil { for _, channel := range channels { - err = channel.UpdateAbilities() + err = channel.UpdateAbilities(nil) if err != nil { common.SysError("failed to update abilities: " + err.Error()) } @@ -509,3 +509,42 @@ func (channel *Channel) SetSetting(setting map[string]interface{}) { } channel.Setting = string(settingBytes) } + +func GetChannelsByIds(ids []int) ([]*Channel, error) { + var channels []*Channel + err := DB.Where("id in (?)", ids).Find(&channels).Error + return channels, err +} + +func BatchSetChannelTag(ids []int, tag *string) error { + // 开启事务 + tx := DB.Begin() + if tx.Error != nil { + return tx.Error + } + + // 更新标签 + err := tx.Model(&Channel{}).Where("id in (?)", ids).Update("tag", tag).Error + if err != nil { + tx.Rollback() + return err + } + + // update ability status + channels, err := GetChannelsByIds(ids) + if err != nil { + tx.Rollback() + return err + } + + for _, channel := range channels { + err = channel.UpdateAbilities(tx) + if err != nil { + tx.Rollback() + return err + } + } + + // 提交事务 + return tx.Commit().Error +} diff --git a/router/api-router.go b/router/api-router.go index a64bcf52..bb87d8f5 100644 --- a/router/api-router.go +++ b/router/api-router.go @@ -99,7 +99,7 @@ func SetApiRouter(router *gin.Engine) { channelRoute.POST("/fix", controller.FixChannelsAbilities) channelRoute.GET("/fetch_models/:id", controller.FetchUpstreamModels) channelRoute.POST("/fetch_models", controller.FetchModels) - + channelRoute.POST("/batch/tag", controller.BatchSetChannelTag) } tokenRoute := apiRouter.Group("/token") tokenRoute.Use(middleware.UserAuth()) diff --git a/web/src/components/ChannelsTable.js b/web/src/components/ChannelsTable.js index ceb3fe44..890e32ba 100644 --- a/web/src/components/ChannelsTable.js +++ b/web/src/components/ChannelsTable.js @@ -162,9 +162,15 @@ const ChannelsTable = () => { return (