fix(auth): harden oauth identity upgrade paths
This commit is contained in:
@@ -4,11 +4,15 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"hash/fnv"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"entgo.io/ent/dialect"
|
||||
entsql "entgo.io/ent/dialect/sql"
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/authidentity"
|
||||
@@ -120,6 +124,113 @@ type sqlQueryExecutor interface {
|
||||
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
|
||||
}
|
||||
|
||||
var repositoryScopedKeyLocks = newScopedKeyLockRegistry()
|
||||
|
||||
type scopedKeyLockRegistry struct {
|
||||
mu sync.Mutex
|
||||
locks map[string]*scopedKeyLockEntry
|
||||
}
|
||||
|
||||
type scopedKeyLockEntry struct {
|
||||
mu sync.Mutex
|
||||
refs int
|
||||
}
|
||||
|
||||
func newScopedKeyLockRegistry() *scopedKeyLockRegistry {
|
||||
return &scopedKeyLockRegistry{
|
||||
locks: make(map[string]*scopedKeyLockEntry),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *scopedKeyLockRegistry) lock(keys ...string) func() {
|
||||
normalized := normalizeLockKeys(keys...)
|
||||
if len(normalized) == 0 {
|
||||
return func() {}
|
||||
}
|
||||
|
||||
entries := make([]*scopedKeyLockEntry, 0, len(normalized))
|
||||
r.mu.Lock()
|
||||
for _, key := range normalized {
|
||||
entry := r.locks[key]
|
||||
if entry == nil {
|
||||
entry = &scopedKeyLockEntry{}
|
||||
r.locks[key] = entry
|
||||
}
|
||||
entry.refs++
|
||||
entries = append(entries, entry)
|
||||
}
|
||||
r.mu.Unlock()
|
||||
|
||||
for _, entry := range entries {
|
||||
entry.mu.Lock()
|
||||
}
|
||||
|
||||
return func() {
|
||||
for i := len(entries) - 1; i >= 0; i-- {
|
||||
entries[i].mu.Unlock()
|
||||
}
|
||||
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
for idx, key := range normalized {
|
||||
entry := entries[idx]
|
||||
entry.refs--
|
||||
if entry.refs == 0 {
|
||||
delete(r.locks, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeLockKeys(keys ...string) []string {
|
||||
if len(keys) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
deduped := make(map[string]struct{}, len(keys))
|
||||
for _, key := range keys {
|
||||
trimmed := strings.TrimSpace(key)
|
||||
if trimmed == "" {
|
||||
continue
|
||||
}
|
||||
deduped[trimmed] = struct{}{}
|
||||
}
|
||||
if len(deduped) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
normalized := make([]string, 0, len(deduped))
|
||||
for key := range deduped {
|
||||
normalized = append(normalized, key)
|
||||
}
|
||||
sort.Strings(normalized)
|
||||
return normalized
|
||||
}
|
||||
|
||||
func advisoryLockHash(key string) int64 {
|
||||
hasher := fnv.New64a()
|
||||
_, _ = hasher.Write([]byte(key))
|
||||
return int64(hasher.Sum64())
|
||||
}
|
||||
|
||||
func lockRepositoryScopedKeys(ctx context.Context, client *dbent.Client, exec sqlQueryExecutor, keys ...string) (func(), error) {
|
||||
release := repositoryScopedKeyLocks.lock(keys...)
|
||||
normalized := normalizeLockKeys(keys...)
|
||||
if len(normalized) == 0 || client == nil || exec == nil || client.Driver().Dialect() != dialect.Postgres {
|
||||
return release, nil
|
||||
}
|
||||
|
||||
for _, key := range normalized {
|
||||
rows, err := exec.QueryContext(ctx, "SELECT pg_advisory_xact_lock($1)", advisoryLockHash(key))
|
||||
if err != nil {
|
||||
release()
|
||||
return nil, err
|
||||
}
|
||||
_ = rows.Close()
|
||||
}
|
||||
return release, nil
|
||||
}
|
||||
|
||||
func (r *userRepository) WithUserProfileIdentityTx(ctx context.Context, fn func(txCtx context.Context) error) error {
|
||||
if dbent.TxFromContext(ctx) != nil {
|
||||
return fn(ctx)
|
||||
@@ -329,7 +440,11 @@ func (r *userRepository) BindAuthIdentityToUser(ctx context.Context, input BindA
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
targetProviderKey := canonicalizeCompatibleIdentityProviderKey(canonical.ProviderType, identity.ProviderKey, canonical.ProviderKey)
|
||||
update := client.AuthIdentity.UpdateOneID(identity.ID)
|
||||
if targetProviderKey != "" && !strings.EqualFold(targetProviderKey, identity.ProviderKey) {
|
||||
update = update.SetProviderKey(targetProviderKey)
|
||||
}
|
||||
if input.Metadata != nil {
|
||||
update = update.SetMetadata(copyMetadata(input.Metadata))
|
||||
}
|
||||
@@ -378,8 +493,12 @@ func (r *userRepository) BindAuthIdentityToUser(ctx context.Context, input BindA
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
targetProviderKey := canonicalizeCompatibleIdentityProviderKey(input.Channel.ProviderType, channel.ProviderKey, input.Channel.ProviderKey)
|
||||
update := client.AuthIdentityChannel.UpdateOneID(channel.ID).
|
||||
SetIdentityID(identity.ID)
|
||||
if targetProviderKey != "" && !strings.EqualFold(targetProviderKey, channel.ProviderKey) {
|
||||
update = update.SetProviderKey(targetProviderKey)
|
||||
}
|
||||
if input.ChannelMetadata != nil {
|
||||
update = update.SetMetadata(copyMetadata(input.ChannelMetadata))
|
||||
}
|
||||
@@ -418,13 +537,52 @@ func compatibleIdentityProviderKeys(providerType, providerKey string) []string {
|
||||
return keys
|
||||
}
|
||||
|
||||
func canonicalizeCompatibleIdentityProviderKey(providerType, existingKey, requestedKey string) string {
|
||||
providerType = strings.TrimSpace(strings.ToLower(providerType))
|
||||
existingKey = strings.TrimSpace(existingKey)
|
||||
requestedKey = strings.TrimSpace(requestedKey)
|
||||
if providerType != "wechat" {
|
||||
if requestedKey != "" {
|
||||
return requestedKey
|
||||
}
|
||||
return existingKey
|
||||
}
|
||||
if strings.EqualFold(existingKey, "wechat") || strings.EqualFold(existingKey, "wechat-main") || strings.EqualFold(requestedKey, "wechat-main") {
|
||||
return "wechat-main"
|
||||
}
|
||||
if requestedKey != "" {
|
||||
return requestedKey
|
||||
}
|
||||
return existingKey
|
||||
}
|
||||
|
||||
func compatibleIdentityProviderKeyRank(providerType, providerKey string) int {
|
||||
providerType = strings.TrimSpace(strings.ToLower(providerType))
|
||||
providerKey = strings.TrimSpace(providerKey)
|
||||
if providerType != "wechat" {
|
||||
return 0
|
||||
}
|
||||
switch {
|
||||
case strings.EqualFold(providerKey, "wechat-main"):
|
||||
return 0
|
||||
case strings.EqualFold(providerKey, "wechat"):
|
||||
return 2
|
||||
default:
|
||||
return 1
|
||||
}
|
||||
}
|
||||
|
||||
func selectOwnedCompatibleIdentity(records []*dbent.AuthIdentity, userID int64) *dbent.AuthIdentity {
|
||||
var selected *dbent.AuthIdentity
|
||||
for _, record := range records {
|
||||
if record.UserID == userID {
|
||||
return record
|
||||
if record.UserID != userID {
|
||||
continue
|
||||
}
|
||||
if selected == nil || compatibleIdentityProviderKeyRank(record.ProviderType, record.ProviderKey) < compatibleIdentityProviderKeyRank(selected.ProviderType, selected.ProviderKey) {
|
||||
selected = record
|
||||
}
|
||||
}
|
||||
return nil
|
||||
return selected
|
||||
}
|
||||
|
||||
func hasCompatibleIdentityConflict(records []*dbent.AuthIdentity, userID int64) bool {
|
||||
@@ -437,12 +595,16 @@ func hasCompatibleIdentityConflict(records []*dbent.AuthIdentity, userID int64)
|
||||
}
|
||||
|
||||
func selectOwnedCompatibleChannel(records []*dbent.AuthIdentityChannel, userID int64) *dbent.AuthIdentityChannel {
|
||||
var selected *dbent.AuthIdentityChannel
|
||||
for _, record := range records {
|
||||
if record.Edges.Identity != nil && record.Edges.Identity.UserID == userID {
|
||||
return record
|
||||
if record.Edges.Identity == nil || record.Edges.Identity.UserID != userID {
|
||||
continue
|
||||
}
|
||||
if selected == nil || compatibleIdentityProviderKeyRank(record.ProviderType, record.ProviderKey) < compatibleIdentityProviderKeyRank(selected.ProviderType, selected.ProviderKey) {
|
||||
selected = record
|
||||
}
|
||||
}
|
||||
return nil
|
||||
return selected
|
||||
}
|
||||
|
||||
func hasCompatibleChannelConflict(records []*dbent.AuthIdentityChannel, userID int64) bool {
|
||||
@@ -479,51 +641,70 @@ ON CONFLICT (user_id, provider_type, grant_reason) DO NOTHING`,
|
||||
}
|
||||
|
||||
func (r *userRepository) UpsertIdentityAdoptionDecision(ctx context.Context, input IdentityAdoptionDecisionInput) (*dbent.IdentityAdoptionDecision, error) {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
if input.IdentityID != nil && *input.IdentityID > 0 {
|
||||
if _, err := client.IdentityAdoptionDecision.Update().
|
||||
Where(
|
||||
identityadoptiondecision.IdentityIDEQ(*input.IdentityID),
|
||||
dbpredicate.IdentityAdoptionDecision(func(s *entsql.Selector) {
|
||||
col := s.C(identityadoptiondecision.FieldPendingAuthSessionID)
|
||||
s.Where(entsql.Or(
|
||||
entsql.IsNull(col),
|
||||
entsql.NEQ(col, input.PendingAuthSessionID),
|
||||
))
|
||||
}),
|
||||
).
|
||||
ClearIdentityID().
|
||||
Save(ctx); err != nil {
|
||||
return nil, err
|
||||
var result *dbent.IdentityAdoptionDecision
|
||||
err := r.WithUserProfileIdentityTx(ctx, func(txCtx context.Context) error {
|
||||
client := clientFromContext(txCtx, r.client)
|
||||
releaseLocks, err := lockRepositoryScopedKeys(
|
||||
txCtx,
|
||||
client,
|
||||
txAwareSQLExecutor(txCtx, r.sql, r.client),
|
||||
identityAdoptionDecisionLockKeys(input.PendingAuthSessionID, input.IdentityID)...,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer releaseLocks()
|
||||
|
||||
if input.IdentityID != nil && *input.IdentityID > 0 {
|
||||
if _, err := client.IdentityAdoptionDecision.Update().
|
||||
Where(
|
||||
identityadoptiondecision.IdentityIDEQ(*input.IdentityID),
|
||||
dbpredicate.IdentityAdoptionDecision(func(s *entsql.Selector) {
|
||||
col := s.C(identityadoptiondecision.FieldPendingAuthSessionID)
|
||||
s.Where(entsql.Or(
|
||||
entsql.IsNull(col),
|
||||
entsql.NEQ(col, input.PendingAuthSessionID),
|
||||
))
|
||||
}),
|
||||
).
|
||||
ClearIdentityID().
|
||||
Save(txCtx); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
current, err := client.IdentityAdoptionDecision.Query().
|
||||
Where(identityadoptiondecision.PendingAuthSessionIDEQ(input.PendingAuthSessionID)).
|
||||
Only(ctx)
|
||||
if err != nil && !dbent.IsNotFound(err) {
|
||||
return nil, err
|
||||
}
|
||||
now := time.Now().UTC()
|
||||
if current == nil {
|
||||
create := client.IdentityAdoptionDecision.Create().
|
||||
SetPendingAuthSessionID(input.PendingAuthSessionID).
|
||||
SetAdoptDisplayName(input.AdoptDisplayName).
|
||||
SetAdoptAvatar(input.AdoptAvatar).
|
||||
SetDecidedAt(now)
|
||||
if input.IdentityID != nil {
|
||||
SetDecidedAt(time.Now().UTC())
|
||||
if input.IdentityID != nil && *input.IdentityID > 0 {
|
||||
create = create.SetIdentityID(*input.IdentityID)
|
||||
}
|
||||
return create.Save(ctx)
|
||||
}
|
||||
|
||||
update := client.IdentityAdoptionDecision.UpdateOneID(current.ID).
|
||||
SetAdoptDisplayName(input.AdoptDisplayName).
|
||||
SetAdoptAvatar(input.AdoptAvatar)
|
||||
if input.IdentityID != nil {
|
||||
update = update.SetIdentityID(*input.IdentityID)
|
||||
decisionID, err := create.
|
||||
OnConflictColumns(identityadoptiondecision.FieldPendingAuthSessionID).
|
||||
UpdateNewValues().
|
||||
ID(txCtx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
result, err = client.IdentityAdoptionDecision.Get(txCtx, decisionID)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return update.Save(ctx)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func identityAdoptionDecisionLockKeys(pendingAuthSessionID int64, identityID *int64) []string {
|
||||
keys := []string{fmt.Sprintf("identity-adoption:pending:%d", pendingAuthSessionID)}
|
||||
if identityID != nil && *identityID > 0 {
|
||||
keys = append(keys, fmt.Sprintf("identity-adoption:identity:%d", *identityID))
|
||||
}
|
||||
return keys
|
||||
}
|
||||
|
||||
func (r *userRepository) GetIdentityAdoptionDecisionByPendingAuthSessionID(ctx context.Context, pendingAuthSessionID int64) (*dbent.IdentityAdoptionDecision, error) {
|
||||
|
||||
Reference in New Issue
Block a user