✨ feat(ability): enhance FixAbility function
This commit is contained in:
@@ -228,7 +228,7 @@ func FetchUpstreamModels(c *gin.Context) {
|
||||
}
|
||||
|
||||
func FixChannelsAbilities(c *gin.Context) {
|
||||
count, err := model.FixAbility()
|
||||
success, fails, err := model.FixAbility()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
@@ -239,7 +239,10 @@ func FixChannelsAbilities(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": count,
|
||||
"data": gin.H{
|
||||
"success": success,
|
||||
"fails": fails,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
4
main.go
4
main.go
@@ -68,9 +68,9 @@ func main() {
|
||||
if r := recover(); r != nil {
|
||||
common.SysError(fmt.Sprintf("InitChannelCache panic: %v, retrying once", r))
|
||||
// Retry once
|
||||
_, fixErr := model.FixAbility()
|
||||
_, _, fixErr := model.FixAbility()
|
||||
if fixErr != nil {
|
||||
common.SysError(fmt.Sprintf("InitChannelCache failed: %s", fixErr.Error()))
|
||||
common.FatalLog(fmt.Sprintf("InitChannelCache failed: %s", fixErr.Error()))
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"one-api/common"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/samber/lo"
|
||||
"gorm.io/gorm"
|
||||
@@ -272,74 +273,45 @@ func UpdateAbilityByTag(tag string, newTag *string, priority *int64, weight *uin
|
||||
return DB.Model(&Ability{}).Where("tag = ?", tag).Updates(ability).Error
|
||||
}
|
||||
|
||||
func FixAbility() (int, error) {
|
||||
var channelIds []int
|
||||
count := 0
|
||||
// Find all channel ids from channel table
|
||||
err := DB.Model(&Channel{}).Pluck("id", &channelIds).Error
|
||||
var fixLock = sync.Mutex{}
|
||||
|
||||
func FixAbility() (int, int, error) {
|
||||
lock := fixLock.TryLock()
|
||||
if !lock {
|
||||
return 0, 0, errors.New("已经有一个修复任务在运行中,请稍后再试")
|
||||
}
|
||||
defer fixLock.Unlock()
|
||||
var channels []*Channel
|
||||
// Find all channels
|
||||
err := DB.Model(&Channel{}).Find(&channels).Error
|
||||
if err != nil {
|
||||
common.SysError(fmt.Sprintf("Get channel ids from channel table failed: %s", err.Error()))
|
||||
return 0, err
|
||||
return 0, 0, err
|
||||
}
|
||||
|
||||
// Delete abilities of channels that are not in channel table - in batches to avoid too many placeholders
|
||||
if len(channelIds) > 0 {
|
||||
// Process deletion in chunks to avoid "too many placeholders" error
|
||||
for _, chunk := range lo.Chunk(channelIds, 100) {
|
||||
err = DB.Where("channel_id NOT IN (?)", chunk).Delete(&Ability{}).Error
|
||||
if err != nil {
|
||||
common.SysError(fmt.Sprintf("Delete abilities of channels (batch) that are not in channel table failed: %s", err.Error()))
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// If no channels exist, delete all abilities
|
||||
err = DB.Delete(&Ability{}).Error
|
||||
if len(channels) == 0 {
|
||||
return 0, 0, nil
|
||||
}
|
||||
successCount := 0
|
||||
failCount := 0
|
||||
for _, chunk := range lo.Chunk(channels, 50) {
|
||||
ids := lo.Map(chunk, func(c *Channel, _ int) int { return c.Id })
|
||||
// Delete all abilities of this channel
|
||||
err = DB.Where("channel_id IN ?", ids).Delete(&Ability{}).Error
|
||||
if err != nil {
|
||||
common.SysError(fmt.Sprintf("Delete all abilities failed: %s", err.Error()))
|
||||
return 0, err
|
||||
common.SysError(fmt.Sprintf("Delete abilities failed: %s", err.Error()))
|
||||
failCount += len(chunk)
|
||||
continue
|
||||
}
|
||||
common.SysLog("Delete all abilities successfully")
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
common.SysLog(fmt.Sprintf("Delete abilities of channels that are not in channel table successfully, ids: %v", channelIds))
|
||||
count += len(channelIds)
|
||||
|
||||
// Use channelIds to find channel not in abilities table
|
||||
var abilityChannelIds []int
|
||||
err = DB.Table("abilities").Distinct("channel_id").Pluck("channel_id", &abilityChannelIds).Error
|
||||
if err != nil {
|
||||
common.SysError(fmt.Sprintf("Get channel ids from abilities table failed: %s", err.Error()))
|
||||
return count, err
|
||||
}
|
||||
|
||||
var channels []Channel
|
||||
if len(abilityChannelIds) == 0 {
|
||||
err = DB.Find(&channels).Error
|
||||
} else {
|
||||
// Process query in chunks to avoid "too many placeholders" error
|
||||
err = nil
|
||||
for _, chunk := range lo.Chunk(abilityChannelIds, 100) {
|
||||
var channelsChunk []Channel
|
||||
err = DB.Where("id NOT IN (?)", chunk).Find(&channelsChunk).Error
|
||||
// Then add new abilities
|
||||
for _, channel := range chunk {
|
||||
err = channel.AddAbilities()
|
||||
if err != nil {
|
||||
common.SysError(fmt.Sprintf("Find channels not in abilities table failed: %s", err.Error()))
|
||||
return count, err
|
||||
common.SysError(fmt.Sprintf("Add abilities for channel %d failed: %s", channel.Id, err.Error()))
|
||||
failCount++
|
||||
} else {
|
||||
successCount++
|
||||
}
|
||||
channels = append(channels, channelsChunk...)
|
||||
}
|
||||
}
|
||||
|
||||
for _, channel := range channels {
|
||||
err := channel.UpdateAbilities(nil)
|
||||
if err != nil {
|
||||
common.SysError(fmt.Sprintf("Update abilities of channel %d failed: %s", channel.Id, err.Error()))
|
||||
} else {
|
||||
common.SysLog(fmt.Sprintf("Update abilities of channel %d successfully", channel.Id))
|
||||
count++
|
||||
}
|
||||
}
|
||||
InitChannelCache()
|
||||
return count, nil
|
||||
return successCount, failCount, nil
|
||||
}
|
||||
|
||||
@@ -1461,9 +1461,9 @@ const ChannelsTable = () => {
|
||||
|
||||
const fixChannelsAbilities = async () => {
|
||||
const res = await API.post(`/api/channel/fix`);
|
||||
const { success, message, data } = res.data;
|
||||
const { success, message, data } = res.data;
|
||||
if (success) {
|
||||
showSuccess(t('已修复 ${data} 个通道!').replace('${data}', data));
|
||||
showSuccess(t('已修复 ${success} 个通道,失败 ${fails} 个通道。').replace('${success}', data.success).replace('${fails}', data.fails));
|
||||
await refresh();
|
||||
} else {
|
||||
showError(message);
|
||||
|
||||
@@ -240,7 +240,7 @@ const EditChannel = (props) => {
|
||||
if (isEdit) {
|
||||
// 如果是编辑模式,使用已有的channel id获取模型列表
|
||||
const res = await API.get('/api/channel/fetch_models/' + channelId);
|
||||
if (res.data && res.data?.success) {
|
||||
if (res.data && res.data.success) {
|
||||
models.push(...res.data.data);
|
||||
} else {
|
||||
err = true;
|
||||
|
||||
Reference in New Issue
Block a user