diff --git a/backend/internal/repository/auth_identity_legacy_migration_integration_test.go b/backend/internal/repository/auth_identity_legacy_migration_integration_test.go index 7f2f363f..29654417 100644 --- a/backend/internal/repository/auth_identity_legacy_migration_integration_test.go +++ b/backend/internal/repository/auth_identity_legacy_migration_integration_test.go @@ -205,6 +205,199 @@ FROM auth_identity_migration_reports require.Equal(t, beforeCount, afterCount) } +func TestAuthIdentityLegacyExternalMigrations_ChainHandlesMalformedAndNonObjectMetadata(t *testing.T) { + tx := testTx(t) + ctx := context.Background() + + migration115Path := filepath.Join("..", "..", "migrations", "115_auth_identity_legacy_external_backfill.sql") + migration115SQL, err := os.ReadFile(migration115Path) + require.NoError(t, err) + + migration116Path := filepath.Join("..", "..", "migrations", "116_auth_identity_legacy_external_safety_reports.sql") + migration116SQL, err := os.ReadFile(migration116Path) + require.NoError(t, err) + + _, err = tx.ExecContext(ctx, ` +CREATE TABLE IF NOT EXISTS user_external_identities ( + id BIGSERIAL PRIMARY KEY, + user_id BIGINT NOT NULL, + provider TEXT NOT NULL, + provider_user_id TEXT NOT NULL, + provider_union_id TEXT NULL, + provider_username TEXT NOT NULL DEFAULT '', + display_name TEXT NOT NULL DEFAULT '', + profile_url TEXT NOT NULL DEFAULT '', + avatar_url TEXT NOT NULL DEFAULT '', + metadata TEXT NOT NULL DEFAULT '{}', + created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP +); + +TRUNCATE TABLE + auth_identity_channels, + auth_identities, + auth_identity_migration_reports, + user_external_identities, + users +RESTART IDENTITY; +`) + require.NoError(t, err) + + var linuxDoMalformedUserID int64 + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO users (email, password_hash, role, status, balance, concurrency) +VALUES ('legacy-linuxdo-malformed@example.com', 'hash', 'user', 'active', 0, 1) +RETURNING id`).Scan(&linuxDoMalformedUserID)) + + var linuxDoArrayUserID int64 + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO users (email, password_hash, role, status, balance, concurrency) +VALUES ('legacy-linuxdo-array@example.com', 'hash', 'user', 'active', 0, 1) +RETURNING id`).Scan(&linuxDoArrayUserID)) + + var wechatUnionArrayUserID int64 + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO users (email, password_hash, role, status, balance, concurrency) +VALUES ('legacy-wechat-array@example.com', 'hash', 'user', 'active', 0, 1) +RETURNING id`).Scan(&wechatUnionArrayUserID)) + + var wechatOpenIDArrayUserID int64 + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO users (email, password_hash, role, status, balance, concurrency) +VALUES ('legacy-wechat-openid-array@example.com', 'hash', 'user', 'active', 0, 1) +RETURNING id`).Scan(&wechatOpenIDArrayUserID)) + + var linuxDoMalformedLegacyID int64 + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO user_external_identities ( + user_id, + provider, + provider_user_id, + provider_union_id, + provider_username, + display_name, + metadata +) VALUES ($1, 'linuxdo', 'linuxdo-malformed', NULL, 'legacy-linuxdo-malformed', 'Legacy LinuxDo Malformed', '{invalid') +RETURNING id +`, linuxDoMalformedUserID).Scan(&linuxDoMalformedLegacyID)) + + var linuxDoArrayLegacyID int64 + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO user_external_identities ( + user_id, + provider, + provider_user_id, + provider_union_id, + provider_username, + display_name, + metadata +) VALUES ($1, 'linuxdo', 'linuxdo-array', NULL, 'legacy-linuxdo-array', 'Legacy LinuxDo Array', '["legacy-linuxdo-array"]') +RETURNING id +`, linuxDoArrayUserID).Scan(&linuxDoArrayLegacyID)) + + var wechatUnionArrayLegacyID int64 + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO user_external_identities ( + user_id, + provider, + provider_user_id, + provider_union_id, + provider_username, + display_name, + metadata +) VALUES ($1, 'wechat', 'openid-array', 'union-array', 'legacy-wechat-array', 'Legacy WeChat Array', '["legacy-wechat-array"]') +RETURNING id +`, wechatUnionArrayUserID).Scan(&wechatUnionArrayLegacyID)) + + var wechatOpenIDArrayLegacyID int64 + require.NoError(t, tx.QueryRowContext(ctx, ` +INSERT INTO user_external_identities ( + user_id, + provider, + provider_user_id, + provider_union_id, + provider_username, + display_name, + metadata +) VALUES ($1, 'wechat', 'openid-array-only', NULL, 'legacy-wechat-array-only', 'Legacy WeChat Array Only', '["legacy-wechat-openid-array"]') +RETURNING id +`, wechatOpenIDArrayUserID).Scan(&wechatOpenIDArrayLegacyID)) + + _, err = tx.ExecContext(ctx, string(migration115SQL)) + require.NoError(t, err) + + _, err = tx.ExecContext(ctx, string(migration116SQL)) + require.NoError(t, err) + + var linuxDoMalformedMetadataType string + require.NoError(t, tx.QueryRowContext(ctx, ` +SELECT jsonb_typeof(metadata) +FROM auth_identities +WHERE user_id = $1 + AND provider_type = 'linuxdo' + AND provider_key = 'linuxdo' + AND provider_subject = 'linuxdo-malformed' +`, linuxDoMalformedUserID).Scan(&linuxDoMalformedMetadataType)) + require.Equal(t, "object", linuxDoMalformedMetadataType) + + var linuxDoArrayMetadataType string + require.NoError(t, tx.QueryRowContext(ctx, ` +SELECT jsonb_typeof(metadata) +FROM auth_identities +WHERE user_id = $1 + AND provider_type = 'linuxdo' + AND provider_key = 'linuxdo' + AND provider_subject = 'linuxdo-array' +`, linuxDoArrayUserID).Scan(&linuxDoArrayMetadataType)) + require.Equal(t, "object", linuxDoArrayMetadataType) + + var wechatUnionArrayMetadataType string + require.NoError(t, tx.QueryRowContext(ctx, ` +SELECT jsonb_typeof(metadata) +FROM auth_identities +WHERE user_id = $1 + AND provider_type = 'wechat' + AND provider_key = 'wechat-main' + AND provider_subject = 'union-array' +`, wechatUnionArrayUserID).Scan(&wechatUnionArrayMetadataType)) + require.Equal(t, "object", wechatUnionArrayMetadataType) + + var invalidJSONReportDetailsType string + require.NoError(t, tx.QueryRowContext(ctx, ` +SELECT jsonb_typeof(details) +FROM auth_identity_migration_reports +WHERE report_type = 'legacy_external_identity_invalid_metadata_json' + AND report_key = $1 +`, "legacy_external_identity:"+strconv.FormatInt(linuxDoMalformedLegacyID, 10)).Scan(&invalidJSONReportDetailsType)) + require.Equal(t, "object", invalidJSONReportDetailsType) + + var openIDOnlyReportDetailsType string + require.NoError(t, tx.QueryRowContext(ctx, ` +SELECT jsonb_typeof(details) +FROM auth_identity_migration_reports +WHERE report_type = 'wechat_openid_only_requires_remediation' + AND report_key = $1 +`, "legacy_external_identity:"+strconv.FormatInt(wechatOpenIDArrayLegacyID, 10)).Scan(&openIDOnlyReportDetailsType)) + require.Equal(t, "object", openIDOnlyReportDetailsType) + + var preservedArrayMetadataCount int + require.NoError(t, tx.QueryRowContext(ctx, ` +SELECT COUNT(*) +FROM auth_identities +WHERE id IN ( + SELECT id + FROM auth_identities + WHERE (user_id = $1 AND provider_subject = 'linuxdo-array') + OR (user_id = $2 AND provider_subject = 'union-array') +) + AND metadata ? '_legacy_metadata_raw_json' +`, linuxDoArrayUserID, wechatUnionArrayUserID).Scan(&preservedArrayMetadataCount)) + require.Equal(t, 2, preservedArrayMetadataCount) + + require.NotZero(t, linuxDoArrayLegacyID) + require.NotZero(t, wechatUnionArrayLegacyID) +} + func TestAuthIdentityLegacyExternalSafetyMigration_ReportsConflictsAndDowngradesInvalidJSON(t *testing.T) { tx := testTx(t) ctx := context.Background() diff --git a/backend/internal/repository/user_profile_identity_repo.go b/backend/internal/repository/user_profile_identity_repo.go index c1b4b6bf..dbba364d 100644 --- a/backend/internal/repository/user_profile_identity_repo.go +++ b/backend/internal/repository/user_profile_identity_repo.go @@ -26,6 +26,10 @@ var ( "AUTH_IDENTITY_CHANNEL_OWNERSHIP_CONFLICT", "auth identity channel already belongs to another user", ) + ErrAuthIdentityChannelProviderMismatch = infraerrors.BadRequest( + "AUTH_IDENTITY_CHANNEL_PROVIDER_MISMATCH", + "auth identity channel provider must match canonical identity", + ) ) type ProviderGrantReason string @@ -133,6 +137,10 @@ func (r *userRepository) WithUserProfileIdentityTx(ctx context.Context, fn func( } func (r *userRepository) CreateAuthIdentity(ctx context.Context, input CreateAuthIdentityInput) (*CreateAuthIdentityResult, error) { + if err := validateAuthIdentityChannelProviderMatch(input.Canonical, input.Channel); err != nil { + return nil, err + } + client := clientFromContext(ctx, r.client) create := client.AuthIdentity.Create(). @@ -240,6 +248,10 @@ func (r *userRepository) ListUserAuthIdentities(ctx context.Context, userID int6 } func (r *userRepository) BindAuthIdentityToUser(ctx context.Context, input BindAuthIdentityInput) (*CreateAuthIdentityResult, error) { + if err := validateAuthIdentityChannelProviderMatch(input.Canonical, input.Channel); err != nil { + return nil, err + } + var result *CreateAuthIdentityResult err := r.WithUserProfileIdentityTx(ctx, func(txCtx context.Context) error { client := clientFromContext(txCtx, r.client) @@ -531,6 +543,23 @@ func copyMetadata(in map[string]any) map[string]any { return out } +func validateAuthIdentityChannelProviderMatch(canonical AuthIdentityKey, channel *AuthIdentityChannelKey) error { + if channel == nil { + return nil + } + + canonicalProviderType := strings.TrimSpace(canonical.ProviderType) + canonicalProviderKey := strings.TrimSpace(canonical.ProviderKey) + channelProviderType := strings.TrimSpace(channel.ProviderType) + channelProviderKey := strings.TrimSpace(channel.ProviderKey) + + if canonicalProviderType != channelProviderType || canonicalProviderKey != channelProviderKey { + return ErrAuthIdentityChannelProviderMismatch + } + + return nil +} + func txAwareSQLExecutor(ctx context.Context, fallback sqlExecutor, client *dbent.Client) sqlQueryExecutor { if tx := dbent.TxFromContext(ctx); tx != nil { if exec := sqlExecutorFromEntClient(tx.Client()); exec != nil { diff --git a/backend/internal/repository/user_profile_identity_repo_contract_test.go b/backend/internal/repository/user_profile_identity_repo_contract_test.go index 19022ec1..a02af62b 100644 --- a/backend/internal/repository/user_profile_identity_repo_contract_test.go +++ b/backend/internal/repository/user_profile_identity_repo_contract_test.go @@ -187,6 +187,48 @@ func (s *UserProfileIdentityRepoSuite) TestBindAuthIdentityToUser_IsIdempotentAn s.Require().ErrorIs(err, ErrAuthIdentityChannelOwnershipConflict) } +func (s *UserProfileIdentityRepoSuite) TestCreateAuthIdentity_RejectsChannelProviderMismatch() { + user := s.mustCreateUser("provider-mismatch-create") + + _, err := s.repo.CreateAuthIdentity(s.ctx, CreateAuthIdentityInput{ + UserID: user.ID, + Canonical: AuthIdentityKey{ + ProviderType: "wechat", + ProviderKey: "wechat-main", + ProviderSubject: "union-create-mismatch", + }, + Channel: &AuthIdentityChannelKey{ + ProviderType: "linuxdo", + ProviderKey: "linuxdo-main", + Channel: "oauth", + ChannelAppID: "app-mismatch", + ChannelSubject: "openid-create-mismatch", + }, + }) + s.Require().ErrorIs(err, ErrAuthIdentityChannelProviderMismatch) +} + +func (s *UserProfileIdentityRepoSuite) TestBindAuthIdentityToUser_RejectsChannelProviderMismatch() { + user := s.mustCreateUser("provider-mismatch-bind") + + _, err := s.repo.BindAuthIdentityToUser(s.ctx, BindAuthIdentityInput{ + UserID: user.ID, + Canonical: AuthIdentityKey{ + ProviderType: "wechat", + ProviderKey: "wechat-main", + ProviderSubject: "union-bind-mismatch", + }, + Channel: &AuthIdentityChannelKey{ + ProviderType: "wechat", + ProviderKey: "wechat-legacy", + Channel: "oa", + ChannelAppID: "wx-app-bind-mismatch", + ChannelSubject: "openid-bind-mismatch", + }, + }) + s.Require().ErrorIs(err, ErrAuthIdentityChannelProviderMismatch) +} + func (s *UserProfileIdentityRepoSuite) TestWithUserProfileIdentityTx_RollsBackIdentityAndGrantOnError() { user := s.mustCreateUser("tx-rollback") expectedErr := errors.New("rollback") diff --git a/backend/migrations/115_auth_identity_legacy_external_backfill.sql b/backend/migrations/115_auth_identity_legacy_external_backfill.sql index f4a13c36..7a20f8eb 100644 --- a/backend/migrations/115_auth_identity_legacy_external_backfill.sql +++ b/backend/migrations/115_auth_identity_legacy_external_backfill.sql @@ -1,3 +1,29 @@ +CREATE OR REPLACE FUNCTION public.__migration_115_safe_legacy_metadata_jsonb(input_text TEXT) +RETURNS JSONB +LANGUAGE plpgsql +AS $$ +DECLARE + parsed JSONB; +BEGIN + IF input_text IS NULL OR BTRIM(input_text) = '' THEN + RETURN '{}'::jsonb; + END IF; + + BEGIN + parsed := input_text::jsonb; + EXCEPTION + WHEN OTHERS THEN + RETURN '{}'::jsonb; + END; + + IF jsonb_typeof(parsed) = 'object' THEN + RETURN parsed; + END IF; + + RETURN jsonb_build_object('_legacy_metadata_raw_json', parsed); +END; +$$; + DO $$ BEGIN IF to_regclass('public.user_external_identities') IS NULL THEN @@ -33,7 +59,7 @@ FROM ( BTRIM(uei.provider_user_id) AS provider_user_id, BTRIM(uei.provider_username) AS provider_username, BTRIM(uei.display_name) AS display_name, - COALESCE(NULLIF(BTRIM(COALESCE(uei.metadata, '')), '')::jsonb, '{}'::jsonb) AS metadata_json, + public.__migration_115_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json, uei.created_at, uei.updated_at FROM user_external_identities AS uei @@ -78,7 +104,7 @@ FROM ( BTRIM(uei.provider_union_id) AS provider_union_id, BTRIM(uei.provider_username) AS provider_username, BTRIM(uei.display_name) AS display_name, - COALESCE(NULLIF(BTRIM(COALESCE(uei.metadata, '')), '')::jsonb, '{}'::jsonb) AS metadata_json, + public.__migration_115_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json, uei.created_at, uei.updated_at FROM user_external_identities AS uei @@ -123,7 +149,7 @@ FROM ( FROM user_external_identities AS uei JOIN users AS u ON u.id = uei.user_id CROSS JOIN LATERAL ( - SELECT COALESCE(NULLIF(BTRIM(COALESCE(uei.metadata, '')), '')::jsonb, '{}'::jsonb) AS metadata_json + SELECT public.__migration_115_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json ) AS meta WHERE u.deleted_at IS NULL AND LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' @@ -157,7 +183,7 @@ FROM ( uei.id, uei.user_id, BTRIM(uei.provider_user_id) AS provider_user_id, - COALESCE(NULLIF(BTRIM(COALESCE(uei.metadata, '')), '')::jsonb, '{}'::jsonb) AS metadata_json + public.__migration_115_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json FROM user_external_identities AS uei JOIN users AS u ON u.id = uei.user_id WHERE u.deleted_at IS NULL @@ -185,3 +211,5 @@ WHERE ai.provider_type = 'wechat' AND COALESCE(ai.metadata ->> 'backfill_source', '') = 'synthetic_email' AND BTRIM(COALESCE(ai.metadata ->> 'unionid', '')) = '' ON CONFLICT (report_type, report_key) DO NOTHING; + +DROP FUNCTION IF EXISTS public.__migration_115_safe_legacy_metadata_jsonb(TEXT); diff --git a/backend/migrations/116_auth_identity_legacy_external_safety_reports.sql b/backend/migrations/116_auth_identity_legacy_external_safety_reports.sql index 994f3f37..3983bb1a 100644 --- a/backend/migrations/116_auth_identity_legacy_external_safety_reports.sql +++ b/backend/migrations/116_auth_identity_legacy_external_safety_reports.sql @@ -332,5 +332,38 @@ ON CONFLICT (report_type, report_key) DO NOTHING; $sql$; END $$; +DO $$ +BEGIN + IF NOT EXISTS ( + SELECT 1 + FROM pg_constraint + WHERE conname = 'auth_identities_metadata_is_object_check' + ) THEN + ALTER TABLE auth_identities + ADD CONSTRAINT auth_identities_metadata_is_object_check + CHECK (jsonb_typeof(metadata) = 'object'); + END IF; + + IF NOT EXISTS ( + SELECT 1 + FROM pg_constraint + WHERE conname = 'auth_identity_channels_metadata_is_object_check' + ) THEN + ALTER TABLE auth_identity_channels + ADD CONSTRAINT auth_identity_channels_metadata_is_object_check + CHECK (jsonb_typeof(metadata) = 'object'); + END IF; + + IF NOT EXISTS ( + SELECT 1 + FROM pg_constraint + WHERE conname = 'auth_identity_migration_reports_details_is_object_check' + ) THEN + ALTER TABLE auth_identity_migration_reports + ADD CONSTRAINT auth_identity_migration_reports_details_is_object_check + CHECK (jsonb_typeof(details) = 'object'); + END IF; +END $$; + DROP FUNCTION IF EXISTS public.__migration_116_is_valid_legacy_metadata_jsonb(TEXT); DROP FUNCTION IF EXISTS public.__migration_116_safe_legacy_metadata_jsonb(TEXT);