feat: Implement batch tagging functionality for channels

- Added a new endpoint to batch set tags for multiple channels, allowing users to update tags efficiently.
- Introduced a new `BatchSetChannelTag` function in the controller to handle incoming requests and validate parameters.
- Updated the `BatchSetChannelTag` method in the model to manage database transactions and ensure data integrity during tag updates.
- Enhanced the ChannelsTable component in the frontend to support batch tag setting, including UI elements for user interaction.
- Updated localization files to include new translation keys related to batch operations and tag settings.
This commit is contained in:
CalciumIon
2024-12-25 14:19:00 +08:00
parent f2c9388139
commit 72d6898eb5
7 changed files with 201 additions and 24 deletions

View File

@@ -419,7 +419,8 @@ func EditTagChannels(c *gin.Context) {
} }
type ChannelBatch struct { type ChannelBatch struct {
Ids []int `json:"ids"` Ids []int `json:"ids"`
Tag *string `json:"tag"`
} }
func DeleteChannelBatch(c *gin.Context) { func DeleteChannelBatch(c *gin.Context) {
@@ -570,3 +571,29 @@ func FetchModels(c *gin.Context) {
"data": models, "data": models,
}) })
} }
func BatchSetChannelTag(c *gin.Context) {
channelBatch := ChannelBatch{}
err := c.ShouldBindJSON(&channelBatch)
if err != nil || len(channelBatch.Ids) == 0 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "参数错误",
})
return
}
err = model.BatchSetChannelTag(channelBatch.Ids, channelBatch.Tag)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": len(channelBatch.Ids),
})
return
}

View File

