perf(admin-usage): avoid expensive count on large usage_logs pagination
This commit is contained in:
@@ -1473,7 +1473,16 @@ func (r *usageLogRepository) ListWithFilters(ctx context.Context, params paginat
|
||||
}
|
||||
|
||||
whereClause := buildWhere(conditions)
|
||||
logs, page, err := r.listUsageLogsWithPagination(ctx, whereClause, args, params)
|
||||
var (
|
||||
logs []service.UsageLog
|
||||
page *pagination.PaginationResult
|
||||
err error
|
||||
)
|
||||
if shouldUseFastUsageLogTotal(filters) {
|
||||
logs, page, err = r.listUsageLogsWithFastPagination(ctx, whereClause, args, params)
|
||||
} else {
|
||||
logs, page, err = r.listUsageLogsWithPagination(ctx, whereClause, args, params)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
@@ -1484,17 +1493,45 @@ func (r *usageLogRepository) ListWithFilters(ctx context.Context, params paginat
|
||||
return logs, page, nil
|
||||
}
|
||||
|
||||
func shouldUseFastUsageLogTotal(filters UsageLogFilters) bool {
|
||||
if filters.ExactTotal {
|
||||
return false
|
||||
}
|
||||
// 强选择过滤下记录集通常较小,保留精确总数。
|
||||
return filters.UserID == 0 && filters.APIKeyID == 0 && filters.AccountID == 0
|
||||
}
|
||||
|
||||
// UsageStats represents usage statistics
|
||||
type UsageStats = usagestats.UsageStats
|
||||
|
||||
// BatchUserUsageStats represents usage stats for a single user
|
||||
type BatchUserUsageStats = usagestats.BatchUserUsageStats
|
||||
|
||||
func normalizePositiveInt64IDs(ids []int64) []int64 {
|
||||
if len(ids) == 0 {
|
||||
return nil
|
||||
}
|
||||
seen := make(map[int64]struct{}, len(ids))
|
||||
out := make([]int64, 0, len(ids))
|
||||
for _, id := range ids {
|
||||
if id <= 0 {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[id]; ok {
|
||||
continue
|
||||
}
|
||||
seen[id] = struct{}{}
|
||||
out = append(out, id)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// GetBatchUserUsageStats gets today and total actual_cost for multiple users within a time range.
|
||||
// If startTime is zero, defaults to 30 days ago.
|
||||
func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*BatchUserUsageStats, error) {
|
||||
result := make(map[int64]*BatchUserUsageStats)
|
||||
if len(userIDs) == 0 {
|
||||
normalizedUserIDs := normalizePositiveInt64IDs(userIDs)
|
||||
if len(normalizedUserIDs) == 0 {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
@@ -1506,58 +1543,36 @@ func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs
|
||||
endTime = time.Now()
|
||||
}
|
||||
|
||||
for _, id := range userIDs {
|
||||
for _, id := range normalizedUserIDs {
|
||||
result[id] = &BatchUserUsageStats{UserID: id}
|
||||
}
|
||||
|
||||
query := `
|
||||
SELECT user_id, COALESCE(SUM(actual_cost), 0) as total_cost
|
||||
SELECT
|
||||
user_id,
|
||||
COALESCE(SUM(actual_cost) FILTER (WHERE created_at >= $2 AND created_at < $3), 0) as total_cost,
|
||||
COALESCE(SUM(actual_cost) FILTER (WHERE created_at >= $4), 0) as today_cost
|
||||
FROM usage_logs
|
||||
WHERE user_id = ANY($1) AND created_at >= $2 AND created_at < $3
|
||||
WHERE user_id = ANY($1)
|
||||
AND created_at >= LEAST($2, $4)
|
||||
GROUP BY user_id
|
||||
`
|
||||
rows, err := r.sql.QueryContext(ctx, query, pq.Array(userIDs), startTime, endTime)
|
||||
today := timezone.Today()
|
||||
rows, err := r.sql.QueryContext(ctx, query, pq.Array(normalizedUserIDs), startTime, endTime, today)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for rows.Next() {
|
||||
var userID int64
|
||||
var total float64
|
||||
if err := rows.Scan(&userID, &total); err != nil {
|
||||
var todayTotal float64
|
||||
if err := rows.Scan(&userID, &total, &todayTotal); err != nil {
|
||||
_ = rows.Close()
|
||||
return nil, err
|
||||
}
|
||||
if stats, ok := result[userID]; ok {
|
||||
stats.TotalActualCost = total
|
||||
}
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
today := timezone.Today()
|
||||
todayQuery := `
|
||||
SELECT user_id, COALESCE(SUM(actual_cost), 0) as today_cost
|
||||
FROM usage_logs
|
||||
WHERE user_id = ANY($1) AND created_at >= $2
|
||||
GROUP BY user_id
|
||||
`
|
||||
rows, err = r.sql.QueryContext(ctx, todayQuery, pq.Array(userIDs), today)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for rows.Next() {
|
||||
var userID int64
|
||||
var total float64
|
||||
if err := rows.Scan(&userID, &total); err != nil {
|
||||
_ = rows.Close()
|
||||
return nil, err
|
||||
}
|
||||
if stats, ok := result[userID]; ok {
|
||||
stats.TodayActualCost = total
|
||||
stats.TodayActualCost = todayTotal
|
||||
}
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
@@ -1577,7 +1592,8 @@ type BatchAPIKeyUsageStats = usagestats.BatchAPIKeyUsageStats
|
||||
// If startTime is zero, defaults to 30 days ago.
|
||||
func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64, startTime, endTime time.Time) (map[int64]*BatchAPIKeyUsageStats, error) {
|
||||
result := make(map[int64]*BatchAPIKeyUsageStats)
|
||||
if len(apiKeyIDs) == 0 {
|
||||
normalizedAPIKeyIDs := normalizePositiveInt64IDs(apiKeyIDs)
|
||||
if len(normalizedAPIKeyIDs) == 0 {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
@@ -1589,58 +1605,36 @@ func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKe
|
||||
endTime = time.Now()
|
||||
}
|
||||
|
||||
for _, id := range apiKeyIDs {
|
||||
for _, id := range normalizedAPIKeyIDs {
|
||||
result[id] = &BatchAPIKeyUsageStats{APIKeyID: id}
|
||||
}
|
||||
|
||||
query := `
|
||||
SELECT api_key_id, COALESCE(SUM(actual_cost), 0) as total_cost
|
||||
SELECT
|
||||
api_key_id,
|
||||
COALESCE(SUM(actual_cost) FILTER (WHERE created_at >= $2 AND created_at < $3), 0) as total_cost,
|
||||
COALESCE(SUM(actual_cost) FILTER (WHERE created_at >= $4), 0) as today_cost
|
||||
FROM usage_logs
|
||||
WHERE api_key_id = ANY($1) AND created_at >= $2 AND created_at < $3
|
||||
WHERE api_key_id = ANY($1)
|
||||
AND created_at >= LEAST($2, $4)
|
||||
GROUP BY api_key_id
|
||||
`
|
||||
rows, err := r.sql.QueryContext(ctx, query, pq.Array(apiKeyIDs), startTime, endTime)
|
||||
today := timezone.Today()
|
||||
rows, err := r.sql.QueryContext(ctx, query, pq.Array(normalizedAPIKeyIDs), startTime, endTime, today)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for rows.Next() {
|
||||
var apiKeyID int64
|
||||
var total float64
|
||||
if err := rows.Scan(&apiKeyID, &total); err != nil {
|
||||
var todayTotal float64
|
||||
if err := rows.Scan(&apiKeyID, &total, &todayTotal); err != nil {
|
||||
_ = rows.Close()
|
||||
return nil, err
|
||||
}
|
||||
if stats, ok := result[apiKeyID]; ok {
|
||||
stats.TotalActualCost = total
|
||||
}
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
today := timezone.Today()
|
||||
todayQuery := `
|
||||
SELECT api_key_id, COALESCE(SUM(actual_cost), 0) as today_cost
|
||||
FROM usage_logs
|
||||
WHERE api_key_id = ANY($1) AND created_at >= $2
|
||||
GROUP BY api_key_id
|
||||
`
|
||||
rows, err = r.sql.QueryContext(ctx, todayQuery, pq.Array(apiKeyIDs), today)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for rows.Next() {
|
||||
var apiKeyID int64
|
||||
var total float64
|
||||
if err := rows.Scan(&apiKeyID, &total); err != nil {
|
||||
_ = rows.Close()
|
||||
return nil, err
|
||||
}
|
||||
if stats, ok := result[apiKeyID]; ok {
|
||||
stats.TodayActualCost = total
|
||||
stats.TodayActualCost = todayTotal
|
||||
}
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
@@ -2245,6 +2239,35 @@ func (r *usageLogRepository) listUsageLogsWithPagination(ctx context.Context, wh
|
||||
return logs, paginationResultFromTotal(total, params), nil
|
||||
}
|
||||
|
||||
func (r *usageLogRepository) listUsageLogsWithFastPagination(ctx context.Context, whereClause string, args []any, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||||
limit := params.Limit()
|
||||
offset := params.Offset()
|
||||
|
||||
limitPos := len(args) + 1
|
||||
offsetPos := len(args) + 2
|
||||
listArgs := append(append([]any{}, args...), limit+1, offset)
|
||||
query := fmt.Sprintf("SELECT %s FROM usage_logs %s ORDER BY id DESC LIMIT $%d OFFSET $%d", usageLogSelectColumns, whereClause, limitPos, offsetPos)
|
||||
|
||||
logs, err := r.queryUsageLogs(ctx, query, listArgs...)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
hasMore := false
|
||||
if len(logs) > limit {
|
||||
hasMore = true
|
||||
logs = logs[:limit]
|
||||
}
|
||||
|
||||
total := int64(offset) + int64(len(logs))
|
||||
if hasMore {
|
||||
// 只保证“还有下一页”,避免对超大表做全量 COUNT(*)。
|
||||
total = int64(offset) + int64(limit) + 1
|
||||
}
|
||||
|
||||
return logs, paginationResultFromTotal(total, params), nil
|
||||
}
|
||||
|
||||
func (r *usageLogRepository) queryUsageLogs(ctx context.Context, query string, args ...any) (logs []service.UsageLog, err error) {
|
||||
rows, err := r.sql.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
|
||||
Reference in New Issue
Block a user