diff --git a/backend/internal/handler/auth_handler.go b/backend/internal/handler/auth_handler.go index 76ca153d..9801b3b3 100644 --- a/backend/internal/handler/auth_handler.go +++ b/backend/internal/handler/auth_handler.go @@ -1,6 +1,7 @@ package handler import ( + "context" "log/slog" "strings" @@ -105,6 +106,34 @@ func (h *AuthHandler) respondWithTokenPair(c *gin.Context, user *service.User) { }) } +func (h *AuthHandler) ensureBackendModeAllowsUser(ctx context.Context, user *service.User) error { + if user == nil { + return infraerrors.Unauthorized("INVALID_USER", "user not found") + } + if h == nil || !h.isBackendModeEnabled(ctx) || user.IsAdmin() { + return nil + } + return infraerrors.Forbidden("BACKEND_MODE_ADMIN_ONLY", "Backend mode is active. Only admin login is allowed.") +} + +func (h *AuthHandler) ensureBackendModeAllowsNewUserLogin(ctx context.Context) error { + if h == nil || !h.isBackendModeEnabled(ctx) { + return nil + } + return infraerrors.Forbidden("BACKEND_MODE_ADMIN_ONLY", "Backend mode is active. Only admin login is allowed.") +} + +func (h *AuthHandler) isBackendModeEnabled(ctx context.Context) bool { + if h == nil || h.settingSvc == nil { + return false + } + settings, err := h.settingSvc.GetPublicSettings(ctx) + if err == nil && settings != nil { + return settings.BackendModeEnabled + } + return h.settingSvc.IsBackendModeEnabled(ctx) +} + // Register handles user registration // POST /api/v1/auth/register func (h *AuthHandler) Register(c *gin.Context) { @@ -178,6 +207,11 @@ func (h *AuthHandler) Login(c *gin.Context) { } _ = token // token 由 authService.Login 返回但此处由 respondWithTokenPair 重新生成 + if err := h.ensureBackendModeAllowsUser(c.Request.Context(), user); err != nil { + response.ErrorFrom(c, err) + return + } + // Check if TOTP 2FA is enabled for this user if h.totpService != nil && h.settingSvc.IsTotpEnabled(c.Request.Context()) && user.TotpEnabled { // Create a temporary login session for 2FA @@ -195,11 +229,7 @@ func (h *AuthHandler) Login(c *gin.Context) { return } - // Backend mode: only admin can login - if h.settingSvc.IsBackendModeEnabled(c.Request.Context()) && !user.IsAdmin() { - response.Forbidden(c, "Backend mode is active. Only admin login is allowed.") - return - } + h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID) h.respondWithTokenPair(c, user) } @@ -264,9 +294,8 @@ func (h *AuthHandler) Login2FA(c *gin.Context) { return } - // Backend mode: only admin can login (check BEFORE deleting session) - if h.settingSvc.IsBackendModeEnabled(c.Request.Context()) && !user.IsAdmin() { - response.Forbidden(c, "Backend mode is active. Only admin login is allowed.") + if err := h.ensureBackendModeAllowsUser(c.Request.Context(), user); err != nil { + response.ErrorFrom(c, err) return } @@ -330,6 +359,10 @@ func (h *AuthHandler) Login2FA(c *gin.Context) { // Delete the login session (only after all checks pass) _ = h.totpService.DeleteLoginSession(c.Request.Context(), req.TempToken) + if session.PendingOAuthBind == nil { + h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID) + } + h.respondWithTokenPair(c, user) } diff --git a/backend/internal/handler/auth_linuxdo_oauth.go b/backend/internal/handler/auth_linuxdo_oauth.go index 175b1e1f..c3a9041b 100644 --- a/backend/internal/handler/auth_linuxdo_oauth.go +++ b/backend/internal/handler/auth_linuxdo_oauth.go @@ -474,6 +474,14 @@ func (h *AuthHandler) CompleteLinuxDoOAuthRegistration(c *gin.Context) { response.ErrorFrom(c, err) return } + if err := ensurePendingOAuthCompleteRegistrationSession(session); err != nil { + response.ErrorFrom(c, err) + return + } + if err := h.ensureBackendModeAllowsNewUserLogin(c.Request.Context()); err != nil { + response.ErrorFrom(c, err) + return + } email := strings.TrimSpace(session.ResolvedEmail) username := pendingSessionStringValue(session.UpstreamIdentityClaims, "username") @@ -499,6 +507,7 @@ func (h *AuthHandler) CompleteLinuxDoOAuthRegistration(c *gin.Context) { response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err)) return } + h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID) if _, err := pendingSvc.ConsumeBrowserSession(c.Request.Context(), sessionToken, browserSessionKey); err != nil { clearOAuthPendingSessionCookie(c, secureCookie) clearOAuthPendingBrowserCookie(c, secureCookie) diff --git a/backend/internal/handler/auth_linuxdo_oauth_test.go b/backend/internal/handler/auth_linuxdo_oauth_test.go index fb57e570..7938f3e7 100644 --- a/backend/internal/handler/auth_linuxdo_oauth_test.go +++ b/backend/internal/handler/auth_linuxdo_oauth_test.go @@ -591,6 +591,58 @@ func TestCompleteLinuxDoOAuthRegistrationAppliesPendingAdoptionDecision(t *testi require.NotNil(t, consumed.ConsumedAt) } +func TestCompleteLinuxDoOAuthRegistrationRejectsAdoptExistingUserSession(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandler(t, false) + ctx := context.Background() + + existingUser, err := client.User.Create(). + SetEmail("owner@example.com"). + SetUsername("owner-user"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("linuxdo-complete-invalid-session"). + SetIntent("adopt_existing_user_by_email"). + SetProviderType("linuxdo"). + SetProviderKey("linuxdo"). + SetProviderSubject("linuxdo-invalid-subject-1"). + SetTargetUserID(existingUser.ID). + SetResolvedEmail(existingUser.Email). + SetBrowserSessionKey("linuxdo-invalid-browser"). + SetUpstreamIdentityClaims(map[string]any{ + "username": "linuxdo_user", + }). + SetLocalFlowState(map[string]any{ + oauthCompletionResponseKey: map[string]any{ + "step": "bind_login_required", + }, + }). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`) + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/linuxdo/complete-registration", body) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("linuxdo-invalid-browser")}) + c.Request = req + + handler.CompleteLinuxDoOAuthRegistration(c) + + require.Equal(t, http.StatusBadRequest, recorder.Code) + + storedSession, err := client.PendingAuthSession.Get(ctx, session.ID) + require.NoError(t, err) + require.Nil(t, storedSession.ConsumedAt) +} + func newLinuxDoOAuthTestHandler(t *testing.T, invitationEnabled bool, oauthCfg config.LinuxDoConnectConfig) *AuthHandler { t.Helper() handler, _ := newLinuxDoOAuthHandlerAndClient(t, invitationEnabled, oauthCfg) diff --git a/backend/internal/handler/auth_oauth_pending_flow.go b/backend/internal/handler/auth_oauth_pending_flow.go index 461810f1..6041e5dd 100644 --- a/backend/internal/handler/auth_oauth_pending_flow.go +++ b/backend/internal/handler/auth_oauth_pending_flow.go @@ -253,6 +253,35 @@ func pendingSessionWantsInvitation(payload map[string]any) bool { return strings.EqualFold(strings.TrimSpace(pendingSessionStringValue(payload, "error")), "invitation_required") } +func pendingOAuthCompletionIncludesTokenPayload(payload map[string]any) bool { + if len(payload) == 0 { + return false + } + for _, key := range []string{"access_token", "refresh_token"} { + if value := pendingSessionStringValue(payload, key); value != "" { + return true + } + } + return false +} + +func ensurePendingOAuthCompleteRegistrationSession(session *dbent.PendingAuthSession) error { + if session == nil { + return infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid") + } + if strings.TrimSpace(session.Intent) != oauthIntentLogin { + return infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid") + } + if session.TargetUserID != nil && *session.TargetUserID > 0 { + return infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid") + } + payload, _ := readCompletionResponse(session.LocalFlowState) + if strings.EqualFold(strings.TrimSpace(pendingSessionStringValue(payload, "step")), "bind_login_required") { + return infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid") + } + return nil +} + func (r oauthAdoptionDecisionRequest) hasDecision() bool { return r.AdoptDisplayName != nil || r.AdoptAvatar != nil } @@ -1090,6 +1119,10 @@ func (h *AuthHandler) bindPendingOAuthLogin(c *gin.Context, provider string) { response.ErrorFrom(c, infraerrors.Conflict("PENDING_AUTH_TARGET_USER_MISMATCH", "pending oauth session must be completed by the targeted user")) return } + if err := h.ensureBackendModeAllowsUser(c.Request.Context(), user); err != nil { + response.ErrorFrom(c, err) + return + } decision, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, req.adoptionDecision()) if err != nil { @@ -1192,6 +1225,10 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string) c.JSON(http.StatusOK, buildPendingOAuthSessionStatusPayload(session)) return } + if err := h.ensureBackendModeAllowsNewUserLogin(c.Request.Context()); err != nil { + response.ErrorFrom(c, err) + return + } tokenPair, user, err := h.authService.RegisterOAuthEmailAccount( c.Request.Context(), @@ -1215,6 +1252,7 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string) response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to bind pending oauth identity").WithCause(err)) return } + h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID) if _, err := pendingSvc.ConsumeBrowserSession(c.Request.Context(), session.SessionToken, session.BrowserSessionKey); err != nil { clearCookies() @@ -1279,6 +1317,25 @@ func (h *AuthHandler) ExchangePendingOAuthCompletion(c *gin.Context) { } } applySuggestedProfileToCompletionResponse(payload, session.UpstreamIdentityClaims) + if pendingOAuthCompletionIncludesTokenPayload(payload) { + if session.TargetUserID == nil || *session.TargetUserID <= 0 { + clearCookies() + response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_COMPLETION_INVALID", "pending auth completion payload is invalid")) + return + } + user, err := h.userService.GetByID(c.Request.Context(), *session.TargetUserID) + if err != nil { + clearCookies() + response.ErrorFrom(c, err) + return + } + if err := h.ensureBackendModeAllowsUser(c.Request.Context(), user); err != nil { + clearCookies() + response.ErrorFrom(c, err) + return + } + h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID) + } if pendingSessionWantsInvitation(payload) { if adoptionDecision.hasDecision() { diff --git a/backend/internal/handler/auth_oauth_pending_flow_test.go b/backend/internal/handler/auth_oauth_pending_flow_test.go index d29e4b88..c2b83c73 100644 --- a/backend/internal/handler/auth_oauth_pending_flow_test.go +++ b/backend/internal/handler/auth_oauth_pending_flow_test.go @@ -523,6 +523,60 @@ func TestExchangePendingOAuthCompletionLoginFalseFalseBindsIdentityWithoutAdopti require.NotNil(t, storedSession.ConsumedAt) } +func TestExchangePendingOAuthCompletionBlocksBackendModeBeforeReturningTokenPayload(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandlerWithDependencies(t, oauthPendingFlowTestHandlerOptions{ + settingValues: map[string]string{ + service.SettingKeyBackendModeEnabled: "true", + }, + }) + ctx := context.Background() + + userEntity, err := client.User.Create(). + SetEmail("blocked@example.com"). + SetUsername("blocked-user"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("blocked-backend-mode-session-token"). + SetIntent("login"). + SetProviderType("linuxdo"). + SetProviderKey("linuxdo"). + SetProviderSubject("blocked-subject-123"). + SetTargetUserID(userEntity.ID). + SetResolvedEmail(userEntity.Email). + SetBrowserSessionKey("blocked-backend-mode-browser-session-key"). + SetLocalFlowState(map[string]any{ + oauthCompletionResponseKey: map[string]any{ + "access_token": "access-token", + "refresh_token": "refresh-token", + "expires_in": float64(3600), + "token_type": "Bearer", + }, + }). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + recorder := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", nil) + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("blocked-backend-mode-browser-session-key")}) + ginCtx.Request = req + + handler.ExchangePendingOAuthCompletion(ginCtx) + + require.Equal(t, http.StatusForbidden, recorder.Code) + + storedSession, err := client.PendingAuthSession.Get(ctx, session.ID) + require.NoError(t, err) + require.Nil(t, storedSession.ConsumedAt) +} + func TestExchangePendingOAuthCompletionInvitationRequiredFalseFalsePersistsDecisionWithoutBinding(t *testing.T) { handler, client := newOAuthPendingFlowTestHandler(t, true) ctx := context.Background() @@ -773,6 +827,60 @@ func TestCreateOIDCOAuthAccountExistingEmailNormalizesLegacySpacingAndCase(t *te require.Equal(t, "owner@example.com", storedSession.ResolvedEmail) } +func TestCreateOIDCOAuthAccountBlocksBackendModeBeforeCreatingUser(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandlerWithDependencies(t, oauthPendingFlowTestHandlerOptions{ + emailVerifyEnabled: true, + emailCache: &oauthPendingFlowEmailCacheStub{ + verificationCodes: map[string]*service.VerificationCodeData{ + "fresh@example.com": { + Code: "246810", + CreatedAt: time.Now().UTC(), + ExpiresAt: time.Now().UTC().Add(15 * time.Minute), + }, + }, + }, + settingValues: map[string]string{ + service.SettingKeyBackendModeEnabled: "true", + }, + }) + ctx := context.Background() + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("create-account-backend-mode-session-token"). + SetIntent("login"). + SetProviderType("oidc"). + SetProviderKey("https://issuer.example"). + SetProviderSubject("oidc-create-backend-mode-123"). + SetBrowserSessionKey("create-account-backend-mode-browser-session-key"). + SetUpstreamIdentityClaims(map[string]any{ + "username": "oidc_user", + }). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + body := bytes.NewBufferString(`{"email":"fresh@example.com","verify_code":"246810","password":"secret-123"}`) + recorder := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/create-account", body) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("create-account-backend-mode-browser-session-key")}) + ginCtx.Request = req + + handler.CreateOIDCOAuthAccount(ginCtx) + + require.Equal(t, http.StatusForbidden, recorder.Code) + + userCount, err := client.User.Query().Where(dbuser.EmailEQ("fresh@example.com")).Count(ctx) + require.NoError(t, err) + require.Zero(t, userCount) + + storedSession, err := client.PendingAuthSession.Get(ctx, session.ID) + require.NoError(t, err) + require.Nil(t, storedSession.ConsumedAt) +} + func TestBindOIDCOAuthLoginBindsExistingUserAndConsumesSession(t *testing.T) { handler, client := newOAuthPendingFlowTestHandler(t, false) ctx := context.Background() @@ -842,6 +950,70 @@ func TestBindOIDCOAuthLoginBindsExistingUserAndConsumesSession(t *testing.T) { require.NotNil(t, storedSession.ConsumedAt) } +func TestBindOIDCOAuthLoginBlocksBackendModeBeforeTokenIssue(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandlerWithDependencies(t, oauthPendingFlowTestHandlerOptions{ + settingValues: map[string]string{ + service.SettingKeyBackendModeEnabled: "true", + }, + }) + ctx := context.Background() + + passwordHash, err := handler.authService.HashPassword("secret-123") + require.NoError(t, err) + + existingUser, err := client.User.Create(). + SetEmail("owner@example.com"). + SetUsername("owner-user"). + SetPasswordHash(passwordHash). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("bind-login-backend-mode-session-token"). + SetIntent("adopt_existing_user_by_email"). + SetProviderType("oidc"). + SetProviderKey("https://issuer.example"). + SetProviderSubject("oidc-bind-backend-mode-123"). + SetTargetUserID(existingUser.ID). + SetResolvedEmail(existingUser.Email). + SetBrowserSessionKey("bind-login-backend-mode-browser-session-key"). + SetUpstreamIdentityClaims(map[string]any{ + "username": "oidc_user", + }). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + body := bytes.NewBufferString(`{"email":"owner@example.com","password":"secret-123"}`) + recorder := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/bind-login", body) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("bind-login-backend-mode-browser-session-key")}) + ginCtx.Request = req + + handler.BindOIDCOAuthLogin(ginCtx) + + require.Equal(t, http.StatusForbidden, recorder.Code) + + identityCount, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("oidc"), + authidentity.ProviderKeyEQ("https://issuer.example"), + authidentity.ProviderSubjectEQ("oidc-bind-backend-mode-123"), + ). + Count(ctx) + require.NoError(t, err) + require.Zero(t, identityCount) + + storedSession, err := client.PendingAuthSession.Get(ctx, session.ID) + require.NoError(t, err) + require.Nil(t, storedSession.ConsumedAt) +} + func TestBindOIDCOAuthLoginRejectsInvalidPasswordWithoutConsumingSession(t *testing.T) { handler, client := newOAuthPendingFlowTestHandler(t, false) ctx := context.Background() diff --git a/backend/internal/handler/auth_oidc_oauth.go b/backend/internal/handler/auth_oidc_oauth.go index 0f9f1895..5901a953 100644 --- a/backend/internal/handler/auth_oidc_oauth.go +++ b/backend/internal/handler/auth_oidc_oauth.go @@ -516,6 +516,14 @@ func (h *AuthHandler) CompleteOIDCOAuthRegistration(c *gin.Context) { response.ErrorFrom(c, err) return } + if err := ensurePendingOAuthCompleteRegistrationSession(session); err != nil { + response.ErrorFrom(c, err) + return + } + if err := h.ensureBackendModeAllowsNewUserLogin(c.Request.Context()); err != nil { + response.ErrorFrom(c, err) + return + } email := strings.TrimSpace(session.ResolvedEmail) username := pendingSessionStringValue(session.UpstreamIdentityClaims, "username") @@ -541,6 +549,7 @@ func (h *AuthHandler) CompleteOIDCOAuthRegistration(c *gin.Context) { response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err)) return } + h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID) if _, err := pendingSvc.ConsumeBrowserSession(c.Request.Context(), sessionToken, browserSessionKey); err != nil { clearOAuthPendingSessionCookie(c, secureCookie) clearOAuthPendingBrowserCookie(c, secureCookie) diff --git a/backend/internal/handler/auth_oidc_oauth_test.go b/backend/internal/handler/auth_oidc_oauth_test.go index 07f5ef68..ba736db2 100644 --- a/backend/internal/handler/auth_oidc_oauth_test.go +++ b/backend/internal/handler/auth_oidc_oauth_test.go @@ -431,6 +431,58 @@ func TestCompleteOIDCOAuthRegistrationAppliesPendingAdoptionDecision(t *testing. require.NotNil(t, consumed.ConsumedAt) } +func TestCompleteOIDCOAuthRegistrationRejectsAdoptExistingUserSession(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandler(t, false) + ctx := context.Background() + + existingUser, err := client.User.Create(). + SetEmail("owner@example.com"). + SetUsername("owner-user"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("oidc-complete-invalid-session"). + SetIntent("adopt_existing_user_by_email"). + SetProviderType("oidc"). + SetProviderKey("https://issuer.example.com"). + SetProviderSubject("oidc-invalid-subject-1"). + SetTargetUserID(existingUser.ID). + SetResolvedEmail(existingUser.Email). + SetBrowserSessionKey("oidc-invalid-browser"). + SetUpstreamIdentityClaims(map[string]any{ + "username": "oidc_user", + }). + SetLocalFlowState(map[string]any{ + oauthCompletionResponseKey: map[string]any{ + "step": "bind_login_required", + }, + }). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`) + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/complete-registration", body) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("oidc-invalid-browser")}) + c.Request = req + + handler.CompleteOIDCOAuthRegistration(c) + + require.Equal(t, http.StatusBadRequest, recorder.Code) + + storedSession, err := client.PendingAuthSession.Get(ctx, session.ID) + require.NoError(t, err) + require.Nil(t, storedSession.ConsumedAt) +} + type oidcProviderFixture struct { Subject string PreferredUsername string diff --git a/backend/internal/handler/auth_wechat_oauth.go b/backend/internal/handler/auth_wechat_oauth.go index 45de30a8..5e697fb5 100644 --- a/backend/internal/handler/auth_wechat_oauth.go +++ b/backend/internal/handler/auth_wechat_oauth.go @@ -506,6 +506,14 @@ func (h *AuthHandler) CompleteWeChatOAuthRegistration(c *gin.Context) { response.ErrorFrom(c, err) return } + if err := ensurePendingOAuthCompleteRegistrationSession(session); err != nil { + response.ErrorFrom(c, err) + return + } + if err := h.ensureBackendModeAllowsNewUserLogin(c.Request.Context()); err != nil { + response.ErrorFrom(c, err) + return + } email := strings.TrimSpace(session.ResolvedEmail) username := pendingSessionStringValue(session.UpstreamIdentityClaims, "username") @@ -531,6 +539,7 @@ func (h *AuthHandler) CompleteWeChatOAuthRegistration(c *gin.Context) { response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err)) return } + h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID) if _, err := pendingSvc.ConsumeBrowserSession(c.Request.Context(), sessionToken, browserSessionKey); err != nil { clearOAuthPendingSessionCookie(c, secureCookie) clearOAuthPendingBrowserCookie(c, secureCookie) diff --git a/backend/internal/handler/auth_wechat_oauth_test.go b/backend/internal/handler/auth_wechat_oauth_test.go index c65f4cd1..cd34f52f 100644 --- a/backend/internal/handler/auth_wechat_oauth_test.go +++ b/backend/internal/handler/auth_wechat_oauth_test.go @@ -846,6 +846,59 @@ func TestWeChatOAuthCallbackRepairsLegacyOpenIDOnlyIdentity(t *testing.T) { require.Equal(t, repairedIdentity.ID, channel.IdentityID) } +func TestCompleteWeChatOAuthRegistrationRejectsAdoptExistingUserSession(t *testing.T) { + handler, client := newWeChatOAuthTestHandler(t, false) + defer client.Close() + + ctx := context.Background() + existingUser, err := client.User.Create(). + SetEmail("owner@example.com"). + SetUsername("owner-user"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("wechat-complete-invalid-session"). + SetIntent("adopt_existing_user_by_email"). + SetProviderType("wechat"). + SetProviderKey("wechat-main"). + SetProviderSubject("union-invalid-1"). + SetTargetUserID(existingUser.ID). + SetResolvedEmail(existingUser.Email). + SetBrowserSessionKey("wechat-invalid-browser"). + SetUpstreamIdentityClaims(map[string]any{ + "username": "wechat_user", + }). + SetLocalFlowState(map[string]any{ + oauthCompletionResponseKey: map[string]any{ + "step": "bind_login_required", + }, + }). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`) + recorder := httptest.NewRecorder() + completeCtx, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/wechat/complete-registration", body) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("wechat-invalid-browser")}) + completeCtx.Request = req + + handler.CompleteWeChatOAuthRegistration(completeCtx) + + require.Equal(t, http.StatusBadRequest, recorder.Code) + + storedSession, err := client.PendingAuthSession.Get(ctx, session.ID) + require.NoError(t, err) + require.Nil(t, storedSession.ConsumedAt) +} + func TestWeChatOAuthCallbackRepairsLegacyProviderKeyCanonicalIdentity(t *testing.T) { t.Setenv("WECHAT_OAUTH_OPEN_APP_ID", "wx-open-app") t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", "wx-open-secret") diff --git a/backend/internal/service/auth_oauth_email_flow.go b/backend/internal/service/auth_oauth_email_flow.go index 4ab4e245..2e0107ae 100644 --- a/backend/internal/service/auth_oauth_email_flow.go +++ b/backend/internal/service/auth_oauth_email_flow.go @@ -104,7 +104,7 @@ func (s *AuthService) RegisterOAuthEmailAccount( return nil, nil, ErrServiceUnavailable } - s.postAuthUserBootstrap(ctx, user, signupSource, true) + s.postAuthUserBootstrap(ctx, user, signupSource, false) s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults") if invitationRedeemCode != nil { diff --git a/backend/internal/service/auth_service.go b/backend/internal/service/auth_service.go index 00fefd82..d63a8753 100644 --- a/backend/internal/service/auth_service.go +++ b/backend/internal/service/auth_service.go @@ -430,8 +430,6 @@ func (s *AuthService) Login(ctx context.Context, email, password string) (string if !user.IsActive() { return "", nil, ErrUserNotActive } - s.backfillEmailIdentityOnSuccessfulLogin(ctx, user) - s.touchUserLogin(ctx, user.ID) // 生成JWT token token, err := s.GenerateToken(user) @@ -507,7 +505,7 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username } } else { user = newUser - s.postAuthUserBootstrap(ctx, user, signupSource, true) + s.postAuthUserBootstrap(ctx, user, signupSource, false) s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults") } } else { @@ -527,8 +525,6 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username logger.LegacyPrintf("service.auth", "[Auth] Failed to update username after oauth login: %v", err) } } - s.touchUserLogin(ctx, user.ID) - token, err := s.GenerateToken(user) if err != nil { return "", nil, fmt.Errorf("generate token: %w", err) @@ -634,7 +630,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema return nil, nil, ErrServiceUnavailable } user = newUser - s.postAuthUserBootstrap(ctx, user, signupSource, true) + s.postAuthUserBootstrap(ctx, user, signupSource, false) s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults") } } else { @@ -651,7 +647,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema } } else { user = newUser - s.postAuthUserBootstrap(ctx, user, signupSource, true) + s.postAuthUserBootstrap(ctx, user, signupSource, false) s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults") if invitationRedeemCode != nil { if err := s.redeemRepo.Use(ctx, invitationRedeemCode.ID, user.ID); err != nil { @@ -676,8 +672,6 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema logger.LegacyPrintf("service.auth", "[Auth] Failed to update username after oauth login: %v", err) } } - s.touchUserLogin(ctx, user.ID) - tokenPair, err := s.GenerateTokenPair(ctx, user, "") if err != nil { return nil, nil, fmt.Errorf("generate token pair: %w", err) diff --git a/backend/internal/service/auth_service_identity_sync_test.go b/backend/internal/service/auth_service_identity_sync_test.go index 95c9c933..85c13604 100644 --- a/backend/internal/service/auth_service_identity_sync_test.go +++ b/backend/internal/service/auth_service_identity_sync_test.go @@ -170,24 +170,26 @@ func TestAuthServiceRegisterDualWritesEmailIdentity(t *testing.T) { require.NotNil(t, identity.VerifiedAt) } -func TestAuthServiceLoginTouchesLastLoginAt(t *testing.T) { - svc, repo, client := newAuthServiceWithEnt(t, map[string]string{ +func TestAuthServiceLoginDefersLastLoginTouchUntilRecordSuccessfulLogin(t *testing.T) { + svc, _, client := newAuthServiceWithEnt(t, map[string]string{ service.SettingKeyRegistrationEnabled: "true", }, nil) ctx := context.Background() - user := &service.User{ - Email: "login@example.com", - Role: service.RoleUser, - Status: service.StatusActive, - Balance: 1, - Concurrency: 1, - } - require.NoError(t, user.SetPassword("password")) - require.NoError(t, repo.Create(ctx, user)) + passwordHash, err := svc.HashPassword("password") + require.NoError(t, err) + user, err := client.User.Create(). + SetEmail("login@example.com"). + SetPasswordHash(passwordHash). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + SetBalance(1). + SetConcurrency(1). + Save(ctx) + require.NoError(t, err) old := time.Now().Add(-2 * time.Hour).UTC().Round(time.Second) - _, err := client.User.UpdateOneID(user.ID). + _, err = client.User.UpdateOneID(user.ID). SetLastLoginAt(old). SetLastActiveAt(old). Save(ctx) @@ -202,8 +204,20 @@ func TestAuthServiceLoginTouchesLastLoginAt(t *testing.T) { require.NoError(t, err) require.NotNil(t, storedUser.LastLoginAt) require.NotNil(t, storedUser.LastActiveAt) - require.True(t, storedUser.LastLoginAt.After(old)) - require.True(t, storedUser.LastActiveAt.After(old)) + require.True(t, storedUser.LastLoginAt.Equal(old)) + require.True(t, storedUser.LastActiveAt.Equal(old)) + + identityCount, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("email"), + authidentity.ProviderKeyEQ("email"), + authidentity.ProviderSubjectEQ("login@example.com"), + ). + Count(ctx) + require.NoError(t, err) + require.Zero(t, identityCount) + + svc.RecordSuccessfulLogin(ctx, user.ID) identity, err := client.AuthIdentity.Query(). Where( @@ -273,6 +287,7 @@ func TestAuthServiceLogin_AppliesEmailFirstBindDefaultsOnlyWhenEmailIdentityIsNe require.NoError(t, err) require.NotEmpty(t, token) require.NotNil(t, gotUser) + svc.RecordSuccessfulLogin(ctx, user.ID) storedUser, err := client.User.Get(ctx, user.ID) require.NoError(t, err) @@ -343,6 +358,7 @@ func TestAuthServiceLogin_DoesNotApplyEmailFirstBindDefaultsWhenIdentityAlreadyE require.NoError(t, err) require.NotEmpty(t, token) require.NotNil(t, gotUser) + svc.RecordSuccessfulLogin(ctx, user.ID) storedUser, err := client.User.Get(ctx, user.ID) require.NoError(t, err) @@ -380,6 +396,7 @@ func TestAuthServiceLogin_RetriesEmailFirstBindDefaultsAfterPreviousFailure(t *t require.NoError(t, err) require.NotEmpty(t, token) require.NotNil(t, gotUser) + svc.RecordSuccessfulLogin(ctx, user.ID) storedUser, err := client.User.Get(ctx, user.ID) require.NoError(t, err) @@ -392,6 +409,7 @@ func TestAuthServiceLogin_RetriesEmailFirstBindDefaultsAfterPreviousFailure(t *t require.NoError(t, err) require.NotEmpty(t, token) require.NotNil(t, gotUser) + svc.RecordSuccessfulLogin(ctx, user.ID) storedUser, err = client.User.Get(ctx, user.ID) require.NoError(t, err)