fix(auth): harden oauth identity upgrade paths
This commit is contained in:
@@ -43,9 +43,6 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error
|
||||
if userIn == nil {
|
||||
return nil
|
||||
}
|
||||
if err := r.ensureNormalizedEmailAvailable(ctx, 0, userIn.Email); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 统一使用 ent 的事务:保证用户与允许分组的更新原子化,
|
||||
// 并避免基于 *sql.Tx 手动构造 ent client 导致的 ExecQuerier 断言错误。
|
||||
@@ -55,9 +52,11 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error
|
||||
}
|
||||
|
||||
var txClient *dbent.Client
|
||||
txCtx := ctx
|
||||
if err == nil {
|
||||
defer func() { _ = tx.Rollback() }()
|
||||
txClient = tx.Client()
|
||||
txCtx = dbent.NewTxContext(ctx, tx)
|
||||
} else {
|
||||
// 已处于外部事务中(ErrTxStarted),复用当前事务 client 并由调用方负责提交/回滚。
|
||||
if existingTx := dbent.TxFromContext(ctx); existingTx != nil {
|
||||
@@ -67,6 +66,21 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error
|
||||
}
|
||||
}
|
||||
|
||||
releaseEmailLock, err := lockRepositoryScopedKeys(
|
||||
txCtx,
|
||||
txClient,
|
||||
txAwareSQLExecutor(txCtx, r.sql, r.client),
|
||||
normalizedEmailUniquenessLockKey(userIn.Email),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer releaseEmailLock()
|
||||
|
||||
if err := ensureNormalizedEmailAvailableWithClient(txCtx, txClient, 0, userIn.Email); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
created, err := txClient.User.Create().
|
||||
SetEmail(userIn.Email).
|
||||
SetUsername(userIn.Username).
|
||||
@@ -79,15 +93,15 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error
|
||||
SetSignupSource(userSignupSourceOrDefault(userIn.SignupSource)).
|
||||
SetNillableLastLoginAt(userIn.LastLoginAt).
|
||||
SetNillableLastActiveAt(userIn.LastActiveAt).
|
||||
Save(ctx)
|
||||
Save(txCtx)
|
||||
if err != nil {
|
||||
return translatePersistenceError(err, nil, service.ErrEmailExists)
|
||||
}
|
||||
|
||||
if err := r.syncUserAllowedGroupsWithClient(ctx, txClient, created.ID, userIn.AllowedGroups); err != nil {
|
||||
if err := r.syncUserAllowedGroupsWithClient(txCtx, txClient, created.ID, userIn.AllowedGroups); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := ensureEmailAuthIdentityWithClient(ctx, txClient, created.ID, created.Email, "user_repo_create"); err != nil {
|
||||
if err := ensureEmailAuthIdentityWithClient(txCtx, txClient, created.ID, created.Email, "user_repo_create"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -149,9 +163,6 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
|
||||
if userIn == nil {
|
||||
return nil
|
||||
}
|
||||
if err := r.ensureNormalizedEmailAvailable(ctx, userIn.ID, userIn.Email); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 使用 ent 事务包裹用户更新与 allowed_groups 同步,避免跨层事务不一致。
|
||||
tx, err := r.client.Tx(ctx)
|
||||
@@ -160,9 +171,11 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
|
||||
}
|
||||
|
||||
var txClient *dbent.Client
|
||||
txCtx := ctx
|
||||
if err == nil {
|
||||
defer func() { _ = tx.Rollback() }()
|
||||
txClient = tx.Client()
|
||||
txCtx = dbent.NewTxContext(ctx, tx)
|
||||
} else {
|
||||
// 已处于外部事务中(ErrTxStarted),复用当前事务 client 并由调用方负责提交/回滚。
|
||||
if existingTx := dbent.TxFromContext(ctx); existingTx != nil {
|
||||
@@ -171,7 +184,23 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
|
||||
txClient = r.client
|
||||
}
|
||||
}
|
||||
existing, err := clientFromContext(ctx, txClient).User.Get(ctx, userIn.ID)
|
||||
|
||||
releaseEmailLock, err := lockRepositoryScopedKeys(
|
||||
txCtx,
|
||||
txClient,
|
||||
txAwareSQLExecutor(txCtx, r.sql, r.client),
|
||||
normalizedEmailUniquenessLockKey(userIn.Email),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer releaseEmailLock()
|
||||
|
||||
if err := ensureNormalizedEmailAvailableWithClient(txCtx, txClient, userIn.ID, userIn.Email); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
existing, err := clientFromContext(txCtx, txClient).User.Get(txCtx, userIn.ID)
|
||||
if err != nil {
|
||||
return translatePersistenceError(err, service.ErrUserNotFound, nil)
|
||||
}
|
||||
@@ -203,15 +232,15 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
|
||||
if userIn.BalanceNotifyThreshold == nil {
|
||||
updateOp = updateOp.ClearBalanceNotifyThreshold()
|
||||
}
|
||||
updated, err := updateOp.Save(ctx)
|
||||
updated, err := updateOp.Save(txCtx)
|
||||
if err != nil {
|
||||
return translatePersistenceError(err, service.ErrUserNotFound, service.ErrEmailExists)
|
||||
}
|
||||
|
||||
if err := r.syncUserAllowedGroupsWithClient(ctx, txClient, updated.ID, userIn.AllowedGroups); err != nil {
|
||||
if err := r.syncUserAllowedGroupsWithClient(txCtx, txClient, updated.ID, userIn.AllowedGroups); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := replaceEmailAuthIdentityWithClient(ctx, txClient, updated.ID, oldEmail, updated.Email, "user_repo_update"); err != nil {
|
||||
if err := replaceEmailAuthIdentityWithClient(txCtx, txClient, updated.ID, oldEmail, updated.Email, "user_repo_update"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -711,7 +740,16 @@ func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool,
|
||||
}
|
||||
|
||||
func (r *userRepository) ensureNormalizedEmailAvailable(ctx context.Context, userID int64, email string) error {
|
||||
matches, err := r.client.User.Query().
|
||||
return ensureNormalizedEmailAvailableWithClient(ctx, clientFromContext(ctx, r.client), userID, email)
|
||||
}
|
||||
|
||||
func ensureNormalizedEmailAvailableWithClient(ctx context.Context, client *dbent.Client, userID int64, email string) error {
|
||||
client = clientFromContext(ctx, client)
|
||||
if client == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
matches, err := client.User.Query().
|
||||
Where(userEmailLookupPredicate(email)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
@@ -726,7 +764,7 @@ func (r *userRepository) ensureNormalizedEmailAvailable(ctx context.Context, use
|
||||
}
|
||||
|
||||
func userEmailLookupPredicate(email string) predicate.User {
|
||||
normalized := strings.ToLower(strings.TrimSpace(email))
|
||||
normalized := normalizeEmailLookupValue(email)
|
||||
if normalized == "" {
|
||||
return dbuser.EmailEQ(email)
|
||||
}
|
||||
@@ -740,6 +778,18 @@ func userEmailLookupPredicate(email string) predicate.User {
|
||||
})
|
||||
}
|
||||
|
||||
func normalizeEmailLookupValue(email string) string {
|
||||
return strings.ToLower(strings.TrimSpace(email))
|
||||
}
|
||||
|
||||
func normalizedEmailUniquenessLockKey(email string) string {
|
||||
normalized := normalizeEmailLookupValue(email)
|
||||
if normalized == "" {
|
||||
return ""
|
||||
}
|
||||
return "users:normalized-email:" + normalized
|
||||
}
|
||||
|
||||
func (r *userRepository) AddGroupToAllowedGroups(ctx context.Context, userID int64, groupID int64) error {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
err := client.UserAllowedGroup.Create().
|
||||
@@ -874,11 +924,14 @@ func applyUserEntityToService(dst *service.User, src *dbent.User) {
|
||||
}
|
||||
|
||||
func userSignupSourceOrDefault(signupSource string) string {
|
||||
signupSource = strings.TrimSpace(signupSource)
|
||||
if signupSource == "" {
|
||||
switch strings.TrimSpace(strings.ToLower(signupSource)) {
|
||||
case "", "email":
|
||||
return "email"
|
||||
case "linuxdo", "wechat", "oidc":
|
||||
return strings.TrimSpace(strings.ToLower(signupSource))
|
||||
default:
|
||||
return "email"
|
||||
}
|
||||
return signupSource
|
||||
}
|
||||
|
||||
// marshalExtraEmails serializes notify email entries to JSON for storage.
|
||||
|
||||
Reference in New Issue
Block a user