From fb6204ea8b1bb3e6a6a55ce94896a287538d594a Mon Sep 17 00:00:00 2001 From: IanShaw027 Date: Mon, 20 Apr 2026 19:53:22 +0800 Subject: [PATCH] feat: apply oauth first-bind defaults and pending bind 2fa --- backend/internal/handler/auth_handler.go | 57 ++ .../internal/handler/auth_linuxdo_oauth.go | 2 +- .../handler/auth_oauth_pending_flow.go | 54 +- .../handler/auth_oauth_pending_flow_test.go | 560 +++++++++++++++++- backend/internal/handler/auth_oidc_oauth.go | 2 +- backend/internal/handler/auth_wechat_oauth.go | 2 +- .../internal/service/auth_oauth_first_bind.go | 106 ++++ backend/internal/service/totp_service.go | 43 +- 8 files changed, 778 insertions(+), 48 deletions(-) create mode 100644 backend/internal/service/auth_oauth_first_bind.go diff --git a/backend/internal/handler/auth_handler.go b/backend/internal/handler/auth_handler.go index f4ddf890..e4697609 100644 --- a/backend/internal/handler/auth_handler.go +++ b/backend/internal/handler/auth_handler.go @@ -6,6 +6,7 @@ import ( "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/handler/dto" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/ip" "github.com/Wei-Shaw/sub2api/internal/pkg/response" middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" @@ -269,6 +270,62 @@ func (h *AuthHandler) Login2FA(c *gin.Context) { return } + if session.PendingOAuthBind != nil { + pendingSvc, err := h.pendingIdentityService() + if err != nil { + response.ErrorFrom(c, err) + return + } + + pendingSession, err := pendingSvc.GetBrowserSession( + c.Request.Context(), + session.PendingOAuthBind.PendingSessionToken, + session.PendingOAuthBind.BrowserSessionKey, + ) + if err != nil { + response.ErrorFrom(c, err) + return + } + + decision, err := h.ensurePendingOAuthAdoptionDecision(c, pendingSession.ID, oauthAdoptionDecisionRequest{}) + if err != nil { + response.ErrorFrom(c, err) + return + } + if err := applyPendingOAuthBinding( + c.Request.Context(), + h.entClient(), + h.authService, + pendingSession, + decision, + &user.ID, + true, + true, + ); err != nil { + response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to bind pending oauth identity").WithCause(err)) + return + } + if _, err := pendingSvc.ConsumeBrowserSession( + c.Request.Context(), + pendingSession.SessionToken, + pendingSession.BrowserSessionKey, + ); err != nil { + response.ErrorFrom(c, err) + return + } + + secureCookie := isRequestHTTPS(c) + clearOAuthPendingSessionCookie(c, secureCookie) + clearOAuthPendingBrowserCookie(c, secureCookie) + h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID) + + user, err = h.userService.GetByID(c.Request.Context(), session.UserID) + if err != nil { + response.ErrorFrom(c, err) + return + } + } + // Delete the login session (only after all checks pass) _ = h.totpService.DeleteLoginSession(c.Request.Context(), req.TempToken) diff --git a/backend/internal/handler/auth_linuxdo_oauth.go b/backend/internal/handler/auth_linuxdo_oauth.go index c4ecb8fa..c3ec3804 100644 --- a/backend/internal/handler/auth_linuxdo_oauth.go +++ b/backend/internal/handler/auth_linuxdo_oauth.go @@ -436,7 +436,7 @@ func (h *AuthHandler) CompleteLinuxDoOAuthRegistration(c *gin.Context) { response.ErrorFrom(c, err) return } - if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), session, decision, &user.ID); err != nil { + if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), h.authService, session, decision, &user.ID); err != nil { response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err)) return } diff --git a/backend/internal/handler/auth_oauth_pending_flow.go b/backend/internal/handler/auth_oauth_pending_flow.go index 99b9b406..6d6564e8 100644 --- a/backend/internal/handler/auth_oauth_pending_flow.go +++ b/backend/internal/handler/auth_oauth_pending_flow.go @@ -601,10 +601,12 @@ func shouldBindPendingOAuthIdentity(session *dbent.PendingAuthSession, decision func applyPendingOAuthBinding( ctx context.Context, client *dbent.Client, + authService *service.AuthService, session *dbent.PendingAuthSession, decision *dbent.IdentityAdoptionDecision, overrideUserID *int64, forceBind bool, + applyFirstBindDefaults bool, ) error { if client == nil || session == nil { return nil @@ -638,16 +640,17 @@ func applyPendingOAuthBinding( return err } defer func() { _ = tx.Rollback() }() + txCtx := dbent.NewTxContext(ctx, tx) if decision != nil && decision.AdoptDisplayName && adoptedDisplayName != "" { if err := tx.Client().User.UpdateOneID(targetUserID). SetUsername(adoptedDisplayName). - Exec(ctx); err != nil { + Exec(txCtx); err != nil { return err } } - identity, err := ensurePendingOAuthIdentityForUser(ctx, tx, session, targetUserID) + identity, err := ensurePendingOAuthIdentityForUser(txCtx, tx, session, targetUserID) if err != nil { return err } @@ -667,14 +670,20 @@ func applyPendingOAuthBinding( if issuer := oauthIdentityIssuer(session); issuer != nil { updateIdentity = updateIdentity.SetIssuer(strings.TrimSpace(*issuer)) } - if _, err := updateIdentity.Save(ctx); err != nil { + if _, err := updateIdentity.Save(txCtx); err != nil { return err } if decision != nil && (decision.IdentityID == nil || *decision.IdentityID != identity.ID) { if _, err := tx.Client().IdentityAdoptionDecision.UpdateOneID(decision.ID). SetIdentityID(identity.ID). - Save(ctx); err != nil { + Save(txCtx); err != nil { + return err + } + } + + if applyFirstBindDefaults && authService != nil { + if err := authService.ApplyProviderDefaultSettingsOnFirstBind(txCtx, targetUserID, session.ProviderType); err != nil { return err } } @@ -685,11 +694,21 @@ func applyPendingOAuthBinding( func applyPendingOAuthAdoption( ctx context.Context, client *dbent.Client, + authService *service.AuthService, session *dbent.PendingAuthSession, decision *dbent.IdentityAdoptionDecision, overrideUserID *int64, ) error { - return applyPendingOAuthBinding(ctx, client, session, decision, overrideUserID, false) + return applyPendingOAuthBinding( + ctx, + client, + authService, + session, + decision, + overrideUserID, + false, + strings.EqualFold(strings.TrimSpace(session.Intent), "bind_current_user"), + ) } func applySuggestedProfileToCompletionResponse(payload map[string]any, upstream map[string]any) { @@ -804,7 +823,26 @@ func (h *AuthHandler) bindPendingOAuthLogin(c *gin.Context, provider string) { response.ErrorFrom(c, err) return } - if err := applyPendingOAuthBinding(c.Request.Context(), h.entClient(), session, decision, &user.ID, true); err != nil { + if h.totpService != nil && h.settingSvc.IsTotpEnabled(c.Request.Context()) && user.TotpEnabled { + tempToken, err := h.totpService.CreatePendingOAuthBindLoginSession( + c.Request.Context(), + user.ID, + user.Email, + session.SessionToken, + session.BrowserSessionKey, + ) + if err != nil { + response.InternalError(c, "Failed to create 2FA session") + return + } + response.Success(c, TotpLoginResponse{ + Requires2FA: true, + TempToken: tempToken, + UserEmailMasked: service.MaskEmail(user.Email), + }) + return + } + if err := applyPendingOAuthBinding(c.Request.Context(), h.entClient(), h.authService, session, decision, &user.ID, true, true); err != nil { response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to bind pending oauth identity").WithCause(err)) return } @@ -900,7 +938,7 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string) response.ErrorFrom(c, err) return } - if err := applyPendingOAuthBinding(c.Request.Context(), client, session, decision, &user.ID, true); err != nil { + if err := applyPendingOAuthBinding(c.Request.Context(), client, h.authService, session, decision, &user.ID, true, false); err != nil { response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to bind pending oauth identity").WithCause(err)) return } @@ -990,7 +1028,7 @@ func (h *AuthHandler) ExchangePendingOAuthCompletion(c *gin.Context) { response.ErrorFrom(c, err) return } - if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), session, decision, session.TargetUserID); err != nil { + if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), h.authService, session, decision, session.TargetUserID); err != nil { response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err)) return } diff --git a/backend/internal/handler/auth_oauth_pending_flow_test.go b/backend/internal/handler/auth_oauth_pending_flow_test.go index 80338b8a..ae506e52 100644 --- a/backend/internal/handler/auth_oauth_pending_flow_test.go +++ b/backend/internal/handler/auth_oauth_pending_flow_test.go @@ -20,6 +20,7 @@ import ( "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/gin-gonic/gin" + "github.com/pquerna/otp/totp" "github.com/stretchr/testify/require" "entgo.io/ent/dialect" @@ -773,6 +774,316 @@ func TestBindOIDCOAuthLoginRejectsInvalidPasswordWithoutConsumingSession(t *test require.Nil(t, storedSession.ConsumedAt) } +func TestBindOIDCOAuthLoginAppliesFirstBindGrantOnce(t *testing.T) { + defaultSubAssigner := &oauthPendingFlowDefaultSubAssignerStub{} + handler, client := newOAuthPendingFlowTestHandlerWithDependencies(t, oauthPendingFlowTestHandlerOptions{ + settingValues: map[string]string{ + service.SettingKeyAuthSourceDefaultOIDCBalance: "12.5", + service.SettingKeyAuthSourceDefaultOIDCConcurrency: "3", + service.SettingKeyAuthSourceDefaultOIDCSubscriptions: `[{"group_id":101,"validity_days":30}]`, + service.SettingKeyAuthSourceDefaultOIDCGrantOnFirstBind: "true", + }, + defaultSubAssigner: defaultSubAssigner, + }) + ctx := context.Background() + + passwordHash, err := handler.authService.HashPassword("secret-123") + require.NoError(t, err) + + existingUser, err := client.User.Create(). + SetEmail("owner@example.com"). + SetUsername("owner-user"). + SetPasswordHash(passwordHash). + SetBalance(5). + SetConcurrency(2). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + firstSession, err := client.PendingAuthSession.Create(). + SetSessionToken("first-bind-session-token"). + SetIntent("adopt_existing_user_by_email"). + SetProviderType("oidc"). + SetProviderKey("https://issuer.example"). + SetProviderSubject("oidc-bind-first-123"). + SetTargetUserID(existingUser.ID). + SetResolvedEmail(existingUser.Email). + SetBrowserSessionKey("first-bind-browser-session-key"). + SetUpstreamIdentityClaims(map[string]any{ + "suggested_display_name": "Bound OIDC User", + "suggested_avatar_url": "https://cdn.example/bound.png", + }). + SetRedirectTo("/profile"). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + firstBody := bytes.NewBufferString(`{"email":"owner@example.com","password":"secret-123","adopt_display_name":false,"adopt_avatar":false}`) + firstRecorder := httptest.NewRecorder() + firstGinCtx, _ := gin.CreateTestContext(firstRecorder) + firstReq := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/bind-login", firstBody) + firstReq.Header.Set("Content-Type", "application/json") + firstReq.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(firstSession.SessionToken)}) + firstReq.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("first-bind-browser-session-key")}) + firstGinCtx.Request = firstReq + + handler.BindOIDCOAuthLogin(firstGinCtx) + + require.Equal(t, http.StatusOK, firstRecorder.Code) + + storedUser, err := client.User.Get(ctx, existingUser.ID) + require.NoError(t, err) + require.Equal(t, 17.5, storedUser.Balance) + require.Equal(t, 5, storedUser.Concurrency) + require.Zero(t, storedUser.TotalRecharged) + require.Len(t, defaultSubAssigner.calls, 1) + require.Equal(t, int64(existingUser.ID), defaultSubAssigner.calls[0].UserID) + require.Equal(t, int64(101), defaultSubAssigner.calls[0].GroupID) + require.Equal(t, 30, defaultSubAssigner.calls[0].ValidityDays) + require.Equal(t, 1, countProviderGrantRecords(t, client, existingUser.ID, "oidc", "first_bind")) + + secondSession, err := client.PendingAuthSession.Create(). + SetSessionToken("second-bind-session-token"). + SetIntent("adopt_existing_user_by_email"). + SetProviderType("oidc"). + SetProviderKey("https://issuer.example"). + SetProviderSubject("oidc-bind-second-456"). + SetTargetUserID(existingUser.ID). + SetResolvedEmail(existingUser.Email). + SetBrowserSessionKey("second-bind-browser-session-key"). + SetUpstreamIdentityClaims(map[string]any{ + "suggested_display_name": "Second OIDC User", + "suggested_avatar_url": "https://cdn.example/second.png", + }). + SetRedirectTo("/profile"). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + secondBody := bytes.NewBufferString(`{"email":"owner@example.com","password":"secret-123","adopt_display_name":false,"adopt_avatar":false}`) + secondRecorder := httptest.NewRecorder() + secondGinCtx, _ := gin.CreateTestContext(secondRecorder) + secondReq := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/bind-login", secondBody) + secondReq.Header.Set("Content-Type", "application/json") + secondReq.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(secondSession.SessionToken)}) + secondReq.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("second-bind-browser-session-key")}) + secondGinCtx.Request = secondReq + + handler.BindOIDCOAuthLogin(secondGinCtx) + + require.Equal(t, http.StatusOK, secondRecorder.Code) + + storedUser, err = client.User.Get(ctx, existingUser.ID) + require.NoError(t, err) + require.Equal(t, 17.5, storedUser.Balance) + require.Equal(t, 5, storedUser.Concurrency) + require.Zero(t, storedUser.TotalRecharged) + require.Len(t, defaultSubAssigner.calls, 1) + require.Equal(t, 1, countProviderGrantRecords(t, client, existingUser.ID, "oidc", "first_bind")) +} + +func TestBindOIDCOAuthLoginReturns2FAChallengeWhenUserHasTotp(t *testing.T) { + totpCache := &oauthPendingFlowTotpCacheStub{} + handler, client := newOAuthPendingFlowTestHandlerWithDependencies(t, oauthPendingFlowTestHandlerOptions{ + settingValues: map[string]string{ + service.SettingKeyTotpEnabled: "true", + }, + totpCache: totpCache, + totpEncryptor: oauthPendingFlowTotpEncryptorStub{}, + }) + ctx := context.Background() + + passwordHash, err := handler.authService.HashPassword("secret-123") + require.NoError(t, err) + totpEnabledAt := time.Now().UTC().Add(-time.Hour) + secret := "JBSWY3DPEHPK3PXP" + + existingUser, err := client.User.Create(). + SetEmail("owner@example.com"). + SetUsername("owner-user"). + SetPasswordHash(passwordHash). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + SetTotpEnabled(true). + SetTotpSecretEncrypted(secret). + SetTotpEnabledAt(totpEnabledAt). + Save(ctx) + require.NoError(t, err) + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("bind-login-2fa-session-token"). + SetIntent("adopt_existing_user_by_email"). + SetProviderType("oidc"). + SetProviderKey("https://issuer.example"). + SetProviderSubject("oidc-bind-2fa-123"). + SetTargetUserID(existingUser.ID). + SetResolvedEmail(existingUser.Email). + SetBrowserSessionKey("bind-login-2fa-browser-session-key"). + SetUpstreamIdentityClaims(map[string]any{ + "suggested_display_name": "Bound OIDC User", + "suggested_avatar_url": "https://cdn.example/bound.png", + }). + SetRedirectTo("/profile"). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + body := bytes.NewBufferString(`{"email":"owner@example.com","password":"secret-123","adopt_display_name":false,"adopt_avatar":false}`) + recorder := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/bind-login", body) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("bind-login-2fa-browser-session-key")}) + ginCtx.Request = req + + handler.BindOIDCOAuthLogin(ginCtx) + + require.Equal(t, http.StatusOK, recorder.Code) + data := decodeJSONResponseData(t, recorder) + require.Equal(t, true, data["requires_2fa"]) + require.Equal(t, "o***r@example.com", data["user_email_masked"]) + tempToken, ok := data["temp_token"].(string) + require.True(t, ok) + require.NotEmpty(t, tempToken) + + loginSession, err := totpCache.GetLoginSession(ctx, tempToken) + require.NoError(t, err) + require.NotNil(t, loginSession) + require.NotNil(t, loginSession.PendingOAuthBind) + require.Equal(t, session.SessionToken, loginSession.PendingOAuthBind.PendingSessionToken) + require.Equal(t, session.BrowserSessionKey, loginSession.PendingOAuthBind.BrowserSessionKey) + + identityCount, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("oidc"), + authidentity.ProviderKeyEQ("https://issuer.example"), + authidentity.ProviderSubjectEQ("oidc-bind-2fa-123"), + ). + Count(ctx) + require.NoError(t, err) + require.Zero(t, identityCount) + + storedSession, err := client.PendingAuthSession.Get(ctx, session.ID) + require.NoError(t, err) + require.Nil(t, storedSession.ConsumedAt) +} + +func TestLogin2FACompletesPendingOAuthBindAndConsumesSession(t *testing.T) { + totpCache := &oauthPendingFlowTotpCacheStub{} + defaultSubAssigner := &oauthPendingFlowDefaultSubAssignerStub{} + handler, client := newOAuthPendingFlowTestHandlerWithDependencies(t, oauthPendingFlowTestHandlerOptions{ + settingValues: map[string]string{ + service.SettingKeyTotpEnabled: "true", + service.SettingKeyAuthSourceDefaultOIDCBalance: "8", + service.SettingKeyAuthSourceDefaultOIDCConcurrency: "2", + service.SettingKeyAuthSourceDefaultOIDCGrantOnFirstBind: "true", + }, + defaultSubAssigner: defaultSubAssigner, + totpCache: totpCache, + totpEncryptor: oauthPendingFlowTotpEncryptorStub{}, + }) + ctx := context.Background() + + passwordHash, err := handler.authService.HashPassword("secret-123") + require.NoError(t, err) + totpEnabledAt := time.Now().UTC().Add(-time.Hour) + secret := "JBSWY3DPEHPK3PXP" + + existingUser, err := client.User.Create(). + SetEmail("owner@example.com"). + SetUsername("owner-user"). + SetPasswordHash(passwordHash). + SetBalance(1.5). + SetConcurrency(4). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + SetTotpEnabled(true). + SetTotpSecretEncrypted(secret). + SetTotpEnabledAt(totpEnabledAt). + Save(ctx) + require.NoError(t, err) + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("login-2fa-pending-session-token"). + SetIntent("adopt_existing_user_by_email"). + SetProviderType("oidc"). + SetProviderKey("https://issuer.example"). + SetProviderSubject("oidc-login-2fa-123"). + SetTargetUserID(existingUser.ID). + SetResolvedEmail(existingUser.Email). + SetBrowserSessionKey("login-2fa-browser-session-key"). + SetUpstreamIdentityClaims(map[string]any{ + "suggested_display_name": "Bound OIDC User", + "suggested_avatar_url": "https://cdn.example/bound.png", + }). + SetRedirectTo("/profile"). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + _, err = client.IdentityAdoptionDecision.Create(). + SetPendingAuthSessionID(session.ID). + SetAdoptDisplayName(false). + SetAdoptAvatar(false). + Save(ctx) + require.NoError(t, err) + + tempToken, err := handler.totpService.CreatePendingOAuthBindLoginSession( + ctx, + existingUser.ID, + existingUser.Email, + session.SessionToken, + session.BrowserSessionKey, + ) + require.NoError(t, err) + + code, err := totp.GenerateCode(secret, time.Now().UTC()) + require.NoError(t, err) + + body := bytes.NewBufferString(`{"temp_token":"` + tempToken + `","totp_code":"` + code + `"}`) + recorder := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/login/2fa", body) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue(session.BrowserSessionKey)}) + ginCtx.Request = req + + handler.Login2FA(ginCtx) + + require.Equal(t, http.StatusOK, recorder.Code) + payload := decodeJSONResponseData(t, recorder) + require.NotEmpty(t, payload["access_token"]) + require.NotEmpty(t, payload["refresh_token"]) + + identity, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("oidc"), + authidentity.ProviderKeyEQ("https://issuer.example"), + authidentity.ProviderSubjectEQ("oidc-login-2fa-123"), + ). + Only(ctx) + require.NoError(t, err) + require.Equal(t, existingUser.ID, identity.UserID) + + storedSession, err := client.PendingAuthSession.Get(ctx, session.ID) + require.NoError(t, err) + require.NotNil(t, storedSession.ConsumedAt) + + loginSession, err := totpCache.GetLoginSession(ctx, tempToken) + require.NoError(t, err) + require.Nil(t, loginSession) + + storedUser, err := client.User.Get(ctx, existingUser.ID) + require.NoError(t, err) + require.Equal(t, 9.5, storedUser.Balance) + require.Equal(t, 6, storedUser.Concurrency) + require.Equal(t, 1, countProviderGrantRecords(t, client, existingUser.ID, "oidc", "first_bind")) + require.Empty(t, defaultSubAssigner.calls) +} + func newOAuthPendingFlowTestHandler(t *testing.T, invitationEnabled bool) (*AuthHandler, *dbent.Client) { t.Helper() @@ -805,6 +1116,27 @@ func newOAuthPendingFlowTestHandlerWithOptions( invitationEnabled bool, emailVerifyEnabled bool, emailCache service.EmailCache, +) (*AuthHandler, *dbent.Client) { + return newOAuthPendingFlowTestHandlerWithDependencies(t, oauthPendingFlowTestHandlerOptions{ + invitationEnabled: invitationEnabled, + emailVerifyEnabled: emailVerifyEnabled, + emailCache: emailCache, + }) +} + +type oauthPendingFlowTestHandlerOptions struct { + invitationEnabled bool + emailVerifyEnabled bool + emailCache service.EmailCache + settingValues map[string]string + defaultSubAssigner service.DefaultSubscriptionAssigner + totpCache service.TotpCache + totpEncryptor service.SecretEncryptor +} + +func newOAuthPendingFlowTestHandlerWithDependencies( + t *testing.T, + options oauthPendingFlowTestHandlerOptions, ) (*AuthHandler, *dbent.Client) { t.Helper() @@ -814,6 +1146,16 @@ func newOAuthPendingFlowTestHandlerWithOptions( _, err = db.Exec("PRAGMA foreign_keys = ON") require.NoError(t, err) + _, err = db.Exec(` +CREATE TABLE IF NOT EXISTS user_provider_default_grants ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_id INTEGER NOT NULL, + provider_type TEXT NOT NULL, + grant_reason TEXT NOT NULL DEFAULT 'first_bind', + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + UNIQUE(user_id, provider_type, grant_reason) +)`) + require.NoError(t, err) drv := entsql.OpenDB(dialect.SQLite, db) client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv))) @@ -830,21 +1172,23 @@ func newOAuthPendingFlowTestHandlerWithOptions( UserConcurrency: 1, }, } - settingSvc := service.NewSettingService(&oauthPendingFlowSettingRepoStub{ - values: map[string]string{ - service.SettingKeyRegistrationEnabled: "true", - service.SettingKeyInvitationCodeEnabled: boolSettingValue(invitationEnabled), - service.SettingKeyEmailVerifyEnabled: boolSettingValue(emailVerifyEnabled), - }, - }, cfg) + settingValues := map[string]string{ + service.SettingKeyRegistrationEnabled: "true", + service.SettingKeyInvitationCodeEnabled: boolSettingValue(options.invitationEnabled), + service.SettingKeyEmailVerifyEnabled: boolSettingValue(options.emailVerifyEnabled), + } + for key, value := range options.settingValues { + settingValues[key] = value + } + settingSvc := service.NewSettingService(&oauthPendingFlowSettingRepoStub{values: settingValues}, cfg) userRepo := &oauthPendingFlowUserRepo{client: client} var emailService *service.EmailService - if emailCache != nil { + if options.emailCache != nil { emailService = service.NewEmailService(&oauthPendingFlowSettingRepoStub{ values: map[string]string{ - service.SettingKeyEmailVerifyEnabled: boolSettingValue(emailVerifyEnabled), + service.SettingKeyEmailVerifyEnabled: boolSettingValue(options.emailVerifyEnabled), }, - }, emailCache) + }, options.emailCache) } authSvc := service.NewAuthService( client, @@ -857,14 +1201,27 @@ func newOAuthPendingFlowTestHandlerWithOptions( nil, nil, nil, - nil, + options.defaultSubAssigner, ) userSvc := service.NewUserService(userRepo, nil, nil, nil) + var totpSvc *service.TotpService + if options.totpCache != nil || options.totpEncryptor != nil { + totpCache := options.totpCache + if totpCache == nil { + totpCache = &oauthPendingFlowTotpCacheStub{} + } + totpEncryptor := options.totpEncryptor + if totpEncryptor == nil { + totpEncryptor = oauthPendingFlowTotpEncryptorStub{} + } + totpSvc = service.NewTotpService(userRepo, totpEncryptor, totpCache, settingSvc, nil, nil) + } return &AuthHandler{ authService: authSvc, userService: userSvc, settingSvc: settingSvc, + totpService: totpSvc, }, client } @@ -1049,6 +1406,32 @@ func decodeJSONBody(t *testing.T, recorder *httptest.ResponseRecorder) map[strin return payload } +func countProviderGrantRecords( + t *testing.T, + client *dbent.Client, + userID int64, + providerType string, + grantReason string, +) int { + t.Helper() + + var rows entsql.Rows + err := client.Driver().Query( + context.Background(), + `SELECT COUNT(*) FROM user_provider_default_grants WHERE user_id = ? AND provider_type = ? AND grant_reason = ?`, + []any{userID, providerType, grantReason}, + &rows, + ) + require.NoError(t, err) + defer rows.Close() + + require.True(t, rows.Next()) + var count int + require.NoError(t, rows.Scan(&count)) + require.False(t, rows.Next()) + return count +} + type oauthPendingFlowUserRepo struct { client *dbent.Client } @@ -1063,6 +1446,10 @@ func (r *oauthPendingFlowUserRepo) Create(ctx context.Context, user *service.Use SetBalance(user.Balance). SetConcurrency(user.Concurrency). SetStatus(user.Status). + SetNillableTotpSecretEncrypted(user.TotpSecretEncrypted). + SetTotpEnabled(user.TotpEnabled). + SetNillableTotpEnabledAt(user.TotpEnabledAt). + SetTotalRecharged(user.TotalRecharged). SetSignupSource(user.SignupSource). SetNillableLastLoginAt(user.LastLoginAt). SetNillableLastActiveAt(user.LastActiveAt). @@ -1112,6 +1499,10 @@ func (r *oauthPendingFlowUserRepo) Update(ctx context.Context, user *service.Use SetBalance(user.Balance). SetConcurrency(user.Concurrency). SetStatus(user.Status). + SetNillableTotpSecretEncrypted(user.TotpSecretEncrypted). + SetTotpEnabled(user.TotpEnabled). + SetNillableTotpEnabledAt(user.TotpEnabledAt). + SetTotalRecharged(user.TotalRecharged). SetSignupSource(user.SignupSource). SetNillableLastLoginAt(user.LastLoginAt). SetNillableLastActiveAt(user.LastActiveAt). @@ -1203,16 +1594,29 @@ func (r *oauthPendingFlowUserRepo) ListUserAuthIdentities(ctx context.Context, u return records, nil } -func (r *oauthPendingFlowUserRepo) UpdateTotpSecret(context.Context, int64, *string) error { - panic("unexpected UpdateTotpSecret call") +func (r *oauthPendingFlowUserRepo) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error { + update := r.client.User.UpdateOneID(userID) + if encryptedSecret == nil { + update = update.ClearTotpSecretEncrypted() + } else { + update = update.SetTotpSecretEncrypted(*encryptedSecret) + } + return update.Exec(ctx) } -func (r *oauthPendingFlowUserRepo) EnableTotp(context.Context, int64) error { - panic("unexpected EnableTotp call") +func (r *oauthPendingFlowUserRepo) EnableTotp(ctx context.Context, userID int64) error { + return r.client.User.UpdateOneID(userID). + SetTotpEnabled(true). + SetTotpEnabledAt(time.Now().UTC()). + Exec(ctx) } -func (r *oauthPendingFlowUserRepo) DisableTotp(context.Context, int64) error { - panic("unexpected DisableTotp call") +func (r *oauthPendingFlowUserRepo) DisableTotp(ctx context.Context, userID int64) error { + return r.client.User.UpdateOneID(userID). + SetTotpEnabled(false). + ClearTotpSecretEncrypted(). + ClearTotpEnabledAt(). + Exec(ctx) } func oauthPendingFlowServiceUser(entity *dbent.User) *service.User { @@ -1220,19 +1624,113 @@ func oauthPendingFlowServiceUser(entity *dbent.User) *service.User { return nil } return &service.User{ - ID: entity.ID, - Email: entity.Email, - Username: entity.Username, - Notes: entity.Notes, - PasswordHash: entity.PasswordHash, - Role: entity.Role, - Balance: entity.Balance, - Concurrency: entity.Concurrency, - Status: entity.Status, - SignupSource: entity.SignupSource, - LastLoginAt: entity.LastLoginAt, - LastActiveAt: entity.LastActiveAt, - CreatedAt: entity.CreatedAt, - UpdatedAt: entity.UpdatedAt, + ID: entity.ID, + Email: entity.Email, + Username: entity.Username, + Notes: entity.Notes, + PasswordHash: entity.PasswordHash, + Role: entity.Role, + Balance: entity.Balance, + Concurrency: entity.Concurrency, + Status: entity.Status, + SignupSource: entity.SignupSource, + LastLoginAt: entity.LastLoginAt, + LastActiveAt: entity.LastActiveAt, + TotpSecretEncrypted: entity.TotpSecretEncrypted, + TotpEnabled: entity.TotpEnabled, + TotpEnabledAt: entity.TotpEnabledAt, + TotalRecharged: entity.TotalRecharged, + CreatedAt: entity.CreatedAt, + UpdatedAt: entity.UpdatedAt, } } + +type oauthPendingFlowDefaultSubAssignerStub struct { + calls []service.AssignSubscriptionInput +} + +func (s *oauthPendingFlowDefaultSubAssignerStub) AssignOrExtendSubscription( + _ context.Context, + input *service.AssignSubscriptionInput, +) (*service.UserSubscription, bool, error) { + if input != nil { + s.calls = append(s.calls, *input) + } + return nil, false, nil +} + +type oauthPendingFlowTotpCacheStub struct { + setupSessions map[int64]*service.TotpSetupSession + loginSessions map[string]*service.TotpLoginSession + verifyAttempts map[int64]int +} + +func (s *oauthPendingFlowTotpCacheStub) GetSetupSession(_ context.Context, userID int64) (*service.TotpSetupSession, error) { + if s == nil || s.setupSessions == nil { + return nil, nil + } + return s.setupSessions[userID], nil +} + +func (s *oauthPendingFlowTotpCacheStub) SetSetupSession(_ context.Context, userID int64, session *service.TotpSetupSession, _ time.Duration) error { + if s.setupSessions == nil { + s.setupSessions = map[int64]*service.TotpSetupSession{} + } + s.setupSessions[userID] = session + return nil +} + +func (s *oauthPendingFlowTotpCacheStub) DeleteSetupSession(_ context.Context, userID int64) error { + delete(s.setupSessions, userID) + return nil +} + +func (s *oauthPendingFlowTotpCacheStub) GetLoginSession(_ context.Context, tempToken string) (*service.TotpLoginSession, error) { + if s == nil || s.loginSessions == nil { + return nil, nil + } + return s.loginSessions[tempToken], nil +} + +func (s *oauthPendingFlowTotpCacheStub) SetLoginSession(_ context.Context, tempToken string, session *service.TotpLoginSession, _ time.Duration) error { + if s.loginSessions == nil { + s.loginSessions = map[string]*service.TotpLoginSession{} + } + s.loginSessions[tempToken] = session + return nil +} + +func (s *oauthPendingFlowTotpCacheStub) DeleteLoginSession(_ context.Context, tempToken string) error { + delete(s.loginSessions, tempToken) + return nil +} + +func (s *oauthPendingFlowTotpCacheStub) IncrementVerifyAttempts(_ context.Context, userID int64) (int, error) { + if s.verifyAttempts == nil { + s.verifyAttempts = map[int64]int{} + } + s.verifyAttempts[userID]++ + return s.verifyAttempts[userID], nil +} + +func (s *oauthPendingFlowTotpCacheStub) GetVerifyAttempts(_ context.Context, userID int64) (int, error) { + if s == nil || s.verifyAttempts == nil { + return 0, nil + } + return s.verifyAttempts[userID], nil +} + +func (s *oauthPendingFlowTotpCacheStub) ClearVerifyAttempts(_ context.Context, userID int64) error { + delete(s.verifyAttempts, userID) + return nil +} + +type oauthPendingFlowTotpEncryptorStub struct{} + +func (oauthPendingFlowTotpEncryptorStub) Encrypt(plaintext string) (string, error) { + return plaintext, nil +} + +func (oauthPendingFlowTotpEncryptorStub) Decrypt(ciphertext string) (string, error) { + return ciphertext, nil +} diff --git a/backend/internal/handler/auth_oidc_oauth.go b/backend/internal/handler/auth_oidc_oauth.go index 909d6379..70424ec5 100644 --- a/backend/internal/handler/auth_oidc_oauth.go +++ b/backend/internal/handler/auth_oidc_oauth.go @@ -537,7 +537,7 @@ func (h *AuthHandler) CompleteOIDCOAuthRegistration(c *gin.Context) { response.ErrorFrom(c, err) return } - if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), session, decision, &user.ID); err != nil { + if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), h.authService, session, decision, &user.ID); err != nil { response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err)) return } diff --git a/backend/internal/handler/auth_wechat_oauth.go b/backend/internal/handler/auth_wechat_oauth.go index 6d37c799..816f60fd 100644 --- a/backend/internal/handler/auth_wechat_oauth.go +++ b/backend/internal/handler/auth_wechat_oauth.go @@ -346,7 +346,7 @@ func (h *AuthHandler) CompleteWeChatOAuthRegistration(c *gin.Context) { response.ErrorFrom(c, err) return } - if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), session, decision, &user.ID); err != nil { + if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), h.authService, session, decision, &user.ID); err != nil { response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err)) return } diff --git a/backend/internal/service/auth_oauth_first_bind.go b/backend/internal/service/auth_oauth_first_bind.go new file mode 100644 index 00000000..422a2a88 --- /dev/null +++ b/backend/internal/service/auth_oauth_first_bind.go @@ -0,0 +1,106 @@ +package service + +import ( + "context" + "fmt" + "strings" + + dbent "github.com/Wei-Shaw/sub2api/ent" + + entsql "entgo.io/ent/dialect/sql" +) + +// ApplyProviderDefaultSettingsOnFirstBind applies provider-specific bootstrap +// settings the first time a user binds a third-party identity. The grant is +// idempotent per user/provider pair. +func (s *AuthService) ApplyProviderDefaultSettingsOnFirstBind( + ctx context.Context, + userID int64, + providerType string, +) error { + if s == nil || s.entClient == nil || s.settingService == nil || userID <= 0 { + return nil + } + + if dbent.TxFromContext(ctx) != nil { + return s.applyProviderDefaultSettingsOnFirstBind(ctx, userID, providerType) + } + + tx, err := s.entClient.Tx(ctx) + if err != nil { + return fmt.Errorf("begin first bind defaults transaction: %w", err) + } + defer func() { _ = tx.Rollback() }() + + txCtx := dbent.NewTxContext(ctx, tx) + if err := s.applyProviderDefaultSettingsOnFirstBind(txCtx, userID, providerType); err != nil { + return err + } + return tx.Commit() +} + +func (s *AuthService) applyProviderDefaultSettingsOnFirstBind( + ctx context.Context, + userID int64, + providerType string, +) error { + defaults, err := s.settingService.GetAuthSourceDefaultSettings(ctx) + if err != nil { + return fmt.Errorf("load auth source defaults: %w", err) + } + + providerDefaults, ok := authSourceSignupSettings(defaults, providerType) + if !ok || !providerDefaults.GrantOnFirstBind { + return nil + } + + client := s.entClient + if tx := dbent.TxFromContext(ctx); tx != nil { + client = tx.Client() + } + + var result entsql.Result + if err := client.Driver().Exec( + ctx, + `INSERT INTO user_provider_default_grants (user_id, provider_type, grant_reason) +VALUES (?, ?, ?) +ON CONFLICT (user_id, provider_type, grant_reason) DO NOTHING`, + []any{userID, strings.TrimSpace(providerType), "first_bind"}, + &result, + ); err != nil { + return fmt.Errorf("record first bind provider grant: %w", err) + } + + affected, err := result.RowsAffected() + if err != nil { + return fmt.Errorf("read first bind provider grant result: %w", err) + } + if affected == 0 { + return nil + } + + if providerDefaults.Balance != 0 { + if err := client.User.UpdateOneID(userID).AddBalance(providerDefaults.Balance).Exec(ctx); err != nil { + return fmt.Errorf("apply first bind balance default: %w", err) + } + } + if providerDefaults.Concurrency != 0 { + if err := client.User.UpdateOneID(userID).AddConcurrency(providerDefaults.Concurrency).Exec(ctx); err != nil { + return fmt.Errorf("apply first bind concurrency default: %w", err) + } + } + if s.defaultSubAssigner != nil { + for _, item := range providerDefaults.Subscriptions { + if _, _, err := s.defaultSubAssigner.AssignOrExtendSubscription(ctx, &AssignSubscriptionInput{ + UserID: userID, + GroupID: item.GroupID, + ValidityDays: item.ValidityDays, + Notes: "auto assigned by first bind defaults", + }); err != nil { + return fmt.Errorf("apply first bind subscription default: %w", err) + } + } + } + + return nil +} diff --git a/backend/internal/service/totp_service.go b/backend/internal/service/totp_service.go index 5192fe3d..052739ed 100644 --- a/backend/internal/service/totp_service.go +++ b/backend/internal/service/totp_service.go @@ -58,9 +58,15 @@ type TotpSetupSession struct { // TotpLoginSession represents a pending 2FA login session type TotpLoginSession struct { - UserID int64 - Email string - TokenExpiry time.Time + UserID int64 + Email string + TokenExpiry time.Time + PendingOAuthBind *PendingOAuthBindLoginSession `json:"pending_oauth_bind,omitempty"` +} + +type PendingOAuthBindLoginSession struct { + PendingSessionToken string `json:"pending_session_token,omitempty"` + BrowserSessionKey string `json:"browser_session_key,omitempty"` } // TotpStatus represents the TOTP status for a user @@ -397,6 +403,30 @@ func (s *TotpService) VerifyCode(ctx context.Context, userID int64, code string) // CreateLoginSession creates a temporary login session for 2FA func (s *TotpService) CreateLoginSession(ctx context.Context, userID int64, email string) (string, error) { + return s.createLoginSession(ctx, userID, email, nil) +} + +// CreatePendingOAuthBindLoginSession creates a temporary 2FA session that will +// finalize a pending OAuth bind after the TOTP code is verified. +func (s *TotpService) CreatePendingOAuthBindLoginSession( + ctx context.Context, + userID int64, + email string, + pendingSessionToken string, + browserSessionKey string, +) (string, error) { + return s.createLoginSession(ctx, userID, email, &PendingOAuthBindLoginSession{ + PendingSessionToken: pendingSessionToken, + BrowserSessionKey: browserSessionKey, + }) +} + +func (s *TotpService) createLoginSession( + ctx context.Context, + userID int64, + email string, + pendingOAuthBind *PendingOAuthBindLoginSession, +) (string, error) { // Generate a random temp token tempToken, err := generateRandomToken(32) if err != nil { @@ -404,9 +434,10 @@ func (s *TotpService) CreateLoginSession(ctx context.Context, userID int64, emai } session := &TotpLoginSession{ - UserID: userID, - Email: email, - TokenExpiry: time.Now().Add(totpLoginTTL), + UserID: userID, + Email: email, + TokenExpiry: time.Now().Add(totpLoginTTL), + PendingOAuthBind: pendingOAuthBind, } if err := s.cache.SetLoginSession(ctx, tempToken, session, totpLoginTTL); err != nil {