fix(profile): stabilize identity binding management

This commit is contained in:
IanShaw027
2026-04-22 13:19:28 +08:00
parent 83cad63ce0
commit 81c827ee51
13 changed files with 584 additions and 39 deletions

View File

@@ -301,17 +301,18 @@ func (r *userRepository) BindAuthIdentityToUser(ctx context.Context, input BindA
client := clientFromContext(txCtx, r.client)
canonical := input.Canonical
identity, err := client.AuthIdentity.Query().
identityRecords, err := client.AuthIdentity.Query().
Where(
authidentity.ProviderTypeEQ(strings.TrimSpace(canonical.ProviderType)),
authidentity.ProviderKeyEQ(strings.TrimSpace(canonical.ProviderKey)),
authidentity.ProviderKeyIn(compatibleIdentityProviderKeys(canonical.ProviderType, canonical.ProviderKey)...),
authidentity.ProviderSubjectEQ(strings.TrimSpace(canonical.ProviderSubject)),
).
Only(txCtx)
if err != nil && !dbent.IsNotFound(err) {
All(txCtx)
if err != nil {
return err
}
if identity != nil && identity.UserID != input.UserID {
identity := selectOwnedCompatibleIdentity(identityRecords, input.UserID)
if identity == nil && hasCompatibleIdentityConflict(identityRecords, input.UserID) {
return ErrAuthIdentityOwnershipConflict
}
if identity == nil {
@@ -346,20 +347,21 @@ func (r *userRepository) BindAuthIdentityToUser(ctx context.Context, input BindA
var channel *dbent.AuthIdentityChannel
if input.Channel != nil {
channel, err = client.AuthIdentityChannel.Query().
channelRecords, err := client.AuthIdentityChannel.Query().
Where(
authidentitychannel.ProviderTypeEQ(strings.TrimSpace(input.Channel.ProviderType)),
authidentitychannel.ProviderKeyEQ(strings.TrimSpace(input.Channel.ProviderKey)),
authidentitychannel.ProviderKeyIn(compatibleIdentityProviderKeys(input.Channel.ProviderType, input.Channel.ProviderKey)...),
authidentitychannel.ChannelEQ(strings.TrimSpace(input.Channel.Channel)),
authidentitychannel.ChannelAppIDEQ(strings.TrimSpace(input.Channel.ChannelAppID)),
authidentitychannel.ChannelSubjectEQ(strings.TrimSpace(input.Channel.ChannelSubject)),
).
WithIdentity().
Only(txCtx)
if err != nil && !dbent.IsNotFound(err) {
All(txCtx)
if err != nil {
return err
}
if channel != nil && channel.Edges.Identity != nil && channel.Edges.Identity.UserID != input.UserID {
channel = selectOwnedCompatibleChannel(channelRecords, input.UserID)
if channel == nil && hasCompatibleChannelConflict(channelRecords, input.UserID) {
return ErrAuthIdentityChannelOwnershipConflict
}
if channel == nil {
@@ -397,6 +399,61 @@ func (r *userRepository) BindAuthIdentityToUser(ctx context.Context, input BindA
return result, nil
}
func compatibleIdentityProviderKeys(providerType, providerKey string) []string {
providerType = strings.TrimSpace(strings.ToLower(providerType))
providerKey = strings.TrimSpace(providerKey)
if providerKey == "" {
return []string{providerKey}
}
if providerType != "wechat" {
return []string{providerKey}
}
keys := []string{providerKey}
if !strings.EqualFold(providerKey, "wechat-main") {
keys = append(keys, "wechat-main")
}
if !strings.EqualFold(providerKey, "wechat") {
keys = append(keys, "wechat")
}
return keys
}
func selectOwnedCompatibleIdentity(records []*dbent.AuthIdentity, userID int64) *dbent.AuthIdentity {
for _, record := range records {
if record.UserID == userID {
return record
}
}
return nil
}
func hasCompatibleIdentityConflict(records []*dbent.AuthIdentity, userID int64) bool {
for _, record := range records {
if record.UserID != userID {
return true
}
}
return false
}
func selectOwnedCompatibleChannel(records []*dbent.AuthIdentityChannel, userID int64) *dbent.AuthIdentityChannel {
for _, record := range records {
if record.Edges.Identity != nil && record.Edges.Identity.UserID == userID {
return record
}
}
return nil
}
func hasCompatibleChannelConflict(records []*dbent.AuthIdentityChannel, userID int64) bool {
for _, record := range records {
if record.Edges.Identity != nil && record.Edges.Identity.UserID != userID {
return true
}
}
return false
}
func (r *userRepository) RecordProviderGrant(ctx context.Context, input ProviderGrantRecordInput) (bool, error) {
exec := txAwareSQLExecutor(ctx, r.sql, r.client)
if exec == nil {