fix auth pending session hardening

This commit is contained in:
IanShaw027
2026-04-21 01:45:25 +08:00
parent 1d8432b8a4
commit 7c6491c2d3
12 changed files with 490 additions and 32 deletions

View File

@@ -1,6 +1,7 @@
package handler
import (
"context"
"log/slog"
"strings"
@@ -105,6 +106,34 @@ func (h *AuthHandler) respondWithTokenPair(c *gin.Context, user *service.User) {
})
}
func (h *AuthHandler) ensureBackendModeAllowsUser(ctx context.Context, user *service.User) error {
if user == nil {
return infraerrors.Unauthorized("INVALID_USER", "user not found")
}
if h == nil || !h.isBackendModeEnabled(ctx) || user.IsAdmin() {
return nil
}
return infraerrors.Forbidden("BACKEND_MODE_ADMIN_ONLY", "Backend mode is active. Only admin login is allowed.")
}
func (h *AuthHandler) ensureBackendModeAllowsNewUserLogin(ctx context.Context) error {
if h == nil || !h.isBackendModeEnabled(ctx) {
return nil
}
return infraerrors.Forbidden("BACKEND_MODE_ADMIN_ONLY", "Backend mode is active. Only admin login is allowed.")
}
func (h *AuthHandler) isBackendModeEnabled(ctx context.Context) bool {
if h == nil || h.settingSvc == nil {
return false
}
settings, err := h.settingSvc.GetPublicSettings(ctx)
if err == nil && settings != nil {
return settings.BackendModeEnabled
}
return h.settingSvc.IsBackendModeEnabled(ctx)
}
// Register handles user registration
// POST /api/v1/auth/register
func (h *AuthHandler) Register(c *gin.Context) {
@@ -178,6 +207,11 @@ func (h *AuthHandler) Login(c *gin.Context) {
}
_ = token // token 由 authService.Login 返回但此处由 respondWithTokenPair 重新生成
if err := h.ensureBackendModeAllowsUser(c.Request.Context(), user); err != nil {
response.ErrorFrom(c, err)
return
}
// Check if TOTP 2FA is enabled for this user
if h.totpService != nil && h.settingSvc.IsTotpEnabled(c.Request.Context()) && user.TotpEnabled {
// Create a temporary login session for 2FA
@@ -195,11 +229,7 @@ func (h *AuthHandler) Login(c *gin.Context) {
return
}
// Backend mode: only admin can login
if h.settingSvc.IsBackendModeEnabled(c.Request.Context()) && !user.IsAdmin() {
response.Forbidden(c, "Backend mode is active. Only admin login is allowed.")
return
}
h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID)
h.respondWithTokenPair(c, user)
}
@@ -264,9 +294,8 @@ func (h *AuthHandler) Login2FA(c *gin.Context) {
return
}
// Backend mode: only admin can login (check BEFORE deleting session)
if h.settingSvc.IsBackendModeEnabled(c.Request.Context()) && !user.IsAdmin() {
response.Forbidden(c, "Backend mode is active. Only admin login is allowed.")
if err := h.ensureBackendModeAllowsUser(c.Request.Context(), user); err != nil {
response.ErrorFrom(c, err)
return
}
@@ -330,6 +359,10 @@ func (h *AuthHandler) Login2FA(c *gin.Context) {
// Delete the login session (only after all checks pass)
_ = h.totpService.DeleteLoginSession(c.Request.Context(), req.TempToken)
if session.PendingOAuthBind == nil {
h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID)
}
h.respondWithTokenPair(c, user)
}

View File

