diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 1d39fa1e..3b474c4a 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -79,7 +79,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { totpCache := repository.NewTotpCache(redisClient) totpService := service.NewTotpService(userRepository, secretEncryptor, totpCache, settingService, emailService, emailQueueService) authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService, promoService, redeemService, totpService) - userHandler := handler.NewUserHandler(userService, emailService, emailCache) + userHandler := handler.NewUserHandler(userService, authService, emailService, emailCache) apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService) usageLogRepository := repository.NewUsageLogRepository(client, db) usageService := service.NewUsageService(usageLogRepository, userRepository, client, apiKeyAuthCacheInvalidator) diff --git a/backend/internal/handler/auth_oauth_pending_flow.go b/backend/internal/handler/auth_oauth_pending_flow.go index b7ff17c3..1d3b113f 100644 --- a/backend/internal/handler/auth_oauth_pending_flow.go +++ b/backend/internal/handler/auth_oauth_pending_flow.go @@ -16,6 +16,7 @@ import ( "github.com/Wei-Shaw/sub2api/ent/predicate" dbuser "github.com/Wei-Shaw/sub2api/ent/user" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/Wei-Shaw/sub2api/internal/pkg/ip" "github.com/Wei-Shaw/sub2api/internal/pkg/oauth" "github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/service" @@ -27,7 +28,7 @@ import ( const ( oauthPendingBrowserCookiePath = "/api/v1/auth/oauth" oauthPendingBrowserCookieName = "oauth_pending_browser_session" - oauthPendingSessionCookiePath = "/api/v1/auth/oauth/pending" + oauthPendingSessionCookiePath = "/api/v1/auth/oauth" oauthPendingSessionCookieName = "oauth_pending_session" oauthPendingCookieMaxAgeSec = 10 * 60 @@ -66,6 +67,13 @@ type createPendingOAuthAccountRequest struct { AdoptAvatar *bool `json:"adopt_avatar,omitempty"` } +type sendPendingOAuthVerifyCodeRequest struct { + Email string `json:"email" binding:"required,email"` + TurnstileToken string `json:"turnstile_token,omitempty"` + PendingAuthToken string `json:"pending_auth_token,omitempty"` + PendingOAuthToken string `json:"pending_oauth_token,omitempty"` +} + func (r bindPendingOAuthLoginRequest) adoptionDecision() oauthAdoptionDecisionRequest { return oauthAdoptionDecisionRequest{ AdoptDisplayName: r.AdoptDisplayName, @@ -448,6 +456,43 @@ func (h *AuthHandler) CreatePendingOAuthAccount(c *gin.Context) { h.createPendingOAuthAccount(c, "") } +// SendPendingOAuthVerifyCode sends a verification code for a browser-bound +// pending OAuth account-creation flow. +// POST /api/v1/auth/oauth/pending/send-verify-code +func (h *AuthHandler) SendPendingOAuthVerifyCode(c *gin.Context) { + var req sendPendingOAuthVerifyCodeRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, ip.GetClientIP(c)); err != nil { + response.ErrorFrom(c, err) + return + } + + _, session, _, err := readPendingOAuthBrowserSession(c, h) + if err != nil { + response.ErrorFrom(c, err) + return + } + if err := ensurePendingOAuthCompleteRegistrationSession(session); err != nil { + response.ErrorFrom(c, err) + return + } + + result, err := h.authService.SendPendingOAuthVerifyCode(c.Request.Context(), req.Email) + if err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, SendVerifyCodeResponse{ + Message: "Verification code sent successfully", + Countdown: result.Countdown, + }) +} + func (h *AuthHandler) upsertPendingOAuthAdoptionDecision( c *gin.Context, sessionID int64, @@ -1084,6 +1129,41 @@ func buildPendingOAuthSessionStatusPayload(session *dbent.PendingAuthSession) gi return payload } +func (h *AuthHandler) transitionPendingOAuthAccountToBindLogin( + c *gin.Context, + client *dbent.Client, + session *dbent.PendingAuthSession, + email string, + decision oauthAdoptionDecisionRequest, +) (*dbent.PendingAuthSession, error) { + existingUser, err := findUserByNormalizedEmail(c.Request.Context(), client, email) + if err != nil { + return nil, err + } + + completionResponse := mergePendingCompletionResponse(session, map[string]any{ + "step": "bind_login_required", + "email": email, + }) + session, err = updatePendingOAuthSessionProgress( + c.Request.Context(), + client, + session, + "adopt_existing_user_by_email", + email, + &existingUser.ID, + completionResponse, + ) + if err != nil { + return nil, infraerrors.InternalServer("PENDING_AUTH_SESSION_UPDATE_FAILED", "failed to update pending oauth session").WithCause(err) + } + + if _, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, decision); err != nil { + return nil, err + } + return session, nil +} + func writeOAuthTokenPairResponse(c *gin.Context, tokenPair *service.TokenPair) { c.JSON(http.StatusOK, gin.H{ "access_token": tokenPair.AccessToken, @@ -1199,29 +1279,11 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string) return } if existingUser != nil { - completionResponse := mergePendingCompletionResponse(session, map[string]any{ - "step": "bind_login_required", - "email": email, - }) - session, err = updatePendingOAuthSessionProgress( - c.Request.Context(), - client, - session, - "adopt_existing_user_by_email", - email, - &existingUser.ID, - completionResponse, - ) + session, err = h.transitionPendingOAuthAccountToBindLogin(c, client, session, email, req.adoptionDecision()) if err != nil { - response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_SESSION_UPDATE_FAILED", "failed to update pending oauth session").WithCause(err)) - return - } - - if _, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, req.adoptionDecision()); err != nil { response.ErrorFrom(c, err) return } - c.JSON(http.StatusOK, buildPendingOAuthSessionStatusPayload(session)) return } @@ -1239,27 +1301,77 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string) strings.TrimSpace(session.ProviderType), ) if err != nil { + if errors.Is(err, service.ErrEmailExists) { + session, err = h.transitionPendingOAuthAccountToBindLogin(c, client, session, email, req.adoptionDecision()) + if err != nil { + response.ErrorFrom(c, err) + return + } + c.JSON(http.StatusOK, buildPendingOAuthSessionStatusPayload(session)) + return + } response.ErrorFrom(c, err) return } + rollbackCreatedUser := func(originalErr error) bool { + if user == nil || user.ID <= 0 { + return false + } + if rollbackErr := h.authService.RollbackOAuthEmailAccountCreation( + c.Request.Context(), + user.ID, + strings.TrimSpace(req.InvitationCode), + ); rollbackErr != nil { + response.ErrorFrom(c, infraerrors.InternalServer( + "PENDING_AUTH_ACCOUNT_ROLLBACK_FAILED", + "failed to rollback pending oauth account creation", + ).WithCause(fmt.Errorf("original error: %w; rollback error: %v", originalErr, rollbackErr))) + return true + } + user = nil + return false + } + decision, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, req.adoptionDecision()) if err != nil { + if rollbackCreatedUser(err) { + return + } response.ErrorFrom(c, err) return } if err := applyPendingOAuthBinding(c.Request.Context(), client, h.authService, h.userService, session, decision, &user.ID, true, false); err != nil { + if rollbackCreatedUser(err) { + return + } response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to bind pending oauth identity").WithCause(err)) return } - h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID) + + if err := h.authService.FinalizeOAuthEmailAccount( + c.Request.Context(), + user, + strings.TrimSpace(req.InvitationCode), + strings.TrimSpace(session.ProviderType), + ); err != nil { + if rollbackCreatedUser(err) { + return + } + response.ErrorFrom(c, err) + return + } if _, err := pendingSvc.ConsumeBrowserSession(c.Request.Context(), session.SessionToken, session.BrowserSessionKey); err != nil { + if rollbackCreatedUser(err) { + return + } clearCookies() response.ErrorFrom(c, err) return } + h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID) clearCookies() writeOAuthTokenPairResponse(c, tokenPair) } diff --git a/backend/internal/handler/auth_oauth_pending_flow_test.go b/backend/internal/handler/auth_oauth_pending_flow_test.go index 913acddc..1013a082 100644 --- a/backend/internal/handler/auth_oauth_pending_flow_test.go +++ b/backend/internal/handler/auth_oauth_pending_flow_test.go @@ -5,6 +5,7 @@ import ( "context" "database/sql" "encoding/json" + "errors" "net/http" "net/http/httptest" "testing" @@ -15,6 +16,7 @@ import ( "github.com/Wei-Shaw/sub2api/ent/enttest" "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" + "github.com/Wei-Shaw/sub2api/ent/redeemcode" dbuser "github.com/Wei-Shaw/sub2api/ent/user" "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" @@ -61,6 +63,18 @@ func TestApplySuggestedProfileToCompletionResponseKeepsExistingPayloadValues(t * require.Equal(t, true, payload["adoption_required"]) } +func TestSetOAuthPendingSessionCookieUsesProviderCompletionPathPrefix(t *testing.T) { + recorder := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(recorder) + ginCtx.Request = httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/oidc/callback", nil) + + setOAuthPendingSessionCookie(ginCtx, "pending-session-token", false) + + cookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName) + require.NotNil(t, cookie) + require.Equal(t, "/api/v1/auth/oauth", cookie.Path) +} + func TestExchangePendingOAuthCompletionPreviewThenFinalizeAppliesAdoptionDecision(t *testing.T) { handler, client := newOAuthPendingFlowTestHandler(t, false) ctx := context.Background() @@ -943,6 +957,81 @@ func TestCreateOIDCOAuthAccountBlocksBackendModeBeforeCreatingUser(t *testing.T) require.Nil(t, storedSession.ConsumedAt) } +func TestCreateOIDCOAuthAccountRollsBackCreatedUserWhenBindingFails(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandlerWithEmailVerification(t, true, "fresh@example.com", "246810") + ctx := context.Background() + + conflictOwner, err := client.User.Create(). + SetEmail("owner@example.com"). + SetUsername("owner-user"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + _, err = client.AuthIdentity.Create(). + SetUserID(conflictOwner.ID). + SetProviderType("oidc"). + SetProviderKey("https://issuer.example"). + SetProviderSubject("oidc-conflict-123"). + SetMetadata(map[string]any{ + "username": "owner-user", + }). + Save(ctx) + require.NoError(t, err) + + invitation, err := client.RedeemCode.Create(). + SetCode("INVITE123"). + SetType(service.RedeemTypeInvitation). + SetStatus(service.StatusUnused). + SetValue(0). + Save(ctx) + require.NoError(t, err) + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("create-account-conflict-session-token"). + SetIntent("login"). + SetProviderType("oidc"). + SetProviderKey("https://issuer.example"). + SetProviderSubject("oidc-conflict-123"). + SetBrowserSessionKey("create-account-conflict-browser-session-key"). + SetUpstreamIdentityClaims(map[string]any{ + "username": "oidc_user", + }). + SetRedirectTo("/profile"). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + body := bytes.NewBufferString(`{"email":"fresh@example.com","verify_code":"246810","password":"secret-123","invitation_code":"INVITE123"}`) + recorder := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/create-account", body) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("create-account-conflict-browser-session-key")}) + ginCtx.Request = req + + handler.CreateOIDCOAuthAccount(ginCtx) + + require.Equal(t, http.StatusInternalServerError, recorder.Code) + + userCount, err := client.User.Query().Where(dbuser.EmailEQ("fresh@example.com")).Count(ctx) + require.NoError(t, err) + require.Zero(t, userCount) + + storedInvitation, err := client.RedeemCode.Get(ctx, invitation.ID) + require.NoError(t, err) + require.Equal(t, service.StatusUnused, storedInvitation.Status) + require.Nil(t, storedInvitation.UsedBy) + require.Nil(t, storedInvitation.UsedAt) + + storedSession, err := client.PendingAuthSession.Get(ctx, session.ID) + require.NoError(t, err) + require.Nil(t, storedSession.ConsumedAt) +} + func TestBindOIDCOAuthLoginBindsExistingUserAndConsumesSession(t *testing.T) { handler, client := newOAuthPendingFlowTestHandler(t, false) ctx := context.Background() @@ -1529,6 +1618,8 @@ type oauthPendingFlowTestHandlerOptions struct { defaultSubAssigner service.DefaultSubscriptionAssigner totpCache service.TotpCache totpEncryptor service.SecretEncryptor + redeemRepoFactory func(client *dbent.Client) service.RedeemCodeRepository + userRepoOptions oauthPendingFlowUserRepoOptions } func newOAuthPendingFlowTestHandlerWithDependencies( @@ -1590,7 +1681,17 @@ CREATE TABLE IF NOT EXISTS user_avatars ( settingValues[key] = value } settingSvc := service.NewSettingService(&oauthPendingFlowSettingRepoStub{values: settingValues}, cfg) - userRepo := &oauthPendingFlowUserRepo{client: client} + userRepo := &oauthPendingFlowUserRepo{ + client: client, + options: options.userRepoOptions, + } + redeemRepo := service.RedeemCodeRepository(nil) + if options.redeemRepoFactory != nil { + redeemRepo = options.redeemRepoFactory(client) + } + if redeemRepo == nil { + redeemRepo = &oauthPendingFlowRedeemCodeRepo{client: client} + } var emailService *service.EmailService if options.emailCache != nil { emailService = service.NewEmailService(&oauthPendingFlowSettingRepoStub{ @@ -1602,7 +1703,7 @@ CREATE TABLE IF NOT EXISTS user_avatars ( authSvc := service.NewAuthService( client, userRepo, - nil, + redeemRepo, &oauthPendingFlowRefreshTokenCacheStub{}, cfg, settingSvc, @@ -1797,6 +1898,127 @@ func (s *oauthPendingFlowRefreshTokenCacheStub) IsTokenInFamily(context.Context, return false, nil } +type oauthPendingFlowRedeemCodeRepo struct { + client *dbent.Client +} + +func (r *oauthPendingFlowRedeemCodeRepo) Create(context.Context, *service.RedeemCode) error { + panic("unexpected Create call") +} + +func (r *oauthPendingFlowRedeemCodeRepo) CreateBatch(context.Context, []service.RedeemCode) error { + panic("unexpected CreateBatch call") +} + +func (r *oauthPendingFlowRedeemCodeRepo) GetByID(context.Context, int64) (*service.RedeemCode, error) { + panic("unexpected GetByID call") +} + +func (r *oauthPendingFlowRedeemCodeRepo) GetByCode(ctx context.Context, code string) (*service.RedeemCode, error) { + entity, err := r.client.RedeemCode.Query().Where(redeemcode.CodeEQ(code)).Only(ctx) + if err != nil { + if dbent.IsNotFound(err) { + return nil, service.ErrRedeemCodeNotFound + } + return nil, err + } + notes := "" + if entity.Notes != nil { + notes = *entity.Notes + } + return &service.RedeemCode{ + ID: entity.ID, + Code: entity.Code, + Type: entity.Type, + Value: entity.Value, + Status: entity.Status, + UsedBy: entity.UsedBy, + UsedAt: entity.UsedAt, + Notes: notes, + CreatedAt: entity.CreatedAt, + GroupID: entity.GroupID, + ValidityDays: entity.ValidityDays, + }, nil +} + +func (r *oauthPendingFlowRedeemCodeRepo) Update(ctx context.Context, code *service.RedeemCode) error { + if code == nil { + return nil + } + update := r.client.RedeemCode.UpdateOneID(code.ID). + SetCode(code.Code). + SetType(code.Type). + SetValue(code.Value). + SetStatus(code.Status). + SetNotes(code.Notes). + SetValidityDays(code.ValidityDays) + if code.UsedBy != nil { + update = update.SetUsedBy(*code.UsedBy) + } else { + update = update.ClearUsedBy() + } + if code.UsedAt != nil { + update = update.SetUsedAt(*code.UsedAt) + } else { + update = update.ClearUsedAt() + } + if code.GroupID != nil { + update = update.SetGroupID(*code.GroupID) + } else { + update = update.ClearGroupID() + } + _, err := update.Save(ctx) + return err +} + +func (r *oauthPendingFlowRedeemCodeRepo) Delete(context.Context, int64) error { + panic("unexpected Delete call") +} + +func (r *oauthPendingFlowRedeemCodeRepo) Use(ctx context.Context, id, userID int64) error { + affected, err := r.client.RedeemCode.Update(). + Where(redeemcode.IDEQ(id), redeemcode.StatusEQ(service.StatusUnused)). + SetStatus(service.StatusUsed). + SetUsedBy(userID). + SetUsedAt(time.Now().UTC()). + Save(ctx) + if err != nil { + return err + } + if affected == 0 { + return service.ErrRedeemCodeUsed + } + return nil +} + +func (r *oauthPendingFlowRedeemCodeRepo) List(context.Context, pagination.PaginationParams) ([]service.RedeemCode, *pagination.PaginationResult, error) { + panic("unexpected List call") +} + +func (r *oauthPendingFlowRedeemCodeRepo) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string) ([]service.RedeemCode, *pagination.PaginationResult, error) { + panic("unexpected ListWithFilters call") +} + +func (r *oauthPendingFlowRedeemCodeRepo) ListByUser(context.Context, int64, int) ([]service.RedeemCode, error) { + panic("unexpected ListByUser call") +} + +func (r *oauthPendingFlowRedeemCodeRepo) ListByUserPaginated(context.Context, int64, pagination.PaginationParams, string) ([]service.RedeemCode, *pagination.PaginationResult, error) { + panic("unexpected ListByUserPaginated call") +} + +func (r *oauthPendingFlowRedeemCodeRepo) SumPositiveBalanceByUser(context.Context, int64) (float64, error) { + panic("unexpected SumPositiveBalanceByUser call") +} + +type oauthPendingFlowFailingUseRedeemRepo struct { + *oauthPendingFlowRedeemCodeRepo +} + +func (r *oauthPendingFlowFailingUseRedeemRepo) Use(context.Context, int64, int64) error { + return errors.New("forced invitation use failure") +} + func decodeJSONResponseData(t *testing.T, recorder *httptest.ResponseRecorder) map[string]any { t.Helper() @@ -1872,6 +2094,11 @@ func countProviderGrantRecords( type oauthPendingFlowUserRepo struct { client *dbent.Client + options oauthPendingFlowUserRepoOptions +} + +type oauthPendingFlowUserRepoOptions struct { + rejectDeleteWhileAuthIdentityExists bool } func (r *oauthPendingFlowUserRepo) Create(ctx context.Context, user *service.User) error { @@ -1953,6 +2180,15 @@ func (r *oauthPendingFlowUserRepo) Update(ctx context.Context, user *service.Use } func (r *oauthPendingFlowUserRepo) Delete(ctx context.Context, id int64) error { + if r.options.rejectDeleteWhileAuthIdentityExists { + count, err := r.client.AuthIdentity.Query().Where(authidentity.UserIDEQ(id)).Count(ctx) + if err != nil { + return err + } + if count > 0 { + return errors.New("cannot delete user while auth identities still exist") + } + } return r.client.User.DeleteOneID(id).Exec(ctx) } diff --git a/backend/internal/handler/user_handler.go b/backend/internal/handler/user_handler.go index b1ade5c0..a6a7be9a 100644 --- a/backend/internal/handler/user_handler.go +++ b/backend/internal/handler/user_handler.go @@ -15,14 +15,21 @@ import ( // UserHandler handles user-related requests type UserHandler struct { userService *service.UserService + authService *service.AuthService emailService *service.EmailService emailCache service.EmailCache } // NewUserHandler creates a new UserHandler -func NewUserHandler(userService *service.UserService, emailService *service.EmailService, emailCache service.EmailCache) *UserHandler { +func NewUserHandler( + userService *service.UserService, + authService *service.AuthService, + emailService *service.EmailService, + emailCache service.EmailCache, +) *UserHandler { return &UserHandler{ userService: userService, + authService: authService, emailService: emailService, emailCache: emailCache, } @@ -157,6 +164,16 @@ type StartIdentityBindingRequest struct { RedirectTo string `json:"redirect_to"` } +type BindEmailIdentityRequest struct { + Email string `json:"email" binding:"required,email"` + VerifyCode string `json:"verify_code" binding:"required"` + Password string `json:"password" binding:"required,min=6"` +} + +type SendEmailBindingCodeRequest struct { + Email string `json:"email" binding:"required,email"` +} + // StartIdentityBinding returns the backend authorize URL for starting a third-party identity bind flow. // POST /api/v1/user/auth-identities/bind/start func (h *UserHandler) StartIdentityBinding(c *gin.Context) { @@ -183,6 +200,73 @@ func (h *UserHandler) StartIdentityBinding(c *gin.Context) { response.Success(c, result) } +// BindEmailIdentity verifies and binds a local email identity for the current user. +// POST /api/v1/user/account-bindings/email +func (h *UserHandler) BindEmailIdentity(c *gin.Context) { + subject, ok := middleware2.GetAuthSubjectFromContext(c) + if !ok { + response.Unauthorized(c, "User not authenticated") + return + } + if h.authService == nil { + response.InternalError(c, "Auth service not configured") + return + } + + var req BindEmailIdentityRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + updatedUser, err := h.authService.BindEmailIdentity( + c.Request.Context(), + subject.UserID, + req.Email, + req.VerifyCode, + req.Password, + ) + if err != nil { + response.ErrorFrom(c, err) + return + } + + profileResp, err := h.buildUserProfileResponse(c.Request.Context(), subject.UserID, updatedUser) + if err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, profileResp) +} + +// SendEmailBindingCode sends a verification code for the current user's email binding flow. +// POST /api/v1/user/account-bindings/email/send-code +func (h *UserHandler) SendEmailBindingCode(c *gin.Context) { + subject, ok := middleware2.GetAuthSubjectFromContext(c) + if !ok { + response.Unauthorized(c, "User not authenticated") + return + } + if h.authService == nil { + response.InternalError(c, "Auth service not configured") + return + } + + var req SendEmailBindingCodeRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + if err := h.authService.SendEmailIdentityBindCode(c.Request.Context(), subject.UserID, req.Email); err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, gin.H{"message": "Verification code sent successfully"}) +} + // SendNotifyEmailCodeRequest represents the request to send notify email verification code type SendNotifyEmailCodeRequest struct { Email string `json:"email" binding:"required,email"` diff --git a/backend/internal/handler/user_handler_test.go b/backend/internal/handler/user_handler_test.go index aacfc332..72b28293 100644 --- a/backend/internal/handler/user_handler_test.go +++ b/backend/internal/handler/user_handler_test.go @@ -11,6 +11,7 @@ import ( "testing" "time" + "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/service" @@ -122,7 +123,7 @@ func TestUserHandlerUpdateProfileReturnsAvatarURL(t *testing.T) { Status: service.StatusActive, }, } - handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil) + handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil) body := []byte(`{"avatar_url":"https://cdn.example.com/avatar.png"}`) recorder := httptest.NewRecorder() @@ -180,7 +181,7 @@ func TestUserHandlerGetProfileReturnsIdentitySummaries(t *testing.T) { }, }, } - handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil) + handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil) recorder := httptest.NewRecorder() c, _ := gin.CreateTestContext(recorder) @@ -262,7 +263,7 @@ func TestUserHandlerGetProfileReturnsLegacyCompatibilityFields(t *testing.T) { }, }, } - handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil) + handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil) recorder := httptest.NewRecorder() c, _ := gin.CreateTestContext(recorder) @@ -311,6 +312,116 @@ func TestUserHandlerGetProfileReturnsLegacyCompatibilityFields(t *testing.T) { require.Equal(t, "linuxdo", usernameSource["source"]) } +type userHandlerEmailCacheStub struct { + data *service.VerificationCodeData +} + +func (s *userHandlerEmailCacheStub) GetVerificationCode(context.Context, string) (*service.VerificationCodeData, error) { + return s.data, nil +} + +func (s *userHandlerEmailCacheStub) SetVerificationCode(context.Context, string, *service.VerificationCodeData, time.Duration) error { + return nil +} + +func (s *userHandlerEmailCacheStub) DeleteVerificationCode(context.Context, string) error { + return nil +} + +func (s *userHandlerEmailCacheStub) GetNotifyVerifyCode(context.Context, string) (*service.VerificationCodeData, error) { + return nil, nil +} + +func (s *userHandlerEmailCacheStub) SetNotifyVerifyCode(context.Context, string, *service.VerificationCodeData, time.Duration) error { + return nil +} + +func (s *userHandlerEmailCacheStub) DeleteNotifyVerifyCode(context.Context, string) error { + return nil +} + +func (s *userHandlerEmailCacheStub) GetPasswordResetToken(context.Context, string) (*service.PasswordResetTokenData, error) { + return nil, nil +} + +func (s *userHandlerEmailCacheStub) SetPasswordResetToken(context.Context, string, *service.PasswordResetTokenData, time.Duration) error { + return nil +} + +func (s *userHandlerEmailCacheStub) DeletePasswordResetToken(context.Context, string) error { + return nil +} + +func (s *userHandlerEmailCacheStub) IsPasswordResetEmailInCooldown(context.Context, string) bool { + return false +} + +func (s *userHandlerEmailCacheStub) SetPasswordResetEmailCooldown(context.Context, string, time.Duration) error { + return nil +} + +func (s *userHandlerEmailCacheStub) GetNotifyCodeUserRate(context.Context, int64) (int64, error) { + return 0, nil +} + +func (s *userHandlerEmailCacheStub) IncrNotifyCodeUserRate(context.Context, int64, time.Duration) (int64, error) { + return 0, nil +} + +func TestUserHandlerBindEmailIdentityReturnsProfileResponse(t *testing.T) { + gin.SetMode(gin.TestMode) + + repo := &userHandlerRepoStub{ + user: &service.User{ + ID: 11, + Email: "legacy-user" + service.LinuxDoConnectSyntheticEmailDomain, + Username: "legacy-user", + Role: service.RoleUser, + Status: service.StatusActive, + }, + } + emailCache := &userHandlerEmailCacheStub{ + data: &service.VerificationCodeData{ + Code: "123456", + CreatedAt: time.Now().UTC(), + ExpiresAt: time.Now().UTC().Add(10 * time.Minute), + }, + } + cfg := &config.Config{ + JWT: config.JWTConfig{ + Secret: "test-secret", + ExpireHour: 1, + }, + } + emailService := service.NewEmailService(nil, emailCache) + authService := service.NewAuthService(nil, repo, nil, nil, cfg, nil, emailService, nil, nil, nil, nil) + handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil) + + body := []byte(`{"email":"new@example.com","verify_code":"123456","password":"new-password"}`) + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/user/account-bindings/email", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + c.Params = gin.Params{{Key: "provider", Value: "email"}} + c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 11}) + + handler.BindEmailIdentity(c) + + require.Equal(t, http.StatusOK, recorder.Code) + + var resp struct { + Code int `json:"code"` + Data struct { + Email string `json:"email"` + EmailBound bool `json:"email_bound"` + } `json:"data"` + } + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp)) + require.Equal(t, 0, resp.Code) + require.Equal(t, "new@example.com", resp.Data.Email) + require.True(t, resp.Data.EmailBound) +} + func TestUserHandlerStartIdentityBindingReturnsAuthorizeURL(t *testing.T) { gin.SetMode(gin.TestMode) @@ -323,7 +434,7 @@ func TestUserHandlerStartIdentityBindingReturnsAuthorizeURL(t *testing.T) { Status: service.StatusActive, }, } - handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil) + handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil) body := []byte(`{"provider":"wechat","redirect_to":"/settings/profile"}`) recorder := httptest.NewRecorder() diff --git a/backend/internal/server/routes/auth.go b/backend/internal/server/routes/auth.go index 20d3d9b4..f1032eb5 100644 --- a/backend/internal/server/routes/auth.go +++ b/backend/internal/server/routes/auth.go @@ -74,6 +74,12 @@ func RegisterAuthRoutes( }), h.Auth.ExchangePendingOAuthCompletion, ) + auth.POST("/oauth/pending/send-verify-code", + rateLimiter.LimitWithOptions("oauth-pending-send-verify-code", 5, time.Minute, middleware.RateLimitOptions{ + FailureMode: middleware.RateLimitFailClose, + }), + h.Auth.SendPendingOAuthVerifyCode, + ) auth.POST("/oauth/pending/create-account", rateLimiter.LimitWithOptions("oauth-pending-create-account", 10, time.Minute, middleware.RateLimitOptions{ FailureMode: middleware.RateLimitFailClose, diff --git a/backend/internal/server/routes/auth_rate_limit_test.go b/backend/internal/server/routes/auth_rate_limit_test.go index 4f411cec..07a66efb 100644 --- a/backend/internal/server/routes/auth_rate_limit_test.go +++ b/backend/internal/server/routes/auth_rate_limit_test.go @@ -52,6 +52,7 @@ func TestAuthRoutesRateLimitFailCloseWhenRedisUnavailable(t *testing.T) { "/api/v1/auth/login", "/api/v1/auth/login/2fa", "/api/v1/auth/send-verify-code", + "/api/v1/auth/oauth/pending/send-verify-code", } for _, path := range paths { diff --git a/backend/internal/server/routes/user.go b/backend/internal/server/routes/user.go index ccbe23ce..46baa80a 100644 --- a/backend/internal/server/routes/user.go +++ b/backend/internal/server/routes/user.go @@ -25,6 +25,8 @@ func RegisterUserRoutes( user.GET("/profile", h.User.GetProfile) user.PUT("/password", h.User.ChangePassword) user.PUT("", h.User.UpdateProfile) + user.POST("/account-bindings/email/send-code", h.User.SendEmailBindingCode) + user.POST("/account-bindings/email", h.User.BindEmailIdentity) user.POST("/auth-identities/bind/start", h.User.StartIdentityBinding) // 通知邮箱管理 diff --git a/backend/internal/service/auth_email_binding.go b/backend/internal/service/auth_email_binding.go new file mode 100644 index 00000000..b999660b --- /dev/null +++ b/backend/internal/service/auth_email_binding.go @@ -0,0 +1,128 @@ +package service + +import ( + "context" + "errors" + "fmt" + "net/mail" + "strings" + + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" +) + +// BindEmailIdentity verifies and binds a local email/password identity to the current user. +func (s *AuthService) BindEmailIdentity( + ctx context.Context, + userID int64, + email string, + verifyCode string, + password string, +) (*User, error) { + if s == nil { + return nil, ErrServiceUnavailable + } + + normalizedEmail, err := normalizeEmailForIdentityBinding(email) + if err != nil { + return nil, err + } + if isReservedEmail(normalizedEmail) { + return nil, ErrEmailReserved + } + if strings.TrimSpace(password) == "" { + return nil, ErrPasswordRequired + } + if err := s.VerifyOAuthEmailCode(ctx, normalizedEmail, verifyCode); err != nil { + return nil, err + } + + currentUser, err := s.userRepo.GetByID(ctx, userID) + if err != nil { + return nil, err + } + + existingUser, err := s.userRepo.GetByEmail(ctx, normalizedEmail) + switch { + case err == nil && existingUser != nil && existingUser.ID != userID: + return nil, ErrEmailExists + case err != nil && !errors.Is(err, ErrUserNotFound): + return nil, ErrServiceUnavailable + } + + hashedPassword, err := s.HashPassword(password) + if err != nil { + return nil, fmt.Errorf("hash password: %w", err) + } + + firstRealEmailBind := !hasBindableEmailIdentitySubject(currentUser.Email) + currentUser.Email = normalizedEmail + currentUser.PasswordHash = hashedPassword + if err := s.userRepo.Update(ctx, currentUser); err != nil { + if errors.Is(err, ErrEmailExists) { + return nil, ErrEmailExists + } + return nil, ErrServiceUnavailable + } + + if firstRealEmailBind { + if err := s.ApplyProviderDefaultSettingsOnFirstBind(ctx, userID, "email"); err != nil { + return nil, fmt.Errorf("apply email first bind defaults: %w", err) + } + } + + return currentUser, nil +} + +// SendEmailIdentityBindCode sends a verification code for authenticated email binding flows. +func (s *AuthService) SendEmailIdentityBindCode(ctx context.Context, userID int64, email string) error { + if s == nil { + return ErrServiceUnavailable + } + + normalizedEmail, err := normalizeEmailForIdentityBinding(email) + if err != nil { + return err + } + if isReservedEmail(normalizedEmail) { + return ErrEmailReserved + } + if s.emailService == nil { + return ErrServiceUnavailable + } + if _, err := s.userRepo.GetByID(ctx, userID); err != nil { + if errors.Is(err, ErrUserNotFound) { + return ErrUserNotFound + } + return ErrServiceUnavailable + } + + existingUser, err := s.userRepo.GetByEmail(ctx, normalizedEmail) + switch { + case err == nil && existingUser != nil && existingUser.ID != userID: + return ErrEmailExists + case err != nil && !errors.Is(err, ErrUserNotFound): + return ErrServiceUnavailable + } + + siteName := "Sub2API" + if s.settingService != nil { + siteName = s.settingService.GetSiteName(ctx) + } + return s.emailService.SendVerifyCode(ctx, normalizedEmail, siteName) +} + +func normalizeEmailForIdentityBinding(email string) (string, error) { + normalized := strings.ToLower(strings.TrimSpace(email)) + if normalized == "" || len(normalized) > 255 { + return "", infraerrors.BadRequest("INVALID_EMAIL", "invalid email") + } + if _, err := mail.ParseAddress(normalized); err != nil { + return "", infraerrors.BadRequest("INVALID_EMAIL", "invalid email") + } + return normalized, nil +} + +func hasBindableEmailIdentitySubject(email string) bool { + normalized := strings.ToLower(strings.TrimSpace(email)) + return normalized != "" && !isReservedEmail(normalized) +} diff --git a/backend/internal/service/auth_oauth_email_flow.go b/backend/internal/service/auth_oauth_email_flow.go index 2e0107ae..ce25222c 100644 --- a/backend/internal/service/auth_oauth_email_flow.go +++ b/backend/internal/service/auth_oauth_email_flow.go @@ -4,9 +4,71 @@ import ( "context" "errors" "fmt" + "net/mail" "strings" + "time" ) +func normalizeOAuthSignupSource(signupSource string) string { + signupSource = strings.TrimSpace(strings.ToLower(signupSource)) + if signupSource == "" { + return "email" + } + return signupSource +} + +// SendPendingOAuthVerifyCode sends a local verification code for pending OAuth +// account-creation flows without relying on the public registration gate. +func (s *AuthService) SendPendingOAuthVerifyCode(ctx context.Context, email string) (*SendVerifyCodeResult, error) { + email = strings.TrimSpace(strings.ToLower(email)) + if email == "" { + return nil, ErrEmailVerifyRequired + } + if _, err := mail.ParseAddress(email); err != nil { + return nil, ErrEmailVerifyRequired + } + if isReservedEmail(email) { + return nil, ErrEmailReserved + } + if s == nil || s.emailService == nil { + return nil, ErrServiceUnavailable + } + + siteName := "Sub2API" + if s.settingService != nil { + siteName = s.settingService.GetSiteName(ctx) + } + if err := s.emailService.SendVerifyCode(ctx, email, siteName); err != nil { + return nil, err + } + return &SendVerifyCodeResult{ + Countdown: int(verifyCodeCooldown / time.Second), + }, nil +} + +func (s *AuthService) validateOAuthRegistrationInvitation(ctx context.Context, invitationCode string) (*RedeemCode, error) { + if s == nil || s.settingService == nil || !s.settingService.IsInvitationCodeEnabled(ctx) { + return nil, nil + } + if s.redeemRepo == nil { + return nil, ErrServiceUnavailable + } + + invitationCode = strings.TrimSpace(invitationCode) + if invitationCode == "" { + return nil, ErrInvitationCodeRequired + } + + redeemCode, err := s.redeemRepo.GetByCode(ctx, invitationCode) + if err != nil { + return nil, ErrInvitationCodeInvalid + } + if redeemCode.Type != RedeemTypeInvitation || redeemCode.Status != StatusUnused { + return nil, ErrInvitationCodeInvalid + } + return redeemCode, nil +} + // VerifyOAuthEmailCode verifies the locally entered email verification code for // third-party signup and binding flows. This is intentionally independent from // the global registration email verification toggle. @@ -54,19 +116,8 @@ func (s *AuthService) RegisterOAuthEmailAccount( return nil, nil, err } - var invitationRedeemCode *RedeemCode - if s.settingService.IsInvitationCodeEnabled(ctx) { - if invitationCode == "" { - return nil, nil, ErrInvitationCodeRequired - } - redeemCode, err := s.redeemRepo.GetByCode(ctx, invitationCode) - if err != nil { - return nil, nil, ErrInvitationCodeInvalid - } - if redeemCode.Type != RedeemTypeInvitation || redeemCode.Status != StatusUnused { - return nil, nil, ErrInvitationCodeInvalid - } - invitationRedeemCode = redeemCode + if _, err := s.validateOAuthRegistrationInvitation(ctx, invitationCode); err != nil { + return nil, nil, err } existsEmail, err := s.userRepo.ExistsByEmail(ctx, email) @@ -104,22 +155,91 @@ func (s *AuthService) RegisterOAuthEmailAccount( return nil, nil, ErrServiceUnavailable } - s.postAuthUserBootstrap(ctx, user, signupSource, false) - s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults") - - if invitationRedeemCode != nil { - if err := s.redeemRepo.Use(ctx, invitationRedeemCode.ID, user.ID); err != nil { - return nil, nil, ErrInvitationCodeInvalid - } - } - tokenPair, err := s.GenerateTokenPair(ctx, user, "") if err != nil { + _ = s.RollbackOAuthEmailAccountCreation(ctx, user.ID, "") return nil, nil, fmt.Errorf("generate token pair: %w", err) } return tokenPair, user, nil } +// FinalizeOAuthEmailAccount applies invitation usage and normal signup bootstrap +// only after the pending OAuth flow has fully reached its last reversible step. +func (s *AuthService) FinalizeOAuthEmailAccount( + ctx context.Context, + user *User, + invitationCode string, + signupSource string, +) error { + if s == nil || user == nil || user.ID <= 0 { + return ErrServiceUnavailable + } + + signupSource = normalizeOAuthSignupSource(signupSource) + invitationRedeemCode, err := s.validateOAuthRegistrationInvitation(ctx, invitationCode) + if err != nil { + return err + } + if invitationRedeemCode != nil { + if err := s.redeemRepo.Use(ctx, invitationRedeemCode.ID, user.ID); err != nil { + return ErrInvitationCodeInvalid + } + } + + s.postAuthUserBootstrap(ctx, user, signupSource, false) + grantPlan := s.resolveSignupGrantPlan(ctx, signupSource) + s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults") + return nil +} + +// RollbackOAuthEmailAccountCreation removes a partially-created local account +// and restores any invitation code already consumed by that account. +func (s *AuthService) RollbackOAuthEmailAccountCreation(ctx context.Context, userID int64, invitationCode string) error { + if s == nil || s.userRepo == nil || userID <= 0 { + return ErrServiceUnavailable + } + if err := s.restoreOAuthRegistrationInvitation(ctx, invitationCode, userID); err != nil { + return err + } + if err := s.userRepo.Delete(ctx, userID); err != nil { + return fmt.Errorf("delete created oauth user: %w", err) + } + return nil +} + +func (s *AuthService) restoreOAuthRegistrationInvitation(ctx context.Context, invitationCode string, userID int64) error { + if s == nil || s.settingService == nil || !s.settingService.IsInvitationCodeEnabled(ctx) { + return nil + } + if s.redeemRepo == nil { + return ErrServiceUnavailable + } + + invitationCode = strings.TrimSpace(invitationCode) + if invitationCode == "" || userID <= 0 { + return nil + } + + redeemCode, err := s.redeemRepo.GetByCode(ctx, invitationCode) + if err != nil { + if errors.Is(err, ErrRedeemCodeNotFound) { + return nil + } + return fmt.Errorf("load invitation code: %w", err) + } + if redeemCode.Type != RedeemTypeInvitation || redeemCode.Status != StatusUsed || redeemCode.UsedBy == nil || *redeemCode.UsedBy != userID { + return nil + } + + redeemCode.Status = StatusUnused + redeemCode.UsedBy = nil + redeemCode.UsedAt = nil + if err := s.redeemRepo.Update(ctx, redeemCode); err != nil { + return fmt.Errorf("restore invitation code: %w", err) + } + return nil +} + // ValidatePasswordCredentials checks the local password without completing the // login flow. This is used by pending third-party account adoption flows before // the external identity has been bound. diff --git a/backend/internal/service/auth_oauth_email_flow_test.go b/backend/internal/service/auth_oauth_email_flow_test.go new file mode 100644 index 00000000..a77dda72 --- /dev/null +++ b/backend/internal/service/auth_oauth_email_flow_test.go @@ -0,0 +1,251 @@ +//go:build unit + +package service + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/stretchr/testify/require" +) + +type redeemCodeRepoStub struct { + codesByCode map[string]*RedeemCode + useCalls []struct { + id int64 + userID int64 + } + updateCalls []*RedeemCode +} + +func (s *redeemCodeRepoStub) Create(context.Context, *RedeemCode) error { + panic("unexpected Create call") +} + +func (s *redeemCodeRepoStub) CreateBatch(context.Context, []RedeemCode) error { + panic("unexpected CreateBatch call") +} + +func (s *redeemCodeRepoStub) GetByID(context.Context, int64) (*RedeemCode, error) { + panic("unexpected GetByID call") +} + +func (s *redeemCodeRepoStub) GetByCode(_ context.Context, code string) (*RedeemCode, error) { + if s.codesByCode == nil { + return nil, ErrRedeemCodeNotFound + } + redeemCode, ok := s.codesByCode[code] + if !ok { + return nil, ErrRedeemCodeNotFound + } + cloned := *redeemCode + return &cloned, nil +} + +func (s *redeemCodeRepoStub) Update(_ context.Context, code *RedeemCode) error { + if code == nil { + return nil + } + cloned := *code + s.updateCalls = append(s.updateCalls, &cloned) + if s.codesByCode == nil { + s.codesByCode = make(map[string]*RedeemCode) + } + s.codesByCode[cloned.Code] = &cloned + return nil +} + +func (s *redeemCodeRepoStub) Delete(context.Context, int64) error { + panic("unexpected Delete call") +} + +func (s *redeemCodeRepoStub) Use(_ context.Context, id, userID int64) error { + for code, redeemCode := range s.codesByCode { + if redeemCode.ID != id { + continue + } + now := time.Now().UTC() + redeemCode.Status = StatusUsed + redeemCode.UsedBy = &userID + redeemCode.UsedAt = &now + s.codesByCode[code] = redeemCode + s.useCalls = append(s.useCalls, struct { + id int64 + userID int64 + }{id: id, userID: userID}) + return nil + } + return ErrRedeemCodeNotFound +} + +func (s *redeemCodeRepoStub) List(context.Context, pagination.PaginationParams) ([]RedeemCode, *pagination.PaginationResult, error) { + panic("unexpected List call") +} + +func (s *redeemCodeRepoStub) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string) ([]RedeemCode, *pagination.PaginationResult, error) { + panic("unexpected ListWithFilters call") +} + +func (s *redeemCodeRepoStub) ListByUser(context.Context, int64, int) ([]RedeemCode, error) { + panic("unexpected ListByUser call") +} + +func (s *redeemCodeRepoStub) ListByUserPaginated(context.Context, int64, pagination.PaginationParams, string) ([]RedeemCode, *pagination.PaginationResult, error) { + panic("unexpected ListByUserPaginated call") +} + +func (s *redeemCodeRepoStub) SumPositiveBalanceByUser(context.Context, int64) (float64, error) { + panic("unexpected SumPositiveBalanceByUser call") +} + +func newOAuthEmailFlowAuthService( + userRepo UserRepository, + redeemRepo RedeemCodeRepository, + refreshTokenCache RefreshTokenCache, + settings map[string]string, + emailCache EmailCache, +) *AuthService { + cfg := &config.Config{ + JWT: config.JWTConfig{ + Secret: "test-secret", + ExpireHour: 1, + AccessTokenExpireMinutes: 60, + RefreshTokenExpireDays: 7, + }, + Default: config.DefaultConfig{ + UserBalance: 3.5, + UserConcurrency: 2, + }, + } + + settingService := NewSettingService(&settingRepoStub{values: settings}, cfg) + emailService := NewEmailService(&settingRepoStub{values: settings}, emailCache) + + return NewAuthService( + nil, + userRepo, + redeemRepo, + refreshTokenCache, + cfg, + settingService, + emailService, + nil, + nil, + nil, + nil, + ) +} + +func TestRegisterOAuthEmailAccountRollsBackCreatedUserWhenTokenPairGenerationFails(t *testing.T) { + userRepo := &userRepoStub{nextID: 42} + redeemRepo := &redeemCodeRepoStub{ + codesByCode: map[string]*RedeemCode{ + "INVITE123": { + ID: 7, + Code: "INVITE123", + Type: RedeemTypeInvitation, + Status: StatusUnused, + }, + }, + } + emailCache := &emailCacheStub{ + data: &VerificationCodeData{ + Code: "246810", + Attempts: 0, + CreatedAt: time.Now().UTC(), + ExpiresAt: time.Now().UTC().Add(15 * time.Minute), + }, + } + authService := newOAuthEmailFlowAuthService( + userRepo, + redeemRepo, + nil, + map[string]string{ + SettingKeyRegistrationEnabled: "true", + SettingKeyInvitationCodeEnabled: "true", + SettingKeyEmailVerifyEnabled: "true", + }, + emailCache, + ) + + tokenPair, user, err := authService.RegisterOAuthEmailAccount( + context.Background(), + "fresh@example.com", + "secret-123", + "246810", + "INVITE123", + "oidc", + ) + + require.Nil(t, tokenPair) + require.Nil(t, user) + require.Error(t, err) + require.Contains(t, err.Error(), "generate token pair") + require.Equal(t, []int64{42}, userRepo.deletedIDs) + require.Len(t, userRepo.created, 1) + require.Empty(t, redeemRepo.useCalls) + require.Empty(t, redeemRepo.updateCalls) +} + +func TestRollbackOAuthEmailAccountCreationRestoresInvitationUsage(t *testing.T) { + userRepo := &userRepoStub{} + redeemRepo := &redeemCodeRepoStub{ + codesByCode: map[string]*RedeemCode{ + "INVITE123": { + ID: 7, + Code: "INVITE123", + Type: RedeemTypeInvitation, + Status: StatusUsed, + UsedBy: func() *int64 { + v := int64(42) + return &v + }(), + UsedAt: func() *time.Time { + v := time.Now().UTC() + return &v + }(), + }, + }, + } + authService := newOAuthEmailFlowAuthService( + userRepo, + redeemRepo, + &refreshTokenCacheStub{}, + map[string]string{ + SettingKeyRegistrationEnabled: "true", + SettingKeyInvitationCodeEnabled: "true", + }, + &emailCacheStub{}, + ) + + err := authService.RollbackOAuthEmailAccountCreation(context.Background(), 42, "INVITE123") + + require.NoError(t, err) + require.Equal(t, []int64{42}, userRepo.deletedIDs) + require.Len(t, redeemRepo.updateCalls, 1) + require.Equal(t, StatusUnused, redeemRepo.updateCalls[0].Status) + require.Nil(t, redeemRepo.updateCalls[0].UsedBy) + require.Nil(t, redeemRepo.updateCalls[0].UsedAt) +} + +func TestRollbackOAuthEmailAccountCreationPropagatesDeleteError(t *testing.T) { + userRepo := &userRepoStub{deleteErr: errors.New("delete failed")} + authService := newOAuthEmailFlowAuthService( + userRepo, + &redeemCodeRepoStub{}, + &refreshTokenCacheStub{}, + map[string]string{ + SettingKeyRegistrationEnabled: "true", + }, + &emailCacheStub{}, + ) + + err := authService.RollbackOAuthEmailAccountCreation(context.Background(), 42, "") + + require.Error(t, err) + require.Contains(t, err.Error(), "delete created oauth user") +} diff --git a/backend/internal/service/auth_service_email_bind_test.go b/backend/internal/service/auth_service_email_bind_test.go new file mode 100644 index 00000000..899a736d --- /dev/null +++ b/backend/internal/service/auth_service_email_bind_test.go @@ -0,0 +1,316 @@ +//go:build unit + +package service_test + +import ( + "context" + "database/sql" + "testing" + "time" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/enttest" + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/repository" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/stretchr/testify/require" + + "entgo.io/ent/dialect" + entsql "entgo.io/ent/dialect/sql" + _ "modernc.org/sqlite" +) + +type emailBindDefaultSubAssignerStub struct { + calls []*service.AssignSubscriptionInput +} + +func (s *emailBindDefaultSubAssignerStub) AssignOrExtendSubscription( + _ context.Context, + input *service.AssignSubscriptionInput, +) (*service.UserSubscription, bool, error) { + cloned := *input + s.calls = append(s.calls, &cloned) + return &service.UserSubscription{UserID: input.UserID, GroupID: input.GroupID}, false, nil +} + +func newAuthServiceForEmailBind( + t *testing.T, + settings map[string]string, + emailCache service.EmailCache, + defaultSubAssigner service.DefaultSubscriptionAssigner, +) (*service.AuthService, service.UserRepository, *dbent.Client) { + t.Helper() + + db, err := sql.Open("sqlite", "file:auth_service_email_bind?mode=memory&cache=shared") + require.NoError(t, err) + t.Cleanup(func() { _ = db.Close() }) + + _, err = db.Exec("PRAGMA foreign_keys = ON") + require.NoError(t, err) + _, err = db.Exec(` +CREATE TABLE IF NOT EXISTS user_provider_default_grants ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_id INTEGER NOT NULL, + provider_type TEXT NOT NULL, + grant_reason TEXT NOT NULL DEFAULT 'first_bind', + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + UNIQUE(user_id, provider_type, grant_reason) +)`) + require.NoError(t, err) + + drv := entsql.OpenDB(dialect.SQLite, db) + client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv))) + t.Cleanup(func() { _ = client.Close() }) + + repo := repository.NewUserRepository(client, db) + cfg := &config.Config{ + JWT: config.JWTConfig{ + Secret: "test-bind-email-secret", + ExpireHour: 1, + }, + Default: config.DefaultConfig{ + UserBalance: 3.5, + UserConcurrency: 2, + }, + } + + settingRepo := &emailBindSettingRepoStub{values: settings} + settingSvc := service.NewSettingService(settingRepo, cfg) + + var emailSvc *service.EmailService + if emailCache != nil { + emailSvc = service.NewEmailService(settingRepo, emailCache) + } + + svc := service.NewAuthService(client, repo, nil, nil, cfg, settingSvc, emailSvc, nil, nil, nil, defaultSubAssigner) + return svc, repo, client +} + +func TestAuthServiceBindEmailIdentity_UpdatesEmailAndAppliesFirstBindDefaults(t *testing.T) { + assigner := &emailBindDefaultSubAssignerStub{} + cache := &emailBindCacheStub{ + data: &service.VerificationCodeData{ + Code: "123456", + CreatedAt: time.Now().UTC(), + ExpiresAt: time.Now().UTC().Add(10 * time.Minute), + }, + } + svc, _, client := newAuthServiceForEmailBind(t, map[string]string{ + service.SettingKeyAuthSourceDefaultEmailBalance: "8.5", + service.SettingKeyAuthSourceDefaultEmailConcurrency: "4", + service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":11,"validity_days":30}]`, + service.SettingKeyAuthSourceDefaultEmailGrantOnFirstBind: "true", + }, cache, assigner) + + ctx := context.Background() + user, err := client.User.Create(). + SetEmail("legacy-user" + service.LinuxDoConnectSyntheticEmailDomain). + SetUsername("legacy-user"). + SetPasswordHash("old-hash"). + SetBalance(2.5). + SetConcurrency(1). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + updatedUser, err := svc.BindEmailIdentity(ctx, user.ID, " NewEmail@Example.com ", "123456", "new-password") + require.NoError(t, err) + require.NotNil(t, updatedUser) + require.Equal(t, "newemail@example.com", updatedUser.Email) + + storedUser, err := client.User.Get(ctx, user.ID) + require.NoError(t, err) + require.Equal(t, "newemail@example.com", storedUser.Email) + require.Equal(t, 11.0, storedUser.Balance) + require.Equal(t, 5, storedUser.Concurrency) + require.True(t, svc.CheckPassword("new-password", storedUser.PasswordHash)) + + identityCount, err := client.AuthIdentity.Query(). + Where( + authidentity.UserIDEQ(user.ID), + authidentity.ProviderTypeEQ("email"), + authidentity.ProviderKeyEQ("email"), + authidentity.ProviderSubjectEQ("newemail@example.com"), + ). + Count(ctx) + require.NoError(t, err) + require.Equal(t, 1, identityCount) + + require.Len(t, assigner.calls, 1) + require.Equal(t, user.ID, assigner.calls[0].UserID) + require.Equal(t, int64(11), assigner.calls[0].GroupID) + require.Equal(t, 30, assigner.calls[0].ValidityDays) + require.Equal(t, 1, countProviderGrantRecords(t, client, user.ID, "email", "first_bind")) +} + +func TestAuthServiceBindEmailIdentity_RejectsExistingEmailOnAnotherUser(t *testing.T) { + cache := &emailBindCacheStub{ + data: &service.VerificationCodeData{ + Code: "123456", + CreatedAt: time.Now().UTC(), + ExpiresAt: time.Now().UTC().Add(10 * time.Minute), + }, + } + svc, _, client := newAuthServiceForEmailBind(t, nil, cache, nil) + + ctx := context.Background() + sourceUser, err := client.User.Create(). + SetEmail("source-user" + service.OIDCConnectSyntheticEmailDomain). + SetUsername("source-user"). + SetPasswordHash("old-hash"). + SetBalance(1). + SetConcurrency(1). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + _, err = client.User.Create(). + SetEmail("taken@example.com"). + SetUsername("taken-user"). + SetPasswordHash("hash"). + SetBalance(1). + SetConcurrency(1). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + updatedUser, err := svc.BindEmailIdentity(ctx, sourceUser.ID, "taken@example.com", "123456", "new-password") + require.ErrorIs(t, err, service.ErrEmailExists) + require.Nil(t, updatedUser) + + storedUser, err := client.User.Get(ctx, sourceUser.ID) + require.NoError(t, err) + require.Equal(t, "source-user"+service.OIDCConnectSyntheticEmailDomain, storedUser.Email) + require.Equal(t, 0, countProviderGrantRecords(t, client, sourceUser.ID, "email", "first_bind")) +} + +func TestAuthServiceBindEmailIdentity_RejectsReservedEmail(t *testing.T) { + cache := &emailBindCacheStub{ + data: &service.VerificationCodeData{ + Code: "123456", + CreatedAt: time.Now().UTC(), + ExpiresAt: time.Now().UTC().Add(10 * time.Minute), + }, + } + svc, _, client := newAuthServiceForEmailBind(t, nil, cache, nil) + + ctx := context.Background() + user, err := client.User.Create(). + SetEmail("source-user@example.com"). + SetUsername("source-user"). + SetPasswordHash("old-hash"). + SetBalance(1). + SetConcurrency(1). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + updatedUser, err := svc.BindEmailIdentity(ctx, user.ID, "reserved"+service.LinuxDoConnectSyntheticEmailDomain, "123456", "new-password") + require.ErrorIs(t, err, service.ErrEmailReserved) + require.Nil(t, updatedUser) +} + +type emailBindSettingRepoStub struct { + values map[string]string +} + +func (s *emailBindSettingRepoStub) Get(context.Context, string) (*service.Setting, error) { + panic("unexpected Get call") +} + +func (s *emailBindSettingRepoStub) GetValue(_ context.Context, key string) (string, error) { + if v, ok := s.values[key]; ok { + return v, nil + } + return "", service.ErrSettingNotFound +} + +func (s *emailBindSettingRepoStub) Set(context.Context, string, string) error { + panic("unexpected Set call") +} + +func (s *emailBindSettingRepoStub) GetMultiple(_ context.Context, keys []string) (map[string]string, error) { + out := make(map[string]string, len(keys)) + for _, key := range keys { + if v, ok := s.values[key]; ok { + out[key] = v + } + } + return out, nil +} + +func (s *emailBindSettingRepoStub) SetMultiple(context.Context, map[string]string) error { + panic("unexpected SetMultiple call") +} + +func (s *emailBindSettingRepoStub) GetAll(context.Context) (map[string]string, error) { + panic("unexpected GetAll call") +} + +func (s *emailBindSettingRepoStub) Delete(context.Context, string) error { + panic("unexpected Delete call") +} + +type emailBindCacheStub struct { + data *service.VerificationCodeData + err error +} + +func (s *emailBindCacheStub) GetVerificationCode(context.Context, string) (*service.VerificationCodeData, error) { + if s.err != nil { + return nil, s.err + } + return s.data, nil +} + +func (s *emailBindCacheStub) SetVerificationCode(context.Context, string, *service.VerificationCodeData, time.Duration) error { + return nil +} + +func (s *emailBindCacheStub) DeleteVerificationCode(context.Context, string) error { + return nil +} + +func (s *emailBindCacheStub) GetNotifyVerifyCode(context.Context, string) (*service.VerificationCodeData, error) { + return nil, nil +} + +func (s *emailBindCacheStub) SetNotifyVerifyCode(context.Context, string, *service.VerificationCodeData, time.Duration) error { + return nil +} + +func (s *emailBindCacheStub) DeleteNotifyVerifyCode(context.Context, string) error { + return nil +} + +func (s *emailBindCacheStub) GetPasswordResetToken(context.Context, string) (*service.PasswordResetTokenData, error) { + return nil, nil +} + +func (s *emailBindCacheStub) SetPasswordResetToken(context.Context, string, *service.PasswordResetTokenData, time.Duration) error { + return nil +} + +func (s *emailBindCacheStub) DeletePasswordResetToken(context.Context, string) error { + return nil +} + +func (s *emailBindCacheStub) IsPasswordResetEmailInCooldown(context.Context, string) bool { + return false +} + +func (s *emailBindCacheStub) SetPasswordResetEmailCooldown(context.Context, string, time.Duration) error { + return nil +} + +func (s *emailBindCacheStub) GetNotifyCodeUserRate(context.Context, int64) (int64, error) { + return 0, nil +} + +func (s *emailBindCacheStub) IncrNotifyCodeUserRate(context.Context, int64, time.Duration) (int64, error) { + return 0, nil +} diff --git a/frontend/src/api/auth.ts b/frontend/src/api/auth.ts index 9cf0c7ad..0f768018 100644 --- a/frontend/src/api/auth.ts +++ b/frontend/src/api/auth.ts @@ -449,6 +449,16 @@ export async function sendVerifyCode( return data } +export async function sendPendingOAuthVerifyCode( + request: SendVerifyCodeRequest +): Promise { + const { data } = await apiClient.post( + '/auth/oauth/pending/send-verify-code', + request + ) + return data +} + /** * Validate promo code response */ @@ -638,6 +648,7 @@ export const authAPI = { clearAuthToken, getPublicSettings, sendVerifyCode, + sendPendingOAuthVerifyCode, validatePromoCode, validateInvitationCode, forgotPassword, diff --git a/frontend/src/api/user.ts b/frontend/src/api/user.ts index 7b498303..502bf151 100644 --- a/frontend/src/api/user.ts +++ b/frontend/src/api/user.ts @@ -89,6 +89,19 @@ export async function toggleNotifyEmail(email: string, disabled: boolean): Promi return data } +export async function sendEmailBindingCode(email: string): Promise { + await apiClient.post('/user/account-bindings/email/send-code', { email }) +} + +export async function bindEmailIdentity(payload: { + email: string + verify_code: string + password: string +}): Promise { + const { data } = await apiClient.post('/user/account-bindings/email', payload) + return data +} + export type BindableOAuthProvider = Exclude interface BuildOAuthBindingStartURLOptions { @@ -158,6 +171,8 @@ export const userAPI = { verifyNotifyEmail, removeNotifyEmail, toggleNotifyEmail, + sendEmailBindingCode, + bindEmailIdentity, buildOAuthBindingStartURL, startOAuthBinding } diff --git a/frontend/src/components/auth/PendingOAuthCreateAccountForm.vue b/frontend/src/components/auth/PendingOAuthCreateAccountForm.vue index 36e78d36..8e05617f 100644 --- a/frontend/src/components/auth/PendingOAuthCreateAccountForm.vue +++ b/frontend/src/components/auth/PendingOAuthCreateAccountForm.vue @@ -58,11 +58,20 @@

{{ t('auth.verificationCodeHint') }}

+ + + + + - -
- - {{ - item.bound - ? t('profile.authBindings.status.bound') - : t('profile.authBindings.status.notBound') - }} - - - +
+ +
@@ -49,7 +110,7 @@ diff --git a/frontend/src/components/user/profile/__tests__/ProfileIdentityBindingsSection.spec.ts b/frontend/src/components/user/profile/__tests__/ProfileIdentityBindingsSection.spec.ts index 4e194a39..ec4aed5d 100644 --- a/frontend/src/components/user/profile/__tests__/ProfileIdentityBindingsSection.spec.ts +++ b/frontend/src/components/user/profile/__tests__/ProfileIdentityBindingsSection.spec.ts @@ -2,7 +2,7 @@ import { mount } from '@vue/test-utils' import { createPinia, setActivePinia } from 'pinia' import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' import ProfileIdentityBindingsSection from '@/components/user/profile/ProfileIdentityBindingsSection.vue' -import { useAppStore } from '@/stores' +import { useAppStore, useAuthStore } from '@/stores' import type { User } from '@/types' const routeState = vi.hoisted(() => ({ @@ -15,10 +15,24 @@ const locationState = vi.hoisted(() => ({ let pinia: ReturnType +const userApiMocks = vi.hoisted(() => ({ + sendEmailBindingCode: vi.fn(), + bindEmailIdentity: vi.fn(), +})) + vi.mock('vue-router', () => ({ useRoute: () => routeState, })) +vi.mock('@/api/user', async (importOriginal) => { + const actual = await importOriginal() + return { + ...actual, + sendEmailBindingCode: (...args: any[]) => userApiMocks.sendEmailBindingCode(...args), + bindEmailIdentity: (...args: any[]) => userApiMocks.bindEmailIdentity(...args), + } +}) + vi.mock('vue-i18n', async (importOriginal) => { const actual = await importOriginal() return { @@ -34,6 +48,13 @@ vi.mock('vue-i18n', async (importOriginal) => { if (key === 'profile.authBindings.providers.wechat') return 'WeChat' if (key === 'profile.authBindings.providers.oidc') return params?.providerName || 'OIDC' if (key === 'profile.authBindings.bindAction') return `Bind ${params?.providerName || ''}`.trim() + if (key === 'profile.authBindings.emailPlaceholder') return 'Email address' + if (key === 'profile.authBindings.codePlaceholder') return 'Verification code' + if (key === 'profile.authBindings.passwordPlaceholder') return 'Set password' + if (key === 'profile.authBindings.sendCodeAction') return 'Send code' + if (key === 'profile.authBindings.confirmEmailBindAction') return 'Bind email' + if (key === 'profile.authBindings.codeSentTo') return `Code sent to ${params?.email || ''}`.trim() + if (key === 'profile.authBindings.bindSuccess') return 'Bind success' return key }, }), @@ -76,6 +97,8 @@ describe('ProfileIdentityBindingsSection', () => { const appStore = useAppStore() appStore.cachedPublicSettings = null appStore.publicSettingsLoaded = false + userApiMocks.sendEmailBindingCode.mockReset() + userApiMocks.bindEmailIdentity.mockReset() }) afterEach(() => { @@ -224,4 +247,58 @@ describe('ProfileIdentityBindingsSection', () => { expect(wrapper.find('[data-testid="profile-binding-wechat-action"]').exists()).toBe(true) }) + + it('sends email verification code and binds email from the profile card', async () => { + userApiMocks.sendEmailBindingCode.mockResolvedValue(undefined) + userApiMocks.bindEmailIdentity.mockResolvedValue( + createUser({ + email: 'bound@example.com', + email_bound: true, + auth_bindings: { + email: { bound: true }, + }, + }) + ) + + const appStore = useAppStore() + const authStore = useAuthStore() + authStore.user = createUser({ + email: 'legacy-user@linuxdo-connect.invalid', + email_bound: false, + auth_bindings: { + email: { bound: false }, + }, + }) + const showSuccessSpy = vi.spyOn(appStore, 'showSuccess') + + const wrapper = mount(ProfileIdentityBindingsSection, { + global: { + plugins: [pinia], + }, + props: { + user: authStore.user, + linuxdoEnabled: false, + oidcEnabled: false, + wechatEnabled: false, + }, + }) + + await wrapper.get('[data-testid="profile-binding-email-input"]').setValue('bound@example.com') + await wrapper.get('[data-testid="profile-binding-email-send-code"]').trigger('click') + + expect(userApiMocks.sendEmailBindingCode).toHaveBeenCalledWith('bound@example.com') + expect(showSuccessSpy).toHaveBeenCalledWith('Code sent to bound@example.com') + + await wrapper.get('[data-testid="profile-binding-email-code-input"]').setValue('123456') + await wrapper.get('[data-testid="profile-binding-email-password-input"]').setValue('new-password') + await wrapper.get('[data-testid="profile-binding-email-submit"]').trigger('click') + + expect(userApiMocks.bindEmailIdentity).toHaveBeenCalledWith({ + email: 'bound@example.com', + verify_code: '123456', + password: 'new-password', + }) + expect(wrapper.get('[data-testid="profile-binding-email-status"]').text()).toBe('Bound') + expect(authStore.user?.email).toBe('bound@example.com') + }) }) diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts index ec9c1ea3..2b41a3c3 100644 --- a/frontend/src/i18n/locales/en.ts +++ b/frontend/src/i18n/locales/en.ts @@ -964,6 +964,12 @@ export default { description: 'View current bindings and connect another provider to this account.', bindAction: 'Bind {providerName}', bindSuccess: 'Account linked successfully', + emailPlaceholder: 'Enter email address', + codePlaceholder: 'Enter verification code', + passwordPlaceholder: 'Set a login password', + sendCodeAction: 'Send code', + confirmEmailBindAction: 'Bind email', + codeSentTo: 'Code sent to {email}', status: { bound: 'Bound', notBound: 'Not bound', diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts index 9941d323..b60a69d6 100644 --- a/frontend/src/i18n/locales/zh.ts +++ b/frontend/src/i18n/locales/zh.ts @@ -968,6 +968,12 @@ export default { description: '查看当前绑定状态,并将更多第三方登录方式关联到这个账号。', bindAction: '绑定 {providerName}', bindSuccess: '账号绑定成功', + emailPlaceholder: '输入邮箱地址', + codePlaceholder: '输入验证码', + passwordPlaceholder: '设置登录密码', + sendCodeAction: '发送验证码', + confirmEmailBindAction: '绑定邮箱', + codeSentTo: '验证码已发送到 {email}', status: { bound: '已绑定', notBound: '未绑定', diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts index 07341919..bfc11cb2 100644 --- a/frontend/src/types/index.ts +++ b/frontend/src/types/index.ts @@ -118,6 +118,8 @@ export interface RegisterRequest { export interface SendVerifyCodeRequest { email: string turnstile_token?: string + pending_auth_token?: string + pending_oauth_token?: string } export interface SendVerifyCodeResponse { diff --git a/frontend/src/views/auth/EmailVerifyView.vue b/frontend/src/views/auth/EmailVerifyView.vue index 84dd4667..d7bf6b7a 100644 --- a/frontend/src/views/auth/EmailVerifyView.vue +++ b/frontend/src/views/auth/EmailVerifyView.vue @@ -176,7 +176,12 @@ import { AuthLayout } from '@/components/layout' import Icon from '@/components/icons/Icon.vue' import TurnstileWidget from '@/components/TurnstileWidget.vue' import { useAuthStore, useAppStore } from '@/stores' -import { persistOAuthTokenContext, getPublicSettings, sendVerifyCode } from '@/api/auth' +import { + persistOAuthTokenContext, + getPublicSettings, + sendPendingOAuthVerifyCode, + sendVerifyCode, +} from '@/api/auth' import { apiClient } from '@/api/client' import { buildAuthErrorMessage } from '@/utils/authError' import { @@ -355,18 +360,21 @@ async function sendCode(): Promise { errorMessage.value = '' try { - if (!isRegistrationEmailSuffixAllowed(email.value, registrationEmailSuffixWhitelist.value)) { + if (!pendingAuthToken.value && !isRegistrationEmailSuffixAllowed(email.value, registrationEmailSuffixWhitelist.value)) { errorMessage.value = buildEmailSuffixNotAllowedMessage() appStore.showError(errorMessage.value) return } - const response = await sendVerifyCode({ + const requestPayload = { email: email.value, [pendingAuthTokenField.value]: pendingAuthToken.value || undefined, // 优先使用重发时新获取的 token(因为初始 token 可能已被使用) turnstile_token: resendTurnstileToken.value || initialTurnstileToken.value || undefined - } as Parameters[0]) + } as Parameters[0] + const response = pendingAuthToken.value + ? await sendPendingOAuthVerifyCode(requestPayload) + : await sendVerifyCode(requestPayload) codeSent.value = true startCountdown(response.countdown) diff --git a/frontend/src/views/auth/LinuxDoCallbackView.vue b/frontend/src/views/auth/LinuxDoCallbackView.vue index 6c923b0a..735c6582 100644 --- a/frontend/src/views/auth/LinuxDoCallbackView.vue +++ b/frontend/src/views/auth/LinuxDoCallbackView.vue @@ -444,6 +444,28 @@ function getRequestErrorMessage(error: unknown, fallback: string): string { return err.response?.data?.detail || err.response?.data?.message || err.message || fallback } +function isCreateAccountRecoveryError(error: unknown): boolean { + const data = (error as { + response?: { + data?: { + reason?: string + error?: string + code?: string + step?: string + intent?: string + } + } + }).response?.data + const states = [data?.reason, data?.error, data?.code, data?.step, data?.intent] + .map(value => value?.trim().toLowerCase()) + .filter((value): value is string => Boolean(value)) + + return states.includes('email_exists') || + states.includes('bind_login_required') || + states.includes('bind_login') || + states.includes('adopt_existing_user_by_email') +} + async function finalizeCompletion(completion: PendingOAuthExchangeResponse, redirect: string) { if (getOAuthCompletionKind(completion) === 'bind') { const bindRedirect = sanitizeRedirectPath(completion.redirect || '/profile') @@ -540,10 +562,15 @@ async function handleCreateAccount(payload: PendingOAuthCreateAccountPayload) { email: payload.email, password: payload.password, verify_code: payload.verifyCode || undefined, + invitation_code: payload.invitationCode || undefined, ...serializeAdoptionDecision(currentAdoptionDecision()) }) await finalizePendingAccountResponse(data) } catch (e: unknown) { + if (isCreateAccountRecoveryError(e)) { + switchToBindLoginMode(payload.email) + return + } accountActionError.value = getRequestErrorMessage(e, t('auth.loginFailed')) } finally { isSubmitting.value = false diff --git a/frontend/src/views/auth/OidcCallbackView.vue b/frontend/src/views/auth/OidcCallbackView.vue index 840e4964..019cab54 100644 --- a/frontend/src/views/auth/OidcCallbackView.vue +++ b/frontend/src/views/auth/OidcCallbackView.vue @@ -488,6 +488,28 @@ function getRequestErrorMessage(error: unknown, fallback: string): string { return err.response?.data?.detail || err.response?.data?.message || err.message || fallback } +function isCreateAccountRecoveryError(error: unknown): boolean { + const data = (error as { + response?: { + data?: { + reason?: string + error?: string + code?: string + step?: string + intent?: string + } + } + }).response?.data + const states = [data?.reason, data?.error, data?.code, data?.step, data?.intent] + .map(value => value?.trim().toLowerCase()) + .filter((value): value is string => Boolean(value)) + + return states.includes('email_exists') || + states.includes('bind_login_required') || + states.includes('bind_login') || + states.includes('adopt_existing_user_by_email') +} + async function finalizeCompletion(completion: PendingOAuthExchangeResponse, redirect: string) { if (getOAuthCompletionKind(completion) === 'bind') { const bindRedirect = sanitizeRedirectPath(completion.redirect || '/profile') @@ -584,10 +606,15 @@ async function handleCreateAccount(payload: PendingOAuthCreateAccountPayload) { email: payload.email, password: payload.password, verify_code: payload.verifyCode || undefined, + invitation_code: payload.invitationCode || undefined, ...serializeAdoptionDecision(currentAdoptionDecision()) }) await finalizePendingAccountResponse(data) } catch (e: unknown) { + if (isCreateAccountRecoveryError(e)) { + switchToBindLoginMode(payload.email) + return + } accountActionError.value = getRequestErrorMessage(e, t('auth.loginFailed')) } finally { isSubmitting.value = false diff --git a/frontend/src/views/auth/WechatCallbackView.vue b/frontend/src/views/auth/WechatCallbackView.vue index 35cd0032..36e3140c 100644 --- a/frontend/src/views/auth/WechatCallbackView.vue +++ b/frontend/src/views/auth/WechatCallbackView.vue @@ -647,6 +647,28 @@ function getRequestErrorMessage(error: unknown, fallback: string): string { return err.response?.data?.detail || err.response?.data?.message || err.message || fallback } +function isCreateAccountRecoveryError(error: unknown): boolean { + const data = (error as { + response?: { + data?: { + reason?: string + error?: string + code?: string + step?: string + intent?: string + } + } + }).response?.data + const states = [data?.reason, data?.error, data?.code, data?.step, data?.intent] + .map(value => value?.trim().toLowerCase()) + .filter((value): value is string => Boolean(value)) + + return states.includes('email_exists') || + states.includes('bind_login_required') || + states.includes('bind_login') || + states.includes('adopt_existing_user_by_email') +} + async function finalizeCompletion(completion: PendingOAuthExchangeResponse, redirect: string) { if (getOAuthCompletionKind(completion) === 'bind') { const bindRedirect = sanitizeRedirectPath(completion.redirect || '/profile') @@ -739,10 +761,15 @@ async function handleCreateAccount(payload: PendingOAuthCreateAccountPayload) { email: payload.email, password: payload.password, verify_code: payload.verifyCode || undefined, + invitation_code: payload.invitationCode || undefined, ...serializeAdoptionDecision(currentAdoptionDecision()) }) await finalizePendingAccountResponse(data) } catch (e: unknown) { + if (isCreateAccountRecoveryError(e)) { + switchToBindLoginMode(payload.email) + return + } accountActionError.value = getRequestErrorMessage(e, t('auth.loginFailed')) } finally { isSubmitting.value = false diff --git a/frontend/src/views/auth/__tests__/EmailVerifyView.spec.ts b/frontend/src/views/auth/__tests__/EmailVerifyView.spec.ts index f6dff076..c231d6e7 100644 --- a/frontend/src/views/auth/__tests__/EmailVerifyView.spec.ts +++ b/frontend/src/views/auth/__tests__/EmailVerifyView.spec.ts @@ -11,6 +11,7 @@ const { clearPendingAuthSessionMock, getPublicSettingsMock, sendVerifyCodeMock, + sendPendingOAuthVerifyCodeMock, persistOAuthTokenContextMock, apiClientPostMock, authStoreState, @@ -23,6 +24,7 @@ const { clearPendingAuthSessionMock: vi.fn(), getPublicSettingsMock: vi.fn(), sendVerifyCodeMock: vi.fn(), + sendPendingOAuthVerifyCodeMock: vi.fn(), persistOAuthTokenContextMock: vi.fn(), apiClientPostMock: vi.fn(), authStoreState: { @@ -80,6 +82,7 @@ vi.mock('@/api/auth', async () => { ...actual, getPublicSettings: (...args: any[]) => getPublicSettingsMock(...args), sendVerifyCode: (...args: any[]) => sendVerifyCodeMock(...args), + sendPendingOAuthVerifyCode: (...args: any[]) => sendPendingOAuthVerifyCodeMock(...args), persistOAuthTokenContext: (...args: any[]) => persistOAuthTokenContextMock(...args), } }) @@ -100,6 +103,7 @@ describe('EmailVerifyView', () => { clearPendingAuthSessionMock.mockReset() getPublicSettingsMock.mockReset() sendVerifyCodeMock.mockReset() + sendPendingOAuthVerifyCodeMock.mockReset() persistOAuthTokenContextMock.mockReset() apiClientPostMock.mockReset() authStoreState.pendingAuthSession = null @@ -112,9 +116,86 @@ describe('EmailVerifyView', () => { registration_email_suffix_whitelist: [], }) sendVerifyCodeMock.mockResolvedValue({ countdown: 60 }) + sendPendingOAuthVerifyCodeMock.mockResolvedValue({ countdown: 60 }) setTokenMock.mockResolvedValue({}) }) + it('uses the pending oauth verify-code endpoint when register data carries a pending auth session', async () => { + authStoreState.pendingAuthSession = { + token: 'pending-token-1', + token_field: 'pending_auth_token', + provider: 'wechat', + redirect: '/profile', + } + sessionStorage.setItem( + 'register_data', + JSON.stringify({ + email: 'fresh@example.com', + password: 'secret-123', + }) + ) + + mount(EmailVerifyView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + TurnstileWidget: true, + transition: false, + }, + }, + }) + + await flushPromises() + + expect(sendPendingOAuthVerifyCodeMock).toHaveBeenCalledWith({ + email: 'fresh@example.com', + pending_auth_token: 'pending-token-1', + }) + expect(sendVerifyCodeMock).not.toHaveBeenCalled() + }) + + it('skips the registration email suffix whitelist for pending oauth verification', async () => { + authStoreState.pendingAuthSession = { + token: 'pending-token-2', + token_field: 'pending_auth_token', + provider: 'oidc', + redirect: '/profile', + } + getPublicSettingsMock.mockResolvedValue({ + turnstile_enabled: false, + turnstile_site_key: '', + site_name: 'Sub2API', + registration_email_suffix_whitelist: ['allowed.com'], + }) + sessionStorage.setItem( + 'register_data', + JSON.stringify({ + email: 'fresh@example.com', + password: 'secret-123', + }) + ) + + mount(EmailVerifyView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + TurnstileWidget: true, + transition: false, + }, + }, + }) + + await flushPromises() + + expect(sendPendingOAuthVerifyCodeMock).toHaveBeenCalledWith({ + email: 'fresh@example.com', + pending_auth_token: 'pending-token-2', + }) + expect(showErrorMock).not.toHaveBeenCalled() + }) + it('submits pending auth account creation when session storage has no pending metadata but auth store does', async () => { authStoreState.pendingAuthSession = { token: 'pending-token-1', diff --git a/frontend/src/views/auth/__tests__/LinuxDoCallbackView.spec.ts b/frontend/src/views/auth/__tests__/LinuxDoCallbackView.spec.ts index a04915b7..f612681a 100644 --- a/frontend/src/views/auth/__tests__/LinuxDoCallbackView.spec.ts +++ b/frontend/src/views/auth/__tests__/LinuxDoCallbackView.spec.ts @@ -15,6 +15,7 @@ const getPublicSettings = vi.fn() const login2FA = vi.fn() const apiClientPost = vi.fn() const sendVerifyCode = vi.fn() +const sendPendingOAuthVerifyCode = vi.fn() vi.mock('vue-router', () => ({ useRoute: () => ({ @@ -61,7 +62,8 @@ vi.mock('@/api/auth', async () => { completeLinuxDoOAuthRegistration: (...args: any[]) => completeLinuxDoOAuthRegistration(...args), getPublicSettings: (...args: any[]) => getPublicSettings(...args), login2FA: (...args: any[]) => login2FA(...args), - sendVerifyCode: (...args: any[]) => sendVerifyCode(...args) + sendVerifyCode: (...args: any[]) => sendVerifyCode(...args), + sendPendingOAuthVerifyCode: (...args: any[]) => sendPendingOAuthVerifyCode(...args) } }) @@ -79,6 +81,7 @@ describe('LinuxDoCallbackView', () => { login2FA.mockReset() apiClientPost.mockReset() sendVerifyCode.mockReset() + sendPendingOAuthVerifyCode.mockReset() getPublicSettings.mockResolvedValue({ turnstile_enabled: false, turnstile_site_key: '' @@ -334,6 +337,11 @@ describe('LinuxDoCallbackView', () => { }) it('collects email, password, and verify code for pending oauth account creation and submits adoption decisions', async () => { + getPublicSettings.mockResolvedValue({ + invitation_code_enabled: true, + turnstile_enabled: false, + turnstile_site_key: '' + }) exchangePendingOAuthCompletion.mockResolvedValue({ error: 'email_required', redirect: '/welcome', @@ -370,6 +378,7 @@ describe('LinuxDoCallbackView', () => { await wrapper.get('[data-testid="linuxdo-create-account-email"]').setValue(' new@example.com ') await wrapper.get('[data-testid="linuxdo-create-account-password"]').setValue('secret-123') await wrapper.get('[data-testid="linuxdo-create-account-verify-code"]').setValue('246810') + await wrapper.get('[data-testid="linuxdo-create-account-invitation-code"]').setValue(' INVITE123 ') await wrapper.get('[data-testid="linuxdo-create-account-submit"]').trigger('click') await flushPromises() @@ -377,6 +386,7 @@ describe('LinuxDoCallbackView', () => { email: 'new@example.com', password: 'secret-123', verify_code: '246810', + invitation_code: 'INVITE123', adopt_display_name: true, adopt_avatar: false }) @@ -384,12 +394,48 @@ describe('LinuxDoCallbackView', () => { expect(replace).toHaveBeenCalledWith('/welcome') }) + it('switches to bind-login when create-account returns EMAIL_EXISTS', async () => { + exchangePendingOAuthCompletion.mockResolvedValue({ + error: 'email_required', + redirect: '/welcome' + }) + apiClientPost.mockRejectedValue({ + response: { + data: { + reason: 'EMAIL_EXISTS', + message: 'email already exists' + } + } + }) + + const wrapper = mount(LinuxDoCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false + } + } + }) + + await flushPromises() + await wrapper.get('[data-testid="linuxdo-create-account-email"]').setValue('existing@example.com') + await wrapper.get('[data-testid="linuxdo-create-account-password"]').setValue('secret-123') + await wrapper.get('[data-testid="linuxdo-create-account-submit"]').trigger('click') + await flushPromises() + + expect((wrapper.get('[data-testid="linuxdo-bind-login-email"]').element as HTMLInputElement).value).toBe( + 'existing@example.com' + ) + }) + it('sends a verify code for pending oauth account creation', async () => { exchangePendingOAuthCompletion.mockResolvedValue({ error: 'email_required', redirect: '/welcome' }) - sendVerifyCode.mockResolvedValue({ + sendPendingOAuthVerifyCode.mockResolvedValue({ message: 'sent', countdown: 60 }) @@ -411,7 +457,7 @@ describe('LinuxDoCallbackView', () => { await wrapper.get('[data-testid="linuxdo-create-account-send-code"]').trigger('click') await flushPromises() - expect(sendVerifyCode).toHaveBeenCalledWith({ + expect(sendPendingOAuthVerifyCode).toHaveBeenCalledWith({ email: 'new@example.com' }) }) diff --git a/frontend/src/views/auth/__tests__/OidcCallbackView.spec.ts b/frontend/src/views/auth/__tests__/OidcCallbackView.spec.ts index 259fb282..0edcb931 100644 --- a/frontend/src/views/auth/__tests__/OidcCallbackView.spec.ts +++ b/frontend/src/views/auth/__tests__/OidcCallbackView.spec.ts @@ -15,6 +15,7 @@ const getPublicSettings = vi.fn() const login2FA = vi.fn() const apiClientPost = vi.fn() const sendVerifyCode = vi.fn() +const sendPendingOAuthVerifyCode = vi.fn() vi.mock('vue-router', () => ({ useRoute: () => ({ @@ -66,7 +67,8 @@ vi.mock('@/api/auth', async () => { completeOIDCOAuthRegistration: (...args: any[]) => completeOIDCOAuthRegistration(...args), getPublicSettings: (...args: any[]) => getPublicSettings(...args), login2FA: (...args: any[]) => login2FA(...args), - sendVerifyCode: (...args: any[]) => sendVerifyCode(...args) + sendVerifyCode: (...args: any[]) => sendVerifyCode(...args), + sendPendingOAuthVerifyCode: (...args: any[]) => sendPendingOAuthVerifyCode(...args) } }) @@ -84,6 +86,7 @@ describe('OidcCallbackView', () => { login2FA.mockReset() apiClientPost.mockReset() sendVerifyCode.mockReset() + sendPendingOAuthVerifyCode.mockReset() getPublicSettings.mockResolvedValue({ oidc_oauth_provider_name: 'ExampleID', turnstile_enabled: false, @@ -312,6 +315,12 @@ describe('OidcCallbackView', () => { }) it('collects email, password, and verify code for pending oauth account creation and submits adoption decisions', async () => { + getPublicSettings.mockResolvedValue({ + oidc_oauth_provider_name: 'ExampleID', + invitation_code_enabled: true, + turnstile_enabled: false, + turnstile_site_key: '' + }) exchangePendingOAuthCompletion.mockResolvedValue({ error: 'email_required', redirect: '/welcome', @@ -348,6 +357,7 @@ describe('OidcCallbackView', () => { await wrapper.get('[data-testid="oidc-create-account-email"]').setValue(' new@example.com ') await wrapper.get('[data-testid="oidc-create-account-password"]').setValue('secret-123') await wrapper.get('[data-testid="oidc-create-account-verify-code"]').setValue('246810') + await wrapper.get('[data-testid="oidc-create-account-invitation-code"]').setValue(' INVITE123 ') await wrapper.get('[data-testid="oidc-create-account-submit"]').trigger('click') await flushPromises() @@ -355,6 +365,7 @@ describe('OidcCallbackView', () => { email: 'new@example.com', password: 'secret-123', verify_code: '246810', + invitation_code: 'INVITE123', adopt_display_name: true, adopt_avatar: false }) @@ -362,12 +373,48 @@ describe('OidcCallbackView', () => { expect(replace).toHaveBeenCalledWith('/welcome') }) + it('switches to bind-login when create-account returns EMAIL_EXISTS', async () => { + exchangePendingOAuthCompletion.mockResolvedValue({ + error: 'email_required', + redirect: '/welcome' + }) + apiClientPost.mockRejectedValue({ + response: { + data: { + reason: 'EMAIL_EXISTS', + message: 'email already exists' + } + } + }) + + const wrapper = mount(OidcCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false + } + } + }) + + await flushPromises() + await wrapper.get('[data-testid="oidc-create-account-email"]').setValue('existing@example.com') + await wrapper.get('[data-testid="oidc-create-account-password"]').setValue('secret-123') + await wrapper.get('[data-testid="oidc-create-account-submit"]').trigger('click') + await flushPromises() + + expect((wrapper.get('[data-testid="oidc-bind-login-email"]').element as HTMLInputElement).value).toBe( + 'existing@example.com' + ) + }) + it('sends a verify code for pending oauth account creation', async () => { exchangePendingOAuthCompletion.mockResolvedValue({ error: 'email_required', redirect: '/welcome' }) - sendVerifyCode.mockResolvedValue({ + sendPendingOAuthVerifyCode.mockResolvedValue({ message: 'sent', countdown: 60 }) @@ -389,7 +436,7 @@ describe('OidcCallbackView', () => { await wrapper.get('[data-testid="oidc-create-account-send-code"]').trigger('click') await flushPromises() - expect(sendVerifyCode).toHaveBeenCalledWith({ + expect(sendPendingOAuthVerifyCode).toHaveBeenCalledWith({ email: 'new@example.com' }) }) diff --git a/frontend/src/views/auth/__tests__/WechatCallbackView.spec.ts b/frontend/src/views/auth/__tests__/WechatCallbackView.spec.ts index fed88890..e02060f6 100644 --- a/frontend/src/views/auth/__tests__/WechatCallbackView.spec.ts +++ b/frontend/src/views/auth/__tests__/WechatCallbackView.spec.ts @@ -8,6 +8,8 @@ const { login2FAMock, apiClientPostMock, sendVerifyCodeMock, + sendPendingOAuthVerifyCodeMock, + getPublicSettingsMock, prepareOAuthBindAccessTokenCookieMock, getAuthTokenMock, replaceMock, @@ -24,6 +26,8 @@ const { login2FAMock: vi.fn(), apiClientPostMock: vi.fn(), sendVerifyCodeMock: vi.fn(), + sendPendingOAuthVerifyCodeMock: vi.fn(), + getPublicSettingsMock: vi.fn(), prepareOAuthBindAccessTokenCookieMock: vi.fn(), getAuthTokenMock: vi.fn(), replaceMock: vi.fn(), @@ -130,6 +134,8 @@ vi.mock('@/api/auth', async () => { completeWeChatOAuthRegistration: (...args: any[]) => completeWeChatOAuthRegistrationMock(...args), login2FA: (...args: any[]) => login2FAMock(...args), sendVerifyCode: (...args: any[]) => sendVerifyCodeMock(...args), + sendPendingOAuthVerifyCode: (...args: any[]) => sendPendingOAuthVerifyCodeMock(...args), + getPublicSettings: (...args: any[]) => getPublicSettingsMock(...args), prepareOAuthBindAccessTokenCookie: (...args: any[]) => prepareOAuthBindAccessTokenCookieMock(...args), getAuthToken: (...args: any[]) => getAuthTokenMock(...args), } @@ -142,6 +148,8 @@ describe('WechatCallbackView', () => { login2FAMock.mockReset() apiClientPostMock.mockReset() sendVerifyCodeMock.mockReset() + sendPendingOAuthVerifyCodeMock.mockReset() + getPublicSettingsMock.mockReset() replaceMock.mockReset() setTokenMock.mockReset() showSuccessMock.mockReset() @@ -167,6 +175,11 @@ describe('WechatCallbackView', () => { configurable: true, value: 'Mozilla/5.0', }) + getPublicSettingsMock.mockResolvedValue({ + invitation_code_enabled: false, + turnstile_enabled: false, + turnstile_site_key: '', + }) }) it('overrides an incompatible query mode with the configured open capability during bind recovery', async () => { @@ -478,6 +491,11 @@ describe('WechatCallbackView', () => { }) it('collects email, password, and verify code for pending oauth account creation and submits adoption decisions', async () => { + getPublicSettingsMock.mockResolvedValue({ + invitation_code_enabled: true, + turnstile_enabled: false, + turnstile_site_key: '', + }) exchangePendingOAuthCompletionMock.mockResolvedValue({ error: 'email_required', redirect: '/welcome', @@ -514,6 +532,7 @@ describe('WechatCallbackView', () => { await wrapper.get('[data-testid="wechat-create-account-email"]').setValue(' new@example.com ') await wrapper.get('[data-testid="wechat-create-account-password"]').setValue('secret-123') await wrapper.get('[data-testid="wechat-create-account-verify-code"]').setValue('246810') + await wrapper.get('[data-testid="wechat-create-account-invitation-code"]').setValue(' INVITE123 ') await wrapper.get('[data-testid="wechat-create-account-submit"]').trigger('click') await flushPromises() @@ -521,6 +540,7 @@ describe('WechatCallbackView', () => { email: 'new@example.com', password: 'secret-123', verify_code: '246810', + invitation_code: 'INVITE123', adopt_display_name: true, adopt_avatar: false, }) @@ -528,12 +548,48 @@ describe('WechatCallbackView', () => { expect(replaceMock).toHaveBeenCalledWith('/welcome') }) + it('switches to bind-login when create-account returns EMAIL_EXISTS', async () => { + exchangePendingOAuthCompletionMock.mockResolvedValue({ + error: 'email_required', + redirect: '/welcome', + }) + apiClientPostMock.mockRejectedValue({ + response: { + data: { + reason: 'EMAIL_EXISTS', + message: 'email already exists', + }, + }, + }) + + const wrapper = mount(WechatCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false, + }, + }, + }) + + await flushPromises() + await wrapper.get('[data-testid="wechat-create-account-email"]').setValue('existing@example.com') + await wrapper.get('[data-testid="wechat-create-account-password"]').setValue('secret-123') + await wrapper.get('[data-testid="wechat-create-account-submit"]').trigger('click') + await flushPromises() + + expect((wrapper.get('[data-testid="wechat-bind-login-email"]').element as HTMLInputElement).value).toBe( + 'existing@example.com' + ) + }) + it('sends a verify code for pending oauth account creation', async () => { exchangePendingOAuthCompletionMock.mockResolvedValue({ error: 'email_required', redirect: '/welcome', }) - sendVerifyCodeMock.mockResolvedValue({ + sendPendingOAuthVerifyCodeMock.mockResolvedValue({ message: 'sent', countdown: 60, }) @@ -555,7 +611,7 @@ describe('WechatCallbackView', () => { await wrapper.get('[data-testid="wechat-create-account-send-code"]').trigger('click') await flushPromises() - expect(sendVerifyCodeMock).toHaveBeenCalledWith({ + expect(sendPendingOAuthVerifyCodeMock).toHaveBeenCalledWith({ email: 'new@example.com', }) })