From 1521d503990f2b3ab6d474958b64c1e4f5fb3baf Mon Sep 17 00:00:00 2001 From: IanShaw027 Date: Tue, 21 Apr 2026 00:31:52 +0800 Subject: [PATCH] fix: apply email first-bind defaults on legacy login --- backend/internal/service/auth_service.go | 80 ++++++-- .../auth_service_identity_sync_test.go | 189 +++++++++++++++++- 2 files changed, 238 insertions(+), 31 deletions(-) diff --git a/backend/internal/service/auth_service.go b/backend/internal/service/auth_service.go index dda6df04..d0d5e4e3 100644 --- a/backend/internal/service/auth_service.go +++ b/backend/internal/service/auth_service.go @@ -807,37 +807,75 @@ func (s *AuthService) backfillEmailIdentityOnSuccessfulLogin(ctx context.Context if s == nil || user == nil || user.ID <= 0 { return } - s.ensureEmailAuthIdentity(ctx, user) + if s.ensureEmailAuthIdentity(ctx, user) { + if err := s.ApplyProviderDefaultSettingsOnFirstBind(ctx, user.ID, "email"); err != nil { + logger.LegacyPrintf("service.auth", "[Auth] Failed to apply email first bind defaults: user_id=%d err=%v", user.ID, err) + } + } } -func (s *AuthService) ensureEmailAuthIdentity(ctx context.Context, user *User) { +func (s *AuthService) ensureEmailAuthIdentity(ctx context.Context, user *User) bool { if s == nil || s.entClient == nil || user == nil || user.ID <= 0 { - return + return false } email := strings.ToLower(strings.TrimSpace(user.Email)) if email == "" || isReservedEmail(email) { - return + return false } - if err := s.entClient.AuthIdentity.Create(). - SetUserID(user.ID). - SetProviderType("email"). - SetProviderKey("email"). - SetProviderSubject(email). - SetVerifiedAt(time.Now().UTC()). - SetMetadata(map[string]any{ - "source": "auth_service_dual_write", - }). - OnConflictColumns( - authidentity.FieldProviderType, - authidentity.FieldProviderKey, - authidentity.FieldProviderSubject, - ). - DoNothing(). - Exec(ctx); err != nil { - logger.LegacyPrintf("service.auth", "[Auth] Failed to ensure email auth identity: user_id=%d email=%s err=%v", user.ID, email, err) + client := s.entClient + if tx := dbent.TxFromContext(ctx); tx != nil { + client = tx.Client() } + + buildQuery := func() *dbent.AuthIdentityQuery { + return client.AuthIdentity.Query().Where( + authidentity.ProviderTypeEQ("email"), + authidentity.ProviderKeyEQ("email"), + authidentity.ProviderSubjectEQ(email), + ) + } + + existed, err := buildQuery().Exist(ctx) + if err != nil { + logger.LegacyPrintf("service.auth", "[Auth] Failed to inspect email auth identity: user_id=%d email=%s err=%v", user.ID, email, err) + return false + } + + if !existed { + if err := client.AuthIdentity.Create(). + SetUserID(user.ID). + SetProviderType("email"). + SetProviderKey("email"). + SetProviderSubject(email). + SetVerifiedAt(time.Now().UTC()). + SetMetadata(map[string]any{ + "source": "auth_service_dual_write", + }). + OnConflictColumns( + authidentity.FieldProviderType, + authidentity.FieldProviderKey, + authidentity.FieldProviderSubject, + ). + DoNothing(). + Exec(ctx); err != nil { + logger.LegacyPrintf("service.auth", "[Auth] Failed to ensure email auth identity: user_id=%d email=%s err=%v", user.ID, email, err) + return false + } + } + + identity, err := buildQuery().Only(ctx) + if err != nil { + logger.LegacyPrintf("service.auth", "[Auth] Failed to reload email auth identity: user_id=%d email=%s err=%v", user.ID, email, err) + return false + } + if identity.UserID != user.ID { + logger.LegacyPrintf("service.auth", "[Auth] Email auth identity ownership mismatch: user_id=%d email=%s owner_id=%d", user.ID, email, identity.UserID) + return false + } + + return !existed } func inferLegacySignupSource(email string) string { diff --git a/backend/internal/service/auth_service_identity_sync_test.go b/backend/internal/service/auth_service_identity_sync_test.go index fcb4813b..e2a94b13 100644 --- a/backend/internal/service/auth_service_identity_sync_test.go +++ b/backend/internal/service/auth_service_identity_sync_test.go @@ -21,6 +21,19 @@ import ( _ "modernc.org/sqlite" ) +type authIdentityDefaultSubAssignerStub struct { + calls []*service.AssignSubscriptionInput +} + +func (s *authIdentityDefaultSubAssignerStub) AssignOrExtendSubscription( + _ context.Context, + input *service.AssignSubscriptionInput, +) (*service.UserSubscription, bool, error) { + cloned := *input + s.calls = append(s.calls, &cloned) + return &service.UserSubscription{UserID: input.UserID, GroupID: input.GroupID}, true, nil +} + type authIdentitySettingRepoStub struct { values map[string]string } @@ -40,8 +53,14 @@ func (s *authIdentitySettingRepoStub) Set(context.Context, string, string) error panic("unexpected Set call") } -func (s *authIdentitySettingRepoStub) GetMultiple(context.Context, []string) (map[string]string, error) { - panic("unexpected GetMultiple call") +func (s *authIdentitySettingRepoStub) GetMultiple(_ context.Context, keys []string) (map[string]string, error) { + out := make(map[string]string, len(keys)) + for _, key := range keys { + if v, ok := s.values[key]; ok { + out[key] = v + } + } + return out, nil } func (s *authIdentitySettingRepoStub) SetMultiple(context.Context, map[string]string) error { @@ -56,7 +75,11 @@ func (s *authIdentitySettingRepoStub) Delete(context.Context, string) error { panic("unexpected Delete call") } -func newAuthServiceWithEnt(t *testing.T) (*service.AuthService, service.UserRepository, *dbent.Client) { +func newAuthServiceWithEnt( + t *testing.T, + settings map[string]string, + defaultSubAssigner service.DefaultSubscriptionAssigner, +) (*service.AuthService, service.UserRepository, *dbent.Client) { t.Helper() db, err := sql.Open("sqlite", "file:auth_service_identity_sync?mode=memory&cache=shared") @@ -65,6 +88,16 @@ func newAuthServiceWithEnt(t *testing.T) (*service.AuthService, service.UserRepo _, err = db.Exec("PRAGMA foreign_keys = ON") require.NoError(t, err) + _, err = db.Exec(` +CREATE TABLE IF NOT EXISTS user_provider_default_grants ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_id INTEGER NOT NULL, + provider_type TEXT NOT NULL, + grant_reason TEXT NOT NULL DEFAULT 'first_bind', + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + UNIQUE(user_id, provider_type, grant_reason) +)`) + require.NoError(t, err) drv := entsql.OpenDB(dialect.SQLite, db) client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv))) @@ -82,17 +115,17 @@ func newAuthServiceWithEnt(t *testing.T) (*service.AuthService, service.UserRepo }, } settingSvc := service.NewSettingService(&authIdentitySettingRepoStub{ - values: map[string]string{ - service.SettingKeyRegistrationEnabled: "true", - }, + values: settings, }, cfg) - svc := service.NewAuthService(client, repo, nil, nil, cfg, settingSvc, nil, nil, nil, nil, nil) + svc := service.NewAuthService(client, repo, nil, nil, cfg, settingSvc, nil, nil, nil, nil, defaultSubAssigner) return svc, repo, client } func TestAuthServiceRegisterDualWritesEmailIdentity(t *testing.T) { - svc, _, client := newAuthServiceWithEnt(t) + svc, _, client := newAuthServiceWithEnt(t, map[string]string{ + service.SettingKeyRegistrationEnabled: "true", + }, nil) ctx := context.Background() token, user, err := svc.Register(ctx, "user@example.com", "password") @@ -119,7 +152,9 @@ func TestAuthServiceRegisterDualWritesEmailIdentity(t *testing.T) { } func TestAuthServiceLoginTouchesLastLoginAt(t *testing.T) { - svc, repo, client := newAuthServiceWithEnt(t) + svc, repo, client := newAuthServiceWithEnt(t, map[string]string{ + service.SettingKeyRegistrationEnabled: "true", + }, nil) ctx := context.Background() user := &service.User{ @@ -163,7 +198,9 @@ func TestAuthServiceLoginTouchesLastLoginAt(t *testing.T) { } func TestAuthServiceRecordSuccessfulLoginBackfillsEmailIdentity(t *testing.T) { - svc, repo, client := newAuthServiceWithEnt(t) + svc, repo, client := newAuthServiceWithEnt(t, map[string]string{ + service.SettingKeyRegistrationEnabled: "true", + }, nil) ctx := context.Background() user := &service.User{ @@ -188,3 +225,135 @@ func TestAuthServiceRecordSuccessfulLoginBackfillsEmailIdentity(t *testing.T) { require.NoError(t, err) require.Equal(t, user.ID, identity.UserID) } + +func TestAuthServiceLogin_AppliesEmailFirstBindDefaultsOnlyWhenEmailIdentityIsNew(t *testing.T) { + assigner := &authIdentityDefaultSubAssignerStub{} + svc, _, client := newAuthServiceWithEnt(t, map[string]string{ + service.SettingKeyRegistrationEnabled: "true", + service.SettingKeyAuthSourceDefaultEmailBalance: "8.5", + service.SettingKeyAuthSourceDefaultEmailConcurrency: "4", + service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":11,"validity_days":30}]`, + service.SettingKeyAuthSourceDefaultEmailGrantOnFirstBind: "true", + }, assigner) + ctx := context.Background() + + passwordHash, err := svc.HashPassword("password") + require.NoError(t, err) + user, err := client.User.Create(). + SetEmail("legacy@example.com"). + SetUsername("legacy-user"). + SetPasswordHash(passwordHash). + SetBalance(1.5). + SetConcurrency(2). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + token, gotUser, err := svc.Login(ctx, user.Email, "password") + require.NoError(t, err) + require.NotEmpty(t, token) + require.NotNil(t, gotUser) + + storedUser, err := client.User.Get(ctx, user.ID) + require.NoError(t, err) + require.Equal(t, 10.0, storedUser.Balance) + require.Equal(t, 6, storedUser.Concurrency) + require.Len(t, assigner.calls, 1) + require.Equal(t, int64(11), assigner.calls[0].GroupID) + require.Equal(t, 30, assigner.calls[0].ValidityDays) + + identityCount, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("email"), + authidentity.ProviderKeyEQ("email"), + authidentity.ProviderSubjectEQ("legacy@example.com"), + ). + Count(ctx) + require.NoError(t, err) + require.Equal(t, 1, identityCount) + require.Equal(t, 1, countProviderGrantRecords(t, client, user.ID, "email", "first_bind")) + + token, gotUser, err = svc.Login(ctx, user.Email, "password") + require.NoError(t, err) + require.NotEmpty(t, token) + require.NotNil(t, gotUser) + + storedUser, err = client.User.Get(ctx, user.ID) + require.NoError(t, err) + require.Equal(t, 10.0, storedUser.Balance) + require.Equal(t, 6, storedUser.Concurrency) + require.Len(t, assigner.calls, 1) + require.Equal(t, 1, countProviderGrantRecords(t, client, user.ID, "email", "first_bind")) +} + +func TestAuthServiceLogin_DoesNotApplyEmailFirstBindDefaultsWhenIdentityAlreadyExists(t *testing.T) { + assigner := &authIdentityDefaultSubAssignerStub{} + svc, _, client := newAuthServiceWithEnt(t, map[string]string{ + service.SettingKeyRegistrationEnabled: "true", + service.SettingKeyAuthSourceDefaultEmailBalance: "8.5", + service.SettingKeyAuthSourceDefaultEmailConcurrency: "4", + service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":11,"validity_days":30}]`, + service.SettingKeyAuthSourceDefaultEmailGrantOnFirstBind: "true", + }, assigner) + ctx := context.Background() + + passwordHash, err := svc.HashPassword("password") + require.NoError(t, err) + user, err := client.User.Create(). + SetEmail("bound@example.com"). + SetUsername("bound-user"). + SetPasswordHash(passwordHash). + SetBalance(2). + SetConcurrency(3). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + _, err = client.AuthIdentity.Create(). + SetUserID(user.ID). + SetProviderType("email"). + SetProviderKey("email"). + SetProviderSubject("bound@example.com"). + SetVerifiedAt(time.Now().UTC()). + SetMetadata(map[string]any{"source": "preexisting"}). + Save(ctx) + require.NoError(t, err) + + token, gotUser, err := svc.Login(ctx, user.Email, "password") + require.NoError(t, err) + require.NotEmpty(t, token) + require.NotNil(t, gotUser) + + storedUser, err := client.User.Get(ctx, user.ID) + require.NoError(t, err) + require.Equal(t, 2.0, storedUser.Balance) + require.Equal(t, 3, storedUser.Concurrency) + require.Empty(t, assigner.calls) + require.Equal(t, 0, countProviderGrantRecords(t, client, user.ID, "email", "first_bind")) +} + +func countProviderGrantRecords( + t *testing.T, + client *dbent.Client, + userID int64, + providerType string, + grantReason string, +) int { + t.Helper() + + var count int + rows, err := client.QueryContext( + context.Background(), + `SELECT COUNT(*) FROM user_provider_default_grants WHERE user_id = ? AND provider_type = ? AND grant_reason = ?`, + userID, + providerType, + grantReason, + ) + require.NoError(t, err) + defer rows.Close() + require.True(t, rows.Next()) + require.NoError(t, rows.Scan(&count)) + require.NoError(t, rows.Err()) + return count +}