fix(仓储): 修复软删除过滤与事务测试

修复软删除拦截器使用错误,确保默认查询过滤已删记录
仓储层改用 ent.Tx 与扫描辅助,避免 sql.Tx 断言问题
同步更新集成测试以覆盖事务与统计变动
This commit is contained in:
yangjianbo
2025-12-29 19:23:49 +08:00
parent b436da7249
commit ae191f72a4
20 changed files with 565 additions and 326 deletions

View File

@@ -3,7 +3,6 @@ package repository
import (
"context"
"database/sql"
"errors"
"fmt"
"strings"
"time"
@@ -33,6 +32,7 @@ func NewUsageLogRepository(client *dbent.Client, sqlDB *sql.DB) service.UsageLog
}
func newUsageLogRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *usageLogRepository {
// 使用 scanSingleRow 替代 QueryRowContext保证 ent.Tx 作为 sqlExecutor 可用。
return &usageLogRepository{client: client, sql: sqlq}
}
@@ -53,7 +53,7 @@ func (r *usageLogRepository) getPerformanceStats(ctx context.Context, userID int
var requestCount int64
var tokenCount int64
if err := r.sql.QueryRowContext(ctx, query, args...).Scan(&requestCount, &tokenCount); err != nil {
if err := scanSingleRow(ctx, r.sql, query, args, &requestCount, &tokenCount); err != nil {
return 0, 0, err
}
return requestCount / 5, tokenCount / 5, nil
@@ -114,9 +114,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
duration := nullInt(log.DurationMs)
firstToken := nullInt(log.FirstTokenMs)
row := r.sql.QueryRowContext(
ctx,
query,
args := []any{
log.UserID,
log.ApiKeyID,
log.AccountID,
@@ -142,9 +140,8 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
duration,
firstToken,
createdAt,
)
if err := row.Scan(&log.ID, &log.CreatedAt); err != nil {
}
if err := scanSingleRow(ctx, r.sql, query, args, &log.ID, &log.CreatedAt); err != nil {
return err
}
log.RateMultiplier = rateMultiplier
@@ -153,11 +150,22 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
func (r *usageLogRepository) GetByID(ctx context.Context, id int64) (*service.UsageLog, error) {
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE id = $1"
log, err := scanUsageLog(r.sql.QueryRowContext(ctx, query, id))
rows, err := r.sql.QueryContext(ctx, query, id)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, service.ErrUsageLogNotFound
return nil, err
}
defer rows.Close()
if !rows.Next() {
if err := rows.Err(); err != nil {
return nil, err
}
return nil, service.ErrUsageLogNotFound
}
log, err := scanUsageLog(rows)
if err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return log, nil
@@ -195,8 +203,18 @@ func (r *usageLogRepository) GetUserStats(ctx context.Context, userID int64, sta
`
stats := &UserStats{}
if err := r.sql.QueryRowContext(ctx, query, userID, startTime, endTime).
Scan(&stats.TotalRequests, &stats.TotalTokens, &stats.TotalCost, &stats.InputTokens, &stats.OutputTokens, &stats.CacheReadTokens); err != nil {
if err := scanSingleRow(
ctx,
r.sql,
query,
[]any{userID, startTime, endTime},
&stats.TotalRequests,
&stats.TotalTokens,
&stats.TotalCost,
&stats.InputTokens,
&stats.OutputTokens,
&stats.CacheReadTokens,
); err != nil {
return nil, err
}
return stats, nil
@@ -219,8 +237,15 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
FROM users
WHERE deleted_at IS NULL
`
if err := r.sql.QueryRowContext(ctx, userStatsQuery, today, today).
Scan(&stats.TotalUsers, &stats.TodayNewUsers, &stats.ActiveUsers); err != nil {
if err := scanSingleRow(
ctx,
r.sql,
userStatsQuery,
[]any{today, today},
&stats.TotalUsers,
&stats.TodayNewUsers,
&stats.ActiveUsers,
); err != nil {
return nil, err
}
@@ -232,8 +257,14 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
FROM api_keys
WHERE deleted_at IS NULL
`
if err := r.sql.QueryRowContext(ctx, apiKeyStatsQuery, service.StatusActive).
Scan(&stats.TotalApiKeys, &stats.ActiveApiKeys); err != nil {
if err := scanSingleRow(
ctx,
r.sql,
apiKeyStatsQuery,
[]any{service.StatusActive},
&stats.TotalApiKeys,
&stats.ActiveApiKeys,
); err != nil {
return nil, err
}
@@ -248,8 +279,17 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
FROM accounts
WHERE deleted_at IS NULL
`
if err := r.sql.QueryRowContext(ctx, accountStatsQuery, service.StatusActive, service.StatusError, now, now).
Scan(&stats.TotalAccounts, &stats.NormalAccounts, &stats.ErrorAccounts, &stats.RateLimitAccounts, &stats.OverloadAccounts); err != nil {
if err := scanSingleRow(
ctx,
r.sql,
accountStatsQuery,
[]any{service.StatusActive, service.StatusError, now, now},
&stats.TotalAccounts,
&stats.NormalAccounts,
&stats.ErrorAccounts,
&stats.RateLimitAccounts,
&stats.OverloadAccounts,
); err != nil {
return nil, err
}
@@ -266,17 +306,20 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
COALESCE(AVG(duration_ms), 0) as avg_duration_ms
FROM usage_logs
`
if err := r.sql.QueryRowContext(ctx, totalStatsQuery).
Scan(
&stats.TotalRequests,
&stats.TotalInputTokens,
&stats.TotalOutputTokens,
&stats.TotalCacheCreationTokens,
&stats.TotalCacheReadTokens,
&stats.TotalCost,
&stats.TotalActualCost,
&stats.AverageDurationMs,
); err != nil {
if err := scanSingleRow(
ctx,
r.sql,
totalStatsQuery,
nil,
&stats.TotalRequests,
&stats.TotalInputTokens,
&stats.TotalOutputTokens,
&stats.TotalCacheCreationTokens,
&stats.TotalCacheReadTokens,
&stats.TotalCost,
&stats.TotalActualCost,
&stats.AverageDurationMs,
); err != nil {
return nil, err
}
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheCreationTokens + stats.TotalCacheReadTokens
@@ -294,16 +337,19 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
FROM usage_logs
WHERE created_at >= $1
`
if err := r.sql.QueryRowContext(ctx, todayStatsQuery, today).
Scan(
&stats.TodayRequests,
&stats.TodayInputTokens,
&stats.TodayOutputTokens,
&stats.TodayCacheCreationTokens,
&stats.TodayCacheReadTokens,
&stats.TodayCost,
&stats.TodayActualCost,
); err != nil {
if err := scanSingleRow(
ctx,
r.sql,
todayStatsQuery,
[]any{today},
&stats.TodayRequests,
&stats.TodayInputTokens,
&stats.TodayOutputTokens,
&stats.TodayCacheCreationTokens,
&stats.TodayCacheReadTokens,
&stats.TodayCost,
&stats.TodayActualCost,
); err != nil {
return nil, err
}
stats.TodayTokens = stats.TodayInputTokens + stats.TodayOutputTokens + stats.TodayCacheCreationTokens + stats.TodayCacheReadTokens
@@ -345,16 +391,19 @@ func (r *usageLogRepository) GetUserStatsAggregated(ctx context.Context, userID
`
var stats usagestats.UsageStats
if err := r.sql.QueryRowContext(ctx, query, userID, startTime, endTime).
Scan(
&stats.TotalRequests,
&stats.TotalInputTokens,
&stats.TotalOutputTokens,
&stats.TotalCacheTokens,
&stats.TotalCost,
&stats.TotalActualCost,
&stats.AverageDurationMs,
); err != nil {
if err := scanSingleRow(
ctx,
r.sql,
query,
[]any{userID, startTime, endTime},
&stats.TotalRequests,
&stats.TotalInputTokens,
&stats.TotalOutputTokens,
&stats.TotalCacheTokens,
&stats.TotalCost,
&stats.TotalActualCost,
&stats.AverageDurationMs,
); err != nil {
return nil, err
}
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheTokens
@@ -377,16 +426,19 @@ func (r *usageLogRepository) GetApiKeyStatsAggregated(ctx context.Context, apiKe
`
var stats usagestats.UsageStats
if err := r.sql.QueryRowContext(ctx, query, apiKeyID, startTime, endTime).
Scan(
&stats.TotalRequests,
&stats.TotalInputTokens,
&stats.TotalOutputTokens,
&stats.TotalCacheTokens,
&stats.TotalCost,
&stats.TotalActualCost,
&stats.AverageDurationMs,
); err != nil {
if err := scanSingleRow(
ctx,
r.sql,
query,
[]any{apiKeyID, startTime, endTime},
&stats.TotalRequests,
&stats.TotalInputTokens,
&stats.TotalOutputTokens,
&stats.TotalCacheTokens,
&stats.TotalCost,
&stats.TotalActualCost,
&stats.AverageDurationMs,
); err != nil {
return nil, err
}
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheTokens
@@ -430,8 +482,15 @@ func (r *usageLogRepository) GetAccountTodayStats(ctx context.Context, accountID
`
stats := &usagestats.AccountStats{}
if err := r.sql.QueryRowContext(ctx, query, accountID, today).
Scan(&stats.Requests, &stats.Tokens, &stats.Cost); err != nil {
if err := scanSingleRow(
ctx,
r.sql,
query,
[]any{accountID, today},
&stats.Requests,
&stats.Tokens,
&stats.Cost,
); err != nil {
return nil, err
}
return stats, nil
@@ -449,8 +508,15 @@ func (r *usageLogRepository) GetAccountWindowStats(ctx context.Context, accountI
`
stats := &usagestats.AccountStats{}
if err := r.sql.QueryRowContext(ctx, query, accountID, startTime).
Scan(&stats.Requests, &stats.Tokens, &stats.Cost); err != nil {
if err := scanSingleRow(
ctx,
r.sql,
query,
[]any{accountID, startTime},
&stats.Requests,
&stats.Tokens,
&stats.Cost,
); err != nil {
return nil, err
}
return stats, nil
@@ -581,12 +647,22 @@ func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID i
today := timezone.Today()
// API Key 统计
if err := r.sql.QueryRowContext(ctx, "SELECT COUNT(*) FROM api_keys WHERE user_id = $1 AND deleted_at IS NULL", userID).
Scan(&stats.TotalApiKeys); err != nil {
if err := scanSingleRow(
ctx,
r.sql,
"SELECT COUNT(*) FROM api_keys WHERE user_id = $1 AND deleted_at IS NULL",
[]any{userID},
&stats.TotalApiKeys,
); err != nil {
return nil, err
}
if err := r.sql.QueryRowContext(ctx, "SELECT COUNT(*) FROM api_keys WHERE user_id = $1 AND status = $2 AND deleted_at IS NULL", userID, service.StatusActive).
Scan(&stats.ActiveApiKeys); err != nil {
if err := scanSingleRow(
ctx,
r.sql,
"SELECT COUNT(*) FROM api_keys WHERE user_id = $1 AND status = $2 AND deleted_at IS NULL",
[]any{userID, service.StatusActive},
&stats.ActiveApiKeys,
); err != nil {
return nil, err
}
@@ -604,17 +680,20 @@ func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID i
FROM usage_logs
WHERE user_id = $1
`
if err := r.sql.QueryRowContext(ctx, totalStatsQuery, userID).
Scan(
&stats.TotalRequests,
&stats.TotalInputTokens,
&stats.TotalOutputTokens,
&stats.TotalCacheCreationTokens,
&stats.TotalCacheReadTokens,
&stats.TotalCost,
&stats.TotalActualCost,
&stats.AverageDurationMs,
); err != nil {
if err := scanSingleRow(
ctx,
r.sql,
totalStatsQuery,
[]any{userID},
&stats.TotalRequests,
&stats.TotalInputTokens,
&stats.TotalOutputTokens,
&stats.TotalCacheCreationTokens,
&stats.TotalCacheReadTokens,
&stats.TotalCost,
&stats.TotalActualCost,
&stats.AverageDurationMs,
); err != nil {
return nil, err
}
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheCreationTokens + stats.TotalCacheReadTokens
@@ -632,16 +711,19 @@ func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID i
FROM usage_logs
WHERE user_id = $1 AND created_at >= $2
`
if err := r.sql.QueryRowContext(ctx, todayStatsQuery, userID, today).
Scan(
&stats.TodayRequests,
&stats.TodayInputTokens,
&stats.TodayOutputTokens,
&stats.TodayCacheCreationTokens,
&stats.TodayCacheReadTokens,
&stats.TodayCost,
&stats.TodayActualCost,
); err != nil {
if err := scanSingleRow(
ctx,
r.sql,
todayStatsQuery,
[]any{userID, today},
&stats.TodayRequests,
&stats.TodayInputTokens,
&stats.TodayOutputTokens,
&stats.TodayCacheCreationTokens,
&stats.TodayCacheReadTokens,
&stats.TodayCost,
&stats.TodayActualCost,
); err != nil {
return nil, err
}
stats.TodayTokens = stats.TodayInputTokens + stats.TodayOutputTokens + stats.TodayCacheCreationTokens + stats.TodayCacheReadTokens
@@ -1007,16 +1089,19 @@ func (r *usageLogRepository) GetGlobalStats(ctx context.Context, startTime, endT
`
stats := &UsageStats{}
if err := r.sql.QueryRowContext(ctx, query, startTime, endTime).
Scan(
&stats.TotalRequests,
&stats.TotalInputTokens,
&stats.TotalOutputTokens,
&stats.TotalCacheTokens,
&stats.TotalCost,
&stats.TotalActualCost,
&stats.AverageDurationMs,
); err != nil {
if err := scanSingleRow(
ctx,
r.sql,
query,
[]any{startTime, endTime},
&stats.TotalRequests,
&stats.TotalInputTokens,
&stats.TotalOutputTokens,
&stats.TotalCacheTokens,
&stats.TotalCost,
&stats.TotalActualCost,
&stats.AverageDurationMs,
); err != nil {
return nil, err
}
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheTokens
@@ -1108,7 +1193,7 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
avgQuery := "SELECT COALESCE(AVG(duration_ms), 0) as avg_duration_ms FROM usage_logs WHERE account_id = $1 AND created_at >= $2 AND created_at < $3"
var avgDuration float64
if err := r.sql.QueryRowContext(ctx, avgQuery, accountID, startTime, endTime).Scan(&avgDuration); err != nil {
if err := scanSingleRow(ctx, r.sql, avgQuery, []any{accountID, startTime, endTime}, &avgDuration); err != nil {
return nil, err
}
@@ -1186,7 +1271,7 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
func (r *usageLogRepository) listUsageLogsWithPagination(ctx context.Context, whereClause string, args []any, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
countQuery := "SELECT COUNT(*) FROM usage_logs " + whereClause
var total int64
if err := r.sql.QueryRowContext(ctx, countQuery, args...).Scan(&total); err != nil {
if err := scanSingleRow(ctx, r.sql, countQuery, args, &total); err != nil {
return nil, nil, err
}