feat(auth): support unbinding third-party identities
This commit is contained in:
@@ -2735,6 +2735,10 @@ func (r *oauthPendingFlowUserRepo) ListUserAuthIdentities(ctx context.Context, u
|
|||||||
return records, nil
|
return records, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *oauthPendingFlowUserRepo) UnbindUserAuthProvider(context.Context, int64, string) error {
|
||||||
|
panic("unexpected UnbindUserAuthProvider call")
|
||||||
|
}
|
||||||
|
|
||||||
func (r *oauthPendingFlowUserRepo) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
|
func (r *oauthPendingFlowUserRepo) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
|
||||||
update := r.client.User.UpdateOneID(userID)
|
update := r.client.User.UpdateOneID(userID)
|
||||||
if encryptedSecret == nil {
|
if encryptedSecret == nil {
|
||||||
|
|||||||
@@ -240,6 +240,34 @@ func (h *UserHandler) BindEmailIdentity(c *gin.Context) {
|
|||||||
response.Success(c, profileResp)
|
response.Success(c, profileResp)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UnbindIdentity removes a third-party sign-in provider from the current user.
|
||||||
|
// DELETE /api/v1/user/account-bindings/:provider
|
||||||
|
func (h *UserHandler) UnbindIdentity(c *gin.Context) {
|
||||||
|
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||||
|
if !ok {
|
||||||
|
response.Unauthorized(c, "User not authenticated")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
updatedUser, err := h.userService.UnbindUserAuthProvider(
|
||||||
|
c.Request.Context(),
|
||||||
|
subject.UserID,
|
||||||
|
c.Param("provider"),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
profileResp, err := h.buildUserProfileResponse(c.Request.Context(), subject.UserID, updatedUser)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, profileResp)
|
||||||
|
}
|
||||||
|
|
||||||
// SendEmailBindingCode sends a verification code for the current user's email binding flow.
|
// SendEmailBindingCode sends a verification code for the current user's email binding flow.
|
||||||
// POST /api/v1/user/account-bindings/email/send-code
|
// POST /api/v1/user/account-bindings/email/send-code
|
||||||
func (h *UserHandler) SendEmailBindingCode(c *gin.Context) {
|
func (h *UserHandler) SendEmailBindingCode(c *gin.Context) {
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ import (
|
|||||||
type userHandlerRepoStub struct {
|
type userHandlerRepoStub struct {
|
||||||
user *service.User
|
user *service.User
|
||||||
identities []service.UserAuthIdentityRecord
|
identities []service.UserAuthIdentityRecord
|
||||||
|
unbound []string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *userHandlerRepoStub) Create(context.Context, *service.User) error { return nil }
|
func (s *userHandlerRepoStub) Create(context.Context, *service.User) error { return nil }
|
||||||
@@ -116,6 +117,18 @@ func (s *userHandlerRepoStub) ListUserAuthIdentities(context.Context, int64) ([]
|
|||||||
copy(out, s.identities)
|
copy(out, s.identities)
|
||||||
return out, nil
|
return out, nil
|
||||||
}
|
}
|
||||||
|
func (s *userHandlerRepoStub) UnbindUserAuthProvider(_ context.Context, _ int64, provider string) error {
|
||||||
|
s.unbound = append(s.unbound, provider)
|
||||||
|
filtered := s.identities[:0]
|
||||||
|
for _, identity := range s.identities {
|
||||||
|
if identity.ProviderType == provider {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
filtered = append(filtered, identity)
|
||||||
|
}
|
||||||
|
s.identities = append([]service.UserAuthIdentityRecord(nil), filtered...)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func TestUserHandlerUpdateProfileReturnsAvatarURL(t *testing.T) {
|
func TestUserHandlerUpdateProfileReturnsAvatarURL(t *testing.T) {
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
@@ -428,6 +441,60 @@ func TestUserHandlerBindEmailIdentityReturnsProfileResponse(t *testing.T) {
|
|||||||
require.True(t, resp.Data.EmailBound)
|
require.True(t, resp.Data.EmailBound)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestUserHandlerUnbindIdentityReturnsUpdatedProfile(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
repo := &userHandlerRepoStub{
|
||||||
|
user: &service.User{
|
||||||
|
ID: 21,
|
||||||
|
Email: "identity@example.com",
|
||||||
|
Username: "identity-user",
|
||||||
|
Role: service.RoleUser,
|
||||||
|
Status: service.StatusActive,
|
||||||
|
},
|
||||||
|
identities: []service.UserAuthIdentityRecord{
|
||||||
|
{
|
||||||
|
ProviderType: "email",
|
||||||
|
ProviderKey: "email",
|
||||||
|
ProviderSubject: "identity@example.com",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ProviderType: "linuxdo",
|
||||||
|
ProviderKey: "linuxdo",
|
||||||
|
ProviderSubject: "linuxdo-subject-21",
|
||||||
|
Metadata: map[string]any{
|
||||||
|
"username": "linuxdo-handle",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil)
|
||||||
|
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(recorder)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodDelete, "/api/v1/user/account-bindings/linuxdo", nil)
|
||||||
|
c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 21})
|
||||||
|
c.Params = gin.Params{{Key: "provider", Value: "linuxdo"}}
|
||||||
|
|
||||||
|
handler.UnbindIdentity(c)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusOK, recorder.Code)
|
||||||
|
require.Equal(t, []string{"linuxdo"}, repo.unbound)
|
||||||
|
|
||||||
|
var resp struct {
|
||||||
|
Code int `json:"code"`
|
||||||
|
Data map[string]any `json:"data"`
|
||||||
|
}
|
||||||
|
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
|
||||||
|
require.Equal(t, 0, resp.Code)
|
||||||
|
|
||||||
|
authBindings, ok := resp.Data["auth_bindings"].(map[string]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
linuxdoBinding, ok := authBindings["linuxdo"].(map[string]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, false, linuxdoBinding["bound"])
|
||||||
|
}
|
||||||
|
|
||||||
func TestUserHandlerBindEmailIdentityRejectsWrongCurrentPasswordForBoundEmail(t *testing.T) {
|
func TestUserHandlerBindEmailIdentityRejectsWrongCurrentPasswordForBoundEmail(t *testing.T) {
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
|||||||
@@ -249,6 +249,48 @@ func (r *userRepository) ListUserAuthIdentities(ctx context.Context, userID int6
|
|||||||
return records, nil
|
return records, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *userRepository) UnbindUserAuthProvider(ctx context.Context, userID int64, provider string) error {
|
||||||
|
provider = strings.ToLower(strings.TrimSpace(provider))
|
||||||
|
if provider == "" || provider == "email" {
|
||||||
|
return service.ErrIdentityProviderInvalid
|
||||||
|
}
|
||||||
|
|
||||||
|
return r.WithUserProfileIdentityTx(ctx, func(txCtx context.Context) error {
|
||||||
|
client := clientFromContext(txCtx, r.client)
|
||||||
|
identityIDs, err := client.AuthIdentity.Query().
|
||||||
|
Where(
|
||||||
|
authidentity.UserIDEQ(userID),
|
||||||
|
authidentity.ProviderTypeEQ(provider),
|
||||||
|
).
|
||||||
|
IDs(txCtx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if len(identityIDs) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := client.IdentityAdoptionDecision.Update().
|
||||||
|
Where(identityadoptiondecision.IdentityIDIn(identityIDs...)).
|
||||||
|
ClearIdentityID().
|
||||||
|
Save(txCtx); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if _, err := client.AuthIdentityChannel.Delete().
|
||||||
|
Where(authidentitychannel.IdentityIDIn(identityIDs...)).
|
||||||
|
Exec(txCtx); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_, err = client.AuthIdentity.Delete().
|
||||||
|
Where(
|
||||||
|
authidentity.UserIDEQ(userID),
|
||||||
|
authidentity.ProviderTypeEQ(provider),
|
||||||
|
).
|
||||||
|
Exec(txCtx)
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func (r *userRepository) BindAuthIdentityToUser(ctx context.Context, input BindAuthIdentityInput) (*CreateAuthIdentityResult, error) {
|
func (r *userRepository) BindAuthIdentityToUser(ctx context.Context, input BindAuthIdentityInput) (*CreateAuthIdentityResult, error) {
|
||||||
if err := validateAuthIdentityChannelProviderMatch(input.Canonical, input.Channel); err != nil {
|
if err := validateAuthIdentityChannelProviderMatch(input.Canonical, input.Channel); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|||||||
@@ -941,6 +941,10 @@ func (r *stubUserRepo) ListUserAuthIdentities(ctx context.Context, userID int64)
|
|||||||
return nil, errors.New("not implemented")
|
return nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *stubUserRepo) UnbindUserAuthProvider(context.Context, int64, string) error {
|
||||||
|
return errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
func (r *stubUserRepo) GetLatestUsedAtByUserIDs(ctx context.Context, userIDs []int64) (map[int64]*time.Time, error) {
|
func (r *stubUserRepo) GetLatestUsedAtByUserIDs(ctx context.Context, userIDs []int64) (map[int64]*time.Time, error) {
|
||||||
return map[int64]*time.Time{}, nil
|
return map[int64]*time.Time{}, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -218,6 +218,10 @@ func (s *stubUserRepo) ListUserAuthIdentities(ctx context.Context, userID int64)
|
|||||||
panic("unexpected ListUserAuthIdentities call")
|
panic("unexpected ListUserAuthIdentities call")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *stubUserRepo) UnbindUserAuthProvider(context.Context, int64, string) error {
|
||||||
|
panic("unexpected UnbindUserAuthProvider call")
|
||||||
|
}
|
||||||
|
|
||||||
func (s *stubUserRepo) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
|
func (s *stubUserRepo) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
|
||||||
panic("unexpected UpdateTotpSecret call")
|
panic("unexpected UpdateTotpSecret call")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ func RegisterUserRoutes(
|
|||||||
user.PUT("", h.User.UpdateProfile)
|
user.PUT("", h.User.UpdateProfile)
|
||||||
user.POST("/account-bindings/email/send-code", h.User.SendEmailBindingCode)
|
user.POST("/account-bindings/email/send-code", h.User.SendEmailBindingCode)
|
||||||
user.POST("/account-bindings/email", h.User.BindEmailIdentity)
|
user.POST("/account-bindings/email", h.User.BindEmailIdentity)
|
||||||
|
user.DELETE("/account-bindings/:provider", h.User.UnbindIdentity)
|
||||||
user.POST("/auth-identities/bind/start", h.User.StartIdentityBinding)
|
user.POST("/auth-identities/bind/start", h.User.StartIdentityBinding)
|
||||||
|
|
||||||
// 通知邮箱管理
|
// 通知邮箱管理
|
||||||
|
|||||||
@@ -82,6 +82,11 @@ func (s *userRepoStubForGroupUpdate) DisableTotp(context.Context, int64) error {
|
|||||||
func (s *userRepoStubForGroupUpdate) ListUserAuthIdentities(context.Context, int64) ([]UserAuthIdentityRecord, error) {
|
func (s *userRepoStubForGroupUpdate) ListUserAuthIdentities(context.Context, int64) ([]UserAuthIdentityRecord, error) {
|
||||||
panic("unexpected")
|
panic("unexpected")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *userRepoStubForGroupUpdate) UnbindUserAuthProvider(context.Context, int64, string) error {
|
||||||
|
panic("unexpected")
|
||||||
|
}
|
||||||
|
|
||||||
func (s *userRepoStubForGroupUpdate) GetLatestUsedAtByUserIDs(context.Context, []int64) (map[int64]*time.Time, error) {
|
func (s *userRepoStubForGroupUpdate) GetLatestUsedAtByUserIDs(context.Context, []int64) (map[int64]*time.Time, error) {
|
||||||
panic("unexpected")
|
panic("unexpected")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -154,6 +154,10 @@ func (s *userRepoStub) ListUserAuthIdentities(ctx context.Context, userID int64)
|
|||||||
panic("unexpected ListUserAuthIdentities call")
|
panic("unexpected ListUserAuthIdentities call")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *userRepoStub) UnbindUserAuthProvider(context.Context, int64, string) error {
|
||||||
|
panic("unexpected UnbindUserAuthProvider call")
|
||||||
|
}
|
||||||
|
|
||||||
func (s *userRepoStub) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
|
func (s *userRepoStub) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
|
||||||
panic("unexpected UpdateTotpSecret call")
|
panic("unexpected UpdateTotpSecret call")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -123,6 +123,8 @@ func (s *emailSyncRepoStub) ListUserAuthIdentities(context.Context, int64) ([]Us
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *emailSyncRepoStub) UnbindUserAuthProvider(context.Context, int64, string) error { return nil }
|
||||||
|
|
||||||
func (s *emailSyncRepoStub) UpdateTotpSecret(context.Context, int64, *string) error { return nil }
|
func (s *emailSyncRepoStub) UpdateTotpSecret(context.Context, int64, *string) error { return nil }
|
||||||
|
|
||||||
func (s *emailSyncRepoStub) EnableTotp(context.Context, int64) error { return nil }
|
func (s *emailSyncRepoStub) EnableTotp(context.Context, int64) error { return nil }
|
||||||
|
|||||||
@@ -90,6 +90,10 @@ func (s *balanceLoadUserRepoStub) ListUserAuthIdentities(context.Context, int64)
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *balanceLoadUserRepoStub) UnbindUserAuthProvider(context.Context, int64, string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func TestBillingCacheServiceGetUserBalance_Singleflight(t *testing.T) {
|
func TestBillingCacheServiceGetUserBalance_Singleflight(t *testing.T) {
|
||||||
cache := &billingCacheMissStub{}
|
cache := &billingCacheMissStub{}
|
||||||
userRepo := &balanceLoadUserRepoStub{
|
userRepo := &balanceLoadUserRepoStub{
|
||||||
|
|||||||
@@ -38,6 +38,10 @@ var (
|
|||||||
ErrAvatarNotImage = infraerrors.BadRequest("AVATAR_NOT_IMAGE", "avatar content must be an image")
|
ErrAvatarNotImage = infraerrors.BadRequest("AVATAR_NOT_IMAGE", "avatar content must be an image")
|
||||||
ErrIdentityProviderInvalid = infraerrors.BadRequest("IDENTITY_PROVIDER_INVALID", "identity provider is invalid")
|
ErrIdentityProviderInvalid = infraerrors.BadRequest("IDENTITY_PROVIDER_INVALID", "identity provider is invalid")
|
||||||
ErrIdentityRedirectInvalid = infraerrors.BadRequest("IDENTITY_REDIRECT_INVALID", "identity redirect path is invalid")
|
ErrIdentityRedirectInvalid = infraerrors.BadRequest("IDENTITY_REDIRECT_INVALID", "identity redirect path is invalid")
|
||||||
|
ErrIdentityUnbindLastMethod = infraerrors.Conflict(
|
||||||
|
"IDENTITY_UNBIND_LAST_METHOD",
|
||||||
|
"bind another sign-in method before unbinding this provider",
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -99,6 +103,7 @@ type UserRepository interface {
|
|||||||
// RemoveGroupFromUserAllowedGroups 移除单个用户的指定分组权限
|
// RemoveGroupFromUserAllowedGroups 移除单个用户的指定分组权限
|
||||||
RemoveGroupFromUserAllowedGroups(ctx context.Context, userID int64, groupID int64) error
|
RemoveGroupFromUserAllowedGroups(ctx context.Context, userID int64, groupID int64) error
|
||||||
ListUserAuthIdentities(ctx context.Context, userID int64) ([]UserAuthIdentityRecord, error)
|
ListUserAuthIdentities(ctx context.Context, userID int64) ([]UserAuthIdentityRecord, error)
|
||||||
|
UnbindUserAuthProvider(ctx context.Context, userID int64, provider string) error
|
||||||
|
|
||||||
// TOTP 双因素认证
|
// TOTP 双因素认证
|
||||||
UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error
|
UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error
|
||||||
@@ -249,9 +254,9 @@ func (s *UserService) GetProfileIdentitySummaries(ctx context.Context, userID in
|
|||||||
|
|
||||||
return UserIdentitySummarySet{
|
return UserIdentitySummarySet{
|
||||||
Email: s.buildEmailIdentitySummary(user, records),
|
Email: s.buildEmailIdentitySummary(user, records),
|
||||||
LinuxDo: s.buildProviderIdentitySummary("linuxdo", records),
|
LinuxDo: s.buildProviderIdentitySummary("linuxdo", user, records),
|
||||||
OIDC: s.buildProviderIdentitySummary("oidc", records),
|
OIDC: s.buildProviderIdentitySummary("oidc", user, records),
|
||||||
WeChat: s.buildProviderIdentitySummary("wechat", records),
|
WeChat: s.buildProviderIdentitySummary("wechat", user, records),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -274,6 +279,42 @@ func (s *UserService) PrepareIdentityBindingStart(_ context.Context, req StartUs
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *UserService) UnbindUserAuthProvider(ctx context.Context, userID int64, provider string) (*User, error) {
|
||||||
|
provider = normalizeUserIdentityProvider(provider)
|
||||||
|
if provider == "" || provider == "email" {
|
||||||
|
return nil, ErrIdentityProviderInvalid
|
||||||
|
}
|
||||||
|
|
||||||
|
user, err := s.userRepo.GetByID(ctx, userID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("get user: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
records, err := s.listUserAuthIdentities(ctx, userID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if len(filterUserAuthIdentities(records, provider)) == 0 {
|
||||||
|
return user, nil
|
||||||
|
}
|
||||||
|
if !s.canUnbindProvider(provider, user, records) {
|
||||||
|
return nil, ErrIdentityUnbindLastMethod
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.userRepo.UnbindUserAuthProvider(ctx, userID, provider); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if s.authCacheInvalidator != nil {
|
||||||
|
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
|
||||||
|
}
|
||||||
|
|
||||||
|
updatedUser, err := s.GetProfile(ctx, userID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return updatedUser, nil
|
||||||
|
}
|
||||||
|
|
||||||
// UpdateProfile 更新用户资料
|
// UpdateProfile 更新用户资料
|
||||||
func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req UpdateProfileRequest) (*User, error) {
|
func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req UpdateProfileRequest) (*User, error) {
|
||||||
if txRunner, ok := s.userRepo.(userProfileIdentityTxRunner); ok {
|
if txRunner, ok := s.userRepo.(userProfileIdentityTxRunner); ok {
|
||||||
@@ -552,7 +593,7 @@ func (s *UserService) buildEmailIdentitySummary(user *User, records []UserAuthId
|
|||||||
return summary
|
return summary
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *UserService) buildProviderIdentitySummary(provider string, records []UserAuthIdentityRecord) UserIdentitySummary {
|
func (s *UserService) buildProviderIdentitySummary(provider string, user *User, records []UserAuthIdentityRecord) UserIdentitySummary {
|
||||||
summary := UserIdentitySummary{
|
summary := UserIdentitySummary{
|
||||||
Provider: provider,
|
Provider: provider,
|
||||||
CanUnbind: false,
|
CanUnbind: false,
|
||||||
@@ -574,10 +615,36 @@ func (s *UserService) buildProviderIdentitySummary(provider string, records []Us
|
|||||||
summary.SubjectHint = maskOpaqueIdentity(primary.ProviderSubject)
|
summary.SubjectHint = maskOpaqueIdentity(primary.ProviderSubject)
|
||||||
summary.ProviderKey = strings.TrimSpace(primary.ProviderKey)
|
summary.ProviderKey = strings.TrimSpace(primary.ProviderKey)
|
||||||
summary.VerifiedAt = primary.VerifiedAt
|
summary.VerifiedAt = primary.VerifiedAt
|
||||||
summary.Note = "Unbind is not available yet."
|
summary.CanUnbind = s.canUnbindProvider(provider, user, records)
|
||||||
|
if summary.CanUnbind {
|
||||||
|
summary.Note = "You can unbind this sign-in method."
|
||||||
|
} else {
|
||||||
|
summary.Note = "Bind another sign-in method before unbinding."
|
||||||
|
}
|
||||||
return summary
|
return summary
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *UserService) canUnbindProvider(provider string, user *User, records []UserAuthIdentityRecord) bool {
|
||||||
|
if provider == "" || provider == "email" || len(filterUserAuthIdentities(records, provider)) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.buildEmailIdentitySummary(user, records).Bound {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, candidate := range []string{"linuxdo", "oidc", "wechat"} {
|
||||||
|
if candidate == provider {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if len(filterUserAuthIdentities(records, candidate)) > 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
func (s *UserService) listUserAuthIdentities(ctx context.Context, userID int64) ([]UserAuthIdentityRecord, error) {
|
func (s *UserService) listUserAuthIdentities(ctx context.Context, userID int64) ([]UserAuthIdentityRecord, error) {
|
||||||
if userID <= 0 || s == nil || s.userRepo == nil {
|
if userID <= 0 || s == nil || s.userRepo == nil {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
|
|||||||
@@ -27,6 +27,9 @@ type mockUserRepo struct {
|
|||||||
updateBalanceFn func(ctx context.Context, id int64, amount float64) error
|
updateBalanceFn func(ctx context.Context, id int64, amount float64) error
|
||||||
getByIDUser *User
|
getByIDUser *User
|
||||||
getByIDErr error
|
getByIDErr error
|
||||||
|
identities []UserAuthIdentityRecord
|
||||||
|
unbindIdentityErr error
|
||||||
|
unboundProviders []string
|
||||||
updateLastActiveErr error
|
updateLastActiveErr error
|
||||||
updateLastActiveUserIDs []int64
|
updateLastActiveUserIDs []int64
|
||||||
updateLastActiveAt []time.Time
|
updateLastActiveAt []time.Time
|
||||||
@@ -160,7 +163,9 @@ func (m *mockUserRepo) RemoveGroupFromAllowedGroups(context.Context, int64) (int
|
|||||||
}
|
}
|
||||||
func (m *mockUserRepo) AddGroupToAllowedGroups(context.Context, int64, int64) error { return nil }
|
func (m *mockUserRepo) AddGroupToAllowedGroups(context.Context, int64, int64) error { return nil }
|
||||||
func (m *mockUserRepo) ListUserAuthIdentities(context.Context, int64) ([]UserAuthIdentityRecord, error) {
|
func (m *mockUserRepo) ListUserAuthIdentities(context.Context, int64) ([]UserAuthIdentityRecord, error) {
|
||||||
return nil, nil
|
out := make([]UserAuthIdentityRecord, len(m.identities))
|
||||||
|
copy(out, m.identities)
|
||||||
|
return out, nil
|
||||||
}
|
}
|
||||||
func (m *mockUserRepo) GetLatestUsedAtByUserIDs(context.Context, []int64) (map[int64]*time.Time, error) {
|
func (m *mockUserRepo) GetLatestUsedAtByUserIDs(context.Context, []int64) (map[int64]*time.Time, error) {
|
||||||
return map[int64]*time.Time{}, nil
|
return map[int64]*time.Time{}, nil
|
||||||
@@ -174,6 +179,21 @@ func (m *mockUserRepo) DisableTotp(context.Context, int64) error {
|
|||||||
func (m *mockUserRepo) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error {
|
func (m *mockUserRepo) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
func (m *mockUserRepo) UnbindUserAuthProvider(_ context.Context, _ int64, provider string) error {
|
||||||
|
if m.unbindIdentityErr != nil {
|
||||||
|
return m.unbindIdentityErr
|
||||||
|
}
|
||||||
|
m.unboundProviders = append(m.unboundProviders, provider)
|
||||||
|
filtered := m.identities[:0]
|
||||||
|
for _, identity := range m.identities {
|
||||||
|
if identity.ProviderType == provider {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
filtered = append(filtered, identity)
|
||||||
|
}
|
||||||
|
m.identities = append([]UserAuthIdentityRecord(nil), filtered...)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (m *mockUserRepo) WithUserProfileIdentityTx(ctx context.Context, fn func(txCtx context.Context) error) error {
|
func (m *mockUserRepo) WithUserProfileIdentityTx(ctx context.Context, fn func(txCtx context.Context) error) error {
|
||||||
m.txCalls++
|
m.txCalls++
|
||||||
@@ -274,6 +294,94 @@ func TestUpdateBalance_Success(t *testing.T) {
|
|||||||
require.Equal(t, []int64{42}, cache.invalidatedUserIDs, "应对 userID=42 失效缓存")
|
require.Equal(t, []int64{42}, cache.invalidatedUserIDs, "应对 userID=42 失效缓存")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGetProfileIdentitySummaries_AllowsUnbindWhenAnotherLoginMethodRemains(t *testing.T) {
|
||||||
|
repo := &mockUserRepo{
|
||||||
|
getByIDUser: &User{
|
||||||
|
ID: 7,
|
||||||
|
Email: "alice@example.com",
|
||||||
|
},
|
||||||
|
identities: []UserAuthIdentityRecord{
|
||||||
|
{
|
||||||
|
ProviderType: "email",
|
||||||
|
ProviderKey: "email",
|
||||||
|
ProviderSubject: "alice@example.com",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ProviderType: "linuxdo",
|
||||||
|
ProviderKey: "linuxdo",
|
||||||
|
ProviderSubject: "linuxdo-subject-123456",
|
||||||
|
Metadata: map[string]any{
|
||||||
|
"username": "linuxdo-handle",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := NewUserService(repo, nil, nil, nil)
|
||||||
|
|
||||||
|
summaries, err := svc.GetProfileIdentitySummaries(context.Background(), 7, repo.getByIDUser)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.True(t, summaries.LinuxDo.Bound)
|
||||||
|
require.True(t, summaries.LinuxDo.CanUnbind)
|
||||||
|
require.Equal(t, "linuxdo-handle", summaries.LinuxDo.DisplayName)
|
||||||
|
require.NotEmpty(t, summaries.LinuxDo.SubjectHint)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUnbindUserAuthProviderRejectsLastRemainingLoginMethod(t *testing.T) {
|
||||||
|
repo := &mockUserRepo{
|
||||||
|
getByIDUser: &User{
|
||||||
|
ID: 9,
|
||||||
|
Email: "only-user@linuxdo-connect.invalid",
|
||||||
|
},
|
||||||
|
identities: []UserAuthIdentityRecord{
|
||||||
|
{
|
||||||
|
ProviderType: "linuxdo",
|
||||||
|
ProviderKey: "linuxdo",
|
||||||
|
ProviderSubject: "linuxdo-only-subject",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := NewUserService(repo, nil, nil, nil)
|
||||||
|
|
||||||
|
_, err := svc.UnbindUserAuthProvider(context.Background(), 9, "linuxdo")
|
||||||
|
|
||||||
|
require.ErrorIs(t, err, ErrIdentityUnbindLastMethod)
|
||||||
|
require.Empty(t, repo.unboundProviders)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUnbindUserAuthProviderRemovesProviderAndReturnsUpdatedProfile(t *testing.T) {
|
||||||
|
repo := &mockUserRepo{
|
||||||
|
getByIDUser: &User{
|
||||||
|
ID: 12,
|
||||||
|
Email: "alice@example.com",
|
||||||
|
},
|
||||||
|
identities: []UserAuthIdentityRecord{
|
||||||
|
{
|
||||||
|
ProviderType: "email",
|
||||||
|
ProviderKey: "email",
|
||||||
|
ProviderSubject: "alice@example.com",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ProviderType: "linuxdo",
|
||||||
|
ProviderKey: "linuxdo",
|
||||||
|
ProviderSubject: "linuxdo-subject-12",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := NewUserService(repo, nil, nil, nil)
|
||||||
|
|
||||||
|
user, err := svc.UnbindUserAuthProvider(context.Background(), 12, "linuxdo")
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, []string{"linuxdo"}, repo.unboundProviders)
|
||||||
|
require.Equal(t, int64(12), user.ID)
|
||||||
|
|
||||||
|
summaries, err := svc.GetProfileIdentitySummaries(context.Background(), 12, user)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.False(t, summaries.LinuxDo.Bound)
|
||||||
|
require.True(t, summaries.LinuxDo.CanBind)
|
||||||
|
}
|
||||||
|
|
||||||
func TestUpdateBalance_NilBillingCache_NoPanic(t *testing.T) {
|
func TestUpdateBalance_NilBillingCache_NoPanic(t *testing.T) {
|
||||||
repo := &mockUserRepo{}
|
repo := &mockUserRepo{}
|
||||||
svc := NewUserService(repo, nil, nil, nil) // billingCache = nil
|
svc := NewUserService(repo, nil, nil, nil) // billingCache = nil
|
||||||
|
|||||||
Reference in New Issue
Block a user