diff --git a/backend/internal/handler/admin/usage_handler.go b/backend/internal/handler/admin/usage_handler.go index d0bba773..05fd00f1 100644 --- a/backend/internal/handler/admin/usage_handler.go +++ b/backend/internal/handler/admin/usage_handler.go @@ -61,6 +61,15 @@ type CreateUsageCleanupTaskRequest struct { // GET /api/v1/admin/usage func (h *UsageHandler) List(c *gin.Context) { page, pageSize := response.ParsePagination(c) + exactTotal := false + if exactTotalRaw := strings.TrimSpace(c.Query("exact_total")); exactTotalRaw != "" { + parsed, err := strconv.ParseBool(exactTotalRaw) + if err != nil { + response.BadRequest(c, "Invalid exact_total value, use true or false") + return + } + exactTotal = parsed + } // Parse filters var userID, apiKeyID, accountID, groupID int64 @@ -167,6 +176,7 @@ func (h *UsageHandler) List(c *gin.Context) { BillingType: billingType, StartTime: startTime, EndTime: endTime, + ExactTotal: exactTotal, } records, result, err := h.usageService.ListWithFilters(c.Request.Context(), params, filters) diff --git a/backend/internal/handler/admin/usage_handler_request_type_test.go b/backend/internal/handler/admin/usage_handler_request_type_test.go index 21add574..3f158316 100644 --- a/backend/internal/handler/admin/usage_handler_request_type_test.go +++ b/backend/internal/handler/admin/usage_handler_request_type_test.go @@ -80,6 +80,29 @@ func TestAdminUsageListInvalidStream(t *testing.T) { require.Equal(t, http.StatusBadRequest, rec.Code) } +func TestAdminUsageListExactTotalTrue(t *testing.T) { + repo := &adminUsageRepoCapture{} + router := newAdminUsageRequestTypeTestRouter(repo) + + req := httptest.NewRequest(http.MethodGet, "/admin/usage?exact_total=true", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + require.True(t, repo.listFilters.ExactTotal) +} + +func TestAdminUsageListInvalidExactTotal(t *testing.T) { + repo := &adminUsageRepoCapture{} + router := newAdminUsageRequestTypeTestRouter(repo) + + req := httptest.NewRequest(http.MethodGet, "/admin/usage?exact_total=oops", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusBadRequest, rec.Code) +} + func TestAdminUsageStatsRequestTypePriority(t *testing.T) { repo := &adminUsageRepoCapture{} router := newAdminUsageRequestTypeTestRouter(repo) diff --git a/backend/internal/pkg/usagestats/usage_log_types.go b/backend/internal/pkg/usagestats/usage_log_types.go index 314a6d3c..746188ea 100644 --- a/backend/internal/pkg/usagestats/usage_log_types.go +++ b/backend/internal/pkg/usagestats/usage_log_types.go @@ -154,6 +154,8 @@ type UsageLogFilters struct { BillingType *int8 StartTime *time.Time EndTime *time.Time + // ExactTotal requests exact COUNT(*) for pagination. Default false for fast large-table paging. + ExactTotal bool } // UsageStats represents usage statistics diff --git a/backend/internal/repository/usage_log_repo.go b/backend/internal/repository/usage_log_repo.go index ff40e97d..44079a55 100644 --- a/backend/internal/repository/usage_log_repo.go +++ b/backend/internal/repository/usage_log_repo.go @@ -1473,7 +1473,16 @@ func (r *usageLogRepository) ListWithFilters(ctx context.Context, params paginat } whereClause := buildWhere(conditions) - logs, page, err := r.listUsageLogsWithPagination(ctx, whereClause, args, params) + var ( + logs []service.UsageLog + page *pagination.PaginationResult + err error + ) + if shouldUseFastUsageLogTotal(filters) { + logs, page, err = r.listUsageLogsWithFastPagination(ctx, whereClause, args, params) + } else { + logs, page, err = r.listUsageLogsWithPagination(ctx, whereClause, args, params) + } if err != nil { return nil, nil, err } @@ -1484,17 +1493,45 @@ func (r *usageLogRepository) ListWithFilters(ctx context.Context, params paginat return logs, page, nil } +func shouldUseFastUsageLogTotal(filters UsageLogFilters) bool { + if filters.ExactTotal { + return false + } + // 强选择过滤下记录集通常较小,保留精确总数。 + return filters.UserID == 0 && filters.APIKeyID == 0 && filters.AccountID == 0 +} + // UsageStats represents usage statistics type UsageStats = usagestats.UsageStats // BatchUserUsageStats represents usage stats for a single user type BatchUserUsageStats = usagestats.BatchUserUsageStats +func normalizePositiveInt64IDs(ids []int64) []int64 { + if len(ids) == 0 { + return nil + } + seen := make(map[int64]struct{}, len(ids)) + out := make([]int64, 0, len(ids)) + for _, id := range ids { + if id <= 0 { + continue + } + if _, ok := seen[id]; ok { + continue + } + seen[id] = struct{}{} + out = append(out, id) + } + return out +} + // GetBatchUserUsageStats gets today and total actual_cost for multiple users within a time range. // If startTime is zero, defaults to 30 days ago. func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*BatchUserUsageStats, error) { result := make(map[int64]*BatchUserUsageStats) - if len(userIDs) == 0 { + normalizedUserIDs := normalizePositiveInt64IDs(userIDs) + if len(normalizedUserIDs) == 0 { return result, nil } @@ -1506,58 +1543,36 @@ func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs endTime = time.Now() } - for _, id := range userIDs { + for _, id := range normalizedUserIDs { result[id] = &BatchUserUsageStats{UserID: id} } query := ` - SELECT user_id, COALESCE(SUM(actual_cost), 0) as total_cost + SELECT + user_id, + COALESCE(SUM(actual_cost) FILTER (WHERE created_at >= $2 AND created_at < $3), 0) as total_cost, + COALESCE(SUM(actual_cost) FILTER (WHERE created_at >= $4), 0) as today_cost FROM usage_logs - WHERE user_id = ANY($1) AND created_at >= $2 AND created_at < $3 + WHERE user_id = ANY($1) + AND created_at >= LEAST($2, $4) GROUP BY user_id ` - rows, err := r.sql.QueryContext(ctx, query, pq.Array(userIDs), startTime, endTime) + today := timezone.Today() + rows, err := r.sql.QueryContext(ctx, query, pq.Array(normalizedUserIDs), startTime, endTime, today) if err != nil { return nil, err } for rows.Next() { var userID int64 var total float64 - if err := rows.Scan(&userID, &total); err != nil { + var todayTotal float64 + if err := rows.Scan(&userID, &total, &todayTotal); err != nil { _ = rows.Close() return nil, err } if stats, ok := result[userID]; ok { stats.TotalActualCost = total - } - } - if err := rows.Close(); err != nil { - return nil, err - } - if err := rows.Err(); err != nil { - return nil, err - } - - today := timezone.Today() - todayQuery := ` - SELECT user_id, COALESCE(SUM(actual_cost), 0) as today_cost - FROM usage_logs - WHERE user_id = ANY($1) AND created_at >= $2 - GROUP BY user_id - ` - rows, err = r.sql.QueryContext(ctx, todayQuery, pq.Array(userIDs), today) - if err != nil { - return nil, err - } - for rows.Next() { - var userID int64 - var total float64 - if err := rows.Scan(&userID, &total); err != nil { - _ = rows.Close() - return nil, err - } - if stats, ok := result[userID]; ok { - stats.TodayActualCost = total + stats.TodayActualCost = todayTotal } } if err := rows.Close(); err != nil { @@ -1577,7 +1592,8 @@ type BatchAPIKeyUsageStats = usagestats.BatchAPIKeyUsageStats // If startTime is zero, defaults to 30 days ago. func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64, startTime, endTime time.Time) (map[int64]*BatchAPIKeyUsageStats, error) { result := make(map[int64]*BatchAPIKeyUsageStats) - if len(apiKeyIDs) == 0 { + normalizedAPIKeyIDs := normalizePositiveInt64IDs(apiKeyIDs) + if len(normalizedAPIKeyIDs) == 0 { return result, nil } @@ -1589,58 +1605,36 @@ func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKe endTime = time.Now() } - for _, id := range apiKeyIDs { + for _, id := range normalizedAPIKeyIDs { result[id] = &BatchAPIKeyUsageStats{APIKeyID: id} } query := ` - SELECT api_key_id, COALESCE(SUM(actual_cost), 0) as total_cost + SELECT + api_key_id, + COALESCE(SUM(actual_cost) FILTER (WHERE created_at >= $2 AND created_at < $3), 0) as total_cost, + COALESCE(SUM(actual_cost) FILTER (WHERE created_at >= $4), 0) as today_cost FROM usage_logs - WHERE api_key_id = ANY($1) AND created_at >= $2 AND created_at < $3 + WHERE api_key_id = ANY($1) + AND created_at >= LEAST($2, $4) GROUP BY api_key_id ` - rows, err := r.sql.QueryContext(ctx, query, pq.Array(apiKeyIDs), startTime, endTime) + today := timezone.Today() + rows, err := r.sql.QueryContext(ctx, query, pq.Array(normalizedAPIKeyIDs), startTime, endTime, today) if err != nil { return nil, err } for rows.Next() { var apiKeyID int64 var total float64 - if err := rows.Scan(&apiKeyID, &total); err != nil { + var todayTotal float64 + if err := rows.Scan(&apiKeyID, &total, &todayTotal); err != nil { _ = rows.Close() return nil, err } if stats, ok := result[apiKeyID]; ok { stats.TotalActualCost = total - } - } - if err := rows.Close(); err != nil { - return nil, err - } - if err := rows.Err(); err != nil { - return nil, err - } - - today := timezone.Today() - todayQuery := ` - SELECT api_key_id, COALESCE(SUM(actual_cost), 0) as today_cost - FROM usage_logs - WHERE api_key_id = ANY($1) AND created_at >= $2 - GROUP BY api_key_id - ` - rows, err = r.sql.QueryContext(ctx, todayQuery, pq.Array(apiKeyIDs), today) - if err != nil { - return nil, err - } - for rows.Next() { - var apiKeyID int64 - var total float64 - if err := rows.Scan(&apiKeyID, &total); err != nil { - _ = rows.Close() - return nil, err - } - if stats, ok := result[apiKeyID]; ok { - stats.TodayActualCost = total + stats.TodayActualCost = todayTotal } } if err := rows.Close(); err != nil { @@ -2245,6 +2239,35 @@ func (r *usageLogRepository) listUsageLogsWithPagination(ctx context.Context, wh return logs, paginationResultFromTotal(total, params), nil } +func (r *usageLogRepository) listUsageLogsWithFastPagination(ctx context.Context, whereClause string, args []any, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) { + limit := params.Limit() + offset := params.Offset() + + limitPos := len(args) + 1 + offsetPos := len(args) + 2 + listArgs := append(append([]any{}, args...), limit+1, offset) + query := fmt.Sprintf("SELECT %s FROM usage_logs %s ORDER BY id DESC LIMIT $%d OFFSET $%d", usageLogSelectColumns, whereClause, limitPos, offsetPos) + + logs, err := r.queryUsageLogs(ctx, query, listArgs...) + if err != nil { + return nil, nil, err + } + + hasMore := false + if len(logs) > limit { + hasMore = true + logs = logs[:limit] + } + + total := int64(offset) + int64(len(logs)) + if hasMore { + // 只保证“还有下一页”,避免对超大表做全量 COUNT(*)。 + total = int64(offset) + int64(limit) + 1 + } + + return logs, paginationResultFromTotal(total, params), nil +} + func (r *usageLogRepository) queryUsageLogs(ctx context.Context, query string, args ...any) (logs []service.UsageLog, err error) { rows, err := r.sql.QueryContext(ctx, query, args...) if err != nil { diff --git a/backend/internal/repository/usage_log_repo_request_type_test.go b/backend/internal/repository/usage_log_repo_request_type_test.go index 95cf2a2d..54eb81e1 100644 --- a/backend/internal/repository/usage_log_repo_request_type_test.go +++ b/backend/internal/repository/usage_log_repo_request_type_test.go @@ -96,6 +96,7 @@ func TestUsageLogRepositoryListWithFiltersRequestTypePriority(t *testing.T) { filters := usagestats.UsageLogFilters{ RequestType: &requestType, Stream: &stream, + ExactTotal: true, } mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM usage_logs WHERE \\(request_type = \\$1 OR \\(request_type = 0 AND openai_ws_mode = TRUE\\)\\)"). diff --git a/frontend/src/api/admin/usage.ts b/frontend/src/api/admin/usage.ts index 66c84410..2d6212c5 100644 --- a/frontend/src/api/admin/usage.ts +++ b/frontend/src/api/admin/usage.ts @@ -75,6 +75,7 @@ export interface CreateUsageCleanupTaskRequest { export interface AdminUsageQueryParams extends UsageQueryParams { user_id?: number + exact_total?: boolean } // ==================== API Functions ==================== diff --git a/frontend/src/views/admin/UsageView.vue b/frontend/src/views/admin/UsageView.vue index b5aa63c8..9c39068a 100644 --- a/frontend/src/views/admin/UsageView.vue +++ b/frontend/src/views/admin/UsageView.vue @@ -88,6 +88,7 @@ const appStore = useAppStore() const usageStats = ref(null); const usageLogs = ref([]); const loading = ref(false); const exporting = ref(false) const trendData = ref([]); const modelStats = ref([]); const groupStats = ref([]); const chartsLoading = ref(false); const granularity = ref<'day' | 'hour'>('day') let abortController: AbortController | null = null; let exportAbortController: AbortController | null = null +let chartReqSeq = 0 const exportProgress = reactive({ show: false, progress: 0, current: 0, total: 0, estimatedTime: '' }) const cleanupDialogVisible = ref(false) @@ -109,7 +110,7 @@ const loadLogs = async () => { try { const requestType = filters.value.request_type const legacyStream = requestType ? requestTypeToLegacyStream(requestType) : filters.value.stream - const res = await adminAPI.usage.list({ page: pagination.page, page_size: pagination.page_size, ...filters.value, stream: legacyStream === null ? undefined : legacyStream }, { signal: c.signal }) + const res = await adminAPI.usage.list({ page: pagination.page, page_size: pagination.page_size, exact_total: false, ...filters.value, stream: legacyStream === null ? undefined : legacyStream }, { signal: c.signal }) if(!c.signal.aborted) { usageLogs.value = res.items; pagination.total = res.total } } catch (error: any) { if(error?.name !== 'AbortError') console.error('Failed to load usage logs:', error) } finally { if(abortController === c) loading.value = false } } @@ -124,15 +125,34 @@ const loadStats = async () => { } } const loadChartData = async () => { + const seq = ++chartReqSeq chartsLoading.value = true try { const requestType = filters.value.request_type const legacyStream = requestType ? requestTypeToLegacyStream(requestType) : filters.value.stream - const params = { start_date: filters.value.start_date || startDate.value, end_date: filters.value.end_date || endDate.value, granularity: granularity.value, user_id: filters.value.user_id, model: filters.value.model, api_key_id: filters.value.api_key_id, account_id: filters.value.account_id, group_id: filters.value.group_id, request_type: requestType, stream: legacyStream === null ? undefined : legacyStream, billing_type: filters.value.billing_type } - const statsParams = { start_date: params.start_date, end_date: params.end_date, user_id: params.user_id, model: params.model, api_key_id: params.api_key_id, account_id: params.account_id, group_id: params.group_id, request_type: params.request_type, stream: params.stream, billing_type: params.billing_type } - const [trendRes, modelRes, groupRes] = await Promise.all([adminAPI.dashboard.getUsageTrend(params), adminAPI.dashboard.getModelStats(statsParams), adminAPI.dashboard.getGroupStats(statsParams)]) - trendData.value = trendRes.trend || []; modelStats.value = modelRes.models || []; groupStats.value = groupRes.groups || [] - } catch (error) { console.error('Failed to load chart data:', error) } finally { chartsLoading.value = false } + const snapshot = await adminAPI.dashboard.getSnapshotV2({ + start_date: filters.value.start_date || startDate.value, + end_date: filters.value.end_date || endDate.value, + granularity: granularity.value, + user_id: filters.value.user_id, + model: filters.value.model, + api_key_id: filters.value.api_key_id, + account_id: filters.value.account_id, + group_id: filters.value.group_id, + request_type: requestType, + stream: legacyStream === null ? undefined : legacyStream, + billing_type: filters.value.billing_type, + include_stats: false, + include_trend: true, + include_model_stats: true, + include_group_stats: true, + include_users_trend: false + }) + if (seq !== chartReqSeq) return + trendData.value = snapshot.trend || [] + modelStats.value = snapshot.models || [] + groupStats.value = snapshot.groups || [] + } catch (error) { console.error('Failed to load chart data:', error) } finally { if (seq === chartReqSeq) chartsLoading.value = false } } const applyFilters = () => { pagination.page = 1; loadLogs(); loadStats(); loadChartData() } const refreshData = () => { loadLogs(); loadStats(); loadChartData() } @@ -171,7 +191,7 @@ const exportToExcel = async () => { while (true) { const requestType = filters.value.request_type const legacyStream = requestType ? requestTypeToLegacyStream(requestType) : filters.value.stream - const res = await adminUsageAPI.list({ page: p, page_size: 100, ...filters.value, stream: legacyStream === null ? undefined : legacyStream }, { signal: c.signal }) + const res = await adminUsageAPI.list({ page: p, page_size: 100, exact_total: true, ...filters.value, stream: legacyStream === null ? undefined : legacyStream }, { signal: c.signal }) if (c.signal.aborted) break; if (p === 1) { total = res.total; exportProgress.total = total } const rows = (res.items || []).map((log: AdminUsageLog) => [ log.created_at, log.user?.email || '', log.api_key?.name || '', log.account?.name || '', log.model, @@ -273,6 +293,14 @@ const handleColumnClickOutside = (event: MouseEvent) => { } } -onMounted(() => { loadLogs(); loadStats(); loadChartData(); loadSavedColumns(); document.addEventListener('click', handleColumnClickOutside) }) +onMounted(() => { + loadLogs() + loadStats() + window.setTimeout(() => { + void loadChartData() + }, 120) + loadSavedColumns() + document.addEventListener('click', handleColumnClickOutside) +}) onUnmounted(() => { abortController?.abort(); exportAbortController?.abort(); document.removeEventListener('click', handleColumnClickOutside) })