fix(auth): harden pending oauth session consumption

This commit is contained in:
IanShaw027
2026-04-22 11:17:38 +08:00
parent 84628108fc
commit ca1f30a911
4 changed files with 146 additions and 38 deletions

View File

@@ -277,6 +277,22 @@ func pendingOAuthCompletionIncludesTokenPayload(payload map[string]any) bool {
return false
}
func pendingOAuthCompletionCanIssueTokenPair(session *dbent.PendingAuthSession, payload map[string]any) bool {
if session == nil {
return false
}
if !strings.EqualFold(strings.TrimSpace(session.Intent), oauthIntentLogin) {
return false
}
if session.TargetUserID == nil || *session.TargetUserID <= 0 {
return false
}
if pendingSessionWantsInvitation(payload) {
return false
}
return strings.TrimSpace(pendingSessionStringValue(payload, "step")) == ""
}
func ensurePendingOAuthCompleteRegistrationSession(session *dbent.PendingAuthSession) error {
if session == nil {
return infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid")
@@ -1212,13 +1228,7 @@ func (h *AuthHandler) shouldSkipPendingOAuthAdoptionPrompt(
if session == nil || len(payload) == 0 {
return false, nil
}
if !strings.EqualFold(strings.TrimSpace(session.Intent), oauthIntentLogin) {
return false, nil
}
if !pendingOAuthCompletionIncludesTokenPayload(payload) {
return false, nil
}
if session.TargetUserID == nil || *session.TargetUserID <= 0 {
if !pendingOAuthCompletionCanIssueTokenPair(session, payload) {
return false, nil
}
if pendingSessionStringValue(session.UpstreamIdentityClaims, "suggested_display_name") == "" &&
@@ -1649,6 +1659,22 @@ func (h *AuthHandler) ExchangePendingOAuthCompletion(c *gin.Context) {
}
}
applySuggestedProfileToCompletionResponse(payload, session.UpstreamIdentityClaims)
canIssueTokenPair := pendingOAuthCompletionCanIssueTokenPair(session, payload)
var loginUser *service.User
if canIssueTokenPair {
loginUser, err = h.userService.GetByID(c.Request.Context(), *session.TargetUserID)
if err != nil {
clearCookies()
response.ErrorFrom(c, err)
return
}
if err := h.ensureBackendModeAllowsUser(c.Request.Context(), loginUser); err != nil {
clearCookies()
response.ErrorFrom(c, err)
return
}
}
skipAdoptionPrompt, err := h.shouldSkipPendingOAuthAdoptionPrompt(c.Request.Context(), session, payload)
if err != nil {
clearCookies()
@@ -1658,25 +1684,6 @@ func (h *AuthHandler) ExchangePendingOAuthCompletion(c *gin.Context) {
if skipAdoptionPrompt {
delete(payload, "adoption_required")
}
if pendingOAuthCompletionIncludesTokenPayload(payload) {
if session.TargetUserID == nil || *session.TargetUserID <= 0 {
clearCookies()
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_COMPLETION_INVALID", "pending auth completion payload is invalid"))
return
}
user, err := h.userService.GetByID(c.Request.Context(), *session.TargetUserID)
if err != nil {
clearCookies()
response.ErrorFrom(c, err)
return
}
if err := h.ensureBackendModeAllowsUser(c.Request.Context(), user); err != nil {
clearCookies()
response.ErrorFrom(c, err)
return
}
h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID)
}
if pendingSessionWantsInvitation(payload) {
if adoptionDecision.hasDecision() {
@@ -1724,6 +1731,20 @@ func (h *AuthHandler) ExchangePendingOAuthCompletion(c *gin.Context) {
return
}
if canIssueTokenPair {
tokenPair, err := h.authService.GenerateTokenPair(c.Request.Context(), loginUser, "")
if err != nil {
clearCookies()
response.InternalError(c, "Failed to generate token pair")
return
}
h.authService.RecordSuccessfulLogin(c.Request.Context(), loginUser.ID)
payload["access_token"] = tokenPair.AccessToken
payload["refresh_token"] = tokenPair.RefreshToken
payload["expires_in"] = tokenPair.ExpiresIn
payload["token_type"] = "Bearer"
}
clearCookies()
response.Success(c, payload)
}

View File

@@ -746,11 +746,7 @@ func TestExchangePendingOAuthCompletionExistingLoginWithSuggestedProfileSkipsAdo
}).
SetLocalFlowState(map[string]any{
oauthCompletionResponseKey: map[string]any{
"access_token": "access-token",
"refresh_token": "refresh-token",
"expires_in": float64(3600),
"token_type": "Bearer",
"redirect": "/dashboard",
"redirect": "/dashboard",
},
}).
SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
@@ -769,8 +765,8 @@ func TestExchangePendingOAuthCompletionExistingLoginWithSuggestedProfileSkipsAdo
require.Equal(t, http.StatusOK, recorder.Code)
payload := decodeJSONResponseData(t, recorder)
require.Equal(t, "access-token", payload["access_token"])
require.Equal(t, "refresh-token", payload["refresh_token"])
require.NotEmpty(t, payload["access_token"])
require.NotEmpty(t, payload["refresh_token"])
require.Equal(t, "/dashboard", payload["redirect"])
require.Equal(t, "Existing Login Example", payload["suggested_display_name"])
require.Equal(t, "https://cdn.example/existing-login.png", payload["suggested_avatar_url"])