fix(auth): scrub legacy pending oauth tokens on upgrade

This commit is contained in:
IanShaw027
2026-04-22 11:29:05 +08:00
parent 9d5e9bbc18
commit be9df2bea7
6 changed files with 123 additions and 0 deletions

View File

@@ -1290,6 +1290,9 @@ func buildPendingOAuthSessionStatusPayload(session *dbent.PendingAuthSession) gi
func normalizePendingOAuthCompletionResponse(payload map[string]any) map[string]any { func normalizePendingOAuthCompletionResponse(payload map[string]any) map[string]any {
normalized := clonePendingMap(payload) normalized := clonePendingMap(payload)
for _, key := range []string{"access_token", "refresh_token", "expires_in", "token_type"} {
delete(normalized, key)
}
step := strings.ToLower(strings.TrimSpace(pendingSessionStringValue(normalized, "step"))) step := strings.ToLower(strings.TrimSpace(pendingSessionStringValue(normalized, "step")))
switch step { switch step {
case "choice", "choose_account_action", "choose_account", "choose", "email_required", "bind_login_required": case "choice", "choose_account_action", "choose_account", "choose", "email_required", "bind_login_required":

View File

@@ -851,6 +851,22 @@ func TestExchangePendingOAuthCompletionBlocksBackendModeBeforeReturningTokenPayl
require.Nil(t, storedSession.ConsumedAt) require.Nil(t, storedSession.ConsumedAt)
} }
func TestNormalizePendingOAuthCompletionResponseScrubsLegacyTokenPayload(t *testing.T) {
payload := normalizePendingOAuthCompletionResponse(map[string]any{
"access_token": "legacy-access-token",
"refresh_token": "legacy-refresh-token",
"expires_in": float64(3600),
"token_type": "Bearer",
"redirect": "/dashboard",
})
require.NotContains(t, payload, "access_token")
require.NotContains(t, payload, "refresh_token")
require.NotContains(t, payload, "expires_in")
require.NotContains(t, payload, "token_type")
require.Equal(t, "/dashboard", payload["redirect"])
}
func TestExchangePendingOAuthCompletionInvitationRequiredFalseFalsePersistsDecisionWithoutBinding(t *testing.T) { func TestExchangePendingOAuthCompletionInvitationRequiredFalseFalsePersistsDecisionWithoutBinding(t *testing.T) {
handler, client := newOAuthPendingFlowTestHandler(t, true) handler, client := newOAuthPendingFlowTestHandler(t, true)
ctx := context.Background() ctx := context.Background()

View File

@@ -236,6 +236,7 @@ func (s *AuthPendingIdentityService) consumeSession(
return nil, err return nil, err
} }
sanitizedLocalFlowState := sanitizePendingAuthLocalFlowState(session.LocalFlowState)
now := time.Now().UTC() now := time.Now().UTC()
update := s.entClient.PendingAuthSession.UpdateOneID(session.ID). update := s.entClient.PendingAuthSession.UpdateOneID(session.ID).
Where( Where(
@@ -247,6 +248,7 @@ func (s *AuthPendingIdentityService) consumeSession(
), ),
). ).
SetConsumedAt(now). SetConsumedAt(now).
SetLocalFlowState(sanitizedLocalFlowState).
SetCompletionCodeHash(""). SetCompletionCodeHash("").
ClearCompletionCodeExpiresAt() ClearCompletionCodeExpiresAt()
if expectedBrowserSessionKey := strings.TrimSpace(session.BrowserSessionKey); expectedBrowserSessionKey != "" { if expectedBrowserSessionKey := strings.TrimSpace(session.BrowserSessionKey); expectedBrowserSessionKey != "" {
@@ -273,6 +275,29 @@ func (s *AuthPendingIdentityService) consumeSession(
return nil, consumedErr return nil, consumedErr
} }
func sanitizePendingAuthLocalFlowState(localFlowState map[string]any) map[string]any {
sanitized := copyPendingMap(localFlowState)
if len(sanitized) == 0 {
return sanitized
}
rawCompletion, ok := sanitized["completion_response"]
if !ok {
return sanitized
}
completion, ok := rawCompletion.(map[string]any)
if !ok {
return sanitized
}
cleanedCompletion := copyPendingMap(completion)
for _, key := range []string{"access_token", "refresh_token", "expires_in", "token_type"} {
delete(cleanedCompletion, key)
}
sanitized["completion_response"] = cleanedCompletion
return sanitized
}
func validatePendingSessionState(session *dbent.PendingAuthSession, browserSessionKey string, expiredErr error, consumedErr error) error { func validatePendingSessionState(session *dbent.PendingAuthSession, browserSessionKey string, expiredErr error, consumedErr error) error {
if session == nil { if session == nil {
return ErrPendingAuthSessionNotFound return ErrPendingAuthSessionNotFound

View File

@@ -0,0 +1,15 @@
UPDATE pending_auth_sessions
SET
local_flow_state = jsonb_set(
local_flow_state,
'{completion_response}',
((local_flow_state -> 'completion_response') - 'access_token' - 'refresh_token' - 'expires_in' - 'token_type'),
true
)
WHERE jsonb_typeof(local_flow_state -> 'completion_response') = 'object'
AND (
(local_flow_state -> 'completion_response') ? 'access_token'
OR (local_flow_state -> 'completion_response') ? 'refresh_token'
OR (local_flow_state -> 'completion_response') ? 'expires_in'
OR (local_flow_state -> 'completion_response') ? 'token_type'
);

View File

@@ -0,0 +1,39 @@
WITH migration_110 AS (
SELECT applied_at
FROM schema_migrations
WHERE filename = '110_pending_auth_and_provider_default_grants.sql'
),
legacy_provider_defaults AS (
SELECT provider_type
FROM (
VALUES ('email'), ('linuxdo'), ('oidc'), ('wechat')
) AS providers(provider_type)
CROSS JOIN migration_110
JOIN settings balance
ON balance.key = 'auth_source_default_' || providers.provider_type || '_balance'
JOIN settings concurrency
ON concurrency.key = 'auth_source_default_' || providers.provider_type || '_concurrency'
JOIN settings subscriptions
ON subscriptions.key = 'auth_source_default_' || providers.provider_type || '_subscriptions'
JOIN settings grant_on_signup
ON grant_on_signup.key = 'auth_source_default_' || providers.provider_type || '_grant_on_signup'
JOIN settings grant_on_first_bind
ON grant_on_first_bind.key = 'auth_source_default_' || providers.provider_type || '_grant_on_first_bind'
WHERE balance.value = '0'
AND concurrency.value = '5'
AND subscriptions.value = '[]'
AND grant_on_signup.value = 'true'
AND grant_on_first_bind.value = 'false'
AND balance.updated_at BETWEEN migration_110.applied_at - INTERVAL '1 minute' AND migration_110.applied_at + INTERVAL '1 minute'
AND concurrency.updated_at BETWEEN migration_110.applied_at - INTERVAL '1 minute' AND migration_110.applied_at + INTERVAL '1 minute'
AND subscriptions.updated_at BETWEEN migration_110.applied_at - INTERVAL '1 minute' AND migration_110.applied_at + INTERVAL '1 minute'
AND grant_on_signup.updated_at BETWEEN migration_110.applied_at - INTERVAL '1 minute' AND migration_110.applied_at + INTERVAL '1 minute'
AND grant_on_first_bind.updated_at BETWEEN migration_110.applied_at - INTERVAL '1 minute' AND migration_110.applied_at + INTERVAL '1 minute'
)
UPDATE settings
SET
value = 'false',
updated_at = NOW()
FROM legacy_provider_defaults
WHERE settings.key = 'auth_source_default_' || legacy_provider_defaults.provider_type || '_grant_on_signup'
AND settings.value = 'true';

View File

@@ -59,3 +59,28 @@ func TestMigration119DefersPaymentIndexRolloutToOnlineFollowup(t *testing.T) {
require.Contains(t, followupSQL, "DROP INDEX CONCURRENTLY IF EXISTS paymentorder_out_trade_no") require.Contains(t, followupSQL, "DROP INDEX CONCURRENTLY IF EXISTS paymentorder_out_trade_no")
require.Contains(t, followupSQL, "WHERE out_trade_no <> ''") require.Contains(t, followupSQL, "WHERE out_trade_no <> ''")
} }
func TestMigration122ScrubsPendingOAuthCompletionTokensAtRest(t *testing.T) {
content, err := FS.ReadFile("122_pending_auth_completion_token_cleanup.sql")
require.NoError(t, err)
sql := string(content)
require.Contains(t, sql, "UPDATE pending_auth_sessions")
require.Contains(t, sql, "completion_response")
require.Contains(t, sql, "access_token")
require.Contains(t, sql, "refresh_token")
require.Contains(t, sql, "expires_in")
require.Contains(t, sql, "token_type")
}
func TestMigration123BackfillsLegacyAuthSourceGrantDefaultsSafely(t *testing.T) {
content, err := FS.ReadFile("123_fix_legacy_auth_source_grant_on_signup_defaults.sql")
require.NoError(t, err)
sql := string(content)
require.Contains(t, sql, "110_pending_auth_and_provider_default_grants.sql")
require.Contains(t, sql, "schema_migrations")
require.Contains(t, sql, "updated_at")
require.Contains(t, sql, "'_grant_on_signup'")
require.Contains(t, sql, "value = 'false'")
}