From ca1f30a9113f363ce864bfb102ddd7286210fa2d Mon Sep 17 00:00:00 2001 From: IanShaw027 Date: Wed, 22 Apr 2026 11:17:38 +0800 Subject: [PATCH] fix(auth): harden pending oauth session consumption --- .../handler/auth_oauth_pending_flow.go | 73 ++++++++++++------- .../handler/auth_oauth_pending_flow_test.go | 10 +-- .../service/auth_pending_identity_service.go | 35 +++++++-- .../auth_pending_identity_service_test.go | 66 +++++++++++++++++ 4 files changed, 146 insertions(+), 38 deletions(-) diff --git a/backend/internal/handler/auth_oauth_pending_flow.go b/backend/internal/handler/auth_oauth_pending_flow.go index 7d7b50f4..c7cd6103 100644 --- a/backend/internal/handler/auth_oauth_pending_flow.go +++ b/backend/internal/handler/auth_oauth_pending_flow.go @@ -277,6 +277,22 @@ func pendingOAuthCompletionIncludesTokenPayload(payload map[string]any) bool { return false } +func pendingOAuthCompletionCanIssueTokenPair(session *dbent.PendingAuthSession, payload map[string]any) bool { + if session == nil { + return false + } + if !strings.EqualFold(strings.TrimSpace(session.Intent), oauthIntentLogin) { + return false + } + if session.TargetUserID == nil || *session.TargetUserID <= 0 { + return false + } + if pendingSessionWantsInvitation(payload) { + return false + } + return strings.TrimSpace(pendingSessionStringValue(payload, "step")) == "" +} + func ensurePendingOAuthCompleteRegistrationSession(session *dbent.PendingAuthSession) error { if session == nil { return infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid") @@ -1212,13 +1228,7 @@ func (h *AuthHandler) shouldSkipPendingOAuthAdoptionPrompt( if session == nil || len(payload) == 0 { return false, nil } - if !strings.EqualFold(strings.TrimSpace(session.Intent), oauthIntentLogin) { - return false, nil - } - if !pendingOAuthCompletionIncludesTokenPayload(payload) { - return false, nil - } - if session.TargetUserID == nil || *session.TargetUserID <= 0 { + if !pendingOAuthCompletionCanIssueTokenPair(session, payload) { return false, nil } if pendingSessionStringValue(session.UpstreamIdentityClaims, "suggested_display_name") == "" && @@ -1649,6 +1659,22 @@ func (h *AuthHandler) ExchangePendingOAuthCompletion(c *gin.Context) { } } applySuggestedProfileToCompletionResponse(payload, session.UpstreamIdentityClaims) + + canIssueTokenPair := pendingOAuthCompletionCanIssueTokenPair(session, payload) + var loginUser *service.User + if canIssueTokenPair { + loginUser, err = h.userService.GetByID(c.Request.Context(), *session.TargetUserID) + if err != nil { + clearCookies() + response.ErrorFrom(c, err) + return + } + if err := h.ensureBackendModeAllowsUser(c.Request.Context(), loginUser); err != nil { + clearCookies() + response.ErrorFrom(c, err) + return + } + } skipAdoptionPrompt, err := h.shouldSkipPendingOAuthAdoptionPrompt(c.Request.Context(), session, payload) if err != nil { clearCookies() @@ -1658,25 +1684,6 @@ func (h *AuthHandler) ExchangePendingOAuthCompletion(c *gin.Context) { if skipAdoptionPrompt { delete(payload, "adoption_required") } - if pendingOAuthCompletionIncludesTokenPayload(payload) { - if session.TargetUserID == nil || *session.TargetUserID <= 0 { - clearCookies() - response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_COMPLETION_INVALID", "pending auth completion payload is invalid")) - return - } - user, err := h.userService.GetByID(c.Request.Context(), *session.TargetUserID) - if err != nil { - clearCookies() - response.ErrorFrom(c, err) - return - } - if err := h.ensureBackendModeAllowsUser(c.Request.Context(), user); err != nil { - clearCookies() - response.ErrorFrom(c, err) - return - } - h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID) - } if pendingSessionWantsInvitation(payload) { if adoptionDecision.hasDecision() { @@ -1724,6 +1731,20 @@ func (h *AuthHandler) ExchangePendingOAuthCompletion(c *gin.Context) { return } + if canIssueTokenPair { + tokenPair, err := h.authService.GenerateTokenPair(c.Request.Context(), loginUser, "") + if err != nil { + clearCookies() + response.InternalError(c, "Failed to generate token pair") + return + } + h.authService.RecordSuccessfulLogin(c.Request.Context(), loginUser.ID) + payload["access_token"] = tokenPair.AccessToken + payload["refresh_token"] = tokenPair.RefreshToken + payload["expires_in"] = tokenPair.ExpiresIn + payload["token_type"] = "Bearer" + } + clearCookies() response.Success(c, payload) } diff --git a/backend/internal/handler/auth_oauth_pending_flow_test.go b/backend/internal/handler/auth_oauth_pending_flow_test.go index 8940e37d..6f457206 100644 --- a/backend/internal/handler/auth_oauth_pending_flow_test.go +++ b/backend/internal/handler/auth_oauth_pending_flow_test.go @@ -746,11 +746,7 @@ func TestExchangePendingOAuthCompletionExistingLoginWithSuggestedProfileSkipsAdo }). SetLocalFlowState(map[string]any{ oauthCompletionResponseKey: map[string]any{ - "access_token": "access-token", - "refresh_token": "refresh-token", - "expires_in": float64(3600), - "token_type": "Bearer", - "redirect": "/dashboard", + "redirect": "/dashboard", }, }). SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). @@ -769,8 +765,8 @@ func TestExchangePendingOAuthCompletionExistingLoginWithSuggestedProfileSkipsAdo require.Equal(t, http.StatusOK, recorder.Code) payload := decodeJSONResponseData(t, recorder) - require.Equal(t, "access-token", payload["access_token"]) - require.Equal(t, "refresh-token", payload["refresh_token"]) + require.NotEmpty(t, payload["access_token"]) + require.NotEmpty(t, payload["refresh_token"]) require.Equal(t, "/dashboard", payload["redirect"]) require.Equal(t, "Existing Login Example", payload["suggested_display_name"]) require.Equal(t, "https://cdn.example/existing-login.png", payload["suggested_avatar_url"]) diff --git a/backend/internal/service/auth_pending_identity_service.go b/backend/internal/service/auth_pending_identity_service.go index 7001ee18..26732c77 100644 --- a/backend/internal/service/auth_pending_identity_service.go +++ b/backend/internal/service/auth_pending_identity_service.go @@ -237,15 +237,40 @@ func (s *AuthPendingIdentityService) consumeSession( } now := time.Now().UTC() - updated, err := s.entClient.PendingAuthSession.UpdateOneID(session.ID). + update := s.entClient.PendingAuthSession.UpdateOneID(session.ID). + Where( + pendingauthsession.ConsumedAtIsNil(), + pendingauthsession.ExpiresAtGTE(now), + pendingauthsession.Or( + pendingauthsession.CompletionCodeExpiresAtIsNil(), + pendingauthsession.CompletionCodeExpiresAtGTE(now), + ), + ). SetConsumedAt(now). SetCompletionCodeHash(""). - ClearCompletionCodeExpiresAt(). - Save(ctx) - if err != nil { + ClearCompletionCodeExpiresAt() + if expectedBrowserSessionKey := strings.TrimSpace(session.BrowserSessionKey); expectedBrowserSessionKey != "" { + update = update.Where(pendingauthsession.BrowserSessionKeyEQ(expectedBrowserSessionKey)) + } + updated, err := update.Save(ctx) + if err == nil { + return updated, nil + } + if !dbent.IsNotFound(err) { return nil, err } - return updated, nil + + current, currentErr := s.entClient.PendingAuthSession.Get(ctx, session.ID) + if currentErr != nil { + if dbent.IsNotFound(currentErr) { + return nil, ErrPendingAuthSessionNotFound + } + return nil, currentErr + } + if err := validatePendingSessionState(current, browserSessionKey, expiredErr, consumedErr); err != nil { + return nil, err + } + return nil, consumedErr } func validatePendingSessionState(session *dbent.PendingAuthSession, browserSessionKey string, expiredErr error, consumedErr error) error { diff --git a/backend/internal/service/auth_pending_identity_service_test.go b/backend/internal/service/auth_pending_identity_service_test.go index de0b18d2..deeeeb06 100644 --- a/backend/internal/service/auth_pending_identity_service_test.go +++ b/backend/internal/service/auth_pending_identity_service_test.go @@ -356,3 +356,69 @@ func TestAuthPendingIdentityService_ConsumeBrowserSession(t *testing.T) { _, err = svc.ConsumeBrowserSession(ctx, session.SessionToken, "browser-session") require.ErrorIs(t, err, ErrPendingAuthSessionConsumed) } + +func TestAuthPendingIdentityService_ConsumeBrowserSessionRejectsStaleLoadedSessionReplay(t *testing.T) { + svc, _ := newAuthPendingIdentityServiceTestClient(t) + ctx := context.Background() + + session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{ + Intent: "login", + Identity: PendingAuthIdentityKey{ + ProviderType: "linuxdo", + ProviderKey: "linuxdo", + ProviderSubject: "stale-replay-subject", + }, + BrowserSessionKey: "browser-session", + }) + require.NoError(t, err) + + loaded, err := svc.getBrowserSession(ctx, session.SessionToken) + require.NoError(t, err) + + consumed, err := svc.consumeSession(ctx, loaded, "browser-session", ErrPendingAuthSessionExpired, ErrPendingAuthSessionConsumed) + require.NoError(t, err) + require.NotNil(t, consumed.ConsumedAt) + + _, err = svc.consumeSession(ctx, loaded, "browser-session", ErrPendingAuthSessionExpired, ErrPendingAuthSessionConsumed) + require.ErrorIs(t, err, ErrPendingAuthSessionConsumed) +} + +func TestAuthPendingIdentityService_ConsumeBrowserSessionScrubsLegacyCompletionTokens(t *testing.T) { + svc, client := newAuthPendingIdentityServiceTestClient(t) + ctx := context.Background() + + session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{ + Intent: "login", + Identity: PendingAuthIdentityKey{ + ProviderType: "linuxdo", + ProviderKey: "linuxdo", + ProviderSubject: "legacy-token-subject", + }, + BrowserSessionKey: "browser-session", + LocalFlowState: map[string]any{ + "completion_response": map[string]any{ + "access_token": "legacy-access-token", + "refresh_token": "legacy-refresh-token", + "expires_in": float64(3600), + "token_type": "Bearer", + "redirect": "/dashboard", + }, + }, + }) + require.NoError(t, err) + + consumed, err := svc.ConsumeBrowserSession(ctx, session.SessionToken, "browser-session") + require.NoError(t, err) + require.NotNil(t, consumed.ConsumedAt) + + stored, err := client.PendingAuthSession.Get(ctx, session.ID) + require.NoError(t, err) + + completion, ok := stored.LocalFlowState["completion_response"].(map[string]any) + require.True(t, ok) + require.NotContains(t, completion, "access_token") + require.NotContains(t, completion, "refresh_token") + require.NotContains(t, completion, "expires_in") + require.NotContains(t, completion, "token_type") + require.Equal(t, "/dashboard", completion["redirect"]) +}