From f3986501663c864594f1444440f5e36199595983 Mon Sep 17 00:00:00 2001 From: IanShaw027 Date: Tue, 21 Apr 2026 11:00:08 +0800 Subject: [PATCH] fix: harden oidc compat email and email bind tx --- backend/internal/handler/auth_oidc_oauth.go | 69 ++++++- .../internal/handler/auth_oidc_oauth_test.go | 121 +++++++++++++ .../internal/service/auth_email_binding.go | 169 ++++++++++++++++++ .../service/auth_service_email_bind_test.go | 71 ++++++++ 4 files changed, 424 insertions(+), 6 deletions(-) diff --git a/backend/internal/handler/auth_oidc_oauth.go b/backend/internal/handler/auth_oidc_oauth.go index 5901a953..6d19e9d6 100644 --- a/backend/internal/handler/auth_oidc_oauth.go +++ b/backend/internal/handler/auth_oidc_oauth.go @@ -19,6 +19,7 @@ import ( "strings" "time" + dbent "github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/internal/config" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/oauth" @@ -323,18 +324,13 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) { if emailVerified == nil { emailVerified = idClaims.EmailVerified } - if cfg.RequireEmailVerified { - if emailVerified == nil || !*emailVerified { - redirectOAuthError(c, frontendCallback, "email_not_verified", "email is not verified", "") - return - } - } if userInfoClaims.Subject != "" && idClaims.Subject != "" && strings.TrimSpace(userInfoClaims.Subject) != strings.TrimSpace(idClaims.Subject) { redirectOAuthError(c, frontendCallback, "subject_mismatch", "userinfo subject does not match id_token", "") return } identityKey := oidcIdentityKey(issuer, subject) + compatEmail := strings.TrimSpace(firstNonEmpty(userInfoClaims.Email, idClaims.Email)) email := oidcSyntheticEmailFromIdentityKey(identityKey) username := firstNonEmpty( userInfoClaims.Username, @@ -357,6 +353,9 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) { "suggested_display_name": firstNonEmpty(userInfoClaims.DisplayName, idClaims.Name, username), "suggested_avatar_url": userInfoClaims.AvatarURL, } + if compatEmail != "" && !strings.EqualFold(strings.TrimSpace(compatEmail), strings.TrimSpace(email)) { + upstreamClaims["compat_email"] = compatEmail + } if intent == oauthIntentBindCurrentUser { targetUserID, err := h.readOAuthBindUserIDFromCookie(c, oidcOAuthBindUserCookieName) if err != nil { @@ -416,6 +415,40 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) { return } + compatEmailUser, err := h.findOIDCCompatEmailUser(c.Request.Context(), compatEmail) + if err != nil { + redirectOAuthError(c, frontendCallback, "session_error", infraerrors.Reason(err), infraerrors.Message(err)) + return + } + if compatEmailUser != nil { + if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{ + Intent: "adopt_existing_user_by_email", + Identity: identityRef, + TargetUserID: &compatEmailUser.ID, + ResolvedEmail: compatEmailUser.Email, + RedirectTo: redirectTo, + BrowserSessionKey: browserSessionKey, + UpstreamIdentityClaims: upstreamClaims, + CompletionResponse: map[string]any{ + "redirect": redirectTo, + "step": "bind_login_required", + "email": compatEmailUser.Email, + }, + }); err != nil { + redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "") + return + } + redirectToFrontendCallback(c, frontendCallback) + return + } + + if cfg.RequireEmailVerified { + if emailVerified == nil || !*emailVerified { + redirectOAuthError(c, frontendCallback, "email_not_verified", "email is not verified", "") + return + } + } + if h.isForceEmailOnThirdPartySignup(c.Request.Context()) { if err := h.createOAuthEmailRequiredPendingSession(c, identityRef, redirectTo, browserSessionKey, upstreamClaims); err != nil { redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "") @@ -473,6 +506,30 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) { redirectToFrontendCallback(c, frontendCallback) } +func (h *AuthHandler) findOIDCCompatEmailUser(ctx context.Context, email string) (*dbent.User, error) { + client := h.entClient() + if client == nil { + return nil, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready") + } + + email = strings.TrimSpace(strings.ToLower(email)) + if email == "" || + strings.HasSuffix(email, service.LinuxDoConnectSyntheticEmailDomain) || + strings.HasSuffix(email, service.OIDCConnectSyntheticEmailDomain) || + strings.HasSuffix(email, service.WeChatConnectSyntheticEmailDomain) { + return nil, nil + } + + userEntity, err := findUserByNormalizedEmail(ctx, client, email) + if err != nil { + if errors.Is(err, service.ErrUserNotFound) { + return nil, nil + } + return nil, infraerrors.InternalServer("COMPAT_EMAIL_LOOKUP_FAILED", "failed to look up compat email user").WithCause(err) + } + return userEntity, nil +} + type completeOIDCOAuthRequest struct { InvitationCode string `json:"invitation_code" binding:"required"` AdoptDisplayName *bool `json:"adopt_display_name,omitempty"` diff --git a/backend/internal/handler/auth_oidc_oauth_test.go b/backend/internal/handler/auth_oidc_oauth_test.go index ba736db2..5cd8e0ea 100644 --- a/backend/internal/handler/auth_oidc_oauth_test.go +++ b/backend/internal/handler/auth_oidc_oauth_test.go @@ -245,6 +245,127 @@ func TestOIDCOAuthCallbackCreatesLoginPendingSessionForExistingUser(t *testing.T require.Nil(t, completion["error"]) } +func TestOIDCOAuthCallbackCreatesBindPendingSessionForCompatEmailUser(t *testing.T) { + cfg, cleanup := newOIDCTestProvider(t, oidcProviderFixture{ + Subject: "oidc-subject-compat", + PreferredUsername: "oidc_compat", + DisplayName: "OIDC Compat Display", + AvatarURL: "https://cdn.example/oidc-compat.png", + Email: "legacy@example.com", + EmailVerified: true, + }) + defer cleanup() + + handler, client := newOIDCOAuthHandlerAndClient(t, false, cfg) + defer client.Close() + + ctx := context.Background() + existingUser, err := client.User.Create(). + SetEmail("legacy@example.com"). + SetUsername("legacy-user"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/oidc/callback?code=oidc-code&state=state-compat", nil) + req.AddCookie(encodedCookie(oidcOAuthStateCookieName, "state-compat")) + req.AddCookie(encodedCookie(oidcOAuthRedirectCookie, "/dashboard")) + req.AddCookie(encodedCookie(oidcOAuthVerifierCookie, "verifier-compat")) + req.AddCookie(encodedCookie(oidcOAuthNonceCookie, "nonce-oidc-subject-compat")) + req.AddCookie(encodedCookie(oidcOAuthIntentCookieName, oauthIntentLogin)) + req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-compat")) + c.Request = req + + handler.OIDCOAuthCallback(c) + + require.Equal(t, http.StatusFound, recorder.Code) + require.Equal(t, "/auth/oidc/callback", recorder.Header().Get("Location")) + + sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName) + require.NotNil(t, sessionCookie) + + session, err := client.PendingAuthSession.Query(). + Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))). + Only(ctx) + require.NoError(t, err) + require.Equal(t, "adopt_existing_user_by_email", session.Intent) + require.NotNil(t, session.TargetUserID) + require.Equal(t, existingUser.ID, *session.TargetUserID) + require.Equal(t, existingUser.Email, session.ResolvedEmail) + require.Equal(t, "legacy@example.com", session.UpstreamIdentityClaims["compat_email"]) + + completion := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any) + require.Equal(t, "/dashboard", completion["redirect"]) + require.Equal(t, "bind_login_required", completion["step"]) + require.Equal(t, existingUser.Email, completion["email"]) + _, hasAccessToken := completion["access_token"] + require.False(t, hasAccessToken) +} + +func TestOIDCOAuthCallbackAllowsCompatEmailBindWhenUpstreamEmailIsUnverified(t *testing.T) { + cfg, cleanup := newOIDCTestProvider(t, oidcProviderFixture{ + Subject: "oidc-subject-unverified-compat", + PreferredUsername: "oidc_unverified", + DisplayName: "OIDC Unverified Compat Display", + AvatarURL: "https://cdn.example/oidc-unverified.png", + Email: "owner@example.com", + EmailVerified: false, + }) + defer cleanup() + cfg.RequireEmailVerified = true + + handler, client := newOIDCOAuthHandlerAndClient(t, false, cfg) + defer client.Close() + + ctx := context.Background() + existingUser, err := client.User.Create(). + SetEmail("owner@example.com"). + SetUsername("owner-user"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/oidc/callback?code=oidc-code&state=state-unverified-compat", nil) + req.AddCookie(encodedCookie(oidcOAuthStateCookieName, "state-unverified-compat")) + req.AddCookie(encodedCookie(oidcOAuthRedirectCookie, "/settings/connections")) + req.AddCookie(encodedCookie(oidcOAuthVerifierCookie, "verifier-unverified-compat")) + req.AddCookie(encodedCookie(oidcOAuthNonceCookie, "nonce-oidc-subject-unverified-compat")) + req.AddCookie(encodedCookie(oidcOAuthIntentCookieName, oauthIntentLogin)) + req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-unverified-compat")) + c.Request = req + + handler.OIDCOAuthCallback(c) + + require.Equal(t, http.StatusFound, recorder.Code) + require.Equal(t, "/auth/oidc/callback", recorder.Header().Get("Location")) + + sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName) + require.NotNil(t, sessionCookie) + + session, err := client.PendingAuthSession.Query(). + Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))). + Only(ctx) + require.NoError(t, err) + require.Equal(t, "adopt_existing_user_by_email", session.Intent) + require.NotNil(t, session.TargetUserID) + require.Equal(t, existingUser.ID, *session.TargetUserID) + require.Equal(t, existingUser.Email, session.ResolvedEmail) + require.Equal(t, "owner@example.com", session.UpstreamIdentityClaims["compat_email"]) + + completion := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any) + require.Equal(t, "/settings/connections", completion["redirect"]) + require.Equal(t, "bind_login_required", completion["step"]) + require.Equal(t, existingUser.Email, completion["email"]) +} + func TestOIDCOAuthCallbackCreatesInvitationPendingSessionWhenSignupRequiresInvite(t *testing.T) { cfg, cleanup := newOIDCTestProvider(t, oidcProviderFixture{ Subject: "oidc-subject-invite", diff --git a/backend/internal/service/auth_email_binding.go b/backend/internal/service/auth_email_binding.go index b999660b..58f8e647 100644 --- a/backend/internal/service/auth_email_binding.go +++ b/backend/internal/service/auth_email_binding.go @@ -6,7 +6,10 @@ import ( "fmt" "net/mail" "strings" + "time" + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/authidentity" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" ) @@ -55,6 +58,13 @@ func (s *AuthService) BindEmailIdentity( } firstRealEmailBind := !hasBindableEmailIdentitySubject(currentUser.Email) + if firstRealEmailBind && s.entClient != nil { + if err := s.bindEmailIdentityWithDefaultsTx(ctx, currentUser, normalizedEmail, hashedPassword); err != nil { + return nil, err + } + return currentUser, nil + } + currentUser.Email = normalizedEmail currentUser.PasswordHash = hashedPassword if err := s.userRepo.Update(ctx, currentUser); err != nil { @@ -126,3 +136,162 @@ func hasBindableEmailIdentitySubject(email string) bool { normalized := strings.ToLower(strings.TrimSpace(email)) return normalized != "" && !isReservedEmail(normalized) } + +func (s *AuthService) bindEmailIdentityWithDefaultsTx( + ctx context.Context, + currentUser *User, + email string, + hashedPassword string, +) error { + if tx := dbent.TxFromContext(ctx); tx != nil { + return s.bindEmailIdentityWithDefaults(ctx, tx.Client(), currentUser, email, hashedPassword) + } + + tx, err := s.entClient.Tx(ctx) + if err != nil { + return ErrServiceUnavailable + } + defer func() { _ = tx.Rollback() }() + + txCtx := dbent.NewTxContext(ctx, tx) + if err := s.bindEmailIdentityWithDefaults(txCtx, tx.Client(), currentUser, email, hashedPassword); err != nil { + return err + } + if err := tx.Commit(); err != nil { + return ErrServiceUnavailable + } + return nil +} + +func (s *AuthService) bindEmailIdentityWithDefaults( + ctx context.Context, + client *dbent.Client, + currentUser *User, + email string, + hashedPassword string, +) error { + if client == nil || currentUser == nil || currentUser.ID <= 0 { + return ErrServiceUnavailable + } + + oldEmail := currentUser.Email + if _, err := client.User.UpdateOneID(currentUser.ID). + SetEmail(email). + SetPasswordHash(hashedPassword). + Save(ctx); err != nil { + if dbent.IsConstraintError(err) { + return ErrEmailExists + } + return ErrServiceUnavailable + } + + if err := replaceBoundEmailAuthIdentityWithClient(ctx, client, currentUser.ID, oldEmail, email, "auth_service_email_bind"); err != nil { + if errors.Is(err, ErrEmailExists) { + return ErrEmailExists + } + return ErrServiceUnavailable + } + + if err := s.ApplyProviderDefaultSettingsOnFirstBind(ctx, currentUser.ID, "email"); err != nil { + return fmt.Errorf("apply email first bind defaults: %w", err) + } + + updatedUser, err := client.User.Get(ctx, currentUser.ID) + if err != nil { + return ErrServiceUnavailable + } + currentUser.Email = updatedUser.Email + currentUser.PasswordHash = updatedUser.PasswordHash + currentUser.Balance = updatedUser.Balance + currentUser.Concurrency = updatedUser.Concurrency + currentUser.UpdatedAt = updatedUser.UpdatedAt + return nil +} + +func replaceBoundEmailAuthIdentityWithClient( + ctx context.Context, + client *dbent.Client, + userID int64, + oldEmail string, + newEmail string, + source string, +) error { + newSubject := normalizeBoundEmailAuthIdentitySubject(newEmail) + if err := ensureBoundEmailAuthIdentityWithClient(ctx, client, userID, newSubject, source); err != nil { + return err + } + + oldSubject := normalizeBoundEmailAuthIdentitySubject(oldEmail) + if oldSubject == "" || oldSubject == newSubject { + return nil + } + + _, err := client.AuthIdentity.Delete(). + Where( + authidentity.UserIDEQ(userID), + authidentity.ProviderTypeEQ("email"), + authidentity.ProviderKeyEQ("email"), + authidentity.ProviderSubjectEQ(oldSubject), + ). + Exec(ctx) + return err +} + +func ensureBoundEmailAuthIdentityWithClient( + ctx context.Context, + client *dbent.Client, + userID int64, + subject string, + source string, +) error { + if client == nil || userID <= 0 || subject == "" { + return nil + } + + if strings.TrimSpace(source) == "" { + source = "auth_service_email_bind" + } + + if err := client.AuthIdentity.Create(). + SetUserID(userID). + SetProviderType("email"). + SetProviderKey("email"). + SetProviderSubject(subject). + SetVerifiedAt(time.Now().UTC()). + SetMetadata(map[string]any{"source": strings.TrimSpace(source)}). + OnConflictColumns( + authidentity.FieldProviderType, + authidentity.FieldProviderKey, + authidentity.FieldProviderSubject, + ). + DoNothing(). + Exec(ctx); err != nil { + return err + } + + identity, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("email"), + authidentity.ProviderKeyEQ("email"), + authidentity.ProviderSubjectEQ(subject), + ). + Only(ctx) + if err != nil { + if dbent.IsNotFound(err) { + return nil + } + return err + } + if identity.UserID != userID { + return ErrEmailExists + } + return nil +} + +func normalizeBoundEmailAuthIdentitySubject(email string) string { + normalized := strings.ToLower(strings.TrimSpace(email)) + if normalized == "" || isReservedEmail(normalized) { + return "" + } + return normalized +} diff --git a/backend/internal/service/auth_service_email_bind_test.go b/backend/internal/service/auth_service_email_bind_test.go index 899a736d..fd5f499b 100644 --- a/backend/internal/service/auth_service_email_bind_test.go +++ b/backend/internal/service/auth_service_email_bind_test.go @@ -5,6 +5,7 @@ package service_test import ( "context" "database/sql" + "errors" "testing" "time" @@ -34,6 +35,20 @@ func (s *emailBindDefaultSubAssignerStub) AssignOrExtendSubscription( return &service.UserSubscription{UserID: input.UserID, GroupID: input.GroupID}, false, nil } +type flakyEmailBindDefaultSubAssignerStub struct { + err error + calls []*service.AssignSubscriptionInput +} + +func (s *flakyEmailBindDefaultSubAssignerStub) AssignOrExtendSubscription( + _ context.Context, + input *service.AssignSubscriptionInput, +) (*service.UserSubscription, bool, error) { + cloned := *input + s.calls = append(s.calls, &cloned) + return nil, false, s.err +} + func newAuthServiceForEmailBind( t *testing.T, settings map[string]string, @@ -187,6 +202,62 @@ func TestAuthServiceBindEmailIdentity_RejectsExistingEmailOnAnotherUser(t *testi require.Equal(t, 0, countProviderGrantRecords(t, client, sourceUser.ID, "email", "first_bind")) } +func TestAuthServiceBindEmailIdentity_RollsBackWhenFirstBindDefaultsFail(t *testing.T) { + assigner := &flakyEmailBindDefaultSubAssignerStub{err: errors.New("temporary assign failure")} + cache := &emailBindCacheStub{ + data: &service.VerificationCodeData{ + Code: "123456", + CreatedAt: time.Now().UTC(), + ExpiresAt: time.Now().UTC().Add(10 * time.Minute), + }, + } + svc, _, client := newAuthServiceForEmailBind(t, map[string]string{ + service.SettingKeyAuthSourceDefaultEmailBalance: "8.5", + service.SettingKeyAuthSourceDefaultEmailConcurrency: "4", + service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":11,"validity_days":30}]`, + service.SettingKeyAuthSourceDefaultEmailGrantOnFirstBind: "true", + }, cache, assigner) + + ctx := context.Background() + originalEmail := "legacy-rollback" + service.LinuxDoConnectSyntheticEmailDomain + user, err := client.User.Create(). + SetEmail(originalEmail). + SetUsername("legacy-rollback"). + SetPasswordHash("old-hash"). + SetBalance(2.5). + SetConcurrency(1). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + updatedUser, err := svc.BindEmailIdentity(ctx, user.ID, "rollback@example.com", "123456", "new-password") + require.ErrorContains(t, err, "apply email first bind defaults") + require.ErrorContains(t, err, "temporary assign failure") + require.Nil(t, updatedUser) + + storedUser, err := client.User.Get(ctx, user.ID) + require.NoError(t, err) + require.Equal(t, originalEmail, storedUser.Email) + require.Equal(t, "old-hash", storedUser.PasswordHash) + require.Equal(t, 2.5, storedUser.Balance) + require.Equal(t, 1, storedUser.Concurrency) + + identityCount, err := client.AuthIdentity.Query(). + Where( + authidentity.UserIDEQ(user.ID), + authidentity.ProviderTypeEQ("email"), + authidentity.ProviderKeyEQ("email"), + authidentity.ProviderSubjectEQ("rollback@example.com"), + ). + Count(ctx) + require.NoError(t, err) + require.Equal(t, 0, identityCount) + + require.Len(t, assigner.calls, 1) + require.Equal(t, 0, countProviderGrantRecords(t, client, user.ID, "email", "first_bind")) +} + func TestAuthServiceBindEmailIdentity_RejectsReservedEmail(t *testing.T) { cache := &emailBindCacheStub{ data: &service.VerificationCodeData{