fix(仓储): 规范 rows.Close 错误回传
统一 usage_log_repo 查询的 Close 错误处理,避免\n成功路径吞掉关闭失败 scanSingleRow 使用 errors.Join 合并 Close 错误,\n保留 ErrNoRows 可判定 测试: make -C backend test-unit
This commit is contained in:
@@ -3,14 +3,16 @@ package repository
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
)
|
||||
|
||||
type sqlQueryer interface {
|
||||
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
|
||||
}
|
||||
|
||||
// scanSingleRow executes a query and scans the first row into dest.
|
||||
// If no rows are returned, sql.ErrNoRows is returned.
|
||||
// scanSingleRow 执行查询并扫描第一行到 dest。
|
||||
// 若无结果,可通过 errors.Is(err, sql.ErrNoRows) 判断。
|
||||
// 如果 Close 失败,会与原始错误合并返回。
|
||||
// 设计目的:仅依赖 QueryContext,避免 QueryRowContext 对 *sql.Tx 的强绑定,
|
||||
// 让 ent.Tx 也能作为 sqlExecutor/Queryer 使用。
|
||||
func scanSingleRow(ctx context.Context, q sqlQueryer, query string, args []any, dest ...any) (err error) {
|
||||
@@ -19,8 +21,8 @@ func scanSingleRow(ctx context.Context, q sqlQueryer, query string, args []any,
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
if closeErr := rows.Close(); closeErr != nil && err == nil {
|
||||
err = closeErr
|
||||
if closeErr := rows.Close(); closeErr != nil {
|
||||
err = errors.Join(err, closeErr)
|
||||
}
|
||||
}()
|
||||
|
||||
|
||||
@@ -148,24 +148,31 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *usageLogRepository) GetByID(ctx context.Context, id int64) (*service.UsageLog, error) {
|
||||
func (r *usageLogRepository) GetByID(ctx context.Context, id int64) (log *service.UsageLog, err error) {
|
||||
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE id = $1"
|
||||
rows, err := r.sql.QueryContext(ctx, query, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
defer func() {
|
||||
// 保持主错误优先;仅在无错误时回传 Close 失败。
|
||||
// 同时清空返回值,避免误用不完整结果。
|
||||
if closeErr := rows.Close(); closeErr != nil && err == nil {
|
||||
err = closeErr
|
||||
log = nil
|
||||
}
|
||||
}()
|
||||
if !rows.Next() {
|
||||
if err := rows.Err(); err != nil {
|
||||
if err = rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return nil, service.ErrUsageLogNotFound
|
||||
}
|
||||
log, err := scanUsageLog(rows)
|
||||
log, err = scanUsageLog(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
if err = rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return log, nil
|
||||
@@ -535,7 +542,7 @@ type UserUsageTrendPoint = usagestats.UserUsageTrendPoint
|
||||
type ApiKeyUsageTrendPoint = usagestats.ApiKeyUsageTrendPoint
|
||||
|
||||
// GetApiKeyUsageTrend returns usage trend data grouped by API key and date
|
||||
func (r *usageLogRepository) GetApiKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]ApiKeyUsageTrendPoint, error) {
|
||||
func (r *usageLogRepository) GetApiKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) (results []ApiKeyUsageTrendPoint, err error) {
|
||||
dateFormat := "YYYY-MM-DD"
|
||||
if granularity == "hour" {
|
||||
dateFormat = "YYYY-MM-DD HH24:00"
|
||||
@@ -568,17 +575,24 @@ func (r *usageLogRepository) GetApiKeyUsageTrend(ctx context.Context, startTime,
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
defer func() {
|
||||
// 保持主错误优先;仅在无错误时回传 Close 失败。
|
||||
// 同时清空返回值,避免误用不完整结果。
|
||||
if closeErr := rows.Close(); closeErr != nil && err == nil {
|
||||
err = closeErr
|
||||
results = nil
|
||||
}
|
||||
}()
|
||||
|
||||
results := make([]ApiKeyUsageTrendPoint, 0)
|
||||
results = make([]ApiKeyUsageTrendPoint, 0)
|
||||
for rows.Next() {
|
||||
var row ApiKeyUsageTrendPoint
|
||||
if err := rows.Scan(&row.Date, &row.ApiKeyID, &row.KeyName, &row.Requests, &row.Tokens); err != nil {
|
||||
if err = rows.Scan(&row.Date, &row.ApiKeyID, &row.KeyName, &row.Requests, &row.Tokens); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
results = append(results, row)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
if err = rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -586,7 +600,7 @@ func (r *usageLogRepository) GetApiKeyUsageTrend(ctx context.Context, startTime,
|
||||
}
|
||||
|
||||
// GetUserUsageTrend returns usage trend data grouped by user and date
|
||||
func (r *usageLogRepository) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]UserUsageTrendPoint, error) {
|
||||
func (r *usageLogRepository) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) (results []UserUsageTrendPoint, err error) {
|
||||
dateFormat := "YYYY-MM-DD"
|
||||
if granularity == "hour" {
|
||||
dateFormat = "YYYY-MM-DD HH24:00"
|
||||
@@ -621,17 +635,24 @@ func (r *usageLogRepository) GetUserUsageTrend(ctx context.Context, startTime, e
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
defer func() {
|
||||
// 保持主错误优先;仅在无错误时回传 Close 失败。
|
||||
// 同时清空返回值,避免误用不完整结果。
|
||||
if closeErr := rows.Close(); closeErr != nil && err == nil {
|
||||
err = closeErr
|
||||
results = nil
|
||||
}
|
||||
}()
|
||||
|
||||
results := make([]UserUsageTrendPoint, 0)
|
||||
results = make([]UserUsageTrendPoint, 0)
|
||||
for rows.Next() {
|
||||
var row UserUsageTrendPoint
|
||||
if err := rows.Scan(&row.Date, &row.UserID, &row.Email, &row.Requests, &row.Tokens, &row.Cost, &row.ActualCost); err != nil {
|
||||
if err = rows.Scan(&row.Date, &row.UserID, &row.Email, &row.Requests, &row.Tokens, &row.Cost, &row.ActualCost); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
results = append(results, row)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
if err = rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -740,7 +761,7 @@ func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID i
|
||||
}
|
||||
|
||||
// GetUserUsageTrendByUserID 获取指定用户的使用趋势
|
||||
func (r *usageLogRepository) GetUserUsageTrendByUserID(ctx context.Context, userID int64, startTime, endTime time.Time, granularity string) ([]TrendDataPoint, error) {
|
||||
func (r *usageLogRepository) GetUserUsageTrendByUserID(ctx context.Context, userID int64, startTime, endTime time.Time, granularity string) (results []TrendDataPoint, err error) {
|
||||
dateFormat := "YYYY-MM-DD"
|
||||
if granularity == "hour" {
|
||||
dateFormat = "YYYY-MM-DD HH24:00"
|
||||
@@ -766,13 +787,24 @@ func (r *usageLogRepository) GetUserUsageTrendByUserID(ctx context.Context, user
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
defer func() {
|
||||
// 保持主错误优先;仅在无错误时回传 Close 失败。
|
||||
// 同时清空返回值,避免误用不完整结果。
|
||||
if closeErr := rows.Close(); closeErr != nil && err == nil {
|
||||
err = closeErr
|
||||
results = nil
|
||||
}
|
||||
}()
|
||||
|
||||
return scanTrendRows(rows)
|
||||
results, err = scanTrendRows(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// GetUserModelStats 获取指定用户的模型统计
|
||||
func (r *usageLogRepository) GetUserModelStats(ctx context.Context, userID int64, startTime, endTime time.Time) ([]ModelStat, error) {
|
||||
func (r *usageLogRepository) GetUserModelStats(ctx context.Context, userID int64, startTime, endTime time.Time) (results []ModelStat, err error) {
|
||||
query := `
|
||||
SELECT
|
||||
model,
|
||||
@@ -792,9 +824,20 @@ func (r *usageLogRepository) GetUserModelStats(ctx context.Context, userID int64
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
defer func() {
|
||||
// 保持主错误优先;仅在无错误时回传 Close 失败。
|
||||
// 同时清空返回值,避免误用不完整结果。
|
||||
if closeErr := rows.Close(); closeErr != nil && err == nil {
|
||||
err = closeErr
|
||||
results = nil
|
||||
}
|
||||
}()
|
||||
|
||||
return scanModelStatsRows(rows)
|
||||
results, err = scanModelStatsRows(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// UsageLogFilters represents filters for usage log queries
|
||||
@@ -994,7 +1037,7 @@ func (r *usageLogRepository) GetBatchApiKeyUsageStats(ctx context.Context, apiKe
|
||||
}
|
||||
|
||||
// GetUsageTrendWithFilters returns usage trend data with optional user/api_key filters
|
||||
func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID int64) ([]TrendDataPoint, error) {
|
||||
func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID int64) (results []TrendDataPoint, err error) {
|
||||
dateFormat := "YYYY-MM-DD"
|
||||
if granularity == "hour" {
|
||||
dateFormat = "YYYY-MM-DD HH24:00"
|
||||
@@ -1029,13 +1072,24 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
defer func() {
|
||||
// 保持主错误优先;仅在无错误时回传 Close 失败。
|
||||
// 同时清空返回值,避免误用不完整结果。
|
||||
if closeErr := rows.Close(); closeErr != nil && err == nil {
|
||||
err = closeErr
|
||||
results = nil
|
||||
}
|
||||
}()
|
||||
|
||||
return scanTrendRows(rows)
|
||||
results, err = scanTrendRows(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// GetModelStatsWithFilters returns model statistics with optional user/api_key filters
|
||||
func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID int64) ([]ModelStat, error) {
|
||||
func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID int64) (results []ModelStat, err error) {
|
||||
query := `
|
||||
SELECT
|
||||
model,
|
||||
@@ -1068,9 +1122,20 @@ func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, start
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
defer func() {
|
||||
// 保持主错误优先;仅在无错误时回传 Close 失败。
|
||||
// 同时清空返回值,避免误用不完整结果。
|
||||
if closeErr := rows.Close(); closeErr != nil && err == nil {
|
||||
err = closeErr
|
||||
results = nil
|
||||
}
|
||||
}()
|
||||
|
||||
return scanModelStatsRows(rows)
|
||||
results, err = scanModelStatsRows(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// GetGlobalStats gets usage statistics for all users within a time range
|
||||
@@ -1118,7 +1183,7 @@ type AccountUsageSummary = usagestats.AccountUsageSummary
|
||||
type AccountUsageStatsResponse = usagestats.AccountUsageStatsResponse
|
||||
|
||||
// GetAccountUsageStats returns comprehensive usage statistics for an account over a time range
|
||||
func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID int64, startTime, endTime time.Time) (*AccountUsageStatsResponse, error) {
|
||||
func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID int64, startTime, endTime time.Time) (resp *AccountUsageStatsResponse, err error) {
|
||||
daysCount := int(endTime.Sub(startTime).Hours()/24) + 1
|
||||
if daysCount <= 0 {
|
||||
daysCount = 30
|
||||
@@ -1141,7 +1206,14 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
defer func() {
|
||||
// 保持主错误优先;仅在无错误时回传 Close 失败。
|
||||
// 同时清空返回值,避免误用不完整结果。
|
||||
if closeErr := rows.Close(); closeErr != nil && err == nil {
|
||||
err = closeErr
|
||||
resp = nil
|
||||
}
|
||||
}()
|
||||
|
||||
history := make([]AccountUsageHistory, 0)
|
||||
for rows.Next() {
|
||||
@@ -1150,7 +1222,7 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
|
||||
var tokens int64
|
||||
var cost float64
|
||||
var actualCost float64
|
||||
if err := rows.Scan(&date, &requests, &tokens, &cost, &actualCost); err != nil {
|
||||
if err = rows.Scan(&date, &requests, &tokens, &cost, &actualCost); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
t, _ := time.Parse("2006-01-02", date)
|
||||
@@ -1163,7 +1235,7 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
|
||||
ActualCost: actualCost,
|
||||
})
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
if err = rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -1261,11 +1333,12 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
|
||||
models = []ModelStat{}
|
||||
}
|
||||
|
||||
return &AccountUsageStatsResponse{
|
||||
resp = &AccountUsageStatsResponse{
|
||||
History: history,
|
||||
Summary: summary,
|
||||
Models: models,
|
||||
}, nil
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (r *usageLogRepository) listUsageLogsWithPagination(ctx context.Context, whereClause string, args []any, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||||
@@ -1286,22 +1359,30 @@ func (r *usageLogRepository) listUsageLogsWithPagination(ctx context.Context, wh
|
||||
return logs, paginationResultFromTotal(total, params), nil
|
||||
}
|
||||
|
||||
func (r *usageLogRepository) queryUsageLogs(ctx context.Context, query string, args ...any) ([]service.UsageLog, error) {
|
||||
func (r *usageLogRepository) queryUsageLogs(ctx context.Context, query string, args ...any) (logs []service.UsageLog, err error) {
|
||||
rows, err := r.sql.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
defer func() {
|
||||
// 保持主错误优先;仅在无错误时回传 Close 失败。
|
||||
// 同时清空返回值,避免误用不完整结果。
|
||||
if closeErr := rows.Close(); closeErr != nil && err == nil {
|
||||
err = closeErr
|
||||
logs = nil
|
||||
}
|
||||
}()
|
||||
|
||||
logs := make([]service.UsageLog, 0)
|
||||
logs = make([]service.UsageLog, 0)
|
||||
for rows.Next() {
|
||||
log, err := scanUsageLog(rows)
|
||||
var log *service.UsageLog
|
||||
log, err = scanUsageLog(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
logs = append(logs, *log)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
if err = rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return logs, nil
|
||||
|
||||
Reference in New Issue
Block a user