@@ -3,10 +3,11 @@ package model
import ( import (
"errors" "errors"
"fmt" "fmt"
"github.com/samber/lo"
"gorm.io/gorm"
"one-api/common" "one-api/common"
"strings" "strings"
"github.com/samber/lo"
"gorm.io/gorm"
) )
type Ability struct { type Ability struct {
@@ -173,18 +174,67 @@ func (channel *Channel) DeleteAbilities() error {
// UpdateAbilities updates abilities of this channel. // UpdateAbilities updates abilities of this channel.
// Make sure the channel is completed before calling this function. // Make sure the channel is completed before calling this function.
func (channel *Channel) UpdateAbilities() error { func (channel *Channel) UpdateAbilities(tx *gorm.DB) error {
// A quick and dirty way to update abilities isNewTx := false
// 如果没有传入事务,创建新的事务
if tx == nil {
tx = DB.Begin()
if tx.Error != nil {
return tx.Error
}
isNewTx = true
defer func() {
if r := recover(); r != nil {
tx.Rollback()
}
}()
}
// First delete all abilities of this channel // First delete all abilities of this channel
err := channel.DeleteAbilities() err := tx.Where("channel_id = ?", channel.Id).Delete(&Ability{}).Error
if err != nil { if err != nil {
if isNewTx {
tx.Rollback()
}
return err return err
} }
// Then add new abilities // Then add new abilities
err = channel.AddAbilities() models_ := strings.Split(channel.Models, ",")
if err != nil { groups_ := strings.Split(channel.Group, ",")
return err abilities := make([]Ability, 0, len(models_))
for _, model := range models_ {
for _, group := range groups_ {
ability := Ability{
Group: group,
Model: model,
ChannelId: channel.Id,
Enabled: channel.Status == common.ChannelStatusEnabled,
Priority: channel.Priority,
Weight: uint(channel.GetWeight()),
Tag: channel.Tag,
}
abilities = append(abilities, ability)
}
} }
if len(abilities) > 0 {
for _, chunk := range lo.Chunk(abilities, 50) {
err = tx.Create(&chunk).Error
if err != nil {
if isNewTx {
tx.Rollback()
}
return err
}
}
}
// 如果是新创建的事务,需要提交
if isNewTx {
return tx.Commit().Error
}
return nil return nil
} }
@@ -246,7 +296,7 @@ func FixAbility() (int, error) {
return 0, err return 0, err
} }
for _, channel := range channels { for _, channel := range channels {
err := channel.UpdateAbilities() err := channel.UpdateAbilities(nil)
if err != nil { if err != nil {
common.SysError(fmt.Sprintf("Update abilities of channel %d failed: %s", channel.Id, err.Error())) common.SysError(fmt.Sprintf("Update abilities of channel %d failed: %s", channel.Id, err.Error()))
} else { } else {

View File

@@ -257,7 +257,7 @@ func (channel *Channel) Update() error {
return err return err
} }
DB.Model(channel).First(channel, "id = ?", channel.Id) DB.Model(channel).First(channel, "id = ?", channel.Id)
err = channel.UpdateAbilities() err = channel.UpdateAbilities(nil)
return err return err
} }
@@ -389,7 +389,7 @@ func EditChannelByTag(tag string, newTag *string, modelMapping *string, models *
channels, err := GetChannelsByTag(updatedTag, false) channels, err := GetChannelsByTag(updatedTag, false)
if err == nil { if err == nil {
for _, channel := range channels { for _, channel := range channels {
err = channel.UpdateAbilities() err = channel.UpdateAbilities(nil)
if err != nil { if err != nil {
common.SysError("failed to update abilities: " + err.Error()) common.SysError("failed to update abilities: " + err.Error())
} }
@@ -509,3 +509,42 @@ func (channel *Channel) SetSetting(setting map[string]interface{}) {
} }
channel.Setting = string(settingBytes) channel.Setting = string(settingBytes)
} }
func GetChannelsByIds(ids []int) ([]*Channel, error) {
var channels []*Channel
err := DB.Where("id in (?)", ids).Find(&channels).Error
return channels, err
}
func BatchSetChannelTag(ids []int, tag *string) error {
// 开启事务
tx := DB.Begin()
if tx.Error != nil {
return tx.Error
}
// 更新标签
err := tx.Model(&Channel{}).Where("id in (?)", ids).Update("tag", tag).Error
if err != nil {
tx.Rollback()
return err
}
// update ability status
channels, err := GetChannelsByIds(ids)
if err != nil {
tx.Rollback()
return err
}
for _, channel := range channels {
err = channel.UpdateAbilities(tx)
if err != nil {
tx.Rollback()
return err
}
}
// 提交事务
return tx.Commit().Error
}

View File

@@ -99,7 +99,7 @@ func SetApiRouter(router *gin.Engine) {
channelRoute.POST("/fix", controller.FixChannelsAbilities) channelRoute.POST("/fix", controller.FixChannelsAbilities)
channelRoute.GET("/fetch_models/:id", controller.FetchUpstreamModels) channelRoute.GET("/fetch_models/:id", controller.FetchUpstreamModels)
channelRoute.POST("/fetch_models", controller.FetchModels) channelRoute.POST("/fetch_models", controller.FetchModels)
channelRoute.POST("/batch/tag", controller.BatchSetChannelTag)
} }
tokenRoute := apiRouter.Group("/token") tokenRoute := apiRouter.Group("/token")
tokenRoute.Use(middleware.UserAuth()) tokenRoute.Use(middleware.UserAuth())

View File

@@ -162,9 +162,15 @@ const ChannelsTable = () => {
return ( return (
<div> <div>
<Space spacing={2}> <Space spacing={2}>
{text?.split(',').map((item, index) => { {text?.split(',')
return renderGroup(item); .sort((a, b) => {
})} if (a === 'default') return -1;
if (b === 'default') return 1;
return a.localeCompare(b);
})
.map((item, index) => {
return renderGroup(item);
})}
</Space> </Space>
</div> </div>
); );
@@ -507,6 +513,8 @@ const ChannelsTable = () => {
const [selectedChannels, setSelectedChannels] = useState([]); const [selectedChannels, setSelectedChannels] = useState([]);
const [showEditPriority, setShowEditPriority] = useState(false); const [showEditPriority, setShowEditPriority] = useState(false);
const [enableTagMode, setEnableTagMode] = useState(false); const [enableTagMode, setEnableTagMode] = useState(false);
const [showBatchSetTag, setShowBatchSetTag] = useState(false);
const [batchSetTagValue, setBatchSetTagValue] = useState('');
const removeRecord = (record) => { const removeRecord = (record) => {
@@ -968,6 +976,29 @@ const ChannelsTable = () => {
} }
}; };
const batchSetChannelTag = async () => {
if (selectedChannels.length === 0) {
showError(t('请先选择要设置标签的渠道!'));
return;
}
if (batchSetTagValue === '') {
showError(t('标签不能为空!'));
return;
}
let ids = selectedChannels.map(channel => channel.id);
const res = await API.post('/api/channel/batch/tag', {
ids: ids,
tag: batchSetTagValue === '' ? null : batchSetTagValue
});
if (res.data.success) {
showSuccess(t('已为 ${count} 个渠道设置标签!').replace('${count}', res.data.data));
await refresh();
setShowBatchSetTag(false);
} else {
showError(res.data.message);
}
};
return ( return (
<> <>
<EditTagModal <EditTagModal
@@ -1115,11 +1146,11 @@ const ChannelsTable = () => {
</div> </div>
<div style={{ marginTop: 20 }}> <div style={{ marginTop: 20 }}>
<Space> <Space>
<Typography.Text strong>{t('开启批量删除')}</Typography.Text> <Typography.Text strong>{t('开启批量操作')}</Typography.Text>
<Switch <Switch
label={t('开启批量删除')} label={t('开启批量操作')}
uncheckedText={t('关')} uncheckedText={t('关')}
aria-label={t('是否开启批量删除')} aria-label={t('是否开启批量操作')}
onChange={(v) => { onChange={(v) => {
setEnableBatchDelete(v); setEnableBatchDelete(v);
}} }}
@@ -1167,7 +1198,17 @@ const ChannelsTable = () => {
loadChannels(0, pageSize, idSort, v); loadChannels(0, pageSize, idSort, v);
}} }}
/> />
<Button
disabled={!enableBatchDelete}
theme="light"
type="primary"
style={{ marginRight: 8 }}
onClick={() => setShowBatchSetTag(true)}
>
{t('批量设置标签')}
</Button>
</Space> </Space>
</div> </div>
@@ -1201,6 +1242,23 @@ const ChannelsTable = () => {
: null : null
} }
/> />
<Modal
title={t('批量设置标签')}
visible={showBatchSetTag}
onOk={batchSetChannelTag}
onCancel={() => setShowBatchSetTag(false)}
maskClosable={false}
centered={true}
>
<div style={{ marginBottom: 20 }}>
<Typography.Text>{t('请输入要设置的标签名称')}</Typography.Text>
</div>
<Input
placeholder={t('请输入标签名称')}
value={batchSetTagValue}
onChange={(v) => setBatchSetTagValue(v)}
/>
</Modal>
</> </>
); );
}; };

