diff --git a/backend/internal/handler/auth_oauth_pending_flow.go b/backend/internal/handler/auth_oauth_pending_flow.go index 1d3b113f..8a3006f3 100644 --- a/backend/internal/handler/auth_oauth_pending_flow.go +++ b/backend/internal/handler/auth_oauth_pending_flow.go @@ -8,6 +8,7 @@ import ( "net/http" "net/url" "strings" + "time" dbent "github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/ent/authidentity" @@ -35,6 +36,8 @@ const ( oauthCompletionResponseKey = "completion_response" ) +var pendingOAuthCreateAccountPreCommitHook func(context.Context, *dbent.PendingAuthSession) error + type oauthPendingSessionPayload struct { Intent string Identity service.PendingAuthIdentityKey @@ -481,6 +484,26 @@ func (h *AuthHandler) SendPendingOAuthVerifyCode(c *gin.Context) { return } + client := h.entClient() + if client == nil { + response.ErrorFrom(c, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")) + return + } + + email := strings.TrimSpace(strings.ToLower(req.Email)) + if existingUser, err := findUserByNormalizedEmail(c.Request.Context(), client, email); err == nil && existingUser != nil { + session, err = h.transitionPendingOAuthAccountToBindLogin(c, client, session, email, oauthAdoptionDecisionRequest{}) + if err != nil { + response.ErrorFrom(c, err) + return + } + c.JSON(http.StatusOK, buildPendingOAuthSessionStatusPayload(session)) + return + } else if err != nil && !errors.Is(err, service.ErrUserNotFound) { + response.ErrorFrom(c, err) + return + } + result, err := h.authService.SendPendingOAuthVerifyCode(c.Request.Context(), req.Email) if err != nil { response.ErrorFrom(c, err) @@ -946,11 +969,46 @@ func applyPendingOAuthBinding( return nil } + if tx := dbent.TxFromContext(ctx); tx != nil { + return applyPendingOAuthBindingTx(ctx, tx, authService, userService, session, decision, overrideUserID, forceBind, applyFirstBindDefaults) + } + + tx, err := client.Tx(ctx) + if err != nil { + return err + } + defer func() { _ = tx.Rollback() }() + + txCtx := dbent.NewTxContext(ctx, tx) + if err := applyPendingOAuthBindingTx(txCtx, tx, authService, userService, session, decision, overrideUserID, forceBind, applyFirstBindDefaults); err != nil { + return err + } + return tx.Commit() +} + +func applyPendingOAuthBindingTx( + ctx context.Context, + tx *dbent.Tx, + authService *service.AuthService, + userService *service.UserService, + session *dbent.PendingAuthSession, + decision *dbent.IdentityAdoptionDecision, + overrideUserID *int64, + forceBind bool, + applyFirstBindDefaults bool, +) error { + if tx == nil || session == nil { + return nil + } + if !forceBind && !shouldBindPendingOAuthIdentity(session, decision) { + return nil + } + targetUserID := int64(0) if overrideUserID != nil && *overrideUserID > 0 { targetUserID = *overrideUserID } else { - resolvedUserID, err := resolvePendingOAuthTargetUserID(ctx, client, session) + resolvedUserID, err := resolvePendingOAuthTargetUserID(ctx, tx.Client(), session) if err != nil { return err } @@ -974,22 +1032,15 @@ func applyPendingOAuthBinding( } } - tx, err := client.Tx(ctx) - if err != nil { - return err - } - defer func() { _ = tx.Rollback() }() - txCtx := dbent.NewTxContext(ctx, tx) - if decision != nil && decision.AdoptDisplayName && adoptedDisplayName != "" { if err := tx.Client().User.UpdateOneID(targetUserID). SetUsername(adoptedDisplayName). - Exec(txCtx); err != nil { + Exec(ctx); err != nil { return err } } - identity, err := ensurePendingOAuthIdentityForUser(txCtx, tx, session, targetUserID) + identity, err := ensurePendingOAuthIdentityForUser(ctx, tx, session, targetUserID) if err != nil { return err } @@ -1009,31 +1060,71 @@ func applyPendingOAuthBinding( if issuer := oauthIdentityIssuer(session); issuer != nil { updateIdentity = updateIdentity.SetIssuer(strings.TrimSpace(*issuer)) } - if _, err := updateIdentity.Save(txCtx); err != nil { + if _, err := updateIdentity.Save(ctx); err != nil { return err } if decision != nil && (decision.IdentityID == nil || *decision.IdentityID != identity.ID) { if _, err := tx.Client().IdentityAdoptionDecision.UpdateOneID(decision.ID). SetIdentityID(identity.ID). - Save(txCtx); err != nil { + Save(ctx); err != nil { return err } } if applyFirstBindDefaults && authService != nil { - if err := authService.ApplyProviderDefaultSettingsOnFirstBind(txCtx, targetUserID, session.ProviderType); err != nil { + if err := authService.ApplyProviderDefaultSettingsOnFirstBind(ctx, targetUserID, session.ProviderType); err != nil { return err } } if shouldAdoptAvatar && userService != nil { - if _, err := userService.SetAvatar(txCtx, targetUserID, adoptedAvatarURL); err != nil { + if _, err := userService.SetAvatar(ctx, targetUserID, adoptedAvatarURL); err != nil { return err } } - return tx.Commit() + return nil +} + +func consumePendingOAuthBrowserSessionTx( + ctx context.Context, + tx *dbent.Tx, + session *dbent.PendingAuthSession, +) error { + if tx == nil || session == nil { + return service.ErrPendingAuthSessionNotFound + } + + storedSession, err := tx.Client().PendingAuthSession.Get(ctx, session.ID) + if err != nil { + if dbent.IsNotFound(err) { + return service.ErrPendingAuthSessionNotFound + } + return err + } + + now := time.Now().UTC() + if storedSession.ConsumedAt != nil { + return service.ErrPendingAuthSessionConsumed + } + if !storedSession.ExpiresAt.IsZero() && now.After(storedSession.ExpiresAt) { + return service.ErrPendingAuthSessionExpired + } + if strings.TrimSpace(storedSession.BrowserSessionKey) != "" && + strings.TrimSpace(storedSession.BrowserSessionKey) != strings.TrimSpace(session.BrowserSessionKey) { + return service.ErrPendingAuthBrowserMismatch + } + + if _, err := tx.Client().PendingAuthSession.UpdateOneID(storedSession.ID). + SetConsumedAt(now). + SetCompletionCodeHash(""). + ClearCompletionCodeExpiresAt(). + Save(ctx); err != nil { + return err + } + + return nil } func applyPendingOAuthAdoption( @@ -1256,7 +1347,7 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string) return } - pendingSvc, session, clearCookies, err := readPendingOAuthBrowserSession(c, h) + _, session, clearCookies, err := readPendingOAuthBrowserSession(c, h) if err != nil { response.ErrorFrom(c, err) return @@ -1341,7 +1432,20 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string) response.ErrorFrom(c, err) return } - if err := applyPendingOAuthBinding(c.Request.Context(), client, h.authService, h.userService, session, decision, &user.ID, true, false); err != nil { + + tx, err := client.Tx(c.Request.Context()) + if err != nil { + if rollbackCreatedUser(err) { + return + } + response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to bind pending oauth identity").WithCause(err)) + return + } + defer func() { _ = tx.Rollback() }() + txCtx := dbent.NewTxContext(c.Request.Context(), tx) + + if err := applyPendingOAuthBinding(txCtx, client, h.authService, h.userService, session, decision, &user.ID, true, false); err != nil { + _ = tx.Rollback() if rollbackCreatedUser(err) { return } @@ -1350,11 +1454,12 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string) } if err := h.authService.FinalizeOAuthEmailAccount( - c.Request.Context(), + txCtx, user, strings.TrimSpace(req.InvitationCode), strings.TrimSpace(session.ProviderType), ); err != nil { + _ = tx.Rollback() if rollbackCreatedUser(err) { return } @@ -1362,7 +1467,8 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string) return } - if _, err := pendingSvc.ConsumeBrowserSession(c.Request.Context(), session.SessionToken, session.BrowserSessionKey); err != nil { + if err := consumePendingOAuthBrowserSessionTx(txCtx, tx, session); err != nil { + _ = tx.Rollback() if rollbackCreatedUser(err) { return } @@ -1371,6 +1477,25 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string) return } + if pendingOAuthCreateAccountPreCommitHook != nil { + if err := pendingOAuthCreateAccountPreCommitHook(txCtx, session); err != nil { + _ = tx.Rollback() + if rollbackCreatedUser(err) { + return + } + response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to bind pending oauth identity").WithCause(err)) + return + } + } + + if err := tx.Commit(); err != nil { + if rollbackCreatedUser(err) { + return + } + response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to bind pending oauth identity").WithCause(err)) + return + } + h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID) clearCookies() writeOAuthTokenPairResponse(c, tokenPair) diff --git a/backend/internal/handler/auth_oauth_pending_flow_test.go b/backend/internal/handler/auth_oauth_pending_flow_test.go index 1013a082..008c9da2 100644 --- a/backend/internal/handler/auth_oauth_pending_flow_test.go +++ b/backend/internal/handler/auth_oauth_pending_flow_test.go @@ -903,6 +903,63 @@ func TestCreateOIDCOAuthAccountExistingEmailNormalizesLegacySpacingAndCase(t *te require.Equal(t, "owner@example.com", storedSession.ResolvedEmail) } +func TestSendPendingOAuthVerifyCodeExistingEmailReturnsBindLoginState(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandlerWithEmailVerification(t, false, "owner@example.com", "135790") + ctx := context.Background() + + existingUser, err := client.User.Create(). + SetEmail("owner@example.com"). + SetUsername("owner-user"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("existing-email-send-code-session-token"). + SetIntent("login"). + SetProviderType("oidc"). + SetProviderKey("https://issuer.example"). + SetProviderSubject("oidc-existing-send-code-123"). + SetBrowserSessionKey("existing-email-send-code-browser-session-key"). + SetLocalFlowState(map[string]any{ + oauthCompletionResponseKey: map[string]any{ + "step": "email_required", + }, + }). + SetRedirectTo("/dashboard"). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + body := bytes.NewBufferString(`{"email":"owner@example.com"}`) + recorder := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/send-verify-code", body) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("existing-email-send-code-browser-session-key")}) + ginCtx.Request = req + + handler.SendPendingOAuthVerifyCode(ginCtx) + + require.Equal(t, http.StatusOK, recorder.Code) + + var payload map[string]any + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &payload)) + require.Equal(t, "pending_session", payload["auth_result"]) + require.Equal(t, "bind_login_required", payload["step"]) + require.Equal(t, "owner@example.com", payload["email"]) + + storedSession, err := client.PendingAuthSession.Get(ctx, session.ID) + require.NoError(t, err) + require.Equal(t, "adopt_existing_user_by_email", storedSession.Intent) + require.NotNil(t, storedSession.TargetUserID) + require.Equal(t, existingUser.ID, *storedSession.TargetUserID) + require.Equal(t, "owner@example.com", storedSession.ResolvedEmail) +} + func TestCreateOIDCOAuthAccountBlocksBackendModeBeforeCreatingUser(t *testing.T) { handler, client := newOAuthPendingFlowTestHandlerWithDependencies(t, oauthPendingFlowTestHandlerOptions{ emailVerifyEnabled: true, @@ -1032,6 +1089,78 @@ func TestCreateOIDCOAuthAccountRollsBackCreatedUserWhenBindingFails(t *testing.T require.Nil(t, storedSession.ConsumedAt) } +func TestCreateOIDCOAuthAccountRollsBackPostBindFailureBeforeIdentityCanCommit(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandlerWithDependencies(t, oauthPendingFlowTestHandlerOptions{ + emailVerifyEnabled: true, + emailCache: &oauthPendingFlowEmailCacheStub{ + verificationCodes: map[string]*service.VerificationCodeData{ + "fresh@example.com": { + Code: "246810", + CreatedAt: time.Now().UTC(), + ExpiresAt: time.Now().UTC().Add(15 * time.Minute), + }, + }, + }, + userRepoOptions: oauthPendingFlowUserRepoOptions{ + rejectDeleteWhileAuthIdentityExists: true, + }, + }) + ctx := context.Background() + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("create-account-finalize-failure-session-token"). + SetIntent("login"). + SetProviderType("oidc"). + SetProviderKey("https://issuer.example"). + SetProviderSubject("oidc-finalize-failure-123"). + SetBrowserSessionKey("create-account-finalize-failure-browser-session-key"). + SetUpstreamIdentityClaims(map[string]any{ + "username": "oidc_user", + }). + SetRedirectTo("/profile"). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + pendingOAuthCreateAccountPreCommitHook = func(context.Context, *dbent.PendingAuthSession) error { + return errors.New("forced post-bind failure") + } + t.Cleanup(func() { + pendingOAuthCreateAccountPreCommitHook = nil + }) + + body := bytes.NewBufferString(`{"email":"fresh@example.com","verify_code":"246810","password":"secret-123"}`) + recorder := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/create-account", body) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("create-account-finalize-failure-browser-session-key")}) + ginCtx.Request = req + + handler.CreateOIDCOAuthAccount(ginCtx) + + require.Equal(t, http.StatusInternalServerError, recorder.Code) + + userCount, err := client.User.Query().Where(dbuser.EmailEQ("fresh@example.com")).Count(ctx) + require.NoError(t, err) + require.Zero(t, userCount) + + identityCount, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("oidc"), + authidentity.ProviderKeyEQ("https://issuer.example"), + authidentity.ProviderSubjectEQ("oidc-finalize-failure-123"), + ). + Count(ctx) + require.NoError(t, err) + require.Zero(t, identityCount) + + storedSession, err := client.PendingAuthSession.Get(ctx, session.ID) + require.NoError(t, err) + require.Nil(t, storedSession.ConsumedAt) +} + func TestBindOIDCOAuthLoginBindsExistingUserAndConsumesSession(t *testing.T) { handler, client := newOAuthPendingFlowTestHandler(t, false) ctx := context.Background() @@ -1618,7 +1747,6 @@ type oauthPendingFlowTestHandlerOptions struct { defaultSubAssigner service.DefaultSubscriptionAssigner totpCache service.TotpCache totpEncryptor service.SecretEncryptor - redeemRepoFactory func(client *dbent.Client) service.RedeemCodeRepository userRepoOptions oauthPendingFlowUserRepoOptions } @@ -1685,13 +1813,7 @@ CREATE TABLE IF NOT EXISTS user_avatars ( client: client, options: options.userRepoOptions, } - redeemRepo := service.RedeemCodeRepository(nil) - if options.redeemRepoFactory != nil { - redeemRepo = options.redeemRepoFactory(client) - } - if redeemRepo == nil { - redeemRepo = &oauthPendingFlowRedeemCodeRepo{client: client} - } + redeemRepo := &oauthPendingFlowRedeemCodeRepo{client: client} var emailService *service.EmailService if options.emailCache != nil { emailService = service.NewEmailService(&oauthPendingFlowSettingRepoStub{ @@ -2011,14 +2133,6 @@ func (r *oauthPendingFlowRedeemCodeRepo) SumPositiveBalanceByUser(context.Contex panic("unexpected SumPositiveBalanceByUser call") } -type oauthPendingFlowFailingUseRedeemRepo struct { - *oauthPendingFlowRedeemCodeRepo -} - -func (r *oauthPendingFlowFailingUseRedeemRepo) Use(context.Context, int64, int64) error { - return errors.New("forced invitation use failure") -} - func decodeJSONResponseData(t *testing.T, recorder *httptest.ResponseRecorder) map[string]any { t.Helper() @@ -2093,7 +2207,7 @@ func countProviderGrantRecords( } type oauthPendingFlowUserRepo struct { - client *dbent.Client + client *dbent.Client options oauthPendingFlowUserRepoOptions } diff --git a/backend/internal/service/auth_oauth_email_flow.go b/backend/internal/service/auth_oauth_email_flow.go index ce25222c..ea558ae2 100644 --- a/backend/internal/service/auth_oauth_email_flow.go +++ b/backend/internal/service/auth_oauth_email_flow.go @@ -7,6 +7,9 @@ import ( "net/mail" "strings" "time" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/redeemcode" ) func normalizeOAuthSignupSource(signupSource string) string { @@ -50,7 +53,7 @@ func (s *AuthService) validateOAuthRegistrationInvitation(ctx context.Context, i if s == nil || s.settingService == nil || !s.settingService.IsInvitationCodeEnabled(ctx) { return nil, nil } - if s.redeemRepo == nil { + if s.redeemRepo == nil && s.oauthEmailFlowClient(ctx) == nil { return nil, ErrServiceUnavailable } @@ -59,7 +62,7 @@ func (s *AuthService) validateOAuthRegistrationInvitation(ctx context.Context, i return nil, ErrInvitationCodeRequired } - redeemCode, err := s.redeemRepo.GetByCode(ctx, invitationCode) + redeemCode, err := s.loadOAuthRegistrationInvitation(ctx, invitationCode) if err != nil { return nil, ErrInvitationCodeInvalid } @@ -181,12 +184,12 @@ func (s *AuthService) FinalizeOAuthEmailAccount( return err } if invitationRedeemCode != nil { - if err := s.redeemRepo.Use(ctx, invitationRedeemCode.ID, user.ID); err != nil { + if err := s.useOAuthRegistrationInvitation(ctx, invitationRedeemCode.ID, user.ID); err != nil { return ErrInvitationCodeInvalid } } - s.postAuthUserBootstrap(ctx, user, signupSource, false) + s.updateOAuthSignupSource(ctx, user.ID, signupSource) grantPlan := s.resolveSignupGrantPlan(ctx, signupSource) s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults") return nil @@ -211,7 +214,7 @@ func (s *AuthService) restoreOAuthRegistrationInvitation(ctx context.Context, in if s == nil || s.settingService == nil || !s.settingService.IsInvitationCodeEnabled(ctx) { return nil } - if s.redeemRepo == nil { + if s.redeemRepo == nil && s.oauthEmailFlowClient(ctx) == nil { return ErrServiceUnavailable } @@ -220,7 +223,7 @@ func (s *AuthService) restoreOAuthRegistrationInvitation(ctx context.Context, in return nil } - redeemCode, err := s.redeemRepo.GetByCode(ctx, invitationCode) + redeemCode, err := s.loadOAuthRegistrationInvitation(ctx, invitationCode) if err != nil { if errors.Is(err, ErrRedeemCodeNotFound) { return nil @@ -234,12 +237,115 @@ func (s *AuthService) restoreOAuthRegistrationInvitation(ctx context.Context, in redeemCode.Status = StatusUnused redeemCode.UsedBy = nil redeemCode.UsedAt = nil - if err := s.redeemRepo.Update(ctx, redeemCode); err != nil { + if err := s.updateOAuthRegistrationInvitation(ctx, redeemCode); err != nil { return fmt.Errorf("restore invitation code: %w", err) } return nil } +func (s *AuthService) oauthEmailFlowClient(ctx context.Context) *dbent.Client { + if s == nil || s.entClient == nil { + return nil + } + if tx := dbent.TxFromContext(ctx); tx != nil { + return tx.Client() + } + return s.entClient +} + +func (s *AuthService) loadOAuthRegistrationInvitation(ctx context.Context, invitationCode string) (*RedeemCode, error) { + if client := s.oauthEmailFlowClient(ctx); client != nil { + entity, err := client.RedeemCode.Query().Where(redeemcode.CodeEQ(invitationCode)).Only(ctx) + if err != nil { + if dbent.IsNotFound(err) { + return nil, ErrRedeemCodeNotFound + } + return nil, err + } + return &RedeemCode{ + ID: entity.ID, + Code: entity.Code, + Type: entity.Type, + Value: entity.Value, + Status: entity.Status, + UsedBy: entity.UsedBy, + UsedAt: entity.UsedAt, + Notes: oauthEmailFlowStringValue(entity.Notes), + CreatedAt: entity.CreatedAt, + GroupID: entity.GroupID, + ValidityDays: entity.ValidityDays, + }, nil + } + return s.redeemRepo.GetByCode(ctx, invitationCode) +} + +func (s *AuthService) useOAuthRegistrationInvitation(ctx context.Context, invitationID, userID int64) error { + if client := s.oauthEmailFlowClient(ctx); client != nil { + affected, err := client.RedeemCode.Update(). + Where(redeemcode.IDEQ(invitationID), redeemcode.StatusEQ(StatusUnused)). + SetStatus(StatusUsed). + SetUsedBy(userID). + SetUsedAt(time.Now().UTC()). + Save(ctx) + if err != nil { + return err + } + if affected == 0 { + return ErrRedeemCodeUsed + } + return nil + } + return s.redeemRepo.Use(ctx, invitationID, userID) +} + +func (s *AuthService) updateOAuthRegistrationInvitation(ctx context.Context, code *RedeemCode) error { + if code == nil { + return nil + } + if client := s.oauthEmailFlowClient(ctx); client != nil { + update := client.RedeemCode.UpdateOneID(code.ID). + SetCode(code.Code). + SetType(code.Type). + SetValue(code.Value). + SetStatus(code.Status). + SetNotes(code.Notes). + SetValidityDays(code.ValidityDays) + if code.UsedBy != nil { + update = update.SetUsedBy(*code.UsedBy) + } else { + update = update.ClearUsedBy() + } + if code.UsedAt != nil { + update = update.SetUsedAt(*code.UsedAt) + } else { + update = update.ClearUsedAt() + } + if code.GroupID != nil { + update = update.SetGroupID(*code.GroupID) + } else { + update = update.ClearGroupID() + } + _, err := update.Save(ctx) + return err + } + return s.redeemRepo.Update(ctx, code) +} + +func (s *AuthService) updateOAuthSignupSource(ctx context.Context, userID int64, signupSource string) { + client := s.oauthEmailFlowClient(ctx) + if client == nil || userID <= 0 || strings.TrimSpace(signupSource) == "" { + return + } + _ = client.User.UpdateOneID(userID).SetSignupSource(signupSource).Exec(ctx) +} + +func oauthEmailFlowStringValue(value *string) string { + if value == nil { + return "" + } + return *value +} + // ValidatePasswordCredentials checks the local password without completing the // login flow. This is used by pending third-party account adoption flows before // the external identity has been bound. @@ -269,7 +375,7 @@ func (s *AuthService) ValidatePasswordCredentials(ctx context.Context, email, pa func (s *AuthService) RecordSuccessfulLogin(ctx context.Context, userID int64) { if s != nil && s.userRepo != nil && userID > 0 { user, err := s.userRepo.GetByID(ctx, userID) - if err == nil { + if err == nil && user != nil && !isReservedEmail(user.Email) { s.backfillEmailIdentityOnSuccessfulLogin(ctx, user) } } diff --git a/backend/internal/service/user_service.go b/backend/internal/service/user_service.go index 9c7d4747..e6053984 100644 --- a/backend/internal/service/user_service.go +++ b/backend/internal/service/user_service.go @@ -240,7 +240,7 @@ func (s *UserService) GetProfileIdentitySummaries(ctx context.Context, userID in } return UserIdentitySummarySet{ - Email: s.buildEmailIdentitySummary(user), + Email: s.buildEmailIdentitySummary(user, records), LinuxDo: s.buildProviderIdentitySummary("linuxdo", records), OIDC: s.buildProviderIdentitySummary("oidc", records), WeChat: s.buildProviderIdentitySummary("wechat", records), @@ -497,7 +497,7 @@ func compressInlineAvatar(decoded []byte) ([]byte, string, error) { return nil, "", ErrAvatarTooLarge } -func (s *UserService) buildEmailIdentitySummary(user *User) UserIdentitySummary { +func (s *UserService) buildEmailIdentitySummary(user *User, records []UserAuthIdentityRecord) UserIdentitySummary { summary := UserIdentitySummary{ Provider: "email", CanBind: false, @@ -508,11 +508,34 @@ func (s *UserService) buildEmailIdentitySummary(user *User) UserIdentitySummary return summary } + filtered := filterUserAuthIdentities(records, "email") + if len(filtered) > 0 { + primary := selectPrimaryUserAuthIdentity(filtered) + email := strings.TrimSpace(firstStringIdentityValue(primary.Metadata, "email")) + if email == "" { + email = strings.TrimSpace(primary.ProviderSubject) + } + if email == "" || isReservedEmail(email) { + email = strings.TrimSpace(user.Email) + } + if email == "" || isReservedEmail(email) { + email = strings.TrimSpace(primary.ProviderKey) + } + + summary.Bound = true + summary.BoundCount = len(filtered) + summary.DisplayName = email + summary.SubjectHint = maskEmailIdentity(email) + summary.ProviderKey = strings.TrimSpace(primary.ProviderKey) + summary.VerifiedAt = primary.VerifiedAt + return summary + } + + // Compatibility fallback for legacy normal-email users that predate auth_identities backfill. email := strings.TrimSpace(user.Email) if email == "" || isReservedEmail(email) { return summary } - summary.Bound = true summary.BoundCount = 1 summary.DisplayName = email diff --git a/frontend/src/api/auth.ts b/frontend/src/api/auth.ts index 0f768018..89964c3c 100644 --- a/frontend/src/api/auth.ts +++ b/frontend/src/api/auth.ts @@ -208,6 +208,12 @@ export type PendingOAuthExchangeResponse = PendingOAuthBindLoginResponse export interface PendingOAuthCreateAccountResponse extends OAuthTokenResponse {} +export interface PendingOAuthSendVerifyCodeResponse extends SendVerifyCodeResponse { + auth_result?: string + provider?: string + redirect?: string +} + export type OAuthCompletionKind = 'login' | 'bind' export interface OAuthAdoptionDecision { @@ -451,8 +457,8 @@ export async function sendVerifyCode( export async function sendPendingOAuthVerifyCode( request: SendVerifyCodeRequest -): Promise { - const { data } = await apiClient.post( +): Promise { + const { data } = await apiClient.post( '/auth/oauth/pending/send-verify-code', request ) diff --git a/frontend/src/components/user/profile/ProfileIdentityBindingsSection.vue b/frontend/src/components/user/profile/ProfileIdentityBindingsSection.vue index ccc1cbd0..653b4e33 100644 --- a/frontend/src/components/user/profile/ProfileIdentityBindingsSection.vue +++ b/frontend/src/components/user/profile/ProfileIdentityBindingsSection.vue @@ -209,7 +209,12 @@ function getBindingStatus(provider: UserAuthProvider): boolean { function getBindingStatusForUser(user: User | null | undefined, provider: UserAuthProvider): boolean { if (provider === 'email') { - return typeof user?.email_bound === 'boolean' ? user.email_bound : Boolean(user?.email) + if (typeof user?.email_bound === 'boolean') { + return user.email_bound + } + const nested = user?.auth_bindings?.email ?? user?.identity_bindings?.email + const normalized = normalizeBindingStatus(nested) + return normalized ?? false } const directFlag = user?.[`${provider}_bound` as keyof User] diff --git a/frontend/src/components/user/profile/__tests__/ProfileIdentityBindingsSection.spec.ts b/frontend/src/components/user/profile/__tests__/ProfileIdentityBindingsSection.spec.ts index ec4aed5d..c07acf18 100644 --- a/frontend/src/components/user/profile/__tests__/ProfileIdentityBindingsSection.spec.ts +++ b/frontend/src/components/user/profile/__tests__/ProfileIdentityBindingsSection.spec.ts @@ -301,4 +301,27 @@ describe('ProfileIdentityBindingsSection', () => { expect(wrapper.get('[data-testid="profile-binding-email-status"]').text()).toBe('Bound') expect(authStore.user?.email).toBe('bound@example.com') }) + + it('keeps the email binding form visible when the user still lacks an email identity', () => { + const wrapper = mount(ProfileIdentityBindingsSection, { + global: { + plugins: [pinia], + }, + props: { + user: createUser({ + email: 'legacy@example.com', + email_bound: false, + auth_bindings: { + email: { bound: false }, + }, + }), + linuxdoEnabled: false, + oidcEnabled: false, + wechatEnabled: false, + }, + }) + + expect(wrapper.get('[data-testid="profile-binding-email-status"]').text()).toBe('Not bound') + expect(wrapper.get('[data-testid="profile-binding-email-input"]').exists()).toBe(true) + }) }) diff --git a/frontend/src/views/auth/EmailVerifyView.vue b/frontend/src/views/auth/EmailVerifyView.vue index d7bf6b7a..01829765 100644 --- a/frontend/src/views/auth/EmailVerifyView.vue +++ b/frontend/src/views/auth/EmailVerifyView.vue @@ -179,6 +179,8 @@ import { useAuthStore, useAppStore } from '@/stores' import { persistOAuthTokenContext, getPublicSettings, + isOAuthLoginCompletion, + type PendingOAuthSendVerifyCodeResponse, sendPendingOAuthVerifyCode, sendVerifyCode, } from '@/api/auth' @@ -216,10 +218,13 @@ type PendingAuthSessionSummary = { redirect?: string } type PendingOAuthCreateAccountResponse = { + auth_result?: string access_token: string refresh_token?: string expires_in?: number token_type?: string + provider?: string + redirect?: string } const email = ref('') @@ -353,6 +358,46 @@ function onTurnstileError(): void { errors.value.turnstile = t('auth.turnstileFailed') } +function isPendingOAuthFlow(): boolean { + return Boolean(pendingProvider.value.trim()) +} + +function shouldBypassRegistrationEmailPolicy(): boolean { + return isPendingOAuthFlow() || Boolean(pendingAuthToken.value.trim()) +} + +function resolvePendingOAuthCallbackRoute(provider: string): string { + switch (provider.trim().toLowerCase()) { + case 'linuxdo': + return '/auth/linuxdo/callback' + case 'oidc': + return '/auth/oidc/callback' + case 'wechat': + return '/auth/wechat/callback' + default: + return '/auth/callback' + } +} + +function isPendingOAuthSessionResponse(data: PendingOAuthCreateAccountResponse): boolean { + return data.auth_result === 'pending_session' +} + +function getPendingOAuthSendCodeSessionResponse( + data: PendingOAuthSendVerifyCodeResponse, +): PendingOAuthSendVerifyCodeResponse | null { + return data.auth_result === 'pending_session' ? data : null +} + +function persistPendingOAuthSession(provider: string, redirect?: string): void { + authStore.setPendingAuthSession({ + token: pendingAuthToken.value, + token_field: pendingAuthTokenField.value, + provider: provider.trim() || pendingProvider.value.trim(), + redirect: redirect || pendingRedirect.value || undefined, + }) +} + // ==================== Send Code ==================== async function sendCode(): Promise { @@ -360,7 +405,7 @@ async function sendCode(): Promise { errorMessage.value = '' try { - if (!pendingAuthToken.value && !isRegistrationEmailSuffixAllowed(email.value, registrationEmailSuffixWhitelist.value)) { + if (!shouldBypassRegistrationEmailPolicy() && !isRegistrationEmailSuffixAllowed(email.value, registrationEmailSuffixWhitelist.value)) { errorMessage.value = buildEmailSuffixNotAllowedMessage() appStore.showError(errorMessage.value) return @@ -372,10 +417,25 @@ async function sendCode(): Promise { // 优先使用重发时新获取的 token(因为初始 token 可能已被使用) turnstile_token: resendTurnstileToken.value || initialTurnstileToken.value || undefined } as Parameters[0] - const response = pendingAuthToken.value + const response = isPendingOAuthFlow() ? await sendPendingOAuthVerifyCode(requestPayload) : await sendVerifyCode(requestPayload) + const pendingSendCodeSession = isPendingOAuthFlow() + ? getPendingOAuthSendCodeSessionResponse(response as PendingOAuthSendVerifyCodeResponse) + : null + if (pendingSendCodeSession) { + sessionStorage.removeItem('register_data') + persistPendingOAuthSession( + pendingSendCodeSession.provider || pendingProvider.value, + pendingSendCodeSession.redirect, + ) + await router.push( + resolvePendingOAuthCallbackRoute(pendingSendCodeSession.provider || pendingProvider.value), + ) + return + } + codeSent.value = true startCountdown(response.countdown) @@ -438,13 +498,13 @@ async function handleVerify(): Promise { isLoading.value = true try { - if (!isRegistrationEmailSuffixAllowed(email.value, registrationEmailSuffixWhitelist.value)) { + if (!shouldBypassRegistrationEmailPolicy() && !isRegistrationEmailSuffixAllowed(email.value, registrationEmailSuffixWhitelist.value)) { errorMessage.value = buildEmailSuffixNotAllowedMessage() appStore.showError(errorMessage.value) return } - if (pendingProvider.value) { + if (isPendingOAuthFlow()) { const { data } = await apiClient.post( '/auth/oauth/pending/create-account', { @@ -456,6 +516,16 @@ async function handleVerify(): Promise { adopt_avatar: pendingAdoptionDecision.value?.adoptAvatar } ) + if (isPendingOAuthSessionResponse(data)) { + sessionStorage.removeItem('register_data') + persistPendingOAuthSession(data.provider || pendingProvider.value, data.redirect) + await router.push(resolvePendingOAuthCallbackRoute(data.provider || pendingProvider.value)) + return + } + if (!isOAuthLoginCompletion(data)) { + throw new Error(t('auth.verifyFailed')) + } + persistOAuthTokenContext(data) await authStore.setToken(data.access_token) authStore.clearPendingAuthSession?.() diff --git a/frontend/src/views/auth/__tests__/EmailVerifyView.spec.ts b/frontend/src/views/auth/__tests__/EmailVerifyView.spec.ts index c231d6e7..9f67a994 100644 --- a/frontend/src/views/auth/__tests__/EmailVerifyView.spec.ts +++ b/frontend/src/views/auth/__tests__/EmailVerifyView.spec.ts @@ -8,6 +8,7 @@ const { showErrorMock, registerMock, setTokenMock, + setPendingAuthSessionMock, clearPendingAuthSessionMock, getPublicSettingsMock, sendVerifyCodeMock, @@ -21,6 +22,7 @@ const { showErrorMock: vi.fn(), registerMock: vi.fn(), setTokenMock: vi.fn(), + setPendingAuthSessionMock: vi.fn(), clearPendingAuthSessionMock: vi.fn(), getPublicSettingsMock: vi.fn(), sendVerifyCodeMock: vi.fn(), @@ -68,6 +70,7 @@ vi.mock('@/stores', () => ({ pendingAuthSession: authStoreState.pendingAuthSession, register: (...args: any[]) => registerMock(...args), setToken: (...args: any[]) => setTokenMock(...args), + setPendingAuthSession: (...args: any[]) => setPendingAuthSessionMock(...args), clearPendingAuthSession: (...args: any[]) => clearPendingAuthSessionMock(...args), }), useAppStore: () => ({ @@ -100,6 +103,7 @@ describe('EmailVerifyView', () => { showErrorMock.mockReset() registerMock.mockReset() setTokenMock.mockReset() + setPendingAuthSessionMock.mockReset() clearPendingAuthSessionMock.mockReset() getPublicSettingsMock.mockReset() sendVerifyCodeMock.mockReset() @@ -196,6 +200,97 @@ describe('EmailVerifyView', () => { expect(showErrorMock).not.toHaveBeenCalled() }) + it('uses the pending oauth verify-code endpoint when auth store only carries the pending provider', async () => { + authStoreState.pendingAuthSession = { + token: '', + token_field: 'pending_oauth_token', + provider: 'oidc', + redirect: '/profile', + } + getPublicSettingsMock.mockResolvedValue({ + turnstile_enabled: false, + turnstile_site_key: '', + site_name: 'Sub2API', + registration_email_suffix_whitelist: ['allowed.com'], + }) + sessionStorage.setItem( + 'register_data', + JSON.stringify({ + email: 'fresh@example.com', + password: 'secret-123', + }) + ) + + mount(EmailVerifyView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + TurnstileWidget: true, + transition: false, + }, + }, + }) + + await flushPromises() + + expect(sendPendingOAuthVerifyCodeMock).toHaveBeenCalledWith({ + email: 'fresh@example.com', + pending_oauth_token: undefined, + }) + expect(sendVerifyCodeMock).not.toHaveBeenCalled() + expect(showErrorMock).not.toHaveBeenCalled() + }) + + it('returns to the oauth callback flow when pending send-code detects an existing account email', async () => { + authStoreState.pendingAuthSession = { + token: '', + token_field: 'pending_oauth_token', + provider: 'oidc', + redirect: '/profile/security', + } + getPublicSettingsMock.mockResolvedValue({ + turnstile_enabled: false, + turnstile_site_key: '', + site_name: 'Sub2API', + registration_email_suffix_whitelist: ['allowed.com'], + }) + sendPendingOAuthVerifyCodeMock.mockResolvedValue({ + auth_result: 'pending_session', + provider: 'oidc', + redirect: '/profile/security', + }) + sessionStorage.setItem( + 'register_data', + JSON.stringify({ + email: 'fresh@example.com', + password: 'secret-123', + }) + ) + + mount(EmailVerifyView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + TurnstileWidget: true, + transition: false, + }, + }, + }) + + await flushPromises() + + expect(setPendingAuthSessionMock).toHaveBeenCalledWith({ + token: '', + token_field: 'pending_oauth_token', + provider: 'oidc', + redirect: '/profile/security', + }) + expect(pushMock).toHaveBeenCalledWith('/auth/oidc/callback') + expect(showErrorMock).not.toHaveBeenCalled() + }) + it('submits pending auth account creation when session storage has no pending metadata but auth store does', async () => { authStoreState.pendingAuthSession = { token: 'pending-token-1', @@ -252,6 +347,70 @@ describe('EmailVerifyView', () => { expect(registerMock).not.toHaveBeenCalled() }) + it('returns to the oauth callback flow when pending account creation becomes bind-login', async () => { + authStoreState.pendingAuthSession = { + token: '', + token_field: 'pending_oauth_token', + provider: 'oidc', + redirect: '/profile/security', + } + getPublicSettingsMock.mockResolvedValue({ + turnstile_enabled: false, + turnstile_site_key: '', + site_name: 'Sub2API', + registration_email_suffix_whitelist: ['allowed.com'], + }) + sessionStorage.setItem( + 'register_data', + JSON.stringify({ + email: 'fresh@example.com', + password: 'secret-123', + }) + ) + apiClientPostMock.mockResolvedValue({ + data: { + auth_result: 'pending_session', + provider: 'oidc', + step: 'bind_login_required', + redirect: '/profile/security', + email: 'fresh@example.com', + }, + }) + + const wrapper = mount(EmailVerifyView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + TurnstileWidget: true, + transition: false, + }, + }, + }) + + await flushPromises() + await wrapper.get('#code').setValue('123456') + await wrapper.get('form').trigger('submit.prevent') + await flushPromises() + + expect(apiClientPostMock).toHaveBeenCalledWith('/auth/oauth/pending/create-account', { + email: 'fresh@example.com', + password: 'secret-123', + verify_code: '123456', + }) + expect(setPendingAuthSessionMock).toHaveBeenCalledWith({ + token: '', + token_field: 'pending_oauth_token', + provider: 'oidc', + redirect: '/profile/security', + }) + expect(pushMock).toHaveBeenCalledWith('/auth/oidc/callback') + expect(setTokenMock).not.toHaveBeenCalled() + expect(persistOAuthTokenContextMock).not.toHaveBeenCalled() + expect(clearPendingAuthSessionMock).not.toHaveBeenCalled() + expect(showSuccessMock).not.toHaveBeenCalled() + }) + it('keeps the normal email registration flow unchanged', async () => { sessionStorage.setItem( 'register_data',