fix: harden auth identity legacy migrations

This commit is contained in:
IanShaw027
2026-04-21 01:30:37 +08:00
parent a70f7aca07
commit 0a461d8248
5 changed files with 329 additions and 4 deletions

View File

@@ -205,6 +205,199 @@ FROM auth_identity_migration_reports
require.Equal(t, beforeCount, afterCount)
}
func TestAuthIdentityLegacyExternalMigrations_ChainHandlesMalformedAndNonObjectMetadata(t *testing.T) {
tx := testTx(t)
ctx := context.Background()
migration115Path := filepath.Join("..", "..", "migrations", "115_auth_identity_legacy_external_backfill.sql")
migration115SQL, err := os.ReadFile(migration115Path)
require.NoError(t, err)
migration116Path := filepath.Join("..", "..", "migrations", "116_auth_identity_legacy_external_safety_reports.sql")
migration116SQL, err := os.ReadFile(migration116Path)
require.NoError(t, err)
_, err = tx.ExecContext(ctx, `
CREATE TABLE IF NOT EXISTS user_external_identities (
id BIGSERIAL PRIMARY KEY,
user_id BIGINT NOT NULL,
provider TEXT NOT NULL,
provider_user_id TEXT NOT NULL,
provider_union_id TEXT NULL,
provider_username TEXT NOT NULL DEFAULT '',
display_name TEXT NOT NULL DEFAULT '',
profile_url TEXT NOT NULL DEFAULT '',
avatar_url TEXT NOT NULL DEFAULT '',
metadata TEXT NOT NULL DEFAULT '{}',
created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP
);
TRUNCATE TABLE
auth_identity_channels,
auth_identities,
auth_identity_migration_reports,
user_external_identities,
users
RESTART IDENTITY;
`)
require.NoError(t, err)
var linuxDoMalformedUserID int64
require.NoError(t, tx.QueryRowContext(ctx, `
INSERT INTO users (email, password_hash, role, status, balance, concurrency)
VALUES ('legacy-linuxdo-malformed@example.com', 'hash', 'user', 'active', 0, 1)
RETURNING id`).Scan(&linuxDoMalformedUserID))
var linuxDoArrayUserID int64
require.NoError(t, tx.QueryRowContext(ctx, `
INSERT INTO users (email, password_hash, role, status, balance, concurrency)
VALUES ('legacy-linuxdo-array@example.com', 'hash', 'user', 'active', 0, 1)
RETURNING id`).Scan(&linuxDoArrayUserID))
var wechatUnionArrayUserID int64
require.NoError(t, tx.QueryRowContext(ctx, `
INSERT INTO users (email, password_hash, role, status, balance, concurrency)
VALUES ('legacy-wechat-array@example.com', 'hash', 'user', 'active', 0, 1)
RETURNING id`).Scan(&wechatUnionArrayUserID))
var wechatOpenIDArrayUserID int64
require.NoError(t, tx.QueryRowContext(ctx, `
INSERT INTO users (email, password_hash, role, status, balance, concurrency)
VALUES ('legacy-wechat-openid-array@example.com', 'hash', 'user', 'active', 0, 1)
RETURNING id`).Scan(&wechatOpenIDArrayUserID))
var linuxDoMalformedLegacyID int64
require.NoError(t, tx.QueryRowContext(ctx, `
INSERT INTO user_external_identities (
user_id,
provider,
provider_user_id,
provider_union_id,
provider_username,
display_name,
metadata
) VALUES ($1, 'linuxdo', 'linuxdo-malformed', NULL, 'legacy-linuxdo-malformed', 'Legacy LinuxDo Malformed', '{invalid')
RETURNING id
`, linuxDoMalformedUserID).Scan(&linuxDoMalformedLegacyID))
var linuxDoArrayLegacyID int64
require.NoError(t, tx.QueryRowContext(ctx, `
INSERT INTO user_external_identities (
user_id,
provider,
provider_user_id,
provider_union_id,
provider_username,
display_name,
metadata
) VALUES ($1, 'linuxdo', 'linuxdo-array', NULL, 'legacy-linuxdo-array', 'Legacy LinuxDo Array', '["legacy-linuxdo-array"]')
RETURNING id
`, linuxDoArrayUserID).Scan(&linuxDoArrayLegacyID))
var wechatUnionArrayLegacyID int64
require.NoError(t, tx.QueryRowContext(ctx, `
INSERT INTO user_external_identities (
user_id,
provider,
provider_user_id,
provider_union_id,
provider_username,
display_name,
metadata
) VALUES ($1, 'wechat', 'openid-array', 'union-array', 'legacy-wechat-array', 'Legacy WeChat Array', '["legacy-wechat-array"]')
RETURNING id
`, wechatUnionArrayUserID).Scan(&wechatUnionArrayLegacyID))
var wechatOpenIDArrayLegacyID int64
require.NoError(t, tx.QueryRowContext(ctx, `
INSERT INTO user_external_identities (
user_id,
provider,
provider_user_id,
provider_union_id,
provider_username,
display_name,
metadata
) VALUES ($1, 'wechat', 'openid-array-only', NULL, 'legacy-wechat-array-only', 'Legacy WeChat Array Only', '["legacy-wechat-openid-array"]')
RETURNING id
`, wechatOpenIDArrayUserID).Scan(&wechatOpenIDArrayLegacyID))
_, err = tx.ExecContext(ctx, string(migration115SQL))
require.NoError(t, err)
_, err = tx.ExecContext(ctx, string(migration116SQL))
require.NoError(t, err)
var linuxDoMalformedMetadataType string
require.NoError(t, tx.QueryRowContext(ctx, `
SELECT jsonb_typeof(metadata)
FROM auth_identities
WHERE user_id = $1
AND provider_type = 'linuxdo'
AND provider_key = 'linuxdo'
AND provider_subject = 'linuxdo-malformed'
`, linuxDoMalformedUserID).Scan(&linuxDoMalformedMetadataType))
require.Equal(t, "object", linuxDoMalformedMetadataType)
var linuxDoArrayMetadataType string
require.NoError(t, tx.QueryRowContext(ctx, `
SELECT jsonb_typeof(metadata)
FROM auth_identities
WHERE user_id = $1
AND provider_type = 'linuxdo'
AND provider_key = 'linuxdo'
AND provider_subject = 'linuxdo-array'
`, linuxDoArrayUserID).Scan(&linuxDoArrayMetadataType))
require.Equal(t, "object", linuxDoArrayMetadataType)
var wechatUnionArrayMetadataType string
require.NoError(t, tx.QueryRowContext(ctx, `
SELECT jsonb_typeof(metadata)
FROM auth_identities
WHERE user_id = $1
AND provider_type = 'wechat'
AND provider_key = 'wechat-main'
AND provider_subject = 'union-array'
`, wechatUnionArrayUserID).Scan(&wechatUnionArrayMetadataType))
require.Equal(t, "object", wechatUnionArrayMetadataType)
var invalidJSONReportDetailsType string
require.NoError(t, tx.QueryRowContext(ctx, `
SELECT jsonb_typeof(details)
FROM auth_identity_migration_reports
WHERE report_type = 'legacy_external_identity_invalid_metadata_json'
AND report_key = $1
`, "legacy_external_identity:"+strconv.FormatInt(linuxDoMalformedLegacyID, 10)).Scan(&invalidJSONReportDetailsType))
require.Equal(t, "object", invalidJSONReportDetailsType)
var openIDOnlyReportDetailsType string
require.NoError(t, tx.QueryRowContext(ctx, `
SELECT jsonb_typeof(details)
FROM auth_identity_migration_reports
WHERE report_type = 'wechat_openid_only_requires_remediation'
AND report_key = $1
`, "legacy_external_identity:"+strconv.FormatInt(wechatOpenIDArrayLegacyID, 10)).Scan(&openIDOnlyReportDetailsType))
require.Equal(t, "object", openIDOnlyReportDetailsType)
var preservedArrayMetadataCount int
require.NoError(t, tx.QueryRowContext(ctx, `
SELECT COUNT(*)
FROM auth_identities
WHERE id IN (
SELECT id
FROM auth_identities
WHERE (user_id = $1 AND provider_subject = 'linuxdo-array')
OR (user_id = $2 AND provider_subject = 'union-array')
)
AND metadata ? '_legacy_metadata_raw_json'
`, linuxDoArrayUserID, wechatUnionArrayUserID).Scan(&preservedArrayMetadataCount))
require.Equal(t, 2, preservedArrayMetadataCount)
require.NotZero(t, linuxDoArrayLegacyID)
require.NotZero(t, wechatUnionArrayLegacyID)
}
func TestAuthIdentityLegacyExternalSafetyMigration_ReportsConflictsAndDowngradesInvalidJSON(t *testing.T) {
tx := testTx(t)
ctx := context.Background()

View File

@@ -26,6 +26,10 @@ var (
"AUTH_IDENTITY_CHANNEL_OWNERSHIP_CONFLICT",
"auth identity channel already belongs to another user",
)
ErrAuthIdentityChannelProviderMismatch = infraerrors.BadRequest(
"AUTH_IDENTITY_CHANNEL_PROVIDER_MISMATCH",
"auth identity channel provider must match canonical identity",
)
)
type ProviderGrantReason string
@@ -133,6 +137,10 @@ func (r *userRepository) WithUserProfileIdentityTx(ctx context.Context, fn func(
}
func (r *userRepository) CreateAuthIdentity(ctx context.Context, input CreateAuthIdentityInput) (*CreateAuthIdentityResult, error) {
if err := validateAuthIdentityChannelProviderMatch(input.Canonical, input.Channel); err != nil {
return nil, err
}
client := clientFromContext(ctx, r.client)
create := client.AuthIdentity.Create().
@@ -240,6 +248,10 @@ func (r *userRepository) ListUserAuthIdentities(ctx context.Context, userID int6
}
func (r *userRepository) BindAuthIdentityToUser(ctx context.Context, input BindAuthIdentityInput) (*CreateAuthIdentityResult, error) {
if err := validateAuthIdentityChannelProviderMatch(input.Canonical, input.Channel); err != nil {
return nil, err
}
var result *CreateAuthIdentityResult
err := r.WithUserProfileIdentityTx(ctx, func(txCtx context.Context) error {
client := clientFromContext(txCtx, r.client)
@@ -531,6 +543,23 @@ func copyMetadata(in map[string]any) map[string]any {
return out
}
func validateAuthIdentityChannelProviderMatch(canonical AuthIdentityKey, channel *AuthIdentityChannelKey) error {
if channel == nil {
return nil
}
canonicalProviderType := strings.TrimSpace(canonical.ProviderType)
canonicalProviderKey := strings.TrimSpace(canonical.ProviderKey)
channelProviderType := strings.TrimSpace(channel.ProviderType)
channelProviderKey := strings.TrimSpace(channel.ProviderKey)
if canonicalProviderType != channelProviderType || canonicalProviderKey != channelProviderKey {
return ErrAuthIdentityChannelProviderMismatch
}
return nil
}
func txAwareSQLExecutor(ctx context.Context, fallback sqlExecutor, client *dbent.Client) sqlQueryExecutor {
if tx := dbent.TxFromContext(ctx); tx != nil {
if exec := sqlExecutorFromEntClient(tx.Client()); exec != nil {

View File

@@ -187,6 +187,48 @@ func (s *UserProfileIdentityRepoSuite) TestBindAuthIdentityToUser_IsIdempotentAn
s.Require().ErrorIs(err, ErrAuthIdentityChannelOwnershipConflict)
}
func (s *UserProfileIdentityRepoSuite) TestCreateAuthIdentity_RejectsChannelProviderMismatch() {
user := s.mustCreateUser("provider-mismatch-create")
_, err := s.repo.CreateAuthIdentity(s.ctx, CreateAuthIdentityInput{
UserID: user.ID,
Canonical: AuthIdentityKey{
ProviderType: "wechat",
ProviderKey: "wechat-main",
ProviderSubject: "union-create-mismatch",
},
Channel: &AuthIdentityChannelKey{
ProviderType: "linuxdo",
ProviderKey: "linuxdo-main",
Channel: "oauth",
ChannelAppID: "app-mismatch",
ChannelSubject: "openid-create-mismatch",
},
})
s.Require().ErrorIs(err, ErrAuthIdentityChannelProviderMismatch)
}
func (s *UserProfileIdentityRepoSuite) TestBindAuthIdentityToUser_RejectsChannelProviderMismatch() {
user := s.mustCreateUser("provider-mismatch-bind")
_, err := s.repo.BindAuthIdentityToUser(s.ctx, BindAuthIdentityInput{
UserID: user.ID,
Canonical: AuthIdentityKey{
ProviderType: "wechat",
ProviderKey: "wechat-main",
ProviderSubject: "union-bind-mismatch",
},
Channel: &AuthIdentityChannelKey{
ProviderType: "wechat",
ProviderKey: "wechat-legacy",
Channel: "oa",
ChannelAppID: "wx-app-bind-mismatch",
ChannelSubject: "openid-bind-mismatch",
},
})
s.Require().ErrorIs(err, ErrAuthIdentityChannelProviderMismatch)
}
func (s *UserProfileIdentityRepoSuite) TestWithUserProfileIdentityTx_RollsBackIdentityAndGrantOnError() {
user := s.mustCreateUser("tx-rollback")
expectedErr := errors.New("rollback")

View File

@@ -1,3 +1,29 @@
CREATE OR REPLACE FUNCTION public.__migration_115_safe_legacy_metadata_jsonb(input_text TEXT)
RETURNS JSONB
LANGUAGE plpgsql
AS $$
DECLARE
parsed JSONB;
BEGIN
IF input_text IS NULL OR BTRIM(input_text) = '' THEN
RETURN '{}'::jsonb;
END IF;
BEGIN
parsed := input_text::jsonb;
EXCEPTION
WHEN OTHERS THEN
RETURN '{}'::jsonb;
END;
IF jsonb_typeof(parsed) = 'object' THEN
RETURN parsed;
END IF;
RETURN jsonb_build_object('_legacy_metadata_raw_json', parsed);
END;
$$;
DO $$
BEGIN
IF to_regclass('public.user_external_identities') IS NULL THEN
@@ -33,7 +59,7 @@ FROM (
BTRIM(uei.provider_user_id) AS provider_user_id,
BTRIM(uei.provider_username) AS provider_username,
BTRIM(uei.display_name) AS display_name,
COALESCE(NULLIF(BTRIM(COALESCE(uei.metadata, '')), '')::jsonb, '{}'::jsonb) AS metadata_json,
public.__migration_115_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json,
uei.created_at,
uei.updated_at
FROM user_external_identities AS uei
@@ -78,7 +104,7 @@ FROM (
BTRIM(uei.provider_union_id) AS provider_union_id,
BTRIM(uei.provider_username) AS provider_username,
BTRIM(uei.display_name) AS display_name,
COALESCE(NULLIF(BTRIM(COALESCE(uei.metadata, '')), '')::jsonb, '{}'::jsonb) AS metadata_json,
public.__migration_115_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json,
uei.created_at,
uei.updated_at
FROM user_external_identities AS uei
@@ -123,7 +149,7 @@ FROM (
FROM user_external_identities AS uei
JOIN users AS u ON u.id = uei.user_id
CROSS JOIN LATERAL (
SELECT COALESCE(NULLIF(BTRIM(COALESCE(uei.metadata, '')), '')::jsonb, '{}'::jsonb) AS metadata_json
SELECT public.__migration_115_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json
) AS meta
WHERE u.deleted_at IS NULL
AND LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat'
@@ -157,7 +183,7 @@ FROM (
uei.id,
uei.user_id,
BTRIM(uei.provider_user_id) AS provider_user_id,
COALESCE(NULLIF(BTRIM(COALESCE(uei.metadata, '')), '')::jsonb, '{}'::jsonb) AS metadata_json
public.__migration_115_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json
FROM user_external_identities AS uei
JOIN users AS u ON u.id = uei.user_id
WHERE u.deleted_at IS NULL
@@ -185,3 +211,5 @@ WHERE ai.provider_type = 'wechat'
AND COALESCE(ai.metadata ->> 'backfill_source', '') = 'synthetic_email'
AND BTRIM(COALESCE(ai.metadata ->> 'unionid', '')) = ''
ON CONFLICT (report_type, report_key) DO NOTHING;
DROP FUNCTION IF EXISTS public.__migration_115_safe_legacy_metadata_jsonb(TEXT);

View File

@@ -332,5 +332,38 @@ ON CONFLICT (report_type, report_key) DO NOTHING;
$sql$;
END $$;
DO $$
BEGIN
IF NOT EXISTS (
SELECT 1
FROM pg_constraint
WHERE conname = 'auth_identities_metadata_is_object_check'
) THEN
ALTER TABLE auth_identities
ADD CONSTRAINT auth_identities_metadata_is_object_check
CHECK (jsonb_typeof(metadata) = 'object');
END IF;
IF NOT EXISTS (
SELECT 1
FROM pg_constraint
WHERE conname = 'auth_identity_channels_metadata_is_object_check'
) THEN
ALTER TABLE auth_identity_channels
ADD CONSTRAINT auth_identity_channels_metadata_is_object_check
CHECK (jsonb_typeof(metadata) = 'object');
END IF;
IF NOT EXISTS (
SELECT 1
FROM pg_constraint
WHERE conname = 'auth_identity_migration_reports_details_is_object_check'
) THEN
ALTER TABLE auth_identity_migration_reports
ADD CONSTRAINT auth_identity_migration_reports_details_is_object_check
CHECK (jsonb_typeof(details) = 'object');
END IF;
END $$;
DROP FUNCTION IF EXISTS public.__migration_116_is_valid_legacy_metadata_jsonb(TEXT);
DROP FUNCTION IF EXISTS public.__migration_116_safe_legacy_metadata_jsonb(TEXT);