fix(仓储): 修复软删除过滤与事务测试
修复软删除拦截器使用错误,确保默认查询过滤已删记录 仓储层改用 ent.Tx 与扫描辅助,避免 sql.Tx 断言问题 同步更新集成测试以覆盖事务与统计变动
This commit is contained in:
@@ -3,6 +3,7 @@ package repository
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/apikey"
|
||||
@@ -10,26 +11,16 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/lib/pq"
|
||||
|
||||
"entgo.io/ent/dialect"
|
||||
entsql "entgo.io/ent/dialect/sql"
|
||||
)
|
||||
|
||||
type sqlExecutor interface {
|
||||
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
|
||||
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
|
||||
QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row
|
||||
}
|
||||
|
||||
type sqlBeginner interface {
|
||||
sqlExecutor
|
||||
BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error)
|
||||
}
|
||||
|
||||
type groupRepository struct {
|
||||
client *dbent.Client
|
||||
sql sqlExecutor
|
||||
begin sqlBeginner
|
||||
}
|
||||
|
||||
func NewGroupRepository(client *dbent.Client, sqlDB *sql.DB) service.GroupRepository {
|
||||
@@ -37,11 +28,7 @@ func NewGroupRepository(client *dbent.Client, sqlDB *sql.DB) service.GroupReposi
|
||||
}
|
||||
|
||||
func newGroupRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *groupRepository {
|
||||
var beginner sqlBeginner
|
||||
if b, ok := sqlq.(sqlBeginner); ok {
|
||||
beginner = b
|
||||
}
|
||||
return &groupRepository{client: client, sql: sqlq, begin: beginner}
|
||||
return &groupRepository{client: client, sql: sqlq}
|
||||
}
|
||||
|
||||
func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) error {
|
||||
@@ -214,7 +201,7 @@ func (r *groupRepository) ExistsByName(ctx context.Context, name string) (bool,
|
||||
|
||||
func (r *groupRepository) GetAccountCount(ctx context.Context, groupID int64) (int64, error) {
|
||||
var count int64
|
||||
if err := r.sql.QueryRowContext(ctx, "SELECT COUNT(*) FROM account_groups WHERE group_id = $1", groupID).Scan(&count); err != nil {
|
||||
if err := scanSingleRow(ctx, r.sql, "SELECT COUNT(*) FROM account_groups WHERE group_id = $1", []any{groupID}, &count); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return count, nil
|
||||
@@ -236,31 +223,44 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64,
|
||||
}
|
||||
groupSvc := groupEntityToService(g)
|
||||
|
||||
exec := r.sql
|
||||
// 使用 ent 事务统一包裹:避免手工基于 *sql.Tx 构造 ent client 带来的驱动断言问题,
|
||||
// 同时保证级联删除的原子性。
|
||||
tx, err := r.client.Tx(ctx)
|
||||
if err != nil && !errors.Is(err, dbent.ErrTxStarted) {
|
||||
return nil, err
|
||||
}
|
||||
exec := r.client
|
||||
txClient := r.client
|
||||
var sqlTx *sql.Tx
|
||||
|
||||
if r.begin != nil {
|
||||
sqlTx, err = r.begin.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
exec = sqlTx
|
||||
txClient = entClientFromSQLTx(sqlTx)
|
||||
// 注意:不能调用 txClient.Close(),因为基于事务的 ent client
|
||||
// 在 Close() 时会尝试将 ExecQuerier 断言为 *sql.DB,但实际是 *sql.Tx
|
||||
// 事务的清理通过 sqlTx.Rollback() 和 sqlTx.Commit() 完成
|
||||
defer func() { _ = sqlTx.Rollback() }()
|
||||
if err == nil {
|
||||
defer func() { _ = tx.Rollback() }()
|
||||
exec = tx.Client()
|
||||
txClient = exec
|
||||
} else {
|
||||
// 已处于外部事务中(ErrTxStarted),复用当前 client 参与同一事务。
|
||||
}
|
||||
|
||||
// Lock the group row to avoid concurrent writes while we cascade.
|
||||
var lockedID int64
|
||||
if err := exec.QueryRowContext(ctx, "SELECT id FROM groups WHERE id = $1 FOR UPDATE", id).Scan(&lockedID); err != nil {
|
||||
if errorsIsNoRows(err) {
|
||||
return nil, service.ErrGroupNotFound
|
||||
}
|
||||
// 这里使用 exec.QueryContext 手动扫描,确保同一事务内加锁并能区分“未找到”与其他错误。
|
||||
rows, err := exec.QueryContext(ctx, "SELECT id FROM groups WHERE id = $1 FOR UPDATE", id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var lockedID int64
|
||||
if rows.Next() {
|
||||
if err := rows.Scan(&lockedID); err != nil {
|
||||
_ = rows.Close()
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if lockedID == 0 {
|
||||
return nil, service.ErrGroupNotFound
|
||||
}
|
||||
|
||||
var affectedUserIDs []int64
|
||||
if groupSvc.IsSubscriptionType() {
|
||||
@@ -319,8 +319,8 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64,
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if sqlTx != nil {
|
||||
if err := sqlTx.Commit(); err != nil {
|
||||
if tx != nil {
|
||||
if err := tx.Commit(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
@@ -359,11 +359,6 @@ func (r *groupRepository) loadAccountCounts(ctx context.Context, groupIDs []int6
|
||||
return counts, nil
|
||||
}
|
||||
|
||||
func entClientFromSQLTx(tx *sql.Tx) *dbent.Client {
|
||||
drv := entsql.NewDriver(dialect.Postgres, entsql.Conn{ExecQuerier: tx})
|
||||
return dbent.NewClient(dbent.Driver(drv))
|
||||
}
|
||||
|
||||
func errorsIsNoRows(err error) bool {
|
||||
return err == sql.ErrNoRows
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user