fix: apply email first-bind defaults on legacy login
This commit is contained in:
@@ -807,37 +807,75 @@ func (s *AuthService) backfillEmailIdentityOnSuccessfulLogin(ctx context.Context
|
|||||||
if s == nil || user == nil || user.ID <= 0 {
|
if s == nil || user == nil || user.ID <= 0 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
s.ensureEmailAuthIdentity(ctx, user)
|
if s.ensureEmailAuthIdentity(ctx, user) {
|
||||||
|
if err := s.ApplyProviderDefaultSettingsOnFirstBind(ctx, user.ID, "email"); err != nil {
|
||||||
|
logger.LegacyPrintf("service.auth", "[Auth] Failed to apply email first bind defaults: user_id=%d err=%v", user.ID, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *AuthService) ensureEmailAuthIdentity(ctx context.Context, user *User) {
|
func (s *AuthService) ensureEmailAuthIdentity(ctx context.Context, user *User) bool {
|
||||||
if s == nil || s.entClient == nil || user == nil || user.ID <= 0 {
|
if s == nil || s.entClient == nil || user == nil || user.ID <= 0 {
|
||||||
return
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
email := strings.ToLower(strings.TrimSpace(user.Email))
|
email := strings.ToLower(strings.TrimSpace(user.Email))
|
||||||
if email == "" || isReservedEmail(email) {
|
if email == "" || isReservedEmail(email) {
|
||||||
return
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := s.entClient.AuthIdentity.Create().
|
client := s.entClient
|
||||||
SetUserID(user.ID).
|
if tx := dbent.TxFromContext(ctx); tx != nil {
|
||||||
SetProviderType("email").
|
client = tx.Client()
|
||||||
SetProviderKey("email").
|
|
||||||
SetProviderSubject(email).
|
|
||||||
SetVerifiedAt(time.Now().UTC()).
|
|
||||||
SetMetadata(map[string]any{
|
|
||||||
"source": "auth_service_dual_write",
|
|
||||||
}).
|
|
||||||
OnConflictColumns(
|
|
||||||
authidentity.FieldProviderType,
|
|
||||||
authidentity.FieldProviderKey,
|
|
||||||
authidentity.FieldProviderSubject,
|
|
||||||
).
|
|
||||||
DoNothing().
|
|
||||||
Exec(ctx); err != nil {
|
|
||||||
logger.LegacyPrintf("service.auth", "[Auth] Failed to ensure email auth identity: user_id=%d email=%s err=%v", user.ID, email, err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
buildQuery := func() *dbent.AuthIdentityQuery {
|
||||||
|
return client.AuthIdentity.Query().Where(
|
||||||
|
authidentity.ProviderTypeEQ("email"),
|
||||||
|
authidentity.ProviderKeyEQ("email"),
|
||||||
|
authidentity.ProviderSubjectEQ(email),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
existed, err := buildQuery().Exist(ctx)
|
||||||
|
if err != nil {
|
||||||
|
logger.LegacyPrintf("service.auth", "[Auth] Failed to inspect email auth identity: user_id=%d email=%s err=%v", user.ID, email, err)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if !existed {
|
||||||
|
if err := client.AuthIdentity.Create().
|
||||||
|
SetUserID(user.ID).
|
||||||
|
SetProviderType("email").
|
||||||
|
SetProviderKey("email").
|
||||||
|
SetProviderSubject(email).
|
||||||
|
SetVerifiedAt(time.Now().UTC()).
|
||||||
|
SetMetadata(map[string]any{
|
||||||
|
"source": "auth_service_dual_write",
|
||||||
|
}).
|
||||||
|
OnConflictColumns(
|
||||||
|
authidentity.FieldProviderType,
|
||||||
|
authidentity.FieldProviderKey,
|
||||||
|
authidentity.FieldProviderSubject,
|
||||||
|
).
|
||||||
|
DoNothing().
|
||||||
|
Exec(ctx); err != nil {
|
||||||
|
logger.LegacyPrintf("service.auth", "[Auth] Failed to ensure email auth identity: user_id=%d email=%s err=%v", user.ID, email, err)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
identity, err := buildQuery().Only(ctx)
|
||||||
|
if err != nil {
|
||||||
|
logger.LegacyPrintf("service.auth", "[Auth] Failed to reload email auth identity: user_id=%d email=%s err=%v", user.ID, email, err)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if identity.UserID != user.ID {
|
||||||
|
logger.LegacyPrintf("service.auth", "[Auth] Email auth identity ownership mismatch: user_id=%d email=%s owner_id=%d", user.ID, email, identity.UserID)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return !existed
|
||||||
}
|
}
|
||||||
|
|
||||||
func inferLegacySignupSource(email string) string {
|
func inferLegacySignupSource(email string) string {
|
||||||
|
|||||||
@@ -21,6 +21,19 @@ import (
|
|||||||
_ "modernc.org/sqlite"
|
_ "modernc.org/sqlite"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type authIdentityDefaultSubAssignerStub struct {
|
||||||
|
calls []*service.AssignSubscriptionInput
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *authIdentityDefaultSubAssignerStub) AssignOrExtendSubscription(
|
||||||
|
_ context.Context,
|
||||||
|
input *service.AssignSubscriptionInput,
|
||||||
|
) (*service.UserSubscription, bool, error) {
|
||||||
|
cloned := *input
|
||||||
|
s.calls = append(s.calls, &cloned)
|
||||||
|
return &service.UserSubscription{UserID: input.UserID, GroupID: input.GroupID}, true, nil
|
||||||
|
}
|
||||||
|
|
||||||
type authIdentitySettingRepoStub struct {
|
type authIdentitySettingRepoStub struct {
|
||||||
values map[string]string
|
values map[string]string
|
||||||
}
|
}
|
||||||
@@ -40,8 +53,14 @@ func (s *authIdentitySettingRepoStub) Set(context.Context, string, string) error
|
|||||||
panic("unexpected Set call")
|
panic("unexpected Set call")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *authIdentitySettingRepoStub) GetMultiple(context.Context, []string) (map[string]string, error) {
|
func (s *authIdentitySettingRepoStub) GetMultiple(_ context.Context, keys []string) (map[string]string, error) {
|
||||||
panic("unexpected GetMultiple call")
|
out := make(map[string]string, len(keys))
|
||||||
|
for _, key := range keys {
|
||||||
|
if v, ok := s.values[key]; ok {
|
||||||
|
out[key] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *authIdentitySettingRepoStub) SetMultiple(context.Context, map[string]string) error {
|
func (s *authIdentitySettingRepoStub) SetMultiple(context.Context, map[string]string) error {
|
||||||
@@ -56,7 +75,11 @@ func (s *authIdentitySettingRepoStub) Delete(context.Context, string) error {
|
|||||||
panic("unexpected Delete call")
|
panic("unexpected Delete call")
|
||||||
}
|
}
|
||||||
|
|
||||||
func newAuthServiceWithEnt(t *testing.T) (*service.AuthService, service.UserRepository, *dbent.Client) {
|
func newAuthServiceWithEnt(
|
||||||
|
t *testing.T,
|
||||||
|
settings map[string]string,
|
||||||
|
defaultSubAssigner service.DefaultSubscriptionAssigner,
|
||||||
|
) (*service.AuthService, service.UserRepository, *dbent.Client) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
db, err := sql.Open("sqlite", "file:auth_service_identity_sync?mode=memory&cache=shared")
|
db, err := sql.Open("sqlite", "file:auth_service_identity_sync?mode=memory&cache=shared")
|
||||||
@@ -65,6 +88,16 @@ func newAuthServiceWithEnt(t *testing.T) (*service.AuthService, service.UserRepo
|
|||||||
|
|
||||||
_, err = db.Exec("PRAGMA foreign_keys = ON")
|
_, err = db.Exec("PRAGMA foreign_keys = ON")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
_, err = db.Exec(`
|
||||||
|
CREATE TABLE IF NOT EXISTS user_provider_default_grants (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
user_id INTEGER NOT NULL,
|
||||||
|
provider_type TEXT NOT NULL,
|
||||||
|
grant_reason TEXT NOT NULL DEFAULT 'first_bind',
|
||||||
|
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
UNIQUE(user_id, provider_type, grant_reason)
|
||||||
|
)`)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
drv := entsql.OpenDB(dialect.SQLite, db)
|
drv := entsql.OpenDB(dialect.SQLite, db)
|
||||||
client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
|
client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
|
||||||
@@ -82,17 +115,17 @@ func newAuthServiceWithEnt(t *testing.T) (*service.AuthService, service.UserRepo
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
settingSvc := service.NewSettingService(&authIdentitySettingRepoStub{
|
settingSvc := service.NewSettingService(&authIdentitySettingRepoStub{
|
||||||
values: map[string]string{
|
values: settings,
|
||||||
service.SettingKeyRegistrationEnabled: "true",
|
|
||||||
},
|
|
||||||
}, cfg)
|
}, cfg)
|
||||||
|
|
||||||
svc := service.NewAuthService(client, repo, nil, nil, cfg, settingSvc, nil, nil, nil, nil, nil)
|
svc := service.NewAuthService(client, repo, nil, nil, cfg, settingSvc, nil, nil, nil, nil, defaultSubAssigner)
|
||||||
return svc, repo, client
|
return svc, repo, client
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAuthServiceRegisterDualWritesEmailIdentity(t *testing.T) {
|
func TestAuthServiceRegisterDualWritesEmailIdentity(t *testing.T) {
|
||||||
svc, _, client := newAuthServiceWithEnt(t)
|
svc, _, client := newAuthServiceWithEnt(t, map[string]string{
|
||||||
|
service.SettingKeyRegistrationEnabled: "true",
|
||||||
|
}, nil)
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
token, user, err := svc.Register(ctx, "user@example.com", "password")
|
token, user, err := svc.Register(ctx, "user@example.com", "password")
|
||||||
@@ -119,7 +152,9 @@ func TestAuthServiceRegisterDualWritesEmailIdentity(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestAuthServiceLoginTouchesLastLoginAt(t *testing.T) {
|
func TestAuthServiceLoginTouchesLastLoginAt(t *testing.T) {
|
||||||
svc, repo, client := newAuthServiceWithEnt(t)
|
svc, repo, client := newAuthServiceWithEnt(t, map[string]string{
|
||||||
|
service.SettingKeyRegistrationEnabled: "true",
|
||||||
|
}, nil)
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
user := &service.User{
|
user := &service.User{
|
||||||
@@ -163,7 +198,9 @@ func TestAuthServiceLoginTouchesLastLoginAt(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestAuthServiceRecordSuccessfulLoginBackfillsEmailIdentity(t *testing.T) {
|
func TestAuthServiceRecordSuccessfulLoginBackfillsEmailIdentity(t *testing.T) {
|
||||||
svc, repo, client := newAuthServiceWithEnt(t)
|
svc, repo, client := newAuthServiceWithEnt(t, map[string]string{
|
||||||
|
service.SettingKeyRegistrationEnabled: "true",
|
||||||
|
}, nil)
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
user := &service.User{
|
user := &service.User{
|
||||||
@@ -188,3 +225,135 @@ func TestAuthServiceRecordSuccessfulLoginBackfillsEmailIdentity(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, user.ID, identity.UserID)
|
require.Equal(t, user.ID, identity.UserID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAuthServiceLogin_AppliesEmailFirstBindDefaultsOnlyWhenEmailIdentityIsNew(t *testing.T) {
|
||||||
|
assigner := &authIdentityDefaultSubAssignerStub{}
|
||||||
|
svc, _, client := newAuthServiceWithEnt(t, map[string]string{
|
||||||
|
service.SettingKeyRegistrationEnabled: "true",
|
||||||
|
service.SettingKeyAuthSourceDefaultEmailBalance: "8.5",
|
||||||
|
service.SettingKeyAuthSourceDefaultEmailConcurrency: "4",
|
||||||
|
service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":11,"validity_days":30}]`,
|
||||||
|
service.SettingKeyAuthSourceDefaultEmailGrantOnFirstBind: "true",
|
||||||
|
}, assigner)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
passwordHash, err := svc.HashPassword("password")
|
||||||
|
require.NoError(t, err)
|
||||||
|
user, err := client.User.Create().
|
||||||
|
SetEmail("legacy@example.com").
|
||||||
|
SetUsername("legacy-user").
|
||||||
|
SetPasswordHash(passwordHash).
|
||||||
|
SetBalance(1.5).
|
||||||
|
SetConcurrency(2).
|
||||||
|
SetRole(service.RoleUser).
|
||||||
|
SetStatus(service.StatusActive).
|
||||||
|
Save(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
token, gotUser, err := svc.Login(ctx, user.Email, "password")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotEmpty(t, token)
|
||||||
|
require.NotNil(t, gotUser)
|
||||||
|
|
||||||
|
storedUser, err := client.User.Get(ctx, user.ID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 10.0, storedUser.Balance)
|
||||||
|
require.Equal(t, 6, storedUser.Concurrency)
|
||||||
|
require.Len(t, assigner.calls, 1)
|
||||||
|
require.Equal(t, int64(11), assigner.calls[0].GroupID)
|
||||||
|
require.Equal(t, 30, assigner.calls[0].ValidityDays)
|
||||||
|
|
||||||
|
identityCount, err := client.AuthIdentity.Query().
|
||||||
|
Where(
|
||||||
|
authidentity.ProviderTypeEQ("email"),
|
||||||
|
authidentity.ProviderKeyEQ("email"),
|
||||||
|
authidentity.ProviderSubjectEQ("legacy@example.com"),
|
||||||
|
).
|
||||||
|
Count(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 1, identityCount)
|
||||||
|
require.Equal(t, 1, countProviderGrantRecords(t, client, user.ID, "email", "first_bind"))
|
||||||
|
|
||||||
|
token, gotUser, err = svc.Login(ctx, user.Email, "password")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotEmpty(t, token)
|
||||||
|
require.NotNil(t, gotUser)
|
||||||
|
|
||||||
|
storedUser, err = client.User.Get(ctx, user.ID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 10.0, storedUser.Balance)
|
||||||
|
require.Equal(t, 6, storedUser.Concurrency)
|
||||||
|
require.Len(t, assigner.calls, 1)
|
||||||
|
require.Equal(t, 1, countProviderGrantRecords(t, client, user.ID, "email", "first_bind"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthServiceLogin_DoesNotApplyEmailFirstBindDefaultsWhenIdentityAlreadyExists(t *testing.T) {
|
||||||
|
assigner := &authIdentityDefaultSubAssignerStub{}
|
||||||
|
svc, _, client := newAuthServiceWithEnt(t, map[string]string{
|
||||||
|
service.SettingKeyRegistrationEnabled: "true",
|
||||||
|
service.SettingKeyAuthSourceDefaultEmailBalance: "8.5",
|
||||||
|
service.SettingKeyAuthSourceDefaultEmailConcurrency: "4",
|
||||||
|
service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":11,"validity_days":30}]`,
|
||||||
|
service.SettingKeyAuthSourceDefaultEmailGrantOnFirstBind: "true",
|
||||||
|
}, assigner)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
passwordHash, err := svc.HashPassword("password")
|
||||||
|
require.NoError(t, err)
|
||||||
|
user, err := client.User.Create().
|
||||||
|
SetEmail("bound@example.com").
|
||||||
|
SetUsername("bound-user").
|
||||||
|
SetPasswordHash(passwordHash).
|
||||||
|
SetBalance(2).
|
||||||
|
SetConcurrency(3).
|
||||||
|
SetRole(service.RoleUser).
|
||||||
|
SetStatus(service.StatusActive).
|
||||||
|
Save(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
_, err = client.AuthIdentity.Create().
|
||||||
|
SetUserID(user.ID).
|
||||||
|
SetProviderType("email").
|
||||||
|
SetProviderKey("email").
|
||||||
|
SetProviderSubject("bound@example.com").
|
||||||
|
SetVerifiedAt(time.Now().UTC()).
|
||||||
|
SetMetadata(map[string]any{"source": "preexisting"}).
|
||||||
|
Save(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
token, gotUser, err := svc.Login(ctx, user.Email, "password")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotEmpty(t, token)
|
||||||
|
require.NotNil(t, gotUser)
|
||||||
|
|
||||||
|
storedUser, err := client.User.Get(ctx, user.ID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 2.0, storedUser.Balance)
|
||||||
|
require.Equal(t, 3, storedUser.Concurrency)
|
||||||
|
require.Empty(t, assigner.calls)
|
||||||
|
require.Equal(t, 0, countProviderGrantRecords(t, client, user.ID, "email", "first_bind"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func countProviderGrantRecords(
|
||||||
|
t *testing.T,
|
||||||
|
client *dbent.Client,
|
||||||
|
userID int64,
|
||||||
|
providerType string,
|
||||||
|
grantReason string,
|
||||||
|
) int {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
var count int
|
||||||
|
rows, err := client.QueryContext(
|
||||||
|
context.Background(),
|
||||||
|
`SELECT COUNT(*) FROM user_provider_default_grants WHERE user_id = ? AND provider_type = ? AND grant_reason = ?`,
|
||||||
|
userID,
|
||||||
|
providerType,
|
||||||
|
grantReason,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer rows.Close()
|
||||||
|
require.True(t, rows.Next())
|
||||||
|
require.NoError(t, rows.Scan(&count))
|
||||||
|
require.NoError(t, rows.Err())
|
||||||
|
return count
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user