Merge pull request #755 from xvhuan/perf/admin-usage-fast-pagination-main

perf(admin-usage): 优化 usage 大表分页,默认避免全量 COUNT(*)
This commit is contained in:
Wesley Liddick
2026-03-04 14:15:57 +08:00
committed by GitHub
7 changed files with 167 additions and 79 deletions

View File

@@ -61,6 +61,15 @@ type CreateUsageCleanupTaskRequest struct {
// GET /api/v1/admin/usage
func (h *UsageHandler) List(c *gin.Context) {
page, pageSize := response.ParsePagination(c)
exactTotal := false
if exactTotalRaw := strings.TrimSpace(c.Query("exact_total")); exactTotalRaw != "" {
parsed, err := strconv.ParseBool(exactTotalRaw)
if err != nil {
response.BadRequest(c, "Invalid exact_total value, use true or false")
return
}
exactTotal = parsed
}
// Parse filters
var userID, apiKeyID, accountID, groupID int64
@@ -167,6 +176,7 @@ func (h *UsageHandler) List(c *gin.Context) {
BillingType: billingType,
StartTime: startTime,
EndTime: endTime,
ExactTotal: exactTotal,
}
records, result, err := h.usageService.ListWithFilters(c.Request.Context(), params, filters)

View File

@@ -80,6 +80,29 @@ func TestAdminUsageListInvalidStream(t *testing.T) {
require.Equal(t, http.StatusBadRequest, rec.Code)
}
func TestAdminUsageListExactTotalTrue(t *testing.T) {
repo := &adminUsageRepoCapture{}
router := newAdminUsageRequestTypeTestRouter(repo)
req := httptest.NewRequest(http.MethodGet, "/admin/usage?exact_total=true", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
require.True(t, repo.listFilters.ExactTotal)
}
func TestAdminUsageListInvalidExactTotal(t *testing.T) {
repo := &adminUsageRepoCapture{}
router := newAdminUsageRequestTypeTestRouter(repo)
req := httptest.NewRequest(http.MethodGet, "/admin/usage?exact_total=oops", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusBadRequest, rec.Code)
}
func TestAdminUsageStatsRequestTypePriority(t *testing.T) {
repo := &adminUsageRepoCapture{}
router := newAdminUsageRequestTypeTestRouter(repo)

View File

@@ -154,6 +154,8 @@ type UsageLogFilters struct {
BillingType *int8
StartTime *time.Time
EndTime *time.Time
// ExactTotal requests exact COUNT(*) for pagination. Default false for fast large-table paging.
ExactTotal bool
}
// UsageStats represents usage statistics

View File

@@ -1473,7 +1473,16 @@ func (r *usageLogRepository) ListWithFilters(ctx context.Context, params paginat
}
whereClause := buildWhere(conditions)
logs, page, err := r.listUsageLogsWithPagination(ctx, whereClause, args, params)
var (
logs []service.UsageLog
page *pagination.PaginationResult
err error
)
if shouldUseFastUsageLogTotal(filters) {
logs, page, err = r.listUsageLogsWithFastPagination(ctx, whereClause, args, params)
} else {
logs, page, err = r.listUsageLogsWithPagination(ctx, whereClause, args, params)
}
if err != nil {
return nil, nil, err
}
@@ -1484,17 +1493,45 @@ func (r *usageLogRepository) ListWithFilters(ctx context.Context, params paginat
return logs, page, nil
}
func shouldUseFastUsageLogTotal(filters UsageLogFilters) bool {
if filters.ExactTotal {
return false
}
// 强选择过滤下记录集通常较小,保留精确总数。
return filters.UserID == 0 && filters.APIKeyID == 0 && filters.AccountID == 0
}
// UsageStats represents usage statistics
type UsageStats = usagestats.UsageStats
// BatchUserUsageStats represents usage stats for a single user
type BatchUserUsageStats = usagestats.BatchUserUsageStats
func normalizePositiveInt64IDs(ids []int64) []int64 {
if len(ids) == 0 {
return nil
}
seen := make(map[int64]struct{}, len(ids))
out := make([]int64, 0, len(ids))
for _, id := range ids {
if id <= 0 {
continue
}
if _, ok := seen[id]; ok {
continue
}
seen[id] = struct{}{}
out = append(out, id)
}
return out
}
// GetBatchUserUsageStats gets today and total actual_cost for multiple users within a time range.
// If startTime is zero, defaults to 30 days ago.
func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*BatchUserUsageStats, error) {
result := make(map[int64]*BatchUserUsageStats)
if len(userIDs) == 0 {
normalizedUserIDs := normalizePositiveInt64IDs(userIDs)
if len(normalizedUserIDs) == 0 {
return result, nil
}
@@ -1506,58 +1543,36 @@ func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs
endTime = time.Now()
}
for _, id := range userIDs {
for _, id := range normalizedUserIDs {
result[id] = &BatchUserUsageStats{UserID: id}
}
query := `
SELECT user_id, COALESCE(SUM(actual_cost), 0) as total_cost
SELECT
user_id,
COALESCE(SUM(actual_cost) FILTER (WHERE created_at >= $2 AND created_at < $3), 0) as total_cost,
COALESCE(SUM(actual_cost) FILTER (WHERE created_at >= $4), 0) as today_cost
FROM usage_logs
WHERE user_id = ANY($1) AND created_at >= $2 AND created_at < $3
WHERE user_id = ANY($1)
AND created_at >= LEAST($2, $4)
GROUP BY user_id
`
rows, err := r.sql.QueryContext(ctx, query, pq.Array(userIDs), startTime, endTime)
today := timezone.Today()
rows, err := r.sql.QueryContext(ctx, query, pq.Array(normalizedUserIDs), startTime, endTime, today)
if err != nil {
return nil, err
}
for rows.Next() {
var userID int64
var total float64
if err := rows.Scan(&userID, &total); err != nil {
var todayTotal float64
if err := rows.Scan(&userID, &total, &todayTotal); err != nil {
_ = rows.Close()
return nil, err
}
if stats, ok := result[userID]; ok {
stats.TotalActualCost = total
}
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
today := timezone.Today()
todayQuery := `
SELECT user_id, COALESCE(SUM(actual_cost), 0) as today_cost
FROM usage_logs
WHERE user_id = ANY($1) AND created_at >= $2
GROUP BY user_id
`
rows, err = r.sql.QueryContext(ctx, todayQuery, pq.Array(userIDs), today)
if err != nil {
return nil, err
}
for rows.Next() {
var userID int64
var total float64
if err := rows.Scan(&userID, &total); err != nil {
_ = rows.Close()
return nil, err
}
if stats, ok := result[userID]; ok {
stats.TodayActualCost = total
stats.TodayActualCost = todayTotal
}
}
if err := rows.Close(); err != nil {
@@ -1577,7 +1592,8 @@ type BatchAPIKeyUsageStats = usagestats.BatchAPIKeyUsageStats
// If startTime is zero, defaults to 30 days ago.
func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64, startTime, endTime time.Time) (map[int64]*BatchAPIKeyUsageStats, error) {
result := make(map[int64]*BatchAPIKeyUsageStats)
if len(apiKeyIDs) == 0 {
normalizedAPIKeyIDs := normalizePositiveInt64IDs(apiKeyIDs)
if len(normalizedAPIKeyIDs) == 0 {
return result, nil
}
@@ -1589,58 +1605,36 @@ func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKe
endTime = time.Now()
}
for _, id := range apiKeyIDs {
for _, id := range normalizedAPIKeyIDs {
result[id] = &BatchAPIKeyUsageStats{APIKeyID: id}
}
query := `
SELECT api_key_id, COALESCE(SUM(actual_cost), 0) as total_cost
SELECT
api_key_id,
COALESCE(SUM(actual_cost) FILTER (WHERE created_at >= $2 AND created_at < $3), 0) as total_cost,
COALESCE(SUM(actual_cost) FILTER (WHERE created_at >= $4), 0) as today_cost
FROM usage_logs
WHERE api_key_id = ANY($1) AND created_at >= $2 AND created_at < $3
WHERE api_key_id = ANY($1)
AND created_at >= LEAST($2, $4)
GROUP BY api_key_id
`
rows, err := r.sql.QueryContext(ctx, query, pq.Array(apiKeyIDs), startTime, endTime)
today := timezone.Today()
rows, err := r.sql.QueryContext(ctx, query, pq.Array(normalizedAPIKeyIDs), startTime, endTime, today)
if err != nil {
return nil, err
}
for rows.Next() {
var apiKeyID int64
var total float64
if err := rows.Scan(&apiKeyID, &total); err != nil {
var todayTotal float64
if err := rows.Scan(&apiKeyID, &total, &todayTotal); err != nil {
_ = rows.Close()
return nil, err
}
if stats, ok := result[apiKeyID]; ok {
stats.TotalActualCost = total
}
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
today := timezone.Today()
todayQuery := `
SELECT api_key_id, COALESCE(SUM(actual_cost), 0) as today_cost
FROM usage_logs
WHERE api_key_id = ANY($1) AND created_at >= $2
GROUP BY api_key_id
`
rows, err = r.sql.QueryContext(ctx, todayQuery, pq.Array(apiKeyIDs), today)
if err != nil {
return nil, err
}
for rows.Next() {
var apiKeyID int64
var total float64
if err := rows.Scan(&apiKeyID, &total); err != nil {
_ = rows.Close()
return nil, err
}
if stats, ok := result[apiKeyID]; ok {
stats.TodayActualCost = total
stats.TodayActualCost = todayTotal
}
}
if err := rows.Close(); err != nil {
@@ -2245,6 +2239,35 @@ func (r *usageLogRepository) listUsageLogsWithPagination(ctx context.Context, wh
return logs, paginationResultFromTotal(total, params), nil
}
func (r *usageLogRepository) listUsageLogsWithFastPagination(ctx context.Context, whereClause string, args []any, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
limit := params.Limit()
offset := params.Offset()
limitPos := len(args) + 1
offsetPos := len(args) + 2
listArgs := append(append([]any{}, args...), limit+1, offset)
query := fmt.Sprintf("SELECT %s FROM usage_logs %s ORDER BY id DESC LIMIT $%d OFFSET $%d", usageLogSelectColumns, whereClause, limitPos, offsetPos)
logs, err := r.queryUsageLogs(ctx, query, listArgs...)
if err != nil {
return nil, nil, err
}
hasMore := false
if len(logs) > limit {
hasMore = true
logs = logs[:limit]
}
total := int64(offset) + int64(len(logs))
if hasMore {
// 只保证“还有下一页”,避免对超大表做全量 COUNT(*)。
total = int64(offset) + int64(limit) + 1
}
return logs, paginationResultFromTotal(total, params), nil
}
func (r *usageLogRepository) queryUsageLogs(ctx context.Context, query string, args ...any) (logs []service.UsageLog, err error) {
rows, err := r.sql.QueryContext(ctx, query, args...)
if err != nil {

View File

@@ -96,6 +96,7 @@ func TestUsageLogRepositoryListWithFiltersRequestTypePriority(t *testing.T) {
filters := usagestats.UsageLogFilters{
RequestType: &requestType,
Stream: &stream,
ExactTotal: true,
}
mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM usage_logs WHERE \\(request_type = \\$1 OR \\(request_type = 0 AND openai_ws_mode = TRUE\\)\\)").