feat: complete email binding and pending oauth verification flows

This commit is contained in:
IanShaw027
2026-04-21 10:00:06 +08:00
parent 6da08262d7
commit dcd5c43da4
29 changed files with 2117 additions and 107 deletions

View File

@@ -0,0 +1,128 @@
package service
import (
"context"
"errors"
"fmt"
"net/mail"
"strings"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
// BindEmailIdentity verifies and binds a local email/password identity to the current user.
func (s *AuthService) BindEmailIdentity(
ctx context.Context,
userID int64,
email string,
verifyCode string,
password string,
) (*User, error) {
if s == nil {
return nil, ErrServiceUnavailable
}
normalizedEmail, err := normalizeEmailForIdentityBinding(email)
if err != nil {
return nil, err
}
if isReservedEmail(normalizedEmail) {
return nil, ErrEmailReserved
}
if strings.TrimSpace(password) == "" {
return nil, ErrPasswordRequired
}
if err := s.VerifyOAuthEmailCode(ctx, normalizedEmail, verifyCode); err != nil {
return nil, err
}
currentUser, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return nil, err
}
existingUser, err := s.userRepo.GetByEmail(ctx, normalizedEmail)
switch {
case err == nil && existingUser != nil && existingUser.ID != userID:
return nil, ErrEmailExists
case err != nil && !errors.Is(err, ErrUserNotFound):
return nil, ErrServiceUnavailable
}
hashedPassword, err := s.HashPassword(password)
if err != nil {
return nil, fmt.Errorf("hash password: %w", err)
}
firstRealEmailBind := !hasBindableEmailIdentitySubject(currentUser.Email)
currentUser.Email = normalizedEmail
currentUser.PasswordHash = hashedPassword
if err := s.userRepo.Update(ctx, currentUser); err != nil {
if errors.Is(err, ErrEmailExists) {
return nil, ErrEmailExists
}
return nil, ErrServiceUnavailable
}
if firstRealEmailBind {
if err := s.ApplyProviderDefaultSettingsOnFirstBind(ctx, userID, "email"); err != nil {
return nil, fmt.Errorf("apply email first bind defaults: %w", err)
}
}
return currentUser, nil
}
// SendEmailIdentityBindCode sends a verification code for authenticated email binding flows.
func (s *AuthService) SendEmailIdentityBindCode(ctx context.Context, userID int64, email string) error {
if s == nil {
return ErrServiceUnavailable
}
normalizedEmail, err := normalizeEmailForIdentityBinding(email)
if err != nil {
return err
}
if isReservedEmail(normalizedEmail) {
return ErrEmailReserved
}
if s.emailService == nil {
return ErrServiceUnavailable
}
if _, err := s.userRepo.GetByID(ctx, userID); err != nil {
if errors.Is(err, ErrUserNotFound) {
return ErrUserNotFound
}
return ErrServiceUnavailable
}
existingUser, err := s.userRepo.GetByEmail(ctx, normalizedEmail)
switch {
case err == nil && existingUser != nil && existingUser.ID != userID:
return ErrEmailExists
case err != nil && !errors.Is(err, ErrUserNotFound):
return ErrServiceUnavailable
}
siteName := "Sub2API"
if s.settingService != nil {
siteName = s.settingService.GetSiteName(ctx)
}
return s.emailService.SendVerifyCode(ctx, normalizedEmail, siteName)
}
func normalizeEmailForIdentityBinding(email string) (string, error) {
normalized := strings.ToLower(strings.TrimSpace(email))
if normalized == "" || len(normalized) > 255 {
return "", infraerrors.BadRequest("INVALID_EMAIL", "invalid email")
}
if _, err := mail.ParseAddress(normalized); err != nil {
return "", infraerrors.BadRequest("INVALID_EMAIL", "invalid email")
}
return normalized, nil
}
func hasBindableEmailIdentitySubject(email string) bool {
normalized := strings.ToLower(strings.TrimSpace(email))
return normalized != "" && !isReservedEmail(normalized)
}

View File

