From 8cb2d3b3525636ba70dac4c0700078bb9e30396e Mon Sep 17 00:00:00 2001 From: yangjianbo Date: Tue, 30 Dec 2025 17:13:32 +0800 Subject: [PATCH] =?UTF-8?q?fix(=E4=BB=93=E5=82=A8):=20=E8=A7=84=E8=8C=83?= =?UTF-8?q?=20rows.Close=20=E9=94=99=E8=AF=AF=E5=9B=9E=E4=BC=A0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 统一 usage_log_repo 查询的 Close 错误处理,避免\n成功路径吞掉关闭失败 scanSingleRow 使用 errors.Join 合并 Close 错误,\n保留 ErrNoRows 可判定 测试: make -C backend test-unit --- backend/internal/repository/sql_scan.go | 10 +- backend/internal/repository/usage_log_repo.go | 157 +++++++++++++----- 2 files changed, 125 insertions(+), 42 deletions(-) diff --git a/backend/internal/repository/sql_scan.go b/backend/internal/repository/sql_scan.go index e734ea82..91b6c9c4 100644 --- a/backend/internal/repository/sql_scan.go +++ b/backend/internal/repository/sql_scan.go @@ -3,14 +3,16 @@ package repository import ( "context" "database/sql" + "errors" ) type sqlQueryer interface { QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) } -// scanSingleRow executes a query and scans the first row into dest. -// If no rows are returned, sql.ErrNoRows is returned. +// scanSingleRow 执行查询并扫描第一行到 dest。 +// 若无结果,可通过 errors.Is(err, sql.ErrNoRows) 判断。 +// 如果 Close 失败,会与原始错误合并返回。 // 设计目的:仅依赖 QueryContext,避免 QueryRowContext 对 *sql.Tx 的强绑定, // 让 ent.Tx 也能作为 sqlExecutor/Queryer 使用。 func scanSingleRow(ctx context.Context, q sqlQueryer, query string, args []any, dest ...any) (err error) { @@ -19,8 +21,8 @@ func scanSingleRow(ctx context.Context, q sqlQueryer, query string, args []any, return err } defer func() { - if closeErr := rows.Close(); closeErr != nil && err == nil { - err = closeErr + if closeErr := rows.Close(); closeErr != nil { + err = errors.Join(err, closeErr) } }() diff --git a/backend/internal/repository/usage_log_repo.go b/backend/internal/repository/usage_log_repo.go index 4b9694c1..9341f20e 100644 --- a/backend/internal/repository/usage_log_repo.go +++ b/backend/internal/repository/usage_log_repo.go @@ -148,24 +148,31 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) return nil } -func (r *usageLogRepository) GetByID(ctx context.Context, id int64) (*service.UsageLog, error) { +func (r *usageLogRepository) GetByID(ctx context.Context, id int64) (log *service.UsageLog, err error) { query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE id = $1" rows, err := r.sql.QueryContext(ctx, query, id) if err != nil { return nil, err } - defer func() { _ = rows.Close() }() + defer func() { + // 保持主错误优先;仅在无错误时回传 Close 失败。 + // 同时清空返回值,避免误用不完整结果。 + if closeErr := rows.Close(); closeErr != nil && err == nil { + err = closeErr + log = nil + } + }() if !rows.Next() { - if err := rows.Err(); err != nil { + if err = rows.Err(); err != nil { return nil, err } return nil, service.ErrUsageLogNotFound } - log, err := scanUsageLog(rows) + log, err = scanUsageLog(rows) if err != nil { return nil, err } - if err := rows.Err(); err != nil { + if err = rows.Err(); err != nil { return nil, err } return log, nil @@ -535,7 +542,7 @@ type UserUsageTrendPoint = usagestats.UserUsageTrendPoint type ApiKeyUsageTrendPoint = usagestats.ApiKeyUsageTrendPoint // GetApiKeyUsageTrend returns usage trend data grouped by API key and date -func (r *usageLogRepository) GetApiKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]ApiKeyUsageTrendPoint, error) { +func (r *usageLogRepository) GetApiKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) (results []ApiKeyUsageTrendPoint, err error) { dateFormat := "YYYY-MM-DD" if granularity == "hour" { dateFormat = "YYYY-MM-DD HH24:00" @@ -568,17 +575,24 @@ func (r *usageLogRepository) GetApiKeyUsageTrend(ctx context.Context, startTime, if err != nil { return nil, err } - defer func() { _ = rows.Close() }() + defer func() { + // 保持主错误优先;仅在无错误时回传 Close 失败。 + // 同时清空返回值,避免误用不完整结果。 + if closeErr := rows.Close(); closeErr != nil && err == nil { + err = closeErr + results = nil + } + }() - results := make([]ApiKeyUsageTrendPoint, 0) + results = make([]ApiKeyUsageTrendPoint, 0) for rows.Next() { var row ApiKeyUsageTrendPoint - if err := rows.Scan(&row.Date, &row.ApiKeyID, &row.KeyName, &row.Requests, &row.Tokens); err != nil { + if err = rows.Scan(&row.Date, &row.ApiKeyID, &row.KeyName, &row.Requests, &row.Tokens); err != nil { return nil, err } results = append(results, row) } - if err := rows.Err(); err != nil { + if err = rows.Err(); err != nil { return nil, err } @@ -586,7 +600,7 @@ func (r *usageLogRepository) GetApiKeyUsageTrend(ctx context.Context, startTime, } // GetUserUsageTrend returns usage trend data grouped by user and date -func (r *usageLogRepository) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]UserUsageTrendPoint, error) { +func (r *usageLogRepository) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) (results []UserUsageTrendPoint, err error) { dateFormat := "YYYY-MM-DD" if granularity == "hour" { dateFormat = "YYYY-MM-DD HH24:00" @@ -621,17 +635,24 @@ func (r *usageLogRepository) GetUserUsageTrend(ctx context.Context, startTime, e if err != nil { return nil, err } - defer func() { _ = rows.Close() }() + defer func() { + // 保持主错误优先;仅在无错误时回传 Close 失败。 + // 同时清空返回值,避免误用不完整结果。 + if closeErr := rows.Close(); closeErr != nil && err == nil { + err = closeErr + results = nil + } + }() - results := make([]UserUsageTrendPoint, 0) + results = make([]UserUsageTrendPoint, 0) for rows.Next() { var row UserUsageTrendPoint - if err := rows.Scan(&row.Date, &row.UserID, &row.Email, &row.Requests, &row.Tokens, &row.Cost, &row.ActualCost); err != nil { + if err = rows.Scan(&row.Date, &row.UserID, &row.Email, &row.Requests, &row.Tokens, &row.Cost, &row.ActualCost); err != nil { return nil, err } results = append(results, row) } - if err := rows.Err(); err != nil { + if err = rows.Err(); err != nil { return nil, err } @@ -740,7 +761,7 @@ func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID i } // GetUserUsageTrendByUserID 获取指定用户的使用趋势 -func (r *usageLogRepository) GetUserUsageTrendByUserID(ctx context.Context, userID int64, startTime, endTime time.Time, granularity string) ([]TrendDataPoint, error) { +func (r *usageLogRepository) GetUserUsageTrendByUserID(ctx context.Context, userID int64, startTime, endTime time.Time, granularity string) (results []TrendDataPoint, err error) { dateFormat := "YYYY-MM-DD" if granularity == "hour" { dateFormat = "YYYY-MM-DD HH24:00" @@ -766,13 +787,24 @@ func (r *usageLogRepository) GetUserUsageTrendByUserID(ctx context.Context, user if err != nil { return nil, err } - defer func() { _ = rows.Close() }() + defer func() { + // 保持主错误优先;仅在无错误时回传 Close 失败。 + // 同时清空返回值,避免误用不完整结果。 + if closeErr := rows.Close(); closeErr != nil && err == nil { + err = closeErr + results = nil + } + }() - return scanTrendRows(rows) + results, err = scanTrendRows(rows) + if err != nil { + return nil, err + } + return results, nil } // GetUserModelStats 获取指定用户的模型统计 -func (r *usageLogRepository) GetUserModelStats(ctx context.Context, userID int64, startTime, endTime time.Time) ([]ModelStat, error) { +func (r *usageLogRepository) GetUserModelStats(ctx context.Context, userID int64, startTime, endTime time.Time) (results []ModelStat, err error) { query := ` SELECT model, @@ -792,9 +824,20 @@ func (r *usageLogRepository) GetUserModelStats(ctx context.Context, userID int64 if err != nil { return nil, err } - defer func() { _ = rows.Close() }() + defer func() { + // 保持主错误优先;仅在无错误时回传 Close 失败。 + // 同时清空返回值,避免误用不完整结果。 + if closeErr := rows.Close(); closeErr != nil && err == nil { + err = closeErr + results = nil + } + }() - return scanModelStatsRows(rows) + results, err = scanModelStatsRows(rows) + if err != nil { + return nil, err + } + return results, nil } // UsageLogFilters represents filters for usage log queries @@ -994,7 +1037,7 @@ func (r *usageLogRepository) GetBatchApiKeyUsageStats(ctx context.Context, apiKe } // GetUsageTrendWithFilters returns usage trend data with optional user/api_key filters -func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID int64) ([]TrendDataPoint, error) { +func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID int64) (results []TrendDataPoint, err error) { dateFormat := "YYYY-MM-DD" if granularity == "hour" { dateFormat = "YYYY-MM-DD HH24:00" @@ -1029,13 +1072,24 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start if err != nil { return nil, err } - defer func() { _ = rows.Close() }() + defer func() { + // 保持主错误优先;仅在无错误时回传 Close 失败。 + // 同时清空返回值,避免误用不完整结果。 + if closeErr := rows.Close(); closeErr != nil && err == nil { + err = closeErr + results = nil + } + }() - return scanTrendRows(rows) + results, err = scanTrendRows(rows) + if err != nil { + return nil, err + } + return results, nil } // GetModelStatsWithFilters returns model statistics with optional user/api_key filters -func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID int64) ([]ModelStat, error) { +func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID int64) (results []ModelStat, err error) { query := ` SELECT model, @@ -1068,9 +1122,20 @@ func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, start if err != nil { return nil, err } - defer func() { _ = rows.Close() }() + defer func() { + // 保持主错误优先;仅在无错误时回传 Close 失败。 + // 同时清空返回值,避免误用不完整结果。 + if closeErr := rows.Close(); closeErr != nil && err == nil { + err = closeErr + results = nil + } + }() - return scanModelStatsRows(rows) + results, err = scanModelStatsRows(rows) + if err != nil { + return nil, err + } + return results, nil } // GetGlobalStats gets usage statistics for all users within a time range @@ -1118,7 +1183,7 @@ type AccountUsageSummary = usagestats.AccountUsageSummary type AccountUsageStatsResponse = usagestats.AccountUsageStatsResponse // GetAccountUsageStats returns comprehensive usage statistics for an account over a time range -func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID int64, startTime, endTime time.Time) (*AccountUsageStatsResponse, error) { +func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID int64, startTime, endTime time.Time) (resp *AccountUsageStatsResponse, err error) { daysCount := int(endTime.Sub(startTime).Hours()/24) + 1 if daysCount <= 0 { daysCount = 30 @@ -1141,7 +1206,14 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID if err != nil { return nil, err } - defer func() { _ = rows.Close() }() + defer func() { + // 保持主错误优先;仅在无错误时回传 Close 失败。 + // 同时清空返回值,避免误用不完整结果。 + if closeErr := rows.Close(); closeErr != nil && err == nil { + err = closeErr + resp = nil + } + }() history := make([]AccountUsageHistory, 0) for rows.Next() { @@ -1150,7 +1222,7 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID var tokens int64 var cost float64 var actualCost float64 - if err := rows.Scan(&date, &requests, &tokens, &cost, &actualCost); err != nil { + if err = rows.Scan(&date, &requests, &tokens, &cost, &actualCost); err != nil { return nil, err } t, _ := time.Parse("2006-01-02", date) @@ -1163,7 +1235,7 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID ActualCost: actualCost, }) } - if err := rows.Err(); err != nil { + if err = rows.Err(); err != nil { return nil, err } @@ -1261,11 +1333,12 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID models = []ModelStat{} } - return &AccountUsageStatsResponse{ + resp = &AccountUsageStatsResponse{ History: history, Summary: summary, Models: models, - }, nil + } + return resp, nil } func (r *usageLogRepository) listUsageLogsWithPagination(ctx context.Context, whereClause string, args []any, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) { @@ -1286,22 +1359,30 @@ func (r *usageLogRepository) listUsageLogsWithPagination(ctx context.Context, wh return logs, paginationResultFromTotal(total, params), nil } -func (r *usageLogRepository) queryUsageLogs(ctx context.Context, query string, args ...any) ([]service.UsageLog, error) { +func (r *usageLogRepository) queryUsageLogs(ctx context.Context, query string, args ...any) (logs []service.UsageLog, err error) { rows, err := r.sql.QueryContext(ctx, query, args...) if err != nil { return nil, err } - defer func() { _ = rows.Close() }() + defer func() { + // 保持主错误优先;仅在无错误时回传 Close 失败。 + // 同时清空返回值,避免误用不完整结果。 + if closeErr := rows.Close(); closeErr != nil && err == nil { + err = closeErr + logs = nil + } + }() - logs := make([]service.UsageLog, 0) + logs = make([]service.UsageLog, 0) for rows.Next() { - log, err := scanUsageLog(rows) + var log *service.UsageLog + log, err = scanUsageLog(rows) if err != nil { return nil, err } logs = append(logs, *log) } - if err := rows.Err(); err != nil { + if err = rows.Err(); err != nil { return nil, err } return logs, nil