diff --git a/backend/internal/handler/user_handler.go b/backend/internal/handler/user_handler.go index a6a7be9a..497a23c4 100644 --- a/backend/internal/handler/user_handler.go +++ b/backend/internal/handler/user_handler.go @@ -167,7 +167,7 @@ type StartIdentityBindingRequest struct { type BindEmailIdentityRequest struct { Email string `json:"email" binding:"required,email"` VerifyCode string `json:"verify_code" binding:"required"` - Password string `json:"password" binding:"required,min=6"` + Password string `json:"password" binding:"required"` } type SendEmailBindingCodeRequest struct { diff --git a/backend/internal/handler/user_handler_test.go b/backend/internal/handler/user_handler_test.go index 72b28293..24f715d4 100644 --- a/backend/internal/handler/user_handler_test.go +++ b/backend/internal/handler/user_handler_test.go @@ -422,6 +422,59 @@ func TestUserHandlerBindEmailIdentityReturnsProfileResponse(t *testing.T) { require.True(t, resp.Data.EmailBound) } +func TestUserHandlerBindEmailIdentityRejectsWrongCurrentPasswordForBoundEmail(t *testing.T) { + gin.SetMode(gin.TestMode) + + user := &service.User{ + ID: 11, + Email: "current@example.com", + Username: "bound-user", + Role: service.RoleUser, + Status: service.StatusActive, + } + require.NoError(t, user.SetPassword("current-password")) + + repo := &userHandlerRepoStub{user: user} + emailCache := &userHandlerEmailCacheStub{ + data: &service.VerificationCodeData{ + Code: "123456", + CreatedAt: time.Now().UTC(), + ExpiresAt: time.Now().UTC().Add(10 * time.Minute), + }, + } + cfg := &config.Config{ + JWT: config.JWTConfig{ + Secret: "test-secret", + ExpireHour: 1, + }, + } + emailService := service.NewEmailService(nil, emailCache) + authService := service.NewAuthService(nil, repo, nil, nil, cfg, nil, emailService, nil, nil, nil, nil) + handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil) + + body := []byte(`{"email":"new@example.com","verify_code":"123456","password":"wrong-password"}`) + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/user/account-bindings/email", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 11}) + + handler.BindEmailIdentity(c) + + require.Equal(t, http.StatusBadRequest, recorder.Code) + + var resp struct { + Code int `json:"code"` + Message string `json:"message"` + Reason string `json:"reason"` + } + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp)) + require.Equal(t, http.StatusBadRequest, resp.Code) + require.Equal(t, "PASSWORD_INCORRECT", resp.Reason) + require.Equal(t, "current password is incorrect", resp.Message) + require.Equal(t, "current@example.com", repo.user.Email) +} + func TestUserHandlerStartIdentityBindingReturnsAuthorizeURL(t *testing.T) { gin.SetMode(gin.TestMode) diff --git a/backend/internal/service/auth_email_binding.go b/backend/internal/service/auth_email_binding.go index 58f8e647..b060ab76 100644 --- a/backend/internal/service/auth_email_binding.go +++ b/backend/internal/service/auth_email_binding.go @@ -13,7 +13,8 @@ import ( infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" ) -// BindEmailIdentity verifies and binds a local email/password identity to the current user. +// BindEmailIdentity verifies and binds a local email/password identity to the +// current user, or replaces the existing bound primary email. func (s *AuthService) BindEmailIdentity( ctx context.Context, userID int64, @@ -43,6 +44,13 @@ func (s *AuthService) BindEmailIdentity( if err != nil { return nil, err } + firstRealEmailBind := !hasBindableEmailIdentitySubject(currentUser.Email) + if firstRealEmailBind && len(password) < 6 { + return nil, infraerrors.BadRequest("PASSWORD_TOO_SHORT", "password must be at least 6 characters") + } + if !firstRealEmailBind && !s.CheckPassword(password, currentUser.PasswordHash) { + return nil, ErrPasswordIncorrect + } existingUser, err := s.userRepo.GetByEmail(ctx, normalizedEmail) switch { @@ -57,9 +65,8 @@ func (s *AuthService) BindEmailIdentity( return nil, fmt.Errorf("hash password: %w", err) } - firstRealEmailBind := !hasBindableEmailIdentitySubject(currentUser.Email) - if firstRealEmailBind && s.entClient != nil { - if err := s.bindEmailIdentityWithDefaultsTx(ctx, currentUser, normalizedEmail, hashedPassword); err != nil { + if s.entClient != nil { + if err := s.updateBoundEmailIdentityTx(ctx, currentUser, normalizedEmail, hashedPassword, firstRealEmailBind); err != nil { return nil, err } return currentUser, nil @@ -137,14 +144,15 @@ func hasBindableEmailIdentitySubject(email string) bool { return normalized != "" && !isReservedEmail(normalized) } -func (s *AuthService) bindEmailIdentityWithDefaultsTx( +func (s *AuthService) updateBoundEmailIdentityTx( ctx context.Context, currentUser *User, email string, hashedPassword string, + applyFirstBindDefaults bool, ) error { if tx := dbent.TxFromContext(ctx); tx != nil { - return s.bindEmailIdentityWithDefaults(ctx, tx.Client(), currentUser, email, hashedPassword) + return s.updateBoundEmailIdentityWithClient(ctx, tx.Client(), currentUser, email, hashedPassword, applyFirstBindDefaults) } tx, err := s.entClient.Tx(ctx) @@ -154,7 +162,7 @@ func (s *AuthService) bindEmailIdentityWithDefaultsTx( defer func() { _ = tx.Rollback() }() txCtx := dbent.NewTxContext(ctx, tx) - if err := s.bindEmailIdentityWithDefaults(txCtx, tx.Client(), currentUser, email, hashedPassword); err != nil { + if err := s.updateBoundEmailIdentityWithClient(txCtx, tx.Client(), currentUser, email, hashedPassword, applyFirstBindDefaults); err != nil { return err } if err := tx.Commit(); err != nil { @@ -163,12 +171,13 @@ func (s *AuthService) bindEmailIdentityWithDefaultsTx( return nil } -func (s *AuthService) bindEmailIdentityWithDefaults( +func (s *AuthService) updateBoundEmailIdentityWithClient( ctx context.Context, client *dbent.Client, currentUser *User, email string, hashedPassword string, + applyFirstBindDefaults bool, ) error { if client == nil || currentUser == nil || currentUser.ID <= 0 { return ErrServiceUnavailable @@ -192,8 +201,10 @@ func (s *AuthService) bindEmailIdentityWithDefaults( return ErrServiceUnavailable } - if err := s.ApplyProviderDefaultSettingsOnFirstBind(ctx, currentUser.ID, "email"); err != nil { - return fmt.Errorf("apply email first bind defaults: %w", err) + if applyFirstBindDefaults { + if err := s.ApplyProviderDefaultSettingsOnFirstBind(ctx, currentUser.ID, "email"); err != nil { + return fmt.Errorf("apply email first bind defaults: %w", err) + } } updatedUser, err := client.User.Get(ctx, currentUser.ID) diff --git a/backend/internal/service/auth_service_email_bind_test.go b/backend/internal/service/auth_service_email_bind_test.go index fd5f499b..d32a4a40 100644 --- a/backend/internal/service/auth_service_email_bind_test.go +++ b/backend/internal/service/auth_service_email_bind_test.go @@ -285,6 +285,148 @@ func TestAuthServiceBindEmailIdentity_RejectsReservedEmail(t *testing.T) { require.Nil(t, updatedUser) } +func TestAuthServiceBindEmailIdentity_ReplacesBoundEmailAndSkipsFirstBindDefaults(t *testing.T) { + assigner := &emailBindDefaultSubAssignerStub{} + cache := &emailBindCacheStub{ + data: &service.VerificationCodeData{ + Code: "123456", + CreatedAt: time.Now().UTC(), + ExpiresAt: time.Now().UTC().Add(10 * time.Minute), + }, + } + svc, _, client := newAuthServiceForEmailBind(t, map[string]string{ + service.SettingKeyAuthSourceDefaultEmailBalance: "8.5", + service.SettingKeyAuthSourceDefaultEmailConcurrency: "4", + service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":11,"validity_days":30}]`, + service.SettingKeyAuthSourceDefaultEmailGrantOnFirstBind: "true", + }, cache, assigner) + + ctx := context.Background() + hashedPassword, err := svc.HashPassword("current-password") + require.NoError(t, err) + + user, err := client.User.Create(). + SetEmail("current@example.com"). + SetUsername("bound-user"). + SetPasswordHash(hashedPassword). + SetBalance(7.5). + SetConcurrency(3). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + require.NoError(t, client.AuthIdentity.Create(). + SetUserID(user.ID). + SetProviderType("email"). + SetProviderKey("email"). + SetProviderSubject("current@example.com"). + SetVerifiedAt(time.Now().UTC()). + SetMetadata(map[string]any{"source": "test"}). + Exec(ctx)) + + updatedUser, err := svc.BindEmailIdentity(ctx, user.ID, "new@example.com", "123456", "current-password") + require.NoError(t, err) + require.NotNil(t, updatedUser) + require.Equal(t, "new@example.com", updatedUser.Email) + + storedUser, err := client.User.Get(ctx, user.ID) + require.NoError(t, err) + require.Equal(t, "new@example.com", storedUser.Email) + require.Equal(t, 7.5, storedUser.Balance) + require.Equal(t, 3, storedUser.Concurrency) + require.True(t, svc.CheckPassword("current-password", storedUser.PasswordHash)) + + newIdentityCount, err := client.AuthIdentity.Query(). + Where( + authidentity.UserIDEQ(user.ID), + authidentity.ProviderTypeEQ("email"), + authidentity.ProviderKeyEQ("email"), + authidentity.ProviderSubjectEQ("new@example.com"), + ). + Count(ctx) + require.NoError(t, err) + require.Equal(t, 1, newIdentityCount) + + oldIdentityCount, err := client.AuthIdentity.Query(). + Where( + authidentity.UserIDEQ(user.ID), + authidentity.ProviderTypeEQ("email"), + authidentity.ProviderKeyEQ("email"), + authidentity.ProviderSubjectEQ("current@example.com"), + ). + Count(ctx) + require.NoError(t, err) + require.Equal(t, 0, oldIdentityCount) + + require.Empty(t, assigner.calls) + require.Equal(t, 0, countProviderGrantRecords(t, client, user.ID, "email", "first_bind")) +} + +func TestAuthServiceBindEmailIdentity_RejectsWrongCurrentPasswordForBoundEmail(t *testing.T) { + cache := &emailBindCacheStub{ + data: &service.VerificationCodeData{ + Code: "123456", + CreatedAt: time.Now().UTC(), + ExpiresAt: time.Now().UTC().Add(10 * time.Minute), + }, + } + svc, _, client := newAuthServiceForEmailBind(t, nil, cache, nil) + + ctx := context.Background() + hashedPassword, err := svc.HashPassword("current-password") + require.NoError(t, err) + + user, err := client.User.Create(). + SetEmail("current@example.com"). + SetUsername("bound-user"). + SetPasswordHash(hashedPassword). + SetBalance(1). + SetConcurrency(1). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + require.NoError(t, client.AuthIdentity.Create(). + SetUserID(user.ID). + SetProviderType("email"). + SetProviderKey("email"). + SetProviderSubject("current@example.com"). + SetVerifiedAt(time.Now().UTC()). + SetMetadata(map[string]any{"source": "test"}). + Exec(ctx)) + + updatedUser, err := svc.BindEmailIdentity(ctx, user.ID, "new@example.com", "123456", "wrong-password") + require.ErrorIs(t, err, service.ErrPasswordIncorrect) + require.Nil(t, updatedUser) + + storedUser, err := client.User.Get(ctx, user.ID) + require.NoError(t, err) + require.Equal(t, "current@example.com", storedUser.Email) + require.True(t, svc.CheckPassword("current-password", storedUser.PasswordHash)) + + oldIdentityCount, err := client.AuthIdentity.Query(). + Where( + authidentity.UserIDEQ(user.ID), + authidentity.ProviderTypeEQ("email"), + authidentity.ProviderKeyEQ("email"), + authidentity.ProviderSubjectEQ("current@example.com"), + ). + Count(ctx) + require.NoError(t, err) + require.Equal(t, 1, oldIdentityCount) + + newIdentityCount, err := client.AuthIdentity.Query(). + Where( + authidentity.UserIDEQ(user.ID), + authidentity.ProviderTypeEQ("email"), + authidentity.ProviderKeyEQ("email"), + authidentity.ProviderSubjectEQ("new@example.com"), + ). + Count(ctx) + require.NoError(t, err) + require.Equal(t, 0, newIdentityCount) +} + type emailBindSettingRepoStub struct { values map[string]string } diff --git a/frontend/src/components/user/profile/ProfileIdentityBindingsSection.vue b/frontend/src/components/user/profile/ProfileIdentityBindingsSection.vue index 653b4e33..ee582a60 100644 --- a/frontend/src/components/user/profile/ProfileIdentityBindingsSection.vue +++ b/frontend/src/components/user/profile/ProfileIdentityBindingsSection.vue @@ -34,7 +34,7 @@