fix: tighten pending oauth email routing and binding state
This commit is contained in:
@@ -8,6 +8,7 @@ import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/authidentity"
|
||||
@@ -35,6 +36,8 @@ const (
|
||||
oauthCompletionResponseKey = "completion_response"
|
||||
)
|
||||
|
||||
var pendingOAuthCreateAccountPreCommitHook func(context.Context, *dbent.PendingAuthSession) error
|
||||
|
||||
type oauthPendingSessionPayload struct {
|
||||
Intent string
|
||||
Identity service.PendingAuthIdentityKey
|
||||
@@ -481,6 +484,26 @@ func (h *AuthHandler) SendPendingOAuthVerifyCode(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
client := h.entClient()
|
||||
if client == nil {
|
||||
response.ErrorFrom(c, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready"))
|
||||
return
|
||||
}
|
||||
|
||||
email := strings.TrimSpace(strings.ToLower(req.Email))
|
||||
if existingUser, err := findUserByNormalizedEmail(c.Request.Context(), client, email); err == nil && existingUser != nil {
|
||||
session, err = h.transitionPendingOAuthAccountToBindLogin(c, client, session, email, oauthAdoptionDecisionRequest{})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, buildPendingOAuthSessionStatusPayload(session))
|
||||
return
|
||||
} else if err != nil && !errors.Is(err, service.ErrUserNotFound) {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
result, err := h.authService.SendPendingOAuthVerifyCode(c.Request.Context(), req.Email)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
@@ -946,11 +969,46 @@ func applyPendingOAuthBinding(
|
||||
return nil
|
||||
}
|
||||
|
||||
if tx := dbent.TxFromContext(ctx); tx != nil {
|
||||
return applyPendingOAuthBindingTx(ctx, tx, authService, userService, session, decision, overrideUserID, forceBind, applyFirstBindDefaults)
|
||||
}
|
||||
|
||||
tx, err := client.Tx(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = tx.Rollback() }()
|
||||
|
||||
txCtx := dbent.NewTxContext(ctx, tx)
|
||||
if err := applyPendingOAuthBindingTx(txCtx, tx, authService, userService, session, decision, overrideUserID, forceBind, applyFirstBindDefaults); err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func applyPendingOAuthBindingTx(
|
||||
ctx context.Context,
|
||||
tx *dbent.Tx,
|
||||
authService *service.AuthService,
|
||||
userService *service.UserService,
|
||||
session *dbent.PendingAuthSession,
|
||||
decision *dbent.IdentityAdoptionDecision,
|
||||
overrideUserID *int64,
|
||||
forceBind bool,
|
||||
applyFirstBindDefaults bool,
|
||||
) error {
|
||||
if tx == nil || session == nil {
|
||||
return nil
|
||||
}
|
||||
if !forceBind && !shouldBindPendingOAuthIdentity(session, decision) {
|
||||
return nil
|
||||
}
|
||||
|
||||
targetUserID := int64(0)
|
||||
if overrideUserID != nil && *overrideUserID > 0 {
|
||||
targetUserID = *overrideUserID
|
||||
} else {
|
||||
resolvedUserID, err := resolvePendingOAuthTargetUserID(ctx, client, session)
|
||||
resolvedUserID, err := resolvePendingOAuthTargetUserID(ctx, tx.Client(), session)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -974,22 +1032,15 @@ func applyPendingOAuthBinding(
|
||||
}
|
||||
}
|
||||
|
||||
tx, err := client.Tx(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = tx.Rollback() }()
|
||||
txCtx := dbent.NewTxContext(ctx, tx)
|
||||
|
||||
if decision != nil && decision.AdoptDisplayName && adoptedDisplayName != "" {
|
||||
if err := tx.Client().User.UpdateOneID(targetUserID).
|
||||
SetUsername(adoptedDisplayName).
|
||||
Exec(txCtx); err != nil {
|
||||
Exec(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
identity, err := ensurePendingOAuthIdentityForUser(txCtx, tx, session, targetUserID)
|
||||
identity, err := ensurePendingOAuthIdentityForUser(ctx, tx, session, targetUserID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -1009,31 +1060,71 @@ func applyPendingOAuthBinding(
|
||||
if issuer := oauthIdentityIssuer(session); issuer != nil {
|
||||
updateIdentity = updateIdentity.SetIssuer(strings.TrimSpace(*issuer))
|
||||
}
|
||||
if _, err := updateIdentity.Save(txCtx); err != nil {
|
||||
if _, err := updateIdentity.Save(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if decision != nil && (decision.IdentityID == nil || *decision.IdentityID != identity.ID) {
|
||||
if _, err := tx.Client().IdentityAdoptionDecision.UpdateOneID(decision.ID).
|
||||
SetIdentityID(identity.ID).
|
||||
Save(txCtx); err != nil {
|
||||
Save(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if applyFirstBindDefaults && authService != nil {
|
||||
if err := authService.ApplyProviderDefaultSettingsOnFirstBind(txCtx, targetUserID, session.ProviderType); err != nil {
|
||||
if err := authService.ApplyProviderDefaultSettingsOnFirstBind(ctx, targetUserID, session.ProviderType); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if shouldAdoptAvatar && userService != nil {
|
||||
if _, err := userService.SetAvatar(txCtx, targetUserID, adoptedAvatarURL); err != nil {
|
||||
if _, err := userService.SetAvatar(ctx, targetUserID, adoptedAvatarURL); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return tx.Commit()
|
||||
return nil
|
||||
}
|
||||
|
||||
func consumePendingOAuthBrowserSessionTx(
|
||||
ctx context.Context,
|
||||
tx *dbent.Tx,
|
||||
session *dbent.PendingAuthSession,
|
||||
) error {
|
||||
if tx == nil || session == nil {
|
||||
return service.ErrPendingAuthSessionNotFound
|
||||
}
|
||||
|
||||
storedSession, err := tx.Client().PendingAuthSession.Get(ctx, session.ID)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return service.ErrPendingAuthSessionNotFound
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
now := time.Now().UTC()
|
||||
if storedSession.ConsumedAt != nil {
|
||||
return service.ErrPendingAuthSessionConsumed
|
||||
}
|
||||
if !storedSession.ExpiresAt.IsZero() && now.After(storedSession.ExpiresAt) {
|
||||
return service.ErrPendingAuthSessionExpired
|
||||
}
|
||||
if strings.TrimSpace(storedSession.BrowserSessionKey) != "" &&
|
||||
strings.TrimSpace(storedSession.BrowserSessionKey) != strings.TrimSpace(session.BrowserSessionKey) {
|
||||
return service.ErrPendingAuthBrowserMismatch
|
||||
}
|
||||
|
||||
if _, err := tx.Client().PendingAuthSession.UpdateOneID(storedSession.ID).
|
||||
SetConsumedAt(now).
|
||||
SetCompletionCodeHash("").
|
||||
ClearCompletionCodeExpiresAt().
|
||||
Save(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func applyPendingOAuthAdoption(
|
||||
@@ -1256,7 +1347,7 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string)
|
||||
return
|
||||
}
|
||||
|
||||
pendingSvc, session, clearCookies, err := readPendingOAuthBrowserSession(c, h)
|
||||
_, session, clearCookies, err := readPendingOAuthBrowserSession(c, h)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
@@ -1341,7 +1432,20 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string)
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
if err := applyPendingOAuthBinding(c.Request.Context(), client, h.authService, h.userService, session, decision, &user.ID, true, false); err != nil {
|
||||
|
||||
tx, err := client.Tx(c.Request.Context())
|
||||
if err != nil {
|
||||
if rollbackCreatedUser(err) {
|
||||
return
|
||||
}
|
||||
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to bind pending oauth identity").WithCause(err))
|
||||
return
|
||||
}
|
||||
defer func() { _ = tx.Rollback() }()
|
||||
txCtx := dbent.NewTxContext(c.Request.Context(), tx)
|
||||
|
||||
if err := applyPendingOAuthBinding(txCtx, client, h.authService, h.userService, session, decision, &user.ID, true, false); err != nil {
|
||||
_ = tx.Rollback()
|
||||
if rollbackCreatedUser(err) {
|
||||
return
|
||||
}
|
||||
@@ -1350,11 +1454,12 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string)
|
||||
}
|
||||
|
||||
if err := h.authService.FinalizeOAuthEmailAccount(
|
||||
c.Request.Context(),
|
||||
txCtx,
|
||||
user,
|
||||
strings.TrimSpace(req.InvitationCode),
|
||||
strings.TrimSpace(session.ProviderType),
|
||||
); err != nil {
|
||||
_ = tx.Rollback()
|
||||
if rollbackCreatedUser(err) {
|
||||
return
|
||||
}
|
||||
@@ -1362,7 +1467,8 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string)
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := pendingSvc.ConsumeBrowserSession(c.Request.Context(), session.SessionToken, session.BrowserSessionKey); err != nil {
|
||||
if err := consumePendingOAuthBrowserSessionTx(txCtx, tx, session); err != nil {
|
||||
_ = tx.Rollback()
|
||||
if rollbackCreatedUser(err) {
|
||||
return
|
||||
}
|
||||
@@ -1371,6 +1477,25 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string)
|
||||
return
|
||||
}
|
||||
|
||||
if pendingOAuthCreateAccountPreCommitHook != nil {
|
||||
if err := pendingOAuthCreateAccountPreCommitHook(txCtx, session); err != nil {
|
||||
_ = tx.Rollback()
|
||||
if rollbackCreatedUser(err) {
|
||||
return
|
||||
}
|
||||
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to bind pending oauth identity").WithCause(err))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
if rollbackCreatedUser(err) {
|
||||
return
|
||||
}
|
||||
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to bind pending oauth identity").WithCause(err))
|
||||
return
|
||||
}
|
||||
|
||||
h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID)
|
||||
clearCookies()
|
||||
writeOAuthTokenPairResponse(c, tokenPair)
|
||||
|
||||
@@ -903,6 +903,63 @@ func TestCreateOIDCOAuthAccountExistingEmailNormalizesLegacySpacingAndCase(t *te
|
||||
require.Equal(t, "owner@example.com", storedSession.ResolvedEmail)
|
||||
}
|
||||
|
||||
func TestSendPendingOAuthVerifyCodeExistingEmailReturnsBindLoginState(t *testing.T) {
|
||||
handler, client := newOAuthPendingFlowTestHandlerWithEmailVerification(t, false, "owner@example.com", "135790")
|
||||
ctx := context.Background()
|
||||
|
||||
existingUser, err := client.User.Create().
|
||||
SetEmail("owner@example.com").
|
||||
SetUsername("owner-user").
|
||||
SetPasswordHash("hash").
|
||||
SetRole(service.RoleUser).
|
||||
SetStatus(service.StatusActive).
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
session, err := client.PendingAuthSession.Create().
|
||||
SetSessionToken("existing-email-send-code-session-token").
|
||||
SetIntent("login").
|
||||
SetProviderType("oidc").
|
||||
SetProviderKey("https://issuer.example").
|
||||
SetProviderSubject("oidc-existing-send-code-123").
|
||||
SetBrowserSessionKey("existing-email-send-code-browser-session-key").
|
||||
SetLocalFlowState(map[string]any{
|
||||
oauthCompletionResponseKey: map[string]any{
|
||||
"step": "email_required",
|
||||
},
|
||||
}).
|
||||
SetRedirectTo("/dashboard").
|
||||
SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
body := bytes.NewBufferString(`{"email":"owner@example.com"}`)
|
||||
recorder := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(recorder)
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/send-verify-code", body)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
|
||||
req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("existing-email-send-code-browser-session-key")})
|
||||
ginCtx.Request = req
|
||||
|
||||
handler.SendPendingOAuthVerifyCode(ginCtx)
|
||||
|
||||
require.Equal(t, http.StatusOK, recorder.Code)
|
||||
|
||||
var payload map[string]any
|
||||
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &payload))
|
||||
require.Equal(t, "pending_session", payload["auth_result"])
|
||||
require.Equal(t, "bind_login_required", payload["step"])
|
||||
require.Equal(t, "owner@example.com", payload["email"])
|
||||
|
||||
storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "adopt_existing_user_by_email", storedSession.Intent)
|
||||
require.NotNil(t, storedSession.TargetUserID)
|
||||
require.Equal(t, existingUser.ID, *storedSession.TargetUserID)
|
||||
require.Equal(t, "owner@example.com", storedSession.ResolvedEmail)
|
||||
}
|
||||
|
||||
func TestCreateOIDCOAuthAccountBlocksBackendModeBeforeCreatingUser(t *testing.T) {
|
||||
handler, client := newOAuthPendingFlowTestHandlerWithDependencies(t, oauthPendingFlowTestHandlerOptions{
|
||||
emailVerifyEnabled: true,
|
||||
@@ -1032,6 +1089,78 @@ func TestCreateOIDCOAuthAccountRollsBackCreatedUserWhenBindingFails(t *testing.T
|
||||
require.Nil(t, storedSession.ConsumedAt)
|
||||
}
|
||||
|
||||
func TestCreateOIDCOAuthAccountRollsBackPostBindFailureBeforeIdentityCanCommit(t *testing.T) {
|
||||
handler, client := newOAuthPendingFlowTestHandlerWithDependencies(t, oauthPendingFlowTestHandlerOptions{
|
||||
emailVerifyEnabled: true,
|
||||
emailCache: &oauthPendingFlowEmailCacheStub{
|
||||
verificationCodes: map[string]*service.VerificationCodeData{
|
||||
"fresh@example.com": {
|
||||
Code: "246810",
|
||||
CreatedAt: time.Now().UTC(),
|
||||
ExpiresAt: time.Now().UTC().Add(15 * time.Minute),
|
||||
},
|
||||
},
|
||||
},
|
||||
userRepoOptions: oauthPendingFlowUserRepoOptions{
|
||||
rejectDeleteWhileAuthIdentityExists: true,
|
||||
},
|
||||
})
|
||||
ctx := context.Background()
|
||||
|
||||
session, err := client.PendingAuthSession.Create().
|
||||
SetSessionToken("create-account-finalize-failure-session-token").
|
||||
SetIntent("login").
|
||||
SetProviderType("oidc").
|
||||
SetProviderKey("https://issuer.example").
|
||||
SetProviderSubject("oidc-finalize-failure-123").
|
||||
SetBrowserSessionKey("create-account-finalize-failure-browser-session-key").
|
||||
SetUpstreamIdentityClaims(map[string]any{
|
||||
"username": "oidc_user",
|
||||
}).
|
||||
SetRedirectTo("/profile").
|
||||
SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
pendingOAuthCreateAccountPreCommitHook = func(context.Context, *dbent.PendingAuthSession) error {
|
||||
return errors.New("forced post-bind failure")
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
pendingOAuthCreateAccountPreCommitHook = nil
|
||||
})
|
||||
|
||||
body := bytes.NewBufferString(`{"email":"fresh@example.com","verify_code":"246810","password":"secret-123"}`)
|
||||
recorder := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(recorder)
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/create-account", body)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
|
||||
req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("create-account-finalize-failure-browser-session-key")})
|
||||
ginCtx.Request = req
|
||||
|
||||
handler.CreateOIDCOAuthAccount(ginCtx)
|
||||
|
||||
require.Equal(t, http.StatusInternalServerError, recorder.Code)
|
||||
|
||||
userCount, err := client.User.Query().Where(dbuser.EmailEQ("fresh@example.com")).Count(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Zero(t, userCount)
|
||||
|
||||
identityCount, err := client.AuthIdentity.Query().
|
||||
Where(
|
||||
authidentity.ProviderTypeEQ("oidc"),
|
||||
authidentity.ProviderKeyEQ("https://issuer.example"),
|
||||
authidentity.ProviderSubjectEQ("oidc-finalize-failure-123"),
|
||||
).
|
||||
Count(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Zero(t, identityCount)
|
||||
|
||||
storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, storedSession.ConsumedAt)
|
||||
}
|
||||
|
||||
func TestBindOIDCOAuthLoginBindsExistingUserAndConsumesSession(t *testing.T) {
|
||||
handler, client := newOAuthPendingFlowTestHandler(t, false)
|
||||
ctx := context.Background()
|
||||
@@ -1618,7 +1747,6 @@ type oauthPendingFlowTestHandlerOptions struct {
|
||||
defaultSubAssigner service.DefaultSubscriptionAssigner
|
||||
totpCache service.TotpCache
|
||||
totpEncryptor service.SecretEncryptor
|
||||
redeemRepoFactory func(client *dbent.Client) service.RedeemCodeRepository
|
||||
userRepoOptions oauthPendingFlowUserRepoOptions
|
||||
}
|
||||
|
||||
@@ -1685,13 +1813,7 @@ CREATE TABLE IF NOT EXISTS user_avatars (
|
||||
client: client,
|
||||
options: options.userRepoOptions,
|
||||
}
|
||||
redeemRepo := service.RedeemCodeRepository(nil)
|
||||
if options.redeemRepoFactory != nil {
|
||||
redeemRepo = options.redeemRepoFactory(client)
|
||||
}
|
||||
if redeemRepo == nil {
|
||||
redeemRepo = &oauthPendingFlowRedeemCodeRepo{client: client}
|
||||
}
|
||||
redeemRepo := &oauthPendingFlowRedeemCodeRepo{client: client}
|
||||
var emailService *service.EmailService
|
||||
if options.emailCache != nil {
|
||||
emailService = service.NewEmailService(&oauthPendingFlowSettingRepoStub{
|
||||
@@ -2011,14 +2133,6 @@ func (r *oauthPendingFlowRedeemCodeRepo) SumPositiveBalanceByUser(context.Contex
|
||||
panic("unexpected SumPositiveBalanceByUser call")
|
||||
}
|
||||
|
||||
type oauthPendingFlowFailingUseRedeemRepo struct {
|
||||
*oauthPendingFlowRedeemCodeRepo
|
||||
}
|
||||
|
||||
func (r *oauthPendingFlowFailingUseRedeemRepo) Use(context.Context, int64, int64) error {
|
||||
return errors.New("forced invitation use failure")
|
||||
}
|
||||
|
||||
func decodeJSONResponseData(t *testing.T, recorder *httptest.ResponseRecorder) map[string]any {
|
||||
t.Helper()
|
||||
|
||||
@@ -2093,7 +2207,7 @@ func countProviderGrantRecords(
|
||||
}
|
||||
|
||||
type oauthPendingFlowUserRepo struct {
|
||||
client *dbent.Client
|
||||
client *dbent.Client
|
||||
options oauthPendingFlowUserRepoOptions
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user