Fix repeated OAuth adoption prompt for existing logins
This commit is contained in:
@@ -1214,6 +1214,67 @@ func applySuggestedProfileToCompletionResponse(payload map[string]any, upstream
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func pendingOAuthIdentityExistsForUser(
|
||||||
|
ctx context.Context,
|
||||||
|
client *dbent.Client,
|
||||||
|
session *dbent.PendingAuthSession,
|
||||||
|
userID int64,
|
||||||
|
) (bool, error) {
|
||||||
|
if client == nil || session == nil || userID <= 0 {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
providerType := strings.TrimSpace(session.ProviderType)
|
||||||
|
providerKey := strings.TrimSpace(session.ProviderKey)
|
||||||
|
providerSubject := strings.TrimSpace(session.ProviderSubject)
|
||||||
|
if providerType == "" || providerSubject == "" {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
query := client.AuthIdentity.Query().
|
||||||
|
Where(
|
||||||
|
authidentity.ProviderTypeEQ(providerType),
|
||||||
|
authidentity.ProviderSubjectEQ(providerSubject),
|
||||||
|
authidentity.UserIDEQ(userID),
|
||||||
|
)
|
||||||
|
if strings.EqualFold(providerType, "wechat") {
|
||||||
|
query = query.Where(authidentity.ProviderKeyIn(wechatCompatibleProviderKeys(providerKey)...))
|
||||||
|
} else if providerKey != "" {
|
||||||
|
query = query.Where(authidentity.ProviderKeyEQ(providerKey))
|
||||||
|
}
|
||||||
|
|
||||||
|
count, err := query.Count(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return false, infraerrors.InternalServer("AUTH_IDENTITY_LOOKUP_FAILED", "failed to inspect auth identity ownership").WithCause(err)
|
||||||
|
}
|
||||||
|
return count > 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *AuthHandler) shouldSkipPendingOAuthAdoptionPrompt(
|
||||||
|
ctx context.Context,
|
||||||
|
session *dbent.PendingAuthSession,
|
||||||
|
payload map[string]any,
|
||||||
|
) (bool, error) {
|
||||||
|
if session == nil || len(payload) == 0 {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
if !strings.EqualFold(strings.TrimSpace(session.Intent), oauthIntentLogin) {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
if !pendingOAuthCompletionIncludesTokenPayload(payload) {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
if session.TargetUserID == nil || *session.TargetUserID <= 0 {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
if pendingSessionStringValue(session.UpstreamIdentityClaims, "suggested_display_name") == "" &&
|
||||||
|
pendingSessionStringValue(session.UpstreamIdentityClaims, "suggested_avatar_url") == "" {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return pendingOAuthIdentityExistsForUser(ctx, h.entClient(), session, *session.TargetUserID)
|
||||||
|
}
|
||||||
|
|
||||||
func readPendingOAuthBrowserSession(c *gin.Context, h *AuthHandler) (*service.AuthPendingIdentityService, *dbent.PendingAuthSession, func(), error) {
|
func readPendingOAuthBrowserSession(c *gin.Context, h *AuthHandler) (*service.AuthPendingIdentityService, *dbent.PendingAuthSession, func(), error) {
|
||||||
secureCookie := isRequestHTTPS(c)
|
secureCookie := isRequestHTTPS(c)
|
||||||
clearCookies := func() {
|
clearCookies := func() {
|
||||||
@@ -1634,6 +1695,15 @@ func (h *AuthHandler) ExchangePendingOAuthCompletion(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
applySuggestedProfileToCompletionResponse(payload, session.UpstreamIdentityClaims)
|
applySuggestedProfileToCompletionResponse(payload, session.UpstreamIdentityClaims)
|
||||||
|
skipAdoptionPrompt, err := h.shouldSkipPendingOAuthAdoptionPrompt(c.Request.Context(), session, payload)
|
||||||
|
if err != nil {
|
||||||
|
clearCookies()
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if skipAdoptionPrompt {
|
||||||
|
delete(payload, "adoption_required")
|
||||||
|
}
|
||||||
if pendingOAuthCompletionIncludesTokenPayload(payload) {
|
if pendingOAuthCompletionIncludesTokenPayload(payload) {
|
||||||
if session.TargetUserID == nil || *session.TargetUserID <= 0 {
|
if session.TargetUserID == nil || *session.TargetUserID <= 0 {
|
||||||
clearCookies()
|
clearCookies()
|
||||||
|
|||||||
@@ -599,6 +599,86 @@ func TestExchangePendingOAuthCompletionLoginWithoutDecisionStillBindsIdentity(t
|
|||||||
require.NotNil(t, storedSession.ConsumedAt)
|
require.NotNil(t, storedSession.ConsumedAt)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestExchangePendingOAuthCompletionExistingLoginWithSuggestedProfileSkipsAdoptionPrompt(t *testing.T) {
|
||||||
|
handler, client := newOAuthPendingFlowTestHandler(t, false)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
userEntity, err := client.User.Create().
|
||||||
|
SetEmail("existing-login@example.com").
|
||||||
|
SetUsername("existing-login-user").
|
||||||
|
SetPasswordHash("hash").
|
||||||
|
SetRole(service.RoleUser).
|
||||||
|
SetStatus(service.StatusActive).
|
||||||
|
Save(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
_, err = client.AuthIdentity.Create().
|
||||||
|
SetUserID(userEntity.ID).
|
||||||
|
SetProviderType("linuxdo").
|
||||||
|
SetProviderKey("linuxdo").
|
||||||
|
SetProviderSubject("existing-login-123").
|
||||||
|
SetMetadata(map[string]any{
|
||||||
|
"username": "existing-login-user",
|
||||||
|
}).
|
||||||
|
Save(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
session, err := client.PendingAuthSession.Create().
|
||||||
|
SetSessionToken("existing-login-session-token").
|
||||||
|
SetIntent("login").
|
||||||
|
SetProviderType("linuxdo").
|
||||||
|
SetProviderKey("linuxdo").
|
||||||
|
SetProviderSubject("existing-login-123").
|
||||||
|
SetTargetUserID(userEntity.ID).
|
||||||
|
SetResolvedEmail(userEntity.Email).
|
||||||
|
SetBrowserSessionKey("existing-login-browser-session-key").
|
||||||
|
SetUpstreamIdentityClaims(map[string]any{
|
||||||
|
"suggested_display_name": "Existing Login Example",
|
||||||
|
"suggested_avatar_url": "https://cdn.example/existing-login.png",
|
||||||
|
}).
|
||||||
|
SetLocalFlowState(map[string]any{
|
||||||
|
oauthCompletionResponseKey: map[string]any{
|
||||||
|
"access_token": "access-token",
|
||||||
|
"refresh_token": "refresh-token",
|
||||||
|
"expires_in": float64(3600),
|
||||||
|
"token_type": "Bearer",
|
||||||
|
"redirect": "/dashboard",
|
||||||
|
},
|
||||||
|
}).
|
||||||
|
SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
|
||||||
|
Save(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
ginCtx, _ := gin.CreateTestContext(recorder)
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", nil)
|
||||||
|
req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
|
||||||
|
req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("existing-login-browser-session-key")})
|
||||||
|
ginCtx.Request = req
|
||||||
|
|
||||||
|
handler.ExchangePendingOAuthCompletion(ginCtx)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusOK, recorder.Code)
|
||||||
|
|
||||||
|
payload := decodeJSONResponseData(t, recorder)
|
||||||
|
require.Equal(t, "access-token", payload["access_token"])
|
||||||
|
require.Equal(t, "refresh-token", payload["refresh_token"])
|
||||||
|
require.Equal(t, "/dashboard", payload["redirect"])
|
||||||
|
require.Equal(t, "Existing Login Example", payload["suggested_display_name"])
|
||||||
|
require.Equal(t, "https://cdn.example/existing-login.png", payload["suggested_avatar_url"])
|
||||||
|
require.NotContains(t, payload, "adoption_required")
|
||||||
|
|
||||||
|
decisionCount, err := client.IdentityAdoptionDecision.Query().
|
||||||
|
Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)).
|
||||||
|
Count(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 1, decisionCount)
|
||||||
|
|
||||||
|
storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, storedSession.ConsumedAt)
|
||||||
|
}
|
||||||
|
|
||||||
func TestExchangePendingOAuthCompletionBlocksBackendModeBeforeReturningTokenPayload(t *testing.T) {
|
func TestExchangePendingOAuthCompletionBlocksBackendModeBeforeReturningTokenPayload(t *testing.T) {
|
||||||
handler, client := newOAuthPendingFlowTestHandlerWithDependencies(t, oauthPendingFlowTestHandlerOptions{
|
handler, client := newOAuthPendingFlowTestHandlerWithDependencies(t, oauthPendingFlowTestHandlerOptions{
|
||||||
settingValues: map[string]string{
|
settingValues: map[string]string{
|
||||||
|
|||||||
Reference in New Issue
Block a user