fix(仓储): 规范 rows.Close 错误回传

统一 usage_log_repo 查询的 Close 错误处理,避免\n成功路径吞掉关闭失败

scanSingleRow 使用 errors.Join 合并 Close 错误,\n保留 ErrNoRows 可判定

测试: make -C backend test-unit
This commit is contained in:
yangjianbo
2025-12-30 17:13:32 +08:00
parent 7e758b24c4
commit 8cb2d3b352
2 changed files with 125 additions and 42 deletions

View File

@@ -3,14 +3,16 @@ package repository
import (
"context"
"database/sql"
"errors"
)
type sqlQueryer interface {
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
}
// scanSingleRow executes a query and scans the first row into dest.
// If no rows are returned, sql.ErrNoRows is returned.
// scanSingleRow 执行查询并扫描第一行到 dest
// 若无结果,可通过 errors.Is(err, sql.ErrNoRows) 判断。
// 如果 Close 失败,会与原始错误合并返回。
// 设计目的:仅依赖 QueryContext避免 QueryRowContext 对 *sql.Tx 的强绑定,
// 让 ent.Tx 也能作为 sqlExecutor/Queryer 使用。
func scanSingleRow(ctx context.Context, q sqlQueryer, query string, args []any, dest ...any) (err error) {
@@ -19,8 +21,8 @@ func scanSingleRow(ctx context.Context, q sqlQueryer, query string, args []any,
return err
}
defer func() {
if closeErr := rows.Close(); closeErr != nil && err == nil {
err = closeErr
if closeErr := rows.Close(); closeErr != nil {
err = errors.Join(err, closeErr)
}
}()

View File

@@ -148,24 +148,31 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
return nil
}
func (r *usageLogRepository) GetByID(ctx context.Context, id int64) (*service.UsageLog, error) {
func (r *usageLogRepository) GetByID(ctx context.Context, id int64) (log *service.UsageLog, err error) {
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE id = $1"
rows, err := r.sql.QueryContext(ctx, query, id)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
defer func() {
// 保持主错误优先;仅在无错误时回传 Close 失败。
// 同时清空返回值,避免误用不完整结果。
if closeErr := rows.Close(); closeErr != nil && err == nil {
err = closeErr
log = nil
}
}()
if !rows.Next() {
if err := rows.Err(); err != nil {
if err = rows.Err(); err != nil {
return nil, err
}
return nil, service.ErrUsageLogNotFound
}
log, err := scanUsageLog(rows)
log, err = scanUsageLog(rows)
if err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
if err = rows.Err(); err != nil {
return nil, err
}
return log, nil
@@ -535,7 +542,7 @@ type UserUsageTrendPoint = usagestats.UserUsageTrendPoint
type ApiKeyUsageTrendPoint = usagestats.ApiKeyUsageTrendPoint
// GetApiKeyUsageTrend returns usage trend data grouped by API key and date
func (r *usageLogRepository) GetApiKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]ApiKeyUsageTrendPoint, error) {
func (r *usageLogRepository) GetApiKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) (results []ApiKeyUsageTrendPoint, err error) {
dateFormat := "YYYY-MM-DD"
if granularity == "hour" {
dateFormat = "YYYY-MM-DD HH24:00"
@@ -568,17 +575,24 @@ func (r *usageLogRepository) GetApiKeyUsageTrend(ctx context.Context, startTime,
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
defer func() {
// 保持主错误优先;仅在无错误时回传 Close 失败。
// 同时清空返回值,避免误用不完整结果。
if closeErr := rows.Close(); closeErr != nil && err == nil {
err = closeErr
results = nil
}
}()
results := make([]ApiKeyUsageTrendPoint, 0)
results = make([]ApiKeyUsageTrendPoint, 0)
for rows.Next() {
var row ApiKeyUsageTrendPoint
if err := rows.Scan(&row.Date, &row.ApiKeyID, &row.KeyName, &row.Requests, &row.Tokens); err != nil {
if err = rows.Scan(&row.Date, &row.ApiKeyID, &row.KeyName, &row.Requests, &row.Tokens); err != nil {
return nil, err
}
results = append(results, row)
}
if err := rows.Err(); err != nil {
if err = rows.Err(); err != nil {
return nil, err
}
@@ -586,7 +600,7 @@ func (r *usageLogRepository) GetApiKeyUsageTrend(ctx context.Context, startTime,
}
// GetUserUsageTrend returns usage trend data grouped by user and date
func (r *usageLogRepository) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]UserUsageTrendPoint, error) {
func (r *usageLogRepository) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) (results []UserUsageTrendPoint, err error) {
dateFormat := "YYYY-MM-DD"
if granularity == "hour" {
dateFormat = "YYYY-MM-DD HH24:00"
@@ -621,17 +635,24 @@ func (r *usageLogRepository) GetUserUsageTrend(ctx context.Context, startTime, e
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
defer func() {
// 保持主错误优先;仅在无错误时回传 Close 失败。
// 同时清空返回值,避免误用不完整结果。
if closeErr := rows.Close(); closeErr != nil && err == nil {
err = closeErr
results = nil
}
}()
results := make([]UserUsageTrendPoint, 0)
results = make([]UserUsageTrendPoint, 0)
for rows.Next() {
var row UserUsageTrendPoint
if err := rows.Scan(&row.Date, &row.UserID, &row.Email, &row.Requests, &row.Tokens, &row.Cost, &row.ActualCost); err != nil {
if err = rows.Scan(&row.Date, &row.UserID, &row.Email, &row.Requests, &row.Tokens, &row.Cost, &row.ActualCost); err != nil {
return nil, err
}
results = append(results, row)
}
if err := rows.Err(); err != nil {
if err = rows.Err(); err != nil {
return nil, err
}
@@ -740,7 +761,7 @@ func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID i
}
// GetUserUsageTrendByUserID 获取指定用户的使用趋势
func (r *usageLogRepository) GetUserUsageTrendByUserID(ctx context.Context, userID int64, startTime, endTime time.Time, granularity string) ([]TrendDataPoint, error) {
func (r *usageLogRepository) GetUserUsageTrendByUserID(ctx context.Context, userID int64, startTime, endTime time.Time, granularity string) (results []TrendDataPoint, err error) {
dateFormat := "YYYY-MM-DD"
if granularity == "hour" {
dateFormat = "YYYY-MM-DD HH24:00"
@@ -766,13 +787,24 @@ func (r *usageLogRepository) GetUserUsageTrendByUserID(ctx context.Context, user
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
defer func() {
// 保持主错误优先;仅在无错误时回传 Close 失败。
// 同时清空返回值,避免误用不完整结果。
if closeErr := rows.Close(); closeErr != nil && err == nil {
err = closeErr
results = nil
}
}()
return scanTrendRows(rows)
results, err = scanTrendRows(rows)
if err != nil {
return nil, err
}
return results, nil
}
// GetUserModelStats 获取指定用户的模型统计
func (r *usageLogRepository) GetUserModelStats(ctx context.Context, userID int64, startTime, endTime time.Time) ([]ModelStat, error) {
func (r *usageLogRepository) GetUserModelStats(ctx context.Context, userID int64, startTime, endTime time.Time) (results []ModelStat, err error) {
query := `
SELECT
model,
@@ -792,9 +824,20 @@ func (r *usageLogRepository) GetUserModelStats(ctx context.Context, userID int64
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
defer func() {
// 保持主错误优先;仅在无错误时回传 Close 失败。
// 同时清空返回值,避免误用不完整结果。
if closeErr := rows.Close(); closeErr != nil && err == nil {
err = closeErr
results = nil
}
}()
return scanModelStatsRows(rows)
results, err = scanModelStatsRows(rows)
if err != nil {
return nil, err
}
return results, nil
}
// UsageLogFilters represents filters for usage log queries
@@ -994,7 +1037,7 @@ func (r *usageLogRepository) GetBatchApiKeyUsageStats(ctx context.Context, apiKe
}
// GetUsageTrendWithFilters returns usage trend data with optional user/api_key filters
func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID int64) ([]TrendDataPoint, error) {
func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID int64) (results []TrendDataPoint, err error) {
dateFormat := "YYYY-MM-DD"
if granularity == "hour" {
dateFormat = "YYYY-MM-DD HH24:00"
@@ -1029,13 +1072,24 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
defer func() {
// 保持主错误优先;仅在无错误时回传 Close 失败。
// 同时清空返回值,避免误用不完整结果。
if closeErr := rows.Close(); closeErr != nil && err == nil {
err = closeErr
results = nil
}
}()
return scanTrendRows(rows)
results, err = scanTrendRows(rows)
if err != nil {
return nil, err
}
return results, nil
}
// GetModelStatsWithFilters returns model statistics with optional user/api_key filters
func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID int64) ([]ModelStat, error) {
func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID int64) (results []ModelStat, err error) {
query := `
SELECT
model,
@@ -1068,9 +1122,20 @@ func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, start
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
defer func() {
// 保持主错误优先;仅在无错误时回传 Close 失败。
// 同时清空返回值,避免误用不完整结果。
if closeErr := rows.Close(); closeErr != nil && err == nil {
err = closeErr
results = nil
}
}()
return scanModelStatsRows(rows)
results, err = scanModelStatsRows(rows)
if err != nil {
return nil, err
}
return results, nil
}
// GetGlobalStats gets usage statistics for all users within a time range
@@ -1118,7 +1183,7 @@ type AccountUsageSummary = usagestats.AccountUsageSummary
type AccountUsageStatsResponse = usagestats.AccountUsageStatsResponse
// GetAccountUsageStats returns comprehensive usage statistics for an account over a time range
func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID int64, startTime, endTime time.Time) (*AccountUsageStatsResponse, error) {
func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID int64, startTime, endTime time.Time) (resp *AccountUsageStatsResponse, err error) {
daysCount := int(endTime.Sub(startTime).Hours()/24) + 1
if daysCount <= 0 {
daysCount = 30
@@ -1141,7 +1206,14 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
defer func() {
// 保持主错误优先;仅在无错误时回传 Close 失败。
// 同时清空返回值,避免误用不完整结果。
if closeErr := rows.Close(); closeErr != nil && err == nil {
err = closeErr
resp = nil
}
}()
history := make([]AccountUsageHistory, 0)
for rows.Next() {
@@ -1150,7 +1222,7 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
var tokens int64
var cost float64
var actualCost float64
if err := rows.Scan(&date, &requests, &tokens, &cost, &actualCost); err != nil {
if err = rows.Scan(&date, &requests, &tokens, &cost, &actualCost); err != nil {
return nil, err
}
t, _ := time.Parse("2006-01-02", date)
@@ -1163,7 +1235,7 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
ActualCost: actualCost,
})
}
if err := rows.Err(); err != nil {
if err = rows.Err(); err != nil {
return nil, err
}
@@ -1261,11 +1333,12 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
models = []ModelStat{}
}
return &AccountUsageStatsResponse{
resp = &AccountUsageStatsResponse{
History: history,
Summary: summary,
Models: models,
}, nil
}
return resp, nil
}
func (r *usageLogRepository) listUsageLogsWithPagination(ctx context.Context, whereClause string, args []any, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
@@ -1286,22 +1359,30 @@ func (r *usageLogRepository) listUsageLogsWithPagination(ctx context.Context, wh
return logs, paginationResultFromTotal(total, params), nil
}
func (r *usageLogRepository) queryUsageLogs(ctx context.Context, query string, args ...any) ([]service.UsageLog, error) {
func (r *usageLogRepository) queryUsageLogs(ctx context.Context, query string, args ...any) (logs []service.UsageLog, err error) {
rows, err := r.sql.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
defer func() {
// 保持主错误优先;仅在无错误时回传 Close 失败。
// 同时清空返回值,避免误用不完整结果。
if closeErr := rows.Close(); closeErr != nil && err == nil {
err = closeErr
logs = nil
}
}()
logs := make([]service.UsageLog, 0)
logs = make([]service.UsageLog, 0)
for rows.Next() {
log, err := scanUsageLog(rows)
var log *service.UsageLog
log, err = scanUsageLog(rows)
if err != nil {
return nil, err
}
logs = append(logs, *log)
}
if err := rows.Err(); err != nil {
if err = rows.Err(); err != nil {
return nil, err
}
return logs, nil