@@ -4,9 +4,71 @@ import (
"context"
"errors"
"fmt"
"net/mail"
"strings"
"time"
)
func normalizeOAuthSignupSource(signupSource string) string {
signupSource = strings.TrimSpace(strings.ToLower(signupSource))
if signupSource == "" {
return "email"
}
return signupSource
}
// SendPendingOAuthVerifyCode sends a local verification code for pending OAuth
// account-creation flows without relying on the public registration gate.
func (s *AuthService) SendPendingOAuthVerifyCode(ctx context.Context, email string) (*SendVerifyCodeResult, error) {
email = strings.TrimSpace(strings.ToLower(email))
if email == "" {
return nil, ErrEmailVerifyRequired
}
if _, err := mail.ParseAddress(email); err != nil {
return nil, ErrEmailVerifyRequired
}
if isReservedEmail(email) {
return nil, ErrEmailReserved
}
if s == nil || s.emailService == nil {
return nil, ErrServiceUnavailable
}
siteName := "Sub2API"
if s.settingService != nil {
siteName = s.settingService.GetSiteName(ctx)
}
if err := s.emailService.SendVerifyCode(ctx, email, siteName); err != nil {
return nil, err
}
return &SendVerifyCodeResult{
Countdown: int(verifyCodeCooldown / time.Second),
}, nil
}
func (s *AuthService) validateOAuthRegistrationInvitation(ctx context.Context, invitationCode string) (*RedeemCode, error) {
if s == nil || s.settingService == nil || !s.settingService.IsInvitationCodeEnabled(ctx) {
return nil, nil
}
if s.redeemRepo == nil {
return nil, ErrServiceUnavailable
}
invitationCode = strings.TrimSpace(invitationCode)
if invitationCode == "" {
return nil, ErrInvitationCodeRequired
}
redeemCode, err := s.redeemRepo.GetByCode(ctx, invitationCode)
if err != nil {
return nil, ErrInvitationCodeInvalid
}
if redeemCode.Type != RedeemTypeInvitation || redeemCode.Status != StatusUnused {
return nil, ErrInvitationCodeInvalid
}
return redeemCode, nil
}
// VerifyOAuthEmailCode verifies the locally entered email verification code for
// third-party signup and binding flows. This is intentionally independent from
// the global registration email verification toggle.
@@ -54,19 +116,8 @@ func (s *AuthService) RegisterOAuthEmailAccount(
return nil, nil, err
}
var invitationRedeemCode *RedeemCode
if s.settingService.IsInvitationCodeEnabled(ctx) {
if invitationCode == "" {
return nil, nil, ErrInvitationCodeRequired
}
redeemCode, err := s.redeemRepo.GetByCode(ctx, invitationCode)
if err != nil {
return nil, nil, ErrInvitationCodeInvalid
}
if redeemCode.Type != RedeemTypeInvitation || redeemCode.Status != StatusUnused {
return nil, nil, ErrInvitationCodeInvalid
}
invitationRedeemCode = redeemCode
if _, err := s.validateOAuthRegistrationInvitation(ctx, invitationCode); err != nil {
return nil, nil, err
}
existsEmail, err := s.userRepo.ExistsByEmail(ctx, email)
@@ -104,22 +155,91 @@ func (s *AuthService) RegisterOAuthEmailAccount(
return nil, nil, ErrServiceUnavailable
}
s.postAuthUserBootstrap(ctx, user, signupSource, false)
s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
if invitationRedeemCode != nil {
if err := s.redeemRepo.Use(ctx, invitationRedeemCode.ID, user.ID); err != nil {
return nil, nil, ErrInvitationCodeInvalid
}
}
tokenPair, err := s.GenerateTokenPair(ctx, user, "")
if err != nil {
_ = s.RollbackOAuthEmailAccountCreation(ctx, user.ID, "")
return nil, nil, fmt.Errorf("generate token pair: %w", err)
}
return tokenPair, user, nil
}
// FinalizeOAuthEmailAccount applies invitation usage and normal signup bootstrap
// only after the pending OAuth flow has fully reached its last reversible step.
func (s *AuthService) FinalizeOAuthEmailAccount(
ctx context.Context,
user *User,
invitationCode string,
signupSource string,
) error {
if s == nil || user == nil || user.ID <= 0 {
return ErrServiceUnavailable
}
signupSource = normalizeOAuthSignupSource(signupSource)
invitationRedeemCode, err := s.validateOAuthRegistrationInvitation(ctx, invitationCode)
if err != nil {
return err
}
if invitationRedeemCode != nil {
if err := s.redeemRepo.Use(ctx, invitationRedeemCode.ID, user.ID); err != nil {
return ErrInvitationCodeInvalid
}
}
s.postAuthUserBootstrap(ctx, user, signupSource, false)
grantPlan := s.resolveSignupGrantPlan(ctx, signupSource)
s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
return nil
}
// RollbackOAuthEmailAccountCreation removes a partially-created local account
// and restores any invitation code already consumed by that account.
func (s *AuthService) RollbackOAuthEmailAccountCreation(ctx context.Context, userID int64, invitationCode string) error {
if s == nil || s.userRepo == nil || userID <= 0 {
return ErrServiceUnavailable
}
if err := s.restoreOAuthRegistrationInvitation(ctx, invitationCode, userID); err != nil {
return err
}
if err := s.userRepo.Delete(ctx, userID); err != nil {
return fmt.Errorf("delete created oauth user: %w", err)
}
return nil
}
func (s *AuthService) restoreOAuthRegistrationInvitation(ctx context.Context, invitationCode string, userID int64) error {
if s == nil || s.settingService == nil || !s.settingService.IsInvitationCodeEnabled(ctx) {
return nil
}
if s.redeemRepo == nil {
return ErrServiceUnavailable
}
invitationCode = strings.TrimSpace(invitationCode)
if invitationCode == "" || userID <= 0 {
return nil
}
redeemCode, err := s.redeemRepo.GetByCode(ctx, invitationCode)
if err != nil {
if errors.Is(err, ErrRedeemCodeNotFound) {
return nil
}
return fmt.Errorf("load invitation code: %w", err)
}
if redeemCode.Type != RedeemTypeInvitation || redeemCode.Status != StatusUsed || redeemCode.UsedBy == nil || *redeemCode.UsedBy != userID {
return nil
}
redeemCode.Status = StatusUnused
redeemCode.UsedBy = nil
redeemCode.UsedAt = nil
if err := s.redeemRepo.Update(ctx, redeemCode); err != nil {
return fmt.Errorf("restore invitation code: %w", err)
}
return nil
}
// ValidatePasswordCredentials checks the local password without completing the
// login flow. This is used by pending third-party account adoption flows before
// the external identity has been bound.

