fix(auth): harden pending oauth and backend mode flows

This commit is contained in:
IanShaw027
2026-04-22 12:30:00 +08:00
parent 1ffebbb568
commit 767f2f2dfe
12 changed files with 568 additions and 44 deletions

View File

@@ -678,6 +678,8 @@ func (h *AuthHandler) Logout(c *gin.Context) {
// 不影响登出流程 // 不影响登出流程
} }
} }
h.consumePendingOAuthSessionOnLogout(c)
clearOAuthLogoutCookies(c)
response.Success(c, LogoutResponse{ response.Success(c, LogoutResponse{
Message: "Logged out successfully", Message: "Logged out successfully",

View File

@@ -469,6 +469,15 @@ func (h *AuthHandler) CompleteLinuxDoOAuthRegistration(c *gin.Context) {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
} }
if updatedSession, handled, err := h.legacyCompleteRegistrationSessionStatus(c, session); err != nil {
response.ErrorFrom(c, err)
return
} else if handled {
c.JSON(http.StatusOK, buildPendingOAuthSessionStatusPayload(updatedSession))
return
} else {
session = updatedSession
}
if err := h.ensureBackendModeAllowsNewUserLogin(c.Request.Context()); err != nil { if err := h.ensureBackendModeAllowsNewUserLogin(c.Request.Context()); err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return

View File

@@ -757,6 +757,61 @@ func TestCompleteLinuxDoOAuthRegistrationRejectsAdoptExistingUserSession(t *test
require.Nil(t, storedSession.ConsumedAt) require.Nil(t, storedSession.ConsumedAt)
} }
func TestCompleteLinuxDoOAuthRegistrationReturnsPendingSessionWhenChoiceStillRequired(t *testing.T) {
handler, client := newOAuthPendingFlowTestHandler(t, false)
ctx := context.Background()
session, err := client.PendingAuthSession.Create().
SetSessionToken("linuxdo-complete-choice-session").
SetIntent("login").
SetProviderType("linuxdo").
SetProviderKey("linuxdo").
SetProviderSubject("linuxdo-choice-subject-1").
SetResolvedEmail("linuxdo-choice-subject-1@linuxdo-connect.invalid").
SetBrowserSessionKey("linuxdo-choice-browser").
SetUpstreamIdentityClaims(map[string]any{
"username": "linuxdo_user",
}).
SetLocalFlowState(map[string]any{
oauthCompletionResponseKey: map[string]any{
"step": oauthPendingChoiceStep,
"redirect": "/dashboard",
"email": "fresh@example.com",
"resolved_email": "fresh@example.com",
"force_email_on_signup": true,
},
}).
SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
Save(ctx)
require.NoError(t, err)
body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/linuxdo/complete-registration", body)
req.Header.Set("Content-Type", "application/json")
req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("linuxdo-choice-browser")})
c.Request = req
handler.CompleteLinuxDoOAuthRegistration(c)
require.Equal(t, http.StatusOK, recorder.Code)
responseData := decodeJSONBody(t, recorder)
require.Equal(t, "pending_session", responseData["auth_result"])
require.Equal(t, oauthPendingChoiceStep, responseData["step"])
require.Equal(t, true, responseData["force_email_on_signup"])
require.Empty(t, responseData["access_token"])
userCount, err := client.User.Query().Count(ctx)
require.NoError(t, err)
require.Zero(t, userCount)
storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
require.NoError(t, err)
require.Nil(t, storedSession.ConsumedAt)
}
func newLinuxDoOAuthTestHandler(t *testing.T, invitationEnabled bool, oauthCfg config.LinuxDoConnectConfig) *AuthHandler { func newLinuxDoOAuthTestHandler(t *testing.T, invitationEnabled bool, oauthCfg config.LinuxDoConnectConfig) *AuthHandler {
t.Helper() t.Helper()
handler, _ := newLinuxDoOAuthHandlerAndClient(t, invitationEnabled, oauthCfg) handler, _ := newLinuxDoOAuthHandlerAndClient(t, invitationEnabled, oauthCfg)

View File

@@ -0,0 +1,68 @@
package handler
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestLogoutClearsOAuthStateCookiesAndConsumesPendingSession(t *testing.T) {
handler, client := newOAuthPendingFlowTestHandler(t, false)
ctx := context.Background()
session, err := client.PendingAuthSession.Create().
SetSessionToken("logout-pending-session-token").
SetIntent("login").
SetProviderType("oidc").
SetProviderKey("https://issuer.example").
SetProviderSubject("logout-subject-123").
SetBrowserSessionKey("logout-browser-session-key").
SetResolvedEmail("logout@example.com").
SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
Save(ctx)
require.NoError(t, err)
recorder := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/logout", nil)
req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("logout-browser-session-key")})
req.AddCookie(&http.Cookie{Name: oauthBindAccessTokenCookieName, Value: "bind-access-token"})
req.AddCookie(&http.Cookie{Name: linuxDoOAuthStateCookieName, Value: encodeCookieValue("linuxdo-state")})
req.AddCookie(&http.Cookie{Name: oidcOAuthStateCookieName, Value: encodeCookieValue("oidc-state")})
req.AddCookie(&http.Cookie{Name: wechatOAuthStateCookieName, Value: encodeCookieValue("wechat-state")})
req.AddCookie(&http.Cookie{Name: wechatPaymentOAuthStateName, Value: encodeCookieValue("wechat-payment-state")})
ginCtx.Request = req
handler.Logout(ginCtx)
require.Equal(t, http.StatusOK, recorder.Code)
cookies := recorder.Result().Cookies()
for _, name := range []string{
oauthPendingSessionCookieName,
oauthPendingBrowserCookieName,
oauthBindAccessTokenCookieName,
linuxDoOAuthStateCookieName,
oidcOAuthStateCookieName,
wechatOAuthStateCookieName,
wechatPaymentOAuthStateName,
} {
cookie := findCookie(cookies, name)
require.NotNil(t, cookie, name)
require.Equal(t, -1, cookie.MaxAge, name)
require.True(t, cookie.HttpOnly, name)
}
storedSession, err := client.PendingAuthSession.Query().
Where(pendingauthsession.IDEQ(session.ID)).
Only(ctx)
require.NoError(t, err)
require.NotNil(t, storedSession.ConsumedAt)
}

View File

@@ -310,6 +310,78 @@ func ensurePendingOAuthCompleteRegistrationSession(session *dbent.PendingAuthSes
return nil return nil
} }
func buildLegacyCompleteRegistrationPendingResponse(
session *dbent.PendingAuthSession,
forceEmailOnSignup bool,
emailVerificationRequired bool,
) map[string]any {
completionResponse := normalizePendingOAuthCompletionResponse(mergePendingCompletionResponse(session, map[string]any{
"step": oauthPendingChoiceStep,
"adoption_required": true,
"create_account_allowed": true,
"force_email_on_signup": forceEmailOnSignup,
}))
if email := strings.TrimSpace(session.ResolvedEmail); email != "" {
if _, exists := completionResponse["email"]; !exists {
completionResponse["email"] = email
}
if _, exists := completionResponse["resolved_email"]; !exists {
completionResponse["resolved_email"] = email
}
}
if _, exists := completionResponse["choice_reason"]; !exists {
switch {
case forceEmailOnSignup:
completionResponse["choice_reason"] = "force_email_on_signup"
case emailVerificationRequired:
completionResponse["choice_reason"] = "email_verification_required"
default:
completionResponse["choice_reason"] = "third_party_signup"
}
}
return completionResponse
}
func (h *AuthHandler) legacyCompleteRegistrationSessionStatus(
c *gin.Context,
session *dbent.PendingAuthSession,
) (*dbent.PendingAuthSession, bool, error) {
if session == nil {
return nil, false, infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid")
}
payload := normalizePendingOAuthCompletionResponse(mergePendingCompletionResponse(session, nil))
if step := pendingSessionStringValue(payload, "step"); step != "" {
return session, true, nil
}
emailVerificationRequired := h != nil && h.authService != nil && h.authService.IsEmailVerifyEnabled(c.Request.Context())
forceEmailOnSignup := h.isForceEmailOnThirdPartySignup(c.Request.Context())
if !emailVerificationRequired && !forceEmailOnSignup {
return session, false, nil
}
client := h.entClient()
if client == nil {
return nil, false, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")
}
updatedSession, err := updatePendingOAuthSessionProgress(
c.Request.Context(),
client,
session,
strings.TrimSpace(session.Intent),
strings.TrimSpace(session.ResolvedEmail),
nil,
buildLegacyCompleteRegistrationPendingResponse(session, forceEmailOnSignup, emailVerificationRequired),
)
if err != nil {
return nil, false, infraerrors.InternalServer("PENDING_AUTH_SESSION_UPDATE_FAILED", "failed to update pending oauth session").WithCause(err)
}
return updatedSession, true, nil
}
func (r oauthAdoptionDecisionRequest) hasDecision() bool { func (r oauthAdoptionDecisionRequest) hasDecision() bool {
return r.AdoptDisplayName != nil || r.AdoptAvatar != nil return r.AdoptDisplayName != nil || r.AdoptAvatar != nil
} }
@@ -1272,6 +1344,59 @@ func readPendingOAuthBrowserSession(c *gin.Context, h *AuthHandler) (*service.Au
return svc, session, clearCookies, nil return svc, session, clearCookies, nil
} }
func (h *AuthHandler) consumePendingOAuthSessionOnLogout(c *gin.Context) {
if c == nil || c.Request == nil {
return
}
sessionToken, err := readOAuthPendingSessionCookie(c)
if err != nil || strings.TrimSpace(sessionToken) == "" {
return
}
browserSessionKey, err := readOAuthPendingBrowserCookie(c)
if err != nil || strings.TrimSpace(browserSessionKey) == "" {
return
}
svc, err := h.pendingIdentityService()
if err != nil {
return
}
_, _ = svc.ConsumeBrowserSession(c.Request.Context(), sessionToken, browserSessionKey)
}
func clearOAuthLogoutCookies(c *gin.Context) {
secureCookie := isRequestHTTPS(c)
clearOAuthPendingSessionCookie(c, secureCookie)
clearOAuthPendingBrowserCookie(c, secureCookie)
clearOAuthBindAccessTokenCookie(c, secureCookie)
clearCookie(c, linuxDoOAuthStateCookieName, secureCookie)
clearCookie(c, linuxDoOAuthVerifierCookie, secureCookie)
clearCookie(c, linuxDoOAuthRedirectCookie, secureCookie)
clearCookie(c, linuxDoOAuthIntentCookieName, secureCookie)
clearCookie(c, linuxDoOAuthBindUserCookieName, secureCookie)
oidcClearCookie(c, oidcOAuthStateCookieName, secureCookie)
oidcClearCookie(c, oidcOAuthVerifierCookie, secureCookie)
oidcClearCookie(c, oidcOAuthRedirectCookie, secureCookie)
oidcClearCookie(c, oidcOAuthNonceCookie, secureCookie)
oidcClearCookie(c, oidcOAuthIntentCookieName, secureCookie)
oidcClearCookie(c, oidcOAuthBindUserCookieName, secureCookie)
wechatClearCookie(c, wechatOAuthStateCookieName, secureCookie)
wechatClearCookie(c, wechatOAuthRedirectCookieName, secureCookie)
wechatClearCookie(c, wechatOAuthIntentCookieName, secureCookie)
wechatClearCookie(c, wechatOAuthModeCookieName, secureCookie)
wechatClearCookie(c, wechatOAuthBindUserCookieName, secureCookie)
wechatPaymentClearCookie(c, wechatPaymentOAuthStateName, secureCookie)
wechatPaymentClearCookie(c, wechatPaymentOAuthRedirect, secureCookie)
wechatPaymentClearCookie(c, wechatPaymentOAuthContextName, secureCookie)
wechatPaymentClearCookie(c, wechatPaymentOAuthScope, secureCookie)
}
func buildPendingOAuthSessionStatusPayload(session *dbent.PendingAuthSession) gin.H { func buildPendingOAuthSessionStatusPayload(session *dbent.PendingAuthSession) gin.H {
completionResponse := normalizePendingOAuthCompletionResponse(mergePendingCompletionResponse(session, nil)) completionResponse := normalizePendingOAuthCompletionResponse(mergePendingCompletionResponse(session, nil))
payload := gin.H{ payload := gin.H{
@@ -1451,6 +1576,10 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string)
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
} }
if err := ensurePendingOAuthCompleteRegistrationSession(session); err != nil {
response.ErrorFrom(c, err)
return
}
if strings.TrimSpace(provider) != "" && !strings.EqualFold(strings.TrimSpace(session.ProviderType), provider) { if strings.TrimSpace(provider) != "" && !strings.EqualFold(strings.TrimSpace(session.ProviderType), provider) {
response.BadRequest(c, "Pending oauth session provider mismatch") response.BadRequest(c, "Pending oauth session provider mismatch")
return return

View File

@@ -1228,6 +1228,26 @@ func TestCreateOIDCOAuthAccountBlocksBackendModeBeforeCreatingUser(t *testing.T)
require.Nil(t, storedSession.ConsumedAt) require.Nil(t, storedSession.ConsumedAt)
} }
func TestLogoutClearsPendingOAuthAndBindCookies(t *testing.T) {
handler, _ := newOAuthPendingFlowTestHandler(t, false)
recorder := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/logout", bytes.NewBufferString(`{}`))
req.Header.Set("Content-Type", "application/json")
req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue("pending-session-token")})
req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("pending-browser-key")})
req.AddCookie(&http.Cookie{Name: oauthBindAccessTokenCookieName, Value: "bind-token"})
ginCtx.Request = req
handler.Logout(ginCtx)
require.Equal(t, http.StatusOK, recorder.Code)
require.Equal(t, -1, findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName).MaxAge)
require.Equal(t, -1, findCookie(recorder.Result().Cookies(), oauthPendingBrowserCookieName).MaxAge)
require.Equal(t, -1, findCookie(recorder.Result().Cookies(), oauthBindAccessTokenCookieName).MaxAge)
}
func TestCreateOIDCOAuthAccountRollsBackCreatedUserWhenBindingFails(t *testing.T) { func TestCreateOIDCOAuthAccountRollsBackCreatedUserWhenBindingFails(t *testing.T) {
handler, client := newOAuthPendingFlowTestHandlerWithEmailVerification(t, true, "fresh@example.com", "246810") handler, client := newOAuthPendingFlowTestHandlerWithEmailVerification(t, true, "fresh@example.com", "246810")
ctx := context.Background() ctx := context.Background()

View File

@@ -374,19 +374,19 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
ProviderSubject: subject, ProviderSubject: subject,
} }
upstreamClaims := map[string]any{ upstreamClaims := map[string]any{
"email": email, "email": email,
"username": username, "username": username,
"subject": subject, "subject": subject,
"issuer": issuer, "issuer": issuer,
"email_verified": emailVerified != nil && *emailVerified, "email_verified": emailVerified != nil && *emailVerified,
"provider_fallback": strings.TrimSpace(cfg.ProviderName), "provider_fallback": strings.TrimSpace(cfg.ProviderName),
"suggested_display_name": firstNonEmpty(userInfoClaims.DisplayName, func() string { "suggested_display_name": firstNonEmpty(userInfoClaims.DisplayName, func() string {
if idClaims != nil { if idClaims != nil {
return idClaims.Name return idClaims.Name
} }
return "" return ""
}(), username), }(), username),
"suggested_avatar_url": userInfoClaims.AvatarURL, "suggested_avatar_url": userInfoClaims.AvatarURL,
} }
if compatEmail != "" && !strings.EqualFold(strings.TrimSpace(compatEmail), strings.TrimSpace(email)) { if compatEmail != "" && !strings.EqualFold(strings.TrimSpace(compatEmail), strings.TrimSpace(email)) {
upstreamClaims["compat_email"] = compatEmail upstreamClaims["compat_email"] = compatEmail
@@ -622,6 +622,15 @@ func (h *AuthHandler) CompleteOIDCOAuthRegistration(c *gin.Context) {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
} }
if updatedSession, handled, err := h.legacyCompleteRegistrationSessionStatus(c, session); err != nil {
response.ErrorFrom(c, err)
return
} else if handled {
c.JSON(http.StatusOK, buildPendingOAuthSessionStatusPayload(updatedSession))
return
} else {
session = updatedSession
}
if err := h.ensureBackendModeAllowsNewUserLogin(c.Request.Context()); err != nil { if err := h.ensureBackendModeAllowsNewUserLogin(c.Request.Context()); err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return

View File

@@ -692,6 +692,62 @@ func TestCompleteOIDCOAuthRegistrationRejectsAdoptExistingUserSession(t *testing
require.Nil(t, storedSession.ConsumedAt) require.Nil(t, storedSession.ConsumedAt)
} }
func TestCompleteOIDCOAuthRegistrationReturnsPendingSessionWhenChoiceStillRequired(t *testing.T) {
handler, client := newOAuthPendingFlowTestHandler(t, false)
ctx := context.Background()
session, err := client.PendingAuthSession.Create().
SetSessionToken("oidc-complete-choice-session").
SetIntent("login").
SetProviderType("oidc").
SetProviderKey("https://issuer.example.com").
SetProviderSubject("oidc-choice-subject-1").
SetResolvedEmail("oidc-choice-subject-1@oidc-connect.invalid").
SetBrowserSessionKey("oidc-choice-browser").
SetUpstreamIdentityClaims(map[string]any{
"username": "oidc_user",
"issuer": "https://issuer.example.com",
}).
SetLocalFlowState(map[string]any{
oauthCompletionResponseKey: map[string]any{
"step": oauthPendingChoiceStep,
"redirect": "/dashboard",
"email": "fresh@example.com",
"resolved_email": "fresh@example.com",
"force_email_on_signup": true,
},
}).
SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
Save(ctx)
require.NoError(t, err)
body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/complete-registration", body)
req.Header.Set("Content-Type", "application/json")
req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("oidc-choice-browser")})
c.Request = req
handler.CompleteOIDCOAuthRegistration(c)
require.Equal(t, http.StatusOK, recorder.Code)
responseData := decodeJSONBody(t, recorder)
require.Equal(t, "pending_session", responseData["auth_result"])
require.Equal(t, oauthPendingChoiceStep, responseData["step"])
require.Equal(t, true, responseData["force_email_on_signup"])
require.Empty(t, responseData["access_token"])
userCount, err := client.User.Query().Count(ctx)
require.NoError(t, err)
require.Zero(t, userCount)
storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
require.NoError(t, err)
require.Nil(t, storedSession.ConsumedAt)
}
type oidcProviderFixture struct { type oidcProviderFixture struct {
Subject string Subject string
PreferredUsername string PreferredUsername string

View File

@@ -525,6 +525,15 @@ func (h *AuthHandler) CompleteWeChatOAuthRegistration(c *gin.Context) {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
} }
if updatedSession, handled, err := h.legacyCompleteRegistrationSessionStatus(c, session); err != nil {
response.ErrorFrom(c, err)
return
} else if handled {
c.JSON(http.StatusOK, buildPendingOAuthSessionStatusPayload(updatedSession))
return
} else {
session = updatedSession
}
if err := h.ensureBackendModeAllowsNewUserLogin(c.Request.Context()); err != nil { if err := h.ensureBackendModeAllowsNewUserLogin(c.Request.Context()); err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return

View File

@@ -19,7 +19,6 @@ import (
"github.com/Wei-Shaw/sub2api/ent/enttest" "github.com/Wei-Shaw/sub2api/ent/enttest"
"github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
"github.com/Wei-Shaw/sub2api/ent/pendingauthsession" "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/payment" "github.com/Wei-Shaw/sub2api/internal/payment"
"github.com/Wei-Shaw/sub2api/internal/repository" "github.com/Wei-Shaw/sub2api/internal/repository"
@@ -700,7 +699,7 @@ func TestWeChatOAuthCallbackBindRejectsLegacyProviderKeyOwnershipConflict(t *tes
require.Zero(t, count) require.Zero(t, count)
} }
func TestCompleteWeChatOAuthRegistrationAfterInvitationPendingSession(t *testing.T) { func TestCompleteWeChatOAuthRegistrationAfterInvitationPendingSessionReturnsPendingSession(t *testing.T) {
originalAccessTokenURL := wechatOAuthAccessTokenURL originalAccessTokenURL := wechatOAuthAccessTokenURL
originalUserInfoURL := wechatOAuthUserInfoURL originalUserInfoURL := wechatOAuthUserInfoURL
t.Cleanup(func() { t.Cleanup(func() {
@@ -773,27 +772,32 @@ func TestCompleteWeChatOAuthRegistrationAfterInvitationPendingSession(t *testing
require.Equal(t, http.StatusOK, completeRecorder.Code) require.Equal(t, http.StatusOK, completeRecorder.Code)
responseData := decodeJSONBody(t, completeRecorder) responseData := decodeJSONBody(t, completeRecorder)
require.NotEmpty(t, responseData["access_token"]) require.Equal(t, "pending_session", responseData["auth_result"])
require.Equal(t, oauthPendingChoiceStep, responseData["step"])
require.Equal(t, true, responseData["adoption_required"])
require.Empty(t, responseData["access_token"])
userEntity, err := client.User.Query(). consumed, err := client.PendingAuthSession.Query().
Where(dbuser.EmailEQ("wechat-union-456@wechat-connect.invalid")). Where(pendingauthsession.IDEQ(pendingSession.ID)).
Only(ctx) Only(ctx)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, "WeChat Display", userEntity.Username) require.Nil(t, consumed.ConsumedAt)
identity, err := client.AuthIdentity.Query(). userCount, err := client.User.Query().Count(ctx)
require.NoError(t, err)
require.Zero(t, userCount)
identityCount, err := client.AuthIdentity.Query().
Where( Where(
authidentity.ProviderTypeEQ("wechat"), authidentity.ProviderTypeEQ("wechat"),
authidentity.ProviderKeyEQ("wechat-main"), authidentity.ProviderKeyEQ("wechat-main"),
authidentity.ProviderSubjectEQ("union-456"), authidentity.ProviderSubjectEQ("union-456"),
). ).
Only(ctx) Count(ctx)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, userEntity.ID, identity.UserID) require.Zero(t, identityCount)
require.Equal(t, "WeChat Display", identity.Metadata["display_name"])
require.Equal(t, "https://cdn.example/wechat.png", identity.Metadata["avatar_url"])
channel, err := client.AuthIdentityChannel.Query(). channelCount, err := client.AuthIdentityChannel.Query().
Where( Where(
authidentitychannel.ProviderTypeEQ("wechat"), authidentitychannel.ProviderTypeEQ("wechat"),
authidentitychannel.ProviderKeyEQ("wechat-main"), authidentitychannel.ProviderKeyEQ("wechat-main"),
@@ -801,25 +805,15 @@ func TestCompleteWeChatOAuthRegistrationAfterInvitationPendingSession(t *testing
authidentitychannel.ChannelAppIDEQ("wx-open-app"), authidentitychannel.ChannelAppIDEQ("wx-open-app"),
authidentitychannel.ChannelSubjectEQ("openid-123"), authidentitychannel.ChannelSubjectEQ("openid-123"),
). ).
Only(ctx) Count(ctx)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, identity.ID, channel.IdentityID) require.Zero(t, channelCount)
require.Equal(t, "union-456", channel.Metadata["unionid"])
decision, err := client.IdentityAdoptionDecision.Query(). decisionCount, err := client.IdentityAdoptionDecision.Query().
Where(identityadoptiondecision.PendingAuthSessionIDEQ(pendingSession.ID)). Where(identityadoptiondecision.PendingAuthSessionIDEQ(pendingSession.ID)).
Only(ctx) Count(ctx)
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, decision.IdentityID) require.Zero(t, decisionCount)
require.Equal(t, identity.ID, *decision.IdentityID)
require.True(t, decision.AdoptDisplayName)
require.True(t, decision.AdoptAvatar)
consumed, err := client.PendingAuthSession.Query().
Where(pendingauthsession.IDEQ(pendingSession.ID)).
Only(ctx)
require.NoError(t, err)
require.NotNil(t, consumed.ConsumedAt)
} }
func TestWeChatOAuthCallbackRepairsLegacyOpenIDOnlyIdentity(t *testing.T) { func TestWeChatOAuthCallbackRepairsLegacyOpenIDOnlyIdentity(t *testing.T) {
@@ -981,6 +975,62 @@ func TestCompleteWeChatOAuthRegistrationRejectsAdoptExistingUserSession(t *testi
require.Nil(t, storedSession.ConsumedAt) require.Nil(t, storedSession.ConsumedAt)
} }
func TestCompleteWeChatOAuthRegistrationReturnsPendingSessionWhenChoiceStillRequired(t *testing.T) {
handler, client := newWeChatOAuthTestHandler(t, false)
defer client.Close()
ctx := context.Background()
session, err := client.PendingAuthSession.Create().
SetSessionToken("wechat-complete-choice-session").
SetIntent("login").
SetProviderType("wechat").
SetProviderKey("wechat-main").
SetProviderSubject("wechat-choice-subject-1").
SetResolvedEmail("wechat-choice-subject-1@wechat-connect.invalid").
SetBrowserSessionKey("wechat-choice-browser").
SetUpstreamIdentityClaims(map[string]any{
"username": "wechat_user",
}).
SetLocalFlowState(map[string]any{
oauthCompletionResponseKey: map[string]any{
"step": oauthPendingChoiceStep,
"redirect": "/dashboard",
"email": "fresh@example.com",
"resolved_email": "fresh@example.com",
"force_email_on_signup": true,
},
}).
SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
Save(ctx)
require.NoError(t, err)
body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`)
recorder := httptest.NewRecorder()
completeCtx, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/wechat/complete-registration", body)
req.Header.Set("Content-Type", "application/json")
req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("wechat-choice-browser")})
completeCtx.Request = req
handler.CompleteWeChatOAuthRegistration(completeCtx)
require.Equal(t, http.StatusOK, recorder.Code)
responseData := decodeJSONBody(t, recorder)
require.Equal(t, "pending_session", responseData["auth_result"])
require.Equal(t, oauthPendingChoiceStep, responseData["step"])
require.Equal(t, true, responseData["force_email_on_signup"])
require.Empty(t, responseData["access_token"])
userCount, err := client.User.Query().Count(ctx)
require.NoError(t, err)
require.Zero(t, userCount)
storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
require.NoError(t, err)
require.Nil(t, storedSession.ConsumedAt)
}
func TestWeChatOAuthCallbackRepairsLegacyProviderKeyCanonicalIdentity(t *testing.T) { func TestWeChatOAuthCallbackRepairsLegacyProviderKeyCanonicalIdentity(t *testing.T) {
originalAccessTokenURL := wechatOAuthAccessTokenURL originalAccessTokenURL := wechatOAuthAccessTokenURL
originalUserInfoURL := wechatOAuthUserInfoURL originalUserInfoURL := wechatOAuthUserInfoURL

View File

@@ -27,23 +27,50 @@ func BackendModeUserGuard(settingService *service.SettingService) gin.HandlerFun
} }
} }
func backendModeAllowsAuthPath(path string) bool {
path = strings.ToLower(strings.TrimSpace(path))
for _, suffix := range []string{"/auth/login", "/auth/login/2fa", "/auth/logout", "/auth/refresh"} {
if strings.HasSuffix(path, suffix) {
return true
}
}
for _, suffix := range []string{
"/auth/oauth/linuxdo/callback",
"/auth/oauth/wechat/callback",
"/auth/oauth/wechat/payment/callback",
"/auth/oauth/oidc/callback",
"/auth/oauth/linuxdo/complete-registration",
"/auth/oauth/wechat/complete-registration",
"/auth/oauth/oidc/complete-registration",
"/auth/oauth/linuxdo/create-account",
"/auth/oauth/wechat/create-account",
"/auth/oauth/oidc/create-account",
"/auth/oauth/linuxdo/bind-login",
"/auth/oauth/wechat/bind-login",
"/auth/oauth/oidc/bind-login",
} {
if strings.HasSuffix(path, suffix) {
return true
}
}
return strings.Contains(path, "/auth/oauth/pending/")
}
// BackendModeAuthGuard selectively blocks auth endpoints when backend mode is enabled. // BackendModeAuthGuard selectively blocks auth endpoints when backend mode is enabled.
// Allows: login, login/2fa, logout, refresh (admin needs these). // Allows the minimal auth surface admins still need in backend mode, including
// Blocks: register, forgot-password, reset-password, OAuth, etc. // OAuth callbacks and pending continuations. Handler-level backend mode checks
// still enforce admin-only login and forbid self-service registration.
func BackendModeAuthGuard(settingService *service.SettingService) gin.HandlerFunc { func BackendModeAuthGuard(settingService *service.SettingService) gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
if settingService == nil || !settingService.IsBackendModeEnabled(c.Request.Context()) { if settingService == nil || !settingService.IsBackendModeEnabled(c.Request.Context()) {
c.Next() c.Next()
return return
} }
path := c.Request.URL.Path if backendModeAllowsAuthPath(c.Request.URL.Path) {
// Allow login, 2FA, logout, refresh, public settings c.Next()
allowedSuffixes := []string{"/auth/login", "/auth/login/2fa", "/auth/logout", "/auth/refresh"} return
for _, suffix := range allowedSuffixes {
if strings.HasSuffix(path, suffix) {
c.Next()
return
}
} }
response.Forbidden(c, "Backend mode is active. Registration and self-service auth flows are disabled.") response.Forbidden(c, "Backend mode is active. Registration and self-service auth flows are disabled.")
c.Abort() c.Abort()

View File

@@ -198,6 +198,96 @@ func TestBackendModeAuthGuard(t *testing.T) {
path: "/api/v1/auth/refresh", path: "/api/v1/auth/refresh",
wantStatus: http.StatusOK, wantStatus: http.StatusOK,
}, },
{
name: "enabled_blocks_linuxdo_oauth_start",
enabled: "true",
path: "/api/v1/auth/oauth/linuxdo/start",
wantStatus: http.StatusForbidden,
},
{
name: "enabled_allows_linuxdo_oauth_callback",
enabled: "true",
path: "/api/v1/auth/oauth/linuxdo/callback",
wantStatus: http.StatusOK,
},
{
name: "enabled_blocks_wechat_oauth_start",
enabled: "true",
path: "/api/v1/auth/oauth/wechat/start",
wantStatus: http.StatusForbidden,
},
{
name: "enabled_allows_wechat_oauth_callback",
enabled: "true",
path: "/api/v1/auth/oauth/wechat/callback",
wantStatus: http.StatusOK,
},
{
name: "enabled_blocks_wechat_payment_oauth_start",
enabled: "true",
path: "/api/v1/auth/oauth/wechat/payment/start",
wantStatus: http.StatusForbidden,
},
{
name: "enabled_allows_wechat_payment_oauth_callback",
enabled: "true",
path: "/api/v1/auth/oauth/wechat/payment/callback",
wantStatus: http.StatusOK,
},
{
name: "enabled_blocks_oidc_oauth_start",
enabled: "true",
path: "/api/v1/auth/oauth/oidc/start",
wantStatus: http.StatusForbidden,
},
{
name: "enabled_allows_oidc_oauth_callback",
enabled: "true",
path: "/api/v1/auth/oauth/oidc/callback",
wantStatus: http.StatusOK,
},
{
name: "enabled_allows_oauth_pending_exchange",
enabled: "true",
path: "/api/v1/auth/oauth/pending/exchange",
wantStatus: http.StatusOK,
},
{
name: "enabled_allows_oauth_pending_send_verify_code",
enabled: "true",
path: "/api/v1/auth/oauth/pending/send-verify-code",
wantStatus: http.StatusOK,
},
{
name: "enabled_allows_oauth_pending_create_account",
enabled: "true",
path: "/api/v1/auth/oauth/pending/create-account",
wantStatus: http.StatusOK,
},
{
name: "enabled_allows_oauth_pending_bind_login",
enabled: "true",
path: "/api/v1/auth/oauth/pending/bind-login",
wantStatus: http.StatusOK,
},
{
name: "enabled_allows_provider_bind_login",
enabled: "true",
path: "/api/v1/auth/oauth/oidc/bind-login",
wantStatus: http.StatusOK,
},
{
name: "enabled_allows_provider_create_account",
enabled: "true",
path: "/api/v1/auth/oauth/wechat/create-account",
wantStatus: http.StatusOK,
},
{
name: "enabled_allows_legacy_complete_registration",
enabled: "true",
path: "/api/v1/auth/oauth/linuxdo/complete-registration",
wantStatus: http.StatusOK,
},
{ {
name: "enabled_blocks_register", name: "enabled_blocks_register",
enabled: "true", enabled: "true",