fix(仓储): 修复软删除过滤与事务测试
修复软删除拦截器使用错误,确保默认查询过滤已删记录 仓储层改用 ent.Tx 与扫描辅助,避免 sql.Tx 断言问题 同步更新集成测试以覆盖事务与统计变动
This commit is contained in:
@@ -3,7 +3,6 @@ package repository
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -33,6 +32,7 @@ func NewUsageLogRepository(client *dbent.Client, sqlDB *sql.DB) service.UsageLog
|
||||
}
|
||||
|
||||
func newUsageLogRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *usageLogRepository {
|
||||
// 使用 scanSingleRow 替代 QueryRowContext,保证 ent.Tx 作为 sqlExecutor 可用。
|
||||
return &usageLogRepository{client: client, sql: sqlq}
|
||||
}
|
||||
|
||||
@@ -53,7 +53,7 @@ func (r *usageLogRepository) getPerformanceStats(ctx context.Context, userID int
|
||||
|
||||
var requestCount int64
|
||||
var tokenCount int64
|
||||
if err := r.sql.QueryRowContext(ctx, query, args...).Scan(&requestCount, &tokenCount); err != nil {
|
||||
if err := scanSingleRow(ctx, r.sql, query, args, &requestCount, &tokenCount); err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
return requestCount / 5, tokenCount / 5, nil
|
||||
@@ -114,9 +114,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
||||
duration := nullInt(log.DurationMs)
|
||||
firstToken := nullInt(log.FirstTokenMs)
|
||||
|
||||
row := r.sql.QueryRowContext(
|
||||
ctx,
|
||||
query,
|
||||
args := []any{
|
||||
log.UserID,
|
||||
log.ApiKeyID,
|
||||
log.AccountID,
|
||||
@@ -142,9 +140,8 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
||||
duration,
|
||||
firstToken,
|
||||
createdAt,
|
||||
)
|
||||
|
||||
if err := row.Scan(&log.ID, &log.CreatedAt); err != nil {
|
||||
}
|
||||
if err := scanSingleRow(ctx, r.sql, query, args, &log.ID, &log.CreatedAt); err != nil {
|
||||
return err
|
||||
}
|
||||
log.RateMultiplier = rateMultiplier
|
||||
@@ -153,11 +150,22 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
||||
|
||||
func (r *usageLogRepository) GetByID(ctx context.Context, id int64) (*service.UsageLog, error) {
|
||||
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE id = $1"
|
||||
log, err := scanUsageLog(r.sql.QueryRowContext(ctx, query, id))
|
||||
rows, err := r.sql.QueryContext(ctx, query, id)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, service.ErrUsageLogNotFound
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
if !rows.Next() {
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return nil, service.ErrUsageLogNotFound
|
||||
}
|
||||
log, err := scanUsageLog(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return log, nil
|
||||
@@ -195,8 +203,18 @@ func (r *usageLogRepository) GetUserStats(ctx context.Context, userID int64, sta
|
||||
`
|
||||
|
||||
stats := &UserStats{}
|
||||
if err := r.sql.QueryRowContext(ctx, query, userID, startTime, endTime).
|
||||
Scan(&stats.TotalRequests, &stats.TotalTokens, &stats.TotalCost, &stats.InputTokens, &stats.OutputTokens, &stats.CacheReadTokens); err != nil {
|
||||
if err := scanSingleRow(
|
||||
ctx,
|
||||
r.sql,
|
||||
query,
|
||||
[]any{userID, startTime, endTime},
|
||||
&stats.TotalRequests,
|
||||
&stats.TotalTokens,
|
||||
&stats.TotalCost,
|
||||
&stats.InputTokens,
|
||||
&stats.OutputTokens,
|
||||
&stats.CacheReadTokens,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return stats, nil
|
||||
@@ -219,8 +237,15 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
|
||||
FROM users
|
||||
WHERE deleted_at IS NULL
|
||||
`
|
||||
if err := r.sql.QueryRowContext(ctx, userStatsQuery, today, today).
|
||||
Scan(&stats.TotalUsers, &stats.TodayNewUsers, &stats.ActiveUsers); err != nil {
|
||||
if err := scanSingleRow(
|
||||
ctx,
|
||||
r.sql,
|
||||
userStatsQuery,
|
||||
[]any{today, today},
|
||||
&stats.TotalUsers,
|
||||
&stats.TodayNewUsers,
|
||||
&stats.ActiveUsers,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -232,8 +257,14 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
|
||||
FROM api_keys
|
||||
WHERE deleted_at IS NULL
|
||||
`
|
||||
if err := r.sql.QueryRowContext(ctx, apiKeyStatsQuery, service.StatusActive).
|
||||
Scan(&stats.TotalApiKeys, &stats.ActiveApiKeys); err != nil {
|
||||
if err := scanSingleRow(
|
||||
ctx,
|
||||
r.sql,
|
||||
apiKeyStatsQuery,
|
||||
[]any{service.StatusActive},
|
||||
&stats.TotalApiKeys,
|
||||
&stats.ActiveApiKeys,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -248,8 +279,17 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
|
||||
FROM accounts
|
||||
WHERE deleted_at IS NULL
|
||||
`
|
||||
if err := r.sql.QueryRowContext(ctx, accountStatsQuery, service.StatusActive, service.StatusError, now, now).
|
||||
Scan(&stats.TotalAccounts, &stats.NormalAccounts, &stats.ErrorAccounts, &stats.RateLimitAccounts, &stats.OverloadAccounts); err != nil {
|
||||
if err := scanSingleRow(
|
||||
ctx,
|
||||
r.sql,
|
||||
accountStatsQuery,
|
||||
[]any{service.StatusActive, service.StatusError, now, now},
|
||||
&stats.TotalAccounts,
|
||||
&stats.NormalAccounts,
|
||||
&stats.ErrorAccounts,
|
||||
&stats.RateLimitAccounts,
|
||||
&stats.OverloadAccounts,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -266,17 +306,20 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
|
||||
COALESCE(AVG(duration_ms), 0) as avg_duration_ms
|
||||
FROM usage_logs
|
||||
`
|
||||
if err := r.sql.QueryRowContext(ctx, totalStatsQuery).
|
||||
Scan(
|
||||
&stats.TotalRequests,
|
||||
&stats.TotalInputTokens,
|
||||
&stats.TotalOutputTokens,
|
||||
&stats.TotalCacheCreationTokens,
|
||||
&stats.TotalCacheReadTokens,
|
||||
&stats.TotalCost,
|
||||
&stats.TotalActualCost,
|
||||
&stats.AverageDurationMs,
|
||||
); err != nil {
|
||||
if err := scanSingleRow(
|
||||
ctx,
|
||||
r.sql,
|
||||
totalStatsQuery,
|
||||
nil,
|
||||
&stats.TotalRequests,
|
||||
&stats.TotalInputTokens,
|
||||
&stats.TotalOutputTokens,
|
||||
&stats.TotalCacheCreationTokens,
|
||||
&stats.TotalCacheReadTokens,
|
||||
&stats.TotalCost,
|
||||
&stats.TotalActualCost,
|
||||
&stats.AverageDurationMs,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheCreationTokens + stats.TotalCacheReadTokens
|
||||
@@ -294,16 +337,19 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
|
||||
FROM usage_logs
|
||||
WHERE created_at >= $1
|
||||
`
|
||||
if err := r.sql.QueryRowContext(ctx, todayStatsQuery, today).
|
||||
Scan(
|
||||
&stats.TodayRequests,
|
||||
&stats.TodayInputTokens,
|
||||
&stats.TodayOutputTokens,
|
||||
&stats.TodayCacheCreationTokens,
|
||||
&stats.TodayCacheReadTokens,
|
||||
&stats.TodayCost,
|
||||
&stats.TodayActualCost,
|
||||
); err != nil {
|
||||
if err := scanSingleRow(
|
||||
ctx,
|
||||
r.sql,
|
||||
todayStatsQuery,
|
||||
[]any{today},
|
||||
&stats.TodayRequests,
|
||||
&stats.TodayInputTokens,
|
||||
&stats.TodayOutputTokens,
|
||||
&stats.TodayCacheCreationTokens,
|
||||
&stats.TodayCacheReadTokens,
|
||||
&stats.TodayCost,
|
||||
&stats.TodayActualCost,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stats.TodayTokens = stats.TodayInputTokens + stats.TodayOutputTokens + stats.TodayCacheCreationTokens + stats.TodayCacheReadTokens
|
||||
@@ -345,16 +391,19 @@ func (r *usageLogRepository) GetUserStatsAggregated(ctx context.Context, userID
|
||||
`
|
||||
|
||||
var stats usagestats.UsageStats
|
||||
if err := r.sql.QueryRowContext(ctx, query, userID, startTime, endTime).
|
||||
Scan(
|
||||
&stats.TotalRequests,
|
||||
&stats.TotalInputTokens,
|
||||
&stats.TotalOutputTokens,
|
||||
&stats.TotalCacheTokens,
|
||||
&stats.TotalCost,
|
||||
&stats.TotalActualCost,
|
||||
&stats.AverageDurationMs,
|
||||
); err != nil {
|
||||
if err := scanSingleRow(
|
||||
ctx,
|
||||
r.sql,
|
||||
query,
|
||||
[]any{userID, startTime, endTime},
|
||||
&stats.TotalRequests,
|
||||
&stats.TotalInputTokens,
|
||||
&stats.TotalOutputTokens,
|
||||
&stats.TotalCacheTokens,
|
||||
&stats.TotalCost,
|
||||
&stats.TotalActualCost,
|
||||
&stats.AverageDurationMs,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheTokens
|
||||
@@ -377,16 +426,19 @@ func (r *usageLogRepository) GetApiKeyStatsAggregated(ctx context.Context, apiKe
|
||||
`
|
||||
|
||||
var stats usagestats.UsageStats
|
||||
if err := r.sql.QueryRowContext(ctx, query, apiKeyID, startTime, endTime).
|
||||
Scan(
|
||||
&stats.TotalRequests,
|
||||
&stats.TotalInputTokens,
|
||||
&stats.TotalOutputTokens,
|
||||
&stats.TotalCacheTokens,
|
||||
&stats.TotalCost,
|
||||
&stats.TotalActualCost,
|
||||
&stats.AverageDurationMs,
|
||||
); err != nil {
|
||||
if err := scanSingleRow(
|
||||
ctx,
|
||||
r.sql,
|
||||
query,
|
||||
[]any{apiKeyID, startTime, endTime},
|
||||
&stats.TotalRequests,
|
||||
&stats.TotalInputTokens,
|
||||
&stats.TotalOutputTokens,
|
||||
&stats.TotalCacheTokens,
|
||||
&stats.TotalCost,
|
||||
&stats.TotalActualCost,
|
||||
&stats.AverageDurationMs,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheTokens
|
||||
@@ -430,8 +482,15 @@ func (r *usageLogRepository) GetAccountTodayStats(ctx context.Context, accountID
|
||||
`
|
||||
|
||||
stats := &usagestats.AccountStats{}
|
||||
if err := r.sql.QueryRowContext(ctx, query, accountID, today).
|
||||
Scan(&stats.Requests, &stats.Tokens, &stats.Cost); err != nil {
|
||||
if err := scanSingleRow(
|
||||
ctx,
|
||||
r.sql,
|
||||
query,
|
||||
[]any{accountID, today},
|
||||
&stats.Requests,
|
||||
&stats.Tokens,
|
||||
&stats.Cost,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return stats, nil
|
||||
@@ -449,8 +508,15 @@ func (r *usageLogRepository) GetAccountWindowStats(ctx context.Context, accountI
|
||||
`
|
||||
|
||||
stats := &usagestats.AccountStats{}
|
||||
if err := r.sql.QueryRowContext(ctx, query, accountID, startTime).
|
||||
Scan(&stats.Requests, &stats.Tokens, &stats.Cost); err != nil {
|
||||
if err := scanSingleRow(
|
||||
ctx,
|
||||
r.sql,
|
||||
query,
|
||||
[]any{accountID, startTime},
|
||||
&stats.Requests,
|
||||
&stats.Tokens,
|
||||
&stats.Cost,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return stats, nil
|
||||
@@ -581,12 +647,22 @@ func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID i
|
||||
today := timezone.Today()
|
||||
|
||||
// API Key 统计
|
||||
if err := r.sql.QueryRowContext(ctx, "SELECT COUNT(*) FROM api_keys WHERE user_id = $1 AND deleted_at IS NULL", userID).
|
||||
Scan(&stats.TotalApiKeys); err != nil {
|
||||
if err := scanSingleRow(
|
||||
ctx,
|
||||
r.sql,
|
||||
"SELECT COUNT(*) FROM api_keys WHERE user_id = $1 AND deleted_at IS NULL",
|
||||
[]any{userID},
|
||||
&stats.TotalApiKeys,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := r.sql.QueryRowContext(ctx, "SELECT COUNT(*) FROM api_keys WHERE user_id = $1 AND status = $2 AND deleted_at IS NULL", userID, service.StatusActive).
|
||||
Scan(&stats.ActiveApiKeys); err != nil {
|
||||
if err := scanSingleRow(
|
||||
ctx,
|
||||
r.sql,
|
||||
"SELECT COUNT(*) FROM api_keys WHERE user_id = $1 AND status = $2 AND deleted_at IS NULL",
|
||||
[]any{userID, service.StatusActive},
|
||||
&stats.ActiveApiKeys,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -604,17 +680,20 @@ func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID i
|
||||
FROM usage_logs
|
||||
WHERE user_id = $1
|
||||
`
|
||||
if err := r.sql.QueryRowContext(ctx, totalStatsQuery, userID).
|
||||
Scan(
|
||||
&stats.TotalRequests,
|
||||
&stats.TotalInputTokens,
|
||||
&stats.TotalOutputTokens,
|
||||
&stats.TotalCacheCreationTokens,
|
||||
&stats.TotalCacheReadTokens,
|
||||
&stats.TotalCost,
|
||||
&stats.TotalActualCost,
|
||||
&stats.AverageDurationMs,
|
||||
); err != nil {
|
||||
if err := scanSingleRow(
|
||||
ctx,
|
||||
r.sql,
|
||||
totalStatsQuery,
|
||||
[]any{userID},
|
||||
&stats.TotalRequests,
|
||||
&stats.TotalInputTokens,
|
||||
&stats.TotalOutputTokens,
|
||||
&stats.TotalCacheCreationTokens,
|
||||
&stats.TotalCacheReadTokens,
|
||||
&stats.TotalCost,
|
||||
&stats.TotalActualCost,
|
||||
&stats.AverageDurationMs,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheCreationTokens + stats.TotalCacheReadTokens
|
||||
@@ -632,16 +711,19 @@ func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID i
|
||||
FROM usage_logs
|
||||
WHERE user_id = $1 AND created_at >= $2
|
||||
`
|
||||
if err := r.sql.QueryRowContext(ctx, todayStatsQuery, userID, today).
|
||||
Scan(
|
||||
&stats.TodayRequests,
|
||||
&stats.TodayInputTokens,
|
||||
&stats.TodayOutputTokens,
|
||||
&stats.TodayCacheCreationTokens,
|
||||
&stats.TodayCacheReadTokens,
|
||||
&stats.TodayCost,
|
||||
&stats.TodayActualCost,
|
||||
); err != nil {
|
||||
if err := scanSingleRow(
|
||||
ctx,
|
||||
r.sql,
|
||||
todayStatsQuery,
|
||||
[]any{userID, today},
|
||||
&stats.TodayRequests,
|
||||
&stats.TodayInputTokens,
|
||||
&stats.TodayOutputTokens,
|
||||
&stats.TodayCacheCreationTokens,
|
||||
&stats.TodayCacheReadTokens,
|
||||
&stats.TodayCost,
|
||||
&stats.TodayActualCost,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stats.TodayTokens = stats.TodayInputTokens + stats.TodayOutputTokens + stats.TodayCacheCreationTokens + stats.TodayCacheReadTokens
|
||||
@@ -1007,16 +1089,19 @@ func (r *usageLogRepository) GetGlobalStats(ctx context.Context, startTime, endT
|
||||
`
|
||||
|
||||
stats := &UsageStats{}
|
||||
if err := r.sql.QueryRowContext(ctx, query, startTime, endTime).
|
||||
Scan(
|
||||
&stats.TotalRequests,
|
||||
&stats.TotalInputTokens,
|
||||
&stats.TotalOutputTokens,
|
||||
&stats.TotalCacheTokens,
|
||||
&stats.TotalCost,
|
||||
&stats.TotalActualCost,
|
||||
&stats.AverageDurationMs,
|
||||
); err != nil {
|
||||
if err := scanSingleRow(
|
||||
ctx,
|
||||
r.sql,
|
||||
query,
|
||||
[]any{startTime, endTime},
|
||||
&stats.TotalRequests,
|
||||
&stats.TotalInputTokens,
|
||||
&stats.TotalOutputTokens,
|
||||
&stats.TotalCacheTokens,
|
||||
&stats.TotalCost,
|
||||
&stats.TotalActualCost,
|
||||
&stats.AverageDurationMs,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheTokens
|
||||
@@ -1108,7 +1193,7 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
|
||||
|
||||
avgQuery := "SELECT COALESCE(AVG(duration_ms), 0) as avg_duration_ms FROM usage_logs WHERE account_id = $1 AND created_at >= $2 AND created_at < $3"
|
||||
var avgDuration float64
|
||||
if err := r.sql.QueryRowContext(ctx, avgQuery, accountID, startTime, endTime).Scan(&avgDuration); err != nil {
|
||||
if err := scanSingleRow(ctx, r.sql, avgQuery, []any{accountID, startTime, endTime}, &avgDuration); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -1186,7 +1271,7 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
|
||||
func (r *usageLogRepository) listUsageLogsWithPagination(ctx context.Context, whereClause string, args []any, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||||
countQuery := "SELECT COUNT(*) FROM usage_logs " + whereClause
|
||||
var total int64
|
||||
if err := r.sql.QueryRowContext(ctx, countQuery, args...).Scan(&total); err != nil {
|
||||
if err := scanSingleRow(ctx, r.sql, countQuery, args, &total); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user