feat(auth): reclaim stale identities and refresh profile UI
This commit is contained in:
@@ -644,15 +644,17 @@ func resolvePendingOAuthTargetUserID(ctx context.Context, client *dbent.Client,
|
||||
}
|
||||
|
||||
func userNormalizedEmailPredicate(email string) predicate.User {
|
||||
normalized := strings.TrimSpace(email)
|
||||
normalized := strings.ToLower(strings.TrimSpace(email))
|
||||
if normalized == "" {
|
||||
return dbuser.EmailEQ(email)
|
||||
}
|
||||
return predicate.User(func(s *entsql.Selector) {
|
||||
s.Where(entsql.ExprP(
|
||||
fmt.Sprintf("LOWER(TRIM(%s)) = LOWER(TRIM(?))", s.C(dbuser.FieldEmail)),
|
||||
normalized,
|
||||
))
|
||||
s.Where(entsql.P(func(b *entsql.Builder) {
|
||||
b.WriteString("LOWER(TRIM(").
|
||||
Ident(s.C(dbuser.FieldEmail)).
|
||||
WriteString(")) = ").
|
||||
Arg(normalized)
|
||||
}))
|
||||
})
|
||||
}
|
||||
|
||||
@@ -718,7 +720,16 @@ func ensurePendingOAuthIdentityForUser(ctx context.Context, tx *dbent.Tx, sessio
|
||||
}
|
||||
if identity != nil {
|
||||
if identity.UserID != userID {
|
||||
return nil, infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user")
|
||||
activeOwner, err := findActiveUserByID(ctx, client, identity.UserID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if activeOwner != nil {
|
||||
return nil, infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user")
|
||||
}
|
||||
return client.AuthIdentity.UpdateOneID(identity.ID).
|
||||
SetUserID(userID).
|
||||
Save(ctx)
|
||||
}
|
||||
return identity, nil
|
||||
}
|
||||
@@ -756,7 +767,7 @@ func ensurePendingWeChatOAuthIdentityForUser(ctx context.Context, tx *dbent.Tx,
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
identity, hasCanonicalKey, err := chooseWeChatIdentityForUser(identityRecords, userID, providerKey)
|
||||
identity, hasCanonicalKey, err := chooseWeChatIdentityForUser(ctx, client, identityRecords, userID, providerKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -773,7 +784,7 @@ func ensurePendingWeChatOAuthIdentityForUser(ctx context.Context, tx *dbent.Tx,
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
legacyOpenIDIdentity, _, err = chooseWeChatIdentityForUser(legacyOpenIDRecords, userID, providerKey)
|
||||
legacyOpenIDIdentity, _, err = chooseWeChatIdentityForUser(ctx, client, legacyOpenIDRecords, userID, providerKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -783,6 +794,9 @@ func ensurePendingWeChatOAuthIdentityForUser(ctx context.Context, tx *dbent.Tx,
|
||||
case identity != nil:
|
||||
update := client.AuthIdentity.UpdateOneID(identity.ID).
|
||||
SetMetadata(mergeOAuthMetadata(identity.Metadata, metadata))
|
||||
if identity.UserID != userID {
|
||||
update = update.SetUserID(userID)
|
||||
}
|
||||
if !strings.EqualFold(strings.TrimSpace(identity.ProviderKey), providerKey) && !hasCanonicalKey {
|
||||
update = update.SetProviderKey(providerKey)
|
||||
}
|
||||
@@ -838,7 +852,7 @@ func ensurePendingWeChatOAuthIdentityForUser(ctx context.Context, tx *dbent.Tx,
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
channelRecord, hasCanonicalChannelKey, err := chooseWeChatChannelForUser(channelRecords, userID, providerKey)
|
||||
channelRecord, hasCanonicalChannelKey, err := chooseWeChatChannelForUser(ctx, client, channelRecords, userID, providerKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -872,7 +886,7 @@ func ensurePendingWeChatOAuthIdentityForUser(ctx context.Context, tx *dbent.Tx,
|
||||
return identity, nil
|
||||
}
|
||||
|
||||
func chooseWeChatIdentityForUser(records []*dbent.AuthIdentity, userID int64, preferredProviderKey string) (*dbent.AuthIdentity, bool, error) {
|
||||
func chooseWeChatIdentityForUser(ctx context.Context, client *dbent.Client, records []*dbent.AuthIdentity, userID int64, preferredProviderKey string) (*dbent.AuthIdentity, bool, error) {
|
||||
var preferred *dbent.AuthIdentity
|
||||
var fallback *dbent.AuthIdentity
|
||||
hasCanonicalKey := false
|
||||
@@ -881,7 +895,13 @@ func chooseWeChatIdentityForUser(records []*dbent.AuthIdentity, userID int64, pr
|
||||
continue
|
||||
}
|
||||
if record.UserID != userID {
|
||||
return nil, false, infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user")
|
||||
activeOwner, err := findActiveUserByID(ctx, client, record.UserID)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
if activeOwner != nil {
|
||||
return nil, false, infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user")
|
||||
}
|
||||
}
|
||||
if strings.EqualFold(strings.TrimSpace(record.ProviderKey), preferredProviderKey) {
|
||||
hasCanonicalKey = true
|
||||
@@ -900,7 +920,7 @@ func chooseWeChatIdentityForUser(records []*dbent.AuthIdentity, userID int64, pr
|
||||
return fallback, hasCanonicalKey, nil
|
||||
}
|
||||
|
||||
func chooseWeChatChannelForUser(records []*dbent.AuthIdentityChannel, userID int64, preferredProviderKey string) (*dbent.AuthIdentityChannel, bool, error) {
|
||||
func chooseWeChatChannelForUser(ctx context.Context, client *dbent.Client, records []*dbent.AuthIdentityChannel, userID int64, preferredProviderKey string) (*dbent.AuthIdentityChannel, bool, error) {
|
||||
var preferred *dbent.AuthIdentityChannel
|
||||
var fallback *dbent.AuthIdentityChannel
|
||||
hasCanonicalKey := false
|
||||
@@ -909,7 +929,13 @@ func chooseWeChatChannelForUser(records []*dbent.AuthIdentityChannel, userID int
|
||||
continue
|
||||
}
|
||||
if record.Edges.Identity != nil && record.Edges.Identity.UserID != userID {
|
||||
return nil, false, infraerrors.Conflict("AUTH_IDENTITY_CHANNEL_OWNERSHIP_CONFLICT", "auth identity channel already belongs to another user")
|
||||
activeOwner, err := findActiveUserByID(ctx, client, record.Edges.Identity.UserID)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
if activeOwner != nil {
|
||||
return nil, false, infraerrors.Conflict("AUTH_IDENTITY_CHANNEL_OWNERSHIP_CONFLICT", "auth identity channel already belongs to another user")
|
||||
}
|
||||
}
|
||||
if strings.EqualFold(strings.TrimSpace(record.ProviderKey), preferredProviderKey) {
|
||||
hasCanonicalKey = true
|
||||
@@ -928,6 +954,20 @@ func chooseWeChatChannelForUser(records []*dbent.AuthIdentityChannel, userID int
|
||||
return fallback, hasCanonicalKey, nil
|
||||
}
|
||||
|
||||
func findActiveUserByID(ctx context.Context, client *dbent.Client, userID int64) (*dbent.User, error) {
|
||||
if client == nil || userID <= 0 {
|
||||
return nil, nil
|
||||
}
|
||||
userEntity, err := client.User.Get(ctx, userID)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, infraerrors.InternalServer("AUTH_IDENTITY_USER_LOOKUP_FAILED", "failed to load auth identity user").WithCause(err)
|
||||
}
|
||||
return userEntity, nil
|
||||
}
|
||||
|
||||
func channelRecordMetadata(channel *dbent.AuthIdentityChannel) map[string]any {
|
||||
if channel == nil {
|
||||
return map[string]any{}
|
||||
@@ -1343,7 +1383,7 @@ func (h *AuthHandler) bindPendingOAuthLogin(c *gin.Context, provider string) {
|
||||
return
|
||||
}
|
||||
if err := applyPendingOAuthBinding(c.Request.Context(), h.entClient(), h.authService, h.userService, session, decision, &user.ID, true, true); err != nil {
|
||||
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to bind pending oauth identity").WithCause(err))
|
||||
respondPendingOAuthBindingApplyError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1363,6 +1403,14 @@ func (h *AuthHandler) bindPendingOAuthLogin(c *gin.Context, provider string) {
|
||||
writeOAuthTokenPairResponse(c, tokenPair)
|
||||
}
|
||||
|
||||
func respondPendingOAuthBindingApplyError(c *gin.Context, err error) {
|
||||
if code := infraerrors.Code(err); code >= http.StatusBadRequest && code < http.StatusInternalServerError {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to bind pending oauth identity").WithCause(err))
|
||||
}
|
||||
|
||||
func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string) {
|
||||
var req createPendingOAuthAccountRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
@@ -1480,7 +1528,7 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string)
|
||||
if rollbackCreatedUser(err) {
|
||||
return
|
||||
}
|
||||
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to bind pending oauth identity").WithCause(err))
|
||||
respondPendingOAuthBindingApplyError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1514,7 +1562,7 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string)
|
||||
if rollbackCreatedUser(err) {
|
||||
return
|
||||
}
|
||||
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to bind pending oauth identity").WithCause(err))
|
||||
respondPendingOAuthBindingApplyError(c, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1358,6 +1358,80 @@ func TestBindOIDCOAuthLoginRejectsInvalidPasswordWithoutConsumingSession(t *test
|
||||
require.Nil(t, storedSession.ConsumedAt)
|
||||
}
|
||||
|
||||
func TestBindOIDCOAuthLoginReclaimsIdentityOwnedBySoftDeletedUser(t *testing.T) {
|
||||
handler, client := newOAuthPendingFlowTestHandler(t, false)
|
||||
ctx := context.Background()
|
||||
|
||||
oldOwnerHash, err := handler.authService.HashPassword("old-secret")
|
||||
require.NoError(t, err)
|
||||
oldOwner, err := client.User.Create().
|
||||
SetEmail("old-owner@example.com").
|
||||
SetUsername("old-owner").
|
||||
SetPasswordHash(oldOwnerHash).
|
||||
SetRole(service.RoleUser).
|
||||
SetStatus(service.StatusActive).
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
identity, err := client.AuthIdentity.Create().
|
||||
SetUserID(oldOwner.ID).
|
||||
SetProviderType("oidc").
|
||||
SetProviderKey("https://issuer.example").
|
||||
SetProviderSubject("oidc-bind-soft-deleted-123").
|
||||
SetMetadata(map[string]any{"username": "old-owner"}).
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = client.User.Delete().Where(dbuser.IDEQ(oldOwner.ID)).Exec(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
newOwnerHash, err := handler.authService.HashPassword("secret-123")
|
||||
require.NoError(t, err)
|
||||
newOwner, err := client.User.Create().
|
||||
SetEmail("owner@example.com").
|
||||
SetUsername("owner-user").
|
||||
SetPasswordHash(newOwnerHash).
|
||||
SetRole(service.RoleUser).
|
||||
SetStatus(service.StatusActive).
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
session, err := client.PendingAuthSession.Create().
|
||||
SetSessionToken("bind-login-soft-deleted-owner-session-token").
|
||||
SetIntent("adopt_existing_user_by_email").
|
||||
SetProviderType("oidc").
|
||||
SetProviderKey("https://issuer.example").
|
||||
SetProviderSubject("oidc-bind-soft-deleted-123").
|
||||
SetTargetUserID(newOwner.ID).
|
||||
SetResolvedEmail(newOwner.Email).
|
||||
SetBrowserSessionKey("bind-login-soft-deleted-owner-browser-session-key").
|
||||
SetUpstreamIdentityClaims(map[string]any{
|
||||
"username": "oidc_user",
|
||||
"suggested_display_name": "Recovered OIDC User",
|
||||
}).
|
||||
SetRedirectTo("/profile").
|
||||
SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
body := bytes.NewBufferString(`{"email":"owner@example.com","password":"secret-123","adopt_display_name":false,"adopt_avatar":false}`)
|
||||
recorder := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(recorder)
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/bind-login", body)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
|
||||
req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("bind-login-soft-deleted-owner-browser-session-key")})
|
||||
ginCtx.Request = req
|
||||
|
||||
handler.BindOIDCOAuthLogin(ginCtx)
|
||||
|
||||
require.Equal(t, http.StatusOK, recorder.Code)
|
||||
|
||||
identity, err = client.AuthIdentity.Get(ctx, identity.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, newOwner.ID, identity.UserID)
|
||||
}
|
||||
|
||||
func TestBindOIDCOAuthLoginAppliesFirstBindGrantOnce(t *testing.T) {
|
||||
defaultSubAssigner := &oauthPendingFlowDefaultSubAssignerStub{}
|
||||
handler, client := newOAuthPendingFlowTestHandlerWithDependencies(t, oauthPendingFlowTestHandlerOptions{
|
||||
|
||||
@@ -12,7 +12,9 @@ import (
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/apikey"
|
||||
"github.com/Wei-Shaw/sub2api/ent/authidentity"
|
||||
"github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
|
||||
dbgroup "github.com/Wei-Shaw/sub2api/ent/group"
|
||||
"github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
|
||||
"github.com/Wei-Shaw/sub2api/ent/predicate"
|
||||
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
|
||||
"github.com/Wei-Shaw/sub2api/ent/userallowedgroup"
|
||||
@@ -292,13 +294,57 @@ func normalizeEmailAuthIdentitySubject(email string) string {
|
||||
}
|
||||
|
||||
func (r *userRepository) Delete(ctx context.Context, id int64) error {
|
||||
affected, err := r.client.User.Delete().Where(dbuser.IDEQ(id)).Exec(ctx)
|
||||
tx, err := r.client.Tx(ctx)
|
||||
if err != nil && !errors.Is(err, dbent.ErrTxStarted) {
|
||||
return translatePersistenceError(err, service.ErrUserNotFound, nil)
|
||||
}
|
||||
|
||||
var txClient *dbent.Client
|
||||
if err == nil {
|
||||
defer func() { _ = tx.Rollback() }()
|
||||
txClient = tx.Client()
|
||||
} else {
|
||||
txClient = r.client
|
||||
}
|
||||
|
||||
identityIDs, err := txClient.AuthIdentity.Query().
|
||||
Where(authidentity.UserIDEQ(id)).
|
||||
IDs(ctx)
|
||||
if err != nil {
|
||||
return translatePersistenceError(err, service.ErrUserNotFound, nil)
|
||||
}
|
||||
if len(identityIDs) > 0 {
|
||||
if _, err := txClient.IdentityAdoptionDecision.Update().
|
||||
Where(identityadoptiondecision.IdentityIDIn(identityIDs...)).
|
||||
ClearIdentityID().
|
||||
Save(ctx); err != nil {
|
||||
return translatePersistenceError(err, service.ErrUserNotFound, nil)
|
||||
}
|
||||
if _, err := txClient.AuthIdentityChannel.Delete().
|
||||
Where(authidentitychannel.IdentityIDIn(identityIDs...)).
|
||||
Exec(ctx); err != nil {
|
||||
return translatePersistenceError(err, service.ErrUserNotFound, nil)
|
||||
}
|
||||
if _, err := txClient.AuthIdentity.Delete().
|
||||
Where(authidentity.UserIDEQ(id)).
|
||||
Exec(ctx); err != nil {
|
||||
return translatePersistenceError(err, service.ErrUserNotFound, nil)
|
||||
}
|
||||
}
|
||||
|
||||
affected, err := txClient.User.Delete().Where(dbuser.IDEQ(id)).Exec(ctx)
|
||||
if err != nil {
|
||||
return translatePersistenceError(err, service.ErrUserNotFound, nil)
|
||||
}
|
||||
if affected == 0 {
|
||||
return service.ErrUserNotFound
|
||||
}
|
||||
|
||||
if tx != nil {
|
||||
if err := tx.Commit(); err != nil {
|
||||
return translatePersistenceError(err, service.ErrUserNotFound, nil)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -645,15 +691,17 @@ func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool,
|
||||
}
|
||||
|
||||
func userEmailLookupPredicate(email string) predicate.User {
|
||||
normalized := strings.TrimSpace(email)
|
||||
normalized := strings.ToLower(strings.TrimSpace(email))
|
||||
if normalized == "" {
|
||||
return dbuser.EmailEQ(email)
|
||||
}
|
||||
return predicate.User(func(s *entsql.Selector) {
|
||||
s.Where(entsql.ExprP(
|
||||
fmt.Sprintf("LOWER(TRIM(%s)) = LOWER(TRIM(?))", s.C(dbuser.FieldEmail)),
|
||||
normalized,
|
||||
))
|
||||
s.Where(entsql.P(func(b *entsql.Builder) {
|
||||
b.WriteString("LOWER(TRIM(").
|
||||
Ident(s.C(dbuser.FieldEmail)).
|
||||
WriteString(")) = ").
|
||||
Arg(normalized)
|
||||
}))
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -8,6 +8,8 @@ import (
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/authidentity"
|
||||
"github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/suite"
|
||||
@@ -124,11 +126,27 @@ func (s *UserRepoSuite) TestGetByEmail() {
|
||||
s.Require().Equal(user.ID, got.ID)
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestGetByEmail_NormalizesSpacingAndCaseOnPostgres() {
|
||||
user := s.mustCreateUser(&service.User{Email: " Legacy@Example.com "})
|
||||
|
||||
got, err := s.repo.GetByEmail(s.ctx, " legacy@example.com ")
|
||||
s.Require().NoError(err, "GetByEmail normalized lookup")
|
||||
s.Require().Equal(user.ID, got.ID)
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestGetByEmail_NotFound() {
|
||||
_, err := s.repo.GetByEmail(s.ctx, "nonexistent@test.com")
|
||||
s.Require().Error(err, "expected error for non-existent email")
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestExistsByEmail_NormalizesSpacingAndCaseOnPostgres() {
|
||||
s.mustCreateUser(&service.User{Email: " Legacy@Example.com "})
|
||||
|
||||
exists, err := s.repo.ExistsByEmail(s.ctx, " LEGACY@example.com ")
|
||||
s.Require().NoError(err, "ExistsByEmail normalized lookup")
|
||||
s.Require().True(exists)
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestUpdate() {
|
||||
user := s.mustCreateUser(&service.User{Email: "update@test.com", Username: "original"})
|
||||
|
||||
@@ -152,6 +170,39 @@ func (s *UserRepoSuite) TestDelete() {
|
||||
s.Require().Error(err, "expected error after delete")
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestDeleteRemovesAuthIdentitiesAndChannels() {
|
||||
user := s.mustCreateUser(&service.User{Email: "delete-oauth@test.com"})
|
||||
|
||||
identity, err := s.client.AuthIdentity.Create().
|
||||
SetUserID(user.ID).
|
||||
SetProviderType("linuxdo").
|
||||
SetProviderKey("linuxdo").
|
||||
SetProviderSubject("delete-oauth-subject").
|
||||
Save(s.ctx)
|
||||
s.Require().NoError(err)
|
||||
|
||||
_, err = s.client.AuthIdentityChannel.Create().
|
||||
SetIdentityID(identity.ID).
|
||||
SetProviderType("wechat").
|
||||
SetProviderKey("wechat").
|
||||
SetChannel("open").
|
||||
SetChannelAppID("app-id").
|
||||
SetChannelSubject("openid-123").
|
||||
Save(s.ctx)
|
||||
s.Require().NoError(err)
|
||||
|
||||
err = s.repo.Delete(s.ctx, user.ID)
|
||||
s.Require().NoError(err)
|
||||
|
||||
identityCount, err := s.client.AuthIdentity.Query().Where(authidentity.UserIDEQ(user.ID)).Count(s.ctx)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Zero(identityCount)
|
||||
|
||||
channelCount, err := s.client.AuthIdentityChannel.Query().Where(authidentitychannel.IdentityIDEQ(identity.ID)).Count(s.ctx)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Zero(channelCount)
|
||||
}
|
||||
|
||||
// --- List / ListWithFilters ---
|
||||
|
||||
func (s *UserRepoSuite) TestList() {
|
||||
|
||||
@@ -11,8 +11,11 @@ import (
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
|
||||
dbpredicate "github.com/Wei-Shaw/sub2api/ent/predicate"
|
||||
"github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
|
||||
entsql "entgo.io/ent/dialect/sql"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -271,6 +274,24 @@ func (s *AuthPendingIdentityService) UpsertAdoptionDecision(ctx context.Context,
|
||||
return nil, fmt.Errorf("pending auth ent client is not configured")
|
||||
}
|
||||
|
||||
if input.IdentityID != nil && *input.IdentityID > 0 {
|
||||
if _, err := s.entClient.IdentityAdoptionDecision.Update().
|
||||
Where(
|
||||
identityadoptiondecision.IdentityIDEQ(*input.IdentityID),
|
||||
dbpredicate.IdentityAdoptionDecision(func(s *entsql.Selector) {
|
||||
col := s.C(identityadoptiondecision.FieldPendingAuthSessionID)
|
||||
s.Where(entsql.Or(
|
||||
entsql.IsNull(col),
|
||||
entsql.NEQ(col, input.PendingAuthSessionID),
|
||||
))
|
||||
}),
|
||||
).
|
||||
ClearIdentityID().
|
||||
Save(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
existing, err := s.entClient.IdentityAdoptionDecision.Query().
|
||||
Where(identityadoptiondecision.PendingAuthSessionIDEQ(input.PendingAuthSessionID)).
|
||||
Only(ctx)
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/enttest"
|
||||
"github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"entgo.io/ent/dialect"
|
||||
@@ -192,6 +193,139 @@ func TestAuthPendingIdentityService_UpsertAdoptionDecision(t *testing.T) {
|
||||
require.True(t, second.AdoptAvatar)
|
||||
}
|
||||
|
||||
func TestAuthPendingIdentityService_UpsertAdoptionDecision_ReassignsExistingIdentityReference(t *testing.T) {
|
||||
svc, client := newAuthPendingIdentityServiceTestClient(t)
|
||||
ctx := context.Background()
|
||||
|
||||
user, err := client.User.Create().
|
||||
SetEmail("adoption-reassign@example.com").
|
||||
SetPasswordHash("hash").
|
||||
SetRole(RoleUser).
|
||||
SetStatus(StatusActive).
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
identity, err := client.AuthIdentity.Create().
|
||||
SetUserID(user.ID).
|
||||
SetProviderType("wechat").
|
||||
SetProviderKey("wechat-open").
|
||||
SetProviderSubject("union-reassign").
|
||||
SetMetadata(map[string]any{}).
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
firstSession, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{
|
||||
Intent: "bind_current_user",
|
||||
Identity: PendingAuthIdentityKey{
|
||||
ProviderType: "wechat",
|
||||
ProviderKey: "wechat-open",
|
||||
ProviderSubject: "union-reassign",
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
firstDecision, err := svc.UpsertAdoptionDecision(ctx, PendingIdentityAdoptionDecisionInput{
|
||||
PendingAuthSessionID: firstSession.ID,
|
||||
IdentityID: &identity.ID,
|
||||
AdoptDisplayName: true,
|
||||
AdoptAvatar: false,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, firstDecision.IdentityID)
|
||||
require.Equal(t, identity.ID, *firstDecision.IdentityID)
|
||||
|
||||
secondSession, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{
|
||||
Intent: "bind_current_user",
|
||||
Identity: PendingAuthIdentityKey{
|
||||
ProviderType: "wechat",
|
||||
ProviderKey: "wechat-open",
|
||||
ProviderSubject: "union-reassign",
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
secondDecision, err := svc.UpsertAdoptionDecision(ctx, PendingIdentityAdoptionDecisionInput{
|
||||
PendingAuthSessionID: secondSession.ID,
|
||||
IdentityID: &identity.ID,
|
||||
AdoptDisplayName: false,
|
||||
AdoptAvatar: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, secondDecision.IdentityID)
|
||||
require.Equal(t, identity.ID, *secondDecision.IdentityID)
|
||||
|
||||
reloadedFirst, err := client.IdentityAdoptionDecision.Get(ctx, firstDecision.ID)
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, reloadedFirst.IdentityID)
|
||||
}
|
||||
|
||||
func TestAuthPendingIdentityService_UpsertAdoptionDecision_ClearsLegacyNullSessionReference(t *testing.T) {
|
||||
t.Skip("legacy NULL pending_auth_session_id rows only exist in production PostgreSQL history; sqlite unit schema rejects NULL")
|
||||
|
||||
svc, client := newAuthPendingIdentityServiceTestClient(t)
|
||||
ctx := context.Background()
|
||||
|
||||
user, err := client.User.Create().
|
||||
SetEmail("legacy-null-session@example.com").
|
||||
SetPasswordHash("hash").
|
||||
SetRole(RoleUser).
|
||||
SetStatus(StatusActive).
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
identity, err := client.AuthIdentity.Create().
|
||||
SetUserID(user.ID).
|
||||
SetProviderType("wechat").
|
||||
SetProviderKey("wechat-main").
|
||||
SetProviderSubject("legacy-null-session").
|
||||
SetMetadata(map[string]any{}).
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = client.ExecContext(
|
||||
ctx,
|
||||
`INSERT INTO identity_adoption_decisions
|
||||
(identity_id, adopt_display_name, adopt_avatar, decided_at, created_at, updated_at, pending_auth_session_id)
|
||||
VALUES (?, ?, ?, ?, ?, ?, NULL)`,
|
||||
identity.ID,
|
||||
true,
|
||||
false,
|
||||
time.Now().UTC(),
|
||||
time.Now().UTC(),
|
||||
time.Now().UTC(),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
legacyDecision, err := client.IdentityAdoptionDecision.Query().
|
||||
Where(identityadoptiondecision.IdentityIDEQ(identity.ID)).
|
||||
Only(ctx)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, legacyDecision.IdentityID)
|
||||
|
||||
session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{
|
||||
Intent: "bind_current_user",
|
||||
Identity: PendingAuthIdentityKey{
|
||||
ProviderType: "wechat",
|
||||
ProviderKey: "wechat-main",
|
||||
ProviderSubject: "legacy-null-session",
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
decision, err := svc.UpsertAdoptionDecision(ctx, PendingIdentityAdoptionDecisionInput{
|
||||
PendingAuthSessionID: session.ID,
|
||||
IdentityID: &identity.ID,
|
||||
AdoptDisplayName: false,
|
||||
AdoptAvatar: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, decision.IdentityID)
|
||||
require.Equal(t, identity.ID, *decision.IdentityID)
|
||||
|
||||
reloadedLegacy, err := client.IdentityAdoptionDecision.Get(ctx, legacyDecision.ID)
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, reloadedLegacy.IdentityID)
|
||||
}
|
||||
|
||||
func TestAuthPendingIdentityService_ConsumeBrowserSession(t *testing.T) {
|
||||
svc, _ := newAuthPendingIdentityServiceTestClient(t)
|
||||
ctx := context.Background()
|
||||
|
||||
Reference in New Issue
Block a user