fix: harden oidc compat email and email bind tx

This commit is contained in:
IanShaw027
2026-04-21 11:00:08 +08:00
parent 7e89bca5e6
commit f398650166
4 changed files with 424 additions and 6 deletions

View File

@@ -19,6 +19,7 @@ import (
"strings"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
@@ -323,18 +324,13 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
if emailVerified == nil {
emailVerified = idClaims.EmailVerified
}
if cfg.RequireEmailVerified {
if emailVerified == nil || !*emailVerified {
redirectOAuthError(c, frontendCallback, "email_not_verified", "email is not verified", "")
return
}
}
if userInfoClaims.Subject != "" && idClaims.Subject != "" && strings.TrimSpace(userInfoClaims.Subject) != strings.TrimSpace(idClaims.Subject) {
redirectOAuthError(c, frontendCallback, "subject_mismatch", "userinfo subject does not match id_token", "")
return
}
identityKey := oidcIdentityKey(issuer, subject)
compatEmail := strings.TrimSpace(firstNonEmpty(userInfoClaims.Email, idClaims.Email))
email := oidcSyntheticEmailFromIdentityKey(identityKey)
username := firstNonEmpty(
userInfoClaims.Username,
@@ -357,6 +353,9 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
"suggested_display_name": firstNonEmpty(userInfoClaims.DisplayName, idClaims.Name, username),
"suggested_avatar_url": userInfoClaims.AvatarURL,
}
if compatEmail != "" && !strings.EqualFold(strings.TrimSpace(compatEmail), strings.TrimSpace(email)) {
upstreamClaims["compat_email"] = compatEmail
}
if intent == oauthIntentBindCurrentUser {
targetUserID, err := h.readOAuthBindUserIDFromCookie(c, oidcOAuthBindUserCookieName)
if err != nil {
@@ -416,6 +415,40 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
return
}
compatEmailUser, err := h.findOIDCCompatEmailUser(c.Request.Context(), compatEmail)
if err != nil {
redirectOAuthError(c, frontendCallback, "session_error", infraerrors.Reason(err), infraerrors.Message(err))
return
}
if compatEmailUser != nil {
if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{
Intent: "adopt_existing_user_by_email",
Identity: identityRef,
TargetUserID: &compatEmailUser.ID,
ResolvedEmail: compatEmailUser.Email,
RedirectTo: redirectTo,
BrowserSessionKey: browserSessionKey,
UpstreamIdentityClaims: upstreamClaims,
CompletionResponse: map[string]any{
"redirect": redirectTo,
"step": "bind_login_required",
"email": compatEmailUser.Email,
},
}); err != nil {
redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "")
return
}
redirectToFrontendCallback(c, frontendCallback)
return
}
if cfg.RequireEmailVerified {
if emailVerified == nil || !*emailVerified {
redirectOAuthError(c, frontendCallback, "email_not_verified", "email is not verified", "")
return
}
}
if h.isForceEmailOnThirdPartySignup(c.Request.Context()) {
if err := h.createOAuthEmailRequiredPendingSession(c, identityRef, redirectTo, browserSessionKey, upstreamClaims); err != nil {
redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "")
@@ -473,6 +506,30 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
redirectToFrontendCallback(c, frontendCallback)
}
func (h *AuthHandler) findOIDCCompatEmailUser(ctx context.Context, email string) (*dbent.User, error) {
client := h.entClient()
if client == nil {
return nil, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")
}
email = strings.TrimSpace(strings.ToLower(email))
if email == "" ||
strings.HasSuffix(email, service.LinuxDoConnectSyntheticEmailDomain) ||
strings.HasSuffix(email, service.OIDCConnectSyntheticEmailDomain) ||
strings.HasSuffix(email, service.WeChatConnectSyntheticEmailDomain) {
return nil, nil
}
userEntity, err := findUserByNormalizedEmail(ctx, client, email)
if err != nil {
if errors.Is(err, service.ErrUserNotFound) {
return nil, nil
}
return nil, infraerrors.InternalServer("COMPAT_EMAIL_LOOKUP_FAILED", "failed to look up compat email user").WithCause(err)
}
return userEntity, nil
}
type completeOIDCOAuthRequest struct {
InvitationCode string `json:"invitation_code" binding:"required"`
AdoptDisplayName *bool `json:"adopt_display_name,omitempty"`

View File

@@ -245,6 +245,127 @@ func TestOIDCOAuthCallbackCreatesLoginPendingSessionForExistingUser(t *testing.T
require.Nil(t, completion["error"])
}
func TestOIDCOAuthCallbackCreatesBindPendingSessionForCompatEmailUser(t *testing.T) {
cfg, cleanup := newOIDCTestProvider(t, oidcProviderFixture{
Subject: "oidc-subject-compat",
PreferredUsername: "oidc_compat",
DisplayName: "OIDC Compat Display",
AvatarURL: "https://cdn.example/oidc-compat.png",
Email: "legacy@example.com",
EmailVerified: true,
})
defer cleanup()
handler, client := newOIDCOAuthHandlerAndClient(t, false, cfg)
defer client.Close()
ctx := context.Background()
existingUser, err := client.User.Create().
SetEmail("legacy@example.com").
SetUsername("legacy-user").
SetPasswordHash("hash").
SetRole(service.RoleUser).
SetStatus(service.StatusActive).
Save(ctx)
require.NoError(t, err)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/oidc/callback?code=oidc-code&state=state-compat", nil)
req.AddCookie(encodedCookie(oidcOAuthStateCookieName, "state-compat"))
req.AddCookie(encodedCookie(oidcOAuthRedirectCookie, "/dashboard"))
req.AddCookie(encodedCookie(oidcOAuthVerifierCookie, "verifier-compat"))
req.AddCookie(encodedCookie(oidcOAuthNonceCookie, "nonce-oidc-subject-compat"))
req.AddCookie(encodedCookie(oidcOAuthIntentCookieName, oauthIntentLogin))
req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-compat"))
c.Request = req
handler.OIDCOAuthCallback(c)
require.Equal(t, http.StatusFound, recorder.Code)
require.Equal(t, "/auth/oidc/callback", recorder.Header().Get("Location"))
sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)
require.NotNil(t, sessionCookie)
session, err := client.PendingAuthSession.Query().
Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))).
Only(ctx)
require.NoError(t, err)
require.Equal(t, "adopt_existing_user_by_email", session.Intent)
require.NotNil(t, session.TargetUserID)
require.Equal(t, existingUser.ID, *session.TargetUserID)
require.Equal(t, existingUser.Email, session.ResolvedEmail)
require.Equal(t, "legacy@example.com", session.UpstreamIdentityClaims["compat_email"])
completion := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any)
require.Equal(t, "/dashboard", completion["redirect"])
require.Equal(t, "bind_login_required", completion["step"])
require.Equal(t, existingUser.Email, completion["email"])
_, hasAccessToken := completion["access_token"]
require.False(t, hasAccessToken)
}
func TestOIDCOAuthCallbackAllowsCompatEmailBindWhenUpstreamEmailIsUnverified(t *testing.T) {
cfg, cleanup := newOIDCTestProvider(t, oidcProviderFixture{
Subject: "oidc-subject-unverified-compat",
PreferredUsername: "oidc_unverified",
DisplayName: "OIDC Unverified Compat Display",
AvatarURL: "https://cdn.example/oidc-unverified.png",
Email: "owner@example.com",
EmailVerified: false,
})
defer cleanup()
cfg.RequireEmailVerified = true
handler, client := newOIDCOAuthHandlerAndClient(t, false, cfg)
defer client.Close()
ctx := context.Background()
existingUser, err := client.User.Create().
SetEmail("owner@example.com").
SetUsername("owner-user").
SetPasswordHash("hash").
SetRole(service.RoleUser).
SetStatus(service.StatusActive).
Save(ctx)
require.NoError(t, err)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/oidc/callback?code=oidc-code&state=state-unverified-compat", nil)
req.AddCookie(encodedCookie(oidcOAuthStateCookieName, "state-unverified-compat"))
req.AddCookie(encodedCookie(oidcOAuthRedirectCookie, "/settings/connections"))
req.AddCookie(encodedCookie(oidcOAuthVerifierCookie, "verifier-unverified-compat"))
req.AddCookie(encodedCookie(oidcOAuthNonceCookie, "nonce-oidc-subject-unverified-compat"))
req.AddCookie(encodedCookie(oidcOAuthIntentCookieName, oauthIntentLogin))
req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-unverified-compat"))
c.Request = req
handler.OIDCOAuthCallback(c)
require.Equal(t, http.StatusFound, recorder.Code)
require.Equal(t, "/auth/oidc/callback", recorder.Header().Get("Location"))
sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)
require.NotNil(t, sessionCookie)
session, err := client.PendingAuthSession.Query().
Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))).
Only(ctx)
require.NoError(t, err)
require.Equal(t, "adopt_existing_user_by_email", session.Intent)
require.NotNil(t, session.TargetUserID)
require.Equal(t, existingUser.ID, *session.TargetUserID)
require.Equal(t, existingUser.Email, session.ResolvedEmail)
require.Equal(t, "owner@example.com", session.UpstreamIdentityClaims["compat_email"])
completion := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any)
require.Equal(t, "/settings/connections", completion["redirect"])
require.Equal(t, "bind_login_required", completion["step"])
require.Equal(t, existingUser.Email, completion["email"])
}
func TestOIDCOAuthCallbackCreatesInvitationPendingSessionWhenSignupRequiresInvite(t *testing.T) {
cfg, cleanup := newOIDCTestProvider(t, oidcProviderFixture{
Subject: "oidc-subject-invite",

View File

@@ -6,7 +6,10 @@ import (
"fmt"
"net/mail"
"strings"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/authidentity"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
@@ -55,6 +58,13 @@ func (s *AuthService) BindEmailIdentity(
}
firstRealEmailBind := !hasBindableEmailIdentitySubject(currentUser.Email)
if firstRealEmailBind && s.entClient != nil {
if err := s.bindEmailIdentityWithDefaultsTx(ctx, currentUser, normalizedEmail, hashedPassword); err != nil {
return nil, err
}
return currentUser, nil
}
currentUser.Email = normalizedEmail
currentUser.PasswordHash = hashedPassword
if err := s.userRepo.Update(ctx, currentUser); err != nil {
@@ -126,3 +136,162 @@ func hasBindableEmailIdentitySubject(email string) bool {
normalized := strings.ToLower(strings.TrimSpace(email))
return normalized != "" && !isReservedEmail(normalized)
}
func (s *AuthService) bindEmailIdentityWithDefaultsTx(
ctx context.Context,
currentUser *User,
email string,
hashedPassword string,
) error {
if tx := dbent.TxFromContext(ctx); tx != nil {
return s.bindEmailIdentityWithDefaults(ctx, tx.Client(), currentUser, email, hashedPassword)
}
tx, err := s.entClient.Tx(ctx)
if err != nil {
return ErrServiceUnavailable
}
defer func() { _ = tx.Rollback() }()
txCtx := dbent.NewTxContext(ctx, tx)
if err := s.bindEmailIdentityWithDefaults(txCtx, tx.Client(), currentUser, email, hashedPassword); err != nil {
return err
}
if err := tx.Commit(); err != nil {
return ErrServiceUnavailable
}
return nil
}
func (s *AuthService) bindEmailIdentityWithDefaults(
ctx context.Context,
client *dbent.Client,
currentUser *User,
email string,
hashedPassword string,
) error {
if client == nil || currentUser == nil || currentUser.ID <= 0 {
return ErrServiceUnavailable
}
oldEmail := currentUser.Email
if _, err := client.User.UpdateOneID(currentUser.ID).
SetEmail(email).
SetPasswordHash(hashedPassword).
Save(ctx); err != nil {
if dbent.IsConstraintError(err) {
return ErrEmailExists
}
return ErrServiceUnavailable
}
if err := replaceBoundEmailAuthIdentityWithClient(ctx, client, currentUser.ID, oldEmail, email, "auth_service_email_bind"); err != nil {
if errors.Is(err, ErrEmailExists) {
return ErrEmailExists
}
return ErrServiceUnavailable
}
if err := s.ApplyProviderDefaultSettingsOnFirstBind(ctx, currentUser.ID, "email"); err != nil {
return fmt.Errorf("apply email first bind defaults: %w", err)
}
updatedUser, err := client.User.Get(ctx, currentUser.ID)
if err != nil {
return ErrServiceUnavailable
}
currentUser.Email = updatedUser.Email
currentUser.PasswordHash = updatedUser.PasswordHash
currentUser.Balance = updatedUser.Balance
currentUser.Concurrency = updatedUser.Concurrency
currentUser.UpdatedAt = updatedUser.UpdatedAt
return nil
}
func replaceBoundEmailAuthIdentityWithClient(
ctx context.Context,
client *dbent.Client,
userID int64,
oldEmail string,
newEmail string,
source string,
) error {
newSubject := normalizeBoundEmailAuthIdentitySubject(newEmail)
if err := ensureBoundEmailAuthIdentityWithClient(ctx, client, userID, newSubject, source); err != nil {
return err
}
oldSubject := normalizeBoundEmailAuthIdentitySubject(oldEmail)
if oldSubject == "" || oldSubject == newSubject {
return nil
}
_, err := client.AuthIdentity.Delete().
Where(
authidentity.UserIDEQ(userID),
authidentity.ProviderTypeEQ("email"),
authidentity.ProviderKeyEQ("email"),
authidentity.ProviderSubjectEQ(oldSubject),
).
Exec(ctx)
return err
}
func ensureBoundEmailAuthIdentityWithClient(
ctx context.Context,
client *dbent.Client,
userID int64,
subject string,
source string,
) error {
if client == nil || userID <= 0 || subject == "" {
return nil
}
if strings.TrimSpace(source) == "" {
source = "auth_service_email_bind"
}
if err := client.AuthIdentity.Create().
SetUserID(userID).
SetProviderType("email").
SetProviderKey("email").
SetProviderSubject(subject).
SetVerifiedAt(time.Now().UTC()).
SetMetadata(map[string]any{"source": strings.TrimSpace(source)}).
OnConflictColumns(
authidentity.FieldProviderType,
authidentity.FieldProviderKey,
authidentity.FieldProviderSubject,
).
DoNothing().
Exec(ctx); err != nil {
return err
}
identity, err := client.AuthIdentity.Query().
Where(
authidentity.ProviderTypeEQ("email"),
authidentity.ProviderKeyEQ("email"),
authidentity.ProviderSubjectEQ(subject),
).
Only(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return nil
}
return err
}
if identity.UserID != userID {
return ErrEmailExists
}
return nil
}
func normalizeBoundEmailAuthIdentitySubject(email string) string {
normalized := strings.ToLower(strings.TrimSpace(email))
if normalized == "" || isReservedEmail(normalized) {
return ""
}
return normalized
}

View File

@@ -5,6 +5,7 @@ package service_test
import (
"context"
"database/sql"
"errors"
"testing"
"time"
@@ -34,6 +35,20 @@ func (s *emailBindDefaultSubAssignerStub) AssignOrExtendSubscription(
return &service.UserSubscription{UserID: input.UserID, GroupID: input.GroupID}, false, nil
}
type flakyEmailBindDefaultSubAssignerStub struct {
err error
calls []*service.AssignSubscriptionInput
}
func (s *flakyEmailBindDefaultSubAssignerStub) AssignOrExtendSubscription(
_ context.Context,
input *service.AssignSubscriptionInput,
) (*service.UserSubscription, bool, error) {
cloned := *input
s.calls = append(s.calls, &cloned)
return nil, false, s.err
}
func newAuthServiceForEmailBind(
t *testing.T,
settings map[string]string,
@@ -187,6 +202,62 @@ func TestAuthServiceBindEmailIdentity_RejectsExistingEmailOnAnotherUser(t *testi
require.Equal(t, 0, countProviderGrantRecords(t, client, sourceUser.ID, "email", "first_bind"))
}
func TestAuthServiceBindEmailIdentity_RollsBackWhenFirstBindDefaultsFail(t *testing.T) {
assigner := &flakyEmailBindDefaultSubAssignerStub{err: errors.New("temporary assign failure")}
cache := &emailBindCacheStub{
data: &service.VerificationCodeData{
Code: "123456",
CreatedAt: time.Now().UTC(),
ExpiresAt: time.Now().UTC().Add(10 * time.Minute),
},
}
svc, _, client := newAuthServiceForEmailBind(t, map[string]string{
service.SettingKeyAuthSourceDefaultEmailBalance: "8.5",
service.SettingKeyAuthSourceDefaultEmailConcurrency: "4",
service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":11,"validity_days":30}]`,
service.SettingKeyAuthSourceDefaultEmailGrantOnFirstBind: "true",
}, cache, assigner)
ctx := context.Background()
originalEmail := "legacy-rollback" + service.LinuxDoConnectSyntheticEmailDomain
user, err := client.User.Create().
SetEmail(originalEmail).
SetUsername("legacy-rollback").
SetPasswordHash("old-hash").
SetBalance(2.5).
SetConcurrency(1).
SetRole(service.RoleUser).
SetStatus(service.StatusActive).
Save(ctx)
require.NoError(t, err)
updatedUser, err := svc.BindEmailIdentity(ctx, user.ID, "rollback@example.com", "123456", "new-password")
require.ErrorContains(t, err, "apply email first bind defaults")
require.ErrorContains(t, err, "temporary assign failure")
require.Nil(t, updatedUser)
storedUser, err := client.User.Get(ctx, user.ID)
require.NoError(t, err)
require.Equal(t, originalEmail, storedUser.Email)
require.Equal(t, "old-hash", storedUser.PasswordHash)
require.Equal(t, 2.5, storedUser.Balance)
require.Equal(t, 1, storedUser.Concurrency)
identityCount, err := client.AuthIdentity.Query().
Where(
authidentity.UserIDEQ(user.ID),
authidentity.ProviderTypeEQ("email"),
authidentity.ProviderKeyEQ("email"),
authidentity.ProviderSubjectEQ("rollback@example.com"),
).
Count(ctx)
require.NoError(t, err)
require.Equal(t, 0, identityCount)
require.Len(t, assigner.calls, 1)
require.Equal(t, 0, countProviderGrantRecords(t, client, user.ID, "email", "first_bind"))
}
func TestAuthServiceBindEmailIdentity_RejectsReservedEmail(t *testing.T) {
cache := &emailBindCacheStub{
data: &service.VerificationCodeData{