View File

@@ -546,8 +546,8 @@
"是否用ID排序": "Whether to sort by ID", "是否用ID排序": "Whether to sort by ID",
"确定?": "Sure?", "确定?": "Sure?",
"确定是否要删除禁用通道?": "Are you sure you want to delete the disabled channel?", "确定是否要删除禁用通道?": "Are you sure you want to delete the disabled channel?",
"开启批量删除": "Enable batch selection", "开启批量操作": "Enable batch selection",
"是否开启批量删除": "Whether to enable batch selection", "是否开启批量操作": "Whether to enable batch selection",
"确定是否要删除所选通道?": "Are you sure you want to delete the selected channels?", "确定是否要删除所选通道?": "Are you sure you want to delete the selected channels?",
"确定是否要修复数据库一致性?": "Are you sure you want to repair database consistency?", "确定是否要修复数据库一致性?": "Are you sure you want to repair database consistency?",
"进行该操作时,可能导致渠道访问错误,请仅在数据库出现问题时使用": "When performing this operation, it may cause channel access errors. Please only use it when there is a problem with the database.", "进行该操作时,可能导致渠道访问错误,请仅在数据库出现问题时使用": "When performing this operation, it may cause channel access errors. Please only use it when there is a problem with the database.",

View File

@@ -548,8 +548,8 @@
"是否用ID排序": "Whether to sort by ID", "是否用ID排序": "Whether to sort by ID",
"确定?": "Sure?", "确定?": "Sure?",
"确定是否要删除禁用通道?": "Are you sure you want to delete the disabled channel?", "确定是否要删除禁用通道?": "Are you sure you want to delete the disabled channel?",
"开启批量删除": "Enable batch selection", "开启批量操作": "Enable batch selection",
"是否开启批量删除": "Whether to enable batch selection", "是否开启批量操作": "Whether to enable batch selection",
"确定是否要删除所选通道?": "Are you sure you want to delete the selected channels?", "确定是否要删除所选通道?": "Are you sure you want to delete the selected channels?",
"确定是否要修复数据库一致性?": "Are you sure you want to repair database consistency?", "确定是否要修复数据库一致性?": "Are you sure you want to repair database consistency?",
"进行该操作时,可能导致渠道访问错误,请仅在数据库出现问题时使用": "When performing this operation, it may cause channel access errors. Please only use it when there is a problem with the database.", "进行该操作时,可能导致渠道访问错误,请仅在数据库出现问题时使用": "When performing this operation, it may cause channel access errors. Please only use it when there is a problem with the database.",
@@ -1237,5 +1237,8 @@
"更多": "Expand more", "更多": "Expand more",
"个模型": "models", "个模型": "models",
"可用模型": "Available models", "可用模型": "Available models",
"时间范围": "Time range" "时间范围": "Time range",
"批量设置标签": "Batch set tag",
"请输入要设置的标签名称": "Please enter the tag name to be set",
"请输入标签名称": "Please enter the tag name"
} }