fix: unify email identity sync and retry first-bind defaults
This commit is contained in:
@@ -209,14 +209,6 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *userRepository) EnsureEmailAuthIdentity(ctx context.Context, userID int64, email string) error {
|
||||
return ensureEmailAuthIdentityWithClient(ctx, r.client, userID, email, "service_dual_write")
|
||||
}
|
||||
|
||||
func (r *userRepository) ReplaceEmailAuthIdentity(ctx context.Context, userID int64, oldEmail, newEmail string) error {
|
||||
return replaceEmailAuthIdentityWithClient(ctx, r.client, userID, oldEmail, newEmail, "service_dual_write")
|
||||
}
|
||||
|
||||
func ensureEmailAuthIdentityWithClient(ctx context.Context, client *dbent.Client, userID int64, email string, source string) error {
|
||||
client = clientFromContext(ctx, client)
|
||||
if client == nil || userID <= 0 {
|
||||
|
||||
@@ -650,9 +650,6 @@ func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInpu
|
||||
if err := s.userRepo.Create(ctx, user); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := ensureEmailAuthIdentitySync(ctx, s.userRepo, user.ID, user.Email); err != nil {
|
||||
return nil, fmt.Errorf("sync email auth identity: %w", err)
|
||||
}
|
||||
s.assignDefaultSubscriptions(ctx, user.ID)
|
||||
return user, nil
|
||||
}
|
||||
@@ -688,7 +685,6 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
|
||||
oldConcurrency := user.Concurrency
|
||||
oldStatus := user.Status
|
||||
oldRole := user.Role
|
||||
oldEmail := user.Email
|
||||
|
||||
if input.Email != "" {
|
||||
user.Email = input.Email
|
||||
@@ -721,9 +717,6 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
|
||||
if err := s.userRepo.Update(ctx, user); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := replaceEmailAuthIdentitySync(ctx, s.userRepo, user.ID, oldEmail, user.Email); err != nil {
|
||||
return nil, fmt.Errorf("sync email auth identity: %w", err)
|
||||
}
|
||||
|
||||
// 同步用户专属分组倍率
|
||||
if input.GroupRates != nil && s.userGroupRateRepo != nil {
|
||||
|
||||
@@ -31,6 +31,8 @@ type emailSyncRepoStub struct {
|
||||
updated []*User
|
||||
ensureCalls []ensureEmailCall
|
||||
replaceCalls []replaceEmailCall
|
||||
ensureErr error
|
||||
replaceErr error
|
||||
}
|
||||
|
||||
func (s *emailSyncRepoStub) Create(_ context.Context, user *User) error {
|
||||
@@ -125,7 +127,7 @@ func (s *emailSyncRepoStub) DisableTotp(context.Context, int64) error { return n
|
||||
|
||||
func (s *emailSyncRepoStub) EnsureEmailAuthIdentity(_ context.Context, userID int64, email string) error {
|
||||
s.ensureCalls = append(s.ensureCalls, ensureEmailCall{userID: userID, email: email})
|
||||
return nil
|
||||
return s.ensureErr
|
||||
}
|
||||
|
||||
func (s *emailSyncRepoStub) ReplaceEmailAuthIdentity(_ context.Context, userID int64, oldEmail, newEmail string) error {
|
||||
@@ -134,11 +136,14 @@ func (s *emailSyncRepoStub) ReplaceEmailAuthIdentity(_ context.Context, userID i
|
||||
oldEmail: oldEmail,
|
||||
newEmail: newEmail,
|
||||
})
|
||||
return nil
|
||||
return s.replaceErr
|
||||
}
|
||||
|
||||
func TestAdminService_CreateUser_EnsuresEmailAuthIdentity(t *testing.T) {
|
||||
repo := &emailSyncRepoStub{nextID: 55}
|
||||
func TestAdminService_CreateUser_DoesNotReturnPartialSuccessFromEmailIdentityResync(t *testing.T) {
|
||||
repo := &emailSyncRepoStub{
|
||||
nextID: 55,
|
||||
ensureErr: fmt.Errorf("unexpected email resync"),
|
||||
}
|
||||
svc := &adminServiceImpl{userRepo: repo}
|
||||
|
||||
user, err := svc.CreateUser(context.Background(), &CreateUserInput{
|
||||
@@ -147,14 +152,12 @@ func TestAdminService_CreateUser_EnsuresEmailAuthIdentity(t *testing.T) {
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, user)
|
||||
require.Equal(t, []ensureEmailCall{{
|
||||
userID: 55,
|
||||
email: "admin-created@example.com",
|
||||
}}, repo.ensureCalls)
|
||||
require.Equal(t, int64(55), user.ID)
|
||||
require.Empty(t, repo.ensureCalls)
|
||||
require.Empty(t, repo.replaceCalls)
|
||||
}
|
||||
|
||||
func TestAdminService_UpdateUser_ReplacesEmailAuthIdentity(t *testing.T) {
|
||||
func TestAdminService_UpdateUser_DoesNotReturnPartialSuccessFromEmailIdentityResync(t *testing.T) {
|
||||
repo := &emailSyncRepoStub{
|
||||
user: &User{
|
||||
ID: 91,
|
||||
@@ -163,6 +166,7 @@ func TestAdminService_UpdateUser_ReplacesEmailAuthIdentity(t *testing.T) {
|
||||
Status: StatusActive,
|
||||
Concurrency: 3,
|
||||
},
|
||||
replaceErr: fmt.Errorf("unexpected email resync"),
|
||||
}
|
||||
svc := &adminServiceImpl{userRepo: repo}
|
||||
|
||||
@@ -172,10 +176,6 @@ func TestAdminService_UpdateUser_ReplacesEmailAuthIdentity(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, updated)
|
||||
require.Equal(t, "after@example.com", updated.Email)
|
||||
require.Equal(t, []replaceEmailCall{{
|
||||
userID: 91,
|
||||
oldEmail: "before@example.com",
|
||||
newEmail: "after@example.com",
|
||||
}}, repo.replaceCalls)
|
||||
require.Empty(t, repo.replaceCalls)
|
||||
require.Empty(t, repo.ensureCalls)
|
||||
}
|
||||
|
||||
@@ -768,9 +768,6 @@ func (s *AuthService) postAuthUserBootstrap(ctx context.Context, user *User, sig
|
||||
}
|
||||
s.updateUserSignupSource(ctx, user.ID, signupSource)
|
||||
|
||||
if signupSource == "email" {
|
||||
s.ensureEmailAuthIdentity(ctx, user)
|
||||
}
|
||||
if touchLogin {
|
||||
s.touchUserLogin(ctx, user.ID)
|
||||
}
|
||||
@@ -807,21 +804,81 @@ func (s *AuthService) backfillEmailIdentityOnSuccessfulLogin(ctx context.Context
|
||||
if s == nil || user == nil || user.ID <= 0 {
|
||||
return
|
||||
}
|
||||
if s.ensureEmailAuthIdentity(ctx, user) {
|
||||
identity, created := s.ensureEmailAuthIdentity(ctx, user)
|
||||
if s.shouldApplyEmailFirstBindDefaults(ctx, user.ID, identity, created) {
|
||||
if err := s.ApplyProviderDefaultSettingsOnFirstBind(ctx, user.ID, "email"); err != nil {
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Failed to apply email first bind defaults: user_id=%d err=%v", user.ID, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AuthService) ensureEmailAuthIdentity(ctx context.Context, user *User) bool {
|
||||
if s == nil || s.entClient == nil || user == nil || user.ID <= 0 {
|
||||
func (s *AuthService) shouldApplyEmailFirstBindDefaults(
|
||||
ctx context.Context,
|
||||
userID int64,
|
||||
identity *dbent.AuthIdentity,
|
||||
created bool,
|
||||
) bool {
|
||||
if created {
|
||||
return true
|
||||
}
|
||||
if s == nil || s.entClient == nil || userID <= 0 || identity == nil || identity.UserID != userID {
|
||||
return false
|
||||
}
|
||||
if emailAuthIdentitySource(identity.Metadata) != "auth_service_dual_write" {
|
||||
return false
|
||||
}
|
||||
|
||||
hasGrant, err := s.hasProviderGrantRecord(ctx, userID, "email", "first_bind")
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Failed to inspect email first bind grant state: user_id=%d err=%v", userID, err)
|
||||
return false
|
||||
}
|
||||
return !hasGrant
|
||||
}
|
||||
|
||||
func emailAuthIdentitySource(metadata map[string]any) string {
|
||||
if len(metadata) == 0 {
|
||||
return ""
|
||||
}
|
||||
raw, ok := metadata["source"]
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(fmt.Sprint(raw))
|
||||
}
|
||||
|
||||
func (s *AuthService) hasProviderGrantRecord(
|
||||
ctx context.Context,
|
||||
userID int64,
|
||||
providerType string,
|
||||
grantReason string,
|
||||
) (bool, error) {
|
||||
if s == nil || s.entClient == nil || userID <= 0 {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
rows, err := s.entClient.QueryContext(
|
||||
ctx,
|
||||
`SELECT 1 FROM user_provider_default_grants WHERE user_id = ? AND provider_type = ? AND grant_reason = ? LIMIT 1`,
|
||||
userID,
|
||||
strings.TrimSpace(providerType),
|
||||
strings.TrimSpace(grantReason),
|
||||
)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
defer rows.Close()
|
||||
return rows.Next(), rows.Err()
|
||||
}
|
||||
|
||||
func (s *AuthService) ensureEmailAuthIdentity(ctx context.Context, user *User) (*dbent.AuthIdentity, bool) {
|
||||
if s == nil || s.entClient == nil || user == nil || user.ID <= 0 {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
email := strings.ToLower(strings.TrimSpace(user.Email))
|
||||
if email == "" || isReservedEmail(email) {
|
||||
return false
|
||||
return nil, false
|
||||
}
|
||||
|
||||
client := s.entClient
|
||||
@@ -840,7 +897,7 @@ func (s *AuthService) ensureEmailAuthIdentity(ctx context.Context, user *User) b
|
||||
existed, err := buildQuery().Exist(ctx)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Failed to inspect email auth identity: user_id=%d email=%s err=%v", user.ID, email, err)
|
||||
return false
|
||||
return nil, false
|
||||
}
|
||||
|
||||
if !existed {
|
||||
@@ -861,21 +918,21 @@ func (s *AuthService) ensureEmailAuthIdentity(ctx context.Context, user *User) b
|
||||
DoNothing().
|
||||
Exec(ctx); err != nil {
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Failed to ensure email auth identity: user_id=%d email=%s err=%v", user.ID, email, err)
|
||||
return false
|
||||
return nil, false
|
||||
}
|
||||
}
|
||||
|
||||
identity, err := buildQuery().Only(ctx)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Failed to reload email auth identity: user_id=%d email=%s err=%v", user.ID, email, err)
|
||||
return false
|
||||
return nil, false
|
||||
}
|
||||
if identity.UserID != user.ID {
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Email auth identity ownership mismatch: user_id=%d email=%s owner_id=%d", user.ID, email, identity.UserID)
|
||||
return false
|
||||
return nil, false
|
||||
}
|
||||
|
||||
return !existed
|
||||
return identity, !existed
|
||||
}
|
||||
|
||||
func inferLegacySignupSource(email string) string {
|
||||
|
||||
@@ -5,6 +5,7 @@ package service_test
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -34,6 +35,24 @@ func (s *authIdentityDefaultSubAssignerStub) AssignOrExtendSubscription(
|
||||
return &service.UserSubscription{UserID: input.UserID, GroupID: input.GroupID}, true, nil
|
||||
}
|
||||
|
||||
type flakyAuthIdentityDefaultSubAssignerStub struct {
|
||||
failuresRemaining int
|
||||
calls []*service.AssignSubscriptionInput
|
||||
}
|
||||
|
||||
func (s *flakyAuthIdentityDefaultSubAssignerStub) AssignOrExtendSubscription(
|
||||
_ context.Context,
|
||||
input *service.AssignSubscriptionInput,
|
||||
) (*service.UserSubscription, bool, error) {
|
||||
cloned := *input
|
||||
s.calls = append(s.calls, &cloned)
|
||||
if s.failuresRemaining > 0 {
|
||||
s.failuresRemaining--
|
||||
return nil, false, errors.New("temporary assign failure")
|
||||
}
|
||||
return &service.UserSubscription{UserID: input.UserID, GroupID: input.GroupID}, true, nil
|
||||
}
|
||||
|
||||
type authIdentitySettingRepoStub struct {
|
||||
values map[string]string
|
||||
}
|
||||
@@ -333,6 +352,55 @@ func TestAuthServiceLogin_DoesNotApplyEmailFirstBindDefaultsWhenIdentityAlreadyE
|
||||
require.Equal(t, 0, countProviderGrantRecords(t, client, user.ID, "email", "first_bind"))
|
||||
}
|
||||
|
||||
func TestAuthServiceLogin_RetriesEmailFirstBindDefaultsAfterPreviousFailure(t *testing.T) {
|
||||
assigner := &flakyAuthIdentityDefaultSubAssignerStub{failuresRemaining: 1}
|
||||
svc, _, client := newAuthServiceWithEnt(t, map[string]string{
|
||||
service.SettingKeyRegistrationEnabled: "true",
|
||||
service.SettingKeyAuthSourceDefaultEmailBalance: "8.5",
|
||||
service.SettingKeyAuthSourceDefaultEmailConcurrency: "4",
|
||||
service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":11,"validity_days":30}]`,
|
||||
service.SettingKeyAuthSourceDefaultEmailGrantOnFirstBind: "true",
|
||||
}, assigner)
|
||||
ctx := context.Background()
|
||||
|
||||
passwordHash, err := svc.HashPassword("password")
|
||||
require.NoError(t, err)
|
||||
user, err := client.User.Create().
|
||||
SetEmail("retry-first-bind@example.com").
|
||||
SetUsername("retry-user").
|
||||
SetPasswordHash(passwordHash).
|
||||
SetBalance(1.5).
|
||||
SetConcurrency(2).
|
||||
SetRole(service.RoleUser).
|
||||
SetStatus(service.StatusActive).
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
token, gotUser, err := svc.Login(ctx, user.Email, "password")
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, token)
|
||||
require.NotNil(t, gotUser)
|
||||
|
||||
storedUser, err := client.User.Get(ctx, user.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1.5, storedUser.Balance)
|
||||
require.Equal(t, 2, storedUser.Concurrency)
|
||||
require.Len(t, assigner.calls, 1)
|
||||
require.Equal(t, 0, countProviderGrantRecords(t, client, user.ID, "email", "first_bind"))
|
||||
|
||||
token, gotUser, err = svc.Login(ctx, user.Email, "password")
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, token)
|
||||
require.NotNil(t, gotUser)
|
||||
|
||||
storedUser, err = client.User.Get(ctx, user.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 10.0, storedUser.Balance)
|
||||
require.Equal(t, 6, storedUser.Concurrency)
|
||||
require.Len(t, assigner.calls, 2)
|
||||
require.Equal(t, 1, countProviderGrantRecords(t, client, user.ID, "email", "first_bind"))
|
||||
}
|
||||
|
||||
func countProviderGrantRecords(
|
||||
t *testing.T,
|
||||
client *dbent.Client,
|
||||
|
||||
@@ -161,33 +161,6 @@ type userAuthIdentityReader interface {
|
||||
ListUserAuthIdentities(ctx context.Context, userID int64) ([]UserAuthIdentityRecord, error)
|
||||
}
|
||||
|
||||
type emailAuthIdentitySynchronizer interface {
|
||||
EnsureEmailAuthIdentity(ctx context.Context, userID int64, email string) error
|
||||
ReplaceEmailAuthIdentity(ctx context.Context, userID int64, oldEmail, newEmail string) error
|
||||
}
|
||||
|
||||
func ensureEmailAuthIdentitySync(ctx context.Context, repo UserRepository, userID int64, email string) error {
|
||||
syncer, ok := repo.(emailAuthIdentitySynchronizer)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return syncer.EnsureEmailAuthIdentity(ctx, userID, email)
|
||||
}
|
||||
|
||||
func replaceEmailAuthIdentitySync(ctx context.Context, repo UserRepository, userID int64, oldEmail, newEmail string) error {
|
||||
oldNormalized := strings.ToLower(strings.TrimSpace(oldEmail))
|
||||
newNormalized := strings.ToLower(strings.TrimSpace(newEmail))
|
||||
if oldNormalized == newNormalized {
|
||||
return nil
|
||||
}
|
||||
|
||||
syncer, ok := repo.(emailAuthIdentitySynchronizer)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return syncer.ReplaceEmailAuthIdentity(ctx, userID, oldEmail, newEmail)
|
||||
}
|
||||
|
||||
// ChangePasswordRequest 修改密码请求
|
||||
type ChangePasswordRequest struct {
|
||||
CurrentPassword string `json:"current_password"`
|
||||
@@ -281,7 +254,6 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat
|
||||
return nil, fmt.Errorf("get user: %w", err)
|
||||
}
|
||||
oldConcurrency := user.Concurrency
|
||||
oldEmail := user.Email
|
||||
|
||||
// 更新字段
|
||||
if req.Email != nil {
|
||||
@@ -326,9 +298,6 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat
|
||||
if err := s.userRepo.Update(ctx, user); err != nil {
|
||||
return nil, fmt.Errorf("update user: %w", err)
|
||||
}
|
||||
if err := replaceEmailAuthIdentitySync(ctx, s.userRepo, user.ID, oldEmail, user.Email); err != nil {
|
||||
return nil, fmt.Errorf("sync email auth identity: %w", err)
|
||||
}
|
||||
if s.authCacheInvalidator != nil && user.Concurrency != oldConcurrency {
|
||||
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
|
||||
}
|
||||
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestUpdateProfile_ReplacesEmailAuthIdentityWhenEmailChanges(t *testing.T) {
|
||||
func TestUpdateProfile_DoesNotReturnPartialSuccessFromEmailIdentityResync(t *testing.T) {
|
||||
repo := &emailSyncRepoStub{
|
||||
user: &User{
|
||||
ID: 19,
|
||||
@@ -17,6 +17,7 @@ func TestUpdateProfile_ReplacesEmailAuthIdentityWhenEmailChanges(t *testing.T) {
|
||||
Username: "tester",
|
||||
Concurrency: 2,
|
||||
},
|
||||
replaceErr: context.DeadlineExceeded,
|
||||
}
|
||||
svc := NewUserService(repo, nil, nil, nil)
|
||||
|
||||
@@ -28,10 +29,6 @@ func TestUpdateProfile_ReplacesEmailAuthIdentityWhenEmailChanges(t *testing.T) {
|
||||
require.NotNil(t, updated)
|
||||
require.Equal(t, newEmail, updated.Email)
|
||||
require.Equal(t, 1, repo.updateCalls)
|
||||
require.Equal(t, []replaceEmailCall{{
|
||||
userID: 19,
|
||||
oldEmail: "profile-before@example.com",
|
||||
newEmail: "profile-after@example.com",
|
||||
}}, repo.replaceCalls)
|
||||
require.Empty(t, repo.replaceCalls)
|
||||
require.Empty(t, repo.ensureCalls)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user