fix(auth): scrub legacy pending oauth tokens on upgrade
This commit is contained in:
@@ -1290,6 +1290,9 @@ func buildPendingOAuthSessionStatusPayload(session *dbent.PendingAuthSession) gi
|
|||||||
|
|
||||||
func normalizePendingOAuthCompletionResponse(payload map[string]any) map[string]any {
|
func normalizePendingOAuthCompletionResponse(payload map[string]any) map[string]any {
|
||||||
normalized := clonePendingMap(payload)
|
normalized := clonePendingMap(payload)
|
||||||
|
for _, key := range []string{"access_token", "refresh_token", "expires_in", "token_type"} {
|
||||||
|
delete(normalized, key)
|
||||||
|
}
|
||||||
step := strings.ToLower(strings.TrimSpace(pendingSessionStringValue(normalized, "step")))
|
step := strings.ToLower(strings.TrimSpace(pendingSessionStringValue(normalized, "step")))
|
||||||
switch step {
|
switch step {
|
||||||
case "choice", "choose_account_action", "choose_account", "choose", "email_required", "bind_login_required":
|
case "choice", "choose_account_action", "choose_account", "choose", "email_required", "bind_login_required":
|
||||||
|
|||||||
@@ -851,6 +851,22 @@ func TestExchangePendingOAuthCompletionBlocksBackendModeBeforeReturningTokenPayl
|
|||||||
require.Nil(t, storedSession.ConsumedAt)
|
require.Nil(t, storedSession.ConsumedAt)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestNormalizePendingOAuthCompletionResponseScrubsLegacyTokenPayload(t *testing.T) {
|
||||||
|
payload := normalizePendingOAuthCompletionResponse(map[string]any{
|
||||||
|
"access_token": "legacy-access-token",
|
||||||
|
"refresh_token": "legacy-refresh-token",
|
||||||
|
"expires_in": float64(3600),
|
||||||
|
"token_type": "Bearer",
|
||||||
|
"redirect": "/dashboard",
|
||||||
|
})
|
||||||
|
|
||||||
|
require.NotContains(t, payload, "access_token")
|
||||||
|
require.NotContains(t, payload, "refresh_token")
|
||||||
|
require.NotContains(t, payload, "expires_in")
|
||||||
|
require.NotContains(t, payload, "token_type")
|
||||||
|
require.Equal(t, "/dashboard", payload["redirect"])
|
||||||
|
}
|
||||||
|
|
||||||
func TestExchangePendingOAuthCompletionInvitationRequiredFalseFalsePersistsDecisionWithoutBinding(t *testing.T) {
|
func TestExchangePendingOAuthCompletionInvitationRequiredFalseFalsePersistsDecisionWithoutBinding(t *testing.T) {
|
||||||
handler, client := newOAuthPendingFlowTestHandler(t, true)
|
handler, client := newOAuthPendingFlowTestHandler(t, true)
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|||||||
@@ -236,6 +236,7 @@ func (s *AuthPendingIdentityService) consumeSession(
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
sanitizedLocalFlowState := sanitizePendingAuthLocalFlowState(session.LocalFlowState)
|
||||||
now := time.Now().UTC()
|
now := time.Now().UTC()
|
||||||
update := s.entClient.PendingAuthSession.UpdateOneID(session.ID).
|
update := s.entClient.PendingAuthSession.UpdateOneID(session.ID).
|
||||||
Where(
|
Where(
|
||||||
@@ -247,6 +248,7 @@ func (s *AuthPendingIdentityService) consumeSession(
|
|||||||
),
|
),
|
||||||
).
|
).
|
||||||
SetConsumedAt(now).
|
SetConsumedAt(now).
|
||||||
|
SetLocalFlowState(sanitizedLocalFlowState).
|
||||||
SetCompletionCodeHash("").
|
SetCompletionCodeHash("").
|
||||||
ClearCompletionCodeExpiresAt()
|
ClearCompletionCodeExpiresAt()
|
||||||
if expectedBrowserSessionKey := strings.TrimSpace(session.BrowserSessionKey); expectedBrowserSessionKey != "" {
|
if expectedBrowserSessionKey := strings.TrimSpace(session.BrowserSessionKey); expectedBrowserSessionKey != "" {
|
||||||
@@ -273,6 +275,29 @@ func (s *AuthPendingIdentityService) consumeSession(
|
|||||||
return nil, consumedErr
|
return nil, consumedErr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func sanitizePendingAuthLocalFlowState(localFlowState map[string]any) map[string]any {
|
||||||
|
sanitized := copyPendingMap(localFlowState)
|
||||||
|
if len(sanitized) == 0 {
|
||||||
|
return sanitized
|
||||||
|
}
|
||||||
|
|
||||||
|
rawCompletion, ok := sanitized["completion_response"]
|
||||||
|
if !ok {
|
||||||
|
return sanitized
|
||||||
|
}
|
||||||
|
completion, ok := rawCompletion.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
return sanitized
|
||||||
|
}
|
||||||
|
|
||||||
|
cleanedCompletion := copyPendingMap(completion)
|
||||||
|
for _, key := range []string{"access_token", "refresh_token", "expires_in", "token_type"} {
|
||||||
|
delete(cleanedCompletion, key)
|
||||||
|
}
|
||||||
|
sanitized["completion_response"] = cleanedCompletion
|
||||||
|
return sanitized
|
||||||
|
}
|
||||||
|
|
||||||
func validatePendingSessionState(session *dbent.PendingAuthSession, browserSessionKey string, expiredErr error, consumedErr error) error {
|
func validatePendingSessionState(session *dbent.PendingAuthSession, browserSessionKey string, expiredErr error, consumedErr error) error {
|
||||||
if session == nil {
|
if session == nil {
|
||||||
return ErrPendingAuthSessionNotFound
|
return ErrPendingAuthSessionNotFound
|
||||||
|
|||||||
@@ -0,0 +1,15 @@
|
|||||||
|
UPDATE pending_auth_sessions
|
||||||
|
SET
|
||||||
|
local_flow_state = jsonb_set(
|
||||||
|
local_flow_state,
|
||||||
|
'{completion_response}',
|
||||||
|
((local_flow_state -> 'completion_response') - 'access_token' - 'refresh_token' - 'expires_in' - 'token_type'),
|
||||||
|
true
|
||||||
|
)
|
||||||
|
WHERE jsonb_typeof(local_flow_state -> 'completion_response') = 'object'
|
||||||
|
AND (
|
||||||
|
(local_flow_state -> 'completion_response') ? 'access_token'
|
||||||
|
OR (local_flow_state -> 'completion_response') ? 'refresh_token'
|
||||||
|
OR (local_flow_state -> 'completion_response') ? 'expires_in'
|
||||||
|
OR (local_flow_state -> 'completion_response') ? 'token_type'
|
||||||
|
);
|
||||||
@@ -0,0 +1,39 @@
|
|||||||
|
WITH migration_110 AS (
|
||||||
|
SELECT applied_at
|
||||||
|
FROM schema_migrations
|
||||||
|
WHERE filename = '110_pending_auth_and_provider_default_grants.sql'
|
||||||
|
),
|
||||||
|
legacy_provider_defaults AS (
|
||||||
|
SELECT provider_type
|
||||||
|
FROM (
|
||||||
|
VALUES ('email'), ('linuxdo'), ('oidc'), ('wechat')
|
||||||
|
) AS providers(provider_type)
|
||||||
|
CROSS JOIN migration_110
|
||||||
|
JOIN settings balance
|
||||||
|
ON balance.key = 'auth_source_default_' || providers.provider_type || '_balance'
|
||||||
|
JOIN settings concurrency
|
||||||
|
ON concurrency.key = 'auth_source_default_' || providers.provider_type || '_concurrency'
|
||||||
|
JOIN settings subscriptions
|
||||||
|
ON subscriptions.key = 'auth_source_default_' || providers.provider_type || '_subscriptions'
|
||||||
|
JOIN settings grant_on_signup
|
||||||
|
ON grant_on_signup.key = 'auth_source_default_' || providers.provider_type || '_grant_on_signup'
|
||||||
|
JOIN settings grant_on_first_bind
|
||||||
|
ON grant_on_first_bind.key = 'auth_source_default_' || providers.provider_type || '_grant_on_first_bind'
|
||||||
|
WHERE balance.value = '0'
|
||||||
|
AND concurrency.value = '5'
|
||||||
|
AND subscriptions.value = '[]'
|
||||||
|
AND grant_on_signup.value = 'true'
|
||||||
|
AND grant_on_first_bind.value = 'false'
|
||||||
|
AND balance.updated_at BETWEEN migration_110.applied_at - INTERVAL '1 minute' AND migration_110.applied_at + INTERVAL '1 minute'
|
||||||
|
AND concurrency.updated_at BETWEEN migration_110.applied_at - INTERVAL '1 minute' AND migration_110.applied_at + INTERVAL '1 minute'
|
||||||
|
AND subscriptions.updated_at BETWEEN migration_110.applied_at - INTERVAL '1 minute' AND migration_110.applied_at + INTERVAL '1 minute'
|
||||||
|
AND grant_on_signup.updated_at BETWEEN migration_110.applied_at - INTERVAL '1 minute' AND migration_110.applied_at + INTERVAL '1 minute'
|
||||||
|
AND grant_on_first_bind.updated_at BETWEEN migration_110.applied_at - INTERVAL '1 minute' AND migration_110.applied_at + INTERVAL '1 minute'
|
||||||
|
)
|
||||||
|
UPDATE settings
|
||||||
|
SET
|
||||||
|
value = 'false',
|
||||||
|
updated_at = NOW()
|
||||||
|
FROM legacy_provider_defaults
|
||||||
|
WHERE settings.key = 'auth_source_default_' || legacy_provider_defaults.provider_type || '_grant_on_signup'
|
||||||
|
AND settings.value = 'true';
|
||||||
@@ -59,3 +59,28 @@ func TestMigration119DefersPaymentIndexRolloutToOnlineFollowup(t *testing.T) {
|
|||||||
require.Contains(t, followupSQL, "DROP INDEX CONCURRENTLY IF EXISTS paymentorder_out_trade_no")
|
require.Contains(t, followupSQL, "DROP INDEX CONCURRENTLY IF EXISTS paymentorder_out_trade_no")
|
||||||
require.Contains(t, followupSQL, "WHERE out_trade_no <> ''")
|
require.Contains(t, followupSQL, "WHERE out_trade_no <> ''")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestMigration122ScrubsPendingOAuthCompletionTokensAtRest(t *testing.T) {
|
||||||
|
content, err := FS.ReadFile("122_pending_auth_completion_token_cleanup.sql")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
sql := string(content)
|
||||||
|
require.Contains(t, sql, "UPDATE pending_auth_sessions")
|
||||||
|
require.Contains(t, sql, "completion_response")
|
||||||
|
require.Contains(t, sql, "access_token")
|
||||||
|
require.Contains(t, sql, "refresh_token")
|
||||||
|
require.Contains(t, sql, "expires_in")
|
||||||
|
require.Contains(t, sql, "token_type")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMigration123BackfillsLegacyAuthSourceGrantDefaultsSafely(t *testing.T) {
|
||||||
|
content, err := FS.ReadFile("123_fix_legacy_auth_source_grant_on_signup_defaults.sql")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
sql := string(content)
|
||||||
|
require.Contains(t, sql, "110_pending_auth_and_provider_default_grants.sql")
|
||||||
|
require.Contains(t, sql, "schema_migrations")
|
||||||
|
require.Contains(t, sql, "updated_at")
|
||||||
|
require.Contains(t, sql, "'_grant_on_signup'")
|
||||||
|
require.Contains(t, sql, "value = 'false'")
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user