fix: apply email first-bind defaults on legacy login
This commit is contained in:
@@ -21,6 +21,19 @@ import (
|
||||
_ "modernc.org/sqlite"
|
||||
)
|
||||
|
||||
type authIdentityDefaultSubAssignerStub struct {
|
||||
calls []*service.AssignSubscriptionInput
|
||||
}
|
||||
|
||||
func (s *authIdentityDefaultSubAssignerStub) AssignOrExtendSubscription(
|
||||
_ context.Context,
|
||||
input *service.AssignSubscriptionInput,
|
||||
) (*service.UserSubscription, bool, error) {
|
||||
cloned := *input
|
||||
s.calls = append(s.calls, &cloned)
|
||||
return &service.UserSubscription{UserID: input.UserID, GroupID: input.GroupID}, true, nil
|
||||
}
|
||||
|
||||
type authIdentitySettingRepoStub struct {
|
||||
values map[string]string
|
||||
}
|
||||
@@ -40,8 +53,14 @@ func (s *authIdentitySettingRepoStub) Set(context.Context, string, string) error
|
||||
panic("unexpected Set call")
|
||||
}
|
||||
|
||||
func (s *authIdentitySettingRepoStub) GetMultiple(context.Context, []string) (map[string]string, error) {
|
||||
panic("unexpected GetMultiple call")
|
||||
func (s *authIdentitySettingRepoStub) GetMultiple(_ context.Context, keys []string) (map[string]string, error) {
|
||||
out := make(map[string]string, len(keys))
|
||||
for _, key := range keys {
|
||||
if v, ok := s.values[key]; ok {
|
||||
out[key] = v
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (s *authIdentitySettingRepoStub) SetMultiple(context.Context, map[string]string) error {
|
||||
@@ -56,7 +75,11 @@ func (s *authIdentitySettingRepoStub) Delete(context.Context, string) error {
|
||||
panic("unexpected Delete call")
|
||||
}
|
||||
|
||||
func newAuthServiceWithEnt(t *testing.T) (*service.AuthService, service.UserRepository, *dbent.Client) {
|
||||
func newAuthServiceWithEnt(
|
||||
t *testing.T,
|
||||
settings map[string]string,
|
||||
defaultSubAssigner service.DefaultSubscriptionAssigner,
|
||||
) (*service.AuthService, service.UserRepository, *dbent.Client) {
|
||||
t.Helper()
|
||||
|
||||
db, err := sql.Open("sqlite", "file:auth_service_identity_sync?mode=memory&cache=shared")
|
||||
@@ -65,6 +88,16 @@ func newAuthServiceWithEnt(t *testing.T) (*service.AuthService, service.UserRepo
|
||||
|
||||
_, err = db.Exec("PRAGMA foreign_keys = ON")
|
||||
require.NoError(t, err)
|
||||
_, err = db.Exec(`
|
||||
CREATE TABLE IF NOT EXISTS user_provider_default_grants (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
user_id INTEGER NOT NULL,
|
||||
provider_type TEXT NOT NULL,
|
||||
grant_reason TEXT NOT NULL DEFAULT 'first_bind',
|
||||
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
UNIQUE(user_id, provider_type, grant_reason)
|
||||
)`)
|
||||
require.NoError(t, err)
|
||||
|
||||
drv := entsql.OpenDB(dialect.SQLite, db)
|
||||
client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
|
||||
@@ -82,17 +115,17 @@ func newAuthServiceWithEnt(t *testing.T) (*service.AuthService, service.UserRepo
|
||||
},
|
||||
}
|
||||
settingSvc := service.NewSettingService(&authIdentitySettingRepoStub{
|
||||
values: map[string]string{
|
||||
service.SettingKeyRegistrationEnabled: "true",
|
||||
},
|
||||
values: settings,
|
||||
}, cfg)
|
||||
|
||||
svc := service.NewAuthService(client, repo, nil, nil, cfg, settingSvc, nil, nil, nil, nil, nil)
|
||||
svc := service.NewAuthService(client, repo, nil, nil, cfg, settingSvc, nil, nil, nil, nil, defaultSubAssigner)
|
||||
return svc, repo, client
|
||||
}
|
||||
|
||||
func TestAuthServiceRegisterDualWritesEmailIdentity(t *testing.T) {
|
||||
svc, _, client := newAuthServiceWithEnt(t)
|
||||
svc, _, client := newAuthServiceWithEnt(t, map[string]string{
|
||||
service.SettingKeyRegistrationEnabled: "true",
|
||||
}, nil)
|
||||
ctx := context.Background()
|
||||
|
||||
token, user, err := svc.Register(ctx, "user@example.com", "password")
|
||||
@@ -119,7 +152,9 @@ func TestAuthServiceRegisterDualWritesEmailIdentity(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAuthServiceLoginTouchesLastLoginAt(t *testing.T) {
|
||||
svc, repo, client := newAuthServiceWithEnt(t)
|
||||
svc, repo, client := newAuthServiceWithEnt(t, map[string]string{
|
||||
service.SettingKeyRegistrationEnabled: "true",
|
||||
}, nil)
|
||||
ctx := context.Background()
|
||||
|
||||
user := &service.User{
|
||||
@@ -163,7 +198,9 @@ func TestAuthServiceLoginTouchesLastLoginAt(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAuthServiceRecordSuccessfulLoginBackfillsEmailIdentity(t *testing.T) {
|
||||
svc, repo, client := newAuthServiceWithEnt(t)
|
||||
svc, repo, client := newAuthServiceWithEnt(t, map[string]string{
|
||||
service.SettingKeyRegistrationEnabled: "true",
|
||||
}, nil)
|
||||
ctx := context.Background()
|
||||
|
||||
user := &service.User{
|
||||
@@ -188,3 +225,135 @@ func TestAuthServiceRecordSuccessfulLoginBackfillsEmailIdentity(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, user.ID, identity.UserID)
|
||||
}
|
||||
|
||||
func TestAuthServiceLogin_AppliesEmailFirstBindDefaultsOnlyWhenEmailIdentityIsNew(t *testing.T) {
|
||||
assigner := &authIdentityDefaultSubAssignerStub{}
|
||||
svc, _, client := newAuthServiceWithEnt(t, map[string]string{
|
||||
service.SettingKeyRegistrationEnabled: "true",
|
||||
service.SettingKeyAuthSourceDefaultEmailBalance: "8.5",
|
||||
service.SettingKeyAuthSourceDefaultEmailConcurrency: "4",
|
||||
service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":11,"validity_days":30}]`,
|
||||
service.SettingKeyAuthSourceDefaultEmailGrantOnFirstBind: "true",
|
||||
}, assigner)
|
||||
ctx := context.Background()
|
||||
|
||||
passwordHash, err := svc.HashPassword("password")
|
||||
require.NoError(t, err)
|
||||
user, err := client.User.Create().
|
||||
SetEmail("legacy@example.com").
|
||||
SetUsername("legacy-user").
|
||||
SetPasswordHash(passwordHash).
|
||||
SetBalance(1.5).
|
||||
SetConcurrency(2).
|
||||
SetRole(service.RoleUser).
|
||||
SetStatus(service.StatusActive).
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
token, gotUser, err := svc.Login(ctx, user.Email, "password")
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, token)
|
||||
require.NotNil(t, gotUser)
|
||||
|
||||
storedUser, err := client.User.Get(ctx, user.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 10.0, storedUser.Balance)
|
||||
require.Equal(t, 6, storedUser.Concurrency)
|
||||
require.Len(t, assigner.calls, 1)
|
||||
require.Equal(t, int64(11), assigner.calls[0].GroupID)
|
||||
require.Equal(t, 30, assigner.calls[0].ValidityDays)
|
||||
|
||||
identityCount, err := client.AuthIdentity.Query().
|
||||
Where(
|
||||
authidentity.ProviderTypeEQ("email"),
|
||||
authidentity.ProviderKeyEQ("email"),
|
||||
authidentity.ProviderSubjectEQ("legacy@example.com"),
|
||||
).
|
||||
Count(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, identityCount)
|
||||
require.Equal(t, 1, countProviderGrantRecords(t, client, user.ID, "email", "first_bind"))
|
||||
|
||||
token, gotUser, err = svc.Login(ctx, user.Email, "password")
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, token)
|
||||
require.NotNil(t, gotUser)
|
||||
|
||||
storedUser, err = client.User.Get(ctx, user.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 10.0, storedUser.Balance)
|
||||
require.Equal(t, 6, storedUser.Concurrency)
|
||||
require.Len(t, assigner.calls, 1)
|
||||
require.Equal(t, 1, countProviderGrantRecords(t, client, user.ID, "email", "first_bind"))
|
||||
}
|
||||
|
||||
func TestAuthServiceLogin_DoesNotApplyEmailFirstBindDefaultsWhenIdentityAlreadyExists(t *testing.T) {
|
||||
assigner := &authIdentityDefaultSubAssignerStub{}
|
||||
svc, _, client := newAuthServiceWithEnt(t, map[string]string{
|
||||
service.SettingKeyRegistrationEnabled: "true",
|
||||
service.SettingKeyAuthSourceDefaultEmailBalance: "8.5",
|
||||
service.SettingKeyAuthSourceDefaultEmailConcurrency: "4",
|
||||
service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":11,"validity_days":30}]`,
|
||||
service.SettingKeyAuthSourceDefaultEmailGrantOnFirstBind: "true",
|
||||
}, assigner)
|
||||
ctx := context.Background()
|
||||
|
||||
passwordHash, err := svc.HashPassword("password")
|
||||
require.NoError(t, err)
|
||||
user, err := client.User.Create().
|
||||
SetEmail("bound@example.com").
|
||||
SetUsername("bound-user").
|
||||
SetPasswordHash(passwordHash).
|
||||
SetBalance(2).
|
||||
SetConcurrency(3).
|
||||
SetRole(service.RoleUser).
|
||||
SetStatus(service.StatusActive).
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
_, err = client.AuthIdentity.Create().
|
||||
SetUserID(user.ID).
|
||||
SetProviderType("email").
|
||||
SetProviderKey("email").
|
||||
SetProviderSubject("bound@example.com").
|
||||
SetVerifiedAt(time.Now().UTC()).
|
||||
SetMetadata(map[string]any{"source": "preexisting"}).
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
token, gotUser, err := svc.Login(ctx, user.Email, "password")
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, token)
|
||||
require.NotNil(t, gotUser)
|
||||
|
||||
storedUser, err := client.User.Get(ctx, user.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 2.0, storedUser.Balance)
|
||||
require.Equal(t, 3, storedUser.Concurrency)
|
||||
require.Empty(t, assigner.calls)
|
||||
require.Equal(t, 0, countProviderGrantRecords(t, client, user.ID, "email", "first_bind"))
|
||||
}
|
||||
|
||||
func countProviderGrantRecords(
|
||||
t *testing.T,
|
||||
client *dbent.Client,
|
||||
userID int64,
|
||||
providerType string,
|
||||
grantReason string,
|
||||
) int {
|
||||
t.Helper()
|
||||
|
||||
var count int
|
||||
rows, err := client.QueryContext(
|
||||
context.Background(),
|
||||
`SELECT COUNT(*) FROM user_provider_default_grants WHERE user_id = ? AND provider_type = ? AND grant_reason = ?`,
|
||||
userID,
|
||||
providerType,
|
||||
grantReason,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
defer rows.Close()
|
||||
require.True(t, rows.Next())
|
||||
require.NoError(t, rows.Scan(&count))
|
||||
require.NoError(t, rows.Err())
|
||||
return count
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user