fix: unify email identity sync and retry first-bind defaults

This commit is contained in:
IanShaw027
2026-04-21 01:00:59 +08:00
parent 7a9488ff37
commit ea27ac6fd7
7 changed files with 154 additions and 78 deletions

View File

@@ -209,14 +209,6 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
return nil
}
func (r *userRepository) EnsureEmailAuthIdentity(ctx context.Context, userID int64, email string) error {
return ensureEmailAuthIdentityWithClient(ctx, r.client, userID, email, "service_dual_write")
}
func (r *userRepository) ReplaceEmailAuthIdentity(ctx context.Context, userID int64, oldEmail, newEmail string) error {
return replaceEmailAuthIdentityWithClient(ctx, r.client, userID, oldEmail, newEmail, "service_dual_write")
}
func ensureEmailAuthIdentityWithClient(ctx context.Context, client *dbent.Client, userID int64, email string, source string) error {
client = clientFromContext(ctx, client)
if client == nil || userID <= 0 {

View File

@@ -650,9 +650,6 @@ func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInpu
if err := s.userRepo.Create(ctx, user); err != nil {
return nil, err
}
if err := ensureEmailAuthIdentitySync(ctx, s.userRepo, user.ID, user.Email); err != nil {
return nil, fmt.Errorf("sync email auth identity: %w", err)
}
s.assignDefaultSubscriptions(ctx, user.ID)
return user, nil
}
@@ -688,7 +685,6 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
oldConcurrency := user.Concurrency
oldStatus := user.Status
oldRole := user.Role
oldEmail := user.Email
if input.Email != "" {
user.Email = input.Email
@@ -721,9 +717,6 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
if err := s.userRepo.Update(ctx, user); err != nil {
return nil, err
}
if err := replaceEmailAuthIdentitySync(ctx, s.userRepo, user.ID, oldEmail, user.Email); err != nil {
return nil, fmt.Errorf("sync email auth identity: %w", err)
}
// 同步用户专属分组倍率
if input.GroupRates != nil && s.userGroupRateRepo != nil {

View File

@@ -31,6 +31,8 @@ type emailSyncRepoStub struct {
updated []*User
ensureCalls []ensureEmailCall
replaceCalls []replaceEmailCall
ensureErr error
replaceErr error
}
func (s *emailSyncRepoStub) Create(_ context.Context, user *User) error {
@@ -125,7 +127,7 @@ func (s *emailSyncRepoStub) DisableTotp(context.Context, int64) error { return n
func (s *emailSyncRepoStub) EnsureEmailAuthIdentity(_ context.Context, userID int64, email string) error {
s.ensureCalls = append(s.ensureCalls, ensureEmailCall{userID: userID, email: email})
return nil
return s.ensureErr
}
func (s *emailSyncRepoStub) ReplaceEmailAuthIdentity(_ context.Context, userID int64, oldEmail, newEmail string) error {
@@ -134,11 +136,14 @@ func (s *emailSyncRepoStub) ReplaceEmailAuthIdentity(_ context.Context, userID i
oldEmail: oldEmail,
newEmail: newEmail,
})
return nil
return s.replaceErr
}
func TestAdminService_CreateUser_EnsuresEmailAuthIdentity(t *testing.T) {
repo := &emailSyncRepoStub{nextID: 55}
func TestAdminService_CreateUser_DoesNotReturnPartialSuccessFromEmailIdentityResync(t *testing.T) {
repo := &emailSyncRepoStub{
nextID: 55,
ensureErr: fmt.Errorf("unexpected email resync"),
}
svc := &adminServiceImpl{userRepo: repo}
user, err := svc.CreateUser(context.Background(), &CreateUserInput{
@@ -147,14 +152,12 @@ func TestAdminService_CreateUser_EnsuresEmailAuthIdentity(t *testing.T) {
})
require.NoError(t, err)
require.NotNil(t, user)
require.Equal(t, []ensureEmailCall{{
userID: 55,
email: "admin-created@example.com",
}}, repo.ensureCalls)
require.Equal(t, int64(55), user.ID)
require.Empty(t, repo.ensureCalls)
require.Empty(t, repo.replaceCalls)
}
func TestAdminService_UpdateUser_ReplacesEmailAuthIdentity(t *testing.T) {
func TestAdminService_UpdateUser_DoesNotReturnPartialSuccessFromEmailIdentityResync(t *testing.T) {
repo := &emailSyncRepoStub{
user: &User{
ID: 91,
@@ -163,6 +166,7 @@ func TestAdminService_UpdateUser_ReplacesEmailAuthIdentity(t *testing.T) {
Status: StatusActive,
Concurrency: 3,
},
replaceErr: fmt.Errorf("unexpected email resync"),
}
svc := &adminServiceImpl{userRepo: repo}
@@ -172,10 +176,6 @@ func TestAdminService_UpdateUser_ReplacesEmailAuthIdentity(t *testing.T) {
require.NoError(t, err)
require.NotNil(t, updated)
require.Equal(t, "after@example.com", updated.Email)
require.Equal(t, []replaceEmailCall{{
userID: 91,
oldEmail: "before@example.com",
newEmail: "after@example.com",
}}, repo.replaceCalls)
require.Empty(t, repo.replaceCalls)
require.Empty(t, repo.ensureCalls)
}

View File

@@ -768,9 +768,6 @@ func (s *AuthService) postAuthUserBootstrap(ctx context.Context, user *User, sig
}
s.updateUserSignupSource(ctx, user.ID, signupSource)
if signupSource == "email" {
s.ensureEmailAuthIdentity(ctx, user)
}
if touchLogin {
s.touchUserLogin(ctx, user.ID)
}
@@ -807,21 +804,81 @@ func (s *AuthService) backfillEmailIdentityOnSuccessfulLogin(ctx context.Context
if s == nil || user == nil || user.ID <= 0 {
return
}
if s.ensureEmailAuthIdentity(ctx, user) {
identity, created := s.ensureEmailAuthIdentity(ctx, user)
if s.shouldApplyEmailFirstBindDefaults(ctx, user.ID, identity, created) {
if err := s.ApplyProviderDefaultSettingsOnFirstBind(ctx, user.ID, "email"); err != nil {
logger.LegacyPrintf("service.auth", "[Auth] Failed to apply email first bind defaults: user_id=%d err=%v", user.ID, err)
}
}
}
func (s *AuthService) ensureEmailAuthIdentity(ctx context.Context, user *User) bool {
if s == nil || s.entClient == nil || user == nil || user.ID <= 0 {
func (s *AuthService) shouldApplyEmailFirstBindDefaults(
ctx context.Context,
userID int64,
identity *dbent.AuthIdentity,
created bool,
) bool {
if created {
return true
}
if s == nil || s.entClient == nil || userID <= 0 || identity == nil || identity.UserID != userID {
return false
}
if emailAuthIdentitySource(identity.Metadata) != "auth_service_dual_write" {
return false
}
hasGrant, err := s.hasProviderGrantRecord(ctx, userID, "email", "first_bind")
if err != nil {
logger.LegacyPrintf("service.auth", "[Auth] Failed to inspect email first bind grant state: user_id=%d err=%v", userID, err)
return false
}
return !hasGrant
}
func emailAuthIdentitySource(metadata map[string]any) string {
if len(metadata) == 0 {
return ""
}
raw, ok := metadata["source"]
if !ok {
return ""
}
return strings.TrimSpace(fmt.Sprint(raw))
}
func (s *AuthService) hasProviderGrantRecord(
ctx context.Context,
userID int64,
providerType string,
grantReason string,
) (bool, error) {
if s == nil || s.entClient == nil || userID <= 0 {
return false, nil
}
rows, err := s.entClient.QueryContext(
ctx,
`SELECT 1 FROM user_provider_default_grants WHERE user_id = ? AND provider_type = ? AND grant_reason = ? LIMIT 1`,
userID,
strings.TrimSpace(providerType),
strings.TrimSpace(grantReason),
)
if err != nil {
return false, err
}
defer rows.Close()
return rows.Next(), rows.Err()
}
func (s *AuthService) ensureEmailAuthIdentity(ctx context.Context, user *User) (*dbent.AuthIdentity, bool) {
if s == nil || s.entClient == nil || user == nil || user.ID <= 0 {
return nil, false
}
email := strings.ToLower(strings.TrimSpace(user.Email))
if email == "" || isReservedEmail(email) {
return false
return nil, false
}
client := s.entClient
@@ -840,7 +897,7 @@ func (s *AuthService) ensureEmailAuthIdentity(ctx context.Context, user *User) b
existed, err := buildQuery().Exist(ctx)
if err != nil {
logger.LegacyPrintf("service.auth", "[Auth] Failed to inspect email auth identity: user_id=%d email=%s err=%v", user.ID, email, err)
return false
return nil, false
}
if !existed {
@@ -861,21 +918,21 @@ func (s *AuthService) ensureEmailAuthIdentity(ctx context.Context, user *User) b
DoNothing().
Exec(ctx); err != nil {
logger.LegacyPrintf("service.auth", "[Auth] Failed to ensure email auth identity: user_id=%d email=%s err=%v", user.ID, email, err)
return false
return nil, false
}
}
identity, err := buildQuery().Only(ctx)
if err != nil {
logger.LegacyPrintf("service.auth", "[Auth] Failed to reload email auth identity: user_id=%d email=%s err=%v", user.ID, email, err)
return false
return nil, false
}
if identity.UserID != user.ID {
logger.LegacyPrintf("service.auth", "[Auth] Email auth identity ownership mismatch: user_id=%d email=%s owner_id=%d", user.ID, email, identity.UserID)
return false
return nil, false
}
return !existed
return identity, !existed
}
func inferLegacySignupSource(email string) string {

View File

@@ -5,6 +5,7 @@ package service_test
import (
"context"
"database/sql"
"errors"
"testing"
"time"
@@ -34,6 +35,24 @@ func (s *authIdentityDefaultSubAssignerStub) AssignOrExtendSubscription(
return &service.UserSubscription{UserID: input.UserID, GroupID: input.GroupID}, true, nil
}
type flakyAuthIdentityDefaultSubAssignerStub struct {
failuresRemaining int
calls []*service.AssignSubscriptionInput
}
func (s *flakyAuthIdentityDefaultSubAssignerStub) AssignOrExtendSubscription(
_ context.Context,
input *service.AssignSubscriptionInput,
) (*service.UserSubscription, bool, error) {
cloned := *input
s.calls = append(s.calls, &cloned)
if s.failuresRemaining > 0 {
s.failuresRemaining--
return nil, false, errors.New("temporary assign failure")
}
return &service.UserSubscription{UserID: input.UserID, GroupID: input.GroupID}, true, nil
}
type authIdentitySettingRepoStub struct {
values map[string]string
}
@@ -333,6 +352,55 @@ func TestAuthServiceLogin_DoesNotApplyEmailFirstBindDefaultsWhenIdentityAlreadyE
require.Equal(t, 0, countProviderGrantRecords(t, client, user.ID, "email", "first_bind"))
}
func TestAuthServiceLogin_RetriesEmailFirstBindDefaultsAfterPreviousFailure(t *testing.T) {
assigner := &flakyAuthIdentityDefaultSubAssignerStub{failuresRemaining: 1}
svc, _, client := newAuthServiceWithEnt(t, map[string]string{
service.SettingKeyRegistrationEnabled: "true",
service.SettingKeyAuthSourceDefaultEmailBalance: "8.5",
service.SettingKeyAuthSourceDefaultEmailConcurrency: "4",
service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":11,"validity_days":30}]`,
service.SettingKeyAuthSourceDefaultEmailGrantOnFirstBind: "true",
}, assigner)
ctx := context.Background()
passwordHash, err := svc.HashPassword("password")
require.NoError(t, err)
user, err := client.User.Create().
SetEmail("retry-first-bind@example.com").
SetUsername("retry-user").
SetPasswordHash(passwordHash).
SetBalance(1.5).
SetConcurrency(2).
SetRole(service.RoleUser).
SetStatus(service.StatusActive).
Save(ctx)
require.NoError(t, err)
token, gotUser, err := svc.Login(ctx, user.Email, "password")
require.NoError(t, err)
require.NotEmpty(t, token)
require.NotNil(t, gotUser)
storedUser, err := client.User.Get(ctx, user.ID)
require.NoError(t, err)
require.Equal(t, 1.5, storedUser.Balance)
require.Equal(t, 2, storedUser.Concurrency)
require.Len(t, assigner.calls, 1)
require.Equal(t, 0, countProviderGrantRecords(t, client, user.ID, "email", "first_bind"))
token, gotUser, err = svc.Login(ctx, user.Email, "password")
require.NoError(t, err)
require.NotEmpty(t, token)
require.NotNil(t, gotUser)
storedUser, err = client.User.Get(ctx, user.ID)
require.NoError(t, err)
require.Equal(t, 10.0, storedUser.Balance)
require.Equal(t, 6, storedUser.Concurrency)
require.Len(t, assigner.calls, 2)
require.Equal(t, 1, countProviderGrantRecords(t, client, user.ID, "email", "first_bind"))
}
func countProviderGrantRecords(
t *testing.T,
client *dbent.Client,

View File

@@ -161,33 +161,6 @@ type userAuthIdentityReader interface {
ListUserAuthIdentities(ctx context.Context, userID int64) ([]UserAuthIdentityRecord, error)
}
type emailAuthIdentitySynchronizer interface {
EnsureEmailAuthIdentity(ctx context.Context, userID int64, email string) error
ReplaceEmailAuthIdentity(ctx context.Context, userID int64, oldEmail, newEmail string) error
}
func ensureEmailAuthIdentitySync(ctx context.Context, repo UserRepository, userID int64, email string) error {
syncer, ok := repo.(emailAuthIdentitySynchronizer)
if !ok {
return nil
}
return syncer.EnsureEmailAuthIdentity(ctx, userID, email)
}
func replaceEmailAuthIdentitySync(ctx context.Context, repo UserRepository, userID int64, oldEmail, newEmail string) error {
oldNormalized := strings.ToLower(strings.TrimSpace(oldEmail))
newNormalized := strings.ToLower(strings.TrimSpace(newEmail))
if oldNormalized == newNormalized {
return nil
}
syncer, ok := repo.(emailAuthIdentitySynchronizer)
if !ok {
return nil
}
return syncer.ReplaceEmailAuthIdentity(ctx, userID, oldEmail, newEmail)
}
// ChangePasswordRequest 修改密码请求
type ChangePasswordRequest struct {
CurrentPassword string `json:"current_password"`
@@ -281,7 +254,6 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat
return nil, fmt.Errorf("get user: %w", err)
}
oldConcurrency := user.Concurrency
oldEmail := user.Email
// 更新字段
if req.Email != nil {
@@ -326,9 +298,6 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat
if err := s.userRepo.Update(ctx, user); err != nil {
return nil, fmt.Errorf("update user: %w", err)
}
if err := replaceEmailAuthIdentitySync(ctx, s.userRepo, user.ID, oldEmail, user.Email); err != nil {
return nil, fmt.Errorf("sync email auth identity: %w", err)
}
if s.authCacheInvalidator != nil && user.Concurrency != oldConcurrency {
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
}

View File

@@ -9,7 +9,7 @@ import (
"github.com/stretchr/testify/require"
)
func TestUpdateProfile_ReplacesEmailAuthIdentityWhenEmailChanges(t *testing.T) {
func TestUpdateProfile_DoesNotReturnPartialSuccessFromEmailIdentityResync(t *testing.T) {
repo := &emailSyncRepoStub{
user: &User{
ID: 19,
@@ -17,6 +17,7 @@ func TestUpdateProfile_ReplacesEmailAuthIdentityWhenEmailChanges(t *testing.T) {
Username: "tester",
Concurrency: 2,
},
replaceErr: context.DeadlineExceeded,
}
svc := NewUserService(repo, nil, nil, nil)
@@ -28,10 +29,6 @@ func TestUpdateProfile_ReplacesEmailAuthIdentityWhenEmailChanges(t *testing.T) {
require.NotNil(t, updated)
require.Equal(t, newEmail, updated.Email)
require.Equal(t, 1, repo.updateCalls)
require.Equal(t, []replaceEmailCall{{
userID: 19,
oldEmail: "profile-before@example.com",
newEmail: "profile-after@example.com",
}}, repo.replaceCalls)
require.Empty(t, repo.replaceCalls)
require.Empty(t, repo.ensureCalls)
}