From fbd0a2e3c488720025a3408e9db234407d8aef9b Mon Sep 17 00:00:00 2001 From: IanShaw027 Date: Mon, 20 Apr 2026 16:27:23 +0800 Subject: [PATCH] feat: carry suggested third-party profile through pending oauth --- .../internal/handler/auth_linuxdo_oauth.go | 201 +++++++++---- .../handler/auth_linuxdo_oauth_test.go | 12 +- .../handler/auth_oauth_pending_flow.go | 263 ++++++++++++++++++ .../handler/auth_oauth_pending_flow_test.go | 40 +++ backend/internal/handler/auth_oidc_oauth.go | 47 +++- .../internal/handler/auth_oidc_oauth_test.go | 20 ++ frontend/src/api/auth.ts | 24 +- 7 files changed, 534 insertions(+), 73 deletions(-) create mode 100644 backend/internal/handler/auth_oauth_pending_flow.go create mode 100644 backend/internal/handler/auth_oauth_pending_flow_test.go diff --git a/backend/internal/handler/auth_linuxdo_oauth.go b/backend/internal/handler/auth_linuxdo_oauth.go index 0c7c2da7..2f182642 100644 --- a/backend/internal/handler/auth_linuxdo_oauth.go +++ b/backend/internal/handler/auth_linuxdo_oauth.go @@ -87,20 +87,25 @@ func (h *AuthHandler) LinuxDoOAuthStart(c *gin.Context) { redirectTo = linuxDoOAuthDefaultRedirectTo } + browserSessionKey, err := generateOAuthPendingBrowserSession() + if err != nil { + response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_BROWSER_SESSION_GEN_FAILED", "failed to generate oauth browser session").WithCause(err)) + return + } + secureCookie := isRequestHTTPS(c) setCookie(c, linuxDoOAuthStateCookieName, encodeCookieValue(state), linuxDoOAuthCookieMaxAgeSec, secureCookie) setCookie(c, linuxDoOAuthRedirectCookie, encodeCookieValue(redirectTo), linuxDoOAuthCookieMaxAgeSec, secureCookie) + setOAuthPendingBrowserCookie(c, browserSessionKey, secureCookie) + clearOAuthPendingSessionCookie(c, secureCookie) - codeChallenge := "" - if cfg.UsePKCE { - verifier, err := oauth.GenerateCodeVerifier() - if err != nil { - response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_PKCE_GEN_FAILED", "failed to generate pkce verifier").WithCause(err)) - return - } - codeChallenge = oauth.GenerateCodeChallenge(verifier) - setCookie(c, linuxDoOAuthVerifierCookie, encodeCookieValue(verifier), linuxDoOAuthCookieMaxAgeSec, secureCookie) + verifier, err := oauth.GenerateCodeVerifier() + if err != nil { + response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_PKCE_GEN_FAILED", "failed to generate pkce verifier").WithCause(err)) + return } + codeChallenge := oauth.GenerateCodeChallenge(verifier) + setCookie(c, linuxDoOAuthVerifierCookie, encodeCookieValue(verifier), linuxDoOAuthCookieMaxAgeSec, secureCookie) redirectURI := strings.TrimSpace(cfg.RedirectURL) if redirectURI == "" { @@ -161,14 +166,16 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) { if redirectTo == "" { redirectTo = linuxDoOAuthDefaultRedirectTo } + browserSessionKey, _ := readOAuthPendingBrowserCookie(c) + if strings.TrimSpace(browserSessionKey) == "" { + redirectOAuthError(c, frontendCallback, "missing_browser_session", "missing oauth browser session", "") + return + } - codeVerifier := "" - if cfg.UsePKCE { - codeVerifier, _ = readCookieDecoded(c, linuxDoOAuthVerifierCookie) - if codeVerifier == "" { - redirectOAuthError(c, frontendCallback, "missing_verifier", "missing pkce verifier", "") - return - } + codeVerifier, _ := readCookieDecoded(c, linuxDoOAuthVerifierCookie) + if codeVerifier == "" { + redirectOAuthError(c, frontendCallback, "missing_verifier", "missing pkce verifier", "") + return } redirectURI := strings.TrimSpace(cfg.RedirectURL) @@ -198,7 +205,7 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) { return } - email, username, subject, err := linuxDoFetchUserInfo(c.Request.Context(), cfg, tokenResp) + email, username, subject, displayName, avatarURL, err := linuxDoFetchUserInfo(c.Request.Context(), cfg, tokenResp) if err != nil { log.Printf("[LinuxDo OAuth] userinfo fetch failed: %v", err) redirectOAuthError(c, frontendCallback, "userinfo_failed", "failed to fetch user info", "") @@ -215,16 +222,32 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) { tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, "") if err != nil { if errors.Is(err, service.ErrOAuthInvitationRequired) { - pendingToken, tokenErr := h.authService.CreatePendingOAuthToken(email, username) - if tokenErr != nil { - redirectOAuthError(c, frontendCallback, "login_failed", "service_error", "") + if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{ + Intent: "login", + Identity: service.PendingAuthIdentityKey{ + ProviderType: "linuxdo", + ProviderKey: "linuxdo", + ProviderSubject: subject, + }, + ResolvedEmail: email, + RedirectTo: redirectTo, + BrowserSessionKey: browserSessionKey, + UpstreamIdentityClaims: map[string]any{ + "email": email, + "username": username, + "subject": subject, + "suggested_display_name": displayName, + "suggested_avatar_url": avatarURL, + }, + CompletionResponse: map[string]any{ + "error": "invitation_required", + "redirect": redirectTo, + }, + }); err != nil { + redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "") return } - fragment := url.Values{} - fragment.Set("error", "invitation_required") - fragment.Set("pending_oauth_token", pendingToken) - fragment.Set("redirect", redirectTo) - redirectWithFragment(c, frontendCallback, fragment) + redirectToFrontendCallback(c, frontendCallback) return } // 避免把内部细节泄露给客户端;给前端保留结构化原因与提示信息即可。 @@ -232,18 +255,39 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) { return } - fragment := url.Values{} - fragment.Set("access_token", tokenPair.AccessToken) - fragment.Set("refresh_token", tokenPair.RefreshToken) - fragment.Set("expires_in", fmt.Sprintf("%d", tokenPair.ExpiresIn)) - fragment.Set("token_type", "Bearer") - fragment.Set("redirect", redirectTo) - redirectWithFragment(c, frontendCallback, fragment) + if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{ + Intent: "login", + Identity: service.PendingAuthIdentityKey{ + ProviderType: "linuxdo", + ProviderKey: "linuxdo", + ProviderSubject: subject, + }, + ResolvedEmail: email, + RedirectTo: redirectTo, + BrowserSessionKey: browserSessionKey, + UpstreamIdentityClaims: map[string]any{ + "email": email, + "username": username, + "subject": subject, + "suggested_display_name": displayName, + "suggested_avatar_url": avatarURL, + }, + CompletionResponse: map[string]any{ + "access_token": tokenPair.AccessToken, + "refresh_token": tokenPair.RefreshToken, + "expires_in": tokenPair.ExpiresIn, + "token_type": "Bearer", + "redirect": redirectTo, + }, + }); err != nil { + redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "") + return + } + redirectToFrontendCallback(c, frontendCallback) } type completeLinuxDoOAuthRequest struct { - PendingOAuthToken string `json:"pending_oauth_token" binding:"required"` - InvitationCode string `json:"invitation_code" binding:"required"` + InvitationCode string `json:"invitation_code" binding:"required"` } // CompleteLinuxDoOAuthRegistration completes a pending OAuth registration by validating @@ -256,9 +300,38 @@ func (h *AuthHandler) CompleteLinuxDoOAuthRegistration(c *gin.Context) { return } - email, username, err := h.authService.VerifyPendingOAuthToken(req.PendingOAuthToken) + secureCookie := isRequestHTTPS(c) + sessionToken, err := readOAuthPendingSessionCookie(c) if err != nil { - c.JSON(http.StatusUnauthorized, gin.H{"error": "INVALID_TOKEN", "message": "invalid or expired registration token"}) + clearOAuthPendingSessionCookie(c, secureCookie) + clearOAuthPendingBrowserCookie(c, secureCookie) + response.ErrorFrom(c, service.ErrPendingAuthSessionNotFound) + return + } + browserSessionKey, err := readOAuthPendingBrowserCookie(c) + if err != nil { + clearOAuthPendingSessionCookie(c, secureCookie) + clearOAuthPendingBrowserCookie(c, secureCookie) + response.ErrorFrom(c, service.ErrPendingAuthBrowserMismatch) + return + } + pendingSvc, err := h.pendingIdentityService() + if err != nil { + response.ErrorFrom(c, err) + return + } + session, err := pendingSvc.GetBrowserSession(c.Request.Context(), sessionToken, browserSessionKey) + if err != nil { + clearOAuthPendingSessionCookie(c, secureCookie) + clearOAuthPendingBrowserCookie(c, secureCookie) + response.ErrorFrom(c, err) + return + } + + email := strings.TrimSpace(session.ResolvedEmail) + username := pendingSessionStringValue(session.UpstreamIdentityClaims, "username") + if email == "" || username == "" { + response.ErrorFrom(c, infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid")) return } @@ -267,6 +340,14 @@ func (h *AuthHandler) CompleteLinuxDoOAuthRegistration(c *gin.Context) { response.ErrorFrom(c, err) return } + if _, err := pendingSvc.ConsumeBrowserSession(c.Request.Context(), sessionToken, browserSessionKey); err != nil { + clearOAuthPendingSessionCookie(c, secureCookie) + clearOAuthPendingBrowserCookie(c, secureCookie) + response.ErrorFrom(c, err) + return + } + clearOAuthPendingSessionCookie(c, secureCookie) + clearOAuthPendingBrowserCookie(c, secureCookie) c.JSON(http.StatusOK, gin.H{ "access_token": tokenPair.AccessToken, @@ -303,9 +384,7 @@ func linuxDoExchangeCode( form.Set("client_id", cfg.ClientID) form.Set("code", code) form.Set("redirect_uri", redirectURI) - if cfg.UsePKCE { - form.Set("code_verifier", codeVerifier) - } + form.Set("code_verifier", codeVerifier) r := client.R(). SetContext(ctx). @@ -353,11 +432,11 @@ func linuxDoFetchUserInfo( ctx context.Context, cfg config.LinuxDoConnectConfig, token *linuxDoTokenResponse, -) (email string, username string, subject string, err error) { +) (email string, username string, subject string, displayName string, avatarURL string, err error) { client := req.C().SetTimeout(30 * time.Second) authorization, err := buildBearerAuthorization(token.TokenType, token.AccessToken) if err != nil { - return "", "", "", fmt.Errorf("invalid token for userinfo request: %w", err) + return "", "", "", "", "", fmt.Errorf("invalid token for userinfo request: %w", err) } resp, err := client.R(). @@ -366,16 +445,16 @@ func linuxDoFetchUserInfo( SetHeader("Authorization", authorization). Get(cfg.UserInfoURL) if err != nil { - return "", "", "", fmt.Errorf("request userinfo: %w", err) + return "", "", "", "", "", fmt.Errorf("request userinfo: %w", err) } if !resp.IsSuccessState() { - return "", "", "", fmt.Errorf("userinfo status=%d", resp.StatusCode) + return "", "", "", "", "", fmt.Errorf("userinfo status=%d", resp.StatusCode) } return linuxDoParseUserInfo(resp.String(), cfg) } -func linuxDoParseUserInfo(body string, cfg config.LinuxDoConnectConfig) (email string, username string, subject string, err error) { +func linuxDoParseUserInfo(body string, cfg config.LinuxDoConnectConfig) (email string, username string, subject string, displayName string, avatarURL string, err error) { email = firstNonEmpty( getGJSON(body, cfg.UserInfoEmailPath), getGJSON(body, "email"), @@ -400,12 +479,29 @@ func linuxDoParseUserInfo(body string, cfg config.LinuxDoConnectConfig) (email s getGJSON(body, "user.id"), ) + displayName = firstNonEmpty( + getGJSON(body, "name"), + getGJSON(body, "nickname"), + getGJSON(body, "display_name"), + getGJSON(body, "user.name"), + getGJSON(body, "user.username"), + username, + ) + avatarURL = firstNonEmpty( + getGJSON(body, "avatar_url"), + getGJSON(body, "avatar"), + getGJSON(body, "picture"), + getGJSON(body, "profile_image_url"), + getGJSON(body, "user.avatar"), + getGJSON(body, "user.avatar_url"), + ) + subject = strings.TrimSpace(subject) if subject == "" { - return "", "", "", errors.New("userinfo missing id field") + return "", "", "", "", "", errors.New("userinfo missing id field") } if !isSafeLinuxDoSubject(subject) { - return "", "", "", errors.New("userinfo returned invalid id field") + return "", "", "", "", "", errors.New("userinfo returned invalid id field") } email = strings.TrimSpace(email) @@ -418,8 +514,13 @@ func linuxDoParseUserInfo(body string, cfg config.LinuxDoConnectConfig) (email s if username == "" { username = "linuxdo_" + subject } + displayName = strings.TrimSpace(displayName) + if displayName == "" { + displayName = username + } + avatarURL = strings.TrimSpace(avatarURL) - return email, username, subject, nil + return email, username, subject, displayName, avatarURL, nil } func buildLinuxDoAuthorizeURL(cfg config.LinuxDoConnectConfig, state string, codeChallenge string, redirectURI string) (string, error) { @@ -436,10 +537,8 @@ func buildLinuxDoAuthorizeURL(cfg config.LinuxDoConnectConfig, state string, cod q.Set("scope", cfg.Scopes) } q.Set("state", state) - if cfg.UsePKCE { - q.Set("code_challenge", codeChallenge) - q.Set("code_challenge_method", "S256") - } + q.Set("code_challenge", codeChallenge) + q.Set("code_challenge_method", "S256") u.RawQuery = q.Encode() return u.String(), nil diff --git a/backend/internal/handler/auth_linuxdo_oauth_test.go b/backend/internal/handler/auth_linuxdo_oauth_test.go index ff169c52..90bc10d1 100644 --- a/backend/internal/handler/auth_linuxdo_oauth_test.go +++ b/backend/internal/handler/auth_linuxdo_oauth_test.go @@ -41,11 +41,13 @@ func TestLinuxDoParseUserInfoParsesIDAndUsername(t *testing.T) { UserInfoURL: "https://connect.linux.do/api/user", } - email, username, subject, err := linuxDoParseUserInfo(`{"id":123,"username":"alice"}`, cfg) + email, username, subject, displayName, avatarURL, err := linuxDoParseUserInfo(`{"id":123,"username":"alice","name":"Alice","avatar_url":"https://cdn.example/avatar.png"}`, cfg) require.NoError(t, err) require.Equal(t, "123", subject) require.Equal(t, "alice", username) require.Equal(t, "linuxdo-123@linuxdo-connect.invalid", email) + require.Equal(t, "Alice", displayName) + require.Equal(t, "https://cdn.example/avatar.png", avatarURL) } func TestLinuxDoParseUserInfoDefaultsUsername(t *testing.T) { @@ -53,11 +55,13 @@ func TestLinuxDoParseUserInfoDefaultsUsername(t *testing.T) { UserInfoURL: "https://connect.linux.do/api/user", } - email, username, subject, err := linuxDoParseUserInfo(`{"id":"123"}`, cfg) + email, username, subject, displayName, avatarURL, err := linuxDoParseUserInfo(`{"id":"123"}`, cfg) require.NoError(t, err) require.Equal(t, "123", subject) require.Equal(t, "linuxdo_123", username) require.Equal(t, "linuxdo-123@linuxdo-connect.invalid", email) + require.Equal(t, "linuxdo_123", displayName) + require.Equal(t, "", avatarURL) } func TestLinuxDoParseUserInfoRejectsUnsafeSubject(t *testing.T) { @@ -65,11 +69,11 @@ func TestLinuxDoParseUserInfoRejectsUnsafeSubject(t *testing.T) { UserInfoURL: "https://connect.linux.do/api/user", } - _, _, _, err := linuxDoParseUserInfo(`{"id":"123@456"}`, cfg) + _, _, _, _, _, err := linuxDoParseUserInfo(`{"id":"123@456"}`, cfg) require.Error(t, err) tooLong := strings.Repeat("a", linuxDoOAuthMaxSubjectLen+1) - _, _, _, err = linuxDoParseUserInfo(`{"id":"`+tooLong+`"}`, cfg) + _, _, _, _, _, err = linuxDoParseUserInfo(`{"id":"`+tooLong+`"}`, cfg) require.Error(t, err) } diff --git a/backend/internal/handler/auth_oauth_pending_flow.go b/backend/internal/handler/auth_oauth_pending_flow.go new file mode 100644 index 00000000..a758c0b9 --- /dev/null +++ b/backend/internal/handler/auth_oauth_pending_flow.go @@ -0,0 +1,263 @@ +package handler + +import ( + "net/http" + "net/url" + "strings" + + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/Wei-Shaw/sub2api/internal/pkg/oauth" + "github.com/Wei-Shaw/sub2api/internal/pkg/response" + "github.com/Wei-Shaw/sub2api/internal/service" + + "github.com/gin-gonic/gin" +) + +const ( + oauthPendingBrowserCookiePath = "/api/v1/auth/oauth" + oauthPendingBrowserCookieName = "oauth_pending_browser_session" + oauthPendingSessionCookiePath = "/api/v1/auth/oauth/pending" + oauthPendingSessionCookieName = "oauth_pending_session" + oauthPendingCookieMaxAgeSec = 10 * 60 + + oauthCompletionResponseKey = "completion_response" +) + +type oauthPendingSessionPayload struct { + Intent string + Identity service.PendingAuthIdentityKey + ResolvedEmail string + RedirectTo string + BrowserSessionKey string + UpstreamIdentityClaims map[string]any + CompletionResponse map[string]any +} + +func (h *AuthHandler) pendingIdentityService() (*service.AuthPendingIdentityService, error) { + if h == nil || h.authService == nil || h.authService.EntClient() == nil { + return nil, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready") + } + return service.NewAuthPendingIdentityService(h.authService.EntClient()), nil +} + +func generateOAuthPendingBrowserSession() (string, error) { + return oauth.GenerateState() +} + +func setOAuthPendingBrowserCookie(c *gin.Context, sessionKey string, secure bool) { + http.SetCookie(c.Writer, &http.Cookie{ + Name: oauthPendingBrowserCookieName, + Value: encodeCookieValue(sessionKey), + Path: oauthPendingBrowserCookiePath, + MaxAge: oauthPendingCookieMaxAgeSec, + HttpOnly: true, + Secure: secure, + SameSite: http.SameSiteLaxMode, + }) +} + +func clearOAuthPendingBrowserCookie(c *gin.Context, secure bool) { + http.SetCookie(c.Writer, &http.Cookie{ + Name: oauthPendingBrowserCookieName, + Value: "", + Path: oauthPendingBrowserCookiePath, + MaxAge: -1, + HttpOnly: true, + Secure: secure, + SameSite: http.SameSiteLaxMode, + }) +} + +func readOAuthPendingBrowserCookie(c *gin.Context) (string, error) { + return readCookieDecoded(c, oauthPendingBrowserCookieName) +} + +func setOAuthPendingSessionCookie(c *gin.Context, sessionToken string, secure bool) { + http.SetCookie(c.Writer, &http.Cookie{ + Name: oauthPendingSessionCookieName, + Value: encodeCookieValue(sessionToken), + Path: oauthPendingSessionCookiePath, + MaxAge: oauthPendingCookieMaxAgeSec, + HttpOnly: true, + Secure: secure, + SameSite: http.SameSiteLaxMode, + }) +} + +func clearOAuthPendingSessionCookie(c *gin.Context, secure bool) { + http.SetCookie(c.Writer, &http.Cookie{ + Name: oauthPendingSessionCookieName, + Value: "", + Path: oauthPendingSessionCookiePath, + MaxAge: -1, + HttpOnly: true, + Secure: secure, + SameSite: http.SameSiteLaxMode, + }) +} + +func readOAuthPendingSessionCookie(c *gin.Context) (string, error) { + return readCookieDecoded(c, oauthPendingSessionCookieName) +} + +func redirectToFrontendCallback(c *gin.Context, frontendCallback string) { + u, err := url.Parse(frontendCallback) + if err != nil { + c.Redirect(http.StatusFound, linuxDoOAuthDefaultRedirectTo) + return + } + if u.Scheme != "" && !strings.EqualFold(u.Scheme, "http") && !strings.EqualFold(u.Scheme, "https") { + c.Redirect(http.StatusFound, linuxDoOAuthDefaultRedirectTo) + return + } + u.Fragment = "" + c.Header("Cache-Control", "no-store") + c.Header("Pragma", "no-cache") + c.Redirect(http.StatusFound, u.String()) +} + +func (h *AuthHandler) createOAuthPendingSession(c *gin.Context, payload oauthPendingSessionPayload) error { + svc, err := h.pendingIdentityService() + if err != nil { + return err + } + + session, err := svc.CreatePendingSession(c.Request.Context(), service.CreatePendingAuthSessionInput{ + Intent: strings.TrimSpace(payload.Intent), + Identity: payload.Identity, + ResolvedEmail: strings.TrimSpace(payload.ResolvedEmail), + RedirectTo: strings.TrimSpace(payload.RedirectTo), + BrowserSessionKey: strings.TrimSpace(payload.BrowserSessionKey), + UpstreamIdentityClaims: payload.UpstreamIdentityClaims, + LocalFlowState: map[string]any{ + oauthCompletionResponseKey: payload.CompletionResponse, + }, + }) + if err != nil { + return infraerrors.InternalServer("PENDING_AUTH_SESSION_CREATE_FAILED", "failed to create pending auth session").WithCause(err) + } + + setOAuthPendingSessionCookie(c, session.SessionToken, isRequestHTTPS(c)) + return nil +} + +func readCompletionResponse(session map[string]any) (map[string]any, bool) { + if len(session) == 0 { + return nil, false + } + value, ok := session[oauthCompletionResponseKey] + if !ok { + return nil, false + } + result, ok := value.(map[string]any) + if !ok { + return nil, false + } + return result, true +} + +func pendingSessionStringValue(values map[string]any, key string) string { + if len(values) == 0 { + return "" + } + raw, ok := values[key] + if !ok { + return "" + } + value, ok := raw.(string) + if !ok { + return "" + } + return strings.TrimSpace(value) +} + +func pendingSessionWantsInvitation(payload map[string]any) bool { + return strings.EqualFold(strings.TrimSpace(pendingSessionStringValue(payload, "error")), "invitation_required") +} + +func applySuggestedProfileToCompletionResponse(payload map[string]any, upstream map[string]any) { + if len(payload) == 0 || len(upstream) == 0 { + return + } + + displayName := pendingSessionStringValue(upstream, "suggested_display_name") + avatarURL := pendingSessionStringValue(upstream, "suggested_avatar_url") + + if displayName != "" { + if _, exists := payload["suggested_display_name"]; !exists { + payload["suggested_display_name"] = displayName + } + } + if avatarURL != "" { + if _, exists := payload["suggested_avatar_url"]; !exists { + payload["suggested_avatar_url"] = avatarURL + } + } + if displayName != "" || avatarURL != "" { + payload["adoption_required"] = true + } +} + +// ExchangePendingOAuthCompletion redeems a pending OAuth browser session into a frontend-safe payload. +// POST /api/v1/auth/oauth/pending/exchange +func (h *AuthHandler) ExchangePendingOAuthCompletion(c *gin.Context) { + secureCookie := isRequestHTTPS(c) + clearCookies := func() { + clearOAuthPendingSessionCookie(c, secureCookie) + clearOAuthPendingBrowserCookie(c, secureCookie) + } + + sessionToken, err := readOAuthPendingSessionCookie(c) + if err != nil || strings.TrimSpace(sessionToken) == "" { + clearCookies() + response.ErrorFrom(c, service.ErrPendingAuthSessionNotFound) + return + } + browserSessionKey, err := readOAuthPendingBrowserCookie(c) + if err != nil || strings.TrimSpace(browserSessionKey) == "" { + clearCookies() + response.ErrorFrom(c, service.ErrPendingAuthBrowserMismatch) + return + } + + svc, err := h.pendingIdentityService() + if err != nil { + clearCookies() + response.ErrorFrom(c, err) + return + } + + session, err := svc.GetBrowserSession(c.Request.Context(), sessionToken, browserSessionKey) + if err != nil { + clearCookies() + response.ErrorFrom(c, err) + return + } + + payload, ok := readCompletionResponse(session.LocalFlowState) + if !ok { + clearCookies() + response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_COMPLETION_INVALID", "pending auth completion payload is invalid")) + return + } + if strings.TrimSpace(session.RedirectTo) != "" { + if _, exists := payload["redirect"]; !exists { + payload["redirect"] = session.RedirectTo + } + } + applySuggestedProfileToCompletionResponse(payload, session.UpstreamIdentityClaims) + + if pendingSessionWantsInvitation(payload) { + response.Success(c, payload) + return + } + + if _, err := svc.ConsumeBrowserSession(c.Request.Context(), sessionToken, browserSessionKey); err != nil { + clearCookies() + response.ErrorFrom(c, err) + return + } + + clearCookies() + response.Success(c, payload) +} diff --git a/backend/internal/handler/auth_oauth_pending_flow_test.go b/backend/internal/handler/auth_oauth_pending_flow_test.go new file mode 100644 index 00000000..5517bae2 --- /dev/null +++ b/backend/internal/handler/auth_oauth_pending_flow_test.go @@ -0,0 +1,40 @@ +package handler + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestApplySuggestedProfileToCompletionResponse(t *testing.T) { + payload := map[string]any{ + "access_token": "token", + } + upstream := map[string]any{ + "suggested_display_name": "Alice", + "suggested_avatar_url": "https://cdn.example/avatar.png", + } + + applySuggestedProfileToCompletionResponse(payload, upstream) + + require.Equal(t, "Alice", payload["suggested_display_name"]) + require.Equal(t, "https://cdn.example/avatar.png", payload["suggested_avatar_url"]) + require.Equal(t, true, payload["adoption_required"]) +} + +func TestApplySuggestedProfileToCompletionResponseKeepsExistingPayloadValues(t *testing.T) { + payload := map[string]any{ + "suggested_display_name": "Existing", + "adoption_required": false, + } + upstream := map[string]any{ + "suggested_display_name": "Alice", + "suggested_avatar_url": "https://cdn.example/avatar.png", + } + + applySuggestedProfileToCompletionResponse(payload, upstream) + + require.Equal(t, "Existing", payload["suggested_display_name"]) + require.Equal(t, "https://cdn.example/avatar.png", payload["suggested_avatar_url"]) + require.Equal(t, true, payload["adoption_required"]) +} diff --git a/backend/internal/handler/auth_oidc_oauth.go b/backend/internal/handler/auth_oidc_oauth.go index 37ef6833..e3694c8f 100644 --- a/backend/internal/handler/auth_oidc_oauth.go +++ b/backend/internal/handler/auth_oidc_oauth.go @@ -87,6 +87,8 @@ type oidcUserInfoClaims struct { Username string Subject string EmailVerified *bool + DisplayName string + AvatarURL string } type oidcJWKSet struct { @@ -338,12 +340,14 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) { RedirectTo: redirectTo, BrowserSessionKey: browserSessionKey, UpstreamIdentityClaims: map[string]any{ - "email": email, - "username": username, - "subject": subject, - "issuer": issuer, - "email_verified": emailVerified != nil && *emailVerified, - "provider_fallback": strings.TrimSpace(cfg.ProviderName), + "email": email, + "username": username, + "subject": subject, + "issuer": issuer, + "email_verified": emailVerified != nil && *emailVerified, + "provider_fallback": strings.TrimSpace(cfg.ProviderName), + "suggested_display_name": firstNonEmpty(userInfoClaims.DisplayName, idClaims.Name, username), + "suggested_avatar_url": userInfoClaims.AvatarURL, }, CompletionResponse: map[string]any{ "error": "invitation_required", @@ -371,12 +375,14 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) { RedirectTo: redirectTo, BrowserSessionKey: browserSessionKey, UpstreamIdentityClaims: map[string]any{ - "email": email, - "username": username, - "subject": subject, - "issuer": issuer, - "email_verified": emailVerified != nil && *emailVerified, - "provider_fallback": strings.TrimSpace(cfg.ProviderName), + "email": email, + "username": username, + "subject": subject, + "issuer": issuer, + "email_verified": emailVerified != nil && *emailVerified, + "provider_fallback": strings.TrimSpace(cfg.ProviderName), + "suggested_display_name": firstNonEmpty(userInfoClaims.DisplayName, idClaims.Name, username), + "suggested_avatar_url": userInfoClaims.AvatarURL, }, CompletionResponse: map[string]any{ "access_token": tokenPair.AccessToken, @@ -643,9 +649,26 @@ func oidcParseUserInfo(body string, cfg config.OIDCConnectConfig) *oidcUserInfoC if verified, ok := getGJSONBool(body, "email_verified"); ok { claims.EmailVerified = &verified } + claims.DisplayName = firstNonEmpty( + getGJSON(body, "name"), + getGJSON(body, "nickname"), + getGJSON(body, "display_name"), + getGJSON(body, "preferred_username"), + getGJSON(body, "username"), + ) + claims.AvatarURL = firstNonEmpty( + getGJSON(body, "picture"), + getGJSON(body, "avatar_url"), + getGJSON(body, "avatar"), + getGJSON(body, "profile_image_url"), + getGJSON(body, "user.avatar"), + getGJSON(body, "user.avatar_url"), + ) claims.Email = strings.TrimSpace(claims.Email) claims.Username = strings.TrimSpace(claims.Username) claims.Subject = strings.TrimSpace(claims.Subject) + claims.DisplayName = strings.TrimSpace(claims.DisplayName) + claims.AvatarURL = strings.TrimSpace(claims.AvatarURL) return claims } diff --git a/backend/internal/handler/auth_oidc_oauth_test.go b/backend/internal/handler/auth_oidc_oauth_test.go index a4cf776a..c389db51 100644 --- a/backend/internal/handler/auth_oidc_oauth_test.go +++ b/backend/internal/handler/auth_oidc_oauth_test.go @@ -91,6 +91,26 @@ func TestOIDCParseAndValidateIDToken(t *testing.T) { require.Error(t, err) } +func TestOIDCParseUserInfoIncludesSuggestedProfile(t *testing.T) { + cfg := config.OIDCConnectConfig{} + + claims := oidcParseUserInfo(`{ + "sub":"subject-1", + "preferred_username":"alice", + "name":"Alice Example", + "picture":"https://cdn.example/avatar.png", + "email":"alice@example.com", + "email_verified":true + }`, cfg) + + require.Equal(t, "subject-1", claims.Subject) + require.Equal(t, "alice", claims.Username) + require.Equal(t, "Alice Example", claims.DisplayName) + require.Equal(t, "https://cdn.example/avatar.png", claims.AvatarURL) + require.NotNil(t, claims.EmailVerified) + require.True(t, *claims.EmailVerified) +} + func buildRSAJWK(kid string, pub *rsa.PublicKey) oidcJWK { n := base64.RawURLEncoding.EncodeToString(pub.N.Bytes()) e := base64.RawURLEncoding.EncodeToString(big.NewInt(int64(pub.E)).Bytes()) diff --git a/frontend/src/api/auth.ts b/frontend/src/api/auth.ts index 837c4f4c..d7abcd6a 100644 --- a/frontend/src/api/auth.ts +++ b/frontend/src/api/auth.ts @@ -186,6 +186,18 @@ export interface RefreshTokenResponse { token_type: string } +export interface PendingOAuthExchangeResponse { + access_token?: string + refresh_token?: string + expires_in?: number + token_type?: string + redirect?: string + error?: string + adoption_required?: boolean + suggested_display_name?: string + suggested_avatar_url?: string +} + /** * Refresh the access token using the refresh token * @returns New token pair @@ -337,12 +349,10 @@ export async function resetPassword(request: ResetPasswordRequest): Promise { const { data } = await apiClient.post<{ @@ -351,7 +361,6 @@ export async function completeLinuxDoOAuthRegistration( expires_in: number token_type: string }>('/auth/oauth/linuxdo/complete-registration', { - pending_oauth_token: pendingOAuthToken, invitation_code: invitationCode }) return data @@ -359,12 +368,10 @@ export async function completeLinuxDoOAuthRegistration( /** * Complete OIDC OAuth registration by supplying an invitation code - * @param pendingOAuthToken - Short-lived JWT from the OAuth callback * @param invitationCode - Invitation code entered by the user * @returns Token pair on success */ export async function completeOIDCOAuthRegistration( - pendingOAuthToken: string, invitationCode: string ): Promise<{ access_token: string; refresh_token: string; expires_in: number; token_type: string }> { const { data } = await apiClient.post<{ @@ -373,12 +380,16 @@ export async function completeOIDCOAuthRegistration( expires_in: number token_type: string }>('/auth/oauth/oidc/complete-registration', { - pending_oauth_token: pendingOAuthToken, invitation_code: invitationCode }) return data } +export async function exchangePendingOAuthCompletion(): Promise { + const { data } = await apiClient.post('/auth/oauth/pending/exchange', {}) + return data +} + export const authAPI = { login, login2FA, @@ -402,6 +413,7 @@ export const authAPI = { resetPassword, refreshToken, revokeAllSessions, + exchangePendingOAuthCompletion, completeLinuxDoOAuthRegistration, completeOIDCOAuthRegistration }