2203 lines
64 KiB
Go
2203 lines
64 KiB
Go
package repository
|
||
|
||
import (
|
||
"context"
|
||
"database/sql"
|
||
"errors"
|
||
"fmt"
|
||
"os"
|
||
"strings"
|
||
"time"
|
||
|
||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||
dbaccount "github.com/Wei-Shaw/sub2api/ent/account"
|
||
dbapikey "github.com/Wei-Shaw/sub2api/ent/apikey"
|
||
dbgroup "github.com/Wei-Shaw/sub2api/ent/group"
|
||
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
|
||
dbusersub "github.com/Wei-Shaw/sub2api/ent/usersubscription"
|
||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||
"github.com/lib/pq"
|
||
)
|
||
|
||
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, created_at"
|
||
|
||
type usageLogRepository struct {
|
||
client *dbent.Client
|
||
sql sqlExecutor
|
||
}
|
||
|
||
func NewUsageLogRepository(client *dbent.Client, sqlDB *sql.DB) service.UsageLogRepository {
|
||
return newUsageLogRepositoryWithSQL(client, sqlDB)
|
||
}
|
||
|
||
func newUsageLogRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *usageLogRepository {
|
||
// 使用 scanSingleRow 替代 QueryRowContext,保证 ent.Tx 作为 sqlExecutor 可用。
|
||
return &usageLogRepository{client: client, sql: sqlq}
|
||
}
|
||
|
||
// getPerformanceStats 获取 RPM 和 TPM(近5分钟平均值,可选按用户过滤)
|
||
func (r *usageLogRepository) getPerformanceStats(ctx context.Context, userID int64) (rpm, tpm int64, err error) {
|
||
fiveMinutesAgo := time.Now().Add(-5 * time.Minute)
|
||
query := `
|
||
SELECT
|
||
COUNT(*) as request_count,
|
||
COALESCE(SUM(input_tokens + output_tokens), 0) as token_count
|
||
FROM usage_logs
|
||
WHERE created_at >= $1`
|
||
args := []any{fiveMinutesAgo}
|
||
if userID > 0 {
|
||
query += " AND user_id = $2"
|
||
args = append(args, userID)
|
||
}
|
||
|
||
var requestCount int64
|
||
var tokenCount int64
|
||
if err := scanSingleRow(ctx, r.sql, query, args, &requestCount, &tokenCount); err != nil {
|
||
return 0, 0, err
|
||
}
|
||
return requestCount / 5, tokenCount / 5, nil
|
||
}
|
||
|
||
func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) (bool, error) {
|
||
if log == nil {
|
||
return false, nil
|
||
}
|
||
|
||
// 在事务上下文中,使用 tx 绑定的 ExecQuerier 执行原生 SQL,保证与其他更新同事务。
|
||
// 无事务时回退到默认的 *sql.DB 执行器。
|
||
sqlq := r.sql
|
||
if tx := dbent.TxFromContext(ctx); tx != nil {
|
||
sqlq = tx.Client()
|
||
}
|
||
|
||
createdAt := log.CreatedAt
|
||
if createdAt.IsZero() {
|
||
createdAt = time.Now()
|
||
}
|
||
|
||
requestID := strings.TrimSpace(log.RequestID)
|
||
log.RequestID = requestID
|
||
|
||
rateMultiplier := log.RateMultiplier
|
||
|
||
query := `
|
||
INSERT INTO usage_logs (
|
||
user_id,
|
||
api_key_id,
|
||
account_id,
|
||
request_id,
|
||
model,
|
||
group_id,
|
||
subscription_id,
|
||
input_tokens,
|
||
output_tokens,
|
||
cache_creation_tokens,
|
||
cache_read_tokens,
|
||
cache_creation_5m_tokens,
|
||
cache_creation_1h_tokens,
|
||
input_cost,
|
||
output_cost,
|
||
cache_creation_cost,
|
||
cache_read_cost,
|
||
total_cost,
|
||
actual_cost,
|
||
rate_multiplier,
|
||
billing_type,
|
||
stream,
|
||
duration_ms,
|
||
first_token_ms,
|
||
user_agent,
|
||
ip_address,
|
||
image_count,
|
||
image_size,
|
||
created_at
|
||
) VALUES (
|
||
$1, $2, $3, $4, $5,
|
||
$6, $7,
|
||
$8, $9, $10, $11,
|
||
$12, $13,
|
||
$14, $15, $16, $17, $18, $19,
|
||
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29
|
||
)
|
||
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
||
RETURNING id, created_at
|
||
`
|
||
|
||
groupID := nullInt64(log.GroupID)
|
||
subscriptionID := nullInt64(log.SubscriptionID)
|
||
duration := nullInt(log.DurationMs)
|
||
firstToken := nullInt(log.FirstTokenMs)
|
||
userAgent := nullString(log.UserAgent)
|
||
ipAddress := nullString(log.IPAddress)
|
||
imageSize := nullString(log.ImageSize)
|
||
|
||
var requestIDArg any
|
||
if requestID != "" {
|
||
requestIDArg = requestID
|
||
}
|
||
|
||
args := []any{
|
||
log.UserID,
|
||
log.APIKeyID,
|
||
log.AccountID,
|
||
requestIDArg,
|
||
log.Model,
|
||
groupID,
|
||
subscriptionID,
|
||
log.InputTokens,
|
||
log.OutputTokens,
|
||
log.CacheCreationTokens,
|
||
log.CacheReadTokens,
|
||
log.CacheCreation5mTokens,
|
||
log.CacheCreation1hTokens,
|
||
log.InputCost,
|
||
log.OutputCost,
|
||
log.CacheCreationCost,
|
||
log.CacheReadCost,
|
||
log.TotalCost,
|
||
log.ActualCost,
|
||
rateMultiplier,
|
||
log.BillingType,
|
||
log.Stream,
|
||
duration,
|
||
firstToken,
|
||
userAgent,
|
||
ipAddress,
|
||
log.ImageCount,
|
||
imageSize,
|
||
createdAt,
|
||
}
|
||
if err := scanSingleRow(ctx, sqlq, query, args, &log.ID, &log.CreatedAt); err != nil {
|
||
if errors.Is(err, sql.ErrNoRows) && requestID != "" {
|
||
selectQuery := "SELECT id, created_at FROM usage_logs WHERE request_id = $1 AND api_key_id = $2"
|
||
if err := scanSingleRow(ctx, sqlq, selectQuery, []any{requestID, log.APIKeyID}, &log.ID, &log.CreatedAt); err != nil {
|
||
return false, err
|
||
}
|
||
log.RateMultiplier = rateMultiplier
|
||
return false, nil
|
||
} else {
|
||
return false, err
|
||
}
|
||
}
|
||
log.RateMultiplier = rateMultiplier
|
||
return true, nil
|
||
}
|
||
|
||
func (r *usageLogRepository) GetByID(ctx context.Context, id int64) (log *service.UsageLog, err error) {
|
||
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE id = $1"
|
||
rows, err := r.sql.QueryContext(ctx, query, id)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
defer func() {
|
||
// 保持主错误优先;仅在无错误时回传 Close 失败。
|
||
// 同时清空返回值,避免误用不完整结果。
|
||
if closeErr := rows.Close(); closeErr != nil && err == nil {
|
||
err = closeErr
|
||
log = nil
|
||
}
|
||
}()
|
||
if !rows.Next() {
|
||
if err = rows.Err(); err != nil {
|
||
return nil, err
|
||
}
|
||
return nil, service.ErrUsageLogNotFound
|
||
}
|
||
log, err = scanUsageLog(rows)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
if err = rows.Err(); err != nil {
|
||
return nil, err
|
||
}
|
||
return log, nil
|
||
}
|
||
|
||
func (r *usageLogRepository) ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||
return r.listUsageLogsWithPagination(ctx, "WHERE user_id = $1", []any{userID}, params)
|
||
}
|
||
|
||
func (r *usageLogRepository) ListByAPIKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||
return r.listUsageLogsWithPagination(ctx, "WHERE api_key_id = $1", []any{apiKeyID}, params)
|
||
}
|
||
|
||
// UserStats 用户使用统计
|
||
type UserStats struct {
|
||
TotalRequests int64 `json:"total_requests"`
|
||
TotalTokens int64 `json:"total_tokens"`
|
||
TotalCost float64 `json:"total_cost"`
|
||
InputTokens int64 `json:"input_tokens"`
|
||
OutputTokens int64 `json:"output_tokens"`
|
||
CacheReadTokens int64 `json:"cache_read_tokens"`
|
||
}
|
||
|
||
func (r *usageLogRepository) GetUserStats(ctx context.Context, userID int64, startTime, endTime time.Time) (*UserStats, error) {
|
||
query := `
|
||
SELECT
|
||
COUNT(*) as total_requests,
|
||
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
|
||
COALESCE(SUM(actual_cost), 0) as total_cost,
|
||
COALESCE(SUM(input_tokens), 0) as input_tokens,
|
||
COALESCE(SUM(output_tokens), 0) as output_tokens,
|
||
COALESCE(SUM(cache_read_tokens), 0) as cache_read_tokens
|
||
FROM usage_logs
|
||
WHERE user_id = $1 AND created_at >= $2 AND created_at < $3
|
||
`
|
||
|
||
stats := &UserStats{}
|
||
if err := scanSingleRow(
|
||
ctx,
|
||
r.sql,
|
||
query,
|
||
[]any{userID, startTime, endTime},
|
||
&stats.TotalRequests,
|
||
&stats.TotalTokens,
|
||
&stats.TotalCost,
|
||
&stats.InputTokens,
|
||
&stats.OutputTokens,
|
||
&stats.CacheReadTokens,
|
||
); err != nil {
|
||
return nil, err
|
||
}
|
||
return stats, nil
|
||
}
|
||
|
||
// DashboardStats 仪表盘统计
|
||
type DashboardStats = usagestats.DashboardStats
|
||
|
||
func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardStats, error) {
|
||
stats := &DashboardStats{}
|
||
now := time.Now().UTC()
|
||
todayUTC := truncateToDayUTC(now)
|
||
|
||
if err := r.fillDashboardEntityStats(ctx, stats, todayUTC, now); err != nil {
|
||
return nil, err
|
||
}
|
||
if err := r.fillDashboardUsageStatsAggregated(ctx, stats, todayUTC, now); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
rpm, tpm, err := r.getPerformanceStats(ctx, 0)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
stats.Rpm = rpm
|
||
stats.Tpm = tpm
|
||
|
||
return stats, nil
|
||
}
|
||
|
||
func (r *usageLogRepository) GetDashboardStatsWithRange(ctx context.Context, start, end time.Time) (*DashboardStats, error) {
|
||
startUTC := start.UTC()
|
||
endUTC := end.UTC()
|
||
if !endUTC.After(startUTC) {
|
||
return nil, errors.New("统计时间范围无效")
|
||
}
|
||
|
||
stats := &DashboardStats{}
|
||
now := time.Now().UTC()
|
||
todayUTC := truncateToDayUTC(now)
|
||
|
||
if err := r.fillDashboardEntityStats(ctx, stats, todayUTC, now); err != nil {
|
||
return nil, err
|
||
}
|
||
if err := r.fillDashboardUsageStatsFromUsageLogs(ctx, stats, startUTC, endUTC, todayUTC, now); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
rpm, tpm, err := r.getPerformanceStats(ctx, 0)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
stats.Rpm = rpm
|
||
stats.Tpm = tpm
|
||
|
||
return stats, nil
|
||
}
|
||
|
||
func (r *usageLogRepository) fillDashboardEntityStats(ctx context.Context, stats *DashboardStats, todayUTC, now time.Time) error {
|
||
userStatsQuery := `
|
||
SELECT
|
||
COUNT(*) as total_users,
|
||
COUNT(CASE WHEN created_at >= $1 THEN 1 END) as today_new_users
|
||
FROM users
|
||
WHERE deleted_at IS NULL
|
||
`
|
||
if err := scanSingleRow(
|
||
ctx,
|
||
r.sql,
|
||
userStatsQuery,
|
||
[]any{todayUTC},
|
||
&stats.TotalUsers,
|
||
&stats.TodayNewUsers,
|
||
); err != nil {
|
||
return err
|
||
}
|
||
|
||
apiKeyStatsQuery := `
|
||
SELECT
|
||
COUNT(*) as total_api_keys,
|
||
COUNT(CASE WHEN status = $1 THEN 1 END) as active_api_keys
|
||
FROM api_keys
|
||
WHERE deleted_at IS NULL
|
||
`
|
||
if err := scanSingleRow(
|
||
ctx,
|
||
r.sql,
|
||
apiKeyStatsQuery,
|
||
[]any{service.StatusActive},
|
||
&stats.TotalAPIKeys,
|
||
&stats.ActiveAPIKeys,
|
||
); err != nil {
|
||
return err
|
||
}
|
||
|
||
accountStatsQuery := `
|
||
SELECT
|
||
COUNT(*) as total_accounts,
|
||
COUNT(CASE WHEN status = $1 AND schedulable = true THEN 1 END) as normal_accounts,
|
||
COUNT(CASE WHEN status = $2 THEN 1 END) as error_accounts,
|
||
COUNT(CASE WHEN rate_limited_at IS NOT NULL AND rate_limit_reset_at > $3 THEN 1 END) as ratelimit_accounts,
|
||
COUNT(CASE WHEN overload_until IS NOT NULL AND overload_until > $4 THEN 1 END) as overload_accounts
|
||
FROM accounts
|
||
WHERE deleted_at IS NULL
|
||
`
|
||
if err := scanSingleRow(
|
||
ctx,
|
||
r.sql,
|
||
accountStatsQuery,
|
||
[]any{service.StatusActive, service.StatusError, now, now},
|
||
&stats.TotalAccounts,
|
||
&stats.NormalAccounts,
|
||
&stats.ErrorAccounts,
|
||
&stats.RateLimitAccounts,
|
||
&stats.OverloadAccounts,
|
||
); err != nil {
|
||
return err
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
func (r *usageLogRepository) fillDashboardUsageStatsAggregated(ctx context.Context, stats *DashboardStats, todayUTC, now time.Time) error {
|
||
totalStatsQuery := `
|
||
SELECT
|
||
COALESCE(SUM(total_requests), 0) as total_requests,
|
||
COALESCE(SUM(input_tokens), 0) as total_input_tokens,
|
||
COALESCE(SUM(output_tokens), 0) as total_output_tokens,
|
||
COALESCE(SUM(cache_creation_tokens), 0) as total_cache_creation_tokens,
|
||
COALESCE(SUM(cache_read_tokens), 0) as total_cache_read_tokens,
|
||
COALESCE(SUM(total_cost), 0) as total_cost,
|
||
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
|
||
COALESCE(SUM(total_duration_ms), 0) as total_duration_ms
|
||
FROM usage_dashboard_daily
|
||
`
|
||
var totalDurationMs int64
|
||
if err := scanSingleRow(
|
||
ctx,
|
||
r.sql,
|
||
totalStatsQuery,
|
||
nil,
|
||
&stats.TotalRequests,
|
||
&stats.TotalInputTokens,
|
||
&stats.TotalOutputTokens,
|
||
&stats.TotalCacheCreationTokens,
|
||
&stats.TotalCacheReadTokens,
|
||
&stats.TotalCost,
|
||
&stats.TotalActualCost,
|
||
&totalDurationMs,
|
||
); err != nil {
|
||
return err
|
||
}
|
||
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheCreationTokens + stats.TotalCacheReadTokens
|
||
if stats.TotalRequests > 0 {
|
||
stats.AverageDurationMs = float64(totalDurationMs) / float64(stats.TotalRequests)
|
||
}
|
||
|
||
todayStatsQuery := `
|
||
SELECT
|
||
total_requests as today_requests,
|
||
input_tokens as today_input_tokens,
|
||
output_tokens as today_output_tokens,
|
||
cache_creation_tokens as today_cache_creation_tokens,
|
||
cache_read_tokens as today_cache_read_tokens,
|
||
total_cost as today_cost,
|
||
actual_cost as today_actual_cost,
|
||
active_users as active_users
|
||
FROM usage_dashboard_daily
|
||
WHERE bucket_date = $1::date
|
||
`
|
||
if err := scanSingleRow(
|
||
ctx,
|
||
r.sql,
|
||
todayStatsQuery,
|
||
[]any{todayUTC},
|
||
&stats.TodayRequests,
|
||
&stats.TodayInputTokens,
|
||
&stats.TodayOutputTokens,
|
||
&stats.TodayCacheCreationTokens,
|
||
&stats.TodayCacheReadTokens,
|
||
&stats.TodayCost,
|
||
&stats.TodayActualCost,
|
||
&stats.ActiveUsers,
|
||
); err != nil {
|
||
if err != sql.ErrNoRows {
|
||
return err
|
||
}
|
||
}
|
||
stats.TodayTokens = stats.TodayInputTokens + stats.TodayOutputTokens + stats.TodayCacheCreationTokens + stats.TodayCacheReadTokens
|
||
|
||
hourlyActiveQuery := `
|
||
SELECT active_users
|
||
FROM usage_dashboard_hourly
|
||
WHERE bucket_start = $1
|
||
`
|
||
hourStart := now.UTC().Truncate(time.Hour)
|
||
if err := scanSingleRow(ctx, r.sql, hourlyActiveQuery, []any{hourStart}, &stats.HourlyActiveUsers); err != nil {
|
||
if err != sql.ErrNoRows {
|
||
return err
|
||
}
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
func (r *usageLogRepository) fillDashboardUsageStatsFromUsageLogs(ctx context.Context, stats *DashboardStats, startUTC, endUTC, todayUTC, now time.Time) error {
|
||
totalStatsQuery := `
|
||
SELECT
|
||
COUNT(*) as total_requests,
|
||
COALESCE(SUM(input_tokens), 0) as total_input_tokens,
|
||
COALESCE(SUM(output_tokens), 0) as total_output_tokens,
|
||
COALESCE(SUM(cache_creation_tokens), 0) as total_cache_creation_tokens,
|
||
COALESCE(SUM(cache_read_tokens), 0) as total_cache_read_tokens,
|
||
COALESCE(SUM(total_cost), 0) as total_cost,
|
||
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
|
||
COALESCE(SUM(COALESCE(duration_ms, 0)), 0) as total_duration_ms
|
||
FROM usage_logs
|
||
WHERE created_at >= $1 AND created_at < $2
|
||
`
|
||
var totalDurationMs int64
|
||
if err := scanSingleRow(
|
||
ctx,
|
||
r.sql,
|
||
totalStatsQuery,
|
||
[]any{startUTC, endUTC},
|
||
&stats.TotalRequests,
|
||
&stats.TotalInputTokens,
|
||
&stats.TotalOutputTokens,
|
||
&stats.TotalCacheCreationTokens,
|
||
&stats.TotalCacheReadTokens,
|
||
&stats.TotalCost,
|
||
&stats.TotalActualCost,
|
||
&totalDurationMs,
|
||
); err != nil {
|
||
return err
|
||
}
|
||
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheCreationTokens + stats.TotalCacheReadTokens
|
||
if stats.TotalRequests > 0 {
|
||
stats.AverageDurationMs = float64(totalDurationMs) / float64(stats.TotalRequests)
|
||
}
|
||
|
||
todayEnd := todayUTC.Add(24 * time.Hour)
|
||
todayStatsQuery := `
|
||
SELECT
|
||
COUNT(*) as today_requests,
|
||
COALESCE(SUM(input_tokens), 0) as today_input_tokens,
|
||
COALESCE(SUM(output_tokens), 0) as today_output_tokens,
|
||
COALESCE(SUM(cache_creation_tokens), 0) as today_cache_creation_tokens,
|
||
COALESCE(SUM(cache_read_tokens), 0) as today_cache_read_tokens,
|
||
COALESCE(SUM(total_cost), 0) as today_cost,
|
||
COALESCE(SUM(actual_cost), 0) as today_actual_cost
|
||
FROM usage_logs
|
||
WHERE created_at >= $1 AND created_at < $2
|
||
`
|
||
if err := scanSingleRow(
|
||
ctx,
|
||
r.sql,
|
||
todayStatsQuery,
|
||
[]any{todayUTC, todayEnd},
|
||
&stats.TodayRequests,
|
||
&stats.TodayInputTokens,
|
||
&stats.TodayOutputTokens,
|
||
&stats.TodayCacheCreationTokens,
|
||
&stats.TodayCacheReadTokens,
|
||
&stats.TodayCost,
|
||
&stats.TodayActualCost,
|
||
); err != nil {
|
||
return err
|
||
}
|
||
stats.TodayTokens = stats.TodayInputTokens + stats.TodayOutputTokens + stats.TodayCacheCreationTokens + stats.TodayCacheReadTokens
|
||
|
||
activeUsersQuery := `
|
||
SELECT COUNT(DISTINCT user_id) as active_users
|
||
FROM usage_logs
|
||
WHERE created_at >= $1 AND created_at < $2
|
||
`
|
||
if err := scanSingleRow(ctx, r.sql, activeUsersQuery, []any{todayUTC, todayEnd}, &stats.ActiveUsers); err != nil {
|
||
return err
|
||
}
|
||
|
||
hourStart := now.UTC().Truncate(time.Hour)
|
||
hourEnd := hourStart.Add(time.Hour)
|
||
hourlyActiveQuery := `
|
||
SELECT COUNT(DISTINCT user_id) as active_users
|
||
FROM usage_logs
|
||
WHERE created_at >= $1 AND created_at < $2
|
||
`
|
||
if err := scanSingleRow(ctx, r.sql, hourlyActiveQuery, []any{hourStart, hourEnd}, &stats.HourlyActiveUsers); err != nil {
|
||
return err
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
func (r *usageLogRepository) ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||
return r.listUsageLogsWithPagination(ctx, "WHERE account_id = $1", []any{accountID}, params)
|
||
}
|
||
|
||
func (r *usageLogRepository) ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE user_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC"
|
||
logs, err := r.queryUsageLogs(ctx, query, userID, startTime, endTime)
|
||
return logs, nil, err
|
||
}
|
||
|
||
// GetUserStatsAggregated returns aggregated usage statistics for a user using database-level aggregation
|
||
func (r *usageLogRepository) GetUserStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
|
||
query := `
|
||
SELECT
|
||
COUNT(*) as total_requests,
|
||
COALESCE(SUM(input_tokens), 0) as total_input_tokens,
|
||
COALESCE(SUM(output_tokens), 0) as total_output_tokens,
|
||
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as total_cache_tokens,
|
||
COALESCE(SUM(total_cost), 0) as total_cost,
|
||
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
|
||
COALESCE(AVG(COALESCE(duration_ms, 0)), 0) as avg_duration_ms
|
||
FROM usage_logs
|
||
WHERE user_id = $1 AND created_at >= $2 AND created_at < $3
|
||
`
|
||
|
||
var stats usagestats.UsageStats
|
||
if err := scanSingleRow(
|
||
ctx,
|
||
r.sql,
|
||
query,
|
||
[]any{userID, startTime, endTime},
|
||
&stats.TotalRequests,
|
||
&stats.TotalInputTokens,
|
||
&stats.TotalOutputTokens,
|
||
&stats.TotalCacheTokens,
|
||
&stats.TotalCost,
|
||
&stats.TotalActualCost,
|
||
&stats.AverageDurationMs,
|
||
); err != nil {
|
||
return nil, err
|
||
}
|
||
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheTokens
|
||
return &stats, nil
|
||
}
|
||
|
||
// GetAPIKeyStatsAggregated returns aggregated usage statistics for an API key using database-level aggregation
|
||
func (r *usageLogRepository) GetAPIKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
|
||
query := `
|
||
SELECT
|
||
COUNT(*) as total_requests,
|
||
COALESCE(SUM(input_tokens), 0) as total_input_tokens,
|
||
COALESCE(SUM(output_tokens), 0) as total_output_tokens,
|
||
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as total_cache_tokens,
|
||
COALESCE(SUM(total_cost), 0) as total_cost,
|
||
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
|
||
COALESCE(AVG(COALESCE(duration_ms, 0)), 0) as avg_duration_ms
|
||
FROM usage_logs
|
||
WHERE api_key_id = $1 AND created_at >= $2 AND created_at < $3
|
||
`
|
||
|
||
var stats usagestats.UsageStats
|
||
if err := scanSingleRow(
|
||
ctx,
|
||
r.sql,
|
||
query,
|
||
[]any{apiKeyID, startTime, endTime},
|
||
&stats.TotalRequests,
|
||
&stats.TotalInputTokens,
|
||
&stats.TotalOutputTokens,
|
||
&stats.TotalCacheTokens,
|
||
&stats.TotalCost,
|
||
&stats.TotalActualCost,
|
||
&stats.AverageDurationMs,
|
||
); err != nil {
|
||
return nil, err
|
||
}
|
||
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheTokens
|
||
return &stats, nil
|
||
}
|
||
|
||
// GetAccountStatsAggregated 使用 SQL 聚合统计账号使用数据
|
||
//
|
||
// 性能优化说明:
|
||
// 原实现先查询所有日志记录,再在应用层循环计算统计值:
|
||
// 1. 需要传输大量数据到应用层
|
||
// 2. 应用层循环计算增加 CPU 和内存开销
|
||
//
|
||
// 新实现使用 SQL 聚合函数:
|
||
// 1. 在数据库层完成 COUNT/SUM/AVG 计算
|
||
// 2. 只返回单行聚合结果,大幅减少数据传输量
|
||
// 3. 利用数据库索引优化聚合查询性能
|
||
func (r *usageLogRepository) GetAccountStatsAggregated(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
|
||
query := `
|
||
SELECT
|
||
COUNT(*) as total_requests,
|
||
COALESCE(SUM(input_tokens), 0) as total_input_tokens,
|
||
COALESCE(SUM(output_tokens), 0) as total_output_tokens,
|
||
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as total_cache_tokens,
|
||
COALESCE(SUM(total_cost), 0) as total_cost,
|
||
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
|
||
COALESCE(AVG(COALESCE(duration_ms, 0)), 0) as avg_duration_ms
|
||
FROM usage_logs
|
||
WHERE account_id = $1 AND created_at >= $2 AND created_at < $3
|
||
`
|
||
|
||
var stats usagestats.UsageStats
|
||
if err := scanSingleRow(
|
||
ctx,
|
||
r.sql,
|
||
query,
|
||
[]any{accountID, startTime, endTime},
|
||
&stats.TotalRequests,
|
||
&stats.TotalInputTokens,
|
||
&stats.TotalOutputTokens,
|
||
&stats.TotalCacheTokens,
|
||
&stats.TotalCost,
|
||
&stats.TotalActualCost,
|
||
&stats.AverageDurationMs,
|
||
); err != nil {
|
||
return nil, err
|
||
}
|
||
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheTokens
|
||
return &stats, nil
|
||
}
|
||
|
||
// GetModelStatsAggregated 使用 SQL 聚合统计模型使用数据
|
||
// 性能优化:数据库层聚合计算,避免应用层循环统计
|
||
func (r *usageLogRepository) GetModelStatsAggregated(ctx context.Context, modelName string, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
|
||
query := `
|
||
SELECT
|
||
COUNT(*) as total_requests,
|
||
COALESCE(SUM(input_tokens), 0) as total_input_tokens,
|
||
COALESCE(SUM(output_tokens), 0) as total_output_tokens,
|
||
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as total_cache_tokens,
|
||
COALESCE(SUM(total_cost), 0) as total_cost,
|
||
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
|
||
COALESCE(AVG(COALESCE(duration_ms, 0)), 0) as avg_duration_ms
|
||
FROM usage_logs
|
||
WHERE model = $1 AND created_at >= $2 AND created_at < $3
|
||
`
|
||
|
||
var stats usagestats.UsageStats
|
||
if err := scanSingleRow(
|
||
ctx,
|
||
r.sql,
|
||
query,
|
||
[]any{modelName, startTime, endTime},
|
||
&stats.TotalRequests,
|
||
&stats.TotalInputTokens,
|
||
&stats.TotalOutputTokens,
|
||
&stats.TotalCacheTokens,
|
||
&stats.TotalCost,
|
||
&stats.TotalActualCost,
|
||
&stats.AverageDurationMs,
|
||
); err != nil {
|
||
return nil, err
|
||
}
|
||
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheTokens
|
||
return &stats, nil
|
||
}
|
||
|
||
// GetDailyStatsAggregated 使用 SQL 聚合统计用户的每日使用数据
|
||
// 性能优化:使用 GROUP BY 在数据库层按日期分组聚合,避免应用层循环分组统计
|
||
func (r *usageLogRepository) GetDailyStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) (result []map[string]any, err error) {
|
||
tzName := resolveUsageStatsTimezone()
|
||
query := `
|
||
SELECT
|
||
-- 使用应用时区分组,避免数据库会话时区导致日边界偏移。
|
||
TO_CHAR(created_at AT TIME ZONE $4, 'YYYY-MM-DD') as date,
|
||
COUNT(*) as total_requests,
|
||
COALESCE(SUM(input_tokens), 0) as total_input_tokens,
|
||
COALESCE(SUM(output_tokens), 0) as total_output_tokens,
|
||
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as total_cache_tokens,
|
||
COALESCE(SUM(total_cost), 0) as total_cost,
|
||
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
|
||
COALESCE(AVG(COALESCE(duration_ms, 0)), 0) as avg_duration_ms
|
||
FROM usage_logs
|
||
WHERE user_id = $1 AND created_at >= $2 AND created_at < $3
|
||
GROUP BY 1
|
||
ORDER BY 1
|
||
`
|
||
|
||
rows, err := r.sql.QueryContext(ctx, query, userID, startTime, endTime, tzName)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
defer func() {
|
||
if closeErr := rows.Close(); closeErr != nil && err == nil {
|
||
err = closeErr
|
||
result = nil
|
||
}
|
||
}()
|
||
|
||
result = make([]map[string]any, 0)
|
||
for rows.Next() {
|
||
var (
|
||
date string
|
||
totalRequests int64
|
||
totalInputTokens int64
|
||
totalOutputTokens int64
|
||
totalCacheTokens int64
|
||
totalCost float64
|
||
totalActualCost float64
|
||
avgDurationMs float64
|
||
)
|
||
if err = rows.Scan(
|
||
&date,
|
||
&totalRequests,
|
||
&totalInputTokens,
|
||
&totalOutputTokens,
|
||
&totalCacheTokens,
|
||
&totalCost,
|
||
&totalActualCost,
|
||
&avgDurationMs,
|
||
); err != nil {
|
||
return nil, err
|
||
}
|
||
result = append(result, map[string]any{
|
||
"date": date,
|
||
"total_requests": totalRequests,
|
||
"total_input_tokens": totalInputTokens,
|
||
"total_output_tokens": totalOutputTokens,
|
||
"total_cache_tokens": totalCacheTokens,
|
||
"total_tokens": totalInputTokens + totalOutputTokens + totalCacheTokens,
|
||
"total_cost": totalCost,
|
||
"total_actual_cost": totalActualCost,
|
||
"average_duration_ms": avgDurationMs,
|
||
})
|
||
}
|
||
|
||
if err = rows.Err(); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
return result, nil
|
||
}
|
||
|
||
// resolveUsageStatsTimezone 获取用于 SQL 分组的时区名称。
|
||
// 优先使用应用初始化的时区,其次尝试读取 TZ 环境变量,最后回落为 UTC。
|
||
func resolveUsageStatsTimezone() string {
|
||
tzName := timezone.Name()
|
||
if tzName != "" && tzName != "Local" {
|
||
return tzName
|
||
}
|
||
if envTZ := strings.TrimSpace(os.Getenv("TZ")); envTZ != "" {
|
||
return envTZ
|
||
}
|
||
return "UTC"
|
||
}
|
||
|
||
func (r *usageLogRepository) ListByAPIKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE api_key_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC"
|
||
logs, err := r.queryUsageLogs(ctx, query, apiKeyID, startTime, endTime)
|
||
return logs, nil, err
|
||
}
|
||
|
||
func (r *usageLogRepository) ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE account_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC"
|
||
logs, err := r.queryUsageLogs(ctx, query, accountID, startTime, endTime)
|
||
return logs, nil, err
|
||
}
|
||
|
||
func (r *usageLogRepository) ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE model = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC"
|
||
logs, err := r.queryUsageLogs(ctx, query, modelName, startTime, endTime)
|
||
return logs, nil, err
|
||
}
|
||
|
||
func (r *usageLogRepository) Delete(ctx context.Context, id int64) error {
|
||
_, err := r.sql.ExecContext(ctx, "DELETE FROM usage_logs WHERE id = $1", id)
|
||
return err
|
||
}
|
||
|
||
// GetAccountTodayStats 获取账号今日统计
|
||
func (r *usageLogRepository) GetAccountTodayStats(ctx context.Context, accountID int64) (*usagestats.AccountStats, error) {
|
||
today := timezone.Today()
|
||
|
||
query := `
|
||
SELECT
|
||
COUNT(*) as requests,
|
||
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
|
||
COALESCE(SUM(actual_cost), 0) as cost
|
||
FROM usage_logs
|
||
WHERE account_id = $1 AND created_at >= $2
|
||
`
|
||
|
||
stats := &usagestats.AccountStats{}
|
||
if err := scanSingleRow(
|
||
ctx,
|
||
r.sql,
|
||
query,
|
||
[]any{accountID, today},
|
||
&stats.Requests,
|
||
&stats.Tokens,
|
||
&stats.Cost,
|
||
); err != nil {
|
||
return nil, err
|
||
}
|
||
return stats, nil
|
||
}
|
||
|
||
// GetAccountWindowStats 获取账号时间窗口内的统计
|
||
func (r *usageLogRepository) GetAccountWindowStats(ctx context.Context, accountID int64, startTime time.Time) (*usagestats.AccountStats, error) {
|
||
query := `
|
||
SELECT
|
||
COUNT(*) as requests,
|
||
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
|
||
COALESCE(SUM(actual_cost), 0) as cost
|
||
FROM usage_logs
|
||
WHERE account_id = $1 AND created_at >= $2
|
||
`
|
||
|
||
stats := &usagestats.AccountStats{}
|
||
if err := scanSingleRow(
|
||
ctx,
|
||
r.sql,
|
||
query,
|
||
[]any{accountID, startTime},
|
||
&stats.Requests,
|
||
&stats.Tokens,
|
||
&stats.Cost,
|
||
); err != nil {
|
||
return nil, err
|
||
}
|
||
return stats, nil
|
||
}
|
||
|
||
// TrendDataPoint represents a single point in trend data
|
||
type TrendDataPoint = usagestats.TrendDataPoint
|
||
|
||
// ModelStat represents usage statistics for a single model
|
||
type ModelStat = usagestats.ModelStat
|
||
|
||
// UserUsageTrendPoint represents user usage trend data point
|
||
type UserUsageTrendPoint = usagestats.UserUsageTrendPoint
|
||
|
||
// APIKeyUsageTrendPoint represents API key usage trend data point
|
||
type APIKeyUsageTrendPoint = usagestats.APIKeyUsageTrendPoint
|
||
|
||
// GetAPIKeyUsageTrend returns usage trend data grouped by API key and date
|
||
func (r *usageLogRepository) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) (results []APIKeyUsageTrendPoint, err error) {
|
||
dateFormat := "YYYY-MM-DD"
|
||
if granularity == "hour" {
|
||
dateFormat = "YYYY-MM-DD HH24:00"
|
||
}
|
||
|
||
query := fmt.Sprintf(`
|
||
WITH top_keys AS (
|
||
SELECT api_key_id
|
||
FROM usage_logs
|
||
WHERE created_at >= $1 AND created_at < $2
|
||
GROUP BY api_key_id
|
||
ORDER BY SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens) DESC
|
||
LIMIT $3
|
||
)
|
||
SELECT
|
||
TO_CHAR(u.created_at, '%s') as date,
|
||
u.api_key_id,
|
||
COALESCE(k.name, '') as key_name,
|
||
COUNT(*) as requests,
|
||
COALESCE(SUM(u.input_tokens + u.output_tokens + u.cache_creation_tokens + u.cache_read_tokens), 0) as tokens
|
||
FROM usage_logs u
|
||
LEFT JOIN api_keys k ON u.api_key_id = k.id
|
||
WHERE u.api_key_id IN (SELECT api_key_id FROM top_keys)
|
||
AND u.created_at >= $4 AND u.created_at < $5
|
||
GROUP BY date, u.api_key_id, k.name
|
||
ORDER BY date ASC, tokens DESC
|
||
`, dateFormat)
|
||
|
||
rows, err := r.sql.QueryContext(ctx, query, startTime, endTime, limit, startTime, endTime)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
defer func() {
|
||
// 保持主错误优先;仅在无错误时回传 Close 失败。
|
||
// 同时清空返回值,避免误用不完整结果。
|
||
if closeErr := rows.Close(); closeErr != nil && err == nil {
|
||
err = closeErr
|
||
results = nil
|
||
}
|
||
}()
|
||
|
||
results = make([]APIKeyUsageTrendPoint, 0)
|
||
for rows.Next() {
|
||
var row APIKeyUsageTrendPoint
|
||
if err = rows.Scan(&row.Date, &row.APIKeyID, &row.KeyName, &row.Requests, &row.Tokens); err != nil {
|
||
return nil, err
|
||
}
|
||
results = append(results, row)
|
||
}
|
||
if err = rows.Err(); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
return results, nil
|
||
}
|
||
|
||
// GetUserUsageTrend returns usage trend data grouped by user and date
|
||
func (r *usageLogRepository) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) (results []UserUsageTrendPoint, err error) {
|
||
dateFormat := "YYYY-MM-DD"
|
||
if granularity == "hour" {
|
||
dateFormat = "YYYY-MM-DD HH24:00"
|
||
}
|
||
|
||
query := fmt.Sprintf(`
|
||
WITH top_users AS (
|
||
SELECT user_id
|
||
FROM usage_logs
|
||
WHERE created_at >= $1 AND created_at < $2
|
||
GROUP BY user_id
|
||
ORDER BY SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens) DESC
|
||
LIMIT $3
|
||
)
|
||
SELECT
|
||
TO_CHAR(u.created_at, '%s') as date,
|
||
u.user_id,
|
||
COALESCE(us.email, '') as email,
|
||
COUNT(*) as requests,
|
||
COALESCE(SUM(u.input_tokens + u.output_tokens + u.cache_creation_tokens + u.cache_read_tokens), 0) as tokens,
|
||
COALESCE(SUM(u.total_cost), 0) as cost,
|
||
COALESCE(SUM(u.actual_cost), 0) as actual_cost
|
||
FROM usage_logs u
|
||
LEFT JOIN users us ON u.user_id = us.id
|
||
WHERE u.user_id IN (SELECT user_id FROM top_users)
|
||
AND u.created_at >= $4 AND u.created_at < $5
|
||
GROUP BY date, u.user_id, us.email
|
||
ORDER BY date ASC, tokens DESC
|
||
`, dateFormat)
|
||
|
||
rows, err := r.sql.QueryContext(ctx, query, startTime, endTime, limit, startTime, endTime)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
defer func() {
|
||
// 保持主错误优先;仅在无错误时回传 Close 失败。
|
||
// 同时清空返回值,避免误用不完整结果。
|
||
if closeErr := rows.Close(); closeErr != nil && err == nil {
|
||
err = closeErr
|
||
results = nil
|
||
}
|
||
}()
|
||
|
||
results = make([]UserUsageTrendPoint, 0)
|
||
for rows.Next() {
|
||
var row UserUsageTrendPoint
|
||
if err = rows.Scan(&row.Date, &row.UserID, &row.Email, &row.Requests, &row.Tokens, &row.Cost, &row.ActualCost); err != nil {
|
||
return nil, err
|
||
}
|
||
results = append(results, row)
|
||
}
|
||
if err = rows.Err(); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
return results, nil
|
||
}
|
||
|
||
// UserDashboardStats 用户仪表盘统计
|
||
type UserDashboardStats = usagestats.UserDashboardStats
|
||
|
||
// GetUserDashboardStats 获取用户专属的仪表盘统计
|
||
func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID int64) (*UserDashboardStats, error) {
|
||
stats := &UserDashboardStats{}
|
||
today := timezone.Today()
|
||
|
||
// API Key 统计
|
||
if err := scanSingleRow(
|
||
ctx,
|
||
r.sql,
|
||
"SELECT COUNT(*) FROM api_keys WHERE user_id = $1 AND deleted_at IS NULL",
|
||
[]any{userID},
|
||
&stats.TotalAPIKeys,
|
||
); err != nil {
|
||
return nil, err
|
||
}
|
||
if err := scanSingleRow(
|
||
ctx,
|
||
r.sql,
|
||
"SELECT COUNT(*) FROM api_keys WHERE user_id = $1 AND status = $2 AND deleted_at IS NULL",
|
||
[]any{userID, service.StatusActive},
|
||
&stats.ActiveAPIKeys,
|
||
); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// 累计 Token 统计
|
||
totalStatsQuery := `
|
||
SELECT
|
||
COUNT(*) as total_requests,
|
||
COALESCE(SUM(input_tokens), 0) as total_input_tokens,
|
||
COALESCE(SUM(output_tokens), 0) as total_output_tokens,
|
||
COALESCE(SUM(cache_creation_tokens), 0) as total_cache_creation_tokens,
|
||
COALESCE(SUM(cache_read_tokens), 0) as total_cache_read_tokens,
|
||
COALESCE(SUM(total_cost), 0) as total_cost,
|
||
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
|
||
COALESCE(AVG(duration_ms), 0) as avg_duration_ms
|
||
FROM usage_logs
|
||
WHERE user_id = $1
|
||
`
|
||
if err := scanSingleRow(
|
||
ctx,
|
||
r.sql,
|
||
totalStatsQuery,
|
||
[]any{userID},
|
||
&stats.TotalRequests,
|
||
&stats.TotalInputTokens,
|
||
&stats.TotalOutputTokens,
|
||
&stats.TotalCacheCreationTokens,
|
||
&stats.TotalCacheReadTokens,
|
||
&stats.TotalCost,
|
||
&stats.TotalActualCost,
|
||
&stats.AverageDurationMs,
|
||
); err != nil {
|
||
return nil, err
|
||
}
|
||
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheCreationTokens + stats.TotalCacheReadTokens
|
||
|
||
// 今日 Token 统计
|
||
todayStatsQuery := `
|
||
SELECT
|
||
COUNT(*) as today_requests,
|
||
COALESCE(SUM(input_tokens), 0) as today_input_tokens,
|
||
COALESCE(SUM(output_tokens), 0) as today_output_tokens,
|
||
COALESCE(SUM(cache_creation_tokens), 0) as today_cache_creation_tokens,
|
||
COALESCE(SUM(cache_read_tokens), 0) as today_cache_read_tokens,
|
||
COALESCE(SUM(total_cost), 0) as today_cost,
|
||
COALESCE(SUM(actual_cost), 0) as today_actual_cost
|
||
FROM usage_logs
|
||
WHERE user_id = $1 AND created_at >= $2
|
||
`
|
||
if err := scanSingleRow(
|
||
ctx,
|
||
r.sql,
|
||
todayStatsQuery,
|
||
[]any{userID, today},
|
||
&stats.TodayRequests,
|
||
&stats.TodayInputTokens,
|
||
&stats.TodayOutputTokens,
|
||
&stats.TodayCacheCreationTokens,
|
||
&stats.TodayCacheReadTokens,
|
||
&stats.TodayCost,
|
||
&stats.TodayActualCost,
|
||
); err != nil {
|
||
return nil, err
|
||
}
|
||
stats.TodayTokens = stats.TodayInputTokens + stats.TodayOutputTokens + stats.TodayCacheCreationTokens + stats.TodayCacheReadTokens
|
||
|
||
// 性能指标:RPM 和 TPM(最近1分钟,仅统计该用户的请求)
|
||
rpm, tpm, err := r.getPerformanceStats(ctx, userID)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
stats.Rpm = rpm
|
||
stats.Tpm = tpm
|
||
|
||
return stats, nil
|
||
}
|
||
|
||
// GetUserUsageTrendByUserID 获取指定用户的使用趋势
|
||
func (r *usageLogRepository) GetUserUsageTrendByUserID(ctx context.Context, userID int64, startTime, endTime time.Time, granularity string) (results []TrendDataPoint, err error) {
|
||
dateFormat := "YYYY-MM-DD"
|
||
if granularity == "hour" {
|
||
dateFormat = "YYYY-MM-DD HH24:00"
|
||
}
|
||
|
||
query := fmt.Sprintf(`
|
||
SELECT
|
||
TO_CHAR(created_at, '%s') as date,
|
||
COUNT(*) as requests,
|
||
COALESCE(SUM(input_tokens), 0) as input_tokens,
|
||
COALESCE(SUM(output_tokens), 0) as output_tokens,
|
||
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as cache_tokens,
|
||
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
|
||
COALESCE(SUM(total_cost), 0) as cost,
|
||
COALESCE(SUM(actual_cost), 0) as actual_cost
|
||
FROM usage_logs
|
||
WHERE user_id = $1 AND created_at >= $2 AND created_at < $3
|
||
GROUP BY date
|
||
ORDER BY date ASC
|
||
`, dateFormat)
|
||
|
||
rows, err := r.sql.QueryContext(ctx, query, userID, startTime, endTime)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
defer func() {
|
||
// 保持主错误优先;仅在无错误时回传 Close 失败。
|
||
// 同时清空返回值,避免误用不完整结果。
|
||
if closeErr := rows.Close(); closeErr != nil && err == nil {
|
||
err = closeErr
|
||
results = nil
|
||
}
|
||
}()
|
||
|
||
results, err = scanTrendRows(rows)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return results, nil
|
||
}
|
||
|
||
// GetUserModelStats 获取指定用户的模型统计
|
||
func (r *usageLogRepository) GetUserModelStats(ctx context.Context, userID int64, startTime, endTime time.Time) (results []ModelStat, err error) {
|
||
query := `
|
||
SELECT
|
||
model,
|
||
COUNT(*) as requests,
|
||
COALESCE(SUM(input_tokens), 0) as input_tokens,
|
||
COALESCE(SUM(output_tokens), 0) as output_tokens,
|
||
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
|
||
COALESCE(SUM(total_cost), 0) as cost,
|
||
COALESCE(SUM(actual_cost), 0) as actual_cost
|
||
FROM usage_logs
|
||
WHERE user_id = $1 AND created_at >= $2 AND created_at < $3
|
||
GROUP BY model
|
||
ORDER BY total_tokens DESC
|
||
`
|
||
|
||
rows, err := r.sql.QueryContext(ctx, query, userID, startTime, endTime)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
defer func() {
|
||
// 保持主错误优先;仅在无错误时回传 Close 失败。
|
||
// 同时清空返回值,避免误用不完整结果。
|
||
if closeErr := rows.Close(); closeErr != nil && err == nil {
|
||
err = closeErr
|
||
results = nil
|
||
}
|
||
}()
|
||
|
||
results, err = scanModelStatsRows(rows)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return results, nil
|
||
}
|
||
|
||
// UsageLogFilters represents filters for usage log queries
|
||
type UsageLogFilters = usagestats.UsageLogFilters
|
||
|
||
// ListWithFilters lists usage logs with optional filters (for admin)
|
||
func (r *usageLogRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters UsageLogFilters) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||
conditions := make([]string, 0, 8)
|
||
args := make([]any, 0, 8)
|
||
|
||
if filters.UserID > 0 {
|
||
conditions = append(conditions, fmt.Sprintf("user_id = $%d", len(args)+1))
|
||
args = append(args, filters.UserID)
|
||
}
|
||
if filters.APIKeyID > 0 {
|
||
conditions = append(conditions, fmt.Sprintf("api_key_id = $%d", len(args)+1))
|
||
args = append(args, filters.APIKeyID)
|
||
}
|
||
if filters.AccountID > 0 {
|
||
conditions = append(conditions, fmt.Sprintf("account_id = $%d", len(args)+1))
|
||
args = append(args, filters.AccountID)
|
||
}
|
||
if filters.GroupID > 0 {
|
||
conditions = append(conditions, fmt.Sprintf("group_id = $%d", len(args)+1))
|
||
args = append(args, filters.GroupID)
|
||
}
|
||
if filters.Model != "" {
|
||
conditions = append(conditions, fmt.Sprintf("model = $%d", len(args)+1))
|
||
args = append(args, filters.Model)
|
||
}
|
||
if filters.Stream != nil {
|
||
conditions = append(conditions, fmt.Sprintf("stream = $%d", len(args)+1))
|
||
args = append(args, *filters.Stream)
|
||
}
|
||
if filters.BillingType != nil {
|
||
conditions = append(conditions, fmt.Sprintf("billing_type = $%d", len(args)+1))
|
||
args = append(args, int16(*filters.BillingType))
|
||
}
|
||
if filters.StartTime != nil {
|
||
conditions = append(conditions, fmt.Sprintf("created_at >= $%d", len(args)+1))
|
||
args = append(args, *filters.StartTime)
|
||
}
|
||
if filters.EndTime != nil {
|
||
conditions = append(conditions, fmt.Sprintf("created_at <= $%d", len(args)+1))
|
||
args = append(args, *filters.EndTime)
|
||
}
|
||
|
||
whereClause := buildWhere(conditions)
|
||
logs, page, err := r.listUsageLogsWithPagination(ctx, whereClause, args, params)
|
||
if err != nil {
|
||
return nil, nil, err
|
||
}
|
||
|
||
if err := r.hydrateUsageLogAssociations(ctx, logs); err != nil {
|
||
return nil, nil, err
|
||
}
|
||
return logs, page, nil
|
||
}
|
||
|
||
// UsageStats represents usage statistics
|
||
type UsageStats = usagestats.UsageStats
|
||
|
||
// BatchUserUsageStats represents usage stats for a single user
|
||
type BatchUserUsageStats = usagestats.BatchUserUsageStats
|
||
|
||
// GetBatchUserUsageStats gets today and total actual_cost for multiple users
|
||
func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*BatchUserUsageStats, error) {
|
||
result := make(map[int64]*BatchUserUsageStats)
|
||
if len(userIDs) == 0 {
|
||
return result, nil
|
||
}
|
||
|
||
for _, id := range userIDs {
|
||
result[id] = &BatchUserUsageStats{UserID: id}
|
||
}
|
||
|
||
query := `
|
||
SELECT user_id, COALESCE(SUM(actual_cost), 0) as total_cost
|
||
FROM usage_logs
|
||
WHERE user_id = ANY($1)
|
||
GROUP BY user_id
|
||
`
|
||
rows, err := r.sql.QueryContext(ctx, query, pq.Array(userIDs))
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
for rows.Next() {
|
||
var userID int64
|
||
var total float64
|
||
if err := rows.Scan(&userID, &total); err != nil {
|
||
_ = rows.Close()
|
||
return nil, err
|
||
}
|
||
if stats, ok := result[userID]; ok {
|
||
stats.TotalActualCost = total
|
||
}
|
||
}
|
||
if err := rows.Close(); err != nil {
|
||
return nil, err
|
||
}
|
||
if err := rows.Err(); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
today := timezone.Today()
|
||
todayQuery := `
|
||
SELECT user_id, COALESCE(SUM(actual_cost), 0) as today_cost
|
||
FROM usage_logs
|
||
WHERE user_id = ANY($1) AND created_at >= $2
|
||
GROUP BY user_id
|
||
`
|
||
rows, err = r.sql.QueryContext(ctx, todayQuery, pq.Array(userIDs), today)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
for rows.Next() {
|
||
var userID int64
|
||
var total float64
|
||
if err := rows.Scan(&userID, &total); err != nil {
|
||
_ = rows.Close()
|
||
return nil, err
|
||
}
|
||
if stats, ok := result[userID]; ok {
|
||
stats.TodayActualCost = total
|
||
}
|
||
}
|
||
if err := rows.Close(); err != nil {
|
||
return nil, err
|
||
}
|
||
if err := rows.Err(); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
return result, nil
|
||
}
|
||
|
||
// BatchAPIKeyUsageStats represents usage stats for a single API key
|
||
type BatchAPIKeyUsageStats = usagestats.BatchAPIKeyUsageStats
|
||
|
||
// GetBatchAPIKeyUsageStats gets today and total actual_cost for multiple API keys
|
||
func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*BatchAPIKeyUsageStats, error) {
|
||
result := make(map[int64]*BatchAPIKeyUsageStats)
|
||
if len(apiKeyIDs) == 0 {
|
||
return result, nil
|
||
}
|
||
|
||
for _, id := range apiKeyIDs {
|
||
result[id] = &BatchAPIKeyUsageStats{APIKeyID: id}
|
||
}
|
||
|
||
query := `
|
||
SELECT api_key_id, COALESCE(SUM(actual_cost), 0) as total_cost
|
||
FROM usage_logs
|
||
WHERE api_key_id = ANY($1)
|
||
GROUP BY api_key_id
|
||
`
|
||
rows, err := r.sql.QueryContext(ctx, query, pq.Array(apiKeyIDs))
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
for rows.Next() {
|
||
var apiKeyID int64
|
||
var total float64
|
||
if err := rows.Scan(&apiKeyID, &total); err != nil {
|
||
_ = rows.Close()
|
||
return nil, err
|
||
}
|
||
if stats, ok := result[apiKeyID]; ok {
|
||
stats.TotalActualCost = total
|
||
}
|
||
}
|
||
if err := rows.Close(); err != nil {
|
||
return nil, err
|
||
}
|
||
if err := rows.Err(); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
today := timezone.Today()
|
||
todayQuery := `
|
||
SELECT api_key_id, COALESCE(SUM(actual_cost), 0) as today_cost
|
||
FROM usage_logs
|
||
WHERE api_key_id = ANY($1) AND created_at >= $2
|
||
GROUP BY api_key_id
|
||
`
|
||
rows, err = r.sql.QueryContext(ctx, todayQuery, pq.Array(apiKeyIDs), today)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
for rows.Next() {
|
||
var apiKeyID int64
|
||
var total float64
|
||
if err := rows.Scan(&apiKeyID, &total); err != nil {
|
||
_ = rows.Close()
|
||
return nil, err
|
||
}
|
||
if stats, ok := result[apiKeyID]; ok {
|
||
stats.TodayActualCost = total
|
||
}
|
||
}
|
||
if err := rows.Close(); err != nil {
|
||
return nil, err
|
||
}
|
||
if err := rows.Err(); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
return result, nil
|
||
}
|
||
|
||
// GetUsageTrendWithFilters returns usage trend data with optional user/api_key filters
|
||
func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID int64) (results []TrendDataPoint, err error) {
|
||
dateFormat := "YYYY-MM-DD"
|
||
if granularity == "hour" {
|
||
dateFormat = "YYYY-MM-DD HH24:00"
|
||
}
|
||
|
||
query := fmt.Sprintf(`
|
||
SELECT
|
||
TO_CHAR(created_at, '%s') as date,
|
||
COUNT(*) as requests,
|
||
COALESCE(SUM(input_tokens), 0) as input_tokens,
|
||
COALESCE(SUM(output_tokens), 0) as output_tokens,
|
||
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as cache_tokens,
|
||
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
|
||
COALESCE(SUM(total_cost), 0) as cost,
|
||
COALESCE(SUM(actual_cost), 0) as actual_cost
|
||
FROM usage_logs
|
||
WHERE created_at >= $1 AND created_at < $2
|
||
`, dateFormat)
|
||
|
||
args := []any{startTime, endTime}
|
||
if userID > 0 {
|
||
query += fmt.Sprintf(" AND user_id = $%d", len(args)+1)
|
||
args = append(args, userID)
|
||
}
|
||
if apiKeyID > 0 {
|
||
query += fmt.Sprintf(" AND api_key_id = $%d", len(args)+1)
|
||
args = append(args, apiKeyID)
|
||
}
|
||
query += " GROUP BY date ORDER BY date ASC"
|
||
|
||
rows, err := r.sql.QueryContext(ctx, query, args...)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
defer func() {
|
||
// 保持主错误优先;仅在无错误时回传 Close 失败。
|
||
// 同时清空返回值,避免误用不完整结果。
|
||
if closeErr := rows.Close(); closeErr != nil && err == nil {
|
||
err = closeErr
|
||
results = nil
|
||
}
|
||
}()
|
||
|
||
results, err = scanTrendRows(rows)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return results, nil
|
||
}
|
||
|
||
// GetModelStatsWithFilters returns model statistics with optional user/api_key filters
|
||
func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID int64) (results []ModelStat, err error) {
|
||
query := `
|
||
SELECT
|
||
model,
|
||
COUNT(*) as requests,
|
||
COALESCE(SUM(input_tokens), 0) as input_tokens,
|
||
COALESCE(SUM(output_tokens), 0) as output_tokens,
|
||
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
|
||
COALESCE(SUM(total_cost), 0) as cost,
|
||
COALESCE(SUM(actual_cost), 0) as actual_cost
|
||
FROM usage_logs
|
||
WHERE created_at >= $1 AND created_at < $2
|
||
`
|
||
|
||
args := []any{startTime, endTime}
|
||
if userID > 0 {
|
||
query += fmt.Sprintf(" AND user_id = $%d", len(args)+1)
|
||
args = append(args, userID)
|
||
}
|
||
if apiKeyID > 0 {
|
||
query += fmt.Sprintf(" AND api_key_id = $%d", len(args)+1)
|
||
args = append(args, apiKeyID)
|
||
}
|
||
if accountID > 0 {
|
||
query += fmt.Sprintf(" AND account_id = $%d", len(args)+1)
|
||
args = append(args, accountID)
|
||
}
|
||
query += " GROUP BY model ORDER BY total_tokens DESC"
|
||
|
||
rows, err := r.sql.QueryContext(ctx, query, args...)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
defer func() {
|
||
// 保持主错误优先;仅在无错误时回传 Close 失败。
|
||
// 同时清空返回值,避免误用不完整结果。
|
||
if closeErr := rows.Close(); closeErr != nil && err == nil {
|
||
err = closeErr
|
||
results = nil
|
||
}
|
||
}()
|
||
|
||
results, err = scanModelStatsRows(rows)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return results, nil
|
||
}
|
||
|
||
// GetGlobalStats gets usage statistics for all users within a time range
|
||
func (r *usageLogRepository) GetGlobalStats(ctx context.Context, startTime, endTime time.Time) (*UsageStats, error) {
|
||
query := `
|
||
SELECT
|
||
COUNT(*) as total_requests,
|
||
COALESCE(SUM(input_tokens), 0) as total_input_tokens,
|
||
COALESCE(SUM(output_tokens), 0) as total_output_tokens,
|
||
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as total_cache_tokens,
|
||
COALESCE(SUM(total_cost), 0) as total_cost,
|
||
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
|
||
COALESCE(AVG(duration_ms), 0) as avg_duration_ms
|
||
FROM usage_logs
|
||
WHERE created_at >= $1 AND created_at <= $2
|
||
`
|
||
|
||
stats := &UsageStats{}
|
||
if err := scanSingleRow(
|
||
ctx,
|
||
r.sql,
|
||
query,
|
||
[]any{startTime, endTime},
|
||
&stats.TotalRequests,
|
||
&stats.TotalInputTokens,
|
||
&stats.TotalOutputTokens,
|
||
&stats.TotalCacheTokens,
|
||
&stats.TotalCost,
|
||
&stats.TotalActualCost,
|
||
&stats.AverageDurationMs,
|
||
); err != nil {
|
||
return nil, err
|
||
}
|
||
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheTokens
|
||
return stats, nil
|
||
}
|
||
|
||
// GetStatsWithFilters gets usage statistics with optional filters
|
||
func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters UsageLogFilters) (*UsageStats, error) {
|
||
conditions := make([]string, 0, 9)
|
||
args := make([]any, 0, 9)
|
||
|
||
if filters.UserID > 0 {
|
||
conditions = append(conditions, fmt.Sprintf("user_id = $%d", len(args)+1))
|
||
args = append(args, filters.UserID)
|
||
}
|
||
if filters.APIKeyID > 0 {
|
||
conditions = append(conditions, fmt.Sprintf("api_key_id = $%d", len(args)+1))
|
||
args = append(args, filters.APIKeyID)
|
||
}
|
||
if filters.AccountID > 0 {
|
||
conditions = append(conditions, fmt.Sprintf("account_id = $%d", len(args)+1))
|
||
args = append(args, filters.AccountID)
|
||
}
|
||
if filters.GroupID > 0 {
|
||
conditions = append(conditions, fmt.Sprintf("group_id = $%d", len(args)+1))
|
||
args = append(args, filters.GroupID)
|
||
}
|
||
if filters.Model != "" {
|
||
conditions = append(conditions, fmt.Sprintf("model = $%d", len(args)+1))
|
||
args = append(args, filters.Model)
|
||
}
|
||
if filters.Stream != nil {
|
||
conditions = append(conditions, fmt.Sprintf("stream = $%d", len(args)+1))
|
||
args = append(args, *filters.Stream)
|
||
}
|
||
if filters.BillingType != nil {
|
||
conditions = append(conditions, fmt.Sprintf("billing_type = $%d", len(args)+1))
|
||
args = append(args, int16(*filters.BillingType))
|
||
}
|
||
if filters.StartTime != nil {
|
||
conditions = append(conditions, fmt.Sprintf("created_at >= $%d", len(args)+1))
|
||
args = append(args, *filters.StartTime)
|
||
}
|
||
if filters.EndTime != nil {
|
||
conditions = append(conditions, fmt.Sprintf("created_at <= $%d", len(args)+1))
|
||
args = append(args, *filters.EndTime)
|
||
}
|
||
|
||
query := fmt.Sprintf(`
|
||
SELECT
|
||
COUNT(*) as total_requests,
|
||
COALESCE(SUM(input_tokens), 0) as total_input_tokens,
|
||
COALESCE(SUM(output_tokens), 0) as total_output_tokens,
|
||
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as total_cache_tokens,
|
||
COALESCE(SUM(total_cost), 0) as total_cost,
|
||
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
|
||
COALESCE(AVG(duration_ms), 0) as avg_duration_ms
|
||
FROM usage_logs
|
||
%s
|
||
`, buildWhere(conditions))
|
||
|
||
stats := &UsageStats{}
|
||
if err := scanSingleRow(
|
||
ctx,
|
||
r.sql,
|
||
query,
|
||
args,
|
||
&stats.TotalRequests,
|
||
&stats.TotalInputTokens,
|
||
&stats.TotalOutputTokens,
|
||
&stats.TotalCacheTokens,
|
||
&stats.TotalCost,
|
||
&stats.TotalActualCost,
|
||
&stats.AverageDurationMs,
|
||
); err != nil {
|
||
return nil, err
|
||
}
|
||
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheTokens
|
||
return stats, nil
|
||
}
|
||
|
||
// AccountUsageHistory represents daily usage history for an account
|
||
type AccountUsageHistory = usagestats.AccountUsageHistory
|
||
|
||
// AccountUsageSummary represents summary statistics for an account
|
||
type AccountUsageSummary = usagestats.AccountUsageSummary
|
||
|
||
// AccountUsageStatsResponse represents the full usage statistics response for an account
|
||
type AccountUsageStatsResponse = usagestats.AccountUsageStatsResponse
|
||
|
||
// GetAccountUsageStats returns comprehensive usage statistics for an account over a time range
|
||
func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID int64, startTime, endTime time.Time) (resp *AccountUsageStatsResponse, err error) {
|
||
daysCount := int(endTime.Sub(startTime).Hours()/24) + 1
|
||
if daysCount <= 0 {
|
||
daysCount = 30
|
||
}
|
||
|
||
query := `
|
||
SELECT
|
||
TO_CHAR(created_at, 'YYYY-MM-DD') as date,
|
||
COUNT(*) as requests,
|
||
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
|
||
COALESCE(SUM(total_cost), 0) as cost,
|
||
COALESCE(SUM(actual_cost), 0) as actual_cost
|
||
FROM usage_logs
|
||
WHERE account_id = $1 AND created_at >= $2 AND created_at < $3
|
||
GROUP BY date
|
||
ORDER BY date ASC
|
||
`
|
||
|
||
rows, err := r.sql.QueryContext(ctx, query, accountID, startTime, endTime)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
defer func() {
|
||
// 保持主错误优先;仅在无错误时回传 Close 失败。
|
||
// 同时清空返回值,避免误用不完整结果。
|
||
if closeErr := rows.Close(); closeErr != nil && err == nil {
|
||
err = closeErr
|
||
resp = nil
|
||
}
|
||
}()
|
||
|
||
history := make([]AccountUsageHistory, 0)
|
||
for rows.Next() {
|
||
var date string
|
||
var requests int64
|
||
var tokens int64
|
||
var cost float64
|
||
var actualCost float64
|
||
if err = rows.Scan(&date, &requests, &tokens, &cost, &actualCost); err != nil {
|
||
return nil, err
|
||
}
|
||
t, _ := time.Parse("2006-01-02", date)
|
||
history = append(history, AccountUsageHistory{
|
||
Date: date,
|
||
Label: t.Format("01/02"),
|
||
Requests: requests,
|
||
Tokens: tokens,
|
||
Cost: cost,
|
||
ActualCost: actualCost,
|
||
})
|
||
}
|
||
if err = rows.Err(); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
var totalActualCost, totalStandardCost float64
|
||
var totalRequests, totalTokens int64
|
||
var highestCostDay, highestRequestDay *AccountUsageHistory
|
||
|
||
for i := range history {
|
||
h := &history[i]
|
||
totalActualCost += h.ActualCost
|
||
totalStandardCost += h.Cost
|
||
totalRequests += h.Requests
|
||
totalTokens += h.Tokens
|
||
|
||
if highestCostDay == nil || h.ActualCost > highestCostDay.ActualCost {
|
||
highestCostDay = h
|
||
}
|
||
if highestRequestDay == nil || h.Requests > highestRequestDay.Requests {
|
||
highestRequestDay = h
|
||
}
|
||
}
|
||
|
||
actualDaysUsed := len(history)
|
||
if actualDaysUsed == 0 {
|
||
actualDaysUsed = 1
|
||
}
|
||
|
||
avgQuery := "SELECT COALESCE(AVG(duration_ms), 0) as avg_duration_ms FROM usage_logs WHERE account_id = $1 AND created_at >= $2 AND created_at < $3"
|
||
var avgDuration float64
|
||
if err := scanSingleRow(ctx, r.sql, avgQuery, []any{accountID, startTime, endTime}, &avgDuration); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
summary := AccountUsageSummary{
|
||
Days: daysCount,
|
||
ActualDaysUsed: actualDaysUsed,
|
||
TotalCost: totalActualCost,
|
||
TotalStandardCost: totalStandardCost,
|
||
TotalRequests: totalRequests,
|
||
TotalTokens: totalTokens,
|
||
AvgDailyCost: totalActualCost / float64(actualDaysUsed),
|
||
AvgDailyRequests: float64(totalRequests) / float64(actualDaysUsed),
|
||
AvgDailyTokens: float64(totalTokens) / float64(actualDaysUsed),
|
||
AvgDurationMs: avgDuration,
|
||
}
|
||
|
||
todayStr := timezone.Now().Format("2006-01-02")
|
||
for i := range history {
|
||
if history[i].Date == todayStr {
|
||
summary.Today = &struct {
|
||
Date string `json:"date"`
|
||
Cost float64 `json:"cost"`
|
||
Requests int64 `json:"requests"`
|
||
Tokens int64 `json:"tokens"`
|
||
}{
|
||
Date: history[i].Date,
|
||
Cost: history[i].ActualCost,
|
||
Requests: history[i].Requests,
|
||
Tokens: history[i].Tokens,
|
||
}
|
||
break
|
||
}
|
||
}
|
||
|
||
if highestCostDay != nil {
|
||
summary.HighestCostDay = &struct {
|
||
Date string `json:"date"`
|
||
Label string `json:"label"`
|
||
Cost float64 `json:"cost"`
|
||
Requests int64 `json:"requests"`
|
||
}{
|
||
Date: highestCostDay.Date,
|
||
Label: highestCostDay.Label,
|
||
Cost: highestCostDay.ActualCost,
|
||
Requests: highestCostDay.Requests,
|
||
}
|
||
}
|
||
|
||
if highestRequestDay != nil {
|
||
summary.HighestRequestDay = &struct {
|
||
Date string `json:"date"`
|
||
Label string `json:"label"`
|
||
Requests int64 `json:"requests"`
|
||
Cost float64 `json:"cost"`
|
||
}{
|
||
Date: highestRequestDay.Date,
|
||
Label: highestRequestDay.Label,
|
||
Requests: highestRequestDay.Requests,
|
||
Cost: highestRequestDay.ActualCost,
|
||
}
|
||
}
|
||
|
||
models, err := r.GetModelStatsWithFilters(ctx, startTime, endTime, 0, 0, accountID)
|
||
if err != nil {
|
||
models = []ModelStat{}
|
||
}
|
||
|
||
resp = &AccountUsageStatsResponse{
|
||
History: history,
|
||
Summary: summary,
|
||
Models: models,
|
||
}
|
||
return resp, nil
|
||
}
|
||
|
||
func (r *usageLogRepository) listUsageLogsWithPagination(ctx context.Context, whereClause string, args []any, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||
countQuery := "SELECT COUNT(*) FROM usage_logs " + whereClause
|
||
var total int64
|
||
if err := scanSingleRow(ctx, r.sql, countQuery, args, &total); err != nil {
|
||
return nil, nil, err
|
||
}
|
||
|
||
limitPos := len(args) + 1
|
||
offsetPos := len(args) + 2
|
||
listArgs := append(append([]any{}, args...), params.Limit(), params.Offset())
|
||
query := fmt.Sprintf("SELECT %s FROM usage_logs %s ORDER BY id DESC LIMIT $%d OFFSET $%d", usageLogSelectColumns, whereClause, limitPos, offsetPos)
|
||
logs, err := r.queryUsageLogs(ctx, query, listArgs...)
|
||
if err != nil {
|
||
return nil, nil, err
|
||
}
|
||
return logs, paginationResultFromTotal(total, params), nil
|
||
}
|
||
|
||
func (r *usageLogRepository) queryUsageLogs(ctx context.Context, query string, args ...any) (logs []service.UsageLog, err error) {
|
||
rows, err := r.sql.QueryContext(ctx, query, args...)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
defer func() {
|
||
// 保持主错误优先;仅在无错误时回传 Close 失败。
|
||
// 同时清空返回值,避免误用不完整结果。
|
||
if closeErr := rows.Close(); closeErr != nil && err == nil {
|
||
err = closeErr
|
||
logs = nil
|
||
}
|
||
}()
|
||
|
||
logs = make([]service.UsageLog, 0)
|
||
for rows.Next() {
|
||
var log *service.UsageLog
|
||
log, err = scanUsageLog(rows)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
logs = append(logs, *log)
|
||
}
|
||
if err = rows.Err(); err != nil {
|
||
return nil, err
|
||
}
|
||
return logs, nil
|
||
}
|
||
|
||
func (r *usageLogRepository) hydrateUsageLogAssociations(ctx context.Context, logs []service.UsageLog) error {
|
||
// 关联数据使用 Ent 批量加载,避免把复杂 SQL 继续膨胀。
|
||
if len(logs) == 0 {
|
||
return nil
|
||
}
|
||
|
||
ids := collectUsageLogIDs(logs)
|
||
users, err := r.loadUsers(ctx, ids.userIDs)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
apiKeys, err := r.loadAPIKeys(ctx, ids.apiKeyIDs)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
accounts, err := r.loadAccounts(ctx, ids.accountIDs)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
groups, err := r.loadGroups(ctx, ids.groupIDs)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
subs, err := r.loadSubscriptions(ctx, ids.subscriptionIDs)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
for i := range logs {
|
||
if user, ok := users[logs[i].UserID]; ok {
|
||
logs[i].User = user
|
||
}
|
||
if key, ok := apiKeys[logs[i].APIKeyID]; ok {
|
||
logs[i].APIKey = key
|
||
}
|
||
if acc, ok := accounts[logs[i].AccountID]; ok {
|
||
logs[i].Account = acc
|
||
}
|
||
if logs[i].GroupID != nil {
|
||
if group, ok := groups[*logs[i].GroupID]; ok {
|
||
logs[i].Group = group
|
||
}
|
||
}
|
||
if logs[i].SubscriptionID != nil {
|
||
if sub, ok := subs[*logs[i].SubscriptionID]; ok {
|
||
logs[i].Subscription = sub
|
||
}
|
||
}
|
||
}
|
||
return nil
|
||
}
|
||
|
||
type usageLogIDs struct {
|
||
userIDs []int64
|
||
apiKeyIDs []int64
|
||
accountIDs []int64
|
||
groupIDs []int64
|
||
subscriptionIDs []int64
|
||
}
|
||
|
||
func collectUsageLogIDs(logs []service.UsageLog) usageLogIDs {
|
||
idSet := func() map[int64]struct{} { return make(map[int64]struct{}) }
|
||
|
||
userIDs := idSet()
|
||
apiKeyIDs := idSet()
|
||
accountIDs := idSet()
|
||
groupIDs := idSet()
|
||
subscriptionIDs := idSet()
|
||
|
||
for i := range logs {
|
||
userIDs[logs[i].UserID] = struct{}{}
|
||
apiKeyIDs[logs[i].APIKeyID] = struct{}{}
|
||
accountIDs[logs[i].AccountID] = struct{}{}
|
||
if logs[i].GroupID != nil {
|
||
groupIDs[*logs[i].GroupID] = struct{}{}
|
||
}
|
||
if logs[i].SubscriptionID != nil {
|
||
subscriptionIDs[*logs[i].SubscriptionID] = struct{}{}
|
||
}
|
||
}
|
||
|
||
return usageLogIDs{
|
||
userIDs: setToSlice(userIDs),
|
||
apiKeyIDs: setToSlice(apiKeyIDs),
|
||
accountIDs: setToSlice(accountIDs),
|
||
groupIDs: setToSlice(groupIDs),
|
||
subscriptionIDs: setToSlice(subscriptionIDs),
|
||
}
|
||
}
|
||
|
||
func (r *usageLogRepository) loadUsers(ctx context.Context, ids []int64) (map[int64]*service.User, error) {
|
||
out := make(map[int64]*service.User)
|
||
if len(ids) == 0 {
|
||
return out, nil
|
||
}
|
||
models, err := r.client.User.Query().Where(dbuser.IDIn(ids...)).All(ctx)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
for _, m := range models {
|
||
out[m.ID] = userEntityToService(m)
|
||
}
|
||
return out, nil
|
||
}
|
||
|
||
func (r *usageLogRepository) loadAPIKeys(ctx context.Context, ids []int64) (map[int64]*service.APIKey, error) {
|
||
out := make(map[int64]*service.APIKey)
|
||
if len(ids) == 0 {
|
||
return out, nil
|
||
}
|
||
models, err := r.client.APIKey.Query().Where(dbapikey.IDIn(ids...)).All(ctx)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
for _, m := range models {
|
||
out[m.ID] = apiKeyEntityToService(m)
|
||
}
|
||
return out, nil
|
||
}
|
||
|
||
func (r *usageLogRepository) loadAccounts(ctx context.Context, ids []int64) (map[int64]*service.Account, error) {
|
||
out := make(map[int64]*service.Account)
|
||
if len(ids) == 0 {
|
||
return out, nil
|
||
}
|
||
models, err := r.client.Account.Query().Where(dbaccount.IDIn(ids...)).All(ctx)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
for _, m := range models {
|
||
out[m.ID] = accountEntityToService(m)
|
||
}
|
||
return out, nil
|
||
}
|
||
|
||
func (r *usageLogRepository) loadGroups(ctx context.Context, ids []int64) (map[int64]*service.Group, error) {
|
||
out := make(map[int64]*service.Group)
|
||
if len(ids) == 0 {
|
||
return out, nil
|
||
}
|
||
models, err := r.client.Group.Query().Where(dbgroup.IDIn(ids...)).All(ctx)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
for _, m := range models {
|
||
out[m.ID] = groupEntityToService(m)
|
||
}
|
||
return out, nil
|
||
}
|
||
|
||
func (r *usageLogRepository) loadSubscriptions(ctx context.Context, ids []int64) (map[int64]*service.UserSubscription, error) {
|
||
out := make(map[int64]*service.UserSubscription)
|
||
if len(ids) == 0 {
|
||
return out, nil
|
||
}
|
||
models, err := r.client.UserSubscription.Query().Where(dbusersub.IDIn(ids...)).All(ctx)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
for _, m := range models {
|
||
out[m.ID] = userSubscriptionEntityToService(m)
|
||
}
|
||
return out, nil
|
||
}
|
||
|
||
func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, error) {
|
||
var (
|
||
id int64
|
||
userID int64
|
||
apiKeyID int64
|
||
accountID int64
|
||
requestID sql.NullString
|
||
model string
|
||
groupID sql.NullInt64
|
||
subscriptionID sql.NullInt64
|
||
inputTokens int
|
||
outputTokens int
|
||
cacheCreationTokens int
|
||
cacheReadTokens int
|
||
cacheCreation5m int
|
||
cacheCreation1h int
|
||
inputCost float64
|
||
outputCost float64
|
||
cacheCreationCost float64
|
||
cacheReadCost float64
|
||
totalCost float64
|
||
actualCost float64
|
||
rateMultiplier float64
|
||
billingType int16
|
||
stream bool
|
||
durationMs sql.NullInt64
|
||
firstTokenMs sql.NullInt64
|
||
userAgent sql.NullString
|
||
ipAddress sql.NullString
|
||
imageCount int
|
||
imageSize sql.NullString
|
||
createdAt time.Time
|
||
)
|
||
|
||
if err := scanner.Scan(
|
||
&id,
|
||
&userID,
|
||
&apiKeyID,
|
||
&accountID,
|
||
&requestID,
|
||
&model,
|
||
&groupID,
|
||
&subscriptionID,
|
||
&inputTokens,
|
||
&outputTokens,
|
||
&cacheCreationTokens,
|
||
&cacheReadTokens,
|
||
&cacheCreation5m,
|
||
&cacheCreation1h,
|
||
&inputCost,
|
||
&outputCost,
|
||
&cacheCreationCost,
|
||
&cacheReadCost,
|
||
&totalCost,
|
||
&actualCost,
|
||
&rateMultiplier,
|
||
&billingType,
|
||
&stream,
|
||
&durationMs,
|
||
&firstTokenMs,
|
||
&userAgent,
|
||
&ipAddress,
|
||
&imageCount,
|
||
&imageSize,
|
||
&createdAt,
|
||
); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
log := &service.UsageLog{
|
||
ID: id,
|
||
UserID: userID,
|
||
APIKeyID: apiKeyID,
|
||
AccountID: accountID,
|
||
Model: model,
|
||
InputTokens: inputTokens,
|
||
OutputTokens: outputTokens,
|
||
CacheCreationTokens: cacheCreationTokens,
|
||
CacheReadTokens: cacheReadTokens,
|
||
CacheCreation5mTokens: cacheCreation5m,
|
||
CacheCreation1hTokens: cacheCreation1h,
|
||
InputCost: inputCost,
|
||
OutputCost: outputCost,
|
||
CacheCreationCost: cacheCreationCost,
|
||
CacheReadCost: cacheReadCost,
|
||
TotalCost: totalCost,
|
||
ActualCost: actualCost,
|
||
RateMultiplier: rateMultiplier,
|
||
BillingType: int8(billingType),
|
||
Stream: stream,
|
||
ImageCount: imageCount,
|
||
CreatedAt: createdAt,
|
||
}
|
||
|
||
if requestID.Valid {
|
||
log.RequestID = requestID.String
|
||
}
|
||
if groupID.Valid {
|
||
value := groupID.Int64
|
||
log.GroupID = &value
|
||
}
|
||
if subscriptionID.Valid {
|
||
value := subscriptionID.Int64
|
||
log.SubscriptionID = &value
|
||
}
|
||
if durationMs.Valid {
|
||
value := int(durationMs.Int64)
|
||
log.DurationMs = &value
|
||
}
|
||
if firstTokenMs.Valid {
|
||
value := int(firstTokenMs.Int64)
|
||
log.FirstTokenMs = &value
|
||
}
|
||
if userAgent.Valid {
|
||
log.UserAgent = &userAgent.String
|
||
}
|
||
if ipAddress.Valid {
|
||
log.IPAddress = &ipAddress.String
|
||
}
|
||
if imageSize.Valid {
|
||
log.ImageSize = &imageSize.String
|
||
}
|
||
|
||
return log, nil
|
||
}
|
||
|
||
func scanTrendRows(rows *sql.Rows) ([]TrendDataPoint, error) {
|
||
results := make([]TrendDataPoint, 0)
|
||
for rows.Next() {
|
||
var row TrendDataPoint
|
||
if err := rows.Scan(
|
||
&row.Date,
|
||
&row.Requests,
|
||
&row.InputTokens,
|
||
&row.OutputTokens,
|
||
&row.CacheTokens,
|
||
&row.TotalTokens,
|
||
&row.Cost,
|
||
&row.ActualCost,
|
||
); err != nil {
|
||
return nil, err
|
||
}
|
||
results = append(results, row)
|
||
}
|
||
if err := rows.Err(); err != nil {
|
||
return nil, err
|
||
}
|
||
return results, nil
|
||
}
|
||
|
||
func scanModelStatsRows(rows *sql.Rows) ([]ModelStat, error) {
|
||
results := make([]ModelStat, 0)
|
||
for rows.Next() {
|
||
var row ModelStat
|
||
if err := rows.Scan(
|
||
&row.Model,
|
||
&row.Requests,
|
||
&row.InputTokens,
|
||
&row.OutputTokens,
|
||
&row.TotalTokens,
|
||
&row.Cost,
|
||
&row.ActualCost,
|
||
); err != nil {
|
||
return nil, err
|
||
}
|
||
results = append(results, row)
|
||
}
|
||
if err := rows.Err(); err != nil {
|
||
return nil, err
|
||
}
|
||
return results, nil
|
||
}
|
||
|
||
func buildWhere(conditions []string) string {
|
||
if len(conditions) == 0 {
|
||
return ""
|
||
}
|
||
return "WHERE " + strings.Join(conditions, " AND ")
|
||
}
|
||
|
||
func nullInt64(v *int64) sql.NullInt64 {
|
||
if v == nil {
|
||
return sql.NullInt64{}
|
||
}
|
||
return sql.NullInt64{Int64: *v, Valid: true}
|
||
}
|
||
|
||
func nullInt(v *int) sql.NullInt64 {
|
||
if v == nil {
|
||
return sql.NullInt64{}
|
||
}
|
||
return sql.NullInt64{Int64: int64(*v), Valid: true}
|
||
}
|
||
|
||
func nullString(v *string) sql.NullString {
|
||
if v == nil || *v == "" {
|
||
return sql.NullString{}
|
||
}
|
||
return sql.NullString{String: *v, Valid: true}
|
||
}
|
||
|
||
func setToSlice(set map[int64]struct{}) []int64 {
|
||
out := make([]int64, 0, len(set))
|
||
for id := range set {
|
||
out = append(out, id)
|
||
}
|
||
return out
|
||
}
|