diff --git a/backend/internal/repository/user_repo.go b/backend/internal/repository/user_repo.go index 25d3f1d6..195776a3 100644 --- a/backend/internal/repository/user_repo.go +++ b/backend/internal/repository/user_repo.go @@ -209,14 +209,6 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error return nil } -func (r *userRepository) EnsureEmailAuthIdentity(ctx context.Context, userID int64, email string) error { - return ensureEmailAuthIdentityWithClient(ctx, r.client, userID, email, "service_dual_write") -} - -func (r *userRepository) ReplaceEmailAuthIdentity(ctx context.Context, userID int64, oldEmail, newEmail string) error { - return replaceEmailAuthIdentityWithClient(ctx, r.client, userID, oldEmail, newEmail, "service_dual_write") -} - func ensureEmailAuthIdentityWithClient(ctx context.Context, client *dbent.Client, userID int64, email string, source string) error { client = clientFromContext(ctx, client) if client == nil || userID <= 0 { diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index 10b85f76..ce1c1a77 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -650,9 +650,6 @@ func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInpu if err := s.userRepo.Create(ctx, user); err != nil { return nil, err } - if err := ensureEmailAuthIdentitySync(ctx, s.userRepo, user.ID, user.Email); err != nil { - return nil, fmt.Errorf("sync email auth identity: %w", err) - } s.assignDefaultSubscriptions(ctx, user.ID) return user, nil } @@ -688,7 +685,6 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda oldConcurrency := user.Concurrency oldStatus := user.Status oldRole := user.Role - oldEmail := user.Email if input.Email != "" { user.Email = input.Email @@ -721,9 +717,6 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda if err := s.userRepo.Update(ctx, user); err != nil { return nil, err } - if err := replaceEmailAuthIdentitySync(ctx, s.userRepo, user.ID, oldEmail, user.Email); err != nil { - return nil, fmt.Errorf("sync email auth identity: %w", err) - } // 同步用户专属分组倍率 if input.GroupRates != nil && s.userGroupRateRepo != nil { diff --git a/backend/internal/service/admin_service_email_identity_sync_test.go b/backend/internal/service/admin_service_email_identity_sync_test.go index d555d609..d6a7af9a 100644 --- a/backend/internal/service/admin_service_email_identity_sync_test.go +++ b/backend/internal/service/admin_service_email_identity_sync_test.go @@ -31,6 +31,8 @@ type emailSyncRepoStub struct { updated []*User ensureCalls []ensureEmailCall replaceCalls []replaceEmailCall + ensureErr error + replaceErr error } func (s *emailSyncRepoStub) Create(_ context.Context, user *User) error { @@ -125,7 +127,7 @@ func (s *emailSyncRepoStub) DisableTotp(context.Context, int64) error { return n func (s *emailSyncRepoStub) EnsureEmailAuthIdentity(_ context.Context, userID int64, email string) error { s.ensureCalls = append(s.ensureCalls, ensureEmailCall{userID: userID, email: email}) - return nil + return s.ensureErr } func (s *emailSyncRepoStub) ReplaceEmailAuthIdentity(_ context.Context, userID int64, oldEmail, newEmail string) error { @@ -134,11 +136,14 @@ func (s *emailSyncRepoStub) ReplaceEmailAuthIdentity(_ context.Context, userID i oldEmail: oldEmail, newEmail: newEmail, }) - return nil + return s.replaceErr } -func TestAdminService_CreateUser_EnsuresEmailAuthIdentity(t *testing.T) { - repo := &emailSyncRepoStub{nextID: 55} +func TestAdminService_CreateUser_DoesNotReturnPartialSuccessFromEmailIdentityResync(t *testing.T) { + repo := &emailSyncRepoStub{ + nextID: 55, + ensureErr: fmt.Errorf("unexpected email resync"), + } svc := &adminServiceImpl{userRepo: repo} user, err := svc.CreateUser(context.Background(), &CreateUserInput{ @@ -147,14 +152,12 @@ func TestAdminService_CreateUser_EnsuresEmailAuthIdentity(t *testing.T) { }) require.NoError(t, err) require.NotNil(t, user) - require.Equal(t, []ensureEmailCall{{ - userID: 55, - email: "admin-created@example.com", - }}, repo.ensureCalls) + require.Equal(t, int64(55), user.ID) + require.Empty(t, repo.ensureCalls) require.Empty(t, repo.replaceCalls) } -func TestAdminService_UpdateUser_ReplacesEmailAuthIdentity(t *testing.T) { +func TestAdminService_UpdateUser_DoesNotReturnPartialSuccessFromEmailIdentityResync(t *testing.T) { repo := &emailSyncRepoStub{ user: &User{ ID: 91, @@ -163,6 +166,7 @@ func TestAdminService_UpdateUser_ReplacesEmailAuthIdentity(t *testing.T) { Status: StatusActive, Concurrency: 3, }, + replaceErr: fmt.Errorf("unexpected email resync"), } svc := &adminServiceImpl{userRepo: repo} @@ -172,10 +176,6 @@ func TestAdminService_UpdateUser_ReplacesEmailAuthIdentity(t *testing.T) { require.NoError(t, err) require.NotNil(t, updated) require.Equal(t, "after@example.com", updated.Email) - require.Equal(t, []replaceEmailCall{{ - userID: 91, - oldEmail: "before@example.com", - newEmail: "after@example.com", - }}, repo.replaceCalls) + require.Empty(t, repo.replaceCalls) require.Empty(t, repo.ensureCalls) } diff --git a/backend/internal/service/auth_service.go b/backend/internal/service/auth_service.go index d0d5e4e3..00fefd82 100644 --- a/backend/internal/service/auth_service.go +++ b/backend/internal/service/auth_service.go @@ -768,9 +768,6 @@ func (s *AuthService) postAuthUserBootstrap(ctx context.Context, user *User, sig } s.updateUserSignupSource(ctx, user.ID, signupSource) - if signupSource == "email" { - s.ensureEmailAuthIdentity(ctx, user) - } if touchLogin { s.touchUserLogin(ctx, user.ID) } @@ -807,21 +804,81 @@ func (s *AuthService) backfillEmailIdentityOnSuccessfulLogin(ctx context.Context if s == nil || user == nil || user.ID <= 0 { return } - if s.ensureEmailAuthIdentity(ctx, user) { + identity, created := s.ensureEmailAuthIdentity(ctx, user) + if s.shouldApplyEmailFirstBindDefaults(ctx, user.ID, identity, created) { if err := s.ApplyProviderDefaultSettingsOnFirstBind(ctx, user.ID, "email"); err != nil { logger.LegacyPrintf("service.auth", "[Auth] Failed to apply email first bind defaults: user_id=%d err=%v", user.ID, err) } } } -func (s *AuthService) ensureEmailAuthIdentity(ctx context.Context, user *User) bool { - if s == nil || s.entClient == nil || user == nil || user.ID <= 0 { +func (s *AuthService) shouldApplyEmailFirstBindDefaults( + ctx context.Context, + userID int64, + identity *dbent.AuthIdentity, + created bool, +) bool { + if created { + return true + } + if s == nil || s.entClient == nil || userID <= 0 || identity == nil || identity.UserID != userID { return false } + if emailAuthIdentitySource(identity.Metadata) != "auth_service_dual_write" { + return false + } + + hasGrant, err := s.hasProviderGrantRecord(ctx, userID, "email", "first_bind") + if err != nil { + logger.LegacyPrintf("service.auth", "[Auth] Failed to inspect email first bind grant state: user_id=%d err=%v", userID, err) + return false + } + return !hasGrant +} + +func emailAuthIdentitySource(metadata map[string]any) string { + if len(metadata) == 0 { + return "" + } + raw, ok := metadata["source"] + if !ok { + return "" + } + return strings.TrimSpace(fmt.Sprint(raw)) +} + +func (s *AuthService) hasProviderGrantRecord( + ctx context.Context, + userID int64, + providerType string, + grantReason string, +) (bool, error) { + if s == nil || s.entClient == nil || userID <= 0 { + return false, nil + } + + rows, err := s.entClient.QueryContext( + ctx, + `SELECT 1 FROM user_provider_default_grants WHERE user_id = ? AND provider_type = ? AND grant_reason = ? LIMIT 1`, + userID, + strings.TrimSpace(providerType), + strings.TrimSpace(grantReason), + ) + if err != nil { + return false, err + } + defer rows.Close() + return rows.Next(), rows.Err() +} + +func (s *AuthService) ensureEmailAuthIdentity(ctx context.Context, user *User) (*dbent.AuthIdentity, bool) { + if s == nil || s.entClient == nil || user == nil || user.ID <= 0 { + return nil, false + } email := strings.ToLower(strings.TrimSpace(user.Email)) if email == "" || isReservedEmail(email) { - return false + return nil, false } client := s.entClient @@ -840,7 +897,7 @@ func (s *AuthService) ensureEmailAuthIdentity(ctx context.Context, user *User) b existed, err := buildQuery().Exist(ctx) if err != nil { logger.LegacyPrintf("service.auth", "[Auth] Failed to inspect email auth identity: user_id=%d email=%s err=%v", user.ID, email, err) - return false + return nil, false } if !existed { @@ -861,21 +918,21 @@ func (s *AuthService) ensureEmailAuthIdentity(ctx context.Context, user *User) b DoNothing(). Exec(ctx); err != nil { logger.LegacyPrintf("service.auth", "[Auth] Failed to ensure email auth identity: user_id=%d email=%s err=%v", user.ID, email, err) - return false + return nil, false } } identity, err := buildQuery().Only(ctx) if err != nil { logger.LegacyPrintf("service.auth", "[Auth] Failed to reload email auth identity: user_id=%d email=%s err=%v", user.ID, email, err) - return false + return nil, false } if identity.UserID != user.ID { logger.LegacyPrintf("service.auth", "[Auth] Email auth identity ownership mismatch: user_id=%d email=%s owner_id=%d", user.ID, email, identity.UserID) - return false + return nil, false } - return !existed + return identity, !existed } func inferLegacySignupSource(email string) string { diff --git a/backend/internal/service/auth_service_identity_sync_test.go b/backend/internal/service/auth_service_identity_sync_test.go index e2a94b13..95c9c933 100644 --- a/backend/internal/service/auth_service_identity_sync_test.go +++ b/backend/internal/service/auth_service_identity_sync_test.go @@ -5,6 +5,7 @@ package service_test import ( "context" "database/sql" + "errors" "testing" "time" @@ -34,6 +35,24 @@ func (s *authIdentityDefaultSubAssignerStub) AssignOrExtendSubscription( return &service.UserSubscription{UserID: input.UserID, GroupID: input.GroupID}, true, nil } +type flakyAuthIdentityDefaultSubAssignerStub struct { + failuresRemaining int + calls []*service.AssignSubscriptionInput +} + +func (s *flakyAuthIdentityDefaultSubAssignerStub) AssignOrExtendSubscription( + _ context.Context, + input *service.AssignSubscriptionInput, +) (*service.UserSubscription, bool, error) { + cloned := *input + s.calls = append(s.calls, &cloned) + if s.failuresRemaining > 0 { + s.failuresRemaining-- + return nil, false, errors.New("temporary assign failure") + } + return &service.UserSubscription{UserID: input.UserID, GroupID: input.GroupID}, true, nil +} + type authIdentitySettingRepoStub struct { values map[string]string } @@ -333,6 +352,55 @@ func TestAuthServiceLogin_DoesNotApplyEmailFirstBindDefaultsWhenIdentityAlreadyE require.Equal(t, 0, countProviderGrantRecords(t, client, user.ID, "email", "first_bind")) } +func TestAuthServiceLogin_RetriesEmailFirstBindDefaultsAfterPreviousFailure(t *testing.T) { + assigner := &flakyAuthIdentityDefaultSubAssignerStub{failuresRemaining: 1} + svc, _, client := newAuthServiceWithEnt(t, map[string]string{ + service.SettingKeyRegistrationEnabled: "true", + service.SettingKeyAuthSourceDefaultEmailBalance: "8.5", + service.SettingKeyAuthSourceDefaultEmailConcurrency: "4", + service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":11,"validity_days":30}]`, + service.SettingKeyAuthSourceDefaultEmailGrantOnFirstBind: "true", + }, assigner) + ctx := context.Background() + + passwordHash, err := svc.HashPassword("password") + require.NoError(t, err) + user, err := client.User.Create(). + SetEmail("retry-first-bind@example.com"). + SetUsername("retry-user"). + SetPasswordHash(passwordHash). + SetBalance(1.5). + SetConcurrency(2). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + token, gotUser, err := svc.Login(ctx, user.Email, "password") + require.NoError(t, err) + require.NotEmpty(t, token) + require.NotNil(t, gotUser) + + storedUser, err := client.User.Get(ctx, user.ID) + require.NoError(t, err) + require.Equal(t, 1.5, storedUser.Balance) + require.Equal(t, 2, storedUser.Concurrency) + require.Len(t, assigner.calls, 1) + require.Equal(t, 0, countProviderGrantRecords(t, client, user.ID, "email", "first_bind")) + + token, gotUser, err = svc.Login(ctx, user.Email, "password") + require.NoError(t, err) + require.NotEmpty(t, token) + require.NotNil(t, gotUser) + + storedUser, err = client.User.Get(ctx, user.ID) + require.NoError(t, err) + require.Equal(t, 10.0, storedUser.Balance) + require.Equal(t, 6, storedUser.Concurrency) + require.Len(t, assigner.calls, 2) + require.Equal(t, 1, countProviderGrantRecords(t, client, user.ID, "email", "first_bind")) +} + func countProviderGrantRecords( t *testing.T, client *dbent.Client, diff --git a/backend/internal/service/user_service.go b/backend/internal/service/user_service.go index cd1bc2bb..7c2ca2d0 100644 --- a/backend/internal/service/user_service.go +++ b/backend/internal/service/user_service.go @@ -161,33 +161,6 @@ type userAuthIdentityReader interface { ListUserAuthIdentities(ctx context.Context, userID int64) ([]UserAuthIdentityRecord, error) } -type emailAuthIdentitySynchronizer interface { - EnsureEmailAuthIdentity(ctx context.Context, userID int64, email string) error - ReplaceEmailAuthIdentity(ctx context.Context, userID int64, oldEmail, newEmail string) error -} - -func ensureEmailAuthIdentitySync(ctx context.Context, repo UserRepository, userID int64, email string) error { - syncer, ok := repo.(emailAuthIdentitySynchronizer) - if !ok { - return nil - } - return syncer.EnsureEmailAuthIdentity(ctx, userID, email) -} - -func replaceEmailAuthIdentitySync(ctx context.Context, repo UserRepository, userID int64, oldEmail, newEmail string) error { - oldNormalized := strings.ToLower(strings.TrimSpace(oldEmail)) - newNormalized := strings.ToLower(strings.TrimSpace(newEmail)) - if oldNormalized == newNormalized { - return nil - } - - syncer, ok := repo.(emailAuthIdentitySynchronizer) - if !ok { - return nil - } - return syncer.ReplaceEmailAuthIdentity(ctx, userID, oldEmail, newEmail) -} - // ChangePasswordRequest 修改密码请求 type ChangePasswordRequest struct { CurrentPassword string `json:"current_password"` @@ -281,7 +254,6 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat return nil, fmt.Errorf("get user: %w", err) } oldConcurrency := user.Concurrency - oldEmail := user.Email // 更新字段 if req.Email != nil { @@ -326,9 +298,6 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat if err := s.userRepo.Update(ctx, user); err != nil { return nil, fmt.Errorf("update user: %w", err) } - if err := replaceEmailAuthIdentitySync(ctx, s.userRepo, user.ID, oldEmail, user.Email); err != nil { - return nil, fmt.Errorf("sync email auth identity: %w", err) - } if s.authCacheInvalidator != nil && user.Concurrency != oldConcurrency { s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID) } diff --git a/backend/internal/service/user_service_email_identity_sync_test.go b/backend/internal/service/user_service_email_identity_sync_test.go index 8109b368..702b3b1a 100644 --- a/backend/internal/service/user_service_email_identity_sync_test.go +++ b/backend/internal/service/user_service_email_identity_sync_test.go @@ -9,7 +9,7 @@ import ( "github.com/stretchr/testify/require" ) -func TestUpdateProfile_ReplacesEmailAuthIdentityWhenEmailChanges(t *testing.T) { +func TestUpdateProfile_DoesNotReturnPartialSuccessFromEmailIdentityResync(t *testing.T) { repo := &emailSyncRepoStub{ user: &User{ ID: 19, @@ -17,6 +17,7 @@ func TestUpdateProfile_ReplacesEmailAuthIdentityWhenEmailChanges(t *testing.T) { Username: "tester", Concurrency: 2, }, + replaceErr: context.DeadlineExceeded, } svc := NewUserService(repo, nil, nil, nil) @@ -28,10 +29,6 @@ func TestUpdateProfile_ReplacesEmailAuthIdentityWhenEmailChanges(t *testing.T) { require.NotNil(t, updated) require.Equal(t, newEmail, updated.Email) require.Equal(t, 1, repo.updateCalls) - require.Equal(t, []replaceEmailCall{{ - userID: 19, - oldEmail: "profile-before@example.com", - newEmail: "profile-after@example.com", - }}, repo.replaceCalls) + require.Empty(t, repo.replaceCalls) require.Empty(t, repo.ensureCalls) }