fix(仓储): 修复软删除过滤与事务测试

修复软删除拦截器使用错误,确保默认查询过滤已删记录
仓储层改用 ent.Tx 与扫描辅助,避免 sql.Tx 断言问题
同步更新集成测试以覆盖事务与统计变动
This commit is contained in:
yangjianbo
2025-12-29 19:23:49 +08:00
parent b436da7249
commit ae191f72a4
20 changed files with 565 additions and 326 deletions

View File

@@ -3,6 +3,7 @@ package repository
import (
"context"
"database/sql"
"errors"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/apikey"
@@ -10,26 +11,16 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/lib/pq"
"entgo.io/ent/dialect"
entsql "entgo.io/ent/dialect/sql"
)
type sqlExecutor interface {
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row
}
type sqlBeginner interface {
sqlExecutor
BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error)
}
type groupRepository struct {
client *dbent.Client
sql sqlExecutor
begin sqlBeginner
}
func NewGroupRepository(client *dbent.Client, sqlDB *sql.DB) service.GroupRepository {
@@ -37,11 +28,7 @@ func NewGroupRepository(client *dbent.Client, sqlDB *sql.DB) service.GroupReposi
}
func newGroupRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *groupRepository {
var beginner sqlBeginner
if b, ok := sqlq.(sqlBeginner); ok {
beginner = b
}
return &groupRepository{client: client, sql: sqlq, begin: beginner}
return &groupRepository{client: client, sql: sqlq}
}
func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) error {
@@ -214,7 +201,7 @@ func (r *groupRepository) ExistsByName(ctx context.Context, name string) (bool,
func (r *groupRepository) GetAccountCount(ctx context.Context, groupID int64) (int64, error) {
var count int64
if err := r.sql.QueryRowContext(ctx, "SELECT COUNT(*) FROM account_groups WHERE group_id = $1", groupID).Scan(&count); err != nil {
if err := scanSingleRow(ctx, r.sql, "SELECT COUNT(*) FROM account_groups WHERE group_id = $1", []any{groupID}, &count); err != nil {
return 0, err
}
return count, nil
@@ -236,31 +223,44 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64,
}
groupSvc := groupEntityToService(g)
exec := r.sql
// 使用 ent 事务统一包裹:避免手工基于 *sql.Tx 构造 ent client 带来的驱动断言问题,
// 同时保证级联删除的原子性。
tx, err := r.client.Tx(ctx)
if err != nil && !errors.Is(err, dbent.ErrTxStarted) {
return nil, err
}
exec := r.client
txClient := r.client
var sqlTx *sql.Tx
if r.begin != nil {
sqlTx, err = r.begin.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
exec = sqlTx
txClient = entClientFromSQLTx(sqlTx)
// 注意:不能调用 txClient.Close(),因为基于事务的 ent client
// 在 Close() 时会尝试将 ExecQuerier 断言为 *sql.DB但实际是 *sql.Tx
// 事务的清理通过 sqlTx.Rollback() 和 sqlTx.Commit() 完成
defer func() { _ = sqlTx.Rollback() }()
if err == nil {
defer func() { _ = tx.Rollback() }()
exec = tx.Client()
txClient = exec
} else {
// 已处于外部事务中ErrTxStarted复用当前 client 参与同一事务。
}
// Lock the group row to avoid concurrent writes while we cascade.
var lockedID int64
if err := exec.QueryRowContext(ctx, "SELECT id FROM groups WHERE id = $1 FOR UPDATE", id).Scan(&lockedID); err != nil {
if errorsIsNoRows(err) {
return nil, service.ErrGroupNotFound
}
// 这里使用 exec.QueryContext 手动扫描,确保同一事务内加锁并能区分“未找到”与其他错误。
rows, err := exec.QueryContext(ctx, "SELECT id FROM groups WHERE id = $1 FOR UPDATE", id)
if err != nil {
return nil, err
}
var lockedID int64
if rows.Next() {
if err := rows.Scan(&lockedID); err != nil {
_ = rows.Close()
return nil, err
}
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
if lockedID == 0 {
return nil, service.ErrGroupNotFound
}
var affectedUserIDs []int64
if groupSvc.IsSubscriptionType() {
@@ -319,8 +319,8 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64,
return nil, err
}
if sqlTx != nil {
if err := sqlTx.Commit(); err != nil {
if tx != nil {
if err := tx.Commit(); err != nil {
return nil, err
}
}
@@ -359,11 +359,6 @@ func (r *groupRepository) loadAccountCounts(ctx context.Context, groupIDs []int6
return counts, nil
}
func entClientFromSQLTx(tx *sql.Tx) *dbent.Client {
drv := entsql.NewDriver(dialect.Postgres, entsql.Conn{ExecQuerier: tx})
return dbent.NewClient(dbent.Driver(drv))
}
func errorsIsNoRows(err error) bool {
return err == sql.ErrNoRows
}