@@ -474,6 +474,14 @@ func (h *AuthHandler) CompleteLinuxDoOAuthRegistration(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
if err := ensurePendingOAuthCompleteRegistrationSession(session); err != nil {
response.ErrorFrom(c, err)
return
}
if err := h.ensureBackendModeAllowsNewUserLogin(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
email := strings.TrimSpace(session.ResolvedEmail)
username := pendingSessionStringValue(session.UpstreamIdentityClaims, "username")
@@ -499,6 +507,7 @@ func (h *AuthHandler) CompleteLinuxDoOAuthRegistration(c *gin.Context) {
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err))
return
}
h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID)
if _, err := pendingSvc.ConsumeBrowserSession(c.Request.Context(), sessionToken, browserSessionKey); err != nil {
clearOAuthPendingSessionCookie(c, secureCookie)
clearOAuthPendingBrowserCookie(c, secureCookie)

View File

@@ -591,6 +591,58 @@ func TestCompleteLinuxDoOAuthRegistrationAppliesPendingAdoptionDecision(t *testi
require.NotNil(t, consumed.ConsumedAt)
}
func TestCompleteLinuxDoOAuthRegistrationRejectsAdoptExistingUserSession(t *testing.T) {
handler, client := newOAuthPendingFlowTestHandler(t, false)
ctx := context.Background()
existingUser, err := client.User.Create().
SetEmail("owner@example.com").
SetUsername("owner-user").
SetPasswordHash("hash").
SetRole(service.RoleUser).
SetStatus(service.StatusActive).
Save(ctx)
require.NoError(t, err)
session, err := client.PendingAuthSession.Create().
SetSessionToken("linuxdo-complete-invalid-session").
SetIntent("adopt_existing_user_by_email").
SetProviderType("linuxdo").
SetProviderKey("linuxdo").
SetProviderSubject("linuxdo-invalid-subject-1").
SetTargetUserID(existingUser.ID).
SetResolvedEmail(existingUser.Email).
SetBrowserSessionKey("linuxdo-invalid-browser").
SetUpstreamIdentityClaims(map[string]any{
"username": "linuxdo_user",
}).
SetLocalFlowState(map[string]any{
oauthCompletionResponseKey: map[string]any{
"step": "bind_login_required",
},
}).
SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
Save(ctx)
require.NoError(t, err)
body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/linuxdo/complete-registration", body)
req.Header.Set("Content-Type", "application/json")
req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("linuxdo-invalid-browser")})
c.Request = req
handler.CompleteLinuxDoOAuthRegistration(c)
require.Equal(t, http.StatusBadRequest, recorder.Code)
storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
require.NoError(t, err)
require.Nil(t, storedSession.ConsumedAt)
}
func newLinuxDoOAuthTestHandler(t *testing.T, invitationEnabled bool, oauthCfg config.LinuxDoConnectConfig) *AuthHandler {
t.Helper()
handler, _ := newLinuxDoOAuthHandlerAndClient(t, invitationEnabled, oauthCfg)

View File

@@ -253,6 +253,35 @@ func pendingSessionWantsInvitation(payload map[string]any) bool {
return strings.EqualFold(strings.TrimSpace(pendingSessionStringValue(payload, "error")), "invitation_required")
}
func pendingOAuthCompletionIncludesTokenPayload(payload map[string]any) bool {
if len(payload) == 0 {
return false
}
for _, key := range []string{"access_token", "refresh_token"} {
if value := pendingSessionStringValue(payload, key); value != "" {
return true
}
}
return false
}
func ensurePendingOAuthCompleteRegistrationSession(session *dbent.PendingAuthSession) error {
if session == nil {
return infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid")
}
if strings.TrimSpace(session.Intent) != oauthIntentLogin {
return infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid")
}
if session.TargetUserID != nil && *session.TargetUserID > 0 {
return infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid")
}
payload, _ := readCompletionResponse(session.LocalFlowState)
if strings.EqualFold(strings.TrimSpace(pendingSessionStringValue(payload, "step")), "bind_login_required") {
return infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid")
}
return nil
}
func (r oauthAdoptionDecisionRequest) hasDecision() bool {
return r.AdoptDisplayName != nil || r.AdoptAvatar != nil
}
@@ -1090,6 +1119,10 @@ func (h *AuthHandler) bindPendingOAuthLogin(c *gin.Context, provider string) {
response.ErrorFrom(c, infraerrors.Conflict("PENDING_AUTH_TARGET_USER_MISMATCH", "pending oauth session must be completed by the targeted user"))
return
}
if err := h.ensureBackendModeAllowsUser(c.Request.Context(), user); err != nil {
response.ErrorFrom(c, err)
return
}
decision, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, req.adoptionDecision())
if err != nil {
@@ -1192,6 +1225,10 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string)
c.JSON(http.StatusOK, buildPendingOAuthSessionStatusPayload(session))
return
}
if err := h.ensureBackendModeAllowsNewUserLogin(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
tokenPair, user, err := h.authService.RegisterOAuthEmailAccount(
c.Request.Context(),
@@ -1215,6 +1252,7 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string)
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to bind pending oauth identity").WithCause(err))
return
}
h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID)
if _, err := pendingSvc.ConsumeBrowserSession(c.Request.Context(), session.SessionToken, session.BrowserSessionKey); err != nil {
clearCookies()
@@ -1279,6 +1317,25 @@ func (h *AuthHandler) ExchangePendingOAuthCompletion(c *gin.Context) {
}
}
applySuggestedProfileToCompletionResponse(payload, session.UpstreamIdentityClaims)
if pendingOAuthCompletionIncludesTokenPayload(payload) {
if session.TargetUserID == nil || *session.TargetUserID <= 0 {
clearCookies()
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_COMPLETION_INVALID", "pending auth completion payload is invalid"))
return
}
user, err := h.userService.GetByID(c.Request.Context(), *session.TargetUserID)
if err != nil {
clearCookies()
response.ErrorFrom(c, err)
return
}
if err := h.ensureBackendModeAllowsUser(c.Request.Context(), user); err != nil {
clearCookies()
response.ErrorFrom(c, err)
return
}
h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID)
}
if pendingSessionWantsInvitation(payload) {
if adoptionDecision.hasDecision() {

View File

@@ -523,6 +523,60 @@ func TestExchangePendingOAuthCompletionLoginFalseFalseBindsIdentityWithoutAdopti
require.NotNil(t, storedSession.ConsumedAt)
}
func TestExchangePendingOAuthCompletionBlocksBackendModeBeforeReturningTokenPayload(t *testing.T) {
handler, client := newOAuthPendingFlowTestHandlerWithDependencies(t, oauthPendingFlowTestHandlerOptions{
settingValues: map[string]string{
service.SettingKeyBackendModeEnabled: "true",
},
})
ctx := context.Background()
userEntity, err := client.User.Create().
SetEmail("blocked@example.com").
SetUsername("blocked-user").
SetPasswordHash("hash").
SetRole(service.RoleUser).
SetStatus(service.StatusActive).
Save(ctx)
require.NoError(t, err)
session, err := client.PendingAuthSession.Create().
SetSessionToken("blocked-backend-mode-session-token").
SetIntent("login").
SetProviderType("linuxdo").
SetProviderKey("linuxdo").
SetProviderSubject("blocked-subject-123").
SetTargetUserID(userEntity.ID).
SetResolvedEmail(userEntity.Email).
SetBrowserSessionKey("blocked-backend-mode-browser-session-key").
SetLocalFlowState(map[string]any{
oauthCompletionResponseKey: map[string]any{
"access_token": "access-token",
"refresh_token": "refresh-token",
"expires_in": float64(3600),
"token_type": "Bearer",
},
}).
SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
Save(ctx)
require.NoError(t, err)
recorder := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", nil)
req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("blocked-backend-mode-browser-session-key")})
ginCtx.Request = req
handler.ExchangePendingOAuthCompletion(ginCtx)
require.Equal(t, http.StatusForbidden, recorder.Code)
storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
require.NoError(t, err)
require.Nil(t, storedSession.ConsumedAt)
}
func TestExchangePendingOAuthCompletionInvitationRequiredFalseFalsePersistsDecisionWithoutBinding(t *testing.T) {
handler, client := newOAuthPendingFlowTestHandler(t, true)
ctx := context.Background()
@@ -773,6 +827,60 @@ func TestCreateOIDCOAuthAccountExistingEmailNormalizesLegacySpacingAndCase(t *te
require.Equal(t, "owner@example.com", storedSession.ResolvedEmail)
}
func TestCreateOIDCOAuthAccountBlocksBackendModeBeforeCreatingUser(t *testing.T) {
handler, client := newOAuthPendingFlowTestHandlerWithDependencies(t, oauthPendingFlowTestHandlerOptions{
emailVerifyEnabled: true,
emailCache: &oauthPendingFlowEmailCacheStub{
verificationCodes: map[string]*service.VerificationCodeData{
"fresh@example.com": {
Code: "246810",
CreatedAt: time.Now().UTC(),
ExpiresAt: time.Now().UTC().Add(15 * time.Minute),
},
},
},
settingValues: map[string]string{
service.SettingKeyBackendModeEnabled: "true",
},
})
ctx := context.Background()
session, err := client.PendingAuthSession.Create().
SetSessionToken("create-account-backend-mode-session-token").
SetIntent("login").
SetProviderType("oidc").
SetProviderKey("https://issuer.example").
SetProviderSubject("oidc-create-backend-mode-123").
SetBrowserSessionKey("create-account-backend-mode-browser-session-key").
SetUpstreamIdentityClaims(map[string]any{
"username": "oidc_user",
}).
SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
Save(ctx)
require.NoError(t, err)
body := bytes.NewBufferString(`{"email":"fresh@example.com","verify_code":"246810","password":"secret-123"}`)
recorder := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/create-account", body)
req.Header.Set("Content-Type", "application/json")
req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("create-account-backend-mode-browser-session-key")})
ginCtx.Request = req
handler.CreateOIDCOAuthAccount(ginCtx)
require.Equal(t, http.StatusForbidden, recorder.Code)
userCount, err := client.User.Query().Where(dbuser.EmailEQ("fresh@example.com")).Count(ctx)
require.NoError(t, err)
require.Zero(t, userCount)
storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
require.NoError(t, err)
require.Nil(t, storedSession.ConsumedAt)
}
func TestBindOIDCOAuthLoginBindsExistingUserAndConsumesSession(t *testing.T) {
handler, client := newOAuthPendingFlowTestHandler(t, false)
ctx := context.Background()
@@ -842,6 +950,70 @@ func TestBindOIDCOAuthLoginBindsExistingUserAndConsumesSession(t *testing.T) {
require.NotNil(t, storedSession.ConsumedAt)
}
func TestBindOIDCOAuthLoginBlocksBackendModeBeforeTokenIssue(t *testing.T) {
handler, client := newOAuthPendingFlowTestHandlerWithDependencies(t, oauthPendingFlowTestHandlerOptions{
settingValues: map[string]string{
service.SettingKeyBackendModeEnabled: "true",
},
})
ctx := context.Background()
passwordHash, err := handler.authService.HashPassword("secret-123")
require.NoError(t, err)
existingUser, err := client.User.Create().
SetEmail("owner@example.com").
SetUsername("owner-user").
SetPasswordHash(passwordHash).
SetRole(service.RoleUser).
SetStatus(service.StatusActive).
Save(ctx)
require.NoError(t, err)
session, err := client.PendingAuthSession.Create().
SetSessionToken("bind-login-backend-mode-session-token").
SetIntent("adopt_existing_user_by_email").
SetProviderType("oidc").
SetProviderKey("https://issuer.example").
SetProviderSubject("oidc-bind-backend-mode-123").
SetTargetUserID(existingUser.ID).
SetResolvedEmail(existingUser.Email).
SetBrowserSessionKey("bind-login-backend-mode-browser-session-key").
SetUpstreamIdentityClaims(map[string]any{
"username": "oidc_user",
}).
SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
Save(ctx)
require.NoError(t, err)
body := bytes.NewBufferString(`{"email":"owner@example.com","password":"secret-123"}`)
recorder := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/bind-login", body)
req.Header.Set("Content-Type", "application/json")
req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("bind-login-backend-mode-browser-session-key")})
ginCtx.Request = req
handler.BindOIDCOAuthLogin(ginCtx)
require.Equal(t, http.StatusForbidden, recorder.Code)
identityCount, err := client.AuthIdentity.Query().
Where(
authidentity.ProviderTypeEQ("oidc"),
authidentity.ProviderKeyEQ("https://issuer.example"),
authidentity.ProviderSubjectEQ("oidc-bind-backend-mode-123"),
).
Count(ctx)
require.NoError(t, err)
require.Zero(t, identityCount)
storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
require.NoError(t, err)
require.Nil(t, storedSession.ConsumedAt)
}
func TestBindOIDCOAuthLoginRejectsInvalidPasswordWithoutConsumingSession(t *testing.T) {
handler, client := newOAuthPendingFlowTestHandler(t, false)
ctx := context.Background()

View File

@@ -516,6 +516,14 @@ func (h *AuthHandler) CompleteOIDCOAuthRegistration(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
if err := ensurePendingOAuthCompleteRegistrationSession(session); err != nil {
response.ErrorFrom(c, err)
return
}
if err := h.ensureBackendModeAllowsNewUserLogin(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
email := strings.TrimSpace(session.ResolvedEmail)
username := pendingSessionStringValue(session.UpstreamIdentityClaims, "username")
@@ -541,6 +549,7 @@ func (h *AuthHandler) CompleteOIDCOAuthRegistration(c *gin.Context) {
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err))
return
}
h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID)
if _, err := pendingSvc.ConsumeBrowserSession(c.Request.Context(), sessionToken, browserSessionKey); err != nil {
clearOAuthPendingSessionCookie(c, secureCookie)
clearOAuthPendingBrowserCookie(c, secureCookie)

View File

@@ -431,6 +431,58 @@ func TestCompleteOIDCOAuthRegistrationAppliesPendingAdoptionDecision(t *testing.
require.NotNil(t, consumed.ConsumedAt)
}
func TestCompleteOIDCOAuthRegistrationRejectsAdoptExistingUserSession(t *testing.T) {
handler, client := newOAuthPendingFlowTestHandler(t, false)
ctx := context.Background()
existingUser, err := client.User.Create().
SetEmail("owner@example.com").
SetUsername("owner-user").
SetPasswordHash("hash").
SetRole(service.RoleUser).
SetStatus(service.StatusActive).
Save(ctx)
require.NoError(t, err)
session, err := client.PendingAuthSession.Create().
SetSessionToken("oidc-complete-invalid-session").
SetIntent("adopt_existing_user_by_email").
SetProviderType("oidc").
SetProviderKey("https://issuer.example.com").
SetProviderSubject("oidc-invalid-subject-1").
SetTargetUserID(existingUser.ID).
SetResolvedEmail(existingUser.Email).
SetBrowserSessionKey("oidc-invalid-browser").
SetUpstreamIdentityClaims(map[string]any{
"username": "oidc_user",
}).
SetLocalFlowState(map[string]any{
oauthCompletionResponseKey: map[string]any{
"step": "bind_login_required",
},
}).
SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
Save(ctx)
require.NoError(t, err)
body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/complete-registration", body)
req.Header.Set("Content-Type", "application/json")
req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("oidc-invalid-browser")})
c.Request = req
handler.CompleteOIDCOAuthRegistration(c)
require.Equal(t, http.StatusBadRequest, recorder.Code)
storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
require.NoError(t, err)
require.Nil(t, storedSession.ConsumedAt)
}
type oidcProviderFixture struct {
Subject string
PreferredUsername string

View File

@@ -506,6 +506,14 @@ func (h *AuthHandler) CompleteWeChatOAuthRegistration(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
if err := ensurePendingOAuthCompleteRegistrationSession(session); err != nil {
response.ErrorFrom(c, err)
return
}
if err := h.ensureBackendModeAllowsNewUserLogin(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
email := strings.TrimSpace(session.ResolvedEmail)
username := pendingSessionStringValue(session.UpstreamIdentityClaims, "username")
@@ -531,6 +539,7 @@ func (h *AuthHandler) CompleteWeChatOAuthRegistration(c *gin.Context) {
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err))
return
}
h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID)
if _, err := pendingSvc.ConsumeBrowserSession(c.Request.Context(), sessionToken, browserSessionKey); err != nil {
clearOAuthPendingSessionCookie(c, secureCookie)
clearOAuthPendingBrowserCookie(c, secureCookie)

View File

@@ -846,6 +846,59 @@ func TestWeChatOAuthCallbackRepairsLegacyOpenIDOnlyIdentity(t *testing.T) {
require.Equal(t, repairedIdentity.ID, channel.IdentityID)
}
func TestCompleteWeChatOAuthRegistrationRejectsAdoptExistingUserSession(t *testing.T) {
handler, client := newWeChatOAuthTestHandler(t, false)
defer client.Close()
ctx := context.Background()
existingUser, err := client.User.Create().
SetEmail("owner@example.com").
SetUsername("owner-user").
SetPasswordHash("hash").
SetRole(service.RoleUser).
SetStatus(service.StatusActive).
Save(ctx)
require.NoError(t, err)
session, err := client.PendingAuthSession.Create().
SetSessionToken("wechat-complete-invalid-session").
SetIntent("adopt_existing_user_by_email").
SetProviderType("wechat").
SetProviderKey("wechat-main").
SetProviderSubject("union-invalid-1").
SetTargetUserID(existingUser.ID).
SetResolvedEmail(existingUser.Email).
SetBrowserSessionKey("wechat-invalid-browser").
SetUpstreamIdentityClaims(map[string]any{
"username": "wechat_user",
}).
SetLocalFlowState(map[string]any{
oauthCompletionResponseKey: map[string]any{
"step": "bind_login_required",
},
}).
SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
Save(ctx)
require.NoError(t, err)
body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`)
recorder := httptest.NewRecorder()
completeCtx, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/wechat/complete-registration", body)
req.Header.Set("Content-Type", "application/json")
req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("wechat-invalid-browser")})
completeCtx.Request = req
handler.CompleteWeChatOAuthRegistration(completeCtx)
require.Equal(t, http.StatusBadRequest, recorder.Code)
storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
require.NoError(t, err)
require.Nil(t, storedSession.ConsumedAt)
}
func TestWeChatOAuthCallbackRepairsLegacyProviderKeyCanonicalIdentity(t *testing.T) {
t.Setenv("WECHAT_OAUTH_OPEN_APP_ID", "wx-open-app")
t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", "wx-open-secret")

View File

@@ -104,7 +104,7 @@ func (s *AuthService) RegisterOAuthEmailAccount(
return nil, nil, ErrServiceUnavailable
}
s.postAuthUserBootstrap(ctx, user, signupSource, true)
s.postAuthUserBootstrap(ctx, user, signupSource, false)
s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
if invitationRedeemCode != nil {

View File

@@ -430,8 +430,6 @@ func (s *AuthService) Login(ctx context.Context, email, password string) (string
if !user.IsActive() {
return "", nil, ErrUserNotActive
}
s.backfillEmailIdentityOnSuccessfulLogin(ctx, user)
s.touchUserLogin(ctx, user.ID)
// 生成JWT token
token, err := s.GenerateToken(user)
@@ -507,7 +505,7 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
}
} else {
user = newUser
s.postAuthUserBootstrap(ctx, user, signupSource, true)
s.postAuthUserBootstrap(ctx, user, signupSource, false)
s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
}
} else {
@@ -527,8 +525,6 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
logger.LegacyPrintf("service.auth", "[Auth] Failed to update username after oauth login: %v", err)
}
}
s.touchUserLogin(ctx, user.ID)
token, err := s.GenerateToken(user)
if err != nil {
return "", nil, fmt.Errorf("generate token: %w", err)
@@ -634,7 +630,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
return nil, nil, ErrServiceUnavailable
}
user = newUser
s.postAuthUserBootstrap(ctx, user, signupSource, true)
s.postAuthUserBootstrap(ctx, user, signupSource, false)
s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
}
} else {
@@ -651,7 +647,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
}
} else {
user = newUser
s.postAuthUserBootstrap(ctx, user, signupSource, true)
s.postAuthUserBootstrap(ctx, user, signupSource, false)
s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
if invitationRedeemCode != nil {
if err := s.redeemRepo.Use(ctx, invitationRedeemCode.ID, user.ID); err != nil {
@@ -676,8 +672,6 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
logger.LegacyPrintf("service.auth", "[Auth] Failed to update username after oauth login: %v", err)
}
}
s.touchUserLogin(ctx, user.ID)
tokenPair, err := s.GenerateTokenPair(ctx, user, "")
if err != nil {
return nil, nil, fmt.Errorf("generate token pair: %w", err)

View File

@@ -170,24 +170,26 @@ func TestAuthServiceRegisterDualWritesEmailIdentity(t *testing.T) {
require.NotNil(t, identity.VerifiedAt)
}
func TestAuthServiceLoginTouchesLastLoginAt(t *testing.T) {
svc, repo, client := newAuthServiceWithEnt(t, map[string]string{
func TestAuthServiceLoginDefersLastLoginTouchUntilRecordSuccessfulLogin(t *testing.T) {
svc, _, client := newAuthServiceWithEnt(t, map[string]string{
service.SettingKeyRegistrationEnabled: "true",
}, nil)
ctx := context.Background()
user := &service.User{
Email: "login@example.com",
Role: service.RoleUser,
Status: service.StatusActive,
Balance: 1,
Concurrency: 1,
}
require.NoError(t, user.SetPassword("password"))
require.NoError(t, repo.Create(ctx, user))
passwordHash, err := svc.HashPassword("password")
require.NoError(t, err)
user, err := client.User.Create().
SetEmail("login@example.com").
SetPasswordHash(passwordHash).
SetRole(service.RoleUser).
SetStatus(service.StatusActive).
SetBalance(1).
SetConcurrency(1).
Save(ctx)
require.NoError(t, err)
old := time.Now().Add(-2 * time.Hour).UTC().Round(time.Second)
_, err := client.User.UpdateOneID(user.ID).
_, err = client.User.UpdateOneID(user.ID).
SetLastLoginAt(old).
SetLastActiveAt(old).
Save(ctx)
@@ -202,8 +204,20 @@ func TestAuthServiceLoginTouchesLastLoginAt(t *testing.T) {
require.NoError(t, err)
require.NotNil(t, storedUser.LastLoginAt)
require.NotNil(t, storedUser.LastActiveAt)
require.True(t, storedUser.LastLoginAt.After(old))
require.True(t, storedUser.LastActiveAt.After(old))
require.True(t, storedUser.LastLoginAt.Equal(old))
require.True(t, storedUser.LastActiveAt.Equal(old))
identityCount, err := client.AuthIdentity.Query().
Where(
authidentity.ProviderTypeEQ("email"),
authidentity.ProviderKeyEQ("email"),
authidentity.ProviderSubjectEQ("login@example.com"),
).
Count(ctx)
require.NoError(t, err)
require.Zero(t, identityCount)
svc.RecordSuccessfulLogin(ctx, user.ID)
identity, err := client.AuthIdentity.Query().
Where(
@@ -273,6 +287,7 @@ func TestAuthServiceLogin_AppliesEmailFirstBindDefaultsOnlyWhenEmailIdentityIsNe
require.NoError(t, err)
require.NotEmpty(t, token)
require.NotNil(t, gotUser)
svc.RecordSuccessfulLogin(ctx, user.ID)
storedUser, err := client.User.Get(ctx, user.ID)
require.NoError(t, err)
@@ -343,6 +358,7 @@ func TestAuthServiceLogin_DoesNotApplyEmailFirstBindDefaultsWhenIdentityAlreadyE
require.NoError(t, err)
require.NotEmpty(t, token)
require.NotNil(t, gotUser)
svc.RecordSuccessfulLogin(ctx, user.ID)
storedUser, err := client.User.Get(ctx, user.ID)
require.NoError(t, err)
@@ -380,6 +396,7 @@ func TestAuthServiceLogin_RetriesEmailFirstBindDefaultsAfterPreviousFailure(t *t
require.NoError(t, err)
require.NotEmpty(t, token)
require.NotNil(t, gotUser)
svc.RecordSuccessfulLogin(ctx, user.ID)
storedUser, err := client.User.Get(ctx, user.ID)
require.NoError(t, err)
@@ -392,6 +409,7 @@ func TestAuthServiceLogin_RetriesEmailFirstBindDefaultsAfterPreviousFailure(t *t
require.NoError(t, err)
require.NotEmpty(t, token)
require.NotNil(t, gotUser)
svc.RecordSuccessfulLogin(ctx, user.ID)
storedUser, err = client.User.Get(ctx, user.ID)
require.NoError(t, err)