fix: tighten pending oauth email routing and binding state

This commit is contained in:
IanShaw027
2026-04-21 10:41:29 +08:00
parent dcd5c43da4
commit 7e89bca5e6
9 changed files with 685 additions and 54 deletions

View File

@@ -8,6 +8,7 @@ import (
"net/http"
"net/url"
"strings"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/authidentity"
@@ -35,6 +36,8 @@ const (
oauthCompletionResponseKey = "completion_response"
)
var pendingOAuthCreateAccountPreCommitHook func(context.Context, *dbent.PendingAuthSession) error
type oauthPendingSessionPayload struct {
Intent string
Identity service.PendingAuthIdentityKey
@@ -481,6 +484,26 @@ func (h *AuthHandler) SendPendingOAuthVerifyCode(c *gin.Context) {
return
}
client := h.entClient()
if client == nil {
response.ErrorFrom(c, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready"))
return
}
email := strings.TrimSpace(strings.ToLower(req.Email))
if existingUser, err := findUserByNormalizedEmail(c.Request.Context(), client, email); err == nil && existingUser != nil {
session, err = h.transitionPendingOAuthAccountToBindLogin(c, client, session, email, oauthAdoptionDecisionRequest{})
if err != nil {
response.ErrorFrom(c, err)
return
}
c.JSON(http.StatusOK, buildPendingOAuthSessionStatusPayload(session))
return
} else if err != nil && !errors.Is(err, service.ErrUserNotFound) {
response.ErrorFrom(c, err)
return
}
result, err := h.authService.SendPendingOAuthVerifyCode(c.Request.Context(), req.Email)
if err != nil {
response.ErrorFrom(c, err)
@@ -946,11 +969,46 @@ func applyPendingOAuthBinding(
return nil
}
if tx := dbent.TxFromContext(ctx); tx != nil {
return applyPendingOAuthBindingTx(ctx, tx, authService, userService, session, decision, overrideUserID, forceBind, applyFirstBindDefaults)
}
tx, err := client.Tx(ctx)
if err != nil {
return err
}
defer func() { _ = tx.Rollback() }()
txCtx := dbent.NewTxContext(ctx, tx)
if err := applyPendingOAuthBindingTx(txCtx, tx, authService, userService, session, decision, overrideUserID, forceBind, applyFirstBindDefaults); err != nil {
return err
}
return tx.Commit()
}
func applyPendingOAuthBindingTx(
ctx context.Context,
tx *dbent.Tx,
authService *service.AuthService,
userService *service.UserService,
session *dbent.PendingAuthSession,
decision *dbent.IdentityAdoptionDecision,
overrideUserID *int64,
forceBind bool,
applyFirstBindDefaults bool,
) error {
if tx == nil || session == nil {
return nil
}
if !forceBind && !shouldBindPendingOAuthIdentity(session, decision) {
return nil
}
targetUserID := int64(0)
if overrideUserID != nil && *overrideUserID > 0 {
targetUserID = *overrideUserID
} else {
resolvedUserID, err := resolvePendingOAuthTargetUserID(ctx, client, session)
resolvedUserID, err := resolvePendingOAuthTargetUserID(ctx, tx.Client(), session)
if err != nil {
return err
}
@@ -974,22 +1032,15 @@ func applyPendingOAuthBinding(
}
}
tx, err := client.Tx(ctx)
if err != nil {
return err
}
defer func() { _ = tx.Rollback() }()
txCtx := dbent.NewTxContext(ctx, tx)
if decision != nil && decision.AdoptDisplayName && adoptedDisplayName != "" {
if err := tx.Client().User.UpdateOneID(targetUserID).
SetUsername(adoptedDisplayName).
Exec(txCtx); err != nil {
Exec(ctx); err != nil {
return err
}
}
identity, err := ensurePendingOAuthIdentityForUser(txCtx, tx, session, targetUserID)
identity, err := ensurePendingOAuthIdentityForUser(ctx, tx, session, targetUserID)
if err != nil {
return err
}
@@ -1009,31 +1060,71 @@ func applyPendingOAuthBinding(
if issuer := oauthIdentityIssuer(session); issuer != nil {
updateIdentity = updateIdentity.SetIssuer(strings.TrimSpace(*issuer))
}
if _, err := updateIdentity.Save(txCtx); err != nil {
if _, err := updateIdentity.Save(ctx); err != nil {
return err
}
if decision != nil && (decision.IdentityID == nil || *decision.IdentityID != identity.ID) {
if _, err := tx.Client().IdentityAdoptionDecision.UpdateOneID(decision.ID).
SetIdentityID(identity.ID).
Save(txCtx); err != nil {
Save(ctx); err != nil {
return err
}
}
if applyFirstBindDefaults && authService != nil {
if err := authService.ApplyProviderDefaultSettingsOnFirstBind(txCtx, targetUserID, session.ProviderType); err != nil {
if err := authService.ApplyProviderDefaultSettingsOnFirstBind(ctx, targetUserID, session.ProviderType); err != nil {
return err
}
}
if shouldAdoptAvatar && userService != nil {
if _, err := userService.SetAvatar(txCtx, targetUserID, adoptedAvatarURL); err != nil {
if _, err := userService.SetAvatar(ctx, targetUserID, adoptedAvatarURL); err != nil {
return err
}
}
return tx.Commit()
return nil
}
func consumePendingOAuthBrowserSessionTx(
ctx context.Context,
tx *dbent.Tx,
session *dbent.PendingAuthSession,
) error {
if tx == nil || session == nil {
return service.ErrPendingAuthSessionNotFound
}
storedSession, err := tx.Client().PendingAuthSession.Get(ctx, session.ID)
if err != nil {
if dbent.IsNotFound(err) {
return service.ErrPendingAuthSessionNotFound
}
return err
}
now := time.Now().UTC()
if storedSession.ConsumedAt != nil {
return service.ErrPendingAuthSessionConsumed
}
if !storedSession.ExpiresAt.IsZero() && now.After(storedSession.ExpiresAt) {
return service.ErrPendingAuthSessionExpired
}
if strings.TrimSpace(storedSession.BrowserSessionKey) != "" &&
strings.TrimSpace(storedSession.BrowserSessionKey) != strings.TrimSpace(session.BrowserSessionKey) {
return service.ErrPendingAuthBrowserMismatch
}
if _, err := tx.Client().PendingAuthSession.UpdateOneID(storedSession.ID).
SetConsumedAt(now).
SetCompletionCodeHash("").
ClearCompletionCodeExpiresAt().
Save(ctx); err != nil {
return err
}
return nil
}
func applyPendingOAuthAdoption(
@@ -1256,7 +1347,7 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string)
return
}
pendingSvc, session, clearCookies, err := readPendingOAuthBrowserSession(c, h)
_, session, clearCookies, err := readPendingOAuthBrowserSession(c, h)
if err != nil {
response.ErrorFrom(c, err)
return
@@ -1341,7 +1432,20 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string)
response.ErrorFrom(c, err)
return
}
if err := applyPendingOAuthBinding(c.Request.Context(), client, h.authService, h.userService, session, decision, &user.ID, true, false); err != nil {
tx, err := client.Tx(c.Request.Context())
if err != nil {
if rollbackCreatedUser(err) {
return
}
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to bind pending oauth identity").WithCause(err))
return
}
defer func() { _ = tx.Rollback() }()
txCtx := dbent.NewTxContext(c.Request.Context(), tx)
if err := applyPendingOAuthBinding(txCtx, client, h.authService, h.userService, session, decision, &user.ID, true, false); err != nil {
_ = tx.Rollback()
if rollbackCreatedUser(err) {
return
}
@@ -1350,11 +1454,12 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string)
}
if err := h.authService.FinalizeOAuthEmailAccount(
c.Request.Context(),
txCtx,
user,
strings.TrimSpace(req.InvitationCode),
strings.TrimSpace(session.ProviderType),
); err != nil {
_ = tx.Rollback()
if rollbackCreatedUser(err) {
return
}
@@ -1362,7 +1467,8 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string)
return
}
if _, err := pendingSvc.ConsumeBrowserSession(c.Request.Context(), session.SessionToken, session.BrowserSessionKey); err != nil {
if err := consumePendingOAuthBrowserSessionTx(txCtx, tx, session); err != nil {
_ = tx.Rollback()
if rollbackCreatedUser(err) {
return
}
@@ -1371,6 +1477,25 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string)
return
}
if pendingOAuthCreateAccountPreCommitHook != nil {
if err := pendingOAuthCreateAccountPreCommitHook(txCtx, session); err != nil {
_ = tx.Rollback()
if rollbackCreatedUser(err) {
return
}
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to bind pending oauth identity").WithCause(err))
return
}
}
if err := tx.Commit(); err != nil {
if rollbackCreatedUser(err) {
return
}
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to bind pending oauth identity").WithCause(err))
return
}
h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID)
clearCookies()
writeOAuthTokenPairResponse(c, tokenPair)