package service import ( "context" "crypto/rand" "crypto/sha256" "encoding/hex" "errors" "fmt" "hash/fnv" "sort" "strings" "sync" "time" "entgo.io/ent/dialect" dbent "github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" dbpredicate "github.com/Wei-Shaw/sub2api/ent/predicate" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" entsql "entgo.io/ent/dialect/sql" ) var ( ErrPendingAuthSessionNotFound = infraerrors.NotFound("PENDING_AUTH_SESSION_NOT_FOUND", "pending auth session not found") ErrPendingAuthSessionExpired = infraerrors.Unauthorized("PENDING_AUTH_SESSION_EXPIRED", "pending auth session has expired") ErrPendingAuthSessionConsumed = infraerrors.Unauthorized("PENDING_AUTH_SESSION_CONSUMED", "pending auth session has already been used") ErrPendingAuthCodeInvalid = infraerrors.Unauthorized("PENDING_AUTH_CODE_INVALID", "pending auth completion code is invalid") ErrPendingAuthCodeExpired = infraerrors.Unauthorized("PENDING_AUTH_CODE_EXPIRED", "pending auth completion code has expired") ErrPendingAuthCodeConsumed = infraerrors.Unauthorized("PENDING_AUTH_CODE_CONSUMED", "pending auth completion code has already been used") ErrPendingAuthBrowserMismatch = infraerrors.Unauthorized("PENDING_AUTH_BROWSER_MISMATCH", "pending auth completion code does not match this browser session") ) const ( defaultPendingAuthTTL = 15 * time.Minute defaultPendingAuthCompletionTTL = 5 * time.Minute ) type PendingAuthIdentityKey struct { ProviderType string ProviderKey string ProviderSubject string } type CreatePendingAuthSessionInput struct { SessionToken string Intent string Identity PendingAuthIdentityKey TargetUserID *int64 RedirectTo string ResolvedEmail string RegistrationPasswordHash string BrowserSessionKey string UpstreamIdentityClaims map[string]any LocalFlowState map[string]any ExpiresAt time.Time } type IssuePendingAuthCompletionCodeInput struct { PendingAuthSessionID int64 BrowserSessionKey string TTL time.Duration } type IssuePendingAuthCompletionCodeResult struct { Code string ExpiresAt time.Time } type PendingIdentityAdoptionDecisionInput struct { PendingAuthSessionID int64 IdentityID *int64 AdoptDisplayName bool AdoptAvatar bool } type AuthPendingIdentityService struct { entClient *dbent.Client } var authPendingIdentityScopedKeyLocks = newAuthPendingIdentityScopedKeyLockRegistry() type authPendingIdentityScopedKeyLockRegistry struct { mu sync.Mutex locks map[string]*authPendingIdentityScopedKeyLockEntry } type authPendingIdentityScopedKeyLockEntry struct { mu sync.Mutex refs int } func newAuthPendingIdentityScopedKeyLockRegistry() *authPendingIdentityScopedKeyLockRegistry { return &authPendingIdentityScopedKeyLockRegistry{ locks: make(map[string]*authPendingIdentityScopedKeyLockEntry), } } func (r *authPendingIdentityScopedKeyLockRegistry) lock(keys ...string) func() { normalized := normalizeAuthPendingIdentityLockKeys(keys...) if len(normalized) == 0 { return func() {} } entries := make([]*authPendingIdentityScopedKeyLockEntry, 0, len(normalized)) r.mu.Lock() for _, key := range normalized { entry := r.locks[key] if entry == nil { entry = &authPendingIdentityScopedKeyLockEntry{} r.locks[key] = entry } entry.refs++ entries = append(entries, entry) } r.mu.Unlock() for _, entry := range entries { entry.mu.Lock() } return func() { for i := len(entries) - 1; i >= 0; i-- { entries[i].mu.Unlock() } r.mu.Lock() defer r.mu.Unlock() for idx, key := range normalized { entry := entries[idx] entry.refs-- if entry.refs == 0 { delete(r.locks, key) } } } } func normalizeAuthPendingIdentityLockKeys(keys ...string) []string { if len(keys) == 0 { return nil } deduped := make(map[string]struct{}, len(keys)) for _, key := range keys { trimmed := strings.TrimSpace(key) if trimmed == "" { continue } deduped[trimmed] = struct{}{} } if len(deduped) == 0 { return nil } normalized := make([]string, 0, len(deduped)) for key := range deduped { normalized = append(normalized, key) } sort.Strings(normalized) return normalized } func authPendingIdentityAdvisoryLockHash(key string) int64 { hasher := fnv.New64a() _, _ = hasher.Write([]byte(key)) return int64(hasher.Sum64()) } func lockAuthPendingIdentityKeys(ctx context.Context, client *dbent.Client, keys ...string) (func(), error) { release := authPendingIdentityScopedKeyLocks.lock(keys...) normalized := normalizeAuthPendingIdentityLockKeys(keys...) if len(normalized) == 0 || client == nil || client.Driver().Dialect() != dialect.Postgres { return release, nil } for _, key := range normalized { var rows entsql.Rows if err := client.Driver().Query(ctx, "SELECT pg_advisory_xact_lock($1)", []any{authPendingIdentityAdvisoryLockHash(key)}, &rows); err != nil { release() return nil, err } _ = rows.Close() } return release, nil } func pendingIdentityAdoptionLockKeys(pendingAuthSessionID int64, identityID *int64) []string { keys := []string{fmt.Sprintf("pending-auth-adoption:pending:%d", pendingAuthSessionID)} if identityID != nil && *identityID > 0 { keys = append(keys, fmt.Sprintf("pending-auth-adoption:identity:%d", *identityID)) } return keys } func NewAuthPendingIdentityService(entClient *dbent.Client) *AuthPendingIdentityService { return &AuthPendingIdentityService{entClient: entClient} } func (s *AuthPendingIdentityService) CreatePendingSession(ctx context.Context, input CreatePendingAuthSessionInput) (*dbent.PendingAuthSession, error) { if s == nil || s.entClient == nil { return nil, fmt.Errorf("pending auth ent client is not configured") } sessionToken := strings.TrimSpace(input.SessionToken) if sessionToken == "" { var err error sessionToken, err = randomOpaqueToken(24) if err != nil { return nil, err } } expiresAt := input.ExpiresAt.UTC() if expiresAt.IsZero() { expiresAt = time.Now().UTC().Add(defaultPendingAuthTTL) } create := s.entClient.PendingAuthSession.Create(). SetSessionToken(sessionToken). SetIntent(strings.TrimSpace(input.Intent)). SetProviderType(strings.TrimSpace(input.Identity.ProviderType)). SetProviderKey(strings.TrimSpace(input.Identity.ProviderKey)). SetProviderSubject(strings.TrimSpace(input.Identity.ProviderSubject)). SetRedirectTo(strings.TrimSpace(input.RedirectTo)). SetResolvedEmail(strings.TrimSpace(input.ResolvedEmail)). SetRegistrationPasswordHash(strings.TrimSpace(input.RegistrationPasswordHash)). SetBrowserSessionKey(strings.TrimSpace(input.BrowserSessionKey)). SetUpstreamIdentityClaims(copyPendingMap(input.UpstreamIdentityClaims)). SetLocalFlowState(copyPendingMap(input.LocalFlowState)). SetExpiresAt(expiresAt) if input.TargetUserID != nil { create = create.SetTargetUserID(*input.TargetUserID) } return create.Save(ctx) } func (s *AuthPendingIdentityService) IssueCompletionCode(ctx context.Context, input IssuePendingAuthCompletionCodeInput) (*IssuePendingAuthCompletionCodeResult, error) { if s == nil || s.entClient == nil { return nil, fmt.Errorf("pending auth ent client is not configured") } session, err := s.entClient.PendingAuthSession.Get(ctx, input.PendingAuthSessionID) if err != nil { if dbent.IsNotFound(err) { return nil, ErrPendingAuthSessionNotFound } return nil, err } code, err := randomOpaqueToken(24) if err != nil { return nil, err } ttl := input.TTL if ttl <= 0 { ttl = defaultPendingAuthCompletionTTL } expiresAt := time.Now().UTC().Add(ttl) update := s.entClient.PendingAuthSession.UpdateOneID(session.ID). SetCompletionCodeHash(hashPendingAuthCode(code)). SetCompletionCodeExpiresAt(expiresAt) if strings.TrimSpace(input.BrowserSessionKey) != "" { update = update.SetBrowserSessionKey(strings.TrimSpace(input.BrowserSessionKey)) } if _, err := update.Save(ctx); err != nil { return nil, err } return &IssuePendingAuthCompletionCodeResult{ Code: code, ExpiresAt: expiresAt, }, nil } func (s *AuthPendingIdentityService) ConsumeCompletionCode(ctx context.Context, rawCode, browserSessionKey string) (*dbent.PendingAuthSession, error) { if s == nil || s.entClient == nil { return nil, fmt.Errorf("pending auth ent client is not configured") } codeHash := hashPendingAuthCode(strings.TrimSpace(rawCode)) session, err := s.entClient.PendingAuthSession.Query(). Where(pendingauthsession.CompletionCodeHashEQ(codeHash)). Only(ctx) if err != nil { if dbent.IsNotFound(err) { return nil, ErrPendingAuthCodeInvalid } return nil, err } return s.consumeSession(ctx, session, browserSessionKey, ErrPendingAuthCodeExpired, ErrPendingAuthCodeConsumed) } func (s *AuthPendingIdentityService) ConsumeBrowserSession(ctx context.Context, sessionToken, browserSessionKey string) (*dbent.PendingAuthSession, error) { if s == nil || s.entClient == nil { return nil, fmt.Errorf("pending auth ent client is not configured") } session, err := s.getBrowserSession(ctx, sessionToken) if err != nil { return nil, err } return s.consumeSession(ctx, session, browserSessionKey, ErrPendingAuthSessionExpired, ErrPendingAuthSessionConsumed) } func (s *AuthPendingIdentityService) GetBrowserSession(ctx context.Context, sessionToken, browserSessionKey string) (*dbent.PendingAuthSession, error) { if s == nil || s.entClient == nil { return nil, fmt.Errorf("pending auth ent client is not configured") } session, err := s.getBrowserSession(ctx, sessionToken) if err != nil { return nil, err } if err := validatePendingSessionState(session, browserSessionKey, ErrPendingAuthSessionExpired, ErrPendingAuthSessionConsumed); err != nil { return nil, err } return session, nil } func (s *AuthPendingIdentityService) getBrowserSession(ctx context.Context, sessionToken string) (*dbent.PendingAuthSession, error) { if s == nil || s.entClient == nil { return nil, fmt.Errorf("pending auth ent client is not configured") } sessionToken = strings.TrimSpace(sessionToken) if sessionToken == "" { return nil, ErrPendingAuthSessionNotFound } session, err := s.entClient.PendingAuthSession.Query(). Where(pendingauthsession.SessionTokenEQ(sessionToken)). Only(ctx) if err != nil { if dbent.IsNotFound(err) { return nil, ErrPendingAuthSessionNotFound } return nil, err } return session, nil } func (s *AuthPendingIdentityService) consumeSession( ctx context.Context, session *dbent.PendingAuthSession, browserSessionKey string, expiredErr error, consumedErr error, ) (*dbent.PendingAuthSession, error) { if err := validatePendingSessionState(session, browserSessionKey, expiredErr, consumedErr); err != nil { return nil, err } sanitizedLocalFlowState := sanitizePendingAuthLocalFlowState(session.LocalFlowState) now := time.Now().UTC() update := s.entClient.PendingAuthSession.UpdateOneID(session.ID). Where( pendingauthsession.ConsumedAtIsNil(), pendingauthsession.ExpiresAtGTE(now), pendingauthsession.Or( pendingauthsession.CompletionCodeExpiresAtIsNil(), pendingauthsession.CompletionCodeExpiresAtGTE(now), ), ). SetConsumedAt(now). SetLocalFlowState(sanitizedLocalFlowState). SetCompletionCodeHash(""). ClearCompletionCodeExpiresAt() if expectedBrowserSessionKey := strings.TrimSpace(session.BrowserSessionKey); expectedBrowserSessionKey != "" { update = update.Where(pendingauthsession.BrowserSessionKeyEQ(expectedBrowserSessionKey)) } updated, err := update.Save(ctx) if err == nil { return updated, nil } if !dbent.IsNotFound(err) { return nil, err } current, currentErr := s.entClient.PendingAuthSession.Get(ctx, session.ID) if currentErr != nil { if dbent.IsNotFound(currentErr) { return nil, ErrPendingAuthSessionNotFound } return nil, currentErr } if err := validatePendingSessionState(current, browserSessionKey, expiredErr, consumedErr); err != nil { return nil, err } return nil, consumedErr } func sanitizePendingAuthLocalFlowState(localFlowState map[string]any) map[string]any { sanitized := copyPendingMap(localFlowState) if len(sanitized) == 0 { return sanitized } rawCompletion, ok := sanitized["completion_response"] if !ok { return sanitized } completion, ok := rawCompletion.(map[string]any) if !ok { return sanitized } cleanedCompletion := copyPendingMap(completion) for _, key := range []string{"access_token", "refresh_token", "expires_in", "token_type"} { delete(cleanedCompletion, key) } sanitized["completion_response"] = cleanedCompletion return sanitized } func validatePendingSessionState(session *dbent.PendingAuthSession, browserSessionKey string, expiredErr error, consumedErr error) error { if session == nil { return ErrPendingAuthSessionNotFound } now := time.Now().UTC() if session.ConsumedAt != nil { return consumedErr } if !session.ExpiresAt.IsZero() && now.After(session.ExpiresAt) { return expiredErr } if session.CompletionCodeExpiresAt != nil && now.After(*session.CompletionCodeExpiresAt) { return expiredErr } if strings.TrimSpace(session.BrowserSessionKey) != "" && strings.TrimSpace(browserSessionKey) != strings.TrimSpace(session.BrowserSessionKey) { return ErrPendingAuthBrowserMismatch } return nil } func (s *AuthPendingIdentityService) UpsertAdoptionDecision(ctx context.Context, input PendingIdentityAdoptionDecisionInput) (*dbent.IdentityAdoptionDecision, error) { if s == nil || s.entClient == nil { return nil, fmt.Errorf("pending auth ent client is not configured") } tx, err := s.entClient.Tx(ctx) if err != nil && !errors.Is(err, dbent.ErrTxStarted) { return nil, err } client := s.entClient txCtx := ctx if err == nil { defer func() { _ = tx.Rollback() }() client = tx.Client() txCtx = dbent.NewTxContext(ctx, tx) } else if existingTx := dbent.TxFromContext(ctx); existingTx != nil { client = existingTx.Client() } releaseLocks, err := lockAuthPendingIdentityKeys(txCtx, client, pendingIdentityAdoptionLockKeys(input.PendingAuthSessionID, input.IdentityID)...) if err != nil { return nil, err } defer releaseLocks() if input.IdentityID != nil && *input.IdentityID > 0 { if _, err := client.IdentityAdoptionDecision.Update(). Where( identityadoptiondecision.IdentityIDEQ(*input.IdentityID), dbpredicate.IdentityAdoptionDecision(func(s *entsql.Selector) { col := s.C(identityadoptiondecision.FieldPendingAuthSessionID) s.Where(entsql.Or( entsql.IsNull(col), entsql.NEQ(col, input.PendingAuthSessionID), )) }), ). ClearIdentityID(). Save(txCtx); err != nil { return nil, err } } create := client.IdentityAdoptionDecision.Create(). SetPendingAuthSessionID(input.PendingAuthSessionID). SetAdoptDisplayName(input.AdoptDisplayName). SetAdoptAvatar(input.AdoptAvatar). SetDecidedAt(time.Now().UTC()) if input.IdentityID != nil && *input.IdentityID > 0 { create = create.SetIdentityID(*input.IdentityID) } decisionID, err := create. OnConflictColumns(identityadoptiondecision.FieldPendingAuthSessionID). UpdateNewValues(). ID(txCtx) if err != nil { return nil, err } decision, err := client.IdentityAdoptionDecision.Get(txCtx, decisionID) if err != nil { return nil, err } if tx != nil { if err := tx.Commit(); err != nil { return nil, err } } return decision, nil } func copyPendingMap(in map[string]any) map[string]any { if len(in) == 0 { return map[string]any{} } out := make(map[string]any, len(in)) for k, v := range in { out[k] = v } return out } func randomOpaqueToken(byteLen int) (string, error) { if byteLen <= 0 { byteLen = 16 } buf := make([]byte, byteLen) if _, err := rand.Read(buf); err != nil { return "", err } return hex.EncodeToString(buf), nil } func hashPendingAuthCode(code string) string { sum := sha256.Sum256([]byte(code)) return hex.EncodeToString(sum[:]) }