feat: rebuild auth identity foundation flow
This commit is contained in:
@@ -149,6 +149,9 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
|
||||
user.FieldBalanceNotifyThreshold,
|
||||
user.FieldBalanceNotifyExtraEmails,
|
||||
user.FieldTotalRecharged,
|
||||
user.FieldSignupSource,
|
||||
user.FieldLastLoginAt,
|
||||
user.FieldLastActiveAt,
|
||||
)
|
||||
}).
|
||||
WithGroup(func(q *dbent.GroupQuery) {
|
||||
@@ -656,6 +659,9 @@ func userEntityToService(u *dbent.User) *service.User {
|
||||
Balance: u.Balance,
|
||||
Concurrency: u.Concurrency,
|
||||
Status: u.Status,
|
||||
SignupSource: u.SignupSource,
|
||||
LastLoginAt: u.LastLoginAt,
|
||||
LastActiveAt: u.LastActiveAt,
|
||||
TotpSecretEncrypted: u.TotpSecretEncrypted,
|
||||
TotpEnabled: u.TotpEnabled,
|
||||
TotpEnabledAt: u.TotpEnabledAt,
|
||||
|
||||
148
backend/internal/repository/auth_identity_migration_report.go
Normal file
148
backend/internal/repository/auth_identity_migration_report.go
Normal file
@@ -0,0 +1,148 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type AuthIdentityMigrationReport struct {
|
||||
ID int64
|
||||
ReportType string
|
||||
ReportKey string
|
||||
Details map[string]any
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
type AuthIdentityMigrationReportQuery struct {
|
||||
ReportType string
|
||||
Limit int
|
||||
Offset int
|
||||
}
|
||||
|
||||
type AuthIdentityMigrationReportSummary struct {
|
||||
Total int64
|
||||
ByType map[string]int64
|
||||
}
|
||||
|
||||
func (r *userRepository) ListAuthIdentityMigrationReports(ctx context.Context, query AuthIdentityMigrationReportQuery) ([]AuthIdentityMigrationReport, error) {
|
||||
exec := txAwareSQLExecutor(ctx, r.sql, r.client)
|
||||
if exec == nil {
|
||||
return nil, fmt.Errorf("sql executor is not configured")
|
||||
}
|
||||
|
||||
limit := query.Limit
|
||||
if limit <= 0 {
|
||||
limit = 100
|
||||
}
|
||||
rows, err := exec.QueryContext(ctx, `
|
||||
SELECT id, report_type, report_key, details, created_at
|
||||
FROM auth_identity_migration_reports
|
||||
WHERE ($1 = '' OR report_type = $1)
|
||||
ORDER BY created_at DESC, id DESC
|
||||
LIMIT $2 OFFSET $3`,
|
||||
strings.TrimSpace(query.ReportType),
|
||||
limit,
|
||||
query.Offset,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
reports := make([]AuthIdentityMigrationReport, 0)
|
||||
for rows.Next() {
|
||||
report, scanErr := scanAuthIdentityMigrationReport(rows)
|
||||
if scanErr != nil {
|
||||
return nil, scanErr
|
||||
}
|
||||
reports = append(reports, report)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return reports, nil
|
||||
}
|
||||
|
||||
func (r *userRepository) GetAuthIdentityMigrationReport(ctx context.Context, reportType, reportKey string) (*AuthIdentityMigrationReport, error) {
|
||||
exec := txAwareSQLExecutor(ctx, r.sql, r.client)
|
||||
if exec == nil {
|
||||
return nil, fmt.Errorf("sql executor is not configured")
|
||||
}
|
||||
|
||||
rows, err := exec.QueryContext(ctx, `
|
||||
SELECT id, report_type, report_key, details, created_at
|
||||
FROM auth_identity_migration_reports
|
||||
WHERE report_type = $1 AND report_key = $2
|
||||
LIMIT 1`,
|
||||
strings.TrimSpace(reportType),
|
||||
strings.TrimSpace(reportKey),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
if !rows.Next() {
|
||||
return nil, sql.ErrNoRows
|
||||
}
|
||||
report, err := scanAuthIdentityMigrationReport(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &report, rows.Err()
|
||||
}
|
||||
|
||||
func (r *userRepository) SummarizeAuthIdentityMigrationReports(ctx context.Context) (*AuthIdentityMigrationReportSummary, error) {
|
||||
exec := txAwareSQLExecutor(ctx, r.sql, r.client)
|
||||
if exec == nil {
|
||||
return nil, fmt.Errorf("sql executor is not configured")
|
||||
}
|
||||
|
||||
rows, err := exec.QueryContext(ctx, `
|
||||
SELECT report_type, COUNT(*)
|
||||
FROM auth_identity_migration_reports
|
||||
GROUP BY report_type
|
||||
ORDER BY report_type ASC`)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
summary := &AuthIdentityMigrationReportSummary{
|
||||
ByType: make(map[string]int64),
|
||||
}
|
||||
for rows.Next() {
|
||||
var reportType string
|
||||
var count int64
|
||||
if err := rows.Scan(&reportType, &count); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
summary.ByType[reportType] = count
|
||||
summary.Total += count
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return summary, nil
|
||||
}
|
||||
|
||||
func scanAuthIdentityMigrationReport(scanner interface{ Scan(dest ...any) error }) (AuthIdentityMigrationReport, error) {
|
||||
var (
|
||||
report AuthIdentityMigrationReport
|
||||
details []byte
|
||||
)
|
||||
if err := scanner.Scan(&report.ID, &report.ReportType, &report.ReportKey, &details, &report.CreatedAt); err != nil {
|
||||
return AuthIdentityMigrationReport{}, err
|
||||
}
|
||||
report.Details = map[string]any{}
|
||||
if len(details) > 0 {
|
||||
if err := json.Unmarshal(details, &report.Details); err != nil {
|
||||
return AuthIdentityMigrationReport{}, err
|
||||
}
|
||||
}
|
||||
return report, nil
|
||||
}
|
||||
544
backend/internal/repository/user_profile_identity_repo.go
Normal file
544
backend/internal/repository/user_profile_identity_repo.go
Normal file
@@ -0,0 +1,544 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/authidentity"
|
||||
"github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
|
||||
"github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrAuthIdentityOwnershipConflict = infraerrors.Conflict(
|
||||
"AUTH_IDENTITY_OWNERSHIP_CONFLICT",
|
||||
"auth identity already belongs to another user",
|
||||
)
|
||||
ErrAuthIdentityChannelOwnershipConflict = infraerrors.Conflict(
|
||||
"AUTH_IDENTITY_CHANNEL_OWNERSHIP_CONFLICT",
|
||||
"auth identity channel already belongs to another user",
|
||||
)
|
||||
)
|
||||
|
||||
type ProviderGrantReason string
|
||||
|
||||
const (
|
||||
ProviderGrantReasonSignup ProviderGrantReason = "signup"
|
||||
ProviderGrantReasonFirstBind ProviderGrantReason = "first_bind"
|
||||
)
|
||||
|
||||
type AuthIdentityKey struct {
|
||||
ProviderType string
|
||||
ProviderKey string
|
||||
ProviderSubject string
|
||||
}
|
||||
|
||||
type AuthIdentityChannelKey struct {
|
||||
ProviderType string
|
||||
ProviderKey string
|
||||
Channel string
|
||||
ChannelAppID string
|
||||
ChannelSubject string
|
||||
}
|
||||
|
||||
type CreateAuthIdentityInput struct {
|
||||
UserID int64
|
||||
Canonical AuthIdentityKey
|
||||
Channel *AuthIdentityChannelKey
|
||||
Issuer *string
|
||||
VerifiedAt *time.Time
|
||||
Metadata map[string]any
|
||||
ChannelMetadata map[string]any
|
||||
}
|
||||
|
||||
type BindAuthIdentityInput = CreateAuthIdentityInput
|
||||
|
||||
type CreateAuthIdentityResult struct {
|
||||
Identity *dbent.AuthIdentity
|
||||
Channel *dbent.AuthIdentityChannel
|
||||
}
|
||||
|
||||
func (r *CreateAuthIdentityResult) IdentityRef() AuthIdentityKey {
|
||||
if r == nil || r.Identity == nil {
|
||||
return AuthIdentityKey{}
|
||||
}
|
||||
return AuthIdentityKey{
|
||||
ProviderType: r.Identity.ProviderType,
|
||||
ProviderKey: r.Identity.ProviderKey,
|
||||
ProviderSubject: r.Identity.ProviderSubject,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *CreateAuthIdentityResult) ChannelRef() *AuthIdentityChannelKey {
|
||||
if r == nil || r.Channel == nil {
|
||||
return nil
|
||||
}
|
||||
return &AuthIdentityChannelKey{
|
||||
ProviderType: r.Channel.ProviderType,
|
||||
ProviderKey: r.Channel.ProviderKey,
|
||||
Channel: r.Channel.Channel,
|
||||
ChannelAppID: r.Channel.ChannelAppID,
|
||||
ChannelSubject: r.Channel.ChannelSubject,
|
||||
}
|
||||
}
|
||||
|
||||
type UserAuthIdentityLookup struct {
|
||||
User *dbent.User
|
||||
Identity *dbent.AuthIdentity
|
||||
Channel *dbent.AuthIdentityChannel
|
||||
}
|
||||
|
||||
type ProviderGrantRecordInput struct {
|
||||
UserID int64
|
||||
ProviderType string
|
||||
GrantReason ProviderGrantReason
|
||||
}
|
||||
|
||||
type IdentityAdoptionDecisionInput struct {
|
||||
PendingAuthSessionID int64
|
||||
IdentityID *int64
|
||||
AdoptDisplayName bool
|
||||
AdoptAvatar bool
|
||||
}
|
||||
|
||||
type sqlQueryExecutor interface {
|
||||
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
|
||||
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
|
||||
}
|
||||
|
||||
func (r *userRepository) WithUserProfileIdentityTx(ctx context.Context, fn func(txCtx context.Context) error) error {
|
||||
if dbent.TxFromContext(ctx) != nil {
|
||||
return fn(ctx)
|
||||
}
|
||||
|
||||
tx, err := r.client.Tx(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = tx.Rollback() }()
|
||||
|
||||
txCtx := dbent.NewTxContext(ctx, tx)
|
||||
if err := fn(txCtx); err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func (r *userRepository) CreateAuthIdentity(ctx context.Context, input CreateAuthIdentityInput) (*CreateAuthIdentityResult, error) {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
|
||||
create := client.AuthIdentity.Create().
|
||||
SetUserID(input.UserID).
|
||||
SetProviderType(strings.TrimSpace(input.Canonical.ProviderType)).
|
||||
SetProviderKey(strings.TrimSpace(input.Canonical.ProviderKey)).
|
||||
SetProviderSubject(strings.TrimSpace(input.Canonical.ProviderSubject)).
|
||||
SetMetadata(copyMetadata(input.Metadata)).
|
||||
SetNillableIssuer(input.Issuer).
|
||||
SetNillableVerifiedAt(input.VerifiedAt)
|
||||
|
||||
identity, err := create.Save(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var channel *dbent.AuthIdentityChannel
|
||||
if input.Channel != nil {
|
||||
channel, err = client.AuthIdentityChannel.Create().
|
||||
SetIdentityID(identity.ID).
|
||||
SetProviderType(strings.TrimSpace(input.Channel.ProviderType)).
|
||||
SetProviderKey(strings.TrimSpace(input.Channel.ProviderKey)).
|
||||
SetChannel(strings.TrimSpace(input.Channel.Channel)).
|
||||
SetChannelAppID(strings.TrimSpace(input.Channel.ChannelAppID)).
|
||||
SetChannelSubject(strings.TrimSpace(input.Channel.ChannelSubject)).
|
||||
SetMetadata(copyMetadata(input.ChannelMetadata)).
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &CreateAuthIdentityResult{Identity: identity, Channel: channel}, nil
|
||||
}
|
||||
|
||||
func (r *userRepository) GetUserByCanonicalIdentity(ctx context.Context, key AuthIdentityKey) (*UserAuthIdentityLookup, error) {
|
||||
identity, err := clientFromContext(ctx, r.client).AuthIdentity.Query().
|
||||
Where(
|
||||
authidentity.ProviderTypeEQ(strings.TrimSpace(key.ProviderType)),
|
||||
authidentity.ProviderKeyEQ(strings.TrimSpace(key.ProviderKey)),
|
||||
authidentity.ProviderSubjectEQ(strings.TrimSpace(key.ProviderSubject)),
|
||||
).
|
||||
WithUser().
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &UserAuthIdentityLookup{
|
||||
User: identity.Edges.User,
|
||||
Identity: identity,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *userRepository) GetUserByChannelIdentity(ctx context.Context, key AuthIdentityChannelKey) (*UserAuthIdentityLookup, error) {
|
||||
channel, err := clientFromContext(ctx, r.client).AuthIdentityChannel.Query().
|
||||
Where(
|
||||
authidentitychannel.ProviderTypeEQ(strings.TrimSpace(key.ProviderType)),
|
||||
authidentitychannel.ProviderKeyEQ(strings.TrimSpace(key.ProviderKey)),
|
||||
authidentitychannel.ChannelEQ(strings.TrimSpace(key.Channel)),
|
||||
authidentitychannel.ChannelAppIDEQ(strings.TrimSpace(key.ChannelAppID)),
|
||||
authidentitychannel.ChannelSubjectEQ(strings.TrimSpace(key.ChannelSubject)),
|
||||
).
|
||||
WithIdentity(func(q *dbent.AuthIdentityQuery) {
|
||||
q.WithUser()
|
||||
}).
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &UserAuthIdentityLookup{
|
||||
User: channel.Edges.Identity.Edges.User,
|
||||
Identity: channel.Edges.Identity,
|
||||
Channel: channel,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *userRepository) BindAuthIdentityToUser(ctx context.Context, input BindAuthIdentityInput) (*CreateAuthIdentityResult, error) {
|
||||
var result *CreateAuthIdentityResult
|
||||
err := r.WithUserProfileIdentityTx(ctx, func(txCtx context.Context) error {
|
||||
client := clientFromContext(txCtx, r.client)
|
||||
canonical := input.Canonical
|
||||
|
||||
identity, err := client.AuthIdentity.Query().
|
||||
Where(
|
||||
authidentity.ProviderTypeEQ(strings.TrimSpace(canonical.ProviderType)),
|
||||
authidentity.ProviderKeyEQ(strings.TrimSpace(canonical.ProviderKey)),
|
||||
authidentity.ProviderSubjectEQ(strings.TrimSpace(canonical.ProviderSubject)),
|
||||
).
|
||||
Only(txCtx)
|
||||
if err != nil && !dbent.IsNotFound(err) {
|
||||
return err
|
||||
}
|
||||
if identity != nil && identity.UserID != input.UserID {
|
||||
return ErrAuthIdentityOwnershipConflict
|
||||
}
|
||||
if identity == nil {
|
||||
identity, err = client.AuthIdentity.Create().
|
||||
SetUserID(input.UserID).
|
||||
SetProviderType(strings.TrimSpace(canonical.ProviderType)).
|
||||
SetProviderKey(strings.TrimSpace(canonical.ProviderKey)).
|
||||
SetProviderSubject(strings.TrimSpace(canonical.ProviderSubject)).
|
||||
SetMetadata(copyMetadata(input.Metadata)).
|
||||
SetNillableIssuer(input.Issuer).
|
||||
SetNillableVerifiedAt(input.VerifiedAt).
|
||||
Save(txCtx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
update := client.AuthIdentity.UpdateOneID(identity.ID)
|
||||
if input.Metadata != nil {
|
||||
update = update.SetMetadata(copyMetadata(input.Metadata))
|
||||
}
|
||||
if input.Issuer != nil {
|
||||
update = update.SetIssuer(strings.TrimSpace(*input.Issuer))
|
||||
}
|
||||
if input.VerifiedAt != nil {
|
||||
update = update.SetVerifiedAt(*input.VerifiedAt)
|
||||
}
|
||||
identity, err = update.Save(txCtx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
var channel *dbent.AuthIdentityChannel
|
||||
if input.Channel != nil {
|
||||
channel, err = client.AuthIdentityChannel.Query().
|
||||
Where(
|
||||
authidentitychannel.ProviderTypeEQ(strings.TrimSpace(input.Channel.ProviderType)),
|
||||
authidentitychannel.ProviderKeyEQ(strings.TrimSpace(input.Channel.ProviderKey)),
|
||||
authidentitychannel.ChannelEQ(strings.TrimSpace(input.Channel.Channel)),
|
||||
authidentitychannel.ChannelAppIDEQ(strings.TrimSpace(input.Channel.ChannelAppID)),
|
||||
authidentitychannel.ChannelSubjectEQ(strings.TrimSpace(input.Channel.ChannelSubject)),
|
||||
).
|
||||
WithIdentity().
|
||||
Only(txCtx)
|
||||
if err != nil && !dbent.IsNotFound(err) {
|
||||
return err
|
||||
}
|
||||
if channel != nil && channel.Edges.Identity != nil && channel.Edges.Identity.UserID != input.UserID {
|
||||
return ErrAuthIdentityChannelOwnershipConflict
|
||||
}
|
||||
if channel == nil {
|
||||
channel, err = client.AuthIdentityChannel.Create().
|
||||
SetIdentityID(identity.ID).
|
||||
SetProviderType(strings.TrimSpace(input.Channel.ProviderType)).
|
||||
SetProviderKey(strings.TrimSpace(input.Channel.ProviderKey)).
|
||||
SetChannel(strings.TrimSpace(input.Channel.Channel)).
|
||||
SetChannelAppID(strings.TrimSpace(input.Channel.ChannelAppID)).
|
||||
SetChannelSubject(strings.TrimSpace(input.Channel.ChannelSubject)).
|
||||
SetMetadata(copyMetadata(input.ChannelMetadata)).
|
||||
Save(txCtx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
update := client.AuthIdentityChannel.UpdateOneID(channel.ID).
|
||||
SetIdentityID(identity.ID)
|
||||
if input.ChannelMetadata != nil {
|
||||
update = update.SetMetadata(copyMetadata(input.ChannelMetadata))
|
||||
}
|
||||
channel, err = update.Save(txCtx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
result = &CreateAuthIdentityResult{Identity: identity, Channel: channel}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (r *userRepository) RecordProviderGrant(ctx context.Context, input ProviderGrantRecordInput) (bool, error) {
|
||||
exec := txAwareSQLExecutor(ctx, r.sql, r.client)
|
||||
if exec == nil {
|
||||
return false, fmt.Errorf("sql executor is not configured")
|
||||
}
|
||||
|
||||
result, err := exec.ExecContext(ctx, `
|
||||
INSERT INTO user_provider_default_grants (user_id, provider_type, grant_reason)
|
||||
VALUES ($1, $2, $3)
|
||||
ON CONFLICT (user_id, provider_type, grant_reason) DO NOTHING`,
|
||||
input.UserID,
|
||||
strings.TrimSpace(input.ProviderType),
|
||||
string(input.GrantReason),
|
||||
)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
affected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return affected > 0, nil
|
||||
}
|
||||
|
||||
func (r *userRepository) UpsertIdentityAdoptionDecision(ctx context.Context, input IdentityAdoptionDecisionInput) (*dbent.IdentityAdoptionDecision, error) {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
current, err := client.IdentityAdoptionDecision.Query().
|
||||
Where(identityadoptiondecision.PendingAuthSessionIDEQ(input.PendingAuthSessionID)).
|
||||
Only(ctx)
|
||||
if err != nil && !dbent.IsNotFound(err) {
|
||||
return nil, err
|
||||
}
|
||||
now := time.Now().UTC()
|
||||
if current == nil {
|
||||
create := client.IdentityAdoptionDecision.Create().
|
||||
SetPendingAuthSessionID(input.PendingAuthSessionID).
|
||||
SetAdoptDisplayName(input.AdoptDisplayName).
|
||||
SetAdoptAvatar(input.AdoptAvatar).
|
||||
SetDecidedAt(now)
|
||||
if input.IdentityID != nil {
|
||||
create = create.SetIdentityID(*input.IdentityID)
|
||||
}
|
||||
return create.Save(ctx)
|
||||
}
|
||||
|
||||
update := client.IdentityAdoptionDecision.UpdateOneID(current.ID).
|
||||
SetAdoptDisplayName(input.AdoptDisplayName).
|
||||
SetAdoptAvatar(input.AdoptAvatar)
|
||||
if input.IdentityID != nil {
|
||||
update = update.SetIdentityID(*input.IdentityID)
|
||||
}
|
||||
return update.Save(ctx)
|
||||
}
|
||||
|
||||
func (r *userRepository) GetIdentityAdoptionDecisionByPendingAuthSessionID(ctx context.Context, pendingAuthSessionID int64) (*dbent.IdentityAdoptionDecision, error) {
|
||||
return clientFromContext(ctx, r.client).IdentityAdoptionDecision.Query().
|
||||
Where(identityadoptiondecision.PendingAuthSessionIDEQ(pendingAuthSessionID)).
|
||||
Only(ctx)
|
||||
}
|
||||
|
||||
func (r *userRepository) UpdateUserLastLoginAt(ctx context.Context, userID int64, loginAt time.Time) error {
|
||||
_, err := clientFromContext(ctx, r.client).User.UpdateOneID(userID).
|
||||
SetLastLoginAt(loginAt).
|
||||
Save(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *userRepository) UpdateUserLastActiveAt(ctx context.Context, userID int64, activeAt time.Time) error {
|
||||
_, err := clientFromContext(ctx, r.client).User.UpdateOneID(userID).
|
||||
SetLastActiveAt(activeAt).
|
||||
Save(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *userRepository) GetUserAvatar(ctx context.Context, userID int64) (*service.UserAvatar, error) {
|
||||
exec, err := r.userProfileIdentitySQL(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rows, err := exec.QueryContext(ctx, `
|
||||
SELECT storage_provider, storage_key, url, content_type, byte_size, sha256
|
||||
FROM user_avatars
|
||||
WHERE user_id = $1`, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
if !rows.Next() {
|
||||
return nil, rows.Err()
|
||||
}
|
||||
|
||||
var avatar service.UserAvatar
|
||||
if err := rows.Scan(
|
||||
&avatar.StorageProvider,
|
||||
&avatar.StorageKey,
|
||||
&avatar.URL,
|
||||
&avatar.ContentType,
|
||||
&avatar.ByteSize,
|
||||
&avatar.SHA256,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &avatar, nil
|
||||
}
|
||||
|
||||
func (r *userRepository) UpsertUserAvatar(ctx context.Context, userID int64, input service.UpsertUserAvatarInput) (*service.UserAvatar, error) {
|
||||
exec, err := r.userProfileIdentitySQL(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, err = exec.ExecContext(ctx, `
|
||||
INSERT INTO user_avatars (user_id, storage_provider, storage_key, url, content_type, byte_size, sha256, updated_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, NOW())
|
||||
ON CONFLICT (user_id) DO UPDATE SET
|
||||
storage_provider = EXCLUDED.storage_provider,
|
||||
storage_key = EXCLUDED.storage_key,
|
||||
url = EXCLUDED.url,
|
||||
content_type = EXCLUDED.content_type,
|
||||
byte_size = EXCLUDED.byte_size,
|
||||
sha256 = EXCLUDED.sha256,
|
||||
updated_at = NOW()`,
|
||||
userID,
|
||||
strings.TrimSpace(input.StorageProvider),
|
||||
strings.TrimSpace(input.StorageKey),
|
||||
strings.TrimSpace(input.URL),
|
||||
strings.TrimSpace(input.ContentType),
|
||||
input.ByteSize,
|
||||
strings.TrimSpace(input.SHA256),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &service.UserAvatar{
|
||||
StorageProvider: strings.TrimSpace(input.StorageProvider),
|
||||
StorageKey: strings.TrimSpace(input.StorageKey),
|
||||
URL: strings.TrimSpace(input.URL),
|
||||
ContentType: strings.TrimSpace(input.ContentType),
|
||||
ByteSize: input.ByteSize,
|
||||
SHA256: strings.TrimSpace(input.SHA256),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *userRepository) DeleteUserAvatar(ctx context.Context, userID int64) error {
|
||||
exec, err := r.userProfileIdentitySQL(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = exec.ExecContext(ctx, `DELETE FROM user_avatars WHERE user_id = $1`, userID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *userRepository) attachUserAvatar(ctx context.Context, user *service.User) error {
|
||||
if user == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
avatar, err := r.GetUserAvatar(ctx, user.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if avatar == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
user.AvatarURL = avatar.URL
|
||||
user.AvatarSource = avatar.StorageProvider
|
||||
user.AvatarMIME = avatar.ContentType
|
||||
user.AvatarByteSize = avatar.ByteSize
|
||||
user.AvatarSHA256 = avatar.SHA256
|
||||
return nil
|
||||
}
|
||||
|
||||
func copyMetadata(in map[string]any) map[string]any {
|
||||
if len(in) == 0 {
|
||||
return map[string]any{}
|
||||
}
|
||||
out := make(map[string]any, len(in))
|
||||
for k, v := range in {
|
||||
out[k] = v
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func txAwareSQLExecutor(ctx context.Context, fallback sqlExecutor, client *dbent.Client) sqlQueryExecutor {
|
||||
if tx := dbent.TxFromContext(ctx); tx != nil {
|
||||
if exec := sqlExecutorFromEntClient(tx.Client()); exec != nil {
|
||||
return exec
|
||||
}
|
||||
}
|
||||
if fallback != nil {
|
||||
return fallback
|
||||
}
|
||||
return sqlExecutorFromEntClient(client)
|
||||
}
|
||||
|
||||
func (r *userRepository) userProfileIdentitySQL(ctx context.Context) (sqlQueryExecutor, error) {
|
||||
exec := txAwareSQLExecutor(ctx, r.sql, r.client)
|
||||
if exec == nil {
|
||||
return nil, fmt.Errorf("sql executor is not configured")
|
||||
}
|
||||
return exec, nil
|
||||
}
|
||||
|
||||
func sqlExecutorFromEntClient(client *dbent.Client) sqlQueryExecutor {
|
||||
if client == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
clientValue := reflect.ValueOf(client).Elem()
|
||||
configValue := clientValue.FieldByName("config")
|
||||
driverValue := configValue.FieldByName("driver")
|
||||
if !driverValue.IsValid() {
|
||||
return nil
|
||||
}
|
||||
|
||||
driver := reflect.NewAt(driverValue.Type(), unsafe.Pointer(driverValue.UnsafeAddr())).Elem().Interface()
|
||||
exec, ok := driver.(sqlQueryExecutor)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return exec
|
||||
}
|
||||
@@ -0,0 +1,428 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type UserProfileIdentityRepoSuite struct {
|
||||
suite.Suite
|
||||
ctx context.Context
|
||||
client *dbent.Client
|
||||
repo *userRepository
|
||||
}
|
||||
|
||||
func TestUserProfileIdentityRepoSuite(t *testing.T) {
|
||||
suite.Run(t, new(UserProfileIdentityRepoSuite))
|
||||
}
|
||||
|
||||
func (s *UserProfileIdentityRepoSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
s.client = testEntClient(s.T())
|
||||
s.repo = newUserRepositoryWithSQL(s.client, integrationDB)
|
||||
|
||||
_, err := integrationDB.ExecContext(s.ctx, `
|
||||
TRUNCATE TABLE
|
||||
identity_adoption_decisions,
|
||||
auth_identity_channels,
|
||||
auth_identities,
|
||||
pending_auth_sessions,
|
||||
auth_identity_migration_reports,
|
||||
user_provider_default_grants,
|
||||
user_avatars
|
||||
RESTART IDENTITY`)
|
||||
s.Require().NoError(err)
|
||||
}
|
||||
|
||||
func (s *UserProfileIdentityRepoSuite) mustCreateUser(label string) *dbent.User {
|
||||
s.T().Helper()
|
||||
|
||||
user, err := s.client.User.Create().
|
||||
SetEmail(fmt.Sprintf("%s-%d@example.com", label, time.Now().UnixNano())).
|
||||
SetPasswordHash("test-password-hash").
|
||||
SetRole("user").
|
||||
SetStatus("active").
|
||||
Save(s.ctx)
|
||||
s.Require().NoError(err)
|
||||
return user
|
||||
}
|
||||
|
||||
func (s *UserProfileIdentityRepoSuite) mustCreatePendingAuthSession(key AuthIdentityKey) *dbent.PendingAuthSession {
|
||||
s.T().Helper()
|
||||
|
||||
session, err := s.client.PendingAuthSession.Create().
|
||||
SetSessionToken(fmt.Sprintf("pending-%d", time.Now().UnixNano())).
|
||||
SetIntent("bind_current_user").
|
||||
SetProviderType(key.ProviderType).
|
||||
SetProviderKey(key.ProviderKey).
|
||||
SetProviderSubject(key.ProviderSubject).
|
||||
SetExpiresAt(time.Now().UTC().Add(15 * time.Minute)).
|
||||
SetUpstreamIdentityClaims(map[string]any{"provider_subject": key.ProviderSubject}).
|
||||
SetLocalFlowState(map[string]any{"step": "pending"}).
|
||||
Save(s.ctx)
|
||||
s.Require().NoError(err)
|
||||
return session
|
||||
}
|
||||
|
||||
func (s *UserProfileIdentityRepoSuite) TestCreateAndLookupCanonicalAndChannelIdentity() {
|
||||
user := s.mustCreateUser("canonical-channel")
|
||||
|
||||
verifiedAt := time.Now().UTC().Truncate(time.Second)
|
||||
created, err := s.repo.CreateAuthIdentity(s.ctx, CreateAuthIdentityInput{
|
||||
UserID: user.ID,
|
||||
Canonical: AuthIdentityKey{
|
||||
ProviderType: "wechat",
|
||||
ProviderKey: "wechat-open",
|
||||
ProviderSubject: "union-123",
|
||||
},
|
||||
Channel: &AuthIdentityChannelKey{
|
||||
ProviderType: "wechat",
|
||||
ProviderKey: "wechat-open",
|
||||
Channel: "mp",
|
||||
ChannelAppID: "wx-app",
|
||||
ChannelSubject: "openid-123",
|
||||
},
|
||||
Issuer: stringPtr("https://issuer.example"),
|
||||
VerifiedAt: &verifiedAt,
|
||||
Metadata: map[string]any{"unionid": "union-123"},
|
||||
ChannelMetadata: map[string]any{"openid": "openid-123"},
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
s.Require().NotNil(created.Identity)
|
||||
s.Require().NotNil(created.Channel)
|
||||
|
||||
canonical, err := s.repo.GetUserByCanonicalIdentity(s.ctx, created.IdentityRef())
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal(user.ID, canonical.User.ID)
|
||||
s.Require().Equal(created.Identity.ID, canonical.Identity.ID)
|
||||
s.Require().Equal("union-123", canonical.Identity.ProviderSubject)
|
||||
|
||||
channel, err := s.repo.GetUserByChannelIdentity(s.ctx, *created.ChannelRef())
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal(user.ID, channel.User.ID)
|
||||
s.Require().Equal(created.Identity.ID, channel.Identity.ID)
|
||||
s.Require().Equal(created.Channel.ID, channel.Channel.ID)
|
||||
}
|
||||
|
||||
func (s *UserProfileIdentityRepoSuite) TestBindAuthIdentityToUser_IsIdempotentAndRejectsOtherOwners() {
|
||||
owner := s.mustCreateUser("owner")
|
||||
other := s.mustCreateUser("other")
|
||||
|
||||
first, err := s.repo.BindAuthIdentityToUser(s.ctx, BindAuthIdentityInput{
|
||||
UserID: owner.ID,
|
||||
Canonical: AuthIdentityKey{
|
||||
ProviderType: "linuxdo",
|
||||
ProviderKey: "linuxdo-main",
|
||||
ProviderSubject: "subject-1",
|
||||
},
|
||||
Channel: &AuthIdentityChannelKey{
|
||||
ProviderType: "linuxdo",
|
||||
ProviderKey: "linuxdo-main",
|
||||
Channel: "oauth",
|
||||
ChannelAppID: "linuxdo-web",
|
||||
ChannelSubject: "subject-1",
|
||||
},
|
||||
Metadata: map[string]any{"username": "first"},
|
||||
ChannelMetadata: map[string]any{"scope": "read"},
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
|
||||
second, err := s.repo.BindAuthIdentityToUser(s.ctx, BindAuthIdentityInput{
|
||||
UserID: owner.ID,
|
||||
Canonical: AuthIdentityKey{
|
||||
ProviderType: "linuxdo",
|
||||
ProviderKey: "linuxdo-main",
|
||||
ProviderSubject: "subject-1",
|
||||
},
|
||||
Channel: &AuthIdentityChannelKey{
|
||||
ProviderType: "linuxdo",
|
||||
ProviderKey: "linuxdo-main",
|
||||
Channel: "oauth",
|
||||
ChannelAppID: "linuxdo-web",
|
||||
ChannelSubject: "subject-1",
|
||||
},
|
||||
Metadata: map[string]any{"username": "second"},
|
||||
ChannelMetadata: map[string]any{"scope": "write"},
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal(first.Identity.ID, second.Identity.ID)
|
||||
s.Require().Equal(first.Channel.ID, second.Channel.ID)
|
||||
s.Require().Equal("second", second.Identity.Metadata["username"])
|
||||
s.Require().Equal("write", second.Channel.Metadata["scope"])
|
||||
|
||||
_, err = s.repo.BindAuthIdentityToUser(s.ctx, BindAuthIdentityInput{
|
||||
UserID: other.ID,
|
||||
Canonical: AuthIdentityKey{
|
||||
ProviderType: "linuxdo",
|
||||
ProviderKey: "linuxdo-main",
|
||||
ProviderSubject: "subject-1",
|
||||
},
|
||||
})
|
||||
s.Require().ErrorIs(err, ErrAuthIdentityOwnershipConflict)
|
||||
|
||||
_, err = s.repo.BindAuthIdentityToUser(s.ctx, BindAuthIdentityInput{
|
||||
UserID: other.ID,
|
||||
Canonical: AuthIdentityKey{
|
||||
ProviderType: "linuxdo",
|
||||
ProviderKey: "linuxdo-main",
|
||||
ProviderSubject: "subject-2",
|
||||
},
|
||||
Channel: &AuthIdentityChannelKey{
|
||||
ProviderType: "linuxdo",
|
||||
ProviderKey: "linuxdo-main",
|
||||
Channel: "oauth",
|
||||
ChannelAppID: "linuxdo-web",
|
||||
ChannelSubject: "subject-1",
|
||||
},
|
||||
})
|
||||
s.Require().ErrorIs(err, ErrAuthIdentityChannelOwnershipConflict)
|
||||
}
|
||||
|
||||
func (s *UserProfileIdentityRepoSuite) TestWithUserProfileIdentityTx_RollsBackIdentityAndGrantOnError() {
|
||||
user := s.mustCreateUser("tx-rollback")
|
||||
expectedErr := errors.New("rollback")
|
||||
|
||||
err := s.repo.WithUserProfileIdentityTx(s.ctx, func(txCtx context.Context) error {
|
||||
_, err := s.repo.CreateAuthIdentity(txCtx, CreateAuthIdentityInput{
|
||||
UserID: user.ID,
|
||||
Canonical: AuthIdentityKey{
|
||||
ProviderType: "oidc",
|
||||
ProviderKey: "https://issuer.example",
|
||||
ProviderSubject: "subject-rollback",
|
||||
},
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
|
||||
inserted, err := s.repo.RecordProviderGrant(txCtx, ProviderGrantRecordInput{
|
||||
UserID: user.ID,
|
||||
ProviderType: "oidc",
|
||||
GrantReason: ProviderGrantReasonFirstBind,
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
s.Require().True(inserted)
|
||||
return expectedErr
|
||||
})
|
||||
s.Require().ErrorIs(err, expectedErr)
|
||||
|
||||
_, err = s.repo.GetUserByCanonicalIdentity(s.ctx, AuthIdentityKey{
|
||||
ProviderType: "oidc",
|
||||
ProviderKey: "https://issuer.example",
|
||||
ProviderSubject: "subject-rollback",
|
||||
})
|
||||
s.Require().True(dbent.IsNotFound(err))
|
||||
|
||||
var count int
|
||||
s.Require().NoError(integrationDB.QueryRowContext(s.ctx, `
|
||||
SELECT COUNT(*)
|
||||
FROM user_provider_default_grants
|
||||
WHERE user_id = $1 AND provider_type = $2 AND grant_reason = $3`,
|
||||
user.ID,
|
||||
"oidc",
|
||||
string(ProviderGrantReasonFirstBind),
|
||||
).Scan(&count))
|
||||
s.Require().Zero(count)
|
||||
}
|
||||
|
||||
func (s *UserProfileIdentityRepoSuite) TestRecordProviderGrant_IsIdempotentPerReason() {
|
||||
user := s.mustCreateUser("grant")
|
||||
|
||||
inserted, err := s.repo.RecordProviderGrant(s.ctx, ProviderGrantRecordInput{
|
||||
UserID: user.ID,
|
||||
ProviderType: "wechat",
|
||||
GrantReason: ProviderGrantReasonFirstBind,
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
s.Require().True(inserted)
|
||||
|
||||
inserted, err = s.repo.RecordProviderGrant(s.ctx, ProviderGrantRecordInput{
|
||||
UserID: user.ID,
|
||||
ProviderType: "wechat",
|
||||
GrantReason: ProviderGrantReasonFirstBind,
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
s.Require().False(inserted)
|
||||
|
||||
inserted, err = s.repo.RecordProviderGrant(s.ctx, ProviderGrantRecordInput{
|
||||
UserID: user.ID,
|
||||
ProviderType: "wechat",
|
||||
GrantReason: ProviderGrantReasonSignup,
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
s.Require().True(inserted)
|
||||
|
||||
var count int
|
||||
s.Require().NoError(integrationDB.QueryRowContext(s.ctx, `
|
||||
SELECT COUNT(*)
|
||||
FROM user_provider_default_grants
|
||||
WHERE user_id = $1 AND provider_type = $2`,
|
||||
user.ID,
|
||||
"wechat",
|
||||
).Scan(&count))
|
||||
s.Require().Equal(2, count)
|
||||
}
|
||||
|
||||
func (s *UserProfileIdentityRepoSuite) TestUpsertIdentityAdoptionDecision_PersistsAndLinksIdentity() {
|
||||
user := s.mustCreateUser("adoption")
|
||||
identity, err := s.repo.CreateAuthIdentity(s.ctx, CreateAuthIdentityInput{
|
||||
UserID: user.ID,
|
||||
Canonical: AuthIdentityKey{
|
||||
ProviderType: "wechat",
|
||||
ProviderKey: "wechat-open",
|
||||
ProviderSubject: "union-adoption",
|
||||
},
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
|
||||
session := s.mustCreatePendingAuthSession(identity.IdentityRef())
|
||||
|
||||
first, err := s.repo.UpsertIdentityAdoptionDecision(s.ctx, IdentityAdoptionDecisionInput{
|
||||
PendingAuthSessionID: session.ID,
|
||||
AdoptDisplayName: true,
|
||||
AdoptAvatar: false,
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
s.Require().True(first.AdoptDisplayName)
|
||||
s.Require().False(first.AdoptAvatar)
|
||||
s.Require().Nil(first.IdentityID)
|
||||
|
||||
second, err := s.repo.UpsertIdentityAdoptionDecision(s.ctx, IdentityAdoptionDecisionInput{
|
||||
PendingAuthSessionID: session.ID,
|
||||
IdentityID: &identity.Identity.ID,
|
||||
AdoptDisplayName: true,
|
||||
AdoptAvatar: true,
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal(first.ID, second.ID)
|
||||
s.Require().NotNil(second.IdentityID)
|
||||
s.Require().Equal(identity.Identity.ID, *second.IdentityID)
|
||||
s.Require().True(second.AdoptAvatar)
|
||||
|
||||
loaded, err := s.repo.GetIdentityAdoptionDecisionByPendingAuthSessionID(s.ctx, session.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal(second.ID, loaded.ID)
|
||||
s.Require().Equal(identity.Identity.ID, *loaded.IdentityID)
|
||||
}
|
||||
|
||||
func (s *UserProfileIdentityRepoSuite) TestUserAvatarCRUDAndUserLookup() {
|
||||
user := s.mustCreateUser("avatar")
|
||||
|
||||
inlineAvatar, err := s.repo.UpsertUserAvatar(s.ctx, user.ID, service.UpsertUserAvatarInput{
|
||||
StorageProvider: "inline",
|
||||
URL: "data:image/png;base64,QUJD",
|
||||
ContentType: "image/png",
|
||||
ByteSize: 3,
|
||||
SHA256: "902fbdd2b1df0c4f70b4a5d23525e932",
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal("inline", inlineAvatar.StorageProvider)
|
||||
s.Require().Equal("data:image/png;base64,QUJD", inlineAvatar.URL)
|
||||
|
||||
loadedAvatar, err := s.repo.GetUserAvatar(s.ctx, user.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().NotNil(loadedAvatar)
|
||||
s.Require().Equal("image/png", loadedAvatar.ContentType)
|
||||
s.Require().Equal(3, loadedAvatar.ByteSize)
|
||||
|
||||
_, err = s.repo.UpsertUserAvatar(s.ctx, user.ID, service.UpsertUserAvatarInput{
|
||||
StorageProvider: "remote_url",
|
||||
URL: "https://cdn.example.com/avatar.png",
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
|
||||
loadedAvatar, err = s.repo.GetUserAvatar(s.ctx, user.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().NotNil(loadedAvatar)
|
||||
s.Require().Equal("remote_url", loadedAvatar.StorageProvider)
|
||||
s.Require().Equal("https://cdn.example.com/avatar.png", loadedAvatar.URL)
|
||||
s.Require().Zero(loadedAvatar.ByteSize)
|
||||
|
||||
s.Require().NoError(s.repo.DeleteUserAvatar(s.ctx, user.ID))
|
||||
loadedAvatar, err = s.repo.GetUserAvatar(s.ctx, user.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Nil(loadedAvatar)
|
||||
}
|
||||
|
||||
func (s *UserProfileIdentityRepoSuite) TestAuthIdentityMigrationReportHelpers_ListAndSummarize() {
|
||||
_, err := integrationDB.ExecContext(s.ctx, `
|
||||
INSERT INTO auth_identity_migration_reports (report_type, report_key, details, created_at)
|
||||
VALUES
|
||||
('wechat_openid_only_requires_remediation', 'u-1', '{"user_id":1}'::jsonb, '2026-04-20T10:00:00Z'),
|
||||
('wechat_openid_only_requires_remediation', 'u-2', '{"user_id":2}'::jsonb, '2026-04-20T11:00:00Z'),
|
||||
('oidc_synthetic_email_requires_manual_recovery', 'u-3', '{"user_id":3}'::jsonb, '2026-04-20T12:00:00Z')`)
|
||||
s.Require().NoError(err)
|
||||
|
||||
summary, err := s.repo.SummarizeAuthIdentityMigrationReports(s.ctx)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal(int64(3), summary.Total)
|
||||
s.Require().Equal(int64(2), summary.ByType["wechat_openid_only_requires_remediation"])
|
||||
s.Require().Equal(int64(1), summary.ByType["oidc_synthetic_email_requires_manual_recovery"])
|
||||
|
||||
reports, err := s.repo.ListAuthIdentityMigrationReports(s.ctx, AuthIdentityMigrationReportQuery{
|
||||
ReportType: "wechat_openid_only_requires_remediation",
|
||||
Limit: 10,
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(reports, 2)
|
||||
s.Require().Equal("u-2", reports[0].ReportKey)
|
||||
s.Require().Equal(float64(2), reports[0].Details["user_id"])
|
||||
|
||||
report, err := s.repo.GetAuthIdentityMigrationReport(s.ctx, "oidc_synthetic_email_requires_manual_recovery", "u-3")
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal("u-3", report.ReportKey)
|
||||
s.Require().Equal(float64(3), report.Details["user_id"])
|
||||
}
|
||||
|
||||
func (s *UserProfileIdentityRepoSuite) TestUpdateUserLastLoginAndActiveAt_UsesDedicatedColumns() {
|
||||
user := s.mustCreateUser("activity")
|
||||
loginAt := time.Date(2026, 4, 20, 8, 0, 0, 0, time.UTC)
|
||||
activeAt := loginAt.Add(5 * time.Minute)
|
||||
|
||||
s.Require().NoError(s.repo.UpdateUserLastLoginAt(s.ctx, user.ID, loginAt))
|
||||
s.Require().NoError(s.repo.UpdateUserLastActiveAt(s.ctx, user.ID, activeAt))
|
||||
|
||||
var storedLoginAt sqlNullTime
|
||||
var storedActiveAt sqlNullTime
|
||||
s.Require().NoError(integrationDB.QueryRowContext(s.ctx, `
|
||||
SELECT last_login_at, last_active_at
|
||||
FROM users
|
||||
WHERE id = $1`,
|
||||
user.ID,
|
||||
).Scan(&storedLoginAt, &storedActiveAt))
|
||||
s.Require().True(storedLoginAt.Valid)
|
||||
s.Require().True(storedActiveAt.Valid)
|
||||
s.Require().True(storedLoginAt.Time.Equal(loginAt))
|
||||
s.Require().True(storedActiveAt.Time.Equal(activeAt))
|
||||
}
|
||||
|
||||
type sqlNullTime struct {
|
||||
Time time.Time
|
||||
Valid bool
|
||||
}
|
||||
|
||||
func (t *sqlNullTime) Scan(value any) error {
|
||||
switch v := value.(type) {
|
||||
case time.Time:
|
||||
t.Time = v
|
||||
t.Valid = true
|
||||
return nil
|
||||
case nil:
|
||||
t.Time = time.Time{}
|
||||
t.Valid = false
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("unsupported scan type %T", value)
|
||||
}
|
||||
}
|
||||
|
||||
func stringPtr(v string) *string {
|
||||
return &v
|
||||
}
|
||||
@@ -64,6 +64,9 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error
|
||||
SetBalance(userIn.Balance).
|
||||
SetConcurrency(userIn.Concurrency).
|
||||
SetStatus(userIn.Status).
|
||||
SetSignupSource(userSignupSourceOrDefault(userIn.SignupSource)).
|
||||
SetNillableLastLoginAt(userIn.LastLoginAt).
|
||||
SetNillableLastActiveAt(userIn.LastActiveAt).
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
return translatePersistenceError(err, nil, service.ErrEmailExists)
|
||||
@@ -151,6 +154,15 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
|
||||
SetNillableBalanceNotifyThreshold(userIn.BalanceNotifyThreshold).
|
||||
SetBalanceNotifyExtraEmails(marshalExtraEmails(userIn.BalanceNotifyExtraEmails)).
|
||||
SetTotalRecharged(userIn.TotalRecharged)
|
||||
if userIn.SignupSource != "" {
|
||||
updateOp = updateOp.SetSignupSource(userIn.SignupSource)
|
||||
}
|
||||
if userIn.LastLoginAt != nil {
|
||||
updateOp = updateOp.SetLastLoginAt(*userIn.LastLoginAt)
|
||||
}
|
||||
if userIn.LastActiveAt != nil {
|
||||
updateOp = updateOp.SetLastActiveAt(*userIn.LastActiveAt)
|
||||
}
|
||||
if userIn.BalanceNotifyThreshold == nil {
|
||||
updateOp = updateOp.ClearBalanceNotifyThreshold()
|
||||
}
|
||||
@@ -300,6 +312,7 @@ func userListOrder(params pagination.PaginationParams) []func(*entsql.Selector)
|
||||
|
||||
var field string
|
||||
defaultField := true
|
||||
nullsLastField := false
|
||||
switch sortBy {
|
||||
case "email":
|
||||
field = dbuser.FieldEmail
|
||||
@@ -322,6 +335,14 @@ func userListOrder(params pagination.PaginationParams) []func(*entsql.Selector)
|
||||
case "created_at":
|
||||
field = dbuser.FieldCreatedAt
|
||||
defaultField = false
|
||||
case "last_login_at":
|
||||
field = dbuser.FieldLastLoginAt
|
||||
defaultField = false
|
||||
nullsLastField = true
|
||||
case "last_active_at":
|
||||
field = dbuser.FieldLastActiveAt
|
||||
defaultField = false
|
||||
nullsLastField = true
|
||||
default:
|
||||
field = dbuser.FieldID
|
||||
}
|
||||
@@ -330,11 +351,23 @@ func userListOrder(params pagination.PaginationParams) []func(*entsql.Selector)
|
||||
if defaultField && field == dbuser.FieldID {
|
||||
return []func(*entsql.Selector){dbent.Asc(dbuser.FieldID)}
|
||||
}
|
||||
if nullsLastField {
|
||||
return []func(*entsql.Selector){
|
||||
entsql.OrderByField(field, entsql.OrderNullsLast()).ToFunc(),
|
||||
dbent.Asc(dbuser.FieldID),
|
||||
}
|
||||
}
|
||||
return []func(*entsql.Selector){dbent.Asc(field), dbent.Asc(dbuser.FieldID)}
|
||||
}
|
||||
if defaultField && field == dbuser.FieldID {
|
||||
return []func(*entsql.Selector){dbent.Desc(dbuser.FieldID)}
|
||||
}
|
||||
if nullsLastField {
|
||||
return []func(*entsql.Selector){
|
||||
entsql.OrderByField(field, entsql.OrderDesc(), entsql.OrderNullsLast()).ToFunc(),
|
||||
dbent.Desc(dbuser.FieldID),
|
||||
}
|
||||
}
|
||||
return []func(*entsql.Selector){dbent.Desc(field), dbent.Desc(dbuser.FieldID)}
|
||||
}
|
||||
|
||||
@@ -558,10 +591,21 @@ func applyUserEntityToService(dst *service.User, src *dbent.User) {
|
||||
return
|
||||
}
|
||||
dst.ID = src.ID
|
||||
dst.SignupSource = src.SignupSource
|
||||
dst.LastLoginAt = src.LastLoginAt
|
||||
dst.LastActiveAt = src.LastActiveAt
|
||||
dst.CreatedAt = src.CreatedAt
|
||||
dst.UpdatedAt = src.UpdatedAt
|
||||
}
|
||||
|
||||
func userSignupSourceOrDefault(signupSource string) string {
|
||||
signupSource = strings.TrimSpace(signupSource)
|
||||
if signupSource == "" {
|
||||
return "email"
|
||||
}
|
||||
return signupSource
|
||||
}
|
||||
|
||||
// marshalExtraEmails serializes notify email entries to JSON for storage.
|
||||
func marshalExtraEmails(entries []service.NotifyEmailEntry) string {
|
||||
return service.MarshalNotifyEmails(entries)
|
||||
|
||||
@@ -4,6 +4,7 @@ package repository
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
@@ -36,4 +37,86 @@ func (s *UserRepoSuite) TestList_DefaultSortByNewestFirst() {
|
||||
s.Require().Equal(first.ID, users[1].ID)
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestCreateAndRead_PreservesSignupSourceAndActivityTimestamps() {
|
||||
lastLoginAt := time.Now().Add(-2 * time.Hour).UTC().Truncate(time.Microsecond)
|
||||
lastActiveAt := time.Now().Add(-30 * time.Minute).UTC().Truncate(time.Microsecond)
|
||||
|
||||
created := s.mustCreateUser(&service.User{
|
||||
Email: "identity-meta@example.com",
|
||||
SignupSource: "github",
|
||||
LastLoginAt: &lastLoginAt,
|
||||
LastActiveAt: &lastActiveAt,
|
||||
})
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, created.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal("github", got.SignupSource)
|
||||
s.Require().NotNil(got.LastLoginAt)
|
||||
s.Require().NotNil(got.LastActiveAt)
|
||||
s.Require().True(got.LastLoginAt.Equal(lastLoginAt))
|
||||
s.Require().True(got.LastActiveAt.Equal(lastActiveAt))
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestUpdate_PersistsSignupSourceAndActivityTimestamps() {
|
||||
created := s.mustCreateUser(&service.User{Email: "identity-update@example.com"})
|
||||
lastLoginAt := time.Now().Add(-90 * time.Minute).UTC().Truncate(time.Microsecond)
|
||||
lastActiveAt := time.Now().Add(-15 * time.Minute).UTC().Truncate(time.Microsecond)
|
||||
|
||||
created.SignupSource = "oidc"
|
||||
created.LastLoginAt = &lastLoginAt
|
||||
created.LastActiveAt = &lastActiveAt
|
||||
|
||||
s.Require().NoError(s.repo.Update(s.ctx, created))
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, created.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal("oidc", got.SignupSource)
|
||||
s.Require().NotNil(got.LastLoginAt)
|
||||
s.Require().NotNil(got.LastActiveAt)
|
||||
s.Require().True(got.LastLoginAt.Equal(lastLoginAt))
|
||||
s.Require().True(got.LastActiveAt.Equal(lastActiveAt))
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestListWithFilters_SortByLastLoginAtDesc() {
|
||||
older := time.Now().Add(-4 * time.Hour).UTC().Truncate(time.Microsecond)
|
||||
newer := time.Now().Add(-1 * time.Hour).UTC().Truncate(time.Microsecond)
|
||||
|
||||
s.mustCreateUser(&service.User{Email: "nil-login@example.com"})
|
||||
s.mustCreateUser(&service.User{Email: "older-login@example.com", LastLoginAt: &older})
|
||||
s.mustCreateUser(&service.User{Email: "newer-login@example.com", LastLoginAt: &newer})
|
||||
|
||||
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{
|
||||
Page: 1,
|
||||
PageSize: 10,
|
||||
SortBy: "last_login_at",
|
||||
SortOrder: "desc",
|
||||
}, service.UserListFilters{})
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(users, 3)
|
||||
s.Require().Equal("newer-login@example.com", users[0].Email)
|
||||
s.Require().Equal("older-login@example.com", users[1].Email)
|
||||
s.Require().Equal("nil-login@example.com", users[2].Email)
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestListWithFilters_SortByLastActiveAtAsc() {
|
||||
earlier := time.Now().Add(-3 * time.Hour).UTC().Truncate(time.Microsecond)
|
||||
later := time.Now().Add(-45 * time.Minute).UTC().Truncate(time.Microsecond)
|
||||
|
||||
s.mustCreateUser(&service.User{Email: "nil-active@example.com"})
|
||||
s.mustCreateUser(&service.User{Email: "later-active@example.com", LastActiveAt: &later})
|
||||
s.mustCreateUser(&service.User{Email: "earlier-active@example.com", LastActiveAt: &earlier})
|
||||
|
||||
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{
|
||||
Page: 1,
|
||||
PageSize: 10,
|
||||
SortBy: "last_active_at",
|
||||
SortOrder: "asc",
|
||||
}, service.UserListFilters{})
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(users, 3)
|
||||
s.Require().Equal("earlier-active@example.com", users[0].Email)
|
||||
s.Require().Equal("later-active@example.com", users[1].Email)
|
||||
s.Require().Equal("nil-active@example.com", users[2].Email)
|
||||
}
|
||||
|
||||
func TestUserRepoSortSuiteSmoke(_ *testing.T) {}
|
||||
|
||||
Reference in New Issue
Block a user