fix(auth): harden pending oauth session consumption
This commit is contained in:
@@ -277,6 +277,22 @@ func pendingOAuthCompletionIncludesTokenPayload(payload map[string]any) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func pendingOAuthCompletionCanIssueTokenPair(session *dbent.PendingAuthSession, payload map[string]any) bool {
|
||||
if session == nil {
|
||||
return false
|
||||
}
|
||||
if !strings.EqualFold(strings.TrimSpace(session.Intent), oauthIntentLogin) {
|
||||
return false
|
||||
}
|
||||
if session.TargetUserID == nil || *session.TargetUserID <= 0 {
|
||||
return false
|
||||
}
|
||||
if pendingSessionWantsInvitation(payload) {
|
||||
return false
|
||||
}
|
||||
return strings.TrimSpace(pendingSessionStringValue(payload, "step")) == ""
|
||||
}
|
||||
|
||||
func ensurePendingOAuthCompleteRegistrationSession(session *dbent.PendingAuthSession) error {
|
||||
if session == nil {
|
||||
return infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid")
|
||||
@@ -1212,13 +1228,7 @@ func (h *AuthHandler) shouldSkipPendingOAuthAdoptionPrompt(
|
||||
if session == nil || len(payload) == 0 {
|
||||
return false, nil
|
||||
}
|
||||
if !strings.EqualFold(strings.TrimSpace(session.Intent), oauthIntentLogin) {
|
||||
return false, nil
|
||||
}
|
||||
if !pendingOAuthCompletionIncludesTokenPayload(payload) {
|
||||
return false, nil
|
||||
}
|
||||
if session.TargetUserID == nil || *session.TargetUserID <= 0 {
|
||||
if !pendingOAuthCompletionCanIssueTokenPair(session, payload) {
|
||||
return false, nil
|
||||
}
|
||||
if pendingSessionStringValue(session.UpstreamIdentityClaims, "suggested_display_name") == "" &&
|
||||
@@ -1649,6 +1659,22 @@ func (h *AuthHandler) ExchangePendingOAuthCompletion(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
applySuggestedProfileToCompletionResponse(payload, session.UpstreamIdentityClaims)
|
||||
|
||||
canIssueTokenPair := pendingOAuthCompletionCanIssueTokenPair(session, payload)
|
||||
var loginUser *service.User
|
||||
if canIssueTokenPair {
|
||||
loginUser, err = h.userService.GetByID(c.Request.Context(), *session.TargetUserID)
|
||||
if err != nil {
|
||||
clearCookies()
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
if err := h.ensureBackendModeAllowsUser(c.Request.Context(), loginUser); err != nil {
|
||||
clearCookies()
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
skipAdoptionPrompt, err := h.shouldSkipPendingOAuthAdoptionPrompt(c.Request.Context(), session, payload)
|
||||
if err != nil {
|
||||
clearCookies()
|
||||
@@ -1658,25 +1684,6 @@ func (h *AuthHandler) ExchangePendingOAuthCompletion(c *gin.Context) {
|
||||
if skipAdoptionPrompt {
|
||||
delete(payload, "adoption_required")
|
||||
}
|
||||
if pendingOAuthCompletionIncludesTokenPayload(payload) {
|
||||
if session.TargetUserID == nil || *session.TargetUserID <= 0 {
|
||||
clearCookies()
|
||||
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_COMPLETION_INVALID", "pending auth completion payload is invalid"))
|
||||
return
|
||||
}
|
||||
user, err := h.userService.GetByID(c.Request.Context(), *session.TargetUserID)
|
||||
if err != nil {
|
||||
clearCookies()
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
if err := h.ensureBackendModeAllowsUser(c.Request.Context(), user); err != nil {
|
||||
clearCookies()
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID)
|
||||
}
|
||||
|
||||
if pendingSessionWantsInvitation(payload) {
|
||||
if adoptionDecision.hasDecision() {
|
||||
@@ -1724,6 +1731,20 @@ func (h *AuthHandler) ExchangePendingOAuthCompletion(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if canIssueTokenPair {
|
||||
tokenPair, err := h.authService.GenerateTokenPair(c.Request.Context(), loginUser, "")
|
||||
if err != nil {
|
||||
clearCookies()
|
||||
response.InternalError(c, "Failed to generate token pair")
|
||||
return
|
||||
}
|
||||
h.authService.RecordSuccessfulLogin(c.Request.Context(), loginUser.ID)
|
||||
payload["access_token"] = tokenPair.AccessToken
|
||||
payload["refresh_token"] = tokenPair.RefreshToken
|
||||
payload["expires_in"] = tokenPair.ExpiresIn
|
||||
payload["token_type"] = "Bearer"
|
||||
}
|
||||
|
||||
clearCookies()
|
||||
response.Success(c, payload)
|
||||
}
|
||||
|
||||
@@ -746,11 +746,7 @@ func TestExchangePendingOAuthCompletionExistingLoginWithSuggestedProfileSkipsAdo
|
||||
}).
|
||||
SetLocalFlowState(map[string]any{
|
||||
oauthCompletionResponseKey: map[string]any{
|
||||
"access_token": "access-token",
|
||||
"refresh_token": "refresh-token",
|
||||
"expires_in": float64(3600),
|
||||
"token_type": "Bearer",
|
||||
"redirect": "/dashboard",
|
||||
"redirect": "/dashboard",
|
||||
},
|
||||
}).
|
||||
SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
|
||||
@@ -769,8 +765,8 @@ func TestExchangePendingOAuthCompletionExistingLoginWithSuggestedProfileSkipsAdo
|
||||
require.Equal(t, http.StatusOK, recorder.Code)
|
||||
|
||||
payload := decodeJSONResponseData(t, recorder)
|
||||
require.Equal(t, "access-token", payload["access_token"])
|
||||
require.Equal(t, "refresh-token", payload["refresh_token"])
|
||||
require.NotEmpty(t, payload["access_token"])
|
||||
require.NotEmpty(t, payload["refresh_token"])
|
||||
require.Equal(t, "/dashboard", payload["redirect"])
|
||||
require.Equal(t, "Existing Login Example", payload["suggested_display_name"])
|
||||
require.Equal(t, "https://cdn.example/existing-login.png", payload["suggested_avatar_url"])
|
||||
|
||||
@@ -237,15 +237,40 @@ func (s *AuthPendingIdentityService) consumeSession(
|
||||
}
|
||||
|
||||
now := time.Now().UTC()
|
||||
updated, err := s.entClient.PendingAuthSession.UpdateOneID(session.ID).
|
||||
update := s.entClient.PendingAuthSession.UpdateOneID(session.ID).
|
||||
Where(
|
||||
pendingauthsession.ConsumedAtIsNil(),
|
||||
pendingauthsession.ExpiresAtGTE(now),
|
||||
pendingauthsession.Or(
|
||||
pendingauthsession.CompletionCodeExpiresAtIsNil(),
|
||||
pendingauthsession.CompletionCodeExpiresAtGTE(now),
|
||||
),
|
||||
).
|
||||
SetConsumedAt(now).
|
||||
SetCompletionCodeHash("").
|
||||
ClearCompletionCodeExpiresAt().
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
ClearCompletionCodeExpiresAt()
|
||||
if expectedBrowserSessionKey := strings.TrimSpace(session.BrowserSessionKey); expectedBrowserSessionKey != "" {
|
||||
update = update.Where(pendingauthsession.BrowserSessionKeyEQ(expectedBrowserSessionKey))
|
||||
}
|
||||
updated, err := update.Save(ctx)
|
||||
if err == nil {
|
||||
return updated, nil
|
||||
}
|
||||
if !dbent.IsNotFound(err) {
|
||||
return nil, err
|
||||
}
|
||||
return updated, nil
|
||||
|
||||
current, currentErr := s.entClient.PendingAuthSession.Get(ctx, session.ID)
|
||||
if currentErr != nil {
|
||||
if dbent.IsNotFound(currentErr) {
|
||||
return nil, ErrPendingAuthSessionNotFound
|
||||
}
|
||||
return nil, currentErr
|
||||
}
|
||||
if err := validatePendingSessionState(current, browserSessionKey, expiredErr, consumedErr); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return nil, consumedErr
|
||||
}
|
||||
|
||||
func validatePendingSessionState(session *dbent.PendingAuthSession, browserSessionKey string, expiredErr error, consumedErr error) error {
|
||||
|
||||
@@ -356,3 +356,69 @@ func TestAuthPendingIdentityService_ConsumeBrowserSession(t *testing.T) {
|
||||
_, err = svc.ConsumeBrowserSession(ctx, session.SessionToken, "browser-session")
|
||||
require.ErrorIs(t, err, ErrPendingAuthSessionConsumed)
|
||||
}
|
||||
|
||||
func TestAuthPendingIdentityService_ConsumeBrowserSessionRejectsStaleLoadedSessionReplay(t *testing.T) {
|
||||
svc, _ := newAuthPendingIdentityServiceTestClient(t)
|
||||
ctx := context.Background()
|
||||
|
||||
session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{
|
||||
Intent: "login",
|
||||
Identity: PendingAuthIdentityKey{
|
||||
ProviderType: "linuxdo",
|
||||
ProviderKey: "linuxdo",
|
||||
ProviderSubject: "stale-replay-subject",
|
||||
},
|
||||
BrowserSessionKey: "browser-session",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
loaded, err := svc.getBrowserSession(ctx, session.SessionToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
consumed, err := svc.consumeSession(ctx, loaded, "browser-session", ErrPendingAuthSessionExpired, ErrPendingAuthSessionConsumed)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, consumed.ConsumedAt)
|
||||
|
||||
_, err = svc.consumeSession(ctx, loaded, "browser-session", ErrPendingAuthSessionExpired, ErrPendingAuthSessionConsumed)
|
||||
require.ErrorIs(t, err, ErrPendingAuthSessionConsumed)
|
||||
}
|
||||
|
||||
func TestAuthPendingIdentityService_ConsumeBrowserSessionScrubsLegacyCompletionTokens(t *testing.T) {
|
||||
svc, client := newAuthPendingIdentityServiceTestClient(t)
|
||||
ctx := context.Background()
|
||||
|
||||
session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{
|
||||
Intent: "login",
|
||||
Identity: PendingAuthIdentityKey{
|
||||
ProviderType: "linuxdo",
|
||||
ProviderKey: "linuxdo",
|
||||
ProviderSubject: "legacy-token-subject",
|
||||
},
|
||||
BrowserSessionKey: "browser-session",
|
||||
LocalFlowState: map[string]any{
|
||||
"completion_response": map[string]any{
|
||||
"access_token": "legacy-access-token",
|
||||
"refresh_token": "legacy-refresh-token",
|
||||
"expires_in": float64(3600),
|
||||
"token_type": "Bearer",
|
||||
"redirect": "/dashboard",
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
consumed, err := svc.ConsumeBrowserSession(ctx, session.SessionToken, "browser-session")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, consumed.ConsumedAt)
|
||||
|
||||
stored, err := client.PendingAuthSession.Get(ctx, session.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
completion, ok := stored.LocalFlowState["completion_response"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.NotContains(t, completion, "access_token")
|
||||
require.NotContains(t, completion, "refresh_token")
|
||||
require.NotContains(t, completion, "expires_in")
|
||||
require.NotContains(t, completion, "token_type")
|
||||
require.Equal(t, "/dashboard", completion["redirect"])
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user