fix(auth): harden oauth identity upgrade paths
This commit is contained in:
@@ -5,10 +5,15 @@ import (
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash/fnv"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"entgo.io/ent/dialect"
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
|
||||
"github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
|
||||
@@ -75,6 +80,122 @@ type AuthPendingIdentityService struct {
|
||||
entClient *dbent.Client
|
||||
}
|
||||
|
||||
var authPendingIdentityScopedKeyLocks = newAuthPendingIdentityScopedKeyLockRegistry()
|
||||
|
||||
type authPendingIdentityScopedKeyLockRegistry struct {
|
||||
mu sync.Mutex
|
||||
locks map[string]*authPendingIdentityScopedKeyLockEntry
|
||||
}
|
||||
|
||||
type authPendingIdentityScopedKeyLockEntry struct {
|
||||
mu sync.Mutex
|
||||
refs int
|
||||
}
|
||||
|
||||
func newAuthPendingIdentityScopedKeyLockRegistry() *authPendingIdentityScopedKeyLockRegistry {
|
||||
return &authPendingIdentityScopedKeyLockRegistry{
|
||||
locks: make(map[string]*authPendingIdentityScopedKeyLockEntry),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *authPendingIdentityScopedKeyLockRegistry) lock(keys ...string) func() {
|
||||
normalized := normalizeAuthPendingIdentityLockKeys(keys...)
|
||||
if len(normalized) == 0 {
|
||||
return func() {}
|
||||
}
|
||||
|
||||
entries := make([]*authPendingIdentityScopedKeyLockEntry, 0, len(normalized))
|
||||
r.mu.Lock()
|
||||
for _, key := range normalized {
|
||||
entry := r.locks[key]
|
||||
if entry == nil {
|
||||
entry = &authPendingIdentityScopedKeyLockEntry{}
|
||||
r.locks[key] = entry
|
||||
}
|
||||
entry.refs++
|
||||
entries = append(entries, entry)
|
||||
}
|
||||
r.mu.Unlock()
|
||||
|
||||
for _, entry := range entries {
|
||||
entry.mu.Lock()
|
||||
}
|
||||
|
||||
return func() {
|
||||
for i := len(entries) - 1; i >= 0; i-- {
|
||||
entries[i].mu.Unlock()
|
||||
}
|
||||
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
for idx, key := range normalized {
|
||||
entry := entries[idx]
|
||||
entry.refs--
|
||||
if entry.refs == 0 {
|
||||
delete(r.locks, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeAuthPendingIdentityLockKeys(keys ...string) []string {
|
||||
if len(keys) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
deduped := make(map[string]struct{}, len(keys))
|
||||
for _, key := range keys {
|
||||
trimmed := strings.TrimSpace(key)
|
||||
if trimmed == "" {
|
||||
continue
|
||||
}
|
||||
deduped[trimmed] = struct{}{}
|
||||
}
|
||||
if len(deduped) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
normalized := make([]string, 0, len(deduped))
|
||||
for key := range deduped {
|
||||
normalized = append(normalized, key)
|
||||
}
|
||||
sort.Strings(normalized)
|
||||
return normalized
|
||||
}
|
||||
|
||||
func authPendingIdentityAdvisoryLockHash(key string) int64 {
|
||||
hasher := fnv.New64a()
|
||||
_, _ = hasher.Write([]byte(key))
|
||||
return int64(hasher.Sum64())
|
||||
}
|
||||
|
||||
func lockAuthPendingIdentityKeys(ctx context.Context, client *dbent.Client, keys ...string) (func(), error) {
|
||||
release := authPendingIdentityScopedKeyLocks.lock(keys...)
|
||||
normalized := normalizeAuthPendingIdentityLockKeys(keys...)
|
||||
if len(normalized) == 0 || client == nil || client.Driver().Dialect() != dialect.Postgres {
|
||||
return release, nil
|
||||
}
|
||||
|
||||
for _, key := range normalized {
|
||||
var rows entsql.Rows
|
||||
if err := client.Driver().Query(ctx, "SELECT pg_advisory_xact_lock($1)", []any{authPendingIdentityAdvisoryLockHash(key)}, &rows); err != nil {
|
||||
release()
|
||||
return nil, err
|
||||
}
|
||||
_ = rows.Close()
|
||||
}
|
||||
|
||||
return release, nil
|
||||
}
|
||||
|
||||
func pendingIdentityAdoptionLockKeys(pendingAuthSessionID int64, identityID *int64) []string {
|
||||
keys := []string{fmt.Sprintf("pending-auth-adoption:pending:%d", pendingAuthSessionID)}
|
||||
if identityID != nil && *identityID > 0 {
|
||||
keys = append(keys, fmt.Sprintf("pending-auth-adoption:identity:%d", *identityID))
|
||||
}
|
||||
return keys
|
||||
}
|
||||
|
||||
func NewAuthPendingIdentityService(entClient *dbent.Client) *AuthPendingIdentityService {
|
||||
return &AuthPendingIdentityService{entClient: entClient}
|
||||
}
|
||||
@@ -324,8 +445,29 @@ func (s *AuthPendingIdentityService) UpsertAdoptionDecision(ctx context.Context,
|
||||
return nil, fmt.Errorf("pending auth ent client is not configured")
|
||||
}
|
||||
|
||||
tx, err := s.entClient.Tx(ctx)
|
||||
if err != nil && !errors.Is(err, dbent.ErrTxStarted) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
client := s.entClient
|
||||
txCtx := ctx
|
||||
if err == nil {
|
||||
defer func() { _ = tx.Rollback() }()
|
||||
client = tx.Client()
|
||||
txCtx = dbent.NewTxContext(ctx, tx)
|
||||
} else if existingTx := dbent.TxFromContext(ctx); existingTx != nil {
|
||||
client = existingTx.Client()
|
||||
}
|
||||
|
||||
releaseLocks, err := lockAuthPendingIdentityKeys(txCtx, client, pendingIdentityAdoptionLockKeys(input.PendingAuthSessionID, input.IdentityID)...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer releaseLocks()
|
||||
|
||||
if input.IdentityID != nil && *input.IdentityID > 0 {
|
||||
if _, err := s.entClient.IdentityAdoptionDecision.Update().
|
||||
if _, err := client.IdentityAdoptionDecision.Update().
|
||||
Where(
|
||||
identityadoptiondecision.IdentityIDEQ(*input.IdentityID),
|
||||
dbpredicate.IdentityAdoptionDecision(func(s *entsql.Selector) {
|
||||
@@ -337,36 +479,40 @@ func (s *AuthPendingIdentityService) UpsertAdoptionDecision(ctx context.Context,
|
||||
}),
|
||||
).
|
||||
ClearIdentityID().
|
||||
Save(ctx); err != nil {
|
||||
Save(txCtx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
existing, err := s.entClient.IdentityAdoptionDecision.Query().
|
||||
Where(identityadoptiondecision.PendingAuthSessionIDEQ(input.PendingAuthSessionID)).
|
||||
Only(ctx)
|
||||
if err != nil && !dbent.IsNotFound(err) {
|
||||
return nil, err
|
||||
}
|
||||
if existing == nil {
|
||||
create := s.entClient.IdentityAdoptionDecision.Create().
|
||||
SetPendingAuthSessionID(input.PendingAuthSessionID).
|
||||
SetAdoptDisplayName(input.AdoptDisplayName).
|
||||
SetAdoptAvatar(input.AdoptAvatar).
|
||||
SetDecidedAt(time.Now().UTC())
|
||||
if input.IdentityID != nil {
|
||||
create = create.SetIdentityID(*input.IdentityID)
|
||||
}
|
||||
return create.Save(ctx)
|
||||
create := client.IdentityAdoptionDecision.Create().
|
||||
SetPendingAuthSessionID(input.PendingAuthSessionID).
|
||||
SetAdoptDisplayName(input.AdoptDisplayName).
|
||||
SetAdoptAvatar(input.AdoptAvatar).
|
||||
SetDecidedAt(time.Now().UTC())
|
||||
if input.IdentityID != nil && *input.IdentityID > 0 {
|
||||
create = create.SetIdentityID(*input.IdentityID)
|
||||
}
|
||||
|
||||
update := s.entClient.IdentityAdoptionDecision.UpdateOneID(existing.ID).
|
||||
SetAdoptDisplayName(input.AdoptDisplayName).
|
||||
SetAdoptAvatar(input.AdoptAvatar)
|
||||
if input.IdentityID != nil {
|
||||
update = update.SetIdentityID(*input.IdentityID)
|
||||
decisionID, err := create.
|
||||
OnConflictColumns(identityadoptiondecision.FieldPendingAuthSessionID).
|
||||
UpdateNewValues().
|
||||
ID(txCtx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return update.Save(ctx)
|
||||
|
||||
decision, err := client.IdentityAdoptionDecision.Get(txCtx, decisionID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if tx != nil {
|
||||
if err := tx.Commit(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return decision, nil
|
||||
}
|
||||
|
||||
func copyPendingMap(in map[string]any) map[string]any {
|
||||
|
||||
Reference in New Issue
Block a user