diff --git a/backend/internal/repository/user_profile_identity_repo.go b/backend/internal/repository/user_profile_identity_repo.go index dbba364d..25990c52 100644 --- a/backend/internal/repository/user_profile_identity_repo.go +++ b/backend/internal/repository/user_profile_identity_repo.go @@ -9,10 +9,12 @@ import ( "time" "unsafe" + entsql "entgo.io/ent/dialect/sql" dbent "github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/ent/authidentity" "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" + dbpredicate "github.com/Wei-Shaw/sub2api/ent/predicate" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/Wei-Shaw/sub2api/internal/service" ) @@ -379,6 +381,24 @@ ON CONFLICT (user_id, provider_type, grant_reason) DO NOTHING`, func (r *userRepository) UpsertIdentityAdoptionDecision(ctx context.Context, input IdentityAdoptionDecisionInput) (*dbent.IdentityAdoptionDecision, error) { client := clientFromContext(ctx, r.client) + if input.IdentityID != nil && *input.IdentityID > 0 { + if _, err := client.IdentityAdoptionDecision.Update(). + Where( + identityadoptiondecision.IdentityIDEQ(*input.IdentityID), + dbpredicate.IdentityAdoptionDecision(func(s *entsql.Selector) { + col := s.C(identityadoptiondecision.FieldPendingAuthSessionID) + s.Where(entsql.Or( + entsql.IsNull(col), + entsql.NEQ(col, input.PendingAuthSessionID), + )) + }), + ). + ClearIdentityID(). + Save(ctx); err != nil { + return nil, err + } + } + current, err := client.IdentityAdoptionDecision.Query(). Where(identityadoptiondecision.PendingAuthSessionIDEQ(input.PendingAuthSessionID)). Only(ctx) diff --git a/backend/internal/repository/user_profile_identity_repo_contract_test.go b/backend/internal/repository/user_profile_identity_repo_contract_test.go index d24a7d83..c5d0d897 100644 --- a/backend/internal/repository/user_profile_identity_repo_contract_test.go +++ b/backend/internal/repository/user_profile_identity_repo_contract_test.go @@ -353,6 +353,45 @@ func (s *UserProfileIdentityRepoSuite) TestUpsertIdentityAdoptionDecision_Persis s.Require().Equal(identity.Identity.ID, *loaded.IdentityID) } +func (s *UserProfileIdentityRepoSuite) TestUpsertIdentityAdoptionDecision_ReassignsExistingIdentityReference() { + user := s.mustCreateUser("adoption-reassign") + identity, err := s.repo.CreateAuthIdentity(s.ctx, CreateAuthIdentityInput{ + UserID: user.ID, + Canonical: AuthIdentityKey{ + ProviderType: "wechat", + ProviderKey: "wechat-open", + ProviderSubject: "union-adoption-reassign", + }, + }) + s.Require().NoError(err) + + firstSession := s.mustCreatePendingAuthSession(identity.IdentityRef()) + firstDecision, err := s.repo.UpsertIdentityAdoptionDecision(s.ctx, IdentityAdoptionDecisionInput{ + PendingAuthSessionID: firstSession.ID, + IdentityID: &identity.Identity.ID, + AdoptDisplayName: true, + AdoptAvatar: false, + }) + s.Require().NoError(err) + s.Require().NotNil(firstDecision.IdentityID) + s.Require().Equal(identity.Identity.ID, *firstDecision.IdentityID) + + secondSession := s.mustCreatePendingAuthSession(identity.IdentityRef()) + secondDecision, err := s.repo.UpsertIdentityAdoptionDecision(s.ctx, IdentityAdoptionDecisionInput{ + PendingAuthSessionID: secondSession.ID, + IdentityID: &identity.Identity.ID, + AdoptDisplayName: false, + AdoptAvatar: true, + }) + s.Require().NoError(err) + s.Require().NotNil(secondDecision.IdentityID) + s.Require().Equal(identity.Identity.ID, *secondDecision.IdentityID) + + reloadedFirst, err := s.repo.GetIdentityAdoptionDecisionByPendingAuthSessionID(s.ctx, firstSession.ID) + s.Require().NoError(err) + s.Require().Nil(reloadedFirst.IdentityID) +} + func (s *UserProfileIdentityRepoSuite) TestUserAvatarCRUDAndUserLookup() { user := s.mustCreateUser("avatar")