Files
sub2api/backend/internal/repository/usage_log_repo.go
yangjianbo 5906f9ab98 fix(数据层): 修复数据完整性与仓储一致性问题
## 数据完整性修复 (fix-critical-data-integrity)
- 添加 error_translate.go 统一错误转换层
- 修复 nil 输入和 NotFound 错误处理
- 增强仓储层错误一致性

## 仓储一致性修复 (fix-high-repository-consistency)
- Group schema 添加 default_validity_days 字段
- Account schema 添加 proxy edge 关联
- 新增 UsageLog ent schema 定义
- 修复 UpdateBalance/UpdateConcurrency 受影响行数校验

## 数据卫生修复 (fix-medium-data-hygiene)
- UserSubscription 添加软删除支持 (SoftDeleteMixin)
- RedeemCode/Setting 添加硬删除策略文档
- account_groups/user_allowed_groups 的 created_at 声明 timestamptz
- 停止写入 legacy users.allowed_groups 列
- 新增迁移: 011-014 (索引优化、软删除、孤立数据审计、列清理)

## 测试补充
- 添加 UserSubscription 软删除测试
- 添加迁移回归测试
- 添加 NotFound 错误测试

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2025-12-31 14:11:57 +08:00

1922 lines
56 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package repository
import (
"context"
"database/sql"
"fmt"
"os"
"strings"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
dbaccount "github.com/Wei-Shaw/sub2api/ent/account"
dbapikey "github.com/Wei-Shaw/sub2api/ent/apikey"
dbgroup "github.com/Wei-Shaw/sub2api/ent/group"
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
dbusersub "github.com/Wei-Shaw/sub2api/ent/usersubscription"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/lib/pq"
)
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, billing_type, stream, duration_ms, first_token_ms, created_at"
type usageLogRepository struct {
client *dbent.Client
sql sqlExecutor
}
func NewUsageLogRepository(client *dbent.Client, sqlDB *sql.DB) service.UsageLogRepository {
return newUsageLogRepositoryWithSQL(client, sqlDB)
}
func newUsageLogRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *usageLogRepository {
// 使用 scanSingleRow 替代 QueryRowContext保证 ent.Tx 作为 sqlExecutor 可用。
return &usageLogRepository{client: client, sql: sqlq}
}
// getPerformanceStats 获取 RPM 和 TPM近5分钟平均值可选按用户过滤
func (r *usageLogRepository) getPerformanceStats(ctx context.Context, userID int64) (rpm, tpm int64, err error) {
fiveMinutesAgo := time.Now().Add(-5 * time.Minute)
query := `
SELECT
COUNT(*) as request_count,
COALESCE(SUM(input_tokens + output_tokens), 0) as token_count
FROM usage_logs
WHERE created_at >= $1`
args := []any{fiveMinutesAgo}
if userID > 0 {
query += " AND user_id = $2"
args = append(args, userID)
}
var requestCount int64
var tokenCount int64
if err := scanSingleRow(ctx, r.sql, query, args, &requestCount, &tokenCount); err != nil {
return 0, 0, err
}
return requestCount / 5, tokenCount / 5, nil
}
func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) error {
if log == nil {
return nil
}
createdAt := log.CreatedAt
if createdAt.IsZero() {
createdAt = time.Now()
}
rateMultiplier := log.RateMultiplier
query := `
INSERT INTO usage_logs (
user_id,
api_key_id,
account_id,
request_id,
model,
group_id,
subscription_id,
input_tokens,
output_tokens,
cache_creation_tokens,
cache_read_tokens,
cache_creation_5m_tokens,
cache_creation_1h_tokens,
input_cost,
output_cost,
cache_creation_cost,
cache_read_cost,
total_cost,
actual_cost,
rate_multiplier,
billing_type,
stream,
duration_ms,
first_token_ms,
created_at
) VALUES (
$1, $2, $3, $4, $5,
$6, $7,
$8, $9, $10, $11,
$12, $13,
$14, $15, $16, $17, $18, $19,
$20, $21, $22, $23, $24, $25
)
RETURNING id, created_at
`
groupID := nullInt64(log.GroupID)
subscriptionID := nullInt64(log.SubscriptionID)
duration := nullInt(log.DurationMs)
firstToken := nullInt(log.FirstTokenMs)
args := []any{
log.UserID,
log.ApiKeyID,
log.AccountID,
log.RequestID,
log.Model,
groupID,
subscriptionID,
log.InputTokens,
log.OutputTokens,
log.CacheCreationTokens,
log.CacheReadTokens,
log.CacheCreation5mTokens,
log.CacheCreation1hTokens,
log.InputCost,
log.OutputCost,
log.CacheCreationCost,
log.CacheReadCost,
log.TotalCost,
log.ActualCost,
rateMultiplier,
log.BillingType,
log.Stream,
duration,
firstToken,
createdAt,
}
if err := scanSingleRow(ctx, r.sql, query, args, &log.ID, &log.CreatedAt); err != nil {
return err
}
log.RateMultiplier = rateMultiplier
return nil
}
func (r *usageLogRepository) GetByID(ctx context.Context, id int64) (log *service.UsageLog, err error) {
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE id = $1"
rows, err := r.sql.QueryContext(ctx, query, id)
if err != nil {
return nil, err
}
defer func() {
// 保持主错误优先;仅在无错误时回传 Close 失败。
// 同时清空返回值,避免误用不完整结果。
if closeErr := rows.Close(); closeErr != nil && err == nil {
err = closeErr
log = nil
}
}()
if !rows.Next() {
if err = rows.Err(); err != nil {
return nil, err
}
return nil, service.ErrUsageLogNotFound
}
log, err = scanUsageLog(rows)
if err != nil {
return nil, err
}
if err = rows.Err(); err != nil {
return nil, err
}
return log, nil
}
func (r *usageLogRepository) ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
return r.listUsageLogsWithPagination(ctx, "WHERE user_id = $1", []any{userID}, params)
}
func (r *usageLogRepository) ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
return r.listUsageLogsWithPagination(ctx, "WHERE api_key_id = $1", []any{apiKeyID}, params)
}
// UserStats 用户使用统计
type UserStats struct {
TotalRequests int64 `json:"total_requests"`
TotalTokens int64 `json:"total_tokens"`
TotalCost float64 `json:"total_cost"`
InputTokens int64 `json:"input_tokens"`
OutputTokens int64 `json:"output_tokens"`
CacheReadTokens int64 `json:"cache_read_tokens"`
}
func (r *usageLogRepository) GetUserStats(ctx context.Context, userID int64, startTime, endTime time.Time) (*UserStats, error) {
query := `
SELECT
COUNT(*) as total_requests,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
COALESCE(SUM(actual_cost), 0) as total_cost,
COALESCE(SUM(input_tokens), 0) as input_tokens,
COALESCE(SUM(output_tokens), 0) as output_tokens,
COALESCE(SUM(cache_read_tokens), 0) as cache_read_tokens
FROM usage_logs
WHERE user_id = $1 AND created_at >= $2 AND created_at < $3
`
stats := &UserStats{}
if err := scanSingleRow(
ctx,
r.sql,
query,
[]any{userID, startTime, endTime},
&stats.TotalRequests,
&stats.TotalTokens,
&stats.TotalCost,
&stats.InputTokens,
&stats.OutputTokens,
&stats.CacheReadTokens,
); err != nil {
return nil, err
}
return stats, nil
}
// DashboardStats 仪表盘统计
type DashboardStats = usagestats.DashboardStats
func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardStats, error) {
var stats DashboardStats
today := timezone.Today()
now := time.Now()
// 合并用户统计查询
userStatsQuery := `
SELECT
COUNT(*) as total_users,
COUNT(CASE WHEN created_at >= $1 THEN 1 END) as today_new_users,
(SELECT COUNT(DISTINCT user_id) FROM usage_logs WHERE created_at >= $2) as active_users
FROM users
WHERE deleted_at IS NULL
`
if err := scanSingleRow(
ctx,
r.sql,
userStatsQuery,
[]any{today, today},
&stats.TotalUsers,
&stats.TodayNewUsers,
&stats.ActiveUsers,
); err != nil {
return nil, err
}
// 合并API Key统计查询
apiKeyStatsQuery := `
SELECT
COUNT(*) as total_api_keys,
COUNT(CASE WHEN status = $1 THEN 1 END) as active_api_keys
FROM api_keys
WHERE deleted_at IS NULL
`
if err := scanSingleRow(
ctx,
r.sql,
apiKeyStatsQuery,
[]any{service.StatusActive},
&stats.TotalApiKeys,
&stats.ActiveApiKeys,
); err != nil {
return nil, err
}
// 合并账户统计查询
accountStatsQuery := `
SELECT
COUNT(*) as total_accounts,
COUNT(CASE WHEN status = $1 AND schedulable = true THEN 1 END) as normal_accounts,
COUNT(CASE WHEN status = $2 THEN 1 END) as error_accounts,
COUNT(CASE WHEN rate_limited_at IS NOT NULL AND rate_limit_reset_at > $3 THEN 1 END) as ratelimit_accounts,
COUNT(CASE WHEN overload_until IS NOT NULL AND overload_until > $4 THEN 1 END) as overload_accounts
FROM accounts
WHERE deleted_at IS NULL
`
if err := scanSingleRow(
ctx,
r.sql,
accountStatsQuery,
[]any{service.StatusActive, service.StatusError, now, now},
&stats.TotalAccounts,
&stats.NormalAccounts,
&stats.ErrorAccounts,
&stats.RateLimitAccounts,
&stats.OverloadAccounts,
); err != nil {
return nil, err
}
// 累计 Token 统计
totalStatsQuery := `
SELECT
COUNT(*) as total_requests,
COALESCE(SUM(input_tokens), 0) as total_input_tokens,
COALESCE(SUM(output_tokens), 0) as total_output_tokens,
COALESCE(SUM(cache_creation_tokens), 0) as total_cache_creation_tokens,
COALESCE(SUM(cache_read_tokens), 0) as total_cache_read_tokens,
COALESCE(SUM(total_cost), 0) as total_cost,
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
COALESCE(AVG(duration_ms), 0) as avg_duration_ms
FROM usage_logs
`
if err := scanSingleRow(
ctx,
r.sql,
totalStatsQuery,
nil,
&stats.TotalRequests,
&stats.TotalInputTokens,
&stats.TotalOutputTokens,
&stats.TotalCacheCreationTokens,
&stats.TotalCacheReadTokens,
&stats.TotalCost,
&stats.TotalActualCost,
&stats.AverageDurationMs,
); err != nil {
return nil, err
}
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheCreationTokens + stats.TotalCacheReadTokens
// 今日 Token 统计
todayStatsQuery := `
SELECT
COUNT(*) as today_requests,
COALESCE(SUM(input_tokens), 0) as today_input_tokens,
COALESCE(SUM(output_tokens), 0) as today_output_tokens,
COALESCE(SUM(cache_creation_tokens), 0) as today_cache_creation_tokens,
COALESCE(SUM(cache_read_tokens), 0) as today_cache_read_tokens,
COALESCE(SUM(total_cost), 0) as today_cost,
COALESCE(SUM(actual_cost), 0) as today_actual_cost
FROM usage_logs
WHERE created_at >= $1
`
if err := scanSingleRow(
ctx,
r.sql,
todayStatsQuery,
[]any{today},
&stats.TodayRequests,
&stats.TodayInputTokens,
&stats.TodayOutputTokens,
&stats.TodayCacheCreationTokens,
&stats.TodayCacheReadTokens,
&stats.TodayCost,
&stats.TodayActualCost,
); err != nil {
return nil, err
}
stats.TodayTokens = stats.TodayInputTokens + stats.TodayOutputTokens + stats.TodayCacheCreationTokens + stats.TodayCacheReadTokens
// 性能指标RPM 和 TPM最近1分钟全局
rpm, tpm, err := r.getPerformanceStats(ctx, 0)
if err != nil {
return nil, err
}
stats.Rpm = rpm
stats.Tpm = tpm
return &stats, nil
}
func (r *usageLogRepository) ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
return r.listUsageLogsWithPagination(ctx, "WHERE account_id = $1", []any{accountID}, params)
}
func (r *usageLogRepository) ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE user_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC"
logs, err := r.queryUsageLogs(ctx, query, userID, startTime, endTime)
return logs, nil, err
}
// GetUserStatsAggregated returns aggregated usage statistics for a user using database-level aggregation
func (r *usageLogRepository) GetUserStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
query := `
SELECT
COUNT(*) as total_requests,
COALESCE(SUM(input_tokens), 0) as total_input_tokens,
COALESCE(SUM(output_tokens), 0) as total_output_tokens,
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as total_cache_tokens,
COALESCE(SUM(total_cost), 0) as total_cost,
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
COALESCE(AVG(COALESCE(duration_ms, 0)), 0) as avg_duration_ms
FROM usage_logs
WHERE user_id = $1 AND created_at >= $2 AND created_at < $3
`
var stats usagestats.UsageStats
if err := scanSingleRow(
ctx,
r.sql,
query,
[]any{userID, startTime, endTime},
&stats.TotalRequests,
&stats.TotalInputTokens,
&stats.TotalOutputTokens,
&stats.TotalCacheTokens,
&stats.TotalCost,
&stats.TotalActualCost,
&stats.AverageDurationMs,
); err != nil {
return nil, err
}
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheTokens
return &stats, nil
}
// GetApiKeyStatsAggregated returns aggregated usage statistics for an API key using database-level aggregation
func (r *usageLogRepository) GetApiKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
query := `
SELECT
COUNT(*) as total_requests,
COALESCE(SUM(input_tokens), 0) as total_input_tokens,
COALESCE(SUM(output_tokens), 0) as total_output_tokens,
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as total_cache_tokens,
COALESCE(SUM(total_cost), 0) as total_cost,
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
COALESCE(AVG(COALESCE(duration_ms, 0)), 0) as avg_duration_ms
FROM usage_logs
WHERE api_key_id = $1 AND created_at >= $2 AND created_at < $3
`
var stats usagestats.UsageStats
if err := scanSingleRow(
ctx,
r.sql,
query,
[]any{apiKeyID, startTime, endTime},
&stats.TotalRequests,
&stats.TotalInputTokens,
&stats.TotalOutputTokens,
&stats.TotalCacheTokens,
&stats.TotalCost,
&stats.TotalActualCost,
&stats.AverageDurationMs,
); err != nil {
return nil, err
}
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheTokens
return &stats, nil
}
// GetAccountStatsAggregated 使用 SQL 聚合统计账号使用数据
//
// 性能优化说明:
// 原实现先查询所有日志记录,再在应用层循环计算统计值:
// 1. 需要传输大量数据到应用层
// 2. 应用层循环计算增加 CPU 和内存开销
//
// 新实现使用 SQL 聚合函数:
// 1. 在数据库层完成 COUNT/SUM/AVG 计算
// 2. 只返回单行聚合结果,大幅减少数据传输量
// 3. 利用数据库索引优化聚合查询性能
func (r *usageLogRepository) GetAccountStatsAggregated(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
query := `
SELECT
COUNT(*) as total_requests,
COALESCE(SUM(input_tokens), 0) as total_input_tokens,
COALESCE(SUM(output_tokens), 0) as total_output_tokens,
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as total_cache_tokens,
COALESCE(SUM(total_cost), 0) as total_cost,
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
COALESCE(AVG(COALESCE(duration_ms, 0)), 0) as avg_duration_ms
FROM usage_logs
WHERE account_id = $1 AND created_at >= $2 AND created_at < $3
`
var stats usagestats.UsageStats
if err := scanSingleRow(
ctx,
r.sql,
query,
[]any{accountID, startTime, endTime},
&stats.TotalRequests,
&stats.TotalInputTokens,
&stats.TotalOutputTokens,
&stats.TotalCacheTokens,
&stats.TotalCost,
&stats.TotalActualCost,
&stats.AverageDurationMs,
); err != nil {
return nil, err
}
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheTokens
return &stats, nil
}
// GetModelStatsAggregated 使用 SQL 聚合统计模型使用数据
// 性能优化:数据库层聚合计算,避免应用层循环统计
func (r *usageLogRepository) GetModelStatsAggregated(ctx context.Context, modelName string, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
query := `
SELECT
COUNT(*) as total_requests,
COALESCE(SUM(input_tokens), 0) as total_input_tokens,
COALESCE(SUM(output_tokens), 0) as total_output_tokens,
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as total_cache_tokens,
COALESCE(SUM(total_cost), 0) as total_cost,
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
COALESCE(AVG(COALESCE(duration_ms, 0)), 0) as avg_duration_ms
FROM usage_logs
WHERE model = $1 AND created_at >= $2 AND created_at < $3
`
var stats usagestats.UsageStats
if err := scanSingleRow(
ctx,
r.sql,
query,
[]any{modelName, startTime, endTime},
&stats.TotalRequests,
&stats.TotalInputTokens,
&stats.TotalOutputTokens,
&stats.TotalCacheTokens,
&stats.TotalCost,
&stats.TotalActualCost,
&stats.AverageDurationMs,
); err != nil {
return nil, err
}
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheTokens
return &stats, nil
}
// GetDailyStatsAggregated 使用 SQL 聚合统计用户的每日使用数据
// 性能优化:使用 GROUP BY 在数据库层按日期分组聚合,避免应用层循环分组统计
func (r *usageLogRepository) GetDailyStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) (result []map[string]any, err error) {
tzName := resolveUsageStatsTimezone()
query := `
SELECT
-- 使用应用时区分组,避免数据库会话时区导致日边界偏移。
TO_CHAR(created_at AT TIME ZONE $4, 'YYYY-MM-DD') as date,
COUNT(*) as total_requests,
COALESCE(SUM(input_tokens), 0) as total_input_tokens,
COALESCE(SUM(output_tokens), 0) as total_output_tokens,
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as total_cache_tokens,
COALESCE(SUM(total_cost), 0) as total_cost,
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
COALESCE(AVG(COALESCE(duration_ms, 0)), 0) as avg_duration_ms
FROM usage_logs
WHERE user_id = $1 AND created_at >= $2 AND created_at < $3
GROUP BY 1
ORDER BY 1
`
rows, err := r.sql.QueryContext(ctx, query, userID, startTime, endTime, tzName)
if err != nil {
return nil, err
}
defer func() {
if closeErr := rows.Close(); closeErr != nil && err == nil {
err = closeErr
result = nil
}
}()
result = make([]map[string]any, 0)
for rows.Next() {
var (
date string
totalRequests int64
totalInputTokens int64
totalOutputTokens int64
totalCacheTokens int64
totalCost float64
totalActualCost float64
avgDurationMs float64
)
if err = rows.Scan(
&date,
&totalRequests,
&totalInputTokens,
&totalOutputTokens,
&totalCacheTokens,
&totalCost,
&totalActualCost,
&avgDurationMs,
); err != nil {
return nil, err
}
result = append(result, map[string]any{
"date": date,
"total_requests": totalRequests,
"total_input_tokens": totalInputTokens,
"total_output_tokens": totalOutputTokens,
"total_cache_tokens": totalCacheTokens,
"total_tokens": totalInputTokens + totalOutputTokens + totalCacheTokens,
"total_cost": totalCost,
"total_actual_cost": totalActualCost,
"average_duration_ms": avgDurationMs,
})
}
if err = rows.Err(); err != nil {
return nil, err
}
return result, nil
}
// resolveUsageStatsTimezone 获取用于 SQL 分组的时区名称。
// 优先使用应用初始化的时区,其次尝试读取 TZ 环境变量,最后回落为 UTC。
func resolveUsageStatsTimezone() string {
tzName := timezone.Name()
if tzName != "" && tzName != "Local" {
return tzName
}
if envTZ := strings.TrimSpace(os.Getenv("TZ")); envTZ != "" {
return envTZ
}
return "UTC"
}
func (r *usageLogRepository) ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE api_key_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC"
logs, err := r.queryUsageLogs(ctx, query, apiKeyID, startTime, endTime)
return logs, nil, err
}
func (r *usageLogRepository) ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE account_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC"
logs, err := r.queryUsageLogs(ctx, query, accountID, startTime, endTime)
return logs, nil, err
}
func (r *usageLogRepository) ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE model = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC"
logs, err := r.queryUsageLogs(ctx, query, modelName, startTime, endTime)
return logs, nil, err
}
func (r *usageLogRepository) Delete(ctx context.Context, id int64) error {
_, err := r.sql.ExecContext(ctx, "DELETE FROM usage_logs WHERE id = $1", id)
return err
}
// GetAccountTodayStats 获取账号今日统计
func (r *usageLogRepository) GetAccountTodayStats(ctx context.Context, accountID int64) (*usagestats.AccountStats, error) {
today := timezone.Today()
query := `
SELECT
COUNT(*) as requests,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
COALESCE(SUM(actual_cost), 0) as cost
FROM usage_logs
WHERE account_id = $1 AND created_at >= $2
`
stats := &usagestats.AccountStats{}
if err := scanSingleRow(
ctx,
r.sql,
query,
[]any{accountID, today},
&stats.Requests,
&stats.Tokens,
&stats.Cost,
); err != nil {
return nil, err
}
return stats, nil
}
// GetAccountWindowStats 获取账号时间窗口内的统计
func (r *usageLogRepository) GetAccountWindowStats(ctx context.Context, accountID int64, startTime time.Time) (*usagestats.AccountStats, error) {
query := `
SELECT
COUNT(*) as requests,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
COALESCE(SUM(actual_cost), 0) as cost
FROM usage_logs
WHERE account_id = $1 AND created_at >= $2
`
stats := &usagestats.AccountStats{}
if err := scanSingleRow(
ctx,
r.sql,
query,
[]any{accountID, startTime},
&stats.Requests,
&stats.Tokens,
&stats.Cost,
); err != nil {
return nil, err
}
return stats, nil
}
// TrendDataPoint represents a single point in trend data
type TrendDataPoint = usagestats.TrendDataPoint
// ModelStat represents usage statistics for a single model
type ModelStat = usagestats.ModelStat
// UserUsageTrendPoint represents user usage trend data point
type UserUsageTrendPoint = usagestats.UserUsageTrendPoint
// ApiKeyUsageTrendPoint represents API key usage trend data point
type ApiKeyUsageTrendPoint = usagestats.ApiKeyUsageTrendPoint
// GetApiKeyUsageTrend returns usage trend data grouped by API key and date
func (r *usageLogRepository) GetApiKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) (results []ApiKeyUsageTrendPoint, err error) {
dateFormat := "YYYY-MM-DD"
if granularity == "hour" {
dateFormat = "YYYY-MM-DD HH24:00"
}
query := fmt.Sprintf(`
WITH top_keys AS (
SELECT api_key_id
FROM usage_logs
WHERE created_at >= $1 AND created_at < $2
GROUP BY api_key_id
ORDER BY SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens) DESC
LIMIT $3
)
SELECT
TO_CHAR(u.created_at, '%s') as date,
u.api_key_id,
COALESCE(k.name, '') as key_name,
COUNT(*) as requests,
COALESCE(SUM(u.input_tokens + u.output_tokens + u.cache_creation_tokens + u.cache_read_tokens), 0) as tokens
FROM usage_logs u
LEFT JOIN api_keys k ON u.api_key_id = k.id
WHERE u.api_key_id IN (SELECT api_key_id FROM top_keys)
AND u.created_at >= $4 AND u.created_at < $5
GROUP BY date, u.api_key_id, k.name
ORDER BY date ASC, tokens DESC
`, dateFormat)
rows, err := r.sql.QueryContext(ctx, query, startTime, endTime, limit, startTime, endTime)
if err != nil {
return nil, err
}
defer func() {
// 保持主错误优先;仅在无错误时回传 Close 失败。
// 同时清空返回值,避免误用不完整结果。
if closeErr := rows.Close(); closeErr != nil && err == nil {
err = closeErr
results = nil
}
}()
results = make([]ApiKeyUsageTrendPoint, 0)
for rows.Next() {
var row ApiKeyUsageTrendPoint
if err = rows.Scan(&row.Date, &row.ApiKeyID, &row.KeyName, &row.Requests, &row.Tokens); err != nil {
return nil, err
}
results = append(results, row)
}
if err = rows.Err(); err != nil {
return nil, err
}
return results, nil
}
// GetUserUsageTrend returns usage trend data grouped by user and date
func (r *usageLogRepository) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) (results []UserUsageTrendPoint, err error) {
dateFormat := "YYYY-MM-DD"
if granularity == "hour" {
dateFormat = "YYYY-MM-DD HH24:00"
}
query := fmt.Sprintf(`
WITH top_users AS (
SELECT user_id
FROM usage_logs
WHERE created_at >= $1 AND created_at < $2
GROUP BY user_id
ORDER BY SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens) DESC
LIMIT $3
)
SELECT
TO_CHAR(u.created_at, '%s') as date,
u.user_id,
COALESCE(us.email, '') as email,
COUNT(*) as requests,
COALESCE(SUM(u.input_tokens + u.output_tokens + u.cache_creation_tokens + u.cache_read_tokens), 0) as tokens,
COALESCE(SUM(u.total_cost), 0) as cost,
COALESCE(SUM(u.actual_cost), 0) as actual_cost
FROM usage_logs u
LEFT JOIN users us ON u.user_id = us.id
WHERE u.user_id IN (SELECT user_id FROM top_users)
AND u.created_at >= $4 AND u.created_at < $5
GROUP BY date, u.user_id, us.email
ORDER BY date ASC, tokens DESC
`, dateFormat)
rows, err := r.sql.QueryContext(ctx, query, startTime, endTime, limit, startTime, endTime)
if err != nil {
return nil, err
}
defer func() {
// 保持主错误优先;仅在无错误时回传 Close 失败。
// 同时清空返回值,避免误用不完整结果。
if closeErr := rows.Close(); closeErr != nil && err == nil {
err = closeErr
results = nil
}
}()
results = make([]UserUsageTrendPoint, 0)
for rows.Next() {
var row UserUsageTrendPoint
if err = rows.Scan(&row.Date, &row.UserID, &row.Email, &row.Requests, &row.Tokens, &row.Cost, &row.ActualCost); err != nil {
return nil, err
}
results = append(results, row)
}
if err = rows.Err(); err != nil {
return nil, err
}
return results, nil
}
// UserDashboardStats 用户仪表盘统计
type UserDashboardStats = usagestats.UserDashboardStats
// GetUserDashboardStats 获取用户专属的仪表盘统计
func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID int64) (*UserDashboardStats, error) {
stats := &UserDashboardStats{}
today := timezone.Today()
// API Key 统计
if err := scanSingleRow(
ctx,
r.sql,
"SELECT COUNT(*) FROM api_keys WHERE user_id = $1 AND deleted_at IS NULL",
[]any{userID},
&stats.TotalApiKeys,
); err != nil {
return nil, err
}
if err := scanSingleRow(
ctx,
r.sql,
"SELECT COUNT(*) FROM api_keys WHERE user_id = $1 AND status = $2 AND deleted_at IS NULL",
[]any{userID, service.StatusActive},
&stats.ActiveApiKeys,
); err != nil {
return nil, err
}
// 累计 Token 统计
totalStatsQuery := `
SELECT
COUNT(*) as total_requests,
COALESCE(SUM(input_tokens), 0) as total_input_tokens,
COALESCE(SUM(output_tokens), 0) as total_output_tokens,
COALESCE(SUM(cache_creation_tokens), 0) as total_cache_creation_tokens,
COALESCE(SUM(cache_read_tokens), 0) as total_cache_read_tokens,
COALESCE(SUM(total_cost), 0) as total_cost,
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
COALESCE(AVG(duration_ms), 0) as avg_duration_ms
FROM usage_logs
WHERE user_id = $1
`
if err := scanSingleRow(
ctx,
r.sql,
totalStatsQuery,
[]any{userID},
&stats.TotalRequests,
&stats.TotalInputTokens,
&stats.TotalOutputTokens,
&stats.TotalCacheCreationTokens,
&stats.TotalCacheReadTokens,
&stats.TotalCost,
&stats.TotalActualCost,
&stats.AverageDurationMs,
); err != nil {
return nil, err
}
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheCreationTokens + stats.TotalCacheReadTokens
// 今日 Token 统计
todayStatsQuery := `
SELECT
COUNT(*) as today_requests,
COALESCE(SUM(input_tokens), 0) as today_input_tokens,
COALESCE(SUM(output_tokens), 0) as today_output_tokens,
COALESCE(SUM(cache_creation_tokens), 0) as today_cache_creation_tokens,
COALESCE(SUM(cache_read_tokens), 0) as today_cache_read_tokens,
COALESCE(SUM(total_cost), 0) as today_cost,
COALESCE(SUM(actual_cost), 0) as today_actual_cost
FROM usage_logs
WHERE user_id = $1 AND created_at >= $2
`
if err := scanSingleRow(
ctx,
r.sql,
todayStatsQuery,
[]any{userID, today},
&stats.TodayRequests,
&stats.TodayInputTokens,
&stats.TodayOutputTokens,
&stats.TodayCacheCreationTokens,
&stats.TodayCacheReadTokens,
&stats.TodayCost,
&stats.TodayActualCost,
); err != nil {
return nil, err
}
stats.TodayTokens = stats.TodayInputTokens + stats.TodayOutputTokens + stats.TodayCacheCreationTokens + stats.TodayCacheReadTokens
// 性能指标RPM 和 TPM最近1分钟仅统计该用户的请求
rpm, tpm, err := r.getPerformanceStats(ctx, userID)
if err != nil {
return nil, err
}
stats.Rpm = rpm
stats.Tpm = tpm
return stats, nil
}
// GetUserUsageTrendByUserID 获取指定用户的使用趋势
func (r *usageLogRepository) GetUserUsageTrendByUserID(ctx context.Context, userID int64, startTime, endTime time.Time, granularity string) (results []TrendDataPoint, err error) {
dateFormat := "YYYY-MM-DD"
if granularity == "hour" {
dateFormat = "YYYY-MM-DD HH24:00"
}
query := fmt.Sprintf(`
SELECT
TO_CHAR(created_at, '%s') as date,
COUNT(*) as requests,
COALESCE(SUM(input_tokens), 0) as input_tokens,
COALESCE(SUM(output_tokens), 0) as output_tokens,
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as cache_tokens,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
COALESCE(SUM(total_cost), 0) as cost,
COALESCE(SUM(actual_cost), 0) as actual_cost
FROM usage_logs
WHERE user_id = $1 AND created_at >= $2 AND created_at < $3
GROUP BY date
ORDER BY date ASC
`, dateFormat)
rows, err := r.sql.QueryContext(ctx, query, userID, startTime, endTime)
if err != nil {
return nil, err
}
defer func() {
// 保持主错误优先;仅在无错误时回传 Close 失败。
// 同时清空返回值,避免误用不完整结果。
if closeErr := rows.Close(); closeErr != nil && err == nil {
err = closeErr
results = nil
}
}()
results, err = scanTrendRows(rows)
if err != nil {
return nil, err
}
return results, nil
}
// GetUserModelStats 获取指定用户的模型统计
func (r *usageLogRepository) GetUserModelStats(ctx context.Context, userID int64, startTime, endTime time.Time) (results []ModelStat, err error) {
query := `
SELECT
model,
COUNT(*) as requests,
COALESCE(SUM(input_tokens), 0) as input_tokens,
COALESCE(SUM(output_tokens), 0) as output_tokens,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
COALESCE(SUM(total_cost), 0) as cost,
COALESCE(SUM(actual_cost), 0) as actual_cost
FROM usage_logs
WHERE user_id = $1 AND created_at >= $2 AND created_at < $3
GROUP BY model
ORDER BY total_tokens DESC
`
rows, err := r.sql.QueryContext(ctx, query, userID, startTime, endTime)
if err != nil {
return nil, err
}
defer func() {
// 保持主错误优先;仅在无错误时回传 Close 失败。
// 同时清空返回值,避免误用不完整结果。
if closeErr := rows.Close(); closeErr != nil && err == nil {
err = closeErr
results = nil
}
}()
results, err = scanModelStatsRows(rows)
if err != nil {
return nil, err
}
return results, nil
}
// UsageLogFilters represents filters for usage log queries
type UsageLogFilters = usagestats.UsageLogFilters
// ListWithFilters lists usage logs with optional filters (for admin)
func (r *usageLogRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters UsageLogFilters) ([]service.UsageLog, *pagination.PaginationResult, error) {
conditions := make([]string, 0, 8)
args := make([]any, 0, 8)
if filters.UserID > 0 {
conditions = append(conditions, fmt.Sprintf("user_id = $%d", len(args)+1))
args = append(args, filters.UserID)
}
if filters.ApiKeyID > 0 {
conditions = append(conditions, fmt.Sprintf("api_key_id = $%d", len(args)+1))
args = append(args, filters.ApiKeyID)
}
if filters.AccountID > 0 {
conditions = append(conditions, fmt.Sprintf("account_id = $%d", len(args)+1))
args = append(args, filters.AccountID)
}
if filters.GroupID > 0 {
conditions = append(conditions, fmt.Sprintf("group_id = $%d", len(args)+1))
args = append(args, filters.GroupID)
}
if filters.Model != "" {
conditions = append(conditions, fmt.Sprintf("model = $%d", len(args)+1))
args = append(args, filters.Model)
}
if filters.Stream != nil {
conditions = append(conditions, fmt.Sprintf("stream = $%d", len(args)+1))
args = append(args, *filters.Stream)
}
if filters.BillingType != nil {
conditions = append(conditions, fmt.Sprintf("billing_type = $%d", len(args)+1))
args = append(args, int16(*filters.BillingType))
}
if filters.StartTime != nil {
conditions = append(conditions, fmt.Sprintf("created_at >= $%d", len(args)+1))
args = append(args, *filters.StartTime)
}
if filters.EndTime != nil {
conditions = append(conditions, fmt.Sprintf("created_at <= $%d", len(args)+1))
args = append(args, *filters.EndTime)
}
whereClause := buildWhere(conditions)
logs, page, err := r.listUsageLogsWithPagination(ctx, whereClause, args, params)
if err != nil {
return nil, nil, err
}
if err := r.hydrateUsageLogAssociations(ctx, logs); err != nil {
return nil, nil, err
}
return logs, page, nil
}
// UsageStats represents usage statistics
type UsageStats = usagestats.UsageStats
// BatchUserUsageStats represents usage stats for a single user
type BatchUserUsageStats = usagestats.BatchUserUsageStats
// GetBatchUserUsageStats gets today and total actual_cost for multiple users
func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*BatchUserUsageStats, error) {
result := make(map[int64]*BatchUserUsageStats)
if len(userIDs) == 0 {
return result, nil
}
for _, id := range userIDs {
result[id] = &BatchUserUsageStats{UserID: id}
}
query := `
SELECT user_id, COALESCE(SUM(actual_cost), 0) as total_cost
FROM usage_logs
WHERE user_id = ANY($1)
GROUP BY user_id
`
rows, err := r.sql.QueryContext(ctx, query, pq.Array(userIDs))
if err != nil {
return nil, err
}
for rows.Next() {
var userID int64
var total float64
if err := rows.Scan(&userID, &total); err != nil {
_ = rows.Close()
return nil, err
}
if stats, ok := result[userID]; ok {
stats.TotalActualCost = total
}
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
today := timezone.Today()
todayQuery := `
SELECT user_id, COALESCE(SUM(actual_cost), 0) as today_cost
FROM usage_logs
WHERE user_id = ANY($1) AND created_at >= $2
GROUP BY user_id
`
rows, err = r.sql.QueryContext(ctx, todayQuery, pq.Array(userIDs), today)
if err != nil {
return nil, err
}
for rows.Next() {
var userID int64
var total float64
if err := rows.Scan(&userID, &total); err != nil {
_ = rows.Close()
return nil, err
}
if stats, ok := result[userID]; ok {
stats.TodayActualCost = total
}
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return result, nil
}
// BatchApiKeyUsageStats represents usage stats for a single API key
type BatchApiKeyUsageStats = usagestats.BatchApiKeyUsageStats
// GetBatchApiKeyUsageStats gets today and total actual_cost for multiple API keys
func (r *usageLogRepository) GetBatchApiKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*BatchApiKeyUsageStats, error) {
result := make(map[int64]*BatchApiKeyUsageStats)
if len(apiKeyIDs) == 0 {
return result, nil
}
for _, id := range apiKeyIDs {
result[id] = &BatchApiKeyUsageStats{ApiKeyID: id}
}
query := `
SELECT api_key_id, COALESCE(SUM(actual_cost), 0) as total_cost
FROM usage_logs
WHERE api_key_id = ANY($1)
GROUP BY api_key_id
`
rows, err := r.sql.QueryContext(ctx, query, pq.Array(apiKeyIDs))
if err != nil {
return nil, err
}
for rows.Next() {
var apiKeyID int64
var total float64
if err := rows.Scan(&apiKeyID, &total); err != nil {
_ = rows.Close()
return nil, err
}
if stats, ok := result[apiKeyID]; ok {
stats.TotalActualCost = total
}
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
today := timezone.Today()
todayQuery := `
SELECT api_key_id, COALESCE(SUM(actual_cost), 0) as today_cost
FROM usage_logs
WHERE api_key_id = ANY($1) AND created_at >= $2
GROUP BY api_key_id
`
rows, err = r.sql.QueryContext(ctx, todayQuery, pq.Array(apiKeyIDs), today)
if err != nil {
return nil, err
}
for rows.Next() {
var apiKeyID int64
var total float64
if err := rows.Scan(&apiKeyID, &total); err != nil {
_ = rows.Close()
return nil, err
}
if stats, ok := result[apiKeyID]; ok {
stats.TodayActualCost = total
}
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return result, nil
}
// GetUsageTrendWithFilters returns usage trend data with optional user/api_key filters
func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID int64) (results []TrendDataPoint, err error) {
dateFormat := "YYYY-MM-DD"
if granularity == "hour" {
dateFormat = "YYYY-MM-DD HH24:00"
}
query := fmt.Sprintf(`
SELECT
TO_CHAR(created_at, '%s') as date,
COUNT(*) as requests,
COALESCE(SUM(input_tokens), 0) as input_tokens,
COALESCE(SUM(output_tokens), 0) as output_tokens,
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as cache_tokens,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
COALESCE(SUM(total_cost), 0) as cost,
COALESCE(SUM(actual_cost), 0) as actual_cost
FROM usage_logs
WHERE created_at >= $1 AND created_at < $2
`, dateFormat)
args := []any{startTime, endTime}
if userID > 0 {
query += fmt.Sprintf(" AND user_id = $%d", len(args)+1)
args = append(args, userID)
}
if apiKeyID > 0 {
query += fmt.Sprintf(" AND api_key_id = $%d", len(args)+1)
args = append(args, apiKeyID)
}
query += " GROUP BY date ORDER BY date ASC"
rows, err := r.sql.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
defer func() {
// 保持主错误优先;仅在无错误时回传 Close 失败。
// 同时清空返回值,避免误用不完整结果。
if closeErr := rows.Close(); closeErr != nil && err == nil {
err = closeErr
results = nil
}
}()
results, err = scanTrendRows(rows)
if err != nil {
return nil, err
}
return results, nil
}
// GetModelStatsWithFilters returns model statistics with optional user/api_key filters
func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID int64) (results []ModelStat, err error) {
query := `
SELECT
model,
COUNT(*) as requests,
COALESCE(SUM(input_tokens), 0) as input_tokens,
COALESCE(SUM(output_tokens), 0) as output_tokens,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
COALESCE(SUM(total_cost), 0) as cost,
COALESCE(SUM(actual_cost), 0) as actual_cost
FROM usage_logs
WHERE created_at >= $1 AND created_at < $2
`
args := []any{startTime, endTime}
if userID > 0 {
query += fmt.Sprintf(" AND user_id = $%d", len(args)+1)
args = append(args, userID)
}
if apiKeyID > 0 {
query += fmt.Sprintf(" AND api_key_id = $%d", len(args)+1)
args = append(args, apiKeyID)
}
if accountID > 0 {
query += fmt.Sprintf(" AND account_id = $%d", len(args)+1)
args = append(args, accountID)
}
query += " GROUP BY model ORDER BY total_tokens DESC"
rows, err := r.sql.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
defer func() {
// 保持主错误优先;仅在无错误时回传 Close 失败。
// 同时清空返回值,避免误用不完整结果。
if closeErr := rows.Close(); closeErr != nil && err == nil {
err = closeErr
results = nil
}
}()
results, err = scanModelStatsRows(rows)
if err != nil {
return nil, err
}
return results, nil
}
// GetGlobalStats gets usage statistics for all users within a time range
func (r *usageLogRepository) GetGlobalStats(ctx context.Context, startTime, endTime time.Time) (*UsageStats, error) {
query := `
SELECT
COUNT(*) as total_requests,
COALESCE(SUM(input_tokens), 0) as total_input_tokens,
COALESCE(SUM(output_tokens), 0) as total_output_tokens,
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as total_cache_tokens,
COALESCE(SUM(total_cost), 0) as total_cost,
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
COALESCE(AVG(duration_ms), 0) as avg_duration_ms
FROM usage_logs
WHERE created_at >= $1 AND created_at <= $2
`
stats := &UsageStats{}
if err := scanSingleRow(
ctx,
r.sql,
query,
[]any{startTime, endTime},
&stats.TotalRequests,
&stats.TotalInputTokens,
&stats.TotalOutputTokens,
&stats.TotalCacheTokens,
&stats.TotalCost,
&stats.TotalActualCost,
&stats.AverageDurationMs,
); err != nil {
return nil, err
}
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheTokens
return stats, nil
}
// AccountUsageHistory represents daily usage history for an account
type AccountUsageHistory = usagestats.AccountUsageHistory
// AccountUsageSummary represents summary statistics for an account
type AccountUsageSummary = usagestats.AccountUsageSummary
// AccountUsageStatsResponse represents the full usage statistics response for an account
type AccountUsageStatsResponse = usagestats.AccountUsageStatsResponse
// GetAccountUsageStats returns comprehensive usage statistics for an account over a time range
func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID int64, startTime, endTime time.Time) (resp *AccountUsageStatsResponse, err error) {
daysCount := int(endTime.Sub(startTime).Hours()/24) + 1
if daysCount <= 0 {
daysCount = 30
}
query := `
SELECT
TO_CHAR(created_at, 'YYYY-MM-DD') as date,
COUNT(*) as requests,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
COALESCE(SUM(total_cost), 0) as cost,
COALESCE(SUM(actual_cost), 0) as actual_cost
FROM usage_logs
WHERE account_id = $1 AND created_at >= $2 AND created_at < $3
GROUP BY date
ORDER BY date ASC
`
rows, err := r.sql.QueryContext(ctx, query, accountID, startTime, endTime)
if err != nil {
return nil, err
}
defer func() {
// 保持主错误优先;仅在无错误时回传 Close 失败。
// 同时清空返回值,避免误用不完整结果。
if closeErr := rows.Close(); closeErr != nil && err == nil {
err = closeErr
resp = nil
}
}()
history := make([]AccountUsageHistory, 0)
for rows.Next() {
var date string
var requests int64
var tokens int64
var cost float64
var actualCost float64
if err = rows.Scan(&date, &requests, &tokens, &cost, &actualCost); err != nil {
return nil, err
}
t, _ := time.Parse("2006-01-02", date)
history = append(history, AccountUsageHistory{
Date: date,
Label: t.Format("01/02"),
Requests: requests,
Tokens: tokens,
Cost: cost,
ActualCost: actualCost,
})
}
if err = rows.Err(); err != nil {
return nil, err
}
var totalActualCost, totalStandardCost float64
var totalRequests, totalTokens int64
var highestCostDay, highestRequestDay *AccountUsageHistory
for i := range history {
h := &history[i]
totalActualCost += h.ActualCost
totalStandardCost += h.Cost
totalRequests += h.Requests
totalTokens += h.Tokens
if highestCostDay == nil || h.ActualCost > highestCostDay.ActualCost {
highestCostDay = h
}
if highestRequestDay == nil || h.Requests > highestRequestDay.Requests {
highestRequestDay = h
}
}
actualDaysUsed := len(history)
if actualDaysUsed == 0 {
actualDaysUsed = 1
}
avgQuery := "SELECT COALESCE(AVG(duration_ms), 0) as avg_duration_ms FROM usage_logs WHERE account_id = $1 AND created_at >= $2 AND created_at < $3"
var avgDuration float64
if err := scanSingleRow(ctx, r.sql, avgQuery, []any{accountID, startTime, endTime}, &avgDuration); err != nil {
return nil, err
}
summary := AccountUsageSummary{
Days: daysCount,
ActualDaysUsed: actualDaysUsed,
TotalCost: totalActualCost,
TotalStandardCost: totalStandardCost,
TotalRequests: totalRequests,
TotalTokens: totalTokens,
AvgDailyCost: totalActualCost / float64(actualDaysUsed),
AvgDailyRequests: float64(totalRequests) / float64(actualDaysUsed),
AvgDailyTokens: float64(totalTokens) / float64(actualDaysUsed),
AvgDurationMs: avgDuration,
}
todayStr := timezone.Now().Format("2006-01-02")
for i := range history {
if history[i].Date == todayStr {
summary.Today = &struct {
Date string `json:"date"`
Cost float64 `json:"cost"`
Requests int64 `json:"requests"`
Tokens int64 `json:"tokens"`
}{
Date: history[i].Date,
Cost: history[i].ActualCost,
Requests: history[i].Requests,
Tokens: history[i].Tokens,
}
break
}
}
if highestCostDay != nil {
summary.HighestCostDay = &struct {
Date string `json:"date"`
Label string `json:"label"`
Cost float64 `json:"cost"`
Requests int64 `json:"requests"`
}{
Date: highestCostDay.Date,
Label: highestCostDay.Label,
Cost: highestCostDay.ActualCost,
Requests: highestCostDay.Requests,
}
}
if highestRequestDay != nil {
summary.HighestRequestDay = &struct {
Date string `json:"date"`
Label string `json:"label"`
Requests int64 `json:"requests"`
Cost float64 `json:"cost"`
}{
Date: highestRequestDay.Date,
Label: highestRequestDay.Label,
Requests: highestRequestDay.Requests,
Cost: highestRequestDay.ActualCost,
}
}
models, err := r.GetModelStatsWithFilters(ctx, startTime, endTime, 0, 0, accountID)
if err != nil {
models = []ModelStat{}
}
resp = &AccountUsageStatsResponse{
History: history,
Summary: summary,
Models: models,
}
return resp, nil
}
func (r *usageLogRepository) listUsageLogsWithPagination(ctx context.Context, whereClause string, args []any, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
countQuery := "SELECT COUNT(*) FROM usage_logs " + whereClause
var total int64
if err := scanSingleRow(ctx, r.sql, countQuery, args, &total); err != nil {
return nil, nil, err
}
limitPos := len(args) + 1
offsetPos := len(args) + 2
listArgs := append(append([]any{}, args...), params.Limit(), params.Offset())
query := fmt.Sprintf("SELECT %s FROM usage_logs %s ORDER BY id DESC LIMIT $%d OFFSET $%d", usageLogSelectColumns, whereClause, limitPos, offsetPos)
logs, err := r.queryUsageLogs(ctx, query, listArgs...)
if err != nil {
return nil, nil, err
}
return logs, paginationResultFromTotal(total, params), nil
}
func (r *usageLogRepository) queryUsageLogs(ctx context.Context, query string, args ...any) (logs []service.UsageLog, err error) {
rows, err := r.sql.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
defer func() {
// 保持主错误优先;仅在无错误时回传 Close 失败。
// 同时清空返回值,避免误用不完整结果。
if closeErr := rows.Close(); closeErr != nil && err == nil {
err = closeErr
logs = nil
}
}()
logs = make([]service.UsageLog, 0)
for rows.Next() {
var log *service.UsageLog
log, err = scanUsageLog(rows)
if err != nil {
return nil, err
}
logs = append(logs, *log)
}
if err = rows.Err(); err != nil {
return nil, err
}
return logs, nil
}
func (r *usageLogRepository) hydrateUsageLogAssociations(ctx context.Context, logs []service.UsageLog) error {
// 关联数据使用 Ent 批量加载,避免把复杂 SQL 继续膨胀。
if len(logs) == 0 {
return nil
}
ids := collectUsageLogIDs(logs)
users, err := r.loadUsers(ctx, ids.userIDs)
if err != nil {
return err
}
apiKeys, err := r.loadApiKeys(ctx, ids.apiKeyIDs)
if err != nil {
return err
}
accounts, err := r.loadAccounts(ctx, ids.accountIDs)
if err != nil {
return err
}
groups, err := r.loadGroups(ctx, ids.groupIDs)
if err != nil {
return err
}
subs, err := r.loadSubscriptions(ctx, ids.subscriptionIDs)
if err != nil {
return err
}
for i := range logs {
if user, ok := users[logs[i].UserID]; ok {
logs[i].User = user
}
if key, ok := apiKeys[logs[i].ApiKeyID]; ok {
logs[i].ApiKey = key
}
if acc, ok := accounts[logs[i].AccountID]; ok {
logs[i].Account = acc
}
if logs[i].GroupID != nil {
if group, ok := groups[*logs[i].GroupID]; ok {
logs[i].Group = group
}
}
if logs[i].SubscriptionID != nil {
if sub, ok := subs[*logs[i].SubscriptionID]; ok {
logs[i].Subscription = sub
}
}
}
return nil
}
type usageLogIDs struct {
userIDs []int64
apiKeyIDs []int64
accountIDs []int64
groupIDs []int64
subscriptionIDs []int64
}
func collectUsageLogIDs(logs []service.UsageLog) usageLogIDs {
idSet := func() map[int64]struct{} { return make(map[int64]struct{}) }
userIDs := idSet()
apiKeyIDs := idSet()
accountIDs := idSet()
groupIDs := idSet()
subscriptionIDs := idSet()
for i := range logs {
userIDs[logs[i].UserID] = struct{}{}
apiKeyIDs[logs[i].ApiKeyID] = struct{}{}
accountIDs[logs[i].AccountID] = struct{}{}
if logs[i].GroupID != nil {
groupIDs[*logs[i].GroupID] = struct{}{}
}
if logs[i].SubscriptionID != nil {
subscriptionIDs[*logs[i].SubscriptionID] = struct{}{}
}
}
return usageLogIDs{
userIDs: setToSlice(userIDs),
apiKeyIDs: setToSlice(apiKeyIDs),
accountIDs: setToSlice(accountIDs),
groupIDs: setToSlice(groupIDs),
subscriptionIDs: setToSlice(subscriptionIDs),
}
}
func (r *usageLogRepository) loadUsers(ctx context.Context, ids []int64) (map[int64]*service.User, error) {
out := make(map[int64]*service.User)
if len(ids) == 0 {
return out, nil
}
models, err := r.client.User.Query().Where(dbuser.IDIn(ids...)).All(ctx)
if err != nil {
return nil, err
}
for _, m := range models {
out[m.ID] = userEntityToService(m)
}
return out, nil
}
func (r *usageLogRepository) loadApiKeys(ctx context.Context, ids []int64) (map[int64]*service.ApiKey, error) {
out := make(map[int64]*service.ApiKey)
if len(ids) == 0 {
return out, nil
}
models, err := r.client.ApiKey.Query().Where(dbapikey.IDIn(ids...)).All(ctx)
if err != nil {
return nil, err
}
for _, m := range models {
out[m.ID] = apiKeyEntityToService(m)
}
return out, nil
}
func (r *usageLogRepository) loadAccounts(ctx context.Context, ids []int64) (map[int64]*service.Account, error) {
out := make(map[int64]*service.Account)
if len(ids) == 0 {
return out, nil
}
models, err := r.client.Account.Query().Where(dbaccount.IDIn(ids...)).All(ctx)
if err != nil {
return nil, err
}
for _, m := range models {
out[m.ID] = accountEntityToService(m)
}
return out, nil
}
func (r *usageLogRepository) loadGroups(ctx context.Context, ids []int64) (map[int64]*service.Group, error) {
out := make(map[int64]*service.Group)
if len(ids) == 0 {
return out, nil
}
models, err := r.client.Group.Query().Where(dbgroup.IDIn(ids...)).All(ctx)
if err != nil {
return nil, err
}
for _, m := range models {
out[m.ID] = groupEntityToService(m)
}
return out, nil
}
func (r *usageLogRepository) loadSubscriptions(ctx context.Context, ids []int64) (map[int64]*service.UserSubscription, error) {
out := make(map[int64]*service.UserSubscription)
if len(ids) == 0 {
return out, nil
}
models, err := r.client.UserSubscription.Query().Where(dbusersub.IDIn(ids...)).All(ctx)
if err != nil {
return nil, err
}
for _, m := range models {
out[m.ID] = userSubscriptionEntityToService(m)
}
return out, nil
}
func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, error) {
var (
id int64
userID int64
apiKeyID int64
accountID int64
requestID sql.NullString
model string
groupID sql.NullInt64
subscriptionID sql.NullInt64
inputTokens int
outputTokens int
cacheCreationTokens int
cacheReadTokens int
cacheCreation5m int
cacheCreation1h int
inputCost float64
outputCost float64
cacheCreationCost float64
cacheReadCost float64
totalCost float64
actualCost float64
rateMultiplier float64
billingType int16
stream bool
durationMs sql.NullInt64
firstTokenMs sql.NullInt64
createdAt time.Time
)
if err := scanner.Scan(
&id,
&userID,
&apiKeyID,
&accountID,
&requestID,
&model,
&groupID,
&subscriptionID,
&inputTokens,
&outputTokens,
&cacheCreationTokens,
&cacheReadTokens,
&cacheCreation5m,
&cacheCreation1h,
&inputCost,
&outputCost,
&cacheCreationCost,
&cacheReadCost,
&totalCost,
&actualCost,
&rateMultiplier,
&billingType,
&stream,
&durationMs,
&firstTokenMs,
&createdAt,
); err != nil {
return nil, err
}
log := &service.UsageLog{
ID: id,
UserID: userID,
ApiKeyID: apiKeyID,
AccountID: accountID,
Model: model,
InputTokens: inputTokens,
OutputTokens: outputTokens,
CacheCreationTokens: cacheCreationTokens,
CacheReadTokens: cacheReadTokens,
CacheCreation5mTokens: cacheCreation5m,
CacheCreation1hTokens: cacheCreation1h,
InputCost: inputCost,
OutputCost: outputCost,
CacheCreationCost: cacheCreationCost,
CacheReadCost: cacheReadCost,
TotalCost: totalCost,
ActualCost: actualCost,
RateMultiplier: rateMultiplier,
BillingType: int8(billingType),
Stream: stream,
CreatedAt: createdAt,
}
if requestID.Valid {
log.RequestID = requestID.String
}
if groupID.Valid {
value := groupID.Int64
log.GroupID = &value
}
if subscriptionID.Valid {
value := subscriptionID.Int64
log.SubscriptionID = &value
}
if durationMs.Valid {
value := int(durationMs.Int64)
log.DurationMs = &value
}
if firstTokenMs.Valid {
value := int(firstTokenMs.Int64)
log.FirstTokenMs = &value
}
return log, nil
}
func scanTrendRows(rows *sql.Rows) ([]TrendDataPoint, error) {
results := make([]TrendDataPoint, 0)
for rows.Next() {
var row TrendDataPoint
if err := rows.Scan(
&row.Date,
&row.Requests,
&row.InputTokens,
&row.OutputTokens,
&row.CacheTokens,
&row.TotalTokens,
&row.Cost,
&row.ActualCost,
); err != nil {
return nil, err
}
results = append(results, row)
}
if err := rows.Err(); err != nil {
return nil, err
}
return results, nil
}
func scanModelStatsRows(rows *sql.Rows) ([]ModelStat, error) {
results := make([]ModelStat, 0)
for rows.Next() {
var row ModelStat
if err := rows.Scan(
&row.Model,
&row.Requests,
&row.InputTokens,
&row.OutputTokens,
&row.TotalTokens,
&row.Cost,
&row.ActualCost,
); err != nil {
return nil, err
}
results = append(results, row)
}
if err := rows.Err(); err != nil {
return nil, err
}
return results, nil
}
func buildWhere(conditions []string) string {
if len(conditions) == 0 {
return ""
}
return "WHERE " + strings.Join(conditions, " AND ")
}
func nullInt64(v *int64) sql.NullInt64 {
if v == nil {
return sql.NullInt64{}
}
return sql.NullInt64{Int64: *v, Valid: true}
}
func nullInt(v *int) sql.NullInt64 {
if v == nil {
return sql.NullInt64{}
}
return sql.NullInt64{Int64: int64(*v), Valid: true}
}
func setToSlice(set map[int64]struct{}) []int64 {
out := make([]int64, 0, len(set))
for id := range set {
out = append(out, id)
}
return out
}