diff --git a/backend/internal/handler/auth_oauth_pending_flow.go b/backend/internal/handler/auth_oauth_pending_flow.go index c7cd6103..658a5f52 100644 --- a/backend/internal/handler/auth_oauth_pending_flow.go +++ b/backend/internal/handler/auth_oauth_pending_flow.go @@ -1290,6 +1290,9 @@ func buildPendingOAuthSessionStatusPayload(session *dbent.PendingAuthSession) gi func normalizePendingOAuthCompletionResponse(payload map[string]any) map[string]any { normalized := clonePendingMap(payload) + for _, key := range []string{"access_token", "refresh_token", "expires_in", "token_type"} { + delete(normalized, key) + } step := strings.ToLower(strings.TrimSpace(pendingSessionStringValue(normalized, "step"))) switch step { case "choice", "choose_account_action", "choose_account", "choose", "email_required", "bind_login_required": diff --git a/backend/internal/handler/auth_oauth_pending_flow_test.go b/backend/internal/handler/auth_oauth_pending_flow_test.go index a212eb91..c0413d4d 100644 --- a/backend/internal/handler/auth_oauth_pending_flow_test.go +++ b/backend/internal/handler/auth_oauth_pending_flow_test.go @@ -851,6 +851,22 @@ func TestExchangePendingOAuthCompletionBlocksBackendModeBeforeReturningTokenPayl require.Nil(t, storedSession.ConsumedAt) } +func TestNormalizePendingOAuthCompletionResponseScrubsLegacyTokenPayload(t *testing.T) { + payload := normalizePendingOAuthCompletionResponse(map[string]any{ + "access_token": "legacy-access-token", + "refresh_token": "legacy-refresh-token", + "expires_in": float64(3600), + "token_type": "Bearer", + "redirect": "/dashboard", + }) + + require.NotContains(t, payload, "access_token") + require.NotContains(t, payload, "refresh_token") + require.NotContains(t, payload, "expires_in") + require.NotContains(t, payload, "token_type") + require.Equal(t, "/dashboard", payload["redirect"]) +} + func TestExchangePendingOAuthCompletionInvitationRequiredFalseFalsePersistsDecisionWithoutBinding(t *testing.T) { handler, client := newOAuthPendingFlowTestHandler(t, true) ctx := context.Background() diff --git a/backend/internal/service/auth_pending_identity_service.go b/backend/internal/service/auth_pending_identity_service.go index 26732c77..cc0522ab 100644 --- a/backend/internal/service/auth_pending_identity_service.go +++ b/backend/internal/service/auth_pending_identity_service.go @@ -236,6 +236,7 @@ func (s *AuthPendingIdentityService) consumeSession( return nil, err } + sanitizedLocalFlowState := sanitizePendingAuthLocalFlowState(session.LocalFlowState) now := time.Now().UTC() update := s.entClient.PendingAuthSession.UpdateOneID(session.ID). Where( @@ -247,6 +248,7 @@ func (s *AuthPendingIdentityService) consumeSession( ), ). SetConsumedAt(now). + SetLocalFlowState(sanitizedLocalFlowState). SetCompletionCodeHash(""). ClearCompletionCodeExpiresAt() if expectedBrowserSessionKey := strings.TrimSpace(session.BrowserSessionKey); expectedBrowserSessionKey != "" { @@ -273,6 +275,29 @@ func (s *AuthPendingIdentityService) consumeSession( return nil, consumedErr } +func sanitizePendingAuthLocalFlowState(localFlowState map[string]any) map[string]any { + sanitized := copyPendingMap(localFlowState) + if len(sanitized) == 0 { + return sanitized + } + + rawCompletion, ok := sanitized["completion_response"] + if !ok { + return sanitized + } + completion, ok := rawCompletion.(map[string]any) + if !ok { + return sanitized + } + + cleanedCompletion := copyPendingMap(completion) + for _, key := range []string{"access_token", "refresh_token", "expires_in", "token_type"} { + delete(cleanedCompletion, key) + } + sanitized["completion_response"] = cleanedCompletion + return sanitized +} + func validatePendingSessionState(session *dbent.PendingAuthSession, browserSessionKey string, expiredErr error, consumedErr error) error { if session == nil { return ErrPendingAuthSessionNotFound diff --git a/backend/migrations/122_pending_auth_completion_token_cleanup.sql b/backend/migrations/122_pending_auth_completion_token_cleanup.sql new file mode 100644 index 00000000..e6341142 --- /dev/null +++ b/backend/migrations/122_pending_auth_completion_token_cleanup.sql @@ -0,0 +1,15 @@ +UPDATE pending_auth_sessions +SET + local_flow_state = jsonb_set( + local_flow_state, + '{completion_response}', + ((local_flow_state -> 'completion_response') - 'access_token' - 'refresh_token' - 'expires_in' - 'token_type'), + true + ) +WHERE jsonb_typeof(local_flow_state -> 'completion_response') = 'object' + AND ( + (local_flow_state -> 'completion_response') ? 'access_token' + OR (local_flow_state -> 'completion_response') ? 'refresh_token' + OR (local_flow_state -> 'completion_response') ? 'expires_in' + OR (local_flow_state -> 'completion_response') ? 'token_type' + ); diff --git a/backend/migrations/123_fix_legacy_auth_source_grant_on_signup_defaults.sql b/backend/migrations/123_fix_legacy_auth_source_grant_on_signup_defaults.sql new file mode 100644 index 00000000..f6053ef0 --- /dev/null +++ b/backend/migrations/123_fix_legacy_auth_source_grant_on_signup_defaults.sql @@ -0,0 +1,39 @@ +WITH migration_110 AS ( + SELECT applied_at + FROM schema_migrations + WHERE filename = '110_pending_auth_and_provider_default_grants.sql' +), +legacy_provider_defaults AS ( + SELECT provider_type + FROM ( + VALUES ('email'), ('linuxdo'), ('oidc'), ('wechat') + ) AS providers(provider_type) + CROSS JOIN migration_110 + JOIN settings balance + ON balance.key = 'auth_source_default_' || providers.provider_type || '_balance' + JOIN settings concurrency + ON concurrency.key = 'auth_source_default_' || providers.provider_type || '_concurrency' + JOIN settings subscriptions + ON subscriptions.key = 'auth_source_default_' || providers.provider_type || '_subscriptions' + JOIN settings grant_on_signup + ON grant_on_signup.key = 'auth_source_default_' || providers.provider_type || '_grant_on_signup' + JOIN settings grant_on_first_bind + ON grant_on_first_bind.key = 'auth_source_default_' || providers.provider_type || '_grant_on_first_bind' + WHERE balance.value = '0' + AND concurrency.value = '5' + AND subscriptions.value = '[]' + AND grant_on_signup.value = 'true' + AND grant_on_first_bind.value = 'false' + AND balance.updated_at BETWEEN migration_110.applied_at - INTERVAL '1 minute' AND migration_110.applied_at + INTERVAL '1 minute' + AND concurrency.updated_at BETWEEN migration_110.applied_at - INTERVAL '1 minute' AND migration_110.applied_at + INTERVAL '1 minute' + AND subscriptions.updated_at BETWEEN migration_110.applied_at - INTERVAL '1 minute' AND migration_110.applied_at + INTERVAL '1 minute' + AND grant_on_signup.updated_at BETWEEN migration_110.applied_at - INTERVAL '1 minute' AND migration_110.applied_at + INTERVAL '1 minute' + AND grant_on_first_bind.updated_at BETWEEN migration_110.applied_at - INTERVAL '1 minute' AND migration_110.applied_at + INTERVAL '1 minute' +) +UPDATE settings +SET + value = 'false', + updated_at = NOW() +FROM legacy_provider_defaults +WHERE settings.key = 'auth_source_default_' || legacy_provider_defaults.provider_type || '_grant_on_signup' + AND settings.value = 'true'; diff --git a/backend/migrations/auth_identity_payment_migrations_regression_test.go b/backend/migrations/auth_identity_payment_migrations_regression_test.go index 988876a9..48cc427b 100644 --- a/backend/migrations/auth_identity_payment_migrations_regression_test.go +++ b/backend/migrations/auth_identity_payment_migrations_regression_test.go @@ -59,3 +59,28 @@ func TestMigration119DefersPaymentIndexRolloutToOnlineFollowup(t *testing.T) { require.Contains(t, followupSQL, "DROP INDEX CONCURRENTLY IF EXISTS paymentorder_out_trade_no") require.Contains(t, followupSQL, "WHERE out_trade_no <> ''") } + +func TestMigration122ScrubsPendingOAuthCompletionTokensAtRest(t *testing.T) { + content, err := FS.ReadFile("122_pending_auth_completion_token_cleanup.sql") + require.NoError(t, err) + + sql := string(content) + require.Contains(t, sql, "UPDATE pending_auth_sessions") + require.Contains(t, sql, "completion_response") + require.Contains(t, sql, "access_token") + require.Contains(t, sql, "refresh_token") + require.Contains(t, sql, "expires_in") + require.Contains(t, sql, "token_type") +} + +func TestMigration123BackfillsLegacyAuthSourceGrantDefaultsSafely(t *testing.T) { + content, err := FS.ReadFile("123_fix_legacy_auth_source_grant_on_signup_defaults.sql") + require.NoError(t, err) + + sql := string(content) + require.Contains(t, sql, "110_pending_auth_and_provider_default_grants.sql") + require.Contains(t, sql, "schema_migrations") + require.Contains(t, sql, "updated_at") + require.Contains(t, sql, "'_grant_on_signup'") + require.Contains(t, sql, "value = 'false'") +}