diff --git a/backend/internal/handler/auth_oauth_pending_flow.go b/backend/internal/handler/auth_oauth_pending_flow.go index 402a6cfc..3a21c69a 100644 --- a/backend/internal/handler/auth_oauth_pending_flow.go +++ b/backend/internal/handler/auth_oauth_pending_flow.go @@ -1214,6 +1214,67 @@ func applySuggestedProfileToCompletionResponse(payload map[string]any, upstream } } +func pendingOAuthIdentityExistsForUser( + ctx context.Context, + client *dbent.Client, + session *dbent.PendingAuthSession, + userID int64, +) (bool, error) { + if client == nil || session == nil || userID <= 0 { + return false, nil + } + + providerType := strings.TrimSpace(session.ProviderType) + providerKey := strings.TrimSpace(session.ProviderKey) + providerSubject := strings.TrimSpace(session.ProviderSubject) + if providerType == "" || providerSubject == "" { + return false, nil + } + + query := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ(providerType), + authidentity.ProviderSubjectEQ(providerSubject), + authidentity.UserIDEQ(userID), + ) + if strings.EqualFold(providerType, "wechat") { + query = query.Where(authidentity.ProviderKeyIn(wechatCompatibleProviderKeys(providerKey)...)) + } else if providerKey != "" { + query = query.Where(authidentity.ProviderKeyEQ(providerKey)) + } + + count, err := query.Count(ctx) + if err != nil { + return false, infraerrors.InternalServer("AUTH_IDENTITY_LOOKUP_FAILED", "failed to inspect auth identity ownership").WithCause(err) + } + return count > 0, nil +} + +func (h *AuthHandler) shouldSkipPendingOAuthAdoptionPrompt( + ctx context.Context, + session *dbent.PendingAuthSession, + payload map[string]any, +) (bool, error) { + if session == nil || len(payload) == 0 { + return false, nil + } + if !strings.EqualFold(strings.TrimSpace(session.Intent), oauthIntentLogin) { + return false, nil + } + if !pendingOAuthCompletionIncludesTokenPayload(payload) { + return false, nil + } + if session.TargetUserID == nil || *session.TargetUserID <= 0 { + return false, nil + } + if pendingSessionStringValue(session.UpstreamIdentityClaims, "suggested_display_name") == "" && + pendingSessionStringValue(session.UpstreamIdentityClaims, "suggested_avatar_url") == "" { + return false, nil + } + + return pendingOAuthIdentityExistsForUser(ctx, h.entClient(), session, *session.TargetUserID) +} + func readPendingOAuthBrowserSession(c *gin.Context, h *AuthHandler) (*service.AuthPendingIdentityService, *dbent.PendingAuthSession, func(), error) { secureCookie := isRequestHTTPS(c) clearCookies := func() { @@ -1634,6 +1695,15 @@ func (h *AuthHandler) ExchangePendingOAuthCompletion(c *gin.Context) { } } applySuggestedProfileToCompletionResponse(payload, session.UpstreamIdentityClaims) + skipAdoptionPrompt, err := h.shouldSkipPendingOAuthAdoptionPrompt(c.Request.Context(), session, payload) + if err != nil { + clearCookies() + response.ErrorFrom(c, err) + return + } + if skipAdoptionPrompt { + delete(payload, "adoption_required") + } if pendingOAuthCompletionIncludesTokenPayload(payload) { if session.TargetUserID == nil || *session.TargetUserID <= 0 { clearCookies() diff --git a/backend/internal/handler/auth_oauth_pending_flow_test.go b/backend/internal/handler/auth_oauth_pending_flow_test.go index bf16b48d..b1ac1c4b 100644 --- a/backend/internal/handler/auth_oauth_pending_flow_test.go +++ b/backend/internal/handler/auth_oauth_pending_flow_test.go @@ -599,6 +599,86 @@ func TestExchangePendingOAuthCompletionLoginWithoutDecisionStillBindsIdentity(t require.NotNil(t, storedSession.ConsumedAt) } +func TestExchangePendingOAuthCompletionExistingLoginWithSuggestedProfileSkipsAdoptionPrompt(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandler(t, false) + ctx := context.Background() + + userEntity, err := client.User.Create(). + SetEmail("existing-login@example.com"). + SetUsername("existing-login-user"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + _, err = client.AuthIdentity.Create(). + SetUserID(userEntity.ID). + SetProviderType("linuxdo"). + SetProviderKey("linuxdo"). + SetProviderSubject("existing-login-123"). + SetMetadata(map[string]any{ + "username": "existing-login-user", + }). + Save(ctx) + require.NoError(t, err) + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("existing-login-session-token"). + SetIntent("login"). + SetProviderType("linuxdo"). + SetProviderKey("linuxdo"). + SetProviderSubject("existing-login-123"). + SetTargetUserID(userEntity.ID). + SetResolvedEmail(userEntity.Email). + SetBrowserSessionKey("existing-login-browser-session-key"). + SetUpstreamIdentityClaims(map[string]any{ + "suggested_display_name": "Existing Login Example", + "suggested_avatar_url": "https://cdn.example/existing-login.png", + }). + SetLocalFlowState(map[string]any{ + oauthCompletionResponseKey: map[string]any{ + "access_token": "access-token", + "refresh_token": "refresh-token", + "expires_in": float64(3600), + "token_type": "Bearer", + "redirect": "/dashboard", + }, + }). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + recorder := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", nil) + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("existing-login-browser-session-key")}) + ginCtx.Request = req + + handler.ExchangePendingOAuthCompletion(ginCtx) + + require.Equal(t, http.StatusOK, recorder.Code) + + payload := decodeJSONResponseData(t, recorder) + require.Equal(t, "access-token", payload["access_token"]) + require.Equal(t, "refresh-token", payload["refresh_token"]) + require.Equal(t, "/dashboard", payload["redirect"]) + require.Equal(t, "Existing Login Example", payload["suggested_display_name"]) + require.Equal(t, "https://cdn.example/existing-login.png", payload["suggested_avatar_url"]) + require.NotContains(t, payload, "adoption_required") + + decisionCount, err := client.IdentityAdoptionDecision.Query(). + Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)). + Count(ctx) + require.NoError(t, err) + require.Equal(t, 1, decisionCount) + + storedSession, err := client.PendingAuthSession.Get(ctx, session.ID) + require.NoError(t, err) + require.NotNil(t, storedSession.ConsumedAt) +} + func TestExchangePendingOAuthCompletionBlocksBackendModeBeforeReturningTokenPayload(t *testing.T) { handler, client := newOAuthPendingFlowTestHandlerWithDependencies(t, oauthPendingFlowTestHandlerOptions{ settingValues: map[string]string{