fix: harden auth identity legacy migrations

This commit is contained in:
IanShaw027
2026-04-21 01:30:37 +08:00
parent a70f7aca07
commit 0a461d8248
5 changed files with 329 additions and 4 deletions

View File

@@ -26,6 +26,10 @@ var (
"AUTH_IDENTITY_CHANNEL_OWNERSHIP_CONFLICT",
"auth identity channel already belongs to another user",
)
ErrAuthIdentityChannelProviderMismatch = infraerrors.BadRequest(
"AUTH_IDENTITY_CHANNEL_PROVIDER_MISMATCH",
"auth identity channel provider must match canonical identity",
)
)
type ProviderGrantReason string
@@ -133,6 +137,10 @@ func (r *userRepository) WithUserProfileIdentityTx(ctx context.Context, fn func(
}
func (r *userRepository) CreateAuthIdentity(ctx context.Context, input CreateAuthIdentityInput) (*CreateAuthIdentityResult, error) {
if err := validateAuthIdentityChannelProviderMatch(input.Canonical, input.Channel); err != nil {
return nil, err
}
client := clientFromContext(ctx, r.client)
create := client.AuthIdentity.Create().
@@ -240,6 +248,10 @@ func (r *userRepository) ListUserAuthIdentities(ctx context.Context, userID int6
}
func (r *userRepository) BindAuthIdentityToUser(ctx context.Context, input BindAuthIdentityInput) (*CreateAuthIdentityResult, error) {
if err := validateAuthIdentityChannelProviderMatch(input.Canonical, input.Channel); err != nil {
return nil, err
}
var result *CreateAuthIdentityResult
err := r.WithUserProfileIdentityTx(ctx, func(txCtx context.Context) error {
client := clientFromContext(txCtx, r.client)
@@ -531,6 +543,23 @@ func copyMetadata(in map[string]any) map[string]any {
return out
}
func validateAuthIdentityChannelProviderMatch(canonical AuthIdentityKey, channel *AuthIdentityChannelKey) error {
if channel == nil {
return nil
}
canonicalProviderType := strings.TrimSpace(canonical.ProviderType)
canonicalProviderKey := strings.TrimSpace(canonical.ProviderKey)
channelProviderType := strings.TrimSpace(channel.ProviderType)
channelProviderKey := strings.TrimSpace(channel.ProviderKey)
if canonicalProviderType != channelProviderType || canonicalProviderKey != channelProviderKey {
return ErrAuthIdentityChannelProviderMismatch
}
return nil
}
func txAwareSQLExecutor(ctx context.Context, fallback sqlExecutor, client *dbent.Client) sqlQueryExecutor {
if tx := dbent.TxFromContext(ctx); tx != nil {
if exec := sqlExecutorFromEntClient(tx.Client()); exec != nil {