fix(仓储): 规范 rows.Close 错误回传

统一 usage_log_repo 查询的 Close 错误处理,避免\n成功路径吞掉关闭失败

scanSingleRow 使用 errors.Join 合并 Close 错误,\n保留 ErrNoRows 可判定

测试: make -C backend test-unit
This commit is contained in:
yangjianbo
2025-12-30 17:13:32 +08:00
parent 7e758b24c4
commit 8cb2d3b352
2 changed files with 125 additions and 42 deletions

View File

@@ -3,14 +3,16 @@ package repository
import ( import (
"context" "context"
"database/sql" "database/sql"
"errors"
) )
type sqlQueryer interface { type sqlQueryer interface {
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
} }
// scanSingleRow executes a query and scans the first row into dest. // scanSingleRow 执行查询并扫描第一行到 dest
// If no rows are returned, sql.ErrNoRows is returned. // 若无结果,可通过 errors.Is(err, sql.ErrNoRows) 判断。
// 如果 Close 失败,会与原始错误合并返回。
// 设计目的:仅依赖 QueryContext避免 QueryRowContext 对 *sql.Tx 的强绑定, // 设计目的:仅依赖 QueryContext避免 QueryRowContext 对 *sql.Tx 的强绑定,
// 让 ent.Tx 也能作为 sqlExecutor/Queryer 使用。 // 让 ent.Tx 也能作为 sqlExecutor/Queryer 使用。
func scanSingleRow(ctx context.Context, q sqlQueryer, query string, args []any, dest ...any) (err error) { func scanSingleRow(ctx context.Context, q sqlQueryer, query string, args []any, dest ...any) (err error) {
@@ -19,8 +21,8 @@ func scanSingleRow(ctx context.Context, q sqlQueryer, query string, args []any,
return err return err
} }
defer func() { defer func() {
if closeErr := rows.Close(); closeErr != nil && err == nil { if closeErr := rows.Close(); closeErr != nil {
err = closeErr err = errors.Join(err, closeErr)
} }
}() }()

View File

@@ -148,24 +148,31 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
return nil return nil
} }
func (r *usageLogRepository) GetByID(ctx context.Context, id int64) (*service.UsageLog, error) { func (r *usageLogRepository) GetByID(ctx context.Context, id int64) (log *service.UsageLog, err error) {
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE id = $1" query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE id = $1"
rows, err := r.sql.QueryContext(ctx, query, id) rows, err := r.sql.QueryContext(ctx, query, id)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer func() { _ = rows.Close() }() defer func() {
// 保持主错误优先;仅在无错误时回传 Close 失败。
// 同时清空返回值,避免误用不完整结果。
if closeErr := rows.Close(); closeErr != nil && err == nil {
err = closeErr
log = nil
}
}()
if !rows.Next() { if !rows.Next() {
if err := rows.Err(); err != nil { if err = rows.Err(); err != nil {
return nil, err return nil, err
} }
return nil, service.ErrUsageLogNotFound return nil, service.ErrUsageLogNotFound
} }
log, err := scanUsageLog(rows) log, err = scanUsageLog(rows)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if err := rows.Err(); err != nil { if err = rows.Err(); err != nil {
return nil, err return nil, err
} }
return log, nil return log, nil
@@ -535,7 +542,7 @@ type UserUsageTrendPoint = usagestats.UserUsageTrendPoint
type ApiKeyUsageTrendPoint = usagestats.ApiKeyUsageTrendPoint type ApiKeyUsageTrendPoint = usagestats.ApiKeyUsageTrendPoint
// GetApiKeyUsageTrend returns usage trend data grouped by API key and date // GetApiKeyUsageTrend returns usage trend data grouped by API key and date
func (r *usageLogRepository) GetApiKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]ApiKeyUsageTrendPoint, error) { func (r *usageLogRepository) GetApiKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) (results []ApiKeyUsageTrendPoint, err error) {
dateFormat := "YYYY-MM-DD" dateFormat := "YYYY-MM-DD"
if granularity == "hour" { if granularity == "hour" {
dateFormat = "YYYY-MM-DD HH24:00" dateFormat = "YYYY-MM-DD HH24:00"
@@ -568,17 +575,24 @@ func (r *usageLogRepository) GetApiKeyUsageTrend(ctx context.Context, startTime,
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer func() { _ = rows.Close() }() defer func() {
// 保持主错误优先;仅在无错误时回传 Close 失败。
// 同时清空返回值,避免误用不完整结果。
if closeErr := rows.Close(); closeErr != nil && err == nil {
err = closeErr
results = nil
}
}()
results := make([]ApiKeyUsageTrendPoint, 0) results = make([]ApiKeyUsageTrendPoint, 0)
for rows.Next() { for rows.Next() {
var row ApiKeyUsageTrendPoint var row ApiKeyUsageTrendPoint
if err := rows.Scan(&row.Date, &row.ApiKeyID, &row.KeyName, &row.Requests, &row.Tokens); err != nil { if err = rows.Scan(&row.Date, &row.ApiKeyID, &row.KeyName, &row.Requests, &row.Tokens); err != nil {
return nil, err return nil, err
} }
results = append(results, row) results = append(results, row)
} }
if err := rows.Err(); err != nil { if err = rows.Err(); err != nil {
return nil, err return nil, err
} }
@@ -586,7 +600,7 @@ func (r *usageLogRepository) GetApiKeyUsageTrend(ctx context.Context, startTime,
} }
// GetUserUsageTrend returns usage trend data grouped by user and date // GetUserUsageTrend returns usage trend data grouped by user and date
func (r *usageLogRepository) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]UserUsageTrendPoint, error) { func (r *usageLogRepository) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) (results []UserUsageTrendPoint, err error) {
dateFormat := "YYYY-MM-DD" dateFormat := "YYYY-MM-DD"
if granularity == "hour" { if granularity == "hour" {
dateFormat = "YYYY-MM-DD HH24:00" dateFormat = "YYYY-MM-DD HH24:00"
@@ -621,17 +635,24 @@ func (r *usageLogRepository) GetUserUsageTrend(ctx context.Context, startTime, e
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer func() { _ = rows.Close() }() defer func() {
// 保持主错误优先;仅在无错误时回传 Close 失败。
// 同时清空返回值,避免误用不完整结果。
if closeErr := rows.Close(); closeErr != nil && err == nil {
err = closeErr
results = nil
}
}()
results := make([]UserUsageTrendPoint, 0) results = make([]UserUsageTrendPoint, 0)
for rows.Next() { for rows.Next() {
var row UserUsageTrendPoint var row UserUsageTrendPoint
if err := rows.Scan(&row.Date, &row.UserID, &row.Email, &row.Requests, &row.Tokens, &row.Cost, &row.ActualCost); err != nil { if err = rows.Scan(&row.Date, &row.UserID, &row.Email, &row.Requests, &row.Tokens, &row.Cost, &row.ActualCost); err != nil {
return nil, err return nil, err
} }
results = append(results, row) results = append(results, row)
} }
if err := rows.Err(); err != nil { if err = rows.Err(); err != nil {
return nil, err return nil, err
} }
@@ -740,7 +761,7 @@ func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID i
} }
// GetUserUsageTrendByUserID 获取指定用户的使用趋势 // GetUserUsageTrendByUserID 获取指定用户的使用趋势
func (r *usageLogRepository) GetUserUsageTrendByUserID(ctx context.Context, userID int64, startTime, endTime time.Time, granularity string) ([]TrendDataPoint, error) { func (r *usageLogRepository) GetUserUsageTrendByUserID(ctx context.Context, userID int64, startTime, endTime time.Time, granularity string) (results []TrendDataPoint, err error) {
dateFormat := "YYYY-MM-DD" dateFormat := "YYYY-MM-DD"
if granularity == "hour" { if granularity == "hour" {
dateFormat = "YYYY-MM-DD HH24:00" dateFormat = "YYYY-MM-DD HH24:00"
@@ -766,13 +787,24 @@ func (r *usageLogRepository) GetUserUsageTrendByUserID(ctx context.Context, user
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer func() { _ = rows.Close() }() defer func() {
// 保持主错误优先;仅在无错误时回传 Close 失败。
// 同时清空返回值,避免误用不完整结果。
if closeErr := rows.Close(); closeErr != nil && err == nil {
err = closeErr
results = nil
}
}()
return scanTrendRows(rows) results, err = scanTrendRows(rows)
if err != nil {
return nil, err
}
return results, nil
} }
// GetUserModelStats 获取指定用户的模型统计 // GetUserModelStats 获取指定用户的模型统计
func (r *usageLogRepository) GetUserModelStats(ctx context.Context, userID int64, startTime, endTime time.Time) ([]ModelStat, error) { func (r *usageLogRepository) GetUserModelStats(ctx context.Context, userID int64, startTime, endTime time.Time) (results []ModelStat, err error) {
query := ` query := `
SELECT SELECT
model, model,
@@ -792,9 +824,20 @@ func (r *usageLogRepository) GetUserModelStats(ctx context.Context, userID int64
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer func() { _ = rows.Close() }() defer func() {
// 保持主错误优先;仅在无错误时回传 Close 失败。
// 同时清空返回值,避免误用不完整结果。
if closeErr := rows.Close(); closeErr != nil && err == nil {
err = closeErr
results = nil
}
}()
return scanModelStatsRows(rows) results, err = scanModelStatsRows(rows)
if err != nil {
return nil, err
}
return results, nil
} }
// UsageLogFilters represents filters for usage log queries // UsageLogFilters represents filters for usage log queries
@@ -994,7 +1037,7 @@ func (r *usageLogRepository) GetBatchApiKeyUsageStats(ctx context.Context, apiKe
} }
// GetUsageTrendWithFilters returns usage trend data with optional user/api_key filters // GetUsageTrendWithFilters returns usage trend data with optional user/api_key filters
func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID int64) ([]TrendDataPoint, error) { func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID int64) (results []TrendDataPoint, err error) {
dateFormat := "YYYY-MM-DD" dateFormat := "YYYY-MM-DD"
if granularity == "hour" { if granularity == "hour" {
dateFormat = "YYYY-MM-DD HH24:00" dateFormat = "YYYY-MM-DD HH24:00"
@@ -1029,13 +1072,24 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer func() { _ = rows.Close() }() defer func() {
// 保持主错误优先;仅在无错误时回传 Close 失败。
// 同时清空返回值,避免误用不完整结果。
if closeErr := rows.Close(); closeErr != nil && err == nil {
err = closeErr
results = nil
}
}()
return scanTrendRows(rows) results, err = scanTrendRows(rows)
if err != nil {
return nil, err
}
return results, nil
} }
// GetModelStatsWithFilters returns model statistics with optional user/api_key filters // GetModelStatsWithFilters returns model statistics with optional user/api_key filters
func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID int64) ([]ModelStat, error) { func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID int64) (results []ModelStat, err error) {
query := ` query := `
SELECT SELECT
model, model,
@@ -1068,9 +1122,20 @@ func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, start
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer func() { _ = rows.Close() }() defer func() {
// 保持主错误优先;仅在无错误时回传 Close 失败。
// 同时清空返回值,避免误用不完整结果。
if closeErr := rows.Close(); closeErr != nil && err == nil {
err = closeErr
results = nil
}
}()
return scanModelStatsRows(rows) results, err = scanModelStatsRows(rows)
if err != nil {
return nil, err
}
return results, nil
} }
// GetGlobalStats gets usage statistics for all users within a time range // GetGlobalStats gets usage statistics for all users within a time range
@@ -1118,7 +1183,7 @@ type AccountUsageSummary = usagestats.AccountUsageSummary
type AccountUsageStatsResponse = usagestats.AccountUsageStatsResponse type AccountUsageStatsResponse = usagestats.AccountUsageStatsResponse
// GetAccountUsageStats returns comprehensive usage statistics for an account over a time range // GetAccountUsageStats returns comprehensive usage statistics for an account over a time range
func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID int64, startTime, endTime time.Time) (*AccountUsageStatsResponse, error) { func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID int64, startTime, endTime time.Time) (resp *AccountUsageStatsResponse, err error) {
daysCount := int(endTime.Sub(startTime).Hours()/24) + 1 daysCount := int(endTime.Sub(startTime).Hours()/24) + 1
if daysCount <= 0 { if daysCount <= 0 {
daysCount = 30 daysCount = 30
@@ -1141,7 +1206,14 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer func() { _ = rows.Close() }() defer func() {
// 保持主错误优先;仅在无错误时回传 Close 失败。
// 同时清空返回值,避免误用不完整结果。
if closeErr := rows.Close(); closeErr != nil && err == nil {
err = closeErr
resp = nil
}
}()
history := make([]AccountUsageHistory, 0) history := make([]AccountUsageHistory, 0)
for rows.Next() { for rows.Next() {
@@ -1150,7 +1222,7 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
var tokens int64 var tokens int64
var cost float64 var cost float64
var actualCost float64 var actualCost float64
if err := rows.Scan(&date, &requests, &tokens, &cost, &actualCost); err != nil { if err = rows.Scan(&date, &requests, &tokens, &cost, &actualCost); err != nil {
return nil, err return nil, err
} }
t, _ := time.Parse("2006-01-02", date) t, _ := time.Parse("2006-01-02", date)
@@ -1163,7 +1235,7 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
ActualCost: actualCost, ActualCost: actualCost,
}) })
} }
if err := rows.Err(); err != nil { if err = rows.Err(); err != nil {
return nil, err return nil, err
} }
@@ -1261,11 +1333,12 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
models = []ModelStat{} models = []ModelStat{}
} }
return &AccountUsageStatsResponse{ resp = &AccountUsageStatsResponse{
History: history, History: history,
Summary: summary, Summary: summary,
Models: models, Models: models,
}, nil }
return resp, nil
} }
func (r *usageLogRepository) listUsageLogsWithPagination(ctx context.Context, whereClause string, args []any, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) { func (r *usageLogRepository) listUsageLogsWithPagination(ctx context.Context, whereClause string, args []any, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
@@ -1286,22 +1359,30 @@ func (r *usageLogRepository) listUsageLogsWithPagination(ctx context.Context, wh
return logs, paginationResultFromTotal(total, params), nil return logs, paginationResultFromTotal(total, params), nil
} }
func (r *usageLogRepository) queryUsageLogs(ctx context.Context, query string, args ...any) ([]service.UsageLog, error) { func (r *usageLogRepository) queryUsageLogs(ctx context.Context, query string, args ...any) (logs []service.UsageLog, err error) {
rows, err := r.sql.QueryContext(ctx, query, args...) rows, err := r.sql.QueryContext(ctx, query, args...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer func() { _ = rows.Close() }() defer func() {
// 保持主错误优先;仅在无错误时回传 Close 失败。
// 同时清空返回值,避免误用不完整结果。
if closeErr := rows.Close(); closeErr != nil && err == nil {
err = closeErr
logs = nil
}
}()
logs := make([]service.UsageLog, 0) logs = make([]service.UsageLog, 0)
for rows.Next() { for rows.Next() {
log, err := scanUsageLog(rows) var log *service.UsageLog
log, err = scanUsageLog(rows)
if err != nil { if err != nil {
return nil, err return nil, err
} }
logs = append(logs, *log) logs = append(logs, *log)
} }
if err := rows.Err(); err != nil { if err = rows.Err(); err != nil {
return nil, err return nil, err
} }
return logs, nil return logs, nil