fix(migrations): keep auth identity and payment upgrades safe

This commit is contained in:
IanShaw027
2026-04-22 12:29:52 +08:00
parent be9df2bea7
commit 1ffebbb568
5 changed files with 278 additions and 80 deletions

View File

@@ -4,6 +4,7 @@ package repository
import ( import (
"context" "context"
"database/sql"
"os" "os"
"path/filepath" "path/filepath"
"strconv" "strconv"
@@ -20,32 +21,8 @@ func TestAuthIdentityLegacyExternalBackfillMigration(t *testing.T) {
migrationSQL, err := os.ReadFile(migrationPath) migrationSQL, err := os.ReadFile(migrationPath)
require.NoError(t, err) require.NoError(t, err)
_, err = tx.ExecContext(ctx, ` prepareLegacyExternalIdentitiesTable(t, tx, ctx)
CREATE TABLE IF NOT EXISTS user_external_identities ( truncateAuthIdentityLegacyFixtureTables(t, tx, ctx)
id BIGSERIAL PRIMARY KEY,
user_id BIGINT NOT NULL,
provider TEXT NOT NULL,
provider_user_id TEXT NOT NULL,
provider_union_id TEXT NULL,
provider_username TEXT NOT NULL DEFAULT '',
display_name TEXT NOT NULL DEFAULT '',
profile_url TEXT NOT NULL DEFAULT '',
avatar_url TEXT NOT NULL DEFAULT '',
metadata TEXT NOT NULL DEFAULT '{}',
created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP
);
TRUNCATE TABLE
auth_identity_channels,
identity_adoption_decisions,
auth_identities,
auth_identity_migration_reports,
user_external_identities,
users
RESTART IDENTITY CASCADE;
`)
require.NoError(t, err)
var linuxDoUserID int64 var linuxDoUserID int64
require.NoError(t, tx.QueryRowContext(ctx, ` require.NoError(t, tx.QueryRowContext(ctx, `
@@ -218,32 +195,8 @@ func TestAuthIdentityLegacyExternalMigrations_ChainHandlesMalformedAndNonObjectM
migration116SQL, err := os.ReadFile(migration116Path) migration116SQL, err := os.ReadFile(migration116Path)
require.NoError(t, err) require.NoError(t, err)
_, err = tx.ExecContext(ctx, ` prepareLegacyExternalIdentitiesTable(t, tx, ctx)
CREATE TABLE IF NOT EXISTS user_external_identities ( truncateAuthIdentityLegacyFixtureTables(t, tx, ctx)
id BIGSERIAL PRIMARY KEY,
user_id BIGINT NOT NULL,
provider TEXT NOT NULL,
provider_user_id TEXT NOT NULL,
provider_union_id TEXT NULL,
provider_username TEXT NOT NULL DEFAULT '',
display_name TEXT NOT NULL DEFAULT '',
profile_url TEXT NOT NULL DEFAULT '',
avatar_url TEXT NOT NULL DEFAULT '',
metadata TEXT NOT NULL DEFAULT '{}',
created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP
);
TRUNCATE TABLE
auth_identity_channels,
identity_adoption_decisions,
auth_identities,
auth_identity_migration_reports,
user_external_identities,
users
RESTART IDENTITY CASCADE;
`)
require.NoError(t, err)
var linuxDoMalformedUserID int64 var linuxDoMalformedUserID int64
require.NoError(t, tx.QueryRowContext(ctx, ` require.NoError(t, tx.QueryRowContext(ctx, `
@@ -408,32 +361,8 @@ func TestAuthIdentityLegacyExternalSafetyMigration_ReportsConflictsAndDowngrades
migrationSQL, err := os.ReadFile(migrationPath) migrationSQL, err := os.ReadFile(migrationPath)
require.NoError(t, err) require.NoError(t, err)
_, err = tx.ExecContext(ctx, ` prepareLegacyExternalIdentitiesTable(t, tx, ctx)
CREATE TABLE IF NOT EXISTS user_external_identities ( truncateAuthIdentityLegacyFixtureTables(t, tx, ctx)
id BIGSERIAL PRIMARY KEY,
user_id BIGINT NOT NULL,
provider TEXT NOT NULL,
provider_user_id TEXT NOT NULL,
provider_union_id TEXT NULL,
provider_username TEXT NOT NULL DEFAULT '',
display_name TEXT NOT NULL DEFAULT '',
profile_url TEXT NOT NULL DEFAULT '',
avatar_url TEXT NOT NULL DEFAULT '',
metadata TEXT NOT NULL DEFAULT '{}',
created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP
);
TRUNCATE TABLE
auth_identity_channels,
identity_adoption_decisions,
auth_identities,
auth_identity_migration_reports,
user_external_identities,
users
RESTART IDENTITY CASCADE;
`)
require.NoError(t, err)
userIDs := make([]int64, 0, 8) userIDs := make([]int64, 0, 8)
for _, email := range []string{ for _, email := range []string{
@@ -646,3 +575,133 @@ FROM auth_identity_migration_reports
`).Scan(&afterCount)) `).Scan(&afterCount))
require.Equal(t, beforeCount, afterCount) require.Equal(t, beforeCount, afterCount)
} }
func TestAuthIdentityMigrationReportTypeWideningPreflightKeeps109And116SafeBefore121(t *testing.T) {
tx := testTx(t)
ctx := context.Background()
migration108aPath := filepath.Join("..", "..", "migrations", "108a_widen_auth_identity_migration_report_type.sql")
migration108aSQL, err := os.ReadFile(migration108aPath)
require.NoError(t, err)
migration109Path := filepath.Join("..", "..", "migrations", "109_auth_identity_compat_backfill.sql")
migration109SQL, err := os.ReadFile(migration109Path)
require.NoError(t, err)
migration116Path := filepath.Join("..", "..", "migrations", "116_auth_identity_legacy_external_safety_reports.sql")
migration116SQL, err := os.ReadFile(migration116Path)
require.NoError(t, err)
prepareLegacyExternalIdentitiesTable(t, tx, ctx)
truncateAuthIdentityLegacyFixtureTables(t, tx, ctx)
_, err = tx.ExecContext(ctx, `
ALTER TABLE auth_identity_migration_reports
ALTER COLUMN report_type TYPE VARCHAR(40);
`)
require.NoError(t, err)
var oidcSyntheticUserID int64
require.NoError(t, tx.QueryRowContext(ctx, `
INSERT INTO users (email, password_hash, role, status, balance, concurrency)
VALUES ('oidc-before-121@oidc-connect.invalid', 'hash', 'user', 'active', 0, 1)
RETURNING id`).Scan(&oidcSyntheticUserID))
var linuxdoLegacyUserID int64
require.NoError(t, tx.QueryRowContext(ctx, `
INSERT INTO users (email, password_hash, role, status, balance, concurrency)
VALUES ('legacy-linuxdo-before-121@example.com', 'hash', 'user', 'active', 0, 1)
RETURNING id`).Scan(&linuxdoLegacyUserID))
var invalidMetadataLegacyID int64
require.NoError(t, tx.QueryRowContext(ctx, `
INSERT INTO user_external_identities (
user_id,
provider,
provider_user_id,
provider_union_id,
provider_username,
display_name,
metadata
) VALUES ($1, 'linuxdo', 'linuxdo-before-121', NULL, 'legacy-linuxdo-before-121', 'Legacy LinuxDo Before 121', '{invalid')
RETURNING id
`, linuxdoLegacyUserID).Scan(&invalidMetadataLegacyID))
_, err = tx.ExecContext(ctx, string(migration108aSQL))
require.NoError(t, err)
_, err = tx.ExecContext(ctx, string(migration109SQL))
require.NoError(t, err)
_, err = tx.ExecContext(ctx, string(migration116SQL))
require.NoError(t, err)
var reportTypeWidth int
require.NoError(t, tx.QueryRowContext(ctx, `
SELECT character_maximum_length
FROM information_schema.columns
WHERE table_schema = 'public'
AND table_name = 'auth_identity_migration_reports'
AND column_name = 'report_type'
`).Scan(&reportTypeWidth))
require.Equal(t, 80, reportTypeWidth)
var oidcSyntheticRecoveryReportCount int
require.NoError(t, tx.QueryRowContext(ctx, `
SELECT COUNT(*)
FROM auth_identity_migration_reports
WHERE report_type = 'oidc_synthetic_email_requires_manual_recovery'
AND report_key = $1
`, strconv.FormatInt(oidcSyntheticUserID, 10)).Scan(&oidcSyntheticRecoveryReportCount))
require.Equal(t, 1, oidcSyntheticRecoveryReportCount)
var invalidMetadataReportCount int
require.NoError(t, tx.QueryRowContext(ctx, `
SELECT COUNT(*)
FROM auth_identity_migration_reports
WHERE report_type = 'legacy_external_identity_invalid_metadata_json'
AND report_key = $1
`, "legacy_external_identity:"+strconv.FormatInt(invalidMetadataLegacyID, 10)).Scan(&invalidMetadataReportCount))
require.Equal(t, 1, invalidMetadataReportCount)
}
func prepareLegacyExternalIdentitiesTable(t *testing.T, tx *sql.Tx, ctx context.Context) {
t.Helper()
_, err := tx.ExecContext(ctx, `
CREATE TABLE IF NOT EXISTS user_external_identities (
id BIGSERIAL PRIMARY KEY,
user_id BIGINT NOT NULL,
provider TEXT NOT NULL,
provider_user_id TEXT NOT NULL,
provider_union_id TEXT NULL,
provider_username TEXT NOT NULL DEFAULT '',
display_name TEXT NOT NULL DEFAULT '',
profile_url TEXT NOT NULL DEFAULT '',
avatar_url TEXT NOT NULL DEFAULT '',
metadata TEXT NOT NULL DEFAULT '{}',
created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP
);
`)
require.NoError(t, err)
}
func truncateAuthIdentityLegacyFixtureTables(t *testing.T, tx *sql.Tx, ctx context.Context) {
t.Helper()
_, err := tx.ExecContext(ctx, `
TRUNCATE TABLE
auth_identity_channels,
identity_adoption_decisions,
pending_auth_sessions,
auth_identities,
auth_identity_migration_reports,
user_provider_default_grants,
user_avatars,
user_external_identities,
users
RESTART IDENTITY CASCADE;
`)
require.NoError(t, err)
}

View File

@@ -89,6 +89,22 @@ func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) {
requireColumn(t, tx, "user_allowed_groups", "created_at", "timestamp with time zone", 0, false) requireColumn(t, tx, "user_allowed_groups", "created_at", "timestamp with time zone", 0, false)
} }
func TestMigrationsRunner_AuthIdentityAndPaymentSchemaStayAligned(t *testing.T) {
tx := testTx(t)
requireColumn(t, tx, "auth_identity_migration_reports", "report_type", "character varying", 80, false)
requireForeignKeyOnDelete(t, tx, "auth_identities", "user_id", "users", "CASCADE")
requireForeignKeyOnDelete(t, tx, "auth_identity_channels", "identity_id", "auth_identities", "CASCADE")
requireForeignKeyOnDelete(t, tx, "pending_auth_sessions", "target_user_id", "users", "SET NULL")
requireForeignKeyOnDelete(t, tx, "identity_adoption_decisions", "pending_auth_session_id", "pending_auth_sessions", "CASCADE")
requireForeignKeyOnDelete(t, tx, "identity_adoption_decisions", "identity_id", "auth_identities", "SET NULL")
requireIndex(t, tx, "payment_orders", "paymentorder_out_trade_no")
requirePartialUniqueIndexDefinition(t, tx, "payment_orders", "paymentorder_out_trade_no", "out_trade_no", "WHERE")
requireIndexAbsent(t, tx, "payment_orders", "paymentorder_out_trade_no_unique")
}
func requireIndex(t *testing.T, tx *sql.Tx, table, index string) { func requireIndex(t *testing.T, tx *sql.Tx, table, index string) {
t.Helper() t.Helper()
@@ -106,6 +122,79 @@ SELECT EXISTS (
require.True(t, exists, "expected index %s on %s", index, table) require.True(t, exists, "expected index %s on %s", index, table)
} }
func requireIndexAbsent(t *testing.T, tx *sql.Tx, table, index string) {
t.Helper()
var exists bool
err := tx.QueryRowContext(context.Background(), `
SELECT EXISTS (
SELECT 1
FROM pg_indexes
WHERE schemaname = 'public'
AND tablename = $1
AND indexname = $2
)
`, table, index).Scan(&exists)
require.NoError(t, err, "query pg_indexes for %s.%s", table, index)
require.False(t, exists, "expected index %s on %s to be absent", index, table)
}
func requirePartialUniqueIndexDefinition(t *testing.T, tx *sql.Tx, table, index string, fragments ...string) {
t.Helper()
var (
unique bool
def string
)
err := tx.QueryRowContext(context.Background(), `
SELECT
i.indisunique,
pg_get_indexdef(i.indexrelid)
FROM pg_class idx
JOIN pg_index i ON i.indexrelid = idx.oid
JOIN pg_class tbl ON tbl.oid = i.indrelid
JOIN pg_namespace ns ON ns.oid = tbl.relnamespace
WHERE ns.nspname = 'public'
AND tbl.relname = $1
AND idx.relname = $2
`, table, index).Scan(&unique, &def)
require.NoError(t, err, "query index definition for %s.%s", table, index)
require.True(t, unique, "expected index %s on %s to be unique", index, table)
for _, fragment := range fragments {
require.Contains(t, def, fragment, "expected index definition for %s.%s to contain %q", table, index, fragment)
}
}
func requireForeignKeyOnDelete(t *testing.T, tx *sql.Tx, table, column, refTable, expected string) {
t.Helper()
var actual string
err := tx.QueryRowContext(context.Background(), `
SELECT CASE c.confdeltype
WHEN 'a' THEN 'NO ACTION'
WHEN 'r' THEN 'RESTRICT'
WHEN 'c' THEN 'CASCADE'
WHEN 'n' THEN 'SET NULL'
WHEN 'd' THEN 'SET DEFAULT'
END
FROM pg_constraint c
JOIN pg_class tbl ON tbl.oid = c.conrelid
JOIN pg_namespace ns ON ns.oid = tbl.relnamespace
JOIN pg_class ref_tbl ON ref_tbl.oid = c.confrelid
JOIN pg_attribute attr ON attr.attrelid = tbl.oid AND attr.attnum = ANY(c.conkey)
WHERE ns.nspname = 'public'
AND c.contype = 'f'
AND tbl.relname = $1
AND attr.attname = $2
AND ref_tbl.relname = $3
LIMIT 1
`, table, column, refTable).Scan(&actual)
require.NoError(t, err, "query foreign key action for %s.%s -> %s", table, column, refTable)
require.Equal(t, expected, actual, "unexpected ON DELETE action for %s.%s -> %s", table, column, refTable)
}
func requireColumn(t *testing.T, tx *sql.Tx, table, column, dataType string, maxLen int, nullable bool) { func requireColumn(t *testing.T, tx *sql.Tx, table, column, dataType string, maxLen int, nullable bool) {
t.Helper() t.Helper()

View File

@@ -0,0 +1,14 @@
DO $$
BEGIN
IF EXISTS (
SELECT 1
FROM information_schema.columns
WHERE table_schema = 'public'
AND table_name = 'auth_identity_migration_reports'
AND column_name = 'report_type'
AND COALESCE(character_maximum_length, 0) < 80
) THEN
ALTER TABLE auth_identity_migration_reports
ALTER COLUMN report_type TYPE VARCHAR(80);
END IF;
END $$;

View File

@@ -0,0 +1,22 @@
DO $$
BEGIN
IF EXISTS (
SELECT 1
FROM pg_indexes
WHERE schemaname = 'public'
AND tablename = 'payment_orders'
AND indexname = 'paymentorder_out_trade_no_unique'
) THEN
IF EXISTS (
SELECT 1
FROM pg_indexes
WHERE schemaname = 'public'
AND tablename = 'payment_orders'
AND indexname = 'paymentorder_out_trade_no'
) THEN
EXECUTE 'DROP INDEX IF EXISTS paymentorder_out_trade_no';
END IF;
EXECUTE 'ALTER INDEX paymentorder_out_trade_no_unique RENAME TO paymentorder_out_trade_no';
END IF;
END $$;

View File

@@ -26,7 +26,14 @@ func TestMigration118DoesNotForceOverwriteAuthSourceGrantDefaults(t *testing.T)
require.True(t, strings.Contains(sql, "ON CONFLICT (key) DO NOTHING")) require.True(t, strings.Contains(sql, "ON CONFLICT (key) DO NOTHING"))
} }
func TestMigration109KeepsPublishedBackfillBodyAndDefersReportTypeWidening(t *testing.T) { func TestAuthIdentityReportTypeWideningRunsBeforeLongReportWritersAndStillReconcilesAt121(t *testing.T) {
preflightContent, err := FS.ReadFile("108a_widen_auth_identity_migration_report_type.sql")
require.NoError(t, err)
preflightSQL := string(preflightContent)
require.Contains(t, preflightSQL, "ALTER TABLE auth_identity_migration_reports")
require.Contains(t, preflightSQL, "ALTER COLUMN report_type TYPE VARCHAR(80)")
content, err := FS.ReadFile("109_auth_identity_compat_backfill.sql") content, err := FS.ReadFile("109_auth_identity_compat_backfill.sql")
require.NoError(t, err) require.NoError(t, err)
@@ -58,6 +65,13 @@ func TestMigration119DefersPaymentIndexRolloutToOnlineFollowup(t *testing.T) {
require.Contains(t, followupSQL, "CREATE UNIQUE INDEX CONCURRENTLY IF NOT EXISTS paymentorder_out_trade_no_unique") require.Contains(t, followupSQL, "CREATE UNIQUE INDEX CONCURRENTLY IF NOT EXISTS paymentorder_out_trade_no_unique")
require.Contains(t, followupSQL, "DROP INDEX CONCURRENTLY IF EXISTS paymentorder_out_trade_no") require.Contains(t, followupSQL, "DROP INDEX CONCURRENTLY IF EXISTS paymentorder_out_trade_no")
require.Contains(t, followupSQL, "WHERE out_trade_no <> ''") require.Contains(t, followupSQL, "WHERE out_trade_no <> ''")
alignmentContent, err := FS.ReadFile("120a_align_payment_orders_out_trade_no_index_name.sql")
require.NoError(t, err)
alignmentSQL := string(alignmentContent)
require.Contains(t, alignmentSQL, "paymentorder_out_trade_no_unique")
require.Contains(t, alignmentSQL, "RENAME TO paymentorder_out_trade_no")
} }
func TestMigration122ScrubsPendingOAuthCompletionTokensAtRest(t *testing.T) { func TestMigration122ScrubsPendingOAuthCompletionTokensAtRest(t *testing.T) {