fix(仓储): 修复软删除过滤与事务测试
修复软删除拦截器使用错误,确保默认查询过滤已删记录 仓储层改用 ent.Tx 与扫描辅助,避免 sql.Tx 断言问题 同步更新集成测试以覆盖事务与统计变动
This commit is contained in:
@@ -3,6 +3,7 @@ package repository
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"sort"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
@@ -17,7 +18,6 @@ import (
|
||||
type userRepository struct {
|
||||
client *dbent.Client
|
||||
sql sqlExecutor
|
||||
begin sqlBeginner
|
||||
}
|
||||
|
||||
func NewUserRepository(client *dbent.Client, sqlDB *sql.DB) service.UserRepository {
|
||||
@@ -25,11 +25,7 @@ func NewUserRepository(client *dbent.Client, sqlDB *sql.DB) service.UserReposito
|
||||
}
|
||||
|
||||
func newUserRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *userRepository {
|
||||
var beginner sqlBeginner
|
||||
if b, ok := sqlq.(sqlBeginner); ok {
|
||||
beginner = b
|
||||
}
|
||||
return &userRepository{client: client, sql: sqlq, begin: beginner}
|
||||
return &userRepository{client: client, sql: sqlq}
|
||||
}
|
||||
|
||||
func (r *userRepository) Create(ctx context.Context, userIn *service.User) error {
|
||||
@@ -37,22 +33,20 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error
|
||||
return nil
|
||||
}
|
||||
|
||||
exec := r.sql
|
||||
txClient := r.client
|
||||
var sqlTx *sql.Tx
|
||||
// 统一使用 ent 的事务:保证用户与允许分组的更新原子化,
|
||||
// 并避免基于 *sql.Tx 手动构造 ent client 导致的 ExecQuerier 断言错误。
|
||||
tx, err := r.client.Tx(ctx)
|
||||
if err != nil && !errors.Is(err, dbent.ErrTxStarted) {
|
||||
return err
|
||||
}
|
||||
|
||||
if r.begin != nil {
|
||||
var err error
|
||||
sqlTx, err = r.begin.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
exec = sqlTx
|
||||
txClient = entClientFromSQLTx(sqlTx)
|
||||
// 注意:不能调用 txClient.Close(),因为基于事务的 ent client
|
||||
// 在 Close() 时会尝试将 ExecQuerier 断言为 *sql.DB,但实际是 *sql.Tx
|
||||
// 事务的清理通过 sqlTx.Rollback() 和 sqlTx.Commit() 完成
|
||||
defer func() { _ = sqlTx.Rollback() }()
|
||||
var txClient *dbent.Client
|
||||
if err == nil {
|
||||
defer func() { _ = tx.Rollback() }()
|
||||
txClient = tx.Client()
|
||||
} else {
|
||||
// 已处于外部事务中(ErrTxStarted),复用当前 client 并由调用方负责提交/回滚。
|
||||
txClient = r.client
|
||||
}
|
||||
|
||||
created, err := txClient.User.Create().
|
||||
@@ -70,12 +64,12 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error
|
||||
return translatePersistenceError(err, nil, service.ErrEmailExists)
|
||||
}
|
||||
|
||||
if err := r.syncUserAllowedGroups(ctx, txClient, exec, created.ID, userIn.AllowedGroups); err != nil {
|
||||
if err := r.syncUserAllowedGroupsWithClient(ctx, txClient, created.ID, userIn.AllowedGroups); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if sqlTx != nil {
|
||||
if err := sqlTx.Commit(); err != nil {
|
||||
if tx != nil {
|
||||
if err := tx.Commit(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -121,22 +115,19 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
|
||||
return nil
|
||||
}
|
||||
|
||||
exec := r.sql
|
||||
txClient := r.client
|
||||
var sqlTx *sql.Tx
|
||||
// 使用 ent 事务包裹用户更新与 allowed_groups 同步,避免跨层事务不一致。
|
||||
tx, err := r.client.Tx(ctx)
|
||||
if err != nil && !errors.Is(err, dbent.ErrTxStarted) {
|
||||
return err
|
||||
}
|
||||
|
||||
if r.begin != nil {
|
||||
var err error
|
||||
sqlTx, err = r.begin.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
exec = sqlTx
|
||||
txClient = entClientFromSQLTx(sqlTx)
|
||||
// 注意:不能调用 txClient.Close(),因为基于事务的 ent client
|
||||
// 在 Close() 时会尝试将 ExecQuerier 断言为 *sql.DB,但实际是 *sql.Tx
|
||||
// 事务的清理通过 sqlTx.Rollback() 和 sqlTx.Commit() 完成
|
||||
defer func() { _ = sqlTx.Rollback() }()
|
||||
var txClient *dbent.Client
|
||||
if err == nil {
|
||||
defer func() { _ = tx.Rollback() }()
|
||||
txClient = tx.Client()
|
||||
} else {
|
||||
// 已处于外部事务中(ErrTxStarted),复用当前 client 并由调用方负责提交/回滚。
|
||||
txClient = r.client
|
||||
}
|
||||
|
||||
updated, err := txClient.User.UpdateOneID(userIn.ID).
|
||||
@@ -154,12 +145,12 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
|
||||
return translatePersistenceError(err, service.ErrUserNotFound, service.ErrEmailExists)
|
||||
}
|
||||
|
||||
if err := r.syncUserAllowedGroups(ctx, txClient, exec, updated.ID, userIn.AllowedGroups); err != nil {
|
||||
if err := r.syncUserAllowedGroupsWithClient(ctx, txClient, updated.ID, userIn.AllowedGroups); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if sqlTx != nil {
|
||||
if err := sqlTx.Commit(); err != nil {
|
||||
if tx != nil {
|
||||
if err := tx.Commit(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -289,8 +280,10 @@ func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool,
|
||||
}
|
||||
|
||||
func (r *userRepository) RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error) {
|
||||
if r.sql == nil {
|
||||
return 0, nil
|
||||
exec := r.sql
|
||||
if exec == nil {
|
||||
// 未注入 sqlExecutor 时,退回到 ent client 的 ExecContext(支持事务)。
|
||||
exec = r.client
|
||||
}
|
||||
|
||||
joinAffected, err := r.client.UserAllowedGroup.Delete().
|
||||
@@ -300,7 +293,7 @@ func (r *userRepository) RemoveGroupFromAllowedGroups(ctx context.Context, group
|
||||
return 0, err
|
||||
}
|
||||
|
||||
arrayRes, err := r.sql.ExecContext(
|
||||
arrayRes, err := exec.ExecContext(
|
||||
ctx,
|
||||
"UPDATE users SET allowed_groups = array_remove(allowed_groups, $1), updated_at = NOW() WHERE $1 = ANY(allowed_groups)",
|
||||
groupID,
|
||||
@@ -362,6 +355,56 @@ func (r *userRepository) loadAllowedGroups(ctx context.Context, userIDs []int64)
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// syncUserAllowedGroupsWithClient 在 ent client/事务内同步用户允许分组:
|
||||
// 1) 以 user_allowed_groups 为读写源,确保新旧逻辑一致;
|
||||
// 2) 额外更新 users.allowed_groups(历史字段)以保持兼容。
|
||||
func (r *userRepository) syncUserAllowedGroupsWithClient(ctx context.Context, client *dbent.Client, userID int64, groupIDs []int64) error {
|
||||
if client == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Keep join table as the source of truth for reads.
|
||||
if _, err := client.UserAllowedGroup.Delete().Where(userallowedgroup.UserIDEQ(userID)).Exec(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
unique := make(map[int64]struct{}, len(groupIDs))
|
||||
for _, id := range groupIDs {
|
||||
if id <= 0 {
|
||||
continue
|
||||
}
|
||||
unique[id] = struct{}{}
|
||||
}
|
||||
|
||||
legacyGroups := make([]int64, 0, len(unique))
|
||||
if len(unique) > 0 {
|
||||
creates := make([]*dbent.UserAllowedGroupCreate, 0, len(unique))
|
||||
for groupID := range unique {
|
||||
creates = append(creates, client.UserAllowedGroup.Create().SetUserID(userID).SetGroupID(groupID))
|
||||
legacyGroups = append(legacyGroups, groupID)
|
||||
}
|
||||
if err := client.UserAllowedGroup.
|
||||
CreateBulk(creates...).
|
||||
OnConflictColumns(userallowedgroup.FieldUserID, userallowedgroup.FieldGroupID).
|
||||
DoNothing().
|
||||
Exec(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Phase 1 兼容:保持 users.allowed_groups(数组字段)同步,避免旧查询路径读取到过期数据。
|
||||
var legacy any
|
||||
if len(legacyGroups) > 0 {
|
||||
sort.Slice(legacyGroups, func(i, j int) bool { return legacyGroups[i] < legacyGroups[j] })
|
||||
legacy = pq.Array(legacyGroups)
|
||||
}
|
||||
if _, err := client.ExecContext(ctx, "UPDATE users SET allowed_groups = $1::bigint[] WHERE id = $2", legacy, userID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *userRepository) syncUserAllowedGroups(ctx context.Context, client *dbent.Client, exec sqlExecutor, userID int64, groupIDs []int64) error {
|
||||
if client == nil || exec == nil {
|
||||
return nil
|
||||
|
||||
Reference in New Issue
Block a user