fix(profile): stabilize identity binding management
This commit is contained in:
@@ -301,17 +301,18 @@ func (r *userRepository) BindAuthIdentityToUser(ctx context.Context, input BindA
|
||||
client := clientFromContext(txCtx, r.client)
|
||||
canonical := input.Canonical
|
||||
|
||||
identity, err := client.AuthIdentity.Query().
|
||||
identityRecords, err := client.AuthIdentity.Query().
|
||||
Where(
|
||||
authidentity.ProviderTypeEQ(strings.TrimSpace(canonical.ProviderType)),
|
||||
authidentity.ProviderKeyEQ(strings.TrimSpace(canonical.ProviderKey)),
|
||||
authidentity.ProviderKeyIn(compatibleIdentityProviderKeys(canonical.ProviderType, canonical.ProviderKey)...),
|
||||
authidentity.ProviderSubjectEQ(strings.TrimSpace(canonical.ProviderSubject)),
|
||||
).
|
||||
Only(txCtx)
|
||||
if err != nil && !dbent.IsNotFound(err) {
|
||||
All(txCtx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if identity != nil && identity.UserID != input.UserID {
|
||||
identity := selectOwnedCompatibleIdentity(identityRecords, input.UserID)
|
||||
if identity == nil && hasCompatibleIdentityConflict(identityRecords, input.UserID) {
|
||||
return ErrAuthIdentityOwnershipConflict
|
||||
}
|
||||
if identity == nil {
|
||||
@@ -346,20 +347,21 @@ func (r *userRepository) BindAuthIdentityToUser(ctx context.Context, input BindA
|
||||
|
||||
var channel *dbent.AuthIdentityChannel
|
||||
if input.Channel != nil {
|
||||
channel, err = client.AuthIdentityChannel.Query().
|
||||
channelRecords, err := client.AuthIdentityChannel.Query().
|
||||
Where(
|
||||
authidentitychannel.ProviderTypeEQ(strings.TrimSpace(input.Channel.ProviderType)),
|
||||
authidentitychannel.ProviderKeyEQ(strings.TrimSpace(input.Channel.ProviderKey)),
|
||||
authidentitychannel.ProviderKeyIn(compatibleIdentityProviderKeys(input.Channel.ProviderType, input.Channel.ProviderKey)...),
|
||||
authidentitychannel.ChannelEQ(strings.TrimSpace(input.Channel.Channel)),
|
||||
authidentitychannel.ChannelAppIDEQ(strings.TrimSpace(input.Channel.ChannelAppID)),
|
||||
authidentitychannel.ChannelSubjectEQ(strings.TrimSpace(input.Channel.ChannelSubject)),
|
||||
).
|
||||
WithIdentity().
|
||||
Only(txCtx)
|
||||
if err != nil && !dbent.IsNotFound(err) {
|
||||
All(txCtx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if channel != nil && channel.Edges.Identity != nil && channel.Edges.Identity.UserID != input.UserID {
|
||||
channel = selectOwnedCompatibleChannel(channelRecords, input.UserID)
|
||||
if channel == nil && hasCompatibleChannelConflict(channelRecords, input.UserID) {
|
||||
return ErrAuthIdentityChannelOwnershipConflict
|
||||
}
|
||||
if channel == nil {
|
||||
@@ -397,6 +399,61 @@ func (r *userRepository) BindAuthIdentityToUser(ctx context.Context, input BindA
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func compatibleIdentityProviderKeys(providerType, providerKey string) []string {
|
||||
providerType = strings.TrimSpace(strings.ToLower(providerType))
|
||||
providerKey = strings.TrimSpace(providerKey)
|
||||
if providerKey == "" {
|
||||
return []string{providerKey}
|
||||
}
|
||||
if providerType != "wechat" {
|
||||
return []string{providerKey}
|
||||
}
|
||||
keys := []string{providerKey}
|
||||
if !strings.EqualFold(providerKey, "wechat-main") {
|
||||
keys = append(keys, "wechat-main")
|
||||
}
|
||||
if !strings.EqualFold(providerKey, "wechat") {
|
||||
keys = append(keys, "wechat")
|
||||
}
|
||||
return keys
|
||||
}
|
||||
|
||||
func selectOwnedCompatibleIdentity(records []*dbent.AuthIdentity, userID int64) *dbent.AuthIdentity {
|
||||
for _, record := range records {
|
||||
if record.UserID == userID {
|
||||
return record
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func hasCompatibleIdentityConflict(records []*dbent.AuthIdentity, userID int64) bool {
|
||||
for _, record := range records {
|
||||
if record.UserID != userID {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func selectOwnedCompatibleChannel(records []*dbent.AuthIdentityChannel, userID int64) *dbent.AuthIdentityChannel {
|
||||
for _, record := range records {
|
||||
if record.Edges.Identity != nil && record.Edges.Identity.UserID == userID {
|
||||
return record
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func hasCompatibleChannelConflict(records []*dbent.AuthIdentityChannel, userID int64) bool {
|
||||
for _, record := range records {
|
||||
if record.Edges.Identity != nil && record.Edges.Identity.UserID != userID {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (r *userRepository) RecordProviderGrant(ctx context.Context, input ProviderGrantRecordInput) (bool, error) {
|
||||
exec := txAwareSQLExecutor(ctx, r.sql, r.client)
|
||||
if exec == nil {
|
||||
|
||||
Reference in New Issue
Block a user