fix: harden oidc compat email and email bind tx
This commit is contained in:
@@ -6,7 +6,10 @@ import (
|
||||
"fmt"
|
||||
"net/mail"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/authidentity"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
)
|
||||
|
||||
@@ -55,6 +58,13 @@ func (s *AuthService) BindEmailIdentity(
|
||||
}
|
||||
|
||||
firstRealEmailBind := !hasBindableEmailIdentitySubject(currentUser.Email)
|
||||
if firstRealEmailBind && s.entClient != nil {
|
||||
if err := s.bindEmailIdentityWithDefaultsTx(ctx, currentUser, normalizedEmail, hashedPassword); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return currentUser, nil
|
||||
}
|
||||
|
||||
currentUser.Email = normalizedEmail
|
||||
currentUser.PasswordHash = hashedPassword
|
||||
if err := s.userRepo.Update(ctx, currentUser); err != nil {
|
||||
@@ -126,3 +136,162 @@ func hasBindableEmailIdentitySubject(email string) bool {
|
||||
normalized := strings.ToLower(strings.TrimSpace(email))
|
||||
return normalized != "" && !isReservedEmail(normalized)
|
||||
}
|
||||
|
||||
func (s *AuthService) bindEmailIdentityWithDefaultsTx(
|
||||
ctx context.Context,
|
||||
currentUser *User,
|
||||
email string,
|
||||
hashedPassword string,
|
||||
) error {
|
||||
if tx := dbent.TxFromContext(ctx); tx != nil {
|
||||
return s.bindEmailIdentityWithDefaults(ctx, tx.Client(), currentUser, email, hashedPassword)
|
||||
}
|
||||
|
||||
tx, err := s.entClient.Tx(ctx)
|
||||
if err != nil {
|
||||
return ErrServiceUnavailable
|
||||
}
|
||||
defer func() { _ = tx.Rollback() }()
|
||||
|
||||
txCtx := dbent.NewTxContext(ctx, tx)
|
||||
if err := s.bindEmailIdentityWithDefaults(txCtx, tx.Client(), currentUser, email, hashedPassword); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := tx.Commit(); err != nil {
|
||||
return ErrServiceUnavailable
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *AuthService) bindEmailIdentityWithDefaults(
|
||||
ctx context.Context,
|
||||
client *dbent.Client,
|
||||
currentUser *User,
|
||||
email string,
|
||||
hashedPassword string,
|
||||
) error {
|
||||
if client == nil || currentUser == nil || currentUser.ID <= 0 {
|
||||
return ErrServiceUnavailable
|
||||
}
|
||||
|
||||
oldEmail := currentUser.Email
|
||||
if _, err := client.User.UpdateOneID(currentUser.ID).
|
||||
SetEmail(email).
|
||||
SetPasswordHash(hashedPassword).
|
||||
Save(ctx); err != nil {
|
||||
if dbent.IsConstraintError(err) {
|
||||
return ErrEmailExists
|
||||
}
|
||||
return ErrServiceUnavailable
|
||||
}
|
||||
|
||||
if err := replaceBoundEmailAuthIdentityWithClient(ctx, client, currentUser.ID, oldEmail, email, "auth_service_email_bind"); err != nil {
|
||||
if errors.Is(err, ErrEmailExists) {
|
||||
return ErrEmailExists
|
||||
}
|
||||
return ErrServiceUnavailable
|
||||
}
|
||||
|
||||
if err := s.ApplyProviderDefaultSettingsOnFirstBind(ctx, currentUser.ID, "email"); err != nil {
|
||||
return fmt.Errorf("apply email first bind defaults: %w", err)
|
||||
}
|
||||
|
||||
updatedUser, err := client.User.Get(ctx, currentUser.ID)
|
||||
if err != nil {
|
||||
return ErrServiceUnavailable
|
||||
}
|
||||
currentUser.Email = updatedUser.Email
|
||||
currentUser.PasswordHash = updatedUser.PasswordHash
|
||||
currentUser.Balance = updatedUser.Balance
|
||||
currentUser.Concurrency = updatedUser.Concurrency
|
||||
currentUser.UpdatedAt = updatedUser.UpdatedAt
|
||||
return nil
|
||||
}
|
||||
|
||||
func replaceBoundEmailAuthIdentityWithClient(
|
||||
ctx context.Context,
|
||||
client *dbent.Client,
|
||||
userID int64,
|
||||
oldEmail string,
|
||||
newEmail string,
|
||||
source string,
|
||||
) error {
|
||||
newSubject := normalizeBoundEmailAuthIdentitySubject(newEmail)
|
||||
if err := ensureBoundEmailAuthIdentityWithClient(ctx, client, userID, newSubject, source); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
oldSubject := normalizeBoundEmailAuthIdentitySubject(oldEmail)
|
||||
if oldSubject == "" || oldSubject == newSubject {
|
||||
return nil
|
||||
}
|
||||
|
||||
_, err := client.AuthIdentity.Delete().
|
||||
Where(
|
||||
authidentity.UserIDEQ(userID),
|
||||
authidentity.ProviderTypeEQ("email"),
|
||||
authidentity.ProviderKeyEQ("email"),
|
||||
authidentity.ProviderSubjectEQ(oldSubject),
|
||||
).
|
||||
Exec(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
func ensureBoundEmailAuthIdentityWithClient(
|
||||
ctx context.Context,
|
||||
client *dbent.Client,
|
||||
userID int64,
|
||||
subject string,
|
||||
source string,
|
||||
) error {
|
||||
if client == nil || userID <= 0 || subject == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
if strings.TrimSpace(source) == "" {
|
||||
source = "auth_service_email_bind"
|
||||
}
|
||||
|
||||
if err := client.AuthIdentity.Create().
|
||||
SetUserID(userID).
|
||||
SetProviderType("email").
|
||||
SetProviderKey("email").
|
||||
SetProviderSubject(subject).
|
||||
SetVerifiedAt(time.Now().UTC()).
|
||||
SetMetadata(map[string]any{"source": strings.TrimSpace(source)}).
|
||||
OnConflictColumns(
|
||||
authidentity.FieldProviderType,
|
||||
authidentity.FieldProviderKey,
|
||||
authidentity.FieldProviderSubject,
|
||||
).
|
||||
DoNothing().
|
||||
Exec(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
identity, err := client.AuthIdentity.Query().
|
||||
Where(
|
||||
authidentity.ProviderTypeEQ("email"),
|
||||
authidentity.ProviderKeyEQ("email"),
|
||||
authidentity.ProviderSubjectEQ(subject),
|
||||
).
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
if identity.UserID != userID {
|
||||
return ErrEmailExists
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func normalizeBoundEmailAuthIdentitySubject(email string) string {
|
||||
normalized := strings.ToLower(strings.TrimSpace(email))
|
||||
if normalized == "" || isReservedEmail(normalized) {
|
||||
return ""
|
||||
}
|
||||
return normalized
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user