fix: tighten pending oauth email routing and binding state
This commit is contained in:
@@ -8,6 +8,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||||
"github.com/Wei-Shaw/sub2api/ent/authidentity"
|
"github.com/Wei-Shaw/sub2api/ent/authidentity"
|
||||||
@@ -35,6 +36,8 @@ const (
|
|||||||
oauthCompletionResponseKey = "completion_response"
|
oauthCompletionResponseKey = "completion_response"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var pendingOAuthCreateAccountPreCommitHook func(context.Context, *dbent.PendingAuthSession) error
|
||||||
|
|
||||||
type oauthPendingSessionPayload struct {
|
type oauthPendingSessionPayload struct {
|
||||||
Intent string
|
Intent string
|
||||||
Identity service.PendingAuthIdentityKey
|
Identity service.PendingAuthIdentityKey
|
||||||
@@ -481,6 +484,26 @@ func (h *AuthHandler) SendPendingOAuthVerifyCode(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
client := h.entClient()
|
||||||
|
if client == nil {
|
||||||
|
response.ErrorFrom(c, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
email := strings.TrimSpace(strings.ToLower(req.Email))
|
||||||
|
if existingUser, err := findUserByNormalizedEmail(c.Request.Context(), client, email); err == nil && existingUser != nil {
|
||||||
|
session, err = h.transitionPendingOAuthAccountToBindLogin(c, client, session, email, oauthAdoptionDecisionRequest{})
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, buildPendingOAuthSessionStatusPayload(session))
|
||||||
|
return
|
||||||
|
} else if err != nil && !errors.Is(err, service.ErrUserNotFound) {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
result, err := h.authService.SendPendingOAuthVerifyCode(c.Request.Context(), req.Email)
|
result, err := h.authService.SendPendingOAuthVerifyCode(c.Request.Context(), req.Email)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
@@ -946,11 +969,46 @@ func applyPendingOAuthBinding(
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if tx := dbent.TxFromContext(ctx); tx != nil {
|
||||||
|
return applyPendingOAuthBindingTx(ctx, tx, authService, userService, session, decision, overrideUserID, forceBind, applyFirstBindDefaults)
|
||||||
|
}
|
||||||
|
|
||||||
|
tx, err := client.Tx(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer func() { _ = tx.Rollback() }()
|
||||||
|
|
||||||
|
txCtx := dbent.NewTxContext(ctx, tx)
|
||||||
|
if err := applyPendingOAuthBindingTx(txCtx, tx, authService, userService, session, decision, overrideUserID, forceBind, applyFirstBindDefaults); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return tx.Commit()
|
||||||
|
}
|
||||||
|
|
||||||
|
func applyPendingOAuthBindingTx(
|
||||||
|
ctx context.Context,
|
||||||
|
tx *dbent.Tx,
|
||||||
|
authService *service.AuthService,
|
||||||
|
userService *service.UserService,
|
||||||
|
session *dbent.PendingAuthSession,
|
||||||
|
decision *dbent.IdentityAdoptionDecision,
|
||||||
|
overrideUserID *int64,
|
||||||
|
forceBind bool,
|
||||||
|
applyFirstBindDefaults bool,
|
||||||
|
) error {
|
||||||
|
if tx == nil || session == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if !forceBind && !shouldBindPendingOAuthIdentity(session, decision) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
targetUserID := int64(0)
|
targetUserID := int64(0)
|
||||||
if overrideUserID != nil && *overrideUserID > 0 {
|
if overrideUserID != nil && *overrideUserID > 0 {
|
||||||
targetUserID = *overrideUserID
|
targetUserID = *overrideUserID
|
||||||
} else {
|
} else {
|
||||||
resolvedUserID, err := resolvePendingOAuthTargetUserID(ctx, client, session)
|
resolvedUserID, err := resolvePendingOAuthTargetUserID(ctx, tx.Client(), session)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -974,22 +1032,15 @@ func applyPendingOAuthBinding(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
tx, err := client.Tx(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer func() { _ = tx.Rollback() }()
|
|
||||||
txCtx := dbent.NewTxContext(ctx, tx)
|
|
||||||
|
|
||||||
if decision != nil && decision.AdoptDisplayName && adoptedDisplayName != "" {
|
if decision != nil && decision.AdoptDisplayName && adoptedDisplayName != "" {
|
||||||
if err := tx.Client().User.UpdateOneID(targetUserID).
|
if err := tx.Client().User.UpdateOneID(targetUserID).
|
||||||
SetUsername(adoptedDisplayName).
|
SetUsername(adoptedDisplayName).
|
||||||
Exec(txCtx); err != nil {
|
Exec(ctx); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
identity, err := ensurePendingOAuthIdentityForUser(txCtx, tx, session, targetUserID)
|
identity, err := ensurePendingOAuthIdentityForUser(ctx, tx, session, targetUserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -1009,31 +1060,71 @@ func applyPendingOAuthBinding(
|
|||||||
if issuer := oauthIdentityIssuer(session); issuer != nil {
|
if issuer := oauthIdentityIssuer(session); issuer != nil {
|
||||||
updateIdentity = updateIdentity.SetIssuer(strings.TrimSpace(*issuer))
|
updateIdentity = updateIdentity.SetIssuer(strings.TrimSpace(*issuer))
|
||||||
}
|
}
|
||||||
if _, err := updateIdentity.Save(txCtx); err != nil {
|
if _, err := updateIdentity.Save(ctx); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if decision != nil && (decision.IdentityID == nil || *decision.IdentityID != identity.ID) {
|
if decision != nil && (decision.IdentityID == nil || *decision.IdentityID != identity.ID) {
|
||||||
if _, err := tx.Client().IdentityAdoptionDecision.UpdateOneID(decision.ID).
|
if _, err := tx.Client().IdentityAdoptionDecision.UpdateOneID(decision.ID).
|
||||||
SetIdentityID(identity.ID).
|
SetIdentityID(identity.ID).
|
||||||
Save(txCtx); err != nil {
|
Save(ctx); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if applyFirstBindDefaults && authService != nil {
|
if applyFirstBindDefaults && authService != nil {
|
||||||
if err := authService.ApplyProviderDefaultSettingsOnFirstBind(txCtx, targetUserID, session.ProviderType); err != nil {
|
if err := authService.ApplyProviderDefaultSettingsOnFirstBind(ctx, targetUserID, session.ProviderType); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if shouldAdoptAvatar && userService != nil {
|
if shouldAdoptAvatar && userService != nil {
|
||||||
if _, err := userService.SetAvatar(txCtx, targetUserID, adoptedAvatarURL); err != nil {
|
if _, err := userService.SetAvatar(ctx, targetUserID, adoptedAvatarURL); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return tx.Commit()
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func consumePendingOAuthBrowserSessionTx(
|
||||||
|
ctx context.Context,
|
||||||
|
tx *dbent.Tx,
|
||||||
|
session *dbent.PendingAuthSession,
|
||||||
|
) error {
|
||||||
|
if tx == nil || session == nil {
|
||||||
|
return service.ErrPendingAuthSessionNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
storedSession, err := tx.Client().PendingAuthSession.Get(ctx, session.ID)
|
||||||
|
if err != nil {
|
||||||
|
if dbent.IsNotFound(err) {
|
||||||
|
return service.ErrPendingAuthSessionNotFound
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now().UTC()
|
||||||
|
if storedSession.ConsumedAt != nil {
|
||||||
|
return service.ErrPendingAuthSessionConsumed
|
||||||
|
}
|
||||||
|
if !storedSession.ExpiresAt.IsZero() && now.After(storedSession.ExpiresAt) {
|
||||||
|
return service.ErrPendingAuthSessionExpired
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(storedSession.BrowserSessionKey) != "" &&
|
||||||
|
strings.TrimSpace(storedSession.BrowserSessionKey) != strings.TrimSpace(session.BrowserSessionKey) {
|
||||||
|
return service.ErrPendingAuthBrowserMismatch
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := tx.Client().PendingAuthSession.UpdateOneID(storedSession.ID).
|
||||||
|
SetConsumedAt(now).
|
||||||
|
SetCompletionCodeHash("").
|
||||||
|
ClearCompletionCodeExpiresAt().
|
||||||
|
Save(ctx); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func applyPendingOAuthAdoption(
|
func applyPendingOAuthAdoption(
|
||||||
@@ -1256,7 +1347,7 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string)
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
pendingSvc, session, clearCookies, err := readPendingOAuthBrowserSession(c, h)
|
_, session, clearCookies, err := readPendingOAuthBrowserSession(c, h)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
return
|
return
|
||||||
@@ -1341,7 +1432,20 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string)
|
|||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if err := applyPendingOAuthBinding(c.Request.Context(), client, h.authService, h.userService, session, decision, &user.ID, true, false); err != nil {
|
|
||||||
|
tx, err := client.Tx(c.Request.Context())
|
||||||
|
if err != nil {
|
||||||
|
if rollbackCreatedUser(err) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to bind pending oauth identity").WithCause(err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer func() { _ = tx.Rollback() }()
|
||||||
|
txCtx := dbent.NewTxContext(c.Request.Context(), tx)
|
||||||
|
|
||||||
|
if err := applyPendingOAuthBinding(txCtx, client, h.authService, h.userService, session, decision, &user.ID, true, false); err != nil {
|
||||||
|
_ = tx.Rollback()
|
||||||
if rollbackCreatedUser(err) {
|
if rollbackCreatedUser(err) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -1350,11 +1454,12 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string)
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err := h.authService.FinalizeOAuthEmailAccount(
|
if err := h.authService.FinalizeOAuthEmailAccount(
|
||||||
c.Request.Context(),
|
txCtx,
|
||||||
user,
|
user,
|
||||||
strings.TrimSpace(req.InvitationCode),
|
strings.TrimSpace(req.InvitationCode),
|
||||||
strings.TrimSpace(session.ProviderType),
|
strings.TrimSpace(session.ProviderType),
|
||||||
); err != nil {
|
); err != nil {
|
||||||
|
_ = tx.Rollback()
|
||||||
if rollbackCreatedUser(err) {
|
if rollbackCreatedUser(err) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -1362,7 +1467,8 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string)
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := pendingSvc.ConsumeBrowserSession(c.Request.Context(), session.SessionToken, session.BrowserSessionKey); err != nil {
|
if err := consumePendingOAuthBrowserSessionTx(txCtx, tx, session); err != nil {
|
||||||
|
_ = tx.Rollback()
|
||||||
if rollbackCreatedUser(err) {
|
if rollbackCreatedUser(err) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -1371,6 +1477,25 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string)
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if pendingOAuthCreateAccountPreCommitHook != nil {
|
||||||
|
if err := pendingOAuthCreateAccountPreCommitHook(txCtx, session); err != nil {
|
||||||
|
_ = tx.Rollback()
|
||||||
|
if rollbackCreatedUser(err) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to bind pending oauth identity").WithCause(err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tx.Commit(); err != nil {
|
||||||
|
if rollbackCreatedUser(err) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to bind pending oauth identity").WithCause(err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID)
|
h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID)
|
||||||
clearCookies()
|
clearCookies()
|
||||||
writeOAuthTokenPairResponse(c, tokenPair)
|
writeOAuthTokenPairResponse(c, tokenPair)
|
||||||
|
|||||||
@@ -903,6 +903,63 @@ func TestCreateOIDCOAuthAccountExistingEmailNormalizesLegacySpacingAndCase(t *te
|
|||||||
require.Equal(t, "owner@example.com", storedSession.ResolvedEmail)
|
require.Equal(t, "owner@example.com", storedSession.ResolvedEmail)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSendPendingOAuthVerifyCodeExistingEmailReturnsBindLoginState(t *testing.T) {
|
||||||
|
handler, client := newOAuthPendingFlowTestHandlerWithEmailVerification(t, false, "owner@example.com", "135790")
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
existingUser, err := client.User.Create().
|
||||||
|
SetEmail("owner@example.com").
|
||||||
|
SetUsername("owner-user").
|
||||||
|
SetPasswordHash("hash").
|
||||||
|
SetRole(service.RoleUser).
|
||||||
|
SetStatus(service.StatusActive).
|
||||||
|
Save(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
session, err := client.PendingAuthSession.Create().
|
||||||
|
SetSessionToken("existing-email-send-code-session-token").
|
||||||
|
SetIntent("login").
|
||||||
|
SetProviderType("oidc").
|
||||||
|
SetProviderKey("https://issuer.example").
|
||||||
|
SetProviderSubject("oidc-existing-send-code-123").
|
||||||
|
SetBrowserSessionKey("existing-email-send-code-browser-session-key").
|
||||||
|
SetLocalFlowState(map[string]any{
|
||||||
|
oauthCompletionResponseKey: map[string]any{
|
||||||
|
"step": "email_required",
|
||||||
|
},
|
||||||
|
}).
|
||||||
|
SetRedirectTo("/dashboard").
|
||||||
|
SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
|
||||||
|
Save(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
body := bytes.NewBufferString(`{"email":"owner@example.com"}`)
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
ginCtx, _ := gin.CreateTestContext(recorder)
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/send-verify-code", body)
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
|
||||||
|
req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("existing-email-send-code-browser-session-key")})
|
||||||
|
ginCtx.Request = req
|
||||||
|
|
||||||
|
handler.SendPendingOAuthVerifyCode(ginCtx)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusOK, recorder.Code)
|
||||||
|
|
||||||
|
var payload map[string]any
|
||||||
|
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &payload))
|
||||||
|
require.Equal(t, "pending_session", payload["auth_result"])
|
||||||
|
require.Equal(t, "bind_login_required", payload["step"])
|
||||||
|
require.Equal(t, "owner@example.com", payload["email"])
|
||||||
|
|
||||||
|
storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "adopt_existing_user_by_email", storedSession.Intent)
|
||||||
|
require.NotNil(t, storedSession.TargetUserID)
|
||||||
|
require.Equal(t, existingUser.ID, *storedSession.TargetUserID)
|
||||||
|
require.Equal(t, "owner@example.com", storedSession.ResolvedEmail)
|
||||||
|
}
|
||||||
|
|
||||||
func TestCreateOIDCOAuthAccountBlocksBackendModeBeforeCreatingUser(t *testing.T) {
|
func TestCreateOIDCOAuthAccountBlocksBackendModeBeforeCreatingUser(t *testing.T) {
|
||||||
handler, client := newOAuthPendingFlowTestHandlerWithDependencies(t, oauthPendingFlowTestHandlerOptions{
|
handler, client := newOAuthPendingFlowTestHandlerWithDependencies(t, oauthPendingFlowTestHandlerOptions{
|
||||||
emailVerifyEnabled: true,
|
emailVerifyEnabled: true,
|
||||||
@@ -1032,6 +1089,78 @@ func TestCreateOIDCOAuthAccountRollsBackCreatedUserWhenBindingFails(t *testing.T
|
|||||||
require.Nil(t, storedSession.ConsumedAt)
|
require.Nil(t, storedSession.ConsumedAt)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCreateOIDCOAuthAccountRollsBackPostBindFailureBeforeIdentityCanCommit(t *testing.T) {
|
||||||
|
handler, client := newOAuthPendingFlowTestHandlerWithDependencies(t, oauthPendingFlowTestHandlerOptions{
|
||||||
|
emailVerifyEnabled: true,
|
||||||
|
emailCache: &oauthPendingFlowEmailCacheStub{
|
||||||
|
verificationCodes: map[string]*service.VerificationCodeData{
|
||||||
|
"fresh@example.com": {
|
||||||
|
Code: "246810",
|
||||||
|
CreatedAt: time.Now().UTC(),
|
||||||
|
ExpiresAt: time.Now().UTC().Add(15 * time.Minute),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
userRepoOptions: oauthPendingFlowUserRepoOptions{
|
||||||
|
rejectDeleteWhileAuthIdentityExists: true,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
session, err := client.PendingAuthSession.Create().
|
||||||
|
SetSessionToken("create-account-finalize-failure-session-token").
|
||||||
|
SetIntent("login").
|
||||||
|
SetProviderType("oidc").
|
||||||
|
SetProviderKey("https://issuer.example").
|
||||||
|
SetProviderSubject("oidc-finalize-failure-123").
|
||||||
|
SetBrowserSessionKey("create-account-finalize-failure-browser-session-key").
|
||||||
|
SetUpstreamIdentityClaims(map[string]any{
|
||||||
|
"username": "oidc_user",
|
||||||
|
}).
|
||||||
|
SetRedirectTo("/profile").
|
||||||
|
SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
|
||||||
|
Save(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
pendingOAuthCreateAccountPreCommitHook = func(context.Context, *dbent.PendingAuthSession) error {
|
||||||
|
return errors.New("forced post-bind failure")
|
||||||
|
}
|
||||||
|
t.Cleanup(func() {
|
||||||
|
pendingOAuthCreateAccountPreCommitHook = nil
|
||||||
|
})
|
||||||
|
|
||||||
|
body := bytes.NewBufferString(`{"email":"fresh@example.com","verify_code":"246810","password":"secret-123"}`)
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
ginCtx, _ := gin.CreateTestContext(recorder)
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/create-account", body)
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
|
||||||
|
req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("create-account-finalize-failure-browser-session-key")})
|
||||||
|
ginCtx.Request = req
|
||||||
|
|
||||||
|
handler.CreateOIDCOAuthAccount(ginCtx)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusInternalServerError, recorder.Code)
|
||||||
|
|
||||||
|
userCount, err := client.User.Query().Where(dbuser.EmailEQ("fresh@example.com")).Count(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Zero(t, userCount)
|
||||||
|
|
||||||
|
identityCount, err := client.AuthIdentity.Query().
|
||||||
|
Where(
|
||||||
|
authidentity.ProviderTypeEQ("oidc"),
|
||||||
|
authidentity.ProviderKeyEQ("https://issuer.example"),
|
||||||
|
authidentity.ProviderSubjectEQ("oidc-finalize-failure-123"),
|
||||||
|
).
|
||||||
|
Count(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Zero(t, identityCount)
|
||||||
|
|
||||||
|
storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Nil(t, storedSession.ConsumedAt)
|
||||||
|
}
|
||||||
|
|
||||||
func TestBindOIDCOAuthLoginBindsExistingUserAndConsumesSession(t *testing.T) {
|
func TestBindOIDCOAuthLoginBindsExistingUserAndConsumesSession(t *testing.T) {
|
||||||
handler, client := newOAuthPendingFlowTestHandler(t, false)
|
handler, client := newOAuthPendingFlowTestHandler(t, false)
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
@@ -1618,7 +1747,6 @@ type oauthPendingFlowTestHandlerOptions struct {
|
|||||||
defaultSubAssigner service.DefaultSubscriptionAssigner
|
defaultSubAssigner service.DefaultSubscriptionAssigner
|
||||||
totpCache service.TotpCache
|
totpCache service.TotpCache
|
||||||
totpEncryptor service.SecretEncryptor
|
totpEncryptor service.SecretEncryptor
|
||||||
redeemRepoFactory func(client *dbent.Client) service.RedeemCodeRepository
|
|
||||||
userRepoOptions oauthPendingFlowUserRepoOptions
|
userRepoOptions oauthPendingFlowUserRepoOptions
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1685,13 +1813,7 @@ CREATE TABLE IF NOT EXISTS user_avatars (
|
|||||||
client: client,
|
client: client,
|
||||||
options: options.userRepoOptions,
|
options: options.userRepoOptions,
|
||||||
}
|
}
|
||||||
redeemRepo := service.RedeemCodeRepository(nil)
|
redeemRepo := &oauthPendingFlowRedeemCodeRepo{client: client}
|
||||||
if options.redeemRepoFactory != nil {
|
|
||||||
redeemRepo = options.redeemRepoFactory(client)
|
|
||||||
}
|
|
||||||
if redeemRepo == nil {
|
|
||||||
redeemRepo = &oauthPendingFlowRedeemCodeRepo{client: client}
|
|
||||||
}
|
|
||||||
var emailService *service.EmailService
|
var emailService *service.EmailService
|
||||||
if options.emailCache != nil {
|
if options.emailCache != nil {
|
||||||
emailService = service.NewEmailService(&oauthPendingFlowSettingRepoStub{
|
emailService = service.NewEmailService(&oauthPendingFlowSettingRepoStub{
|
||||||
@@ -2011,14 +2133,6 @@ func (r *oauthPendingFlowRedeemCodeRepo) SumPositiveBalanceByUser(context.Contex
|
|||||||
panic("unexpected SumPositiveBalanceByUser call")
|
panic("unexpected SumPositiveBalanceByUser call")
|
||||||
}
|
}
|
||||||
|
|
||||||
type oauthPendingFlowFailingUseRedeemRepo struct {
|
|
||||||
*oauthPendingFlowRedeemCodeRepo
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *oauthPendingFlowFailingUseRedeemRepo) Use(context.Context, int64, int64) error {
|
|
||||||
return errors.New("forced invitation use failure")
|
|
||||||
}
|
|
||||||
|
|
||||||
func decodeJSONResponseData(t *testing.T, recorder *httptest.ResponseRecorder) map[string]any {
|
func decodeJSONResponseData(t *testing.T, recorder *httptest.ResponseRecorder) map[string]any {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
|
|||||||
@@ -7,6 +7,9 @@ import (
|
|||||||
"net/mail"
|
"net/mail"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||||
|
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
|
||||||
)
|
)
|
||||||
|
|
||||||
func normalizeOAuthSignupSource(signupSource string) string {
|
func normalizeOAuthSignupSource(signupSource string) string {
|
||||||
@@ -50,7 +53,7 @@ func (s *AuthService) validateOAuthRegistrationInvitation(ctx context.Context, i
|
|||||||
if s == nil || s.settingService == nil || !s.settingService.IsInvitationCodeEnabled(ctx) {
|
if s == nil || s.settingService == nil || !s.settingService.IsInvitationCodeEnabled(ctx) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
if s.redeemRepo == nil {
|
if s.redeemRepo == nil && s.oauthEmailFlowClient(ctx) == nil {
|
||||||
return nil, ErrServiceUnavailable
|
return nil, ErrServiceUnavailable
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -59,7 +62,7 @@ func (s *AuthService) validateOAuthRegistrationInvitation(ctx context.Context, i
|
|||||||
return nil, ErrInvitationCodeRequired
|
return nil, ErrInvitationCodeRequired
|
||||||
}
|
}
|
||||||
|
|
||||||
redeemCode, err := s.redeemRepo.GetByCode(ctx, invitationCode)
|
redeemCode, err := s.loadOAuthRegistrationInvitation(ctx, invitationCode)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, ErrInvitationCodeInvalid
|
return nil, ErrInvitationCodeInvalid
|
||||||
}
|
}
|
||||||
@@ -181,12 +184,12 @@ func (s *AuthService) FinalizeOAuthEmailAccount(
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if invitationRedeemCode != nil {
|
if invitationRedeemCode != nil {
|
||||||
if err := s.redeemRepo.Use(ctx, invitationRedeemCode.ID, user.ID); err != nil {
|
if err := s.useOAuthRegistrationInvitation(ctx, invitationRedeemCode.ID, user.ID); err != nil {
|
||||||
return ErrInvitationCodeInvalid
|
return ErrInvitationCodeInvalid
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
s.postAuthUserBootstrap(ctx, user, signupSource, false)
|
s.updateOAuthSignupSource(ctx, user.ID, signupSource)
|
||||||
grantPlan := s.resolveSignupGrantPlan(ctx, signupSource)
|
grantPlan := s.resolveSignupGrantPlan(ctx, signupSource)
|
||||||
s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
|
s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
|
||||||
return nil
|
return nil
|
||||||
@@ -211,7 +214,7 @@ func (s *AuthService) restoreOAuthRegistrationInvitation(ctx context.Context, in
|
|||||||
if s == nil || s.settingService == nil || !s.settingService.IsInvitationCodeEnabled(ctx) {
|
if s == nil || s.settingService == nil || !s.settingService.IsInvitationCodeEnabled(ctx) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
if s.redeemRepo == nil {
|
if s.redeemRepo == nil && s.oauthEmailFlowClient(ctx) == nil {
|
||||||
return ErrServiceUnavailable
|
return ErrServiceUnavailable
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -220,7 +223,7 @@ func (s *AuthService) restoreOAuthRegistrationInvitation(ctx context.Context, in
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
redeemCode, err := s.redeemRepo.GetByCode(ctx, invitationCode)
|
redeemCode, err := s.loadOAuthRegistrationInvitation(ctx, invitationCode)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, ErrRedeemCodeNotFound) {
|
if errors.Is(err, ErrRedeemCodeNotFound) {
|
||||||
return nil
|
return nil
|
||||||
@@ -234,12 +237,115 @@ func (s *AuthService) restoreOAuthRegistrationInvitation(ctx context.Context, in
|
|||||||
redeemCode.Status = StatusUnused
|
redeemCode.Status = StatusUnused
|
||||||
redeemCode.UsedBy = nil
|
redeemCode.UsedBy = nil
|
||||||
redeemCode.UsedAt = nil
|
redeemCode.UsedAt = nil
|
||||||
if err := s.redeemRepo.Update(ctx, redeemCode); err != nil {
|
if err := s.updateOAuthRegistrationInvitation(ctx, redeemCode); err != nil {
|
||||||
return fmt.Errorf("restore invitation code: %w", err)
|
return fmt.Errorf("restore invitation code: %w", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *AuthService) oauthEmailFlowClient(ctx context.Context) *dbent.Client {
|
||||||
|
if s == nil || s.entClient == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if tx := dbent.TxFromContext(ctx); tx != nil {
|
||||||
|
return tx.Client()
|
||||||
|
}
|
||||||
|
return s.entClient
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *AuthService) loadOAuthRegistrationInvitation(ctx context.Context, invitationCode string) (*RedeemCode, error) {
|
||||||
|
if client := s.oauthEmailFlowClient(ctx); client != nil {
|
||||||
|
entity, err := client.RedeemCode.Query().Where(redeemcode.CodeEQ(invitationCode)).Only(ctx)
|
||||||
|
if err != nil {
|
||||||
|
if dbent.IsNotFound(err) {
|
||||||
|
return nil, ErrRedeemCodeNotFound
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &RedeemCode{
|
||||||
|
ID: entity.ID,
|
||||||
|
Code: entity.Code,
|
||||||
|
Type: entity.Type,
|
||||||
|
Value: entity.Value,
|
||||||
|
Status: entity.Status,
|
||||||
|
UsedBy: entity.UsedBy,
|
||||||
|
UsedAt: entity.UsedAt,
|
||||||
|
Notes: oauthEmailFlowStringValue(entity.Notes),
|
||||||
|
CreatedAt: entity.CreatedAt,
|
||||||
|
GroupID: entity.GroupID,
|
||||||
|
ValidityDays: entity.ValidityDays,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
return s.redeemRepo.GetByCode(ctx, invitationCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *AuthService) useOAuthRegistrationInvitation(ctx context.Context, invitationID, userID int64) error {
|
||||||
|
if client := s.oauthEmailFlowClient(ctx); client != nil {
|
||||||
|
affected, err := client.RedeemCode.Update().
|
||||||
|
Where(redeemcode.IDEQ(invitationID), redeemcode.StatusEQ(StatusUnused)).
|
||||||
|
SetStatus(StatusUsed).
|
||||||
|
SetUsedBy(userID).
|
||||||
|
SetUsedAt(time.Now().UTC()).
|
||||||
|
Save(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if affected == 0 {
|
||||||
|
return ErrRedeemCodeUsed
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return s.redeemRepo.Use(ctx, invitationID, userID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *AuthService) updateOAuthRegistrationInvitation(ctx context.Context, code *RedeemCode) error {
|
||||||
|
if code == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if client := s.oauthEmailFlowClient(ctx); client != nil {
|
||||||
|
update := client.RedeemCode.UpdateOneID(code.ID).
|
||||||
|
SetCode(code.Code).
|
||||||
|
SetType(code.Type).
|
||||||
|
SetValue(code.Value).
|
||||||
|
SetStatus(code.Status).
|
||||||
|
SetNotes(code.Notes).
|
||||||
|
SetValidityDays(code.ValidityDays)
|
||||||
|
if code.UsedBy != nil {
|
||||||
|
update = update.SetUsedBy(*code.UsedBy)
|
||||||
|
} else {
|
||||||
|
update = update.ClearUsedBy()
|
||||||
|
}
|
||||||
|
if code.UsedAt != nil {
|
||||||
|
update = update.SetUsedAt(*code.UsedAt)
|
||||||
|
} else {
|
||||||
|
update = update.ClearUsedAt()
|
||||||
|
}
|
||||||
|
if code.GroupID != nil {
|
||||||
|
update = update.SetGroupID(*code.GroupID)
|
||||||
|
} else {
|
||||||
|
update = update.ClearGroupID()
|
||||||
|
}
|
||||||
|
_, err := update.Save(ctx)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return s.redeemRepo.Update(ctx, code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *AuthService) updateOAuthSignupSource(ctx context.Context, userID int64, signupSource string) {
|
||||||
|
client := s.oauthEmailFlowClient(ctx)
|
||||||
|
if client == nil || userID <= 0 || strings.TrimSpace(signupSource) == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_ = client.User.UpdateOneID(userID).SetSignupSource(signupSource).Exec(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func oauthEmailFlowStringValue(value *string) string {
|
||||||
|
if value == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return *value
|
||||||
|
}
|
||||||
|
|
||||||
// ValidatePasswordCredentials checks the local password without completing the
|
// ValidatePasswordCredentials checks the local password without completing the
|
||||||
// login flow. This is used by pending third-party account adoption flows before
|
// login flow. This is used by pending third-party account adoption flows before
|
||||||
// the external identity has been bound.
|
// the external identity has been bound.
|
||||||
@@ -269,7 +375,7 @@ func (s *AuthService) ValidatePasswordCredentials(ctx context.Context, email, pa
|
|||||||
func (s *AuthService) RecordSuccessfulLogin(ctx context.Context, userID int64) {
|
func (s *AuthService) RecordSuccessfulLogin(ctx context.Context, userID int64) {
|
||||||
if s != nil && s.userRepo != nil && userID > 0 {
|
if s != nil && s.userRepo != nil && userID > 0 {
|
||||||
user, err := s.userRepo.GetByID(ctx, userID)
|
user, err := s.userRepo.GetByID(ctx, userID)
|
||||||
if err == nil {
|
if err == nil && user != nil && !isReservedEmail(user.Email) {
|
||||||
s.backfillEmailIdentityOnSuccessfulLogin(ctx, user)
|
s.backfillEmailIdentityOnSuccessfulLogin(ctx, user)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -240,7 +240,7 @@ func (s *UserService) GetProfileIdentitySummaries(ctx context.Context, userID in
|
|||||||
}
|
}
|
||||||
|
|
||||||
return UserIdentitySummarySet{
|
return UserIdentitySummarySet{
|
||||||
Email: s.buildEmailIdentitySummary(user),
|
Email: s.buildEmailIdentitySummary(user, records),
|
||||||
LinuxDo: s.buildProviderIdentitySummary("linuxdo", records),
|
LinuxDo: s.buildProviderIdentitySummary("linuxdo", records),
|
||||||
OIDC: s.buildProviderIdentitySummary("oidc", records),
|
OIDC: s.buildProviderIdentitySummary("oidc", records),
|
||||||
WeChat: s.buildProviderIdentitySummary("wechat", records),
|
WeChat: s.buildProviderIdentitySummary("wechat", records),
|
||||||
@@ -497,7 +497,7 @@ func compressInlineAvatar(decoded []byte) ([]byte, string, error) {
|
|||||||
return nil, "", ErrAvatarTooLarge
|
return nil, "", ErrAvatarTooLarge
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *UserService) buildEmailIdentitySummary(user *User) UserIdentitySummary {
|
func (s *UserService) buildEmailIdentitySummary(user *User, records []UserAuthIdentityRecord) UserIdentitySummary {
|
||||||
summary := UserIdentitySummary{
|
summary := UserIdentitySummary{
|
||||||
Provider: "email",
|
Provider: "email",
|
||||||
CanBind: false,
|
CanBind: false,
|
||||||
@@ -508,11 +508,34 @@ func (s *UserService) buildEmailIdentitySummary(user *User) UserIdentitySummary
|
|||||||
return summary
|
return summary
|
||||||
}
|
}
|
||||||
|
|
||||||
|
filtered := filterUserAuthIdentities(records, "email")
|
||||||
|
if len(filtered) > 0 {
|
||||||
|
primary := selectPrimaryUserAuthIdentity(filtered)
|
||||||
|
email := strings.TrimSpace(firstStringIdentityValue(primary.Metadata, "email"))
|
||||||
|
if email == "" {
|
||||||
|
email = strings.TrimSpace(primary.ProviderSubject)
|
||||||
|
}
|
||||||
|
if email == "" || isReservedEmail(email) {
|
||||||
|
email = strings.TrimSpace(user.Email)
|
||||||
|
}
|
||||||
|
if email == "" || isReservedEmail(email) {
|
||||||
|
email = strings.TrimSpace(primary.ProviderKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
summary.Bound = true
|
||||||
|
summary.BoundCount = len(filtered)
|
||||||
|
summary.DisplayName = email
|
||||||
|
summary.SubjectHint = maskEmailIdentity(email)
|
||||||
|
summary.ProviderKey = strings.TrimSpace(primary.ProviderKey)
|
||||||
|
summary.VerifiedAt = primary.VerifiedAt
|
||||||
|
return summary
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compatibility fallback for legacy normal-email users that predate auth_identities backfill.
|
||||||
email := strings.TrimSpace(user.Email)
|
email := strings.TrimSpace(user.Email)
|
||||||
if email == "" || isReservedEmail(email) {
|
if email == "" || isReservedEmail(email) {
|
||||||
return summary
|
return summary
|
||||||
}
|
}
|
||||||
|
|
||||||
summary.Bound = true
|
summary.Bound = true
|
||||||
summary.BoundCount = 1
|
summary.BoundCount = 1
|
||||||
summary.DisplayName = email
|
summary.DisplayName = email
|
||||||
|
|||||||
@@ -208,6 +208,12 @@ export type PendingOAuthExchangeResponse = PendingOAuthBindLoginResponse
|
|||||||
|
|
||||||
export interface PendingOAuthCreateAccountResponse extends OAuthTokenResponse {}
|
export interface PendingOAuthCreateAccountResponse extends OAuthTokenResponse {}
|
||||||
|
|
||||||
|
export interface PendingOAuthSendVerifyCodeResponse extends SendVerifyCodeResponse {
|
||||||
|
auth_result?: string
|
||||||
|
provider?: string
|
||||||
|
redirect?: string
|
||||||
|
}
|
||||||
|
|
||||||
export type OAuthCompletionKind = 'login' | 'bind'
|
export type OAuthCompletionKind = 'login' | 'bind'
|
||||||
|
|
||||||
export interface OAuthAdoptionDecision {
|
export interface OAuthAdoptionDecision {
|
||||||
@@ -451,8 +457,8 @@ export async function sendVerifyCode(
|
|||||||
|
|
||||||
export async function sendPendingOAuthVerifyCode(
|
export async function sendPendingOAuthVerifyCode(
|
||||||
request: SendVerifyCodeRequest
|
request: SendVerifyCodeRequest
|
||||||
): Promise<SendVerifyCodeResponse> {
|
): Promise<PendingOAuthSendVerifyCodeResponse> {
|
||||||
const { data } = await apiClient.post<SendVerifyCodeResponse>(
|
const { data } = await apiClient.post<PendingOAuthSendVerifyCodeResponse>(
|
||||||
'/auth/oauth/pending/send-verify-code',
|
'/auth/oauth/pending/send-verify-code',
|
||||||
request
|
request
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -209,7 +209,12 @@ function getBindingStatus(provider: UserAuthProvider): boolean {
|
|||||||
|
|
||||||
function getBindingStatusForUser(user: User | null | undefined, provider: UserAuthProvider): boolean {
|
function getBindingStatusForUser(user: User | null | undefined, provider: UserAuthProvider): boolean {
|
||||||
if (provider === 'email') {
|
if (provider === 'email') {
|
||||||
return typeof user?.email_bound === 'boolean' ? user.email_bound : Boolean(user?.email)
|
if (typeof user?.email_bound === 'boolean') {
|
||||||
|
return user.email_bound
|
||||||
|
}
|
||||||
|
const nested = user?.auth_bindings?.email ?? user?.identity_bindings?.email
|
||||||
|
const normalized = normalizeBindingStatus(nested)
|
||||||
|
return normalized ?? false
|
||||||
}
|
}
|
||||||
|
|
||||||
const directFlag = user?.[`${provider}_bound` as keyof User]
|
const directFlag = user?.[`${provider}_bound` as keyof User]
|
||||||
|
|||||||
@@ -301,4 +301,27 @@ describe('ProfileIdentityBindingsSection', () => {
|
|||||||
expect(wrapper.get('[data-testid="profile-binding-email-status"]').text()).toBe('Bound')
|
expect(wrapper.get('[data-testid="profile-binding-email-status"]').text()).toBe('Bound')
|
||||||
expect(authStore.user?.email).toBe('bound@example.com')
|
expect(authStore.user?.email).toBe('bound@example.com')
|
||||||
})
|
})
|
||||||
|
|
||||||
|
it('keeps the email binding form visible when the user still lacks an email identity', () => {
|
||||||
|
const wrapper = mount(ProfileIdentityBindingsSection, {
|
||||||
|
global: {
|
||||||
|
plugins: [pinia],
|
||||||
|
},
|
||||||
|
props: {
|
||||||
|
user: createUser({
|
||||||
|
email: 'legacy@example.com',
|
||||||
|
email_bound: false,
|
||||||
|
auth_bindings: {
|
||||||
|
email: { bound: false },
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
linuxdoEnabled: false,
|
||||||
|
oidcEnabled: false,
|
||||||
|
wechatEnabled: false,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(wrapper.get('[data-testid="profile-binding-email-status"]').text()).toBe('Not bound')
|
||||||
|
expect(wrapper.get('[data-testid="profile-binding-email-input"]').exists()).toBe(true)
|
||||||
|
})
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -179,6 +179,8 @@ import { useAuthStore, useAppStore } from '@/stores'
|
|||||||
import {
|
import {
|
||||||
persistOAuthTokenContext,
|
persistOAuthTokenContext,
|
||||||
getPublicSettings,
|
getPublicSettings,
|
||||||
|
isOAuthLoginCompletion,
|
||||||
|
type PendingOAuthSendVerifyCodeResponse,
|
||||||
sendPendingOAuthVerifyCode,
|
sendPendingOAuthVerifyCode,
|
||||||
sendVerifyCode,
|
sendVerifyCode,
|
||||||
} from '@/api/auth'
|
} from '@/api/auth'
|
||||||
@@ -216,10 +218,13 @@ type PendingAuthSessionSummary = {
|
|||||||
redirect?: string
|
redirect?: string
|
||||||
}
|
}
|
||||||
type PendingOAuthCreateAccountResponse = {
|
type PendingOAuthCreateAccountResponse = {
|
||||||
|
auth_result?: string
|
||||||
access_token: string
|
access_token: string
|
||||||
refresh_token?: string
|
refresh_token?: string
|
||||||
expires_in?: number
|
expires_in?: number
|
||||||
token_type?: string
|
token_type?: string
|
||||||
|
provider?: string
|
||||||
|
redirect?: string
|
||||||
}
|
}
|
||||||
|
|
||||||
const email = ref<string>('')
|
const email = ref<string>('')
|
||||||
@@ -353,6 +358,46 @@ function onTurnstileError(): void {
|
|||||||
errors.value.turnstile = t('auth.turnstileFailed')
|
errors.value.turnstile = t('auth.turnstileFailed')
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function isPendingOAuthFlow(): boolean {
|
||||||
|
return Boolean(pendingProvider.value.trim())
|
||||||
|
}
|
||||||
|
|
||||||
|
function shouldBypassRegistrationEmailPolicy(): boolean {
|
||||||
|
return isPendingOAuthFlow() || Boolean(pendingAuthToken.value.trim())
|
||||||
|
}
|
||||||
|
|
||||||
|
function resolvePendingOAuthCallbackRoute(provider: string): string {
|
||||||
|
switch (provider.trim().toLowerCase()) {
|
||||||
|
case 'linuxdo':
|
||||||
|
return '/auth/linuxdo/callback'
|
||||||
|
case 'oidc':
|
||||||
|
return '/auth/oidc/callback'
|
||||||
|
case 'wechat':
|
||||||
|
return '/auth/wechat/callback'
|
||||||
|
default:
|
||||||
|
return '/auth/callback'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function isPendingOAuthSessionResponse(data: PendingOAuthCreateAccountResponse): boolean {
|
||||||
|
return data.auth_result === 'pending_session'
|
||||||
|
}
|
||||||
|
|
||||||
|
function getPendingOAuthSendCodeSessionResponse(
|
||||||
|
data: PendingOAuthSendVerifyCodeResponse,
|
||||||
|
): PendingOAuthSendVerifyCodeResponse | null {
|
||||||
|
return data.auth_result === 'pending_session' ? data : null
|
||||||
|
}
|
||||||
|
|
||||||
|
function persistPendingOAuthSession(provider: string, redirect?: string): void {
|
||||||
|
authStore.setPendingAuthSession({
|
||||||
|
token: pendingAuthToken.value,
|
||||||
|
token_field: pendingAuthTokenField.value,
|
||||||
|
provider: provider.trim() || pendingProvider.value.trim(),
|
||||||
|
redirect: redirect || pendingRedirect.value || undefined,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// ==================== Send Code ====================
|
// ==================== Send Code ====================
|
||||||
|
|
||||||
async function sendCode(): Promise<void> {
|
async function sendCode(): Promise<void> {
|
||||||
@@ -360,7 +405,7 @@ async function sendCode(): Promise<void> {
|
|||||||
errorMessage.value = ''
|
errorMessage.value = ''
|
||||||
|
|
||||||
try {
|
try {
|
||||||
if (!pendingAuthToken.value && !isRegistrationEmailSuffixAllowed(email.value, registrationEmailSuffixWhitelist.value)) {
|
if (!shouldBypassRegistrationEmailPolicy() && !isRegistrationEmailSuffixAllowed(email.value, registrationEmailSuffixWhitelist.value)) {
|
||||||
errorMessage.value = buildEmailSuffixNotAllowedMessage()
|
errorMessage.value = buildEmailSuffixNotAllowedMessage()
|
||||||
appStore.showError(errorMessage.value)
|
appStore.showError(errorMessage.value)
|
||||||
return
|
return
|
||||||
@@ -372,10 +417,25 @@ async function sendCode(): Promise<void> {
|
|||||||
// 优先使用重发时新获取的 token(因为初始 token 可能已被使用)
|
// 优先使用重发时新获取的 token(因为初始 token 可能已被使用)
|
||||||
turnstile_token: resendTurnstileToken.value || initialTurnstileToken.value || undefined
|
turnstile_token: resendTurnstileToken.value || initialTurnstileToken.value || undefined
|
||||||
} as Parameters<typeof sendVerifyCode>[0]
|
} as Parameters<typeof sendVerifyCode>[0]
|
||||||
const response = pendingAuthToken.value
|
const response = isPendingOAuthFlow()
|
||||||
? await sendPendingOAuthVerifyCode(requestPayload)
|
? await sendPendingOAuthVerifyCode(requestPayload)
|
||||||
: await sendVerifyCode(requestPayload)
|
: await sendVerifyCode(requestPayload)
|
||||||
|
|
||||||
|
const pendingSendCodeSession = isPendingOAuthFlow()
|
||||||
|
? getPendingOAuthSendCodeSessionResponse(response as PendingOAuthSendVerifyCodeResponse)
|
||||||
|
: null
|
||||||
|
if (pendingSendCodeSession) {
|
||||||
|
sessionStorage.removeItem('register_data')
|
||||||
|
persistPendingOAuthSession(
|
||||||
|
pendingSendCodeSession.provider || pendingProvider.value,
|
||||||
|
pendingSendCodeSession.redirect,
|
||||||
|
)
|
||||||
|
await router.push(
|
||||||
|
resolvePendingOAuthCallbackRoute(pendingSendCodeSession.provider || pendingProvider.value),
|
||||||
|
)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
codeSent.value = true
|
codeSent.value = true
|
||||||
startCountdown(response.countdown)
|
startCountdown(response.countdown)
|
||||||
|
|
||||||
@@ -438,13 +498,13 @@ async function handleVerify(): Promise<void> {
|
|||||||
isLoading.value = true
|
isLoading.value = true
|
||||||
|
|
||||||
try {
|
try {
|
||||||
if (!isRegistrationEmailSuffixAllowed(email.value, registrationEmailSuffixWhitelist.value)) {
|
if (!shouldBypassRegistrationEmailPolicy() && !isRegistrationEmailSuffixAllowed(email.value, registrationEmailSuffixWhitelist.value)) {
|
||||||
errorMessage.value = buildEmailSuffixNotAllowedMessage()
|
errorMessage.value = buildEmailSuffixNotAllowedMessage()
|
||||||
appStore.showError(errorMessage.value)
|
appStore.showError(errorMessage.value)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if (pendingProvider.value) {
|
if (isPendingOAuthFlow()) {
|
||||||
const { data } = await apiClient.post<PendingOAuthCreateAccountResponse>(
|
const { data } = await apiClient.post<PendingOAuthCreateAccountResponse>(
|
||||||
'/auth/oauth/pending/create-account',
|
'/auth/oauth/pending/create-account',
|
||||||
{
|
{
|
||||||
@@ -456,6 +516,16 @@ async function handleVerify(): Promise<void> {
|
|||||||
adopt_avatar: pendingAdoptionDecision.value?.adoptAvatar
|
adopt_avatar: pendingAdoptionDecision.value?.adoptAvatar
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
if (isPendingOAuthSessionResponse(data)) {
|
||||||
|
sessionStorage.removeItem('register_data')
|
||||||
|
persistPendingOAuthSession(data.provider || pendingProvider.value, data.redirect)
|
||||||
|
await router.push(resolvePendingOAuthCallbackRoute(data.provider || pendingProvider.value))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if (!isOAuthLoginCompletion(data)) {
|
||||||
|
throw new Error(t('auth.verifyFailed'))
|
||||||
|
}
|
||||||
|
|
||||||
persistOAuthTokenContext(data)
|
persistOAuthTokenContext(data)
|
||||||
await authStore.setToken(data.access_token)
|
await authStore.setToken(data.access_token)
|
||||||
authStore.clearPendingAuthSession?.()
|
authStore.clearPendingAuthSession?.()
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ const {
|
|||||||
showErrorMock,
|
showErrorMock,
|
||||||
registerMock,
|
registerMock,
|
||||||
setTokenMock,
|
setTokenMock,
|
||||||
|
setPendingAuthSessionMock,
|
||||||
clearPendingAuthSessionMock,
|
clearPendingAuthSessionMock,
|
||||||
getPublicSettingsMock,
|
getPublicSettingsMock,
|
||||||
sendVerifyCodeMock,
|
sendVerifyCodeMock,
|
||||||
@@ -21,6 +22,7 @@ const {
|
|||||||
showErrorMock: vi.fn(),
|
showErrorMock: vi.fn(),
|
||||||
registerMock: vi.fn(),
|
registerMock: vi.fn(),
|
||||||
setTokenMock: vi.fn(),
|
setTokenMock: vi.fn(),
|
||||||
|
setPendingAuthSessionMock: vi.fn(),
|
||||||
clearPendingAuthSessionMock: vi.fn(),
|
clearPendingAuthSessionMock: vi.fn(),
|
||||||
getPublicSettingsMock: vi.fn(),
|
getPublicSettingsMock: vi.fn(),
|
||||||
sendVerifyCodeMock: vi.fn(),
|
sendVerifyCodeMock: vi.fn(),
|
||||||
@@ -68,6 +70,7 @@ vi.mock('@/stores', () => ({
|
|||||||
pendingAuthSession: authStoreState.pendingAuthSession,
|
pendingAuthSession: authStoreState.pendingAuthSession,
|
||||||
register: (...args: any[]) => registerMock(...args),
|
register: (...args: any[]) => registerMock(...args),
|
||||||
setToken: (...args: any[]) => setTokenMock(...args),
|
setToken: (...args: any[]) => setTokenMock(...args),
|
||||||
|
setPendingAuthSession: (...args: any[]) => setPendingAuthSessionMock(...args),
|
||||||
clearPendingAuthSession: (...args: any[]) => clearPendingAuthSessionMock(...args),
|
clearPendingAuthSession: (...args: any[]) => clearPendingAuthSessionMock(...args),
|
||||||
}),
|
}),
|
||||||
useAppStore: () => ({
|
useAppStore: () => ({
|
||||||
@@ -100,6 +103,7 @@ describe('EmailVerifyView', () => {
|
|||||||
showErrorMock.mockReset()
|
showErrorMock.mockReset()
|
||||||
registerMock.mockReset()
|
registerMock.mockReset()
|
||||||
setTokenMock.mockReset()
|
setTokenMock.mockReset()
|
||||||
|
setPendingAuthSessionMock.mockReset()
|
||||||
clearPendingAuthSessionMock.mockReset()
|
clearPendingAuthSessionMock.mockReset()
|
||||||
getPublicSettingsMock.mockReset()
|
getPublicSettingsMock.mockReset()
|
||||||
sendVerifyCodeMock.mockReset()
|
sendVerifyCodeMock.mockReset()
|
||||||
@@ -196,6 +200,97 @@ describe('EmailVerifyView', () => {
|
|||||||
expect(showErrorMock).not.toHaveBeenCalled()
|
expect(showErrorMock).not.toHaveBeenCalled()
|
||||||
})
|
})
|
||||||
|
|
||||||
|
it('uses the pending oauth verify-code endpoint when auth store only carries the pending provider', async () => {
|
||||||
|
authStoreState.pendingAuthSession = {
|
||||||
|
token: '',
|
||||||
|
token_field: 'pending_oauth_token',
|
||||||
|
provider: 'oidc',
|
||||||
|
redirect: '/profile',
|
||||||
|
}
|
||||||
|
getPublicSettingsMock.mockResolvedValue({
|
||||||
|
turnstile_enabled: false,
|
||||||
|
turnstile_site_key: '',
|
||||||
|
site_name: 'Sub2API',
|
||||||
|
registration_email_suffix_whitelist: ['allowed.com'],
|
||||||
|
})
|
||||||
|
sessionStorage.setItem(
|
||||||
|
'register_data',
|
||||||
|
JSON.stringify({
|
||||||
|
email: 'fresh@example.com',
|
||||||
|
password: 'secret-123',
|
||||||
|
})
|
||||||
|
)
|
||||||
|
|
||||||
|
mount(EmailVerifyView, {
|
||||||
|
global: {
|
||||||
|
stubs: {
|
||||||
|
AuthLayout: { template: '<div><slot /><slot name="footer" /></div>' },
|
||||||
|
Icon: true,
|
||||||
|
TurnstileWidget: true,
|
||||||
|
transition: false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
await flushPromises()
|
||||||
|
|
||||||
|
expect(sendPendingOAuthVerifyCodeMock).toHaveBeenCalledWith({
|
||||||
|
email: 'fresh@example.com',
|
||||||
|
pending_oauth_token: undefined,
|
||||||
|
})
|
||||||
|
expect(sendVerifyCodeMock).not.toHaveBeenCalled()
|
||||||
|
expect(showErrorMock).not.toHaveBeenCalled()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('returns to the oauth callback flow when pending send-code detects an existing account email', async () => {
|
||||||
|
authStoreState.pendingAuthSession = {
|
||||||
|
token: '',
|
||||||
|
token_field: 'pending_oauth_token',
|
||||||
|
provider: 'oidc',
|
||||||
|
redirect: '/profile/security',
|
||||||
|
}
|
||||||
|
getPublicSettingsMock.mockResolvedValue({
|
||||||
|
turnstile_enabled: false,
|
||||||
|
turnstile_site_key: '',
|
||||||
|
site_name: 'Sub2API',
|
||||||
|
registration_email_suffix_whitelist: ['allowed.com'],
|
||||||
|
})
|
||||||
|
sendPendingOAuthVerifyCodeMock.mockResolvedValue({
|
||||||
|
auth_result: 'pending_session',
|
||||||
|
provider: 'oidc',
|
||||||
|
redirect: '/profile/security',
|
||||||
|
})
|
||||||
|
sessionStorage.setItem(
|
||||||
|
'register_data',
|
||||||
|
JSON.stringify({
|
||||||
|
email: 'fresh@example.com',
|
||||||
|
password: 'secret-123',
|
||||||
|
})
|
||||||
|
)
|
||||||
|
|
||||||
|
mount(EmailVerifyView, {
|
||||||
|
global: {
|
||||||
|
stubs: {
|
||||||
|
AuthLayout: { template: '<div><slot /><slot name="footer" /></div>' },
|
||||||
|
Icon: true,
|
||||||
|
TurnstileWidget: true,
|
||||||
|
transition: false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
await flushPromises()
|
||||||
|
|
||||||
|
expect(setPendingAuthSessionMock).toHaveBeenCalledWith({
|
||||||
|
token: '',
|
||||||
|
token_field: 'pending_oauth_token',
|
||||||
|
provider: 'oidc',
|
||||||
|
redirect: '/profile/security',
|
||||||
|
})
|
||||||
|
expect(pushMock).toHaveBeenCalledWith('/auth/oidc/callback')
|
||||||
|
expect(showErrorMock).not.toHaveBeenCalled()
|
||||||
|
})
|
||||||
|
|
||||||
it('submits pending auth account creation when session storage has no pending metadata but auth store does', async () => {
|
it('submits pending auth account creation when session storage has no pending metadata but auth store does', async () => {
|
||||||
authStoreState.pendingAuthSession = {
|
authStoreState.pendingAuthSession = {
|
||||||
token: 'pending-token-1',
|
token: 'pending-token-1',
|
||||||
@@ -252,6 +347,70 @@ describe('EmailVerifyView', () => {
|
|||||||
expect(registerMock).not.toHaveBeenCalled()
|
expect(registerMock).not.toHaveBeenCalled()
|
||||||
})
|
})
|
||||||
|
|
||||||
|
it('returns to the oauth callback flow when pending account creation becomes bind-login', async () => {
|
||||||
|
authStoreState.pendingAuthSession = {
|
||||||
|
token: '',
|
||||||
|
token_field: 'pending_oauth_token',
|
||||||
|
provider: 'oidc',
|
||||||
|
redirect: '/profile/security',
|
||||||
|
}
|
||||||
|
getPublicSettingsMock.mockResolvedValue({
|
||||||
|
turnstile_enabled: false,
|
||||||
|
turnstile_site_key: '',
|
||||||
|
site_name: 'Sub2API',
|
||||||
|
registration_email_suffix_whitelist: ['allowed.com'],
|
||||||
|
})
|
||||||
|
sessionStorage.setItem(
|
||||||
|
'register_data',
|
||||||
|
JSON.stringify({
|
||||||
|
email: 'fresh@example.com',
|
||||||
|
password: 'secret-123',
|
||||||
|
})
|
||||||
|
)
|
||||||
|
apiClientPostMock.mockResolvedValue({
|
||||||
|
data: {
|
||||||
|
auth_result: 'pending_session',
|
||||||
|
provider: 'oidc',
|
||||||
|
step: 'bind_login_required',
|
||||||
|
redirect: '/profile/security',
|
||||||
|
email: 'fresh@example.com',
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
const wrapper = mount(EmailVerifyView, {
|
||||||
|
global: {
|
||||||
|
stubs: {
|
||||||
|
AuthLayout: { template: '<div><slot /><slot name="footer" /></div>' },
|
||||||
|
Icon: true,
|
||||||
|
TurnstileWidget: true,
|
||||||
|
transition: false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
await flushPromises()
|
||||||
|
await wrapper.get('#code').setValue('123456')
|
||||||
|
await wrapper.get('form').trigger('submit.prevent')
|
||||||
|
await flushPromises()
|
||||||
|
|
||||||
|
expect(apiClientPostMock).toHaveBeenCalledWith('/auth/oauth/pending/create-account', {
|
||||||
|
email: 'fresh@example.com',
|
||||||
|
password: 'secret-123',
|
||||||
|
verify_code: '123456',
|
||||||
|
})
|
||||||
|
expect(setPendingAuthSessionMock).toHaveBeenCalledWith({
|
||||||
|
token: '',
|
||||||
|
token_field: 'pending_oauth_token',
|
||||||
|
provider: 'oidc',
|
||||||
|
redirect: '/profile/security',
|
||||||
|
})
|
||||||
|
expect(pushMock).toHaveBeenCalledWith('/auth/oidc/callback')
|
||||||
|
expect(setTokenMock).not.toHaveBeenCalled()
|
||||||
|
expect(persistOAuthTokenContextMock).not.toHaveBeenCalled()
|
||||||
|
expect(clearPendingAuthSessionMock).not.toHaveBeenCalled()
|
||||||
|
expect(showSuccessMock).not.toHaveBeenCalled()
|
||||||
|
})
|
||||||
|
|
||||||
it('keeps the normal email registration flow unchanged', async () => {
|
it('keeps the normal email registration flow unchanged', async () => {
|
||||||
sessionStorage.setItem(
|
sessionStorage.setItem(
|
||||||
'register_data',
|
'register_data',
|
||||||
|
|||||||
Reference in New Issue
Block a user