fix(auth): harden oauth identity upgrade paths

This commit is contained in:
IanShaw027
2026-04-22 14:56:56 +08:00
parent 3d29f7c2fa
commit 36aed35957
32 changed files with 2365 additions and 262 deletions

View File

@@ -43,9 +43,6 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error
if userIn == nil {
return nil
}
if err := r.ensureNormalizedEmailAvailable(ctx, 0, userIn.Email); err != nil {
return err
}
// 统一使用 ent 的事务:保证用户与允许分组的更新原子化,
// 并避免基于 *sql.Tx 手动构造 ent client 导致的 ExecQuerier 断言错误。
@@ -55,9 +52,11 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error
}
var txClient *dbent.Client
txCtx := ctx
if err == nil {
defer func() { _ = tx.Rollback() }()
txClient = tx.Client()
txCtx = dbent.NewTxContext(ctx, tx)
} else {
// 已处于外部事务中ErrTxStarted复用当前事务 client 并由调用方负责提交/回滚。
if existingTx := dbent.TxFromContext(ctx); existingTx != nil {
@@ -67,6 +66,21 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error
}
}
releaseEmailLock, err := lockRepositoryScopedKeys(
txCtx,
txClient,
txAwareSQLExecutor(txCtx, r.sql, r.client),
normalizedEmailUniquenessLockKey(userIn.Email),
)
if err != nil {
return err
}
defer releaseEmailLock()
if err := ensureNormalizedEmailAvailableWithClient(txCtx, txClient, 0, userIn.Email); err != nil {
return err
}
created, err := txClient.User.Create().
SetEmail(userIn.Email).
SetUsername(userIn.Username).
@@ -79,15 +93,15 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error
SetSignupSource(userSignupSourceOrDefault(userIn.SignupSource)).
SetNillableLastLoginAt(userIn.LastLoginAt).
SetNillableLastActiveAt(userIn.LastActiveAt).
Save(ctx)
Save(txCtx)
if err != nil {
return translatePersistenceError(err, nil, service.ErrEmailExists)
}
if err := r.syncUserAllowedGroupsWithClient(ctx, txClient, created.ID, userIn.AllowedGroups); err != nil {
if err := r.syncUserAllowedGroupsWithClient(txCtx, txClient, created.ID, userIn.AllowedGroups); err != nil {
return err
}
if err := ensureEmailAuthIdentityWithClient(ctx, txClient, created.ID, created.Email, "user_repo_create"); err != nil {
if err := ensureEmailAuthIdentityWithClient(txCtx, txClient, created.ID, created.Email, "user_repo_create"); err != nil {
return err
}
@@ -149,9 +163,6 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
if userIn == nil {
return nil
}
if err := r.ensureNormalizedEmailAvailable(ctx, userIn.ID, userIn.Email); err != nil {
return err
}
// 使用 ent 事务包裹用户更新与 allowed_groups 同步,避免跨层事务不一致。
tx, err := r.client.Tx(ctx)
@@ -160,9 +171,11 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
}
var txClient *dbent.Client
txCtx := ctx
if err == nil {
defer func() { _ = tx.Rollback() }()
txClient = tx.Client()
txCtx = dbent.NewTxContext(ctx, tx)
} else {
// 已处于外部事务中ErrTxStarted复用当前事务 client 并由调用方负责提交/回滚。
if existingTx := dbent.TxFromContext(ctx); existingTx != nil {
@@ -171,7 +184,23 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
txClient = r.client
}
}
existing, err := clientFromContext(ctx, txClient).User.Get(ctx, userIn.ID)
releaseEmailLock, err := lockRepositoryScopedKeys(
txCtx,
txClient,
txAwareSQLExecutor(txCtx, r.sql, r.client),
normalizedEmailUniquenessLockKey(userIn.Email),
)
if err != nil {
return err
}
defer releaseEmailLock()
if err := ensureNormalizedEmailAvailableWithClient(txCtx, txClient, userIn.ID, userIn.Email); err != nil {
return err
}
existing, err := clientFromContext(txCtx, txClient).User.Get(txCtx, userIn.ID)
if err != nil {
return translatePersistenceError(err, service.ErrUserNotFound, nil)
}
@@ -203,15 +232,15 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
if userIn.BalanceNotifyThreshold == nil {
updateOp = updateOp.ClearBalanceNotifyThreshold()
}
updated, err := updateOp.Save(ctx)
updated, err := updateOp.Save(txCtx)
if err != nil {
return translatePersistenceError(err, service.ErrUserNotFound, service.ErrEmailExists)
}
if err := r.syncUserAllowedGroupsWithClient(ctx, txClient, updated.ID, userIn.AllowedGroups); err != nil {
if err := r.syncUserAllowedGroupsWithClient(txCtx, txClient, updated.ID, userIn.AllowedGroups); err != nil {
return err
}
if err := replaceEmailAuthIdentityWithClient(ctx, txClient, updated.ID, oldEmail, updated.Email, "user_repo_update"); err != nil {
if err := replaceEmailAuthIdentityWithClient(txCtx, txClient, updated.ID, oldEmail, updated.Email, "user_repo_update"); err != nil {
return err
}
@@ -711,7 +740,16 @@ func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool,
}
func (r *userRepository) ensureNormalizedEmailAvailable(ctx context.Context, userID int64, email string) error {
matches, err := r.client.User.Query().
return ensureNormalizedEmailAvailableWithClient(ctx, clientFromContext(ctx, r.client), userID, email)
}
func ensureNormalizedEmailAvailableWithClient(ctx context.Context, client *dbent.Client, userID int64, email string) error {
client = clientFromContext(ctx, client)
if client == nil {
return nil
}
matches, err := client.User.Query().
Where(userEmailLookupPredicate(email)).
All(ctx)
if err != nil {
@@ -726,7 +764,7 @@ func (r *userRepository) ensureNormalizedEmailAvailable(ctx context.Context, use
}
func userEmailLookupPredicate(email string) predicate.User {
normalized := strings.ToLower(strings.TrimSpace(email))
normalized := normalizeEmailLookupValue(email)
if normalized == "" {
return dbuser.EmailEQ(email)
}
@@ -740,6 +778,18 @@ func userEmailLookupPredicate(email string) predicate.User {
})
}
func normalizeEmailLookupValue(email string) string {
return strings.ToLower(strings.TrimSpace(email))
}
func normalizedEmailUniquenessLockKey(email string) string {
normalized := normalizeEmailLookupValue(email)
if normalized == "" {
return ""
}
return "users:normalized-email:" + normalized
}
func (r *userRepository) AddGroupToAllowedGroups(ctx context.Context, userID int64, groupID int64) error {
client := clientFromContext(ctx, r.client)
err := client.UserAllowedGroup.Create().
@@ -874,11 +924,14 @@ func applyUserEntityToService(dst *service.User, src *dbent.User) {
}
func userSignupSourceOrDefault(signupSource string) string {
signupSource = strings.TrimSpace(signupSource)
if signupSource == "" {
switch strings.TrimSpace(strings.ToLower(signupSource)) {
case "", "email":
return "email"
case "linuxdo", "wechat", "oidc":
return strings.TrimSpace(strings.ToLower(signupSource))
default:
return "email"
}
return signupSource
}
// marshalExtraEmails serializes notify email entries to JSON for storage.