View File

@@ -0,0 +1,251 @@
//go:build unit
package service
import (
"context"
"errors"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/require"
)
type redeemCodeRepoStub struct {
codesByCode map[string]*RedeemCode
useCalls []struct {
id int64
userID int64
}
updateCalls []*RedeemCode
}
func (s *redeemCodeRepoStub) Create(context.Context, *RedeemCode) error {
panic("unexpected Create call")
}
func (s *redeemCodeRepoStub) CreateBatch(context.Context, []RedeemCode) error {
panic("unexpected CreateBatch call")
}
func (s *redeemCodeRepoStub) GetByID(context.Context, int64) (*RedeemCode, error) {
panic("unexpected GetByID call")
}
func (s *redeemCodeRepoStub) GetByCode(_ context.Context, code string) (*RedeemCode, error) {
if s.codesByCode == nil {
return nil, ErrRedeemCodeNotFound
}
redeemCode, ok := s.codesByCode[code]
if !ok {
return nil, ErrRedeemCodeNotFound
}
cloned := *redeemCode
return &cloned, nil
}
func (s *redeemCodeRepoStub) Update(_ context.Context, code *RedeemCode) error {
if code == nil {
return nil
}
cloned := *code
s.updateCalls = append(s.updateCalls, &cloned)
if s.codesByCode == nil {
s.codesByCode = make(map[string]*RedeemCode)
}
s.codesByCode[cloned.Code] = &cloned
return nil
}
func (s *redeemCodeRepoStub) Delete(context.Context, int64) error {
panic("unexpected Delete call")
}
func (s *redeemCodeRepoStub) Use(_ context.Context, id, userID int64) error {
for code, redeemCode := range s.codesByCode {
if redeemCode.ID != id {
continue
}
now := time.Now().UTC()
redeemCode.Status = StatusUsed
redeemCode.UsedBy = &userID
redeemCode.UsedAt = &now
s.codesByCode[code] = redeemCode
s.useCalls = append(s.useCalls, struct {
id int64
userID int64
}{id: id, userID: userID})
return nil
}
return ErrRedeemCodeNotFound
}
func (s *redeemCodeRepoStub) List(context.Context, pagination.PaginationParams) ([]RedeemCode, *pagination.PaginationResult, error) {
panic("unexpected List call")
}
func (s *redeemCodeRepoStub) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string) ([]RedeemCode, *pagination.PaginationResult, error) {
panic("unexpected ListWithFilters call")
}
func (s *redeemCodeRepoStub) ListByUser(context.Context, int64, int) ([]RedeemCode, error) {
panic("unexpected ListByUser call")
}
func (s *redeemCodeRepoStub) ListByUserPaginated(context.Context, int64, pagination.PaginationParams, string) ([]RedeemCode, *pagination.PaginationResult, error) {
panic("unexpected ListByUserPaginated call")
}
func (s *redeemCodeRepoStub) SumPositiveBalanceByUser(context.Context, int64) (float64, error) {
panic("unexpected SumPositiveBalanceByUser call")
}
func newOAuthEmailFlowAuthService(
userRepo UserRepository,
redeemRepo RedeemCodeRepository,
refreshTokenCache RefreshTokenCache,
settings map[string]string,
emailCache EmailCache,
) *AuthService {
cfg := &config.Config{
JWT: config.JWTConfig{
Secret: "test-secret",
ExpireHour: 1,
AccessTokenExpireMinutes: 60,
RefreshTokenExpireDays: 7,
},
Default: config.DefaultConfig{
UserBalance: 3.5,
UserConcurrency: 2,
},
}
settingService := NewSettingService(&settingRepoStub{values: settings}, cfg)
emailService := NewEmailService(&settingRepoStub{values: settings}, emailCache)
return NewAuthService(
nil,
userRepo,
redeemRepo,
refreshTokenCache,
cfg,
settingService,
emailService,
nil,
nil,
nil,
nil,
)
}
func TestRegisterOAuthEmailAccountRollsBackCreatedUserWhenTokenPairGenerationFails(t *testing.T) {
userRepo := &userRepoStub{nextID: 42}
redeemRepo := &redeemCodeRepoStub{
codesByCode: map[string]*RedeemCode{
"INVITE123": {
ID: 7,
Code: "INVITE123",
Type: RedeemTypeInvitation,
Status: StatusUnused,
},
},
}
emailCache := &emailCacheStub{
data: &VerificationCodeData{
Code: "246810",
Attempts: 0,
CreatedAt: time.Now().UTC(),
ExpiresAt: time.Now().UTC().Add(15 * time.Minute),
},
}
authService := newOAuthEmailFlowAuthService(
userRepo,
redeemRepo,
nil,
map[string]string{
SettingKeyRegistrationEnabled: "true",
SettingKeyInvitationCodeEnabled: "true",
SettingKeyEmailVerifyEnabled: "true",
},
emailCache,
)
tokenPair, user, err := authService.RegisterOAuthEmailAccount(
context.Background(),
"fresh@example.com",
"secret-123",
"246810",
"INVITE123",
"oidc",
)
require.Nil(t, tokenPair)
require.Nil(t, user)
require.Error(t, err)
require.Contains(t, err.Error(), "generate token pair")
require.Equal(t, []int64{42}, userRepo.deletedIDs)
require.Len(t, userRepo.created, 1)
require.Empty(t, redeemRepo.useCalls)
require.Empty(t, redeemRepo.updateCalls)
}
func TestRollbackOAuthEmailAccountCreationRestoresInvitationUsage(t *testing.T) {
userRepo := &userRepoStub{}
redeemRepo := &redeemCodeRepoStub{
codesByCode: map[string]*RedeemCode{
"INVITE123": {
ID: 7,
Code: "INVITE123",
Type: RedeemTypeInvitation,
Status: StatusUsed,
UsedBy: func() *int64 {
v := int64(42)
return &v
}(),
UsedAt: func() *time.Time {
v := time.Now().UTC()
return &v
}(),
},
},
}
authService := newOAuthEmailFlowAuthService(
userRepo,
redeemRepo,
&refreshTokenCacheStub{},
map[string]string{
SettingKeyRegistrationEnabled: "true",
SettingKeyInvitationCodeEnabled: "true",
},
&emailCacheStub{},
)
err := authService.RollbackOAuthEmailAccountCreation(context.Background(), 42, "INVITE123")
require.NoError(t, err)
require.Equal(t, []int64{42}, userRepo.deletedIDs)
require.Len(t, redeemRepo.updateCalls, 1)
require.Equal(t, StatusUnused, redeemRepo.updateCalls[0].Status)
require.Nil(t, redeemRepo.updateCalls[0].UsedBy)
require.Nil(t, redeemRepo.updateCalls[0].UsedAt)
}
func TestRollbackOAuthEmailAccountCreationPropagatesDeleteError(t *testing.T) {
userRepo := &userRepoStub{deleteErr: errors.New("delete failed")}
authService := newOAuthEmailFlowAuthService(
userRepo,
&redeemCodeRepoStub{},
&refreshTokenCacheStub{},
map[string]string{
SettingKeyRegistrationEnabled: "true",
},
&emailCacheStub{},
)
err := authService.RollbackOAuthEmailAccountCreation(context.Background(), 42, "")
require.Error(t, err)
require.Contains(t, err.Error(), "delete created oauth user")
}

