fix(仓储): 规范 rows.Close 错误回传
统一 usage_log_repo 查询的 Close 错误处理,避免\n成功路径吞掉关闭失败 scanSingleRow 使用 errors.Join 合并 Close 错误,\n保留 ErrNoRows 可判定 测试: make -C backend test-unit
This commit is contained in:
@@ -3,14 +3,16 @@ package repository
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
type sqlQueryer interface {
|
type sqlQueryer interface {
|
||||||
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
|
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// scanSingleRow executes a query and scans the first row into dest.
|
// scanSingleRow 执行查询并扫描第一行到 dest。
|
||||||
// If no rows are returned, sql.ErrNoRows is returned.
|
// 若无结果,可通过 errors.Is(err, sql.ErrNoRows) 判断。
|
||||||
|
// 如果 Close 失败,会与原始错误合并返回。
|
||||||
// 设计目的:仅依赖 QueryContext,避免 QueryRowContext 对 *sql.Tx 的强绑定,
|
// 设计目的:仅依赖 QueryContext,避免 QueryRowContext 对 *sql.Tx 的强绑定,
|
||||||
// 让 ent.Tx 也能作为 sqlExecutor/Queryer 使用。
|
// 让 ent.Tx 也能作为 sqlExecutor/Queryer 使用。
|
||||||
func scanSingleRow(ctx context.Context, q sqlQueryer, query string, args []any, dest ...any) (err error) {
|
func scanSingleRow(ctx context.Context, q sqlQueryer, query string, args []any, dest ...any) (err error) {
|
||||||
@@ -19,8 +21,8 @@ func scanSingleRow(ctx context.Context, q sqlQueryer, query string, args []any,
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
if closeErr := rows.Close(); closeErr != nil && err == nil {
|
if closeErr := rows.Close(); closeErr != nil {
|
||||||
err = closeErr
|
err = errors.Join(err, closeErr)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
|||||||
@@ -148,24 +148,31 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *usageLogRepository) GetByID(ctx context.Context, id int64) (*service.UsageLog, error) {
|
func (r *usageLogRepository) GetByID(ctx context.Context, id int64) (log *service.UsageLog, err error) {
|
||||||
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE id = $1"
|
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE id = $1"
|
||||||
rows, err := r.sql.QueryContext(ctx, query, id)
|
rows, err := r.sql.QueryContext(ctx, query, id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer func() { _ = rows.Close() }()
|
defer func() {
|
||||||
|
// 保持主错误优先;仅在无错误时回传 Close 失败。
|
||||||
|
// 同时清空返回值,避免误用不完整结果。
|
||||||
|
if closeErr := rows.Close(); closeErr != nil && err == nil {
|
||||||
|
err = closeErr
|
||||||
|
log = nil
|
||||||
|
}
|
||||||
|
}()
|
||||||
if !rows.Next() {
|
if !rows.Next() {
|
||||||
if err := rows.Err(); err != nil {
|
if err = rows.Err(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return nil, service.ErrUsageLogNotFound
|
return nil, service.ErrUsageLogNotFound
|
||||||
}
|
}
|
||||||
log, err := scanUsageLog(rows)
|
log, err = scanUsageLog(rows)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if err := rows.Err(); err != nil {
|
if err = rows.Err(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return log, nil
|
return log, nil
|
||||||
@@ -535,7 +542,7 @@ type UserUsageTrendPoint = usagestats.UserUsageTrendPoint
|
|||||||
type ApiKeyUsageTrendPoint = usagestats.ApiKeyUsageTrendPoint
|
type ApiKeyUsageTrendPoint = usagestats.ApiKeyUsageTrendPoint
|
||||||
|
|
||||||
// GetApiKeyUsageTrend returns usage trend data grouped by API key and date
|
// GetApiKeyUsageTrend returns usage trend data grouped by API key and date
|
||||||
func (r *usageLogRepository) GetApiKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]ApiKeyUsageTrendPoint, error) {
|
func (r *usageLogRepository) GetApiKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) (results []ApiKeyUsageTrendPoint, err error) {
|
||||||
dateFormat := "YYYY-MM-DD"
|
dateFormat := "YYYY-MM-DD"
|
||||||
if granularity == "hour" {
|
if granularity == "hour" {
|
||||||
dateFormat = "YYYY-MM-DD HH24:00"
|
dateFormat = "YYYY-MM-DD HH24:00"
|
||||||
@@ -568,17 +575,24 @@ func (r *usageLogRepository) GetApiKeyUsageTrend(ctx context.Context, startTime,
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer func() { _ = rows.Close() }()
|
defer func() {
|
||||||
|
// 保持主错误优先;仅在无错误时回传 Close 失败。
|
||||||
|
// 同时清空返回值,避免误用不完整结果。
|
||||||
|
if closeErr := rows.Close(); closeErr != nil && err == nil {
|
||||||
|
err = closeErr
|
||||||
|
results = nil
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
results := make([]ApiKeyUsageTrendPoint, 0)
|
results = make([]ApiKeyUsageTrendPoint, 0)
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var row ApiKeyUsageTrendPoint
|
var row ApiKeyUsageTrendPoint
|
||||||
if err := rows.Scan(&row.Date, &row.ApiKeyID, &row.KeyName, &row.Requests, &row.Tokens); err != nil {
|
if err = rows.Scan(&row.Date, &row.ApiKeyID, &row.KeyName, &row.Requests, &row.Tokens); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
results = append(results, row)
|
results = append(results, row)
|
||||||
}
|
}
|
||||||
if err := rows.Err(); err != nil {
|
if err = rows.Err(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -586,7 +600,7 @@ func (r *usageLogRepository) GetApiKeyUsageTrend(ctx context.Context, startTime,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetUserUsageTrend returns usage trend data grouped by user and date
|
// GetUserUsageTrend returns usage trend data grouped by user and date
|
||||||
func (r *usageLogRepository) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]UserUsageTrendPoint, error) {
|
func (r *usageLogRepository) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) (results []UserUsageTrendPoint, err error) {
|
||||||
dateFormat := "YYYY-MM-DD"
|
dateFormat := "YYYY-MM-DD"
|
||||||
if granularity == "hour" {
|
if granularity == "hour" {
|
||||||
dateFormat = "YYYY-MM-DD HH24:00"
|
dateFormat = "YYYY-MM-DD HH24:00"
|
||||||
@@ -621,17 +635,24 @@ func (r *usageLogRepository) GetUserUsageTrend(ctx context.Context, startTime, e
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer func() { _ = rows.Close() }()
|
defer func() {
|
||||||
|
// 保持主错误优先;仅在无错误时回传 Close 失败。
|
||||||
|
// 同时清空返回值,避免误用不完整结果。
|
||||||
|
if closeErr := rows.Close(); closeErr != nil && err == nil {
|
||||||
|
err = closeErr
|
||||||
|
results = nil
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
results := make([]UserUsageTrendPoint, 0)
|
results = make([]UserUsageTrendPoint, 0)
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var row UserUsageTrendPoint
|
var row UserUsageTrendPoint
|
||||||
if err := rows.Scan(&row.Date, &row.UserID, &row.Email, &row.Requests, &row.Tokens, &row.Cost, &row.ActualCost); err != nil {
|
if err = rows.Scan(&row.Date, &row.UserID, &row.Email, &row.Requests, &row.Tokens, &row.Cost, &row.ActualCost); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
results = append(results, row)
|
results = append(results, row)
|
||||||
}
|
}
|
||||||
if err := rows.Err(); err != nil {
|
if err = rows.Err(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -740,7 +761,7 @@ func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID i
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetUserUsageTrendByUserID 获取指定用户的使用趋势
|
// GetUserUsageTrendByUserID 获取指定用户的使用趋势
|
||||||
func (r *usageLogRepository) GetUserUsageTrendByUserID(ctx context.Context, userID int64, startTime, endTime time.Time, granularity string) ([]TrendDataPoint, error) {
|
func (r *usageLogRepository) GetUserUsageTrendByUserID(ctx context.Context, userID int64, startTime, endTime time.Time, granularity string) (results []TrendDataPoint, err error) {
|
||||||
dateFormat := "YYYY-MM-DD"
|
dateFormat := "YYYY-MM-DD"
|
||||||
if granularity == "hour" {
|
if granularity == "hour" {
|
||||||
dateFormat = "YYYY-MM-DD HH24:00"
|
dateFormat = "YYYY-MM-DD HH24:00"
|
||||||
@@ -766,13 +787,24 @@ func (r *usageLogRepository) GetUserUsageTrendByUserID(ctx context.Context, user
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer func() { _ = rows.Close() }()
|
defer func() {
|
||||||
|
// 保持主错误优先;仅在无错误时回传 Close 失败。
|
||||||
|
// 同时清空返回值,避免误用不完整结果。
|
||||||
|
if closeErr := rows.Close(); closeErr != nil && err == nil {
|
||||||
|
err = closeErr
|
||||||
|
results = nil
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
return scanTrendRows(rows)
|
results, err = scanTrendRows(rows)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return results, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetUserModelStats 获取指定用户的模型统计
|
// GetUserModelStats 获取指定用户的模型统计
|
||||||
func (r *usageLogRepository) GetUserModelStats(ctx context.Context, userID int64, startTime, endTime time.Time) ([]ModelStat, error) {
|
func (r *usageLogRepository) GetUserModelStats(ctx context.Context, userID int64, startTime, endTime time.Time) (results []ModelStat, err error) {
|
||||||
query := `
|
query := `
|
||||||
SELECT
|
SELECT
|
||||||
model,
|
model,
|
||||||
@@ -792,9 +824,20 @@ func (r *usageLogRepository) GetUserModelStats(ctx context.Context, userID int64
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer func() { _ = rows.Close() }()
|
defer func() {
|
||||||
|
// 保持主错误优先;仅在无错误时回传 Close 失败。
|
||||||
|
// 同时清空返回值,避免误用不完整结果。
|
||||||
|
if closeErr := rows.Close(); closeErr != nil && err == nil {
|
||||||
|
err = closeErr
|
||||||
|
results = nil
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
return scanModelStatsRows(rows)
|
results, err = scanModelStatsRows(rows)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return results, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// UsageLogFilters represents filters for usage log queries
|
// UsageLogFilters represents filters for usage log queries
|
||||||
@@ -994,7 +1037,7 @@ func (r *usageLogRepository) GetBatchApiKeyUsageStats(ctx context.Context, apiKe
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetUsageTrendWithFilters returns usage trend data with optional user/api_key filters
|
// GetUsageTrendWithFilters returns usage trend data with optional user/api_key filters
|
||||||
func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID int64) ([]TrendDataPoint, error) {
|
func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID int64) (results []TrendDataPoint, err error) {
|
||||||
dateFormat := "YYYY-MM-DD"
|
dateFormat := "YYYY-MM-DD"
|
||||||
if granularity == "hour" {
|
if granularity == "hour" {
|
||||||
dateFormat = "YYYY-MM-DD HH24:00"
|
dateFormat = "YYYY-MM-DD HH24:00"
|
||||||
@@ -1029,13 +1072,24 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer func() { _ = rows.Close() }()
|
defer func() {
|
||||||
|
// 保持主错误优先;仅在无错误时回传 Close 失败。
|
||||||
|
// 同时清空返回值,避免误用不完整结果。
|
||||||
|
if closeErr := rows.Close(); closeErr != nil && err == nil {
|
||||||
|
err = closeErr
|
||||||
|
results = nil
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
return scanTrendRows(rows)
|
results, err = scanTrendRows(rows)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return results, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetModelStatsWithFilters returns model statistics with optional user/api_key filters
|
// GetModelStatsWithFilters returns model statistics with optional user/api_key filters
|
||||||
func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID int64) ([]ModelStat, error) {
|
func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID int64) (results []ModelStat, err error) {
|
||||||
query := `
|
query := `
|
||||||
SELECT
|
SELECT
|
||||||
model,
|
model,
|
||||||
@@ -1068,9 +1122,20 @@ func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, start
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer func() { _ = rows.Close() }()
|
defer func() {
|
||||||
|
// 保持主错误优先;仅在无错误时回传 Close 失败。
|
||||||
|
// 同时清空返回值,避免误用不完整结果。
|
||||||
|
if closeErr := rows.Close(); closeErr != nil && err == nil {
|
||||||
|
err = closeErr
|
||||||
|
results = nil
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
return scanModelStatsRows(rows)
|
results, err = scanModelStatsRows(rows)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return results, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetGlobalStats gets usage statistics for all users within a time range
|
// GetGlobalStats gets usage statistics for all users within a time range
|
||||||
@@ -1118,7 +1183,7 @@ type AccountUsageSummary = usagestats.AccountUsageSummary
|
|||||||
type AccountUsageStatsResponse = usagestats.AccountUsageStatsResponse
|
type AccountUsageStatsResponse = usagestats.AccountUsageStatsResponse
|
||||||
|
|
||||||
// GetAccountUsageStats returns comprehensive usage statistics for an account over a time range
|
// GetAccountUsageStats returns comprehensive usage statistics for an account over a time range
|
||||||
func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID int64, startTime, endTime time.Time) (*AccountUsageStatsResponse, error) {
|
func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID int64, startTime, endTime time.Time) (resp *AccountUsageStatsResponse, err error) {
|
||||||
daysCount := int(endTime.Sub(startTime).Hours()/24) + 1
|
daysCount := int(endTime.Sub(startTime).Hours()/24) + 1
|
||||||
if daysCount <= 0 {
|
if daysCount <= 0 {
|
||||||
daysCount = 30
|
daysCount = 30
|
||||||
@@ -1141,7 +1206,14 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer func() { _ = rows.Close() }()
|
defer func() {
|
||||||
|
// 保持主错误优先;仅在无错误时回传 Close 失败。
|
||||||
|
// 同时清空返回值,避免误用不完整结果。
|
||||||
|
if closeErr := rows.Close(); closeErr != nil && err == nil {
|
||||||
|
err = closeErr
|
||||||
|
resp = nil
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
history := make([]AccountUsageHistory, 0)
|
history := make([]AccountUsageHistory, 0)
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
@@ -1150,7 +1222,7 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
|
|||||||
var tokens int64
|
var tokens int64
|
||||||
var cost float64
|
var cost float64
|
||||||
var actualCost float64
|
var actualCost float64
|
||||||
if err := rows.Scan(&date, &requests, &tokens, &cost, &actualCost); err != nil {
|
if err = rows.Scan(&date, &requests, &tokens, &cost, &actualCost); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
t, _ := time.Parse("2006-01-02", date)
|
t, _ := time.Parse("2006-01-02", date)
|
||||||
@@ -1163,7 +1235,7 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
|
|||||||
ActualCost: actualCost,
|
ActualCost: actualCost,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
if err := rows.Err(); err != nil {
|
if err = rows.Err(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1261,11 +1333,12 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
|
|||||||
models = []ModelStat{}
|
models = []ModelStat{}
|
||||||
}
|
}
|
||||||
|
|
||||||
return &AccountUsageStatsResponse{
|
resp = &AccountUsageStatsResponse{
|
||||||
History: history,
|
History: history,
|
||||||
Summary: summary,
|
Summary: summary,
|
||||||
Models: models,
|
Models: models,
|
||||||
}, nil
|
}
|
||||||
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *usageLogRepository) listUsageLogsWithPagination(ctx context.Context, whereClause string, args []any, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
func (r *usageLogRepository) listUsageLogsWithPagination(ctx context.Context, whereClause string, args []any, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||||||
@@ -1286,22 +1359,30 @@ func (r *usageLogRepository) listUsageLogsWithPagination(ctx context.Context, wh
|
|||||||
return logs, paginationResultFromTotal(total, params), nil
|
return logs, paginationResultFromTotal(total, params), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *usageLogRepository) queryUsageLogs(ctx context.Context, query string, args ...any) ([]service.UsageLog, error) {
|
func (r *usageLogRepository) queryUsageLogs(ctx context.Context, query string, args ...any) (logs []service.UsageLog, err error) {
|
||||||
rows, err := r.sql.QueryContext(ctx, query, args...)
|
rows, err := r.sql.QueryContext(ctx, query, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer func() { _ = rows.Close() }()
|
defer func() {
|
||||||
|
// 保持主错误优先;仅在无错误时回传 Close 失败。
|
||||||
|
// 同时清空返回值,避免误用不完整结果。
|
||||||
|
if closeErr := rows.Close(); closeErr != nil && err == nil {
|
||||||
|
err = closeErr
|
||||||
|
logs = nil
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
logs := make([]service.UsageLog, 0)
|
logs = make([]service.UsageLog, 0)
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
log, err := scanUsageLog(rows)
|
var log *service.UsageLog
|
||||||
|
log, err = scanUsageLog(rows)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
logs = append(logs, *log)
|
logs = append(logs, *log)
|
||||||
}
|
}
|
||||||
if err := rows.Err(); err != nil {
|
if err = rows.Err(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return logs, nil
|
return logs, nil
|
||||||
|
|||||||
Reference in New Issue
Block a user