Previously the /v1/usage endpoint aggregated usage stats (today/total tokens, cost, RPM/TPM) across all API Keys belonging to the user. This made it impossible to distinguish usage from different API Keys (e.g. balance vs subscription keys). Now the usage stats are filtered by the current request's API Key ID, so each key only sees its own usage data. The balance/remaining fields are unaffected and still reflect the user-level wallet balance. Changes: - Add GetAPIKeyDashboardStats to repository interface and implementation - Add getPerformanceStatsByAPIKey helper (also fixes TPM to include cache_creation_tokens and cache_read_tokens) - Add GetAPIKeyDashboardStats to UsageService - Update Usage handler to call GetAPIKeyDashboardStats(apiKey.ID) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2389 lines
70 KiB
Go
2389 lines
70 KiB
Go
package repository
|
||
|
||
import (
|
||
"context"
|
||
"database/sql"
|
||
"errors"
|
||
"fmt"
|
||
"os"
|
||
"strings"
|
||
"time"
|
||
|
||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||
dbaccount "github.com/Wei-Shaw/sub2api/ent/account"
|
||
dbapikey "github.com/Wei-Shaw/sub2api/ent/apikey"
|
||
dbgroup "github.com/Wei-Shaw/sub2api/ent/group"
|
||
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
|
||
dbusersub "github.com/Wei-Shaw/sub2api/ent/usersubscription"
|
||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||
"github.com/lib/pq"
|
||
)
|
||
|
||
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, reasoning_effort, created_at"
|
||
|
||
type usageLogRepository struct {
|
||
client *dbent.Client
|
||
sql sqlExecutor
|
||
}
|
||
|
||
func NewUsageLogRepository(client *dbent.Client, sqlDB *sql.DB) service.UsageLogRepository {
|
||
return newUsageLogRepositoryWithSQL(client, sqlDB)
|
||
}
|
||
|
||
func newUsageLogRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *usageLogRepository {
|
||
// 使用 scanSingleRow 替代 QueryRowContext,保证 ent.Tx 作为 sqlExecutor 可用。
|
||
return &usageLogRepository{client: client, sql: sqlq}
|
||
}
|
||
|
||
// getPerformanceStats 获取 RPM 和 TPM(近5分钟平均值,可选按用户过滤)
|
||
func (r *usageLogRepository) getPerformanceStats(ctx context.Context, userID int64) (rpm, tpm int64, err error) {
|
||
fiveMinutesAgo := time.Now().Add(-5 * time.Minute)
|
||
query := `
|
||
SELECT
|
||
COUNT(*) as request_count,
|
||
COALESCE(SUM(input_tokens + output_tokens), 0) as token_count
|
||
FROM usage_logs
|
||
WHERE created_at >= $1`
|
||
args := []any{fiveMinutesAgo}
|
||
if userID > 0 {
|
||
query += " AND user_id = $2"
|
||
args = append(args, userID)
|
||
}
|
||
|
||
var requestCount int64
|
||
var tokenCount int64
|
||
if err := scanSingleRow(ctx, r.sql, query, args, &requestCount, &tokenCount); err != nil {
|
||
return 0, 0, err
|
||
}
|
||
return requestCount / 5, tokenCount / 5, nil
|
||
}
|
||
|
||
func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) (bool, error) {
|
||
if log == nil {
|
||
return false, nil
|
||
}
|
||
|
||
// 在事务上下文中,使用 tx 绑定的 ExecQuerier 执行原生 SQL,保证与其他更新同事务。
|
||
// 无事务时回退到默认的 *sql.DB 执行器。
|
||
sqlq := r.sql
|
||
if tx := dbent.TxFromContext(ctx); tx != nil {
|
||
sqlq = tx.Client()
|
||
}
|
||
|
||
createdAt := log.CreatedAt
|
||
if createdAt.IsZero() {
|
||
createdAt = time.Now()
|
||
}
|
||
|
||
requestID := strings.TrimSpace(log.RequestID)
|
||
log.RequestID = requestID
|
||
|
||
rateMultiplier := log.RateMultiplier
|
||
|
||
query := `
|
||
INSERT INTO usage_logs (
|
||
user_id,
|
||
api_key_id,
|
||
account_id,
|
||
request_id,
|
||
model,
|
||
group_id,
|
||
subscription_id,
|
||
input_tokens,
|
||
output_tokens,
|
||
cache_creation_tokens,
|
||
cache_read_tokens,
|
||
cache_creation_5m_tokens,
|
||
cache_creation_1h_tokens,
|
||
input_cost,
|
||
output_cost,
|
||
cache_creation_cost,
|
||
cache_read_cost,
|
||
total_cost,
|
||
actual_cost,
|
||
rate_multiplier,
|
||
account_rate_multiplier,
|
||
billing_type,
|
||
stream,
|
||
duration_ms,
|
||
first_token_ms,
|
||
user_agent,
|
||
ip_address,
|
||
image_count,
|
||
image_size,
|
||
reasoning_effort,
|
||
created_at
|
||
) VALUES (
|
||
$1, $2, $3, $4, $5,
|
||
$6, $7,
|
||
$8, $9, $10, $11,
|
||
$12, $13,
|
||
$14, $15, $16, $17, $18, $19,
|
||
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31
|
||
)
|
||
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
||
RETURNING id, created_at
|
||
`
|
||
|
||
groupID := nullInt64(log.GroupID)
|
||
subscriptionID := nullInt64(log.SubscriptionID)
|
||
duration := nullInt(log.DurationMs)
|
||
firstToken := nullInt(log.FirstTokenMs)
|
||
userAgent := nullString(log.UserAgent)
|
||
ipAddress := nullString(log.IPAddress)
|
||
imageSize := nullString(log.ImageSize)
|
||
reasoningEffort := nullString(log.ReasoningEffort)
|
||
|
||
var requestIDArg any
|
||
if requestID != "" {
|
||
requestIDArg = requestID
|
||
}
|
||
|
||
args := []any{
|
||
log.UserID,
|
||
log.APIKeyID,
|
||
log.AccountID,
|
||
requestIDArg,
|
||
log.Model,
|
||
groupID,
|
||
subscriptionID,
|
||
log.InputTokens,
|
||
log.OutputTokens,
|
||
log.CacheCreationTokens,
|
||
log.CacheReadTokens,
|
||
log.CacheCreation5mTokens,
|
||
log.CacheCreation1hTokens,
|
||
log.InputCost,
|
||
log.OutputCost,
|
||
log.CacheCreationCost,
|
||
log.CacheReadCost,
|
||
log.TotalCost,
|
||
log.ActualCost,
|
||
rateMultiplier,
|
||
log.AccountRateMultiplier,
|
||
log.BillingType,
|
||
log.Stream,
|
||
duration,
|
||
firstToken,
|
||
userAgent,
|
||
ipAddress,
|
||
log.ImageCount,
|
||
imageSize,
|
||
reasoningEffort,
|
||
createdAt,
|
||
}
|
||
if err := scanSingleRow(ctx, sqlq, query, args, &log.ID, &log.CreatedAt); err != nil {
|
||
if errors.Is(err, sql.ErrNoRows) && requestID != "" {
|
||
selectQuery := "SELECT id, created_at FROM usage_logs WHERE request_id = $1 AND api_key_id = $2"
|
||
if err := scanSingleRow(ctx, sqlq, selectQuery, []any{requestID, log.APIKeyID}, &log.ID, &log.CreatedAt); err != nil {
|
||
return false, err
|
||
}
|
||
log.RateMultiplier = rateMultiplier
|
||
return false, nil
|
||
} else {
|
||
return false, err
|
||
}
|
||
}
|
||
log.RateMultiplier = rateMultiplier
|
||
return true, nil
|
||
}
|
||
|
||
func (r *usageLogRepository) GetByID(ctx context.Context, id int64) (log *service.UsageLog, err error) {
|
||
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE id = $1"
|
||
rows, err := r.sql.QueryContext(ctx, query, id)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
defer func() {
|
||
// 保持主错误优先;仅在无错误时回传 Close 失败。
|
||
// 同时清空返回值,避免误用不完整结果。
|
||
if closeErr := rows.Close(); closeErr != nil && err == nil {
|
||
err = closeErr
|
||
log = nil
|
||
}
|
||
}()
|
||
if !rows.Next() {
|
||
if err = rows.Err(); err != nil {
|
||
return nil, err
|
||
}
|
||
return nil, service.ErrUsageLogNotFound
|
||
}
|
||
log, err = scanUsageLog(rows)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
if err = rows.Err(); err != nil {
|
||
return nil, err
|
||
}
|
||
return log, nil
|
||
}
|
||
|
||
func (r *usageLogRepository) ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||
return r.listUsageLogsWithPagination(ctx, "WHERE user_id = $1", []any{userID}, params)
|
||
}
|
||
|
||
func (r *usageLogRepository) ListByAPIKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||
return r.listUsageLogsWithPagination(ctx, "WHERE api_key_id = $1", []any{apiKeyID}, params)
|
||
}
|
||
|
||
// UserStats 用户使用统计
|
||
type UserStats struct {
|
||
TotalRequests int64 `json:"total_requests"`
|
||
TotalTokens int64 `json:"total_tokens"`
|
||
TotalCost float64 `json:"total_cost"`
|
||
InputTokens int64 `json:"input_tokens"`
|
||
OutputTokens int64 `json:"output_tokens"`
|
||
CacheReadTokens int64 `json:"cache_read_tokens"`
|
||
}
|
||
|
||
func (r *usageLogRepository) GetUserStats(ctx context.Context, userID int64, startTime, endTime time.Time) (*UserStats, error) {
|
||
query := `
|
||
SELECT
|
||
COUNT(*) as total_requests,
|
||
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
|
||
COALESCE(SUM(actual_cost), 0) as total_cost,
|
||
COALESCE(SUM(input_tokens), 0) as input_tokens,
|
||
COALESCE(SUM(output_tokens), 0) as output_tokens,
|
||
COALESCE(SUM(cache_read_tokens), 0) as cache_read_tokens
|
||
FROM usage_logs
|
||
WHERE user_id = $1 AND created_at >= $2 AND created_at < $3
|
||
`
|
||
|
||
stats := &UserStats{}
|
||
if err := scanSingleRow(
|
||
ctx,
|
||
r.sql,
|
||
query,
|
||
[]any{userID, startTime, endTime},
|
||
&stats.TotalRequests,
|
||
&stats.TotalTokens,
|
||
&stats.TotalCost,
|
||
&stats.InputTokens,
|
||
&stats.OutputTokens,
|
||
&stats.CacheReadTokens,
|
||
); err != nil {
|
||
return nil, err
|
||
}
|
||
return stats, nil
|
||
}
|
||
|
||
// DashboardStats 仪表盘统计
|
||
type DashboardStats = usagestats.DashboardStats
|
||
|
||
func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardStats, error) {
|
||
stats := &DashboardStats{}
|
||
now := timezone.Now()
|
||
todayStart := timezone.Today()
|
||
|
||
if err := r.fillDashboardEntityStats(ctx, stats, todayStart, now); err != nil {
|
||
return nil, err
|
||
}
|
||
if err := r.fillDashboardUsageStatsAggregated(ctx, stats, todayStart, now); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
rpm, tpm, err := r.getPerformanceStats(ctx, 0)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
stats.Rpm = rpm
|
||
stats.Tpm = tpm
|
||
|
||
return stats, nil
|
||
}
|
||
|
||
func (r *usageLogRepository) GetDashboardStatsWithRange(ctx context.Context, start, end time.Time) (*DashboardStats, error) {
|
||
startUTC := start.UTC()
|
||
endUTC := end.UTC()
|
||
if !endUTC.After(startUTC) {
|
||
return nil, errors.New("统计时间范围无效")
|
||
}
|
||
|
||
stats := &DashboardStats{}
|
||
now := timezone.Now()
|
||
todayStart := timezone.Today()
|
||
|
||
if err := r.fillDashboardEntityStats(ctx, stats, todayStart, now); err != nil {
|
||
return nil, err
|
||
}
|
||
if err := r.fillDashboardUsageStatsFromUsageLogs(ctx, stats, startUTC, endUTC, todayStart, now); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
rpm, tpm, err := r.getPerformanceStats(ctx, 0)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
stats.Rpm = rpm
|
||
stats.Tpm = tpm
|
||
|
||
return stats, nil
|
||
}
|
||
|
||
func (r *usageLogRepository) fillDashboardEntityStats(ctx context.Context, stats *DashboardStats, todayUTC, now time.Time) error {
|
||
userStatsQuery := `
|
||
SELECT
|
||
COUNT(*) as total_users,
|
||
COUNT(CASE WHEN created_at >= $1 THEN 1 END) as today_new_users
|
||
FROM users
|
||
WHERE deleted_at IS NULL
|
||
`
|
||
if err := scanSingleRow(
|
||
ctx,
|
||
r.sql,
|
||
userStatsQuery,
|
||
[]any{todayUTC},
|
||
&stats.TotalUsers,
|
||
&stats.TodayNewUsers,
|
||
); err != nil {
|
||
return err
|
||
}
|
||
|
||
apiKeyStatsQuery := `
|
||
SELECT
|
||
COUNT(*) as total_api_keys,
|
||
COUNT(CASE WHEN status = $1 THEN 1 END) as active_api_keys
|
||
FROM api_keys
|
||
WHERE deleted_at IS NULL
|
||
`
|
||
if err := scanSingleRow(
|
||
ctx,
|
||
r.sql,
|
||
apiKeyStatsQuery,
|
||
[]any{service.StatusActive},
|
||
&stats.TotalAPIKeys,
|
||
&stats.ActiveAPIKeys,
|
||
); err != nil {
|
||
return err
|
||
}
|
||
|
||
accountStatsQuery := `
|
||
SELECT
|
||
COUNT(*) as total_accounts,
|
||
COUNT(CASE WHEN status = $1 AND schedulable = true THEN 1 END) as normal_accounts,
|
||
COUNT(CASE WHEN status = $2 THEN 1 END) as error_accounts,
|
||
COUNT(CASE WHEN rate_limited_at IS NOT NULL AND rate_limit_reset_at > $3 THEN 1 END) as ratelimit_accounts,
|
||
COUNT(CASE WHEN overload_until IS NOT NULL AND overload_until > $4 THEN 1 END) as overload_accounts
|
||
FROM accounts
|
||
WHERE deleted_at IS NULL
|
||
`
|
||
if err := scanSingleRow(
|
||
ctx,
|
||
r.sql,
|
||
accountStatsQuery,
|
||
[]any{service.StatusActive, service.StatusError, now, now},
|
||
&stats.TotalAccounts,
|
||
&stats.NormalAccounts,
|
||
&stats.ErrorAccounts,
|
||
&stats.RateLimitAccounts,
|
||
&stats.OverloadAccounts,
|
||
); err != nil {
|
||
return err
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
func (r *usageLogRepository) fillDashboardUsageStatsAggregated(ctx context.Context, stats *DashboardStats, todayUTC, now time.Time) error {
|
||
totalStatsQuery := `
|
||
SELECT
|
||
COALESCE(SUM(total_requests), 0) as total_requests,
|
||
COALESCE(SUM(input_tokens), 0) as total_input_tokens,
|
||
COALESCE(SUM(output_tokens), 0) as total_output_tokens,
|
||
COALESCE(SUM(cache_creation_tokens), 0) as total_cache_creation_tokens,
|
||
COALESCE(SUM(cache_read_tokens), 0) as total_cache_read_tokens,
|
||
COALESCE(SUM(total_cost), 0) as total_cost,
|
||
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
|
||
COALESCE(SUM(total_duration_ms), 0) as total_duration_ms
|
||
FROM usage_dashboard_daily
|
||
`
|
||
var totalDurationMs int64
|
||
if err := scanSingleRow(
|
||
ctx,
|
||
r.sql,
|
||
totalStatsQuery,
|
||
nil,
|
||
&stats.TotalRequests,
|
||
&stats.TotalInputTokens,
|
||
&stats.TotalOutputTokens,
|
||
&stats.TotalCacheCreationTokens,
|
||
&stats.TotalCacheReadTokens,
|
||
&stats.TotalCost,
|
||
&stats.TotalActualCost,
|
||
&totalDurationMs,
|
||
); err != nil {
|
||
return err
|
||
}
|
||
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheCreationTokens + stats.TotalCacheReadTokens
|
||
if stats.TotalRequests > 0 {
|
||
stats.AverageDurationMs = float64(totalDurationMs) / float64(stats.TotalRequests)
|
||
}
|
||
|
||
todayStatsQuery := `
|
||
SELECT
|
||
total_requests as today_requests,
|
||
input_tokens as today_input_tokens,
|
||
output_tokens as today_output_tokens,
|
||
cache_creation_tokens as today_cache_creation_tokens,
|
||
cache_read_tokens as today_cache_read_tokens,
|
||
total_cost as today_cost,
|
||
actual_cost as today_actual_cost,
|
||
active_users as active_users
|
||
FROM usage_dashboard_daily
|
||
WHERE bucket_date = $1::date
|
||
`
|
||
if err := scanSingleRow(
|
||
ctx,
|
||
r.sql,
|
||
todayStatsQuery,
|
||
[]any{todayUTC},
|
||
&stats.TodayRequests,
|
||
&stats.TodayInputTokens,
|
||
&stats.TodayOutputTokens,
|
||
&stats.TodayCacheCreationTokens,
|
||
&stats.TodayCacheReadTokens,
|
||
&stats.TodayCost,
|
||
&stats.TodayActualCost,
|
||
&stats.ActiveUsers,
|
||
); err != nil {
|
||
if err != sql.ErrNoRows {
|
||
return err
|
||
}
|
||
}
|
||
stats.TodayTokens = stats.TodayInputTokens + stats.TodayOutputTokens + stats.TodayCacheCreationTokens + stats.TodayCacheReadTokens
|
||
|
||
hourlyActiveQuery := `
|
||
SELECT active_users
|
||
FROM usage_dashboard_hourly
|
||
WHERE bucket_start = $1
|
||
`
|
||
hourStart := now.In(timezone.Location()).Truncate(time.Hour)
|
||
if err := scanSingleRow(ctx, r.sql, hourlyActiveQuery, []any{hourStart}, &stats.HourlyActiveUsers); err != nil {
|
||
if err != sql.ErrNoRows {
|
||
return err
|
||
}
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
func (r *usageLogRepository) fillDashboardUsageStatsFromUsageLogs(ctx context.Context, stats *DashboardStats, startUTC, endUTC, todayUTC, now time.Time) error {
|
||
totalStatsQuery := `
|
||
SELECT
|
||
COUNT(*) as total_requests,
|
||
COALESCE(SUM(input_tokens), 0) as total_input_tokens,
|
||
COALESCE(SUM(output_tokens), 0) as total_output_tokens,
|
||
COALESCE(SUM(cache_creation_tokens), 0) as total_cache_creation_tokens,
|
||
COALESCE(SUM(cache_read_tokens), 0) as total_cache_read_tokens,
|
||
COALESCE(SUM(total_cost), 0) as total_cost,
|
||
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
|
||
COALESCE(SUM(COALESCE(duration_ms, 0)), 0) as total_duration_ms
|
||
FROM usage_logs
|
||
WHERE created_at >= $1 AND created_at < $2
|
||
`
|
||
var totalDurationMs int64
|
||
if err := scanSingleRow(
|
||
ctx,
|
||
r.sql,
|
||
totalStatsQuery,
|
||
[]any{startUTC, endUTC},
|
||
&stats.TotalRequests,
|
||
&stats.TotalInputTokens,
|
||
&stats.TotalOutputTokens,
|
||
&stats.TotalCacheCreationTokens,
|
||
&stats.TotalCacheReadTokens,
|
||
&stats.TotalCost,
|
||
&stats.TotalActualCost,
|
||
&totalDurationMs,
|
||
); err != nil {
|
||
return err
|
||
}
|
||
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheCreationTokens + stats.TotalCacheReadTokens
|
||
if stats.TotalRequests > 0 {
|
||
stats.AverageDurationMs = float64(totalDurationMs) / float64(stats.TotalRequests)
|
||
}
|
||
|
||
todayEnd := todayUTC.Add(24 * time.Hour)
|
||
todayStatsQuery := `
|
||
SELECT
|
||
COUNT(*) as today_requests,
|
||
COALESCE(SUM(input_tokens), 0) as today_input_tokens,
|
||
COALESCE(SUM(output_tokens), 0) as today_output_tokens,
|
||
COALESCE(SUM(cache_creation_tokens), 0) as today_cache_creation_tokens,
|
||
COALESCE(SUM(cache_read_tokens), 0) as today_cache_read_tokens,
|
||
COALESCE(SUM(total_cost), 0) as today_cost,
|
||
COALESCE(SUM(actual_cost), 0) as today_actual_cost
|
||
FROM usage_logs
|
||
WHERE created_at >= $1 AND created_at < $2
|
||
`
|
||
if err := scanSingleRow(
|
||
ctx,
|
||
r.sql,
|
||
todayStatsQuery,
|
||
[]any{todayUTC, todayEnd},
|
||
&stats.TodayRequests,
|
||
&stats.TodayInputTokens,
|
||
&stats.TodayOutputTokens,
|
||
&stats.TodayCacheCreationTokens,
|
||
&stats.TodayCacheReadTokens,
|
||
&stats.TodayCost,
|
||
&stats.TodayActualCost,
|
||
); err != nil {
|
||
return err
|
||
}
|
||
stats.TodayTokens = stats.TodayInputTokens + stats.TodayOutputTokens + stats.TodayCacheCreationTokens + stats.TodayCacheReadTokens
|
||
|
||
activeUsersQuery := `
|
||
SELECT COUNT(DISTINCT user_id) as active_users
|
||
FROM usage_logs
|
||
WHERE created_at >= $1 AND created_at < $2
|
||
`
|
||
if err := scanSingleRow(ctx, r.sql, activeUsersQuery, []any{todayUTC, todayEnd}, &stats.ActiveUsers); err != nil {
|
||
return err
|
||
}
|
||
|
||
hourStart := now.UTC().Truncate(time.Hour)
|
||
hourEnd := hourStart.Add(time.Hour)
|
||
hourlyActiveQuery := `
|
||
SELECT COUNT(DISTINCT user_id) as active_users
|
||
FROM usage_logs
|
||
WHERE created_at >= $1 AND created_at < $2
|
||
`
|
||
if err := scanSingleRow(ctx, r.sql, hourlyActiveQuery, []any{hourStart, hourEnd}, &stats.HourlyActiveUsers); err != nil {
|
||
return err
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
func (r *usageLogRepository) ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||
return r.listUsageLogsWithPagination(ctx, "WHERE account_id = $1", []any{accountID}, params)
|
||
}
|
||
|
||
func (r *usageLogRepository) ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE user_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC"
|
||
logs, err := r.queryUsageLogs(ctx, query, userID, startTime, endTime)
|
||
return logs, nil, err
|
||
}
|
||
|
||
// GetUserStatsAggregated returns aggregated usage statistics for a user using database-level aggregation
|
||
func (r *usageLogRepository) GetUserStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
|
||
query := `
|
||
SELECT
|
||
COUNT(*) as total_requests,
|
||
COALESCE(SUM(input_tokens), 0) as total_input_tokens,
|
||
COALESCE(SUM(output_tokens), 0) as total_output_tokens,
|
||
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as total_cache_tokens,
|
||
COALESCE(SUM(total_cost), 0) as total_cost,
|
||
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
|
||
COALESCE(AVG(COALESCE(duration_ms, 0)), 0) as avg_duration_ms
|
||
FROM usage_logs
|
||
WHERE user_id = $1 AND created_at >= $2 AND created_at < $3
|
||
`
|
||
|
||
var stats usagestats.UsageStats
|
||
if err := scanSingleRow(
|
||
ctx,
|
||
r.sql,
|
||
query,
|
||
[]any{userID, startTime, endTime},
|
||
&stats.TotalRequests,
|
||
&stats.TotalInputTokens,
|
||
&stats.TotalOutputTokens,
|
||
&stats.TotalCacheTokens,
|
||
&stats.TotalCost,
|
||
&stats.TotalActualCost,
|
||
&stats.AverageDurationMs,
|
||
); err != nil {
|
||
return nil, err
|
||
}
|
||
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheTokens
|
||
return &stats, nil
|
||
}
|
||
|
||
// GetAPIKeyStatsAggregated returns aggregated usage statistics for an API key using database-level aggregation
|
||
func (r *usageLogRepository) GetAPIKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
|
||
query := `
|
||
SELECT
|
||
COUNT(*) as total_requests,
|
||
COALESCE(SUM(input_tokens), 0) as total_input_tokens,
|
||
COALESCE(SUM(output_tokens), 0) as total_output_tokens,
|
||
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as total_cache_tokens,
|
||
COALESCE(SUM(total_cost), 0) as total_cost,
|
||
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
|
||
COALESCE(AVG(COALESCE(duration_ms, 0)), 0) as avg_duration_ms
|
||
FROM usage_logs
|
||
WHERE api_key_id = $1 AND created_at >= $2 AND created_at < $3
|
||
`
|
||
|
||
var stats usagestats.UsageStats
|
||
if err := scanSingleRow(
|
||
ctx,
|
||
r.sql,
|
||
query,
|
||
[]any{apiKeyID, startTime, endTime},
|
||
&stats.TotalRequests,
|
||
&stats.TotalInputTokens,
|
||
&stats.TotalOutputTokens,
|
||
&stats.TotalCacheTokens,
|
||
&stats.TotalCost,
|
||
&stats.TotalActualCost,
|
||
&stats.AverageDurationMs,
|
||
); err != nil {
|
||
return nil, err
|
||
}
|
||
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheTokens
|
||
return &stats, nil
|
||
}
|
||
|
||
// GetAccountStatsAggregated 使用 SQL 聚合统计账号使用数据
|
||
//
|
||
// 性能优化说明:
|
||
// 原实现先查询所有日志记录,再在应用层循环计算统计值:
|
||
// 1. 需要传输大量数据到应用层
|
||
// 2. 应用层循环计算增加 CPU 和内存开销
|
||
//
|
||
// 新实现使用 SQL 聚合函数:
|
||
// 1. 在数据库层完成 COUNT/SUM/AVG 计算
|
||
// 2. 只返回单行聚合结果,大幅减少数据传输量
|
||
// 3. 利用数据库索引优化聚合查询性能
|
||
func (r *usageLogRepository) GetAccountStatsAggregated(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
|
||
query := `
|
||
SELECT
|
||
COUNT(*) as total_requests,
|
||
COALESCE(SUM(input_tokens), 0) as total_input_tokens,
|
||
COALESCE(SUM(output_tokens), 0) as total_output_tokens,
|
||
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as total_cache_tokens,
|
||
COALESCE(SUM(total_cost), 0) as total_cost,
|
||
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
|
||
COALESCE(AVG(COALESCE(duration_ms, 0)), 0) as avg_duration_ms
|
||
FROM usage_logs
|
||
WHERE account_id = $1 AND created_at >= $2 AND created_at < $3
|
||
`
|
||
|
||
var stats usagestats.UsageStats
|
||
if err := scanSingleRow(
|
||
ctx,
|
||
r.sql,
|
||
query,
|
||
[]any{accountID, startTime, endTime},
|
||
&stats.TotalRequests,
|
||
&stats.TotalInputTokens,
|
||
&stats.TotalOutputTokens,
|
||
&stats.TotalCacheTokens,
|
||
&stats.TotalCost,
|
||
&stats.TotalActualCost,
|
||
&stats.AverageDurationMs,
|
||
); err != nil {
|
||
return nil, err
|
||
}
|
||
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheTokens
|
||
return &stats, nil
|
||
}
|
||
|
||
// GetModelStatsAggregated 使用 SQL 聚合统计模型使用数据
|
||
// 性能优化:数据库层聚合计算,避免应用层循环统计
|
||
func (r *usageLogRepository) GetModelStatsAggregated(ctx context.Context, modelName string, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
|
||
query := `
|
||
SELECT
|
||
COUNT(*) as total_requests,
|
||
COALESCE(SUM(input_tokens), 0) as total_input_tokens,
|
||
COALESCE(SUM(output_tokens), 0) as total_output_tokens,
|
||
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as total_cache_tokens,
|
||
COALESCE(SUM(total_cost), 0) as total_cost,
|
||
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
|
||
COALESCE(AVG(COALESCE(duration_ms, 0)), 0) as avg_duration_ms
|
||
FROM usage_logs
|
||
WHERE model = $1 AND created_at >= $2 AND created_at < $3
|
||
`
|
||
|
||
var stats usagestats.UsageStats
|
||
if err := scanSingleRow(
|
||
ctx,
|
||
r.sql,
|
||
query,
|
||
[]any{modelName, startTime, endTime},
|
||
&stats.TotalRequests,
|
||
&stats.TotalInputTokens,
|
||
&stats.TotalOutputTokens,
|
||
&stats.TotalCacheTokens,
|
||
&stats.TotalCost,
|
||
&stats.TotalActualCost,
|
||
&stats.AverageDurationMs,
|
||
); err != nil {
|
||
return nil, err
|
||
}
|
||
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheTokens
|
||
return &stats, nil
|
||
}
|
||
|
||
// GetDailyStatsAggregated 使用 SQL 聚合统计用户的每日使用数据
|
||
// 性能优化:使用 GROUP BY 在数据库层按日期分组聚合,避免应用层循环分组统计
|
||
func (r *usageLogRepository) GetDailyStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) (result []map[string]any, err error) {
|
||
tzName := resolveUsageStatsTimezone()
|
||
query := `
|
||
SELECT
|
||
-- 使用应用时区分组,避免数据库会话时区导致日边界偏移。
|
||
TO_CHAR(created_at AT TIME ZONE $4, 'YYYY-MM-DD') as date,
|
||
COUNT(*) as total_requests,
|
||
COALESCE(SUM(input_tokens), 0) as total_input_tokens,
|
||
COALESCE(SUM(output_tokens), 0) as total_output_tokens,
|
||
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as total_cache_tokens,
|
||
COALESCE(SUM(total_cost), 0) as total_cost,
|
||
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
|
||
COALESCE(AVG(COALESCE(duration_ms, 0)), 0) as avg_duration_ms
|
||
FROM usage_logs
|
||
WHERE user_id = $1 AND created_at >= $2 AND created_at < $3
|
||
GROUP BY 1
|
||
ORDER BY 1
|
||
`
|
||
|
||
rows, err := r.sql.QueryContext(ctx, query, userID, startTime, endTime, tzName)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
defer func() {
|
||
if closeErr := rows.Close(); closeErr != nil && err == nil {
|
||
err = closeErr
|
||
result = nil
|
||
}
|
||
}()
|
||
|
||
result = make([]map[string]any, 0)
|
||
for rows.Next() {
|
||
var (
|
||
date string
|
||
totalRequests int64
|
||
totalInputTokens int64
|
||
totalOutputTokens int64
|
||
totalCacheTokens int64
|
||
totalCost float64
|
||
totalActualCost float64
|
||
avgDurationMs float64
|
||
)
|
||
if err = rows.Scan(
|
||
&date,
|
||
&totalRequests,
|
||
&totalInputTokens,
|
||
&totalOutputTokens,
|
||
&totalCacheTokens,
|
||
&totalCost,
|
||
&totalActualCost,
|
||
&avgDurationMs,
|
||
); err != nil {
|
||
return nil, err
|
||
}
|
||
result = append(result, map[string]any{
|
||
"date": date,
|
||
"total_requests": totalRequests,
|
||
"total_input_tokens": totalInputTokens,
|
||
"total_output_tokens": totalOutputTokens,
|
||
"total_cache_tokens": totalCacheTokens,
|
||
"total_tokens": totalInputTokens + totalOutputTokens + totalCacheTokens,
|
||
"total_cost": totalCost,
|
||
"total_actual_cost": totalActualCost,
|
||
"average_duration_ms": avgDurationMs,
|
||
})
|
||
}
|
||
|
||
if err = rows.Err(); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
return result, nil
|
||
}
|
||
|
||
// resolveUsageStatsTimezone 获取用于 SQL 分组的时区名称。
|
||
// 优先使用应用初始化的时区,其次尝试读取 TZ 环境变量,最后回落为 UTC。
|
||
func resolveUsageStatsTimezone() string {
|
||
tzName := timezone.Name()
|
||
if tzName != "" && tzName != "Local" {
|
||
return tzName
|
||
}
|
||
if envTZ := strings.TrimSpace(os.Getenv("TZ")); envTZ != "" {
|
||
return envTZ
|
||
}
|
||
return "UTC"
|
||
}
|
||
|
||
func (r *usageLogRepository) ListByAPIKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE api_key_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC"
|
||
logs, err := r.queryUsageLogs(ctx, query, apiKeyID, startTime, endTime)
|
||
return logs, nil, err
|
||
}
|
||
|
||
func (r *usageLogRepository) ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE account_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC"
|
||
logs, err := r.queryUsageLogs(ctx, query, accountID, startTime, endTime)
|
||
return logs, nil, err
|
||
}
|
||
|
||
func (r *usageLogRepository) ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE model = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC"
|
||
logs, err := r.queryUsageLogs(ctx, query, modelName, startTime, endTime)
|
||
return logs, nil, err
|
||
}
|
||
|
||
func (r *usageLogRepository) Delete(ctx context.Context, id int64) error {
|
||
_, err := r.sql.ExecContext(ctx, "DELETE FROM usage_logs WHERE id = $1", id)
|
||
return err
|
||
}
|
||
|
||
// GetAccountTodayStats 获取账号今日统计
|
||
func (r *usageLogRepository) GetAccountTodayStats(ctx context.Context, accountID int64) (*usagestats.AccountStats, error) {
|
||
today := timezone.Today()
|
||
|
||
query := `
|
||
SELECT
|
||
COUNT(*) as requests,
|
||
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
|
||
COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as cost,
|
||
COALESCE(SUM(total_cost), 0) as standard_cost,
|
||
COALESCE(SUM(actual_cost), 0) as user_cost
|
||
FROM usage_logs
|
||
WHERE account_id = $1 AND created_at >= $2
|
||
`
|
||
|
||
stats := &usagestats.AccountStats{}
|
||
if err := scanSingleRow(
|
||
ctx,
|
||
r.sql,
|
||
query,
|
||
[]any{accountID, today},
|
||
&stats.Requests,
|
||
&stats.Tokens,
|
||
&stats.Cost,
|
||
&stats.StandardCost,
|
||
&stats.UserCost,
|
||
); err != nil {
|
||
return nil, err
|
||
}
|
||
return stats, nil
|
||
}
|
||
|
||
// GetAccountWindowStats 获取账号时间窗口内的统计
|
||
func (r *usageLogRepository) GetAccountWindowStats(ctx context.Context, accountID int64, startTime time.Time) (*usagestats.AccountStats, error) {
|
||
query := `
|
||
SELECT
|
||
COUNT(*) as requests,
|
||
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
|
||
COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as cost,
|
||
COALESCE(SUM(total_cost), 0) as standard_cost,
|
||
COALESCE(SUM(actual_cost), 0) as user_cost
|
||
FROM usage_logs
|
||
WHERE account_id = $1 AND created_at >= $2
|
||
`
|
||
|
||
stats := &usagestats.AccountStats{}
|
||
if err := scanSingleRow(
|
||
ctx,
|
||
r.sql,
|
||
query,
|
||
[]any{accountID, startTime},
|
||
&stats.Requests,
|
||
&stats.Tokens,
|
||
&stats.Cost,
|
||
&stats.StandardCost,
|
||
&stats.UserCost,
|
||
); err != nil {
|
||
return nil, err
|
||
}
|
||
return stats, nil
|
||
}
|
||
|
||
// TrendDataPoint represents a single point in trend data
|
||
type TrendDataPoint = usagestats.TrendDataPoint
|
||
|
||
// ModelStat represents usage statistics for a single model
|
||
type ModelStat = usagestats.ModelStat
|
||
|
||
// UserUsageTrendPoint represents user usage trend data point
|
||
type UserUsageTrendPoint = usagestats.UserUsageTrendPoint
|
||
|
||
// APIKeyUsageTrendPoint represents API key usage trend data point
|
||
type APIKeyUsageTrendPoint = usagestats.APIKeyUsageTrendPoint
|
||
|
||
// GetAPIKeyUsageTrend returns usage trend data grouped by API key and date
|
||
func (r *usageLogRepository) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) (results []APIKeyUsageTrendPoint, err error) {
|
||
dateFormat := "YYYY-MM-DD"
|
||
if granularity == "hour" {
|
||
dateFormat = "YYYY-MM-DD HH24:00"
|
||
}
|
||
|
||
query := fmt.Sprintf(`
|
||
WITH top_keys AS (
|
||
SELECT api_key_id
|
||
FROM usage_logs
|
||
WHERE created_at >= $1 AND created_at < $2
|
||
GROUP BY api_key_id
|
||
ORDER BY SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens) DESC
|
||
LIMIT $3
|
||
)
|
||
SELECT
|
||
TO_CHAR(u.created_at, '%s') as date,
|
||
u.api_key_id,
|
||
COALESCE(k.name, '') as key_name,
|
||
COUNT(*) as requests,
|
||
COALESCE(SUM(u.input_tokens + u.output_tokens + u.cache_creation_tokens + u.cache_read_tokens), 0) as tokens
|
||
FROM usage_logs u
|
||
LEFT JOIN api_keys k ON u.api_key_id = k.id
|
||
WHERE u.api_key_id IN (SELECT api_key_id FROM top_keys)
|
||
AND u.created_at >= $4 AND u.created_at < $5
|
||
GROUP BY date, u.api_key_id, k.name
|
||
ORDER BY date ASC, tokens DESC
|
||
`, dateFormat)
|
||
|
||
rows, err := r.sql.QueryContext(ctx, query, startTime, endTime, limit, startTime, endTime)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
defer func() {
|
||
// 保持主错误优先;仅在无错误时回传 Close 失败。
|
||
// 同时清空返回值,避免误用不完整结果。
|
||
if closeErr := rows.Close(); closeErr != nil && err == nil {
|
||
err = closeErr
|
||
results = nil
|
||
}
|
||
}()
|
||
|
||
results = make([]APIKeyUsageTrendPoint, 0)
|
||
for rows.Next() {
|
||
var row APIKeyUsageTrendPoint
|
||
if err = rows.Scan(&row.Date, &row.APIKeyID, &row.KeyName, &row.Requests, &row.Tokens); err != nil {
|
||
return nil, err
|
||
}
|
||
results = append(results, row)
|
||
}
|
||
if err = rows.Err(); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
return results, nil
|
||
}
|
||
|
||
// GetUserUsageTrend returns usage trend data grouped by user and date
|
||
func (r *usageLogRepository) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) (results []UserUsageTrendPoint, err error) {
|
||
dateFormat := "YYYY-MM-DD"
|
||
if granularity == "hour" {
|
||
dateFormat = "YYYY-MM-DD HH24:00"
|
||
}
|
||
|
||
query := fmt.Sprintf(`
|
||
WITH top_users AS (
|
||
SELECT user_id
|
||
FROM usage_logs
|
||
WHERE created_at >= $1 AND created_at < $2
|
||
GROUP BY user_id
|
||
ORDER BY SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens) DESC
|
||
LIMIT $3
|
||
)
|
||
SELECT
|
||
TO_CHAR(u.created_at, '%s') as date,
|
||
u.user_id,
|
||
COALESCE(us.email, '') as email,
|
||
COUNT(*) as requests,
|
||
COALESCE(SUM(u.input_tokens + u.output_tokens + u.cache_creation_tokens + u.cache_read_tokens), 0) as tokens,
|
||
COALESCE(SUM(u.total_cost), 0) as cost,
|
||
COALESCE(SUM(u.actual_cost), 0) as actual_cost
|
||
FROM usage_logs u
|
||
LEFT JOIN users us ON u.user_id = us.id
|
||
WHERE u.user_id IN (SELECT user_id FROM top_users)
|
||
AND u.created_at >= $4 AND u.created_at < $5
|
||
GROUP BY date, u.user_id, us.email
|
||
ORDER BY date ASC, tokens DESC
|
||
`, dateFormat)
|
||
|
||
rows, err := r.sql.QueryContext(ctx, query, startTime, endTime, limit, startTime, endTime)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
defer func() {
|
||
// 保持主错误优先;仅在无错误时回传 Close 失败。
|
||
// 同时清空返回值,避免误用不完整结果。
|
||
if closeErr := rows.Close(); closeErr != nil && err == nil {
|
||
err = closeErr
|
||
results = nil
|
||
}
|
||
}()
|
||
|
||
results = make([]UserUsageTrendPoint, 0)
|
||
for rows.Next() {
|
||
var row UserUsageTrendPoint
|
||
if err = rows.Scan(&row.Date, &row.UserID, &row.Email, &row.Requests, &row.Tokens, &row.Cost, &row.ActualCost); err != nil {
|
||
return nil, err
|
||
}
|
||
results = append(results, row)
|
||
}
|
||
if err = rows.Err(); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
return results, nil
|
||
}
|
||
|
||
// UserDashboardStats 用户仪表盘统计
|
||
type UserDashboardStats = usagestats.UserDashboardStats
|
||
|
||
// GetUserDashboardStats 获取用户专属的仪表盘统计
|
||
func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID int64) (*UserDashboardStats, error) {
|
||
stats := &UserDashboardStats{}
|
||
today := timezone.Today()
|
||
|
||
// API Key 统计
|
||
if err := scanSingleRow(
|
||
ctx,
|
||
r.sql,
|
||
"SELECT COUNT(*) FROM api_keys WHERE user_id = $1 AND deleted_at IS NULL",
|
||
[]any{userID},
|
||
&stats.TotalAPIKeys,
|
||
); err != nil {
|
||
return nil, err
|
||
}
|
||
if err := scanSingleRow(
|
||
ctx,
|
||
r.sql,
|
||
"SELECT COUNT(*) FROM api_keys WHERE user_id = $1 AND status = $2 AND deleted_at IS NULL",
|
||
[]any{userID, service.StatusActive},
|
||
&stats.ActiveAPIKeys,
|
||
); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// 累计 Token 统计
|
||
totalStatsQuery := `
|
||
SELECT
|
||
COUNT(*) as total_requests,
|
||
COALESCE(SUM(input_tokens), 0) as total_input_tokens,
|
||
COALESCE(SUM(output_tokens), 0) as total_output_tokens,
|
||
COALESCE(SUM(cache_creation_tokens), 0) as total_cache_creation_tokens,
|
||
COALESCE(SUM(cache_read_tokens), 0) as total_cache_read_tokens,
|
||
COALESCE(SUM(total_cost), 0) as total_cost,
|
||
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
|
||
COALESCE(AVG(duration_ms), 0) as avg_duration_ms
|
||
FROM usage_logs
|
||
WHERE user_id = $1
|
||
`
|
||
if err := scanSingleRow(
|
||
ctx,
|
||
r.sql,
|
||
totalStatsQuery,
|
||
[]any{userID},
|
||
&stats.TotalRequests,
|
||
&stats.TotalInputTokens,
|
||
&stats.TotalOutputTokens,
|
||
&stats.TotalCacheCreationTokens,
|
||
&stats.TotalCacheReadTokens,
|
||
&stats.TotalCost,
|
||
&stats.TotalActualCost,
|
||
&stats.AverageDurationMs,
|
||
); err != nil {
|
||
return nil, err
|
||
}
|
||
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheCreationTokens + stats.TotalCacheReadTokens
|
||
|
||
// 今日 Token 统计
|
||
todayStatsQuery := `
|
||
SELECT
|
||
COUNT(*) as today_requests,
|
||
COALESCE(SUM(input_tokens), 0) as today_input_tokens,
|
||
COALESCE(SUM(output_tokens), 0) as today_output_tokens,
|
||
COALESCE(SUM(cache_creation_tokens), 0) as today_cache_creation_tokens,
|
||
COALESCE(SUM(cache_read_tokens), 0) as today_cache_read_tokens,
|
||
COALESCE(SUM(total_cost), 0) as today_cost,
|
||
COALESCE(SUM(actual_cost), 0) as today_actual_cost
|
||
FROM usage_logs
|
||
WHERE user_id = $1 AND created_at >= $2
|
||
`
|
||
if err := scanSingleRow(
|
||
ctx,
|
||
r.sql,
|
||
todayStatsQuery,
|
||
[]any{userID, today},
|
||
&stats.TodayRequests,
|
||
&stats.TodayInputTokens,
|
||
&stats.TodayOutputTokens,
|
||
&stats.TodayCacheCreationTokens,
|
||
&stats.TodayCacheReadTokens,
|
||
&stats.TodayCost,
|
||
&stats.TodayActualCost,
|
||
); err != nil {
|
||
return nil, err
|
||
}
|
||
stats.TodayTokens = stats.TodayInputTokens + stats.TodayOutputTokens + stats.TodayCacheCreationTokens + stats.TodayCacheReadTokens
|
||
|
||
// 性能指标:RPM 和 TPM(最近1分钟,仅统计该用户的请求)
|
||
rpm, tpm, err := r.getPerformanceStats(ctx, userID)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
stats.Rpm = rpm
|
||
stats.Tpm = tpm
|
||
|
||
return stats, nil
|
||
}
|
||
|
||
// getPerformanceStatsByAPIKey 获取指定 API Key 的 RPM 和 TPM(近5分钟平均值)
|
||
func (r *usageLogRepository) getPerformanceStatsByAPIKey(ctx context.Context, apiKeyID int64) (rpm, tpm int64, err error) {
|
||
fiveMinutesAgo := time.Now().Add(-5 * time.Minute)
|
||
query := `
|
||
SELECT
|
||
COUNT(*) as request_count,
|
||
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as token_count
|
||
FROM usage_logs
|
||
WHERE created_at >= $1 AND api_key_id = $2`
|
||
args := []any{fiveMinutesAgo, apiKeyID}
|
||
|
||
var requestCount int64
|
||
var tokenCount int64
|
||
if err := scanSingleRow(ctx, r.sql, query, args, &requestCount, &tokenCount); err != nil {
|
||
return 0, 0, err
|
||
}
|
||
return requestCount / 5, tokenCount / 5, nil
|
||
}
|
||
|
||
// GetAPIKeyDashboardStats 获取指定 API Key 的仪表盘统计(按 api_key_id 过滤)
|
||
func (r *usageLogRepository) GetAPIKeyDashboardStats(ctx context.Context, apiKeyID int64) (*UserDashboardStats, error) {
|
||
stats := &UserDashboardStats{}
|
||
today := timezone.Today()
|
||
|
||
// API Key 维度不需要统计 key 数量,设为 1
|
||
stats.TotalAPIKeys = 1
|
||
stats.ActiveAPIKeys = 1
|
||
|
||
// 累计 Token 统计
|
||
totalStatsQuery := `
|
||
SELECT
|
||
COUNT(*) as total_requests,
|
||
COALESCE(SUM(input_tokens), 0) as total_input_tokens,
|
||
COALESCE(SUM(output_tokens), 0) as total_output_tokens,
|
||
COALESCE(SUM(cache_creation_tokens), 0) as total_cache_creation_tokens,
|
||
COALESCE(SUM(cache_read_tokens), 0) as total_cache_read_tokens,
|
||
COALESCE(SUM(total_cost), 0) as total_cost,
|
||
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
|
||
COALESCE(AVG(duration_ms), 0) as avg_duration_ms
|
||
FROM usage_logs
|
||
WHERE api_key_id = $1
|
||
`
|
||
if err := scanSingleRow(
|
||
ctx,
|
||
r.sql,
|
||
totalStatsQuery,
|
||
[]any{apiKeyID},
|
||
&stats.TotalRequests,
|
||
&stats.TotalInputTokens,
|
||
&stats.TotalOutputTokens,
|
||
&stats.TotalCacheCreationTokens,
|
||
&stats.TotalCacheReadTokens,
|
||
&stats.TotalCost,
|
||
&stats.TotalActualCost,
|
||
&stats.AverageDurationMs,
|
||
); err != nil {
|
||
return nil, err
|
||
}
|
||
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheCreationTokens + stats.TotalCacheReadTokens
|
||
|
||
// 今日 Token 统计
|
||
todayStatsQuery := `
|
||
SELECT
|
||
COUNT(*) as today_requests,
|
||
COALESCE(SUM(input_tokens), 0) as today_input_tokens,
|
||
COALESCE(SUM(output_tokens), 0) as today_output_tokens,
|
||
COALESCE(SUM(cache_creation_tokens), 0) as today_cache_creation_tokens,
|
||
COALESCE(SUM(cache_read_tokens), 0) as today_cache_read_tokens,
|
||
COALESCE(SUM(total_cost), 0) as today_cost,
|
||
COALESCE(SUM(actual_cost), 0) as today_actual_cost
|
||
FROM usage_logs
|
||
WHERE api_key_id = $1 AND created_at >= $2
|
||
`
|
||
if err := scanSingleRow(
|
||
ctx,
|
||
r.sql,
|
||
todayStatsQuery,
|
||
[]any{apiKeyID, today},
|
||
&stats.TodayRequests,
|
||
&stats.TodayInputTokens,
|
||
&stats.TodayOutputTokens,
|
||
&stats.TodayCacheCreationTokens,
|
||
&stats.TodayCacheReadTokens,
|
||
&stats.TodayCost,
|
||
&stats.TodayActualCost,
|
||
); err != nil {
|
||
return nil, err
|
||
}
|
||
stats.TodayTokens = stats.TodayInputTokens + stats.TodayOutputTokens + stats.TodayCacheCreationTokens + stats.TodayCacheReadTokens
|
||
|
||
// 性能指标:RPM 和 TPM(最近5分钟,按 API Key 过滤)
|
||
rpm, tpm, err := r.getPerformanceStatsByAPIKey(ctx, apiKeyID)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
stats.Rpm = rpm
|
||
stats.Tpm = tpm
|
||
|
||
return stats, nil
|
||
}
|
||
|
||
// GetUserUsageTrendByUserID 获取指定用户的使用趋势
|
||
func (r *usageLogRepository) GetUserUsageTrendByUserID(ctx context.Context, userID int64, startTime, endTime time.Time, granularity string) (results []TrendDataPoint, err error) {
|
||
dateFormat := "YYYY-MM-DD"
|
||
if granularity == "hour" {
|
||
dateFormat = "YYYY-MM-DD HH24:00"
|
||
}
|
||
|
||
query := fmt.Sprintf(`
|
||
SELECT
|
||
TO_CHAR(created_at, '%s') as date,
|
||
COUNT(*) as requests,
|
||
COALESCE(SUM(input_tokens), 0) as input_tokens,
|
||
COALESCE(SUM(output_tokens), 0) as output_tokens,
|
||
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as cache_tokens,
|
||
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
|
||
COALESCE(SUM(total_cost), 0) as cost,
|
||
COALESCE(SUM(actual_cost), 0) as actual_cost
|
||
FROM usage_logs
|
||
WHERE user_id = $1 AND created_at >= $2 AND created_at < $3
|
||
GROUP BY date
|
||
ORDER BY date ASC
|
||
`, dateFormat)
|
||
|
||
rows, err := r.sql.QueryContext(ctx, query, userID, startTime, endTime)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
defer func() {
|
||
// 保持主错误优先;仅在无错误时回传 Close 失败。
|
||
// 同时清空返回值,避免误用不完整结果。
|
||
if closeErr := rows.Close(); closeErr != nil && err == nil {
|
||
err = closeErr
|
||
results = nil
|
||
}
|
||
}()
|
||
|
||
results, err = scanTrendRows(rows)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return results, nil
|
||
}
|
||
|
||
// GetUserModelStats 获取指定用户的模型统计
|
||
func (r *usageLogRepository) GetUserModelStats(ctx context.Context, userID int64, startTime, endTime time.Time) (results []ModelStat, err error) {
|
||
query := `
|
||
SELECT
|
||
model,
|
||
COUNT(*) as requests,
|
||
COALESCE(SUM(input_tokens), 0) as input_tokens,
|
||
COALESCE(SUM(output_tokens), 0) as output_tokens,
|
||
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
|
||
COALESCE(SUM(total_cost), 0) as cost,
|
||
COALESCE(SUM(actual_cost), 0) as actual_cost
|
||
FROM usage_logs
|
||
WHERE user_id = $1 AND created_at >= $2 AND created_at < $3
|
||
GROUP BY model
|
||
ORDER BY total_tokens DESC
|
||
`
|
||
|
||
rows, err := r.sql.QueryContext(ctx, query, userID, startTime, endTime)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
defer func() {
|
||
// 保持主错误优先;仅在无错误时回传 Close 失败。
|
||
// 同时清空返回值,避免误用不完整结果。
|
||
if closeErr := rows.Close(); closeErr != nil && err == nil {
|
||
err = closeErr
|
||
results = nil
|
||
}
|
||
}()
|
||
|
||
results, err = scanModelStatsRows(rows)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return results, nil
|
||
}
|
||
|
||
// UsageLogFilters represents filters for usage log queries
|
||
type UsageLogFilters = usagestats.UsageLogFilters
|
||
|
||
// ListWithFilters lists usage logs with optional filters (for admin)
|
||
func (r *usageLogRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters UsageLogFilters) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||
conditions := make([]string, 0, 8)
|
||
args := make([]any, 0, 8)
|
||
|
||
if filters.UserID > 0 {
|
||
conditions = append(conditions, fmt.Sprintf("user_id = $%d", len(args)+1))
|
||
args = append(args, filters.UserID)
|
||
}
|
||
if filters.APIKeyID > 0 {
|
||
conditions = append(conditions, fmt.Sprintf("api_key_id = $%d", len(args)+1))
|
||
args = append(args, filters.APIKeyID)
|
||
}
|
||
if filters.AccountID > 0 {
|
||
conditions = append(conditions, fmt.Sprintf("account_id = $%d", len(args)+1))
|
||
args = append(args, filters.AccountID)
|
||
}
|
||
if filters.GroupID > 0 {
|
||
conditions = append(conditions, fmt.Sprintf("group_id = $%d", len(args)+1))
|
||
args = append(args, filters.GroupID)
|
||
}
|
||
if filters.Model != "" {
|
||
conditions = append(conditions, fmt.Sprintf("model = $%d", len(args)+1))
|
||
args = append(args, filters.Model)
|
||
}
|
||
if filters.Stream != nil {
|
||
conditions = append(conditions, fmt.Sprintf("stream = $%d", len(args)+1))
|
||
args = append(args, *filters.Stream)
|
||
}
|
||
if filters.BillingType != nil {
|
||
conditions = append(conditions, fmt.Sprintf("billing_type = $%d", len(args)+1))
|
||
args = append(args, int16(*filters.BillingType))
|
||
}
|
||
if filters.StartTime != nil {
|
||
conditions = append(conditions, fmt.Sprintf("created_at >= $%d", len(args)+1))
|
||
args = append(args, *filters.StartTime)
|
||
}
|
||
if filters.EndTime != nil {
|
||
conditions = append(conditions, fmt.Sprintf("created_at <= $%d", len(args)+1))
|
||
args = append(args, *filters.EndTime)
|
||
}
|
||
|
||
whereClause := buildWhere(conditions)
|
||
logs, page, err := r.listUsageLogsWithPagination(ctx, whereClause, args, params)
|
||
if err != nil {
|
||
return nil, nil, err
|
||
}
|
||
|
||
if err := r.hydrateUsageLogAssociations(ctx, logs); err != nil {
|
||
return nil, nil, err
|
||
}
|
||
return logs, page, nil
|
||
}
|
||
|
||
// UsageStats represents usage statistics
|
||
type UsageStats = usagestats.UsageStats
|
||
|
||
// BatchUserUsageStats represents usage stats for a single user
|
||
type BatchUserUsageStats = usagestats.BatchUserUsageStats
|
||
|
||
// GetBatchUserUsageStats gets today and total actual_cost for multiple users
|
||
func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*BatchUserUsageStats, error) {
|
||
result := make(map[int64]*BatchUserUsageStats)
|
||
if len(userIDs) == 0 {
|
||
return result, nil
|
||
}
|
||
|
||
for _, id := range userIDs {
|
||
result[id] = &BatchUserUsageStats{UserID: id}
|
||
}
|
||
|
||
query := `
|
||
SELECT user_id, COALESCE(SUM(actual_cost), 0) as total_cost
|
||
FROM usage_logs
|
||
WHERE user_id = ANY($1)
|
||
GROUP BY user_id
|
||
`
|
||
rows, err := r.sql.QueryContext(ctx, query, pq.Array(userIDs))
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
for rows.Next() {
|
||
var userID int64
|
||
var total float64
|
||
if err := rows.Scan(&userID, &total); err != nil {
|
||
_ = rows.Close()
|
||
return nil, err
|
||
}
|
||
if stats, ok := result[userID]; ok {
|
||
stats.TotalActualCost = total
|
||
}
|
||
}
|
||
if err := rows.Close(); err != nil {
|
||
return nil, err
|
||
}
|
||
if err := rows.Err(); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
today := timezone.Today()
|
||
todayQuery := `
|
||
SELECT user_id, COALESCE(SUM(actual_cost), 0) as today_cost
|
||
FROM usage_logs
|
||
WHERE user_id = ANY($1) AND created_at >= $2
|
||
GROUP BY user_id
|
||
`
|
||
rows, err = r.sql.QueryContext(ctx, todayQuery, pq.Array(userIDs), today)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
for rows.Next() {
|
||
var userID int64
|
||
var total float64
|
||
if err := rows.Scan(&userID, &total); err != nil {
|
||
_ = rows.Close()
|
||
return nil, err
|
||
}
|
||
if stats, ok := result[userID]; ok {
|
||
stats.TodayActualCost = total
|
||
}
|
||
}
|
||
if err := rows.Close(); err != nil {
|
||
return nil, err
|
||
}
|
||
if err := rows.Err(); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
return result, nil
|
||
}
|
||
|
||
// BatchAPIKeyUsageStats represents usage stats for a single API key
|
||
type BatchAPIKeyUsageStats = usagestats.BatchAPIKeyUsageStats
|
||
|
||
// GetBatchAPIKeyUsageStats gets today and total actual_cost for multiple API keys
|
||
func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*BatchAPIKeyUsageStats, error) {
|
||
result := make(map[int64]*BatchAPIKeyUsageStats)
|
||
if len(apiKeyIDs) == 0 {
|
||
return result, nil
|
||
}
|
||
|
||
for _, id := range apiKeyIDs {
|
||
result[id] = &BatchAPIKeyUsageStats{APIKeyID: id}
|
||
}
|
||
|
||
query := `
|
||
SELECT api_key_id, COALESCE(SUM(actual_cost), 0) as total_cost
|
||
FROM usage_logs
|
||
WHERE api_key_id = ANY($1)
|
||
GROUP BY api_key_id
|
||
`
|
||
rows, err := r.sql.QueryContext(ctx, query, pq.Array(apiKeyIDs))
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
for rows.Next() {
|
||
var apiKeyID int64
|
||
var total float64
|
||
if err := rows.Scan(&apiKeyID, &total); err != nil {
|
||
_ = rows.Close()
|
||
return nil, err
|
||
}
|
||
if stats, ok := result[apiKeyID]; ok {
|
||
stats.TotalActualCost = total
|
||
}
|
||
}
|
||
if err := rows.Close(); err != nil {
|
||
return nil, err
|
||
}
|
||
if err := rows.Err(); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
today := timezone.Today()
|
||
todayQuery := `
|
||
SELECT api_key_id, COALESCE(SUM(actual_cost), 0) as today_cost
|
||
FROM usage_logs
|
||
WHERE api_key_id = ANY($1) AND created_at >= $2
|
||
GROUP BY api_key_id
|
||
`
|
||
rows, err = r.sql.QueryContext(ctx, todayQuery, pq.Array(apiKeyIDs), today)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
for rows.Next() {
|
||
var apiKeyID int64
|
||
var total float64
|
||
if err := rows.Scan(&apiKeyID, &total); err != nil {
|
||
_ = rows.Close()
|
||
return nil, err
|
||
}
|
||
if stats, ok := result[apiKeyID]; ok {
|
||
stats.TodayActualCost = total
|
||
}
|
||
}
|
||
if err := rows.Close(); err != nil {
|
||
return nil, err
|
||
}
|
||
if err := rows.Err(); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
return result, nil
|
||
}
|
||
|
||
// GetUsageTrendWithFilters returns usage trend data with optional filters
|
||
func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool, billingType *int8) (results []TrendDataPoint, err error) {
|
||
dateFormat := "YYYY-MM-DD"
|
||
if granularity == "hour" {
|
||
dateFormat = "YYYY-MM-DD HH24:00"
|
||
}
|
||
|
||
query := fmt.Sprintf(`
|
||
SELECT
|
||
TO_CHAR(created_at, '%s') as date,
|
||
COUNT(*) as requests,
|
||
COALESCE(SUM(input_tokens), 0) as input_tokens,
|
||
COALESCE(SUM(output_tokens), 0) as output_tokens,
|
||
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as cache_tokens,
|
||
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
|
||
COALESCE(SUM(total_cost), 0) as cost,
|
||
COALESCE(SUM(actual_cost), 0) as actual_cost
|
||
FROM usage_logs
|
||
WHERE created_at >= $1 AND created_at < $2
|
||
`, dateFormat)
|
||
|
||
args := []any{startTime, endTime}
|
||
if userID > 0 {
|
||
query += fmt.Sprintf(" AND user_id = $%d", len(args)+1)
|
||
args = append(args, userID)
|
||
}
|
||
if apiKeyID > 0 {
|
||
query += fmt.Sprintf(" AND api_key_id = $%d", len(args)+1)
|
||
args = append(args, apiKeyID)
|
||
}
|
||
if accountID > 0 {
|
||
query += fmt.Sprintf(" AND account_id = $%d", len(args)+1)
|
||
args = append(args, accountID)
|
||
}
|
||
if groupID > 0 {
|
||
query += fmt.Sprintf(" AND group_id = $%d", len(args)+1)
|
||
args = append(args, groupID)
|
||
}
|
||
if model != "" {
|
||
query += fmt.Sprintf(" AND model = $%d", len(args)+1)
|
||
args = append(args, model)
|
||
}
|
||
if stream != nil {
|
||
query += fmt.Sprintf(" AND stream = $%d", len(args)+1)
|
||
args = append(args, *stream)
|
||
}
|
||
if billingType != nil {
|
||
query += fmt.Sprintf(" AND billing_type = $%d", len(args)+1)
|
||
args = append(args, int16(*billingType))
|
||
}
|
||
query += " GROUP BY date ORDER BY date ASC"
|
||
|
||
rows, err := r.sql.QueryContext(ctx, query, args...)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
defer func() {
|
||
// 保持主错误优先;仅在无错误时回传 Close 失败。
|
||
// 同时清空返回值,避免误用不完整结果。
|
||
if closeErr := rows.Close(); closeErr != nil && err == nil {
|
||
err = closeErr
|
||
results = nil
|
||
}
|
||
}()
|
||
|
||
results, err = scanTrendRows(rows)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return results, nil
|
||
}
|
||
|
||
// GetModelStatsWithFilters returns model statistics with optional filters
|
||
func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool, billingType *int8) (results []ModelStat, err error) {
|
||
actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost"
|
||
// 当仅按 account_id 聚合时,实际费用使用账号倍率(total_cost * account_rate_multiplier)。
|
||
if accountID > 0 && userID == 0 && apiKeyID == 0 {
|
||
actualCostExpr = "COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost"
|
||
}
|
||
|
||
query := fmt.Sprintf(`
|
||
SELECT
|
||
model,
|
||
COUNT(*) as requests,
|
||
COALESCE(SUM(input_tokens), 0) as input_tokens,
|
||
COALESCE(SUM(output_tokens), 0) as output_tokens,
|
||
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
|
||
COALESCE(SUM(total_cost), 0) as cost,
|
||
%s
|
||
FROM usage_logs
|
||
WHERE created_at >= $1 AND created_at < $2
|
||
`, actualCostExpr)
|
||
|
||
args := []any{startTime, endTime}
|
||
if userID > 0 {
|
||
query += fmt.Sprintf(" AND user_id = $%d", len(args)+1)
|
||
args = append(args, userID)
|
||
}
|
||
if apiKeyID > 0 {
|
||
query += fmt.Sprintf(" AND api_key_id = $%d", len(args)+1)
|
||
args = append(args, apiKeyID)
|
||
}
|
||
if accountID > 0 {
|
||
query += fmt.Sprintf(" AND account_id = $%d", len(args)+1)
|
||
args = append(args, accountID)
|
||
}
|
||
if groupID > 0 {
|
||
query += fmt.Sprintf(" AND group_id = $%d", len(args)+1)
|
||
args = append(args, groupID)
|
||
}
|
||
if stream != nil {
|
||
query += fmt.Sprintf(" AND stream = $%d", len(args)+1)
|
||
args = append(args, *stream)
|
||
}
|
||
if billingType != nil {
|
||
query += fmt.Sprintf(" AND billing_type = $%d", len(args)+1)
|
||
args = append(args, int16(*billingType))
|
||
}
|
||
query += " GROUP BY model ORDER BY total_tokens DESC"
|
||
|
||
rows, err := r.sql.QueryContext(ctx, query, args...)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
defer func() {
|
||
// 保持主错误优先;仅在无错误时回传 Close 失败。
|
||
// 同时清空返回值,避免误用不完整结果。
|
||
if closeErr := rows.Close(); closeErr != nil && err == nil {
|
||
err = closeErr
|
||
results = nil
|
||
}
|
||
}()
|
||
|
||
results, err = scanModelStatsRows(rows)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return results, nil
|
||
}
|
||
|
||
// GetGlobalStats gets usage statistics for all users within a time range
|
||
func (r *usageLogRepository) GetGlobalStats(ctx context.Context, startTime, endTime time.Time) (*UsageStats, error) {
|
||
query := `
|
||
SELECT
|
||
COUNT(*) as total_requests,
|
||
COALESCE(SUM(input_tokens), 0) as total_input_tokens,
|
||
COALESCE(SUM(output_tokens), 0) as total_output_tokens,
|
||
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as total_cache_tokens,
|
||
COALESCE(SUM(total_cost), 0) as total_cost,
|
||
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
|
||
COALESCE(AVG(duration_ms), 0) as avg_duration_ms
|
||
FROM usage_logs
|
||
WHERE created_at >= $1 AND created_at <= $2
|
||
`
|
||
|
||
stats := &UsageStats{}
|
||
if err := scanSingleRow(
|
||
ctx,
|
||
r.sql,
|
||
query,
|
||
[]any{startTime, endTime},
|
||
&stats.TotalRequests,
|
||
&stats.TotalInputTokens,
|
||
&stats.TotalOutputTokens,
|
||
&stats.TotalCacheTokens,
|
||
&stats.TotalCost,
|
||
&stats.TotalActualCost,
|
||
&stats.AverageDurationMs,
|
||
); err != nil {
|
||
return nil, err
|
||
}
|
||
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheTokens
|
||
return stats, nil
|
||
}
|
||
|
||
// GetStatsWithFilters gets usage statistics with optional filters
|
||
func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters UsageLogFilters) (*UsageStats, error) {
|
||
conditions := make([]string, 0, 9)
|
||
args := make([]any, 0, 9)
|
||
|
||
if filters.UserID > 0 {
|
||
conditions = append(conditions, fmt.Sprintf("user_id = $%d", len(args)+1))
|
||
args = append(args, filters.UserID)
|
||
}
|
||
if filters.APIKeyID > 0 {
|
||
conditions = append(conditions, fmt.Sprintf("api_key_id = $%d", len(args)+1))
|
||
args = append(args, filters.APIKeyID)
|
||
}
|
||
if filters.AccountID > 0 {
|
||
conditions = append(conditions, fmt.Sprintf("account_id = $%d", len(args)+1))
|
||
args = append(args, filters.AccountID)
|
||
}
|
||
if filters.GroupID > 0 {
|
||
conditions = append(conditions, fmt.Sprintf("group_id = $%d", len(args)+1))
|
||
args = append(args, filters.GroupID)
|
||
}
|
||
if filters.Model != "" {
|
||
conditions = append(conditions, fmt.Sprintf("model = $%d", len(args)+1))
|
||
args = append(args, filters.Model)
|
||
}
|
||
if filters.Stream != nil {
|
||
conditions = append(conditions, fmt.Sprintf("stream = $%d", len(args)+1))
|
||
args = append(args, *filters.Stream)
|
||
}
|
||
if filters.BillingType != nil {
|
||
conditions = append(conditions, fmt.Sprintf("billing_type = $%d", len(args)+1))
|
||
args = append(args, int16(*filters.BillingType))
|
||
}
|
||
if filters.StartTime != nil {
|
||
conditions = append(conditions, fmt.Sprintf("created_at >= $%d", len(args)+1))
|
||
args = append(args, *filters.StartTime)
|
||
}
|
||
if filters.EndTime != nil {
|
||
conditions = append(conditions, fmt.Sprintf("created_at <= $%d", len(args)+1))
|
||
args = append(args, *filters.EndTime)
|
||
}
|
||
|
||
query := fmt.Sprintf(`
|
||
SELECT
|
||
COUNT(*) as total_requests,
|
||
COALESCE(SUM(input_tokens), 0) as total_input_tokens,
|
||
COALESCE(SUM(output_tokens), 0) as total_output_tokens,
|
||
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as total_cache_tokens,
|
||
COALESCE(SUM(total_cost), 0) as total_cost,
|
||
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
|
||
COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as total_account_cost,
|
||
COALESCE(AVG(duration_ms), 0) as avg_duration_ms
|
||
FROM usage_logs
|
||
%s
|
||
`, buildWhere(conditions))
|
||
|
||
stats := &UsageStats{}
|
||
var totalAccountCost float64
|
||
if err := scanSingleRow(
|
||
ctx,
|
||
r.sql,
|
||
query,
|
||
args,
|
||
&stats.TotalRequests,
|
||
&stats.TotalInputTokens,
|
||
&stats.TotalOutputTokens,
|
||
&stats.TotalCacheTokens,
|
||
&stats.TotalCost,
|
||
&stats.TotalActualCost,
|
||
&totalAccountCost,
|
||
&stats.AverageDurationMs,
|
||
); err != nil {
|
||
return nil, err
|
||
}
|
||
if filters.AccountID > 0 {
|
||
stats.TotalAccountCost = &totalAccountCost
|
||
}
|
||
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheTokens
|
||
return stats, nil
|
||
}
|
||
|
||
// AccountUsageHistory represents daily usage history for an account
|
||
type AccountUsageHistory = usagestats.AccountUsageHistory
|
||
|
||
// AccountUsageSummary represents summary statistics for an account
|
||
type AccountUsageSummary = usagestats.AccountUsageSummary
|
||
|
||
// AccountUsageStatsResponse represents the full usage statistics response for an account
|
||
type AccountUsageStatsResponse = usagestats.AccountUsageStatsResponse
|
||
|
||
// GetAccountUsageStats returns comprehensive usage statistics for an account over a time range
|
||
func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID int64, startTime, endTime time.Time) (resp *AccountUsageStatsResponse, err error) {
|
||
daysCount := int(endTime.Sub(startTime).Hours()/24) + 1
|
||
if daysCount <= 0 {
|
||
daysCount = 30
|
||
}
|
||
|
||
query := `
|
||
SELECT
|
||
TO_CHAR(created_at, 'YYYY-MM-DD') as date,
|
||
COUNT(*) as requests,
|
||
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
|
||
COALESCE(SUM(total_cost), 0) as cost,
|
||
COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost,
|
||
COALESCE(SUM(actual_cost), 0) as user_cost
|
||
FROM usage_logs
|
||
WHERE account_id = $1 AND created_at >= $2 AND created_at < $3
|
||
GROUP BY date
|
||
ORDER BY date ASC
|
||
`
|
||
|
||
rows, err := r.sql.QueryContext(ctx, query, accountID, startTime, endTime)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
defer func() {
|
||
// 保持主错误优先;仅在无错误时回传 Close 失败。
|
||
// 同时清空返回值,避免误用不完整结果。
|
||
if closeErr := rows.Close(); closeErr != nil && err == nil {
|
||
err = closeErr
|
||
resp = nil
|
||
}
|
||
}()
|
||
|
||
history := make([]AccountUsageHistory, 0)
|
||
for rows.Next() {
|
||
var date string
|
||
var requests int64
|
||
var tokens int64
|
||
var cost float64
|
||
var actualCost float64
|
||
var userCost float64
|
||
if err = rows.Scan(&date, &requests, &tokens, &cost, &actualCost, &userCost); err != nil {
|
||
return nil, err
|
||
}
|
||
t, _ := time.Parse("2006-01-02", date)
|
||
history = append(history, AccountUsageHistory{
|
||
Date: date,
|
||
Label: t.Format("01/02"),
|
||
Requests: requests,
|
||
Tokens: tokens,
|
||
Cost: cost,
|
||
ActualCost: actualCost,
|
||
UserCost: userCost,
|
||
})
|
||
}
|
||
if err = rows.Err(); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
var totalAccountCost, totalUserCost, totalStandardCost float64
|
||
var totalRequests, totalTokens int64
|
||
var highestCostDay, highestRequestDay *AccountUsageHistory
|
||
|
||
for i := range history {
|
||
h := &history[i]
|
||
totalAccountCost += h.ActualCost
|
||
totalUserCost += h.UserCost
|
||
totalStandardCost += h.Cost
|
||
totalRequests += h.Requests
|
||
totalTokens += h.Tokens
|
||
|
||
if highestCostDay == nil || h.ActualCost > highestCostDay.ActualCost {
|
||
highestCostDay = h
|
||
}
|
||
if highestRequestDay == nil || h.Requests > highestRequestDay.Requests {
|
||
highestRequestDay = h
|
||
}
|
||
}
|
||
|
||
actualDaysUsed := len(history)
|
||
if actualDaysUsed == 0 {
|
||
actualDaysUsed = 1
|
||
}
|
||
|
||
avgQuery := "SELECT COALESCE(AVG(duration_ms), 0) as avg_duration_ms FROM usage_logs WHERE account_id = $1 AND created_at >= $2 AND created_at < $3"
|
||
var avgDuration float64
|
||
if err := scanSingleRow(ctx, r.sql, avgQuery, []any{accountID, startTime, endTime}, &avgDuration); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
summary := AccountUsageSummary{
|
||
Days: daysCount,
|
||
ActualDaysUsed: actualDaysUsed,
|
||
TotalCost: totalAccountCost,
|
||
TotalUserCost: totalUserCost,
|
||
TotalStandardCost: totalStandardCost,
|
||
TotalRequests: totalRequests,
|
||
TotalTokens: totalTokens,
|
||
AvgDailyCost: totalAccountCost / float64(actualDaysUsed),
|
||
AvgDailyUserCost: totalUserCost / float64(actualDaysUsed),
|
||
AvgDailyRequests: float64(totalRequests) / float64(actualDaysUsed),
|
||
AvgDailyTokens: float64(totalTokens) / float64(actualDaysUsed),
|
||
AvgDurationMs: avgDuration,
|
||
}
|
||
|
||
todayStr := timezone.Now().Format("2006-01-02")
|
||
for i := range history {
|
||
if history[i].Date == todayStr {
|
||
summary.Today = &struct {
|
||
Date string `json:"date"`
|
||
Cost float64 `json:"cost"`
|
||
UserCost float64 `json:"user_cost"`
|
||
Requests int64 `json:"requests"`
|
||
Tokens int64 `json:"tokens"`
|
||
}{
|
||
Date: history[i].Date,
|
||
Cost: history[i].ActualCost,
|
||
UserCost: history[i].UserCost,
|
||
Requests: history[i].Requests,
|
||
Tokens: history[i].Tokens,
|
||
}
|
||
break
|
||
}
|
||
}
|
||
|
||
if highestCostDay != nil {
|
||
summary.HighestCostDay = &struct {
|
||
Date string `json:"date"`
|
||
Label string `json:"label"`
|
||
Cost float64 `json:"cost"`
|
||
UserCost float64 `json:"user_cost"`
|
||
Requests int64 `json:"requests"`
|
||
}{
|
||
Date: highestCostDay.Date,
|
||
Label: highestCostDay.Label,
|
||
Cost: highestCostDay.ActualCost,
|
||
UserCost: highestCostDay.UserCost,
|
||
Requests: highestCostDay.Requests,
|
||
}
|
||
}
|
||
|
||
if highestRequestDay != nil {
|
||
summary.HighestRequestDay = &struct {
|
||
Date string `json:"date"`
|
||
Label string `json:"label"`
|
||
Requests int64 `json:"requests"`
|
||
Cost float64 `json:"cost"`
|
||
UserCost float64 `json:"user_cost"`
|
||
}{
|
||
Date: highestRequestDay.Date,
|
||
Label: highestRequestDay.Label,
|
||
Requests: highestRequestDay.Requests,
|
||
Cost: highestRequestDay.ActualCost,
|
||
UserCost: highestRequestDay.UserCost,
|
||
}
|
||
}
|
||
|
||
models, err := r.GetModelStatsWithFilters(ctx, startTime, endTime, 0, 0, accountID, 0, nil, nil)
|
||
if err != nil {
|
||
models = []ModelStat{}
|
||
}
|
||
|
||
resp = &AccountUsageStatsResponse{
|
||
History: history,
|
||
Summary: summary,
|
||
Models: models,
|
||
}
|
||
return resp, nil
|
||
}
|
||
|
||
func (r *usageLogRepository) listUsageLogsWithPagination(ctx context.Context, whereClause string, args []any, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||
countQuery := "SELECT COUNT(*) FROM usage_logs " + whereClause
|
||
var total int64
|
||
if err := scanSingleRow(ctx, r.sql, countQuery, args, &total); err != nil {
|
||
return nil, nil, err
|
||
}
|
||
|
||
limitPos := len(args) + 1
|
||
offsetPos := len(args) + 2
|
||
listArgs := append(append([]any{}, args...), params.Limit(), params.Offset())
|
||
query := fmt.Sprintf("SELECT %s FROM usage_logs %s ORDER BY id DESC LIMIT $%d OFFSET $%d", usageLogSelectColumns, whereClause, limitPos, offsetPos)
|
||
logs, err := r.queryUsageLogs(ctx, query, listArgs...)
|
||
if err != nil {
|
||
return nil, nil, err
|
||
}
|
||
return logs, paginationResultFromTotal(total, params), nil
|
||
}
|
||
|
||
func (r *usageLogRepository) queryUsageLogs(ctx context.Context, query string, args ...any) (logs []service.UsageLog, err error) {
|
||
rows, err := r.sql.QueryContext(ctx, query, args...)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
defer func() {
|
||
// 保持主错误优先;仅在无错误时回传 Close 失败。
|
||
// 同时清空返回值,避免误用不完整结果。
|
||
if closeErr := rows.Close(); closeErr != nil && err == nil {
|
||
err = closeErr
|
||
logs = nil
|
||
}
|
||
}()
|
||
|
||
logs = make([]service.UsageLog, 0)
|
||
for rows.Next() {
|
||
var log *service.UsageLog
|
||
log, err = scanUsageLog(rows)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
logs = append(logs, *log)
|
||
}
|
||
if err = rows.Err(); err != nil {
|
||
return nil, err
|
||
}
|
||
return logs, nil
|
||
}
|
||
|
||
func (r *usageLogRepository) hydrateUsageLogAssociations(ctx context.Context, logs []service.UsageLog) error {
|
||
// 关联数据使用 Ent 批量加载,避免把复杂 SQL 继续膨胀。
|
||
if len(logs) == 0 {
|
||
return nil
|
||
}
|
||
|
||
ids := collectUsageLogIDs(logs)
|
||
users, err := r.loadUsers(ctx, ids.userIDs)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
apiKeys, err := r.loadAPIKeys(ctx, ids.apiKeyIDs)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
accounts, err := r.loadAccounts(ctx, ids.accountIDs)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
groups, err := r.loadGroups(ctx, ids.groupIDs)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
subs, err := r.loadSubscriptions(ctx, ids.subscriptionIDs)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
for i := range logs {
|
||
if user, ok := users[logs[i].UserID]; ok {
|
||
logs[i].User = user
|
||
}
|
||
if key, ok := apiKeys[logs[i].APIKeyID]; ok {
|
||
logs[i].APIKey = key
|
||
}
|
||
if acc, ok := accounts[logs[i].AccountID]; ok {
|
||
logs[i].Account = acc
|
||
}
|
||
if logs[i].GroupID != nil {
|
||
if group, ok := groups[*logs[i].GroupID]; ok {
|
||
logs[i].Group = group
|
||
}
|
||
}
|
||
if logs[i].SubscriptionID != nil {
|
||
if sub, ok := subs[*logs[i].SubscriptionID]; ok {
|
||
logs[i].Subscription = sub
|
||
}
|
||
}
|
||
}
|
||
return nil
|
||
}
|
||
|
||
type usageLogIDs struct {
|
||
userIDs []int64
|
||
apiKeyIDs []int64
|
||
accountIDs []int64
|
||
groupIDs []int64
|
||
subscriptionIDs []int64
|
||
}
|
||
|
||
func collectUsageLogIDs(logs []service.UsageLog) usageLogIDs {
|
||
idSet := func() map[int64]struct{} { return make(map[int64]struct{}) }
|
||
|
||
userIDs := idSet()
|
||
apiKeyIDs := idSet()
|
||
accountIDs := idSet()
|
||
groupIDs := idSet()
|
||
subscriptionIDs := idSet()
|
||
|
||
for i := range logs {
|
||
userIDs[logs[i].UserID] = struct{}{}
|
||
apiKeyIDs[logs[i].APIKeyID] = struct{}{}
|
||
accountIDs[logs[i].AccountID] = struct{}{}
|
||
if logs[i].GroupID != nil {
|
||
groupIDs[*logs[i].GroupID] = struct{}{}
|
||
}
|
||
if logs[i].SubscriptionID != nil {
|
||
subscriptionIDs[*logs[i].SubscriptionID] = struct{}{}
|
||
}
|
||
}
|
||
|
||
return usageLogIDs{
|
||
userIDs: setToSlice(userIDs),
|
||
apiKeyIDs: setToSlice(apiKeyIDs),
|
||
accountIDs: setToSlice(accountIDs),
|
||
groupIDs: setToSlice(groupIDs),
|
||
subscriptionIDs: setToSlice(subscriptionIDs),
|
||
}
|
||
}
|
||
|
||
func (r *usageLogRepository) loadUsers(ctx context.Context, ids []int64) (map[int64]*service.User, error) {
|
||
out := make(map[int64]*service.User)
|
||
if len(ids) == 0 {
|
||
return out, nil
|
||
}
|
||
models, err := r.client.User.Query().Where(dbuser.IDIn(ids...)).All(ctx)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
for _, m := range models {
|
||
out[m.ID] = userEntityToService(m)
|
||
}
|
||
return out, nil
|
||
}
|
||
|
||
func (r *usageLogRepository) loadAPIKeys(ctx context.Context, ids []int64) (map[int64]*service.APIKey, error) {
|
||
out := make(map[int64]*service.APIKey)
|
||
if len(ids) == 0 {
|
||
return out, nil
|
||
}
|
||
models, err := r.client.APIKey.Query().Where(dbapikey.IDIn(ids...)).All(ctx)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
for _, m := range models {
|
||
out[m.ID] = apiKeyEntityToService(m)
|
||
}
|
||
return out, nil
|
||
}
|
||
|
||
func (r *usageLogRepository) loadAccounts(ctx context.Context, ids []int64) (map[int64]*service.Account, error) {
|
||
out := make(map[int64]*service.Account)
|
||
if len(ids) == 0 {
|
||
return out, nil
|
||
}
|
||
models, err := r.client.Account.Query().Where(dbaccount.IDIn(ids...)).All(ctx)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
for _, m := range models {
|
||
out[m.ID] = accountEntityToService(m)
|
||
}
|
||
return out, nil
|
||
}
|
||
|
||
func (r *usageLogRepository) loadGroups(ctx context.Context, ids []int64) (map[int64]*service.Group, error) {
|
||
out := make(map[int64]*service.Group)
|
||
if len(ids) == 0 {
|
||
return out, nil
|
||
}
|
||
models, err := r.client.Group.Query().Where(dbgroup.IDIn(ids...)).All(ctx)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
for _, m := range models {
|
||
out[m.ID] = groupEntityToService(m)
|
||
}
|
||
return out, nil
|
||
}
|
||
|
||
func (r *usageLogRepository) loadSubscriptions(ctx context.Context, ids []int64) (map[int64]*service.UserSubscription, error) {
|
||
out := make(map[int64]*service.UserSubscription)
|
||
if len(ids) == 0 {
|
||
return out, nil
|
||
}
|
||
models, err := r.client.UserSubscription.Query().Where(dbusersub.IDIn(ids...)).All(ctx)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
for _, m := range models {
|
||
out[m.ID] = userSubscriptionEntityToService(m)
|
||
}
|
||
return out, nil
|
||
}
|
||
|
||
func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, error) {
|
||
var (
|
||
id int64
|
||
userID int64
|
||
apiKeyID int64
|
||
accountID int64
|
||
requestID sql.NullString
|
||
model string
|
||
groupID sql.NullInt64
|
||
subscriptionID sql.NullInt64
|
||
inputTokens int
|
||
outputTokens int
|
||
cacheCreationTokens int
|
||
cacheReadTokens int
|
||
cacheCreation5m int
|
||
cacheCreation1h int
|
||
inputCost float64
|
||
outputCost float64
|
||
cacheCreationCost float64
|
||
cacheReadCost float64
|
||
totalCost float64
|
||
actualCost float64
|
||
rateMultiplier float64
|
||
accountRateMultiplier sql.NullFloat64
|
||
billingType int16
|
||
stream bool
|
||
durationMs sql.NullInt64
|
||
firstTokenMs sql.NullInt64
|
||
userAgent sql.NullString
|
||
ipAddress sql.NullString
|
||
imageCount int
|
||
imageSize sql.NullString
|
||
reasoningEffort sql.NullString
|
||
createdAt time.Time
|
||
)
|
||
|
||
if err := scanner.Scan(
|
||
&id,
|
||
&userID,
|
||
&apiKeyID,
|
||
&accountID,
|
||
&requestID,
|
||
&model,
|
||
&groupID,
|
||
&subscriptionID,
|
||
&inputTokens,
|
||
&outputTokens,
|
||
&cacheCreationTokens,
|
||
&cacheReadTokens,
|
||
&cacheCreation5m,
|
||
&cacheCreation1h,
|
||
&inputCost,
|
||
&outputCost,
|
||
&cacheCreationCost,
|
||
&cacheReadCost,
|
||
&totalCost,
|
||
&actualCost,
|
||
&rateMultiplier,
|
||
&accountRateMultiplier,
|
||
&billingType,
|
||
&stream,
|
||
&durationMs,
|
||
&firstTokenMs,
|
||
&userAgent,
|
||
&ipAddress,
|
||
&imageCount,
|
||
&imageSize,
|
||
&reasoningEffort,
|
||
&createdAt,
|
||
); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
log := &service.UsageLog{
|
||
ID: id,
|
||
UserID: userID,
|
||
APIKeyID: apiKeyID,
|
||
AccountID: accountID,
|
||
Model: model,
|
||
InputTokens: inputTokens,
|
||
OutputTokens: outputTokens,
|
||
CacheCreationTokens: cacheCreationTokens,
|
||
CacheReadTokens: cacheReadTokens,
|
||
CacheCreation5mTokens: cacheCreation5m,
|
||
CacheCreation1hTokens: cacheCreation1h,
|
||
InputCost: inputCost,
|
||
OutputCost: outputCost,
|
||
CacheCreationCost: cacheCreationCost,
|
||
CacheReadCost: cacheReadCost,
|
||
TotalCost: totalCost,
|
||
ActualCost: actualCost,
|
||
RateMultiplier: rateMultiplier,
|
||
AccountRateMultiplier: nullFloat64Ptr(accountRateMultiplier),
|
||
BillingType: int8(billingType),
|
||
Stream: stream,
|
||
ImageCount: imageCount,
|
||
CreatedAt: createdAt,
|
||
}
|
||
|
||
if requestID.Valid {
|
||
log.RequestID = requestID.String
|
||
}
|
||
if groupID.Valid {
|
||
value := groupID.Int64
|
||
log.GroupID = &value
|
||
}
|
||
if subscriptionID.Valid {
|
||
value := subscriptionID.Int64
|
||
log.SubscriptionID = &value
|
||
}
|
||
if durationMs.Valid {
|
||
value := int(durationMs.Int64)
|
||
log.DurationMs = &value
|
||
}
|
||
if firstTokenMs.Valid {
|
||
value := int(firstTokenMs.Int64)
|
||
log.FirstTokenMs = &value
|
||
}
|
||
if userAgent.Valid {
|
||
log.UserAgent = &userAgent.String
|
||
}
|
||
if ipAddress.Valid {
|
||
log.IPAddress = &ipAddress.String
|
||
}
|
||
if imageSize.Valid {
|
||
log.ImageSize = &imageSize.String
|
||
}
|
||
if reasoningEffort.Valid {
|
||
log.ReasoningEffort = &reasoningEffort.String
|
||
}
|
||
|
||
return log, nil
|
||
}
|
||
|
||
func scanTrendRows(rows *sql.Rows) ([]TrendDataPoint, error) {
|
||
results := make([]TrendDataPoint, 0)
|
||
for rows.Next() {
|
||
var row TrendDataPoint
|
||
if err := rows.Scan(
|
||
&row.Date,
|
||
&row.Requests,
|
||
&row.InputTokens,
|
||
&row.OutputTokens,
|
||
&row.CacheTokens,
|
||
&row.TotalTokens,
|
||
&row.Cost,
|
||
&row.ActualCost,
|
||
); err != nil {
|
||
return nil, err
|
||
}
|
||
results = append(results, row)
|
||
}
|
||
if err := rows.Err(); err != nil {
|
||
return nil, err
|
||
}
|
||
return results, nil
|
||
}
|
||
|
||
func scanModelStatsRows(rows *sql.Rows) ([]ModelStat, error) {
|
||
results := make([]ModelStat, 0)
|
||
for rows.Next() {
|
||
var row ModelStat
|
||
if err := rows.Scan(
|
||
&row.Model,
|
||
&row.Requests,
|
||
&row.InputTokens,
|
||
&row.OutputTokens,
|
||
&row.TotalTokens,
|
||
&row.Cost,
|
||
&row.ActualCost,
|
||
); err != nil {
|
||
return nil, err
|
||
}
|
||
results = append(results, row)
|
||
}
|
||
if err := rows.Err(); err != nil {
|
||
return nil, err
|
||
}
|
||
return results, nil
|
||
}
|
||
|
||
func buildWhere(conditions []string) string {
|
||
if len(conditions) == 0 {
|
||
return ""
|
||
}
|
||
return "WHERE " + strings.Join(conditions, " AND ")
|
||
}
|
||
|
||
func nullInt64(v *int64) sql.NullInt64 {
|
||
if v == nil {
|
||
return sql.NullInt64{}
|
||
}
|
||
return sql.NullInt64{Int64: *v, Valid: true}
|
||
}
|
||
|
||
func nullInt(v *int) sql.NullInt64 {
|
||
if v == nil {
|
||
return sql.NullInt64{}
|
||
}
|
||
return sql.NullInt64{Int64: int64(*v), Valid: true}
|
||
}
|
||
|
||
func nullFloat64Ptr(v sql.NullFloat64) *float64 {
|
||
if !v.Valid {
|
||
return nil
|
||
}
|
||
out := v.Float64
|
||
return &out
|
||
}
|
||
|
||
func nullString(v *string) sql.NullString {
|
||
if v == nil || *v == "" {
|
||
return sql.NullString{}
|
||
}
|
||
return sql.NullString{String: *v, Valid: true}
|
||
}
|
||
|
||
func setToSlice(set map[int64]struct{}) []int64 {
|
||
out := make([]int64, 0, len(set))
|
||
for id := range set {
|
||
out = append(out, id)
|
||
}
|
||
return out
|
||
}
|