merge: 合并main分支最新改动
解决冲突: - backend/internal/config/config.go: 合并Ops和Dashboard配置 - backend/internal/server/api_contract_test.go: 合并handler初始化 - backend/internal/service/openai_gateway_service.go: 保留Ops错误追踪逻辑 - backend/internal/service/wire.go: 合并Ops和APIKeyAuth provider 主要合并内容: - Dashboard缓存和预聚合功能 - API Key认证缓存优化 - Codex转换支持 - 使用日志分区表
This commit is contained in:
@@ -2,6 +2,7 @@ package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
@@ -13,6 +14,7 @@ import (
|
||||
const (
|
||||
apiKeyRateLimitKeyPrefix = "apikey:ratelimit:"
|
||||
apiKeyRateLimitDuration = 24 * time.Hour
|
||||
apiKeyAuthCachePrefix = "apikey:auth:"
|
||||
)
|
||||
|
||||
// apiKeyRateLimitKey generates the Redis key for API key creation rate limiting.
|
||||
@@ -20,6 +22,10 @@ func apiKeyRateLimitKey(userID int64) string {
|
||||
return fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID)
|
||||
}
|
||||
|
||||
func apiKeyAuthCacheKey(key string) string {
|
||||
return fmt.Sprintf("%s%s", apiKeyAuthCachePrefix, key)
|
||||
}
|
||||
|
||||
type apiKeyCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
@@ -58,3 +64,30 @@ func (c *apiKeyCache) IncrementDailyUsage(ctx context.Context, apiKey string) er
|
||||
func (c *apiKeyCache) SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error {
|
||||
return c.rdb.Expire(ctx, apiKey, ttl).Err()
|
||||
}
|
||||
|
||||
func (c *apiKeyCache) GetAuthCache(ctx context.Context, key string) (*service.APIKeyAuthCacheEntry, error) {
|
||||
val, err := c.rdb.Get(ctx, apiKeyAuthCacheKey(key)).Bytes()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var entry service.APIKeyAuthCacheEntry
|
||||
if err := json.Unmarshal(val, &entry); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &entry, nil
|
||||
}
|
||||
|
||||
func (c *apiKeyCache) SetAuthCache(ctx context.Context, key string, entry *service.APIKeyAuthCacheEntry, ttl time.Duration) error {
|
||||
if entry == nil {
|
||||
return nil
|
||||
}
|
||||
payload, err := json.Marshal(entry)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.rdb.Set(ctx, apiKeyAuthCacheKey(key), payload, ttl).Err()
|
||||
}
|
||||
|
||||
func (c *apiKeyCache) DeleteAuthCache(ctx context.Context, key string) error {
|
||||
return c.rdb.Del(ctx, apiKeyAuthCacheKey(key)).Err()
|
||||
}
|
||||
|
||||
@@ -6,7 +6,9 @@ import (
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/apikey"
|
||||
"github.com/Wei-Shaw/sub2api/ent/group"
|
||||
"github.com/Wei-Shaw/sub2api/ent/schema/mixins"
|
||||
"github.com/Wei-Shaw/sub2api/ent/user"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
@@ -64,23 +66,23 @@ func (r *apiKeyRepository) GetByID(ctx context.Context, id int64) (*service.APIK
|
||||
return apiKeyEntityToService(m), nil
|
||||
}
|
||||
|
||||
// GetOwnerID 根据 API Key ID 获取其所有者(用户)的 ID。
|
||||
// GetKeyAndOwnerID 根据 API Key ID 获取其 key 与所有者(用户)ID。
|
||||
// 相比 GetByID,此方法性能更优,因为:
|
||||
// - 使用 Select() 只查询 user_id 字段,减少数据传输量
|
||||
// - 使用 Select() 只查询必要字段,减少数据传输量
|
||||
// - 不加载完整的 API Key 实体及其关联数据(User、Group 等)
|
||||
// - 适用于权限验证等只需用户 ID 的场景(如删除前的所有权检查)
|
||||
func (r *apiKeyRepository) GetOwnerID(ctx context.Context, id int64) (int64, error) {
|
||||
// - 适用于删除等只需 key 与用户 ID 的场景
|
||||
func (r *apiKeyRepository) GetKeyAndOwnerID(ctx context.Context, id int64) (string, int64, error) {
|
||||
m, err := r.activeQuery().
|
||||
Where(apikey.IDEQ(id)).
|
||||
Select(apikey.FieldUserID).
|
||||
Select(apikey.FieldKey, apikey.FieldUserID).
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return 0, service.ErrAPIKeyNotFound
|
||||
return "", 0, service.ErrAPIKeyNotFound
|
||||
}
|
||||
return 0, err
|
||||
return "", 0, err
|
||||
}
|
||||
return m.UserID, nil
|
||||
return m.Key, m.UserID, nil
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*service.APIKey, error) {
|
||||
@@ -98,6 +100,54 @@ func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*service.A
|
||||
return apiKeyEntityToService(m), nil
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*service.APIKey, error) {
|
||||
m, err := r.activeQuery().
|
||||
Where(apikey.KeyEQ(key)).
|
||||
Select(
|
||||
apikey.FieldID,
|
||||
apikey.FieldUserID,
|
||||
apikey.FieldGroupID,
|
||||
apikey.FieldStatus,
|
||||
apikey.FieldIPWhitelist,
|
||||
apikey.FieldIPBlacklist,
|
||||
).
|
||||
WithUser(func(q *dbent.UserQuery) {
|
||||
q.Select(
|
||||
user.FieldID,
|
||||
user.FieldStatus,
|
||||
user.FieldRole,
|
||||
user.FieldBalance,
|
||||
user.FieldConcurrency,
|
||||
)
|
||||
}).
|
||||
WithGroup(func(q *dbent.GroupQuery) {
|
||||
q.Select(
|
||||
group.FieldID,
|
||||
group.FieldName,
|
||||
group.FieldPlatform,
|
||||
group.FieldStatus,
|
||||
group.FieldSubscriptionType,
|
||||
group.FieldRateMultiplier,
|
||||
group.FieldDailyLimitUsd,
|
||||
group.FieldWeeklyLimitUsd,
|
||||
group.FieldMonthlyLimitUsd,
|
||||
group.FieldImagePrice1k,
|
||||
group.FieldImagePrice2k,
|
||||
group.FieldImagePrice4k,
|
||||
group.FieldClaudeCodeOnly,
|
||||
group.FieldFallbackGroupID,
|
||||
)
|
||||
}).
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return nil, service.ErrAPIKeyNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return apiKeyEntityToService(m), nil
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) Update(ctx context.Context, key *service.APIKey) error {
|
||||
// 使用原子操作:将软删除检查与更新合并到同一语句,避免竞态条件。
|
||||
// 之前的实现先检查 Exist 再 UpdateOneID,若在两步之间发生软删除,
|
||||
@@ -283,6 +333,28 @@ func (r *apiKeyRepository) CountByGroupID(ctx context.Context, groupID int64) (i
|
||||
return int64(count), err
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) ListKeysByUserID(ctx context.Context, userID int64) ([]string, error) {
|
||||
keys, err := r.activeQuery().
|
||||
Where(apikey.UserIDEQ(userID)).
|
||||
Select(apikey.FieldKey).
|
||||
Strings(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error) {
|
||||
keys, err := r.activeQuery().
|
||||
Where(apikey.GroupIDEQ(groupID)).
|
||||
Select(apikey.FieldKey).
|
||||
Strings(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
func apiKeyEntityToService(m *dbent.APIKey) *service.APIKey {
|
||||
if m == nil {
|
||||
return nil
|
||||
|
||||
363
backend/internal/repository/dashboard_aggregation_repo.go
Normal file
363
backend/internal/repository/dashboard_aggregation_repo.go
Normal file
@@ -0,0 +1,363 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/lib/pq"
|
||||
)
|
||||
|
||||
type dashboardAggregationRepository struct {
|
||||
sql sqlExecutor
|
||||
}
|
||||
|
||||
// NewDashboardAggregationRepository 创建仪表盘预聚合仓储。
|
||||
func NewDashboardAggregationRepository(sqlDB *sql.DB) service.DashboardAggregationRepository {
|
||||
return newDashboardAggregationRepositoryWithSQL(sqlDB)
|
||||
}
|
||||
|
||||
func newDashboardAggregationRepositoryWithSQL(sqlq sqlExecutor) *dashboardAggregationRepository {
|
||||
return &dashboardAggregationRepository{sql: sqlq}
|
||||
}
|
||||
|
||||
func (r *dashboardAggregationRepository) AggregateRange(ctx context.Context, start, end time.Time) error {
|
||||
startUTC := start.UTC()
|
||||
endUTC := end.UTC()
|
||||
if !endUTC.After(startUTC) {
|
||||
return nil
|
||||
}
|
||||
|
||||
hourStart := startUTC.Truncate(time.Hour)
|
||||
hourEnd := endUTC.Truncate(time.Hour)
|
||||
if endUTC.After(hourEnd) {
|
||||
hourEnd = hourEnd.Add(time.Hour)
|
||||
}
|
||||
|
||||
dayStart := truncateToDayUTC(startUTC)
|
||||
dayEnd := truncateToDayUTC(endUTC)
|
||||
if endUTC.After(dayEnd) {
|
||||
dayEnd = dayEnd.Add(24 * time.Hour)
|
||||
}
|
||||
|
||||
// 以桶边界聚合,允许覆盖 end 所在桶的剩余区间。
|
||||
if err := r.insertHourlyActiveUsers(ctx, hourStart, hourEnd); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := r.insertDailyActiveUsers(ctx, hourStart, hourEnd); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := r.upsertHourlyAggregates(ctx, hourStart, hourEnd); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := r.upsertDailyAggregates(ctx, dayStart, dayEnd); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *dashboardAggregationRepository) GetAggregationWatermark(ctx context.Context) (time.Time, error) {
|
||||
var ts time.Time
|
||||
query := "SELECT last_aggregated_at FROM usage_dashboard_aggregation_watermark WHERE id = 1"
|
||||
if err := scanSingleRow(ctx, r.sql, query, nil, &ts); err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return time.Unix(0, 0).UTC(), nil
|
||||
}
|
||||
return time.Time{}, err
|
||||
}
|
||||
return ts.UTC(), nil
|
||||
}
|
||||
|
||||
func (r *dashboardAggregationRepository) UpdateAggregationWatermark(ctx context.Context, aggregatedAt time.Time) error {
|
||||
query := `
|
||||
INSERT INTO usage_dashboard_aggregation_watermark (id, last_aggregated_at, updated_at)
|
||||
VALUES (1, $1, NOW())
|
||||
ON CONFLICT (id)
|
||||
DO UPDATE SET last_aggregated_at = EXCLUDED.last_aggregated_at, updated_at = EXCLUDED.updated_at
|
||||
`
|
||||
_, err := r.sql.ExecContext(ctx, query, aggregatedAt.UTC())
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *dashboardAggregationRepository) CleanupAggregates(ctx context.Context, hourlyCutoff, dailyCutoff time.Time) error {
|
||||
_, err := r.sql.ExecContext(ctx, `
|
||||
DELETE FROM usage_dashboard_hourly WHERE bucket_start < $1;
|
||||
DELETE FROM usage_dashboard_hourly_users WHERE bucket_start < $1;
|
||||
DELETE FROM usage_dashboard_daily WHERE bucket_date < $2::date;
|
||||
DELETE FROM usage_dashboard_daily_users WHERE bucket_date < $2::date;
|
||||
`, hourlyCutoff.UTC(), dailyCutoff.UTC())
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *dashboardAggregationRepository) CleanupUsageLogs(ctx context.Context, cutoff time.Time) error {
|
||||
isPartitioned, err := r.isUsageLogsPartitioned(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if isPartitioned {
|
||||
return r.dropUsageLogsPartitions(ctx, cutoff)
|
||||
}
|
||||
_, err = r.sql.ExecContext(ctx, "DELETE FROM usage_logs WHERE created_at < $1", cutoff.UTC())
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *dashboardAggregationRepository) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error {
|
||||
isPartitioned, err := r.isUsageLogsPartitioned(ctx)
|
||||
if err != nil || !isPartitioned {
|
||||
return err
|
||||
}
|
||||
monthStart := truncateToMonthUTC(now)
|
||||
prevMonth := monthStart.AddDate(0, -1, 0)
|
||||
nextMonth := monthStart.AddDate(0, 1, 0)
|
||||
|
||||
for _, m := range []time.Time{prevMonth, monthStart, nextMonth} {
|
||||
if err := r.createUsageLogsPartition(ctx, m); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *dashboardAggregationRepository) insertHourlyActiveUsers(ctx context.Context, start, end time.Time) error {
|
||||
query := `
|
||||
INSERT INTO usage_dashboard_hourly_users (bucket_start, user_id)
|
||||
SELECT DISTINCT
|
||||
date_trunc('hour', created_at AT TIME ZONE 'UTC') AT TIME ZONE 'UTC' AS bucket_start,
|
||||
user_id
|
||||
FROM usage_logs
|
||||
WHERE created_at >= $1 AND created_at < $2
|
||||
ON CONFLICT DO NOTHING
|
||||
`
|
||||
_, err := r.sql.ExecContext(ctx, query, start.UTC(), end.UTC())
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *dashboardAggregationRepository) insertDailyActiveUsers(ctx context.Context, start, end time.Time) error {
|
||||
query := `
|
||||
INSERT INTO usage_dashboard_daily_users (bucket_date, user_id)
|
||||
SELECT DISTINCT
|
||||
(bucket_start AT TIME ZONE 'UTC')::date AS bucket_date,
|
||||
user_id
|
||||
FROM usage_dashboard_hourly_users
|
||||
WHERE bucket_start >= $1 AND bucket_start < $2
|
||||
ON CONFLICT DO NOTHING
|
||||
`
|
||||
_, err := r.sql.ExecContext(ctx, query, start.UTC(), end.UTC())
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *dashboardAggregationRepository) upsertHourlyAggregates(ctx context.Context, start, end time.Time) error {
|
||||
query := `
|
||||
WITH hourly AS (
|
||||
SELECT
|
||||
date_trunc('hour', created_at AT TIME ZONE 'UTC') AT TIME ZONE 'UTC' AS bucket_start,
|
||||
COUNT(*) AS total_requests,
|
||||
COALESCE(SUM(input_tokens), 0) AS input_tokens,
|
||||
COALESCE(SUM(output_tokens), 0) AS output_tokens,
|
||||
COALESCE(SUM(cache_creation_tokens), 0) AS cache_creation_tokens,
|
||||
COALESCE(SUM(cache_read_tokens), 0) AS cache_read_tokens,
|
||||
COALESCE(SUM(total_cost), 0) AS total_cost,
|
||||
COALESCE(SUM(actual_cost), 0) AS actual_cost,
|
||||
COALESCE(SUM(COALESCE(duration_ms, 0)), 0) AS total_duration_ms
|
||||
FROM usage_logs
|
||||
WHERE created_at >= $1 AND created_at < $2
|
||||
GROUP BY 1
|
||||
),
|
||||
user_counts AS (
|
||||
SELECT bucket_start, COUNT(*) AS active_users
|
||||
FROM usage_dashboard_hourly_users
|
||||
WHERE bucket_start >= $1 AND bucket_start < $2
|
||||
GROUP BY bucket_start
|
||||
)
|
||||
INSERT INTO usage_dashboard_hourly (
|
||||
bucket_start,
|
||||
total_requests,
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
cache_creation_tokens,
|
||||
cache_read_tokens,
|
||||
total_cost,
|
||||
actual_cost,
|
||||
total_duration_ms,
|
||||
active_users,
|
||||
computed_at
|
||||
)
|
||||
SELECT
|
||||
hourly.bucket_start,
|
||||
hourly.total_requests,
|
||||
hourly.input_tokens,
|
||||
hourly.output_tokens,
|
||||
hourly.cache_creation_tokens,
|
||||
hourly.cache_read_tokens,
|
||||
hourly.total_cost,
|
||||
hourly.actual_cost,
|
||||
hourly.total_duration_ms,
|
||||
COALESCE(user_counts.active_users, 0) AS active_users,
|
||||
NOW()
|
||||
FROM hourly
|
||||
LEFT JOIN user_counts ON user_counts.bucket_start = hourly.bucket_start
|
||||
ON CONFLICT (bucket_start)
|
||||
DO UPDATE SET
|
||||
total_requests = EXCLUDED.total_requests,
|
||||
input_tokens = EXCLUDED.input_tokens,
|
||||
output_tokens = EXCLUDED.output_tokens,
|
||||
cache_creation_tokens = EXCLUDED.cache_creation_tokens,
|
||||
cache_read_tokens = EXCLUDED.cache_read_tokens,
|
||||
total_cost = EXCLUDED.total_cost,
|
||||
actual_cost = EXCLUDED.actual_cost,
|
||||
total_duration_ms = EXCLUDED.total_duration_ms,
|
||||
active_users = EXCLUDED.active_users,
|
||||
computed_at = EXCLUDED.computed_at
|
||||
`
|
||||
_, err := r.sql.ExecContext(ctx, query, start.UTC(), end.UTC())
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *dashboardAggregationRepository) upsertDailyAggregates(ctx context.Context, start, end time.Time) error {
|
||||
query := `
|
||||
WITH daily AS (
|
||||
SELECT
|
||||
(bucket_start AT TIME ZONE 'UTC')::date AS bucket_date,
|
||||
COALESCE(SUM(total_requests), 0) AS total_requests,
|
||||
COALESCE(SUM(input_tokens), 0) AS input_tokens,
|
||||
COALESCE(SUM(output_tokens), 0) AS output_tokens,
|
||||
COALESCE(SUM(cache_creation_tokens), 0) AS cache_creation_tokens,
|
||||
COALESCE(SUM(cache_read_tokens), 0) AS cache_read_tokens,
|
||||
COALESCE(SUM(total_cost), 0) AS total_cost,
|
||||
COALESCE(SUM(actual_cost), 0) AS actual_cost,
|
||||
COALESCE(SUM(total_duration_ms), 0) AS total_duration_ms
|
||||
FROM usage_dashboard_hourly
|
||||
WHERE bucket_start >= $1 AND bucket_start < $2
|
||||
GROUP BY (bucket_start AT TIME ZONE 'UTC')::date
|
||||
),
|
||||
user_counts AS (
|
||||
SELECT bucket_date, COUNT(*) AS active_users
|
||||
FROM usage_dashboard_daily_users
|
||||
WHERE bucket_date >= $3::date AND bucket_date < $4::date
|
||||
GROUP BY bucket_date
|
||||
)
|
||||
INSERT INTO usage_dashboard_daily (
|
||||
bucket_date,
|
||||
total_requests,
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
cache_creation_tokens,
|
||||
cache_read_tokens,
|
||||
total_cost,
|
||||
actual_cost,
|
||||
total_duration_ms,
|
||||
active_users,
|
||||
computed_at
|
||||
)
|
||||
SELECT
|
||||
daily.bucket_date,
|
||||
daily.total_requests,
|
||||
daily.input_tokens,
|
||||
daily.output_tokens,
|
||||
daily.cache_creation_tokens,
|
||||
daily.cache_read_tokens,
|
||||
daily.total_cost,
|
||||
daily.actual_cost,
|
||||
daily.total_duration_ms,
|
||||
COALESCE(user_counts.active_users, 0) AS active_users,
|
||||
NOW()
|
||||
FROM daily
|
||||
LEFT JOIN user_counts ON user_counts.bucket_date = daily.bucket_date
|
||||
ON CONFLICT (bucket_date)
|
||||
DO UPDATE SET
|
||||
total_requests = EXCLUDED.total_requests,
|
||||
input_tokens = EXCLUDED.input_tokens,
|
||||
output_tokens = EXCLUDED.output_tokens,
|
||||
cache_creation_tokens = EXCLUDED.cache_creation_tokens,
|
||||
cache_read_tokens = EXCLUDED.cache_read_tokens,
|
||||
total_cost = EXCLUDED.total_cost,
|
||||
actual_cost = EXCLUDED.actual_cost,
|
||||
total_duration_ms = EXCLUDED.total_duration_ms,
|
||||
active_users = EXCLUDED.active_users,
|
||||
computed_at = EXCLUDED.computed_at
|
||||
`
|
||||
_, err := r.sql.ExecContext(ctx, query, start.UTC(), end.UTC(), start.UTC(), end.UTC())
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *dashboardAggregationRepository) isUsageLogsPartitioned(ctx context.Context) (bool, error) {
|
||||
query := `
|
||||
SELECT EXISTS(
|
||||
SELECT 1
|
||||
FROM pg_partitioned_table pt
|
||||
JOIN pg_class c ON c.oid = pt.partrelid
|
||||
WHERE c.relname = 'usage_logs'
|
||||
)
|
||||
`
|
||||
var partitioned bool
|
||||
if err := scanSingleRow(ctx, r.sql, query, nil, &partitioned); err != nil {
|
||||
return false, err
|
||||
}
|
||||
return partitioned, nil
|
||||
}
|
||||
|
||||
func (r *dashboardAggregationRepository) dropUsageLogsPartitions(ctx context.Context, cutoff time.Time) error {
|
||||
rows, err := r.sql.QueryContext(ctx, `
|
||||
SELECT c.relname
|
||||
FROM pg_inherits
|
||||
JOIN pg_class c ON c.oid = pg_inherits.inhrelid
|
||||
JOIN pg_class p ON p.oid = pg_inherits.inhparent
|
||||
WHERE p.relname = 'usage_logs'
|
||||
`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
_ = rows.Close()
|
||||
}()
|
||||
|
||||
cutoffMonth := truncateToMonthUTC(cutoff)
|
||||
for rows.Next() {
|
||||
var name string
|
||||
if err := rows.Scan(&name); err != nil {
|
||||
return err
|
||||
}
|
||||
if !strings.HasPrefix(name, "usage_logs_") {
|
||||
continue
|
||||
}
|
||||
suffix := strings.TrimPrefix(name, "usage_logs_")
|
||||
month, err := time.Parse("200601", suffix)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
month = month.UTC()
|
||||
if month.Before(cutoffMonth) {
|
||||
if _, err := r.sql.ExecContext(ctx, fmt.Sprintf("DROP TABLE IF EXISTS %s", pq.QuoteIdentifier(name))); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return rows.Err()
|
||||
}
|
||||
|
||||
func (r *dashboardAggregationRepository) createUsageLogsPartition(ctx context.Context, month time.Time) error {
|
||||
monthStart := truncateToMonthUTC(month)
|
||||
nextMonth := monthStart.AddDate(0, 1, 0)
|
||||
name := fmt.Sprintf("usage_logs_%s", monthStart.Format("200601"))
|
||||
query := fmt.Sprintf(
|
||||
"CREATE TABLE IF NOT EXISTS %s PARTITION OF usage_logs FOR VALUES FROM (%s) TO (%s)",
|
||||
pq.QuoteIdentifier(name),
|
||||
pq.QuoteLiteral(monthStart.Format("2006-01-02")),
|
||||
pq.QuoteLiteral(nextMonth.Format("2006-01-02")),
|
||||
)
|
||||
_, err := r.sql.ExecContext(ctx, query)
|
||||
return err
|
||||
}
|
||||
|
||||
func truncateToDayUTC(t time.Time) time.Time {
|
||||
t = t.UTC()
|
||||
return time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, time.UTC)
|
||||
}
|
||||
|
||||
func truncateToMonthUTC(t time.Time) time.Time {
|
||||
t = t.UTC()
|
||||
return time.Date(t.Year(), t.Month(), 1, 0, 0, 0, 0, time.UTC)
|
||||
}
|
||||
58
backend/internal/repository/dashboard_cache.go
Normal file
58
backend/internal/repository/dashboard_cache.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const dashboardStatsCacheKey = "dashboard:stats:v1"
|
||||
|
||||
type dashboardCache struct {
|
||||
rdb *redis.Client
|
||||
keyPrefix string
|
||||
}
|
||||
|
||||
func NewDashboardCache(rdb *redis.Client, cfg *config.Config) service.DashboardStatsCache {
|
||||
prefix := "sub2api:"
|
||||
if cfg != nil {
|
||||
prefix = strings.TrimSpace(cfg.Dashboard.KeyPrefix)
|
||||
}
|
||||
if prefix != "" && !strings.HasSuffix(prefix, ":") {
|
||||
prefix += ":"
|
||||
}
|
||||
return &dashboardCache{
|
||||
rdb: rdb,
|
||||
keyPrefix: prefix,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *dashboardCache) GetDashboardStats(ctx context.Context) (string, error) {
|
||||
val, err := c.rdb.Get(ctx, c.buildKey()).Result()
|
||||
if err != nil {
|
||||
if err == redis.Nil {
|
||||
return "", service.ErrDashboardStatsCacheMiss
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
return val, nil
|
||||
}
|
||||
|
||||
func (c *dashboardCache) SetDashboardStats(ctx context.Context, data string, ttl time.Duration) error {
|
||||
return c.rdb.Set(ctx, c.buildKey(), data, ttl).Err()
|
||||
}
|
||||
|
||||
func (c *dashboardCache) buildKey() string {
|
||||
if c.keyPrefix == "" {
|
||||
return dashboardStatsCacheKey
|
||||
}
|
||||
return c.keyPrefix + dashboardStatsCacheKey
|
||||
}
|
||||
|
||||
func (c *dashboardCache) DeleteDashboardStats(ctx context.Context) error {
|
||||
return c.rdb.Del(ctx, c.buildKey()).Err()
|
||||
}
|
||||
28
backend/internal/repository/dashboard_cache_test.go
Normal file
28
backend/internal/repository/dashboard_cache_test.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewDashboardCacheKeyPrefix(t *testing.T) {
|
||||
cache := NewDashboardCache(nil, &config.Config{
|
||||
Dashboard: config.DashboardCacheConfig{
|
||||
KeyPrefix: "prod",
|
||||
},
|
||||
})
|
||||
impl, ok := cache.(*dashboardCache)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "prod:", impl.keyPrefix)
|
||||
|
||||
cache = NewDashboardCache(nil, &config.Config{
|
||||
Dashboard: config.DashboardCacheConfig{
|
||||
KeyPrefix: "staging:",
|
||||
},
|
||||
})
|
||||
impl, ok = cache.(*dashboardCache)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "staging:", impl.keyPrefix)
|
||||
}
|
||||
@@ -269,16 +269,60 @@ func (r *usageLogRepository) GetUserStats(ctx context.Context, userID int64, sta
|
||||
type DashboardStats = usagestats.DashboardStats
|
||||
|
||||
func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardStats, error) {
|
||||
var stats DashboardStats
|
||||
today := timezone.Today()
|
||||
now := time.Now()
|
||||
stats := &DashboardStats{}
|
||||
now := time.Now().UTC()
|
||||
todayUTC := truncateToDayUTC(now)
|
||||
|
||||
// 合并用户统计查询
|
||||
if err := r.fillDashboardEntityStats(ctx, stats, todayUTC, now); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := r.fillDashboardUsageStatsAggregated(ctx, stats, todayUTC, now); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rpm, tpm, err := r.getPerformanceStats(ctx, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stats.Rpm = rpm
|
||||
stats.Tpm = tpm
|
||||
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
func (r *usageLogRepository) GetDashboardStatsWithRange(ctx context.Context, start, end time.Time) (*DashboardStats, error) {
|
||||
startUTC := start.UTC()
|
||||
endUTC := end.UTC()
|
||||
if !endUTC.After(startUTC) {
|
||||
return nil, errors.New("统计时间范围无效")
|
||||
}
|
||||
|
||||
stats := &DashboardStats{}
|
||||
now := time.Now().UTC()
|
||||
todayUTC := truncateToDayUTC(now)
|
||||
|
||||
if err := r.fillDashboardEntityStats(ctx, stats, todayUTC, now); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := r.fillDashboardUsageStatsFromUsageLogs(ctx, stats, startUTC, endUTC, todayUTC, now); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rpm, tpm, err := r.getPerformanceStats(ctx, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stats.Rpm = rpm
|
||||
stats.Tpm = tpm
|
||||
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
func (r *usageLogRepository) fillDashboardEntityStats(ctx context.Context, stats *DashboardStats, todayUTC, now time.Time) error {
|
||||
userStatsQuery := `
|
||||
SELECT
|
||||
COUNT(*) as total_users,
|
||||
COUNT(CASE WHEN created_at >= $1 THEN 1 END) as today_new_users,
|
||||
(SELECT COUNT(DISTINCT user_id) FROM usage_logs WHERE created_at >= $2) as active_users
|
||||
COUNT(CASE WHEN created_at >= $1 THEN 1 END) as today_new_users
|
||||
FROM users
|
||||
WHERE deleted_at IS NULL
|
||||
`
|
||||
@@ -286,15 +330,13 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
|
||||
ctx,
|
||||
r.sql,
|
||||
userStatsQuery,
|
||||
[]any{today, today},
|
||||
[]any{todayUTC},
|
||||
&stats.TotalUsers,
|
||||
&stats.TodayNewUsers,
|
||||
&stats.ActiveUsers,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
// 合并API Key统计查询
|
||||
apiKeyStatsQuery := `
|
||||
SELECT
|
||||
COUNT(*) as total_api_keys,
|
||||
@@ -310,10 +352,9 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
|
||||
&stats.TotalAPIKeys,
|
||||
&stats.ActiveAPIKeys,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
// 合并账户统计查询
|
||||
accountStatsQuery := `
|
||||
SELECT
|
||||
COUNT(*) as total_accounts,
|
||||
@@ -335,22 +376,26 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
|
||||
&stats.RateLimitAccounts,
|
||||
&stats.OverloadAccounts,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
// 累计 Token 统计
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *usageLogRepository) fillDashboardUsageStatsAggregated(ctx context.Context, stats *DashboardStats, todayUTC, now time.Time) error {
|
||||
totalStatsQuery := `
|
||||
SELECT
|
||||
COUNT(*) as total_requests,
|
||||
COALESCE(SUM(total_requests), 0) as total_requests,
|
||||
COALESCE(SUM(input_tokens), 0) as total_input_tokens,
|
||||
COALESCE(SUM(output_tokens), 0) as total_output_tokens,
|
||||
COALESCE(SUM(cache_creation_tokens), 0) as total_cache_creation_tokens,
|
||||
COALESCE(SUM(cache_read_tokens), 0) as total_cache_read_tokens,
|
||||
COALESCE(SUM(total_cost), 0) as total_cost,
|
||||
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
|
||||
COALESCE(AVG(duration_ms), 0) as avg_duration_ms
|
||||
FROM usage_logs
|
||||
COALESCE(SUM(total_duration_ms), 0) as total_duration_ms
|
||||
FROM usage_dashboard_daily
|
||||
`
|
||||
var totalDurationMs int64
|
||||
if err := scanSingleRow(
|
||||
ctx,
|
||||
r.sql,
|
||||
@@ -363,13 +408,100 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
|
||||
&stats.TotalCacheReadTokens,
|
||||
&stats.TotalCost,
|
||||
&stats.TotalActualCost,
|
||||
&stats.AverageDurationMs,
|
||||
&totalDurationMs,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheCreationTokens + stats.TotalCacheReadTokens
|
||||
if stats.TotalRequests > 0 {
|
||||
stats.AverageDurationMs = float64(totalDurationMs) / float64(stats.TotalRequests)
|
||||
}
|
||||
|
||||
// 今日 Token 统计
|
||||
todayStatsQuery := `
|
||||
SELECT
|
||||
total_requests as today_requests,
|
||||
input_tokens as today_input_tokens,
|
||||
output_tokens as today_output_tokens,
|
||||
cache_creation_tokens as today_cache_creation_tokens,
|
||||
cache_read_tokens as today_cache_read_tokens,
|
||||
total_cost as today_cost,
|
||||
actual_cost as today_actual_cost,
|
||||
active_users as active_users
|
||||
FROM usage_dashboard_daily
|
||||
WHERE bucket_date = $1::date
|
||||
`
|
||||
if err := scanSingleRow(
|
||||
ctx,
|
||||
r.sql,
|
||||
todayStatsQuery,
|
||||
[]any{todayUTC},
|
||||
&stats.TodayRequests,
|
||||
&stats.TodayInputTokens,
|
||||
&stats.TodayOutputTokens,
|
||||
&stats.TodayCacheCreationTokens,
|
||||
&stats.TodayCacheReadTokens,
|
||||
&stats.TodayCost,
|
||||
&stats.TodayActualCost,
|
||||
&stats.ActiveUsers,
|
||||
); err != nil {
|
||||
if err != sql.ErrNoRows {
|
||||
return err
|
||||
}
|
||||
}
|
||||
stats.TodayTokens = stats.TodayInputTokens + stats.TodayOutputTokens + stats.TodayCacheCreationTokens + stats.TodayCacheReadTokens
|
||||
|
||||
hourlyActiveQuery := `
|
||||
SELECT active_users
|
||||
FROM usage_dashboard_hourly
|
||||
WHERE bucket_start = $1
|
||||
`
|
||||
hourStart := now.UTC().Truncate(time.Hour)
|
||||
if err := scanSingleRow(ctx, r.sql, hourlyActiveQuery, []any{hourStart}, &stats.HourlyActiveUsers); err != nil {
|
||||
if err != sql.ErrNoRows {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *usageLogRepository) fillDashboardUsageStatsFromUsageLogs(ctx context.Context, stats *DashboardStats, startUTC, endUTC, todayUTC, now time.Time) error {
|
||||
totalStatsQuery := `
|
||||
SELECT
|
||||
COUNT(*) as total_requests,
|
||||
COALESCE(SUM(input_tokens), 0) as total_input_tokens,
|
||||
COALESCE(SUM(output_tokens), 0) as total_output_tokens,
|
||||
COALESCE(SUM(cache_creation_tokens), 0) as total_cache_creation_tokens,
|
||||
COALESCE(SUM(cache_read_tokens), 0) as total_cache_read_tokens,
|
||||
COALESCE(SUM(total_cost), 0) as total_cost,
|
||||
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
|
||||
COALESCE(SUM(COALESCE(duration_ms, 0)), 0) as total_duration_ms
|
||||
FROM usage_logs
|
||||
WHERE created_at >= $1 AND created_at < $2
|
||||
`
|
||||
var totalDurationMs int64
|
||||
if err := scanSingleRow(
|
||||
ctx,
|
||||
r.sql,
|
||||
totalStatsQuery,
|
||||
[]any{startUTC, endUTC},
|
||||
&stats.TotalRequests,
|
||||
&stats.TotalInputTokens,
|
||||
&stats.TotalOutputTokens,
|
||||
&stats.TotalCacheCreationTokens,
|
||||
&stats.TotalCacheReadTokens,
|
||||
&stats.TotalCost,
|
||||
&stats.TotalActualCost,
|
||||
&totalDurationMs,
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheCreationTokens + stats.TotalCacheReadTokens
|
||||
if stats.TotalRequests > 0 {
|
||||
stats.AverageDurationMs = float64(totalDurationMs) / float64(stats.TotalRequests)
|
||||
}
|
||||
|
||||
todayEnd := todayUTC.Add(24 * time.Hour)
|
||||
todayStatsQuery := `
|
||||
SELECT
|
||||
COUNT(*) as today_requests,
|
||||
@@ -380,13 +512,13 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
|
||||
COALESCE(SUM(total_cost), 0) as today_cost,
|
||||
COALESCE(SUM(actual_cost), 0) as today_actual_cost
|
||||
FROM usage_logs
|
||||
WHERE created_at >= $1
|
||||
WHERE created_at >= $1 AND created_at < $2
|
||||
`
|
||||
if err := scanSingleRow(
|
||||
ctx,
|
||||
r.sql,
|
||||
todayStatsQuery,
|
||||
[]any{today},
|
||||
[]any{todayUTC, todayEnd},
|
||||
&stats.TodayRequests,
|
||||
&stats.TodayInputTokens,
|
||||
&stats.TodayOutputTokens,
|
||||
@@ -395,19 +527,31 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
|
||||
&stats.TodayCost,
|
||||
&stats.TodayActualCost,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
stats.TodayTokens = stats.TodayInputTokens + stats.TodayOutputTokens + stats.TodayCacheCreationTokens + stats.TodayCacheReadTokens
|
||||
|
||||
// 性能指标:RPM 和 TPM(最近1分钟,全局)
|
||||
rpm, tpm, err := r.getPerformanceStats(ctx, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
activeUsersQuery := `
|
||||
SELECT COUNT(DISTINCT user_id) as active_users
|
||||
FROM usage_logs
|
||||
WHERE created_at >= $1 AND created_at < $2
|
||||
`
|
||||
if err := scanSingleRow(ctx, r.sql, activeUsersQuery, []any{todayUTC, todayEnd}, &stats.ActiveUsers); err != nil {
|
||||
return err
|
||||
}
|
||||
stats.Rpm = rpm
|
||||
stats.Tpm = tpm
|
||||
|
||||
return &stats, nil
|
||||
hourStart := now.UTC().Truncate(time.Hour)
|
||||
hourEnd := hourStart.Add(time.Hour)
|
||||
hourlyActiveQuery := `
|
||||
SELECT COUNT(DISTINCT user_id) as active_users
|
||||
FROM usage_logs
|
||||
WHERE created_at >= $1 AND created_at < $2
|
||||
`
|
||||
if err := scanSingleRow(ctx, r.sql, hourlyActiveQuery, []any{hourStart, hourEnd}, &stats.HourlyActiveUsers); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *usageLogRepository) ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||||
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/suite"
|
||||
@@ -198,8 +197,8 @@ func (s *UsageLogRepoSuite) TestListWithFilters() {
|
||||
// --- GetDashboardStats ---
|
||||
|
||||
func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
|
||||
now := time.Now()
|
||||
todayStart := timezone.Today()
|
||||
now := time.Now().UTC()
|
||||
todayStart := truncateToDayUTC(now)
|
||||
baseStats, err := s.repo.GetDashboardStats(s.ctx)
|
||||
s.Require().NoError(err, "GetDashboardStats base")
|
||||
|
||||
@@ -273,6 +272,11 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
|
||||
_, err = s.repo.Create(s.ctx, logPerf)
|
||||
s.Require().NoError(err, "Create logPerf")
|
||||
|
||||
aggRepo := newDashboardAggregationRepositoryWithSQL(s.tx)
|
||||
aggStart := todayStart.Add(-2 * time.Hour)
|
||||
aggEnd := now.Add(2 * time.Minute)
|
||||
s.Require().NoError(aggRepo.AggregateRange(s.ctx, aggStart, aggEnd), "AggregateRange")
|
||||
|
||||
stats, err := s.repo.GetDashboardStats(s.ctx)
|
||||
s.Require().NoError(err, "GetDashboardStats")
|
||||
|
||||
@@ -303,6 +307,80 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
|
||||
s.Require().Equal(wantTpm, stats.Tpm, "Tpm mismatch")
|
||||
}
|
||||
|
||||
func (s *UsageLogRepoSuite) TestDashboardStatsWithRange_Fallback() {
|
||||
now := time.Now().UTC()
|
||||
todayStart := truncateToDayUTC(now)
|
||||
rangeStart := todayStart.Add(-24 * time.Hour)
|
||||
rangeEnd := now.Add(1 * time.Second)
|
||||
|
||||
user1 := mustCreateUser(s.T(), s.client, &service.User{Email: "range-u1@test.com"})
|
||||
user2 := mustCreateUser(s.T(), s.client, &service.User{Email: "range-u2@test.com"})
|
||||
apiKey1 := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user1.ID, Key: "sk-range-1", Name: "k1"})
|
||||
apiKey2 := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user2.ID, Key: "sk-range-2", Name: "k2"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-range"})
|
||||
|
||||
d1, d2, d3 := 100, 200, 300
|
||||
logOutside := &service.UsageLog{
|
||||
UserID: user1.ID,
|
||||
APIKeyID: apiKey1.ID,
|
||||
AccountID: account.ID,
|
||||
Model: "claude-3",
|
||||
InputTokens: 7,
|
||||
OutputTokens: 8,
|
||||
TotalCost: 0.8,
|
||||
ActualCost: 0.7,
|
||||
DurationMs: &d3,
|
||||
CreatedAt: rangeStart.Add(-1 * time.Hour),
|
||||
}
|
||||
_, err := s.repo.Create(s.ctx, logOutside)
|
||||
s.Require().NoError(err)
|
||||
|
||||
logRange := &service.UsageLog{
|
||||
UserID: user1.ID,
|
||||
APIKeyID: apiKey1.ID,
|
||||
AccountID: account.ID,
|
||||
Model: "claude-3",
|
||||
InputTokens: 10,
|
||||
OutputTokens: 20,
|
||||
CacheCreationTokens: 1,
|
||||
CacheReadTokens: 2,
|
||||
TotalCost: 1.0,
|
||||
ActualCost: 0.9,
|
||||
DurationMs: &d1,
|
||||
CreatedAt: rangeStart.Add(2 * time.Hour),
|
||||
}
|
||||
_, err = s.repo.Create(s.ctx, logRange)
|
||||
s.Require().NoError(err)
|
||||
|
||||
logToday := &service.UsageLog{
|
||||
UserID: user2.ID,
|
||||
APIKeyID: apiKey2.ID,
|
||||
AccountID: account.ID,
|
||||
Model: "claude-3",
|
||||
InputTokens: 5,
|
||||
OutputTokens: 6,
|
||||
CacheReadTokens: 1,
|
||||
TotalCost: 0.5,
|
||||
ActualCost: 0.5,
|
||||
DurationMs: &d2,
|
||||
CreatedAt: now,
|
||||
}
|
||||
_, err = s.repo.Create(s.ctx, logToday)
|
||||
s.Require().NoError(err)
|
||||
|
||||
stats, err := s.repo.GetDashboardStatsWithRange(s.ctx, rangeStart, rangeEnd)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal(int64(2), stats.TotalRequests)
|
||||
s.Require().Equal(int64(15), stats.TotalInputTokens)
|
||||
s.Require().Equal(int64(26), stats.TotalOutputTokens)
|
||||
s.Require().Equal(int64(1), stats.TotalCacheCreationTokens)
|
||||
s.Require().Equal(int64(3), stats.TotalCacheReadTokens)
|
||||
s.Require().Equal(int64(45), stats.TotalTokens)
|
||||
s.Require().Equal(1.5, stats.TotalCost)
|
||||
s.Require().Equal(1.4, stats.TotalActualCost)
|
||||
s.Require().InEpsilon(150.0, stats.AverageDurationMs, 0.0001)
|
||||
}
|
||||
|
||||
// --- GetUserDashboardStats ---
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetUserDashboardStats() {
|
||||
@@ -333,6 +411,151 @@ func (s *UsageLogRepoSuite) TestGetAccountTodayStats() {
|
||||
s.Require().Equal(int64(30), stats.Tokens)
|
||||
}
|
||||
|
||||
func (s *UsageLogRepoSuite) TestDashboardAggregationConsistency() {
|
||||
now := time.Now().UTC().Truncate(time.Second)
|
||||
hour1 := now.Add(-90 * time.Minute).Truncate(time.Hour)
|
||||
hour2 := now.Add(-30 * time.Minute).Truncate(time.Hour)
|
||||
dayStart := truncateToDayUTC(now)
|
||||
|
||||
user1 := mustCreateUser(s.T(), s.client, &service.User{Email: "agg-u1@test.com"})
|
||||
user2 := mustCreateUser(s.T(), s.client, &service.User{Email: "agg-u2@test.com"})
|
||||
apiKey1 := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user1.ID, Key: "sk-agg-1", Name: "k1"})
|
||||
apiKey2 := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user2.ID, Key: "sk-agg-2", Name: "k2"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-agg"})
|
||||
|
||||
d1, d2, d3 := 100, 200, 150
|
||||
log1 := &service.UsageLog{
|
||||
UserID: user1.ID,
|
||||
APIKeyID: apiKey1.ID,
|
||||
AccountID: account.ID,
|
||||
Model: "claude-3",
|
||||
InputTokens: 10,
|
||||
OutputTokens: 20,
|
||||
CacheCreationTokens: 2,
|
||||
CacheReadTokens: 1,
|
||||
TotalCost: 1.0,
|
||||
ActualCost: 0.9,
|
||||
DurationMs: &d1,
|
||||
CreatedAt: hour1.Add(5 * time.Minute),
|
||||
}
|
||||
_, err := s.repo.Create(s.ctx, log1)
|
||||
s.Require().NoError(err)
|
||||
|
||||
log2 := &service.UsageLog{
|
||||
UserID: user1.ID,
|
||||
APIKeyID: apiKey1.ID,
|
||||
AccountID: account.ID,
|
||||
Model: "claude-3",
|
||||
InputTokens: 5,
|
||||
OutputTokens: 5,
|
||||
TotalCost: 0.5,
|
||||
ActualCost: 0.5,
|
||||
DurationMs: &d2,
|
||||
CreatedAt: hour1.Add(20 * time.Minute),
|
||||
}
|
||||
_, err = s.repo.Create(s.ctx, log2)
|
||||
s.Require().NoError(err)
|
||||
|
||||
log3 := &service.UsageLog{
|
||||
UserID: user2.ID,
|
||||
APIKeyID: apiKey2.ID,
|
||||
AccountID: account.ID,
|
||||
Model: "claude-3",
|
||||
InputTokens: 7,
|
||||
OutputTokens: 8,
|
||||
TotalCost: 0.7,
|
||||
ActualCost: 0.7,
|
||||
DurationMs: &d3,
|
||||
CreatedAt: hour2.Add(10 * time.Minute),
|
||||
}
|
||||
_, err = s.repo.Create(s.ctx, log3)
|
||||
s.Require().NoError(err)
|
||||
|
||||
aggRepo := newDashboardAggregationRepositoryWithSQL(s.tx)
|
||||
aggStart := hour1.Add(-5 * time.Minute)
|
||||
aggEnd := now.Add(5 * time.Minute)
|
||||
s.Require().NoError(aggRepo.AggregateRange(s.ctx, aggStart, aggEnd))
|
||||
|
||||
type hourlyRow struct {
|
||||
totalRequests int64
|
||||
inputTokens int64
|
||||
outputTokens int64
|
||||
cacheCreationTokens int64
|
||||
cacheReadTokens int64
|
||||
totalCost float64
|
||||
actualCost float64
|
||||
totalDurationMs int64
|
||||
activeUsers int64
|
||||
}
|
||||
fetchHourly := func(bucketStart time.Time) hourlyRow {
|
||||
var row hourlyRow
|
||||
err := scanSingleRow(s.ctx, s.tx, `
|
||||
SELECT total_requests, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens,
|
||||
total_cost, actual_cost, total_duration_ms, active_users
|
||||
FROM usage_dashboard_hourly
|
||||
WHERE bucket_start = $1
|
||||
`, []any{bucketStart}, &row.totalRequests, &row.inputTokens, &row.outputTokens,
|
||||
&row.cacheCreationTokens, &row.cacheReadTokens, &row.totalCost, &row.actualCost,
|
||||
&row.totalDurationMs, &row.activeUsers,
|
||||
)
|
||||
s.Require().NoError(err)
|
||||
return row
|
||||
}
|
||||
|
||||
hour1Row := fetchHourly(hour1)
|
||||
s.Require().Equal(int64(2), hour1Row.totalRequests)
|
||||
s.Require().Equal(int64(15), hour1Row.inputTokens)
|
||||
s.Require().Equal(int64(25), hour1Row.outputTokens)
|
||||
s.Require().Equal(int64(2), hour1Row.cacheCreationTokens)
|
||||
s.Require().Equal(int64(1), hour1Row.cacheReadTokens)
|
||||
s.Require().Equal(1.5, hour1Row.totalCost)
|
||||
s.Require().Equal(1.4, hour1Row.actualCost)
|
||||
s.Require().Equal(int64(300), hour1Row.totalDurationMs)
|
||||
s.Require().Equal(int64(1), hour1Row.activeUsers)
|
||||
|
||||
hour2Row := fetchHourly(hour2)
|
||||
s.Require().Equal(int64(1), hour2Row.totalRequests)
|
||||
s.Require().Equal(int64(7), hour2Row.inputTokens)
|
||||
s.Require().Equal(int64(8), hour2Row.outputTokens)
|
||||
s.Require().Equal(int64(0), hour2Row.cacheCreationTokens)
|
||||
s.Require().Equal(int64(0), hour2Row.cacheReadTokens)
|
||||
s.Require().Equal(0.7, hour2Row.totalCost)
|
||||
s.Require().Equal(0.7, hour2Row.actualCost)
|
||||
s.Require().Equal(int64(150), hour2Row.totalDurationMs)
|
||||
s.Require().Equal(int64(1), hour2Row.activeUsers)
|
||||
|
||||
var daily struct {
|
||||
totalRequests int64
|
||||
inputTokens int64
|
||||
outputTokens int64
|
||||
cacheCreationTokens int64
|
||||
cacheReadTokens int64
|
||||
totalCost float64
|
||||
actualCost float64
|
||||
totalDurationMs int64
|
||||
activeUsers int64
|
||||
}
|
||||
err = scanSingleRow(s.ctx, s.tx, `
|
||||
SELECT total_requests, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens,
|
||||
total_cost, actual_cost, total_duration_ms, active_users
|
||||
FROM usage_dashboard_daily
|
||||
WHERE bucket_date = $1::date
|
||||
`, []any{dayStart}, &daily.totalRequests, &daily.inputTokens, &daily.outputTokens,
|
||||
&daily.cacheCreationTokens, &daily.cacheReadTokens, &daily.totalCost, &daily.actualCost,
|
||||
&daily.totalDurationMs, &daily.activeUsers,
|
||||
)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal(int64(3), daily.totalRequests)
|
||||
s.Require().Equal(int64(22), daily.inputTokens)
|
||||
s.Require().Equal(int64(33), daily.outputTokens)
|
||||
s.Require().Equal(int64(2), daily.cacheCreationTokens)
|
||||
s.Require().Equal(int64(1), daily.cacheReadTokens)
|
||||
s.Require().Equal(2.2, daily.totalCost)
|
||||
s.Require().Equal(2.1, daily.actualCost)
|
||||
s.Require().Equal(int64(450), daily.totalDurationMs)
|
||||
s.Require().Equal(int64(2), daily.activeUsers)
|
||||
}
|
||||
|
||||
// --- GetBatchUserUsageStats ---
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats() {
|
||||
|
||||
@@ -47,6 +47,7 @@ var ProviderSet = wire.NewSet(
|
||||
NewRedeemCodeRepository,
|
||||
NewPromoCodeRepository,
|
||||
NewUsageLogRepository,
|
||||
NewDashboardAggregationRepository,
|
||||
NewSettingRepository,
|
||||
NewOpsRepository,
|
||||
NewUserSubscriptionRepository,
|
||||
@@ -59,6 +60,7 @@ var ProviderSet = wire.NewSet(
|
||||
NewAPIKeyCache,
|
||||
NewTempUnschedCache,
|
||||
ProvideConcurrencyCache,
|
||||
NewDashboardCache,
|
||||
NewEmailCache,
|
||||
NewIdentityCache,
|
||||
NewRedeemCache,
|
||||
|
||||
Reference in New Issue
Block a user