diff --git a/controller/channel.go b/controller/channel.go index a4ef87c3..1cfb7906 100644 --- a/controller/channel.go +++ b/controller/channel.go @@ -43,22 +43,23 @@ type OpenAIModelsResponse struct { func GetAllChannels(c *gin.Context) { p, _ := strconv.Atoi(c.Query("p")) pageSize, _ := strconv.Atoi(c.Query("page_size")) - if p < 0 { - p = 0 + if p < 1 { + p = 1 } - if pageSize < 0 { + if pageSize < 1 { pageSize = common.ItemsPerPage } channelData := make([]*model.Channel, 0) idSort, _ := strconv.ParseBool(c.Query("id_sort")) enableTagMode, _ := strconv.ParseBool(c.Query("tag_mode")) + + var total int64 + if enableTagMode { - tags, err := model.GetPaginatedTags(p*pageSize, pageSize) + // tag 分页:先分页 tag,再取各 tag 下 channels + tags, err := model.GetPaginatedTags((p-1)*pageSize, pageSize) if err != nil { - c.JSON(http.StatusOK, gin.H{ - "success": false, - "message": err.Error(), - }) + c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()}) return } for _, tag := range tags { @@ -69,21 +70,27 @@ func GetAllChannels(c *gin.Context) { } } } + // 计算 tag 总数用于分页 + total, _ = model.CountAllTags() } else { - channels, err := model.GetAllChannels(p*pageSize, pageSize, false, idSort) + channels, err := model.GetAllChannels((p-1)*pageSize, pageSize, false, idSort) if err != nil { - c.JSON(http.StatusOK, gin.H{ - "success": false, - "message": err.Error(), - }) + c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()}) return } channelData = channels + total, _ = model.CountAllChannels() } + c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", - "data": channelData, + "data": gin.H{ + "items": channelData, + "total": total, + "page": p, + "page_size": pageSize, + }, }) return } diff --git a/controller/midjourney.go b/controller/midjourney.go index 21027d8f..56bdcb80 100644 --- a/controller/midjourney.go +++ b/controller/midjourney.go @@ -7,7 +7,6 @@ import ( "fmt" "github.com/gin-gonic/gin" "io" - "log" "net/http" "one-api/common" "one-api/dto" @@ -215,8 +214,12 @@ func checkMjTaskNeedUpdate(oldTask *model.Midjourney, newTask dto.MidjourneyDto) func GetAllMidjourney(c *gin.Context) { p, _ := strconv.Atoi(c.Query("p")) - if p < 0 { - p = 0 + if p < 1 { + p = 1 + } + pageSize, _ := strconv.Atoi(c.Query("page_size")) + if pageSize <= 0 { + pageSize = common.ItemsPerPage } // 解析其他查询参数 @@ -227,31 +230,38 @@ func GetAllMidjourney(c *gin.Context) { EndTimestamp: c.Query("end_timestamp"), } - logs := model.GetAllTasks(p*common.ItemsPerPage, common.ItemsPerPage, queryParams) - if logs == nil { - logs = make([]*model.Midjourney, 0) - } + items := model.GetAllTasks((p-1)*pageSize, pageSize, queryParams) + total := model.CountAllTasks(queryParams) + if setting.MjForwardUrlEnabled { - for i, midjourney := range logs { + for i, midjourney := range items { midjourney.ImageUrl = setting.ServerAddress + "/mj/image/" + midjourney.MjId - logs[i] = midjourney + items[i] = midjourney } } c.JSON(200, gin.H{ "success": true, "message": "", - "data": logs, + "data": gin.H{ + "items": items, + "total": total, + "page": p, + "page_size": pageSize, + }, }) } func GetUserMidjourney(c *gin.Context) { p, _ := strconv.Atoi(c.Query("p")) - if p < 0 { - p = 0 + if p < 1 { + p = 1 + } + pageSize, _ := strconv.Atoi(c.Query("page_size")) + if pageSize <= 0 { + pageSize = common.ItemsPerPage } userId := c.GetInt("id") - log.Printf("userId = %d \n", userId) queryParams := model.TaskQueryParams{ MjID: c.Query("mj_id"), @@ -259,19 +269,23 @@ func GetUserMidjourney(c *gin.Context) { EndTimestamp: c.Query("end_timestamp"), } - logs := model.GetAllUserTask(userId, p*common.ItemsPerPage, common.ItemsPerPage, queryParams) - if logs == nil { - logs = make([]*model.Midjourney, 0) - } + items := model.GetAllUserTask(userId, (p-1)*pageSize, pageSize, queryParams) + total := model.CountAllUserTask(userId, queryParams) + if setting.MjForwardUrlEnabled { - for i, midjourney := range logs { + for i, midjourney := range items { midjourney.ImageUrl = setting.ServerAddress + "/mj/image/" + midjourney.MjId - logs[i] = midjourney + items[i] = midjourney } } c.JSON(200, gin.H{ "success": true, "message": "", - "data": logs, + "data": gin.H{ + "items": items, + "total": total, + "page": p, + "page_size": pageSize, + }, }) } diff --git a/controller/task.go b/controller/task.go index 65f79ead..34e14f3f 100644 --- a/controller/task.go +++ b/controller/task.go @@ -224,9 +224,14 @@ func checkTaskNeedUpdate(oldTask *model.Task, newTask dto.SunoDataResponse) bool func GetAllTask(c *gin.Context) { p, _ := strconv.Atoi(c.Query("p")) - if p < 0 { - p = 0 + if p < 1 { + p = 1 } + pageSize, _ := strconv.Atoi(c.Query("page_size")) + if pageSize <= 0 { + pageSize = common.ItemsPerPage + } + startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64) endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64) // 解析其他查询参数 @@ -237,24 +242,32 @@ func GetAllTask(c *gin.Context) { Action: c.Query("action"), StartTimestamp: startTimestamp, EndTimestamp: endTimestamp, + ChannelID: c.Query("channel_id"), } - logs := model.TaskGetAllTasks(p*common.ItemsPerPage, common.ItemsPerPage, queryParams) - if logs == nil { - logs = make([]*model.Task, 0) - } + items := model.TaskGetAllTasks((p-1)*pageSize, pageSize, queryParams) + total := model.TaskCountAllTasks(queryParams) c.JSON(200, gin.H{ "success": true, "message": "", - "data": logs, + "data": gin.H{ + "items": items, + "total": total, + "page": p, + "page_size": pageSize, + }, }) } func GetUserTask(c *gin.Context) { p, _ := strconv.Atoi(c.Query("p")) - if p < 0 { - p = 0 + if p < 1 { + p = 1 + } + pageSize, _ := strconv.Atoi(c.Query("page_size")) + if pageSize <= 0 { + pageSize = common.ItemsPerPage } userId := c.GetInt("id") @@ -271,14 +284,17 @@ func GetUserTask(c *gin.Context) { EndTimestamp: endTimestamp, } - logs := model.TaskGetAllUserTask(userId, p*common.ItemsPerPage, common.ItemsPerPage, queryParams) - if logs == nil { - logs = make([]*model.Task, 0) - } + items := model.TaskGetAllUserTask(userId, (p-1)*pageSize, pageSize, queryParams) + total := model.TaskCountAllUserTask(userId, queryParams) c.JSON(200, gin.H{ "success": true, "message": "", - "data": logs, + "data": gin.H{ + "items": items, + "total": total, + "page": p, + "page_size": pageSize, + }, }) } diff --git a/controller/token.go b/controller/token.go index a8803279..c57552c0 100644 --- a/controller/token.go +++ b/controller/token.go @@ -12,15 +12,15 @@ func GetAllTokens(c *gin.Context) { userId := c.GetInt("id") p, _ := strconv.Atoi(c.Query("p")) size, _ := strconv.Atoi(c.Query("size")) - if p < 0 { - p = 0 + if p < 1 { + p = 1 } if size <= 0 { size = common.ItemsPerPage } else if size > 100 { size = 100 } - tokens, err := model.GetAllUserTokens(userId, p*size, size) + tokens, err := model.GetAllUserTokens(userId, (p-1)*size, size) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, @@ -28,10 +28,18 @@ func GetAllTokens(c *gin.Context) { }) return } + // Get total count for pagination + total, _ := model.CountUserTokens(userId) + c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", - "data": tokens, + "data": gin.H{ + "items": tokens, + "total": total, + "page": p, + "page_size": size, + }, }) return } diff --git a/model/channel.go b/model/channel.go index ed7a0a7e..a302df40 100644 --- a/model/channel.go +++ b/model/channel.go @@ -583,3 +583,17 @@ func BatchSetChannelTag(ids []int, tag *string) error { // 提交事务 return tx.Commit().Error } + +// CountAllChannels returns total channels in DB +func CountAllChannels() (int64, error) { + var total int64 + err := DB.Model(&Channel{}).Count(&total).Error + return total, err +} + +// CountAllTags returns number of non-empty distinct tags +func CountAllTags() (int64, error) { + var total int64 + err := DB.Model(&Channel{}).Where("tag is not null AND tag != ''").Distinct("tag").Count(&total).Error + return total, err +} diff --git a/model/midjourney.go b/model/midjourney.go index 5f85abfd..e8140447 100644 --- a/model/midjourney.go +++ b/model/midjourney.go @@ -166,3 +166,40 @@ func MjBulkUpdateByTaskIds(taskIDs []int, params map[string]any) error { Where("id in (?)", taskIDs). Updates(params).Error } + +// CountAllTasks returns total midjourney tasks for admin query +func CountAllTasks(queryParams TaskQueryParams) int64 { + var total int64 + query := DB.Model(&Midjourney{}) + if queryParams.ChannelID != "" { + query = query.Where("channel_id = ?", queryParams.ChannelID) + } + if queryParams.MjID != "" { + query = query.Where("mj_id = ?", queryParams.MjID) + } + if queryParams.StartTimestamp != "" { + query = query.Where("submit_time >= ?", queryParams.StartTimestamp) + } + if queryParams.EndTimestamp != "" { + query = query.Where("submit_time <= ?", queryParams.EndTimestamp) + } + _ = query.Count(&total).Error + return total +} + +// CountAllUserTask returns total midjourney tasks for user +func CountAllUserTask(userId int, queryParams TaskQueryParams) int64 { + var total int64 + query := DB.Model(&Midjourney{}).Where("user_id = ?", userId) + if queryParams.MjID != "" { + query = query.Where("mj_id = ?", queryParams.MjID) + } + if queryParams.StartTimestamp != "" { + query = query.Where("submit_time >= ?", queryParams.StartTimestamp) + } + if queryParams.EndTimestamp != "" { + query = query.Where("submit_time <= ?", queryParams.EndTimestamp) + } + _ = query.Count(&total).Error + return total +} diff --git a/model/task.go b/model/task.go index df221edf..9e4177ba 100644 --- a/model/task.go +++ b/model/task.go @@ -302,3 +302,64 @@ func SumUsedTaskQuota(queryParams SyncTaskQueryParams) (stat []TaskQuotaUsage, e err = query.Select("mode, sum(quota) as count").Group("mode").Find(&stat).Error return stat, err } + +// TaskCountAllTasks returns total tasks that match the given query params (admin usage) +func TaskCountAllTasks(queryParams SyncTaskQueryParams) int64 { + var total int64 + query := DB.Model(&Task{}) + if queryParams.ChannelID != "" { + query = query.Where("channel_id = ?", queryParams.ChannelID) + } + if queryParams.Platform != "" { + query = query.Where("platform = ?", queryParams.Platform) + } + if queryParams.UserID != "" { + query = query.Where("user_id = ?", queryParams.UserID) + } + if len(queryParams.UserIDs) != 0 { + query = query.Where("user_id in (?)", queryParams.UserIDs) + } + if queryParams.TaskID != "" { + query = query.Where("task_id = ?", queryParams.TaskID) + } + if queryParams.Action != "" { + query = query.Where("action = ?", queryParams.Action) + } + if queryParams.Status != "" { + query = query.Where("status = ?", queryParams.Status) + } + if queryParams.StartTimestamp != 0 { + query = query.Where("submit_time >= ?", queryParams.StartTimestamp) + } + if queryParams.EndTimestamp != 0 { + query = query.Where("submit_time <= ?", queryParams.EndTimestamp) + } + _ = query.Count(&total).Error + return total +} + +// TaskCountAllUserTask returns total tasks for given user +func TaskCountAllUserTask(userId int, queryParams SyncTaskQueryParams) int64 { + var total int64 + query := DB.Model(&Task{}).Where("user_id = ?", userId) + if queryParams.TaskID != "" { + query = query.Where("task_id = ?", queryParams.TaskID) + } + if queryParams.Action != "" { + query = query.Where("action = ?", queryParams.Action) + } + if queryParams.Status != "" { + query = query.Where("status = ?", queryParams.Status) + } + if queryParams.Platform != "" { + query = query.Where("platform = ?", queryParams.Platform) + } + if queryParams.StartTimestamp != 0 { + query = query.Where("submit_time >= ?", queryParams.StartTimestamp) + } + if queryParams.EndTimestamp != 0 { + query = query.Where("submit_time <= ?", queryParams.EndTimestamp) + } + _ = query.Count(&total).Error + return total +} diff --git a/model/token.go b/model/token.go index 8587ea62..d4b26afe 100644 --- a/model/token.go +++ b/model/token.go @@ -320,3 +320,10 @@ func decreaseTokenQuota(id int, quota int) (err error) { ).Error return err } + +// CountUserTokens returns total number of tokens for the given user, used for pagination +func CountUserTokens(userId int) (int64, error) { + var total int64 + err := DB.Model(&Token{}).Where("user_id = ?", userId).Count(&total).Error + return total, err +} diff --git a/web/src/components/table/ChannelsTable.js b/web/src/components/table/ChannelsTable.js index 6eeeab9a..f5a78490 100644 --- a/web/src/components/table/ChannelsTable.js +++ b/web/src/components/table/ChannelsTable.js @@ -865,32 +865,22 @@ const ChannelsTable = () => { tagChannelDates.response_time = tagChannelDates.response_time / 2; } } - // data.key = '' + data.id setChannels(channelDates); - if (channelDates.length >= pageSize) { - setChannelCount(channelDates.length + pageSize); - } else { - setChannelCount(channelDates.length); - } }; - const loadChannels = async (startIdx, pageSize, idSort, enableTagMode) => { + const loadChannels = async (page, pageSize, idSort, enableTagMode) => { setLoading(true); const res = await API.get( - `/api/channel/?p=${startIdx}&page_size=${pageSize}&id_sort=${idSort}&tag_mode=${enableTagMode}`, + `/api/channel/?p=${page}&page_size=${pageSize}&id_sort=${idSort}&tag_mode=${enableTagMode}`, ); if (res === undefined) { return; } const { success, message, data } = res.data; if (success) { - if (startIdx === 0) { - setChannelFormat(data, enableTagMode); - } else { - let newChannels = [...channels]; - newChannels.splice(startIdx * pageSize, data.length, ...data); - setChannelFormat(newChannels, enableTagMode); - } + const { items, total } = data; + setChannelFormat(items, enableTagMode); + setChannelCount(total); } else { showError(message); } @@ -903,7 +893,6 @@ const ChannelsTable = () => { channelToCopy.created_time = null; channelToCopy.balance = 0; channelToCopy.used_quota = 0; - // 删除可能导致类型不匹配的字段 delete channelToCopy.test_time; delete channelToCopy.response_time; if (!channelToCopy) { @@ -927,7 +916,7 @@ const ChannelsTable = () => { const refresh = async () => { const { searchKeyword, searchGroup, searchModel } = getFormValues(); if (searchKeyword === '' && searchGroup === '' && searchModel === '') { - await loadChannels(activePage - 1, pageSize, idSort, enableTagMode); + await loadChannels(activePage, pageSize, idSort, enableTagMode); } else { await searchChannels(enableTagMode); } @@ -944,7 +933,7 @@ const ChannelsTable = () => { setPageSize(localPageSize); setEnableTagMode(localEnableTagMode); setEnableBatchDelete(localEnableBatchDelete); - loadChannels(0, localPageSize, localIdSort, localEnableTagMode) + loadChannels(1, localPageSize, localIdSort, localEnableTagMode) .then() .catch((reason) => { showError(reason); @@ -1052,7 +1041,6 @@ const ChannelsTable = () => { try { if (searchKeyword === '' && searchGroup === '' && searchModel === '') { await loadChannels(activePage - 1, pageSize, idSort, enableTagMode); - // setActivePage(1); return; } @@ -1191,24 +1179,18 @@ const ChannelsTable = () => { } }; - let pageData = channels.slice( - (activePage - 1) * pageSize, - activePage * pageSize, - ); + let pageData = channels; const handlePageChange = (page) => { setActivePage(page); - if (page === Math.ceil(channels.length / pageSize) + 1) { - // In this case we have to load more data and then append them. - loadChannels(page - 1, pageSize, idSort, enableTagMode).then((r) => { }); - } + loadChannels(page, pageSize, idSort, enableTagMode).then(() => { }); }; const handlePageSizeChange = async (size) => { localStorage.setItem('page-size', size + ''); setPageSize(size); setActivePage(1); - loadChannels(0, size, idSort, enableTagMode) + loadChannels(1, size, idSort, enableTagMode) .then() .catch((reason) => { showError(reason); @@ -1218,8 +1200,6 @@ const ChannelsTable = () => { const fetchGroups = async () => { try { let res = await API.get(`/api/group/`); - // add 'all' option - // res.data.data.unshift('all'); if (res === undefined) { return; } @@ -1514,7 +1494,7 @@ const ChannelsTable = () => { onChange={(v) => { localStorage.setItem('id-sort', v + ''); setIdSort(v); - loadChannels(0, pageSize, v, enableTagMode); + loadChannels(activePage, pageSize, v, enableTagMode); }} /> @@ -1541,7 +1521,8 @@ const ChannelsTable = () => { onChange={(v) => { localStorage.setItem('enable-tag-mode', v + ''); setEnableTagMode(v); - loadChannels(0, pageSize, idSort, v); + setActivePage(1); + loadChannels(1, pageSize, idSort, v); }} /> @@ -1703,7 +1684,7 @@ const ChannelsTable = () => { formatPageText: (page) => t('第 {{start}} - {{end}} 条,共 {{total}} 条', { start: page.currentStart, end: page.currentEnd, - total: channels.length, + total: channelCount, }), onPageSizeChange: (size) => { handlePageSizeChange(size); diff --git a/web/src/components/table/MjLogsTable.js b/web/src/components/table/MjLogsTable.js index 08376641..869db485 100644 --- a/web/src/components/table/MjLogsTable.js +++ b/web/src/components/table/MjLogsTable.js @@ -601,7 +601,7 @@ const LogsTable = () => { const [logs, setLogs] = useState([]); const [loading, setLoading] = useState(true); const [activePage, setActivePage] = useState(1); - const [logCount, setLogCount] = useState(ITEMS_PER_PAGE); + const [logCount, setLogCount] = useState(0); const [pageSize, setPageSize] = useState(ITEMS_PER_PAGE); const [isModalOpenurl, setIsModalOpenurl] = useState(false); const [showBanner, setShowBanner] = useState(false); @@ -649,69 +649,53 @@ const LogsTable = () => { }; }; - const setLogsFormat = (logs) => { - for (let i = 0; i < logs.length; i++) { - logs[i].timestamp2string = timestamp2string(logs[i].created_at); - logs[i].key = '' + logs[i].id; - } - // data.key = '' + data.id - setLogs(logs); - setLogCount(logs.length + pageSize); - // console.log(logCount); + const enrichLogs = (items) => { + return items.map((log) => ({ + ...log, + timestamp2string: timestamp2string(log.created_at), + key: '' + log.id, + })); }; - const loadLogs = async (startIdx, pageSize = ITEMS_PER_PAGE) => { - setLoading(true); + const syncPageData = (payload) => { + const items = enrichLogs(payload.items || []); + setLogs(items); + setLogCount(payload.total || 0); + setActivePage(payload.page || 1); + setPageSize(payload.page_size || pageSize); + }; - let url = ''; + const loadLogs = async (page = 1, size = pageSize) => { + setLoading(true); const { channel_id, mj_id, start_timestamp, end_timestamp } = getFormValues(); let localStartTimestamp = Date.parse(start_timestamp); let localEndTimestamp = Date.parse(end_timestamp); - if (isAdminUser) { - url = `/api/mj/?p=${startIdx}&page_size=${pageSize}&channel_id=${channel_id}&mj_id=${mj_id}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`; - } else { - url = `/api/mj/self/?p=${startIdx}&page_size=${pageSize}&mj_id=${mj_id}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`; - } + const url = isAdminUser + ? `/api/mj/?p=${page}&page_size=${size}&channel_id=${channel_id}&mj_id=${mj_id}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}` + : `/api/mj/self/?p=${page}&page_size=${size}&mj_id=${mj_id}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`; const res = await API.get(url); const { success, message, data } = res.data; if (success) { - if (startIdx === 0) { - setLogsFormat(data); - } else { - let newLogs = [...logs]; - newLogs.splice(startIdx * pageSize, data.length, ...data); - setLogsFormat(newLogs); - } + syncPageData(data); } else { showError(message); } setLoading(false); }; - const pageData = logs.slice( - (activePage - 1) * pageSize, - activePage * pageSize, - ); + const pageData = logs; const handlePageChange = (page) => { - setActivePage(page); - if (page === Math.ceil(logs.length / pageSize) + 1) { - // In this case we have to load more data and then append them. - loadLogs(page - 1, pageSize).then((r) => { }); - } + loadLogs(page, pageSize).then(); }; const handlePageSizeChange = async (size) => { localStorage.setItem('mj-page-size', size + ''); - setPageSize(size); - setActivePage(1); - await loadLogs(0, size); + await loadLogs(1, size); }; const refresh = async () => { - // setLoading(true); - setActivePage(1); - await loadLogs(0, pageSize); + await loadLogs(1, pageSize); }; const copyText = async (text) => { @@ -726,7 +710,7 @@ const LogsTable = () => { useEffect(() => { const localPageSize = parseInt(localStorage.getItem('mj-page-size')) || ITEMS_PER_PAGE; setPageSize(localPageSize); - loadLogs(0, localPageSize).then(); + loadLogs(1, localPageSize).then(); }, []); useEffect(() => { @@ -936,7 +920,7 @@ const LogsTable = () => { >