From 767f2f2dfe4e08490ca28f073c0156b2b6b7d912 Mon Sep 17 00:00:00 2001 From: IanShaw027 Date: Wed, 22 Apr 2026 12:30:00 +0800 Subject: [PATCH] fix(auth): harden pending oauth and backend mode flows --- backend/internal/handler/auth_handler.go | 2 + .../internal/handler/auth_linuxdo_oauth.go | 9 ++ .../handler/auth_linuxdo_oauth_test.go | 55 ++++++++ .../handler/auth_oauth_logout_test.go | 68 +++++++++ .../handler/auth_oauth_pending_flow.go | 129 ++++++++++++++++++ .../handler/auth_oauth_pending_flow_test.go | 20 +++ backend/internal/handler/auth_oidc_oauth.go | 23 +++- .../internal/handler/auth_oidc_oauth_test.go | 56 ++++++++ backend/internal/handler/auth_wechat_oauth.go | 9 ++ .../handler/auth_wechat_oauth_test.go | 104 ++++++++++---- .../server/middleware/backend_mode_guard.go | 47 +++++-- .../middleware/backend_mode_guard_test.go | 90 ++++++++++++ 12 files changed, 568 insertions(+), 44 deletions(-) create mode 100644 backend/internal/handler/auth_oauth_logout_test.go diff --git a/backend/internal/handler/auth_handler.go b/backend/internal/handler/auth_handler.go index 9801b3b3..acd43e9f 100644 --- a/backend/internal/handler/auth_handler.go +++ b/backend/internal/handler/auth_handler.go @@ -678,6 +678,8 @@ func (h *AuthHandler) Logout(c *gin.Context) { // 不影响登出流程 } } + h.consumePendingOAuthSessionOnLogout(c) + clearOAuthLogoutCookies(c) response.Success(c, LogoutResponse{ Message: "Logged out successfully", diff --git a/backend/internal/handler/auth_linuxdo_oauth.go b/backend/internal/handler/auth_linuxdo_oauth.go index ef9a5bca..157be066 100644 --- a/backend/internal/handler/auth_linuxdo_oauth.go +++ b/backend/internal/handler/auth_linuxdo_oauth.go @@ -469,6 +469,15 @@ func (h *AuthHandler) CompleteLinuxDoOAuthRegistration(c *gin.Context) { response.ErrorFrom(c, err) return } + if updatedSession, handled, err := h.legacyCompleteRegistrationSessionStatus(c, session); err != nil { + response.ErrorFrom(c, err) + return + } else if handled { + c.JSON(http.StatusOK, buildPendingOAuthSessionStatusPayload(updatedSession)) + return + } else { + session = updatedSession + } if err := h.ensureBackendModeAllowsNewUserLogin(c.Request.Context()); err != nil { response.ErrorFrom(c, err) return diff --git a/backend/internal/handler/auth_linuxdo_oauth_test.go b/backend/internal/handler/auth_linuxdo_oauth_test.go index 841dc442..a9a5e3e6 100644 --- a/backend/internal/handler/auth_linuxdo_oauth_test.go +++ b/backend/internal/handler/auth_linuxdo_oauth_test.go @@ -757,6 +757,61 @@ func TestCompleteLinuxDoOAuthRegistrationRejectsAdoptExistingUserSession(t *test require.Nil(t, storedSession.ConsumedAt) } +func TestCompleteLinuxDoOAuthRegistrationReturnsPendingSessionWhenChoiceStillRequired(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandler(t, false) + ctx := context.Background() + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("linuxdo-complete-choice-session"). + SetIntent("login"). + SetProviderType("linuxdo"). + SetProviderKey("linuxdo"). + SetProviderSubject("linuxdo-choice-subject-1"). + SetResolvedEmail("linuxdo-choice-subject-1@linuxdo-connect.invalid"). + SetBrowserSessionKey("linuxdo-choice-browser"). + SetUpstreamIdentityClaims(map[string]any{ + "username": "linuxdo_user", + }). + SetLocalFlowState(map[string]any{ + oauthCompletionResponseKey: map[string]any{ + "step": oauthPendingChoiceStep, + "redirect": "/dashboard", + "email": "fresh@example.com", + "resolved_email": "fresh@example.com", + "force_email_on_signup": true, + }, + }). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`) + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/linuxdo/complete-registration", body) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("linuxdo-choice-browser")}) + c.Request = req + + handler.CompleteLinuxDoOAuthRegistration(c) + + require.Equal(t, http.StatusOK, recorder.Code) + responseData := decodeJSONBody(t, recorder) + require.Equal(t, "pending_session", responseData["auth_result"]) + require.Equal(t, oauthPendingChoiceStep, responseData["step"]) + require.Equal(t, true, responseData["force_email_on_signup"]) + require.Empty(t, responseData["access_token"]) + + userCount, err := client.User.Query().Count(ctx) + require.NoError(t, err) + require.Zero(t, userCount) + + storedSession, err := client.PendingAuthSession.Get(ctx, session.ID) + require.NoError(t, err) + require.Nil(t, storedSession.ConsumedAt) +} + func newLinuxDoOAuthTestHandler(t *testing.T, invitationEnabled bool, oauthCfg config.LinuxDoConnectConfig) *AuthHandler { t.Helper() handler, _ := newLinuxDoOAuthHandlerAndClient(t, invitationEnabled, oauthCfg) diff --git a/backend/internal/handler/auth_oauth_logout_test.go b/backend/internal/handler/auth_oauth_logout_test.go new file mode 100644 index 00000000..0d4f94b1 --- /dev/null +++ b/backend/internal/handler/auth_oauth_logout_test.go @@ -0,0 +1,68 @@ +package handler + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func TestLogoutClearsOAuthStateCookiesAndConsumesPendingSession(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandler(t, false) + ctx := context.Background() + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("logout-pending-session-token"). + SetIntent("login"). + SetProviderType("oidc"). + SetProviderKey("https://issuer.example"). + SetProviderSubject("logout-subject-123"). + SetBrowserSessionKey("logout-browser-session-key"). + SetResolvedEmail("logout@example.com"). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + recorder := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/logout", nil) + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("logout-browser-session-key")}) + req.AddCookie(&http.Cookie{Name: oauthBindAccessTokenCookieName, Value: "bind-access-token"}) + req.AddCookie(&http.Cookie{Name: linuxDoOAuthStateCookieName, Value: encodeCookieValue("linuxdo-state")}) + req.AddCookie(&http.Cookie{Name: oidcOAuthStateCookieName, Value: encodeCookieValue("oidc-state")}) + req.AddCookie(&http.Cookie{Name: wechatOAuthStateCookieName, Value: encodeCookieValue("wechat-state")}) + req.AddCookie(&http.Cookie{Name: wechatPaymentOAuthStateName, Value: encodeCookieValue("wechat-payment-state")}) + ginCtx.Request = req + + handler.Logout(ginCtx) + + require.Equal(t, http.StatusOK, recorder.Code) + + cookies := recorder.Result().Cookies() + for _, name := range []string{ + oauthPendingSessionCookieName, + oauthPendingBrowserCookieName, + oauthBindAccessTokenCookieName, + linuxDoOAuthStateCookieName, + oidcOAuthStateCookieName, + wechatOAuthStateCookieName, + wechatPaymentOAuthStateName, + } { + cookie := findCookie(cookies, name) + require.NotNil(t, cookie, name) + require.Equal(t, -1, cookie.MaxAge, name) + require.True(t, cookie.HttpOnly, name) + } + + storedSession, err := client.PendingAuthSession.Query(). + Where(pendingauthsession.IDEQ(session.ID)). + Only(ctx) + require.NoError(t, err) + require.NotNil(t, storedSession.ConsumedAt) +} diff --git a/backend/internal/handler/auth_oauth_pending_flow.go b/backend/internal/handler/auth_oauth_pending_flow.go index 658a5f52..c5df4db1 100644 --- a/backend/internal/handler/auth_oauth_pending_flow.go +++ b/backend/internal/handler/auth_oauth_pending_flow.go @@ -310,6 +310,78 @@ func ensurePendingOAuthCompleteRegistrationSession(session *dbent.PendingAuthSes return nil } +func buildLegacyCompleteRegistrationPendingResponse( + session *dbent.PendingAuthSession, + forceEmailOnSignup bool, + emailVerificationRequired bool, +) map[string]any { + completionResponse := normalizePendingOAuthCompletionResponse(mergePendingCompletionResponse(session, map[string]any{ + "step": oauthPendingChoiceStep, + "adoption_required": true, + "create_account_allowed": true, + "force_email_on_signup": forceEmailOnSignup, + })) + + if email := strings.TrimSpace(session.ResolvedEmail); email != "" { + if _, exists := completionResponse["email"]; !exists { + completionResponse["email"] = email + } + if _, exists := completionResponse["resolved_email"]; !exists { + completionResponse["resolved_email"] = email + } + } + if _, exists := completionResponse["choice_reason"]; !exists { + switch { + case forceEmailOnSignup: + completionResponse["choice_reason"] = "force_email_on_signup" + case emailVerificationRequired: + completionResponse["choice_reason"] = "email_verification_required" + default: + completionResponse["choice_reason"] = "third_party_signup" + } + } + return completionResponse +} + +func (h *AuthHandler) legacyCompleteRegistrationSessionStatus( + c *gin.Context, + session *dbent.PendingAuthSession, +) (*dbent.PendingAuthSession, bool, error) { + if session == nil { + return nil, false, infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid") + } + + payload := normalizePendingOAuthCompletionResponse(mergePendingCompletionResponse(session, nil)) + if step := pendingSessionStringValue(payload, "step"); step != "" { + return session, true, nil + } + + emailVerificationRequired := h != nil && h.authService != nil && h.authService.IsEmailVerifyEnabled(c.Request.Context()) + forceEmailOnSignup := h.isForceEmailOnThirdPartySignup(c.Request.Context()) + if !emailVerificationRequired && !forceEmailOnSignup { + return session, false, nil + } + + client := h.entClient() + if client == nil { + return nil, false, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready") + } + + updatedSession, err := updatePendingOAuthSessionProgress( + c.Request.Context(), + client, + session, + strings.TrimSpace(session.Intent), + strings.TrimSpace(session.ResolvedEmail), + nil, + buildLegacyCompleteRegistrationPendingResponse(session, forceEmailOnSignup, emailVerificationRequired), + ) + if err != nil { + return nil, false, infraerrors.InternalServer("PENDING_AUTH_SESSION_UPDATE_FAILED", "failed to update pending oauth session").WithCause(err) + } + return updatedSession, true, nil +} + func (r oauthAdoptionDecisionRequest) hasDecision() bool { return r.AdoptDisplayName != nil || r.AdoptAvatar != nil } @@ -1272,6 +1344,59 @@ func readPendingOAuthBrowserSession(c *gin.Context, h *AuthHandler) (*service.Au return svc, session, clearCookies, nil } +func (h *AuthHandler) consumePendingOAuthSessionOnLogout(c *gin.Context) { + if c == nil || c.Request == nil { + return + } + + sessionToken, err := readOAuthPendingSessionCookie(c) + if err != nil || strings.TrimSpace(sessionToken) == "" { + return + } + browserSessionKey, err := readOAuthPendingBrowserCookie(c) + if err != nil || strings.TrimSpace(browserSessionKey) == "" { + return + } + + svc, err := h.pendingIdentityService() + if err != nil { + return + } + _, _ = svc.ConsumeBrowserSession(c.Request.Context(), sessionToken, browserSessionKey) +} + +func clearOAuthLogoutCookies(c *gin.Context) { + secureCookie := isRequestHTTPS(c) + + clearOAuthPendingSessionCookie(c, secureCookie) + clearOAuthPendingBrowserCookie(c, secureCookie) + clearOAuthBindAccessTokenCookie(c, secureCookie) + + clearCookie(c, linuxDoOAuthStateCookieName, secureCookie) + clearCookie(c, linuxDoOAuthVerifierCookie, secureCookie) + clearCookie(c, linuxDoOAuthRedirectCookie, secureCookie) + clearCookie(c, linuxDoOAuthIntentCookieName, secureCookie) + clearCookie(c, linuxDoOAuthBindUserCookieName, secureCookie) + + oidcClearCookie(c, oidcOAuthStateCookieName, secureCookie) + oidcClearCookie(c, oidcOAuthVerifierCookie, secureCookie) + oidcClearCookie(c, oidcOAuthRedirectCookie, secureCookie) + oidcClearCookie(c, oidcOAuthNonceCookie, secureCookie) + oidcClearCookie(c, oidcOAuthIntentCookieName, secureCookie) + oidcClearCookie(c, oidcOAuthBindUserCookieName, secureCookie) + + wechatClearCookie(c, wechatOAuthStateCookieName, secureCookie) + wechatClearCookie(c, wechatOAuthRedirectCookieName, secureCookie) + wechatClearCookie(c, wechatOAuthIntentCookieName, secureCookie) + wechatClearCookie(c, wechatOAuthModeCookieName, secureCookie) + wechatClearCookie(c, wechatOAuthBindUserCookieName, secureCookie) + + wechatPaymentClearCookie(c, wechatPaymentOAuthStateName, secureCookie) + wechatPaymentClearCookie(c, wechatPaymentOAuthRedirect, secureCookie) + wechatPaymentClearCookie(c, wechatPaymentOAuthContextName, secureCookie) + wechatPaymentClearCookie(c, wechatPaymentOAuthScope, secureCookie) +} + func buildPendingOAuthSessionStatusPayload(session *dbent.PendingAuthSession) gin.H { completionResponse := normalizePendingOAuthCompletionResponse(mergePendingCompletionResponse(session, nil)) payload := gin.H{ @@ -1451,6 +1576,10 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string) response.ErrorFrom(c, err) return } + if err := ensurePendingOAuthCompleteRegistrationSession(session); err != nil { + response.ErrorFrom(c, err) + return + } if strings.TrimSpace(provider) != "" && !strings.EqualFold(strings.TrimSpace(session.ProviderType), provider) { response.BadRequest(c, "Pending oauth session provider mismatch") return diff --git a/backend/internal/handler/auth_oauth_pending_flow_test.go b/backend/internal/handler/auth_oauth_pending_flow_test.go index c0413d4d..b3b8dfe1 100644 --- a/backend/internal/handler/auth_oauth_pending_flow_test.go +++ b/backend/internal/handler/auth_oauth_pending_flow_test.go @@ -1228,6 +1228,26 @@ func TestCreateOIDCOAuthAccountBlocksBackendModeBeforeCreatingUser(t *testing.T) require.Nil(t, storedSession.ConsumedAt) } +func TestLogoutClearsPendingOAuthAndBindCookies(t *testing.T) { + handler, _ := newOAuthPendingFlowTestHandler(t, false) + + recorder := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/logout", bytes.NewBufferString(`{}`)) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue("pending-session-token")}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("pending-browser-key")}) + req.AddCookie(&http.Cookie{Name: oauthBindAccessTokenCookieName, Value: "bind-token"}) + ginCtx.Request = req + + handler.Logout(ginCtx) + + require.Equal(t, http.StatusOK, recorder.Code) + require.Equal(t, -1, findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName).MaxAge) + require.Equal(t, -1, findCookie(recorder.Result().Cookies(), oauthPendingBrowserCookieName).MaxAge) + require.Equal(t, -1, findCookie(recorder.Result().Cookies(), oauthBindAccessTokenCookieName).MaxAge) +} + func TestCreateOIDCOAuthAccountRollsBackCreatedUserWhenBindingFails(t *testing.T) { handler, client := newOAuthPendingFlowTestHandlerWithEmailVerification(t, true, "fresh@example.com", "246810") ctx := context.Background() diff --git a/backend/internal/handler/auth_oidc_oauth.go b/backend/internal/handler/auth_oidc_oauth.go index 7fe4b8d9..6345938b 100644 --- a/backend/internal/handler/auth_oidc_oauth.go +++ b/backend/internal/handler/auth_oidc_oauth.go @@ -374,19 +374,19 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) { ProviderSubject: subject, } upstreamClaims := map[string]any{ - "email": email, - "username": username, - "subject": subject, - "issuer": issuer, - "email_verified": emailVerified != nil && *emailVerified, - "provider_fallback": strings.TrimSpace(cfg.ProviderName), + "email": email, + "username": username, + "subject": subject, + "issuer": issuer, + "email_verified": emailVerified != nil && *emailVerified, + "provider_fallback": strings.TrimSpace(cfg.ProviderName), "suggested_display_name": firstNonEmpty(userInfoClaims.DisplayName, func() string { if idClaims != nil { return idClaims.Name } return "" }(), username), - "suggested_avatar_url": userInfoClaims.AvatarURL, + "suggested_avatar_url": userInfoClaims.AvatarURL, } if compatEmail != "" && !strings.EqualFold(strings.TrimSpace(compatEmail), strings.TrimSpace(email)) { upstreamClaims["compat_email"] = compatEmail @@ -622,6 +622,15 @@ func (h *AuthHandler) CompleteOIDCOAuthRegistration(c *gin.Context) { response.ErrorFrom(c, err) return } + if updatedSession, handled, err := h.legacyCompleteRegistrationSessionStatus(c, session); err != nil { + response.ErrorFrom(c, err) + return + } else if handled { + c.JSON(http.StatusOK, buildPendingOAuthSessionStatusPayload(updatedSession)) + return + } else { + session = updatedSession + } if err := h.ensureBackendModeAllowsNewUserLogin(c.Request.Context()); err != nil { response.ErrorFrom(c, err) return diff --git a/backend/internal/handler/auth_oidc_oauth_test.go b/backend/internal/handler/auth_oidc_oauth_test.go index a600fd56..63008344 100644 --- a/backend/internal/handler/auth_oidc_oauth_test.go +++ b/backend/internal/handler/auth_oidc_oauth_test.go @@ -692,6 +692,62 @@ func TestCompleteOIDCOAuthRegistrationRejectsAdoptExistingUserSession(t *testing require.Nil(t, storedSession.ConsumedAt) } +func TestCompleteOIDCOAuthRegistrationReturnsPendingSessionWhenChoiceStillRequired(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandler(t, false) + ctx := context.Background() + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("oidc-complete-choice-session"). + SetIntent("login"). + SetProviderType("oidc"). + SetProviderKey("https://issuer.example.com"). + SetProviderSubject("oidc-choice-subject-1"). + SetResolvedEmail("oidc-choice-subject-1@oidc-connect.invalid"). + SetBrowserSessionKey("oidc-choice-browser"). + SetUpstreamIdentityClaims(map[string]any{ + "username": "oidc_user", + "issuer": "https://issuer.example.com", + }). + SetLocalFlowState(map[string]any{ + oauthCompletionResponseKey: map[string]any{ + "step": oauthPendingChoiceStep, + "redirect": "/dashboard", + "email": "fresh@example.com", + "resolved_email": "fresh@example.com", + "force_email_on_signup": true, + }, + }). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`) + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/complete-registration", body) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("oidc-choice-browser")}) + c.Request = req + + handler.CompleteOIDCOAuthRegistration(c) + + require.Equal(t, http.StatusOK, recorder.Code) + responseData := decodeJSONBody(t, recorder) + require.Equal(t, "pending_session", responseData["auth_result"]) + require.Equal(t, oauthPendingChoiceStep, responseData["step"]) + require.Equal(t, true, responseData["force_email_on_signup"]) + require.Empty(t, responseData["access_token"]) + + userCount, err := client.User.Query().Count(ctx) + require.NoError(t, err) + require.Zero(t, userCount) + + storedSession, err := client.PendingAuthSession.Get(ctx, session.ID) + require.NoError(t, err) + require.Nil(t, storedSession.ConsumedAt) +} + type oidcProviderFixture struct { Subject string PreferredUsername string diff --git a/backend/internal/handler/auth_wechat_oauth.go b/backend/internal/handler/auth_wechat_oauth.go index 39703ce7..3ed20a7d 100644 --- a/backend/internal/handler/auth_wechat_oauth.go +++ b/backend/internal/handler/auth_wechat_oauth.go @@ -525,6 +525,15 @@ func (h *AuthHandler) CompleteWeChatOAuthRegistration(c *gin.Context) { response.ErrorFrom(c, err) return } + if updatedSession, handled, err := h.legacyCompleteRegistrationSessionStatus(c, session); err != nil { + response.ErrorFrom(c, err) + return + } else if handled { + c.JSON(http.StatusOK, buildPendingOAuthSessionStatusPayload(updatedSession)) + return + } else { + session = updatedSession + } if err := h.ensureBackendModeAllowsNewUserLogin(c.Request.Context()); err != nil { response.ErrorFrom(c, err) return diff --git a/backend/internal/handler/auth_wechat_oauth_test.go b/backend/internal/handler/auth_wechat_oauth_test.go index 99006701..349e7dd2 100644 --- a/backend/internal/handler/auth_wechat_oauth_test.go +++ b/backend/internal/handler/auth_wechat_oauth_test.go @@ -19,7 +19,6 @@ import ( "github.com/Wei-Shaw/sub2api/ent/enttest" "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" - dbuser "github.com/Wei-Shaw/sub2api/ent/user" "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/payment" "github.com/Wei-Shaw/sub2api/internal/repository" @@ -700,7 +699,7 @@ func TestWeChatOAuthCallbackBindRejectsLegacyProviderKeyOwnershipConflict(t *tes require.Zero(t, count) } -func TestCompleteWeChatOAuthRegistrationAfterInvitationPendingSession(t *testing.T) { +func TestCompleteWeChatOAuthRegistrationAfterInvitationPendingSessionReturnsPendingSession(t *testing.T) { originalAccessTokenURL := wechatOAuthAccessTokenURL originalUserInfoURL := wechatOAuthUserInfoURL t.Cleanup(func() { @@ -773,27 +772,32 @@ func TestCompleteWeChatOAuthRegistrationAfterInvitationPendingSession(t *testing require.Equal(t, http.StatusOK, completeRecorder.Code) responseData := decodeJSONBody(t, completeRecorder) - require.NotEmpty(t, responseData["access_token"]) + require.Equal(t, "pending_session", responseData["auth_result"]) + require.Equal(t, oauthPendingChoiceStep, responseData["step"]) + require.Equal(t, true, responseData["adoption_required"]) + require.Empty(t, responseData["access_token"]) - userEntity, err := client.User.Query(). - Where(dbuser.EmailEQ("wechat-union-456@wechat-connect.invalid")). + consumed, err := client.PendingAuthSession.Query(). + Where(pendingauthsession.IDEQ(pendingSession.ID)). Only(ctx) require.NoError(t, err) - require.Equal(t, "WeChat Display", userEntity.Username) + require.Nil(t, consumed.ConsumedAt) - identity, err := client.AuthIdentity.Query(). + userCount, err := client.User.Query().Count(ctx) + require.NoError(t, err) + require.Zero(t, userCount) + + identityCount, err := client.AuthIdentity.Query(). Where( authidentity.ProviderTypeEQ("wechat"), authidentity.ProviderKeyEQ("wechat-main"), authidentity.ProviderSubjectEQ("union-456"), ). - Only(ctx) + Count(ctx) require.NoError(t, err) - require.Equal(t, userEntity.ID, identity.UserID) - require.Equal(t, "WeChat Display", identity.Metadata["display_name"]) - require.Equal(t, "https://cdn.example/wechat.png", identity.Metadata["avatar_url"]) + require.Zero(t, identityCount) - channel, err := client.AuthIdentityChannel.Query(). + channelCount, err := client.AuthIdentityChannel.Query(). Where( authidentitychannel.ProviderTypeEQ("wechat"), authidentitychannel.ProviderKeyEQ("wechat-main"), @@ -801,25 +805,15 @@ func TestCompleteWeChatOAuthRegistrationAfterInvitationPendingSession(t *testing authidentitychannel.ChannelAppIDEQ("wx-open-app"), authidentitychannel.ChannelSubjectEQ("openid-123"), ). - Only(ctx) + Count(ctx) require.NoError(t, err) - require.Equal(t, identity.ID, channel.IdentityID) - require.Equal(t, "union-456", channel.Metadata["unionid"]) + require.Zero(t, channelCount) - decision, err := client.IdentityAdoptionDecision.Query(). + decisionCount, err := client.IdentityAdoptionDecision.Query(). Where(identityadoptiondecision.PendingAuthSessionIDEQ(pendingSession.ID)). - Only(ctx) + Count(ctx) require.NoError(t, err) - require.NotNil(t, decision.IdentityID) - require.Equal(t, identity.ID, *decision.IdentityID) - require.True(t, decision.AdoptDisplayName) - require.True(t, decision.AdoptAvatar) - - consumed, err := client.PendingAuthSession.Query(). - Where(pendingauthsession.IDEQ(pendingSession.ID)). - Only(ctx) - require.NoError(t, err) - require.NotNil(t, consumed.ConsumedAt) + require.Zero(t, decisionCount) } func TestWeChatOAuthCallbackRepairsLegacyOpenIDOnlyIdentity(t *testing.T) { @@ -981,6 +975,62 @@ func TestCompleteWeChatOAuthRegistrationRejectsAdoptExistingUserSession(t *testi require.Nil(t, storedSession.ConsumedAt) } +func TestCompleteWeChatOAuthRegistrationReturnsPendingSessionWhenChoiceStillRequired(t *testing.T) { + handler, client := newWeChatOAuthTestHandler(t, false) + defer client.Close() + + ctx := context.Background() + session, err := client.PendingAuthSession.Create(). + SetSessionToken("wechat-complete-choice-session"). + SetIntent("login"). + SetProviderType("wechat"). + SetProviderKey("wechat-main"). + SetProviderSubject("wechat-choice-subject-1"). + SetResolvedEmail("wechat-choice-subject-1@wechat-connect.invalid"). + SetBrowserSessionKey("wechat-choice-browser"). + SetUpstreamIdentityClaims(map[string]any{ + "username": "wechat_user", + }). + SetLocalFlowState(map[string]any{ + oauthCompletionResponseKey: map[string]any{ + "step": oauthPendingChoiceStep, + "redirect": "/dashboard", + "email": "fresh@example.com", + "resolved_email": "fresh@example.com", + "force_email_on_signup": true, + }, + }). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`) + recorder := httptest.NewRecorder() + completeCtx, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/wechat/complete-registration", body) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("wechat-choice-browser")}) + completeCtx.Request = req + + handler.CompleteWeChatOAuthRegistration(completeCtx) + + require.Equal(t, http.StatusOK, recorder.Code) + responseData := decodeJSONBody(t, recorder) + require.Equal(t, "pending_session", responseData["auth_result"]) + require.Equal(t, oauthPendingChoiceStep, responseData["step"]) + require.Equal(t, true, responseData["force_email_on_signup"]) + require.Empty(t, responseData["access_token"]) + + userCount, err := client.User.Query().Count(ctx) + require.NoError(t, err) + require.Zero(t, userCount) + + storedSession, err := client.PendingAuthSession.Get(ctx, session.ID) + require.NoError(t, err) + require.Nil(t, storedSession.ConsumedAt) +} + func TestWeChatOAuthCallbackRepairsLegacyProviderKeyCanonicalIdentity(t *testing.T) { originalAccessTokenURL := wechatOAuthAccessTokenURL originalUserInfoURL := wechatOAuthUserInfoURL diff --git a/backend/internal/server/middleware/backend_mode_guard.go b/backend/internal/server/middleware/backend_mode_guard.go index 46482af3..ae53037e 100644 --- a/backend/internal/server/middleware/backend_mode_guard.go +++ b/backend/internal/server/middleware/backend_mode_guard.go @@ -27,23 +27,50 @@ func BackendModeUserGuard(settingService *service.SettingService) gin.HandlerFun } } +func backendModeAllowsAuthPath(path string) bool { + path = strings.ToLower(strings.TrimSpace(path)) + for _, suffix := range []string{"/auth/login", "/auth/login/2fa", "/auth/logout", "/auth/refresh"} { + if strings.HasSuffix(path, suffix) { + return true + } + } + + for _, suffix := range []string{ + "/auth/oauth/linuxdo/callback", + "/auth/oauth/wechat/callback", + "/auth/oauth/wechat/payment/callback", + "/auth/oauth/oidc/callback", + "/auth/oauth/linuxdo/complete-registration", + "/auth/oauth/wechat/complete-registration", + "/auth/oauth/oidc/complete-registration", + "/auth/oauth/linuxdo/create-account", + "/auth/oauth/wechat/create-account", + "/auth/oauth/oidc/create-account", + "/auth/oauth/linuxdo/bind-login", + "/auth/oauth/wechat/bind-login", + "/auth/oauth/oidc/bind-login", + } { + if strings.HasSuffix(path, suffix) { + return true + } + } + + return strings.Contains(path, "/auth/oauth/pending/") +} + // BackendModeAuthGuard selectively blocks auth endpoints when backend mode is enabled. -// Allows: login, login/2fa, logout, refresh (admin needs these). -// Blocks: register, forgot-password, reset-password, OAuth, etc. +// Allows the minimal auth surface admins still need in backend mode, including +// OAuth callbacks and pending continuations. Handler-level backend mode checks +// still enforce admin-only login and forbid self-service registration. func BackendModeAuthGuard(settingService *service.SettingService) gin.HandlerFunc { return func(c *gin.Context) { if settingService == nil || !settingService.IsBackendModeEnabled(c.Request.Context()) { c.Next() return } - path := c.Request.URL.Path - // Allow login, 2FA, logout, refresh, public settings - allowedSuffixes := []string{"/auth/login", "/auth/login/2fa", "/auth/logout", "/auth/refresh"} - for _, suffix := range allowedSuffixes { - if strings.HasSuffix(path, suffix) { - c.Next() - return - } + if backendModeAllowsAuthPath(c.Request.URL.Path) { + c.Next() + return } response.Forbidden(c, "Backend mode is active. Registration and self-service auth flows are disabled.") c.Abort() diff --git a/backend/internal/server/middleware/backend_mode_guard_test.go b/backend/internal/server/middleware/backend_mode_guard_test.go index 8878ebc9..bd77677b 100644 --- a/backend/internal/server/middleware/backend_mode_guard_test.go +++ b/backend/internal/server/middleware/backend_mode_guard_test.go @@ -198,6 +198,96 @@ func TestBackendModeAuthGuard(t *testing.T) { path: "/api/v1/auth/refresh", wantStatus: http.StatusOK, }, + { + name: "enabled_blocks_linuxdo_oauth_start", + enabled: "true", + path: "/api/v1/auth/oauth/linuxdo/start", + wantStatus: http.StatusForbidden, + }, + { + name: "enabled_allows_linuxdo_oauth_callback", + enabled: "true", + path: "/api/v1/auth/oauth/linuxdo/callback", + wantStatus: http.StatusOK, + }, + { + name: "enabled_blocks_wechat_oauth_start", + enabled: "true", + path: "/api/v1/auth/oauth/wechat/start", + wantStatus: http.StatusForbidden, + }, + { + name: "enabled_allows_wechat_oauth_callback", + enabled: "true", + path: "/api/v1/auth/oauth/wechat/callback", + wantStatus: http.StatusOK, + }, + { + name: "enabled_blocks_wechat_payment_oauth_start", + enabled: "true", + path: "/api/v1/auth/oauth/wechat/payment/start", + wantStatus: http.StatusForbidden, + }, + { + name: "enabled_allows_wechat_payment_oauth_callback", + enabled: "true", + path: "/api/v1/auth/oauth/wechat/payment/callback", + wantStatus: http.StatusOK, + }, + { + name: "enabled_blocks_oidc_oauth_start", + enabled: "true", + path: "/api/v1/auth/oauth/oidc/start", + wantStatus: http.StatusForbidden, + }, + { + name: "enabled_allows_oidc_oauth_callback", + enabled: "true", + path: "/api/v1/auth/oauth/oidc/callback", + wantStatus: http.StatusOK, + }, + { + name: "enabled_allows_oauth_pending_exchange", + enabled: "true", + path: "/api/v1/auth/oauth/pending/exchange", + wantStatus: http.StatusOK, + }, + { + name: "enabled_allows_oauth_pending_send_verify_code", + enabled: "true", + path: "/api/v1/auth/oauth/pending/send-verify-code", + wantStatus: http.StatusOK, + }, + { + name: "enabled_allows_oauth_pending_create_account", + enabled: "true", + path: "/api/v1/auth/oauth/pending/create-account", + wantStatus: http.StatusOK, + }, + { + name: "enabled_allows_oauth_pending_bind_login", + enabled: "true", + path: "/api/v1/auth/oauth/pending/bind-login", + wantStatus: http.StatusOK, + }, + { + name: "enabled_allows_provider_bind_login", + enabled: "true", + path: "/api/v1/auth/oauth/oidc/bind-login", + wantStatus: http.StatusOK, + }, + { + name: "enabled_allows_provider_create_account", + enabled: "true", + path: "/api/v1/auth/oauth/wechat/create-account", + wantStatus: http.StatusOK, + }, + { + name: "enabled_allows_legacy_complete_registration", + enabled: "true", + path: "/api/v1/auth/oauth/linuxdo/complete-registration", + wantStatus: http.StatusOK, + }, { name: "enabled_blocks_register", enabled: "true",