fix(auth): scrub legacy pending oauth tokens on upgrade

This commit is contained in:
IanShaw027
2026-04-22 11:29:05 +08:00
parent 9d5e9bbc18
commit be9df2bea7
6 changed files with 123 additions and 0 deletions

View File

@@ -1290,6 +1290,9 @@ func buildPendingOAuthSessionStatusPayload(session *dbent.PendingAuthSession) gi
func normalizePendingOAuthCompletionResponse(payload map[string]any) map[string]any {
normalized := clonePendingMap(payload)
for _, key := range []string{"access_token", "refresh_token", "expires_in", "token_type"} {
delete(normalized, key)
}
step := strings.ToLower(strings.TrimSpace(pendingSessionStringValue(normalized, "step")))
switch step {
case "choice", "choose_account_action", "choose_account", "choose", "email_required", "bind_login_required":

View File

@@ -851,6 +851,22 @@ func TestExchangePendingOAuthCompletionBlocksBackendModeBeforeReturningTokenPayl
require.Nil(t, storedSession.ConsumedAt)
}
func TestNormalizePendingOAuthCompletionResponseScrubsLegacyTokenPayload(t *testing.T) {
payload := normalizePendingOAuthCompletionResponse(map[string]any{
"access_token": "legacy-access-token",
"refresh_token": "legacy-refresh-token",
"expires_in": float64(3600),
"token_type": "Bearer",
"redirect": "/dashboard",
})
require.NotContains(t, payload, "access_token")
require.NotContains(t, payload, "refresh_token")
require.NotContains(t, payload, "expires_in")
require.NotContains(t, payload, "token_type")
require.Equal(t, "/dashboard", payload["redirect"])
}
func TestExchangePendingOAuthCompletionInvitationRequiredFalseFalsePersistsDecisionWithoutBinding(t *testing.T) {
handler, client := newOAuthPendingFlowTestHandler(t, true)
ctx := context.Background()

View File

@@ -236,6 +236,7 @@ func (s *AuthPendingIdentityService) consumeSession(
return nil, err
}
sanitizedLocalFlowState := sanitizePendingAuthLocalFlowState(session.LocalFlowState)
now := time.Now().UTC()
update := s.entClient.PendingAuthSession.UpdateOneID(session.ID).
Where(
@@ -247,6 +248,7 @@ func (s *AuthPendingIdentityService) consumeSession(
),
).
SetConsumedAt(now).
SetLocalFlowState(sanitizedLocalFlowState).
SetCompletionCodeHash("").
ClearCompletionCodeExpiresAt()
if expectedBrowserSessionKey := strings.TrimSpace(session.BrowserSessionKey); expectedBrowserSessionKey != "" {
@@ -273,6 +275,29 @@ func (s *AuthPendingIdentityService) consumeSession(
return nil, consumedErr
}
func sanitizePendingAuthLocalFlowState(localFlowState map[string]any) map[string]any {
sanitized := copyPendingMap(localFlowState)
if len(sanitized) == 0 {
return sanitized
}
rawCompletion, ok := sanitized["completion_response"]
if !ok {
return sanitized
}
completion, ok := rawCompletion.(map[string]any)
if !ok {
return sanitized
}
cleanedCompletion := copyPendingMap(completion)
for _, key := range []string{"access_token", "refresh_token", "expires_in", "token_type"} {
delete(cleanedCompletion, key)
}
sanitized["completion_response"] = cleanedCompletion
return sanitized
}
func validatePendingSessionState(session *dbent.PendingAuthSession, browserSessionKey string, expiredErr error, consumedErr error) error {
if session == nil {
return ErrPendingAuthSessionNotFound