feat: add tag aggregation mode to channels API and UI

This commit is contained in:
CalciumIon
2024-12-01 09:24:43 +08:00
parent bb0c504709
commit 88b0e6a768
3 changed files with 71 additions and 48 deletions

View File

@@ -3,12 +3,13 @@ package controller
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/gin-gonic/gin"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/model" "one-api/model"
"strconv" "strconv"
"strings" "strings"
"github.com/gin-gonic/gin"
) )
type OpenAIModel struct { type OpenAIModel struct {
@@ -48,41 +49,36 @@ func GetAllChannels(c *gin.Context) {
if pageSize < 0 { if pageSize < 0 {
pageSize = common.ItemsPerPage pageSize = common.ItemsPerPage
} }
channelData := make([]*model.Channel, 0)
idSort, _ := strconv.ParseBool(c.Query("id_sort")) idSort, _ := strconv.ParseBool(c.Query("id_sort"))
channels, err := model.GetAllChannels(p*pageSize, pageSize, false, idSort) enableTagMode, _ := strconv.ParseBool(c.Query("tag_mode"))
if err != nil { if enableTagMode {
c.JSON(http.StatusOK, gin.H{ tags, err := model.GetPaginatedTags(p*pageSize, pageSize)
"success": false, if err != nil {
"message": err.Error(), c.JSON(http.StatusOK, gin.H{
}) "success": false,
return "message": err.Error(),
} })
tags := make(map[string]bool) return
channelData := make([]*model.Channel, 0, len(channels))
tagChannels := make([]*model.Channel, 0)
for _, channel := range channels {
channelTag := channel.GetTag()
if channelTag != "" && !tags[channelTag] {
tags[channelTag] = true
tagChannel, err := model.GetChannelsByTag(channelTag)
if err == nil {
tagChannels = append(tagChannels, tagChannel...)
}
} else {
channelData = append(channelData, channel)
} }
} for _, tag := range tags {
for i, channel := range tagChannels { if tag != nil && *tag != "" {
find := false tagChannel, err := model.GetChannelsByTag(*tag)
for _, can := range channelData { if err == nil {
if channel.Id == can.Id { channelData = append(channelData, tagChannel...)
find = true }
break
} }
} }
if !find { } else {
channelData = append(channelData, tagChannels[i]) channels, err := model.GetAllChannels(p*pageSize, pageSize, false, idSort)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
} }
channelData = channels
} }
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": true, "success": true,

View File

@@ -2,9 +2,10 @@ package model
import ( import (
"encoding/json" "encoding/json"
"gorm.io/gorm"
"one-api/common" "one-api/common"
"strings" "strings"
"gorm.io/gorm"
) )
type Channel struct { type Channel struct {
@@ -403,3 +404,9 @@ func DeleteDisabledChannel() (int64, error) {
result := DB.Where("status = ? or status = ?", common.ChannelStatusAutoDisabled, common.ChannelStatusManuallyDisabled).Delete(&Channel{}) result := DB.Where("status = ? or status = ?", common.ChannelStatusAutoDisabled, common.ChannelStatusManuallyDisabled).Delete(&Channel{})
return result.RowsAffected, result.Error return result.RowsAffected, result.Error
} }
func GetPaginatedTags(offset int, limit int) ([]*string, error) {
var tags []*string
err := DB.Model(&Channel{}).Select("DISTINCT tag").Where("tag != ''").Offset(offset).Limit(limit).Find(&tags).Error
return tags, err
}

View File

@@ -439,6 +439,7 @@ const ChannelsTable = () => {
const [editingTag, setEditingTag] = useState(''); const [editingTag, setEditingTag] = useState('');
const [selectedChannels, setSelectedChannels] = useState([]); const [selectedChannels, setSelectedChannels] = useState([]);
const [showEditPriority, setShowEditPriority] = useState(false); const [showEditPriority, setShowEditPriority] = useState(false);
const [enableTagMode, setEnableTagMode] = useState(false);
const removeRecord = (record) => { const removeRecord = (record) => {
@@ -464,13 +465,12 @@ const ChannelsTable = () => {
} }
}; };
const setChannelFormat = (channels) => { const setChannelFormat = (channels, enableTagMode) => {
let channelDates = []; let channelDates = [];
let channelTags = {}; let channelTags = {};
for (let i = 0; i < channels.length; i++) { for (let i = 0; i < channels.length; i++) {
channels[i].key = '' + channels[i].id; channels[i].key = '' + channels[i].id;
if (!enableTagMode) {
if (channels[i].tag === '' || channels[i].tag === null) {
let test_models = []; let test_models = [];
channels[i].models.split(',').forEach((item, index) => { channels[i].models.split(',').forEach((item, index) => {
test_models.push({ test_models.push({
@@ -554,10 +554,10 @@ const ChannelsTable = () => {
} }
}; };
const loadChannels = async (startIdx, pageSize, idSort) => { const loadChannels = async (startIdx, pageSize, idSort, enableTagMode) => {
setLoading(true); setLoading(true);
const res = await API.get( const res = await API.get(
`/api/channel/?p=${startIdx}&page_size=${pageSize}&id_sort=${idSort}` `/api/channel/?p=${startIdx}&page_size=${pageSize}&id_sort=${idSort}&tag_mode=${enableTagMode}`
); );
if (res === undefined) { if (res === undefined) {
return; return;
@@ -565,11 +565,11 @@ const ChannelsTable = () => {
const { success, message, data } = res.data; const { success, message, data } = res.data;
if (success) { if (success) {
if (startIdx === 0) { if (startIdx === 0) {
setChannelFormat(data); setChannelFormat(data, enableTagMode);
} else { } else {
let newChannels = [...channels]; let newChannels = [...channels];
newChannels.splice(startIdx * pageSize, data.length, ...data); newChannels.splice(startIdx * pageSize, data.length, ...data);
setChannelFormat(newChannels); setChannelFormat(newChannels, enableTagMode);
} }
} else { } else {
showError(message); showError(message);
@@ -602,7 +602,7 @@ const ChannelsTable = () => {
}; };
const refresh = async () => { const refresh = async () => {
await loadChannels(activePage - 1, pageSize, idSort); await loadChannels(activePage - 1, pageSize, idSort, enableTagMode);
}; };
useEffect(() => { useEffect(() => {
@@ -612,7 +612,7 @@ const ChannelsTable = () => {
parseInt(localStorage.getItem('page-size')) || ITEMS_PER_PAGE; parseInt(localStorage.getItem('page-size')) || ITEMS_PER_PAGE;
setIdSort(localIdSort); setIdSort(localIdSort);
setPageSize(localPageSize); setPageSize(localPageSize);
loadChannels(0, localPageSize, localIdSort) loadChannels(0, localPageSize, localIdSort, enableTagMode)
.then() .then()
.catch((reason) => { .catch((reason) => {
showError(reason); showError(reason);
@@ -770,18 +770,22 @@ const ChannelsTable = () => {
const searchChannels = async (searchKeyword, searchGroup, searchModel) => { const searchChannels = async (searchKeyword, searchGroup, searchModel) => {
if (searchKeyword === '' && searchGroup === '' && searchModel === '') { if (searchKeyword === '' && searchGroup === '' && searchModel === '') {
// if keyword is blank, load files instead. await loadChannels(0, pageSize, idSort, enableTagMode);
await loadChannels(0, pageSize, idSort);
setActivePage(1); setActivePage(1);
return; return;
} }
setSearching(true); setSearching(true);
const res = await API.get( const res = await API.get(
`/api/channel/search?keyword=${searchKeyword}&group=${searchGroup}&model=${searchModel}&id_sort=${idSort}` `/api/channel/search?keyword=${searchKeyword}&group=${searchGroup}&model=${searchModel}&id_sort=${idSort}&tag_mode=${enableTagMode}`
); );
const { success, message, data } = res.data; const { success, message, data } = res.data;
if (success) { if (success) {
setChannelFormat(data); if (enableTagMode) {
setChannelFormat(data, enableTagMode);
} else {
setChannels(data.map(channel => ({...channel, key: '' + channel.id})));
setChannelCount(data.length);
}
setActivePage(1); setActivePage(1);
} else { } else {
showError(message); showError(message);
@@ -887,7 +891,7 @@ const ChannelsTable = () => {
setActivePage(page); setActivePage(page);
if (page === Math.ceil(channels.length / pageSize) + 1) { if (page === Math.ceil(channels.length / pageSize) + 1) {
// In this case we have to load more data and then append them. // In this case we have to load more data and then append them.
loadChannels(page - 1, pageSize, idSort).then((r) => { loadChannels(page - 1, pageSize, idSort, enableTagMode).then((r) => {
}); });
} }
}; };
@@ -896,7 +900,7 @@ const ChannelsTable = () => {
localStorage.setItem('page-size', size + ''); localStorage.setItem('page-size', size + '');
setPageSize(size); setPageSize(size);
setActivePage(1); setActivePage(1);
loadChannels(0, size, idSort) loadChannels(0, size, idSort, enableTagMode)
.then() .then()
.catch((reason) => { .catch((reason) => {
showError(reason); showError(reason);
@@ -1052,7 +1056,7 @@ const ChannelsTable = () => {
onChange={(v) => { onChange={(v) => {
localStorage.setItem('id-sort', v + ''); localStorage.setItem('id-sort', v + '');
setIdSort(v); setIdSort(v);
loadChannels(0, pageSize, v) loadChannels(0, pageSize, v, enableTagMode)
.then() .then()
.catch((reason) => { .catch((reason) => {
showError(reason); showError(reason);
@@ -1153,6 +1157,22 @@ const ChannelsTable = () => {
</Popconfirm> </Popconfirm>
</Space> </Space>
</div> </div>
<div style={{ marginTop: 20 }}>
<Space>
<Typography.Text strong>标签聚合模式</Typography.Text>
<Switch
checked={enableTagMode}
label="标签聚合模式"
uncheckedText="关"
aria-label="是否启用标签聚合"
onChange={(v) => {
setEnableTagMode(v);
// 切换模式时重新加载数据
loadChannels(0, pageSize, idSort, v);
}}
/>
</Space>
</div>
<Table <Table