fix(profile): stabilize identity binding management
This commit is contained in:
@@ -301,17 +301,18 @@ func (r *userRepository) BindAuthIdentityToUser(ctx context.Context, input BindA
|
||||
client := clientFromContext(txCtx, r.client)
|
||||
canonical := input.Canonical
|
||||
|
||||
identity, err := client.AuthIdentity.Query().
|
||||
identityRecords, err := client.AuthIdentity.Query().
|
||||
Where(
|
||||
authidentity.ProviderTypeEQ(strings.TrimSpace(canonical.ProviderType)),
|
||||
authidentity.ProviderKeyEQ(strings.TrimSpace(canonical.ProviderKey)),
|
||||
authidentity.ProviderKeyIn(compatibleIdentityProviderKeys(canonical.ProviderType, canonical.ProviderKey)...),
|
||||
authidentity.ProviderSubjectEQ(strings.TrimSpace(canonical.ProviderSubject)),
|
||||
).
|
||||
Only(txCtx)
|
||||
if err != nil && !dbent.IsNotFound(err) {
|
||||
All(txCtx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if identity != nil && identity.UserID != input.UserID {
|
||||
identity := selectOwnedCompatibleIdentity(identityRecords, input.UserID)
|
||||
if identity == nil && hasCompatibleIdentityConflict(identityRecords, input.UserID) {
|
||||
return ErrAuthIdentityOwnershipConflict
|
||||
}
|
||||
if identity == nil {
|
||||
@@ -346,20 +347,21 @@ func (r *userRepository) BindAuthIdentityToUser(ctx context.Context, input BindA
|
||||
|
||||
var channel *dbent.AuthIdentityChannel
|
||||
if input.Channel != nil {
|
||||
channel, err = client.AuthIdentityChannel.Query().
|
||||
channelRecords, err := client.AuthIdentityChannel.Query().
|
||||
Where(
|
||||
authidentitychannel.ProviderTypeEQ(strings.TrimSpace(input.Channel.ProviderType)),
|
||||
authidentitychannel.ProviderKeyEQ(strings.TrimSpace(input.Channel.ProviderKey)),
|
||||
authidentitychannel.ProviderKeyIn(compatibleIdentityProviderKeys(input.Channel.ProviderType, input.Channel.ProviderKey)...),
|
||||
authidentitychannel.ChannelEQ(strings.TrimSpace(input.Channel.Channel)),
|
||||
authidentitychannel.ChannelAppIDEQ(strings.TrimSpace(input.Channel.ChannelAppID)),
|
||||
authidentitychannel.ChannelSubjectEQ(strings.TrimSpace(input.Channel.ChannelSubject)),
|
||||
).
|
||||
WithIdentity().
|
||||
Only(txCtx)
|
||||
if err != nil && !dbent.IsNotFound(err) {
|
||||
All(txCtx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if channel != nil && channel.Edges.Identity != nil && channel.Edges.Identity.UserID != input.UserID {
|
||||
channel = selectOwnedCompatibleChannel(channelRecords, input.UserID)
|
||||
if channel == nil && hasCompatibleChannelConflict(channelRecords, input.UserID) {
|
||||
return ErrAuthIdentityChannelOwnershipConflict
|
||||
}
|
||||
if channel == nil {
|
||||
@@ -397,6 +399,61 @@ func (r *userRepository) BindAuthIdentityToUser(ctx context.Context, input BindA
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func compatibleIdentityProviderKeys(providerType, providerKey string) []string {
|
||||
providerType = strings.TrimSpace(strings.ToLower(providerType))
|
||||
providerKey = strings.TrimSpace(providerKey)
|
||||
if providerKey == "" {
|
||||
return []string{providerKey}
|
||||
}
|
||||
if providerType != "wechat" {
|
||||
return []string{providerKey}
|
||||
}
|
||||
keys := []string{providerKey}
|
||||
if !strings.EqualFold(providerKey, "wechat-main") {
|
||||
keys = append(keys, "wechat-main")
|
||||
}
|
||||
if !strings.EqualFold(providerKey, "wechat") {
|
||||
keys = append(keys, "wechat")
|
||||
}
|
||||
return keys
|
||||
}
|
||||
|
||||
func selectOwnedCompatibleIdentity(records []*dbent.AuthIdentity, userID int64) *dbent.AuthIdentity {
|
||||
for _, record := range records {
|
||||
if record.UserID == userID {
|
||||
return record
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func hasCompatibleIdentityConflict(records []*dbent.AuthIdentity, userID int64) bool {
|
||||
for _, record := range records {
|
||||
if record.UserID != userID {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func selectOwnedCompatibleChannel(records []*dbent.AuthIdentityChannel, userID int64) *dbent.AuthIdentityChannel {
|
||||
for _, record := range records {
|
||||
if record.Edges.Identity != nil && record.Edges.Identity.UserID == userID {
|
||||
return record
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func hasCompatibleChannelConflict(records []*dbent.AuthIdentityChannel, userID int64) bool {
|
||||
for _, record := range records {
|
||||
if record.Edges.Identity != nil && record.Edges.Identity.UserID != userID {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (r *userRepository) RecordProviderGrant(ctx context.Context, input ProviderGrantRecordInput) (bool, error) {
|
||||
exec := txAwareSQLExecutor(ctx, r.sql, r.client)
|
||||
if exec == nil {
|
||||
|
||||
@@ -186,6 +186,79 @@ func (s *UserProfileIdentityRepoSuite) TestBindAuthIdentityToUser_IsIdempotentAn
|
||||
s.Require().ErrorIs(err, ErrAuthIdentityChannelOwnershipConflict)
|
||||
}
|
||||
|
||||
func (s *UserProfileIdentityRepoSuite) TestBindAuthIdentityToUser_ReusesLegacyWeChatAliasRecords() {
|
||||
user := s.mustCreateUser("wechat-legacy-alias")
|
||||
|
||||
legacyIdentity, err := s.client.AuthIdentity.Create().
|
||||
SetUserID(user.ID).
|
||||
SetProviderType("wechat").
|
||||
SetProviderKey("wechat").
|
||||
SetProviderSubject("union-legacy-123").
|
||||
SetMetadata(map[string]any{"source": "legacy-alias"}).
|
||||
Save(s.ctx)
|
||||
s.Require().NoError(err)
|
||||
|
||||
legacyChannel, err := s.client.AuthIdentityChannel.Create().
|
||||
SetIdentityID(legacyIdentity.ID).
|
||||
SetProviderType("wechat").
|
||||
SetProviderKey("wechat").
|
||||
SetChannel("oa").
|
||||
SetChannelAppID("wx-app-legacy").
|
||||
SetChannelSubject("openid-legacy-123").
|
||||
SetMetadata(map[string]any{"scene": "legacy-alias"}).
|
||||
Save(s.ctx)
|
||||
s.Require().NoError(err)
|
||||
|
||||
bound, err := s.repo.BindAuthIdentityToUser(s.ctx, BindAuthIdentityInput{
|
||||
UserID: user.ID,
|
||||
Canonical: AuthIdentityKey{
|
||||
ProviderType: "wechat",
|
||||
ProviderKey: "wechat-main",
|
||||
ProviderSubject: "union-legacy-123",
|
||||
},
|
||||
Channel: &AuthIdentityChannelKey{
|
||||
ProviderType: "wechat",
|
||||
ProviderKey: "wechat-main",
|
||||
Channel: "oa",
|
||||
ChannelAppID: "wx-app-legacy",
|
||||
ChannelSubject: "openid-legacy-123",
|
||||
},
|
||||
Metadata: map[string]any{"source": "canonical-bind"},
|
||||
ChannelMetadata: map[string]any{"scene": "canonical-bind"},
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
s.Require().NotNil(bound)
|
||||
s.Require().NotNil(bound.Identity)
|
||||
s.Require().NotNil(bound.Channel)
|
||||
s.Require().Equal(legacyIdentity.ID, bound.Identity.ID)
|
||||
s.Require().Equal(legacyChannel.ID, bound.Channel.ID)
|
||||
s.Require().Equal("wechat-main", bound.Identity.ProviderKey)
|
||||
s.Require().Equal("wechat-main", bound.Channel.ProviderKey)
|
||||
s.Require().Equal("canonical-bind", bound.Identity.Metadata["source"])
|
||||
s.Require().Equal("canonical-bind", bound.Channel.Metadata["scene"])
|
||||
|
||||
identityCount, err := s.client.AuthIdentity.Query().
|
||||
Where(
|
||||
authidentity.UserIDEQ(user.ID),
|
||||
authidentity.ProviderTypeEQ("wechat"),
|
||||
authidentity.ProviderSubjectEQ("union-legacy-123"),
|
||||
).
|
||||
Count(s.ctx)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal(1, identityCount)
|
||||
|
||||
channelCount, err := s.client.AuthIdentityChannel.Query().
|
||||
Where(
|
||||
authidentitychannel.ProviderTypeEQ("wechat"),
|
||||
authidentitychannel.ChannelEQ("oa"),
|
||||
authidentitychannel.ChannelAppIDEQ("wx-app-legacy"),
|
||||
authidentitychannel.ChannelSubjectEQ("openid-legacy-123"),
|
||||
).
|
||||
Count(s.ctx)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal(1, channelCount)
|
||||
}
|
||||
|
||||
func (s *UserProfileIdentityRepoSuite) TestCreateAuthIdentity_RejectsChannelProviderMismatch() {
|
||||
user := s.mustCreateUser("provider-mismatch-create")
|
||||
|
||||
|
||||
@@ -43,6 +43,9 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error
|
||||
if userIn == nil {
|
||||
return nil
|
||||
}
|
||||
if err := r.ensureNormalizedEmailAvailable(ctx, 0, userIn.Email); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 统一使用 ent 的事务:保证用户与允许分组的更新原子化,
|
||||
// 并避免基于 *sql.Tx 手动构造 ent client 导致的 ExecQuerier 断言错误。
|
||||
@@ -146,6 +149,9 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
|
||||
if userIn == nil {
|
||||
return nil
|
||||
}
|
||||
if err := r.ensureNormalizedEmailAvailable(ctx, userIn.ID, userIn.Email); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 使用 ent 事务包裹用户更新与 allowed_groups 同步,避免跨层事务不一致。
|
||||
tx, err := r.client.Tx(ctx)
|
||||
@@ -704,6 +710,21 @@ func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool,
|
||||
return r.client.User.Query().Where(userEmailLookupPredicate(email)).Exist(ctx)
|
||||
}
|
||||
|
||||
func (r *userRepository) ensureNormalizedEmailAvailable(ctx context.Context, userID int64, email string) error {
|
||||
matches, err := r.client.User.Query().
|
||||
Where(userEmailLookupPredicate(email)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, match := range matches {
|
||||
if match.ID != userID {
|
||||
return service.ErrEmailExists
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func userEmailLookupPredicate(email string) predicate.User {
|
||||
normalized := strings.ToLower(strings.TrimSpace(email))
|
||||
if normalized == "" {
|
||||
|
||||
@@ -67,3 +67,80 @@ func TestUserRepositoryExistsByEmailNormalizesLegacySpacingAndCase(t *testing.T)
|
||||
require.NoError(t, err)
|
||||
require.True(t, exists)
|
||||
}
|
||||
|
||||
func TestUserRepositoryCreateRejectsNormalizedEmailDuplicate(t *testing.T) {
|
||||
repo, _ := newUserEntRepo(t)
|
||||
ctx := context.Background()
|
||||
|
||||
err := repo.Create(ctx, &service.User{
|
||||
Email: " Existing@Example.com ",
|
||||
Username: "existing-user",
|
||||
PasswordHash: "hash",
|
||||
Role: service.RoleUser,
|
||||
Status: service.StatusActive,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = repo.Create(ctx, &service.User{
|
||||
Email: "existing@example.com",
|
||||
Username: "duplicate-user",
|
||||
PasswordHash: "hash",
|
||||
Role: service.RoleUser,
|
||||
Status: service.StatusActive,
|
||||
})
|
||||
require.ErrorIs(t, err, service.ErrEmailExists)
|
||||
}
|
||||
|
||||
func TestUserRepositoryUpdateRejectsNormalizedEmailDuplicate(t *testing.T) {
|
||||
repo, _ := newUserEntRepo(t)
|
||||
ctx := context.Background()
|
||||
|
||||
first := &service.User{
|
||||
Email: " Existing@Example.com ",
|
||||
Username: "existing-user",
|
||||
PasswordHash: "hash",
|
||||
Role: service.RoleUser,
|
||||
Status: service.StatusActive,
|
||||
}
|
||||
require.NoError(t, repo.Create(ctx, first))
|
||||
|
||||
second := &service.User{
|
||||
Email: "second@example.com",
|
||||
Username: "second-user",
|
||||
PasswordHash: "hash",
|
||||
Role: service.RoleUser,
|
||||
Status: service.StatusActive,
|
||||
}
|
||||
require.NoError(t, repo.Create(ctx, second))
|
||||
|
||||
second.Email = " existing@example.com "
|
||||
err := repo.Update(ctx, second)
|
||||
require.ErrorIs(t, err, service.ErrEmailExists)
|
||||
}
|
||||
|
||||
func TestUserRepositoryGetByEmailReportsNormalizedEmailConflict(t *testing.T) {
|
||||
repo, client := newUserEntRepo(t)
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := client.User.Create().
|
||||
SetEmail("Conflict@Example.com").
|
||||
SetUsername("conflict-user-1").
|
||||
SetPasswordHash("hash").
|
||||
SetRole(service.RoleUser).
|
||||
SetStatus(service.StatusActive).
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = client.User.Create().
|
||||
SetEmail(" conflict@example.com ").
|
||||
SetUsername("conflict-user-2").
|
||||
SetPasswordHash("hash").
|
||||
SetRole(service.RoleUser).
|
||||
SetStatus(service.StatusActive).
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = repo.GetByEmail(ctx, "conflict@example.com")
|
||||
require.Error(t, err)
|
||||
require.ErrorContains(t, err, "normalized email lookup matched multiple users")
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user