fix: harden auth identity legacy migrations
This commit is contained in:
@@ -205,6 +205,199 @@ FROM auth_identity_migration_reports
|
||||
require.Equal(t, beforeCount, afterCount)
|
||||
}
|
||||
|
||||
func TestAuthIdentityLegacyExternalMigrations_ChainHandlesMalformedAndNonObjectMetadata(t *testing.T) {
|
||||
tx := testTx(t)
|
||||
ctx := context.Background()
|
||||
|
||||
migration115Path := filepath.Join("..", "..", "migrations", "115_auth_identity_legacy_external_backfill.sql")
|
||||
migration115SQL, err := os.ReadFile(migration115Path)
|
||||
require.NoError(t, err)
|
||||
|
||||
migration116Path := filepath.Join("..", "..", "migrations", "116_auth_identity_legacy_external_safety_reports.sql")
|
||||
migration116SQL, err := os.ReadFile(migration116Path)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = tx.ExecContext(ctx, `
|
||||
CREATE TABLE IF NOT EXISTS user_external_identities (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
user_id BIGINT NOT NULL,
|
||||
provider TEXT NOT NULL,
|
||||
provider_user_id TEXT NOT NULL,
|
||||
provider_union_id TEXT NULL,
|
||||
provider_username TEXT NOT NULL DEFAULT '',
|
||||
display_name TEXT NOT NULL DEFAULT '',
|
||||
profile_url TEXT NOT NULL DEFAULT '',
|
||||
avatar_url TEXT NOT NULL DEFAULT '',
|
||||
metadata TEXT NOT NULL DEFAULT '{}',
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
TRUNCATE TABLE
|
||||
auth_identity_channels,
|
||||
auth_identities,
|
||||
auth_identity_migration_reports,
|
||||
user_external_identities,
|
||||
users
|
||||
RESTART IDENTITY;
|
||||
`)
|
||||
require.NoError(t, err)
|
||||
|
||||
var linuxDoMalformedUserID int64
|
||||
require.NoError(t, tx.QueryRowContext(ctx, `
|
||||
INSERT INTO users (email, password_hash, role, status, balance, concurrency)
|
||||
VALUES ('legacy-linuxdo-malformed@example.com', 'hash', 'user', 'active', 0, 1)
|
||||
RETURNING id`).Scan(&linuxDoMalformedUserID))
|
||||
|
||||
var linuxDoArrayUserID int64
|
||||
require.NoError(t, tx.QueryRowContext(ctx, `
|
||||
INSERT INTO users (email, password_hash, role, status, balance, concurrency)
|
||||
VALUES ('legacy-linuxdo-array@example.com', 'hash', 'user', 'active', 0, 1)
|
||||
RETURNING id`).Scan(&linuxDoArrayUserID))
|
||||
|
||||
var wechatUnionArrayUserID int64
|
||||
require.NoError(t, tx.QueryRowContext(ctx, `
|
||||
INSERT INTO users (email, password_hash, role, status, balance, concurrency)
|
||||
VALUES ('legacy-wechat-array@example.com', 'hash', 'user', 'active', 0, 1)
|
||||
RETURNING id`).Scan(&wechatUnionArrayUserID))
|
||||
|
||||
var wechatOpenIDArrayUserID int64
|
||||
require.NoError(t, tx.QueryRowContext(ctx, `
|
||||
INSERT INTO users (email, password_hash, role, status, balance, concurrency)
|
||||
VALUES ('legacy-wechat-openid-array@example.com', 'hash', 'user', 'active', 0, 1)
|
||||
RETURNING id`).Scan(&wechatOpenIDArrayUserID))
|
||||
|
||||
var linuxDoMalformedLegacyID int64
|
||||
require.NoError(t, tx.QueryRowContext(ctx, `
|
||||
INSERT INTO user_external_identities (
|
||||
user_id,
|
||||
provider,
|
||||
provider_user_id,
|
||||
provider_union_id,
|
||||
provider_username,
|
||||
display_name,
|
||||
metadata
|
||||
) VALUES ($1, 'linuxdo', 'linuxdo-malformed', NULL, 'legacy-linuxdo-malformed', 'Legacy LinuxDo Malformed', '{invalid')
|
||||
RETURNING id
|
||||
`, linuxDoMalformedUserID).Scan(&linuxDoMalformedLegacyID))
|
||||
|
||||
var linuxDoArrayLegacyID int64
|
||||
require.NoError(t, tx.QueryRowContext(ctx, `
|
||||
INSERT INTO user_external_identities (
|
||||
user_id,
|
||||
provider,
|
||||
provider_user_id,
|
||||
provider_union_id,
|
||||
provider_username,
|
||||
display_name,
|
||||
metadata
|
||||
) VALUES ($1, 'linuxdo', 'linuxdo-array', NULL, 'legacy-linuxdo-array', 'Legacy LinuxDo Array', '["legacy-linuxdo-array"]')
|
||||
RETURNING id
|
||||
`, linuxDoArrayUserID).Scan(&linuxDoArrayLegacyID))
|
||||
|
||||
var wechatUnionArrayLegacyID int64
|
||||
require.NoError(t, tx.QueryRowContext(ctx, `
|
||||
INSERT INTO user_external_identities (
|
||||
user_id,
|
||||
provider,
|
||||
provider_user_id,
|
||||
provider_union_id,
|
||||
provider_username,
|
||||
display_name,
|
||||
metadata
|
||||
) VALUES ($1, 'wechat', 'openid-array', 'union-array', 'legacy-wechat-array', 'Legacy WeChat Array', '["legacy-wechat-array"]')
|
||||
RETURNING id
|
||||
`, wechatUnionArrayUserID).Scan(&wechatUnionArrayLegacyID))
|
||||
|
||||
var wechatOpenIDArrayLegacyID int64
|
||||
require.NoError(t, tx.QueryRowContext(ctx, `
|
||||
INSERT INTO user_external_identities (
|
||||
user_id,
|
||||
provider,
|
||||
provider_user_id,
|
||||
provider_union_id,
|
||||
provider_username,
|
||||
display_name,
|
||||
metadata
|
||||
) VALUES ($1, 'wechat', 'openid-array-only', NULL, 'legacy-wechat-array-only', 'Legacy WeChat Array Only', '["legacy-wechat-openid-array"]')
|
||||
RETURNING id
|
||||
`, wechatOpenIDArrayUserID).Scan(&wechatOpenIDArrayLegacyID))
|
||||
|
||||
_, err = tx.ExecContext(ctx, string(migration115SQL))
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = tx.ExecContext(ctx, string(migration116SQL))
|
||||
require.NoError(t, err)
|
||||
|
||||
var linuxDoMalformedMetadataType string
|
||||
require.NoError(t, tx.QueryRowContext(ctx, `
|
||||
SELECT jsonb_typeof(metadata)
|
||||
FROM auth_identities
|
||||
WHERE user_id = $1
|
||||
AND provider_type = 'linuxdo'
|
||||
AND provider_key = 'linuxdo'
|
||||
AND provider_subject = 'linuxdo-malformed'
|
||||
`, linuxDoMalformedUserID).Scan(&linuxDoMalformedMetadataType))
|
||||
require.Equal(t, "object", linuxDoMalformedMetadataType)
|
||||
|
||||
var linuxDoArrayMetadataType string
|
||||
require.NoError(t, tx.QueryRowContext(ctx, `
|
||||
SELECT jsonb_typeof(metadata)
|
||||
FROM auth_identities
|
||||
WHERE user_id = $1
|
||||
AND provider_type = 'linuxdo'
|
||||
AND provider_key = 'linuxdo'
|
||||
AND provider_subject = 'linuxdo-array'
|
||||
`, linuxDoArrayUserID).Scan(&linuxDoArrayMetadataType))
|
||||
require.Equal(t, "object", linuxDoArrayMetadataType)
|
||||
|
||||
var wechatUnionArrayMetadataType string
|
||||
require.NoError(t, tx.QueryRowContext(ctx, `
|
||||
SELECT jsonb_typeof(metadata)
|
||||
FROM auth_identities
|
||||
WHERE user_id = $1
|
||||
AND provider_type = 'wechat'
|
||||
AND provider_key = 'wechat-main'
|
||||
AND provider_subject = 'union-array'
|
||||
`, wechatUnionArrayUserID).Scan(&wechatUnionArrayMetadataType))
|
||||
require.Equal(t, "object", wechatUnionArrayMetadataType)
|
||||
|
||||
var invalidJSONReportDetailsType string
|
||||
require.NoError(t, tx.QueryRowContext(ctx, `
|
||||
SELECT jsonb_typeof(details)
|
||||
FROM auth_identity_migration_reports
|
||||
WHERE report_type = 'legacy_external_identity_invalid_metadata_json'
|
||||
AND report_key = $1
|
||||
`, "legacy_external_identity:"+strconv.FormatInt(linuxDoMalformedLegacyID, 10)).Scan(&invalidJSONReportDetailsType))
|
||||
require.Equal(t, "object", invalidJSONReportDetailsType)
|
||||
|
||||
var openIDOnlyReportDetailsType string
|
||||
require.NoError(t, tx.QueryRowContext(ctx, `
|
||||
SELECT jsonb_typeof(details)
|
||||
FROM auth_identity_migration_reports
|
||||
WHERE report_type = 'wechat_openid_only_requires_remediation'
|
||||
AND report_key = $1
|
||||
`, "legacy_external_identity:"+strconv.FormatInt(wechatOpenIDArrayLegacyID, 10)).Scan(&openIDOnlyReportDetailsType))
|
||||
require.Equal(t, "object", openIDOnlyReportDetailsType)
|
||||
|
||||
var preservedArrayMetadataCount int
|
||||
require.NoError(t, tx.QueryRowContext(ctx, `
|
||||
SELECT COUNT(*)
|
||||
FROM auth_identities
|
||||
WHERE id IN (
|
||||
SELECT id
|
||||
FROM auth_identities
|
||||
WHERE (user_id = $1 AND provider_subject = 'linuxdo-array')
|
||||
OR (user_id = $2 AND provider_subject = 'union-array')
|
||||
)
|
||||
AND metadata ? '_legacy_metadata_raw_json'
|
||||
`, linuxDoArrayUserID, wechatUnionArrayUserID).Scan(&preservedArrayMetadataCount))
|
||||
require.Equal(t, 2, preservedArrayMetadataCount)
|
||||
|
||||
require.NotZero(t, linuxDoArrayLegacyID)
|
||||
require.NotZero(t, wechatUnionArrayLegacyID)
|
||||
}
|
||||
|
||||
func TestAuthIdentityLegacyExternalSafetyMigration_ReportsConflictsAndDowngradesInvalidJSON(t *testing.T) {
|
||||
tx := testTx(t)
|
||||
ctx := context.Background()
|
||||
|
||||
@@ -26,6 +26,10 @@ var (
|
||||
"AUTH_IDENTITY_CHANNEL_OWNERSHIP_CONFLICT",
|
||||
"auth identity channel already belongs to another user",
|
||||
)
|
||||
ErrAuthIdentityChannelProviderMismatch = infraerrors.BadRequest(
|
||||
"AUTH_IDENTITY_CHANNEL_PROVIDER_MISMATCH",
|
||||
"auth identity channel provider must match canonical identity",
|
||||
)
|
||||
)
|
||||
|
||||
type ProviderGrantReason string
|
||||
@@ -133,6 +137,10 @@ func (r *userRepository) WithUserProfileIdentityTx(ctx context.Context, fn func(
|
||||
}
|
||||
|
||||
func (r *userRepository) CreateAuthIdentity(ctx context.Context, input CreateAuthIdentityInput) (*CreateAuthIdentityResult, error) {
|
||||
if err := validateAuthIdentityChannelProviderMatch(input.Canonical, input.Channel); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
client := clientFromContext(ctx, r.client)
|
||||
|
||||
create := client.AuthIdentity.Create().
|
||||
@@ -240,6 +248,10 @@ func (r *userRepository) ListUserAuthIdentities(ctx context.Context, userID int6
|
||||
}
|
||||
|
||||
func (r *userRepository) BindAuthIdentityToUser(ctx context.Context, input BindAuthIdentityInput) (*CreateAuthIdentityResult, error) {
|
||||
if err := validateAuthIdentityChannelProviderMatch(input.Canonical, input.Channel); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var result *CreateAuthIdentityResult
|
||||
err := r.WithUserProfileIdentityTx(ctx, func(txCtx context.Context) error {
|
||||
client := clientFromContext(txCtx, r.client)
|
||||
@@ -531,6 +543,23 @@ func copyMetadata(in map[string]any) map[string]any {
|
||||
return out
|
||||
}
|
||||
|
||||
func validateAuthIdentityChannelProviderMatch(canonical AuthIdentityKey, channel *AuthIdentityChannelKey) error {
|
||||
if channel == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
canonicalProviderType := strings.TrimSpace(canonical.ProviderType)
|
||||
canonicalProviderKey := strings.TrimSpace(canonical.ProviderKey)
|
||||
channelProviderType := strings.TrimSpace(channel.ProviderType)
|
||||
channelProviderKey := strings.TrimSpace(channel.ProviderKey)
|
||||
|
||||
if canonicalProviderType != channelProviderType || canonicalProviderKey != channelProviderKey {
|
||||
return ErrAuthIdentityChannelProviderMismatch
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func txAwareSQLExecutor(ctx context.Context, fallback sqlExecutor, client *dbent.Client) sqlQueryExecutor {
|
||||
if tx := dbent.TxFromContext(ctx); tx != nil {
|
||||
if exec := sqlExecutorFromEntClient(tx.Client()); exec != nil {
|
||||
|
||||
@@ -187,6 +187,48 @@ func (s *UserProfileIdentityRepoSuite) TestBindAuthIdentityToUser_IsIdempotentAn
|
||||
s.Require().ErrorIs(err, ErrAuthIdentityChannelOwnershipConflict)
|
||||
}
|
||||
|
||||
func (s *UserProfileIdentityRepoSuite) TestCreateAuthIdentity_RejectsChannelProviderMismatch() {
|
||||
user := s.mustCreateUser("provider-mismatch-create")
|
||||
|
||||
_, err := s.repo.CreateAuthIdentity(s.ctx, CreateAuthIdentityInput{
|
||||
UserID: user.ID,
|
||||
Canonical: AuthIdentityKey{
|
||||
ProviderType: "wechat",
|
||||
ProviderKey: "wechat-main",
|
||||
ProviderSubject: "union-create-mismatch",
|
||||
},
|
||||
Channel: &AuthIdentityChannelKey{
|
||||
ProviderType: "linuxdo",
|
||||
ProviderKey: "linuxdo-main",
|
||||
Channel: "oauth",
|
||||
ChannelAppID: "app-mismatch",
|
||||
ChannelSubject: "openid-create-mismatch",
|
||||
},
|
||||
})
|
||||
s.Require().ErrorIs(err, ErrAuthIdentityChannelProviderMismatch)
|
||||
}
|
||||
|
||||
func (s *UserProfileIdentityRepoSuite) TestBindAuthIdentityToUser_RejectsChannelProviderMismatch() {
|
||||
user := s.mustCreateUser("provider-mismatch-bind")
|
||||
|
||||
_, err := s.repo.BindAuthIdentityToUser(s.ctx, BindAuthIdentityInput{
|
||||
UserID: user.ID,
|
||||
Canonical: AuthIdentityKey{
|
||||
ProviderType: "wechat",
|
||||
ProviderKey: "wechat-main",
|
||||
ProviderSubject: "union-bind-mismatch",
|
||||
},
|
||||
Channel: &AuthIdentityChannelKey{
|
||||
ProviderType: "wechat",
|
||||
ProviderKey: "wechat-legacy",
|
||||
Channel: "oa",
|
||||
ChannelAppID: "wx-app-bind-mismatch",
|
||||
ChannelSubject: "openid-bind-mismatch",
|
||||
},
|
||||
})
|
||||
s.Require().ErrorIs(err, ErrAuthIdentityChannelProviderMismatch)
|
||||
}
|
||||
|
||||
func (s *UserProfileIdentityRepoSuite) TestWithUserProfileIdentityTx_RollsBackIdentityAndGrantOnError() {
|
||||
user := s.mustCreateUser("tx-rollback")
|
||||
expectedErr := errors.New("rollback")
|
||||
|
||||
Reference in New Issue
Block a user