feat(auth): support unbinding third-party identities
This commit is contained in:
@@ -82,6 +82,11 @@ func (s *userRepoStubForGroupUpdate) DisableTotp(context.Context, int64) error {
|
||||
func (s *userRepoStubForGroupUpdate) ListUserAuthIdentities(context.Context, int64) ([]UserAuthIdentityRecord, error) {
|
||||
panic("unexpected")
|
||||
}
|
||||
|
||||
func (s *userRepoStubForGroupUpdate) UnbindUserAuthProvider(context.Context, int64, string) error {
|
||||
panic("unexpected")
|
||||
}
|
||||
|
||||
func (s *userRepoStubForGroupUpdate) GetLatestUsedAtByUserIDs(context.Context, []int64) (map[int64]*time.Time, error) {
|
||||
panic("unexpected")
|
||||
}
|
||||
|
||||
@@ -154,6 +154,10 @@ func (s *userRepoStub) ListUserAuthIdentities(ctx context.Context, userID int64)
|
||||
panic("unexpected ListUserAuthIdentities call")
|
||||
}
|
||||
|
||||
func (s *userRepoStub) UnbindUserAuthProvider(context.Context, int64, string) error {
|
||||
panic("unexpected UnbindUserAuthProvider call")
|
||||
}
|
||||
|
||||
func (s *userRepoStub) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
|
||||
panic("unexpected UpdateTotpSecret call")
|
||||
}
|
||||
|
||||
@@ -123,6 +123,8 @@ func (s *emailSyncRepoStub) ListUserAuthIdentities(context.Context, int64) ([]Us
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *emailSyncRepoStub) UnbindUserAuthProvider(context.Context, int64, string) error { return nil }
|
||||
|
||||
func (s *emailSyncRepoStub) UpdateTotpSecret(context.Context, int64, *string) error { return nil }
|
||||
|
||||
func (s *emailSyncRepoStub) EnableTotp(context.Context, int64) error { return nil }
|
||||
|
||||
@@ -90,6 +90,10 @@ func (s *balanceLoadUserRepoStub) ListUserAuthIdentities(context.Context, int64)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *balanceLoadUserRepoStub) UnbindUserAuthProvider(context.Context, int64, string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestBillingCacheServiceGetUserBalance_Singleflight(t *testing.T) {
|
||||
cache := &billingCacheMissStub{}
|
||||
userRepo := &balanceLoadUserRepoStub{
|
||||
|
||||
@@ -29,15 +29,19 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
ErrUserNotFound = infraerrors.NotFound("USER_NOT_FOUND", "user not found")
|
||||
ErrPasswordIncorrect = infraerrors.BadRequest("PASSWORD_INCORRECT", "current password is incorrect")
|
||||
ErrInsufficientPerms = infraerrors.Forbidden("INSUFFICIENT_PERMISSIONS", "insufficient permissions")
|
||||
ErrNotifyCodeUserRateLimit = infraerrors.TooManyRequests("NOTIFY_CODE_USER_RATE_LIMIT", "too many verification codes requested, please try again later")
|
||||
ErrAvatarInvalid = infraerrors.BadRequest("AVATAR_INVALID", "avatar must be a valid image data URL or http(s) URL")
|
||||
ErrAvatarTooLarge = infraerrors.BadRequest("AVATAR_TOO_LARGE", "avatar image must be 100KB or smaller")
|
||||
ErrAvatarNotImage = infraerrors.BadRequest("AVATAR_NOT_IMAGE", "avatar content must be an image")
|
||||
ErrIdentityProviderInvalid = infraerrors.BadRequest("IDENTITY_PROVIDER_INVALID", "identity provider is invalid")
|
||||
ErrIdentityRedirectInvalid = infraerrors.BadRequest("IDENTITY_REDIRECT_INVALID", "identity redirect path is invalid")
|
||||
ErrUserNotFound = infraerrors.NotFound("USER_NOT_FOUND", "user not found")
|
||||
ErrPasswordIncorrect = infraerrors.BadRequest("PASSWORD_INCORRECT", "current password is incorrect")
|
||||
ErrInsufficientPerms = infraerrors.Forbidden("INSUFFICIENT_PERMISSIONS", "insufficient permissions")
|
||||
ErrNotifyCodeUserRateLimit = infraerrors.TooManyRequests("NOTIFY_CODE_USER_RATE_LIMIT", "too many verification codes requested, please try again later")
|
||||
ErrAvatarInvalid = infraerrors.BadRequest("AVATAR_INVALID", "avatar must be a valid image data URL or http(s) URL")
|
||||
ErrAvatarTooLarge = infraerrors.BadRequest("AVATAR_TOO_LARGE", "avatar image must be 100KB or smaller")
|
||||
ErrAvatarNotImage = infraerrors.BadRequest("AVATAR_NOT_IMAGE", "avatar content must be an image")
|
||||
ErrIdentityProviderInvalid = infraerrors.BadRequest("IDENTITY_PROVIDER_INVALID", "identity provider is invalid")
|
||||
ErrIdentityRedirectInvalid = infraerrors.BadRequest("IDENTITY_REDIRECT_INVALID", "identity redirect path is invalid")
|
||||
ErrIdentityUnbindLastMethod = infraerrors.Conflict(
|
||||
"IDENTITY_UNBIND_LAST_METHOD",
|
||||
"bind another sign-in method before unbinding this provider",
|
||||
)
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -99,6 +103,7 @@ type UserRepository interface {
|
||||
// RemoveGroupFromUserAllowedGroups 移除单个用户的指定分组权限
|
||||
RemoveGroupFromUserAllowedGroups(ctx context.Context, userID int64, groupID int64) error
|
||||
ListUserAuthIdentities(ctx context.Context, userID int64) ([]UserAuthIdentityRecord, error)
|
||||
UnbindUserAuthProvider(ctx context.Context, userID int64, provider string) error
|
||||
|
||||
// TOTP 双因素认证
|
||||
UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error
|
||||
@@ -249,9 +254,9 @@ func (s *UserService) GetProfileIdentitySummaries(ctx context.Context, userID in
|
||||
|
||||
return UserIdentitySummarySet{
|
||||
Email: s.buildEmailIdentitySummary(user, records),
|
||||
LinuxDo: s.buildProviderIdentitySummary("linuxdo", records),
|
||||
OIDC: s.buildProviderIdentitySummary("oidc", records),
|
||||
WeChat: s.buildProviderIdentitySummary("wechat", records),
|
||||
LinuxDo: s.buildProviderIdentitySummary("linuxdo", user, records),
|
||||
OIDC: s.buildProviderIdentitySummary("oidc", user, records),
|
||||
WeChat: s.buildProviderIdentitySummary("wechat", user, records),
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -274,6 +279,42 @@ func (s *UserService) PrepareIdentityBindingStart(_ context.Context, req StartUs
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *UserService) UnbindUserAuthProvider(ctx context.Context, userID int64, provider string) (*User, error) {
|
||||
provider = normalizeUserIdentityProvider(provider)
|
||||
if provider == "" || provider == "email" {
|
||||
return nil, ErrIdentityProviderInvalid
|
||||
}
|
||||
|
||||
user, err := s.userRepo.GetByID(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get user: %w", err)
|
||||
}
|
||||
|
||||
records, err := s.listUserAuthIdentities(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(filterUserAuthIdentities(records, provider)) == 0 {
|
||||
return user, nil
|
||||
}
|
||||
if !s.canUnbindProvider(provider, user, records) {
|
||||
return nil, ErrIdentityUnbindLastMethod
|
||||
}
|
||||
|
||||
if err := s.userRepo.UnbindUserAuthProvider(ctx, userID, provider); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.authCacheInvalidator != nil {
|
||||
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
|
||||
}
|
||||
|
||||
updatedUser, err := s.GetProfile(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return updatedUser, nil
|
||||
}
|
||||
|
||||
// UpdateProfile 更新用户资料
|
||||
func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req UpdateProfileRequest) (*User, error) {
|
||||
if txRunner, ok := s.userRepo.(userProfileIdentityTxRunner); ok {
|
||||
@@ -552,7 +593,7 @@ func (s *UserService) buildEmailIdentitySummary(user *User, records []UserAuthId
|
||||
return summary
|
||||
}
|
||||
|
||||
func (s *UserService) buildProviderIdentitySummary(provider string, records []UserAuthIdentityRecord) UserIdentitySummary {
|
||||
func (s *UserService) buildProviderIdentitySummary(provider string, user *User, records []UserAuthIdentityRecord) UserIdentitySummary {
|
||||
summary := UserIdentitySummary{
|
||||
Provider: provider,
|
||||
CanUnbind: false,
|
||||
@@ -574,10 +615,36 @@ func (s *UserService) buildProviderIdentitySummary(provider string, records []Us
|
||||
summary.SubjectHint = maskOpaqueIdentity(primary.ProviderSubject)
|
||||
summary.ProviderKey = strings.TrimSpace(primary.ProviderKey)
|
||||
summary.VerifiedAt = primary.VerifiedAt
|
||||
summary.Note = "Unbind is not available yet."
|
||||
summary.CanUnbind = s.canUnbindProvider(provider, user, records)
|
||||
if summary.CanUnbind {
|
||||
summary.Note = "You can unbind this sign-in method."
|
||||
} else {
|
||||
summary.Note = "Bind another sign-in method before unbinding."
|
||||
}
|
||||
return summary
|
||||
}
|
||||
|
||||
func (s *UserService) canUnbindProvider(provider string, user *User, records []UserAuthIdentityRecord) bool {
|
||||
if provider == "" || provider == "email" || len(filterUserAuthIdentities(records, provider)) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
if s.buildEmailIdentitySummary(user, records).Bound {
|
||||
return true
|
||||
}
|
||||
|
||||
for _, candidate := range []string{"linuxdo", "oidc", "wechat"} {
|
||||
if candidate == provider {
|
||||
continue
|
||||
}
|
||||
if len(filterUserAuthIdentities(records, candidate)) > 0 {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *UserService) listUserAuthIdentities(ctx context.Context, userID int64) ([]UserAuthIdentityRecord, error) {
|
||||
if userID <= 0 || s == nil || s.userRepo == nil {
|
||||
return nil, nil
|
||||
|
||||
@@ -27,6 +27,9 @@ type mockUserRepo struct {
|
||||
updateBalanceFn func(ctx context.Context, id int64, amount float64) error
|
||||
getByIDUser *User
|
||||
getByIDErr error
|
||||
identities []UserAuthIdentityRecord
|
||||
unbindIdentityErr error
|
||||
unboundProviders []string
|
||||
updateLastActiveErr error
|
||||
updateLastActiveUserIDs []int64
|
||||
updateLastActiveAt []time.Time
|
||||
@@ -160,7 +163,9 @@ func (m *mockUserRepo) RemoveGroupFromAllowedGroups(context.Context, int64) (int
|
||||
}
|
||||
func (m *mockUserRepo) AddGroupToAllowedGroups(context.Context, int64, int64) error { return nil }
|
||||
func (m *mockUserRepo) ListUserAuthIdentities(context.Context, int64) ([]UserAuthIdentityRecord, error) {
|
||||
return nil, nil
|
||||
out := make([]UserAuthIdentityRecord, len(m.identities))
|
||||
copy(out, m.identities)
|
||||
return out, nil
|
||||
}
|
||||
func (m *mockUserRepo) GetLatestUsedAtByUserIDs(context.Context, []int64) (map[int64]*time.Time, error) {
|
||||
return map[int64]*time.Time{}, nil
|
||||
@@ -174,6 +179,21 @@ func (m *mockUserRepo) DisableTotp(context.Context, int64) error {
|
||||
func (m *mockUserRepo) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockUserRepo) UnbindUserAuthProvider(_ context.Context, _ int64, provider string) error {
|
||||
if m.unbindIdentityErr != nil {
|
||||
return m.unbindIdentityErr
|
||||
}
|
||||
m.unboundProviders = append(m.unboundProviders, provider)
|
||||
filtered := m.identities[:0]
|
||||
for _, identity := range m.identities {
|
||||
if identity.ProviderType == provider {
|
||||
continue
|
||||
}
|
||||
filtered = append(filtered, identity)
|
||||
}
|
||||
m.identities = append([]UserAuthIdentityRecord(nil), filtered...)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockUserRepo) WithUserProfileIdentityTx(ctx context.Context, fn func(txCtx context.Context) error) error {
|
||||
m.txCalls++
|
||||
@@ -274,6 +294,94 @@ func TestUpdateBalance_Success(t *testing.T) {
|
||||
require.Equal(t, []int64{42}, cache.invalidatedUserIDs, "应对 userID=42 失效缓存")
|
||||
}
|
||||
|
||||
func TestGetProfileIdentitySummaries_AllowsUnbindWhenAnotherLoginMethodRemains(t *testing.T) {
|
||||
repo := &mockUserRepo{
|
||||
getByIDUser: &User{
|
||||
ID: 7,
|
||||
Email: "alice@example.com",
|
||||
},
|
||||
identities: []UserAuthIdentityRecord{
|
||||
{
|
||||
ProviderType: "email",
|
||||
ProviderKey: "email",
|
||||
ProviderSubject: "alice@example.com",
|
||||
},
|
||||
{
|
||||
ProviderType: "linuxdo",
|
||||
ProviderKey: "linuxdo",
|
||||
ProviderSubject: "linuxdo-subject-123456",
|
||||
Metadata: map[string]any{
|
||||
"username": "linuxdo-handle",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
svc := NewUserService(repo, nil, nil, nil)
|
||||
|
||||
summaries, err := svc.GetProfileIdentitySummaries(context.Background(), 7, repo.getByIDUser)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.True(t, summaries.LinuxDo.Bound)
|
||||
require.True(t, summaries.LinuxDo.CanUnbind)
|
||||
require.Equal(t, "linuxdo-handle", summaries.LinuxDo.DisplayName)
|
||||
require.NotEmpty(t, summaries.LinuxDo.SubjectHint)
|
||||
}
|
||||
|
||||
func TestUnbindUserAuthProviderRejectsLastRemainingLoginMethod(t *testing.T) {
|
||||
repo := &mockUserRepo{
|
||||
getByIDUser: &User{
|
||||
ID: 9,
|
||||
Email: "only-user@linuxdo-connect.invalid",
|
||||
},
|
||||
identities: []UserAuthIdentityRecord{
|
||||
{
|
||||
ProviderType: "linuxdo",
|
||||
ProviderKey: "linuxdo",
|
||||
ProviderSubject: "linuxdo-only-subject",
|
||||
},
|
||||
},
|
||||
}
|
||||
svc := NewUserService(repo, nil, nil, nil)
|
||||
|
||||
_, err := svc.UnbindUserAuthProvider(context.Background(), 9, "linuxdo")
|
||||
|
||||
require.ErrorIs(t, err, ErrIdentityUnbindLastMethod)
|
||||
require.Empty(t, repo.unboundProviders)
|
||||
}
|
||||
|
||||
func TestUnbindUserAuthProviderRemovesProviderAndReturnsUpdatedProfile(t *testing.T) {
|
||||
repo := &mockUserRepo{
|
||||
getByIDUser: &User{
|
||||
ID: 12,
|
||||
Email: "alice@example.com",
|
||||
},
|
||||
identities: []UserAuthIdentityRecord{
|
||||
{
|
||||
ProviderType: "email",
|
||||
ProviderKey: "email",
|
||||
ProviderSubject: "alice@example.com",
|
||||
},
|
||||
{
|
||||
ProviderType: "linuxdo",
|
||||
ProviderKey: "linuxdo",
|
||||
ProviderSubject: "linuxdo-subject-12",
|
||||
},
|
||||
},
|
||||
}
|
||||
svc := NewUserService(repo, nil, nil, nil)
|
||||
|
||||
user, err := svc.UnbindUserAuthProvider(context.Background(), 12, "linuxdo")
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []string{"linuxdo"}, repo.unboundProviders)
|
||||
require.Equal(t, int64(12), user.ID)
|
||||
|
||||
summaries, err := svc.GetProfileIdentitySummaries(context.Background(), 12, user)
|
||||
require.NoError(t, err)
|
||||
require.False(t, summaries.LinuxDo.Bound)
|
||||
require.True(t, summaries.LinuxDo.CanBind)
|
||||
}
|
||||
|
||||
func TestUpdateBalance_NilBillingCache_NoPanic(t *testing.T) {
|
||||
repo := &mockUserRepo{}
|
||||
svc := NewUserService(repo, nil, nil, nil) // billingCache = nil
|
||||
|
||||
Reference in New Issue
Block a user