View File

@@ -0,0 +1,316 @@
//go:build unit
package service_test
import (
"context"
"database/sql"
"testing"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/ent/enttest"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/repository"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
"entgo.io/ent/dialect"
entsql "entgo.io/ent/dialect/sql"
_ "modernc.org/sqlite"
)
type emailBindDefaultSubAssignerStub struct {
calls []*service.AssignSubscriptionInput
}
func (s *emailBindDefaultSubAssignerStub) AssignOrExtendSubscription(
_ context.Context,
input *service.AssignSubscriptionInput,
) (*service.UserSubscription, bool, error) {
cloned := *input
s.calls = append(s.calls, &cloned)
return &service.UserSubscription{UserID: input.UserID, GroupID: input.GroupID}, false, nil
}
func newAuthServiceForEmailBind(
t *testing.T,
settings map[string]string,
emailCache service.EmailCache,
defaultSubAssigner service.DefaultSubscriptionAssigner,
) (*service.AuthService, service.UserRepository, *dbent.Client) {
t.Helper()
db, err := sql.Open("sqlite", "file:auth_service_email_bind?mode=memory&cache=shared")
require.NoError(t, err)
t.Cleanup(func() { _ = db.Close() })
_, err = db.Exec("PRAGMA foreign_keys = ON")
require.NoError(t, err)
_, err = db.Exec(`
CREATE TABLE IF NOT EXISTS user_provider_default_grants (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL,
provider_type TEXT NOT NULL,
grant_reason TEXT NOT NULL DEFAULT 'first_bind',
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
UNIQUE(user_id, provider_type, grant_reason)
)`)
require.NoError(t, err)
drv := entsql.OpenDB(dialect.SQLite, db)
client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
t.Cleanup(func() { _ = client.Close() })
repo := repository.NewUserRepository(client, db)
cfg := &config.Config{
JWT: config.JWTConfig{
Secret: "test-bind-email-secret",
ExpireHour: 1,
},
Default: config.DefaultConfig{
UserBalance: 3.5,
UserConcurrency: 2,
},
}
settingRepo := &emailBindSettingRepoStub{values: settings}
settingSvc := service.NewSettingService(settingRepo, cfg)
var emailSvc *service.EmailService
if emailCache != nil {
emailSvc = service.NewEmailService(settingRepo, emailCache)
}
svc := service.NewAuthService(client, repo, nil, nil, cfg, settingSvc, emailSvc, nil, nil, nil, defaultSubAssigner)
return svc, repo, client
}
func TestAuthServiceBindEmailIdentity_UpdatesEmailAndAppliesFirstBindDefaults(t *testing.T) {
assigner := &emailBindDefaultSubAssignerStub{}
cache := &emailBindCacheStub{
data: &service.VerificationCodeData{
Code: "123456",
CreatedAt: time.Now().UTC(),
ExpiresAt: time.Now().UTC().Add(10 * time.Minute),
},
}
svc, _, client := newAuthServiceForEmailBind(t, map[string]string{
service.SettingKeyAuthSourceDefaultEmailBalance: "8.5",
service.SettingKeyAuthSourceDefaultEmailConcurrency: "4",
service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":11,"validity_days":30}]`,
service.SettingKeyAuthSourceDefaultEmailGrantOnFirstBind: "true",
}, cache, assigner)
ctx := context.Background()
user, err := client.User.Create().
SetEmail("legacy-user" + service.LinuxDoConnectSyntheticEmailDomain).
SetUsername("legacy-user").
SetPasswordHash("old-hash").
SetBalance(2.5).
SetConcurrency(1).
SetRole(service.RoleUser).
SetStatus(service.StatusActive).
Save(ctx)
require.NoError(t, err)
updatedUser, err := svc.BindEmailIdentity(ctx, user.ID, " NewEmail@Example.com ", "123456", "new-password")
require.NoError(t, err)
require.NotNil(t, updatedUser)
require.Equal(t, "newemail@example.com", updatedUser.Email)
storedUser, err := client.User.Get(ctx, user.ID)
require.NoError(t, err)
require.Equal(t, "newemail@example.com", storedUser.Email)
require.Equal(t, 11.0, storedUser.Balance)
require.Equal(t, 5, storedUser.Concurrency)
require.True(t, svc.CheckPassword("new-password", storedUser.PasswordHash))
identityCount, err := client.AuthIdentity.Query().
Where(
authidentity.UserIDEQ(user.ID),
authidentity.ProviderTypeEQ("email"),
authidentity.ProviderKeyEQ("email"),
authidentity.ProviderSubjectEQ("newemail@example.com"),
).
Count(ctx)
require.NoError(t, err)
require.Equal(t, 1, identityCount)
require.Len(t, assigner.calls, 1)
require.Equal(t, user.ID, assigner.calls[0].UserID)
require.Equal(t, int64(11), assigner.calls[0].GroupID)
require.Equal(t, 30, assigner.calls[0].ValidityDays)
require.Equal(t, 1, countProviderGrantRecords(t, client, user.ID, "email", "first_bind"))
}
func TestAuthServiceBindEmailIdentity_RejectsExistingEmailOnAnotherUser(t *testing.T) {
cache := &emailBindCacheStub{
data: &service.VerificationCodeData{
Code: "123456",
CreatedAt: time.Now().UTC(),
ExpiresAt: time.Now().UTC().Add(10 * time.Minute),
},
}
svc, _, client := newAuthServiceForEmailBind(t, nil, cache, nil)
ctx := context.Background()
sourceUser, err := client.User.Create().
SetEmail("source-user" + service.OIDCConnectSyntheticEmailDomain).
SetUsername("source-user").
SetPasswordHash("old-hash").
SetBalance(1).
SetConcurrency(1).
SetRole(service.RoleUser).
SetStatus(service.StatusActive).
Save(ctx)
require.NoError(t, err)
_, err = client.User.Create().
SetEmail("taken@example.com").
SetUsername("taken-user").
SetPasswordHash("hash").
SetBalance(1).
SetConcurrency(1).
SetRole(service.RoleUser).
SetStatus(service.StatusActive).
Save(ctx)
require.NoError(t, err)
updatedUser, err := svc.BindEmailIdentity(ctx, sourceUser.ID, "taken@example.com", "123456", "new-password")
require.ErrorIs(t, err, service.ErrEmailExists)
require.Nil(t, updatedUser)
storedUser, err := client.User.Get(ctx, sourceUser.ID)
require.NoError(t, err)
require.Equal(t, "source-user"+service.OIDCConnectSyntheticEmailDomain, storedUser.Email)
require.Equal(t, 0, countProviderGrantRecords(t, client, sourceUser.ID, "email", "first_bind"))
}
func TestAuthServiceBindEmailIdentity_RejectsReservedEmail(t *testing.T) {
cache := &emailBindCacheStub{
data: &service.VerificationCodeData{
Code: "123456",
CreatedAt: time.Now().UTC(),
ExpiresAt: time.Now().UTC().Add(10 * time.Minute),
},
}
svc, _, client := newAuthServiceForEmailBind(t, nil, cache, nil)
ctx := context.Background()
user, err := client.User.Create().
SetEmail("source-user@example.com").
SetUsername("source-user").
SetPasswordHash("old-hash").
SetBalance(1).
SetConcurrency(1).
SetRole(service.RoleUser).
SetStatus(service.StatusActive).
Save(ctx)
require.NoError(t, err)
updatedUser, err := svc.BindEmailIdentity(ctx, user.ID, "reserved"+service.LinuxDoConnectSyntheticEmailDomain, "123456", "new-password")
require.ErrorIs(t, err, service.ErrEmailReserved)
require.Nil(t, updatedUser)
}
type emailBindSettingRepoStub struct {
values map[string]string
}
func (s *emailBindSettingRepoStub) Get(context.Context, string) (*service.Setting, error) {
panic("unexpected Get call")
}
func (s *emailBindSettingRepoStub) GetValue(_ context.Context, key string) (string, error) {
if v, ok := s.values[key]; ok {
return v, nil
}
return "", service.ErrSettingNotFound
}
func (s *emailBindSettingRepoStub) Set(context.Context, string, string) error {
panic("unexpected Set call")
}
func (s *emailBindSettingRepoStub) GetMultiple(_ context.Context, keys []string) (map[string]string, error) {
out := make(map[string]string, len(keys))
for _, key := range keys {
if v, ok := s.values[key]; ok {
out[key] = v
}
}
return out, nil
}
func (s *emailBindSettingRepoStub) SetMultiple(context.Context, map[string]string) error {
panic("unexpected SetMultiple call")
}
func (s *emailBindSettingRepoStub) GetAll(context.Context) (map[string]string, error) {
panic("unexpected GetAll call")
}
func (s *emailBindSettingRepoStub) Delete(context.Context, string) error {
panic("unexpected Delete call")
}
type emailBindCacheStub struct {
data *service.VerificationCodeData
err error
}
func (s *emailBindCacheStub) GetVerificationCode(context.Context, string) (*service.VerificationCodeData, error) {
if s.err != nil {
return nil, s.err
}
return s.data, nil
}
func (s *emailBindCacheStub) SetVerificationCode(context.Context, string, *service.VerificationCodeData, time.Duration) error {
return nil
}
func (s *emailBindCacheStub) DeleteVerificationCode(context.Context, string) error {
return nil
}
func (s *emailBindCacheStub) GetNotifyVerifyCode(context.Context, string) (*service.VerificationCodeData, error) {
return nil, nil
}
func (s *emailBindCacheStub) SetNotifyVerifyCode(context.Context, string, *service.VerificationCodeData, time.Duration) error {
return nil
}
func (s *emailBindCacheStub) DeleteNotifyVerifyCode(context.Context, string) error {
return nil
}
func (s *emailBindCacheStub) GetPasswordResetToken(context.Context, string) (*service.PasswordResetTokenData, error) {
return nil, nil
}
func (s *emailBindCacheStub) SetPasswordResetToken(context.Context, string, *service.PasswordResetTokenData, time.Duration) error {
return nil
}
func (s *emailBindCacheStub) DeletePasswordResetToken(context.Context, string) error {
return nil
}
func (s *emailBindCacheStub) IsPasswordResetEmailInCooldown(context.Context, string) bool {
return false
}
func (s *emailBindCacheStub) SetPasswordResetEmailCooldown(context.Context, string, time.Duration) error {
return nil
}
func (s *emailBindCacheStub) GetNotifyCodeUserRate(context.Context, int64) (int64, error) {
return 0, nil
}
func (s *emailBindCacheStub) IncrNotifyCodeUserRate(context.Context, int64, time.Duration) (int64, error) {
return 0, nil
}